Browse Source

[spirv] Hull shader output patch must be shared between threads (#3059)

The existing code uses the Function storage class for the output patch
of a hull shader that will be used by the patch constant function.
DirectX automatically transmits the output patch to the all threads,
which are identical between threads. Even though Vulkan does not have
the patch constant function, we can simulate it using the execution
barrier and the variable with WorkGroup storage class.
Jaebaek Seo 5 năm trước cách đây
mục cha
commit
b94ee767cc

+ 14 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -3570,5 +3570,19 @@ void DeclResultIdMapper::tryToCreateImplicitConstVar(const ValueDecl *decl) {
   astDecls[varDecl].instr = constVal;
   astDecls[varDecl].instr = constVal;
 }
 }
 
 
+SpirvInstruction *DeclResultIdMapper::createHullMainOutputPatch(
+    const ParmVarDecl *param, const QualType retType,
+    uint32_t numOutputControlPoints, SourceLocation loc) {
+  const QualType hullMainRetType = astContext.getConstantArrayType(
+      retType, llvm::APInt(32, numOutputControlPoints),
+      clang::ArrayType::Normal, 0);
+  SpirvInstruction *hullMainOutputPatch = spvBuilder.addModuleVar(
+      hullMainRetType, spv::StorageClass::Workgroup, false,
+      "temp.var.hullMainRetVal", llvm::None, loc);
+  assert(astDecls[param].instr == nullptr);
+  astDecls[param].instr = hullMainOutputPatch;
+  return hullMainOutputPatch;
+}
+
 } // end namespace spirv
 } // end namespace spirv
 } // end namespace clang
 } // end namespace clang

+ 7 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -402,6 +402,13 @@ public:
   /// VarDecls (such as some ray tracing enums).
   /// VarDecls (such as some ray tracing enums).
   void tryToCreateImplicitConstVar(const ValueDecl *);
   void tryToCreateImplicitConstVar(const ValueDecl *);
 
 
+  /// \brief Creates a variable for hull shader output patch with Workgroup
+  /// storage class, and registers the SPIR-V variable for the given decl.
+  SpirvInstruction *createHullMainOutputPatch(const ParmVarDecl *param,
+                                              const QualType retType,
+                                              uint32_t numOutputControlPoints,
+                                              SourceLocation loc);
+
   /// Raytracing specific functions
   /// Raytracing specific functions
   /// \brief Creates a ShaderRecordBufferNV block from the given decl.
   /// \brief Creates a ShaderRecordBufferNV block from the given decl.
   SpirvVariable *createShaderRecordBufferNV(const VarDecl *decl);
   SpirvVariable *createShaderRecordBufferNV(const VarDecl *decl);

+ 17 - 10
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -56,11 +56,11 @@ bool hasSemantic(const DeclaratorDecl *decl,
   return false;
   return false;
 }
 }
 
 
-bool patchConstFuncTakesHullOutputPatch(FunctionDecl *pcf) {
+const ParmVarDecl *patchConstFuncTakesHullOutputPatch(FunctionDecl *pcf) {
   for (const auto *param : pcf->parameters())
   for (const auto *param : pcf->parameters())
     if (hlsl::IsHLSLOutputPatchType(param->getType()))
     if (hlsl::IsHLSLOutputPatchType(param->getType()))
-      return true;
-  return false;
+      return param;
+  return nullptr;
 }
 }
 
 
 inline bool isSpirvMatrixOp(spv::Op opcode) {
 inline bool isSpirvMatrixOp(spv::Op opcode) {
@@ -977,6 +977,13 @@ void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) {
   // Create all parameters.
   // Create all parameters.
   for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
   for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
     const ParmVarDecl *paramDecl = decl->getParamDecl(i);
     const ParmVarDecl *paramDecl = decl->getParamDecl(i);
+    if (spvContext.isHS() && decl == patchConstFunc &&
+        hlsl::IsHLSLOutputPatchType(paramDecl->getType())) {
+      // Since the output patch used in hull shaders is translated to
+      // a variable with Workgroup storage class, there is no need
+      // to pass the variable as function parameter in SPIR-V.
+      continue;
+    }
     (void)declIdMapper.createFnParam(paramDecl);
     (void)declIdMapper.createFnParam(paramDecl);
   }
   }
 
 
@@ -10889,12 +10896,9 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
   // If the patch constant function (PCF) takes the result of the Hull main
   // If the patch constant function (PCF) takes the result of the Hull main
   // entry point, create a temporary function-scope variable and write the
   // entry point, create a temporary function-scope variable and write the
   // results to it, so it can be passed to the PCF.
   // results to it, so it can be passed to the PCF.
-  if (patchConstFuncTakesHullOutputPatch(patchConstFunc)) {
-    const QualType hullMainRetType = astContext.getConstantArrayType(
-        retType, llvm::APInt(32, numOutputControlPoints),
-        clang::ArrayType::Normal, 0);
-    hullMainOutputPatch =
-        spvBuilder.addFnVar(hullMainRetType, locEnd, "temp.var.hullMainRetVal");
+  if (const auto *param = patchConstFuncTakesHullOutputPatch(patchConstFunc)) {
+    hullMainOutputPatch = declIdMapper.createHullMainOutputPatch(
+        param, retType, numOutputControlPoints, locEnd);
     auto *tempLocation = spvBuilder.createAccessChain(
     auto *tempLocation = spvBuilder.createAccessChain(
         retType, hullMainOutputPatch, {outputControlPointId}, locEnd);
         retType, hullMainOutputPatch, {outputControlPointId}, locEnd);
     spvBuilder.createStore(tempLocation, retVal, locEnd);
     spvBuilder.createStore(tempLocation, retVal, locEnd);
@@ -10956,7 +10960,10 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
     if (hlsl::IsHLSLInputPatchType(param->getType())) {
     if (hlsl::IsHLSLInputPatchType(param->getType())) {
       pcfParams.push_back(hullMainInputPatch);
       pcfParams.push_back(hullMainInputPatch);
     } else if (hlsl::IsHLSLOutputPatchType(param->getType())) {
     } else if (hlsl::IsHLSLOutputPatchType(param->getType())) {
-      pcfParams.push_back(hullMainOutputPatch);
+      // Since the output patch used in hull shaders is translated to
+      // a variable with Workgroup storage class, there is no need
+      // to pass the variable as function parameter in SPIR-V.
+      continue;
     } else if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID)) {
     } else if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID)) {
       if (!primitiveId) {
       if (!primitiveId) {
         primitiveId = createParmVarAndInitFromStageInputVar(param);
         primitiveId = createParmVarAndInitFromStageInputVar(param);

+ 5 - 6
tools/clang/test/CodeGenSPIRV/hs.pcf.output-patch.hlsl

@@ -6,21 +6,20 @@
 
 
 
 
 // CHECK:               %_arr_BEZIER_CONTROL_POINT_uint_16 = OpTypeArray %BEZIER_CONTROL_POINT %uint_16
 // CHECK:               %_arr_BEZIER_CONTROL_POINT_uint_16 = OpTypeArray %BEZIER_CONTROL_POINT %uint_16
-// CHECK: %_ptr_Function__arr_BEZIER_CONTROL_POINT_uint_16 = OpTypePointer Function %_arr_BEZIER_CONTROL_POINT_uint_16
-// CHECK:                                   [[fType:%\d+]] = OpTypeFunction %HS_CONSTANT_DATA_OUTPUT %_ptr_Function__arr_BEZIER_CONTROL_POINT_uint_16
+// CHECK: %_ptr_Workgroup__arr_BEZIER_CONTROL_POINT_uint_16 = OpTypePointer Workgroup %_arr_BEZIER_CONTROL_POINT_uint_16
+// CHECK:                                   [[fType:%\d+]] = OpTypeFunction %HS_CONSTANT_DATA_OUTPUT
+// CHECK: %temp_var_hullMainRetVal = OpVariable %_ptr_Workgroup__arr_BEZIER_CONTROL_POINT_uint_16 Workgroup
 
 
 // CHECK:                    %main = OpFunction %void None {{%\d+}}
 // CHECK:                    %main = OpFunction %void None {{%\d+}}
-// CHECK: %temp_var_hullMainRetVal = OpVariable %_ptr_Function__arr_BEZIER_CONTROL_POINT_uint_16 Function
 
 
 // CHECK:              [[id:%\d+]] = OpLoad %uint %gl_InvocationID
 // CHECK:              [[id:%\d+]] = OpLoad %uint %gl_InvocationID
 // CHECK:      [[mainResult:%\d+]] = OpFunctionCall %BEZIER_CONTROL_POINT %src_main %param_var_ip %param_var_i %param_var_PatchID
 // CHECK:      [[mainResult:%\d+]] = OpFunctionCall %BEZIER_CONTROL_POINT %src_main %param_var_ip %param_var_i %param_var_PatchID
-// CHECK:             [[loc:%\d+]] = OpAccessChain %_ptr_Function_BEZIER_CONTROL_POINT %temp_var_hullMainRetVal [[id]]
+// CHECK:             [[loc:%\d+]] = OpAccessChain %_ptr_Workgroup_BEZIER_CONTROL_POINT %temp_var_hullMainRetVal [[id]]
 // CHECK:                            OpStore [[loc]] [[mainResult]]
 // CHECK:                            OpStore [[loc]] [[mainResult]]
 
 
-// CHECK:                 {{%\d+}} = OpFunctionCall %HS_CONSTANT_DATA_OUTPUT %PCF %temp_var_hullMainRetVal
+// CHECK:                 {{%\d+}} = OpFunctionCall %HS_CONSTANT_DATA_OUTPUT %PCF
 
 
 // CHECK:      %PCF = OpFunction %HS_CONSTANT_DATA_OUTPUT None [[fType]]
 // CHECK:      %PCF = OpFunction %HS_CONSTANT_DATA_OUTPUT None [[fType]]
-// CHECK-NEXT:  %op = OpFunctionParameter %_ptr_Function__arr_BEZIER_CONTROL_POINT_uint_16
 
 
 HS_CONSTANT_DATA_OUTPUT PCF(OutputPatch<BEZIER_CONTROL_POINT, MAX_POINTS> op) {
 HS_CONSTANT_DATA_OUTPUT PCF(OutputPatch<BEZIER_CONTROL_POINT, MAX_POINTS> op) {
   HS_CONSTANT_DATA_OUTPUT Output;
   HS_CONSTANT_DATA_OUTPUT Output;

+ 2 - 2
tools/clang/test/CodeGenSPIRV/method.input-output-patch.access.hlsl

@@ -42,12 +42,12 @@ HS_CONSTANT_DATA_OUTPUT PCF(OutputPatch<BEZIER_CONTROL_POINT, MAX_POINTS> op) {
 
 
   uint x = 5;
   uint x = 5;
 
 
-// CHECK:      [[op_1_loc:%\d+]] = OpAccessChain %_ptr_Function_v3float %op %uint_1 %int_0
+// CHECK:      [[op_1_loc:%\d+]] = OpAccessChain %_ptr_Workgroup_v3float %temp_var_hullMainRetVal %uint_1 %int_0
 // CHECK-NEXT:          {{%\d+}} = OpLoad %v3float [[op_1_loc]]
 // CHECK-NEXT:          {{%\d+}} = OpLoad %v3float [[op_1_loc]]
   float3 out1pos = op[1].vPosition;
   float3 out1pos = op[1].vPosition;
 
 
 // CHECK:             [[x:%\d+]] = OpLoad %uint %x
 // CHECK:             [[x:%\d+]] = OpLoad %uint %x
-// CHECK-NEXT: [[op_x_loc:%\d+]] = OpAccessChain %_ptr_Function_uint %op [[x]] %int_1
+// CHECK-NEXT: [[op_x_loc:%\d+]] = OpAccessChain %_ptr_Workgroup_uint %temp_var_hullMainRetVal [[x]] %int_1
 // CHECK-NEXT:          {{%\d+}} = OpLoad %uint [[op_x_loc]]
 // CHECK-NEXT:          {{%\d+}} = OpLoad %uint [[op_x_loc]]
   uint out5id = op[x].pointID;
   uint out5id = op[x].pointID;