DxilUtil.cpp 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilUtil.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Dxil helper functions. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "llvm/IR/GlobalVariable.h"
  12. #include "dxc/HLSL/DxilTypeSystem.h"
  13. #include "dxc/HLSL/DxilUtil.h"
  14. #include "dxc/HLSL/DxilModule.h"
  15. #include "llvm/Bitcode/ReaderWriter.h"
  16. #include "llvm/IR/DiagnosticInfo.h"
  17. #include "llvm/IR/DiagnosticPrinter.h"
  18. #include "llvm/IR/LLVMContext.h"
  19. #include "llvm/IR/Module.h"
  20. #include "llvm/Support/MemoryBuffer.h"
  21. #include "llvm/Support/raw_ostream.h"
  22. #include "llvm/IR/Instructions.h"
  23. #include "llvm/IR/Constants.h"
  24. using namespace llvm;
  25. using namespace hlsl;
  26. namespace hlsl {
  27. namespace dxilutil {
  28. Type *GetArrayEltTy(Type *Ty) {
  29. if (isa<PointerType>(Ty))
  30. Ty = Ty->getPointerElementType();
  31. while (isa<ArrayType>(Ty)) {
  32. Ty = Ty->getArrayElementType();
  33. }
  34. return Ty;
  35. }
  36. bool HasDynamicIndexing(Value *V) {
  37. for (auto User : V->users()) {
  38. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  39. for (auto Idx = GEP->idx_begin(); Idx != GEP->idx_end(); ++Idx) {
  40. if (!isa<ConstantInt>(Idx))
  41. return true;
  42. }
  43. }
  44. }
  45. return false;
  46. }
  47. unsigned
  48. GetLegacyCBufferFieldElementSize(DxilFieldAnnotation &fieldAnnotation,
  49. llvm::Type *Ty,
  50. DxilTypeSystem &typeSys) {
  51. while (isa<ArrayType>(Ty)) {
  52. Ty = Ty->getArrayElementType();
  53. }
  54. // Bytes.
  55. CompType compType = fieldAnnotation.GetCompType();
  56. unsigned compSize = compType.Is64Bit() ? 8 : compType.Is16Bit() && !typeSys.UseMinPrecision() ? 2 : 4;
  57. unsigned fieldSize = compSize;
  58. if (Ty->isVectorTy()) {
  59. fieldSize *= Ty->getVectorNumElements();
  60. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  61. DxilStructAnnotation *EltAnnotation = typeSys.GetStructAnnotation(ST);
  62. if (EltAnnotation) {
  63. fieldSize = EltAnnotation->GetCBufferSize();
  64. } else {
  65. // Calculate size when don't have annotation.
  66. if (fieldAnnotation.HasMatrixAnnotation()) {
  67. const DxilMatrixAnnotation &matAnnotation =
  68. fieldAnnotation.GetMatrixAnnotation();
  69. unsigned rows = matAnnotation.Rows;
  70. unsigned cols = matAnnotation.Cols;
  71. if (matAnnotation.Orientation == MatrixOrientation::ColumnMajor) {
  72. rows = cols;
  73. cols = matAnnotation.Rows;
  74. } else if (matAnnotation.Orientation != MatrixOrientation::RowMajor) {
  75. // Invalid matrix orientation.
  76. fieldSize = 0;
  77. }
  78. fieldSize = (rows - 1) * 16 + cols * 4;
  79. } else {
  80. // Cannot find struct annotation.
  81. fieldSize = 0;
  82. }
  83. }
  84. }
  85. return fieldSize;
  86. }
  87. bool IsStaticGlobal(GlobalVariable *GV) {
  88. return GV->getLinkage() == GlobalValue::LinkageTypes::InternalLinkage &&
  89. GV->getType()->getPointerAddressSpace() == DXIL::kDefaultAddrSpace;
  90. }
  91. bool IsSharedMemoryGlobal(llvm::GlobalVariable *GV) {
  92. return GV->getType()->getPointerAddressSpace() == DXIL::kTGSMAddrSpace;
  93. }
  94. bool RemoveUnusedFunctions(Module &M, Function *EntryFunc,
  95. Function *PatchConstantFunc, bool IsLib) {
  96. std::vector<Function *> deadList;
  97. for (auto &F : M.functions()) {
  98. if (&F == EntryFunc || &F == PatchConstantFunc)
  99. continue;
  100. if (F.isDeclaration() || !IsLib) {
  101. if (F.user_empty())
  102. deadList.emplace_back(&F);
  103. }
  104. }
  105. bool bUpdated = deadList.size();
  106. for (Function *F : deadList)
  107. F->eraseFromParent();
  108. return bUpdated;
  109. }
  110. void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context) {
  111. DiagnosticPrinter *printer = reinterpret_cast<DiagnosticPrinter *>(Context);
  112. DI.print(*printer);
  113. }
  114. std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,
  115. llvm::LLVMContext &Ctx,
  116. std::string &DiagStr) {
  117. raw_string_ostream DiagStream(DiagStr);
  118. llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
  119. Ctx.setDiagnosticHandler(PrintDiagnosticHandler, &DiagPrinter, true);
  120. ErrorOr<std::unique_ptr<llvm::Module>> pModule(
  121. llvm::parseBitcodeFile(MB->getMemBufferRef(), Ctx));
  122. if (std::error_code ec = pModule.getError()) {
  123. return nullptr;
  124. }
  125. return std::unique_ptr<llvm::Module>(pModule.get().release());
  126. }
  127. std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::StringRef BC,
  128. llvm::LLVMContext &Ctx,
  129. std::string &DiagStr) {
  130. std::unique_ptr<llvm::MemoryBuffer> pBitcodeBuf(
  131. llvm::MemoryBuffer::getMemBuffer(BC, "", false));
  132. return LoadModuleFromBitcode(pBitcodeBuf.get(), Ctx, DiagStr);
  133. }
  134. }
  135. }