|
@@ -15,6 +15,7 @@
|
|
#include "dxc/HLSL/HLModule.h"
|
|
#include "dxc/HLSL/HLModule.h"
|
|
#include "dxc/HlslIntrinsicOp.h"
|
|
#include "dxc/HlslIntrinsicOp.h"
|
|
#include "dxc/Support/Global.h"
|
|
#include "dxc/Support/Global.h"
|
|
|
|
+#include "dxc/HLSL/DxilOperations.h"
|
|
|
|
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/Module.h"
|
|
@@ -200,12 +201,13 @@ private:
|
|
CallInst *matUseInst);
|
|
CallInst *matUseInst);
|
|
// Replace matInst with vecInst on mulInst.
|
|
// Replace matInst with vecInst on mulInst.
|
|
void TranslateMatMatMul(CallInst *matInst, Instruction *vecInst,
|
|
void TranslateMatMatMul(CallInst *matInst, Instruction *vecInst,
|
|
- CallInst *mulInst);
|
|
|
|
|
|
+ CallInst *mulInst, bool isSigned);
|
|
void TranslateMatVecMul(CallInst *matInst, Instruction *vecInst,
|
|
void TranslateMatVecMul(CallInst *matInst, Instruction *vecInst,
|
|
- CallInst *mulInst);
|
|
|
|
|
|
+ CallInst *mulInst, bool isSigned);
|
|
void TranslateVecMatMul(CallInst *matInst, Instruction *vecInst,
|
|
void TranslateVecMatMul(CallInst *matInst, Instruction *vecInst,
|
|
- CallInst *mulInst);
|
|
|
|
- void TranslateMul(CallInst *matInst, Instruction *vecInst, CallInst *mulInst);
|
|
|
|
|
|
+ CallInst *mulInst, bool isSigned);
|
|
|
|
+ void TranslateMul(CallInst *matInst, Instruction *vecInst, CallInst *mulInst,
|
|
|
|
+ bool isSigned);
|
|
// Replace matInst with vecInst on transposeInst.
|
|
// Replace matInst with vecInst on transposeInst.
|
|
void TranslateMatTranspose(CallInst *matInst, Instruction *vecInst,
|
|
void TranslateMatTranspose(CallInst *matInst, Instruction *vecInst,
|
|
CallInst *transposeInst);
|
|
CallInst *transposeInst);
|
|
@@ -565,7 +567,7 @@ Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
|
|
if (isFloat)
|
|
if (isFloat)
|
|
Result = BinaryOperator::CreateFDiv(tmp, tmp);
|
|
Result = BinaryOperator::CreateFDiv(tmp, tmp);
|
|
else
|
|
else
|
|
- Result = BinaryOperator::CreateFDiv(tmp, tmp);
|
|
|
|
|
|
+ Result = BinaryOperator::CreateSDiv(tmp, tmp);
|
|
break;
|
|
break;
|
|
case HLBinaryOpcode::Rem:
|
|
case HLBinaryOpcode::Rem:
|
|
if (isFloat)
|
|
if (isFloat)
|
|
@@ -807,7 +809,7 @@ void HLMatrixLowerPass::TrivialMatBinOpReplace(CallInst *matInst,
|
|
|
|
|
|
void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
|
|
void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
|
|
Instruction *vecInst,
|
|
Instruction *vecInst,
|
|
- CallInst *mulInst) {
|
|
|
|
|
|
+ CallInst *mulInst, bool isSigned) {
|
|
DXASSERT(matToVecMap.count(mulInst), "must has vec version");
|
|
DXASSERT(matToVecMap.count(mulInst), "must has vec version");
|
|
Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
|
|
Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
|
|
// Already translated.
|
|
// Already translated.
|
|
@@ -839,15 +841,27 @@ void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
|
|
: Builder.CreateMul(lMatElt, rMatElt);
|
|
: Builder.CreateMul(lMatElt, rMatElt);
|
|
};
|
|
};
|
|
|
|
|
|
|
|
+ DXIL::OpCode madOp =
|
|
|
|
+ isFloat ? DXIL::OpCode::FMad
|
|
|
|
+ : (isSigned ? DXIL::OpCode::IMad : DXIL::OpCode::UMad);
|
|
|
|
+ Function *Mad = m_pHLModule->GetOP()->GetOpFunc(madOp, EltTy);
|
|
|
|
+ Value *madOpArg = Builder.getInt32((unsigned)madOp);
|
|
|
|
+ auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
|
|
|
|
+ Value *acc) -> Value * {
|
|
|
|
+ unsigned lMatIdx = GetMatIdx(r, lc, row);
|
|
|
|
+ unsigned rMatIdx = GetMatIdx(lc, c, rRow);
|
|
|
|
+ Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
|
|
|
|
+ Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
|
|
|
|
+ return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
|
|
|
|
+ };
|
|
|
|
+
|
|
for (unsigned r = 0; r < row; r++) {
|
|
for (unsigned r = 0; r < row; r++) {
|
|
for (unsigned c = 0; c < rCol; c++) {
|
|
for (unsigned c = 0; c < rCol; c++) {
|
|
unsigned lc = 0;
|
|
unsigned lc = 0;
|
|
Value *tmpVal = CreateOneEltMul(r, lc, c);
|
|
Value *tmpVal = CreateOneEltMul(r, lc, c);
|
|
|
|
|
|
for (lc = 1; lc < col; lc++) {
|
|
for (lc = 1; lc < col; lc++) {
|
|
- Value *tmpMul = CreateOneEltMul(r, lc, c);
|
|
|
|
- tmpVal = isFloat ? Builder.CreateFAdd(tmpVal, tmpMul)
|
|
|
|
- : Builder.CreateAdd(tmpVal, tmpMul);
|
|
|
|
|
|
+ tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
|
|
}
|
|
}
|
|
unsigned matIdx = GetMatIdx(r, c, row);
|
|
unsigned matIdx = GetMatIdx(r, c, row);
|
|
retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
|
|
retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
|
|
@@ -863,7 +877,7 @@ void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
|
|
|
|
|
|
void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
|
|
void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
|
|
Instruction *vecInst,
|
|
Instruction *vecInst,
|
|
- CallInst *mulInst) {
|
|
|
|
|
|
+ CallInst *mulInst, bool isSigned) {
|
|
// matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
|
|
// matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
|
|
Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
|
|
|
|
@@ -879,6 +893,18 @@ void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
|
|
Value *vec = RVal;
|
|
Value *vec = RVal;
|
|
Value *mat = vecInst; // vec version of matInst;
|
|
Value *mat = vecInst; // vec version of matInst;
|
|
|
|
|
|
|
|
+ DXIL::OpCode madOp =
|
|
|
|
+ isFloat ? DXIL::OpCode::FMad
|
|
|
|
+ : (isSigned ? DXIL::OpCode::IMad : DXIL::OpCode::UMad);
|
|
|
|
+ Function *Mad = m_pHLModule->GetOP()->GetOpFunc(madOp, EltTy);
|
|
|
|
+ Value *madOpArg = Builder.getInt32((unsigned)madOp);
|
|
|
|
+ auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
|
|
|
|
+ Value *vecElt = Builder.CreateExtractElement(vec, c);
|
|
|
|
+ uint32_t matIdx = GetMatIdx(r, c, row);
|
|
|
|
+ Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
|
|
+ return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
|
|
|
|
+ };
|
|
|
|
+
|
|
for (unsigned r = 0; r < row; r++) {
|
|
for (unsigned r = 0; r < row; r++) {
|
|
unsigned c = 0;
|
|
unsigned c = 0;
|
|
Value *vecElt = Builder.CreateExtractElement(vec, c);
|
|
Value *vecElt = Builder.CreateExtractElement(vec, c);
|
|
@@ -889,13 +915,7 @@ void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
|
|
: Builder.CreateMul(vecElt, matElt);
|
|
: Builder.CreateMul(vecElt, matElt);
|
|
|
|
|
|
for (c = 1; c < col; c++) {
|
|
for (c = 1; c < col; c++) {
|
|
- vecElt = Builder.CreateExtractElement(vec, c);
|
|
|
|
- uint32_t matIdx = GetMatIdx(r, c, row);
|
|
|
|
- Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
|
|
- Value *tmpMul = isFloat ? Builder.CreateFMul(vecElt, matElt)
|
|
|
|
- : Builder.CreateMul(vecElt, matElt);
|
|
|
|
- tmpVal = isFloat ? Builder.CreateFAdd(tmpVal, tmpMul)
|
|
|
|
- : Builder.CreateAdd(tmpVal, tmpMul);
|
|
|
|
|
|
+ tmpVal = CreateOneEltMad(r, c, tmpVal);
|
|
}
|
|
}
|
|
|
|
|
|
retVal = Builder.CreateInsertElement(retVal, tmpVal, r);
|
|
retVal = Builder.CreateInsertElement(retVal, tmpVal, r);
|
|
@@ -907,7 +927,7 @@ void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
|
|
|
|
|
|
void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
|
|
void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
|
|
Instruction *vecInst,
|
|
Instruction *vecInst,
|
|
- CallInst *mulInst) {
|
|
|
|
|
|
+ CallInst *mulInst, bool isSigned) {
|
|
Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
|
|
Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
|
|
// matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
// matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
Value *RVal = vecInst;
|
|
Value *RVal = vecInst;
|
|
@@ -924,6 +944,18 @@ void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
|
|
Value *vec = LVal;
|
|
Value *vec = LVal;
|
|
Value *mat = RVal;
|
|
Value *mat = RVal;
|
|
|
|
|
|
|
|
+ DXIL::OpCode madOp =
|
|
|
|
+ isFloat ? DXIL::OpCode::FMad
|
|
|
|
+ : (isSigned ? DXIL::OpCode::IMad : DXIL::OpCode::UMad);
|
|
|
|
+ Function *Mad = m_pHLModule->GetOP()->GetOpFunc(madOp, EltTy);
|
|
|
|
+ Value *madOpArg = Builder.getInt32((unsigned)madOp);
|
|
|
|
+ auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
|
|
|
|
+ Value *vecElt = Builder.CreateExtractElement(vec, r);
|
|
|
|
+ uint32_t matIdx = GetMatIdx(r, c, row);
|
|
|
|
+ Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
|
|
+ return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
|
|
|
|
+ };
|
|
|
|
+
|
|
for (unsigned c = 0; c < col; c++) {
|
|
for (unsigned c = 0; c < col; c++) {
|
|
unsigned r = 0;
|
|
unsigned r = 0;
|
|
Value *vecElt = Builder.CreateExtractElement(vec, r);
|
|
Value *vecElt = Builder.CreateExtractElement(vec, r);
|
|
@@ -934,13 +966,7 @@ void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
|
|
: Builder.CreateMul(vecElt, matElt);
|
|
: Builder.CreateMul(vecElt, matElt);
|
|
|
|
|
|
for (r = 1; r < row; r++) {
|
|
for (r = 1; r < row; r++) {
|
|
- vecElt = Builder.CreateExtractElement(vec, r);
|
|
|
|
- uint32_t matIdx = GetMatIdx(r, c, row);
|
|
|
|
- Value *matElt = Builder.CreateExtractElement(mat, matIdx);
|
|
|
|
- Value *tmpMul = isFloat ? Builder.CreateFMul(vecElt, matElt)
|
|
|
|
- : Builder.CreateMul(vecElt, matElt);
|
|
|
|
- tmpVal = isFloat ? Builder.CreateFAdd(tmpVal, tmpMul)
|
|
|
|
- : Builder.CreateAdd(tmpVal, tmpMul);
|
|
|
|
|
|
+ tmpVal = CreateOneEltMad(r, c, tmpVal);
|
|
}
|
|
}
|
|
|
|
|
|
retVal = Builder.CreateInsertElement(retVal, tmpVal, c);
|
|
retVal = Builder.CreateInsertElement(retVal, tmpVal, c);
|
|
@@ -951,18 +977,18 @@ void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
|
|
}
|
|
}
|
|
|
|
|
|
void HLMatrixLowerPass::TranslateMul(CallInst *matInst, Instruction *vecInst,
|
|
void HLMatrixLowerPass::TranslateMul(CallInst *matInst, Instruction *vecInst,
|
|
- CallInst *mulInst) {
|
|
|
|
|
|
+ CallInst *mulInst, bool isSigned) {
|
|
Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
|
|
Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
|
|
Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
|
|
|
|
|
|
bool LMat = IsMatrixType(LVal->getType());
|
|
bool LMat = IsMatrixType(LVal->getType());
|
|
bool RMat = IsMatrixType(RVal->getType());
|
|
bool RMat = IsMatrixType(RVal->getType());
|
|
if (LMat && RMat) {
|
|
if (LMat && RMat) {
|
|
- TranslateMatMatMul(matInst, vecInst, mulInst);
|
|
|
|
|
|
+ TranslateMatMatMul(matInst, vecInst, mulInst, isSigned);
|
|
} else if (LMat) {
|
|
} else if (LMat) {
|
|
- TranslateMatVecMul(matInst, vecInst, mulInst);
|
|
|
|
|
|
+ TranslateMatVecMul(matInst, vecInst, mulInst, isSigned);
|
|
} else {
|
|
} else {
|
|
- TranslateVecMatMul(matInst, vecInst, mulInst);
|
|
|
|
|
|
+ TranslateVecMatMul(matInst, vecInst, mulInst, isSigned);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -1209,8 +1235,11 @@ void HLMatrixLowerPass::MatIntrinsicReplace(CallInst *matInst,
|
|
IRBuilder<> Builder(matUseInst);
|
|
IRBuilder<> Builder(matUseInst);
|
|
IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(matUseInst));
|
|
IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(matUseInst));
|
|
switch (opcode) {
|
|
switch (opcode) {
|
|
|
|
+ case IntrinsicOp::IOP_umul:
|
|
|
|
+ TranslateMul(matInst, vecInst, matUseInst, /*isSigned*/false);
|
|
|
|
+ break;
|
|
case IntrinsicOp::IOP_mul:
|
|
case IntrinsicOp::IOP_mul:
|
|
- TranslateMul(matInst, vecInst, matUseInst);
|
|
|
|
|
|
+ TranslateMul(matInst, vecInst, matUseInst, /*isSigned*/true);
|
|
break;
|
|
break;
|
|
case IntrinsicOp::IOP_transpose:
|
|
case IntrinsicOp::IOP_transpose:
|
|
TranslateMatTranspose(matInst, vecInst, matUseInst);
|
|
TranslateMatTranspose(matInst, vecInst, matUseInst);
|