Browse Source

[spirv] Translate numthreads attribute for CS. (#580)

Ehsan 8 years ago
parent
commit
1e6a9afd47

+ 16 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -3349,6 +3349,22 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
       theBuilder.beginFunction(funcType, voidType, decl->getName());
   declIdMapper.setEntryFunctionId(entryFunctionId);
 
+  // Handle translation of numthreads attribute for compute shaders.
+  if (shaderModel.IsCS()) {
+    // Number of threads attributes are stored as integers. We cast them to
+    // uint32_t to pass to OpExecutionMode SPIR-V instruction.
+    if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
+      theBuilder.addExecutionMode(
+          entryFunctionId, spv::ExecutionMode::LocalSize,
+          {static_cast<uint32_t>(numThreadsAttr->getX()),
+           static_cast<uint32_t>(numThreadsAttr->getY()),
+           static_cast<uint32_t>(numThreadsAttr->getZ())});
+    } else {
+      theBuilder.addExecutionMode(entryFunctionId,
+                                  spv::ExecutionMode::LocalSize, {1, 1, 1});
+    }
+  }
+
   // The entry basic block.
   const uint32_t entryLabel = theBuilder.createBasicBlock();
   theBuilder.setInsertPoint(entryLabel);

+ 7 - 0
tools/clang/test/CodeGenSPIRV/attribute.numthreads.hlsl

@@ -0,0 +1,7 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: OpEntryPoint GLCompute %main "main"
+// CHECK: OpExecutionMode %main LocalSize 8 4 1
+
+[numthreads(8, 4, 1)]
+void main() {}

+ 6 - 0
tools/clang/test/CodeGenSPIRV/attribute.numthreads.missing.hlsl

@@ -0,0 +1,6 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: OpEntryPoint GLCompute %main "main"
+// CHECK: OpExecutionMode %main LocalSize 1 1 1
+
+void main() {}

+ 8 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -373,6 +373,14 @@ TEST_F(FileTest, IntrinsicsAsin) { runFileTest("intrinsics.asin.hlsl"); }
 TEST_F(FileTest, IntrinsicsAcos) { runFileTest("intrinsics.acos.hlsl"); }
 TEST_F(FileTest, IntrinsicsAtan) { runFileTest("intrinsics.atan.hlsl"); }
 
+// For attributes
+TEST_F(FileTest, AttributeNumThreads) {
+  runFileTest("attribute.numthreads.hlsl");
+}
+TEST_F(FileTest, AttributeMissingNumThreads) {
+  runFileTest("attribute.numthreads.missing.hlsl");
+}
+
 // Vulkan/SPIR-V specific
 TEST_F(FileTest, SpirvStorageClass) { runFileTest("spirv.storage-class.hlsl"); }