DxilFixConstArrayInitializer.cpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. //===- DxilFixConstArrayInitializer.cpp - Special Construct Initializer ------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "llvm/Pass.h"
  10. #include "llvm/IR/Module.h"
  11. #include "llvm/IR/GlobalVariable.h"
  12. #include "llvm/IR/Constants.h"
  13. #include "llvm/IR/Instructions.h"
  14. #include "llvm/IR/Operator.h"
  15. #include "llvm/IR/CFG.h"
  16. #include "llvm/Transforms/Scalar.h"
  17. #include "dxc/DXIL/DxilModule.h"
  18. #include "dxc/HLSL/HLModule.h"
  19. #include <unordered_map>
  20. #include <limits>
  21. using namespace llvm;
  22. namespace {
  23. class DxilFixConstArrayInitializer : public ModulePass {
  24. public:
  25. static char ID;
  26. DxilFixConstArrayInitializer() : ModulePass(ID) {
  27. initializeDxilFixConstArrayInitializerPass(*PassRegistry::getPassRegistry());
  28. }
  29. bool runOnModule(Module &M) override;
  30. const char *getPassName() const override { return "Dxil Fix Const Array Initializer"; }
  31. };
  32. char DxilFixConstArrayInitializer::ID;
  33. }
  34. static bool TryFixGlobalVariable(GlobalVariable &GV, BasicBlock *EntryBlock, const std::unordered_map<Instruction *, unsigned> &InstOrder) {
  35. // Only proceed if the variable has an undef initializer
  36. if (!GV.hasInitializer() || !isa<UndefValue>(GV.getInitializer()))
  37. return false;
  38. // Only handle cases when it's an array of scalars.
  39. Type *Ty = GV.getType()->getPointerElementType();
  40. if (!Ty->isArrayTy())
  41. return false;
  42. // Don't handle arrays that are too big
  43. if (Ty->getArrayNumElements() > 1024)
  44. return false;
  45. Type *ElementTy = Ty->getArrayElementType();
  46. // Only handle arrays of scalar types
  47. if (ElementTy->isAggregateType() || ElementTy->isVectorTy())
  48. return false;
  49. // The instruction index at which point we no longer consider it
  50. // safe to fold Stores. It's the earliest store with non-constant index,
  51. // earliest store with non-constant value, or a load
  52. unsigned FirstUnsafeIndex = std::numeric_limits<unsigned>::max();
  53. SmallVector<StoreInst *, 8> PossibleFoldableStores;
  54. // First do a pass to find the boundary for where we could fold stores. Get a
  55. // list of stores that may be folded.
  56. for (User *U : GV.users()) {
  57. if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
  58. bool AllConstIndices = GEP->hasAllConstantIndices();
  59. unsigned NumIndices = GEP->getNumIndices();
  60. if (NumIndices != 2)
  61. return false;
  62. for (User *GEPUser : GEP->users()) {
  63. if (StoreInst *Store = dyn_cast<StoreInst>(GEPUser)) {
  64. if (Store->getParent() != EntryBlock)
  65. continue;
  66. unsigned StoreIndex = InstOrder.at(Store);
  67. if (!AllConstIndices || !isa<Constant>(Store->getValueOperand())) {
  68. FirstUnsafeIndex = std::min(StoreIndex, FirstUnsafeIndex);
  69. continue;
  70. }
  71. PossibleFoldableStores.push_back(Store);
  72. }
  73. else if (LoadInst *Load = dyn_cast<LoadInst>(GEPUser)) {
  74. if (Load->getParent() != EntryBlock)
  75. continue;
  76. FirstUnsafeIndex = std::min(FirstUnsafeIndex, InstOrder.at(Load));
  77. }
  78. // If we have something weird like chained GEPS, or bitcasts, give up.
  79. else {
  80. return false;
  81. }
  82. }
  83. }
  84. }
  85. SmallVector<Constant *, 16> InitValue;
  86. SmallVector<unsigned, 16> LatestStores;
  87. SmallVector<StoreInst *, 8> StoresToRemove;
  88. InitValue.resize(Ty->getArrayNumElements());
  89. LatestStores.resize(Ty->getArrayNumElements());
  90. for (StoreInst *Store : PossibleFoldableStores) {
  91. unsigned StoreIndex = InstOrder.at(Store);
  92. // Skip stores that are out of bounds
  93. if (StoreIndex >= FirstUnsafeIndex)
  94. continue;
  95. GEPOperator *GEP = cast<GEPOperator>(Store->getPointerOperand());
  96. uint64_t Index = cast<ConstantInt>(GEP->getOperand(2))->getLimitedValue();
  97. if (LatestStores[Index] <= StoreIndex) {
  98. InitValue[Index] = cast<Constant>(Store->getValueOperand());
  99. LatestStores[Index] = StoreIndex;
  100. }
  101. StoresToRemove.push_back(Store);
  102. }
  103. // Give up if we have missing indices
  104. for (Constant *C : InitValue)
  105. if (!C)
  106. return false;
  107. GV.setInitializer(ConstantArray::get(cast<ArrayType>(Ty), InitValue));
  108. for (StoreInst *Store : StoresToRemove)
  109. Store->eraseFromParent();
  110. return true;
  111. }
  112. bool DxilFixConstArrayInitializer::runOnModule(Module &M) {
  113. BasicBlock *EntryBlock = nullptr;
  114. if (M.HasDxilModule()) {
  115. hlsl::DxilModule &DM = M.GetDxilModule();
  116. if (DM.GetEntryFunction()) {
  117. EntryBlock = &DM.GetEntryFunction()->getEntryBlock();
  118. }
  119. }
  120. else if (M.HasHLModule()) {
  121. hlsl::HLModule &HM = M.GetHLModule();
  122. if (HM.GetEntryFunction())
  123. EntryBlock = &HM.GetEntryFunction()->getEntryBlock();
  124. }
  125. if (!EntryBlock)
  126. return false;
  127. // If some block might branch to the entry for some reason (like if it's a loop header),
  128. // give up now. Have to make sure this block is not preceeded by anything.
  129. if (pred_begin(EntryBlock) != pred_end(EntryBlock))
  130. return false;
  131. // Find the instruction order for everything in the entry block.
  132. std::unordered_map<Instruction *, unsigned> InstOrder;
  133. for (Instruction &I : *EntryBlock) {
  134. InstOrder[&I] = InstOrder.size();
  135. }
  136. bool Changed = false;
  137. for (GlobalVariable &GV : M.globals()) {
  138. Changed = TryFixGlobalVariable(GV, EntryBlock, InstOrder);
  139. }
  140. return Changed;
  141. }
  142. Pass *llvm::createDxilFixConstArrayInitializerPass() {
  143. return new DxilFixConstArrayInitializer();
  144. }
  145. INITIALIZE_PASS(DxilFixConstArrayInitializer, "dxil-fix-array-init", "Dxil Fix Array Initializer", false, false)