Browse Source

Don't sink sample coordinate into control flow. (#1188)

Xiang Li 7 years ago
parent
commit
357803d342

+ 4 - 0
include/dxc/HLSL/DxilGenerationPass.h

@@ -51,6 +51,8 @@ ModulePass *createHLEnsureMetadataPass();
 ModulePass *createDxilFinalizeModulePass();
 ModulePass *createDxilEmitMetadataPass();
 FunctionPass *createDxilExpandTrigIntrinsicsPass();
+ModulePass *createDxilConvergentMarkPass();
+ModulePass *createDxilConvergentClearPass();
 ModulePass *createDxilLoadMetadataPass();
 ModulePass *createDxilDeadFunctionEliminationPass();
 ModulePass *createHLDeadFunctionEliminationPass();
@@ -79,6 +81,8 @@ void initializeDxilLoadMetadataPass(llvm::PassRegistry&);
 void initializeDxilDeadFunctionEliminationPass(llvm::PassRegistry&);
 void initializeHLDeadFunctionEliminationPass(llvm::PassRegistry&);
 void initializeHLPreprocessPass(llvm::PassRegistry&);
+void initializeDxilConvergentMarkPass(llvm::PassRegistry&);
+void initializeDxilConvergentClearPass(llvm::PassRegistry&);
 void initializeDxilPrecisePropagatePassPass(llvm::PassRegistry&);
 void initializeDxilPreserveAllOutputsPass(llvm::PassRegistry&);
 void initializeDxilLegalizeResourceUsePassPass(llvm::PassRegistry&);

+ 1 - 0
lib/HLSL/CMakeLists.txt

@@ -10,6 +10,7 @@ add_llvm_library(LLVMHLSL
   DxilContainer.cpp
   DxilContainerAssembler.cpp
   DxilContainerReflection.cpp
+  DxilConvergent.cpp
   DxilDebugInstrumentation.cpp
   DxilEliminateOutputDynamicIndexing.cpp
   DxilExpandTrigIntrinsics.cpp

+ 2 - 0
lib/HLSL/DxcOptimizer.cpp

@@ -87,6 +87,8 @@ HRESULT SetupRegistryPassForHLSL() {
     initializeDeadInstEliminationPass(Registry);
     initializeDxilAddPixelHitInstrumentationPass(Registry);
     initializeDxilCondenseResourcesPass(Registry);
+    initializeDxilConvergentClearPass(Registry);
+    initializeDxilConvergentMarkPass(Registry);
     initializeDxilDeadFunctionEliminationPass(Registry);
     initializeDxilDebugInstrumentationPass(Registry);
     initializeDxilEliminateOutputDynamicIndexingPass(Registry);

+ 249 - 0
lib/HLSL/DxilConvergent.cpp

@@ -0,0 +1,249 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// DxilConvergent.cpp                                                        //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+// Mark convergent for hlsl.                                                 //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/GenericDomTree.h"
+#include "llvm/Support/raw_os_ostream.h"
+
+#include "dxc/HLSL/DxilConstants.h"
+#include "dxc/HLSL/DxilGenerationPass.h"
+#include "dxc/HLSL/HLOperations.h"
+#include "dxc/HLSL/HLModule.h"
+#include "dxc/HlslIntrinsicOp.h"
+
+using namespace llvm;
+using namespace hlsl;
+
+namespace {
+const StringRef kConvergentFunctionPrefix = "dxil.convergent.marker.";
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// DxilConvergent.
+// Mark convergent to avoid sample coordnate calculation sink into control flow.
+//
+namespace {
+
+class DxilConvergentMark : public ModulePass {
+public:
+  static char ID; // Pass identification, replacement for typeid
+  explicit DxilConvergentMark() : ModulePass(ID) {}
+
+  const char *getPassName() const override {
+    return "DxilConvergentMark";
+  }
+
+  bool runOnModule(Module &M) override {
+    if (M.HasHLModule()) {
+      if (!M.GetHLModule().GetShaderModel()->IsPS())
+        return false;
+    }
+    bool bUpdated = false;
+
+    for (Function &F : M.functions()) {
+      if (F.isDeclaration())
+        continue;
+
+      // Compute postdominator relation.
+      DominatorTreeBase<BasicBlock> PDR(true);
+      PDR.recalculate(F);
+      for (BasicBlock &bb : F.getBasicBlockList()) {
+        for (auto it = bb.begin(); it != bb.end();) {
+          Instruction *I = (it++);
+          if (Value *V = FindConvergentOperand(I)) {
+            if (PropagateConvergent(V, &F, PDR)) {
+              // TODO: emit warning here.
+            }
+            bUpdated = true;
+          }
+        }
+      }
+    }
+
+    return bUpdated;
+  }
+
+private:
+  void MarkConvergent(Value *V, IRBuilder<> &Builder, Module &M);
+  Value *FindConvergentOperand(Instruction *I);
+  bool PropagateConvergent(Value *V, Function *F,
+                           DominatorTreeBase<BasicBlock> &PostDom);
+};
+
+char DxilConvergentMark::ID = 0;
+
+void DxilConvergentMark::MarkConvergent(Value *V, IRBuilder<> &Builder,
+                                        Module &M) {
+  Type *Ty = V->getType()->getScalarType();
+  // Only work on vector/scalar types.
+  if (Ty->isAggregateType() ||
+      Ty->isPointerTy())
+    return;
+  FunctionType *FT = FunctionType::get(Ty, Ty, false);
+  std::string str = kConvergentFunctionPrefix;
+  raw_string_ostream os(str);
+  Ty->print(os);
+  os.flush();
+  Function *ConvF = cast<Function>(M.getOrInsertFunction(str, FT));
+  ConvF->addFnAttr(Attribute::AttrKind::Convergent);
+  if (VectorType *VT = dyn_cast<VectorType>(V->getType())) {
+    Value *ConvV = UndefValue::get(V->getType());
+    std::vector<ExtractElementInst *> extractList(VT->getNumElements());
+    for (unsigned i = 0; i < VT->getNumElements(); i++) {
+      ExtractElementInst *EltV =
+          cast<ExtractElementInst>(Builder.CreateExtractElement(V, i));
+      extractList[i] = EltV;
+      Value *EltC = Builder.CreateCall(ConvF, {EltV});
+      ConvV = Builder.CreateInsertElement(ConvV, EltC, i);
+    }
+    V->replaceAllUsesWith(ConvV);
+    for (ExtractElementInst *E : extractList) {
+      E->setOperand(0, V);
+    }
+  } else {
+    CallInst *ConvV = Builder.CreateCall(ConvF, {V});
+    V->replaceAllUsesWith(ConvV);
+    ConvV->setOperand(0, V);
+  }
+}
+
+bool DxilConvergentMark::PropagateConvergent(
+    Value *V, Function *F, DominatorTreeBase<BasicBlock> &PostDom) {
+  // Skip constant.
+  if (isa<Constant>(V))
+    return false;
+  // Skip phi which cannot sink.
+  if (isa<PHINode>(V))
+    return false;
+  if (Instruction *I = dyn_cast<Instruction>(V)) {
+    BasicBlock *BB = I->getParent();
+    if (PostDom.dominates(BB, &F->getEntryBlock())) {
+      IRBuilder<> Builder(I->getNextNode());
+      MarkConvergent(I, Builder, *F->getParent());
+      return false;
+    } else {
+      // Propagete to each operand of I.
+      for (Use &U : I->operands()) {
+        PropagateConvergent(U.get(), F, PostDom);
+      }
+      // return true for report warning.
+      // TODO: static indexing cbuffer is fine.
+      return true;
+    }
+  } else {
+    IRBuilder<> EntryBuilder(F->getEntryBlock().getFirstInsertionPt());
+    MarkConvergent(V, EntryBuilder, *F->getParent());
+    return false;
+  }
+}
+
+Value *DxilConvergentMark::FindConvergentOperand(Instruction *I) {
+  if (CallInst *CI = dyn_cast<CallInst>(I)) {
+    if (hlsl::GetHLOpcodeGroup(CI->getCalledFunction()) ==
+        HLOpcodeGroup::HLIntrinsic) {
+      IntrinsicOp IOP = static_cast<IntrinsicOp>(GetHLOpcode(CI));
+      switch (IOP) {
+      case IntrinsicOp::IOP_ddx:
+      case IntrinsicOp::IOP_ddx_fine:
+      case IntrinsicOp::IOP_ddx_coarse:
+      case IntrinsicOp::IOP_ddy:
+      case IntrinsicOp::IOP_ddy_fine:
+      case IntrinsicOp::IOP_ddy_coarse:
+        return CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
+      case IntrinsicOp::MOP_Sample:
+      case IntrinsicOp::MOP_SampleBias:
+      case IntrinsicOp::MOP_SampleCmp:
+      case IntrinsicOp::MOP_SampleCmpLevelZero:
+      case IntrinsicOp::MOP_CalculateLevelOfDetail:
+      case IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped:
+        return CI->getArgOperand(HLOperandIndex::kSampleCoordArgIndex);
+      case IntrinsicOp::MOP_Gather:
+      case IntrinsicOp::MOP_GatherAlpha:
+      case IntrinsicOp::MOP_GatherBlue:
+      case IntrinsicOp::MOP_GatherCmp:
+      case IntrinsicOp::MOP_GatherCmpAlpha:
+      case IntrinsicOp::MOP_GatherCmpBlue:
+      case IntrinsicOp::MOP_GatherCmpGreen:
+      case IntrinsicOp::MOP_GatherCmpRed:
+      case IntrinsicOp::MOP_GatherGreen:
+      case IntrinsicOp::MOP_GatherRed:
+        return CI->getArgOperand(HLOperandIndex::kGatherCoordArgIndex);
+      }
+    }
+  }
+  return nullptr;
+}
+
+} // namespace
+
+INITIALIZE_PASS(DxilConvergentMark, "hlsl-dxil-convergent-mark",
+                "Mark convergent", false, false)
+
+ModulePass *llvm::createDxilConvergentMarkPass() {
+  return new DxilConvergentMark();
+}
+
+namespace {
+
+class DxilConvergentClear : public ModulePass {
+public:
+  static char ID; // Pass identification, replacement for typeid
+  explicit DxilConvergentClear() : ModulePass(ID) {}
+
+  const char *getPassName() const override {
+    return "DxilConvergentClear";
+  }
+
+  bool runOnModule(Module &M) override {
+    std::vector<Function *> convergentList;
+    for (Function &F : M.functions()) {
+      if (F.getName().startswith(kConvergentFunctionPrefix)) {
+        convergentList.emplace_back(&F);
+      }
+    }
+
+    for (Function *F : convergentList) {
+      ClearConvergent(F);
+    }
+    return convergentList.size();
+  }
+
+private:
+  void ClearConvergent(Function *F);
+};
+
+char DxilConvergentClear::ID = 0;
+
+void DxilConvergentClear::ClearConvergent(Function *F) {
+  // Replace all users with arg.
+  for (auto it = F->user_begin(); it != F->user_end();) {
+    CallInst *CI = cast<CallInst>(*(it++));
+    Value *arg = CI->getArgOperand(0);
+    CI->replaceAllUsesWith(arg);
+    CI->eraseFromParent();
+  }
+
+  F->eraseFromParent();
+}
+
+} // namespace
+
+INITIALIZE_PASS(DxilConvergentClear, "hlsl-dxil-convergent-clear",
+                "Clear convergent before dxil emit", false, false)
+
+ModulePass *llvm::createDxilConvergentClearPass() {
+  return new DxilConvergentClear();
+}

+ 4 - 0
lib/Transforms/IPO/PassManagerBuilder.cpp

@@ -242,6 +242,8 @@ static void addHLSLPasses(bool HLSLHighLevel, unsigned OptLevel, hlsl::HLSLExten
     MPM.add(createLowerStaticGlobalIntoAlloca());
     // mem2reg
     MPM.add(createPromoteMemoryToRegisterPass());
+
+    MPM.add(createDxilConvergentMarkPass());
   }
 
   if (OptLevel > 2) {
@@ -301,6 +303,7 @@ void PassManagerBuilder::populateModulePassManager(
     // HLSL Change Begins.
     addHLSLPasses(HLSLHighLevel, OptLevel, HLSLExtensionsCodeGen, MPM);
     if (!HLSLHighLevel) {
+      MPM.add(createDxilConvergentClearPass());
       MPM.add(createMultiDimArrayToOneDimArrayPass());
       MPM.add(createDxilCondenseResourcesPass());
       MPM.add(createDxilLegalizeSampleOffsetPass());
@@ -573,6 +576,7 @@ void PassManagerBuilder::populateModulePassManager(
 
   // HLSL Change Begins.
   if (!HLSLHighLevel) {
+    MPM.add(createDxilConvergentClearPass());
     MPM.add(createMultiDimArrayToOneDimArrayPass());
     MPM.add(createDxilCondenseResourcesPass());
     MPM.add(createDeadCodeEliminationPass());

+ 19 - 0
tools/clang/test/CodeGenHLSL/quick-test/convergent.hlsl

@@ -0,0 +1,19 @@
+// RUN: %dxc -T ps_6_1 -E main %s | FileCheck %s
+
+// Make sure add is not sink into if.
+// CHECK: fadd
+// CHECK: fadd
+// CHECK: if.then
+
+Texture2D<float4> tex;
+SamplerState s;
+float4 main(float2 a:A, float b:B) : SV_Target {
+
+  float2 coord = a + b;
+  float4 c = b;
+  if (b > 2) {
+    c += tex.Sample(s, coord);
+  }
+  return c;
+
+}

+ 2 - 0
utils/hct/hctdb.py

@@ -1306,6 +1306,8 @@ class db_dxil(object):
         add_pass('hlsl-passes-pause', 'PausePasses', 'Prepare to pause passes', [])
         add_pass('hlsl-passes-resume', 'ResumePasses', 'Prepare to resume passes', [])
         add_pass('hlsl-dxil-condense', 'DxilCondenseResources', 'DXIL Condense Resources', [])
+        add_pass('hlsl-dxil-convergent-mark', 'DxilConvergentMark', 'Mark convergent', [])
+        add_pass('hlsl-dxil-convergent-clear', 'DxilConvergentClear', 'Clear convergent before dxil emit', [])
         add_pass('hlsl-dxil-eliminate-output-dynamic', 'DxilEliminateOutputDynamicIndexing', 'DXIL eliminate ouptut dynamic indexing', [])
         add_pass('hlsl-dxil-add-pixel-hit-instrmentation', 'DxilAddPixelHitInstrumentation', 'DXIL Count completed PS invocations and costs', [
             {'n':'force-early-z','t':'int','c':1},