HLMatrixLowerPass.cpp 87 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixLowerPass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // HLMatrixLowerPass implementation. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  12. #include "dxc/HLSL/HLMatrixLowerPass.h"
  13. #include "dxc/HLSL/HLOperations.h"
  14. #include "dxc/HLSL/HLModule.h"
  15. #include "dxc/HlslIntrinsicOp.h"
  16. #include "dxc/Support/Global.h"
  17. #include "dxc/HLSL/DxilOperations.h"
  18. #include "llvm/IR/IRBuilder.h"
  19. #include "llvm/IR/Module.h"
  20. #include "llvm/IR/DebugInfo.h"
  21. #include "llvm/IR/IntrinsicInst.h"
  22. #include "llvm/Transforms/Utils/Local.h"
  23. #include "llvm/Pass.h"
  24. #include "llvm/Support/raw_ostream.h"
  25. #include <unordered_set>
  26. #include <vector>
  27. using namespace llvm;
  28. using namespace hlsl;
  29. using namespace hlsl::HLMatrixLower;
  30. namespace hlsl {
  31. namespace HLMatrixLower {
  32. bool IsMatrixType(Type *Ty) {
  33. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  34. Type *EltTy = ST->getElementType(0);
  35. if (!ST->getName().startswith("class.matrix"))
  36. return false;
  37. bool isVecArray = EltTy->isArrayTy() &&
  38. EltTy->getArrayElementType()->isVectorTy();
  39. return isVecArray && EltTy->getArrayNumElements() <= 4;
  40. }
  41. return false;
  42. }
  43. // Translate matrix type to vector type.
  44. Type *LowerMatrixType(Type *Ty) {
  45. // Only translate matrix type and function type which use matrix type.
  46. // Not translate struct has matrix or matrix pointer.
  47. // Struct should be flattened before.
  48. // Pointer could cover by matldst which use vector as value type.
  49. if (FunctionType *FT = dyn_cast<FunctionType>(Ty)) {
  50. Type *RetTy = LowerMatrixType(FT->getReturnType());
  51. SmallVector<Type *, 4> params;
  52. for (Type *param : FT->params()) {
  53. params.emplace_back(LowerMatrixType(param));
  54. }
  55. return FunctionType::get(RetTy, params, false);
  56. } else if (IsMatrixType(Ty)) {
  57. unsigned row, col;
  58. Type *EltTy = GetMatrixInfo(Ty, col, row);
  59. return VectorType::get(EltTy, row * col);
  60. } else {
  61. return Ty;
  62. }
  63. }
  64. Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
  65. DXASSERT(IsMatrixType(Ty), "not matrix type");
  66. StructType *ST = cast<StructType>(Ty);
  67. Type *EltTy = ST->getElementType(0);
  68. Type *RowTy = EltTy->getArrayElementType();
  69. row = EltTy->getArrayNumElements();
  70. col = RowTy->getVectorNumElements();
  71. return RowTy->getVectorElementType();
  72. }
  73. bool IsMatrixArrayPointer(llvm::Type *Ty) {
  74. if (!Ty->isPointerTy())
  75. return false;
  76. Ty = Ty->getPointerElementType();
  77. if (!Ty->isArrayTy())
  78. return false;
  79. while (Ty->isArrayTy())
  80. Ty = Ty->getArrayElementType();
  81. return IsMatrixType(Ty);
  82. }
  83. Type *LowerMatrixArrayPointer(Type *Ty) {
  84. Ty = Ty->getPointerElementType();
  85. std::vector<unsigned> arraySizeList;
  86. while (Ty->isArrayTy()) {
  87. arraySizeList.push_back(Ty->getArrayNumElements());
  88. Ty = Ty->getArrayElementType();
  89. }
  90. Ty = LowerMatrixType(Ty);
  91. for (auto arraySize = arraySizeList.rbegin();
  92. arraySize != arraySizeList.rend(); arraySize++)
  93. Ty = ArrayType::get(Ty, *arraySize);
  94. return PointerType::get(Ty, 0);
  95. }
  96. Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
  97. IRBuilder<> &Builder) {
  98. Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
  99. for (unsigned i = 0; i < size; i++)
  100. Vec = Builder.CreateInsertElement(Vec, elts[i], i);
  101. return Vec;
  102. }
  103. Value *LowerGEPOnMatIndexListToIndex(
  104. llvm::GetElementPtrInst *GEP, ArrayRef<Value *> IdxList) {
  105. IRBuilder<> Builder(GEP);
  106. Value *zero = Builder.getInt32(0);
  107. DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
  108. Value *baseIdx = (GEP->idx_begin())->get();
  109. DXASSERT_LOCALVAR(baseIdx, baseIdx == zero, "base index must be 0");
  110. Value *Idx = (GEP->idx_begin() + 1)->get();
  111. if (ConstantInt *immIdx = dyn_cast<ConstantInt>(Idx)) {
  112. return IdxList[immIdx->getSExtValue()];
  113. } else {
  114. IRBuilder<> AllocaBuilder(
  115. GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  116. unsigned size = IdxList.size();
  117. // Store idxList to temp array.
  118. ArrayType *AT = ArrayType::get(IdxList[0]->getType(), size);
  119. Value *tempArray = AllocaBuilder.CreateAlloca(AT);
  120. for (unsigned i = 0; i < size; i++) {
  121. Value *EltPtr = Builder.CreateGEP(tempArray, {zero, Builder.getInt32(i)});
  122. Builder.CreateStore(IdxList[i], EltPtr);
  123. }
  124. // Load the idx.
  125. Value *GEPOffset = Builder.CreateGEP(tempArray, {zero, Idx});
  126. return Builder.CreateLoad(GEPOffset);
  127. }
  128. }
  129. unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row) {
  130. return c * row + r;
  131. }
  132. unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col) {
  133. return r * col + c;
  134. }
  135. } // namespace HLMatrixLower
  136. } // namespace hlsl
  137. namespace {
  138. class HLMatrixLowerPass : public ModulePass {
  139. public:
  140. static char ID; // Pass identification, replacement for typeid
  141. explicit HLMatrixLowerPass() : ModulePass(ID) {}
  142. const char *getPassName() const override { return "HL matrix lower"; }
  143. bool runOnModule(Module &M) override {
  144. m_pModule = &M;
  145. m_pHLModule = &m_pModule->GetOrCreateHLModule();
  146. // Load up debug information, to cross-reference values and the instructions
  147. // used to load them.
  148. m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  149. for (Function &F : M.functions()) {
  150. if (F.isDeclaration())
  151. continue;
  152. runOnFunction(F);
  153. }
  154. std::vector<GlobalVariable*> staticGVs;
  155. for (GlobalVariable &GV : M.globals()) {
  156. if (HLModule::IsStaticGlobal(&GV) ||
  157. HLModule::IsSharedMemoryGlobal(&GV)) {
  158. staticGVs.emplace_back(&GV);
  159. }
  160. }
  161. for (GlobalVariable *GV : staticGVs)
  162. runOnGlobal(GV);
  163. return true;
  164. }
  165. private:
  166. Module *m_pModule;
  167. HLModule *m_pHLModule;
  168. bool m_HasDbgInfo;
  169. std::vector<Instruction *> m_deadInsts;
  170. // For instruction like matrix array init.
  171. // May use more than 1 matrix alloca inst.
  172. // This set is here to avoid put it into deadInsts more than once.
  173. std::unordered_set<Instruction *> m_inDeadInstsSet;
  174. // For most matrix insturction users, it will only have one matrix use.
  175. // Use vector so save deadInsts because vector is cheap.
  176. void AddToDeadInsts(Instruction *I) { m_deadInsts.emplace_back(I); }
  177. // In case instruction has more than one matrix use.
  178. // Use AddToDeadInstsWithDups to make sure it's not add to deadInsts more than once.
  179. void AddToDeadInstsWithDups(Instruction *I) {
  180. if (m_inDeadInstsSet.count(I) == 0) {
  181. // Only add to deadInsts when it's not inside m_inDeadInstsSet.
  182. m_inDeadInstsSet.insert(I);
  183. AddToDeadInsts(I);
  184. }
  185. }
  186. void runOnFunction(Function &F);
  187. void runOnGlobal(GlobalVariable *GV);
  188. void runOnGlobalMatrixArray(GlobalVariable *GV);
  189. Instruction *MatCastToVec(CallInst *CI);
  190. Instruction *MatLdStToVec(CallInst *CI);
  191. Instruction *MatSubscriptToVec(CallInst *CI);
  192. Instruction *MatIntrinsicToVec(CallInst *CI);
  193. Instruction *TrivialMatUnOpToVec(CallInst *CI);
  194. // Replace matInst with vecInst on matUseInst.
  195. void TrivialMatUnOpReplace(CallInst *matInst, Instruction *vecInst,
  196. CallInst *matUseInst);
  197. Instruction *TrivialMatBinOpToVec(CallInst *CI);
  198. // Replace matInst with vecInst on matUseInst.
  199. void TrivialMatBinOpReplace(CallInst *matInst, Instruction *vecInst,
  200. CallInst *matUseInst);
  201. // Replace matInst with vecInst on mulInst.
  202. void TranslateMatMatMul(CallInst *matInst, Instruction *vecInst,
  203. CallInst *mulInst, bool isSigned);
  204. void TranslateMatVecMul(CallInst *matInst, Instruction *vecInst,
  205. CallInst *mulInst, bool isSigned);
  206. void TranslateVecMatMul(CallInst *matInst, Instruction *vecInst,
  207. CallInst *mulInst, bool isSigned);
  208. void TranslateMul(CallInst *matInst, Instruction *vecInst, CallInst *mulInst,
  209. bool isSigned);
  210. // Replace matInst with vecInst on transposeInst.
  211. void TranslateMatTranspose(CallInst *matInst, Instruction *vecInst,
  212. CallInst *transposeInst);
  213. void TranslateMatDeterminant(CallInst *matInst, Instruction *vecInst,
  214. CallInst *determinantInst);
  215. void MatIntrinsicReplace(CallInst *matInst, Instruction *vecInst,
  216. CallInst *matUseInst);
  217. // Replace matInst with vecInst on castInst.
  218. void TranslateMatMatCast(CallInst *matInst, Instruction *vecInst,
  219. CallInst *castInst);
  220. void TranslateMatToOtherCast(CallInst *matInst, Instruction *vecInst,
  221. CallInst *castInst);
  222. void TranslateMatCast(CallInst *matInst, Instruction *vecInst,
  223. CallInst *castInst);
  224. void TranslateMatMajorCast(CallInst *matInst, Instruction *vecInst,
  225. CallInst *castInst, bool rowToCol);
  226. // Replace matInst with vecInst in matSubscript
  227. void TranslateMatSubscript(Value *matInst, Value *vecInst,
  228. CallInst *matSubInst);
  229. // Replace matInst with vecInst
  230. void TranslateMatInit(CallInst *matInitInst);
  231. // Replace matInst with vecInst.
  232. void TranslateMatSelect(CallInst *matSelectInst);
  233. // Replace matInst with vecInst on matInitInst.
  234. void TranslateMatArrayGEP(Value *matInst, Instruction *vecInst,
  235. GetElementPtrInst *matGEP);
  236. void TranslateMatLoadStoreOnGlobal(Value *matGlobal, ArrayRef<Value *>vecGlobals,
  237. CallInst *matLdStInst);
  238. void TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
  239. CallInst *matLdStInst);
  240. void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
  241. void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
  242. // Replace matInst with vecInst on matUseInst.
  243. void TrivialMatReplace(CallInst *matInst, Instruction *vecInst,
  244. CallInst *matUseInst);
  245. // Lower a matrix type instruction to a vector type instruction.
  246. void lowerToVec(Instruction *matInst);
  247. // Lower users of a matrix type instruction.
  248. void replaceMatWithVec(Instruction *matInst, Instruction *vecInst);
  249. // Translate mat inst which need all operands ready.
  250. void finalMatTranslation(Instruction *matInst);
  251. // Delete dead insts in m_deadInsts.
  252. void DeleteDeadInsts();
  253. // Map from matrix inst to its vector version.
  254. DenseMap<Instruction *, Value *> matToVecMap;
  255. };
  256. }
  257. char HLMatrixLowerPass::ID = 0;
  258. ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
  259. INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
  260. static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
  261. IRBuilder<> Builder) {
  262. // Cast to bool.
  263. if (toTy->getScalarType()->isIntegerTy() &&
  264. toTy->getScalarType()->getIntegerBitWidth() == 1) {
  265. Type *fromTy = src->getType();
  266. bool isFloat = fromTy->getScalarType()->isFloatingPointTy();
  267. Constant *zero;
  268. if (isFloat)
  269. zero = llvm::ConstantFP::get(fromTy->getScalarType(), 0);
  270. else
  271. zero = llvm::ConstantInt::get(fromTy->getScalarType(), 0);
  272. if (toTy->getScalarType() != toTy) {
  273. // Create constant vector.
  274. unsigned size = toTy->getVectorNumElements();
  275. std::vector<Constant *> zeros(size, zero);
  276. zero = llvm::ConstantVector::get(zeros);
  277. }
  278. if (isFloat)
  279. return cast<Instruction>(Builder.CreateFCmpOEQ(src, zero));
  280. else
  281. return cast<Instruction>(Builder.CreateICmpEQ(src, zero));
  282. }
  283. Type *eltToTy = toTy->getScalarType();
  284. Type *eltFromTy = src->getType()->getScalarType();
  285. bool fromUnsigned = castOp == HLCastOpcode::FromUnsignedCast ||
  286. castOp == HLCastOpcode::UnsignedUnsignedCast;
  287. bool toUnsigned = castOp == HLCastOpcode::ToUnsignedCast ||
  288. castOp == HLCastOpcode::UnsignedUnsignedCast;
  289. Instruction::CastOps castOps = static_cast<Instruction::CastOps>(
  290. HLModule::FindCastOp(fromUnsigned, toUnsigned, eltFromTy, eltToTy));
  291. return cast<Instruction>(Builder.CreateCast(castOps, src, toTy));
  292. }
  293. Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
  294. IRBuilder<> Builder(CI);
  295. Value *op = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  296. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(CI));
  297. bool ToMat = IsMatrixType(CI->getType());
  298. bool FromMat = IsMatrixType(op->getType());
  299. if (ToMat && !FromMat) {
  300. // Translate OtherToMat here.
  301. // Rest will translated when replace.
  302. unsigned col, row;
  303. Type *EltTy = GetMatrixInfo(CI->getType(), col, row);
  304. unsigned toSize = col * row;
  305. Instruction *sizeCast = nullptr;
  306. Type *FromTy = op->getType();
  307. Type *I32Ty = IntegerType::get(FromTy->getContext(), 32);
  308. if (FromTy->isVectorTy()) {
  309. std::vector<Constant *> MaskVec(toSize);
  310. for (size_t i = 0; i != toSize; ++i)
  311. MaskVec[i] = ConstantInt::get(I32Ty, i);
  312. Value *castMask = ConstantVector::get(MaskVec);
  313. sizeCast = new ShuffleVectorInst(op, op, castMask);
  314. Builder.Insert(sizeCast);
  315. } else {
  316. op = Builder.CreateInsertElement(
  317. UndefValue::get(VectorType::get(FromTy, 1)), op, (uint64_t)0);
  318. Constant *zero = ConstantInt::get(I32Ty, 0);
  319. std::vector<Constant *> MaskVec(toSize, zero);
  320. Value *castMask = ConstantVector::get(MaskVec);
  321. sizeCast = new ShuffleVectorInst(op, op, castMask);
  322. Builder.Insert(sizeCast);
  323. }
  324. Instruction *typeCast = sizeCast;
  325. if (EltTy != FromTy->getScalarType()) {
  326. typeCast = CreateTypeCast(opcode, VectorType::get(EltTy, toSize),
  327. sizeCast, Builder);
  328. }
  329. return typeCast;
  330. } else if (FromMat && ToMat) {
  331. if (isa<Argument>(op)) {
  332. // Cast From mat to mat for arugment.
  333. IRBuilder<> Builder(CI);
  334. // Here only lower the return type to vector.
  335. Type *RetTy = LowerMatrixType(CI->getType());
  336. SmallVector<Type *, 4> params;
  337. for (Value *operand : CI->arg_operands()) {
  338. params.emplace_back(operand->getType());
  339. }
  340. Type *FT = FunctionType::get(RetTy, params, false);
  341. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  342. unsigned opcode = GetHLOpcode(CI);
  343. Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT),
  344. group, opcode);
  345. SmallVector<Value *, 4> argList;
  346. for (Value *arg : CI->arg_operands()) {
  347. argList.emplace_back(arg);
  348. }
  349. return Builder.CreateCall(vecF, argList);
  350. }
  351. }
  352. return MatIntrinsicToVec(CI);
  353. }
  354. Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
  355. IRBuilder<> Builder(CI);
  356. unsigned opcode = GetHLOpcode(CI);
  357. HLMatLoadStoreOpcode matOpcode = static_cast<HLMatLoadStoreOpcode>(opcode);
  358. Instruction *result = nullptr;
  359. switch (matOpcode) {
  360. case HLMatLoadStoreOpcode::ColMatLoad:
  361. case HLMatLoadStoreOpcode::RowMatLoad: {
  362. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
  363. if (isa<AllocaInst>(matPtr)) {
  364. Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
  365. result = Builder.CreateLoad(vecPtr);
  366. } else
  367. result = MatIntrinsicToVec(CI);
  368. } break;
  369. case HLMatLoadStoreOpcode::ColMatStore:
  370. case HLMatLoadStoreOpcode::RowMatStore: {
  371. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
  372. if (isa<AllocaInst>(matPtr)) {
  373. Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
  374. Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  375. Value *vecVal =
  376. UndefValue::get(HLMatrixLower::LowerMatrixType(matVal->getType()));
  377. result = Builder.CreateStore(vecVal, vecPtr);
  378. } else
  379. result = MatIntrinsicToVec(CI);
  380. } break;
  381. }
  382. return result;
  383. }
  384. Instruction *HLMatrixLowerPass::MatSubscriptToVec(CallInst *CI) {
  385. IRBuilder<> Builder(CI);
  386. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  387. if (isa<AllocaInst>(matPtr)) {
  388. // Here just create a new matSub call which use vec ptr.
  389. // Later in TranslateMatSubscript will do the real translation.
  390. std::vector<Value *> args(CI->getNumArgOperands());
  391. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  392. args[i] = CI->getArgOperand(i);
  393. }
  394. // Change mat ptr into vec ptr.
  395. args[HLOperandIndex::kMatSubscriptMatOpIdx] =
  396. matToVecMap[cast<Instruction>(matPtr)];
  397. std::vector<Type *> paramTyList(CI->getNumArgOperands());
  398. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  399. paramTyList[i] = args[i]->getType();
  400. }
  401. FunctionType *funcTy = FunctionType::get(CI->getType(), paramTyList, false);
  402. unsigned opcode = GetHLOpcode(CI);
  403. Function *opFunc = GetOrCreateHLFunction(*m_pModule, funcTy, HLOpcodeGroup::HLSubscript, opcode);
  404. return Builder.CreateCall(opFunc, args);
  405. } else
  406. return MatIntrinsicToVec(CI);
  407. }
  408. Instruction *HLMatrixLowerPass::MatIntrinsicToVec(CallInst *CI) {
  409. IRBuilder<> Builder(CI);
  410. unsigned opcode = GetHLOpcode(CI);
  411. Type *FT = LowerMatrixType(CI->getCalledFunction()->getFunctionType());
  412. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  413. Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT), group, opcode);
  414. SmallVector<Value *, 4> argList;
  415. for (Value *arg : CI->arg_operands()) {
  416. Type *Ty = arg->getType();
  417. if (IsMatrixType(Ty)) {
  418. argList.emplace_back(UndefValue::get(LowerMatrixType(Ty)));
  419. } else
  420. argList.emplace_back(arg);
  421. }
  422. return Builder.CreateCall(vecF, argList);
  423. }
  424. static Value *VectorizeScalarOp(Value *op, Type *dstTy, IRBuilder<> &Builder) {
  425. if (op->getType() == dstTy)
  426. return op;
  427. op = Builder.CreateInsertElement(
  428. UndefValue::get(VectorType::get(op->getType(), 1)), op, (uint64_t)0);
  429. Type *I32Ty = IntegerType::get(dstTy->getContext(), 32);
  430. Constant *zero = ConstantInt::get(I32Ty, 0);
  431. std::vector<Constant *> MaskVec(dstTy->getVectorNumElements(), zero);
  432. Value *castMask = ConstantVector::get(MaskVec);
  433. Value *vecOp = new ShuffleVectorInst(op, op, castMask);
  434. Builder.Insert(cast<Instruction>(vecOp));
  435. return vecOp;
  436. }
  437. Instruction *HLMatrixLowerPass::TrivialMatUnOpToVec(CallInst *CI) {
  438. Type *ResultTy = LowerMatrixType(CI->getType());
  439. UndefValue *tmp = UndefValue::get(ResultTy);
  440. IRBuilder<> Builder(CI);
  441. HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(CI));
  442. bool isFloat = ResultTy->getVectorElementType()->isFloatingPointTy();
  443. auto GetVecConst = [&](Type *Ty, int v) -> Constant * {
  444. Constant *val = isFloat ? ConstantFP::get(Ty->getScalarType(), v)
  445. : ConstantInt::get(Ty->getScalarType(), v);
  446. std::vector<Constant *> vals(Ty->getVectorNumElements(), val);
  447. return ConstantVector::get(vals);
  448. };
  449. Constant *one = GetVecConst(ResultTy, 1);
  450. Instruction *Result = nullptr;
  451. switch (opcode) {
  452. case HLUnaryOpcode::Minus: {
  453. Constant *zero = GetVecConst(ResultTy, 0);
  454. if (isFloat)
  455. Result = BinaryOperator::CreateFSub(zero, tmp);
  456. else
  457. Result = BinaryOperator::CreateSub(zero, tmp);
  458. } break;
  459. case HLUnaryOpcode::LNot: {
  460. Constant *zero = GetVecConst(ResultTy, 0);
  461. if (isFloat)
  462. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_UNE, tmp, zero);
  463. else
  464. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, zero);
  465. } break;
  466. case HLUnaryOpcode::Not:
  467. Result = BinaryOperator::CreateXor(tmp, tmp);
  468. break;
  469. case HLUnaryOpcode::PostInc:
  470. case HLUnaryOpcode::PreInc:
  471. if (isFloat)
  472. Result = BinaryOperator::CreateFAdd(tmp, one);
  473. else
  474. Result = BinaryOperator::CreateAdd(tmp, one);
  475. break;
  476. case HLUnaryOpcode::PostDec:
  477. case HLUnaryOpcode::PreDec:
  478. if (isFloat)
  479. Result = BinaryOperator::CreateFSub(tmp, one);
  480. else
  481. Result = BinaryOperator::CreateSub(tmp, one);
  482. break;
  483. default:
  484. DXASSERT(0, "not implement");
  485. return nullptr;
  486. }
  487. Builder.Insert(Result);
  488. return Result;
  489. }
  490. Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
  491. Type *ResultTy = LowerMatrixType(CI->getType());
  492. IRBuilder<> Builder(CI);
  493. HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(CI));
  494. Type *OpTy = LowerMatrixType(
  495. CI->getOperand(HLOperandIndex::kBinaryOpSrc0Idx)->getType());
  496. UndefValue *tmp = UndefValue::get(OpTy);
  497. bool isFloat = OpTy->getVectorElementType()->isFloatingPointTy();
  498. Instruction *Result = nullptr;
  499. switch (opcode) {
  500. case HLBinaryOpcode::Add:
  501. if (isFloat)
  502. Result = BinaryOperator::CreateFAdd(tmp, tmp);
  503. else
  504. Result = BinaryOperator::CreateAdd(tmp, tmp);
  505. break;
  506. case HLBinaryOpcode::Sub:
  507. if (isFloat)
  508. Result = BinaryOperator::CreateFSub(tmp, tmp);
  509. else
  510. Result = BinaryOperator::CreateSub(tmp, tmp);
  511. break;
  512. case HLBinaryOpcode::Mul:
  513. if (isFloat)
  514. Result = BinaryOperator::CreateFMul(tmp, tmp);
  515. else
  516. Result = BinaryOperator::CreateMul(tmp, tmp);
  517. break;
  518. case HLBinaryOpcode::Div:
  519. if (isFloat)
  520. Result = BinaryOperator::CreateFDiv(tmp, tmp);
  521. else
  522. Result = BinaryOperator::CreateSDiv(tmp, tmp);
  523. break;
  524. case HLBinaryOpcode::Rem:
  525. if (isFloat)
  526. Result = BinaryOperator::CreateFRem(tmp, tmp);
  527. else
  528. Result = BinaryOperator::CreateSRem(tmp, tmp);
  529. break;
  530. case HLBinaryOpcode::And:
  531. Result = BinaryOperator::CreateAnd(tmp, tmp);
  532. break;
  533. case HLBinaryOpcode::Or:
  534. Result = BinaryOperator::CreateOr(tmp, tmp);
  535. break;
  536. case HLBinaryOpcode::Xor:
  537. Result = BinaryOperator::CreateXor(tmp, tmp);
  538. break;
  539. case HLBinaryOpcode::Shl: {
  540. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  541. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  542. Result = BinaryOperator::CreateShl(tmp, tmp);
  543. } break;
  544. case HLBinaryOpcode::Shr: {
  545. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  546. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  547. Result = BinaryOperator::CreateAShr(tmp, tmp);
  548. } break;
  549. case HLBinaryOpcode::LT:
  550. if (isFloat)
  551. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLT, tmp, tmp);
  552. else
  553. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT, tmp, tmp);
  554. break;
  555. case HLBinaryOpcode::GT:
  556. if (isFloat)
  557. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGT, tmp, tmp);
  558. else
  559. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, tmp, tmp);
  560. break;
  561. case HLBinaryOpcode::LE:
  562. if (isFloat)
  563. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLE, tmp, tmp);
  564. else
  565. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLE, tmp, tmp);
  566. break;
  567. case HLBinaryOpcode::GE:
  568. if (isFloat)
  569. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGE, tmp, tmp);
  570. else
  571. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGE, tmp, tmp);
  572. break;
  573. case HLBinaryOpcode::EQ:
  574. if (isFloat)
  575. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, tmp);
  576. else
  577. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, tmp);
  578. break;
  579. case HLBinaryOpcode::NE:
  580. if (isFloat)
  581. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, tmp);
  582. else
  583. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, tmp);
  584. break;
  585. case HLBinaryOpcode::UDiv:
  586. Result = BinaryOperator::CreateUDiv(tmp, tmp);
  587. break;
  588. case HLBinaryOpcode::URem:
  589. Result = BinaryOperator::CreateURem(tmp, tmp);
  590. break;
  591. case HLBinaryOpcode::UShr: {
  592. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  593. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  594. Result = BinaryOperator::CreateLShr(tmp, tmp);
  595. } break;
  596. case HLBinaryOpcode::ULT:
  597. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, tmp, tmp);
  598. break;
  599. case HLBinaryOpcode::UGT:
  600. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, tmp, tmp);
  601. break;
  602. case HLBinaryOpcode::ULE:
  603. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, tmp, tmp);
  604. break;
  605. case HLBinaryOpcode::UGE:
  606. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGE, tmp, tmp);
  607. break;
  608. case HLBinaryOpcode::LAnd:
  609. case HLBinaryOpcode::LOr: {
  610. Constant *zero;
  611. if (isFloat)
  612. zero = llvm::ConstantFP::get(ResultTy->getVectorElementType(), 0);
  613. else
  614. zero = llvm::ConstantInt::get(ResultTy->getVectorElementType(), 0);
  615. unsigned size = ResultTy->getVectorNumElements();
  616. std::vector<Constant *> zeros(size, zero);
  617. Value *vecZero = llvm::ConstantVector::get(zeros);
  618. Instruction *cmpL;
  619. if (isFloat)
  620. cmpL =
  621. CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
  622. else
  623. cmpL = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
  624. Builder.Insert(cmpL);
  625. Instruction *cmpR;
  626. if (isFloat)
  627. cmpR =
  628. CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
  629. else
  630. cmpR = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
  631. Builder.Insert(cmpR);
  632. // How to map l, r back? Need check opcode
  633. if (opcode == HLBinaryOpcode::LOr)
  634. Result = BinaryOperator::CreateAnd(cmpL, cmpR);
  635. else
  636. Result = BinaryOperator::CreateAnd(cmpL, cmpR);
  637. break;
  638. }
  639. default:
  640. DXASSERT(0, "not implement");
  641. return nullptr;
  642. }
  643. Builder.Insert(Result);
  644. return Result;
  645. }
  646. void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
  647. Instruction *vecInst;
  648. if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
  649. hlsl::HLOpcodeGroup group =
  650. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  651. switch (group) {
  652. case HLOpcodeGroup::HLIntrinsic: {
  653. vecInst = MatIntrinsicToVec(CI);
  654. } break;
  655. case HLOpcodeGroup::HLSelect: {
  656. vecInst = MatIntrinsicToVec(CI);
  657. } break;
  658. case HLOpcodeGroup::HLBinOp: {
  659. vecInst = TrivialMatBinOpToVec(CI);
  660. } break;
  661. case HLOpcodeGroup::HLUnOp: {
  662. vecInst = TrivialMatUnOpToVec(CI);
  663. } break;
  664. case HLOpcodeGroup::HLCast: {
  665. vecInst = MatCastToVec(CI);
  666. } break;
  667. case HLOpcodeGroup::HLInit: {
  668. vecInst = MatIntrinsicToVec(CI);
  669. } break;
  670. case HLOpcodeGroup::HLMatLoadStore: {
  671. vecInst = MatLdStToVec(CI);
  672. } break;
  673. case HLOpcodeGroup::HLSubscript: {
  674. vecInst = MatSubscriptToVec(CI);
  675. } break;
  676. }
  677. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
  678. Type *Ty = AI->getAllocatedType();
  679. Type *matTy = Ty;
  680. IRBuilder<> Builder(AI);
  681. if (Ty->isArrayTy()) {
  682. Type *vecTy = HLMatrixLower::LowerMatrixArrayPointer(AI->getType());
  683. vecTy = vecTy->getPointerElementType();
  684. vecInst = Builder.CreateAlloca(vecTy, nullptr, AI->getName());
  685. } else {
  686. Type *vecTy = HLMatrixLower::LowerMatrixType(matTy);
  687. vecInst = Builder.CreateAlloca(vecTy, nullptr, AI->getName());
  688. }
  689. // Update debug info.
  690. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI);
  691. if (DDI) {
  692. LLVMContext &Context = AI->getContext();
  693. Value *DDIVar = MetadataAsValue::get(Context, DDI->getRawVariable());
  694. Value *DDIExp = MetadataAsValue::get(Context, DDI->getRawExpression());
  695. Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(vecInst));
  696. IRBuilder<> debugBuilder(DDI);
  697. debugBuilder.CreateCall(DDI->getCalledFunction(), {VMD, DDIVar, DDIExp});
  698. }
  699. if (HLModule::HasPreciseAttributeWithMetadata(AI))
  700. HLModule::MarkPreciseAttributeWithMetadata(vecInst);
  701. } else {
  702. DXASSERT(0, "invalid inst");
  703. }
  704. matToVecMap[matInst] = vecInst;
  705. }
  706. // Replace matInst with vecInst on matUseInst.
  707. void HLMatrixLowerPass::TrivialMatUnOpReplace(CallInst *matInst,
  708. Instruction *vecInst,
  709. CallInst *matUseInst) {
  710. HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(matUseInst));
  711. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
  712. switch (opcode) {
  713. case HLUnaryOpcode::Not:
  714. // Not is xor now
  715. vecUseInst->setOperand(0, vecInst);
  716. vecUseInst->setOperand(1, vecInst);
  717. break;
  718. case HLUnaryOpcode::LNot:
  719. case HLUnaryOpcode::PostInc:
  720. case HLUnaryOpcode::PreInc:
  721. case HLUnaryOpcode::PostDec:
  722. case HLUnaryOpcode::PreDec:
  723. vecUseInst->setOperand(0, vecInst);
  724. break;
  725. }
  726. }
  727. // Replace matInst with vecInst on matUseInst.
  728. void HLMatrixLowerPass::TrivialMatBinOpReplace(CallInst *matInst,
  729. Instruction *vecInst,
  730. CallInst *matUseInst) {
  731. HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(matUseInst));
  732. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
  733. if (opcode != HLBinaryOpcode::LAnd && opcode != HLBinaryOpcode::LOr) {
  734. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) == matInst)
  735. vecUseInst->setOperand(0, vecInst);
  736. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) == matInst)
  737. vecUseInst->setOperand(1, vecInst);
  738. } else {
  739. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) ==
  740. matInst) {
  741. Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(0));
  742. vecCmp->setOperand(0, vecInst);
  743. }
  744. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) ==
  745. matInst) {
  746. Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(1));
  747. vecCmp->setOperand(0, vecInst);
  748. }
  749. }
  750. }
  751. static Function *GetOrCreateMadIntrinsic(Type *Ty, Type *opcodeTy, IntrinsicOp madOp, Module &M) {
  752. llvm::FunctionType *MadFuncTy =
  753. llvm::FunctionType::get(Ty, { opcodeTy, Ty, Ty, Ty}, false);
  754. Function *MAD =
  755. GetOrCreateHLFunction(M, MadFuncTy, HLOpcodeGroup::HLIntrinsic,
  756. (unsigned)madOp);
  757. return MAD;
  758. }
  759. void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
  760. Instruction *vecInst,
  761. CallInst *mulInst, bool isSigned) {
  762. DXASSERT(matToVecMap.count(mulInst), "must has vec version");
  763. Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
  764. // Already translated.
  765. if (!isa<CallInst>(vecUseInst))
  766. return;
  767. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  768. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  769. unsigned col, row;
  770. Type *EltTy = GetMatrixInfo(LVal->getType(), col, row);
  771. unsigned rCol, rRow;
  772. GetMatrixInfo(RVal->getType(), rCol, rRow);
  773. DXASSERT_NOMSG(col == rRow);
  774. bool isFloat = EltTy->isFloatingPointTy();
  775. Value *retVal = llvm::UndefValue::get(LowerMatrixType(mulInst->getType()));
  776. IRBuilder<> Builder(mulInst);
  777. Value *lMat = matToVecMap[cast<Instruction>(LVal)];
  778. Value *rMat = matToVecMap[cast<Instruction>(RVal)];
  779. auto CreateOneEltMul = [&](unsigned r, unsigned lc, unsigned c) -> Value * {
  780. unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
  781. unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
  782. Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
  783. Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
  784. return isFloat ? Builder.CreateFMul(lMatElt, rMatElt)
  785. : Builder.CreateMul(lMatElt, rMatElt);
  786. };
  787. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  788. Type *opcodeTy = Builder.getInt32Ty();
  789. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  790. *m_pHLModule->GetModule());
  791. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  792. auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
  793. Value *acc) -> Value * {
  794. unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
  795. unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
  796. Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
  797. Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
  798. return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
  799. };
  800. for (unsigned r = 0; r < row; r++) {
  801. for (unsigned c = 0; c < rCol; c++) {
  802. unsigned lc = 0;
  803. Value *tmpVal = CreateOneEltMul(r, lc, c);
  804. for (lc = 1; lc < col; lc++) {
  805. tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
  806. }
  807. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, rCol);
  808. retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
  809. }
  810. }
  811. Instruction *matmatMul = cast<Instruction>(retVal);
  812. // Replace vec transpose function call with shuf.
  813. vecUseInst->replaceAllUsesWith(matmatMul);
  814. AddToDeadInsts(vecUseInst);
  815. matToVecMap[mulInst] = matmatMul;
  816. }
  817. void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
  818. Instruction *vecInst,
  819. CallInst *mulInst, bool isSigned) {
  820. // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  821. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  822. unsigned col, row;
  823. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  824. DXASSERT(RVal->getType()->getVectorNumElements() == col, "");
  825. bool isFloat = EltTy->isFloatingPointTy();
  826. Value *retVal = llvm::UndefValue::get(mulInst->getType());
  827. IRBuilder<> Builder(mulInst);
  828. Value *vec = RVal;
  829. Value *mat = vecInst; // vec version of matInst;
  830. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  831. Type *opcodeTy = Builder.getInt32Ty();
  832. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  833. *m_pHLModule->GetModule());
  834. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  835. auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
  836. Value *vecElt = Builder.CreateExtractElement(vec, c);
  837. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  838. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  839. return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
  840. };
  841. for (unsigned r = 0; r < row; r++) {
  842. unsigned c = 0;
  843. Value *vecElt = Builder.CreateExtractElement(vec, c);
  844. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  845. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  846. Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
  847. : Builder.CreateMul(vecElt, matElt);
  848. for (c = 1; c < col; c++) {
  849. tmpVal = CreateOneEltMad(r, c, tmpVal);
  850. }
  851. retVal = Builder.CreateInsertElement(retVal, tmpVal, r);
  852. }
  853. mulInst->replaceAllUsesWith(retVal);
  854. AddToDeadInsts(mulInst);
  855. }
  856. void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
  857. Instruction *vecInst,
  858. CallInst *mulInst, bool isSigned) {
  859. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  860. // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  861. Value *RVal = vecInst;
  862. unsigned col, row;
  863. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  864. DXASSERT(LVal->getType()->getVectorNumElements() == row, "");
  865. bool isFloat = EltTy->isFloatingPointTy();
  866. Value *retVal = llvm::UndefValue::get(mulInst->getType());
  867. IRBuilder<> Builder(mulInst);
  868. Value *vec = LVal;
  869. Value *mat = RVal;
  870. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  871. Type *opcodeTy = Builder.getInt32Ty();
  872. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  873. *m_pHLModule->GetModule());
  874. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  875. auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
  876. Value *vecElt = Builder.CreateExtractElement(vec, r);
  877. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  878. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  879. return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
  880. };
  881. for (unsigned c = 0; c < col; c++) {
  882. unsigned r = 0;
  883. Value *vecElt = Builder.CreateExtractElement(vec, r);
  884. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  885. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  886. Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
  887. : Builder.CreateMul(vecElt, matElt);
  888. for (r = 1; r < row; r++) {
  889. tmpVal = CreateOneEltMad(r, c, tmpVal);
  890. }
  891. retVal = Builder.CreateInsertElement(retVal, tmpVal, c);
  892. }
  893. mulInst->replaceAllUsesWith(retVal);
  894. AddToDeadInsts(mulInst);
  895. }
  896. void HLMatrixLowerPass::TranslateMul(CallInst *matInst, Instruction *vecInst,
  897. CallInst *mulInst, bool isSigned) {
  898. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  899. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  900. bool LMat = IsMatrixType(LVal->getType());
  901. bool RMat = IsMatrixType(RVal->getType());
  902. if (LMat && RMat) {
  903. TranslateMatMatMul(matInst, vecInst, mulInst, isSigned);
  904. } else if (LMat) {
  905. TranslateMatVecMul(matInst, vecInst, mulInst, isSigned);
  906. } else {
  907. TranslateVecMatMul(matInst, vecInst, mulInst, isSigned);
  908. }
  909. }
  910. void HLMatrixLowerPass::TranslateMatTranspose(CallInst *matInst,
  911. Instruction *vecInst,
  912. CallInst *transposeInst) {
  913. // Matrix value is row major, transpose is cast it to col major.
  914. TranslateMatMajorCast(matInst, vecInst, transposeInst, /*bRowToCol*/ true);
  915. }
  916. static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
  917. IRBuilder<> &Builder) {
  918. Value *mul0 = Builder.CreateFMul(m00, m11);
  919. Value *mul1 = Builder.CreateFMul(m01, m10);
  920. return Builder.CreateFSub(mul0, mul1);
  921. }
  922. static Value *Determinant3x3(Value *m00, Value *m01, Value *m02,
  923. Value *m10, Value *m11, Value *m12,
  924. Value *m20, Value *m21, Value *m22,
  925. IRBuilder<> &Builder) {
  926. Value *deter00 = Determinant2x2(m11, m12, m21, m22, Builder);
  927. Value *deter01 = Determinant2x2(m10, m12, m20, m22, Builder);
  928. Value *deter02 = Determinant2x2(m10, m11, m20, m21, Builder);
  929. deter00 = Builder.CreateFMul(m00, deter00);
  930. deter01 = Builder.CreateFMul(m01, deter01);
  931. deter02 = Builder.CreateFMul(m02, deter02);
  932. Value *result = Builder.CreateFSub(deter00, deter01);
  933. result = Builder.CreateFAdd(result, deter02);
  934. return result;
  935. }
  936. static Value *Determinant4x4(Value *m00, Value *m01, Value *m02, Value *m03,
  937. Value *m10, Value *m11, Value *m12, Value *m13,
  938. Value *m20, Value *m21, Value *m22, Value *m23,
  939. Value *m30, Value *m31, Value *m32, Value *m33,
  940. IRBuilder<> &Builder) {
  941. Value *deter00 = Determinant3x3(m11, m12, m13, m21, m22, m23, m31, m32, m33, Builder);
  942. Value *deter01 = Determinant3x3(m10, m12, m13, m20, m22, m23, m30, m32, m33, Builder);
  943. Value *deter02 = Determinant3x3(m10, m11, m13, m20, m21, m23, m30, m31, m33, Builder);
  944. Value *deter03 = Determinant3x3(m10, m11, m12, m20, m21, m22, m30, m31, m32, Builder);
  945. deter00 = Builder.CreateFMul(m00, deter00);
  946. deter01 = Builder.CreateFMul(m01, deter01);
  947. deter02 = Builder.CreateFMul(m02, deter02);
  948. deter03 = Builder.CreateFMul(m03, deter03);
  949. Value *result = Builder.CreateFSub(deter00, deter01);
  950. result = Builder.CreateFAdd(result, deter02);
  951. result = Builder.CreateFSub(result, deter03);
  952. return result;
  953. }
  954. void HLMatrixLowerPass::TranslateMatDeterminant(CallInst *matInst, Instruction *vecInst,
  955. CallInst *determinantInst) {
  956. unsigned row, col;
  957. GetMatrixInfo(matInst->getType(), col, row);
  958. IRBuilder<> Builder(determinantInst);
  959. // when row == 1, result is vecInst.
  960. Value *Result = vecInst;
  961. if (row == 2) {
  962. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  963. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  964. Value *m10 = Builder.CreateExtractElement(vecInst, 2);
  965. Value *m11 = Builder.CreateExtractElement(vecInst, 3);
  966. Result = Determinant2x2(m00, m01, m10, m11, Builder);
  967. }
  968. else if (row == 3) {
  969. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  970. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  971. Value *m02 = Builder.CreateExtractElement(vecInst, 2);
  972. Value *m10 = Builder.CreateExtractElement(vecInst, 3);
  973. Value *m11 = Builder.CreateExtractElement(vecInst, 4);
  974. Value *m12 = Builder.CreateExtractElement(vecInst, 5);
  975. Value *m20 = Builder.CreateExtractElement(vecInst, 6);
  976. Value *m21 = Builder.CreateExtractElement(vecInst, 7);
  977. Value *m22 = Builder.CreateExtractElement(vecInst, 8);
  978. Result = Determinant3x3(m00, m01, m02,
  979. m10, m11, m12,
  980. m20, m21, m22, Builder);
  981. }
  982. else if (row == 4) {
  983. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  984. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  985. Value *m02 = Builder.CreateExtractElement(vecInst, 2);
  986. Value *m03 = Builder.CreateExtractElement(vecInst, 3);
  987. Value *m10 = Builder.CreateExtractElement(vecInst, 4);
  988. Value *m11 = Builder.CreateExtractElement(vecInst, 5);
  989. Value *m12 = Builder.CreateExtractElement(vecInst, 6);
  990. Value *m13 = Builder.CreateExtractElement(vecInst, 7);
  991. Value *m20 = Builder.CreateExtractElement(vecInst, 8);
  992. Value *m21 = Builder.CreateExtractElement(vecInst, 9);
  993. Value *m22 = Builder.CreateExtractElement(vecInst, 10);
  994. Value *m23 = Builder.CreateExtractElement(vecInst, 11);
  995. Value *m30 = Builder.CreateExtractElement(vecInst, 12);
  996. Value *m31 = Builder.CreateExtractElement(vecInst, 13);
  997. Value *m32 = Builder.CreateExtractElement(vecInst, 14);
  998. Value *m33 = Builder.CreateExtractElement(vecInst, 15);
  999. Result = Determinant4x4(m00, m01, m02, m03,
  1000. m10, m11, m12, m13,
  1001. m20, m21, m22, m23,
  1002. m30, m31, m32, m33,
  1003. Builder);
  1004. } else {
  1005. DXASSERT(row == 1, "invalid matrix type");
  1006. Result = Builder.CreateExtractElement(Result, (uint64_t)0);
  1007. }
  1008. determinantInst->replaceAllUsesWith(Result);
  1009. AddToDeadInsts(determinantInst);
  1010. }
  1011. void HLMatrixLowerPass::TrivialMatReplace(CallInst *matInst,
  1012. Instruction *vecInst,
  1013. CallInst *matUseInst) {
  1014. CallInst *vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
  1015. for (unsigned i = 0; i < matUseInst->getNumArgOperands(); i++)
  1016. if (matUseInst->getArgOperand(i) == matInst) {
  1017. vecUseInst->setArgOperand(i, vecInst);
  1018. }
  1019. }
  1020. void HLMatrixLowerPass::TranslateMatMajorCast(CallInst *matInst,
  1021. Instruction *vecInst,
  1022. CallInst *castInst,
  1023. bool bRowToCol) {
  1024. unsigned col, row;
  1025. GetMatrixInfo(castInst->getType(), col, row);
  1026. DXASSERT(castInst->getType() == matInst->getType(), "type must match");
  1027. IRBuilder<> Builder(castInst);
  1028. // shuf to change major.
  1029. SmallVector<int, 16> castMask(col * row);
  1030. unsigned idx = 0;
  1031. if (bRowToCol) {
  1032. for (unsigned c = 0; c < col; c++)
  1033. for (unsigned r = 0; r < row; r++) {
  1034. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  1035. castMask[idx++] = matIdx;
  1036. }
  1037. } else {
  1038. for (unsigned r = 0; r < row; r++)
  1039. for (unsigned c = 0; c < col; c++) {
  1040. unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
  1041. castMask[idx++] = matIdx;
  1042. }
  1043. }
  1044. Instruction *vecCast = cast<Instruction>(
  1045. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1046. // Replace vec cast function call with vecCast.
  1047. DXASSERT(matToVecMap.count(castInst), "must has vec version");
  1048. Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
  1049. vecUseInst->replaceAllUsesWith(vecCast);
  1050. AddToDeadInsts(vecUseInst);
  1051. matToVecMap[castInst] = vecCast;
  1052. }
  1053. void HLMatrixLowerPass::TranslateMatMatCast(CallInst *matInst,
  1054. Instruction *vecInst,
  1055. CallInst *castInst) {
  1056. unsigned toCol, toRow;
  1057. Type *ToEltTy = GetMatrixInfo(castInst->getType(), toCol, toRow);
  1058. unsigned fromCol, fromRow;
  1059. Type *FromEltTy = GetMatrixInfo(matInst->getType(), fromCol, fromRow);
  1060. unsigned fromSize = fromCol * fromRow;
  1061. unsigned toSize = toCol * toRow;
  1062. DXASSERT(fromSize >= toSize, "cannot extend matrix");
  1063. IRBuilder<> Builder(castInst);
  1064. Instruction *vecCast = nullptr;
  1065. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1066. if (fromSize == toSize) {
  1067. vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), vecInst,
  1068. Builder);
  1069. } else {
  1070. // shuf first
  1071. std::vector<int> castMask(toCol * toRow);
  1072. unsigned idx = 0;
  1073. for (unsigned r = 0; r < toRow; r++)
  1074. for (unsigned c = 0; c < toCol; c++) {
  1075. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, fromCol);
  1076. castMask[idx++] = matIdx;
  1077. }
  1078. Instruction *shuf = cast<Instruction>(
  1079. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1080. if (ToEltTy != FromEltTy)
  1081. vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), shuf,
  1082. Builder);
  1083. else
  1084. vecCast = shuf;
  1085. }
  1086. // Replace vec cast function call with vecCast.
  1087. DXASSERT(matToVecMap.count(castInst), "must has vec version");
  1088. Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
  1089. vecUseInst->replaceAllUsesWith(vecCast);
  1090. AddToDeadInsts(vecUseInst);
  1091. matToVecMap[castInst] = vecCast;
  1092. }
  1093. void HLMatrixLowerPass::TranslateMatToOtherCast(CallInst *matInst,
  1094. Instruction *vecInst,
  1095. CallInst *castInst) {
  1096. unsigned col, row;
  1097. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  1098. unsigned fromSize = col * row;
  1099. IRBuilder<> Builder(castInst);
  1100. Instruction *sizeCast = nullptr;
  1101. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1102. Type *ToTy = castInst->getType();
  1103. if (ToTy->isVectorTy()) {
  1104. unsigned toSize = ToTy->getVectorNumElements();
  1105. if (fromSize != toSize) {
  1106. std::vector<int> castMask(fromSize);
  1107. for (unsigned c = 0; c < toSize; c++)
  1108. castMask[c] = c;
  1109. sizeCast = cast<Instruction>(
  1110. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1111. } else
  1112. sizeCast = vecInst;
  1113. } else {
  1114. DXASSERT(ToTy->isSingleValueType(), "must scalar here");
  1115. sizeCast =
  1116. cast<Instruction>(Builder.CreateExtractElement(vecInst, (uint64_t)0));
  1117. }
  1118. Instruction *typeCast = sizeCast;
  1119. if (EltTy != ToTy->getScalarType()) {
  1120. typeCast = CreateTypeCast(opcode, ToTy, typeCast, Builder);
  1121. }
  1122. // Replace cast function call with typeCast.
  1123. castInst->replaceAllUsesWith(typeCast);
  1124. AddToDeadInsts(castInst);
  1125. }
  1126. void HLMatrixLowerPass::TranslateMatCast(CallInst *matInst,
  1127. Instruction *vecInst,
  1128. CallInst *castInst) {
  1129. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1130. if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
  1131. opcode == HLCastOpcode::RowMatrixToColMatrix) {
  1132. TranslateMatMajorCast(matInst, vecInst, castInst,
  1133. opcode == HLCastOpcode::RowMatrixToColMatrix);
  1134. } else {
  1135. bool ToMat = IsMatrixType(castInst->getType());
  1136. bool FromMat = IsMatrixType(matInst->getType());
  1137. if (ToMat && FromMat) {
  1138. TranslateMatMatCast(matInst, vecInst, castInst);
  1139. } else if (FromMat)
  1140. TranslateMatToOtherCast(matInst, vecInst, castInst);
  1141. else
  1142. DXASSERT(0, "Not translate as user of matInst");
  1143. }
  1144. }
  1145. void HLMatrixLowerPass::MatIntrinsicReplace(CallInst *matInst,
  1146. Instruction *vecInst,
  1147. CallInst *matUseInst) {
  1148. IRBuilder<> Builder(matUseInst);
  1149. IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(matUseInst));
  1150. switch (opcode) {
  1151. case IntrinsicOp::IOP_umul:
  1152. TranslateMul(matInst, vecInst, matUseInst, /*isSigned*/false);
  1153. break;
  1154. case IntrinsicOp::IOP_mul:
  1155. TranslateMul(matInst, vecInst, matUseInst, /*isSigned*/true);
  1156. break;
  1157. case IntrinsicOp::IOP_transpose:
  1158. TranslateMatTranspose(matInst, vecInst, matUseInst);
  1159. break;
  1160. case IntrinsicOp::IOP_determinant:
  1161. TranslateMatDeterminant(matInst, vecInst, matUseInst);
  1162. break;
  1163. default:
  1164. CallInst *vecUseInst = nullptr;
  1165. if (matToVecMap.count(matUseInst))
  1166. vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
  1167. for (unsigned i = 0; i < matInst->getNumArgOperands(); i++)
  1168. if (matUseInst->getArgOperand(i) == matInst) {
  1169. if (vecUseInst)
  1170. vecUseInst->setArgOperand(i, vecInst);
  1171. else
  1172. matUseInst->setArgOperand(i, vecInst);
  1173. }
  1174. break;
  1175. }
  1176. }
  1177. void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
  1178. CallInst *matSubInst) {
  1179. unsigned opcode = GetHLOpcode(matSubInst);
  1180. HLSubscriptOpcode matOpcode = static_cast<HLSubscriptOpcode>(opcode);
  1181. assert(matOpcode != HLSubscriptOpcode::DefaultSubscript &&
  1182. "matrix don't use default subscript");
  1183. Type *matType = matInst->getType()->getPointerElementType();
  1184. unsigned col, row;
  1185. Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
  1186. bool isElement = (matOpcode == HLSubscriptOpcode::ColMatElement) |
  1187. (matOpcode == HLSubscriptOpcode::RowMatElement);
  1188. Value *mask =
  1189. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1190. if (isElement) {
  1191. Type *resultType = matSubInst->getType()->getPointerElementType();
  1192. unsigned resultSize = 1;
  1193. if (resultType->isVectorTy())
  1194. resultSize = resultType->getVectorNumElements();
  1195. std::vector<int> shufMask(resultSize);
  1196. Constant *EltIdxs = cast<Constant>(mask);
  1197. for (unsigned i = 0; i < resultSize; i++) {
  1198. shufMask[i] =
  1199. cast<ConstantInt>(EltIdxs->getAggregateElement(i))->getLimitedValue();
  1200. }
  1201. for (Value::use_iterator CallUI = matSubInst->use_begin(),
  1202. CallE = matSubInst->use_end();
  1203. CallUI != CallE;) {
  1204. Use &CallUse = *CallUI++;
  1205. Instruction *CallUser = cast<Instruction>(CallUse.getUser());
  1206. IRBuilder<> Builder(CallUser);
  1207. Value *vecLd = Builder.CreateLoad(vecInst);
  1208. if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
  1209. if (resultSize > 1) {
  1210. Value *shuf = Builder.CreateShuffleVector(vecLd, vecLd, shufMask);
  1211. ld->replaceAllUsesWith(shuf);
  1212. } else {
  1213. Value *elt = Builder.CreateExtractElement(vecLd, shufMask[0]);
  1214. ld->replaceAllUsesWith(elt);
  1215. }
  1216. } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
  1217. Value *val = st->getValueOperand();
  1218. if (resultSize > 1) {
  1219. for (unsigned i = 0; i < shufMask.size(); i++) {
  1220. unsigned idx = shufMask[i];
  1221. Value *valElt = Builder.CreateExtractElement(val, i);
  1222. vecLd = Builder.CreateInsertElement(vecLd, valElt, idx);
  1223. }
  1224. Builder.CreateStore(vecLd, vecInst);
  1225. } else {
  1226. vecLd = Builder.CreateInsertElement(vecLd, val, shufMask[0]);
  1227. Builder.CreateStore(vecLd, vecInst);
  1228. }
  1229. } else
  1230. DXASSERT(0, "matrix element should only used by load/store.");
  1231. AddToDeadInsts(CallUser);
  1232. }
  1233. } else {
  1234. // Subscript.
  1235. // Return a row.
  1236. // Use insertElement and extractElement.
  1237. ArrayType *AT = ArrayType::get(EltTy, col*row);
  1238. IRBuilder<> AllocaBuilder(
  1239. matSubInst->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  1240. Value *tempArray = AllocaBuilder.CreateAlloca(AT);
  1241. Value *zero = AllocaBuilder.getInt32(0);
  1242. bool isDynamicIndexing = !isa<ConstantInt>(mask);
  1243. SmallVector<Value *, 4> idxList;
  1244. for (unsigned i = 0; i < col; i++) {
  1245. idxList.emplace_back(
  1246. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i));
  1247. }
  1248. for (Value::use_iterator CallUI = matSubInst->use_begin(),
  1249. CallE = matSubInst->use_end();
  1250. CallUI != CallE;) {
  1251. Use &CallUse = *CallUI++;
  1252. Instruction *CallUser = cast<Instruction>(CallUse.getUser());
  1253. IRBuilder<> Builder(CallUser);
  1254. Value *vecLd = Builder.CreateLoad(vecInst);
  1255. if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
  1256. Value *sub = UndefValue::get(ld->getType());
  1257. if (!isDynamicIndexing) {
  1258. for (unsigned i = 0; i < col; i++) {
  1259. Value *matIdx = idxList[i];
  1260. Value *valElt = Builder.CreateExtractElement(vecLd, matIdx);
  1261. sub = Builder.CreateInsertElement(sub, valElt, i);
  1262. }
  1263. } else {
  1264. // Copy vec to array.
  1265. for (unsigned int i = 0; i < row*col; i++) {
  1266. Value *Elt =
  1267. Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
  1268. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1269. {zero, Builder.getInt32(i)});
  1270. Builder.CreateStore(Elt, Ptr);
  1271. }
  1272. for (unsigned i = 0; i < col; i++) {
  1273. Value *matIdx = idxList[i];
  1274. Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
  1275. Value *valElt = Builder.CreateLoad(Ptr);
  1276. sub = Builder.CreateInsertElement(sub, valElt, i);
  1277. }
  1278. }
  1279. ld->replaceAllUsesWith(sub);
  1280. } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
  1281. Value *val = st->getValueOperand();
  1282. if (!isDynamicIndexing) {
  1283. for (unsigned i = 0; i < col; i++) {
  1284. Value *matIdx = idxList[i];
  1285. Value *valElt = Builder.CreateExtractElement(val, i);
  1286. vecLd = Builder.CreateInsertElement(vecLd, valElt, matIdx);
  1287. }
  1288. } else {
  1289. // Copy vec to array.
  1290. for (unsigned int i = 0; i < row * col; i++) {
  1291. Value *Elt =
  1292. Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
  1293. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1294. {zero, Builder.getInt32(i)});
  1295. Builder.CreateStore(Elt, Ptr);
  1296. }
  1297. // Update array.
  1298. for (unsigned i = 0; i < col; i++) {
  1299. Value *matIdx = idxList[i];
  1300. Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
  1301. Value *valElt = Builder.CreateExtractElement(val, i);
  1302. Builder.CreateStore(valElt, Ptr);
  1303. }
  1304. // Copy array to vec.
  1305. for (unsigned int i = 0; i < row * col; i++) {
  1306. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1307. {zero, Builder.getInt32(i)});
  1308. Value *Elt = Builder.CreateLoad(Ptr);
  1309. vecLd = Builder.CreateInsertElement(vecLd, Elt, i);
  1310. }
  1311. }
  1312. Builder.CreateStore(vecLd, vecInst);
  1313. } else if (GetElementPtrInst *GEP =
  1314. dyn_cast<GetElementPtrInst>(CallUser)) {
  1315. Value *GEPOffset = HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
  1316. Value *NewGEP = Builder.CreateGEP(vecInst, {zero, GEPOffset});
  1317. GEP->replaceAllUsesWith(NewGEP);
  1318. } else
  1319. DXASSERT(0, "matrix subscript should only used by load/store.");
  1320. AddToDeadInsts(CallUser);
  1321. }
  1322. }
  1323. // Check vec version.
  1324. DXASSERT(matToVecMap.count(matSubInst) == 0, "should not have vec version");
  1325. // All the user should has been removed.
  1326. matSubInst->replaceAllUsesWith(UndefValue::get(matSubInst->getType()));
  1327. AddToDeadInsts(matSubInst);
  1328. }
  1329. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
  1330. Value *matGlobal, ArrayRef<Value *> vecGlobals,
  1331. CallInst *matLdStInst) {
  1332. // No dynamic indexing on matrix, flatten matrix to scalars.
  1333. // vecGlobals already in correct major.
  1334. Type *matType = matGlobal->getType()->getPointerElementType();
  1335. unsigned col, row;
  1336. HLMatrixLower::GetMatrixInfo(matType, col, row);
  1337. Type *vecType = HLMatrixLower::LowerMatrixType(matType);
  1338. IRBuilder<> Builder(matLdStInst);
  1339. HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
  1340. switch (opcode) {
  1341. case HLMatLoadStoreOpcode::ColMatLoad:
  1342. case HLMatLoadStoreOpcode::RowMatLoad: {
  1343. Value *Result = UndefValue::get(vecType);
  1344. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1345. Value *Elt = Builder.CreateLoad(vecGlobals[matIdx]);
  1346. Result = Builder.CreateInsertElement(Result, Elt, matIdx);
  1347. }
  1348. matLdStInst->replaceAllUsesWith(Result);
  1349. } break;
  1350. case HLMatLoadStoreOpcode::ColMatStore:
  1351. case HLMatLoadStoreOpcode::RowMatStore: {
  1352. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1353. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1354. Value *Elt = Builder.CreateExtractElement(Val, matIdx);
  1355. Builder.CreateStore(Elt, vecGlobals[matIdx]);
  1356. }
  1357. } break;
  1358. }
  1359. }
  1360. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
  1361. GlobalVariable *scalarArrayGlobal,
  1362. CallInst *matLdStInst) {
  1363. // vecGlobals already in correct major.
  1364. const bool bColMajor = true;
  1365. HLMatLoadStoreOpcode opcode =
  1366. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
  1367. switch (opcode) {
  1368. case HLMatLoadStoreOpcode::ColMatLoad:
  1369. case HLMatLoadStoreOpcode::RowMatLoad: {
  1370. IRBuilder<> Builder(matLdStInst);
  1371. Type *matTy = matGlobal->getType()->getPointerElementType();
  1372. unsigned col, row;
  1373. Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1374. Value *zeroIdx = Builder.getInt32(0);
  1375. std::vector<Value *> matElts(col * row);
  1376. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1377. Value *GEP = Builder.CreateInBoundsGEP(
  1378. scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
  1379. matElts[matIdx] = Builder.CreateLoad(GEP);
  1380. }
  1381. Value *newVec =
  1382. HLMatrixLower::BuildVector(EltTy, col * row, matElts, Builder);
  1383. matLdStInst->replaceAllUsesWith(newVec);
  1384. matLdStInst->eraseFromParent();
  1385. } break;
  1386. case HLMatLoadStoreOpcode::ColMatStore:
  1387. case HLMatLoadStoreOpcode::RowMatStore: {
  1388. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1389. IRBuilder<> Builder(matLdStInst);
  1390. Type *matTy = matGlobal->getType()->getPointerElementType();
  1391. unsigned col, row;
  1392. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1393. Value *zeroIdx = Builder.getInt32(0);
  1394. std::vector<Value *> matElts(col * row);
  1395. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1396. Value *GEP = Builder.CreateInBoundsGEP(
  1397. scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
  1398. Value *Elt = Builder.CreateExtractElement(Val, matIdx);
  1399. Builder.CreateStore(Elt, GEP);
  1400. }
  1401. matLdStInst->eraseFromParent();
  1402. } break;
  1403. }
  1404. }
  1405. void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(
  1406. CallInst *matSubInst, Value *vecPtr) {
  1407. Value *basePtr =
  1408. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  1409. Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1410. IRBuilder<> subBuilder(matSubInst);
  1411. Value *zeroIdx = subBuilder.getInt32(0);
  1412. HLSubscriptOpcode opcode =
  1413. static_cast<HLSubscriptOpcode>(GetHLOpcode(matSubInst));
  1414. Type *matTy = basePtr->getType()->getPointerElementType();
  1415. unsigned col, row;
  1416. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1417. std::vector<Value *> idxList;
  1418. switch (opcode) {
  1419. case HLSubscriptOpcode::ColMatSubscript:
  1420. case HLSubscriptOpcode::RowMatSubscript: {
  1421. // Just use index created in EmitHLSLMatrixSubscript.
  1422. for (unsigned c = 0; c < col; c++) {
  1423. Value *matIdx =
  1424. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + c);
  1425. idxList.emplace_back(matIdx);
  1426. }
  1427. } break;
  1428. case HLSubscriptOpcode::RowMatElement:
  1429. case HLSubscriptOpcode::ColMatElement: {
  1430. Type *resultType = matSubInst->getType()->getPointerElementType();
  1431. unsigned resultSize = 1;
  1432. if (resultType->isVectorTy())
  1433. resultSize = resultType->getVectorNumElements();
  1434. // Just use index created in EmitHLSLMatrixElement.
  1435. Constant *EltIdxs = cast<Constant>(idx);
  1436. for (unsigned i = 0; i < resultSize; i++) {
  1437. Value *matIdx = EltIdxs->getAggregateElement(i);
  1438. idxList.emplace_back(matIdx);
  1439. }
  1440. } break;
  1441. default:
  1442. DXASSERT(0, "invalid operation");
  1443. break;
  1444. }
  1445. // Cannot generate vector pointer
  1446. // Replace all uses with scalar pointers.
  1447. if (idxList.size() == 1) {
  1448. Value *Ptr =
  1449. subBuilder.CreateInBoundsGEP(vecPtr, {zeroIdx, idxList[0]});
  1450. matSubInst->replaceAllUsesWith(Ptr);
  1451. } else {
  1452. // Split the use of CI with Ptrs.
  1453. for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
  1454. Instruction *subsUser = cast<Instruction>(*(U++));
  1455. IRBuilder<> userBuilder(subsUser);
  1456. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
  1457. Value *IndexPtr =
  1458. HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
  1459. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1460. {zeroIdx, IndexPtr});
  1461. for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
  1462. Instruction *gepUser = cast<Instruction>(*(gepU++));
  1463. IRBuilder<> gepUserBuilder(gepUser);
  1464. if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
  1465. Value *subData = stUser->getValueOperand();
  1466. gepUserBuilder.CreateStore(subData, Ptr);
  1467. stUser->eraseFromParent();
  1468. } else if (LoadInst *ldUser = dyn_cast<LoadInst>(gepUser)) {
  1469. Value *subData = gepUserBuilder.CreateLoad(Ptr);
  1470. ldUser->replaceAllUsesWith(subData);
  1471. ldUser->eraseFromParent();
  1472. } else {
  1473. AddrSpaceCastInst *Cast = cast<AddrSpaceCastInst>(gepUser);
  1474. Cast->setOperand(0, Ptr);
  1475. }
  1476. }
  1477. GEP->eraseFromParent();
  1478. } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
  1479. Value *val = stUser->getValueOperand();
  1480. for (unsigned i = 0; i < idxList.size(); i++) {
  1481. Value *Elt = userBuilder.CreateExtractElement(val, i);
  1482. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1483. {zeroIdx, idxList[i]});
  1484. userBuilder.CreateStore(Elt, Ptr);
  1485. }
  1486. stUser->eraseFromParent();
  1487. } else {
  1488. Value *ldVal =
  1489. UndefValue::get(matSubInst->getType()->getPointerElementType());
  1490. for (unsigned i = 0; i < idxList.size(); i++) {
  1491. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1492. {zeroIdx, idxList[i]});
  1493. Value *Elt = userBuilder.CreateLoad(Ptr);
  1494. ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
  1495. }
  1496. // Must be load here.
  1497. LoadInst *ldUser = cast<LoadInst>(subsUser);
  1498. ldUser->replaceAllUsesWith(ldVal);
  1499. ldUser->eraseFromParent();
  1500. }
  1501. }
  1502. }
  1503. matSubInst->eraseFromParent();
  1504. }
  1505. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
  1506. CallInst *matLdStInst, Value *vecPtr) {
  1507. // Just translate into vector here.
  1508. // DynamicIndexingVectorToArray will change it to scalar array.
  1509. IRBuilder<> Builder(matLdStInst);
  1510. unsigned opcode = hlsl::GetHLOpcode(matLdStInst);
  1511. HLMatLoadStoreOpcode matLdStOp = static_cast<HLMatLoadStoreOpcode>(opcode);
  1512. switch (matLdStOp) {
  1513. case HLMatLoadStoreOpcode::ColMatLoad:
  1514. case HLMatLoadStoreOpcode::RowMatLoad: {
  1515. // Load as vector.
  1516. Value *newLoad = Builder.CreateLoad(vecPtr);
  1517. matLdStInst->replaceAllUsesWith(newLoad);
  1518. matLdStInst->eraseFromParent();
  1519. } break;
  1520. case HLMatLoadStoreOpcode::ColMatStore:
  1521. case HLMatLoadStoreOpcode::RowMatStore: {
  1522. // Change value to vector array, then store.
  1523. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1524. Value *vecArrayGep = vecPtr;
  1525. Builder.CreateStore(Val, vecArrayGep);
  1526. matLdStInst->eraseFromParent();
  1527. } break;
  1528. default:
  1529. DXASSERT(0, "invalid operation");
  1530. break;
  1531. }
  1532. }
  1533. // Flatten values inside init list to scalar elements.
  1534. static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
  1535. Value *val,
  1536. DenseMap<Instruction *, Value *> &matToVecMap,
  1537. IRBuilder<> &Builder) {
  1538. Type *valTy = val->getType();
  1539. if (valTy->isPointerTy()) {
  1540. if (HLMatrixLower::IsMatrixArrayPointer(valTy)) {
  1541. if (matToVecMap.count(cast<Instruction>(val))) {
  1542. val = matToVecMap[cast<Instruction>(val)];
  1543. } else {
  1544. // Convert to vec array with bitcast.
  1545. Type *vecArrayPtrTy = HLMatrixLower::LowerMatrixArrayPointer(valTy);
  1546. val = Builder.CreateBitCast(val, vecArrayPtrTy);
  1547. }
  1548. }
  1549. Type *valEltTy = val->getType()->getPointerElementType();
  1550. if (valEltTy->isVectorTy() || HLMatrixLower::IsMatrixType(valEltTy) ||
  1551. valEltTy->isSingleValueType()) {
  1552. Value *ldVal = Builder.CreateLoad(val);
  1553. IterateInitList(elts, idx, ldVal, matToVecMap, Builder);
  1554. } else {
  1555. Type *i32Ty = Type::getInt32Ty(valTy->getContext());
  1556. Value *zero = ConstantInt::get(i32Ty, 0);
  1557. if (ArrayType *AT = dyn_cast<ArrayType>(valEltTy)) {
  1558. for (unsigned i = 0; i < AT->getArrayNumElements(); i++) {
  1559. Value *gepIdx = ConstantInt::get(i32Ty, i);
  1560. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  1561. IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
  1562. }
  1563. } else {
  1564. // Struct.
  1565. StructType *ST = cast<StructType>(valEltTy);
  1566. for (unsigned i = 0; i < ST->getNumElements(); i++) {
  1567. Value *gepIdx = ConstantInt::get(i32Ty, i);
  1568. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  1569. IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
  1570. }
  1571. }
  1572. }
  1573. } else if (HLMatrixLower::IsMatrixType(valTy)) {
  1574. unsigned col, row;
  1575. HLMatrixLower::GetMatrixInfo(valTy, col, row);
  1576. unsigned matSize = col * row;
  1577. val = matToVecMap[cast<Instruction>(val)];
  1578. // temp matrix all row major
  1579. for (unsigned i = 0; i < matSize; i++) {
  1580. Value *Elt = Builder.CreateExtractElement(val, i);
  1581. elts[idx + i] = Elt;
  1582. }
  1583. idx += matSize;
  1584. } else {
  1585. if (valTy->isVectorTy()) {
  1586. unsigned vecSize = valTy->getVectorNumElements();
  1587. for (unsigned i = 0; i < vecSize; i++) {
  1588. Value *Elt = Builder.CreateExtractElement(val, i);
  1589. elts[idx + i] = Elt;
  1590. }
  1591. idx += vecSize;
  1592. } else {
  1593. DXASSERT(valTy->isSingleValueType(), "must be single value type here");
  1594. elts[idx++] = val;
  1595. }
  1596. }
  1597. }
  1598. // Store flattened init list elements into matrix array.
  1599. static void GenerateMatArrayInit(ArrayRef<Value *> elts, Value *ptr,
  1600. unsigned &offset, IRBuilder<> &Builder) {
  1601. Type *Ty = ptr->getType()->getPointerElementType();
  1602. if (Ty->isVectorTy()) {
  1603. unsigned vecSize = Ty->getVectorNumElements();
  1604. Type *eltTy = Ty->getVectorElementType();
  1605. Value *result = UndefValue::get(Ty);
  1606. for (unsigned i = 0; i < vecSize; i++) {
  1607. Value *elt = elts[offset + i];
  1608. if (elt->getType() != eltTy) {
  1609. // FIXME: get signed/unsigned info.
  1610. elt = CreateTypeCast(HLCastOpcode::DefaultCast, eltTy, elt, Builder);
  1611. }
  1612. result = Builder.CreateInsertElement(result, elt, i);
  1613. }
  1614. // Update offset.
  1615. offset += vecSize;
  1616. Builder.CreateStore(result, ptr);
  1617. } else {
  1618. DXASSERT(Ty->isArrayTy(), "must be array type");
  1619. Type *i32Ty = Type::getInt32Ty(Ty->getContext());
  1620. Constant *zero = ConstantInt::get(i32Ty, 0);
  1621. unsigned arraySize = Ty->getArrayNumElements();
  1622. for (unsigned i = 0; i < arraySize; i++) {
  1623. Value *GEP =
  1624. Builder.CreateInBoundsGEP(ptr, {zero, ConstantInt::get(i32Ty, i)});
  1625. GenerateMatArrayInit(elts, GEP, offset, Builder);
  1626. }
  1627. }
  1628. }
  1629. void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
  1630. // Array matrix init will be translated in TranslateMatArrayInitReplace.
  1631. if (matInitInst->getType()->isVoidTy())
  1632. return;
  1633. IRBuilder<> Builder(matInitInst);
  1634. unsigned col, row;
  1635. Type *EltTy = GetMatrixInfo(matInitInst->getType(), col, row);
  1636. Type *vecTy = VectorType::get(EltTy, col * row);
  1637. unsigned vecSize = vecTy->getVectorNumElements();
  1638. unsigned idx = 0;
  1639. std::vector<Value *> elts(vecSize);
  1640. // Skip opcode arg.
  1641. for (unsigned i = 1; i < matInitInst->getNumArgOperands(); i++) {
  1642. Value *val = matInitInst->getArgOperand(i);
  1643. IterateInitList(elts, idx, val, matToVecMap, Builder);
  1644. }
  1645. Value *newInit = UndefValue::get(vecTy);
  1646. // InitList is row major, the result is row major too.
  1647. for (unsigned i=0;i< col * row;i++) {
  1648. Constant *vecIdx = Builder.getInt32(i);
  1649. newInit = InsertElementInst::Create(newInit, elts[i], vecIdx);
  1650. Builder.Insert(cast<Instruction>(newInit));
  1651. }
  1652. // Replace matInit function call with matInitInst.
  1653. DXASSERT(matToVecMap.count(matInitInst), "must has vec version");
  1654. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
  1655. vecUseInst->replaceAllUsesWith(newInit);
  1656. AddToDeadInsts(vecUseInst);
  1657. matToVecMap[matInitInst] = newInit;
  1658. }
  1659. void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
  1660. IRBuilder<> Builder(matSelectInst);
  1661. unsigned col, row;
  1662. Type *EltTy = GetMatrixInfo(matSelectInst->getType(), col, row);
  1663. Type *vecTy = VectorType::get(EltTy, col * row);
  1664. unsigned vecSize = vecTy->getVectorNumElements();
  1665. CallInst *vecUseInst = cast<CallInst>(matToVecMap[matSelectInst]);
  1666. Instruction *LHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx));
  1667. Instruction *RHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx));
  1668. Value *Cond = vecUseInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
  1669. bool isVecCond = Cond->getType()->isVectorTy();
  1670. if (isVecCond) {
  1671. Instruction *MatCond = cast<Instruction>(
  1672. matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx));
  1673. DXASSERT_NOMSG(matToVecMap.count(MatCond));
  1674. Cond = matToVecMap[MatCond];
  1675. }
  1676. DXASSERT_NOMSG(matToVecMap.count(LHS));
  1677. Value *VLHS = matToVecMap[LHS];
  1678. DXASSERT_NOMSG(matToVecMap.count(RHS));
  1679. Value *VRHS = matToVecMap[RHS];
  1680. Value *VecSelect = UndefValue::get(vecTy);
  1681. for (unsigned i = 0; i < vecSize; i++) {
  1682. llvm::Value *EltCond = Cond;
  1683. if (isVecCond)
  1684. EltCond = Builder.CreateExtractElement(Cond, i);
  1685. llvm::Value *EltL = Builder.CreateExtractElement(VLHS, i);
  1686. llvm::Value *EltR = Builder.CreateExtractElement(VRHS, i);
  1687. llvm::Value *EltSelect = Builder.CreateSelect(EltCond, EltL, EltR);
  1688. VecSelect = Builder.CreateInsertElement(VecSelect, EltSelect, i);
  1689. }
  1690. AddToDeadInsts(vecUseInst);
  1691. vecUseInst->replaceAllUsesWith(VecSelect);
  1692. matToVecMap[matSelectInst] = VecSelect;
  1693. }
  1694. void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
  1695. Instruction *vecInst,
  1696. GetElementPtrInst *matGEP) {
  1697. SmallVector<Value *, 4> idxList(matGEP->idx_begin(), matGEP->idx_end());
  1698. IRBuilder<> GEPBuilder(matGEP);
  1699. Value *newGEP = GEPBuilder.CreateInBoundsGEP(vecInst, idxList);
  1700. // Only used by mat subscript and mat ld/st.
  1701. for (Value::user_iterator user = matGEP->user_begin();
  1702. user != matGEP->user_end();) {
  1703. Instruction *useInst = cast<Instruction>(*(user++));
  1704. IRBuilder<> Builder(useInst);
  1705. // Skip return here.
  1706. if (isa<ReturnInst>(useInst))
  1707. continue;
  1708. if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
  1709. // Function call.
  1710. hlsl::HLOpcodeGroup group =
  1711. hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
  1712. switch (group) {
  1713. case HLOpcodeGroup::HLMatLoadStore: {
  1714. unsigned opcode = GetHLOpcode(useCall);
  1715. HLMatLoadStoreOpcode matOpcode =
  1716. static_cast<HLMatLoadStoreOpcode>(opcode);
  1717. switch (matOpcode) {
  1718. case HLMatLoadStoreOpcode::ColMatLoad:
  1719. case HLMatLoadStoreOpcode::RowMatLoad: {
  1720. // Skip the vector version.
  1721. if (useCall->getType()->isVectorTy())
  1722. continue;
  1723. Value *newLd = Builder.CreateLoad(newGEP);
  1724. DXASSERT(matToVecMap.count(useCall), "must has vec version");
  1725. Value *oldLd = matToVecMap[useCall];
  1726. // Delete the oldLd.
  1727. AddToDeadInsts(cast<Instruction>(oldLd));
  1728. oldLd->replaceAllUsesWith(newLd);
  1729. matToVecMap[useCall] = newLd;
  1730. } break;
  1731. case HLMatLoadStoreOpcode::ColMatStore:
  1732. case HLMatLoadStoreOpcode::RowMatStore: {
  1733. Value *vecPtr = newGEP;
  1734. Value *matVal = useCall->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1735. // Skip the vector version.
  1736. if (matVal->getType()->isVectorTy()) {
  1737. AddToDeadInsts(useCall);
  1738. continue;
  1739. }
  1740. Instruction *matInst = cast<Instruction>(matVal);
  1741. DXASSERT(matToVecMap.count(matInst), "must has vec version");
  1742. Value *vecVal = matToVecMap[matInst];
  1743. Builder.CreateStore(vecVal, vecPtr);
  1744. } break;
  1745. }
  1746. } break;
  1747. case HLOpcodeGroup::HLSubscript: {
  1748. TranslateMatSubscript(matGEP, newGEP, useCall);
  1749. } break;
  1750. default:
  1751. DXASSERT(0, "invalid operation");
  1752. break;
  1753. }
  1754. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
  1755. // Just replace the src with vec version.
  1756. useInst->setOperand(0, newGEP);
  1757. } else {
  1758. // Must be GEP.
  1759. GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
  1760. TranslateMatArrayGEP(matGEP, cast<Instruction>(newGEP), GEP);
  1761. }
  1762. }
  1763. AddToDeadInsts(matGEP);
  1764. }
  1765. void HLMatrixLowerPass::replaceMatWithVec(Instruction *matInst,
  1766. Instruction *vecInst) {
  1767. for (Value::user_iterator user = matInst->user_begin();
  1768. user != matInst->user_end();) {
  1769. Instruction *useInst = cast<Instruction>(*(user++));
  1770. // Skip return here.
  1771. if (isa<ReturnInst>(useInst))
  1772. continue;
  1773. // User must be function call.
  1774. if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
  1775. hlsl::HLOpcodeGroup group =
  1776. hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
  1777. switch (group) {
  1778. case HLOpcodeGroup::HLIntrinsic: {
  1779. MatIntrinsicReplace(cast<CallInst>(matInst), vecInst, useCall);
  1780. } break;
  1781. case HLOpcodeGroup::HLSelect: {
  1782. MatIntrinsicReplace(cast<CallInst>(matInst), vecInst, useCall);
  1783. } break;
  1784. case HLOpcodeGroup::HLBinOp: {
  1785. TrivialMatBinOpReplace(cast<CallInst>(matInst), vecInst, useCall);
  1786. } break;
  1787. case HLOpcodeGroup::HLUnOp: {
  1788. TrivialMatUnOpReplace(cast<CallInst>(matInst), vecInst, useCall);
  1789. } break;
  1790. case HLOpcodeGroup::HLCast: {
  1791. TranslateMatCast(cast<CallInst>(matInst), vecInst, useCall);
  1792. } break;
  1793. case HLOpcodeGroup::HLMatLoadStore: {
  1794. DXASSERT(matToVecMap.count(useCall), "must has vec version");
  1795. Value *vecUser = matToVecMap[useCall];
  1796. if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
  1797. // Load Already translated in lowerToVec.
  1798. // Store val operand will be set by the val use.
  1799. // Do nothing here.
  1800. } else if (StoreInst *stInst = dyn_cast<StoreInst>(vecUser))
  1801. stInst->setOperand(0, vecInst);
  1802. else
  1803. TrivialMatReplace(cast<CallInst>(matInst), vecInst, useCall);
  1804. } break;
  1805. case HLOpcodeGroup::HLSubscript: {
  1806. if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst))
  1807. TranslateMatSubscript(AI, vecInst, useCall);
  1808. else
  1809. TrivialMatReplace(cast<CallInst>(matInst), vecInst, useCall);
  1810. } break;
  1811. case HLOpcodeGroup::HLInit: {
  1812. DXASSERT(!isa<AllocaInst>(matInst), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
  1813. TranslateMatInit(useCall);
  1814. } break;
  1815. }
  1816. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
  1817. // Just replace the src with vec version.
  1818. useInst->setOperand(0, vecInst);
  1819. } else {
  1820. // Must be GEP on mat array alloca.
  1821. GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
  1822. AllocaInst *AI = cast<AllocaInst>(matInst);
  1823. TranslateMatArrayGEP(AI, vecInst, GEP);
  1824. }
  1825. }
  1826. }
  1827. void HLMatrixLowerPass::finalMatTranslation(Instruction *matInst) {
  1828. // Translate matInit.
  1829. if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
  1830. hlsl::HLOpcodeGroup group =
  1831. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  1832. switch (group) {
  1833. case HLOpcodeGroup::HLInit: {
  1834. TranslateMatInit(CI);
  1835. } break;
  1836. case HLOpcodeGroup::HLSelect: {
  1837. TranslateMatSelect(CI);
  1838. } break;
  1839. default:
  1840. // Skip group already translated.
  1841. break;
  1842. }
  1843. }
  1844. }
  1845. void HLMatrixLowerPass::DeleteDeadInsts() {
  1846. // Delete the matrix version insts.
  1847. for (Instruction *deadInst : m_deadInsts) {
  1848. // Replace with undef and remove it.
  1849. deadInst->replaceAllUsesWith(UndefValue::get(deadInst->getType()));
  1850. deadInst->eraseFromParent();
  1851. }
  1852. m_deadInsts.clear();
  1853. m_inDeadInstsSet.clear();
  1854. }
  1855. static bool OnlyUsedByMatrixLdSt(Value *V) {
  1856. bool onlyLdSt = true;
  1857. for (User *user : V->users()) {
  1858. if (isa<Constant>(user) && user->use_empty())
  1859. continue;
  1860. CallInst *CI = cast<CallInst>(user);
  1861. if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
  1862. HLOpcodeGroup::HLMatLoadStore)
  1863. continue;
  1864. onlyLdSt = false;
  1865. break;
  1866. }
  1867. return onlyLdSt;
  1868. }
  1869. static Constant *LowerMatrixArrayConst(Constant *MA, Type *ResultTy) {
  1870. if (ArrayType *AT = dyn_cast<ArrayType>(ResultTy)) {
  1871. std::vector<Constant *> Elts;
  1872. Type *EltResultTy = AT->getElementType();
  1873. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  1874. Constant *Elt =
  1875. LowerMatrixArrayConst(MA->getAggregateElement(i), EltResultTy);
  1876. Elts.emplace_back(Elt);
  1877. }
  1878. return ConstantArray::get(AT, Elts);
  1879. } else {
  1880. // Cast float[row][col] -> float< row * col>.
  1881. // Get float[row][col] from the struct.
  1882. Constant *rows = MA->getAggregateElement((unsigned)0);
  1883. ArrayType *RowAT = cast<ArrayType>(rows->getType());
  1884. std::vector<Constant *> Elts;
  1885. for (unsigned r=0;r<RowAT->getArrayNumElements();r++) {
  1886. Constant *row = rows->getAggregateElement(r);
  1887. VectorType *VT = cast<VectorType>(row->getType());
  1888. for (unsigned c = 0; c < VT->getVectorNumElements(); c++) {
  1889. Elts.emplace_back(row->getAggregateElement(c));
  1890. }
  1891. }
  1892. return ConstantVector::get(Elts);
  1893. }
  1894. }
  1895. void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
  1896. // Lower to array of vector array like float[row * col].
  1897. // It's follow the major of decl.
  1898. // DynamicIndexingVectorToArray will change it to scalar array.
  1899. Type *Ty = GV->getType()->getPointerElementType();
  1900. std::vector<unsigned> arraySizeList;
  1901. while (Ty->isArrayTy()) {
  1902. arraySizeList.push_back(Ty->getArrayNumElements());
  1903. Ty = Ty->getArrayElementType();
  1904. }
  1905. unsigned row, col;
  1906. Type *EltTy = GetMatrixInfo(Ty, col, row);
  1907. Ty = VectorType::get(EltTy, col * row);
  1908. for (auto arraySize = arraySizeList.rbegin();
  1909. arraySize != arraySizeList.rend(); arraySize++)
  1910. Ty = ArrayType::get(Ty, *arraySize);
  1911. Type *VecArrayTy = Ty;
  1912. Constant *OldInitVal = GV->getInitializer();
  1913. Constant *InitVal =
  1914. isa<UndefValue>(OldInitVal)
  1915. ? UndefValue::get(VecArrayTy)
  1916. : LowerMatrixArrayConst(OldInitVal, cast<ArrayType>(VecArrayTy));
  1917. bool isConst = GV->isConstant();
  1918. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  1919. unsigned AddressSpace = GV->getType()->getAddressSpace();
  1920. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  1921. Module *M = GV->getParent();
  1922. GlobalVariable *VecGV =
  1923. new llvm::GlobalVariable(*M, VecArrayTy, /*IsConstant*/ isConst, linkage,
  1924. /*InitVal*/ InitVal, GV->getName() + ".v",
  1925. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  1926. // Add debug info.
  1927. if (m_HasDbgInfo) {
  1928. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  1929. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, VecGV);
  1930. }
  1931. DenseMap<Instruction *, Value *> matToVecMap;
  1932. for (User *U : GV->users()) {
  1933. Value *VecGEP = nullptr;
  1934. // Must be GEP or GEPOperator.
  1935. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  1936. IRBuilder<> Builder(GEP);
  1937. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  1938. VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
  1939. AddToDeadInsts(GEP);
  1940. } else {
  1941. GEPOperator *GEPOP = cast<GEPOperator>(U);
  1942. IRBuilder<> Builder(GV->getContext());
  1943. SmallVector<Value *, 4> idxList(GEPOP->idx_begin(), GEPOP->idx_end());
  1944. VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
  1945. }
  1946. for (auto user = U->user_begin(); user != U->user_end();) {
  1947. CallInst *CI = cast<CallInst>(*(user++));
  1948. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  1949. if (group == HLOpcodeGroup::HLMatLoadStore) {
  1950. TranslateMatLoadStoreOnGlobalPtr(CI, VecGEP);
  1951. } else if (group == HLOpcodeGroup::HLSubscript) {
  1952. TranslateMatSubscriptOnGlobalPtr(CI, VecGEP);
  1953. } else {
  1954. DXASSERT(0, "invalid operation");
  1955. }
  1956. }
  1957. }
  1958. DeleteDeadInsts();
  1959. GV->removeDeadConstantUsers();
  1960. GV->eraseFromParent();
  1961. }
  1962. static void FlattenMatConst(Constant *M, std::vector<Constant *> &Elts) {
  1963. unsigned row, col;
  1964. Type *EltTy = HLMatrixLower::GetMatrixInfo(M->getType(), col, row);
  1965. if (isa<UndefValue>(M)) {
  1966. Constant *Elt = UndefValue::get(EltTy);
  1967. for (unsigned i=0;i<col*row;i++)
  1968. Elts.emplace_back(Elt);
  1969. } else {
  1970. M = M->getAggregateElement((unsigned)0);
  1971. // Initializer is already in correct major.
  1972. // Just read it here.
  1973. // The type is vector<element, col>[row].
  1974. for (unsigned r = 0; r < row; r++) {
  1975. Constant *C = M->getAggregateElement(r);
  1976. for (unsigned c = 0; c < col; c++) {
  1977. Elts.emplace_back(C->getAggregateElement(c));
  1978. }
  1979. }
  1980. }
  1981. }
  1982. void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
  1983. if (HLMatrixLower::IsMatrixArrayPointer(GV->getType())) {
  1984. runOnGlobalMatrixArray(GV);
  1985. return;
  1986. }
  1987. Type *Ty = GV->getType()->getPointerElementType();
  1988. if (!HLMatrixLower::IsMatrixType(Ty))
  1989. return;
  1990. bool onlyLdSt = OnlyUsedByMatrixLdSt(GV);
  1991. bool isConst = GV->isConstant();
  1992. Type *vecTy = HLMatrixLower::LowerMatrixType(Ty);
  1993. Module *M = GV->getParent();
  1994. const DataLayout &DL = M->getDataLayout();
  1995. std::vector<Constant *> Elts;
  1996. // Lower to vector or array for scalar matrix.
  1997. // Make it col major so don't need shuffle when load/store.
  1998. FlattenMatConst(GV->getInitializer(), Elts);
  1999. if (onlyLdSt) {
  2000. Type *EltTy = vecTy->getVectorElementType();
  2001. unsigned vecSize = vecTy->getVectorNumElements();
  2002. std::vector<Value *> vecGlobals(vecSize);
  2003. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  2004. unsigned AddressSpace = GV->getType()->getAddressSpace();
  2005. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  2006. unsigned debugOffset = 0;
  2007. unsigned size = DL.getTypeAllocSizeInBits(EltTy);
  2008. unsigned align = DL.getPrefTypeAlignment(EltTy);
  2009. for (int i = 0, e = vecSize; i != e; ++i) {
  2010. Constant *InitVal = Elts[i];
  2011. GlobalVariable *EltGV = new llvm::GlobalVariable(
  2012. *M, EltTy, /*IsConstant*/ isConst, linkage,
  2013. /*InitVal*/ InitVal, GV->getName() + "." + Twine(i),
  2014. /*InsertBefore*/nullptr,
  2015. TLMode, AddressSpace);
  2016. // Add debug info.
  2017. if (m_HasDbgInfo) {
  2018. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2019. HLModule::CreateElementGlobalVariableDebugInfo(
  2020. GV, Finder, EltGV, size, align, debugOffset,
  2021. EltGV->getName().ltrim(GV->getName()));
  2022. debugOffset += size;
  2023. }
  2024. vecGlobals[i] = EltGV;
  2025. }
  2026. for (User *user : GV->users()) {
  2027. if (isa<Constant>(user) && user->use_empty())
  2028. continue;
  2029. CallInst *CI = cast<CallInst>(user);
  2030. TranslateMatLoadStoreOnGlobal(GV, vecGlobals, CI);
  2031. AddToDeadInsts(CI);
  2032. }
  2033. DeleteDeadInsts();
  2034. GV->eraseFromParent();
  2035. }
  2036. else {
  2037. // lower to array of scalar here.
  2038. ArrayType *AT = ArrayType::get(vecTy->getVectorElementType(), vecTy->getVectorNumElements());
  2039. Constant *InitVal = ConstantArray::get(AT, Elts);
  2040. GlobalVariable *arrayMat = new llvm::GlobalVariable(
  2041. *M, AT, /*IsConstant*/ false, llvm::GlobalValue::InternalLinkage,
  2042. /*InitVal*/ InitVal, GV->getName());
  2043. // Add debug info.
  2044. if (m_HasDbgInfo) {
  2045. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2046. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder,
  2047. arrayMat);
  2048. }
  2049. for (auto U = GV->user_begin(); U != GV->user_end();) {
  2050. Value *user = *(U++);
  2051. CallInst *CI = cast<CallInst>(user);
  2052. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2053. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2054. TranslateMatLoadStoreOnGlobal(GV, arrayMat, CI);
  2055. }
  2056. else {
  2057. DXASSERT(group == HLOpcodeGroup::HLSubscript, "Must be subscript operation");
  2058. TranslateMatSubscriptOnGlobalPtr(CI, arrayMat);
  2059. }
  2060. }
  2061. GV->removeDeadConstantUsers();
  2062. GV->eraseFromParent();
  2063. }
  2064. }
  2065. void HLMatrixLowerPass::runOnFunction(Function &F) {
  2066. // Create vector version of matrix instructions first.
  2067. // The matrix operands will be undefval for these instructions.
  2068. for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
  2069. BasicBlock *BB = BBI;
  2070. for (Instruction &I : BB->getInstList()) {
  2071. if (IsMatrixType(I.getType())) {
  2072. lowerToVec(&I);
  2073. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
  2074. Type *Ty = AI->getAllocatedType();
  2075. if (HLMatrixLower::IsMatrixType(Ty)) {
  2076. lowerToVec(&I);
  2077. } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
  2078. lowerToVec(&I);
  2079. }
  2080. } else if (CallInst *CI = dyn_cast<CallInst>(&I)) {
  2081. HLOpcodeGroup group =
  2082. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  2083. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2084. HLMatLoadStoreOpcode opcode =
  2085. static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
  2086. DXASSERT(opcode == HLMatLoadStoreOpcode::ColMatStore ||
  2087. opcode == HLMatLoadStoreOpcode::RowMatStore,
  2088. "Must MatStore here, load will go IsMatrixType path");
  2089. // Lower it here to make sure it is ready before replace.
  2090. lowerToVec(&I);
  2091. }
  2092. }
  2093. }
  2094. }
  2095. // Update the use of matrix inst with the vector version.
  2096. for (auto matToVecIter = matToVecMap.begin();
  2097. matToVecIter != matToVecMap.end();) {
  2098. auto matToVec = matToVecIter++;
  2099. replaceMatWithVec(matToVec->first, cast<Instruction>(matToVec->second));
  2100. }
  2101. // Translate mat inst which require all operands ready.
  2102. for (auto matToVecIter = matToVecMap.begin();
  2103. matToVecIter != matToVecMap.end();) {
  2104. auto matToVec = matToVecIter++;
  2105. finalMatTranslation(matToVec->first);
  2106. }
  2107. // Delete the matrix version insts.
  2108. for (auto matToVecIter = matToVecMap.begin();
  2109. matToVecIter != matToVecMap.end();) {
  2110. auto matToVec = matToVecIter++;
  2111. // Add to m_deadInsts.
  2112. Instruction *matInst = matToVec->first;
  2113. AddToDeadInsts(matInst);
  2114. }
  2115. DeleteDeadInsts();
  2116. matToVecMap.clear();
  2117. }