HLMatrixLowerPass.cpp 64 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixLowerPass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // HLMatrixLowerPass implementation. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLMatrixLowerPass.h"
  12. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  13. #include "dxc/HLSL/HLMatrixType.h"
  14. #include "dxc/HLSL/HLOperations.h"
  15. #include "dxc/HLSL/HLModule.h"
  16. #include "dxc/HlslIntrinsicOp.h"
  17. #include "dxc/Support/Global.h"
  18. #include "dxc/DXIL/DxilOperations.h"
  19. #include "dxc/DXIL/DxilTypeSystem.h"
  20. #include "dxc/DXIL/DxilModule.h"
  21. #include "dxc/DXIL/DxilUtil.h"
  22. #include "HLMatrixSubscriptUseReplacer.h"
  23. #include "llvm/IR/IRBuilder.h"
  24. #include "llvm/IR/Module.h"
  25. #include "llvm/IR/DebugInfo.h"
  26. #include "llvm/IR/IntrinsicInst.h"
  27. #include "llvm/Transforms/Utils/Local.h"
  28. #include "llvm/Pass.h"
  29. #include "llvm/Support/raw_ostream.h"
  30. #include <unordered_set>
  31. #include <vector>
  32. using namespace llvm;
  33. using namespace hlsl;
  34. using namespace hlsl::HLMatrixLower;
  35. namespace hlsl {
  36. namespace HLMatrixLower {
  37. Value *BuildVector(Type *EltTy, ArrayRef<llvm::Value *> elts, IRBuilder<> &Builder) {
  38. Value *Vec = UndefValue::get(VectorType::get(EltTy, static_cast<unsigned>(elts.size())));
  39. for (unsigned i = 0; i < elts.size(); i++)
  40. Vec = Builder.CreateInsertElement(Vec, elts[i], i);
  41. return Vec;
  42. }
  43. } // namespace HLMatrixLower
  44. } // namespace hlsl
  45. namespace {
  46. // Creates and manages a set of temporary overloaded functions keyed on the function type,
  47. // and which should be destroyed when the pool gets out of scope.
  48. class TempOverloadPool {
  49. public:
  50. TempOverloadPool(llvm::Module &Module, const char* BaseName)
  51. : Module(Module), BaseName(BaseName) {}
  52. ~TempOverloadPool() { clear(); }
  53. Function *get(FunctionType *Ty);
  54. bool contains(FunctionType *Ty) const { return Funcs.count(Ty) != 0; }
  55. bool contains(Function *Func) const;
  56. void clear();
  57. private:
  58. llvm::Module &Module;
  59. const char* BaseName;
  60. llvm::DenseMap<FunctionType*, Function*> Funcs;
  61. };
  62. Function *TempOverloadPool::get(FunctionType *Ty) {
  63. auto It = Funcs.find(Ty);
  64. if (It != Funcs.end()) return It->second;
  65. std::string MangledName;
  66. raw_string_ostream MangledNameStream(MangledName);
  67. MangledNameStream << BaseName;
  68. MangledNameStream << '.';
  69. Ty->print(MangledNameStream);
  70. MangledNameStream.flush();
  71. Function* Func = cast<Function>(Module.getOrInsertFunction(MangledName, Ty));
  72. Funcs.insert(std::make_pair(Ty, Func));
  73. return Func;
  74. }
  75. bool TempOverloadPool::contains(Function *Func) const {
  76. auto It = Funcs.find(Func->getFunctionType());
  77. return It != Funcs.end() && It->second == Func;
  78. }
  79. void TempOverloadPool::clear() {
  80. for (auto Entry : Funcs) {
  81. DXASSERT(Entry.second->use_empty(), "Temporary function still used during pool destruction.");
  82. Entry.second->removeFromParent();
  83. }
  84. Funcs.clear();
  85. }
  86. // High-level matrix lowering pass.
  87. //
  88. // This pass converts matrices to their lowered vector representations,
  89. // including global variables, local variables and operations,
  90. // but not function signatures (arguments and return types) - left to HLSignatureLower and HLMatrixBitcastLower,
  91. // nor matrices obtained from resources or constant - left to HLOperationLower.
  92. //
  93. // Algorithm overview:
  94. // 1. Find all matrix and matrix array global variables and lower them to vectors.
  95. // Walk any GEPs and insert vec-to-mat translation stubs so that consuming
  96. // instructions keep dealing with matrix types for the moment.
  97. // 2. For each function
  98. // 2a. Lower all matrix and matrix array allocas, just like global variables.
  99. // 2b. Lower all other instructions producing or consuming matrices
  100. //
  101. // Conversion stubs are used to allow converting instructions in isolation,
  102. // and in an order-independent manner:
  103. //
  104. // Initial: MatInst1(MatInst2(MatInst3))
  105. // After lowering MatInst2: MatInst1(VecToMat(VecInst2(MatToVec(MatInst3))))
  106. // After lowering MatInst1: VecInst1(VecInst2(MatToVec(MatInst3)))
  107. // After lowering MatInst3: VecInst1(VecInst2(VecInst3))
  108. class HLMatrixLowerPass : public ModulePass {
  109. public:
  110. static char ID; // Pass identification, replacement for typeid
  111. explicit HLMatrixLowerPass() : ModulePass(ID) {}
  112. const char *getPassName() const override { return "HL matrix lower"; }
  113. bool runOnModule(Module &M) override;
  114. private:
  115. void runOnFunction(Function &Func);
  116. void addToDeadInsts(Instruction *Inst) { m_deadInsts.emplace_back(Inst); }
  117. void deleteDeadInsts();
  118. void getMatrixAllocasAndOtherInsts(Function &Func,
  119. std::vector<AllocaInst*> &MatAllocas, std::vector<Instruction*> &MatInsts);
  120. Value *getLoweredByValOperand(Value *Val, IRBuilder<> &Builder, bool DiscardStub = false);
  121. Value *tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder, bool DiscardStub = false);
  122. Value *bitCastValue(Value *SrcVal, Type* DstTy, bool DstTyAlloca, IRBuilder<> &Builder);
  123. void replaceAllUsesByLoweredValue(Instruction *MatInst, Value *VecVal);
  124. void replaceAllVariableUses(Value* MatPtr, Value* LoweredPtr);
  125. void replaceAllVariableUses(SmallVectorImpl<Value*> &GEPIdxStack, Value *StackTopPtr, Value* LoweredPtr);
  126. Value *translateScalarMatMul(Value *scalar, Value *mat, IRBuilder<> &Builder, bool isLhsScalar = true);
  127. void lowerGlobal(GlobalVariable *Global);
  128. Constant *lowerConstInitVal(Constant *Val);
  129. AllocaInst *lowerAlloca(AllocaInst *MatAlloca);
  130. void lowerInstruction(Instruction* Inst);
  131. void lowerReturn(ReturnInst* Return);
  132. Value *lowerCall(CallInst *Call);
  133. Value *lowerNonHLCall(CallInst *Call);
  134. void lowerPreciseCall(CallInst *Call, IRBuilder<> Builder);
  135. Value *lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup);
  136. Value *lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode);
  137. Value *lowerHLMulIntrinsic(Value* Lhs, Value *Rhs, bool Unsigned, IRBuilder<> &Builder);
  138. Value *lowerHLTransposeIntrinsic(Value *MatVal, IRBuilder<> &Builder);
  139. Value *lowerHLDeterminantIntrinsic(Value *MatVal, IRBuilder<> &Builder);
  140. Value *lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder);
  141. Value *lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder);
  142. Value *lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode);
  143. Value *lowerHLLoad(CallInst *Call, Value *MatPtr, bool RowMajor, IRBuilder<> &Builder);
  144. Value *lowerHLStore(CallInst *Call, Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder);
  145. Value *lowerHLCast(CallInst *Call, Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder);
  146. Value *lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
  147. Value *lowerHLMatElementSubscript(CallInst *Call, bool RowMajor);
  148. Value *lowerHLMatSubscript(CallInst *Call, bool RowMajor);
  149. void lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices);
  150. Value *lowerHLInit(CallInst *Call);
  151. Value *lowerHLSelect(CallInst *Call);
  152. private:
  153. Module *m_pModule;
  154. HLModule *m_pHLModule;
  155. bool m_HasDbgInfo;
  156. // Pools for the translation stubs
  157. TempOverloadPool *m_matToVecStubs = nullptr;
  158. TempOverloadPool *m_vecToMatStubs = nullptr;
  159. std::vector<Instruction *> m_deadInsts;
  160. };
  161. }
  162. char HLMatrixLowerPass::ID = 0;
  163. ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
  164. INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
  165. bool HLMatrixLowerPass::runOnModule(Module &M) {
  166. TempOverloadPool matToVecStubs(M, "hlmatrixlower.mat2vec");
  167. TempOverloadPool vecToMatStubs(M, "hlmatrixlower.vec2mat");
  168. m_pModule = &M;
  169. m_pHLModule = &m_pModule->GetOrCreateHLModule();
  170. // Load up debug information, to cross-reference values and the instructions
  171. // used to load them.
  172. m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  173. m_matToVecStubs = &matToVecStubs;
  174. m_vecToMatStubs = &vecToMatStubs;
  175. // First, lower static global variables.
  176. // We need to accumulate them locally because we'll be creating new ones as we lower them.
  177. std::vector<GlobalVariable*> Globals;
  178. for (GlobalVariable &Global : M.globals()) {
  179. if ((dxilutil::IsStaticGlobal(&Global) || dxilutil::IsSharedMemoryGlobal(&Global))
  180. && HLMatrixType::isMatrixPtrOrArrayPtr(Global.getType())) {
  181. Globals.emplace_back(&Global);
  182. }
  183. }
  184. for (GlobalVariable *Global : Globals)
  185. lowerGlobal(Global);
  186. for (Function &F : M.functions()) {
  187. if (F.isDeclaration()) continue;
  188. runOnFunction(F);
  189. }
  190. m_pModule = nullptr;
  191. m_pHLModule = nullptr;
  192. m_matToVecStubs = nullptr;
  193. m_vecToMatStubs = nullptr;
  194. // If you hit an assert during TempOverloadPool destruction,
  195. // it means that either a matrix producer was lowered,
  196. // causing a translation stub to be created,
  197. // but the consumer of that matrix was never (properly) lowered.
  198. // Or the opposite: a matrix consumer was lowered and not its producer.
  199. return true;
  200. }
  201. void HLMatrixLowerPass::runOnFunction(Function &Func) {
  202. // Skip hl function definition (like createhandle)
  203. if (hlsl::GetHLOpcodeGroupByName(&Func) != HLOpcodeGroup::NotHL)
  204. return;
  205. // Save the matrix instructions first since the translation process
  206. // will temporarily create other instructions consuming/producing matrix types.
  207. std::vector<AllocaInst*> MatAllocas;
  208. std::vector<Instruction*> MatInsts;
  209. getMatrixAllocasAndOtherInsts(Func, MatAllocas, MatInsts);
  210. // First lower all allocas and take care of their GEP chains
  211. for (AllocaInst* MatAlloca : MatAllocas) {
  212. AllocaInst* LoweredAlloca = lowerAlloca(MatAlloca);
  213. replaceAllVariableUses(MatAlloca, LoweredAlloca);
  214. addToDeadInsts(MatAlloca);
  215. }
  216. // Now lower all other matrix instructions
  217. for (Instruction *MatInst : MatInsts)
  218. lowerInstruction(MatInst);
  219. deleteDeadInsts();
  220. }
  221. void HLMatrixLowerPass::deleteDeadInsts() {
  222. while (!m_deadInsts.empty()) {
  223. Instruction *Inst = m_deadInsts.back();
  224. m_deadInsts.pop_back();
  225. DXASSERT_NOMSG(Inst->use_empty());
  226. for (Value *Operand : Inst->operand_values()) {
  227. Instruction *OperandInst = dyn_cast<Instruction>(Operand);
  228. if (OperandInst && ++OperandInst->user_begin() == OperandInst->user_end()) {
  229. // We were its only user, erase recursively.
  230. // This will get rid of translation stubs:
  231. // Original: MatConsumer(MatProducer)
  232. // Producer lowered: MatConsumer(VecToMat(VecProducer)), MatProducer dead
  233. // Consumer lowered: VecConsumer(VecProducer)), MatConsumer(VecToMat) dead
  234. // Only by recursing on MatConsumer's operand do we delete the VecToMat stub.
  235. DXASSERT_NOMSG(*OperandInst->user_begin() == Inst);
  236. m_deadInsts.emplace_back(OperandInst);
  237. }
  238. }
  239. Inst->eraseFromParent();
  240. }
  241. }
  242. // Find all instructions consuming or producing matrices,
  243. // directly or through pointers/arrays.
  244. void HLMatrixLowerPass::getMatrixAllocasAndOtherInsts(Function &Func,
  245. std::vector<AllocaInst*> &MatAllocas, std::vector<Instruction*> &MatInsts){
  246. for (BasicBlock &BasicBlock : Func) {
  247. for (Instruction &Inst : BasicBlock) {
  248. // Don't lower GEPs directly, we'll handle them as we lower the root pointer,
  249. // typically a global variable or alloca.
  250. if (isa<GetElementPtrInst>(&Inst)) continue;
  251. if (AllocaInst *Alloca = dyn_cast<AllocaInst>(&Inst)) {
  252. if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Alloca->getType())) {
  253. MatAllocas.emplace_back(Alloca);
  254. }
  255. continue;
  256. }
  257. if (CallInst *Call = dyn_cast<CallInst>(&Inst)) {
  258. // Lowering of global variables will have introduced
  259. // vec-to-mat translation stubs, which we deal with indirectly,
  260. // as we lower the instructions consuming them.
  261. if (m_vecToMatStubs->contains(Call->getCalledFunction()))
  262. continue;
  263. // Mat-to-vec stubs should only be introduced during instruction lowering.
  264. // Globals lowering won't introduce any because their only operand is
  265. // their initializer, which we can fully lower without stubbing since it is constant.
  266. DXASSERT(!m_matToVecStubs->contains(Call->getCalledFunction()),
  267. "Unexpected mat-to-vec stubbing before function instruction lowering.");
  268. // Match matrix producers
  269. if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Inst.getType())) {
  270. MatInsts.emplace_back(Call);
  271. continue;
  272. }
  273. // Match matrix consumers
  274. for (Value *Operand : Inst.operand_values()) {
  275. if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Operand->getType())) {
  276. MatInsts.emplace_back(Call);
  277. break;
  278. }
  279. }
  280. continue;
  281. }
  282. if (ReturnInst *Return = dyn_cast<ReturnInst>(&Inst)) {
  283. Value *ReturnValue = Return->getReturnValue();
  284. if (ReturnValue != nullptr && HLMatrixType::isMatrixOrPtrOrArrayPtr(ReturnValue->getType()))
  285. MatInsts.emplace_back(Return);
  286. continue;
  287. }
  288. // Nothing else should produce or consume matrices
  289. }
  290. }
  291. }
  292. // Gets the matrix-lowered representation of a value, potentially adding a translation stub.
  293. // DiscardStub causes any vec-to-mat translation stubs to be deleted,
  294. // it should be true only if the original instruction will be modified and kept alive.
  295. // If a new instruction is created and the original marked as dead,
  296. // then the remove dead instructions pass will take care of removing the stub.
  297. Value* HLMatrixLowerPass::getLoweredByValOperand(Value *Val, IRBuilder<> &Builder, bool DiscardStub) {
  298. Type *Ty = Val->getType();
  299. // We're only lowering byval matrices.
  300. // Since structs and arrays are always accessed by pointer,
  301. // we do not need to worry about a matrix being hidden inside a more complex type.
  302. DXASSERT(!Ty->isPointerTy(), "Value cannot be a pointer.");
  303. HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty);
  304. if (!MatTy) return Val;
  305. Type *LoweredTy = MatTy.getLoweredVectorTypeForReg();
  306. // Check if the value is already a vec-to-mat translation stub
  307. if (CallInst *Call = dyn_cast<CallInst>(Val)) {
  308. if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
  309. if (DiscardStub && Call->getNumUses() == 1) {
  310. Call->use_begin()->set(UndefValue::get(Call->getType()));
  311. addToDeadInsts(Call);
  312. }
  313. Value *LoweredVal = Call->getArgOperand(0);
  314. DXASSERT(LoweredVal->getType() == LoweredTy, "Unexpected already-lowered value type.");
  315. return LoweredVal;
  316. }
  317. }
  318. // Return a mat-to-vec translation stub
  319. FunctionType *TranslationStubTy = FunctionType::get(LoweredTy, { Ty }, /* isVarArg */ false);
  320. Function *TranslationStub = m_matToVecStubs->get(TranslationStubTy);
  321. return Builder.CreateCall(TranslationStub, { Val });
  322. }
  323. // Attempts to retrieve the lowered vector pointer equivalent to a matrix pointer.
  324. // Returns nullptr if the pointed-to matrix lives in memory that cannot be lowered at this time,
  325. // for example a buffer or shader inputs/outputs, which are lowered during signature lowering.
  326. Value *HLMatrixLowerPass::tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder, bool DiscardStub) {
  327. if (!HLMatrixType::isMatrixPtrOrArrayPtr(Ptr->getType()))
  328. return nullptr;
  329. // Matrix pointers can only be derived from Allocas, GlobalVariables or resource accesses.
  330. // The first two cases are what this pass must be able to lower, and we should already
  331. // have replaced their uses by vector to matrix pointer translation stubs.
  332. if (CallInst *Call = dyn_cast<CallInst>(Ptr)) {
  333. if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
  334. if (DiscardStub && Call->getNumUses() == 1) {
  335. Call->use_begin()->set(UndefValue::get(Call->getType()));
  336. addToDeadInsts(Call);
  337. }
  338. return Call->getArgOperand(0);
  339. }
  340. }
  341. // There's one more case to handle.
  342. // When compiling shader libraries, signatures won't have been lowered yet.
  343. // So we can have a matrix in a struct as an argument,
  344. // or an alloca'd struct holding the return value of a call and containing a matrix.
  345. Value *RootPtr = Ptr;
  346. while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
  347. RootPtr = GEP->getPointerOperand();
  348. Argument *Arg = dyn_cast<Argument>(RootPtr);
  349. bool IsNonShaderArg = Arg != nullptr && !m_pHLModule->IsGraphicsShader(Arg->getParent());
  350. if (IsNonShaderArg || isa<AllocaInst>(RootPtr)) {
  351. // Bitcast the matrix pointer to its lowered equivalent.
  352. // The HLMatrixBitcast pass will take care of this later.
  353. return Builder.CreateBitCast(Ptr, HLMatrixType::getLoweredType(Ptr->getType()));
  354. }
  355. // The pointer must be derived from a resource, we don't handle it in this pass.
  356. return nullptr;
  357. }
  358. // Bitcasts a value from matrix to vector or vice-versa.
  359. // This is used to convert to/from arguments/return values since we don't
  360. // lower signatures in this pass. The later HLMatrixBitcastLower pass fixes this.
  361. Value *HLMatrixLowerPass::bitCastValue(Value *SrcVal, Type* DstTy, bool DstTyAlloca, IRBuilder<> &Builder) {
  362. Type *SrcTy = SrcVal->getType();
  363. DXASSERT_NOMSG(!SrcTy->isPointerTy());
  364. // We store and load from a temporary alloca, bitcasting either on the store pointer
  365. // or on the load pointer.
  366. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
  367. Value *Alloca = AllocaBuilder.CreateAlloca(DstTyAlloca ? DstTy : SrcTy);
  368. Value *BitCastedAlloca = Builder.CreateBitCast(Alloca, (DstTyAlloca ? SrcTy : DstTy)->getPointerTo());
  369. Builder.CreateStore(SrcVal, DstTyAlloca ? BitCastedAlloca : Alloca);
  370. return Builder.CreateLoad(DstTyAlloca ? Alloca : BitCastedAlloca);
  371. }
  372. // Replaces all uses of a matrix value by its lowered vector form,
  373. // inserting translation stubs for users which still expect a matrix value.
  374. void HLMatrixLowerPass::replaceAllUsesByLoweredValue(Instruction* MatInst, Value* VecVal) {
  375. if (VecVal == nullptr || VecVal == MatInst) return;
  376. DXASSERT(HLMatrixType::getLoweredType(MatInst->getType()) == VecVal->getType(),
  377. "Unexpected lowered value type.");
  378. Instruction *VecToMatStub = nullptr;
  379. while (!MatInst->use_empty()) {
  380. Use &ValUse = *MatInst->use_begin();
  381. // Handle non-matrix cases, just point to the new value.
  382. if (MatInst->getType() == VecVal->getType()) {
  383. ValUse.set(VecVal);
  384. continue;
  385. }
  386. // If the user is already a matrix-to-vector translation stub,
  387. // we can now replace it by the proper vector value.
  388. if (CallInst *Call = dyn_cast<CallInst>(ValUse.getUser())) {
  389. if (m_matToVecStubs->contains(Call->getCalledFunction())) {
  390. Call->replaceAllUsesWith(VecVal);
  391. ValUse.set(UndefValue::get(MatInst->getType()));
  392. addToDeadInsts(Call);
  393. continue;
  394. }
  395. }
  396. // Otherwise, the user should point to a vector-to-matrix translation
  397. // stub of the new vector value.
  398. if (VecToMatStub == nullptr) {
  399. FunctionType *TranslationStubTy = FunctionType::get(
  400. MatInst->getType(), { VecVal->getType() }, /* isVarArg */ false);
  401. Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
  402. Instruction *PrevInst = dyn_cast<Instruction>(VecVal);
  403. if (PrevInst == nullptr) PrevInst = MatInst;
  404. IRBuilder<> Builder(dxilutil::SkipAllocas(PrevInst->getNextNode()));
  405. VecToMatStub = Builder.CreateCall(TranslationStub, { VecVal });
  406. }
  407. ValUse.set(VecToMatStub);
  408. }
  409. }
  410. // Replaces all uses of a matrix or matrix array alloca or global variable by its lowered equivalent.
  411. // This doesn't lower the users, but will insert a translation stub from the lowered value pointer
  412. // back to the matrix value pointer, and recreate any GEPs around the new pointer.
  413. // Before: User(GEP(MatrixArrayAlloca))
  414. // After: User(VecToMatPtrStub(GEP'(VectorArrayAlloca)))
  415. void HLMatrixLowerPass::replaceAllVariableUses(Value* MatPtr, Value* LoweredPtr) {
  416. DXASSERT_NOMSG(HLMatrixType::isMatrixPtrOrArrayPtr(MatPtr->getType()));
  417. DXASSERT_NOMSG(LoweredPtr->getType() == HLMatrixType::getLoweredType(MatPtr->getType()));
  418. SmallVector<Value*, 4> GEPIdxStack;
  419. GEPIdxStack.emplace_back(ConstantInt::get(Type::getInt32Ty(MatPtr->getContext()), 0));
  420. replaceAllVariableUses(GEPIdxStack, MatPtr, LoweredPtr);
  421. }
  422. void HLMatrixLowerPass::replaceAllVariableUses(
  423. SmallVectorImpl<Value*> &GEPIdxStack, Value *StackTopPtr, Value* LoweredPtr) {
  424. while (!StackTopPtr->use_empty()) {
  425. llvm::Use &Use = *StackTopPtr->use_begin();
  426. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Use.getUser())) {
  427. DXASSERT(GEP->getNumIndices() >= 1, "Unexpected degenerate GEP.");
  428. DXASSERT(cast<ConstantInt>(*GEP->idx_begin())->isZero(), "Unexpected non-zero first GEP index.");
  429. // Recurse in GEP to find actual users
  430. for (auto It = GEP->idx_begin() + 1; It != GEP->idx_end(); ++It)
  431. GEPIdxStack.emplace_back(*It);
  432. replaceAllVariableUses(GEPIdxStack, GEP, LoweredPtr);
  433. GEPIdxStack.erase(GEPIdxStack.end() - (GEP->getNumIndices() - 1), GEPIdxStack.end());
  434. // Discard the GEP
  435. DXASSERT_NOMSG(GEP->use_empty());
  436. if (GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(GEP)) {
  437. Use.set(UndefValue::get(Use->getType()));
  438. addToDeadInsts(GEPInst);
  439. } else {
  440. // constant GEP
  441. cast<Constant>(GEP)->destroyConstant();
  442. }
  443. continue;
  444. }
  445. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Use.getUser())) {
  446. DXASSERT(CE->getOpcode() == Instruction::AddrSpaceCast,
  447. "Unexpected constant user");
  448. replaceAllVariableUses(GEPIdxStack, CE, LoweredPtr);
  449. DXASSERT_NOMSG(CE->use_empty());
  450. CE->destroyConstant();
  451. continue;
  452. }
  453. if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(Use.getUser())) {
  454. replaceAllVariableUses(GEPIdxStack, CI, LoweredPtr);
  455. Use.set(UndefValue::get(Use->getType()));
  456. addToDeadInsts(CI);
  457. continue;
  458. }
  459. // Recreate the same GEP sequence, if any, on the lowered pointer
  460. IRBuilder<> Builder(cast<Instruction>(Use.getUser()));
  461. Value *LoweredStackTopPtr = GEPIdxStack.size() == 1
  462. ? LoweredPtr : Builder.CreateGEP(LoweredPtr, GEPIdxStack);
  463. // Generate a stub translating the vector pointer back to a matrix pointer,
  464. // such that consuming instructions are unaffected.
  465. FunctionType *TranslationStubTy = FunctionType::get(
  466. StackTopPtr->getType(), { LoweredStackTopPtr->getType() }, /* isVarArg */ false);
  467. Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
  468. Use.set(Builder.CreateCall(TranslationStub, { LoweredStackTopPtr }));
  469. }
  470. }
  471. void HLMatrixLowerPass::lowerGlobal(GlobalVariable *Global) {
  472. if (Global->user_empty()) return;
  473. PointerType *LoweredPtrTy = cast<PointerType>(HLMatrixType::getLoweredType(Global->getType()));
  474. DXASSERT_NOMSG(LoweredPtrTy != Global->getType());
  475. Constant *LoweredInitVal = Global->hasInitializer()
  476. ? lowerConstInitVal(Global->getInitializer()) : nullptr;
  477. GlobalVariable *LoweredGlobal = new GlobalVariable(*m_pModule, LoweredPtrTy->getElementType(),
  478. Global->isConstant(), Global->getLinkage(), LoweredInitVal,
  479. Global->getName() + ".v", /*InsertBefore*/ nullptr, Global->getThreadLocalMode(),
  480. Global->getType()->getAddressSpace());
  481. // Add debug info.
  482. if (m_HasDbgInfo) {
  483. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  484. HLModule::UpdateGlobalVariableDebugInfo(Global, Finder, LoweredGlobal);
  485. }
  486. replaceAllVariableUses(Global, LoweredGlobal);
  487. Global->removeDeadConstantUsers();
  488. Global->eraseFromParent();
  489. }
  490. Constant *HLMatrixLowerPass::lowerConstInitVal(Constant *Val) {
  491. Type *Ty = Val->getType();
  492. // If it's an array of matrices, recurse for each element or nested array
  493. if (ArrayType *ArrayTy = dyn_cast<ArrayType>(Ty)) {
  494. SmallVector<Constant*, 4> LoweredElems;
  495. unsigned NumElems = ArrayTy->getNumElements();
  496. LoweredElems.reserve(NumElems);
  497. for (unsigned ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
  498. Constant *ArrayElem = Val->getAggregateElement(ElemIdx);
  499. LoweredElems.emplace_back(lowerConstInitVal(ArrayElem));
  500. }
  501. Type *LoweredElemTy = HLMatrixType::getLoweredType(ArrayTy->getElementType(), /*MemRepr*/true);
  502. ArrayType *LoweredArrayTy = ArrayType::get(LoweredElemTy, NumElems);
  503. return ConstantArray::get(LoweredArrayTy, LoweredElems);
  504. }
  505. // Otherwise it's a matrix, lower it to a vector
  506. HLMatrixType MatTy = HLMatrixType::cast(Ty);
  507. DXASSERT_NOMSG(isa<StructType>(Ty));
  508. Constant *RowArrayVal = Val->getAggregateElement((unsigned)0);
  509. // Original initializer should have been produced in row/column-major order
  510. // depending on the qualifiers of the target variable, so preserve the order.
  511. SmallVector<Constant*, 16> MatElems;
  512. for (unsigned RowIdx = 0; RowIdx < MatTy.getNumRows(); ++RowIdx) {
  513. Constant *RowVal = RowArrayVal->getAggregateElement(RowIdx);
  514. for (unsigned ColIdx = 0; ColIdx < MatTy.getNumColumns(); ++ColIdx) {
  515. MatElems.emplace_back(RowVal->getAggregateElement(ColIdx));
  516. }
  517. }
  518. Constant *Vec = ConstantVector::get(MatElems);
  519. // Matrix elements are always in register representation,
  520. // but the lowered global variable is of vector type in
  521. // its memory representation, so we must convert here.
  522. // This will produce a constant so we can use an IRBuilder without a valid insertion point.
  523. IRBuilder<> DummyBuilder(Val->getContext());
  524. return cast<Constant>(MatTy.emitLoweredRegToMem(Vec, DummyBuilder));
  525. }
  526. AllocaInst *HLMatrixLowerPass::lowerAlloca(AllocaInst *MatAlloca) {
  527. PointerType *LoweredAllocaTy = cast<PointerType>(HLMatrixType::getLoweredType(MatAlloca->getType()));
  528. IRBuilder<> Builder(MatAlloca);
  529. AllocaInst *LoweredAlloca = Builder.CreateAlloca(
  530. LoweredAllocaTy->getElementType(), nullptr, MatAlloca->getName());
  531. // Update debug info.
  532. if (DbgDeclareInst *DbgDeclare = llvm::FindAllocaDbgDeclare(MatAlloca)) {
  533. LLVMContext &Context = MatAlloca->getContext();
  534. Value *DbgDeclareVar = MetadataAsValue::get(Context, DbgDeclare->getRawVariable());
  535. Value *DbgDeclareExpr = MetadataAsValue::get(Context, DbgDeclare->getRawExpression());
  536. Value *ValueMetadata = MetadataAsValue::get(Context, ValueAsMetadata::get(LoweredAlloca));
  537. IRBuilder<> DebugBuilder(DbgDeclare);
  538. DebugBuilder.CreateCall(DbgDeclare->getCalledFunction(), { ValueMetadata, DbgDeclareVar, DbgDeclareExpr });
  539. }
  540. if (HLModule::HasPreciseAttributeWithMetadata(MatAlloca))
  541. HLModule::MarkPreciseAttributeWithMetadata(LoweredAlloca);
  542. replaceAllVariableUses(MatAlloca, LoweredAlloca);
  543. return LoweredAlloca;
  544. }
  545. void HLMatrixLowerPass::lowerInstruction(Instruction* Inst) {
  546. if (CallInst *Call = dyn_cast<CallInst>(Inst)) {
  547. Value *LoweredValue = lowerCall(Call);
  548. // lowerCall returns the lowered value iff we should discard
  549. // the original matrix instruction and replace all of its uses
  550. // by the lowered value. It returns nullptr to opt-out of this.
  551. if (LoweredValue != nullptr) {
  552. replaceAllUsesByLoweredValue(Call, LoweredValue);
  553. addToDeadInsts(Inst);
  554. }
  555. }
  556. else if (ReturnInst *Return = dyn_cast<ReturnInst>(Inst)) {
  557. lowerReturn(Return);
  558. }
  559. else
  560. llvm_unreachable("Unexpected matrix instruction type.");
  561. }
  562. void HLMatrixLowerPass::lowerReturn(ReturnInst* Return) {
  563. Value *RetVal = Return->getReturnValue();
  564. Type *RetTy = RetVal->getType();
  565. DXASSERT_LOCALVAR(RetTy, !RetTy->isPointerTy(), "Unexpected matrix returned by pointer.");
  566. IRBuilder<> Builder(Return);
  567. Value *LoweredRetVal = getLoweredByValOperand(RetVal, Builder, /* DiscardStub */ true);
  568. // Since we're not lowering the signature, we can't return the lowered value directly,
  569. // so insert a bitcast, which HLMatrixBitcastLower knows how to eliminate.
  570. Value *BitCastedRetVal = bitCastValue(LoweredRetVal, RetVal->getType(), /* DstTyAlloca */ false, Builder);
  571. Return->setOperand(0, BitCastedRetVal);
  572. }
  573. Value *HLMatrixLowerPass::lowerCall(CallInst *Call) {
  574. HLOpcodeGroup OpcodeGroup = GetHLOpcodeGroupByName(Call->getCalledFunction());
  575. return OpcodeGroup == HLOpcodeGroup::NotHL
  576. ? lowerNonHLCall(Call) : lowerHLOperation(Call, OpcodeGroup);
  577. }
  578. // Special function to lower precise call applied to a matrix
  579. // The matrix should be lowered and the call regenerated with vector arg
  580. void HLMatrixLowerPass::lowerPreciseCall(CallInst *Call, IRBuilder<> Builder) {
  581. DXASSERT(Call->getNumArgOperands() == 1, "Only one arg expected for precise matrix call");
  582. Value *Arg = Call->getArgOperand(0);
  583. Value *LoweredArg = getLoweredByValOperand(Arg, Builder);
  584. HLModule::MarkPreciseAttributeOnValWithFunctionCall(LoweredArg, Builder, *m_pModule);
  585. addToDeadInsts(Call);
  586. }
  587. Value *HLMatrixLowerPass::lowerNonHLCall(CallInst *Call) {
  588. // First, handle any operand of matrix-derived type
  589. // We don't lower the callee's signature in this pass,
  590. // so, for any matrix-typed parameter, we create a bitcast from the
  591. // lowered vector back to the matrix type, which the later HLMatrixBitcastLower
  592. // pass knows how to eliminate.
  593. IRBuilder<> PreCallBuilder(Call);
  594. unsigned NumArgs = Call->getNumArgOperands();
  595. Function *Func = Call->getCalledFunction();
  596. if (Func && HLModule::HasPreciseAttribute(Func)) {
  597. lowerPreciseCall(Call, PreCallBuilder);
  598. return nullptr;
  599. }
  600. for (unsigned ArgIdx = 0; ArgIdx < NumArgs; ++ArgIdx) {
  601. Use &ArgUse = Call->getArgOperandUse(ArgIdx);
  602. if (ArgUse->getType()->isPointerTy()) {
  603. // Byref arg
  604. Value *LoweredArg = tryGetLoweredPtrOperand(ArgUse.get(), PreCallBuilder, /* DiscardStub */ true);
  605. if (LoweredArg != nullptr) {
  606. // Pointer to a matrix we've lowered, insert a bitcast back to matrix pointer type.
  607. Value *BitCastedArg = PreCallBuilder.CreateBitCast(LoweredArg, ArgUse->getType());
  608. ArgUse.set(BitCastedArg);
  609. }
  610. }
  611. else {
  612. // Byvalue arg
  613. Value *LoweredArg = getLoweredByValOperand(ArgUse.get(), PreCallBuilder, /* DiscardStub */ true);
  614. if (LoweredArg == ArgUse.get()) continue;
  615. Value *BitCastedArg = bitCastValue(LoweredArg, ArgUse->getType(), /* DstTyAlloca */ false, PreCallBuilder);
  616. ArgUse.set(BitCastedArg);
  617. }
  618. }
  619. // Now check the return type
  620. HLMatrixType RetMatTy = HLMatrixType::dyn_cast(Call->getType());
  621. if (!RetMatTy) {
  622. DXASSERT(!HLMatrixType::isMatrixPtrOrArrayPtr(Call->getType()),
  623. "Unexpected user call returning a matrix by pointer.");
  624. // Nothing to replace, other instructions can consume a non-matrix return type.
  625. return nullptr;
  626. }
  627. // The callee returns a matrix, and we don't lower signatures in this pass.
  628. // We perform a sketchy bitcast to the lowered register-representation type,
  629. // which the later HLMatrixBitcastLower pass knows how to eliminate.
  630. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Call));
  631. Value *LoweredAlloca = AllocaBuilder.CreateAlloca(RetMatTy.getLoweredVectorTypeForReg());
  632. IRBuilder<> PostCallBuilder(Call->getNextNode());
  633. Value *BitCastedAlloca = PostCallBuilder.CreateBitCast(LoweredAlloca, Call->getType()->getPointerTo());
  634. // This is slightly tricky
  635. // We want to replace all uses of the matrix-returning call by the bitcasted value,
  636. // but the store to the bitcasted pointer itself is a use of that matrix,
  637. // so we need to create the load, replace the uses, and then insert the store.
  638. LoadInst *LoweredVal = PostCallBuilder.CreateLoad(LoweredAlloca);
  639. replaceAllUsesByLoweredValue(Call, LoweredVal);
  640. // Now we can insert the store. Make sure to do so before the load.
  641. PostCallBuilder.SetInsertPoint(LoweredVal);
  642. PostCallBuilder.CreateStore(Call, BitCastedAlloca);
  643. // Return nullptr since we did our own uses replacement and we don't want
  644. // the matrix instruction to be marked as dead since we're still using it.
  645. return nullptr;
  646. }
  647. Value *HLMatrixLowerPass::lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup) {
  648. IRBuilder<> Builder(Call);
  649. switch (OpcodeGroup) {
  650. case HLOpcodeGroup::HLIntrinsic:
  651. return lowerHLIntrinsic(Call, static_cast<IntrinsicOp>(GetHLOpcode(Call)));
  652. case HLOpcodeGroup::HLBinOp:
  653. return lowerHLBinaryOperation(
  654. Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
  655. Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
  656. static_cast<HLBinaryOpcode>(GetHLOpcode(Call)), Builder);
  657. case HLOpcodeGroup::HLUnOp:
  658. return lowerHLUnaryOperation(
  659. Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx),
  660. static_cast<HLUnaryOpcode>(GetHLOpcode(Call)), Builder);
  661. case HLOpcodeGroup::HLMatLoadStore:
  662. return lowerHLLoadStore(Call, static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(Call)));
  663. case HLOpcodeGroup::HLCast:
  664. return lowerHLCast(Call,
  665. Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Call->getType(),
  666. static_cast<HLCastOpcode>(GetHLOpcode(Call)), Builder);
  667. case HLOpcodeGroup::HLSubscript:
  668. return lowerHLSubscript(Call, static_cast<HLSubscriptOpcode>(GetHLOpcode(Call)));
  669. case HLOpcodeGroup::HLInit:
  670. return lowerHLInit(Call);
  671. case HLOpcodeGroup::HLSelect:
  672. return lowerHLSelect(Call);
  673. default:
  674. llvm_unreachable("Unexpected matrix opcode");
  675. }
  676. }
  677. Value *HLMatrixLowerPass::lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode) {
  678. IRBuilder<> Builder(Call);
  679. // See if this is a matrix-specific intrinsic which we should expand here
  680. switch (Opcode) {
  681. case IntrinsicOp::IOP_umul:
  682. case IntrinsicOp::IOP_mul:
  683. return lowerHLMulIntrinsic(
  684. Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
  685. Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
  686. /* Unsigned */ Opcode == IntrinsicOp::IOP_umul, Builder);
  687. case IntrinsicOp::IOP_transpose:
  688. return lowerHLTransposeIntrinsic(Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
  689. case IntrinsicOp::IOP_determinant:
  690. return lowerHLDeterminantIntrinsic(Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
  691. }
  692. // Delegate to a lowered intrinsic call
  693. SmallVector<Value*, 4> LoweredArgs;
  694. LoweredArgs.reserve(Call->getNumArgOperands());
  695. for (Value *Arg : Call->arg_operands()) {
  696. if (Arg->getType()->isPointerTy()) {
  697. // ByRef parameter (for example, frexp's second parameter)
  698. // If the argument points to a lowered matrix variable, replace it here,
  699. // otherwise preserve the matrix type and let further passes handle the lowering.
  700. Value *LoweredArg = tryGetLoweredPtrOperand(Arg, Builder);
  701. if (LoweredArg == nullptr) LoweredArg = Arg;
  702. LoweredArgs.emplace_back(LoweredArg);
  703. }
  704. else {
  705. LoweredArgs.emplace_back(getLoweredByValOperand(Arg, Builder));
  706. }
  707. }
  708. Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
  709. return callHLFunction(*m_pModule, HLOpcodeGroup::HLIntrinsic, static_cast<unsigned>(Opcode),
  710. LoweredRetTy, LoweredArgs,
  711. Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
  712. }
  713. // Handles multiplcation of a scalar with a matrix
  714. Value *HLMatrixLowerPass::translateScalarMatMul(Value *Lhs, Value *Rhs, IRBuilder<> &Builder, bool isLhsScalar) {
  715. Value *Mat = isLhsScalar ? Rhs : Lhs;
  716. Value *Scalar = isLhsScalar ? Lhs : Rhs;
  717. Value* LoweredMat = getLoweredByValOperand(Mat, Builder);
  718. Type *ScalarTy = Scalar->getType();
  719. // Perform the scalar-matrix multiplication!
  720. Type *ElemTy = LoweredMat->getType()->getVectorElementType();
  721. bool isIntMulOp = ScalarTy->isIntegerTy() && ElemTy->isIntegerTy();
  722. bool isFloatMulOp = ScalarTy->isFloatingPointTy() && ElemTy->isFloatingPointTy();
  723. DXASSERT(ScalarTy == ElemTy, "Scalar type must match the matrix component type.");
  724. Value *Result = Builder.CreateVectorSplat(LoweredMat->getType()->getVectorNumElements(), Scalar);
  725. if (isFloatMulOp) {
  726. // Preserve the order of operation for floats
  727. Result = isLhsScalar ? Builder.CreateFMul(Result, LoweredMat) : Builder.CreateFMul(LoweredMat, Result);
  728. }
  729. else if (isIntMulOp) {
  730. // Doesn't matter for integers but still preserve the order of operation
  731. Result = isLhsScalar ? Builder.CreateMul(Result, LoweredMat) : Builder.CreateMul(LoweredMat, Result);
  732. }
  733. else {
  734. DXASSERT(0, "Unknown type encountered when doing scalar-matrix multiplication.");
  735. }
  736. return Result;
  737. }
  738. Value *HLMatrixLowerPass::lowerHLMulIntrinsic(Value* Lhs, Value *Rhs,
  739. bool Unsigned, IRBuilder<> &Builder) {
  740. HLMatrixType LhsMatTy = HLMatrixType::dyn_cast(Lhs->getType());
  741. HLMatrixType RhsMatTy = HLMatrixType::dyn_cast(Rhs->getType());
  742. Value* LoweredLhs = getLoweredByValOperand(Lhs, Builder);
  743. Value* LoweredRhs = getLoweredByValOperand(Rhs, Builder);
  744. // Translate multiplication of scalar with matrix
  745. bool isLhsScalar = !LoweredLhs->getType()->isVectorTy();
  746. bool isRhsScalar = !LoweredRhs->getType()->isVectorTy();
  747. bool isScalar = isLhsScalar || isRhsScalar;
  748. if (isScalar)
  749. return translateScalarMatMul(Lhs, Rhs, Builder, isLhsScalar);
  750. DXASSERT(LoweredLhs->getType()->getScalarType() == LoweredRhs->getType()->getScalarType(),
  751. "Unexpected element type mismatch in mul intrinsic.");
  752. DXASSERT(cast<VectorType>(LoweredLhs->getType()) && cast<VectorType>(LoweredRhs->getType()),
  753. "Unexpected scalar in lowered matrix mul intrinsic operands.");
  754. Type* ElemTy = LoweredLhs->getType()->getScalarType();
  755. // Figure out the dimensions of each side
  756. unsigned LhsNumRows, LhsNumCols, RhsNumRows, RhsNumCols;
  757. if (LhsMatTy && RhsMatTy) {
  758. LhsNumRows = LhsMatTy.getNumRows();
  759. LhsNumCols = LhsMatTy.getNumColumns();
  760. RhsNumRows = RhsMatTy.getNumRows();
  761. RhsNumCols = RhsMatTy.getNumColumns();
  762. }
  763. else if (LhsMatTy) {
  764. LhsNumRows = LhsMatTy.getNumRows();
  765. LhsNumCols = LhsMatTy.getNumColumns();
  766. RhsNumRows = LoweredRhs->getType()->getVectorNumElements();
  767. RhsNumCols = 1;
  768. }
  769. else if (RhsMatTy) {
  770. LhsNumRows = 1;
  771. LhsNumCols = LoweredLhs->getType()->getVectorNumElements();
  772. RhsNumRows = RhsMatTy.getNumRows();
  773. RhsNumCols = RhsMatTy.getNumColumns();
  774. }
  775. else {
  776. llvm_unreachable("mul intrinsic was identified as a matrix operation but neither operand is a matrix.");
  777. }
  778. DXASSERT(LhsNumCols == RhsNumRows, "Matrix mul intrinsic operands dimensions mismatch.");
  779. HLMatrixType ResultMatTy(ElemTy, LhsNumRows, RhsNumCols);
  780. unsigned AccCount = LhsNumCols;
  781. // Get the multiply-and-add intrinsic function, we'll need it
  782. IntrinsicOp MadOpcode = Unsigned ? IntrinsicOp::IOP_umad : IntrinsicOp::IOP_mad;
  783. FunctionType *MadFuncTy = FunctionType::get(ElemTy, { Builder.getInt32Ty(), ElemTy, ElemTy, ElemTy }, false);
  784. Function *MadFunc = GetOrCreateHLFunction(*m_pModule, MadFuncTy, HLOpcodeGroup::HLIntrinsic, (unsigned)MadOpcode);
  785. Constant *MadOpcodeVal = Builder.getInt32((unsigned)MadOpcode);
  786. // Perform the multiplication!
  787. Value *Result = UndefValue::get(VectorType::get(ElemTy, LhsNumRows * RhsNumCols));
  788. for (unsigned ResultRowIdx = 0; ResultRowIdx < ResultMatTy.getNumRows(); ++ResultRowIdx) {
  789. for (unsigned ResultColIdx = 0; ResultColIdx < ResultMatTy.getNumColumns(); ++ResultColIdx) {
  790. unsigned ResultElemIdx = ResultMatTy.getRowMajorIndex(ResultRowIdx, ResultColIdx);
  791. Value *ResultElem = nullptr;
  792. for (unsigned AccIdx = 0; AccIdx < AccCount; ++AccIdx) {
  793. unsigned LhsElemIdx = HLMatrixType::getRowMajorIndex(ResultRowIdx, AccIdx, LhsNumRows, LhsNumCols);
  794. unsigned RhsElemIdx = HLMatrixType::getRowMajorIndex(AccIdx, ResultColIdx, RhsNumRows, RhsNumCols);
  795. Value* LhsElem = Builder.CreateExtractElement(LoweredLhs, static_cast<uint64_t>(LhsElemIdx));
  796. Value* RhsElem = Builder.CreateExtractElement(LoweredRhs, static_cast<uint64_t>(RhsElemIdx));
  797. if (ResultElem == nullptr) {
  798. ResultElem = ElemTy->isFloatingPointTy()
  799. ? Builder.CreateFMul(LhsElem, RhsElem)
  800. : Builder.CreateMul(LhsElem, RhsElem);
  801. }
  802. else {
  803. ResultElem = Builder.CreateCall(MadFunc, { MadOpcodeVal, LhsElem, RhsElem, ResultElem });
  804. }
  805. }
  806. Result = Builder.CreateInsertElement(Result, ResultElem, static_cast<uint64_t>(ResultElemIdx));
  807. }
  808. }
  809. return Result;
  810. }
  811. Value *HLMatrixLowerPass::lowerHLTransposeIntrinsic(Value* MatVal, IRBuilder<> &Builder) {
  812. HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
  813. Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
  814. return MatTy.emitLoweredVectorRowToCol(LoweredVal, Builder);
  815. }
  816. static Value *determinant2x2(Value *M00, Value *M01, Value *M10, Value *M11, IRBuilder<> &Builder) {
  817. Value *Mul0 = Builder.CreateFMul(M00, M11);
  818. Value *Mul1 = Builder.CreateFMul(M01, M10);
  819. return Builder.CreateFSub(Mul0, Mul1);
  820. }
  821. static Value *determinant3x3(Value *M00, Value *M01, Value *M02,
  822. Value *M10, Value *M11, Value *M12,
  823. Value *M20, Value *M21, Value *M22,
  824. IRBuilder<> &Builder) {
  825. Value *Det00 = determinant2x2(M11, M12, M21, M22, Builder);
  826. Value *Det01 = determinant2x2(M10, M12, M20, M22, Builder);
  827. Value *Det02 = determinant2x2(M10, M11, M20, M21, Builder);
  828. Det00 = Builder.CreateFMul(M00, Det00);
  829. Det01 = Builder.CreateFMul(M01, Det01);
  830. Det02 = Builder.CreateFMul(M02, Det02);
  831. Value *Result = Builder.CreateFSub(Det00, Det01);
  832. Result = Builder.CreateFAdd(Result, Det02);
  833. return Result;
  834. }
  835. static Value *determinant4x4(Value *M00, Value *M01, Value *M02, Value *M03,
  836. Value *M10, Value *M11, Value *M12, Value *M13,
  837. Value *M20, Value *M21, Value *M22, Value *M23,
  838. Value *M30, Value *M31, Value *M32, Value *M33,
  839. IRBuilder<> &Builder) {
  840. Value *Det00 = determinant3x3(M11, M12, M13, M21, M22, M23, M31, M32, M33, Builder);
  841. Value *Det01 = determinant3x3(M10, M12, M13, M20, M22, M23, M30, M32, M33, Builder);
  842. Value *Det02 = determinant3x3(M10, M11, M13, M20, M21, M23, M30, M31, M33, Builder);
  843. Value *Det03 = determinant3x3(M10, M11, M12, M20, M21, M22, M30, M31, M32, Builder);
  844. Det00 = Builder.CreateFMul(M00, Det00);
  845. Det01 = Builder.CreateFMul(M01, Det01);
  846. Det02 = Builder.CreateFMul(M02, Det02);
  847. Det03 = Builder.CreateFMul(M03, Det03);
  848. Value *Result = Builder.CreateFSub(Det00, Det01);
  849. Result = Builder.CreateFAdd(Result, Det02);
  850. Result = Builder.CreateFSub(Result, Det03);
  851. return Result;
  852. }
  853. Value *HLMatrixLowerPass::lowerHLDeterminantIntrinsic(Value* MatVal, IRBuilder<> &Builder) {
  854. HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
  855. DXASSERT_NOMSG(MatTy.getNumColumns() == MatTy.getNumRows());
  856. Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
  857. // Extract all matrix elements
  858. SmallVector<Value*, 16> Elems;
  859. for (unsigned ElemIdx = 0; ElemIdx < MatTy.getNumElements(); ++ElemIdx)
  860. Elems.emplace_back(Builder.CreateExtractElement(LoweredVal, static_cast<uint64_t>(ElemIdx)));
  861. // Delegate to appropriate determinant function
  862. switch (MatTy.getNumColumns()) {
  863. case 1:
  864. return Elems[0];
  865. case 2:
  866. return determinant2x2(
  867. Elems[0], Elems[1],
  868. Elems[2], Elems[3],
  869. Builder);
  870. case 3:
  871. return determinant3x3(
  872. Elems[0], Elems[1], Elems[2],
  873. Elems[3], Elems[4], Elems[5],
  874. Elems[6], Elems[7], Elems[8],
  875. Builder);
  876. case 4:
  877. return determinant4x4(
  878. Elems[0], Elems[1], Elems[2], Elems[3],
  879. Elems[4], Elems[5], Elems[6], Elems[7],
  880. Elems[8], Elems[9], Elems[10], Elems[11],
  881. Elems[12], Elems[13], Elems[14], Elems[15],
  882. Builder);
  883. default:
  884. llvm_unreachable("Unexpected matrix dimensions.");
  885. }
  886. }
  887. Value *HLMatrixLowerPass::lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder) {
  888. Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
  889. VectorType *VecTy = cast<VectorType>(LoweredVal->getType());
  890. bool IsFloat = VecTy->getElementType()->isFloatingPointTy();
  891. switch (Opcode) {
  892. case HLUnaryOpcode::Plus: return LoweredVal; // No-op
  893. case HLUnaryOpcode::Minus:
  894. return IsFloat
  895. ? Builder.CreateFSub(Constant::getNullValue(VecTy), LoweredVal)
  896. : Builder.CreateSub(Constant::getNullValue(VecTy), LoweredVal);
  897. case HLUnaryOpcode::LNot:
  898. return IsFloat
  899. ? Builder.CreateFCmp(CmpInst::FCMP_UEQ, LoweredVal, Constant::getNullValue(VecTy))
  900. : Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredVal, Constant::getNullValue(VecTy));
  901. case HLUnaryOpcode::Not:
  902. return Builder.CreateXor(LoweredVal, Constant::getAllOnesValue(VecTy));
  903. case HLUnaryOpcode::PostInc:
  904. case HLUnaryOpcode::PreInc:
  905. case HLUnaryOpcode::PostDec:
  906. case HLUnaryOpcode::PreDec: {
  907. Constant *ScalarOne = IsFloat
  908. ? ConstantFP::get(VecTy->getElementType(), 1)
  909. : ConstantInt::get(VecTy->getElementType(), 1);
  910. Constant *VecOne = ConstantVector::getSplat(VecTy->getNumElements(), ScalarOne);
  911. // CodeGen already emitted the load and following store, our job is only to produce
  912. // the updated value.
  913. if (Opcode == HLUnaryOpcode::PostInc || Opcode == HLUnaryOpcode::PreInc) {
  914. return IsFloat
  915. ? Builder.CreateFAdd(LoweredVal, VecOne)
  916. : Builder.CreateAdd(LoweredVal, VecOne);
  917. }
  918. else {
  919. return IsFloat
  920. ? Builder.CreateFSub(LoweredVal, VecOne)
  921. : Builder.CreateSub(LoweredVal, VecOne);
  922. }
  923. }
  924. default:
  925. llvm_unreachable("Unsupported unary matrix operator");
  926. }
  927. }
  928. Value *HLMatrixLowerPass::lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder) {
  929. Value *LoweredLhs = getLoweredByValOperand(Lhs, Builder);
  930. Value *LoweredRhs = getLoweredByValOperand(Rhs, Builder);
  931. DXASSERT(LoweredLhs->getType()->isVectorTy() && LoweredRhs->getType()->isVectorTy(),
  932. "Expected lowered binary operation operands to be vectors");
  933. DXASSERT(LoweredLhs->getType() == LoweredRhs->getType(),
  934. "Expected lowered binary operation operands to have matching types.");
  935. bool IsFloat = LoweredLhs->getType()->getVectorElementType()->isFloatingPointTy();
  936. switch (Opcode) {
  937. case HLBinaryOpcode::Add:
  938. return IsFloat
  939. ? Builder.CreateFAdd(LoweredLhs, LoweredRhs)
  940. : Builder.CreateAdd(LoweredLhs, LoweredRhs);
  941. case HLBinaryOpcode::Sub:
  942. return IsFloat
  943. ? Builder.CreateFSub(LoweredLhs, LoweredRhs)
  944. : Builder.CreateSub(LoweredLhs, LoweredRhs);
  945. case HLBinaryOpcode::Mul:
  946. return IsFloat
  947. ? Builder.CreateFMul(LoweredLhs, LoweredRhs)
  948. : Builder.CreateMul(LoweredLhs, LoweredRhs);
  949. case HLBinaryOpcode::Div:
  950. return IsFloat
  951. ? Builder.CreateFDiv(LoweredLhs, LoweredRhs)
  952. : Builder.CreateSDiv(LoweredLhs, LoweredRhs);
  953. case HLBinaryOpcode::Rem:
  954. return IsFloat
  955. ? Builder.CreateFRem(LoweredLhs, LoweredRhs)
  956. : Builder.CreateSRem(LoweredLhs, LoweredRhs);
  957. case HLBinaryOpcode::And:
  958. return Builder.CreateAnd(LoweredLhs, LoweredRhs);
  959. case HLBinaryOpcode::Or:
  960. return Builder.CreateOr(LoweredLhs, LoweredRhs);
  961. case HLBinaryOpcode::Xor:
  962. return Builder.CreateXor(LoweredLhs, LoweredRhs);
  963. case HLBinaryOpcode::Shl:
  964. return Builder.CreateShl(LoweredLhs, LoweredRhs);
  965. case HLBinaryOpcode::Shr:
  966. return Builder.CreateAShr(LoweredLhs, LoweredRhs);
  967. case HLBinaryOpcode::LT:
  968. return IsFloat
  969. ? Builder.CreateFCmp(CmpInst::FCMP_OLT, LoweredLhs, LoweredRhs)
  970. : Builder.CreateICmp(CmpInst::ICMP_SLT, LoweredLhs, LoweredRhs);
  971. case HLBinaryOpcode::GT:
  972. return IsFloat
  973. ? Builder.CreateFCmp(CmpInst::FCMP_OGT, LoweredLhs, LoweredRhs)
  974. : Builder.CreateICmp(CmpInst::ICMP_SGT, LoweredLhs, LoweredRhs);
  975. case HLBinaryOpcode::LE:
  976. return IsFloat
  977. ? Builder.CreateFCmp(CmpInst::FCMP_OLE, LoweredLhs, LoweredRhs)
  978. : Builder.CreateICmp(CmpInst::ICMP_SLE, LoweredLhs, LoweredRhs);
  979. case HLBinaryOpcode::GE:
  980. return IsFloat
  981. ? Builder.CreateFCmp(CmpInst::FCMP_OGE, LoweredLhs, LoweredRhs)
  982. : Builder.CreateICmp(CmpInst::ICMP_SGE, LoweredLhs, LoweredRhs);
  983. case HLBinaryOpcode::EQ:
  984. return IsFloat
  985. ? Builder.CreateFCmp(CmpInst::FCMP_OEQ, LoweredLhs, LoweredRhs)
  986. : Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredLhs, LoweredRhs);
  987. case HLBinaryOpcode::NE:
  988. return IsFloat
  989. ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, LoweredRhs)
  990. : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, LoweredRhs);
  991. case HLBinaryOpcode::UDiv:
  992. return Builder.CreateUDiv(LoweredLhs, LoweredRhs);
  993. case HLBinaryOpcode::URem:
  994. return Builder.CreateURem(LoweredLhs, LoweredRhs);
  995. case HLBinaryOpcode::UShr:
  996. return Builder.CreateLShr(LoweredLhs, LoweredRhs);
  997. case HLBinaryOpcode::ULT:
  998. return Builder.CreateICmp(CmpInst::ICMP_ULT, LoweredLhs, LoweredRhs);
  999. case HLBinaryOpcode::UGT:
  1000. return Builder.CreateICmp(CmpInst::ICMP_UGT, LoweredLhs, LoweredRhs);
  1001. case HLBinaryOpcode::ULE:
  1002. return Builder.CreateICmp(CmpInst::ICMP_ULE, LoweredLhs, LoweredRhs);
  1003. case HLBinaryOpcode::UGE:
  1004. return Builder.CreateICmp(CmpInst::ICMP_UGE, LoweredLhs, LoweredRhs);
  1005. case HLBinaryOpcode::LAnd:
  1006. case HLBinaryOpcode::LOr: {
  1007. Value *Zero = Constant::getNullValue(LoweredLhs->getType());
  1008. Value *LhsCmp = IsFloat
  1009. ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, Zero)
  1010. : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, Zero);
  1011. Value *RhsCmp = IsFloat
  1012. ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredRhs, Zero)
  1013. : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredRhs, Zero);
  1014. return Opcode == HLBinaryOpcode::LOr
  1015. ? Builder.CreateOr(LhsCmp, RhsCmp)
  1016. : Builder.CreateAnd(LhsCmp, RhsCmp);
  1017. }
  1018. default:
  1019. llvm_unreachable("Unsupported binary matrix operator");
  1020. }
  1021. }
  1022. Value *HLMatrixLowerPass::lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode) {
  1023. IRBuilder<> Builder(Call);
  1024. switch (Opcode) {
  1025. case HLMatLoadStoreOpcode::RowMatLoad:
  1026. case HLMatLoadStoreOpcode::ColMatLoad:
  1027. return lowerHLLoad(Call, Call->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx),
  1028. /* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatLoad, Builder);
  1029. case HLMatLoadStoreOpcode::RowMatStore:
  1030. case HLMatLoadStoreOpcode::ColMatStore:
  1031. return lowerHLStore(Call,
  1032. Call->getArgOperand(HLOperandIndex::kMatStoreValOpIdx),
  1033. Call->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx),
  1034. /* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatStore,
  1035. /* Return */ !Call->getType()->isVoidTy(), Builder);
  1036. default:
  1037. llvm_unreachable("Unsupported matrix load/store operation");
  1038. }
  1039. }
  1040. Value *HLMatrixLowerPass::lowerHLLoad(CallInst *Call, Value *MatPtr, bool RowMajor, IRBuilder<> &Builder) {
  1041. HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
  1042. Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
  1043. if (LoweredPtr == nullptr) {
  1044. // Can't lower this here, defer to HL signature lower
  1045. HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
  1046. return callHLFunction(
  1047. *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
  1048. MatTy.getLoweredVectorTypeForReg(), { Builder.getInt32((uint32_t)Opcode), MatPtr },
  1049. Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
  1050. }
  1051. return MatTy.emitLoweredLoad(LoweredPtr, Builder);
  1052. }
  1053. Value *HLMatrixLowerPass::lowerHLStore(CallInst *Call, Value *MatVal, Value *MatPtr,
  1054. bool RowMajor, bool Return, IRBuilder<> &Builder) {
  1055. DXASSERT(MatVal->getType() == MatPtr->getType()->getPointerElementType(),
  1056. "Matrix store value/pointer type mismatch.");
  1057. Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
  1058. Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
  1059. if (LoweredPtr == nullptr) {
  1060. // Can't lower the pointer here, defer to HL signature lower
  1061. HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatStore : HLMatLoadStoreOpcode::ColMatStore;
  1062. return callHLFunction(
  1063. *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
  1064. Return ? LoweredVal->getType() : Builder.getVoidTy(),
  1065. { Builder.getInt32((uint32_t)Opcode), MatPtr, LoweredVal },
  1066. Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
  1067. }
  1068. HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
  1069. StoreInst *LoweredStore = MatTy.emitLoweredStore(LoweredVal, LoweredPtr, Builder);
  1070. // If the intrinsic returned a value, return the stored lowered value
  1071. return Return ? LoweredVal : LoweredStore;
  1072. }
  1073. static Value *convertScalarOrVector(Value *SrcVal, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> Builder) {
  1074. DXASSERT(SrcVal->getType()->isVectorTy() == DstTy->isVectorTy(),
  1075. "Scalar/vector type mismatch in numerical conversion.");
  1076. Type *SrcTy = SrcVal->getType();
  1077. // Conversions between equivalent types are no-ops,
  1078. // even between signed/unsigned variants.
  1079. if (SrcTy == DstTy) return SrcVal;
  1080. // Conversions to bools are comparisons
  1081. if (DstTy->getScalarSizeInBits() == 1) {
  1082. // fcmp une is what regular clang uses in C++ for (bool)f;
  1083. return cast<Instruction>(SrcTy->isIntOrIntVectorTy()
  1084. ? Builder.CreateICmpNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool")
  1085. : Builder.CreateFCmpUNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool"));
  1086. }
  1087. // Cast necessary
  1088. bool SrcIsUnsigned = Opcode == HLCastOpcode::FromUnsignedCast ||
  1089. Opcode == HLCastOpcode::UnsignedUnsignedCast;
  1090. bool DstIsUnsigned = Opcode == HLCastOpcode::ToUnsignedCast ||
  1091. Opcode == HLCastOpcode::UnsignedUnsignedCast;
  1092. auto CastOp = static_cast<Instruction::CastOps>(HLModule::GetNumericCastOp(
  1093. SrcTy, SrcIsUnsigned, DstTy, DstIsUnsigned));
  1094. return cast<Instruction>(Builder.CreateCast(CastOp, SrcVal, DstTy));
  1095. }
  1096. Value *HLMatrixLowerPass::lowerHLCast(CallInst *Call, Value *Src, Type *DstTy,
  1097. HLCastOpcode Opcode, IRBuilder<> &Builder) {
  1098. // The opcode really doesn't mean much here, the types involved are what drive most of the casting.
  1099. DXASSERT(Opcode != HLCastOpcode::HandleToResCast, "Unexpected matrix cast opcode.");
  1100. if (dxilutil::IsIntegerOrFloatingPointType(Src->getType())) {
  1101. // Scalar to matrix splat
  1102. HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
  1103. // Apply element conversion
  1104. Value *Result = convertScalarOrVector(Src,
  1105. MatDstTy.getElementTypeForReg(), Opcode, Builder);
  1106. // Splat to a vector
  1107. Result = Builder.CreateInsertElement(
  1108. UndefValue::get(VectorType::get(Result->getType(), 1)),
  1109. Result, static_cast<uint64_t>(0));
  1110. return Builder.CreateShuffleVector(Result, Result,
  1111. ConstantVector::getSplat(MatDstTy.getNumElements(), Builder.getInt32(0)));
  1112. }
  1113. else if (VectorType *SrcVecTy = dyn_cast<VectorType>(Src->getType())) {
  1114. // Vector to matrix
  1115. HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
  1116. Value *Result = Src;
  1117. // We might need to truncate
  1118. if (MatDstTy.getNumElements() < SrcVecTy->getNumElements()) {
  1119. SmallVector<int, 4> ShuffleIndices;
  1120. for (unsigned Idx = 0; Idx < MatDstTy.getNumElements(); ++Idx)
  1121. ShuffleIndices.emplace_back(static_cast<int>(Idx));
  1122. Result = Builder.CreateShuffleVector(Src, Src, ShuffleIndices);
  1123. }
  1124. // Apply element conversion
  1125. return convertScalarOrVector(Result,
  1126. MatDstTy.getLoweredVectorTypeForReg(), Opcode, Builder);
  1127. }
  1128. // Source must now be a matrix
  1129. HLMatrixType MatSrcTy = HLMatrixType::cast(Src->getType());
  1130. VectorType* LoweredSrcTy = MatSrcTy.getLoweredVectorTypeForReg();
  1131. Value *LoweredSrc;
  1132. if (isa<Argument>(Src)) {
  1133. // Function arguments are lowered in HLSignatureLower.
  1134. // Initial codegen first generates those cast intrinsics to tell us how to lower them into vectors.
  1135. // Preserve them, but change the return type to vector.
  1136. DXASSERT(Opcode == HLCastOpcode::ColMatrixToVecCast || Opcode == HLCastOpcode::RowMatrixToVecCast,
  1137. "Unexpected cast of matrix argument.");
  1138. LoweredSrc = callHLFunction(*m_pModule, HLOpcodeGroup::HLCast, static_cast<unsigned>(Opcode),
  1139. LoweredSrcTy, { Builder.getInt32((uint32_t)Opcode), Src },
  1140. Call->getCalledFunction()->getAttributes().getFnAttributes(), Builder);
  1141. }
  1142. else {
  1143. LoweredSrc = getLoweredByValOperand(Src, Builder);
  1144. }
  1145. DXASSERT_NOMSG(LoweredSrc->getType() == LoweredSrcTy);
  1146. Value* Result = LoweredSrc;
  1147. Type* LoweredDstTy = DstTy;
  1148. if (dxilutil::IsIntegerOrFloatingPointType(DstTy)) {
  1149. // Matrix to scalar
  1150. Result = Builder.CreateExtractElement(LoweredSrc, static_cast<uint64_t>(0));
  1151. }
  1152. else if (DstTy->isVectorTy()) {
  1153. // Matrix to vector
  1154. VectorType *DstVecTy = cast<VectorType>(DstTy);
  1155. DXASSERT(DstVecTy->getNumElements() <= LoweredSrcTy->getNumElements(),
  1156. "Cannot cast matrix to a larger vector.");
  1157. // We might have to truncate
  1158. if (DstTy->getVectorNumElements() < LoweredSrcTy->getNumElements()) {
  1159. SmallVector<int, 3> ShuffleIndices;
  1160. for (unsigned Idx = 0; Idx < DstVecTy->getNumElements(); ++Idx)
  1161. ShuffleIndices.emplace_back(static_cast<int>(Idx));
  1162. Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
  1163. }
  1164. }
  1165. else {
  1166. // Destination must now be a matrix too
  1167. HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
  1168. // Apply any changes at the matrix level: orientation changes and truncation
  1169. if (Opcode == HLCastOpcode::ColMatrixToRowMatrix)
  1170. Result = MatSrcTy.emitLoweredVectorColToRow(Result, Builder);
  1171. else if (Opcode == HLCastOpcode::RowMatrixToColMatrix)
  1172. Result = MatSrcTy.emitLoweredVectorRowToCol(Result, Builder);
  1173. else if (MatDstTy.getNumRows() != MatSrcTy.getNumRows()
  1174. || MatDstTy.getNumColumns() != MatSrcTy.getNumColumns()) {
  1175. // Apply truncation
  1176. DXASSERT(MatDstTy.getNumRows() <= MatSrcTy.getNumRows()
  1177. && MatDstTy.getNumColumns() <= MatSrcTy.getNumColumns(),
  1178. "Unexpected matrix cast between incompatible dimensions.");
  1179. SmallVector<int, 16> ShuffleIndices;
  1180. for (unsigned RowIdx = 0; RowIdx < MatDstTy.getNumRows(); ++RowIdx)
  1181. for (unsigned ColIdx = 0; ColIdx < MatDstTy.getNumColumns(); ++ColIdx)
  1182. ShuffleIndices.emplace_back(static_cast<int>(MatSrcTy.getRowMajorIndex(RowIdx, ColIdx)));
  1183. Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
  1184. }
  1185. LoweredDstTy = MatDstTy.getLoweredVectorTypeForReg();
  1186. DXASSERT(Result->getType()->getVectorNumElements() == LoweredDstTy->getVectorNumElements(),
  1187. "Unexpected matrix src/dst lowered element count mismatch after truncation.");
  1188. }
  1189. // Apply element conversion
  1190. return convertScalarOrVector(Result, LoweredDstTy, Opcode, Builder);
  1191. }
  1192. Value *HLMatrixLowerPass::lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode) {
  1193. switch (Opcode) {
  1194. case HLSubscriptOpcode::RowMatElement:
  1195. case HLSubscriptOpcode::ColMatElement:
  1196. return lowerHLMatElementSubscript(Call,
  1197. /* RowMajor */ Opcode == HLSubscriptOpcode::RowMatElement);
  1198. case HLSubscriptOpcode::RowMatSubscript:
  1199. case HLSubscriptOpcode::ColMatSubscript:
  1200. return lowerHLMatSubscript(Call,
  1201. /* RowMajor */ Opcode == HLSubscriptOpcode::RowMatSubscript);
  1202. case HLSubscriptOpcode::DefaultSubscript:
  1203. case HLSubscriptOpcode::CBufferSubscript:
  1204. // Those get lowered during HLOperationLower,
  1205. // and the return type must stay unchanged (as a matrix)
  1206. // to provide the metadata to properly emit the loads.
  1207. return nullptr;
  1208. default:
  1209. llvm_unreachable("Unexpected matrix subscript opcode.");
  1210. }
  1211. }
  1212. Value *HLMatrixLowerPass::lowerHLMatElementSubscript(CallInst *Call, bool RowMajor) {
  1213. (void)RowMajor; // It doesn't look like we actually need this?
  1214. Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  1215. Constant *IdxVec = cast<Constant>(Call->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx));
  1216. VectorType *IdxVecTy = cast<VectorType>(IdxVec->getType());
  1217. // Get the loaded lowered vector element indices
  1218. SmallVector<Value*, 4> ElemIndices;
  1219. ElemIndices.reserve(IdxVecTy->getNumElements());
  1220. for (unsigned VecIdx = 0; VecIdx < IdxVecTy->getNumElements(); ++VecIdx) {
  1221. ElemIndices.emplace_back(IdxVec->getAggregateElement(VecIdx));
  1222. }
  1223. lowerHLMatSubscript(Call, MatPtr, ElemIndices);
  1224. // We did our own replacement of uses, opt-out of having the caller does it for us.
  1225. return nullptr;
  1226. }
  1227. Value *HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, bool RowMajor) {
  1228. (void)RowMajor; // It doesn't look like we actually need this?
  1229. Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  1230. // Gather the indices, checking if they are all constant
  1231. SmallVector<Value*, 4> ElemIndices;
  1232. for (unsigned Idx = HLOperandIndex::kMatSubscriptSubOpIdx; Idx < Call->getNumArgOperands(); ++Idx) {
  1233. ElemIndices.emplace_back(Call->getArgOperand(Idx));
  1234. }
  1235. lowerHLMatSubscript(Call, MatPtr, ElemIndices);
  1236. // We did our own replacement of uses, opt-out of having the caller does it for us.
  1237. return nullptr;
  1238. }
  1239. void HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices) {
  1240. DXASSERT_NOMSG(HLMatrixType::isMatrixPtr(MatPtr->getType()));
  1241. IRBuilder<> CallBuilder(Call);
  1242. Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, CallBuilder);
  1243. if (LoweredPtr == nullptr) return;
  1244. // For global variables, we can GEP directly into the lowered vector pointer.
  1245. // This is necessary to support group shared memory atomics and the likes.
  1246. Value *RootPtr = LoweredPtr;
  1247. while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
  1248. RootPtr = GEP->getPointerOperand();
  1249. bool AllowLoweredPtrGEPs = isa<GlobalVariable>(RootPtr);
  1250. // Just constructing this does all the work
  1251. HLMatrixSubscriptUseReplacer UseReplacer(Call, LoweredPtr, ElemIndices, AllowLoweredPtrGEPs, m_deadInsts);
  1252. DXASSERT(Call->use_empty(), "Expected all matrix subscript uses to have been replaced.");
  1253. addToDeadInsts(Call);
  1254. }
  1255. Value *HLMatrixLowerPass::lowerHLInit(CallInst *Call) {
  1256. DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
  1257. // Figure out the result type
  1258. HLMatrixType MatTy = HLMatrixType::cast(Call->getType());
  1259. VectorType *LoweredTy = MatTy.getLoweredVectorTypeForReg();
  1260. // Handle case where produced by EmitHLSLFlatConversion where there's one
  1261. // vector argument, instead of scalar arguments.
  1262. if (1 == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx &&
  1263. Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx)->
  1264. getType()->isVectorTy()) {
  1265. Value *LoweredVec = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
  1266. DXASSERT(LoweredTy->getNumElements() ==
  1267. LoweredVec->getType()->getVectorNumElements(),
  1268. "Invalid matrix init argument vector element count.");
  1269. return LoweredVec;
  1270. }
  1271. DXASSERT(LoweredTy->getNumElements() == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx,
  1272. "Invalid matrix init argument count.");
  1273. // Build the result vector from the init args.
  1274. // Both the args and the result vector are in row-major order, so no shuffling is necessary.
  1275. IRBuilder<> Builder(Call);
  1276. Value *LoweredVec = UndefValue::get(LoweredTy);
  1277. for (unsigned VecElemIdx = 0; VecElemIdx < LoweredTy->getNumElements(); ++VecElemIdx) {
  1278. Value *ArgVal = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx + VecElemIdx);
  1279. DXASSERT(dxilutil::IsIntegerOrFloatingPointType(ArgVal->getType()),
  1280. "Expected only scalars in matrix initialization.");
  1281. LoweredVec = Builder.CreateInsertElement(LoweredVec, ArgVal, static_cast<uint64_t>(VecElemIdx));
  1282. }
  1283. return LoweredVec;
  1284. }
  1285. Value *HLMatrixLowerPass::lowerHLSelect(CallInst *Call) {
  1286. DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
  1287. Value *Cond = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
  1288. Value *TrueMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx);
  1289. Value *FalseMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx);
  1290. DXASSERT(TrueMat->getType() == FalseMat->getType(),
  1291. "Unexpected type mismatch between matrix ternary operator values.");
  1292. #ifndef NDEBUG
  1293. // Assert that if the condition is a matrix, it matches the dimensions of the values
  1294. if (HLMatrixType MatCondTy = HLMatrixType::dyn_cast(Cond->getType())) {
  1295. HLMatrixType ValMatTy = HLMatrixType::cast(TrueMat->getType());
  1296. DXASSERT(MatCondTy.getNumRows() == ValMatTy.getNumRows()
  1297. && MatCondTy.getNumColumns() == ValMatTy.getNumColumns(),
  1298. "Unexpected mismatch between ternary operator condition and value matrix dimensions.");
  1299. }
  1300. #endif
  1301. IRBuilder<> Builder(Call);
  1302. Value *LoweredCond = getLoweredByValOperand(Cond, Builder);
  1303. Value *LoweredTrueVec = getLoweredByValOperand(TrueMat, Builder);
  1304. Value *LoweredFalseVec = getLoweredByValOperand(FalseMat, Builder);
  1305. Value *Result = UndefValue::get(LoweredTrueVec->getType());
  1306. bool IsScalarCond = !LoweredCond->getType()->isVectorTy();
  1307. unsigned NumElems = Result->getType()->getVectorNumElements();
  1308. for (uint64_t ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
  1309. Value *ElemCond = IsScalarCond ? LoweredCond
  1310. : Builder.CreateExtractElement(LoweredCond, ElemIdx);
  1311. Value *ElemTrueVal = Builder.CreateExtractElement(LoweredTrueVec, ElemIdx);
  1312. Value *ElemFalseVal = Builder.CreateExtractElement(LoweredFalseVec, ElemIdx);
  1313. Value *ResultElem = Builder.CreateSelect(ElemCond, ElemTrueVal, ElemFalseVal);
  1314. Result = Builder.CreateInsertElement(Result, ResultElem, ElemIdx);
  1315. }
  1316. return Result;
  1317. }