Procházet zdrojové kódy

Move wave sensitive check from validation into DxilValidateWaveSensit… (#2640)

* Move wave sensitive check from validation into DxilValidateWaveSensitivity pass.
Xiang Li před 5 roky
rodič
revize
2f440d0462

+ 0 - 1
docs/DXIL.rst

@@ -3188,7 +3188,6 @@ TYPES.INTWIDTH                            Int type must be of valid width
 TYPES.NOMULTIDIM                          Only one dimension allowed for array type
 TYPES.NOPTRTOPTR                          Pointers to pointers, or pointers in structures are not allowed
 TYPES.NOVECTOR                            Vector types must not be present
-UNI.NOWAVESENSITIVEGRADIENT               Gradient operations are not affected by wave-sensitive data or control flow.
 ========================================= =======================================================================================================================================================================================================================================================================================================
 
 .. VALRULES-RST:END

+ 1 - 0
include/dxc/DXIL/DxilUtil.h

@@ -70,6 +70,7 @@ namespace dxilutil {
   bool RemoveUnusedFunctions(llvm::Module &M, llvm::Function *EntryFunc,
                              llvm::Function *PatchConstantFunc, bool IsLib);
   void EmitErrorOnInstruction(llvm::Instruction *I, llvm::StringRef Msg);
+  void EmitWarningOnInstruction(llvm::Instruction *I, llvm::StringRef Msg);
   void EmitResMappingError(llvm::Instruction *Res);
   std::string FormatMessageAtLocation(const llvm::DebugLoc &DL, const llvm::Twine& Msg);
   llvm::Twine FormatMessageWithoutLocation(const llvm::Twine& Msg);

+ 3 - 0
include/dxc/HLSL/DxilGenerationPass.h

@@ -109,6 +109,9 @@ void initializeResumePassesPass(llvm::PassRegistry&);
 void initializeMatrixBitcastLowerPassPass(llvm::PassRegistry&);
 void initializeDxilCleanupAddrSpaceCastPass(llvm::PassRegistry&);
 
+ModulePass *createDxilValidateWaveSensitivityPass();
+void initializeDxilValidateWaveSensitivityPass(llvm::PassRegistry&);
+
 bool AreDxilResourcesDense(llvm::Module *M, hlsl::DxilResourceBase **ppNonDense);
 
 }

+ 0 - 3
include/dxc/HLSL/DxilValidation.h

@@ -273,9 +273,6 @@ enum class ValidationRule : unsigned {
   TypesNoMultiDim, // Only one dimension allowed for array type
   TypesNoPtrToPtr, // Pointers to pointers, or pointers in structures are not allowed
   TypesNoVector, // Vector types must not be present
-
-  // Uniform analysis
-  UniNoWaveSensitiveGradient, // Gradient operations are not affected by wave-sensitive data or control flow.
 };
 // VALRULE-ENUM:END
 

+ 41 - 20
lib/DXIL/DxilUtil.cpp

@@ -210,50 +210,71 @@ std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::StringRef BC,
   return LoadModuleFromBitcode(pBitcodeBuf.get(), Ctx, DiagStr);
 }
 
+std::string FormatMessageAtLocation(const DebugLoc &DL, const Twine& Msg) {
+  std::string locString;
+  raw_string_ostream os(locString);
+  DL.print(os);
+  os << ": " << Msg;
+  return os.str();
+}
+
+Twine FormatMessageWithoutLocation(const Twine& Msg) {
+  return Msg + " Use /Zi for source location.";
+}
+
+static void EmitWarningOrErrorOnInstruction(Instruction *I, StringRef Msg,
+                                            bool bWarning);
+
 // If we don't have debug location and this is select/phi,
 // try recursing users to find instruction with debug info.
 // Only recurse phi/select and limit depth to prevent doing
 // too much work if no debug location found.
-static bool EmitErrorOnInstructionFollowPhiSelect(
-    Instruction *I, StringRef Msg, unsigned depth=0) {
+static bool EmitWarningOrErrorOnInstructionFollowPhiSelect(Instruction *I,
+                                                           StringRef Msg,
+                                                           bool bWarning,
+                                                           unsigned depth = 0) {
   if (depth > 4)
     return false;
   if (I->getDebugLoc().get()) {
-    EmitErrorOnInstruction(I, Msg);
+    EmitWarningOrErrorOnInstruction(I, Msg, bWarning);
     return true;
   }
   if (isa<PHINode>(I) || isa<SelectInst>(I)) {
     for (auto U : I->users())
       if (Instruction *UI = dyn_cast<Instruction>(U))
-        if (EmitErrorOnInstructionFollowPhiSelect(UI, Msg, depth+1))
+        if (EmitWarningOrErrorOnInstructionFollowPhiSelect(UI, Msg, bWarning,
+                                                           depth + 1))
           return true;
   }
   return false;
 }
 
-std::string FormatMessageAtLocation(const DebugLoc &DL, const Twine& Msg) {
-  std::string locString;
-  raw_string_ostream os(locString);
-  DL.print(os);
-  os << ": " << Msg;
-  return os.str();
-}
-
-Twine FormatMessageWithoutLocation(const Twine& Msg) {
-  return Msg + " Use /Zi for source location.";
-}
-
-void EmitErrorOnInstruction(Instruction *I, StringRef Msg) {
+static void EmitWarningOrErrorOnInstruction(Instruction *I, StringRef Msg,
+                                            bool bWarning) {
   const DebugLoc &DL = I->getDebugLoc();
   if (DL.get()) {
-    I->getContext().emitError(FormatMessageAtLocation(DL, Msg));
+    if (bWarning)
+      I->getContext().emitWarning(FormatMessageAtLocation(DL, Msg));
+    else
+      I->getContext().emitError(FormatMessageAtLocation(DL, Msg));
     return;
   } else if (isa<PHINode>(I) || isa<SelectInst>(I)) {
-    if (EmitErrorOnInstructionFollowPhiSelect(I, Msg))
+    if (EmitWarningOrErrorOnInstructionFollowPhiSelect(I, Msg, bWarning))
       return;
   }
 
-  I->getContext().emitError(FormatMessageWithoutLocation(Msg));
+  if (bWarning)
+    I->getContext().emitWarning(FormatMessageWithoutLocation(Msg));
+  else
+    I->getContext().emitError(FormatMessageWithoutLocation(Msg));
+}
+
+void EmitErrorOnInstruction(Instruction *I, StringRef Msg) {
+  EmitWarningOrErrorOnInstruction(I, Msg, /*bWarning*/false);
+}
+
+void EmitWarningOnInstruction(Instruction *I, StringRef Msg) {
+  EmitWarningOrErrorOnInstruction(I, Msg, /*bWarning*/true);
 }
 
 const StringRef kResourceMapErrorMsg =

+ 1 - 0
lib/HLSL/DxcOptimizer.cpp

@@ -114,6 +114,7 @@ HRESULT SetupRegistryPassForHLSL() {
     initializeDxilPromoteStaticResourcesPass(Registry);
     initializeDxilSimpleGVNHoistPass(Registry);
     initializeDxilTranslateRawBufferPass(Registry);
+    initializeDxilValidateWaveSensitivityPass(Registry);
     initializeDxilValueCachePass(Registry);
     initializeDynamicIndexingVectorToArrayPass(Registry);
     initializeEarlyCSELegacyPassPass(Registry);

+ 108 - 0
lib/HLSL/DxilPreparePasses.cpp

@@ -23,6 +23,7 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Analysis/PostDominators.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/DebugInfo.h"
 #include "llvm/IR/PassManager.h"
@@ -922,3 +923,110 @@ ModulePass *llvm::createDxilEmitMetadataPass() {
 }
 
 INITIALIZE_PASS(DxilEmitMetadata, "hlsl-dxilemit", "HLSL DXIL Metadata Emit", false, false)
+
+///////////////////////////////////////////////////////////////////////////////
+
+namespace {
+
+const StringRef UniNoWaveSensitiveGradientErrMsg =
+    "Gradient operations are not affected by wave-sensitive data or control "
+    "flow.";
+
+class DxilValidateWaveSensitivity : public ModulePass {
+public:
+  static char ID; // Pass identification, replacement for typeid
+  explicit DxilValidateWaveSensitivity() : ModulePass(ID) {}
+
+  const char *getPassName() const override {
+    return "HLSL DXIL wave sensitiveity validation";
+  }
+
+  bool runOnModule(Module &M) override {
+    // Only check ps and lib profile.
+    DxilModule &DM = M.GetDxilModule();
+    const ShaderModel *pSM = DM.GetShaderModel();
+    if (!pSM->IsPS() && !pSM->IsLib())
+      return false;
+
+    SmallVector<CallInst *, 16> gradientOps;
+    SmallVector<CallInst *, 16> barriers;
+    SmallVector<CallInst *, 16> waveOps;
+
+    for (auto &F : M) {
+      if (!F.isDeclaration())
+        continue;
+
+      for (User *U : F.users()) {
+        CallInst *CI = dyn_cast<CallInst>(U);
+        if (!CI)
+          continue;
+        Function *FCalled = CI->getCalledFunction();
+        if (!FCalled || !FCalled->isDeclaration())
+          continue;
+
+        if (!hlsl::OP::IsDxilOpFunc(FCalled))
+          continue;
+
+        DXIL::OpCode dxilOpcode = hlsl::OP::GetDxilOpFuncCallInst(CI);
+
+        if (OP::IsDxilOpWave(dxilOpcode)) {
+          waveOps.emplace_back(CI);
+        }
+
+        if (OP::IsDxilOpGradient(dxilOpcode)) {
+          gradientOps.push_back(CI);
+        }
+
+        if (dxilOpcode == DXIL::OpCode::Barrier) {
+          barriers.push_back(CI);
+        }
+      }
+    }
+
+    // Skip if not have wave op.
+    if (waveOps.empty())
+      return false;
+
+    // Skip if no gradient op.
+    if (gradientOps.empty())
+      return false;
+
+    for (auto &F : M) {
+      if (F.isDeclaration())
+        continue;
+
+      SmallVector<CallInst *, 16> localGradientOps;
+      for (CallInst *CI : gradientOps) {
+        if (CI->getParent()->getParent() == &F)
+          localGradientOps.emplace_back(CI);
+      }
+
+      if (localGradientOps.empty())
+        continue;
+
+      PostDominatorTree PDT;
+      PDT.runOnFunction(F);
+      std::unique_ptr<WaveSensitivityAnalysis> WaveVal(
+          WaveSensitivityAnalysis::create(PDT));
+
+      WaveVal->Analyze(&F);
+      for (CallInst *op : localGradientOps) {
+        if (WaveVal->IsWaveSensitive(op)) {
+          dxilutil::EmitWarningOnInstruction(op,
+                                             UniNoWaveSensitiveGradientErrMsg);
+        }
+      }
+    }
+    return false;
+  }
+};
+
+}
+
+char DxilValidateWaveSensitivity::ID = 0;
+
+ModulePass *llvm::createDxilValidateWaveSensitivityPass() {
+  return new DxilValidateWaveSensitivity();
+}
+
+INITIALIZE_PASS(DxilValidateWaveSensitivity, "hlsl-validate-wave-sensitivity", "HLSL DXIL wave sensitiveity validation", false, false)

+ 0 - 25
lib/HLSL/DxilValidation.cpp

@@ -274,7 +274,6 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
     case hlsl::ValidationRule::SmMaxMSSMSize: return "Total Thread Group Shared Memory storage is %0, exceeded %1";
     case hlsl::ValidationRule::SmAmplificationShaderPayloadSize: return "For shader '%0', payload size is greater than %1";
     case hlsl::ValidationRule::SmAmplificationShaderPayloadSizeDeclared: return "For shader '%0', payload size %1 is greater than declared size of %2 bytes";
-    case hlsl::ValidationRule::UniNoWaveSensitiveGradient: return "Gradient operations are not affected by wave-sensitive data or control flow.";
     case hlsl::ValidationRule::FlowReducible: return "Execution flow must be reducible";
     case hlsl::ValidationRule::FlowNoRecusion: return "Recursion is not permitted";
     case hlsl::ValidationRule::FlowDeadLoop: return "Loop must have break";
@@ -2827,26 +2826,6 @@ static bool IsValueMinPrec(DxilModule &DxilMod, Value *V) {
   return Ty->isHalfTy();
 }
 
-static void ValidateGradientOps(Function *F, ArrayRef<CallInst *> ops, ArrayRef<CallInst *> barriers, ValidationContext &ValCtx) {
-  // In the absence of wave operations, the wave validation effect need not happen.
-  // We haven't verified this is true at this point, but validation will fail
-  // later if the flags don't match in any case. Given that most shaders will
-  // not be using these wave operations, it's a reasonable cost saving.
-  if (!ValCtx.DxilMod.m_ShaderFlags.GetWaveOps()) {
-    return;
-  }
-
-    PostDominatorTree PDT;
-    PDT.runOnFunction(*F);
-  std::unique_ptr<WaveSensitivityAnalysis> WaveVal(WaveSensitivityAnalysis::create(PDT));
-  WaveVal->Analyze(F);
-  for (CallInst *op : ops) {
-    if (WaveVal->IsWaveSensitive(op)) {
-      ValCtx.EmitInstrError(op, ValidationRule::UniNoWaveSensitiveGradient);
-    }
-  }
-}
-
 static void ValidateMsIntrinsics(Function *F,
                                  ValidationContext &ValCtx,
                                  CallInst *setMeshOutputCounts,
@@ -3533,10 +3512,6 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
     ValidateControlFlowHint(*b, ValCtx);
   }
 
-  if (!gradientOps.empty()) {
-    ValidateGradientOps(F, gradientOps, barriers, ValCtx);
-  }
-
   ValidateMsIntrinsics(F, ValCtx, setMeshOutputCounts, getMeshPayload);
 
   ValidateAsIntrinsics(F, ValCtx, dispatchMesh);

+ 1 - 0
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -644,6 +644,7 @@ void PassManagerBuilder::populateModulePassManager(
     MPM.add(createComputeViewIdStatePass());
     MPM.add(createDxilDeadFunctionEliminationPass());
     MPM.add(createNoPausePassesPass());
+    MPM.add(createDxilValidateWaveSensitivityPass());
     MPM.add(createDxilEmitMetadataPass());
   }
   // HLSL Change Ends.

+ 4 - 4
tools/clang/test/CodeGenHLSL/val-wave-failures-ps.hlsl

@@ -1,8 +1,8 @@
-// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck -input=stderr %s
 
-// CHECK: Gradient operations are not affected by wave-sensitive data or control flow.
-// CHECK: Gradient operations are not affected by wave-sensitive data or control flow.
-// CHECK: Gradient operations are not affected by wave-sensitive data or control flow.
+// CHECK:warning: Gradient operations are not affected by wave-sensitive data or control flow.
+// CHECK:warning: Gradient operations are not affected by wave-sensitive data or control flow.
+// CHECK:warning: Gradient operations are not affected by wave-sensitive data or control flow.
 
 float4 main(float4 p: SV_Position) : SV_Target {
   // cannot feed into ddx

+ 4 - 4
tools/clang/test/HLSLFileCheck/validation/val-wave-failures-ps.hlsl

@@ -1,8 +1,8 @@
-// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck -input=stderr %s
 
-// CHECK: Gradient operations are not affected by wave-sensitive data or control flow.
-// CHECK: Gradient operations are not affected by wave-sensitive data or control flow.
-// CHECK: Gradient operations are not affected by wave-sensitive data or control flow.
+// CHECK:warning: Gradient operations are not affected by wave-sensitive data or control flow.
+// CHECK:warning: Gradient operations are not affected by wave-sensitive data or control flow.
+// CHECK:warning: Gradient operations are not affected by wave-sensitive data or control flow.
 
 float4 main(float4 p: SV_Position) : SV_Target {
   // cannot feed into ddx

+ 6 - 1
utils/hct/hctdb.py

@@ -1983,6 +1983,11 @@ class db_dxil(object):
         add_pass('hlsl-translate-dxil-opcode-version', 'DxilTranslateRawBuffer', 'Translates one version of dxil to another', [])
         add_pass('hlsl-dxil-cleanup-addrspacecast', 'DxilCleanupAddrSpaceCast', 'HLSL DXIL Cleanup Address Space Cast (part of hlsl-dxilfinalize)', [])
         add_pass('dxil-fix-array-init', 'DxilFixConstArrayInitializer', 'Dxil Fix Array Initializer', [])
+        add_pass('hlsl-validate-wave-sensitivity', 'DxilValidateWaveSensitivity', 'HLSL DXIL wave sensitiveity validation', [])
+        add_pass('dxil-elim-vector', 'DxilEliminateVector', 'Dxil Eliminate Vectors', [])
+        add_pass('dxil-finalize-noops', 'DxilFinalizeNoops', 'Dxil Finalize Noops', [])
+        add_pass('dxil-insert-noops', 'DxilInsertNoops', 'Dxil Insert Noops', [])
+        add_pass('dxil-value-cache', 'DxilValueCache', 'Dxil Value Cache',[])
 
         category_lib="llvm"
         add_pass('ipsccp', 'IPSCCP', 'Interprocedural Sparse Conditional Constant Propagation', [])
@@ -2479,7 +2484,7 @@ class db_dxil(object):
         #self.add_valrule("Uni.NoUniInDiv", "TODO - No instruction requiring uniform execution can be present in divergent block")
         #self.add_valrule("Uni.GradientFlow", "TODO - No divergent gradient operations inside flow control") # a bit more specific than the prior rule
         #self.add_valrule("Uni.ThreadSync", "TODO - Thread sync operation must be in non-varying flow control due to a potential race condition, adding a sync after reading any values controlling shader execution at this point")
-        self.add_valrule("Uni.NoWaveSensitiveGradient", "Gradient operations are not affected by wave-sensitive data or control flow.")
+        #self.add_valrule("Uni.NoWaveSensitiveGradient", "Gradient operations are not affected by wave-sensitive data or control flow.")
         
         self.add_valrule("Flow.Reducible", "Execution flow must be reducible")
         self.add_valrule("Flow.NoRecusion", "Recursion is not permitted")