DxilPrecisePropagatePass.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilPrecisePropagatePass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/DXIL/DxilModule.h"
  10. #include "dxc/HLSL/DxilGenerationPass.h"
  11. #include "dxc/HLSL/HLModule.h"
  12. #include "dxc/HLSL/HLOperations.h"
  13. #include "dxc/HLSL/ControlDependence.h"
  14. #include "llvm/Pass.h"
  15. #include "llvm/IR/Function.h"
  16. #include "llvm/IR/Instruction.h"
  17. #include "llvm/IR/Instructions.h"
  18. #include "llvm/IR/Operator.h"
  19. #include "llvm/IR/Module.h"
  20. #include "llvm/Support/Casting.h"
  21. #include <unordered_set>
  22. #include <vector>
  23. using namespace llvm;
  24. using namespace hlsl;
  25. namespace {
  26. typedef std::unordered_set<Value *> ValueSet;
  27. struct FuncInfo {
  28. ControlDependence CtrlDep;
  29. std::unique_ptr<llvm::DominatorTreeBase<llvm::BasicBlock>> pPostDom;
  30. void Init(Function *F);
  31. void Clear();
  32. };
  33. typedef std::unordered_map<llvm::Function *, std::unique_ptr<FuncInfo>> FuncInfoMap;
  34. class DxilPrecisePropagatePass : public ModulePass {
  35. public:
  36. static char ID; // Pass identification, replacement for typeid
  37. explicit DxilPrecisePropagatePass() : ModulePass(ID) {}
  38. const char *getPassName() const override { return "DXIL Precise Propagate"; }
  39. bool runOnModule(Module &M) override {
  40. m_pDM = &(M.GetOrCreateDxilModule());
  41. std::vector<Function*> deadList;
  42. for (Function &F : M.functions()) {
  43. if (HLModule::HasPreciseAttribute(&F)) {
  44. PropagatePreciseOnFunctionUser(F);
  45. deadList.emplace_back(&F);
  46. }
  47. }
  48. for (Function *F : deadList)
  49. F->eraseFromParent();
  50. return true;
  51. }
  52. private:
  53. void PropagatePreciseOnFunctionUser(Function &F);
  54. void AddToWorkList(Value *V);
  55. void ProcessWorkList();
  56. void Propagate(Instruction *I);
  57. void PropagateOnPointer(Value *Ptr);
  58. void PropagateOnPointerUsers(Value *Ptr);
  59. void PropagateThroughGEPs(Value *Ptr, ArrayRef<Value *> idxList,
  60. ValueSet &processedGEPs);
  61. void PropagateOnPointerUsedInCall(Value *Ptr, CallInst *CI);
  62. void PropagateCtrlDep(FuncInfo &FI, BasicBlock *BB);
  63. void PropagateCtrlDep(BasicBlock *BB);
  64. void PropagateCtrlDep(Instruction *I);
  65. // Add to m_ProcessedSet, return true if already in set.
  66. bool Processed(Value *V) {
  67. return !m_ProcessedSet.insert(V).second;
  68. }
  69. FuncInfo &GetFuncInfo(Function *F);
  70. DxilModule *m_pDM;
  71. std::vector<Value*> m_WorkList;
  72. ValueSet m_ProcessedSet;
  73. FuncInfoMap m_FuncInfo;
  74. };
  75. char DxilPrecisePropagatePass::ID = 0;
  76. }
  77. void DxilPrecisePropagatePass::PropagatePreciseOnFunctionUser(Function &F) {
  78. for (auto U = F.user_begin(), E = F.user_end(); U != E;) {
  79. CallInst *CI = cast<CallInst>(*(U++));
  80. Value *V = CI->getArgOperand(0);
  81. AddToWorkList(V);
  82. ProcessWorkList();
  83. CI->eraseFromParent();
  84. }
  85. }
  86. void DxilPrecisePropagatePass::AddToWorkList(Value *V) {
  87. // Skip values already marked.
  88. if (Processed(V))
  89. return;
  90. m_WorkList.emplace_back(V);
  91. }
  92. void DxilPrecisePropagatePass::ProcessWorkList() {
  93. while (!m_WorkList.empty()) {
  94. Value *V = m_WorkList.back();
  95. m_WorkList.pop_back();
  96. if (V->getType()->isPointerTy()) {
  97. PropagateOnPointer(V);
  98. }
  99. Instruction *I = dyn_cast<Instruction>(V);
  100. if (!I)
  101. continue;
  102. // Set precise fast math on those instructions that support it.
  103. if (DxilModule::PreservesFastMathFlags(I))
  104. DxilModule::SetPreciseFastMathFlags(I);
  105. // Fast math not work on call, use metadata.
  106. if (isa<FPMathOperator>(I) && isa<CallInst>(I))
  107. HLModule::MarkPreciseAttributeWithMetadata(cast<CallInst>(I));
  108. Propagate(I);
  109. PropagateCtrlDep(I);
  110. }
  111. }
  112. void DxilPrecisePropagatePass::Propagate(Instruction *I) {
  113. if (CallInst *CI = dyn_cast<CallInst>(I)) {
  114. for (unsigned i = 0; i < CI->getNumArgOperands(); i++)
  115. AddToWorkList(CI->getArgOperand(i));
  116. } else {
  117. for (Value *src : I->operands())
  118. AddToWorkList(src);
  119. }
  120. if (PHINode *Phi = dyn_cast<PHINode>(I)) {
  121. // Use pred for control dependence when constant (for now)
  122. FuncInfo &FI = GetFuncInfo(I->getParent()->getParent());
  123. for (unsigned i = 0; i < Phi->getNumIncomingValues(); i++) {
  124. if (isa<Constant>(Phi->getIncomingValue(i)))
  125. PropagateCtrlDep(FI, Phi->getIncomingBlock(i));
  126. }
  127. }
  128. }
  129. // TODO: This could be a util function
  130. // TODO: Should this tunnel through addrspace cast?
  131. // And how could bitcast be handled?
  132. static Value *GetRootAndIndicesForGEP(
  133. GEPOperator *GEP, SmallVectorImpl<Value*> &idxList) {
  134. Value *Ptr = GEP;
  135. SmallVector<GEPOperator*, 4> GEPs;
  136. GEPs.emplace_back(GEP);
  137. while ((GEP = dyn_cast<GEPOperator>(Ptr = GEP->getPointerOperand())))
  138. GEPs.emplace_back(GEP);
  139. while (!GEPs.empty()) {
  140. GEP = GEPs.back();
  141. GEPs.pop_back();
  142. auto idx = GEP->idx_begin();
  143. idx++;
  144. while (idx != GEP->idx_end())
  145. idxList.emplace_back(*(idx++));
  146. }
  147. return Ptr;
  148. }
  149. void DxilPrecisePropagatePass::PropagateOnPointer(Value *Ptr) {
  150. PropagateOnPointerUsers(Ptr);
  151. // GetElementPointer gets special treatment since different GEPs may be used
  152. // at different points on the same root pointer to load or store data. We
  153. // need to find any stores that could have written data to the pointer we are
  154. // marking, so we need to search through all GEPs from the root pointer for
  155. // ones that may write to the same location.
  156. //
  157. // In addition, there may be multiple GEPs between the root pointer and loads
  158. // or stores, so we need to accumulate all the indices between the root and
  159. // the leaf pointer we are marking.
  160. //
  161. // Starting at the root pointer, we follow users, looking for GEPs with
  162. // indices that could "match", or calls that may write to the pointer along
  163. // the way. A "match" to the reference index is one that matches with constant
  164. // values, or if either index is non-constant, since the compiler doesn't know
  165. // what index may be read or written in that case.
  166. //
  167. // This still doesn't handle addrspace cast or bitcast, so propagation through
  168. // groupshared aggregates will not work, as one example.
  169. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Ptr)) {
  170. // Get root Ptr, gather index list, and mark matching stores
  171. SmallVector<Value*, 8> idxList;
  172. Ptr = GetRootAndIndicesForGEP(GEP, idxList);
  173. ValueSet processedGEPs;
  174. PropagateThroughGEPs(Ptr, idxList, processedGEPs);
  175. }
  176. }
  177. void DxilPrecisePropagatePass::PropagateOnPointerUsers(Value *Ptr) {
  178. // Find all store and propagate on the val operand of store.
  179. // For CallInst, if Ptr is used as out parameter, mark it.
  180. for (User *U : Ptr->users()) {
  181. if (StoreInst *stInst = dyn_cast<StoreInst>(U)) {
  182. Value *val = stInst->getValueOperand();
  183. AddToWorkList(val);
  184. } else if (CallInst *CI = dyn_cast<CallInst>(U)) {
  185. PropagateOnPointerUsedInCall(Ptr, CI);
  186. }
  187. }
  188. }
  189. void DxilPrecisePropagatePass::PropagateThroughGEPs(
  190. Value *Ptr, ArrayRef<Value*> idxList, ValueSet &processedGEPs) {
  191. // recurse to matching GEP users
  192. for (User *U : Ptr->users()) {
  193. if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
  194. // skip visited GEPs
  195. // These are separate from processedSet because while we don't need to
  196. // visit an intermediate GEP multiple times while marking a single value
  197. // precise, we are not necessarily marking every value reachable from
  198. // the GEP as precise, so we may need to revisit when marking a different
  199. // value as precise.
  200. if (!processedGEPs.insert(GEP).second)
  201. continue;
  202. // Mismatch if both constant and unequal, otherwise be conservative.
  203. bool bMismatch = false;
  204. auto idx = GEP->idx_begin();
  205. idx++;
  206. unsigned i = 0;
  207. while (idx != GEP->idx_end()) {
  208. if (ConstantInt *C = dyn_cast<ConstantInt>(*idx)) {
  209. if (ConstantInt *CRef = dyn_cast<ConstantInt>(idxList[i])) {
  210. if (CRef->getLimitedValue() != C->getLimitedValue()) {
  211. bMismatch = true;
  212. break;
  213. }
  214. }
  215. }
  216. idx++;
  217. i++;
  218. }
  219. if (bMismatch)
  220. continue;
  221. if ((unsigned)idxList.size() == i) {
  222. // Mark leaf users
  223. if (Processed(GEP))
  224. continue;
  225. PropagateOnPointerUsers(GEP);
  226. } else {
  227. // Recurse GEP users
  228. PropagateThroughGEPs(
  229. GEP, ArrayRef<Value*>(idxList.data() + i, idxList.end()),
  230. processedGEPs);
  231. }
  232. } else if (CallInst *CI = dyn_cast<CallInst>(U)) {
  233. // Root pointer or intermediate GEP used in call.
  234. // If it may write to the pointer, we must mark the call and recurse
  235. // arguments.
  236. // This also widens the precise propagation to the entire aggregate
  237. // pointed to by the root ptr or intermediate GEP.
  238. PropagateOnPointerUsedInCall(Ptr, CI);
  239. }
  240. }
  241. }
  242. void DxilPrecisePropagatePass::PropagateOnPointerUsedInCall(
  243. Value *Ptr, CallInst *CI) {
  244. bool bReadOnly = true;
  245. Function *F = CI->getCalledFunction();
  246. // skip starting points (dx.attribute.precise calls)
  247. if (HLModule::HasPreciseAttribute(F))
  248. return;
  249. const DxilFunctionAnnotation *funcAnnotation =
  250. m_pDM->GetTypeSystem().GetFunctionAnnotation(F);
  251. if (funcAnnotation) {
  252. for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
  253. if (Ptr != CI->getArgOperand(i))
  254. continue;
  255. const DxilParameterAnnotation &paramAnnotation =
  256. funcAnnotation->GetParameterAnnotation(i);
  257. // OutputPatch and OutputStream will be checked after scalar repl.
  258. // Here only check out/inout
  259. if (paramAnnotation.GetParamInputQual() == DxilParamInputQual::Out ||
  260. paramAnnotation.GetParamInputQual() == DxilParamInputQual::Inout) {
  261. bReadOnly = false;
  262. break;
  263. }
  264. }
  265. } else {
  266. bReadOnly = false;
  267. }
  268. if (!bReadOnly) {
  269. AddToWorkList(CI);
  270. }
  271. }
  272. void FuncInfo::Init(Function *F) {
  273. if (!pPostDom) {
  274. pPostDom = make_unique<DominatorTreeBase<BasicBlock> >(true);
  275. pPostDom->recalculate(*F);
  276. CtrlDep.Compute(F, *pPostDom);
  277. }
  278. }
  279. void FuncInfo::Clear() {
  280. CtrlDep.Clear();
  281. pPostDom.reset();
  282. }
  283. FuncInfo &DxilPrecisePropagatePass::GetFuncInfo(Function *F) {
  284. auto &FI = m_FuncInfo[F];
  285. if (!FI) {
  286. FI = make_unique<FuncInfo>();
  287. FI->Init(F);
  288. }
  289. return *FI.get();
  290. }
  291. void DxilPrecisePropagatePass::PropagateCtrlDep(FuncInfo &FI, BasicBlock *BB) {
  292. if (Processed(BB))
  293. return;
  294. const BasicBlockSet &CtrlDepSet = FI.CtrlDep.GetCDBlocks(BB);
  295. for (BasicBlock *B : CtrlDepSet) {
  296. AddToWorkList(B->getTerminator());
  297. }
  298. }
  299. void DxilPrecisePropagatePass::PropagateCtrlDep(BasicBlock *BB) {
  300. FuncInfo &FI = GetFuncInfo(BB->getParent());
  301. PropagateCtrlDep(FI, BB);
  302. }
  303. void DxilPrecisePropagatePass::PropagateCtrlDep(Instruction *I) {
  304. PropagateCtrlDep(I->getParent());
  305. }
  306. ModulePass *llvm::createDxilPrecisePropagatePass() {
  307. return new DxilPrecisePropagatePass();
  308. }
  309. INITIALIZE_PASS(DxilPrecisePropagatePass, "hlsl-dxil-precise", "DXIL precise attribute propagate", false, false)