DxilUtil.cpp 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilUtil.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Dxil helper functions. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/DXIL/DxilTypeSystem.h"
  12. #include "dxc/DXIL/DxilUtil.h"
  13. #include "dxc/DXIL/DxilModule.h"
  14. #include "dxc/DXIL/DxilOperations.h"
  15. #include "dxc/Support/Global.h"
  16. #include "llvm/ADT/StringExtras.h"
  17. #include "llvm/ADT/Twine.h"
  18. #include "llvm/Bitcode/ReaderWriter.h"
  19. #include "llvm/IR/DiagnosticInfo.h"
  20. #include "llvm/IR/DiagnosticPrinter.h"
  21. #include "llvm/IR/GlobalVariable.h"
  22. #include "llvm/IR/IntrinsicInst.h"
  23. #include "llvm/IR/LLVMContext.h"
  24. #include "llvm/IR/Module.h"
  25. #include "llvm/Support/MemoryBuffer.h"
  26. #include "llvm/Support/raw_ostream.h"
  27. #include "llvm/IR/Instructions.h"
  28. #include "llvm/IR/Constants.h"
  29. #include "llvm/IR/DIBuilder.h"
  30. #include "llvm/IR/IRBuilder.h"
  31. using namespace llvm;
  32. using namespace hlsl;
  33. namespace hlsl {
  34. namespace dxilutil {
  35. const char ManglingPrefix[] = "\01?";
  36. const char EntryPrefix[] = "dx.entry.";
  37. Type *GetArrayEltTy(Type *Ty) {
  38. if (isa<PointerType>(Ty))
  39. Ty = Ty->getPointerElementType();
  40. while (isa<ArrayType>(Ty)) {
  41. Ty = Ty->getArrayElementType();
  42. }
  43. return Ty;
  44. }
  45. bool HasDynamicIndexing(Value *V) {
  46. for (auto User : V->users()) {
  47. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  48. for (auto Idx = GEP->idx_begin(); Idx != GEP->idx_end(); ++Idx) {
  49. if (!isa<ConstantInt>(Idx))
  50. return true;
  51. }
  52. }
  53. }
  54. return false;
  55. }
  56. unsigned
  57. GetLegacyCBufferFieldElementSize(DxilFieldAnnotation &fieldAnnotation,
  58. llvm::Type *Ty,
  59. DxilTypeSystem &typeSys) {
  60. while (isa<ArrayType>(Ty)) {
  61. Ty = Ty->getArrayElementType();
  62. }
  63. // Bytes.
  64. CompType compType = fieldAnnotation.GetCompType();
  65. unsigned compSize = compType.Is64Bit() ? 8 : compType.Is16Bit() && !typeSys.UseMinPrecision() ? 2 : 4;
  66. unsigned fieldSize = compSize;
  67. if (Ty->isVectorTy()) {
  68. fieldSize *= Ty->getVectorNumElements();
  69. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  70. DxilStructAnnotation *EltAnnotation = typeSys.GetStructAnnotation(ST);
  71. if (EltAnnotation) {
  72. fieldSize = EltAnnotation->GetCBufferSize();
  73. } else {
  74. // Calculate size when don't have annotation.
  75. if (fieldAnnotation.HasMatrixAnnotation()) {
  76. const DxilMatrixAnnotation &matAnnotation =
  77. fieldAnnotation.GetMatrixAnnotation();
  78. unsigned rows = matAnnotation.Rows;
  79. unsigned cols = matAnnotation.Cols;
  80. if (matAnnotation.Orientation == MatrixOrientation::ColumnMajor) {
  81. rows = cols;
  82. cols = matAnnotation.Rows;
  83. } else if (matAnnotation.Orientation != MatrixOrientation::RowMajor) {
  84. // Invalid matrix orientation.
  85. fieldSize = 0;
  86. }
  87. fieldSize = (rows - 1) * 16 + cols * 4;
  88. } else {
  89. // Cannot find struct annotation.
  90. fieldSize = 0;
  91. }
  92. }
  93. }
  94. return fieldSize;
  95. }
  96. bool IsStaticGlobal(GlobalVariable *GV) {
  97. return GV->getLinkage() == GlobalValue::LinkageTypes::InternalLinkage &&
  98. GV->getType()->getPointerAddressSpace() == DXIL::kDefaultAddrSpace;
  99. }
  100. bool IsSharedMemoryGlobal(llvm::GlobalVariable *GV) {
  101. return GV->getType()->getPointerAddressSpace() == DXIL::kTGSMAddrSpace;
  102. }
  103. bool RemoveUnusedFunctions(Module &M, Function *EntryFunc,
  104. Function *PatchConstantFunc, bool IsLib) {
  105. std::vector<Function *> deadList;
  106. for (auto &F : M.functions()) {
  107. if (&F == EntryFunc || &F == PatchConstantFunc)
  108. continue;
  109. if (F.isDeclaration() || !IsLib) {
  110. if (F.user_empty())
  111. deadList.emplace_back(&F);
  112. }
  113. }
  114. bool bUpdated = deadList.size();
  115. for (Function *F : deadList)
  116. F->eraseFromParent();
  117. return bUpdated;
  118. }
  119. void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context) {
  120. DiagnosticPrinter *printer = reinterpret_cast<DiagnosticPrinter *>(Context);
  121. DI.print(*printer);
  122. }
  123. StringRef DemangleFunctionName(StringRef name) {
  124. if (!name.startswith(ManglingPrefix)) {
  125. // Name isn't mangled.
  126. return name;
  127. }
  128. size_t nameEnd = name.find_first_of("@");
  129. DXASSERT(nameEnd != StringRef::npos, "else Name isn't mangled but has \01?");
  130. return name.substr(2, nameEnd - 2);
  131. }
  132. std::string ReplaceFunctionName(StringRef originalName, StringRef newName) {
  133. if (originalName.startswith(ManglingPrefix)) {
  134. return (Twine(ManglingPrefix) + newName +
  135. originalName.substr(originalName.find_first_of('@'))).str();
  136. } else if (originalName.startswith(EntryPrefix)) {
  137. return (Twine(EntryPrefix) + newName).str();
  138. }
  139. return newName.str();
  140. }
  141. // From AsmWriter.cpp
  142. // PrintEscapedString - Print each character of the specified string, escaping
  143. // it if it is not printable or if it is an escape char.
  144. void PrintEscapedString(StringRef Name, raw_ostream &Out) {
  145. for (unsigned i = 0, e = Name.size(); i != e; ++i) {
  146. unsigned char C = Name[i];
  147. if (isprint(C) && C != '\\' && C != '"')
  148. Out << C;
  149. else
  150. Out << '\\' << hexdigit(C >> 4) << hexdigit(C & 0x0F);
  151. }
  152. }
  153. void PrintUnescapedString(StringRef Name, raw_ostream &Out) {
  154. for (unsigned i = 0, e = Name.size(); i != e; ++i) {
  155. unsigned char C = Name[i];
  156. if (C == '\\') {
  157. C = Name[++i];
  158. unsigned value = hexDigitValue(C);
  159. if (value != -1U) {
  160. C = (unsigned char)value;
  161. unsigned value2 = hexDigitValue(Name[i+1]);
  162. assert(value2 != -1U && "otherwise, not a two digit hex escape");
  163. if (value2 != -1U) {
  164. C = (C << 4) + (unsigned char)value2;
  165. ++i;
  166. }
  167. } // else, the next character (in C) should be the escaped character
  168. }
  169. Out << C;
  170. }
  171. }
  172. std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,
  173. llvm::LLVMContext &Ctx,
  174. std::string &DiagStr) {
  175. // Note: the DiagStr is not used.
  176. auto pModule = llvm::parseBitcodeFile(MB->getMemBufferRef(), Ctx);
  177. if (!pModule) {
  178. return nullptr;
  179. }
  180. return std::unique_ptr<llvm::Module>(pModule.get().release());
  181. }
  182. std::unique_ptr<llvm::Module> LoadModuleFromBitcodeLazy(std::unique_ptr<llvm::MemoryBuffer> &&MB,
  183. llvm::LLVMContext &Ctx, std::string &DiagStr)
  184. {
  185. // Note: the DiagStr is not used.
  186. auto pModule = llvm::getLazyBitcodeModule(std::move(MB), Ctx, nullptr, true);
  187. if (!pModule) {
  188. return nullptr;
  189. }
  190. return std::unique_ptr<llvm::Module>(pModule.get().release());
  191. }
  192. std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::StringRef BC,
  193. llvm::LLVMContext &Ctx,
  194. std::string &DiagStr) {
  195. std::unique_ptr<llvm::MemoryBuffer> pBitcodeBuf(
  196. llvm::MemoryBuffer::getMemBuffer(BC, "", false));
  197. return LoadModuleFromBitcode(pBitcodeBuf.get(), Ctx, DiagStr);
  198. }
  199. DIGlobalVariable *FindGlobalVariableDebugInfo(GlobalVariable *GV,
  200. DebugInfoFinder &DbgInfoFinder) {
  201. struct GlobalFinder {
  202. GlobalVariable *GV;
  203. bool operator()(llvm::DIGlobalVariable *const arg) const {
  204. return arg->getVariable() == GV;
  205. }
  206. };
  207. GlobalFinder F = {GV};
  208. DebugInfoFinder::global_variable_iterator Found =
  209. std::find_if(DbgInfoFinder.global_variables().begin(),
  210. DbgInfoFinder.global_variables().end(), F);
  211. if (Found != DbgInfoFinder.global_variables().end()) {
  212. return *Found;
  213. }
  214. return nullptr;
  215. }
  216. std::string FormatMessageAtLocation(const DebugLoc &DL, const Twine& Msg) {
  217. std::string locString;
  218. raw_string_ostream os(locString);
  219. DL.print(os);
  220. os << ": " << Msg;
  221. return os.str();
  222. }
  223. std::string FormatMessageInSubProgram(DISubprogram *DISP, const Twine& Msg) {
  224. std::string locString;
  225. raw_string_ostream os(locString);
  226. auto *Scope = cast<DIScope>(DISP->getScope());
  227. os << Scope->getFilename();
  228. os << ':' << DISP->getLine();
  229. os << ": " << Msg;
  230. return os.str();
  231. }
  232. std::string FormatMessageInVariable(DIVariable *DIV, const Twine& Msg) {
  233. std::string locString;
  234. raw_string_ostream os(locString);
  235. auto *Scope = cast<DIScope>(DIV->getScope());
  236. os << Scope->getFilename();
  237. os << ':' << DIV->getLine();
  238. os << ": " << Msg;
  239. return os.str();
  240. }
  241. Twine FormatMessageWithoutLocation(const Twine& Msg) {
  242. return Msg + " Use /Zi for source location.";
  243. }
  244. static void EmitWarningOrErrorOnInstruction(Instruction *I, Twine Msg,
  245. bool bWarning);
  246. // If we don't have debug location and this is select/phi,
  247. // try recursing users to find instruction with debug info.
  248. // Only recurse phi/select and limit depth to prevent doing
  249. // too much work if no debug location found.
  250. static bool EmitWarningOrErrorOnInstructionFollowPhiSelect(Instruction *I,
  251. Twine Msg,
  252. bool bWarning,
  253. unsigned depth = 0) {
  254. if (depth > 4)
  255. return false;
  256. if (I->getDebugLoc().get()) {
  257. EmitWarningOrErrorOnInstruction(I, Msg, bWarning);
  258. return true;
  259. }
  260. if (isa<PHINode>(I) || isa<SelectInst>(I)) {
  261. for (auto U : I->users())
  262. if (Instruction *UI = dyn_cast<Instruction>(U))
  263. if (EmitWarningOrErrorOnInstructionFollowPhiSelect(UI, Msg, bWarning,
  264. depth + 1))
  265. return true;
  266. }
  267. return false;
  268. }
  269. static void EmitWarningOrErrorOnInstruction(Instruction *I, Twine Msg,
  270. bool bWarning) {
  271. const DebugLoc &DL = I->getDebugLoc();
  272. if (DL.get()) {
  273. if (bWarning)
  274. I->getContext().emitWarning(FormatMessageAtLocation(DL, Msg));
  275. else
  276. I->getContext().emitError(FormatMessageAtLocation(DL, Msg));
  277. return;
  278. } else if (isa<PHINode>(I) || isa<SelectInst>(I)) {
  279. if (EmitWarningOrErrorOnInstructionFollowPhiSelect(I, Msg, bWarning))
  280. return;
  281. }
  282. if (bWarning)
  283. I->getContext().emitWarning(FormatMessageWithoutLocation(Msg));
  284. else
  285. I->getContext().emitError(FormatMessageWithoutLocation(Msg));
  286. }
  287. void EmitErrorOnInstruction(Instruction *I, Twine Msg) {
  288. EmitWarningOrErrorOnInstruction(I, Msg, /*bWarning*/false);
  289. }
  290. void EmitWarningOnInstruction(Instruction *I, Twine Msg) {
  291. EmitWarningOrErrorOnInstruction(I, Msg, /*bWarning*/true);
  292. }
  293. static void EmitWarningOrErrorOnFunction(Function *F, Twine Msg,
  294. bool bWarning) {
  295. DISubprogram *DISP = getDISubprogram(F);
  296. if (DISP) {
  297. if (bWarning)
  298. F->getContext().emitWarning(FormatMessageInSubProgram(DISP, Msg));
  299. else
  300. F->getContext().emitError(FormatMessageInSubProgram(DISP, Msg));
  301. return;
  302. }
  303. if (bWarning)
  304. F->getContext().emitWarning(FormatMessageWithoutLocation(Msg));
  305. else
  306. F->getContext().emitError(FormatMessageWithoutLocation(Msg));
  307. }
  308. void EmitErrorOnFunction(Function *F, Twine Msg) {
  309. EmitWarningOrErrorOnFunction(F, Msg, /*bWarning*/false);
  310. }
  311. void EmitWarningOnFunction(Function *F, Twine Msg) {
  312. EmitWarningOrErrorOnFunction(F, Msg, /*bWarning*/true);
  313. }
  314. static void EmitWarningOrErrorOnGlobalVariable(GlobalVariable *GV,
  315. Twine Msg, bool bWarning) {
  316. DIVariable *DIV = nullptr;
  317. if (GV)
  318. DIV = FindGlobalVariableDebugInfo(GV, GV->getParent()->GetDxilModule().GetOrCreateDebugInfoFinder());
  319. if (DIV) {
  320. if (bWarning)
  321. GV->getContext().emitWarning(FormatMessageInVariable(DIV, Msg));
  322. else
  323. GV->getContext().emitError(FormatMessageInVariable(DIV, Msg));
  324. return;
  325. }
  326. if (bWarning)
  327. GV->getContext().emitWarning(FormatMessageWithoutLocation(Msg));
  328. else
  329. GV->getContext().emitError(FormatMessageWithoutLocation(Msg));
  330. }
  331. void EmitErrorOnGlobalVariable(GlobalVariable *GV, Twine Msg) {
  332. EmitWarningOrErrorOnGlobalVariable(GV, Msg, /*bWarning*/false);
  333. }
  334. void EmitWarningOnGlobalVariable(GlobalVariable *GV, Twine Msg) {
  335. EmitWarningOrErrorOnGlobalVariable(GV, Msg, /*bWarning*/true);
  336. }
  337. const char *kResourceMapErrorMsg =
  338. "local resource not guaranteed to map to unique global resource.";
  339. void EmitResMappingError(Instruction *Res) {
  340. EmitErrorOnInstruction(Res, kResourceMapErrorMsg);
  341. }
  342. void CollectSelect(llvm::Instruction *Inst,
  343. std::unordered_set<llvm::Instruction *> &selectSet) {
  344. unsigned startOpIdx = 0;
  345. // Skip Cond for Select.
  346. if (isa<SelectInst>(Inst)) {
  347. startOpIdx = 1;
  348. } else if (!isa<PHINode>(Inst)) {
  349. // Only check phi and select here.
  350. return;
  351. }
  352. // Already add.
  353. if (selectSet.count(Inst))
  354. return;
  355. selectSet.insert(Inst);
  356. // Scan operand to add node which is phi/select.
  357. unsigned numOperands = Inst->getNumOperands();
  358. for (unsigned i = startOpIdx; i < numOperands; i++) {
  359. Value *V = Inst->getOperand(i);
  360. if (Instruction *I = dyn_cast<Instruction>(V)) {
  361. CollectSelect(I, selectSet);
  362. }
  363. }
  364. }
  365. Value *MergeSelectOnSameValue(Instruction *SelInst, unsigned startOpIdx,
  366. unsigned numOperands) {
  367. Value *op0 = nullptr;
  368. for (unsigned i = startOpIdx; i < numOperands; i++) {
  369. Value *op = SelInst->getOperand(i);
  370. if (i == startOpIdx) {
  371. op0 = op;
  372. } else {
  373. if (op0 != op)
  374. return nullptr;
  375. }
  376. }
  377. if (op0) {
  378. SelInst->replaceAllUsesWith(op0);
  379. SelInst->eraseFromParent();
  380. }
  381. return op0;
  382. }
  383. bool SimplifyTrivialPHIs(BasicBlock *BB) {
  384. bool Changed = false;
  385. SmallVector<Instruction *, 16> Removed;
  386. for (Instruction &I : *BB) {
  387. PHINode *PN = dyn_cast<PHINode>(&I);
  388. if (!PN)
  389. continue;
  390. if (PN->getNumIncomingValues() == 1) {
  391. Value *V = PN->getIncomingValue(0);
  392. PN->replaceAllUsesWith(V);
  393. Removed.push_back(PN);
  394. Changed = true;
  395. }
  396. }
  397. for (Instruction *I : Removed)
  398. I->eraseFromParent();
  399. return Changed;
  400. }
  401. static DbgValueInst *FindDbgValueInst(Value *Val) {
  402. if (auto *ValAsMD = LocalAsMetadata::getIfExists(Val)) {
  403. if (auto *ValMDAsVal = MetadataAsValue::getIfExists(Val->getContext(), ValAsMD)) {
  404. for (User *ValMDUser : ValMDAsVal->users()) {
  405. if (DbgValueInst *DbgValInst = dyn_cast<DbgValueInst>(ValMDUser))
  406. return DbgValInst;
  407. }
  408. }
  409. }
  410. return nullptr;
  411. }
  412. void MigrateDebugValue(Value *Old, Value *New) {
  413. DbgValueInst *DbgValInst = FindDbgValueInst(Old);
  414. if (DbgValInst == nullptr) return;
  415. DbgValInst->setOperand(0, MetadataAsValue::get(New->getContext(), ValueAsMetadata::get(New)));
  416. // Move the dbg value after the new instruction
  417. if (Instruction *NewInst = dyn_cast<Instruction>(New)) {
  418. if (NewInst->getNextNode() != DbgValInst) {
  419. DbgValInst->removeFromParent();
  420. DbgValInst->insertAfter(NewInst);
  421. }
  422. }
  423. }
  424. // Propagates any llvm.dbg.value instruction for a given vector
  425. // to the elements that were used to create it through a series
  426. // of insertelement instructions.
  427. //
  428. // This is used after lowering a vector-returning intrinsic.
  429. // If we just keep the debug info on the recomposed vector,
  430. // we will lose it when we break it apart again during later
  431. // optimization stages.
  432. void TryScatterDebugValueToVectorElements(Value *Val) {
  433. if (!isa<InsertElementInst>(Val) || !Val->getType()->isVectorTy()) return;
  434. DbgValueInst *VecDbgValInst = FindDbgValueInst(Val);
  435. if (VecDbgValInst == nullptr) return;
  436. Type *ElemTy = Val->getType()->getVectorElementType();
  437. DIBuilder DbgInfoBuilder(*VecDbgValInst->getModule());
  438. unsigned ElemSizeInBits = VecDbgValInst->getModule()->getDataLayout().getTypeSizeInBits(ElemTy);
  439. DIExpression *ParentBitPiece = VecDbgValInst->getExpression();
  440. if (ParentBitPiece != nullptr && !ParentBitPiece->isBitPiece())
  441. ParentBitPiece = nullptr;
  442. while (InsertElementInst *InsertElt = dyn_cast<InsertElementInst>(Val)) {
  443. Value *NewElt = InsertElt->getOperand(1);
  444. unsigned EltIdx = static_cast<unsigned>(cast<ConstantInt>(InsertElt->getOperand(2))->getLimitedValue());
  445. unsigned OffsetInBits = EltIdx * ElemSizeInBits;
  446. if (ParentBitPiece) {
  447. assert(OffsetInBits + ElemSizeInBits <= ParentBitPiece->getBitPieceSize()
  448. && "Nested bit piece expression exceeds bounds of its parent.");
  449. OffsetInBits += ParentBitPiece->getBitPieceOffset();
  450. }
  451. DIExpression *DIExpr = DbgInfoBuilder.createBitPieceExpression(OffsetInBits, ElemSizeInBits);
  452. // Offset is basically unused and deprecated in later LLVM versions.
  453. // Emit it as zero otherwise later versions of the bitcode reader will drop the intrinsic.
  454. DbgInfoBuilder.insertDbgValueIntrinsic(NewElt, /* Offset */ 0, VecDbgValInst->getVariable(),
  455. DIExpr, VecDbgValInst->getDebugLoc(), InsertElt);
  456. Val = InsertElt->getOperand(0);
  457. }
  458. }
  459. Value *SelectOnOperation(llvm::Instruction *Inst, unsigned operandIdx) {
  460. Instruction *prototype = Inst;
  461. for (unsigned i = 0; i < prototype->getNumOperands(); i++) {
  462. if (i == operandIdx)
  463. continue;
  464. if (!isa<Constant>(prototype->getOperand(i)))
  465. return nullptr;
  466. }
  467. Value *V = prototype->getOperand(operandIdx);
  468. if (SelectInst *SI = dyn_cast<SelectInst>(V)) {
  469. IRBuilder<> Builder(SI);
  470. Instruction *trueClone = Inst->clone();
  471. trueClone->setOperand(operandIdx, SI->getTrueValue());
  472. Builder.Insert(trueClone);
  473. Instruction *falseClone = Inst->clone();
  474. falseClone->setOperand(operandIdx, SI->getFalseValue());
  475. Builder.Insert(falseClone);
  476. Value *newSel =
  477. Builder.CreateSelect(SI->getCondition(), trueClone, falseClone);
  478. return newSel;
  479. }
  480. if (PHINode *Phi = dyn_cast<PHINode>(V)) {
  481. Type *Ty = Inst->getType();
  482. unsigned numOperands = Phi->getNumOperands();
  483. IRBuilder<> Builder(Phi);
  484. PHINode *newPhi = Builder.CreatePHI(Ty, numOperands);
  485. for (unsigned i = 0; i < numOperands; i++) {
  486. BasicBlock *b = Phi->getIncomingBlock(i);
  487. Value *V = Phi->getIncomingValue(i);
  488. Instruction *iClone = Inst->clone();
  489. IRBuilder<> iBuilder(b->getTerminator()->getPrevNode());
  490. iClone->setOperand(operandIdx, V);
  491. iBuilder.Insert(iClone);
  492. newPhi->addIncoming(iClone, b);
  493. }
  494. return newPhi;
  495. }
  496. return nullptr;
  497. }
  498. llvm::Instruction *SkipAllocas(llvm::Instruction *I) {
  499. // Step past any allocas:
  500. while (I && (isa<AllocaInst>(I) || isa<DbgInfoIntrinsic>(I)))
  501. I = I->getNextNode();
  502. return I;
  503. }
  504. llvm::Instruction *FindAllocaInsertionPt(llvm::BasicBlock* BB) {
  505. return &*BB->getFirstInsertionPt();
  506. }
  507. llvm::Instruction *FindAllocaInsertionPt(llvm::Function* F) {
  508. return FindAllocaInsertionPt(&F->getEntryBlock());
  509. }
  510. llvm::Instruction *FindAllocaInsertionPt(llvm::Instruction* I) {
  511. Function *F = I->getParent()->getParent();
  512. if (F)
  513. return FindAllocaInsertionPt(F);
  514. else // BB with no parent function
  515. return FindAllocaInsertionPt(I->getParent());
  516. }
  517. llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Instruction* I) {
  518. return SkipAllocas(FindAllocaInsertionPt(I));
  519. }
  520. llvm::Instruction *FirstNonAllocaInsertionPt(llvm::BasicBlock* BB) {
  521. return SkipAllocas(FindAllocaInsertionPt(BB));
  522. }
  523. llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Function* F) {
  524. return SkipAllocas(FindAllocaInsertionPt(F));
  525. }
  526. static bool ConsumePrefix(StringRef &Str, StringRef Prefix) {
  527. if (!Str.startswith(Prefix)) return false;
  528. Str = Str.substr(Prefix.size());
  529. return true;
  530. }
  531. bool IsResourceSingleComponent(Type *Ty) {
  532. if (llvm::ArrayType *arrType = llvm::dyn_cast<llvm::ArrayType>(Ty)) {
  533. if (arrType->getArrayNumElements() > 1) {
  534. return false;
  535. }
  536. return IsResourceSingleComponent(arrType->getArrayElementType());
  537. } else if (llvm::StructType *structType =
  538. llvm::dyn_cast<llvm::StructType>(Ty)) {
  539. if (structType->getStructNumElements() > 1) {
  540. return false;
  541. }
  542. return IsResourceSingleComponent(structType->getStructElementType(0));
  543. } else if (llvm::VectorType *vectorType =
  544. llvm::dyn_cast<llvm::VectorType>(Ty)) {
  545. if (vectorType->getNumElements() > 1) {
  546. return false;
  547. }
  548. return IsResourceSingleComponent(vectorType->getVectorElementType());
  549. }
  550. return true;
  551. }
  552. bool IsHLSLResourceType(llvm::Type *Ty) {
  553. if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
  554. if (!ST->hasName())
  555. return false;
  556. StringRef name = ST->getName();
  557. ConsumePrefix(name, "class.");
  558. ConsumePrefix(name, "struct.");
  559. if (name == "SamplerState")
  560. return true;
  561. if (name == "SamplerComparisonState")
  562. return true;
  563. if (name.startswith("AppendStructuredBuffer<"))
  564. return true;
  565. if (name.startswith("ConsumeStructuredBuffer<"))
  566. return true;
  567. if (name.startswith("ConstantBuffer<"))
  568. return true;
  569. if (name == "RaytracingAccelerationStructure")
  570. return true;
  571. if (ConsumePrefix(name, "FeedbackTexture2D")) {
  572. ConsumePrefix(name, "Array");
  573. return name.startswith("<");
  574. }
  575. ConsumePrefix(name, "RasterizerOrdered");
  576. ConsumePrefix(name, "RW");
  577. if (name == "ByteAddressBuffer")
  578. return true;
  579. if (name.startswith("Buffer<"))
  580. return true;
  581. if (name.startswith("StructuredBuffer<"))
  582. return true;
  583. if (ConsumePrefix(name, "Texture")) {
  584. if (name.startswith("1D<"))
  585. return true;
  586. if (name.startswith("1DArray<"))
  587. return true;
  588. if (name.startswith("2D<"))
  589. return true;
  590. if (name.startswith("2DArray<"))
  591. return true;
  592. if (name.startswith("3D<"))
  593. return true;
  594. if (name.startswith("Cube<"))
  595. return true;
  596. if (name.startswith("CubeArray<"))
  597. return true;
  598. if (name.startswith("2DMS<"))
  599. return true;
  600. if (name.startswith("2DMSArray<"))
  601. return true;
  602. return false;
  603. }
  604. }
  605. return false;
  606. }
  607. bool IsHLSLObjectType(llvm::Type *Ty) {
  608. if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
  609. if (!ST->hasName()) {
  610. return false;
  611. }
  612. StringRef name = ST->getName();
  613. // TODO: don't check names.
  614. if (name.startswith("dx.types.wave_t"))
  615. return true;
  616. if (name.endswith("_slice_type"))
  617. return false;
  618. if (IsHLSLResourceType(Ty))
  619. return true;
  620. ConsumePrefix(name, "class.");
  621. ConsumePrefix(name, "struct.");
  622. if (name.startswith("TriangleStream<"))
  623. return true;
  624. if (name.startswith("PointStream<"))
  625. return true;
  626. if (name.startswith("LineStream<"))
  627. return true;
  628. }
  629. return false;
  630. }
  631. bool IsHLSLRayQueryType(llvm::Type *Ty) {
  632. if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
  633. if (!ST->hasName())
  634. return false;
  635. StringRef name = ST->getName();
  636. // TODO: don't check names.
  637. ConsumePrefix(name, "class.");
  638. if (name.startswith("RayQuery<"))
  639. return true;
  640. }
  641. return false;
  642. }
  643. bool IsHLSLResourceDescType(llvm::Type *Ty) {
  644. if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
  645. if (!ST->hasName())
  646. return false;
  647. StringRef name = ST->getName();
  648. // TODO: don't check names.
  649. if (name == ("struct..Resource"))
  650. return true;
  651. }
  652. return false;
  653. }
  654. bool IsIntegerOrFloatingPointType(llvm::Type *Ty) {
  655. return Ty->isIntegerTy() || Ty->isFloatingPointTy();
  656. }
  657. bool ContainsHLSLObjectType(llvm::Type *Ty) {
  658. // Unwrap pointer/array
  659. while (llvm::isa<llvm::PointerType>(Ty))
  660. Ty = llvm::cast<llvm::PointerType>(Ty)->getPointerElementType();
  661. while (llvm::isa<llvm::ArrayType>(Ty))
  662. Ty = llvm::cast<llvm::ArrayType>(Ty)->getArrayElementType();
  663. if (llvm::StructType *ST = llvm::dyn_cast<llvm::StructType>(Ty)) {
  664. if (ST->hasName() && ST->getName().startswith("dx.types."))
  665. return true;
  666. // TODO: How is this suppoed to check for Input/OutputPatch types if
  667. // these have already been eliminated in function arguments during CG?
  668. if (IsHLSLObjectType(Ty))
  669. return true;
  670. // Otherwise, recurse elements of UDT
  671. for (auto ETy : ST->elements()) {
  672. if (ContainsHLSLObjectType(ETy))
  673. return true;
  674. }
  675. }
  676. return false;
  677. }
  678. // Based on the implementation available in LLVM's trunk:
  679. // http://llvm.org/doxygen/Constants_8cpp_source.html#l02734
  680. bool IsSplat(llvm::ConstantDataVector *cdv) {
  681. const char *Base = cdv->getRawDataValues().data();
  682. // Compare elements 1+ to the 0'th element.
  683. unsigned EltSize = cdv->getElementByteSize();
  684. for (unsigned i = 1, e = cdv->getNumElements(); i != e; ++i)
  685. if (memcmp(Base, Base + i * EltSize, EltSize))
  686. return false;
  687. return true;
  688. }
  689. llvm::Type* StripArrayTypes(llvm::Type *Ty, llvm::SmallVectorImpl<unsigned> *OuterToInnerLengths) {
  690. DXASSERT_NOMSG(Ty);
  691. while (Ty->isArrayTy()) {
  692. if (OuterToInnerLengths) {
  693. OuterToInnerLengths->push_back(Ty->getArrayNumElements());
  694. }
  695. Ty = Ty->getArrayElementType();
  696. }
  697. return Ty;
  698. }
  699. llvm::Type* WrapInArrayTypes(llvm::Type *Ty, llvm::ArrayRef<unsigned> OuterToInnerLengths) {
  700. DXASSERT_NOMSG(Ty);
  701. for (auto it = OuterToInnerLengths.rbegin(), E = OuterToInnerLengths.rend(); it != E; ++it) {
  702. Ty = ArrayType::get(Ty, *it);
  703. }
  704. return Ty;
  705. }
  706. namespace {
  707. // Create { v0, v1 } from { v0.lo, v0.hi, v1.lo, v1.hi }
  708. void Make64bitResultForLoad(Type *EltTy, ArrayRef<Value *> resultElts32,
  709. unsigned size, MutableArrayRef<Value *> resultElts,
  710. hlsl::OP *hlslOP, IRBuilder<> &Builder) {
  711. Type *i64Ty = Builder.getInt64Ty();
  712. Type *doubleTy = Builder.getDoubleTy();
  713. if (EltTy == doubleTy) {
  714. Function *makeDouble =
  715. hlslOP->GetOpFunc(DXIL::OpCode::MakeDouble, doubleTy);
  716. Value *makeDoubleOpArg =
  717. Builder.getInt32((unsigned)DXIL::OpCode::MakeDouble);
  718. for (unsigned i = 0; i < size; i++) {
  719. Value *lo = resultElts32[2 * i];
  720. Value *hi = resultElts32[2 * i + 1];
  721. Value *V = Builder.CreateCall(makeDouble, {makeDoubleOpArg, lo, hi});
  722. resultElts[i] = V;
  723. }
  724. } else {
  725. for (unsigned i = 0; i < size; i++) {
  726. Value *lo = resultElts32[2 * i];
  727. Value *hi = resultElts32[2 * i + 1];
  728. lo = Builder.CreateZExt(lo, i64Ty);
  729. hi = Builder.CreateZExt(hi, i64Ty);
  730. hi = Builder.CreateShl(hi, 32);
  731. resultElts[i] = Builder.CreateOr(lo, hi);
  732. }
  733. }
  734. }
  735. // Split { v0, v1 } to { v0.lo, v0.hi, v1.lo, v1.hi }
  736. void Split64bitValForStore(Type *EltTy, ArrayRef<Value *> vals, unsigned size,
  737. MutableArrayRef<Value *> vals32, hlsl::OP *hlslOP,
  738. IRBuilder<> &Builder) {
  739. Type *i32Ty = Builder.getInt32Ty();
  740. Type *doubleTy = Builder.getDoubleTy();
  741. Value *undefI32 = UndefValue::get(i32Ty);
  742. if (EltTy == doubleTy) {
  743. Function *dToU = hlslOP->GetOpFunc(DXIL::OpCode::SplitDouble, doubleTy);
  744. Value *dToUOpArg = Builder.getInt32((unsigned)DXIL::OpCode::SplitDouble);
  745. for (unsigned i = 0; i < size; i++) {
  746. if (isa<UndefValue>(vals[i])) {
  747. vals32[2 * i] = undefI32;
  748. vals32[2 * i + 1] = undefI32;
  749. } else {
  750. Value *retVal = Builder.CreateCall(dToU, {dToUOpArg, vals[i]});
  751. Value *lo = Builder.CreateExtractValue(retVal, 0);
  752. Value *hi = Builder.CreateExtractValue(retVal, 1);
  753. vals32[2 * i] = lo;
  754. vals32[2 * i + 1] = hi;
  755. }
  756. }
  757. } else {
  758. for (unsigned i = 0; i < size; i++) {
  759. if (isa<UndefValue>(vals[i])) {
  760. vals32[2 * i] = undefI32;
  761. vals32[2 * i + 1] = undefI32;
  762. } else {
  763. Value *lo = Builder.CreateTrunc(vals[i], i32Ty);
  764. Value *hi = Builder.CreateLShr(vals[i], 32);
  765. hi = Builder.CreateTrunc(hi, i32Ty);
  766. vals32[2 * i] = lo;
  767. vals32[2 * i + 1] = hi;
  768. }
  769. }
  770. }
  771. }
  772. }
  773. llvm::CallInst *TranslateCallRawBufferLoadToBufferLoad(
  774. llvm::CallInst *CI, llvm::Function *newFunction, hlsl::OP *op) {
  775. IRBuilder<> Builder(CI);
  776. SmallVector<Value *, 4> args;
  777. args.emplace_back(op->GetI32Const((unsigned)DXIL::OpCode::BufferLoad));
  778. for (unsigned i = 1; i < 4; ++i) {
  779. args.emplace_back(CI->getArgOperand(i));
  780. }
  781. CallInst *newCall = Builder.CreateCall(newFunction, args);
  782. return newCall;
  783. }
  784. void ReplaceRawBufferLoadWithBufferLoad(
  785. llvm::Function *F, hlsl::OP *op) {
  786. Type *RTy = F->getReturnType();
  787. if (StructType *STy = dyn_cast<StructType>(RTy)) {
  788. Type *ETy = STy->getElementType(0);
  789. Function *newFunction = op->GetOpFunc(hlsl::DXIL::OpCode::BufferLoad, ETy);
  790. for (auto U = F->user_begin(), E = F->user_end(); U != E;) {
  791. User *user = *(U++);
  792. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  793. CallInst *newCall = TranslateCallRawBufferLoadToBufferLoad(CI, newFunction, op);
  794. CI->replaceAllUsesWith(newCall);
  795. CI->eraseFromParent();
  796. } else {
  797. DXASSERT(false, "function can only be used with call instructions.");
  798. }
  799. }
  800. } else {
  801. DXASSERT(false, "RawBufferLoad should return struct type.");
  802. }
  803. }
  804. llvm::CallInst *TranslateCallRawBufferStoreToBufferStore(
  805. llvm::CallInst *CI, llvm::Function *newFunction, hlsl::OP *op) {
  806. IRBuilder<> Builder(CI);
  807. SmallVector<Value *, 4> args;
  808. args.emplace_back(op->GetI32Const((unsigned)DXIL::OpCode::BufferStore));
  809. for (unsigned i = 1; i < 9; ++i) {
  810. args.emplace_back(CI->getArgOperand(i));
  811. }
  812. CallInst *newCall = Builder.CreateCall(newFunction, args);
  813. return newCall;
  814. }
  815. void ReplaceRawBufferStoreWithBufferStore(llvm::Function *F, hlsl::OP *op) {
  816. DXASSERT(F->getReturnType()->isVoidTy(), "rawBufferStore should return a void type.");
  817. Type *ETy = F->getFunctionType()->getParamType(4); // value
  818. Function *newFunction = op->GetOpFunc(hlsl::DXIL::OpCode::BufferStore, ETy);
  819. for (auto U = F->user_begin(), E = F->user_end(); U != E;) {
  820. User *user = *(U++);
  821. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  822. TranslateCallRawBufferStoreToBufferStore(CI, newFunction, op);
  823. CI->eraseFromParent();
  824. }
  825. else {
  826. DXASSERT(false, "function can only be used with call instructions.");
  827. }
  828. }
  829. }
  830. void ReplaceRawBufferLoad64Bit(llvm::Function *F, llvm::Type *EltTy, hlsl::OP *hlslOP) {
  831. Function *bufLd = hlslOP->GetOpFunc(DXIL::OpCode::RawBufferLoad,
  832. Type::getInt32Ty(hlslOP->GetCtx()));
  833. for (auto U = F->user_begin(), E = F->user_end(); U != E;) {
  834. User *user = *(U++);
  835. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  836. IRBuilder<> Builder(CI);
  837. SmallVector<Value *, 4> args(CI->arg_operands());
  838. Value *offset = CI->getArgOperand(
  839. DXIL::OperandIndex::kRawBufferLoadElementOffsetOpIdx);
  840. unsigned size = 0;
  841. bool bNeedStatus = false;
  842. for (User *U : CI->users()) {
  843. ExtractValueInst *Elt = cast<ExtractValueInst>(U);
  844. DXASSERT(Elt->getNumIndices() == 1, "else invalid use for resRet");
  845. unsigned idx = Elt->getIndices()[0];
  846. if (idx == 4) {
  847. bNeedStatus = true;
  848. } else {
  849. size = std::max(size, idx+1);
  850. }
  851. }
  852. unsigned maskHi = 0;
  853. unsigned maskLo = 0;
  854. switch (size) {
  855. case 1:
  856. maskLo = 3;
  857. break;
  858. case 2:
  859. maskLo = 0xf;
  860. break;
  861. case 3:
  862. maskLo = 0xf;
  863. maskHi = 3;
  864. break;
  865. case 4:
  866. maskLo = 0xf;
  867. maskHi = 0xf;
  868. break;
  869. }
  870. args[DXIL::OperandIndex::kRawBufferLoadMaskOpIdx] =
  871. Builder.getInt8(maskLo);
  872. Value *resultElts[5] = {nullptr, nullptr, nullptr, nullptr, nullptr};
  873. CallInst *newLd = Builder.CreateCall(bufLd, args);
  874. Value *resultElts32[8];
  875. unsigned eltBase = 0;
  876. for (unsigned i = 0; i < size; i++) {
  877. if (i == 2) {
  878. // Update offset 4 by 4 bytes.
  879. if (isa<UndefValue>(offset)) {
  880. // [RW]ByteAddressBuffer has undef element offset -> update index
  881. Value *index = CI->getArgOperand(DXIL::OperandIndex::kRawBufferLoadIndexOpIdx);
  882. args[DXIL::OperandIndex::kRawBufferLoadIndexOpIdx] =
  883. Builder.CreateAdd(index, Builder.getInt32(4 * 4));
  884. }
  885. else {
  886. // [RW]StructuredBuffer -> update element offset
  887. args[DXIL::OperandIndex::kRawBufferLoadElementOffsetOpIdx] =
  888. Builder.CreateAdd(offset, Builder.getInt32(4 * 4));
  889. }
  890. args[DXIL::OperandIndex::kRawBufferLoadMaskOpIdx] =
  891. Builder.getInt8(maskHi);
  892. newLd = Builder.CreateCall(bufLd, args);
  893. eltBase = 4;
  894. }
  895. unsigned resBase = 2 * i;
  896. resultElts32[resBase] =
  897. Builder.CreateExtractValue(newLd, resBase - eltBase);
  898. resultElts32[resBase + 1] =
  899. Builder.CreateExtractValue(newLd, resBase + 1 - eltBase);
  900. }
  901. Make64bitResultForLoad(EltTy, resultElts32, size, resultElts, hlslOP, Builder);
  902. if (bNeedStatus) {
  903. resultElts[4] = Builder.CreateExtractValue(newLd, 4);
  904. }
  905. for (auto it = CI->user_begin(); it != CI->user_end(); ) {
  906. ExtractValueInst *Elt = cast<ExtractValueInst>(*(it++));
  907. DXASSERT(Elt->getNumIndices() == 1, "else invalid use for resRet");
  908. unsigned idx = Elt->getIndices()[0];
  909. if (!Elt->user_empty()) {
  910. Value *newElt = resultElts[idx];
  911. Elt->replaceAllUsesWith(newElt);
  912. }
  913. Elt->eraseFromParent();
  914. }
  915. CI->eraseFromParent();
  916. } else {
  917. DXASSERT(false, "function can only be used with call instructions.");
  918. }
  919. }
  920. }
  921. void ReplaceRawBufferStore64Bit(llvm::Function *F, llvm::Type *ETy, hlsl::OP *hlslOP) {
  922. Function *newFunction = hlslOP->GetOpFunc(hlsl::DXIL::OpCode::RawBufferStore,
  923. Type::getInt32Ty(hlslOP->GetCtx()));
  924. for (auto U = F->user_begin(), E = F->user_end(); U != E;) {
  925. User *user = *(U++);
  926. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  927. IRBuilder<> Builder(CI);
  928. SmallVector<Value *, 4> args(CI->arg_operands());
  929. Value *vals[4] = {
  930. CI->getArgOperand(DXIL::OperandIndex::kRawBufferStoreVal0OpIdx),
  931. CI->getArgOperand(DXIL::OperandIndex::kRawBufferStoreVal1OpIdx),
  932. CI->getArgOperand(DXIL::OperandIndex::kRawBufferStoreVal2OpIdx),
  933. CI->getArgOperand(DXIL::OperandIndex::kRawBufferStoreVal3OpIdx)};
  934. ConstantInt *cMask = cast<ConstantInt>(
  935. CI->getArgOperand(DXIL::OperandIndex::kRawBufferStoreMaskOpIdx));
  936. Value *undefI32 = UndefValue::get(Builder.getInt32Ty());
  937. Value *vals32[8] = {undefI32, undefI32, undefI32, undefI32,
  938. undefI32, undefI32, undefI32, undefI32};
  939. unsigned maskLo = 0;
  940. unsigned maskHi = 0;
  941. unsigned size = 0;
  942. unsigned mask = cMask->getLimitedValue();
  943. switch (mask) {
  944. case 1:
  945. maskLo = 3;
  946. size = 1;
  947. break;
  948. case 3:
  949. maskLo = 15;
  950. size = 2;
  951. break;
  952. case 7:
  953. maskLo = 15;
  954. maskHi = 3;
  955. size = 3;
  956. break;
  957. case 15:
  958. maskLo = 15;
  959. maskHi = 15;
  960. size = 4;
  961. break;
  962. default:
  963. DXASSERT(0, "invalid mask");
  964. }
  965. Split64bitValForStore(ETy, vals, size, vals32, hlslOP, Builder);
  966. args[DXIL::OperandIndex::kRawBufferStoreMaskOpIdx] =
  967. Builder.getInt8(maskLo);
  968. args[DXIL::OperandIndex::kRawBufferStoreVal0OpIdx] = vals32[0];
  969. args[DXIL::OperandIndex::kRawBufferStoreVal1OpIdx] = vals32[1];
  970. args[DXIL::OperandIndex::kRawBufferStoreVal2OpIdx] = vals32[2];
  971. args[DXIL::OperandIndex::kRawBufferStoreVal3OpIdx] = vals32[3];
  972. Builder.CreateCall(newFunction, args);
  973. if (maskHi) {
  974. // Update offset 4 by 4 bytes.
  975. Value *offset = args[DXIL::OperandIndex::kBufferStoreCoord1OpIdx];
  976. if (isa<UndefValue>(offset)) {
  977. // [RW]ByteAddressBuffer has element offset == undef -> update index instead
  978. Value *index = args[DXIL::OperandIndex::kBufferStoreCoord0OpIdx];
  979. index = Builder.CreateAdd(index, Builder.getInt32(4 * 4));
  980. args[DXIL::OperandIndex::kRawBufferStoreIndexOpIdx] = index;
  981. }
  982. else {
  983. // [RW]StructuredBuffer -> update element offset
  984. offset = Builder.CreateAdd(offset, Builder.getInt32(4 * 4));
  985. args[DXIL::OperandIndex::kRawBufferStoreElementOffsetOpIdx] = offset;
  986. }
  987. args[DXIL::OperandIndex::kRawBufferStoreMaskOpIdx] =
  988. Builder.getInt8(maskHi);
  989. args[DXIL::OperandIndex::kRawBufferStoreVal0OpIdx] = vals32[4];
  990. args[DXIL::OperandIndex::kRawBufferStoreVal1OpIdx] = vals32[5];
  991. args[DXIL::OperandIndex::kRawBufferStoreVal2OpIdx] = vals32[6];
  992. args[DXIL::OperandIndex::kRawBufferStoreVal3OpIdx] = vals32[7];
  993. Builder.CreateCall(newFunction, args);
  994. }
  995. CI->eraseFromParent();
  996. } else {
  997. DXASSERT(false, "function can only be used with call instructions.");
  998. }
  999. }
  1000. }
  1001. }
  1002. }
  1003. ///////////////////////////////////////////////////////////////////////////////
  1004. namespace {
  1005. class DxilLoadMetadata : public ModulePass {
  1006. public:
  1007. static char ID; // Pass identification, replacement for typeid
  1008. explicit DxilLoadMetadata () : ModulePass(ID) {}
  1009. const char *getPassName() const override { return "HLSL load DxilModule from metadata"; }
  1010. bool runOnModule(Module &M) override {
  1011. if (!M.HasDxilModule()) {
  1012. (void)M.GetOrCreateDxilModule();
  1013. return true;
  1014. }
  1015. return false;
  1016. }
  1017. };
  1018. }
  1019. char DxilLoadMetadata::ID = 0;
  1020. ModulePass *llvm::createDxilLoadMetadataPass() {
  1021. return new DxilLoadMetadata();
  1022. }
  1023. INITIALIZE_PASS(DxilLoadMetadata, "hlsl-dxilload", "HLSL load DxilModule from metadata", false, false)