HLMatrixLowerPass.cpp 102 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixLowerPass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // HLMatrixLowerPass implementation. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  12. #include "dxc/HLSL/HLMatrixLowerPass.h"
  13. #include "dxc/HLSL/HLOperations.h"
  14. #include "dxc/HLSL/HLModule.h"
  15. #include "dxc/HLSL/DxilUtil.h"
  16. #include "dxc/HlslIntrinsicOp.h"
  17. #include "dxc/Support/Global.h"
  18. #include "dxc/HLSL/DxilOperations.h"
  19. #include "dxc/HLSL/DxilTypeSystem.h"
  20. #include "dxc/HLSL/DxilModule.h"
  21. #include "llvm/IR/IRBuilder.h"
  22. #include "llvm/IR/Module.h"
  23. #include "llvm/IR/DebugInfo.h"
  24. #include "llvm/IR/IntrinsicInst.h"
  25. #include "llvm/Transforms/Utils/Local.h"
  26. #include "llvm/Pass.h"
  27. #include "llvm/Support/raw_ostream.h"
  28. #include <unordered_set>
  29. #include <vector>
  30. using namespace llvm;
  31. using namespace hlsl;
  32. using namespace hlsl::HLMatrixLower;
  33. namespace hlsl {
  34. namespace HLMatrixLower {
  35. bool IsMatrixType(Type *Ty) {
  36. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  37. Type *EltTy = ST->getElementType(0);
  38. if (!ST->getName().startswith("class.matrix"))
  39. return false;
  40. bool isVecArray = EltTy->isArrayTy() &&
  41. EltTy->getArrayElementType()->isVectorTy();
  42. return isVecArray && EltTy->getArrayNumElements() <= 4;
  43. }
  44. return false;
  45. }
  46. // If user is function call, return param annotation to get matrix major.
  47. DxilFieldAnnotation *FindAnnotationFromMatUser(Value *Mat,
  48. DxilTypeSystem &typeSys) {
  49. for (User *U : Mat->users()) {
  50. if (CallInst *CI = dyn_cast<CallInst>(U)) {
  51. Function *F = CI->getCalledFunction();
  52. if (DxilFunctionAnnotation *Anno = typeSys.GetFunctionAnnotation(F)) {
  53. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  54. if (CI->getArgOperand(i) == Mat) {
  55. return &Anno->GetParameterAnnotation(i);
  56. }
  57. }
  58. }
  59. }
  60. }
  61. return nullptr;
  62. }
  63. // Translate matrix type to vector type.
  64. Type *LowerMatrixType(Type *Ty) {
  65. // Only translate matrix type and function type which use matrix type.
  66. // Not translate struct has matrix or matrix pointer.
  67. // Struct should be flattened before.
  68. // Pointer could cover by matldst which use vector as value type.
  69. if (FunctionType *FT = dyn_cast<FunctionType>(Ty)) {
  70. Type *RetTy = LowerMatrixType(FT->getReturnType());
  71. SmallVector<Type *, 4> params;
  72. for (Type *param : FT->params()) {
  73. params.emplace_back(LowerMatrixType(param));
  74. }
  75. return FunctionType::get(RetTy, params, false);
  76. } else if (IsMatrixType(Ty)) {
  77. unsigned row, col;
  78. Type *EltTy = GetMatrixInfo(Ty, col, row);
  79. return VectorType::get(EltTy, row * col);
  80. } else {
  81. return Ty;
  82. }
  83. }
  84. // Translate matrix type to array type.
  85. Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
  86. if (IsMatrixType(Ty)) {
  87. unsigned row, col;
  88. Type *EltTy = GetMatrixInfo(Ty, col, row);
  89. return ArrayType::get(EltTy, row * col);
  90. } else {
  91. return Ty;
  92. }
  93. }
  94. Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
  95. DXASSERT(IsMatrixType(Ty), "not matrix type");
  96. StructType *ST = cast<StructType>(Ty);
  97. Type *EltTy = ST->getElementType(0);
  98. Type *RowTy = EltTy->getArrayElementType();
  99. row = EltTy->getArrayNumElements();
  100. col = RowTy->getVectorNumElements();
  101. return RowTy->getVectorElementType();
  102. }
  103. bool IsMatrixArrayPointer(llvm::Type *Ty) {
  104. if (!Ty->isPointerTy())
  105. return false;
  106. Ty = Ty->getPointerElementType();
  107. if (!Ty->isArrayTy())
  108. return false;
  109. while (Ty->isArrayTy())
  110. Ty = Ty->getArrayElementType();
  111. return IsMatrixType(Ty);
  112. }
  113. Type *LowerMatrixArrayPointer(Type *Ty) {
  114. unsigned addrSpace = Ty->getPointerAddressSpace();
  115. Ty = Ty->getPointerElementType();
  116. std::vector<unsigned> arraySizeList;
  117. while (Ty->isArrayTy()) {
  118. arraySizeList.push_back(Ty->getArrayNumElements());
  119. Ty = Ty->getArrayElementType();
  120. }
  121. Ty = LowerMatrixType(Ty);
  122. for (auto arraySize = arraySizeList.rbegin();
  123. arraySize != arraySizeList.rend(); arraySize++)
  124. Ty = ArrayType::get(Ty, *arraySize);
  125. return PointerType::get(Ty, addrSpace);
  126. }
  127. Type *LowerMatrixArrayPointerToOneDimArray(Type *Ty) {
  128. unsigned addrSpace = Ty->getPointerAddressSpace();
  129. Ty = Ty->getPointerElementType();
  130. unsigned arraySize = 1;
  131. while (Ty->isArrayTy()) {
  132. arraySize *= Ty->getArrayNumElements();
  133. Ty = Ty->getArrayElementType();
  134. }
  135. unsigned row, col;
  136. Type *EltTy = GetMatrixInfo(Ty, col, row);
  137. arraySize *= row*col;
  138. Ty = ArrayType::get(EltTy, arraySize);
  139. return PointerType::get(Ty, addrSpace);
  140. }
  141. Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
  142. IRBuilder<> &Builder) {
  143. Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
  144. for (unsigned i = 0; i < size; i++)
  145. Vec = Builder.CreateInsertElement(Vec, elts[i], i);
  146. return Vec;
  147. }
  148. Value *LowerGEPOnMatIndexListToIndex(
  149. llvm::GetElementPtrInst *GEP, ArrayRef<Value *> IdxList) {
  150. IRBuilder<> Builder(GEP);
  151. Value *zero = Builder.getInt32(0);
  152. DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
  153. Value *baseIdx = (GEP->idx_begin())->get();
  154. DXASSERT_LOCALVAR(baseIdx, baseIdx == zero, "base index must be 0");
  155. Value *Idx = (GEP->idx_begin() + 1)->get();
  156. if (ConstantInt *immIdx = dyn_cast<ConstantInt>(Idx)) {
  157. return IdxList[immIdx->getSExtValue()];
  158. } else {
  159. IRBuilder<> AllocaBuilder(
  160. GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  161. unsigned size = IdxList.size();
  162. // Store idxList to temp array.
  163. ArrayType *AT = ArrayType::get(IdxList[0]->getType(), size);
  164. Value *tempArray = AllocaBuilder.CreateAlloca(AT);
  165. for (unsigned i = 0; i < size; i++) {
  166. Value *EltPtr = Builder.CreateGEP(tempArray, {zero, Builder.getInt32(i)});
  167. Builder.CreateStore(IdxList[i], EltPtr);
  168. }
  169. // Load the idx.
  170. Value *GEPOffset = Builder.CreateGEP(tempArray, {zero, Idx});
  171. return Builder.CreateLoad(GEPOffset);
  172. }
  173. }
  174. unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row) {
  175. return c * row + r;
  176. }
  177. unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col) {
  178. return r * col + c;
  179. }
  180. } // namespace HLMatrixLower
  181. } // namespace hlsl
  182. namespace {
  183. class HLMatrixLowerPass : public ModulePass {
  184. public:
  185. static char ID; // Pass identification, replacement for typeid
  186. explicit HLMatrixLowerPass() : ModulePass(ID) {}
  187. const char *getPassName() const override { return "HL matrix lower"; }
  188. bool runOnModule(Module &M) override {
  189. m_pModule = &M;
  190. m_pHLModule = &m_pModule->GetOrCreateHLModule();
  191. // Load up debug information, to cross-reference values and the instructions
  192. // used to load them.
  193. m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  194. for (Function &F : M.functions()) {
  195. if (F.isDeclaration())
  196. continue;
  197. runOnFunction(F);
  198. }
  199. std::vector<GlobalVariable*> staticGVs;
  200. for (GlobalVariable &GV : M.globals()) {
  201. if (dxilutil::IsStaticGlobal(&GV) ||
  202. dxilutil::IsSharedMemoryGlobal(&GV)) {
  203. staticGVs.emplace_back(&GV);
  204. }
  205. }
  206. for (GlobalVariable *GV : staticGVs)
  207. runOnGlobal(GV);
  208. return true;
  209. }
  210. private:
  211. Module *m_pModule;
  212. HLModule *m_pHLModule;
  213. bool m_HasDbgInfo;
  214. std::vector<Instruction *> m_deadInsts;
  215. // For instruction like matrix array init.
  216. // May use more than 1 matrix alloca inst.
  217. // This set is here to avoid put it into deadInsts more than once.
  218. std::unordered_set<Instruction *> m_inDeadInstsSet;
  219. // For most matrix insturction users, it will only have one matrix use.
  220. // Use vector so save deadInsts because vector is cheap.
  221. void AddToDeadInsts(Instruction *I) { m_deadInsts.emplace_back(I); }
  222. // In case instruction has more than one matrix use.
  223. // Use AddToDeadInstsWithDups to make sure it's not add to deadInsts more than once.
  224. void AddToDeadInstsWithDups(Instruction *I) {
  225. if (m_inDeadInstsSet.count(I) == 0) {
  226. // Only add to deadInsts when it's not inside m_inDeadInstsSet.
  227. m_inDeadInstsSet.insert(I);
  228. AddToDeadInsts(I);
  229. }
  230. }
  231. void runOnFunction(Function &F);
  232. void runOnGlobal(GlobalVariable *GV);
  233. void runOnGlobalMatrixArray(GlobalVariable *GV);
  234. Instruction *MatCastToVec(CallInst *CI);
  235. Instruction *MatLdStToVec(CallInst *CI);
  236. Instruction *MatSubscriptToVec(CallInst *CI);
  237. Instruction *MatFrExpToVec(CallInst *CI);
  238. Instruction *MatIntrinsicToVec(CallInst *CI);
  239. Instruction *TrivialMatUnOpToVec(CallInst *CI);
  240. // Replace matVal with vecVal on matUseInst.
  241. void TrivialMatUnOpReplace(Value *matVal, Value *vecVal,
  242. CallInst *matUseInst);
  243. Instruction *TrivialMatBinOpToVec(CallInst *CI);
  244. // Replace matVal with vecVal on matUseInst.
  245. void TrivialMatBinOpReplace(Value *matVal, Value *vecVal,
  246. CallInst *matUseInst);
  247. // Replace matVal with vecVal on mulInst.
  248. void TranslateMatMatMul(Value *matVal, Value *vecVal,
  249. CallInst *mulInst, bool isSigned);
  250. void TranslateMatVecMul(Value *matVal, Value *vecVal,
  251. CallInst *mulInst, bool isSigned);
  252. void TranslateVecMatMul(Value *matVal, Value *vecVal,
  253. CallInst *mulInst, bool isSigned);
  254. void TranslateMul(Value *matVal, Value *vecVal, CallInst *mulInst,
  255. bool isSigned);
  256. // Replace matVal with vecVal on transposeInst.
  257. void TranslateMatTranspose(Value *matVal, Value *vecVal,
  258. CallInst *transposeInst);
  259. void TranslateMatDeterminant(Value *matVal, Value *vecVal,
  260. CallInst *determinantInst);
  261. void MatIntrinsicReplace(Value *matVal, Value *vecVal,
  262. CallInst *matUseInst);
  263. // Replace matVal with vecVal on castInst.
  264. void TranslateMatMatCast(Value *matVal, Value *vecVal,
  265. CallInst *castInst);
  266. void TranslateMatToOtherCast(Value *matVal, Value *vecVal,
  267. CallInst *castInst);
  268. void TranslateMatCast(Value *matVal, Value *vecVal,
  269. CallInst *castInst);
  270. void TranslateMatMajorCast(Value *matVal, Value *vecVal,
  271. CallInst *castInst, bool rowToCol, bool transpose);
  272. // Replace matVal with vecVal in matSubscript
  273. void TranslateMatSubscript(Value *matVal, Value *vecVal,
  274. CallInst *matSubInst);
  275. // Replace matInitInst using matToVecMap
  276. void TranslateMatInit(CallInst *matInitInst);
  277. // Replace matSelectInst using matToVecMap
  278. void TranslateMatSelect(CallInst *matSelectInst);
  279. // Replace matVal with vecVal on matInitInst.
  280. void TranslateMatArrayGEP(Value *matVal, Value *vecVal,
  281. GetElementPtrInst *matGEP);
  282. void TranslateMatLoadStoreOnGlobal(Value *matGlobal, ArrayRef<Value *>vecGlobals,
  283. CallInst *matLdStInst);
  284. void TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
  285. CallInst *matLdStInst);
  286. void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
  287. void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
  288. // Get new matrix value corresponding to vecVal
  289. Value *GetMatrixForVec(Value *vecVal, Type *matTy);
  290. // Translate library function input/output to preserve function signatures
  291. void TranslateArgForLibFunc(CallInst *CI);
  292. void TranslateArgsForLibFunc(Function &F);
  293. // Replace matVal with vecVal on matUseInst.
  294. void TrivialMatReplace(Value *matVal, Value *vecVal,
  295. CallInst *matUseInst);
  296. // Lower a matrix type instruction to a vector type instruction.
  297. void lowerToVec(Instruction *matInst);
  298. // Lower users of a matrix type instruction.
  299. void replaceMatWithVec(Value *matVal, Value *vecVal);
  300. // Translate user library function call arguments
  301. void castMatrixArgs(Value *matVal, Value *vecVal, CallInst *CI);
  302. // Translate mat inst which need all operands ready.
  303. void finalMatTranslation(Value *matVal);
  304. // Delete dead insts in m_deadInsts.
  305. void DeleteDeadInsts();
  306. // Map from matrix value to its vector version.
  307. DenseMap<Value *, Value *> matToVecMap;
  308. // Map from new vector version to matrix version needed by user call or return.
  309. DenseMap<Value *, Value *> vecToMatMap;
  310. };
  311. }
  312. char HLMatrixLowerPass::ID = 0;
  313. ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
  314. INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
  315. static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
  316. IRBuilder<> Builder) {
  317. // Cast to bool.
  318. if (toTy->getScalarType()->isIntegerTy() &&
  319. toTy->getScalarType()->getIntegerBitWidth() == 1) {
  320. Type *fromTy = src->getType();
  321. bool isFloat = fromTy->getScalarType()->isFloatingPointTy();
  322. Constant *zero;
  323. if (isFloat)
  324. zero = llvm::ConstantFP::get(fromTy->getScalarType(), 0);
  325. else
  326. zero = llvm::ConstantInt::get(fromTy->getScalarType(), 0);
  327. if (toTy->getScalarType() != toTy) {
  328. // Create constant vector.
  329. unsigned size = toTy->getVectorNumElements();
  330. std::vector<Constant *> zeros(size, zero);
  331. zero = llvm::ConstantVector::get(zeros);
  332. }
  333. if (isFloat)
  334. return cast<Instruction>(Builder.CreateFCmpOEQ(src, zero));
  335. else
  336. return cast<Instruction>(Builder.CreateICmpEQ(src, zero));
  337. }
  338. Type *eltToTy = toTy->getScalarType();
  339. Type *eltFromTy = src->getType()->getScalarType();
  340. bool fromUnsigned = castOp == HLCastOpcode::FromUnsignedCast ||
  341. castOp == HLCastOpcode::UnsignedUnsignedCast;
  342. bool toUnsigned = castOp == HLCastOpcode::ToUnsignedCast ||
  343. castOp == HLCastOpcode::UnsignedUnsignedCast;
  344. Instruction::CastOps castOps = static_cast<Instruction::CastOps>(
  345. HLModule::FindCastOp(fromUnsigned, toUnsigned, eltFromTy, eltToTy));
  346. return cast<Instruction>(Builder.CreateCast(castOps, src, toTy));
  347. }
  348. Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
  349. IRBuilder<> Builder(CI);
  350. Value *op = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  351. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(CI));
  352. bool ToMat = IsMatrixType(CI->getType());
  353. bool FromMat = IsMatrixType(op->getType());
  354. if (ToMat && !FromMat) {
  355. // Translate OtherToMat here.
  356. // Rest will translated when replace.
  357. unsigned col, row;
  358. Type *EltTy = GetMatrixInfo(CI->getType(), col, row);
  359. unsigned toSize = col * row;
  360. Instruction *sizeCast = nullptr;
  361. Type *FromTy = op->getType();
  362. Type *I32Ty = IntegerType::get(FromTy->getContext(), 32);
  363. if (FromTy->isVectorTy()) {
  364. std::vector<Constant *> MaskVec(toSize);
  365. for (size_t i = 0; i != toSize; ++i)
  366. MaskVec[i] = ConstantInt::get(I32Ty, i);
  367. Value *castMask = ConstantVector::get(MaskVec);
  368. sizeCast = new ShuffleVectorInst(op, op, castMask);
  369. Builder.Insert(sizeCast);
  370. } else {
  371. op = Builder.CreateInsertElement(
  372. UndefValue::get(VectorType::get(FromTy, 1)), op, (uint64_t)0);
  373. Constant *zero = ConstantInt::get(I32Ty, 0);
  374. std::vector<Constant *> MaskVec(toSize, zero);
  375. Value *castMask = ConstantVector::get(MaskVec);
  376. sizeCast = new ShuffleVectorInst(op, op, castMask);
  377. Builder.Insert(sizeCast);
  378. }
  379. Instruction *typeCast = sizeCast;
  380. if (EltTy != FromTy->getScalarType()) {
  381. typeCast = CreateTypeCast(opcode, VectorType::get(EltTy, toSize),
  382. sizeCast, Builder);
  383. }
  384. return typeCast;
  385. } else if (FromMat && ToMat) {
  386. if (isa<Argument>(op)) {
  387. // Cast From mat to mat for arugment.
  388. IRBuilder<> Builder(CI);
  389. // Here only lower the return type to vector.
  390. Type *RetTy = LowerMatrixType(CI->getType());
  391. SmallVector<Type *, 4> params;
  392. for (Value *operand : CI->arg_operands()) {
  393. params.emplace_back(operand->getType());
  394. }
  395. Type *FT = FunctionType::get(RetTy, params, false);
  396. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  397. unsigned opcode = GetHLOpcode(CI);
  398. Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT),
  399. group, opcode);
  400. SmallVector<Value *, 4> argList;
  401. for (Value *arg : CI->arg_operands()) {
  402. argList.emplace_back(arg);
  403. }
  404. return Builder.CreateCall(vecF, argList);
  405. }
  406. }
  407. return MatIntrinsicToVec(CI);
  408. }
  409. // Return GEP if value is Matrix resulting GEP from UDT alloca
  410. // UDT alloca must be there for library function args
  411. static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
  412. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
  413. if (IsMatrixType(GEP->getResultElementType())) {
  414. Value *ptr = GEP->getPointerOperand();
  415. if (AllocaInst *AI = dyn_cast<AllocaInst>(ptr)) {
  416. Type *ATy = AI->getAllocatedType();
  417. if (ATy->isStructTy() && !IsMatrixType(ATy)) {
  418. return GEP;
  419. }
  420. }
  421. }
  422. }
  423. return nullptr;
  424. }
  425. // Return GEP if value is Matrix resulting GEP from UDT argument of
  426. // none-graphics functions.
  427. static GetElementPtrInst *GetIfMatrixGEPOfUDTArg(Value *V, HLModule &HM) {
  428. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
  429. if (IsMatrixType(GEP->getResultElementType())) {
  430. Value *ptr = GEP->getPointerOperand();
  431. if (Argument *Arg = dyn_cast<Argument>(ptr)) {
  432. if (!HM.IsGraphicsShader(Arg->getParent()))
  433. return GEP;
  434. }
  435. }
  436. }
  437. return nullptr;
  438. }
  439. Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
  440. IRBuilder<> Builder(CI);
  441. unsigned opcode = GetHLOpcode(CI);
  442. HLMatLoadStoreOpcode matOpcode = static_cast<HLMatLoadStoreOpcode>(opcode);
  443. Instruction *result = nullptr;
  444. switch (matOpcode) {
  445. case HLMatLoadStoreOpcode::ColMatLoad:
  446. case HLMatLoadStoreOpcode::RowMatLoad: {
  447. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
  448. if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
  449. GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
  450. Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
  451. result = Builder.CreateLoad(vecPtr);
  452. } else
  453. result = MatIntrinsicToVec(CI);
  454. } break;
  455. case HLMatLoadStoreOpcode::ColMatStore:
  456. case HLMatLoadStoreOpcode::RowMatStore: {
  457. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
  458. if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
  459. GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
  460. Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
  461. Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  462. Value *vecVal =
  463. UndefValue::get(HLMatrixLower::LowerMatrixType(matVal->getType()));
  464. result = Builder.CreateStore(vecVal, vecPtr);
  465. } else
  466. result = MatIntrinsicToVec(CI);
  467. } break;
  468. }
  469. return result;
  470. }
  471. Instruction *HLMatrixLowerPass::MatSubscriptToVec(CallInst *CI) {
  472. IRBuilder<> Builder(CI);
  473. Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  474. if (isa<AllocaInst>(matPtr)) {
  475. // Here just create a new matSub call which use vec ptr.
  476. // Later in TranslateMatSubscript will do the real translation.
  477. std::vector<Value *> args(CI->getNumArgOperands());
  478. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  479. args[i] = CI->getArgOperand(i);
  480. }
  481. // Change mat ptr into vec ptr.
  482. args[HLOperandIndex::kMatSubscriptMatOpIdx] =
  483. matToVecMap[cast<Instruction>(matPtr)];
  484. std::vector<Type *> paramTyList(CI->getNumArgOperands());
  485. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  486. paramTyList[i] = args[i]->getType();
  487. }
  488. FunctionType *funcTy = FunctionType::get(CI->getType(), paramTyList, false);
  489. unsigned opcode = GetHLOpcode(CI);
  490. Function *opFunc = GetOrCreateHLFunction(*m_pModule, funcTy, HLOpcodeGroup::HLSubscript, opcode);
  491. return Builder.CreateCall(opFunc, args);
  492. } else
  493. return MatIntrinsicToVec(CI);
  494. }
  495. Instruction *HLMatrixLowerPass::MatFrExpToVec(CallInst *CI) {
  496. IRBuilder<> Builder(CI);
  497. FunctionType *FT = CI->getCalledFunction()->getFunctionType();
  498. Type *RetTy = LowerMatrixType(FT->getReturnType());
  499. SmallVector<Type *, 4> params;
  500. for (Type *param : FT->params()) {
  501. if (!param->isPointerTy()) {
  502. params.emplace_back(LowerMatrixType(param));
  503. } else {
  504. // Lower pointer type for frexp.
  505. Type *EltTy = LowerMatrixType(param->getPointerElementType());
  506. params.emplace_back(
  507. PointerType::get(EltTy, param->getPointerAddressSpace()));
  508. }
  509. }
  510. Type *VecFT = FunctionType::get(RetTy, params, false);
  511. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  512. Function *vecF =
  513. GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(VecFT), group,
  514. static_cast<unsigned>(IntrinsicOp::IOP_frexp));
  515. SmallVector<Value *, 4> argList;
  516. auto paramTyIt = params.begin();
  517. for (Value *arg : CI->arg_operands()) {
  518. Type *Ty = arg->getType();
  519. Type *ParamTy = *(paramTyIt++);
  520. if (Ty != ParamTy)
  521. argList.emplace_back(UndefValue::get(ParamTy));
  522. else
  523. argList.emplace_back(arg);
  524. }
  525. return Builder.CreateCall(vecF, argList);
  526. }
  527. Instruction *HLMatrixLowerPass::MatIntrinsicToVec(CallInst *CI) {
  528. IRBuilder<> Builder(CI);
  529. unsigned opcode = GetHLOpcode(CI);
  530. if (opcode == static_cast<unsigned>(IntrinsicOp::IOP_frexp))
  531. return MatFrExpToVec(CI);
  532. Type *FT = LowerMatrixType(CI->getCalledFunction()->getFunctionType());
  533. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  534. Function *vecF = GetOrCreateHLFunction(*m_pModule, cast<FunctionType>(FT), group, opcode);
  535. SmallVector<Value *, 4> argList;
  536. for (Value *arg : CI->arg_operands()) {
  537. Type *Ty = arg->getType();
  538. if (IsMatrixType(Ty)) {
  539. argList.emplace_back(UndefValue::get(LowerMatrixType(Ty)));
  540. } else
  541. argList.emplace_back(arg);
  542. }
  543. return Builder.CreateCall(vecF, argList);
  544. }
  545. Instruction *HLMatrixLowerPass::TrivialMatUnOpToVec(CallInst *CI) {
  546. Type *ResultTy = LowerMatrixType(CI->getType());
  547. UndefValue *tmp = UndefValue::get(ResultTy);
  548. IRBuilder<> Builder(CI);
  549. HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(CI));
  550. bool isFloat = ResultTy->getVectorElementType()->isFloatingPointTy();
  551. auto GetVecConst = [&](Type *Ty, int v) -> Constant * {
  552. Constant *val = isFloat ? ConstantFP::get(Ty->getScalarType(), v)
  553. : ConstantInt::get(Ty->getScalarType(), v);
  554. std::vector<Constant *> vals(Ty->getVectorNumElements(), val);
  555. return ConstantVector::get(vals);
  556. };
  557. Constant *one = GetVecConst(ResultTy, 1);
  558. Instruction *Result = nullptr;
  559. switch (opcode) {
  560. case HLUnaryOpcode::Minus: {
  561. Constant *zero = GetVecConst(ResultTy, 0);
  562. if (isFloat)
  563. Result = BinaryOperator::CreateFSub(zero, tmp);
  564. else
  565. Result = BinaryOperator::CreateSub(zero, tmp);
  566. } break;
  567. case HLUnaryOpcode::LNot: {
  568. Constant *zero = GetVecConst(ResultTy, 0);
  569. if (isFloat)
  570. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_UNE, tmp, zero);
  571. else
  572. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, zero);
  573. } break;
  574. case HLUnaryOpcode::Not:
  575. Result = BinaryOperator::CreateXor(tmp, tmp);
  576. break;
  577. case HLUnaryOpcode::PostInc:
  578. case HLUnaryOpcode::PreInc:
  579. if (isFloat)
  580. Result = BinaryOperator::CreateFAdd(tmp, one);
  581. else
  582. Result = BinaryOperator::CreateAdd(tmp, one);
  583. break;
  584. case HLUnaryOpcode::PostDec:
  585. case HLUnaryOpcode::PreDec:
  586. if (isFloat)
  587. Result = BinaryOperator::CreateFSub(tmp, one);
  588. else
  589. Result = BinaryOperator::CreateSub(tmp, one);
  590. break;
  591. default:
  592. DXASSERT(0, "not implement");
  593. return nullptr;
  594. }
  595. Builder.Insert(Result);
  596. return Result;
  597. }
  598. Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
  599. Type *ResultTy = LowerMatrixType(CI->getType());
  600. IRBuilder<> Builder(CI);
  601. HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(CI));
  602. Type *OpTy = LowerMatrixType(
  603. CI->getOperand(HLOperandIndex::kBinaryOpSrc0Idx)->getType());
  604. UndefValue *tmp = UndefValue::get(OpTy);
  605. bool isFloat = OpTy->getVectorElementType()->isFloatingPointTy();
  606. Instruction *Result = nullptr;
  607. switch (opcode) {
  608. case HLBinaryOpcode::Add:
  609. if (isFloat)
  610. Result = BinaryOperator::CreateFAdd(tmp, tmp);
  611. else
  612. Result = BinaryOperator::CreateAdd(tmp, tmp);
  613. break;
  614. case HLBinaryOpcode::Sub:
  615. if (isFloat)
  616. Result = BinaryOperator::CreateFSub(tmp, tmp);
  617. else
  618. Result = BinaryOperator::CreateSub(tmp, tmp);
  619. break;
  620. case HLBinaryOpcode::Mul:
  621. if (isFloat)
  622. Result = BinaryOperator::CreateFMul(tmp, tmp);
  623. else
  624. Result = BinaryOperator::CreateMul(tmp, tmp);
  625. break;
  626. case HLBinaryOpcode::Div:
  627. if (isFloat)
  628. Result = BinaryOperator::CreateFDiv(tmp, tmp);
  629. else
  630. Result = BinaryOperator::CreateSDiv(tmp, tmp);
  631. break;
  632. case HLBinaryOpcode::Rem:
  633. if (isFloat)
  634. Result = BinaryOperator::CreateFRem(tmp, tmp);
  635. else
  636. Result = BinaryOperator::CreateSRem(tmp, tmp);
  637. break;
  638. case HLBinaryOpcode::And:
  639. Result = BinaryOperator::CreateAnd(tmp, tmp);
  640. break;
  641. case HLBinaryOpcode::Or:
  642. Result = BinaryOperator::CreateOr(tmp, tmp);
  643. break;
  644. case HLBinaryOpcode::Xor:
  645. Result = BinaryOperator::CreateXor(tmp, tmp);
  646. break;
  647. case HLBinaryOpcode::Shl: {
  648. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  649. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  650. Result = BinaryOperator::CreateShl(tmp, tmp);
  651. } break;
  652. case HLBinaryOpcode::Shr: {
  653. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  654. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  655. Result = BinaryOperator::CreateAShr(tmp, tmp);
  656. } break;
  657. case HLBinaryOpcode::LT:
  658. if (isFloat)
  659. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLT, tmp, tmp);
  660. else
  661. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLT, tmp, tmp);
  662. break;
  663. case HLBinaryOpcode::GT:
  664. if (isFloat)
  665. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGT, tmp, tmp);
  666. else
  667. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, tmp, tmp);
  668. break;
  669. case HLBinaryOpcode::LE:
  670. if (isFloat)
  671. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OLE, tmp, tmp);
  672. else
  673. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SLE, tmp, tmp);
  674. break;
  675. case HLBinaryOpcode::GE:
  676. if (isFloat)
  677. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OGE, tmp, tmp);
  678. else
  679. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGE, tmp, tmp);
  680. break;
  681. case HLBinaryOpcode::EQ:
  682. if (isFloat)
  683. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, tmp);
  684. else
  685. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, tmp);
  686. break;
  687. case HLBinaryOpcode::NE:
  688. if (isFloat)
  689. Result = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_ONE, tmp, tmp);
  690. else
  691. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, tmp, tmp);
  692. break;
  693. case HLBinaryOpcode::UDiv:
  694. Result = BinaryOperator::CreateUDiv(tmp, tmp);
  695. break;
  696. case HLBinaryOpcode::URem:
  697. Result = BinaryOperator::CreateURem(tmp, tmp);
  698. break;
  699. case HLBinaryOpcode::UShr: {
  700. Value *op1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  701. DXASSERT_LOCALVAR(op1, IsMatrixType(op1->getType()), "must be matrix type here");
  702. Result = BinaryOperator::CreateLShr(tmp, tmp);
  703. } break;
  704. case HLBinaryOpcode::ULT:
  705. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, tmp, tmp);
  706. break;
  707. case HLBinaryOpcode::UGT:
  708. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, tmp, tmp);
  709. break;
  710. case HLBinaryOpcode::ULE:
  711. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, tmp, tmp);
  712. break;
  713. case HLBinaryOpcode::UGE:
  714. Result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGE, tmp, tmp);
  715. break;
  716. case HLBinaryOpcode::LAnd:
  717. case HLBinaryOpcode::LOr: {
  718. Constant *zero;
  719. if (isFloat)
  720. zero = llvm::ConstantFP::get(ResultTy->getVectorElementType(), 0);
  721. else
  722. zero = llvm::ConstantInt::get(ResultTy->getVectorElementType(), 0);
  723. unsigned size = ResultTy->getVectorNumElements();
  724. std::vector<Constant *> zeros(size, zero);
  725. Value *vecZero = llvm::ConstantVector::get(zeros);
  726. Instruction *cmpL;
  727. if (isFloat)
  728. cmpL =
  729. CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
  730. else
  731. cmpL = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
  732. Builder.Insert(cmpL);
  733. Instruction *cmpR;
  734. if (isFloat)
  735. cmpR =
  736. CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, tmp, vecZero);
  737. else
  738. cmpR = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, tmp, vecZero);
  739. Builder.Insert(cmpR);
  740. // How to map l, r back? Need check opcode
  741. if (opcode == HLBinaryOpcode::LOr)
  742. Result = BinaryOperator::CreateAnd(cmpL, cmpR);
  743. else
  744. Result = BinaryOperator::CreateAnd(cmpL, cmpR);
  745. break;
  746. }
  747. default:
  748. DXASSERT(0, "not implement");
  749. return nullptr;
  750. }
  751. Builder.Insert(Result);
  752. return Result;
  753. }
  754. // Create BitCast if ptr, otherwise, create alloca of new type, write to bitcast of alloca, and return load from alloca
  755. // If bOrigAllocaTy is true: create alloca of old type instead, write to alloca, and return load from bitcast of alloca
  756. static Instruction *BitCastValueOrPtr(Value* V, Instruction *Insert, Type *Ty, bool bOrigAllocaTy = false, const Twine &Name = "") {
  757. IRBuilder<> Builder(Insert);
  758. if (Ty->isPointerTy()) {
  759. // If pointer, we can bitcast directly
  760. return cast<Instruction>(Builder.CreateBitCast(V, Ty, Name));
  761. } else {
  762. // If value, we have to alloca, store to bitcast ptr, and load
  763. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Insert));
  764. Type *allocaTy = bOrigAllocaTy ? V->getType() : Ty;
  765. Type *otherTy = bOrigAllocaTy ? Ty : V->getType();
  766. Instruction *allocaInst = AllocaBuilder.CreateAlloca(allocaTy);
  767. Instruction *bitCast = cast<Instruction>(Builder.CreateBitCast(allocaInst, otherTy->getPointerTo()));
  768. Builder.CreateStore(V, bOrigAllocaTy ? allocaInst : bitCast);
  769. return Builder.CreateLoad(bOrigAllocaTy ? bitCast : allocaInst, Name);
  770. }
  771. }
  772. void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
  773. Value *vecVal = nullptr;
  774. if (CallInst *CI = dyn_cast<CallInst>(matInst)) {
  775. hlsl::HLOpcodeGroup group =
  776. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  777. switch (group) {
  778. case HLOpcodeGroup::HLIntrinsic: {
  779. vecVal = MatIntrinsicToVec(CI);
  780. } break;
  781. case HLOpcodeGroup::HLSelect: {
  782. vecVal = MatIntrinsicToVec(CI);
  783. } break;
  784. case HLOpcodeGroup::HLBinOp: {
  785. vecVal = TrivialMatBinOpToVec(CI);
  786. } break;
  787. case HLOpcodeGroup::HLUnOp: {
  788. vecVal = TrivialMatUnOpToVec(CI);
  789. } break;
  790. case HLOpcodeGroup::HLCast: {
  791. vecVal = MatCastToVec(CI);
  792. } break;
  793. case HLOpcodeGroup::HLInit: {
  794. vecVal = MatIntrinsicToVec(CI);
  795. } break;
  796. case HLOpcodeGroup::HLMatLoadStore: {
  797. vecVal = MatLdStToVec(CI);
  798. } break;
  799. case HLOpcodeGroup::HLSubscript: {
  800. vecVal = MatSubscriptToVec(CI);
  801. } break;
  802. case HLOpcodeGroup::NotHL: {
  803. // Translate user function return
  804. vecVal = BitCastValueOrPtr( matInst,
  805. matInst->getNextNode(),
  806. HLMatrixLower::LowerMatrixType(matInst->getType()),
  807. /*bOrigAllocaTy*/ false,
  808. matInst->getName());
  809. // matrix equivalent of this new vector will be the original, retained user call
  810. vecToMatMap[vecVal] = matInst;
  811. } break;
  812. default:
  813. DXASSERT(0, "invalid inst");
  814. }
  815. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
  816. Type *Ty = AI->getAllocatedType();
  817. Type *matTy = Ty;
  818. IRBuilder<> AllocaBuilder(AI);
  819. if (Ty->isArrayTy()) {
  820. Type *vecTy = HLMatrixLower::LowerMatrixArrayPointer(AI->getType());
  821. vecTy = vecTy->getPointerElementType();
  822. vecVal = AllocaBuilder.CreateAlloca(vecTy, nullptr, AI->getName());
  823. } else {
  824. Type *vecTy = HLMatrixLower::LowerMatrixType(matTy);
  825. vecVal = AllocaBuilder.CreateAlloca(vecTy, nullptr, AI->getName());
  826. }
  827. // Update debug info.
  828. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI);
  829. if (DDI) {
  830. LLVMContext &Context = AI->getContext();
  831. Value *DDIVar = MetadataAsValue::get(Context, DDI->getRawVariable());
  832. Value *DDIExp = MetadataAsValue::get(Context, DDI->getRawExpression());
  833. Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(vecVal));
  834. IRBuilder<> debugBuilder(DDI);
  835. debugBuilder.CreateCall(DDI->getCalledFunction(), {VMD, DDIVar, DDIExp});
  836. }
  837. if (HLModule::HasPreciseAttributeWithMetadata(AI))
  838. HLModule::MarkPreciseAttributeWithMetadata(cast<Instruction>(vecVal));
  839. } else if (GetIfMatrixGEPOfUDTAlloca(matInst) ||
  840. GetIfMatrixGEPOfUDTArg(matInst, *m_pHLModule)) {
  841. // If GEP from alloca of non-matrix UDT, bitcast
  842. IRBuilder<> Builder(matInst->getNextNode());
  843. vecVal = Builder.CreateBitCast(matInst,
  844. HLMatrixLower::LowerMatrixType(
  845. matInst->getType()->getPointerElementType() )->getPointerTo());
  846. // matrix equivalent of this new vector will be the original, retained GEP
  847. vecToMatMap[vecVal] = matInst;
  848. } else {
  849. DXASSERT(0, "invalid inst");
  850. }
  851. if (vecVal) {
  852. matToVecMap[matInst] = vecVal;
  853. }
  854. }
  855. // Replace matInst with vecVal on matUseInst.
  856. void HLMatrixLowerPass::TrivialMatUnOpReplace(Value *matVal,
  857. Value *vecVal,
  858. CallInst *matUseInst) {
  859. (void)(matVal); // Unused
  860. HLUnaryOpcode opcode = static_cast<HLUnaryOpcode>(GetHLOpcode(matUseInst));
  861. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
  862. switch (opcode) {
  863. case HLUnaryOpcode::Not:
  864. // Not is xor now
  865. vecUseInst->setOperand(0, vecVal);
  866. vecUseInst->setOperand(1, vecVal);
  867. break;
  868. case HLUnaryOpcode::LNot:
  869. case HLUnaryOpcode::PostInc:
  870. case HLUnaryOpcode::PreInc:
  871. case HLUnaryOpcode::PostDec:
  872. case HLUnaryOpcode::PreDec:
  873. vecUseInst->setOperand(0, vecVal);
  874. break;
  875. case HLUnaryOpcode::Invalid:
  876. case HLUnaryOpcode::Plus:
  877. case HLUnaryOpcode::Minus:
  878. case HLUnaryOpcode::NumOfUO:
  879. // No VecInst replacements for these.
  880. break;
  881. }
  882. }
  883. // Replace matInst with vecVal on matUseInst.
  884. void HLMatrixLowerPass::TrivialMatBinOpReplace(Value *matVal,
  885. Value *vecVal,
  886. CallInst *matUseInst) {
  887. HLBinaryOpcode opcode = static_cast<HLBinaryOpcode>(GetHLOpcode(matUseInst));
  888. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matUseInst]);
  889. if (opcode != HLBinaryOpcode::LAnd && opcode != HLBinaryOpcode::LOr) {
  890. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) == matVal)
  891. vecUseInst->setOperand(0, vecVal);
  892. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) == matVal)
  893. vecUseInst->setOperand(1, vecVal);
  894. } else {
  895. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx) ==
  896. matVal) {
  897. Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(0));
  898. vecCmp->setOperand(0, vecVal);
  899. }
  900. if (matUseInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx) ==
  901. matVal) {
  902. Instruction *vecCmp = cast<Instruction>(vecUseInst->getOperand(1));
  903. vecCmp->setOperand(0, vecVal);
  904. }
  905. }
  906. }
  907. static Function *GetOrCreateMadIntrinsic(Type *Ty, Type *opcodeTy, IntrinsicOp madOp, Module &M) {
  908. llvm::FunctionType *MadFuncTy =
  909. llvm::FunctionType::get(Ty, { opcodeTy, Ty, Ty, Ty}, false);
  910. Function *MAD =
  911. GetOrCreateHLFunction(M, MadFuncTy, HLOpcodeGroup::HLIntrinsic,
  912. (unsigned)madOp);
  913. return MAD;
  914. }
  915. void HLMatrixLowerPass::TranslateMatMatMul(Value *matVal,
  916. Value *vecVal,
  917. CallInst *mulInst, bool isSigned) {
  918. (void)(matVal); // Unused; retrieved from matToVecMap directly
  919. DXASSERT(matToVecMap.count(mulInst), "must have vec version");
  920. Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
  921. // Already translated.
  922. if (!isa<CallInst>(vecUseInst))
  923. return;
  924. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  925. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  926. unsigned col, row;
  927. Type *EltTy = GetMatrixInfo(LVal->getType(), col, row);
  928. unsigned rCol, rRow;
  929. GetMatrixInfo(RVal->getType(), rCol, rRow);
  930. DXASSERT_NOMSG(col == rRow);
  931. bool isFloat = EltTy->isFloatingPointTy();
  932. Value *retVal = llvm::UndefValue::get(LowerMatrixType(mulInst->getType()));
  933. IRBuilder<> Builder(vecUseInst);
  934. Value *lMat = matToVecMap[cast<Instruction>(LVal)];
  935. Value *rMat = matToVecMap[cast<Instruction>(RVal)];
  936. auto CreateOneEltMul = [&](unsigned r, unsigned lc, unsigned c) -> Value * {
  937. unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
  938. unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
  939. Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
  940. Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
  941. return isFloat ? Builder.CreateFMul(lMatElt, rMatElt)
  942. : Builder.CreateMul(lMatElt, rMatElt);
  943. };
  944. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  945. Type *opcodeTy = Builder.getInt32Ty();
  946. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  947. *m_pHLModule->GetModule());
  948. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  949. auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
  950. Value *acc) -> Value * {
  951. unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
  952. unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
  953. Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
  954. Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
  955. return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
  956. };
  957. for (unsigned r = 0; r < row; r++) {
  958. for (unsigned c = 0; c < rCol; c++) {
  959. unsigned lc = 0;
  960. Value *tmpVal = CreateOneEltMul(r, lc, c);
  961. for (lc = 1; lc < col; lc++) {
  962. tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
  963. }
  964. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, rCol);
  965. retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
  966. }
  967. }
  968. Instruction *matmatMul = cast<Instruction>(retVal);
  969. // Replace vec transpose function call with shuf.
  970. vecUseInst->replaceAllUsesWith(matmatMul);
  971. AddToDeadInsts(vecUseInst);
  972. matToVecMap[mulInst] = matmatMul;
  973. }
  974. void HLMatrixLowerPass::TranslateMatVecMul(Value *matVal,
  975. Value *vecVal,
  976. CallInst *mulInst, bool isSigned) {
  977. // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  978. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  979. unsigned col, row;
  980. Type *EltTy = GetMatrixInfo(matVal->getType(), col, row);
  981. DXASSERT_NOMSG(RVal->getType()->getVectorNumElements() == col);
  982. bool isFloat = EltTy->isFloatingPointTy();
  983. Value *retVal = llvm::UndefValue::get(mulInst->getType());
  984. IRBuilder<> Builder(mulInst);
  985. Value *vec = RVal;
  986. Value *mat = vecVal; // vec version of matInst;
  987. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  988. Type *opcodeTy = Builder.getInt32Ty();
  989. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  990. *m_pHLModule->GetModule());
  991. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  992. auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
  993. Value *vecElt = Builder.CreateExtractElement(vec, c);
  994. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  995. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  996. return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
  997. };
  998. for (unsigned r = 0; r < row; r++) {
  999. unsigned c = 0;
  1000. Value *vecElt = Builder.CreateExtractElement(vec, c);
  1001. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  1002. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  1003. Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
  1004. : Builder.CreateMul(vecElt, matElt);
  1005. for (c = 1; c < col; c++) {
  1006. tmpVal = CreateOneEltMad(r, c, tmpVal);
  1007. }
  1008. retVal = Builder.CreateInsertElement(retVal, tmpVal, r);
  1009. }
  1010. mulInst->replaceAllUsesWith(retVal);
  1011. AddToDeadInsts(mulInst);
  1012. }
  1013. void HLMatrixLowerPass::TranslateVecMatMul(Value *matVal,
  1014. Value *vecVal,
  1015. CallInst *mulInst, bool isSigned) {
  1016. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  1017. // matVal should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  1018. Value *RVal = vecVal;
  1019. unsigned col, row;
  1020. Type *EltTy = GetMatrixInfo(matVal->getType(), col, row);
  1021. DXASSERT_NOMSG(LVal->getType()->getVectorNumElements() == row);
  1022. bool isFloat = EltTy->isFloatingPointTy();
  1023. Value *retVal = llvm::UndefValue::get(mulInst->getType());
  1024. IRBuilder<> Builder(mulInst);
  1025. Value *vec = LVal;
  1026. Value *mat = RVal;
  1027. IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
  1028. Type *opcodeTy = Builder.getInt32Ty();
  1029. Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
  1030. *m_pHLModule->GetModule());
  1031. Value *madOpArg = Builder.getInt32((unsigned)madOp);
  1032. auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
  1033. Value *vecElt = Builder.CreateExtractElement(vec, r);
  1034. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  1035. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  1036. return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
  1037. };
  1038. for (unsigned c = 0; c < col; c++) {
  1039. unsigned r = 0;
  1040. Value *vecElt = Builder.CreateExtractElement(vec, r);
  1041. uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
  1042. Value *matElt = Builder.CreateExtractElement(mat, matIdx);
  1043. Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
  1044. : Builder.CreateMul(vecElt, matElt);
  1045. for (r = 1; r < row; r++) {
  1046. tmpVal = CreateOneEltMad(r, c, tmpVal);
  1047. }
  1048. retVal = Builder.CreateInsertElement(retVal, tmpVal, c);
  1049. }
  1050. mulInst->replaceAllUsesWith(retVal);
  1051. AddToDeadInsts(mulInst);
  1052. }
  1053. void HLMatrixLowerPass::TranslateMul(Value *matVal, Value *vecVal,
  1054. CallInst *mulInst, bool isSigned) {
  1055. Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
  1056. Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
  1057. bool LMat = IsMatrixType(LVal->getType());
  1058. bool RMat = IsMatrixType(RVal->getType());
  1059. if (LMat && RMat) {
  1060. TranslateMatMatMul(matVal, vecVal, mulInst, isSigned);
  1061. } else if (LMat) {
  1062. TranslateMatVecMul(matVal, vecVal, mulInst, isSigned);
  1063. } else {
  1064. TranslateVecMatMul(matVal, vecVal, mulInst, isSigned);
  1065. }
  1066. }
  1067. void HLMatrixLowerPass::TranslateMatTranspose(Value *matVal,
  1068. Value *vecVal,
  1069. CallInst *transposeInst) {
  1070. // Matrix value is row major, transpose is cast it to col major.
  1071. TranslateMatMajorCast(matVal, vecVal, transposeInst,
  1072. /*bRowToCol*/ true, /*bTranspose*/ true);
  1073. }
  1074. static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
  1075. IRBuilder<> &Builder) {
  1076. Value *mul0 = Builder.CreateFMul(m00, m11);
  1077. Value *mul1 = Builder.CreateFMul(m01, m10);
  1078. return Builder.CreateFSub(mul0, mul1);
  1079. }
  1080. static Value *Determinant3x3(Value *m00, Value *m01, Value *m02,
  1081. Value *m10, Value *m11, Value *m12,
  1082. Value *m20, Value *m21, Value *m22,
  1083. IRBuilder<> &Builder) {
  1084. Value *deter00 = Determinant2x2(m11, m12, m21, m22, Builder);
  1085. Value *deter01 = Determinant2x2(m10, m12, m20, m22, Builder);
  1086. Value *deter02 = Determinant2x2(m10, m11, m20, m21, Builder);
  1087. deter00 = Builder.CreateFMul(m00, deter00);
  1088. deter01 = Builder.CreateFMul(m01, deter01);
  1089. deter02 = Builder.CreateFMul(m02, deter02);
  1090. Value *result = Builder.CreateFSub(deter00, deter01);
  1091. result = Builder.CreateFAdd(result, deter02);
  1092. return result;
  1093. }
  1094. static Value *Determinant4x4(Value *m00, Value *m01, Value *m02, Value *m03,
  1095. Value *m10, Value *m11, Value *m12, Value *m13,
  1096. Value *m20, Value *m21, Value *m22, Value *m23,
  1097. Value *m30, Value *m31, Value *m32, Value *m33,
  1098. IRBuilder<> &Builder) {
  1099. Value *deter00 = Determinant3x3(m11, m12, m13, m21, m22, m23, m31, m32, m33, Builder);
  1100. Value *deter01 = Determinant3x3(m10, m12, m13, m20, m22, m23, m30, m32, m33, Builder);
  1101. Value *deter02 = Determinant3x3(m10, m11, m13, m20, m21, m23, m30, m31, m33, Builder);
  1102. Value *deter03 = Determinant3x3(m10, m11, m12, m20, m21, m22, m30, m31, m32, Builder);
  1103. deter00 = Builder.CreateFMul(m00, deter00);
  1104. deter01 = Builder.CreateFMul(m01, deter01);
  1105. deter02 = Builder.CreateFMul(m02, deter02);
  1106. deter03 = Builder.CreateFMul(m03, deter03);
  1107. Value *result = Builder.CreateFSub(deter00, deter01);
  1108. result = Builder.CreateFAdd(result, deter02);
  1109. result = Builder.CreateFSub(result, deter03);
  1110. return result;
  1111. }
  1112. void HLMatrixLowerPass::TranslateMatDeterminant(Value *matVal, Value *vecVal,
  1113. CallInst *determinantInst) {
  1114. unsigned row, col;
  1115. GetMatrixInfo(matVal->getType(), col, row);
  1116. IRBuilder<> Builder(determinantInst);
  1117. // when row == 1, result is vecVal.
  1118. Value *Result = vecVal;
  1119. if (row == 2) {
  1120. Value *m00 = Builder.CreateExtractElement(vecVal, (uint64_t)0);
  1121. Value *m01 = Builder.CreateExtractElement(vecVal, 1);
  1122. Value *m10 = Builder.CreateExtractElement(vecVal, 2);
  1123. Value *m11 = Builder.CreateExtractElement(vecVal, 3);
  1124. Result = Determinant2x2(m00, m01, m10, m11, Builder);
  1125. }
  1126. else if (row == 3) {
  1127. Value *m00 = Builder.CreateExtractElement(vecVal, (uint64_t)0);
  1128. Value *m01 = Builder.CreateExtractElement(vecVal, 1);
  1129. Value *m02 = Builder.CreateExtractElement(vecVal, 2);
  1130. Value *m10 = Builder.CreateExtractElement(vecVal, 3);
  1131. Value *m11 = Builder.CreateExtractElement(vecVal, 4);
  1132. Value *m12 = Builder.CreateExtractElement(vecVal, 5);
  1133. Value *m20 = Builder.CreateExtractElement(vecVal, 6);
  1134. Value *m21 = Builder.CreateExtractElement(vecVal, 7);
  1135. Value *m22 = Builder.CreateExtractElement(vecVal, 8);
  1136. Result = Determinant3x3(m00, m01, m02,
  1137. m10, m11, m12,
  1138. m20, m21, m22, Builder);
  1139. }
  1140. else if (row == 4) {
  1141. Value *m00 = Builder.CreateExtractElement(vecVal, (uint64_t)0);
  1142. Value *m01 = Builder.CreateExtractElement(vecVal, 1);
  1143. Value *m02 = Builder.CreateExtractElement(vecVal, 2);
  1144. Value *m03 = Builder.CreateExtractElement(vecVal, 3);
  1145. Value *m10 = Builder.CreateExtractElement(vecVal, 4);
  1146. Value *m11 = Builder.CreateExtractElement(vecVal, 5);
  1147. Value *m12 = Builder.CreateExtractElement(vecVal, 6);
  1148. Value *m13 = Builder.CreateExtractElement(vecVal, 7);
  1149. Value *m20 = Builder.CreateExtractElement(vecVal, 8);
  1150. Value *m21 = Builder.CreateExtractElement(vecVal, 9);
  1151. Value *m22 = Builder.CreateExtractElement(vecVal, 10);
  1152. Value *m23 = Builder.CreateExtractElement(vecVal, 11);
  1153. Value *m30 = Builder.CreateExtractElement(vecVal, 12);
  1154. Value *m31 = Builder.CreateExtractElement(vecVal, 13);
  1155. Value *m32 = Builder.CreateExtractElement(vecVal, 14);
  1156. Value *m33 = Builder.CreateExtractElement(vecVal, 15);
  1157. Result = Determinant4x4(m00, m01, m02, m03,
  1158. m10, m11, m12, m13,
  1159. m20, m21, m22, m23,
  1160. m30, m31, m32, m33,
  1161. Builder);
  1162. } else {
  1163. DXASSERT(row == 1, "invalid matrix type");
  1164. Result = Builder.CreateExtractElement(Result, (uint64_t)0);
  1165. }
  1166. determinantInst->replaceAllUsesWith(Result);
  1167. AddToDeadInsts(determinantInst);
  1168. }
  1169. void HLMatrixLowerPass::TrivialMatReplace(Value *matVal,
  1170. Value *vecVal,
  1171. CallInst *matUseInst) {
  1172. CallInst *vecUseInst = cast<CallInst>(matToVecMap[matUseInst]);
  1173. for (unsigned i = 0; i < matUseInst->getNumArgOperands(); i++)
  1174. if (matUseInst->getArgOperand(i) == matVal) {
  1175. vecUseInst->setArgOperand(i, vecVal);
  1176. }
  1177. }
  1178. static Instruction *CreateTransposeShuffle(IRBuilder<> &Builder, Value *vecVal, unsigned toRows, unsigned toCols) {
  1179. SmallVector<int, 16> castMask(toCols * toRows);
  1180. unsigned idx = 0;
  1181. for (unsigned r = 0; r < toRows; r++)
  1182. for (unsigned c = 0; c < toCols; c++)
  1183. castMask[idx++] = c * toRows + r;
  1184. return cast<Instruction>(
  1185. Builder.CreateShuffleVector(vecVal, vecVal, castMask));
  1186. }
  1187. void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
  1188. Value *vecVal,
  1189. CallInst *castInst,
  1190. bool bRowToCol,
  1191. bool bTranspose) {
  1192. unsigned col, row;
  1193. if (!bTranspose) {
  1194. GetMatrixInfo(castInst->getType(), col, row);
  1195. DXASSERT(castInst->getType() == matVal->getType(), "type must match");
  1196. } else {
  1197. unsigned castCol, castRow;
  1198. Type *castTy = GetMatrixInfo(castInst->getType(), castCol, castRow);
  1199. unsigned srcCol, srcRow;
  1200. Type *srcTy = GetMatrixInfo(matVal->getType(), srcCol, srcRow);
  1201. DXASSERT_LOCALVAR((castTy == srcTy), srcTy == castTy, "type must match");
  1202. DXASSERT(castCol == srcRow && castRow == srcCol, "col row must match");
  1203. col = srcCol;
  1204. row = srcRow;
  1205. }
  1206. DXASSERT(matToVecMap.count(castInst), "must have vec version");
  1207. Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
  1208. // Create before vecUseInst to prevent instructions being inserted after uses.
  1209. IRBuilder<> Builder(vecUseInst);
  1210. if (bRowToCol)
  1211. std::swap(row, col);
  1212. Instruction *vecCast = CreateTransposeShuffle(Builder, vecVal, row, col);
  1213. // Replace vec cast function call with vecCast.
  1214. vecUseInst->replaceAllUsesWith(vecCast);
  1215. AddToDeadInsts(vecUseInst);
  1216. matToVecMap[castInst] = vecCast;
  1217. }
  1218. void HLMatrixLowerPass::TranslateMatMatCast(Value *matVal,
  1219. Value *vecVal,
  1220. CallInst *castInst) {
  1221. unsigned toCol, toRow;
  1222. Type *ToEltTy = GetMatrixInfo(castInst->getType(), toCol, toRow);
  1223. unsigned fromCol, fromRow;
  1224. Type *FromEltTy = GetMatrixInfo(matVal->getType(), fromCol, fromRow);
  1225. unsigned fromSize = fromCol * fromRow;
  1226. unsigned toSize = toCol * toRow;
  1227. DXASSERT(fromSize >= toSize, "cannot extend matrix");
  1228. DXASSERT(matToVecMap.count(castInst), "must have vec version");
  1229. Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
  1230. IRBuilder<> Builder(vecUseInst);
  1231. Instruction *vecCast = nullptr;
  1232. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1233. if (fromSize == toSize) {
  1234. vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), vecVal,
  1235. Builder);
  1236. } else {
  1237. // shuf first
  1238. std::vector<int> castMask(toCol * toRow);
  1239. unsigned idx = 0;
  1240. for (unsigned r = 0; r < toRow; r++)
  1241. for (unsigned c = 0; c < toCol; c++) {
  1242. unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, fromCol);
  1243. castMask[idx++] = matIdx;
  1244. }
  1245. Instruction *shuf = cast<Instruction>(
  1246. Builder.CreateShuffleVector(vecVal, vecVal, castMask));
  1247. if (ToEltTy != FromEltTy)
  1248. vecCast = CreateTypeCast(opcode, VectorType::get(ToEltTy, toSize), shuf,
  1249. Builder);
  1250. else
  1251. vecCast = shuf;
  1252. }
  1253. // Replace vec cast function call with vecCast.
  1254. vecUseInst->replaceAllUsesWith(vecCast);
  1255. AddToDeadInsts(vecUseInst);
  1256. matToVecMap[castInst] = vecCast;
  1257. }
  1258. void HLMatrixLowerPass::TranslateMatToOtherCast(Value *matVal,
  1259. Value *vecVal,
  1260. CallInst *castInst) {
  1261. unsigned col, row;
  1262. Type *EltTy = GetMatrixInfo(matVal->getType(), col, row);
  1263. unsigned fromSize = col * row;
  1264. IRBuilder<> Builder(castInst);
  1265. Value *sizeCast = nullptr;
  1266. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1267. Type *ToTy = castInst->getType();
  1268. if (ToTy->isVectorTy()) {
  1269. unsigned toSize = ToTy->getVectorNumElements();
  1270. if (fromSize != toSize) {
  1271. std::vector<int> castMask(fromSize);
  1272. for (unsigned c = 0; c < toSize; c++)
  1273. castMask[c] = c;
  1274. sizeCast = Builder.CreateShuffleVector(vecVal, vecVal, castMask);
  1275. } else
  1276. sizeCast = vecVal;
  1277. } else {
  1278. DXASSERT(ToTy->isSingleValueType(), "must scalar here");
  1279. sizeCast = Builder.CreateExtractElement(vecVal, (uint64_t)0);
  1280. }
  1281. Value *typeCast = sizeCast;
  1282. if (EltTy != ToTy->getScalarType()) {
  1283. typeCast = CreateTypeCast(opcode, ToTy, typeCast, Builder);
  1284. }
  1285. // Replace cast function call with typeCast.
  1286. castInst->replaceAllUsesWith(typeCast);
  1287. AddToDeadInsts(castInst);
  1288. }
  1289. void HLMatrixLowerPass::TranslateMatCast(Value *matVal,
  1290. Value *vecVal,
  1291. CallInst *castInst) {
  1292. HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
  1293. if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
  1294. opcode == HLCastOpcode::RowMatrixToColMatrix) {
  1295. TranslateMatMajorCast(matVal, vecVal, castInst,
  1296. opcode == HLCastOpcode::RowMatrixToColMatrix,
  1297. /*bTranspose*/false);
  1298. } else {
  1299. bool ToMat = IsMatrixType(castInst->getType());
  1300. bool FromMat = IsMatrixType(matVal->getType());
  1301. if (ToMat && FromMat) {
  1302. TranslateMatMatCast(matVal, vecVal, castInst);
  1303. } else if (FromMat)
  1304. TranslateMatToOtherCast(matVal, vecVal, castInst);
  1305. else {
  1306. DXASSERT(0, "Not translate as user of matInst");
  1307. }
  1308. }
  1309. }
  1310. void HLMatrixLowerPass::MatIntrinsicReplace(Value *matVal,
  1311. Value *vecVal,
  1312. CallInst *matUseInst) {
  1313. IRBuilder<> Builder(matUseInst);
  1314. IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(matUseInst));
  1315. switch (opcode) {
  1316. case IntrinsicOp::IOP_umul:
  1317. TranslateMul(matVal, vecVal, matUseInst, /*isSigned*/false);
  1318. break;
  1319. case IntrinsicOp::IOP_mul:
  1320. TranslateMul(matVal, vecVal, matUseInst, /*isSigned*/true);
  1321. break;
  1322. case IntrinsicOp::IOP_transpose:
  1323. TranslateMatTranspose(matVal, vecVal, matUseInst);
  1324. break;
  1325. case IntrinsicOp::IOP_determinant:
  1326. TranslateMatDeterminant(matVal, vecVal, matUseInst);
  1327. break;
  1328. default:
  1329. CallInst *useInst = matUseInst;
  1330. if (matToVecMap.count(matUseInst))
  1331. useInst = cast<CallInst>(matToVecMap[matUseInst]);
  1332. for (unsigned i = 0; i < useInst->getNumArgOperands(); i++) {
  1333. if (matUseInst->getArgOperand(i) == matVal)
  1334. useInst->setArgOperand(i, vecVal);
  1335. }
  1336. break;
  1337. }
  1338. }
  1339. void HLMatrixLowerPass::TranslateMatSubscript(Value *matVal, Value *vecVal,
  1340. CallInst *matSubInst) {
  1341. unsigned opcode = GetHLOpcode(matSubInst);
  1342. HLSubscriptOpcode matOpcode = static_cast<HLSubscriptOpcode>(opcode);
  1343. assert(matOpcode != HLSubscriptOpcode::DefaultSubscript &&
  1344. "matrix don't use default subscript");
  1345. Type *matType = matVal->getType()->getPointerElementType();
  1346. unsigned col, row;
  1347. Type *EltTy = HLMatrixLower::GetMatrixInfo(matType, col, row);
  1348. bool isElement = (matOpcode == HLSubscriptOpcode::ColMatElement) |
  1349. (matOpcode == HLSubscriptOpcode::RowMatElement);
  1350. Value *mask =
  1351. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1352. if (isElement) {
  1353. Type *resultType = matSubInst->getType()->getPointerElementType();
  1354. unsigned resultSize = 1;
  1355. if (resultType->isVectorTy())
  1356. resultSize = resultType->getVectorNumElements();
  1357. std::vector<int> shufMask(resultSize);
  1358. Constant *EltIdxs = cast<Constant>(mask);
  1359. for (unsigned i = 0; i < resultSize; i++) {
  1360. shufMask[i] =
  1361. cast<ConstantInt>(EltIdxs->getAggregateElement(i))->getLimitedValue();
  1362. }
  1363. for (Value::use_iterator CallUI = matSubInst->use_begin(),
  1364. CallE = matSubInst->use_end();
  1365. CallUI != CallE;) {
  1366. Use &CallUse = *CallUI++;
  1367. Instruction *CallUser = cast<Instruction>(CallUse.getUser());
  1368. IRBuilder<> Builder(CallUser);
  1369. Value *vecLd = Builder.CreateLoad(vecVal);
  1370. if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
  1371. if (resultSize > 1) {
  1372. Value *shuf = Builder.CreateShuffleVector(vecLd, vecLd, shufMask);
  1373. ld->replaceAllUsesWith(shuf);
  1374. } else {
  1375. Value *elt = Builder.CreateExtractElement(vecLd, shufMask[0]);
  1376. ld->replaceAllUsesWith(elt);
  1377. }
  1378. } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
  1379. Value *val = st->getValueOperand();
  1380. if (resultSize > 1) {
  1381. for (unsigned i = 0; i < shufMask.size(); i++) {
  1382. unsigned idx = shufMask[i];
  1383. Value *valElt = Builder.CreateExtractElement(val, i);
  1384. vecLd = Builder.CreateInsertElement(vecLd, valElt, idx);
  1385. }
  1386. Builder.CreateStore(vecLd, vecVal);
  1387. } else {
  1388. vecLd = Builder.CreateInsertElement(vecLd, val, shufMask[0]);
  1389. Builder.CreateStore(vecLd, vecVal);
  1390. }
  1391. } else {
  1392. DXASSERT(0, "matrix element should only used by load/store.");
  1393. }
  1394. AddToDeadInsts(CallUser);
  1395. }
  1396. } else {
  1397. // Subscript.
  1398. // Return a row.
  1399. // Use insertElement and extractElement.
  1400. ArrayType *AT = ArrayType::get(EltTy, col*row);
  1401. IRBuilder<> AllocaBuilder(
  1402. matSubInst->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  1403. Value *tempArray = AllocaBuilder.CreateAlloca(AT);
  1404. Value *zero = AllocaBuilder.getInt32(0);
  1405. bool isDynamicIndexing = !isa<ConstantInt>(mask);
  1406. SmallVector<Value *, 4> idxList;
  1407. for (unsigned i = 0; i < col; i++) {
  1408. idxList.emplace_back(
  1409. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i));
  1410. }
  1411. for (Value::use_iterator CallUI = matSubInst->use_begin(),
  1412. CallE = matSubInst->use_end();
  1413. CallUI != CallE;) {
  1414. Use &CallUse = *CallUI++;
  1415. Instruction *CallUser = cast<Instruction>(CallUse.getUser());
  1416. IRBuilder<> Builder(CallUser);
  1417. Value *vecLd = Builder.CreateLoad(vecVal);
  1418. if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
  1419. Value *sub = UndefValue::get(ld->getType());
  1420. if (!isDynamicIndexing) {
  1421. for (unsigned i = 0; i < col; i++) {
  1422. Value *matIdx = idxList[i];
  1423. Value *valElt = Builder.CreateExtractElement(vecLd, matIdx);
  1424. sub = Builder.CreateInsertElement(sub, valElt, i);
  1425. }
  1426. } else {
  1427. // Copy vec to array.
  1428. for (unsigned int i = 0; i < row*col; i++) {
  1429. Value *Elt =
  1430. Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
  1431. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1432. {zero, Builder.getInt32(i)});
  1433. Builder.CreateStore(Elt, Ptr);
  1434. }
  1435. for (unsigned i = 0; i < col; i++) {
  1436. Value *matIdx = idxList[i];
  1437. Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
  1438. Value *valElt = Builder.CreateLoad(Ptr);
  1439. sub = Builder.CreateInsertElement(sub, valElt, i);
  1440. }
  1441. }
  1442. ld->replaceAllUsesWith(sub);
  1443. } else if (StoreInst *st = dyn_cast<StoreInst>(CallUser)) {
  1444. Value *val = st->getValueOperand();
  1445. if (!isDynamicIndexing) {
  1446. for (unsigned i = 0; i < col; i++) {
  1447. Value *matIdx = idxList[i];
  1448. Value *valElt = Builder.CreateExtractElement(val, i);
  1449. vecLd = Builder.CreateInsertElement(vecLd, valElt, matIdx);
  1450. }
  1451. } else {
  1452. // Copy vec to array.
  1453. for (unsigned int i = 0; i < row * col; i++) {
  1454. Value *Elt =
  1455. Builder.CreateExtractElement(vecLd, Builder.getInt32(i));
  1456. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1457. {zero, Builder.getInt32(i)});
  1458. Builder.CreateStore(Elt, Ptr);
  1459. }
  1460. // Update array.
  1461. for (unsigned i = 0; i < col; i++) {
  1462. Value *matIdx = idxList[i];
  1463. Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
  1464. Value *valElt = Builder.CreateExtractElement(val, i);
  1465. Builder.CreateStore(valElt, Ptr);
  1466. }
  1467. // Copy array to vec.
  1468. for (unsigned int i = 0; i < row * col; i++) {
  1469. Value *Ptr = Builder.CreateInBoundsGEP(tempArray,
  1470. {zero, Builder.getInt32(i)});
  1471. Value *Elt = Builder.CreateLoad(Ptr);
  1472. vecLd = Builder.CreateInsertElement(vecLd, Elt, i);
  1473. }
  1474. }
  1475. Builder.CreateStore(vecLd, vecVal);
  1476. } else if (GetElementPtrInst *GEP =
  1477. dyn_cast<GetElementPtrInst>(CallUser)) {
  1478. Value *GEPOffset = HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
  1479. Value *NewGEP = Builder.CreateGEP(vecVal, {zero, GEPOffset});
  1480. GEP->replaceAllUsesWith(NewGEP);
  1481. } else {
  1482. DXASSERT(0, "matrix subscript should only used by load/store.");
  1483. }
  1484. AddToDeadInsts(CallUser);
  1485. }
  1486. }
  1487. // Check vec version.
  1488. DXASSERT(matToVecMap.count(matSubInst) == 0, "should not have vec version");
  1489. // All the user should have been removed.
  1490. matSubInst->replaceAllUsesWith(UndefValue::get(matSubInst->getType()));
  1491. AddToDeadInsts(matSubInst);
  1492. }
  1493. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
  1494. Value *matGlobal, ArrayRef<Value *> vecGlobals,
  1495. CallInst *matLdStInst) {
  1496. // No dynamic indexing on matrix, flatten matrix to scalars.
  1497. // vecGlobals already in correct major.
  1498. Type *matType = matGlobal->getType()->getPointerElementType();
  1499. unsigned col, row;
  1500. HLMatrixLower::GetMatrixInfo(matType, col, row);
  1501. Type *vecType = HLMatrixLower::LowerMatrixType(matType);
  1502. IRBuilder<> Builder(matLdStInst);
  1503. HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
  1504. switch (opcode) {
  1505. case HLMatLoadStoreOpcode::ColMatLoad:
  1506. case HLMatLoadStoreOpcode::RowMatLoad: {
  1507. Value *Result = UndefValue::get(vecType);
  1508. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1509. Value *Elt = Builder.CreateLoad(vecGlobals[matIdx]);
  1510. Result = Builder.CreateInsertElement(Result, Elt, matIdx);
  1511. }
  1512. matLdStInst->replaceAllUsesWith(Result);
  1513. } break;
  1514. case HLMatLoadStoreOpcode::ColMatStore:
  1515. case HLMatLoadStoreOpcode::RowMatStore: {
  1516. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1517. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1518. Value *Elt = Builder.CreateExtractElement(Val, matIdx);
  1519. Builder.CreateStore(Elt, vecGlobals[matIdx]);
  1520. }
  1521. } break;
  1522. }
  1523. }
  1524. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
  1525. GlobalVariable *scalarArrayGlobal,
  1526. CallInst *matLdStInst) {
  1527. // vecGlobals already in correct major.
  1528. HLMatLoadStoreOpcode opcode =
  1529. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
  1530. switch (opcode) {
  1531. case HLMatLoadStoreOpcode::ColMatLoad:
  1532. case HLMatLoadStoreOpcode::RowMatLoad: {
  1533. IRBuilder<> Builder(matLdStInst);
  1534. Type *matTy = matGlobal->getType()->getPointerElementType();
  1535. unsigned col, row;
  1536. Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1537. Value *zeroIdx = Builder.getInt32(0);
  1538. std::vector<Value *> matElts(col * row);
  1539. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1540. Value *GEP = Builder.CreateInBoundsGEP(
  1541. scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
  1542. matElts[matIdx] = Builder.CreateLoad(GEP);
  1543. }
  1544. Value *newVec =
  1545. HLMatrixLower::BuildVector(EltTy, col * row, matElts, Builder);
  1546. matLdStInst->replaceAllUsesWith(newVec);
  1547. matLdStInst->eraseFromParent();
  1548. } break;
  1549. case HLMatLoadStoreOpcode::ColMatStore:
  1550. case HLMatLoadStoreOpcode::RowMatStore: {
  1551. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1552. IRBuilder<> Builder(matLdStInst);
  1553. Type *matTy = matGlobal->getType()->getPointerElementType();
  1554. unsigned col, row;
  1555. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1556. Value *zeroIdx = Builder.getInt32(0);
  1557. std::vector<Value *> matElts(col * row);
  1558. for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
  1559. Value *GEP = Builder.CreateInBoundsGEP(
  1560. scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
  1561. Value *Elt = Builder.CreateExtractElement(Val, matIdx);
  1562. Builder.CreateStore(Elt, GEP);
  1563. }
  1564. matLdStInst->eraseFromParent();
  1565. } break;
  1566. }
  1567. }
  1568. void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(
  1569. CallInst *matSubInst, Value *vecPtr) {
  1570. Value *basePtr =
  1571. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  1572. Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
  1573. IRBuilder<> subBuilder(matSubInst);
  1574. Value *zeroIdx = subBuilder.getInt32(0);
  1575. HLSubscriptOpcode opcode =
  1576. static_cast<HLSubscriptOpcode>(GetHLOpcode(matSubInst));
  1577. Type *matTy = basePtr->getType()->getPointerElementType();
  1578. unsigned col, row;
  1579. HLMatrixLower::GetMatrixInfo(matTy, col, row);
  1580. std::vector<Value *> idxList;
  1581. switch (opcode) {
  1582. case HLSubscriptOpcode::ColMatSubscript:
  1583. case HLSubscriptOpcode::RowMatSubscript: {
  1584. // Just use index created in EmitHLSLMatrixSubscript.
  1585. for (unsigned c = 0; c < col; c++) {
  1586. Value *matIdx =
  1587. matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + c);
  1588. idxList.emplace_back(matIdx);
  1589. }
  1590. } break;
  1591. case HLSubscriptOpcode::RowMatElement:
  1592. case HLSubscriptOpcode::ColMatElement: {
  1593. Type *resultType = matSubInst->getType()->getPointerElementType();
  1594. unsigned resultSize = 1;
  1595. if (resultType->isVectorTy())
  1596. resultSize = resultType->getVectorNumElements();
  1597. // Just use index created in EmitHLSLMatrixElement.
  1598. Constant *EltIdxs = cast<Constant>(idx);
  1599. for (unsigned i = 0; i < resultSize; i++) {
  1600. Value *matIdx = EltIdxs->getAggregateElement(i);
  1601. idxList.emplace_back(matIdx);
  1602. }
  1603. } break;
  1604. default:
  1605. DXASSERT(0, "invalid operation");
  1606. break;
  1607. }
  1608. // Cannot generate vector pointer
  1609. // Replace all uses with scalar pointers.
  1610. if (idxList.size() == 1) {
  1611. Value *Ptr =
  1612. subBuilder.CreateInBoundsGEP(vecPtr, {zeroIdx, idxList[0]});
  1613. matSubInst->replaceAllUsesWith(Ptr);
  1614. } else {
  1615. // Split the use of CI with Ptrs.
  1616. for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
  1617. Instruction *subsUser = cast<Instruction>(*(U++));
  1618. IRBuilder<> userBuilder(subsUser);
  1619. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
  1620. Value *IndexPtr =
  1621. HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
  1622. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1623. {zeroIdx, IndexPtr});
  1624. for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
  1625. Instruction *gepUser = cast<Instruction>(*(gepU++));
  1626. IRBuilder<> gepUserBuilder(gepUser);
  1627. if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
  1628. Value *subData = stUser->getValueOperand();
  1629. gepUserBuilder.CreateStore(subData, Ptr);
  1630. stUser->eraseFromParent();
  1631. } else if (LoadInst *ldUser = dyn_cast<LoadInst>(gepUser)) {
  1632. Value *subData = gepUserBuilder.CreateLoad(Ptr);
  1633. ldUser->replaceAllUsesWith(subData);
  1634. ldUser->eraseFromParent();
  1635. } else {
  1636. AddrSpaceCastInst *Cast = cast<AddrSpaceCastInst>(gepUser);
  1637. Cast->setOperand(0, Ptr);
  1638. }
  1639. }
  1640. GEP->eraseFromParent();
  1641. } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
  1642. Value *val = stUser->getValueOperand();
  1643. for (unsigned i = 0; i < idxList.size(); i++) {
  1644. Value *Elt = userBuilder.CreateExtractElement(val, i);
  1645. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1646. {zeroIdx, idxList[i]});
  1647. userBuilder.CreateStore(Elt, Ptr);
  1648. }
  1649. stUser->eraseFromParent();
  1650. } else {
  1651. Value *ldVal =
  1652. UndefValue::get(matSubInst->getType()->getPointerElementType());
  1653. for (unsigned i = 0; i < idxList.size(); i++) {
  1654. Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
  1655. {zeroIdx, idxList[i]});
  1656. Value *Elt = userBuilder.CreateLoad(Ptr);
  1657. ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
  1658. }
  1659. // Must be load here.
  1660. LoadInst *ldUser = cast<LoadInst>(subsUser);
  1661. ldUser->replaceAllUsesWith(ldVal);
  1662. ldUser->eraseFromParent();
  1663. }
  1664. }
  1665. }
  1666. matSubInst->eraseFromParent();
  1667. }
  1668. void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
  1669. CallInst *matLdStInst, Value *vecPtr) {
  1670. // Just translate into vector here.
  1671. // DynamicIndexingVectorToArray will change it to scalar array.
  1672. IRBuilder<> Builder(matLdStInst);
  1673. unsigned opcode = hlsl::GetHLOpcode(matLdStInst);
  1674. HLMatLoadStoreOpcode matLdStOp = static_cast<HLMatLoadStoreOpcode>(opcode);
  1675. switch (matLdStOp) {
  1676. case HLMatLoadStoreOpcode::ColMatLoad:
  1677. case HLMatLoadStoreOpcode::RowMatLoad: {
  1678. // Load as vector.
  1679. Value *newLoad = Builder.CreateLoad(vecPtr);
  1680. matLdStInst->replaceAllUsesWith(newLoad);
  1681. matLdStInst->eraseFromParent();
  1682. } break;
  1683. case HLMatLoadStoreOpcode::ColMatStore:
  1684. case HLMatLoadStoreOpcode::RowMatStore: {
  1685. // Change value to vector array, then store.
  1686. Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1687. Value *vecArrayGep = vecPtr;
  1688. Builder.CreateStore(Val, vecArrayGep);
  1689. matLdStInst->eraseFromParent();
  1690. } break;
  1691. default:
  1692. DXASSERT(0, "invalid operation");
  1693. break;
  1694. }
  1695. }
  1696. // Flatten values inside init list to scalar elements.
  1697. static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
  1698. Value *val,
  1699. DenseMap<Value *, Value *> &matToVecMap,
  1700. IRBuilder<> &Builder) {
  1701. Type *valTy = val->getType();
  1702. if (valTy->isPointerTy()) {
  1703. if (HLMatrixLower::IsMatrixArrayPointer(valTy)) {
  1704. if (matToVecMap.count(cast<Instruction>(val))) {
  1705. val = matToVecMap[cast<Instruction>(val)];
  1706. } else {
  1707. // Convert to vec array with bitcast.
  1708. Type *vecArrayPtrTy = HLMatrixLower::LowerMatrixArrayPointer(valTy);
  1709. val = Builder.CreateBitCast(val, vecArrayPtrTy);
  1710. }
  1711. }
  1712. Type *valEltTy = val->getType()->getPointerElementType();
  1713. if (valEltTy->isVectorTy() || HLMatrixLower::IsMatrixType(valEltTy) ||
  1714. valEltTy->isSingleValueType()) {
  1715. Value *ldVal = Builder.CreateLoad(val);
  1716. IterateInitList(elts, idx, ldVal, matToVecMap, Builder);
  1717. } else {
  1718. Type *i32Ty = Type::getInt32Ty(valTy->getContext());
  1719. Value *zero = ConstantInt::get(i32Ty, 0);
  1720. if (ArrayType *AT = dyn_cast<ArrayType>(valEltTy)) {
  1721. for (unsigned i = 0; i < AT->getArrayNumElements(); i++) {
  1722. Value *gepIdx = ConstantInt::get(i32Ty, i);
  1723. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  1724. IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
  1725. }
  1726. } else {
  1727. // Struct.
  1728. StructType *ST = cast<StructType>(valEltTy);
  1729. for (unsigned i = 0; i < ST->getNumElements(); i++) {
  1730. Value *gepIdx = ConstantInt::get(i32Ty, i);
  1731. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  1732. IterateInitList(elts, idx, EltPtr, matToVecMap, Builder);
  1733. }
  1734. }
  1735. }
  1736. } else if (HLMatrixLower::IsMatrixType(valTy)) {
  1737. unsigned col, row;
  1738. HLMatrixLower::GetMatrixInfo(valTy, col, row);
  1739. unsigned matSize = col * row;
  1740. val = matToVecMap[cast<Instruction>(val)];
  1741. // temp matrix all row major
  1742. for (unsigned i = 0; i < matSize; i++) {
  1743. Value *Elt = Builder.CreateExtractElement(val, i);
  1744. elts[idx + i] = Elt;
  1745. }
  1746. idx += matSize;
  1747. } else {
  1748. if (valTy->isVectorTy()) {
  1749. unsigned vecSize = valTy->getVectorNumElements();
  1750. for (unsigned i = 0; i < vecSize; i++) {
  1751. Value *Elt = Builder.CreateExtractElement(val, i);
  1752. elts[idx + i] = Elt;
  1753. }
  1754. idx += vecSize;
  1755. } else {
  1756. DXASSERT(valTy->isSingleValueType(), "must be single value type here");
  1757. elts[idx++] = val;
  1758. }
  1759. }
  1760. }
  1761. void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
  1762. // Array matrix init will be translated in TranslateMatArrayInitReplace.
  1763. if (matInitInst->getType()->isVoidTy())
  1764. return;
  1765. DXASSERT(matToVecMap.count(matInitInst), "must have vec version");
  1766. Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
  1767. IRBuilder<> Builder(vecUseInst);
  1768. unsigned col, row;
  1769. Type *EltTy = GetMatrixInfo(matInitInst->getType(), col, row);
  1770. Type *vecTy = VectorType::get(EltTy, col * row);
  1771. unsigned vecSize = vecTy->getVectorNumElements();
  1772. unsigned idx = 0;
  1773. std::vector<Value *> elts(vecSize);
  1774. // Skip opcode arg.
  1775. for (unsigned i = 1; i < matInitInst->getNumArgOperands(); i++) {
  1776. Value *val = matInitInst->getArgOperand(i);
  1777. IterateInitList(elts, idx, val, matToVecMap, Builder);
  1778. }
  1779. Value *newInit = UndefValue::get(vecTy);
  1780. // InitList is row major, the result is row major too.
  1781. for (unsigned i=0;i< col * row;i++) {
  1782. Constant *vecIdx = Builder.getInt32(i);
  1783. newInit = InsertElementInst::Create(newInit, elts[i], vecIdx);
  1784. Builder.Insert(cast<Instruction>(newInit));
  1785. }
  1786. // Replace matInit function call with matInitInst.
  1787. vecUseInst->replaceAllUsesWith(newInit);
  1788. AddToDeadInsts(vecUseInst);
  1789. matToVecMap[matInitInst] = newInit;
  1790. }
  1791. void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
  1792. unsigned col, row;
  1793. Type *EltTy = GetMatrixInfo(matSelectInst->getType(), col, row);
  1794. Type *vecTy = VectorType::get(EltTy, col * row);
  1795. unsigned vecSize = vecTy->getVectorNumElements();
  1796. CallInst *vecUseInst = cast<CallInst>(matToVecMap[matSelectInst]);
  1797. Instruction *LHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx));
  1798. Instruction *RHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx));
  1799. IRBuilder<> Builder(vecUseInst);
  1800. Value *Cond = vecUseInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
  1801. bool isVecCond = Cond->getType()->isVectorTy();
  1802. if (isVecCond) {
  1803. Instruction *MatCond = cast<Instruction>(
  1804. matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx));
  1805. DXASSERT_NOMSG(matToVecMap.count(MatCond));
  1806. Cond = matToVecMap[MatCond];
  1807. }
  1808. DXASSERT_NOMSG(matToVecMap.count(LHS));
  1809. Value *VLHS = matToVecMap[LHS];
  1810. DXASSERT_NOMSG(matToVecMap.count(RHS));
  1811. Value *VRHS = matToVecMap[RHS];
  1812. Value *VecSelect = UndefValue::get(vecTy);
  1813. for (unsigned i = 0; i < vecSize; i++) {
  1814. llvm::Value *EltCond = Cond;
  1815. if (isVecCond)
  1816. EltCond = Builder.CreateExtractElement(Cond, i);
  1817. llvm::Value *EltL = Builder.CreateExtractElement(VLHS, i);
  1818. llvm::Value *EltR = Builder.CreateExtractElement(VRHS, i);
  1819. llvm::Value *EltSelect = Builder.CreateSelect(EltCond, EltL, EltR);
  1820. VecSelect = Builder.CreateInsertElement(VecSelect, EltSelect, i);
  1821. }
  1822. AddToDeadInsts(vecUseInst);
  1823. vecUseInst->replaceAllUsesWith(VecSelect);
  1824. matToVecMap[matSelectInst] = VecSelect;
  1825. }
  1826. void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
  1827. Value *vecVal,
  1828. GetElementPtrInst *matGEP) {
  1829. SmallVector<Value *, 4> idxList(matGEP->idx_begin(), matGEP->idx_end());
  1830. IRBuilder<> GEPBuilder(matGEP);
  1831. Value *newGEP = GEPBuilder.CreateInBoundsGEP(vecVal, idxList);
  1832. // Only used by mat subscript and mat ld/st.
  1833. for (Value::user_iterator user = matGEP->user_begin();
  1834. user != matGEP->user_end();) {
  1835. Instruction *useInst = cast<Instruction>(*(user++));
  1836. IRBuilder<> Builder(useInst);
  1837. // Skip return here.
  1838. if (isa<ReturnInst>(useInst))
  1839. continue;
  1840. if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
  1841. // Function call.
  1842. hlsl::HLOpcodeGroup group =
  1843. hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
  1844. switch (group) {
  1845. case HLOpcodeGroup::HLMatLoadStore: {
  1846. unsigned opcode = GetHLOpcode(useCall);
  1847. HLMatLoadStoreOpcode matOpcode =
  1848. static_cast<HLMatLoadStoreOpcode>(opcode);
  1849. switch (matOpcode) {
  1850. case HLMatLoadStoreOpcode::ColMatLoad:
  1851. case HLMatLoadStoreOpcode::RowMatLoad: {
  1852. // Skip the vector version.
  1853. if (useCall->getType()->isVectorTy())
  1854. continue;
  1855. Value *newLd = Builder.CreateLoad(newGEP);
  1856. DXASSERT(matToVecMap.count(useCall), "must have vec version");
  1857. Value *oldLd = matToVecMap[useCall];
  1858. // Delete the oldLd.
  1859. AddToDeadInsts(cast<Instruction>(oldLd));
  1860. oldLd->replaceAllUsesWith(newLd);
  1861. matToVecMap[useCall] = newLd;
  1862. } break;
  1863. case HLMatLoadStoreOpcode::ColMatStore:
  1864. case HLMatLoadStoreOpcode::RowMatStore: {
  1865. Value *vecPtr = newGEP;
  1866. Value *matVal = useCall->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  1867. // Skip the vector version.
  1868. if (matVal->getType()->isVectorTy()) {
  1869. AddToDeadInsts(useCall);
  1870. continue;
  1871. }
  1872. Instruction *matInst = cast<Instruction>(matVal);
  1873. DXASSERT(matToVecMap.count(matInst), "must have vec version");
  1874. Value *vecVal = matToVecMap[matInst];
  1875. Builder.CreateStore(vecVal, vecPtr);
  1876. } break;
  1877. }
  1878. } break;
  1879. case HLOpcodeGroup::HLSubscript: {
  1880. TranslateMatSubscript(matGEP, newGEP, useCall);
  1881. } break;
  1882. default:
  1883. DXASSERT(0, "invalid operation");
  1884. break;
  1885. }
  1886. } else if (dyn_cast<BitCastInst>(useInst)) {
  1887. // Just replace the src with vec version.
  1888. useInst->setOperand(0, newGEP);
  1889. } else {
  1890. // Must be GEP.
  1891. GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
  1892. TranslateMatArrayGEP(matGEP, cast<Instruction>(newGEP), GEP);
  1893. }
  1894. }
  1895. AddToDeadInsts(matGEP);
  1896. }
  1897. Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
  1898. Value *newMatVal = nullptr;
  1899. if (vecToMatMap.count(vecVal)) {
  1900. newMatVal = vecToMatMap[vecVal];
  1901. } else {
  1902. // create conversion instructions if necessary, caching result for subsequent replacements.
  1903. // do so right after the vecVal def so it's available to all potential uses.
  1904. newMatVal = BitCastValueOrPtr(vecVal,
  1905. cast<Instruction>(vecVal)->getNextNode(), // vecVal must be instruction
  1906. matTy,
  1907. /*bOrigAllocaTy*/true,
  1908. vecVal->getName());
  1909. vecToMatMap[vecVal] = newMatVal;
  1910. }
  1911. return newMatVal;
  1912. }
  1913. void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
  1914. Value *vecVal) {
  1915. for (Value::user_iterator user = matVal->user_begin();
  1916. user != matVal->user_end();) {
  1917. Instruction *useInst = cast<Instruction>(*(user++));
  1918. // User must be function call.
  1919. if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
  1920. hlsl::HLOpcodeGroup group =
  1921. hlsl::GetHLOpcodeGroupByName(useCall->getCalledFunction());
  1922. switch (group) {
  1923. case HLOpcodeGroup::HLIntrinsic: {
  1924. if (CallInst *matCI = dyn_cast<CallInst>(matVal)) {
  1925. MatIntrinsicReplace(matCI, vecVal, useCall);
  1926. } else {
  1927. IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(useCall));
  1928. DXASSERT_LOCALVAR(opcode, opcode == IntrinsicOp::IOP_frexp,
  1929. "otherwise, unexpected opcode with matrix out parameter");
  1930. // NOTE: because out param use copy out semantic, so the operand of
  1931. // out must be temp alloca.
  1932. DXASSERT(isa<AllocaInst>(matVal), "else invalid mat ptr for frexp");
  1933. auto it = matToVecMap.find(useCall);
  1934. DXASSERT(it != matToVecMap.end(),
  1935. "else fail to create vec version of useCall");
  1936. CallInst *vecUseInst = cast<CallInst>(it->second);
  1937. for (unsigned i = 0; i < vecUseInst->getNumArgOperands(); i++) {
  1938. if (useCall->getArgOperand(i) == matVal) {
  1939. vecUseInst->setArgOperand(i, vecVal);
  1940. }
  1941. }
  1942. }
  1943. } break;
  1944. case HLOpcodeGroup::HLSelect: {
  1945. MatIntrinsicReplace(matVal, vecVal, useCall);
  1946. } break;
  1947. case HLOpcodeGroup::HLBinOp: {
  1948. TrivialMatBinOpReplace(matVal, vecVal, useCall);
  1949. } break;
  1950. case HLOpcodeGroup::HLUnOp: {
  1951. TrivialMatUnOpReplace(matVal, vecVal, useCall);
  1952. } break;
  1953. case HLOpcodeGroup::HLCast: {
  1954. TranslateMatCast(matVal, vecVal, useCall);
  1955. } break;
  1956. case HLOpcodeGroup::HLMatLoadStore: {
  1957. DXASSERT(matToVecMap.count(useCall), "must have vec version");
  1958. Value *vecUser = matToVecMap[useCall];
  1959. if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal) ||
  1960. GetIfMatrixGEPOfUDTArg(matVal, *m_pHLModule)) {
  1961. // Load Already translated in lowerToVec.
  1962. // Store val operand will be set by the val use.
  1963. // Do nothing here.
  1964. } else if (StoreInst *stInst = dyn_cast<StoreInst>(vecUser))
  1965. stInst->setOperand(0, vecVal);
  1966. else
  1967. TrivialMatReplace(matVal, vecVal, useCall);
  1968. } break;
  1969. case HLOpcodeGroup::HLSubscript: {
  1970. if (AllocaInst *AI = dyn_cast<AllocaInst>(vecVal))
  1971. TranslateMatSubscript(matVal, vecVal, useCall);
  1972. else if (BitCastInst *BCI = dyn_cast<BitCastInst>(vecVal))
  1973. TranslateMatSubscript(matVal, vecVal, useCall);
  1974. else
  1975. TrivialMatReplace(matVal, vecVal, useCall);
  1976. } break;
  1977. case HLOpcodeGroup::HLInit: {
  1978. DXASSERT(!isa<AllocaInst>(matVal), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
  1979. TranslateMatInit(useCall);
  1980. } break;
  1981. case HLOpcodeGroup::NotHL: {
  1982. castMatrixArgs(matVal, vecVal, useCall);
  1983. } break;
  1984. case HLOpcodeGroup::HLExtIntrinsic:
  1985. case HLOpcodeGroup::HLCreateHandle:
  1986. case HLOpcodeGroup::NumOfHLOps:
  1987. // No vector equivalents for these ops.
  1988. break;
  1989. }
  1990. } else if (dyn_cast<BitCastInst>(useInst)) {
  1991. // Just replace the src with vec version.
  1992. if (useInst != vecVal)
  1993. useInst->setOperand(0, vecVal);
  1994. } else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
  1995. Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
  1996. RI->setOperand(0, newMatVal);
  1997. } else if (isa<StoreInst>(useInst)) {
  1998. DXASSERT(vecToMatMap.count(vecVal) && vecToMatMap[vecVal] == matVal, "matrix store should only be used with preserved matrix values");
  1999. } else {
  2000. // Must be GEP on mat array alloca.
  2001. GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
  2002. AllocaInst *AI = cast<AllocaInst>(matVal);
  2003. TranslateMatArrayGEP(AI, vecVal, GEP);
  2004. }
  2005. }
  2006. }
  2007. void HLMatrixLowerPass::castMatrixArgs(Value *matVal, Value *vecVal, CallInst *CI) {
  2008. // translate user function parameters as necessary
  2009. Type *Ty = matVal->getType();
  2010. if (Ty->isPointerTy()) {
  2011. IRBuilder<> Builder(CI);
  2012. Value *newMatVal = Builder.CreateBitCast(vecVal, Ty);
  2013. CI->replaceUsesOfWith(matVal, newMatVal);
  2014. } else {
  2015. Value *newMatVal = GetMatrixForVec(vecVal, Ty);
  2016. CI->replaceUsesOfWith(matVal, newMatVal);
  2017. }
  2018. }
  2019. void HLMatrixLowerPass::finalMatTranslation(Value *matVal) {
  2020. // Translate matInit.
  2021. if (CallInst *CI = dyn_cast<CallInst>(matVal)) {
  2022. hlsl::HLOpcodeGroup group =
  2023. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  2024. switch (group) {
  2025. case HLOpcodeGroup::HLInit: {
  2026. TranslateMatInit(CI);
  2027. } break;
  2028. case HLOpcodeGroup::HLSelect: {
  2029. TranslateMatSelect(CI);
  2030. } break;
  2031. default:
  2032. // Skip group already translated.
  2033. break;
  2034. }
  2035. }
  2036. }
  2037. void HLMatrixLowerPass::DeleteDeadInsts() {
  2038. // Delete the matrix version insts.
  2039. for (Instruction *deadInst : m_deadInsts) {
  2040. // Replace with undef and remove it.
  2041. deadInst->replaceAllUsesWith(UndefValue::get(deadInst->getType()));
  2042. deadInst->eraseFromParent();
  2043. }
  2044. m_deadInsts.clear();
  2045. m_inDeadInstsSet.clear();
  2046. }
  2047. static bool OnlyUsedByMatrixLdSt(Value *V) {
  2048. bool onlyLdSt = true;
  2049. for (User *user : V->users()) {
  2050. if (isa<Constant>(user) && user->use_empty())
  2051. continue;
  2052. CallInst *CI = cast<CallInst>(user);
  2053. if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
  2054. HLOpcodeGroup::HLMatLoadStore)
  2055. continue;
  2056. onlyLdSt = false;
  2057. break;
  2058. }
  2059. return onlyLdSt;
  2060. }
  2061. static Constant *LowerMatrixArrayConst(Constant *MA, Type *ResultTy) {
  2062. if (ArrayType *AT = dyn_cast<ArrayType>(ResultTy)) {
  2063. std::vector<Constant *> Elts;
  2064. Type *EltResultTy = AT->getElementType();
  2065. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  2066. Constant *Elt =
  2067. LowerMatrixArrayConst(MA->getAggregateElement(i), EltResultTy);
  2068. Elts.emplace_back(Elt);
  2069. }
  2070. return ConstantArray::get(AT, Elts);
  2071. } else {
  2072. // Cast float[row][col] -> float< row * col>.
  2073. // Get float[row][col] from the struct.
  2074. Constant *rows = MA->getAggregateElement((unsigned)0);
  2075. ArrayType *RowAT = cast<ArrayType>(rows->getType());
  2076. std::vector<Constant *> Elts;
  2077. for (unsigned r=0;r<RowAT->getArrayNumElements();r++) {
  2078. Constant *row = rows->getAggregateElement(r);
  2079. VectorType *VT = cast<VectorType>(row->getType());
  2080. for (unsigned c = 0; c < VT->getVectorNumElements(); c++) {
  2081. Elts.emplace_back(row->getAggregateElement(c));
  2082. }
  2083. }
  2084. return ConstantVector::get(Elts);
  2085. }
  2086. }
  2087. void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
  2088. // Lower to array of vector array like float[row * col].
  2089. // It's follow the major of decl.
  2090. // DynamicIndexingVectorToArray will change it to scalar array.
  2091. Type *Ty = GV->getType()->getPointerElementType();
  2092. std::vector<unsigned> arraySizeList;
  2093. while (Ty->isArrayTy()) {
  2094. arraySizeList.push_back(Ty->getArrayNumElements());
  2095. Ty = Ty->getArrayElementType();
  2096. }
  2097. unsigned row, col;
  2098. Type *EltTy = GetMatrixInfo(Ty, col, row);
  2099. Ty = VectorType::get(EltTy, col * row);
  2100. for (auto arraySize = arraySizeList.rbegin();
  2101. arraySize != arraySizeList.rend(); arraySize++)
  2102. Ty = ArrayType::get(Ty, *arraySize);
  2103. Type *VecArrayTy = Ty;
  2104. Constant *OldInitVal = GV->getInitializer();
  2105. Constant *InitVal =
  2106. isa<UndefValue>(OldInitVal)
  2107. ? UndefValue::get(VecArrayTy)
  2108. : LowerMatrixArrayConst(OldInitVal, cast<ArrayType>(VecArrayTy));
  2109. bool isConst = GV->isConstant();
  2110. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  2111. unsigned AddressSpace = GV->getType()->getAddressSpace();
  2112. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  2113. Module *M = GV->getParent();
  2114. GlobalVariable *VecGV =
  2115. new llvm::GlobalVariable(*M, VecArrayTy, /*IsConstant*/ isConst, linkage,
  2116. /*InitVal*/ InitVal, GV->getName() + ".v",
  2117. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  2118. // Add debug info.
  2119. if (m_HasDbgInfo) {
  2120. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2121. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, VecGV);
  2122. }
  2123. DenseMap<Instruction *, Value *> matToVecMap;
  2124. for (User *U : GV->users()) {
  2125. Value *VecGEP = nullptr;
  2126. // Must be GEP or GEPOperator.
  2127. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  2128. IRBuilder<> Builder(GEP);
  2129. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  2130. VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
  2131. AddToDeadInsts(GEP);
  2132. } else {
  2133. GEPOperator *GEPOP = cast<GEPOperator>(U);
  2134. IRBuilder<> Builder(GV->getContext());
  2135. SmallVector<Value *, 4> idxList(GEPOP->idx_begin(), GEPOP->idx_end());
  2136. VecGEP = Builder.CreateInBoundsGEP(VecGV, idxList);
  2137. }
  2138. for (auto user = U->user_begin(); user != U->user_end();) {
  2139. CallInst *CI = cast<CallInst>(*(user++));
  2140. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2141. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2142. TranslateMatLoadStoreOnGlobalPtr(CI, VecGEP);
  2143. } else if (group == HLOpcodeGroup::HLSubscript) {
  2144. TranslateMatSubscriptOnGlobalPtr(CI, VecGEP);
  2145. } else {
  2146. DXASSERT(0, "invalid operation");
  2147. }
  2148. }
  2149. }
  2150. DeleteDeadInsts();
  2151. GV->removeDeadConstantUsers();
  2152. GV->eraseFromParent();
  2153. }
  2154. static void FlattenMatConst(Constant *M, std::vector<Constant *> &Elts) {
  2155. unsigned row, col;
  2156. Type *EltTy = HLMatrixLower::GetMatrixInfo(M->getType(), col, row);
  2157. if (isa<UndefValue>(M)) {
  2158. Constant *Elt = UndefValue::get(EltTy);
  2159. for (unsigned i=0;i<col*row;i++)
  2160. Elts.emplace_back(Elt);
  2161. } else {
  2162. M = M->getAggregateElement((unsigned)0);
  2163. // Initializer is already in correct major.
  2164. // Just read it here.
  2165. // The type is vector<element, col>[row].
  2166. for (unsigned r = 0; r < row; r++) {
  2167. Constant *C = M->getAggregateElement(r);
  2168. for (unsigned c = 0; c < col; c++) {
  2169. Elts.emplace_back(C->getAggregateElement(c));
  2170. }
  2171. }
  2172. }
  2173. }
  2174. void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
  2175. if (HLMatrixLower::IsMatrixArrayPointer(GV->getType())) {
  2176. runOnGlobalMatrixArray(GV);
  2177. return;
  2178. }
  2179. Type *Ty = GV->getType()->getPointerElementType();
  2180. if (!HLMatrixLower::IsMatrixType(Ty))
  2181. return;
  2182. bool onlyLdSt = OnlyUsedByMatrixLdSt(GV);
  2183. bool isConst = GV->isConstant();
  2184. Type *vecTy = HLMatrixLower::LowerMatrixType(Ty);
  2185. Module *M = GV->getParent();
  2186. const DataLayout &DL = M->getDataLayout();
  2187. std::vector<Constant *> Elts;
  2188. // Lower to vector or array for scalar matrix.
  2189. // Make it col major so don't need shuffle when load/store.
  2190. FlattenMatConst(GV->getInitializer(), Elts);
  2191. if (onlyLdSt) {
  2192. Type *EltTy = vecTy->getVectorElementType();
  2193. unsigned vecSize = vecTy->getVectorNumElements();
  2194. std::vector<Value *> vecGlobals(vecSize);
  2195. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  2196. unsigned AddressSpace = GV->getType()->getAddressSpace();
  2197. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  2198. unsigned debugOffset = 0;
  2199. unsigned size = DL.getTypeAllocSizeInBits(EltTy);
  2200. unsigned align = DL.getPrefTypeAlignment(EltTy);
  2201. for (int i = 0, e = vecSize; i != e; ++i) {
  2202. Constant *InitVal = Elts[i];
  2203. GlobalVariable *EltGV = new llvm::GlobalVariable(
  2204. *M, EltTy, /*IsConstant*/ isConst, linkage,
  2205. /*InitVal*/ InitVal, GV->getName() + "." + Twine(i),
  2206. /*InsertBefore*/nullptr,
  2207. TLMode, AddressSpace);
  2208. // Add debug info.
  2209. if (m_HasDbgInfo) {
  2210. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2211. HLModule::CreateElementGlobalVariableDebugInfo(
  2212. GV, Finder, EltGV, size, align, debugOffset,
  2213. EltGV->getName().ltrim(GV->getName()));
  2214. debugOffset += size;
  2215. }
  2216. vecGlobals[i] = EltGV;
  2217. }
  2218. for (User *user : GV->users()) {
  2219. if (isa<Constant>(user) && user->use_empty())
  2220. continue;
  2221. CallInst *CI = cast<CallInst>(user);
  2222. TranslateMatLoadStoreOnGlobal(GV, vecGlobals, CI);
  2223. AddToDeadInsts(CI);
  2224. }
  2225. DeleteDeadInsts();
  2226. GV->eraseFromParent();
  2227. }
  2228. else {
  2229. // lower to array of scalar here.
  2230. ArrayType *AT = ArrayType::get(vecTy->getVectorElementType(), vecTy->getVectorNumElements());
  2231. Constant *InitVal = ConstantArray::get(AT, Elts);
  2232. GlobalVariable *arrayMat = new llvm::GlobalVariable(
  2233. *M, AT, /*IsConstant*/ false, llvm::GlobalValue::InternalLinkage,
  2234. /*InitVal*/ InitVal, GV->getName());
  2235. // Add debug info.
  2236. if (m_HasDbgInfo) {
  2237. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  2238. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder,
  2239. arrayMat);
  2240. }
  2241. for (auto U = GV->user_begin(); U != GV->user_end();) {
  2242. Value *user = *(U++);
  2243. CallInst *CI = cast<CallInst>(user);
  2244. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2245. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2246. TranslateMatLoadStoreOnGlobal(GV, arrayMat, CI);
  2247. }
  2248. else {
  2249. DXASSERT(group == HLOpcodeGroup::HLSubscript, "Must be subscript operation");
  2250. TranslateMatSubscriptOnGlobalPtr(CI, arrayMat);
  2251. }
  2252. }
  2253. GV->removeDeadConstantUsers();
  2254. GV->eraseFromParent();
  2255. }
  2256. }
  2257. void HLMatrixLowerPass::runOnFunction(Function &F) {
  2258. // Skip hl function definition (like createhandle)
  2259. if (hlsl::GetHLOpcodeGroupByName(&F) != HLOpcodeGroup::NotHL)
  2260. return;
  2261. // Create vector version of matrix instructions first.
  2262. // The matrix operands will be undefval for these instructions.
  2263. for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
  2264. BasicBlock *BB = BBI;
  2265. for (auto II = BB->begin(); II != BB->end(); ) {
  2266. Instruction &I = *(II++);
  2267. if (IsMatrixType(I.getType())) {
  2268. lowerToVec(&I);
  2269. } else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
  2270. Type *Ty = AI->getAllocatedType();
  2271. if (HLMatrixLower::IsMatrixType(Ty)) {
  2272. lowerToVec(&I);
  2273. } else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
  2274. lowerToVec(&I);
  2275. }
  2276. } else if (CallInst *CI = dyn_cast<CallInst>(&I)) {
  2277. HLOpcodeGroup group =
  2278. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  2279. if (group == HLOpcodeGroup::HLMatLoadStore) {
  2280. HLMatLoadStoreOpcode opcode =
  2281. static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
  2282. DXASSERT_LOCALVAR(opcode,
  2283. opcode == HLMatLoadStoreOpcode::ColMatStore ||
  2284. opcode == HLMatLoadStoreOpcode::RowMatStore,
  2285. "Must MatStore here, load will go IsMatrixType path");
  2286. // Lower it here to make sure it is ready before replace.
  2287. lowerToVec(&I);
  2288. }
  2289. } else if (GetIfMatrixGEPOfUDTAlloca(&I) ||
  2290. GetIfMatrixGEPOfUDTArg(&I, *m_pHLModule)) {
  2291. lowerToVec(&I);
  2292. }
  2293. }
  2294. }
  2295. // Update the use of matrix inst with the vector version.
  2296. for (auto matToVecIter = matToVecMap.begin();
  2297. matToVecIter != matToVecMap.end();) {
  2298. auto matToVec = matToVecIter++;
  2299. replaceMatWithVec(matToVec->first, cast<Instruction>(matToVec->second));
  2300. }
  2301. // Translate mat inst which require all operands ready.
  2302. for (auto matToVecIter = matToVecMap.begin();
  2303. matToVecIter != matToVecMap.end();) {
  2304. auto matToVec = matToVecIter++;
  2305. if (isa<Instruction>(matToVec->first))
  2306. finalMatTranslation(matToVec->first);
  2307. }
  2308. // Remove matrix targets of vecToMatMap from matToVecMap before adding the rest to dead insts.
  2309. for (auto &it : vecToMatMap) {
  2310. matToVecMap.erase(it.second);
  2311. }
  2312. // Delete the matrix version insts.
  2313. for (auto matToVecIter = matToVecMap.begin();
  2314. matToVecIter != matToVecMap.end();) {
  2315. auto matToVec = matToVecIter++;
  2316. // Add to m_deadInsts.
  2317. if (Instruction *matInst = dyn_cast<Instruction>(matToVec->first))
  2318. AddToDeadInsts(matInst);
  2319. }
  2320. DeleteDeadInsts();
  2321. matToVecMap.clear();
  2322. vecToMatMap.clear();
  2323. return;
  2324. }
  2325. // Matrix Bitcast lower.
  2326. // After linking Lower matrix bitcast patterns like:
  2327. // %169 = bitcast [72 x float]* %0 to [6 x %class.matrix.float.4.3]*
  2328. // %conv.i = fptoui float %164 to i32
  2329. // %arrayidx.i = getelementptr inbounds [6 x %class.matrix.float.4.3], [6 x %class.matrix.float.4.3]* %169, i32 0, i32 %conv.i
  2330. // %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
  2331. namespace {
  2332. Type *TryLowerMatTy(Type *Ty) {
  2333. Type *VecTy = nullptr;
  2334. if (HLMatrixLower::IsMatrixArrayPointer(Ty)) {
  2335. VecTy = HLMatrixLower::LowerMatrixArrayPointerToOneDimArray(Ty);
  2336. } else if (isa<PointerType>(Ty) &&
  2337. HLMatrixLower::IsMatrixType(Ty->getPointerElementType())) {
  2338. VecTy = HLMatrixLower::LowerMatrixTypeToOneDimArray(
  2339. Ty->getPointerElementType());
  2340. VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
  2341. }
  2342. return VecTy;
  2343. }
  2344. class MatrixBitcastLowerPass : public FunctionPass {
  2345. public:
  2346. static char ID; // Pass identification, replacement for typeid
  2347. explicit MatrixBitcastLowerPass() : FunctionPass(ID) {}
  2348. const char *getPassName() const override { return "Matrix Bitcast lower"; }
  2349. bool runOnFunction(Function &F) override {
  2350. bool bUpdated = false;
  2351. std::unordered_set<BitCastInst*> matCastSet;
  2352. for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
  2353. BasicBlock *BB = blkIt;
  2354. for (auto iIt = BB->begin(); iIt != BB->end(); ) {
  2355. Instruction *I = (iIt++);
  2356. if (BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
  2357. // Mutate mat to vec.
  2358. Type *ToTy = BCI->getType();
  2359. if (Type *ToVecTy = TryLowerMatTy(ToTy)) {
  2360. matCastSet.insert(BCI);
  2361. bUpdated = true;
  2362. }
  2363. }
  2364. }
  2365. }
  2366. DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
  2367. // Remove bitcast which has CallInst user.
  2368. if (DM.GetShaderModel()->IsLib()) {
  2369. for (auto it = matCastSet.begin(); it != matCastSet.end();) {
  2370. BitCastInst *BCI = *(it++);
  2371. if (hasCallUser(BCI)) {
  2372. matCastSet.erase(BCI);
  2373. }
  2374. }
  2375. }
  2376. // Lower matrix first.
  2377. for (BitCastInst *BCI : matCastSet) {
  2378. lowerMatrix(BCI, BCI->getOperand(0));
  2379. }
  2380. return bUpdated;
  2381. }
  2382. private:
  2383. void lowerMatrix(Instruction *M, Value *A);
  2384. bool hasCallUser(Instruction *M);
  2385. };
  2386. }
  2387. bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
  2388. for (auto it = M->user_begin(); it != M->user_end();) {
  2389. User *U = *(it++);
  2390. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  2391. Type *EltTy = GEP->getType()->getPointerElementType();
  2392. if (HLMatrixLower::IsMatrixType(EltTy)) {
  2393. if (hasCallUser(GEP))
  2394. return true;
  2395. } else {
  2396. DXASSERT(0, "invalid GEP for matrix");
  2397. }
  2398. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
  2399. if (hasCallUser(BCI))
  2400. return true;
  2401. } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  2402. if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
  2403. } else {
  2404. DXASSERT(0, "invalid load for matrix");
  2405. }
  2406. } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
  2407. Value *V = ST->getValueOperand();
  2408. if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
  2409. } else {
  2410. DXASSERT(0, "invalid load for matrix");
  2411. }
  2412. } else if (isa<CallInst>(U)) {
  2413. return true;
  2414. } else {
  2415. DXASSERT(0, "invalid use of matrix");
  2416. }
  2417. }
  2418. return false;
  2419. }
  2420. namespace {
  2421. Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
  2422. IRBuilder<> &Builder) {
  2423. Value *GEP = nullptr;
  2424. if (GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A)) {
  2425. // A should be gep oneDimArray, 0, index * matSize
  2426. // Here add eltIdx to index * matSize foreach elt.
  2427. Instruction *EltGEP = GEPA->clone();
  2428. unsigned eltIdx = EltGEP->getNumOperands() - 1;
  2429. Value *NewIdx =
  2430. Builder.CreateAdd(EltGEP->getOperand(eltIdx), Builder.getInt32(i));
  2431. EltGEP->setOperand(eltIdx, NewIdx);
  2432. Builder.Insert(EltGEP);
  2433. GEP = EltGEP;
  2434. } else {
  2435. GEP = Builder.CreateInBoundsGEP(A, {zeroIdx, Builder.getInt32(i)});
  2436. }
  2437. return GEP;
  2438. }
  2439. } // namespace
  2440. void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
  2441. for (auto it = M->user_begin(); it != M->user_end();) {
  2442. User *U = *(it++);
  2443. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  2444. Type *EltTy = GEP->getType()->getPointerElementType();
  2445. if (HLMatrixLower::IsMatrixType(EltTy)) {
  2446. // Change gep matrixArray, 0, index
  2447. // into
  2448. // gep oneDimArray, 0, index * matSize
  2449. IRBuilder<> Builder(GEP);
  2450. SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
  2451. DXASSERT(idxList.size() == 2,
  2452. "else not one dim matrix array index to matrix");
  2453. unsigned col = 0;
  2454. unsigned row = 0;
  2455. HLMatrixLower::GetMatrixInfo(EltTy, col, row);
  2456. Value *matSize = Builder.getInt32(col * row);
  2457. idxList.back() = Builder.CreateMul(idxList.back(), matSize);
  2458. Value *NewGEP = Builder.CreateGEP(A, idxList);
  2459. lowerMatrix(GEP, NewGEP);
  2460. DXASSERT(GEP->user_empty(), "else lower matrix fail");
  2461. GEP->eraseFromParent();
  2462. } else {
  2463. DXASSERT(0, "invalid GEP for matrix");
  2464. }
  2465. } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
  2466. lowerMatrix(BCI, A);
  2467. DXASSERT(BCI->user_empty(), "else lower matrix fail");
  2468. BCI->eraseFromParent();
  2469. } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  2470. if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
  2471. IRBuilder<> Builder(LI);
  2472. Value *zeroIdx = Builder.getInt32(0);
  2473. unsigned vecSize = Ty->getNumElements();
  2474. Value *NewVec = UndefValue::get(LI->getType());
  2475. for (unsigned i = 0; i < vecSize; i++) {
  2476. Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
  2477. Value *Elt = Builder.CreateLoad(GEP);
  2478. NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
  2479. }
  2480. LI->replaceAllUsesWith(NewVec);
  2481. LI->eraseFromParent();
  2482. } else {
  2483. DXASSERT(0, "invalid load for matrix");
  2484. }
  2485. } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
  2486. Value *V = ST->getValueOperand();
  2487. if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
  2488. IRBuilder<> Builder(LI);
  2489. Value *zeroIdx = Builder.getInt32(0);
  2490. unsigned vecSize = Ty->getNumElements();
  2491. for (unsigned i = 0; i < vecSize; i++) {
  2492. Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
  2493. Value *Elt = Builder.CreateExtractElement(V, i);
  2494. Builder.CreateStore(Elt, GEP);
  2495. }
  2496. ST->eraseFromParent();
  2497. } else {
  2498. DXASSERT(0, "invalid load for matrix");
  2499. }
  2500. } else {
  2501. DXASSERT(0, "invalid use of matrix");
  2502. }
  2503. }
  2504. }
  2505. #include "dxc/HLSL/DxilGenerationPass.h"
  2506. char MatrixBitcastLowerPass::ID = 0;
  2507. FunctionPass *llvm::createMatrixBitcastLowerPass() { return new MatrixBitcastLowerPass(); }
  2508. INITIALIZE_PASS(MatrixBitcastLowerPass, "matrixbitcastlower", "Matrix Bitcast lower", false, false)