Jelajahi Sumber

[spirv] Support OO inheritance in GS stage IOs (#1043)

We have special handling of the extra level of arrayness and
different code path for emitting vertices in GS.
Lei Zhang 7 tahun lalu
induk
melakukan
a7aaccef9f

+ 30 - 6
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -1432,6 +1432,16 @@ bool DeclResultIdMapper::createStageVars(const hlsl::SigPoint *sigPoint,
     for (uint32_t arrayIndex = 0; arrayIndex < arraySize; ++arrayIndex) {
       llvm::SmallVector<uint32_t, 8> fields;
 
+      // If we have base classes, we need to handle them first.
+      if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
+        uint32_t baseIndex = 0;
+        for (auto base : cxxDecl->bases()) {
+          const auto baseType = typeTranslator.translateType(base.getType());
+          fields.push_back(theBuilder.createCompositeExtract(
+              baseType, subValues[baseIndex++], {arrayIndex}));
+        }
+      }
+
       // Extract the element at index arrayIndex from each field
       for (const auto *field : structDecl->fields()) {
         const uint32_t fieldType =
@@ -1497,12 +1507,10 @@ bool DeclResultIdMapper::createStageVars(const hlsl::SigPoint *sigPoint,
   return true;
 }
 
-bool DeclResultIdMapper::writeBackOutputStream(const ValueDecl *decl,
-                                               uint32_t value) {
+bool DeclResultIdMapper::writeBackOutputStream(const NamedDecl *decl,
+                                               QualType type, uint32_t value) {
   assert(shaderModel.IsGS()); // Only for GS use
 
-  QualType type = decl->getType();
-
   if (hlsl::IsHLSLStreamOutputType(type))
     type = hlsl::GetHLSLResourceResultType(type);
   if (hasGSPrimitiveTypeQualifier(decl))
@@ -1524,7 +1532,9 @@ bool DeclResultIdMapper::writeBackOutputStream(const ValueDecl *decl,
 
     // Query the <result-id> for the stage output variable generated out
     // of this decl.
-    const auto found = stageVarIds.find(decl);
+    // We have semantic string attached to this decl; therefore, it must be a
+    // DeclaratorDecl.
+    const auto found = stageVarIds.find(cast<DeclaratorDecl>(decl));
 
     // We should have recorded its stage output variable previously.
     assert(found != stageVarIds.end());
@@ -1554,6 +1564,20 @@ bool DeclResultIdMapper::writeBackOutputStream(const ValueDecl *decl,
     return false;
   }
 
+  // If we have base classes, we need to handle them first.
+  if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
+    uint32_t baseIndex = 0;
+    for (auto base : cxxDecl->bases()) {
+      const auto baseType = typeTranslator.translateType(base.getType());
+      const auto subValue =
+          theBuilder.createCompositeExtract(baseType, value, {baseIndex++});
+
+      if (!writeBackOutputStream(base.getType()->getAsCXXRecordDecl(),
+                                 base.getType(), subValue))
+        return false;
+    }
+  }
+
   const auto *structDecl = type->getAs<RecordType>()->getDecl();
 
   // Write out each field
@@ -1562,7 +1586,7 @@ bool DeclResultIdMapper::writeBackOutputStream(const ValueDecl *decl,
     const uint32_t subValue = theBuilder.createCompositeExtract(
         fieldType, value, {getNumBaseClasses(type) + field->getFieldIndex()});
 
-    if (!writeBackOutputStream(field, subValue))
+    if (!writeBackOutputStream(field, field->getType(), subValue))
       return false;
   }
 

+ 2 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -424,7 +424,8 @@ public:
   ///
   /// This method is specially for writing back per-vertex data at the time of
   /// OpEmitVertex in GS.
-  bool writeBackOutputStream(const ValueDecl *decl, uint32_t value);
+  bool writeBackOutputStream(const NamedDecl *decl, QualType type,
+                             uint32_t value);
 
   /// \brief Decorates all stage input and output variables with proper
   /// location and returns true on success.

+ 1 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -3210,7 +3210,7 @@ SPIRVEmitter::processStreamOutputAppend(const CXXMemberCallExpr *expr) {
   const auto *stream = cast<DeclRefExpr>(object)->getDecl();
   const uint32_t value = doExpr(expr->getArg(0));
 
-  declIdMapper.writeBackOutputStream(stream, value);
+  declIdMapper.writeBackOutputStream(stream, stream->getType(), value);
   theBuilder.createEmitVertex();
 
   return 0;

+ 78 - 0
tools/clang/test/CodeGenSPIRV/oo.inheritance.stage-io.gs.hlsl

@@ -0,0 +1,78 @@
+// Run: %dxc -T gs_6_0 -E main
+
+struct Empty { };
+
+struct Base : Empty {
+    float4 a  : AAA;
+    float4 pos: SV_Position;
+};
+
+struct Derived : Base {
+    float4 b  : BBB;
+};
+
+// CHECK-LABEL: %main = OpFunction
+
+// CHECK:         [[empty0:%\d+]] = OpCompositeConstruct %Empty
+// CHECK-NEXT:    [[empty1:%\d+]] = OpCompositeConstruct %Empty
+// CHECK-NEXT: [[empty_arr:%\d+]] = OpCompositeConstruct %_arr_Empty_uint_2 [[empty0]] [[empty1]]
+
+// CHECK-NEXT:     [[a_arr:%\d+]] = OpLoad %_arr_v4float_uint_2 %in_var_AAA
+
+// CHECK-NEXT:  [[pos0_ptr:%\d+]] = OpAccessChain %_ptr_Input_v4float %gl_PerVertexIn %uint_0 %uint_0
+// CHECK-NEXT:      [[pos0:%\d+]] = OpLoad %v4float [[pos0_ptr]]
+// CHECK-NEXT:  [[pos1_ptr:%\d+]] = OpAccessChain %_ptr_Input_v4float %gl_PerVertexIn %uint_1 %uint_0
+// CHECK-NEXT:      [[pos1:%\d+]] = OpLoad %v4float [[pos1_ptr]]
+// CHECK-NEXT:   [[pos_arr:%\d+]] = OpCompositeConstruct %_arr_v4float_uint_2 [[pos0]] [[pos1]]
+
+// CHECK-NEXT:    [[empty0:%\d+]] = OpCompositeExtract %Empty [[empty_arr]] 0
+// CHECK-NEXT:        [[a0:%\d+]] = OpCompositeExtract %v4float [[a_arr]] 0
+// CHECK-NEXT:      [[pos0:%\d+]] = OpCompositeExtract %v4float [[pos_arr]] 0
+// CHECK-NEXT:     [[base0:%\d+]] = OpCompositeConstruct %Base [[empty0]] [[a0]] [[pos0]]
+
+// CHECK-NEXT:    [[empty1:%\d+]] = OpCompositeExtract %Empty [[empty_arr]] 1
+// CHECK-NEXT:        [[a1:%\d+]] = OpCompositeExtract %v4float [[a_arr]] 1
+// CHECK-NEXT:      [[pos1:%\d+]] = OpCompositeExtract %v4float [[pos_arr]] 1
+// CHECK-NEXT:     [[base1:%\d+]] = OpCompositeConstruct %Base [[empty1]] [[a1]] [[pos1]]
+
+// CHECK-NEXT:  [[base_arr:%\d+]] = OpCompositeConstruct %_arr_Base_uint_2 [[base0]] [[base1]]
+
+// CHECK-NEXT:     [[b_arr:%\d+]] = OpLoad %_arr_v4float_uint_2 %in_var_BBB
+
+// CHECK-NEXT:     [[base0:%\d+]] = OpCompositeExtract %Base [[base_arr]] 0
+// CHECK-NEXT:        [[b0:%\d+]] = OpCompositeExtract %v4float [[b_arr]] 0
+// CHECK-NEXT:  [[derived0:%\d+]] = OpCompositeConstruct %Derived [[base0]] [[b0]]
+
+// CHECK-NEXT:     [[base1:%\d+]] = OpCompositeExtract %Base [[base_arr]] 1
+// CHECK-NEXT:        [[b1:%\d+]] = OpCompositeExtract %v4float [[b_arr]] 1
+// CHECK-NEXT:  [[derived1:%\d+]] = OpCompositeConstruct %Derived [[base1]] [[b1]]
+
+// CHECK-NEXT:    [[inData:%\d+]] = OpCompositeConstruct %_arr_Derived_uint_2 [[derived0]] [[derived1]]
+// CHECK-NEXT:                      OpStore %param_var_inData [[inData]]
+
+// CHECK-LABEL: %src_main = OpFunction
+
+[maxvertexcount(2)]
+void main(in    line Derived             inData[2],
+          inout      LineStream<Derived> outData)
+{
+// CHECK:            [[ptr:%\d+]] = OpAccessChain %_ptr_Function_Derived %inData %int_0
+// CHECK-NEXT:   [[inData0:%\d+]] = OpLoad %Derived [[ptr]]
+// CHECK-NEXT:      [[base:%\d+]] = OpCompositeExtract %Base [[inData0]] 0
+
+// CHECK-NEXT:           {{%\d+}} = OpCompositeExtract %Empty [[base]] 0
+
+// CHECK-NEXT:         [[a:%\d+]] = OpCompositeExtract %v4float [[base]] 1
+// CHECK-NEXT:                      OpStore %out_var_AAA [[a]]
+
+// CHECK-NEXT:       [[pos:%\d+]] = OpCompositeExtract %v4float [[base]] 2
+// CHECK-NEXT:                      OpStore %gl_Position [[pos]]
+
+// CHECK-NEXT:         [[b:%\d+]] = OpCompositeExtract %v4float [[inData0]] 1
+// CHECK-NEXT:                      OpStore %out_var_BBB [[b]]
+
+// CHECK-NEXT:       OpEmitVertex
+    outData.Append(inData[0]);
+
+    outData.RestartStrip();
+}

+ 0 - 0
tools/clang/test/CodeGenSPIRV/oo.inheritance.stage-io.hlsl → tools/clang/test/CodeGenSPIRV/oo.inheritance.stage-io.vs.hlsl


+ 5 - 2
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -417,8 +417,11 @@ TEST_F(FileTest, MethodCallOnStaticVar) {
   runFileTest("oo.method.on-static-var.hlsl");
 }
 TEST_F(FileTest, Inheritance) { runFileTest("oo.inheritance.hlsl"); }
-TEST_F(FileTest, InheritanceStageIO) {
-  runFileTest("oo.inheritance.stage-io.hlsl");
+TEST_F(FileTest, InheritanceStageIOVS) {
+  runFileTest("oo.inheritance.stage-io.vs.hlsl");
+}
+TEST_F(FileTest, InheritanceStageIOGS) {
+  runFileTest("oo.inheritance.stage-io.gs.hlsl");
 }
 
 // For semantics