DxilConditionalMem2Reg.cpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. //===- DxilConditionalMem2Reg.cpp - Mem2Reg that selectively promotes Allocas ----===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "llvm/Pass.h"
  10. #include "llvm/Analysis/AssumptionCache.h"
  11. #include "llvm/Transforms/Scalar.h"
  12. #include "llvm/Transforms/Utils/PromoteMemToReg.h"
  13. #include "llvm/IR/Instructions.h"
  14. #include "llvm/IR/IntrinsicInst.h"
  15. #include "llvm/IR/Dominators.h"
  16. #include "llvm/IR/IRBuilder.h"
  17. #include "llvm/Support/raw_ostream.h"
  18. #include "llvm/Support/Debug.h"
  19. #include "llvm/IR/LegacyPassManager.h"
  20. #include "dxc/DXIL/DxilUtil.h"
  21. #include "dxc/HLSL/HLModule.h"
  22. #include "llvm/Analysis/DxilValueCache.h"
  23. using namespace llvm;
  24. using namespace hlsl;
  25. static bool ContainsFloatingPointType(Type *Ty) {
  26. if (Ty->isFloatingPointTy()) {
  27. return true;
  28. }
  29. else if (Ty->isArrayTy()) {
  30. return ContainsFloatingPointType(Ty->getArrayElementType());
  31. }
  32. else if (Ty->isVectorTy()) {
  33. return ContainsFloatingPointType(Ty->getVectorElementType());
  34. }
  35. else if (Ty->isStructTy()) {
  36. for (unsigned i = 0, NumStructElms = Ty->getStructNumElements(); i < NumStructElms; i++) {
  37. if (ContainsFloatingPointType(Ty->getStructElementType(i)))
  38. return true;
  39. }
  40. }
  41. return false;
  42. }
  43. static bool Mem2Reg(Function &F, DominatorTree &DT, AssumptionCache &AC) {
  44. BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function
  45. bool Changed = false;
  46. std::vector<AllocaInst*> Allocas;
  47. while (1) {
  48. Allocas.clear();
  49. // Find allocas that are safe to promote, by looking at all instructions in
  50. // the entry node
  51. for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
  52. if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) // Is it an alloca?
  53. if (isAllocaPromotable(AI) &&
  54. (!HLModule::HasPreciseAttributeWithMetadata(AI) || !ContainsFloatingPointType(AI->getAllocatedType())))
  55. Allocas.push_back(AI);
  56. if (Allocas.empty()) break;
  57. PromoteMemToReg(Allocas, DT, nullptr, &AC);
  58. Changed = true;
  59. }
  60. return Changed;
  61. }
  62. //
  63. // Special Mem2Reg pass that conditionally promotes or transforms Alloca's.
  64. //
  65. // Anything marked 'dx.precise', will not be promoted because precise markers
  66. // are not propagated to the dxil operations yet and will be lost if alloca
  67. // is removed right now.
  68. //
  69. // Precise Allocas of vectors get scalarized here. It's important we do that
  70. // before Scalarizer pass because promoting the allocas later than that will
  71. // produce vector phi's (disallowed by the validator), which need another
  72. // Scalarizer pass to clean up.
  73. //
  74. class DxilConditionalMem2Reg : public FunctionPass {
  75. public:
  76. static char ID;
  77. // Function overrides that resolve options when used for DxOpt
  78. void applyOptions(PassOptions O) override {
  79. GetPassOptionBool(O, "NoOpt", &NoOpt, false);
  80. }
  81. void dumpConfig(raw_ostream &OS) override {
  82. FunctionPass::dumpConfig(OS);
  83. OS << ",NoOpt=" << NoOpt;
  84. }
  85. bool NoOpt = false;
  86. explicit DxilConditionalMem2Reg(bool NoOpt=false) : FunctionPass(ID), NoOpt(NoOpt)
  87. {
  88. initializeDxilConditionalMem2RegPass(*PassRegistry::getPassRegistry());
  89. }
  90. void getAnalysisUsage(AnalysisUsage &AU) const override {
  91. AU.addRequired<DominatorTreeWrapperPass>();
  92. AU.addRequired<AssumptionCacheTracker>();
  93. AU.setPreservesCFG();
  94. }
  95. // Collect and remove all instructions that use AI, but
  96. // give up if there are anything other than store, bitcast,
  97. // memcpy, or GEP.
  98. static bool TryRemoveUnusedAlloca(AllocaInst *AI) {
  99. std::vector<Instruction *> WorkList;
  100. WorkList.push_back(AI);
  101. for (unsigned i = 0; i < WorkList.size(); i++) {
  102. Instruction *I = WorkList[i];
  103. for (User *U : I->users()) {
  104. Instruction *UI = cast<Instruction>(U);
  105. unsigned Opcode = UI->getOpcode();
  106. if (Opcode == Instruction::BitCast ||
  107. Opcode == Instruction::GetElementPtr ||
  108. Opcode == Instruction::Store)
  109. {
  110. WorkList.push_back(UI);
  111. }
  112. else if (MemCpyInst *MC = dyn_cast<MemCpyInst>(UI)) {
  113. if (MC->getSource() == I) { // MC reads from our alloca
  114. return false;
  115. }
  116. WorkList.push_back(UI);
  117. }
  118. else { // Load? PHINode? Assume read.
  119. return false;
  120. }
  121. }
  122. }
  123. // Remove all instructions
  124. for (auto It = WorkList.rbegin(), E = WorkList.rend(); It != E; It++) {
  125. Instruction *I = *It;
  126. I->eraseFromParent();
  127. }
  128. return true;
  129. }
  130. static bool RemoveAllUnusedAllocas(Function &F) {
  131. std::vector<AllocaInst *> Allocas;
  132. BasicBlock &EntryBB = *F.begin();
  133. for (auto It = EntryBB.begin(), E = EntryBB.end(); It != E;) {
  134. Instruction &I = *(It++);
  135. if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
  136. Allocas.push_back(AI);
  137. }
  138. }
  139. bool Changed = false;
  140. for (AllocaInst *AI : Allocas) {
  141. Changed |= TryRemoveUnusedAlloca(AI);
  142. }
  143. return Changed;
  144. }
  145. //
  146. // Turns all allocas of vector types that are marked with 'dx.precise'
  147. // and turn them into scalars. For example:
  148. //
  149. // x = alloca <f32 x 4> !dx.precise
  150. //
  151. // becomes:
  152. //
  153. // x1 = alloca f32 !dx.precise
  154. // x2 = alloca f32 !dx.precise
  155. // x3 = alloca f32 !dx.precise
  156. // x4 = alloca f32 !dx.precise
  157. //
  158. // This function also replaces all stores and loads but leaves everything
  159. // else alone by generating insertelement and extractelement as appropriate.
  160. //
  161. static bool ScalarizePreciseVectorAlloca(Function &F) {
  162. BasicBlock *Entry = &*F.begin();
  163. bool Changed = false;
  164. for (auto it = Entry->begin(); it != Entry->end();) {
  165. Instruction *I = &*(it++);
  166. AllocaInst *AI = dyn_cast<AllocaInst>(I);
  167. if (!AI || !AI->getAllocatedType()->isVectorTy()) continue;
  168. if (!HLModule::HasPreciseAttributeWithMetadata(AI)) continue;
  169. IRBuilder<> B(AI);
  170. VectorType *VTy = cast<VectorType>(AI->getAllocatedType());
  171. Type *ScalarTy = VTy->getVectorElementType();
  172. const unsigned VectorSize = VTy->getVectorNumElements();
  173. SmallVector<AllocaInst *, 32> Elements;
  174. for (unsigned i = 0; i < VectorSize; i++) {
  175. AllocaInst *Elem = B.CreateAlloca(ScalarTy);
  176. hlsl::DxilMDHelper::CopyMetadata(*Elem, *AI);
  177. Elements.push_back(Elem);
  178. }
  179. for (auto it = AI->user_begin(); it != AI->user_end();) {
  180. User *U = *(it++);
  181. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  182. B.SetInsertPoint(LI);
  183. Value *Vec = UndefValue::get(VTy);
  184. for (unsigned i = 0; i < VectorSize; i++) {
  185. LoadInst *Elem = B.CreateLoad(Elements[i]);
  186. hlsl::DxilMDHelper::CopyMetadata(*Elem, *LI);
  187. Vec = B.CreateInsertElement(Vec, Elem, i);
  188. }
  189. LI->replaceAllUsesWith(Vec);
  190. LI->eraseFromParent();
  191. }
  192. else if (StoreInst *Store = dyn_cast<StoreInst>(U)) {
  193. B.SetInsertPoint(Store);
  194. Value *Vec = Store->getValueOperand();
  195. for (unsigned i = 0; i < VectorSize; i++) {
  196. Value *Elem = B.CreateExtractElement(Vec, i);
  197. StoreInst *ElemStore = B.CreateStore(Elem, Elements[i]);
  198. hlsl::DxilMDHelper::CopyMetadata(*ElemStore, *Store);
  199. }
  200. Store->eraseFromParent();
  201. }
  202. else {
  203. llvm_unreachable("Cannot handle non-store/load on precise vector allocas");
  204. }
  205. }
  206. AI->eraseFromParent();
  207. Changed = true;
  208. }
  209. return Changed;
  210. }
  211. bool runOnFunction(Function &F) override {
  212. DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  213. AssumptionCache *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
  214. bool Changed = false;
  215. Changed |= RemoveAllUnusedAllocas(F);
  216. Changed |= ScalarizePreciseVectorAlloca(F);
  217. Changed |= Mem2Reg(F, *DT, *AC);
  218. return Changed;
  219. }
  220. };
  221. char DxilConditionalMem2Reg::ID;
  222. Pass *llvm::createDxilConditionalMem2RegPass(bool NoOpt) {
  223. return new DxilConditionalMem2Reg(NoOpt);
  224. }
  225. INITIALIZE_PASS_BEGIN(DxilConditionalMem2Reg, "dxil-cond-mem2reg", "Dxil Conditional Mem2Reg", false, false)
  226. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  227. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  228. INITIALIZE_PASS_END(DxilConditionalMem2Reg, "dxil-cond-mem2reg", "Dxil Conditional Mem2Reg", false, false)