Browse Source

[spirv] Support struct as Buffer/Texture template type (#1371)

fxc.exe supports templating Buffer/Texture with a struct type,
as long asthe struct type can be fit into a 4-component vector.
Lei Zhang 7 years ago
parent
commit
310c305150

+ 51 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -3151,7 +3151,21 @@ SpirvEvalInfo SPIRVEmitter::processBufferTextureLoad(
   QualType elemType = sampledType;
   uint32_t elemCount = 1;
   uint32_t elemTypeId = 0;
-  (void)TypeTranslator::isVectorType(sampledType, &elemType, &elemCount);
+  bool isTemplateOverStruct = false;
+
+  // Check whether the template type is a vector type or struct type.
+  if (!TypeTranslator::isVectorType(sampledType, &elemType, &elemCount)) {
+    if (sampledType->getAsStructureType()) {
+      isTemplateOverStruct = true;
+      // For struct type, we need to make sure it can fit into a 4-component
+      // vector. Detailed failing reasons will be emitted by the function so
+      // we don't need to emit errors here.
+      if (!typeTranslator.canFitIntoOneRegister(sampledType, &elemType,
+                                                &elemCount))
+        return 0;
+    }
+  }
+
   if (elemType->isFloatingType()) {
     elemTypeId = theBuilder.getFloat32Type();
   } else if (elemType->isSignedIntegerType()) {
@@ -3172,6 +3186,10 @@ SpirvEvalInfo SPIRVEmitter::processBufferTextureLoad(
   // If the result type is a vec1, vec2, or vec3, some extra processing
   // (extraction) is required.
   uint32_t retVal = extractVecFromVec4(texel, elemCount, elemTypeId);
+  if (isTemplateOverStruct) {
+    // Convert to the struct so that we are consistent with types in the AST.
+    retVal = convertVectorToStruct(sampledType, elemTypeId, retVal);
+  }
   return SpirvEvalInfo(retVal).setRValue();
 }
 
@@ -5446,6 +5464,38 @@ void SPIRVEmitter::splitVecLastElement(QualType vecType, uint32_t vec,
       theBuilder.createCompositeExtract(elemTypeId, vec, {count - 1});
 }
 
+uint32_t SPIRVEmitter::convertVectorToStruct(QualType structType,
+                                             uint32_t elemTypeId,
+                                             uint32_t vector) {
+  assert(structType->isStructureType());
+
+  const auto *structDecl = structType->getAsStructureType()->getDecl();
+  uint32_t vectorIndex = 0;
+  uint32_t elemCount = 1;
+  llvm::SmallVector<uint32_t, 4> members;
+
+  for (const auto *field : structDecl->fields()) {
+    if (TypeTranslator::isScalarType(field->getType())) {
+      members.push_back(theBuilder.createCompositeExtract(elemTypeId, vector,
+                                                          {vectorIndex++}));
+    } else if (TypeTranslator::isVectorType(field->getType(), nullptr,
+                                            &elemCount)) {
+      llvm::SmallVector<uint32_t, 4> indices;
+      for (uint32_t i = 0; i < elemCount; ++i)
+        indices.push_back(vectorIndex++);
+
+      const uint32_t type = theBuilder.getVecType(elemTypeId, elemCount);
+      members.push_back(
+          theBuilder.createVectorShuffle(type, vector, vector, indices));
+    } else {
+      assert(false && "unhandled type");
+    }
+  }
+
+  return theBuilder.createCompositeConstruct(
+      typeTranslator.translateType(structType), members);
+}
+
 SpirvEvalInfo
 SPIRVEmitter::tryToGenFloatVectorScale(const BinaryOperator *expr) {
   const QualType type = expr->getType();

+ 8 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -222,6 +222,14 @@ private:
   void splitVecLastElement(QualType vecType, uint32_t vec, uint32_t *residual,
                            uint32_t *lastElement);
 
+  /// Converts a vector value into the given struct type with its element type's
+  /// <result-id> as elemTypeId.
+  ///
+  /// Assumes the vector and the struct have matching number of elements. Panics
+  /// otherwise.
+  uint32_t convertVectorToStruct(QualType structType, uint32_t elemTypeId,
+                                 uint32_t vector);
+
   /// Translates a floatN * float multiplication into SPIR-V instructions and
   /// returns the <result-id>. Returns 0 if the given binary operation is not
   /// floatN * float.

+ 53 - 2
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -1191,7 +1191,7 @@ bool TypeTranslator::isSameType(QualType type1, QualType type2) {
 QualType TypeTranslator::getElementType(QualType type) {
   QualType elemType = {};
   if (isScalarType(type, &elemType) || isVectorType(type, &elemType) ||
-      isMxNMatrix(type, &elemType)) {
+      isMxNMatrix(type, &elemType) || canFitIntoOneRegister(type, &elemType)) {
     return elemType;
   }
 
@@ -1598,7 +1598,8 @@ TypeTranslator::translateSampledTypeToImageFormat(QualType sampledType) {
   uint32_t elemCount = 1;
   QualType ty = {};
   if (isScalarType(sampledType, &ty) ||
-      isVectorType(sampledType, &ty, &elemCount)) {
+      isVectorType(sampledType, &ty, &elemCount) ||
+      canFitIntoOneRegister(sampledType, &ty, &elemCount)) {
     if (const auto *builtinType = ty->getAs<BuiltinType>()) {
       switch (builtinType->getKind()) {
       case BuiltinType::Int:
@@ -1625,6 +1626,56 @@ TypeTranslator::translateSampledTypeToImageFormat(QualType sampledType) {
   return spv::ImageFormat::Unknown;
 }
 
+bool TypeTranslator::canFitIntoOneRegister(QualType structType,
+                                           QualType *elemType,
+                                           uint32_t *elemCount) {
+  if (structType->getAsStructureType() == nullptr)
+    return false;
+
+  const auto *structDecl = structType->getAsStructureType()->getDecl();
+  QualType firstElemType;
+  uint32_t totalCount = 0;
+
+  for (const auto *field : structDecl->fields()) {
+    QualType type;
+    uint32_t count = 1;
+
+    if (isScalarType(field->getType(), &type) ||
+        isVectorType(field->getType(), &type, &count)) {
+      if (firstElemType.isNull()) {
+        firstElemType = type;
+      } else {
+        if (!canTreatAsSameScalarType(firstElemType, type)) {
+          emitError("all struct members should have the same element type for "
+                    "resource template instantiation",
+                    structDecl->getLocation());
+          return false;
+        }
+      }
+      totalCount += count;
+    } else {
+      emitError("unsupported struct element type for resource template "
+                "instantiation",
+                structDecl->getLocation());
+      return false;
+    }
+  }
+
+  if (totalCount > 4) {
+    emitError(
+        "resource template element type %0 cannot fit into four 32-bit scalars",
+        structDecl->getLocation())
+        << structType;
+    return false;
+  }
+
+  if (elemType)
+    *elemType = firstElemType;
+  if (elemCount)
+    *elemCount = totalCount;
+  return true;
+}
+
 void TypeTranslator::alignUsingHLSLRelaxedLayout(QualType fieldType,
                                                  uint32_t fieldSize,
                                                  uint32_t *fieldAlignment,

+ 7 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -233,6 +233,13 @@ public:
   /// matrix type.
   uint32_t getComponentVectorType(QualType matrixType);
 
+  /// \brief Returns true if all members in structType are of the same element
+  /// type and can be fit into a 4-component vector. Writes element type and
+  /// count to *elemType and *elemCount if not nullptr. Otherwise, emit errors
+  /// explaining why not.
+  bool canFitIntoOneRegister(QualType structType, QualType *elemType = nullptr,
+                             uint32_t *elemCount = nullptr);
+
   /// \brief Returns the capability required for the given storage image type.
   /// Returns Capability::Max to mean no capability requirements.
   static spv::Capability getCapabilityForStorageImageReadWrite(QualType type);

+ 33 - 0
tools/clang/test/CodeGenSPIRV/op.buffer.access.hlsl

@@ -15,6 +15,14 @@ RWBuffer<int4> int4buf;
 RWBuffer<uint4> uint4buf;
 RWBuffer<float4> float4buf;
 
+struct S {
+  float  a;
+  float2 b;
+  float1 c;
+};
+
+  Buffer<S> sBuf;
+
 void main() {
   int address;
 
@@ -102,4 +110,29 @@ void main() {
 // CHECK-NEXT:     [[b:%\d+]] = OpLoad %float [[ac14]]
 // CHECK-NEXT:                  OpStore %b [[b]]
   float b = float4buf[address][2];
+
+// CHECK:        [[img:%\d+]] = OpLoad %type_buffer_image_7 %sBuf
+// CHECK-NEXT: [[fetch:%\d+]] = OpImageFetch %v4float [[img]] %uint_0 None
+// CHECK-NEXT:   [[s_a:%\d+]] = OpCompositeExtract %float [[fetch]] 0
+// CHECK-NEXT:   [[s_b:%\d+]] = OpVectorShuffle %v2float [[fetch]] [[fetch]] 1 2
+// CHECK-NEXT:   [[s_c:%\d+]] = OpCompositeExtract %float [[fetch]] 3
+// CHECK-NEXT:     [[s:%\d+]] = OpCompositeConstruct %S [[s_a]] [[s_b]] [[s_c]]
+// CHECK-NEXT:                  OpStore %temp_var_S [[s]]
+// CHECK-NEXT:   [[ptr:%\d+]] = OpAccessChain %_ptr_Function_float %temp_var_S %int_0
+// CHECK-NEXT:     [[c:%\d+]] = OpLoad %float [[ptr]]
+// CHECK-NEXT:                  OpStore %c [[c]]
+  float c = sBuf[0].a;
+
+// CHECK:        [[img:%\d+]] = OpLoad %type_buffer_image_7 %sBuf
+// CHECK-NEXT: [[fetch:%\d+]] = OpImageFetch %v4float [[img]] %uint_1 None
+// CHECK-NEXT:   [[s_a:%\d+]] = OpCompositeExtract %float [[fetch]] 0
+// CHECK-NEXT:   [[s_b:%\d+]] = OpVectorShuffle %v2float [[fetch]] [[fetch]] 1 2
+// CHECK-NEXT:   [[s_c:%\d+]] = OpCompositeExtract %float [[fetch]] 3
+// CHECK-NEXT:     [[s:%\d+]] = OpCompositeConstruct %S [[s_a]] [[s_b]] [[s_c]]
+// CHECK-NEXT:                  OpStore %temp_var_S_0 [[s]]
+// CHECK-NEXT:   [[ptr:%\d+]] = OpAccessChain %_ptr_Function_v2float %temp_var_S_0 %int_1
+// CHECK-NEXT:   [[val:%\d+]] = OpLoad %v2float [[ptr]]
+// CHECK-NEXT:     [[d:%\d+]] = OpCompositeExtract %float [[val]] 1
+// CHECK-NEXT:                  OpStore %d [[d]]
+  float d = sBuf[1].b.y;
 }

+ 19 - 0
tools/clang/test/CodeGenSPIRV/op.texture.access.hlsl

@@ -10,6 +10,13 @@ Texture2DArray   <int3>   t6;
 // There is no operator[] for TextureCubeArray in HLSL reference.
 // There is no operator[] for Texture2DMSArray in HLSL reference.
 
+struct S {
+  float  a;
+  float2 b;
+  float1 c;
+};
+
+Texture2D <S> tStruct;
 
 // CHECK:  [[cu12:%\d+]] = OpConstantComposite %v2uint %uint_1 %uint_2
 // CHECK: [[cu123:%\d+]] = OpConstantComposite %v3uint %uint_1 %uint_2 %uint_3
@@ -55,4 +62,16 @@ void main() {
 // CHECK-NEXT: [[result6:%\d+]] = OpVectorShuffle %v3int [[f6]] [[f6]] 0 1 2
 // CHECK-NEXT: OpStore %a6 [[result6]]
   int3   a6 = t6[uint3(1,2,3)];
+
+// CHECK:        [[tex:%\d+]] = OpLoad %type_2d_image_1 %tStruct
+// CHECK-NEXT: [[fetch:%\d+]] = OpImageFetch %v4float [[tex]] {{%\d+}} Lod %uint_0
+// CHECK-NEXT:     [[a:%\d+]] = OpCompositeExtract %float [[fetch]] 0
+// CHECK-NEXT:     [[b:%\d+]] = OpVectorShuffle %v2float [[fetch]] [[fetch]] 1 2
+// CHECK-NEXT:     [[c:%\d+]] = OpCompositeExtract %float [[fetch]] 3
+// CHECK-NEXT:     [[s:%\d+]] = OpCompositeConstruct %S [[a]] [[b]] [[c]]
+// CHECK-NEXT:                  OpStore %temp_var_S [[s]]
+// CHECK-NEXT:   [[ptr:%\d+]] = OpAccessChain %_ptr_Function_float %temp_var_S %int_2
+// CHECK-NEXT:   [[val:%\d+]] = OpLoad %float [[ptr]]
+// CHECK-NEXT:                  OpStore %a7 [[val]]
+  float  a7 = tStruct[uint2(1, 2)].c;
 }

+ 16 - 0
tools/clang/test/CodeGenSPIRV/type.buffer.hlsl

@@ -69,6 +69,20 @@ RWBuffer<uint4> uint4rwbuf;
 RWBuffer<float3> float3rwbuf;
 RWBuffer<float4> float4rwbuf;
 
+struct S {
+    float a;
+    float b;
+};
+
+struct T {
+    float1 a;
+    float2 b;
+};
+
+  Buffer<S> sBuf;
+
+  Buffer<T> tBuf;
+
 // CHECK: %intbuf = OpVariable %_ptr_UniformConstant_type_buffer_image UniformConstant
 // CHECK: %uintbuf = OpVariable %_ptr_UniformConstant_type_buffer_image_0 UniformConstant
 // CHECK: %floatbuf = OpVariable %_ptr_UniformConstant_type_buffer_image_1 UniformConstant
@@ -93,5 +107,7 @@ RWBuffer<float4> float4rwbuf;
 // CHECK: %uint4rwbuf = OpVariable %_ptr_UniformConstant_type_buffer_image_15 UniformConstant
 // CHECK: %float3rwbuf = OpVariable %_ptr_UniformConstant_type_buffer_image_16 UniformConstant
 // CHECK: %float4rwbuf = OpVariable %_ptr_UniformConstant_type_buffer_image_16 UniformConstant
+// CHECK:   %sBuf = OpVariable %_ptr_UniformConstant_type_buffer_image_7 UniformConstant
+// CHECK:   %tBuf = OpVariable %_ptr_UniformConstant_type_buffer_image_13 UniformConstant
 
 void main() {}

+ 14 - 0
tools/clang/test/CodeGenSPIRV/type.buffer.struct.error1.hlsl

@@ -0,0 +1,14 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct S {
+    float4 a;
+    float3 b;
+};
+
+Buffer<S> MyBuffer;
+
+float4 main(): SV_Target {
+    return MyBuffer[0].a;
+}
+
+// CHECK: :3:8: error: resource template element type 'S' cannot fit into four 32-bit scalars

+ 14 - 0
tools/clang/test/CodeGenSPIRV/type.buffer.struct.error2.hlsl

@@ -0,0 +1,14 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct S {
+    float2 a;
+    int1   b;
+};
+
+Buffer<S> MyBuffer;
+
+float4 main(): SV_Target {
+    return MyBuffer[0].a.x;
+}
+
+// CHECK: :3:8: error: all struct members should have the same element type for resource template instantiation

+ 18 - 0
tools/clang/test/CodeGenSPIRV/type.buffer.struct.error3.hlsl

@@ -0,0 +1,18 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct B {
+    float2 b;
+};
+
+struct S {
+    float2 a;
+    B      b;
+};
+
+Buffer<S> MyBuffer;
+
+float4 main(): SV_Target {
+    return MyBuffer[0].a.x;
+}
+
+// CHECK: :7:8: error: unsupported struct element type for resource template instantiation

+ 0 - 4
tools/clang/test/CodeGenSPIRV/type.buffer.struct.error.hlsl → tools/clang/test/CodeGenSPIRV/type.rwbuffer.struct.error.hlsl

@@ -5,10 +5,6 @@ struct S {
   float b;
 };
 
-// CHECK: error: cannot translate resource type parameter 'S' to proper image format
-// CHECK: error: unsupported resource type parameter 'S'
-  Buffer<S> MyBuffer;
-
 // CHECK: error: cannot instantiate RWBuffer with struct type 'S'
 RWBuffer<S> MyRWBuffer;
 

+ 17 - 0
tools/clang/test/CodeGenSPIRV/type.texture.hlsl

@@ -30,6 +30,8 @@
 // CHECK: %type_2d_image_array_0 = OpTypeImage %uint 2D 2 1 1 1 Unknown
 // CHECK: %_ptr_UniformConstant_type_2d_image_array_0 = OpTypePointer UniformConstant %type_2d_image_array_0
 
+// CHECK: %type_2d_image_array_1 = OpTypeImage %float 2D 2 1 0 1 Unknown
+// CHECK: %_ptr_UniformConstant_type_2d_image_array_1 = OpTypePointer UniformConstant %type_2d_image_array_1
 
 // CHECK: %t1 = OpVariable %_ptr_UniformConstant_type_1d_image UniformConstant
 Texture1D   <float4> t1 : register(t1);
@@ -52,6 +54,21 @@ Texture2DMS      <int3>   t8 : register(t8);
 // CHECK: %t9 = OpVariable %_ptr_UniformConstant_type_2d_image_array_0 UniformConstant
 Texture2DMSArray <uint4>  t9 : register(t9);
 
+struct S {
+    float a;
+    float b;
+};
+
+struct T {
+    float1 a;
+    float2 b;
+};
+
+// CHECK: %sTex = OpVariable %_ptr_UniformConstant_type_1d_image UniformConstant
+Texture1D<S>      sTex;
+// CHECK: %tTex = OpVariable %_ptr_UniformConstant_type_2d_image_array_1 UniformConstant
+Texture2DArray<T> tTex;
+
 void main() {
 // CHECK-LABEL: %main = OpFunction
 }

+ 11 - 2
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -63,8 +63,17 @@ TEST_F(FileTest, SamplerTypes) { runFileTest("type.sampler.hlsl"); }
 TEST_F(FileTest, TextureTypes) { runFileTest("type.texture.hlsl"); }
 TEST_F(FileTest, RWTextureTypes) { runFileTest("type.rwtexture.hlsl"); }
 TEST_F(FileTest, BufferType) { runFileTest("type.buffer.hlsl"); }
-TEST_F(FileTest, BufferTypeStructError) {
-  runFileTest("type.buffer.struct.error.hlsl", Expect::Failure);
+TEST_F(FileTest, BufferTypeStructError1) {
+  runFileTest("type.buffer.struct.error1.hlsl", Expect::Failure);
+}
+TEST_F(FileTest, BufferTypeStructError2) {
+  runFileTest("type.buffer.struct.error2.hlsl", Expect::Failure);
+}
+TEST_F(FileTest, BufferTypeStructError3) {
+  runFileTest("type.buffer.struct.error3.hlsl", Expect::Failure);
+}
+TEST_F(FileTest, RWBufferTypeStructError) {
+  runFileTest("type.rwbuffer.struct.error.hlsl", Expect::Failure);
 }
 TEST_F(FileTest, CBufferType) { runFileTest("type.cbuffer.hlsl"); }
 TEST_F(FileTest, ConstantBufferType) {