|
@@ -2499,19 +2499,25 @@ void HLMatrixLowerPass::TranslateArgForLibFunc(CallInst *CI) {
|
|
switch (group) {
|
|
switch (group) {
|
|
case HLOpcodeGroup::HLCast: {
|
|
case HLOpcodeGroup::HLCast: {
|
|
HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
|
|
HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
|
|
- if (opcode == HLCastOpcode::RowMatrixToVecCast ||
|
|
|
|
- opcode == HLCastOpcode::ColMatrixToVecCast) {
|
|
|
|
|
|
+ bool bTranspose = false;
|
|
|
|
+ switch (opcode) {
|
|
|
|
+ case HLCastOpcode::ColMatrixToRowMatrix:
|
|
|
|
+ case HLCastOpcode::RowMatrixToColMatrix:
|
|
|
|
+ bTranspose = true;
|
|
|
|
+ case HLCastOpcode::ColMatrixToVecCast:
|
|
|
|
+ case HLCastOpcode::RowMatrixToVecCast: {
|
|
Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
|
|
Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
|
|
Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
|
|
Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
|
|
/*bOrigAllocaTy*/false,
|
|
/*bOrigAllocaTy*/false,
|
|
matVal->getName());
|
|
matVal->getName());
|
|
- //if (opcode == HLCastOpcode::ColMatrixToVecCast) {
|
|
|
|
- // unsigned row, col;
|
|
|
|
- // HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
|
|
|
|
- // vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
|
|
|
|
- //}
|
|
|
|
|
|
+ if (bTranspose) {
|
|
|
|
+ unsigned row, col;
|
|
|
|
+ HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
|
|
|
|
+ vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
|
|
|
|
+ }
|
|
CI->replaceAllUsesWith(vecVal);
|
|
CI->replaceAllUsesWith(vecVal);
|
|
CI->eraseFromParent();
|
|
CI->eraseFromParent();
|
|
|
|
+ } break;
|
|
}
|
|
}
|
|
} break;
|
|
} break;
|
|
case HLOpcodeGroup::HLMatLoadStore: {
|
|
case HLOpcodeGroup::HLMatLoadStore: {
|