|
@@ -0,0 +1,249 @@
|
|
|
|
+///////////////////////////////////////////////////////////////////////////////
|
|
|
|
+// //
|
|
|
|
+// DxilConvergent.cpp //
|
|
|
|
+// Copyright (C) Microsoft Corporation. All rights reserved. //
|
|
|
|
+// This file is distributed under the University of Illinois Open Source //
|
|
|
|
+// License. See LICENSE.TXT for details. //
|
|
|
|
+// //
|
|
|
|
+// Mark convergent for hlsl. //
|
|
|
|
+// //
|
|
|
|
+///////////////////////////////////////////////////////////////////////////////
|
|
|
|
+
|
|
|
|
+#include "llvm/IR/BasicBlock.h"
|
|
|
|
+#include "llvm/IR/Dominators.h"
|
|
|
|
+#include "llvm/IR/Function.h"
|
|
|
|
+#include "llvm/IR/IRBuilder.h"
|
|
|
|
+#include "llvm/IR/Intrinsics.h"
|
|
|
|
+#include "llvm/IR/Module.h"
|
|
|
|
+#include "llvm/Support/GenericDomTree.h"
|
|
|
|
+#include "llvm/Support/raw_os_ostream.h"
|
|
|
|
+
|
|
|
|
+#include "dxc/HLSL/DxilConstants.h"
|
|
|
|
+#include "dxc/HLSL/DxilGenerationPass.h"
|
|
|
|
+#include "dxc/HLSL/HLOperations.h"
|
|
|
|
+#include "dxc/HLSL/HLModule.h"
|
|
|
|
+#include "dxc/HlslIntrinsicOp.h"
|
|
|
|
+
|
|
|
|
+using namespace llvm;
|
|
|
|
+using namespace hlsl;
|
|
|
|
+
|
|
|
|
+namespace {
|
|
|
|
+const StringRef kConvergentFunctionPrefix = "dxil.convergent.marker.";
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+///////////////////////////////////////////////////////////////////////////////
|
|
|
|
+// DxilConvergent.
|
|
|
|
+// Mark convergent to avoid sample coordnate calculation sink into control flow.
|
|
|
|
+//
|
|
|
|
+namespace {
|
|
|
|
+
|
|
|
|
+class DxilConvergentMark : public ModulePass {
|
|
|
|
+public:
|
|
|
|
+ static char ID; // Pass identification, replacement for typeid
|
|
|
|
+ explicit DxilConvergentMark() : ModulePass(ID) {}
|
|
|
|
+
|
|
|
|
+ const char *getPassName() const override {
|
|
|
|
+ return "DxilConvergentMark";
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ bool runOnModule(Module &M) override {
|
|
|
|
+ if (M.HasHLModule()) {
|
|
|
|
+ if (!M.GetHLModule().GetShaderModel()->IsPS())
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ bool bUpdated = false;
|
|
|
|
+
|
|
|
|
+ for (Function &F : M.functions()) {
|
|
|
|
+ if (F.isDeclaration())
|
|
|
|
+ continue;
|
|
|
|
+
|
|
|
|
+ // Compute postdominator relation.
|
|
|
|
+ DominatorTreeBase<BasicBlock> PDR(true);
|
|
|
|
+ PDR.recalculate(F);
|
|
|
|
+ for (BasicBlock &bb : F.getBasicBlockList()) {
|
|
|
|
+ for (auto it = bb.begin(); it != bb.end();) {
|
|
|
|
+ Instruction *I = (it++);
|
|
|
|
+ if (Value *V = FindConvergentOperand(I)) {
|
|
|
|
+ if (PropagateConvergent(V, &F, PDR)) {
|
|
|
|
+ // TODO: emit warning here.
|
|
|
|
+ }
|
|
|
|
+ bUpdated = true;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return bUpdated;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+private:
|
|
|
|
+ void MarkConvergent(Value *V, IRBuilder<> &Builder, Module &M);
|
|
|
|
+ Value *FindConvergentOperand(Instruction *I);
|
|
|
|
+ bool PropagateConvergent(Value *V, Function *F,
|
|
|
|
+ DominatorTreeBase<BasicBlock> &PostDom);
|
|
|
|
+};
|
|
|
|
+
|
|
|
|
+char DxilConvergentMark::ID = 0;
|
|
|
|
+
|
|
|
|
+void DxilConvergentMark::MarkConvergent(Value *V, IRBuilder<> &Builder,
|
|
|
|
+ Module &M) {
|
|
|
|
+ Type *Ty = V->getType()->getScalarType();
|
|
|
|
+ // Only work on vector/scalar types.
|
|
|
|
+ if (Ty->isAggregateType() ||
|
|
|
|
+ Ty->isPointerTy())
|
|
|
|
+ return;
|
|
|
|
+ FunctionType *FT = FunctionType::get(Ty, Ty, false);
|
|
|
|
+ std::string str = kConvergentFunctionPrefix;
|
|
|
|
+ raw_string_ostream os(str);
|
|
|
|
+ Ty->print(os);
|
|
|
|
+ os.flush();
|
|
|
|
+ Function *ConvF = cast<Function>(M.getOrInsertFunction(str, FT));
|
|
|
|
+ ConvF->addFnAttr(Attribute::AttrKind::Convergent);
|
|
|
|
+ if (VectorType *VT = dyn_cast<VectorType>(V->getType())) {
|
|
|
|
+ Value *ConvV = UndefValue::get(V->getType());
|
|
|
|
+ std::vector<ExtractElementInst *> extractList(VT->getNumElements());
|
|
|
|
+ for (unsigned i = 0; i < VT->getNumElements(); i++) {
|
|
|
|
+ ExtractElementInst *EltV =
|
|
|
|
+ cast<ExtractElementInst>(Builder.CreateExtractElement(V, i));
|
|
|
|
+ extractList[i] = EltV;
|
|
|
|
+ Value *EltC = Builder.CreateCall(ConvF, {EltV});
|
|
|
|
+ ConvV = Builder.CreateInsertElement(ConvV, EltC, i);
|
|
|
|
+ }
|
|
|
|
+ V->replaceAllUsesWith(ConvV);
|
|
|
|
+ for (ExtractElementInst *E : extractList) {
|
|
|
|
+ E->setOperand(0, V);
|
|
|
|
+ }
|
|
|
|
+ } else {
|
|
|
|
+ CallInst *ConvV = Builder.CreateCall(ConvF, {V});
|
|
|
|
+ V->replaceAllUsesWith(ConvV);
|
|
|
|
+ ConvV->setOperand(0, V);
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+bool DxilConvergentMark::PropagateConvergent(
|
|
|
|
+ Value *V, Function *F, DominatorTreeBase<BasicBlock> &PostDom) {
|
|
|
|
+ // Skip constant.
|
|
|
|
+ if (isa<Constant>(V))
|
|
|
|
+ return false;
|
|
|
|
+ // Skip phi which cannot sink.
|
|
|
|
+ if (isa<PHINode>(V))
|
|
|
|
+ return false;
|
|
|
|
+ if (Instruction *I = dyn_cast<Instruction>(V)) {
|
|
|
|
+ BasicBlock *BB = I->getParent();
|
|
|
|
+ if (PostDom.dominates(BB, &F->getEntryBlock())) {
|
|
|
|
+ IRBuilder<> Builder(I->getNextNode());
|
|
|
|
+ MarkConvergent(I, Builder, *F->getParent());
|
|
|
|
+ return false;
|
|
|
|
+ } else {
|
|
|
|
+ // Propagete to each operand of I.
|
|
|
|
+ for (Use &U : I->operands()) {
|
|
|
|
+ PropagateConvergent(U.get(), F, PostDom);
|
|
|
|
+ }
|
|
|
|
+ // return true for report warning.
|
|
|
|
+ // TODO: static indexing cbuffer is fine.
|
|
|
|
+ return true;
|
|
|
|
+ }
|
|
|
|
+ } else {
|
|
|
|
+ IRBuilder<> EntryBuilder(F->getEntryBlock().getFirstInsertionPt());
|
|
|
|
+ MarkConvergent(V, EntryBuilder, *F->getParent());
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+Value *DxilConvergentMark::FindConvergentOperand(Instruction *I) {
|
|
|
|
+ if (CallInst *CI = dyn_cast<CallInst>(I)) {
|
|
|
|
+ if (hlsl::GetHLOpcodeGroup(CI->getCalledFunction()) ==
|
|
|
|
+ HLOpcodeGroup::HLIntrinsic) {
|
|
|
|
+ IntrinsicOp IOP = static_cast<IntrinsicOp>(GetHLOpcode(CI));
|
|
|
|
+ switch (IOP) {
|
|
|
|
+ case IntrinsicOp::IOP_ddx:
|
|
|
|
+ case IntrinsicOp::IOP_ddx_fine:
|
|
|
|
+ case IntrinsicOp::IOP_ddx_coarse:
|
|
|
|
+ case IntrinsicOp::IOP_ddy:
|
|
|
|
+ case IntrinsicOp::IOP_ddy_fine:
|
|
|
|
+ case IntrinsicOp::IOP_ddy_coarse:
|
|
|
|
+ return CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
|
|
|
|
+ case IntrinsicOp::MOP_Sample:
|
|
|
|
+ case IntrinsicOp::MOP_SampleBias:
|
|
|
|
+ case IntrinsicOp::MOP_SampleCmp:
|
|
|
|
+ case IntrinsicOp::MOP_SampleCmpLevelZero:
|
|
|
|
+ case IntrinsicOp::MOP_CalculateLevelOfDetail:
|
|
|
|
+ case IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped:
|
|
|
|
+ return CI->getArgOperand(HLOperandIndex::kSampleCoordArgIndex);
|
|
|
|
+ case IntrinsicOp::MOP_Gather:
|
|
|
|
+ case IntrinsicOp::MOP_GatherAlpha:
|
|
|
|
+ case IntrinsicOp::MOP_GatherBlue:
|
|
|
|
+ case IntrinsicOp::MOP_GatherCmp:
|
|
|
|
+ case IntrinsicOp::MOP_GatherCmpAlpha:
|
|
|
|
+ case IntrinsicOp::MOP_GatherCmpBlue:
|
|
|
|
+ case IntrinsicOp::MOP_GatherCmpGreen:
|
|
|
|
+ case IntrinsicOp::MOP_GatherCmpRed:
|
|
|
|
+ case IntrinsicOp::MOP_GatherGreen:
|
|
|
|
+ case IntrinsicOp::MOP_GatherRed:
|
|
|
|
+ return CI->getArgOperand(HLOperandIndex::kGatherCoordArgIndex);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return nullptr;
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+} // namespace
|
|
|
|
+
|
|
|
|
+INITIALIZE_PASS(DxilConvergentMark, "hlsl-dxil-convergent-mark",
|
|
|
|
+ "Mark convergent", false, false)
|
|
|
|
+
|
|
|
|
+ModulePass *llvm::createDxilConvergentMarkPass() {
|
|
|
|
+ return new DxilConvergentMark();
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+namespace {
|
|
|
|
+
|
|
|
|
+class DxilConvergentClear : public ModulePass {
|
|
|
|
+public:
|
|
|
|
+ static char ID; // Pass identification, replacement for typeid
|
|
|
|
+ explicit DxilConvergentClear() : ModulePass(ID) {}
|
|
|
|
+
|
|
|
|
+ const char *getPassName() const override {
|
|
|
|
+ return "DxilConvergentClear";
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ bool runOnModule(Module &M) override {
|
|
|
|
+ std::vector<Function *> convergentList;
|
|
|
|
+ for (Function &F : M.functions()) {
|
|
|
|
+ if (F.getName().startswith(kConvergentFunctionPrefix)) {
|
|
|
|
+ convergentList.emplace_back(&F);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for (Function *F : convergentList) {
|
|
|
|
+ ClearConvergent(F);
|
|
|
|
+ }
|
|
|
|
+ return convergentList.size();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+private:
|
|
|
|
+ void ClearConvergent(Function *F);
|
|
|
|
+};
|
|
|
|
+
|
|
|
|
+char DxilConvergentClear::ID = 0;
|
|
|
|
+
|
|
|
|
+void DxilConvergentClear::ClearConvergent(Function *F) {
|
|
|
|
+ // Replace all users with arg.
|
|
|
|
+ for (auto it = F->user_begin(); it != F->user_end();) {
|
|
|
|
+ CallInst *CI = cast<CallInst>(*(it++));
|
|
|
|
+ Value *arg = CI->getArgOperand(0);
|
|
|
|
+ CI->replaceAllUsesWith(arg);
|
|
|
|
+ CI->eraseFromParent();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ F->eraseFromParent();
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+} // namespace
|
|
|
|
+
|
|
|
|
+INITIALIZE_PASS(DxilConvergentClear, "hlsl-dxil-convergent-clear",
|
|
|
|
+ "Clear convergent before dxil emit", false, false)
|
|
|
|
+
|
|
|
|
+ModulePass *llvm::createDxilConvergentClearPass() {
|
|
|
|
+ return new DxilConvergentClear();
|
|
|
|
+}
|