|
@@ -251,7 +251,7 @@ private:
|
|
|
void TranslateMatCast(CallInst *matInst, Instruction *vecInst,
|
|
|
CallInst *castInst);
|
|
|
void TranslateMatMajorCast(CallInst *matInst, Instruction *vecInst,
|
|
|
- CallInst *castInst, bool rowToCol);
|
|
|
+ CallInst *castInst, bool rowToCol, bool transpose);
|
|
|
// Replace matInst with vecInst in matSubscript
|
|
|
void TranslateMatSubscript(Value *matInst, Value *vecInst,
|
|
|
CallInst *matSubInst);
|
|
@@ -1073,7 +1073,8 @@ void HLMatrixLowerPass::TranslateMatTranspose(CallInst *matInst,
|
|
|
Instruction *vecInst,
|
|
|
CallInst *transposeInst) {
|
|
|
// Matrix value is row major, transpose is cast it to col major.
|
|
|
- TranslateMatMajorCast(matInst, vecInst, transposeInst, /*bRowToCol*/ true);
|
|
|
+ TranslateMatMajorCast(matInst, vecInst, transposeInst,
|
|
|
+ /*bRowToCol*/ true, /*bTranspose*/ true);
|
|
|
}
|
|
|
|
|
|
static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
|
|
@@ -1194,10 +1195,22 @@ void HLMatrixLowerPass::TrivialMatReplace(CallInst *matInst,
|
|
|
void HLMatrixLowerPass::TranslateMatMajorCast(CallInst *matInst,
|
|
|
Instruction *vecInst,
|
|
|
CallInst *castInst,
|
|
|
- bool bRowToCol) {
|
|
|
+ bool bRowToCol,
|
|
|
+ bool bTranspose) {
|
|
|
unsigned col, row;
|
|
|
- GetMatrixInfo(castInst->getType(), col, row);
|
|
|
- DXASSERT(castInst->getType() == matInst->getType(), "type must match");
|
|
|
+ if (!bTranspose) {
|
|
|
+ GetMatrixInfo(castInst->getType(), col, row);
|
|
|
+ DXASSERT(castInst->getType() == matInst->getType(), "type must match");
|
|
|
+ } else {
|
|
|
+ unsigned castCol, castRow;
|
|
|
+ Type *castTy = GetMatrixInfo(castInst->getType(), castCol, castRow);
|
|
|
+ unsigned srcCol, srcRow;
|
|
|
+ Type *srcTy = GetMatrixInfo(matInst->getType(), srcCol, srcRow);
|
|
|
+ DXASSERT(srcTy == castTy, "type must match");
|
|
|
+ DXASSERT(castCol == srcRow && castRow == srcCol, "col row must match");
|
|
|
+ col = srcCol;
|
|
|
+ row = srcRow;
|
|
|
+ }
|
|
|
|
|
|
IRBuilder<> Builder(castInst);
|
|
|
|
|
@@ -1321,7 +1334,8 @@ void HLMatrixLowerPass::TranslateMatCast(CallInst *matInst,
|
|
|
if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
|
|
|
opcode == HLCastOpcode::RowMatrixToColMatrix) {
|
|
|
TranslateMatMajorCast(matInst, vecInst, castInst,
|
|
|
- opcode == HLCastOpcode::RowMatrixToColMatrix);
|
|
|
+ opcode == HLCastOpcode::RowMatrixToColMatrix,
|
|
|
+ /*bTranspose*/false);
|
|
|
} else {
|
|
|
bool ToMat = IsMatrixType(castInst->getType());
|
|
|
bool FromMat = IsMatrixType(matInst->getType());
|