Browse Source

[spirv] Translate intrinsic HLSL mul() function (#547)

Ehsan 8 years ago
parent
commit
9d7db791eb

+ 135 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -2497,6 +2497,8 @@ uint32_t SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
   switch (static_cast<hlsl::IntrinsicOp>(opcode)) {
   case hlsl::IntrinsicOp::IOP_dot:
     return processIntrinsicDot(callExpr);
+  case hlsl::IntrinsicOp::IOP_mul:
+    return processIntrinsicMul(callExpr);
   case hlsl::IntrinsicOp::IOP_all:
     return processIntrinsicAllOrAny(callExpr, spv::Op::OpAll);
   case hlsl::IntrinsicOp::IOP_any:
@@ -2550,6 +2552,139 @@ uint32_t SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
   return processIntrinsicUsingGLSLInst(callExpr, glslOpcode, actOnEachVecInMat);
 }
 
+uint32_t SPIRVEmitter::processIntrinsicMul(const CallExpr *callExpr) {
+  const QualType returnType = callExpr->getType();
+  const uint32_t returnTypeId =
+      typeTranslator.translateType(callExpr->getType());
+
+  // Get the function parameters. Expect 2 parameters.
+  assert(callExpr->getNumArgs() == 2u);
+  const Expr *arg0 = callExpr->getArg(0);
+  const Expr *arg1 = callExpr->getArg(1);
+  const QualType arg0Type = arg0->getType();
+  const QualType arg1Type = arg1->getType();
+
+  // The HLSL mul() function takes 2 arguments. Each argument may be a scalar,
+  // vector, or matrix. The frontend ensures that the two arguments have the
+  // same component type. The only allowed component types are int and float.
+
+  // mul(scalar, vector)
+  {
+    uint32_t elemCount = 0;
+    if (TypeTranslator::isScalarType(arg0Type) &&
+        TypeTranslator::isVectorType(arg1Type, nullptr, &elemCount)) {
+
+      const uint32_t arg1Id = doExpr(arg1);
+
+      // We can use OpVectorTimesScalar if arguments are floats.
+      if (arg0Type->isFloatingType())
+        return theBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
+                                         returnTypeId, arg1Id, doExpr(arg0));
+
+      // Use OpIMul for integers
+      return theBuilder.createBinaryOp(spv::Op::OpIMul, returnTypeId,
+                                       createVectorSplat(arg0, elemCount),
+                                       arg1Id);
+    }
+  }
+
+  // mul(vector, scalar)
+  {
+    uint32_t elemCount = 0;
+    if (TypeTranslator::isVectorType(arg0Type, nullptr, &elemCount) &&
+        TypeTranslator::isScalarType(arg1Type)) {
+
+      const uint32_t arg0Id = doExpr(arg0);
+
+      // We can use OpVectorTimesScalar if arguments are floats.
+      if (arg1Type->isFloatingType())
+        return theBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
+                                         returnTypeId, arg0Id, doExpr(arg1));
+
+      // Use OpIMul for integers
+      return theBuilder.createBinaryOp(spv::Op::OpIMul, returnTypeId, arg0Id,
+                                       createVectorSplat(arg1, elemCount));
+    }
+  }
+
+  // mul(vector, vector)
+  if (TypeTranslator::isVectorType(arg0Type) &&
+      TypeTranslator::isVectorType(arg1Type))
+    return processIntrinsicDot(callExpr);
+
+  // All the following cases require handling arg0 and arg1 expressions first.
+  const uint32_t arg0Id = doExpr(arg0);
+  const uint32_t arg1Id = doExpr(arg1);
+
+  // mul(scalar, scalar)
+  if (TypeTranslator::isScalarType(arg0Type) &&
+      TypeTranslator::isScalarType(arg1Type))
+    return theBuilder.createBinaryOp(translateOp(BO_Mul, arg0Type),
+                                     returnTypeId, arg0Id, arg1Id);
+
+  // mul(scalar, matrix)
+  if (TypeTranslator::isScalarType(arg0Type) &&
+      TypeTranslator::isMxNMatrix(arg1Type)) {
+    // We currently only support float matrices. So we can use
+    // OpMatrixTimesScalar
+    if (arg0Type->isFloatingType())
+      return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesScalar,
+                                       returnTypeId, arg1Id, arg0Id);
+  }
+
+  // mul(matrix, scalar)
+  if (TypeTranslator::isScalarType(arg1Type) &&
+      TypeTranslator::isMxNMatrix(arg0Type)) {
+    // We currently only support float matrices. So we can use
+    // OpMatrixTimesScalar
+    if (arg1Type->isFloatingType())
+      return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesScalar,
+                                       returnTypeId, arg0Id, arg1Id);
+  }
+
+  // mul(vector, matrix)
+  {
+    QualType elemType = {};
+    uint32_t elemCount = 0, numRows = 0;
+    if (TypeTranslator::isVectorType(arg0Type, &elemType, &elemCount) &&
+        TypeTranslator::isMxNMatrix(arg1Type, nullptr, &numRows, nullptr) &&
+        elemType->isFloatingType()) {
+      assert(elemCount == numRows);
+      return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesVector,
+                                       returnTypeId, arg1Id, arg0Id);
+    }
+  }
+
+  // mul(matrix, vector)
+  {
+    QualType elemType = {};
+    uint32_t elemCount = 0, numCols = 0;
+    if (TypeTranslator::isMxNMatrix(arg0Type, nullptr, nullptr, &numCols) &&
+        TypeTranslator::isVectorType(arg1Type, &elemType, &elemCount) &&
+        elemType->isFloatingType()) {
+      assert(elemCount == numCols);
+      return theBuilder.createBinaryOp(spv::Op::OpVectorTimesMatrix,
+                                       returnTypeId, arg1Id, arg0Id);
+    }
+  }
+
+  // mul(matrix, matrix)
+  {
+    QualType elemType = {};
+    uint32_t arg0Cols = 0, arg1Rows = 0;
+    if (TypeTranslator::isMxNMatrix(arg0Type, &elemType, nullptr, &arg0Cols) &&
+        TypeTranslator::isMxNMatrix(arg1Type, nullptr, &arg1Rows, nullptr) &&
+        elemType->isFloatingType()) {
+      assert(arg0Cols == arg1Rows);
+      return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesMatrix,
+                                       returnTypeId, arg1Id, arg0Id);
+    }
+  }
+
+  emitError("Unsupported arguments passed to mul() function.");
+  return 0;
+}
+
 uint32_t SPIRVEmitter::processIntrinsicDot(const CallExpr *callExpr) {
   const QualType returnType = callExpr->getType();
   const uint32_t returnTypeId =

+ 3 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -227,6 +227,9 @@ private:
   /// Processes HLSL instrinsic functions.
   uint32_t processIntrinsicCallExpr(const CallExpr *);
 
+  /// Processes the 'mul' intrinsic function.
+  uint32_t processIntrinsicMul(const CallExpr *);
+
   /// Processes the 'dot' intrinsic function.
   uint32_t processIntrinsicDot(const CallExpr *);
 

+ 142 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.mul.hlsl

@@ -0,0 +1,142 @@
+// Run: %dxc -T ps_6_0 -E main
+
+/*
+According to HLSL reference, mul() has the following versions:
+
+|Name|Purpose|Template|Component Type  |size|
+|====|=======|========|================|==========================================================================|
+|x   |in     |scalar  |float, int      |1                                                                         |
+|y   |in     |scalar  |same as input x |1                                                                         |
+|ret |out    |scalar  |same as input x |1                                                                         |
+|====|=======|========|================|==========================================================================|
+|x   |in     |scalar  |float, int      | 1                                                                        |
+|y   |in     |vector  |float, int      |any                                                                       |
+|ret |out    |vector  |float, int      |same dimension(s) as input y                                              |
+|====|=======|========|================|==========================================================================|
+|x   |in     |scalar  |float, int      |1                                                                         |
+|y   |in     |matrix  |float, int      |any                                                                       |
+|ret |out    |matrix  |same as intput y|same dimension(s) as input y                                              |
+|====|=======|========|================|==========================================================================|
+|x   |in     |vector  |float, int      |any                                                                       |
+|y   |in     |scalar  |float, int      |1                                                                         |
+|ret |out    |vector  |float, int      |same dimension(s) as input x                                              |
+|====|=======|========|================|==========================================================================|
+|x   |in     |vector  |float, int      |any                                                                       |
+|y   |in     |vector  |float, int      |same dimension(s) as input x                                              |
+|ret |out    |scalar  |float, int      |1                                                                         |
+|====|=======|========|================|==========================================================================|
+|x   |in     |vector  |float, int      |any                                                                       |
+|y   |in     |matrix  |float, int      |rows = same dimension(s) as input x, columns = any                        |
+|ret |out    |vector  |float, int      |same dimension(s) as input y columns                                      |
+|====|=======|========|================|==========================================================================|
+|x   |in     |matrix  |float, int      |any                                                                       |
+|y   |in     |scalar  |float, int      |1                                                                         |
+|ret |out    |matrix  |float, int      |same dimension(s) as input x                                              |
+|====|=======|========|================|==========================================================================|
+|x   |in     |matrix  |float, int      |any                                                                       |
+|y   |in     |vector  |float, int      |number of columns in input x                                              |
+|ret |out    |vector  |float, int      |number of rows in input x                                                 |
+|====|=======|========|================|==========================================================================|
+|x   |in     |matrix  |float, int      |any                                                                       |
+|y   |in     |matrix  |float, int      |rows = number of columns in input x                                       |
+|ret |out    |matrix  |float, int      |rows = number of rows in input x, columns = number of columns in input y  |
+|====|=======|========|================|==========================================================================|
+*/
+
+void main() {
+
+  float a, b;
+// CHECK: {{%\d+}} = OpFMul %float {{%\d+}} {{%\d+}}
+  float scalarMulscalar = mul(a,b);
+
+  float float_c;
+  float4 float4_d;
+
+// CHECK:      [[float4_d:%\d+]] = OpLoad %v4float %float4_d
+// CHECK-NEXT: [[float_c:%\d+]] = OpLoad %float %float_c
+// CHECK-NEXT: {{%\d+}} = OpVectorTimesScalar %v4float [[float4_d]] [[float_c]]
+  float4 float_scalarMulVector = mul(float_c,float4_d);
+
+// CHECK:      [[float4_d1:%\d+]] = OpLoad %v4float %float4_d
+// CHECK-NEXT: [[float_c1:%\d+]] = OpLoad %float %float_c
+// CHECK-NEXT: {{%\d+}} = OpVectorTimesScalar %v4float [[float4_d1]] [[float_c1]]
+  float4 float_vectorMulScalar = mul(float4_d,float_c);
+
+  int int_c;
+  int4 int4_d;
+
+// CHECK:      [[int4_d:%\d+]] = OpLoad %v4int %int4_d
+// CHECK-NEXT: [[int_c:%\d+]] = OpLoad %int %int_c
+// CHECK-NEXT: [[c_splat:%\d+]] = OpCompositeConstruct %v4int [[int_c]] [[int_c]] [[int_c]] [[int_c]]
+// CHECK-NEXT: {{%\d+}} = OpIMul %v4int [[c_splat]] [[int4_d]]
+  int4 int_scalarMulVector = mul(int_c,int4_d);
+
+// CHECK:      [[int4_d1:%\d+]] = OpLoad %v4int %int4_d
+// CHECK-NEXT: [[int_c1:%\d+]] = OpLoad %int %int_c
+// CHECK-NEXT: [[c_splat1:%\d+]] = OpCompositeConstruct %v4int [[int_c1]] [[int_c1]] [[int_c1]] [[int_c1]]
+// CHECK-NEXT: {{%\d+}} = OpIMul %v4int [[int4_d1]] [[c_splat1]]
+  int4 int_vectorMulScalar = mul(int4_d,int_c);
+  
+  float e;
+  float3x4 f;
+
+// CHECK:      [[e:%\d+]] = OpLoad %float %e
+// CHECK-NEXT: [[f:%\d+]] = OpLoad %mat3v4float %f
+// CHECK-NEXT: {{%\d+}} = OpMatrixTimesScalar %mat3v4float [[f]] [[e]]
+  float3x4 scalarMulMatrix = mul(e,f);
+  
+// CHECK:      [[f1:%\d+]] = OpLoad %mat3v4float %f
+// CHECK-NEXT: [[e1:%\d+]] = OpLoad %float %e
+// CHECK-NEXT: {{%\d+}} = OpMatrixTimesScalar %mat3v4float [[f1]] [[e1]]
+  float3x4 matrixMulScalar = mul(f,e);
+
+
+  int4 g,h;
+// CHECK:      [[g:%\d+]] = OpLoad %v4int %g
+// CHECK-NEXT: [[h:%\d+]] = OpLoad %v4int %h
+// CHECK-NEXT: [[g0:%\d+]] = OpCompositeExtract %int [[g]] 0
+// CHECK-NEXT: [[h0:%\d+]] = OpCompositeExtract %int [[h]] 0
+// CHECK-NEXT: [[g0h0:%\d+]] = OpIMul %int [[g0]] [[h0]]
+// CHECK-NEXT: [[g1:%\d+]] = OpCompositeExtract %int [[g]] 1
+// CHECK-NEXT: [[h1:%\d+]] = OpCompositeExtract %int [[h]] 1
+// CHECK-NEXT: [[g1h1:%\d+]] = OpIMul %int [[g1]] [[h1]]
+// CHECK-NEXT: [[g2:%\d+]] = OpCompositeExtract %int [[g]] 2
+// CHECK-NEXT: [[h2:%\d+]] = OpCompositeExtract %int [[h]] 2
+// CHECK-NEXT: [[g2h2:%\d+]] = OpIMul %int [[g2]] [[h2]]
+// CHECK-NEXT: [[g3:%\d+]] = OpCompositeExtract %int [[g]] 3
+// CHECK-NEXT: [[h3:%\d+]] = OpCompositeExtract %int [[h]] 3
+// CHECK-NEXT: [[g3h3:%\d+]] = OpIMul %int [[g3]] [[h3]]
+// CHECK-NEXT: [[add_1:%\d+]] = OpIAdd %int [[g0h0]] [[g1h1]]
+// CHECK-NEXT: [[add_2:%\d+]] = OpIAdd %int [[add_1]] [[g2h2]]
+// CHECK-NEXT: [[add_3:%\d+]] = OpIAdd %int [[add_2]] [[g3h3]]
+// CHECK-NEXT: OpStore %vectorMulVector [[add_3]]
+  int vectorMulVector = mul(g,h);
+
+  float3 float_g, float_h;
+// CHECK:      [[float_g:%\d+]] = OpLoad %v3float %float_g
+// CHECK-NEXT: [[float_h:%\d+]] = OpLoad %v3float %float_h
+// CHECK-NEXT: {{%\d+}} = OpDot %float [[float_g]] [[float_h]]
+  float float_vectorMulVector = mul(float_g, float_h);
+
+  float4 i;
+  float4x3 j;
+// CHECK:      [[i:%\d+]] = OpLoad %v4float %i
+// CHECK-NEXT: [[j:%\d+]] = OpLoad %mat4v3float %j
+// CHECK-NEXT: {{%\d+}} = OpMatrixTimesVector %v3float [[j]] [[i]]
+  float3 vectorMulMatrix = mul(i,j);
+
+  float2x3 k;
+  float3 l;
+// CHECK:      [[k:%\d+]] = OpLoad %mat2v3float %k
+// CHECK-NEXT: [[l:%\d+]] = OpLoad %v3float %l
+// CHECK-NEXT: {{%\d+}} = OpVectorTimesMatrix %v2float [[l]] [[k]]
+  float2 matrixMulVector = mul(k,l);
+
+
+  float3x4 m;
+  float4x2 n;
+// CHECK:      [[m:%\d+]] = OpLoad %mat3v4float %m
+// CHECK-NEXT: [[n:%\d+]] = OpLoad %mat4v2float %n
+// CHECK-NEXT: {{%\d+}} = OpMatrixTimesMatrix %mat3v2float [[n]] [[m]]
+  float3x2 matrixMulMatrix = mul(m,n);
+}

+ 1 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -262,6 +262,7 @@ TEST_F(FileTest, SemanticArbitrary) { runFileTest("semantic.arbitrary.hlsl"); }
 
 // For intrinsic functions
 TEST_F(FileTest, IntrinsicsDot) { runFileTest("intrinsics.dot.hlsl"); }
+TEST_F(FileTest, IntrinsicsMul) { runFileTest("intrinsics.mul.hlsl"); }
 TEST_F(FileTest, IntrinsicsAll) { runFileTest("intrinsics.all.hlsl"); }
 TEST_F(FileTest, IntrinsicsAny) { runFileTest("intrinsics.any.hlsl"); }
 TEST_F(FileTest, IntrinsicsAsfloat) { runFileTest("intrinsics.asfloat.hlsl"); }