HLMatrixBitcastLowerPass.cpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixBitcastLowerPass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/HLSL/HLMatrixLowerPass.h"
  10. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  11. #include "dxc/HLSL/HLMatrixType.h"
  12. #include "dxc/DXIL/DxilUtil.h"
  13. #include "dxc/Support/Global.h"
  14. #include "dxc/DXIL/DxilOperations.h"
  15. #include "dxc/DXIL/DxilModule.h"
  16. #include "dxc/HLSL/DxilGenerationPass.h"
  17. #include "llvm/IR/IRBuilder.h"
  18. #include "llvm/IR/Module.h"
  19. #include "llvm/Pass.h"
  20. #include <unordered_set>
  21. #include <vector>
  22. using namespace llvm;
  23. using namespace hlsl;
  24. using namespace hlsl::HLMatrixLower;
  25. // Matrix Bitcast lower.
  26. // After linking Lower matrix bitcast patterns like:
  27. // %169 = bitcast [72 x float]* %0 to [6 x %class.matrix.float.4.3]*
  28. // %conv.i = fptoui float %164 to i32
  29. // %arrayidx.i = getelementptr inbounds [6 x %class.matrix.float.4.3], [6 x %class.matrix.float.4.3]* %169, i32 0, i32 %conv.i
  30. // %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
  31. namespace {
  32. // Translate matrix type to array type.
  33. Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
  34. if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
  35. Type *EltTy = MatTy.getElementTypeForReg();
  36. return ArrayType::get(EltTy, MatTy.getNumElements());
  37. }
  38. else {
  39. return Ty;
  40. }
  41. }
  42. Type *LowerMatrixArrayPointerToOneDimArray(Type *Ty) {
  43. unsigned addrSpace = Ty->getPointerAddressSpace();
  44. Ty = Ty->getPointerElementType();
  45. unsigned arraySize = 1;
  46. while (Ty->isArrayTy()) {
  47. arraySize *= Ty->getArrayNumElements();
  48. Ty = Ty->getArrayElementType();
  49. }
  50. HLMatrixType MatTy = HLMatrixType::cast(Ty);
  51. arraySize *= MatTy.getNumElements();
  52. Ty = ArrayType::get(MatTy.getElementTypeForReg(), arraySize);
  53. return PointerType::get(Ty, addrSpace);
  54. }
  55. Type *TryLowerMatTy(Type *Ty) {
  56. Type *VecTy = nullptr;
  57. if (HLMatrixType::isMatrixArrayPtr(Ty)) {
  58. VecTy = LowerMatrixArrayPointerToOneDimArray(Ty);
  59. } else if (isa<PointerType>(Ty) && HLMatrixType::isa(Ty->getPointerElementType())) {
  60. VecTy = LowerMatrixTypeToOneDimArray(
  61. Ty->getPointerElementType());
  62. VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
  63. }
  64. return VecTy;
  65. }
  66. class MatrixBitcastLowerPass : public FunctionPass {
  67. public:
  68. static char ID; // Pass identification, replacement for typeid
  69. explicit MatrixBitcastLowerPass() : FunctionPass(ID) {}
  70. const char *getPassName() const override { return "Matrix Bitcast lower"; }
  71. bool runOnFunction(Function &F) override {
  72. bool bUpdated = false;
  73. std::unordered_set<BitCastInst*> matCastSet;
  74. for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
  75. BasicBlock *BB = blkIt;
  76. for (auto iIt = BB->begin(); iIt != BB->end(); ) {
  77. Instruction *I = (iIt++);
  78. if (BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
  79. // Mutate mat to vec.
  80. Type *ToTy = BCI->getType();
  81. if (TryLowerMatTy(ToTy)) {
  82. matCastSet.insert(BCI);
  83. bUpdated = true;
  84. }
  85. }
  86. }
  87. }
  88. DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
  89. // Remove bitcast which has CallInst user.
  90. if (DM.GetShaderModel()->IsLib()) {
  91. for (auto it = matCastSet.begin(); it != matCastSet.end();) {
  92. BitCastInst *BCI = *(it++);
  93. if (hasCallUser(BCI)) {
  94. matCastSet.erase(BCI);
  95. }
  96. }
  97. }
  98. // Lower matrix first.
  99. for (BitCastInst *BCI : matCastSet) {
  100. lowerMatrix(BCI, BCI->getOperand(0));
  101. }
  102. return bUpdated;
  103. }
  104. private:
  105. void lowerMatrix(Instruction *M, Value *A);
  106. bool hasCallUser(Instruction *M);
  107. };
  108. }
  109. bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
  110. for (auto it = M->user_begin(); it != M->user_end();) {
  111. User *U = *(it++);
  112. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  113. Type *EltTy = GEP->getType()->getPointerElementType();
  114. if (HLMatrixType::isa(EltTy)) {
  115. if (hasCallUser(GEP))
  116. return true;
  117. } else {
  118. DXASSERT(0, "invalid GEP for matrix");
  119. }
  120. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
  121. if (hasCallUser(BCI))
  122. return true;
  123. } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  124. if (isa<VectorType>(LI->getType())) {
  125. } else {
  126. DXASSERT(0, "invalid load for matrix");
  127. }
  128. } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
  129. Value *V = ST->getValueOperand();
  130. if (isa<VectorType>(V->getType())) {
  131. } else {
  132. DXASSERT(0, "invalid load for matrix");
  133. }
  134. } else if (isa<CallInst>(U)) {
  135. return true;
  136. } else {
  137. DXASSERT(0, "invalid use of matrix");
  138. }
  139. }
  140. return false;
  141. }
  142. namespace {
  143. Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
  144. IRBuilder<> &Builder) {
  145. Value *GEP = nullptr;
  146. if (GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A)) {
  147. // A should be gep oneDimArray, 0, index * matSize
  148. // Here add eltIdx to index * matSize foreach elt.
  149. Instruction *EltGEP = GEPA->clone();
  150. unsigned eltIdx = EltGEP->getNumOperands() - 1;
  151. Value *NewIdx =
  152. Builder.CreateAdd(EltGEP->getOperand(eltIdx), Builder.getInt32(i));
  153. EltGEP->setOperand(eltIdx, NewIdx);
  154. Builder.Insert(EltGEP);
  155. GEP = EltGEP;
  156. } else {
  157. GEP = Builder.CreateInBoundsGEP(A, {zeroIdx, Builder.getInt32(i)});
  158. }
  159. return GEP;
  160. }
  161. } // namespace
  162. void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
  163. for (auto it = M->user_begin(); it != M->user_end();) {
  164. User *U = *(it++);
  165. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  166. Type *EltTy = GEP->getType()->getPointerElementType();
  167. if (HLMatrixType::isa(EltTy)) {
  168. // Change gep matrixArray, 0, index
  169. // into
  170. // gep oneDimArray, 0, index * matSize
  171. IRBuilder<> Builder(GEP);
  172. SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
  173. DXASSERT(idxList.size() == 2,
  174. "else not one dim matrix array index to matrix");
  175. HLMatrixType MatTy = HLMatrixType::cast(EltTy);
  176. Value *matSize = Builder.getInt32(MatTy.getNumElements());
  177. idxList.back() = Builder.CreateMul(idxList.back(), matSize);
  178. Value *NewGEP = Builder.CreateGEP(A, idxList);
  179. lowerMatrix(GEP, NewGEP);
  180. DXASSERT(GEP->user_empty(), "else lower matrix fail");
  181. GEP->eraseFromParent();
  182. } else {
  183. DXASSERT(0, "invalid GEP for matrix");
  184. }
  185. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
  186. lowerMatrix(BCI, A);
  187. DXASSERT(BCI->user_empty(), "else lower matrix fail");
  188. BCI->eraseFromParent();
  189. } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  190. if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
  191. IRBuilder<> Builder(LI);
  192. Value *zeroIdx = Builder.getInt32(0);
  193. unsigned vecSize = Ty->getNumElements();
  194. Value *NewVec = UndefValue::get(LI->getType());
  195. for (unsigned i = 0; i < vecSize; i++) {
  196. Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
  197. Value *Elt = Builder.CreateLoad(GEP);
  198. NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
  199. }
  200. LI->replaceAllUsesWith(NewVec);
  201. LI->eraseFromParent();
  202. } else {
  203. DXASSERT(0, "invalid load for matrix");
  204. }
  205. } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
  206. Value *V = ST->getValueOperand();
  207. if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
  208. IRBuilder<> Builder(LI);
  209. Value *zeroIdx = Builder.getInt32(0);
  210. unsigned vecSize = Ty->getNumElements();
  211. for (unsigned i = 0; i < vecSize; i++) {
  212. Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
  213. Value *Elt = Builder.CreateExtractElement(V, i);
  214. Builder.CreateStore(Elt, GEP);
  215. }
  216. ST->eraseFromParent();
  217. } else {
  218. DXASSERT(0, "invalid load for matrix");
  219. }
  220. } else {
  221. DXASSERT(0, "invalid use of matrix");
  222. }
  223. }
  224. }
  225. char MatrixBitcastLowerPass::ID = 0;
  226. FunctionPass *llvm::createMatrixBitcastLowerPass() { return new MatrixBitcastLowerPass(); }
  227. INITIALIZE_PASS(MatrixBitcastLowerPass, "matrixbitcastlower", "Matrix Bitcast lower", false, false)