Browse Source

Merged PR 77: Enable validation for lib.

Enable validation for lib.
TODO: add more lib validation rules.
Xiang_Li (XBox) 7 years ago
parent
commit
630879e980

+ 233 - 25
lib/HLSL/DxilValidation.cpp

@@ -15,6 +15,7 @@
 #include "dxc/HLSL/DxilModule.h"
 #include "dxc/HLSL/DxilShaderModel.h"
 #include "dxc/HLSL/DxilContainer.h"
+#include "dxc/hlsl/DxilFunctionProps.h"
 #include "dxc/Support/Global.h"
 #include "dxc/HLSL/DxilUtil.h"
 #include "dxc/HLSL/DxilInstructions.h"
@@ -350,8 +351,10 @@ struct ValidationContext {
   std::unordered_set<Function *> entryFuncCallSet;
   std::unordered_set<Function *> patchConstFuncCallSet;
   std::unordered_map<unsigned, bool> UavCounterIncMap;
+  std::unordered_map<Type *, DxilResourceBase *> ResTypeMap;
   bool hasOutputPosition[DXIL::kNumOutputStreams];
   bool hasViewID;
+  bool isLibProfile;
   unsigned OutputPositionMask[DXIL::kNumOutputStreams];
   std::vector<unsigned> outputCols;
   std::vector<unsigned> patchConstCols;
@@ -382,8 +385,31 @@ struct ValidationContext {
       hasOutputPosition[i] = false;
       OutputPositionMask[i] = 0;
     }
-    outputCols.resize(DxilMod.GetOutputSignature().GetElements().size(), 0);
-    patchConstCols.resize(DxilMod.GetPatchConstantSignature().GetElements().size(), 0);
+    isLibProfile = dxilModule.GetShaderModel()->IsLib();
+    if (!isLibProfile) {
+      outputCols.resize(DxilMod.GetOutputSignature().GetElements().size(), 0);
+      patchConstCols.resize(
+          DxilMod.GetPatchConstantSignature().GetElements().size(), 0);
+    } else {
+      auto collectResTy = [&](auto &ResTab) {
+        for (auto &Res : ResTab) {
+          Type *Ty = Res->GetGlobalSymbol()->getType()->getPointerElementType();
+          Ty = dxilutil::GetArrayEltTy(Ty);
+          ResTypeMap[Ty] = Res.get();
+        }
+      };
+      collectResTy(DxilMod.GetCBuffers());
+      collectResTy(DxilMod.GetUAVs());
+      collectResTy(DxilMod.GetSRVs());
+      collectResTy(DxilMod.GetSamplers());
+    }
+  }
+  DxilResourceBase *GetResFromTy(Type *Ty) {
+    auto it = ResTypeMap.find(Ty);
+    if (it == ResTypeMap.end())
+      return nullptr;
+    else
+      return it->second;
   }
 
   // Provide direct access to the raw_ostream in DiagPrinter.
@@ -686,8 +712,31 @@ static DXIL::SamplerKind GetSamplerKind(Value *samplerHandle, ValidationContext
 
   DxilInst_CreateHandle createHandle(cast<CallInst>(samplerHandle));
   if (!createHandle) {
-    ValCtx.EmitInstrError(cast<CallInst>(samplerHandle), ValidationRule::InstrHandleNotFromCreateHandle);
-    return DXIL::SamplerKind::Invalid;
+    auto EmitError = [&]() -> DXIL::SamplerKind {
+      ValCtx.EmitInstrError(cast<CallInst>(samplerHandle),
+                            ValidationRule::InstrHandleNotFromCreateHandle);
+      return DXIL::SamplerKind::Invalid;
+    };
+    if (!ValCtx.isLibProfile) {
+      return EmitError();
+    }
+
+    DxilInst_CreateHandleFromResourceStructForLib createHandleFromRes(
+        cast<CallInst>(samplerHandle));
+    if (!createHandleFromRes) {
+      return EmitError();
+    }
+
+    DxilResourceBase *Res =
+        ValCtx.GetResFromTy(createHandleFromRes.get_Resource()->getType());
+    if (!Res) {
+      return EmitError();
+    }
+    if (DxilSampler *S = dynamic_cast<DxilSampler *>(Res)) {
+      return S->GetSamplerKind();
+    } else {
+      return EmitError();
+    }
   }
 
   Value *resClass = createHandle.get_resourceClass();
@@ -747,8 +796,32 @@ static DXIL::ResourceKind GetResourceKindAndCompTy(Value *handle, DXIL::Componen
 
   DxilInst_CreateHandle createHandle(cast<CallInst>(handle));
   if (!createHandle) {
-    ValCtx.EmitInstrError(cast<CallInst>(handle), ValidationRule::InstrHandleNotFromCreateHandle);
-    return DXIL::ResourceKind::Invalid;
+    auto EmitError = [&]() -> DXIL::ResourceKind {
+      ValCtx.EmitInstrError(cast<CallInst>(handle),
+                            ValidationRule::InstrHandleNotFromCreateHandle);
+      return DXIL::ResourceKind::Invalid;
+    };
+    if (!ValCtx.isLibProfile) {
+      return EmitError();
+    }
+    DxilInst_CreateHandleFromResourceStructForLib createHandleFromRes(
+        cast<CallInst>(handle));
+    if (!createHandleFromRes) {
+      return EmitError();
+    }
+    DxilResourceBase *res =
+        ValCtx.GetResFromTy(createHandleFromRes.get_Resource()->getType());
+    if (!res) {
+      return EmitError();
+    }
+    // TODO: resIndex for Uav Counter.
+    if (DxilResource *Res = dynamic_cast<DxilResource *>(res)) {
+      CompTy = Res->GetCompType().GetKind();
+    } else {
+      return EmitError();
+    }
+    ResClass = res->GetClass();
+    return res->GetKind();
   }
 
   Value *resourceClass = createHandle.get_resourceClass();
@@ -1101,9 +1174,30 @@ static unsigned StoreValueToMask(ArrayRef<Value *> vals) {
 static int GetCBufSize(Value *cbHandle, ValidationContext &ValCtx) {
   DxilInst_CreateHandle createHandle(cast<CallInst>(cbHandle));
   if (!createHandle) {
-    ValCtx.EmitInstrError(cast<CallInst>(cbHandle),
-                          ValidationRule::InstrHandleNotFromCreateHandle);
-    return -1;
+    auto EmitError = [&]() -> int {
+      ValCtx.EmitInstrError(cast<CallInst>(cbHandle),
+                            ValidationRule::InstrHandleNotFromCreateHandle);
+      return -1;
+    };
+    if (!ValCtx.isLibProfile) {
+      return EmitError();
+    }
+    DxilInst_CreateHandleFromResourceStructForLib createHandleFromRes(
+        cast<CallInst>(cbHandle));
+    if (!createHandleFromRes) {
+      return EmitError();
+    }
+
+    DxilResourceBase *Res =
+        ValCtx.GetResFromTy(createHandleFromRes.get_Resource()->getType());
+    if (!Res) {
+      return EmitError();
+    }
+    if (DxilCBuffer *CB = dynamic_cast<DxilCBuffer *>(Res)) {
+      return CB->GetSize();
+    } else {
+      return EmitError();
+    }
   }
 
   Value *resourceClass = createHandle.get_resourceClass();
@@ -1180,10 +1274,21 @@ static unsigned GetNumVertices(DXIL::InputPrimitive inputPrimitive) {
   return InputPrimitiveVertexTab[primitiveIdx];
 }
 
+static void ValidateDxilOperationCallInLibProfile(CallInst *CI,
+                                                  DXIL::OpCode opcode,
+                                                  ValidationContext &ValCtx) {
+  // TODO: validation for lib profile.
+}
+
 static void ValidateDxilOperationCallInProfile(CallInst *CI,
                                                DXIL::OpCode opcode,
                                                const ShaderModel *pSM,
                                                ValidationContext &ValCtx) {
+  if (ValCtx.isLibProfile) {
+    ValidateDxilOperationCallInLibProfile(CI, opcode, ValCtx);
+    return;
+  }
+
   switch (opcode) {
   case DXIL::OpCode::LoadInput: {
     Value *inputID = CI->getArgOperand(DXIL::OperandIndex::kLoadInputIDOpIdx);
@@ -1962,7 +2067,7 @@ static bool IsDxilFunction(llvm::Function *F) {
 }
 
 static void ValidateExternalFunction(Function *F, ValidationContext &ValCtx) {
-  if (!IsDxilFunction(F)) {
+  if (!IsDxilFunction(F) && !ValCtx.isLibProfile) {
     ValCtx.EmitGlobalValueError(F, ValidationRule::DeclDxilFnExtern);
     return;
   }
@@ -2028,7 +2133,7 @@ static void ValidateExternalFunction(Function *F, ValidationContext &ValCtx) {
       continue;
     }
 
-    if (!ValidateOpcodeInProfile(dxilOpcode, pSM)) {
+    if (!ValCtx.isLibProfile && !ValidateOpcodeInProfile(dxilOpcode, pSM)) {
       // Opcode not available in profile.
       ValCtx.EmitInstrFormatError(CI, ValidationRule::SmOpcode,
                                   {hlslOP->GetOpCodeName(dxilOpcode),
@@ -2137,6 +2242,9 @@ static bool ValidateType(Type *Ty, ValidationContext &ValCtx) {
     }
     return true;
   }
+  // Lib profile allow all types except those hit ValidationRule::InstrDxilStructUser.
+  if (ValCtx.isLibProfile)
+    return true;
 
   if (Ty->isVectorTy()) {
     ValCtx.EmitTypeError(Ty, ValidationRule::TypesNoVector);
@@ -2430,6 +2538,29 @@ static void ValidateFunctionMetadata(Function *F, ValidationContext &ValCtx) {
   }
 }
 
+static bool IsLLVMInstructionAllowedForLib(Instruction &I, ValidationContext &ValCtx) {
+  if (!ValCtx.isLibProfile)
+    return false;
+  switch (I.getOpcode()) {
+  case Instruction::InsertElement:
+  case Instruction::ExtractElement:
+    return true;
+  case Instruction::Unreachable:
+    if (Instruction *Prev = I.getPrevNode()) {
+      if (CallInst *CI = dyn_cast<CallInst>(Prev)) {
+        Function *F = CI->getCalledFunction();
+        if (IsDxilFunction(F) &&
+            F->hasFnAttribute(Attribute::AttrKind::NoReturn)) {
+          return true;
+        }
+      }
+    }
+    return false;
+  default:
+    return false;
+  }
+}
+
 static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
   bool SupportsMinPrecision =
       ValCtx.DxilMod.GetGlobalFlags() & DXIL::kEnableMinPrecision;
@@ -2446,8 +2577,10 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
 
       // Instructions must be allowed.
       if (!IsLLVMInstructionAllowed(I)) {
-        ValCtx.EmitInstrError(&I, ValidationRule::InstrAllowed);
-        continue;
+        if (!IsLLVMInstructionAllowedForLib(I, ValCtx)) {
+          ValCtx.EmitInstrError(&I, ValidationRule::InstrAllowed);
+          continue;
+        }
       }
 
       // Instructions marked precise may not have minprecision arguments.
@@ -2499,9 +2632,15 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
       }
 
       for (Value *op : I.operands()) {
-        if (!isa<PHINode>(&I) && isa<UndefValue>(op)) {
-          ValCtx.EmitInstrError(&I,
-                                ValidationRule::InstrNoReadingUninitialized);
+        if (isa<UndefValue>(op)) {
+          bool legalUndef = isa<PHINode>(&I);
+          if (InsertElementInst *InsertInst = dyn_cast<InsertElementInst>(&I)) {
+            legalUndef = op == I.getOperand(0);
+          }
+
+          if (!legalUndef)
+            ValCtx.EmitInstrError(&I,
+                                  ValidationRule::InstrNoReadingUninitialized);
         } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(op)) {
           for (Value *opCE : CE->operands()) {
             if (isa<UndefValue>(opCE)) {
@@ -2640,7 +2779,7 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
             ToTy = ToTy->getArrayElementType();
           }
         }
-        if (isa<StructType>(FromTy) || isa<StructType>(ToTy)) {
+        if ((isa<StructType>(FromTy) || isa<StructType>(ToTy)) && !ValCtx.isLibProfile) {
           ValCtx.EmitInstrError(Cast, ValidationRule::InstrStructBitCast);
           continue;
         }
@@ -2688,7 +2827,25 @@ static void ValidateFunction(Function &F, ValidationContext &ValCtx) {
   if (F.isDeclaration()) {
     ValidateExternalFunction(&F, ValCtx);
   } else {
-    if (!F.arg_empty())
+    bool isNoArgEntry = ValCtx.DxilMod.HasDxilFunctionProps(&F);
+    if (isNoArgEntry) {
+      switch (ValCtx.DxilMod.GetDxilFunctionProps(&F).shaderKind) {
+      case DXIL::ShaderKind::AnyHit:
+      case DXIL::ShaderKind::Callable:
+      case DXIL::ShaderKind::ClosestHit:
+      case DXIL::ShaderKind::Miss:
+        isNoArgEntry = false;
+        break;
+      default:
+        isNoArgEntry = true;
+        break;
+      }
+    } else {
+      isNoArgEntry = &F == ValCtx.DxilMod.GetEntryFunction();
+      isNoArgEntry |= &F == ValCtx.DxilMod.GetPatchConstantFunction();
+    }
+    // Entry function should not have parameter.
+    if (!F.arg_empty() && isNoArgEntry)
       ValCtx.EmitFormatError(ValidationRule::FlowFunctionCall,
                              {F.getName().str()});
 
@@ -2709,7 +2866,7 @@ static void ValidateFunction(Function &F, ValidationContext &ValCtx) {
         argTy = argTy->getArrayElementType();
       }
 
-      if (argTy->isStructTy()) {
+      if (argTy->isStructTy() && !ValCtx.isLibProfile) {
         if (arg.hasName())
           ValCtx.EmitFormatError(
               ValidationRule::DeclFnFlattenParam,
@@ -2724,8 +2881,9 @@ static void ValidateFunction(Function &F, ValidationContext &ValCtx) {
 
     ValidateFunctionBody(&F, ValCtx);
   }
-
-  ValidateFunctionAttribute(&F, ValCtx);
+  // TODO: Remove attribute for lib?
+  if (!ValCtx.isLibProfile)
+    ValidateFunctionAttribute(&F, ValCtx);
 
   if (F.hasMetadata()) {
     ValidateFunctionMetadata(&F, ValCtx);
@@ -2737,6 +2895,22 @@ static void ValidateGlobalVariable(GlobalVariable &GV,
   bool isInternalGV =
       dxilutil::IsStaticGlobal(&GV) || dxilutil::IsSharedMemoryGlobal(&GV);
 
+  if (ValCtx.isLibProfile) {
+    auto isResourceGlobal = [&](auto &ResTab) -> bool {
+      for (auto &Res : ResTab) {
+        if (Res->GetGlobalSymbol() == &GV)
+          return true;
+      }
+      return false;
+    };
+
+    bool isRes = isResourceGlobal(ValCtx.DxilMod.GetCBuffers());
+    isRes |= isResourceGlobal(ValCtx.DxilMod.GetUAVs());
+    isRes |= isResourceGlobal(ValCtx.DxilMod.GetSRVs());
+    isRes |= isResourceGlobal(ValCtx.DxilMod.GetSamplers());
+    isInternalGV |= isRes;
+  }
+
   if (!isInternalGV) {
     if (!GV.user_empty()) {
       bool hasInstructionUser = false;
@@ -2907,6 +3081,20 @@ static void ValidateTypeAnnotation(ValidationContext &ValCtx) {
   }
 }
 
+static bool IsLibMetadata(ValidationContext &ValCtx, StringRef name) {
+    if (!ValCtx.isLibProfile)
+        return false;
+  // Skip dx.func.props and dx.func.signatures for now.
+  // And these 2 need validation also.
+  // Or we merge them into Entry, and validate as entry.
+  const char * libMetaNames[] = {"dx.func.props","dx.func.signatures"};
+  for (const char *libName : libMetaNames) {
+    if (name.equals(libName))
+      return true;
+  }
+  return false;
+}
+
 static void ValidateMetadata(ValidationContext &ValCtx) {
   Module *pModule = &ValCtx.M;
   const std::string &target = pModule->getTargetTriple();
@@ -2926,8 +3114,11 @@ static void ValidateMetadata(ValidationContext &ValCtx) {
   for (auto &NamedMetaNode : pModule->named_metadata()) {
     if (!DxilModule::IsKnownNamedMetaData(NamedMetaNode)) {
       StringRef name = NamedMetaNode.getName();
-      if (!name.startswith_lower("llvm."))
+      if (IsLibMetadata(ValCtx, name))
+        continue;
+      if (!name.startswith_lower("llvm.")) {
         ValCtx.EmitFormatError(ValidationRule::MetaKnown, {name.str()});
+      }
       else {
         if (llvmNamedMeta.count(name) == 0) {
           ValCtx.EmitFormatError(ValidationRule::MetaKnown,
@@ -2964,6 +3155,10 @@ static void ValidateResourceOverlap(
     SpacesAllocator<unsigned, DxilResourceBase> &spaceAllocator,
     ValidationContext &ValCtx) {
   unsigned base = res.GetLowerBound();
+  if (ValCtx.isLibProfile && !res.IsAllocated()) {
+    // Skip unallocated resource for library.
+    return;
+  }
   unsigned size = res.GetRangeSize();
   unsigned space = res.GetSpaceID();
 
@@ -3005,6 +3200,9 @@ static void ValidateResource(hlsl::DxilResource &res,
   case DXIL::ResourceKind::Texture2DMS:
   case DXIL::ResourceKind::Texture2DMSArray:
     break;
+  case DXIL::ResourceKind::RTAccelerationStructure:
+    // TODO: check profile.
+    break;
   default:
     ValCtx.EmitResourceError(&res, ValidationRule::SmInvalidResourceKind);
     break;
@@ -3167,7 +3365,7 @@ static void ValidateResources(ValidationContext &ValCtx) {
   for (auto &uav : uavs) {
     if (uav->IsROV()) {
       hasROV = true;
-      if (!ValCtx.DxilMod.GetShaderModel()->IsPS()) {
+      if (!ValCtx.DxilMod.GetShaderModel()->IsPS() && !ValCtx.isLibProfile) {
         ValCtx.EmitResourceError(uav.get(), ValidationRule::SmROVOnlyInPS);
       }
     }
@@ -3224,6 +3422,10 @@ static void ValidateResources(ValidationContext &ValCtx) {
 }
 
 static void ValidateShaderFlags(ValidationContext &ValCtx) {
+  // TODO: validate flags foreach entry.
+  if (ValCtx.isLibProfile)
+    return;
+
   ShaderFlags calcFlags;
   ValCtx.DxilMod.CollectShaderFlagsForModule(calcFlags);
   const uint64_t mask = ShaderFlags::GetShaderFlagsRawForCollection();
@@ -3236,7 +3438,6 @@ static void ValidateShaderFlags(ValidationContext &ValCtx) {
   if (declaredFlagsRaw == calcFlagsRaw) {
     return;
   }
-
   ValCtx.EmitError(ValidationRule::MetaFlagsUsage);
   ValCtx.DiagStream() << "Flags declared=" << declaredFlagsRaw
                       << ", actual=" << calcFlagsRaw << "\n";
@@ -4546,6 +4747,11 @@ HRESULT ValidateDxilContainerParts(llvm::Module *pModule,
     case DFCC_ShaderDebugName:
       continue;
 
+    // Lib part
+    case DFCC_RuntimeData:
+      // TODO: Validate RuntimeData.
+      break;
+
     case DFCC_Container:
     default:
       ValCtx.EmitFormatError(ValidationRule::ContainerPartInvalid, {szFourCC});
@@ -4586,7 +4792,9 @@ HRESULT ValidateDxilContainerParts(llvm::Module *pModule,
       }
     }
   } else {
-    ValCtx.EmitFormatError(ValidationRule::ContainerPartMissing, {"Pipeline State Validation"});
+    // Not for lib.
+    if (!ValCtx.isLibProfile)
+      ValCtx.EmitFormatError(ValidationRule::ContainerPartMissing, {"Pipeline State Validation"});
   }
 
   if (ValCtx.Failed) {

+ 6 - 2
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -388,8 +388,12 @@ public:
       // validator can be used as a fallback.
       bool produceFullContainer = !opts.CodeGenHighLevel && !opts.AstDump && !opts.OptDump && rootSigMajor == 0;
 
-      bool needsValidation = produceFullContainer && !opts.DisableValidation &&
-                             !opts.IsLibraryProfile();
+      bool needsValidation = produceFullContainer && !opts.DisableValidation;
+      // Disable validation for lib_6_1 and lib_6_2.
+      if (compiler.getCodeGenOpts().HLSLProfile == "lib_6_1" ||
+          compiler.getCodeGenOpts().HLSLProfile == "lib_6_2") {
+        needsValidation = false;
+      }
 
       if (needsValidation || (opts.CodeGenHighLevel && !opts.DisableValidation)) {
         UINT32 majorVer, minorVer;

+ 1 - 1
tools/clang/unittests/HLSL/DxilContainerTest.cpp

@@ -685,7 +685,7 @@ TEST_F(DxilContainerTest, CompileWhenOkThenCheckRDAT) {
     "float function1(float x, min12int i) {"
     "  return x + c_buf + b_buf.Load(x) + tex2[i].x; }"
     "float function2(float x) { return x + function_import(x); }"
-    "float function3(int i) {"
+    "void function3(int i) {"
     "  Foo f = consume_buf.Consume();"
     "  f.f2 += 0.5; append_buf.Append(f);"
     "  rov_buf.Store(i, f.i2.x);"