瀏覽代碼

Support mul on vectors and scalars (#1561)

Fixes #1538
Helena Kotas 7 年之前
父節點
當前提交
582d26bc28
共有 2 個文件被更改,包括 95 次插入1 次删除
  1. 45 1
      lib/HLSL/HLOperationLower.cpp
  2. 50 0
      tools/clang/test/CodeGenHLSL/quick-test/mul-vector-scalar.hlsl

+ 45 - 1
lib/HLSL/HLOperationLower.cpp

@@ -2351,6 +2351,49 @@ Value *SplatToVector(Value *Elt, Type *DstTy, IRBuilder<> &Builder) {
   return Result;
 }
 
+Value *TranslateMul(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
+  HLOperationLowerHelper &helper, HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
+
+  hlsl::OP *hlslOP = &helper.hlslOP;
+  Value *arg0 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc0Idx);
+  Value *arg1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
+  Type *arg0Ty = arg0->getType();
+  Type *arg1Ty = arg1->getType();
+  IRBuilder<> Builder(CI);
+
+  if (arg0Ty->isVectorTy()) {
+    if (arg1Ty->isVectorTy()) {
+      // mul(vector, vector) == dot(vector, vector)
+      unsigned vecSize = arg0Ty->getVectorNumElements();
+      if (arg0Ty->getScalarType()->isFloatingPointTy()) {
+        return TranslateFDot(arg0, arg1, vecSize, hlslOP, Builder);
+      }
+      else {
+        return TranslateIDot(arg0, arg1, vecSize, hlslOP, Builder);
+      }
+    }
+    else {
+      // mul(vector, scalar) == vector * scalar-splat
+      arg1 = SplatToVector(arg1, arg0Ty, Builder);
+    }
+  }
+  else {
+    if (arg1Ty->isVectorTy()) {
+      // mul(scalar, vector) == scalar-splat * vector
+      arg0 = SplatToVector(arg0, arg1Ty, Builder);
+    }
+    // else mul(scalar, scalar) == scalar * scalar;
+  }
+
+  // create fmul/mul for the pair of vectors or scalars
+  if (arg0Ty->getScalarType()->isFloatingPointTy()) {
+    return Builder.CreateFMul(arg0, arg1);
+  }
+  else {
+    return Builder.CreateMul(arg0, arg1);
+  }
+}
+
 // Sample intrinsics.
 struct SampleHelper {
   SampleHelper(CallInst *CI, OP::OpCode op, HLObjectOperationLowerHelper *pObjHelper);
@@ -4251,6 +4294,7 @@ Value *TranslateProcessTessFactors(CallInst *CI, IntrinsicOp IOP, OP::OpCode opc
   return nullptr;
 }
 
+
 }
 
 // Ray Tracing.
@@ -4582,7 +4626,7 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     {IntrinsicOp::IOP_min, TranslateFUIBinary, DXIL::OpCode::IMin},
     {IntrinsicOp::IOP_modf, TranslateModF, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::IOP_msad4, TranslateMSad4, DXIL::OpCode::NumOpCodes},
-    {IntrinsicOp::IOP_mul, EmptyLower, DXIL::OpCode::NumOpCodes},
+    {IntrinsicOp::IOP_mul, TranslateMul, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::IOP_normalize, TranslateNormalize, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::IOP_pow, TranslatePow, DXIL::OpCode::NumOpCodes},
     {IntrinsicOp::IOP_radians, TranslateRadians, DXIL::OpCode::NumOpCodes},

+ 50 - 0
tools/clang/test/CodeGenHLSL/quick-test/mul-vector-scalar.hlsl

@@ -0,0 +1,50 @@
+// RUN: %dxc -T vs_6_0 -E main -Od %s  | FileCheck %s
+
+void main() {
+
+    float3 fvec1 = { 0.1, 0.2, 0.3};
+    float4 fvec2 = { 1.1, 1.2, 1.3, 1.4};
+    float fx1 = 0.5;
+    float fx2 = 1.5;
+
+// CHECK: call float @dx.op.dot3.f32
+    float4 a = mul(fvec1, fvec2);
+
+// CHECK: fmul fast float
+// CHECK: fmul fast float
+// CHECK: fmul fast float
+    float3 b = mul(fvec1, fx1);
+
+// CHECK: fmul fast float
+// CHECK: fmul fast float
+// CHECK: fmul fast float
+    float3 c = mul(fx1, fvec1);
+
+// CHECK: fmul fast float
+    float d = mul(fx1, fx2);
+    
+    int4 ivec1 = { 1, 2, 3, 4};
+    int3 ivec2 = { 4, 5, 6};
+    int i1 = 1;
+    int i2 = 2;   
+
+// CHECK: mul i32
+// CHECK: call i32 @dx.op.tertiary.i32(i32 48,
+// CHECK: call i32 @dx.op.tertiary.i32(i32 48,
+    int e = mul(ivec1, ivec2);
+
+// CHECK: mul i32
+// CHECK: mul i32
+// CHECK: mul i32
+// CHECK: mul i32
+    int4 f = mul(ivec1, i1);
+
+// CHECK: mul i32
+// CHECK: mul i32
+// CHECK: mul i32
+// CHECK: mul i32
+    int4 g = mul(i1, ivec1);
+
+// CHECK: mul i32
+    int h = mul(i1, i2);
+}