|
@@ -3349,6 +3349,22 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
|
|
|
theBuilder.beginFunction(funcType, voidType, decl->getName());
|
|
|
declIdMapper.setEntryFunctionId(entryFunctionId);
|
|
|
|
|
|
+ // Handle translation of numthreads attribute for compute shaders.
|
|
|
+ if (shaderModel.IsCS()) {
|
|
|
+ // Number of threads attributes are stored as integers. We cast them to
|
|
|
+ // uint32_t to pass to OpExecutionMode SPIR-V instruction.
|
|
|
+ if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
|
|
|
+ theBuilder.addExecutionMode(
|
|
|
+ entryFunctionId, spv::ExecutionMode::LocalSize,
|
|
|
+ {static_cast<uint32_t>(numThreadsAttr->getX()),
|
|
|
+ static_cast<uint32_t>(numThreadsAttr->getY()),
|
|
|
+ static_cast<uint32_t>(numThreadsAttr->getZ())});
|
|
|
+ } else {
|
|
|
+ theBuilder.addExecutionMode(entryFunctionId,
|
|
|
+ spv::ExecutionMode::LocalSize, {1, 1, 1});
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
// The entry basic block.
|
|
|
const uint32_t entryLabel = theBuilder.createBasicBlock();
|
|
|
theBuilder.setInsertPoint(entryLabel);
|