Browse Source

Move wave sensitive check from validation into DxilValidateWaveSensit… (#2640)

* Move wave sensitive check from validation into DxilValidateWaveSensitivity pass.
Xiang Li 5 years ago
parent
commit
2f440d0462

+ 0 - 1
docs/DXIL.rst

@@ -3188,7 +3188,6 @@ TYPES.INTWIDTH                            Int type must be of valid width
 TYPES.NOMULTIDIM                          Only one dimension allowed for array type
 TYPES.NOMULTIDIM                          Only one dimension allowed for array type
 TYPES.NOPTRTOPTR                          Pointers to pointers, or pointers in structures are not allowed
 TYPES.NOPTRTOPTR                          Pointers to pointers, or pointers in structures are not allowed
 TYPES.NOVECTOR                            Vector types must not be present
 TYPES.NOVECTOR                            Vector types must not be present
-UNI.NOWAVESENSITIVEGRADIENT               Gradient operations are not affected by wave-sensitive data or control flow.
 ========================================= =======================================================================================================================================================================================================================================================================================================
 ========================================= =======================================================================================================================================================================================================================================================================================================
 
 
 .. VALRULES-RST:END
 .. VALRULES-RST:END

+ 1 - 0
include/dxc/DXIL/DxilUtil.h

@@ -70,6 +70,7 @@ namespace dxilutil {
   bool RemoveUnusedFunctions(llvm::Module &M, llvm::Function *EntryFunc,
   bool RemoveUnusedFunctions(llvm::Module &M, llvm::Function *EntryFunc,
                              llvm::Function *PatchConstantFunc, bool IsLib);
                              llvm::Function *PatchConstantFunc, bool IsLib);
   void EmitErrorOnInstruction(llvm::Instruction *I, llvm::StringRef Msg);
   void EmitErrorOnInstruction(llvm::Instruction *I, llvm::StringRef Msg);
+  void EmitWarningOnInstruction(llvm::Instruction *I, llvm::StringRef Msg);
   void EmitResMappingError(llvm::Instruction *Res);
   void EmitResMappingError(llvm::Instruction *Res);
   std::string FormatMessageAtLocation(const llvm::DebugLoc &DL, const llvm::Twine& Msg);
   std::string FormatMessageAtLocation(const llvm::DebugLoc &DL, const llvm::Twine& Msg);
   llvm::Twine FormatMessageWithoutLocation(const llvm::Twine& Msg);
   llvm::Twine FormatMessageWithoutLocation(const llvm::Twine& Msg);

+ 3 - 0
include/dxc/HLSL/DxilGenerationPass.h

@@ -109,6 +109,9 @@ void initializeResumePassesPass(llvm::PassRegistry&);
 void initializeMatrixBitcastLowerPassPass(llvm::PassRegistry&);
 void initializeMatrixBitcastLowerPassPass(llvm::PassRegistry&);
 void initializeDxilCleanupAddrSpaceCastPass(llvm::PassRegistry&);
 void initializeDxilCleanupAddrSpaceCastPass(llvm::PassRegistry&);
 
 
+ModulePass *createDxilValidateWaveSensitivityPass();
+void initializeDxilValidateWaveSensitivityPass(llvm::PassRegistry&);
+
 bool AreDxilResourcesDense(llvm::Module *M, hlsl::DxilResourceBase **ppNonDense);
 bool AreDxilResourcesDense(llvm::Module *M, hlsl::DxilResourceBase **ppNonDense);
 
 
 }
 }

+ 0 - 3
include/dxc/HLSL/DxilValidation.h

@@ -273,9 +273,6 @@ enum class ValidationRule : unsigned {
   TypesNoMultiDim, // Only one dimension allowed for array type
   TypesNoMultiDim, // Only one dimension allowed for array type
   TypesNoPtrToPtr, // Pointers to pointers, or pointers in structures are not allowed
   TypesNoPtrToPtr, // Pointers to pointers, or pointers in structures are not allowed
   TypesNoVector, // Vector types must not be present
   TypesNoVector, // Vector types must not be present
-
-  // Uniform analysis
-  UniNoWaveSensitiveGradient, // Gradient operations are not affected by wave-sensitive data or control flow.
 };
 };
 // VALRULE-ENUM:END
 // VALRULE-ENUM:END
 
 

+ 41 - 20
lib/DXIL/DxilUtil.cpp

@@ -210,50 +210,71 @@ std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::StringRef BC,
   return LoadModuleFromBitcode(pBitcodeBuf.get(), Ctx, DiagStr);
   return LoadModuleFromBitcode(pBitcodeBuf.get(), Ctx, DiagStr);
 }
 }
 
 
+std::string FormatMessageAtLocation(const DebugLoc &DL, const Twine& Msg) {
+  std::string locString;
+  raw_string_ostream os(locString);
+  DL.print(os);
+  os << ": " << Msg;
+  return os.str();
+}
+
+Twine FormatMessageWithoutLocation(const Twine& Msg) {
+  return Msg + " Use /Zi for source location.";
+}
+
+static void EmitWarningOrErrorOnInstruction(Instruction *I, StringRef Msg,
+                                            bool bWarning);
+
 // If we don't have debug location and this is select/phi,
 // If we don't have debug location and this is select/phi,
 // try recursing users to find instruction with debug info.
 // try recursing users to find instruction with debug info.
 // Only recurse phi/select and limit depth to prevent doing
 // Only recurse phi/select and limit depth to prevent doing
 // too much work if no debug location found.
 // too much work if no debug location found.
-static bool EmitErrorOnInstructionFollowPhiSelect(
-    Instruction *I, StringRef Msg, unsigned depth=0) {
+static bool EmitWarningOrErrorOnInstructionFollowPhiSelect(Instruction *I,
+                                                           StringRef Msg,
+                                                           bool bWarning,
+                                                           unsigned depth = 0) {
   if (depth > 4)
   if (depth > 4)
     return false;
     return false;
   if (I->getDebugLoc().get()) {
   if (I->getDebugLoc().get()) {
-    EmitErrorOnInstruction(I, Msg);
+    EmitWarningOrErrorOnInstruction(I, Msg, bWarning);
     return true;
     return true;
   }
   }
   if (isa<PHINode>(I) || isa<SelectInst>(I)) {
   if (isa<PHINode>(I) || isa<SelectInst>(I)) {
     for (auto U : I->users())
     for (auto U : I->users())
       if (Instruction *UI = dyn_cast<Instruction>(U))
       if (Instruction *UI = dyn_cast<Instruction>(U))
-        if (EmitErrorOnInstructionFollowPhiSelect(UI, Msg, depth+1))
+        if (EmitWarningOrErrorOnInstructionFollowPhiSelect(UI, Msg, bWarning,
+                                                           depth + 1))
           return true;
           return true;
   }
   }
   return false;
   return false;
 }
 }
 
 
-std::string FormatMessageAtLocation(const DebugLoc &DL, const Twine& Msg) {
-  std::string locString;
-  raw_string_ostream os(locString);
-  DL.print(os);
-  os << ": " << Msg;
-  return os.str();
-}
-
-Twine FormatMessageWithoutLocation(const Twine& Msg) {
-  return Msg + " Use /Zi for source location.";
-}
-
-void EmitErrorOnInstruction(Instruction *I, StringRef Msg) {
+static void EmitWarningOrErrorOnInstruction(Instruction *I, StringRef Msg,
+                                            bool bWarning) {
   const DebugLoc &DL = I->getDebugLoc();
   const DebugLoc &DL = I->getDebugLoc();
   if (DL.get()) {
   if (DL.get()) {
-    I->getContext().emitError(FormatMessageAtLocation(DL, Msg));
+    if (bWarning)
+      I->getContext().emitWarning(FormatMessageAtLocation(DL, Msg));
+    else
+      I->getContext().emitError(FormatMessageAtLocation(DL, Msg));
     return;
     return;
   } else if (isa<PHINode>(I) || isa<SelectInst>(I)) {
   } else if (isa<PHINode>(I) || isa<SelectInst>(I)) {
-    if (EmitErrorOnInstructionFollowPhiSelect(I, Msg))
+    if (EmitWarningOrErrorOnInstructionFollowPhiSelect(I, Msg, bWarning))
       return;
       return;
   }
   }
 
 
-  I->getContext().emitError(FormatMessageWithoutLocation(Msg));
+  if (bWarning)
+    I->getContext().emitWarning(FormatMessageWithoutLocation(Msg));
+  else
+    I->getContext().emitError(FormatMessageWithoutLocation(Msg));
+}
+
+void EmitErrorOnInstruction(Instruction *I, StringRef Msg) {
+  EmitWarningOrErrorOnInstruction(I, Msg, /*bWarning*/false);
+}
+
+void EmitWarningOnInstruction(Instruction *I, StringRef Msg) {
+  EmitWarningOrErrorOnInstruction(I, Msg, /*bWarning*/true);
 }
 }
 
 
 const StringRef kResourceMapErrorMsg =
 const StringRef kResourceMapErrorMsg =

+ 1 - 0
lib/HLSL/DxcOptimizer.cpp

@@ -114,6 +114,7 @@ HRESULT SetupRegistryPassForHLSL() {
     initializeDxilPromoteStaticResourcesPass(Registry);
     initializeDxilPromoteStaticResourcesPass(Registry);
     initializeDxilSimpleGVNHoistPass(Registry);
     initializeDxilSimpleGVNHoistPass(Registry);
     initializeDxilTranslateRawBufferPass(Registry);
     initializeDxilTranslateRawBufferPass(Registry);
+    initializeDxilValidateWaveSensitivityPass(Registry);
     initializeDxilValueCachePass(Registry);
     initializeDxilValueCachePass(Registry);
     initializeDynamicIndexingVectorToArrayPass(Registry);
     initializeDynamicIndexingVectorToArrayPass(Registry);
     initializeEarlyCSELegacyPassPass(Registry);
     initializeEarlyCSELegacyPassPass(Registry);

+ 108 - 0
lib/HLSL/DxilPreparePasses.cpp

@@ -23,6 +23,7 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Analysis/PostDominators.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/DebugInfo.h"
 #include "llvm/IR/DebugInfo.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/PassManager.h"
@@ -922,3 +923,110 @@ ModulePass *llvm::createDxilEmitMetadataPass() {
 }
 }
 
 
 INITIALIZE_PASS(DxilEmitMetadata, "hlsl-dxilemit", "HLSL DXIL Metadata Emit", false, false)
 INITIALIZE_PASS(DxilEmitMetadata, "hlsl-dxilemit", "HLSL DXIL Metadata Emit", false, false)
+
+///////////////////////////////////////////////////////////////////////////////
+
+namespace {
+
+const StringRef UniNoWaveSensitiveGradientErrMsg =
+    "Gradient operations are not affected by wave-sensitive data or control "
+    "flow.";
+
+class DxilValidateWaveSensitivity : public ModulePass {
+public:
+  static char ID; // Pass identification, replacement for typeid
+  explicit DxilValidateWaveSensitivity() : ModulePass(ID) {}
+
+  const char *getPassName() const override {
+    return "HLSL DXIL wave sensitiveity validation";
+  }
+
+  bool runOnModule(Module &M) override {
+    // Only check ps and lib profile.
+    DxilModule &DM = M.GetDxilModule();
+    const ShaderModel *pSM = DM.GetShaderModel();
+    if (!pSM->IsPS() && !pSM->IsLib())
+      return false;
+
+    SmallVector<CallInst *, 16> gradientOps;
+    SmallVector<CallInst *, 16> barriers;
+    SmallVector<CallInst *, 16> waveOps;
+
+    for (auto &F : M) {
+      if (!F.isDeclaration())
+        continue;
+
+      for (User *U : F.users()) {
+        CallInst *CI = dyn_cast<CallInst>(U);
+        if (!CI)
+          continue;
+        Function *FCalled = CI->getCalledFunction();
+        if (!FCalled || !FCalled->isDeclaration())
+          continue;
+
+        if (!hlsl::OP::IsDxilOpFunc(FCalled))
+          continue;
+
+        DXIL::OpCode dxilOpcode = hlsl::OP::GetDxilOpFuncCallInst(CI);
+
+        if (OP::IsDxilOpWave(dxilOpcode)) {
+          waveOps.emplace_back(CI);
+        }
+
+        if (OP::IsDxilOpGradient(dxilOpcode)) {
+          gradientOps.push_back(CI);
+        }
+
+        if (dxilOpcode == DXIL::OpCode::Barrier) {
+          barriers.push_back(CI);
+        }
+      }
+    }
+
+    // Skip if not have wave op.
+    if (waveOps.empty())
+      return false;
+
+    // Skip if no gradient op.
+    if (gradientOps.empty())
+      return false;
+
+    for (auto &F : M) {
+      if (F.isDeclaration())
+        continue;
+
+      SmallVector<CallInst *, 16> localGradientOps;
+      for (CallInst *CI : gradientOps) {
+        if (CI->getParent()->getParent() == &F)
+          localGradientOps.emplace_back(CI);
+      }
+
+      if (localGradientOps.empty())
+        continue;
+
+      PostDominatorTree PDT;
+      PDT.runOnFunction(F);
+      std::unique_ptr<WaveSensitivityAnalysis> WaveVal(
+          WaveSensitivityAnalysis::create(PDT));
+
+      WaveVal->Analyze(&F);
+      for (CallInst *op : localGradientOps) {
+        if (WaveVal->IsWaveSensitive(op)) {
+          dxilutil::EmitWarningOnInstruction(op,
+                                             UniNoWaveSensitiveGradientErrMsg);
+        }
+      }
+    }
+    return false;
+  }
+};
+
+}
+
+char DxilValidateWaveSensitivity::ID = 0;
+
+ModulePass *llvm::createDxilValidateWaveSensitivityPass() {
+  return new DxilValidateWaveSensitivity();
+}
+
+INITIALIZE_PASS(DxilValidateWaveSensitivity, "hlsl-validate-wave-sensitivity", "HLSL DXIL wave sensitiveity validation", false, false)

+ 0 - 25
lib/HLSL/DxilValidation.cpp

@@ -274,7 +274,6 @@ const char *hlsl::GetValidationRuleText(ValidationRule value) {
     case hlsl::ValidationRule::SmMaxMSSMSize: return "Total Thread Group Shared Memory storage is %0, exceeded %1";
     case hlsl::ValidationRule::SmMaxMSSMSize: return "Total Thread Group Shared Memory storage is %0, exceeded %1";
     case hlsl::ValidationRule::SmAmplificationShaderPayloadSize: return "For shader '%0', payload size is greater than %1";
     case hlsl::ValidationRule::SmAmplificationShaderPayloadSize: return "For shader '%0', payload size is greater than %1";
     case hlsl::ValidationRule::SmAmplificationShaderPayloadSizeDeclared: return "For shader '%0', payload size %1 is greater than declared size of %2 bytes";
     case hlsl::ValidationRule::SmAmplificationShaderPayloadSizeDeclared: return "For shader '%0', payload size %1 is greater than declared size of %2 bytes";
-    case hlsl::ValidationRule::UniNoWaveSensitiveGradient: return "Gradient operations are not affected by wave-sensitive data or control flow.";
     case hlsl::ValidationRule::FlowReducible: return "Execution flow must be reducible";
     case hlsl::ValidationRule::FlowReducible: return "Execution flow must be reducible";
     case hlsl::ValidationRule::FlowNoRecusion: return "Recursion is not permitted";
     case hlsl::ValidationRule::FlowNoRecusion: return "Recursion is not permitted";
     case hlsl::ValidationRule::FlowDeadLoop: return "Loop must have break";
     case hlsl::ValidationRule::FlowDeadLoop: return "Loop must have break";
@@ -2827,26 +2826,6 @@ static bool IsValueMinPrec(DxilModule &DxilMod, Value *V) {
   return Ty->isHalfTy();
   return Ty->isHalfTy();
 }
 }
 
 
-static void ValidateGradientOps(Function *F, ArrayRef<CallInst *> ops, ArrayRef<CallInst *> barriers, ValidationContext &ValCtx) {
-  // In the absence of wave operations, the wave validation effect need not happen.
-  // We haven't verified this is true at this point, but validation will fail
-  // later if the flags don't match in any case. Given that most shaders will
-  // not be using these wave operations, it's a reasonable cost saving.
-  if (!ValCtx.DxilMod.m_ShaderFlags.GetWaveOps()) {
-    return;
-  }
-
-    PostDominatorTree PDT;
-    PDT.runOnFunction(*F);
-  std::unique_ptr<WaveSensitivityAnalysis> WaveVal(WaveSensitivityAnalysis::create(PDT));
-  WaveVal->Analyze(F);
-  for (CallInst *op : ops) {
-    if (WaveVal->IsWaveSensitive(op)) {
-      ValCtx.EmitInstrError(op, ValidationRule::UniNoWaveSensitiveGradient);
-    }
-  }
-}
-
 static void ValidateMsIntrinsics(Function *F,
 static void ValidateMsIntrinsics(Function *F,
                                  ValidationContext &ValCtx,
                                  ValidationContext &ValCtx,
                                  CallInst *setMeshOutputCounts,
                                  CallInst *setMeshOutputCounts,
@@ -3533,10 +3512,6 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
     ValidateControlFlowHint(*b, ValCtx);
     ValidateControlFlowHint(*b, ValCtx);
   }
   }
 
 
-  if (!gradientOps.empty()) {
-    ValidateGradientOps(F, gradientOps, barriers, ValCtx);
-  }
-
   ValidateMsIntrinsics(F, ValCtx, setMeshOutputCounts, getMeshPayload);
   ValidateMsIntrinsics(F, ValCtx, setMeshOutputCounts, getMeshPayload);
 
 
   ValidateAsIntrinsics(F, ValCtx, dispatchMesh);
   ValidateAsIntrinsics(F, ValCtx, dispatchMesh);

+ 1 - 0
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -644,6 +644,7 @@ void PassManagerBuilder::populateModulePassManager(
     MPM.add(createComputeViewIdStatePass());
     MPM.add(createComputeViewIdStatePass());
     MPM.add(createDxilDeadFunctionEliminationPass());
     MPM.add(createDxilDeadFunctionEliminationPass());
     MPM.add(createNoPausePassesPass());
     MPM.add(createNoPausePassesPass());
+    MPM.add(createDxilValidateWaveSensitivityPass());
     MPM.add(createDxilEmitMetadataPass());
     MPM.add(createDxilEmitMetadataPass());
   }
   }
   // HLSL Change Ends.
   // HLSL Change Ends.

+ 4 - 4
tools/clang/test/CodeGenHLSL/val-wave-failures-ps.hlsl

@@ -1,8 +1,8 @@
-// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck -input=stderr %s
 
 
-// CHECK: Gradient operations are not affected by wave-sensitive data or control flow.
-// CHECK: Gradient operations are not affected by wave-sensitive data or control flow.
-// CHECK: Gradient operations are not affected by wave-sensitive data or control flow.
+// CHECK:warning: Gradient operations are not affected by wave-sensitive data or control flow.
+// CHECK:warning: Gradient operations are not affected by wave-sensitive data or control flow.
+// CHECK:warning: Gradient operations are not affected by wave-sensitive data or control flow.
 
 
 float4 main(float4 p: SV_Position) : SV_Target {
 float4 main(float4 p: SV_Position) : SV_Target {
   // cannot feed into ddx
   // cannot feed into ddx

+ 4 - 4
tools/clang/test/HLSLFileCheck/validation/val-wave-failures-ps.hlsl

@@ -1,8 +1,8 @@
-// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck -input=stderr %s
 
 
-// CHECK: Gradient operations are not affected by wave-sensitive data or control flow.
-// CHECK: Gradient operations are not affected by wave-sensitive data or control flow.
-// CHECK: Gradient operations are not affected by wave-sensitive data or control flow.
+// CHECK:warning: Gradient operations are not affected by wave-sensitive data or control flow.
+// CHECK:warning: Gradient operations are not affected by wave-sensitive data or control flow.
+// CHECK:warning: Gradient operations are not affected by wave-sensitive data or control flow.
 
 
 float4 main(float4 p: SV_Position) : SV_Target {
 float4 main(float4 p: SV_Position) : SV_Target {
   // cannot feed into ddx
   // cannot feed into ddx

+ 6 - 1
utils/hct/hctdb.py

@@ -1983,6 +1983,11 @@ class db_dxil(object):
         add_pass('hlsl-translate-dxil-opcode-version', 'DxilTranslateRawBuffer', 'Translates one version of dxil to another', [])
         add_pass('hlsl-translate-dxil-opcode-version', 'DxilTranslateRawBuffer', 'Translates one version of dxil to another', [])
         add_pass('hlsl-dxil-cleanup-addrspacecast', 'DxilCleanupAddrSpaceCast', 'HLSL DXIL Cleanup Address Space Cast (part of hlsl-dxilfinalize)', [])
         add_pass('hlsl-dxil-cleanup-addrspacecast', 'DxilCleanupAddrSpaceCast', 'HLSL DXIL Cleanup Address Space Cast (part of hlsl-dxilfinalize)', [])
         add_pass('dxil-fix-array-init', 'DxilFixConstArrayInitializer', 'Dxil Fix Array Initializer', [])
         add_pass('dxil-fix-array-init', 'DxilFixConstArrayInitializer', 'Dxil Fix Array Initializer', [])
+        add_pass('hlsl-validate-wave-sensitivity', 'DxilValidateWaveSensitivity', 'HLSL DXIL wave sensitiveity validation', [])
+        add_pass('dxil-elim-vector', 'DxilEliminateVector', 'Dxil Eliminate Vectors', [])
+        add_pass('dxil-finalize-noops', 'DxilFinalizeNoops', 'Dxil Finalize Noops', [])
+        add_pass('dxil-insert-noops', 'DxilInsertNoops', 'Dxil Insert Noops', [])
+        add_pass('dxil-value-cache', 'DxilValueCache', 'Dxil Value Cache',[])
 
 
         category_lib="llvm"
         category_lib="llvm"
         add_pass('ipsccp', 'IPSCCP', 'Interprocedural Sparse Conditional Constant Propagation', [])
         add_pass('ipsccp', 'IPSCCP', 'Interprocedural Sparse Conditional Constant Propagation', [])
@@ -2479,7 +2484,7 @@ class db_dxil(object):
         #self.add_valrule("Uni.NoUniInDiv", "TODO - No instruction requiring uniform execution can be present in divergent block")
         #self.add_valrule("Uni.NoUniInDiv", "TODO - No instruction requiring uniform execution can be present in divergent block")
         #self.add_valrule("Uni.GradientFlow", "TODO - No divergent gradient operations inside flow control") # a bit more specific than the prior rule
         #self.add_valrule("Uni.GradientFlow", "TODO - No divergent gradient operations inside flow control") # a bit more specific than the prior rule
         #self.add_valrule("Uni.ThreadSync", "TODO - Thread sync operation must be in non-varying flow control due to a potential race condition, adding a sync after reading any values controlling shader execution at this point")
         #self.add_valrule("Uni.ThreadSync", "TODO - Thread sync operation must be in non-varying flow control due to a potential race condition, adding a sync after reading any values controlling shader execution at this point")
-        self.add_valrule("Uni.NoWaveSensitiveGradient", "Gradient operations are not affected by wave-sensitive data or control flow.")
+        #self.add_valrule("Uni.NoWaveSensitiveGradient", "Gradient operations are not affected by wave-sensitive data or control flow.")
         
         
         self.add_valrule("Flow.Reducible", "Execution flow must be reducible")
         self.add_valrule("Flow.Reducible", "Execution flow must be reducible")
         self.add_valrule("Flow.NoRecusion", "Recursion is not permitted")
         self.add_valrule("Flow.NoRecusion", "Recursion is not permitted")