Browse Source

Validation: Remove flawed special case for MeshShader Payload type

Tex Riddell 5 years ago
parent
commit
a9fe6e19f5
1 changed files with 15 additions and 54 deletions
  1. 15 54
      lib/HLSL/DxilValidation.cpp

+ 15 - 54
lib/HLSL/DxilValidation.cpp

@@ -3200,7 +3200,9 @@ static void ValidateFunctionMetadata(Function *F, ValidationContext &ValCtx) {
 }
 
 static bool IsLLVMInstructionAllowedForLib(Instruction &I, ValidationContext &ValCtx) {
-  if (!ValCtx.isLibProfile)
+  if (!(ValCtx.isLibProfile ||
+        ValCtx.DxilMod.GetShaderModel()->IsMS() ||
+        ValCtx.DxilMod.GetShaderModel()->IsAS()))
     return false;
   switch (I.getOpcode()) {
   case Instruction::InsertElement:
@@ -3223,41 +3225,6 @@ static bool IsLLVMInstructionAllowedForLib(Instruction &I, ValidationContext &Va
   }
 }
 
-static bool IsFromMeshPayload(Instruction *I) {
-  unsigned opcode = I->getOpcode();
-  switch (opcode) {
-  case Instruction::Alloca: {
-    break;
-  }
-  case Instruction::GetElementPtr: {
-    Value *src0 = I->getOperand(0);
-    if (I = dyn_cast<Instruction>(src0)) {
-      return IsFromMeshPayload(I);
-    }
-    return false;
-  }
-  case Instruction::Store: {
-    Value *src1 = I->getOperand(1);
-    if (I = dyn_cast<Instruction>(src1)) {
-      return IsFromMeshPayload(I);
-    }
-    return false;
-  }
-  default:
-    return false;
-  }
-
-  for (auto user : I->users()) {
-    if (CallInst *CI = dyn_cast<CallInst>(user)) {
-      Function *func = CI->getCalledFunction();
-      StringRef funcName = func->getName();
-      if (funcName.startswith("dx.op.dispatchMesh"))
-        return true;
-    }
-  }
-  return false;
-}
-
 static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
   bool SupportsMinPrecision =
       ValCtx.DxilMod.GetGlobalFlags() & DXIL::kEnableMinPrecision;
@@ -3397,13 +3364,11 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
       unsigned opcode = I.getOpcode();
       switch (opcode) {
       case Instruction::Alloca: {
-        if (!IsFromMeshPayload(&I)) {
-          AllocaInst *AI = cast<AllocaInst>(&I);
-          // TODO: validate address space and alignment
-          Type *Ty = AI->getAllocatedType();
-          if (!ValidateType(Ty, ValCtx)) {
-            continue;
-          }
+        AllocaInst *AI = cast<AllocaInst>(&I);
+        // TODO: validate address space and alignment
+        Type *Ty = AI->getAllocatedType();
+        if (!ValidateType(Ty, ValCtx)) {
+          continue;
         }
       } break;
       case Instruction::ExtractValue: {
@@ -3426,20 +3391,16 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
         }
       } break;
       case Instruction::Store: {
-        if (!IsFromMeshPayload(&I)) {
-          StoreInst *SI = cast<StoreInst>(&I);
-          Type *Ty = SI->getValueOperand()->getType();
-          if (!ValidateType(Ty, ValCtx)) {
-            continue;
-          }
+        StoreInst *SI = cast<StoreInst>(&I);
+        Type *Ty = SI->getValueOperand()->getType();
+        if (!ValidateType(Ty, ValCtx)) {
+          continue;
         }
       } break;
       case Instruction::GetElementPtr: {
-        if (!IsFromMeshPayload(&I)) {
-          Type *Ty = I.getType()->getPointerElementType();
-          if (!ValidateType(Ty, ValCtx)) {
-            continue;
-          }
+        Type *Ty = I.getType()->getPointerElementType();
+        if (!ValidateType(Ty, ValCtx)) {
+          continue;
         }
         GetElementPtrInst *GEP = cast<GetElementPtrInst>(&I);
         bool allImmIndex = true;