HLMatrixType.h 4.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixType.h //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #pragma once
  10. #include "llvm/IR/IRBuilder.h"
  11. namespace llvm {
  12. template<typename T>
  13. class ArrayRef;
  14. class Type;
  15. class Value;
  16. class Constant;
  17. class StructType;
  18. class VectorType;
  19. class StoreInst;
  20. }
  21. namespace hlsl {
  22. class DxilFieldAnnotation;
  23. class DxilTypeSystem;
  24. // A high-level matrix type in LLVM IR.
  25. //
  26. // Matrices are represented by an llvm struct type of the following form:
  27. // { [RowCount x <ColCount x RegReprTy>] }
  28. // Note that the element type is always in its register representation (ie bools are i1s).
  29. // This allows preserving the original type and is okay since matrix types are only
  30. // manipulated in an opaque way, through intrinsics.
  31. //
  32. // During matrix lowering, matrices are converted to vectors of the following form:
  33. // <RowCount*ColCount x Ty>
  34. // At this point, register vs memory representation starts to matter and we have to
  35. // imitate the codegen for scalar and vector bools: i1s when in llvm registers,
  36. // and i32s when in memory (allocas, pointers, or in structs/lists, which are always in memory).
  37. //
  38. // This class is designed to resemble a llvm::Type-derived class.
  39. class HLMatrixType
  40. {
  41. public:
  42. static constexpr const char* StructNamePrefix = "class.matrix.";
  43. HLMatrixType() : RegReprElemTy(nullptr), NumRows(0), NumColumns(0) {}
  44. HLMatrixType(llvm::Type *RegReprElemTy, unsigned NumRows, unsigned NumColumns);
  45. // We allow default construction to an invalid state to support the dynCast pattern.
  46. // This tests whether we have a legit object.
  47. operator bool() const { return RegReprElemTy != nullptr; }
  48. llvm::Type *getElementType(bool MemRepr) const;
  49. llvm::Type *getElementTypeForReg() const { return getElementType(false); }
  50. llvm::Type *getElementTypeForMem() const { return getElementType(true); }
  51. unsigned getNumRows() const { return NumRows; }
  52. unsigned getNumColumns() const { return NumColumns; }
  53. unsigned getNumElements() const { return NumRows * NumColumns; }
  54. unsigned getRowMajorIndex(unsigned RowIdx, unsigned ColIdx) const;
  55. unsigned getColumnMajorIndex(unsigned RowIdx, unsigned ColIdx) const;
  56. static unsigned getRowMajorIndex(unsigned RowIdx, unsigned ColIdx, unsigned NumRows, unsigned NumColumns);
  57. static unsigned getColumnMajorIndex(unsigned RowIdx, unsigned ColIdx, unsigned NumRows, unsigned NumColumns);
  58. llvm::VectorType *getLoweredVectorType(bool MemRepr) const;
  59. llvm::VectorType *getLoweredVectorTypeForReg() const { return getLoweredVectorType(false); }
  60. llvm::VectorType *getLoweredVectorTypeForMem() const { return getLoweredVectorType(true); }
  61. llvm::Value *emitLoweredMemToReg(llvm::Value *Val, llvm::IRBuilder<> &Builder) const;
  62. llvm::Value *emitLoweredRegToMem(llvm::Value *Val, llvm::IRBuilder<> &Builder) const;
  63. llvm::Value *emitLoweredLoad(llvm::Value *Ptr, llvm::IRBuilder<> &Builder) const;
  64. llvm::StoreInst *emitLoweredStore(llvm::Value *Val, llvm::Value *Ptr, llvm::IRBuilder<> &Builder) const;
  65. llvm::Value *emitLoweredVectorRowToCol(llvm::Value *VecVal, llvm::IRBuilder<> &Builder) const;
  66. llvm::Value *emitLoweredVectorColToRow(llvm::Value *VecVal, llvm::IRBuilder<> &Builder) const;
  67. static bool isa(llvm::Type *Ty);
  68. static bool isMatrixPtr(llvm::Type *Ty);
  69. static bool isMatrixArray(llvm::Type *Ty);
  70. static bool isMatrixArrayPtr(llvm::Type *Ty);
  71. static bool isMatrixPtrOrArrayPtr(llvm::Type *Ty);
  72. static bool isMatrixOrPtrOrArrayPtr(llvm::Type *Ty);
  73. static llvm::Type *getLoweredType(llvm::Type *Ty, bool MemRepr = false);
  74. static HLMatrixType cast(llvm::Type *Ty);
  75. static HLMatrixType dyn_cast(llvm::Type *Ty);
  76. private:
  77. llvm::Type *RegReprElemTy;
  78. unsigned NumRows, NumColumns;
  79. };
  80. } // namespace hlsl