DxilUtil.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilUtil.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Dxil helper functions. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "llvm/IR/GlobalVariable.h"
  12. #include "dxc/DXIL/DxilTypeSystem.h"
  13. #include "dxc/DXIL/DxilUtil.h"
  14. #include "dxc/DXIL/DxilModule.h"
  15. #include "llvm/Bitcode/ReaderWriter.h"
  16. #include "llvm/IR/DiagnosticInfo.h"
  17. #include "llvm/IR/DiagnosticPrinter.h"
  18. #include "llvm/IR/LLVMContext.h"
  19. #include "llvm/IR/Module.h"
  20. #include "llvm/Support/MemoryBuffer.h"
  21. #include "llvm/Support/raw_ostream.h"
  22. #include "llvm/IR/Instructions.h"
  23. #include "llvm/IR/Constants.h"
  24. #include "llvm/IR/IRBuilder.h"
  25. #include "dxc/Support/Global.h"
  26. #include "llvm/ADT/StringExtras.h"
  27. #include "llvm/ADT/Twine.h"
  28. using namespace llvm;
  29. using namespace hlsl;
  30. namespace hlsl {
  31. namespace dxilutil {
  32. const char ManglingPrefix[] = "\01?";
  33. const char EntryPrefix[] = "dx.entry.";
  34. Type *GetArrayEltTy(Type *Ty) {
  35. if (isa<PointerType>(Ty))
  36. Ty = Ty->getPointerElementType();
  37. while (isa<ArrayType>(Ty)) {
  38. Ty = Ty->getArrayElementType();
  39. }
  40. return Ty;
  41. }
  42. bool HasDynamicIndexing(Value *V) {
  43. for (auto User : V->users()) {
  44. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  45. for (auto Idx = GEP->idx_begin(); Idx != GEP->idx_end(); ++Idx) {
  46. if (!isa<ConstantInt>(Idx))
  47. return true;
  48. }
  49. }
  50. }
  51. return false;
  52. }
  53. unsigned
  54. GetLegacyCBufferFieldElementSize(DxilFieldAnnotation &fieldAnnotation,
  55. llvm::Type *Ty,
  56. DxilTypeSystem &typeSys) {
  57. while (isa<ArrayType>(Ty)) {
  58. Ty = Ty->getArrayElementType();
  59. }
  60. // Bytes.
  61. CompType compType = fieldAnnotation.GetCompType();
  62. unsigned compSize = compType.Is64Bit() ? 8 : compType.Is16Bit() && !typeSys.UseMinPrecision() ? 2 : 4;
  63. unsigned fieldSize = compSize;
  64. if (Ty->isVectorTy()) {
  65. fieldSize *= Ty->getVectorNumElements();
  66. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  67. DxilStructAnnotation *EltAnnotation = typeSys.GetStructAnnotation(ST);
  68. if (EltAnnotation) {
  69. fieldSize = EltAnnotation->GetCBufferSize();
  70. } else {
  71. // Calculate size when don't have annotation.
  72. if (fieldAnnotation.HasMatrixAnnotation()) {
  73. const DxilMatrixAnnotation &matAnnotation =
  74. fieldAnnotation.GetMatrixAnnotation();
  75. unsigned rows = matAnnotation.Rows;
  76. unsigned cols = matAnnotation.Cols;
  77. if (matAnnotation.Orientation == MatrixOrientation::ColumnMajor) {
  78. rows = cols;
  79. cols = matAnnotation.Rows;
  80. } else if (matAnnotation.Orientation != MatrixOrientation::RowMajor) {
  81. // Invalid matrix orientation.
  82. fieldSize = 0;
  83. }
  84. fieldSize = (rows - 1) * 16 + cols * 4;
  85. } else {
  86. // Cannot find struct annotation.
  87. fieldSize = 0;
  88. }
  89. }
  90. }
  91. return fieldSize;
  92. }
  93. bool IsStaticGlobal(GlobalVariable *GV) {
  94. return GV->getLinkage() == GlobalValue::LinkageTypes::InternalLinkage &&
  95. GV->getType()->getPointerAddressSpace() == DXIL::kDefaultAddrSpace;
  96. }
  97. bool IsSharedMemoryGlobal(llvm::GlobalVariable *GV) {
  98. return GV->getType()->getPointerAddressSpace() == DXIL::kTGSMAddrSpace;
  99. }
  100. bool RemoveUnusedFunctions(Module &M, Function *EntryFunc,
  101. Function *PatchConstantFunc, bool IsLib) {
  102. std::vector<Function *> deadList;
  103. for (auto &F : M.functions()) {
  104. if (&F == EntryFunc || &F == PatchConstantFunc)
  105. continue;
  106. if (F.isDeclaration() || !IsLib) {
  107. if (F.user_empty())
  108. deadList.emplace_back(&F);
  109. }
  110. }
  111. bool bUpdated = deadList.size();
  112. for (Function *F : deadList)
  113. F->eraseFromParent();
  114. return bUpdated;
  115. }
  116. void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context) {
  117. DiagnosticPrinter *printer = reinterpret_cast<DiagnosticPrinter *>(Context);
  118. DI.print(*printer);
  119. }
  120. StringRef DemangleFunctionName(StringRef name) {
  121. if (!name.startswith(ManglingPrefix)) {
  122. // Name isn't mangled.
  123. return name;
  124. }
  125. size_t nameEnd = name.find_first_of("@");
  126. DXASSERT(nameEnd != StringRef::npos, "else Name isn't mangled but has \01?");
  127. return name.substr(2, nameEnd - 2);
  128. }
  129. std::string ReplaceFunctionName(StringRef originalName, StringRef newName) {
  130. if (originalName.startswith(ManglingPrefix)) {
  131. return (Twine(ManglingPrefix) + newName +
  132. originalName.substr(originalName.find_first_of('@'))).str();
  133. } else if (originalName.startswith(EntryPrefix)) {
  134. return (Twine(EntryPrefix) + newName).str();
  135. }
  136. return newName.str();
  137. }
  138. // From AsmWriter.cpp
  139. // PrintEscapedString - Print each character of the specified string, escaping
  140. // it if it is not printable or if it is an escape char.
  141. void PrintEscapedString(StringRef Name, raw_ostream &Out) {
  142. for (unsigned i = 0, e = Name.size(); i != e; ++i) {
  143. unsigned char C = Name[i];
  144. if (isprint(C) && C != '\\' && C != '"')
  145. Out << C;
  146. else
  147. Out << '\\' << hexdigit(C >> 4) << hexdigit(C & 0x0F);
  148. }
  149. }
  150. void PrintUnescapedString(StringRef Name, raw_ostream &Out) {
  151. for (unsigned i = 0, e = Name.size(); i != e; ++i) {
  152. unsigned char C = Name[i];
  153. if (C == '\\') {
  154. C = Name[++i];
  155. unsigned value = hexDigitValue(C);
  156. if (value != -1U) {
  157. C = (unsigned char)value;
  158. unsigned value2 = hexDigitValue(Name[i+1]);
  159. assert(value2 != -1U && "otherwise, not a two digit hex escape");
  160. if (value2 != -1U) {
  161. C = (C << 4) + (unsigned char)value2;
  162. ++i;
  163. }
  164. } // else, the next character (in C) should be the escaped character
  165. }
  166. Out << C;
  167. }
  168. }
  169. std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,
  170. llvm::LLVMContext &Ctx,
  171. std::string &DiagStr) {
  172. // Note: the DiagStr is not used.
  173. auto pModule = llvm::parseBitcodeFile(MB->getMemBufferRef(), Ctx);
  174. if (!pModule) {
  175. return nullptr;
  176. }
  177. return std::unique_ptr<llvm::Module>(pModule.get().release());
  178. }
  179. std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::StringRef BC,
  180. llvm::LLVMContext &Ctx,
  181. std::string &DiagStr) {
  182. std::unique_ptr<llvm::MemoryBuffer> pBitcodeBuf(
  183. llvm::MemoryBuffer::getMemBuffer(BC, "", false));
  184. return LoadModuleFromBitcode(pBitcodeBuf.get(), Ctx, DiagStr);
  185. }
  186. // If we don't have debug location and this is select/phi,
  187. // try recursing users to find instruction with debug info.
  188. // Only recurse phi/select and limit depth to prevent doing
  189. // too much work if no debug location found.
  190. static bool EmitErrorOnInstructionFollowPhiSelect(
  191. Instruction *I, StringRef Msg, unsigned depth=0) {
  192. if (depth > 4)
  193. return false;
  194. if (I->getDebugLoc().get()) {
  195. EmitErrorOnInstruction(I, Msg);
  196. return true;
  197. }
  198. if (isa<PHINode>(I) || isa<SelectInst>(I)) {
  199. for (auto U : I->users())
  200. if (Instruction *UI = dyn_cast<Instruction>(U))
  201. if (EmitErrorOnInstructionFollowPhiSelect(UI, Msg, depth+1))
  202. return true;
  203. }
  204. return false;
  205. }
  206. std::string FormatMessageAtLocation(const DebugLoc &DL, const Twine& Msg) {
  207. std::string locString;
  208. raw_string_ostream os(locString);
  209. DL.print(os);
  210. os << ": " << Msg;
  211. return os.str();
  212. }
  213. Twine FormatMessageWithoutLocation(const Twine& Msg) {
  214. return Msg + " Use /Zi for source location.";
  215. }
  216. void EmitErrorOnInstruction(Instruction *I, StringRef Msg) {
  217. const DebugLoc &DL = I->getDebugLoc();
  218. if (DL.get()) {
  219. I->getContext().emitError(FormatMessageAtLocation(DL, Msg));
  220. return;
  221. } else if (isa<PHINode>(I) || isa<SelectInst>(I)) {
  222. if (EmitErrorOnInstructionFollowPhiSelect(I, Msg))
  223. return;
  224. }
  225. I->getContext().emitError(FormatMessageWithoutLocation(Msg));
  226. }
  227. const StringRef kResourceMapErrorMsg =
  228. "local resource not guaranteed to map to unique global resource.";
  229. void EmitResMappingError(Instruction *Res) {
  230. EmitErrorOnInstruction(Res, kResourceMapErrorMsg);
  231. }
  232. void CollectSelect(llvm::Instruction *Inst,
  233. std::unordered_set<llvm::Instruction *> &selectSet) {
  234. unsigned startOpIdx = 0;
  235. // Skip Cond for Select.
  236. if (isa<SelectInst>(Inst)) {
  237. startOpIdx = 1;
  238. } else if (!isa<PHINode>(Inst)) {
  239. // Only check phi and select here.
  240. return;
  241. }
  242. // Already add.
  243. if (selectSet.count(Inst))
  244. return;
  245. selectSet.insert(Inst);
  246. // Scan operand to add node which is phi/select.
  247. unsigned numOperands = Inst->getNumOperands();
  248. for (unsigned i = startOpIdx; i < numOperands; i++) {
  249. Value *V = Inst->getOperand(i);
  250. if (Instruction *I = dyn_cast<Instruction>(V)) {
  251. CollectSelect(I, selectSet);
  252. }
  253. }
  254. }
  255. Value *MergeSelectOnSameValue(Instruction *SelInst, unsigned startOpIdx,
  256. unsigned numOperands) {
  257. Value *op0 = nullptr;
  258. for (unsigned i = startOpIdx; i < numOperands; i++) {
  259. Value *op = SelInst->getOperand(i);
  260. if (i == startOpIdx) {
  261. op0 = op;
  262. } else {
  263. if (op0 != op)
  264. return nullptr;
  265. }
  266. }
  267. if (op0) {
  268. SelInst->replaceAllUsesWith(op0);
  269. SelInst->eraseFromParent();
  270. }
  271. return op0;
  272. }
  273. bool SimplifyTrivialPHIs(BasicBlock *BB) {
  274. bool Changed = false;
  275. SmallVector<Instruction *, 16> Removed;
  276. for (Instruction &I : *BB) {
  277. PHINode *PN = dyn_cast<PHINode>(&I);
  278. if (!PN)
  279. continue;
  280. if (PN->getNumIncomingValues() == 1) {
  281. Value *V = PN->getIncomingValue(0);
  282. PN->replaceAllUsesWith(V);
  283. Removed.push_back(PN);
  284. Changed = true;
  285. }
  286. }
  287. for (Instruction *I : Removed)
  288. I->eraseFromParent();
  289. return Changed;
  290. }
  291. Value *SelectOnOperation(llvm::Instruction *Inst, unsigned operandIdx) {
  292. Instruction *prototype = Inst;
  293. for (unsigned i = 0; i < prototype->getNumOperands(); i++) {
  294. if (i == operandIdx)
  295. continue;
  296. if (!isa<Constant>(prototype->getOperand(i)))
  297. return nullptr;
  298. }
  299. Value *V = prototype->getOperand(operandIdx);
  300. if (SelectInst *SI = dyn_cast<SelectInst>(V)) {
  301. IRBuilder<> Builder(SI);
  302. Instruction *trueClone = Inst->clone();
  303. trueClone->setOperand(operandIdx, SI->getTrueValue());
  304. Builder.Insert(trueClone);
  305. Instruction *falseClone = Inst->clone();
  306. falseClone->setOperand(operandIdx, SI->getFalseValue());
  307. Builder.Insert(falseClone);
  308. Value *newSel =
  309. Builder.CreateSelect(SI->getCondition(), trueClone, falseClone);
  310. return newSel;
  311. }
  312. if (PHINode *Phi = dyn_cast<PHINode>(V)) {
  313. Type *Ty = Inst->getType();
  314. unsigned numOperands = Phi->getNumOperands();
  315. IRBuilder<> Builder(Phi);
  316. PHINode *newPhi = Builder.CreatePHI(Ty, numOperands);
  317. for (unsigned i = 0; i < numOperands; i++) {
  318. BasicBlock *b = Phi->getIncomingBlock(i);
  319. Value *V = Phi->getIncomingValue(i);
  320. Instruction *iClone = Inst->clone();
  321. IRBuilder<> iBuilder(b->getTerminator()->getPrevNode());
  322. iClone->setOperand(operandIdx, V);
  323. iBuilder.Insert(iClone);
  324. newPhi->addIncoming(iClone, b);
  325. }
  326. return newPhi;
  327. }
  328. return nullptr;
  329. }
  330. llvm::Instruction *SkipAllocas(llvm::Instruction *I) {
  331. // Step past any allocas:
  332. while (I && isa<AllocaInst>(I))
  333. I = I->getNextNode();
  334. return I;
  335. }
  336. llvm::Instruction *FindAllocaInsertionPt(llvm::BasicBlock* BB) {
  337. return &*BB->getFirstInsertionPt();
  338. }
  339. llvm::Instruction *FindAllocaInsertionPt(llvm::Function* F) {
  340. return FindAllocaInsertionPt(&F->getEntryBlock());
  341. }
  342. llvm::Instruction *FindAllocaInsertionPt(llvm::Instruction* I) {
  343. Function *F = I->getParent()->getParent();
  344. if (F)
  345. return FindAllocaInsertionPt(F);
  346. else // BB with no parent function
  347. return FindAllocaInsertionPt(I->getParent());
  348. }
  349. llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Instruction* I) {
  350. return SkipAllocas(FindAllocaInsertionPt(I));
  351. }
  352. llvm::Instruction *FirstNonAllocaInsertionPt(llvm::BasicBlock* BB) {
  353. return SkipAllocas(FindAllocaInsertionPt(BB));
  354. }
  355. llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Function* F) {
  356. return SkipAllocas(FindAllocaInsertionPt(F));
  357. }
  358. bool IsHLSLResourceType(llvm::Type *Ty) {
  359. if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
  360. StringRef name = ST->getName();
  361. name = name.ltrim("class.");
  362. name = name.ltrim("struct.");
  363. if (name == "SamplerState")
  364. return true;
  365. if (name == "SamplerComparisonState")
  366. return true;
  367. if (name.startswith("AppendStructuredBuffer<"))
  368. return true;
  369. if (name.startswith("ConsumeStructuredBuffer<"))
  370. return true;
  371. if (name.startswith("ConstantBuffer<"))
  372. return true;
  373. if (name == "RaytracingAccelerationStructure")
  374. return true;
  375. name = name.ltrim("RasterizerOrdered");
  376. name = name.ltrim("RW");
  377. if (name == "ByteAddressBuffer")
  378. return true;
  379. if (name.startswith("Buffer<"))
  380. return true;
  381. if (name.startswith("StructuredBuffer<"))
  382. return true;
  383. if (name.startswith("Texture")) {
  384. name = name.ltrim("Texture");
  385. if (name.startswith("1D<"))
  386. return true;
  387. if (name.startswith("1DArray<"))
  388. return true;
  389. if (name.startswith("2D<"))
  390. return true;
  391. if (name.startswith("2DArray<"))
  392. return true;
  393. if (name.startswith("3D<"))
  394. return true;
  395. if (name.startswith("Cube<"))
  396. return true;
  397. if (name.startswith("CubeArray<"))
  398. return true;
  399. if (name.startswith("2DMS<"))
  400. return true;
  401. if (name.startswith("2DMSArray<"))
  402. return true;
  403. }
  404. }
  405. return false;
  406. }
  407. bool IsHLSLObjectType(llvm::Type *Ty) {
  408. if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
  409. StringRef name = ST->getName();
  410. // TODO: don't check names.
  411. if (name.startswith("dx.types.wave_t"))
  412. return true;
  413. if (name.endswith("_slice_type"))
  414. return false;
  415. if (IsHLSLResourceType(Ty))
  416. return true;
  417. name = name.ltrim("class.");
  418. name = name.ltrim("struct.");
  419. if (name.startswith("TriangleStream<"))
  420. return true;
  421. if (name.startswith("PointStream<"))
  422. return true;
  423. if (name.startswith("LineStream<"))
  424. return true;
  425. }
  426. return false;
  427. }
  428. bool IsHLSLMatrixType(Type *Ty) {
  429. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  430. Type *EltTy = ST->getElementType(0);
  431. if (!ST->getName().startswith("class.matrix"))
  432. return false;
  433. bool isVecArray =
  434. EltTy->isArrayTy() && EltTy->getArrayElementType()->isVectorTy();
  435. return isVecArray && EltTy->getArrayNumElements() <= 4;
  436. }
  437. return false;
  438. }
  439. bool IsIntegerOrFloatingPointType(llvm::Type *Ty) {
  440. return Ty->isIntegerTy() || Ty->isFloatingPointTy();
  441. }
  442. bool ContainsHLSLObjectType(llvm::Type *Ty) {
  443. // Unwrap pointer/array
  444. while (llvm::isa<llvm::PointerType>(Ty))
  445. Ty = llvm::cast<llvm::PointerType>(Ty)->getPointerElementType();
  446. while (llvm::isa<llvm::ArrayType>(Ty))
  447. Ty = llvm::cast<llvm::ArrayType>(Ty)->getArrayElementType();
  448. if (llvm::StructType *ST = llvm::dyn_cast<llvm::StructType>(Ty)) {
  449. if (ST->getName().startswith("dx.types."))
  450. return true;
  451. // TODO: How is this suppoed to check for Input/OutputPatch types if
  452. // these have already been eliminated in function arguments during CG?
  453. if (IsHLSLObjectType(Ty))
  454. return true;
  455. // Otherwise, recurse elements of UDT
  456. for (auto ETy : ST->elements()) {
  457. if (ContainsHLSLObjectType(ETy))
  458. return true;
  459. }
  460. }
  461. return false;
  462. }
  463. // Based on the implementation available in LLVM's trunk:
  464. // http://llvm.org/doxygen/Constants_8cpp_source.html#l02734
  465. bool IsSplat(llvm::ConstantDataVector *cdv) {
  466. const char *Base = cdv->getRawDataValues().data();
  467. // Compare elements 1+ to the 0'th element.
  468. unsigned EltSize = cdv->getElementByteSize();
  469. for (unsigned i = 1, e = cdv->getNumElements(); i != e; ++i)
  470. if (memcmp(Base, Base + i * EltSize, EltSize))
  471. return false;
  472. return true;
  473. }
  474. }
  475. }
  476. ///////////////////////////////////////////////////////////////////////////////
  477. namespace {
  478. class DxilLoadMetadata : public ModulePass {
  479. public:
  480. static char ID; // Pass identification, replacement for typeid
  481. explicit DxilLoadMetadata () : ModulePass(ID) {}
  482. const char *getPassName() const override { return "HLSL load DxilModule from metadata"; }
  483. bool runOnModule(Module &M) override {
  484. if (!M.HasDxilModule()) {
  485. (void)M.GetOrCreateDxilModule();
  486. return true;
  487. }
  488. return false;
  489. }
  490. };
  491. }
  492. char DxilLoadMetadata::ID = 0;
  493. ModulePass *llvm::createDxilLoadMetadataPass() {
  494. return new DxilLoadMetadata();
  495. }
  496. INITIALIZE_PASS(DxilLoadMetadata, "hlsl-dxilload", "HLSL load DxilModule from metadata", false, false)