DxilNoops.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilNoops.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Passes to insert dx.noops() and replace them with llvm.donothing() //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. //
  12. // Here is how dx.preserve and dx.noop work.
  13. //
  14. // For example, the following HLSL code:
  15. //
  16. // float foo(float y) {
  17. // float x = 10;
  18. // x = 20;
  19. // x += y;
  20. // return x;
  21. // }
  22. //
  23. // float main() : SV_Target {
  24. // float ret = foo(10);
  25. // return ret;
  26. // }
  27. //
  28. // Ordinarily, it gets lowered as:
  29. //
  30. // dx.op.storeOutput(3.0)
  31. //
  32. // Intermediate steps at "x = 20;", "x += y;", "return x", and
  33. // even the call to "foo()" are lost.
  34. //
  35. // But with with Preserve and Noop:
  36. //
  37. // void call dx.noop() // float ret = foo(10);
  38. // %y = dx.preserve(10.0, 10.0) // argument: y=10
  39. // %x0 = dx.preserve(10.0, 10.0) // float x = 10;
  40. // %x1 = dx.preserve(20.0, %x0) // x = 20;
  41. // %x2 = fadd %x1, %y // x += y;
  42. // void call dx.noop() // return x
  43. // %ret = dx.preserve(%x2, %x2) // ret = returned from foo()
  44. // dx.op.storeOutput(%ret)
  45. //
  46. // All the intermediate transformations are visible and could be
  47. // made inspectable in the debugger.
  48. //
  49. // The reason why dx.preserve takes 2 arguments is so that the previous
  50. // value of a variable does not get cleaned up by DCE. For example:
  51. //
  52. // float x = ...;
  53. // do_some_stuff_with(x);
  54. // do_some_other_stuff(); // At this point, x's last values
  55. // // are dead and register allocators
  56. // // are free to reuse its location during
  57. // // call this code.
  58. // // So until x is assigned a new value below
  59. // // x could become unavailable.
  60. // //
  61. // // The second parameter in dx.preserve
  62. // // keeps x's previous value alive.
  63. //
  64. // x = ...; // Assign something else
  65. //
  66. //
  67. // When emitting proper DXIL, dx.noop and dx.preserve are lowered to
  68. // ordinary LLVM instructions that do not affect the semantic of the
  69. // shader, but can be used by a debugger or backend generator if they
  70. // know what to look for.
  71. //
  72. // We generate two special internal constant global vars:
  73. //
  74. // @dx.preserve.value = internal constant i1 false
  75. // @dx.nothing = internal constant i32 0
  76. //
  77. // "call dx.noop()" is lowered to "load @dx.nothing"
  78. //
  79. // "... = call dx.preserve(%cur_val, %last_val)" is lowered to:
  80. //
  81. // %p = load @dx.preserve.value
  82. // ... = select i1 %p, %last_val, %cur_val
  83. //
  84. // Since %p is guaranteed to be false, the select is guaranteed
  85. // to return %cur_val.
  86. //
  87. #include "llvm/Pass.h"
  88. #include "llvm/IR/Module.h"
  89. #include "llvm/IR/Instructions.h"
  90. #include "llvm/IR/IntrinsicInst.h"
  91. #include "llvm/IR/Intrinsics.h"
  92. #include "llvm/IR/IRBuilder.h"
  93. #include "llvm/Transforms/Scalar.h"
  94. #include "llvm/Support/raw_os_ostream.h"
  95. #include "dxc/DXIL/DxilMetadataHelper.h"
  96. #include "dxc/DXIL/DxilConstants.h"
  97. #include <unordered_set>
  98. using namespace llvm;
  99. namespace {
  100. StringRef kNoopName = "dx.noop";
  101. StringRef kPreservePrefix = "dx.preserve.";
  102. StringRef kNothingName = "dx.nothing.a";
  103. StringRef kPreserveName = "dx.preserve.value.a";
  104. }
  105. static Function *GetOrCreateNoopF(Module &M) {
  106. LLVMContext &Ctx = M.getContext();
  107. FunctionType *FT = FunctionType::get(Type::getVoidTy(Ctx), false);
  108. Function *NoopF = cast<Function>(M.getOrInsertFunction(::kNoopName, FT));
  109. NoopF->addFnAttr(Attribute::AttrKind::Convergent);
  110. return NoopF;
  111. }
  112. static Constant *GetConstGep(Constant *Ptr, unsigned Idx0, unsigned Idx1) {
  113. Type *i32Ty = Type::getInt32Ty(Ptr->getContext());
  114. Constant *Indices[] = { ConstantInt::get(i32Ty, Idx0), ConstantInt::get(i32Ty, Idx1) };
  115. return ConstantExpr::getGetElementPtr(nullptr, Ptr, Indices);
  116. }
  117. static bool ShouldPreserve(Value *V) {
  118. if (isa<Constant>(V)) return true;
  119. if (isa<Argument>(V)) return true;
  120. if (isa<LoadInst>(V)) return true;
  121. if (ExtractElementInst *GEP = dyn_cast<ExtractElementInst>(V)) {
  122. return ShouldPreserve(GEP->getVectorOperand());
  123. }
  124. if (isa<CallInst>(V)) return true;
  125. return false;
  126. }
  127. struct Store_Info {
  128. Instruction *StoreOrMC = nullptr;
  129. Value *Source = nullptr; // Alloca, GV, or Argument
  130. bool AllowLoads = false;
  131. };
  132. static void FindAllStores(Value *Ptr, std::vector<Store_Info> *Stores, std::vector<Value *> &WorklistStorage, std::unordered_set<Value *> &SeenStorage) {
  133. assert(isa<Argument>(Ptr) || isa<AllocaInst>(Ptr) || isa<GlobalVariable>(Ptr));
  134. WorklistStorage.clear();
  135. WorklistStorage.push_back(Ptr);
  136. // Don't clear Seen Storage because two pointers can be involved with the same
  137. // memcpy. Clearing it can get the memcpy added twice.
  138. unsigned StartIdx = Stores->size();
  139. bool AllowLoad = false;
  140. while (WorklistStorage.size()) {
  141. Value *V = WorklistStorage.back();
  142. WorklistStorage.pop_back();
  143. SeenStorage.insert(V);
  144. if (isa<BitCastOperator>(V) || isa<GEPOperator>(V) || isa<GlobalVariable>(V) || isa<AllocaInst>(V) || isa<Argument>(V)) {
  145. for (User *U : V->users()) {
  146. // Allow load if MC reads from pointer
  147. if (MemCpyInst *MC = dyn_cast<MemCpyInst>(U)) {
  148. AllowLoad |= MC->getSource() == V;
  149. }
  150. else if (isa<LoadInst>(U)) {
  151. AllowLoad = true;
  152. }
  153. // Add to worklist if we haven't seen it before.
  154. else {
  155. if (!SeenStorage.count(U))
  156. WorklistStorage.push_back(U);
  157. }
  158. }
  159. }
  160. else if (StoreInst *Store = dyn_cast<StoreInst>(V)) {
  161. if (ShouldPreserve(Store->getValueOperand())) {
  162. Store_Info Info;
  163. Info.StoreOrMC = Store;
  164. Info.Source = Ptr;
  165. Stores->push_back(Info);
  166. }
  167. }
  168. else if (MemCpyInst *MC = dyn_cast<MemCpyInst>(V)) {
  169. Store_Info Info;
  170. Info.StoreOrMC = MC;
  171. Info.Source = Ptr;
  172. Stores->push_back(Info);
  173. }
  174. }
  175. if (isa<GlobalVariable>(Ptr)) {
  176. AllowLoad = true;
  177. }
  178. if (AllowLoad) {
  179. Store_Info *ptr = Stores->data();
  180. for (unsigned i = StartIdx; i < Stores->size(); i++)
  181. ptr[i].AllowLoads = true;
  182. }
  183. }
  184. static User *GetUniqueUser(Value *V) {
  185. if (V->user_begin() != V->user_end()) {
  186. if (std::next(V->user_begin()) == V->user_end())
  187. return *V->user_begin();
  188. }
  189. return nullptr;
  190. }
  191. static Value *GetOrCreatePreserveCond(Function *F) {
  192. assert(!F->isDeclaration());
  193. Module *M = F->getParent();
  194. GlobalVariable *GV = M->getGlobalVariable(kPreserveName, true);
  195. if (!GV) {
  196. Type *i32Ty = Type::getInt32Ty(M->getContext());
  197. Type *i32ArrayTy = ArrayType::get(i32Ty, 1);
  198. unsigned int Values[1] = { 0 };
  199. Constant *InitialValue = llvm::ConstantDataArray::get(M->getContext(), Values);
  200. GV = new GlobalVariable(*M,
  201. i32ArrayTy, true,
  202. llvm::GlobalValue::InternalLinkage,
  203. InitialValue, kPreserveName);
  204. }
  205. for (User *U : GV->users()) {
  206. GEPOperator *Gep = Gep = cast<GEPOperator>(U);
  207. for (User *GepU : Gep->users()) {
  208. LoadInst *LI = cast<LoadInst>(GepU);
  209. if (LI->getParent()->getParent() == F) {
  210. return GetUniqueUser(LI);
  211. }
  212. }
  213. }
  214. BasicBlock *BB = &F->getEntryBlock();
  215. Instruction *InsertPt = &BB->front();
  216. while (isa<AllocaInst>(InsertPt) || isa<DbgInfoIntrinsic>(InsertPt))
  217. InsertPt = InsertPt->getNextNode();
  218. IRBuilder<> B(InsertPt);
  219. Constant *Gep = GetConstGep(GV, 0, 0);
  220. LoadInst *Load = B.CreateLoad(Gep);
  221. return B.CreateTrunc(Load, B.getInt1Ty());
  222. }
  223. static Function *GetOrCreatePreserveF(Module *M, Type *Ty) {
  224. std::string str = kPreservePrefix;
  225. raw_string_ostream os(str);
  226. Ty->print(os);
  227. os.flush();
  228. FunctionType *FT = FunctionType::get(Ty, { Ty, Ty }, false);
  229. Function *PreserveF = cast<Function>(M->getOrInsertFunction(str, FT));
  230. PreserveF->addFnAttr(Attribute::AttrKind::ReadNone);
  231. PreserveF->addFnAttr(Attribute::AttrKind::NoUnwind);
  232. return PreserveF;
  233. }
  234. static Instruction *CreatePreserve(Value *V, Value *LastV, Instruction *InsertPt) {
  235. assert(V->getType() == LastV->getType());
  236. Type *Ty = V->getType();
  237. Function *PreserveF = GetOrCreatePreserveF(InsertPt->getModule(), Ty);
  238. return CallInst::Create(PreserveF, ArrayRef<Value *> { V, LastV }, "", InsertPt);
  239. }
  240. static void LowerPreserveToSelect(CallInst *CI) {
  241. Value *V = CI->getArgOperand(0);
  242. Value *LastV = CI->getArgOperand(1);
  243. if (LastV == V)
  244. LastV = UndefValue::get(V->getType());
  245. Value *Cond = GetOrCreatePreserveCond(CI->getParent()->getParent());
  246. SelectInst *Select = SelectInst::Create(Cond, LastV, V, "", CI);
  247. Select->setDebugLoc(CI->getDebugLoc());
  248. CI->replaceAllUsesWith(Select);
  249. CI->eraseFromParent();
  250. }
  251. static void InsertNoopAt(Instruction *I) {
  252. Module &M = *I->getModule();
  253. Function *NoopF = GetOrCreateNoopF(M);
  254. CallInst *Noop = CallInst::Create(NoopF, {}, I);
  255. Noop->setDebugLoc(I->getDebugLoc());
  256. }
  257. //==========================================================
  258. // Insertion pass
  259. //
  260. // This pass inserts dx.noop and dx.preserve where we want
  261. // to preserve line mapping or perserve some intermediate
  262. // values.
  263. struct DxilInsertPreserves : public ModulePass {
  264. static char ID;
  265. DxilInsertPreserves() : ModulePass(ID) {
  266. initializeDxilInsertPreservesPass(*PassRegistry::getPassRegistry());
  267. }
  268. bool runOnModule(Module &M) override {
  269. std::vector<Store_Info> Stores;
  270. std::vector<Value *> WorklistStorage;
  271. std::unordered_set<Value *> SeenStorage;
  272. for (GlobalVariable &GV : M.globals()) {
  273. if (GV.getLinkage() != GlobalValue::LinkageTypes::InternalLinkage ||
  274. GV.getType()->getPointerAddressSpace() == hlsl::DXIL::kTGSMAddrSpace)
  275. {
  276. continue;
  277. }
  278. for (User *U : GV.users()) {
  279. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  280. InsertNoopAt(LI);
  281. }
  282. }
  283. FindAllStores(&GV, &Stores, WorklistStorage, SeenStorage);
  284. }
  285. bool Changed = false;
  286. for (Function &F : M) {
  287. if (F.isDeclaration())
  288. continue;
  289. // Collect Stores on Allocas in function
  290. BasicBlock *Entry = &*F.begin();
  291. for (Instruction &I : *Entry) {
  292. AllocaInst *AI = dyn_cast<AllocaInst>(&I);
  293. if (!AI)
  294. continue;
  295. // Skip temp allocas
  296. if (!AI->getMetadata(hlsl::DxilMDHelper::kDxilTempAllocaMDName))
  297. FindAllStores(AI, &Stores, WorklistStorage, SeenStorage);
  298. }
  299. // Collect Stores on pointer Arguments in function
  300. for (Argument &Arg : F.args()) {
  301. if (Arg.getType()->isPointerTy())
  302. FindAllStores(&Arg, &Stores, WorklistStorage, SeenStorage);
  303. }
  304. // For every real function call, insert a nop
  305. // so we can put a breakpoint there.
  306. for (User *U : F.users()) {
  307. if (CallInst *CI = dyn_cast<CallInst>(U)) {
  308. InsertNoopAt(CI);
  309. }
  310. }
  311. // Insert nops for void return statements
  312. for (BasicBlock &BB : F) {
  313. ReturnInst *Ret = dyn_cast<ReturnInst>(BB.getTerminator());
  314. if (Ret)
  315. InsertNoopAt(Ret);
  316. }
  317. }
  318. // Insert preserves or noops for these stores
  319. for (Store_Info &Info : Stores) {
  320. if (StoreInst *Store = dyn_cast<StoreInst>(Info.StoreOrMC)) {
  321. Value *V = Store->getValueOperand();
  322. if (V &&
  323. !V->getType()->isAggregateType() &&
  324. !V->getType()->isPointerTy())
  325. {
  326. IRBuilder<> B(Store);
  327. Value *Last_Value = nullptr;
  328. // If there's never any loads for this memory location,
  329. // don't generate a load.
  330. if (Info.AllowLoads) {
  331. Last_Value = B.CreateLoad(Store->getPointerOperand());
  332. }
  333. else {
  334. Last_Value = UndefValue::get(V->getType());
  335. }
  336. Instruction *Preserve = CreatePreserve(V, Last_Value, Store);
  337. Preserve->setDebugLoc(Store->getDebugLoc());
  338. Store->replaceUsesOfWith(V, Preserve);
  339. Changed = true;
  340. }
  341. else {
  342. InsertNoopAt(Store);
  343. }
  344. }
  345. else if (MemCpyInst *MC = cast<MemCpyInst>(Info.StoreOrMC)) {
  346. // TODO: Do something to preserve pointer's previous value.
  347. InsertNoopAt(MC);
  348. }
  349. }
  350. return Changed;
  351. }
  352. const char *getPassName() const override { return "Dxil Insert Preserves"; }
  353. };
  354. char DxilInsertPreserves::ID;
  355. Pass *llvm::createDxilInsertPreservesPass() {
  356. return new DxilInsertPreserves();
  357. }
  358. INITIALIZE_PASS(DxilInsertPreserves, "dxil-insert-preserves", "Dxil Insert Preserves", false, false)
  359. //==========================================================
  360. // Lower dx.preserve to select
  361. //
  362. // This pass replaces all dx.preserve calls to select
  363. //
  364. namespace {
  365. class DxilPreserveToSelect : public ModulePass {
  366. public:
  367. static char ID;
  368. SmallDenseMap<Type *, Function *> PreserveFunctions;
  369. DxilPreserveToSelect() : ModulePass(ID) {
  370. initializeDxilPreserveToSelectPass(*PassRegistry::getPassRegistry());
  371. }
  372. bool runOnModule(Module &M) override {
  373. bool Changed = false;
  374. for (auto fit = M.getFunctionList().begin(), end = M.getFunctionList().end();
  375. fit != end;)
  376. {
  377. Function *F = &*(fit++);
  378. if (!F->isDeclaration())
  379. continue;
  380. if (F->getName().startswith(kPreservePrefix)) {
  381. for (auto uit = F->user_begin(), end = F->user_end(); uit != end;) {
  382. User *U = *(uit++);
  383. CallInst *CI = cast<CallInst>(U);
  384. LowerPreserveToSelect(CI);
  385. }
  386. F->eraseFromParent();
  387. Changed = true;
  388. }
  389. }
  390. return Changed;
  391. }
  392. const char *getPassName() const override { return "Dxil Lower Preserves to Selects"; }
  393. };
  394. char DxilPreserveToSelect::ID;
  395. }
  396. Pass *llvm::createDxilPreserveToSelectPass() {
  397. return new DxilPreserveToSelect();
  398. }
  399. INITIALIZE_PASS(DxilPreserveToSelect, "dxil-insert-noops", "Dxil Insert Noops", false, false)
  400. //==========================================================
  401. // Finalize pass
  402. //
  403. namespace {
  404. class DxilFinalizePreserves : public ModulePass {
  405. public:
  406. static char ID;
  407. GlobalVariable *NothingGV = nullptr;
  408. DxilFinalizePreserves() : ModulePass(ID) {
  409. initializeDxilFinalizePreservesPass(*PassRegistry::getPassRegistry());
  410. }
  411. Instruction *GetFinalNoopInst(Module &M, Instruction *InsertBefore) {
  412. Type *i32Ty = Type::getInt32Ty(M.getContext());
  413. if (!NothingGV) {
  414. NothingGV = M.getGlobalVariable(kNothingName);
  415. if (!NothingGV) {
  416. Type *i32ArrayTy = ArrayType::get(i32Ty, 1);
  417. unsigned int Values[1] = { 0 };
  418. Constant *InitialValue = llvm::ConstantDataArray::get(M.getContext(), Values);
  419. NothingGV = new GlobalVariable(M,
  420. i32ArrayTy, true,
  421. llvm::GlobalValue::InternalLinkage,
  422. InitialValue, kNothingName);
  423. }
  424. }
  425. Constant *Gep = GetConstGep(NothingGV, 0, 0);
  426. return new llvm::LoadInst(Gep, nullptr, InsertBefore);
  427. }
  428. bool LowerPreserves(Module &M);
  429. bool LowerNoops(Module &M);
  430. bool runOnModule(Module &M) override;
  431. const char *getPassName() const override { return "Dxil Finalize Preserves"; }
  432. };
  433. char DxilFinalizePreserves::ID;
  434. }
  435. // Fix undefs in the dx.preserve -> selects
  436. bool DxilFinalizePreserves::LowerPreserves(Module &M) {
  437. bool Changed = false;
  438. GlobalVariable *GV = M.getGlobalVariable(kPreserveName, true);
  439. if (GV) {
  440. for (User *U : GV->users()) {
  441. GEPOperator *Gep = cast<GEPOperator>(U);
  442. for (User *GepU : Gep->users()) {
  443. LoadInst *LI = cast<LoadInst>(GepU);
  444. assert(LI->user_begin() != LI->user_end() &&
  445. std::next(LI->user_begin()) == LI->user_end());
  446. Instruction *I = cast<Instruction>(*LI->user_begin());
  447. for (User *UU : I->users()) {
  448. SelectInst *P = cast<SelectInst>(UU);
  449. Value *PrevV = P->getTrueValue();
  450. Value *CurV = P->getFalseValue();
  451. if (isa<UndefValue>(PrevV) || isa<Constant>(PrevV)) {
  452. P->setOperand(1, CurV);
  453. Changed = true;
  454. }
  455. }
  456. }
  457. }
  458. }
  459. return Changed;
  460. }
  461. // Replace all @dx.noop's with load @dx.nothing.value
  462. bool DxilFinalizePreserves::LowerNoops(Module &M) {
  463. bool Changed = false;
  464. Function *NoopF = nullptr;
  465. for (Function &F : M) {
  466. if (!F.isDeclaration())
  467. continue;
  468. if (F.getName() == kNoopName) {
  469. NoopF = &F;
  470. }
  471. }
  472. if (NoopF) {
  473. for (auto It = NoopF->user_begin(), E = NoopF->user_end(); It != E;) {
  474. User *U = *(It++);
  475. CallInst *CI = cast<CallInst>(U);
  476. Instruction *Nop = GetFinalNoopInst(M, CI);
  477. Nop->setDebugLoc(CI->getDebugLoc());
  478. CI->eraseFromParent();
  479. Changed = true;
  480. }
  481. assert(NoopF->user_empty() && "dx.noop calls must be all removed now");
  482. NoopF->eraseFromParent();
  483. }
  484. return Changed;
  485. }
  486. // Replace all preserves and nops
  487. bool DxilFinalizePreserves::runOnModule(Module &M) {
  488. bool Changed = false;
  489. Changed |= LowerPreserves(M);
  490. Changed |= LowerNoops(M);
  491. return Changed;
  492. }
  493. Pass *llvm::createDxilFinalizePreservesPass() {
  494. return new DxilFinalizePreserves();
  495. }
  496. INITIALIZE_PASS(DxilFinalizePreserves, "dxil-finalize-preserves", "Dxil Finalize Preserves", false, false)