Browse Source

[spirv] Handle class template instance (#3687)

This change correctly generates struct type for the following template
instance:
```
template <typename T>
struct Foo { ... use T ... }

...

Foo<int>
Foo<float>
...
```

Fixes #3557
Jaebaek Seo 4 years ago
parent
commit
95176bf696

+ 12 - 1
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -720,7 +720,7 @@ void SpirvEmitter::doDecl(const Decl *decl) {
     doEnumDecl(enumDecl);
     doEnumDecl(enumDecl);
   } else if (const auto *classTemplateDecl =
   } else if (const auto *classTemplateDecl =
                  dyn_cast<ClassTemplateDecl>(decl)) {
                  dyn_cast<ClassTemplateDecl>(decl)) {
-    // nothing to do.
+    doClassTemplateDecl(classTemplateDecl);
   } else if (const auto *functionTemplateDecl =
   } else if (const auto *functionTemplateDecl =
                  dyn_cast<FunctionTemplateDecl>(decl)) {
                  dyn_cast<FunctionTemplateDecl>(decl)) {
     // nothing to do.
     // nothing to do.
@@ -1362,6 +1362,17 @@ void SpirvEmitter::doHLSLBufferDecl(const HLSLBufferDecl *bufferDecl) {
   }
   }
 }
 }
 
 
+void SpirvEmitter::doClassTemplateDecl(
+    const ClassTemplateDecl *classTemplateDecl) {
+  for (auto classTemplateSpecializationDeclItr :
+       classTemplateDecl->specializations()) {
+    if (const CXXRecordDecl *recordDecl =
+            dyn_cast<CXXRecordDecl>(&*classTemplateSpecializationDeclItr)) {
+      doRecordDecl(recordDecl);
+    }
+  }
+}
+
 void SpirvEmitter::doRecordDecl(const RecordDecl *recordDecl) {
 void SpirvEmitter::doRecordDecl(const RecordDecl *recordDecl) {
   // Ignore implict records
   // Ignore implict records
   // Somehow we'll have implicit records with:
   // Somehow we'll have implicit records with:

+ 1 - 0
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -82,6 +82,7 @@ private:
   void doFunctionDecl(const FunctionDecl *decl);
   void doFunctionDecl(const FunctionDecl *decl);
   void doVarDecl(const VarDecl *decl);
   void doVarDecl(const VarDecl *decl);
   void doRecordDecl(const RecordDecl *decl);
   void doRecordDecl(const RecordDecl *decl);
+  void doClassTemplateDecl(const ClassTemplateDecl *classTemplateDecl);
   void doEnumDecl(const EnumDecl *decl);
   void doEnumDecl(const EnumDecl *decl);
   void doHLSLBufferDecl(const HLSLBufferDecl *decl);
   void doHLSLBufferDecl(const HLSLBufferDecl *decl);
   void doImplicitDecl(const Decl *decl);
   void doImplicitDecl(const Decl *decl);

+ 51 - 0
tools/clang/test/CodeGenSPIRV/type.template.struct.template-instance.hlsl

@@ -0,0 +1,51 @@
+// Run: %dxc -T ps_6_0 -E main -enable-templates
+
+// The SPIR-V backend correctly handles the template instance `Foo<int>`.
+// The created template instance is ClassTemplateSpecializationDecl in AST.
+
+template <typename T>
+struct Foo {
+    static const T bar = 0;
+
+    T value;
+    void set(T value_) { value = value_; }
+    T get() { return value; }
+};
+
+void main() {
+// CHECK: [[bar_int:%\w+]] = OpVariable %_ptr_Private_int Private
+// CHECK: [[bar_float:%\w+]] = OpVariable %_ptr_Private_float Private
+
+// CHECK: OpStore [[bar_int]] %int_0
+// CHECK: OpStore [[bar_float]] %float_0
+
+    Foo<int>::bar;
+
+// CHECK: %x = OpVariable %_ptr_Function_int Function
+    int x;
+
+// CHECK: %y = OpVariable %_ptr_Function_Foo Function
+    Foo<float> y;
+
+// CHECK:       [[x:%\w+]] = OpLoad %int %x
+// CHECK: [[float_x:%\w+]] = OpConvertSToF %float [[x]]
+// CHECK:                    OpStore [[param_value_:%\w+]] [[float_x]]
+// CHECK:                    OpFunctionCall %void %Foo_set %y [[param_value_]]
+    y.set(x);
+
+// CHECK:     [[y_get:%\w+]] = OpFunctionCall %float %Foo_get %y
+// CHECK: [[y_get_int:%\w+]] = OpConvertFToS %int [[y_get]]
+// CHECK:                      OpStore %x [[y_get_int]]
+    x = y.get();
+}
+
+// CHECK:         %Foo_set = OpFunction
+// CHECK-NEXT: %param_this = OpFunctionParameter %_ptr_Function_Foo
+// CHECK-NEXT:     %value_ = OpFunctionParameter %_ptr_Function_float
+
+// CHECK:         %Foo_get = OpFunction %float
+// CHECK-NEXT: [[this:%\w+]] = OpFunctionParameter %_ptr_Function_Foo
+
+// CHECK: [[ptr_value:%\w+]] = OpAccessChain %_ptr_Function_float [[this]] %int_0
+// CHECK:     [[value:%\w+]] = OpLoad %float [[ptr_value]]
+// CHECK:                      OpReturnValue [[value]]

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -148,6 +148,9 @@ TEST_F(FileTest, TriangleStreamTypes) {
 TEST_F(FileTest, TemplateFunctionInstance) {
 TEST_F(FileTest, TemplateFunctionInstance) {
   runFileTest("type.template.function.template-instance.hlsl");
   runFileTest("type.template.function.template-instance.hlsl");
 }
 }
+TEST_F(FileTest, TemplateStructInstance) {
+  runFileTest("type.template.struct.template-instance.hlsl");
+}
 
 
 // For constants
 // For constants
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }