瀏覽代碼

Use high level intrinsic when lower matrix mul to make sure DXIL operation only exist when DxilModule is ready. (#189)

Xiang Li 8 年之前
父節點
當前提交
9e8137c54d
共有 1 個文件被更改,包括 25 次插入12 次删除
  1. 25 12
      lib/HLSL/HLMatrixLowerPass.cpp

+ 25 - 12
lib/HLSL/HLMatrixLowerPass.cpp

@@ -807,6 +807,16 @@ void HLMatrixLowerPass::TrivialMatBinOpReplace(CallInst *matInst,
   }
 }
 
+static Function *GetOrCreateMadIntrinsic(Type *Ty, Type *opcodeTy, IntrinsicOp madOp, Module &M) {
+  llvm::FunctionType *MadFuncTy =
+      llvm::FunctionType::get(Ty, { opcodeTy, Ty, Ty, Ty}, false);
+
+  Function *MAD =
+      GetOrCreateHLFunction(M, MadFuncTy, HLOpcodeGroup::HLIntrinsic,
+                            (unsigned)madOp);
+  return MAD;
+}
+
 void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
                                            Instruction *vecInst,
                                            CallInst *mulInst, bool isSigned) {
@@ -841,11 +851,12 @@ void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
                    : Builder.CreateMul(lMatElt, rMatElt);
   };
 
-  DXIL::OpCode madOp =
-      isFloat ? DXIL::OpCode::FMad
-              : (isSigned ? DXIL::OpCode::IMad : DXIL::OpCode::UMad);
-  Function *Mad = m_pHLModule->GetOP()->GetOpFunc(madOp, EltTy);
+  IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
+  Type *opcodeTy = Builder.getInt32Ty();
+  Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
+                                          *m_pHLModule->GetModule());
   Value *madOpArg = Builder.getInt32((unsigned)madOp);
+
   auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
                              Value *acc) -> Value * {
     unsigned lMatIdx = GetMatIdx(r, lc, row);
@@ -893,11 +904,12 @@ void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
   Value *vec = RVal;
   Value *mat = vecInst; // vec version of matInst;
 
-  DXIL::OpCode madOp =
-      isFloat ? DXIL::OpCode::FMad
-              : (isSigned ? DXIL::OpCode::IMad : DXIL::OpCode::UMad);
-  Function *Mad = m_pHLModule->GetOP()->GetOpFunc(madOp, EltTy);
+  IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
+  Type *opcodeTy = Builder.getInt32Ty();
+  Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
+                                          *m_pHLModule->GetModule());
   Value *madOpArg = Builder.getInt32((unsigned)madOp);
+
   auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
     Value *vecElt = Builder.CreateExtractElement(vec, c);
     uint32_t matIdx = GetMatIdx(r, c, row);
@@ -944,11 +956,12 @@ void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
   Value *vec = LVal;
   Value *mat = RVal;
 
-  DXIL::OpCode madOp =
-      isFloat ? DXIL::OpCode::FMad
-              : (isSigned ? DXIL::OpCode::IMad : DXIL::OpCode::UMad);
-  Function *Mad = m_pHLModule->GetOP()->GetOpFunc(madOp, EltTy);
+  IntrinsicOp madOp = isSigned ? IntrinsicOp::IOP_mad : IntrinsicOp::IOP_umad;
+  Type *opcodeTy = Builder.getInt32Ty();
+  Function *Mad = GetOrCreateMadIntrinsic(EltTy, opcodeTy, madOp,
+                                          *m_pHLModule->GetModule());
   Value *madOpArg = Builder.getInt32((unsigned)madOp);
+
   auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
     Value *vecElt = Builder.CreateExtractElement(vec, r);
     uint32_t matIdx = GetMatIdx(r, c, row);