Browse Source

[spirv] Hull shader output patch must be shared between threads (#3059)

The existing code uses the Function storage class for the output patch
of a hull shader that will be used by the patch constant function.
DirectX automatically transmits the output patch to the all threads,
which are identical between threads. Even though Vulkan does not have
the patch constant function, we can simulate it using the execution
barrier and the variable with WorkGroup storage class.
Jaebaek Seo 5 years ago
parent
commit
b94ee767cc

+ 14 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -3570,5 +3570,19 @@ void DeclResultIdMapper::tryToCreateImplicitConstVar(const ValueDecl *decl) {
   astDecls[varDecl].instr = constVal;
 }
 
+SpirvInstruction *DeclResultIdMapper::createHullMainOutputPatch(
+    const ParmVarDecl *param, const QualType retType,
+    uint32_t numOutputControlPoints, SourceLocation loc) {
+  const QualType hullMainRetType = astContext.getConstantArrayType(
+      retType, llvm::APInt(32, numOutputControlPoints),
+      clang::ArrayType::Normal, 0);
+  SpirvInstruction *hullMainOutputPatch = spvBuilder.addModuleVar(
+      hullMainRetType, spv::StorageClass::Workgroup, false,
+      "temp.var.hullMainRetVal", llvm::None, loc);
+  assert(astDecls[param].instr == nullptr);
+  astDecls[param].instr = hullMainOutputPatch;
+  return hullMainOutputPatch;
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 7 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -402,6 +402,13 @@ public:
   /// VarDecls (such as some ray tracing enums).
   void tryToCreateImplicitConstVar(const ValueDecl *);
 
+  /// \brief Creates a variable for hull shader output patch with Workgroup
+  /// storage class, and registers the SPIR-V variable for the given decl.
+  SpirvInstruction *createHullMainOutputPatch(const ParmVarDecl *param,
+                                              const QualType retType,
+                                              uint32_t numOutputControlPoints,
+                                              SourceLocation loc);
+
   /// Raytracing specific functions
   /// \brief Creates a ShaderRecordBufferNV block from the given decl.
   SpirvVariable *createShaderRecordBufferNV(const VarDecl *decl);

+ 17 - 10
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -56,11 +56,11 @@ bool hasSemantic(const DeclaratorDecl *decl,
   return false;
 }
 
-bool patchConstFuncTakesHullOutputPatch(FunctionDecl *pcf) {
+const ParmVarDecl *patchConstFuncTakesHullOutputPatch(FunctionDecl *pcf) {
   for (const auto *param : pcf->parameters())
     if (hlsl::IsHLSLOutputPatchType(param->getType()))
-      return true;
-  return false;
+      return param;
+  return nullptr;
 }
 
 inline bool isSpirvMatrixOp(spv::Op opcode) {
@@ -977,6 +977,13 @@ void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) {
   // Create all parameters.
   for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
     const ParmVarDecl *paramDecl = decl->getParamDecl(i);
+    if (spvContext.isHS() && decl == patchConstFunc &&
+        hlsl::IsHLSLOutputPatchType(paramDecl->getType())) {
+      // Since the output patch used in hull shaders is translated to
+      // a variable with Workgroup storage class, there is no need
+      // to pass the variable as function parameter in SPIR-V.
+      continue;
+    }
     (void)declIdMapper.createFnParam(paramDecl);
   }
 
@@ -10889,12 +10896,9 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
   // If the patch constant function (PCF) takes the result of the Hull main
   // entry point, create a temporary function-scope variable and write the
   // results to it, so it can be passed to the PCF.
-  if (patchConstFuncTakesHullOutputPatch(patchConstFunc)) {
-    const QualType hullMainRetType = astContext.getConstantArrayType(
-        retType, llvm::APInt(32, numOutputControlPoints),
-        clang::ArrayType::Normal, 0);
-    hullMainOutputPatch =
-        spvBuilder.addFnVar(hullMainRetType, locEnd, "temp.var.hullMainRetVal");
+  if (const auto *param = patchConstFuncTakesHullOutputPatch(patchConstFunc)) {
+    hullMainOutputPatch = declIdMapper.createHullMainOutputPatch(
+        param, retType, numOutputControlPoints, locEnd);
     auto *tempLocation = spvBuilder.createAccessChain(
         retType, hullMainOutputPatch, {outputControlPointId}, locEnd);
     spvBuilder.createStore(tempLocation, retVal, locEnd);
@@ -10956,7 +10960,10 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
     if (hlsl::IsHLSLInputPatchType(param->getType())) {
       pcfParams.push_back(hullMainInputPatch);
     } else if (hlsl::IsHLSLOutputPatchType(param->getType())) {
-      pcfParams.push_back(hullMainOutputPatch);
+      // Since the output patch used in hull shaders is translated to
+      // a variable with Workgroup storage class, there is no need
+      // to pass the variable as function parameter in SPIR-V.
+      continue;
     } else if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID)) {
       if (!primitiveId) {
         primitiveId = createParmVarAndInitFromStageInputVar(param);

+ 5 - 6
tools/clang/test/CodeGenSPIRV/hs.pcf.output-patch.hlsl

@@ -6,21 +6,20 @@
 
 
 // CHECK:               %_arr_BEZIER_CONTROL_POINT_uint_16 = OpTypeArray %BEZIER_CONTROL_POINT %uint_16
-// CHECK: %_ptr_Function__arr_BEZIER_CONTROL_POINT_uint_16 = OpTypePointer Function %_arr_BEZIER_CONTROL_POINT_uint_16
-// CHECK:                                   [[fType:%\d+]] = OpTypeFunction %HS_CONSTANT_DATA_OUTPUT %_ptr_Function__arr_BEZIER_CONTROL_POINT_uint_16
+// CHECK: %_ptr_Workgroup__arr_BEZIER_CONTROL_POINT_uint_16 = OpTypePointer Workgroup %_arr_BEZIER_CONTROL_POINT_uint_16
+// CHECK:                                   [[fType:%\d+]] = OpTypeFunction %HS_CONSTANT_DATA_OUTPUT
+// CHECK: %temp_var_hullMainRetVal = OpVariable %_ptr_Workgroup__arr_BEZIER_CONTROL_POINT_uint_16 Workgroup
 
 // CHECK:                    %main = OpFunction %void None {{%\d+}}
-// CHECK: %temp_var_hullMainRetVal = OpVariable %_ptr_Function__arr_BEZIER_CONTROL_POINT_uint_16 Function
 
 // CHECK:              [[id:%\d+]] = OpLoad %uint %gl_InvocationID
 // CHECK:      [[mainResult:%\d+]] = OpFunctionCall %BEZIER_CONTROL_POINT %src_main %param_var_ip %param_var_i %param_var_PatchID
-// CHECK:             [[loc:%\d+]] = OpAccessChain %_ptr_Function_BEZIER_CONTROL_POINT %temp_var_hullMainRetVal [[id]]
+// CHECK:             [[loc:%\d+]] = OpAccessChain %_ptr_Workgroup_BEZIER_CONTROL_POINT %temp_var_hullMainRetVal [[id]]
 // CHECK:                            OpStore [[loc]] [[mainResult]]
 
-// CHECK:                 {{%\d+}} = OpFunctionCall %HS_CONSTANT_DATA_OUTPUT %PCF %temp_var_hullMainRetVal
+// CHECK:                 {{%\d+}} = OpFunctionCall %HS_CONSTANT_DATA_OUTPUT %PCF
 
 // CHECK:      %PCF = OpFunction %HS_CONSTANT_DATA_OUTPUT None [[fType]]
-// CHECK-NEXT:  %op = OpFunctionParameter %_ptr_Function__arr_BEZIER_CONTROL_POINT_uint_16
 
 HS_CONSTANT_DATA_OUTPUT PCF(OutputPatch<BEZIER_CONTROL_POINT, MAX_POINTS> op) {
   HS_CONSTANT_DATA_OUTPUT Output;

+ 2 - 2
tools/clang/test/CodeGenSPIRV/method.input-output-patch.access.hlsl

@@ -42,12 +42,12 @@ HS_CONSTANT_DATA_OUTPUT PCF(OutputPatch<BEZIER_CONTROL_POINT, MAX_POINTS> op) {
 
   uint x = 5;
 
-// CHECK:      [[op_1_loc:%\d+]] = OpAccessChain %_ptr_Function_v3float %op %uint_1 %int_0
+// CHECK:      [[op_1_loc:%\d+]] = OpAccessChain %_ptr_Workgroup_v3float %temp_var_hullMainRetVal %uint_1 %int_0
 // CHECK-NEXT:          {{%\d+}} = OpLoad %v3float [[op_1_loc]]
   float3 out1pos = op[1].vPosition;
 
 // CHECK:             [[x:%\d+]] = OpLoad %uint %x
-// CHECK-NEXT: [[op_x_loc:%\d+]] = OpAccessChain %_ptr_Function_uint %op [[x]] %int_1
+// CHECK-NEXT: [[op_x_loc:%\d+]] = OpAccessChain %_ptr_Workgroup_uint %temp_var_hullMainRetVal [[x]] %int_1
 // CHECK-NEXT:          {{%\d+}} = OpLoad %uint [[op_x_loc]]
   uint out5id = op[x].pointID;