HoistConstantArray.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. //===- HoistConstantArray.cpp - Code to perform constant array hoisting ---===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. // Copyright (C) Microsoft Corporation. All rights reserved.
  9. //
  10. //===----------------------------------------------------------------------===//
  11. //
  12. // This file implements hoisting of constant local arrays to global arrays.
  13. // The idea is to change the array initialization from function local memory
  14. // using alloca and stores to global constant memory using a global variable
  15. // and constant initializer. We only hoist arrays that have all constant elements.
  16. // The frontend will hoist the arrays if they are declared static, but we can
  17. // hoist any array that is only ever initialized with constant data.
  18. //
  19. // This transformation was developed to work with the dxil produced from the
  20. // hlsl compiler. Hoisting the array to use a constant initializer should allow
  21. // a dxil backend compiler to generate more efficent code than a local array.
  22. // For example, it could use an immediate constant pool to represent the array.
  23. //
  24. // We limit hoisting to those arrays that are initialized by constant values.
  25. // We still hoist if the array is partially initialized as long as no
  26. // non-constant values are written. The uninitialized values will be hoisted
  27. // as undef values.
  28. //
  29. // Improvements:
  30. // Currently we do not merge arrays that have the same constant values. We
  31. // create the global variables with `unnamed_addr` set which means they
  32. // can be merged with other constants. We should probably use a separate
  33. // pass to merge all the unnamed_addr constants.
  34. //
  35. // Example:
  36. //
  37. // float main(int i : I) : SV_Target{
  38. // float A[] = { 1, 2, 3 };
  39. // return A[i];
  40. // }
  41. //
  42. // Without array hoisting, we generate the following dxil
  43. //
  44. // define void @main() {
  45. // entry:
  46. // %0 = call i32 @dx.op.loadInput.i32(i32 4, i32 0, i32 0, i8 0, i32 undef)
  47. // %A = alloca[3 x float], align 4
  48. // %1 = getelementptr inbounds[3 x float], [3 x float] * %A, i32 0, i32 0
  49. // store float 1.000000e+00, float* %1, align 4
  50. // %2 = getelementptr inbounds[3 x float], [3 x float] * %A, i32 0, i32 1
  51. // store float 2.000000e+00, float* %2, align 4
  52. // %3 = getelementptr inbounds[3 x float], [3 x float] * %A, i32 0, i32 2
  53. // store float 3.000000e+00, float* %3, align 4
  54. // %arrayidx = getelementptr inbounds[3 x float], [3 x float] * %A, i32 0, i32 %0
  55. // %4 = load float, float* %arrayidx, align 4, !tbaa !14
  56. // call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %4);
  57. // ret void
  58. // }
  59. //
  60. // With array hoisting enabled we generate this dxil
  61. //
  62. // @A.hca = internal unnamed_addr constant [3 x float] [float 1.000000e+00, float 2.000000e+00, float 3.000000e+00]
  63. // define void @main() {
  64. // entry:
  65. // %0 = call i32 @dx.op.loadInput.i32(i32 4, i32 0, i32 0, i8 0, i32 undef)
  66. // %arrayidx = getelementptr inbounds[3 x float], [3 x float] * @A.hca, i32 0, i32 %0
  67. // %1 = load float, float* %arrayidx, align 4, !tbaa !14
  68. // call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %1)
  69. // ret void
  70. // }
  71. //
  72. //===----------------------------------------------------------------------===//
  73. #include "llvm/Transforms/Scalar.h"
  74. #include "llvm/Pass.h"
  75. #include "llvm/IR/Type.h"
  76. #include "llvm/IR/Constant.h"
  77. #include "llvm/IR/Instruction.h"
  78. #include "llvm/IR/Instructions.h"
  79. #include "llvm/IR/GlobalVariable.h"
  80. #include "llvm/IR/Module.h"
  81. #include "llvm/IR/Function.h"
  82. #include "llvm/IR/Operator.h"
  83. #include "llvm/Support/Casting.h"
  84. using namespace llvm;
  85. namespace {
  86. class CandidateArray;
  87. //===--------------------------------------------------------------------===//
  88. // HoistConstantArray pass implementation
  89. //
  90. class HoistConstantArray : public ModulePass {
  91. public:
  92. static char ID; // Pass identification, replacement for typeid
  93. HoistConstantArray() : ModulePass(ID) {
  94. initializeHoistConstantArrayPass(*PassRegistry::getPassRegistry());
  95. }
  96. bool runOnModule(Module &M) override;
  97. void getAnalysisUsage(AnalysisUsage &AU) const override {
  98. AU.setPreservesCFG();
  99. }
  100. private:
  101. bool runOnFunction(Function &F);
  102. std::vector<AllocaInst *> findCandidateAllocas(Function &F);
  103. void hoistArray(const CandidateArray &candidate);
  104. void removeLocalArrayStores(const CandidateArray &candidate);
  105. };
  106. // Represents an array we are considering for hoisting.
  107. // Contains helper routines for analyzing if hoisting is possible
  108. // and creating the global variable for the hoisted array.
  109. class CandidateArray {
  110. public:
  111. explicit CandidateArray(AllocaInst *);
  112. bool IsConstArray() const { return m_IsConstArray; }
  113. void AnalyzeUses();
  114. GlobalVariable *GetGlobalArray() const;
  115. AllocaInst *GetLocalArray() const { return m_Alloca; }
  116. std::vector<StoreInst*> GetArrayStores() const;
  117. private:
  118. AllocaInst *m_Alloca;
  119. ArrayType *m_ArrayType;
  120. std::vector<Constant *> m_Values;
  121. bool m_IsConstArray;
  122. bool AnalyzeStore(StoreInst *SI);
  123. bool StoreConstant(int64_t index, Constant *value);
  124. void EnsureSize();
  125. void GetArrayStores(GEPOperator *gep,
  126. std::vector<StoreInst *> &stores) const;
  127. bool AllArrayUsersAreGEP(std::vector<GEPOperator *> &geps);
  128. bool AllGEPUsersAreValid(GEPOperator *gep);
  129. UndefValue *UndefElement();
  130. };
  131. }
  132. // Returns the ArrayType for the alloca or nullptr if the alloca
  133. // does not allocate an array.
  134. static ArrayType *getAllocaArrayType(AllocaInst *allocaInst) {
  135. return dyn_cast<ArrayType>(allocaInst->getType()->getPointerElementType());
  136. }
  137. // Check if the instruction is an alloca that we should consider for hoisting.
  138. // The alloca must allocate and array of primitive types.
  139. static AllocaInst *isHoistableArrayAlloca(Instruction *I) {
  140. AllocaInst *allocaInst = dyn_cast<AllocaInst>(I);
  141. if (!allocaInst)
  142. return nullptr;
  143. ArrayType *arrayTy = getAllocaArrayType(allocaInst);
  144. if (!arrayTy)
  145. return nullptr;
  146. if (!arrayTy->getElementType()->isSingleValueType())
  147. return nullptr;
  148. return allocaInst;
  149. }
  150. // ----------------------------------------------------------------------------
  151. // CandidateArray implementation
  152. // ----------------------------------------------------------------------------
  153. // Create the candidate array for the alloca.
  154. CandidateArray::CandidateArray(AllocaInst *AI)
  155. : m_Alloca(AI)
  156. , m_Values()
  157. , m_IsConstArray(false)
  158. {
  159. assert(isHoistableArrayAlloca(AI));
  160. m_ArrayType = getAllocaArrayType(AI);
  161. }
  162. // Get the global variable with a constant initializer for the array.
  163. // Only valid to call if the array has been analyzed as a constant array.
  164. GlobalVariable *CandidateArray::GetGlobalArray() const {
  165. assert(IsConstArray());
  166. Constant *initializer = ConstantArray::get(m_ArrayType, m_Values);
  167. Module *M = m_Alloca->getModule();
  168. GlobalVariable *GV = new GlobalVariable(*M, m_ArrayType, true, GlobalVariable::LinkageTypes::InternalLinkage, initializer, Twine(m_Alloca->getName()) + ".hca");
  169. GV->setUnnamedAddr(true);
  170. return GV;
  171. }
  172. // Get a list of all the stores that write to the array through one or more
  173. // GetElementPtrInst operations.
  174. std::vector<StoreInst *> CandidateArray::GetArrayStores() const {
  175. std::vector<StoreInst *> stores;
  176. for (User *U : m_Alloca->users())
  177. if (GEPOperator *gep = dyn_cast<GEPOperator>(U))
  178. GetArrayStores(gep, stores);
  179. return stores;
  180. }
  181. // Recursively collect all the stores that write to the pointer/buffer
  182. // referred to by this GetElementPtrInst.
  183. void CandidateArray::GetArrayStores(GEPOperator *gep,
  184. std::vector<StoreInst *> &stores) const {
  185. for (User *GU : gep->users()) {
  186. if (StoreInst *SI = dyn_cast<StoreInst>(GU)) {
  187. stores.push_back(SI);
  188. }
  189. else if (GEPOperator *GEPI = dyn_cast<GEPOperator>(GU)) {
  190. GetArrayStores(GEPI, stores);
  191. }
  192. }
  193. }
  194. // Check to see that all the users of the array are GEPs.
  195. // If so, populate the `geps` vector with a list of all geps that use the array.
  196. bool CandidateArray::AllArrayUsersAreGEP(std::vector<GEPOperator *> &geps) {
  197. for (User *U : m_Alloca->users()) {
  198. GEPOperator *gep = dyn_cast<GEPOperator>(U);
  199. if (!gep)
  200. return false;
  201. geps.push_back(gep);
  202. }
  203. return true;
  204. }
  205. // Check that all gep uses are valid.
  206. // A valid use is either
  207. // 1. A store of a constant value that does not overwrite an existing constant
  208. // with a different value.
  209. // 2. A load instruction.
  210. // 3. Another GetElementPtrInst that itself only has valid uses (recursively)
  211. // Any other use is considered invalid.
  212. bool CandidateArray::AllGEPUsersAreValid(GEPOperator *gep) {
  213. for (User *U : gep->users()) {
  214. if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
  215. if (!AnalyzeStore(SI))
  216. return false;
  217. }
  218. else if (GEPOperator *recursive_gep = dyn_cast<GEPOperator>(U)) {
  219. if (!AllGEPUsersAreValid(recursive_gep))
  220. return false;
  221. }
  222. else if (!isa<LoadInst>(U)) {
  223. return false;
  224. }
  225. }
  226. return true;
  227. }
  228. // Analyze all uses of the array to see if it qualifes as a constant array.
  229. // We check the following conditions:
  230. // 1. Make sure alloca is only used by GEP.
  231. // 2. Make sure GEP is only used in load/store.
  232. // 3. Make sure all stores have constant indicies.
  233. // 4. Make sure all stores are constants.
  234. // 5. Make sure all stores to same location are the same constant.
  235. void CandidateArray::AnalyzeUses() {
  236. m_IsConstArray = false;
  237. std::vector<GEPOperator *> geps;
  238. if (!AllArrayUsersAreGEP(geps))
  239. return;
  240. for (GEPOperator *gep : geps)
  241. if (!AllGEPUsersAreValid(gep))
  242. return;
  243. m_IsConstArray = true;
  244. }
  245. // Analyze a store to see if it is a valid constant store.
  246. // A valid store will write a constant value to a known (constant) location.
  247. bool CandidateArray::AnalyzeStore(StoreInst *SI) {
  248. if (!isa<Constant>(SI->getValueOperand()))
  249. return false;
  250. // Walk up the ladder of GetElementPtr instructions to accumulate the index
  251. int64_t index = 0;
  252. for (auto iter = SI->getPointerOperand(); iter != m_Alloca;) {
  253. GEPOperator *gep = cast<GEPOperator>(iter);
  254. if (!gep->hasAllConstantIndices())
  255. return false;
  256. // Deal with the 'extra 0' index from what might have been a global pointer
  257. // https://www.llvm.org/docs/GetElementPtr.html#why-is-the-extra-0-index-required
  258. if ((gep->getNumIndices() == 2) && (gep->getPointerOperand() == m_Alloca)) {
  259. // Non-zero offset is unexpected, but could occur in the wild. Bail out if
  260. // we see it.
  261. ConstantInt *ptrOffset = cast<ConstantInt>(gep->getOperand(1));
  262. if (!ptrOffset->isZero())
  263. return false;
  264. }
  265. else if (gep->getNumIndices() != 1) {
  266. return false;
  267. }
  268. // Accumulate the index
  269. ConstantInt *c = cast<ConstantInt>(gep->getOperand(gep->getNumIndices()));
  270. index += c->getSExtValue();
  271. iter = gep->getPointerOperand();
  272. }
  273. return StoreConstant(index, cast<Constant>(SI->getValueOperand()));
  274. }
  275. // Check if the store is valid and record the value if so.
  276. // A valid constant store is either:
  277. // 1. A store of a new constant
  278. // 2. A store of the same constant to the same location
  279. bool CandidateArray::StoreConstant(int64_t index, Constant *value) {
  280. EnsureSize();
  281. size_t i = static_cast<size_t>(index);
  282. if (i >= m_Values.size())
  283. return false;
  284. if (m_Values[i] == UndefElement())
  285. m_Values[i] = value;
  286. return m_Values[i] == value;
  287. }
  288. // We lazily create the values array until we have a store of a
  289. // constant that we need to remember. This avoids memory overhead
  290. // for obviously non-constant arrays.
  291. void CandidateArray::EnsureSize() {
  292. if (m_Values.size() == 0) {
  293. m_Values.resize(m_ArrayType->getNumElements(), UndefElement());
  294. }
  295. assert(m_Values.size() == m_ArrayType->getNumElements());
  296. }
  297. // Get an undef value of the correct type for the array.
  298. UndefValue *CandidateArray::UndefElement() {
  299. return UndefValue::get(m_ArrayType->getElementType());
  300. }
  301. // ----------------------------------------------------------------------------
  302. // Pass Implementation
  303. // ----------------------------------------------------------------------------
  304. // Find the allocas that are candidates for array hoisting in the function.
  305. std::vector<AllocaInst*> HoistConstantArray::findCandidateAllocas(Function &F) {
  306. std::vector<AllocaInst*> candidates;
  307. for (Instruction &I : F.getEntryBlock())
  308. if (AllocaInst *allocaInst = isHoistableArrayAlloca(&I))
  309. candidates.push_back(allocaInst);
  310. return candidates;
  311. }
  312. // Remove local stores to the array.
  313. // We remove them explicitly rather than relying on DCE to find they are dead.
  314. // Other uses (e.g. geps) can be easily cleaned up by DCE.
  315. void HoistConstantArray::removeLocalArrayStores(const CandidateArray &candidate) {
  316. std::vector<StoreInst*> stores = candidate.GetArrayStores();
  317. for (StoreInst *store : stores)
  318. store->eraseFromParent();
  319. }
  320. // Hoist an array from a local to a global.
  321. void HoistConstantArray::hoistArray(const CandidateArray &candidate) {
  322. assert(candidate.IsConstArray());
  323. removeLocalArrayStores(candidate);
  324. AllocaInst *local = candidate.GetLocalArray();
  325. GlobalVariable *global = candidate.GetGlobalArray();
  326. local->replaceAllUsesWith(global);
  327. local->eraseFromParent();
  328. }
  329. // Perform array hoisting on a single function.
  330. bool HoistConstantArray::runOnFunction(Function &F) {
  331. bool changed = false;
  332. std::vector<AllocaInst *> candidateAllocas = findCandidateAllocas(F);
  333. for (AllocaInst *AI : candidateAllocas) {
  334. CandidateArray candidate(AI);
  335. candidate.AnalyzeUses();
  336. if (candidate.IsConstArray()) {
  337. hoistArray(candidate);
  338. changed |= true;
  339. }
  340. }
  341. return changed;
  342. }
  343. char HoistConstantArray::ID = 0;
  344. INITIALIZE_PASS(HoistConstantArray, "hlsl-hca", "Hoist constant arrays", false, false)
  345. bool HoistConstantArray::runOnModule(Module &M) {
  346. bool changed = false;
  347. for (Function &F : M) {
  348. if (F.isDeclaration())
  349. continue;
  350. changed |= runOnFunction(F);
  351. }
  352. return changed;
  353. }
  354. ModulePass *llvm::createHoistConstantArrayPass() {
  355. return new HoistConstantArray();
  356. }