Parcourir la source

[spirv] Add vk::image_format attribute for Buffers, RWBuffers and RWTextures (#3395)

According to Vulkan specification when using `OpImageRead/OpImageWrite`, the `OpTypeImage` (`Buffers`, `RWBuffers`, `RWTextures`) must have a format that matches the format on the API side, unless the StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat is added and `Unknown` is used as the format.

This pull request addressess #2498 for the format part by adding an attribute `[[vk::image_format("<image format as spelled in SPIR-V spec>")]].` Example of the syntax:

```
[[vk::image_format("rgba8")]]
RWBuffer<float4> Buf;

[[vk::image_format("rg16f")]]
RWTexture2D<float2> Tex;

RWTexture2D<float2> Tex2; // Works like before
```

The `image_format` only applies to **global variables** of type `Buffer`, `RWBuffer`, `RWTexture`. For variables and function parameters it is propagated by the inlining pass in legalization. This required a small change to one of the passes in SPIRV-Tools, that should be also checked by someone more familiar with the codebase: https://github.com/KhronosGroup/SPIRV-Tools/pull/4126

Note that this does not fix the handling of unspecified format (that case still works like before, using `R32f`, etc. based on the type in shader), although it should be still fixed to add the
StorageImageReadWithoutFormat and/or StorageImageWriteWithoutFormat and use Undefined. But I think the ability to specify the format is more urgent.

Design note from Jaebaek:
Since the `image_format` attribute only applies to **global variables**, under the DXC architecture
only `DeclResultIdMapper` can check the attribute when it handles `VarDecl`s. It means
we have to pass the `image_format` information to `LowerTypeVisitor` because it cannot access to
`VarDecl`. In order to pass the `image_format`, we use `SpirvContext` that can be accessed by
`SpirvEmitter` and all visitors. We use `SpirvVariable` to `spv::ImageFormat` mapping because the
attribute only applies to **global variables** (not to image types).
See how we use `llvm::DenseMap<const SpirvVariable *, spv::ImageFormat> spvVarToImageFormat`.
seppala2 il y a 4 ans
Parent
commit
03cc4f8663

+ 67 - 0
docs/SPIR-V.rst

@@ -786,6 +786,73 @@ are translated into SPIR-V ``OpTypeImage``, with parameters:
 The meanings of the headers in the above table is explained in ``OpTypeImage``
 of the SPIR-V spec.
 
+Vulkan specific Image Formats
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Since HLSL lacks the syntax for fully specifying image formats for textures in
+SPIR-V, we introduce ``[[vk::image_format("FORMAT")]]`` attribute for texture types.
+For example,
+
+.. code:: hlsl
+  [[vk::image_format("rgba8")]]
+  RWBuffer<float4> Buf;
+
+  [[vk::image_format("rg16f")]]
+  RWTexture2D<float2> Tex;
+
+  RWTexture2D<float2> Tex2; // Works like before
+
+``rgba8`` means ``Rgba8`` `SPIR-V Image Format <https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_image_format_a_image_format>`_.
+The following table lists the mapping between ``FORMAT`` of
+``[[vk::image_format("FORMAT")]]`` and its corresponding SPIR-V Image Format.
+
+======================= ============================================
+       FORMAT                   SPIR-V Image Format
+======================= ============================================
+``unknown``             ``Unknown``
+``rgba32f``             ``Rgba32f``
+``rgba16f``             ``Rgba16f``
+``r32f``                ``R32f``
+``rgba8``               ``Rgba8``
+``rgba8snorm``          ``Rgba8Snorm``
+``rg32f``               ``Rg32f``
+``rg16f``               ``Rg16f``
+``r11g11b10f``          ``R11fG11fB10f``
+``r16f``                ``R16f``
+``rgba16``              ``Rgba16``
+``rgb10a2``             ``Rgb10A2``
+``rg16``                ``Rg16``
+``rg8``                 ``Rg8``
+``r16``                 ``R16``
+``r8``                  ``R8``
+``rgba16snorm``         ``Rgba16Snorm``
+``rg16snorm``           ``Rg16Snorm``
+``rg8snorm``            ``Rg8Snorm``
+``r16snorm``            ``R16Snorm``
+``r8snorm``             ``R8Snorm``
+``rgba32i``             ``Rgba32i``
+``rgba16i``             ``Rgba16i``
+``rgba8i``              ``Rgba8i``
+``r32i``                ``R32i``
+``rg32i``               ``Rg32i``
+``rg16i``               ``Rg16i``
+``rg8i``                ``Rg8i``
+``r16i``                ``R16i``
+``r8i``                 ``R8i``
+``rgba32ui``            ``Rgba32ui``
+``rgba16ui``            ``Rgba16ui``
+``rgba8ui``             ``Rgba8ui``
+``r32ui``               ``R32ui``
+``rgb10a2ui``           ``Rgb10a2ui``
+``rg32ui``              ``Rg32ui``
+``rg16ui``              ``Rg16ui``
+``rg8ui``               ``Rg8ui``
+``r16ui``               ``R16ui``
+``r8ui``                ``R8ui``
+``r64ui``               ``R64ui``
+``r64i``                ``R64i``
+======================= ============================================
+
 Constant/Texture/Structured/Byte Buffers
 ----------------------------------------
 

+ 1 - 1
external/SPIRV-Headers

@@ -1 +1 @@
-Subproject commit 75b30a659c8a4979104986652c54cc421fc51129
+Subproject commit a3fdfe81465d57efc97cfd28ac6c8190fb31a6c8

+ 1 - 1
external/SPIRV-Tools

@@ -1 +1 @@
-Subproject commit cfa1dadb1e62312655531de3cb97cecb0b21a737
+Subproject commit ef3290bbea35935ba8fd623970511ed9f045bbd7

+ 49 - 0
tools/clang/include/clang/Basic/Attr.td

@@ -937,6 +937,32 @@ def ConstantTextureBuffer
                   S->getType()->getAs<RecordType>()->getDecl()->getName() ==
                       "TextureBuffer")}]>;
 
+// Global variable with "RWTexture" type
+def RWTexture
+    : SubsetSubject<
+          Var, [{S->hasGlobalStorage() && S->getType()->getAs<RecordType>() &&
+                 S->getType()->getAs<RecordType>()->getDecl() &&
+                  (S->getType()->getAs<RecordType>()->getDecl()->getName() ==
+                      "RWTexture1D" ||
+                  S->getType()->getAs<RecordType>()->getDecl()->getName() ==
+                      "RWTexture1DArray" ||
+                  S->getType()->getAs<RecordType>()->getDecl()->getName() ==
+                      "RWTexture2D" ||
+                  S->getType()->getAs<RecordType>()->getDecl()->getName() ==
+                      "RWTexture2DArray" ||
+                  S->getType()->getAs<RecordType>()->getDecl()->getName() ==
+                      "RWTexture3D")}]>;
+
+// Global variable with "[RW]Buffer" type
+def Buffer
+    : SubsetSubject<
+          Var, [{S->hasGlobalStorage() && S->getType()->getAs<RecordType>() &&
+                 S->getType()->getAs<RecordType>()->getDecl() &&
+                 (S->getType()->getAs<RecordType>()->getDecl()->getName() ==
+                      "Buffer" ||
+                  S->getType()->getAs<RecordType>()->getDecl()->getName() ==
+                      "RWBuffer")}]>;
+
 def VKBuiltIn : InheritableAttr {
   let Spellings = [CXX11<"vk", "builtin">];
   let Subjects = SubjectList<[Function, ParmVar, Field], ErrorDiag>;
@@ -997,6 +1023,29 @@ def VKOffset : InheritableAttr {
   let Documentation = [Undocumented];
 }
 
+def VKImageFormat : InheritableAttr {
+  let Spellings = [CXX11<"vk", "image_format">];
+  let Subjects = SubjectList<[RWTexture, Buffer],
+                             ErrorDiag, "ExpectedRWTextureOrBuffer">;
+  let Args = [EnumArgument<"ImageFormat", "ImageFormatType",
+                           ["unknown", "rgba32f", "rgba16f", "r32f", "rgba8", "rgba8snorm",
+                           "rg32f", "rg16f", "r11g11b10f", "r16f", "rgba16", "rgb10a2",
+                           "rg16", "rg8", "r16", "r8", "rgba16snorm", "rg16snorm", "rg8snorm",
+                           "r16snorm", "r8snorm", "rgba32i", "rgba16i", "rgba8i", "r32i",
+                           "rg32i", "rg16i", "rg8i", "r16i", "r8i", "rgba32ui", "rgba16ui", "rgba8ui",
+                           "r32ui", "rgb10a2ui", "rg32ui", "rg16ui", "rg8ui", "r16ui",
+                           "r8ui", "r64ui", "r64i"],
+                           ["unknown", "rgba32f", "rgba16f", "r32f", "rgba8", "rgba8snorm",
+                           "rg32f", "rg16f", "r11g11b10f", "r16f", "rgba16", "rgb10a2",
+                           "rg16", "rg8", "r16", "r8", "rgba16snorm", "rg16snorm", "rg8snorm",
+                           "r16snorm", "r8snorm", "rgba32i", "rgba16i", "rgba8i", "r32i",
+                           "rg32i", "rg16i", "rg8i", "r16i", "r8i", "rgba32ui", "rgba16ui", "rgba8ui",
+                           "r32ui", "rgb10a2ui", "rg32ui", "rg16ui", "rg8ui", "r16ui",
+                           "r8ui", "r64ui", "r64i"]>];
+  let LangOpts = [SPIRV];
+  let Documentation = [Undocumented];
+}
+
 def SubpassInput : SubsetSubject<
     Var,
     [{S->hasGlobalStorage() && S->getType()->getAs<RecordType>() &&

+ 1 - 0
tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

@@ -2332,6 +2332,7 @@ def warn_attribute_wrong_decl_type : Warning<
   "global variables of scalar type|"
   "global variables of struct type|"
   "global variables, cbuffers, and tbuffers|"
+  "RWTextures, Buffers and RWBuffers|"
   "RWStructuredBuffers, AppendStructuredBuffers, and ConsumeStructuredBuffers|"
   "SubpassInput, SubpassInputMS|"
   "cbuffer or ConstantBuffer|"

+ 21 - 0
tools/clang/include/clang/SPIRV/SpirvContext.h

@@ -235,6 +235,10 @@ public:
                                 ImageType::WithDepth, bool arrayed, bool ms,
                                 ImageType::WithSampler sampled,
                                 spv::ImageFormat);
+  // Get ImageType whose attributes are the same with imageTypeWithUnknownFormat
+  // but it has spv::ImageFormat format.
+  const ImageType *getImageType(const ImageType *imageTypeWithUnknownFormat,
+                                spv::ImageFormat format);
   const SamplerType *getSamplerType() const { return samplerType; }
   const SampledImageType *getSampledImageType(const ImageType *image);
   const HybridSampledImageType *getSampledImageType(QualType image);
@@ -335,6 +339,20 @@ public:
     return currentLexicalScope;
   }
 
+  /// Function to add/get the mapping from a SPIR-V OpVariable to its image
+  /// format.
+  void registerImageFormatForSpirvVariable(const SpirvVariable *spvVar,
+                                           spv::ImageFormat format) {
+    assert(spvVar != nullptr);
+    spvVarToImageFormat[spvVar] = format;
+  }
+  spv::ImageFormat getImageFormatForSpirvVariable(const SpirvVariable *spvVar) {
+    auto itr = spvVarToImageFormat.find(spvVar);
+    if (itr == spvVarToImageFormat.end())
+      return spv::ImageFormat::Unknown;
+    return itr->second;
+  }
+
   /// Function to add/get the mapping from a SPIR-V type to its Decl for
   /// a struct type.
   void registerStructDeclForSpirvType(const SpirvType *spvTy,
@@ -442,6 +460,9 @@ private:
   // Mapping from FunctionDecl to SPIR-V debug function.
   llvm::DenseMap<const FunctionDecl *, SpirvDebugFunction *>
       declToDebugFunction;
+
+  // Mapping from SPIR-V OpVariable to SPIR-V image format.
+  llvm::DenseMap<const SpirvVariable *, spv::ImageFormat> spvVarToImageFormat;
 };
 
 } // end namespace spirv

+ 10 - 9
tools/clang/include/clang/Sema/AttributeList.h

@@ -854,18 +854,19 @@ enum AttributeDeclKind {
   ExpectedStructOrUnionOrTypedef,
   ExpectedStructOrTypedef,
   ExpectedObjectiveCInterfaceOrProtocol,
-  ExpectedKernelFunction
+  ExpectedKernelFunction,
   // SPIRV Change Begins
-  ,ExpectedField
-  ,ExpectedScalarGlobalVar
-  ,ExpectedStructGlobalVar
-  ,ExpectedGlobalVarOrCTBuffer
-  ,ExpectedCounterStructuredBuffer
-  ,ExpectedSubpassInput
-  ,ExpectedCTBuffer
+  ExpectedField,
+  ExpectedScalarGlobalVar,
+  ExpectedStructGlobalVar,
+  ExpectedGlobalVarOrCTBuffer,
+  ExpectedRWTextureOrBuffer,
+  ExpectedCounterStructuredBuffer,
+  ExpectedSubpassInput,
+  ExpectedCTBuffer,
   // SPIRV Change Ends
   // HLSL Change Begins - add attribute decl combinations
-  ,ExpectedVariableOrParam,
+  ExpectedVariableOrParam,
   ExpectedFunctionOrParamOrField,
   ExpectedFunctionOrVariableOrParamOrFieldOrType
   // HLSL Change Ends

+ 100 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -398,6 +398,99 @@ SpirvLayoutRule getLayoutRuleForExternVar(QualType type,
   return SpirvLayoutRule::Void;
 }
 
+spv::ImageFormat getSpvImageFormat(const VKImageFormatAttr *imageFormatAttr) {
+  if (imageFormatAttr == nullptr)
+    return spv::ImageFormat::Unknown;
+
+  switch (imageFormatAttr->getImageFormat()) {
+  case VKImageFormatAttr::unknown:
+    return spv::ImageFormat::Unknown;
+  case VKImageFormatAttr::rgba32f:
+    return spv::ImageFormat::Rgba32f;
+  case VKImageFormatAttr::rgba16f:
+    return spv::ImageFormat::Rgba16f;
+  case VKImageFormatAttr::r32f:
+    return spv::ImageFormat::R32f;
+  case VKImageFormatAttr::rgba8:
+    return spv::ImageFormat::Rgba8;
+  case VKImageFormatAttr::rgba8snorm:
+    return spv::ImageFormat::Rgba8Snorm;
+  case VKImageFormatAttr::rg32f:
+    return spv::ImageFormat::Rg32f;
+  case VKImageFormatAttr::rg16f:
+    return spv::ImageFormat::Rg16f;
+  case VKImageFormatAttr::r11g11b10f:
+    return spv::ImageFormat::R11fG11fB10f;
+  case VKImageFormatAttr::r16f:
+    return spv::ImageFormat::R16f;
+  case VKImageFormatAttr::rgba16:
+    return spv::ImageFormat::Rgba16;
+  case VKImageFormatAttr::rgb10a2:
+    return spv::ImageFormat::Rgb10A2;
+  case VKImageFormatAttr::rg16:
+    return spv::ImageFormat::Rg16;
+  case VKImageFormatAttr::rg8:
+    return spv::ImageFormat::Rg8;
+  case VKImageFormatAttr::r16:
+    return spv::ImageFormat::R16;
+  case VKImageFormatAttr::r8:
+    return spv::ImageFormat::R8;
+  case VKImageFormatAttr::rgba16snorm:
+    return spv::ImageFormat::Rgba16Snorm;
+  case VKImageFormatAttr::rg16snorm:
+    return spv::ImageFormat::Rg16Snorm;
+  case VKImageFormatAttr::rg8snorm:
+    return spv::ImageFormat::Rg8Snorm;
+  case VKImageFormatAttr::r16snorm:
+    return spv::ImageFormat::R16Snorm;
+  case VKImageFormatAttr::r8snorm:
+    return spv::ImageFormat::R8Snorm;
+  case VKImageFormatAttr::rgba32i:
+    return spv::ImageFormat::Rgba32i;
+  case VKImageFormatAttr::rgba16i:
+    return spv::ImageFormat::Rgba16i;
+  case VKImageFormatAttr::rgba8i:
+    return spv::ImageFormat::Rgba8i;
+  case VKImageFormatAttr::r32i:
+    return spv::ImageFormat::R32i;
+  case VKImageFormatAttr::rg32i:
+    return spv::ImageFormat::Rg32i;
+  case VKImageFormatAttr::rg16i:
+    return spv::ImageFormat::Rg16i;
+  case VKImageFormatAttr::rg8i:
+    return spv::ImageFormat::Rg8i;
+  case VKImageFormatAttr::r16i:
+    return spv::ImageFormat::R16i;
+  case VKImageFormatAttr::r8i:
+    return spv::ImageFormat::R8i;
+  case VKImageFormatAttr::rgba32ui:
+    return spv::ImageFormat::Rgba32ui;
+  case VKImageFormatAttr::rgba16ui:
+    return spv::ImageFormat::Rgba16ui;
+  case VKImageFormatAttr::rgba8ui:
+    return spv::ImageFormat::Rgba8ui;
+  case VKImageFormatAttr::r32ui:
+    return spv::ImageFormat::R32ui;
+  case VKImageFormatAttr::rgb10a2ui:
+    return spv::ImageFormat::Rgb10a2ui;
+  case VKImageFormatAttr::rg32ui:
+    return spv::ImageFormat::Rg32ui;
+  case VKImageFormatAttr::rg16ui:
+    return spv::ImageFormat::Rg16ui;
+  case VKImageFormatAttr::rg8ui:
+    return spv::ImageFormat::Rg8ui;
+  case VKImageFormatAttr::r16ui:
+    return spv::ImageFormat::R16ui;
+  case VKImageFormatAttr::r8ui:
+    return spv::ImageFormat::R8ui;
+  case VKImageFormatAttr::r64ui:
+    return spv::ImageFormat::R64ui;
+  case VKImageFormatAttr::r64i:
+    return spv::ImageFormat::R64i;
+  }
+  return spv::ImageFormat::Unknown;
+}
+
 } // anonymous namespace
 
 std::string StageVar::getSemanticStr() const {
@@ -847,6 +940,13 @@ SpirvVariable *DeclResultIdMapper::createExternVar(const VarDecl *var) {
       type, storageClass, var->hasAttr<HLSLPreciseAttr>(), name, llvm::None,
       loc);
   varInstr->setLayoutRule(rule);
+
+  // If this variable has [[vk::image_format("..")]] attribute, we have to keep
+  // it in the SpirvContext and use it when we lower the QualType to SpirvType.
+  auto spvImageFormat = getSpvImageFormat(var->getAttr<VKImageFormatAttr>());
+  if (spvImageFormat != spv::ImageFormat::Unknown)
+    spvContext.registerImageFormatForSpirvVariable(varInstr, spvImageFormat);
+
   DeclSpirvInfo info(varInstr);
   astDecls[var] = info;
 

+ 8 - 0
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -121,6 +121,14 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
       if (var->hasBinding() && var->getHlslUserType().empty()) {
         var->setHlslUserType(getHlslResourceTypeName(var->getAstResultType()));
       }
+
+      auto spvImageFormat = spvContext.getImageFormatForSpirvVariable(var);
+      if (spvImageFormat != spv::ImageFormat::Unknown) {
+        if (const auto *imageType = dyn_cast<ImageType>(resultType)) {
+          resultType = spvContext.getImageType(imageType, spvImageFormat);
+          instr->setResultType(resultType);
+        }
+      }
     }
     const SpirvType *pointerType =
         spvContext.getPointerType(resultType, instr->getStorageClass());

+ 11 - 0
tools/clang/lib/SPIRV/SpirvContext.cpp

@@ -184,6 +184,17 @@ const SpirvType *SpirvContext::getMatrixType(const SpirvType *elemType,
   return ptr;
 }
 
+const ImageType *
+SpirvContext::getImageType(const ImageType *imageTypeWithUnknownFormat,
+                           spv::ImageFormat format) {
+  return getImageType(imageTypeWithUnknownFormat->getSampledType(),
+                      imageTypeWithUnknownFormat->getDimension(),
+                      imageTypeWithUnknownFormat->getDepth(),
+                      imageTypeWithUnknownFormat->isArrayedImage(),
+                      imageTypeWithUnknownFormat->isMSImage(),
+                      imageTypeWithUnknownFormat->withSampler(), format);
+}
+
 const ImageType *SpirvContext::getImageType(const SpirvType *sampledType,
                                             spv::Dim dim,
                                             ImageType::WithDepth depth,

+ 31 - 0
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -11137,6 +11137,28 @@ static int ValidateAttributeFloatArg(Sema &S, const AttributeList &Attr,
   return value;
 }
 
+template <typename AttrType, typename EnumType,
+          bool (*ConvertStrToEnumType)(StringRef, EnumType &)>
+static EnumType ValidateAttributeEnumArg(Sema &S, const AttributeList &Attr,
+                                         EnumType defaultValue,
+                                         unsigned index = 0) {
+  EnumType value(defaultValue);
+  StringRef Str = "";
+  SourceLocation ArgLoc;
+
+  if (Attr.getNumArgs() > index) {
+    if (!S.checkStringLiteralArgumentAttr(Attr, 0, Str, &ArgLoc))
+      return value;
+
+    if (!ConvertStrToEnumType(Str, value)) {
+      S.Diag(Attr.getLoc(), diag::warn_attribute_type_not_supported)
+          << Attr.getName() << Str << ArgLoc;
+    }
+    return value;
+  }
+  return value;
+}
+
 static Stmt* IgnoreParensAndDecay(Stmt* S)
 {
   for (;;)
@@ -11683,6 +11705,15 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
     declAttr = ::new (S.Context) VKOffsetAttr(A.getRange(), S.Context,
       ValidateAttributeIntArg(S, A), A.getAttributeSpellingListIndex());
     break;
+  case AttributeList::AT_VKImageFormat: {
+    VKImageFormatAttr::ImageFormatType Kind = ValidateAttributeEnumArg<
+        VKImageFormatAttr, VKImageFormatAttr::ImageFormatType,
+        VKImageFormatAttr::ConvertStrToImageFormatType>(
+        S, A, VKImageFormatAttr::ImageFormatType::unknown);
+    declAttr = ::new (S.Context) VKImageFormatAttr(
+        A.getRange(), S.Context, Kind, A.getAttributeSpellingListIndex());
+    break;
+  }
   case AttributeList::AT_VKInputAttachmentIndex:
     declAttr = ::new (S.Context) VKInputAttachmentIndexAttr(
         A.getRange(), S.Context, ValidateAttributeIntArg(S, A),

+ 91 - 0
tools/clang/test/CodeGenSPIRV/vk.attribute.image-format.hlsl

@@ -0,0 +1,91 @@
+// Run: %dxc -T cs_6_0 -E main
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 Rgba16f
+[[vk::image_format("rgba16f")]]
+RWBuffer<float4> Buf;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 R32f
+[[vk::image_format("r32f")]]
+RWBuffer<float4> Buf_r32f;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 Rgba8Snorm
+[[vk::image_format("rgba8snorm")]]
+RWBuffer<float4> Buf_rgba8snorm;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 Rg16f
+[[vk::image_format("rg16f")]]
+RWBuffer<float4> Buf_rg16f;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 R11fG11fB10f
+[[vk::image_format("r11g11b10f")]]
+RWBuffer<float4> Buf_r11g11b10f;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 Rgb10A2
+[[vk::image_format("rgb10a2")]]
+RWBuffer<float4> Buf_rgb10a2;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 Rg8
+[[vk::image_format("rg8")]]
+RWBuffer<float4> Buf_rg8;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 R8
+[[vk::image_format("r8")]]
+RWBuffer<float4> Buf_r8;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 Rg16Snorm
+[[vk::image_format("rg16snorm")]]
+RWBuffer<float4> Buf_rg16snorm;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 Rgba32i
+[[vk::image_format("rgba32i")]]
+RWBuffer<float4> Buf_rgba32i;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 Rg8i
+[[vk::image_format("rg8i")]]
+RWBuffer<float4> Buf_rg8i;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 Rgba16ui
+[[vk::image_format("rgba16ui")]]
+RWBuffer<float4> Buf_rgba16ui;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 Rgb10a2ui
+[[vk::image_format("rgb10a2ui")]]
+RWBuffer<float4> Buf_rgb10a2ui;
+
+struct S {
+    RWBuffer<float4> b;
+};
+
+float4 getVal(RWBuffer<float4> b) {
+    return b[0];
+}
+
+float4 getValStruct(S s) {
+    return s.b[1];
+}
+
+[numthreads(1, 1, 1)]
+void main() {
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 Rgba32f
+    RWBuffer<float4> foo;
+
+    foo = Buf;
+
+    float4 test = getVal(foo);
+    test += getVal(Buf_r32f);
+
+    S s;
+    s.b = Buf;
+    test += getValStruct(s);
+
+    S s2;
+    s2.b = Buf_r32f;
+    test += getValStruct(s2);
+
+    RWBuffer<float4> var = Buf;
+    RWBuffer<float4> var2 = Buf_r32f;
+    test += var[2];
+    test += var2[2];
+
+    Buf[10] = test + 1;
+}

+ 79 - 0
tools/clang/test/CodeGenSPIRV/vk.attribute.image-format.o3.hlsl

@@ -0,0 +1,79 @@
+// Run: %dxc -T cs_6_0 -E main -O3
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 Rgba16f
+[[vk::image_format("rgba16f")]]
+RWBuffer<float4> Buf;
+
+//CHECK: OpTypeImage %float Buffer 2 0 0 2 R32f
+[[vk::image_format("r32f")]]
+RWBuffer<float4> Buf_r32f;
+
+[[vk::image_format("rgba8snorm")]]
+RWBuffer<float4> Buf_rgba8snorm;
+
+[[vk::image_format("rg16f")]]
+RWBuffer<float4> Buf_rg16f;
+
+[[vk::image_format("r11g11b10f")]]
+RWBuffer<float4> Buf_r11g11b10f;
+
+[[vk::image_format("rgb10a2")]]
+RWBuffer<float4> Buf_rgb10a2;
+
+[[vk::image_format("rg8")]]
+RWBuffer<float4> Buf_rg8;
+
+[[vk::image_format("r8")]]
+RWBuffer<float4> Buf_r8;
+
+[[vk::image_format("rg16snorm")]]
+RWBuffer<float4> Buf_rg16snorm;
+
+[[vk::image_format("rgba32i")]]
+RWBuffer<float4> Buf_rgba32i;
+
+[[vk::image_format("rg8i")]]
+RWBuffer<float4> Buf_rg8i;
+
+[[vk::image_format("rgba16ui")]]
+RWBuffer<float4> Buf_rgba16ui;
+
+[[vk::image_format("rgb10a2ui")]]
+RWBuffer<float4> Buf_rgb10a2ui;
+
+struct S {
+    RWBuffer<float4> b;
+};
+
+float4 getVal(RWBuffer<float4> b) {
+    return b[0];
+}
+
+float4 getValStruct(S s) {
+    return s.b[1];
+}
+
+[numthreads(1, 1, 1)]
+void main() {
+    RWBuffer<float4> foo;
+
+    foo = Buf;
+
+    float4 test = getVal(foo);
+    test += getVal(Buf_r32f);
+
+    S s;
+    s.b = Buf;
+    test += getValStruct(s);
+
+    S s2;
+    s2.b = Buf_r32f;
+    test += getValStruct(s2);
+
+    RWBuffer<float4> var = Buf;
+    RWBuffer<float4> var2 = Buf_r32f;
+    test += var[2];
+    test += var2[2];
+
+    Buf[10] = test + 1;
+}

+ 8 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -1708,6 +1708,14 @@ TEST_F(FileTest, VulkanAttributeShaderRecordEXTInvalidUsages) {
   runFileTest("vk.attribute.shader-record-ext.invalid.hlsl", Expect::Failure);
 }
 
+TEST_F(FileTest, VulkanAttributeImageFormat) {
+  runFileTest("vk.attribute.image-format.hlsl", Expect::Success,
+              /*runValidation*/ false);
+}
+TEST_F(FileTest, VulkanAttributeImageFormatO3) {
+  runFileTest("vk.attribute.image-format.o3.hlsl");
+}
+
 TEST_F(FileTest, VulkanCLOptionInvertYVS) {
   runFileTest("vk.cloption.invert-y.vs.hlsl");
 }