Explorar el Código

Only propagate WaveSensitive when target BB not post dom current BB. (#1648)

Xiang Li hace 6 años
padre
commit
0df5e31b43

+ 2 - 1
include/dxc/HLSL/DxilGenerationPass.h

@@ -19,13 +19,14 @@ class FunctionPass;
 class Instruction;
 class PassRegistry;
 class StringRef;
+struct PostDominatorTree;
 }
 
 namespace hlsl {
 class DxilResourceBase;
 class WaveSensitivityAnalysis {
 public:
-  static WaveSensitivityAnalysis* create();
+  static WaveSensitivityAnalysis* create(llvm::PostDominatorTree &PDT);
   virtual ~WaveSensitivityAnalysis() { }
   virtual void Analyze(llvm::Function *F) = 0;
   virtual bool IsWaveSensitive(llvm::Instruction *op) = 0;

+ 3 - 1
lib/HLSL/DxilValidation.cpp

@@ -2679,7 +2679,9 @@ static void ValidateGradientOps(Function *F, ArrayRef<CallInst *> ops, ArrayRef<
     return;
   }
 
-  std::unique_ptr<WaveSensitivityAnalysis> WaveVal(WaveSensitivityAnalysis::create());
+    PostDominatorTree PDT;
+    PDT.runOnFunction(*F);
+  std::unique_ptr<WaveSensitivityAnalysis> WaveVal(WaveSensitivityAnalysis::create(PDT));
   WaveVal->Analyze(F);
   for (CallInst *op : ops) {
     if (WaveVal->IsWaveSensitive(op)) {

+ 33 - 3
lib/HLSL/WaveSensitivityAnalysis.cpp

@@ -31,6 +31,8 @@
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/ADT/BitVector.h"
+#include "llvm/Analysis/PostDominators.h"
+
 #ifdef _WIN32
 #include <winerror.h>
 #endif
@@ -42,6 +44,14 @@ using namespace std;
 
 namespace hlsl {
 
+// WaveSensitivityAnalysis is created to validate Gradient operations.
+// Gradient operations require all neighbor lanes to be active when calculated,
+// compiler will enable lanes to meet this requirement. If a wave operation
+// contributed to gradient operation, it will get unexpected result because the
+// active lanes are modified.
+// To avoid unexpected result, validation will fail if gradient operations
+// are dependent on wave-sensitive data or control flow.
+
 class WaveSensitivityAnalyzer : public WaveSensitivityAnalysis {
 private:
   enum WaveSensitivity {
@@ -49,6 +59,7 @@ private:
     KnownNotSensitive,
     Unknown
   };
+  PostDominatorTree *pPDT;
   map<Instruction *, WaveSensitivity> InstState;
   map<BasicBlock *, WaveSensitivity> BBState;
   std::vector<Instruction *> InstWorkList;
@@ -59,12 +70,13 @@ private:
   void UpdateInst(Instruction *I, WaveSensitivity WS);
   void VisitInst(Instruction *I);
 public:
+  WaveSensitivityAnalyzer(PostDominatorTree &PDT) : pPDT(&PDT) {}
   void Analyze(Function *F);
   bool IsWaveSensitive(Instruction *op);
 };
 
-WaveSensitivityAnalysis* WaveSensitivityAnalysis::create() {
-  return new WaveSensitivityAnalyzer();
+WaveSensitivityAnalysis* WaveSensitivityAnalysis::create(PostDominatorTree &PDT) {
+  return new WaveSensitivityAnalyzer(PDT);
 }
 
 void WaveSensitivityAnalyzer::Analyze(Function *F) {
@@ -132,9 +144,14 @@ void WaveSensitivityAnalyzer::UpdateInst(Instruction *I, WaveSensitivity WS) {
     InstState[I] = WS;
     InstWorkList.push_back(I);
     if (TerminatorInst * TI = dyn_cast<TerminatorInst>(I)) {
+      BasicBlock *CurBB = TI->getParent();
       for (unsigned i = 0; i < TI->getNumSuccessors(); ++i) {
         BasicBlock *BB = TI->getSuccessor(i);
-        UpdateBlock(BB, WS);
+        // Only propagate WS when BB not post dom CurBB.
+        WaveSensitivity TmpWS = pPDT->properlyDominates(BB, CurBB)
+                                    ? WaveSensitivity::KnownNotSensitive
+                                    : WS;
+        UpdateBlock(BB, TmpWS);
       }
     }
   }
@@ -153,11 +170,24 @@ void WaveSensitivityAnalyzer::VisitInst(Instruction *I) {
     }
   }
 
+
   if (CheckBBState(I->getParent(), KnownSensitive)) {
     UpdateInst(I, KnownSensitive);
     return;
   }
 
+  // Catch control flow wave sensitive for phi.
+  if (PHINode *Phi = dyn_cast<PHINode>(I)) {
+    for (unsigned i = 0; i < Phi->getNumIncomingValues(); i++) {
+      BasicBlock *BB = Phi->getIncomingBlock(i);
+      WaveSensitivity WS = GetInstState(BB->getTerminator());
+      if (WS == KnownSensitive) {
+        UpdateInst(I, KnownSensitive);
+        return;
+      }
+    }
+  }
+
   bool allKnownNotSensitive = true;
   for (unsigned i = firstArg; i < I->getNumOperands(); ++i) {
     Value *V = I->getOperand(i);

+ 15 - 0
tools/clang/test/CodeGenHLSL/quick-test/NotWaveSensitive.hlsl

@@ -0,0 +1,15 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: Sin
+
+float main ( uint mask:M, float a:A) : SV_Target 
+{ 
+   float r = a;
+   mask = WaveActiveBitOr ( mask ) ;
+    if (mask & 0xf) {
+       r += sin(r);
+    }
+    
+    float dd = ddx(a);
+    return r + dd; 
+}