|
@@ -807,6 +807,16 @@ void HLMatrixLowerPass::TrivialMatBinOpReplace(CallInst *matInst,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+static Function *GetOrCreateMadIntrinsic(Type *Ty, Type *opcodeTy, IntrinsicOp madOp, Module &M) {
|
|
|
+ llvm::FunctionType *MadFuncTy =
|
|
|
+ llvm::FunctionType::get(Ty, { opcodeTy, Ty, Ty, Ty}, false);
|
|
|
+
|
|
|
+ Function *MAD =
|
|
|
+ GetOrCreateHLFunction(M, MadFuncTy, HLOpcodeGroup::HLIntrinsic,
|
|
|
+ (unsigned)madOp);
|
|
|
+ return MAD;
|
|
|
+}
|
|
|
+
|
|
|
void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
|
|
|
Instruction *vecInst,
|
|
|
CallInst *mulInst, bool isSigned) {
|
|
@@ -841,11 +851,12 @@ void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
|
|
|
: Builder.CreateMul(lMatElt, rMatElt);
|
|
|
};
|
|
|
|
|
|
- DXIL::OpCode madOp =
|
|
|
- isFloat ? DXIL::OpCode::FMad
|
|
|
- : (isSigned ? DXIL::OpCode::IMad : DXIL::OpCode::UMad);
|
|
|
- Function *Mad = m_pHLModule->GetOP()->GetOpFunc(madOp, EltTy);
|
|
|
+ IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
|
|
|
+ Type *opcodeTy = Builder.getInt32Ty();
|
|
|
+ Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
|
|
|
+ *m_pHLModule->GetModule());
|
|
|
Value *madOpArg = Builder.getInt32((unsigned)madOp);
|
|
|
+
|
|
|
auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
|
|
|
Value *acc) -> Value * {
|
|
|
unsigned lMatIdx = GetMatIdx(r, lc, row);
|
|
@@ -893,11 +904,12 @@ void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
|
|
|
Value *vec = RVal;
|
|
|
Value *mat = vecInst; // vec version of matInst;
|
|
|
|
|
|
- DXIL::OpCode madOp =
|
|
|
- isFloat ? DXIL::OpCode::FMad
|
|
|
- : (isSigned ? DXIL::OpCode::IMad : DXIL::OpCode::UMad);
|
|
|
- Function *Mad = m_pHLModule->GetOP()->GetOpFunc(madOp, EltTy);
|
|
|
+ IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
|
|
|
+ Type *opcodeTy = Builder.getInt32Ty();
|
|
|
+ Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
|
|
|
+ *m_pHLModule->GetModule());
|
|
|
Value *madOpArg = Builder.getInt32((unsigned)madOp);
|
|
|
+
|
|
|
auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
|
|
|
Value *vecElt = Builder.CreateExtractElement(vec, c);
|
|
|
uint32_t matIdx = GetMatIdx(r, c, row);
|
|
@@ -944,11 +956,12 @@ void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
|
|
|
Value *vec = LVal;
|
|
|
Value *mat = RVal;
|
|
|
|
|
|
- DXIL::OpCode madOp =
|
|
|
- isFloat ? DXIL::OpCode::FMad
|
|
|
- : (isSigned ? DXIL::OpCode::IMad : DXIL::OpCode::UMad);
|
|
|
- Function *Mad = m_pHLModule->GetOP()->GetOpFunc(madOp, EltTy);
|
|
|
+ IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
|
|
|
+ Type *opcodeTy = Builder.getInt32Ty();
|
|
|
+ Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
|
|
|
+ *m_pHLModule->GetModule());
|
|
|
Value *madOpArg = Builder.getInt32((unsigned)madOp);
|
|
|
+
|
|
|
auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
|
|
|
Value *vecElt = Builder.CreateExtractElement(vec, r);
|
|
|
uint32_t matIdx = GetMatIdx(r, c, row);
|