HLMatrixLowerPass.cpp 66 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixLowerPass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // HLMatrixLowerPass implementation. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLMatrixLowerPass.h"
  12. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  13. #include "dxc/HLSL/HLMatrixType.h"
  14. #include "dxc/HLSL/HLOperations.h"
  15. #include "dxc/HLSL/HLModule.h"
  16. #include "dxc/HlslIntrinsicOp.h"
  17. #include "dxc/Support/Global.h"
  18. #include "dxc/DXIL/DxilOperations.h"
  19. #include "dxc/DXIL/DxilTypeSystem.h"
  20. #include "dxc/DXIL/DxilModule.h"
  21. #include "dxc/DXIL/DxilUtil.h"
  22. #include "HLMatrixSubscriptUseReplacer.h"
  23. #include "llvm/IR/IRBuilder.h"
  24. #include "llvm/IR/Module.h"
  25. #include "llvm/IR/DebugInfo.h"
  26. #include "llvm/IR/IntrinsicInst.h"
  27. #include "llvm/Transforms/Utils/Local.h"
  28. #include "llvm/Pass.h"
  29. #include "llvm/Support/raw_ostream.h"
  30. #include "llvm/Analysis/ValueTracking.h"
  31. #include <unordered_set>
  32. #include <vector>
  33. using namespace llvm;
  34. using namespace hlsl;
  35. using namespace hlsl::HLMatrixLower;
  36. namespace hlsl {
  37. namespace HLMatrixLower {
  38. Value *BuildVector(Type *EltTy, ArrayRef<llvm::Value *> elts, IRBuilder<> &Builder) {
  39. Value *Vec = UndefValue::get(VectorType::get(EltTy, static_cast<unsigned>(elts.size())));
  40. for (unsigned i = 0; i < elts.size(); i++)
  41. Vec = Builder.CreateInsertElement(Vec, elts[i], i);
  42. return Vec;
  43. }
  44. } // namespace HLMatrixLower
  45. } // namespace hlsl
  46. namespace {
  47. // Creates and manages a set of temporary overloaded functions keyed on the function type,
  48. // and which should be destroyed when the pool gets out of scope.
  49. class TempOverloadPool {
  50. public:
  51. TempOverloadPool(llvm::Module &Module, const char* BaseName)
  52. : Module(Module), BaseName(BaseName) {}
  53. ~TempOverloadPool() { clear(); }
  54. Function *get(FunctionType *Ty);
  55. bool contains(FunctionType *Ty) const { return Funcs.count(Ty) != 0; }
  56. bool contains(Function *Func) const;
  57. void clear();
  58. private:
  59. llvm::Module &Module;
  60. const char* BaseName;
  61. llvm::DenseMap<FunctionType*, Function*> Funcs;
  62. };
  63. Function *TempOverloadPool::get(FunctionType *Ty) {
  64. auto It = Funcs.find(Ty);
  65. if (It != Funcs.end()) return It->second;
  66. std::string MangledName;
  67. raw_string_ostream MangledNameStream(MangledName);
  68. MangledNameStream << BaseName;
  69. MangledNameStream << '.';
  70. Ty->print(MangledNameStream);
  71. MangledNameStream.flush();
  72. Function* Func = cast<Function>(Module.getOrInsertFunction(MangledName, Ty));
  73. Funcs.insert(std::make_pair(Ty, Func));
  74. return Func;
  75. }
  76. bool TempOverloadPool::contains(Function *Func) const {
  77. auto It = Funcs.find(Func->getFunctionType());
  78. return It != Funcs.end() && It->second == Func;
  79. }
  80. void TempOverloadPool::clear() {
  81. for (auto Entry : Funcs) {
  82. DXASSERT(Entry.second->use_empty(), "Temporary function still used during pool destruction.");
  83. Entry.second->removeFromParent();
  84. }
  85. Funcs.clear();
  86. }
  87. // High-level matrix lowering pass.
  88. //
  89. // This pass converts matrices to their lowered vector representations,
  90. // including global variables, local variables and operations,
  91. // but not function signatures (arguments and return types) - left to HLSignatureLower and HLMatrixBitcastLower,
  92. // nor matrices obtained from resources or constant - left to HLOperationLower.
  93. //
  94. // Algorithm overview:
  95. // 1. Find all matrix and matrix array global variables and lower them to vectors.
  96. // Walk any GEPs and insert vec-to-mat translation stubs so that consuming
  97. // instructions keep dealing with matrix types for the moment.
  98. // 2. For each function
  99. // 2a. Lower all matrix and matrix array allocas, just like global variables.
  100. // 2b. Lower all other instructions producing or consuming matrices
  101. //
  102. // Conversion stubs are used to allow converting instructions in isolation,
  103. // and in an order-independent manner:
  104. //
  105. // Initial: MatInst1(MatInst2(MatInst3))
  106. // After lowering MatInst2: MatInst1(VecToMat(VecInst2(MatToVec(MatInst3))))
  107. // After lowering MatInst1: VecInst1(VecInst2(MatToVec(MatInst3)))
  108. // After lowering MatInst3: VecInst1(VecInst2(VecInst3))
  109. class HLMatrixLowerPass : public ModulePass {
  110. public:
  111. static char ID; // Pass identification, replacement for typeid
  112. explicit HLMatrixLowerPass() : ModulePass(ID) {}
  113. const char *getPassName() const override { return "HL matrix lower"; }
  114. bool runOnModule(Module &M) override;
  115. private:
  116. void runOnFunction(Function &Func);
  117. void addToDeadInsts(Instruction *Inst) { m_deadInsts.emplace_back(Inst); }
  118. void deleteDeadInsts();
  119. void getMatrixAllocasAndOtherInsts(Function &Func,
  120. std::vector<AllocaInst*> &MatAllocas, std::vector<Instruction*> &MatInsts);
  121. Value *getLoweredByValOperand(Value *Val, IRBuilder<> &Builder, bool DiscardStub = false);
  122. Value *tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder, bool DiscardStub = false);
  123. Value *bitCastValue(Value *SrcVal, Type* DstTy, bool DstTyAlloca, IRBuilder<> &Builder);
  124. void replaceAllUsesByLoweredValue(Instruction *MatInst, Value *VecVal);
  125. void replaceAllVariableUses(Value* MatPtr, Value* LoweredPtr);
  126. void replaceAllVariableUses(SmallVectorImpl<Value*> &GEPIdxStack, Value *StackTopPtr, Value* LoweredPtr);
  127. Value *translateScalarMatMul(Value *scalar, Value *mat, IRBuilder<> &Builder, bool isLhsScalar = true);
  128. void lowerGlobal(GlobalVariable *Global);
  129. Constant *lowerConstInitVal(Constant *Val);
  130. AllocaInst *lowerAlloca(AllocaInst *MatAlloca);
  131. void lowerInstruction(Instruction* Inst);
  132. void lowerReturn(ReturnInst* Return);
  133. Value *lowerCall(CallInst *Call);
  134. Value *lowerNonHLCall(CallInst *Call);
  135. void lowerPreciseCall(CallInst *Call, IRBuilder<> Builder);
  136. Value *lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup);
  137. Value *lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode);
  138. Value *lowerHLMulIntrinsic(Value* Lhs, Value *Rhs, bool Unsigned, IRBuilder<> &Builder);
  139. Value *lowerHLTransposeIntrinsic(Value *MatVal, IRBuilder<> &Builder);
  140. Value *lowerHLDeterminantIntrinsic(Value *MatVal, IRBuilder<> &Builder);
  141. Value *lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder);
  142. Value *lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder);
  143. Value *lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode);
  144. Value *lowerHLLoad(CallInst *Call, Value *MatPtr, bool RowMajor, IRBuilder<> &Builder);
  145. Value *lowerHLStore(CallInst *Call, Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder);
  146. Value *lowerHLCast(CallInst *Call, Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder);
  147. Value *lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
  148. Value *lowerHLMatElementSubscript(CallInst *Call, bool RowMajor);
  149. Value *lowerHLMatSubscript(CallInst *Call, bool RowMajor);
  150. void lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices);
  151. Value *lowerHLInit(CallInst *Call);
  152. Value *lowerHLSelect(CallInst *Call);
  153. private:
  154. Module *m_pModule;
  155. HLModule *m_pHLModule;
  156. bool m_HasDbgInfo;
  157. // Pools for the translation stubs
  158. TempOverloadPool *m_matToVecStubs = nullptr;
  159. TempOverloadPool *m_vecToMatStubs = nullptr;
  160. std::vector<Instruction *> m_deadInsts;
  161. };
  162. }
  163. char HLMatrixLowerPass::ID = 0;
  164. ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
  165. INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
  166. bool HLMatrixLowerPass::runOnModule(Module &M) {
  167. TempOverloadPool matToVecStubs(M, "hlmatrixlower.mat2vec");
  168. TempOverloadPool vecToMatStubs(M, "hlmatrixlower.vec2mat");
  169. m_pModule = &M;
  170. m_pHLModule = &m_pModule->GetOrCreateHLModule();
  171. // Load up debug information, to cross-reference values and the instructions
  172. // used to load them.
  173. m_HasDbgInfo = hasDebugInfo(M);
  174. m_matToVecStubs = &matToVecStubs;
  175. m_vecToMatStubs = &vecToMatStubs;
  176. // First, lower static global variables.
  177. // We need to accumulate them locally because we'll be creating new ones as we lower them.
  178. std::vector<GlobalVariable*> Globals;
  179. for (GlobalVariable &Global : M.globals()) {
  180. if ((dxilutil::IsStaticGlobal(&Global) || dxilutil::IsSharedMemoryGlobal(&Global))
  181. && HLMatrixType::isMatrixPtrOrArrayPtr(Global.getType())) {
  182. Globals.emplace_back(&Global);
  183. }
  184. }
  185. for (GlobalVariable *Global : Globals)
  186. lowerGlobal(Global);
  187. for (Function &F : M.functions()) {
  188. if (F.isDeclaration()) continue;
  189. runOnFunction(F);
  190. }
  191. m_pModule = nullptr;
  192. m_pHLModule = nullptr;
  193. m_matToVecStubs = nullptr;
  194. m_vecToMatStubs = nullptr;
  195. // If you hit an assert during TempOverloadPool destruction,
  196. // it means that either a matrix producer was lowered,
  197. // causing a translation stub to be created,
  198. // but the consumer of that matrix was never (properly) lowered.
  199. // Or the opposite: a matrix consumer was lowered and not its producer.
  200. return true;
  201. }
  202. void HLMatrixLowerPass::runOnFunction(Function &Func) {
  203. // Skip hl function definition (like createhandle)
  204. if (hlsl::GetHLOpcodeGroupByName(&Func) != HLOpcodeGroup::NotHL)
  205. return;
  206. // Save the matrix instructions first since the translation process
  207. // will temporarily create other instructions consuming/producing matrix types.
  208. std::vector<AllocaInst*> MatAllocas;
  209. std::vector<Instruction*> MatInsts;
  210. getMatrixAllocasAndOtherInsts(Func, MatAllocas, MatInsts);
  211. // First lower all allocas and take care of their GEP chains
  212. for (AllocaInst* MatAlloca : MatAllocas) {
  213. AllocaInst* LoweredAlloca = lowerAlloca(MatAlloca);
  214. replaceAllVariableUses(MatAlloca, LoweredAlloca);
  215. addToDeadInsts(MatAlloca);
  216. }
  217. // Now lower all other matrix instructions
  218. for (Instruction *MatInst : MatInsts)
  219. lowerInstruction(MatInst);
  220. deleteDeadInsts();
  221. }
  222. void HLMatrixLowerPass::deleteDeadInsts() {
  223. while (!m_deadInsts.empty()) {
  224. Instruction *Inst = m_deadInsts.back();
  225. m_deadInsts.pop_back();
  226. DXASSERT_NOMSG(Inst->use_empty());
  227. for (Value *Operand : Inst->operand_values()) {
  228. Instruction *OperandInst = dyn_cast<Instruction>(Operand);
  229. if (OperandInst && ++OperandInst->user_begin() == OperandInst->user_end()) {
  230. // We were its only user, erase recursively.
  231. // This will get rid of translation stubs:
  232. // Original: MatConsumer(MatProducer)
  233. // Producer lowered: MatConsumer(VecToMat(VecProducer)), MatProducer dead
  234. // Consumer lowered: VecConsumer(VecProducer)), MatConsumer(VecToMat) dead
  235. // Only by recursing on MatConsumer's operand do we delete the VecToMat stub.
  236. DXASSERT_NOMSG(*OperandInst->user_begin() == Inst);
  237. m_deadInsts.emplace_back(OperandInst);
  238. }
  239. }
  240. Inst->eraseFromParent();
  241. }
  242. }
  243. // Find all instructions consuming or producing matrices,
  244. // directly or through pointers/arrays.
  245. void HLMatrixLowerPass::getMatrixAllocasAndOtherInsts(Function &Func,
  246. std::vector<AllocaInst*> &MatAllocas, std::vector<Instruction*> &MatInsts){
  247. for (BasicBlock &BasicBlock : Func) {
  248. for (Instruction &Inst : BasicBlock) {
  249. // Don't lower GEPs directly, we'll handle them as we lower the root pointer,
  250. // typically a global variable or alloca.
  251. if (isa<GetElementPtrInst>(&Inst)) continue;
  252. // Don't lower lifetime intrinsics here, we'll handle them as we lower the alloca.
  253. IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(&Inst);
  254. if (Intrin && Intrin->getIntrinsicID() == Intrinsic::lifetime_start) continue;
  255. if (Intrin && Intrin->getIntrinsicID() == Intrinsic::lifetime_end) continue;
  256. if (AllocaInst *Alloca = dyn_cast<AllocaInst>(&Inst)) {
  257. if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Alloca->getType())) {
  258. MatAllocas.emplace_back(Alloca);
  259. }
  260. continue;
  261. }
  262. if (CallInst *Call = dyn_cast<CallInst>(&Inst)) {
  263. // Lowering of global variables will have introduced
  264. // vec-to-mat translation stubs, which we deal with indirectly,
  265. // as we lower the instructions consuming them.
  266. if (m_vecToMatStubs->contains(Call->getCalledFunction()))
  267. continue;
  268. // Mat-to-vec stubs should only be introduced during instruction lowering.
  269. // Globals lowering won't introduce any because their only operand is
  270. // their initializer, which we can fully lower without stubbing since it is constant.
  271. DXASSERT(!m_matToVecStubs->contains(Call->getCalledFunction()),
  272. "Unexpected mat-to-vec stubbing before function instruction lowering.");
  273. // Match matrix producers
  274. if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Inst.getType())) {
  275. MatInsts.emplace_back(Call);
  276. continue;
  277. }
  278. // Match matrix consumers
  279. for (Value *Operand : Inst.operand_values()) {
  280. if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Operand->getType())) {
  281. MatInsts.emplace_back(Call);
  282. break;
  283. }
  284. }
  285. continue;
  286. }
  287. if (ReturnInst *Return = dyn_cast<ReturnInst>(&Inst)) {
  288. Value *ReturnValue = Return->getReturnValue();
  289. if (ReturnValue != nullptr && HLMatrixType::isMatrixOrPtrOrArrayPtr(ReturnValue->getType()))
  290. MatInsts.emplace_back(Return);
  291. continue;
  292. }
  293. // Nothing else should produce or consume matrices
  294. }
  295. }
  296. }
  297. // Gets the matrix-lowered representation of a value, potentially adding a translation stub.
  298. // DiscardStub causes any vec-to-mat translation stubs to be deleted,
  299. // it should be true only if the original instruction will be modified and kept alive.
  300. // If a new instruction is created and the original marked as dead,
  301. // then the remove dead instructions pass will take care of removing the stub.
  302. Value* HLMatrixLowerPass::getLoweredByValOperand(Value *Val, IRBuilder<> &Builder, bool DiscardStub) {
  303. Type *Ty = Val->getType();
  304. // We're only lowering byval matrices.
  305. // Since structs and arrays are always accessed by pointer,
  306. // we do not need to worry about a matrix being hidden inside a more complex type.
  307. DXASSERT(!Ty->isPointerTy(), "Value cannot be a pointer.");
  308. HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty);
  309. if (!MatTy) return Val;
  310. Type *LoweredTy = MatTy.getLoweredVectorTypeForReg();
  311. // Check if the value is already a vec-to-mat translation stub
  312. if (CallInst *Call = dyn_cast<CallInst>(Val)) {
  313. if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
  314. if (DiscardStub && Call->getNumUses() == 1) {
  315. Call->use_begin()->set(UndefValue::get(Call->getType()));
  316. addToDeadInsts(Call);
  317. }
  318. Value *LoweredVal = Call->getArgOperand(0);
  319. DXASSERT(LoweredVal->getType() == LoweredTy, "Unexpected already-lowered value type.");
  320. return LoweredVal;
  321. }
  322. }
  323. // Lower mat 0 to vec 0.
  324. if (isa<ConstantAggregateZero>(Val))
  325. return ConstantAggregateZero::get(LoweredTy);
  326. // Return a mat-to-vec translation stub
  327. FunctionType *TranslationStubTy = FunctionType::get(LoweredTy, { Ty }, /* isVarArg */ false);
  328. Function *TranslationStub = m_matToVecStubs->get(TranslationStubTy);
  329. return Builder.CreateCall(TranslationStub, { Val });
  330. }
  331. // Attempts to retrieve the lowered vector pointer equivalent to a matrix pointer.
  332. // Returns nullptr if the pointed-to matrix lives in memory that cannot be lowered at this time,
  333. // for example a buffer or shader inputs/outputs, which are lowered during signature lowering.
  334. Value *HLMatrixLowerPass::tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder, bool DiscardStub) {
  335. if (!HLMatrixType::isMatrixPtrOrArrayPtr(Ptr->getType()))
  336. return nullptr;
  337. // Matrix pointers can only be derived from Allocas, GlobalVariables or resource accesses.
  338. // The first two cases are what this pass must be able to lower, and we should already
  339. // have replaced their uses by vector to matrix pointer translation stubs.
  340. if (CallInst *Call = dyn_cast<CallInst>(Ptr)) {
  341. if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
  342. if (DiscardStub && Call->getNumUses() == 1) {
  343. Call->use_begin()->set(UndefValue::get(Call->getType()));
  344. addToDeadInsts(Call);
  345. }
  346. return Call->getArgOperand(0);
  347. }
  348. }
  349. // There's one more case to handle.
  350. // When compiling shader libraries, signatures won't have been lowered yet.
  351. // So we can have a matrix in a struct as an argument,
  352. // or an alloca'd struct holding the return value of a call and containing a matrix.
  353. Value *RootPtr = Ptr;
  354. while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
  355. RootPtr = GEP->getPointerOperand();
  356. Argument *Arg = dyn_cast<Argument>(RootPtr);
  357. bool IsNonShaderArg = Arg != nullptr && !m_pHLModule->IsEntryThatUsesSignatures(Arg->getParent());
  358. if (IsNonShaderArg || isa<AllocaInst>(RootPtr)) {
  359. // Bitcast the matrix pointer to its lowered equivalent.
  360. // The HLMatrixBitcast pass will take care of this later.
  361. return Builder.CreateBitCast(Ptr, HLMatrixType::getLoweredType(Ptr->getType()));
  362. }
  363. // The pointer must be derived from a resource, we don't handle it in this pass.
  364. return nullptr;
  365. }
  366. // Bitcasts a value from matrix to vector or vice-versa.
  367. // This is used to convert to/from arguments/return values since we don't
  368. // lower signatures in this pass. The later HLMatrixBitcastLower pass fixes this.
  369. Value *HLMatrixLowerPass::bitCastValue(Value *SrcVal, Type* DstTy, bool DstTyAlloca, IRBuilder<> &Builder) {
  370. Type *SrcTy = SrcVal->getType();
  371. DXASSERT_NOMSG(!SrcTy->isPointerTy());
  372. // We store and load from a temporary alloca, bitcasting either on the store pointer
  373. // or on the load pointer.
  374. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
  375. Value *Alloca = AllocaBuilder.CreateAlloca(DstTyAlloca ? DstTy : SrcTy);
  376. Value *BitCastedAlloca = Builder.CreateBitCast(Alloca, (DstTyAlloca ? SrcTy : DstTy)->getPointerTo());
  377. Builder.CreateStore(SrcVal, DstTyAlloca ? BitCastedAlloca : Alloca);
  378. return Builder.CreateLoad(DstTyAlloca ? Alloca : BitCastedAlloca);
  379. }
  380. // Replaces all uses of a matrix value by its lowered vector form,
  381. // inserting translation stubs for users which still expect a matrix value.
  382. void HLMatrixLowerPass::replaceAllUsesByLoweredValue(Instruction* MatInst, Value* VecVal) {
  383. if (VecVal == nullptr || VecVal == MatInst) return;
  384. DXASSERT(HLMatrixType::getLoweredType(MatInst->getType()) == VecVal->getType(),
  385. "Unexpected lowered value type.");
  386. Instruction *VecToMatStub = nullptr;
  387. while (!MatInst->use_empty()) {
  388. Use &ValUse = *MatInst->use_begin();
  389. // Handle non-matrix cases, just point to the new value.
  390. if (MatInst->getType() == VecVal->getType()) {
  391. ValUse.set(VecVal);
  392. continue;
  393. }
  394. // If the user is already a matrix-to-vector translation stub,
  395. // we can now replace it by the proper vector value.
  396. if (CallInst *Call = dyn_cast<CallInst>(ValUse.getUser())) {
  397. if (m_matToVecStubs->contains(Call->getCalledFunction())) {
  398. Call->replaceAllUsesWith(VecVal);
  399. ValUse.set(UndefValue::get(MatInst->getType()));
  400. addToDeadInsts(Call);
  401. continue;
  402. }
  403. }
  404. // Otherwise, the user should point to a vector-to-matrix translation
  405. // stub of the new vector value.
  406. if (VecToMatStub == nullptr) {
  407. FunctionType *TranslationStubTy = FunctionType::get(
  408. MatInst->getType(), { VecVal->getType() }, /* isVarArg */ false);
  409. Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
  410. Instruction *PrevInst = dyn_cast<Instruction>(VecVal);
  411. if (PrevInst == nullptr) PrevInst = MatInst;
  412. IRBuilder<> Builder(PrevInst->getNextNode());
  413. VecToMatStub = Builder.CreateCall(TranslationStub, { VecVal });
  414. }
  415. ValUse.set(VecToMatStub);
  416. }
  417. }
  418. // Replaces all uses of a matrix or matrix array alloca or global variable by its lowered equivalent.
  419. // This doesn't lower the users, but will insert a translation stub from the lowered value pointer
  420. // back to the matrix value pointer, and recreate any GEPs around the new pointer.
  421. // Before: User(GEP(MatrixArrayAlloca))
  422. // After: User(VecToMatPtrStub(GEP'(VectorArrayAlloca)))
  423. void HLMatrixLowerPass::replaceAllVariableUses(Value* MatPtr, Value* LoweredPtr) {
  424. DXASSERT_NOMSG(HLMatrixType::isMatrixPtrOrArrayPtr(MatPtr->getType()));
  425. DXASSERT_NOMSG(LoweredPtr->getType() == HLMatrixType::getLoweredType(MatPtr->getType()));
  426. SmallVector<Value*, 4> GEPIdxStack;
  427. GEPIdxStack.emplace_back(ConstantInt::get(Type::getInt32Ty(MatPtr->getContext()), 0));
  428. replaceAllVariableUses(GEPIdxStack, MatPtr, LoweredPtr);
  429. }
  430. void HLMatrixLowerPass::replaceAllVariableUses(
  431. SmallVectorImpl<Value*> &GEPIdxStack, Value *StackTopPtr, Value* LoweredPtr) {
  432. while (!StackTopPtr->use_empty()) {
  433. llvm::Use &Use = *StackTopPtr->use_begin();
  434. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Use.getUser())) {
  435. DXASSERT(GEP->getNumIndices() >= 1, "Unexpected degenerate GEP.");
  436. DXASSERT(cast<ConstantInt>(*GEP->idx_begin())->isZero(), "Unexpected non-zero first GEP index.");
  437. // Recurse in GEP to find actual users
  438. for (auto It = GEP->idx_begin() + 1; It != GEP->idx_end(); ++It)
  439. GEPIdxStack.emplace_back(*It);
  440. replaceAllVariableUses(GEPIdxStack, GEP, LoweredPtr);
  441. GEPIdxStack.erase(GEPIdxStack.end() - (GEP->getNumIndices() - 1), GEPIdxStack.end());
  442. // Discard the GEP
  443. DXASSERT_NOMSG(GEP->use_empty());
  444. if (GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(GEP)) {
  445. Use.set(UndefValue::get(Use->getType()));
  446. addToDeadInsts(GEPInst);
  447. } else {
  448. // constant GEP
  449. cast<Constant>(GEP)->destroyConstant();
  450. }
  451. continue;
  452. }
  453. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Use.getUser())) {
  454. DXASSERT(CE->getOpcode() == Instruction::AddrSpaceCast ||
  455. CE->use_empty(), "Unexpected constant user");
  456. replaceAllVariableUses(GEPIdxStack, CE, LoweredPtr);
  457. DXASSERT_NOMSG(CE->use_empty());
  458. CE->destroyConstant();
  459. continue;
  460. }
  461. if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(Use.getUser())) {
  462. replaceAllVariableUses(GEPIdxStack, CI, LoweredPtr);
  463. Use.set(UndefValue::get(Use->getType()));
  464. addToDeadInsts(CI);
  465. continue;
  466. }
  467. if (BitCastInst *BCI = dyn_cast<BitCastInst>(Use.getUser())) {
  468. // Replace bitcasts to i8* for lifetime intrinsics.
  469. if (BCI->getType()->isPointerTy() && BCI->getType()->getPointerElementType()->isIntegerTy(8))
  470. {
  471. DXASSERT(onlyUsedByLifetimeMarkers(BCI),
  472. "bitcast to i8* must only be used by lifetime intrinsics");
  473. Value *NewBCI = IRBuilder<>(BCI).CreateBitCast(LoweredPtr, BCI->getType());
  474. // Replace all uses of the use.
  475. BCI->replaceAllUsesWith(NewBCI);
  476. // Remove the current use to end iteration.
  477. Use.set(UndefValue::get(Use->getType()));
  478. addToDeadInsts(BCI);
  479. continue;
  480. }
  481. }
  482. // Recreate the same GEP sequence, if any, on the lowered pointer
  483. IRBuilder<> Builder(cast<Instruction>(Use.getUser()));
  484. Value *LoweredStackTopPtr = GEPIdxStack.size() == 1
  485. ? LoweredPtr : Builder.CreateGEP(LoweredPtr, GEPIdxStack);
  486. // Generate a stub translating the vector pointer back to a matrix pointer,
  487. // such that consuming instructions are unaffected.
  488. FunctionType *TranslationStubTy = FunctionType::get(
  489. StackTopPtr->getType(), { LoweredStackTopPtr->getType() }, /* isVarArg */ false);
  490. Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
  491. Use.set(Builder.CreateCall(TranslationStub, { LoweredStackTopPtr }));
  492. }
  493. }
  494. void HLMatrixLowerPass::lowerGlobal(GlobalVariable *Global) {
  495. if (Global->user_empty()) return;
  496. PointerType *LoweredPtrTy = cast<PointerType>(HLMatrixType::getLoweredType(Global->getType()));
  497. DXASSERT_NOMSG(LoweredPtrTy != Global->getType());
  498. Constant *LoweredInitVal = Global->hasInitializer()
  499. ? lowerConstInitVal(Global->getInitializer()) : nullptr;
  500. GlobalVariable *LoweredGlobal = new GlobalVariable(*m_pModule, LoweredPtrTy->getElementType(),
  501. Global->isConstant(), Global->getLinkage(), LoweredInitVal,
  502. Global->getName() + ".v", /*InsertBefore*/ nullptr, Global->getThreadLocalMode(),
  503. Global->getType()->getAddressSpace());
  504. // Add debug info.
  505. if (m_HasDbgInfo) {
  506. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  507. HLModule::UpdateGlobalVariableDebugInfo(Global, Finder, LoweredGlobal);
  508. }
  509. replaceAllVariableUses(Global, LoweredGlobal);
  510. Global->removeDeadConstantUsers();
  511. Global->eraseFromParent();
  512. }
  513. Constant *HLMatrixLowerPass::lowerConstInitVal(Constant *Val) {
  514. Type *Ty = Val->getType();
  515. // If it's an array of matrices, recurse for each element or nested array
  516. if (ArrayType *ArrayTy = dyn_cast<ArrayType>(Ty)) {
  517. SmallVector<Constant*, 4> LoweredElems;
  518. unsigned NumElems = ArrayTy->getNumElements();
  519. LoweredElems.reserve(NumElems);
  520. for (unsigned ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
  521. Constant *ArrayElem = Val->getAggregateElement(ElemIdx);
  522. LoweredElems.emplace_back(lowerConstInitVal(ArrayElem));
  523. }
  524. Type *LoweredElemTy = HLMatrixType::getLoweredType(ArrayTy->getElementType(), /*MemRepr*/true);
  525. ArrayType *LoweredArrayTy = ArrayType::get(LoweredElemTy, NumElems);
  526. return ConstantArray::get(LoweredArrayTy, LoweredElems);
  527. }
  528. // Otherwise it's a matrix, lower it to a vector
  529. HLMatrixType MatTy = HLMatrixType::cast(Ty);
  530. DXASSERT_NOMSG(isa<StructType>(Ty));
  531. Constant *RowArrayVal = Val->getAggregateElement((unsigned)0);
  532. // Original initializer should have been produced in row/column-major order
  533. // depending on the qualifiers of the target variable, so preserve the order.
  534. SmallVector<Constant*, 16> MatElems;
  535. for (unsigned RowIdx = 0; RowIdx < MatTy.getNumRows(); ++RowIdx) {
  536. Constant *RowVal = RowArrayVal->getAggregateElement(RowIdx);
  537. for (unsigned ColIdx = 0; ColIdx < MatTy.getNumColumns(); ++ColIdx) {
  538. MatElems.emplace_back(RowVal->getAggregateElement(ColIdx));
  539. }
  540. }
  541. Constant *Vec = ConstantVector::get(MatElems);
  542. // Matrix elements are always in register representation,
  543. // but the lowered global variable is of vector type in
  544. // its memory representation, so we must convert here.
  545. // This will produce a constant so we can use an IRBuilder without a valid insertion point.
  546. IRBuilder<> DummyBuilder(Val->getContext());
  547. return cast<Constant>(MatTy.emitLoweredRegToMem(Vec, DummyBuilder));
  548. }
  549. AllocaInst *HLMatrixLowerPass::lowerAlloca(AllocaInst *MatAlloca) {
  550. PointerType *LoweredAllocaTy = cast<PointerType>(HLMatrixType::getLoweredType(MatAlloca->getType()));
  551. IRBuilder<> Builder(MatAlloca);
  552. AllocaInst *LoweredAlloca = Builder.CreateAlloca(
  553. LoweredAllocaTy->getElementType(), nullptr, MatAlloca->getName());
  554. // Update debug info.
  555. if (DbgDeclareInst *DbgDeclare = llvm::FindAllocaDbgDeclare(MatAlloca)) {
  556. LLVMContext &Context = MatAlloca->getContext();
  557. Value *DbgDeclareVar = MetadataAsValue::get(Context, DbgDeclare->getRawVariable());
  558. Value *DbgDeclareExpr = MetadataAsValue::get(Context, DbgDeclare->getRawExpression());
  559. Value *ValueMetadata = MetadataAsValue::get(Context, ValueAsMetadata::get(LoweredAlloca));
  560. IRBuilder<> DebugBuilder(DbgDeclare);
  561. DebugBuilder.CreateCall(DbgDeclare->getCalledFunction(), { ValueMetadata, DbgDeclareVar, DbgDeclareExpr });
  562. }
  563. if (HLModule::HasPreciseAttributeWithMetadata(MatAlloca))
  564. HLModule::MarkPreciseAttributeWithMetadata(LoweredAlloca);
  565. replaceAllVariableUses(MatAlloca, LoweredAlloca);
  566. return LoweredAlloca;
  567. }
  568. void HLMatrixLowerPass::lowerInstruction(Instruction* Inst) {
  569. if (CallInst *Call = dyn_cast<CallInst>(Inst)) {
  570. Value *LoweredValue = lowerCall(Call);
  571. // lowerCall returns the lowered value iff we should discard
  572. // the original matrix instruction and replace all of its uses
  573. // by the lowered value. It returns nullptr to opt-out of this.
  574. if (LoweredValue != nullptr) {
  575. replaceAllUsesByLoweredValue(Call, LoweredValue);
  576. addToDeadInsts(Inst);
  577. }
  578. }
  579. else if (ReturnInst *Return = dyn_cast<ReturnInst>(Inst)) {
  580. lowerReturn(Return);
  581. }
  582. else
  583. llvm_unreachable("Unexpected matrix instruction type.");
  584. }
  585. void HLMatrixLowerPass::lowerReturn(ReturnInst* Return) {
  586. Value *RetVal = Return->getReturnValue();
  587. Type *RetTy = RetVal->getType();
  588. DXASSERT_LOCALVAR(RetTy, !RetTy->isPointerTy(), "Unexpected matrix returned by pointer.");
  589. IRBuilder<> Builder(Return);
  590. Value *LoweredRetVal = getLoweredByValOperand(RetVal, Builder, /* DiscardStub */ true);
  591. // Since we're not lowering the signature, we can't return the lowered value directly,
  592. // so insert a bitcast, which HLMatrixBitcastLower knows how to eliminate.
  593. Value *BitCastedRetVal = bitCastValue(LoweredRetVal, RetVal->getType(), /* DstTyAlloca */ false, Builder);
  594. Return->setOperand(0, BitCastedRetVal);
  595. }
  596. Value *HLMatrixLowerPass::lowerCall(CallInst *Call) {
  597. HLOpcodeGroup OpcodeGroup = GetHLOpcodeGroupByName(Call->getCalledFunction());
  598. return OpcodeGroup == HLOpcodeGroup::NotHL
  599. ? lowerNonHLCall(Call) : lowerHLOperation(Call, OpcodeGroup);
  600. }
  601. // Special function to lower precise call applied to a matrix
  602. // The matrix should be lowered and the call regenerated with vector arg
  603. void HLMatrixLowerPass::lowerPreciseCall(CallInst *Call, IRBuilder<> Builder) {
  604. DXASSERT(Call->getNumArgOperands() == 1, "Only one arg expected for precise matrix call");
  605. Value *Arg = Call->getArgOperand(0);
  606. Value *LoweredArg = getLoweredByValOperand(Arg, Builder);
  607. HLModule::MarkPreciseAttributeOnValWithFunctionCall(LoweredArg, Builder, *m_pModule);
  608. addToDeadInsts(Call);
  609. }
  610. Value *HLMatrixLowerPass::lowerNonHLCall(CallInst *Call) {
  611. // First, handle any operand of matrix-derived type
  612. // We don't lower the callee's signature in this pass,
  613. // so, for any matrix-typed parameter, we create a bitcast from the
  614. // lowered vector back to the matrix type, which the later HLMatrixBitcastLower
  615. // pass knows how to eliminate.
  616. IRBuilder<> PreCallBuilder(Call);
  617. unsigned NumArgs = Call->getNumArgOperands();
  618. Function *Func = Call->getCalledFunction();
  619. if (Func && HLModule::HasPreciseAttribute(Func)) {
  620. lowerPreciseCall(Call, PreCallBuilder);
  621. return nullptr;
  622. }
  623. for (unsigned ArgIdx = 0; ArgIdx < NumArgs; ++ArgIdx) {
  624. Use &ArgUse = Call->getArgOperandUse(ArgIdx);
  625. if (ArgUse->getType()->isPointerTy()) {
  626. // Byref arg
  627. Value *LoweredArg = tryGetLoweredPtrOperand(ArgUse.get(), PreCallBuilder, /* DiscardStub */ true);
  628. if (LoweredArg != nullptr) {
  629. // Pointer to a matrix we've lowered, insert a bitcast back to matrix pointer type.
  630. Value *BitCastedArg = PreCallBuilder.CreateBitCast(LoweredArg, ArgUse->getType());
  631. ArgUse.set(BitCastedArg);
  632. }
  633. }
  634. else {
  635. // Byvalue arg
  636. Value *LoweredArg = getLoweredByValOperand(ArgUse.get(), PreCallBuilder, /* DiscardStub */ true);
  637. if (LoweredArg == ArgUse.get()) continue;
  638. Value *BitCastedArg = bitCastValue(LoweredArg, ArgUse->getType(), /* DstTyAlloca */ false, PreCallBuilder);
  639. ArgUse.set(BitCastedArg);
  640. }
  641. }
  642. // Now check the return type
  643. HLMatrixType RetMatTy = HLMatrixType::dyn_cast(Call->getType());
  644. if (!RetMatTy) {
  645. DXASSERT(!HLMatrixType::isMatrixPtrOrArrayPtr(Call->getType()),
  646. "Unexpected user call returning a matrix by pointer.");
  647. // Nothing to replace, other instructions can consume a non-matrix return type.
  648. return nullptr;
  649. }
  650. // The callee returns a matrix, and we don't lower signatures in this pass.
  651. // We perform a sketchy bitcast to the lowered register-representation type,
  652. // which the later HLMatrixBitcastLower pass knows how to eliminate.
  653. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Call));
  654. Value *LoweredAlloca = AllocaBuilder.CreateAlloca(RetMatTy.getLoweredVectorTypeForReg());
  655. IRBuilder<> PostCallBuilder(Call->getNextNode());
  656. Value *BitCastedAlloca = PostCallBuilder.CreateBitCast(LoweredAlloca, Call->getType()->getPointerTo());
  657. // This is slightly tricky
  658. // We want to replace all uses of the matrix-returning call by the bitcasted value,
  659. // but the store to the bitcasted pointer itself is a use of that matrix,
  660. // so we need to create the load, replace the uses, and then insert the store.
  661. LoadInst *LoweredVal = PostCallBuilder.CreateLoad(LoweredAlloca);
  662. replaceAllUsesByLoweredValue(Call, LoweredVal);
  663. // Now we can insert the store. Make sure to do so before the load.
  664. PostCallBuilder.SetInsertPoint(LoweredVal);
  665. PostCallBuilder.CreateStore(Call, BitCastedAlloca);
  666. // Return nullptr since we did our own uses replacement and we don't want
  667. // the matrix instruction to be marked as dead since we're still using it.
  668. return nullptr;
  669. }
  670. Value *HLMatrixLowerPass::lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup) {
  671. IRBuilder<> Builder(Call);
  672. switch (OpcodeGroup) {
  673. case HLOpcodeGroup::HLIntrinsic:
  674. return lowerHLIntrinsic(Call, static_cast<IntrinsicOp>(GetHLOpcode(Call)));
  675. case HLOpcodeGroup::HLBinOp:
  676. return lowerHLBinaryOperation(
  677. Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
  678. Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
  679. static_cast<HLBinaryOpcode>(GetHLOpcode(Call)), Builder);
  680. case HLOpcodeGroup::HLUnOp:
  681. return lowerHLUnaryOperation(
  682. Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx),
  683. static_cast<HLUnaryOpcode>(GetHLOpcode(Call)), Builder);
  684. case HLOpcodeGroup::HLMatLoadStore:
  685. return lowerHLLoadStore(Call, static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(Call)));
  686. case HLOpcodeGroup::HLCast:
  687. return lowerHLCast(Call,
  688. Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Call->getType(),
  689. static_cast<HLCastOpcode>(GetHLOpcode(Call)), Builder);
  690. case HLOpcodeGroup::HLSubscript:
  691. return lowerHLSubscript(Call, static_cast<HLSubscriptOpcode>(GetHLOpcode(Call)));
  692. case HLOpcodeGroup::HLInit:
  693. return lowerHLInit(Call);
  694. case HLOpcodeGroup::HLSelect:
  695. return lowerHLSelect(Call);
  696. default:
  697. llvm_unreachable("Unexpected matrix opcode");
  698. }
  699. }
  700. Value *HLMatrixLowerPass::lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode) {
  701. IRBuilder<> Builder(Call);
  702. // See if this is a matrix-specific intrinsic which we should expand here
  703. switch (Opcode) {
  704. case IntrinsicOp::IOP_umul:
  705. case IntrinsicOp::IOP_mul:
  706. return lowerHLMulIntrinsic(
  707. Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
  708. Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
  709. /* Unsigned */ Opcode == IntrinsicOp::IOP_umul, Builder);
  710. case IntrinsicOp::IOP_transpose:
  711. return lowerHLTransposeIntrinsic(Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
  712. case IntrinsicOp::IOP_determinant:
  713. return lowerHLDeterminantIntrinsic(Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
  714. }
  715. // Delegate to a lowered intrinsic call
  716. SmallVector<Value*, 4> LoweredArgs;
  717. LoweredArgs.reserve(Call->getNumArgOperands());
  718. for (Value *Arg : Call->arg_operands()) {
  719. if (Arg->getType()->isPointerTy()) {
  720. // ByRef parameter (for example, frexp's second parameter)
  721. // If the argument points to a lowered matrix variable, replace it here,
  722. // otherwise preserve the matrix type and let further passes handle the lowering.
  723. Value *LoweredArg = tryGetLoweredPtrOperand(Arg, Builder);
  724. if (LoweredArg == nullptr) LoweredArg = Arg;
  725. LoweredArgs.emplace_back(LoweredArg);
  726. }
  727. else {
  728. LoweredArgs.emplace_back(getLoweredByValOperand(Arg, Builder));
  729. }
  730. }
  731. Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
  732. return callHLFunction(*m_pModule, HLOpcodeGroup::HLIntrinsic, static_cast<unsigned>(Opcode),
  733. LoweredRetTy, LoweredArgs,
  734. Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
  735. }
  736. // Handles multiplcation of a scalar with a matrix
  737. Value *HLMatrixLowerPass::translateScalarMatMul(Value *Lhs, Value *Rhs, IRBuilder<> &Builder, bool isLhsScalar) {
  738. Value *Mat = isLhsScalar ? Rhs : Lhs;
  739. Value *Scalar = isLhsScalar ? Lhs : Rhs;
  740. Value* LoweredMat = getLoweredByValOperand(Mat, Builder);
  741. Type *ScalarTy = Scalar->getType();
  742. // Perform the scalar-matrix multiplication!
  743. Type *ElemTy = LoweredMat->getType()->getVectorElementType();
  744. bool isIntMulOp = ScalarTy->isIntegerTy() && ElemTy->isIntegerTy();
  745. bool isFloatMulOp = ScalarTy->isFloatingPointTy() && ElemTy->isFloatingPointTy();
  746. DXASSERT(ScalarTy == ElemTy, "Scalar type must match the matrix component type.");
  747. Value *Result = Builder.CreateVectorSplat(LoweredMat->getType()->getVectorNumElements(), Scalar);
  748. if (isFloatMulOp) {
  749. // Preserve the order of operation for floats
  750. Result = isLhsScalar ? Builder.CreateFMul(Result, LoweredMat) : Builder.CreateFMul(LoweredMat, Result);
  751. }
  752. else if (isIntMulOp) {
  753. // Doesn't matter for integers but still preserve the order of operation
  754. Result = isLhsScalar ? Builder.CreateMul(Result, LoweredMat) : Builder.CreateMul(LoweredMat, Result);
  755. }
  756. else {
  757. DXASSERT(0, "Unknown type encountered when doing scalar-matrix multiplication.");
  758. }
  759. return Result;
  760. }
  761. Value *HLMatrixLowerPass::lowerHLMulIntrinsic(Value* Lhs, Value *Rhs,
  762. bool Unsigned, IRBuilder<> &Builder) {
  763. HLMatrixType LhsMatTy = HLMatrixType::dyn_cast(Lhs->getType());
  764. HLMatrixType RhsMatTy = HLMatrixType::dyn_cast(Rhs->getType());
  765. Value* LoweredLhs = getLoweredByValOperand(Lhs, Builder);
  766. Value* LoweredRhs = getLoweredByValOperand(Rhs, Builder);
  767. // Translate multiplication of scalar with matrix
  768. bool isLhsScalar = !LoweredLhs->getType()->isVectorTy();
  769. bool isRhsScalar = !LoweredRhs->getType()->isVectorTy();
  770. bool isScalar = isLhsScalar || isRhsScalar;
  771. if (isScalar)
  772. return translateScalarMatMul(Lhs, Rhs, Builder, isLhsScalar);
  773. DXASSERT(LoweredLhs->getType()->getScalarType() == LoweredRhs->getType()->getScalarType(),
  774. "Unexpected element type mismatch in mul intrinsic.");
  775. DXASSERT(cast<VectorType>(LoweredLhs->getType()) && cast<VectorType>(LoweredRhs->getType()),
  776. "Unexpected scalar in lowered matrix mul intrinsic operands.");
  777. Type* ElemTy = LoweredLhs->getType()->getScalarType();
  778. // Figure out the dimensions of each side
  779. unsigned LhsNumRows, LhsNumCols, RhsNumRows, RhsNumCols;
  780. if (LhsMatTy && RhsMatTy) {
  781. LhsNumRows = LhsMatTy.getNumRows();
  782. LhsNumCols = LhsMatTy.getNumColumns();
  783. RhsNumRows = RhsMatTy.getNumRows();
  784. RhsNumCols = RhsMatTy.getNumColumns();
  785. }
  786. else if (LhsMatTy) {
  787. LhsNumRows = LhsMatTy.getNumRows();
  788. LhsNumCols = LhsMatTy.getNumColumns();
  789. RhsNumRows = LoweredRhs->getType()->getVectorNumElements();
  790. RhsNumCols = 1;
  791. }
  792. else if (RhsMatTy) {
  793. LhsNumRows = 1;
  794. LhsNumCols = LoweredLhs->getType()->getVectorNumElements();
  795. RhsNumRows = RhsMatTy.getNumRows();
  796. RhsNumCols = RhsMatTy.getNumColumns();
  797. }
  798. else {
  799. llvm_unreachable("mul intrinsic was identified as a matrix operation but neither operand is a matrix.");
  800. }
  801. DXASSERT(LhsNumCols == RhsNumRows, "Matrix mul intrinsic operands dimensions mismatch.");
  802. HLMatrixType ResultMatTy(ElemTy, LhsNumRows, RhsNumCols);
  803. unsigned AccCount = LhsNumCols;
  804. // Get the multiply-and-add intrinsic function, we'll need it
  805. IntrinsicOp MadOpcode = Unsigned ? IntrinsicOp::IOP_umad : IntrinsicOp::IOP_mad;
  806. FunctionType *MadFuncTy = FunctionType::get(ElemTy, { Builder.getInt32Ty(), ElemTy, ElemTy, ElemTy }, false);
  807. Function *MadFunc = GetOrCreateHLFunction(*m_pModule, MadFuncTy, HLOpcodeGroup::HLIntrinsic, (unsigned)MadOpcode);
  808. Constant *MadOpcodeVal = Builder.getInt32((unsigned)MadOpcode);
  809. // Perform the multiplication!
  810. Value *Result = UndefValue::get(VectorType::get(ElemTy, LhsNumRows * RhsNumCols));
  811. for (unsigned ResultRowIdx = 0; ResultRowIdx < ResultMatTy.getNumRows(); ++ResultRowIdx) {
  812. for (unsigned ResultColIdx = 0; ResultColIdx < ResultMatTy.getNumColumns(); ++ResultColIdx) {
  813. unsigned ResultElemIdx = ResultMatTy.getRowMajorIndex(ResultRowIdx, ResultColIdx);
  814. Value *ResultElem = nullptr;
  815. for (unsigned AccIdx = 0; AccIdx < AccCount; ++AccIdx) {
  816. unsigned LhsElemIdx = HLMatrixType::getRowMajorIndex(ResultRowIdx, AccIdx, LhsNumRows, LhsNumCols);
  817. unsigned RhsElemIdx = HLMatrixType::getRowMajorIndex(AccIdx, ResultColIdx, RhsNumRows, RhsNumCols);
  818. Value* LhsElem = Builder.CreateExtractElement(LoweredLhs, static_cast<uint64_t>(LhsElemIdx));
  819. Value* RhsElem = Builder.CreateExtractElement(LoweredRhs, static_cast<uint64_t>(RhsElemIdx));
  820. if (ResultElem == nullptr) {
  821. ResultElem = ElemTy->isFloatingPointTy()
  822. ? Builder.CreateFMul(LhsElem, RhsElem)
  823. : Builder.CreateMul(LhsElem, RhsElem);
  824. }
  825. else {
  826. ResultElem = Builder.CreateCall(MadFunc, { MadOpcodeVal, LhsElem, RhsElem, ResultElem });
  827. }
  828. }
  829. Result = Builder.CreateInsertElement(Result, ResultElem, static_cast<uint64_t>(ResultElemIdx));
  830. }
  831. }
  832. return Result;
  833. }
  834. Value *HLMatrixLowerPass::lowerHLTransposeIntrinsic(Value* MatVal, IRBuilder<> &Builder) {
  835. HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
  836. Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
  837. return MatTy.emitLoweredVectorRowToCol(LoweredVal, Builder);
  838. }
  839. static Value *determinant2x2(Value *M00, Value *M01, Value *M10, Value *M11, IRBuilder<> &Builder) {
  840. Value *Mul0 = Builder.CreateFMul(M00, M11);
  841. Value *Mul1 = Builder.CreateFMul(M01, M10);
  842. return Builder.CreateFSub(Mul0, Mul1);
  843. }
  844. static Value *determinant3x3(Value *M00, Value *M01, Value *M02,
  845. Value *M10, Value *M11, Value *M12,
  846. Value *M20, Value *M21, Value *M22,
  847. IRBuilder<> &Builder) {
  848. Value *Det00 = determinant2x2(M11, M12, M21, M22, Builder);
  849. Value *Det01 = determinant2x2(M10, M12, M20, M22, Builder);
  850. Value *Det02 = determinant2x2(M10, M11, M20, M21, Builder);
  851. Det00 = Builder.CreateFMul(M00, Det00);
  852. Det01 = Builder.CreateFMul(M01, Det01);
  853. Det02 = Builder.CreateFMul(M02, Det02);
  854. Value *Result = Builder.CreateFSub(Det00, Det01);
  855. Result = Builder.CreateFAdd(Result, Det02);
  856. return Result;
  857. }
  858. static Value *determinant4x4(Value *M00, Value *M01, Value *M02, Value *M03,
  859. Value *M10, Value *M11, Value *M12, Value *M13,
  860. Value *M20, Value *M21, Value *M22, Value *M23,
  861. Value *M30, Value *M31, Value *M32, Value *M33,
  862. IRBuilder<> &Builder) {
  863. Value *Det00 = determinant3x3(M11, M12, M13, M21, M22, M23, M31, M32, M33, Builder);
  864. Value *Det01 = determinant3x3(M10, M12, M13, M20, M22, M23, M30, M32, M33, Builder);
  865. Value *Det02 = determinant3x3(M10, M11, M13, M20, M21, M23, M30, M31, M33, Builder);
  866. Value *Det03 = determinant3x3(M10, M11, M12, M20, M21, M22, M30, M31, M32, Builder);
  867. Det00 = Builder.CreateFMul(M00, Det00);
  868. Det01 = Builder.CreateFMul(M01, Det01);
  869. Det02 = Builder.CreateFMul(M02, Det02);
  870. Det03 = Builder.CreateFMul(M03, Det03);
  871. Value *Result = Builder.CreateFSub(Det00, Det01);
  872. Result = Builder.CreateFAdd(Result, Det02);
  873. Result = Builder.CreateFSub(Result, Det03);
  874. return Result;
  875. }
  876. Value *HLMatrixLowerPass::lowerHLDeterminantIntrinsic(Value* MatVal, IRBuilder<> &Builder) {
  877. HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
  878. DXASSERT_NOMSG(MatTy.getNumColumns() == MatTy.getNumRows());
  879. Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
  880. // Extract all matrix elements
  881. SmallVector<Value*, 16> Elems;
  882. for (unsigned ElemIdx = 0; ElemIdx < MatTy.getNumElements(); ++ElemIdx)
  883. Elems.emplace_back(Builder.CreateExtractElement(LoweredVal, static_cast<uint64_t>(ElemIdx)));
  884. // Delegate to appropriate determinant function
  885. switch (MatTy.getNumColumns()) {
  886. case 1:
  887. return Elems[0];
  888. case 2:
  889. return determinant2x2(
  890. Elems[0], Elems[1],
  891. Elems[2], Elems[3],
  892. Builder);
  893. case 3:
  894. return determinant3x3(
  895. Elems[0], Elems[1], Elems[2],
  896. Elems[3], Elems[4], Elems[5],
  897. Elems[6], Elems[7], Elems[8],
  898. Builder);
  899. case 4:
  900. return determinant4x4(
  901. Elems[0], Elems[1], Elems[2], Elems[3],
  902. Elems[4], Elems[5], Elems[6], Elems[7],
  903. Elems[8], Elems[9], Elems[10], Elems[11],
  904. Elems[12], Elems[13], Elems[14], Elems[15],
  905. Builder);
  906. default:
  907. llvm_unreachable("Unexpected matrix dimensions.");
  908. }
  909. }
  910. Value *HLMatrixLowerPass::lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder) {
  911. Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
  912. VectorType *VecTy = cast<VectorType>(LoweredVal->getType());
  913. bool IsFloat = VecTy->getElementType()->isFloatingPointTy();
  914. switch (Opcode) {
  915. case HLUnaryOpcode::Plus: return LoweredVal; // No-op
  916. case HLUnaryOpcode::Minus:
  917. return IsFloat
  918. ? Builder.CreateFSub(Constant::getNullValue(VecTy), LoweredVal)
  919. : Builder.CreateSub(Constant::getNullValue(VecTy), LoweredVal);
  920. case HLUnaryOpcode::LNot:
  921. return IsFloat
  922. ? Builder.CreateFCmp(CmpInst::FCMP_UEQ, LoweredVal, Constant::getNullValue(VecTy))
  923. : Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredVal, Constant::getNullValue(VecTy));
  924. case HLUnaryOpcode::Not:
  925. return Builder.CreateXor(LoweredVal, Constant::getAllOnesValue(VecTy));
  926. case HLUnaryOpcode::PostInc:
  927. case HLUnaryOpcode::PreInc:
  928. case HLUnaryOpcode::PostDec:
  929. case HLUnaryOpcode::PreDec: {
  930. Constant *ScalarOne = IsFloat
  931. ? ConstantFP::get(VecTy->getElementType(), 1)
  932. : ConstantInt::get(VecTy->getElementType(), 1);
  933. Constant *VecOne = ConstantVector::getSplat(VecTy->getNumElements(), ScalarOne);
  934. // CodeGen already emitted the load and following store, our job is only to produce
  935. // the updated value.
  936. if (Opcode == HLUnaryOpcode::PostInc || Opcode == HLUnaryOpcode::PreInc) {
  937. return IsFloat
  938. ? Builder.CreateFAdd(LoweredVal, VecOne)
  939. : Builder.CreateAdd(LoweredVal, VecOne);
  940. }
  941. else {
  942. return IsFloat
  943. ? Builder.CreateFSub(LoweredVal, VecOne)
  944. : Builder.CreateSub(LoweredVal, VecOne);
  945. }
  946. }
  947. default:
  948. llvm_unreachable("Unsupported unary matrix operator");
  949. }
  950. }
  951. Value *HLMatrixLowerPass::lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder) {
  952. Value *LoweredLhs = getLoweredByValOperand(Lhs, Builder);
  953. Value *LoweredRhs = getLoweredByValOperand(Rhs, Builder);
  954. DXASSERT(LoweredLhs->getType()->isVectorTy() && LoweredRhs->getType()->isVectorTy(),
  955. "Expected lowered binary operation operands to be vectors");
  956. DXASSERT(LoweredLhs->getType() == LoweredRhs->getType(),
  957. "Expected lowered binary operation operands to have matching types.");
  958. bool IsFloat = LoweredLhs->getType()->getVectorElementType()->isFloatingPointTy();
  959. switch (Opcode) {
  960. case HLBinaryOpcode::Add:
  961. return IsFloat
  962. ? Builder.CreateFAdd(LoweredLhs, LoweredRhs)
  963. : Builder.CreateAdd(LoweredLhs, LoweredRhs);
  964. case HLBinaryOpcode::Sub:
  965. return IsFloat
  966. ? Builder.CreateFSub(LoweredLhs, LoweredRhs)
  967. : Builder.CreateSub(LoweredLhs, LoweredRhs);
  968. case HLBinaryOpcode::Mul:
  969. return IsFloat
  970. ? Builder.CreateFMul(LoweredLhs, LoweredRhs)
  971. : Builder.CreateMul(LoweredLhs, LoweredRhs);
  972. case HLBinaryOpcode::Div:
  973. return IsFloat
  974. ? Builder.CreateFDiv(LoweredLhs, LoweredRhs)
  975. : Builder.CreateSDiv(LoweredLhs, LoweredRhs);
  976. case HLBinaryOpcode::Rem:
  977. return IsFloat
  978. ? Builder.CreateFRem(LoweredLhs, LoweredRhs)
  979. : Builder.CreateSRem(LoweredLhs, LoweredRhs);
  980. case HLBinaryOpcode::And:
  981. return Builder.CreateAnd(LoweredLhs, LoweredRhs);
  982. case HLBinaryOpcode::Or:
  983. return Builder.CreateOr(LoweredLhs, LoweredRhs);
  984. case HLBinaryOpcode::Xor:
  985. return Builder.CreateXor(LoweredLhs, LoweredRhs);
  986. case HLBinaryOpcode::Shl:
  987. return Builder.CreateShl(LoweredLhs, LoweredRhs);
  988. case HLBinaryOpcode::Shr:
  989. return Builder.CreateAShr(LoweredLhs, LoweredRhs);
  990. case HLBinaryOpcode::LT:
  991. return IsFloat
  992. ? Builder.CreateFCmp(CmpInst::FCMP_OLT, LoweredLhs, LoweredRhs)
  993. : Builder.CreateICmp(CmpInst::ICMP_SLT, LoweredLhs, LoweredRhs);
  994. case HLBinaryOpcode::GT:
  995. return IsFloat
  996. ? Builder.CreateFCmp(CmpInst::FCMP_OGT, LoweredLhs, LoweredRhs)
  997. : Builder.CreateICmp(CmpInst::ICMP_SGT, LoweredLhs, LoweredRhs);
  998. case HLBinaryOpcode::LE:
  999. return IsFloat
  1000. ? Builder.CreateFCmp(CmpInst::FCMP_OLE, LoweredLhs, LoweredRhs)
  1001. : Builder.CreateICmp(CmpInst::ICMP_SLE, LoweredLhs, LoweredRhs);
  1002. case HLBinaryOpcode::GE:
  1003. return IsFloat
  1004. ? Builder.CreateFCmp(CmpInst::FCMP_OGE, LoweredLhs, LoweredRhs)
  1005. : Builder.CreateICmp(CmpInst::ICMP_SGE, LoweredLhs, LoweredRhs);
  1006. case HLBinaryOpcode::EQ:
  1007. return IsFloat
  1008. ? Builder.CreateFCmp(CmpInst::FCMP_OEQ, LoweredLhs, LoweredRhs)
  1009. : Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredLhs, LoweredRhs);
  1010. case HLBinaryOpcode::NE:
  1011. return IsFloat
  1012. ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, LoweredRhs)
  1013. : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, LoweredRhs);
  1014. case HLBinaryOpcode::UDiv:
  1015. return Builder.CreateUDiv(LoweredLhs, LoweredRhs);
  1016. case HLBinaryOpcode::URem:
  1017. return Builder.CreateURem(LoweredLhs, LoweredRhs);
  1018. case HLBinaryOpcode::UShr:
  1019. return Builder.CreateLShr(LoweredLhs, LoweredRhs);
  1020. case HLBinaryOpcode::ULT:
  1021. return Builder.CreateICmp(CmpInst::ICMP_ULT, LoweredLhs, LoweredRhs);
  1022. case HLBinaryOpcode::UGT:
  1023. return Builder.CreateICmp(CmpInst::ICMP_UGT, LoweredLhs, LoweredRhs);
  1024. case HLBinaryOpcode::ULE:
  1025. return Builder.CreateICmp(CmpInst::ICMP_ULE, LoweredLhs, LoweredRhs);
  1026. case HLBinaryOpcode::UGE:
  1027. return Builder.CreateICmp(CmpInst::ICMP_UGE, LoweredLhs, LoweredRhs);
  1028. case HLBinaryOpcode::LAnd:
  1029. case HLBinaryOpcode::LOr: {
  1030. Value *Zero = Constant::getNullValue(LoweredLhs->getType());
  1031. Value *LhsCmp = IsFloat
  1032. ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, Zero)
  1033. : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, Zero);
  1034. Value *RhsCmp = IsFloat
  1035. ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredRhs, Zero)
  1036. : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredRhs, Zero);
  1037. return Opcode == HLBinaryOpcode::LOr
  1038. ? Builder.CreateOr(LhsCmp, RhsCmp)
  1039. : Builder.CreateAnd(LhsCmp, RhsCmp);
  1040. }
  1041. default:
  1042. llvm_unreachable("Unsupported binary matrix operator");
  1043. }
  1044. }
  1045. Value *HLMatrixLowerPass::lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode) {
  1046. IRBuilder<> Builder(Call);
  1047. switch (Opcode) {
  1048. case HLMatLoadStoreOpcode::RowMatLoad:
  1049. case HLMatLoadStoreOpcode::ColMatLoad:
  1050. return lowerHLLoad(Call, Call->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx),
  1051. /* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatLoad, Builder);
  1052. case HLMatLoadStoreOpcode::RowMatStore:
  1053. case HLMatLoadStoreOpcode::ColMatStore:
  1054. return lowerHLStore(Call,
  1055. Call->getArgOperand(HLOperandIndex::kMatStoreValOpIdx),
  1056. Call->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx),
  1057. /* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatStore,
  1058. /* Return */ !Call->getType()->isVoidTy(), Builder);
  1059. default:
  1060. llvm_unreachable("Unsupported matrix load/store operation");
  1061. }
  1062. }
  1063. Value *HLMatrixLowerPass::lowerHLLoad(CallInst *Call, Value *MatPtr, bool RowMajor, IRBuilder<> &Builder) {
  1064. HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
  1065. Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
  1066. if (LoweredPtr == nullptr) {
  1067. // Can't lower this here, defer to HL signature lower
  1068. HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
  1069. return callHLFunction(
  1070. *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
  1071. MatTy.getLoweredVectorTypeForReg(), { Builder.getInt32((uint32_t)Opcode), MatPtr },
  1072. Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
  1073. }
  1074. return MatTy.emitLoweredLoad(LoweredPtr, Builder);
  1075. }
  1076. Value *HLMatrixLowerPass::lowerHLStore(CallInst *Call, Value *MatVal, Value *MatPtr,
  1077. bool RowMajor, bool Return, IRBuilder<> &Builder) {
  1078. DXASSERT(MatVal->getType() == MatPtr->getType()->getPointerElementType(),
  1079. "Matrix store value/pointer type mismatch.");
  1080. Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
  1081. Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
  1082. if (LoweredPtr == nullptr) {
  1083. // Can't lower the pointer here, defer to HL signature lower
  1084. HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatStore : HLMatLoadStoreOpcode::ColMatStore;
  1085. return callHLFunction(
  1086. *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
  1087. Return ? LoweredVal->getType() : Builder.getVoidTy(),
  1088. { Builder.getInt32((uint32_t)Opcode), MatPtr, LoweredVal },
  1089. Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
  1090. }
  1091. HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
  1092. StoreInst *LoweredStore = MatTy.emitLoweredStore(LoweredVal, LoweredPtr, Builder);
  1093. // If the intrinsic returned a value, return the stored lowered value
  1094. return Return ? LoweredVal : LoweredStore;
  1095. }
  1096. static Value *convertScalarOrVector(Value *SrcVal, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> Builder) {
  1097. DXASSERT(SrcVal->getType()->isVectorTy() == DstTy->isVectorTy(),
  1098. "Scalar/vector type mismatch in numerical conversion.");
  1099. Type *SrcTy = SrcVal->getType();
  1100. // Conversions between equivalent types are no-ops,
  1101. // even between signed/unsigned variants.
  1102. if (SrcTy == DstTy) return SrcVal;
  1103. // Conversions to bools are comparisons
  1104. if (DstTy->getScalarSizeInBits() == 1) {
  1105. // fcmp une is what regular clang uses in C++ for (bool)f;
  1106. return cast<Instruction>(SrcTy->isIntOrIntVectorTy()
  1107. ? Builder.CreateICmpNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool")
  1108. : Builder.CreateFCmpUNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool"));
  1109. }
  1110. // Cast necessary
  1111. bool SrcIsUnsigned = Opcode == HLCastOpcode::FromUnsignedCast ||
  1112. Opcode == HLCastOpcode::UnsignedUnsignedCast;
  1113. bool DstIsUnsigned = Opcode == HLCastOpcode::ToUnsignedCast ||
  1114. Opcode == HLCastOpcode::UnsignedUnsignedCast;
  1115. auto CastOp = static_cast<Instruction::CastOps>(HLModule::GetNumericCastOp(
  1116. SrcTy, SrcIsUnsigned, DstTy, DstIsUnsigned));
  1117. return cast<Instruction>(Builder.CreateCast(CastOp, SrcVal, DstTy));
  1118. }
  1119. Value *HLMatrixLowerPass::lowerHLCast(CallInst *Call, Value *Src, Type *DstTy,
  1120. HLCastOpcode Opcode, IRBuilder<> &Builder) {
  1121. // The opcode really doesn't mean much here, the types involved are what drive most of the casting.
  1122. DXASSERT(Opcode != HLCastOpcode::HandleToResCast, "Unexpected matrix cast opcode.");
  1123. if (dxilutil::IsIntegerOrFloatingPointType(Src->getType())) {
  1124. // Scalar to matrix splat
  1125. HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
  1126. // Apply element conversion
  1127. Value *Result = convertScalarOrVector(Src,
  1128. MatDstTy.getElementTypeForReg(), Opcode, Builder);
  1129. // Splat to a vector
  1130. Result = Builder.CreateInsertElement(
  1131. UndefValue::get(VectorType::get(Result->getType(), 1)),
  1132. Result, static_cast<uint64_t>(0));
  1133. return Builder.CreateShuffleVector(Result, Result,
  1134. ConstantVector::getSplat(MatDstTy.getNumElements(), Builder.getInt32(0)));
  1135. }
  1136. else if (VectorType *SrcVecTy = dyn_cast<VectorType>(Src->getType())) {
  1137. // Vector to matrix
  1138. HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
  1139. Value *Result = Src;
  1140. // We might need to truncate
  1141. if (MatDstTy.getNumElements() < SrcVecTy->getNumElements()) {
  1142. SmallVector<int, 4> ShuffleIndices;
  1143. for (unsigned Idx = 0; Idx < MatDstTy.getNumElements(); ++Idx)
  1144. ShuffleIndices.emplace_back(static_cast<int>(Idx));
  1145. Result = Builder.CreateShuffleVector(Src, Src, ShuffleIndices);
  1146. }
  1147. // Apply element conversion
  1148. return convertScalarOrVector(Result,
  1149. MatDstTy.getLoweredVectorTypeForReg(), Opcode, Builder);
  1150. }
  1151. // Source must now be a matrix
  1152. HLMatrixType MatSrcTy = HLMatrixType::cast(Src->getType());
  1153. VectorType* LoweredSrcTy = MatSrcTy.getLoweredVectorTypeForReg();
  1154. Value *LoweredSrc;
  1155. if (isa<Argument>(Src)) {
  1156. // Function arguments are lowered in HLSignatureLower.
  1157. // Initial codegen first generates those cast intrinsics to tell us how to lower them into vectors.
  1158. // Preserve them, but change the return type to vector.
  1159. DXASSERT(Opcode == HLCastOpcode::ColMatrixToVecCast || Opcode == HLCastOpcode::RowMatrixToVecCast,
  1160. "Unexpected cast of matrix argument.");
  1161. LoweredSrc = callHLFunction(*m_pModule, HLOpcodeGroup::HLCast, static_cast<unsigned>(Opcode),
  1162. LoweredSrcTy, { Builder.getInt32((uint32_t)Opcode), Src },
  1163. Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
  1164. }
  1165. else {
  1166. LoweredSrc = getLoweredByValOperand(Src, Builder);
  1167. }
  1168. DXASSERT_NOMSG(LoweredSrc->getType() == LoweredSrcTy);
  1169. Value* Result = LoweredSrc;
  1170. Type* LoweredDstTy = DstTy;
  1171. if (dxilutil::IsIntegerOrFloatingPointType(DstTy)) {
  1172. // Matrix to scalar
  1173. Result = Builder.CreateExtractElement(LoweredSrc, static_cast<uint64_t>(0));
  1174. }
  1175. else if (DstTy->isVectorTy()) {
  1176. // Matrix to vector
  1177. VectorType *DstVecTy = cast<VectorType>(DstTy);
  1178. DXASSERT(DstVecTy->getNumElements() <= LoweredSrcTy->getNumElements(),
  1179. "Cannot cast matrix to a larger vector.");
  1180. // We might have to truncate
  1181. if (DstTy->getVectorNumElements() < LoweredSrcTy->getNumElements()) {
  1182. SmallVector<int, 3> ShuffleIndices;
  1183. for (unsigned Idx = 0; Idx < DstVecTy->getNumElements(); ++Idx)
  1184. ShuffleIndices.emplace_back(static_cast<int>(Idx));
  1185. Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
  1186. }
  1187. }
  1188. else {
  1189. // Destination must now be a matrix too
  1190. HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
  1191. // Apply any changes at the matrix level: orientation changes and truncation
  1192. if (Opcode == HLCastOpcode::ColMatrixToRowMatrix)
  1193. Result = MatSrcTy.emitLoweredVectorColToRow(Result, Builder);
  1194. else if (Opcode == HLCastOpcode::RowMatrixToColMatrix)
  1195. Result = MatSrcTy.emitLoweredVectorRowToCol(Result, Builder);
  1196. else if (MatDstTy.getNumRows() != MatSrcTy.getNumRows()
  1197. || MatDstTy.getNumColumns() != MatSrcTy.getNumColumns()) {
  1198. // Apply truncation
  1199. DXASSERT(MatDstTy.getNumRows() <= MatSrcTy.getNumRows()
  1200. && MatDstTy.getNumColumns() <= MatSrcTy.getNumColumns(),
  1201. "Unexpected matrix cast between incompatible dimensions.");
  1202. SmallVector<int, 16> ShuffleIndices;
  1203. for (unsigned RowIdx = 0; RowIdx < MatDstTy.getNumRows(); ++RowIdx)
  1204. for (unsigned ColIdx = 0; ColIdx < MatDstTy.getNumColumns(); ++ColIdx)
  1205. ShuffleIndices.emplace_back(static_cast<int>(MatSrcTy.getRowMajorIndex(RowIdx, ColIdx)));
  1206. Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
  1207. }
  1208. LoweredDstTy = MatDstTy.getLoweredVectorTypeForReg();
  1209. DXASSERT(Result->getType()->getVectorNumElements() == LoweredDstTy->getVectorNumElements(),
  1210. "Unexpected matrix src/dst lowered element count mismatch after truncation.");
  1211. }
  1212. // Apply element conversion
  1213. return convertScalarOrVector(Result, LoweredDstTy, Opcode, Builder);
  1214. }
  1215. Value *HLMatrixLowerPass::lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode) {
  1216. switch (Opcode) {
  1217. case HLSubscriptOpcode::RowMatElement:
  1218. case HLSubscriptOpcode::ColMatElement:
  1219. return lowerHLMatElementSubscript(Call,
  1220. /* RowMajor */ Opcode == HLSubscriptOpcode::RowMatElement);
  1221. case HLSubscriptOpcode::RowMatSubscript:
  1222. case HLSubscriptOpcode::ColMatSubscript:
  1223. return lowerHLMatSubscript(Call,
  1224. /* RowMajor */ Opcode == HLSubscriptOpcode::RowMatSubscript);
  1225. case HLSubscriptOpcode::DefaultSubscript:
  1226. case HLSubscriptOpcode::CBufferSubscript:
  1227. // Those get lowered during HLOperationLower,
  1228. // and the return type must stay unchanged (as a matrix)
  1229. // to provide the metadata to properly emit the loads.
  1230. return nullptr;
  1231. default:
  1232. llvm_unreachable("Unexpected matrix subscript opcode.");
  1233. }
  1234. }
  1235. Value *HLMatrixLowerPass::lowerHLMatElementSubscript(CallInst *Call, bool RowMajor) {
  1236. (void)RowMajor; // It doesn't look like we actually need this?
  1237. Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  1238. Constant *IdxVec = cast<Constant>(Call->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx));
  1239. VectorType *IdxVecTy = cast<VectorType>(IdxVec->getType());
  1240. // Get the loaded lowered vector element indices
  1241. SmallVector<Value*, 4> ElemIndices;
  1242. ElemIndices.reserve(IdxVecTy->getNumElements());
  1243. for (unsigned VecIdx = 0; VecIdx < IdxVecTy->getNumElements(); ++VecIdx) {
  1244. ElemIndices.emplace_back(IdxVec->getAggregateElement(VecIdx));
  1245. }
  1246. lowerHLMatSubscript(Call, MatPtr, ElemIndices);
  1247. // We did our own replacement of uses, opt-out of having the caller does it for us.
  1248. return nullptr;
  1249. }
  1250. Value *HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, bool RowMajor) {
  1251. (void)RowMajor; // It doesn't look like we actually need this?
  1252. Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  1253. // Gather the indices, checking if they are all constant
  1254. SmallVector<Value*, 4> ElemIndices;
  1255. for (unsigned Idx = HLOperandIndex::kMatSubscriptSubOpIdx; Idx < Call->getNumArgOperands(); ++Idx) {
  1256. ElemIndices.emplace_back(Call->getArgOperand(Idx));
  1257. }
  1258. lowerHLMatSubscript(Call, MatPtr, ElemIndices);
  1259. // We did our own replacement of uses, opt-out of having the caller does it for us.
  1260. return nullptr;
  1261. }
  1262. void HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices) {
  1263. DXASSERT_NOMSG(HLMatrixType::isMatrixPtr(MatPtr->getType()));
  1264. IRBuilder<> CallBuilder(Call);
  1265. Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, CallBuilder);
  1266. Value *LoweredMatrix = nullptr;
  1267. Value *RootPtr = LoweredPtr? LoweredPtr: MatPtr;
  1268. while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
  1269. RootPtr = GEP->getPointerOperand();
  1270. if (LoweredPtr == nullptr) {
  1271. if (!isa<Argument>(RootPtr))
  1272. return;
  1273. // For a shader input, load the matrix into a lowered ptr
  1274. // The load will be handled by LowerSignature
  1275. HLMatLoadStoreOpcode Opcode = (HLSubscriptOpcode)GetHLOpcode(Call) == HLSubscriptOpcode::RowMatSubscript ?
  1276. HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
  1277. HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
  1278. LoweredMatrix = callHLFunction(
  1279. *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
  1280. MatTy.getLoweredVectorTypeForReg(), { CallBuilder.getInt32((uint32_t)Opcode), MatPtr },
  1281. Call->getCalledFunction()->getAttributes().getFnAttributes(), CallBuilder);
  1282. }
  1283. // For global variables, we can GEP directly into the lowered vector pointer.
  1284. // This is necessary to support group shared memory atomics and the likes.
  1285. bool AllowLoweredPtrGEPs = isa<GlobalVariable>(RootPtr);
  1286. // Just constructing this does all the work
  1287. HLMatrixSubscriptUseReplacer UseReplacer(Call, LoweredPtr, LoweredMatrix,
  1288. ElemIndices, AllowLoweredPtrGEPs, m_deadInsts);
  1289. DXASSERT(Call->use_empty(), "Expected all matrix subscript uses to have been replaced.");
  1290. addToDeadInsts(Call);
  1291. }
  1292. Value *HLMatrixLowerPass::lowerHLInit(CallInst *Call) {
  1293. DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
  1294. // Figure out the result type
  1295. HLMatrixType MatTy = HLMatrixType::cast(Call->getType());
  1296. VectorType *LoweredTy = MatTy.getLoweredVectorTypeForReg();
  1297. // Handle case where produced by EmitHLSLFlatConversion where there's one
  1298. // vector argument, instead of scalar arguments.
  1299. if (1 == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx &&
  1300. Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx)->
  1301. getType()->isVectorTy()) {
  1302. Value *LoweredVec = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
  1303. DXASSERT(LoweredTy->getNumElements() ==
  1304. LoweredVec->getType()->getVectorNumElements(),
  1305. "Invalid matrix init argument vector element count.");
  1306. return LoweredVec;
  1307. }
  1308. DXASSERT(LoweredTy->getNumElements() == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx,
  1309. "Invalid matrix init argument count.");
  1310. // Build the result vector from the init args.
  1311. // Both the args and the result vector are in row-major order, so no shuffling is necessary.
  1312. IRBuilder<> Builder(Call);
  1313. Value *LoweredVec = UndefValue::get(LoweredTy);
  1314. for (unsigned VecElemIdx = 0; VecElemIdx < LoweredTy->getNumElements(); ++VecElemIdx) {
  1315. Value *ArgVal = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx + VecElemIdx);
  1316. DXASSERT(dxilutil::IsIntegerOrFloatingPointType(ArgVal->getType()),
  1317. "Expected only scalars in matrix initialization.");
  1318. LoweredVec = Builder.CreateInsertElement(LoweredVec, ArgVal, static_cast<uint64_t>(VecElemIdx));
  1319. }
  1320. return LoweredVec;
  1321. }
  1322. Value *HLMatrixLowerPass::lowerHLSelect(CallInst *Call) {
  1323. DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
  1324. Value *Cond = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
  1325. Value *TrueMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx);
  1326. Value *FalseMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx);
  1327. DXASSERT(TrueMat->getType() == FalseMat->getType(),
  1328. "Unexpected type mismatch between matrix ternary operator values.");
  1329. #ifndef NDEBUG
  1330. // Assert that if the condition is a matrix, it matches the dimensions of the values
  1331. if (HLMatrixType MatCondTy = HLMatrixType::dyn_cast(Cond->getType())) {
  1332. HLMatrixType ValMatTy = HLMatrixType::cast(TrueMat->getType());
  1333. DXASSERT(MatCondTy.getNumRows() == ValMatTy.getNumRows()
  1334. && MatCondTy.getNumColumns() == ValMatTy.getNumColumns(),
  1335. "Unexpected mismatch between ternary operator condition and value matrix dimensions.");
  1336. }
  1337. #endif
  1338. IRBuilder<> Builder(Call);
  1339. Value *LoweredCond = getLoweredByValOperand(Cond, Builder);
  1340. Value *LoweredTrueVec = getLoweredByValOperand(TrueMat, Builder);
  1341. Value *LoweredFalseVec = getLoweredByValOperand(FalseMat, Builder);
  1342. Value *Result = UndefValue::get(LoweredTrueVec->getType());
  1343. bool IsScalarCond = !LoweredCond->getType()->isVectorTy();
  1344. unsigned NumElems = Result->getType()->getVectorNumElements();
  1345. for (uint64_t ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
  1346. Value *ElemCond = IsScalarCond ? LoweredCond
  1347. : Builder.CreateExtractElement(LoweredCond, ElemIdx);
  1348. Value *ElemTrueVal = Builder.CreateExtractElement(LoweredTrueVec, ElemIdx);
  1349. Value *ElemFalseVal = Builder.CreateExtractElement(LoweredFalseVec, ElemIdx);
  1350. Value *ResultElem = Builder.CreateSelect(ElemCond, ElemTrueVal, ElemFalseVal);
  1351. Result = Builder.CreateInsertElement(Result, ResultElem, ElemIdx);
  1352. }
  1353. return Result;
  1354. }