|
@@ -588,6 +588,7 @@ Value *GenerateLdInput(Function *loadInput, ArrayRef<Value *> args,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+
|
|
|
Value *replaceLdWithLdInput(Function *loadInput, LoadInst *ldInst,
|
|
|
unsigned cols, MutableArrayRef<Value *> args,
|
|
|
bool bCast) {
|
|
@@ -654,6 +655,96 @@ Value *replaceLdWithLdInput(Function *loadInput, LoadInst *ldInst,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+void replaceMatStWithStOutputs(CallInst *CI, HLMatLoadStoreOpcode matOp,
|
|
|
+ Function *ldStFunc, Constant *OpArg, Constant *ID,
|
|
|
+ Constant *columnConsts[],Value *vertexOrPrimID,
|
|
|
+ Value *idxVal) {
|
|
|
+ IRBuilder<> LocalBuilder(CI);
|
|
|
+ Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
|
+ HLMatrixType MatTy = HLMatrixType::cast(
|
|
|
+ CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
|
|
|
+ ->getType()->getPointerElementType());
|
|
|
+
|
|
|
+ Val = MatTy.emitLoweredRegToMem(Val, LocalBuilder);
|
|
|
+
|
|
|
+ if (matOp == HLMatLoadStoreOpcode::ColMatStore) {
|
|
|
+ for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
|
|
|
+ Constant *constColIdx = LocalBuilder.getInt32(c);
|
|
|
+ Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
|
|
|
+
|
|
|
+ for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
|
|
|
+ unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
|
|
|
+ Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
|
|
|
+ LocalBuilder.CreateCall(ldStFunc,
|
|
|
+ { OpArg, ID, colIdx, columnConsts[r], Elt });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
|
|
|
+ Constant *constRowIdx = LocalBuilder.getInt32(r);
|
|
|
+ Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
|
|
|
+ for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
|
|
|
+ unsigned matIdx = MatTy.getRowMajorIndex(r, c);
|
|
|
+ Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
|
|
|
+ LocalBuilder.CreateCall(ldStFunc,
|
|
|
+ { OpArg, ID, rowIdx, columnConsts[c], Elt });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ CI->eraseFromParent();
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+void replaceMatLdWithLdInputs(CallInst *CI, HLMatLoadStoreOpcode matOp,
|
|
|
+ Function *ldStFunc, Constant *OpArg, Constant *ID,
|
|
|
+ Constant *columnConsts[],Value *vertexOrPrimID,
|
|
|
+ Value *idxVal) {
|
|
|
+ IRBuilder<> LocalBuilder(CI);
|
|
|
+ HLMatrixType MatTy = HLMatrixType::cast(
|
|
|
+ CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
|
|
|
+ ->getType()->getPointerElementType());
|
|
|
+ std::vector<Value *> matElts(MatTy.getNumElements());
|
|
|
+
|
|
|
+ if (matOp == HLMatLoadStoreOpcode::ColMatLoad) {
|
|
|
+ for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
|
|
|
+ Constant *constRowIdx = LocalBuilder.getInt32(c);
|
|
|
+ Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
|
|
|
+ for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
|
|
|
+ SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[r] };
|
|
|
+ if (vertexOrPrimID)
|
|
|
+ args.emplace_back(vertexOrPrimID);
|
|
|
+
|
|
|
+ Value *input = LocalBuilder.CreateCall(ldStFunc, args);
|
|
|
+ unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
|
|
|
+ matElts[matIdx] = input;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
|
|
|
+ Constant *constRowIdx = LocalBuilder.getInt32(r);
|
|
|
+ Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
|
|
|
+ for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
|
|
|
+ SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[c] };
|
|
|
+ if (vertexOrPrimID)
|
|
|
+ args.emplace_back(vertexOrPrimID);
|
|
|
+
|
|
|
+ Value *input = LocalBuilder.CreateCall(ldStFunc, args);
|
|
|
+ unsigned matIdx = MatTy.getRowMajorIndex(r, c);
|
|
|
+ matElts[matIdx] = input;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ Value *newVec =
|
|
|
+ HLMatrixLower::BuildVector(matElts[0]->getType(), matElts, LocalBuilder);
|
|
|
+ newVec = MatTy.emitLoweredMemToReg(newVec, LocalBuilder);
|
|
|
+
|
|
|
+ CI->replaceAllUsesWith(newVec);
|
|
|
+ CI->eraseFromParent();
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
void replaceDirectInputParameter(Value *param, Function *loadInput,
|
|
|
unsigned cols, MutableArrayRef<Value *> args,
|
|
|
bool bCast, OP *hlslOP, IRBuilder<> &Builder) {
|
|
@@ -964,84 +1055,11 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
|
|
|
switch (matOp) {
|
|
|
case HLMatLoadStoreOpcode::ColMatLoad:
|
|
|
case HLMatLoadStoreOpcode::RowMatLoad: {
|
|
|
- IRBuilder<> LocalBuilder(CI);
|
|
|
- HLMatrixType MatTy = HLMatrixType::cast(
|
|
|
- CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
|
|
|
- ->getType()->getPointerElementType());
|
|
|
- std::vector<Value *> matElts(MatTy.getNumElements());
|
|
|
-
|
|
|
- if (matOp == HLMatLoadStoreOpcode::ColMatLoad) {
|
|
|
- for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
|
|
|
- Constant *constRowIdx = LocalBuilder.getInt32(c);
|
|
|
- Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
|
|
|
- for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
|
|
|
- SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[r] };
|
|
|
- if (vertexOrPrimID)
|
|
|
- args.emplace_back(vertexOrPrimID);
|
|
|
-
|
|
|
- Value *input = LocalBuilder.CreateCall(ldStFunc, args);
|
|
|
- unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
|
|
|
- matElts[matIdx] = input;
|
|
|
- }
|
|
|
- }
|
|
|
- } else {
|
|
|
- for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
|
|
|
- Constant *constRowIdx = LocalBuilder.getInt32(r);
|
|
|
- Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
|
|
|
- for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
|
|
|
- SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[c] };
|
|
|
- if (vertexOrPrimID)
|
|
|
- args.emplace_back(vertexOrPrimID);
|
|
|
-
|
|
|
- Value *input = LocalBuilder.CreateCall(ldStFunc, args);
|
|
|
- unsigned matIdx = MatTy.getRowMajorIndex(r, c);
|
|
|
- matElts[matIdx] = input;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- Value *newVec =
|
|
|
- HLMatrixLower::BuildVector(matElts[0]->getType(), matElts, LocalBuilder);
|
|
|
- newVec = MatTy.emitLoweredMemToReg(newVec, LocalBuilder);
|
|
|
-
|
|
|
- CI->replaceAllUsesWith(newVec);
|
|
|
- CI->eraseFromParent();
|
|
|
+ replaceMatLdWithLdInputs(CI, matOp, ldStFunc, OpArg, ID, columnConsts, vertexOrPrimID, idxVal);
|
|
|
} break;
|
|
|
case HLMatLoadStoreOpcode::ColMatStore:
|
|
|
case HLMatLoadStoreOpcode::RowMatStore: {
|
|
|
- IRBuilder<> LocalBuilder(CI);
|
|
|
- Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
|
|
|
- HLMatrixType MatTy = HLMatrixType::cast(
|
|
|
- CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
|
|
|
- ->getType()->getPointerElementType());
|
|
|
-
|
|
|
- Val = MatTy.emitLoweredRegToMem(Val, LocalBuilder);
|
|
|
-
|
|
|
- if (matOp == HLMatLoadStoreOpcode::ColMatStore) {
|
|
|
- for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
|
|
|
- Constant *constColIdx = LocalBuilder.getInt32(c);
|
|
|
- Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
|
|
|
-
|
|
|
- for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
|
|
|
- unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
|
|
|
- Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
|
|
|
- LocalBuilder.CreateCall(ldStFunc,
|
|
|
- { OpArg, ID, colIdx, columnConsts[r], Elt });
|
|
|
- }
|
|
|
- }
|
|
|
- } else {
|
|
|
- for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
|
|
|
- Constant *constRowIdx = LocalBuilder.getInt32(r);
|
|
|
- Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
|
|
|
- for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
|
|
|
- unsigned matIdx = MatTy.getRowMajorIndex(r, c);
|
|
|
- Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
|
|
|
- LocalBuilder.CreateCall(ldStFunc,
|
|
|
- { OpArg, ID, rowIdx, columnConsts[c], Elt });
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- CI->eraseFromParent();
|
|
|
+ replaceMatStWithStOutputs(CI, matOp, ldStFunc, OpArg, ID, columnConsts, vertexOrPrimID, idxVal);
|
|
|
} break;
|
|
|
}
|
|
|
} else {
|
|
@@ -1386,6 +1404,14 @@ void HLSignatureLower::GenerateDxilPatchConstantFunctionInputs() {
|
|
|
Type *i1Ty = Type::getInt1Ty(constZero->getContext());
|
|
|
Type *i32Ty = constZero->getType();
|
|
|
|
|
|
+ Constant *columnConsts[] = {
|
|
|
+ hlslOP->GetU8Const(0), hlslOP->GetU8Const(1), hlslOP->GetU8Const(2),
|
|
|
+ hlslOP->GetU8Const(3), hlslOP->GetU8Const(4), hlslOP->GetU8Const(5),
|
|
|
+ hlslOP->GetU8Const(6), hlslOP->GetU8Const(7), hlslOP->GetU8Const(8),
|
|
|
+ hlslOP->GetU8Const(9), hlslOP->GetU8Const(10), hlslOP->GetU8Const(11),
|
|
|
+ hlslOP->GetU8Const(12), hlslOP->GetU8Const(13), hlslOP->GetU8Const(14),
|
|
|
+ hlslOP->GetU8Const(15)};
|
|
|
+
|
|
|
for (Argument &arg : patchConstantFunc->args()) {
|
|
|
DxilParameterAnnotation ¶mAnnotation =
|
|
|
patchFuncAnnotation->GetParameterAnnotation(arg.getArgNo());
|
|
@@ -1422,11 +1448,21 @@ void HLSignatureLower::GenerateDxilPatchConstantFunctionInputs() {
|
|
|
collectInputOutputAccessInfo(&arg, constZero, accessInfoList,
|
|
|
/*hasVertexOrPrimID*/ true, true, bRowMajor, false);
|
|
|
for (InputOutputAccessInfo &info : accessInfoList) {
|
|
|
+ Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
|
|
|
if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
|
|
|
- Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
|
|
|
Value *args[] = {OpArg, inputID, info.idx, info.vectorIdx,
|
|
|
info.vertexOrPrimID};
|
|
|
replaceLdWithLdInput(dxilLdFunc, ldInst, cols, args, bI1Cast);
|
|
|
+ } else if (CallInst *CI = dyn_cast<CallInst>(info.user)) {
|
|
|
+ HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
|
|
|
+ // Intrinsic will be translated later.
|
|
|
+ if (group == HLOpcodeGroup::HLIntrinsic || group == HLOpcodeGroup::NotHL)
|
|
|
+ return;
|
|
|
+ unsigned opcode = GetHLOpcode(CI);
|
|
|
+ DXASSERT_NOMSG(group == HLOpcodeGroup::HLMatLoadStore);
|
|
|
+ HLMatLoadStoreOpcode matOp = static_cast<HLMatLoadStoreOpcode>(opcode);
|
|
|
+ if (matOp == HLMatLoadStoreOpcode::ColMatLoad || matOp == HLMatLoadStoreOpcode::RowMatLoad)
|
|
|
+ replaceMatLdWithLdInputs(CI, matOp, dxilLdFunc, OpArg, inputID, columnConsts, info.vertexOrPrimID, info.idx);
|
|
|
} else {
|
|
|
DXASSERT(0, "input should only be ld");
|
|
|
}
|