|
@@ -283,6 +283,8 @@ private:
|
|
|
void lowerToVec(Instruction *matInst);
|
|
|
// Lower users of a matrix type instruction.
|
|
|
void replaceMatWithVec(Value *matVal, Value *vecVal);
|
|
|
+ // Translate user library function call arguments
|
|
|
+ void castMatrixArgs(Instruction *I);
|
|
|
// Translate mat inst which need all operands ready.
|
|
|
void finalMatTranslation(Value *matVal);
|
|
|
// Delete dead insts in m_deadInsts.
|
|
@@ -291,8 +293,6 @@ private:
|
|
|
DenseMap<Value *, Value *> matToVecMap;
|
|
|
// Map from new vector version to matrix version needed by user call or return.
|
|
|
DenseMap<Value *, Value *> vecToMatMap;
|
|
|
- // Record matrix defining instructions that need preserving (in library functions).
|
|
|
- std::vector<Instruction*> matInstsToKeep;
|
|
|
};
|
|
|
}
|
|
|
|
|
@@ -878,8 +878,6 @@ void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
|
|
|
matInst->getName());
|
|
|
// matrix equivalent of this new vector will be the original, retained user call
|
|
|
vecToMatMap[vecVal] = matInst;
|
|
|
- // Add to matInstsToKeep so we don't delete this call
|
|
|
- matInstsToKeep.push_back(matInst);
|
|
|
} break;
|
|
|
default:
|
|
|
DXASSERT(0, "invalid inst");
|
|
@@ -919,8 +917,6 @@ void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
|
|
|
matInst->getType()->getPointerElementType() )->getPointerTo());
|
|
|
// matrix equivalent of this new vector will be the original, retained GEP
|
|
|
vecToMatMap[vecVal] = matInst;
|
|
|
- // Add to matInstsToKeep so we don't delete this GEP
|
|
|
- matInstsToKeep.push_back(matInst);
|
|
|
} else {
|
|
|
DXASSERT(0, "invalid inst");
|
|
|
}
|
|
@@ -992,7 +988,7 @@ void HLMatrixLowerPass::TranslateMatMatMul(Value *matVal,
|
|
|
Value *vecVal,
|
|
|
CallInst *mulInst, bool isSigned) {
|
|
|
(void)(matVal); // Unused; retrieved from matToVecMap directly
|
|
|
- DXASSERT(matToVecMap.count(mulInst), "must has vec version");
|
|
|
+ DXASSERT(matToVecMap.count(mulInst), "must have vec version");
|
|
|
Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
|
|
|
// Already translated.
|
|
|
if (!isa<CallInst>(vecUseInst))
|
|
@@ -1009,7 +1005,7 @@ void HLMatrixLowerPass::TranslateMatMatMul(Value *matVal,
|
|
|
bool isFloat = EltTy->isFloatingPointTy();
|
|
|
|
|
|
Value *retVal = llvm::UndefValue::get(LowerMatrixType(mulInst->getType()));
|
|
|
- IRBuilder<> Builder(mulInst);
|
|
|
+ IRBuilder<> Builder(vecUseInst);
|
|
|
|
|
|
Value *lMat = matToVecMap[cast<Instruction>(LVal)];
|
|
|
Value *rMat = matToVecMap[cast<Instruction>(RVal)];
|
|
@@ -1330,15 +1326,16 @@ void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
|
|
|
row = srcRow;
|
|
|
}
|
|
|
|
|
|
- IRBuilder<> Builder(castInst);
|
|
|
+ DXASSERT(matToVecMap.count(castInst), "must have vec version");
|
|
|
+ Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
|
|
|
+ // Create before vecUseInst to prevent instructions being inserted after uses.
|
|
|
+ IRBuilder<> Builder(vecUseInst);
|
|
|
|
|
|
if (bRowToCol)
|
|
|
std::swap(row, col);
|
|
|
Instruction *vecCast = CreateTransposeShuffle(Builder, vecVal, row, col);
|
|
|
|
|
|
// Replace vec cast function call with vecCast.
|
|
|
- DXASSERT(matToVecMap.count(castInst), "must has vec version");
|
|
|
- Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
|
|
|
vecUseInst->replaceAllUsesWith(vecCast);
|
|
|
AddToDeadInsts(vecUseInst);
|
|
|
matToVecMap[castInst] = vecCast;
|
|
@@ -1354,8 +1351,10 @@ void HLMatrixLowerPass::TranslateMatMatCast(Value *matVal,
|
|
|
unsigned fromSize = fromCol * fromRow;
|
|
|
unsigned toSize = toCol * toRow;
|
|
|
DXASSERT(fromSize >= toSize, "cannot extend matrix");
|
|
|
+ DXASSERT(matToVecMap.count(castInst), "must have vec version");
|
|
|
+ Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
|
|
|
|
|
|
- IRBuilder<> Builder(castInst);
|
|
|
+ IRBuilder<> Builder(vecUseInst);
|
|
|
Instruction *vecCast = nullptr;
|
|
|
|
|
|
HLCastOpcode opcode = static_cast<HLCastOpcode>(GetHLOpcode(castInst));
|
|
@@ -1383,8 +1382,6 @@ void HLMatrixLowerPass::TranslateMatMatCast(Value *matVal,
|
|
|
vecCast = shuf;
|
|
|
}
|
|
|
// Replace vec cast function call with vecCast.
|
|
|
- DXASSERT(matToVecMap.count(castInst), "must has vec version");
|
|
|
- Instruction *vecUseInst = cast<Instruction>(matToVecMap[castInst]);
|
|
|
vecUseInst->replaceAllUsesWith(vecCast);
|
|
|
AddToDeadInsts(vecUseInst);
|
|
|
matToVecMap[castInst] = vecCast;
|
|
@@ -1633,7 +1630,7 @@ void HLMatrixLowerPass::TranslateMatSubscript(Value *matVal, Value *vecVal,
|
|
|
}
|
|
|
// Check vec version.
|
|
|
DXASSERT(matToVecMap.count(matSubInst) == 0, "should not have vec version");
|
|
|
- // All the user should has been removed.
|
|
|
+ // All the user should have been removed.
|
|
|
matSubInst->replaceAllUsesWith(UndefValue::get(matSubInst->getType()));
|
|
|
AddToDeadInsts(matSubInst);
|
|
|
}
|
|
@@ -1968,7 +1965,9 @@ void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
|
|
|
if (matInitInst->getType()->isVoidTy())
|
|
|
return;
|
|
|
|
|
|
- IRBuilder<> Builder(matInitInst);
|
|
|
+ DXASSERT(matToVecMap.count(matInitInst), "must have vec version");
|
|
|
+ Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
|
|
|
+ IRBuilder<> Builder(vecUseInst);
|
|
|
unsigned col, row;
|
|
|
Type *EltTy = GetMatrixInfo(matInitInst->getType(), col, row);
|
|
|
|
|
@@ -1992,15 +1991,12 @@ void HLMatrixLowerPass::TranslateMatInit(CallInst *matInitInst) {
|
|
|
}
|
|
|
|
|
|
// Replace matInit function call with matInitInst.
|
|
|
- DXASSERT(matToVecMap.count(matInitInst), "must has vec version");
|
|
|
- Instruction *vecUseInst = cast<Instruction>(matToVecMap[matInitInst]);
|
|
|
vecUseInst->replaceAllUsesWith(newInit);
|
|
|
AddToDeadInsts(vecUseInst);
|
|
|
matToVecMap[matInitInst] = newInit;
|
|
|
}
|
|
|
|
|
|
void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
|
|
|
- IRBuilder<> Builder(matSelectInst);
|
|
|
unsigned col, row;
|
|
|
Type *EltTy = GetMatrixInfo(matSelectInst->getType(), col, row);
|
|
|
|
|
@@ -2011,6 +2007,8 @@ void HLMatrixLowerPass::TranslateMatSelect(CallInst *matSelectInst) {
|
|
|
Instruction *LHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc1Idx));
|
|
|
Instruction *RHS = cast<Instruction>(matSelectInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc2Idx));
|
|
|
|
|
|
+ IRBuilder<> Builder(vecUseInst);
|
|
|
+
|
|
|
Value *Cond = vecUseInst->getArgOperand(HLOperandIndex::kTrinaryOpSrc0Idx);
|
|
|
bool isVecCond = Cond->getType()->isVectorTy();
|
|
|
if (isVecCond) {
|
|
@@ -2070,7 +2068,7 @@ void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
|
|
|
if (useCall->getType()->isVectorTy())
|
|
|
continue;
|
|
|
Value *newLd = Builder.CreateLoad(newGEP);
|
|
|
- DXASSERT(matToVecMap.count(useCall), "must has vec version");
|
|
|
+ DXASSERT(matToVecMap.count(useCall), "must have vec version");
|
|
|
Value *oldLd = matToVecMap[useCall];
|
|
|
// Delete the oldLd.
|
|
|
AddToDeadInsts(cast<Instruction>(oldLd));
|
|
@@ -2090,7 +2088,7 @@ void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
|
|
|
|
|
|
Instruction *matInst = cast<Instruction>(matVal);
|
|
|
|
|
|
- DXASSERT(matToVecMap.count(matInst), "must has vec version");
|
|
|
+ DXASSERT(matToVecMap.count(matInst), "must have vec version");
|
|
|
Value *vecVal = matToVecMap[matInst];
|
|
|
Builder.CreateStore(vecVal, vecPtr);
|
|
|
} break;
|
|
@@ -2134,7 +2132,6 @@ Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
|
|
|
|
|
|
void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
|
Value *vecVal) {
|
|
|
- Type *matTy = matVal->getType();
|
|
|
for (Value::user_iterator user = matVal->user_begin();
|
|
|
user != matVal->user_end();) {
|
|
|
Instruction *useInst = cast<Instruction>(*(user++));
|
|
@@ -2178,7 +2175,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
|
TranslateMatCast(matVal, vecVal, useCall);
|
|
|
} break;
|
|
|
case HLOpcodeGroup::HLMatLoadStore: {
|
|
|
- DXASSERT(matToVecMap.count(useCall), "must has vec version");
|
|
|
+ DXASSERT(matToVecMap.count(useCall), "must have vec version");
|
|
|
Value *vecUser = matToVecMap[useCall];
|
|
|
if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal)) {
|
|
|
// Load Already translated in lowerToVec.
|
|
@@ -2202,15 +2199,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
|
TranslateMatInit(useCall);
|
|
|
} break;
|
|
|
case HLOpcodeGroup::NotHL: {
|
|
|
- // translate user function parameters as necessary
|
|
|
- for (unsigned i = 0; i < useCall->getNumArgOperands(); i++) {
|
|
|
- if (useCall->getArgOperand(i) == matVal) {
|
|
|
- // update the user call with the correct matrix value in new code sequence
|
|
|
- Value *newMatVal = GetMatrixForVec(vecVal, matTy);
|
|
|
- if (matVal != newMatVal)
|
|
|
- useCall->setArgOperand(i, newMatVal);
|
|
|
- }
|
|
|
- }
|
|
|
+ castMatrixArgs(useCall);
|
|
|
} break;
|
|
|
}
|
|
|
} else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
|
|
@@ -2218,7 +2207,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
|
if (useInst != vecVal)
|
|
|
useInst->setOperand(0, vecVal);
|
|
|
} else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
|
|
|
- Value *newMatVal = GetMatrixForVec(vecVal, matTy);
|
|
|
+ Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
|
|
|
RI->setOperand(0, newMatVal);
|
|
|
} else if (isa<StoreInst>(useInst)) {
|
|
|
DXASSERT(vecToMatMap.count(vecVal) && vecToMatMap[vecVal] == matVal, "matrix store should only be used with preserved matrix values");
|
|
@@ -2231,6 +2220,22 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+void HLMatrixLowerPass::castMatrixArgs(Instruction *I) {
|
|
|
+ // translate user function parameters as necessary
|
|
|
+ for (unsigned i = 0; i < I->getNumOperands(); i++) {
|
|
|
+ Value *argVal = I->getOperand(i);
|
|
|
+ Type *argTy = argVal->getType();
|
|
|
+ if (argTy->isPointerTy())
|
|
|
+ argTy->getPointerElementType();
|
|
|
+ if (argTy->isStructTy() && IsMatrixType(argTy)) {
|
|
|
+ Value *vecVal = matToVecMap[argVal];
|
|
|
+ Value *newMatVal = GetMatrixForVec(vecVal, argVal->getType());
|
|
|
+ if (argVal != newMatVal)
|
|
|
+ I->setOperand(i, newMatVal);
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
void HLMatrixLowerPass::finalMatTranslation(Value *matVal) {
|
|
|
// Translate matInit.
|
|
|
if (CallInst *CI = dyn_cast<CallInst>(matVal)) {
|
|
@@ -2500,8 +2505,10 @@ void HLMatrixLowerPass::TranslateArgForLibFunc(CallInst *CI) {
|
|
|
case HLOpcodeGroup::HLCast: {
|
|
|
HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
|
|
|
bool bTranspose = false;
|
|
|
+ bool bColSource = false;
|
|
|
switch (opcode) {
|
|
|
case HLCastOpcode::ColMatrixToRowMatrix:
|
|
|
+ bColSource = true;
|
|
|
case HLCastOpcode::RowMatrixToColMatrix:
|
|
|
bTranspose = true;
|
|
|
case HLCastOpcode::ColMatrixToVecCast:
|
|
@@ -2513,6 +2520,7 @@ void HLMatrixLowerPass::TranslateArgForLibFunc(CallInst *CI) {
|
|
|
if (bTranspose) {
|
|
|
unsigned row, col;
|
|
|
HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
|
|
|
+ if (bColSource) std::swap(row,col);
|
|
|
vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
|
|
|
}
|
|
|
CI->replaceAllUsesWith(vecVal);
|
|
@@ -2649,9 +2657,9 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
|
finalMatTranslation(matToVec->first);
|
|
|
}
|
|
|
|
|
|
- // Remove matInstsToKeep from matToVecMap before adding the rest to dead insts.
|
|
|
- for (auto I : matInstsToKeep) {
|
|
|
- matToVecMap.erase(I);
|
|
|
+ // Remove matrix targets of vecToMatMap from matToVecMap before adding the rest to dead insts.
|
|
|
+ for (auto &it : vecToMatMap) {
|
|
|
+ matToVecMap.erase(it.second);
|
|
|
}
|
|
|
|
|
|
// Delete the matrix version insts.
|
|
@@ -2659,8 +2667,8 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
|
matToVecIter != matToVecMap.end();) {
|
|
|
auto matToVec = matToVecIter++;
|
|
|
// Add to m_deadInsts.
|
|
|
- Instruction *matInst = cast<Instruction>(matToVec->first);
|
|
|
- AddToDeadInsts(matInst);
|
|
|
+ if (Instruction *matInst = dyn_cast<Instruction>(matToVec->first))
|
|
|
+ AddToDeadInsts(matInst);
|
|
|
}
|
|
|
|
|
|
DeleteDeadInsts();
|