HoistConstantArray.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. //===- HoistConstantArray.cpp - Code to perform constant array hoisting ---===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. // Copyright (C) Microsoft Corporation. All rights reserved.
  9. //
  10. //===----------------------------------------------------------------------===//
  11. //
  12. // This file implements hoisting of constant local arrays to global arrays.
  13. // The idea is to change the array initialization from function local memory
  14. // using alloca and stores to global constant memory using a global variable
  15. // and constant initializer. We only hoist arrays that have all constant elements.
  16. // The frontend will hoist the arrays if they are declared static, but we can
  17. // hoist any array that is only ever initialized with constant data.
  18. //
  19. // This transformation was developed to work with the dxil produced from the
  20. // hlsl compiler. Hoisting the array to use a constant initializer should allow
  21. // a dxil backend compiler to generate more efficent code than a local array.
  22. // For example, it could use an immediate constant pool to represent the array.
  23. //
  24. // We limit hoisting to those arrays that are initialized by constant values.
  25. // We still hoist if the array is partially initialized as long as no
  26. // non-constant values are written. The uninitialized values will be hoisted
  27. // as undef values.
  28. //
  29. // Improvements:
  30. // Currently we do not merge arrays that have the same constant values. We
  31. // create the global variables with `unnamed_addr` set which means they
  32. // can be merged with other constants. We should probably use a separate
  33. // pass to merge all the unnamed_addr constants.
  34. //
  35. // Example:
  36. //
  37. // float main(int i : I) : SV_Target{
  38. // float A[] = { 1, 2, 3 };
  39. // return A[i];
  40. // }
  41. //
  42. // Without array hoisting, we generate the following dxil
  43. //
  44. // define void @main() {
  45. // entry:
  46. // %0 = call i32 @dx.op.loadInput.i32(i32 4, i32 0, i32 0, i8 0, i32 undef)
  47. // %A = alloca[3 x float], align 4
  48. // %1 = getelementptr inbounds[3 x float], [3 x float] * %A, i32 0, i32 0
  49. // store float 1.000000e+00, float* %1, align 4
  50. // %2 = getelementptr inbounds[3 x float], [3 x float] * %A, i32 0, i32 1
  51. // store float 2.000000e+00, float* %2, align 4
  52. // %3 = getelementptr inbounds[3 x float], [3 x float] * %A, i32 0, i32 2
  53. // store float 3.000000e+00, float* %3, align 4
  54. // %arrayidx = getelementptr inbounds[3 x float], [3 x float] * %A, i32 0, i32 %0
  55. // %4 = load float, float* %arrayidx, align 4, !tbaa !14
  56. // call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %4);
  57. // ret void
  58. // }
  59. //
  60. // With array hoisting enabled we generate this dxil
  61. //
  62. // @A.hca = internal unnamed_addr constant [3 x float] [float 1.000000e+00, float 2.000000e+00, float 3.000000e+00]
  63. // define void @main() {
  64. // entry:
  65. // %0 = call i32 @dx.op.loadInput.i32(i32 4, i32 0, i32 0, i8 0, i32 undef)
  66. // %arrayidx = getelementptr inbounds[3 x float], [3 x float] * @A.hca, i32 0, i32 %0
  67. // %1 = load float, float* %arrayidx, align 4, !tbaa !14
  68. // call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %1)
  69. // ret void
  70. // }
  71. //
  72. //===----------------------------------------------------------------------===//
  73. #include "llvm/Transforms/Scalar.h"
  74. #include "llvm/Pass.h"
  75. #include "llvm/IR/Type.h"
  76. #include "llvm/IR/Constant.h"
  77. #include "llvm/IR/Instruction.h"
  78. #include "llvm/IR/Instructions.h"
  79. #include "llvm/IR/GlobalVariable.h"
  80. #include "llvm/IR/Module.h"
  81. #include "llvm/IR/Function.h"
  82. #include "llvm/IR/Operator.h"
  83. #include "llvm/Support/Casting.h"
  84. #include "llvm/Analysis/ValueTracking.h"
  85. using namespace llvm;
  86. namespace {
  87. class CandidateArray;
  88. //===--------------------------------------------------------------------===//
  89. // HoistConstantArray pass implementation
  90. //
  91. class HoistConstantArray : public ModulePass {
  92. public:
  93. static char ID; // Pass identification, replacement for typeid
  94. HoistConstantArray() : ModulePass(ID) {
  95. initializeHoistConstantArrayPass(*PassRegistry::getPassRegistry());
  96. }
  97. bool runOnModule(Module &M) override;
  98. void getAnalysisUsage(AnalysisUsage &AU) const override {
  99. AU.setPreservesCFG();
  100. }
  101. private:
  102. bool runOnFunction(Function &F);
  103. std::vector<AllocaInst *> findCandidateAllocas(Function &F);
  104. void hoistArray(const CandidateArray &candidate);
  105. void removeLocalArrayStores(const CandidateArray &candidate);
  106. };
  107. // Represents an array we are considering for hoisting.
  108. // Contains helper routines for analyzing if hoisting is possible
  109. // and creating the global variable for the hoisted array.
  110. class CandidateArray {
  111. public:
  112. explicit CandidateArray(AllocaInst *);
  113. bool IsConstArray() const { return m_IsConstArray; }
  114. void AnalyzeUses();
  115. GlobalVariable *GetGlobalArray() const;
  116. AllocaInst *GetLocalArray() const { return m_Alloca; }
  117. std::vector<StoreInst*> GetArrayStores() const;
  118. private:
  119. AllocaInst *m_Alloca;
  120. ArrayType *m_ArrayType;
  121. std::vector<Constant *> m_Values;
  122. bool m_IsConstArray;
  123. bool AnalyzeStore(StoreInst *SI);
  124. bool StoreConstant(int64_t index, Constant *value);
  125. void EnsureSize();
  126. void GetArrayStores(GEPOperator *gep,
  127. std::vector<StoreInst *> &stores) const;
  128. bool AllArrayUsersAreGEPOrLifetime(std::vector<GEPOperator *> &geps);
  129. bool AllGEPUsersAreValid(GEPOperator *gep);
  130. UndefValue *UndefElement();
  131. };
  132. }
  133. // Returns the ArrayType for the alloca or nullptr if the alloca
  134. // does not allocate an array.
  135. static ArrayType *getAllocaArrayType(AllocaInst *allocaInst) {
  136. return dyn_cast<ArrayType>(allocaInst->getType()->getPointerElementType());
  137. }
  138. // Check if the instruction is an alloca that we should consider for hoisting.
  139. // The alloca must allocate and array of primitive types.
  140. static AllocaInst *isHoistableArrayAlloca(Instruction *I) {
  141. AllocaInst *allocaInst = dyn_cast<AllocaInst>(I);
  142. if (!allocaInst)
  143. return nullptr;
  144. ArrayType *arrayTy = getAllocaArrayType(allocaInst);
  145. if (!arrayTy)
  146. return nullptr;
  147. if (!arrayTy->getElementType()->isSingleValueType())
  148. return nullptr;
  149. return allocaInst;
  150. }
  151. // ----------------------------------------------------------------------------
  152. // CandidateArray implementation
  153. // ----------------------------------------------------------------------------
  154. // Create the candidate array for the alloca.
  155. CandidateArray::CandidateArray(AllocaInst *AI)
  156. : m_Alloca(AI)
  157. , m_Values()
  158. , m_IsConstArray(false)
  159. {
  160. assert(isHoistableArrayAlloca(AI));
  161. m_ArrayType = getAllocaArrayType(AI);
  162. }
  163. // Get the global variable with a constant initializer for the array.
  164. // Only valid to call if the array has been analyzed as a constant array.
  165. GlobalVariable *CandidateArray::GetGlobalArray() const {
  166. assert(IsConstArray());
  167. Constant *initializer = ConstantArray::get(m_ArrayType, m_Values);
  168. Module *M = m_Alloca->getModule();
  169. GlobalVariable *GV = new GlobalVariable(*M, m_ArrayType, true, GlobalVariable::LinkageTypes::InternalLinkage, initializer, Twine(m_Alloca->getName()) + ".hca");
  170. GV->setUnnamedAddr(true);
  171. return GV;
  172. }
  173. // Get a list of all the stores that write to the array through one or more
  174. // GetElementPtrInst operations.
  175. std::vector<StoreInst *> CandidateArray::GetArrayStores() const {
  176. std::vector<StoreInst *> stores;
  177. for (User *U : m_Alloca->users())
  178. if (GEPOperator *gep = dyn_cast<GEPOperator>(U))
  179. GetArrayStores(gep, stores);
  180. return stores;
  181. }
  182. // Recursively collect all the stores that write to the pointer/buffer
  183. // referred to by this GetElementPtrInst.
  184. void CandidateArray::GetArrayStores(GEPOperator *gep,
  185. std::vector<StoreInst *> &stores) const {
  186. for (User *GU : gep->users()) {
  187. if (StoreInst *SI = dyn_cast<StoreInst>(GU)) {
  188. stores.push_back(SI);
  189. }
  190. else if (GEPOperator *GEPI = dyn_cast<GEPOperator>(GU)) {
  191. GetArrayStores(GEPI, stores);
  192. }
  193. }
  194. }
  195. // Check to see that all the users of the array are GEPs or lifetime intrinsics.
  196. // If so, populate the `geps` vector with a list of all geps that use the array.
  197. bool CandidateArray::AllArrayUsersAreGEPOrLifetime(std::vector<GEPOperator *> &geps) {
  198. for (User *U : m_Alloca->users()) {
  199. // Allow users that are only used by lifetime intrinsics.
  200. if (onlyUsedByLifetimeMarkers(U))
  201. continue;
  202. GEPOperator *gep = dyn_cast<GEPOperator>(U);
  203. if (!gep)
  204. return false;
  205. geps.push_back(gep);
  206. }
  207. return true;
  208. }
  209. // Check that all gep uses are valid.
  210. // A valid use is either
  211. // 1. A store of a constant value that does not overwrite an existing constant
  212. // with a different value.
  213. // 2. A load instruction.
  214. // 3. Another GetElementPtrInst that itself only has valid uses (recursively)
  215. // Any other use is considered invalid.
  216. bool CandidateArray::AllGEPUsersAreValid(GEPOperator *gep) {
  217. for (User *U : gep->users()) {
  218. if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
  219. if (!AnalyzeStore(SI))
  220. return false;
  221. }
  222. else if (GEPOperator *recursive_gep = dyn_cast<GEPOperator>(U)) {
  223. if (!AllGEPUsersAreValid(recursive_gep))
  224. return false;
  225. }
  226. else if (!isa<LoadInst>(U)) {
  227. return false;
  228. }
  229. }
  230. return true;
  231. }
  232. // Analyze all uses of the array to see if it qualifes as a constant array.
  233. // We check the following conditions:
  234. // 1. Make sure alloca is only used by GEP and lifetime intrinsics.
  235. // 2. Make sure GEP is only used in load/store.
  236. // 3. Make sure all stores have constant indicies.
  237. // 4. Make sure all stores are constants.
  238. // 5. Make sure all stores to same location are the same constant.
  239. void CandidateArray::AnalyzeUses() {
  240. m_IsConstArray = false;
  241. std::vector<GEPOperator *> geps;
  242. if (!AllArrayUsersAreGEPOrLifetime(geps))
  243. return;
  244. for (GEPOperator *gep : geps)
  245. if (!AllGEPUsersAreValid(gep))
  246. return;
  247. m_IsConstArray = true;
  248. }
  249. // Analyze a store to see if it is a valid constant store.
  250. // A valid store will write a constant value to a known (constant) location.
  251. bool CandidateArray::AnalyzeStore(StoreInst *SI) {
  252. if (!isa<Constant>(SI->getValueOperand()))
  253. return false;
  254. // Walk up the ladder of GetElementPtr instructions to accumulate the index
  255. int64_t index = 0;
  256. for (auto iter = SI->getPointerOperand(); iter != m_Alloca;) {
  257. GEPOperator *gep = cast<GEPOperator>(iter);
  258. if (!gep->hasAllConstantIndices())
  259. return false;
  260. // Deal with the 'extra 0' index from what might have been a global pointer
  261. // https://www.llvm.org/docs/GetElementPtr.html#why-is-the-extra-0-index-required
  262. if ((gep->getNumIndices() == 2) && (gep->getPointerOperand() == m_Alloca)) {
  263. // Non-zero offset is unexpected, but could occur in the wild. Bail out if
  264. // we see it.
  265. ConstantInt *ptrOffset = cast<ConstantInt>(gep->getOperand(1));
  266. if (!ptrOffset->isZero())
  267. return false;
  268. }
  269. else if (gep->getNumIndices() != 1) {
  270. return false;
  271. }
  272. // Accumulate the index
  273. ConstantInt *c = cast<ConstantInt>(gep->getOperand(gep->getNumIndices()));
  274. index += c->getSExtValue();
  275. iter = gep->getPointerOperand();
  276. }
  277. return StoreConstant(index, cast<Constant>(SI->getValueOperand()));
  278. }
  279. // Check if the store is valid and record the value if so.
  280. // A valid constant store is either:
  281. // 1. A store of a new constant
  282. // 2. A store of the same constant to the same location
  283. bool CandidateArray::StoreConstant(int64_t index, Constant *value) {
  284. EnsureSize();
  285. size_t i = static_cast<size_t>(index);
  286. if (i >= m_Values.size())
  287. return false;
  288. if (m_Values[i] == UndefElement())
  289. m_Values[i] = value;
  290. return m_Values[i] == value;
  291. }
  292. // We lazily create the values array until we have a store of a
  293. // constant that we need to remember. This avoids memory overhead
  294. // for obviously non-constant arrays.
  295. void CandidateArray::EnsureSize() {
  296. if (m_Values.size() == 0) {
  297. m_Values.resize(m_ArrayType->getNumElements(), UndefElement());
  298. }
  299. assert(m_Values.size() == m_ArrayType->getNumElements());
  300. }
  301. // Get an undef value of the correct type for the array.
  302. UndefValue *CandidateArray::UndefElement() {
  303. return UndefValue::get(m_ArrayType->getElementType());
  304. }
  305. // ----------------------------------------------------------------------------
  306. // Pass Implementation
  307. // ----------------------------------------------------------------------------
  308. // Find the allocas that are candidates for array hoisting in the function.
  309. std::vector<AllocaInst*> HoistConstantArray::findCandidateAllocas(Function &F) {
  310. std::vector<AllocaInst*> candidates;
  311. for (Instruction &I : F.getEntryBlock())
  312. if (AllocaInst *allocaInst = isHoistableArrayAlloca(&I))
  313. candidates.push_back(allocaInst);
  314. return candidates;
  315. }
  316. // Remove local stores to the array.
  317. // We remove them explicitly rather than relying on DCE to find they are dead.
  318. // Other uses (e.g. geps) can be easily cleaned up by DCE.
  319. void HoistConstantArray::removeLocalArrayStores(const CandidateArray &candidate) {
  320. std::vector<StoreInst*> stores = candidate.GetArrayStores();
  321. for (StoreInst *store : stores)
  322. store->eraseFromParent();
  323. }
  324. // Hoist an array from a local to a global.
  325. void HoistConstantArray::hoistArray(const CandidateArray &candidate) {
  326. assert(candidate.IsConstArray());
  327. removeLocalArrayStores(candidate);
  328. AllocaInst *local = candidate.GetLocalArray();
  329. GlobalVariable *global = candidate.GetGlobalArray();
  330. local->replaceAllUsesWith(global);
  331. local->eraseFromParent();
  332. }
  333. // Perform array hoisting on a single function.
  334. bool HoistConstantArray::runOnFunction(Function &F) {
  335. bool changed = false;
  336. std::vector<AllocaInst *> candidateAllocas = findCandidateAllocas(F);
  337. for (AllocaInst *AI : candidateAllocas) {
  338. CandidateArray candidate(AI);
  339. candidate.AnalyzeUses();
  340. if (candidate.IsConstArray()) {
  341. hoistArray(candidate);
  342. changed |= true;
  343. }
  344. }
  345. return changed;
  346. }
  347. char HoistConstantArray::ID = 0;
  348. INITIALIZE_PASS(HoistConstantArray, "hlsl-hca", "Hoist constant arrays", false, false)
  349. bool HoistConstantArray::runOnModule(Module &M) {
  350. bool changed = false;
  351. for (Function &F : M) {
  352. if (F.isDeclaration())
  353. continue;
  354. changed |= runOnFunction(F);
  355. }
  356. return changed;
  357. }
  358. ModulePass *llvm::createHoistConstantArrayPass() {
  359. return new HoistConstantArray();
  360. }