2
0
Эх сурвалжийг харах

Remove transpose during library arg translation of matrix load/store
(should be present already if necesssary)

Tex Riddell 7 жил өмнө
parent
commit
456a78da3a

+ 21 - 19
lib/HLSL/HLMatrixLowerPass.cpp

@@ -2467,6 +2467,8 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
 void HLMatrixLowerPass::TranslateLibraryArgs(Function &F) {
   // Replace HLCast with BitCastValueOrPtr (+ transpose for colMatToVec)
   // Replace HLMatLoadStore with bitcast + load/store + shuffle if col major
+  // NOTE: Transpose has been removed, as it should have explicit cast op
+  //       ColMatrixToRowMatrix or RowMatrixToColMatrix.
   for (auto &arg : F.args()) {
     SmallVector<CallInst *, 4> Candidates;
     for (User *U : arg.users()) {
@@ -2492,47 +2494,47 @@ void HLMatrixLowerPass::TranslateLibraryArgs(Function &F) {
           Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
                                             /*bOrigAllocaTy*/false,
                                             matVal->getName());
-          if (opcode == HLCastOpcode::ColMatrixToVecCast) {
-            unsigned row, col;
-            HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
-            vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
-          }
+          //if (opcode == HLCastOpcode::ColMatrixToVecCast) {
+          //  unsigned row, col;
+          //  HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
+          //  vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
+          //}
           CI->replaceAllUsesWith(vecVal);
           CI->eraseFromParent();
         }
       } break;
       case HLOpcodeGroup::HLMatLoadStore: {
         HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
-        bool bTranspose = false;
+        //bool bTranspose = false;
         switch (opcode) {
         case HLMatLoadStoreOpcode::ColMatStore:
-          bTranspose = true;
+          //bTranspose = true;
         case HLMatLoadStoreOpcode::RowMatStore: {
           // shuffle if transposed, bitcast, and store
           Value *vecVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
           Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
-          if (bTranspose) {
-            unsigned row, col;
-            HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
-            vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
-          }
+          //if (bTranspose) {
+          //  unsigned row, col;
+          //  HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
+          //  vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
+          //}
           Value *castPtr = Builder.CreateBitCast(matPtr, vecVal->getType()->getPointerTo());
           Builder.CreateStore(vecVal, castPtr);
           CI->eraseFromParent();
         } break;
         case HLMatLoadStoreOpcode::ColMatLoad:
-          bTranspose = true;
+          //bTranspose = true;
         case HLMatLoadStoreOpcode::RowMatLoad: {
           // bitcast, load, and shuffle if transposed
           Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
           Value *castPtr = Builder.CreateBitCast(matPtr, CI->getType()->getPointerTo());
           Value *vecVal = Builder.CreateLoad(castPtr);
-          if (bTranspose) {
-            unsigned row, col;
-            HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
-            // row/col swapped for col major source
-            vecVal = CreateTransposeShuffle(Builder, vecVal, col, row);
-          }
+          //if (bTranspose) {
+          //  unsigned row, col;
+          //  HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
+          //  // row/col swapped for col major source
+          //  vecVal = CreateTransposeShuffle(Builder, vecVal, col, row);
+          //}
           CI->replaceAllUsesWith(vecVal);
           CI->eraseFromParent();
         } break;