Quellcode durchsuchen

[spirv] Handle OO inheritance in stage IOs (#1041)

We need to go over all base classes and handle stage IOs inside
them too.
Lei Zhang vor 7 Jahren
Ursprung
Commit
902b1ee3c1

+ 47 - 11
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -191,7 +191,7 @@ bool CounterVarFields::assign(const CounterVarFields &srcFields,
 }
 
 DeclResultIdMapper::SemanticInfo
-DeclResultIdMapper::getStageVarSemantic(const ValueDecl *decl) {
+DeclResultIdMapper::getStageVarSemantic(const NamedDecl *decl) {
   for (auto *annotation : decl->getUnusualAnnotations()) {
     if (auto *sema = dyn_cast<hlsl::SemanticDecl>(annotation)) {
       llvm::StringRef semanticStr = sema->SemanticName;
@@ -1080,11 +1080,13 @@ bool DeclResultIdMapper::decorateResourceBindings() {
   return true;
 }
 
-bool DeclResultIdMapper::createStageVars(
-    const hlsl::SigPoint *sigPoint, const DeclaratorDecl *decl, bool asInput,
-    QualType type, uint32_t arraySize, const llvm::StringRef namePrefix,
-    llvm::Optional<uint32_t> invocationId, uint32_t *value, bool noWriteBack,
-    SemanticInfo *inheritSemantic) {
+bool DeclResultIdMapper::createStageVars(const hlsl::SigPoint *sigPoint,
+                                         const NamedDecl *decl, bool asInput,
+                                         QualType type, uint32_t arraySize,
+                                         const llvm::StringRef namePrefix,
+                                         llvm::Optional<uint32_t> invocationId,
+                                         uint32_t *value, bool noWriteBack,
+                                         SemanticInfo *inheritSemantic) {
   // invocationId should only be used for handling HS per-vertex output.
   if (invocationId.hasValue()) {
     assert(shaderModel.IsHS() && arraySize != 0 && !asInput);
@@ -1221,7 +1223,9 @@ bool DeclResultIdMapper::createStageVars(
     stageVar.setSpirvId(varId);
     stageVar.setLocationAttr(decl->getAttr<VKLocationAttr>());
     stageVars.push_back(stageVar);
-    stageVarIds[decl] = varId;
+    // We have semantics attached to this decl, which means it must be a
+    // function/parameter/variable. All are DeclaratorDecls.
+    stageVarIds[cast<DeclaratorDecl>(decl)] = varId;
 
     // Mark that we have used one index for this semantic
     ++semanticToUse->index;
@@ -1388,6 +1392,18 @@ bool DeclResultIdMapper::createStageVars(
     // load their values into a composite.
     llvm::SmallVector<uint32_t, 4> subValues;
 
+    // If we have base classes, we need to handle them first.
+    if (const auto *cxxDecl = type->getAsCXXRecordDecl())
+      for (auto base : cxxDecl->bases()) {
+        uint32_t subValue = 0;
+        if (!createStageVars(sigPoint, base.getType()->getAsCXXRecordDecl(),
+                             asInput, base.getType(), arraySize, namePrefix,
+                             invocationId, &subValue, noWriteBack,
+                             semanticToUse))
+          return false;
+        subValues.push_back(subValue);
+      }
+
     for (const auto *field : structDecl->fields()) {
       uint32_t subValue = 0;
       if (!createStageVars(sigPoint, field, asInput, field->getType(),
@@ -1432,6 +1448,24 @@ bool DeclResultIdMapper::createStageVars(
 
     *value = theBuilder.createCompositeConstruct(arrayType, arrayElements);
   } else {
+    // If we have base classes, we need to handle them first.
+    if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
+      uint32_t baseIndex = 0;
+      for (auto base : cxxDecl->bases()) {
+        uint32_t subValue = 0;
+        if (!noWriteBack)
+          subValue = theBuilder.createCompositeExtract(
+              typeTranslator.translateType(base.getType()), *value,
+              {baseIndex++});
+
+        if (!createStageVars(sigPoint, base.getType()->getAsCXXRecordDecl(),
+                             asInput, base.getType(), arraySize, namePrefix,
+                             invocationId, &subValue, noWriteBack,
+                             semanticToUse))
+          return false;
+      }
+    }
+
     // Unlike reading, which may require us to read stand-alone builtins and
     // stage input variables and compose an array of structs out of them,
     // it happens that we don't need to write an array of structs in a bunch
@@ -1535,7 +1569,7 @@ bool DeclResultIdMapper::writeBackOutputStream(const ValueDecl *decl,
   return true;
 }
 
-void DeclResultIdMapper::decoratePSInterpolationMode(const DeclaratorDecl *decl,
+void DeclResultIdMapper::decoratePSInterpolationMode(const NamedDecl *decl,
                                                      QualType type,
                                                      uint32_t varId) {
   const QualType elemType = typeTranslator.getElementType(type);
@@ -1569,7 +1603,7 @@ void DeclResultIdMapper::decoratePSInterpolationMode(const DeclaratorDecl *decl,
 }
 
 uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
-                                                 const DeclaratorDecl *decl,
+                                                 const NamedDecl *decl,
                                                  const llvm::StringRef name,
                                                  SourceLocation srcLoc) {
   using spv::BuiltIn;
@@ -1924,12 +1958,14 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
   return 0;
 }
 
-bool DeclResultIdMapper::validateVKBuiltins(const DeclaratorDecl *decl,
+bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl,
                                             const hlsl::SigPoint *sigPoint) {
   bool success = true;
 
   if (const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>()) {
-    const auto declType = getTypeOrFnRetType(decl);
+    // The front end parsing only allows vk::builtin to be attached to a
+    // function/parameter/variable; all of them are DeclaratorDecls.
+    const auto declType = getTypeOrFnRetType(cast<DeclaratorDecl>(decl));
     const auto loc = builtinAttr->getLocation();
 
     if (decl->hasAttr<VKLocationAttr>()) {

+ 7 - 7
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -526,7 +526,7 @@ private:
   };
 
   /// Returns the given decl's HLSL semantic information.
-  static SemanticInfo getStageVarSemantic(const ValueDecl *decl);
+  static SemanticInfo getStageVarSemantic(const NamedDecl *decl);
 
   /// Creates all the stage variables mapped from semantics on the given decl.
   /// Returns true on sucess.
@@ -552,9 +552,9 @@ private:
   /// If inheritSemantic is valid, it will override all semantics attached to
   /// the children of this decl, and the children of this decl will be using
   /// the semantic in inheritSemantic, with index increasing sequentially.
-  bool createStageVars(const hlsl::SigPoint *sigPoint,
-                       const DeclaratorDecl *decl, bool asInput, QualType type,
-                       uint32_t arraySize, const llvm::StringRef namePrefix,
+  bool createStageVars(const hlsl::SigPoint *sigPoint, const NamedDecl *decl,
+                       bool asInput, QualType type, uint32_t arraySize,
+                       const llvm::StringRef namePrefix,
                        llvm::Optional<uint32_t> invocationId, uint32_t *value,
                        bool noWriteBack, SemanticInfo *inheritSemantic);
 
@@ -562,11 +562,11 @@ private:
   /// the <result-id>. Also sets whether the StageVar is a SPIR-V builtin and
   /// its storage class accordingly. name will be used as the debug name when
   /// creating a stage input/output variable.
-  uint32_t createSpirvStageVar(StageVar *, const DeclaratorDecl *decl,
+  uint32_t createSpirvStageVar(StageVar *, const NamedDecl *decl,
                                const llvm::StringRef name, SourceLocation);
 
   /// Returns true if all vk::builtin usages are valid.
-  bool validateVKBuiltins(const DeclaratorDecl *decl,
+  bool validateVKBuiltins(const NamedDecl *decl,
                           const hlsl::SigPoint *sigPoint);
 
   /// Methods for creating counter variables associated with the given decl.
@@ -595,7 +595,7 @@ private:
 
   /// Decorates varId of the given asType with proper interpolation modes
   /// considering the attributes on the given decl.
-  void decoratePSInterpolationMode(const DeclaratorDecl *decl, QualType asType,
+  void decoratePSInterpolationMode(const NamedDecl *decl, QualType asType,
                                    uint32_t varId);
 
   /// Returns the proper SPIR-V storage class (Input or Output) for the given

+ 71 - 0
tools/clang/test/CodeGenSPIRV/oo.inheritance.stage-io.hlsl

@@ -0,0 +1,71 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct S {
+    float4 m : MMM;
+};
+
+struct T {
+    float3 n : NNN;
+};
+
+struct Base {
+    float4 a : AAA;
+    float4 b : BBB;
+    S      s;
+    float4 p : SV_Position;
+};
+
+struct Derived : Base {
+    T      t;
+    float4 c : CCC;
+    float4 d : DDD;
+};
+
+void main(in Derived input, out Derived output) {
+// CHECK:         [[a:%\d+]] = OpLoad %v4float %in_var_AAA
+// CHECK-NEXT:    [[b:%\d+]] = OpLoad %v4float %in_var_BBB
+
+// CHECK-NEXT:    [[m:%\d+]] = OpLoad %v4float %in_var_MMM
+// CHECK-NEXT:    [[s:%\d+]] = OpCompositeConstruct %S [[m]]
+
+// CHECK-NEXT:  [[pos:%\d+]] = OpLoad %v4float %in_var_SV_Position
+
+// CHECK-NEXT: [[base:%\d+]] = OpCompositeConstruct %Base [[a]] [[b]] [[s]] [[pos]]
+
+// CHECK-NEXT:    [[n:%\d+]] = OpLoad %v3float %in_var_NNN
+// CHECK-NEXT:    [[t:%\d+]] = OpCompositeConstruct %T [[n]]
+
+// CHECK-NEXT:    [[c:%\d+]] = OpLoad %v4float %in_var_CCC
+// CHECK-NEXT:    [[d:%\d+]] = OpLoad %v4float %in_var_DDD
+
+// CHECK-NEXT:  [[drv:%\d+]] = OpCompositeConstruct %Derived [[base]] [[t]] [[c]] [[d]]
+// CHECK-NEXT:                 OpStore %param_var_input [[drv]]
+
+// CHECK-NEXT:      {{%\d+}} = OpFunctionCall %void %src_main %param_var_input %param_var_output
+
+// CHECK-NEXT:  [[drv:%\d+]] = OpLoad %Derived %param_var_output
+
+// CHECK-NEXT: [[base:%\d+]] = OpCompositeExtract %Base [[drv]] 0
+// CHECK-NEXT:    [[a:%\d+]] = OpCompositeExtract %v4float [[base]] 0
+// CHECK-NEXT:                 OpStore %out_var_AAA [[a]]
+// CHECK-NEXT:    [[b:%\d+]] = OpCompositeExtract %v4float [[base]] 1
+// CHECK-NEXT:                 OpStore %out_var_BBB [[b]]
+
+// CHECK-NEXT:    [[s:%\d+]] = OpCompositeExtract %S [[base]] 2
+// CHECK-NEXT:    [[m:%\d+]] = OpCompositeExtract %v4float [[s]] 0
+// CHECK-NEXT:                 OpStore %out_var_MMM [[m]]
+
+// CHECK-NEXT:  [[pos:%\d+]] = OpCompositeExtract %v4float [[base]] 3
+// CHECK-NEXT:  [[ptr:%\d+]] = OpAccessChain %_ptr_Output_v4float %gl_PerVertexOut %uint_0
+// CHECK-NEXT:                 OpStore [[ptr]] [[pos]]
+
+// CHECK-NEXT:    [[t:%\d+]] = OpCompositeExtract %T [[drv]] 1
+// CHECK-NEXT:    [[n:%\d+]] = OpCompositeExtract %v3float [[t]] 0
+// CHECK-NEXT:                 OpStore %out_var_NNN [[n]]
+
+// CHECK-NEXT:    [[c:%\d+]] = OpCompositeExtract %v4float [[drv]] 2
+// CHECK-NEXT:                 OpStore %out_var_CCC [[c]]
+// CHECK-NEXT:    [[d:%\d+]] = OpCompositeExtract %v4float [[drv]] 3
+// CHECK-NEXT:                 OpStore %out_var_DDD [[d]]
+    output = input;
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -417,6 +417,9 @@ TEST_F(FileTest, MethodCallOnStaticVar) {
   runFileTest("oo.method.on-static-var.hlsl");
 }
 TEST_F(FileTest, Inheritance) { runFileTest("oo.inheritance.hlsl"); }
+TEST_F(FileTest, InheritanceStageIO) {
+  runFileTest("oo.inheritance.stage-io.hlsl");
+}
 
 // For semantics
 // SV_Position, SV_ClipDistance, and SV_CullDistance are covered in