DxilUtil.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilUtil.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Dxil helper functions. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "llvm/IR/GlobalVariable.h"
  12. #include "dxc/HLSL/DxilTypeSystem.h"
  13. #include "dxc/HLSL/DxilUtil.h"
  14. #include "dxc/HLSL/DxilModule.h"
  15. #include "llvm/Bitcode/ReaderWriter.h"
  16. #include "llvm/IR/DiagnosticInfo.h"
  17. #include "llvm/IR/DiagnosticPrinter.h"
  18. #include "llvm/IR/LLVMContext.h"
  19. #include "llvm/IR/Module.h"
  20. #include "llvm/Support/MemoryBuffer.h"
  21. #include "llvm/Support/raw_ostream.h"
  22. using namespace llvm;
  23. using namespace hlsl;
  24. namespace hlsl {
  25. namespace dxilutil {
  26. Type *GetArrayEltTy(Type *Ty) {
  27. if (isa<PointerType>(Ty))
  28. Ty = Ty->getPointerElementType();
  29. while (isa<ArrayType>(Ty)) {
  30. Ty = Ty->getArrayElementType();
  31. }
  32. return Ty;
  33. }
  34. unsigned
  35. GetLegacyCBufferFieldElementSize(DxilFieldAnnotation &fieldAnnotation,
  36. llvm::Type *Ty,
  37. DxilTypeSystem &typeSys) {
  38. while (isa<ArrayType>(Ty)) {
  39. Ty = Ty->getArrayElementType();
  40. }
  41. // Bytes.
  42. CompType compType = fieldAnnotation.GetCompType();
  43. unsigned compSize = compType.Is64Bit() ? 8 : compType.Is16Bit() && !typeSys.UseMinPrecision() ? 2 : 4;
  44. unsigned fieldSize = compSize;
  45. if (Ty->isVectorTy()) {
  46. fieldSize *= Ty->getVectorNumElements();
  47. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  48. DxilStructAnnotation *EltAnnotation = typeSys.GetStructAnnotation(ST);
  49. if (EltAnnotation) {
  50. fieldSize = EltAnnotation->GetCBufferSize();
  51. } else {
  52. // Calculate size when don't have annotation.
  53. if (fieldAnnotation.HasMatrixAnnotation()) {
  54. const DxilMatrixAnnotation &matAnnotation =
  55. fieldAnnotation.GetMatrixAnnotation();
  56. unsigned rows = matAnnotation.Rows;
  57. unsigned cols = matAnnotation.Cols;
  58. if (matAnnotation.Orientation == MatrixOrientation::ColumnMajor) {
  59. rows = cols;
  60. cols = matAnnotation.Rows;
  61. } else if (matAnnotation.Orientation != MatrixOrientation::RowMajor) {
  62. // Invalid matrix orientation.
  63. fieldSize = 0;
  64. }
  65. fieldSize = (rows - 1) * 16 + cols * 4;
  66. } else {
  67. // Cannot find struct annotation.
  68. fieldSize = 0;
  69. }
  70. }
  71. }
  72. return fieldSize;
  73. }
  74. bool IsStaticGlobal(GlobalVariable *GV) {
  75. return GV->getLinkage() == GlobalValue::LinkageTypes::InternalLinkage &&
  76. GV->getType()->getPointerAddressSpace() == DXIL::kDefaultAddrSpace;
  77. }
  78. bool IsSharedMemoryGlobal(llvm::GlobalVariable *GV) {
  79. return GV->getType()->getPointerAddressSpace() == DXIL::kTGSMAddrSpace;
  80. }
  81. bool RemoveUnusedFunctions(Module &M, Function *EntryFunc,
  82. Function *PatchConstantFunc, bool IsLib) {
  83. std::vector<Function *> deadList;
  84. for (auto &F : M.functions()) {
  85. if (&F == EntryFunc || &F == PatchConstantFunc)
  86. continue;
  87. if (F.isDeclaration() || !IsLib) {
  88. if (F.user_empty())
  89. deadList.emplace_back(&F);
  90. }
  91. }
  92. bool bUpdated = deadList.size();
  93. for (Function *F : deadList)
  94. F->eraseFromParent();
  95. return bUpdated;
  96. }
  97. void PrintDiagnosticHandler(const llvm::DiagnosticInfo &DI, void *Context) {
  98. DiagnosticPrinter *printer = reinterpret_cast<DiagnosticPrinter *>(Context);
  99. DI.print(*printer);
  100. }
  101. std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::MemoryBuffer *MB,
  102. llvm::LLVMContext &Ctx,
  103. std::string &DiagStr) {
  104. raw_string_ostream DiagStream(DiagStr);
  105. llvm::DiagnosticPrinterRawOStream DiagPrinter(DiagStream);
  106. Ctx.setDiagnosticHandler(PrintDiagnosticHandler, &DiagPrinter, true);
  107. ErrorOr<std::unique_ptr<llvm::Module>> pModule(
  108. llvm::parseBitcodeFile(MB->getMemBufferRef(), Ctx));
  109. if (std::error_code ec = pModule.getError()) {
  110. return nullptr;
  111. }
  112. return std::unique_ptr<llvm::Module>(pModule.get().release());
  113. }
  114. std::unique_ptr<llvm::Module> LoadModuleFromBitcode(llvm::StringRef BC,
  115. llvm::LLVMContext &Ctx,
  116. std::string &DiagStr) {
  117. std::unique_ptr<llvm::MemoryBuffer> pBitcodeBuf(
  118. llvm::MemoryBuffer::getMemBuffer(BC, "", false));
  119. return LoadModuleFromBitcode(pBitcodeBuf.get(), Ctx, DiagStr);
  120. }
  121. }
  122. }