HLMatrixLowerPass.cpp 90 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixLowerPass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // HLMatrixLowerPass implementation. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  12. #include "dxc/HLSL/HLMatrixLowerPass.h"
  13. #include "dxc/HLSL/HLOperations.h"
  14. #include "dxc/HLSL/HLModule.h"
  15. #include "dxc/HLSL/DxilUtil.h"
  16. #include "dxc/HlslIntrinsicOp.h"
  17. #include "dxc/Support/Global.h"
  18. #include "dxc/HLSL/DxilOperations.h"
  19. #include "llvm/IR/IRBuilder.h"
  20. #include "llvm/IR/Module.h"
  21. #include "llvm/IR/DebugInfo.h"
  22. #include "llvm/IR/IntrinsicInst.h"
  23. #include "llvm/Transforms/Utils/Local.h"
  24. #include "llvm/Pass.h"
  25. #include "llvm/Support/raw_ostream.h"
  26. #include <unordered_set>
  27. #include <vector>
  28. using namespace llvm;
  29. using namespace hlsl;
  30. using namespace hlsl::HLMatrixLower;
  31. namespace hlsl {
  32. namespace HLMatrixLower {
  33. bool IsMatrixType(Type *Ty) {
  34. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  35. Type *EltTy = ST->getElementType(0);
  36. if (!ST->getName().startswith("class.matrix"))
  37. return false;
  38. bool isVecArray = EltTy->isArrayTy() &&
  39. EltTy->getArrayElementType()->isVectorTy();
  40. return isVecArray && EltTy->getArrayNumElements() <= 4;
  41. }
  42. return false;
  43. }
  44. // Translate matrix type to vector type.
  45. Type *LowerMatrixType(Type *Ty) {
  46. // Only translate matrix type and function type which use matrix type.
  47. // Not translate struct has matrix or matrix pointer.
  48. // Struct should be flattened before.
  49. // Pointer could cover by matldst which use vector as value type.
  50. if (FunctionType *FT = dyn_cast<FunctionType>(Ty)) {
  51. Type *RetTy = LowerMatrixType(FT->getReturnType());
  52. SmallVector<Type *, 4> params;
  53. for (Type *param : FT->params()) {
  54. params.emplace_back(LowerMatrixType(param));
  55. }
  56. return FunctionType::get(RetTy, params, false);
  57. } else if (IsMatrixType(Ty)) {
  58. unsigned row, col;
  59. Type *EltTy = GetMatrixInfo(Ty, col, row);
  60. return VectorType::get(EltTy, row * col);
  61. } else {
  62. return Ty;
  63. }
  64. }
  65. Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
  66. DXASSERT(IsMatrixType(Ty), "not matrix type");
  67. StructType *ST = cast<StructType>(Ty);
  68. Type *EltTy = ST->getElementType(0);
  69. Type *RowTy = EltTy->getArrayElementType();
  70. row = EltTy->getArrayNumElements();
  71. col = RowTy->getVectorNumElements();
  72. return RowTy->getVectorElementType();
  73. }
  74. bool IsMatrixArrayPointer(llvm::Type *Ty) {
  75. if (!Ty->isPointerTy())
  76. return false;
  77. Ty = Ty->getPointerElementType();
  78. if (!Ty->isArrayTy())
  79. return false;
  80. while (Ty->isArrayTy())
  81. Ty = Ty->getArrayElementType();
  82. return IsMatrixType(Ty);
  83. }
  84. Type *LowerMatrixArrayPointer(Type *Ty) {
  85. Ty = Ty->getPointerElementType();
  86. std::vector<unsigned> arraySizeList;
  87. while (Ty->isArrayTy()) {
  88. arraySizeList.push_back(Ty->getArrayNumElements());
  89. Ty = Ty->getArrayElementType();
  90. }
  91. Ty = LowerMatrixType(Ty);
  92. for (auto arraySize = arraySizeList.rbegin();
  93. arraySize != arraySizeList.rend(); arraySize++)
  94. Ty = ArrayType::get(Ty, *arraySize);
  95. return PointerType::get(Ty, 0);
  96. }
  97. Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
  98. IRBuilder<> &Builder) {
  99. Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
  100. for (unsigned i = 0; i < size; i++)
  101. Vec = Builder.CreateInsertElement(Vec, elts[i], i);
  102. return Vec;
  103. }
  104. Value *LowerGEPOnMatIndexListToIndex(
  105. llvm::GetElementPtrInst *GEP, ArrayRef<Value *> IdxList) {
  106. IRBuilder<> Builder(GEP);
  107. Value *zero = Builder.getInt32(0);
  108. DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
  109. Value *baseIdx = (GEP->idx_begin())->get();
  110. DXASSERT_LOCALVAR(baseIdx, baseIdx == zero, "base index must be 0");
  111. Value *Idx = (GEP->idx_begin() + 1)->get();
  112. if (ConstantInt *immIdx = dyn_cast<ConstantInt>(Idx)) {
  113. return IdxList[immIdx->getSExtValue()];
  114. } else {
  115. IRBuilder<> AllocaBuilder(
  116. GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  117. unsigned size = IdxList.size();
  118. // Store idxList to temp array.
  119. ArrayType *AT = ArrayType::get(IdxList[0]->getType(), size);
  120. Value *tempArray = AllocaBuilder.CreateAlloca(AT);
  121. for (unsigned i = 0; i < size; i++) {
  122. Value *EltPtr = Builder.CreateGEP(tempArray, {zero, Builder.getInt32(i)});
  123. Builder.CreateStore(IdxList[i], EltPtr);
  124. }
  125. // Load the idx.
  126. Value *GEPOffset = Builder.CreateGEP(tempArray, {zero, Idx});
  127. return Builder.CreateLoad(GEPOffset);
  128. }
  129. }
  130. unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row) {
  131. return c * row + r;
  132. }
  133. unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col) {
  134. return r * col + c;
  135. }
  136. } // namespace HLMatrixLower
  137. } // namespace hlsl
  138. namespace {
  139. class HLMatrixLowerPass : public ModulePass {
  140. public:
  141. static char ID; // Pass identification, replacement for typeid
  142. explicit HLMatrixLowerPass() : ModulePass(ID) {}
  143. const char *getPassName() const override { return "HL matrix lower"; }
  144. bool runOnModule(Module &M) override {
  145. m_pModule = &M;
  146. m_pHLModule = &m_pModule->GetOrCreateHLModule();
  147. // Load up debug information, to cross-reference values and the instructions
  148. // used to load them.
  149. m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  150. for (Function &F : M.functions()) {
  151. if (F.isDeclaration())
  152. continue;
  153. runOnFunction(F);
  154. }
  155. std::vector<GlobalVariable*> staticGVs;
  156. for (GlobalVariable &GV : M.globals()) {
  157. if (dxilutil::IsStaticGlobal(&GV) ||
  158. dxilutil::IsSharedMemoryGlobal(&GV)) {
  159. staticGVs.emplace_back(&GV);
  160. }
  161. }
  162. for (GlobalVariable *GV : staticGVs)
  163. runOnGlobal(GV);
  164. return true;
  165. }
  166. private:
  167. Module *m_pModule;
  168. HLModule *m_pHLModule;
  169. bool m_HasDbgInfo;
  170. std::vector<Instruction *> m_deadInsts;
  171. // For instruction like matrix array init.
  172. // May use more than 1 matrix alloca inst.
  173. // This set is here to avoid put it into deadInsts more than once.
  174. std::unordered_set<Instruction *> m_inDeadInstsSet;
  175. // For most matrix insturction users, it will only have one matrix use.
  176. // Use vector so save deadInsts because vector is cheap.
  177. void AddToDeadInsts(Instruction *I) { m_deadInsts.emplace_back(I); }
  178. // In case instruction has more than one matrix use.
  179. // Use AddToDeadInstsWithDups to make sure it's not add to deadInsts more than once.
  180. void AddToDeadInstsWithDups(Instruction *I) {
  181. if (m_inDeadInstsSet.count(I) == 0) {
  182. // Only add to deadInsts when it's not inside m_inDeadInstsSet.
  183. m_inDeadInstsSet.insert(I);
  184. AddToDeadInsts(I);
  185. }
  186. }
  187. void runOnFunction(Function &F);
  188. void runOnGlobal(GlobalVariable *GV);
  189. void runOnGlobalMatrixArray(GlobalVariable *GV);
  190. Instruction *MatCastToVec(CallInst *CI);
  191. Instruction *MatLdStToVec(CallInst *CI);
  192. Instruction *MatSubscriptToVec(CallInst *CI);
  193. Instruction *MatFrExpToVec(CallInst *CI);
  194. Instruction *MatIntrinsicToVec(CallInst *CI);
  195. Instruction *TrivialMatUnOpToVec(CallInst *CI);
  196. // Replace matInst with vecInst on matUseInst.
  197. void TrivialMatUnOpReplace(CallInst *matInst, Instruction *vecInst,
  198. CallInst *matUseInst);
  199. Instruction *TrivialMatBinOpToVec(CallInst *CI);
  200. // Replace matInst with vecInst on matUseInst.
  201. void TrivialMatBinOpReplace(CallInst *matInst, Instruction *vecInst,
  202. CallInst *matUseInst);
  203. // Replace matInst with vecInst on mulInst.
  204. void TranslateMatMatMul(CallInst *matInst, Instruction *vecInst,
  205. CallInst *mulInst, bool isSigned);
  206. void TranslateMatVecMul(CallInst *matInst, Instruction *vecInst,
  207. CallInst *mulInst, bool isSigned);
  208. void TranslateVecMatMul(CallInst *matInst, Instruction *vecInst,
  209. CallInst *mulInst, bool isSigned);
  210. void TranslateMul(CallInst *matInst, Instruction *vecInst, CallInst *mulInst,
  211. bool isSigned);
  212. // Replace matInst with vecInst on transposeInst.
  213. void TranslateMatTranspose(CallInst *matInst, Instruction *vecInst,
  214. CallInst *transposeInst);
  215. void TranslateMatDeterminant(CallInst *matInst, Instruction *vecInst,
  216. CallInst *determinantInst);
  217. void MatIntrinsicReplace(CallInst *matInst, Instruction *vecInst,
  218. CallInst *matUseInst);
  219. // Replace matInst with vecInst on castInst.
  220. void TranslateMatMatCast(CallInst *matInst, Instruction *vecInst,
  221. CallInst *castInst);
  222. void TranslateMatToOtherCast(CallInst *matInst, Instruction *vecInst,
  223. CallInst *castInst);
  224. void TranslateMatCast(CallInst *matInst, Instruction *vecInst,
  225. CallInst *castInst);
  226. void TranslateMatMajorCast(CallInst *matInst, Instruction *vecInst,
  227. CallInst *castInst, bool rowToCol, bool transpose);
  228. // Replace matInst with vecInst in matSubscript
  229. void TranslateMatSubscript(Value *matInst, Value *vecInst,
  230. CallInst *matSubInst);
  231. // Replace matInst with vecInst
  232. void TranslateMatInit(CallInst *matInitInst);
  233. // Replace matInst with vecInst.
  234. void TranslateMatSelect(CallInst *matSelectInst);
  235. // Replace matInst with vecInst on matInitInst.
  236. void TranslateMatArrayGEP(Value *matInst, Instruction *vecInst,
  237. GetElementPtrInst *matGEP);
  238. void TranslateMatLoadStoreOnGlobal(Value *matGlobal, ArrayRef<Value *>vecGlobals,
  239. CallInst *matLdStInst);
  240. void TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
  241. CallInst *matLdStInst);
  242. void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
  243. void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
  244. // Replace matInst with vecInst on matUseInst.
  245. void TrivialMatReplace(CallInst *matInst, Instruction *vecInst,
  246. CallInst *matUseInst);
  247. // Lower a matrix type instruction to a vector type instruction.
  248. void lowerToVec(Instruction *matInst);
  249. // Lower users of a matrix type instruction.
  250. void replaceMatWithVec(Instruction *matInst, Instruction *vecInst);
  251. // Translate mat inst which need all operands ready.
  252. void finalMatTranslation(Instruction *matInst);
  253. // Delete dead insts in m_deadInsts.
  254. void DeleteDeadInsts();
  255. // Map from matrix inst to its vector version.
  256. DenseMap<Instruction *, Value *> matToVecMap;
  257. };
  258. }
  259. char HLMatrixLowerPass::ID = 0;
  260. ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
  261. INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
  262. static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
  263. IRBuilder<> Builder) {
  264. // Cast to bool.
  265. if (toTy->getScalarType()->isIntegerTy() &&
  266. toTy->getScalarType()->getIntegerBitWidth() == 1) {
  267. Type *fromTy = src->getType();
  268. bool isFloat = fromTy->getScalarType()->isFloatingPointTy();
  269. Constant *zero;
  270. if (isFloat)
  271. zero = llvm::ConstantFP::get(fromTy->getScalarType(), 0);
  272. else
  273. zero = llvm::ConstantInt::get(fromTy->getScalarType(), 0);
  274. if (toTy->getScalarType() != toTy) {
  275. // Create constant vector.
  276. unsigned size = toTy->getVectorNumElements();
  277. std::vector<Constant *> zeros(size, zero);
  278. zero = llvm::ConstantVector::get(zeros);
  279. }
  280. if (isFloat)
  281. return cast<Instruction>(Builder.CreateFCmpOEQ(src, zero));
  282. else
  283. return cast<Instruction>(Builder.CreateICmpEQ(src, zero));
  284. }
  285. Type *eltToTy = toTy->getScalarType();
  286. Type *eltFromTy = src->getType()->getScalarType();
  287. bool fromUnsigned = castOp == HLCastOpcode::FromUnsignedCast ||
  288. castOp == HLCastOpcode::UnsignedUnsignedCast;
  289. bool toUnsigned = castOp == HLCastOpcode::ToUnsignedCast ||
  290. castOp == HLCastOpcode::UnsignedUnsignedCast;
  291. Instruction::CastOps castOps = static_cast<Instruction::CastOps>(
  292. HLModule::FindCastOp(fromUnsigned, toUnsigned, eltFromTy, eltToTy));
  293. return cast<Instruction>(Builder.CreateCast(castOps, src, toTy));
  294. }
  295. Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
  296. IRBuilder<> Builder(CI);
  297. Value *op = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  298. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(CI));
  299. bool ToMat = IsMatrixType(CI->getType());
  300. bool FromMat = IsMatrixType(op->getType());
  301. if (ToMat && !FromMat) {
  302. // Translate OtherToMat here.
  303. // Rest will translated when replace.
  304. unsigned col, row;
  305. Type *EltTy = GetMatrixInfo(CI->getType(), col, row);
  306. unsigned toSize = col * row;
  307. Instruction *sizeCast = nullptr;
  308. Type *FromTy = op->getType();
  309. Type *I32Ty = IntegerType::get(FromTy->getContext(), 32);
  310. if (FromTy->isVectorTy()) {
  311. std::vector<Constant *> MaskVec(toSize);
  312. for (size_t i = 0; i != toSize; ++i)
  313. MaskVec[i] = ConstantInt::get(I32Ty, i);
  314. Value *castMask = ConstantVector::get(MaskVec);
  315. sizeCast = new ShuffleVectorInst(op, op, castMask);
  316. Builder.Insert(sizeCast);
  317. } else {
  318. op = Builder.CreateInsertElement(
  319. UndefValue::get(VectorType::get(FromTy, 1)), op, (uint64_t)0);
  320. Constant *zero = ConstantInt::get(I32Ty, 0);
  321. std::vector<Constant *> MaskVec(toSize, zero);
  322. Value *castMask = ConstantVector::get(MaskVec);
  323. sizeCast = new ShuffleVectorInst(op, op, castMask);
  324. Builder.Insert(sizeCast);
  325. }
  326. Instruction *typeCast = sizeCast;
  327. if (EltTy != FromTy->getScalarType()) {
  328. typeCast = CreateTypeCast(opcode, VectorType::get(EltTy, toSize),
  329. sizeCast, Builder);
  330. }
  331. return typeCast;
  332. } else if (FromMat && ToMat) {
  333. if (isa<Argument>(op)) {
  334. // Cast From mat to mat for arugment.
  335. IRBuilder<> Builder(CI);
  336. // Here only lower the return type to vector.
  337. Type *RetTy = LowerMatrixType(CI->getType());
  338. SmallVector<Type *, 4> params;
  339. for (Value *operand : CI->arg_operands()) {
  340. params.emplace_back(operand->getType());
  341. }
  342. Type *FT = FunctionType::get(RetTy, params, false);
  343. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  344. unsigned opcode = GetHLOpcode(CI);
  345. Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT),
  346. group, opcode);
  347. SmallVector<Value *, 4> argList;
  348. for (Value *arg : CI->arg_operands()) {
  349. argList.emplace_back(arg);
  350. }
  351. return Builder.CreateCall(vecF, argList);
  352. }
  353. }
  354. return MatIntrinsicToVec(CI);
  355. }
  356. Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
  357. IRBuilder<> Builder(CI);
  358. unsigned opcode = GetHLOpcode(CI);
  359. HLMatLoadStoreOpcode matOpcode = static_cast<HLMatLoadStoreOpcode>(opcode);
  360. Instruction *result = nullptr;
  361. switch (matOpcode) {
  362. case HLMatLoadStoreOpcode::ColMatLoad:
  363. case HLMatLoadStoreOpcode::RowMatLoad: {
  364. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
  365. if (isa<AllocaInst>(matPtr)) {
  366. Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
  367. result = Builder.CreateLoad(vecPtr);
  368. } else
  369. result = MatIntrinsicToVec(CI);
  370. } break;
  371. case HLMatLoadStoreOpcode::ColMatStore:
  372. case HLMatLoadStoreOpcode::RowMatStore: {
  373. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
  374. if (isa<AllocaInst>(matPtr)) {
  375. Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
  376. Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  377. Value *vecVal =
  378. UndefValue::get(HLMatrixLower::LowerMatrixType(matVal->getType()));
  379. result = Builder.CreateStore(vecVal, vecPtr);
  380. } else
  381. result = MatIntrinsicToVec(CI);
  382. } break;
  383. }
  384. return result;
  385. }
  386. Instruction *HLMatrixLowerPass::MatSubscriptToVec(CallInst *CI) {
  387. IRBuilder<> Builder(CI);
  388. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  389. if (isa<AllocaInst>(matPtr)) {
  390. // Here just create a new matSub call which use vec ptr.
  391. // Later in TranslateMatSubscript will do the real translation.
  392. std::vector<Value *> args(CI->getNumArgOperands());
  393. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  394. args[i] = CI->getArgOperand(i);
  395. }
  396. // Change mat ptr into vec ptr.
  397. args[HLOperandIndex::kMatSubscriptMatOpIdx] =
  398. matToVecMap[cast<Instruction>(matPtr)];
  399. std::vector<Type *> paramTyList(CI->getNumArgOperands());
  400. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  401. paramTyList[i] = args[i]->getType();
  402. }
  403. FunctionType *funcTy = FunctionType::get(CI->getType(), paramTyList, false);
  404. unsigned opcode = GetHLOpcode(CI);
  405. Function *opFunc = GetOrCreateHLFunction(*m_pModule, funcTy, HLOpcodeGroup::HLSubscript, opcode);
  406. return Builder.CreateCall(opFunc, args);
  407. } else
  408. return MatIntrinsicToVec(CI);
  409. }
  410. Instruction *HLMatrixLowerPass::MatFrExpToVec(CallInst *CI) {
  411. IRBuilder<> Builder(CI);
  412. FunctionType *FT = CI->getCalledFunction()->getFunctionType();
  413. Type *RetTy = LowerMatrixType(FT->getReturnType());
  414. SmallVector<Type *, 4> params;
  415. for (Type *param : FT->params()) {
  416. if (!param->isPointerTy()) {
  417. params.emplace_back(LowerMatrixType(param));
  418. } else {
  419. // Lower pointer type for frexp.
  420. Type *EltTy = LowerMatrixType(param->getPointerElementType());
  421. params.emplace_back(
  422. PointerType::get(EltTy, param->getPointerAddressSpace()));
  423. }
  424. }
  425. Type *VecFT = FunctionType::get(RetTy, params, false);
  426. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  427. Function *vecF =
  428. GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(VecFT), group,
  429. static_cast<unsigned>(IntrinsicOp::IOP_frexp));
  430. SmallVector<Value *, 4> argList;
  431. auto paramTyIt = params.begin();
  432. for (Value *arg : CI->arg_operands()) {
  433. Type *Ty = arg->getType();
  434. Type *ParamTy = *(paramTyIt++);
  435. if (Ty != ParamTy)
  436. argList.emplace_back(UndefValue::get(ParamTy));
  437. else
  438. argList.emplace_back(arg);
  439. }
  440. return Builder.CreateCall(vecF, argList);
  441. }
  442. Instruction *HLMatrixLowerPass::MatIntrinsicToVec(CallInst *CI) {
  443. IRBuilder<> Builder(CI);
  444. unsigned opcode = GetHLOpcode(CI);
  445. if (opcode == static_cast<unsigned>(IntrinsicOp::IOP_frexp))
  446. return MatFrExpToVec(CI);
  447. Type *FT = LowerMatrixType(CI->getCalledFunction()->getFunctionType());
  448. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  449. Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT), group, opcode);
  450. SmallVector<Value *, 4> argList;
  451. for (Value *arg : CI->arg_operands()) {
  452. Type *Ty = arg->getType();
  453. if (IsMatrixType(Ty)) {
  454. argList.emplace_back(UndefValue::get(LowerMatrixType(Ty)));
  455. } else
  456. argList.emplace_back(arg);
  457. }
  458. return Builder.CreateCall(vecF, argList);
  459. }
  460. static Value *VectorizeScalarOp(Value *op, Type *dstTy, IRBuilder<> &Builder) {
  461. if (op->getType() == dstTy)
  462. return op;
  463. op = Builder.CreateInsertElement(
  464. UndefValue::get(VectorType::get(op->getType(), 1)), op, (uint64_t)0);
  465. Type *I32Ty = IntegerType::get(dstTy->getContext(), 32);
  466. Constant *zero = ConstantInt::get(I32Ty, 0);
  467. std::vector<Constant *> MaskVec(dstTy->getVectorNumElements(), zero);
  468. Value *castMask = ConstantVector::get(MaskVec);
  469. Value *vecOp = new ShuffleVectorInst(op, op, castMask);
  470. Builder.Insert(cast<Instruction>(vecOp));
  471. return vecOp;
  472. }
  473. Instruction *HLMatrixLowerPass::TrivialMatUnOpToVec(CallInst *CI) {
  474. Type *ResultTy = LowerMatrixType(CI->getType());
  475. UndefValue *tmp = UndefValue::get(ResultTy);
  476. IRBuilder<> Builder(CI);
  477. HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(CI));
  478. bool isFloat = ResultTy->getVectorElementType()->isFloatingPointTy();
  479. auto GetVecConst = [&](Type *Ty, int v) -> Constant * {
  480. Constant *val = isFloat ? ConstantFP::get(Ty->getScalarType(), v)
  481. : ConstantInt::get(Ty->getScalarType(), v);
  482. std::vector<Constant *> vals(Ty->getVectorNumElements(), val);
  483. return ConstantVector::get(vals);
  484. };
  485. Constant *one = GetVecConst(ResultTy, 1);
  486. Instruction *Result = nullptr;
  487. switch (opcode) {
  488. case HLUnaryOpcode::Minus: {
  489. Constant *zero = GetVecConst(ResultTy, 0);
  490. if (isFloat)
  491. Result = BinaryOperator::CreateFSub(zero, tmp);
  492. else
  493. Result = BinaryOperator::CreateSub(zero, tmp);
  494. } break;
  495. case HLUnaryOpcode::LNot: {
  496. Constant *zero = GetVecConst(ResultTy, 0);
  497. if (isFloat)
  498. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_UNE, tmp, zero);
  499. else
  500. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, zero);
  501. } break;
  502. case HLUnaryOpcode::Not:
  503. Result = BinaryOperator::CreateXor(tmp, tmp);
  504. break;
  505. case HLUnaryOpcode::PostInc:
  506. case HLUnaryOpcode::PreInc:
  507. if (isFloat)
  508. Result = BinaryOperator::CreateFAdd(tmp, one);
  509. else
  510. Result = BinaryOperator::CreateAdd(tmp, one);
  511. break;
  512. case HLUnaryOpcode::PostDec:
  513. case HLUnaryOpcode::PreDec:
  514. if (isFloat)
  515. Result = BinaryOperator::CreateFSub(tmp, one);
  516. else
  517. Result = BinaryOperator::CreateSub(tmp, one);
  518. break;
  519. default:
  520. DXASSERT(0, "not implement");
  521. return nullptr;
  522. }
  523. Builder.Insert(Result);
  524. return Result;
  525. }
  526. Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
  527. Type *ResultTy = LowerMatrixType(CI->getType());
  528. IRBuilder<> Builder(CI);
  529. HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(CI));
  530. Type *OpTy = LowerMatrixType(
  531. CI->getOperand(HLOperandIndex::kBinaryOpSrc0Idx)->getType());
  532. UndefValue *tmp = UndefValue::get(OpTy);
  533. bool isFloat = OpTy->getVectorElementType()->isFloatingPointTy();
  534. Instruction *Result = nullptr;
  535. switch (opcode) {
  536. case HLBinaryOpcode::Add:
  537. if (isFloat)
  538. Result = BinaryOperator::CreateFAdd(tmp, tmp);
  539. else
  540. Result = BinaryOperator::CreateAdd(tmp, tmp);
  541. break;
  542. case HLBinaryOpcode::Sub:
  543. if (isFloat)
  544. Result = BinaryOperator::CreateFSub(tmp, tmp);
  545. else
  546. Result = BinaryOperator::CreateSub(tmp, tmp);
  547. break;
  548. case HLBinaryOpcode::Mul:
  549. if (isFloat)
  550. Result = BinaryOperator::CreateFMul(tmp, tmp);
  551. else
  552. Result = BinaryOperator::CreateMul(tmp, tmp);
  553. break;
  554. case HLBinaryOpcode::Div:
  555. if (isFloat)
  556. Result = BinaryOperator::CreateFDiv(tmp, tmp);
  557. else
  558. Result = BinaryOperator::CreateSDiv(tmp, tmp);
  559. break;
  560. case HLBinaryOpcode::Rem:
  561. if (isFloat)
  562. Result = BinaryOperator::CreateFRem(tmp, tmp);
  563. else
  564. Result = BinaryOperator::CreateSRem(tmp, tmp);
  565. break;
  566. case HLBinaryOpcode::And:
  567. Result = BinaryOperator::CreateAnd(tmp, tmp);
  568. break;
  569. case HLBinaryOpcode::Or:
  570. Result = BinaryOperator::CreateOr(tmp, tmp);
  571. break;
  572. case HLBinaryOpcode::Xor:
  573. Result = BinaryOperator::CreateXor(tmp, tmp);
  574. break;
  575. case HLBinaryOpcode::Shl: {
  576. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  577. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  578. Result = BinaryOperator::CreateShl(tmp, tmp);
  579. } break;
  580. case HLBinaryOpcode::Shr: {
  581. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  582. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  583. Result = BinaryOperator::CreateAShr(tmp, tmp);
  584. } break;
  585. case HLBinaryOpcode::LT:
  586. if (isFloat)
  587. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLT, tmp, tmp);
  588. else
  589. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT, tmp, tmp);
  590. break;
  591. case HLBinaryOpcode::GT:
  592. if (isFloat)
  593. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGT, tmp, tmp);
  594. else
  595. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, tmp, tmp);
  596. break;
  597. case HLBinaryOpcode::LE:
  598. if (isFloat)
  599. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLE, tmp, tmp);
  600. else
  601. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLE, tmp, tmp);
  602. break;
  603. case HLBinaryOpcode::GE:
  604. if (isFloat)
  605. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGE, tmp, tmp);
  606. else
  607. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGE, tmp, tmp);
  608. break;
  609. case HLBinaryOpcode::EQ:
  610. if (isFloat)
  611. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, tmp);
  612. else
  613. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, tmp);
  614. break;
  615. case HLBinaryOpcode::NE:
  616. if (isFloat)
  617. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, tmp);
  618. else
  619. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, tmp);
  620. break;
  621. case HLBinaryOpcode::UDiv:
  622. Result = BinaryOperator::CreateUDiv(tmp, tmp);
  623. break;
  624. case HLBinaryOpcode::URem:
  625. Result = BinaryOperator::CreateURem(tmp, tmp);
  626. break;
  627. case HLBinaryOpcode::UShr: {
  628. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  629. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  630. Result = BinaryOperator::CreateLShr(tmp, tmp);
  631. } break;
  632. case HLBinaryOpcode::ULT:
  633. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, tmp, tmp);
  634. break;
  635. case HLBinaryOpcode::UGT:
  636. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, tmp, tmp);
  637. break;
  638. case HLBinaryOpcode::ULE:
  639. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, tmp, tmp);
  640. break;
  641. case HLBinaryOpcode::UGE:
  642. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGE, tmp, tmp);
  643. break;
  644. case HLBinaryOpcode::LAnd:
  645. case HLBinaryOpcode::LOr: {
  646. Constant *zero;
  647. if (isFloat)
  648. zero = llvm::ConstantFP::get(ResultTy->getVectorElementType(), 0);
  649. else
  650. zero = llvm::ConstantInt::get(ResultTy->getVectorElementType(), 0);
  651. unsigned size = ResultTy->getVectorNumElements();
  652. std::vector<Constant *> zeros(size, zero);
  653. Value *vecZero = llvm::ConstantVector::get(zeros);
  654. Instruction *cmpL;
  655. if (isFloat)
  656. cmpL =
  657. CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
  658. else
  659. cmpL = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
  660. Builder.Insert(cmpL);
  661. Instruction *cmpR;
  662. if (isFloat)
  663. cmpR =
  664. CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
  665. else
  666. cmpR = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
  667. Builder.Insert(cmpR);
  668. // How to map l, r back? Need check opcode
  669. if (opcode == HLBinaryOpcode::LOr)
  670. Result = BinaryOperator::CreateAnd(cmpL, cmpR);
  671. else
  672. Result = BinaryOperator::CreateAnd(cmpL, cmpR);
  673. break;
  674. }
  675. default:
  676. DXASSERT(0, "not implement");
  677. return nullptr;
  678. }
  679. Builder.Insert(Result);
  680. return Result;
  681. }
  682. void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
  683. Instruction *vecInst;
  684. if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
  685. hlsl::HLOpcodeGroup group =
  686. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  687. switch (group) {
  688. case HLOpcodeGroup::HLIntrinsic: {
  689. vecInst = MatIntrinsicToVec(CI);
  690. } break;
  691. case HLOpcodeGroup::HLSelect: {
  692. vecInst = MatIntrinsicToVec(CI);
  693. } break;
  694. case HLOpcodeGroup::HLBinOp: {
  695. vecInst = TrivialMatBinOpToVec(CI);
  696. } break;
  697. case HLOpcodeGroup::HLUnOp: {
  698. vecInst = TrivialMatUnOpToVec(CI);
  699. } break;
  700. case HLOpcodeGroup::HLCast: {
  701. vecInst = MatCastToVec(CI);
  702. } break;
  703. case HLOpcodeGroup::HLInit: {
  704. vecInst = MatIntrinsicToVec(CI);
  705. } break;
  706. case HLOpcodeGroup::HLMatLoadStore: {
  707. vecInst = MatLdStToVec(CI);
  708. } break;
  709. case HLOpcodeGroup::HLSubscript: {
  710. vecInst = MatSubscriptToVec(CI);
  711. } break;
  712. }
  713. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
  714. Type *Ty = AI->getAllocatedType();
  715. Type *matTy = Ty;
  716. IRBuilder<> Builder(AI);
  717. if (Ty->isArrayTy()) {
  718. Type *vecTy = HLMatrixLower::LowerMatrixArrayPointer(AI->getType());
  719. vecTy = vecTy->getPointerElementType();
  720. vecInst = Builder.CreateAlloca(vecTy, nullptr, AI->getName());
  721. } else {
  722. Type *vecTy = HLMatrixLower::LowerMatrixType(matTy);
  723. vecInst = Builder.CreateAlloca(vecTy, nullptr, AI->getName());
  724. }
  725. // Update debug info.
  726. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI);
  727. if (DDI) {
  728. LLVMContext &Context = AI->getContext();
  729. Value *DDIVar = MetadataAsValue::get(Context, DDI->getRawVariable());
  730. Value *DDIExp = MetadataAsValue::get(Context, DDI->getRawExpression());
  731. Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(vecInst));
  732. IRBuilder<> debugBuilder(DDI);
  733. debugBuilder.CreateCall(DDI->getCalledFunction(), {VMD, DDIVar, DDIExp});
  734. }
  735. if (HLModule::HasPreciseAttributeWithMetadata(AI))
  736. HLModule::MarkPreciseAttributeWithMetadata(vecInst);
  737. } else {
  738. DXASSERT(0, "invalid inst");
  739. }
  740. matToVecMap[matInst] = vecInst;
  741. }
  742. // Replace matInst with vecInst on matUseInst.
  743. void HLMatrixLowerPass::TrivialMatUnOpReplace(CallInst *matInst,
  744. Instruction *vecInst,
  745. CallInst *matUseInst) {
  746. HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(matUseInst));
  747. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
  748. switch (opcode) {
  749. case HLUnaryOpcode::Not:
  750. // Not is xor now
  751. vecUseInst->setOperand(0, vecInst);
  752. vecUseInst->setOperand(1, vecInst);
  753. break;
  754. case HLUnaryOpcode::LNot:
  755. case HLUnaryOpcode::PostInc:
  756. case HLUnaryOpcode::PreInc:
  757. case HLUnaryOpcode::PostDec:
  758. case HLUnaryOpcode::PreDec:
  759. vecUseInst->setOperand(0, vecInst);
  760. break;
  761. }
  762. }
  763. // Replace matInst with vecInst on matUseInst.
  764. void HLMatrixLowerPass::TrivialMatBinOpReplace(CallInst *matInst,
  765. Instruction *vecInst,
  766. CallInst *matUseInst) {
  767. HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(matUseInst));
  768. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
  769. if (opcode != HLBinaryOpcode::LAnd && opcode != HLBinaryOpcode::LOr) {
  770. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) == matInst)
  771. vecUseInst->setOperand(0, vecInst);
  772. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) == matInst)
  773. vecUseInst->setOperand(1, vecInst);
  774. } else {
  775. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) ==
  776. matInst) {
  777. Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(0));
  778. vecCmp->setOperand(0, vecInst);
  779. }
  780. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) ==
  781. matInst) {
  782. Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(1));
  783. vecCmp->setOperand(0, vecInst);
  784. }
  785. }
  786. }
  787. static Function *GetOrCreateMadIntrinsic(Type *Ty, Type *opcodeTy, IntrinsicOp madOp, Module &M) {
  788. llvm::FunctionType *MadFuncTy =
  789. llvm::FunctionType::get(Ty, { opcodeTy, Ty, Ty, Ty}, false);
  790. Function *MAD =
  791. GetOrCreateHLFunction(M, MadFuncTy, HLOpcodeGroup::HLIntrinsic,
  792. (unsigned)madOp);
  793. return MAD;
  794. }
  795. void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
  796. Instruction *vecInst,
  797. CallInst *mulInst, bool isSigned) {
  798. DXASSERT(matToVecMap.count(mulInst), "must has vec version");
  799. Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
  800. // Already translated.
  801. if (!isa<CallInst>(vecUseInst))
  802. return;
  803. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  804. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  805. unsigned col, row;
  806. Type *EltTy = GetMatrixInfo(LVal->getType(), col, row);
  807. unsigned rCol, rRow;
  808. GetMatrixInfo(RVal->getType(), rCol, rRow);
  809. DXASSERT_NOMSG(col == rRow);
  810. bool isFloat = EltTy->isFloatingPointTy();
  811. Value *retVal = llvm::UndefValue::get(LowerMatrixType(mulInst->getType()));
  812. IRBuilder<> Builder(mulInst);
  813. Value *lMat = matToVecMap[cast<Instruction>(LVal)];
  814. Value *rMat = matToVecMap[cast<Instruction>(RVal)];
  815. auto CreateOneEltMul = [&](unsigned r, unsigned lc, unsigned c) -> Value * {
  816. unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
  817. unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
  818. Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
  819. Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
  820. return isFloat ? Builder.CreateFMul(lMatElt, rMatElt)
  821. : Builder.CreateMul(lMatElt, rMatElt);
  822. };
  823. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  824. Type *opcodeTy = Builder.getInt32Ty();
  825. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  826. *m_pHLModule->GetModule());
  827. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  828. auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
  829. Value *acc) -> Value * {
  830. unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
  831. unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
  832. Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
  833. Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
  834. return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
  835. };
  836. for (unsigned r = 0; r < row; r++) {
  837. for (unsigned c = 0; c < rCol; c++) {
  838. unsigned lc = 0;
  839. Value *tmpVal = CreateOneEltMul(r, lc, c);
  840. for (lc = 1; lc < col; lc++) {
  841. tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
  842. }
  843. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, rCol);
  844. retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
  845. }
  846. }
  847. Instruction *matmatMul = cast<Instruction>(retVal);
  848. // Replace vec transpose function call with shuf.
  849. vecUseInst->replaceAllUsesWith(matmatMul);
  850. AddToDeadInsts(vecUseInst);
  851. matToVecMap[mulInst] = matmatMul;
  852. }
  853. void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
  854. Instruction *vecInst,
  855. CallInst *mulInst, bool isSigned) {
  856. // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  857. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  858. unsigned col, row;
  859. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  860. DXASSERT(RVal->getType()->getVectorNumElements() == col, "");
  861. bool isFloat = EltTy->isFloatingPointTy();
  862. Value *retVal = llvm::UndefValue::get(mulInst->getType());
  863. IRBuilder<> Builder(mulInst);
  864. Value *vec = RVal;
  865. Value *mat = vecInst; // vec version of matInst;
  866. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  867. Type *opcodeTy = Builder.getInt32Ty();
  868. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  869. *m_pHLModule->GetModule());
  870. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  871. auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
  872. Value *vecElt = Builder.CreateExtractElement(vec, c);
  873. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  874. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  875. return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
  876. };
  877. for (unsigned r = 0; r < row; r++) {
  878. unsigned c = 0;
  879. Value *vecElt = Builder.CreateExtractElement(vec, c);
  880. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  881. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  882. Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
  883. : Builder.CreateMul(vecElt, matElt);
  884. for (c = 1; c < col; c++) {
  885. tmpVal = CreateOneEltMad(r, c, tmpVal);
  886. }
  887. retVal = Builder.CreateInsertElement(retVal, tmpVal, r);
  888. }
  889. mulInst->replaceAllUsesWith(retVal);
  890. AddToDeadInsts(mulInst);
  891. }
  892. void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
  893. Instruction *vecInst,
  894. CallInst *mulInst, bool isSigned) {
  895. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  896. // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  897. Value *RVal = vecInst;
  898. unsigned col, row;
  899. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  900. DXASSERT(LVal->getType()->getVectorNumElements() == row, "");
  901. bool isFloat = EltTy->isFloatingPointTy();
  902. Value *retVal = llvm::UndefValue::get(mulInst->getType());
  903. IRBuilder<> Builder(mulInst);
  904. Value *vec = LVal;
  905. Value *mat = RVal;
  906. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  907. Type *opcodeTy = Builder.getInt32Ty();
  908. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  909. *m_pHLModule->GetModule());
  910. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  911. auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
  912. Value *vecElt = Builder.CreateExtractElement(vec, r);
  913. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  914. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  915. return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
  916. };
  917. for (unsigned c = 0; c < col; c++) {
  918. unsigned r = 0;
  919. Value *vecElt = Builder.CreateExtractElement(vec, r);
  920. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  921. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  922. Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
  923. : Builder.CreateMul(vecElt, matElt);
  924. for (r = 1; r < row; r++) {
  925. tmpVal = CreateOneEltMad(r, c, tmpVal);
  926. }
  927. retVal = Builder.CreateInsertElement(retVal, tmpVal, c);
  928. }
  929. mulInst->replaceAllUsesWith(retVal);
  930. AddToDeadInsts(mulInst);
  931. }
  932. void HLMatrixLowerPass::TranslateMul(CallInst *matInst, Instruction *vecInst,
  933. CallInst *mulInst, bool isSigned) {
  934. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  935. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  936. bool LMat = IsMatrixType(LVal->getType());
  937. bool RMat = IsMatrixType(RVal->getType());
  938. if (LMat && RMat) {
  939. TranslateMatMatMul(matInst, vecInst, mulInst, isSigned);
  940. } else if (LMat) {
  941. TranslateMatVecMul(matInst, vecInst, mulInst, isSigned);
  942. } else {
  943. TranslateVecMatMul(matInst, vecInst, mulInst, isSigned);
  944. }
  945. }
  946. void HLMatrixLowerPass::TranslateMatTranspose(CallInst *matInst,
  947. Instruction *vecInst,
  948. CallInst *transposeInst) {
  949. // Matrix value is row major, transpose is cast it to col major.
  950. TranslateMatMajorCast(matInst, vecInst, transposeInst,
  951. /*bRowToCol*/ true, /*bTranspose*/ true);
  952. }
  953. static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
  954. IRBuilder<> &Builder) {
  955. Value *mul0 = Builder.CreateFMul(m00, m11);
  956. Value *mul1 = Builder.CreateFMul(m01, m10);
  957. return Builder.CreateFSub(mul0, mul1);
  958. }
  959. static Value *Determinant3x3(Value *m00, Value *m01, Value *m02,
  960. Value *m10, Value *m11, Value *m12,
  961. Value *m20, Value *m21, Value *m22,
  962. IRBuilder<> &Builder) {
  963. Value *deter00 = Determinant2x2(m11, m12, m21, m22, Builder);
  964. Value *deter01 = Determinant2x2(m10, m12, m20, m22, Builder);
  965. Value *deter02 = Determinant2x2(m10, m11, m20, m21, Builder);
  966. deter00 = Builder.CreateFMul(m00, deter00);
  967. deter01 = Builder.CreateFMul(m01, deter01);
  968. deter02 = Builder.CreateFMul(m02, deter02);
  969. Value *result = Builder.CreateFSub(deter00, deter01);
  970. result = Builder.CreateFAdd(result, deter02);
  971. return result;
  972. }
  973. static Value *Determinant4x4(Value *m00, Value *m01, Value *m02, Value *m03,
  974. Value *m10, Value *m11, Value *m12, Value *m13,
  975. Value *m20, Value *m21, Value *m22, Value *m23,
  976. Value *m30, Value *m31, Value *m32, Value *m33,
  977. IRBuilder<> &Builder) {
  978. Value *deter00 = Determinant3x3(m11, m12, m13, m21, m22, m23, m31, m32, m33, Builder);
  979. Value *deter01 = Determinant3x3(m10, m12, m13, m20, m22, m23, m30, m32, m33, Builder);
  980. Value *deter02 = Determinant3x3(m10, m11, m13, m20, m21, m23, m30, m31, m33, Builder);
  981. Value *deter03 = Determinant3x3(m10, m11, m12, m20, m21, m22, m30, m31, m32, Builder);
  982. deter00 = Builder.CreateFMul(m00, deter00);
  983. deter01 = Builder.CreateFMul(m01, deter01);
  984. deter02 = Builder.CreateFMul(m02, deter02);
  985. deter03 = Builder.CreateFMul(m03, deter03);
  986. Value *result = Builder.CreateFSub(deter00, deter01);
  987. result = Builder.CreateFAdd(result, deter02);
  988. result = Builder.CreateFSub(result, deter03);
  989. return result;
  990. }
  991. void HLMatrixLowerPass::TranslateMatDeterminant(CallInst *matInst, Instruction *vecInst,
  992. CallInst *determinantInst) {
  993. unsigned row, col;
  994. GetMatrixInfo(matInst->getType(), col, row);
  995. IRBuilder<> Builder(determinantInst);
  996. // when row == 1, result is vecInst.
  997. Value *Result = vecInst;
  998. if (row == 2) {
  999. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  1000. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  1001. Value *m10 = Builder.CreateExtractElement(vecInst, 2);
  1002. Value *m11 = Builder.CreateExtractElement(vecInst, 3);
  1003. Result = Determinant2x2(m00, m01, m10, m11, Builder);
  1004. }
  1005. else if (row == 3) {
  1006. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  1007. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  1008. Value *m02 = Builder.CreateExtractElement(vecInst, 2);
  1009. Value *m10 = Builder.CreateExtractElement(vecInst, 3);
  1010. Value *m11 = Builder.CreateExtractElement(vecInst, 4);
  1011. Value *m12 = Builder.CreateExtractElement(vecInst, 5);
  1012. Value *m20 = Builder.CreateExtractElement(vecInst, 6);
  1013. Value *m21 = Builder.CreateExtractElement(vecInst, 7);
  1014. Value *m22 = Builder.CreateExtractElement(vecInst, 8);
  1015. Result = Determinant3x3(m00, m01, m02,
  1016. m10, m11, m12,
  1017. m20, m21, m22, Builder);
  1018. }
  1019. else if (row == 4) {
  1020. Value *m00 = Builder.CreateExtractElement(vecInst, (uint64_t)0);
  1021. Value *m01 = Builder.CreateExtractElement(vecInst, 1);
  1022. Value *m02 = Builder.CreateExtractElement(vecInst, 2);
  1023. Value *m03 = Builder.CreateExtractElement(vecInst, 3);
  1024. Value *m10 = Builder.CreateExtractElement(vecInst, 4);
  1025. Value *m11 = Builder.CreateExtractElement(vecInst, 5);
  1026. Value *m12 = Builder.CreateExtractElement(vecInst, 6);
  1027. Value *m13 = Builder.CreateExtractElement(vecInst, 7);
  1028. Value *m20 = Builder.CreateExtractElement(vecInst, 8);
  1029. Value *m21 = Builder.CreateExtractElement(vecInst, 9);
  1030. Value *m22 = Builder.CreateExtractElement(vecInst, 10);
  1031. Value *m23 = Builder.CreateExtractElement(vecInst, 11);
  1032. Value *m30 = Builder.CreateExtractElement(vecInst, 12);
  1033. Value *m31 = Builder.CreateExtractElement(vecInst, 13);
  1034. Value *m32 = Builder.CreateExtractElement(vecInst, 14);
  1035. Value *m33 = Builder.CreateExtractElement(vecInst, 15);
  1036. Result = Determinant4x4(m00, m01, m02, m03,
  1037. m10, m11, m12, m13,
  1038. m20, m21, m22, m23,
  1039. m30, m31, m32, m33,
  1040. Builder);
  1041. } else {
  1042. DXASSERT(row == 1, "invalid matrix type");
  1043. Result = Builder.CreateExtractElement(Result, (uint64_t)0);
  1044. }
  1045. determinantInst->replaceAllUsesWith(Result);
  1046. AddToDeadInsts(determinantInst);
  1047. }
  1048. void HLMatrixLowerPass::TrivialMatReplace(CallInst *matInst,
  1049. Instruction *vecInst,
  1050. CallInst *matUseInst) {
  1051. CallInst *vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
  1052. for (unsigned i = 0; i < matUseInst->getNumArgOperands(); i++)
  1053. if (matUseInst->getArgOperand(i) == matInst) {
  1054. vecUseInst->setArgOperand(i, vecInst);
  1055. }
  1056. }
  1057. void HLMatrixLowerPass::TranslateMatMajorCast(CallInst *matInst,
  1058. Instruction *vecInst,
  1059. CallInst *castInst,
  1060. bool bRowToCol,
  1061. bool bTranspose) {
  1062. unsigned col, row;
  1063. if (!bTranspose) {
  1064. GetMatrixInfo(castInst->getType(), col, row);
  1065. DXASSERT(castInst->getType() == matInst->getType(), "type must match");
  1066. } else {
  1067. unsigned castCol, castRow;
  1068. Type *castTy = GetMatrixInfo(castInst->getType(), castCol, castRow);
  1069. unsigned srcCol, srcRow;
  1070. Type *srcTy = GetMatrixInfo(matInst->getType(), srcCol, srcRow);
  1071. DXASSERT(srcTy == castTy, "type must match");
  1072. DXASSERT(castCol == srcRow && castRow == srcCol, "col row must match");
  1073. col = srcCol;
  1074. row = srcRow;
  1075. }
  1076. IRBuilder<> Builder(castInst);
  1077. // shuf to change major.
  1078. SmallVector<int, 16> castMask(col * row);
  1079. unsigned idx = 0;
  1080. if (bRowToCol) {
  1081. for (unsigned c = 0; c < col; c++)
  1082. for (unsigned r = 0; r < row; r++) {
  1083. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  1084. castMask[idx++] = matIdx;
  1085. }
  1086. } else {
  1087. for (unsigned r = 0; r < row; r++)
  1088. for (unsigned c = 0; c < col; c++) {
  1089. unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
  1090. castMask[idx++] = matIdx;
  1091. }
  1092. }
  1093. Instruction *vecCast = cast<Instruction>(
  1094. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1095. // Replace vec cast function call with vecCast.
  1096. DXASSERT(matToVecMap.count(castInst), "must has vec version");
  1097. Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
  1098. vecUseInst->replaceAllUsesWith(vecCast);
  1099. AddToDeadInsts(vecUseInst);
  1100. matToVecMap[castInst] = vecCast;
  1101. }
  1102. void HLMatrixLowerPass::TranslateMatMatCast(CallInst *matInst,
  1103. Instruction *vecInst,
  1104. CallInst *castInst) {
  1105. unsigned toCol, toRow;
  1106. Type *ToEltTy = GetMatrixInfo(castInst->getType(), toCol, toRow);
  1107. unsigned fromCol, fromRow;
  1108. Type *FromEltTy = GetMatrixInfo(matInst->getType(), fromCol, fromRow);
  1109. unsigned fromSize = fromCol * fromRow;
  1110. unsigned toSize = toCol * toRow;
  1111. DXASSERT(fromSize >= toSize, "cannot extend matrix");
  1112. IRBuilder<> Builder(castInst);
  1113. Instruction *vecCast = nullptr;
  1114. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1115. if (fromSize == toSize) {
  1116. vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), vecInst,
  1117. Builder);
  1118. } else {
  1119. // shuf first
  1120. std::vector<int> castMask(toCol * toRow);
  1121. unsigned idx = 0;
  1122. for (unsigned r = 0; r < toRow; r++)
  1123. for (unsigned c = 0; c < toCol; c++) {
  1124. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, fromCol);
  1125. castMask[idx++] = matIdx;
  1126. }
  1127. Instruction *shuf = cast<Instruction>(
  1128. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1129. if (ToEltTy != FromEltTy)
  1130. vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), shuf,
  1131. Builder);
  1132. else
  1133. vecCast = shuf;
  1134. }
  1135. // Replace vec cast function call with vecCast.
  1136. DXASSERT(matToVecMap.count(castInst), "must has vec version");
  1137. Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
  1138. vecUseInst->replaceAllUsesWith(vecCast);
  1139. AddToDeadInsts(vecUseInst);
  1140. matToVecMap[castInst] = vecCast;
  1141. }
  1142. void HLMatrixLowerPass::TranslateMatToOtherCast(CallInst *matInst,
  1143. Instruction *vecInst,
  1144. CallInst *castInst) {
  1145. unsigned col, row;
  1146. Type *EltTy = GetMatrixInfo(matInst->getType(), col, row);
  1147. unsigned fromSize = col * row;
  1148. IRBuilder<> Builder(castInst);
  1149. Instruction *sizeCast = nullptr;
  1150. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1151. Type *ToTy = castInst->getType();
  1152. if (ToTy->isVectorTy()) {
  1153. unsigned toSize = ToTy->getVectorNumElements();
  1154. if (fromSize != toSize) {
  1155. std::vector<int> castMask(fromSize);
  1156. for (unsigned c = 0; c < toSize; c++)
  1157. castMask[c] = c;
  1158. sizeCast = cast<Instruction>(
  1159. Builder.CreateShuffleVector(vecInst, vecInst, castMask));
  1160. } else
  1161. sizeCast = vecInst;
  1162. } else {
  1163. DXASSERT(ToTy->isSingleValueType(), "must scalar here");
  1164. sizeCast =
  1165. cast<Instruction>(Builder.CreateExtractElement(vecInst, (uint64_t)0));
  1166. }
  1167. Instruction *typeCast = sizeCast;
  1168. if (EltTy != ToTy->getScalarType()) {
  1169. typeCast = CreateTypeCast(opcode, ToTy, typeCast, Builder);
  1170. }
  1171. // Replace cast function call with typeCast.
  1172. castInst->replaceAllUsesWith(typeCast);
  1173. AddToDeadInsts(castInst);
  1174. }
  1175. void HLMatrixLowerPass::TranslateMatCast(CallInst *matInst,
  1176. Instruction *vecInst,
  1177. CallInst *castInst) {
  1178. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1179. if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
  1180. opcode == HLCastOpcode::RowMatrixToColMatrix) {
  1181. TranslateMatMajorCast(matInst, vecInst, castInst,
  1182. opcode == HLCastOpcode::RowMatrixToColMatrix,
  1183. /*bTranspose*/false);
  1184. } else {
  1185. bool ToMat = IsMatrixType(castInst->getType());
  1186. bool FromMat = IsMatrixType(matInst->getType());
  1187. if (ToMat && FromMat) {
  1188. TranslateMatMatCast(matInst, vecInst, castInst);
  1189. } else if (FromMat)
  1190. TranslateMatToOtherCast(matInst, vecInst, castInst);
  1191. else
  1192. DXASSERT(0, "Not translate as user of matInst");
  1193. }
  1194. }
  1195. void HLMatrixLowerPass::MatIntrinsicReplace(CallInst *matInst,
  1196. Instruction *vecInst,
  1197. CallInst *matUseInst) {
  1198. IRBuilder<> Builder(matUseInst);
  1199. IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(matUseInst));
  1200. switch (opcode) {
  1201. case IntrinsicOp::IOP_umul:
  1202. TranslateMul(matInst, vecInst, matUseInst, /*isSigned*/false);
  1203. break;
  1204. case IntrinsicOp::IOP_mul:
  1205. TranslateMul(matInst, vecInst, matUseInst, /*isSigned*/true);
  1206. break;
  1207. case IntrinsicOp::IOP_transpose:
  1208. TranslateMatTranspose(matInst, vecInst, matUseInst);
  1209. break;
  1210. case IntrinsicOp::IOP_determinant:
  1211. TranslateMatDeterminant(matInst, vecInst, matUseInst);
  1212. break;
  1213. default:
  1214. CallInst *vecUseInst = nullptr;
  1215. if (matToVecMap.count(matUseInst))
  1216. vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
  1217. for (unsigned i = 0; i < matInst->getNumArgOperands(); i++)
  1218. if (matUseInst->getArgOperand(i) == matInst) {
  1219. if (vecUseInst)
  1220. vecUseInst->setArgOperand(i, vecInst);
  1221. else
  1222. matUseInst->setArgOperand(i, vecInst);
  1223. }
  1224. break;
  1225. }
  1226. }
  1227. void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
  1228. CallInst *matSubInst) {
  1229. unsigned opcode = GetHLOpcode(matSubInst);
  1230. HLSubscriptOpcode matOpcode = static_cast<HLSubscriptOpcode>(opcode);
  1231. assert(matOpcode != HLSubscriptOpcode::DefaultSubscript &&
  1232. "matrix don't use default subscript");
  1233. Type *matType = matInst->getType()->getPointerElementType();
  1234. unsigned col, row;
  1235. Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
  1236. bool isElement = (matOpcode == HLSubscriptOpcode::ColMatElement) |
  1237. (matOpcode == HLSubscriptOpcode::RowMatElement);
  1238. Value *mask =
  1239. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1240. if (isElement) {
  1241. Type *resultType = matSubInst->getType()->getPointerElementType();
  1242. unsigned resultSize = 1;
  1243. if (resultType->isVectorTy())
  1244. resultSize = resultType->getVectorNumElements();
  1245. std::vector<int> shufMask(resultSize);
  1246. Constant *EltIdxs = cast<Constant>(mask);
  1247. for (unsigned i = 0; i < resultSize; i++) {
  1248. shufMask[i] =
  1249. cast<ConstantInt>(EltIdxs->getAggregateElement(i))->getLimitedValue();
  1250. }
  1251. for (Value::use_iterator CallUI = matSubInst->use_begin(),
  1252. CallE = matSubInst->use_end();
  1253. CallUI != CallE;) {
  1254. Use &CallUse = *CallUI++;
  1255. Instruction *CallUser = cast<Instruction>(CallUse.getUser());
  1256. IRBuilder<> Builder(CallUser);
  1257. Value *vecLd = Builder.CreateLoad(vecInst);
  1258. if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
  1259. if (resultSize > 1) {
  1260. Value *shuf = Builder.CreateShuffleVector(vecLd, vecLd, shufMask);
  1261. ld->replaceAllUsesWith(shuf);
  1262. } else {
  1263. Value *elt = Builder.CreateExtractElement(vecLd, shufMask[0]);
  1264. ld->replaceAllUsesWith(elt);
  1265. }
  1266. } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
  1267. Value *val = st->getValueOperand();
  1268. if (resultSize > 1) {
  1269. for (unsigned i = 0; i < shufMask.size(); i++) {
  1270. unsigned idx = shufMask[i];
  1271. Value *valElt = Builder.CreateExtractElement(val, i);
  1272. vecLd = Builder.CreateInsertElement(vecLd, valElt, idx);
  1273. }
  1274. Builder.CreateStore(vecLd, vecInst);
  1275. } else {
  1276. vecLd = Builder.CreateInsertElement(vecLd, val, shufMask[0]);
  1277. Builder.CreateStore(vecLd, vecInst);
  1278. }
  1279. } else
  1280. DXASSERT(0, "matrix element should only used by load/store.");
  1281. AddToDeadInsts(CallUser);
  1282. }
  1283. } else {
  1284. // Subscript.
  1285. // Return a row.
  1286. // Use insertElement and extractElement.
  1287. ArrayType *AT = ArrayType::get(EltTy, col*row);
  1288. IRBuilder<> AllocaBuilder(
  1289. matSubInst->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  1290. Value *tempArray = AllocaBuilder.CreateAlloca(AT);
  1291. Value *zero = AllocaBuilder.getInt32(0);
  1292. bool isDynamicIndexing = !isa<ConstantInt>(mask);
  1293. SmallVector<Value *, 4> idxList;
  1294. for (unsigned i = 0; i < col; i++) {
  1295. idxList.emplace_back(
  1296. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i));
  1297. }
  1298. for (Value::use_iterator CallUI = matSubInst->use_begin(),
  1299. CallE = matSubInst->use_end();
  1300. CallUI != CallE;) {
  1301. Use &CallUse = *CallUI++;
  1302. Instruction *CallUser = cast<Instruction>(CallUse.getUser());
  1303. IRBuilder<> Builder(CallUser);
  1304. Value *vecLd = Builder.CreateLoad(vecInst);
  1305. if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
  1306. Value *sub = UndefValue::get(ld->getType());
  1307. if (!isDynamicIndexing) {
  1308. for (unsigned i = 0; i < col; i++) {
  1309. Value *matIdx = idxList[i];
  1310. Value *valElt = Builder.CreateExtractElement(vecLd, matIdx);
  1311. sub = Builder.CreateInsertElement(sub, valElt, i);
  1312. }
  1313. } else {
  1314. // Copy vec to array.
  1315. for (unsigned int i = 0; i < row*col; i++) {
  1316. Value *Elt =
  1317. Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
  1318. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1319. {zero, Builder.getInt32(i)});
  1320. Builder.CreateStore(Elt, Ptr);
  1321. }
  1322. for (unsigned i = 0; i < col; i++) {
  1323. Value *matIdx = idxList[i];
  1324. Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
  1325. Value *valElt = Builder.CreateLoad(Ptr);
  1326. sub = Builder.CreateInsertElement(sub, valElt, i);
  1327. }
  1328. }
  1329. ld->replaceAllUsesWith(sub);
  1330. } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
  1331. Value *val = st->getValueOperand();
  1332. if (!isDynamicIndexing) {
  1333. for (unsigned i = 0; i < col; i++) {
  1334. Value *matIdx = idxList[i];
  1335. Value *valElt = Builder.CreateExtractElement(val, i);
  1336. vecLd = Builder.CreateInsertElement(vecLd, valElt, matIdx);
  1337. }
  1338. } else {
  1339. // Copy vec to array.
  1340. for (unsigned int i = 0; i < row * col; i++) {
  1341. Value *Elt =
  1342. Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
  1343. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1344. {zero, Builder.getInt32(i)});
  1345. Builder.CreateStore(Elt, Ptr);
  1346. }
  1347. // Update array.
  1348. for (unsigned i = 0; i < col; i++) {
  1349. Value *matIdx = idxList[i];
  1350. Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
  1351. Value *valElt = Builder.CreateExtractElement(val, i);
  1352. Builder.CreateStore(valElt, Ptr);
  1353. }
  1354. // Copy array to vec.
  1355. for (unsigned int i = 0; i < row * col; i++) {
  1356. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1357. {zero, Builder.getInt32(i)});
  1358. Value *Elt = Builder.CreateLoad(Ptr);
  1359. vecLd = Builder.CreateInsertElement(vecLd, Elt, i);
  1360. }
  1361. }
  1362. Builder.CreateStore(vecLd, vecInst);
  1363. } else if (GetElementPtrInst *GEP =
  1364. dyn_cast<GetElementPtrInst>(CallUser)) {
  1365. Value *GEPOffset = HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
  1366. Value *NewGEP = Builder.CreateGEP(vecInst, {zero, GEPOffset});
  1367. GEP->replaceAllUsesWith(NewGEP);
  1368. } else
  1369. DXASSERT(0, "matrix subscript should only used by load/store.");
  1370. AddToDeadInsts(CallUser);
  1371. }
  1372. }
  1373. // Check vec version.
  1374. DXASSERT(matToVecMap.count(matSubInst) == 0, "should not have vec version");
  1375. // All the user should has been removed.
  1376. matSubInst->replaceAllUsesWith(UndefValue::get(matSubInst->getType()));
  1377. AddToDeadInsts(matSubInst);
  1378. }
  1379. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
  1380. Value *matGlobal, ArrayRef<Value *> vecGlobals,
  1381. CallInst *matLdStInst) {
  1382. // No dynamic indexing on matrix, flatten matrix to scalars.
  1383. // vecGlobals already in correct major.
  1384. Type *matType = matGlobal->getType()->getPointerElementType();
  1385. unsigned col, row;
  1386. HLMatrixLower::GetMatrixInfo(matType, col, row);
  1387. Type *vecType = HLMatrixLower::LowerMatrixType(matType);
  1388. IRBuilder<> Builder(matLdStInst);
  1389. HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
  1390. switch (opcode) {
  1391. case HLMatLoadStoreOpcode::ColMatLoad:
  1392. case HLMatLoadStoreOpcode::RowMatLoad: {
  1393. Value *Result = UndefValue::get(vecType);
  1394. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1395. Value *Elt = Builder.CreateLoad(vecGlobals[matIdx]);
  1396. Result = Builder.CreateInsertElement(Result, Elt, matIdx);
  1397. }
  1398. matLdStInst->replaceAllUsesWith(Result);
  1399. } break;
  1400. case HLMatLoadStoreOpcode::ColMatStore:
  1401. case HLMatLoadStoreOpcode::RowMatStore: {
  1402. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1403. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1404. Value *Elt = Builder.CreateExtractElement(Val, matIdx);
  1405. Builder.CreateStore(Elt, vecGlobals[matIdx]);
  1406. }
  1407. } break;
  1408. }
  1409. }
  1410. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
  1411. GlobalVariable *scalarArrayGlobal,
  1412. CallInst *matLdStInst) {
  1413. // vecGlobals already in correct major.
  1414. const bool bColMajor = true;
  1415. HLMatLoadStoreOpcode opcode =
  1416. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
  1417. switch (opcode) {
  1418. case HLMatLoadStoreOpcode::ColMatLoad:
  1419. case HLMatLoadStoreOpcode::RowMatLoad: {
  1420. IRBuilder<> Builder(matLdStInst);
  1421. Type *matTy = matGlobal->getType()->getPointerElementType();
  1422. unsigned col, row;
  1423. Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1424. Value *zeroIdx = Builder.getInt32(0);
  1425. std::vector<Value *> matElts(col * row);
  1426. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1427. Value *GEP = Builder.CreateInBoundsGEP(
  1428. scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
  1429. matElts[matIdx] = Builder.CreateLoad(GEP);
  1430. }
  1431. Value *newVec =
  1432. HLMatrixLower::BuildVector(EltTy, col * row, matElts, Builder);
  1433. matLdStInst->replaceAllUsesWith(newVec);
  1434. matLdStInst->eraseFromParent();
  1435. } break;
  1436. case HLMatLoadStoreOpcode::ColMatStore:
  1437. case HLMatLoadStoreOpcode::RowMatStore: {
  1438. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1439. IRBuilder<> Builder(matLdStInst);
  1440. Type *matTy = matGlobal->getType()->getPointerElementType();
  1441. unsigned col, row;
  1442. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1443. Value *zeroIdx = Builder.getInt32(0);
  1444. std::vector<Value *> matElts(col * row);
  1445. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1446. Value *GEP = Builder.CreateInBoundsGEP(
  1447. scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
  1448. Value *Elt = Builder.CreateExtractElement(Val, matIdx);
  1449. Builder.CreateStore(Elt, GEP);
  1450. }
  1451. matLdStInst->eraseFromParent();
  1452. } break;
  1453. }
  1454. }
  1455. void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(
  1456. CallInst *matSubInst, Value *vecPtr) {
  1457. Value *basePtr =
  1458. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  1459. Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1460. IRBuilder<> subBuilder(matSubInst);
  1461. Value *zeroIdx = subBuilder.getInt32(0);
  1462. HLSubscriptOpcode opcode =
  1463. static_cast<HLSubscriptOpcode>(GetHLOpcode(matSubInst));
  1464. Type *matTy = basePtr->getType()->getPointerElementType();
  1465. unsigned col, row;
  1466. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1467. std::vector<Value *> idxList;
  1468. switch (opcode) {
  1469. case HLSubscriptOpcode::ColMatSubscript:
  1470. case HLSubscriptOpcode::RowMatSubscript: {
  1471. // Just use index created in EmitHLSLMatrixSubscript.
  1472. for (unsigned c = 0; c < col; c++) {
  1473. Value *matIdx =
  1474. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + c);
  1475. idxList.emplace_back(matIdx);
  1476. }
  1477. } break;
  1478. case HLSubscriptOpcode::RowMatElement:
  1479. case HLSubscriptOpcode::ColMatElement: {
  1480. Type *resultType = matSubInst->getType()->getPointerElementType();
  1481. unsigned resultSize = 1;
  1482. if (resultType->isVectorTy())
  1483. resultSize = resultType->getVectorNumElements();
  1484. // Just use index created in EmitHLSLMatrixElement.
  1485. Constant *EltIdxs = cast<Constant>(idx);
  1486. for (unsigned i = 0; i < resultSize; i++) {
  1487. Value *matIdx = EltIdxs->getAggregateElement(i);
  1488. idxList.emplace_back(matIdx);
  1489. }
  1490. } break;
  1491. default:
  1492. DXASSERT(0, "invalid operation");
  1493. break;
  1494. }
  1495. // Cannot generate vector pointer
  1496. // Replace all uses with scalar pointers.
  1497. if (idxList.size() == 1) {
  1498. Value *Ptr =
  1499. subBuilder.CreateInBoundsGEP(vecPtr, {zeroIdx, idxList[0]});
  1500. matSubInst->replaceAllUsesWith(Ptr);
  1501. } else {
  1502. // Split the use of CI with Ptrs.
  1503. for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
  1504. Instruction *subsUser = cast<Instruction>(*(U++));
  1505. IRBuilder<> userBuilder(subsUser);
  1506. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
  1507. Value *IndexPtr =
  1508. HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
  1509. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1510. {zeroIdx, IndexPtr});
  1511. for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
  1512. Instruction *gepUser = cast<Instruction>(*(gepU++));
  1513. IRBuilder<> gepUserBuilder(gepUser);
  1514. if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
  1515. Value *subData = stUser->getValueOperand();
  1516. gepUserBuilder.CreateStore(subData, Ptr);
  1517. stUser->eraseFromParent();
  1518. } else if (LoadInst *ldUser = dyn_cast<LoadInst>(gepUser)) {
  1519. Value *subData = gepUserBuilder.CreateLoad(Ptr);
  1520. ldUser->replaceAllUsesWith(subData);
  1521. ldUser->eraseFromParent();
  1522. } else {
  1523. AddrSpaceCastInst *Cast = cast<AddrSpaceCastInst>(gepUser);
  1524. Cast->setOperand(0, Ptr);
  1525. }
  1526. }
  1527. GEP->eraseFromParent();
  1528. } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
  1529. Value *val = stUser->getValueOperand();
  1530. for (unsigned i = 0; i < idxList.size(); i++) {
  1531. Value *Elt = userBuilder.CreateExtractElement(val, i);
  1532. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1533. {zeroIdx, idxList[i]});
  1534. userBuilder.CreateStore(Elt, Ptr);
  1535. }
  1536. stUser->eraseFromParent();
  1537. } else {
  1538. Value *ldVal =
  1539. UndefValue::get(matSubInst->getType()->getPointerElementType());
  1540. for (unsigned i = 0; i < idxList.size(); i++) {
  1541. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1542. {zeroIdx, idxList[i]});
  1543. Value *Elt = userBuilder.CreateLoad(Ptr);
  1544. ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
  1545. }
  1546. // Must be load here.
  1547. LoadInst *ldUser = cast<LoadInst>(subsUser);
  1548. ldUser->replaceAllUsesWith(ldVal);
  1549. ldUser->eraseFromParent();
  1550. }
  1551. }
  1552. }
  1553. matSubInst->eraseFromParent();
  1554. }
  1555. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
  1556. CallInst *matLdStInst, Value *vecPtr) {
  1557. // Just translate into vector here.
  1558. // DynamicIndexingVectorToArray will change it to scalar array.
  1559. IRBuilder<> Builder(matLdStInst);
  1560. unsigned opcode = hlsl::GetHLOpcode(matLdStInst);
  1561. HLMatLoadStoreOpcode matLdStOp = static_cast<HLMatLoadStoreOpcode>(opcode);
  1562. switch (matLdStOp) {
  1563. case HLMatLoadStoreOpcode::ColMatLoad:
  1564. case HLMatLoadStoreOpcode::RowMatLoad: {
  1565. // Load as vector.
  1566. Value *newLoad = Builder.CreateLoad(vecPtr);
  1567. matLdStInst->replaceAllUsesWith(newLoad);
  1568. matLdStInst->eraseFromParent();
  1569. } break;
  1570. case HLMatLoadStoreOpcode::ColMatStore:
  1571. case HLMatLoadStoreOpcode::RowMatStore: {
  1572. // Change value to vector array, then store.
  1573. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1574. Value *vecArrayGep = vecPtr;
  1575. Builder.CreateStore(Val, vecArrayGep);
  1576. matLdStInst->eraseFromParent();
  1577. } break;
  1578. default:
  1579. DXASSERT(0, "invalid operation");
  1580. break;
  1581. }
  1582. }
  1583. // Flatten values inside init list to scalar elements.
  1584. static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
  1585. Value *val,
  1586. DenseMap<Instruction *, Value *> &matToVecMap,
  1587. IRBuilder<> &Builder) {
  1588. Type *valTy = val->getType();
  1589. if (valTy->isPointerTy()) {
  1590. if (HLMatrixLower::IsMatrixArrayPointer(valTy)) {
  1591. if (matToVecMap.count(cast<Instruction>(val))) {
  1592. val = matToVecMap[cast<Instruction>(val)];
  1593. } else {
  1594. // Convert to vec array with bitcast.
  1595. Type *vecArrayPtrTy = HLMatrixLower::LowerMatrixArrayPointer(valTy);
  1596. val = Builder.CreateBitCast(val, vecArrayPtrTy);
  1597. }
  1598. }
  1599. Type *valEltTy = val->getType()->getPointerElementType();
  1600. if (valEltTy->isVectorTy() || HLMatrixLower::IsMatrixType(valEltTy) ||
  1601. valEltTy->isSingleValueType()) {
  1602. Value *ldVal = Builder.CreateLoad(val);
  1603. IterateInitList(elts, idx, ldVal, matToVecMap, Builder);
  1604. } else {
  1605. Type *i32Ty = Type::getInt32Ty(valTy->getContext());
  1606. Value *zero = ConstantInt::get(i32Ty, 0);
  1607. if (ArrayType *AT = dyn_cast<ArrayType>(valEltTy)) {
  1608. for (unsigned i = 0; i < AT->getArrayNumElements(); i++) {
  1609. Value *gepIdx = ConstantInt::get(i32Ty, i);
  1610. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  1611. IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
  1612. }
  1613. } else {
  1614. // Struct.
  1615. StructType *ST = cast<StructType>(valEltTy);
  1616. for (unsigned i = 0; i < ST->getNumElements(); i++) {
  1617. Value *gepIdx = ConstantInt::get(i32Ty, i);
  1618. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  1619. IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
  1620. }
  1621. }
  1622. }
  1623. } else if (HLMatrixLower::IsMatrixType(valTy)) {
  1624. unsigned col, row;
  1625. HLMatrixLower::GetMatrixInfo(valTy, col, row);
  1626. unsigned matSize = col * row;
  1627. val = matToVecMap[cast<Instruction>(val)];
  1628. // temp matrix all row major
  1629. for (unsigned i = 0; i < matSize; i++) {
  1630. Value *Elt = Builder.CreateExtractElement(val, i);
  1631. elts[idx + i] = Elt;
  1632. }
  1633. idx += matSize;
  1634. } else {
  1635. if (valTy->isVectorTy()) {
  1636. unsigned vecSize = valTy->getVectorNumElements();
  1637. for (unsigned i = 0; i < vecSize; i++) {
  1638. Value *Elt = Builder.CreateExtractElement(val, i);
  1639. elts[idx + i] = Elt;
  1640. }
  1641. idx += vecSize;
  1642. } else {
  1643. DXASSERT(valTy->isSingleValueType(), "must be single value type here");
  1644. elts[idx++] = val;
  1645. }
  1646. }
  1647. }
  1648. // Store flattened init list elements into matrix array.
  1649. static void GenerateMatArrayInit(ArrayRef<Value *> elts, Value *ptr,
  1650. unsigned &offset, IRBuilder<> &Builder) {
  1651. Type *Ty = ptr->getType()->getPointerElementType();
  1652. if (Ty->isVectorTy()) {
  1653. unsigned vecSize = Ty->getVectorNumElements();
  1654. Type *eltTy = Ty->getVectorElementType();
  1655. Value *result = UndefValue::get(Ty);
  1656. for (unsigned i = 0; i < vecSize; i++) {
  1657. Value *elt = elts[offset + i];
  1658. if (elt->getType() != eltTy) {
  1659. // FIXME: get signed/unsigned info.
  1660. elt = CreateTypeCast(HLCastOpcode::DefaultCast, eltTy, elt, Builder);
  1661. }
  1662. result = Builder.CreateInsertElement(result, elt, i);
  1663. }
  1664. // Update offset.
  1665. offset += vecSize;
  1666. Builder.CreateStore(result, ptr);
  1667. } else {
  1668. DXASSERT(Ty->isArrayTy(), "must be array type");
  1669. Type *i32Ty = Type::getInt32Ty(Ty->getContext());
  1670. Constant *zero = ConstantInt::get(i32Ty, 0);
  1671. unsigned arraySize = Ty->getArrayNumElements();
  1672. for (unsigned i = 0; i < arraySize; i++) {
  1673. Value *GEP =
  1674. Builder.CreateInBoundsGEP(ptr, {zero, ConstantInt::get(i32Ty, i)});
  1675. GenerateMatArrayInit(elts, GEP, offset, Builder);
  1676. }
  1677. }
  1678. }
  1679. void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
  1680. // Array matrix init will be translated in TranslateMatArrayInitReplace.
  1681. if (matInitInst->getType()->isVoidTy())
  1682. return;
  1683. IRBuilder<> Builder(matInitInst);
  1684. unsigned col, row;
  1685. Type *EltTy = GetMatrixInfo(matInitInst->getType(), col, row);
  1686. Type *vecTy = VectorType::get(EltTy, col * row);
  1687. unsigned vecSize = vecTy->getVectorNumElements();
  1688. unsigned idx = 0;
  1689. std::vector<Value *> elts(vecSize);
  1690. // Skip opcode arg.
  1691. for (unsigned i = 1; i < matInitInst->getNumArgOperands(); i++) {
  1692. Value *val = matInitInst->getArgOperand(i);
  1693. IterateInitList(elts, idx, val, matToVecMap, Builder);
  1694. }
  1695. Value *newInit = UndefValue::get(vecTy);
  1696. // InitList is row major, the result is row major too.
  1697. for (unsigned i=0;i< col * row;i++) {
  1698. Constant *vecIdx = Builder.getInt32(i);
  1699. newInit = InsertElementInst::Create(newInit, elts[i], vecIdx);
  1700. Builder.Insert(cast<Instruction>(newInit));
  1701. }
  1702. // Replace matInit function call with matInitInst.
  1703. DXASSERT(matToVecMap.count(matInitInst), "must has vec version");
  1704. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
  1705. vecUseInst->replaceAllUsesWith(newInit);
  1706. AddToDeadInsts(vecUseInst);
  1707. matToVecMap[matInitInst] = newInit;
  1708. }
  1709. void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
  1710. IRBuilder<> Builder(matSelectInst);
  1711. unsigned col, row;
  1712. Type *EltTy = GetMatrixInfo(matSelectInst->getType(), col, row);
  1713. Type *vecTy = VectorType::get(EltTy, col * row);
  1714. unsigned vecSize = vecTy->getVectorNumElements();
  1715. CallInst *vecUseInst = cast<CallInst>(matToVecMap[matSelectInst]);
  1716. Instruction *LHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx));
  1717. Instruction *RHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx));
  1718. Value *Cond = vecUseInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
  1719. bool isVecCond = Cond->getType()->isVectorTy();
  1720. if (isVecCond) {
  1721. Instruction *MatCond = cast<Instruction>(
  1722. matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx));
  1723. DXASSERT_NOMSG(matToVecMap.count(MatCond));
  1724. Cond = matToVecMap[MatCond];
  1725. }
  1726. DXASSERT_NOMSG(matToVecMap.count(LHS));
  1727. Value *VLHS = matToVecMap[LHS];
  1728. DXASSERT_NOMSG(matToVecMap.count(RHS));
  1729. Value *VRHS = matToVecMap[RHS];
  1730. Value *VecSelect = UndefValue::get(vecTy);
  1731. for (unsigned i = 0; i < vecSize; i++) {
  1732. llvm::Value *EltCond = Cond;
  1733. if (isVecCond)
  1734. EltCond = Builder.CreateExtractElement(Cond, i);
  1735. llvm::Value *EltL = Builder.CreateExtractElement(VLHS, i);
  1736. llvm::Value *EltR = Builder.CreateExtractElement(VRHS, i);
  1737. llvm::Value *EltSelect = Builder.CreateSelect(EltCond, EltL, EltR);
  1738. VecSelect = Builder.CreateInsertElement(VecSelect, EltSelect, i);
  1739. }
  1740. AddToDeadInsts(vecUseInst);
  1741. vecUseInst->replaceAllUsesWith(VecSelect);
  1742. matToVecMap[matSelectInst] = VecSelect;
  1743. }
  1744. void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
  1745. Instruction *vecInst,
  1746. GetElementPtrInst *matGEP) {
  1747. SmallVector<Value *, 4> idxList(matGEP->idx_begin(), matGEP->idx_end());
  1748. IRBuilder<> GEPBuilder(matGEP);
  1749. Value *newGEP = GEPBuilder.CreateInBoundsGEP(vecInst, idxList);
  1750. // Only used by mat subscript and mat ld/st.
  1751. for (Value::user_iterator user = matGEP->user_begin();
  1752. user != matGEP->user_end();) {
  1753. Instruction *useInst = cast<Instruction>(*(user++));
  1754. IRBuilder<> Builder(useInst);
  1755. // Skip return here.
  1756. if (isa<ReturnInst>(useInst))
  1757. continue;
  1758. if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
  1759. // Function call.
  1760. hlsl::HLOpcodeGroup group =
  1761. hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
  1762. switch (group) {
  1763. case HLOpcodeGroup::HLMatLoadStore: {
  1764. unsigned opcode = GetHLOpcode(useCall);
  1765. HLMatLoadStoreOpcode matOpcode =
  1766. static_cast<HLMatLoadStoreOpcode>(opcode);
  1767. switch (matOpcode) {
  1768. case HLMatLoadStoreOpcode::ColMatLoad:
  1769. case HLMatLoadStoreOpcode::RowMatLoad: {
  1770. // Skip the vector version.
  1771. if (useCall->getType()->isVectorTy())
  1772. continue;
  1773. Value *newLd = Builder.CreateLoad(newGEP);
  1774. DXASSERT(matToVecMap.count(useCall), "must has vec version");
  1775. Value *oldLd = matToVecMap[useCall];
  1776. // Delete the oldLd.
  1777. AddToDeadInsts(cast<Instruction>(oldLd));
  1778. oldLd->replaceAllUsesWith(newLd);
  1779. matToVecMap[useCall] = newLd;
  1780. } break;
  1781. case HLMatLoadStoreOpcode::ColMatStore:
  1782. case HLMatLoadStoreOpcode::RowMatStore: {
  1783. Value *vecPtr = newGEP;
  1784. Value *matVal = useCall->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1785. // Skip the vector version.
  1786. if (matVal->getType()->isVectorTy()) {
  1787. AddToDeadInsts(useCall);
  1788. continue;
  1789. }
  1790. Instruction *matInst = cast<Instruction>(matVal);
  1791. DXASSERT(matToVecMap.count(matInst), "must has vec version");
  1792. Value *vecVal = matToVecMap[matInst];
  1793. Builder.CreateStore(vecVal, vecPtr);
  1794. } break;
  1795. }
  1796. } break;
  1797. case HLOpcodeGroup::HLSubscript: {
  1798. TranslateMatSubscript(matGEP, newGEP, useCall);
  1799. } break;
  1800. default:
  1801. DXASSERT(0, "invalid operation");
  1802. break;
  1803. }
  1804. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
  1805. // Just replace the src with vec version.
  1806. useInst->setOperand(0, newGEP);
  1807. } else {
  1808. // Must be GEP.
  1809. GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
  1810. TranslateMatArrayGEP(matGEP, cast<Instruction>(newGEP), GEP);
  1811. }
  1812. }
  1813. AddToDeadInsts(matGEP);
  1814. }
  1815. void HLMatrixLowerPass::replaceMatWithVec(Instruction *matInst,
  1816. Instruction *vecInst) {
  1817. for (Value::user_iterator user = matInst->user_begin();
  1818. user != matInst->user_end();) {
  1819. Instruction *useInst = cast<Instruction>(*(user++));
  1820. // Skip return here.
  1821. if (isa<ReturnInst>(useInst))
  1822. continue;
  1823. // User must be function call.
  1824. if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
  1825. hlsl::HLOpcodeGroup group =
  1826. hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
  1827. switch (group) {
  1828. case HLOpcodeGroup::HLIntrinsic: {
  1829. if (CallInst *matCI = dyn_cast<CallInst>(matInst)) {
  1830. MatIntrinsicReplace(cast<CallInst>(matInst), vecInst, useCall);
  1831. } else {
  1832. IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(useCall));
  1833. DXASSERT_LOCALVAR(opcode, opcode == IntrinsicOp::IOP_frexp,
  1834. "otherwise, unexpected opcode with matrix out parameter");
  1835. // NOTE: because out param use copy out semantic, so the operand of
  1836. // out must be temp alloca.
  1837. DXASSERT(isa<AllocaInst>(matInst), "else invalid mat ptr for frexp");
  1838. auto it = matToVecMap.find(useCall);
  1839. DXASSERT(it != matToVecMap.end(),
  1840. "else fail to create vec version of useCall");
  1841. CallInst *vecUseInst = cast<CallInst>(it->second);
  1842. for (unsigned i = 0; i < vecUseInst->getNumArgOperands(); i++) {
  1843. if (useCall->getArgOperand(i) == matInst) {
  1844. vecUseInst->setArgOperand(i, vecInst);
  1845. }
  1846. }
  1847. }
  1848. } break;
  1849. case HLOpcodeGroup::HLSelect: {
  1850. MatIntrinsicReplace(cast<CallInst>(matInst), vecInst, useCall);
  1851. } break;
  1852. case HLOpcodeGroup::HLBinOp: {
  1853. TrivialMatBinOpReplace(cast<CallInst>(matInst), vecInst, useCall);
  1854. } break;
  1855. case HLOpcodeGroup::HLUnOp: {
  1856. TrivialMatUnOpReplace(cast<CallInst>(matInst), vecInst, useCall);
  1857. } break;
  1858. case HLOpcodeGroup::HLCast: {
  1859. TranslateMatCast(cast<CallInst>(matInst), vecInst, useCall);
  1860. } break;
  1861. case HLOpcodeGroup::HLMatLoadStore: {
  1862. DXASSERT(matToVecMap.count(useCall), "must has vec version");
  1863. Value *vecUser = matToVecMap[useCall];
  1864. if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
  1865. // Load Already translated in lowerToVec.
  1866. // Store val operand will be set by the val use.
  1867. // Do nothing here.
  1868. } else if (StoreInst *stInst = dyn_cast<StoreInst>(vecUser))
  1869. stInst->setOperand(0, vecInst);
  1870. else
  1871. TrivialMatReplace(cast<CallInst>(matInst), vecInst, useCall);
  1872. } break;
  1873. case HLOpcodeGroup::HLSubscript: {
  1874. if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst))
  1875. TranslateMatSubscript(AI, vecInst, useCall);
  1876. else
  1877. TrivialMatReplace(cast<CallInst>(matInst), vecInst, useCall);
  1878. } break;
  1879. case HLOpcodeGroup::HLInit: {
  1880. DXASSERT(!isa<AllocaInst>(matInst), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
  1881. TranslateMatInit(useCall);
  1882. } break;
  1883. }
  1884. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
  1885. // Just replace the src with vec version.
  1886. useInst->setOperand(0, vecInst);
  1887. } else {
  1888. // Must be GEP on mat array alloca.
  1889. GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
  1890. AllocaInst *AI = cast<AllocaInst>(matInst);
  1891. TranslateMatArrayGEP(AI, vecInst, GEP);
  1892. }
  1893. }
  1894. }
  1895. void HLMatrixLowerPass::finalMatTranslation(Instruction *matInst) {
  1896. // Translate matInit.
  1897. if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
  1898. hlsl::HLOpcodeGroup group =
  1899. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  1900. switch (group) {
  1901. case HLOpcodeGroup::HLInit: {
  1902. TranslateMatInit(CI);
  1903. } break;
  1904. case HLOpcodeGroup::HLSelect: {
  1905. TranslateMatSelect(CI);
  1906. } break;
  1907. default:
  1908. // Skip group already translated.
  1909. break;
  1910. }
  1911. }
  1912. }
  1913. void HLMatrixLowerPass::DeleteDeadInsts() {
  1914. // Delete the matrix version insts.
  1915. for (Instruction *deadInst : m_deadInsts) {
  1916. // Replace with undef and remove it.
  1917. deadInst->replaceAllUsesWith(UndefValue::get(deadInst->getType()));
  1918. deadInst->eraseFromParent();
  1919. }
  1920. m_deadInsts.clear();
  1921. m_inDeadInstsSet.clear();
  1922. }
  1923. static bool OnlyUsedByMatrixLdSt(Value *V) {
  1924. bool onlyLdSt = true;
  1925. for (User *user : V->users()) {
  1926. if (isa<Constant>(user) && user->use_empty())
  1927. continue;
  1928. CallInst *CI = cast<CallInst>(user);
  1929. if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
  1930. HLOpcodeGroup::HLMatLoadStore)
  1931. continue;
  1932. onlyLdSt = false;
  1933. break;
  1934. }
  1935. return onlyLdSt;
  1936. }
  1937. static Constant *LowerMatrixArrayConst(Constant *MA, Type *ResultTy) {
  1938. if (ArrayType *AT = dyn_cast<ArrayType>(ResultTy)) {
  1939. std::vector<Constant *> Elts;
  1940. Type *EltResultTy = AT->getElementType();
  1941. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  1942. Constant *Elt =
  1943. LowerMatrixArrayConst(MA->getAggregateElement(i), EltResultTy);
  1944. Elts.emplace_back(Elt);
  1945. }
  1946. return ConstantArray::get(AT, Elts);
  1947. } else {
  1948. // Cast float[row][col] -> float< row * col>.
  1949. // Get float[row][col] from the struct.
  1950. Constant *rows = MA->getAggregateElement((unsigned)0);
  1951. ArrayType *RowAT = cast<ArrayType>(rows->getType());
  1952. std::vector<Constant *> Elts;
  1953. for (unsigned r=0;r<RowAT->getArrayNumElements();r++) {
  1954. Constant *row = rows->getAggregateElement(r);
  1955. VectorType *VT = cast<VectorType>(row->getType());
  1956. for (unsigned c = 0; c < VT->getVectorNumElements(); c++) {
  1957. Elts.emplace_back(row->getAggregateElement(c));
  1958. }
  1959. }
  1960. return ConstantVector::get(Elts);
  1961. }
  1962. }
  1963. void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
  1964. // Lower to array of vector array like float[row * col].
  1965. // It's follow the major of decl.
  1966. // DynamicIndexingVectorToArray will change it to scalar array.
  1967. Type *Ty = GV->getType()->getPointerElementType();
  1968. std::vector<unsigned> arraySizeList;
  1969. while (Ty->isArrayTy()) {
  1970. arraySizeList.push_back(Ty->getArrayNumElements());
  1971. Ty = Ty->getArrayElementType();
  1972. }
  1973. unsigned row, col;
  1974. Type *EltTy = GetMatrixInfo(Ty, col, row);
  1975. Ty = VectorType::get(EltTy, col * row);
  1976. for (auto arraySize = arraySizeList.rbegin();
  1977. arraySize != arraySizeList.rend(); arraySize++)
  1978. Ty = ArrayType::get(Ty, *arraySize);
  1979. Type *VecArrayTy = Ty;
  1980. Constant *OldInitVal = GV->getInitializer();
  1981. Constant *InitVal =
  1982. isa<UndefValue>(OldInitVal)
  1983. ? UndefValue::get(VecArrayTy)
  1984. : LowerMatrixArrayConst(OldInitVal, cast<ArrayType>(VecArrayTy));
  1985. bool isConst = GV->isConstant();
  1986. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  1987. unsigned AddressSpace = GV->getType()->getAddressSpace();
  1988. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  1989. Module *M = GV->getParent();
  1990. GlobalVariable *VecGV =
  1991. new llvm::GlobalVariable(*M, VecArrayTy, /*IsConstant*/ isConst, linkage,
  1992. /*InitVal*/ InitVal, GV->getName() + ".v",
  1993. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  1994. // Add debug info.
  1995. if (m_HasDbgInfo) {
  1996. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  1997. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, VecGV);
  1998. }
  1999. DenseMap<Instruction *, Value *> matToVecMap;
  2000. for (User *U : GV->users()) {
  2001. Value *VecGEP = nullptr;
  2002. // Must be GEP or GEPOperator.
  2003. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  2004. IRBuilder<> Builder(GEP);
  2005. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  2006. VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
  2007. AddToDeadInsts(GEP);
  2008. } else {
  2009. GEPOperator *GEPOP = cast<GEPOperator>(U);
  2010. IRBuilder<> Builder(GV->getContext());
  2011. SmallVector<Value *, 4> idxList(GEPOP->idx_begin(), GEPOP->idx_end());
  2012. VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
  2013. }
  2014. for (auto user = U->user_begin(); user != U->user_end();) {
  2015. CallInst *CI = cast<CallInst>(*(user++));
  2016. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2017. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2018. TranslateMatLoadStoreOnGlobalPtr(CI, VecGEP);
  2019. } else if (group == HLOpcodeGroup::HLSubscript) {
  2020. TranslateMatSubscriptOnGlobalPtr(CI, VecGEP);
  2021. } else {
  2022. DXASSERT(0, "invalid operation");
  2023. }
  2024. }
  2025. }
  2026. DeleteDeadInsts();
  2027. GV->removeDeadConstantUsers();
  2028. GV->eraseFromParent();
  2029. }
  2030. static void FlattenMatConst(Constant *M, std::vector<Constant *> &Elts) {
  2031. unsigned row, col;
  2032. Type *EltTy = HLMatrixLower::GetMatrixInfo(M->getType(), col, row);
  2033. if (isa<UndefValue>(M)) {
  2034. Constant *Elt = UndefValue::get(EltTy);
  2035. for (unsigned i=0;i<col*row;i++)
  2036. Elts.emplace_back(Elt);
  2037. } else {
  2038. M = M->getAggregateElement((unsigned)0);
  2039. // Initializer is already in correct major.
  2040. // Just read it here.
  2041. // The type is vector<element, col>[row].
  2042. for (unsigned r = 0; r < row; r++) {
  2043. Constant *C = M->getAggregateElement(r);
  2044. for (unsigned c = 0; c < col; c++) {
  2045. Elts.emplace_back(C->getAggregateElement(c));
  2046. }
  2047. }
  2048. }
  2049. }
  2050. void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
  2051. if (HLMatrixLower::IsMatrixArrayPointer(GV->getType())) {
  2052. runOnGlobalMatrixArray(GV);
  2053. return;
  2054. }
  2055. Type *Ty = GV->getType()->getPointerElementType();
  2056. if (!HLMatrixLower::IsMatrixType(Ty))
  2057. return;
  2058. bool onlyLdSt = OnlyUsedByMatrixLdSt(GV);
  2059. bool isConst = GV->isConstant();
  2060. Type *vecTy = HLMatrixLower::LowerMatrixType(Ty);
  2061. Module *M = GV->getParent();
  2062. const DataLayout &DL = M->getDataLayout();
  2063. std::vector<Constant *> Elts;
  2064. // Lower to vector or array for scalar matrix.
  2065. // Make it col major so don't need shuffle when load/store.
  2066. FlattenMatConst(GV->getInitializer(), Elts);
  2067. if (onlyLdSt) {
  2068. Type *EltTy = vecTy->getVectorElementType();
  2069. unsigned vecSize = vecTy->getVectorNumElements();
  2070. std::vector<Value *> vecGlobals(vecSize);
  2071. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  2072. unsigned AddressSpace = GV->getType()->getAddressSpace();
  2073. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  2074. unsigned debugOffset = 0;
  2075. unsigned size = DL.getTypeAllocSizeInBits(EltTy);
  2076. unsigned align = DL.getPrefTypeAlignment(EltTy);
  2077. for (int i = 0, e = vecSize; i != e; ++i) {
  2078. Constant *InitVal = Elts[i];
  2079. GlobalVariable *EltGV = new llvm::GlobalVariable(
  2080. *M, EltTy, /*IsConstant*/ isConst, linkage,
  2081. /*InitVal*/ InitVal, GV->getName() + "." + Twine(i),
  2082. /*InsertBefore*/nullptr,
  2083. TLMode, AddressSpace);
  2084. // Add debug info.
  2085. if (m_HasDbgInfo) {
  2086. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2087. HLModule::CreateElementGlobalVariableDebugInfo(
  2088. GV, Finder, EltGV, size, align, debugOffset,
  2089. EltGV->getName().ltrim(GV->getName()));
  2090. debugOffset += size;
  2091. }
  2092. vecGlobals[i] = EltGV;
  2093. }
  2094. for (User *user : GV->users()) {
  2095. if (isa<Constant>(user) && user->use_empty())
  2096. continue;
  2097. CallInst *CI = cast<CallInst>(user);
  2098. TranslateMatLoadStoreOnGlobal(GV, vecGlobals, CI);
  2099. AddToDeadInsts(CI);
  2100. }
  2101. DeleteDeadInsts();
  2102. GV->eraseFromParent();
  2103. }
  2104. else {
  2105. // lower to array of scalar here.
  2106. ArrayType *AT = ArrayType::get(vecTy->getVectorElementType(), vecTy->getVectorNumElements());
  2107. Constant *InitVal = ConstantArray::get(AT, Elts);
  2108. GlobalVariable *arrayMat = new llvm::GlobalVariable(
  2109. *M, AT, /*IsConstant*/ false, llvm::GlobalValue::InternalLinkage,
  2110. /*InitVal*/ InitVal, GV->getName());
  2111. // Add debug info.
  2112. if (m_HasDbgInfo) {
  2113. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2114. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder,
  2115. arrayMat);
  2116. }
  2117. for (auto U = GV->user_begin(); U != GV->user_end();) {
  2118. Value *user = *(U++);
  2119. CallInst *CI = cast<CallInst>(user);
  2120. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2121. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2122. TranslateMatLoadStoreOnGlobal(GV, arrayMat, CI);
  2123. }
  2124. else {
  2125. DXASSERT(group == HLOpcodeGroup::HLSubscript, "Must be subscript operation");
  2126. TranslateMatSubscriptOnGlobalPtr(CI, arrayMat);
  2127. }
  2128. }
  2129. GV->removeDeadConstantUsers();
  2130. GV->eraseFromParent();
  2131. }
  2132. }
  2133. void HLMatrixLowerPass::runOnFunction(Function &F) {
  2134. // Create vector version of matrix instructions first.
  2135. // The matrix operands will be undefval for these instructions.
  2136. for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
  2137. BasicBlock *BB = BBI;
  2138. for (Instruction &I : BB->getInstList()) {
  2139. if (IsMatrixType(I.getType())) {
  2140. lowerToVec(&I);
  2141. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
  2142. Type *Ty = AI->getAllocatedType();
  2143. if (HLMatrixLower::IsMatrixType(Ty)) {
  2144. lowerToVec(&I);
  2145. } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
  2146. lowerToVec(&I);
  2147. }
  2148. } else if (CallInst *CI = dyn_cast<CallInst>(&I)) {
  2149. HLOpcodeGroup group =
  2150. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  2151. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2152. HLMatLoadStoreOpcode opcode =
  2153. static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
  2154. DXASSERT_LOCALVAR(opcode,
  2155. opcode == HLMatLoadStoreOpcode::ColMatStore ||
  2156. opcode == HLMatLoadStoreOpcode::RowMatStore,
  2157. "Must MatStore here, load will go IsMatrixType path");
  2158. // Lower it here to make sure it is ready before replace.
  2159. lowerToVec(&I);
  2160. }
  2161. }
  2162. }
  2163. }
  2164. // Update the use of matrix inst with the vector version.
  2165. for (auto matToVecIter = matToVecMap.begin();
  2166. matToVecIter != matToVecMap.end();) {
  2167. auto matToVec = matToVecIter++;
  2168. replaceMatWithVec(matToVec->first, cast<Instruction>(matToVec->second));
  2169. }
  2170. // Translate mat inst which require all operands ready.
  2171. for (auto matToVecIter = matToVecMap.begin();
  2172. matToVecIter != matToVecMap.end();) {
  2173. auto matToVec = matToVecIter++;
  2174. finalMatTranslation(matToVec->first);
  2175. }
  2176. // Delete the matrix version insts.
  2177. for (auto matToVecIter = matToVecMap.begin();
  2178. matToVecIter != matToVecMap.end();) {
  2179. auto matToVec = matToVecIter++;
  2180. // Add to m_deadInsts.
  2181. Instruction *matInst = matToVec->first;
  2182. AddToDeadInsts(matInst);
  2183. }
  2184. DeleteDeadInsts();
  2185. matToVecMap.clear();
  2186. }