HLMatrixLowerPass.cpp 91 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixLowerPass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // HLMatrixLowerPass implementation. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  12. #include "dxc/HLSL/HLMatrixLowerPass.h"
  13. #include "dxc/HLSL/HLOperations.h"
  14. #include "dxc/HLSL/HLModule.h"
  15. #include "dxc/HlslIntrinsicOp.h"
  16. #include "dxc/Support/Global.h"
  17. #include "llvm/IR/IRBuilder.h"
  18. #include "llvm/IR/Module.h"
  19. #include "llvm/IR/DebugInfo.h"
  20. #include "llvm/IR/IntrinsicInst.h"
  21. #include "llvm/Transforms/Utils/Local.h"
  22. #include "llvm/Pass.h"
  23. #include "llvm/Support/raw_ostream.h"
  24. #include <unordered_set>
  25. #include <vector>
  26. using namespace llvm;
  27. using namespace hlsl;
  28. using namespace hlsl::HLMatrixLower;
  29. namespace hlsl {
  30. namespace HLMatrixLower {
  31. bool IsMatrixType(Type *Ty) {
  32. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  33. Type *EltTy = ST->getElementType(0);
  34. if (!ST->getName().startswith("class.matrix"))
  35. return false;
  36. bool isVecArray = EltTy->isArrayTy() &&
  37. EltTy->getArrayElementType()->isVectorTy();
  38. return isVecArray && EltTy->getArrayNumElements() <= 4;
  39. }
  40. return false;
  41. }
  42. // Translate matrix type to vector type.
  43. Type *LowerMatrixType(Type *Ty) {
  44. // Only translate matrix type and function type which use matrix type.
  45. // Not translate struct has matrix or matrix pointer.
  46. // Struct should be flattened before.
  47. // Pointer could cover by matldst which use vector as value type.
  48. if (FunctionType *FT = dyn_cast<FunctionType>(Ty)) {
  49. Type *RetTy = LowerMatrixType(FT->getReturnType());
  50. SmallVector<Type *, 4> params;
  51. for (Type *param : FT->params()) {
  52. params.emplace_back(LowerMatrixType(param));
  53. }
  54. return FunctionType::get(RetTy, params, false);
  55. } else if (IsMatrixType(Ty)) {
  56. unsigned row, col;
  57. Type *EltTy = GetMatrixInfo(Ty, col, row);
  58. return VectorType::get(EltTy, row * col);
  59. } else {
  60. return Ty;
  61. }
  62. }
  63. Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
  64. DXASSERT(IsMatrixType(Ty), "not matrix type");
  65. StructType *ST = cast<StructType>(Ty);
  66. Type *EltTy = ST->getElementType(0);
  67. Type *ColTy = EltTy->getArrayElementType();
  68. col = EltTy->getArrayNumElements();
  69. row = ColTy->getVectorNumElements();
  70. return ColTy->getVectorElementType();
  71. }
  72. bool IsMatrixArrayPointer(llvm::Type *Ty) {
  73. if (!Ty->isPointerTy())
  74. return false;
  75. Ty = Ty->getPointerElementType();
  76. if (!Ty->isArrayTy())
  77. return false;
  78. while (Ty->isArrayTy())
  79. Ty = Ty->getArrayElementType();
  80. return IsMatrixType(Ty);
  81. }
  82. Type *LowerMatrixArrayPointer(Type *Ty) {
  83. Ty = Ty->getPointerElementType();
  84. std::vector<unsigned> arraySizeList;
  85. while (Ty->isArrayTy()) {
  86. arraySizeList.push_back(Ty->getArrayNumElements());
  87. Ty = Ty->getArrayElementType();
  88. }
  89. Ty = LowerMatrixType(Ty);
  90. for (auto arraySize = arraySizeList.rbegin();
  91. arraySize != arraySizeList.rend(); arraySize++)
  92. Ty = ArrayType::get(Ty, *arraySize);
  93. return PointerType::get(Ty, 0);
  94. }
  95. Value *BuildMatrix(Type *EltTy, unsigned col, unsigned row,
  96. bool colMajor, ArrayRef<Value *> elts,
  97. IRBuilder<> &Builder) {
  98. Value *Result = UndefValue::get(VectorType::get(EltTy, col * row));
  99. if (colMajor) {
  100. for (unsigned i = 0; i < col * row; i++)
  101. Result = Builder.CreateInsertElement(Result, elts[i], i);
  102. } else {
  103. for (unsigned r = 0; r < row; r++)
  104. for (unsigned c = 0; c < col; c++) {
  105. unsigned rowMajorIdx = r * col + c;
  106. unsigned colMajorIdx = c * row + r;
  107. Result =
  108. Builder.CreateInsertElement(Result, elts[rowMajorIdx], colMajorIdx);
  109. }
  110. }
  111. return Result;
  112. }
  113. } // namespace HLMatrixLower
  114. } // namespace hlsl
  115. namespace {
  116. class HLMatrixLowerPass : public ModulePass {
  117. public:
  118. static char ID; // Pass identification, replacement for typeid
  119. explicit HLMatrixLowerPass() : ModulePass(ID) {}
  120. const char *getPassName() const override { return "HL matrix lower"; }
  121. bool runOnModule(Module &M) override {
  122. m_pModule = &M;
  123. m_pHLModule = &m_pModule->GetOrCreateHLModule();
  124. // Load up debug information, to cross-reference values and the instructions
  125. // used to load them.
  126. m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  127. for (Function &F : M.functions()) {
  128. if (F.isDeclaration())
  129. continue;
  130. runOnFunction(F);
  131. }
  132. std::vector<GlobalVariable*> staticGVs;
  133. for (GlobalVariable &GV : M.globals()) {
  134. if (HLModule::IsStaticGlobal(&GV) ||
  135. HLModule::IsSharedMemoryGlobal(&GV)) {
  136. staticGVs.emplace_back(&GV);
  137. }
  138. }
  139. for (GlobalVariable *GV : staticGVs)
  140. runOnGlobal(GV);
  141. return true;
  142. }
  143. private:
  144. Module *m_pModule;
  145. HLModule *m_pHLModule;
  146. bool m_HasDbgInfo;
  147. std::vector<Instruction *> m_deadInsts;
  148. // For instruction like matrix array init.
  149. // May use more than 1 matrix alloca inst.
  150. // This set is here to avoid put it into deadInsts more than once.
  151. std::unordered_set<Instruction *> m_inDeadInstsSet;
  152. // For most matrix insturction users, it will only have one matrix use.
  153. // Use vector so save deadInsts because vector is cheap.
  154. void AddToDeadInsts(Instruction *I) { m_deadInsts.emplace_back(I); }
  155. // In case instruction has more than one matrix use.
  156. // Use AddToDeadInstsWithDups to make sure it's not add to deadInsts more than once.
  157. void AddToDeadInstsWithDups(Instruction *I) {
  158. if (m_inDeadInstsSet.count(I) == 0) {
  159. // Only add to deadInsts when it's not inside m_inDeadInstsSet.
  160. m_inDeadInstsSet.insert(I);
  161. AddToDeadInsts(I);
  162. }
  163. }
  164. void runOnFunction(Function &F);
  165. void runOnGlobal(GlobalVariable *GV);
  166. void runOnGlobalMatrixArray(GlobalVariable *GV);
  167. Instruction *MatCastToVec(CallInst *CI);
  168. Instruction *MatLdStToVec(CallInst *CI);
  169. Instruction *MatSubscriptToVec(CallInst *CI);
  170. Instruction *MatIntrinsicToVec(CallInst *CI);
  171. Instruction *TrivialMatUnOpToVec(CallInst *CI);
  172. // Replace matInst with vecInst on matUseInst.
  173. void TrivialMatUnOpReplace(CallInst *matInst, Instruction *vecInst,
  174. CallInst *matUseInst);
  175. Instruction *TrivialMatBinOpToVec(CallInst *CI);
  176. // Replace matInst with vecInst on matUseInst.
  177. void TrivialMatBinOpReplace(CallInst *matInst, Instruction *vecInst,
  178. CallInst *matUseInst);
  179. // Replace matInst with vecInst on mulInst.
  180. void TranslateMatMatMul(CallInst *matInst, Instruction *vecInst,
  181. CallInst *mulInst);
  182. void TranslateMatVecMul(CallInst *matInst, Instruction *vecInst,
  183. CallInst *mulInst);
  184. void TranslateVecMatMul(CallInst *matInst, Instruction *vecInst,
  185. CallInst *mulInst);
  186. void TranslateMul(CallInst *matInst, Instruction *vecInst, CallInst *mulInst);
  187. // Replace matInst with vecInst on transposeInst.
  188. void TranslateMatTranspose(CallInst *matInst, Instruction *vecInst,
  189. CallInst *transposeInst);
  190. void TranslateMatDeterminant(CallInst *matInst, Instruction *vecInst,
  191. CallInst *determinantInst);
  192. void MatIntrinsicReplace(CallInst *matInst, Instruction *vecInst,
  193. CallInst *matUseInst);
  194. // Replace matInst with vecInst on castInst.
  195. void TranslateMatMatCast(CallInst *matInst, Instruction *vecInst,
  196. CallInst *castInst);
  197. void TranslateMatToOtherCast(CallInst *matInst, Instruction *vecInst,
  198. CallInst *castInst);
  199. void TranslateMatCast(CallInst *matInst, Instruction *vecInst,
  200. CallInst *castInst);
  201. // Replace matInst with vecInst in matSubscript
  202. void TranslateMatSubscript(Value *matInst, Value *vecInst,
  203. CallInst *matSubInst);
  204. // Replace matInst with vecInst
  205. void TranslateMatInit(CallInst *matInitInst);
  206. // Replace matInst with vecInst.
  207. void TranslateMatSelect(CallInst *matSelectInst);
  208. // Replace matInst with vecInst on matInitInst.
  209. void TranslateMatArrayGEP(Value *matInst, Instruction *vecInst,
  210. GetElementPtrInst *matGEP);
  211. void TranslateMatLoadStoreOnGlobal(Value *matGlobal, ArrayRef<Value *>vecGlobals,
  212. CallInst *matLdStInst);
  213. void TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
  214. CallInst *matLdStInst);
  215. void TranslateMatSubscriptOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
  216. CallInst *matSubInst);
  217. void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
  218. void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
  219. // Replace matInst with vecInst on matUseInst.
  220. void TrivialMatReplace(CallInst *matInst, Instruction *vecInst,
  221. CallInst *matUseInst);
  222. // Lower a matrix type instruction to a vector type instruction.
  223. void lowerToVec(Instruction *matInst);
  224. // Lower users of a matrix type instruction.
  225. void replaceMatWithVec(Instruction *matInst, Instruction *vecInst);
  226. // Translate mat inst which need all operands ready.
  227. void finalMatTranslation(Instruction *matInst);
  228. // Delete dead insts in m_deadInsts.
  229. void DeleteDeadInsts();
  230. // Map from matrix inst to its vector version.
  231. DenseMap<Instruction *, Value *> matToVecMap;
  232. };
  233. }
  234. char HLMatrixLowerPass::ID = 0;
  235. ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
  236. INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
  237. // All calculation on col major.
  238. static unsigned GetMatIdx(unsigned r, unsigned c, unsigned rowSize) {
  239. return (c * rowSize + r);
  240. }
  241. static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
  242. IRBuilder<> Builder) {
  243. // Cast to bool.
  244. if (toTy->getScalarType()->isIntegerTy() &&
  245. toTy->getScalarType()->getIntegerBitWidth() == 1) {
  246. Type *fromTy = src->getType();
  247. bool isFloat = fromTy->getScalarType()->isFloatingPointTy();
  248. Constant *zero;
  249. if (isFloat)
  250. zero = llvm::ConstantFP::get(fromTy->getScalarType(), 0);
  251. else
  252. zero = llvm::ConstantInt::get(fromTy->getScalarType(), 0);
  253. if (toTy->getScalarType() != toTy) {
  254. // Create constant vector.
  255. unsigned size = toTy->getVectorNumElements();
  256. std::vector<Constant *> zeros(size, zero);
  257. zero = llvm::ConstantVector::get(zeros);
  258. }
  259. if (isFloat)
  260. return cast<Instruction>(Builder.CreateFCmpOEQ(src, zero));
  261. else
  262. return cast<Instruction>(Builder.CreateICmpEQ(src, zero));
  263. }
  264. Type *eltToTy = toTy->getScalarType();
  265. Type *eltFromTy = src->getType()->getScalarType();
  266. bool fromUnsigned = castOp == HLCastOpcode::FromUnsignedCast ||
  267. castOp == HLCastOpcode::UnsignedUnsignedCast;
  268. bool toUnsigned = castOp == HLCastOpcode::ToUnsignedCast ||
  269. castOp == HLCastOpcode::UnsignedUnsignedCast;
  270. Instruction::CastOps castOps = static_cast<Instruction::CastOps>(
  271. HLModule::FindCastOp(fromUnsigned, toUnsigned, eltFromTy, eltToTy));
  272. return cast<Instruction>(Builder.CreateCast(castOps, src, toTy));
  273. }
  274. Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
  275. IRBuilder<> Builder(CI);
  276. Value *op = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  277. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(CI));
  278. bool ToMat = IsMatrixType(CI->getType());
  279. bool FromMat = IsMatrixType(op->getType());
  280. if (ToMat && !FromMat) {
  281. // Translate OtherToMat here.
  282. // Rest will translated when replace.
  283. unsigned col, row;
  284. Type *EltTy = GetMatrixInfo(CI->getType(), col, row);
  285. unsigned toSize = col * row;
  286. Instruction *sizeCast = nullptr;
  287. Type *FromTy = op->getType();
  288. Type *I32Ty = IntegerType::get(FromTy->getContext(), 32);
  289. if (FromTy->isVectorTy()) {
  290. std::vector<Constant *> MaskVec(toSize);
  291. for (size_t i = 0; i != toSize; ++i)
  292. MaskVec[i] = ConstantInt::get(I32Ty, i);
  293. Value *castMask = ConstantVector::get(MaskVec);
  294. sizeCast = new ShuffleVectorInst(op, op, castMask);
  295. Builder.Insert(sizeCast);
  296. } else {
  297. op = Builder.CreateInsertElement(
  298. UndefValue::get(VectorType::get(FromTy, 1)), op, (uint64_t)0);
  299. Constant *zero = ConstantInt::get(I32Ty, 0);
  300. std::vector<Constant *> MaskVec(toSize, zero);
  301. Value *castMask = ConstantVector::get(MaskVec);
  302. sizeCast = new ShuffleVectorInst(op, op, castMask);
  303. Builder.Insert(sizeCast);
  304. }
  305. Instruction *typeCast = sizeCast;
  306. if (EltTy != FromTy->getScalarType()) {
  307. typeCast = CreateTypeCast(opcode, VectorType::get(EltTy, toSize),
  308. sizeCast, Builder);
  309. }
  310. return typeCast;
  311. } else if (FromMat && ToMat) {
  312. if (isa<Argument>(op)) {
  313. // Cast From mat to mat for arugment.
  314. IRBuilder<> Builder(CI);
  315. // Here only lower the return type to vector.
  316. Type *RetTy = LowerMatrixType(CI->getType());
  317. SmallVector<Type *, 4> params;
  318. for (Value *operand : CI->arg_operands()) {
  319. params.emplace_back(operand->getType());
  320. }
  321. Type *FT = FunctionType::get(RetTy, params, false);
  322. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  323. unsigned opcode = GetHLOpcode(CI);
  324. Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT),
  325. group, opcode);
  326. SmallVector<Value *, 4> argList;
  327. for (Value *arg : CI->arg_operands()) {
  328. argList.emplace_back(arg);
  329. }
  330. return Builder.CreateCall(vecF, argList);
  331. }
  332. }
  333. return MatIntrinsicToVec(CI);
  334. }
  335. Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
  336. IRBuilder<> Builder(CI);
  337. unsigned opcode = GetHLOpcode(CI);
  338. HLMatLoadStoreOpcode matOpcode = static_cast<HLMatLoadStoreOpcode>(opcode);
  339. Instruction *result = nullptr;
  340. switch (matOpcode) {
  341. case HLMatLoadStoreOpcode::ColMatLoad:
  342. case HLMatLoadStoreOpcode::RowMatLoad: {
  343. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
  344. if (isa<AllocaInst>(matPtr)) {
  345. Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
  346. result = Builder.CreateLoad(vecPtr);
  347. } else
  348. result = MatIntrinsicToVec(CI);
  349. } break;
  350. case HLMatLoadStoreOpcode::ColMatStore:
  351. case HLMatLoadStoreOpcode::RowMatStore: {
  352. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
  353. if (isa<AllocaInst>(matPtr)) {
  354. Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
  355. Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  356. Value *vecVal =
  357. UndefValue::get(HLMatrixLower::LowerMatrixType(matVal->getType()));
  358. result = Builder.CreateStore(vecVal, vecPtr);
  359. } else
  360. result = MatIntrinsicToVec(CI);
  361. } break;
  362. }
  363. return result;
  364. }
  365. Instruction *HLMatrixLowerPass::MatSubscriptToVec(CallInst *CI) {
  366. IRBuilder<> Builder(CI);
  367. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  368. if (isa<AllocaInst>(matPtr)) {
  369. // Here just create a new matSub call which use vec ptr.
  370. // Later in TranslateMatSubscript will do the real translation.
  371. std::vector<Value *> args(CI->getNumArgOperands());
  372. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  373. args[i] = CI->getArgOperand(i);
  374. }
  375. // Change mat ptr into vec ptr.
  376. args[HLOperandIndex::kMatSubscriptMatOpIdx] =
  377. matToVecMap[cast<Instruction>(matPtr)];
  378. std::vector<Type *> paramTyList(CI->getNumArgOperands());
  379. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  380. paramTyList[i] = args[i]->getType();
  381. }
  382. FunctionType *funcTy = FunctionType::get(CI->getType(), paramTyList, false);
  383. unsigned opcode = GetHLOpcode(CI);
  384. Function *opFunc = GetOrCreateHLFunction(*m_pModule, funcTy, HLOpcodeGroup::HLSubscript, opcode);
  385. return Builder.CreateCall(opFunc, args);
  386. } else
  387. return MatIntrinsicToVec(CI);
  388. }
  389. Instruction *HLMatrixLowerPass::MatIntrinsicToVec(CallInst *CI) {
  390. IRBuilder<> Builder(CI);
  391. unsigned opcode = GetHLOpcode(CI);
  392. Type *FT = LowerMatrixType(CI->getCalledFunction()->getFunctionType());
  393. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  394. Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT), group, opcode);
  395. SmallVector<Value *, 4> argList;
  396. for (Value *arg : CI->arg_operands()) {
  397. Type *Ty = arg->getType();
  398. if (IsMatrixType(Ty)) {
  399. argList.emplace_back(UndefValue::get(LowerMatrixType(Ty)));
  400. } else
  401. argList.emplace_back(arg);
  402. }
  403. return Builder.CreateCall(vecF, argList);
  404. }
  405. static Value *VectorizeScalarOp(Value *op, Type *dstTy, IRBuilder<> &Builder) {
  406. if (op->getType() == dstTy)
  407. return op;
  408. op = Builder.CreateInsertElement(
  409. UndefValue::get(VectorType::get(op->getType(), 1)), op, (uint64_t)0);
  410. Type *I32Ty = IntegerType::get(dstTy->getContext(), 32);
  411. Constant *zero = ConstantInt::get(I32Ty, 0);
  412. std::vector<Constant *> MaskVec(dstTy->getVectorNumElements(), zero);
  413. Value *castMask = ConstantVector::get(MaskVec);
  414. Value *vecOp = new ShuffleVectorInst(op, op, castMask);
  415. Builder.Insert(cast<Instruction>(vecOp));
  416. return vecOp;
  417. }
  418. Instruction *HLMatrixLowerPass::TrivialMatUnOpToVec(CallInst *CI) {
  419. Type *ResultTy = LowerMatrixType(CI->getType());
  420. UndefValue *tmp = UndefValue::get(ResultTy);
  421. IRBuilder<> Builder(CI);
  422. HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(CI));
  423. bool isFloat = ResultTy->getVectorElementType()->isFloatingPointTy();
  424. auto GetVecConst = [&](Type *Ty, int v) -> Constant * {
  425. Constant *val = isFloat ? ConstantFP::get(Ty->getScalarType(), v)
  426. : ConstantInt::get(Ty->getScalarType(), v);
  427. std::vector<Constant *> vals(Ty->getVectorNumElements(), val);
  428. return ConstantVector::get(vals);
  429. };
  430. Constant *one = GetVecConst(ResultTy, 1);
  431. Instruction *Result = nullptr;
  432. switch (opcode) {
  433. case HLUnaryOpcode::Minus: {
  434. Constant *zero = GetVecConst(ResultTy, 0);
  435. if (isFloat)
  436. Result = BinaryOperator::CreateFSub(zero, tmp);
  437. else
  438. Result = BinaryOperator::CreateSub(zero, tmp);
  439. } break;
  440. case HLUnaryOpcode::LNot: {
  441. Constant *zero = GetVecConst(ResultTy, 0);
  442. if (isFloat)
  443. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_UNE, tmp, zero);
  444. else
  445. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, zero);
  446. } break;
  447. case HLUnaryOpcode::Not:
  448. Result = BinaryOperator::CreateXor(tmp, tmp);
  449. break;
  450. case HLUnaryOpcode::PostInc:
  451. case HLUnaryOpcode::PreInc:
  452. if (isFloat)
  453. Result = BinaryOperator::CreateFAdd(tmp, one);
  454. else
  455. Result = BinaryOperator::CreateAdd(tmp, one);
  456. break;
  457. case HLUnaryOpcode::PostDec:
  458. case HLUnaryOpcode::PreDec:
  459. if (isFloat)
  460. Result = BinaryOperator::CreateFSub(tmp, one);
  461. else
  462. Result = BinaryOperator::CreateSub(tmp, one);
  463. break;
  464. default:
  465. DXASSERT(0, "not implement");
  466. return nullptr;
  467. }
  468. Builder.Insert(Result);
  469. return Result;
  470. }
  471. Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
  472. Type *ResultTy = LowerMatrixType(CI->getType());
  473. IRBuilder<> Builder(CI);
  474. HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(CI));
  475. Type *OpTy = LowerMatrixType(
  476. CI->getOperand(HLOperandIndex::kBinaryOpSrc0Idx)->getType());
  477. UndefValue *tmp = UndefValue::get(OpTy);
  478. bool isFloat = OpTy->getVectorElementType()->isFloatingPointTy();
  479. Instruction *Result = nullptr;
  480. switch (opcode) {
  481. case HLBinaryOpcode::Add:
  482. if (isFloat)
  483. Result = BinaryOperator::CreateFAdd(tmp, tmp);
  484. else
  485. Result = BinaryOperator::CreateAdd(tmp, tmp);
  486. break;
  487. case HLBinaryOpcode::Sub:
  488. if (isFloat)
  489. Result = BinaryOperator::CreateFSub(tmp, tmp);
  490. else
  491. Result = BinaryOperator::CreateSub(tmp, tmp);
  492. break;
  493. case HLBinaryOpcode::Mul:
  494. if (isFloat)
  495. Result = BinaryOperator::CreateFMul(tmp, tmp);
  496. else
  497. Result = BinaryOperator::CreateMul(tmp, tmp);
  498. break;
  499. case HLBinaryOpcode::Div:
  500. if (isFloat)
  501. Result = BinaryOperator::CreateFDiv(tmp, tmp);
  502. else
  503. Result = BinaryOperator::CreateFDiv(tmp, tmp);
  504. break;
  505. case HLBinaryOpcode::Rem:
  506. if (isFloat)
  507. Result = BinaryOperator::CreateFRem(tmp, tmp);
  508. else
  509. Result = BinaryOperator::CreateSRem(tmp, tmp);
  510. break;
  511. case HLBinaryOpcode::And:
  512. Result = BinaryOperator::CreateAnd(tmp, tmp);
  513. break;
  514. case HLBinaryOpcode::Or:
  515. Result = BinaryOperator::CreateOr(tmp, tmp);
  516. break;
  517. case HLBinaryOpcode::Xor:
  518. Result = BinaryOperator::CreateXor(tmp, tmp);
  519. break;
  520. case HLBinaryOpcode::Shl: {
  521. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  522. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  523. Result = BinaryOperator::CreateShl(tmp, tmp);
  524. } break;
  525. case HLBinaryOpcode::Shr: {
  526. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  527. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  528. Result = BinaryOperator::CreateAShr(tmp, tmp);
  529. } break;
  530. case HLBinaryOpcode::LT:
  531. if (isFloat)
  532. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLT, tmp, tmp);
  533. else
  534. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT, tmp, tmp);
  535. break;
  536. case HLBinaryOpcode::GT:
  537. if (isFloat)
  538. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGT, tmp, tmp);
  539. else
  540. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, tmp, tmp);
  541. break;
  542. case HLBinaryOpcode::LE:
  543. if (isFloat)
  544. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLE, tmp, tmp);
  545. else
  546. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLE, tmp, tmp);
  547. break;
  548. case HLBinaryOpcode::GE:
  549. if (isFloat)
  550. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGE, tmp, tmp);
  551. else
  552. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGE, tmp, tmp);
  553. break;
  554. case HLBinaryOpcode::EQ:
  555. if (isFloat)
  556. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, tmp);
  557. else
  558. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, tmp);
  559. break;
  560. case HLBinaryOpcode::NE:
  561. if (isFloat)
  562. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, tmp);
  563. else
  564. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, tmp);
  565. break;
  566. case HLBinaryOpcode::UDiv:
  567. Result = BinaryOperator::CreateUDiv(tmp, tmp);
  568. break;
  569. case HLBinaryOpcode::URem:
  570. Result = BinaryOperator::CreateURem(tmp, tmp);
  571. break;
  572. case HLBinaryOpcode::UShr: {
  573. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  574. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  575. Result = BinaryOperator::CreateLShr(tmp, tmp);
  576. } break;
  577. case HLBinaryOpcode::ULT:
  578. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, tmp, tmp);
  579. break;
  580. case HLBinaryOpcode::UGT:
  581. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, tmp, tmp);
  582. break;
  583. case HLBinaryOpcode::ULE:
  584. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, tmp, tmp);
  585. break;
  586. case HLBinaryOpcode::UGE:
  587. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGE, tmp, tmp);
  588. break;
  589. case HLBinaryOpcode::LAnd:
  590. case HLBinaryOpcode::LOr: {
  591. Constant *zero;
  592. if (isFloat)
  593. zero = llvm::ConstantFP::get(ResultTy->getVectorElementType(), 0);
  594. else
  595. zero = llvm::ConstantInt::get(ResultTy->getVectorElementType(), 0);
  596. unsigned size = ResultTy->getVectorNumElements();
  597. std::vector<Constant *> zeros(size, zero);
  598. Value *vecZero = llvm::ConstantVector::get(zeros);
  599. Instruction *cmpL;
  600. if (isFloat)
  601. cmpL =
  602. CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
  603. else
  604. cmpL = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
  605. Builder.Insert(cmpL);
  606. Instruction *cmpR;
  607. if (isFloat)
  608. cmpR =
  609. CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
  610. else
  611. cmpR = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
  612. Builder.Insert(cmpR);
  613. // How to map l, r back? Need check opcode
  614. if (opcode == HLBinaryOpcode::LOr)
  615. Result = BinaryOperator::CreateAnd(cmpL, cmpR);
  616. else
  617. Result = BinaryOperator::CreateAnd(cmpL, cmpR);
  618. break;
  619. }
  620. default:
  621. DXASSERT(0, "not implement");
  622. return nullptr;
  623. }
  624. Builder.Insert(Result);
  625. return Result;
  626. }
  627. void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
  628. Instruction *vecInst;
  629. if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
  630. hlsl::HLOpcodeGroup group =
  631. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  632. switch (group) {
  633. case HLOpcodeGroup::HLIntrinsic: {
  634. vecInst = MatIntrinsicToVec(CI);
  635. } break;
  636. case HLOpcodeGroup::HLSelect: {
  637. vecInst = MatIntrinsicToVec(CI);
  638. } break;
  639. case HLOpcodeGroup::HLBinOp: {
  640. vecInst = TrivialMatBinOpToVec(CI);
  641. } break;
  642. case HLOpcodeGroup::HLUnOp: {
  643. vecInst = TrivialMatUnOpToVec(CI);
  644. } break;
  645. case HLOpcodeGroup::HLCast: {
  646. vecInst = MatCastToVec(CI);
  647. } break;
  648. case HLOpcodeGroup::HLInit: {
  649. vecInst = MatIntrinsicToVec(CI);
  650. } break;
  651. case HLOpcodeGroup::HLMatLoadStore: {
  652. vecInst = MatLdStToVec(CI);
  653. } break;
  654. case HLOpcodeGroup::HLSubscript: {
  655. vecInst = MatSubscriptToVec(CI);
  656. } break;
  657. }
  658. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
  659. Type *Ty = AI->getAllocatedType();
  660. Type *matTy = Ty;
  661. IRBuilder<> Builder(AI);
  662. if (Ty->isArrayTy()) {
  663. Type *vecTy = HLMatrixLower::LowerMatrixArrayPointer(AI->getType());
  664. vecTy = vecTy->getPointerElementType();
  665. vecInst = Builder.CreateAlloca(vecTy, nullptr, AI->getName());
  666. } else {
  667. Type *vecTy = HLMatrixLower::LowerMatrixType(matTy);
  668. vecInst = Builder.CreateAlloca(vecTy, nullptr, AI->getName());
  669. }
  670. // Update debug info.
  671. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI);
  672. if (DDI) {
  673. LLVMContext &Context = AI->getContext();
  674. Value *DDIVar = MetadataAsValue::get(Context, DDI->getRawVariable());
  675. Value *DDIExp = MetadataAsValue::get(Context, DDI->getRawExpression());
  676. Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(vecInst));
  677. IRBuilder<> debugBuilder(DDI);
  678. debugBuilder.CreateCall(DDI->getCalledFunction(), {VMD, DDIVar, DDIExp});
  679. }
  680. if (HLModule::HasPreciseAttributeWithMetadata(AI))
  681. HLModule::MarkPreciseAttributeWithMetadata(vecInst);
  682. } else {
  683. DXASSERT(0, "invalid inst");
  684. }
  685. matToVecMap[matInst] = vecInst;
  686. }
  687. // Replace matInst with vecInst on matUseInst.
  688. void HLMatrixLowerPass::TrivialMatUnOpReplace(CallInst *matInst,
  689. Instruction *vecInst,
  690. CallInst *matUseInst) {
  691. HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(matUseInst));
  692. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
  693. switch (opcode) {
  694. case HLUnaryOpcode::Not:
  695. // Not is xor now
  696. vecUseInst->setOperand(0, vecInst);
  697. vecUseInst->setOperand(1, vecInst);
  698. break;
  699. case HLUnaryOpcode::LNot:
  700. case HLUnaryOpcode::PostInc:
  701. case HLUnaryOpcode::PreInc:
  702. case HLUnaryOpcode::PostDec:
  703. case HLUnaryOpcode::PreDec:
  704. vecUseInst->setOperand(0, vecInst);
  705. break;
  706. }
  707. }
  708. // Replace matInst with vecInst on matUseInst.
  709. void HLMatrixLowerPass::TrivialMatBinOpReplace(CallInst *matInst,
  710. Instruction *vecInst,
  711. CallInst *matUseInst) {
  712. HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(matUseInst));
  713. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
  714. if (opcode != HLBinaryOpcode::LAnd && opcode != HLBinaryOpcode::LOr) {
  715. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) == matInst)
  716. vecUseInst->setOperand(0, vecInst);
  717. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) == matInst)
  718. vecUseInst->setOperand(1, vecInst);
  719. } else {
  720. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) ==
  721. matInst) {
  722. Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(0));
  723. vecCmp->setOperand(0, vecInst);
  724. }
  725. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) ==
  726. matInst) {
  727. Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(1));
  728. vecCmp->setOperand(0, vecInst);
  729. }
  730. }
  731. }
  732. void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
  733. Instruction *vecInst,
  734. CallInst *mulInst) {
  735. DXASSERT(matToVecMap.count(mulInst), "must has vec version");
  736. Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
  737. // Already translated.
  738. if (!isa<CallInst>(vecUseInst))
  739. return;
  740. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  741. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  742. unsigned col, row;
  743. Type *EltTy = GetMatrixInfo(LVal->getType(), col, row);
  744. unsigned rCol, rRow;
  745. GetMatrixInfo(RVal->getType(), rCol, rRow);
  746. DXASSERT_NOMSG(col == rRow);
  747. bool isFloat = EltTy->isFloatingPointTy();
  748. Value *retVal = llvm::UndefValue::get(LowerMatrixType(mulInst->getType()));
  749. IRBuilder<> Builder(mulInst);
  750. Value *lMat = matToVecMap[cast<Instruction>(LVal)];
  751. Value *rMat = matToVecMap[cast<Instruction>(RVal)];
  752. auto CreateOneEltMul = [&](unsigned r, unsigned lc, unsigned c) -> Value * {
  753. unsigned lMatIdx = GetMatIdx(r, lc, row);
  754. unsigned rMatIdx = GetMatIdx(lc, c, rRow);
  755. Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
  756. Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
  757. return isFloat ? Builder.CreateFMul(lMatElt, rMatElt)
  758. : Builder.CreateMul(lMatElt, rMatElt);
  759. };
  760. for (unsigned r = 0; r < row; r++) {
  761. for (unsigned c = 0; c < rCol; c++) {
  762. unsigned lc = 0;
  763. Value *tmpVal = CreateOneEltMul(r, lc, c);
  764. for (lc = 1; lc < col; lc++) {
  765. Value *tmpMul = CreateOneEltMul(r, lc, c);
  766. tmpVal = isFloat ? Builder.CreateFAdd(tmpVal, tmpMul)
  767. : Builder.CreateAdd(tmpVal, tmpMul);
  768. }
  769. unsigned matIdx = GetMatIdx(r, c, row);
  770. retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
  771. }
  772. }
  773. Instruction *matmatMul = cast<Instruction>(retVal);
  774. // Replace vec transpose function call with shuf.
  775. vecUseInst->replaceAllUsesWith(matmatMul);
  776. AddToDeadInsts(vecUseInst);
  777. matToVecMap[mulInst] = matmatMul;
  778. }
  779. void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
  780. Instruction *vecInst,
  781. CallInst *mulInst) {
  782. // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  783. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  784. unsigned col, row;
  785. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  786. DXASSERT(RVal->getType()->getVectorNumElements() == col, "");
  787. bool isFloat = EltTy->isFloatingPointTy();
  788. Value *retVal = llvm::UndefValue::get(mulInst->getType());
  789. IRBuilder<> Builder(mulInst);
  790. Value *vec = RVal;
  791. Value *mat = vecInst; // vec version of matInst;
  792. for (unsigned r = 0; r < row; r++) {
  793. unsigned c = 0;
  794. Value *vecElt = Builder.CreateExtractElement(vec, c);
  795. uint32_t matIdx = GetMatIdx(r, c, row);
  796. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  797. Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
  798. : Builder.CreateMul(vecElt, matElt);
  799. for (c = 1; c < col; c++) {
  800. vecElt = Builder.CreateExtractElement(vec, c);
  801. uint32_t matIdx = GetMatIdx(r, c, row);
  802. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  803. Value *tmpMul = isFloat ? Builder.CreateFMul(vecElt, matElt)
  804. : Builder.CreateMul(vecElt, matElt);
  805. tmpVal = isFloat ? Builder.CreateFAdd(tmpVal, tmpMul)
  806. : Builder.CreateAdd(tmpVal, tmpMul);
  807. }
  808. retVal = Builder.CreateInsertElement(retVal, tmpVal, r);
  809. }
  810. mulInst->replaceAllUsesWith(retVal);
  811. AddToDeadInsts(mulInst);
  812. }
  813. void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
  814. Instruction *vecInst,
  815. CallInst *mulInst) {
  816. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  817. // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  818. Value *RVal = vecInst;
  819. unsigned col, row;
  820. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  821. DXASSERT(LVal->getType()->getVectorNumElements() == row, "");
  822. bool isFloat = EltTy->isFloatingPointTy();
  823. Value *retVal = llvm::UndefValue::get(mulInst->getType());
  824. IRBuilder<> Builder(mulInst);
  825. Value *vec = LVal;
  826. Value *mat = RVal;
  827. for (unsigned c = 0; c < col; c++) {
  828. unsigned r = 0;
  829. Value *vecElt = Builder.CreateExtractElement(vec, r);
  830. uint32_t matIdx = GetMatIdx(r, c, row);
  831. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  832. Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
  833. : Builder.CreateMul(vecElt, matElt);
  834. for (r = 1; r < row; r++) {
  835. vecElt = Builder.CreateExtractElement(vec, r);
  836. uint32_t matIdx = GetMatIdx(r, c, row);
  837. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  838. Value *tmpMul = isFloat ? Builder.CreateFMul(vecElt, matElt)
  839. : Builder.CreateMul(vecElt, matElt);
  840. tmpVal = isFloat ? Builder.CreateFAdd(tmpVal, tmpMul)
  841. : Builder.CreateAdd(tmpVal, tmpMul);
  842. }
  843. retVal = Builder.CreateInsertElement(retVal, tmpVal, c);
  844. }
  845. mulInst->replaceAllUsesWith(retVal);
  846. AddToDeadInsts(mulInst);
  847. }
  848. void HLMatrixLowerPass::TranslateMul(CallInst *matInst, Instruction *vecInst,
  849. CallInst *mulInst) {
  850. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  851. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  852. bool LMat = IsMatrixType(LVal->getType());
  853. bool RMat = IsMatrixType(RVal->getType());
  854. if (LMat && RMat) {
  855. TranslateMatMatMul(matInst, vecInst, mulInst);
  856. } else if (LMat) {
  857. TranslateMatVecMul(matInst, vecInst, mulInst);
  858. } else {
  859. TranslateVecMatMul(matInst, vecInst, mulInst);
  860. }
  861. }
  862. void HLMatrixLowerPass::TranslateMatTranspose(CallInst *matInst,
  863. Instruction *vecInst,
  864. CallInst *transposeInst) {
  865. unsigned row, col;
  866. GetMatrixInfo(transposeInst->getType(), col, row);
  867. IRBuilder<> Builder(transposeInst);
  868. std::vector<int> transposeMask(col * row);
  869. unsigned idx = 0;
  870. for (unsigned c = 0; c < col; c++)
  871. for (unsigned r = 0; r < row; r++) {
  872. // change to row major
  873. unsigned matIdx = GetMatIdx(c, r, col);
  874. transposeMask[idx++] = (matIdx);
  875. }
  876. Instruction *shuf = cast<Instruction>(
  877. Builder.CreateShuffleVector(vecInst, vecInst, transposeMask));
  878. // Replace vec transpose function call with shuf.
  879. DXASSERT(matToVecMap.count(transposeInst), "must has vec version");
  880. Instruction *vecUseInst = cast<Instruction>(matToVecMap[transposeInst]);
  881. vecUseInst->replaceAllUsesWith(shuf);
  882. AddToDeadInsts(vecUseInst);
  883. matToVecMap[transposeInst] = shuf;
  884. }
  885. static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
  886. IRBuilder<> &Builder) {
  887. Value *mul0 = Builder.CreateFMul(m00, m11);
  888. Value *mul1 = Builder.CreateFMul(m01, m10);
  889. return Builder.CreateFSub(mul0, mul1);
  890. }
  891. static Value *Determinant3x3(Value *m00, Value *m01, Value *m02,
  892. Value *m10, Value *m11, Value *m12,
  893. Value *m20, Value *m21, Value *m22,
  894. IRBuilder<> &Builder) {
  895. Value *deter00 = Determinant2x2(m11, m12, m21, m22, Builder);
  896. Value *deter01 = Determinant2x2(m10, m12, m20, m22, Builder);
  897. Value *deter02 = Determinant2x2(m10, m11, m20, m21, Builder);
  898. deter00 = Builder.CreateFMul(m00, deter00);
  899. deter01 = Builder.CreateFMul(m01, deter01);
  900. deter02 = Builder.CreateFMul(m02, deter02);
  901. Value *result = Builder.CreateFSub(deter00, deter01);
  902. result = Builder.CreateFAdd(result, deter02);
  903. return result;
  904. }
  905. static Value *Determinant4x4(Value *m00, Value *m01, Value *m02, Value *m03,
  906. Value *m10, Value *m11, Value *m12, Value *m13,
  907. Value *m20, Value *m21, Value *m22, Value *m23,
  908. Value *m30, Value *m31, Value *m32, Value *m33,
  909. IRBuilder<> &Builder) {
  910. Value *deter00 = Determinant3x3(m11, m12, m13, m21, m22, m23, m31, m32, m33, Builder);
  911. Value *deter01 = Determinant3x3(m10, m12, m13, m20, m22, m23, m30, m32, m33, Builder);
  912. Value *deter02 = Determinant3x3(m10, m11, m13, m20, m21, m23, m30, m31, m33, Builder);
  913. Value *deter03 = Determinant3x3(m10, m11, m12, m20, m21, m22, m30, m31, m32, Builder);
  914. deter00 = Builder.CreateFMul(m00, deter00);
  915. deter01 = Builder.CreateFMul(m01, deter01);
  916. deter02 = Builder.CreateFMul(m02, deter02);
  917. deter03 = Builder.CreateFMul(m03, deter03);
  918. Value *result = Builder.CreateFSub(deter00, deter01);
  919. result = Builder.CreateFAdd(result, deter02);
  920. result = Builder.CreateFSub(result, deter03);
  921. return result;
  922. }
  923. void HLMatrixLowerPass::TranslateMatDeterminant(CallInst *matInst, Instruction *vecInst,
  924. CallInst *determinantInst) {
  925. unsigned row, col;
  926. GetMatrixInfo(matInst->getType(), col, row);
  927. IRBuilder<> Builder(determinantInst);
  928. // when row == 1, result is vecInst.
  929. Value *Result = vecInst;
  930. if (row == 2) {
  931. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  932. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  933. Value *m10 = Builder.CreateExtractElement(vecInst, 2);
  934. Value *m11 = Builder.CreateExtractElement(vecInst, 3);
  935. Result = Determinant2x2(m00, m01, m10, m11, Builder);
  936. }
  937. else if (row == 3) {
  938. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  939. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  940. Value *m02 = Builder.CreateExtractElement(vecInst, 2);
  941. Value *m10 = Builder.CreateExtractElement(vecInst, 3);
  942. Value *m11 = Builder.CreateExtractElement(vecInst, 4);
  943. Value *m12 = Builder.CreateExtractElement(vecInst, 5);
  944. Value *m20 = Builder.CreateExtractElement(vecInst, 6);
  945. Value *m21 = Builder.CreateExtractElement(vecInst, 7);
  946. Value *m22 = Builder.CreateExtractElement(vecInst, 8);
  947. Result = Determinant3x3(m00, m01, m02,
  948. m10, m11, m12,
  949. m20, m21, m22, Builder);
  950. }
  951. else if (row == 4) {
  952. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  953. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  954. Value *m02 = Builder.CreateExtractElement(vecInst, 2);
  955. Value *m03 = Builder.CreateExtractElement(vecInst, 3);
  956. Value *m10 = Builder.CreateExtractElement(vecInst, 4);
  957. Value *m11 = Builder.CreateExtractElement(vecInst, 5);
  958. Value *m12 = Builder.CreateExtractElement(vecInst, 6);
  959. Value *m13 = Builder.CreateExtractElement(vecInst, 7);
  960. Value *m20 = Builder.CreateExtractElement(vecInst, 8);
  961. Value *m21 = Builder.CreateExtractElement(vecInst, 9);
  962. Value *m22 = Builder.CreateExtractElement(vecInst, 10);
  963. Value *m23 = Builder.CreateExtractElement(vecInst, 11);
  964. Value *m30 = Builder.CreateExtractElement(vecInst, 12);
  965. Value *m31 = Builder.CreateExtractElement(vecInst, 13);
  966. Value *m32 = Builder.CreateExtractElement(vecInst, 14);
  967. Value *m33 = Builder.CreateExtractElement(vecInst, 15);
  968. Result = Determinant4x4(m00, m01, m02, m03,
  969. m10, m11, m12, m13,
  970. m20, m21, m22, m23,
  971. m30, m31, m32, m33,
  972. Builder);
  973. } else {
  974. DXASSERT(row == 1, "invalid matrix type");
  975. Result = Builder.CreateExtractElement(Result, (uint64_t)0);
  976. }
  977. determinantInst->replaceAllUsesWith(Result);
  978. AddToDeadInsts(determinantInst);
  979. }
  980. void HLMatrixLowerPass::TrivialMatReplace(CallInst *matInst,
  981. Instruction *vecInst,
  982. CallInst *matUseInst) {
  983. CallInst *vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
  984. for (unsigned i = 0; i < matUseInst->getNumArgOperands(); i++)
  985. if (matUseInst->getArgOperand(i) == matInst) {
  986. vecUseInst->setArgOperand(i, vecInst);
  987. }
  988. }
  989. void HLMatrixLowerPass::TranslateMatMatCast(CallInst *matInst,
  990. Instruction *vecInst,
  991. CallInst *castInst) {
  992. unsigned toCol, toRow;
  993. Type *ToEltTy = GetMatrixInfo(castInst->getType(), toCol, toRow);
  994. unsigned fromCol, fromRow;
  995. Type *FromEltTy = GetMatrixInfo(matInst->getType(), fromCol, fromRow);
  996. unsigned fromSize = fromCol * fromRow;
  997. unsigned toSize = toCol * toRow;
  998. DXASSERT(fromSize >= toSize, "cannot extend matrix");
  999. IRBuilder<> Builder(castInst);
  1000. Instruction *vecCast = nullptr;
  1001. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1002. if (fromSize == toSize) {
  1003. vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), vecInst,
  1004. Builder);
  1005. } else {
  1006. // shuf first
  1007. std::vector<int> castMask(toCol * toRow);
  1008. unsigned idx = 0;
  1009. for (unsigned c = 0; c < toCol; c++)
  1010. for (unsigned r = 0; r < toRow; r++) {
  1011. unsigned matIdx = GetMatIdx(r, c, fromRow);
  1012. castMask[idx++] = matIdx;
  1013. }
  1014. Instruction *shuf = cast<Instruction>(
  1015. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1016. if (ToEltTy != FromEltTy)
  1017. vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), shuf,
  1018. Builder);
  1019. else
  1020. vecCast = shuf;
  1021. }
  1022. // Replace vec cast function call with vecCast.
  1023. DXASSERT(matToVecMap.count(castInst), "must has vec version");
  1024. Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
  1025. vecUseInst->replaceAllUsesWith(vecCast);
  1026. AddToDeadInsts(vecUseInst);
  1027. matToVecMap[castInst] = vecCast;
  1028. }
  1029. void HLMatrixLowerPass::TranslateMatToOtherCast(CallInst *matInst,
  1030. Instruction *vecInst,
  1031. CallInst *castInst) {
  1032. unsigned col, row;
  1033. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  1034. unsigned fromSize = col * row;
  1035. IRBuilder<> Builder(castInst);
  1036. Instruction *sizeCast = nullptr;
  1037. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1038. Type *ToTy = castInst->getType();
  1039. if (ToTy->isVectorTy()) {
  1040. unsigned toSize = ToTy->getVectorNumElements();
  1041. if (fromSize != toSize) {
  1042. std::vector<int> castMask(fromSize);
  1043. for (unsigned c = 0; c < toSize; c++)
  1044. castMask[c] = c;
  1045. sizeCast = cast<Instruction>(
  1046. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1047. } else
  1048. sizeCast = vecInst;
  1049. } else {
  1050. DXASSERT(ToTy->isSingleValueType(), "must scalar here");
  1051. sizeCast =
  1052. cast<Instruction>(Builder.CreateExtractElement(vecInst, (uint64_t)0));
  1053. }
  1054. Instruction *typeCast = sizeCast;
  1055. if (EltTy != ToTy->getScalarType()) {
  1056. typeCast = CreateTypeCast(opcode, ToTy, typeCast, Builder);
  1057. }
  1058. // Replace cast function call with typeCast.
  1059. castInst->replaceAllUsesWith(typeCast);
  1060. AddToDeadInsts(castInst);
  1061. }
  1062. void HLMatrixLowerPass::TranslateMatCast(CallInst *matInst,
  1063. Instruction *vecInst,
  1064. CallInst *castInst) {
  1065. bool ToMat = IsMatrixType(castInst->getType());
  1066. bool FromMat = IsMatrixType(matInst->getType());
  1067. if (ToMat && FromMat) {
  1068. TranslateMatMatCast(matInst, vecInst, castInst);
  1069. } else if (FromMat)
  1070. TranslateMatToOtherCast(matInst, vecInst, castInst);
  1071. else
  1072. DXASSERT(0, "Not translate as user of matInst");
  1073. }
  1074. void HLMatrixLowerPass::MatIntrinsicReplace(CallInst *matInst,
  1075. Instruction *vecInst,
  1076. CallInst *matUseInst) {
  1077. IRBuilder<> Builder(matUseInst);
  1078. IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(matUseInst));
  1079. switch (opcode) {
  1080. case IntrinsicOp::IOP_mul:
  1081. TranslateMul(matInst, vecInst, matUseInst);
  1082. break;
  1083. case IntrinsicOp::IOP_transpose:
  1084. TranslateMatTranspose(matInst, vecInst, matUseInst);
  1085. break;
  1086. case IntrinsicOp::IOP_determinant:
  1087. TranslateMatDeterminant(matInst, vecInst, matUseInst);
  1088. break;
  1089. default:
  1090. CallInst *vecUseInst = nullptr;
  1091. if (matToVecMap.count(matUseInst))
  1092. vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
  1093. for (unsigned i = 0; i < matInst->getNumArgOperands(); i++)
  1094. if (matUseInst->getArgOperand(i) == matInst) {
  1095. if (vecUseInst)
  1096. vecUseInst->setArgOperand(i, vecInst);
  1097. else
  1098. matUseInst->setArgOperand(i, vecInst);
  1099. }
  1100. break;
  1101. }
  1102. }
  1103. void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
  1104. CallInst *matSubInst) {
  1105. unsigned opcode = GetHLOpcode(matSubInst);
  1106. HLSubscriptOpcode matOpcode = static_cast<HLSubscriptOpcode>(opcode);
  1107. assert(matOpcode != HLSubscriptOpcode::DefaultSubscript &&
  1108. "matrix don't use default subscript");
  1109. Type *matType = matInst->getType()->getPointerElementType();
  1110. unsigned col, row;
  1111. Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
  1112. bool isElement = (matOpcode == HLSubscriptOpcode::ColMatElement) |
  1113. (matOpcode == HLSubscriptOpcode::RowMatElement);
  1114. Value *mask =
  1115. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1116. // For temp matrix, all use col major.
  1117. if (isElement) {
  1118. Type *resultType = matSubInst->getType()->getPointerElementType();
  1119. unsigned resultSize = 1;
  1120. if (resultType->isVectorTy())
  1121. resultSize = resultType->getVectorNumElements();
  1122. std::vector<int> shufMask(resultSize);
  1123. if (ConstantDataSequential *elts = dyn_cast<ConstantDataSequential>(mask)) {
  1124. unsigned count = elts->getNumElements();
  1125. for (unsigned i = 0; i < count; i += 2) {
  1126. unsigned rowIdx = elts->getElementAsInteger(i);
  1127. unsigned colIdx = elts->getElementAsInteger(i + 1);
  1128. unsigned matIdx = GetMatIdx(rowIdx, colIdx, row);
  1129. shufMask[i>>1] = matIdx;
  1130. }
  1131. }
  1132. else {
  1133. ConstantAggregateZero *zeros = cast<ConstantAggregateZero>(mask);
  1134. unsigned size = zeros->getNumElements()>>1;
  1135. for (unsigned i=0;i<size;i++)
  1136. shufMask[i] = 0;
  1137. }
  1138. for (Value::use_iterator CallUI = matSubInst->use_begin(),
  1139. CallE = matSubInst->use_end();
  1140. CallUI != CallE;) {
  1141. Use &CallUse = *CallUI++;
  1142. Instruction *CallUser = cast<Instruction>(CallUse.getUser());
  1143. IRBuilder<> Builder(CallUser);
  1144. Value *vecLd = Builder.CreateLoad(vecInst);
  1145. if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
  1146. if (resultSize > 1) {
  1147. Value *shuf = Builder.CreateShuffleVector(vecLd, vecLd, shufMask);
  1148. ld->replaceAllUsesWith(shuf);
  1149. } else {
  1150. Value *elt = Builder.CreateExtractElement(vecLd, shufMask[0]);
  1151. ld->replaceAllUsesWith(elt);
  1152. }
  1153. } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
  1154. Value *val = st->getValueOperand();
  1155. if (resultSize > 1) {
  1156. for (unsigned i = 0; i < shufMask.size(); i++) {
  1157. unsigned idx = shufMask[i];
  1158. Value *valElt = Builder.CreateExtractElement(val, i);
  1159. vecLd = Builder.CreateInsertElement(vecLd, valElt, idx);
  1160. }
  1161. Builder.CreateStore(vecLd, vecInst);
  1162. } else {
  1163. vecLd = Builder.CreateInsertElement(vecLd, val, shufMask[0]);
  1164. Builder.CreateStore(vecLd, vecInst);
  1165. }
  1166. } else
  1167. DXASSERT(0, "matrix element should only used by load/store.");
  1168. AddToDeadInsts(CallUser);
  1169. }
  1170. } else {
  1171. // Subscript.
  1172. // Return a row.
  1173. // Use insertElement and extractElement.
  1174. ArrayType *AT = ArrayType::get(EltTy, col*row);
  1175. IRBuilder<> AllocaBuilder(
  1176. matSubInst->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  1177. Value *tempArray = AllocaBuilder.CreateAlloca(AT);
  1178. Value *zero = AllocaBuilder.getInt32(0);
  1179. bool isDynamicIndexing = !isa<ConstantInt>(mask);
  1180. for (Value::use_iterator CallUI = matSubInst->use_begin(),
  1181. CallE = matSubInst->use_end();
  1182. CallUI != CallE;) {
  1183. Use &CallUse = *CallUI++;
  1184. Instruction *CallUser = cast<Instruction>(CallUse.getUser());
  1185. Value *idx = mask;
  1186. IRBuilder<> Builder(CallUser);
  1187. Value *vecLd = Builder.CreateLoad(vecInst);
  1188. if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
  1189. Value *sub = UndefValue::get(ld->getType());
  1190. if (!isDynamicIndexing) {
  1191. for (unsigned i = 0; i < col; i++) {
  1192. // col major: matIdx = c * row + r;
  1193. Value *matIdx = Builder.CreateAdd(idx, Builder.getInt32(i * row));
  1194. Value *valElt = Builder.CreateExtractElement(vecLd, matIdx);
  1195. sub = Builder.CreateInsertElement(sub, valElt, i);
  1196. }
  1197. } else {
  1198. // Copy vec to array.
  1199. for (unsigned int i = 0; i < row*col; i++) {
  1200. Value *Elt =
  1201. Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
  1202. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1203. {zero, Builder.getInt32(i)});
  1204. Builder.CreateStore(Elt, Ptr);
  1205. }
  1206. for (unsigned i = 0; i < col; i++) {
  1207. // col major: matIdx = c * row + r;
  1208. Value *matIdx = Builder.CreateAdd(idx, Builder.getInt32(i * row));
  1209. Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
  1210. Value *valElt = Builder.CreateLoad(Ptr);
  1211. sub = Builder.CreateInsertElement(sub, valElt, i);
  1212. }
  1213. }
  1214. ld->replaceAllUsesWith(sub);
  1215. } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
  1216. Value *val = st->getValueOperand();
  1217. if (!isDynamicIndexing) {
  1218. for (unsigned i = 0; i < col; i++) {
  1219. // col major: matIdx = c * row + r;
  1220. Value *matIdx = Builder.CreateAdd(idx, Builder.getInt32(i * row));
  1221. Value *valElt = Builder.CreateExtractElement(val, i);
  1222. vecLd = Builder.CreateInsertElement(vecLd, valElt, matIdx);
  1223. }
  1224. } else {
  1225. // Copy vec to array.
  1226. for (unsigned int i = 0; i < row * col; i++) {
  1227. Value *Elt =
  1228. Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
  1229. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1230. {zero, Builder.getInt32(i)});
  1231. Builder.CreateStore(Elt, Ptr);
  1232. }
  1233. // Update array.
  1234. for (unsigned i = 0; i < col; i++) {
  1235. // col major: matIdx = c * row + r;
  1236. Value *matIdx = Builder.CreateAdd(idx, Builder.getInt32(i * row));
  1237. Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
  1238. Value *valElt = Builder.CreateExtractElement(val, i);
  1239. Builder.CreateStore(valElt, Ptr);
  1240. }
  1241. // Copy array to vec.
  1242. for (unsigned int i = 0; i < row * col; i++) {
  1243. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1244. {zero, Builder.getInt32(i)});
  1245. Value *Elt = Builder.CreateLoad(Ptr);
  1246. vecLd = Builder.CreateInsertElement(vecLd, Elt, i);
  1247. }
  1248. }
  1249. Builder.CreateStore(vecLd, vecInst);
  1250. } else if (GetElementPtrInst *GEP =
  1251. dyn_cast<GetElementPtrInst>(CallUser)) {
  1252. // Must be for subscript on vector
  1253. auto idxIter = GEP->idx_begin();
  1254. // skip the zero
  1255. idxIter++;
  1256. Value *gepIdx = *idxIter;
  1257. // Col major matIdx = r + c * row; r is idx, c is gepIdx
  1258. Value *iMulRow = Builder.CreateMul(gepIdx, Builder.getInt32(row));
  1259. Value *vecIdx = Builder.CreateAdd(iMulRow, idx);
  1260. llvm::Constant *zero = llvm::ConstantInt::get(vecIdx->getType(), 0);
  1261. Value *NewGEP = Builder.CreateGEP(vecInst, {zero, vecIdx});
  1262. GEP->replaceAllUsesWith(NewGEP);
  1263. } else
  1264. DXASSERT(0, "matrix subscript should only used by load/store.");
  1265. AddToDeadInsts(CallUser);
  1266. }
  1267. }
  1268. // Check vec version.
  1269. DXASSERT(matToVecMap.count(matSubInst) == 0, "should not have vec version");
  1270. // All the user should has been removed.
  1271. matSubInst->replaceAllUsesWith(UndefValue::get(matSubInst->getType()));
  1272. AddToDeadInsts(matSubInst);
  1273. }
  1274. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
  1275. Value *matGlobal, ArrayRef<Value *> vecGlobals,
  1276. CallInst *matLdStInst) {
  1277. // No dynamic indexing on matrix, flatten matrix to scalars.
  1278. Type *matType = matGlobal->getType()->getPointerElementType();
  1279. unsigned col, row;
  1280. HLMatrixLower::GetMatrixInfo(matType, col, row);
  1281. Type *vecType = HLMatrixLower::LowerMatrixType(matType);
  1282. IRBuilder<> Builder(matLdStInst);
  1283. HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
  1284. switch (opcode) {
  1285. case HLMatLoadStoreOpcode::ColMatLoad:
  1286. case HLMatLoadStoreOpcode::RowMatLoad: {
  1287. Value *Result = UndefValue::get(vecType);
  1288. for (unsigned c = 0; c < col; c++)
  1289. for (unsigned r = 0; r < row; r++) {
  1290. unsigned matIdx = c * row + r;
  1291. Value *Elt = Builder.CreateLoad(vecGlobals[matIdx]);
  1292. Result = Builder.CreateInsertElement(Result, Elt, matIdx);
  1293. }
  1294. matLdStInst->replaceAllUsesWith(Result);
  1295. } break;
  1296. case HLMatLoadStoreOpcode::ColMatStore:
  1297. case HLMatLoadStoreOpcode::RowMatStore: {
  1298. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1299. for (unsigned c = 0; c < col; c++)
  1300. for (unsigned r = 0; r < row; r++) {
  1301. unsigned matIdx = c * row + r;
  1302. Value *Elt = Builder.CreateExtractElement(Val, matIdx);
  1303. Builder.CreateStore(Elt, vecGlobals[matIdx]);
  1304. }
  1305. } break;
  1306. }
  1307. }
  1308. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
  1309. GlobalVariable *scalarArrayGlobal,
  1310. CallInst *matLdStInst) {
  1311. HLMatLoadStoreOpcode opcode =
  1312. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
  1313. switch (opcode) {
  1314. case HLMatLoadStoreOpcode::ColMatLoad:
  1315. case HLMatLoadStoreOpcode::RowMatLoad: {
  1316. IRBuilder<> Builder(matLdStInst);
  1317. Type *matTy = matGlobal->getType()->getPointerElementType();
  1318. unsigned col, row;
  1319. Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1320. Value *zeroIdx = Builder.getInt32(0);
  1321. std::vector<Value *> matElts(col * row);
  1322. for (unsigned c = 0; c < col; c++)
  1323. for (unsigned r = 0; r < row; r++) {
  1324. unsigned matIdx = c * row + r;
  1325. Value *GEP = Builder.CreateInBoundsGEP(
  1326. scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
  1327. matElts[matIdx] = Builder.CreateLoad(GEP);
  1328. }
  1329. Value *newVec =
  1330. HLMatrixLower::BuildMatrix(EltTy, col, row, false, matElts, Builder);
  1331. matLdStInst->replaceAllUsesWith(newVec);
  1332. matLdStInst->eraseFromParent();
  1333. } break;
  1334. case HLMatLoadStoreOpcode::ColMatStore:
  1335. case HLMatLoadStoreOpcode::RowMatStore: {
  1336. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1337. IRBuilder<> Builder(matLdStInst);
  1338. Type *matTy = matGlobal->getType()->getPointerElementType();
  1339. unsigned col, row;
  1340. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1341. Value *zeroIdx = Builder.getInt32(0);
  1342. std::vector<Value *> matElts(col * row);
  1343. for (unsigned c = 0; c < col; c++)
  1344. for (unsigned r = 0; r < row; r++) {
  1345. unsigned matIdx = c * row + r;
  1346. Value *GEP = Builder.CreateInBoundsGEP(
  1347. scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
  1348. Value *Elt = Builder.CreateExtractElement(Val, matIdx);
  1349. Builder.CreateStore(Elt, GEP);
  1350. }
  1351. matLdStInst->eraseFromParent();
  1352. } break;
  1353. }
  1354. }
  1355. void HLMatrixLowerPass::TranslateMatSubscriptOnGlobal(GlobalVariable *matGlobal,
  1356. GlobalVariable *scalarArrayGlobal,
  1357. CallInst *matSubInst) {
  1358. Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1359. IRBuilder<> subBuilder(matSubInst);
  1360. Value *zeroIdx = subBuilder.getInt32(0);
  1361. HLSubscriptOpcode opcode =
  1362. static_cast<HLSubscriptOpcode>(GetHLOpcode(matSubInst));
  1363. Type *matTy = matGlobal->getType()->getPointerElementType();
  1364. unsigned col, row;
  1365. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1366. std::vector<Value *> Ptrs;
  1367. switch (opcode) {
  1368. case HLSubscriptOpcode::ColMatSubscript:
  1369. case HLSubscriptOpcode::RowMatSubscript: {
  1370. // Use col major for internal matrix.
  1371. // And subscripts will return a row.
  1372. for (unsigned c = 0; c < col; c++) {
  1373. Value *colIdxBase = subBuilder.getInt32(c * row);
  1374. Value *matIdx = subBuilder.CreateAdd(colIdxBase, idx);
  1375. Value *Ptr =
  1376. subBuilder.CreateInBoundsGEP(scalarArrayGlobal, {zeroIdx, matIdx});
  1377. Ptrs.emplace_back(Ptr);
  1378. }
  1379. } break;
  1380. case HLSubscriptOpcode::RowMatElement:
  1381. case HLSubscriptOpcode::ColMatElement: {
  1382. // Use col major for internal matrix.
  1383. if (ConstantDataSequential *elts = dyn_cast<ConstantDataSequential>(idx)) {
  1384. unsigned count = elts->getNumElements();
  1385. for (unsigned i = 0; i < count; i += 2) {
  1386. unsigned rowIdx = elts->getElementAsInteger(i);
  1387. unsigned colIdx = elts->getElementAsInteger(i + 1);
  1388. Value *matIdx = subBuilder.getInt32(colIdx * row + rowIdx);
  1389. Value *Ptr =
  1390. subBuilder.CreateInBoundsGEP(scalarArrayGlobal, {zeroIdx, matIdx});
  1391. Ptrs.emplace_back(Ptr);
  1392. }
  1393. } else {
  1394. ConstantAggregateZero *zeros = cast<ConstantAggregateZero>(idx);
  1395. unsigned size = zeros->getNumElements() >> 1;
  1396. for (unsigned i = 0; i < size; i++) {
  1397. Value *Ptr =
  1398. subBuilder.CreateInBoundsGEP(scalarArrayGlobal, {zeroIdx, zeroIdx});
  1399. Ptrs.emplace_back(Ptr);
  1400. }
  1401. }
  1402. } break;
  1403. default:
  1404. DXASSERT(0, "invalid operation");
  1405. break;
  1406. }
  1407. // Cannot generate vector pointer
  1408. // Replace all uses with scalar pointers.
  1409. if (Ptrs.size() == 1)
  1410. matSubInst->replaceAllUsesWith(Ptrs[0]);
  1411. else {
  1412. // Split the use of CI with Ptrs.
  1413. for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
  1414. Instruction *subsUser = cast<Instruction>(*(U++));
  1415. IRBuilder<> userBuilder(subsUser);
  1416. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
  1417. DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
  1418. Value *baseIdx = (GEP->idx_begin())->get();
  1419. DXASSERT_LOCALVAR(baseIdx, baseIdx == zeroIdx, "base index must be 0");
  1420. Value *idx = (GEP->idx_begin() + 1)->get();
  1421. for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
  1422. Instruction *gepUser = cast<Instruction>(*(gepU++));
  1423. IRBuilder<> gepUserBuilder(gepUser);
  1424. if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
  1425. Value *subData = stUser->getValueOperand();
  1426. if (ConstantInt *immIdx = dyn_cast<ConstantInt>(idx)) {
  1427. Value *Ptr = Ptrs[immIdx->getSExtValue()];
  1428. gepUserBuilder.CreateStore(subData, Ptr);
  1429. } else {
  1430. // Create a temp array.
  1431. IRBuilder<> allocaBuilder(stUser->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  1432. Value *tempArray = allocaBuilder.CreateAlloca(
  1433. ArrayType::get(subData->getType(), Ptrs.size()));
  1434. // Store value to temp array.
  1435. for (unsigned i = 0; i < Ptrs.size(); i++) {
  1436. Value *Elt = gepUserBuilder.CreateLoad(Ptrs[i]);
  1437. Value *EltGEP = gepUserBuilder.CreateGEP(tempArray, {zeroIdx, gepUserBuilder.getInt32(i)} );
  1438. gepUserBuilder.CreateStore(Elt, EltGEP);
  1439. }
  1440. // Dynamic indexing.
  1441. Value *subGEP =
  1442. gepUserBuilder.CreateInBoundsGEP(tempArray, {zeroIdx, idx});
  1443. gepUserBuilder.CreateStore(subData, subGEP);
  1444. // Store temp array to value.
  1445. for (unsigned i = 0; i < Ptrs.size(); i++) {
  1446. Value *EltGEP = gepUserBuilder.CreateGEP(tempArray, {zeroIdx, gepUserBuilder.getInt32(i)} );
  1447. Value *Elt = gepUserBuilder.CreateLoad(EltGEP);
  1448. gepUserBuilder.CreateStore(Elt, Ptrs[i]);
  1449. }
  1450. }
  1451. stUser->eraseFromParent();
  1452. } else {
  1453. // Must be load here;
  1454. LoadInst *ldUser = cast<LoadInst>(gepUser);
  1455. Value *subData = nullptr;
  1456. if (ConstantInt *immIdx = dyn_cast<ConstantInt>(idx)) {
  1457. Value *Ptr = Ptrs[immIdx->getSExtValue()];
  1458. subData = gepUserBuilder.CreateLoad(Ptr);
  1459. } else {
  1460. // Create a temp array.
  1461. IRBuilder<> allocaBuilder(ldUser->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  1462. Value *tempArray = allocaBuilder.CreateAlloca(
  1463. ArrayType::get(ldUser->getType(), Ptrs.size()));
  1464. // Store value to temp array.
  1465. for (unsigned i = 0; i < Ptrs.size(); i++) {
  1466. Value *Elt = gepUserBuilder.CreateLoad(Ptrs[i]);
  1467. Value *EltGEP = gepUserBuilder.CreateGEP(tempArray, {zeroIdx, gepUserBuilder.getInt32(i)} );
  1468. gepUserBuilder.CreateStore(Elt, EltGEP);
  1469. }
  1470. // Dynamic indexing.
  1471. Value *subGEP =
  1472. gepUserBuilder.CreateInBoundsGEP(tempArray, {zeroIdx, idx});
  1473. subData = gepUserBuilder.CreateLoad(subGEP);
  1474. }
  1475. ldUser->replaceAllUsesWith(subData);
  1476. ldUser->eraseFromParent();
  1477. }
  1478. }
  1479. GEP->eraseFromParent();
  1480. } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
  1481. Value *val = stUser->getValueOperand();
  1482. for (unsigned i = 0; i < Ptrs.size(); i++) {
  1483. Value *Elt = userBuilder.CreateExtractElement(val, i);
  1484. userBuilder.CreateStore(Elt, Ptrs[i]);
  1485. }
  1486. stUser->eraseFromParent();
  1487. } else {
  1488. Value *ldVal =
  1489. UndefValue::get(matSubInst->getType()->getPointerElementType());
  1490. for (unsigned i = 0; i < Ptrs.size(); i++) {
  1491. Value *Elt = userBuilder.CreateLoad(Ptrs[i]);
  1492. ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
  1493. }
  1494. // Must be load here.
  1495. LoadInst *ldUser = cast<LoadInst>(subsUser);
  1496. ldUser->replaceAllUsesWith(ldVal);
  1497. ldUser->eraseFromParent();
  1498. }
  1499. }
  1500. }
  1501. matSubInst->eraseFromParent();
  1502. }
  1503. void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr) {
  1504. // Just translate into vec array here.
  1505. // DynamicIndexingVectorToArray will change it to scalar array.
  1506. Value *basePtr = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  1507. IRBuilder<> subBuilder(matSubInst);
  1508. unsigned opcode = hlsl::GetHLOpcode(matSubInst);
  1509. HLSubscriptOpcode subOp = static_cast<HLSubscriptOpcode>(opcode);
  1510. // Vector array is inside struct.
  1511. Value *zeroIdx = subBuilder.getInt32(0);
  1512. Value *vecArrayGep = vecPtr;
  1513. Type *matType = basePtr->getType()->getPointerElementType();
  1514. unsigned col, row;
  1515. HLMatrixLower::GetMatrixInfo(matType, col, row);
  1516. Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1517. std::vector<Value *> Ptrs;
  1518. switch (subOp) {
  1519. case HLSubscriptOpcode::ColMatSubscript:
  1520. case HLSubscriptOpcode::RowMatSubscript: {
  1521. // Vector array is row major.
  1522. // And subscripts will return a row.
  1523. Value *rowIdx = idx;
  1524. Value *subPtr = subBuilder.CreateInBoundsGEP(vecArrayGep, {zeroIdx, rowIdx});
  1525. matSubInst->replaceAllUsesWith(subPtr);
  1526. matSubInst->eraseFromParent();
  1527. return;
  1528. } break;
  1529. case HLSubscriptOpcode::RowMatElement:
  1530. case HLSubscriptOpcode::ColMatElement: {
  1531. // Vector array is row major.
  1532. if (ConstantDataSequential *elts = dyn_cast<ConstantDataSequential>(idx)) {
  1533. unsigned count = elts->getNumElements();
  1534. for (unsigned i = 0; i < count; i += 2) {
  1535. Value *rowIdx = subBuilder.getInt32(elts->getElementAsInteger(i));
  1536. Value *colIdx = subBuilder.getInt32(elts->getElementAsInteger(i + 1));
  1537. Value *Ptr =
  1538. subBuilder.CreateInBoundsGEP(vecArrayGep, {zeroIdx, rowIdx, colIdx});
  1539. Ptrs.emplace_back(Ptr);
  1540. }
  1541. } else {
  1542. ConstantAggregateZero *zeros = cast<ConstantAggregateZero>(idx);
  1543. unsigned size = zeros->getNumElements() >> 1;
  1544. for (unsigned i = 0; i < size; i++) {
  1545. Value *Ptr =
  1546. subBuilder.CreateInBoundsGEP(vecArrayGep, {zeroIdx, zeroIdx, zeroIdx});
  1547. Ptrs.emplace_back(Ptr);
  1548. }
  1549. }
  1550. } break;
  1551. default:
  1552. DXASSERT(0, "invalid operation for TranslateMatSubscriptOnGlobalPtr");
  1553. break;
  1554. }
  1555. if (Ptrs.size() == 1)
  1556. matSubInst->replaceAllUsesWith(Ptrs[0]);
  1557. else {
  1558. // Split the use of CI with Ptrs.
  1559. for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
  1560. Instruction *subsUser = cast<Instruction>(*(U++));
  1561. IRBuilder<> userBuilder(subsUser);
  1562. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
  1563. DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
  1564. Value *baseIdx = (GEP->idx_begin())->get();
  1565. DXASSERT_LOCALVAR(baseIdx, baseIdx == zeroIdx, "base index must be 0");
  1566. Value *idx = (GEP->idx_begin() + 1)->get();
  1567. for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
  1568. Instruction *gepUser = cast<Instruction>(*(gepU++));
  1569. IRBuilder<> gepUserBuilder(gepUser);
  1570. if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
  1571. Value *subData = stUser->getValueOperand();
  1572. // Only element can reach here.
  1573. // So index must be imm.
  1574. ConstantInt *immIdx = cast<ConstantInt>(idx);
  1575. Value *Ptr = Ptrs[immIdx->getSExtValue()];
  1576. gepUserBuilder.CreateStore(subData, Ptr);
  1577. stUser->eraseFromParent();
  1578. } else {
  1579. // Must be load here;
  1580. LoadInst *ldUser = cast<LoadInst>(gepUser);
  1581. Value *subData = nullptr;
  1582. // Only element can reach here.
  1583. // So index must be imm.
  1584. ConstantInt *immIdx = cast<ConstantInt>(idx);
  1585. Value *Ptr = Ptrs[immIdx->getSExtValue()];
  1586. subData = gepUserBuilder.CreateLoad(Ptr);
  1587. ldUser->replaceAllUsesWith(subData);
  1588. ldUser->eraseFromParent();
  1589. }
  1590. }
  1591. GEP->eraseFromParent();
  1592. } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
  1593. Value *val = stUser->getValueOperand();
  1594. for (unsigned i = 0; i < Ptrs.size(); i++) {
  1595. Value *Elt = userBuilder.CreateExtractElement(val, i);
  1596. userBuilder.CreateStore(Elt, Ptrs[i]);
  1597. }
  1598. stUser->eraseFromParent();
  1599. } else {
  1600. // Must be load here.
  1601. LoadInst *ldUser = cast<LoadInst>(subsUser);
  1602. // reload the value.
  1603. Value *ldVal =
  1604. UndefValue::get(matSubInst->getType()->getPointerElementType());
  1605. for (unsigned i = 0; i < Ptrs.size(); i++) {
  1606. Value *Elt = userBuilder.CreateLoad(Ptrs[i]);
  1607. ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
  1608. }
  1609. ldUser->replaceAllUsesWith(ldVal);
  1610. ldUser->eraseFromParent();
  1611. }
  1612. }
  1613. }
  1614. matSubInst->eraseFromParent();
  1615. }
  1616. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
  1617. CallInst *matLdStInst, Value *vecPtr) {
  1618. // Just translate into vec array here.
  1619. // DynamicIndexingVectorToArray will change it to scalar array.
  1620. IRBuilder<> Builder(matLdStInst);
  1621. unsigned opcode = hlsl::GetHLOpcode(matLdStInst);
  1622. HLMatLoadStoreOpcode matLdStOp = static_cast<HLMatLoadStoreOpcode>(opcode);
  1623. switch (matLdStOp) {
  1624. case HLMatLoadStoreOpcode::ColMatLoad:
  1625. case HLMatLoadStoreOpcode::RowMatLoad: {
  1626. // Load as vector array.
  1627. Value *newLoad = Builder.CreateLoad(vecPtr);
  1628. // Then change to vector.
  1629. // Use col major.
  1630. Value *Ptr = matLdStInst->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
  1631. Type *matTy = Ptr->getType()->getPointerElementType();
  1632. unsigned col, row;
  1633. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1634. Value *NewVal = UndefValue::get(matLdStInst->getType());
  1635. // Vector array is row major.
  1636. for (unsigned r = 0; r < row; r++) {
  1637. Value *eltRow = Builder.CreateExtractValue(newLoad, r);
  1638. for (unsigned c = 0; c < col; c++) {
  1639. Value *elt = Builder.CreateExtractElement(eltRow, c);
  1640. // Vector is col major.
  1641. unsigned matIdx = c * row + r;
  1642. NewVal = Builder.CreateInsertElement(NewVal, elt, matIdx);
  1643. }
  1644. }
  1645. matLdStInst->replaceAllUsesWith(NewVal);
  1646. matLdStInst->eraseFromParent();
  1647. } break;
  1648. case HLMatLoadStoreOpcode::ColMatStore:
  1649. case HLMatLoadStoreOpcode::RowMatStore: {
  1650. // Change value to vector array, then store.
  1651. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1652. Value *Ptr = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
  1653. Type *matTy = Ptr->getType()->getPointerElementType();
  1654. unsigned col, row;
  1655. Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1656. Type *rowTy = VectorType::get(EltTy, row);
  1657. Value *vecArrayGep = vecPtr;
  1658. Value *NewVal =
  1659. UndefValue::get(vecArrayGep->getType()->getPointerElementType());
  1660. // Vector array val is row major.
  1661. for (unsigned r = 0; r < row; r++) {
  1662. Value *NewElt = UndefValue::get(rowTy);
  1663. for (unsigned c = 0; c < col; c++) {
  1664. // Vector val is col major.
  1665. unsigned matIdx = c * row + r;
  1666. Value *elt = Builder.CreateExtractElement(Val, matIdx);
  1667. NewElt = Builder.CreateInsertElement(NewElt, elt, c);
  1668. }
  1669. NewVal = Builder.CreateInsertValue(NewVal, NewElt, r);
  1670. }
  1671. Builder.CreateStore(NewVal, vecArrayGep);
  1672. matLdStInst->eraseFromParent();
  1673. } break;
  1674. default:
  1675. DXASSERT(0, "invalid operation");
  1676. break;
  1677. }
  1678. }
  1679. // Flatten values inside init list to scalar elements.
  1680. static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
  1681. Value *val,
  1682. DenseMap<Instruction *, Value *> &matToVecMap,
  1683. IRBuilder<> &Builder) {
  1684. Type *valTy = val->getType();
  1685. if (valTy->isPointerTy()) {
  1686. if (HLMatrixLower::IsMatrixArrayPointer(valTy)) {
  1687. if (matToVecMap.count(cast<Instruction>(val))) {
  1688. val = matToVecMap[cast<Instruction>(val)];
  1689. } else {
  1690. // Convert to vec array with bitcast.
  1691. Type *vecArrayPtrTy = HLMatrixLower::LowerMatrixArrayPointer(valTy);
  1692. val = Builder.CreateBitCast(val, vecArrayPtrTy);
  1693. }
  1694. }
  1695. Type *valEltTy = val->getType()->getPointerElementType();
  1696. if (valEltTy->isVectorTy() || HLMatrixLower::IsMatrixType(valEltTy) ||
  1697. valEltTy->isSingleValueType()) {
  1698. Value *ldVal = Builder.CreateLoad(val);
  1699. IterateInitList(elts, idx, ldVal, matToVecMap, Builder);
  1700. } else {
  1701. Type *i32Ty = Type::getInt32Ty(valTy->getContext());
  1702. Value *zero = ConstantInt::get(i32Ty, 0);
  1703. if (ArrayType *AT = dyn_cast<ArrayType>(valEltTy)) {
  1704. for (unsigned i = 0; i < AT->getArrayNumElements(); i++) {
  1705. Value *gepIdx = ConstantInt::get(i32Ty, i);
  1706. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  1707. IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
  1708. }
  1709. } else {
  1710. // Struct.
  1711. StructType *ST = cast<StructType>(valEltTy);
  1712. for (unsigned i = 0; i < ST->getNumElements(); i++) {
  1713. Value *gepIdx = ConstantInt::get(i32Ty, i);
  1714. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  1715. IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
  1716. }
  1717. }
  1718. }
  1719. } else if (HLMatrixLower::IsMatrixType(valTy)) {
  1720. unsigned col, row;
  1721. HLMatrixLower::GetMatrixInfo(valTy, col, row);
  1722. unsigned matSize = col * row;
  1723. val = matToVecMap[cast<Instruction>(val)];
  1724. // temp matrix all col major
  1725. for (unsigned i = 0; i < matSize; i++) {
  1726. Value *Elt = Builder.CreateExtractElement(val, i);
  1727. elts[idx + i] = Elt;
  1728. }
  1729. idx += matSize;
  1730. } else {
  1731. if (valTy->isVectorTy()) {
  1732. unsigned vecSize = valTy->getVectorNumElements();
  1733. for (unsigned i = 0; i < vecSize; i++) {
  1734. Value *Elt = Builder.CreateExtractElement(val, i);
  1735. elts[idx + i] = Elt;
  1736. }
  1737. idx += vecSize;
  1738. } else {
  1739. DXASSERT(valTy->isSingleValueType(), "must be single value type here");
  1740. elts[idx++] = val;
  1741. }
  1742. }
  1743. }
  1744. // Store flattened init list elements into matrix array.
  1745. static void GenerateMatArrayInit(ArrayRef<Value *> elts, Value *ptr,
  1746. unsigned &offset, IRBuilder<> &Builder) {
  1747. Type *Ty = ptr->getType()->getPointerElementType();
  1748. if (Ty->isVectorTy()) {
  1749. unsigned vecSize = Ty->getVectorNumElements();
  1750. Type *eltTy = Ty->getVectorElementType();
  1751. Value *result = UndefValue::get(Ty);
  1752. for (unsigned i = 0; i < vecSize; i++) {
  1753. Value *elt = elts[offset + i];
  1754. if (elt->getType() != eltTy) {
  1755. // FIXME: get signed/unsigned info.
  1756. elt = CreateTypeCast(HLCastOpcode::DefaultCast, eltTy, elt, Builder);
  1757. }
  1758. result = Builder.CreateInsertElement(result, elt, i);
  1759. }
  1760. // Update offset.
  1761. offset += vecSize;
  1762. Builder.CreateStore(result, ptr);
  1763. } else {
  1764. DXASSERT(Ty->isArrayTy(), "must be array type");
  1765. Type *i32Ty = Type::getInt32Ty(Ty->getContext());
  1766. Constant *zero = ConstantInt::get(i32Ty, 0);
  1767. unsigned arraySize = Ty->getArrayNumElements();
  1768. for (unsigned i = 0; i < arraySize; i++) {
  1769. Value *GEP =
  1770. Builder.CreateInBoundsGEP(ptr, {zero, ConstantInt::get(i32Ty, i)});
  1771. GenerateMatArrayInit(elts, GEP, offset, Builder);
  1772. }
  1773. }
  1774. }
  1775. void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
  1776. // Array matrix init will be translated in TranslateMatArrayInitReplace.
  1777. if (matInitInst->getType()->isVoidTy())
  1778. return;
  1779. IRBuilder<> Builder(matInitInst);
  1780. unsigned col, row;
  1781. Type *EltTy = GetMatrixInfo(matInitInst->getType(), col, row);
  1782. Type *vecTy = VectorType::get(EltTy, col * row);
  1783. unsigned vecSize = vecTy->getVectorNumElements();
  1784. unsigned idx = 0;
  1785. std::vector<Value *> elts(vecSize);
  1786. // Skip opcode arg.
  1787. for (unsigned i = 1; i < matInitInst->getNumArgOperands(); i++) {
  1788. Value *val = matInitInst->getArgOperand(i);
  1789. IterateInitList(elts, idx, val, matToVecMap, Builder);
  1790. }
  1791. Value *newInit = UndefValue::get(vecTy);
  1792. // InitList is row major, the result is col major.
  1793. for (unsigned c = 0; c < col; c++)
  1794. for (unsigned r = 0; r < row; r++) {
  1795. unsigned rowMajorIdx = r * col + c;
  1796. unsigned colMajorIdx = c * row + r;
  1797. Constant *vecIdx = Builder.getInt32(colMajorIdx);
  1798. newInit = InsertElementInst::Create(newInit, elts[rowMajorIdx], vecIdx);
  1799. Builder.Insert(cast<Instruction>(newInit));
  1800. }
  1801. // Replace matInit function call with matInitInst.
  1802. DXASSERT(matToVecMap.count(matInitInst), "must has vec version");
  1803. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
  1804. vecUseInst->replaceAllUsesWith(newInit);
  1805. AddToDeadInsts(vecUseInst);
  1806. matToVecMap[matInitInst] = newInit;
  1807. }
  1808. void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
  1809. IRBuilder<> Builder(matSelectInst);
  1810. unsigned col, row;
  1811. Type *EltTy = GetMatrixInfo(matSelectInst->getType(), col, row);
  1812. Type *vecTy = VectorType::get(EltTy, col * row);
  1813. unsigned vecSize = vecTy->getVectorNumElements();
  1814. CallInst *vecUseInst = cast<CallInst>(matToVecMap[matSelectInst]);
  1815. Instruction *LHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx));
  1816. Instruction *RHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx));
  1817. Value *Cond = vecUseInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
  1818. bool isVecCond = Cond->getType()->isVectorTy();
  1819. if (isVecCond) {
  1820. Instruction *MatCond = cast<Instruction>(
  1821. matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx));
  1822. DXASSERT_NOMSG(matToVecMap.count(MatCond));
  1823. Cond = matToVecMap[MatCond];
  1824. }
  1825. DXASSERT_NOMSG(matToVecMap.count(LHS));
  1826. Value *VLHS = matToVecMap[LHS];
  1827. DXASSERT_NOMSG(matToVecMap.count(RHS));
  1828. Value *VRHS = matToVecMap[RHS];
  1829. Value *VecSelect = UndefValue::get(vecTy);
  1830. for (unsigned i = 0; i < vecSize; i++) {
  1831. llvm::Value *EltCond = Cond;
  1832. if (isVecCond)
  1833. EltCond = Builder.CreateExtractElement(Cond, i);
  1834. llvm::Value *EltL = Builder.CreateExtractElement(VLHS, i);
  1835. llvm::Value *EltR = Builder.CreateExtractElement(VRHS, i);
  1836. llvm::Value *EltSelect = Builder.CreateSelect(EltCond, EltL, EltR);
  1837. VecSelect = Builder.CreateInsertElement(VecSelect, EltSelect, i);
  1838. }
  1839. AddToDeadInsts(vecUseInst);
  1840. vecUseInst->replaceAllUsesWith(VecSelect);
  1841. matToVecMap[matSelectInst] = VecSelect;
  1842. }
  1843. void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
  1844. Instruction *vecInst,
  1845. GetElementPtrInst *matGEP) {
  1846. SmallVector<Value *, 4> idxList(matGEP->idx_begin(), matGEP->idx_end());
  1847. IRBuilder<> GEPBuilder(matGEP);
  1848. Value *newGEP = GEPBuilder.CreateInBoundsGEP(vecInst, idxList);
  1849. // Only used by mat subscript and mat ld/st.
  1850. for (Value::user_iterator user = matGEP->user_begin();
  1851. user != matGEP->user_end();) {
  1852. Instruction *useInst = cast<Instruction>(*(user++));
  1853. IRBuilder<> Builder(useInst);
  1854. // Skip return here.
  1855. if (isa<ReturnInst>(useInst))
  1856. continue;
  1857. if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
  1858. // Function call.
  1859. hlsl::HLOpcodeGroup group =
  1860. hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
  1861. switch (group) {
  1862. case HLOpcodeGroup::HLMatLoadStore: {
  1863. unsigned opcode = GetHLOpcode(useCall);
  1864. HLMatLoadStoreOpcode matOpcode =
  1865. static_cast<HLMatLoadStoreOpcode>(opcode);
  1866. switch (matOpcode) {
  1867. case HLMatLoadStoreOpcode::ColMatLoad:
  1868. case HLMatLoadStoreOpcode::RowMatLoad: {
  1869. // Skip the vector version.
  1870. if (useCall->getType()->isVectorTy())
  1871. continue;
  1872. Value *newLd = Builder.CreateLoad(newGEP);
  1873. DXASSERT(matToVecMap.count(useCall), "must has vec version");
  1874. Value *oldLd = matToVecMap[useCall];
  1875. // Delete the oldLd.
  1876. AddToDeadInsts(cast<Instruction>(oldLd));
  1877. oldLd->replaceAllUsesWith(newLd);
  1878. matToVecMap[useCall] = newLd;
  1879. } break;
  1880. case HLMatLoadStoreOpcode::ColMatStore:
  1881. case HLMatLoadStoreOpcode::RowMatStore: {
  1882. Value *vecPtr = newGEP;
  1883. Value *matVal = useCall->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1884. // Skip the vector version.
  1885. if (matVal->getType()->isVectorTy()) {
  1886. AddToDeadInsts(useCall);
  1887. continue;
  1888. }
  1889. Instruction *matInst = cast<Instruction>(matVal);
  1890. DXASSERT(matToVecMap.count(matInst), "must has vec version");
  1891. Value *vecVal = matToVecMap[matInst];
  1892. Builder.CreateStore(vecVal, vecPtr);
  1893. } break;
  1894. }
  1895. } break;
  1896. case HLOpcodeGroup::HLSubscript: {
  1897. TranslateMatSubscript(matGEP, newGEP, useCall);
  1898. } break;
  1899. default:
  1900. DXASSERT(0, "invalid operation");
  1901. break;
  1902. }
  1903. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
  1904. // Just replace the src with vec version.
  1905. useInst->setOperand(0, newGEP);
  1906. } else {
  1907. // Must be GEP.
  1908. GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
  1909. TranslateMatArrayGEP(matGEP, cast<Instruction>(newGEP), GEP);
  1910. }
  1911. }
  1912. AddToDeadInsts(matGEP);
  1913. }
  1914. void HLMatrixLowerPass::replaceMatWithVec(Instruction *matInst,
  1915. Instruction *vecInst) {
  1916. for (Value::user_iterator user = matInst->user_begin();
  1917. user != matInst->user_end();) {
  1918. Instruction *useInst = cast<Instruction>(*(user++));
  1919. // Skip return here.
  1920. if (isa<ReturnInst>(useInst))
  1921. continue;
  1922. // User must be function call.
  1923. if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
  1924. hlsl::HLOpcodeGroup group =
  1925. hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
  1926. switch (group) {
  1927. case HLOpcodeGroup::HLIntrinsic: {
  1928. MatIntrinsicReplace(cast<CallInst>(matInst), vecInst, useCall);
  1929. } break;
  1930. case HLOpcodeGroup::HLSelect: {
  1931. MatIntrinsicReplace(cast<CallInst>(matInst), vecInst, useCall);
  1932. } break;
  1933. case HLOpcodeGroup::HLBinOp: {
  1934. TrivialMatBinOpReplace(cast<CallInst>(matInst), vecInst, useCall);
  1935. } break;
  1936. case HLOpcodeGroup::HLUnOp: {
  1937. TrivialMatUnOpReplace(cast<CallInst>(matInst), vecInst, useCall);
  1938. } break;
  1939. case HLOpcodeGroup::HLCast: {
  1940. TranslateMatCast(cast<CallInst>(matInst), vecInst, useCall);
  1941. } break;
  1942. case HLOpcodeGroup::HLMatLoadStore: {
  1943. DXASSERT(matToVecMap.count(useCall), "must has vec version");
  1944. Value *vecUser = matToVecMap[useCall];
  1945. if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
  1946. // Load Already translated in lowerToVec.
  1947. // Store val operand will be set by the val use.
  1948. // Do nothing here.
  1949. } else if (StoreInst *stInst = dyn_cast<StoreInst>(vecUser))
  1950. stInst->setOperand(0, vecInst);
  1951. else
  1952. TrivialMatReplace(cast<CallInst>(matInst), vecInst, useCall);
  1953. } break;
  1954. case HLOpcodeGroup::HLSubscript: {
  1955. if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst))
  1956. TranslateMatSubscript(AI, vecInst, useCall);
  1957. else
  1958. TrivialMatReplace(cast<CallInst>(matInst), vecInst, useCall);
  1959. } break;
  1960. case HLOpcodeGroup::HLInit: {
  1961. DXASSERT(!isa<AllocaInst>(matInst), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
  1962. TranslateMatInit(useCall);
  1963. } break;
  1964. }
  1965. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
  1966. // Just replace the src with vec version.
  1967. useInst->setOperand(0, vecInst);
  1968. } else {
  1969. // Must be GEP on mat array alloca.
  1970. GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
  1971. AllocaInst *AI = cast<AllocaInst>(matInst);
  1972. TranslateMatArrayGEP(AI, vecInst, GEP);
  1973. }
  1974. }
  1975. }
  1976. void HLMatrixLowerPass::finalMatTranslation(Instruction *matInst) {
  1977. // Translate matInit.
  1978. if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
  1979. hlsl::HLOpcodeGroup group =
  1980. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  1981. switch (group) {
  1982. case HLOpcodeGroup::HLInit: {
  1983. TranslateMatInit(CI);
  1984. } break;
  1985. case HLOpcodeGroup::HLSelect: {
  1986. TranslateMatSelect(CI);
  1987. } break;
  1988. }
  1989. }
  1990. }
  1991. void HLMatrixLowerPass::DeleteDeadInsts() {
  1992. // Delete the matrix version insts.
  1993. for (Instruction *deadInst : m_deadInsts) {
  1994. // Replace with undef and remove it.
  1995. deadInst->replaceAllUsesWith(UndefValue::get(deadInst->getType()));
  1996. deadInst->eraseFromParent();
  1997. }
  1998. m_deadInsts.clear();
  1999. m_inDeadInstsSet.clear();
  2000. }
  2001. static bool OnlyUsedByMatrixLdSt(Value *V) {
  2002. bool onlyLdSt = true;
  2003. for (User *user : V->users()) {
  2004. CallInst *CI = cast<CallInst>(user);
  2005. if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
  2006. HLOpcodeGroup::HLMatLoadStore)
  2007. continue;
  2008. onlyLdSt = false;
  2009. break;
  2010. }
  2011. return onlyLdSt;
  2012. }
  2013. void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
  2014. // Lower to array of vector array like float[row][col].
  2015. // DynamicIndexingVectorToArray will change it to scalar array.
  2016. Type *Ty = GV->getType()->getPointerElementType();
  2017. std::vector<unsigned> arraySizeList;
  2018. while (Ty->isArrayTy()) {
  2019. arraySizeList.push_back(Ty->getArrayNumElements());
  2020. Ty = Ty->getArrayElementType();
  2021. }
  2022. unsigned row, col;
  2023. Type *EltTy = GetMatrixInfo(Ty, col, row);
  2024. Ty = VectorType::get(EltTy, col);
  2025. Ty = ArrayType::get(Ty, row);
  2026. for (auto arraySize = arraySizeList.rbegin();
  2027. arraySize != arraySizeList.rend(); arraySize++)
  2028. Ty = ArrayType::get(Ty, *arraySize);
  2029. Type *VecArrayTy = Ty;
  2030. // Matrix will use store to initialize.
  2031. // So set init val to undef.
  2032. Constant *InitVal = UndefValue::get(VecArrayTy);
  2033. bool isConst = GV->isConstant();
  2034. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  2035. unsigned AddressSpace = GV->getType()->getAddressSpace();
  2036. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  2037. Module *M = GV->getParent();
  2038. GlobalVariable *VecGV =
  2039. new llvm::GlobalVariable(*M, VecArrayTy, /*IsConstant*/ isConst, linkage,
  2040. /*InitVal*/ InitVal, GV->getName() + ".v",
  2041. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  2042. // Add debug info.
  2043. if (m_HasDbgInfo) {
  2044. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2045. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, VecGV);
  2046. }
  2047. DenseMap<Instruction *, Value *> matToVecMap;
  2048. for (User *U : GV->users()) {
  2049. Value *VecGEP = nullptr;
  2050. // Must be GEP or GEPOperator.
  2051. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  2052. IRBuilder<> Builder(GEP);
  2053. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  2054. VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
  2055. AddToDeadInsts(GEP);
  2056. } else {
  2057. GEPOperator *GEPOP = cast<GEPOperator>(U);
  2058. IRBuilder<> Builder(GV->getContext());
  2059. SmallVector<Value *, 4> idxList(GEPOP->idx_begin(), GEPOP->idx_end());
  2060. VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
  2061. }
  2062. for (auto user = U->user_begin(); user != U->user_end();) {
  2063. CallInst *CI = cast<CallInst>(*(user++));
  2064. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2065. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2066. TranslateMatLoadStoreOnGlobalPtr(CI, VecGEP);
  2067. } else if (group == HLOpcodeGroup::HLSubscript) {
  2068. TranslateMatSubscriptOnGlobalPtr(CI, VecGEP);
  2069. } else {
  2070. DXASSERT(0, "invalid operation");
  2071. }
  2072. }
  2073. }
  2074. DeleteDeadInsts();
  2075. GV->removeDeadConstantUsers();
  2076. GV->eraseFromParent();
  2077. }
  2078. void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
  2079. if (HLMatrixLower::IsMatrixArrayPointer(GV->getType())) {
  2080. runOnGlobalMatrixArray(GV);
  2081. return;
  2082. }
  2083. Type *Ty = GV->getType()->getPointerElementType();
  2084. if (!HLMatrixLower::IsMatrixType(Ty))
  2085. return;
  2086. bool onlyLdSt = OnlyUsedByMatrixLdSt(GV);
  2087. bool isConst = GV->isConstant();
  2088. Type *vecTy = HLMatrixLower::LowerMatrixType(Ty);
  2089. Module *M = GV->getParent();
  2090. const DataLayout &DL = M->getDataLayout();
  2091. if (onlyLdSt) {
  2092. Type *EltTy = vecTy->getVectorElementType();
  2093. unsigned vecSize = vecTy->getVectorNumElements();
  2094. std::vector<Value *> vecGlobals(vecSize);
  2095. // Matrix will use store to initialize.
  2096. // So set init val to undef.
  2097. Constant *InitVal = UndefValue::get(EltTy);
  2098. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  2099. unsigned AddressSpace = GV->getType()->getAddressSpace();
  2100. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  2101. unsigned debugOffset = 0;
  2102. unsigned size = DL.getTypeAllocSizeInBits(EltTy);
  2103. unsigned align = DL.getPrefTypeAlignment(EltTy);
  2104. for (int i = 0, e = vecSize; i != e; ++i) {
  2105. GlobalVariable *EltGV = new llvm::GlobalVariable(
  2106. *M, EltTy, /*IsConstant*/ isConst, linkage,
  2107. /*InitVal*/ InitVal, GV->getName() + "." + Twine(i),
  2108. /*InsertBefore*/nullptr,
  2109. TLMode, AddressSpace);
  2110. // Add debug info.
  2111. if (m_HasDbgInfo) {
  2112. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2113. HLModule::CreateElementGlobalVariableDebugInfo(
  2114. GV, Finder, EltGV, size, align, debugOffset,
  2115. EltGV->getName().ltrim(GV->getName()));
  2116. debugOffset += size;
  2117. }
  2118. vecGlobals[i] = EltGV;
  2119. }
  2120. for (User *user : GV->users()) {
  2121. CallInst *CI = cast<CallInst>(user);
  2122. TranslateMatLoadStoreOnGlobal(GV, vecGlobals, CI);
  2123. AddToDeadInsts(CI);
  2124. }
  2125. DeleteDeadInsts();
  2126. GV->eraseFromParent();
  2127. }
  2128. else {
  2129. // lower to array of scalar here.
  2130. ArrayType *AT = ArrayType::get(vecTy->getVectorElementType(), vecTy->getVectorNumElements());
  2131. GlobalVariable *arrayMat = new llvm::GlobalVariable(
  2132. *M, AT, /*IsConstant*/ false, llvm::GlobalValue::InternalLinkage,
  2133. /*InitVal*/ UndefValue::get(AT), GV->getName());
  2134. // Add debug info.
  2135. if (m_HasDbgInfo) {
  2136. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2137. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder,
  2138. arrayMat);
  2139. }
  2140. for (auto U = GV->user_begin(); U != GV->user_end();) {
  2141. Value *user = *(U++);
  2142. CallInst *CI = cast<CallInst>(user);
  2143. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2144. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2145. TranslateMatLoadStoreOnGlobal(GV, arrayMat, CI);
  2146. }
  2147. else {
  2148. DXASSERT(group == HLOpcodeGroup::HLSubscript, "Must be subscript operation");
  2149. TranslateMatSubscriptOnGlobal(GV, arrayMat, CI);
  2150. }
  2151. }
  2152. GV->removeDeadConstantUsers();
  2153. GV->eraseFromParent();
  2154. }
  2155. }
  2156. void HLMatrixLowerPass::runOnFunction(Function &F) {
  2157. // Create vector version of matrix instructions first.
  2158. // The matrix operands will be undefval for these instructions.
  2159. for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
  2160. BasicBlock *BB = BBI;
  2161. for (Instruction &I : BB->getInstList()) {
  2162. if (IsMatrixType(I.getType())) {
  2163. lowerToVec(&I);
  2164. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
  2165. Type *Ty = AI->getAllocatedType();
  2166. if (HLMatrixLower::IsMatrixType(Ty)) {
  2167. lowerToVec(&I);
  2168. } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
  2169. lowerToVec(&I);
  2170. }
  2171. }
  2172. }
  2173. }
  2174. // Update the use of matrix inst with the vector version.
  2175. for (auto matToVecIter = matToVecMap.begin();
  2176. matToVecIter != matToVecMap.end();) {
  2177. auto matToVec = matToVecIter++;
  2178. replaceMatWithVec(matToVec->first, cast<Instruction>(matToVec->second));
  2179. }
  2180. // Translate mat inst which require all operands ready.
  2181. for (auto matToVecIter = matToVecMap.begin();
  2182. matToVecIter != matToVecMap.end();) {
  2183. auto matToVec = matToVecIter++;
  2184. finalMatTranslation(matToVec->first);
  2185. }
  2186. // Delete the matrix version insts.
  2187. for (auto matToVecIter = matToVecMap.begin();
  2188. matToVecIter != matToVecMap.end();) {
  2189. auto matToVec = matToVecIter++;
  2190. // Add to m_deadInsts.
  2191. Instruction *matInst = matToVec->first;
  2192. AddToDeadInsts(matInst);
  2193. }
  2194. DeleteDeadInsts();
  2195. matToVecMap.clear();
  2196. }