HLMatrixType.cpp 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLMatrixType.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/HLSL/HLMatrixType.h"
  10. #include "dxc/Support/Global.h"
  11. #include "llvm/IR/IRBuilder.h"
  12. #include "llvm/IR/Module.h"
  13. #include "llvm/IR/Type.h"
  14. #include "llvm/IR/DerivedTypes.h"
  15. #include "llvm/IR/Value.h"
  16. using namespace llvm;
  17. using namespace hlsl;
  18. HLMatrixType::HLMatrixType(Type *RegReprElemTy, unsigned NumRows, unsigned NumColumns)
  19. : RegReprElemTy(RegReprElemTy), NumRows(NumRows), NumColumns(NumColumns) {
  20. DXASSERT(RegReprElemTy != nullptr && (RegReprElemTy->isIntegerTy() || RegReprElemTy->isFloatingPointTy()),
  21. "Invalid matrix element type.");
  22. DXASSERT(NumRows >= 1 && NumRows <= 4 && NumColumns >= 1 && NumColumns <= 4,
  23. "Invalid matrix dimensions.");
  24. }
  25. Type *HLMatrixType::getElementType(bool MemRepr) const {
  26. // Bool i1s become i32s
  27. return MemRepr && RegReprElemTy->isIntegerTy(1)
  28. ? IntegerType::get(RegReprElemTy->getContext(), 32)
  29. : RegReprElemTy;
  30. }
  31. unsigned HLMatrixType::getRowMajorIndex(unsigned RowIdx, unsigned ColIdx) const {
  32. return getRowMajorIndex(RowIdx, ColIdx, NumRows, NumColumns);
  33. }
  34. unsigned HLMatrixType::getColumnMajorIndex(unsigned RowIdx, unsigned ColIdx) const {
  35. return getColumnMajorIndex(RowIdx, ColIdx, NumRows, NumColumns);
  36. }
  37. unsigned HLMatrixType::getRowMajorIndex(unsigned RowIdx, unsigned ColIdx, unsigned NumRows, unsigned NumColumns) {
  38. DXASSERT_NOMSG(RowIdx < NumRows && ColIdx < NumColumns);
  39. return RowIdx * NumColumns + ColIdx;
  40. }
  41. unsigned HLMatrixType::getColumnMajorIndex(unsigned RowIdx, unsigned ColIdx, unsigned NumRows, unsigned NumColumns) {
  42. DXASSERT_NOMSG(RowIdx < NumRows && ColIdx < NumColumns);
  43. return ColIdx * NumRows + RowIdx;
  44. }
  45. VectorType *HLMatrixType::getLoweredVectorType(bool MemRepr) const {
  46. return VectorType::get(getElementType(MemRepr), getNumElements());
  47. }
  48. Value *HLMatrixType::emitLoweredMemToReg(Value *Val, IRBuilder<> &Builder) const {
  49. DXASSERT(Val->getType()->getScalarType() == getElementTypeForMem(), "Lowered matrix type mismatch.");
  50. if (RegReprElemTy->isIntegerTy(1)) {
  51. Val = Builder.CreateICmpNE(Val, Constant::getNullValue(Val->getType()), "tobool");
  52. }
  53. return Val;
  54. }
  55. Value *HLMatrixType::emitLoweredRegToMem(Value *Val, IRBuilder<> &Builder) const {
  56. DXASSERT(Val->getType()->getScalarType() == RegReprElemTy, "Lowered matrix type mismatch.");
  57. if (RegReprElemTy->isIntegerTy(1)) {
  58. Type *MemReprTy = Val->getType()->isVectorTy() ? getLoweredVectorTypeForMem() : getElementTypeForMem();
  59. Val = Builder.CreateZExt(Val, MemReprTy, "frombool");
  60. }
  61. return Val;
  62. }
  63. Value *HLMatrixType::emitLoweredLoad(Value *Ptr, IRBuilder<> &Builder) const {
  64. return emitLoweredMemToReg(Builder.CreateLoad(Ptr), Builder);
  65. }
  66. StoreInst *HLMatrixType::emitLoweredStore(Value *Val, Value *Ptr, IRBuilder<> &Builder) const {
  67. return Builder.CreateStore(emitLoweredRegToMem(Val, Builder), Ptr);
  68. }
  69. Value *HLMatrixType::emitLoweredVectorRowToCol(Value *VecVal, IRBuilder<> &Builder) const {
  70. DXASSERT(VecVal->getType() == getLoweredVectorTypeForReg(), "Lowered matrix type mismatch.");
  71. if (NumRows == 1 || NumColumns == 1) return VecVal;
  72. SmallVector<int, 16> ShuffleIndices;
  73. for (unsigned ColIdx = 0; ColIdx < NumColumns; ++ColIdx)
  74. for (unsigned RowIdx = 0; RowIdx < NumRows; ++RowIdx)
  75. ShuffleIndices.emplace_back((int)getRowMajorIndex(RowIdx, ColIdx));
  76. return Builder.CreateShuffleVector(VecVal, VecVal, ShuffleIndices, "row2col");
  77. }
  78. Value *HLMatrixType::emitLoweredVectorColToRow(Value *VecVal, IRBuilder<> &Builder) const {
  79. DXASSERT(VecVal->getType() == getLoweredVectorTypeForReg(), "Lowered matrix type mismatch.");
  80. if (NumRows == 1 || NumColumns == 1) return VecVal;
  81. SmallVector<int, 16> ShuffleIndices;
  82. for (unsigned RowIdx = 0; RowIdx < NumRows; ++RowIdx)
  83. for (unsigned ColIdx = 0; ColIdx < NumColumns; ++ColIdx)
  84. ShuffleIndices.emplace_back((int)getColumnMajorIndex(RowIdx, ColIdx));
  85. return Builder.CreateShuffleVector(VecVal, VecVal, ShuffleIndices, "col2row");
  86. }
  87. bool HLMatrixType::isa(Type *Ty) {
  88. StructType *StructTy = llvm::dyn_cast<StructType>(Ty);
  89. return StructTy != nullptr && !StructTy->isLiteral() && StructTy->getName().startswith(StructNamePrefix);
  90. }
  91. bool HLMatrixType::isMatrixPtr(Type *Ty) {
  92. PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty);
  93. return PtrTy != nullptr && isa(PtrTy->getElementType());
  94. }
  95. bool HLMatrixType::isMatrixArray(Type *Ty) {
  96. ArrayType *ArrayTy = llvm::dyn_cast<ArrayType>(Ty);
  97. if (ArrayTy == nullptr) return false;
  98. while (ArrayType *NestedArrayTy = llvm::dyn_cast<ArrayType>(ArrayTy->getElementType()))
  99. ArrayTy = NestedArrayTy;
  100. return isa(ArrayTy->getElementType());
  101. }
  102. bool HLMatrixType::isMatrixArrayPtr(Type *Ty) {
  103. PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty);
  104. if (PtrTy == nullptr) return false;
  105. return isMatrixArray(PtrTy->getElementType());
  106. }
  107. bool HLMatrixType::isMatrixPtrOrArrayPtr(Type *Ty) {
  108. PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty);
  109. if (PtrTy == nullptr) return false;
  110. Ty = PtrTy->getElementType();
  111. while (ArrayType *ArrayTy = llvm::dyn_cast<ArrayType>(Ty))
  112. Ty = Ty->getArrayElementType();
  113. return isa(Ty);
  114. }
  115. bool HLMatrixType::isMatrixOrPtrOrArrayPtr(Type *Ty) {
  116. if (PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty)) Ty = PtrTy->getElementType();
  117. while (ArrayType *ArrayTy = llvm::dyn_cast<ArrayType>(Ty)) Ty = ArrayTy->getElementType();
  118. return isa(Ty);
  119. }
  120. // Converts a matrix, matrix pointer, or matrix array pointer type to its lowered equivalent.
  121. // If the type is not matrix-derived, the original type is returned.
  122. // Does not lower struct types containing matrices.
  123. Type *HLMatrixType::getLoweredType(Type *Ty, bool MemRepr) {
  124. if (PointerType *PtrTy = llvm::dyn_cast<PointerType>(Ty)) {
  125. // Pointees are always in memory representation
  126. Type *LoweredElemTy = getLoweredType(PtrTy->getElementType(), /* MemRepr */ true);
  127. return LoweredElemTy == PtrTy->getElementType()
  128. ? Ty : PointerType::get(LoweredElemTy, PtrTy->getAddressSpace());
  129. }
  130. else if (ArrayType *ArrayTy = llvm::dyn_cast<ArrayType>(Ty)) {
  131. // Arrays are always in memory and so their elements are in memory representation
  132. Type *LoweredElemTy = getLoweredType(ArrayTy->getElementType(), /* MemRepr */ true);
  133. return LoweredElemTy == ArrayTy->getElementType()
  134. ? Ty : ArrayType::get(LoweredElemTy, ArrayTy->getNumElements());
  135. }
  136. else if (HLMatrixType MatrixTy = HLMatrixType::dyn_cast(Ty)) {
  137. return MatrixTy.getLoweredVectorType(MemRepr);
  138. }
  139. else return Ty;
  140. }
  141. HLMatrixType HLMatrixType::cast(Type *Ty) {
  142. DXASSERT_NOMSG(isa(Ty));
  143. StructType *StructTy = llvm::cast<StructType>(Ty);
  144. DXASSERT_NOMSG(Ty->getNumContainedTypes() == 1);
  145. ArrayType *RowArrayTy = llvm::cast<ArrayType>(StructTy->getElementType(0));
  146. DXASSERT_NOMSG(RowArrayTy->getNumElements() >= 1 && RowArrayTy->getNumElements() <= 4);
  147. VectorType *RowTy = llvm::cast<VectorType>(RowArrayTy->getElementType());
  148. DXASSERT_NOMSG(RowTy->getNumElements() >= 1 && RowTy->getNumElements() <= 4);
  149. return HLMatrixType(RowTy->getElementType(), RowArrayTy->getNumElements(), RowTy->getNumElements());
  150. }
  151. HLMatrixType HLMatrixType::dyn_cast(Type *Ty) {
  152. return isa(Ty) ? cast(Ty) : HLMatrixType();
  153. }