|
@@ -73,10 +73,10 @@ Type *GetMatrixInfo(Type *Ty, unsigned &col, unsigned &row) {
|
|
DXASSERT(IsMatrixType(Ty), "not matrix type");
|
|
DXASSERT(IsMatrixType(Ty), "not matrix type");
|
|
StructType *ST = cast<StructType>(Ty);
|
|
StructType *ST = cast<StructType>(Ty);
|
|
Type *EltTy = ST->getElementType(0);
|
|
Type *EltTy = ST->getElementType(0);
|
|
- Type *ColTy = EltTy->getArrayElementType();
|
|
|
|
- col = EltTy->getArrayNumElements();
|
|
|
|
- row = ColTy->getVectorNumElements();
|
|
|
|
- return ColTy->getVectorElementType();
|
|
|
|
|
|
+ Type *RowTy = EltTy->getArrayElementType();
|
|
|
|
+ row = EltTy->getArrayNumElements();
|
|
|
|
+ col = RowTy->getVectorNumElements();
|
|
|
|
+ return RowTy->getVectorElementType();
|
|
}
|
|
}
|
|
|
|
|
|
bool IsMatrixArrayPointer(llvm::Type *Ty) {
|
|
bool IsMatrixArrayPointer(llvm::Type *Ty) {
|
|
@@ -104,23 +104,49 @@ Type *LowerMatrixArrayPointer(Type *Ty) {
|
|
return PointerType::get(Ty, 0);
|
|
return PointerType::get(Ty, 0);
|
|
}
|
|
}
|
|
|
|
|
|
-Value *BuildMatrix(Type *EltTy, unsigned col, unsigned row,
|
|
|
|
- bool colMajor, ArrayRef<Value *> elts,
|
|
|
|
- IRBuilder<> &Builder) {
|
|
|
|
- Value *Result = UndefValue::get(VectorType::get(EltTy, col * row));
|
|
|
|
- if (colMajor) {
|
|
|
|
- for (unsigned i = 0; i < col * row; i++)
|
|
|
|
- Result = Builder.CreateInsertElement(Result, elts[i], i);
|
|
|
|
|
|
+Value *BuildVector(Type *EltTy, unsigned size, ArrayRef<llvm::Value *> elts,
|
|
|
|
+ IRBuilder<> &Builder) {
|
|
|
|
+ Value *Vec = UndefValue::get(VectorType::get(EltTy, size));
|
|
|
|
+ for (unsigned i = 0; i < size; i++)
|
|
|
|
+ Vec = Builder.CreateInsertElement(Vec, elts[i], i);
|
|
|
|
+ return Vec;
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+Value *LowerGEPOnMatIndexListToIndex(
|
|
|
|
+ llvm::GetElementPtrInst *GEP, ArrayRef<Value *> IdxList) {
|
|
|
|
+ IRBuilder<> Builder(GEP);
|
|
|
|
+ Value *zero = Builder.getInt32(0);
|
|
|
|
+ DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
|
|
|
|
+ Value *baseIdx = (GEP->idx_begin())->get();
|
|
|
|
+ DXASSERT_LOCALVAR(baseIdx, baseIdx == zero, "base index must be 0");
|
|
|
|
+ Value *Idx = (GEP->idx_begin() + 1)->get();
|
|
|
|
+
|
|
|
|
+ if (ConstantInt *immIdx = dyn_cast<ConstantInt>(Idx)) {
|
|
|
|
+ return IdxList[immIdx->getSExtValue()];
|
|
} else {
|
|
} else {
|
|
- for (unsigned r = 0; r < row; r++)
|
|
|
|
- for (unsigned c = 0; c < col; c++) {
|
|
|
|
- unsigned rowMajorIdx = r * col + c;
|
|
|
|
- unsigned colMajorIdx = c * row + r;
|
|
|
|
- Result =
|
|
|
|
- Builder.CreateInsertElement(Result, elts[rowMajorIdx], colMajorIdx);
|
|
|
|
- }
|
|
|
|
|
|
+ IRBuilder<> AllocaBuilder(
|
|
|
|
+ GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
|
|
|
|
+ unsigned size = IdxList.size();
|
|
|
|
+ // Store idxList to temp array.
|
|
|
|
+ ArrayType *AT = ArrayType::get(IdxList[0]->getType(), size);
|
|
|
|
+ Value *tempArray = AllocaBuilder.CreateAlloca(AT);
|
|
|
|
+
|
|
|
|
+ for (unsigned i = 0; i < size; i++) {
|
|
|
|
+ Value *EltPtr = Builder.CreateGEP(tempArray, {zero, Builder.getInt32(i)});
|
|
|
|
+ Builder.CreateStore(IdxList[i], EltPtr);
|
|
|
|
+ }
|
|
|
|
+ // Load the idx.
|
|
|
|
+ Value *GEPOffset = Builder.CreateGEP(tempArray, {zero, Idx});
|
|
|
|
+ return Builder.CreateLoad(GEPOffset);
|
|
}
|
|
}
|
|
- return Result;
|
|
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+unsigned GetColMajorIdx(unsigned r, unsigned c, unsigned row) {
|
|
|
|
+ return c * row + r;
|
|
|
|
+}
|
|
|
|
+unsigned GetRowMajorIdx(unsigned r, unsigned c, unsigned col) {
|
|
|
|
+ return r * col + c;
|
|
}
|
|
}
|
|
|
|
|
|
} // namespace HLMatrixLower
|
|
} // namespace HLMatrixLower
|
|
@@ -222,6 +248,8 @@ private:
|
|
CallInst *castInst);
|
|
CallInst *castInst);
|
|
void TranslateMatCast(CallInst *matInst, Instruction *vecInst,
|
|
void TranslateMatCast(CallInst *matInst, Instruction *vecInst,
|
|
CallInst *castInst);
|
|
CallInst *castInst);
|
|
|
|
+ void TranslateMatMajorCast(CallInst *matInst, Instruction *vecInst,
|
|
|
|
+ CallInst *castInst, bool rowToCol);
|
|
// Replace matInst with vecInst in matSubscript
|
|
// Replace matInst with vecInst in matSubscript
|
|
void TranslateMatSubscript(Value *matInst, Value *vecInst,
|
|
void TranslateMatSubscript(Value *matInst, Value *vecInst,
|
|
CallInst *matSubInst);
|
|
CallInst *matSubInst);
|
|
@@ -236,8 +264,6 @@ private:
|
|
CallInst *matLdStInst);
|
|
CallInst *matLdStInst);
|
|
void TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
|
|
void TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
|
|
CallInst *matLdStInst);
|
|
CallInst *matLdStInst);
|
|
- void TranslateMatSubscriptOnGlobal(GlobalVariable *matGlobal, GlobalVariable *vecGlobal,
|
|
|
|
- CallInst *matSubInst);
|
|
|
|
void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
|
|
void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
|
|
void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
|
|
void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
|
|
|
|
|
|
@@ -263,11 +289,6 @@ ModulePass *llvm::createHLMatrixLowerPass() { return new HLMatrixLowerPass(); }
|
|
|
|
|
|
INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
|
|
INITIALIZE_PASS(HLMatrixLowerPass, "hlmatrixlower", "HLSL High-Level Matrix Lower", false, false)
|
|
|
|
|
|
-// All calculation on col major.
|
|
|
|
-static unsigned GetMatIdx(unsigned r, unsigned c, unsigned rowSize) {
|
|
|
|
- return (c * rowSize + r);
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
|
|
static Instruction *CreateTypeCast(HLCastOpcode castOp, Type *toTy, Value *src,
|
|
IRBuilder<> Builder) {
|
|
IRBuilder<> Builder) {
|
|
// Cast to bool.
|
|
// Cast to bool.
|
|
@@ -843,8 +864,8 @@ void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
|
|
Value *rMat = matToVecMap[cast<Instruction>(RVal)];
|
|
Value *rMat = matToVecMap[cast<Instruction>(RVal)];
|
|
|
|
|
|
auto CreateOneEltMul = [&](unsigned r, unsigned lc, unsigned c) -> Value * {
|
|
auto CreateOneEltMul = [&](unsigned r, unsigned lc, unsigned c) -> Value * {
|
|
- unsigned lMatIdx = GetMatIdx(r, lc, row);
|
|
|
|
- unsigned rMatIdx = GetMatIdx(lc, c, rRow);
|
|
|
|
|
|
+ unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
|
|
|
|
+ unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
|
|
Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
|
|
Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
|
|
Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
|
|
Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
|
|
return isFloat ? Builder.CreateFMul(lMatElt, rMatElt)
|
|
return isFloat ? Builder.CreateFMul(lMatElt, rMatElt)
|
|
@@ -859,8 +880,8 @@ void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
|
|
|
|
|
|
auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
|
|
auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
|
|
Value *acc) -> Value * {
|
|
Value *acc) -> Value * {
|
|
- unsigned lMatIdx = GetMatIdx(r, lc, row);
|
|
|
|
- unsigned rMatIdx = GetMatIdx(lc, c, rRow);
|
|
|
|
|
|
+ unsigned lMatIdx = HLMatrixLower::GetRowMajorIdx(r, lc, col);
|
|
|
|
+ unsigned rMatIdx = HLMatrixLower::GetRowMajorIdx(lc, c, rCol);
|
|
Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
|
|
Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
|
|
Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
|
|
Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
|
|
return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
|
|
return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
|
|
@@ -874,7 +895,7 @@ void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
|
|
for (lc = 1; lc < col; lc++) {
|
|
for (lc = 1; lc < col; lc++) {
|
|
tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
|
|
tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
|
|
}
|
|
}
|
|
- unsigned matIdx = GetMatIdx(r, c, row);
|
|
|
|
|
|
+ unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, rCol);
|
|
retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
|
|
retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -912,7 +933,7 @@ void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
|
|
|
|
|
|
auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
|
|
auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
|
|
Value *vecElt = Builder.CreateExtractElement(vec, c);
|
|
Value *vecElt = Builder.CreateExtractElement(vec, c);
|
|
- uint32_t matIdx = GetMatIdx(r, c, row);
|
|
|
|
|
|
+ uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
|
|
Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
|
|
return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
|
|
};
|
|
};
|
|
@@ -920,7 +941,7 @@ void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
|
|
for (unsigned r = 0; r < row; r++) {
|
|
for (unsigned r = 0; r < row; r++) {
|
|
unsigned c = 0;
|
|
unsigned c = 0;
|
|
Value *vecElt = Builder.CreateExtractElement(vec, c);
|
|
Value *vecElt = Builder.CreateExtractElement(vec, c);
|
|
- uint32_t matIdx = GetMatIdx(r, c, row);
|
|
|
|
|
|
+ uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
|
|
Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
|
|
|
|
Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
|
|
Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
|
|
@@ -964,7 +985,7 @@ void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
|
|
|
|
|
|
auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
|
|
auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
|
|
Value *vecElt = Builder.CreateExtractElement(vec, r);
|
|
Value *vecElt = Builder.CreateExtractElement(vec, r);
|
|
- uint32_t matIdx = GetMatIdx(r, c, row);
|
|
|
|
|
|
+ uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
|
|
Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
|
|
return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
|
|
};
|
|
};
|
|
@@ -972,7 +993,7 @@ void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
|
|
for (unsigned c = 0; c < col; c++) {
|
|
for (unsigned c = 0; c < col; c++) {
|
|
unsigned r = 0;
|
|
unsigned r = 0;
|
|
Value *vecElt = Builder.CreateExtractElement(vec, r);
|
|
Value *vecElt = Builder.CreateExtractElement(vec, r);
|
|
- uint32_t matIdx = GetMatIdx(r, c, row);
|
|
|
|
|
|
+ uint32_t matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
|
|
Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
|
|
|
|
Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
|
|
Value *tmpVal = isFloat ? Builder.CreateFMul(vecElt, matElt)
|
|
@@ -1008,25 +1029,8 @@ void HLMatrixLowerPass::TranslateMul(CallInst *matInst, Instruction *vecInst,
|
|
void HLMatrixLowerPass::TranslateMatTranspose(CallInst *matInst,
|
|
void HLMatrixLowerPass::TranslateMatTranspose(CallInst *matInst,
|
|
Instruction *vecInst,
|
|
Instruction *vecInst,
|
|
CallInst *transposeInst) {
|
|
CallInst *transposeInst) {
|
|
- unsigned row, col;
|
|
|
|
- GetMatrixInfo(transposeInst->getType(), col, row);
|
|
|
|
- IRBuilder<> Builder(transposeInst);
|
|
|
|
- std::vector<int> transposeMask(col * row);
|
|
|
|
- unsigned idx = 0;
|
|
|
|
- for (unsigned c = 0; c < col; c++)
|
|
|
|
- for (unsigned r = 0; r < row; r++) {
|
|
|
|
- // change to row major
|
|
|
|
- unsigned matIdx = GetMatIdx(c, r, col);
|
|
|
|
- transposeMask[idx++] = (matIdx);
|
|
|
|
- }
|
|
|
|
- Instruction *shuf = cast<Instruction>(
|
|
|
|
- Builder.CreateShuffleVector(vecInst, vecInst, transposeMask));
|
|
|
|
- // Replace vec transpose function call with shuf.
|
|
|
|
- DXASSERT(matToVecMap.count(transposeInst), "must has vec version");
|
|
|
|
- Instruction *vecUseInst = cast<Instruction>(matToVecMap[transposeInst]);
|
|
|
|
- vecUseInst->replaceAllUsesWith(shuf);
|
|
|
|
- AddToDeadInsts(vecUseInst);
|
|
|
|
- matToVecMap[transposeInst] = shuf;
|
|
|
|
|
|
+ // Matrix value is row major, transpose is cast it to col major.
|
|
|
|
+ TranslateMatMajorCast(matInst, vecInst, transposeInst, /*bRowToCol*/ true);
|
|
}
|
|
}
|
|
|
|
|
|
static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
|
|
static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
|
|
@@ -1144,6 +1148,44 @@ void HLMatrixLowerPass::TrivialMatReplace(CallInst *matInst,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+void HLMatrixLowerPass::TranslateMatMajorCast(CallInst *matInst,
|
|
|
|
+ Instruction *vecInst,
|
|
|
|
+ CallInst *castInst,
|
|
|
|
+ bool bRowToCol) {
|
|
|
|
+ unsigned col, row;
|
|
|
|
+ GetMatrixInfo(castInst->getType(), col, row);
|
|
|
|
+ DXASSERT(castInst->getType() == matInst->getType(), "type must match");
|
|
|
|
+
|
|
|
|
+ IRBuilder<> Builder(castInst);
|
|
|
|
+
|
|
|
|
+ // shuf to change major.
|
|
|
|
+ SmallVector<int, 16> castMask(col * row);
|
|
|
|
+ unsigned idx = 0;
|
|
|
|
+ if (bRowToCol) {
|
|
|
|
+ for (unsigned c = 0; c < col; c++)
|
|
|
|
+ for (unsigned r = 0; r < row; r++) {
|
|
|
|
+ unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
|
|
|
|
+ castMask[idx++] = matIdx;
|
|
|
|
+ }
|
|
|
|
+ } else {
|
|
|
|
+ for (unsigned r = 0; r < row; r++)
|
|
|
|
+ for (unsigned c = 0; c < col; c++) {
|
|
|
|
+ unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
|
|
|
|
+ castMask[idx++] = matIdx;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ Instruction *vecCast = cast<Instruction>(
|
|
|
|
+ Builder.CreateShuffleVector(vecInst, vecInst, castMask));
|
|
|
|
+
|
|
|
|
+ // Replace vec cast function call with vecCast.
|
|
|
|
+ DXASSERT(matToVecMap.count(castInst), "must has vec version");
|
|
|
|
+ Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
|
|
|
|
+ vecUseInst->replaceAllUsesWith(vecCast);
|
|
|
|
+ AddToDeadInsts(vecUseInst);
|
|
|
|
+ matToVecMap[castInst] = vecCast;
|
|
|
|
+}
|
|
|
|
+
|
|
void HLMatrixLowerPass::TranslateMatMatCast(CallInst *matInst,
|
|
void HLMatrixLowerPass::TranslateMatMatCast(CallInst *matInst,
|
|
Instruction *vecInst,
|
|
Instruction *vecInst,
|
|
CallInst *castInst) {
|
|
CallInst *castInst) {
|
|
@@ -1167,9 +1209,9 @@ void HLMatrixLowerPass::TranslateMatMatCast(CallInst *matInst,
|
|
// shuf first
|
|
// shuf first
|
|
std::vector<int> castMask(toCol * toRow);
|
|
std::vector<int> castMask(toCol * toRow);
|
|
unsigned idx = 0;
|
|
unsigned idx = 0;
|
|
- for (unsigned c = 0; c < toCol; c++)
|
|
|
|
- for (unsigned r = 0; r < toRow; r++) {
|
|
|
|
- unsigned matIdx = GetMatIdx(r, c, fromRow);
|
|
|
|
|
|
+ for (unsigned r = 0; r < toRow; r++)
|
|
|
|
+ for (unsigned c = 0; c < toCol; c++) {
|
|
|
|
+ unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, fromCol);
|
|
castMask[idx++] = matIdx;
|
|
castMask[idx++] = matIdx;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -1232,14 +1274,21 @@ void HLMatrixLowerPass::TranslateMatToOtherCast(CallInst *matInst,
|
|
void HLMatrixLowerPass::TranslateMatCast(CallInst *matInst,
|
|
void HLMatrixLowerPass::TranslateMatCast(CallInst *matInst,
|
|
Instruction *vecInst,
|
|
Instruction *vecInst,
|
|
CallInst *castInst) {
|
|
CallInst *castInst) {
|
|
- bool ToMat = IsMatrixType(castInst->getType());
|
|
|
|
- bool FromMat = IsMatrixType(matInst->getType());
|
|
|
|
- if (ToMat && FromMat) {
|
|
|
|
- TranslateMatMatCast(matInst, vecInst, castInst);
|
|
|
|
- } else if (FromMat)
|
|
|
|
- TranslateMatToOtherCast(matInst, vecInst, castInst);
|
|
|
|
- else
|
|
|
|
- DXASSERT(0, "Not translate as user of matInst");
|
|
|
|
|
|
+ HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
|
|
|
|
+ if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
|
|
|
|
+ opcode == HLCastOpcode::RowMatrixToColMatrix) {
|
|
|
|
+ TranslateMatMajorCast(matInst, vecInst, castInst,
|
|
|
|
+ opcode == HLCastOpcode::RowMatrixToColMatrix);
|
|
|
|
+ } else {
|
|
|
|
+ bool ToMat = IsMatrixType(castInst->getType());
|
|
|
|
+ bool FromMat = IsMatrixType(matInst->getType());
|
|
|
|
+ if (ToMat && FromMat) {
|
|
|
|
+ TranslateMatMatCast(matInst, vecInst, castInst);
|
|
|
|
+ } else if (FromMat)
|
|
|
|
+ TranslateMatToOtherCast(matInst, vecInst, castInst);
|
|
|
|
+ else
|
|
|
|
+ DXASSERT(0, "Not translate as user of matInst");
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
void HLMatrixLowerPass::MatIntrinsicReplace(CallInst *matInst,
|
|
void HLMatrixLowerPass::MatIntrinsicReplace(CallInst *matInst,
|
|
@@ -1290,7 +1339,7 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
|
|
(matOpcode == HLSubscriptOpcode::RowMatElement);
|
|
(matOpcode == HLSubscriptOpcode::RowMatElement);
|
|
Value *mask =
|
|
Value *mask =
|
|
matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
|
|
matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
|
|
- // For temp matrix, all use col major.
|
|
|
|
|
|
+
|
|
if (isElement) {
|
|
if (isElement) {
|
|
Type *resultType = matSubInst->getType()->getPointerElementType();
|
|
Type *resultType = matSubInst->getType()->getPointerElementType();
|
|
unsigned resultSize = 1;
|
|
unsigned resultSize = 1;
|
|
@@ -1298,20 +1347,10 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
|
|
resultSize = resultType->getVectorNumElements();
|
|
resultSize = resultType->getVectorNumElements();
|
|
|
|
|
|
std::vector<int> shufMask(resultSize);
|
|
std::vector<int> shufMask(resultSize);
|
|
- if (ConstantDataSequential *elts = dyn_cast<ConstantDataSequential>(mask)) {
|
|
|
|
- unsigned count = elts->getNumElements();
|
|
|
|
- for (unsigned i = 0; i < count; i += 2) {
|
|
|
|
- unsigned rowIdx = elts->getElementAsInteger(i);
|
|
|
|
- unsigned colIdx = elts->getElementAsInteger(i + 1);
|
|
|
|
- unsigned matIdx = GetMatIdx(rowIdx, colIdx, row);
|
|
|
|
- shufMask[i>>1] = matIdx;
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- else {
|
|
|
|
- ConstantAggregateZero *zeros = cast<ConstantAggregateZero>(mask);
|
|
|
|
- unsigned size = zeros->getNumElements()>>1;
|
|
|
|
- for (unsigned i=0;i<size;i++)
|
|
|
|
- shufMask[i] = 0;
|
|
|
|
|
|
+ Constant *EltIdxs = cast<Constant>(mask);
|
|
|
|
+ for (unsigned i = 0; i < resultSize; i++) {
|
|
|
|
+ shufMask[i] =
|
|
|
|
+ cast<ConstantInt>(EltIdxs->getAggregateElement(i))->getLimitedValue();
|
|
}
|
|
}
|
|
|
|
|
|
for (Value::use_iterator CallUI = matSubInst->use_begin(),
|
|
for (Value::use_iterator CallUI = matSubInst->use_begin(),
|
|
@@ -1357,21 +1396,24 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
|
|
Value *tempArray = AllocaBuilder.CreateAlloca(AT);
|
|
Value *tempArray = AllocaBuilder.CreateAlloca(AT);
|
|
Value *zero = AllocaBuilder.getInt32(0);
|
|
Value *zero = AllocaBuilder.getInt32(0);
|
|
bool isDynamicIndexing = !isa<ConstantInt>(mask);
|
|
bool isDynamicIndexing = !isa<ConstantInt>(mask);
|
|
|
|
+ SmallVector<Value *, 4> idxList;
|
|
|
|
+ for (unsigned i = 0; i < col; i++) {
|
|
|
|
+ idxList.emplace_back(
|
|
|
|
+ matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + i));
|
|
|
|
+ }
|
|
|
|
|
|
for (Value::use_iterator CallUI = matSubInst->use_begin(),
|
|
for (Value::use_iterator CallUI = matSubInst->use_begin(),
|
|
CallE = matSubInst->use_end();
|
|
CallE = matSubInst->use_end();
|
|
CallUI != CallE;) {
|
|
CallUI != CallE;) {
|
|
Use &CallUse = *CallUI++;
|
|
Use &CallUse = *CallUI++;
|
|
Instruction *CallUser = cast<Instruction>(CallUse.getUser());
|
|
Instruction *CallUser = cast<Instruction>(CallUse.getUser());
|
|
- Value *idx = mask;
|
|
|
|
IRBuilder<> Builder(CallUser);
|
|
IRBuilder<> Builder(CallUser);
|
|
Value *vecLd = Builder.CreateLoad(vecInst);
|
|
Value *vecLd = Builder.CreateLoad(vecInst);
|
|
if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
|
|
if (LoadInst *ld = dyn_cast<LoadInst>(CallUser)) {
|
|
Value *sub = UndefValue::get(ld->getType());
|
|
Value *sub = UndefValue::get(ld->getType());
|
|
if (!isDynamicIndexing) {
|
|
if (!isDynamicIndexing) {
|
|
for (unsigned i = 0; i < col; i++) {
|
|
for (unsigned i = 0; i < col; i++) {
|
|
- // col major: matIdx = c * row + r;
|
|
|
|
- Value *matIdx = Builder.CreateAdd(idx, Builder.getInt32(i * row));
|
|
|
|
|
|
+ Value *matIdx = idxList[i];
|
|
Value *valElt = Builder.CreateExtractElement(vecLd, matIdx);
|
|
Value *valElt = Builder.CreateExtractElement(vecLd, matIdx);
|
|
sub = Builder.CreateInsertElement(sub, valElt, i);
|
|
sub = Builder.CreateInsertElement(sub, valElt, i);
|
|
}
|
|
}
|
|
@@ -1385,8 +1427,7 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
|
|
Builder.CreateStore(Elt, Ptr);
|
|
Builder.CreateStore(Elt, Ptr);
|
|
}
|
|
}
|
|
for (unsigned i = 0; i < col; i++) {
|
|
for (unsigned i = 0; i < col; i++) {
|
|
- // col major: matIdx = c * row + r;
|
|
|
|
- Value *matIdx = Builder.CreateAdd(idx, Builder.getInt32(i * row));
|
|
|
|
|
|
+ Value *matIdx = idxList[i];
|
|
Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
|
|
Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
|
|
Value *valElt = Builder.CreateLoad(Ptr);
|
|
Value *valElt = Builder.CreateLoad(Ptr);
|
|
sub = Builder.CreateInsertElement(sub, valElt, i);
|
|
sub = Builder.CreateInsertElement(sub, valElt, i);
|
|
@@ -1397,8 +1438,7 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
|
|
Value *val = st->getValueOperand();
|
|
Value *val = st->getValueOperand();
|
|
if (!isDynamicIndexing) {
|
|
if (!isDynamicIndexing) {
|
|
for (unsigned i = 0; i < col; i++) {
|
|
for (unsigned i = 0; i < col; i++) {
|
|
- // col major: matIdx = c * row + r;
|
|
|
|
- Value *matIdx = Builder.CreateAdd(idx, Builder.getInt32(i * row));
|
|
|
|
|
|
+ Value *matIdx = idxList[i];
|
|
Value *valElt = Builder.CreateExtractElement(val, i);
|
|
Value *valElt = Builder.CreateExtractElement(val, i);
|
|
vecLd = Builder.CreateInsertElement(vecLd, valElt, matIdx);
|
|
vecLd = Builder.CreateInsertElement(vecLd, valElt, matIdx);
|
|
}
|
|
}
|
|
@@ -1413,8 +1453,7 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
|
|
}
|
|
}
|
|
// Update array.
|
|
// Update array.
|
|
for (unsigned i = 0; i < col; i++) {
|
|
for (unsigned i = 0; i < col; i++) {
|
|
- // col major: matIdx = c * row + r;
|
|
|
|
- Value *matIdx = Builder.CreateAdd(idx, Builder.getInt32(i * row));
|
|
|
|
|
|
+ Value *matIdx = idxList[i];
|
|
Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
|
|
Value *Ptr = Builder.CreateGEP(tempArray, { zero, matIdx});
|
|
Value *valElt = Builder.CreateExtractElement(val, i);
|
|
Value *valElt = Builder.CreateExtractElement(val, i);
|
|
Builder.CreateStore(valElt, Ptr);
|
|
Builder.CreateStore(valElt, Ptr);
|
|
@@ -1430,17 +1469,8 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matInst, Value *vecInst,
|
|
Builder.CreateStore(vecLd, vecInst);
|
|
Builder.CreateStore(vecLd, vecInst);
|
|
} else if (GetElementPtrInst *GEP =
|
|
} else if (GetElementPtrInst *GEP =
|
|
dyn_cast<GetElementPtrInst>(CallUser)) {
|
|
dyn_cast<GetElementPtrInst>(CallUser)) {
|
|
- // Must be for subscript on vector
|
|
|
|
- auto idxIter = GEP->idx_begin();
|
|
|
|
- // skip the zero
|
|
|
|
- idxIter++;
|
|
|
|
- Value *gepIdx = *idxIter;
|
|
|
|
- // Col major matIdx = r + c * row; r is idx, c is gepIdx
|
|
|
|
- Value *iMulRow = Builder.CreateMul(gepIdx, Builder.getInt32(row));
|
|
|
|
- Value *vecIdx = Builder.CreateAdd(iMulRow, idx);
|
|
|
|
-
|
|
|
|
- llvm::Constant *zero = llvm::ConstantInt::get(vecIdx->getType(), 0);
|
|
|
|
- Value *NewGEP = Builder.CreateGEP(vecInst, {zero, vecIdx});
|
|
|
|
|
|
+ Value *GEPOffset = HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
|
|
|
|
+ Value *NewGEP = Builder.CreateGEP(vecInst, {zero, GEPOffset});
|
|
GEP->replaceAllUsesWith(NewGEP);
|
|
GEP->replaceAllUsesWith(NewGEP);
|
|
} else
|
|
} else
|
|
DXASSERT(0, "matrix subscript should only used by load/store.");
|
|
DXASSERT(0, "matrix subscript should only used by load/store.");
|
|
@@ -1458,7 +1488,7 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
|
|
Value *matGlobal, ArrayRef<Value *> vecGlobals,
|
|
Value *matGlobal, ArrayRef<Value *> vecGlobals,
|
|
CallInst *matLdStInst) {
|
|
CallInst *matLdStInst) {
|
|
// No dynamic indexing on matrix, flatten matrix to scalars.
|
|
// No dynamic indexing on matrix, flatten matrix to scalars.
|
|
- // vecGlobals already in col major.
|
|
|
|
|
|
+ // vecGlobals already in correct major.
|
|
Type *matType = matGlobal->getType()->getPointerElementType();
|
|
Type *matType = matGlobal->getType()->getPointerElementType();
|
|
unsigned col, row;
|
|
unsigned col, row;
|
|
HLMatrixLower::GetMatrixInfo(matType, col, row);
|
|
HLMatrixLower::GetMatrixInfo(matType, col, row);
|
|
@@ -1491,7 +1521,7 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
|
|
void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
|
|
void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
|
|
GlobalVariable *scalarArrayGlobal,
|
|
GlobalVariable *scalarArrayGlobal,
|
|
CallInst *matLdStInst) {
|
|
CallInst *matLdStInst) {
|
|
- // vecGlobals already in col major.
|
|
|
|
|
|
+ // vecGlobals already in correct major.
|
|
const bool bColMajor = true;
|
|
const bool bColMajor = true;
|
|
HLMatLoadStoreOpcode opcode =
|
|
HLMatLoadStoreOpcode opcode =
|
|
static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
|
|
static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
|
|
@@ -1513,7 +1543,7 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
|
|
}
|
|
}
|
|
|
|
|
|
Value *newVec =
|
|
Value *newVec =
|
|
- HLMatrixLower::BuildMatrix(EltTy, col, row, bColMajor, matElts, Builder);
|
|
|
|
|
|
+ HLMatrixLower::BuildVector(EltTy, col * row, matElts, Builder);
|
|
matLdStInst->replaceAllUsesWith(newVec);
|
|
matLdStInst->replaceAllUsesWith(newVec);
|
|
matLdStInst->eraseFromParent();
|
|
matLdStInst->eraseFromParent();
|
|
} break;
|
|
} break;
|
|
@@ -1540,9 +1570,10 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
|
|
} break;
|
|
} break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-void HLMatrixLowerPass::TranslateMatSubscriptOnGlobal(GlobalVariable *matGlobal,
|
|
|
|
- GlobalVariable *scalarArrayGlobal,
|
|
|
|
- CallInst *matSubInst) {
|
|
|
|
|
|
+void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(
|
|
|
|
+ CallInst *matSubInst, Value *vecPtr) {
|
|
|
|
+ Value *basePtr =
|
|
|
|
+ matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
|
|
Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
|
|
Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
|
|
IRBuilder<> subBuilder(matSubInst);
|
|
IRBuilder<> subBuilder(matSubInst);
|
|
Value *zeroIdx = subBuilder.getInt32(0);
|
|
Value *zeroIdx = subBuilder.getInt32(0);
|
|
@@ -1550,46 +1581,32 @@ void HLMatrixLowerPass::TranslateMatSubscriptOnGlobal(GlobalVariable *matGlobal,
|
|
HLSubscriptOpcode opcode =
|
|
HLSubscriptOpcode opcode =
|
|
static_cast<HLSubscriptOpcode>(GetHLOpcode(matSubInst));
|
|
static_cast<HLSubscriptOpcode>(GetHLOpcode(matSubInst));
|
|
|
|
|
|
- Type *matTy = matGlobal->getType()->getPointerElementType();
|
|
|
|
|
|
+ Type *matTy = basePtr->getType()->getPointerElementType();
|
|
unsigned col, row;
|
|
unsigned col, row;
|
|
HLMatrixLower::GetMatrixInfo(matTy, col, row);
|
|
HLMatrixLower::GetMatrixInfo(matTy, col, row);
|
|
|
|
|
|
- std::vector<Value *> Ptrs;
|
|
|
|
|
|
+ std::vector<Value *> idxList;
|
|
switch (opcode) {
|
|
switch (opcode) {
|
|
case HLSubscriptOpcode::ColMatSubscript:
|
|
case HLSubscriptOpcode::ColMatSubscript:
|
|
case HLSubscriptOpcode::RowMatSubscript: {
|
|
case HLSubscriptOpcode::RowMatSubscript: {
|
|
- // Use col major for internal matrix.
|
|
|
|
- // And subscripts will return a row.
|
|
|
|
|
|
+ // Just use index created in EmitHLSLMatrixSubscript.
|
|
for (unsigned c = 0; c < col; c++) {
|
|
for (unsigned c = 0; c < col; c++) {
|
|
- Value *colIdxBase = subBuilder.getInt32(c * row);
|
|
|
|
- Value *matIdx = subBuilder.CreateAdd(colIdxBase, idx);
|
|
|
|
- Value *Ptr =
|
|
|
|
- subBuilder.CreateInBoundsGEP(scalarArrayGlobal, {zeroIdx, matIdx});
|
|
|
|
- Ptrs.emplace_back(Ptr);
|
|
|
|
|
|
+ Value *matIdx =
|
|
|
|
+ matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx + c);
|
|
|
|
+ idxList.emplace_back(matIdx);
|
|
}
|
|
}
|
|
} break;
|
|
} break;
|
|
case HLSubscriptOpcode::RowMatElement:
|
|
case HLSubscriptOpcode::RowMatElement:
|
|
case HLSubscriptOpcode::ColMatElement: {
|
|
case HLSubscriptOpcode::ColMatElement: {
|
|
- // Use col major for internal matrix.
|
|
|
|
- if (ConstantDataSequential *elts = dyn_cast<ConstantDataSequential>(idx)) {
|
|
|
|
- unsigned count = elts->getNumElements();
|
|
|
|
-
|
|
|
|
- for (unsigned i = 0; i < count; i += 2) {
|
|
|
|
- unsigned rowIdx = elts->getElementAsInteger(i);
|
|
|
|
- unsigned colIdx = elts->getElementAsInteger(i + 1);
|
|
|
|
- Value *matIdx = subBuilder.getInt32(colIdx * row + rowIdx);
|
|
|
|
- Value *Ptr =
|
|
|
|
- subBuilder.CreateInBoundsGEP(scalarArrayGlobal, {zeroIdx, matIdx});
|
|
|
|
- Ptrs.emplace_back(Ptr);
|
|
|
|
- }
|
|
|
|
- } else {
|
|
|
|
- ConstantAggregateZero *zeros = cast<ConstantAggregateZero>(idx);
|
|
|
|
- unsigned size = zeros->getNumElements() >> 1;
|
|
|
|
- for (unsigned i = 0; i < size; i++) {
|
|
|
|
- Value *Ptr =
|
|
|
|
- subBuilder.CreateInBoundsGEP(scalarArrayGlobal, {zeroIdx, zeroIdx});
|
|
|
|
- Ptrs.emplace_back(Ptr);
|
|
|
|
- }
|
|
|
|
|
|
+ Type *resultType = matSubInst->getType()->getPointerElementType();
|
|
|
|
+ unsigned resultSize = 1;
|
|
|
|
+ if (resultType->isVectorTy())
|
|
|
|
+ resultSize = resultType->getVectorNumElements();
|
|
|
|
+ // Just use index created in EmitHLSLMatrixElement.
|
|
|
|
+ Constant *EltIdxs = cast<Constant>(idx);
|
|
|
|
+ for (unsigned i = 0; i < resultSize; i++) {
|
|
|
|
+ Value *matIdx = EltIdxs->getAggregateElement(i);
|
|
|
|
+ idxList.emplace_back(matIdx);
|
|
}
|
|
}
|
|
} break;
|
|
} break;
|
|
default:
|
|
default:
|
|
@@ -1599,90 +1616,54 @@ void HLMatrixLowerPass::TranslateMatSubscriptOnGlobal(GlobalVariable *matGlobal,
|
|
|
|
|
|
// Cannot generate vector pointer
|
|
// Cannot generate vector pointer
|
|
// Replace all uses with scalar pointers.
|
|
// Replace all uses with scalar pointers.
|
|
- if (Ptrs.size() == 1)
|
|
|
|
- matSubInst->replaceAllUsesWith(Ptrs[0]);
|
|
|
|
- else {
|
|
|
|
|
|
+ if (idxList.size() == 1) {
|
|
|
|
+ Value *Ptr =
|
|
|
|
+ subBuilder.CreateInBoundsGEP(vecPtr, {zeroIdx, idxList[0]});
|
|
|
|
+ matSubInst->replaceAllUsesWith(Ptr);
|
|
|
|
+ } else {
|
|
// Split the use of CI with Ptrs.
|
|
// Split the use of CI with Ptrs.
|
|
for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
|
|
for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
|
|
Instruction *subsUser = cast<Instruction>(*(U++));
|
|
Instruction *subsUser = cast<Instruction>(*(U++));
|
|
IRBuilder<> userBuilder(subsUser);
|
|
IRBuilder<> userBuilder(subsUser);
|
|
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
|
|
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
|
|
- DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
|
|
|
|
- Value *baseIdx = (GEP->idx_begin())->get();
|
|
|
|
- DXASSERT_LOCALVAR(baseIdx, baseIdx == zeroIdx, "base index must be 0");
|
|
|
|
- Value *idx = (GEP->idx_begin() + 1)->get();
|
|
|
|
|
|
+ Value *IndexPtr =
|
|
|
|
+ HLMatrixLower::LowerGEPOnMatIndexListToIndex(GEP, idxList);
|
|
|
|
+ Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
|
|
|
|
+ {zeroIdx, IndexPtr});
|
|
for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
|
|
for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
|
|
Instruction *gepUser = cast<Instruction>(*(gepU++));
|
|
Instruction *gepUser = cast<Instruction>(*(gepU++));
|
|
IRBuilder<> gepUserBuilder(gepUser);
|
|
IRBuilder<> gepUserBuilder(gepUser);
|
|
if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
|
|
if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
|
|
Value *subData = stUser->getValueOperand();
|
|
Value *subData = stUser->getValueOperand();
|
|
- if (ConstantInt *immIdx = dyn_cast<ConstantInt>(idx)) {
|
|
|
|
- Value *Ptr = Ptrs[immIdx->getSExtValue()];
|
|
|
|
- gepUserBuilder.CreateStore(subData, Ptr);
|
|
|
|
- } else {
|
|
|
|
- // Create a temp array.
|
|
|
|
- IRBuilder<> allocaBuilder(stUser->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
|
|
|
|
- Value *tempArray = allocaBuilder.CreateAlloca(
|
|
|
|
- ArrayType::get(subData->getType(), Ptrs.size()));
|
|
|
|
- // Store value to temp array.
|
|
|
|
- for (unsigned i = 0; i < Ptrs.size(); i++) {
|
|
|
|
- Value *Elt = gepUserBuilder.CreateLoad(Ptrs[i]);
|
|
|
|
- Value *EltGEP = gepUserBuilder.CreateGEP(tempArray, {zeroIdx, gepUserBuilder.getInt32(i)} );
|
|
|
|
- gepUserBuilder.CreateStore(Elt, EltGEP);
|
|
|
|
- }
|
|
|
|
- // Dynamic indexing.
|
|
|
|
- Value *subGEP =
|
|
|
|
- gepUserBuilder.CreateInBoundsGEP(tempArray, {zeroIdx, idx});
|
|
|
|
- gepUserBuilder.CreateStore(subData, subGEP);
|
|
|
|
- // Store temp array to value.
|
|
|
|
- for (unsigned i = 0; i < Ptrs.size(); i++) {
|
|
|
|
- Value *EltGEP = gepUserBuilder.CreateGEP(tempArray, {zeroIdx, gepUserBuilder.getInt32(i)} );
|
|
|
|
- Value *Elt = gepUserBuilder.CreateLoad(EltGEP);
|
|
|
|
- gepUserBuilder.CreateStore(Elt, Ptrs[i]);
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
|
|
+ gepUserBuilder.CreateStore(subData, Ptr);
|
|
stUser->eraseFromParent();
|
|
stUser->eraseFromParent();
|
|
- } else {
|
|
|
|
- // Must be load here;
|
|
|
|
- LoadInst *ldUser = cast<LoadInst>(gepUser);
|
|
|
|
- Value *subData = nullptr;
|
|
|
|
- if (ConstantInt *immIdx = dyn_cast<ConstantInt>(idx)) {
|
|
|
|
- Value *Ptr = Ptrs[immIdx->getSExtValue()];
|
|
|
|
- subData = gepUserBuilder.CreateLoad(Ptr);
|
|
|
|
- } else {
|
|
|
|
- // Create a temp array.
|
|
|
|
- IRBuilder<> allocaBuilder(ldUser->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
|
|
|
|
- Value *tempArray = allocaBuilder.CreateAlloca(
|
|
|
|
- ArrayType::get(ldUser->getType(), Ptrs.size()));
|
|
|
|
- // Store value to temp array.
|
|
|
|
- for (unsigned i = 0; i < Ptrs.size(); i++) {
|
|
|
|
- Value *Elt = gepUserBuilder.CreateLoad(Ptrs[i]);
|
|
|
|
- Value *EltGEP = gepUserBuilder.CreateGEP(tempArray, {zeroIdx, gepUserBuilder.getInt32(i)} );
|
|
|
|
- gepUserBuilder.CreateStore(Elt, EltGEP);
|
|
|
|
- }
|
|
|
|
- // Dynamic indexing.
|
|
|
|
- Value *subGEP =
|
|
|
|
- gepUserBuilder.CreateInBoundsGEP(tempArray, {zeroIdx, idx});
|
|
|
|
- subData = gepUserBuilder.CreateLoad(subGEP);
|
|
|
|
- }
|
|
|
|
|
|
+ } else if (LoadInst *ldUser = dyn_cast<LoadInst>(gepUser)) {
|
|
|
|
+ Value *subData = gepUserBuilder.CreateLoad(Ptr);
|
|
ldUser->replaceAllUsesWith(subData);
|
|
ldUser->replaceAllUsesWith(subData);
|
|
ldUser->eraseFromParent();
|
|
ldUser->eraseFromParent();
|
|
|
|
+ } else {
|
|
|
|
+ AddrSpaceCastInst *Cast = cast<AddrSpaceCastInst>(gepUser);
|
|
|
|
+ Cast->setOperand(0, Ptr);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
GEP->eraseFromParent();
|
|
GEP->eraseFromParent();
|
|
} else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
|
|
} else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
|
|
Value *val = stUser->getValueOperand();
|
|
Value *val = stUser->getValueOperand();
|
|
- for (unsigned i = 0; i < Ptrs.size(); i++) {
|
|
|
|
|
|
+ for (unsigned i = 0; i < idxList.size(); i++) {
|
|
Value *Elt = userBuilder.CreateExtractElement(val, i);
|
|
Value *Elt = userBuilder.CreateExtractElement(val, i);
|
|
- userBuilder.CreateStore(Elt, Ptrs[i]);
|
|
|
|
|
|
+ Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
|
|
|
|
+ {zeroIdx, idxList[i]});
|
|
|
|
+ userBuilder.CreateStore(Elt, Ptr);
|
|
}
|
|
}
|
|
stUser->eraseFromParent();
|
|
stUser->eraseFromParent();
|
|
} else {
|
|
} else {
|
|
|
|
|
|
Value *ldVal =
|
|
Value *ldVal =
|
|
UndefValue::get(matSubInst->getType()->getPointerElementType());
|
|
UndefValue::get(matSubInst->getType()->getPointerElementType());
|
|
- for (unsigned i = 0; i < Ptrs.size(); i++) {
|
|
|
|
- Value *Elt = userBuilder.CreateLoad(Ptrs[i]);
|
|
|
|
|
|
+ for (unsigned i = 0; i < idxList.size(); i++) {
|
|
|
|
+ Value *Ptr = userBuilder.CreateInBoundsGEP(vecPtr,
|
|
|
|
+ {zeroIdx, idxList[i]});
|
|
|
|
+ Value *Elt = userBuilder.CreateLoad(Ptr);
|
|
ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
|
|
ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
|
|
}
|
|
}
|
|
// Must be load here.
|
|
// Must be load here.
|
|
@@ -1695,128 +1676,9 @@ void HLMatrixLowerPass::TranslateMatSubscriptOnGlobal(GlobalVariable *matGlobal,
|
|
matSubInst->eraseFromParent();
|
|
matSubInst->eraseFromParent();
|
|
}
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr) {
|
|
|
|
- // Just translate into vec array here.
|
|
|
|
- // DynamicIndexingVectorToArray will change it to scalar array.
|
|
|
|
- Value *basePtr = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptMatOpIdx);
|
|
|
|
- IRBuilder<> subBuilder(matSubInst);
|
|
|
|
- unsigned opcode = hlsl::GetHLOpcode(matSubInst);
|
|
|
|
- HLSubscriptOpcode subOp = static_cast<HLSubscriptOpcode>(opcode);
|
|
|
|
-
|
|
|
|
- // Vector array is inside struct.
|
|
|
|
- Value *zeroIdx = subBuilder.getInt32(0);
|
|
|
|
- Value *vecArrayGep = vecPtr;
|
|
|
|
-
|
|
|
|
- Type *matType = basePtr->getType()->getPointerElementType();
|
|
|
|
- unsigned col, row;
|
|
|
|
- HLMatrixLower::GetMatrixInfo(matType, col, row);
|
|
|
|
-
|
|
|
|
- Value *idx = matSubInst->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx);
|
|
|
|
-
|
|
|
|
- std::vector<Value *> Ptrs;
|
|
|
|
- switch (subOp) {
|
|
|
|
- case HLSubscriptOpcode::ColMatSubscript:
|
|
|
|
- case HLSubscriptOpcode::RowMatSubscript: {
|
|
|
|
- // Vector array is row major.
|
|
|
|
- // And subscripts will return a row.
|
|
|
|
- Value *rowIdx = idx;
|
|
|
|
- Value *subPtr = subBuilder.CreateInBoundsGEP(vecArrayGep, {zeroIdx, rowIdx});
|
|
|
|
- matSubInst->replaceAllUsesWith(subPtr);
|
|
|
|
- matSubInst->eraseFromParent();
|
|
|
|
- return;
|
|
|
|
- } break;
|
|
|
|
- case HLSubscriptOpcode::RowMatElement:
|
|
|
|
- case HLSubscriptOpcode::ColMatElement: {
|
|
|
|
- // Vector array is row major.
|
|
|
|
- if (ConstantDataSequential *elts = dyn_cast<ConstantDataSequential>(idx)) {
|
|
|
|
- unsigned count = elts->getNumElements();
|
|
|
|
-
|
|
|
|
- for (unsigned i = 0; i < count; i += 2) {
|
|
|
|
- Value *rowIdx = subBuilder.getInt32(elts->getElementAsInteger(i));
|
|
|
|
- Value *colIdx = subBuilder.getInt32(elts->getElementAsInteger(i + 1));
|
|
|
|
- Value *Ptr =
|
|
|
|
- subBuilder.CreateInBoundsGEP(vecArrayGep, {zeroIdx, rowIdx, colIdx});
|
|
|
|
- Ptrs.emplace_back(Ptr);
|
|
|
|
- }
|
|
|
|
- } else {
|
|
|
|
- ConstantAggregateZero *zeros = cast<ConstantAggregateZero>(idx);
|
|
|
|
- unsigned size = zeros->getNumElements() >> 1;
|
|
|
|
- for (unsigned i = 0; i < size; i++) {
|
|
|
|
- Value *Ptr =
|
|
|
|
- subBuilder.CreateInBoundsGEP(vecArrayGep, {zeroIdx, zeroIdx, zeroIdx});
|
|
|
|
- Ptrs.emplace_back(Ptr);
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- } break;
|
|
|
|
- default:
|
|
|
|
- DXASSERT(0, "invalid operation for TranslateMatSubscriptOnGlobalPtr");
|
|
|
|
- break;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if (Ptrs.size() == 1)
|
|
|
|
- matSubInst->replaceAllUsesWith(Ptrs[0]);
|
|
|
|
- else {
|
|
|
|
- // Split the use of CI with Ptrs.
|
|
|
|
- for (auto U = matSubInst->user_begin(); U != matSubInst->user_end();) {
|
|
|
|
- Instruction *subsUser = cast<Instruction>(*(U++));
|
|
|
|
- IRBuilder<> userBuilder(subsUser);
|
|
|
|
- if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
|
|
|
|
- DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
|
|
|
|
- Value *baseIdx = (GEP->idx_begin())->get();
|
|
|
|
- DXASSERT_LOCALVAR(baseIdx, baseIdx == zeroIdx, "base index must be 0");
|
|
|
|
- Value *idx = (GEP->idx_begin() + 1)->get();
|
|
|
|
- for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
|
|
|
|
- Instruction *gepUser = cast<Instruction>(*(gepU++));
|
|
|
|
- IRBuilder<> gepUserBuilder(gepUser);
|
|
|
|
- if (StoreInst *stUser = dyn_cast<StoreInst>(gepUser)) {
|
|
|
|
- Value *subData = stUser->getValueOperand();
|
|
|
|
- // Only element can reach here.
|
|
|
|
- // So index must be imm.
|
|
|
|
- ConstantInt *immIdx = cast<ConstantInt>(idx);
|
|
|
|
- Value *Ptr = Ptrs[immIdx->getSExtValue()];
|
|
|
|
- gepUserBuilder.CreateStore(subData, Ptr);
|
|
|
|
- stUser->eraseFromParent();
|
|
|
|
- } else {
|
|
|
|
- // Must be load here;
|
|
|
|
- LoadInst *ldUser = cast<LoadInst>(gepUser);
|
|
|
|
- Value *subData = nullptr;
|
|
|
|
- // Only element can reach here.
|
|
|
|
- // So index must be imm.
|
|
|
|
- ConstantInt *immIdx = cast<ConstantInt>(idx);
|
|
|
|
- Value *Ptr = Ptrs[immIdx->getSExtValue()];
|
|
|
|
- subData = gepUserBuilder.CreateLoad(Ptr);
|
|
|
|
- ldUser->replaceAllUsesWith(subData);
|
|
|
|
- ldUser->eraseFromParent();
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- GEP->eraseFromParent();
|
|
|
|
- } else if (StoreInst *stUser = dyn_cast<StoreInst>(subsUser)) {
|
|
|
|
- Value *val = stUser->getValueOperand();
|
|
|
|
- for (unsigned i = 0; i < Ptrs.size(); i++) {
|
|
|
|
- Value *Elt = userBuilder.CreateExtractElement(val, i);
|
|
|
|
- userBuilder.CreateStore(Elt, Ptrs[i]);
|
|
|
|
- }
|
|
|
|
- stUser->eraseFromParent();
|
|
|
|
- } else {
|
|
|
|
- // Must be load here.
|
|
|
|
- LoadInst *ldUser = cast<LoadInst>(subsUser);
|
|
|
|
- // reload the value.
|
|
|
|
- Value *ldVal =
|
|
|
|
- UndefValue::get(matSubInst->getType()->getPointerElementType());
|
|
|
|
- for (unsigned i = 0; i < Ptrs.size(); i++) {
|
|
|
|
- Value *Elt = userBuilder.CreateLoad(Ptrs[i]);
|
|
|
|
- ldVal = userBuilder.CreateInsertElement(ldVal, Elt, i);
|
|
|
|
- }
|
|
|
|
- ldUser->replaceAllUsesWith(ldVal);
|
|
|
|
- ldUser->eraseFromParent();
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- matSubInst->eraseFromParent();
|
|
|
|
-}
|
|
|
|
void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
|
|
void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
|
|
CallInst *matLdStInst, Value *vecPtr) {
|
|
CallInst *matLdStInst, Value *vecPtr) {
|
|
- // Just translate into vec array here.
|
|
|
|
|
|
+ // Just translate into vector here.
|
|
// DynamicIndexingVectorToArray will change it to scalar array.
|
|
// DynamicIndexingVectorToArray will change it to scalar array.
|
|
IRBuilder<> Builder(matLdStInst);
|
|
IRBuilder<> Builder(matLdStInst);
|
|
unsigned opcode = hlsl::GetHLOpcode(matLdStInst);
|
|
unsigned opcode = hlsl::GetHLOpcode(matLdStInst);
|
|
@@ -1824,56 +1686,19 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobalPtr(
|
|
switch (matLdStOp) {
|
|
switch (matLdStOp) {
|
|
case HLMatLoadStoreOpcode::ColMatLoad:
|
|
case HLMatLoadStoreOpcode::ColMatLoad:
|
|
case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
- // Load as vector array.
|
|
|
|
|
|
+ // Load as vector.
|
|
Value *newLoad = Builder.CreateLoad(vecPtr);
|
|
Value *newLoad = Builder.CreateLoad(vecPtr);
|
|
- // Then change to vector.
|
|
|
|
- // Use col major.
|
|
|
|
- Value *Ptr = matLdStInst->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
|
|
|
|
- Type *matTy = Ptr->getType()->getPointerElementType();
|
|
|
|
- unsigned col, row;
|
|
|
|
- HLMatrixLower::GetMatrixInfo(matTy, col, row);
|
|
|
|
- Value *NewVal = UndefValue::get(matLdStInst->getType());
|
|
|
|
- // Vector array is row major.
|
|
|
|
- for (unsigned r = 0; r < row; r++) {
|
|
|
|
- Value *eltRow = Builder.CreateExtractValue(newLoad, r);
|
|
|
|
- for (unsigned c = 0; c < col; c++) {
|
|
|
|
- Value *elt = Builder.CreateExtractElement(eltRow, c);
|
|
|
|
- // Vector is col major.
|
|
|
|
- unsigned matIdx = c * row + r;
|
|
|
|
- NewVal = Builder.CreateInsertElement(NewVal, elt, matIdx);
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- matLdStInst->replaceAllUsesWith(NewVal);
|
|
|
|
|
|
+
|
|
|
|
+ matLdStInst->replaceAllUsesWith(newLoad);
|
|
matLdStInst->eraseFromParent();
|
|
matLdStInst->eraseFromParent();
|
|
} break;
|
|
} break;
|
|
case HLMatLoadStoreOpcode::ColMatStore:
|
|
case HLMatLoadStoreOpcode::ColMatStore:
|
|
case HLMatLoadStoreOpcode::RowMatStore: {
|
|
case HLMatLoadStoreOpcode::RowMatStore: {
|
|
// Change value to vector array, then store.
|
|
// Change value to vector array, then store.
|
|
Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
- Value *Ptr = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
|
|
|
|
-
|
|
|
|
- Type *matTy = Ptr->getType()->getPointerElementType();
|
|
|
|
- unsigned col, row;
|
|
|
|
- Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
|
|
|
|
- Type *rowTy = VectorType::get(EltTy, row);
|
|
|
|
|
|
|
|
Value *vecArrayGep = vecPtr;
|
|
Value *vecArrayGep = vecPtr;
|
|
-
|
|
|
|
- Value *NewVal =
|
|
|
|
- UndefValue::get(vecArrayGep->getType()->getPointerElementType());
|
|
|
|
-
|
|
|
|
- // Vector array val is row major.
|
|
|
|
- for (unsigned r = 0; r < row; r++) {
|
|
|
|
- Value *NewElt = UndefValue::get(rowTy);
|
|
|
|
- for (unsigned c = 0; c < col; c++) {
|
|
|
|
- // Vector val is col major.
|
|
|
|
- unsigned matIdx = c * row + r;
|
|
|
|
- Value *elt = Builder.CreateExtractElement(Val, matIdx);
|
|
|
|
- NewElt = Builder.CreateInsertElement(NewElt, elt, c);
|
|
|
|
- }
|
|
|
|
- NewVal = Builder.CreateInsertValue(NewVal, NewElt, r);
|
|
|
|
- }
|
|
|
|
- Builder.CreateStore(NewVal, vecArrayGep);
|
|
|
|
|
|
+ Builder.CreateStore(Val, vecArrayGep);
|
|
matLdStInst->eraseFromParent();
|
|
matLdStInst->eraseFromParent();
|
|
} break;
|
|
} break;
|
|
default:
|
|
default:
|
|
@@ -1928,7 +1753,7 @@ static void IterateInitList(MutableArrayRef<Value *> elts, unsigned &idx,
|
|
HLMatrixLower::GetMatrixInfo(valTy, col, row);
|
|
HLMatrixLower::GetMatrixInfo(valTy, col, row);
|
|
unsigned matSize = col * row;
|
|
unsigned matSize = col * row;
|
|
val = matToVecMap[cast<Instruction>(val)];
|
|
val = matToVecMap[cast<Instruction>(val)];
|
|
- // temp matrix all col major
|
|
|
|
|
|
+ // temp matrix all row major
|
|
for (unsigned i = 0; i < matSize; i++) {
|
|
for (unsigned i = 0; i < matSize; i++) {
|
|
Value *Elt = Builder.CreateExtractElement(val, i);
|
|
Value *Elt = Builder.CreateExtractElement(val, i);
|
|
elts[idx + i] = Elt;
|
|
elts[idx + i] = Elt;
|
|
@@ -2005,15 +1830,13 @@ void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
|
|
}
|
|
}
|
|
|
|
|
|
Value *newInit = UndefValue::get(vecTy);
|
|
Value *newInit = UndefValue::get(vecTy);
|
|
- // InitList is row major, the result is col major.
|
|
|
|
- for (unsigned c = 0; c < col; c++)
|
|
|
|
- for (unsigned r = 0; r < row; r++) {
|
|
|
|
- unsigned rowMajorIdx = r * col + c;
|
|
|
|
- unsigned colMajorIdx = c * row + r;
|
|
|
|
- Constant *vecIdx = Builder.getInt32(colMajorIdx);
|
|
|
|
- newInit = InsertElementInst::Create(newInit, elts[rowMajorIdx], vecIdx);
|
|
|
|
|
|
+ // InitList is row major, the result is row major too.
|
|
|
|
+ for (unsigned i=0;i< col * row;i++) {
|
|
|
|
+ Constant *vecIdx = Builder.getInt32(i);
|
|
|
|
+ newInit = InsertElementInst::Create(newInit, elts[i], vecIdx);
|
|
Builder.Insert(cast<Instruction>(newInit));
|
|
Builder.Insert(cast<Instruction>(newInit));
|
|
- }
|
|
|
|
|
|
+ }
|
|
|
|
+
|
|
// Replace matInit function call with matInitInst.
|
|
// Replace matInit function call with matInitInst.
|
|
DXASSERT(matToVecMap.count(matInitInst), "must has vec version");
|
|
DXASSERT(matToVecMap.count(matInitInst), "must has vec version");
|
|
Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
|
|
Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
|
|
@@ -2247,25 +2070,36 @@ static bool OnlyUsedByMatrixLdSt(Value *V) {
|
|
return onlyLdSt;
|
|
return onlyLdSt;
|
|
}
|
|
}
|
|
|
|
|
|
-static Constant *LowerMatrixArrayConst(Constant *MA, ArrayType *ResultTy) {
|
|
|
|
- if (ArrayType *AT = dyn_cast<ArrayType>(MA->getType())) {
|
|
|
|
|
|
+static Constant *LowerMatrixArrayConst(Constant *MA, Type *ResultTy) {
|
|
|
|
+ if (ArrayType *AT = dyn_cast<ArrayType>(ResultTy)) {
|
|
std::vector<Constant *> Elts;
|
|
std::vector<Constant *> Elts;
|
|
- ArrayType *EltResultTy = cast<ArrayType>(ResultTy->getElementType());
|
|
|
|
|
|
+ Type *EltResultTy = AT->getElementType();
|
|
for (unsigned i = 0; i < AT->getNumElements(); i++) {
|
|
for (unsigned i = 0; i < AT->getNumElements(); i++) {
|
|
Constant *Elt =
|
|
Constant *Elt =
|
|
LowerMatrixArrayConst(MA->getAggregateElement(i), EltResultTy);
|
|
LowerMatrixArrayConst(MA->getAggregateElement(i), EltResultTy);
|
|
Elts.emplace_back(Elt);
|
|
Elts.emplace_back(Elt);
|
|
}
|
|
}
|
|
- return ConstantArray::get(ResultTy, Elts);
|
|
|
|
|
|
+ return ConstantArray::get(AT, Elts);
|
|
} else {
|
|
} else {
|
|
|
|
+ // Cast float[row][col] -> float< row * col>.
|
|
// Get float[row][col] from the struct.
|
|
// Get float[row][col] from the struct.
|
|
- return MA->getAggregateElement((unsigned)0);
|
|
|
|
|
|
+ Constant *rows = MA->getAggregateElement((unsigned)0);
|
|
|
|
+ ArrayType *RowAT = cast<ArrayType>(rows->getType());
|
|
|
|
+ std::vector<Constant *> Elts;
|
|
|
|
+ for (unsigned r=0;r<RowAT->getArrayNumElements();r++) {
|
|
|
|
+ Constant *row = rows->getAggregateElement(r);
|
|
|
|
+ VectorType *VT = cast<VectorType>(row->getType());
|
|
|
|
+ for (unsigned c = 0; c < VT->getVectorNumElements(); c++) {
|
|
|
|
+ Elts.emplace_back(row->getAggregateElement(c));
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return ConstantVector::get(Elts);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
|
|
void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
|
|
- // Lower to array of vector array like float[row][col].
|
|
|
|
- // It's row major.
|
|
|
|
|
|
+ // Lower to array of vector array like float[row * col].
|
|
|
|
+ // It's follow the major of decl.
|
|
// DynamicIndexingVectorToArray will change it to scalar array.
|
|
// DynamicIndexingVectorToArray will change it to scalar array.
|
|
Type *Ty = GV->getType()->getPointerElementType();
|
|
Type *Ty = GV->getType()->getPointerElementType();
|
|
std::vector<unsigned> arraySizeList;
|
|
std::vector<unsigned> arraySizeList;
|
|
@@ -2275,8 +2109,7 @@ void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
|
|
}
|
|
}
|
|
unsigned row, col;
|
|
unsigned row, col;
|
|
Type *EltTy = GetMatrixInfo(Ty, col, row);
|
|
Type *EltTy = GetMatrixInfo(Ty, col, row);
|
|
- Ty = VectorType::get(EltTy, col);
|
|
|
|
- Ty = ArrayType::get(Ty, row);
|
|
|
|
|
|
+ Ty = VectorType::get(EltTy, col * row);
|
|
|
|
|
|
for (auto arraySize = arraySizeList.rbegin();
|
|
for (auto arraySize = arraySizeList.rbegin();
|
|
arraySize != arraySizeList.rend(); arraySize++)
|
|
arraySize != arraySizeList.rend(); arraySize++)
|
|
@@ -2348,12 +2181,13 @@ static void FlattenMatConst(Constant *M, std::vector<Constant *> &Elts) {
|
|
Elts.emplace_back(Elt);
|
|
Elts.emplace_back(Elt);
|
|
} else {
|
|
} else {
|
|
M = M->getAggregateElement((unsigned)0);
|
|
M = M->getAggregateElement((unsigned)0);
|
|
- // Initializer is row major.
|
|
|
|
- // Make it col major to match temp matrix.
|
|
|
|
- for (unsigned c = 0; c < col; c++) {
|
|
|
|
- for (unsigned r = 0; r < row; r++) {
|
|
|
|
- Constant *R = M->getAggregateElement(r);
|
|
|
|
- Elts.emplace_back(R->getAggregateElement(c));
|
|
|
|
|
|
+ // Initializer is already in correct major.
|
|
|
|
+ // Just read it here.
|
|
|
|
+ // The type is vector<element, col>[row].
|
|
|
|
+ for (unsigned r = 0; r < row; r++) {
|
|
|
|
+ Constant *C = M->getAggregateElement(r);
|
|
|
|
+ for (unsigned c = 0; c < col; c++) {
|
|
|
|
+ Elts.emplace_back(C->getAggregateElement(c));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -2441,7 +2275,7 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
|
|
}
|
|
}
|
|
else {
|
|
else {
|
|
DXASSERT(group == HLOpcodeGroup::HLSubscript, "Must be subscript operation");
|
|
DXASSERT(group == HLOpcodeGroup::HLSubscript, "Must be subscript operation");
|
|
- TranslateMatSubscriptOnGlobal(GV, arrayMat, CI);
|
|
|
|
|
|
+ TranslateMatSubscriptOnGlobalPtr(CI, arrayMat);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
GV->removeDeadConstantUsers();
|
|
GV->removeDeadConstantUsers();
|
|
@@ -2464,6 +2298,18 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
} else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
|
|
} else if (HLMatrixLower::IsMatrixArrayPointer(AI->getType())) {
|
|
lowerToVec(&I);
|
|
lowerToVec(&I);
|
|
}
|
|
}
|
|
|
|
+ } else if (CallInst *CI = dyn_cast<CallInst>(&I)) {
|
|
|
|
+ HLOpcodeGroup group =
|
|
|
|
+ hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
|
+ if (group == HLOpcodeGroup::HLMatLoadStore) {
|
|
|
|
+ HLMatLoadStoreOpcode opcode =
|
|
|
|
+ static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
|
|
|
|
+ DXASSERT(opcode == HLMatLoadStoreOpcode::ColMatStore ||
|
|
|
|
+ opcode == HLMatLoadStoreOpcode::RowMatStore,
|
|
|
|
+ "Must MatStore here, load will go IsMatrixType path");
|
|
|
|
+ // Lower it here to make sure it is ready before replace.
|
|
|
|
+ lowerToVec(&I);
|
|
|
|
+ }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|