DxilPreparePasses.cpp 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilPreparePasses.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Passes to prepare DxilModule. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/DxilGenerationPass.h"
  12. #include "dxc/DXIL/DxilOperations.h"
  13. #include "dxc/HLSL/HLOperations.h"
  14. #include "dxc/DXIL/DxilModule.h"
  15. #include "dxc/Support/Global.h"
  16. #include "dxc/DXIL/DxilTypeSystem.h"
  17. #include "dxc/DXIL/DxilUtil.h"
  18. #include "dxc/DXIL/DxilFunctionProps.h"
  19. #include "dxc/DXIL/DxilInstructions.h"
  20. #include "dxc/DXIL/DxilConstants.h"
  21. #include "dxc/HlslIntrinsicOp.h"
  22. #include "llvm/IR/GetElementPtrTypeIterator.h"
  23. #include "llvm/IR/IRBuilder.h"
  24. #include "llvm/IR/Instructions.h"
  25. #include "llvm/IR/InstIterator.h"
  26. #include "llvm/IR/IntrinsicInst.h"
  27. #include "llvm/Analysis/PostDominators.h"
  28. #include "llvm/IR/Module.h"
  29. #include "llvm/IR/DebugInfo.h"
  30. #include "llvm/IR/PassManager.h"
  31. #include "llvm/ADT/BitVector.h"
  32. #include "llvm/Pass.h"
  33. #include "llvm/Transforms/Utils/Local.h"
  34. #include "llvm/Analysis/AssumptionCache.h"
  35. #include "llvm/Analysis/DxilValueCache.h"
  36. #include "llvm/Analysis/LoopInfo.h"
  37. #include <memory>
  38. #include <unordered_set>
  39. using namespace llvm;
  40. using namespace hlsl;
  41. namespace {
  42. class InvalidateUndefResources : public ModulePass {
  43. public:
  44. static char ID;
  45. explicit InvalidateUndefResources() : ModulePass(ID) {
  46. initializeScalarizerPass(*PassRegistry::getPassRegistry());
  47. }
  48. const char *getPassName() const override { return "Invalidate undef resources"; }
  49. bool runOnModule(Module &M) override;
  50. };
  51. }
  52. char InvalidateUndefResources::ID = 0;
  53. ModulePass *llvm::createInvalidateUndefResourcesPass() { return new InvalidateUndefResources(); }
  54. INITIALIZE_PASS(InvalidateUndefResources, "invalidate-undef-resource", "Invalidate undef resources", false, false)
  55. bool InvalidateUndefResources::runOnModule(Module &M) {
  56. // Undef resources typically indicate uninitialized locals being used
  57. // in some code path, which we should catch and report. However, some
  58. // code patterns in large shaders cause dead undef resources to momentarily,
  59. // which is not an error. We must wait until cleanup passes
  60. // have run to know whether we must produce an error.
  61. // However, we can't leave the undef values in because they could eliminated,
  62. // such as by reading from resources seen in a code path that was not taken.
  63. // We avoid the problem by replacing undef values by another invalid
  64. // value that we can identify later.
  65. for (auto &F : M.functions()) {
  66. if (GetHLOpcodeGroupByName(&F) == HLOpcodeGroup::HLCreateHandle) {
  67. Type *ResTy = F.getFunctionType()->getParamType(
  68. HLOperandIndex::kCreateHandleResourceOpIdx);
  69. UndefValue *UndefRes = UndefValue::get(ResTy);
  70. if (!UndefRes->use_empty()) {
  71. Constant *InvalidRes = ConstantAggregateZero::get(ResTy);
  72. UndefRes->replaceAllUsesWith(InvalidRes);
  73. }
  74. }
  75. }
  76. return false;
  77. }
  78. ///////////////////////////////////////////////////////////////////////////////
  79. namespace {
  80. class SimplifyInst : public FunctionPass {
  81. public:
  82. static char ID;
  83. SimplifyInst() : FunctionPass(ID) {
  84. initializeScalarizerPass(*PassRegistry::getPassRegistry());
  85. }
  86. bool runOnFunction(Function &F) override;
  87. private:
  88. };
  89. }
  90. char SimplifyInst::ID = 0;
  91. FunctionPass *llvm::createSimplifyInstPass() { return new SimplifyInst(); }
  92. INITIALIZE_PASS(SimplifyInst, "simplify-inst", "Simplify Instructions", false, false)
  93. bool SimplifyInst::runOnFunction(Function &F) {
  94. for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
  95. BasicBlock *BB = BBI;
  96. llvm::SimplifyInstructionsInBlock(BB, nullptr);
  97. }
  98. return true;
  99. }
  100. ///////////////////////////////////////////////////////////////////////////////
  101. namespace {
  102. class DxilDeadFunctionElimination : public ModulePass {
  103. public:
  104. static char ID; // Pass identification, replacement for typeid
  105. explicit DxilDeadFunctionElimination () : ModulePass(ID) {}
  106. const char *getPassName() const override { return "Remove all unused function except entry from DxilModule"; }
  107. bool runOnModule(Module &M) override {
  108. if (M.HasDxilModule()) {
  109. DxilModule &DM = M.GetDxilModule();
  110. bool IsLib = DM.GetShaderModel()->IsLib();
  111. // Remove unused functions except entry and patch constant func.
  112. // For library profile, only remove unused external functions.
  113. Function *EntryFunc = DM.GetEntryFunction();
  114. Function *PatchConstantFunc = DM.GetPatchConstantFunction();
  115. return dxilutil::RemoveUnusedFunctions(M, EntryFunc, PatchConstantFunc,
  116. IsLib);
  117. }
  118. return false;
  119. }
  120. };
  121. }
  122. char DxilDeadFunctionElimination::ID = 0;
  123. ModulePass *llvm::createDxilDeadFunctionEliminationPass() {
  124. return new DxilDeadFunctionElimination();
  125. }
  126. INITIALIZE_PASS(DxilDeadFunctionElimination, "dxil-dfe", "Remove all unused function except entry from DxilModule", false, false)
  127. ///////////////////////////////////////////////////////////////////////////////
  128. bool CleanupSharedMemoryAddrSpaceCast(Module &M);
  129. namespace {
  130. static void TransferEntryFunctionAttributes(Function *F, Function *NewFunc) {
  131. // Keep necessary function attributes
  132. AttributeSet attributeSet = F->getAttributes();
  133. StringRef attrKind, attrValue;
  134. if (attributeSet.hasAttribute(AttributeSet::FunctionIndex, DXIL::kFP32DenormKindString)) {
  135. Attribute attribute = attributeSet.getAttribute(AttributeSet::FunctionIndex, DXIL::kFP32DenormKindString);
  136. DXASSERT(attribute.isStringAttribute(), "otherwise we have wrong fp-denorm-mode attribute.");
  137. attrKind = attribute.getKindAsString();
  138. attrValue = attribute.getValueAsString();
  139. }
  140. if (F == NewFunc) {
  141. NewFunc->removeAttributes(AttributeSet::FunctionIndex, attributeSet);
  142. }
  143. if (!attrKind.empty() && !attrValue.empty())
  144. NewFunc->addFnAttr(attrKind, attrValue);
  145. }
  146. // If this returns non-null, the old function F has been stripped and can be deleted.
  147. static Function *StripFunctionParameter(Function *F, DxilModule &DM,
  148. DenseMap<const Function *, DISubprogram *> &FunctionDIs) {
  149. if (F->arg_empty() && F->getReturnType()->isVoidTy()) {
  150. // This will strip non-entry function attributes
  151. TransferEntryFunctionAttributes(F, F);
  152. return nullptr;
  153. }
  154. Module &M = *DM.GetModule();
  155. Type *VoidTy = Type::getVoidTy(M.getContext());
  156. FunctionType *FT = FunctionType::get(VoidTy, false);
  157. for (auto &arg : F->args()) {
  158. if (!arg.user_empty())
  159. return nullptr;
  160. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(&arg);
  161. if (DDI) {
  162. DDI->eraseFromParent();
  163. }
  164. }
  165. Function *NewFunc = Function::Create(FT, F->getLinkage());
  166. M.getFunctionList().insert(F, NewFunc);
  167. // Splice the body of the old function right into the new function.
  168. NewFunc->getBasicBlockList().splice(NewFunc->begin(), F->getBasicBlockList());
  169. TransferEntryFunctionAttributes(F, NewFunc);
  170. // Patch the pointer to LLVM function in debug info descriptor.
  171. auto DI = FunctionDIs.find(F);
  172. if (DI != FunctionDIs.end()) {
  173. DISubprogram *SP = DI->second;
  174. SP->replaceFunction(NewFunc);
  175. // Ensure the map is updated so it can be reused on subsequent argument
  176. // promotions of the same function.
  177. FunctionDIs.erase(DI);
  178. FunctionDIs[NewFunc] = SP;
  179. }
  180. NewFunc->takeName(F);
  181. if (DM.HasDxilFunctionProps(F)) {
  182. DM.ReplaceDxilEntryProps(F, NewFunc);
  183. }
  184. DM.GetTypeSystem().EraseFunctionAnnotation(F);
  185. DM.GetTypeSystem().AddFunctionAnnotation(NewFunc);
  186. return NewFunc;
  187. }
  188. void CheckInBoundForTGSM(GlobalVariable &GV, const DataLayout &DL) {
  189. for (User *U : GV.users()) {
  190. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  191. bool allImmIndex = true;
  192. for (auto Idx = GEP->idx_begin(), E = GEP->idx_end(); Idx != E; Idx++) {
  193. if (!isa<ConstantInt>(Idx)) {
  194. allImmIndex = false;
  195. break;
  196. }
  197. }
  198. if (!allImmIndex)
  199. GEP->setIsInBounds(false);
  200. else {
  201. Value *Ptr = GEP->getPointerOperand();
  202. unsigned size =
  203. DL.getTypeAllocSize(Ptr->getType()->getPointerElementType());
  204. unsigned valSize =
  205. DL.getTypeAllocSize(GEP->getType()->getPointerElementType());
  206. SmallVector<Value *, 8> Indices(GEP->idx_begin(), GEP->idx_end());
  207. unsigned offset =
  208. DL.getIndexedOffset(GEP->getPointerOperandType(), Indices);
  209. if ((offset + valSize) > size)
  210. GEP->setIsInBounds(false);
  211. }
  212. }
  213. }
  214. }
  215. static bool GetUnsignedVal(Value *V, uint32_t *pValue) {
  216. ConstantInt *CI = dyn_cast<ConstantInt>(V);
  217. if (!CI) return false;
  218. uint64_t u = CI->getZExtValue();
  219. if (u > UINT32_MAX) return false;
  220. *pValue = (uint32_t)u;
  221. return true;
  222. }
  223. static void MarkUsedSignatureElements(Function *F, DxilModule &DM) {
  224. DXASSERT_NOMSG(F != nullptr);
  225. // For every loadInput/storeOutput, update the corresponding ReadWriteMask.
  226. // F is a pointer to a Function instance
  227. for (llvm::inst_iterator I = llvm::inst_begin(F), E = llvm::inst_end(F); I != E; ++I) {
  228. DxilInst_LoadInput LI(&*I);
  229. DxilInst_StoreOutput SO(&*I);
  230. DxilInst_LoadPatchConstant LPC(&*I);
  231. DxilInst_StorePatchConstant SPC(&*I);
  232. DxilInst_StoreVertexOutput SVO(&*I);
  233. DxilInst_StorePrimitiveOutput SPO(&*I);
  234. DxilSignature *pSig;
  235. uint32_t col, row, sigId;
  236. bool bDynIdx = false;
  237. if (LI) {
  238. if (!GetUnsignedVal(LI.get_inputSigId(), &sigId)) continue;
  239. if (!GetUnsignedVal(LI.get_colIndex(), &col)) continue;
  240. if (!GetUnsignedVal(LI.get_rowIndex(), &row)) bDynIdx = true;
  241. pSig = &DM.GetInputSignature();
  242. }
  243. else if (SO) {
  244. if (!GetUnsignedVal(SO.get_outputSigId(), &sigId)) continue;
  245. if (!GetUnsignedVal(SO.get_colIndex(), &col)) continue;
  246. if (!GetUnsignedVal(SO.get_rowIndex(), &row)) bDynIdx = true;
  247. pSig = &DM.GetOutputSignature();
  248. }
  249. else if (SPC) {
  250. if (!GetUnsignedVal(SPC.get_outputSigID(), &sigId)) continue;
  251. if (!GetUnsignedVal(SPC.get_col(), &col)) continue;
  252. if (!GetUnsignedVal(SPC.get_row(), &row)) bDynIdx = true;
  253. pSig = &DM.GetPatchConstOrPrimSignature();
  254. }
  255. else if (LPC) {
  256. if (!GetUnsignedVal(LPC.get_inputSigId(), &sigId)) continue;
  257. if (!GetUnsignedVal(LPC.get_col(), &col)) continue;
  258. if (!GetUnsignedVal(LPC.get_row(), &row)) bDynIdx = true;
  259. pSig = &DM.GetPatchConstOrPrimSignature();
  260. }
  261. else if (SVO) {
  262. if (!GetUnsignedVal(SVO.get_outputSigId(), &sigId)) continue;
  263. if (!GetUnsignedVal(SVO.get_colIndex(), &col)) continue;
  264. if (!GetUnsignedVal(SVO.get_rowIndex(), &row)) bDynIdx = true;
  265. pSig = &DM.GetOutputSignature();
  266. }
  267. else if (SPO) {
  268. if (!GetUnsignedVal(SPO.get_outputSigId(), &sigId)) continue;
  269. if (!GetUnsignedVal(SPO.get_colIndex(), &col)) continue;
  270. if (!GetUnsignedVal(SPO.get_rowIndex(), &row)) bDynIdx = true;
  271. pSig = &DM.GetPatchConstOrPrimSignature();
  272. }
  273. else {
  274. continue;
  275. }
  276. // Consider being more fine-grained about masks.
  277. // We report sometimes-read on input as always-read.
  278. auto &El = pSig->GetElement(sigId);
  279. unsigned UsageMask = El.GetUsageMask();
  280. unsigned colBit = 1 << col;
  281. if (!(colBit & UsageMask)) {
  282. El.SetUsageMask(UsageMask | colBit);
  283. }
  284. if (bDynIdx && (El.GetDynIdxCompMask() & colBit) == 0) {
  285. El.SetDynIdxCompMask(El.GetDynIdxCompMask() | colBit);
  286. }
  287. }
  288. }
  289. class DxilFinalizeModule : public ModulePass {
  290. public:
  291. static char ID; // Pass identification, replacement for typeid
  292. explicit DxilFinalizeModule() : ModulePass(ID) {}
  293. const char *getPassName() const override { return "HLSL DXIL Finalize Module"; }
  294. void patchValidation_1_1(Module &M) {
  295. for (iplist<Function>::iterator F : M.getFunctionList()) {
  296. for (Function::iterator BBI = F->begin(), BBE = F->end(); BBI != BBE;
  297. ++BBI) {
  298. BasicBlock *BB = BBI;
  299. for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;
  300. ++II) {
  301. Instruction *I = II;
  302. if (I->hasMetadataOtherThanDebugLoc()) {
  303. SmallVector<std::pair<unsigned, MDNode*>, 2> MDs;
  304. I->getAllMetadataOtherThanDebugLoc(MDs);
  305. for (auto &MD : MDs) {
  306. unsigned kind = MD.first;
  307. // Remove Metadata which validation_1_0 not allowed.
  308. bool bNeedPatch = kind == LLVMContext::MD_tbaa ||
  309. kind == LLVMContext::MD_prof ||
  310. (kind > LLVMContext::MD_fpmath &&
  311. kind <= LLVMContext::MD_dereferenceable_or_null);
  312. if (bNeedPatch)
  313. I->setMetadata(kind, nullptr);
  314. }
  315. }
  316. }
  317. }
  318. }
  319. }
  320. void patchDxil_1_6(Module &M, hlsl::OP *hlslOP) {
  321. for (auto it : hlslOP->GetOpFuncList(DXIL::OpCode::AnnotateHandle)) {
  322. Function *F = it.second;
  323. if (!F)
  324. continue;
  325. for (auto uit = F->user_begin(); uit != F->user_end();) {
  326. CallInst *CI = cast<CallInst>(*(uit++));
  327. DxilInst_AnnotateHandle annoteHdl(CI);
  328. Value *hdl = annoteHdl.get_res();
  329. CI->replaceAllUsesWith(hdl);
  330. CI->eraseFromParent();
  331. }
  332. }
  333. }
  334. bool runOnModule(Module &M) override {
  335. if (M.HasDxilModule()) {
  336. DxilModule &DM = M.GetDxilModule();
  337. unsigned ValMajor = 0;
  338. unsigned ValMinor = 0;
  339. M.GetDxilModule().GetValidatorVersion(ValMajor, ValMinor);
  340. unsigned DxilMajor = 0;
  341. unsigned DxilMinor = 0;
  342. M.GetDxilModule().GetDxilVersion(DxilMajor, DxilMinor);
  343. bool IsLib = DM.GetShaderModel()->IsLib();
  344. // Skip validation patch for lib.
  345. if (!IsLib) {
  346. if (ValMajor == 1 && ValMinor <= 1) {
  347. patchValidation_1_1(M);
  348. }
  349. // Set used masks for signature elements
  350. MarkUsedSignatureElements(DM.GetEntryFunction(), DM);
  351. if (DM.GetShaderModel()->IsHS())
  352. MarkUsedSignatureElements(DM.GetPatchConstantFunction(), DM);
  353. }
  354. // Remove store undef output.
  355. hlsl::OP *hlslOP = M.GetDxilModule().GetOP();
  356. if (DxilMinor < 6) {
  357. patchDxil_1_6(M, hlslOP);
  358. }
  359. RemoveStoreUndefOutput(M, hlslOP);
  360. // Turn dx.break() conditional into global
  361. LowerDxBreak(M);
  362. RemoveUnusedStaticGlobal(M);
  363. // Remove unnecessary address space casts.
  364. CleanupSharedMemoryAddrSpaceCast(M);
  365. // Clear inbound for GEP which has none-const index.
  366. LegalizeSharedMemoryGEPInbound(M);
  367. // Strip parameters of entry function.
  368. StripEntryParameters(M, DM, IsLib);
  369. // Update flags to reflect any changes.
  370. DM.CollectShaderFlagsForModule();
  371. // Update Validator Version
  372. DM.UpgradeToMinValidatorVersion();
  373. // Clear intermediate options that shouldn't be in the final DXIL
  374. DM.ClearIntermediateOptions();
  375. // Remove unused AllocateRayQuery calls
  376. RemoveUnusedRayQuery(M);
  377. if (IsLib && DXIL::CompareVersions(ValMajor, ValMinor, 1, 4) <= 0) {
  378. // 1.4 validator requires function annotations for all functions
  379. AddFunctionAnnotationForInitializers(M, DM);
  380. }
  381. return true;
  382. }
  383. return false;
  384. }
  385. private:
  386. void RemoveUnusedStaticGlobal(Module &M) {
  387. // Remove unused internal global.
  388. std::vector<GlobalVariable *> staticGVs;
  389. for (GlobalVariable &GV : M.globals()) {
  390. if (dxilutil::IsStaticGlobal(&GV) ||
  391. dxilutil::IsSharedMemoryGlobal(&GV)) {
  392. staticGVs.emplace_back(&GV);
  393. }
  394. }
  395. for (GlobalVariable *GV : staticGVs) {
  396. bool onlyStoreUse = true;
  397. for (User *user : GV->users()) {
  398. if (isa<StoreInst>(user))
  399. continue;
  400. if (isa<ConstantExpr>(user) && user->user_empty())
  401. continue;
  402. onlyStoreUse = false;
  403. break;
  404. }
  405. if (onlyStoreUse) {
  406. for (auto UserIt = GV->user_begin(); UserIt != GV->user_end();) {
  407. Value *User = *(UserIt++);
  408. if (Instruction *I = dyn_cast<Instruction>(User)) {
  409. I->eraseFromParent();
  410. } else {
  411. ConstantExpr *CE = cast<ConstantExpr>(User);
  412. CE->dropAllReferences();
  413. }
  414. }
  415. GV->eraseFromParent();
  416. }
  417. }
  418. }
  419. void RemoveStoreUndefOutput(Module &M, hlsl::OP *hlslOP) {
  420. for (iplist<Function>::iterator F : M.getFunctionList()) {
  421. if (!hlslOP->IsDxilOpFunc(F))
  422. continue;
  423. DXIL::OpCodeClass opClass;
  424. bool bHasOpClass = hlslOP->GetOpCodeClass(F, opClass);
  425. DXASSERT_LOCALVAR(bHasOpClass, bHasOpClass, "else not a dxil op func");
  426. if (opClass != DXIL::OpCodeClass::StoreOutput)
  427. continue;
  428. for (auto it = F->user_begin(); it != F->user_end();) {
  429. CallInst *CI = dyn_cast<CallInst>(*(it++));
  430. if (!CI)
  431. continue;
  432. Value *V = CI->getArgOperand(DXIL::OperandIndex::kStoreOutputValOpIdx);
  433. // Remove the store of undef.
  434. if (isa<UndefValue>(V))
  435. CI->eraseFromParent();
  436. }
  437. }
  438. }
  439. void LegalizeSharedMemoryGEPInbound(Module &M) {
  440. const DataLayout &DL = M.getDataLayout();
  441. // Clear inbound for GEP which has none-const index.
  442. for (GlobalVariable &GV : M.globals()) {
  443. if (dxilutil::IsSharedMemoryGlobal(&GV)) {
  444. CheckInBoundForTGSM(GV, DL);
  445. }
  446. }
  447. }
  448. void StripEntryParameters(Module &M, DxilModule &DM, bool IsLib) {
  449. DenseMap<const Function *, DISubprogram *> FunctionDIs =
  450. makeSubprogramMap(M);
  451. // Strip parameters of entry function.
  452. if (!IsLib) {
  453. if (Function *OldPatchConstantFunc = DM.GetPatchConstantFunction()) {
  454. Function *NewPatchConstantFunc =
  455. StripFunctionParameter(OldPatchConstantFunc, DM, FunctionDIs);
  456. if (NewPatchConstantFunc) {
  457. DM.SetPatchConstantFunction(NewPatchConstantFunc);
  458. // Erase once the DxilModule doesn't track the old function anymore
  459. DXASSERT(DM.IsPatchConstantShader(NewPatchConstantFunc) && !DM.IsPatchConstantShader(OldPatchConstantFunc),
  460. "Error while migrating to parameter-stripped patch constant function.");
  461. OldPatchConstantFunc->eraseFromParent();
  462. }
  463. }
  464. if (Function *OldEntryFunc = DM.GetEntryFunction()) {
  465. StringRef Name = DM.GetEntryFunctionName();
  466. OldEntryFunc->setName(Name);
  467. Function *NewEntryFunc = StripFunctionParameter(OldEntryFunc, DM, FunctionDIs);
  468. if (NewEntryFunc) {
  469. DM.SetEntryFunction(NewEntryFunc);
  470. OldEntryFunc->eraseFromParent();
  471. }
  472. }
  473. } else {
  474. std::vector<Function *> entries;
  475. // Handle when multiple hull shaders point to the same patch constant function
  476. MapVector<Function*, llvm::SmallVector<Function*, 2>> PatchConstantFuncUsers;
  477. for (iplist<Function>::iterator F : M.getFunctionList()) {
  478. if (DM.IsEntryThatUsesSignatures(F)) {
  479. auto *FT = F->getFunctionType();
  480. // Only do this when has parameters.
  481. if (FT->getNumParams() > 0 || !FT->getReturnType()->isVoidTy()) {
  482. entries.emplace_back(F);
  483. }
  484. DxilFunctionProps& props = DM.GetDxilFunctionProps(F);
  485. if (props.IsHS() && props.ShaderProps.HS.patchConstantFunc) {
  486. FunctionType* PatchConstantFuncTy = props.ShaderProps.HS.patchConstantFunc->getFunctionType();
  487. if (PatchConstantFuncTy->getNumParams() > 0 || !PatchConstantFuncTy->getReturnType()->isVoidTy()) {
  488. // Accumulate all hull shaders using a given patch constant function,
  489. // so we can update it once and fix all hull shaders, without having an intermediary
  490. // state where some hull shaders point to a destroyed patch constant function.
  491. PatchConstantFuncUsers[props.ShaderProps.HS.patchConstantFunc].emplace_back(F);
  492. }
  493. }
  494. }
  495. }
  496. // Strip patch constant functions first
  497. for (auto &PatchConstantFuncEntry : PatchConstantFuncUsers) {
  498. Function* OldPatchConstantFunc = PatchConstantFuncEntry.first;
  499. Function* NewPatchConstantFunc = StripFunctionParameter(OldPatchConstantFunc, DM, FunctionDIs);
  500. if (NewPatchConstantFunc) {
  501. // Update all user hull shaders
  502. for (Function *HullShaderFunc : PatchConstantFuncEntry.second)
  503. DM.SetPatchConstantFunctionForHS(HullShaderFunc, NewPatchConstantFunc);
  504. // Erase once the DxilModule doesn't track the old function anymore
  505. DXASSERT(DM.IsPatchConstantShader(NewPatchConstantFunc) && !DM.IsPatchConstantShader(OldPatchConstantFunc),
  506. "Error while migrating to parameter-stripped patch constant function.");
  507. OldPatchConstantFunc->eraseFromParent();
  508. }
  509. }
  510. for (Function *OldEntry : entries) {
  511. Function *NewEntry = StripFunctionParameter(OldEntry, DM, FunctionDIs);
  512. if (NewEntry) OldEntry->eraseFromParent();
  513. }
  514. }
  515. }
  516. void AddFunctionAnnotationForInitializers(Module &M, DxilModule &DM) {
  517. if (GlobalVariable *GV = M.getGlobalVariable("llvm.global_ctors")) {
  518. if (isa<ConstantAggregateZero>(GV->getInitializer())) {
  519. DXASSERT_NOMSG(GV->user_empty());
  520. GV->eraseFromParent();
  521. return;
  522. }
  523. ConstantArray *init = cast<ConstantArray>(GV->getInitializer());
  524. for (auto V : init->operand_values()) {
  525. if (isa<ConstantAggregateZero>(V))
  526. continue;
  527. ConstantStruct *CS = cast<ConstantStruct>(V);
  528. if (isa<ConstantPointerNull>(CS->getOperand(1)))
  529. continue;
  530. Function *F = cast<Function>(CS->getOperand(1));
  531. if (DM.GetTypeSystem().GetFunctionAnnotation(F) == nullptr)
  532. DM.GetTypeSystem().AddFunctionAnnotation(F);
  533. }
  534. }
  535. }
  536. void RemoveUnusedRayQuery(Module &M) {
  537. hlsl::OP *hlslOP = M.GetDxilModule().GetOP();
  538. llvm::Function *AllocFn = hlslOP->GetOpFunc(
  539. DXIL::OpCode::AllocateRayQuery, Type::getVoidTy(M.getContext()));
  540. SmallVector<CallInst*, 4> DeadInsts;
  541. for (auto U : AllocFn->users()) {
  542. if (CallInst *CI = dyn_cast<CallInst>(U)) {
  543. if (CI->user_empty()) {
  544. DeadInsts.emplace_back(CI);
  545. }
  546. }
  547. }
  548. for (auto CI : DeadInsts) {
  549. CI->eraseFromParent();
  550. }
  551. if (AllocFn->user_empty()) {
  552. AllocFn->eraseFromParent();
  553. }
  554. }
  555. // Convert all uses of dx.break() into per-function load/cmp of dx.break.cond global constant
  556. void LowerDxBreak(Module &M) {
  557. if (Function *BreakFunc = M.getFunction(DXIL::kDxBreakFuncName)) {
  558. if (!BreakFunc->use_empty()) {
  559. llvm::Type *i32Ty = llvm::Type::getInt32Ty(M.getContext());
  560. Type *i32ArrayTy = ArrayType::get(i32Ty, 1);
  561. unsigned int Values[1] = { 0 };
  562. Constant *InitialValue = ConstantDataArray::get(M.getContext(), Values);
  563. Constant *GV = new GlobalVariable(M, i32ArrayTy, true,
  564. GlobalValue::InternalLinkage,
  565. InitialValue, DXIL::kDxBreakCondName);
  566. Constant *Indices[] = { ConstantInt::get(i32Ty, 0), ConstantInt::get(i32Ty, 0) };
  567. Constant *Gep = ConstantExpr::getGetElementPtr(nullptr, GV, Indices);
  568. SmallDenseMap<llvm::Function*, llvm::ICmpInst*, 16> DxBreakCmpMap;
  569. // Replace all uses of dx.break with references to the constant global
  570. for (auto I = BreakFunc->user_begin(), E = BreakFunc->user_end(); I != E;) {
  571. User *U = *I++;
  572. CallInst *CI = cast<CallInst>(U);
  573. Function *F = CI->getParent()->getParent();
  574. ICmpInst *Cmp = DxBreakCmpMap.lookup(F);
  575. if (!Cmp) {
  576. Instruction *IP = dxilutil::FirstNonAllocaInsertionPt(F);
  577. LoadInst *LI = new LoadInst(Gep, nullptr, false, IP);
  578. Cmp = new ICmpInst(IP, ICmpInst::ICMP_EQ, LI, llvm::ConstantInt::get(i32Ty,0));
  579. DxBreakCmpMap[F] = Cmp;
  580. }
  581. CI->replaceAllUsesWith(Cmp);
  582. CI->eraseFromParent();
  583. }
  584. }
  585. BreakFunc->eraseFromParent();
  586. }
  587. }
  588. };
  589. }
  590. char DxilFinalizeModule::ID = 0;
  591. ModulePass *llvm::createDxilFinalizeModulePass() {
  592. return new DxilFinalizeModule();
  593. }
  594. INITIALIZE_PASS(DxilFinalizeModule, "hlsl-dxilfinalize", "HLSL DXIL Finalize Module", false, false)
  595. ///////////////////////////////////////////////////////////////////////////////
  596. namespace {
  597. typedef MapVector< PHINode*, SmallVector<Value*,8> > PHIReplacementMap;
  598. bool RemoveAddrSpaceCasts(Value *Val, Value *NewVal,
  599. PHIReplacementMap &phiReplacements,
  600. DenseMap<Value*, Value*> &valueMap) {
  601. bool bChanged = false;
  602. for (auto itU = Val->use_begin(), itEnd = Val->use_end(); itU != itEnd; ) {
  603. Use &use = *(itU++);
  604. User *user = use.getUser();
  605. Value *userReplacement = user;
  606. bool bConstructReplacement = false;
  607. bool bCleanupInst = false;
  608. auto valueMapIter = valueMap.find(user);
  609. if (valueMapIter != valueMap.end())
  610. userReplacement = valueMapIter->second;
  611. else if (Val != NewVal)
  612. bConstructReplacement = true;
  613. if (ConstantExpr* CE = dyn_cast<ConstantExpr>(user)) {
  614. if (CE->getOpcode() == Instruction::BitCast) {
  615. if (bConstructReplacement) {
  616. // Replicate bitcast in target address space
  617. Type* NewTy = PointerType::get(
  618. CE->getType()->getPointerElementType(),
  619. NewVal->getType()->getPointerAddressSpace());
  620. userReplacement = ConstantExpr::getBitCast(cast<Constant>(NewVal), NewTy);
  621. }
  622. } else if (CE->getOpcode() == Instruction::GetElementPtr) {
  623. if (bConstructReplacement) {
  624. // Replicate GEP in target address space
  625. GEPOperator *GEP = cast<GEPOperator>(CE);
  626. SmallVector<Value*, 8> idxList(GEP->idx_begin(), GEP->idx_end());
  627. userReplacement = ConstantExpr::getGetElementPtr(
  628. nullptr, cast<Constant>(NewVal), idxList, GEP->isInBounds());
  629. }
  630. } else if (CE->getOpcode() == Instruction::AddrSpaceCast) {
  631. userReplacement = NewVal;
  632. bConstructReplacement = false;
  633. } else {
  634. DXASSERT(false, "RemoveAddrSpaceCasts: unhandled pointer ConstantExpr");
  635. }
  636. } else if (Instruction *I = dyn_cast<Instruction>(user)) {
  637. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
  638. if (bConstructReplacement) {
  639. IRBuilder<> Builder(GEP);
  640. SmallVector<Value*, 8> idxList(GEP->idx_begin(), GEP->idx_end());
  641. if (GEP->isInBounds())
  642. userReplacement = Builder.CreateInBoundsGEP(NewVal, idxList, GEP->getName());
  643. else
  644. userReplacement = Builder.CreateGEP(NewVal, idxList, GEP->getName());
  645. }
  646. } else if (BitCastInst *BC = dyn_cast<BitCastInst>(user)) {
  647. if (bConstructReplacement) {
  648. IRBuilder<> Builder(BC);
  649. Type* NewTy = PointerType::get(
  650. BC->getType()->getPointerElementType(),
  651. NewVal->getType()->getPointerAddressSpace());
  652. userReplacement = Builder.CreateBitCast(NewVal, NewTy);
  653. }
  654. } else if (PHINode *PHI = dyn_cast<PHINode>(user)) {
  655. // set replacement phi values for PHI pass
  656. unsigned numValues = PHI->getNumIncomingValues();
  657. auto &phiValues = phiReplacements[PHI];
  658. if (phiValues.empty())
  659. phiValues.resize(numValues, nullptr);
  660. for (unsigned idx = 0; idx < numValues; ++idx) {
  661. if (phiValues[idx] == nullptr &&
  662. PHI->getIncomingValue(idx) == Val) {
  663. phiValues[idx] = NewVal;
  664. bChanged = true;
  665. }
  666. }
  667. continue;
  668. } else if (isa<AddrSpaceCastInst>(user)) {
  669. userReplacement = NewVal;
  670. bConstructReplacement = false;
  671. bCleanupInst = true;
  672. } else if (isa<CallInst>(user)) {
  673. continue;
  674. } else {
  675. if (Val != NewVal) {
  676. use.set(NewVal);
  677. bChanged = true;
  678. }
  679. continue;
  680. }
  681. }
  682. if (bConstructReplacement && user != userReplacement)
  683. valueMap[user] = userReplacement;
  684. bChanged |= RemoveAddrSpaceCasts(user, userReplacement, phiReplacements,
  685. valueMap);
  686. if (bCleanupInst && user->use_empty()) {
  687. // Clean up old instruction if it's now unused.
  688. // Safe during this use iteration when only one use of V in instruction.
  689. if (Instruction *I = dyn_cast<Instruction>(user))
  690. I->eraseFromParent();
  691. bChanged = true;
  692. }
  693. }
  694. return bChanged;
  695. }
  696. }
  697. bool CleanupSharedMemoryAddrSpaceCast(Module &M) {
  698. bool bChanged = false;
  699. // Eliminate address space casts if possible
  700. // Collect phi nodes so we can replace iteratively after pass over GVs
  701. PHIReplacementMap phiReplacements;
  702. DenseMap<Value*, Value*> valueMap;
  703. for (GlobalVariable &GV : M.globals()) {
  704. if (dxilutil::IsSharedMemoryGlobal(&GV)) {
  705. bChanged |= RemoveAddrSpaceCasts(&GV, &GV, phiReplacements,
  706. valueMap);
  707. }
  708. }
  709. bool bConverged = false;
  710. while (!phiReplacements.empty() && !bConverged) {
  711. bConverged = true;
  712. for (auto &phiReplacement : phiReplacements) {
  713. PHINode *PHI = phiReplacement.first;
  714. unsigned origAddrSpace = PHI->getType()->getPointerAddressSpace();
  715. unsigned incomingAddrSpace = UINT_MAX;
  716. bool bReplacePHI = true;
  717. bool bRemovePHI = false;
  718. for (auto V : phiReplacement.second) {
  719. if (nullptr == V) {
  720. // cannot replace phi (yet)
  721. bReplacePHI = false;
  722. break;
  723. }
  724. unsigned addrSpace = V->getType()->getPointerAddressSpace();
  725. if (incomingAddrSpace == UINT_MAX) {
  726. incomingAddrSpace = addrSpace;
  727. } else if (addrSpace != incomingAddrSpace) {
  728. bRemovePHI = true;
  729. break;
  730. }
  731. }
  732. if (origAddrSpace == incomingAddrSpace)
  733. bRemovePHI = true;
  734. if (bRemovePHI) {
  735. // Cannot replace phi. Remove it and restart.
  736. phiReplacements.erase(PHI);
  737. bConverged = false;
  738. break;
  739. }
  740. if (!bReplacePHI)
  741. continue;
  742. auto &NewVal = valueMap[PHI];
  743. PHINode *NewPHI = nullptr;
  744. if (NewVal) {
  745. NewPHI = cast<PHINode>(NewVal);
  746. } else {
  747. IRBuilder<> Builder(PHI);
  748. NewPHI = Builder.CreatePHI(
  749. PointerType::get(PHI->getType()->getPointerElementType(),
  750. incomingAddrSpace),
  751. PHI->getNumIncomingValues(),
  752. PHI->getName());
  753. NewVal = NewPHI;
  754. for (unsigned idx = 0; idx < PHI->getNumIncomingValues(); idx++) {
  755. NewPHI->addIncoming(phiReplacement.second[idx],
  756. PHI->getIncomingBlock(idx));
  757. }
  758. }
  759. if (RemoveAddrSpaceCasts(PHI, NewPHI, phiReplacements,
  760. valueMap)) {
  761. bConverged = false;
  762. bChanged = true;
  763. break;
  764. }
  765. if (PHI->use_empty()) {
  766. phiReplacements.erase(PHI);
  767. bConverged = false;
  768. bChanged = true;
  769. break;
  770. }
  771. }
  772. }
  773. // Cleanup unused replacement instructions
  774. SmallVector<WeakVH, 8> cleanupInsts;
  775. for (auto it : valueMap) {
  776. if (isa<Instruction>(it.first))
  777. cleanupInsts.push_back(it.first);
  778. if (isa<Instruction>(it.second))
  779. cleanupInsts.push_back(it.second);
  780. }
  781. for (auto V : cleanupInsts) {
  782. if (!V)
  783. continue;
  784. if (PHINode *PHI = dyn_cast<PHINode>(V))
  785. RecursivelyDeleteDeadPHINode(PHI);
  786. else if (Instruction *I = dyn_cast<Instruction>(V))
  787. RecursivelyDeleteTriviallyDeadInstructions(I);
  788. }
  789. return bChanged;
  790. }
  791. class DxilCleanupAddrSpaceCast : public ModulePass {
  792. public:
  793. static char ID; // Pass identification, replacement for typeid
  794. explicit DxilCleanupAddrSpaceCast() : ModulePass(ID) {}
  795. const char *getPassName() const override { return "HLSL DXIL Cleanup Address Space Cast"; }
  796. bool runOnModule(Module &M) override {
  797. return CleanupSharedMemoryAddrSpaceCast(M);
  798. }
  799. };
  800. char DxilCleanupAddrSpaceCast::ID = 0;
  801. ModulePass *llvm::createDxilCleanupAddrSpaceCastPass() {
  802. return new DxilCleanupAddrSpaceCast();
  803. }
  804. INITIALIZE_PASS(DxilCleanupAddrSpaceCast, "hlsl-dxil-cleanup-addrspacecast", "HLSL DXIL Cleanup Address Space Cast", false, false)
  805. ///////////////////////////////////////////////////////////////////////////////
  806. namespace {
  807. class DxilEmitMetadata : public ModulePass {
  808. public:
  809. static char ID; // Pass identification, replacement for typeid
  810. explicit DxilEmitMetadata() : ModulePass(ID) {}
  811. const char *getPassName() const override { return "HLSL DXIL Metadata Emit"; }
  812. bool runOnModule(Module &M) override {
  813. if (M.HasDxilModule()) {
  814. DxilModule::ClearDxilMetadata(M);
  815. patchIsFrontfaceTy(M);
  816. M.GetDxilModule().EmitDxilMetadata();
  817. return true;
  818. }
  819. return false;
  820. }
  821. private:
  822. void patchIsFrontfaceTy(Module &M);
  823. };
  824. void patchIsFrontface(DxilSignatureElement &Elt, bool bForceUint) {
  825. // If force to uint, change i1 to u32.
  826. // If not force to uint, change u32 to i1.
  827. if (bForceUint && Elt.GetCompType() == CompType::Kind::I1)
  828. Elt.SetCompType(CompType::Kind::U32);
  829. else if (!bForceUint && Elt.GetCompType() == CompType::Kind::U32)
  830. Elt.SetCompType(CompType::Kind::I1);
  831. }
  832. void patchIsFrontface(DxilSignature &sig, bool bForceUint) {
  833. for (auto &Elt : sig.GetElements()) {
  834. if (Elt->GetSemantic()->GetKind() == Semantic::Kind::IsFrontFace) {
  835. patchIsFrontface(*Elt, bForceUint);
  836. }
  837. }
  838. }
  839. void DxilEmitMetadata::patchIsFrontfaceTy(Module &M) {
  840. DxilModule &DM = M.GetDxilModule();
  841. const ShaderModel *pSM = DM.GetShaderModel();
  842. if (!pSM->IsGS() && !pSM->IsPS())
  843. return;
  844. unsigned ValMajor, ValMinor;
  845. DM.GetValidatorVersion(ValMajor, ValMinor);
  846. bool bForceUint = ValMajor == 0 || (ValMajor >= 1 && ValMinor >= 2);
  847. if (pSM->IsPS()) {
  848. patchIsFrontface(DM.GetInputSignature(), bForceUint);
  849. } else if (pSM->IsGS()) {
  850. patchIsFrontface(DM.GetOutputSignature(), bForceUint);
  851. }
  852. }
  853. }
  854. char DxilEmitMetadata::ID = 0;
  855. ModulePass *llvm::createDxilEmitMetadataPass() {
  856. return new DxilEmitMetadata();
  857. }
  858. INITIALIZE_PASS(DxilEmitMetadata, "hlsl-dxilemit", "HLSL DXIL Metadata Emit", false, false)
  859. ///////////////////////////////////////////////////////////////////////////////
  860. namespace {
  861. const StringRef UniNoWaveSensitiveGradientErrMsg =
  862. "Gradient operations are not affected by wave-sensitive data or control "
  863. "flow.";
  864. class DxilValidateWaveSensitivity : public ModulePass {
  865. public:
  866. static char ID; // Pass identification, replacement for typeid
  867. explicit DxilValidateWaveSensitivity() : ModulePass(ID) {}
  868. const char *getPassName() const override {
  869. return "HLSL DXIL wave sensitiveity validation";
  870. }
  871. bool runOnModule(Module &M) override {
  872. // Only check ps and lib profile.
  873. DxilModule &DM = M.GetDxilModule();
  874. const ShaderModel *pSM = DM.GetShaderModel();
  875. if (!pSM->IsPS() && !pSM->IsLib())
  876. return false;
  877. SmallVector<CallInst *, 16> gradientOps;
  878. SmallVector<CallInst *, 16> barriers;
  879. SmallVector<CallInst *, 16> waveOps;
  880. for (auto &F : M) {
  881. if (!F.isDeclaration())
  882. continue;
  883. for (User *U : F.users()) {
  884. CallInst *CI = dyn_cast<CallInst>(U);
  885. if (!CI)
  886. continue;
  887. Function *FCalled = CI->getCalledFunction();
  888. if (!FCalled || !FCalled->isDeclaration())
  889. continue;
  890. if (!hlsl::OP::IsDxilOpFunc(FCalled))
  891. continue;
  892. DXIL::OpCode dxilOpcode = hlsl::OP::GetDxilOpFuncCallInst(CI);
  893. if (OP::IsDxilOpWave(dxilOpcode)) {
  894. waveOps.emplace_back(CI);
  895. }
  896. if (OP::IsDxilOpGradient(dxilOpcode)) {
  897. gradientOps.push_back(CI);
  898. }
  899. if (dxilOpcode == DXIL::OpCode::Barrier) {
  900. barriers.push_back(CI);
  901. }
  902. }
  903. }
  904. // Skip if not have wave op.
  905. if (waveOps.empty())
  906. return false;
  907. // Skip if no gradient op.
  908. if (gradientOps.empty())
  909. return false;
  910. for (auto &F : M) {
  911. if (F.isDeclaration())
  912. continue;
  913. SmallVector<CallInst *, 16> localGradientOps;
  914. for (CallInst *CI : gradientOps) {
  915. if (CI->getParent()->getParent() == &F)
  916. localGradientOps.emplace_back(CI);
  917. }
  918. if (localGradientOps.empty())
  919. continue;
  920. PostDominatorTree PDT;
  921. PDT.runOnFunction(F);
  922. std::unique_ptr<WaveSensitivityAnalysis> WaveVal(
  923. WaveSensitivityAnalysis::create(PDT));
  924. WaveVal->Analyze(&F);
  925. for (CallInst *op : localGradientOps) {
  926. if (WaveVal->IsWaveSensitive(op)) {
  927. dxilutil::EmitWarningOnInstruction(op,
  928. UniNoWaveSensitiveGradientErrMsg);
  929. }
  930. }
  931. }
  932. return false;
  933. }
  934. };
  935. }
  936. char DxilValidateWaveSensitivity::ID = 0;
  937. ModulePass *llvm::createDxilValidateWaveSensitivityPass() {
  938. return new DxilValidateWaveSensitivity();
  939. }
  940. INITIALIZE_PASS(DxilValidateWaveSensitivity, "hlsl-validate-wave-sensitivity", "HLSL DXIL wave sensitiveity validation", false, false)
  941. namespace {
  942. // Cull blocks from BreakBBs that containing instructions that are sensitive to the wave-sensitive Inst
  943. // Sensitivity entails being an eventual user of the Inst and also belonging to a block with
  944. // a break conditional on dx.break that breaks out of a loop that contains WaveCI
  945. // LInfo is needed to determine loop contents. VisitedPhis is needed to prevent infinit looping.
  946. static void CullSensitiveBlocks(LoopInfo *LInfo, BasicBlock *WaveBB, BasicBlock *LastBB, Instruction *Inst,
  947. SmallPtrSet<Instruction *, 16> &VisitedPhis,
  948. SmallDenseMap<BasicBlock *, Instruction *, 16> &BreakBBs) {
  949. BasicBlock *BB = Inst->getParent();
  950. Loop *BreakLoop = LInfo->getLoopFor(BB);
  951. // If this instruction isn't in a loop, there is no need to track its sensitivity further
  952. if (!BreakLoop || BreakBBs.empty())
  953. return;
  954. // To prevent infinite looping, only visit each PHI once
  955. // If we've seen this PHI before, don't reprocess it
  956. if (isa<PHINode>(Inst) && !VisitedPhis.insert(Inst).second)
  957. return;
  958. // If this BB wasn't already just processed, handle it now
  959. if (LastBB != BB) {
  960. // Determine if the instruction's block has an artificially-conditional break
  961. // and breaks out of a loop that contains the waveCI
  962. BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
  963. if (BI && BI->isConditional() && BreakLoop->contains(WaveBB))
  964. BreakBBs.erase(BB);
  965. }
  966. // Recurse on the users
  967. for (User *U : Inst->users()) {
  968. Instruction *I = cast<Instruction>(U);
  969. CullSensitiveBlocks(LInfo, WaveBB, BB, I, VisitedPhis, BreakBBs);
  970. }
  971. }
  972. // Collect blocks that end in a dx.break dependent branch by tracing the descendants of BreakFunc
  973. // that are found in ThisFunc and store the block and call instruction in BreakBBs
  974. static void CollectBreakBlocks(Function *BreakFunc, Function *ThisFunc,
  975. SmallDenseMap<BasicBlock *, Instruction *, 16> &BreakBBs) {
  976. for (User *U : BreakFunc->users()) {
  977. SmallVector<User *, 16> WorkList;
  978. Instruction *CI = cast<Instruction>(U);
  979. // If this user doesn't pertain to the current function, skip it.
  980. if (CI->getParent()->getParent() != ThisFunc)
  981. continue;
  982. WorkList.append(CI->user_begin(), CI->user_end());
  983. while (!WorkList.empty()) {
  984. Instruction *I = dyn_cast<Instruction>(WorkList.pop_back_val());
  985. // When we find a Branch that depends on dx.break, save it and stop
  986. // This should almost always be the first user of the Call Inst
  987. // If not, iterate on the users
  988. if (BranchInst *BI = dyn_cast<BranchInst>(I))
  989. BreakBBs[BI->getParent()] = CI;
  990. else
  991. WorkList.append(I->user_begin(), I->user_end());
  992. }
  993. }
  994. }
  995. // A pass to remove conditions from breaks that do not contain instructions that
  996. // depend on wave operations that are in the loop that the break leaves.
  997. class CleanupDxBreak : public FunctionPass {
  998. public:
  999. static char ID; // Pass identification, replacement for typeid
  1000. explicit CleanupDxBreak() : FunctionPass(ID) {}
  1001. const char *getPassName() const override { return "HLSL Remove unnecessary dx.break conditions"; }
  1002. void getAnalysisUsage(AnalysisUsage &AU) const override {
  1003. AU.addRequired<LoopInfoWrapperPass>();
  1004. }
  1005. LoopInfo *LInfo;
  1006. bool runOnFunction(Function &F) override {
  1007. if (F.isDeclaration())
  1008. return false;
  1009. Module *M = F.getEntryBlock().getModule();
  1010. Function *BreakFunc = M->getFunction(DXIL::kDxBreakFuncName);
  1011. if (!BreakFunc)
  1012. return false;
  1013. LInfo = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  1014. // Collect the blocks that depend on dx.break and the instructions that call dx.break()
  1015. SmallDenseMap<BasicBlock *, Instruction *, 16> BreakBBs;
  1016. CollectBreakBlocks(BreakFunc, &F, BreakBBs);
  1017. if (BreakBBs.empty())
  1018. return false;
  1019. // For each wave operation, remove all the dx.break blocks that are sensitive to it
  1020. for (Function &IF : M->functions()) {
  1021. HLOpcodeGroup opgroup = hlsl::GetHLOpcodeGroup(&IF);
  1022. // Only consider wave-sensitive intrinsics or extintrinsics
  1023. if (IF.isDeclaration() && IsHLWaveSensitive(&IF) && !BreakBBs.empty() &&
  1024. (opgroup == HLOpcodeGroup::HLIntrinsic || opgroup == HLOpcodeGroup::HLExtIntrinsic)) {
  1025. // For each user of the function, trace all its users to remove the blocks
  1026. for (User *U : IF.users()) {
  1027. CallInst *CI = cast<CallInst>(U);
  1028. if (CI->getParent()->getParent() == &F) {
  1029. SmallPtrSet<Instruction *, 16> VisitedPhis;
  1030. CullSensitiveBlocks(LInfo, CI->getParent(), nullptr, CI, VisitedPhis, BreakBBs);
  1031. }
  1032. }
  1033. }
  1034. }
  1035. bool Changed = false;
  1036. // Revert artificially conditional breaks in non-wave-sensitive blocks that remain in BreakBBs
  1037. Constant *C = ConstantInt::get(Type::getInt1Ty(M->getContext()), 1);
  1038. for (auto &BB : BreakBBs) {
  1039. // Replace the call instruction with a constant boolen
  1040. BB.second->replaceAllUsesWith(C);
  1041. BB.second->eraseFromParent();
  1042. Changed = true;
  1043. }
  1044. return Changed;
  1045. }
  1046. };
  1047. }
  1048. char CleanupDxBreak::ID = 0;
  1049. INITIALIZE_PASS_BEGIN(CleanupDxBreak, "hlsl-cleanup-dxbreak", "HLSL Remove unnecessary dx.break conditions", false, false)
  1050. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  1051. INITIALIZE_PASS_END(CleanupDxBreak, "hlsl-cleanup-dxbreak", "HLSL Remove unnecessary dx.break conditions", false, false)
  1052. FunctionPass *llvm::createCleanupDxBreakPass() {
  1053. return new CleanupDxBreak();
  1054. }