|
@@ -273,7 +273,8 @@ private:
|
|
|
Value *GetMatrixForVec(Value *vecVal, Type *matTy);
|
|
|
|
|
|
// Translate library function input/output to preserve function signatures
|
|
|
- void TranslateLibraryArgs(Function &F);
|
|
|
+ void TranslateArgForLibFunc(CallInst *CI);
|
|
|
+ void TranslateArgsForLibFunc(Function &F);
|
|
|
|
|
|
// Replace matVal with vecVal on matUseInst.
|
|
|
void TrivialMatReplace(Value *matVal, Value *vecVal,
|
|
@@ -414,6 +415,23 @@ Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
|
|
|
return MatIntrinsicToVec(CI);
|
|
|
}
|
|
|
|
|
|
+// Return GEP if value is Matrix resulting GEP from UDT alloca
|
|
|
+// UDT alloca must be there for library function args
|
|
|
+static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
|
|
|
+ if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
|
|
|
+ if (IsMatrixType(GEP->getResultElementType())) {
|
|
|
+ Value *ptr = GEP->getPointerOperand();
|
|
|
+ if (AllocaInst *AI = dyn_cast<AllocaInst>(ptr)) {
|
|
|
+ Type *ATy = AI->getAllocatedType();
|
|
|
+ if (ATy->isStructTy() && !IsMatrixType(ATy)) {
|
|
|
+ return GEP;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nullptr;
|
|
|
+}
|
|
|
+
|
|
|
Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
|
|
|
IRBuilder<> Builder(CI);
|
|
|
unsigned opcode = GetHLOpcode(CI);
|
|
@@ -423,7 +441,7 @@ Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
|
|
|
case HLMatLoadStoreOpcode::ColMatLoad:
|
|
|
case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
|
Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
|
|
|
- if (isa<AllocaInst>(matPtr)) {
|
|
|
+ if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr)) {
|
|
|
Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
|
|
|
result = Builder.CreateLoad(vecPtr);
|
|
|
} else
|
|
@@ -432,7 +450,7 @@ Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
|
|
|
case HLMatLoadStoreOpcode::ColMatStore:
|
|
|
case HLMatLoadStoreOpcode::RowMatStore: {
|
|
|
Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
|
|
|
- if (isa<AllocaInst>(matPtr)) {
|
|
|
+ if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr)) {
|
|
|
Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
|
|
|
Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
|
Value *vecVal =
|
|
@@ -893,6 +911,16 @@ void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
|
|
|
if (HLModule::HasPreciseAttributeWithMetadata(AI))
|
|
|
HLModule::MarkPreciseAttributeWithMetadata(cast<Instruction>(vecVal));
|
|
|
|
|
|
+ } else if (GetIfMatrixGEPOfUDTAlloca(matInst)) {
|
|
|
+ // If GEP from alloca of non-matrix UDT, bitcast
|
|
|
+ IRBuilder<> Builder(matInst->getNextNode());
|
|
|
+ vecVal = Builder.CreateBitCast(matInst,
|
|
|
+ HLMatrixLower::LowerMatrixType(
|
|
|
+ matInst->getType()->getPointerElementType() )->getPointerTo());
|
|
|
+ // matrix equivalent of this new vector will be the original, retained GEP
|
|
|
+ vecToMatMap[vecVal] = matInst;
|
|
|
+ // Add to matInstsToKeep so we don't delete this GEP
|
|
|
+ matInstsToKeep.push_back(matInst);
|
|
|
} else {
|
|
|
DXASSERT(0, "invalid inst");
|
|
|
}
|
|
@@ -2152,7 +2180,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
|
case HLOpcodeGroup::HLMatLoadStore: {
|
|
|
DXASSERT(matToVecMap.count(useCall), "must has vec version");
|
|
|
Value *vecUser = matToVecMap[useCall];
|
|
|
- if (AllocaInst *AI = dyn_cast<AllocaInst>(matVal)) {
|
|
|
+ if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal)) {
|
|
|
// Load Already translated in lowerToVec.
|
|
|
// Store val operand will be set by the val use.
|
|
|
// Do nothing here.
|
|
@@ -2187,7 +2215,8 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
|
}
|
|
|
} else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
|
|
|
// Just replace the src with vec version.
|
|
|
- useInst->setOperand(0, vecVal);
|
|
|
+ if (useInst != vecVal)
|
|
|
+ useInst->setOperand(0, vecVal);
|
|
|
} else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
|
|
|
Value *newMatVal = GetMatrixForVec(vecVal, matTy);
|
|
|
RI->setOperand(0, newMatVal);
|
|
@@ -2464,82 +2493,98 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void HLMatrixLowerPass::TranslateLibraryArgs(Function &F) {
|
|
|
+void HLMatrixLowerPass::TranslateArgForLibFunc(CallInst *CI) {
|
|
|
+ IRBuilder<> Builder(CI);
|
|
|
+ HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
+ switch (group) {
|
|
|
+ case HLOpcodeGroup::HLCast: {
|
|
|
+ HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
|
|
|
+ if (opcode == HLCastOpcode::RowMatrixToVecCast ||
|
|
|
+ opcode == HLCastOpcode::ColMatrixToVecCast) {
|
|
|
+ Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
|
|
|
+ Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
|
|
|
+ /*bOrigAllocaTy*/false,
|
|
|
+ matVal->getName());
|
|
|
+ //if (opcode == HLCastOpcode::ColMatrixToVecCast) {
|
|
|
+ // unsigned row, col;
|
|
|
+ // HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
|
|
|
+ // vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
|
|
|
+ //}
|
|
|
+ CI->replaceAllUsesWith(vecVal);
|
|
|
+ CI->eraseFromParent();
|
|
|
+ }
|
|
|
+ } break;
|
|
|
+ case HLOpcodeGroup::HLMatLoadStore: {
|
|
|
+ HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
|
|
|
+ //bool bTranspose = false;
|
|
|
+ switch (opcode) {
|
|
|
+ case HLMatLoadStoreOpcode::ColMatStore:
|
|
|
+ //bTranspose = true;
|
|
|
+ case HLMatLoadStoreOpcode::RowMatStore: {
|
|
|
+ // shuffle if transposed, bitcast, and store
|
|
|
+ Value *vecVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
|
+ Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
|
|
|
+ //if (bTranspose) {
|
|
|
+ // unsigned row, col;
|
|
|
+ // HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
|
|
|
+ // vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
|
|
|
+ //}
|
|
|
+ Value *castPtr = Builder.CreateBitCast(matPtr, vecVal->getType()->getPointerTo());
|
|
|
+ Builder.CreateStore(vecVal, castPtr);
|
|
|
+ CI->eraseFromParent();
|
|
|
+ } break;
|
|
|
+ case HLMatLoadStoreOpcode::ColMatLoad:
|
|
|
+ //bTranspose = true;
|
|
|
+ case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
|
+ // bitcast, load, and shuffle if transposed
|
|
|
+ Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
|
|
|
+ Value *castPtr = Builder.CreateBitCast(matPtr, CI->getType()->getPointerTo());
|
|
|
+ Value *vecVal = Builder.CreateLoad(castPtr);
|
|
|
+ //if (bTranspose) {
|
|
|
+ // unsigned row, col;
|
|
|
+ // HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
|
|
|
+ // // row/col swapped for col major source
|
|
|
+ // vecVal = CreateTransposeShuffle(Builder, vecVal, col, row);
|
|
|
+ //}
|
|
|
+ CI->replaceAllUsesWith(vecVal);
|
|
|
+ CI->eraseFromParent();
|
|
|
+ } break;
|
|
|
+ }
|
|
|
+ } break;
|
|
|
+ }
|
|
|
+}
|
|
|
+static CallInst *GetAsMatCastOrLdSt(Value* V) {
|
|
|
+ if (CallInst *CI = dyn_cast<CallInst>(V)) {
|
|
|
+ HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
+ switch (group) {
|
|
|
+ case HLOpcodeGroup::HLCast:
|
|
|
+ case HLOpcodeGroup::HLMatLoadStore:
|
|
|
+ return CI;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nullptr;
|
|
|
+}
|
|
|
+void HLMatrixLowerPass::TranslateArgsForLibFunc(Function &F) {
|
|
|
// Replace HLCast with BitCastValueOrPtr (+ transpose for colMatToVec)
|
|
|
// Replace HLMatLoadStore with bitcast + load/store + shuffle if col major
|
|
|
// NOTE: Transpose has been removed, as it should have explicit cast op
|
|
|
// ColMatrixToRowMatrix or RowMatrixToColMatrix.
|
|
|
for (auto &arg : F.args()) {
|
|
|
- SmallVector<CallInst *, 4> Candidates;
|
|
|
- for (User *U : arg.users()) {
|
|
|
- if (CallInst *CI = dyn_cast<CallInst>(U)) {
|
|
|
- HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
- switch (group) {
|
|
|
- case HLOpcodeGroup::HLCast:
|
|
|
- case HLOpcodeGroup::HLMatLoadStore:
|
|
|
- Candidates.push_back(CI);
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- for (CallInst *CI : Candidates) {
|
|
|
- IRBuilder<> Builder(CI);
|
|
|
- HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
- switch (group) {
|
|
|
- case HLOpcodeGroup::HLCast: {
|
|
|
- HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
|
|
|
- if (opcode == HLCastOpcode::RowMatrixToVecCast ||
|
|
|
- opcode == HLCastOpcode::ColMatrixToVecCast) {
|
|
|
- Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
|
|
|
- Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
|
|
|
- /*bOrigAllocaTy*/false,
|
|
|
- matVal->getName());
|
|
|
- //if (opcode == HLCastOpcode::ColMatrixToVecCast) {
|
|
|
- // unsigned row, col;
|
|
|
- // HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
|
|
|
- // vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
|
|
|
- //}
|
|
|
- CI->replaceAllUsesWith(vecVal);
|
|
|
- CI->eraseFromParent();
|
|
|
- }
|
|
|
- } break;
|
|
|
- case HLOpcodeGroup::HLMatLoadStore: {
|
|
|
- HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
|
|
|
- //bool bTranspose = false;
|
|
|
- switch (opcode) {
|
|
|
- case HLMatLoadStoreOpcode::ColMatStore:
|
|
|
- //bTranspose = true;
|
|
|
- case HLMatLoadStoreOpcode::RowMatStore: {
|
|
|
- // shuffle if transposed, bitcast, and store
|
|
|
- Value *vecVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
|
- Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
|
|
|
- //if (bTranspose) {
|
|
|
- // unsigned row, col;
|
|
|
- // HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
|
|
|
- // vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
|
|
|
- //}
|
|
|
- Value *castPtr = Builder.CreateBitCast(matPtr, vecVal->getType()->getPointerTo());
|
|
|
- Builder.CreateStore(vecVal, castPtr);
|
|
|
- CI->eraseFromParent();
|
|
|
- } break;
|
|
|
- case HLMatLoadStoreOpcode::ColMatLoad:
|
|
|
- //bTranspose = true;
|
|
|
- case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
|
- // bitcast, load, and shuffle if transposed
|
|
|
- Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
|
|
|
- Value *castPtr = Builder.CreateBitCast(matPtr, CI->getType()->getPointerTo());
|
|
|
- Value *vecVal = Builder.CreateLoad(castPtr);
|
|
|
- //if (bTranspose) {
|
|
|
- // unsigned row, col;
|
|
|
- // HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
|
|
|
- // // row/col swapped for col major source
|
|
|
- // vecVal = CreateTransposeShuffle(Builder, vecVal, col, row);
|
|
|
- //}
|
|
|
- CI->replaceAllUsesWith(vecVal);
|
|
|
- CI->eraseFromParent();
|
|
|
- } break;
|
|
|
+ for (auto itU = arg.user_begin(); itU != arg.user_end();) {
|
|
|
+ Value *U = *(itU++);
|
|
|
+ if (CallInst *CI = GetAsMatCastOrLdSt(U)) {
|
|
|
+ TranslateArgForLibFunc(CI);
|
|
|
+ } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
|
|
|
+ if (IsMatrixType(GEP->getResultElementType())) {
|
|
|
+ // arg is struct, GEP returns matrix ptr
|
|
|
+ // look for load/store/cast to translate in GEP users
|
|
|
+ for (auto itGEPU = GEP->user_begin(); itGEPU != GEP->user_end();) {
|
|
|
+ Value *gepU = *(itGEPU++);
|
|
|
+ if (CallInst *CI = GetAsMatCastOrLdSt(gepU))
|
|
|
+ TranslateArgForLibFunc(CI);
|
|
|
+ }
|
|
|
}
|
|
|
- } break;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -2577,6 +2622,8 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
|
// Lower it here to make sure it is ready before replace.
|
|
|
lowerToVec(&I);
|
|
|
}
|
|
|
+ } else if (GetIfMatrixGEPOfUDTAlloca(&I)) {
|
|
|
+ lowerToVec(&I);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -2616,9 +2663,8 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
|
vecToMatMap.clear();
|
|
|
|
|
|
// If this is a library function, now fix input/output matrix params
|
|
|
- // TODO: What about Patch Constant Shaders?
|
|
|
if (!m_pHLModule->IsEntryThatUsesSignatures(&F)) {
|
|
|
- TranslateLibraryArgs(F);
|
|
|
+ TranslateArgsForLibFunc(F);
|
|
|
}
|
|
|
return;
|
|
|
}
|