HLMatrixBitcastLowerPass.cpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixBitcastLowerPass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/HLSL/HLMatrixLowerPass.h"
  10. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  11. #include "dxc/HLSL/HLMatrixType.h"
  12. #include "dxc/DXIL/DxilUtil.h"
  13. #include "dxc/Support/Global.h"
  14. #include "dxc/DXIL/DxilOperations.h"
  15. #include "dxc/DXIL/DxilModule.h"
  16. #include "dxc/HLSL/DxilGenerationPass.h"
  17. #include "llvm/IR/IRBuilder.h"
  18. #include "llvm/IR/Module.h"
  19. #include "llvm/Pass.h"
  20. #include <unordered_set>
  21. #include <vector>
  22. using namespace llvm;
  23. using namespace hlsl;
  24. using namespace hlsl::HLMatrixLower;
  25. // Matrix Bitcast lower.
  26. // After linking Lower matrix bitcast patterns like:
  27. // %169 = bitcast [72 x float]* %0 to [6 x %class.matrix.float.4.3]*
  28. // %conv.i = fptoui float %164 to i32
  29. // %arrayidx.i = getelementptr inbounds [6 x %class.matrix.float.4.3], [6 x %class.matrix.float.4.3]* %169, i32 0, i32 %conv.i
  30. // %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
  31. namespace {
  32. // Translate matrix type to array type.
  33. Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
  34. if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
  35. Type *EltTy = MatTy.getElementTypeForReg();
  36. return ArrayType::get(EltTy, MatTy.getNumElements());
  37. }
  38. else {
  39. return Ty;
  40. }
  41. }
  42. Type *LowerMatrixArrayPointerToOneDimArray(Type *Ty) {
  43. unsigned addrSpace = Ty->getPointerAddressSpace();
  44. Ty = Ty->getPointerElementType();
  45. unsigned arraySize = 1;
  46. while (Ty->isArrayTy()) {
  47. arraySize *= Ty->getArrayNumElements();
  48. Ty = Ty->getArrayElementType();
  49. }
  50. HLMatrixType MatTy = HLMatrixType::cast(Ty);
  51. arraySize *= MatTy.getNumElements();
  52. Ty = ArrayType::get(MatTy.getElementTypeForReg(), arraySize);
  53. return PointerType::get(Ty, addrSpace);
  54. }
  55. Type *TryLowerMatTy(Type *Ty) {
  56. Type *VecTy = nullptr;
  57. if (HLMatrixType::isMatrixArrayPtr(Ty)) {
  58. VecTy = LowerMatrixArrayPointerToOneDimArray(Ty);
  59. } else if (isa<PointerType>(Ty) &&
  60. dxilutil::IsHLSLMatrixType(Ty->getPointerElementType())) {
  61. VecTy = LowerMatrixTypeToOneDimArray(
  62. Ty->getPointerElementType());
  63. VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
  64. }
  65. return VecTy;
  66. }
  67. class MatrixBitcastLowerPass : public FunctionPass {
  68. public:
  69. static char ID; // Pass identification, replacement for typeid
  70. explicit MatrixBitcastLowerPass() : FunctionPass(ID) {}
  71. const char *getPassName() const override { return "Matrix Bitcast lower"; }
  72. bool runOnFunction(Function &F) override {
  73. bool bUpdated = false;
  74. std::unordered_set<BitCastInst*> matCastSet;
  75. for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
  76. BasicBlock *BB = blkIt;
  77. for (auto iIt = BB->begin(); iIt != BB->end(); ) {
  78. Instruction *I = (iIt++);
  79. if (BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
  80. // Mutate mat to vec.
  81. Type *ToTy = BCI->getType();
  82. if (TryLowerMatTy(ToTy)) {
  83. matCastSet.insert(BCI);
  84. bUpdated = true;
  85. }
  86. }
  87. }
  88. }
  89. DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
  90. // Remove bitcast which has CallInst user.
  91. if (DM.GetShaderModel()->IsLib()) {
  92. for (auto it = matCastSet.begin(); it != matCastSet.end();) {
  93. BitCastInst *BCI = *(it++);
  94. if (hasCallUser(BCI)) {
  95. matCastSet.erase(BCI);
  96. }
  97. }
  98. }
  99. // Lower matrix first.
  100. for (BitCastInst *BCI : matCastSet) {
  101. lowerMatrix(BCI, BCI->getOperand(0));
  102. }
  103. return bUpdated;
  104. }
  105. private:
  106. void lowerMatrix(Instruction *M, Value *A);
  107. bool hasCallUser(Instruction *M);
  108. };
  109. }
  110. bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
  111. for (auto it = M->user_begin(); it != M->user_end();) {
  112. User *U = *(it++);
  113. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  114. Type *EltTy = GEP->getType()->getPointerElementType();
  115. if (dxilutil::IsHLSLMatrixType(EltTy)) {
  116. if (hasCallUser(GEP))
  117. return true;
  118. } else {
  119. DXASSERT(0, "invalid GEP for matrix");
  120. }
  121. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
  122. if (hasCallUser(BCI))
  123. return true;
  124. } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  125. if (isa<VectorType>(LI->getType())) {
  126. } else {
  127. DXASSERT(0, "invalid load for matrix");
  128. }
  129. } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
  130. Value *V = ST->getValueOperand();
  131. if (isa<VectorType>(V->getType())) {
  132. } else {
  133. DXASSERT(0, "invalid load for matrix");
  134. }
  135. } else if (isa<CallInst>(U)) {
  136. return true;
  137. } else {
  138. DXASSERT(0, "invalid use of matrix");
  139. }
  140. }
  141. return false;
  142. }
  143. namespace {
  144. Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
  145. IRBuilder<> &Builder) {
  146. Value *GEP = nullptr;
  147. if (GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A)) {
  148. // A should be gep oneDimArray, 0, index * matSize
  149. // Here add eltIdx to index * matSize foreach elt.
  150. Instruction *EltGEP = GEPA->clone();
  151. unsigned eltIdx = EltGEP->getNumOperands() - 1;
  152. Value *NewIdx =
  153. Builder.CreateAdd(EltGEP->getOperand(eltIdx), Builder.getInt32(i));
  154. EltGEP->setOperand(eltIdx, NewIdx);
  155. Builder.Insert(EltGEP);
  156. GEP = EltGEP;
  157. } else {
  158. GEP = Builder.CreateInBoundsGEP(A, {zeroIdx, Builder.getInt32(i)});
  159. }
  160. return GEP;
  161. }
  162. } // namespace
  163. void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
  164. for (auto it = M->user_begin(); it != M->user_end();) {
  165. User *U = *(it++);
  166. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  167. Type *EltTy = GEP->getType()->getPointerElementType();
  168. if (dxilutil::IsHLSLMatrixType(EltTy)) {
  169. // Change gep matrixArray, 0, index
  170. // into
  171. // gep oneDimArray, 0, index * matSize
  172. IRBuilder<> Builder(GEP);
  173. SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
  174. DXASSERT(idxList.size() == 2,
  175. "else not one dim matrix array index to matrix");
  176. HLMatrixType MatTy = HLMatrixType::cast(EltTy);
  177. Value *matSize = Builder.getInt32(MatTy.getNumElements());
  178. idxList.back() = Builder.CreateMul(idxList.back(), matSize);
  179. Value *NewGEP = Builder.CreateGEP(A, idxList);
  180. lowerMatrix(GEP, NewGEP);
  181. DXASSERT(GEP->user_empty(), "else lower matrix fail");
  182. GEP->eraseFromParent();
  183. } else {
  184. DXASSERT(0, "invalid GEP for matrix");
  185. }
  186. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
  187. lowerMatrix(BCI, A);
  188. DXASSERT(BCI->user_empty(), "else lower matrix fail");
  189. BCI->eraseFromParent();
  190. } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  191. if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
  192. IRBuilder<> Builder(LI);
  193. Value *zeroIdx = Builder.getInt32(0);
  194. unsigned vecSize = Ty->getNumElements();
  195. Value *NewVec = UndefValue::get(LI->getType());
  196. for (unsigned i = 0; i < vecSize; i++) {
  197. Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
  198. Value *Elt = Builder.CreateLoad(GEP);
  199. NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
  200. }
  201. LI->replaceAllUsesWith(NewVec);
  202. LI->eraseFromParent();
  203. } else {
  204. DXASSERT(0, "invalid load for matrix");
  205. }
  206. } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
  207. Value *V = ST->getValueOperand();
  208. if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
  209. IRBuilder<> Builder(LI);
  210. Value *zeroIdx = Builder.getInt32(0);
  211. unsigned vecSize = Ty->getNumElements();
  212. for (unsigned i = 0; i < vecSize; i++) {
  213. Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
  214. Value *Elt = Builder.CreateExtractElement(V, i);
  215. Builder.CreateStore(Elt, GEP);
  216. }
  217. ST->eraseFromParent();
  218. } else {
  219. DXASSERT(0, "invalid load for matrix");
  220. }
  221. } else {
  222. DXASSERT(0, "invalid use of matrix");
  223. }
  224. }
  225. }
  226. char MatrixBitcastLowerPass::ID = 0;
  227. FunctionPass *llvm::createMatrixBitcastLowerPass() { return new MatrixBitcastLowerPass(); }
  228. INITIALIZE_PASS(MatrixBitcastLowerPass, "matrixbitcastlower", "Matrix Bitcast lower", false, false)