DxilOperations.h 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilOperations.h //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Implementation of DXIL operation tables. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #pragma once
  12. namespace llvm {
  13. class LLVMContext;
  14. class Module;
  15. class Type;
  16. class StructType;
  17. class Function;
  18. class Constant;
  19. class Value;
  20. class Instruction;
  21. class CallInst;
  22. }
  23. #include "llvm/IR/Attributes.h"
  24. #include "llvm/ADT/StringRef.h"
  25. #include "llvm/ADT/MapVector.h"
  26. #include "DxilConstants.h"
  27. #include <unordered_map>
  28. namespace hlsl {
  29. /// Use this utility class to interact with DXIL operations.
  30. class OP {
  31. public:
  32. using OpCode = DXIL::OpCode;
  33. using OpCodeClass = DXIL::OpCodeClass;
  34. public:
  35. OP() = delete;
  36. OP(llvm::LLVMContext &Ctx, llvm::Module *pModule);
  37. void RefreshCache();
  38. void FixOverloadNames();
  39. llvm::Function *GetOpFunc(OpCode OpCode, llvm::Type *pOverloadType);
  40. const llvm::SmallMapVector<llvm::Type *, llvm::Function *, 8> &GetOpFuncList(OpCode OpCode) const;
  41. bool IsDxilOpUsed(OpCode opcode) const;
  42. void RemoveFunction(llvm::Function *F);
  43. llvm::LLVMContext &GetCtx() { return m_Ctx; }
  44. llvm::Type *GetHandleType() const;
  45. llvm::Type *GetResourcePropertiesType() const;
  46. llvm::Type *GetResourceBindingType() const;
  47. llvm::Type *GetDimensionsType() const;
  48. llvm::Type *GetSamplePosType() const;
  49. llvm::Type *GetBinaryWithCarryType() const;
  50. llvm::Type *GetBinaryWithTwoOutputsType() const;
  51. llvm::Type *GetSplitDoubleType() const;
  52. llvm::Type *GetFourI32Type() const;
  53. llvm::Type *GetFourI16Type() const;
  54. llvm::Type *GetResRetType(llvm::Type *pOverloadType);
  55. llvm::Type *GetCBufferRetType(llvm::Type *pOverloadType);
  56. llvm::Type *GetVectorType(unsigned numElements, llvm::Type *pOverloadType);
  57. bool IsResRetType(llvm::Type *Ty);
  58. // Try to get the opcode class for a function.
  59. // Return true and set `opClass` if the given function is a dxil function.
  60. // Return false if the given function is not a dxil function.
  61. bool GetOpCodeClass(const llvm::Function *F, OpCodeClass &opClass);
  62. // To check if operation uses strict precision types
  63. bool UseMinPrecision();
  64. // Set if operation uses strict precision types or not.
  65. void SetMinPrecision(bool bMinPrecision);
  66. // Get the size of the type for a given layout
  67. uint64_t GetAllocSizeForType(llvm::Type *Ty);
  68. // LLVM helpers. Perhaps, move to a separate utility class.
  69. llvm::Constant *GetI1Const(bool v);
  70. llvm::Constant *GetI8Const(char v);
  71. llvm::Constant *GetU8Const(unsigned char v);
  72. llvm::Constant *GetI16Const(int v);
  73. llvm::Constant *GetU16Const(unsigned v);
  74. llvm::Constant *GetI32Const(int v);
  75. llvm::Constant *GetU32Const(unsigned v);
  76. llvm::Constant *GetU64Const(unsigned long long v);
  77. llvm::Constant *GetFloatConst(float v);
  78. llvm::Constant *GetDoubleConst(double v);
  79. static llvm::Type *GetOverloadType(OpCode OpCode, llvm::Function *F);
  80. static OpCode GetDxilOpFuncCallInst(const llvm::Instruction *I);
  81. static const char *GetOpCodeName(OpCode OpCode);
  82. static const char *GetAtomicOpName(DXIL::AtomicBinOpCode OpCode);
  83. static OpCodeClass GetOpCodeClass(OpCode OpCode);
  84. static const char *GetOpCodeClassName(OpCode OpCode);
  85. static llvm::Attribute::AttrKind GetMemAccessAttr(OpCode opCode);
  86. static bool IsOverloadLegal(OpCode OpCode, llvm::Type *pType);
  87. static bool CheckOpCodeTable();
  88. static bool IsDxilOpFuncName(llvm::StringRef name);
  89. static bool IsDxilOpFunc(const llvm::Function *F);
  90. static bool IsDxilOpFuncCallInst(const llvm::Instruction *I);
  91. static bool IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode);
  92. static bool IsDxilOpWave(OpCode C);
  93. static bool IsDxilOpGradient(OpCode C);
  94. static bool IsDxilOpFeedback(OpCode C);
  95. static bool IsDxilOpTypeName(llvm::StringRef name);
  96. static bool IsDxilOpType(llvm::StructType *ST);
  97. static bool IsDupDxilOpType(llvm::StructType *ST);
  98. static llvm::StructType *GetOriginalDxilOpType(llvm::StructType *ST,
  99. llvm::Module &M);
  100. static void GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
  101. unsigned &major, unsigned &minor,
  102. unsigned &mask);
  103. static void GetMinShaderModelAndMask(const llvm::CallInst *CI, bool bWithTranslation,
  104. unsigned valMajor, unsigned valMinor,
  105. unsigned &major, unsigned &minor,
  106. unsigned &mask);
  107. private:
  108. // Per-module properties.
  109. llvm::LLVMContext &m_Ctx;
  110. llvm::Module *m_pModule;
  111. llvm::Type *m_pHandleType;
  112. llvm::Type *m_pResourcePropertiesType;
  113. llvm::Type *m_pResourceBindingType;
  114. llvm::Type *m_pDimensionsType;
  115. llvm::Type *m_pSamplePosType;
  116. llvm::Type *m_pBinaryWithCarryType;
  117. llvm::Type *m_pBinaryWithTwoOutputsType;
  118. llvm::Type *m_pSplitDoubleType;
  119. llvm::Type *m_pFourI32Type;
  120. llvm::Type *m_pFourI16Type;
  121. DXIL::LowPrecisionMode m_LowPrecisionMode;
  122. static const unsigned kUserDefineTypeSlot = 9;
  123. static const unsigned kObjectTypeSlot = 10;
  124. static const unsigned kNumTypeOverloads = 11; // void, h,f,d, i1, i8,i16,i32,i64, udt, obj
  125. llvm::Type *m_pResRetType[kNumTypeOverloads];
  126. llvm::Type *m_pCBufferRetType[kNumTypeOverloads];
  127. struct OpCodeCacheItem {
  128. llvm::SmallMapVector<llvm::Type *, llvm::Function *, 8> pOverloads;
  129. };
  130. OpCodeCacheItem m_OpCodeClassCache[(unsigned)OpCodeClass::NumOpClasses];
  131. std::unordered_map<const llvm::Function *, OpCodeClass> m_FunctionToOpClass;
  132. void UpdateCache(OpCodeClass opClass, llvm::Type * Ty, llvm::Function *F);
  133. private:
  134. // Static properties.
  135. struct OpCodeProperty {
  136. OpCode opCode;
  137. const char *pOpCodeName;
  138. OpCodeClass opCodeClass;
  139. const char *pOpCodeClassName;
  140. bool bAllowOverload[kNumTypeOverloads]; // void, h,f,d, i1, i8,i16,i32,i64, udt
  141. llvm::Attribute::AttrKind FuncAttr;
  142. };
  143. static const OpCodeProperty m_OpCodeProps[(unsigned)OpCode::NumOpCodes];
  144. static const char *m_OverloadTypeName[kNumTypeOverloads];
  145. static const char *m_NamePrefix;
  146. static const char *m_TypePrefix;
  147. static const char *m_MatrixTypePrefix;
  148. static unsigned GetTypeSlot(llvm::Type *pType);
  149. static const char *GetOverloadTypeName(unsigned TypeSlot);
  150. static llvm::StringRef GetTypeName(llvm::Type *Ty, std::string &str);
  151. static llvm::StringRef ConstructOverloadName(llvm::Type *Ty, DXIL::OpCode opCode,
  152. std::string &funcNameStorage);
  153. };
  154. } // namespace hlsl