Explorar el Código

[spirv] Handle RecordDecls defined inside DeclStmt (#855)

Ehsan hace 7 años
padre
commit
1d92f9a129

+ 21 - 15
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -312,21 +312,7 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
     } else if (auto *varDecl = dyn_cast<VarDecl>(decl)) {
       doVarDecl(varDecl);
     } else if (auto *recordDecl = dyn_cast<RecordDecl>(decl)) {
-      // Ignore implict records
-      // Somehow we'll have implicit records with:
-      //   static const int Length = count;
-      // that can mess up with the normal CodeGen.
-      if (recordDecl->isImplicit())
-        continue;
-
-      // Handle each static member with inline initializer.
-      // Each static member has a corresponding VarDecl inside the
-      // RecordDecl. For those defined in the translation unit,
-      // their VarDecls do not have initializer.
-      for (auto *subDecl : recordDecl->decls())
-        if (auto *varDecl = dyn_cast<VarDecl>(subDecl))
-          if (varDecl->isStaticDataMember() && varDecl->hasInit())
-            doVarDecl(varDecl);
+      doRecordDecl(recordDecl);
     } else if (auto *bufferDecl = dyn_cast<HLSLBufferDecl>(decl)) {
       // This is a cbuffer/tbuffer decl.
 
@@ -410,6 +396,8 @@ void SPIRVEmitter::doDecl(const Decl *decl) {
     doFunctionDecl(funcDecl);
   } else if (dyn_cast<HLSLBufferDecl>(decl)) {
     llvm_unreachable("HLSLBufferDecl should not be handled here");
+  } else if (const auto *recordDecl = dyn_cast<RecordDecl>(decl)) {
+    doRecordDecl(recordDecl);
   } else {
     emitError("decl type %0 unimplemented", decl->getLocation())
         << decl->getDeclKindName();
@@ -736,6 +724,24 @@ void SPIRVEmitter::validateVKAttributes(const NamedDecl *decl) {
   }
 }
 
+void SPIRVEmitter::doRecordDecl(const RecordDecl *recordDecl) {
+  // Ignore implict records
+  // Somehow we'll have implicit records with:
+  //   static const int Length = count;
+  // that can mess up with the normal CodeGen.
+  if (recordDecl->isImplicit())
+    return;
+
+  // Handle each static member with inline initializer.
+  // Each static member has a corresponding VarDecl inside the
+  // RecordDecl. For those defined in the translation unit,
+  // their VarDecls do not have initializer.
+  for (auto *subDecl : recordDecl->decls())
+    if (auto *varDecl = dyn_cast<VarDecl>(subDecl))
+      if (varDecl->isStaticDataMember() && varDecl->hasInit())
+        doVarDecl(varDecl);
+}
+
 void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
   validateVKAttributes(decl);
 

+ 1 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -76,6 +76,7 @@ public:
 private:
   void doFunctionDecl(const FunctionDecl *decl);
   void doVarDecl(const VarDecl *decl);
+  void doRecordDecl(const RecordDecl *decl);
 
   void doBreakStmt(const BreakStmt *stmt);
   void doDiscardStmt(const DiscardStmt *stmt);

+ 45 - 0
tools/clang/test/CodeGenSPIRV/type.class.hlsl

@@ -0,0 +1,45 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK:      OpName %N "N"
+
+// CHECK:      OpName %S "S"
+// CHECK-NEXT: OpMemberName %S 0 "a"
+// CHECK-NEXT: OpMemberName %S 1 "b"
+// CHECK-NEXT: OpMemberName %S 2 "c"
+
+// CHECK:      OpName %T "T"
+// CHECK-NEXT: OpMemberName %T 0 "x"
+// CHECK-NEXT: OpMemberName %T 1 "y"
+// CHECK-NEXT: OpMemberName %T 2 "z"
+
+// CHECK:      %N = OpTypeStruct
+class N {};
+
+// CHECK:      %S = OpTypeStruct %uint %v4float %mat2v3float
+class S {
+  uint a;
+  float4 b;
+  float2x3 c;
+};
+
+// CHECK:      %T = OpTypeStruct %S %v3int %S
+class T {
+  S x;
+  int3 y;
+  S z;
+};
+
+void main() {
+  N n;
+  S s;
+  T t;
+
+// CHECK: %R = OpTypeStruct %v2float
+// CHECK: %r0 = OpVariable %_ptr_Function_R Function
+  class R {
+    float2 rVal;
+  } r0;
+
+// CHECK: %r1 = OpVariable %_ptr_Function_R Function
+  R r1;
+}

+ 18 - 9
tools/clang/test/CodeGenSPIRV/type.struct.hlsl

@@ -17,20 +17,29 @@ struct N {};
 
 // CHECK:      %S = OpTypeStruct %uint %v4float %mat2v3float
 struct S {
-    uint a;
-    float4 b;
-    float2x3 c;
+  uint a;
+  float4 b;
+  float2x3 c;
 };
 
 // CHECK:      %T = OpTypeStruct %S %v3int %S
 struct T {
-    S x;
-    int3 y;
-    S z;
+  S x;
+  int3 y;
+  S z;
 };
 
 void main() {
-    N n;
-    S s;
-    T t;
+  N n;
+  S s;
+  T t;
+
+// CHECK: %R = OpTypeStruct %v2float
+// CHECK: %r0 = OpVariable %_ptr_Function_R Function
+  struct R {
+    float2 rVal;
+  } r0;
+
+// CHECK: %r1 = OpVariable %_ptr_Function_R Function
+  R r1;
 }

+ 1 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -46,6 +46,7 @@ TEST_F(FileTest, ScalarTypes) { runFileTest("type.scalar.hlsl"); }
 TEST_F(FileTest, VectorTypes) { runFileTest("type.vector.hlsl"); }
 TEST_F(FileTest, MatrixTypes) { runFileTest("type.matrix.hlsl"); }
 TEST_F(FileTest, StructTypes) { runFileTest("type.struct.hlsl"); }
+TEST_F(FileTest, ClassTypes) { runFileTest("type.class.hlsl"); }
 TEST_F(FileTest, ArrayTypes) { runFileTest("type.array.hlsl"); }
 TEST_F(FileTest, TypedefTypes) { runFileTest("type.typedef.hlsl"); }
 TEST_F(FileTest, SamplerTypes) { runFileTest("type.sampler.hlsl"); }