DxilOperations.h 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxilOperations.h //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Implementation of DXIL operation tables. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #pragma once
  12. namespace llvm {
  13. class LLVMContext;
  14. class Module;
  15. class Type;
  16. class StructType;
  17. class Function;
  18. class Constant;
  19. class Value;
  20. class Instruction;
  21. };
  22. #include "llvm/IR/Attributes.h"
  23. #include "llvm/ADT/StringRef.h"
  24. #include "DxilConstants.h"
  25. #include <unordered_map>
  26. namespace hlsl {
  27. /// Use this utility class to interact with DXIL operations.
  28. class OP {
  29. public:
  30. using OpCode = DXIL::OpCode;
  31. using OpCodeClass = DXIL::OpCodeClass;
  32. public:
  33. OP() = delete;
  34. OP(llvm::LLVMContext &Ctx, llvm::Module *pModule);
  35. void RefreshCache();
  36. llvm::Function *GetOpFunc(OpCode OpCode, llvm::Type *pOverloadType);
  37. llvm::ArrayRef<llvm::Function *> GetOpFuncList(OpCode OpCode) const;
  38. void RemoveFunction(llvm::Function *F);
  39. llvm::Type *GetOverloadType(OpCode OpCode, llvm::Function *F);
  40. llvm::LLVMContext &GetCtx() { return m_Ctx; }
  41. llvm::Type *GetHandleType() const;
  42. llvm::Type *GetDimensionsType() const;
  43. llvm::Type *GetSamplePosType() const;
  44. llvm::Type *GetBinaryWithCarryType() const;
  45. llvm::Type *GetBinaryWithTwoOutputsType() const;
  46. llvm::Type *GetSplitDoubleType() const;
  47. llvm::Type *GetInt4Type() const;
  48. llvm::Type *GetResRetType(llvm::Type *pOverloadType);
  49. llvm::Type *GetCBufferRetType(llvm::Type *pOverloadType);
  50. bool IsResRetType(llvm::Type *Ty);
  51. // Try to get the opcode class for a function.
  52. // Return true and set `opClass` if the given function is a dxil function.
  53. // Return false if the given function is not a dxil function.
  54. bool GetOpCodeClass(const llvm::Function *F, OpCodeClass &opClass);
  55. // To check if operation uses strict precision types
  56. bool UseMinPrecision();
  57. // Get the size of the type for a given layout
  58. uint64_t GetAllocSizeForType(llvm::Type *Ty);
  59. // LLVM helpers. Perhaps, move to a separate utility class.
  60. llvm::Constant *GetI1Const(bool v);
  61. llvm::Constant *GetI8Const(char v);
  62. llvm::Constant *GetU8Const(unsigned char v);
  63. llvm::Constant *GetI16Const(int v);
  64. llvm::Constant *GetU16Const(unsigned v);
  65. llvm::Constant *GetI32Const(int v);
  66. llvm::Constant *GetU32Const(unsigned v);
  67. llvm::Constant *GetU64Const(unsigned long long v);
  68. llvm::Constant *GetFloatConst(float v);
  69. llvm::Constant *GetDoubleConst(double v);
  70. static OpCode GetDxilOpFuncCallInst(const llvm::Instruction *I);
  71. static const char *GetOpCodeName(OpCode OpCode);
  72. static const char *GetAtomicOpName(DXIL::AtomicBinOpCode OpCode);
  73. static OpCodeClass GetOpCodeClass(OpCode OpCode);
  74. static const char *GetOpCodeClassName(OpCode OpCode);
  75. static bool IsOverloadLegal(OpCode OpCode, llvm::Type *pType);
  76. static bool CheckOpCodeTable();
  77. static bool IsDxilOpFuncName(llvm::StringRef name);
  78. static bool IsDxilOpFunc(const llvm::Function *F);
  79. static bool IsDxilOpFuncCallInst(const llvm::Instruction *I);
  80. static bool IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode);
  81. static bool IsDxilOpWave(OpCode C);
  82. static bool IsDxilOpGradient(OpCode C);
  83. static bool IsDxilOpType(llvm::StructType *ST);
  84. static bool IsDupDxilOpType(llvm::StructType *ST);
  85. static llvm::StructType *GetOriginalDxilOpType(llvm::StructType *ST,
  86. llvm::Module &M);
  87. private:
  88. // Per-module properties.
  89. llvm::LLVMContext &m_Ctx;
  90. llvm::Module *m_pModule;
  91. llvm::Type *m_pHandleType;
  92. llvm::Type *m_pDimensionsType;
  93. llvm::Type *m_pSamplePosType;
  94. llvm::Type *m_pBinaryWithCarryType;
  95. llvm::Type *m_pBinaryWithTwoOutputsType;
  96. llvm::Type *m_pSplitDoubleType;
  97. llvm::Type *m_pInt4Type;
  98. DXIL::LowPrecisionMode m_LowPrecisionMode;
  99. static const unsigned kNumTypeOverloads = 9;
  100. llvm::Type *m_pResRetType[kNumTypeOverloads];
  101. llvm::Type *m_pCBufferRetType[kNumTypeOverloads];
  102. struct OpCodeCacheItem {
  103. llvm::Function *pOverloads[kNumTypeOverloads];
  104. };
  105. OpCodeCacheItem m_OpCodeClassCache[(unsigned)OpCodeClass::NumOpClasses];
  106. std::unordered_map<const llvm::Function *, OpCodeClass> m_FunctionToOpClass;
  107. void UpdateCache(OpCodeClass opClass, unsigned typeSlot, llvm::Function *F);
  108. private:
  109. // Static properties.
  110. struct OpCodeProperty {
  111. OpCode OpCode;
  112. const char *pOpCodeName;
  113. OpCodeClass OpCodeClass;
  114. const char *pOpCodeClassName;
  115. bool bAllowOverload[kNumTypeOverloads]; // void, h,f,d, i1, i8,i16,i32,i64
  116. llvm::Attribute::AttrKind FuncAttr;
  117. };
  118. static const OpCodeProperty m_OpCodeProps[(unsigned)OpCode::NumOpCodes];
  119. static const char *m_OverloadTypeName[kNumTypeOverloads];
  120. static const char *m_NamePrefix;
  121. static const char *m_TypePrefix;
  122. static unsigned GetTypeSlot(llvm::Type *pType);
  123. static const char *GetOverloadTypeName(unsigned TypeSlot);
  124. };
  125. } // namespace hlsl