Browse Source

[spirv] do not use WorkGroup storage class for hull shader output patch (#3271)

* [spirv] do not use WorkGroup storage class for hull shader output patch

* Code review
Jaebaek Seo 4 years ago
parent
commit
a30d76ab78

+ 39 - 6
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -2946,6 +2946,39 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
   return var;
 }
 
+SpirvVariable *DeclResultIdMapper::createSpirvIntermediateOutputStageVar(
+    const NamedDecl *decl, const llvm::StringRef name, QualType type) {
+  const auto *semantic = hlsl::Semantic::GetByName(name);
+  SemanticInfo thisSemantic{name, semantic, name, 0, decl->getLocation()};
+
+  const auto *sigPoint =
+      deduceSigPoint(cast<DeclaratorDecl>(decl), /*asInput=*/false,
+                     spvContext.getCurrentShaderModelKind(), /*forPCF=*/false);
+
+  StageVar stageVar(sigPoint, thisSemantic, decl->getAttr<VKBuiltInAttr>(),
+                    type, /*locCount=*/1);
+  SpirvVariable *varInstr =
+      createSpirvStageVar(&stageVar, decl, name, thisSemantic.loc);
+
+  if (!varInstr)
+    return nullptr;
+
+  stageVar.setSpirvInstr(varInstr);
+  stageVar.setLocationAttr(decl->getAttr<VKLocationAttr>());
+  stageVar.setIndexAttr(decl->getAttr<VKIndexAttr>());
+  stageVars.push_back(stageVar);
+
+  // Emit OpDecorate* instructions to link this stage variable with the HLSL
+  // semantic it is created for.
+  spvBuilder.decorateHlslSemantic(varInstr, stageVar.getSemanticStr());
+
+  // We have semantics attached to this decl, which means it must be a
+  // function/parameter/variable. All are DeclaratorDecls.
+  stageVarInstructions[cast<DeclaratorDecl>(decl)] = varInstr;
+
+  return varInstr;
+}
+
 SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     StageVar *stageVar, const NamedDecl *decl, const llvm::StringRef name,
     SourceLocation srcLoc) {
@@ -3662,15 +3695,15 @@ void DeclResultIdMapper::tryToCreateImplicitConstVar(const ValueDecl *decl) {
   astDecls[varDecl].instr = constVal;
 }
 
-SpirvInstruction *DeclResultIdMapper::createHullMainOutputPatch(
-    const ParmVarDecl *param, const QualType retType,
-    uint32_t numOutputControlPoints, SourceLocation loc) {
+SpirvInstruction *
+DeclResultIdMapper::createHullMainOutputPatch(const ParmVarDecl *param,
+                                              const QualType retType,
+                                              uint32_t numOutputControlPoints) {
   const QualType hullMainRetType = astContext.getConstantArrayType(
       retType, llvm::APInt(32, numOutputControlPoints),
       clang::ArrayType::Normal, 0);
-  SpirvInstruction *hullMainOutputPatch = spvBuilder.addModuleVar(
-      hullMainRetType, spv::StorageClass::Workgroup, false,
-      "temp.var.hullMainRetVal", llvm::None, loc);
+  SpirvInstruction *hullMainOutputPatch = createSpirvIntermediateOutputStageVar(
+      param, "temp.var.hullMainRetVal", hullMainRetType);
   assert(astDecls[param].instr == nullptr);
   astDecls[param].instr = hullMainOutputPatch;
   return hullMainOutputPatch;

+ 7 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -410,12 +410,11 @@ public:
   /// VarDecls (such as some ray tracing enums).
   void tryToCreateImplicitConstVar(const ValueDecl *);
 
-  /// \brief Creates a variable for hull shader output patch with Workgroup
+  /// \brief Creates a variable for hull shader output patch with Output
   /// storage class, and registers the SPIR-V variable for the given decl.
   SpirvInstruction *createHullMainOutputPatch(const ParmVarDecl *param,
                                               const QualType retType,
-                                              uint32_t numOutputControlPoints,
-                                              SourceLocation loc);
+                                              uint32_t numOutputControlPoints);
 
   /// Raytracing specific functions
   /// \brief Creates a ShaderRecordBufferNV block from the given decl.
@@ -667,6 +666,11 @@ private:
                                      const llvm::StringRef name,
                                      SourceLocation);
 
+  // Create intermediate output variable to communicate patch constant
+  // data in hull shader since workgroup memory is not allowed there.
+  SpirvVariable *createSpirvIntermediateOutputStageVar(
+      const NamedDecl *decl, const llvm::StringRef name, QualType asType);
+
   /// Returns true if all vk:: attributes usages are valid.
   bool validateVKAttributes(const NamedDecl *decl);
 

+ 2 - 2
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -1122,7 +1122,7 @@ void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) {
     if (spvContext.isHS() && decl == patchConstFunc &&
         hlsl::IsHLSLOutputPatchType(paramDecl->getType())) {
       // Since the output patch used in hull shaders is translated to
-      // a variable with Workgroup storage class, there is no need
+      // a variable with Output storage class, there is no need
       // to pass the variable as function parameter in SPIR-V.
       continue;
     }
@@ -11270,7 +11270,7 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF(
   // results to it, so it can be passed to the PCF.
   if (const auto *param = patchConstFuncTakesHullOutputPatch(patchConstFunc)) {
     hullMainOutputPatch = declIdMapper.createHullMainOutputPatch(
-        param, retType, numOutputControlPoints, locEnd);
+        param, retType, numOutputControlPoints);
     auto *tempLocation = spvBuilder.createAccessChain(
         retType, hullMainOutputPatch, {outputControlPointId}, locEnd);
     spvBuilder.createStore(tempLocation, retVal, locEnd);

+ 1 - 1
tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv

@@ -275,4 +275,4 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 //                OpStore %124 %123
 //         %125 = OpLoad %BEZIER_CONTROL_POINT %result
 //                OpReturnValue %125
-//                OpFunctionEnd
+//                OpFunctionEnd

+ 51 - 0
tools/clang/test/CodeGenSPIRV/hs.const.output-patch.hlsl

@@ -0,0 +1,51 @@
+// Run: %dxc -T hs_6_0 -E main
+
+struct HSCtrlPt {
+  float4 ctrlPt : CONTROLPOINT;
+};
+
+struct HSPatchConstData {
+  float tessFactor[3] : SV_TessFactor;
+  float insideTessFactor[1] : SV_InsideTessFactor;
+  float4 constData : CONSTANTDATA;
+};
+
+// CHECK: OpDecorate %temp_var_hullMainRetVal Location 2
+
+// CHECK: %temp_var_hullMainRetVal = OpVariable %_ptr_Output__arr_HSCtrlPt_uint_3 Output
+// CHECK:        [[invoc_id:%\d+]] = OpLoad %uint %gl_InvocationID
+// CHECK:        [[HSResult:%\d+]] = OpFunctionCall %HSCtrlPt %src_main
+// CHECK:         [[OutCtrl:%\d+]] = OpAccessChain %_ptr_Output_HSCtrlPt %temp_var_hullMainRetVal [[invoc_id]]
+// CHECK:                            OpStore [[OutCtrl]] [[HSResult]]
+
+HSPatchConstData HSPatchConstantFunc(const OutputPatch<HSCtrlPt, 3> input) {
+  HSPatchConstData data;
+
+// CHECK: [[OutCtrl0:%\d+]] = OpAccessChain %_ptr_Output_v4float %temp_var_hullMainRetVal %uint_0 %int_0
+// CHECK:   [[input0:%\d+]] = OpLoad %v4float [[OutCtrl0]]
+// CHECK: [[OutCtrl1:%\d+]] = OpAccessChain %_ptr_Output_v4float %temp_var_hullMainRetVal %uint_1 %int_0
+// CHECK:   [[input1:%\d+]] = OpLoad %v4float [[OutCtrl1]]
+// CHECK:      [[add:%\d+]] = OpFAdd %v4float [[input0]] [[input1]]
+// CHECK: [[OutCtrl2:%\d+]] = OpAccessChain %_ptr_Output_v4float %temp_var_hullMainRetVal %uint_2 %int_0
+// CHECK:   [[input2:%\d+]] = OpLoad %v4float [[OutCtrl2]]
+// CHECK:                     OpFAdd %v4float [[add]] [[input2]]
+  data.constData = input[0].ctrlPt + input[1].ctrlPt + input[2].ctrlPt;
+
+  data.tessFactor[0] = 3.0;
+  data.tessFactor[1] = 3.0;
+  data.tessFactor[2] = 3.0;
+  data.insideTessFactor[0] = 3.0;
+  return data;
+}
+
+[domain("tri")]
+[partitioning("fractional_odd")]
+[outputtopology("triangle_cw")]
+[outputcontrolpoints(3)]
+[patchconstantfunc("HSPatchConstantFunc")]
+[maxtessfactor(15)]
+HSCtrlPt main(InputPatch<HSCtrlPt, 3> input, uint CtrlPtID : SV_OutputControlPointID) {
+  HSCtrlPt data;
+  data.ctrlPt = input[CtrlPtID].ctrlPt;
+  return data;
+}

+ 5 - 5
tools/clang/test/CodeGenSPIRV/hs.pcf.output-patch.hlsl

@@ -5,16 +5,16 @@
 // Test: PCF takes the output (OutputPatch) of the main entry point function.
 
 
-// CHECK:               %_arr_BEZIER_CONTROL_POINT_uint_16 = OpTypeArray %BEZIER_CONTROL_POINT %uint_16
-// CHECK: %_ptr_Workgroup__arr_BEZIER_CONTROL_POINT_uint_16 = OpTypePointer Workgroup %_arr_BEZIER_CONTROL_POINT_uint_16
-// CHECK:                                   [[fType:%\d+]] = OpTypeFunction %HS_CONSTANT_DATA_OUTPUT
-// CHECK: %temp_var_hullMainRetVal = OpVariable %_ptr_Workgroup__arr_BEZIER_CONTROL_POINT_uint_16 Workgroup
+// CHECK:             %_arr_BEZIER_CONTROL_POINT_uint_16 = OpTypeArray %BEZIER_CONTROL_POINT %uint_16
+// CHECK: %_ptr_Output__arr_BEZIER_CONTROL_POINT_uint_16 = OpTypePointer Output %_arr_BEZIER_CONTROL_POINT_uint_16
+// CHECK:                                 [[fType:%\d+]] = OpTypeFunction %HS_CONSTANT_DATA_OUTPUT
+// CHECK: %temp_var_hullMainRetVal = OpVariable %_ptr_Output__arr_BEZIER_CONTROL_POINT_uint_16 Output
 
 // CHECK:                    %main = OpFunction %void None {{%\d+}}
 
 // CHECK:              [[id:%\d+]] = OpLoad %uint %gl_InvocationID
 // CHECK:      [[mainResult:%\d+]] = OpFunctionCall %BEZIER_CONTROL_POINT %src_main %param_var_ip %param_var_i %param_var_PatchID
-// CHECK:             [[loc:%\d+]] = OpAccessChain %_ptr_Workgroup_BEZIER_CONTROL_POINT %temp_var_hullMainRetVal [[id]]
+// CHECK:             [[loc:%\d+]] = OpAccessChain %_ptr_Output_BEZIER_CONTROL_POINT %temp_var_hullMainRetVal [[id]]
 // CHECK:                            OpStore [[loc]] [[mainResult]]
 
 // CHECK:                 {{%\d+}} = OpFunctionCall %HS_CONSTANT_DATA_OUTPUT %PCF

+ 2 - 2
tools/clang/test/CodeGenSPIRV/method.input-output-patch.access.hlsl

@@ -42,12 +42,12 @@ HS_CONSTANT_DATA_OUTPUT PCF(OutputPatch<BEZIER_CONTROL_POINT, MAX_POINTS> op) {
 
   uint x = 5;
 
-// CHECK:      [[op_1_loc:%\d+]] = OpAccessChain %_ptr_Workgroup_v3float %temp_var_hullMainRetVal %uint_1 %int_0
+// CHECK:      [[op_1_loc:%\d+]] = OpAccessChain %_ptr_Output_v3float %temp_var_hullMainRetVal %uint_1 %int_0
 // CHECK-NEXT:          {{%\d+}} = OpLoad %v3float [[op_1_loc]]
   float3 out1pos = op[1].vPosition;
 
 // CHECK:             [[x:%\d+]] = OpLoad %uint %x
-// CHECK-NEXT: [[op_x_loc:%\d+]] = OpAccessChain %_ptr_Workgroup_uint %temp_var_hullMainRetVal [[x]] %int_1
+// CHECK-NEXT: [[op_x_loc:%\d+]] = OpAccessChain %_ptr_Output_uint %temp_var_hullMainRetVal [[x]] %int_1
 // CHECK-NEXT:          {{%\d+}} = OpLoad %uint [[op_x_loc]]
   uint out5id = op[x].pointID;
 

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -2083,6 +2083,9 @@ TEST_F(FileTest, HullShaderPCFTakesViewId) {
 TEST_F(FileTest, HullShaderPCFTakesViewIdButMainDoesnt) {
   runFileTest("hs.pcf.view-id.2.hlsl");
 }
+TEST_F(FileTest, HullShaderConstOutputPatch) {
+  runFileTest("hs.const.output-patch.hlsl");
+}
 // HS: for the structure of hull shaders
 TEST_F(FileTest, HullShaderStructure) { runFileTest("hs.structure.hlsl"); }