HLMatrixLowerPass.cpp 87 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixLowerPass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // HLMatrixLowerPass implementation. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  12. #include "dxc/HLSL/HLMatrixLowerPass.h"
  13. #include "dxc/HLSL/HLOperations.h"
  14. #include "dxc/HLSL/HLModule.h"
  15. #include "dxc/HLSL/DxilUtil.h"
  16. #include "dxc/HlslIntrinsicOp.h"
  17. #include "dxc/Support/Global.h"
  18. #include "dxc/HLSL/DxilOperations.h"
  19. #include "llvm/IR/IRBuilder.h"
  20. #include "llvm/IR/Module.h"
  21. #include "llvm/IR/DebugInfo.h"
  22. #include "llvm/IR/IntrinsicInst.h"
  23. #include "llvm/Transforms/Utils/Local.h"
  24. #include "llvm/Pass.h"
  25. #include "llvm/Support/raw_ostream.h"
  26. #include <unordered_set>
  27. #include <vector>
  28. using namespace llvm;
  29. using namespace hlsl;
  30. using namespace hlsl::HLMatrixLower;
  31. namespace hlsl {
  32. namespace HLMatrixLower {
  33. bool IsMatrixType(Type *Ty) {
  34. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  35. Type *EltTy = ST->getElementType(0);
  36. if (!ST->getName().startswith("class.matrix"))
  37. return false;
  38. bool isVecArray = EltTy->isArrayTy() &&
  39. EltTy->getArrayElementType()->isVectorTy();
  40. return isVecArray && EltTy->getArrayNumElements() <= 4;
  41. }
  42. return false;
  43. }
  44. // Translate matrix type to vector type.
  45. Type *LowerMatrixType(Type *Ty) {
  46. // Only translate matrix type and function type which use matrix type.
  47. // Not translate struct has matrix or matrix pointer.
  48. // Struct should be flattened before.
  49. // Pointer could cover by matldst which use vector as value type.
  50. if (FunctionType *FT = dyn_cast<FunctionType>(Ty)) {
  51. Type *RetTy = LowerMatrixType(FT->getReturnType());
  52. SmallVector<Type *, 4> params;
  53. for (Type *param : FT->params()) {
  54. params.emplace_back(LowerMatrixType(param));
  55. }
  56. return FunctionType::get(RetTy, params, false);
  57. } else if (IsMatrixType(Ty)) {
  58. unsigned row, col;
  59. Type *EltTy = GetMatrixInfo(Ty, col, row);
  60. return VectorType::get(EltTy, row * col);
  61. } else {
  62. return Ty;
  63. }
  64. }
  65. Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
  66. DXASSERT(IsMatrixType(Ty), "not matrix type");
  67. StructType *ST = cast<StructType>(Ty);
  68. Type *EltTy = ST->getElementType(0);
  69. Type *RowTy = EltTy->getArrayElementType();
  70. row = EltTy->getArrayNumElements();
  71. col = RowTy->getVectorNumElements();
  72. return RowTy->getVectorElementType();
  73. }
  74. bool IsMatrixArrayPointer(llvm::Type *Ty) {
  75. if (!Ty->isPointerTy())
  76. return false;
  77. Ty = Ty->getPointerElementType();
  78. if (!Ty->isArrayTy())
  79. return false;
  80. while (Ty->isArrayTy())
  81. Ty = Ty->getArrayElementType();
  82. return IsMatrixType(Ty);
  83. }
  84. Type *LowerMatrixArrayPointer(Type *Ty) {
  85. Ty = Ty->getPointerElementType();
  86. std::vector<unsigned> arraySizeList;
  87. while (Ty->isArrayTy()) {
  88. arraySizeList.push_back(Ty->getArrayNumElements());
  89. Ty = Ty->getArrayElementType();
  90. }
  91. Ty = LowerMatrixType(Ty);
  92. for (auto arraySize = arraySizeList.rbegin();
  93. arraySize != arraySizeList.rend(); arraySize++)
  94. Ty = ArrayType::get(Ty, *arraySize);
  95. return PointerType::get(Ty, 0);
  96. }
  97. Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
  98. IRBuilder<> &Builder) {
  99. Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
  100. for (unsigned i = 0; i < size; i++)
  101. Vec = Builder.CreateInsertElement(Vec, elts[i], i);
  102. return Vec;
  103. }
  104. Value *LowerGEPOnMatIndexListToIndex(
  105. llvm::GetElementPtrInst *GEP, ArrayRef<Value *> IdxList) {
  106. IRBuilder<> Builder(GEP);
  107. Value *zero = Builder.getInt32(0);
  108. DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
  109. Value *baseIdx = (GEP->idx_begin())->get();
  110. DXASSERT_LOCALVAR(baseIdx, baseIdx == zero, "base index must be 0");
  111. Value *Idx = (GEP->idx_begin() + 1)->get();
  112. if (ConstantInt *immIdx = dyn_cast<ConstantInt>(Idx)) {
  113. return IdxList[immIdx->getSExtValue()];
  114. } else {
  115. IRBuilder<> AllocaBuilder(
  116. GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  117. unsigned size = IdxList.size();
  118. // Store idxList to temp array.
  119. ArrayType *AT = ArrayType::get(IdxList[0]->getType(), size);
  120. Value *tempArray = AllocaBuilder.CreateAlloca(AT);
  121. for (unsigned i = 0; i < size; i++) {
  122. Value *EltPtr = Builder.CreateGEP(tempArray, {zero, Builder.getInt32(i)});
  123. Builder.CreateStore(IdxList[i], EltPtr);
  124. }
  125. // Load the idx.
  126. Value *GEPOffset = Builder.CreateGEP(tempArray, {zero, Idx});
  127. return Builder.CreateLoad(GEPOffset);
  128. }
  129. }
  130. unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row) {
  131. return c * row + r;
  132. }
  133. unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col) {
  134. return r * col + c;
  135. }
  136. } // namespace HLMatrixLower
  137. } // namespace hlsl
  138. namespace {
  139. class HLMatrixLowerPass : public ModulePass {
  140. public:
  141. static char ID; // Pass identification, replacement for typeid
  142. explicit HLMatrixLowerPass() : ModulePass(ID) {}
  143. const char *getPassName() const override { return "HL matrix lower"; }
  144. bool runOnModule(Module &M) override {
  145. m_pModule = &M;
  146. m_pHLModule = &m_pModule->GetOrCreateHLModule();
  147. // Load up debug information, to cross-reference values and the instructions
  148. // used to load them.
  149. m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  150. for (Function &F : M.functions()) {
  151. if (F.isDeclaration())
  152. continue;
  153. runOnFunction(F);
  154. }
  155. std::vector<GlobalVariable*> staticGVs;
  156. for (GlobalVariable &GV : M.globals()) {
  157. if (dxilutil::IsStaticGlobal(&GV) ||
  158. dxilutil::IsSharedMemoryGlobal(&GV)) {
  159. staticGVs.emplace_back(&GV);
  160. }
  161. }
  162. for (GlobalVariable *GV : staticGVs)
  163. runOnGlobal(GV);
  164. return true;
  165. }
  166. private:
  167. Module *m_pModule;
  168. HLModule *m_pHLModule;
  169. bool m_HasDbgInfo;
  170. std::vector<Instruction *> m_deadInsts;
  171. // For instruction like matrix array init.
  172. // May use more than 1 matrix alloca inst.
  173. // This set is here to avoid put it into deadInsts more than once.
  174. std::unordered_set<Instruction *> m_inDeadInstsSet;
  175. // For most matrix insturction users, it will only have one matrix use.
  176. // Use vector so save deadInsts because vector is cheap.
  177. void AddToDeadInsts(Instruction *I) { m_deadInsts.emplace_back(I); }
  178. // In case instruction has more than one matrix use.
  179. // Use AddToDeadInstsWithDups to make sure it's not add to deadInsts more than once.
  180. void AddToDeadInstsWithDups(Instruction *I) {
  181. if (m_inDeadInstsSet.count(I) == 0) {
  182. // Only add to deadInsts when it's not inside m_inDeadInstsSet.
  183. m_inDeadInstsSet.insert(I);
  184. AddToDeadInsts(I);
  185. }
  186. }
  187. void runOnFunction(Function &F);
  188. void runOnGlobal(GlobalVariable *GV);
  189. void runOnGlobalMatrixArray(GlobalVariable *GV);
  190. Instruction *MatCastToVec(CallInst *CI);
  191. Instruction *MatLdStToVec(CallInst *CI);
  192. Instruction *MatSubscriptToVec(CallInst *CI);
  193. Instruction *MatIntrinsicToVec(CallInst *CI);
  194. Instruction *TrivialMatUnOpToVec(CallInst *CI);
  195. // Replace matInst with vecInst on matUseInst.
  196. void TrivialMatUnOpReplace(CallInst *matInst, Instruction *vecInst,
  197. CallInst *matUseInst);
  198. Instruction *TrivialMatBinOpToVec(CallInst *CI);
  199. // Replace matInst with vecInst on matUseInst.
  200. void TrivialMatBinOpReplace(CallInst *matInst, Instruction *vecInst,
  201. CallInst *matUseInst);
  202. // Replace matInst with vecInst on mulInst.
  203. void TranslateMatMatMul(CallInst *matInst, Instruction *vecInst,
  204. CallInst *mulInst, bool isSigned);
  205. void TranslateMatVecMul(CallInst *matInst, Instruction *vecInst,
  206. CallInst *mulInst, bool isSigned);
  207. void TranslateVecMatMul(CallInst *matInst, Instruction *vecInst,
  208. CallInst *mulInst, bool isSigned);
  209. void TranslateMul(CallInst *matInst, Instruction *vecInst, CallInst *mulInst,
  210. bool isSigned);
  211. // Replace matInst with vecInst on transposeInst.
  212. void TranslateMatTranspose(CallInst *matInst, Instruction *vecInst,
  213. CallInst *transposeInst);
  214. void TranslateMatDeterminant(CallInst *matInst, Instruction *vecInst,
  215. CallInst *determinantInst);
  216. void MatIntrinsicReplace(CallInst *matInst, Instruction *vecInst,
  217. CallInst *matUseInst);
  218. // Replace matInst with vecInst on castInst.
  219. void TranslateMatMatCast(CallInst *matInst, Instruction *vecInst,
  220. CallInst *castInst);
  221. void TranslateMatToOtherCast(CallInst *matInst, Instruction *vecInst,
  222. CallInst *castInst);
  223. void TranslateMatCast(CallInst *matInst, Instruction *vecInst,
  224. CallInst *castInst);
  225. void TranslateMatMajorCast(CallInst *matInst, Instruction *vecInst,
  226. CallInst *castInst, bool rowToCol);
  227. // Replace matInst with vecInst in matSubscript
  228. void TranslateMatSubscript(Value *matInst, Value *vecInst,
  229. CallInst *matSubInst);
  230. // Replace matInst with vecInst
  231. void TranslateMatInit(CallInst *matInitInst);
  232. // Replace matInst with vecInst.
  233. void TranslateMatSelect(CallInst *matSelectInst);
  234. // Replace matInst with vecInst on matInitInst.
  235. void TranslateMatArrayGEP(Value *matInst, Instruction *vecInst,
  236. GetElementPtrInst *matGEP);
  237. void TranslateMatLoadStoreOnGlobal(Value *matGlobal, ArrayRef<Value *>vecGlobals,
  238. CallInst *matLdStInst);
  239. void TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
  240. CallInst *matLdStInst);
  241. void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
  242. void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
  243. // Replace matInst with vecInst on matUseInst.
  244. void TrivialMatReplace(CallInst *matInst, Instruction *vecInst,
  245. CallInst *matUseInst);
  246. // Lower a matrix type instruction to a vector type instruction.
  247. void lowerToVec(Instruction *matInst);
  248. // Lower users of a matrix type instruction.
  249. void replaceMatWithVec(Instruction *matInst, Instruction *vecInst);
  250. // Translate mat inst which need all operands ready.
  251. void finalMatTranslation(Instruction *matInst);
  252. // Delete dead insts in m_deadInsts.
  253. void DeleteDeadInsts();
  254. // Map from matrix inst to its vector version.
  255. DenseMap<Instruction *, Value *> matToVecMap;
  256. };
  257. }
  258. char HLMatrixLowerPass::ID = 0;
  259. ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
  260. INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
  261. static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
  262. IRBuilder<> Builder) {
  263. // Cast to bool.
  264. if (toTy->getScalarType()->isIntegerTy() &&
  265. toTy->getScalarType()->getIntegerBitWidth() == 1) {
  266. Type *fromTy = src->getType();
  267. bool isFloat = fromTy->getScalarType()->isFloatingPointTy();
  268. Constant *zero;
  269. if (isFloat)
  270. zero = llvm::ConstantFP::get(fromTy->getScalarType(), 0);
  271. else
  272. zero = llvm::ConstantInt::get(fromTy->getScalarType(), 0);
  273. if (toTy->getScalarType() != toTy) {
  274. // Create constant vector.
  275. unsigned size = toTy->getVectorNumElements();
  276. std::vector<Constant *> zeros(size, zero);
  277. zero = llvm::ConstantVector::get(zeros);
  278. }
  279. if (isFloat)
  280. return cast<Instruction>(Builder.CreateFCmpOEQ(src, zero));
  281. else
  282. return cast<Instruction>(Builder.CreateICmpEQ(src, zero));
  283. }
  284. Type *eltToTy = toTy->getScalarType();
  285. Type *eltFromTy = src->getType()->getScalarType();
  286. bool fromUnsigned = castOp == HLCastOpcode::FromUnsignedCast ||
  287. castOp == HLCastOpcode::UnsignedUnsignedCast;
  288. bool toUnsigned = castOp == HLCastOpcode::ToUnsignedCast ||
  289. castOp == HLCastOpcode::UnsignedUnsignedCast;
  290. Instruction::CastOps castOps = static_cast<Instruction::CastOps>(
  291. HLModule::FindCastOp(fromUnsigned, toUnsigned, eltFromTy, eltToTy));
  292. return cast<Instruction>(Builder.CreateCast(castOps, src, toTy));
  293. }
  294. Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
  295. IRBuilder<> Builder(CI);
  296. Value *op = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  297. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(CI));
  298. bool ToMat = IsMatrixType(CI->getType());
  299. bool FromMat = IsMatrixType(op->getType());
  300. if (ToMat && !FromMat) {
  301. // Translate OtherToMat here.
  302. // Rest will translated when replace.
  303. unsigned col, row;
  304. Type *EltTy = GetMatrixInfo(CI->getType(), col, row);
  305. unsigned toSize = col * row;
  306. Instruction *sizeCast = nullptr;
  307. Type *FromTy = op->getType();
  308. Type *I32Ty = IntegerType::get(FromTy->getContext(), 32);
  309. if (FromTy->isVectorTy()) {
  310. std::vector<Constant *> MaskVec(toSize);
  311. for (size_t i = 0; i != toSize; ++i)
  312. MaskVec[i] = ConstantInt::get(I32Ty, i);
  313. Value *castMask = ConstantVector::get(MaskVec);
  314. sizeCast = new ShuffleVectorInst(op, op, castMask);
  315. Builder.Insert(sizeCast);
  316. } else {
  317. op = Builder.CreateInsertElement(
  318. UndefValue::get(VectorType::get(FromTy, 1)), op, (uint64_t)0);
  319. Constant *zero = ConstantInt::get(I32Ty, 0);
  320. std::vector<Constant *> MaskVec(toSize, zero);
  321. Value *castMask = ConstantVector::get(MaskVec);
  322. sizeCast = new ShuffleVectorInst(op, op, castMask);
  323. Builder.Insert(sizeCast);
  324. }
  325. Instruction *typeCast = sizeCast;
  326. if (EltTy != FromTy->getScalarType()) {
  327. typeCast = CreateTypeCast(opcode, VectorType::get(EltTy, toSize),
  328. sizeCast, Builder);
  329. }
  330. return typeCast;
  331. } else if (FromMat && ToMat) {
  332. if (isa<Argument>(op)) {
  333. // Cast From mat to mat for arugment.
  334. IRBuilder<> Builder(CI);
  335. // Here only lower the return type to vector.
  336. Type *RetTy = LowerMatrixType(CI->getType());
  337. SmallVector<Type *, 4> params;
  338. for (Value *operand : CI->arg_operands()) {
  339. params.emplace_back(operand->getType());
  340. }
  341. Type *FT = FunctionType::get(RetTy, params, false);
  342. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  343. unsigned opcode = GetHLOpcode(CI);
  344. Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT),
  345. group, opcode);
  346. SmallVector<Value *, 4> argList;
  347. for (Value *arg : CI->arg_operands()) {
  348. argList.emplace_back(arg);
  349. }
  350. return Builder.CreateCall(vecF, argList);
  351. }
  352. }
  353. return MatIntrinsicToVec(CI);
  354. }
  355. Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
  356. IRBuilder<> Builder(CI);
  357. unsigned opcode = GetHLOpcode(CI);
  358. HLMatLoadStoreOpcode matOpcode = static_cast<HLMatLoadStoreOpcode>(opcode);
  359. Instruction *result = nullptr;
  360. switch (matOpcode) {
  361. case HLMatLoadStoreOpcode::ColMatLoad:
  362. case HLMatLoadStoreOpcode::RowMatLoad: {
  363. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
  364. if (isa<AllocaInst>(matPtr)) {
  365. Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
  366. result = Builder.CreateLoad(vecPtr);
  367. } else
  368. result = MatIntrinsicToVec(CI);
  369. } break;
  370. case HLMatLoadStoreOpcode::ColMatStore:
  371. case HLMatLoadStoreOpcode::RowMatStore: {
  372. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
  373. if (isa<AllocaInst>(matPtr)) {
  374. Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
  375. Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  376. Value *vecVal =
  377. UndefValue::get(HLMatrixLower::LowerMatrixType(matVal->getType()));
  378. result = Builder.CreateStore(vecVal, vecPtr);
  379. } else
  380. result = MatIntrinsicToVec(CI);
  381. } break;
  382. }
  383. return result;
  384. }
  385. Instruction *HLMatrixLowerPass::MatSubscriptToVec(CallInst *CI) {
  386. IRBuilder<> Builder(CI);
  387. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  388. if (isa<AllocaInst>(matPtr)) {
  389. // Here just create a new matSub call which use vec ptr.
  390. // Later in TranslateMatSubscript will do the real translation.
  391. std::vector<Value *> args(CI->getNumArgOperands());
  392. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  393. args[i] = CI->getArgOperand(i);
  394. }
  395. // Change mat ptr into vec ptr.
  396. args[HLOperandIndex::kMatSubscriptMatOpIdx] =
  397. matToVecMap[cast<Instruction>(matPtr)];
  398. std::vector<Type *> paramTyList(CI->getNumArgOperands());
  399. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  400. paramTyList[i] = args[i]->getType();
  401. }
  402. FunctionType *funcTy = FunctionType::get(CI->getType(), paramTyList, false);
  403. unsigned opcode = GetHLOpcode(CI);
  404. Function *opFunc = GetOrCreateHLFunction(*m_pModule, funcTy, HLOpcodeGroup::HLSubscript, opcode);
  405. return Builder.CreateCall(opFunc, args);
  406. } else
  407. return MatIntrinsicToVec(CI);
  408. }
  409. Instruction *HLMatrixLowerPass::MatIntrinsicToVec(CallInst *CI) {
  410. IRBuilder<> Builder(CI);
  411. unsigned opcode = GetHLOpcode(CI);
  412. Type *FT = LowerMatrixType(CI->getCalledFunction()->getFunctionType());
  413. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  414. Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT), group, opcode);
  415. SmallVector<Value *, 4> argList;
  416. for (Value *arg : CI->arg_operands()) {
  417. Type *Ty = arg->getType();
  418. if (IsMatrixType(Ty)) {
  419. argList.emplace_back(UndefValue::get(LowerMatrixType(Ty)));
  420. } else
  421. argList.emplace_back(arg);
  422. }
  423. return Builder.CreateCall(vecF, argList);
  424. }
  425. static Value *VectorizeScalarOp(Value *op, Type *dstTy, IRBuilder<> &Builder) {
  426. if (op->getType() == dstTy)
  427. return op;
  428. op = Builder.CreateInsertElement(
  429. UndefValue::get(VectorType::get(op->getType(), 1)), op, (uint64_t)0);
  430. Type *I32Ty = IntegerType::get(dstTy->getContext(), 32);
  431. Constant *zero = ConstantInt::get(I32Ty, 0);
  432. std::vector<Constant *> MaskVec(dstTy->getVectorNumElements(), zero);
  433. Value *castMask = ConstantVector::get(MaskVec);
  434. Value *vecOp = new ShuffleVectorInst(op, op, castMask);
  435. Builder.Insert(cast<Instruction>(vecOp));
  436. return vecOp;
  437. }
  438. Instruction *HLMatrixLowerPass::TrivialMatUnOpToVec(CallInst *CI) {
  439. Type *ResultTy = LowerMatrixType(CI->getType());
  440. UndefValue *tmp = UndefValue::get(ResultTy);
  441. IRBuilder<> Builder(CI);
  442. HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(CI));
  443. bool isFloat = ResultTy->getVectorElementType()->isFloatingPointTy();
  444. auto GetVecConst = [&](Type *Ty, int v) -> Constant * {
  445. Constant *val = isFloat ? ConstantFP::get(Ty->getScalarType(), v)
  446. : ConstantInt::get(Ty->getScalarType(), v);
  447. std::vector<Constant *> vals(Ty->getVectorNumElements(), val);
  448. return ConstantVector::get(vals);
  449. };
  450. Constant *one = GetVecConst(ResultTy, 1);
  451. Instruction *Result = nullptr;
  452. switch (opcode) {
  453. case HLUnaryOpcode::Minus: {
  454. Constant *zero = GetVecConst(ResultTy, 0);
  455. if (isFloat)
  456. Result = BinaryOperator::CreateFSub(zero, tmp);
  457. else
  458. Result = BinaryOperator::CreateSub(zero, tmp);
  459. } break;
  460. case HLUnaryOpcode::LNot: {
  461. Constant *zero = GetVecConst(ResultTy, 0);
  462. if (isFloat)
  463. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_UNE, tmp, zero);
  464. else
  465. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, zero);
  466. } break;
  467. case HLUnaryOpcode::Not:
  468. Result = BinaryOperator::CreateXor(tmp, tmp);
  469. break;
  470. case HLUnaryOpcode::PostInc:
  471. case HLUnaryOpcode::PreInc:
  472. if (isFloat)
  473. Result = BinaryOperator::CreateFAdd(tmp, one);
  474. else
  475. Result = BinaryOperator::CreateAdd(tmp, one);
  476. break;
  477. case HLUnaryOpcode::PostDec:
  478. case HLUnaryOpcode::PreDec:
  479. if (isFloat)
  480. Result = BinaryOperator::CreateFSub(tmp, one);
  481. else
  482. Result = BinaryOperator::CreateSub(tmp, one);
  483. break;
  484. default:
  485. DXASSERT(0, "not implement");
  486. return nullptr;
  487. }
  488. Builder.Insert(Result);
  489. return Result;
  490. }
  491. Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
  492. Type *ResultTy = LowerMatrixType(CI->getType());
  493. IRBuilder<> Builder(CI);
  494. HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(CI));
  495. Type *OpTy = LowerMatrixType(
  496. CI->getOperand(HLOperandIndex::kBinaryOpSrc0Idx)->getType());
  497. UndefValue *tmp = UndefValue::get(OpTy);
  498. bool isFloat = OpTy->getVectorElementType()->isFloatingPointTy();
  499. Instruction *Result = nullptr;
  500. switch (opcode) {
  501. case HLBinaryOpcode::Add:
  502. if (isFloat)
  503. Result = BinaryOperator::CreateFAdd(tmp, tmp);
  504. else
  505. Result = BinaryOperator::CreateAdd(tmp, tmp);
  506. break;
  507. case HLBinaryOpcode::Sub:
  508. if (isFloat)
  509. Result = BinaryOperator::CreateFSub(tmp, tmp);
  510. else
  511. Result = BinaryOperator::CreateSub(tmp, tmp);
  512. break;
  513. case HLBinaryOpcode::Mul:
  514. if (isFloat)
  515. Result = BinaryOperator::CreateFMul(tmp, tmp);
  516. else
  517. Result = BinaryOperator::CreateMul(tmp, tmp);
  518. break;
  519. case HLBinaryOpcode::Div:
  520. if (isFloat)
  521. Result = BinaryOperator::CreateFDiv(tmp, tmp);
  522. else
  523. Result = BinaryOperator::CreateSDiv(tmp, tmp);
  524. break;
  525. case HLBinaryOpcode::Rem:
  526. if (isFloat)
  527. Result = BinaryOperator::CreateFRem(tmp, tmp);
  528. else
  529. Result = BinaryOperator::CreateSRem(tmp, tmp);
  530. break;
  531. case HLBinaryOpcode::And:
  532. Result = BinaryOperator::CreateAnd(tmp, tmp);
  533. break;
  534. case HLBinaryOpcode::Or:
  535. Result = BinaryOperator::CreateOr(tmp, tmp);
  536. break;
  537. case HLBinaryOpcode::Xor:
  538. Result = BinaryOperator::CreateXor(tmp, tmp);
  539. break;
  540. case HLBinaryOpcode::Shl: {
  541. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  542. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  543. Result = BinaryOperator::CreateShl(tmp, tmp);
  544. } break;
  545. case HLBinaryOpcode::Shr: {
  546. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  547. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  548. Result = BinaryOperator::CreateAShr(tmp, tmp);
  549. } break;
  550. case HLBinaryOpcode::LT:
  551. if (isFloat)
  552. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLT, tmp, tmp);
  553. else
  554. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT, tmp, tmp);
  555. break;
  556. case HLBinaryOpcode::GT:
  557. if (isFloat)
  558. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGT, tmp, tmp);
  559. else
  560. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, tmp, tmp);
  561. break;
  562. case HLBinaryOpcode::LE:
  563. if (isFloat)
  564. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLE, tmp, tmp);
  565. else
  566. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLE, tmp, tmp);
  567. break;
  568. case HLBinaryOpcode::GE:
  569. if (isFloat)
  570. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGE, tmp, tmp);
  571. else
  572. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGE, tmp, tmp);
  573. break;
  574. case HLBinaryOpcode::EQ:
  575. if (isFloat)
  576. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, tmp);
  577. else
  578. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, tmp);
  579. break;
  580. case HLBinaryOpcode::NE:
  581. if (isFloat)
  582. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, tmp);
  583. else
  584. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, tmp);
  585. break;
  586. case HLBinaryOpcode::UDiv:
  587. Result = BinaryOperator::CreateUDiv(tmp, tmp);
  588. break;
  589. case HLBinaryOpcode::URem:
  590. Result = BinaryOperator::CreateURem(tmp, tmp);
  591. break;
  592. case HLBinaryOpcode::UShr: {
  593. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  594. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  595. Result = BinaryOperator::CreateLShr(tmp, tmp);
  596. } break;
  597. case HLBinaryOpcode::ULT:
  598. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, tmp, tmp);
  599. break;
  600. case HLBinaryOpcode::UGT:
  601. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, tmp, tmp);
  602. break;
  603. case HLBinaryOpcode::ULE:
  604. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, tmp, tmp);
  605. break;
  606. case HLBinaryOpcode::UGE:
  607. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGE, tmp, tmp);
  608. break;
  609. case HLBinaryOpcode::LAnd:
  610. case HLBinaryOpcode::LOr: {
  611. Constant *zero;
  612. if (isFloat)
  613. zero = llvm::ConstantFP::get(ResultTy->getVectorElementType(), 0);
  614. else
  615. zero = llvm::ConstantInt::get(ResultTy->getVectorElementType(), 0);
  616. unsigned size = ResultTy->getVectorNumElements();
  617. std::vector<Constant *> zeros(size, zero);
  618. Value *vecZero = llvm::ConstantVector::get(zeros);
  619. Instruction *cmpL;
  620. if (isFloat)
  621. cmpL =
  622. CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
  623. else
  624. cmpL = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
  625. Builder.Insert(cmpL);
  626. Instruction *cmpR;
  627. if (isFloat)
  628. cmpR =
  629. CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
  630. else
  631. cmpR = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
  632. Builder.Insert(cmpR);
  633. // How to map l, r back? Need check opcode
  634. if (opcode == HLBinaryOpcode::LOr)
  635. Result = BinaryOperator::CreateAnd(cmpL, cmpR);
  636. else
  637. Result = BinaryOperator::CreateAnd(cmpL, cmpR);
  638. break;
  639. }
  640. default:
  641. DXASSERT(0, "not implement");
  642. return nullptr;
  643. }
  644. Builder.Insert(Result);
  645. return Result;
  646. }
  647. void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
  648. Instruction *vecInst;
  649. if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
  650. hlsl::HLOpcodeGroup group =
  651. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  652. switch (group) {
  653. case HLOpcodeGroup::HLIntrinsic: {
  654. vecInst = MatIntrinsicToVec(CI);
  655. } break;
  656. case HLOpcodeGroup::HLSelect: {
  657. vecInst = MatIntrinsicToVec(CI);
  658. } break;
  659. case HLOpcodeGroup::HLBinOp: {
  660. vecInst = TrivialMatBinOpToVec(CI);
  661. } break;
  662. case HLOpcodeGroup::HLUnOp: {
  663. vecInst = TrivialMatUnOpToVec(CI);
  664. } break;
  665. case HLOpcodeGroup::HLCast: {
  666. vecInst = MatCastToVec(CI);
  667. } break;
  668. case HLOpcodeGroup::HLInit: {
  669. vecInst = MatIntrinsicToVec(CI);
  670. } break;
  671. case HLOpcodeGroup::HLMatLoadStore: {
  672. vecInst = MatLdStToVec(CI);
  673. } break;
  674. case HLOpcodeGroup::HLSubscript: {
  675. vecInst = MatSubscriptToVec(CI);
  676. } break;
  677. }
  678. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
  679. Type *Ty = AI->getAllocatedType();
  680. Type *matTy = Ty;
  681. IRBuilder<> Builder(AI);
  682. if (Ty->isArrayTy()) {
  683. Type *vecTy = HLMatrixLower::LowerMatrixArrayPointer(AI->getType());
  684. vecTy = vecTy->getPointerElementType();
  685. vecInst = Builder.CreateAlloca(vecTy, nullptr, AI->getName());
  686. } else {
  687. Type *vecTy = HLMatrixLower::LowerMatrixType(matTy);
  688. vecInst = Builder.CreateAlloca(vecTy, nullptr, AI->getName());
  689. }
  690. // Update debug info.
  691. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI);
  692. if (DDI) {
  693. LLVMContext &Context = AI->getContext();
  694. Value *DDIVar = MetadataAsValue::get(Context, DDI->getRawVariable());
  695. Value *DDIExp = MetadataAsValue::get(Context, DDI->getRawExpression());
  696. Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(vecInst));
  697. IRBuilder<> debugBuilder(DDI);
  698. debugBuilder.CreateCall(DDI->getCalledFunction(), {VMD, DDIVar, DDIExp});
  699. }
  700. if (HLModule::HasPreciseAttributeWithMetadata(AI))
  701. HLModule::MarkPreciseAttributeWithMetadata(vecInst);
  702. } else {
  703. DXASSERT(0, "invalid inst");
  704. }
  705. matToVecMap[matInst] = vecInst;
  706. }
  707. // Replace matInst with vecInst on matUseInst.
  708. void HLMatrixLowerPass::TrivialMatUnOpReplace(CallInst *matInst,
  709. Instruction *vecInst,
  710. CallInst *matUseInst) {
  711. HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(matUseInst));
  712. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
  713. switch (opcode) {
  714. case HLUnaryOpcode::Not:
  715. // Not is xor now
  716. vecUseInst->setOperand(0, vecInst);
  717. vecUseInst->setOperand(1, vecInst);
  718. break;
  719. case HLUnaryOpcode::LNot:
  720. case HLUnaryOpcode::PostInc:
  721. case HLUnaryOpcode::PreInc:
  722. case HLUnaryOpcode::PostDec:
  723. case HLUnaryOpcode::PreDec:
  724. vecUseInst->setOperand(0, vecInst);
  725. break;
  726. }
  727. }
  728. // Replace matInst with vecInst on matUseInst.
  729. void HLMatrixLowerPass::TrivialMatBinOpReplace(CallInst *matInst,
  730. Instruction *vecInst,
  731. CallInst *matUseInst) {
  732. HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(matUseInst));
  733. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
  734. if (opcode != HLBinaryOpcode::LAnd && opcode != HLBinaryOpcode::LOr) {
  735. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) == matInst)
  736. vecUseInst->setOperand(0, vecInst);
  737. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) == matInst)
  738. vecUseInst->setOperand(1, vecInst);
  739. } else {
  740. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) ==
  741. matInst) {
  742. Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(0));
  743. vecCmp->setOperand(0, vecInst);
  744. }
  745. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) ==
  746. matInst) {
  747. Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(1));
  748. vecCmp->setOperand(0, vecInst);
  749. }
  750. }
  751. }
  752. static Function *GetOrCreateMadIntrinsic(Type *Ty, Type *opcodeTy, IntrinsicOp madOp, Module &M) {
  753. llvm::FunctionType *MadFuncTy =
  754. llvm::FunctionType::get(Ty, { opcodeTy, Ty, Ty, Ty}, false);
  755. Function *MAD =
  756. GetOrCreateHLFunction(M, MadFuncTy, HLOpcodeGroup::HLIntrinsic,
  757. (unsigned)madOp);
  758. return MAD;
  759. }
  760. void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
  761. Instruction *vecInst,
  762. CallInst *mulInst, bool isSigned) {
  763. DXASSERT(matToVecMap.count(mulInst), "must has vec version");
  764. Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
  765. // Already translated.
  766. if (!isa<CallInst>(vecUseInst))
  767. return;
  768. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  769. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  770. unsigned col, row;
  771. Type *EltTy = GetMatrixInfo(LVal->getType(), col, row);
  772. unsigned rCol, rRow;
  773. GetMatrixInfo(RVal->getType(), rCol, rRow);
  774. DXASSERT_NOMSG(col == rRow);
  775. bool isFloat = EltTy->isFloatingPointTy();
  776. Value *retVal = llvm::UndefValue::get(LowerMatrixType(mulInst->getType()));
  777. IRBuilder<> Builder(mulInst);
  778. Value *lMat = matToVecMap[cast<Instruction>(LVal)];
  779. Value *rMat = matToVecMap[cast<Instruction>(RVal)];
  780. auto CreateOneEltMul = [&](unsigned r, unsigned lc, unsigned c) -> Value * {
  781. unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
  782. unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
  783. Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
  784. Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
  785. return isFloat ? Builder.CreateFMul(lMatElt, rMatElt)
  786. : Builder.CreateMul(lMatElt, rMatElt);
  787. };
  788. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  789. Type *opcodeTy = Builder.getInt32Ty();
  790. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  791. *m_pHLModule->GetModule());
  792. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  793. auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
  794. Value *acc) -> Value * {
  795. unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
  796. unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
  797. Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
  798. Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
  799. return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
  800. };
  801. for (unsigned r = 0; r < row; r++) {
  802. for (unsigned c = 0; c < rCol; c++) {
  803. unsigned lc = 0;
  804. Value *tmpVal = CreateOneEltMul(r, lc, c);
  805. for (lc = 1; lc < col; lc++) {
  806. tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
  807. }
  808. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, rCol);
  809. retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
  810. }
  811. }
  812. Instruction *matmatMul = cast<Instruction>(retVal);
  813. // Replace vec transpose function call with shuf.
  814. vecUseInst->replaceAllUsesWith(matmatMul);
  815. AddToDeadInsts(vecUseInst);
  816. matToVecMap[mulInst] = matmatMul;
  817. }
  818. void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
  819. Instruction *vecInst,
  820. CallInst *mulInst, bool isSigned) {
  821. // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  822. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  823. unsigned col, row;
  824. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  825. DXASSERT(RVal->getType()->getVectorNumElements() == col, "");
  826. bool isFloat = EltTy->isFloatingPointTy();
  827. Value *retVal = llvm::UndefValue::get(mulInst->getType());
  828. IRBuilder<> Builder(mulInst);
  829. Value *vec = RVal;
  830. Value *mat = vecInst; // vec version of matInst;
  831. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  832. Type *opcodeTy = Builder.getInt32Ty();
  833. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  834. *m_pHLModule->GetModule());
  835. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  836. auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
  837. Value *vecElt = Builder.CreateExtractElement(vec, c);
  838. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  839. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  840. return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
  841. };
  842. for (unsigned r = 0; r < row; r++) {
  843. unsigned c = 0;
  844. Value *vecElt = Builder.CreateExtractElement(vec, c);
  845. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  846. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  847. Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
  848. : Builder.CreateMul(vecElt, matElt);
  849. for (c = 1; c < col; c++) {
  850. tmpVal = CreateOneEltMad(r, c, tmpVal);
  851. }
  852. retVal = Builder.CreateInsertElement(retVal, tmpVal, r);
  853. }
  854. mulInst->replaceAllUsesWith(retVal);
  855. AddToDeadInsts(mulInst);
  856. }
  857. void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
  858. Instruction *vecInst,
  859. CallInst *mulInst, bool isSigned) {
  860. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  861. // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  862. Value *RVal = vecInst;
  863. unsigned col, row;
  864. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  865. DXASSERT(LVal->getType()->getVectorNumElements() == row, "");
  866. bool isFloat = EltTy->isFloatingPointTy();
  867. Value *retVal = llvm::UndefValue::get(mulInst->getType());
  868. IRBuilder<> Builder(mulInst);
  869. Value *vec = LVal;
  870. Value *mat = RVal;
  871. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  872. Type *opcodeTy = Builder.getInt32Ty();
  873. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  874. *m_pHLModule->GetModule());
  875. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  876. auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
  877. Value *vecElt = Builder.CreateExtractElement(vec, r);
  878. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  879. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  880. return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
  881. };
  882. for (unsigned c = 0; c < col; c++) {
  883. unsigned r = 0;
  884. Value *vecElt = Builder.CreateExtractElement(vec, r);
  885. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  886. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  887. Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
  888. : Builder.CreateMul(vecElt, matElt);
  889. for (r = 1; r < row; r++) {
  890. tmpVal = CreateOneEltMad(r, c, tmpVal);
  891. }
  892. retVal = Builder.CreateInsertElement(retVal, tmpVal, c);
  893. }
  894. mulInst->replaceAllUsesWith(retVal);
  895. AddToDeadInsts(mulInst);
  896. }
  897. void HLMatrixLowerPass::TranslateMul(CallInst *matInst, Instruction *vecInst,
  898. CallInst *mulInst, bool isSigned) {
  899. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  900. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  901. bool LMat = IsMatrixType(LVal->getType());
  902. bool RMat = IsMatrixType(RVal->getType());
  903. if (LMat && RMat) {
  904. TranslateMatMatMul(matInst, vecInst, mulInst, isSigned);
  905. } else if (LMat) {
  906. TranslateMatVecMul(matInst, vecInst, mulInst, isSigned);
  907. } else {
  908. TranslateVecMatMul(matInst, vecInst, mulInst, isSigned);
  909. }
  910. }
  911. void HLMatrixLowerPass::TranslateMatTranspose(CallInst *matInst,
  912. Instruction *vecInst,
  913. CallInst *transposeInst) {
  914. // Matrix value is row major, transpose is cast it to col major.
  915. TranslateMatMajorCast(matInst, vecInst, transposeInst, /*bRowToCol*/ true);
  916. }
  917. static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
  918. IRBuilder<> &Builder) {
  919. Value *mul0 = Builder.CreateFMul(m00, m11);
  920. Value *mul1 = Builder.CreateFMul(m01, m10);
  921. return Builder.CreateFSub(mul0, mul1);
  922. }
  923. static Value *Determinant3x3(Value *m00, Value *m01, Value *m02,
  924. Value *m10, Value *m11, Value *m12,
  925. Value *m20, Value *m21, Value *m22,
  926. IRBuilder<> &Builder) {
  927. Value *deter00 = Determinant2x2(m11, m12, m21, m22, Builder);
  928. Value *deter01 = Determinant2x2(m10, m12, m20, m22, Builder);
  929. Value *deter02 = Determinant2x2(m10, m11, m20, m21, Builder);
  930. deter00 = Builder.CreateFMul(m00, deter00);
  931. deter01 = Builder.CreateFMul(m01, deter01);
  932. deter02 = Builder.CreateFMul(m02, deter02);
  933. Value *result = Builder.CreateFSub(deter00, deter01);
  934. result = Builder.CreateFAdd(result, deter02);
  935. return result;
  936. }
  937. static Value *Determinant4x4(Value *m00, Value *m01, Value *m02, Value *m03,
  938. Value *m10, Value *m11, Value *m12, Value *m13,
  939. Value *m20, Value *m21, Value *m22, Value *m23,
  940. Value *m30, Value *m31, Value *m32, Value *m33,
  941. IRBuilder<> &Builder) {
  942. Value *deter00 = Determinant3x3(m11, m12, m13, m21, m22, m23, m31, m32, m33, Builder);
  943. Value *deter01 = Determinant3x3(m10, m12, m13, m20, m22, m23, m30, m32, m33, Builder);
  944. Value *deter02 = Determinant3x3(m10, m11, m13, m20, m21, m23, m30, m31, m33, Builder);
  945. Value *deter03 = Determinant3x3(m10, m11, m12, m20, m21, m22, m30, m31, m32, Builder);
  946. deter00 = Builder.CreateFMul(m00, deter00);
  947. deter01 = Builder.CreateFMul(m01, deter01);
  948. deter02 = Builder.CreateFMul(m02, deter02);
  949. deter03 = Builder.CreateFMul(m03, deter03);
  950. Value *result = Builder.CreateFSub(deter00, deter01);
  951. result = Builder.CreateFAdd(result, deter02);
  952. result = Builder.CreateFSub(result, deter03);
  953. return result;
  954. }
  955. void HLMatrixLowerPass::TranslateMatDeterminant(CallInst *matInst, Instruction *vecInst,
  956. CallInst *determinantInst) {
  957. unsigned row, col;
  958. GetMatrixInfo(matInst->getType(), col, row);
  959. IRBuilder<> Builder(determinantInst);
  960. // when row == 1, result is vecInst.
  961. Value *Result = vecInst;
  962. if (row == 2) {
  963. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  964. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  965. Value *m10 = Builder.CreateExtractElement(vecInst, 2);
  966. Value *m11 = Builder.CreateExtractElement(vecInst, 3);
  967. Result = Determinant2x2(m00, m01, m10, m11, Builder);
  968. }
  969. else if (row == 3) {
  970. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  971. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  972. Value *m02 = Builder.CreateExtractElement(vecInst, 2);
  973. Value *m10 = Builder.CreateExtractElement(vecInst, 3);
  974. Value *m11 = Builder.CreateExtractElement(vecInst, 4);
  975. Value *m12 = Builder.CreateExtractElement(vecInst, 5);
  976. Value *m20 = Builder.CreateExtractElement(vecInst, 6);
  977. Value *m21 = Builder.CreateExtractElement(vecInst, 7);
  978. Value *m22 = Builder.CreateExtractElement(vecInst, 8);
  979. Result = Determinant3x3(m00, m01, m02,
  980. m10, m11, m12,
  981. m20, m21, m22, Builder);
  982. }
  983. else if (row == 4) {
  984. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  985. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  986. Value *m02 = Builder.CreateExtractElement(vecInst, 2);
  987. Value *m03 = Builder.CreateExtractElement(vecInst, 3);
  988. Value *m10 = Builder.CreateExtractElement(vecInst, 4);
  989. Value *m11 = Builder.CreateExtractElement(vecInst, 5);
  990. Value *m12 = Builder.CreateExtractElement(vecInst, 6);
  991. Value *m13 = Builder.CreateExtractElement(vecInst, 7);
  992. Value *m20 = Builder.CreateExtractElement(vecInst, 8);
  993. Value *m21 = Builder.CreateExtractElement(vecInst, 9);
  994. Value *m22 = Builder.CreateExtractElement(vecInst, 10);
  995. Value *m23 = Builder.CreateExtractElement(vecInst, 11);
  996. Value *m30 = Builder.CreateExtractElement(vecInst, 12);
  997. Value *m31 = Builder.CreateExtractElement(vecInst, 13);
  998. Value *m32 = Builder.CreateExtractElement(vecInst, 14);
  999. Value *m33 = Builder.CreateExtractElement(vecInst, 15);
  1000. Result = Determinant4x4(m00, m01, m02, m03,
  1001. m10, m11, m12, m13,
  1002. m20, m21, m22, m23,
  1003. m30, m31, m32, m33,
  1004. Builder);
  1005. } else {
  1006. DXASSERT(row == 1, "invalid matrix type");
  1007. Result = Builder.CreateExtractElement(Result, (uint64_t)0);
  1008. }
  1009. determinantInst->replaceAllUsesWith(Result);
  1010. AddToDeadInsts(determinantInst);
  1011. }
  1012. void HLMatrixLowerPass::TrivialMatReplace(CallInst *matInst,
  1013. Instruction *vecInst,
  1014. CallInst *matUseInst) {
  1015. CallInst *vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
  1016. for (unsigned i = 0; i < matUseInst->getNumArgOperands(); i++)
  1017. if (matUseInst->getArgOperand(i) == matInst) {
  1018. vecUseInst->setArgOperand(i, vecInst);
  1019. }
  1020. }
  1021. void HLMatrixLowerPass::TranslateMatMajorCast(CallInst *matInst,
  1022. Instruction *vecInst,
  1023. CallInst *castInst,
  1024. bool bRowToCol) {
  1025. unsigned col, row;
  1026. GetMatrixInfo(castInst->getType(), col, row);
  1027. DXASSERT(castInst->getType() == matInst->getType(), "type must match");
  1028. IRBuilder<> Builder(castInst);
  1029. // shuf to change major.
  1030. SmallVector<int, 16> castMask(col * row);
  1031. unsigned idx = 0;
  1032. if (bRowToCol) {
  1033. for (unsigned c = 0; c < col; c++)
  1034. for (unsigned r = 0; r < row; r++) {
  1035. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  1036. castMask[idx++] = matIdx;
  1037. }
  1038. } else {
  1039. for (unsigned r = 0; r < row; r++)
  1040. for (unsigned c = 0; c < col; c++) {
  1041. unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
  1042. castMask[idx++] = matIdx;
  1043. }
  1044. }
  1045. Instruction *vecCast = cast<Instruction>(
  1046. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1047. // Replace vec cast function call with vecCast.
  1048. DXASSERT(matToVecMap.count(castInst), "must has vec version");
  1049. Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
  1050. vecUseInst->replaceAllUsesWith(vecCast);
  1051. AddToDeadInsts(vecUseInst);
  1052. matToVecMap[castInst] = vecCast;
  1053. }
  1054. void HLMatrixLowerPass::TranslateMatMatCast(CallInst *matInst,
  1055. Instruction *vecInst,
  1056. CallInst *castInst) {
  1057. unsigned toCol, toRow;
  1058. Type *ToEltTy = GetMatrixInfo(castInst->getType(), toCol, toRow);
  1059. unsigned fromCol, fromRow;
  1060. Type *FromEltTy = GetMatrixInfo(matInst->getType(), fromCol, fromRow);
  1061. unsigned fromSize = fromCol * fromRow;
  1062. unsigned toSize = toCol * toRow;
  1063. DXASSERT(fromSize >= toSize, "cannot extend matrix");
  1064. IRBuilder<> Builder(castInst);
  1065. Instruction *vecCast = nullptr;
  1066. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1067. if (fromSize == toSize) {
  1068. vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), vecInst,
  1069. Builder);
  1070. } else {
  1071. // shuf first
  1072. std::vector<int> castMask(toCol * toRow);
  1073. unsigned idx = 0;
  1074. for (unsigned r = 0; r < toRow; r++)
  1075. for (unsigned c = 0; c < toCol; c++) {
  1076. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, fromCol);
  1077. castMask[idx++] = matIdx;
  1078. }
  1079. Instruction *shuf = cast<Instruction>(
  1080. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1081. if (ToEltTy != FromEltTy)
  1082. vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), shuf,
  1083. Builder);
  1084. else
  1085. vecCast = shuf;
  1086. }
  1087. // Replace vec cast function call with vecCast.
  1088. DXASSERT(matToVecMap.count(castInst), "must has vec version");
  1089. Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
  1090. vecUseInst->replaceAllUsesWith(vecCast);
  1091. AddToDeadInsts(vecUseInst);
  1092. matToVecMap[castInst] = vecCast;
  1093. }
  1094. void HLMatrixLowerPass::TranslateMatToOtherCast(CallInst *matInst,
  1095. Instruction *vecInst,
  1096. CallInst *castInst) {
  1097. unsigned col, row;
  1098. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  1099. unsigned fromSize = col * row;
  1100. IRBuilder<> Builder(castInst);
  1101. Instruction *sizeCast = nullptr;
  1102. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1103. Type *ToTy = castInst->getType();
  1104. if (ToTy->isVectorTy()) {
  1105. unsigned toSize = ToTy->getVectorNumElements();
  1106. if (fromSize != toSize) {
  1107. std::vector<int> castMask(fromSize);
  1108. for (unsigned c = 0; c < toSize; c++)
  1109. castMask[c] = c;
  1110. sizeCast = cast<Instruction>(
  1111. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1112. } else
  1113. sizeCast = vecInst;
  1114. } else {
  1115. DXASSERT(ToTy->isSingleValueType(), "must scalar here");
  1116. sizeCast =
  1117. cast<Instruction>(Builder.CreateExtractElement(vecInst, (uint64_t)0));
  1118. }
  1119. Instruction *typeCast = sizeCast;
  1120. if (EltTy != ToTy->getScalarType()) {
  1121. typeCast = CreateTypeCast(opcode, ToTy, typeCast, Builder);
  1122. }
  1123. // Replace cast function call with typeCast.
  1124. castInst->replaceAllUsesWith(typeCast);
  1125. AddToDeadInsts(castInst);
  1126. }
  1127. void HLMatrixLowerPass::TranslateMatCast(CallInst *matInst,
  1128. Instruction *vecInst,
  1129. CallInst *castInst) {
  1130. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1131. if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
  1132. opcode == HLCastOpcode::RowMatrixToColMatrix) {
  1133. TranslateMatMajorCast(matInst, vecInst, castInst,
  1134. opcode == HLCastOpcode::RowMatrixToColMatrix);
  1135. } else {
  1136. bool ToMat = IsMatrixType(castInst->getType());
  1137. bool FromMat = IsMatrixType(matInst->getType());
  1138. if (ToMat && FromMat) {
  1139. TranslateMatMatCast(matInst, vecInst, castInst);
  1140. } else if (FromMat)
  1141. TranslateMatToOtherCast(matInst, vecInst, castInst);
  1142. else
  1143. DXASSERT(0, "Not translate as user of matInst");
  1144. }
  1145. }
  1146. void HLMatrixLowerPass::MatIntrinsicReplace(CallInst *matInst,
  1147. Instruction *vecInst,
  1148. CallInst *matUseInst) {
  1149. IRBuilder<> Builder(matUseInst);
  1150. IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(matUseInst));
  1151. switch (opcode) {
  1152. case IntrinsicOp::IOP_umul:
  1153. TranslateMul(matInst, vecInst, matUseInst, /*isSigned*/false);
  1154. break;
  1155. case IntrinsicOp::IOP_mul:
  1156. TranslateMul(matInst, vecInst, matUseInst, /*isSigned*/true);
  1157. break;
  1158. case IntrinsicOp::IOP_transpose:
  1159. TranslateMatTranspose(matInst, vecInst, matUseInst);
  1160. break;
  1161. case IntrinsicOp::IOP_determinant:
  1162. TranslateMatDeterminant(matInst, vecInst, matUseInst);
  1163. break;
  1164. default:
  1165. CallInst *vecUseInst = nullptr;
  1166. if (matToVecMap.count(matUseInst))
  1167. vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
  1168. for (unsigned i = 0; i < matInst->getNumArgOperands(); i++)
  1169. if (matUseInst->getArgOperand(i) == matInst) {
  1170. if (vecUseInst)
  1171. vecUseInst->setArgOperand(i, vecInst);
  1172. else
  1173. matUseInst->setArgOperand(i, vecInst);
  1174. }
  1175. break;
  1176. }
  1177. }
  1178. void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
  1179. CallInst *matSubInst) {
  1180. unsigned opcode = GetHLOpcode(matSubInst);
  1181. HLSubscriptOpcode matOpcode = static_cast<HLSubscriptOpcode>(opcode);
  1182. assert(matOpcode != HLSubscriptOpcode::DefaultSubscript &&
  1183. "matrix don't use default subscript");
  1184. Type *matType = matInst->getType()->getPointerElementType();
  1185. unsigned col, row;
  1186. Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
  1187. bool isElement = (matOpcode == HLSubscriptOpcode::ColMatElement) |
  1188. (matOpcode == HLSubscriptOpcode::RowMatElement);
  1189. Value *mask =
  1190. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1191. if (isElement) {
  1192. Type *resultType = matSubInst->getType()->getPointerElementType();
  1193. unsigned resultSize = 1;
  1194. if (resultType->isVectorTy())
  1195. resultSize = resultType->getVectorNumElements();
  1196. std::vector<int> shufMask(resultSize);
  1197. Constant *EltIdxs = cast<Constant>(mask);
  1198. for (unsigned i = 0; i < resultSize; i++) {
  1199. shufMask[i] =
  1200. cast<ConstantInt>(EltIdxs->getAggregateElement(i))->getLimitedValue();
  1201. }
  1202. for (Value::use_iterator CallUI = matSubInst->use_begin(),
  1203. CallE = matSubInst->use_end();
  1204. CallUI != CallE;) {
  1205. Use &CallUse = *CallUI++;
  1206. Instruction *CallUser = cast<Instruction>(CallUse.getUser());
  1207. IRBuilder<> Builder(CallUser);
  1208. Value *vecLd = Builder.CreateLoad(vecInst);
  1209. if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
  1210. if (resultSize > 1) {
  1211. Value *shuf = Builder.CreateShuffleVector(vecLd, vecLd, shufMask);
  1212. ld->replaceAllUsesWith(shuf);
  1213. } else {
  1214. Value *elt = Builder.CreateExtractElement(vecLd, shufMask[0]);
  1215. ld->replaceAllUsesWith(elt);
  1216. }
  1217. } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
  1218. Value *val = st->getValueOperand();
  1219. if (resultSize > 1) {
  1220. for (unsigned i = 0; i < shufMask.size(); i++) {
  1221. unsigned idx = shufMask[i];
  1222. Value *valElt = Builder.CreateExtractElement(val, i);
  1223. vecLd = Builder.CreateInsertElement(vecLd, valElt, idx);
  1224. }
  1225. Builder.CreateStore(vecLd, vecInst);
  1226. } else {
  1227. vecLd = Builder.CreateInsertElement(vecLd, val, shufMask[0]);
  1228. Builder.CreateStore(vecLd, vecInst);
  1229. }
  1230. } else
  1231. DXASSERT(0, "matrix element should only used by load/store.");
  1232. AddToDeadInsts(CallUser);
  1233. }
  1234. } else {
  1235. // Subscript.
  1236. // Return a row.
  1237. // Use insertElement and extractElement.
  1238. ArrayType *AT = ArrayType::get(EltTy, col*row);
  1239. IRBuilder<> AllocaBuilder(
  1240. matSubInst->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  1241. Value *tempArray = AllocaBuilder.CreateAlloca(AT);
  1242. Value *zero = AllocaBuilder.getInt32(0);
  1243. bool isDynamicIndexing = !isa<ConstantInt>(mask);
  1244. SmallVector<Value *, 4> idxList;
  1245. for (unsigned i = 0; i < col; i++) {
  1246. idxList.emplace_back(
  1247. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i));
  1248. }
  1249. for (Value::use_iterator CallUI = matSubInst->use_begin(),
  1250. CallE = matSubInst->use_end();
  1251. CallUI != CallE;) {
  1252. Use &CallUse = *CallUI++;
  1253. Instruction *CallUser = cast<Instruction>(CallUse.getUser());
  1254. IRBuilder<> Builder(CallUser);
  1255. Value *vecLd = Builder.CreateLoad(vecInst);
  1256. if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
  1257. Value *sub = UndefValue::get(ld->getType());
  1258. if (!isDynamicIndexing) {
  1259. for (unsigned i = 0; i < col; i++) {
  1260. Value *matIdx = idxList[i];
  1261. Value *valElt = Builder.CreateExtractElement(vecLd, matIdx);
  1262. sub = Builder.CreateInsertElement(sub, valElt, i);
  1263. }
  1264. } else {
  1265. // Copy vec to array.
  1266. for (unsigned int i = 0; i < row*col; i++) {
  1267. Value *Elt =
  1268. Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
  1269. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1270. {zero, Builder.getInt32(i)});
  1271. Builder.CreateStore(Elt, Ptr);
  1272. }
  1273. for (unsigned i = 0; i < col; i++) {
  1274. Value *matIdx = idxList[i];
  1275. Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
  1276. Value *valElt = Builder.CreateLoad(Ptr);
  1277. sub = Builder.CreateInsertElement(sub, valElt, i);
  1278. }
  1279. }
  1280. ld->replaceAllUsesWith(sub);
  1281. } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
  1282. Value *val = st->getValueOperand();
  1283. if (!isDynamicIndexing) {
  1284. for (unsigned i = 0; i < col; i++) {
  1285. Value *matIdx = idxList[i];
  1286. Value *valElt = Builder.CreateExtractElement(val, i);
  1287. vecLd = Builder.CreateInsertElement(vecLd, valElt, matIdx);
  1288. }
  1289. } else {
  1290. // Copy vec to array.
  1291. for (unsigned int i = 0; i < row * col; i++) {
  1292. Value *Elt =
  1293. Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
  1294. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1295. {zero, Builder.getInt32(i)});
  1296. Builder.CreateStore(Elt, Ptr);
  1297. }
  1298. // Update array.
  1299. for (unsigned i = 0; i < col; i++) {
  1300. Value *matIdx = idxList[i];
  1301. Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
  1302. Value *valElt = Builder.CreateExtractElement(val, i);
  1303. Builder.CreateStore(valElt, Ptr);
  1304. }
  1305. // Copy array to vec.
  1306. for (unsigned int i = 0; i < row * col; i++) {
  1307. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1308. {zero, Builder.getInt32(i)});
  1309. Value *Elt = Builder.CreateLoad(Ptr);
  1310. vecLd = Builder.CreateInsertElement(vecLd, Elt, i);
  1311. }
  1312. }
  1313. Builder.CreateStore(vecLd, vecInst);
  1314. } else if (GetElementPtrInst *GEP =
  1315. dyn_cast<GetElementPtrInst>(CallUser)) {
  1316. Value *GEPOffset = HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
  1317. Value *NewGEP = Builder.CreateGEP(vecInst, {zero, GEPOffset});
  1318. GEP->replaceAllUsesWith(NewGEP);
  1319. } else
  1320. DXASSERT(0, "matrix subscript should only used by load/store.");
  1321. AddToDeadInsts(CallUser);
  1322. }
  1323. }
  1324. // Check vec version.
  1325. DXASSERT(matToVecMap.count(matSubInst) == 0, "should not have vec version");
  1326. // All the user should has been removed.
  1327. matSubInst->replaceAllUsesWith(UndefValue::get(matSubInst->getType()));
  1328. AddToDeadInsts(matSubInst);
  1329. }
  1330. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
  1331. Value *matGlobal, ArrayRef<Value *> vecGlobals,
  1332. CallInst *matLdStInst) {
  1333. // No dynamic indexing on matrix, flatten matrix to scalars.
  1334. // vecGlobals already in correct major.
  1335. Type *matType = matGlobal->getType()->getPointerElementType();
  1336. unsigned col, row;
  1337. HLMatrixLower::GetMatrixInfo(matType, col, row);
  1338. Type *vecType = HLMatrixLower::LowerMatrixType(matType);
  1339. IRBuilder<> Builder(matLdStInst);
  1340. HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
  1341. switch (opcode) {
  1342. case HLMatLoadStoreOpcode::ColMatLoad:
  1343. case HLMatLoadStoreOpcode::RowMatLoad: {
  1344. Value *Result = UndefValue::get(vecType);
  1345. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1346. Value *Elt = Builder.CreateLoad(vecGlobals[matIdx]);
  1347. Result = Builder.CreateInsertElement(Result, Elt, matIdx);
  1348. }
  1349. matLdStInst->replaceAllUsesWith(Result);
  1350. } break;
  1351. case HLMatLoadStoreOpcode::ColMatStore:
  1352. case HLMatLoadStoreOpcode::RowMatStore: {
  1353. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1354. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1355. Value *Elt = Builder.CreateExtractElement(Val, matIdx);
  1356. Builder.CreateStore(Elt, vecGlobals[matIdx]);
  1357. }
  1358. } break;
  1359. }
  1360. }
  1361. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
  1362. GlobalVariable *scalarArrayGlobal,
  1363. CallInst *matLdStInst) {
  1364. // vecGlobals already in correct major.
  1365. const bool bColMajor = true;
  1366. HLMatLoadStoreOpcode opcode =
  1367. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
  1368. switch (opcode) {
  1369. case HLMatLoadStoreOpcode::ColMatLoad:
  1370. case HLMatLoadStoreOpcode::RowMatLoad: {
  1371. IRBuilder<> Builder(matLdStInst);
  1372. Type *matTy = matGlobal->getType()->getPointerElementType();
  1373. unsigned col, row;
  1374. Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1375. Value *zeroIdx = Builder.getInt32(0);
  1376. std::vector<Value *> matElts(col * row);
  1377. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1378. Value *GEP = Builder.CreateInBoundsGEP(
  1379. scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
  1380. matElts[matIdx] = Builder.CreateLoad(GEP);
  1381. }
  1382. Value *newVec =
  1383. HLMatrixLower::BuildVector(EltTy, col * row, matElts, Builder);
  1384. matLdStInst->replaceAllUsesWith(newVec);
  1385. matLdStInst->eraseFromParent();
  1386. } break;
  1387. case HLMatLoadStoreOpcode::ColMatStore:
  1388. case HLMatLoadStoreOpcode::RowMatStore: {
  1389. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1390. IRBuilder<> Builder(matLdStInst);
  1391. Type *matTy = matGlobal->getType()->getPointerElementType();
  1392. unsigned col, row;
  1393. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1394. Value *zeroIdx = Builder.getInt32(0);
  1395. std::vector<Value *> matElts(col * row);
  1396. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1397. Value *GEP = Builder.CreateInBoundsGEP(
  1398. scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
  1399. Value *Elt = Builder.CreateExtractElement(Val, matIdx);
  1400. Builder.CreateStore(Elt, GEP);
  1401. }
  1402. matLdStInst->eraseFromParent();
  1403. } break;
  1404. }
  1405. }
  1406. void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(
  1407. CallInst *matSubInst, Value *vecPtr) {
  1408. Value *basePtr =
  1409. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  1410. Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1411. IRBuilder<> subBuilder(matSubInst);
  1412. Value *zeroIdx = subBuilder.getInt32(0);
  1413. HLSubscriptOpcode opcode =
  1414. static_cast<HLSubscriptOpcode>(GetHLOpcode(matSubInst));
  1415. Type *matTy = basePtr->getType()->getPointerElementType();
  1416. unsigned col, row;
  1417. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1418. std::vector<Value *> idxList;
  1419. switch (opcode) {
  1420. case HLSubscriptOpcode::ColMatSubscript:
  1421. case HLSubscriptOpcode::RowMatSubscript: {
  1422. // Just use index created in EmitHLSLMatrixSubscript.
  1423. for (unsigned c = 0; c < col; c++) {
  1424. Value *matIdx =
  1425. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + c);
  1426. idxList.emplace_back(matIdx);
  1427. }
  1428. } break;
  1429. case HLSubscriptOpcode::RowMatElement:
  1430. case HLSubscriptOpcode::ColMatElement: {
  1431. Type *resultType = matSubInst->getType()->getPointerElementType();
  1432. unsigned resultSize = 1;
  1433. if (resultType->isVectorTy())
  1434. resultSize = resultType->getVectorNumElements();
  1435. // Just use index created in EmitHLSLMatrixElement.
  1436. Constant *EltIdxs = cast<Constant>(idx);
  1437. for (unsigned i = 0; i < resultSize; i++) {
  1438. Value *matIdx = EltIdxs->getAggregateElement(i);
  1439. idxList.emplace_back(matIdx);
  1440. }
  1441. } break;
  1442. default:
  1443. DXASSERT(0, "invalid operation");
  1444. break;
  1445. }
  1446. // Cannot generate vector pointer
  1447. // Replace all uses with scalar pointers.
  1448. if (idxList.size() == 1) {
  1449. Value *Ptr =
  1450. subBuilder.CreateInBoundsGEP(vecPtr, {zeroIdx, idxList[0]});
  1451. matSubInst->replaceAllUsesWith(Ptr);
  1452. } else {
  1453. // Split the use of CI with Ptrs.
  1454. for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
  1455. Instruction *subsUser = cast<Instruction>(*(U++));
  1456. IRBuilder<> userBuilder(subsUser);
  1457. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
  1458. Value *IndexPtr =
  1459. HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
  1460. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1461. {zeroIdx, IndexPtr});
  1462. for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
  1463. Instruction *gepUser = cast<Instruction>(*(gepU++));
  1464. IRBuilder<> gepUserBuilder(gepUser);
  1465. if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
  1466. Value *subData = stUser->getValueOperand();
  1467. gepUserBuilder.CreateStore(subData, Ptr);
  1468. stUser->eraseFromParent();
  1469. } else if (LoadInst *ldUser = dyn_cast<LoadInst>(gepUser)) {
  1470. Value *subData = gepUserBuilder.CreateLoad(Ptr);
  1471. ldUser->replaceAllUsesWith(subData);
  1472. ldUser->eraseFromParent();
  1473. } else {
  1474. AddrSpaceCastInst *Cast = cast<AddrSpaceCastInst>(gepUser);
  1475. Cast->setOperand(0, Ptr);
  1476. }
  1477. }
  1478. GEP->eraseFromParent();
  1479. } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
  1480. Value *val = stUser->getValueOperand();
  1481. for (unsigned i = 0; i < idxList.size(); i++) {
  1482. Value *Elt = userBuilder.CreateExtractElement(val, i);
  1483. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1484. {zeroIdx, idxList[i]});
  1485. userBuilder.CreateStore(Elt, Ptr);
  1486. }
  1487. stUser->eraseFromParent();
  1488. } else {
  1489. Value *ldVal =
  1490. UndefValue::get(matSubInst->getType()->getPointerElementType());
  1491. for (unsigned i = 0; i < idxList.size(); i++) {
  1492. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1493. {zeroIdx, idxList[i]});
  1494. Value *Elt = userBuilder.CreateLoad(Ptr);
  1495. ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
  1496. }
  1497. // Must be load here.
  1498. LoadInst *ldUser = cast<LoadInst>(subsUser);
  1499. ldUser->replaceAllUsesWith(ldVal);
  1500. ldUser->eraseFromParent();
  1501. }
  1502. }
  1503. }
  1504. matSubInst->eraseFromParent();
  1505. }
  1506. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
  1507. CallInst *matLdStInst, Value *vecPtr) {
  1508. // Just translate into vector here.
  1509. // DynamicIndexingVectorToArray will change it to scalar array.
  1510. IRBuilder<> Builder(matLdStInst);
  1511. unsigned opcode = hlsl::GetHLOpcode(matLdStInst);
  1512. HLMatLoadStoreOpcode matLdStOp = static_cast<HLMatLoadStoreOpcode>(opcode);
  1513. switch (matLdStOp) {
  1514. case HLMatLoadStoreOpcode::ColMatLoad:
  1515. case HLMatLoadStoreOpcode::RowMatLoad: {
  1516. // Load as vector.
  1517. Value *newLoad = Builder.CreateLoad(vecPtr);
  1518. matLdStInst->replaceAllUsesWith(newLoad);
  1519. matLdStInst->eraseFromParent();
  1520. } break;
  1521. case HLMatLoadStoreOpcode::ColMatStore:
  1522. case HLMatLoadStoreOpcode::RowMatStore: {
  1523. // Change value to vector array, then store.
  1524. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1525. Value *vecArrayGep = vecPtr;
  1526. Builder.CreateStore(Val, vecArrayGep);
  1527. matLdStInst->eraseFromParent();
  1528. } break;
  1529. default:
  1530. DXASSERT(0, "invalid operation");
  1531. break;
  1532. }
  1533. }
  1534. // Flatten values inside init list to scalar elements.
  1535. static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
  1536. Value *val,
  1537. DenseMap<Instruction *, Value *> &matToVecMap,
  1538. IRBuilder<> &Builder) {
  1539. Type *valTy = val->getType();
  1540. if (valTy->isPointerTy()) {
  1541. if (HLMatrixLower::IsMatrixArrayPointer(valTy)) {
  1542. if (matToVecMap.count(cast<Instruction>(val))) {
  1543. val = matToVecMap[cast<Instruction>(val)];
  1544. } else {
  1545. // Convert to vec array with bitcast.
  1546. Type *vecArrayPtrTy = HLMatrixLower::LowerMatrixArrayPointer(valTy);
  1547. val = Builder.CreateBitCast(val, vecArrayPtrTy);
  1548. }
  1549. }
  1550. Type *valEltTy = val->getType()->getPointerElementType();
  1551. if (valEltTy->isVectorTy() || HLMatrixLower::IsMatrixType(valEltTy) ||
  1552. valEltTy->isSingleValueType()) {
  1553. Value *ldVal = Builder.CreateLoad(val);
  1554. IterateInitList(elts, idx, ldVal, matToVecMap, Builder);
  1555. } else {
  1556. Type *i32Ty = Type::getInt32Ty(valTy->getContext());
  1557. Value *zero = ConstantInt::get(i32Ty, 0);
  1558. if (ArrayType *AT = dyn_cast<ArrayType>(valEltTy)) {
  1559. for (unsigned i = 0; i < AT->getArrayNumElements(); i++) {
  1560. Value *gepIdx = ConstantInt::get(i32Ty, i);
  1561. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  1562. IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
  1563. }
  1564. } else {
  1565. // Struct.
  1566. StructType *ST = cast<StructType>(valEltTy);
  1567. for (unsigned i = 0; i < ST->getNumElements(); i++) {
  1568. Value *gepIdx = ConstantInt::get(i32Ty, i);
  1569. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  1570. IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
  1571. }
  1572. }
  1573. }
  1574. } else if (HLMatrixLower::IsMatrixType(valTy)) {
  1575. unsigned col, row;
  1576. HLMatrixLower::GetMatrixInfo(valTy, col, row);
  1577. unsigned matSize = col * row;
  1578. val = matToVecMap[cast<Instruction>(val)];
  1579. // temp matrix all row major
  1580. for (unsigned i = 0; i < matSize; i++) {
  1581. Value *Elt = Builder.CreateExtractElement(val, i);
  1582. elts[idx + i] = Elt;
  1583. }
  1584. idx += matSize;
  1585. } else {
  1586. if (valTy->isVectorTy()) {
  1587. unsigned vecSize = valTy->getVectorNumElements();
  1588. for (unsigned i = 0; i < vecSize; i++) {
  1589. Value *Elt = Builder.CreateExtractElement(val, i);
  1590. elts[idx + i] = Elt;
  1591. }
  1592. idx += vecSize;
  1593. } else {
  1594. DXASSERT(valTy->isSingleValueType(), "must be single value type here");
  1595. elts[idx++] = val;
  1596. }
  1597. }
  1598. }
  1599. // Store flattened init list elements into matrix array.
  1600. static void GenerateMatArrayInit(ArrayRef<Value *> elts, Value *ptr,
  1601. unsigned &offset, IRBuilder<> &Builder) {
  1602. Type *Ty = ptr->getType()->getPointerElementType();
  1603. if (Ty->isVectorTy()) {
  1604. unsigned vecSize = Ty->getVectorNumElements();
  1605. Type *eltTy = Ty->getVectorElementType();
  1606. Value *result = UndefValue::get(Ty);
  1607. for (unsigned i = 0; i < vecSize; i++) {
  1608. Value *elt = elts[offset + i];
  1609. if (elt->getType() != eltTy) {
  1610. // FIXME: get signed/unsigned info.
  1611. elt = CreateTypeCast(HLCastOpcode::DefaultCast, eltTy, elt, Builder);
  1612. }
  1613. result = Builder.CreateInsertElement(result, elt, i);
  1614. }
  1615. // Update offset.
  1616. offset += vecSize;
  1617. Builder.CreateStore(result, ptr);
  1618. } else {
  1619. DXASSERT(Ty->isArrayTy(), "must be array type");
  1620. Type *i32Ty = Type::getInt32Ty(Ty->getContext());
  1621. Constant *zero = ConstantInt::get(i32Ty, 0);
  1622. unsigned arraySize = Ty->getArrayNumElements();
  1623. for (unsigned i = 0; i < arraySize; i++) {
  1624. Value *GEP =
  1625. Builder.CreateInBoundsGEP(ptr, {zero, ConstantInt::get(i32Ty, i)});
  1626. GenerateMatArrayInit(elts, GEP, offset, Builder);
  1627. }
  1628. }
  1629. }
  1630. void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
  1631. // Array matrix init will be translated in TranslateMatArrayInitReplace.
  1632. if (matInitInst->getType()->isVoidTy())
  1633. return;
  1634. IRBuilder<> Builder(matInitInst);
  1635. unsigned col, row;
  1636. Type *EltTy = GetMatrixInfo(matInitInst->getType(), col, row);
  1637. Type *vecTy = VectorType::get(EltTy, col * row);
  1638. unsigned vecSize = vecTy->getVectorNumElements();
  1639. unsigned idx = 0;
  1640. std::vector<Value *> elts(vecSize);
  1641. // Skip opcode arg.
  1642. for (unsigned i = 1; i < matInitInst->getNumArgOperands(); i++) {
  1643. Value *val = matInitInst->getArgOperand(i);
  1644. IterateInitList(elts, idx, val, matToVecMap, Builder);
  1645. }
  1646. Value *newInit = UndefValue::get(vecTy);
  1647. // InitList is row major, the result is row major too.
  1648. for (unsigned i=0;i< col * row;i++) {
  1649. Constant *vecIdx = Builder.getInt32(i);
  1650. newInit = InsertElementInst::Create(newInit, elts[i], vecIdx);
  1651. Builder.Insert(cast<Instruction>(newInit));
  1652. }
  1653. // Replace matInit function call with matInitInst.
  1654. DXASSERT(matToVecMap.count(matInitInst), "must has vec version");
  1655. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
  1656. vecUseInst->replaceAllUsesWith(newInit);
  1657. AddToDeadInsts(vecUseInst);
  1658. matToVecMap[matInitInst] = newInit;
  1659. }
  1660. void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
  1661. IRBuilder<> Builder(matSelectInst);
  1662. unsigned col, row;
  1663. Type *EltTy = GetMatrixInfo(matSelectInst->getType(), col, row);
  1664. Type *vecTy = VectorType::get(EltTy, col * row);
  1665. unsigned vecSize = vecTy->getVectorNumElements();
  1666. CallInst *vecUseInst = cast<CallInst>(matToVecMap[matSelectInst]);
  1667. Instruction *LHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx));
  1668. Instruction *RHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx));
  1669. Value *Cond = vecUseInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
  1670. bool isVecCond = Cond->getType()->isVectorTy();
  1671. if (isVecCond) {
  1672. Instruction *MatCond = cast<Instruction>(
  1673. matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx));
  1674. DXASSERT_NOMSG(matToVecMap.count(MatCond));
  1675. Cond = matToVecMap[MatCond];
  1676. }
  1677. DXASSERT_NOMSG(matToVecMap.count(LHS));
  1678. Value *VLHS = matToVecMap[LHS];
  1679. DXASSERT_NOMSG(matToVecMap.count(RHS));
  1680. Value *VRHS = matToVecMap[RHS];
  1681. Value *VecSelect = UndefValue::get(vecTy);
  1682. for (unsigned i = 0; i < vecSize; i++) {
  1683. llvm::Value *EltCond = Cond;
  1684. if (isVecCond)
  1685. EltCond = Builder.CreateExtractElement(Cond, i);
  1686. llvm::Value *EltL = Builder.CreateExtractElement(VLHS, i);
  1687. llvm::Value *EltR = Builder.CreateExtractElement(VRHS, i);
  1688. llvm::Value *EltSelect = Builder.CreateSelect(EltCond, EltL, EltR);
  1689. VecSelect = Builder.CreateInsertElement(VecSelect, EltSelect, i);
  1690. }
  1691. AddToDeadInsts(vecUseInst);
  1692. vecUseInst->replaceAllUsesWith(VecSelect);
  1693. matToVecMap[matSelectInst] = VecSelect;
  1694. }
  1695. void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
  1696. Instruction *vecInst,
  1697. GetElementPtrInst *matGEP) {
  1698. SmallVector<Value *, 4> idxList(matGEP->idx_begin(), matGEP->idx_end());
  1699. IRBuilder<> GEPBuilder(matGEP);
  1700. Value *newGEP = GEPBuilder.CreateInBoundsGEP(vecInst, idxList);
  1701. // Only used by mat subscript and mat ld/st.
  1702. for (Value::user_iterator user = matGEP->user_begin();
  1703. user != matGEP->user_end();) {
  1704. Instruction *useInst = cast<Instruction>(*(user++));
  1705. IRBuilder<> Builder(useInst);
  1706. // Skip return here.
  1707. if (isa<ReturnInst>(useInst))
  1708. continue;
  1709. if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
  1710. // Function call.
  1711. hlsl::HLOpcodeGroup group =
  1712. hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
  1713. switch (group) {
  1714. case HLOpcodeGroup::HLMatLoadStore: {
  1715. unsigned opcode = GetHLOpcode(useCall);
  1716. HLMatLoadStoreOpcode matOpcode =
  1717. static_cast<HLMatLoadStoreOpcode>(opcode);
  1718. switch (matOpcode) {
  1719. case HLMatLoadStoreOpcode::ColMatLoad:
  1720. case HLMatLoadStoreOpcode::RowMatLoad: {
  1721. // Skip the vector version.
  1722. if (useCall->getType()->isVectorTy())
  1723. continue;
  1724. Value *newLd = Builder.CreateLoad(newGEP);
  1725. DXASSERT(matToVecMap.count(useCall), "must has vec version");
  1726. Value *oldLd = matToVecMap[useCall];
  1727. // Delete the oldLd.
  1728. AddToDeadInsts(cast<Instruction>(oldLd));
  1729. oldLd->replaceAllUsesWith(newLd);
  1730. matToVecMap[useCall] = newLd;
  1731. } break;
  1732. case HLMatLoadStoreOpcode::ColMatStore:
  1733. case HLMatLoadStoreOpcode::RowMatStore: {
  1734. Value *vecPtr = newGEP;
  1735. Value *matVal = useCall->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1736. // Skip the vector version.
  1737. if (matVal->getType()->isVectorTy()) {
  1738. AddToDeadInsts(useCall);
  1739. continue;
  1740. }
  1741. Instruction *matInst = cast<Instruction>(matVal);
  1742. DXASSERT(matToVecMap.count(matInst), "must has vec version");
  1743. Value *vecVal = matToVecMap[matInst];
  1744. Builder.CreateStore(vecVal, vecPtr);
  1745. } break;
  1746. }
  1747. } break;
  1748. case HLOpcodeGroup::HLSubscript: {
  1749. TranslateMatSubscript(matGEP, newGEP, useCall);
  1750. } break;
  1751. default:
  1752. DXASSERT(0, "invalid operation");
  1753. break;
  1754. }
  1755. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
  1756. // Just replace the src with vec version.
  1757. useInst->setOperand(0, newGEP);
  1758. } else {
  1759. // Must be GEP.
  1760. GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
  1761. TranslateMatArrayGEP(matGEP, cast<Instruction>(newGEP), GEP);
  1762. }
  1763. }
  1764. AddToDeadInsts(matGEP);
  1765. }
  1766. void HLMatrixLowerPass::replaceMatWithVec(Instruction *matInst,
  1767. Instruction *vecInst) {
  1768. for (Value::user_iterator user = matInst->user_begin();
  1769. user != matInst->user_end();) {
  1770. Instruction *useInst = cast<Instruction>(*(user++));
  1771. // Skip return here.
  1772. if (isa<ReturnInst>(useInst))
  1773. continue;
  1774. // User must be function call.
  1775. if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
  1776. hlsl::HLOpcodeGroup group =
  1777. hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
  1778. switch (group) {
  1779. case HLOpcodeGroup::HLIntrinsic: {
  1780. MatIntrinsicReplace(cast<CallInst>(matInst), vecInst, useCall);
  1781. } break;
  1782. case HLOpcodeGroup::HLSelect: {
  1783. MatIntrinsicReplace(cast<CallInst>(matInst), vecInst, useCall);
  1784. } break;
  1785. case HLOpcodeGroup::HLBinOp: {
  1786. TrivialMatBinOpReplace(cast<CallInst>(matInst), vecInst, useCall);
  1787. } break;
  1788. case HLOpcodeGroup::HLUnOp: {
  1789. TrivialMatUnOpReplace(cast<CallInst>(matInst), vecInst, useCall);
  1790. } break;
  1791. case HLOpcodeGroup::HLCast: {
  1792. TranslateMatCast(cast<CallInst>(matInst), vecInst, useCall);
  1793. } break;
  1794. case HLOpcodeGroup::HLMatLoadStore: {
  1795. DXASSERT(matToVecMap.count(useCall), "must has vec version");
  1796. Value *vecUser = matToVecMap[useCall];
  1797. if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
  1798. // Load Already translated in lowerToVec.
  1799. // Store val operand will be set by the val use.
  1800. // Do nothing here.
  1801. } else if (StoreInst *stInst = dyn_cast<StoreInst>(vecUser))
  1802. stInst->setOperand(0, vecInst);
  1803. else
  1804. TrivialMatReplace(cast<CallInst>(matInst), vecInst, useCall);
  1805. } break;
  1806. case HLOpcodeGroup::HLSubscript: {
  1807. if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst))
  1808. TranslateMatSubscript(AI, vecInst, useCall);
  1809. else
  1810. TrivialMatReplace(cast<CallInst>(matInst), vecInst, useCall);
  1811. } break;
  1812. case HLOpcodeGroup::HLInit: {
  1813. DXASSERT(!isa<AllocaInst>(matInst), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
  1814. TranslateMatInit(useCall);
  1815. } break;
  1816. }
  1817. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
  1818. // Just replace the src with vec version.
  1819. useInst->setOperand(0, vecInst);
  1820. } else {
  1821. // Must be GEP on mat array alloca.
  1822. GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
  1823. AllocaInst *AI = cast<AllocaInst>(matInst);
  1824. TranslateMatArrayGEP(AI, vecInst, GEP);
  1825. }
  1826. }
  1827. }
  1828. void HLMatrixLowerPass::finalMatTranslation(Instruction *matInst) {
  1829. // Translate matInit.
  1830. if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
  1831. hlsl::HLOpcodeGroup group =
  1832. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  1833. switch (group) {
  1834. case HLOpcodeGroup::HLInit: {
  1835. TranslateMatInit(CI);
  1836. } break;
  1837. case HLOpcodeGroup::HLSelect: {
  1838. TranslateMatSelect(CI);
  1839. } break;
  1840. default:
  1841. // Skip group already translated.
  1842. break;
  1843. }
  1844. }
  1845. }
  1846. void HLMatrixLowerPass::DeleteDeadInsts() {
  1847. // Delete the matrix version insts.
  1848. for (Instruction *deadInst : m_deadInsts) {
  1849. // Replace with undef and remove it.
  1850. deadInst->replaceAllUsesWith(UndefValue::get(deadInst->getType()));
  1851. deadInst->eraseFromParent();
  1852. }
  1853. m_deadInsts.clear();
  1854. m_inDeadInstsSet.clear();
  1855. }
  1856. static bool OnlyUsedByMatrixLdSt(Value *V) {
  1857. bool onlyLdSt = true;
  1858. for (User *user : V->users()) {
  1859. if (isa<Constant>(user) && user->use_empty())
  1860. continue;
  1861. CallInst *CI = cast<CallInst>(user);
  1862. if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
  1863. HLOpcodeGroup::HLMatLoadStore)
  1864. continue;
  1865. onlyLdSt = false;
  1866. break;
  1867. }
  1868. return onlyLdSt;
  1869. }
  1870. static Constant *LowerMatrixArrayConst(Constant *MA, Type *ResultTy) {
  1871. if (ArrayType *AT = dyn_cast<ArrayType>(ResultTy)) {
  1872. std::vector<Constant *> Elts;
  1873. Type *EltResultTy = AT->getElementType();
  1874. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  1875. Constant *Elt =
  1876. LowerMatrixArrayConst(MA->getAggregateElement(i), EltResultTy);
  1877. Elts.emplace_back(Elt);
  1878. }
  1879. return ConstantArray::get(AT, Elts);
  1880. } else {
  1881. // Cast float[row][col] -> float< row * col>.
  1882. // Get float[row][col] from the struct.
  1883. Constant *rows = MA->getAggregateElement((unsigned)0);
  1884. ArrayType *RowAT = cast<ArrayType>(rows->getType());
  1885. std::vector<Constant *> Elts;
  1886. for (unsigned r=0;r<RowAT->getArrayNumElements();r++) {
  1887. Constant *row = rows->getAggregateElement(r);
  1888. VectorType *VT = cast<VectorType>(row->getType());
  1889. for (unsigned c = 0; c < VT->getVectorNumElements(); c++) {
  1890. Elts.emplace_back(row->getAggregateElement(c));
  1891. }
  1892. }
  1893. return ConstantVector::get(Elts);
  1894. }
  1895. }
  1896. void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
  1897. // Lower to array of vector array like float[row * col].
  1898. // It's follow the major of decl.
  1899. // DynamicIndexingVectorToArray will change it to scalar array.
  1900. Type *Ty = GV->getType()->getPointerElementType();
  1901. std::vector<unsigned> arraySizeList;
  1902. while (Ty->isArrayTy()) {
  1903. arraySizeList.push_back(Ty->getArrayNumElements());
  1904. Ty = Ty->getArrayElementType();
  1905. }
  1906. unsigned row, col;
  1907. Type *EltTy = GetMatrixInfo(Ty, col, row);
  1908. Ty = VectorType::get(EltTy, col * row);
  1909. for (auto arraySize = arraySizeList.rbegin();
  1910. arraySize != arraySizeList.rend(); arraySize++)
  1911. Ty = ArrayType::get(Ty, *arraySize);
  1912. Type *VecArrayTy = Ty;
  1913. Constant *OldInitVal = GV->getInitializer();
  1914. Constant *InitVal =
  1915. isa<UndefValue>(OldInitVal)
  1916. ? UndefValue::get(VecArrayTy)
  1917. : LowerMatrixArrayConst(OldInitVal, cast<ArrayType>(VecArrayTy));
  1918. bool isConst = GV->isConstant();
  1919. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  1920. unsigned AddressSpace = GV->getType()->getAddressSpace();
  1921. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  1922. Module *M = GV->getParent();
  1923. GlobalVariable *VecGV =
  1924. new llvm::GlobalVariable(*M, VecArrayTy, /*IsConstant*/ isConst, linkage,
  1925. /*InitVal*/ InitVal, GV->getName() + ".v",
  1926. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  1927. // Add debug info.
  1928. if (m_HasDbgInfo) {
  1929. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  1930. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, VecGV);
  1931. }
  1932. DenseMap<Instruction *, Value *> matToVecMap;
  1933. for (User *U : GV->users()) {
  1934. Value *VecGEP = nullptr;
  1935. // Must be GEP or GEPOperator.
  1936. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  1937. IRBuilder<> Builder(GEP);
  1938. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  1939. VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
  1940. AddToDeadInsts(GEP);
  1941. } else {
  1942. GEPOperator *GEPOP = cast<GEPOperator>(U);
  1943. IRBuilder<> Builder(GV->getContext());
  1944. SmallVector<Value *, 4> idxList(GEPOP->idx_begin(), GEPOP->idx_end());
  1945. VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
  1946. }
  1947. for (auto user = U->user_begin(); user != U->user_end();) {
  1948. CallInst *CI = cast<CallInst>(*(user++));
  1949. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  1950. if (group == HLOpcodeGroup::HLMatLoadStore) {
  1951. TranslateMatLoadStoreOnGlobalPtr(CI, VecGEP);
  1952. } else if (group == HLOpcodeGroup::HLSubscript) {
  1953. TranslateMatSubscriptOnGlobalPtr(CI, VecGEP);
  1954. } else {
  1955. DXASSERT(0, "invalid operation");
  1956. }
  1957. }
  1958. }
  1959. DeleteDeadInsts();
  1960. GV->removeDeadConstantUsers();
  1961. GV->eraseFromParent();
  1962. }
  1963. static void FlattenMatConst(Constant *M, std::vector<Constant *> &Elts) {
  1964. unsigned row, col;
  1965. Type *EltTy = HLMatrixLower::GetMatrixInfo(M->getType(), col, row);
  1966. if (isa<UndefValue>(M)) {
  1967. Constant *Elt = UndefValue::get(EltTy);
  1968. for (unsigned i=0;i<col*row;i++)
  1969. Elts.emplace_back(Elt);
  1970. } else {
  1971. M = M->getAggregateElement((unsigned)0);
  1972. // Initializer is already in correct major.
  1973. // Just read it here.
  1974. // The type is vector<element, col>[row].
  1975. for (unsigned r = 0; r < row; r++) {
  1976. Constant *C = M->getAggregateElement(r);
  1977. for (unsigned c = 0; c < col; c++) {
  1978. Elts.emplace_back(C->getAggregateElement(c));
  1979. }
  1980. }
  1981. }
  1982. }
  1983. void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
  1984. if (HLMatrixLower::IsMatrixArrayPointer(GV->getType())) {
  1985. runOnGlobalMatrixArray(GV);
  1986. return;
  1987. }
  1988. Type *Ty = GV->getType()->getPointerElementType();
  1989. if (!HLMatrixLower::IsMatrixType(Ty))
  1990. return;
  1991. bool onlyLdSt = OnlyUsedByMatrixLdSt(GV);
  1992. bool isConst = GV->isConstant();
  1993. Type *vecTy = HLMatrixLower::LowerMatrixType(Ty);
  1994. Module *M = GV->getParent();
  1995. const DataLayout &DL = M->getDataLayout();
  1996. std::vector<Constant *> Elts;
  1997. // Lower to vector or array for scalar matrix.
  1998. // Make it col major so don't need shuffle when load/store.
  1999. FlattenMatConst(GV->getInitializer(), Elts);
  2000. if (onlyLdSt) {
  2001. Type *EltTy = vecTy->getVectorElementType();
  2002. unsigned vecSize = vecTy->getVectorNumElements();
  2003. std::vector<Value *> vecGlobals(vecSize);
  2004. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  2005. unsigned AddressSpace = GV->getType()->getAddressSpace();
  2006. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  2007. unsigned debugOffset = 0;
  2008. unsigned size = DL.getTypeAllocSizeInBits(EltTy);
  2009. unsigned align = DL.getPrefTypeAlignment(EltTy);
  2010. for (int i = 0, e = vecSize; i != e; ++i) {
  2011. Constant *InitVal = Elts[i];
  2012. GlobalVariable *EltGV = new llvm::GlobalVariable(
  2013. *M, EltTy, /*IsConstant*/ isConst, linkage,
  2014. /*InitVal*/ InitVal, GV->getName() + "." + Twine(i),
  2015. /*InsertBefore*/nullptr,
  2016. TLMode, AddressSpace);
  2017. // Add debug info.
  2018. if (m_HasDbgInfo) {
  2019. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2020. HLModule::CreateElementGlobalVariableDebugInfo(
  2021. GV, Finder, EltGV, size, align, debugOffset,
  2022. EltGV->getName().ltrim(GV->getName()));
  2023. debugOffset += size;
  2024. }
  2025. vecGlobals[i] = EltGV;
  2026. }
  2027. for (User *user : GV->users()) {
  2028. if (isa<Constant>(user) && user->use_empty())
  2029. continue;
  2030. CallInst *CI = cast<CallInst>(user);
  2031. TranslateMatLoadStoreOnGlobal(GV, vecGlobals, CI);
  2032. AddToDeadInsts(CI);
  2033. }
  2034. DeleteDeadInsts();
  2035. GV->eraseFromParent();
  2036. }
  2037. else {
  2038. // lower to array of scalar here.
  2039. ArrayType *AT = ArrayType::get(vecTy->getVectorElementType(), vecTy->getVectorNumElements());
  2040. Constant *InitVal = ConstantArray::get(AT, Elts);
  2041. GlobalVariable *arrayMat = new llvm::GlobalVariable(
  2042. *M, AT, /*IsConstant*/ false, llvm::GlobalValue::InternalLinkage,
  2043. /*InitVal*/ InitVal, GV->getName());
  2044. // Add debug info.
  2045. if (m_HasDbgInfo) {
  2046. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2047. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder,
  2048. arrayMat);
  2049. }
  2050. for (auto U = GV->user_begin(); U != GV->user_end();) {
  2051. Value *user = *(U++);
  2052. CallInst *CI = cast<CallInst>(user);
  2053. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2054. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2055. TranslateMatLoadStoreOnGlobal(GV, arrayMat, CI);
  2056. }
  2057. else {
  2058. DXASSERT(group == HLOpcodeGroup::HLSubscript, "Must be subscript operation");
  2059. TranslateMatSubscriptOnGlobalPtr(CI, arrayMat);
  2060. }
  2061. }
  2062. GV->removeDeadConstantUsers();
  2063. GV->eraseFromParent();
  2064. }
  2065. }
  2066. void HLMatrixLowerPass::runOnFunction(Function &F) {
  2067. // Create vector version of matrix instructions first.
  2068. // The matrix operands will be undefval for these instructions.
  2069. for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
  2070. BasicBlock *BB = BBI;
  2071. for (Instruction &I : BB->getInstList()) {
  2072. if (IsMatrixType(I.getType())) {
  2073. lowerToVec(&I);
  2074. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
  2075. Type *Ty = AI->getAllocatedType();
  2076. if (HLMatrixLower::IsMatrixType(Ty)) {
  2077. lowerToVec(&I);
  2078. } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
  2079. lowerToVec(&I);
  2080. }
  2081. } else if (CallInst *CI = dyn_cast<CallInst>(&I)) {
  2082. HLOpcodeGroup group =
  2083. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  2084. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2085. HLMatLoadStoreOpcode opcode =
  2086. static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
  2087. DXASSERT_LOCALVAR(opcode,
  2088. opcode == HLMatLoadStoreOpcode::ColMatStore ||
  2089. opcode == HLMatLoadStoreOpcode::RowMatStore,
  2090. "Must MatStore here, load will go IsMatrixType path");
  2091. // Lower it here to make sure it is ready before replace.
  2092. lowerToVec(&I);
  2093. }
  2094. }
  2095. }
  2096. }
  2097. // Update the use of matrix inst with the vector version.
  2098. for (auto matToVecIter = matToVecMap.begin();
  2099. matToVecIter != matToVecMap.end();) {
  2100. auto matToVec = matToVecIter++;
  2101. replaceMatWithVec(matToVec->first, cast<Instruction>(matToVec->second));
  2102. }
  2103. // Translate mat inst which require all operands ready.
  2104. for (auto matToVecIter = matToVecMap.begin();
  2105. matToVecIter != matToVecMap.end();) {
  2106. auto matToVec = matToVecIter++;
  2107. finalMatTranslation(matToVec->first);
  2108. }
  2109. // Delete the matrix version insts.
  2110. for (auto matToVecIter = matToVecMap.begin();
  2111. matToVecIter != matToVecMap.end();) {
  2112. auto matToVec = matToVecIter++;
  2113. // Add to m_deadInsts.
  2114. Instruction *matInst = matToVec->first;
  2115. AddToDeadInsts(matInst);
  2116. }
  2117. DeleteDeadInsts();
  2118. matToVecMap.clear();
  2119. }