HLMatrixLowerPass.cpp 90 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixLowerPass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // HLMatrixLowerPass implementation. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  12. #include "dxc/HLSL/HLMatrixLowerPass.h"
  13. #include "dxc/HLSL/HLOperations.h"
  14. #include "dxc/HLSL/HLModule.h"
  15. #include "dxc/HLSL/DxilUtil.h"
  16. #include "dxc/HlslIntrinsicOp.h"
  17. #include "dxc/Support/Global.h"
  18. #include "dxc/HLSL/DxilOperations.h"
  19. #include "llvm/IR/IRBuilder.h"
  20. #include "llvm/IR/Module.h"
  21. #include "llvm/IR/DebugInfo.h"
  22. #include "llvm/IR/IntrinsicInst.h"
  23. #include "llvm/Transforms/Utils/Local.h"
  24. #include "llvm/Pass.h"
  25. #include "llvm/Support/raw_ostream.h"
  26. #include <unordered_set>
  27. #include <vector>
  28. using namespace llvm;
  29. using namespace hlsl;
  30. using namespace hlsl::HLMatrixLower;
  31. namespace hlsl {
  32. namespace HLMatrixLower {
  33. bool IsMatrixType(Type *Ty) {
  34. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  35. Type *EltTy = ST->getElementType(0);
  36. if (!ST->getName().startswith("class.matrix"))
  37. return false;
  38. bool isVecArray = EltTy->isArrayTy() &&
  39. EltTy->getArrayElementType()->isVectorTy();
  40. return isVecArray && EltTy->getArrayNumElements() <= 4;
  41. }
  42. return false;
  43. }
  44. // Translate matrix type to vector type.
  45. Type *LowerMatrixType(Type *Ty) {
  46. // Only translate matrix type and function type which use matrix type.
  47. // Not translate struct has matrix or matrix pointer.
  48. // Struct should be flattened before.
  49. // Pointer could cover by matldst which use vector as value type.
  50. if (FunctionType *FT = dyn_cast<FunctionType>(Ty)) {
  51. Type *RetTy = LowerMatrixType(FT->getReturnType());
  52. SmallVector<Type *, 4> params;
  53. for (Type *param : FT->params()) {
  54. params.emplace_back(LowerMatrixType(param));
  55. }
  56. return FunctionType::get(RetTy, params, false);
  57. } else if (IsMatrixType(Ty)) {
  58. unsigned row, col;
  59. Type *EltTy = GetMatrixInfo(Ty, col, row);
  60. return VectorType::get(EltTy, row * col);
  61. } else {
  62. return Ty;
  63. }
  64. }
  65. Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
  66. DXASSERT(IsMatrixType(Ty), "not matrix type");
  67. StructType *ST = cast<StructType>(Ty);
  68. Type *EltTy = ST->getElementType(0);
  69. Type *RowTy = EltTy->getArrayElementType();
  70. row = EltTy->getArrayNumElements();
  71. col = RowTy->getVectorNumElements();
  72. return RowTy->getVectorElementType();
  73. }
  74. bool IsMatrixArrayPointer(llvm::Type *Ty) {
  75. if (!Ty->isPointerTy())
  76. return false;
  77. Ty = Ty->getPointerElementType();
  78. if (!Ty->isArrayTy())
  79. return false;
  80. while (Ty->isArrayTy())
  81. Ty = Ty->getArrayElementType();
  82. return IsMatrixType(Ty);
  83. }
  84. Type *LowerMatrixArrayPointer(Type *Ty) {
  85. Ty = Ty->getPointerElementType();
  86. std::vector<unsigned> arraySizeList;
  87. while (Ty->isArrayTy()) {
  88. arraySizeList.push_back(Ty->getArrayNumElements());
  89. Ty = Ty->getArrayElementType();
  90. }
  91. Ty = LowerMatrixType(Ty);
  92. for (auto arraySize = arraySizeList.rbegin();
  93. arraySize != arraySizeList.rend(); arraySize++)
  94. Ty = ArrayType::get(Ty, *arraySize);
  95. return PointerType::get(Ty, 0);
  96. }
  97. Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
  98. IRBuilder<> &Builder) {
  99. Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
  100. for (unsigned i = 0; i < size; i++)
  101. Vec = Builder.CreateInsertElement(Vec, elts[i], i);
  102. return Vec;
  103. }
  104. Value *LowerGEPOnMatIndexListToIndex(
  105. llvm::GetElementPtrInst *GEP, ArrayRef<Value *> IdxList) {
  106. IRBuilder<> Builder(GEP);
  107. Value *zero = Builder.getInt32(0);
  108. DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
  109. Value *baseIdx = (GEP->idx_begin())->get();
  110. DXASSERT_LOCALVAR(baseIdx, baseIdx == zero, "base index must be 0");
  111. Value *Idx = (GEP->idx_begin() + 1)->get();
  112. if (ConstantInt *immIdx = dyn_cast<ConstantInt>(Idx)) {
  113. return IdxList[immIdx->getSExtValue()];
  114. } else {
  115. IRBuilder<> AllocaBuilder(
  116. GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  117. unsigned size = IdxList.size();
  118. // Store idxList to temp array.
  119. ArrayType *AT = ArrayType::get(IdxList[0]->getType(), size);
  120. Value *tempArray = AllocaBuilder.CreateAlloca(AT);
  121. for (unsigned i = 0; i < size; i++) {
  122. Value *EltPtr = Builder.CreateGEP(tempArray, {zero, Builder.getInt32(i)});
  123. Builder.CreateStore(IdxList[i], EltPtr);
  124. }
  125. // Load the idx.
  126. Value *GEPOffset = Builder.CreateGEP(tempArray, {zero, Idx});
  127. return Builder.CreateLoad(GEPOffset);
  128. }
  129. }
  130. unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row) {
  131. return c * row + r;
  132. }
  133. unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col) {
  134. return r * col + c;
  135. }
  136. } // namespace HLMatrixLower
  137. } // namespace hlsl
  138. namespace {
  139. class HLMatrixLowerPass : public ModulePass {
  140. public:
  141. static char ID; // Pass identification, replacement for typeid
  142. explicit HLMatrixLowerPass() : ModulePass(ID) {}
  143. const char *getPassName() const override { return "HL matrix lower"; }
  144. bool runOnModule(Module &M) override {
  145. m_pModule = &M;
  146. m_pHLModule = &m_pModule->GetOrCreateHLModule();
  147. // Load up debug information, to cross-reference values and the instructions
  148. // used to load them.
  149. m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  150. for (Function &F : M.functions()) {
  151. if (F.isDeclaration())
  152. continue;
  153. runOnFunction(F);
  154. }
  155. std::vector<GlobalVariable*> staticGVs;
  156. for (GlobalVariable &GV : M.globals()) {
  157. if (dxilutil::IsStaticGlobal(&GV) ||
  158. dxilutil::IsSharedMemoryGlobal(&GV)) {
  159. staticGVs.emplace_back(&GV);
  160. }
  161. }
  162. for (GlobalVariable *GV : staticGVs)
  163. runOnGlobal(GV);
  164. return true;
  165. }
  166. private:
  167. Module *m_pModule;
  168. HLModule *m_pHLModule;
  169. bool m_HasDbgInfo;
  170. std::vector<Instruction *> m_deadInsts;
  171. // For instruction like matrix array init.
  172. // May use more than 1 matrix alloca inst.
  173. // This set is here to avoid put it into deadInsts more than once.
  174. std::unordered_set<Instruction *> m_inDeadInstsSet;
  175. // For most matrix insturction users, it will only have one matrix use.
  176. // Use vector so save deadInsts because vector is cheap.
  177. void AddToDeadInsts(Instruction *I) { m_deadInsts.emplace_back(I); }
  178. // In case instruction has more than one matrix use.
  179. // Use AddToDeadInstsWithDups to make sure it's not add to deadInsts more than once.
  180. void AddToDeadInstsWithDups(Instruction *I) {
  181. if (m_inDeadInstsSet.count(I) == 0) {
  182. // Only add to deadInsts when it's not inside m_inDeadInstsSet.
  183. m_inDeadInstsSet.insert(I);
  184. AddToDeadInsts(I);
  185. }
  186. }
  187. void runOnFunction(Function &F);
  188. void runOnGlobal(GlobalVariable *GV);
  189. void runOnGlobalMatrixArray(GlobalVariable *GV);
  190. Instruction *MatCastToVec(CallInst *CI);
  191. Instruction *MatLdStToVec(CallInst *CI);
  192. Instruction *MatSubscriptToVec(CallInst *CI);
  193. Instruction *MatFrExpToVec(CallInst *CI);
  194. Instruction *MatIntrinsicToVec(CallInst *CI);
  195. Instruction *TrivialMatUnOpToVec(CallInst *CI);
  196. // Replace matInst with vecInst on matUseInst.
  197. void TrivialMatUnOpReplace(CallInst *matInst, Instruction *vecInst,
  198. CallInst *matUseInst);
  199. Instruction *TrivialMatBinOpToVec(CallInst *CI);
  200. // Replace matInst with vecInst on matUseInst.
  201. void TrivialMatBinOpReplace(CallInst *matInst, Instruction *vecInst,
  202. CallInst *matUseInst);
  203. // Replace matInst with vecInst on mulInst.
  204. void TranslateMatMatMul(CallInst *matInst, Instruction *vecInst,
  205. CallInst *mulInst, bool isSigned);
  206. void TranslateMatVecMul(CallInst *matInst, Instruction *vecInst,
  207. CallInst *mulInst, bool isSigned);
  208. void TranslateVecMatMul(CallInst *matInst, Instruction *vecInst,
  209. CallInst *mulInst, bool isSigned);
  210. void TranslateMul(CallInst *matInst, Instruction *vecInst, CallInst *mulInst,
  211. bool isSigned);
  212. // Replace matInst with vecInst on transposeInst.
  213. void TranslateMatTranspose(CallInst *matInst, Instruction *vecInst,
  214. CallInst *transposeInst);
  215. void TranslateMatDeterminant(CallInst *matInst, Instruction *vecInst,
  216. CallInst *determinantInst);
  217. void MatIntrinsicReplace(CallInst *matInst, Instruction *vecInst,
  218. CallInst *matUseInst);
  219. // Replace matInst with vecInst on castInst.
  220. void TranslateMatMatCast(CallInst *matInst, Instruction *vecInst,
  221. CallInst *castInst);
  222. void TranslateMatToOtherCast(CallInst *matInst, Instruction *vecInst,
  223. CallInst *castInst);
  224. void TranslateMatCast(CallInst *matInst, Instruction *vecInst,
  225. CallInst *castInst);
  226. void TranslateMatMajorCast(CallInst *matInst, Instruction *vecInst,
  227. CallInst *castInst, bool rowToCol);
  228. // Replace matInst with vecInst in matSubscript
  229. void TranslateMatSubscript(Value *matInst, Value *vecInst,
  230. CallInst *matSubInst);
  231. // Replace matInst with vecInst
  232. void TranslateMatInit(CallInst *matInitInst);
  233. // Replace matInst with vecInst.
  234. void TranslateMatSelect(CallInst *matSelectInst);
  235. // Replace matInst with vecInst on matInitInst.
  236. void TranslateMatArrayGEP(Value *matInst, Instruction *vecInst,
  237. GetElementPtrInst *matGEP);
  238. void TranslateMatLoadStoreOnGlobal(Value *matGlobal, ArrayRef<Value *>vecGlobals,
  239. CallInst *matLdStInst);
  240. void TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
  241. CallInst *matLdStInst);
  242. void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
  243. void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
  244. // Replace matInst with vecInst on matUseInst.
  245. void TrivialMatReplace(CallInst *matInst, Instruction *vecInst,
  246. CallInst *matUseInst);
  247. // Lower a matrix type instruction to a vector type instruction.
  248. void lowerToVec(Instruction *matInst);
  249. // Lower users of a matrix type instruction.
  250. void replaceMatWithVec(Instruction *matInst, Instruction *vecInst);
  251. // Translate mat inst which need all operands ready.
  252. void finalMatTranslation(Instruction *matInst);
  253. // Delete dead insts in m_deadInsts.
  254. void DeleteDeadInsts();
  255. // Map from matrix inst to its vector version.
  256. DenseMap<Instruction *, Value *> matToVecMap;
  257. };
  258. }
  259. char HLMatrixLowerPass::ID = 0;
  260. ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
  261. INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
  262. static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
  263. IRBuilder<> Builder) {
  264. // Cast to bool.
  265. if (toTy->getScalarType()->isIntegerTy() &&
  266. toTy->getScalarType()->getIntegerBitWidth() == 1) {
  267. Type *fromTy = src->getType();
  268. bool isFloat = fromTy->getScalarType()->isFloatingPointTy();
  269. Constant *zero;
  270. if (isFloat)
  271. zero = llvm::ConstantFP::get(fromTy->getScalarType(), 0);
  272. else
  273. zero = llvm::ConstantInt::get(fromTy->getScalarType(), 0);
  274. if (toTy->getScalarType() != toTy) {
  275. // Create constant vector.
  276. unsigned size = toTy->getVectorNumElements();
  277. std::vector<Constant *> zeros(size, zero);
  278. zero = llvm::ConstantVector::get(zeros);
  279. }
  280. if (isFloat)
  281. return cast<Instruction>(Builder.CreateFCmpOEQ(src, zero));
  282. else
  283. return cast<Instruction>(Builder.CreateICmpEQ(src, zero));
  284. }
  285. Type *eltToTy = toTy->getScalarType();
  286. Type *eltFromTy = src->getType()->getScalarType();
  287. bool fromUnsigned = castOp == HLCastOpcode::FromUnsignedCast ||
  288. castOp == HLCastOpcode::UnsignedUnsignedCast;
  289. bool toUnsigned = castOp == HLCastOpcode::ToUnsignedCast ||
  290. castOp == HLCastOpcode::UnsignedUnsignedCast;
  291. Instruction::CastOps castOps = static_cast<Instruction::CastOps>(
  292. HLModule::FindCastOp(fromUnsigned, toUnsigned, eltFromTy, eltToTy));
  293. return cast<Instruction>(Builder.CreateCast(castOps, src, toTy));
  294. }
  295. Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
  296. IRBuilder<> Builder(CI);
  297. Value *op = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  298. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(CI));
  299. bool ToMat = IsMatrixType(CI->getType());
  300. bool FromMat = IsMatrixType(op->getType());
  301. if (ToMat && !FromMat) {
  302. // Translate OtherToMat here.
  303. // Rest will translated when replace.
  304. unsigned col, row;
  305. Type *EltTy = GetMatrixInfo(CI->getType(), col, row);
  306. unsigned toSize = col * row;
  307. Instruction *sizeCast = nullptr;
  308. Type *FromTy = op->getType();
  309. Type *I32Ty = IntegerType::get(FromTy->getContext(), 32);
  310. if (FromTy->isVectorTy()) {
  311. std::vector<Constant *> MaskVec(toSize);
  312. for (size_t i = 0; i != toSize; ++i)
  313. MaskVec[i] = ConstantInt::get(I32Ty, i);
  314. Value *castMask = ConstantVector::get(MaskVec);
  315. sizeCast = new ShuffleVectorInst(op, op, castMask);
  316. Builder.Insert(sizeCast);
  317. } else {
  318. op = Builder.CreateInsertElement(
  319. UndefValue::get(VectorType::get(FromTy, 1)), op, (uint64_t)0);
  320. Constant *zero = ConstantInt::get(I32Ty, 0);
  321. std::vector<Constant *> MaskVec(toSize, zero);
  322. Value *castMask = ConstantVector::get(MaskVec);
  323. sizeCast = new ShuffleVectorInst(op, op, castMask);
  324. Builder.Insert(sizeCast);
  325. }
  326. Instruction *typeCast = sizeCast;
  327. if (EltTy != FromTy->getScalarType()) {
  328. typeCast = CreateTypeCast(opcode, VectorType::get(EltTy, toSize),
  329. sizeCast, Builder);
  330. }
  331. return typeCast;
  332. } else if (FromMat && ToMat) {
  333. if (isa<Argument>(op)) {
  334. // Cast From mat to mat for arugment.
  335. IRBuilder<> Builder(CI);
  336. // Here only lower the return type to vector.
  337. Type *RetTy = LowerMatrixType(CI->getType());
  338. SmallVector<Type *, 4> params;
  339. for (Value *operand : CI->arg_operands()) {
  340. params.emplace_back(operand->getType());
  341. }
  342. Type *FT = FunctionType::get(RetTy, params, false);
  343. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  344. unsigned opcode = GetHLOpcode(CI);
  345. Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT),
  346. group, opcode);
  347. SmallVector<Value *, 4> argList;
  348. for (Value *arg : CI->arg_operands()) {
  349. argList.emplace_back(arg);
  350. }
  351. return Builder.CreateCall(vecF, argList);
  352. }
  353. }
  354. return MatIntrinsicToVec(CI);
  355. }
  356. Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
  357. IRBuilder<> Builder(CI);
  358. unsigned opcode = GetHLOpcode(CI);
  359. HLMatLoadStoreOpcode matOpcode = static_cast<HLMatLoadStoreOpcode>(opcode);
  360. Instruction *result = nullptr;
  361. switch (matOpcode) {
  362. case HLMatLoadStoreOpcode::ColMatLoad:
  363. case HLMatLoadStoreOpcode::RowMatLoad: {
  364. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
  365. if (isa<AllocaInst>(matPtr)) {
  366. Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
  367. result = Builder.CreateLoad(vecPtr);
  368. } else
  369. result = MatIntrinsicToVec(CI);
  370. } break;
  371. case HLMatLoadStoreOpcode::ColMatStore:
  372. case HLMatLoadStoreOpcode::RowMatStore: {
  373. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
  374. if (isa<AllocaInst>(matPtr)) {
  375. Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
  376. Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  377. Value *vecVal =
  378. UndefValue::get(HLMatrixLower::LowerMatrixType(matVal->getType()));
  379. result = Builder.CreateStore(vecVal, vecPtr);
  380. } else
  381. result = MatIntrinsicToVec(CI);
  382. } break;
  383. }
  384. return result;
  385. }
  386. Instruction *HLMatrixLowerPass::MatSubscriptToVec(CallInst *CI) {
  387. IRBuilder<> Builder(CI);
  388. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  389. if (isa<AllocaInst>(matPtr)) {
  390. // Here just create a new matSub call which use vec ptr.
  391. // Later in TranslateMatSubscript will do the real translation.
  392. std::vector<Value *> args(CI->getNumArgOperands());
  393. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  394. args[i] = CI->getArgOperand(i);
  395. }
  396. // Change mat ptr into vec ptr.
  397. args[HLOperandIndex::kMatSubscriptMatOpIdx] =
  398. matToVecMap[cast<Instruction>(matPtr)];
  399. std::vector<Type *> paramTyList(CI->getNumArgOperands());
  400. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  401. paramTyList[i] = args[i]->getType();
  402. }
  403. FunctionType *funcTy = FunctionType::get(CI->getType(), paramTyList, false);
  404. unsigned opcode = GetHLOpcode(CI);
  405. Function *opFunc = GetOrCreateHLFunction(*m_pModule, funcTy, HLOpcodeGroup::HLSubscript, opcode);
  406. return Builder.CreateCall(opFunc, args);
  407. } else
  408. return MatIntrinsicToVec(CI);
  409. }
  410. Instruction *HLMatrixLowerPass::MatFrExpToVec(CallInst *CI) {
  411. IRBuilder<> Builder(CI);
  412. FunctionType *FT = CI->getCalledFunction()->getFunctionType();
  413. Type *RetTy = LowerMatrixType(FT->getReturnType());
  414. SmallVector<Type *, 4> params;
  415. for (Type *param : FT->params()) {
  416. if (!param->isPointerTy()) {
  417. params.emplace_back(LowerMatrixType(param));
  418. } else {
  419. // Lower pointer type for frexp.
  420. Type *EltTy = LowerMatrixType(param->getPointerElementType());
  421. params.emplace_back(
  422. PointerType::get(EltTy, param->getPointerAddressSpace()));
  423. }
  424. }
  425. Type *VecFT = FunctionType::get(RetTy, params, false);
  426. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  427. Function *vecF =
  428. GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(VecFT), group,
  429. static_cast<unsigned>(IntrinsicOp::IOP_frexp));
  430. SmallVector<Value *, 4> argList;
  431. auto paramTyIt = params.begin();
  432. for (Value *arg : CI->arg_operands()) {
  433. Type *Ty = arg->getType();
  434. Type *ParamTy = *(paramTyIt++);
  435. if (Ty != ParamTy)
  436. argList.emplace_back(UndefValue::get(ParamTy));
  437. else
  438. argList.emplace_back(arg);
  439. }
  440. return Builder.CreateCall(vecF, argList);
  441. }
  442. Instruction *HLMatrixLowerPass::MatIntrinsicToVec(CallInst *CI) {
  443. IRBuilder<> Builder(CI);
  444. unsigned opcode = GetHLOpcode(CI);
  445. if (opcode == static_cast<unsigned>(IntrinsicOp::IOP_frexp))
  446. return MatFrExpToVec(CI);
  447. Type *FT = LowerMatrixType(CI->getCalledFunction()->getFunctionType());
  448. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  449. Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT), group, opcode);
  450. SmallVector<Value *, 4> argList;
  451. for (Value *arg : CI->arg_operands()) {
  452. Type *Ty = arg->getType();
  453. if (IsMatrixType(Ty)) {
  454. argList.emplace_back(UndefValue::get(LowerMatrixType(Ty)));
  455. } else
  456. argList.emplace_back(arg);
  457. }
  458. return Builder.CreateCall(vecF, argList);
  459. }
  460. static Value *VectorizeScalarOp(Value *op, Type *dstTy, IRBuilder<> &Builder) {
  461. if (op->getType() == dstTy)
  462. return op;
  463. op = Builder.CreateInsertElement(
  464. UndefValue::get(VectorType::get(op->getType(), 1)), op, (uint64_t)0);
  465. Type *I32Ty = IntegerType::get(dstTy->getContext(), 32);
  466. Constant *zero = ConstantInt::get(I32Ty, 0);
  467. std::vector<Constant *> MaskVec(dstTy->getVectorNumElements(), zero);
  468. Value *castMask = ConstantVector::get(MaskVec);
  469. Value *vecOp = new ShuffleVectorInst(op, op, castMask);
  470. Builder.Insert(cast<Instruction>(vecOp));
  471. return vecOp;
  472. }
  473. Instruction *HLMatrixLowerPass::TrivialMatUnOpToVec(CallInst *CI) {
  474. Type *ResultTy = LowerMatrixType(CI->getType());
  475. UndefValue *tmp = UndefValue::get(ResultTy);
  476. IRBuilder<> Builder(CI);
  477. HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(CI));
  478. bool isFloat = ResultTy->getVectorElementType()->isFloatingPointTy();
  479. auto GetVecConst = [&](Type *Ty, int v) -> Constant * {
  480. Constant *val = isFloat ? ConstantFP::get(Ty->getScalarType(), v)
  481. : ConstantInt::get(Ty->getScalarType(), v);
  482. std::vector<Constant *> vals(Ty->getVectorNumElements(), val);
  483. return ConstantVector::get(vals);
  484. };
  485. Constant *one = GetVecConst(ResultTy, 1);
  486. Instruction *Result = nullptr;
  487. switch (opcode) {
  488. case HLUnaryOpcode::Minus: {
  489. Constant *zero = GetVecConst(ResultTy, 0);
  490. if (isFloat)
  491. Result = BinaryOperator::CreateFSub(zero, tmp);
  492. else
  493. Result = BinaryOperator::CreateSub(zero, tmp);
  494. } break;
  495. case HLUnaryOpcode::LNot: {
  496. Constant *zero = GetVecConst(ResultTy, 0);
  497. if (isFloat)
  498. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_UNE, tmp, zero);
  499. else
  500. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, zero);
  501. } break;
  502. case HLUnaryOpcode::Not:
  503. Result = BinaryOperator::CreateXor(tmp, tmp);
  504. break;
  505. case HLUnaryOpcode::PostInc:
  506. case HLUnaryOpcode::PreInc:
  507. if (isFloat)
  508. Result = BinaryOperator::CreateFAdd(tmp, one);
  509. else
  510. Result = BinaryOperator::CreateAdd(tmp, one);
  511. break;
  512. case HLUnaryOpcode::PostDec:
  513. case HLUnaryOpcode::PreDec:
  514. if (isFloat)
  515. Result = BinaryOperator::CreateFSub(tmp, one);
  516. else
  517. Result = BinaryOperator::CreateSub(tmp, one);
  518. break;
  519. default:
  520. DXASSERT(0, "not implement");
  521. return nullptr;
  522. }
  523. Builder.Insert(Result);
  524. return Result;
  525. }
  526. Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
  527. Type *ResultTy = LowerMatrixType(CI->getType());
  528. IRBuilder<> Builder(CI);
  529. HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(CI));
  530. Type *OpTy = LowerMatrixType(
  531. CI->getOperand(HLOperandIndex::kBinaryOpSrc0Idx)->getType());
  532. UndefValue *tmp = UndefValue::get(OpTy);
  533. bool isFloat = OpTy->getVectorElementType()->isFloatingPointTy();
  534. Instruction *Result = nullptr;
  535. switch (opcode) {
  536. case HLBinaryOpcode::Add:
  537. if (isFloat)
  538. Result = BinaryOperator::CreateFAdd(tmp, tmp);
  539. else
  540. Result = BinaryOperator::CreateAdd(tmp, tmp);
  541. break;
  542. case HLBinaryOpcode::Sub:
  543. if (isFloat)
  544. Result = BinaryOperator::CreateFSub(tmp, tmp);
  545. else
  546. Result = BinaryOperator::CreateSub(tmp, tmp);
  547. break;
  548. case HLBinaryOpcode::Mul:
  549. if (isFloat)
  550. Result = BinaryOperator::CreateFMul(tmp, tmp);
  551. else
  552. Result = BinaryOperator::CreateMul(tmp, tmp);
  553. break;
  554. case HLBinaryOpcode::Div:
  555. if (isFloat)
  556. Result = BinaryOperator::CreateFDiv(tmp, tmp);
  557. else
  558. Result = BinaryOperator::CreateSDiv(tmp, tmp);
  559. break;
  560. case HLBinaryOpcode::Rem:
  561. if (isFloat)
  562. Result = BinaryOperator::CreateFRem(tmp, tmp);
  563. else
  564. Result = BinaryOperator::CreateSRem(tmp, tmp);
  565. break;
  566. case HLBinaryOpcode::And:
  567. Result = BinaryOperator::CreateAnd(tmp, tmp);
  568. break;
  569. case HLBinaryOpcode::Or:
  570. Result = BinaryOperator::CreateOr(tmp, tmp);
  571. break;
  572. case HLBinaryOpcode::Xor:
  573. Result = BinaryOperator::CreateXor(tmp, tmp);
  574. break;
  575. case HLBinaryOpcode::Shl: {
  576. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  577. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  578. Result = BinaryOperator::CreateShl(tmp, tmp);
  579. } break;
  580. case HLBinaryOpcode::Shr: {
  581. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  582. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  583. Result = BinaryOperator::CreateAShr(tmp, tmp);
  584. } break;
  585. case HLBinaryOpcode::LT:
  586. if (isFloat)
  587. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLT, tmp, tmp);
  588. else
  589. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT, tmp, tmp);
  590. break;
  591. case HLBinaryOpcode::GT:
  592. if (isFloat)
  593. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGT, tmp, tmp);
  594. else
  595. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, tmp, tmp);
  596. break;
  597. case HLBinaryOpcode::LE:
  598. if (isFloat)
  599. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLE, tmp, tmp);
  600. else
  601. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLE, tmp, tmp);
  602. break;
  603. case HLBinaryOpcode::GE:
  604. if (isFloat)
  605. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGE, tmp, tmp);
  606. else
  607. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGE, tmp, tmp);
  608. break;
  609. case HLBinaryOpcode::EQ:
  610. if (isFloat)
  611. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, tmp);
  612. else
  613. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, tmp);
  614. break;
  615. case HLBinaryOpcode::NE:
  616. if (isFloat)
  617. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, tmp);
  618. else
  619. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, tmp);
  620. break;
  621. case HLBinaryOpcode::UDiv:
  622. Result = BinaryOperator::CreateUDiv(tmp, tmp);
  623. break;
  624. case HLBinaryOpcode::URem:
  625. Result = BinaryOperator::CreateURem(tmp, tmp);
  626. break;
  627. case HLBinaryOpcode::UShr: {
  628. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  629. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  630. Result = BinaryOperator::CreateLShr(tmp, tmp);
  631. } break;
  632. case HLBinaryOpcode::ULT:
  633. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, tmp, tmp);
  634. break;
  635. case HLBinaryOpcode::UGT:
  636. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, tmp, tmp);
  637. break;
  638. case HLBinaryOpcode::ULE:
  639. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, tmp, tmp);
  640. break;
  641. case HLBinaryOpcode::UGE:
  642. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGE, tmp, tmp);
  643. break;
  644. case HLBinaryOpcode::LAnd:
  645. case HLBinaryOpcode::LOr: {
  646. Constant *zero;
  647. if (isFloat)
  648. zero = llvm::ConstantFP::get(ResultTy->getVectorElementType(), 0);
  649. else
  650. zero = llvm::ConstantInt::get(ResultTy->getVectorElementType(), 0);
  651. unsigned size = ResultTy->getVectorNumElements();
  652. std::vector<Constant *> zeros(size, zero);
  653. Value *vecZero = llvm::ConstantVector::get(zeros);
  654. Instruction *cmpL;
  655. if (isFloat)
  656. cmpL =
  657. CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
  658. else
  659. cmpL = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
  660. Builder.Insert(cmpL);
  661. Instruction *cmpR;
  662. if (isFloat)
  663. cmpR =
  664. CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
  665. else
  666. cmpR = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
  667. Builder.Insert(cmpR);
  668. // How to map l, r back? Need check opcode
  669. if (opcode == HLBinaryOpcode::LOr)
  670. Result = BinaryOperator::CreateAnd(cmpL, cmpR);
  671. else
  672. Result = BinaryOperator::CreateAnd(cmpL, cmpR);
  673. break;
  674. }
  675. default:
  676. DXASSERT(0, "not implement");
  677. return nullptr;
  678. }
  679. Builder.Insert(Result);
  680. return Result;
  681. }
  682. void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
  683. Instruction *vecInst;
  684. if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
  685. hlsl::HLOpcodeGroup group =
  686. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  687. switch (group) {
  688. case HLOpcodeGroup::HLIntrinsic: {
  689. vecInst = MatIntrinsicToVec(CI);
  690. } break;
  691. case HLOpcodeGroup::HLSelect: {
  692. vecInst = MatIntrinsicToVec(CI);
  693. } break;
  694. case HLOpcodeGroup::HLBinOp: {
  695. vecInst = TrivialMatBinOpToVec(CI);
  696. } break;
  697. case HLOpcodeGroup::HLUnOp: {
  698. vecInst = TrivialMatUnOpToVec(CI);
  699. } break;
  700. case HLOpcodeGroup::HLCast: {
  701. vecInst = MatCastToVec(CI);
  702. } break;
  703. case HLOpcodeGroup::HLInit: {
  704. vecInst = MatIntrinsicToVec(CI);
  705. } break;
  706. case HLOpcodeGroup::HLMatLoadStore: {
  707. vecInst = MatLdStToVec(CI);
  708. } break;
  709. case HLOpcodeGroup::HLSubscript: {
  710. vecInst = MatSubscriptToVec(CI);
  711. } break;
  712. }
  713. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
  714. Type *Ty = AI->getAllocatedType();
  715. Type *matTy = Ty;
  716. IRBuilder<> Builder(AI);
  717. if (Ty->isArrayTy()) {
  718. Type *vecTy = HLMatrixLower::LowerMatrixArrayPointer(AI->getType());
  719. vecTy = vecTy->getPointerElementType();
  720. vecInst = Builder.CreateAlloca(vecTy, nullptr, AI->getName());
  721. } else {
  722. Type *vecTy = HLMatrixLower::LowerMatrixType(matTy);
  723. vecInst = Builder.CreateAlloca(vecTy, nullptr, AI->getName());
  724. }
  725. // Update debug info.
  726. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI);
  727. if (DDI) {
  728. LLVMContext &Context = AI->getContext();
  729. Value *DDIVar = MetadataAsValue::get(Context, DDI->getRawVariable());
  730. Value *DDIExp = MetadataAsValue::get(Context, DDI->getRawExpression());
  731. Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(vecInst));
  732. IRBuilder<> debugBuilder(DDI);
  733. debugBuilder.CreateCall(DDI->getCalledFunction(), {VMD, DDIVar, DDIExp});
  734. }
  735. if (HLModule::HasPreciseAttributeWithMetadata(AI))
  736. HLModule::MarkPreciseAttributeWithMetadata(vecInst);
  737. } else {
  738. DXASSERT(0, "invalid inst");
  739. }
  740. matToVecMap[matInst] = vecInst;
  741. }
  742. // Replace matInst with vecInst on matUseInst.
  743. void HLMatrixLowerPass::TrivialMatUnOpReplace(CallInst *matInst,
  744. Instruction *vecInst,
  745. CallInst *matUseInst) {
  746. HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(matUseInst));
  747. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
  748. switch (opcode) {
  749. case HLUnaryOpcode::Not:
  750. // Not is xor now
  751. vecUseInst->setOperand(0, vecInst);
  752. vecUseInst->setOperand(1, vecInst);
  753. break;
  754. case HLUnaryOpcode::LNot:
  755. case HLUnaryOpcode::PostInc:
  756. case HLUnaryOpcode::PreInc:
  757. case HLUnaryOpcode::PostDec:
  758. case HLUnaryOpcode::PreDec:
  759. vecUseInst->setOperand(0, vecInst);
  760. break;
  761. }
  762. }
  763. // Replace matInst with vecInst on matUseInst.
  764. void HLMatrixLowerPass::TrivialMatBinOpReplace(CallInst *matInst,
  765. Instruction *vecInst,
  766. CallInst *matUseInst) {
  767. HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(matUseInst));
  768. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
  769. if (opcode != HLBinaryOpcode::LAnd && opcode != HLBinaryOpcode::LOr) {
  770. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) == matInst)
  771. vecUseInst->setOperand(0, vecInst);
  772. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) == matInst)
  773. vecUseInst->setOperand(1, vecInst);
  774. } else {
  775. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) ==
  776. matInst) {
  777. Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(0));
  778. vecCmp->setOperand(0, vecInst);
  779. }
  780. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) ==
  781. matInst) {
  782. Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(1));
  783. vecCmp->setOperand(0, vecInst);
  784. }
  785. }
  786. }
  787. static Function *GetOrCreateMadIntrinsic(Type *Ty, Type *opcodeTy, IntrinsicOp madOp, Module &M) {
  788. llvm::FunctionType *MadFuncTy =
  789. llvm::FunctionType::get(Ty, { opcodeTy, Ty, Ty, Ty}, false);
  790. Function *MAD =
  791. GetOrCreateHLFunction(M, MadFuncTy, HLOpcodeGroup::HLIntrinsic,
  792. (unsigned)madOp);
  793. return MAD;
  794. }
  795. void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
  796. Instruction *vecInst,
  797. CallInst *mulInst, bool isSigned) {
  798. DXASSERT(matToVecMap.count(mulInst), "must has vec version");
  799. Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
  800. // Already translated.
  801. if (!isa<CallInst>(vecUseInst))
  802. return;
  803. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  804. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  805. unsigned col, row;
  806. Type *EltTy = GetMatrixInfo(LVal->getType(), col, row);
  807. unsigned rCol, rRow;
  808. GetMatrixInfo(RVal->getType(), rCol, rRow);
  809. DXASSERT_NOMSG(col == rRow);
  810. bool isFloat = EltTy->isFloatingPointTy();
  811. Value *retVal = llvm::UndefValue::get(LowerMatrixType(mulInst->getType()));
  812. IRBuilder<> Builder(mulInst);
  813. Value *lMat = matToVecMap[cast<Instruction>(LVal)];
  814. Value *rMat = matToVecMap[cast<Instruction>(RVal)];
  815. auto CreateOneEltMul = [&](unsigned r, unsigned lc, unsigned c) -> Value * {
  816. unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
  817. unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
  818. Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
  819. Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
  820. return isFloat ? Builder.CreateFMul(lMatElt, rMatElt)
  821. : Builder.CreateMul(lMatElt, rMatElt);
  822. };
  823. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  824. Type *opcodeTy = Builder.getInt32Ty();
  825. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  826. *m_pHLModule->GetModule());
  827. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  828. auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
  829. Value *acc) -> Value * {
  830. unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
  831. unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
  832. Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
  833. Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
  834. return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
  835. };
  836. for (unsigned r = 0; r < row; r++) {
  837. for (unsigned c = 0; c < rCol; c++) {
  838. unsigned lc = 0;
  839. Value *tmpVal = CreateOneEltMul(r, lc, c);
  840. for (lc = 1; lc < col; lc++) {
  841. tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
  842. }
  843. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, rCol);
  844. retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
  845. }
  846. }
  847. Instruction *matmatMul = cast<Instruction>(retVal);
  848. // Replace vec transpose function call with shuf.
  849. vecUseInst->replaceAllUsesWith(matmatMul);
  850. AddToDeadInsts(vecUseInst);
  851. matToVecMap[mulInst] = matmatMul;
  852. }
  853. void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
  854. Instruction *vecInst,
  855. CallInst *mulInst, bool isSigned) {
  856. // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  857. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  858. unsigned col, row;
  859. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  860. DXASSERT(RVal->getType()->getVectorNumElements() == col, "");
  861. bool isFloat = EltTy->isFloatingPointTy();
  862. Value *retVal = llvm::UndefValue::get(mulInst->getType());
  863. IRBuilder<> Builder(mulInst);
  864. Value *vec = RVal;
  865. Value *mat = vecInst; // vec version of matInst;
  866. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  867. Type *opcodeTy = Builder.getInt32Ty();
  868. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  869. *m_pHLModule->GetModule());
  870. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  871. auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
  872. Value *vecElt = Builder.CreateExtractElement(vec, c);
  873. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  874. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  875. return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
  876. };
  877. for (unsigned r = 0; r < row; r++) {
  878. unsigned c = 0;
  879. Value *vecElt = Builder.CreateExtractElement(vec, c);
  880. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  881. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  882. Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
  883. : Builder.CreateMul(vecElt, matElt);
  884. for (c = 1; c < col; c++) {
  885. tmpVal = CreateOneEltMad(r, c, tmpVal);
  886. }
  887. retVal = Builder.CreateInsertElement(retVal, tmpVal, r);
  888. }
  889. mulInst->replaceAllUsesWith(retVal);
  890. AddToDeadInsts(mulInst);
  891. }
  892. void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
  893. Instruction *vecInst,
  894. CallInst *mulInst, bool isSigned) {
  895. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  896. // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  897. Value *RVal = vecInst;
  898. unsigned col, row;
  899. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  900. DXASSERT(LVal->getType()->getVectorNumElements() == row, "");
  901. bool isFloat = EltTy->isFloatingPointTy();
  902. Value *retVal = llvm::UndefValue::get(mulInst->getType());
  903. IRBuilder<> Builder(mulInst);
  904. Value *vec = LVal;
  905. Value *mat = RVal;
  906. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  907. Type *opcodeTy = Builder.getInt32Ty();
  908. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  909. *m_pHLModule->GetModule());
  910. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  911. auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
  912. Value *vecElt = Builder.CreateExtractElement(vec, r);
  913. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  914. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  915. return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
  916. };
  917. for (unsigned c = 0; c < col; c++) {
  918. unsigned r = 0;
  919. Value *vecElt = Builder.CreateExtractElement(vec, r);
  920. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  921. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  922. Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
  923. : Builder.CreateMul(vecElt, matElt);
  924. for (r = 1; r < row; r++) {
  925. tmpVal = CreateOneEltMad(r, c, tmpVal);
  926. }
  927. retVal = Builder.CreateInsertElement(retVal, tmpVal, c);
  928. }
  929. mulInst->replaceAllUsesWith(retVal);
  930. AddToDeadInsts(mulInst);
  931. }
  932. void HLMatrixLowerPass::TranslateMul(CallInst *matInst, Instruction *vecInst,
  933. CallInst *mulInst, bool isSigned) {
  934. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  935. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  936. bool LMat = IsMatrixType(LVal->getType());
  937. bool RMat = IsMatrixType(RVal->getType());
  938. if (LMat && RMat) {
  939. TranslateMatMatMul(matInst, vecInst, mulInst, isSigned);
  940. } else if (LMat) {
  941. TranslateMatVecMul(matInst, vecInst, mulInst, isSigned);
  942. } else {
  943. TranslateVecMatMul(matInst, vecInst, mulInst, isSigned);
  944. }
  945. }
  946. void HLMatrixLowerPass::TranslateMatTranspose(CallInst *matInst,
  947. Instruction *vecInst,
  948. CallInst *transposeInst) {
  949. // Matrix value is row major, transpose is cast it to col major.
  950. TranslateMatMajorCast(matInst, vecInst, transposeInst, /*bRowToCol*/ true);
  951. }
  952. static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
  953. IRBuilder<> &Builder) {
  954. Value *mul0 = Builder.CreateFMul(m00, m11);
  955. Value *mul1 = Builder.CreateFMul(m01, m10);
  956. return Builder.CreateFSub(mul0, mul1);
  957. }
  958. static Value *Determinant3x3(Value *m00, Value *m01, Value *m02,
  959. Value *m10, Value *m11, Value *m12,
  960. Value *m20, Value *m21, Value *m22,
  961. IRBuilder<> &Builder) {
  962. Value *deter00 = Determinant2x2(m11, m12, m21, m22, Builder);
  963. Value *deter01 = Determinant2x2(m10, m12, m20, m22, Builder);
  964. Value *deter02 = Determinant2x2(m10, m11, m20, m21, Builder);
  965. deter00 = Builder.CreateFMul(m00, deter00);
  966. deter01 = Builder.CreateFMul(m01, deter01);
  967. deter02 = Builder.CreateFMul(m02, deter02);
  968. Value *result = Builder.CreateFSub(deter00, deter01);
  969. result = Builder.CreateFAdd(result, deter02);
  970. return result;
  971. }
  972. static Value *Determinant4x4(Value *m00, Value *m01, Value *m02, Value *m03,
  973. Value *m10, Value *m11, Value *m12, Value *m13,
  974. Value *m20, Value *m21, Value *m22, Value *m23,
  975. Value *m30, Value *m31, Value *m32, Value *m33,
  976. IRBuilder<> &Builder) {
  977. Value *deter00 = Determinant3x3(m11, m12, m13, m21, m22, m23, m31, m32, m33, Builder);
  978. Value *deter01 = Determinant3x3(m10, m12, m13, m20, m22, m23, m30, m32, m33, Builder);
  979. Value *deter02 = Determinant3x3(m10, m11, m13, m20, m21, m23, m30, m31, m33, Builder);
  980. Value *deter03 = Determinant3x3(m10, m11, m12, m20, m21, m22, m30, m31, m32, Builder);
  981. deter00 = Builder.CreateFMul(m00, deter00);
  982. deter01 = Builder.CreateFMul(m01, deter01);
  983. deter02 = Builder.CreateFMul(m02, deter02);
  984. deter03 = Builder.CreateFMul(m03, deter03);
  985. Value *result = Builder.CreateFSub(deter00, deter01);
  986. result = Builder.CreateFAdd(result, deter02);
  987. result = Builder.CreateFSub(result, deter03);
  988. return result;
  989. }
  990. void HLMatrixLowerPass::TranslateMatDeterminant(CallInst *matInst, Instruction *vecInst,
  991. CallInst *determinantInst) {
  992. unsigned row, col;
  993. GetMatrixInfo(matInst->getType(), col, row);
  994. IRBuilder<> Builder(determinantInst);
  995. // when row == 1, result is vecInst.
  996. Value *Result = vecInst;
  997. if (row == 2) {
  998. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  999. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  1000. Value *m10 = Builder.CreateExtractElement(vecInst, 2);
  1001. Value *m11 = Builder.CreateExtractElement(vecInst, 3);
  1002. Result = Determinant2x2(m00, m01, m10, m11, Builder);
  1003. }
  1004. else if (row == 3) {
  1005. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  1006. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  1007. Value *m02 = Builder.CreateExtractElement(vecInst, 2);
  1008. Value *m10 = Builder.CreateExtractElement(vecInst, 3);
  1009. Value *m11 = Builder.CreateExtractElement(vecInst, 4);
  1010. Value *m12 = Builder.CreateExtractElement(vecInst, 5);
  1011. Value *m20 = Builder.CreateExtractElement(vecInst, 6);
  1012. Value *m21 = Builder.CreateExtractElement(vecInst, 7);
  1013. Value *m22 = Builder.CreateExtractElement(vecInst, 8);
  1014. Result = Determinant3x3(m00, m01, m02,
  1015. m10, m11, m12,
  1016. m20, m21, m22, Builder);
  1017. }
  1018. else if (row == 4) {
  1019. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  1020. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  1021. Value *m02 = Builder.CreateExtractElement(vecInst, 2);
  1022. Value *m03 = Builder.CreateExtractElement(vecInst, 3);
  1023. Value *m10 = Builder.CreateExtractElement(vecInst, 4);
  1024. Value *m11 = Builder.CreateExtractElement(vecInst, 5);
  1025. Value *m12 = Builder.CreateExtractElement(vecInst, 6);
  1026. Value *m13 = Builder.CreateExtractElement(vecInst, 7);
  1027. Value *m20 = Builder.CreateExtractElement(vecInst, 8);
  1028. Value *m21 = Builder.CreateExtractElement(vecInst, 9);
  1029. Value *m22 = Builder.CreateExtractElement(vecInst, 10);
  1030. Value *m23 = Builder.CreateExtractElement(vecInst, 11);
  1031. Value *m30 = Builder.CreateExtractElement(vecInst, 12);
  1032. Value *m31 = Builder.CreateExtractElement(vecInst, 13);
  1033. Value *m32 = Builder.CreateExtractElement(vecInst, 14);
  1034. Value *m33 = Builder.CreateExtractElement(vecInst, 15);
  1035. Result = Determinant4x4(m00, m01, m02, m03,
  1036. m10, m11, m12, m13,
  1037. m20, m21, m22, m23,
  1038. m30, m31, m32, m33,
  1039. Builder);
  1040. } else {
  1041. DXASSERT(row == 1, "invalid matrix type");
  1042. Result = Builder.CreateExtractElement(Result, (uint64_t)0);
  1043. }
  1044. determinantInst->replaceAllUsesWith(Result);
  1045. AddToDeadInsts(determinantInst);
  1046. }
  1047. void HLMatrixLowerPass::TrivialMatReplace(CallInst *matInst,
  1048. Instruction *vecInst,
  1049. CallInst *matUseInst) {
  1050. CallInst *vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
  1051. for (unsigned i = 0; i < matUseInst->getNumArgOperands(); i++)
  1052. if (matUseInst->getArgOperand(i) == matInst) {
  1053. vecUseInst->setArgOperand(i, vecInst);
  1054. }
  1055. }
  1056. void HLMatrixLowerPass::TranslateMatMajorCast(CallInst *matInst,
  1057. Instruction *vecInst,
  1058. CallInst *castInst,
  1059. bool bRowToCol) {
  1060. unsigned col, row;
  1061. GetMatrixInfo(castInst->getType(), col, row);
  1062. DXASSERT(castInst->getType() == matInst->getType(), "type must match");
  1063. IRBuilder<> Builder(castInst);
  1064. // shuf to change major.
  1065. SmallVector<int, 16> castMask(col * row);
  1066. unsigned idx = 0;
  1067. if (bRowToCol) {
  1068. for (unsigned c = 0; c < col; c++)
  1069. for (unsigned r = 0; r < row; r++) {
  1070. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  1071. castMask[idx++] = matIdx;
  1072. }
  1073. } else {
  1074. for (unsigned r = 0; r < row; r++)
  1075. for (unsigned c = 0; c < col; c++) {
  1076. unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
  1077. castMask[idx++] = matIdx;
  1078. }
  1079. }
  1080. Instruction *vecCast = cast<Instruction>(
  1081. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1082. // Replace vec cast function call with vecCast.
  1083. DXASSERT(matToVecMap.count(castInst), "must has vec version");
  1084. Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
  1085. vecUseInst->replaceAllUsesWith(vecCast);
  1086. AddToDeadInsts(vecUseInst);
  1087. matToVecMap[castInst] = vecCast;
  1088. }
  1089. void HLMatrixLowerPass::TranslateMatMatCast(CallInst *matInst,
  1090. Instruction *vecInst,
  1091. CallInst *castInst) {
  1092. unsigned toCol, toRow;
  1093. Type *ToEltTy = GetMatrixInfo(castInst->getType(), toCol, toRow);
  1094. unsigned fromCol, fromRow;
  1095. Type *FromEltTy = GetMatrixInfo(matInst->getType(), fromCol, fromRow);
  1096. unsigned fromSize = fromCol * fromRow;
  1097. unsigned toSize = toCol * toRow;
  1098. DXASSERT(fromSize >= toSize, "cannot extend matrix");
  1099. IRBuilder<> Builder(castInst);
  1100. Instruction *vecCast = nullptr;
  1101. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1102. if (fromSize == toSize) {
  1103. vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), vecInst,
  1104. Builder);
  1105. } else {
  1106. // shuf first
  1107. std::vector<int> castMask(toCol * toRow);
  1108. unsigned idx = 0;
  1109. for (unsigned r = 0; r < toRow; r++)
  1110. for (unsigned c = 0; c < toCol; c++) {
  1111. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, fromCol);
  1112. castMask[idx++] = matIdx;
  1113. }
  1114. Instruction *shuf = cast<Instruction>(
  1115. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1116. if (ToEltTy != FromEltTy)
  1117. vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), shuf,
  1118. Builder);
  1119. else
  1120. vecCast = shuf;
  1121. }
  1122. // Replace vec cast function call with vecCast.
  1123. DXASSERT(matToVecMap.count(castInst), "must has vec version");
  1124. Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
  1125. vecUseInst->replaceAllUsesWith(vecCast);
  1126. AddToDeadInsts(vecUseInst);
  1127. matToVecMap[castInst] = vecCast;
  1128. }
  1129. void HLMatrixLowerPass::TranslateMatToOtherCast(CallInst *matInst,
  1130. Instruction *vecInst,
  1131. CallInst *castInst) {
  1132. unsigned col, row;
  1133. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  1134. unsigned fromSize = col * row;
  1135. IRBuilder<> Builder(castInst);
  1136. Instruction *sizeCast = nullptr;
  1137. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1138. Type *ToTy = castInst->getType();
  1139. if (ToTy->isVectorTy()) {
  1140. unsigned toSize = ToTy->getVectorNumElements();
  1141. if (fromSize != toSize) {
  1142. std::vector<int> castMask(fromSize);
  1143. for (unsigned c = 0; c < toSize; c++)
  1144. castMask[c] = c;
  1145. sizeCast = cast<Instruction>(
  1146. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1147. } else
  1148. sizeCast = vecInst;
  1149. } else {
  1150. DXASSERT(ToTy->isSingleValueType(), "must scalar here");
  1151. sizeCast =
  1152. cast<Instruction>(Builder.CreateExtractElement(vecInst, (uint64_t)0));
  1153. }
  1154. Instruction *typeCast = sizeCast;
  1155. if (EltTy != ToTy->getScalarType()) {
  1156. typeCast = CreateTypeCast(opcode, ToTy, typeCast, Builder);
  1157. }
  1158. // Replace cast function call with typeCast.
  1159. castInst->replaceAllUsesWith(typeCast);
  1160. AddToDeadInsts(castInst);
  1161. }
  1162. void HLMatrixLowerPass::TranslateMatCast(CallInst *matInst,
  1163. Instruction *vecInst,
  1164. CallInst *castInst) {
  1165. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1166. if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
  1167. opcode == HLCastOpcode::RowMatrixToColMatrix) {
  1168. TranslateMatMajorCast(matInst, vecInst, castInst,
  1169. opcode == HLCastOpcode::RowMatrixToColMatrix);
  1170. } else {
  1171. bool ToMat = IsMatrixType(castInst->getType());
  1172. bool FromMat = IsMatrixType(matInst->getType());
  1173. if (ToMat && FromMat) {
  1174. TranslateMatMatCast(matInst, vecInst, castInst);
  1175. } else if (FromMat)
  1176. TranslateMatToOtherCast(matInst, vecInst, castInst);
  1177. else
  1178. DXASSERT(0, "Not translate as user of matInst");
  1179. }
  1180. }
  1181. void HLMatrixLowerPass::MatIntrinsicReplace(CallInst *matInst,
  1182. Instruction *vecInst,
  1183. CallInst *matUseInst) {
  1184. IRBuilder<> Builder(matUseInst);
  1185. IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(matUseInst));
  1186. switch (opcode) {
  1187. case IntrinsicOp::IOP_umul:
  1188. TranslateMul(matInst, vecInst, matUseInst, /*isSigned*/false);
  1189. break;
  1190. case IntrinsicOp::IOP_mul:
  1191. TranslateMul(matInst, vecInst, matUseInst, /*isSigned*/true);
  1192. break;
  1193. case IntrinsicOp::IOP_transpose:
  1194. TranslateMatTranspose(matInst, vecInst, matUseInst);
  1195. break;
  1196. case IntrinsicOp::IOP_determinant:
  1197. TranslateMatDeterminant(matInst, vecInst, matUseInst);
  1198. break;
  1199. default:
  1200. CallInst *vecUseInst = nullptr;
  1201. if (matToVecMap.count(matUseInst))
  1202. vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
  1203. for (unsigned i = 0; i < matInst->getNumArgOperands(); i++)
  1204. if (matUseInst->getArgOperand(i) == matInst) {
  1205. if (vecUseInst)
  1206. vecUseInst->setArgOperand(i, vecInst);
  1207. else
  1208. matUseInst->setArgOperand(i, vecInst);
  1209. }
  1210. break;
  1211. }
  1212. }
  1213. void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
  1214. CallInst *matSubInst) {
  1215. unsigned opcode = GetHLOpcode(matSubInst);
  1216. HLSubscriptOpcode matOpcode = static_cast<HLSubscriptOpcode>(opcode);
  1217. assert(matOpcode != HLSubscriptOpcode::DefaultSubscript &&
  1218. "matrix don't use default subscript");
  1219. Type *matType = matInst->getType()->getPointerElementType();
  1220. unsigned col, row;
  1221. Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
  1222. bool isElement = (matOpcode == HLSubscriptOpcode::ColMatElement) |
  1223. (matOpcode == HLSubscriptOpcode::RowMatElement);
  1224. Value *mask =
  1225. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1226. if (isElement) {
  1227. Type *resultType = matSubInst->getType()->getPointerElementType();
  1228. unsigned resultSize = 1;
  1229. if (resultType->isVectorTy())
  1230. resultSize = resultType->getVectorNumElements();
  1231. std::vector<int> shufMask(resultSize);
  1232. Constant *EltIdxs = cast<Constant>(mask);
  1233. for (unsigned i = 0; i < resultSize; i++) {
  1234. shufMask[i] =
  1235. cast<ConstantInt>(EltIdxs->getAggregateElement(i))->getLimitedValue();
  1236. }
  1237. for (Value::use_iterator CallUI = matSubInst->use_begin(),
  1238. CallE = matSubInst->use_end();
  1239. CallUI != CallE;) {
  1240. Use &CallUse = *CallUI++;
  1241. Instruction *CallUser = cast<Instruction>(CallUse.getUser());
  1242. IRBuilder<> Builder(CallUser);
  1243. Value *vecLd = Builder.CreateLoad(vecInst);
  1244. if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
  1245. if (resultSize > 1) {
  1246. Value *shuf = Builder.CreateShuffleVector(vecLd, vecLd, shufMask);
  1247. ld->replaceAllUsesWith(shuf);
  1248. } else {
  1249. Value *elt = Builder.CreateExtractElement(vecLd, shufMask[0]);
  1250. ld->replaceAllUsesWith(elt);
  1251. }
  1252. } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
  1253. Value *val = st->getValueOperand();
  1254. if (resultSize > 1) {
  1255. for (unsigned i = 0; i < shufMask.size(); i++) {
  1256. unsigned idx = shufMask[i];
  1257. Value *valElt = Builder.CreateExtractElement(val, i);
  1258. vecLd = Builder.CreateInsertElement(vecLd, valElt, idx);
  1259. }
  1260. Builder.CreateStore(vecLd, vecInst);
  1261. } else {
  1262. vecLd = Builder.CreateInsertElement(vecLd, val, shufMask[0]);
  1263. Builder.CreateStore(vecLd, vecInst);
  1264. }
  1265. } else
  1266. DXASSERT(0, "matrix element should only used by load/store.");
  1267. AddToDeadInsts(CallUser);
  1268. }
  1269. } else {
  1270. // Subscript.
  1271. // Return a row.
  1272. // Use insertElement and extractElement.
  1273. ArrayType *AT = ArrayType::get(EltTy, col*row);
  1274. IRBuilder<> AllocaBuilder(
  1275. matSubInst->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  1276. Value *tempArray = AllocaBuilder.CreateAlloca(AT);
  1277. Value *zero = AllocaBuilder.getInt32(0);
  1278. bool isDynamicIndexing = !isa<ConstantInt>(mask);
  1279. SmallVector<Value *, 4> idxList;
  1280. for (unsigned i = 0; i < col; i++) {
  1281. idxList.emplace_back(
  1282. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i));
  1283. }
  1284. for (Value::use_iterator CallUI = matSubInst->use_begin(),
  1285. CallE = matSubInst->use_end();
  1286. CallUI != CallE;) {
  1287. Use &CallUse = *CallUI++;
  1288. Instruction *CallUser = cast<Instruction>(CallUse.getUser());
  1289. IRBuilder<> Builder(CallUser);
  1290. Value *vecLd = Builder.CreateLoad(vecInst);
  1291. if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
  1292. Value *sub = UndefValue::get(ld->getType());
  1293. if (!isDynamicIndexing) {
  1294. for (unsigned i = 0; i < col; i++) {
  1295. Value *matIdx = idxList[i];
  1296. Value *valElt = Builder.CreateExtractElement(vecLd, matIdx);
  1297. sub = Builder.CreateInsertElement(sub, valElt, i);
  1298. }
  1299. } else {
  1300. // Copy vec to array.
  1301. for (unsigned int i = 0; i < row*col; i++) {
  1302. Value *Elt =
  1303. Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
  1304. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1305. {zero, Builder.getInt32(i)});
  1306. Builder.CreateStore(Elt, Ptr);
  1307. }
  1308. for (unsigned i = 0; i < col; i++) {
  1309. Value *matIdx = idxList[i];
  1310. Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
  1311. Value *valElt = Builder.CreateLoad(Ptr);
  1312. sub = Builder.CreateInsertElement(sub, valElt, i);
  1313. }
  1314. }
  1315. ld->replaceAllUsesWith(sub);
  1316. } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
  1317. Value *val = st->getValueOperand();
  1318. if (!isDynamicIndexing) {
  1319. for (unsigned i = 0; i < col; i++) {
  1320. Value *matIdx = idxList[i];
  1321. Value *valElt = Builder.CreateExtractElement(val, i);
  1322. vecLd = Builder.CreateInsertElement(vecLd, valElt, matIdx);
  1323. }
  1324. } else {
  1325. // Copy vec to array.
  1326. for (unsigned int i = 0; i < row * col; i++) {
  1327. Value *Elt =
  1328. Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
  1329. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1330. {zero, Builder.getInt32(i)});
  1331. Builder.CreateStore(Elt, Ptr);
  1332. }
  1333. // Update array.
  1334. for (unsigned i = 0; i < col; i++) {
  1335. Value *matIdx = idxList[i];
  1336. Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
  1337. Value *valElt = Builder.CreateExtractElement(val, i);
  1338. Builder.CreateStore(valElt, Ptr);
  1339. }
  1340. // Copy array to vec.
  1341. for (unsigned int i = 0; i < row * col; i++) {
  1342. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1343. {zero, Builder.getInt32(i)});
  1344. Value *Elt = Builder.CreateLoad(Ptr);
  1345. vecLd = Builder.CreateInsertElement(vecLd, Elt, i);
  1346. }
  1347. }
  1348. Builder.CreateStore(vecLd, vecInst);
  1349. } else if (GetElementPtrInst *GEP =
  1350. dyn_cast<GetElementPtrInst>(CallUser)) {
  1351. Value *GEPOffset = HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
  1352. Value *NewGEP = Builder.CreateGEP(vecInst, {zero, GEPOffset});
  1353. GEP->replaceAllUsesWith(NewGEP);
  1354. } else
  1355. DXASSERT(0, "matrix subscript should only used by load/store.");
  1356. AddToDeadInsts(CallUser);
  1357. }
  1358. }
  1359. // Check vec version.
  1360. DXASSERT(matToVecMap.count(matSubInst) == 0, "should not have vec version");
  1361. // All the user should has been removed.
  1362. matSubInst->replaceAllUsesWith(UndefValue::get(matSubInst->getType()));
  1363. AddToDeadInsts(matSubInst);
  1364. }
  1365. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
  1366. Value *matGlobal, ArrayRef<Value *> vecGlobals,
  1367. CallInst *matLdStInst) {
  1368. // No dynamic indexing on matrix, flatten matrix to scalars.
  1369. // vecGlobals already in correct major.
  1370. Type *matType = matGlobal->getType()->getPointerElementType();
  1371. unsigned col, row;
  1372. HLMatrixLower::GetMatrixInfo(matType, col, row);
  1373. Type *vecType = HLMatrixLower::LowerMatrixType(matType);
  1374. IRBuilder<> Builder(matLdStInst);
  1375. HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
  1376. switch (opcode) {
  1377. case HLMatLoadStoreOpcode::ColMatLoad:
  1378. case HLMatLoadStoreOpcode::RowMatLoad: {
  1379. Value *Result = UndefValue::get(vecType);
  1380. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1381. Value *Elt = Builder.CreateLoad(vecGlobals[matIdx]);
  1382. Result = Builder.CreateInsertElement(Result, Elt, matIdx);
  1383. }
  1384. matLdStInst->replaceAllUsesWith(Result);
  1385. } break;
  1386. case HLMatLoadStoreOpcode::ColMatStore:
  1387. case HLMatLoadStoreOpcode::RowMatStore: {
  1388. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1389. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1390. Value *Elt = Builder.CreateExtractElement(Val, matIdx);
  1391. Builder.CreateStore(Elt, vecGlobals[matIdx]);
  1392. }
  1393. } break;
  1394. }
  1395. }
  1396. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
  1397. GlobalVariable *scalarArrayGlobal,
  1398. CallInst *matLdStInst) {
  1399. // vecGlobals already in correct major.
  1400. const bool bColMajor = true;
  1401. HLMatLoadStoreOpcode opcode =
  1402. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
  1403. switch (opcode) {
  1404. case HLMatLoadStoreOpcode::ColMatLoad:
  1405. case HLMatLoadStoreOpcode::RowMatLoad: {
  1406. IRBuilder<> Builder(matLdStInst);
  1407. Type *matTy = matGlobal->getType()->getPointerElementType();
  1408. unsigned col, row;
  1409. Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1410. Value *zeroIdx = Builder.getInt32(0);
  1411. std::vector<Value *> matElts(col * row);
  1412. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1413. Value *GEP = Builder.CreateInBoundsGEP(
  1414. scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
  1415. matElts[matIdx] = Builder.CreateLoad(GEP);
  1416. }
  1417. Value *newVec =
  1418. HLMatrixLower::BuildVector(EltTy, col * row, matElts, Builder);
  1419. matLdStInst->replaceAllUsesWith(newVec);
  1420. matLdStInst->eraseFromParent();
  1421. } break;
  1422. case HLMatLoadStoreOpcode::ColMatStore:
  1423. case HLMatLoadStoreOpcode::RowMatStore: {
  1424. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1425. IRBuilder<> Builder(matLdStInst);
  1426. Type *matTy = matGlobal->getType()->getPointerElementType();
  1427. unsigned col, row;
  1428. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1429. Value *zeroIdx = Builder.getInt32(0);
  1430. std::vector<Value *> matElts(col * row);
  1431. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1432. Value *GEP = Builder.CreateInBoundsGEP(
  1433. scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
  1434. Value *Elt = Builder.CreateExtractElement(Val, matIdx);
  1435. Builder.CreateStore(Elt, GEP);
  1436. }
  1437. matLdStInst->eraseFromParent();
  1438. } break;
  1439. }
  1440. }
  1441. void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(
  1442. CallInst *matSubInst, Value *vecPtr) {
  1443. Value *basePtr =
  1444. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  1445. Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1446. IRBuilder<> subBuilder(matSubInst);
  1447. Value *zeroIdx = subBuilder.getInt32(0);
  1448. HLSubscriptOpcode opcode =
  1449. static_cast<HLSubscriptOpcode>(GetHLOpcode(matSubInst));
  1450. Type *matTy = basePtr->getType()->getPointerElementType();
  1451. unsigned col, row;
  1452. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1453. std::vector<Value *> idxList;
  1454. switch (opcode) {
  1455. case HLSubscriptOpcode::ColMatSubscript:
  1456. case HLSubscriptOpcode::RowMatSubscript: {
  1457. // Just use index created in EmitHLSLMatrixSubscript.
  1458. for (unsigned c = 0; c < col; c++) {
  1459. Value *matIdx =
  1460. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + c);
  1461. idxList.emplace_back(matIdx);
  1462. }
  1463. } break;
  1464. case HLSubscriptOpcode::RowMatElement:
  1465. case HLSubscriptOpcode::ColMatElement: {
  1466. Type *resultType = matSubInst->getType()->getPointerElementType();
  1467. unsigned resultSize = 1;
  1468. if (resultType->isVectorTy())
  1469. resultSize = resultType->getVectorNumElements();
  1470. // Just use index created in EmitHLSLMatrixElement.
  1471. Constant *EltIdxs = cast<Constant>(idx);
  1472. for (unsigned i = 0; i < resultSize; i++) {
  1473. Value *matIdx = EltIdxs->getAggregateElement(i);
  1474. idxList.emplace_back(matIdx);
  1475. }
  1476. } break;
  1477. default:
  1478. DXASSERT(0, "invalid operation");
  1479. break;
  1480. }
  1481. // Cannot generate vector pointer
  1482. // Replace all uses with scalar pointers.
  1483. if (idxList.size() == 1) {
  1484. Value *Ptr =
  1485. subBuilder.CreateInBoundsGEP(vecPtr, {zeroIdx, idxList[0]});
  1486. matSubInst->replaceAllUsesWith(Ptr);
  1487. } else {
  1488. // Split the use of CI with Ptrs.
  1489. for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
  1490. Instruction *subsUser = cast<Instruction>(*(U++));
  1491. IRBuilder<> userBuilder(subsUser);
  1492. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
  1493. Value *IndexPtr =
  1494. HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
  1495. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1496. {zeroIdx, IndexPtr});
  1497. for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
  1498. Instruction *gepUser = cast<Instruction>(*(gepU++));
  1499. IRBuilder<> gepUserBuilder(gepUser);
  1500. if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
  1501. Value *subData = stUser->getValueOperand();
  1502. gepUserBuilder.CreateStore(subData, Ptr);
  1503. stUser->eraseFromParent();
  1504. } else if (LoadInst *ldUser = dyn_cast<LoadInst>(gepUser)) {
  1505. Value *subData = gepUserBuilder.CreateLoad(Ptr);
  1506. ldUser->replaceAllUsesWith(subData);
  1507. ldUser->eraseFromParent();
  1508. } else {
  1509. AddrSpaceCastInst *Cast = cast<AddrSpaceCastInst>(gepUser);
  1510. Cast->setOperand(0, Ptr);
  1511. }
  1512. }
  1513. GEP->eraseFromParent();
  1514. } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
  1515. Value *val = stUser->getValueOperand();
  1516. for (unsigned i = 0; i < idxList.size(); i++) {
  1517. Value *Elt = userBuilder.CreateExtractElement(val, i);
  1518. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1519. {zeroIdx, idxList[i]});
  1520. userBuilder.CreateStore(Elt, Ptr);
  1521. }
  1522. stUser->eraseFromParent();
  1523. } else {
  1524. Value *ldVal =
  1525. UndefValue::get(matSubInst->getType()->getPointerElementType());
  1526. for (unsigned i = 0; i < idxList.size(); i++) {
  1527. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1528. {zeroIdx, idxList[i]});
  1529. Value *Elt = userBuilder.CreateLoad(Ptr);
  1530. ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
  1531. }
  1532. // Must be load here.
  1533. LoadInst *ldUser = cast<LoadInst>(subsUser);
  1534. ldUser->replaceAllUsesWith(ldVal);
  1535. ldUser->eraseFromParent();
  1536. }
  1537. }
  1538. }
  1539. matSubInst->eraseFromParent();
  1540. }
  1541. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
  1542. CallInst *matLdStInst, Value *vecPtr) {
  1543. // Just translate into vector here.
  1544. // DynamicIndexingVectorToArray will change it to scalar array.
  1545. IRBuilder<> Builder(matLdStInst);
  1546. unsigned opcode = hlsl::GetHLOpcode(matLdStInst);
  1547. HLMatLoadStoreOpcode matLdStOp = static_cast<HLMatLoadStoreOpcode>(opcode);
  1548. switch (matLdStOp) {
  1549. case HLMatLoadStoreOpcode::ColMatLoad:
  1550. case HLMatLoadStoreOpcode::RowMatLoad: {
  1551. // Load as vector.
  1552. Value *newLoad = Builder.CreateLoad(vecPtr);
  1553. matLdStInst->replaceAllUsesWith(newLoad);
  1554. matLdStInst->eraseFromParent();
  1555. } break;
  1556. case HLMatLoadStoreOpcode::ColMatStore:
  1557. case HLMatLoadStoreOpcode::RowMatStore: {
  1558. // Change value to vector array, then store.
  1559. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1560. Value *vecArrayGep = vecPtr;
  1561. Builder.CreateStore(Val, vecArrayGep);
  1562. matLdStInst->eraseFromParent();
  1563. } break;
  1564. default:
  1565. DXASSERT(0, "invalid operation");
  1566. break;
  1567. }
  1568. }
  1569. // Flatten values inside init list to scalar elements.
  1570. static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
  1571. Value *val,
  1572. DenseMap<Instruction *, Value *> &matToVecMap,
  1573. IRBuilder<> &Builder) {
  1574. Type *valTy = val->getType();
  1575. if (valTy->isPointerTy()) {
  1576. if (HLMatrixLower::IsMatrixArrayPointer(valTy)) {
  1577. if (matToVecMap.count(cast<Instruction>(val))) {
  1578. val = matToVecMap[cast<Instruction>(val)];
  1579. } else {
  1580. // Convert to vec array with bitcast.
  1581. Type *vecArrayPtrTy = HLMatrixLower::LowerMatrixArrayPointer(valTy);
  1582. val = Builder.CreateBitCast(val, vecArrayPtrTy);
  1583. }
  1584. }
  1585. Type *valEltTy = val->getType()->getPointerElementType();
  1586. if (valEltTy->isVectorTy() || HLMatrixLower::IsMatrixType(valEltTy) ||
  1587. valEltTy->isSingleValueType()) {
  1588. Value *ldVal = Builder.CreateLoad(val);
  1589. IterateInitList(elts, idx, ldVal, matToVecMap, Builder);
  1590. } else {
  1591. Type *i32Ty = Type::getInt32Ty(valTy->getContext());
  1592. Value *zero = ConstantInt::get(i32Ty, 0);
  1593. if (ArrayType *AT = dyn_cast<ArrayType>(valEltTy)) {
  1594. for (unsigned i = 0; i < AT->getArrayNumElements(); i++) {
  1595. Value *gepIdx = ConstantInt::get(i32Ty, i);
  1596. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  1597. IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
  1598. }
  1599. } else {
  1600. // Struct.
  1601. StructType *ST = cast<StructType>(valEltTy);
  1602. for (unsigned i = 0; i < ST->getNumElements(); i++) {
  1603. Value *gepIdx = ConstantInt::get(i32Ty, i);
  1604. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  1605. IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
  1606. }
  1607. }
  1608. }
  1609. } else if (HLMatrixLower::IsMatrixType(valTy)) {
  1610. unsigned col, row;
  1611. HLMatrixLower::GetMatrixInfo(valTy, col, row);
  1612. unsigned matSize = col * row;
  1613. val = matToVecMap[cast<Instruction>(val)];
  1614. // temp matrix all row major
  1615. for (unsigned i = 0; i < matSize; i++) {
  1616. Value *Elt = Builder.CreateExtractElement(val, i);
  1617. elts[idx + i] = Elt;
  1618. }
  1619. idx += matSize;
  1620. } else {
  1621. if (valTy->isVectorTy()) {
  1622. unsigned vecSize = valTy->getVectorNumElements();
  1623. for (unsigned i = 0; i < vecSize; i++) {
  1624. Value *Elt = Builder.CreateExtractElement(val, i);
  1625. elts[idx + i] = Elt;
  1626. }
  1627. idx += vecSize;
  1628. } else {
  1629. DXASSERT(valTy->isSingleValueType(), "must be single value type here");
  1630. elts[idx++] = val;
  1631. }
  1632. }
  1633. }
  1634. // Store flattened init list elements into matrix array.
  1635. static void GenerateMatArrayInit(ArrayRef<Value *> elts, Value *ptr,
  1636. unsigned &offset, IRBuilder<> &Builder) {
  1637. Type *Ty = ptr->getType()->getPointerElementType();
  1638. if (Ty->isVectorTy()) {
  1639. unsigned vecSize = Ty->getVectorNumElements();
  1640. Type *eltTy = Ty->getVectorElementType();
  1641. Value *result = UndefValue::get(Ty);
  1642. for (unsigned i = 0; i < vecSize; i++) {
  1643. Value *elt = elts[offset + i];
  1644. if (elt->getType() != eltTy) {
  1645. // FIXME: get signed/unsigned info.
  1646. elt = CreateTypeCast(HLCastOpcode::DefaultCast, eltTy, elt, Builder);
  1647. }
  1648. result = Builder.CreateInsertElement(result, elt, i);
  1649. }
  1650. // Update offset.
  1651. offset += vecSize;
  1652. Builder.CreateStore(result, ptr);
  1653. } else {
  1654. DXASSERT(Ty->isArrayTy(), "must be array type");
  1655. Type *i32Ty = Type::getInt32Ty(Ty->getContext());
  1656. Constant *zero = ConstantInt::get(i32Ty, 0);
  1657. unsigned arraySize = Ty->getArrayNumElements();
  1658. for (unsigned i = 0; i < arraySize; i++) {
  1659. Value *GEP =
  1660. Builder.CreateInBoundsGEP(ptr, {zero, ConstantInt::get(i32Ty, i)});
  1661. GenerateMatArrayInit(elts, GEP, offset, Builder);
  1662. }
  1663. }
  1664. }
  1665. void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
  1666. // Array matrix init will be translated in TranslateMatArrayInitReplace.
  1667. if (matInitInst->getType()->isVoidTy())
  1668. return;
  1669. IRBuilder<> Builder(matInitInst);
  1670. unsigned col, row;
  1671. Type *EltTy = GetMatrixInfo(matInitInst->getType(), col, row);
  1672. Type *vecTy = VectorType::get(EltTy, col * row);
  1673. unsigned vecSize = vecTy->getVectorNumElements();
  1674. unsigned idx = 0;
  1675. std::vector<Value *> elts(vecSize);
  1676. // Skip opcode arg.
  1677. for (unsigned i = 1; i < matInitInst->getNumArgOperands(); i++) {
  1678. Value *val = matInitInst->getArgOperand(i);
  1679. IterateInitList(elts, idx, val, matToVecMap, Builder);
  1680. }
  1681. Value *newInit = UndefValue::get(vecTy);
  1682. // InitList is row major, the result is row major too.
  1683. for (unsigned i=0;i< col * row;i++) {
  1684. Constant *vecIdx = Builder.getInt32(i);
  1685. newInit = InsertElementInst::Create(newInit, elts[i], vecIdx);
  1686. Builder.Insert(cast<Instruction>(newInit));
  1687. }
  1688. // Replace matInit function call with matInitInst.
  1689. DXASSERT(matToVecMap.count(matInitInst), "must has vec version");
  1690. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
  1691. vecUseInst->replaceAllUsesWith(newInit);
  1692. AddToDeadInsts(vecUseInst);
  1693. matToVecMap[matInitInst] = newInit;
  1694. }
  1695. void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
  1696. IRBuilder<> Builder(matSelectInst);
  1697. unsigned col, row;
  1698. Type *EltTy = GetMatrixInfo(matSelectInst->getType(), col, row);
  1699. Type *vecTy = VectorType::get(EltTy, col * row);
  1700. unsigned vecSize = vecTy->getVectorNumElements();
  1701. CallInst *vecUseInst = cast<CallInst>(matToVecMap[matSelectInst]);
  1702. Instruction *LHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx));
  1703. Instruction *RHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx));
  1704. Value *Cond = vecUseInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
  1705. bool isVecCond = Cond->getType()->isVectorTy();
  1706. if (isVecCond) {
  1707. Instruction *MatCond = cast<Instruction>(
  1708. matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx));
  1709. DXASSERT_NOMSG(matToVecMap.count(MatCond));
  1710. Cond = matToVecMap[MatCond];
  1711. }
  1712. DXASSERT_NOMSG(matToVecMap.count(LHS));
  1713. Value *VLHS = matToVecMap[LHS];
  1714. DXASSERT_NOMSG(matToVecMap.count(RHS));
  1715. Value *VRHS = matToVecMap[RHS];
  1716. Value *VecSelect = UndefValue::get(vecTy);
  1717. for (unsigned i = 0; i < vecSize; i++) {
  1718. llvm::Value *EltCond = Cond;
  1719. if (isVecCond)
  1720. EltCond = Builder.CreateExtractElement(Cond, i);
  1721. llvm::Value *EltL = Builder.CreateExtractElement(VLHS, i);
  1722. llvm::Value *EltR = Builder.CreateExtractElement(VRHS, i);
  1723. llvm::Value *EltSelect = Builder.CreateSelect(EltCond, EltL, EltR);
  1724. VecSelect = Builder.CreateInsertElement(VecSelect, EltSelect, i);
  1725. }
  1726. AddToDeadInsts(vecUseInst);
  1727. vecUseInst->replaceAllUsesWith(VecSelect);
  1728. matToVecMap[matSelectInst] = VecSelect;
  1729. }
  1730. void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
  1731. Instruction *vecInst,
  1732. GetElementPtrInst *matGEP) {
  1733. SmallVector<Value *, 4> idxList(matGEP->idx_begin(), matGEP->idx_end());
  1734. IRBuilder<> GEPBuilder(matGEP);
  1735. Value *newGEP = GEPBuilder.CreateInBoundsGEP(vecInst, idxList);
  1736. // Only used by mat subscript and mat ld/st.
  1737. for (Value::user_iterator user = matGEP->user_begin();
  1738. user != matGEP->user_end();) {
  1739. Instruction *useInst = cast<Instruction>(*(user++));
  1740. IRBuilder<> Builder(useInst);
  1741. // Skip return here.
  1742. if (isa<ReturnInst>(useInst))
  1743. continue;
  1744. if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
  1745. // Function call.
  1746. hlsl::HLOpcodeGroup group =
  1747. hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
  1748. switch (group) {
  1749. case HLOpcodeGroup::HLMatLoadStore: {
  1750. unsigned opcode = GetHLOpcode(useCall);
  1751. HLMatLoadStoreOpcode matOpcode =
  1752. static_cast<HLMatLoadStoreOpcode>(opcode);
  1753. switch (matOpcode) {
  1754. case HLMatLoadStoreOpcode::ColMatLoad:
  1755. case HLMatLoadStoreOpcode::RowMatLoad: {
  1756. // Skip the vector version.
  1757. if (useCall->getType()->isVectorTy())
  1758. continue;
  1759. Value *newLd = Builder.CreateLoad(newGEP);
  1760. DXASSERT(matToVecMap.count(useCall), "must has vec version");
  1761. Value *oldLd = matToVecMap[useCall];
  1762. // Delete the oldLd.
  1763. AddToDeadInsts(cast<Instruction>(oldLd));
  1764. oldLd->replaceAllUsesWith(newLd);
  1765. matToVecMap[useCall] = newLd;
  1766. } break;
  1767. case HLMatLoadStoreOpcode::ColMatStore:
  1768. case HLMatLoadStoreOpcode::RowMatStore: {
  1769. Value *vecPtr = newGEP;
  1770. Value *matVal = useCall->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1771. // Skip the vector version.
  1772. if (matVal->getType()->isVectorTy()) {
  1773. AddToDeadInsts(useCall);
  1774. continue;
  1775. }
  1776. Instruction *matInst = cast<Instruction>(matVal);
  1777. DXASSERT(matToVecMap.count(matInst), "must has vec version");
  1778. Value *vecVal = matToVecMap[matInst];
  1779. Builder.CreateStore(vecVal, vecPtr);
  1780. } break;
  1781. }
  1782. } break;
  1783. case HLOpcodeGroup::HLSubscript: {
  1784. TranslateMatSubscript(matGEP, newGEP, useCall);
  1785. } break;
  1786. default:
  1787. DXASSERT(0, "invalid operation");
  1788. break;
  1789. }
  1790. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
  1791. // Just replace the src with vec version.
  1792. useInst->setOperand(0, newGEP);
  1793. } else {
  1794. // Must be GEP.
  1795. GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
  1796. TranslateMatArrayGEP(matGEP, cast<Instruction>(newGEP), GEP);
  1797. }
  1798. }
  1799. AddToDeadInsts(matGEP);
  1800. }
  1801. void HLMatrixLowerPass::replaceMatWithVec(Instruction *matInst,
  1802. Instruction *vecInst) {
  1803. for (Value::user_iterator user = matInst->user_begin();
  1804. user != matInst->user_end();) {
  1805. Instruction *useInst = cast<Instruction>(*(user++));
  1806. // Skip return here.
  1807. if (isa<ReturnInst>(useInst))
  1808. continue;
  1809. // User must be function call.
  1810. if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
  1811. hlsl::HLOpcodeGroup group =
  1812. hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
  1813. switch (group) {
  1814. case HLOpcodeGroup::HLIntrinsic: {
  1815. if (CallInst *matCI = dyn_cast<CallInst>(matInst)) {
  1816. MatIntrinsicReplace(cast<CallInst>(matInst), vecInst, useCall);
  1817. } else {
  1818. IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(useCall));
  1819. DXASSERT_LOCALVAR(opcode, opcode == IntrinsicOp::IOP_frexp,
  1820. "otherwise, unexpected opcode with matrix out parameter");
  1821. // NOTE: because out param use copy out semantic, so the operand of
  1822. // out must be temp alloca.
  1823. DXASSERT(isa<AllocaInst>(matInst), "else invalid mat ptr for frexp");
  1824. auto it = matToVecMap.find(useCall);
  1825. DXASSERT(it != matToVecMap.end(),
  1826. "else fail to create vec version of useCall");
  1827. CallInst *vecUseInst = cast<CallInst>(it->second);
  1828. for (unsigned i = 0; i < vecUseInst->getNumArgOperands(); i++) {
  1829. if (useCall->getArgOperand(i) == matInst) {
  1830. vecUseInst->setArgOperand(i, vecInst);
  1831. }
  1832. }
  1833. }
  1834. } break;
  1835. case HLOpcodeGroup::HLSelect: {
  1836. MatIntrinsicReplace(cast<CallInst>(matInst), vecInst, useCall);
  1837. } break;
  1838. case HLOpcodeGroup::HLBinOp: {
  1839. TrivialMatBinOpReplace(cast<CallInst>(matInst), vecInst, useCall);
  1840. } break;
  1841. case HLOpcodeGroup::HLUnOp: {
  1842. TrivialMatUnOpReplace(cast<CallInst>(matInst), vecInst, useCall);
  1843. } break;
  1844. case HLOpcodeGroup::HLCast: {
  1845. TranslateMatCast(cast<CallInst>(matInst), vecInst, useCall);
  1846. } break;
  1847. case HLOpcodeGroup::HLMatLoadStore: {
  1848. DXASSERT(matToVecMap.count(useCall), "must has vec version");
  1849. Value *vecUser = matToVecMap[useCall];
  1850. if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
  1851. // Load Already translated in lowerToVec.
  1852. // Store val operand will be set by the val use.
  1853. // Do nothing here.
  1854. } else if (StoreInst *stInst = dyn_cast<StoreInst>(vecUser))
  1855. stInst->setOperand(0, vecInst);
  1856. else
  1857. TrivialMatReplace(cast<CallInst>(matInst), vecInst, useCall);
  1858. } break;
  1859. case HLOpcodeGroup::HLSubscript: {
  1860. if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst))
  1861. TranslateMatSubscript(AI, vecInst, useCall);
  1862. else
  1863. TrivialMatReplace(cast<CallInst>(matInst), vecInst, useCall);
  1864. } break;
  1865. case HLOpcodeGroup::HLInit: {
  1866. DXASSERT(!isa<AllocaInst>(matInst), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
  1867. TranslateMatInit(useCall);
  1868. } break;
  1869. }
  1870. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
  1871. // Just replace the src with vec version.
  1872. useInst->setOperand(0, vecInst);
  1873. } else {
  1874. // Must be GEP on mat array alloca.
  1875. GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
  1876. AllocaInst *AI = cast<AllocaInst>(matInst);
  1877. TranslateMatArrayGEP(AI, vecInst, GEP);
  1878. }
  1879. }
  1880. }
  1881. void HLMatrixLowerPass::finalMatTranslation(Instruction *matInst) {
  1882. // Translate matInit.
  1883. if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
  1884. hlsl::HLOpcodeGroup group =
  1885. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  1886. switch (group) {
  1887. case HLOpcodeGroup::HLInit: {
  1888. TranslateMatInit(CI);
  1889. } break;
  1890. case HLOpcodeGroup::HLSelect: {
  1891. TranslateMatSelect(CI);
  1892. } break;
  1893. default:
  1894. // Skip group already translated.
  1895. break;
  1896. }
  1897. }
  1898. }
  1899. void HLMatrixLowerPass::DeleteDeadInsts() {
  1900. // Delete the matrix version insts.
  1901. for (Instruction *deadInst : m_deadInsts) {
  1902. // Replace with undef and remove it.
  1903. deadInst->replaceAllUsesWith(UndefValue::get(deadInst->getType()));
  1904. deadInst->eraseFromParent();
  1905. }
  1906. m_deadInsts.clear();
  1907. m_inDeadInstsSet.clear();
  1908. }
  1909. static bool OnlyUsedByMatrixLdSt(Value *V) {
  1910. bool onlyLdSt = true;
  1911. for (User *user : V->users()) {
  1912. if (isa<Constant>(user) && user->use_empty())
  1913. continue;
  1914. CallInst *CI = cast<CallInst>(user);
  1915. if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
  1916. HLOpcodeGroup::HLMatLoadStore)
  1917. continue;
  1918. onlyLdSt = false;
  1919. break;
  1920. }
  1921. return onlyLdSt;
  1922. }
  1923. static Constant *LowerMatrixArrayConst(Constant *MA, Type *ResultTy) {
  1924. if (ArrayType *AT = dyn_cast<ArrayType>(ResultTy)) {
  1925. std::vector<Constant *> Elts;
  1926. Type *EltResultTy = AT->getElementType();
  1927. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  1928. Constant *Elt =
  1929. LowerMatrixArrayConst(MA->getAggregateElement(i), EltResultTy);
  1930. Elts.emplace_back(Elt);
  1931. }
  1932. return ConstantArray::get(AT, Elts);
  1933. } else {
  1934. // Cast float[row][col] -> float< row * col>.
  1935. // Get float[row][col] from the struct.
  1936. Constant *rows = MA->getAggregateElement((unsigned)0);
  1937. ArrayType *RowAT = cast<ArrayType>(rows->getType());
  1938. std::vector<Constant *> Elts;
  1939. for (unsigned r=0;r<RowAT->getArrayNumElements();r++) {
  1940. Constant *row = rows->getAggregateElement(r);
  1941. VectorType *VT = cast<VectorType>(row->getType());
  1942. for (unsigned c = 0; c < VT->getVectorNumElements(); c++) {
  1943. Elts.emplace_back(row->getAggregateElement(c));
  1944. }
  1945. }
  1946. return ConstantVector::get(Elts);
  1947. }
  1948. }
  1949. void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
  1950. // Lower to array of vector array like float[row * col].
  1951. // It's follow the major of decl.
  1952. // DynamicIndexingVectorToArray will change it to scalar array.
  1953. Type *Ty = GV->getType()->getPointerElementType();
  1954. std::vector<unsigned> arraySizeList;
  1955. while (Ty->isArrayTy()) {
  1956. arraySizeList.push_back(Ty->getArrayNumElements());
  1957. Ty = Ty->getArrayElementType();
  1958. }
  1959. unsigned row, col;
  1960. Type *EltTy = GetMatrixInfo(Ty, col, row);
  1961. Ty = VectorType::get(EltTy, col * row);
  1962. for (auto arraySize = arraySizeList.rbegin();
  1963. arraySize != arraySizeList.rend(); arraySize++)
  1964. Ty = ArrayType::get(Ty, *arraySize);
  1965. Type *VecArrayTy = Ty;
  1966. Constant *OldInitVal = GV->getInitializer();
  1967. Constant *InitVal =
  1968. isa<UndefValue>(OldInitVal)
  1969. ? UndefValue::get(VecArrayTy)
  1970. : LowerMatrixArrayConst(OldInitVal, cast<ArrayType>(VecArrayTy));
  1971. bool isConst = GV->isConstant();
  1972. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  1973. unsigned AddressSpace = GV->getType()->getAddressSpace();
  1974. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  1975. Module *M = GV->getParent();
  1976. GlobalVariable *VecGV =
  1977. new llvm::GlobalVariable(*M, VecArrayTy, /*IsConstant*/ isConst, linkage,
  1978. /*InitVal*/ InitVal, GV->getName() + ".v",
  1979. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  1980. // Add debug info.
  1981. if (m_HasDbgInfo) {
  1982. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  1983. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, VecGV);
  1984. }
  1985. DenseMap<Instruction *, Value *> matToVecMap;
  1986. for (User *U : GV->users()) {
  1987. Value *VecGEP = nullptr;
  1988. // Must be GEP or GEPOperator.
  1989. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  1990. IRBuilder<> Builder(GEP);
  1991. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  1992. VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
  1993. AddToDeadInsts(GEP);
  1994. } else {
  1995. GEPOperator *GEPOP = cast<GEPOperator>(U);
  1996. IRBuilder<> Builder(GV->getContext());
  1997. SmallVector<Value *, 4> idxList(GEPOP->idx_begin(), GEPOP->idx_end());
  1998. VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
  1999. }
  2000. for (auto user = U->user_begin(); user != U->user_end();) {
  2001. CallInst *CI = cast<CallInst>(*(user++));
  2002. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2003. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2004. TranslateMatLoadStoreOnGlobalPtr(CI, VecGEP);
  2005. } else if (group == HLOpcodeGroup::HLSubscript) {
  2006. TranslateMatSubscriptOnGlobalPtr(CI, VecGEP);
  2007. } else {
  2008. DXASSERT(0, "invalid operation");
  2009. }
  2010. }
  2011. }
  2012. DeleteDeadInsts();
  2013. GV->removeDeadConstantUsers();
  2014. GV->eraseFromParent();
  2015. }
  2016. static void FlattenMatConst(Constant *M, std::vector<Constant *> &Elts) {
  2017. unsigned row, col;
  2018. Type *EltTy = HLMatrixLower::GetMatrixInfo(M->getType(), col, row);
  2019. if (isa<UndefValue>(M)) {
  2020. Constant *Elt = UndefValue::get(EltTy);
  2021. for (unsigned i=0;i<col*row;i++)
  2022. Elts.emplace_back(Elt);
  2023. } else {
  2024. M = M->getAggregateElement((unsigned)0);
  2025. // Initializer is already in correct major.
  2026. // Just read it here.
  2027. // The type is vector<element, col>[row].
  2028. for (unsigned r = 0; r < row; r++) {
  2029. Constant *C = M->getAggregateElement(r);
  2030. for (unsigned c = 0; c < col; c++) {
  2031. Elts.emplace_back(C->getAggregateElement(c));
  2032. }
  2033. }
  2034. }
  2035. }
  2036. void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
  2037. if (HLMatrixLower::IsMatrixArrayPointer(GV->getType())) {
  2038. runOnGlobalMatrixArray(GV);
  2039. return;
  2040. }
  2041. Type *Ty = GV->getType()->getPointerElementType();
  2042. if (!HLMatrixLower::IsMatrixType(Ty))
  2043. return;
  2044. bool onlyLdSt = OnlyUsedByMatrixLdSt(GV);
  2045. bool isConst = GV->isConstant();
  2046. Type *vecTy = HLMatrixLower::LowerMatrixType(Ty);
  2047. Module *M = GV->getParent();
  2048. const DataLayout &DL = M->getDataLayout();
  2049. std::vector<Constant *> Elts;
  2050. // Lower to vector or array for scalar matrix.
  2051. // Make it col major so don't need shuffle when load/store.
  2052. FlattenMatConst(GV->getInitializer(), Elts);
  2053. if (onlyLdSt) {
  2054. Type *EltTy = vecTy->getVectorElementType();
  2055. unsigned vecSize = vecTy->getVectorNumElements();
  2056. std::vector<Value *> vecGlobals(vecSize);
  2057. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  2058. unsigned AddressSpace = GV->getType()->getAddressSpace();
  2059. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  2060. unsigned debugOffset = 0;
  2061. unsigned size = DL.getTypeAllocSizeInBits(EltTy);
  2062. unsigned align = DL.getPrefTypeAlignment(EltTy);
  2063. for (int i = 0, e = vecSize; i != e; ++i) {
  2064. Constant *InitVal = Elts[i];
  2065. GlobalVariable *EltGV = new llvm::GlobalVariable(
  2066. *M, EltTy, /*IsConstant*/ isConst, linkage,
  2067. /*InitVal*/ InitVal, GV->getName() + "." + Twine(i),
  2068. /*InsertBefore*/nullptr,
  2069. TLMode, AddressSpace);
  2070. // Add debug info.
  2071. if (m_HasDbgInfo) {
  2072. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2073. HLModule::CreateElementGlobalVariableDebugInfo(
  2074. GV, Finder, EltGV, size, align, debugOffset,
  2075. EltGV->getName().ltrim(GV->getName()));
  2076. debugOffset += size;
  2077. }
  2078. vecGlobals[i] = EltGV;
  2079. }
  2080. for (User *user : GV->users()) {
  2081. if (isa<Constant>(user) && user->use_empty())
  2082. continue;
  2083. CallInst *CI = cast<CallInst>(user);
  2084. TranslateMatLoadStoreOnGlobal(GV, vecGlobals, CI);
  2085. AddToDeadInsts(CI);
  2086. }
  2087. DeleteDeadInsts();
  2088. GV->eraseFromParent();
  2089. }
  2090. else {
  2091. // lower to array of scalar here.
  2092. ArrayType *AT = ArrayType::get(vecTy->getVectorElementType(), vecTy->getVectorNumElements());
  2093. Constant *InitVal = ConstantArray::get(AT, Elts);
  2094. GlobalVariable *arrayMat = new llvm::GlobalVariable(
  2095. *M, AT, /*IsConstant*/ false, llvm::GlobalValue::InternalLinkage,
  2096. /*InitVal*/ InitVal, GV->getName());
  2097. // Add debug info.
  2098. if (m_HasDbgInfo) {
  2099. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2100. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder,
  2101. arrayMat);
  2102. }
  2103. for (auto U = GV->user_begin(); U != GV->user_end();) {
  2104. Value *user = *(U++);
  2105. CallInst *CI = cast<CallInst>(user);
  2106. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2107. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2108. TranslateMatLoadStoreOnGlobal(GV, arrayMat, CI);
  2109. }
  2110. else {
  2111. DXASSERT(group == HLOpcodeGroup::HLSubscript, "Must be subscript operation");
  2112. TranslateMatSubscriptOnGlobalPtr(CI, arrayMat);
  2113. }
  2114. }
  2115. GV->removeDeadConstantUsers();
  2116. GV->eraseFromParent();
  2117. }
  2118. }
  2119. void HLMatrixLowerPass::runOnFunction(Function &F) {
  2120. // Create vector version of matrix instructions first.
  2121. // The matrix operands will be undefval for these instructions.
  2122. for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
  2123. BasicBlock *BB = BBI;
  2124. for (Instruction &I : BB->getInstList()) {
  2125. if (IsMatrixType(I.getType())) {
  2126. lowerToVec(&I);
  2127. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
  2128. Type *Ty = AI->getAllocatedType();
  2129. if (HLMatrixLower::IsMatrixType(Ty)) {
  2130. lowerToVec(&I);
  2131. } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
  2132. lowerToVec(&I);
  2133. }
  2134. } else if (CallInst *CI = dyn_cast<CallInst>(&I)) {
  2135. HLOpcodeGroup group =
  2136. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  2137. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2138. HLMatLoadStoreOpcode opcode =
  2139. static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
  2140. DXASSERT_LOCALVAR(opcode,
  2141. opcode == HLMatLoadStoreOpcode::ColMatStore ||
  2142. opcode == HLMatLoadStoreOpcode::RowMatStore,
  2143. "Must MatStore here, load will go IsMatrixType path");
  2144. // Lower it here to make sure it is ready before replace.
  2145. lowerToVec(&I);
  2146. }
  2147. }
  2148. }
  2149. }
  2150. // Update the use of matrix inst with the vector version.
  2151. for (auto matToVecIter = matToVecMap.begin();
  2152. matToVecIter != matToVecMap.end();) {
  2153. auto matToVec = matToVecIter++;
  2154. replaceMatWithVec(matToVec->first, cast<Instruction>(matToVec->second));
  2155. }
  2156. // Translate mat inst which require all operands ready.
  2157. for (auto matToVecIter = matToVecMap.begin();
  2158. matToVecIter != matToVecMap.end();) {
  2159. auto matToVec = matToVecIter++;
  2160. finalMatTranslation(matToVec->first);
  2161. }
  2162. // Delete the matrix version insts.
  2163. for (auto matToVecIter = matToVecMap.begin();
  2164. matToVecIter != matToVecMap.end();) {
  2165. auto matToVec = matToVecIter++;
  2166. // Add to m_deadInsts.
  2167. Instruction *matInst = matToVec->first;
  2168. AddToDeadInsts(matInst);
  2169. }
  2170. DeleteDeadInsts();
  2171. matToVecMap.clear();
  2172. }