2
0
Xiang Li 8 жил өмнө
parent
commit
73e76ff294

+ 4 - 0
include/dxc/HlslIntrinsicOp.h

@@ -237,6 +237,7 @@ import hctdb_instrhelp
   IOP_umad,
   IOP_umax,
   IOP_umin,
+  IOP_umul,
   MOP_InterlockedUMax,
   MOP_InterlockedUMin,
   Num_Intrinsics,
@@ -264,6 +265,7 @@ import hctdb_instrhelp
   case IntrinsicOp::IOP_mad:
   case IntrinsicOp::IOP_max:
   case IntrinsicOp::IOP_min:
+  case IntrinsicOp::IOP_mul:
   case IntrinsicOp::MOP_InterlockedMax:
   case IntrinsicOp::MOP_InterlockedMin:
 // HLSL-HAS-UNSIGNED-INTRINSICS:END
@@ -307,6 +309,8 @@ import hctdb_instrhelp
     return static_cast<unsigned>(IntrinsicOp::IOP_umax);
   case IntrinsicOp::IOP_min:
     return static_cast<unsigned>(IntrinsicOp::IOP_umin);
+  case IntrinsicOp::IOP_mul:
+    return static_cast<unsigned>(IntrinsicOp::IOP_umul);
   case IntrinsicOp::MOP_InterlockedMax:
     return static_cast<unsigned>(IntrinsicOp::MOP_InterlockedUMax);
   case IntrinsicOp::MOP_InterlockedMin:

+ 59 - 30
lib/HLSL/HLMatrixLowerPass.cpp

@@ -15,6 +15,7 @@
 #include "dxc/HLSL/HLModule.h"
 #include "dxc/HlslIntrinsicOp.h"
 #include "dxc/Support/Global.h"
+#include "dxc/HLSL/DxilOperations.h"
 
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Module.h"
@@ -200,12 +201,13 @@ private:
                              CallInst *matUseInst);
   // Replace matInst with vecInst on mulInst.
   void TranslateMatMatMul(CallInst *matInst, Instruction *vecInst,
-                          CallInst *mulInst);
+                          CallInst *mulInst, bool isSigned);
   void TranslateMatVecMul(CallInst *matInst, Instruction *vecInst,
-                          CallInst *mulInst);
+                          CallInst *mulInst, bool isSigned);
   void TranslateVecMatMul(CallInst *matInst, Instruction *vecInst,
-                          CallInst *mulInst);
-  void TranslateMul(CallInst *matInst, Instruction *vecInst, CallInst *mulInst);
+                          CallInst *mulInst, bool isSigned);
+  void TranslateMul(CallInst *matInst, Instruction *vecInst, CallInst *mulInst,
+                    bool isSigned);
   // Replace matInst with vecInst on transposeInst.
   void TranslateMatTranspose(CallInst *matInst, Instruction *vecInst,
                              CallInst *transposeInst);
@@ -565,7 +567,7 @@ Instruction *HLMatrixLowerPass::TrivialMatBinOpToVec(CallInst *CI) {
     if (isFloat)
       Result = BinaryOperator::CreateFDiv(tmp, tmp);
     else
-      Result = BinaryOperator::CreateFDiv(tmp, tmp);
+      Result = BinaryOperator::CreateSDiv(tmp, tmp);
     break;
   case HLBinaryOpcode::Rem:
     if (isFloat)
@@ -807,7 +809,7 @@ void HLMatrixLowerPass::TrivialMatBinOpReplace(CallInst *matInst,
 
 void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
                                            Instruction *vecInst,
-                                           CallInst *mulInst) {
+                                           CallInst *mulInst, bool isSigned) {
   DXASSERT(matToVecMap.count(mulInst), "must has vec version");
   Instruction *vecUseInst = cast<Instruction>(matToVecMap[mulInst]);
   // Already translated.
@@ -839,15 +841,27 @@ void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
                    : Builder.CreateMul(lMatElt, rMatElt);
   };
 
+  DXIL::OpCode madOp =
+      isFloat ? DXIL::OpCode::FMad
+              : (isSigned ? DXIL::OpCode::IMad : DXIL::OpCode::UMad);
+  Function *Mad = m_pHLModule->GetOP()->GetOpFunc(madOp, EltTy);
+  Value *madOpArg = Builder.getInt32((unsigned)madOp);
+  auto CreateOneEltMad = [&](unsigned r, unsigned lc, unsigned c,
+                             Value *acc) -> Value * {
+    unsigned lMatIdx = GetMatIdx(r, lc, row);
+    unsigned rMatIdx = GetMatIdx(lc, c, rRow);
+    Value *lMatElt = Builder.CreateExtractElement(lMat, lMatIdx);
+    Value *rMatElt = Builder.CreateExtractElement(rMat, rMatIdx);
+    return Builder.CreateCall(Mad, {madOpArg, lMatElt, rMatElt, acc});
+  };
+
   for (unsigned r = 0; r < row; r++) {
     for (unsigned c = 0; c < rCol; c++) {
       unsigned lc = 0;
       Value *tmpVal = CreateOneEltMul(r, lc, c);
 
       for (lc = 1; lc < col; lc++) {
-        Value *tmpMul = CreateOneEltMul(r, lc, c);
-        tmpVal = isFloat ? Builder.CreateFAdd(tmpVal, tmpMul)
-                         : Builder.CreateAdd(tmpVal, tmpMul);
+        tmpVal = CreateOneEltMad(r, lc, c, tmpVal);
       }
       unsigned matIdx = GetMatIdx(r, c, row);
       retVal = Builder.CreateInsertElement(retVal, tmpVal, matIdx);
@@ -863,7 +877,7 @@ void HLMatrixLowerPass::TranslateMatMatMul(CallInst *matInst,
 
 void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
                                            Instruction *vecInst,
-                                           CallInst *mulInst) {
+                                           CallInst *mulInst, bool isSigned) {
   // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
   Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
 
@@ -879,6 +893,18 @@ void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
   Value *vec = RVal;
   Value *mat = vecInst; // vec version of matInst;
 
+  DXIL::OpCode madOp =
+      isFloat ? DXIL::OpCode::FMad
+              : (isSigned ? DXIL::OpCode::IMad : DXIL::OpCode::UMad);
+  Function *Mad = m_pHLModule->GetOP()->GetOpFunc(madOp, EltTy);
+  Value *madOpArg = Builder.getInt32((unsigned)madOp);
+  auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
+    Value *vecElt = Builder.CreateExtractElement(vec, c);
+    uint32_t matIdx = GetMatIdx(r, c, row);
+    Value *matElt = Builder.CreateExtractElement(mat, matIdx);
+    return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
+  };
+
   for (unsigned r = 0; r < row; r++) {
     unsigned c = 0;
     Value *vecElt = Builder.CreateExtractElement(vec, c);
@@ -889,13 +915,7 @@ void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
                             : Builder.CreateMul(vecElt, matElt);
 
     for (c = 1; c < col; c++) {
-      vecElt = Builder.CreateExtractElement(vec, c);
-      uint32_t matIdx = GetMatIdx(r, c, row);
-      Value *matElt = Builder.CreateExtractElement(mat, matIdx);
-      Value *tmpMul = isFloat ? Builder.CreateFMul(vecElt, matElt)
-                              : Builder.CreateMul(vecElt, matElt);
-      tmpVal = isFloat ? Builder.CreateFAdd(tmpVal, tmpMul)
-                       : Builder.CreateAdd(tmpVal, tmpMul);
+      tmpVal = CreateOneEltMad(r, c, tmpVal);
     }
 
     retVal = Builder.CreateInsertElement(retVal, tmpVal, r);
@@ -907,7 +927,7 @@ void HLMatrixLowerPass::TranslateMatVecMul(CallInst *matInst,
 
 void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
                                            Instruction *vecInst,
-                                           CallInst *mulInst) {
+                                           CallInst *mulInst, bool isSigned) {
   Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
   // matInst should == mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
   Value *RVal = vecInst;
@@ -924,6 +944,18 @@ void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
   Value *vec = LVal;
   Value *mat = RVal;
 
+  DXIL::OpCode madOp =
+      isFloat ? DXIL::OpCode::FMad
+              : (isSigned ? DXIL::OpCode::IMad : DXIL::OpCode::UMad);
+  Function *Mad = m_pHLModule->GetOP()->GetOpFunc(madOp, EltTy);
+  Value *madOpArg = Builder.getInt32((unsigned)madOp);
+  auto CreateOneEltMad = [&](unsigned r, unsigned c, Value *acc) -> Value * {
+    Value *vecElt = Builder.CreateExtractElement(vec, r);
+    uint32_t matIdx = GetMatIdx(r, c, row);
+    Value *matElt = Builder.CreateExtractElement(mat, matIdx);
+    return Builder.CreateCall(Mad, {madOpArg, vecElt, matElt, acc});
+  };
+
   for (unsigned c = 0; c < col; c++) {
     unsigned r = 0;
     Value *vecElt = Builder.CreateExtractElement(vec, r);
@@ -934,13 +966,7 @@ void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
                             : Builder.CreateMul(vecElt, matElt);
 
     for (r = 1; r < row; r++) {
-      vecElt = Builder.CreateExtractElement(vec, r);
-      uint32_t matIdx = GetMatIdx(r, c, row);
-      Value *matElt = Builder.CreateExtractElement(mat, matIdx);
-      Value *tmpMul = isFloat ? Builder.CreateFMul(vecElt, matElt)
-                              : Builder.CreateMul(vecElt, matElt);
-      tmpVal = isFloat ? Builder.CreateFAdd(tmpVal, tmpMul)
-                       : Builder.CreateAdd(tmpVal, tmpMul);
+      tmpVal = CreateOneEltMad(r, c, tmpVal);
     }
 
     retVal = Builder.CreateInsertElement(retVal, tmpVal, c);
@@ -951,18 +977,18 @@ void HLMatrixLowerPass::TranslateVecMatMul(CallInst *matInst,
 }
 
 void HLMatrixLowerPass::TranslateMul(CallInst *matInst, Instruction *vecInst,
-                                     CallInst *mulInst) {
+                                     CallInst *mulInst, bool isSigned) {
   Value *LVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
   Value *RVal = mulInst->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
 
   bool LMat = IsMatrixType(LVal->getType());
   bool RMat = IsMatrixType(RVal->getType());
   if (LMat && RMat) {
-    TranslateMatMatMul(matInst, vecInst, mulInst);
+    TranslateMatMatMul(matInst, vecInst, mulInst, isSigned);
   } else if (LMat) {
-    TranslateMatVecMul(matInst, vecInst, mulInst);
+    TranslateMatVecMul(matInst, vecInst, mulInst, isSigned);
   } else {
-    TranslateVecMatMul(matInst, vecInst, mulInst);
+    TranslateVecMatMul(matInst, vecInst, mulInst, isSigned);
   }
 }
 
@@ -1209,8 +1235,11 @@ void HLMatrixLowerPass::MatIntrinsicReplace(CallInst *matInst,
   IRBuilder<> Builder(matUseInst);
   IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(matUseInst));
   switch (opcode) {
+  case IntrinsicOp::IOP_umul:
+    TranslateMul(matInst, vecInst, matUseInst, /*isSigned*/false);
+    break;
   case IntrinsicOp::IOP_mul:
-    TranslateMul(matInst, vecInst, matUseInst);
+    TranslateMul(matInst, vecInst, matUseInst, /*isSigned*/true);
     break;
   case IntrinsicOp::IOP_transpose:
     TranslateMatTranspose(matInst, vecInst, matUseInst);

+ 1 - 0
lib/HLSL/HLOperationLower.cpp

@@ -4236,6 +4236,7 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     { IntrinsicOp::IOP_umad, TranslateFUITrinary, DXIL::OpCode::UMad},
     { IntrinsicOp::IOP_umax, TranslateFUIBinary, DXIL::OpCode::UMax},
     { IntrinsicOp::IOP_umin,   TranslateFUIBinary, DXIL::OpCode::UMin },
+    { IntrinsicOp::IOP_umul,   TranslateFUIBinary, DXIL::OpCode::UMul },
     { IntrinsicOp::MOP_InterlockedUMax, TranslateMopAtomicBinaryOperation, DXIL::OpCode::NumOpCodes },
     { IntrinsicOp::MOP_InterlockedUMin, TranslateMopAtomicBinaryOperation, DXIL::OpCode::NumOpCodes },
 };

+ 4 - 1
tools/clang/test/CodeGenHLSL/matOps.hlsl

@@ -1,4 +1,6 @@
-// RUN: %dxc -E main -T ps_6_0 %s
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: sdiv
 
 float1x1 f1x1;
 float1x2 f1x2;
@@ -44,6 +46,7 @@ float4 main(float4 a : A) : SV_TARGET
   int4x4 im = i;
   im[2] = 1;
   im |= ~(i4x4<<2) + i4x4>>2 % (im & 2 | im ^ i);
+  im /= 3;
   bool4x4 b = (im++) < i;
   b = !b;
   float4 f4b = mul(f4, x+mt-x*f4x4b/im);

+ 1 - 2
tools/clang/test/CodeGenHLSL/precise2.hlsl

@@ -1,7 +1,6 @@
 // RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
 
-// CHECK: fmul float
-// CHECK: fadd float
+// CHECK: !dx.precise
 
 //--------------------------------------------------------------------------------------
 // File: BasicHLSL11_VS.hlsl

+ 1 - 2
tools/clang/test/CodeGenHLSL/precise3.hlsl

@@ -1,7 +1,6 @@
 // RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
 
-// CHECK: fmul float
-// CHECK: fadd float
+// CHECK: !dx.precise
 
 //--------------------------------------------------------------------------------------
 // File: BasicHLSL11_VS.hlsl

+ 1 - 2
tools/clang/test/CodeGenHLSL/precise4.hlsl

@@ -1,7 +1,6 @@
 // RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
 
-// CHECK: fmul float
-// CHECK: fadd float
+// CHECK: !dx.precise
 
 //--------------------------------------------------------------------------------------
 // File: BasicHLSL11_VS.hlsl

+ 3 - 3
utils/hct/gen_intrin_main.txt

@@ -172,10 +172,10 @@ numeric<c2> [[rn]] mul(in $match<1, 0> numeric a, in $match<2, 0> numeric<c2> b)
 numeric<r2, c2> [[rn]] mul(in $match<1, 0> numeric a, in $match<2, 0> numeric<r2, c2> b) : mul_sm;
 numeric<c> [[rn]] mul(in $match<1, 0> numeric<c> a, in $match<2, 0> numeric b) : mul_vs;
 numeric [[rn]] mul(in $match<1, 0> numeric<c> a, in $match<2, 0> numeric<c> b) : mul_vv;
-numeric<c2> [[rn]] mul(in $match<1, 0> numeric<c> a, in col_major $match<2, 0> numeric<c, c2> b) : mul_vm;
+numeric<c2> [[rn,unsigned_op=umul]] mul(in $match<1, 0> numeric<c> a, in col_major $match<2, 0> numeric<c, c2> b) : mul_vm;
 numeric<r, c> [[rn]] mul(in $match<1, 0> numeric<r, c> a, in $match<2, 0> numeric b) : mul_ms;
-numeric<r> [[rn]] mul(in row_major $match<1, 0> numeric<r, c> a, in $match<2, 0> numeric<c> b) : mul_mv;
-numeric<r, c2> [[rn]] mul(in row_major $match<1, 0> numeric<r, c> a, in col_major $match<2, 0> numeric<c, c2> b) : mul_mm;
+numeric<r> [[rn,unsigned_op=umul]] mul(in row_major $match<1, 0> numeric<r, c> a, in $match<2, 0> numeric<c> b) : mul_mv;
+numeric<r, c2> [[rn,unsigned_op=umul]] mul(in row_major $match<1, 0> numeric<r, c> a, in col_major $match<2, 0> numeric<c, c2> b) : mul_mm;
 $match<0, 1> float_like [[rn]] noise(in float_like<c> x);
 $type1 [[rn]] normalize(in float_like<c> x);
 $type1 [[rn]] pow(in float_like<> x, in $type1 y);