123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- ///////////////////////////////////////////////////////////////////////////////
- // //
- // DxilConvergent.cpp //
- // Copyright (C) Microsoft Corporation. All rights reserved. //
- // This file is distributed under the University of Illinois Open Source //
- // License. See LICENSE.TXT for details. //
- // //
- // Mark convergent for hlsl. //
- // //
- ///////////////////////////////////////////////////////////////////////////////
- #include "llvm/IR/BasicBlock.h"
- #include "llvm/IR/Dominators.h"
- #include "llvm/IR/Function.h"
- #include "llvm/IR/IRBuilder.h"
- #include "llvm/IR/Intrinsics.h"
- #include "llvm/IR/Module.h"
- #include "llvm/Support/GenericDomTree.h"
- #include "llvm/Support/raw_os_ostream.h"
- #include "dxc/DXIL/DxilConstants.h"
- #include "dxc/HLSL/DxilGenerationPass.h"
- #include "dxc/HLSL/HLOperations.h"
- #include "dxc/HLSL/HLModule.h"
- #include "dxc/HlslIntrinsicOp.h"
- #include "dxc/HLSL/DxilConvergentName.h"
- using namespace llvm;
- using namespace hlsl;
- ///////////////////////////////////////////////////////////////////////////////
- // DxilConvergent.
- // Mark convergent to avoid sample coordnate calculation sink into control flow.
- //
- namespace {
- class DxilConvergentMark : public ModulePass {
- public:
- static char ID; // Pass identification, replacement for typeid
- explicit DxilConvergentMark() : ModulePass(ID) {}
- const char *getPassName() const override {
- return "DxilConvergentMark";
- }
- bool runOnModule(Module &M) override {
- if (M.HasHLModule()) {
- const ShaderModel *SM = M.GetHLModule().GetShaderModel();
- if (!SM->IsPS() && !SM->IsLib() && (!SM->IsSM66Plus() || (!SM->IsCS() && !SM->IsMS() && !SM->IsAS())))
- return false;
- }
- bool bUpdated = false;
- for (Function &F : M.functions()) {
- if (F.isDeclaration())
- continue;
- // Compute postdominator relation.
- DominatorTreeBase<BasicBlock> PDR(true);
- PDR.recalculate(F);
- for (BasicBlock &bb : F.getBasicBlockList()) {
- for (auto it = bb.begin(); it != bb.end();) {
- Instruction *I = (it++);
- if (Value *V = FindConvergentOperand(I)) {
- if (PropagateConvergent(V, &F, PDR)) {
- // TODO: emit warning here.
- }
- bUpdated = true;
- }
- }
- }
- }
- return bUpdated;
- }
- private:
- void MarkConvergent(Value *V, IRBuilder<> &Builder, Module &M);
- Value *FindConvergentOperand(Instruction *I);
- bool PropagateConvergent(Value *V, Function *F,
- DominatorTreeBase<BasicBlock> &PostDom);
- bool PropagateConvergentImpl(Value *V, Function *F,
- DominatorTreeBase<BasicBlock> &PostDom, std::set<Value*>& visited);
- };
- char DxilConvergentMark::ID = 0;
- void DxilConvergentMark::MarkConvergent(Value *V, IRBuilder<> &Builder,
- Module &M) {
- Type *Ty = V->getType()->getScalarType();
- // Only work on vector/scalar types.
- if (Ty->isAggregateType() ||
- Ty->isPointerTy())
- return;
- FunctionType *FT = FunctionType::get(Ty, Ty, false);
- std::string str = kConvergentFunctionPrefix;
- raw_string_ostream os(str);
- Ty->print(os);
- os.flush();
- Function *ConvF = cast<Function>(M.getOrInsertFunction(str, FT));
- ConvF->addFnAttr(Attribute::AttrKind::Convergent);
- if (VectorType *VT = dyn_cast<VectorType>(V->getType())) {
- Value *ConvV = UndefValue::get(V->getType());
- std::vector<ExtractElementInst *> extractList(VT->getNumElements());
- for (unsigned i = 0; i < VT->getNumElements(); i++) {
- ExtractElementInst *EltV =
- cast<ExtractElementInst>(Builder.CreateExtractElement(V, i));
- extractList[i] = EltV;
- Value *EltC = Builder.CreateCall(ConvF, {EltV});
- ConvV = Builder.CreateInsertElement(ConvV, EltC, i);
- }
- V->replaceAllUsesWith(ConvV);
- for (ExtractElementInst *E : extractList) {
- E->setOperand(0, V);
- }
- } else {
- CallInst *ConvV = Builder.CreateCall(ConvF, {V});
- V->replaceAllUsesWith(ConvV);
- ConvV->setOperand(0, V);
- }
- }
- bool DxilConvergentMark::PropagateConvergent(
- Value *V, Function *F, DominatorTreeBase<BasicBlock> &PostDom) {
- std::set<Value *> visited;
- return PropagateConvergentImpl(V, F, PostDom, visited);
- }
- bool DxilConvergentMark::PropagateConvergentImpl(Value *V, Function *F,
- DominatorTreeBase<BasicBlock> &PostDom, std::set<Value*>& visited) {
- // Don't go through already visted nodes
- if (visited.find(V) != visited.end())
- return false;
- // Mark as visited
- visited.insert(V);
- // Skip constant.
- if (isa<Constant>(V))
- return false;
- // Skip phi which cannot sink.
- if (isa<PHINode>(V))
- return false;
- if (Instruction *I = dyn_cast<Instruction>(V)) {
- BasicBlock *BB = I->getParent();
- if (PostDom.dominates(BB, &F->getEntryBlock())) {
- IRBuilder<> Builder(I->getNextNode());
- MarkConvergent(I, Builder, *F->getParent());
- return false;
- } else {
- // Propagete to each operand of I.
- for (Use &U : I->operands()) {
- PropagateConvergentImpl(U.get(), F, PostDom, visited);
- }
- // return true for report warning.
- // TODO: static indexing cbuffer is fine.
- return true;
- }
- } else {
- IRBuilder<> EntryBuilder(F->getEntryBlock().getFirstInsertionPt());
- MarkConvergent(V, EntryBuilder, *F->getParent());
- return false;
- }
- }
- Value *DxilConvergentMark::FindConvergentOperand(Instruction *I) {
- if (CallInst *CI = dyn_cast<CallInst>(I)) {
- if (hlsl::GetHLOpcodeGroup(CI->getCalledFunction()) ==
- HLOpcodeGroup::HLIntrinsic) {
- IntrinsicOp IOP = static_cast<IntrinsicOp>(GetHLOpcode(CI));
- switch (IOP) {
- case IntrinsicOp::IOP_ddx:
- case IntrinsicOp::IOP_ddx_fine:
- case IntrinsicOp::IOP_ddx_coarse:
- case IntrinsicOp::IOP_ddy:
- case IntrinsicOp::IOP_ddy_fine:
- case IntrinsicOp::IOP_ddy_coarse:
- return CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
- case IntrinsicOp::MOP_Sample:
- case IntrinsicOp::MOP_SampleBias:
- case IntrinsicOp::MOP_SampleCmp:
- case IntrinsicOp::MOP_CalculateLevelOfDetail:
- case IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped:
- return CI->getArgOperand(HLOperandIndex::kSampleCoordArgIndex);
- case IntrinsicOp::MOP_WriteSamplerFeedback:
- case IntrinsicOp::MOP_WriteSamplerFeedbackBias:
- return CI->getArgOperand(HLOperandIndex::kWriteSamplerFeedbackCoordArgIndex);
- default:
- // No other ops have convergent operands.
- break;
- }
- }
- }
- return nullptr;
- }
- } // namespace
- INITIALIZE_PASS(DxilConvergentMark, "hlsl-dxil-convergent-mark",
- "Mark convergent", false, false)
- ModulePass *llvm::createDxilConvergentMarkPass() {
- return new DxilConvergentMark();
- }
- namespace {
- class DxilConvergentClear : public ModulePass {
- public:
- static char ID; // Pass identification, replacement for typeid
- explicit DxilConvergentClear() : ModulePass(ID) {}
- const char *getPassName() const override {
- return "DxilConvergentClear";
- }
- bool runOnModule(Module &M) override {
- std::vector<Function *> convergentList;
- for (Function &F : M.functions()) {
- if (F.getName().startswith(kConvergentFunctionPrefix)) {
- convergentList.emplace_back(&F);
- }
- }
- for (Function *F : convergentList) {
- ClearConvergent(F);
- }
- return convergentList.size();
- }
- private:
- void ClearConvergent(Function *F);
- };
- char DxilConvergentClear::ID = 0;
- void DxilConvergentClear::ClearConvergent(Function *F) {
- // Replace all users with arg.
- for (auto it = F->user_begin(); it != F->user_end();) {
- CallInst *CI = cast<CallInst>(*(it++));
- Value *arg = CI->getArgOperand(0);
- CI->replaceAllUsesWith(arg);
- CI->eraseFromParent();
- }
- F->eraseFromParent();
- }
- } // namespace
- INITIALIZE_PASS(DxilConvergentClear, "hlsl-dxil-convergent-clear",
- "Clear convergent before dxil emit", false, false)
- ModulePass *llvm::createDxilConvergentClearPass() {
- return new DxilConvergentClear();
- }
|