|
|
@@ -269,6 +269,9 @@ private:
|
|
|
void TranslateMatSubscriptOnGlobalPtr(CallInst *matSubInst, Value *vecPtr);
|
|
|
void TranslateMatLoadStoreOnGlobalPtr(CallInst *matLdStInst, Value *vecPtr);
|
|
|
|
|
|
+ // Get new matrix value corresponding to vecVal
|
|
|
+ Value *GetMatrixForVec(Value *vecVal, Type *matTy);
|
|
|
+
|
|
|
// Replace matVal with vecVal on matUseInst.
|
|
|
void TrivialMatReplace(Value *matVal, Value *vecVal,
|
|
|
CallInst *matUseInst);
|
|
|
@@ -282,6 +285,10 @@ private:
|
|
|
void DeleteDeadInsts();
|
|
|
// Map from matrix value to its vector version.
|
|
|
DenseMap<Value *, Value *> matToVecMap;
|
|
|
+ // Map from new vector version to matrix version needed by user call or return.
|
|
|
+ DenseMap<Value *, Value *> vecToMatMap;
|
|
|
+ // Record matrix defining instructions that need preserving (in library functions).
|
|
|
+ std::vector<Instruction*> matInstsToKeep;
|
|
|
};
|
|
|
}
|
|
|
|
|
|
@@ -841,6 +848,20 @@ void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
|
|
|
case HLOpcodeGroup::HLSubscript: {
|
|
|
vecVal = MatSubscriptToVec(CI);
|
|
|
} break;
|
|
|
+ case HLOpcodeGroup::NotHL: {
|
|
|
+ // Translate user function return
|
|
|
+ vecVal = BitCastValueOrPtr( matInst,
|
|
|
+ matInst->getNextNode(),
|
|
|
+ HLMatrixLower::LowerMatrixType(matInst->getType()),
|
|
|
+ /*bOrigAllocaTy*/ false,
|
|
|
+ matInst->getName());
|
|
|
+ // matrix equivalent of this new vector will be the original, retained user call
|
|
|
+ vecToMatMap[vecVal] = matInst;
|
|
|
+ // Add to matInstsToKeep so we don't delete this call
|
|
|
+ matInstsToKeep.push_back(matInst);
|
|
|
+ } break;
|
|
|
+ default:
|
|
|
+ DXASSERT(0, "invalid inst");
|
|
|
}
|
|
|
} else if (AllocaInst *AI = dyn_cast<AllocaInst>(matInst)) {
|
|
|
Type *Ty = AI->getAllocatedType();
|
|
|
@@ -2069,6 +2090,23 @@ void HLMatrixLowerPass::TranslateMatArrayGEP(Value *matInst,
|
|
|
AddToDeadInsts(matGEP);
|
|
|
}
|
|
|
|
|
|
+Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
|
|
|
+ Value *newMatVal = nullptr;
|
|
|
+ if (vecToMatMap.count(vecVal)) {
|
|
|
+ newMatVal = vecToMatMap[vecVal];
|
|
|
+ } else {
|
|
|
+ // create conversion instructions if necessary, caching result for subsequent replacements.
|
|
|
+ // do so right after the vecVal def so it's available to all potential uses.
|
|
|
+ newMatVal = BitCastValueOrPtr(vecVal,
|
|
|
+ cast<Instruction>(vecVal)->getNextNode(), // vecVal must be instruction
|
|
|
+ matTy,
|
|
|
+ /*bOrigAllocaTy*/true,
|
|
|
+ vecVal->getName());
|
|
|
+ vecToMatMap[vecVal] = newMatVal;
|
|
|
+ }
|
|
|
+ return newMatVal;
|
|
|
+}
|
|
|
+
|
|
|
void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
|
Value *vecVal) {
|
|
|
for (Value::user_iterator user = matVal->user_begin();
|
|
|
@@ -2140,10 +2178,24 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
|
DXASSERT(!isa<AllocaInst>(matVal), "array of matrix init should lowered in StoreInitListToDestPtr at CGHLSLMS.cpp");
|
|
|
TranslateMatInit(useCall);
|
|
|
} break;
|
|
|
+ case HLOpcodeGroup::NotHL: {
|
|
|
+ // translate user function parameters as necessary
|
|
|
+ for (unsigned i = 0; i < useCall->getNumArgOperands(); i++) {
|
|
|
+ if (useCall->getArgOperand(i) == matVal) {
|
|
|
+ // update the user call with the correct matrix value in new code sequence
|
|
|
+ Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
|
|
|
+ if (matVal != newMatVal)
|
|
|
+ useCall->setArgOperand(i, newMatVal);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } break;
|
|
|
}
|
|
|
} else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
|
|
|
// Just replace the src with vec version.
|
|
|
useInst->setOperand(0, vecVal);
|
|
|
+ } else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
|
|
|
+ Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
|
|
|
+ RI->setOperand(0, newMatVal);
|
|
|
} else {
|
|
|
// Must be GEP on mat array alloca.
|
|
|
GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
|
|
|
@@ -2462,6 +2514,11 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
|
finalMatTranslation(matToVec->first);
|
|
|
}
|
|
|
|
|
|
+ // Remove matInstsToKeep from matToVecMap before adding the rest to dead insts.
|
|
|
+ for (auto I : matInstsToKeep) {
|
|
|
+ matToVecMap.erase(I);
|
|
|
+ }
|
|
|
+
|
|
|
// Delete the matrix version insts.
|
|
|
for (auto matToVecIter = matToVecMap.begin();
|
|
|
matToVecIter != matToVecMap.end();) {
|