Kaynağa Gözat

Fixed some cases where O0 fails compilation (#3205)

- Fixed value tracking for dxil intrinsics
- Fixed some selects holding on to invalid resource uses
- Fixed some cases where unused globals hold on to invalid resource uses
- Fixed some cases where stores of undefs stick around
Adam Yang 4 yıl önce
ebeveyn
işleme
d8dec0efd7

+ 3 - 0
include/dxc/HLSL/DxilGenerationPass.h

@@ -125,4 +125,7 @@ void initializeHLLegalizeParameterPass(llvm::PassRegistry &);
 
 bool AreDxilResourcesDense(llvm::Module *M, hlsl::DxilResourceBase **ppNonDense);
 
+ModulePass *createDxilNoOptLegalizePass();
+void initializeDxilNoOptLegalizePass(llvm::PassRegistry&);
+
 }

+ 6 - 0
include/dxc/HLSL/DxilNoops.h

@@ -10,9 +10,15 @@
 
 #include "llvm/ADT/StringRef.h"
 
+namespace llvm {
+  class Instruction;
+}
+
 namespace hlsl {
 static const llvm::StringRef kNoopName = "dx.noop";
 static const llvm::StringRef kPreservePrefix = "dx.preserve.";
 static const llvm::StringRef kNothingName = "dx.nothing.a";
 static const llvm::StringRef kPreserveName = "dx.preserve.value.a";
+
+bool IsPreserve(llvm::Instruction *S);
 }

+ 3 - 1
include/llvm/Analysis/DxilSimplify.h

@@ -27,13 +27,15 @@ class Value;
 } // namespace llvm
 
 namespace hlsl {
+
 /// \brief Given a function and set of arguments, see if we can fold the
 /// result as dxil operation.
 ///
 /// If this call could not be simplified returns null.
 llvm::Value *SimplifyDxilCall(llvm::Function *F,
                               llvm::ArrayRef<llvm::Value *> Args,
-                              llvm::Instruction *I);
+                              llvm::Instruction *I,
+                              bool MayInsert);
 
 /// CanSimplify
 /// Return true on dxil operation function which can be simplified.

+ 32 - 23
lib/Analysis/DxilSimplify.cpp

@@ -49,6 +49,9 @@ bool CanSimplify(const llvm::Function *F) {
     return false;
   }
 
+  if (CanConstantFoldCallTo(F))
+    return true;
+
   // Lookup opcode class in dxil module. Set default value to invalid class.
   OP::OpCodeClass opClass = OP::OpCodeClass::NumOpClasses;
   const bool found =
@@ -72,7 +75,9 @@ bool CanSimplify(const llvm::Function *F) {
 ///
 /// If this call could not be simplified returns null.
 Value *SimplifyDxilCall(llvm::Function *F, ArrayRef<Value *> Args,
-                        llvm::Instruction *I) {
+                        llvm::Instruction *I,
+                        bool MayInsert)
+{
   if (!F->getParent()->HasDxilModule()) {
     assert(!OP::IsDxilOpFunc(F) && "dx.op function with no dxil module?");
     return nullptr;
@@ -124,21 +129,23 @@ Value *SimplifyDxilCall(llvm::Function *F, ArrayRef<Value *> Args,
     if (op1 == zero)
       return op2;
 
-    Constant *one = ConstantFP::get(op0->getType(), 1);
-    if (op0 == one) {
-      IRBuilder<> Builder(I);
-      llvm::FastMathFlags FMF;
-      FMF.setUnsafeAlgebraHLSL();
-      Builder.SetFastMathFlags(FMF);
-      return Builder.CreateFAdd(op1, op2);
-    }
-    if (op1 == one) {
-      IRBuilder<> Builder(I);
-      llvm::FastMathFlags FMF;
-      FMF.setUnsafeAlgebraHLSL();
-      Builder.SetFastMathFlags(FMF);
+    if (MayInsert) {
+      Constant *one = ConstantFP::get(op0->getType(), 1);
+      if (op0 == one) {
+        IRBuilder<> Builder(I);
+        llvm::FastMathFlags FMF;
+        FMF.setUnsafeAlgebraHLSL();
+        Builder.SetFastMathFlags(FMF);
+        return Builder.CreateFAdd(op1, op2);
+      }
+      if (op1 == one) {
+        IRBuilder<> Builder(I);
+        llvm::FastMathFlags FMF;
+        FMF.setUnsafeAlgebraHLSL();
+        Builder.SetFastMathFlags(FMF);
 
-      return Builder.CreateFAdd(op0, op2);
+        return Builder.CreateFAdd(op0, op2);
+      }
     }
     return nullptr;
   } break;
@@ -153,14 +160,16 @@ Value *SimplifyDxilCall(llvm::Function *F, ArrayRef<Value *> Args,
     if (op1 == zero)
       return op2;
 
-    Constant *one = ConstantInt::get(op0->getType(), 1);
-    if (op0 == one) {
-      IRBuilder<> Builder(I);
-      return Builder.CreateAdd(op1, op2);
-    }
-    if (op1 == one) {
-      IRBuilder<> Builder(I);
-      return Builder.CreateAdd(op0, op2);
+    if (MayInsert) {
+      Constant *one = ConstantInt::get(op0->getType(), 1);
+      if (op0 == one) {
+        IRBuilder<> Builder(I);
+        return Builder.CreateAdd(op1, op2);
+      }
+      if (op1 == one) {
+        IRBuilder<> Builder(I);
+        return Builder.CreateAdd(op0, op2);
+      }
     }
     return nullptr;
   } break;

+ 16 - 1
lib/Analysis/DxilValueCache.cpp

@@ -13,6 +13,7 @@
 
 #include "llvm/Pass.h"
 #include "dxc/DXIL/DxilConstants.h"
+#include "llvm/Analysis/DxilSimplify.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/Constants.h"
@@ -228,10 +229,24 @@ Value *DxilValueCache::SimplifyAndCacheResult(Instruction *I, DominatorTree *DT)
   else if (Instruction::Call == I->getOpcode()) {
     Module *M = I->getModule();
     CallInst *CI = cast<CallInst>(I);
-    if (CI->getCalledFunction()->getName() == hlsl::DXIL::kDxBreakFuncName) {
+    Function *Callee = CI->getCalledFunction();
+    if (Callee->getName() == hlsl::DXIL::kDxBreakFuncName) {
       llvm::Type *i1Ty = llvm::Type::getInt1Ty(M->getContext());
       Simplified = llvm::ConstantInt::get(i1Ty, 1);
     }
+    else {
+      SmallVector<Value *,16> Args;
+      for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
+        Args.push_back(TryGetCachedValue(CI->getArgOperand(i)));
+      }
+
+      if (hlsl::CanSimplify(Callee)) {
+        Simplified = hlsl::SimplifyDxilCall(Callee, Args, CI, /* MayInsert */ false);
+      }
+      else {
+        Simplified = llvm::SimplifyCall(Callee, Args, DL, nullptr, DT);
+      }
+    }
   }
   // The rest of the checks use LLVM stock simplifications
   else if (I->isBinaryOp()) {

+ 1 - 1
lib/Analysis/InstructionSimplify.cpp

@@ -4406,7 +4406,7 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL,
     if (Function *Callee = CS.getCalledFunction()) {
       if (hlsl::CanSimplify(Callee)) {
         SmallVector<Value *, 4> Args(CS.arg_begin(), CS.arg_end());
-        if (Value *DxilResult = hlsl::SimplifyDxilCall(CS.getCalledFunction(), Args, I)) {
+        if (Value *DxilResult = hlsl::SimplifyDxilCall(CS.getCalledFunction(), Args, I, /* MayInsert */ true)) {
           Result = DxilResult;
           break;
         }

+ 1 - 0
lib/HLSL/CMakeLists.txt

@@ -54,6 +54,7 @@ add_llvm_library(LLVMHLSL
   HLUtil.cpp
   PauseResumePasses.cpp
   WaveSensitivityAnalysis.cpp
+  DxilNoOptLegalize.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${LLVM_MAIN_INCLUDE_DIR}/llvm/IR

+ 1 - 0
lib/HLSL/DxcOptimizer.cpp

@@ -109,6 +109,7 @@ HRESULT SetupRegistryPassForHLSL() {
     initializeDxilLoopDeletionPass(Registry);
     initializeDxilLoopUnrollPass(Registry);
     initializeDxilLowerCreateHandleForLibPass(Registry);
+    initializeDxilNoOptLegalizePass(Registry);
     initializeDxilPrecisePropagatePassPass(Registry);
     initializeDxilPreserveAllOutputsPass(Registry);
     initializeDxilPreserveToSelectPass(Registry);

+ 118 - 0
lib/HLSL/DxilNoOptLegalize.cpp

@@ -0,0 +1,118 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// DxilNoOptLegalize.cpp                                                     //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "llvm/Pass.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Instructions.h"
+#include "dxc/HLSL/DxilGenerationPass.h"
+#include "dxc/HLSL/DxilNoops.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/Analysis/DxilValueCache.h"
+
+using namespace llvm;
+
+class DxilNoOptLegalize : public ModulePass {
+  SmallVector<Value *, 16> Worklist;
+
+public:
+  static char ID;
+  DxilNoOptLegalize() : ModulePass(ID) {
+    initializeDxilNoOptLegalizePass(*PassRegistry::getPassRegistry());
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<DxilValueCache>();
+  }
+
+  bool runOnModule(Module &M) override;
+  bool RemoveStoreUndefsFromPtr(Value *V);
+  bool RemoveStoreUndefs(Module &M);
+  bool SimplifySelects(Module &M);
+};
+char DxilNoOptLegalize::ID;
+
+bool DxilNoOptLegalize::RemoveStoreUndefsFromPtr(Value *Ptr) {
+  bool Changed = false;
+  Worklist.clear();
+  Worklist.push_back(Ptr);
+
+  while (Worklist.size()) {
+    Value *V = Worklist.back();
+    Worklist.pop_back();
+    if (isa<AllocaInst>(V) || isa<GlobalVariable>(V) || isa<GEPOperator>(V)) {
+      for (User *U : V->users())
+        Worklist.push_back(U);
+    }
+    else if (StoreInst *Store = dyn_cast<StoreInst>(V)) {
+      if (isa<UndefValue>(Store->getValueOperand())) {
+        Store->eraseFromParent();
+        Changed = true;
+      }
+    }
+  }
+
+  return Changed;
+}
+
+bool DxilNoOptLegalize::RemoveStoreUndefs(Module &M) {
+  bool Changed = false;
+  for (GlobalVariable &GV : M.globals()) {
+    Changed |= RemoveStoreUndefsFromPtr(&GV);
+  }
+
+  for (Function &F : M) {
+    if (F.empty())
+      continue;
+
+    BasicBlock &Entry = F.getEntryBlock();
+    for (Instruction &I : Entry) {
+      if (isa<AllocaInst>(&I))
+        Changed |= RemoveStoreUndefsFromPtr(&I);
+    }
+  }
+
+  return Changed;
+}
+
+bool DxilNoOptLegalize::SimplifySelects(Module &M) {
+  bool Changed = false;
+  DxilValueCache *DVC = &getAnalysis<DxilValueCache>();
+  for (Function &F : M) {
+    for (BasicBlock &BB : F) {
+      for (auto it = BB.begin(), end = BB.end(); it != end;) {
+        Instruction *I = &*(it++);
+        if (I->getOpcode() == Instruction::Select) {
+
+          if (hlsl::IsPreserve(I))
+            continue;
+
+          if (Value *C = DVC->GetValue(I)) {
+            I->replaceAllUsesWith(C);
+            I->eraseFromParent();
+            Changed = true;
+          }
+        }
+      }
+    }
+  }
+  return Changed;
+}
+
+bool DxilNoOptLegalize::runOnModule(Module &M) {
+  bool Changed = false;
+  Changed |= RemoveStoreUndefs(M);
+  Changed |= SimplifySelects(M);
+  return Changed;
+}
+
+ModulePass *llvm::createDxilNoOptLegalizePass() {
+  return new DxilNoOptLegalize();
+}
+
+INITIALIZE_PASS(DxilNoOptLegalize, "dxil-o0-legalize", "DXIL No-Opt Legalize", false, false)

+ 22 - 0
lib/HLSL/DxilNoops.cpp

@@ -229,6 +229,28 @@ static Value *GetOrCreatePreserveCond(Function *F) {
   return B.CreateTrunc(Load, B.getInt1Ty());
 }
 
+bool hlsl::IsPreserve(llvm::Instruction *I) {
+  SelectInst *S = dyn_cast<SelectInst>(I);
+  if (!S)
+    return false;
+
+  TruncInst *Trunc = dyn_cast<TruncInst>(S->getCondition());
+  if (!Trunc)
+    return false;
+
+  LoadInst *Load = dyn_cast<LoadInst>(Trunc->getOperand(0));
+  if (!Load)
+    return false;
+
+  GEPOperator *GEP = dyn_cast<GEPOperator>(Load->getPointerOperand());
+  if (!GEP)
+    return false;
+
+  GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getPointerOperand());
+
+  return GV && GV->getLinkage() == GlobalVariable::LinkageTypes::InternalLinkage && GV->getName() == kPreserveName;
+}
+
 
 static Function *GetOrCreatePreserveF(Module *M, Type *Ty) {
   std::string str = hlsl::kPreservePrefix;

+ 3 - 1
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -359,8 +359,10 @@ void PassManagerBuilder::populateModulePassManager(
 
     if (!HLSLHighLevel) {
       MPM.add(createDxilConvergentClearPass());
-      MPM.add(createMultiDimArrayToOneDimArrayPass());
       MPM.add(createDxilRemoveDeadBlocksPass());
+      MPM.add(createDxilNoOptLegalizePass());
+      MPM.add(createGlobalOptimizerPass());
+      MPM.add(createMultiDimArrayToOneDimArrayPass());
       MPM.add(createDeadCodeEliminationPass());
       MPM.add(createGlobalDCEPass());
       MPM.add(createDxilLowerCreateHandleForLibPass());

+ 22 - 0
tools/clang/test/HLSLFileCheck/passes/dxil/dxil_o0_legalize/selects.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc %s -T ps_6_0 -Od | FileCheck %s
+
+// Regression test for selecting on bad resources.
+
+// CHECK: @main
+
+Texture2D t0 : register(t0);
+
+struct Foo {
+  Texture2D a, b;
+};
+
+float4 bar(uint x, int3 off, Foo foo) {
+  return x ? foo.a.Load(off) : foo.b.Load(off);
+}
+
+float4 main(int3 off : OFF) : SV_Target {
+  Foo foo;
+  foo.a = t0;
+  return bar(1, off, foo);
+}
+

+ 49 - 0
tools/clang/test/HLSLFileCheck/passes/dxil/dxil_o0_legalize/store_undef.hlsl

@@ -0,0 +1,49 @@
+// RUN: %dxc %s -T ps_6_0 -Od | FileCheck %s
+
+// Regression test for validation failure in O0 due to
+// storing structure with uninitialized member.
+
+// CHECK: @main
+
+Texture2D t0 : register(t0);
+Texture2D t1 : register(t1);
+
+struct Foo {
+  float a,b,c,d,e,f,g,h,i;
+};
+
+groupshared Foo foos[4];
+
+Foo make_foo(float x, float y, float z) {
+  Foo foo;
+  foo.a = x;
+  foo.b = y;
+  // foo.c is missing
+  foo.d = x;
+  foo.e = y;
+  foo.f = z;
+  foo.g = x;
+  foo.h = y;
+  foo.i = z;
+  return foo;
+}
+
+void foo(float x, float y, float z) {
+ [unroll]
+ for( int i = (4) - 1; i >= 0; --i ) {
+   foos[i] = make_foo( x, y, z );
+  }
+}
+
+float bar(Foo f) {
+  return f.e;
+}
+
+float main(uint3 off : OFF) : SV_Target {
+  foo(1, 2, 0);
+  return bar(foos[3]);
+}
+
+
+
+

+ 40 - 0
tools/clang/test/HLSLFileCheck/passes/dxil/dxil_remove_dead_pass/dxil_ops.hlsl

@@ -0,0 +1,40 @@
+// RUN: %dxc %s -T ps_6_0 -Od | FileCheck %s
+
+// Regression test for dxil operations not being evaluated.
+
+// CHECK: @main
+
+Texture2D t0 : register(t0);
+Texture2D t1 : register(t1);
+
+static const uint global = 1;
+static const uint global2 = 2;
+
+static const uint global3[3] = { 0, 1, 1 };
+
+cbuffer cb {
+  float bar, baz;
+};
+
+Texture2D foo(float x, float y, float z) {
+  int i;
+  i = mad(bar, 0, y); // 0
+  [branch]
+  if (mad(x, y, 0) == 0) { // true
+    i = mad(bar, 0, x); // 1
+  }
+
+  int j = i - 1;
+
+  if (j) {
+    return t0;
+  }
+  else {
+    return t1;
+  }
+}
+
+float main(uint3 off : OFF) : SV_Target {
+  return foo(1, 0, 2).Load(off).x;
+}
+

+ 40 - 0
tools/clang/test/HLSLFileCheck/passes/dxil/dxil_remove_dead_pass/store_only.hlsl

@@ -0,0 +1,40 @@
+// RUN: %dxc %s -T vs_6_0 -Od | FileCheck %s
+
+// CHECK: @main
+
+// Regression test to make sure resources used by stores to unused globals
+// are removed.
+
+struct Foo {
+  float4 member[8];
+};
+
+struct Bar {
+  Texture2D t0;
+};
+
+struct Baz {
+    Foo foo;
+    Bar bar;
+};
+
+Texture2D t0 : register(t0 , space6);
+cbuffer Baz_cbuffer : register(b0 , space6 ) {
+  Foo cb_foo;
+};
+
+Baz CreateBaz() {
+  Baz i;
+  i.foo = cb_foo;
+  i.bar.t0 = t0;
+  return i;
+}
+
+static const Baz g_Baz = CreateBaz();
+
+[RootSignature("")]
+float4 main() : SV_Position {
+  return float4(0,0,0,0);
+}
+
+

+ 1 - 0
utils/hct/hctdb.py

@@ -2144,6 +2144,7 @@ class db_dxil(object):
         ])
         add_pass('dxil-erase-dead-region', 'DxilEraseDeadRegion', 'DxilEraseDeadRegion', [])
         add_pass('dxil-remove-dead-blocks', 'DxilRemoveDeadBlocks', 'DxilRemoveDeadBlocks', [])
+        add_pass('dxil-o0-legalize', 'DxilNoOptLegalize', 'DXIL No-Opt Legalize', [])
         add_pass('loop-deletion', 'LoopDeletion', "Delete dead loops", [])
         add_pass('loop-interchange', 'LoopInterchange', 'Interchanges loops for cache reuse', [])
         add_pass('loop-unroll', 'LoopUnroll', 'Unroll loops', [