浏览代码

Add attribute to mark shader entry.

Xiang Li 8 年之前
父节点
当前提交
85063375b3

+ 4 - 0
docs/DXIL.rst

@@ -2888,6 +2888,10 @@ Modules and Linking
 ===================
 
 HLSL has linking capabilities to enable third-party libraries. The linking step happens before shader DXIL is given to the driver compilers.
+Experimental library generation is added in DXIL1.1. A library could be created by compile with lib_6_1 profile.
+A library is a dxil container like the compile result of other shader profiles. The difference is library will keep information for linking like resource link info and entry function signatures.
+Library support is not part of DXIL spec. Only requirement is linked shader must be valid DXIL.
+
 
 Additional Notes
 ================

+ 6 - 0
tools/clang/include/clang/Basic/Attr.td

@@ -844,6 +844,12 @@ def HLSLGloballyCoherent : InheritableAttr {
   let Documentation = [Undocumented];
 }
 
+def HLSLShader : InheritableAttr {
+  let Spellings = [CXX11<"", "shader", 2017>];
+  let Args = [StringArgument<"stage">]; // one of compute, pixel, vertex, hull, domain, geomery
+  let Documentation = [Undocumented];
+}
+
 // HLSL Change Ends
 
 def C11NoReturn : InheritableAttr {

+ 43 - 8
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -1059,10 +1059,52 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   if (isEntry)
     EntryFunc = F;
 
+  DiagnosticsEngine &Diags = CGM.getDiags();
+
   std::unique_ptr<DxilFunctionProps> funcProps =
       llvm::make_unique<DxilFunctionProps>();
-  // TODO: add attribute to mark shader entry.
   funcProps->shaderKind = DXIL::ShaderKind::Invalid;
+  bool isCS = false;
+  bool isGS = false;
+  bool isHS = false;
+  bool isDS = false;
+  bool isVS = false;
+  bool isPS = false;
+  if (const HLSLShaderAttr *Attr = FD->getAttr<HLSLShaderAttr>()) {
+    // Stage is already validate in HandleDeclAttributeForHLSL.
+    // Here just check first letter.
+    switch (Attr->getStage()[0]) {
+    case 'c':
+      isCS = true;
+      funcProps->shaderKind = DXIL::ShaderKind::Compute;
+      break;
+    case 'v':
+      isVS = true;
+      funcProps->shaderKind = DXIL::ShaderKind::Vertex;
+      break;
+    case 'h':
+      isHS = true;
+      funcProps->shaderKind = DXIL::ShaderKind::Hull;
+      break;
+    case 'd':
+      isDS = true;
+      funcProps->shaderKind = DXIL::ShaderKind::Domain;
+      break;
+    case 'g':
+      isGS = true;
+      funcProps->shaderKind = DXIL::ShaderKind::Geometry;
+      break;
+    case 'p':
+      isPS = true;
+      funcProps->shaderKind = DXIL::ShaderKind::Pixel;
+      break;
+    default: {
+      unsigned DiagID = Diags.getCustomDiagID(
+          DiagnosticsEngine::Error, "Invalid profile for shader attribute");
+      Diags.Report(Attr->getLocation(), DiagID);
+    } break;
+    }
+  }
 
   // Save patch constant function to patchConstantFunctionMap.
   bool isPatchConstantFunction = false;
@@ -1103,9 +1145,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     funcProps->shaderKind = SM->GetKind();
   }
 
-  DiagnosticsEngine &Diags = CGM.getDiags();
   // Geometry shader.
-  bool isGS = false;
   if (const HLSLMaxVertexCountAttr *Attr =
           FD->getAttr<HLSLMaxVertexCountAttr>()) {
     isGS = true;
@@ -1138,7 +1178,6 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   }
 
   // Computer shader.
-  bool isCS = false;
   if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {
     isCS = true;
     funcProps->shaderKind = DXIL::ShaderKind::Compute;
@@ -1156,7 +1195,6 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   }
 
   // Hull shader.
-  bool isHS = false;
   if (const HLSLPatchConstantFuncAttr *Attr =
           FD->getAttr<HLSLPatchConstantFuncAttr>()) {
     if (isEntry && !SM->IsHS()) {
@@ -1262,7 +1300,6 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   }
 
   // Hull or domain shader.
-  bool isDS = false;
   if (const HLSLDomainAttr *Attr = FD->getAttr<HLSLDomainAttr>()) {
     if (isEntry && !SM->IsHS() && !SM->IsDS()) {
       unsigned DiagID =
@@ -1284,7 +1321,6 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   }
 
   // Vertex shader.
-  bool isVS = false;
   if (const HLSLClipPlanesAttr *Attr = FD->getAttr<HLSLClipPlanesAttr>()) {
     if (isEntry && !SM->IsVS()) {
       unsigned DiagID = Diags.getCustomDiagID(
@@ -1300,7 +1336,6 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   }
 
   // Pixel shader.
-  bool isPS = false;
   if (const HLSLEarlyDepthStencilAttr *Attr =
           FD->getAttr<HLSLEarlyDepthStencilAttr>()) {
     if (isEntry && !SM->IsPS()) {

+ 1 - 0
tools/clang/lib/Parse/ParseDecl.cpp

@@ -706,6 +706,7 @@ void Parser::ParseGNUAttributeArgs(IdentifierInfo *AttrName,
     case AttributeList::AT_HLSLLoop:
     case AttributeList::AT_HLSLMaxTessFactor:
     case AttributeList::AT_HLSLNumThreads:
+    case AttributeList::AT_HLSLShader:
     case AttributeList::AT_HLSLRootSignature:
     case AttributeList::AT_HLSLOutputControlPoints:
     case AttributeList::AT_HLSLOutputTopology:

+ 16 - 0
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -10058,6 +10058,13 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
     declAttr = ::new (S.Context) HLSLPatchConstantFuncAttr(A.getRange(), S.Context,
       ValidateAttributeStringArg(S, A, nullptr), A.getAttributeSpellingListIndex());
     break;
+  case AttributeList::AT_HLSLShader:
+    declAttr = ::new (S.Context) HLSLShaderAttr(
+        A.getRange(), S.Context,
+        ValidateAttributeStringArg(S, A,
+                                   "compute,vertex,pixel,hull,domain,geometry"),
+        A.getAttributeSpellingListIndex());
+    break;
   case AttributeList::AT_HLSLMaxVertexCount:
 	  declAttr = ::new (S.Context) HLSLMaxVertexCountAttr(A.getRange(), S.Context,
 		  ValidateAttributeIntArg(S, A), A.getAttributeSpellingListIndex());
@@ -11056,6 +11063,15 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, con
     Out << "[patchconstantfunc(\"" << ACast->getFunctionName() << "\")]\n";
     break;
   }
+
+  case clang::attr::HLSLShader:
+  {
+    Attr * noconst = const_cast<Attr*>(A);
+    HLSLShaderAttr *ACast = static_cast<HLSLShaderAttr*>(noconst);
+    Indent(Indentation, Out);
+    Out << "[shader(\"" << ACast->getStage() << "\")]\n";
+    break;
+  }
   
   case clang::attr::HLSLMaxVertexCount:
   {

+ 197 - 0
tools/clang/test/CodeGenHLSL/lib_entries2.hlsl

@@ -0,0 +1,197 @@
+// RUN: %dxc -T lib_6_1 %s | FileCheck %s
+
+// Make sure entry function exist.
+// CHECK: @cs_main()
+// Make sure signatures are lowered.
+// CHECK: dx.op.threadId
+// CHECK: dx.op.groupId
+
+// Make sure entry function exist.
+// CHECK: @gs_main()
+// Make sure signatures are lowered.
+// CHECK: dx.op.loadInput
+// CHECK: dx.op.storeOutput
+// CHECK: dx.op.emitStream
+// CHECK: dx.op.cutStream
+
+// Make sure entry function exist.
+// CHECK: @ds_main()
+// Make sure signatures are lowered.
+// CHECK: dx.op.loadPatchConstant
+// CHECK: dx.op.domainLocation
+// CHECK: dx.op.loadInput
+// CHECK: dx.op.storeOutput
+
+// Make sure patch constant function exist.
+// CHECK: HSPerPatchFunc
+// Make sure signatures are lowered.
+// CHECK: dx.op.storePatchConstant
+
+// Make sure entry function exist.
+// CHECK: @hs_main()
+// Make sure signatures are lowered.
+// CHECK: dx.op.outputControlPointID
+// CHECK: dx.op.loadInput
+// CHECK: dx.op.storeOutput
+
+// Make sure entry function exist.
+// CHECK: @vs_main()
+// Make sure signatures are lowered.
+// CHECK: dx.op.loadInput
+// CHECK: dx.op.storeOutput
+
+// Make sure entry function exist.
+// CHECK: @ps_main()
+// Make sure signatures are lowered.
+// CHECK: dx.op.loadInput
+// CHECK: dx.op.storeOutput
+// Finish ps_main
+// CHECK: ret void
+
+// Make sure cloned function signatures are not lowered.
+// CHECK-NOT: call float @dx.op.loadInput
+// CHECK-NOT: call void @dx.op.storeOutput
+
+
+// Make sure cloned function exist.
+// CHECK: @"\01?ps_main
+
+
+// Make sure function entrys exist.
+// CHECK: dx.func.signatures
+// Make sure cs don't have signature.
+// CHECK: @cs_main, null
+
+void StoreCSOutput(uint2 tid, uint2 gid);
+
+[shader("compute")]
+[numthreads(8,8,1)]
+void cs_main( uint2 tid : SV_DispatchThreadID, uint2 gid : SV_GroupID, uint2 gtid : SV_GroupThreadID, uint gidx : SV_GroupIndex )
+{
+    StoreCSOutput(tid, gid);
+}
+
+// GS
+
+struct GSOut {
+  float2 uv : TEXCOORD0;
+  float4 pos : SV_Position;
+};
+
+// geometry shader that outputs 3 vertices from a point
+[shader("geometry")]
+[maxvertexcount(3)]
+[instance(24)]
+void gs_main(InputPatch<GSOut, 2>points, inout PointStream<GSOut> stream) {
+
+  stream.Append(points[0]);
+
+  stream.RestartStrip();
+}
+
+// DS
+struct PSSceneIn {
+  float4 pos : SV_Position;
+  float2 tex : TEXCOORD0;
+  float3 norm : NORMAL;
+
+uint   RTIndex      : SV_RenderTargetArrayIndex;
+};
+
+struct HSPerVertexData {
+  // This is just the original vertex verbatim. In many real life cases this would be a
+  // control point instead
+  PSSceneIn v;
+};
+
+struct HSPerPatchData {
+  // We at least have to specify tess factors per patch
+  // As we're tesselating triangles, there will be 4 tess factors
+  // In real life case this might contain face normal, for example
+  float edges[3] : SV_TessFactor;
+  float inside : SV_InsideTessFactor;
+};
+
+// domain shader that actually outputs the triangle vertices
+[shader("domain")]
+[domain("tri")] PSSceneIn ds_main(const float3 bary
+                               : SV_DomainLocation,
+                                 const OutputPatch<HSPerVertexData, 3> patch,
+                                 const HSPerPatchData perPatchData) {
+  PSSceneIn v;
+
+  // Compute interpolated coordinates
+  v.pos = patch[0].v.pos * bary.x + patch[1].v.pos * bary.y + patch[2].v.pos * bary.z + perPatchData.edges[1];
+  v.tex = patch[0].v.tex * bary.x + patch[1].v.tex * bary.y + patch[2].v.tex * bary.z + perPatchData.edges[0];
+  v.norm = patch[0].v.norm * bary.x + patch[1].v.norm * bary.y + patch[2].v.norm * bary.z + perPatchData.inside;
+  v.RTIndex = 0;
+  return v;
+}
+
+// HS
+
+HSPerPatchData HSPerPatchFunc( const InputPatch< PSSceneIn, 3 > points, OutputPatch<HSPerVertexData, 3> outp)
+{
+    HSPerPatchData d;
+
+    d.edges[ 0 ] = 1;
+    d.edges[ 1 ] = 1;
+    d.edges[ 2 ] = 1;
+    d.inside = 1;
+
+    return d;
+}
+
+// hull per-control point shader
+[shader("hull")]
+[domain("tri")]
+[partitioning("fractional_odd")]
+[outputtopology("triangle_cw")]
+[patchconstantfunc("HSPerPatchFunc")]
+[outputcontrolpoints(3)]
+HSPerVertexData hs_main( const uint id : SV_OutputControlPointID,
+                               const InputPatch< PSSceneIn, 3 > points)
+{
+    HSPerVertexData v;
+
+    // Just forward the vertex
+    v.v = points[ id ];
+
+	return v;
+}
+
+// VS
+
+struct VS_INPUT
+{
+	float3 vPosition	: POSITION;
+	float3 vNormal		: NORMAL;
+	float2 vTexcoord	: TEXCOORD0;
+};
+
+struct VS_OUTPUT
+{
+	float3 vNormal		: NORMAL;
+	float2 vTexcoord	: TEXCOORD0;
+	float4 vPosition	: SV_POSITION;
+};
+
+
+[shader("vertex")]
+VS_OUTPUT vs_main(VS_INPUT Input)
+{
+	VS_OUTPUT Output;
+
+	Output.vPosition = float4( Input.vPosition, 1.0 );
+	Output.vNormal = Input.vNormal;
+	Output.vTexcoord = Input.vTexcoord;
+
+       return Output;
+}
+
+// PS
+[shader("pixel")]
+float4 ps_main(float4 a : A) : SV_TARGET
+{
+  return a;
+}

+ 9 - 0
tools/clang/test/CodeGenHLSL/shader_attr.hlsl

@@ -0,0 +1,9 @@
+// RUN: %dxc -T lib_6_1 %s | FileCheck %s
+
+// CHECK:attribute 'shader' must have one of these values: compute,vertex,pixel,hull,domain,geometry
+
+[shader("lib")]
+float4 ps_main(float4 a : A) : SV_TARGET
+{
+  return a;
+}

+ 10 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -505,6 +505,7 @@ public:
   TEST_METHOD(CodeGenLibCsEntry2)
   TEST_METHOD(CodeGenLibCsEntry3)
   TEST_METHOD(CodeGenLibEntries)
+  TEST_METHOD(CodeGenLibEntries2)
   TEST_METHOD(CodeGenLibResource)
   TEST_METHOD(CodeGenLibUnusedFunc)
   TEST_METHOD(CodeGenLitInParen)
@@ -614,6 +615,7 @@ public:
   TEST_METHOD(CodeGenSelectObj5)
   TEST_METHOD(CodeGenSelfCopy)
   TEST_METHOD(CodeGenSelMat)
+  TEST_METHOD(CodeGenShaderAttr)
   TEST_METHOD(CodeGenShare_Mem_Dbg)
   TEST_METHOD(CodeGenShare_Mem_Phi)
   TEST_METHOD(CodeGenShare_Mem1)
@@ -2793,6 +2795,10 @@ TEST_F(CompilerTest, CodeGenLibEntries) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\lib_entries.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenLibEntries2) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\lib_entries2.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenLibResource) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\lib_resource.hlsl");
 }
@@ -3241,6 +3247,10 @@ TEST_F(CompilerTest, CodeGenSelMat) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\selMat.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenShaderAttr) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\shader_attr.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenShare_Mem_Dbg) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\share_mem_dbg.hlsl");
 }