|
@@ -272,6 +272,9 @@ private:
|
|
// Get new matrix value corresponding to vecVal
|
|
// Get new matrix value corresponding to vecVal
|
|
Value *GetMatrixForVec(Value *vecVal, Type *matTy);
|
|
Value *GetMatrixForVec(Value *vecVal, Type *matTy);
|
|
|
|
|
|
|
|
+ // Translate library function input/output to preserve function signatures
|
|
|
|
+ void TranslateLibraryArgs(Function &F);
|
|
|
|
+
|
|
// Replace matVal with vecVal on matUseInst.
|
|
// Replace matVal with vecVal on matUseInst.
|
|
void TrivialMatReplace(Value *matVal, Value *vecVal,
|
|
void TrivialMatReplace(Value *matVal, Value *vecVal,
|
|
CallInst *matUseInst);
|
|
CallInst *matUseInst);
|
|
@@ -1269,6 +1272,16 @@ void HLMatrixLowerPass::TrivialMatReplace(Value *matVal,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+static Instruction *CreateTransposeShuffle(IRBuilder<> &Builder, Value *vecVal, unsigned row, unsigned col) {
|
|
|
|
+ SmallVector<int, 16> castMask(col * row);
|
|
|
|
+ unsigned idx = 0;
|
|
|
|
+ for (unsigned c = 0; c < col; c++)
|
|
|
|
+ for (unsigned r = 0; r < row; r++)
|
|
|
|
+ castMask[idx++] = r * col + c;
|
|
|
|
+ return cast<Instruction>(
|
|
|
|
+ Builder.CreateShuffleVector(vecVal, vecVal, castMask));
|
|
|
|
+}
|
|
|
|
+
|
|
void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
|
|
void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
|
|
Value *vecVal,
|
|
Value *vecVal,
|
|
CallInst *castInst,
|
|
CallInst *castInst,
|
|
@@ -1291,25 +1304,9 @@ void HLMatrixLowerPass::TranslateMatMajorCast(Value *matVal,
|
|
|
|
|
|
IRBuilder<> Builder(castInst);
|
|
IRBuilder<> Builder(castInst);
|
|
|
|
|
|
- // shuf to change major.
|
|
|
|
- SmallVector<int, 16> castMask(col * row);
|
|
|
|
- unsigned idx = 0;
|
|
|
|
- if (bRowToCol) {
|
|
|
|
- for (unsigned c = 0; c < col; c++)
|
|
|
|
- for (unsigned r = 0; r < row; r++) {
|
|
|
|
- unsigned matIdx = HLMatrixLower::GetRowMajorIdx(r, c, col);
|
|
|
|
- castMask[idx++] = matIdx;
|
|
|
|
- }
|
|
|
|
- } else {
|
|
|
|
- for (unsigned r = 0; r < row; r++)
|
|
|
|
- for (unsigned c = 0; c < col; c++) {
|
|
|
|
- unsigned matIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
|
|
|
|
- castMask[idx++] = matIdx;
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- Instruction *vecCast = cast<Instruction>(
|
|
|
|
- Builder.CreateShuffleVector(vecVal, vecVal, castMask));
|
|
|
|
|
|
+ if (bRowToCol)
|
|
|
|
+ std::swap(row, col);
|
|
|
|
+ Instruction *vecCast = CreateTransposeShuffle(Builder, vecVal, row, col);
|
|
|
|
|
|
// Replace vec cast function call with vecCast.
|
|
// Replace vec cast function call with vecCast.
|
|
DXASSERT(matToVecMap.count(castInst), "must has vec version");
|
|
DXASSERT(matToVecMap.count(castInst), "must has vec version");
|
|
@@ -2109,12 +2106,10 @@ Value *HLMatrixLowerPass::GetMatrixForVec(Value *vecVal, Type *matTy) {
|
|
|
|
|
|
void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
Value *vecVal) {
|
|
Value *vecVal) {
|
|
|
|
+ Type *matTy = matVal->getType();
|
|
for (Value::user_iterator user = matVal->user_begin();
|
|
for (Value::user_iterator user = matVal->user_begin();
|
|
user != matVal->user_end();) {
|
|
user != matVal->user_end();) {
|
|
Instruction *useInst = cast<Instruction>(*(user++));
|
|
Instruction *useInst = cast<Instruction>(*(user++));
|
|
- // Skip return here.
|
|
|
|
- if (isa<ReturnInst>(useInst))
|
|
|
|
- continue;
|
|
|
|
// User must be function call.
|
|
// User must be function call.
|
|
if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
|
|
if (CallInst *useCall = dyn_cast<CallInst>(useInst)) {
|
|
hlsl::HLOpcodeGroup group =
|
|
hlsl::HLOpcodeGroup group =
|
|
@@ -2183,7 +2178,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
for (unsigned i = 0; i < useCall->getNumArgOperands(); i++) {
|
|
for (unsigned i = 0; i < useCall->getNumArgOperands(); i++) {
|
|
if (useCall->getArgOperand(i) == matVal) {
|
|
if (useCall->getArgOperand(i) == matVal) {
|
|
// update the user call with the correct matrix value in new code sequence
|
|
// update the user call with the correct matrix value in new code sequence
|
|
- Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
|
|
|
|
|
|
+ Value *newMatVal = GetMatrixForVec(vecVal, matTy);
|
|
if (matVal != newMatVal)
|
|
if (matVal != newMatVal)
|
|
useCall->setArgOperand(i, newMatVal);
|
|
useCall->setArgOperand(i, newMatVal);
|
|
}
|
|
}
|
|
@@ -2194,8 +2189,10 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
|
|
// Just replace the src with vec version.
|
|
// Just replace the src with vec version.
|
|
useInst->setOperand(0, vecVal);
|
|
useInst->setOperand(0, vecVal);
|
|
} else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
|
|
} else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
|
|
- Value *newMatVal = GetMatrixForVec(vecVal, matVal->getType());
|
|
|
|
|
|
+ Value *newMatVal = GetMatrixForVec(vecVal, matTy);
|
|
RI->setOperand(0, newMatVal);
|
|
RI->setOperand(0, newMatVal);
|
|
|
|
+ } else if (isa<StoreInst>(useInst)) {
|
|
|
|
+ DXASSERT(vecToMatMap.count(vecVal) && vecToMatMap[vecVal] == matVal, "matrix store should only be used with preserved matrix values");
|
|
} else {
|
|
} else {
|
|
// Must be GEP on mat array alloca.
|
|
// Must be GEP on mat array alloca.
|
|
GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
|
|
GetElementPtrInst *GEP = cast<GetElementPtrInst>(useInst);
|
|
@@ -2467,6 +2464,85 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+void HLMatrixLowerPass::TranslateLibraryArgs(Function &F) {
|
|
|
|
+ // Replace HLCast with BitCastValueOrPtr (+ transpose for colMatToVec)
|
|
|
|
+ // Replace HLMatLoadStore with bitcast + load/store + shuffle if col major
|
|
|
|
+ for (auto &arg : F.args()) {
|
|
|
|
+ SmallVector<CallInst *, 4> Candidates;
|
|
|
|
+ for (User *U : arg.users()) {
|
|
|
|
+ if (CallInst *CI = dyn_cast<CallInst>(U)) {
|
|
|
|
+ HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
|
+ switch (group) {
|
|
|
|
+ case HLOpcodeGroup::HLCast:
|
|
|
|
+ case HLOpcodeGroup::HLMatLoadStore:
|
|
|
|
+ Candidates.push_back(CI);
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ for (CallInst *CI : Candidates) {
|
|
|
|
+ IRBuilder<> Builder(CI);
|
|
|
|
+ HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
|
+ switch (group) {
|
|
|
|
+ case HLOpcodeGroup::HLCast: {
|
|
|
|
+ HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
|
|
|
|
+ if (opcode == HLCastOpcode::RowMatrixToVecCast ||
|
|
|
|
+ opcode == HLCastOpcode::ColMatrixToVecCast) {
|
|
|
|
+ Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
|
|
|
|
+ Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
|
|
|
|
+ /*bOrigAllocaTy*/false,
|
|
|
|
+ matVal->getName());
|
|
|
|
+ if (opcode == HLCastOpcode::ColMatrixToVecCast) {
|
|
|
|
+ unsigned row, col;
|
|
|
|
+ HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
|
|
|
|
+ vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
|
|
|
|
+ }
|
|
|
|
+ CI->replaceAllUsesWith(vecVal);
|
|
|
|
+ CI->eraseFromParent();
|
|
|
|
+ }
|
|
|
|
+ } break;
|
|
|
|
+ case HLOpcodeGroup::HLMatLoadStore: {
|
|
|
|
+ HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
|
|
|
|
+ bool bTranspose = false;
|
|
|
|
+ switch (opcode) {
|
|
|
|
+ case HLMatLoadStoreOpcode::ColMatStore:
|
|
|
|
+ bTranspose = true;
|
|
|
|
+ case HLMatLoadStoreOpcode::RowMatStore: {
|
|
|
|
+ // shuffle if transposed, bitcast, and store
|
|
|
|
+ Value *vecVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
|
|
+ Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
|
|
|
|
+ if (bTranspose) {
|
|
|
|
+ unsigned row, col;
|
|
|
|
+ HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
|
|
|
|
+ vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
|
|
|
|
+ }
|
|
|
|
+ Value *castPtr = Builder.CreateBitCast(matPtr, vecVal->getType()->getPointerTo());
|
|
|
|
+ Builder.CreateStore(vecVal, castPtr);
|
|
|
|
+ CI->eraseFromParent();
|
|
|
|
+ } break;
|
|
|
|
+ case HLMatLoadStoreOpcode::ColMatLoad:
|
|
|
|
+ bTranspose = true;
|
|
|
|
+ case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
|
|
+ // bitcast, load, and shuffle if transposed
|
|
|
|
+ Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
|
|
|
|
+ Value *castPtr = Builder.CreateBitCast(matPtr, CI->getType()->getPointerTo());
|
|
|
|
+ Value *vecVal = Builder.CreateLoad(castPtr);
|
|
|
|
+ if (bTranspose) {
|
|
|
|
+ unsigned row, col;
|
|
|
|
+ HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
|
|
|
|
+ // row/col swapped for col major source
|
|
|
|
+ vecVal = CreateTransposeShuffle(Builder, vecVal, col, row);
|
|
|
|
+ }
|
|
|
|
+ CI->replaceAllUsesWith(vecVal);
|
|
|
|
+ CI->eraseFromParent();
|
|
|
|
+ } break;
|
|
|
|
+ }
|
|
|
|
+ } break;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
// Create vector version of matrix instructions first.
|
|
// Create vector version of matrix instructions first.
|
|
// The matrix operands will be undefval for these instructions.
|
|
// The matrix operands will be undefval for these instructions.
|
|
@@ -2531,4 +2607,12 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
|
|
DeleteDeadInsts();
|
|
DeleteDeadInsts();
|
|
|
|
|
|
matToVecMap.clear();
|
|
matToVecMap.clear();
|
|
|
|
+ vecToMatMap.clear();
|
|
|
|
+
|
|
|
|
+ // If this is a library function, now fix input/output matrix params
|
|
|
|
+ // TODO: What about Patch Constant Shaders?
|
|
|
|
+ if (!m_pHLModule->IsEntryThatUsesSignatures(&F)) {
|
|
|
|
+ TranslateLibraryArgs(F);
|
|
|
|
+ }
|
|
|
|
+ return;
|
|
}
|
|
}
|