Browse Source

Add RT shader types for function attribution in lib target.

Tex Riddell 7 years ago
parent
commit
e9ce3e97d1

+ 6 - 0
include/dxc/HLSL/DxilConstants.h

@@ -124,6 +124,12 @@ namespace DXIL {
     Domain,
     Compute,
     Library,
+    RayGeneration,
+    Intersection,
+    AnyHit,
+    ClosestHit,
+    Miss,
+    Callable,
     Invalid,
   };
 

+ 4 - 8
include/dxc/HLSL/DxilFunctionProps.h

@@ -65,14 +65,10 @@ struct DxilFunctionProps {
   bool IsDS() const     { return shaderKind == DXIL::ShaderKind::Domain; }
   bool IsCS() const     { return shaderKind == DXIL::ShaderKind::Compute; }
   bool IsGraphics() const {
-    switch (shaderKind) {
-    case DXIL::ShaderKind::Compute:
-    case DXIL::ShaderKind::Library:
-    case DXIL::ShaderKind::Invalid:
-      return false;
-    default:
-      return true;
-    }
+    return (shaderKind >= DXIL::ShaderKind::Pixel && shaderKind <= DXIL::ShaderKind::Domain);
+  }
+  bool IsRay() const {
+    return (shaderKind >= DXIL::ShaderKind::RayGeneration && shaderKind <= DXIL::ShaderKind::Callable);
   }
 };
 

+ 2 - 0
include/dxc/HLSL/DxilShaderModel.h

@@ -38,8 +38,10 @@ public:
   bool IsDS() const     { return m_Kind == Kind::Domain; }
   bool IsCS() const     { return m_Kind == Kind::Compute; }
   bool IsLib() const    { return m_Kind == Kind::Library; }
+  bool IsRay() const    { return m_Kind >= Kind::RayGeneration && m_Kind <= Kind::Callable; }
   bool IsValid() const;
   bool IsValidForDxil() const;
+  bool IsValidForModule() const;
 
   Kind GetKind() const      { return m_Kind; }
   unsigned GetMajor() const { return m_Major; }

+ 3 - 1
lib/HLSL/DxilLinker.cpp

@@ -510,7 +510,9 @@ DxilLinkJob::Link(std::pair<DxilFunctionLinkInfo *, DxilLib *> &entryLinkPair,
 
   DxilFunctionProps props = entryDM.GetDxilFunctionProps(entryFunc);
   if (props.shaderKind == DXIL::ShaderKind::Library ||
-      props.shaderKind == DXIL::ShaderKind::Invalid) {
+      props.shaderKind == DXIL::ShaderKind::Invalid ||
+      (props.shaderKind >= DXIL::ShaderKind::RayGeneration &&
+      props.shaderKind <= DXIL::ShaderKind::Callable)) {
     m_ctx.emitError(profile + Twine(kInvalidProfile));
     // Invalid profile.
     return nullptr;

+ 1 - 0
lib/HLSL/DxilModule.cpp

@@ -137,6 +137,7 @@ OP *DxilModule::GetOP() const { return m_pOP.get(); }
 void DxilModule::SetShaderModel(const ShaderModel *pSM) {
   DXASSERT(m_pSM == nullptr || (pSM != nullptr && *m_pSM == *pSM), "shader model must not change for the module");
   DXASSERT(pSM != nullptr && pSM->IsValidForDxil(), "shader model must be valid");
+  DXASSERT(pSM->IsValidForModule(), "shader model must be valid for top-level module use");
   m_pSM = pSM;
   m_pSM->GetDxilVersion(m_DxilMajor, m_DxilMinor);
   m_pMDHelper->SetShaderModel(m_pSM);

+ 5 - 0
lib/HLSL/DxilShaderModel.cpp

@@ -63,6 +63,11 @@ bool ShaderModel::IsValidForDxil() const {
   return false;
 }
 
+bool ShaderModel::IsValidForModule() const {
+  // Ray tracing shader model should only be used on functions in a lib
+  return IsValid() && !IsRay();
+}
+
 const ShaderModel *ShaderModel::Get(unsigned Idx) {
   DXASSERT_NOMSG(Idx < kNumShaderModels - 1);
   if (Idx < kNumShaderModels - 1)

+ 45 - 7
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -1165,13 +1165,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   bool isDS = false;
   bool isVS = false;
   bool isPS = false;
+  bool isRay = false;
   if (const HLSLShaderAttr *Attr = FD->getAttr<HLSLShaderAttr>()) {
     // Stage is already validate in HandleDeclAttributeForHLSL.
-    // Here just check first letter.
+    // Here just check first letter (or two).
     switch (Attr->getStage()[0]) {
     case 'c':
-      isCS = true;
-      funcProps->shaderKind = DXIL::ShaderKind::Compute;
+      switch (Attr->getStage()[1]) {
+      case 'o':
+        isCS = true;
+        funcProps->shaderKind = DXIL::ShaderKind::Compute;
+        break;
+      case 'l':
+        isRay = true;
+        funcProps->shaderKind = DXIL::ShaderKind::ClosestHit;
+        break;
+      case 'a':
+        isRay = true;
+        funcProps->shaderKind = DXIL::ShaderKind::Callable;
+        break;
+      default:
+        break;
+      }
       break;
     case 'v':
       isVS = true;
@@ -1193,11 +1208,34 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
       isPS = true;
       funcProps->shaderKind = DXIL::ShaderKind::Pixel;
       break;
-    default: {
+    case 'r':
+      isRay = true;
+      funcProps->shaderKind = DXIL::ShaderKind::RayGeneration;
+      break;
+    case 'i':
+      isRay = true;
+      funcProps->shaderKind = DXIL::ShaderKind::Intersection;
+      break;
+    case 'a':
+      isRay = true;
+      funcProps->shaderKind = DXIL::ShaderKind::AnyHit;
+      break;
+    case 'm':
+      isRay = true;
+      funcProps->shaderKind = DXIL::ShaderKind::Miss;
+      break;
+    default:
+      break;
+    }
+    if (funcProps->shaderKind == DXIL::ShaderKind::Invalid) {
       unsigned DiagID = Diags.getCustomDiagID(
-          DiagnosticsEngine::Error, "Invalid profile for shader attribute");
+        DiagnosticsEngine::Error, "Invalid profile for shader attribute");
+      Diags.Report(Attr->getLocation(), DiagID);
+    }
+    if (isEntry && isRay) {
+      unsigned DiagID = Diags.getCustomDiagID(
+        DiagnosticsEngine::Error, "Ray function cannot be used as a global entry point");
       Diags.Report(Attr->getLocation(), DiagID);
-    } break;
     }
   }
 
@@ -1414,7 +1452,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     funcProps->shaderKind = DXIL::ShaderKind::Pixel;
   }
 
-  const unsigned profileAttributes = isCS + isHS + isDS + isGS + isVS + isPS;
+  const unsigned profileAttributes = isCS + isHS + isDS + isGS + isVS + isPS + isRay;
 
   // TODO: check this in front-end and report error.
   DXASSERT(profileAttributes < 2, "profile attributes are mutual exclusive");

+ 1 - 1
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -10409,7 +10409,7 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
     declAttr = ::new (S.Context) HLSLShaderAttr(
         A.getRange(), S.Context,
         ValidateAttributeStringArg(S, A,
-                                   "compute,vertex,pixel,hull,domain,geometry"),
+                                   "compute,vertex,pixel,hull,domain,geometry,raygeneration,intersection,anyhit,closesthit,miss,callable"),
         A.getAttributeSpellingListIndex());
     break;
   case AttributeList::AT_HLSLMaxVertexCount: