瀏覽代碼

Fix lh matrix cast on param case for RowMatrixToColMatrix and reverse

Tex Riddell 7 年之前
父節點
當前提交
2a1a3d2b99
共有 1 個文件被更改,包括 13 次插入7 次删除
  1. 13 7
      lib/HLSL/HLMatrixLowerPass.cpp

+ 13 - 7
lib/HLSL/HLMatrixLowerPass.cpp

@@ -2499,19 +2499,25 @@ void HLMatrixLowerPass::TranslateArgForLibFunc(CallInst *CI) {
   switch (group) {
   case HLOpcodeGroup::HLCast: {
     HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
-    if (opcode == HLCastOpcode::RowMatrixToVecCast ||
-        opcode == HLCastOpcode::ColMatrixToVecCast) {
+    bool bTranspose = false;
+    switch (opcode) {
+    case HLCastOpcode::ColMatrixToRowMatrix:
+    case HLCastOpcode::RowMatrixToColMatrix:
+      bTranspose = true;
+    case HLCastOpcode::ColMatrixToVecCast:
+    case HLCastOpcode::RowMatrixToVecCast: {
       Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
       Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
                                         /*bOrigAllocaTy*/false,
                                         matVal->getName());
-      //if (opcode == HLCastOpcode::ColMatrixToVecCast) {
-      //  unsigned row, col;
-      //  HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
-      //  vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
-      //}
+      if (bTranspose) {
+        unsigned row, col;
+        HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
+        vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
+      }
       CI->replaceAllUsesWith(vecVal);
       CI->eraseFromParent();
+    } break;
     }
   } break;
   case HLOpcodeGroup::HLMatLoadStore: {