DxilOperations.h 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilOperations.h //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Implementation of DXIL operation tables. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #pragma once
  12. namespace llvm {
  13. class LLVMContext;
  14. class Module;
  15. class Type;
  16. class StructType;
  17. class Function;
  18. class Constant;
  19. class Value;
  20. class Instruction;
  21. }
  22. #include "llvm/IR/Attributes.h"
  23. #include "llvm/ADT/StringRef.h"
  24. #include "llvm/ADT/DenseMap.h"
  25. #include "DxilConstants.h"
  26. #include <unordered_map>
  27. namespace hlsl {
  28. /// Use this utility class to interact with DXIL operations.
  29. class OP {
  30. public:
  31. using OpCode = DXIL::OpCode;
  32. using OpCodeClass = DXIL::OpCodeClass;
  33. public:
  34. OP() = delete;
  35. OP(llvm::LLVMContext &Ctx, llvm::Module *pModule);
  36. void RefreshCache();
  37. llvm::Function *GetOpFunc(OpCode OpCode, llvm::Type *pOverloadType);
  38. const llvm::SmallDenseMap<llvm::Type *, llvm::Function *, 8> &GetOpFuncList(OpCode OpCode) const;
  39. void RemoveFunction(llvm::Function *F);
  40. llvm::Type *GetOverloadType(OpCode OpCode, llvm::Function *F);
  41. llvm::LLVMContext &GetCtx() { return m_Ctx; }
  42. llvm::Type *GetHandleType() const;
  43. llvm::Type *GetDimensionsType() const;
  44. llvm::Type *GetSamplePosType() const;
  45. llvm::Type *GetBinaryWithCarryType() const;
  46. llvm::Type *GetBinaryWithTwoOutputsType() const;
  47. llvm::Type *GetSplitDoubleType() const;
  48. llvm::Type *GetInt4Type() const;
  49. llvm::Type *GetResRetType(llvm::Type *pOverloadType);
  50. llvm::Type *GetCBufferRetType(llvm::Type *pOverloadType);
  51. bool IsResRetType(llvm::Type *Ty);
  52. // Try to get the opcode class for a function.
  53. // Return true and set `opClass` if the given function is a dxil function.
  54. // Return false if the given function is not a dxil function.
  55. bool GetOpCodeClass(const llvm::Function *F, OpCodeClass &opClass);
  56. // To check if operation uses strict precision types
  57. bool UseMinPrecision();
  58. // Get the size of the type for a given layout
  59. uint64_t GetAllocSizeForType(llvm::Type *Ty);
  60. // LLVM helpers. Perhaps, move to a separate utility class.
  61. llvm::Constant *GetI1Const(bool v);
  62. llvm::Constant *GetI8Const(char v);
  63. llvm::Constant *GetU8Const(unsigned char v);
  64. llvm::Constant *GetI16Const(int v);
  65. llvm::Constant *GetU16Const(unsigned v);
  66. llvm::Constant *GetI32Const(int v);
  67. llvm::Constant *GetU32Const(unsigned v);
  68. llvm::Constant *GetU64Const(unsigned long long v);
  69. llvm::Constant *GetFloatConst(float v);
  70. llvm::Constant *GetDoubleConst(double v);
  71. static OpCode GetDxilOpFuncCallInst(const llvm::Instruction *I);
  72. static const char *GetOpCodeName(OpCode OpCode);
  73. static const char *GetAtomicOpName(DXIL::AtomicBinOpCode OpCode);
  74. static OpCodeClass GetOpCodeClass(OpCode OpCode);
  75. static const char *GetOpCodeClassName(OpCode OpCode);
  76. static bool IsOverloadLegal(OpCode OpCode, llvm::Type *pType);
  77. static bool CheckOpCodeTable();
  78. static bool IsDxilOpFuncName(llvm::StringRef name);
  79. static bool IsDxilOpFunc(const llvm::Function *F);
  80. static bool IsDxilOpFuncCallInst(const llvm::Instruction *I);
  81. static bool IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode);
  82. static bool IsDxilOpWave(OpCode C);
  83. static bool IsDxilOpGradient(OpCode C);
  84. static bool IsDxilOpTypeName(llvm::StringRef name);
  85. static bool IsDxilOpType(llvm::StructType *ST);
  86. static bool IsDupDxilOpType(llvm::StructType *ST);
  87. static llvm::StructType *GetOriginalDxilOpType(llvm::StructType *ST,
  88. llvm::Module &M);
  89. static void GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
  90. unsigned &major, unsigned &minor,
  91. unsigned &mask);
  92. private:
  93. // Per-module properties.
  94. llvm::LLVMContext &m_Ctx;
  95. llvm::Module *m_pModule;
  96. llvm::Type *m_pHandleType;
  97. llvm::Type *m_pDimensionsType;
  98. llvm::Type *m_pSamplePosType;
  99. llvm::Type *m_pBinaryWithCarryType;
  100. llvm::Type *m_pBinaryWithTwoOutputsType;
  101. llvm::Type *m_pSplitDoubleType;
  102. llvm::Type *m_pInt4Type;
  103. DXIL::LowPrecisionMode m_LowPrecisionMode;
  104. static const unsigned kUserDefineTypeSlot = 9;
  105. static const unsigned kObjectTypeSlot = 10;
  106. static const unsigned kNumTypeOverloads = 11; // void, h,f,d, i1, i8,i16,i32,i64, udt, obj
  107. llvm::Type *m_pResRetType[kNumTypeOverloads];
  108. llvm::Type *m_pCBufferRetType[kNumTypeOverloads];
  109. struct OpCodeCacheItem {
  110. llvm::SmallDenseMap<llvm::Type *, llvm::Function *, 8> pOverloads;
  111. };
  112. OpCodeCacheItem m_OpCodeClassCache[(unsigned)OpCodeClass::NumOpClasses];
  113. std::unordered_map<const llvm::Function *, OpCodeClass> m_FunctionToOpClass;
  114. void UpdateCache(OpCodeClass opClass, llvm::Type * Ty, llvm::Function *F);
  115. private:
  116. // Static properties.
  117. struct OpCodeProperty {
  118. OpCode opCode;
  119. const char *pOpCodeName;
  120. OpCodeClass opCodeClass;
  121. const char *pOpCodeClassName;
  122. bool bAllowOverload[kNumTypeOverloads]; // void, h,f,d, i1, i8,i16,i32,i64, udt
  123. llvm::Attribute::AttrKind FuncAttr;
  124. };
  125. static const OpCodeProperty m_OpCodeProps[(unsigned)OpCode::NumOpCodes];
  126. static const char *m_OverloadTypeName[kNumTypeOverloads];
  127. static const char *m_NamePrefix;
  128. static const char *m_TypePrefix;
  129. static const char *m_MatrixTypePrefix;
  130. static unsigned GetTypeSlot(llvm::Type *pType);
  131. static const char *GetOverloadTypeName(unsigned TypeSlot);
  132. static llvm::StringRef GetTypeName(llvm::Type *Ty, std::string &str);
  133. };
  134. } // namespace hlsl