|
|
@@ -17,7 +17,8 @@
|
|
|
#include "dxc/HlslIntrinsicOp.h"
|
|
|
#include "dxc/Support/Global.h"
|
|
|
#include "dxc/HLSL/DxilOperations.h"
|
|
|
-#include "dxc/hlsl/DxilTypeSystem.h"
|
|
|
+#include "dxc/HLSL/DxilTypeSystem.h"
|
|
|
+#include "dxc/HLSL/DxilModule.h"
|
|
|
|
|
|
#include "llvm/IR/IRBuilder.h"
|
|
|
#include "llvm/IR/Module.h"
|
|
|
@@ -89,6 +90,18 @@ Type *LowerMatrixType(Type *Ty) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// Translate matrix type to array type.
|
|
|
+Type *LowerMatrixTypeToOneDimArray(Type *Ty) {
|
|
|
+ if (IsMatrixType(Ty)) {
|
|
|
+ unsigned row, col;
|
|
|
+ Type *EltTy = GetMatrixInfo(Ty, col, row);
|
|
|
+ return ArrayType::get(EltTy, row * col);
|
|
|
+ } else {
|
|
|
+ return Ty;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
|
|
|
DXASSERT(IsMatrixType(Ty), "not matrix type");
|
|
|
StructType *ST = cast<StructType>(Ty);
|
|
|
@@ -110,6 +123,7 @@ bool IsMatrixArrayPointer(llvm::Type *Ty) {
|
|
|
return IsMatrixType(Ty);
|
|
|
}
|
|
|
Type *LowerMatrixArrayPointer(Type *Ty) {
|
|
|
+ unsigned addrSpace = Ty->getPointerAddressSpace();
|
|
|
Ty = Ty->getPointerElementType();
|
|
|
std::vector<unsigned> arraySizeList;
|
|
|
while (Ty->isArrayTy()) {
|
|
|
@@ -121,9 +135,25 @@ Type *LowerMatrixArrayPointer(Type *Ty) {
|
|
|
for (auto arraySize = arraySizeList.rbegin();
|
|
|
arraySize != arraySizeList.rend(); arraySize++)
|
|
|
Ty = ArrayType::get(Ty, *arraySize);
|
|
|
- return PointerType::get(Ty, 0);
|
|
|
+ return PointerType::get(Ty, addrSpace);
|
|
|
}
|
|
|
|
|
|
+Type *LowerMatrixArrayPointerToOneDimArray(Type *Ty) {
|
|
|
+ unsigned addrSpace = Ty->getPointerAddressSpace();
|
|
|
+ Ty = Ty->getPointerElementType();
|
|
|
+
|
|
|
+ unsigned arraySize = 1;
|
|
|
+ while (Ty->isArrayTy()) {
|
|
|
+ arraySize *= Ty->getArrayNumElements();
|
|
|
+ Ty = Ty->getArrayElementType();
|
|
|
+ }
|
|
|
+ unsigned row, col;
|
|
|
+ Type *EltTy = GetMatrixInfo(Ty, col, row);
|
|
|
+ arraySize *= row*col;
|
|
|
+
|
|
|
+ Ty = ArrayType::get(EltTy, arraySize);
|
|
|
+ return PointerType::get(Ty, addrSpace);
|
|
|
+}
|
|
|
Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
|
|
|
IRBuilder<> &Builder) {
|
|
|
Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
|
|
|
@@ -475,7 +505,8 @@ Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
|
|
|
case HLMatLoadStoreOpcode::ColMatLoad:
|
|
|
case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
|
Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
|
|
|
- if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr)) {
|
|
|
+ if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
|
|
|
+ GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
|
|
|
Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
|
|
|
result = Builder.CreateLoad(vecPtr);
|
|
|
} else
|
|
|
@@ -484,7 +515,8 @@ Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
|
|
|
case HLMatLoadStoreOpcode::ColMatStore:
|
|
|
case HLMatLoadStoreOpcode::RowMatStore: {
|
|
|
Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
|
|
|
- if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr)) {
|
|
|
+ if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr) ||
|
|
|
+ GetIfMatrixGEPOfUDTArg(matPtr, *m_pHLModule)) {
|
|
|
Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
|
|
|
Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
|
Value *vecVal =
|
|
|
@@ -2179,7 +2211,8 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
|
case HLOpcodeGroup::HLMatLoadStore: {
|
|
|
DXASSERT(matToVecMap.count(useCall), "must have vec version");
|
|
|
Value *vecUser = matToVecMap[useCall];
|
|
|
- if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal)) {
|
|
|
+ if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal) ||
|
|
|
+ GetIfMatrixGEPOfUDTArg(matVal, *m_pHLModule)) {
|
|
|
// Load Already translated in lowerToVec.
|
|
|
// Store val operand will be set by the val use.
|
|
|
// Do nothing here.
|
|
|
@@ -2508,7 +2541,8 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
|
// The matrix operands will be undefval for these instructions.
|
|
|
for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
|
|
|
BasicBlock *BB = BBI;
|
|
|
- for (Instruction &I : BB->getInstList()) {
|
|
|
+ for (auto II = BB->begin(); II != BB->end(); ) {
|
|
|
+ Instruction &I = *(II++);
|
|
|
if (IsMatrixType(I.getType())) {
|
|
|
lowerToVec(&I);
|
|
|
} else if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
|
|
|
@@ -2531,7 +2565,8 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
|
// Lower it here to make sure it is ready before replace.
|
|
|
lowerToVec(&I);
|
|
|
}
|
|
|
- } else if (GetIfMatrixGEPOfUDTAlloca(&I)) {
|
|
|
+ } else if (GetIfMatrixGEPOfUDTAlloca(&I) ||
|
|
|
+ GetIfMatrixGEPOfUDTArg(&I, *m_pHLModule)) {
|
|
|
lowerToVec(&I);
|
|
|
}
|
|
|
}
|
|
|
@@ -2582,6 +2617,20 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
|
// %170 = bitcast %class.matrix.float.4.3* %arrayidx.i to <12 x float>*
|
|
|
|
|
|
namespace {
|
|
|
+
|
|
|
+Type *TryLowerMatTy(Type *Ty) {
|
|
|
+ Type *VecTy = nullptr;
|
|
|
+ if (HLMatrixLower::IsMatrixArrayPointer(Ty)) {
|
|
|
+ VecTy = HLMatrixLower::LowerMatrixArrayPointerToOneDimArray(Ty);
|
|
|
+ } else if (isa<PointerType>(Ty) &&
|
|
|
+ HLMatrixLower::IsMatrixType(Ty->getPointerElementType())) {
|
|
|
+ VecTy = HLMatrixLower::LowerMatrixTypeToOneDimArray(
|
|
|
+ Ty->getPointerElementType());
|
|
|
+ VecTy = PointerType::get(VecTy, Ty->getPointerAddressSpace());
|
|
|
+ }
|
|
|
+ return VecTy;
|
|
|
+}
|
|
|
+
|
|
|
class MatrixBitcastLowerPass : public FunctionPass {
|
|
|
|
|
|
public:
|
|
|
@@ -2590,18 +2639,166 @@ public:
|
|
|
|
|
|
const char *getPassName() const override { return "Matrix Bitcast lower"; }
|
|
|
bool runOnFunction(Function &F) override {
|
|
|
- // TODO: remove bitcast on matrix.
|
|
|
- return false;
|
|
|
+ bool bUpdated = false;
|
|
|
+ std::unordered_set<BitCastInst*> matCastSet;
|
|
|
+ for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
|
|
|
+ BasicBlock *BB = blkIt;
|
|
|
+ for (auto iIt = BB->begin(); iIt != BB->end(); ) {
|
|
|
+ Instruction *I = (iIt++);
|
|
|
+ if (BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
|
|
|
+ // Mutate mat to vec.
|
|
|
+ Type *ToTy = BCI->getType();
|
|
|
+ if (Type *ToVecTy = TryLowerMatTy(ToTy)) {
|
|
|
+ matCastSet.insert(BCI);
|
|
|
+ bUpdated = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
|
|
|
+ // Remove bitcast which has CallInst user.
|
|
|
+ if (DM.GetShaderModel()->IsLib()) {
|
|
|
+ for (auto it = matCastSet.begin(); it != matCastSet.end();) {
|
|
|
+ BitCastInst *BCI = *(it++);
|
|
|
+ if (hasCallUser(BCI)) {
|
|
|
+ matCastSet.erase(BCI);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Lower matrix first.
|
|
|
+ for (BitCastInst *BCI : matCastSet) {
|
|
|
+ lowerMatrix(BCI, BCI->getOperand(0));
|
|
|
+ }
|
|
|
+ return bUpdated;
|
|
|
}
|
|
|
private:
|
|
|
- void lowerMatrixBitcast(BitCastInst *BCI);
|
|
|
+ void lowerMatrix(Instruction *M, Value *A);
|
|
|
+ bool hasCallUser(Instruction *M);
|
|
|
};
|
|
|
|
|
|
}
|
|
|
|
|
|
-void MatrixBitcastLowerPass::lowerMatrixBitcast(BitCastInst *BCI) {
|
|
|
- // to matrix.
|
|
|
- // from matrix.
|
|
|
+bool MatrixBitcastLowerPass::hasCallUser(Instruction *M) {
|
|
|
+ for (auto it = M->user_begin(); it != M->user_end();) {
|
|
|
+ User *U = *(it++);
|
|
|
+ if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
|
|
|
+ Type *EltTy = GEP->getType()->getPointerElementType();
|
|
|
+ if (HLMatrixLower::IsMatrixType(EltTy)) {
|
|
|
+ if (hasCallUser(GEP))
|
|
|
+ return true;
|
|
|
+ } else {
|
|
|
+ DXASSERT(0, "invalid GEP for matrix");
|
|
|
+ }
|
|
|
+ } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
|
|
|
+ if (hasCallUser(BCI))
|
|
|
+ return true;
|
|
|
+ } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
|
|
|
+ if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
|
|
|
+ } else {
|
|
|
+ DXASSERT(0, "invalid load for matrix");
|
|
|
+ }
|
|
|
+ } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
|
|
|
+ Value *V = ST->getValueOperand();
|
|
|
+ if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
|
|
|
+ } else {
|
|
|
+ DXASSERT(0, "invalid load for matrix");
|
|
|
+ }
|
|
|
+ } else if (isa<CallInst>(U)) {
|
|
|
+ return true;
|
|
|
+ } else {
|
|
|
+ DXASSERT(0, "invalid use of matrix");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+}
|
|
|
+
|
|
|
+namespace {
|
|
|
+Value *CreateEltGEP(Value *A, unsigned i, Value *zeroIdx,
|
|
|
+ IRBuilder<> &Builder) {
|
|
|
+ Value *GEP = nullptr;
|
|
|
+ if (GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A)) {
|
|
|
+ // A should be gep oneDimArray, 0, index * matSize
|
|
|
+ // Here add eltIdx to index * matSize foreach elt.
|
|
|
+ Instruction *EltGEP = GEPA->clone();
|
|
|
+ unsigned eltIdx = EltGEP->getNumOperands() - 1;
|
|
|
+ Value *NewIdx =
|
|
|
+ Builder.CreateAdd(EltGEP->getOperand(eltIdx), Builder.getInt32(i));
|
|
|
+ EltGEP->setOperand(eltIdx, NewIdx);
|
|
|
+ Builder.Insert(EltGEP);
|
|
|
+ GEP = EltGEP;
|
|
|
+ } else {
|
|
|
+ GEP = Builder.CreateInBoundsGEP(A, {zeroIdx, Builder.getInt32(i)});
|
|
|
+ }
|
|
|
+ return GEP;
|
|
|
+}
|
|
|
+} // namespace
|
|
|
+
|
|
|
+void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
|
|
|
+ for (auto it = M->user_begin(); it != M->user_end();) {
|
|
|
+ User *U = *(it++);
|
|
|
+ if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
|
|
|
+ Type *EltTy = GEP->getType()->getPointerElementType();
|
|
|
+ if (HLMatrixLower::IsMatrixType(EltTy)) {
|
|
|
+ // Change gep matrixArray, 0, index
|
|
|
+ // into
|
|
|
+ // gep oneDimArray, 0, index * matSize
|
|
|
+ IRBuilder<> Builder(GEP);
|
|
|
+ SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
|
|
|
+ DXASSERT(idxList.size() == 2,
|
|
|
+ "else not one dim matrix array index to matrix");
|
|
|
+ unsigned col = 0;
|
|
|
+ unsigned row = 0;
|
|
|
+ HLMatrixLower::GetMatrixInfo(EltTy, col, row);
|
|
|
+ Value *matSize = Builder.getInt32(col * row);
|
|
|
+ idxList.back() = Builder.CreateMul(idxList.back(), matSize);
|
|
|
+ Value *NewGEP = Builder.CreateGEP(A, idxList);
|
|
|
+ lowerMatrix(GEP, NewGEP);
|
|
|
+ DXASSERT(GEP->user_empty(), "else lower matrix fail");
|
|
|
+ GEP->eraseFromParent();
|
|
|
+ } else {
|
|
|
+ DXASSERT(0, "invalid GEP for matrix");
|
|
|
+ }
|
|
|
+ } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
|
|
|
+ lowerMatrix(BCI, A);
|
|
|
+ DXASSERT(BCI->user_empty(), "else lower matrix fail");
|
|
|
+ BCI->eraseFromParent();
|
|
|
+ } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
|
|
|
+ if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
|
|
|
+ IRBuilder<> Builder(LI);
|
|
|
+ Value *zeroIdx = Builder.getInt32(0);
|
|
|
+ unsigned vecSize = Ty->getNumElements();
|
|
|
+ Value *NewVec = UndefValue::get(LI->getType());
|
|
|
+ for (unsigned i = 0; i < vecSize; i++) {
|
|
|
+ Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
|
|
|
+ Value *Elt = Builder.CreateLoad(GEP);
|
|
|
+ NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
|
|
|
+ }
|
|
|
+ LI->replaceAllUsesWith(NewVec);
|
|
|
+ LI->eraseFromParent();
|
|
|
+ } else {
|
|
|
+ DXASSERT(0, "invalid load for matrix");
|
|
|
+ }
|
|
|
+ } else if (StoreInst *ST = dyn_cast<StoreInst>(U)) {
|
|
|
+ Value *V = ST->getValueOperand();
|
|
|
+ if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
|
|
|
+ IRBuilder<> Builder(LI);
|
|
|
+ Value *zeroIdx = Builder.getInt32(0);
|
|
|
+ unsigned vecSize = Ty->getNumElements();
|
|
|
+ for (unsigned i = 0; i < vecSize; i++) {
|
|
|
+ Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
|
|
|
+ Value *Elt = Builder.CreateExtractElement(V, i);
|
|
|
+ Builder.CreateStore(Elt, GEP);
|
|
|
+ }
|
|
|
+ ST->eraseFromParent();
|
|
|
+ } else {
|
|
|
+ DXASSERT(0, "invalid load for matrix");
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ DXASSERT(0, "invalid use of matrix");
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
#include "dxc/HLSL/DxilGenerationPass.h"
|