DxilOperations.h 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilOperations.h //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Implementation of DXIL operation tables. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #pragma once
  12. namespace llvm {
  13. class LLVMContext;
  14. class Module;
  15. class Type;
  16. class StructType;
  17. class Function;
  18. class Constant;
  19. class Value;
  20. class Instruction;
  21. class CallInst;
  22. }
  23. #include "llvm/IR/Attributes.h"
  24. #include "llvm/ADT/StringRef.h"
  25. #include "llvm/ADT/MapVector.h"
  26. #include "DxilConstants.h"
  27. #include <unordered_map>
  28. namespace hlsl {
  29. /// Use this utility class to interact with DXIL operations.
  30. class OP {
  31. public:
  32. using OpCode = DXIL::OpCode;
  33. using OpCodeClass = DXIL::OpCodeClass;
  34. public:
  35. OP() = delete;
  36. OP(llvm::LLVMContext &Ctx, llvm::Module *pModule);
  37. void RefreshCache();
  38. llvm::Function *GetOpFunc(OpCode OpCode, llvm::Type *pOverloadType);
  39. const llvm::SmallMapVector<llvm::Type *, llvm::Function *, 8> &GetOpFuncList(OpCode OpCode) const;
  40. void RemoveFunction(llvm::Function *F);
  41. llvm::Type *GetOverloadType(OpCode OpCode, llvm::Function *F);
  42. llvm::LLVMContext &GetCtx() { return m_Ctx; }
  43. llvm::Type *GetHandleType() const;
  44. llvm::Type *GetResourcePropertiesType() const;
  45. llvm::Type *GetDimensionsType() const;
  46. llvm::Type *GetSamplePosType() const;
  47. llvm::Type *GetBinaryWithCarryType() const;
  48. llvm::Type *GetBinaryWithTwoOutputsType() const;
  49. llvm::Type *GetSplitDoubleType() const;
  50. llvm::Type *GetInt4Type() const;
  51. llvm::Type *GetResRetType(llvm::Type *pOverloadType);
  52. llvm::Type *GetCBufferRetType(llvm::Type *pOverloadType);
  53. bool IsResRetType(llvm::Type *Ty);
  54. // Try to get the opcode class for a function.
  55. // Return true and set `opClass` if the given function is a dxil function.
  56. // Return false if the given function is not a dxil function.
  57. bool GetOpCodeClass(const llvm::Function *F, OpCodeClass &opClass);
  58. // To check if operation uses strict precision types
  59. bool UseMinPrecision();
  60. // Set if operation uses strict precision types or not.
  61. void SetMinPrecision(bool bMinPrecision);
  62. // Get the size of the type for a given layout
  63. uint64_t GetAllocSizeForType(llvm::Type *Ty);
  64. // LLVM helpers. Perhaps, move to a separate utility class.
  65. llvm::Constant *GetI1Const(bool v);
  66. llvm::Constant *GetI8Const(char v);
  67. llvm::Constant *GetU8Const(unsigned char v);
  68. llvm::Constant *GetI16Const(int v);
  69. llvm::Constant *GetU16Const(unsigned v);
  70. llvm::Constant *GetI32Const(int v);
  71. llvm::Constant *GetU32Const(unsigned v);
  72. llvm::Constant *GetU64Const(unsigned long long v);
  73. llvm::Constant *GetFloatConst(float v);
  74. llvm::Constant *GetDoubleConst(double v);
  75. static OpCode GetDxilOpFuncCallInst(const llvm::Instruction *I);
  76. static const char *GetOpCodeName(OpCode OpCode);
  77. static const char *GetAtomicOpName(DXIL::AtomicBinOpCode OpCode);
  78. static OpCodeClass GetOpCodeClass(OpCode OpCode);
  79. static const char *GetOpCodeClassName(OpCode OpCode);
  80. static llvm::Attribute::AttrKind GetMemAccessAttr(OpCode opCode);
  81. static bool IsOverloadLegal(OpCode OpCode, llvm::Type *pType);
  82. static bool CheckOpCodeTable();
  83. static bool IsDxilOpFuncName(llvm::StringRef name);
  84. static bool IsDxilOpFunc(const llvm::Function *F);
  85. static bool IsDxilOpFuncCallInst(const llvm::Instruction *I);
  86. static bool IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode);
  87. static bool IsDxilOpWave(OpCode C);
  88. static bool IsDxilOpGradient(OpCode C);
  89. static bool IsDxilOpFeedback(OpCode C);
  90. static bool IsDxilOpTypeName(llvm::StringRef name);
  91. static bool IsDxilOpType(llvm::StructType *ST);
  92. static bool IsDupDxilOpType(llvm::StructType *ST);
  93. static llvm::StructType *GetOriginalDxilOpType(llvm::StructType *ST,
  94. llvm::Module &M);
  95. static void GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
  96. unsigned &major, unsigned &minor,
  97. unsigned &mask);
  98. static void GetMinShaderModelAndMask(const llvm::CallInst *CI, bool bWithTranslation,
  99. unsigned valMajor, unsigned valMinor,
  100. unsigned &major, unsigned &minor,
  101. unsigned &mask);
  102. private:
  103. // Per-module properties.
  104. llvm::LLVMContext &m_Ctx;
  105. llvm::Module *m_pModule;
  106. llvm::Type *m_pHandleType;
  107. llvm::Type *m_pResourcePropertiesType;
  108. llvm::Type *m_pDimensionsType;
  109. llvm::Type *m_pSamplePosType;
  110. llvm::Type *m_pBinaryWithCarryType;
  111. llvm::Type *m_pBinaryWithTwoOutputsType;
  112. llvm::Type *m_pSplitDoubleType;
  113. llvm::Type *m_pInt4Type;
  114. DXIL::LowPrecisionMode m_LowPrecisionMode;
  115. static const unsigned kUserDefineTypeSlot = 9;
  116. static const unsigned kObjectTypeSlot = 10;
  117. static const unsigned kNumTypeOverloads = 11; // void, h,f,d, i1, i8,i16,i32,i64, udt, obj
  118. llvm::Type *m_pResRetType[kNumTypeOverloads];
  119. llvm::Type *m_pCBufferRetType[kNumTypeOverloads];
  120. struct OpCodeCacheItem {
  121. llvm::SmallMapVector<llvm::Type *, llvm::Function *, 8> pOverloads;
  122. };
  123. OpCodeCacheItem m_OpCodeClassCache[(unsigned)OpCodeClass::NumOpClasses];
  124. std::unordered_map<const llvm::Function *, OpCodeClass> m_FunctionToOpClass;
  125. void UpdateCache(OpCodeClass opClass, llvm::Type * Ty, llvm::Function *F);
  126. private:
  127. // Static properties.
  128. struct OpCodeProperty {
  129. OpCode opCode;
  130. const char *pOpCodeName;
  131. OpCodeClass opCodeClass;
  132. const char *pOpCodeClassName;
  133. bool bAllowOverload[kNumTypeOverloads]; // void, h,f,d, i1, i8,i16,i32,i64, udt
  134. llvm::Attribute::AttrKind FuncAttr;
  135. };
  136. static const OpCodeProperty m_OpCodeProps[(unsigned)OpCode::NumOpCodes];
  137. static const char *m_OverloadTypeName[kNumTypeOverloads];
  138. static const char *m_NamePrefix;
  139. static const char *m_TypePrefix;
  140. static const char *m_MatrixTypePrefix;
  141. static unsigned GetTypeSlot(llvm::Type *pType);
  142. static const char *GetOverloadTypeName(unsigned TypeSlot);
  143. static llvm::StringRef GetTypeName(llvm::Type *Ty, std::string &str);
  144. };
  145. } // namespace hlsl