DxilConditionalMem2Reg.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. //===- DxilConditionalMem2Reg.cpp - Mem2Reg that selectively promotes Allocas ----===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "llvm/Pass.h"
  10. #include "llvm/Analysis/AssumptionCache.h"
  11. #include "llvm/Transforms/Scalar.h"
  12. #include "llvm/Transforms/Utils/PromoteMemToReg.h"
  13. #include "llvm/IR/Instructions.h"
  14. #include "llvm/IR/IntrinsicInst.h"
  15. #include "llvm/IR/Dominators.h"
  16. #include "llvm/IR/Module.h"
  17. #include "llvm/IR/IRBuilder.h"
  18. #include "llvm/Support/raw_ostream.h"
  19. #include "llvm/Support/Debug.h"
  20. #include "llvm/IR/LegacyPassManager.h"
  21. #include "llvm/IR/DebugInfo.h"
  22. #include "llvm/IR/DIBuilder.h"
  23. #include "dxc/DXIL/DxilUtil.h"
  24. #include "dxc/HLSL/HLModule.h"
  25. #include "llvm/Analysis/DxilValueCache.h"
  26. using namespace llvm;
  27. using namespace hlsl;
  28. static bool ContainsFloatingPointType(Type *Ty) {
  29. if (Ty->isFloatingPointTy()) {
  30. return true;
  31. }
  32. else if (Ty->isArrayTy()) {
  33. return ContainsFloatingPointType(Ty->getArrayElementType());
  34. }
  35. else if (Ty->isVectorTy()) {
  36. return ContainsFloatingPointType(Ty->getVectorElementType());
  37. }
  38. else if (Ty->isStructTy()) {
  39. for (unsigned i = 0, NumStructElms = Ty->getStructNumElements(); i < NumStructElms; i++) {
  40. if (ContainsFloatingPointType(Ty->getStructElementType(i)))
  41. return true;
  42. }
  43. }
  44. return false;
  45. }
  46. static bool Mem2Reg(Function &F, DominatorTree &DT, AssumptionCache &AC) {
  47. BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function
  48. bool Changed = false;
  49. std::vector<AllocaInst*> Allocas;
  50. while (1) {
  51. Allocas.clear();
  52. // Find allocas that are safe to promote, by looking at all instructions in
  53. // the entry node
  54. for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
  55. if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) // Is it an alloca?
  56. if (isAllocaPromotable(AI) &&
  57. (!HLModule::HasPreciseAttributeWithMetadata(AI) || !ContainsFloatingPointType(AI->getAllocatedType())))
  58. Allocas.push_back(AI);
  59. if (Allocas.empty()) break;
  60. PromoteMemToReg(Allocas, DT, nullptr, &AC);
  61. Changed = true;
  62. }
  63. return Changed;
  64. }
  65. //
  66. // Special Mem2Reg pass that conditionally promotes or transforms Alloca's.
  67. //
  68. // Anything marked 'dx.precise', will not be promoted because precise markers
  69. // are not propagated to the dxil operations yet and will be lost if alloca
  70. // is removed right now.
  71. //
  72. // Precise Allocas of vectors get scalarized here. It's important we do that
  73. // before Scalarizer pass because promoting the allocas later than that will
  74. // produce vector phi's (disallowed by the validator), which need another
  75. // Scalarizer pass to clean up.
  76. //
  77. class DxilConditionalMem2Reg : public FunctionPass {
  78. public:
  79. static char ID;
  80. // Function overrides that resolve options when used for DxOpt
  81. void applyOptions(PassOptions O) override {
  82. GetPassOptionBool(O, "NoOpt", &NoOpt, false);
  83. }
  84. void dumpConfig(raw_ostream &OS) override {
  85. FunctionPass::dumpConfig(OS);
  86. OS << ",NoOpt=" << NoOpt;
  87. }
  88. bool NoOpt = false;
  89. explicit DxilConditionalMem2Reg(bool NoOpt=false) : FunctionPass(ID), NoOpt(NoOpt)
  90. {
  91. initializeDxilConditionalMem2RegPass(*PassRegistry::getPassRegistry());
  92. }
  93. void getAnalysisUsage(AnalysisUsage &AU) const override {
  94. AU.addRequired<DominatorTreeWrapperPass>();
  95. AU.addRequired<AssumptionCacheTracker>();
  96. AU.setPreservesCFG();
  97. }
  98. // Collect and remove all instructions that use AI, but
  99. // give up if there are anything other than store, bitcast,
  100. // memcpy, or GEP.
  101. static bool TryRemoveUnusedAlloca(AllocaInst *AI) {
  102. std::vector<Instruction *> WorkList;
  103. WorkList.push_back(AI);
  104. for (unsigned i = 0; i < WorkList.size(); i++) {
  105. Instruction *I = WorkList[i];
  106. for (User *U : I->users()) {
  107. Instruction *UI = cast<Instruction>(U);
  108. unsigned Opcode = UI->getOpcode();
  109. if (Opcode == Instruction::BitCast ||
  110. Opcode == Instruction::GetElementPtr ||
  111. Opcode == Instruction::Store)
  112. {
  113. WorkList.push_back(UI);
  114. }
  115. else if (MemCpyInst *MC = dyn_cast<MemCpyInst>(UI)) {
  116. if (MC->getSource() == I) { // MC reads from our alloca
  117. return false;
  118. }
  119. WorkList.push_back(UI);
  120. }
  121. else { // Load? PHINode? Assume read.
  122. return false;
  123. }
  124. }
  125. }
  126. // Remove all instructions
  127. for (auto It = WorkList.rbegin(), E = WorkList.rend(); It != E; It++) {
  128. Instruction *I = *It;
  129. I->eraseFromParent();
  130. }
  131. return true;
  132. }
  133. static bool RemoveAllUnusedAllocas(Function &F) {
  134. std::vector<AllocaInst *> Allocas;
  135. BasicBlock &EntryBB = *F.begin();
  136. for (auto It = EntryBB.begin(), E = EntryBB.end(); It != E;) {
  137. Instruction &I = *(It++);
  138. if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
  139. Allocas.push_back(AI);
  140. }
  141. }
  142. bool Changed = false;
  143. for (AllocaInst *AI : Allocas) {
  144. Changed |= TryRemoveUnusedAlloca(AI);
  145. }
  146. return Changed;
  147. }
  148. //
  149. // Turns all allocas of vector types that are marked with 'dx.precise'
  150. // and turn them into scalars. For example:
  151. //
  152. // x = alloca <f32 x 4> !dx.precise
  153. //
  154. // becomes:
  155. //
  156. // x1 = alloca f32 !dx.precise
  157. // x2 = alloca f32 !dx.precise
  158. // x3 = alloca f32 !dx.precise
  159. // x4 = alloca f32 !dx.precise
  160. //
  161. // This function also replaces all stores and loads but leaves everything
  162. // else alone by generating insertelement and extractelement as appropriate.
  163. //
  164. static bool ScalarizePreciseVectorAlloca(Function &F) {
  165. BasicBlock *Entry = &*F.begin();
  166. bool Changed = false;
  167. for (auto it = Entry->begin(); it != Entry->end();) {
  168. Instruction *I = &*(it++);
  169. AllocaInst *AI = dyn_cast<AllocaInst>(I);
  170. if (!AI || !AI->getAllocatedType()->isVectorTy()) continue;
  171. if (!HLModule::HasPreciseAttributeWithMetadata(AI)) continue;
  172. IRBuilder<> B(AI);
  173. VectorType *VTy = cast<VectorType>(AI->getAllocatedType());
  174. Type *ScalarTy = VTy->getVectorElementType();
  175. const unsigned VectorSize = VTy->getVectorNumElements();
  176. SmallVector<AllocaInst *, 32> Elements;
  177. for (unsigned i = 0; i < VectorSize; i++) {
  178. AllocaInst *Elem = B.CreateAlloca(ScalarTy);
  179. hlsl::DxilMDHelper::CopyMetadata(*Elem, *AI);
  180. Elements.push_back(Elem);
  181. }
  182. for (auto it = AI->user_begin(); it != AI->user_end();) {
  183. User *U = *(it++);
  184. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  185. B.SetInsertPoint(LI);
  186. Value *Vec = UndefValue::get(VTy);
  187. for (unsigned i = 0; i < VectorSize; i++) {
  188. LoadInst *Elem = B.CreateLoad(Elements[i]);
  189. hlsl::DxilMDHelper::CopyMetadata(*Elem, *LI);
  190. Vec = B.CreateInsertElement(Vec, Elem, i);
  191. }
  192. LI->replaceAllUsesWith(Vec);
  193. LI->eraseFromParent();
  194. }
  195. else if (StoreInst *Store = dyn_cast<StoreInst>(U)) {
  196. B.SetInsertPoint(Store);
  197. Value *Vec = Store->getValueOperand();
  198. for (unsigned i = 0; i < VectorSize; i++) {
  199. Value *Elem = B.CreateExtractElement(Vec, i);
  200. StoreInst *ElemStore = B.CreateStore(Elem, Elements[i]);
  201. hlsl::DxilMDHelper::CopyMetadata(*ElemStore, *Store);
  202. }
  203. Store->eraseFromParent();
  204. }
  205. else {
  206. llvm_unreachable("Cannot handle non-store/load on precise vector allocas");
  207. }
  208. }
  209. AI->eraseFromParent();
  210. Changed = true;
  211. }
  212. return Changed;
  213. }
  214. struct StoreInfo {
  215. Value *V;
  216. unsigned Offset;
  217. };
  218. static bool FindAllStores(Module &M, Value *V, SmallVectorImpl<StoreInfo> *Stores) {
  219. SmallVector<StoreInfo, 8> Worklist;
  220. std::set<Value *> Seen;
  221. auto Add = [&](Value *V, unsigned OffsetInBits) {
  222. if (Seen.insert(V).second)
  223. Worklist.push_back({ V, OffsetInBits });
  224. };
  225. Add(V, 0);
  226. const DataLayout &DL = M.getDataLayout();
  227. while (Worklist.size()) {
  228. auto Info = Worklist.pop_back_val();
  229. auto *Elem = Info.V;
  230. if (auto GEP = dyn_cast<GEPOperator>(Elem)) {
  231. if (GEP->getNumIndices() != 2)
  232. continue;
  233. unsigned ElemSize = 0;
  234. Type *GEPPtrType = GEP->getPointerOperand()->getType();
  235. Type *PtrElemType = GEPPtrType->getPointerElementType();
  236. if (ArrayType *ArrayTy = dyn_cast<ArrayType>(PtrElemType)) {
  237. ElemSize = DL.getTypeAllocSizeInBits(ArrayTy->getElementType());
  238. }
  239. else if (VectorType *VectorTy = dyn_cast<VectorType>(PtrElemType)) {
  240. ElemSize = DL.getTypeAllocSizeInBits(VectorTy->getElementType());
  241. }
  242. else {
  243. return false;
  244. }
  245. unsigned OffsetInBits = 0;
  246. for (unsigned i = 0; i < GEP->getNumIndices(); i++) {
  247. auto IdxOp = dyn_cast<ConstantInt>(GEP->getOperand(i+1));
  248. if (!IdxOp) {
  249. return false;
  250. }
  251. uint64_t Idx = IdxOp->getLimitedValue();
  252. if (i == 0) {
  253. if (Idx != 0)
  254. return false;
  255. }
  256. else {
  257. OffsetInBits = Idx * ElemSize;
  258. }
  259. }
  260. for (User *U : Elem->users())
  261. Add(U, Info.Offset + OffsetInBits);
  262. }
  263. else if (auto *Store = dyn_cast<StoreInst>(Elem)) {
  264. Stores->push_back({ Store, Info.Offset });
  265. }
  266. }
  267. return true;
  268. }
  269. // Function to rewrite debug info for output argument.
  270. // Sometimes, normal local variables that get returned from functions get rewritten as
  271. // a pointer argument.
  272. //
  273. // Right now, we generally have a single dbg.declare for the Argument, but as we lower
  274. // it to storeOutput, the dbg.declare and the Argument both get removed, leavning no
  275. // debug info for the local variable.
  276. //
  277. // Solution here is to rewrite the dbg.declare as dbg.value's by finding all the stores
  278. // and writing a dbg.value immediately before the store. Fairly conservative at the moment
  279. // about what cases to rewrite (only scalars and vectors, and arrays of scalars and vectors).
  280. //
  281. bool RewriteOutputArgsDebugInfo(Function &F) {
  282. bool Changed = false;
  283. Module *M = F.getParent();
  284. DIBuilder DIB(*M);
  285. SmallVector<StoreInfo, 4> Stores;
  286. LLVMContext &Ctx = F.getContext();
  287. for (Argument &Arg : F.args()) {
  288. if (!Arg.getType()->isPointerTy())
  289. continue;
  290. Type *Ty = Arg.getType()->getPointerElementType();
  291. bool IsSimpleType =
  292. Ty->isSingleValueType() ||
  293. Ty->isVectorTy() ||
  294. (Ty->isArrayTy() && (Ty->getArrayElementType()->isVectorTy() || Ty->getArrayElementType()->isSingleValueType()));
  295. if (!IsSimpleType)
  296. continue;
  297. Stores.clear();
  298. for (User *U : Arg.users()) {
  299. if (!FindAllStores(*M, U, &Stores)) {
  300. Stores.clear();
  301. break;
  302. }
  303. }
  304. if (Stores.empty())
  305. continue;
  306. DbgDeclareInst *Declare = nullptr;
  307. if (auto *L = LocalAsMetadata::getIfExists(&Arg)) {
  308. if (auto *DINode = MetadataAsValue::getIfExists(Ctx, L)) {
  309. if (!DINode->user_empty() && std::next(DINode->user_begin()) == DINode->user_end()) {
  310. Declare = dyn_cast<DbgDeclareInst>(*DINode->user_begin());
  311. }
  312. }
  313. }
  314. if (Declare) {
  315. DITypeIdentifierMap EmptyMap;
  316. DILocalVariable *Var = Declare->getVariable();
  317. DIExpression *Expr = Declare->getExpression();
  318. DIType *VarTy = Var->getType().resolve(EmptyMap);
  319. uint64_t VarSize = VarTy->getSizeInBits();
  320. uint64_t Offset = 0;
  321. if (Expr->isBitPiece())
  322. Offset = Expr->getBitPieceOffset();
  323. for (auto &Info : Stores) {
  324. auto *Store = cast<StoreInst>(Info.V);
  325. auto Val = Store->getValueOperand();
  326. auto Loc = Store->getDebugLoc();
  327. auto &M = *F.getParent();
  328. unsigned ValSize = M.getDataLayout().getTypeAllocSizeInBits(Val->getType());
  329. DIExpression *NewExpr = nullptr;
  330. if (Offset || VarSize > ValSize) {
  331. uint64_t Elems[] = { dwarf::DW_OP_bit_piece, Offset + Info.Offset, ValSize };
  332. NewExpr = DIExpression::get(Ctx, Elems);
  333. }
  334. else {
  335. NewExpr = DIExpression::get(Ctx, {});
  336. }
  337. DIB.insertDbgValueIntrinsic(Val, 0, Var, NewExpr, Loc, Store);
  338. }
  339. Declare->eraseFromParent();
  340. Changed = true;
  341. }
  342. }
  343. return Changed;
  344. }
  345. bool runOnFunction(Function &F) override {
  346. DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  347. AssumptionCache *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
  348. bool Changed = false;
  349. Changed |= RewriteOutputArgsDebugInfo(F);
  350. Changed |= RemoveAllUnusedAllocas(F);
  351. Changed |= ScalarizePreciseVectorAlloca(F);
  352. Changed |= Mem2Reg(F, *DT, *AC);
  353. return Changed;
  354. }
  355. };
  356. char DxilConditionalMem2Reg::ID;
  357. Pass *llvm::createDxilConditionalMem2RegPass(bool NoOpt) {
  358. return new DxilConditionalMem2Reg(NoOpt);
  359. }
  360. INITIALIZE_PASS_BEGIN(DxilConditionalMem2Reg, "dxil-cond-mem2reg", "Dxil Conditional Mem2Reg", false, false)
  361. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  362. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  363. INITIALIZE_PASS_END(DxilConditionalMem2Reg, "dxil-cond-mem2reg", "Dxil Conditional Mem2Reg", false, false)