浏览代码

Merged PR 108: Validate function params for libraries.

Validate function params for libraries.
Tex Riddell 7 年之前
父节点
当前提交
7fe628800d

+ 6 - 0
docs/DXIL.rst

@@ -2903,12 +2903,18 @@ CONTAINER.PARTMATCHES                    DXIL Container Parts must match Module
 CONTAINER.PARTMISSING                    DXIL Container requires certain parts, corresponding to module
 CONTAINER.PARTREPEATED                   DXIL Container must have only one of each part type
 CONTAINER.ROOTSIGNATUREINCOMPATIBLE      Root Signature in DXIL Container must be compatible with shader
+DECL.ATTRSTRUCT                          Attributes parameter must be struct type
 DECL.DXILFNEXTERN                        External function must be a DXIL function
 DECL.DXILNSRESERVED                      The DXIL reserved prefixes must only be used by built-in functions and types
+DECL.EXTRAARGS                           Extra arguments not allowed for shader functions
 DECL.FNATTRIBUTE                         Functions should only contain known function attributes
 DECL.FNFLATTENPARAM                      Function parameters must not use struct types
 DECL.FNISCALLED                          Functions can only be used by call instructions
 DECL.NOTUSEDEXTERNAL                     External declaration should not be used
+DECL.PARAMSTRUCT                         Callable function parameter must be struct type
+DECL.PAYLOADSTRUCT                       Payload parameter must be struct type
+DECL.RESOURCEINFNSIG                     Resources not allowed in function signatures
+DECL.SHADERRETURNVOID                    Shader functions must return void
 DECL.USEDEXTERNALFUNCTION                External function must be used
 DECL.USEDINTERNAL                        Internal declaration must be used
 FLOW.DEADLOOP                            Loop must have break

+ 2 - 0
include/dxc/HLSL/DxilUtil.h

@@ -90,6 +90,8 @@ namespace dxilutil {
   std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,
     llvm::LLVMContext &Ctx, std::string &DiagStr);
   void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context);
+  // Returns true if type contains HLSL Object type (resource)
+  bool ContainsHLSLObjectType(llvm::Type *Ty);
 }
 
 }

+ 6 - 0
include/dxc/HLSL/DxilValidation.h

@@ -41,12 +41,18 @@ enum class ValidationRule : unsigned {
   ContainerRootSignatureIncompatible, // Root Signature in DXIL Container must be compatible with shader
 
   // Declaration
+  DeclAttrStruct, // Attributes parameter must be struct type
   DeclDxilFnExtern, // External function must be a DXIL function
   DeclDxilNsReserved, // The DXIL reserved prefixes must only be used by built-in functions and types
+  DeclExtraArgs, // Extra arguments not allowed for shader functions
   DeclFnAttribute, // Functions should only contain known function attributes
   DeclFnFlattenParam, // Function parameters must not use struct types
   DeclFnIsCalled, // Functions can only be used by call instructions
   DeclNotUsedExternal, // External declaration should not be used
+  DeclParamStruct, // Callable function parameter must be struct type
+  DeclPayloadStruct, // Payload parameter must be struct type
+  DeclResourceInFnSig, // Resources not allowed in function signatures
+  DeclShaderReturnVoid, // Shader functions must return void
   DeclUsedExternalFunction, // External function must be used
   DeclUsedInternal, // Internal declaration must be used
 

+ 25 - 0
lib/HLSL/DxilUtil.cpp

@@ -14,6 +14,7 @@
 #include "dxc/HLSL/DxilTypeSystem.h"
 #include "dxc/HLSL/DxilUtil.h"
 #include "dxc/HLSL/DxilModule.h"
+#include "dxc/HLSL/HLModule.h"
 #include "llvm/Bitcode/ReaderWriter.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/DiagnosticPrinter.h"
@@ -370,5 +371,29 @@ llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Function* F) {
     F->getEntryBlock().getFirstInsertionPt());
 }
 
+bool ContainsHLSLObjectType(llvm::Type *Ty) {
+  // Unwrap pointer/array
+  while (llvm::isa<llvm::PointerType>(Ty))
+    Ty = llvm::cast<llvm::PointerType>(Ty)->getPointerElementType();
+  while (llvm::isa<llvm::ArrayType>(Ty))
+    Ty = llvm::cast<llvm::ArrayType>(Ty)->getArrayElementType();
+
+  if (llvm::StructType *ST = llvm::dyn_cast<llvm::StructType>(Ty)) {
+    if (ST->getName().startswith("dx.types."))
+      return true;
+    // TODO: How is this suppoed to check for Input/OutputPatch types if
+    // these have already been eliminated in function arguments during CG?
+    if (HLModule::IsHLSLObjectType(Ty))
+      return true;
+    // Otherwise, recurse elements of UDT
+    for (auto ETy : ST->elements()) {
+      if (ContainsHLSLObjectType(ETy))
+        return true;
+    }
+  }
+  return false;
+}
+
+
 }
 }

+ 74 - 23
lib/HLSL/DxilValidation.cpp

@@ -258,6 +258,12 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
     case hlsl::ValidationRule::DeclFnIsCalled: return "Function '%0' is used for something other than calling";
     case hlsl::ValidationRule::DeclFnFlattenParam: return "Type '%0' is a struct type but is used as a parameter in function '%1'";
     case hlsl::ValidationRule::DeclFnAttribute: return "Function '%0' contains invalid attribute '%1' with value '%2'";
+    case hlsl::ValidationRule::DeclResourceInFnSig: return "Function '%0' uses resource in function signature";
+    case hlsl::ValidationRule::DeclPayloadStruct: return "Argument '%0' must be a struct type for payload in shader function '%1'";
+    case hlsl::ValidationRule::DeclAttrStruct: return "Argument '%0' must be a struct type for attributes in shader function '%1'";
+    case hlsl::ValidationRule::DeclParamStruct: return "Argument '%0' must be a struct type for callable shader function '%1'";
+    case hlsl::ValidationRule::DeclExtraArgs: return "Extra argument '%0' not allowed for shader function '%1'";
+    case hlsl::ValidationRule::DeclShaderReturnVoid: return "Shader function '%0' must have void return type";
   }
   // VALRULE-TEXT:END
   llvm_unreachable("invalid value");
@@ -585,7 +591,7 @@ struct ValidationContext {
   }
 
   void EmitGlobalValueError(GlobalValue *GV, ValidationRule rule) {
-    EmitFormatError(rule, { dxilutil::DemangleFunctionName(GV->getName()) });
+    EmitFormatError(rule, { GV->getName() });
   }
 
   // This is the least desirable mechanism, as it has no context.
@@ -595,19 +601,27 @@ struct ValidationContext {
   }
 
   void FormatRuleText(std::string &ruleText, ArrayRef<StringRef> args) {
+    std::string escapedArg;
     // Consider changing const char * to StringRef
     for (unsigned i = 0; i < args.size(); i++) {
       std::string argIdx = "%" + std::to_string(i);
       StringRef pArg = args[i];
       if (pArg == "")
         pArg = "<null>";
+      if (pArg[0] == 1) {
+        escapedArg = "";
+        raw_string_ostream os(escapedArg);
+        dxilutil::PrintEscapedString(pArg, os);
+        os.flush();
+        pArg = escapedArg;
+      }
 
       std::string::size_type offset = ruleText.find(argIdx);
       if (offset == std::string::npos)
         continue;
 
       unsigned size = argIdx.size();
-      ruleText.replace(offset, size, args[i]);
+      ruleText.replace(offset, size, pArg);
     }
   }
 
@@ -3173,61 +3187,98 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
 static void ValidateFunction(Function &F, ValidationContext &ValCtx) {
   if (F.isDeclaration()) {
     ValidateExternalFunction(&F, ValCtx);
+    if (F.isIntrinsic() || IsDxilFunction(&F))
+      return;
   } else {
-    bool isNoArgEntry = ValCtx.DxilMod.HasDxilFunctionProps(&F);
-    if (isNoArgEntry) {
-      switch (ValCtx.DxilMod.GetDxilFunctionProps(&F).shaderKind) {
+    DXIL::ShaderKind shaderKind = DXIL::ShaderKind::Library;
+    bool isShader = ValCtx.DxilMod.HasDxilFunctionProps(&F);
+    unsigned numUDTShaderArgs = 0;
+    if (isShader) {
+      shaderKind = ValCtx.DxilMod.GetDxilFunctionProps(&F).shaderKind;
+      switch (shaderKind) {
       case DXIL::ShaderKind::AnyHit:
-      case DXIL::ShaderKind::Callable:
       case DXIL::ShaderKind::ClosestHit:
+        numUDTShaderArgs = 2;
+        break;
       case DXIL::ShaderKind::Miss:
-        isNoArgEntry = false;
+      case DXIL::ShaderKind::Callable:
+        numUDTShaderArgs = 1;
         break;
       default:
-        isNoArgEntry = true;
         break;
       }
     } else {
-      isNoArgEntry = &F == ValCtx.DxilMod.GetEntryFunction();
-      isNoArgEntry |= &F == ValCtx.DxilMod.GetPatchConstantFunction();
+      isShader = ValCtx.DxilMod.IsPatchConstantShader(&F);
     }
+
     // Entry function should not have parameter.
-    if (!F.arg_empty() && isNoArgEntry)
-      ValCtx.EmitFormatError(ValidationRule::FlowFunctionCall,
-                             {F.getName().str()});
+    if (isShader && 0 == numUDTShaderArgs && !F.arg_empty())
+      ValCtx.EmitFormatError(ValidationRule::FlowFunctionCall, { F.getName() });
+
+    // Shader functions should return void.
+    if (isShader && !F.getReturnType()->isVoidTy())
+      ValCtx.EmitFormatError(ValidationRule::DeclShaderReturnVoid, { F.getName() });
 
     DxilFunctionAnnotation *funcAnnotation =
         ValCtx.DxilMod.GetTypeSystem().GetFunctionAnnotation(&F);
     if (!funcAnnotation) {
-      ValCtx.EmitFormatError(ValidationRule::MetaFunctionAnnotation,
-                             {F.getName().str()});
+      ValCtx.EmitFormatError(ValidationRule::MetaFunctionAnnotation, { F.getName() });
       return;
     }
 
+    auto ArgFormatError = [&](Argument &arg, ValidationRule rule) {
+      if (arg.hasName())
+        ValCtx.EmitFormatError(rule, { arg.getName().str(), F.getName() });
+      else
+        ValCtx.EmitFormatError(rule, { std::to_string(arg.getArgNo()), F.getName() });
+    };
+
     // Validate parameter type.
     for (auto &arg : F.args()) {
       Type *argTy = arg.getType();
       if (argTy->isPointerTy())
         argTy = argTy->getPointerElementType();
+
+      if (numUDTShaderArgs) {
+        if (arg.getArgNo() >= numUDTShaderArgs) {
+          ArgFormatError(arg, ValidationRule::DeclExtraArgs);
+          break;
+        }
+        if (!argTy->isStructTy()) {
+          ArgFormatError(arg,
+            shaderKind == DXIL::ShaderKind::Callable
+              ? ValidationRule::DeclParamStruct
+              : arg.getArgNo() == 0 ? ValidationRule::DeclPayloadStruct
+                                    : ValidationRule::DeclAttrStruct);
+          break;
+        }
+      }
+
       while (argTy->isArrayTy()) {
         argTy = argTy->getArrayElementType();
       }
 
       if (argTy->isStructTy() && !ValCtx.isLibProfile) {
-        if (arg.hasName())
-          ValCtx.EmitFormatError(
-              ValidationRule::DeclFnFlattenParam,
-              {arg.getName().str(), F.getName().str()});
-        else
-          ValCtx.EmitFormatError(ValidationRule::DeclFnFlattenParam,
-                                 {std::to_string(arg.getArgNo()),
-                                  F.getName().str()});
+        ArgFormatError(arg, ValidationRule::DeclFnFlattenParam);
         break;
       }
     }
 
     ValidateFunctionBody(&F, ValCtx);
   }
+
+  // function params & return type must not contain resources
+  if (dxilutil::ContainsHLSLObjectType(F.getReturnType())) {
+    ValCtx.EmitGlobalValueError(&F, ValidationRule::DeclResourceInFnSig);
+    return;
+  }
+  for (auto &Arg : F.args()) {
+    if (dxilutil::ContainsHLSLObjectType(Arg.getType())) {
+      ValCtx.EmitGlobalValueError(&F, ValidationRule::DeclResourceInFnSig);
+      return;
+    }
+  }
+
   // TODO: Remove attribute for lib?
   if (!ValCtx.isLibProfile)
     ValidateFunctionAttribute(&F, ValCtx);

+ 2 - 25
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -4396,29 +4396,6 @@ void CGMSHLSLRuntime::SetPatchConstantFunctionWithAttr(
   
 }
 
-static bool ContainsDisallowedTypeForExport(llvm::Type *Ty) {
-  // Unwrap pointer/array
-  while (llvm::isa<llvm::PointerType>(Ty))
-    Ty = llvm::cast<llvm::PointerType>(Ty)->getPointerElementType();
-  while (llvm::isa<llvm::ArrayType>(Ty))
-    Ty = llvm::cast<llvm::ArrayType>(Ty)->getArrayElementType();
-
-  if (llvm::StructType *ST = llvm::dyn_cast<llvm::StructType>(Ty)) {
-    if (ST->getName().startswith("dx.types."))
-      return true;
-    // TODO: How is this suppoed to check for Input/OutputPatch types if
-    // these have already been eliminated in function arguments during CG?
-    if (HLModule::IsHLSLObjectType(Ty))
-      return true;
-    // Otherwise, recurse elements of UDT
-    for (auto ETy : ST->elements()) {
-      if (ContainsDisallowedTypeForExport(ETy))
-        return true;
-    }
-  }
-  return false;
-}
-
 static void ReportDisallowedTypeInExportParam(CodeGenModule &CGM, StringRef name) {
   DiagnosticsEngine &Diags = CGM.getDiags();
   unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
@@ -4625,12 +4602,12 @@ void CGMSHLSLRuntime::FinishCodeGen() {
           !m_pHLModule->IsPatchConstantShader(&f) &&
           GetHLOpcodeGroup(&f) == HLOpcodeGroup::NotHL) {
         // Verify no resources in param/return types
-        if (ContainsDisallowedTypeForExport(f.getReturnType())) {
+        if (dxilutil::ContainsHLSLObjectType(f.getReturnType())) {
           ReportDisallowedTypeInExportParam(CGM, f.getName());
           continue;
         }
         for (auto &Arg : f.args()) {
-          if (ContainsDisallowedTypeForExport(Arg.getType())) {
+          if (dxilutil::ContainsHLSLObjectType(Arg.getType())) {
             ReportDisallowedTypeInExportParam(CGM, f.getName());
             break;
           }

+ 245 - 2
tools/clang/unittests/HLSL/ValidationTest.cpp

@@ -366,6 +366,14 @@ public:
   TEST_METHOD(ViewIDIn60Fail)
   TEST_METHOD(ViewIDNoSpaceFail)
 
+  TEST_METHOD(LibFunctionResInSig)
+  TEST_METHOD(RayPayloadIsStruct)
+  TEST_METHOD(RayAttrIsStruct)
+  TEST_METHOD(CallableParamIsStruct)
+  TEST_METHOD(RayShaderExtraArg)
+  TEST_METHOD(ResInShaderStruct)
+  TEST_METHOD(ShaderFunctionReturnTypeVoid)
+
   dxc::DxcDllSupport m_dllSupport;
   VersionSupportInfo m_ver;
 
@@ -3209,7 +3217,7 @@ TEST_F(ValidationTest, Float32DenormModeAttribute) {
     { "\"fp32-denorm-mode\"=\"ftz\"" },
     { "\"fp32-denorm-mode\"=\"invalid_mode\"" },
     "contains invalid attribute 'fp32-denorm-mode' with value 'invalid_mode'",
-    true);
+    false);
 }
 
 TEST_F(ValidationTest, ResCounter) {
@@ -3232,5 +3240,240 @@ TEST_F(ValidationTest, FunctionAttributes) {
     { "\"fp32-denorm-mode\"=\"ftz\"" },
     { "\"dummy_attribute\"=\"invalid_mode\"" },
     "contains invalid attribute 'dummy_attribute' with value 'invalid_mode'",
-    true);
+    false);
 }// TODO: reject non-zero padding
+
+TEST_F(ValidationTest, LibFunctionResInSig) {
+  if (m_ver.SkipDxilVersion(1, 3)) return;
+  RewriteAssemblyCheckMsg(
+    "Texture2D<float4> T1;\n"
+    "struct ResInStruct { float f; Texture2D<float4> T; };\n"
+    "struct ResStructInStruct { float f; ResInStruct S; };\n"
+    "ResStructInStruct fnResInReturn(float f) : SV_Target {\n"
+    "  ResStructInStruct S1; S1.f = S1.S.f = f; S1.S.T = T1;\n"
+    "  return S1; }\n"
+    "float fnResInArg(ResStructInStruct S1) : SV_Target {\n"
+    "  return S1.f; }\n"
+    "struct Data { float f; };\n"
+    "float fnStreamInArg(float f, inout PointStream<Data> S1) : SV_Target {\n"
+    "  S1.Append((Data)f); return 1.0; }\n"
+    , "lib_6_x",
+    { "!{!\"lib\", i32 6, i32 15}", "!dx.valver = !{!2}" },
+    { "!{!\"lib\", i32 6, i32 3}", "!dx.valver = !{!1002}\n!1002 = !{i32 1, i32 3}" },
+    {  "Function '\\01?fnResInReturn@@YA?AUResStructInStruct@@M@Z' uses resource in function signature"
+      ,"Function '\\01?fnResInArg@@YAMUResStructInStruct@@@Z' uses resource in function signature"
+      ,"Function '\\01?fnStreamInArg@@YAMMV?$PointStream@UData@@@@@Z' uses resource in function signature"
+      // TODO: Unable to lower stream append, since it's used in a non-GS function.
+      // Should we fail to compile earlier (even on lib_6_x), or add lowering to linker?
+      ,"Function 'dx.hl.op..void (i32, %\"class.PointStream<Data>\"*, float*)' uses resource in function signature"
+    },
+    false);
+}
+
+TEST_F(ValidationTest, RayPayloadIsStruct) {
+  if (m_ver.SkipDxilVersion(1, 3)) return;
+  RewriteAssemblyCheckMsg(
+    "struct Payload { float f; }; struct Attributes { float2 b; };\n"
+    "[shader(\"anyhit\")] void AnyHitProto(inout Payload p, in Attributes a) { p.f += a.b.x; }\n"
+    "void BadAnyHit(inout float f, in Attributes a) { f += a.b.x; }\n"
+    "[shader(\"closesthit\")] void ClosestHitProto(inout Payload p, in Attributes a) { p.f += a.b.y; }\n"
+    "void BadClosestHit(inout float f, in Attributes a) { f += a.b.y; }\n"
+    "[shader(\"miss\")] void MissProto(inout Payload p) { p.f += 1.0; }\n"
+    "void BadMiss(inout float f) { f += 1.0; }\n"
+    , "lib_6_3",
+    { "!{void (%struct.Payload*, %struct.Attributes*)* @\"\\01?AnyHitProto@@YAXUPayload@@UAttributes@@@Z\", "
+        "!\"\\01?AnyHitProto@@YAXUPayload@@UAttributes@@@Z\",",
+      "!{void (%struct.Payload*, %struct.Attributes*)* @\"\\01?ClosestHitProto@@YAXUPayload@@UAttributes@@@Z\", "
+        "!\"\\01?ClosestHitProto@@YAXUPayload@@UAttributes@@@Z\",",
+      "!{void (%struct.Payload*)* @\"\\01?MissProto@@YAXUPayload@@@Z\", "
+        "!\"\\01?MissProto@@YAXUPayload@@@Z\","
+    },
+    { "!{void (float*, %struct.Attributes*)* @\"\\01?BadAnyHit@@YAXAIAMUAttributes@@@Z\", "
+        "!\"\\01?BadAnyHit@@YAXAIAMUAttributes@@@Z\",",
+      "!{void (float*, %struct.Attributes*)* @\"\\01?BadClosestHit@@YAXAIAMUAttributes@@@Z\", "
+        "!\"\\01?BadClosestHit@@YAXAIAMUAttributes@@@Z\",",
+      "!{void (float*)* @\"\\01?BadMiss@@YAXAIAM@Z\", "
+        "!\"\\01?BadMiss@@YAXAIAM@Z\","
+    },
+    {  "Argument 'f' must be a struct type for payload in shader function '\\01?BadAnyHit@@YAXAIAMUAttributes@@@Z'"
+      ,"Argument 'f' must be a struct type for payload in shader function '\\01?BadClosestHit@@YAXAIAMUAttributes@@@Z'"
+      ,"Argument 'f' must be a struct type for payload in shader function '\\01?BadMiss@@YAXAIAM@Z'"
+    },
+    false);
+}
+
+TEST_F(ValidationTest, RayAttrIsStruct) {
+  if (m_ver.SkipDxilVersion(1, 3)) return;
+  RewriteAssemblyCheckMsg(
+    "struct Payload { float f; }; struct Attributes { float2 b; };\n"
+    "[shader(\"anyhit\")] void AnyHitProto(inout Payload p, in Attributes a) { p.f += a.b.x; }\n"
+    "void BadAnyHit(inout Payload p, in float a) { p.f += a; }\n"
+    "[shader(\"closesthit\")] void ClosestHitProto(inout Payload p, in Attributes a) { p.f += a.b.y; }\n"
+    "void BadClosestHit(inout Payload p, in float a) { p.f += a; }\n"
+    , "lib_6_3",
+    { "!{void (%struct.Payload*, %struct.Attributes*)* @\"\\01?AnyHitProto@@YAXUPayload@@UAttributes@@@Z\", "
+        "!\"\\01?AnyHitProto@@YAXUPayload@@UAttributes@@@Z\",",
+      "!{void (%struct.Payload*, %struct.Attributes*)* @\"\\01?ClosestHitProto@@YAXUPayload@@UAttributes@@@Z\", "
+        "!\"\\01?ClosestHitProto@@YAXUPayload@@UAttributes@@@Z\","
+    },
+    { "!{void (%struct.Payload*, float)* @\"\\01?BadAnyHit@@YAXUPayload@@M@Z\", "
+        "!\"\\01?BadAnyHit@@YAXUPayload@@M@Z\",",
+      "!{void (%struct.Payload*, float)* @\"\\01?BadClosestHit@@YAXUPayload@@M@Z\", "
+        "!\"\\01?BadClosestHit@@YAXUPayload@@M@Z\","
+    },
+    {  "Argument 'a' must be a struct type for attributes in shader function '\\01?BadAnyHit@@YAXUPayload@@M@Z'"
+      ,"Argument 'a' must be a struct type for attributes in shader function '\\01?BadClosestHit@@YAXUPayload@@M@Z'"
+    },
+    false);
+}
+
+TEST_F(ValidationTest, CallableParamIsStruct) {
+  if (m_ver.SkipDxilVersion(1, 3)) return;
+  RewriteAssemblyCheckMsg(
+    "struct Param { float f; };\n"
+    "[shader(\"callable\")] void CallableProto(inout Param p) { p.f += 1.0; }\n"
+    "void BadCallable(inout float f) { f += 1.0; }\n"
+    , "lib_6_3",
+    { "!{void (%struct.Param*)* @\"\\01?CallableProto@@YAXUParam@@@Z\", "
+        "!\"\\01?CallableProto@@YAXUParam@@@Z\","
+    },
+    { "!{void (float*)* @\"\\01?BadCallable@@YAXAIAM@Z\", "
+        "!\"\\01?BadCallable@@YAXAIAM@Z\","
+    },
+    {  "Argument 'f' must be a struct type for callable shader function '\\01?BadCallable@@YAXAIAM@Z'"
+    },
+    false);
+}
+
+TEST_F(ValidationTest, RayShaderExtraArg) {
+  if (m_ver.SkipDxilVersion(1, 3)) return;
+  RewriteAssemblyCheckMsg(
+    "struct Payload { float f; }; struct Attributes { float2 b; };\n"
+    "struct Param { float f; };\n"
+    "[shader(\"anyhit\")] void AnyHitProto(inout Payload p, in Attributes a) { p.f += a.b.x; }\n"
+    "[shader(\"closesthit\")] void ClosestHitProto(inout Payload p, in Attributes a) { p.f += a.b.y; }\n"
+    "[shader(\"miss\")] void MissProto(inout Payload p) { p.f += 1.0; }\n"
+    "[shader(\"callable\")] void CallableProto(inout Param p) { p.f += 1.0; }\n"
+    "void BadAnyHit(inout Payload p, in Attributes a, float f) { p.f += f; }\n"
+    "void BadClosestHit(inout Payload p, in Attributes a, float f) { p.f += f; }\n"
+    "void BadMiss(inout Payload p, float f) { p.f += f; }\n"
+    "void BadCallable(inout Param p, float f) { p.f += f; }\n"
+    , "lib_6_3",
+    { "!{void (%struct.Payload*, %struct.Attributes*)* @\"\\01?AnyHitProto@@YAXUPayload@@UAttributes@@@Z\", "
+        "!\"\\01?AnyHitProto@@YAXUPayload@@UAttributes@@@Z\"",
+      "!{void (%struct.Payload*, %struct.Attributes*)* @\"\\01?ClosestHitProto@@YAXUPayload@@UAttributes@@@Z\", "
+        "!\"\\01?ClosestHitProto@@YAXUPayload@@UAttributes@@@Z\"",
+      "!{void (%struct.Payload*)* @\"\\01?MissProto@@YAXUPayload@@@Z\", "
+        "!\"\\01?MissProto@@YAXUPayload@@@Z\"",
+      "!{void (%struct.Param*)* @\"\\01?CallableProto@@YAXUParam@@@Z\", "
+        "!\"\\01?CallableProto@@YAXUParam@@@Z\""
+    },
+    { "!{void (%struct.Payload*, %struct.Attributes*, float)* @\"\\01?BadAnyHit@@YAXUPayload@@UAttributes@@M@Z\", "
+        "!\"\\01?BadAnyHit@@YAXUPayload@@UAttributes@@M@Z\"",
+      "!{void (%struct.Payload*, %struct.Attributes*, float)* @\"\\01?BadClosestHit@@YAXUPayload@@UAttributes@@M@Z\", "
+        "!\"\\01?BadClosestHit@@YAXUPayload@@UAttributes@@M@Z\"",
+      "!{void (%struct.Payload*, float)* @\"\\01?BadMiss@@YAXUPayload@@M@Z\", "
+        "!\"\\01?BadMiss@@YAXUPayload@@M@Z\"",
+      "!{void (%struct.Param*, float)* @\"\\01?BadCallable@@YAXUParam@@M@Z\", "
+        "!\"\\01?BadCallable@@YAXUParam@@M@Z\""
+    },
+    {  "Extra argument 'f' not allowed for shader function '\\01?BadAnyHit@@YAXUPayload@@UAttributes@@M@Z'"
+      ,"Extra argument 'f' not allowed for shader function '\\01?BadClosestHit@@YAXUPayload@@UAttributes@@M@Z'"
+      ,"Extra argument 'f' not allowed for shader function '\\01?BadMiss@@YAXUPayload@@M@Z'"
+      ,"Extra argument 'f' not allowed for shader function '\\01?BadCallable@@YAXUParam@@M@Z'"
+    },
+    false);
+}
+
+TEST_F(ValidationTest, ResInShaderStruct) {
+  if (m_ver.SkipDxilVersion(1, 3)) return;
+  // Verify resource not used in shader argument structure
+  RewriteAssemblyCheckMsg(
+    "struct ResInStruct { float f; Texture2D<float4> T; };\n"
+    "struct ResStructInStruct { float f; ResInStruct S; };\n"
+    "struct Payload { float f; }; struct Attributes { float2 b; };\n"
+    "[shader(\"anyhit\")] void AnyHitProto(inout Payload p, in Attributes a) { p.f += a.b.x; }\n"
+    "void BadAnyHit(inout ResStructInStruct p, in Attributes a) { p.f += a.b.x; }\n"
+    "[shader(\"closesthit\")] void ClosestHitProto(inout Payload p, in Attributes a) { p.f += a.b.y; }\n"
+    "void BadClosestHit(inout ResStructInStruct p, in Attributes a) { p.f += a.b.x; }\n"
+    "[shader(\"miss\")] void MissProto(inout Payload p) { p.f += 1.0; }\n"
+    "void BadMiss(inout ResStructInStruct p) { p.f += 1.0; }\n"
+    "struct Param { float f; };\n"
+    "[shader(\"callable\")] void CallableProto(inout Param p) { p.f += 1.0; }\n"
+    "void BadCallable(inout ResStructInStruct p) { p.f += 1.0; }\n"
+    , "lib_6_x",
+    { "!{!\"lib\", i32 6, i32 15}", "!dx.valver = !{!2}",
+      "!{void (%struct.Payload*, %struct.Attributes*)* @\"\\01?AnyHitProto@@YAXUPayload@@UAttributes@@@Z\", "
+        "!\"\\01?AnyHitProto@@YAXUPayload@@UAttributes@@@Z\",",
+      "!{void (%struct.Payload*, %struct.Attributes*)* @\"\\01?ClosestHitProto@@YAXUPayload@@UAttributes@@@Z\", "
+        "!\"\\01?ClosestHitProto@@YAXUPayload@@UAttributes@@@Z\",",
+      "!{void (%struct.Payload*)* @\"\\01?MissProto@@YAXUPayload@@@Z\", "
+        "!\"\\01?MissProto@@YAXUPayload@@@Z\",",
+      "!{void (%struct.Param*)* @\"\\01?CallableProto@@YAXUParam@@@Z\", "
+        "!\"\\01?CallableProto@@YAXUParam@@@Z\","
+    },
+    { "!{!\"lib\", i32 6, i32 3}", "!dx.valver = !{!1002}\n!1002 = !{i32 1, i32 3}",
+      "!{void (%struct.ResStructInStruct*, %struct.Attributes*)* @\"\\01?BadAnyHit@@YAXUResStructInStruct@@UAttributes@@@Z\", "
+        "!\"\\01?BadAnyHit@@YAXUResStructInStruct@@UAttributes@@@Z\",",
+      "!{void (%struct.ResStructInStruct*, %struct.Attributes*)* @\"\\01?BadClosestHit@@YAXUResStructInStruct@@UAttributes@@@Z\", "
+        "!\"\\01?BadClosestHit@@YAXUResStructInStruct@@UAttributes@@@Z\",",
+      "!{void (%struct.ResStructInStruct*)* @\"\\01?BadMiss@@YAXUResStructInStruct@@@Z\", "
+        "!\"\\01?BadMiss@@YAXUResStructInStruct@@@Z\",",
+      "!{void (%struct.ResStructInStruct*)* @\"\\01?BadCallable@@YAXUResStructInStruct@@@Z\", "
+        "!\"\\01?BadCallable@@YAXUResStructInStruct@@@Z\",",
+    },
+    {  "Function '\\01?BadAnyHit@@YAXUResStructInStruct@@UAttributes@@@Z' uses resource in function signature"
+      ,"Function '\\01?BadClosestHit@@YAXUResStructInStruct@@UAttributes@@@Z' uses resource in function signature"
+      ,"Function '\\01?BadMiss@@YAXUResStructInStruct@@@Z' uses resource in function signature"
+      ,"Function '\\01?BadCallable@@YAXUResStructInStruct@@@Z' uses resource in function signature"
+    },
+    false);
+}
+
+TEST_F(ValidationTest, ShaderFunctionReturnTypeVoid) {
+  if (m_ver.SkipDxilVersion(1, 3)) return;
+  // Verify resource not used in shader argument structure
+  RewriteAssemblyCheckMsg(
+    "struct Payload { float f; }; struct Attributes { float2 b; };\n"
+    "struct Param { float f; };\n"
+    "[shader(\"raygeneration\")] void RayGenProto() { return; }\n"
+    "[shader(\"anyhit\")] void AnyHitProto(inout Payload p, in Attributes a) { p.f += a.b.x; }\n"
+    "[shader(\"closesthit\")] void ClosestHitProto(inout Payload p, in Attributes a) { p.f += a.b.y; }\n"
+    "[shader(\"miss\")] void MissProto(inout Payload p) { p.f += 1.0; }\n"
+    "[shader(\"callable\")] void CallableProto(inout Param p) { p.f += 1.0; }\n"
+    "float BadRayGen() { return 1; }\n"
+    "float BadAnyHit(inout Payload p, in Attributes a) { return p.f; }\n"
+    "float BadClosestHit(inout Payload p, in Attributes a) { return p.f; }\n"
+    "float BadMiss(inout Payload p) { return p.f; }\n"
+    "float BadCallable(inout Param p) { return p.f; }\n"
+    , "lib_6_3",
+    { "!{void ()* @\"\\01?RayGenProto@@YAXXZ\", "
+        "!\"\\01?RayGenProto@@YAXXZ\",",
+      "!{void (%struct.Payload*, %struct.Attributes*)* @\"\\01?AnyHitProto@@YAXUPayload@@UAttributes@@@Z\", "
+        "!\"\\01?AnyHitProto@@YAXUPayload@@UAttributes@@@Z\",",
+      "!{void (%struct.Payload*, %struct.Attributes*)* @\"\\01?ClosestHitProto@@YAXUPayload@@UAttributes@@@Z\", "
+        "!\"\\01?ClosestHitProto@@YAXUPayload@@UAttributes@@@Z\",",
+      "!{void (%struct.Payload*)* @\"\\01?MissProto@@YAXUPayload@@@Z\", "
+        "!\"\\01?MissProto@@YAXUPayload@@@Z\",",
+      "!{void (%struct.Param*)* @\"\\01?CallableProto@@YAXUParam@@@Z\", "
+        "!\"\\01?CallableProto@@YAXUParam@@@Z\","
+    },
+    { "!{float ()* @\"\\01?BadRayGen@@YAMXZ\", "
+        "!\"\\01?BadRayGen@@YAMXZ\",",
+      "!{float (%struct.Payload*, %struct.Attributes*)* @\"\\01?BadAnyHit@@YAMUPayload@@UAttributes@@@Z\", "
+        "!\"\\01?BadAnyHit@@YAMUPayload@@UAttributes@@@Z\",",
+      "!{float (%struct.Payload*, %struct.Attributes*)* @\"\\01?BadClosestHit@@YAMUPayload@@UAttributes@@@Z\", "
+        "!\"\\01?BadClosestHit@@YAMUPayload@@UAttributes@@@Z\",",
+      "!{float (%struct.Payload*)* @\"\\01?BadMiss@@YAMUPayload@@@Z\", "
+        "!\"\\01?BadMiss@@YAMUPayload@@@Z\",",
+      "!{float (%struct.Param*)* @\"\\01?BadCallable@@YAMUParam@@@Z\", "
+        "!\"\\01?BadCallable@@YAMUParam@@@Z\","
+    },
+    {  "Shader function '\\01?BadRayGen@@YAMXZ' must have void return type"
+      ,"Shader function '\\01?BadAnyHit@@YAMUPayload@@UAttributes@@@Z' must have void return type"
+      ,"Shader function '\\01?BadClosestHit@@YAMUPayload@@UAttributes@@@Z' must have void return type"
+      ,"Shader function '\\01?BadMiss@@YAMUPayload@@@Z' must have void return type"
+      ,"Shader function '\\01?BadCallable@@YAMUParam@@@Z' must have void return type"
+    },
+    false);
+}

+ 6 - 0
utils/hct/hctdb.py

@@ -2018,6 +2018,12 @@ class db_dxil(object):
         self.add_valrule_msg("Decl.FnIsCalled", "Functions can only be used by call instructions", "Function '%0' is used for something other than calling")
         self.add_valrule_msg("Decl.FnFlattenParam", "Function parameters must not use struct types", "Type '%0' is a struct type but is used as a parameter in function '%1'")
         self.add_valrule_msg("Decl.FnAttribute", "Functions should only contain known function attributes", "Function '%0' contains invalid attribute '%1' with value '%2'")
+        self.add_valrule_msg("Decl.ResourceInFnSig", "Resources not allowed in function signatures", "Function '%0' uses resource in function signature")
+        self.add_valrule_msg("Decl.PayloadStruct", "Payload parameter must be struct type", "Argument '%0' must be a struct type for payload in shader function '%1'")
+        self.add_valrule_msg("Decl.AttrStruct", "Attributes parameter must be struct type", "Argument '%0' must be a struct type for attributes in shader function '%1'")
+        self.add_valrule_msg("Decl.ParamStruct", "Callable function parameter must be struct type", "Argument '%0' must be a struct type for callable shader function '%1'")
+        self.add_valrule_msg("Decl.ExtraArgs", "Extra arguments not allowed for shader functions", "Extra argument '%0' not allowed for shader function '%1'")
+        self.add_valrule_msg("Decl.ShaderReturnVoid", "Shader functions must return void", "Shader function '%0' must have void return type")
 
         # Assign sensible category names and build up an enumeration description
         cat_names = {