Преглед на файлове

spirv: fix GL semantic in inherited classes (#5144)

When a semantic was in an inherited class, it was missed and caused
issues later down the pipe. This was because we assumed semantic fields
would only be present at the top-level struct, ignoring inheritance.

This commit also changes the parameter type from `DeclaratorDecl` to
`NamedDecl` as it only requires methods up to `NamedDecl`. Loosening
this requirement allowed me to pass a CXXRecordDecl.

Fixes #5138

Signed-off-by: Nathan Gauër <[email protected]>
Nathan Gauër преди 2 години
родител
ревизия
cbe8ec28a3

+ 13 - 5
tools/clang/lib/SPIRV/GlPerVertex.cpp

@@ -52,7 +52,7 @@ inline QualType getTypeOrFnRetType(const DeclaratorDecl *decl) {
 
 /// Returns true if the given declaration has a primitive type qualifier.
 /// Returns false otherwise.
-inline bool hasGSPrimitiveTypeQualifier(const DeclaratorDecl *decl) {
+inline bool hasGSPrimitiveTypeQualifier(const NamedDecl *decl) {
   return decl->hasAttr<HLSLTriangleAttr>() ||
          decl->hasAttr<HLSLTriangleAdjAttr>() ||
          decl->hasAttr<HLSLPointAttr>() || decl->hasAttr<HLSLLineAttr>() ||
@@ -322,8 +322,8 @@ bool GlPerVertex::setClipCullDistanceType(SemanticIndexToTypeMap *typeMap,
   return true;
 }
 
-bool GlPerVertex::doGlPerVertexFacts(const DeclaratorDecl *decl,
-                                     QualType baseType, bool asInput) {
+bool GlPerVertex::doGlPerVertexFacts(const NamedDecl *decl, QualType baseType,
+                                     bool asInput) {
 
   llvm::StringRef semanticStr;
   const hlsl::Semantic *semantic = {};
@@ -332,13 +332,21 @@ bool GlPerVertex::doGlPerVertexFacts(const DeclaratorDecl *decl,
 
   if (!getStageVarSemantic(decl, &semanticStr, &semantic, &semanticIndex)) {
     if (baseType->isStructureType()) {
-      const auto *structDecl = baseType->getAs<RecordType>()->getDecl();
+      const auto *recordType = baseType->getAs<RecordType>();
+      const auto *recordDecl = recordType->getAsCXXRecordDecl();
       // Go through each field to see if there is any usage of
       // SV_ClipDistance/SV_CullDistance.
-      for (const auto *field : structDecl->fields()) {
+      for (const auto *field : recordDecl->fields()) {
         if (!doGlPerVertexFacts(field, field->getType(), asInput))
           return false;
       }
+
+      // We should also recursively go through each inherited class.
+      for (const auto &base : recordDecl->bases()) {
+        const auto *baseDecl = base.getType()->getAsCXXRecordDecl();
+        if (!doGlPerVertexFacts(baseDecl, base.getType(), asInput))
+          return false;
+      }
       return true;
     }
 

+ 1 - 2
tools/clang/lib/SPIRV/GlPerVertex.h

@@ -129,8 +129,7 @@ private:
                   SourceLocation loc, SourceRange range = {});
 
   /// Internal implementation for recordClipCullDistanceDecl().
-  bool doGlPerVertexFacts(const DeclaratorDecl *decl, QualType type,
-                          bool asInput);
+  bool doGlPerVertexFacts(const NamedDecl *decl, QualType type, bool asInput);
 
   /// Returns whether the type is a scalar, vector, or array that contains
   /// only scalars with float type.

+ 24 - 0
tools/clang/test/CodeGenSPIRV/spirv.interface.ps.inheritance.sv_clipdistance.hlsl

@@ -0,0 +1,24 @@
+// RUN: %dxc -T ps_6_0 -E main
+
+struct Parent {
+  float clipDistance : SV_ClipDistance;
+};
+
+struct PSInput : Parent
+{ };
+
+float main(PSInput input) : SV_TARGET
+{
+// CHECK:  [[ptr0:%\d+]] = OpAccessChain %_ptr_Input_float %gl_ClipDistance %uint_0
+// CHECK: [[load0:%\d+]] = OpLoad %float [[ptr0]]
+
+// CHECK: [[parent:%\d+]] = OpCompositeConstruct %Parent [[load0]]
+// CHECK:  [[input:%\d+]] = OpCompositeConstruct %PSInput [[parent]]
+
+
+// CHECK: [[access0:%\d+]] = OpAccessChain %_ptr_Function_Parent %input %uint_0
+// CHECK: [[access1:%\d+]] = OpAccessChain %_ptr_Function_float [[access0]] %int_0
+// CHECK:   [[load1:%\d+]] = OpLoad %float [[access1]]
+    return input.clipDistance;
+}
+

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -1734,6 +1734,9 @@ TEST_F(FileTest, SpirvStageIOInterfaceVSClipDistanceInvalidType) {
   runFileTest("spirv.interface.vs.clip_distance.type.error.hlsl",
               Expect::Failure);
 }
+TEST_F(FileTest, SpirvStageIOInterfacePSInheritanceSVClipDistance) {
+  runFileTest("spirv.interface.ps.inheritance.sv_clipdistance.hlsl");
+}
 
 TEST_F(FileTest, SpirvStageIOAliasBuiltIn) {
   runFileTest("spirv.interface.alias-builtin.hlsl");