浏览代码

[spirv] Support struct -> struct FlatConversion.

Ehsan Nasiri 6 年之前
父节点
当前提交
8cdb542225

+ 24 - 0
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -648,6 +648,30 @@ bool isSameType(const ASTContext &astContext, QualType type1, QualType type2) {
                           arrType2->getElementType());
   }
 
+  { // Two structures with identical fields
+    if (const auto *structType1 = type1->getAs<RecordType>()) {
+      if (const auto *structType2 = type2->getAs<RecordType>()) {
+        llvm::SmallVector<QualType, 4> fieldTypes1;
+        llvm::SmallVector<QualType, 4> fieldTypes2;
+        for (const auto *field : structType1->getDecl()->fields())
+          fieldTypes1.push_back(field->getType());
+        for (const auto *field : structType2->getDecl()->fields())
+          fieldTypes2.push_back(field->getType());
+        // Note: We currently do NOT consider such cases as equal types:
+        // struct s1 { int x; int y; }
+        // struct s2 { int2 x; }
+        // Therefore if two structs have different number of members, we
+        // consider them different.
+        if (fieldTypes1.size() != fieldTypes2.size())
+          return false;
+        for (auto i = 0; i < fieldTypes1.size(); ++i)
+          if (!isSameType(astContext, fieldTypes1[i], fieldTypes2[i]))
+            return false;
+        return true;
+      }
+    }
+  }
+
   // TODO: support other types if needed
 
   return false;

+ 14 - 5
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -2329,15 +2329,24 @@ SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr) {
     }
     // For assigning one array instance to another one with the same array type
     // (regardless of constness and literalness), the rhs will be wrapped in a
-    // FlatConversion:
+    // FlatConversion. Similarly for assigning a struct to another struct with
+    // identical members.
     //  |- <lhs>
     //  `- ImplicitCastExpr <FlatConversion>
     //     `- ImplicitCastExpr <LValueToRValue>
     //        `- <rhs>
-    // This FlatConversion does not affect CodeGen, so that we can ignore it.
-    else if (subExprType->isArrayType() &&
-             isSameType(astContext, expr->getType(), subExprType)) {
-      return doExpr(subExpr);
+    else if (isSameType(astContext, toType, evalType) ||
+             // We can have casts changing the shape but without affecting
+             // memory order, e.g., `float4 a[2]; float b[8] = (float[8])a;`.
+             // This is also represented as FlatConversion. For such cases, we
+             // can rely on the InitListHandler, which can decompse
+             // vectors/matrices.
+             subExprType->isArrayType()) {
+      auto *valInstr =
+          InitListHandler(astContext, *this).processCast(toType, subExpr);
+      if (valInstr)
+        valInstr->setRValue();
+      return valInstr;
     }
     // We can have casts changing the shape but without affecting memory order,
     // e.g., `float4 a[2]; float b[8] = (float[8])a;`. This is also represented

+ 60 - 0
tools/clang/test/CodeGenSPIRV/cast.flat-conversion.struct-to-struct.hlsl

@@ -0,0 +1,60 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// Processing FlatConversion when source and destination
+// are both structures with identical members.
+
+struct FirstStruct {
+  float3 anArray[4];
+  float2x3 mats[1];
+  int2 ints[3];
+};
+
+struct SecondStruct {
+  float3 anArray[4];
+  float2x3 mats[1];
+  int2 ints[3];
+};
+
+RWStructuredBuffer<FirstStruct> rwBuf : register(u0);
+[ numthreads ( 16 , 16 , 1 ) ]
+void main() {
+  SecondStruct values;
+  FirstStruct v;
+
+// Yes, this is a FlatConversion!
+// CHECK:      [[v0ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_v3float_uint_4_0 %values %int_0
+// CHECK-NEXT:    [[v0:%\d+]] = OpLoad %_arr_v3float_uint_4_0 [[v0ptr]]
+// CHECK-NEXT: [[v1ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_mat2v3float_uint_1_0 %values %int_1
+// CHECK-NEXT:    [[v1:%\d+]] = OpLoad %_arr_mat2v3float_uint_1_0 [[v1ptr]]
+// CHECK-NEXT: [[v2ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_v2int_uint_3_0 %values %int_2
+// CHECK-NEXT:    [[v2:%\d+]] = OpLoad %_arr_v2int_uint_3_0 [[v2ptr]]
+// CHECK-NEXT:     [[v:%\d+]] = OpCompositeConstruct %FirstStruct_0 [[v0]] [[v1]] [[v2]]
+// CHECK-NEXT:                  OpStore %v [[v]]
+  v = values;
+
+// CHECK:          [[v0ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_v3float_uint_4_0 %values %int_0
+// CHECK-NEXT:        [[v0:%\d+]] = OpLoad %_arr_v3float_uint_4_0 [[v0ptr]]
+// CHECK-NEXT:     [[v1ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_mat2v3float_uint_1_0 %values %int_1
+// CHECK-NEXT:        [[v1:%\d+]] = OpLoad %_arr_mat2v3float_uint_1_0 [[v1ptr]]
+// CHECK-NEXT:     [[v2ptr:%\d+]] = OpAccessChain %_ptr_Function__arr_v2int_uint_3_0 %values %int_2
+// CHECK-NEXT:        [[v2:%\d+]] = OpLoad %_arr_v2int_uint_3_0 [[v2ptr]]
+// CHECK-NEXT:    [[values:%\d+]] = OpCompositeConstruct %FirstStruct_0 [[v0]] [[v1]] [[v2]]
+// CHECK-NEXT: [[rwBuf_ptr:%\d+]] = OpAccessChain %_ptr_Uniform_FirstStruct %rwBuf %int_0 %uint_0
+// CHECK-NEXT:   [[anArray:%\d+]] = OpCompositeExtract %_arr_v3float_uint_4_0 [[values]] 0
+// CHECK-NEXT:  [[anArray1:%\d+]] = OpCompositeExtract %v3float [[anArray]] 0
+// CHECK-NEXT:  [[anArray2:%\d+]] = OpCompositeExtract %v3float [[anArray]] 1
+// CHECK-NEXT:  [[anArray3:%\d+]] = OpCompositeExtract %v3float [[anArray]] 2
+// CHECK-NEXT:  [[anArray4:%\d+]] = OpCompositeExtract %v3float [[anArray]] 3
+// CHECK-NEXT:      [[res1:%\d+]] = OpCompositeConstruct %_arr_v3float_uint_4 [[anArray1]] [[anArray2]] [[anArray3]] [[anArray4]]
+// CHECK-NEXT:      [[mats:%\d+]] = OpCompositeExtract %_arr_mat2v3float_uint_1_0 [[values]] 1
+// CHECK-NEXT:       [[mat:%\d+]] = OpCompositeExtract %mat2v3float [[mats]] 0
+// CHECK-NEXT:      [[res2:%\d+]] = OpCompositeConstruct %_arr_mat2v3float_uint_1 [[mat]]
+// CHECK-NEXT:      [[ints:%\d+]] = OpCompositeExtract %_arr_v2int_uint_3_0 [[values]] 2
+// CHECK-NEXT:     [[ints1:%\d+]] = OpCompositeExtract %v2int [[ints]] 0
+// CHECK-NEXT:     [[ints2:%\d+]] = OpCompositeExtract %v2int [[ints]] 1
+// CHECK-NEXT:     [[ints3:%\d+]] = OpCompositeExtract %v2int [[ints]] 2
+// CHECK-NEXT:      [[res3:%\d+]] = OpCompositeConstruct %_arr_v2int_uint_3 [[ints1]] [[ints2]] [[ints3]]
+// CHECK-NEXT:    [[result:%\d+]] = OpCompositeConstruct %FirstStruct [[res1]] [[res2]] [[res3]]
+// CHECK-NEXT:                      OpStore [[rwBuf_ptr]] [[result]]
+  rwBuf[0] = values;
+}

+ 5 - 3
tools/clang/test/CodeGenSPIRV/cast.flat-conversion.struct.hlsl

@@ -16,9 +16,11 @@ float4 main(float4 a: A) : SV_Target {
 // CHECK-NEXT:              OpStore %s [[s]]
     S s = (S)a;
 
-// CHECK:      [[s:%\d+]] = OpLoad %S %s
-// CHECK-NEXT: [[t:%\d+]] = OpCompositeConstruct %T [[s]]
-// CHECK-NEXT:              OpStore %t [[t]]
+// CHECK:      [[valptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %s %int_0
+// CHECK-NEXT:    [[val:%\d+]] = OpLoad %v4float [[valptr]]
+// CHECK-NEXT:      [[s:%\d+]] = OpCompositeConstruct %S [[val]]
+// CHECK-NEXT:      [[t:%\d+]] = OpCompositeConstruct %T [[s]]
+// CHECK-NEXT:                   OpStore %t [[t]]
     T t = (T)s;
 
     return s.val + t.val.val;

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -395,6 +395,9 @@ TEST_F(FileTest, CastFlatConversionStruct) {
 TEST_F(FileTest, CastFlatConversionNoOp) {
   runFileTest("cast.flat-conversion.no-op.hlsl");
 }
+TEST_F(FileTest, CastFlatConversionStructToStruct) {
+  runFileTest("cast.flat-conversion.struct-to-struct.hlsl");
+}
 TEST_F(FileTest, CastFlatConversionLiteralInitializer) {
   runFileTest("cast.flat-conversion.literal-initializer.hlsl");
 }