DxilPreparePasses.cpp 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilPreparePasses.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Passes to prepare DxilModule. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/DxilGenerationPass.h"
  12. #include "dxc/DXIL/DxilOperations.h"
  13. #include "dxc/HLSL/HLOperations.h"
  14. #include "dxc/DXIL/DxilModule.h"
  15. #include "dxc/Support/Global.h"
  16. #include "dxc/DXIL/DxilTypeSystem.h"
  17. #include "dxc/DXIL/DxilUtil.h"
  18. #include "dxc/DXIL/DxilFunctionProps.h"
  19. #include "dxc/DXIL/DxilInstructions.h"
  20. #include "dxc/DXIL/DxilConstants.h"
  21. #include "dxc/HlslIntrinsicOp.h"
  22. #include "llvm/IR/GetElementPtrTypeIterator.h"
  23. #include "llvm/IR/IRBuilder.h"
  24. #include "llvm/IR/Instructions.h"
  25. #include "llvm/IR/InstIterator.h"
  26. #include "llvm/IR/IntrinsicInst.h"
  27. #include "llvm/Analysis/PostDominators.h"
  28. #include "llvm/IR/Module.h"
  29. #include "llvm/IR/DebugInfo.h"
  30. #include "llvm/IR/DIBuilder.h"
  31. #include "llvm/IR/PassManager.h"
  32. #include "llvm/ADT/BitVector.h"
  33. #include "llvm/Pass.h"
  34. #include "llvm/Transforms/Utils/Local.h"
  35. #include "llvm/Analysis/AssumptionCache.h"
  36. #include "llvm/Analysis/DxilValueCache.h"
  37. #include "llvm/Analysis/LoopInfo.h"
  38. #include <memory>
  39. #include <unordered_set>
  40. using namespace llvm;
  41. using namespace hlsl;
  42. namespace {
  43. class InvalidateUndefResources : public ModulePass {
  44. public:
  45. static char ID;
  46. explicit InvalidateUndefResources() : ModulePass(ID) {
  47. initializeScalarizerPass(*PassRegistry::getPassRegistry());
  48. }
  49. const char *getPassName() const override { return "Invalidate undef resources"; }
  50. bool runOnModule(Module &M) override;
  51. };
  52. }
  53. char InvalidateUndefResources::ID = 0;
  54. ModulePass *llvm::createInvalidateUndefResourcesPass() { return new InvalidateUndefResources(); }
  55. INITIALIZE_PASS(InvalidateUndefResources, "invalidate-undef-resource", "Invalidate undef resources", false, false)
  56. bool InvalidateUndefResources::runOnModule(Module &M) {
  57. // Undef resources typically indicate uninitialized locals being used
  58. // in some code path, which we should catch and report. However, some
  59. // code patterns in large shaders cause dead undef resources to momentarily,
  60. // which is not an error. We must wait until cleanup passes
  61. // have run to know whether we must produce an error.
  62. // However, we can't leave the undef values in because they could eliminated,
  63. // such as by reading from resources seen in a code path that was not taken.
  64. // We avoid the problem by replacing undef values by another invalid
  65. // value that we can identify later.
  66. for (auto &F : M.functions()) {
  67. if (GetHLOpcodeGroupByName(&F) == HLOpcodeGroup::HLCreateHandle) {
  68. Type *ResTy = F.getFunctionType()->getParamType(
  69. HLOperandIndex::kCreateHandleResourceOpIdx);
  70. UndefValue *UndefRes = UndefValue::get(ResTy);
  71. if (!UndefRes->use_empty()) {
  72. Constant *InvalidRes = ConstantAggregateZero::get(ResTy);
  73. UndefRes->replaceAllUsesWith(InvalidRes);
  74. }
  75. }
  76. }
  77. return false;
  78. }
  79. ///////////////////////////////////////////////////////////////////////////////
  80. namespace {
  81. class SimplifyInst : public FunctionPass {
  82. public:
  83. static char ID;
  84. SimplifyInst() : FunctionPass(ID) {
  85. initializeScalarizerPass(*PassRegistry::getPassRegistry());
  86. }
  87. bool runOnFunction(Function &F) override;
  88. private:
  89. };
  90. }
  91. char SimplifyInst::ID = 0;
  92. FunctionPass *llvm::createSimplifyInstPass() { return new SimplifyInst(); }
  93. INITIALIZE_PASS(SimplifyInst, "simplify-inst", "Simplify Instructions", false, false)
  94. bool SimplifyInst::runOnFunction(Function &F) {
  95. for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
  96. BasicBlock *BB = BBI;
  97. llvm::SimplifyInstructionsInBlock(BB, nullptr);
  98. }
  99. return true;
  100. }
  101. ///////////////////////////////////////////////////////////////////////////////
  102. namespace {
  103. class DxilDeadFunctionElimination : public ModulePass {
  104. public:
  105. static char ID; // Pass identification, replacement for typeid
  106. explicit DxilDeadFunctionElimination () : ModulePass(ID) {}
  107. const char *getPassName() const override { return "Remove all unused function except entry from DxilModule"; }
  108. bool runOnModule(Module &M) override {
  109. if (M.HasDxilModule()) {
  110. DxilModule &DM = M.GetDxilModule();
  111. bool IsLib = DM.GetShaderModel()->IsLib();
  112. // Remove unused functions except entry and patch constant func.
  113. // For library profile, only remove unused external functions.
  114. Function *EntryFunc = DM.GetEntryFunction();
  115. Function *PatchConstantFunc = DM.GetPatchConstantFunction();
  116. return dxilutil::RemoveUnusedFunctions(M, EntryFunc, PatchConstantFunc,
  117. IsLib);
  118. }
  119. return false;
  120. }
  121. };
  122. }
  123. char DxilDeadFunctionElimination::ID = 0;
  124. ModulePass *llvm::createDxilDeadFunctionEliminationPass() {
  125. return new DxilDeadFunctionElimination();
  126. }
  127. INITIALIZE_PASS(DxilDeadFunctionElimination, "dxil-dfe", "Remove all unused function except entry from DxilModule", false, false)
  128. ///////////////////////////////////////////////////////////////////////////////
  129. bool CleanupSharedMemoryAddrSpaceCast(Module &M);
  130. namespace {
  131. static void TransferEntryFunctionAttributes(Function *F, Function *NewFunc) {
  132. // Keep necessary function attributes
  133. AttributeSet attributeSet = F->getAttributes();
  134. StringRef attrKind, attrValue;
  135. if (attributeSet.hasAttribute(AttributeSet::FunctionIndex, DXIL::kFP32DenormKindString)) {
  136. Attribute attribute = attributeSet.getAttribute(AttributeSet::FunctionIndex, DXIL::kFP32DenormKindString);
  137. DXASSERT(attribute.isStringAttribute(), "otherwise we have wrong fp-denorm-mode attribute.");
  138. attrKind = attribute.getKindAsString();
  139. attrValue = attribute.getValueAsString();
  140. }
  141. if (F == NewFunc) {
  142. NewFunc->removeAttributes(AttributeSet::FunctionIndex, attributeSet);
  143. }
  144. if (!attrKind.empty() && !attrValue.empty())
  145. NewFunc->addFnAttr(attrKind, attrValue);
  146. }
  147. // If this returns non-null, the old function F has been stripped and can be deleted.
  148. static Function *StripFunctionParameter(Function *F, DxilModule &DM,
  149. DenseMap<const Function *, DISubprogram *> &FunctionDIs) {
  150. if (F->arg_empty() && F->getReturnType()->isVoidTy()) {
  151. // This will strip non-entry function attributes
  152. TransferEntryFunctionAttributes(F, F);
  153. return nullptr;
  154. }
  155. Module &M = *DM.GetModule();
  156. Type *VoidTy = Type::getVoidTy(M.getContext());
  157. FunctionType *FT = FunctionType::get(VoidTy, false);
  158. for (auto &arg : F->args()) {
  159. if (!arg.user_empty())
  160. return nullptr;
  161. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(&arg);
  162. if (DDI) {
  163. DDI->eraseFromParent();
  164. }
  165. }
  166. Function *NewFunc = Function::Create(FT, F->getLinkage());
  167. M.getFunctionList().insert(F, NewFunc);
  168. // Splice the body of the old function right into the new function.
  169. NewFunc->getBasicBlockList().splice(NewFunc->begin(), F->getBasicBlockList());
  170. TransferEntryFunctionAttributes(F, NewFunc);
  171. // Patch the pointer to LLVM function in debug info descriptor.
  172. auto DI = FunctionDIs.find(F);
  173. if (DI != FunctionDIs.end()) {
  174. DISubprogram *SP = DI->second;
  175. SP->replaceFunction(NewFunc);
  176. // Ensure the map is updated so it can be reused on subsequent argument
  177. // promotions of the same function.
  178. FunctionDIs.erase(DI);
  179. FunctionDIs[NewFunc] = SP;
  180. }
  181. NewFunc->takeName(F);
  182. if (DM.HasDxilFunctionProps(F)) {
  183. DM.ReplaceDxilEntryProps(F, NewFunc);
  184. }
  185. DM.GetTypeSystem().EraseFunctionAnnotation(F);
  186. DM.GetTypeSystem().AddFunctionAnnotation(NewFunc);
  187. return NewFunc;
  188. }
  189. void CheckInBoundForTGSM(GlobalVariable &GV, const DataLayout &DL) {
  190. for (User *U : GV.users()) {
  191. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  192. bool allImmIndex = true;
  193. for (auto Idx = GEP->idx_begin(), E = GEP->idx_end(); Idx != E; Idx++) {
  194. if (!isa<ConstantInt>(Idx)) {
  195. allImmIndex = false;
  196. break;
  197. }
  198. }
  199. if (!allImmIndex)
  200. GEP->setIsInBounds(false);
  201. else {
  202. Value *Ptr = GEP->getPointerOperand();
  203. unsigned size =
  204. DL.getTypeAllocSize(Ptr->getType()->getPointerElementType());
  205. unsigned valSize =
  206. DL.getTypeAllocSize(GEP->getType()->getPointerElementType());
  207. SmallVector<Value *, 8> Indices(GEP->idx_begin(), GEP->idx_end());
  208. unsigned offset =
  209. DL.getIndexedOffset(GEP->getPointerOperandType(), Indices);
  210. if ((offset + valSize) > size)
  211. GEP->setIsInBounds(false);
  212. }
  213. }
  214. }
  215. }
  216. static bool GetUnsignedVal(Value *V, uint32_t *pValue) {
  217. ConstantInt *CI = dyn_cast<ConstantInt>(V);
  218. if (!CI) return false;
  219. uint64_t u = CI->getZExtValue();
  220. if (u > UINT32_MAX) return false;
  221. *pValue = (uint32_t)u;
  222. return true;
  223. }
  224. static void MarkUsedSignatureElements(Function *F, DxilModule &DM) {
  225. DXASSERT_NOMSG(F != nullptr);
  226. // For every loadInput/storeOutput, update the corresponding ReadWriteMask.
  227. // F is a pointer to a Function instance
  228. for (llvm::inst_iterator I = llvm::inst_begin(F), E = llvm::inst_end(F); I != E; ++I) {
  229. DxilInst_LoadInput LI(&*I);
  230. DxilInst_StoreOutput SO(&*I);
  231. DxilInst_LoadPatchConstant LPC(&*I);
  232. DxilInst_StorePatchConstant SPC(&*I);
  233. DxilInst_StoreVertexOutput SVO(&*I);
  234. DxilInst_StorePrimitiveOutput SPO(&*I);
  235. DxilSignature *pSig;
  236. uint32_t col, row, sigId;
  237. bool bDynIdx = false;
  238. if (LI) {
  239. if (!GetUnsignedVal(LI.get_inputSigId(), &sigId)) continue;
  240. if (!GetUnsignedVal(LI.get_colIndex(), &col)) continue;
  241. if (!GetUnsignedVal(LI.get_rowIndex(), &row)) bDynIdx = true;
  242. pSig = &DM.GetInputSignature();
  243. }
  244. else if (SO) {
  245. if (!GetUnsignedVal(SO.get_outputSigId(), &sigId)) continue;
  246. if (!GetUnsignedVal(SO.get_colIndex(), &col)) continue;
  247. if (!GetUnsignedVal(SO.get_rowIndex(), &row)) bDynIdx = true;
  248. pSig = &DM.GetOutputSignature();
  249. }
  250. else if (SPC) {
  251. if (!GetUnsignedVal(SPC.get_outputSigID(), &sigId)) continue;
  252. if (!GetUnsignedVal(SPC.get_col(), &col)) continue;
  253. if (!GetUnsignedVal(SPC.get_row(), &row)) bDynIdx = true;
  254. pSig = &DM.GetPatchConstOrPrimSignature();
  255. }
  256. else if (LPC) {
  257. if (!GetUnsignedVal(LPC.get_inputSigId(), &sigId)) continue;
  258. if (!GetUnsignedVal(LPC.get_col(), &col)) continue;
  259. if (!GetUnsignedVal(LPC.get_row(), &row)) bDynIdx = true;
  260. pSig = &DM.GetPatchConstOrPrimSignature();
  261. }
  262. else if (SVO) {
  263. if (!GetUnsignedVal(SVO.get_outputSigId(), &sigId)) continue;
  264. if (!GetUnsignedVal(SVO.get_colIndex(), &col)) continue;
  265. if (!GetUnsignedVal(SVO.get_rowIndex(), &row)) bDynIdx = true;
  266. pSig = &DM.GetOutputSignature();
  267. }
  268. else if (SPO) {
  269. if (!GetUnsignedVal(SPO.get_outputSigId(), &sigId)) continue;
  270. if (!GetUnsignedVal(SPO.get_colIndex(), &col)) continue;
  271. if (!GetUnsignedVal(SPO.get_rowIndex(), &row)) bDynIdx = true;
  272. pSig = &DM.GetPatchConstOrPrimSignature();
  273. }
  274. else {
  275. continue;
  276. }
  277. // Consider being more fine-grained about masks.
  278. // We report sometimes-read on input as always-read.
  279. auto &El = pSig->GetElement(sigId);
  280. unsigned UsageMask = El.GetUsageMask();
  281. unsigned colBit = 1 << col;
  282. if (!(colBit & UsageMask)) {
  283. El.SetUsageMask(UsageMask | colBit);
  284. }
  285. if (bDynIdx && (El.GetDynIdxCompMask() & colBit) == 0) {
  286. El.SetDynIdxCompMask(El.GetDynIdxCompMask() | colBit);
  287. }
  288. }
  289. }
  290. class DxilFinalizeModule : public ModulePass {
  291. public:
  292. static char ID; // Pass identification, replacement for typeid
  293. explicit DxilFinalizeModule() : ModulePass(ID) {}
  294. const char *getPassName() const override { return "HLSL DXIL Finalize Module"; }
  295. void patchValidation_1_1(Module &M) {
  296. for (iplist<Function>::iterator F : M.getFunctionList()) {
  297. for (Function::iterator BBI = F->begin(), BBE = F->end(); BBI != BBE;
  298. ++BBI) {
  299. BasicBlock *BB = BBI;
  300. for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;
  301. ++II) {
  302. Instruction *I = II;
  303. if (I->hasMetadataOtherThanDebugLoc()) {
  304. SmallVector<std::pair<unsigned, MDNode*>, 2> MDs;
  305. I->getAllMetadataOtherThanDebugLoc(MDs);
  306. for (auto &MD : MDs) {
  307. unsigned kind = MD.first;
  308. // Remove Metadata which validation_1_0 not allowed.
  309. bool bNeedPatch = kind == LLVMContext::MD_tbaa ||
  310. kind == LLVMContext::MD_prof ||
  311. (kind > LLVMContext::MD_fpmath &&
  312. kind <= LLVMContext::MD_dereferenceable_or_null);
  313. if (bNeedPatch)
  314. I->setMetadata(kind, nullptr);
  315. }
  316. }
  317. }
  318. }
  319. }
  320. }
  321. void patchDxil_1_6(Module &M, hlsl::OP *hlslOP, unsigned ValMajor, unsigned ValMinor) {
  322. for (auto it : hlslOP->GetOpFuncList(DXIL::OpCode::AnnotateHandle)) {
  323. Function *F = it.second;
  324. if (!F)
  325. continue;
  326. for (auto uit = F->user_begin(); uit != F->user_end();) {
  327. CallInst *CI = cast<CallInst>(*(uit++));
  328. DxilInst_AnnotateHandle annoteHdl(CI);
  329. Value *hdl = annoteHdl.get_res();
  330. CI->replaceAllUsesWith(hdl);
  331. CI->eraseFromParent();
  332. }
  333. }
  334. }
  335. // Replace llvm.lifetime.start/.end intrinsics with undef or zeroinitializer
  336. // stores (for earlier validator versions) unless the pointer is a global
  337. // that has an initializer.
  338. // This works around losing scoping information in earlier shader models
  339. // that do not support the intrinsics natively.
  340. void patchLifetimeIntrinsics(Module &M, unsigned ValMajor, unsigned ValMinor, bool forceZeroStoreLifetimes) {
  341. // Get the declarations. This may introduce them if there were none before.
  342. Value *StartDecl = Intrinsic::getDeclaration(&M, Intrinsic::lifetime_start);
  343. Value *EndDecl = Intrinsic::getDeclaration(&M, Intrinsic::lifetime_end);
  344. // Collect all calls to both intrinsics.
  345. std::vector<CallInst*> intrinsicCalls;
  346. for (Use &U : StartDecl->uses()) {
  347. // All users must be call instructions.
  348. CallInst *CI = dyn_cast<CallInst>(U.getUser());
  349. DXASSERT(CI,
  350. "Expected user of lifetime.start intrinsic to be a CallInst");
  351. intrinsicCalls.push_back(CI);
  352. }
  353. for (Use &U : EndDecl->uses()) {
  354. // All users must be call instructions.
  355. CallInst *CI = dyn_cast<CallInst>(U.getUser());
  356. DXASSERT(CI, "Expected user of lifetime.end intrinsic to be a CallInst");
  357. intrinsicCalls.push_back(CI);
  358. }
  359. // Replace each intrinsic with an undef store.
  360. for (CallInst *CI : intrinsicCalls) {
  361. // Find the corresponding pointer (bitcast from alloca, global value, an
  362. // argument, ...).
  363. Value *voidPtr = CI->getArgOperand(1);
  364. DXASSERT(voidPtr->getType()->isPointerTy() &&
  365. voidPtr->getType()->getPointerElementType()->isIntegerTy(8),
  366. "Expected operand of lifetime intrinsic to be of type i8*" );
  367. Value *ptr = nullptr;
  368. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(voidPtr)) {
  369. // This can happen if a local variable/array is promoted to a constant
  370. // global. In this case we must not introduce a store, since that would
  371. // overwrite the constant values in the initializer. Thus, we simply
  372. // remove the intrinsic.
  373. DXASSERT(CE->getOpcode() == Instruction::BitCast,
  374. "expected operand of lifetime intrinsic to be a bitcast");
  375. } else {
  376. // Otherwise, it must be a normal bitcast.
  377. DXASSERT(isa<BitCastInst>(voidPtr),
  378. "Expected operand of lifetime intrinsic to be a bitcast");
  379. BitCastInst *BC = cast<BitCastInst>(voidPtr);
  380. ptr = BC->getOperand(0);
  381. // If the original pointer is a global with initializer, do not replace
  382. // the intrinsic with a store.
  383. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(ptr))
  384. if (GV->hasInitializer() || GV->isExternallyInitialized())
  385. ptr = nullptr;
  386. }
  387. if (ptr) {
  388. // Determine the type to use when storing undef.
  389. DXASSERT(ptr->getType()->isPointerTy(),
  390. "Expected type of operand of lifetime intrinsic bitcast operand to be a pointer");
  391. Type *T = ptr->getType()->getPointerElementType();
  392. // Store undef at the location of the start/end intrinsic.
  393. // If we are targeting validator version < 6.6 we cannot store undef
  394. // since it causes a validation error. As a workaround we store 0, which
  395. // achieves mostly the same as storing undef but can cause overhead in
  396. // some situations.
  397. // We also allow to force zeroinitializer through a flag.
  398. if (forceZeroStoreLifetimes || ValMajor < 1 || (ValMajor == 1 && ValMinor < 6))
  399. IRBuilder<>(CI).CreateStore(Constant::getNullValue(T), ptr);
  400. else
  401. IRBuilder<>(CI).CreateStore(UndefValue::get(T), ptr);
  402. }
  403. // Erase the intrinsic call and, if it has no uses anymore, the bitcast as
  404. // well.
  405. DXASSERT_NOMSG(CI->use_empty());
  406. CI->eraseFromParent();
  407. // Erase the bitcast inst if it is not a ConstantExpr.
  408. if (BitCastInst *BC = dyn_cast<BitCastInst>(voidPtr))
  409. if (BC->use_empty())
  410. BC->eraseFromParent();
  411. }
  412. // Erase the intrinsic declarations.
  413. DXASSERT_NOMSG(StartDecl->use_empty());
  414. DXASSERT_NOMSG(EndDecl->use_empty());
  415. cast<Function>(StartDecl)->eraseFromParent();
  416. cast<Function>(EndDecl)->eraseFromParent();
  417. }
  418. bool runOnModule(Module &M) override {
  419. if (M.HasDxilModule()) {
  420. DxilModule &DM = M.GetDxilModule();
  421. unsigned ValMajor = 0;
  422. unsigned ValMinor = 0;
  423. DM.GetValidatorVersion(ValMajor, ValMinor);
  424. unsigned DxilMajor = 0;
  425. unsigned DxilMinor = 0;
  426. DM.GetDxilVersion(DxilMajor, DxilMinor);
  427. bool IsLib = DM.GetShaderModel()->IsLib();
  428. // Skip validation patch for lib.
  429. if (!IsLib) {
  430. if (ValMajor == 1 && ValMinor <= 1) {
  431. patchValidation_1_1(M);
  432. }
  433. // Set used masks for signature elements
  434. MarkUsedSignatureElements(DM.GetEntryFunction(), DM);
  435. if (DM.GetShaderModel()->IsHS())
  436. MarkUsedSignatureElements(DM.GetPatchConstantFunction(), DM);
  437. }
  438. // Replace lifetime intrinsics if requested or necessary.
  439. const bool forceZeroStoreLifetimes = DM.GetForceZeroStoreLifetimes();
  440. if (forceZeroStoreLifetimes || DxilMinor < 6) {
  441. patchLifetimeIntrinsics(M, ValMajor, ValMinor, forceZeroStoreLifetimes);
  442. }
  443. // Remove store undef output.
  444. hlsl::OP *hlslOP = DM.GetOP();
  445. if (DxilMinor < 6) {
  446. patchDxil_1_6(M, hlslOP, ValMajor, ValMinor);
  447. }
  448. RemoveStoreUndefOutput(M, hlslOP);
  449. // Turn dx.break() conditional into global
  450. LowerDxBreak(M);
  451. RemoveUnusedStaticGlobal(M);
  452. // Remove unnecessary address space casts.
  453. CleanupSharedMemoryAddrSpaceCast(M);
  454. // Clear inbound for GEP which has none-const index.
  455. LegalizeSharedMemoryGEPInbound(M);
  456. // Strip parameters of entry function.
  457. StripEntryParameters(M, DM, IsLib);
  458. // Update flags to reflect any changes.
  459. DM.CollectShaderFlagsForModule();
  460. // Update Validator Version
  461. DM.UpgradeToMinValidatorVersion();
  462. // Clear intermediate options that shouldn't be in the final DXIL
  463. DM.ClearIntermediateOptions();
  464. // Remove unused AllocateRayQuery calls
  465. RemoveUnusedRayQuery(M);
  466. if (IsLib && DXIL::CompareVersions(ValMajor, ValMinor, 1, 4) <= 0) {
  467. // 1.4 validator requires function annotations for all functions
  468. AddFunctionAnnotationForInitializers(M, DM);
  469. }
  470. // Fix DIExpression fragments that cover whole variables
  471. LegalizeDbgFragments(M);
  472. return true;
  473. }
  474. return false;
  475. }
  476. private:
  477. void RemoveUnusedStaticGlobal(Module &M) {
  478. // Remove unused internal global.
  479. std::vector<GlobalVariable *> staticGVs;
  480. for (GlobalVariable &GV : M.globals()) {
  481. if (dxilutil::IsStaticGlobal(&GV) ||
  482. dxilutil::IsSharedMemoryGlobal(&GV)) {
  483. staticGVs.emplace_back(&GV);
  484. }
  485. }
  486. for (GlobalVariable *GV : staticGVs) {
  487. bool onlyStoreUse = true;
  488. for (User *user : GV->users()) {
  489. if (isa<StoreInst>(user))
  490. continue;
  491. if (isa<ConstantExpr>(user) && user->user_empty())
  492. continue;
  493. onlyStoreUse = false;
  494. break;
  495. }
  496. if (onlyStoreUse) {
  497. for (auto UserIt = GV->user_begin(); UserIt != GV->user_end();) {
  498. Value *User = *(UserIt++);
  499. if (Instruction *I = dyn_cast<Instruction>(User)) {
  500. I->eraseFromParent();
  501. } else {
  502. ConstantExpr *CE = cast<ConstantExpr>(User);
  503. CE->dropAllReferences();
  504. }
  505. }
  506. GV->eraseFromParent();
  507. }
  508. }
  509. }
  510. static bool BitPieceCoversEntireVar(DIExpression *expr, DILocalVariable *var, DITypeIdentifierMap &TypeIdentifierMap) {
  511. if (expr->isBitPiece()) {
  512. DIType *ty = var->getType().resolve(TypeIdentifierMap);
  513. return expr->getBitPieceOffset() == 0 && expr->getBitPieceSize() == ty->getSizeInBits();
  514. }
  515. return false;
  516. }
  517. static void LegalizeDbgFragmentsForDbgIntrinsic(Function *f, DITypeIdentifierMap &TypeIdentifierMap) {
  518. Intrinsic::ID intrinsic = f->getIntrinsicID();
  519. DIBuilder dib(*f->getParent());
  520. if (intrinsic == Intrinsic::dbg_value) {
  521. for (auto it = f->user_begin(), end = f->user_end(); it != end;) {
  522. User *u = *(it++);
  523. DbgValueInst *di = cast<DbgValueInst>(u);
  524. Value *value = di->getValue();
  525. if (!value) {
  526. di->eraseFromParent();
  527. continue;
  528. }
  529. DIExpression *expr = di->getExpression();
  530. DILocalVariable *var = di->getVariable();
  531. if (BitPieceCoversEntireVar(expr, var, TypeIdentifierMap)) {
  532. dib.insertDbgValueIntrinsic(value, 0, var, DIExpression::get(di->getContext(), {}), di->getDebugLoc(), di);
  533. di->eraseFromParent();
  534. }
  535. }
  536. }
  537. else if (intrinsic == Intrinsic::dbg_declare) {
  538. for (auto it = f->user_begin(), end = f->user_end(); it != end;) {
  539. User *u = *(it++);
  540. DbgDeclareInst *di = cast<DbgDeclareInst>(u);
  541. Value *addr = di->getAddress();
  542. if (!addr) {
  543. di->eraseFromParent();
  544. continue;
  545. }
  546. DIExpression *expr = di->getExpression();
  547. DILocalVariable *var = di->getVariable();
  548. if (BitPieceCoversEntireVar(expr, var, TypeIdentifierMap)) {
  549. dib.insertDeclare(addr, var, DIExpression::get(di->getContext(), {}), di->getDebugLoc(), di);
  550. di->eraseFromParent();
  551. }
  552. }
  553. }
  554. }
  555. static void LegalizeDbgFragments(Module &M) {
  556. DITypeIdentifierMap TypeIdentifierMap;
  557. if (Function *f = M.getFunction(Intrinsic::getName(Intrinsic::dbg_value))) {
  558. LegalizeDbgFragmentsForDbgIntrinsic(f, TypeIdentifierMap);
  559. }
  560. if (Function *f = M.getFunction(Intrinsic::getName(Intrinsic::dbg_declare))) {
  561. LegalizeDbgFragmentsForDbgIntrinsic(f, TypeIdentifierMap);
  562. }
  563. }
  564. void RemoveStoreUndefOutput(Module &M, hlsl::OP *hlslOP) {
  565. for (iplist<Function>::iterator F : M.getFunctionList()) {
  566. if (!hlslOP->IsDxilOpFunc(F))
  567. continue;
  568. DXIL::OpCodeClass opClass;
  569. bool bHasOpClass = hlslOP->GetOpCodeClass(F, opClass);
  570. DXASSERT_LOCALVAR(bHasOpClass, bHasOpClass, "else not a dxil op func");
  571. if (opClass != DXIL::OpCodeClass::StoreOutput)
  572. continue;
  573. for (auto it = F->user_begin(); it != F->user_end();) {
  574. CallInst *CI = dyn_cast<CallInst>(*(it++));
  575. if (!CI)
  576. continue;
  577. Value *V = CI->getArgOperand(DXIL::OperandIndex::kStoreOutputValOpIdx);
  578. // Remove the store of undef.
  579. if (isa<UndefValue>(V))
  580. CI->eraseFromParent();
  581. }
  582. }
  583. }
  584. void LegalizeSharedMemoryGEPInbound(Module &M) {
  585. const DataLayout &DL = M.getDataLayout();
  586. // Clear inbound for GEP which has none-const index.
  587. for (GlobalVariable &GV : M.globals()) {
  588. if (dxilutil::IsSharedMemoryGlobal(&GV)) {
  589. CheckInBoundForTGSM(GV, DL);
  590. }
  591. }
  592. }
  593. void StripEntryParameters(Module &M, DxilModule &DM, bool IsLib) {
  594. DenseMap<const Function *, DISubprogram *> FunctionDIs =
  595. makeSubprogramMap(M);
  596. // Strip parameters of entry function.
  597. if (!IsLib) {
  598. if (Function *OldPatchConstantFunc = DM.GetPatchConstantFunction()) {
  599. Function *NewPatchConstantFunc =
  600. StripFunctionParameter(OldPatchConstantFunc, DM, FunctionDIs);
  601. if (NewPatchConstantFunc) {
  602. DM.SetPatchConstantFunction(NewPatchConstantFunc);
  603. // Erase once the DxilModule doesn't track the old function anymore
  604. DXASSERT(DM.IsPatchConstantShader(NewPatchConstantFunc) && !DM.IsPatchConstantShader(OldPatchConstantFunc),
  605. "Error while migrating to parameter-stripped patch constant function.");
  606. OldPatchConstantFunc->eraseFromParent();
  607. }
  608. }
  609. if (Function *OldEntryFunc = DM.GetEntryFunction()) {
  610. StringRef Name = DM.GetEntryFunctionName();
  611. OldEntryFunc->setName(Name);
  612. Function *NewEntryFunc = StripFunctionParameter(OldEntryFunc, DM, FunctionDIs);
  613. if (NewEntryFunc) {
  614. DM.SetEntryFunction(NewEntryFunc);
  615. OldEntryFunc->eraseFromParent();
  616. }
  617. }
  618. } else {
  619. std::vector<Function *> entries;
  620. // Handle when multiple hull shaders point to the same patch constant function
  621. MapVector<Function*, llvm::SmallVector<Function*, 2>> PatchConstantFuncUsers;
  622. for (iplist<Function>::iterator F : M.getFunctionList()) {
  623. if (DM.IsEntryThatUsesSignatures(F)) {
  624. auto *FT = F->getFunctionType();
  625. // Only do this when has parameters.
  626. if (FT->getNumParams() > 0 || !FT->getReturnType()->isVoidTy()) {
  627. entries.emplace_back(F);
  628. }
  629. DxilFunctionProps& props = DM.GetDxilFunctionProps(F);
  630. if (props.IsHS() && props.ShaderProps.HS.patchConstantFunc) {
  631. FunctionType* PatchConstantFuncTy = props.ShaderProps.HS.patchConstantFunc->getFunctionType();
  632. if (PatchConstantFuncTy->getNumParams() > 0 || !PatchConstantFuncTy->getReturnType()->isVoidTy()) {
  633. // Accumulate all hull shaders using a given patch constant function,
  634. // so we can update it once and fix all hull shaders, without having an intermediary
  635. // state where some hull shaders point to a destroyed patch constant function.
  636. PatchConstantFuncUsers[props.ShaderProps.HS.patchConstantFunc].emplace_back(F);
  637. }
  638. }
  639. }
  640. }
  641. // Strip patch constant functions first
  642. for (auto &PatchConstantFuncEntry : PatchConstantFuncUsers) {
  643. Function* OldPatchConstantFunc = PatchConstantFuncEntry.first;
  644. Function* NewPatchConstantFunc = StripFunctionParameter(OldPatchConstantFunc, DM, FunctionDIs);
  645. if (NewPatchConstantFunc) {
  646. // Update all user hull shaders
  647. for (Function *HullShaderFunc : PatchConstantFuncEntry.second)
  648. DM.SetPatchConstantFunctionForHS(HullShaderFunc, NewPatchConstantFunc);
  649. // Erase once the DxilModule doesn't track the old function anymore
  650. DXASSERT(DM.IsPatchConstantShader(NewPatchConstantFunc) && !DM.IsPatchConstantShader(OldPatchConstantFunc),
  651. "Error while migrating to parameter-stripped patch constant function.");
  652. OldPatchConstantFunc->eraseFromParent();
  653. }
  654. }
  655. for (Function *OldEntry : entries) {
  656. Function *NewEntry = StripFunctionParameter(OldEntry, DM, FunctionDIs);
  657. if (NewEntry) OldEntry->eraseFromParent();
  658. }
  659. }
  660. }
  661. void AddFunctionAnnotationForInitializers(Module &M, DxilModule &DM) {
  662. if (GlobalVariable *GV = M.getGlobalVariable("llvm.global_ctors")) {
  663. if (isa<ConstantAggregateZero>(GV->getInitializer())) {
  664. DXASSERT_NOMSG(GV->user_empty());
  665. GV->eraseFromParent();
  666. return;
  667. }
  668. ConstantArray *init = cast<ConstantArray>(GV->getInitializer());
  669. for (auto V : init->operand_values()) {
  670. if (isa<ConstantAggregateZero>(V))
  671. continue;
  672. ConstantStruct *CS = cast<ConstantStruct>(V);
  673. if (isa<ConstantPointerNull>(CS->getOperand(1)))
  674. continue;
  675. Function *F = cast<Function>(CS->getOperand(1));
  676. if (DM.GetTypeSystem().GetFunctionAnnotation(F) == nullptr)
  677. DM.GetTypeSystem().AddFunctionAnnotation(F);
  678. }
  679. }
  680. }
  681. void RemoveUnusedRayQuery(Module &M) {
  682. hlsl::OP *hlslOP = M.GetDxilModule().GetOP();
  683. llvm::Function *AllocFn = hlslOP->GetOpFunc(
  684. DXIL::OpCode::AllocateRayQuery, Type::getVoidTy(M.getContext()));
  685. SmallVector<CallInst*, 4> DeadInsts;
  686. for (auto U : AllocFn->users()) {
  687. if (CallInst *CI = dyn_cast<CallInst>(U)) {
  688. if (CI->user_empty()) {
  689. DeadInsts.emplace_back(CI);
  690. }
  691. }
  692. }
  693. for (auto CI : DeadInsts) {
  694. CI->eraseFromParent();
  695. }
  696. if (AllocFn->user_empty()) {
  697. AllocFn->eraseFromParent();
  698. }
  699. }
  700. // Convert all uses of dx.break() into per-function load/cmp of dx.break.cond global constant
  701. void LowerDxBreak(Module &M) {
  702. if (Function *BreakFunc = M.getFunction(DXIL::kDxBreakFuncName)) {
  703. if (!BreakFunc->use_empty()) {
  704. llvm::Type *i32Ty = llvm::Type::getInt32Ty(M.getContext());
  705. Type *i32ArrayTy = ArrayType::get(i32Ty, 1);
  706. unsigned int Values[1] = { 0 };
  707. Constant *InitialValue = ConstantDataArray::get(M.getContext(), Values);
  708. Constant *GV = new GlobalVariable(M, i32ArrayTy, true,
  709. GlobalValue::InternalLinkage,
  710. InitialValue, DXIL::kDxBreakCondName);
  711. Constant *Indices[] = { ConstantInt::get(i32Ty, 0), ConstantInt::get(i32Ty, 0) };
  712. Constant *Gep = ConstantExpr::getGetElementPtr(nullptr, GV, Indices);
  713. SmallDenseMap<llvm::Function*, llvm::ICmpInst*, 16> DxBreakCmpMap;
  714. // Replace all uses of dx.break with references to the constant global
  715. for (auto I = BreakFunc->user_begin(), E = BreakFunc->user_end(); I != E;) {
  716. User *U = *I++;
  717. CallInst *CI = cast<CallInst>(U);
  718. Function *F = CI->getParent()->getParent();
  719. ICmpInst *Cmp = DxBreakCmpMap.lookup(F);
  720. if (!Cmp) {
  721. Instruction *IP = dxilutil::FindAllocaInsertionPt(F);
  722. LoadInst *LI = new LoadInst(Gep, nullptr, false, IP);
  723. Cmp = new ICmpInst(IP, ICmpInst::ICMP_EQ, LI, llvm::ConstantInt::get(i32Ty,0));
  724. DxBreakCmpMap[F] = Cmp;
  725. }
  726. CI->replaceAllUsesWith(Cmp);
  727. CI->eraseFromParent();
  728. }
  729. }
  730. BreakFunc->eraseFromParent();
  731. }
  732. for (Function &F : M) {
  733. for (BasicBlock &BB : F) {
  734. if (BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator())) {
  735. BI->setMetadata(DXIL::kDxBreakMDName, nullptr);
  736. }
  737. }
  738. }
  739. }
  740. };
  741. }
  742. char DxilFinalizeModule::ID = 0;
  743. ModulePass *llvm::createDxilFinalizeModulePass() {
  744. return new DxilFinalizeModule();
  745. }
  746. INITIALIZE_PASS(DxilFinalizeModule, "hlsl-dxilfinalize", "HLSL DXIL Finalize Module", false, false)
  747. ///////////////////////////////////////////////////////////////////////////////
  748. namespace {
  749. typedef MapVector< PHINode*, SmallVector<Value*,8> > PHIReplacementMap;
  750. bool RemoveAddrSpaceCasts(Value *Val, Value *NewVal,
  751. PHIReplacementMap &phiReplacements,
  752. DenseMap<Value*, Value*> &valueMap) {
  753. bool bChanged = false;
  754. for (auto itU = Val->use_begin(), itEnd = Val->use_end(); itU != itEnd; ) {
  755. Use &use = *(itU++);
  756. User *user = use.getUser();
  757. Value *userReplacement = user;
  758. bool bConstructReplacement = false;
  759. bool bCleanupInst = false;
  760. auto valueMapIter = valueMap.find(user);
  761. if (valueMapIter != valueMap.end())
  762. userReplacement = valueMapIter->second;
  763. else if (Val != NewVal)
  764. bConstructReplacement = true;
  765. if (ConstantExpr* CE = dyn_cast<ConstantExpr>(user)) {
  766. if (CE->getOpcode() == Instruction::BitCast) {
  767. if (bConstructReplacement) {
  768. // Replicate bitcast in target address space
  769. Type* NewTy = PointerType::get(
  770. CE->getType()->getPointerElementType(),
  771. NewVal->getType()->getPointerAddressSpace());
  772. userReplacement = ConstantExpr::getBitCast(cast<Constant>(NewVal), NewTy);
  773. }
  774. } else if (CE->getOpcode() == Instruction::GetElementPtr) {
  775. if (bConstructReplacement) {
  776. // Replicate GEP in target address space
  777. GEPOperator *GEP = cast<GEPOperator>(CE);
  778. SmallVector<Value*, 8> idxList(GEP->idx_begin(), GEP->idx_end());
  779. userReplacement = ConstantExpr::getGetElementPtr(
  780. nullptr, cast<Constant>(NewVal), idxList, GEP->isInBounds());
  781. }
  782. } else if (CE->getOpcode() == Instruction::AddrSpaceCast) {
  783. userReplacement = NewVal;
  784. bConstructReplacement = false;
  785. } else {
  786. DXASSERT(false, "RemoveAddrSpaceCasts: unhandled pointer ConstantExpr");
  787. }
  788. } else if (Instruction *I = dyn_cast<Instruction>(user)) {
  789. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
  790. if (bConstructReplacement) {
  791. IRBuilder<> Builder(GEP);
  792. SmallVector<Value*, 8> idxList(GEP->idx_begin(), GEP->idx_end());
  793. if (GEP->isInBounds())
  794. userReplacement = Builder.CreateInBoundsGEP(NewVal, idxList, GEP->getName());
  795. else
  796. userReplacement = Builder.CreateGEP(NewVal, idxList, GEP->getName());
  797. }
  798. } else if (BitCastInst *BC = dyn_cast<BitCastInst>(user)) {
  799. if (bConstructReplacement) {
  800. IRBuilder<> Builder(BC);
  801. Type* NewTy = PointerType::get(
  802. BC->getType()->getPointerElementType(),
  803. NewVal->getType()->getPointerAddressSpace());
  804. userReplacement = Builder.CreateBitCast(NewVal, NewTy);
  805. }
  806. } else if (PHINode *PHI = dyn_cast<PHINode>(user)) {
  807. // set replacement phi values for PHI pass
  808. unsigned numValues = PHI->getNumIncomingValues();
  809. auto &phiValues = phiReplacements[PHI];
  810. if (phiValues.empty())
  811. phiValues.resize(numValues, nullptr);
  812. for (unsigned idx = 0; idx < numValues; ++idx) {
  813. if (phiValues[idx] == nullptr &&
  814. PHI->getIncomingValue(idx) == Val) {
  815. phiValues[idx] = NewVal;
  816. bChanged = true;
  817. }
  818. }
  819. continue;
  820. } else if (isa<AddrSpaceCastInst>(user)) {
  821. userReplacement = NewVal;
  822. bConstructReplacement = false;
  823. bCleanupInst = true;
  824. } else if (isa<CallInst>(user)) {
  825. continue;
  826. } else {
  827. if (Val != NewVal) {
  828. use.set(NewVal);
  829. bChanged = true;
  830. }
  831. continue;
  832. }
  833. }
  834. if (bConstructReplacement && user != userReplacement)
  835. valueMap[user] = userReplacement;
  836. bChanged |= RemoveAddrSpaceCasts(user, userReplacement, phiReplacements,
  837. valueMap);
  838. if (bCleanupInst && user->use_empty()) {
  839. // Clean up old instruction if it's now unused.
  840. // Safe during this use iteration when only one use of V in instruction.
  841. if (Instruction *I = dyn_cast<Instruction>(user))
  842. I->eraseFromParent();
  843. bChanged = true;
  844. }
  845. }
  846. return bChanged;
  847. }
  848. }
  849. bool CleanupSharedMemoryAddrSpaceCast(Module &M) {
  850. bool bChanged = false;
  851. // Eliminate address space casts if possible
  852. // Collect phi nodes so we can replace iteratively after pass over GVs
  853. PHIReplacementMap phiReplacements;
  854. DenseMap<Value*, Value*> valueMap;
  855. for (GlobalVariable &GV : M.globals()) {
  856. if (dxilutil::IsSharedMemoryGlobal(&GV)) {
  857. bChanged |= RemoveAddrSpaceCasts(&GV, &GV, phiReplacements,
  858. valueMap);
  859. }
  860. }
  861. bool bConverged = false;
  862. while (!phiReplacements.empty() && !bConverged) {
  863. bConverged = true;
  864. for (auto &phiReplacement : phiReplacements) {
  865. PHINode *PHI = phiReplacement.first;
  866. unsigned origAddrSpace = PHI->getType()->getPointerAddressSpace();
  867. unsigned incomingAddrSpace = UINT_MAX;
  868. bool bReplacePHI = true;
  869. bool bRemovePHI = false;
  870. for (auto V : phiReplacement.second) {
  871. if (nullptr == V) {
  872. // cannot replace phi (yet)
  873. bReplacePHI = false;
  874. break;
  875. }
  876. unsigned addrSpace = V->getType()->getPointerAddressSpace();
  877. if (incomingAddrSpace == UINT_MAX) {
  878. incomingAddrSpace = addrSpace;
  879. } else if (addrSpace != incomingAddrSpace) {
  880. bRemovePHI = true;
  881. break;
  882. }
  883. }
  884. if (origAddrSpace == incomingAddrSpace)
  885. bRemovePHI = true;
  886. if (bRemovePHI) {
  887. // Cannot replace phi. Remove it and restart.
  888. phiReplacements.erase(PHI);
  889. bConverged = false;
  890. break;
  891. }
  892. if (!bReplacePHI)
  893. continue;
  894. auto &NewVal = valueMap[PHI];
  895. PHINode *NewPHI = nullptr;
  896. if (NewVal) {
  897. NewPHI = cast<PHINode>(NewVal);
  898. } else {
  899. IRBuilder<> Builder(PHI);
  900. NewPHI = Builder.CreatePHI(
  901. PointerType::get(PHI->getType()->getPointerElementType(),
  902. incomingAddrSpace),
  903. PHI->getNumIncomingValues(),
  904. PHI->getName());
  905. NewVal = NewPHI;
  906. for (unsigned idx = 0; idx < PHI->getNumIncomingValues(); idx++) {
  907. NewPHI->addIncoming(phiReplacement.second[idx],
  908. PHI->getIncomingBlock(idx));
  909. }
  910. }
  911. if (RemoveAddrSpaceCasts(PHI, NewPHI, phiReplacements,
  912. valueMap)) {
  913. bConverged = false;
  914. bChanged = true;
  915. break;
  916. }
  917. if (PHI->use_empty()) {
  918. phiReplacements.erase(PHI);
  919. bConverged = false;
  920. bChanged = true;
  921. break;
  922. }
  923. }
  924. }
  925. // Cleanup unused replacement instructions
  926. SmallVector<WeakVH, 8> cleanupInsts;
  927. for (auto it : valueMap) {
  928. if (isa<Instruction>(it.first))
  929. cleanupInsts.push_back(it.first);
  930. if (isa<Instruction>(it.second))
  931. cleanupInsts.push_back(it.second);
  932. }
  933. for (auto V : cleanupInsts) {
  934. if (!V)
  935. continue;
  936. if (PHINode *PHI = dyn_cast<PHINode>(V))
  937. RecursivelyDeleteDeadPHINode(PHI);
  938. else if (Instruction *I = dyn_cast<Instruction>(V))
  939. RecursivelyDeleteTriviallyDeadInstructions(I);
  940. }
  941. return bChanged;
  942. }
  943. class DxilCleanupAddrSpaceCast : public ModulePass {
  944. public:
  945. static char ID; // Pass identification, replacement for typeid
  946. explicit DxilCleanupAddrSpaceCast() : ModulePass(ID) {}
  947. const char *getPassName() const override { return "HLSL DXIL Cleanup Address Space Cast"; }
  948. bool runOnModule(Module &M) override {
  949. return CleanupSharedMemoryAddrSpaceCast(M);
  950. }
  951. };
  952. char DxilCleanupAddrSpaceCast::ID = 0;
  953. ModulePass *llvm::createDxilCleanupAddrSpaceCastPass() {
  954. return new DxilCleanupAddrSpaceCast();
  955. }
  956. INITIALIZE_PASS(DxilCleanupAddrSpaceCast, "hlsl-dxil-cleanup-addrspacecast", "HLSL DXIL Cleanup Address Space Cast", false, false)
  957. ///////////////////////////////////////////////////////////////////////////////
  958. namespace {
  959. class DxilEmitMetadata : public ModulePass {
  960. public:
  961. static char ID; // Pass identification, replacement for typeid
  962. explicit DxilEmitMetadata() : ModulePass(ID) {}
  963. const char *getPassName() const override { return "HLSL DXIL Metadata Emit"; }
  964. bool runOnModule(Module &M) override {
  965. if (M.HasDxilModule()) {
  966. DxilModule::ClearDxilMetadata(M);
  967. patchIsFrontfaceTy(M);
  968. M.GetDxilModule().EmitDxilMetadata();
  969. return true;
  970. }
  971. return false;
  972. }
  973. private:
  974. void patchIsFrontfaceTy(Module &M);
  975. };
  976. void patchIsFrontface(DxilSignatureElement &Elt, bool bForceUint) {
  977. // If force to uint, change i1 to u32.
  978. // If not force to uint, change u32 to i1.
  979. if (bForceUint && Elt.GetCompType() == CompType::Kind::I1)
  980. Elt.SetCompType(CompType::Kind::U32);
  981. else if (!bForceUint && Elt.GetCompType() == CompType::Kind::U32)
  982. Elt.SetCompType(CompType::Kind::I1);
  983. }
  984. void patchIsFrontface(DxilSignature &sig, bool bForceUint) {
  985. for (auto &Elt : sig.GetElements()) {
  986. if (Elt->GetSemantic()->GetKind() == Semantic::Kind::IsFrontFace) {
  987. patchIsFrontface(*Elt, bForceUint);
  988. }
  989. }
  990. }
  991. void DxilEmitMetadata::patchIsFrontfaceTy(Module &M) {
  992. DxilModule &DM = M.GetDxilModule();
  993. const ShaderModel *pSM = DM.GetShaderModel();
  994. if (!pSM->IsGS() && !pSM->IsPS())
  995. return;
  996. unsigned ValMajor, ValMinor;
  997. DM.GetValidatorVersion(ValMajor, ValMinor);
  998. bool bForceUint = ValMajor == 0 || (ValMajor >= 1 && ValMinor >= 2);
  999. if (pSM->IsPS()) {
  1000. patchIsFrontface(DM.GetInputSignature(), bForceUint);
  1001. } else if (pSM->IsGS()) {
  1002. patchIsFrontface(DM.GetOutputSignature(), bForceUint);
  1003. }
  1004. }
  1005. }
  1006. char DxilEmitMetadata::ID = 0;
  1007. ModulePass *llvm::createDxilEmitMetadataPass() {
  1008. return new DxilEmitMetadata();
  1009. }
  1010. INITIALIZE_PASS(DxilEmitMetadata, "hlsl-dxilemit", "HLSL DXIL Metadata Emit", false, false)
  1011. ///////////////////////////////////////////////////////////////////////////////
  1012. namespace {
  1013. const StringRef UniNoWaveSensitiveGradientErrMsg =
  1014. "Gradient operations are not affected by wave-sensitive data or control "
  1015. "flow.";
  1016. class DxilValidateWaveSensitivity : public ModulePass {
  1017. public:
  1018. static char ID; // Pass identification, replacement for typeid
  1019. explicit DxilValidateWaveSensitivity() : ModulePass(ID) {}
  1020. const char *getPassName() const override {
  1021. return "HLSL DXIL wave sensitiveity validation";
  1022. }
  1023. bool runOnModule(Module &M) override {
  1024. // Only check ps and lib profile.
  1025. DxilModule &DM = M.GetDxilModule();
  1026. const ShaderModel *pSM = DM.GetShaderModel();
  1027. if (!pSM->IsPS() && !pSM->IsLib())
  1028. return false;
  1029. SmallVector<CallInst *, 16> gradientOps;
  1030. SmallVector<CallInst *, 16> barriers;
  1031. SmallVector<CallInst *, 16> waveOps;
  1032. for (auto &F : M) {
  1033. if (!F.isDeclaration())
  1034. continue;
  1035. for (User *U : F.users()) {
  1036. CallInst *CI = dyn_cast<CallInst>(U);
  1037. if (!CI)
  1038. continue;
  1039. Function *FCalled = CI->getCalledFunction();
  1040. if (!FCalled || !FCalled->isDeclaration())
  1041. continue;
  1042. if (!hlsl::OP::IsDxilOpFunc(FCalled))
  1043. continue;
  1044. DXIL::OpCode dxilOpcode = hlsl::OP::GetDxilOpFuncCallInst(CI);
  1045. if (OP::IsDxilOpWave(dxilOpcode)) {
  1046. waveOps.emplace_back(CI);
  1047. }
  1048. if (OP::IsDxilOpGradient(dxilOpcode)) {
  1049. gradientOps.push_back(CI);
  1050. }
  1051. if (dxilOpcode == DXIL::OpCode::Barrier) {
  1052. barriers.push_back(CI);
  1053. }
  1054. }
  1055. }
  1056. // Skip if not have wave op.
  1057. if (waveOps.empty())
  1058. return false;
  1059. // Skip if no gradient op.
  1060. if (gradientOps.empty())
  1061. return false;
  1062. for (auto &F : M) {
  1063. if (F.isDeclaration())
  1064. continue;
  1065. SmallVector<CallInst *, 16> localGradientOps;
  1066. for (CallInst *CI : gradientOps) {
  1067. if (CI->getParent()->getParent() == &F)
  1068. localGradientOps.emplace_back(CI);
  1069. }
  1070. if (localGradientOps.empty())
  1071. continue;
  1072. PostDominatorTree PDT;
  1073. PDT.runOnFunction(F);
  1074. std::unique_ptr<WaveSensitivityAnalysis> WaveVal(
  1075. WaveSensitivityAnalysis::create(PDT));
  1076. WaveVal->Analyze(&F);
  1077. for (CallInst *op : localGradientOps) {
  1078. if (WaveVal->IsWaveSensitive(op)) {
  1079. dxilutil::EmitWarningOnInstruction(op,
  1080. UniNoWaveSensitiveGradientErrMsg);
  1081. }
  1082. }
  1083. }
  1084. return false;
  1085. }
  1086. };
  1087. }
  1088. char DxilValidateWaveSensitivity::ID = 0;
  1089. ModulePass *llvm::createDxilValidateWaveSensitivityPass() {
  1090. return new DxilValidateWaveSensitivity();
  1091. }
  1092. INITIALIZE_PASS(DxilValidateWaveSensitivity, "hlsl-validate-wave-sensitivity", "HLSL DXIL wave sensitiveity validation", false, false)
  1093. namespace {
  1094. // Cull blocks from BreakBBs that containing instructions that are sensitive to the wave-sensitive Inst
  1095. // Sensitivity entails being an eventual user of the Inst and also belonging to a block with
  1096. // a break conditional on dx.break that breaks out of a loop that contains WaveCI
  1097. // LInfo is needed to determine loop contents. Visited is needed to prevent infinite looping.
  1098. static void CullSensitiveBlocks(LoopInfo *LInfo, Loop *WaveLoop, BasicBlock *LastBB, Instruction *Inst,
  1099. std::unordered_set<Instruction *> &Visited,
  1100. SmallDenseMap<BasicBlock *, Instruction *, 16> &BreakBBs) {
  1101. BasicBlock *BB = Inst->getParent();
  1102. Loop *BreakLoop = LInfo->getLoopFor(BB);
  1103. // If this instruction isn't in a loop, there is no need to track its sensitivity further
  1104. if (!BreakLoop || BreakBBs.empty())
  1105. return;
  1106. // To prevent infinite looping, only visit each instruction once
  1107. if (!Visited.insert(Inst).second)
  1108. return;
  1109. // If this BB wasn't already just processed, handle it now
  1110. if (LastBB != BB) {
  1111. // Determine if the instruction's block has an artificially-conditional break
  1112. // and breaks out of a loop that contains the waveCI
  1113. BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
  1114. if (BI && BI->isConditional() && BreakLoop->contains(WaveLoop))
  1115. BreakBBs.erase(BB);
  1116. }
  1117. // Recurse on the users
  1118. for (User *U : Inst->users()) {
  1119. Instruction *I = cast<Instruction>(U);
  1120. CullSensitiveBlocks(LInfo, WaveLoop, BB, I, Visited, BreakBBs);
  1121. }
  1122. }
  1123. // Collect blocks that end in a dx.break dependent branch by tracing the descendants of BreakFunc
  1124. // that are found in ThisFunc and store the block and call instruction in BreakBBs
  1125. static void CollectBreakBlocks(Function *BreakFunc, Function *ThisFunc,
  1126. SmallDenseMap<BasicBlock *, Instruction *, 16> &BreakBBs) {
  1127. for (User *U : BreakFunc->users()) {
  1128. SmallVector<User *, 16> WorkList;
  1129. Instruction *CI = cast<Instruction>(U);
  1130. // If this user doesn't pertain to the current function, skip it.
  1131. if (CI->getParent()->getParent() != ThisFunc)
  1132. continue;
  1133. WorkList.append(CI->user_begin(), CI->user_end());
  1134. while (!WorkList.empty()) {
  1135. Instruction *I = dyn_cast<Instruction>(WorkList.pop_back_val());
  1136. // When we find a Branch that depends on dx.break, save it and stop
  1137. // This should almost always be the first user of the Call Inst
  1138. // If not, iterate on the users
  1139. if (BranchInst *BI = dyn_cast<BranchInst>(I))
  1140. BreakBBs[BI->getParent()] = CI;
  1141. else
  1142. WorkList.append(I->user_begin(), I->user_end());
  1143. }
  1144. }
  1145. }
  1146. // A pass to remove conditions from breaks that do not contain instructions that
  1147. // depend on wave operations that are in the loop that the break leaves.
  1148. class CleanupDxBreak : public FunctionPass {
  1149. public:
  1150. static char ID; // Pass identification, replacement for typeid
  1151. explicit CleanupDxBreak() : FunctionPass(ID) {}
  1152. const char *getPassName() const override { return "HLSL Remove unnecessary dx.break conditions"; }
  1153. void getAnalysisUsage(AnalysisUsage &AU) const override {
  1154. AU.addRequired<LoopInfoWrapperPass>();
  1155. }
  1156. LoopInfo *LInfo;
  1157. bool runOnFunction(Function &F) override {
  1158. if (F.isDeclaration())
  1159. return false;
  1160. Module *M = F.getEntryBlock().getModule();
  1161. Function *BreakFunc = M->getFunction(DXIL::kDxBreakFuncName);
  1162. if (!BreakFunc)
  1163. return false;
  1164. LInfo = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  1165. // Collect the blocks that depend on dx.break and the instructions that call dx.break()
  1166. SmallDenseMap<BasicBlock *, Instruction *, 16> BreakBBs;
  1167. CollectBreakBlocks(BreakFunc, &F, BreakBBs);
  1168. if (BreakBBs.empty())
  1169. return false;
  1170. // Collect all wave calls in this function and group by loop
  1171. SmallDenseMap<Loop *, SmallVector<CallInst *, 8>, 16> WaveCalls;
  1172. for (Function &IF : M->functions()) {
  1173. HLOpcodeGroup opgroup = hlsl::GetHLOpcodeGroup(&IF);
  1174. // Only consider wave-sensitive intrinsics or extintrinsics
  1175. if (IF.isDeclaration() && IsHLWaveSensitive(&IF) && !BreakBBs.empty() &&
  1176. (opgroup == HLOpcodeGroup::HLIntrinsic || opgroup == HLOpcodeGroup::HLExtIntrinsic)) {
  1177. // For each user of the function, trace all its users to remove the blocks
  1178. for (User *U : IF.users()) {
  1179. CallInst *CI = cast<CallInst>(U);
  1180. if (CI->getParent()->getParent() == &F) {
  1181. Loop *WaveLoop = LInfo->getLoopFor(CI->getParent());
  1182. WaveCalls[WaveLoop].emplace_back(CI);
  1183. }
  1184. }
  1185. }
  1186. }
  1187. // For each wave operation, remove all the dx.break blocks that are sensitive to it
  1188. for (DenseMap<Loop*, SmallVector<CallInst *, 8>>::iterator I =
  1189. WaveCalls.begin(), E = WaveCalls.end();
  1190. I != E; ++I) {
  1191. Loop *loop = I->first;
  1192. std::unordered_set<Instruction *> Visited;
  1193. for (CallInst *CI : I->second) {
  1194. CullSensitiveBlocks(LInfo, loop, nullptr, CI, Visited, BreakBBs);
  1195. }
  1196. }
  1197. bool Changed = false;
  1198. // Revert artificially conditional breaks in non-wave-sensitive blocks that remain in BreakBBs
  1199. Constant *C = ConstantInt::get(Type::getInt1Ty(M->getContext()), 1);
  1200. for (auto &BB : BreakBBs) {
  1201. // Replace the call instruction with a constant boolen
  1202. BB.second->replaceAllUsesWith(C);
  1203. BB.second->eraseFromParent();
  1204. Changed = true;
  1205. }
  1206. return Changed;
  1207. }
  1208. };
  1209. }
  1210. char CleanupDxBreak::ID = 0;
  1211. INITIALIZE_PASS_BEGIN(CleanupDxBreak, "hlsl-cleanup-dxbreak", "HLSL Remove unnecessary dx.break conditions", false, false)
  1212. INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
  1213. INITIALIZE_PASS_END(CleanupDxBreak, "hlsl-cleanup-dxbreak", "HLSL Remove unnecessary dx.break conditions", false, false)
  1214. FunctionPass *llvm::createCleanupDxBreakPass() {
  1215. return new CleanupDxBreak();
  1216. }