DxilUtil.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilUtil.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Dxil helper functions. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "llvm/IR/GlobalVariable.h"
  12. #include "dxc/HLSL/DxilTypeSystem.h"
  13. #include "dxc/HLSL/DxilUtil.h"
  14. #include "dxc/HLSL/DxilModule.h"
  15. #include "dxc/HLSL/HLModule.h"
  16. #include "llvm/Bitcode/ReaderWriter.h"
  17. #include "llvm/IR/DiagnosticInfo.h"
  18. #include "llvm/IR/DiagnosticPrinter.h"
  19. #include "llvm/IR/LLVMContext.h"
  20. #include "llvm/IR/Module.h"
  21. #include "llvm/Support/MemoryBuffer.h"
  22. #include "llvm/Support/raw_ostream.h"
  23. #include "llvm/IR/Instructions.h"
  24. #include "llvm/IR/Constants.h"
  25. #include "llvm/IR/IRBuilder.h"
  26. #include "dxc/Support/Global.h"
  27. #include "llvm/ADT/StringExtras.h"
  28. #include "llvm/ADT/Twine.h"
  29. using namespace llvm;
  30. using namespace hlsl;
  31. namespace hlsl {
  32. namespace dxilutil {
  33. const char ManglingPrefix[] = "\01?";
  34. const char EntryPrefix[] = "dx.entry.";
  35. Type *GetArrayEltTy(Type *Ty) {
  36. if (isa<PointerType>(Ty))
  37. Ty = Ty->getPointerElementType();
  38. while (isa<ArrayType>(Ty)) {
  39. Ty = Ty->getArrayElementType();
  40. }
  41. return Ty;
  42. }
  43. bool HasDynamicIndexing(Value *V) {
  44. for (auto User : V->users()) {
  45. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  46. for (auto Idx = GEP->idx_begin(); Idx != GEP->idx_end(); ++Idx) {
  47. if (!isa<ConstantInt>(Idx))
  48. return true;
  49. }
  50. }
  51. }
  52. return false;
  53. }
  54. unsigned
  55. GetLegacyCBufferFieldElementSize(DxilFieldAnnotation &fieldAnnotation,
  56. llvm::Type *Ty,
  57. DxilTypeSystem &typeSys) {
  58. while (isa<ArrayType>(Ty)) {
  59. Ty = Ty->getArrayElementType();
  60. }
  61. // Bytes.
  62. CompType compType = fieldAnnotation.GetCompType();
  63. unsigned compSize = compType.Is64Bit() ? 8 : compType.Is16Bit() && !typeSys.UseMinPrecision() ? 2 : 4;
  64. unsigned fieldSize = compSize;
  65. if (Ty->isVectorTy()) {
  66. fieldSize *= Ty->getVectorNumElements();
  67. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  68. DxilStructAnnotation *EltAnnotation = typeSys.GetStructAnnotation(ST);
  69. if (EltAnnotation) {
  70. fieldSize = EltAnnotation->GetCBufferSize();
  71. } else {
  72. // Calculate size when don't have annotation.
  73. if (fieldAnnotation.HasMatrixAnnotation()) {
  74. const DxilMatrixAnnotation &matAnnotation =
  75. fieldAnnotation.GetMatrixAnnotation();
  76. unsigned rows = matAnnotation.Rows;
  77. unsigned cols = matAnnotation.Cols;
  78. if (matAnnotation.Orientation == MatrixOrientation::ColumnMajor) {
  79. rows = cols;
  80. cols = matAnnotation.Rows;
  81. } else if (matAnnotation.Orientation != MatrixOrientation::RowMajor) {
  82. // Invalid matrix orientation.
  83. fieldSize = 0;
  84. }
  85. fieldSize = (rows - 1) * 16 + cols * 4;
  86. } else {
  87. // Cannot find struct annotation.
  88. fieldSize = 0;
  89. }
  90. }
  91. }
  92. return fieldSize;
  93. }
  94. bool IsStaticGlobal(GlobalVariable *GV) {
  95. return GV->getLinkage() == GlobalValue::LinkageTypes::InternalLinkage &&
  96. GV->getType()->getPointerAddressSpace() == DXIL::kDefaultAddrSpace;
  97. }
  98. bool IsSharedMemoryGlobal(llvm::GlobalVariable *GV) {
  99. return GV->getType()->getPointerAddressSpace() == DXIL::kTGSMAddrSpace;
  100. }
  101. bool RemoveUnusedFunctions(Module &M, Function *EntryFunc,
  102. Function *PatchConstantFunc, bool IsLib) {
  103. std::vector<Function *> deadList;
  104. for (auto &F : M.functions()) {
  105. if (&F == EntryFunc || &F == PatchConstantFunc)
  106. continue;
  107. if (F.isDeclaration() || !IsLib) {
  108. if (F.user_empty())
  109. deadList.emplace_back(&F);
  110. }
  111. }
  112. bool bUpdated = deadList.size();
  113. for (Function *F : deadList)
  114. F->eraseFromParent();
  115. return bUpdated;
  116. }
  117. void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context) {
  118. DiagnosticPrinter *printer = reinterpret_cast<DiagnosticPrinter *>(Context);
  119. DI.print(*printer);
  120. }
  121. StringRef DemangleFunctionName(StringRef name) {
  122. if (!name.startswith(ManglingPrefix)) {
  123. // Name isn't mangled.
  124. return name;
  125. }
  126. size_t nameEnd = name.find_first_of("@");
  127. DXASSERT(nameEnd != StringRef::npos, "else Name isn't mangled but has \01?");
  128. return name.substr(2, nameEnd - 2);
  129. }
  130. std::string ReplaceFunctionName(StringRef originalName, StringRef newName) {
  131. if (originalName.startswith(ManglingPrefix)) {
  132. return (Twine(ManglingPrefix) + newName +
  133. originalName.substr(originalName.find_first_of('@'))).str();
  134. } else if (originalName.startswith(EntryPrefix)) {
  135. return (Twine(EntryPrefix) + newName).str();
  136. }
  137. return newName.str();
  138. }
  139. // From AsmWriter.cpp
  140. // PrintEscapedString - Print each character of the specified string, escaping
  141. // it if it is not printable or if it is an escape char.
  142. void PrintEscapedString(StringRef Name, raw_ostream &Out) {
  143. for (unsigned i = 0, e = Name.size(); i != e; ++i) {
  144. unsigned char C = Name[i];
  145. if (isprint(C) && C != '\\' && C != '"')
  146. Out << C;
  147. else
  148. Out << '\\' << hexdigit(C >> 4) << hexdigit(C & 0x0F);
  149. }
  150. }
  151. void PrintUnescapedString(StringRef Name, raw_ostream &Out) {
  152. for (unsigned i = 0, e = Name.size(); i != e; ++i) {
  153. unsigned char C = Name[i];
  154. if (C == '\\') {
  155. C = Name[++i];
  156. unsigned value = hexDigitValue(C);
  157. if (value != -1U) {
  158. C = (unsigned char)value;
  159. unsigned value2 = hexDigitValue(Name[i+1]);
  160. assert(value2 != -1U && "otherwise, not a two digit hex escape");
  161. if (value2 != -1U) {
  162. C = (C << 4) + (unsigned char)value2;
  163. ++i;
  164. }
  165. } // else, the next character (in C) should be the escaped character
  166. }
  167. Out << C;
  168. }
  169. }
  170. std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,
  171. llvm::LLVMContext &Ctx,
  172. std::string &DiagStr) {
  173. raw_string_ostream DiagStream(DiagStr);
  174. llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
  175. LLVMContext::DiagnosticHandlerTy OrigHandler = Ctx.getDiagnosticHandler();
  176. void *OrigContext = Ctx.getDiagnosticContext();
  177. Ctx.setDiagnosticHandler(PrintDiagnosticHandler, &DiagPrinter, true);
  178. ErrorOr<std::unique_ptr<llvm::Module>> pModule(
  179. llvm::parseBitcodeFile(MB->getMemBufferRef(), Ctx));
  180. Ctx.setDiagnosticHandler(OrigHandler, OrigContext);
  181. if (std::error_code ec = pModule.getError()) {
  182. return nullptr;
  183. }
  184. return std::unique_ptr<llvm::Module>(pModule.get().release());
  185. }
  186. std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::StringRef BC,
  187. llvm::LLVMContext &Ctx,
  188. std::string &DiagStr) {
  189. std::unique_ptr<llvm::MemoryBuffer> pBitcodeBuf(
  190. llvm::MemoryBuffer::getMemBuffer(BC, "", false));
  191. return LoadModuleFromBitcode(pBitcodeBuf.get(), Ctx, DiagStr);
  192. }
  193. // If we don't have debug location and this is select/phi,
  194. // try recursing users to find instruction with debug info.
  195. // Only recurse phi/select and limit depth to prevent doing
  196. // too much work if no debug location found.
  197. static bool EmitErrorOnInstructionFollowPhiSelect(
  198. Instruction *I, StringRef Msg, unsigned depth=0) {
  199. if (depth > 4)
  200. return false;
  201. if (I->getDebugLoc().get()) {
  202. EmitErrorOnInstruction(I, Msg);
  203. return true;
  204. }
  205. if (isa<PHINode>(I) || isa<SelectInst>(I)) {
  206. for (auto U : I->users())
  207. if (Instruction *UI = dyn_cast<Instruction>(U))
  208. if (EmitErrorOnInstructionFollowPhiSelect(UI, Msg, depth+1))
  209. return true;
  210. }
  211. return false;
  212. }
  213. void EmitErrorOnInstruction(Instruction *I, StringRef Msg) {
  214. const DebugLoc &DL = I->getDebugLoc();
  215. if (DL.get()) {
  216. std::string locString;
  217. raw_string_ostream os(locString);
  218. DL.print(os);
  219. I->getContext().emitError(os.str() + ": " + Twine(Msg));
  220. return;
  221. } else if (isa<PHINode>(I) || isa<SelectInst>(I)) {
  222. if (EmitErrorOnInstructionFollowPhiSelect(I, Msg))
  223. return;
  224. }
  225. I->getContext().emitError(Twine(Msg) + " Use /Zi for source location.");
  226. }
  227. const StringRef kResourceMapErrorMsg =
  228. "local resource not guaranteed to map to unique global resource.";
  229. void EmitResMappingError(Instruction *Res) {
  230. EmitErrorOnInstruction(Res, kResourceMapErrorMsg);
  231. }
  232. void CollectSelect(llvm::Instruction *Inst,
  233. std::unordered_set<llvm::Instruction *> &selectSet) {
  234. unsigned startOpIdx = 0;
  235. // Skip Cond for Select.
  236. if (isa<SelectInst>(Inst)) {
  237. startOpIdx = 1;
  238. } else if (!isa<PHINode>(Inst)) {
  239. // Only check phi and select here.
  240. return;
  241. }
  242. // Already add.
  243. if (selectSet.count(Inst))
  244. return;
  245. selectSet.insert(Inst);
  246. // Scan operand to add node which is phi/select.
  247. unsigned numOperands = Inst->getNumOperands();
  248. for (unsigned i = startOpIdx; i < numOperands; i++) {
  249. Value *V = Inst->getOperand(i);
  250. if (Instruction *I = dyn_cast<Instruction>(V)) {
  251. CollectSelect(I, selectSet);
  252. }
  253. }
  254. }
  255. Value *MergeSelectOnSameValue(Instruction *SelInst, unsigned startOpIdx,
  256. unsigned numOperands) {
  257. Value *op0 = nullptr;
  258. for (unsigned i = startOpIdx; i < numOperands; i++) {
  259. Value *op = SelInst->getOperand(i);
  260. if (i == startOpIdx) {
  261. op0 = op;
  262. } else {
  263. if (op0 != op)
  264. return nullptr;
  265. }
  266. }
  267. if (op0) {
  268. SelInst->replaceAllUsesWith(op0);
  269. SelInst->eraseFromParent();
  270. }
  271. return op0;
  272. }
  273. Value *SelectOnOperation(llvm::Instruction *Inst, unsigned operandIdx) {
  274. Instruction *prototype = Inst;
  275. for (unsigned i = 0; i < prototype->getNumOperands(); i++) {
  276. if (i == operandIdx)
  277. continue;
  278. if (!isa<Constant>(prototype->getOperand(i)))
  279. return nullptr;
  280. }
  281. Value *V = prototype->getOperand(operandIdx);
  282. if (SelectInst *SI = dyn_cast<SelectInst>(V)) {
  283. IRBuilder<> Builder(SI);
  284. Instruction *trueClone = Inst->clone();
  285. trueClone->setOperand(operandIdx, SI->getTrueValue());
  286. Builder.Insert(trueClone);
  287. Instruction *falseClone = Inst->clone();
  288. falseClone->setOperand(operandIdx, SI->getFalseValue());
  289. Builder.Insert(falseClone);
  290. Value *newSel =
  291. Builder.CreateSelect(SI->getCondition(), trueClone, falseClone);
  292. return newSel;
  293. }
  294. if (PHINode *Phi = dyn_cast<PHINode>(V)) {
  295. Type *Ty = Inst->getType();
  296. unsigned numOperands = Phi->getNumOperands();
  297. IRBuilder<> Builder(Phi);
  298. PHINode *newPhi = Builder.CreatePHI(Ty, numOperands);
  299. for (unsigned i = 0; i < numOperands; i++) {
  300. BasicBlock *b = Phi->getIncomingBlock(i);
  301. Value *V = Phi->getIncomingValue(i);
  302. Instruction *iClone = Inst->clone();
  303. IRBuilder<> iBuilder(b->getTerminator()->getPrevNode());
  304. iClone->setOperand(operandIdx, V);
  305. iBuilder.Insert(iClone);
  306. newPhi->addIncoming(iClone, b);
  307. }
  308. return newPhi;
  309. }
  310. return nullptr;
  311. }
  312. llvm::Instruction *SkipAllocas(llvm::Instruction *I) {
  313. // Step past any allocas:
  314. while (I && isa<AllocaInst>(I))
  315. I = I->getNextNode();
  316. return I;
  317. }
  318. llvm::Instruction *FindAllocaInsertionPt(llvm::Instruction* I) {
  319. Function *F = I->getParent()->getParent();
  320. if (F)
  321. return F->getEntryBlock().getFirstInsertionPt();
  322. else // BB with no parent function
  323. return I->getParent()->getFirstInsertionPt();
  324. }
  325. llvm::Instruction *FindAllocaInsertionPt(llvm::Function* F) {
  326. return F->getEntryBlock().getFirstInsertionPt();
  327. }
  328. llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Instruction* I) {
  329. return SkipAllocas(FindAllocaInsertionPt(I));
  330. }
  331. llvm::Instruction *FirstNonAllocaInsertionPt(llvm::BasicBlock* BB) {
  332. return SkipAllocas(
  333. BB->getFirstInsertionPt());
  334. }
  335. llvm::Instruction *FirstNonAllocaInsertionPt(llvm::Function* F) {
  336. return SkipAllocas(
  337. F->getEntryBlock().getFirstInsertionPt());
  338. }
  339. bool ContainsHLSLObjectType(llvm::Type *Ty) {
  340. // Unwrap pointer/array
  341. while (llvm::isa<llvm::PointerType>(Ty))
  342. Ty = llvm::cast<llvm::PointerType>(Ty)->getPointerElementType();
  343. while (llvm::isa<llvm::ArrayType>(Ty))
  344. Ty = llvm::cast<llvm::ArrayType>(Ty)->getArrayElementType();
  345. if (llvm::StructType *ST = llvm::dyn_cast<llvm::StructType>(Ty)) {
  346. if (ST->getName().startswith("dx.types."))
  347. return true;
  348. // TODO: How is this suppoed to check for Input/OutputPatch types if
  349. // these have already been eliminated in function arguments during CG?
  350. if (HLModule::IsHLSLObjectType(Ty))
  351. return true;
  352. // Otherwise, recurse elements of UDT
  353. for (auto ETy : ST->elements()) {
  354. if (ContainsHLSLObjectType(ETy))
  355. return true;
  356. }
  357. }
  358. return false;
  359. }
  360. // Based on the implementation available in LLVM's trunk:
  361. // http://llvm.org/doxygen/Constants_8cpp_source.html#l02734
  362. bool IsSplat(llvm::ConstantDataVector *cdv) {
  363. const char *Base = cdv->getRawDataValues().data();
  364. // Compare elements 1+ to the 0'th element.
  365. unsigned EltSize = cdv->getElementByteSize();
  366. for (unsigned i = 1, e = cdv->getNumElements(); i != e; ++i)
  367. if (memcmp(Base, Base + i * EltSize, EltSize))
  368. return false;
  369. return true;
  370. }
  371. }
  372. }