瀏覽代碼

Legalize IO for lib func, HLMatrixLowerPass: fix value args, other cleanup.

Tex Riddell 7 年之前
父節點
當前提交
0c55780ce4
共有 2 個文件被更改,包括 55 次插入40 次删除
  1. 46 38
      lib/HLSL/HLMatrixLowerPass.cpp
  2. 9 2
      lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

+ 46 - 38
lib/HLSL/HLMatrixLowerPass.cpp

@@ -283,6 +283,8 @@ private:
   void lowerToVec(Instruction *matInst);
   // Lower users of a matrix type instruction.
   void replaceMatWithVec(Value *matVal, Value *vecVal);
+  // Translate user library function call arguments
+  void castMatrixArgs(Instruction *I);
   // Translate mat inst which need all operands ready.
   void finalMatTranslation(Value *matVal);
   // Delete dead insts in m_deadInsts.
@@ -291,8 +293,6 @@ private:
   DenseMap<Value *, Value *> matToVecMap;
   // Map from new vector version to matrix version needed by user call or return.
   DenseMap<Value *, Value *> vecToMatMap;
-  // Record matrix defining instructions that need preserving (in library functions).
-  std::vector<Instruction*> matInstsToKeep;
 };
 }
 
@@ -878,8 +878,6 @@ void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
                                   matInst->getName());
       // matrix equivalent of this new vector will be the original, retained user call
       vecToMatMap[vecVal] = matInst;
-      // Add to matInstsToKeep so we don't delete this call
-      matInstsToKeep.push_back(matInst);
     } break;
     default:
       DXASSERT(0, "invalid inst");
@@ -919,8 +917,6 @@ void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
         matInst->getType()->getPointerElementType() )->getPointerTo());
     // matrix equivalent of this new vector will be the original, retained GEP
     vecToMatMap[vecVal] = matInst;
-    // Add to matInstsToKeep so we don't delete this GEP
-    matInstsToKeep.push_back(matInst);
   } else {
     DXASSERT(0, "invalid inst");
   }
@@ -992,7 +988,7 @@ void HLMatrixLowerPass::TranslateMatMatMul(Value *matVal,
                                            Value *vecVal,
                                            CallInst *mulInst, bool isSigned) {
   (void)(matVal); // Unused; retrieved from matToVecMap directly
-  DXASSERT(matToVecMap.count(mulInst), "must has vec version");
+  DXASSERT(matToVecMap.count(mulInst), "must have vec version");
   Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
   // Already translated.
   if (!isa<CallInst>(vecUseInst))
@@ -1009,7 +1005,7 @@ void HLMatrixLowerPass::TranslateMatMatMul(Value *matVal,
   bool isFloat = EltTy->isFloatingPointTy();
 
   Value *retVal = llvm::UndefValue::get(LowerMatrixType(mulInst->getType()));
-  IRBuilder<> Builder(mulInst);
+  IRBuilder<> Builder(vecUseInst);
 
   Value *lMat = matToVecMap[cast<Instruction>(LVal)];
   Value *rMat = matToVecMap[cast<Instruction>(RVal)];
@@ -1330,15 +1326,16 @@ void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
     row = srcRow;
   }
 
-  IRBuilder<> Builder(castInst);
+  DXASSERT(matToVecMap.count(castInst), "must have vec version");
+  Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
+  // Create before vecUseInst to prevent instructions being inserted after uses.
+  IRBuilder<> Builder(vecUseInst);
 
   if (bRowToCol)
     std::swap(row, col);
   Instruction *vecCast = CreateTransposeShuffle(Builder, vecVal, row, col);
 
   // Replace vec cast function call with vecCast.
-  DXASSERT(matToVecMap.count(castInst), "must has vec version");
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
   vecUseInst->replaceAllUsesWith(vecCast);
   AddToDeadInsts(vecUseInst);
   matToVecMap[castInst] = vecCast;
@@ -1354,8 +1351,10 @@ void HLMatrixLowerPass::TranslateMatMatCast(Value *matVal,
   unsigned fromSize = fromCol * fromRow;
   unsigned toSize = toCol * toRow;
   DXASSERT(fromSize >= toSize, "cannot extend matrix");
+  DXASSERT(matToVecMap.count(castInst), "must have vec version");
+  Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
 
-  IRBuilder<> Builder(castInst);
+  IRBuilder<> Builder(vecUseInst);
   Instruction *vecCast = nullptr;
 
   HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
@@ -1383,8 +1382,6 @@ void HLMatrixLowerPass::TranslateMatMatCast(Value *matVal,
       vecCast = shuf;
   }
   // Replace vec cast function call with vecCast.
-  DXASSERT(matToVecMap.count(castInst), "must has vec version");
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
   vecUseInst->replaceAllUsesWith(vecCast);
   AddToDeadInsts(vecUseInst);
   matToVecMap[castInst] = vecCast;
@@ -1633,7 +1630,7 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matVal, Value *vecVal,
   }
   // Check vec version.
   DXASSERT(matToVecMap.count(matSubInst) == 0, "should not have vec version");
-  // All the user should has been removed.
+  // All the user should have been removed.
   matSubInst->replaceAllUsesWith(UndefValue::get(matSubInst->getType()));
   AddToDeadInsts(matSubInst);
 }
@@ -1968,7 +1965,9 @@ void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
   if (matInitInst->getType()->isVoidTy())
     return;
 
-  IRBuilder<> Builder(matInitInst);
+  DXASSERT(matToVecMap.count(matInitInst), "must have vec version");
+  Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
+  IRBuilder<> Builder(vecUseInst);
   unsigned col, row;
   Type *EltTy = GetMatrixInfo(matInitInst->getType(), col, row);
 
@@ -1992,15 +1991,12 @@ void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
   }
 
   // Replace matInit function call with matInitInst.
-  DXASSERT(matToVecMap.count(matInitInst), "must has vec version");
-  Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
   vecUseInst->replaceAllUsesWith(newInit);
   AddToDeadInsts(vecUseInst);
   matToVecMap[matInitInst] = newInit;
 }
 
 void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
-  IRBuilder<> Builder(matSelectInst);
   unsigned col, row;
   Type *EltTy = GetMatrixInfo(matSelectInst->getType(), col, row);
 
@@ -2011,6 +2007,8 @@ void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
   Instruction *LHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx));
   Instruction *RHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx));
 
+  IRBuilder<> Builder(vecUseInst);
+
   Value *Cond = vecUseInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
   bool isVecCond = Cond->getType()->isVectorTy();
   if (isVecCond) {
@@ -2070,7 +2068,7 @@ void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
           if (useCall->getType()->isVectorTy())
             continue;
           Value *newLd = Builder.CreateLoad(newGEP);
-          DXASSERT(matToVecMap.count(useCall), "must has vec version");
+          DXASSERT(matToVecMap.count(useCall), "must have vec version");
           Value *oldLd = matToVecMap[useCall];
           // Delete the oldLd.
           AddToDeadInsts(cast<Instruction>(oldLd));
@@ -2090,7 +2088,7 @@ void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
 
           Instruction *matInst = cast<Instruction>(matVal);
 
-          DXASSERT(matToVecMap.count(matInst), "must has vec version");
+          DXASSERT(matToVecMap.count(matInst), "must have vec version");
           Value *vecVal = matToVecMap[matInst];
           Builder.CreateStore(vecVal, vecPtr);
         } break;
@@ -2134,7 +2132,6 @@ Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
 
 void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
                                           Value *vecVal) {
-  Type *matTy = matVal->getType();
   for (Value::user_iterator user = matVal->user_begin();
        user != matVal->user_end();) {
     Instruction *useInst = cast<Instruction>(*(user++));
@@ -2178,7 +2175,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
         TranslateMatCast(matVal, vecVal, useCall);
       } break;
       case HLOpcodeGroup::HLMatLoadStore: {
-        DXASSERT(matToVecMap.count(useCall), "must has vec version");
+        DXASSERT(matToVecMap.count(useCall), "must have vec version");
         Value *vecUser = matToVecMap[useCall];
         if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal)) {
           // Load Already translated in lowerToVec.
@@ -2202,15 +2199,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
         TranslateMatInit(useCall);
       } break;
       case HLOpcodeGroup::NotHL: {
-        // translate user function parameters as necessary
-        for (unsigned i = 0; i < useCall->getNumArgOperands(); i++) {
-          if (useCall->getArgOperand(i) == matVal) {
-            // update the user call with the correct matrix value in new code sequence
-            Value *newMatVal = GetMatrixForVec(vecVal, matTy);
-            if (matVal != newMatVal)
-              useCall->setArgOperand(i, newMatVal);
-          }
-        }
+        castMatrixArgs(useCall);
       } break;
       }
     } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
@@ -2218,7 +2207,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
       if (useInst != vecVal)
         useInst->setOperand(0, vecVal);
     } else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
-      Value *newMatVal = GetMatrixForVec(vecVal, matTy);
+      Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
       RI->setOperand(0, newMatVal);
     } else if (isa<StoreInst>(useInst)) {
       DXASSERT(vecToMatMap.count(vecVal) && vecToMatMap[vecVal] == matVal, "matrix store should only be used with preserved matrix values");
@@ -2231,6 +2220,22 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
   }
 }
 
+void HLMatrixLowerPass::castMatrixArgs(Instruction *I) {
+  // translate user function parameters as necessary
+  for (unsigned i = 0; i < I->getNumOperands(); i++) {
+    Value *argVal = I->getOperand(i);
+    Type *argTy = argVal->getType();
+    if (argTy->isPointerTy())
+      argTy->getPointerElementType();
+    if (argTy->isStructTy() && IsMatrixType(argTy)) {
+      Value *vecVal = matToVecMap[argVal];
+      Value *newMatVal = GetMatrixForVec(vecVal, argVal->getType());
+      if (argVal != newMatVal)
+        I->setOperand(i, newMatVal);
+    }
+  }
+}
+
 void HLMatrixLowerPass::finalMatTranslation(Value *matVal) {
   // Translate matInit.
   if (CallInst *CI = dyn_cast<CallInst>(matVal)) {
@@ -2500,8 +2505,10 @@ void HLMatrixLowerPass::TranslateArgForLibFunc(CallInst *CI) {
   case HLOpcodeGroup::HLCast: {
     HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
     bool bTranspose = false;
+    bool bColSource = false;
     switch (opcode) {
     case HLCastOpcode::ColMatrixToRowMatrix:
+      bColSource = true;
     case HLCastOpcode::RowMatrixToColMatrix:
       bTranspose = true;
     case HLCastOpcode::ColMatrixToVecCast:
@@ -2513,6 +2520,7 @@ void HLMatrixLowerPass::TranslateArgForLibFunc(CallInst *CI) {
       if (bTranspose) {
         unsigned row, col;
         HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
+        if (bColSource) std::swap(row,col);
         vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
       }
       CI->replaceAllUsesWith(vecVal);
@@ -2649,9 +2657,9 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
       finalMatTranslation(matToVec->first);
   }
 
-  // Remove matInstsToKeep from matToVecMap before adding the rest to dead insts.
-  for (auto I : matInstsToKeep) {
-    matToVecMap.erase(I);
+  // Remove matrix targets of vecToMatMap from matToVecMap before adding the rest to dead insts.
+  for (auto &it : vecToMatMap) {
+    matToVecMap.erase(it.second);
   }
 
   // Delete the matrix version insts.
@@ -2659,8 +2667,8 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
        matToVecIter != matToVecMap.end();) {
     auto matToVec = matToVecIter++;
     // Add to m_deadInsts.
-    Instruction *matInst = cast<Instruction>(matToVec->first);
-    AddToDeadInsts(matInst);
+    if (Instruction *matInst = dyn_cast<Instruction>(matToVec->first))
+      AddToDeadInsts(matInst);
   }
 
   DeleteDeadInsts();

+ 9 - 2
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -3935,6 +3935,10 @@ bool SROA_Helper::IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys) {
 // SROA on function parameters.
 //===----------------------------------------------------------------------===//
 
+static void LegalizeDxilInputOutputs(Function *F,
+  DxilFunctionAnnotation *EntryAnnotation,
+  DxilTypeSystem &typeSys);
+
 namespace {
 class SROA_Parameter_HLSL : public ModulePass {
   HLModule *m_pHLModule;
@@ -3975,10 +3979,13 @@ public:
       if (F.getReturnType()->isVoidTy() && F.arg_size() == 0)
         continue;
 
-      // Skip library functions
+      // Skip library function, except to LegalizeDxilInputOutputs
       if (&F != m_pHLModule->GetEntryFunction() &&
-          !m_pHLModule->IsEntryThatUsesSignatures(&F))
+          !m_pHLModule->IsEntryThatUsesSignatures(&F)) {
+        if (!F.isDeclaration())
+          LegalizeDxilInputOutputs(&F, m_pHLModule->GetFunctionAnnotation(&F), m_pHLModule->GetTypeSystem());
         continue;
+      }
 
       WorkList.emplace_back(&F);
     }