소스 검색

HLMatrixLower: Handle unflattened lib function matrix return val and param.

Tex Riddell 7 년 전
부모
커밋
2ae113596a
1개의 변경된 파일57개의 추가작업 그리고 0개의 파일을 삭제
  1. 57 0
      lib/HLSL/HLMatrixLowerPass.cpp

+ 57 - 0
lib/HLSL/HLMatrixLowerPass.cpp

@@ -269,6 +269,9 @@ private:
   void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
   void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
 
+  // Get new matrix value corresponding to vecVal
+  Value *GetMatrixForVec(Value *vecVal, Type *matTy);
+
   // Replace matVal with vecVal on matUseInst.
   void TrivialMatReplace(Value *matVal, Value *vecVal,
                         CallInst *matUseInst);
@@ -282,6 +285,10 @@ private:
   void DeleteDeadInsts();
   // Map from matrix value to its vector version.
   DenseMap<Value *, Value *> matToVecMap;
+  // Map from new vector version to matrix version needed by user call or return.
+  DenseMap<Value *, Value *> vecToMatMap;
+  // Record matrix defining instructions that need preserving (in library functions).
+  std::vector<Instruction*> matInstsToKeep;
 };
 }
 
@@ -841,6 +848,20 @@ void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
     case HLOpcodeGroup::HLSubscript: {
       vecVal = MatSubscriptToVec(CI);
     } break;
+    case HLOpcodeGroup::NotHL: {
+      // Translate user function return
+      vecVal = BitCastValueOrPtr( matInst,
+                                  matInst->getNextNode(),
+                                  HLMatrixLower::LowerMatrixType(matInst->getType()),
+                                  /*bOrigAllocaTy*/ false,
+                                  matInst->getName());
+      // matrix equivalent of this new vector will be the original, retained user call
+      vecToMatMap[vecVal] = matInst;
+      // Add to matInstsToKeep so we don't delete this call
+      matInstsToKeep.push_back(matInst);
+    } break;
+    default:
+      DXASSERT(0, "invalid inst");
     }
   } else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
     Type *Ty = AI->getAllocatedType();
@@ -2069,6 +2090,23 @@ void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
   AddToDeadInsts(matGEP);
 }
 
+Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
+  Value *newMatVal = nullptr;
+  if (vecToMatMap.count(vecVal)) {
+    newMatVal = vecToMatMap[vecVal];
+  } else {
+    // create conversion instructions if necessary, caching result for subsequent replacements.
+    // do so right after the vecVal def so it's available to all potential uses.
+    newMatVal = BitCastValueOrPtr(vecVal,
+      cast<Instruction>(vecVal)->getNextNode(), // vecVal must be instruction
+      matTy,
+      /*bOrigAllocaTy*/true,
+      vecVal->getName());
+    vecToMatMap[vecVal] = newMatVal;
+  }
+  return newMatVal;
+}
+
 void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
                                           Value *vecVal) {
   for (Value::user_iterator user = matVal->user_begin();
@@ -2140,10 +2178,24 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
         DXASSERT(!isa<AllocaInst>(matVal), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
         TranslateMatInit(useCall);
       } break;
+      case HLOpcodeGroup::NotHL: {
+        // translate user function parameters as necessary
+        for (unsigned i = 0; i < useCall->getNumArgOperands(); i++) {
+          if (useCall->getArgOperand(i) == matVal) {
+            // update the user call with the correct matrix value in new code sequence
+            Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
+            if (matVal != newMatVal)
+              useCall->setArgOperand(i, newMatVal);
+          }
+        }
+      } break;
       }
     } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
       // Just replace the src with vec version.
       useInst->setOperand(0, vecVal);
+    } else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
+      Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
+      RI->setOperand(0, newMatVal);
     } else {
       // Must be GEP on mat array alloca.
       GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
@@ -2462,6 +2514,11 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
       finalMatTranslation(matToVec->first);
   }
 
+  // Remove matInstsToKeep from matToVecMap before adding the rest to dead insts.
+  for (auto I : matInstsToKeep) {
+    matToVecMap.erase(I);
+  }
+
   // Delete the matrix version insts.
   for (auto matToVecIter = matToVecMap.begin();
        matToVecIter != matToVecMap.end();) {