DxilCleanup.cpp 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilCleanup.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Optimization of DXIL after conversion from DXBC. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. //===----------------------------------------------------------------------===//
  12. // DXIL Cleanup Transformation
  13. //===----------------------------------------------------------------------===//
  14. //
  15. // The pass cleans up DXIL obtained after conversion from DXBC.
  16. // Essentially, the pass construct efficient SSA for DXBC r-registers and
  17. // performs the following:
  18. // 1. Removes TempRegStore/TempRegLoad calls, replacing DXBC registers with
  19. // either temporary or global LLVM values.
  20. // 2. Minimizes the number of bitcasts induced by the lack of types in DXBC.
  21. // 3. Removes helper operations to support DXBC conditionals, translated to i1.
  22. // 4. Recovers doubles from pairs of 32-bit DXBC registers.
  23. // 5. Removes MinPrecXRegLoad and MinPrecXRegStore for DXBC indexable,
  24. // min-presicion x-registers.
  25. //
  26. // Clarification of important algorithmic decisions:
  27. // 1. A live range (LR) is all defs connected via phi-nodes. A straightforward
  28. // recursive algorithm is used to collect LR's set of defs.
  29. // 2. Live ranges are "connected" to other liver ranges via DXIL bitcasts.
  30. // This creates a bitcast graphs.
  31. // 3. Live ranges are assigned types based on the number of float (F) or
  32. // integer (I) defs. A bitcast def initially has an unknow type (U).
  33. // Each LR is assigned type only once. LRs are processed in dynamic order
  34. // biased towards LRs with known types, e.g., numF > numI + numU.
  35. // When a LR is assigned final type, emanating bitcasts become "resolved"
  36. // and contribute desired type to the neighboring LRs.
  37. // 4. After all LRs are processed, each LR is assigned final type based on
  38. // the number of F and I defs. If type changed from the initial assumption,
  39. // the code is rewritten accordingly: new bitcasts are inserted for
  40. // correctness.
  41. // 5. After every LR type is finalized, chains of bitcasts are cleaned up.
  42. // 6. The algorithm splits 16- and 32-bit LRs.
  43. // 7. Registers that are used in an entry and another subroutine are
  44. // represented as global variables.
  45. //
  46. #include "DxilConvPasses/DxilCleanup.h"
  47. #include "dxc/Support/Global.h"
  48. #include "dxc/DXIL/DxilModule.h"
  49. #include "dxc/DXIL/DxilOperations.h"
  50. #include "dxc/DXIL/DxilInstructions.h"
  51. #include "llvm/Support/Casting.h"
  52. #include "llvm/Support/raw_ostream.h"
  53. #include "llvm/Support/Debug.h"
  54. #include "llvm/ADT/PostOrderIterator.h"
  55. #include "llvm/IR/LLVMContext.h"
  56. #include "llvm/IR/Module.h"
  57. #include "llvm/IR/Function.h"
  58. #include "llvm/IR/BasicBlock.h"
  59. #include "llvm/IR/Instructions.h"
  60. #include "llvm/IR/Constants.h"
  61. #include "llvm/IR/CFG.h"
  62. #include "llvm/IR/IRBuilder.h"
  63. #include "llvm/IR/LegacyPassManager.h"
  64. #include "llvm/IR/Verifier.h"
  65. #include "llvm/Transforms/Scalar.h"
  66. #include "llvm/ADT/SmallVector.h"
  67. #include "llvm/ADT/DenseMap.h"
  68. #include "llvm/ADT/MapVector.h"
  69. #include <utility>
  70. #include <vector>
  71. #include <set>
  72. #include <queue>
  73. #include <algorithm>
  74. using namespace llvm;
  75. using namespace llvm::legacy;
  76. using namespace hlsl;
  77. using std::string;
  78. using std::vector;
  79. using std::set;
  80. using std::pair;
  81. #define DXILCLEANUP_DBG 0
  82. #define DEBUG_TYPE "dxilcleanup"
  83. #if DXILCLEANUP_DBG
  84. static void debugprint(const char *banner, Module &M) {
  85. std::string buf;
  86. raw_string_ostream os(buf);
  87. os << banner << "\n";
  88. M.print(os, nullptr);
  89. os.flush();
  90. std::puts(buf.c_str());
  91. }
  92. #endif
  93. namespace DxilCleanupNS {
  94. /// Use this class to optimize DXIL after conversion from DXBC.
  95. class DxilCleanup : public ModulePass {
  96. public:
  97. static char ID;
  98. DxilCleanup() : ModulePass(ID), m_pCtx(nullptr), m_pModule(nullptr) {
  99. initializeDxilCleanupPass(*PassRegistry::getPassRegistry());
  100. }
  101. virtual bool runOnModule(Module &M);
  102. struct LiveRange {
  103. unsigned id;
  104. SmallVector<Value *, 4> defs;
  105. SmallDenseMap<unsigned, unsigned, 4> bitcastMap;
  106. unsigned numI;
  107. unsigned numF;
  108. unsigned numU;
  109. Type *pNewType;
  110. LiveRange() : id(0), numI(0), numF(0), numU(0), pNewType(nullptr) {}
  111. LiveRange operator=(const LiveRange &) = delete;
  112. // I cannot delete these constructors, because vector depends on them, even if I never trigger them.
  113. // So assert if they are hit instead.
  114. LiveRange(const LiveRange &other)
  115. : id(other.id), numI(other.numI), numF(other.numF), numU(other.numU), pNewType(other.pNewType),
  116. defs(other.defs), bitcastMap(other.bitcastMap)
  117. { DXASSERT_NOMSG(false); }
  118. LiveRange(LiveRange &&other)
  119. : id(other.id), numI(other.numI), numF(other.numF), numU(other.numU), pNewType(other.pNewType),
  120. defs(std::move(other.defs)), bitcastMap(std::move(other.bitcastMap))
  121. { DXASSERT_NOMSG(false); }
  122. unsigned GetCaseNumber() const;
  123. void GuessType(LLVMContext &Ctx);
  124. bool operator<(const LiveRange &other) const;
  125. };
  126. private:
  127. const unsigned kRegCompAlignment = 4;
  128. LLVMContext *m_pCtx;
  129. Module *m_pModule;
  130. DxilModule *m_pDxilModule;
  131. vector<LiveRange> m_LiveRanges;
  132. DenseMap<Value *, unsigned> m_LiveRangeMap;
  133. void OptimizeIdxRegDecls();
  134. bool OptimizeIdxRegDecls_CollectUsage(Value *pDecl, unsigned &numF, unsigned &numI);
  135. bool OptimizeIdxRegDecls_CollectUsageForUser(User *U, bool bFlt, bool bInt, unsigned &numF, unsigned &numI);
  136. Type *OptimizeIdxRegDecls_DeclareType(Type *pOldType);
  137. void OptimizeIdxRegDecls_ReplaceDecl(Value *pOldDecl, Value *pNewDecl, vector<Instruction*> &InstrToErase);
  138. void OptimizeIdxRegDecls_ReplaceGEPUse(Value *pOldGEPUser, Value *pNewGEP, Value *pOldDecl, Value *pNewDecl, vector<Instruction*> &InstrToErase);
  139. void RemoveRegLoadStore();
  140. void ConstructSSA();
  141. void CollectLiveRanges();
  142. void CountLiveRangeRec(unsigned LRId, Instruction *pInst);
  143. void RecoverLiveRangeRec(LiveRange &LR, Instruction *pInst);
  144. void InferLiveRangeTypes();
  145. void ChangeLiveRangeTypes();
  146. void CleanupPatterns();
  147. void RemoveDeadCode();
  148. Value *CastValue(Value *pValue, Type *pToType, Instruction *pOrigInst);
  149. bool IsDxilBitcast(Value *pValue);
  150. ArrayType *GetDeclArrayType(Type *pSrcType);
  151. Type *GetDeclScalarType(Type *pSrcType);
  152. };
  153. char DxilCleanup::ID = 0;
  154. //------------------------------------------------------------------------------
  155. //
  156. // DxilCleanup methods.
  157. //
  158. bool DxilCleanup::runOnModule(Module &M) {
  159. m_pModule = &M;
  160. m_pCtx = &M.getContext();
  161. m_pDxilModule = &m_pModule->GetOrCreateDxilModule();
  162. OptimizeIdxRegDecls();
  163. RemoveRegLoadStore();
  164. ConstructSSA();
  165. CollectLiveRanges();
  166. InferLiveRangeTypes();
  167. ChangeLiveRangeTypes();
  168. CleanupPatterns();
  169. RemoveDeadCode();
  170. return true;
  171. }
  172. void DxilCleanup::OptimizeIdxRegDecls() {
  173. // 1. Convert global x-register decl into alloca if used only in one function.
  174. for (auto itGV = m_pModule->global_begin(), endGV = m_pModule->global_end(); itGV != endGV; ) {
  175. GlobalVariable *GV = itGV;
  176. ++itGV;
  177. if (GV->isConstant() || GV->getLinkage() != GlobalValue::InternalLinkage) continue;
  178. PointerType *pPtrType = dyn_cast<PointerType>(GV->getType());
  179. if (!pPtrType || pPtrType->getAddressSpace() != DXIL::kDefaultAddrSpace) continue;
  180. Type *pElemType = pPtrType->getElementType();
  181. Function *F = nullptr;
  182. for (User *U : GV->users()) {
  183. Instruction *I = dyn_cast<Instruction>(U);
  184. if (!I || (F && I->getParent()->getParent() != F)) {
  185. F = nullptr;
  186. break;
  187. }
  188. F = cast<Function>(I->getParent()->getParent());
  189. }
  190. if (F) {
  191. // Promote to alloca.
  192. Instruction *pAnchor = F->getEntryBlock().begin();
  193. AllocaInst *AI = new AllocaInst(pElemType, nullptr, GV->getName(), pAnchor);
  194. AI->setAlignment(GV->getAlignment());
  195. GV->replaceAllUsesWith(AI);
  196. GV->eraseFromParent();
  197. }
  198. }
  199. // 2. Collect x-register alloca usage stats and change type, if profitable.
  200. for (auto itF = m_pModule->begin(), endFn = m_pModule->end(); itF != endFn; ++itF) {
  201. Function *F = itF;
  202. if (F->empty()) continue;
  203. BasicBlock *pEntryBB = &F->getEntryBlock();
  204. vector<Instruction*> InstrToErase;
  205. for (auto itInst = pEntryBB->begin(), endInst = pEntryBB->end(); itInst != endInst; ++itInst) {
  206. AllocaInst *AI = dyn_cast<AllocaInst>(itInst);
  207. if (!AI) continue;
  208. Type *pScalarType = GetDeclScalarType(AI->getType());
  209. if (pScalarType != Type::getFloatTy(*m_pCtx) && pScalarType != Type::getHalfTy(*m_pCtx) &&
  210. pScalarType != Type::getInt32Ty(*m_pCtx) && pScalarType != Type::getInt16Ty(*m_pCtx)) {
  211. continue;
  212. }
  213. // Collect usage stats and potentially change decl type.
  214. unsigned numF, numI;
  215. if (OptimizeIdxRegDecls_CollectUsage(AI, numF, numI)) {
  216. Type *pScalarType = GetDeclScalarType(AI->getType());
  217. if ((pScalarType->isFloatingPointTy() && numI > numF) ||
  218. (pScalarType->isIntegerTy() && numF >= numI)) {
  219. Type *pNewType = OptimizeIdxRegDecls_DeclareType(AI->getType());
  220. if (pNewType) {
  221. // Replace alloca.
  222. AllocaInst *AI2 = new AllocaInst(pNewType, nullptr, AI->getName(), AI);
  223. AI2->setAlignment(AI->getAlignment());
  224. OptimizeIdxRegDecls_ReplaceDecl(AI, AI2, InstrToErase);
  225. InstrToErase.emplace_back(AI);
  226. }
  227. }
  228. }
  229. }
  230. for (auto *I : InstrToErase) {
  231. I->eraseFromParent();
  232. }
  233. }
  234. // 3. Collect x-register global decl usage stats and change type, if profitable.
  235. llvm::SmallVector<GlobalVariable*, 4> GVWorklist;
  236. for (auto itGV = m_pModule->global_begin(), endGV = m_pModule->global_end(); itGV != endGV; ) {
  237. GlobalVariable *pOldGV = itGV;
  238. ++itGV;
  239. if (pOldGV->isConstant()) continue;
  240. PointerType *pOldPtrType = dyn_cast<PointerType>(pOldGV->getType());
  241. if (!pOldPtrType || pOldPtrType->getAddressSpace() != DXIL::kDefaultAddrSpace) continue;
  242. unsigned numF, numI;
  243. if (OptimizeIdxRegDecls_CollectUsage(pOldGV, numF, numI)) {
  244. Type *pScalarType = GetDeclScalarType(pOldGV->getType());
  245. if ((pScalarType->isFloatingPointTy() && numI > numF) ||
  246. (pScalarType->isIntegerTy() && numF >= numI)) {
  247. GVWorklist.push_back(pOldGV);
  248. }
  249. }
  250. }
  251. for (auto pOldGV : GVWorklist) {
  252. if (Type *pNewType = OptimizeIdxRegDecls_DeclareType(pOldGV->getType())) {
  253. // Replace global decl.
  254. PointerType *pOldPtrType = dyn_cast<PointerType>(pOldGV->getType());
  255. GlobalVariable *pNewGV = new GlobalVariable(*m_pModule, pNewType, false, pOldGV->getLinkage(),
  256. UndefValue::get(pNewType), pOldGV->getName(),
  257. nullptr, pOldGV->getThreadLocalMode(),
  258. pOldPtrType->getAddressSpace());
  259. vector<Instruction*> InstrToErase;
  260. OptimizeIdxRegDecls_ReplaceDecl(pOldGV, pNewGV, InstrToErase);
  261. for (auto *I : InstrToErase) {
  262. I->eraseFromParent();
  263. }
  264. pOldGV->eraseFromParent();
  265. }
  266. }
  267. }
  268. ArrayType *DxilCleanup::GetDeclArrayType(Type *pSrcType) {
  269. PointerType *pPtrType = dyn_cast<PointerType>(pSrcType);
  270. if (!pPtrType) return nullptr;
  271. if (ArrayType *pArrayType = dyn_cast<ArrayType>(pPtrType->getElementType())) {
  272. return pArrayType;
  273. }
  274. return nullptr;
  275. }
  276. Type *DxilCleanup::GetDeclScalarType(Type *pSrcType) {
  277. PointerType *pPtrType = dyn_cast<PointerType>(pSrcType);
  278. if (!pPtrType) return nullptr;
  279. Type *pScalarType = pPtrType->getElementType();
  280. if (ArrayType *pArrayType = dyn_cast<ArrayType>(pScalarType)) {
  281. pScalarType = pArrayType->getArrayElementType();
  282. }
  283. return pScalarType;
  284. }
  285. Type *DxilCleanup::OptimizeIdxRegDecls_DeclareType(Type *pOldType) {
  286. Type *pNewType = nullptr;
  287. Type *pScalarType = GetDeclScalarType(pOldType);
  288. if (ArrayType *pArrayType = GetDeclArrayType(pOldType)) {
  289. uint64_t ArraySize = pArrayType->getArrayNumElements();
  290. if (pScalarType == Type::getFloatTy(*m_pCtx)) {
  291. pNewType = ArrayType::get(Type::getInt32Ty(*m_pCtx), ArraySize);
  292. } else if (pScalarType == Type::getHalfTy(*m_pCtx)) {
  293. pNewType = ArrayType::get(Type::getInt16Ty(*m_pCtx), ArraySize);
  294. } else if (pScalarType == Type::getInt32Ty(*m_pCtx)) {
  295. pNewType = ArrayType::get(Type::getFloatTy(*m_pCtx), ArraySize);
  296. } else if (pScalarType == Type::getInt16Ty(*m_pCtx)) {
  297. pNewType = ArrayType::get(Type::getHalfTy(*m_pCtx), ArraySize);
  298. } else {
  299. IFT(DXC_E_OPTIMIZATION_FAILED);
  300. }
  301. } else {
  302. if (pScalarType == Type::getFloatTy(*m_pCtx)) {
  303. pNewType = Type::getInt32Ty(*m_pCtx);
  304. } else if (pScalarType == Type::getHalfTy(*m_pCtx)) {
  305. pNewType = Type::getInt16Ty(*m_pCtx);
  306. } else if (pScalarType == Type::getInt32Ty(*m_pCtx)) {
  307. pNewType = Type::getFloatTy(*m_pCtx);
  308. } else if (pScalarType == Type::getInt16Ty(*m_pCtx)) {
  309. pNewType = Type::getHalfTy(*m_pCtx);
  310. } else {
  311. IFT(DXC_E_OPTIMIZATION_FAILED);
  312. }
  313. }
  314. return pNewType;
  315. }
  316. bool DxilCleanup::OptimizeIdxRegDecls_CollectUsage(Value *pDecl, unsigned &numF, unsigned &numI) {
  317. numF = numI = 0;
  318. Type *pScalarType = GetDeclScalarType(pDecl->getType());
  319. if (!pScalarType) return false;
  320. bool bFlt = pScalarType == Type::getFloatTy(*m_pCtx) || pScalarType == Type::getHalfTy(*m_pCtx);
  321. bool bInt = pScalarType == Type::getInt32Ty(*m_pCtx) || pScalarType == Type::getInt16Ty(*m_pCtx);
  322. if (!(bFlt || bInt)) return false;
  323. for (User *U : pDecl->users()) {
  324. if (GetElementPtrInst *pGEP = dyn_cast<GetElementPtrInst>(U)) {
  325. for (User *U2 : pGEP->users()) {
  326. if (!OptimizeIdxRegDecls_CollectUsageForUser(U2, bFlt, bInt, numF, numI))
  327. return false;
  328. }
  329. } else if (GEPOperator *pGEP = dyn_cast<GEPOperator>(U)) {
  330. for (User *U2 : pGEP->users()) {
  331. if (!OptimizeIdxRegDecls_CollectUsageForUser(U2, bFlt, bInt, numF, numI))
  332. return false;
  333. }
  334. } else if (BitCastInst *pBC = dyn_cast<BitCastInst>(U)) {
  335. if (pBC->getType() != Type::getDoublePtrTy(*m_pCtx)) return false;
  336. } else {
  337. return false;
  338. }
  339. }
  340. return true;
  341. }
  342. bool DxilCleanup::OptimizeIdxRegDecls_CollectUsageForUser(User *U, bool bFlt, bool bInt, unsigned &numF, unsigned &numI) {
  343. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  344. for (User *U2 : LI->users()) {
  345. if (!IsDxilBitcast(U2)) {
  346. if (bFlt) numF++;
  347. if (bInt) numI++;
  348. } else {
  349. if (bFlt) numI++;
  350. if (bInt) numF++;
  351. }
  352. }
  353. } else if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
  354. Value *pValue = SI->getValueOperand();
  355. if (!IsDxilBitcast(pValue)) {
  356. if (bFlt) numF++;
  357. if (bInt) numI++;
  358. } else {
  359. if (bFlt) numI++;
  360. if (bInt) numF++;
  361. }
  362. } else {
  363. return false;
  364. }
  365. return true;
  366. }
  367. void DxilCleanup::OptimizeIdxRegDecls_ReplaceDecl(Value *pOldDecl, Value *pNewDecl,
  368. vector<Instruction*> &InstrToErase) {
  369. for (auto itU = pOldDecl->use_begin(), endU = pOldDecl->use_end(); itU != endU; ++itU) {
  370. User *I = itU->getUser();
  371. if (GetElementPtrInst *pOldGEP = dyn_cast<GetElementPtrInst>(I)) {
  372. // Case 1. Load.
  373. // %44 = getelementptr [24 x float], [24 x float]* %dx.v32.x0, i32 0, i32 %43
  374. // %45 = load float, float* %44, align 4
  375. // %46 = add float %45, ...
  376. // becomes
  377. // %44 = getelementptr [24 x i32], [24 x i32]* %dx.v32.x0, i32 0, i32 %43
  378. // %45 = load i32, i32* %44, align 4
  379. // %t1 = call float @dx.op.bitcastI32toF32 i32 %45
  380. // %46 = add i32 %t1, ...
  381. //
  382. // Case 2. Store.
  383. // %31 = add float ...
  384. // %32 = getelementptr [24 x float], [24 x float]* %dx.v32.x0, i32 0, i32 16
  385. // store float %31, float* %32, align 4
  386. // becomes
  387. // %31 = add float ...
  388. // %32 = getelementptr [24 x i32], [24 x i32]* %dx.v32.x0, i32 0, i32 16
  389. // %t1 = call i32 @dx.op.bitcastF32toI32 float %31
  390. // store i32 %t1, i32* %32, align 4
  391. //
  392. SmallVector<Value *, 4> GEPIndices;
  393. for (auto i = pOldGEP->idx_begin(), e = pOldGEP->idx_end(); i != e; i++) {
  394. GEPIndices.push_back(*i);
  395. }
  396. GetElementPtrInst *pNewGEP = GetElementPtrInst::Create(nullptr, pNewDecl, GEPIndices, pOldGEP->getName(), pOldGEP->getNextNode());
  397. for (auto itU2 = pOldGEP->use_begin(), endU2 = pOldGEP->use_end(); itU2 != endU2; ++itU2) {
  398. Value *pOldGEPUser = itU2->getUser();
  399. OptimizeIdxRegDecls_ReplaceGEPUse(pOldGEPUser, pNewGEP, pOldDecl, pNewDecl, InstrToErase);
  400. }
  401. InstrToErase.emplace_back(pOldGEP);
  402. } else if (GEPOperator *pOldGEP = dyn_cast<GEPOperator>(I)) {
  403. // The cases are the same as for the GetElementPtrInst above.
  404. SmallVector<Value *, 4> GEPIndices;
  405. for (auto i = pOldGEP->idx_begin(), e = pOldGEP->idx_end(); i != e; i++) {
  406. GEPIndices.push_back(*i);
  407. }
  408. Type *pNewGEPElemType = cast<PointerType>(pNewDecl->getType())->getElementType();
  409. Constant *pNewGEPOp = ConstantExpr::getGetElementPtr(pNewGEPElemType, cast<Constant>(pNewDecl), GEPIndices, pOldGEP->isInBounds());
  410. GEPOperator *pNewGEP = cast<GEPOperator>(pNewGEPOp);
  411. for (auto itU2 = pOldGEP->use_begin(), endU2 = pOldGEP->use_end(); itU2 != endU2; ++itU2) {
  412. Value *pOldGEPUser = itU2->getUser();
  413. OptimizeIdxRegDecls_ReplaceGEPUse(pOldGEPUser, pNewGEP, pOldDecl, pNewDecl, InstrToErase);
  414. }
  415. } else if (BitCastInst *pOldBC = dyn_cast<BitCastInst>(I)) {
  416. // %1 = bitcast [24 x float]* %dx.v32.x0 to double*
  417. // becomes
  418. // %1 = bitcast [24 x i32]* %dx.v32.x0 to double*
  419. BitCastInst *pNewBC = new BitCastInst(pNewDecl, pOldBC->getType(), pOldBC->getName(), pOldBC->getNextNode());
  420. pOldBC->replaceAllUsesWith(pNewBC);
  421. InstrToErase.emplace_back(pOldBC);
  422. } else {
  423. IFT(DXC_E_OPTIMIZATION_FAILED);
  424. }
  425. }
  426. }
  427. void DxilCleanup::OptimizeIdxRegDecls_ReplaceGEPUse(Value *pOldGEPUser, Value *pNewGEP,
  428. Value *pOldDecl, Value *pNewDecl,
  429. vector<Instruction*> &InstrToErase) {
  430. if (LoadInst *pOldLI = dyn_cast<LoadInst>(pOldGEPUser)) {
  431. LoadInst *pNewLI = new LoadInst(pNewGEP, pOldLI->getName(), pOldLI->getNextNode());
  432. pNewLI->setAlignment(pOldLI->getAlignment());
  433. Value *pNewValue = CastValue(pNewLI, GetDeclScalarType(pOldDecl->getType()), pNewLI->getNextNode());
  434. pOldLI->replaceAllUsesWith(pNewValue);
  435. InstrToErase.emplace_back(pOldLI);
  436. } else if (StoreInst *pOldSI = dyn_cast<StoreInst>(pOldGEPUser)) {
  437. Value *pOldValue = pOldSI->getValueOperand();
  438. Value *pNewValue = CastValue(pOldValue, GetDeclScalarType(pNewDecl->getType()), pOldSI);
  439. StoreInst *pNewSI = new StoreInst(pNewValue, pNewGEP, pOldSI->getNextNode());
  440. pNewSI->setAlignment(pOldSI->getAlignment());
  441. InstrToErase.emplace_back(pOldSI);
  442. } else {
  443. IFT(DXC_E_OPTIMIZATION_FAILED);
  444. }
  445. }
  446. void DxilCleanup::RemoveRegLoadStore() {
  447. struct RegRec {
  448. unsigned numI32;
  449. unsigned numF32;
  450. unsigned numI16;
  451. unsigned numF16;
  452. Value *pDecl32;
  453. Value *pDecl16;
  454. RegRec() : numI32(0), numF32(0), numI16(0), numF16(0), pDecl32(nullptr), pDecl16(nullptr) {}
  455. };
  456. struct FuncRec {
  457. MapVector<unsigned, RegRec> RegMap;
  458. bool bEntry;
  459. bool bCallsOtherFunc;
  460. FuncRec() : bEntry(false), bCallsOtherFunc(false) {}
  461. };
  462. MapVector<Function *, FuncRec> FuncMap;
  463. // 1. For each r-register, collect usage stats.
  464. for (auto itF = m_pModule->begin(), endFn = m_pModule->end(); itF != endFn; ++itF) {
  465. Function *F = itF;
  466. if (F->empty()) continue;
  467. DXASSERT_NOMSG(FuncMap.find(F) == FuncMap.end());
  468. FuncRec &FR = FuncMap[F];
  469. // Detect entry.
  470. if (F == m_pDxilModule->GetEntryFunction() ||
  471. F == m_pDxilModule->GetPatchConstantFunction()) {
  472. FR.bEntry = true;
  473. }
  474. for (auto itBB = F->begin(), endBB = F->end(); itBB != endBB; ++itBB) {
  475. BasicBlock *BB = itBB;
  476. for (auto itInst = BB->begin(), endInst = BB->end(); itInst != endInst; ++itInst) {
  477. CallInst *CI = dyn_cast<CallInst>(itInst);
  478. if (!CI) continue;
  479. if (!OP::IsDxilOpFuncCallInst(CI)) {
  480. FuncMap[F].bCallsOtherFunc = true;
  481. continue;
  482. }
  483. // Obtain register index for TempRegLoad/TempRegStore.
  484. unsigned regIdx = 0;
  485. Type *pValType = nullptr;
  486. if (DxilInst_TempRegLoad TRL = DxilInst_TempRegLoad(CI)) {
  487. regIdx = dyn_cast<ConstantInt>(TRL.get_index())->getZExtValue();
  488. pValType = CI->getType();
  489. } else if (DxilInst_TempRegStore TRS = DxilInst_TempRegStore(CI)) {
  490. regIdx = dyn_cast<ConstantInt>(TRS.get_index())->getZExtValue();
  491. pValType = TRS.get_value()->getType();
  492. } else {
  493. continue;
  494. }
  495. // Update register usage.
  496. RegRec &reg = FR.RegMap[regIdx];
  497. if (pValType == Type::getFloatTy(*m_pCtx)) {
  498. reg.numF32++;
  499. } else if (pValType == Type::getInt32Ty(*m_pCtx)) {
  500. reg.numI32++;
  501. } else if (pValType == Type::getHalfTy(*m_pCtx)) {
  502. reg.numF16++;
  503. } else if (pValType == Type::getInt16Ty(*m_pCtx)) {
  504. reg.numI16++;
  505. } else {
  506. IFT(DXC_E_OPTIMIZATION_FAILED);
  507. }
  508. }
  509. }
  510. }
  511. // 2. Declare local and global variables to represent each r-register.
  512. for (auto &itF : FuncMap) {
  513. Function *F = itF.first;
  514. FuncRec &FR = itF.second;
  515. for (auto &itReg : FR.RegMap) {
  516. unsigned regIdx = itReg.first;
  517. RegRec &reg = itReg.second;
  518. DXASSERT_NOMSG(reg.pDecl16 == nullptr && reg.pDecl32 == nullptr);
  519. enum class DeclKind { None, Alloca, Global };
  520. DeclKind Decl32Kind = (reg.numF32 + reg.numI32) == 0 ? DeclKind::None : DeclKind::Alloca;
  521. DeclKind Decl16Kind = (reg.numF16 + reg.numI16) == 0 ? DeclKind::None : DeclKind::Alloca;
  522. DXASSERT_NOMSG(Decl32Kind == DeclKind::Alloca || Decl16Kind == DeclKind::Alloca);
  523. unsigned numF32 = reg.numF32, numI32 = reg.numI32, numF16 = reg.numF16, numI16 = reg.numI16;
  524. if (!FR.bEntry || FR.bCallsOtherFunc) {
  525. // Check if register is used in another function.
  526. for (auto &itF2 : FuncMap) {
  527. Function *F2 = itF2.first;
  528. FuncRec &FR2 = itF2.second;
  529. if (F2 == F || (FR.bEntry && FR2.bEntry)) continue;
  530. auto itReg2 = FR2.RegMap.find(regIdx);
  531. if (itReg2 == FR2.RegMap.end()) continue;
  532. RegRec &reg2 = itReg2->second;
  533. if (Decl32Kind == DeclKind::Alloca && (reg2.numF32 + reg2.numI32) > 0) {
  534. Decl32Kind = DeclKind::Global;
  535. }
  536. if (Decl16Kind == DeclKind::Alloca && (reg2.numF16 + reg2.numI16) > 0) {
  537. Decl16Kind = DeclKind::Global;
  538. }
  539. numF32 += reg2.numF32;
  540. numI32 += reg2.numI32;
  541. numF16 += reg2.numF16;
  542. numI16 += reg2.numI16;
  543. }
  544. }
  545. // Declare variables.
  546. if (Decl32Kind == DeclKind::Alloca) {
  547. Twine regName = Twine("dx.v32.r") + Twine(regIdx);
  548. Type *pDeclType = numF32 >= numI32 ? Type::getFloatTy(*m_pCtx) : Type::getInt32Ty(*m_pCtx);
  549. Instruction *pAnchor = F->getEntryBlock().begin();
  550. AllocaInst *AI = new AllocaInst(pDeclType, nullptr, regName, pAnchor);
  551. AI->setAlignment(kRegCompAlignment);
  552. reg.pDecl32 = AI;
  553. }
  554. if (Decl16Kind == DeclKind::Alloca) {
  555. Twine regName = Twine("dx.v16.r") + Twine(regIdx);
  556. Type *pDeclType = numF16 >= numI16 ? Type::getHalfTy(*m_pCtx) : Type::getInt16Ty(*m_pCtx);
  557. Instruction *pAnchor = F->getEntryBlock().begin();
  558. AllocaInst *AI = new AllocaInst(pDeclType, nullptr, regName, pAnchor);
  559. AI->setAlignment(kRegCompAlignment);
  560. reg.pDecl16 = AI;
  561. }
  562. if (Decl32Kind == DeclKind::Global) {
  563. SmallVector<char, 16> regName;
  564. (Twine("dx.v32.r") + Twine(regIdx)).toStringRef(regName);
  565. Type *pDeclType = numF32 >= numI32 ? Type::getFloatTy(*m_pCtx) : Type::getInt32Ty(*m_pCtx);
  566. GlobalVariable *GV = m_pModule->getGlobalVariable(StringRef(regName.data(), regName.size()), true);
  567. if (!GV) {
  568. GV = new GlobalVariable(*m_pModule, pDeclType,
  569. false, GlobalValue::InternalLinkage,
  570. UndefValue::get(pDeclType),
  571. regName, nullptr,
  572. GlobalVariable::NotThreadLocal, DXIL::kDefaultAddrSpace);
  573. }
  574. GV->setAlignment(kRegCompAlignment);
  575. reg.pDecl32 = GV;
  576. }
  577. if (Decl16Kind == DeclKind::Global) {
  578. SmallVector<char, 16> regName;
  579. (Twine("dx.v16.r") + Twine(regIdx)).toStringRef(regName);
  580. Type *pDeclType = numF16 >= numI16 ? Type::getHalfTy(*m_pCtx) : Type::getInt16Ty(*m_pCtx);
  581. GlobalVariable *GV = m_pModule->getGlobalVariable(StringRef(regName.data(), regName.size()), true);
  582. if (!GV) {
  583. GV = new GlobalVariable(*m_pModule, pDeclType,
  584. false, GlobalValue::InternalLinkage,
  585. UndefValue::get(pDeclType),
  586. regName, nullptr,
  587. GlobalVariable::NotThreadLocal, DXIL::kDefaultAddrSpace);
  588. }
  589. GV->setAlignment(kRegCompAlignment);
  590. reg.pDecl16 = GV;
  591. }
  592. }
  593. }
  594. // 3. Replace TempRegLoad/Store with load/store to declared variables.
  595. for (auto itFn = m_pModule->begin(), endFn = m_pModule->end(); itFn != endFn; ++itFn) {
  596. Function *F = itFn;
  597. if (F->empty()) continue;
  598. DXASSERT_NOMSG(FuncMap.find(F) != FuncMap.end());
  599. FuncRec &FR = FuncMap[F];
  600. for (auto itBB = F->begin(), endBB = F->end(); itBB != endBB; ++itBB) {
  601. BasicBlock *BB = itBB;
  602. for (auto itInst = BB->begin(), endInst = BB->end(); itInst != endInst; ) {
  603. Instruction *CI = itInst;
  604. if (DxilInst_TempRegLoad TRL = DxilInst_TempRegLoad(CI)) {
  605. // Replace TempRegLoad intrinsic with a load.
  606. unsigned regIdx = dyn_cast<ConstantInt>(TRL.get_index())->getZExtValue();
  607. RegRec &reg = FR.RegMap[regIdx];
  608. Type *pValType = CI->getType();
  609. Value *pDecl = (pValType == Type::getFloatTy(*m_pCtx) ||
  610. pValType == Type::getInt32Ty(*m_pCtx)) ? reg.pDecl32 : reg.pDecl16;
  611. DXASSERT_NOMSG(pValType != nullptr);
  612. LoadInst *LI = new LoadInst(pDecl, nullptr, CI);
  613. Value *pBitcastLI = CastValue(LI, pValType, CI);
  614. CI->replaceAllUsesWith(pBitcastLI);
  615. ++itInst;
  616. CI->eraseFromParent();
  617. } else if (DxilInst_TempRegStore TRS = DxilInst_TempRegStore(CI)) {
  618. // Replace TempRegStore with a store.
  619. unsigned regIdx = dyn_cast<ConstantInt>(TRS.get_index())->getZExtValue();
  620. RegRec &reg = FR.RegMap[regIdx];
  621. Value *pValue = TRS.get_value();
  622. Type *pValType = pValue->getType();
  623. Value *pDecl = (pValType == Type::getFloatTy(*m_pCtx) ||
  624. pValType == Type::getInt32Ty(*m_pCtx)) ? reg.pDecl32 : reg.pDecl16;
  625. DXASSERT_NOMSG(pValType != nullptr);
  626. Type *pDeclType = cast<PointerType>(pDecl->getType())->getElementType();
  627. Value *pBitcastValueToStore = CastValue(pValue, pDeclType, CI);
  628. StoreInst *SI = new StoreInst(pBitcastValueToStore, pDecl, CI);
  629. CI->replaceAllUsesWith(SI);
  630. ++itInst;
  631. CI->eraseFromParent();
  632. } else {
  633. ++itInst;
  634. }
  635. }
  636. }
  637. }
  638. }
  639. void DxilCleanup::ConstructSSA() {
  640. // Construct SSA for r-register live ranges.
  641. #if DXILCLEANUP_DBG
  642. DXASSERT_NOMSG(!verifyModule(*m_pModule));
  643. #endif
  644. PassManager PM;
  645. PM.add(createPromoteMemoryToRegisterPass());
  646. PM.run(*m_pModule);
  647. }
  648. // Note: this two-pass initialization scheme limits the algorithm to handling 2^31 live ranges, instead of 2^32.
  649. #define LIVE_RANGE_UNINITIALIZED (((unsigned)1<<31))
  650. void DxilCleanup::CollectLiveRanges() {
  651. // 0. Count and allocate live ranges.
  652. unsigned LiveRangeCount = 0;
  653. for (auto itFn = m_pModule->begin(), endFn = m_pModule->end(); itFn != endFn; ++itFn) {
  654. Function *F = itFn;
  655. for (auto itBB = F->begin(), endBB = F->end(); itBB != endBB; ++itBB) {
  656. BasicBlock *BB = &*itBB;
  657. for (auto itInst = BB->begin(), endInst = BB->end(); itInst != endInst; ++itInst) {
  658. Instruction *I = &*itInst;
  659. Type *pType = I->getType();
  660. if (!pType->isFloatingPointTy() && !pType->isIntegerTy())
  661. continue;
  662. if (m_LiveRangeMap.find(I) != m_LiveRangeMap.end())
  663. continue;
  664. // Count live range.
  665. if (LiveRangeCount & LIVE_RANGE_UNINITIALIZED) {
  666. // Too many live ranges for our two-pass initialization scheme.
  667. DXASSERT(false, "otherwise, more than 2^31 live ranges!");
  668. return;
  669. }
  670. CountLiveRangeRec(LiveRangeCount, I);
  671. LiveRangeCount++;
  672. }
  673. }
  674. }
  675. m_LiveRanges.resize(LiveRangeCount);
  676. // 1. Recover live ranges.
  677. unsigned LRId = 0;
  678. for (auto itFn = m_pModule->begin(), endFn = m_pModule->end(); itFn != endFn; ++itFn) {
  679. Function *F = itFn;
  680. for (auto itBB = F->begin(), endBB = F->end(); itBB != endBB; ++itBB) {
  681. BasicBlock *BB = &*itBB;
  682. for (auto itInst = BB->begin(), endInst = BB->end(); itInst != endInst; ++itInst) {
  683. Instruction *I = &*itInst;
  684. Type *pType = I->getType();
  685. if (!pType->isFloatingPointTy() && !pType->isIntegerTy())
  686. continue;
  687. auto it = m_LiveRangeMap.find(I);
  688. DXASSERT(it != m_LiveRangeMap.end(), "otherwise, instruction not added to m_LiveRangeMap during counting stage");
  689. if (!(it->second & LIVE_RANGE_UNINITIALIZED)) {
  690. continue;
  691. }
  692. // Recover a live range.
  693. LiveRange &LR = m_LiveRanges[LRId];
  694. LR.id = LRId++;
  695. RecoverLiveRangeRec(LR, I);
  696. }
  697. }
  698. }
  699. // 2. Add bitcast edges.
  700. for (LiveRange &LR : m_LiveRanges) {
  701. for (Value *def : LR.defs) {
  702. for (User *U : def->users()) {
  703. if (IsDxilBitcast(U)) {
  704. DXASSERT_NOMSG(m_LiveRangeMap.find(U) != m_LiveRangeMap.end());
  705. DXASSERT(!(m_LiveRangeMap.find(U)->second & LIVE_RANGE_UNINITIALIZED), "otherwise, live range not initialized!");
  706. unsigned userLRId = m_LiveRangeMap[U];
  707. LR.bitcastMap[userLRId]++;
  708. }
  709. }
  710. }
  711. }
  712. #if DXILCLEANUP_DBG
  713. // Print live ranges.
  714. size_t NumDefs = 0;
  715. dbgs() << "Live ranges:\n";
  716. for (LiveRange &LR : m_LiveRanges) {
  717. NumDefs += LR.defs.size();
  718. dbgs() << "id=" << LR.id << ", F=" << LR.numF
  719. << ", I=" << LR.numI << ", U=" << LR.numU << ", defs = {";
  720. for (Value *D : LR.defs) {
  721. dbgs() << "\n";
  722. D->dump();
  723. }
  724. dbgs() << "}, edges = { ";
  725. bool bFirst = true;
  726. for (auto it : LR.bitcastMap) {
  727. if (!bFirst) {
  728. dbgs() << ", ";
  729. }
  730. dbgs() << "<" << it.first << "," << it.second << ">";
  731. bFirst= true;
  732. }
  733. dbgs() << "}\n";
  734. }
  735. DXASSERT_NOMSG(NumDefs == m_LiveRangeMap.size());
  736. #endif
  737. }
  738. void DxilCleanup::CountLiveRangeRec(unsigned LRId, Instruction *pInst) {
  739. if (m_LiveRangeMap.find(pInst) != m_LiveRangeMap.end()) {
  740. DXASSERT_NOMSG(m_LiveRangeMap[pInst] == (LRId | LIVE_RANGE_UNINITIALIZED));
  741. return;
  742. }
  743. m_LiveRangeMap[pInst] = LRId | LIVE_RANGE_UNINITIALIZED;
  744. for (User *U : pInst->users()) {
  745. if (PHINode *phi = dyn_cast<PHINode>(U)) {
  746. CountLiveRangeRec(LRId, phi);
  747. }
  748. }
  749. if (PHINode *phi = dyn_cast<PHINode>(pInst)) {
  750. for (Use &U : phi->operands()) {
  751. if (Instruction *I = dyn_cast<Instruction>(U.get())) {
  752. CountLiveRangeRec(LRId, I);
  753. }
  754. }
  755. }
  756. }
  757. void DxilCleanup::RecoverLiveRangeRec(LiveRange &LR, Instruction *pInst) {
  758. auto it = m_LiveRangeMap.find(pInst);
  759. DXASSERT_NOMSG(it != m_LiveRangeMap.end());
  760. if (!(it->second & LIVE_RANGE_UNINITIALIZED)) {
  761. return;
  762. }
  763. it->second &= ~LIVE_RANGE_UNINITIALIZED;
  764. LR.defs.push_back(pInst);
  765. for (User *U : pInst->users()) {
  766. if (PHINode *phi = dyn_cast<PHINode>(U)) {
  767. RecoverLiveRangeRec(LR, phi);
  768. } else if (IsDxilBitcast(U)) {
  769. LR.numU++;
  770. } else {
  771. Type *pType = pInst->getType();
  772. if (pType->isFloatingPointTy()) {
  773. LR.numF++;
  774. } else if (pType->isIntegerTy()) {
  775. LR.numI++;
  776. } else {
  777. DXASSERT_NOMSG(false);
  778. }
  779. }
  780. }
  781. if (PHINode *phi = dyn_cast<PHINode>(pInst)) {
  782. for (Use &U : phi->operands()) {
  783. Instruction *I = dyn_cast<Instruction>(U.get());
  784. if (I) {
  785. RecoverLiveRangeRec(LR, I);
  786. } else {
  787. DXASSERT_NOMSG(dyn_cast<Constant>(U.get()));
  788. }
  789. }
  790. }
  791. }
  792. unsigned DxilCleanup::LiveRange::GetCaseNumber() const {
  793. if (numI > (numF+numU) || numF > (numI+numU))
  794. return 1; // Type is known.
  795. if (numI == (numF+numU) || numF == (numI+numU))
  796. return 2; // Type may change, but unlikely.
  797. return 3; // Type is unknown yet. Postpone the decision until more live ranges have types.
  798. }
  799. void DxilCleanup::LiveRange::GuessType(LLVMContext &Ctx) {
  800. DXASSERT_NOMSG(pNewType == nullptr);
  801. bool bFlt = false;
  802. bool bInt = false;
  803. if (numU == 0) {
  804. bFlt = numF > numI;
  805. bInt = numI > numF;
  806. } else {
  807. if (numF >= numI + numU) {
  808. bFlt = true;
  809. } else if (numI >= numF + numU) {
  810. bInt = true;
  811. } else if (numF > numI) {
  812. bFlt = true;
  813. } else if (numI > numF) {
  814. bInt = true;
  815. }
  816. }
  817. Type *pDefType = (*defs.begin())->getType();
  818. if (!bFlt && !bInt) {
  819. bFlt = pDefType->isFloatingPointTy();
  820. bInt = pDefType->isIntegerTy();
  821. }
  822. if ((bFlt && pDefType->isFloatingPointTy()) ||
  823. (bInt && pDefType->isIntegerTy())) {
  824. pNewType = pDefType;
  825. return;
  826. }
  827. if (bFlt) {
  828. if (pDefType == Type::getInt16Ty(Ctx)) {
  829. pNewType = Type::getHalfTy(Ctx);
  830. } else if (pDefType == Type::getInt32Ty(Ctx)) {
  831. pNewType = Type::getFloatTy(Ctx);
  832. } else if (pDefType == Type::getInt64Ty(Ctx)) {
  833. pNewType = Type::getDoubleTy(Ctx);
  834. } else {
  835. DXASSERT_NOMSG(false);
  836. }
  837. } else if (bInt) {
  838. if (pDefType == Type::getHalfTy(Ctx)) {
  839. pNewType = Type::getInt16Ty(Ctx);
  840. } else if (pDefType == Type::getFloatTy(Ctx)) {
  841. pNewType = Type::getInt32Ty(Ctx);
  842. } else if (pDefType == Type::getDoubleTy(Ctx)) {
  843. pNewType = Type::getInt64Ty(Ctx);
  844. } else {
  845. DXASSERT_NOMSG(false);
  846. }
  847. } else {
  848. DXASSERT_NOMSG(false);
  849. }
  850. }
  851. bool DxilCleanup::LiveRange::operator<(const LiveRange &o) const {
  852. unsigned case1 = GetCaseNumber();
  853. unsigned case2 = o.GetCaseNumber();
  854. if (case1 != case2)
  855. return case1 < case2;
  856. switch (case1) {
  857. case 1:
  858. case 2: {
  859. unsigned n1 = std::max(numI, numF);
  860. unsigned n2 = std::max(o.numI, o.numF);
  861. if (n1 != n2)
  862. return n2 < n1;
  863. break;
  864. }
  865. case 3: {
  866. double r1 = (double)(numI + numF) / (double)numU;
  867. double r2 = (double)(o.numI + o.numF) / (double)o.numU;
  868. if (r1 != r2)
  869. return r2 < r1;
  870. if (numU != o.numU)
  871. return numU < o.numU;
  872. break;
  873. }
  874. default:
  875. DXASSERT_NOMSG(false);
  876. break;
  877. }
  878. return id < o.id;
  879. }
  880. struct LiveRangeLT {
  881. LiveRangeLT(const vector<DxilCleanup::LiveRange> &LiveRanges) : m_LiveRanges(LiveRanges) {}
  882. bool operator()(const unsigned i1, const unsigned i2) const {
  883. const DxilCleanup::LiveRange &lr1 = m_LiveRanges[i1];
  884. const DxilCleanup::LiveRange &lr2 = m_LiveRanges[i2];
  885. return lr1 < lr2;
  886. }
  887. private:
  888. const vector<DxilCleanup::LiveRange> &m_LiveRanges;
  889. };
  890. void DxilCleanup::InferLiveRangeTypes() {
  891. set<unsigned, LiveRangeLT> LiveRangeSet{LiveRangeLT(m_LiveRanges)};
  892. // TODO: Evaluate as candidate for optimization.
  893. // Initialize queue.
  894. for (LiveRange &LR : m_LiveRanges) {
  895. LiveRangeSet.insert(LR.id);
  896. }
  897. while (!LiveRangeSet.empty()) {
  898. unsigned LRId = *LiveRangeSet.cbegin();
  899. LiveRange &LR = m_LiveRanges[LRId];
  900. LiveRangeSet.erase(LRId);
  901. // Assign type.
  902. LR.GuessType(*m_pCtx);
  903. // Propagate type assignment to neigboring live ranges.
  904. for (auto itp : LR.bitcastMap) {
  905. if (LiveRangeSet.find(itp.first) == LiveRangeSet.end())
  906. continue;
  907. unsigned neighborId = itp.first;
  908. unsigned numLinks = itp.second;
  909. LiveRangeSet.erase(neighborId);
  910. LiveRange &neighbor = m_LiveRanges[neighborId];
  911. if (LR.pNewType->isFloatingPointTy()) {
  912. neighbor.numF += numLinks;
  913. } else {
  914. neighbor.numI += numLinks;
  915. }
  916. LiveRangeSet.insert(neighborId);
  917. }
  918. }
  919. }
  920. void DxilCleanup::ChangeLiveRangeTypes() {
  921. for (LiveRange &LR : m_LiveRanges) {
  922. Type *pType = (*LR.defs.begin())->getType();
  923. if (pType == LR.pNewType)
  924. continue;
  925. // Change live range type.
  926. SmallDenseMap<Value *, Value *, 4> DefMap;
  927. // a. Create new defs.
  928. for (Value *D : LR.defs) {
  929. Instruction *pInst = dyn_cast<Instruction>(D);
  930. if (PHINode *phi = dyn_cast<PHINode>(pInst)) {
  931. PHINode *pNewPhi = PHINode::Create(LR.pNewType, phi->getNumIncomingValues(), phi->getName(), phi->getNextNode());
  932. DefMap[D] = pNewPhi;
  933. } else {
  934. DefMap[D] = CastValue(pInst, LR.pNewType, pInst);
  935. }
  936. }
  937. // b. Fix phi uses.
  938. for (Value *D : LR.defs) {
  939. if (PHINode *phi = dyn_cast<PHINode>(D)) {
  940. DXASSERT_NOMSG(DefMap.find(phi) != DefMap.end());
  941. PHINode *pNewPhi = dyn_cast<PHINode>(DefMap[phi]);
  942. for (unsigned i = 0; i < phi->getNumIncomingValues(); i++) {
  943. Value *pVal = phi->getIncomingValue(i);
  944. BasicBlock *BB = phi->getIncomingBlock(i);
  945. Value *pNewVal = nullptr;
  946. if (!isa<Constant>(pVal)) {
  947. DXASSERT_NOMSG(DefMap.find(pVal) != DefMap.end());
  948. pNewVal = DefMap[pVal];
  949. } else {
  950. pNewVal = CastValue(pVal, pNewPhi->getType(), BB->getTerminator());
  951. }
  952. pNewPhi->addIncoming(pNewVal, BB);
  953. }
  954. }
  955. }
  956. // c. Fix other uses.
  957. for (Value *D : LR.defs) {
  958. for (User *U : D->users()) {
  959. if (isa<PHINode>(U) || IsDxilBitcast(U))
  960. continue;
  961. Instruction *pNewInst = dyn_cast<Instruction>(DefMap[D]);
  962. Value *pRevBitcast = CastValue(pNewInst, pType, pNewInst);
  963. U->replaceUsesOfWith(D, pRevBitcast);
  964. // If the new def is a phi we need to be careful about where we place the bitcast.
  965. // For phis we need to place the bitcast after all the phi defs for the block.
  966. if (isa<PHINode>(pNewInst) && isa<Instruction>(pRevBitcast) && pRevBitcast != pNewInst) {
  967. PHINode *pPhi = cast<PHINode>(pNewInst);
  968. Instruction *pInst = cast<Instruction>(pRevBitcast);
  969. pInst->removeFromParent();
  970. pInst->insertBefore(pPhi->getParent()->getFirstInsertionPt());
  971. }
  972. }
  973. }
  974. }
  975. }
  976. template<typename DxilBitcast1, typename DxilBitcast2>
  977. static bool CleanupBitcastPattern(Instruction *I1) {
  978. if (DxilBitcast1 BC1 = DxilBitcast1(I1)) {
  979. Instruction *I2 = dyn_cast<Instruction>(BC1.get_value());
  980. if (I2) {
  981. if (DxilBitcast2 BC2 = DxilBitcast2(I2)) {
  982. I1->replaceAllUsesWith(BC2.get_value());
  983. }
  984. }
  985. return true;
  986. }
  987. return false;
  988. }
  989. void DxilCleanup::CleanupPatterns() {
  990. for (auto itFn = m_pModule->begin(), endFn = m_pModule->end(); itFn != endFn; ++itFn) {
  991. Function *F = itFn;
  992. for (auto itBB = F->begin(), endBB = F->end(); itBB != endBB; ++itBB) {
  993. BasicBlock *BB = &*itBB;
  994. for (auto itInst = BB->begin(), endInst = BB->end(); itInst != endInst; ++itInst) {
  995. Instruction *I1 = &*itInst;
  996. // Cleanup i1 pattern:
  997. // %1 = icmp eq i32 %0, 1
  998. // %2 = sext i1 %1 to i32
  999. // %3 = icmp ne i32 %2, 0
  1000. // br i1 %3, ...
  1001. //
  1002. // becomes
  1003. // ...
  1004. // br i1 %1, ...
  1005. //
  1006. if (ICmpInst *pICmp = dyn_cast<ICmpInst>(I1)) {
  1007. if (pICmp->getPredicate() != CmpInst::Predicate::ICMP_NE)
  1008. continue;
  1009. Value *O1 = pICmp->getOperand(0);
  1010. if (O1->getType() != Type::getInt32Ty(*m_pCtx))
  1011. continue;
  1012. Value *O2 = pICmp->getOperand(1);
  1013. if (dyn_cast<ConstantInt>(O1))
  1014. std::swap(O1, O2);
  1015. ConstantInt *C = dyn_cast<ConstantInt>(O2);
  1016. if (!C || C->getZExtValue() != 0)
  1017. continue;
  1018. SExtInst *SE = dyn_cast<SExtInst>(O1);
  1019. DXASSERT_NOMSG(!SE || SE->getType() == Type::getInt32Ty(*m_pCtx));
  1020. if (!SE || SE->getSrcTy() != Type::getInt1Ty(*m_pCtx))
  1021. continue;
  1022. I1->replaceAllUsesWith(SE->getOperand(0));
  1023. continue;
  1024. }
  1025. // Cleanup chains of bitcasts:
  1026. // %1 = call float @dx.op.bitcastI32toF32(i32 126, i32 %0)
  1027. // %2 = call i32 @dx.op.bitcastF32toI32(i32 127, float %1)
  1028. // %3 = iadd i32 %2, ...
  1029. //
  1030. // becomes
  1031. // ...
  1032. // %3 = iadd i32 %0, ...
  1033. //
  1034. if (CleanupBitcastPattern<DxilInst_BitcastI32toF32, DxilInst_BitcastF32toI32>(I1)) continue;
  1035. if (CleanupBitcastPattern<DxilInst_BitcastF32toI32, DxilInst_BitcastI32toF32>(I1)) continue;
  1036. if (CleanupBitcastPattern<DxilInst_BitcastI16toF16, DxilInst_BitcastF16toI16>(I1)) continue;
  1037. if (CleanupBitcastPattern<DxilInst_BitcastF16toI16, DxilInst_BitcastI16toF16>(I1)) continue;
  1038. if (CleanupBitcastPattern<DxilInst_BitcastI64toF64, DxilInst_BitcastF64toI64>(I1)) continue;
  1039. if (CleanupBitcastPattern<DxilInst_BitcastF64toI64, DxilInst_BitcastI64toF64>(I1)) continue;
  1040. // Cleanup chains of doubles:
  1041. // %7 = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %6)
  1042. // %8 = extractvalue %dx.types.splitdouble %7, 0
  1043. // %9 = extractvalue %dx.types.splitdouble %7, 1
  1044. // ...
  1045. // %15 = call double @dx.op.makeDouble.f64(i32 101, i32 %8, i32 %9)
  1046. // %16 = call double @dx.op.binary.f64(i32 36, double %15, double 0x3FFC51EB80000000)
  1047. //
  1048. // becomes (%15 -> %6)
  1049. // ...
  1050. // %16 = call double @dx.op.binary.f64(i32 36, double %6, double 0x3FFC51EB80000000)
  1051. //
  1052. if (DxilInst_MakeDouble MD = DxilInst_MakeDouble(I1)) {
  1053. ExtractValueInst *V1 = dyn_cast<ExtractValueInst>(MD.get_hi());
  1054. ExtractValueInst *V2 = dyn_cast<ExtractValueInst>(MD.get_lo());
  1055. if (V1 && V2 && V1->getAggregateOperand() == V2->getAggregateOperand() &&
  1056. V1->getNumIndices() == 1 && V2->getNumIndices() == 1 &&
  1057. *V1->idx_begin() == 1 && *V2->idx_begin() == 0) {
  1058. Instruction *pSDInst = dyn_cast<Instruction>(V1->getAggregateOperand());
  1059. if (!pSDInst) continue;
  1060. if (DxilInst_SplitDouble SD = DxilInst_SplitDouble(pSDInst)) {
  1061. I1->replaceAllUsesWith(SD.get_value());
  1062. }
  1063. }
  1064. continue;
  1065. }
  1066. }
  1067. }
  1068. }
  1069. }
  1070. void DxilCleanup::RemoveDeadCode() {
  1071. #if DXILCLEANUP_DBG
  1072. DXASSERT_NOMSG(!verifyModule(*m_pModule));
  1073. #endif
  1074. PassManager PM;
  1075. PM.add(createDeadCodeEliminationPass());
  1076. PM.run(*m_pModule);
  1077. }
  1078. Value *DxilCleanup::CastValue(Value *pValue, Type *pToType, Instruction *pOrigInst) {
  1079. Type *pType = pValue->getType();
  1080. if (pType == pToType)
  1081. return pValue;
  1082. const unsigned kNumTypeArgs = 3;
  1083. Type *ArgTypes[kNumTypeArgs];
  1084. DXIL::OpCode OpCode;
  1085. if (pType == Type::getFloatTy(*m_pCtx)) {
  1086. IFTBOOL(pToType == Type::getInt32Ty(*m_pCtx), DXC_E_OPTIMIZATION_FAILED);
  1087. OpCode = DXIL::OpCode::BitcastF32toI32;
  1088. ArgTypes[0] = Type::getInt32Ty(*m_pCtx);
  1089. ArgTypes[1] = Type::getInt32Ty(*m_pCtx);
  1090. ArgTypes[2] = Type::getFloatTy(*m_pCtx);
  1091. } else if (pType == Type::getInt32Ty(*m_pCtx)) {
  1092. IFTBOOL(pToType == Type::getFloatTy(*m_pCtx), DXC_E_OPTIMIZATION_FAILED);
  1093. OpCode = DXIL::OpCode::BitcastI32toF32;
  1094. ArgTypes[0] = Type::getFloatTy(*m_pCtx);
  1095. ArgTypes[1] = Type::getInt32Ty(*m_pCtx);
  1096. ArgTypes[2] = Type::getInt32Ty(*m_pCtx);
  1097. } else if (pType == Type::getHalfTy(*m_pCtx)) {
  1098. IFTBOOL(pToType == Type::getInt16Ty(*m_pCtx), DXC_E_OPTIMIZATION_FAILED);
  1099. OpCode = DXIL::OpCode::BitcastF16toI16;
  1100. ArgTypes[0] = Type::getInt16Ty(*m_pCtx);
  1101. ArgTypes[1] = Type::getInt32Ty(*m_pCtx);
  1102. ArgTypes[2] = Type::getHalfTy(*m_pCtx);
  1103. } else if (pType == Type::getInt16Ty(*m_pCtx)) {
  1104. IFTBOOL(pToType == Type::getHalfTy(*m_pCtx), DXC_E_OPTIMIZATION_FAILED);
  1105. OpCode = DXIL::OpCode::BitcastI16toF16;
  1106. ArgTypes[0] = Type::getHalfTy(*m_pCtx);
  1107. ArgTypes[1] = Type::getInt32Ty(*m_pCtx);
  1108. ArgTypes[2] = Type::getInt16Ty(*m_pCtx);
  1109. } else if (pType == Type::getDoubleTy(*m_pCtx)) {
  1110. IFTBOOL(pToType == Type::getInt64Ty(*m_pCtx), DXC_E_OPTIMIZATION_FAILED);
  1111. OpCode = DXIL::OpCode::BitcastF64toI64;
  1112. ArgTypes[0] = Type::getInt64Ty(*m_pCtx);
  1113. ArgTypes[1] = Type::getInt32Ty(*m_pCtx);
  1114. ArgTypes[2] = Type::getDoubleTy(*m_pCtx);
  1115. } else if (pType == Type::getInt64Ty(*m_pCtx)) {
  1116. IFTBOOL(pToType == Type::getDoubleTy(*m_pCtx), DXC_E_OPTIMIZATION_FAILED);
  1117. OpCode = DXIL::OpCode::BitcastI64toF64;
  1118. ArgTypes[0] = Type::getDoubleTy(*m_pCtx);
  1119. ArgTypes[1] = Type::getInt32Ty(*m_pCtx);
  1120. ArgTypes[2] = Type::getInt64Ty(*m_pCtx);
  1121. } else {
  1122. IFT(DXC_E_OPTIMIZATION_FAILED);
  1123. }
  1124. // Get function.
  1125. std::string funcName = (Twine("dx.op.") + Twine(OP::GetOpCodeClassName(OpCode))).str();
  1126. // Try to find exist function with the same name in the module.
  1127. Function *F = m_pModule->getFunction(funcName);
  1128. if (!F) {
  1129. FunctionType *pFT;
  1130. pFT = FunctionType::get(ArgTypes[0], ArrayRef<Type*>(&ArgTypes[1], kNumTypeArgs-1), false);
  1131. F = Function::Create(pFT, GlobalValue::LinkageTypes::ExternalLinkage, funcName, m_pModule);
  1132. F->setCallingConv(CallingConv::C);
  1133. F->addFnAttr(Attribute::NoUnwind);
  1134. F->addFnAttr(Attribute::ReadNone);
  1135. }
  1136. // Create bitcast call.
  1137. const unsigned kNumArgs = 2;
  1138. Value *Args[kNumArgs];
  1139. Args[0] = Constant::getIntegerValue(IntegerType::get(*m_pCtx, 32), APInt(32, (int)OpCode));
  1140. Args[1] = pValue;
  1141. CallInst *pBitcast = nullptr;
  1142. if (Instruction *pInsertAfter = dyn_cast<Instruction>(pValue)) {
  1143. pBitcast = CallInst::Create(F, ArrayRef<Value*>(&Args[0], kNumArgs), "", pInsertAfter->getNextNode());
  1144. } else {
  1145. pBitcast = CallInst::Create(F, ArrayRef<Value*>(&Args[0], kNumArgs), "", pOrigInst);
  1146. }
  1147. return pBitcast;
  1148. }
  1149. bool DxilCleanup::IsDxilBitcast(Value *pValue) {
  1150. if (Instruction *pInst = dyn_cast<Instruction>(pValue)) {
  1151. if (OP::IsDxilOpFuncCallInst(pInst)) {
  1152. OP::OpCode opcode = OP::GetDxilOpFuncCallInst(pInst);
  1153. switch (opcode) {
  1154. case OP::OpCode::BitcastF16toI16:
  1155. case OP::OpCode::BitcastF32toI32:
  1156. case OP::OpCode::BitcastF64toI64:
  1157. case OP::OpCode::BitcastI16toF16:
  1158. case OP::OpCode::BitcastI32toF32:
  1159. case OP::OpCode::BitcastI64toF64:
  1160. return true;
  1161. }
  1162. }
  1163. }
  1164. return false;
  1165. }
  1166. } // namespace DxilCleanupNS
  1167. using namespace DxilCleanupNS;
  1168. // Publicly exposed interface to pass...
  1169. char &llvm::DxilCleanupID = DxilCleanup::ID;
  1170. INITIALIZE_PASS_BEGIN(DxilCleanup, "dxil-cleanup", "Optimize DXIL after conversion from DXBC", true, false)
  1171. INITIALIZE_PASS_END (DxilCleanup, "dxil-cleanup", "Optimize DXIL after conversion from DXBC", true, false)
  1172. namespace llvm {
  1173. ModulePass *createDxilCleanupPass() {
  1174. return new DxilCleanup();
  1175. }
  1176. }