HLMatrixLowerPass.cpp 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixLowerPass.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // HLMatrixLowerPass implementation. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLMatrixLowerPass.h"
  12. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  13. #include "dxc/HLSL/HLMatrixType.h"
  14. #include "dxc/HLSL/HLOperations.h"
  15. #include "dxc/HLSL/HLModule.h"
  16. #include "dxc/DXIL/DxilUtil.h"
  17. #include "dxc/HlslIntrinsicOp.h"
  18. #include "dxc/Support/Global.h"
  19. #include "dxc/DXIL/DxilOperations.h"
  20. #include "dxc/DXIL/DxilTypeSystem.h"
  21. #include "dxc/DXIL/DxilModule.h"
  22. #include "HLMatrixSubscriptUseReplacer.h"
  23. #include "llvm/IR/IRBuilder.h"
  24. #include "llvm/IR/Module.h"
  25. #include "llvm/IR/DebugInfo.h"
  26. #include "llvm/IR/IntrinsicInst.h"
  27. #include "llvm/Transforms/Utils/Local.h"
  28. #include "llvm/Pass.h"
  29. #include "llvm/Support/raw_ostream.h"
  30. #include <unordered_set>
  31. #include <vector>
  32. using namespace llvm;
  33. using namespace hlsl;
  34. using namespace hlsl::HLMatrixLower;
  35. namespace hlsl {
  36. namespace HLMatrixLower {
  37. Value *BuildVector(Type *EltTy, ArrayRef<llvm::Value *> elts, IRBuilder<> &Builder) {
  38. Value *Vec = UndefValue::get(VectorType::get(EltTy, static_cast<unsigned>(elts.size())));
  39. for (unsigned i = 0; i < elts.size(); i++)
  40. Vec = Builder.CreateInsertElement(Vec, elts[i], i);
  41. return Vec;
  42. }
  43. } // namespace HLMatrixLower
  44. } // namespace hlsl
  45. namespace {
  46. // Creates and manages a set of temporary overloaded functions keyed on the function type,
  47. // and which should be destroyed when the pool gets out of scope.
  48. class TempOverloadPool {
  49. public:
  50. TempOverloadPool(llvm::Module &Module, const char* BaseName)
  51. : Module(Module), BaseName(BaseName) {}
  52. ~TempOverloadPool() { clear(); }
  53. Function *get(FunctionType *Ty);
  54. bool contains(FunctionType *Ty) const { return Funcs.count(Ty) != 0; }
  55. bool contains(Function *Func) const;
  56. void clear();
  57. private:
  58. llvm::Module &Module;
  59. const char* BaseName;
  60. llvm::DenseMap<FunctionType*, Function*> Funcs;
  61. };
  62. Function *TempOverloadPool::get(FunctionType *Ty) {
  63. auto It = Funcs.find(Ty);
  64. if (It != Funcs.end()) return It->second;
  65. std::string MangledName;
  66. raw_string_ostream MangledNameStream(MangledName);
  67. MangledNameStream << BaseName;
  68. MangledNameStream << '.';
  69. Ty->print(MangledNameStream);
  70. MangledNameStream.flush();
  71. Function* Func = cast<Function>(Module.getOrInsertFunction(MangledName, Ty));
  72. Funcs.insert(std::make_pair(Ty, Func));
  73. return Func;
  74. }
  75. bool TempOverloadPool::contains(Function *Func) const {
  76. auto It = Funcs.find(Func->getFunctionType());
  77. return It != Funcs.end() && It->second == Func;
  78. }
  79. void TempOverloadPool::clear() {
  80. for (auto Entry : Funcs) {
  81. DXASSERT(Entry.second->use_empty(), "Temporary function still used during pool destruction.");
  82. Entry.second->removeFromParent();
  83. }
  84. Funcs.clear();
  85. }
  86. // High-level matrix lowering pass.
  87. //
  88. // This pass converts matrices to their lowered vector representations,
  89. // including global variables, local variables and operations,
  90. // but not function signatures (arguments and return types) - left to HLSignatureLower and HLMatrixBitcastLower,
  91. // nor matrices obtained from resources or constant - left to HLOperationLower.
  92. //
  93. // Algorithm overview:
  94. // 1. Find all matrix and matrix array global variables and lower them to vectors.
  95. // Walk any GEPs and insert vec-to-mat translation stubs so that consuming
  96. // instructions keep dealing with matrix types for the moment.
  97. // 2. For each function
  98. // 2a. Lower all matrix and matrix array allocas, just like global variables.
  99. // 2b. Lower all other instructions producing or consuming matrices
  100. //
  101. // Conversion stubs are used to allow converting instructions in isolation,
  102. // and in an order-independent manner:
  103. //
  104. // Initial: MatInst1(MatInst2(MatInst3))
  105. // After lowering MatInst2: MatInst1(VecToMat(VecInst2(MatToVec(MatInst3))))
  106. // After lowering MatInst1: VecInst1(VecInst2(MatToVec(MatInst3)))
  107. // After lowering MatInst3: VecInst1(VecInst2(VecInst3))
  108. class HLMatrixLowerPass : public ModulePass {
  109. public:
  110. static char ID; // Pass identification, replacement for typeid
  111. explicit HLMatrixLowerPass() : ModulePass(ID) {}
  112. const char *getPassName() const override { return "HL matrix lower"; }
  113. bool runOnModule(Module &M) override;
  114. private:
  115. void runOnFunction(Function &Func);
  116. void addToDeadInsts(Instruction *Inst) { m_deadInsts.emplace_back(Inst); }
  117. void deleteDeadInsts();
  118. void getMatrixAllocasAndOtherInsts(Function &Func,
  119. std::vector<AllocaInst*> &MatAllocas, std::vector<Instruction*> &MatInsts);
  120. Value *getLoweredByValOperand(Value *Val, IRBuilder<> &Builder, bool DiscardStub = false);
  121. Value *tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder, bool DiscardStub = false);
  122. Value *bitCastValue(Value *SrcVal, Type* DstTy, bool DstTyAlloca, IRBuilder<> &Builder);
  123. void replaceAllUsesByLoweredValue(Instruction *MatInst, Value *VecVal);
  124. void replaceAllVariableUses(Value* MatPtr, Value* LoweredPtr);
  125. void replaceAllVariableUses(SmallVectorImpl<Value*> &GEPIdxStack, Value *StackTopPtr, Value* LoweredPtr);
  126. void lowerGlobal(GlobalVariable *Global);
  127. Constant *lowerConstInitVal(Constant *Val);
  128. AllocaInst *lowerAlloca(AllocaInst *MatAlloca);
  129. void lowerInstruction(Instruction* Inst);
  130. void lowerReturn(ReturnInst* Return);
  131. Value *lowerCall(CallInst *Call);
  132. Value *lowerNonHLCall(CallInst *Call);
  133. Value *lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup);
  134. Value *lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode);
  135. Value *lowerHLMulIntrinsic(Value* Lhs, Value *Rhs, bool Unsigned, IRBuilder<> &Builder);
  136. Value *lowerHLTransposeIntrinsic(Value *MatVal, IRBuilder<> &Builder);
  137. Value *lowerHLDeterminantIntrinsic(Value *MatVal, IRBuilder<> &Builder);
  138. Value *lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder);
  139. Value *lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder);
  140. Value *lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode);
  141. Value *lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<> &Builder);
  142. Value *lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder);
  143. Value *lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder);
  144. Value *lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
  145. Value *lowerHLMatElementSubscript(CallInst *Call, bool RowMajor);
  146. Value *lowerHLMatSubscript(CallInst *Call, bool RowMajor);
  147. void lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices);
  148. Value *lowerHLMatResourceSubscript(CallInst *Call, HLSubscriptOpcode Opcode);
  149. Value *lowerHLInit(CallInst *Call);
  150. Value *lowerHLSelect(CallInst *Call);
  151. private:
  152. Module *m_pModule;
  153. HLModule *m_pHLModule;
  154. bool m_HasDbgInfo;
  155. // Pools for the translation stubs
  156. TempOverloadPool *m_matToVecStubs = nullptr;
  157. TempOverloadPool *m_vecToMatStubs = nullptr;
  158. std::vector<Instruction *> m_deadInsts;
  159. };
  160. }
  161. char HLMatrixLowerPass::ID = 0;
  162. ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
  163. INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
  164. bool HLMatrixLowerPass::runOnModule(Module &M) {
  165. TempOverloadPool matToVecStubs(M, "hlmatrixlower.mat2vec");
  166. TempOverloadPool vecToMatStubs(M, "hlmatrixlower.vec2mat");
  167. m_pModule = &M;
  168. m_pHLModule = &m_pModule->GetOrCreateHLModule();
  169. // Load up debug information, to cross-reference values and the instructions
  170. // used to load them.
  171. m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  172. m_matToVecStubs = &matToVecStubs;
  173. m_vecToMatStubs = &vecToMatStubs;
  174. // First, lower static global variables.
  175. // We need to accumulate them locally because we'll be creating new ones as we lower them.
  176. std::vector<GlobalVariable*> Globals;
  177. for (GlobalVariable &Global : M.globals()) {
  178. if ((dxilutil::IsStaticGlobal(&Global) || dxilutil::IsSharedMemoryGlobal(&Global))
  179. && HLMatrixType::isMatrixPtrOrArrayPtr(Global.getType())) {
  180. Globals.emplace_back(&Global);
  181. }
  182. }
  183. for (GlobalVariable *Global : Globals)
  184. lowerGlobal(Global);
  185. for (Function &F : M.functions()) {
  186. if (F.isDeclaration()) continue;
  187. runOnFunction(F);
  188. }
  189. m_pModule = nullptr;
  190. m_pHLModule = nullptr;
  191. m_matToVecStubs = nullptr;
  192. m_vecToMatStubs = nullptr;
  193. // If you hit an assert during TempOverloadPool destruction,
  194. // it means that either a matrix producer was lowered,
  195. // causing a translation stub to be created,
  196. // but the consumer of that matrix was never (properly) lowered.
  197. // Or the opposite: a matrix consumer was lowered and not its producer.
  198. return true;
  199. }
  200. void HLMatrixLowerPass::runOnFunction(Function &Func) {
  201. // Skip hl function definition (like createhandle)
  202. if (hlsl::GetHLOpcodeGroupByName(&Func) != HLOpcodeGroup::NotHL)
  203. return;
  204. // Save the matrix instructions first since the translation process
  205. // will temporarily create other instructions consuming/producing matrix types.
  206. std::vector<AllocaInst*> MatAllocas;
  207. std::vector<Instruction*> MatInsts;
  208. getMatrixAllocasAndOtherInsts(Func, MatAllocas, MatInsts);
  209. // First lower all allocas and take care of their GEP chains
  210. for (AllocaInst* MatAlloca : MatAllocas) {
  211. AllocaInst* LoweredAlloca = lowerAlloca(MatAlloca);
  212. replaceAllVariableUses(MatAlloca, LoweredAlloca);
  213. addToDeadInsts(MatAlloca);
  214. }
  215. // Now lower all other matrix instructions
  216. for (Instruction *MatInst : MatInsts)
  217. lowerInstruction(MatInst);
  218. deleteDeadInsts();
  219. }
  220. void HLMatrixLowerPass::deleteDeadInsts() {
  221. while (!m_deadInsts.empty()) {
  222. Instruction *Inst = m_deadInsts.back();
  223. m_deadInsts.pop_back();
  224. DXASSERT_NOMSG(Inst->use_empty());
  225. for (Value *Operand : Inst->operand_values()) {
  226. Instruction *OperandInst = dyn_cast<Instruction>(Operand);
  227. if (OperandInst && ++OperandInst->user_begin() == OperandInst->user_end()) {
  228. // We were its only user, erase recursively.
  229. // This will get rid of translation stubs:
  230. // Original: MatConsumer(MatProducer)
  231. // Producer lowered: MatConsumer(VecToMat(VecProducer)), MatProducer dead
  232. // Consumer lowered: VecConsumer(VecProducer)), MatConsumer(VecToMat) dead
  233. // Only by recursing on MatConsumer's operand do we delete the VecToMat stub.
  234. DXASSERT_NOMSG(*OperandInst->user_begin() == Inst);
  235. m_deadInsts.emplace_back(OperandInst);
  236. }
  237. }
  238. Inst->eraseFromParent();
  239. }
  240. }
  241. // Find all instructions consuming or producing matrices,
  242. // directly or through pointers/arrays.
  243. void HLMatrixLowerPass::getMatrixAllocasAndOtherInsts(Function &Func,
  244. std::vector<AllocaInst*> &MatAllocas, std::vector<Instruction*> &MatInsts){
  245. for (BasicBlock &BasicBlock : Func) {
  246. for (Instruction &Inst : BasicBlock) {
  247. // Don't lower GEPs directly, we'll handle them as we lower the root pointer,
  248. // typically a global variable or alloca.
  249. if (isa<GetElementPtrInst>(&Inst)) continue;
  250. if (AllocaInst *Alloca = dyn_cast<AllocaInst>(&Inst)) {
  251. if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Alloca->getType())) {
  252. MatAllocas.emplace_back(Alloca);
  253. }
  254. continue;
  255. }
  256. if (CallInst *Call = dyn_cast<CallInst>(&Inst)) {
  257. // Lowering of global variables will have introduced
  258. // vec-to-mat translation stubs, which we deal with indirectly,
  259. // as we lower the instructions consuming them.
  260. if (m_vecToMatStubs->contains(Call->getCalledFunction()))
  261. continue;
  262. // Mat-to-vec stubs should only be introduced during instruction lowering.
  263. // Globals lowering won't introduce any because their only operand is
  264. // their initializer, which we can fully lower without stubbing since it is constant.
  265. DXASSERT(!m_matToVecStubs->contains(Call->getCalledFunction()),
  266. "Unexpected mat-to-vec stubbing before function instruction lowering.");
  267. // Match matrix producers
  268. if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Inst.getType())) {
  269. MatInsts.emplace_back(Call);
  270. continue;
  271. }
  272. // Match matrix consumers
  273. for (Value *Operand : Inst.operand_values()) {
  274. if (HLMatrixType::isMatrixOrPtrOrArrayPtr(Operand->getType())) {
  275. MatInsts.emplace_back(Call);
  276. break;
  277. }
  278. }
  279. continue;
  280. }
  281. if (ReturnInst *Return = dyn_cast<ReturnInst>(&Inst)) {
  282. Value *ReturnValue = Return->getReturnValue();
  283. if (ReturnValue != nullptr && HLMatrixType::isMatrixOrPtrOrArrayPtr(ReturnValue->getType()))
  284. MatInsts.emplace_back(Return);
  285. continue;
  286. }
  287. // Nothing else should produce or consume matrices
  288. }
  289. }
  290. }
  291. // Gets the matrix-lowered representation of a value, potentially adding a translation stub.
  292. // DiscardStub causes any vec-to-mat translation stubs to be deleted,
  293. // it should be true only if the original instruction will be modified and kept alive.
  294. // If a new instruction is created and the original marked as dead,
  295. // then the remove dead instructions pass will take care of removing the stub.
  296. Value* HLMatrixLowerPass::getLoweredByValOperand(Value *Val, IRBuilder<> &Builder, bool DiscardStub) {
  297. Type *Ty = Val->getType();
  298. // We're only lowering byval matrices.
  299. // Since structs and arrays are always accessed by pointer,
  300. // we do not need to worry about a matrix being hidden inside a more complex type.
  301. DXASSERT(!Ty->isPointerTy(), "Value cannot be a pointer.");
  302. HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty);
  303. if (!MatTy) return Val;
  304. Type *LoweredTy = MatTy.getLoweredVectorTypeForReg();
  305. // Check if the value is already a vec-to-mat translation stub
  306. if (CallInst *Call = dyn_cast<CallInst>(Val)) {
  307. if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
  308. if (DiscardStub && Call->getNumUses() == 1) {
  309. Call->use_begin()->set(UndefValue::get(Call->getType()));
  310. addToDeadInsts(Call);
  311. }
  312. Value *LoweredVal = Call->getArgOperand(0);
  313. DXASSERT(LoweredVal->getType() == LoweredTy, "Unexpected already-lowered value type.");
  314. return LoweredVal;
  315. }
  316. }
  317. // Return a mat-to-vec translation stub
  318. FunctionType *TranslationStubTy = FunctionType::get(LoweredTy, { Ty }, /* isVarArg */ false);
  319. Function *TranslationStub = m_matToVecStubs->get(TranslationStubTy);
  320. return Builder.CreateCall(TranslationStub, { Val });
  321. }
  322. // Attempts to retrieve the lowered vector pointer equivalent to a matrix pointer.
  323. // Returns nullptr if the pointed-to matrix lives in memory that cannot be lowered at this time,
  324. // for example a buffer or shader inputs/outputs, which are lowered during signature lowering.
  325. Value *HLMatrixLowerPass::tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Builder, bool DiscardStub) {
  326. if (!HLMatrixType::isMatrixPtrOrArrayPtr(Ptr->getType()))
  327. return nullptr;
  328. // Matrix pointers can only be derived from Allocas, GlobalVariables or resource accesses.
  329. // The first two cases are what this pass must be able to lower, and we should already
  330. // have replaced their uses by vector to matrix pointer translation stubs.
  331. if (CallInst *Call = dyn_cast<CallInst>(Ptr)) {
  332. if (m_vecToMatStubs->contains(Call->getCalledFunction())) {
  333. if (DiscardStub && Call->getNumUses() == 1) {
  334. Call->use_begin()->set(UndefValue::get(Call->getType()));
  335. addToDeadInsts(Call);
  336. }
  337. return Call->getArgOperand(0);
  338. }
  339. }
  340. // There's one more case to handle.
  341. // When compiling shader libraries, signatures won't have been lowered yet.
  342. // So we can have a matrix in a struct as an argument,
  343. // or an alloca'd struct holding the return value of a call and containing a matrix.
  344. Value *RootPtr = Ptr;
  345. while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
  346. RootPtr = GEP->getPointerOperand();
  347. Argument *Arg = dyn_cast<Argument>(RootPtr);
  348. bool IsNonShaderArg = Arg != nullptr && !m_pHLModule->IsGraphicsShader(Arg->getParent());
  349. if (IsNonShaderArg || isa<AllocaInst>(RootPtr)) {
  350. // Bitcast the matrix pointer to its lowered equivalent.
  351. // The HLMatrixBitcast pass will take care of this later.
  352. return Builder.CreateBitCast(Ptr, HLMatrixType::getLoweredType(Ptr->getType()));
  353. }
  354. // The pointer must be derived from a resource, we don't handle it in this pass.
  355. return nullptr;
  356. }
  357. // Bitcasts a value from matrix to vector or vice-versa.
  358. // This is used to convert to/from arguments/return values since we don't
  359. // lower signatures in this pass. The later HLMatrixBitcastLower pass fixes this.
  360. Value *HLMatrixLowerPass::bitCastValue(Value *SrcVal, Type* DstTy, bool DstTyAlloca, IRBuilder<> &Builder) {
  361. Type *SrcTy = SrcVal->getType();
  362. DXASSERT_NOMSG(!SrcTy->isPointerTy());
  363. // We store and load from a temporary alloca, bitcasting either on the store pointer
  364. // or on the load pointer.
  365. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
  366. Value *Alloca = AllocaBuilder.CreateAlloca(DstTyAlloca ? DstTy : SrcTy);
  367. Value *BitCastedAlloca = Builder.CreateBitCast(Alloca, (DstTyAlloca ? SrcTy : DstTy)->getPointerTo());
  368. Builder.CreateStore(SrcVal, DstTyAlloca ? BitCastedAlloca : Alloca);
  369. return Builder.CreateLoad(DstTyAlloca ? Alloca : BitCastedAlloca);
  370. }
  371. // Replaces all uses of a matrix value by its lowered vector form,
  372. // inserting translation stubs for users which still expect a matrix value.
  373. void HLMatrixLowerPass::replaceAllUsesByLoweredValue(Instruction* MatInst, Value* VecVal) {
  374. if (VecVal == nullptr || VecVal == MatInst) return;
  375. DXASSERT(HLMatrixType::getLoweredType(MatInst->getType()) == VecVal->getType(),
  376. "Unexpected lowered value type.");
  377. Instruction *VecToMatStub = nullptr;
  378. while (!MatInst->use_empty()) {
  379. Use &ValUse = *MatInst->use_begin();
  380. // Handle non-matrix cases, just point to the new value.
  381. if (MatInst->getType() == VecVal->getType()) {
  382. ValUse.set(VecVal);
  383. continue;
  384. }
  385. // If the user is already a matrix-to-vector translation stub,
  386. // we can now replace it by the proper vector value.
  387. if (CallInst *Call = dyn_cast<CallInst>(ValUse.getUser())) {
  388. if (m_matToVecStubs->contains(Call->getCalledFunction())) {
  389. Call->replaceAllUsesWith(VecVal);
  390. ValUse.set(UndefValue::get(MatInst->getType()));
  391. addToDeadInsts(Call);
  392. continue;
  393. }
  394. }
  395. // Otherwise, the user should point to a vector-to-matrix translation
  396. // stub of the new vector value.
  397. if (VecToMatStub == nullptr) {
  398. FunctionType *TranslationStubTy = FunctionType::get(
  399. MatInst->getType(), { VecVal->getType() }, /* isVarArg */ false);
  400. Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
  401. Instruction *PrevInst = dyn_cast<Instruction>(VecVal);
  402. if (PrevInst == nullptr) PrevInst = MatInst;
  403. IRBuilder<> Builder(dxilutil::SkipAllocas(PrevInst->getNextNode()));
  404. VecToMatStub = Builder.CreateCall(TranslationStub, { VecVal });
  405. }
  406. ValUse.set(VecToMatStub);
  407. }
  408. }
  409. // Replaces all uses of a matrix or matrix array alloca or global variable by its lowered equivalent.
  410. // This doesn't lower the users, but will insert a translation stub from the lowered value pointer
  411. // back to the matrix value pointer, and recreate any GEPs around the new pointer.
  412. // Before: User(GEP(MatrixArrayAlloca))
  413. // After: User(VecToMatPtrStub(GEP'(VectorArrayAlloca)))
  414. void HLMatrixLowerPass::replaceAllVariableUses(Value* MatPtr, Value* LoweredPtr) {
  415. DXASSERT_NOMSG(HLMatrixType::isMatrixPtrOrArrayPtr(MatPtr->getType()));
  416. DXASSERT_NOMSG(LoweredPtr->getType() == HLMatrixType::getLoweredType(MatPtr->getType()));
  417. SmallVector<Value*, 4> GEPIdxStack;
  418. GEPIdxStack.emplace_back(ConstantInt::get(Type::getInt32Ty(MatPtr->getContext()), 0));
  419. replaceAllVariableUses(GEPIdxStack, MatPtr, LoweredPtr);
  420. }
  421. void HLMatrixLowerPass::replaceAllVariableUses(
  422. SmallVectorImpl<Value*> &GEPIdxStack, Value *StackTopPtr, Value* LoweredPtr) {
  423. while (!StackTopPtr->use_empty()) {
  424. llvm::Use &Use = *StackTopPtr->use_begin();
  425. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Use.getUser())) {
  426. DXASSERT(GEP->getNumIndices() >= 1, "Unexpected degenerate GEP.");
  427. DXASSERT(cast<ConstantInt>(*GEP->idx_begin())->isZero(), "Unexpected non-zero first GEP index.");
  428. // Recurse in GEP to find actual users
  429. for (auto It = GEP->idx_begin() + 1; It != GEP->idx_end(); ++It)
  430. GEPIdxStack.emplace_back(*It);
  431. replaceAllVariableUses(GEPIdxStack, GEP, LoweredPtr);
  432. GEPIdxStack.erase(GEPIdxStack.end() - (GEP->getNumIndices() - 1), GEPIdxStack.end());
  433. // Discard the GEP
  434. DXASSERT_NOMSG(GEP->use_empty());
  435. if (GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(GEP)) {
  436. Use.set(UndefValue::get(Use->getType()));
  437. addToDeadInsts(GEPInst);
  438. } else {
  439. // constant GEP
  440. cast<Constant>(GEP)->destroyConstant();
  441. }
  442. continue;
  443. }
  444. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Use.getUser())) {
  445. DXASSERT(CE->getOpcode() == Instruction::AddrSpaceCast,
  446. "Unexpected constant user");
  447. replaceAllVariableUses(GEPIdxStack, CE, LoweredPtr);
  448. DXASSERT_NOMSG(CE->use_empty());
  449. CE->destroyConstant();
  450. continue;
  451. }
  452. if (AddrSpaceCastInst *CI = dyn_cast<AddrSpaceCastInst>(Use.getUser())) {
  453. replaceAllVariableUses(GEPIdxStack, CI, LoweredPtr);
  454. Use.set(UndefValue::get(Use->getType()));
  455. addToDeadInsts(CI);
  456. continue;
  457. }
  458. // Recreate the same GEP sequence, if any, on the lowered pointer
  459. IRBuilder<> Builder(cast<Instruction>(Use.getUser()));
  460. Value *LoweredStackTopPtr = GEPIdxStack.size() == 1
  461. ? LoweredPtr : Builder.CreateGEP(LoweredPtr, GEPIdxStack);
  462. // Generate a stub translating the vector pointer back to a matrix pointer,
  463. // such that consuming instructions are unaffected.
  464. FunctionType *TranslationStubTy = FunctionType::get(
  465. StackTopPtr->getType(), { LoweredStackTopPtr->getType() }, /* isVarArg */ false);
  466. Function *TranslationStub = m_vecToMatStubs->get(TranslationStubTy);
  467. Use.set(Builder.CreateCall(TranslationStub, { LoweredStackTopPtr }));
  468. }
  469. }
  470. void HLMatrixLowerPass::lowerGlobal(GlobalVariable *Global) {
  471. if (Global->user_empty()) return;
  472. PointerType *LoweredPtrTy = cast<PointerType>(HLMatrixType::getLoweredType(Global->getType()));
  473. DXASSERT_NOMSG(LoweredPtrTy != Global->getType());
  474. Constant *LoweredInitVal = Global->hasInitializer()
  475. ? lowerConstInitVal(Global->getInitializer()) : nullptr;
  476. GlobalVariable *LoweredGlobal = new GlobalVariable(*m_pModule, LoweredPtrTy->getElementType(),
  477. Global->isConstant(), Global->getLinkage(), LoweredInitVal,
  478. Global->getName() + ".v", /*InsertBefore*/ nullptr, Global->getThreadLocalMode(),
  479. Global->getType()->getAddressSpace());
  480. // Add debug info.
  481. if (m_HasDbgInfo) {
  482. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  483. HLModule::UpdateGlobalVariableDebugInfo(Global, Finder, LoweredGlobal);
  484. }
  485. replaceAllVariableUses(Global, LoweredGlobal);
  486. Global->removeDeadConstantUsers();
  487. Global->eraseFromParent();
  488. }
  489. Constant *HLMatrixLowerPass::lowerConstInitVal(Constant *Val) {
  490. Type *Ty = Val->getType();
  491. // If it's an array of matrices, recurse for each element or nested array
  492. if (ArrayType *ArrayTy = dyn_cast<ArrayType>(Ty)) {
  493. SmallVector<Constant*, 4> LoweredElems;
  494. unsigned NumElems = ArrayTy->getNumElements();
  495. LoweredElems.reserve(NumElems);
  496. for (unsigned ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
  497. Constant *ArrayElem = Val->getAggregateElement(ElemIdx);
  498. LoweredElems.emplace_back(lowerConstInitVal(ArrayElem));
  499. }
  500. Type *LoweredElemTy = HLMatrixType::getLoweredType(ArrayTy->getElementType());
  501. ArrayType *LoweredArrayTy = ArrayType::get(LoweredElemTy, NumElems);
  502. return ConstantArray::get(LoweredArrayTy, LoweredElems);
  503. }
  504. // Otherwise it's a matrix, lower it to a vector
  505. HLMatrixType MatTy = HLMatrixType::cast(Ty);
  506. DXASSERT_NOMSG(isa<StructType>(Ty));
  507. Constant *RowArrayVal = Val->getAggregateElement((unsigned)0);
  508. // Original initializer should have been produced in row/column-major order
  509. // depending on the qualifiers of the target variable, so preserve the order.
  510. SmallVector<Constant*, 16> MatElems;
  511. for (unsigned RowIdx = 0; RowIdx < MatTy.getNumRows(); ++RowIdx) {
  512. Constant *RowVal = RowArrayVal->getAggregateElement(RowIdx);
  513. for (unsigned ColIdx = 0; ColIdx < MatTy.getNumColumns(); ++ColIdx) {
  514. MatElems.emplace_back(RowVal->getAggregateElement(ColIdx));
  515. }
  516. }
  517. Constant *Vec = ConstantVector::get(MatElems);
  518. // Matrix elements are always in register representation,
  519. // but the lowered global variable is of vector type in
  520. // its memory representation, so we must convert here.
  521. // This will produce a constant so we can use an IRBuilder without a valid insertion point.
  522. IRBuilder<> DummyBuilder(Val->getContext());
  523. return cast<Constant>(MatTy.emitLoweredRegToMem(Vec, DummyBuilder));
  524. }
  525. AllocaInst *HLMatrixLowerPass::lowerAlloca(AllocaInst *MatAlloca) {
  526. PointerType *LoweredAllocaTy = cast<PointerType>(HLMatrixType::getLoweredType(MatAlloca->getType()));
  527. IRBuilder<> Builder(MatAlloca);
  528. AllocaInst *LoweredAlloca = Builder.CreateAlloca(
  529. LoweredAllocaTy->getElementType(), nullptr, MatAlloca->getName());
  530. // Update debug info.
  531. if (DbgDeclareInst *DbgDeclare = llvm::FindAllocaDbgDeclare(MatAlloca)) {
  532. LLVMContext &Context = MatAlloca->getContext();
  533. Value *DbgDeclareVar = MetadataAsValue::get(Context, DbgDeclare->getRawVariable());
  534. Value *DbgDeclareExpr = MetadataAsValue::get(Context, DbgDeclare->getRawExpression());
  535. Value *ValueMetadata = MetadataAsValue::get(Context, ValueAsMetadata::get(LoweredAlloca));
  536. IRBuilder<> DebugBuilder(DbgDeclare);
  537. DebugBuilder.CreateCall(DbgDeclare->getCalledFunction(), { ValueMetadata, DbgDeclareVar, DbgDeclareExpr });
  538. }
  539. if (HLModule::HasPreciseAttributeWithMetadata(MatAlloca))
  540. HLModule::MarkPreciseAttributeWithMetadata(LoweredAlloca);
  541. replaceAllVariableUses(MatAlloca, LoweredAlloca);
  542. return LoweredAlloca;
  543. }
  544. void HLMatrixLowerPass::lowerInstruction(Instruction* Inst) {
  545. if (CallInst *Call = dyn_cast<CallInst>(Inst)) {
  546. Value *LoweredValue = lowerCall(Call);
  547. // lowerCall returns the lowered value iff we should discard
  548. // the original matrix instruction and replace all of its uses
  549. // by the lowered value. It returns nullptr to opt-out of this.
  550. if (LoweredValue != nullptr) {
  551. replaceAllUsesByLoweredValue(Call, LoweredValue);
  552. addToDeadInsts(Inst);
  553. }
  554. }
  555. else if (ReturnInst *Return = dyn_cast<ReturnInst>(Inst)) {
  556. lowerReturn(Return);
  557. }
  558. else
  559. llvm_unreachable("Unexpected matrix instruction type.");
  560. }
  561. void HLMatrixLowerPass::lowerReturn(ReturnInst* Return) {
  562. Value *RetVal = Return->getReturnValue();
  563. Type *RetTy = RetVal->getType();
  564. DXASSERT_LOCALVAR(RetTy, !RetTy->isPointerTy(), "Unexpected matrix returned by pointer.");
  565. IRBuilder<> Builder(Return);
  566. Value *LoweredRetVal = getLoweredByValOperand(RetVal, Builder, /* DiscardStub */ true);
  567. // Since we're not lowering the signature, we can't return the lowered value directly,
  568. // so insert a bitcast, which HLMatrixBitcastLower knows how to eliminate.
  569. Value *BitCastedRetVal = bitCastValue(LoweredRetVal, RetVal->getType(), /* DstTyAlloca */ false, Builder);
  570. Return->setOperand(0, BitCastedRetVal);
  571. }
  572. Value *HLMatrixLowerPass::lowerCall(CallInst *Call) {
  573. HLOpcodeGroup OpcodeGroup = GetHLOpcodeGroupByName(Call->getCalledFunction());
  574. return OpcodeGroup == HLOpcodeGroup::NotHL
  575. ? lowerNonHLCall(Call) : lowerHLOperation(Call, OpcodeGroup);
  576. }
  577. Value *HLMatrixLowerPass::lowerNonHLCall(CallInst *Call) {
  578. // First, handle any operand of matrix-derived type
  579. // We don't lower the callee's signature in this pass,
  580. // so, for any matrix-typed parameter, we create a bitcast from the
  581. // lowered vector back to the matrix type, which the later HLMatrixBitcastLower
  582. // pass knows how to eliminate.
  583. IRBuilder<> PreCallBuilder(Call);
  584. unsigned NumArgs = Call->getNumArgOperands();
  585. for (unsigned ArgIdx = 0; ArgIdx < NumArgs; ++ArgIdx) {
  586. Use &ArgUse = Call->getArgOperandUse(ArgIdx);
  587. if (ArgUse->getType()->isPointerTy()) {
  588. // Byref arg
  589. Value *LoweredArg = tryGetLoweredPtrOperand(ArgUse.get(), PreCallBuilder, /* DiscardStub */ true);
  590. if (LoweredArg != nullptr) {
  591. // Pointer to a matrix we've lowered, insert a bitcast back to matrix pointer type.
  592. Value *BitCastedArg = PreCallBuilder.CreateBitCast(LoweredArg, ArgUse->getType());
  593. ArgUse.set(BitCastedArg);
  594. }
  595. }
  596. else {
  597. // Byvalue arg
  598. Value *LoweredArg = getLoweredByValOperand(ArgUse.get(), PreCallBuilder, /* DiscardStub */ true);
  599. if (LoweredArg == ArgUse.get()) continue;
  600. Value *BitCastedArg = bitCastValue(LoweredArg, ArgUse->getType(), /* DstTyAlloca */ false, PreCallBuilder);
  601. ArgUse.set(BitCastedArg);
  602. }
  603. }
  604. // Now check the return type
  605. HLMatrixType RetMatTy = HLMatrixType::dyn_cast(Call->getType());
  606. if (!RetMatTy) {
  607. DXASSERT(!HLMatrixType::isMatrixPtrOrArrayPtr(Call->getType()),
  608. "Unexpected user call returning a matrix by pointer.");
  609. // Nothing to replace, other instructions can consume a non-matrix return type.
  610. return nullptr;
  611. }
  612. // The callee returns a matrix, and we don't lower signatures in this pass.
  613. // We perform a sketchy bitcast to the lowered register-representation type,
  614. // which the later HLMatrixBitcastLower pass knows how to eliminate.
  615. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Call));
  616. Value *LoweredAlloca = AllocaBuilder.CreateAlloca(RetMatTy.getLoweredVectorTypeForReg());
  617. IRBuilder<> PostCallBuilder(Call->getNextNode());
  618. Value *BitCastedAlloca = PostCallBuilder.CreateBitCast(LoweredAlloca, Call->getType()->getPointerTo());
  619. // This is slightly tricky
  620. // We want to replace all uses of the matrix-returning call by the bitcasted value,
  621. // but the store to the bitcasted pointer itself is a use of that matrix,
  622. // so we need to create the load, replace the uses, and then insert the store.
  623. LoadInst *LoweredVal = PostCallBuilder.CreateLoad(LoweredAlloca);
  624. replaceAllUsesByLoweredValue(Call, LoweredVal);
  625. // Now we can insert the store. Make sure to do so before the load.
  626. PostCallBuilder.SetInsertPoint(LoweredVal);
  627. PostCallBuilder.CreateStore(Call, BitCastedAlloca);
  628. // Return nullptr since we did our own uses replacement and we don't want
  629. // the matrix instruction to be marked as dead since we're still using it.
  630. return nullptr;
  631. }
  632. Value *HLMatrixLowerPass::lowerHLOperation(CallInst *Call, HLOpcodeGroup OpcodeGroup) {
  633. IRBuilder<> Builder(Call);
  634. switch (OpcodeGroup) {
  635. case HLOpcodeGroup::HLIntrinsic:
  636. return lowerHLIntrinsic(Call, static_cast<IntrinsicOp>(GetHLOpcode(Call)));
  637. case HLOpcodeGroup::HLBinOp:
  638. return lowerHLBinaryOperation(
  639. Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
  640. Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
  641. static_cast<HLBinaryOpcode>(GetHLOpcode(Call)), Builder);
  642. case HLOpcodeGroup::HLUnOp:
  643. return lowerHLUnaryOperation(
  644. Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx),
  645. static_cast<HLUnaryOpcode>(GetHLOpcode(Call)), Builder);
  646. case HLOpcodeGroup::HLMatLoadStore:
  647. return lowerHLLoadStore(Call, static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(Call)));
  648. case HLOpcodeGroup::HLCast:
  649. return lowerHLCast(
  650. Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Call->getType(),
  651. static_cast<HLCastOpcode>(GetHLOpcode(Call)), Builder);
  652. case HLOpcodeGroup::HLSubscript:
  653. return lowerHLSubscript(Call, static_cast<HLSubscriptOpcode>(GetHLOpcode(Call)));
  654. case HLOpcodeGroup::HLInit:
  655. return lowerHLInit(Call);
  656. case HLOpcodeGroup::HLSelect:
  657. return lowerHLSelect(Call);
  658. default:
  659. llvm_unreachable("Unexpected matrix opcode");
  660. }
  661. }
  662. static Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
  663. Type *RetTy, ArrayRef<Value*> Args, IRBuilder<> &Builder) {
  664. SmallVector<Type*, 4> ArgTys;
  665. ArgTys.reserve(Args.size());
  666. for (Value *Arg : Args)
  667. ArgTys.emplace_back(Arg->getType());
  668. FunctionType *FuncTy = FunctionType::get(RetTy, ArgTys, /* isVarArg */ false);
  669. Function *Func = GetOrCreateHLFunction(Module, FuncTy, OpcodeGroup, Opcode);
  670. return Builder.CreateCall(Func, Args);
  671. }
  672. Value *HLMatrixLowerPass::lowerHLIntrinsic(CallInst *Call, IntrinsicOp Opcode) {
  673. IRBuilder<> Builder(Call);
  674. // See if this is a matrix-specific intrinsic which we should expand here
  675. switch (Opcode) {
  676. case IntrinsicOp::IOP_umul:
  677. case IntrinsicOp::IOP_mul:
  678. return lowerHLMulIntrinsic(
  679. Call->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx),
  680. Call->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx),
  681. /* Unsigned */ Opcode == IntrinsicOp::IOP_umul, Builder);
  682. case IntrinsicOp::IOP_transpose:
  683. return lowerHLTransposeIntrinsic(Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
  684. case IntrinsicOp::IOP_determinant:
  685. return lowerHLDeterminantIntrinsic(Call->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx), Builder);
  686. }
  687. // Delegate to a lowered intrinsic call
  688. SmallVector<Value*, 4> LoweredArgs;
  689. LoweredArgs.reserve(Call->getNumArgOperands());
  690. for (Value *Arg : Call->arg_operands()) {
  691. if (Arg->getType()->isPointerTy()) {
  692. // ByRef parameter (for example, frexp's second parameter)
  693. // If the argument points to a lowered matrix variable, replace it here,
  694. // otherwise preserve the matrix type and let further passes handle the lowering.
  695. Value *LoweredArg = tryGetLoweredPtrOperand(Arg, Builder);
  696. if (LoweredArg == nullptr) LoweredArg = Arg;
  697. LoweredArgs.emplace_back(LoweredArg);
  698. }
  699. else {
  700. LoweredArgs.emplace_back(getLoweredByValOperand(Arg, Builder));
  701. }
  702. }
  703. Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
  704. return callHLFunction(*m_pModule, HLOpcodeGroup::HLIntrinsic, static_cast<unsigned>(Opcode),
  705. LoweredRetTy, LoweredArgs, Builder);
  706. }
  707. Value *HLMatrixLowerPass::lowerHLMulIntrinsic(Value* Lhs, Value *Rhs,
  708. bool Unsigned, IRBuilder<> &Builder) {
  709. HLMatrixType LhsMatTy = HLMatrixType::dyn_cast(Lhs->getType());
  710. HLMatrixType RhsMatTy = HLMatrixType::dyn_cast(Rhs->getType());
  711. Value* LoweredLhs = getLoweredByValOperand(Lhs, Builder);
  712. Value* LoweredRhs = getLoweredByValOperand(Rhs, Builder);
  713. DXASSERT(LoweredLhs->getType()->getScalarType() == LoweredRhs->getType()->getScalarType(),
  714. "Unexpected element type mismatch in mul intrinsic.");
  715. DXASSERT(cast<VectorType>(LoweredLhs->getType()) && cast<VectorType>(LoweredLhs->getType()),
  716. "Unexpected scalar in lowered matrix mul intrinsic operands.");
  717. Type* ElemTy = LoweredLhs->getType()->getScalarType();
  718. // Figure out the dimensions of each side
  719. unsigned LhsNumRows, LhsNumCols, RhsNumRows, RhsNumCols;
  720. if (LhsMatTy && RhsMatTy) {
  721. LhsNumRows = LhsMatTy.getNumRows();
  722. LhsNumCols = LhsMatTy.getNumColumns();
  723. RhsNumRows = RhsMatTy.getNumRows();
  724. RhsNumCols = RhsMatTy.getNumColumns();
  725. }
  726. else if (LhsMatTy) {
  727. LhsNumRows = LhsMatTy.getNumRows();
  728. LhsNumCols = LhsMatTy.getNumColumns();
  729. RhsNumRows = LoweredRhs->getType()->getVectorNumElements();
  730. RhsNumCols = 1;
  731. }
  732. else if (RhsMatTy) {
  733. LhsNumRows = 1;
  734. LhsNumCols = LoweredLhs->getType()->getVectorNumElements();
  735. RhsNumRows = RhsMatTy.getNumRows();
  736. RhsNumCols = RhsMatTy.getNumColumns();
  737. }
  738. else {
  739. llvm_unreachable("mul intrinsic was identified as a matrix operation but neither operand is a matrix.");
  740. }
  741. DXASSERT(LhsNumCols == RhsNumRows, "Matrix mul intrinsic operands dimensions mismatch.");
  742. HLMatrixType ResultMatTy(ElemTy, LhsNumRows, RhsNumCols);
  743. unsigned AccCount = LhsNumCols;
  744. // Get the multiply-and-add intrinsic function, we'll need it
  745. IntrinsicOp MadOpcode = Unsigned ? IntrinsicOp::IOP_umad : IntrinsicOp::IOP_mad;
  746. FunctionType *MadFuncTy = FunctionType::get(ElemTy, { Builder.getInt32Ty(), ElemTy, ElemTy, ElemTy }, false);
  747. Function *MadFunc = GetOrCreateHLFunction(*m_pModule, MadFuncTy, HLOpcodeGroup::HLIntrinsic, (unsigned)MadOpcode);
  748. Constant *MadOpcodeVal = Builder.getInt32((unsigned)MadOpcode);
  749. // Perform the multiplication!
  750. Value *Result = UndefValue::get(VectorType::get(ElemTy, LhsNumRows * RhsNumCols));
  751. for (unsigned ResultRowIdx = 0; ResultRowIdx < ResultMatTy.getNumRows(); ++ResultRowIdx) {
  752. for (unsigned ResultColIdx = 0; ResultColIdx < ResultMatTy.getNumColumns(); ++ResultColIdx) {
  753. unsigned ResultElemIdx = ResultMatTy.getRowMajorIndex(ResultRowIdx, ResultColIdx);
  754. Value *ResultElem = nullptr;
  755. for (unsigned AccIdx = 0; AccIdx < AccCount; ++AccIdx) {
  756. unsigned LhsElemIdx = HLMatrixType::getRowMajorIndex(ResultRowIdx, AccIdx, LhsNumRows, LhsNumCols);
  757. unsigned RhsElemIdx = HLMatrixType::getRowMajorIndex(AccIdx, ResultColIdx, RhsNumRows, RhsNumCols);
  758. Value* LhsElem = Builder.CreateExtractElement(LoweredLhs, static_cast<uint64_t>(LhsElemIdx));
  759. Value* RhsElem = Builder.CreateExtractElement(LoweredRhs, static_cast<uint64_t>(RhsElemIdx));
  760. if (ResultElem == nullptr) {
  761. ResultElem = ElemTy->isFloatingPointTy()
  762. ? Builder.CreateFMul(LhsElem, RhsElem)
  763. : Builder.CreateMul(LhsElem, RhsElem);
  764. }
  765. else {
  766. ResultElem = Builder.CreateCall(MadFunc, { MadOpcodeVal, LhsElem, RhsElem, ResultElem });
  767. }
  768. }
  769. Result = Builder.CreateInsertElement(Result, ResultElem, static_cast<uint64_t>(ResultElemIdx));
  770. }
  771. }
  772. return Result;
  773. }
  774. Value *HLMatrixLowerPass::lowerHLTransposeIntrinsic(Value* MatVal, IRBuilder<> &Builder) {
  775. HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
  776. Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
  777. return MatTy.emitLoweredVectorRowToCol(LoweredVal, Builder);
  778. }
  779. static Value *determinant2x2(Value *M00, Value *M01, Value *M10, Value *M11, IRBuilder<> &Builder) {
  780. Value *Mul0 = Builder.CreateFMul(M00, M11);
  781. Value *Mul1 = Builder.CreateFMul(M01, M10);
  782. return Builder.CreateFSub(Mul0, Mul1);
  783. }
  784. static Value *determinant3x3(Value *M00, Value *M01, Value *M02,
  785. Value *M10, Value *M11, Value *M12,
  786. Value *M20, Value *M21, Value *M22,
  787. IRBuilder<> &Builder) {
  788. Value *Det00 = determinant2x2(M11, M12, M21, M22, Builder);
  789. Value *Det01 = determinant2x2(M10, M12, M20, M22, Builder);
  790. Value *Det02 = determinant2x2(M10, M11, M20, M21, Builder);
  791. Det00 = Builder.CreateFMul(M00, Det00);
  792. Det01 = Builder.CreateFMul(M01, Det01);
  793. Det02 = Builder.CreateFMul(M02, Det02);
  794. Value *Result = Builder.CreateFSub(Det00, Det01);
  795. Result = Builder.CreateFAdd(Result, Det02);
  796. return Result;
  797. }
  798. static Value *determinant4x4(Value *M00, Value *M01, Value *M02, Value *M03,
  799. Value *M10, Value *M11, Value *M12, Value *M13,
  800. Value *M20, Value *M21, Value *M22, Value *M23,
  801. Value *M30, Value *M31, Value *M32, Value *M33,
  802. IRBuilder<> &Builder) {
  803. Value *Det00 = determinant3x3(M11, M12, M13, M21, M22, M23, M31, M32, M33, Builder);
  804. Value *Det01 = determinant3x3(M10, M12, M13, M20, M22, M23, M30, M32, M33, Builder);
  805. Value *Det02 = determinant3x3(M10, M11, M13, M20, M21, M23, M30, M31, M33, Builder);
  806. Value *Det03 = determinant3x3(M10, M11, M12, M20, M21, M22, M30, M31, M32, Builder);
  807. Det00 = Builder.CreateFMul(M00, Det00);
  808. Det01 = Builder.CreateFMul(M01, Det01);
  809. Det02 = Builder.CreateFMul(M02, Det02);
  810. Det03 = Builder.CreateFMul(M03, Det03);
  811. Value *Result = Builder.CreateFSub(Det00, Det01);
  812. Result = Builder.CreateFAdd(Result, Det02);
  813. Result = Builder.CreateFSub(Result, Det03);
  814. return Result;
  815. }
  816. Value *HLMatrixLowerPass::lowerHLDeterminantIntrinsic(Value* MatVal, IRBuilder<> &Builder) {
  817. HLMatrixType MatTy = HLMatrixType::cast(MatVal->getType());
  818. DXASSERT_NOMSG(MatTy.getNumColumns() == MatTy.getNumRows());
  819. Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
  820. // Extract all matrix elements
  821. SmallVector<Value*, 16> Elems;
  822. for (unsigned ElemIdx = 0; ElemIdx < MatTy.getNumElements(); ++ElemIdx)
  823. Elems.emplace_back(Builder.CreateExtractElement(LoweredVal, static_cast<uint64_t>(ElemIdx)));
  824. // Delegate to appropriate determinant function
  825. switch (MatTy.getNumColumns()) {
  826. case 1:
  827. return Elems[0];
  828. case 2:
  829. return determinant2x2(
  830. Elems[0], Elems[1],
  831. Elems[2], Elems[3],
  832. Builder);
  833. case 3:
  834. return determinant3x3(
  835. Elems[0], Elems[1], Elems[2],
  836. Elems[3], Elems[4], Elems[5],
  837. Elems[6], Elems[7], Elems[8],
  838. Builder);
  839. case 4:
  840. return determinant4x4(
  841. Elems[0], Elems[1], Elems[2], Elems[3],
  842. Elems[4], Elems[5], Elems[6], Elems[7],
  843. Elems[8], Elems[9], Elems[10], Elems[11],
  844. Elems[12], Elems[13], Elems[14], Elems[15],
  845. Builder);
  846. default:
  847. llvm_unreachable("Unexpected matrix dimensions.");
  848. }
  849. }
  850. Value *HLMatrixLowerPass::lowerHLUnaryOperation(Value *MatVal, HLUnaryOpcode Opcode, IRBuilder<> &Builder) {
  851. Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
  852. VectorType *VecTy = cast<VectorType>(LoweredVal->getType());
  853. bool IsFloat = VecTy->getElementType()->isFloatingPointTy();
  854. switch (Opcode) {
  855. case HLUnaryOpcode::Plus: return LoweredVal; // No-op
  856. case HLUnaryOpcode::Minus:
  857. return IsFloat
  858. ? Builder.CreateFSub(Constant::getNullValue(VecTy), LoweredVal)
  859. : Builder.CreateSub(Constant::getNullValue(VecTy), LoweredVal);
  860. case HLUnaryOpcode::LNot:
  861. return IsFloat
  862. ? Builder.CreateFCmp(CmpInst::FCMP_UEQ, LoweredVal, Constant::getNullValue(VecTy))
  863. : Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredVal, Constant::getNullValue(VecTy));
  864. case HLUnaryOpcode::Not:
  865. return Builder.CreateXor(LoweredVal, Constant::getAllOnesValue(VecTy));
  866. case HLUnaryOpcode::PostInc:
  867. case HLUnaryOpcode::PreInc:
  868. case HLUnaryOpcode::PostDec:
  869. case HLUnaryOpcode::PreDec: {
  870. Constant *ScalarOne = IsFloat
  871. ? ConstantFP::get(VecTy->getElementType(), 1)
  872. : ConstantInt::get(VecTy->getElementType(), 1);
  873. Constant *VecOne = ConstantVector::getSplat(VecTy->getNumElements(), ScalarOne);
  874. // BUGBUG: This implementation has incorrect semantics (GitHub #1780)
  875. if (Opcode == HLUnaryOpcode::PostInc || Opcode == HLUnaryOpcode::PreInc) {
  876. return IsFloat
  877. ? Builder.CreateFAdd(LoweredVal, VecOne)
  878. : Builder.CreateAdd(LoweredVal, VecOne);
  879. }
  880. else {
  881. return IsFloat
  882. ? Builder.CreateFSub(LoweredVal, VecOne)
  883. : Builder.CreateSub(LoweredVal, VecOne);
  884. }
  885. }
  886. default:
  887. llvm_unreachable("Unsupported unary matrix operator");
  888. }
  889. }
  890. Value *HLMatrixLowerPass::lowerHLBinaryOperation(Value *Lhs, Value *Rhs, HLBinaryOpcode Opcode, IRBuilder<> &Builder) {
  891. Value *LoweredLhs = getLoweredByValOperand(Lhs, Builder);
  892. Value *LoweredRhs = getLoweredByValOperand(Rhs, Builder);
  893. DXASSERT(LoweredLhs->getType()->isVectorTy() && LoweredRhs->getType()->isVectorTy(),
  894. "Expected lowered binary operation operands to be vectors");
  895. DXASSERT(LoweredLhs->getType() == LoweredRhs->getType(),
  896. "Expected lowered binary operation operands to have matching types.");
  897. bool IsFloat = LoweredLhs->getType()->getVectorElementType()->isFloatingPointTy();
  898. switch (Opcode) {
  899. case HLBinaryOpcode::Add:
  900. return IsFloat
  901. ? Builder.CreateFAdd(LoweredLhs, LoweredRhs)
  902. : Builder.CreateAdd(LoweredLhs, LoweredRhs);
  903. case HLBinaryOpcode::Sub:
  904. return IsFloat
  905. ? Builder.CreateFSub(LoweredLhs, LoweredRhs)
  906. : Builder.CreateSub(LoweredLhs, LoweredRhs);
  907. case HLBinaryOpcode::Mul:
  908. return IsFloat
  909. ? Builder.CreateFMul(LoweredLhs, LoweredRhs)
  910. : Builder.CreateMul(LoweredLhs, LoweredRhs);
  911. case HLBinaryOpcode::Div:
  912. return IsFloat
  913. ? Builder.CreateFDiv(LoweredLhs, LoweredRhs)
  914. : Builder.CreateSDiv(LoweredLhs, LoweredRhs);
  915. case HLBinaryOpcode::Rem:
  916. return IsFloat
  917. ? Builder.CreateFRem(LoweredLhs, LoweredRhs)
  918. : Builder.CreateSRem(LoweredLhs, LoweredRhs);
  919. case HLBinaryOpcode::And:
  920. return Builder.CreateAnd(LoweredLhs, LoweredRhs);
  921. case HLBinaryOpcode::Or:
  922. return Builder.CreateOr(LoweredLhs, LoweredRhs);
  923. case HLBinaryOpcode::Xor:
  924. return Builder.CreateXor(LoweredLhs, LoweredRhs);
  925. case HLBinaryOpcode::Shl:
  926. return Builder.CreateShl(LoweredLhs, LoweredRhs);
  927. case HLBinaryOpcode::Shr:
  928. return Builder.CreateAShr(LoweredLhs, LoweredRhs);
  929. case HLBinaryOpcode::LT:
  930. return IsFloat
  931. ? Builder.CreateFCmp(CmpInst::FCMP_OLT, LoweredLhs, LoweredRhs)
  932. : Builder.CreateICmp(CmpInst::ICMP_SLT, LoweredLhs, LoweredRhs);
  933. case HLBinaryOpcode::GT:
  934. return IsFloat
  935. ? Builder.CreateFCmp(CmpInst::FCMP_OGT, LoweredLhs, LoweredRhs)
  936. : Builder.CreateICmp(CmpInst::ICMP_SGT, LoweredLhs, LoweredRhs);
  937. case HLBinaryOpcode::LE:
  938. return IsFloat
  939. ? Builder.CreateFCmp(CmpInst::FCMP_OLE, LoweredLhs, LoweredRhs)
  940. : Builder.CreateICmp(CmpInst::ICMP_SLE, LoweredLhs, LoweredRhs);
  941. case HLBinaryOpcode::GE:
  942. return IsFloat
  943. ? Builder.CreateFCmp(CmpInst::FCMP_OGE, LoweredLhs, LoweredRhs)
  944. : Builder.CreateICmp(CmpInst::ICMP_SGE, LoweredLhs, LoweredRhs);
  945. case HLBinaryOpcode::EQ:
  946. return IsFloat
  947. ? Builder.CreateFCmp(CmpInst::FCMP_OEQ, LoweredLhs, LoweredRhs)
  948. : Builder.CreateICmp(CmpInst::ICMP_EQ, LoweredLhs, LoweredRhs);
  949. case HLBinaryOpcode::NE:
  950. return IsFloat
  951. ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, LoweredRhs)
  952. : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, LoweredRhs);
  953. case HLBinaryOpcode::UDiv:
  954. return Builder.CreateUDiv(LoweredLhs, LoweredRhs);
  955. case HLBinaryOpcode::URem:
  956. return Builder.CreateURem(LoweredLhs, LoweredRhs);
  957. case HLBinaryOpcode::UShr:
  958. return Builder.CreateLShr(LoweredLhs, LoweredRhs);
  959. case HLBinaryOpcode::ULT:
  960. return Builder.CreateICmp(CmpInst::ICMP_ULT, LoweredLhs, LoweredRhs);
  961. case HLBinaryOpcode::UGT:
  962. return Builder.CreateICmp(CmpInst::ICMP_UGT, LoweredLhs, LoweredRhs);
  963. case HLBinaryOpcode::ULE:
  964. return Builder.CreateICmp(CmpInst::ICMP_ULE, LoweredLhs, LoweredRhs);
  965. case HLBinaryOpcode::UGE:
  966. return Builder.CreateICmp(CmpInst::ICMP_UGE, LoweredLhs, LoweredRhs);
  967. case HLBinaryOpcode::LAnd:
  968. case HLBinaryOpcode::LOr: {
  969. Value *Zero = Constant::getNullValue(LoweredLhs->getType());
  970. Value *LhsCmp = IsFloat
  971. ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredLhs, Zero)
  972. : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredLhs, Zero);
  973. Value *RhsCmp = IsFloat
  974. ? Builder.CreateFCmp(CmpInst::FCMP_ONE, LoweredRhs, Zero)
  975. : Builder.CreateICmp(CmpInst::ICMP_NE, LoweredRhs, Zero);
  976. return Opcode == HLBinaryOpcode::LOr
  977. ? Builder.CreateOr(LhsCmp, RhsCmp)
  978. : Builder.CreateAnd(LhsCmp, RhsCmp);
  979. }
  980. default:
  981. llvm_unreachable("Unsupported binary matrix operator");
  982. }
  983. }
  984. Value *HLMatrixLowerPass::lowerHLLoadStore(CallInst *Call, HLMatLoadStoreOpcode Opcode) {
  985. IRBuilder<> Builder(Call);
  986. switch (Opcode) {
  987. case HLMatLoadStoreOpcode::RowMatLoad:
  988. case HLMatLoadStoreOpcode::ColMatLoad:
  989. return lowerHLLoad(Call->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx),
  990. /* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatLoad, Builder);
  991. case HLMatLoadStoreOpcode::RowMatStore:
  992. case HLMatLoadStoreOpcode::ColMatStore:
  993. return lowerHLStore(
  994. Call->getArgOperand(HLOperandIndex::kMatStoreValOpIdx),
  995. Call->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx),
  996. /* RowMajor */ Opcode == HLMatLoadStoreOpcode::RowMatStore,
  997. /* Return */ !Call->getType()->isVoidTy(), Builder);
  998. default:
  999. llvm_unreachable("Unsupported matrix load/store operation");
  1000. }
  1001. }
  1002. Value *HLMatrixLowerPass::lowerHLLoad(Value *MatPtr, bool RowMajor, IRBuilder<> &Builder) {
  1003. HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
  1004. Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
  1005. if (LoweredPtr == nullptr) {
  1006. // Can't lower this here, defer to HL signature lower
  1007. HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
  1008. return callHLFunction(
  1009. *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
  1010. MatTy.getLoweredVectorTypeForReg(), { Builder.getInt32((uint32_t)Opcode), MatPtr }, Builder);
  1011. }
  1012. return MatTy.emitLoweredLoad(LoweredPtr, Builder);
  1013. }
  1014. Value *HLMatrixLowerPass::lowerHLStore(Value *MatVal, Value *MatPtr, bool RowMajor, bool Return, IRBuilder<> &Builder) {
  1015. DXASSERT(MatVal->getType() == MatPtr->getType()->getPointerElementType(),
  1016. "Matrix store value/pointer type mismatch.");
  1017. Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, Builder);
  1018. Value *LoweredVal = getLoweredByValOperand(MatVal, Builder);
  1019. if (LoweredPtr == nullptr) {
  1020. // Can't lower the pointer here, defer to HL signature lower
  1021. HLMatLoadStoreOpcode Opcode = RowMajor ? HLMatLoadStoreOpcode::RowMatStore : HLMatLoadStoreOpcode::ColMatStore;
  1022. return callHLFunction(
  1023. *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
  1024. Return ? LoweredVal->getType() : Builder.getVoidTy(),
  1025. { Builder.getInt32((uint32_t)Opcode), MatPtr, LoweredVal }, Builder);
  1026. }
  1027. HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
  1028. StoreInst *LoweredStore = MatTy.emitLoweredStore(LoweredVal, LoweredPtr, Builder);
  1029. // If the intrinsic returned a value, return the stored lowered value
  1030. return Return ? LoweredVal : LoweredStore;
  1031. }
  1032. static Value *convertScalarOrVector(Value *SrcVal, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> Builder) {
  1033. DXASSERT(SrcVal->getType()->isVectorTy() == DstTy->isVectorTy(),
  1034. "Scalar/vector type mismatch in numerical conversion.");
  1035. Type *SrcTy = SrcVal->getType();
  1036. // Conversions between equivalent types are no-ops,
  1037. // even between signed/unsigned variants.
  1038. if (SrcTy == DstTy) return SrcVal;
  1039. // Conversions to bools are comparisons
  1040. if (DstTy->getScalarSizeInBits() == 1) {
  1041. // fcmp une is what regular clang uses in C++ for (bool)f;
  1042. return cast<Instruction>(SrcTy->isIntOrIntVectorTy()
  1043. ? Builder.CreateICmpNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool")
  1044. : Builder.CreateFCmpUNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool"));
  1045. }
  1046. // Cast necessary
  1047. bool SrcIsUnsigned = Opcode == HLCastOpcode::FromUnsignedCast ||
  1048. Opcode == HLCastOpcode::UnsignedUnsignedCast;
  1049. bool DstIsUnsigned = Opcode == HLCastOpcode::ToUnsignedCast ||
  1050. Opcode == HLCastOpcode::UnsignedUnsignedCast;
  1051. auto CastOp = static_cast<Instruction::CastOps>(HLModule::GetNumericCastOp(
  1052. SrcTy, SrcIsUnsigned, DstTy, DstIsUnsigned));
  1053. return cast<Instruction>(Builder.CreateCast(CastOp, SrcVal, DstTy));
  1054. }
  1055. Value *HLMatrixLowerPass::lowerHLCast(Value *Src, Type *DstTy, HLCastOpcode Opcode, IRBuilder<> &Builder) {
  1056. // The opcode really doesn't mean much here, the types involved are what drive most of the casting.
  1057. DXASSERT(Opcode != HLCastOpcode::HandleToResCast, "Unexpected matrix cast opcode.");
  1058. if (dxilutil::IsIntegerOrFloatingPointType(Src->getType())) {
  1059. // Scalar to matrix splat
  1060. HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
  1061. // Apply element conversion
  1062. Value *Result = convertScalarOrVector(Src,
  1063. MatDstTy.getElementTypeForReg(), Opcode, Builder);
  1064. // Splat to a vector
  1065. Result = Builder.CreateInsertElement(
  1066. UndefValue::get(VectorType::get(Result->getType(), 1)),
  1067. Result, static_cast<uint64_t>(0));
  1068. return Builder.CreateShuffleVector(Result, Result,
  1069. ConstantVector::getSplat(MatDstTy.getNumElements(), Builder.getInt32(0)));
  1070. }
  1071. else if (VectorType *SrcVecTy = dyn_cast<VectorType>(Src->getType())) {
  1072. // Vector to matrix
  1073. HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
  1074. Value *Result = Src;
  1075. // We might need to truncate
  1076. if (MatDstTy.getNumElements() < SrcVecTy->getNumElements()) {
  1077. SmallVector<int, 4> ShuffleIndices;
  1078. for (unsigned Idx = 0; Idx < MatDstTy.getNumElements(); ++Idx)
  1079. ShuffleIndices.emplace_back(static_cast<int>(Idx));
  1080. Result = Builder.CreateShuffleVector(Src, Src, ShuffleIndices);
  1081. }
  1082. // Apply element conversion
  1083. return convertScalarOrVector(Result,
  1084. MatDstTy.getLoweredVectorTypeForReg(), Opcode, Builder);
  1085. }
  1086. // Source must now be a matrix
  1087. HLMatrixType MatSrcTy = HLMatrixType::cast(Src->getType());
  1088. VectorType* LoweredSrcTy = MatSrcTy.getLoweredVectorTypeForReg();
  1089. Value *LoweredSrc;
  1090. if (isa<Argument>(Src)) {
  1091. // Function arguments are lowered in HLSignatureLower.
  1092. // Initial codegen first generates those cast intrinsics to tell us how to lower them into vectors.
  1093. // Preserve them, but change the return type to vector.
  1094. DXASSERT(Opcode == HLCastOpcode::ColMatrixToVecCast || Opcode == HLCastOpcode::RowMatrixToVecCast,
  1095. "Unexpected cast of matrix argument.");
  1096. LoweredSrc = callHLFunction(*m_pModule, HLOpcodeGroup::HLCast, static_cast<unsigned>(Opcode),
  1097. LoweredSrcTy, { Builder.getInt32((uint32_t)Opcode), Src }, Builder);
  1098. }
  1099. else {
  1100. LoweredSrc = getLoweredByValOperand(Src, Builder);
  1101. }
  1102. DXASSERT_NOMSG(LoweredSrc->getType() == LoweredSrcTy);
  1103. Value* Result = LoweredSrc;
  1104. Type* LoweredDstTy = DstTy;
  1105. if (dxilutil::IsIntegerOrFloatingPointType(DstTy)) {
  1106. // Matrix to scalar
  1107. Result = Builder.CreateExtractElement(LoweredSrc, static_cast<uint64_t>(0));
  1108. }
  1109. else if (DstTy->isVectorTy()) {
  1110. // Matrix to vector
  1111. VectorType *DstVecTy = cast<VectorType>(DstTy);
  1112. DXASSERT(DstVecTy->getNumElements() <= LoweredSrcTy->getNumElements(),
  1113. "Cannot cast matrix to a larger vector.");
  1114. // We might have to truncate
  1115. if (DstTy->getVectorNumElements() < LoweredSrcTy->getNumElements()) {
  1116. SmallVector<int, 3> ShuffleIndices;
  1117. for (unsigned Idx = 0; Idx < DstVecTy->getNumElements(); ++Idx)
  1118. ShuffleIndices.emplace_back(static_cast<int>(Idx));
  1119. Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
  1120. }
  1121. }
  1122. else {
  1123. // Destination must now be a matrix too
  1124. HLMatrixType MatDstTy = HLMatrixType::cast(DstTy);
  1125. // Apply any changes at the matrix level: orientation changes and truncation
  1126. if (Opcode == HLCastOpcode::ColMatrixToRowMatrix)
  1127. Result = MatSrcTy.emitLoweredVectorColToRow(Result, Builder);
  1128. else if (Opcode == HLCastOpcode::RowMatrixToColMatrix)
  1129. Result = MatSrcTy.emitLoweredVectorRowToCol(Result, Builder);
  1130. else if (MatDstTy.getNumRows() != MatSrcTy.getNumRows()
  1131. || MatDstTy.getNumColumns() != MatSrcTy.getNumColumns()) {
  1132. // Apply truncation
  1133. DXASSERT(MatDstTy.getNumRows() <= MatSrcTy.getNumRows()
  1134. && MatDstTy.getNumColumns() <= MatSrcTy.getNumColumns(),
  1135. "Unexpected matrix cast between incompatible dimensions.");
  1136. SmallVector<int, 16> ShuffleIndices;
  1137. for (unsigned RowIdx = 0; RowIdx < MatDstTy.getNumRows(); ++RowIdx)
  1138. for (unsigned ColIdx = 0; ColIdx < MatDstTy.getNumColumns(); ++ColIdx)
  1139. ShuffleIndices.emplace_back(static_cast<int>(MatSrcTy.getRowMajorIndex(RowIdx, ColIdx)));
  1140. Result = Builder.CreateShuffleVector(Result, Result, ShuffleIndices);
  1141. }
  1142. LoweredDstTy = MatDstTy.getLoweredVectorTypeForReg();
  1143. DXASSERT(Result->getType()->getVectorNumElements() == LoweredDstTy->getVectorNumElements(),
  1144. "Unexpected matrix src/dst lowered element count mismatch after truncation.");
  1145. }
  1146. // Apply element conversion
  1147. return convertScalarOrVector(Result, LoweredDstTy, Opcode, Builder);
  1148. }
  1149. Value *HLMatrixLowerPass::lowerHLSubscript(CallInst *Call, HLSubscriptOpcode Opcode) {
  1150. switch (Opcode) {
  1151. case HLSubscriptOpcode::RowMatElement:
  1152. case HLSubscriptOpcode::ColMatElement:
  1153. return lowerHLMatElementSubscript(Call,
  1154. /* RowMajor */ Opcode == HLSubscriptOpcode::RowMatElement);
  1155. case HLSubscriptOpcode::RowMatSubscript:
  1156. case HLSubscriptOpcode::ColMatSubscript:
  1157. return lowerHLMatSubscript(Call,
  1158. /* RowMajor */ Opcode == HLSubscriptOpcode::RowMatSubscript);
  1159. case HLSubscriptOpcode::DefaultSubscript:
  1160. case HLSubscriptOpcode::CBufferSubscript:
  1161. // Those get lowered during HLOperationLower,
  1162. // and the return type must stay unchanged (as a matrix)
  1163. // to provide the metadata to properly emit the loads.
  1164. return nullptr;
  1165. default:
  1166. llvm_unreachable("Unexpected matrix subscript opcode.");
  1167. }
  1168. }
  1169. Value *HLMatrixLowerPass::lowerHLMatElementSubscript(CallInst *Call, bool RowMajor) {
  1170. (void)RowMajor; // It doesn't look like we actually need this?
  1171. Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  1172. Constant *IdxVec = cast<Constant>(Call->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx));
  1173. VectorType *IdxVecTy = cast<VectorType>(IdxVec->getType());
  1174. // Get the loaded lowered vector element indices
  1175. SmallVector<Value*, 4> ElemIndices;
  1176. ElemIndices.reserve(IdxVecTy->getNumElements());
  1177. for (unsigned VecIdx = 0; VecIdx < IdxVecTy->getNumElements(); ++VecIdx) {
  1178. ElemIndices.emplace_back(IdxVec->getAggregateElement(VecIdx));
  1179. }
  1180. lowerHLMatSubscript(Call, MatPtr, ElemIndices);
  1181. // We did our own replacement of uses, opt-out of having the caller does it for us.
  1182. return nullptr;
  1183. }
  1184. Value *HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, bool RowMajor) {
  1185. (void)RowMajor; // It doesn't look like we actually need this?
  1186. Value *MatPtr = Call->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
  1187. // Gather the indices, checking if they are all constant
  1188. SmallVector<Value*, 4> ElemIndices;
  1189. for (unsigned Idx = HLOperandIndex::kMatSubscriptSubOpIdx; Idx < Call->getNumArgOperands(); ++Idx) {
  1190. ElemIndices.emplace_back(Call->getArgOperand(Idx));
  1191. }
  1192. lowerHLMatSubscript(Call, MatPtr, ElemIndices);
  1193. // We did our own replacement of uses, opt-out of having the caller does it for us.
  1194. return nullptr;
  1195. }
  1196. void HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, Value *MatPtr, SmallVectorImpl<Value*> &ElemIndices) {
  1197. DXASSERT_NOMSG(HLMatrixType::isMatrixPtr(MatPtr->getType()));
  1198. IRBuilder<> CallBuilder(Call);
  1199. Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, CallBuilder);
  1200. if (LoweredPtr == nullptr) return;
  1201. // For global variables, we can GEP directly into the lowered vector pointer.
  1202. // This is necessary to support group shared memory atomics and the likes.
  1203. Value *RootPtr = LoweredPtr;
  1204. while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
  1205. RootPtr = GEP->getPointerOperand();
  1206. bool AllowLoweredPtrGEPs = isa<GlobalVariable>(RootPtr);
  1207. // Just constructing this does all the work
  1208. HLMatrixSubscriptUseReplacer UseReplacer(Call, LoweredPtr, ElemIndices, AllowLoweredPtrGEPs, m_deadInsts);
  1209. DXASSERT(Call->use_empty(), "Expected all matrix subscript uses to have been replaced.");
  1210. addToDeadInsts(Call);
  1211. }
  1212. // Lowers StructuredBuffer<matrix>[index] or similar with constant buffers
  1213. Value *HLMatrixLowerPass::lowerHLMatResourceSubscript(CallInst *Call, HLSubscriptOpcode Opcode) {
  1214. // Just replace the intrinsic by its equivalent with a lowered return type
  1215. IRBuilder<> Builder(Call);
  1216. SmallVector<Value*, 4> Args;
  1217. Args.reserve(Call->getNumArgOperands());
  1218. for (Value *Arg : Call->arg_operands())
  1219. Args.emplace_back(Arg);
  1220. Type *LoweredRetTy = HLMatrixType::getLoweredType(Call->getType());
  1221. return callHLFunction(*m_pModule, HLOpcodeGroup::HLSubscript, static_cast<unsigned>(Opcode),
  1222. LoweredRetTy, Args, Builder);
  1223. }
  1224. Value *HLMatrixLowerPass::lowerHLInit(CallInst *Call) {
  1225. DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
  1226. // Figure out the result type
  1227. HLMatrixType MatTy = HLMatrixType::cast(Call->getType());
  1228. VectorType *LoweredTy = MatTy.getLoweredVectorTypeForReg();
  1229. // Handle case where produced by EmitHLSLFlatConversion where there's one
  1230. // vector argument, instead of scalar arguments.
  1231. if (1 == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx &&
  1232. Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx)->
  1233. getType()->isVectorTy()) {
  1234. Value *LoweredVec = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
  1235. DXASSERT(LoweredTy->getNumElements() ==
  1236. LoweredVec->getType()->getVectorNumElements(),
  1237. "Invalid matrix init argument vector element count.");
  1238. return LoweredVec;
  1239. }
  1240. DXASSERT(LoweredTy->getNumElements() == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx,
  1241. "Invalid matrix init argument count.");
  1242. // Build the result vector from the init args.
  1243. // Both the args and the result vector are in row-major order, so no shuffling is necessary.
  1244. IRBuilder<> Builder(Call);
  1245. Value *LoweredVec = UndefValue::get(LoweredTy);
  1246. for (unsigned VecElemIdx = 0; VecElemIdx < LoweredTy->getNumElements(); ++VecElemIdx) {
  1247. Value *ArgVal = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx + VecElemIdx);
  1248. DXASSERT(dxilutil::IsIntegerOrFloatingPointType(ArgVal->getType()),
  1249. "Expected only scalars in matrix initialization.");
  1250. LoweredVec = Builder.CreateInsertElement(LoweredVec, ArgVal, static_cast<uint64_t>(VecElemIdx));
  1251. }
  1252. return LoweredVec;
  1253. }
  1254. Value *HLMatrixLowerPass::lowerHLSelect(CallInst *Call) {
  1255. DXASSERT(GetHLOpcode(Call) == 0, "Unexpected matrix init opcode.");
  1256. Value *Cond = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
  1257. Value *TrueMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx);
  1258. Value *FalseMat = Call->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx);
  1259. DXASSERT(TrueMat->getType() == FalseMat->getType(),
  1260. "Unexpected type mismatch between matrix ternary operator values.");
  1261. #ifndef NDEBUG
  1262. // Assert that if the condition is a matrix, it matches the dimensions of the values
  1263. if (HLMatrixType MatCondTy = HLMatrixType::dyn_cast(Cond->getType())) {
  1264. HLMatrixType ValMatTy = HLMatrixType::cast(TrueMat->getType());
  1265. DXASSERT(MatCondTy.getNumRows() == ValMatTy.getNumRows()
  1266. && MatCondTy.getNumColumns() == ValMatTy.getNumColumns(),
  1267. "Unexpected mismatch between ternary operator condition and value matrix dimensions.");
  1268. }
  1269. #endif
  1270. IRBuilder<> Builder(Call);
  1271. Value *LoweredCond = getLoweredByValOperand(Cond, Builder);
  1272. Value *LoweredTrueVec = getLoweredByValOperand(TrueMat, Builder);
  1273. Value *LoweredFalseVec = getLoweredByValOperand(FalseMat, Builder);
  1274. Value *Result = UndefValue::get(LoweredTrueVec->getType());
  1275. bool IsScalarCond = !LoweredCond->getType()->isVectorTy();
  1276. unsigned NumElems = Result->getType()->getVectorNumElements();
  1277. for (uint64_t ElemIdx = 0; ElemIdx < NumElems; ++ElemIdx) {
  1278. Value *ElemCond = IsScalarCond ? LoweredCond
  1279. : Builder.CreateExtractElement(LoweredCond, ElemIdx);
  1280. Value *ElemTrueVal = Builder.CreateExtractElement(LoweredTrueVec, ElemIdx);
  1281. Value *ElemFalseVal = Builder.CreateExtractElement(LoweredFalseVec, ElemIdx);
  1282. Value *ResultElem = Builder.CreateSelect(ElemCond, ElemTrueVal, ElemFalseVal);
  1283. Result = Builder.CreateInsertElement(Result, ResultElem, ElemIdx);
  1284. }
  1285. return Result;
  1286. }