Przeglądaj źródła

[spirv] Optimize floating point matrix scaling codegen (#525)

SPIR-V has a specific OpMatrixTimesScalar for scaling floating
point matrices.
Lei Zhang 8 lat temu
rodzic
commit
76796801b8

+ 85 - 3
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -122,6 +122,18 @@ bool isCompoundAssignment(BinaryOperatorKind opcode) {
   }
 }
 
+bool isSpirvMatrixOp(spv::Op opcode) {
+  switch (opcode) {
+  case spv::Op::OpMatrixTimesMatrix:
+  case spv::Op::OpMatrixTimesVector:
+  case spv::Op::OpMatrixTimesScalar:
+    return true;
+  default:
+    break;
+  }
+  return false;
+}
+
 } // namespace
 
 SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
@@ -731,8 +743,10 @@ uint32_t SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
   if (opcode == BO_Assign)
     return processAssignment(expr->getLHS(), doExpr(expr->getRHS()), false);
 
-  // Try to optimize floatN * float case
+  // Try to optimize floatMxN * float and floatN * float case
   if (opcode == BO_Mul) {
+    if (const uint32_t result = tryToGenFloatMatrixScale(expr))
+      return result;
     if (const uint32_t result = tryToGenFloatVectorScale(expr))
       return result;
   }
@@ -940,8 +954,10 @@ uint32_t
 SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
   const auto opcode = expr->getOpcode();
 
-  // Try to optimize floatN *= float case
+  // Try to optimize floatMxN *= float and floatN *= float case
   if (opcode == BO_MulAssign) {
+    if (const uint32_t result = tryToGenFloatMatrixScale(expr))
+      return result;
     if (const uint32_t result = tryToGenFloatVectorScale(expr))
       return result;
   }
@@ -1370,7 +1386,11 @@ uint32_t SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
                                        const uint32_t resultType,
                                        uint32_t *lhsResultId,
                                        const spv::Op mandateGenOpcode) {
-  if (TypeTranslator::isSpirvAcceptableMatrixType(lhs->getType())) {
+  // If the operands are of matrix type, we need to dispatch the operation
+  // onto each element vector iff the operands are not degenerated matrices
+  // and we don't have a matrix specific SPIR-V instruction for the operation.
+  if (!isSpirvMatrixOp(mandateGenOpcode) &&
+      TypeTranslator::isSpirvAcceptableMatrixType(lhs->getType())) {
     return processMatrixBinaryOp(lhs, rhs, opcode);
   }
 
@@ -1647,6 +1667,68 @@ uint32_t SPIRVEmitter::tryToGenFloatVectorScale(const BinaryOperator *expr) {
   return 0;
 }
 
+uint32_t SPIRVEmitter::tryToGenFloatMatrixScale(const BinaryOperator *expr) {
+  const QualType type = expr->getType();
+  // We can only translate floatMxN * float into OpMatrixTimesScalar.
+  // So the result type must be floatMxN.
+  if (!hlsl::IsHLSLMatType(type) ||
+      !hlsl::GetHLSLMatElementType(type)->isFloatingType())
+    return 0;
+
+  const Expr *lhs = expr->getLHS();
+  const Expr *rhs = expr->getRHS();
+  const QualType lhsType = lhs->getType();
+  const QualType rhsType = rhs->getType();
+
+  const auto selectOpcode = [](const QualType ty) {
+    return TypeTranslator::isMx1MatrixType(ty) ||
+                   TypeTranslator::is1xNMatrixType(ty)
+               ? spv::Op::OpVectorTimesScalar
+               : spv::Op::OpMatrixTimesScalar;
+  };
+
+  // Multiplying a float matrix with a float scalar will be represented in
+  // AST via a binary operation with two float matrices as operands; one of
+  // the operand is from an implicit cast with kind CK_HLSLMatrixSplat.
+
+  // matrix * scalar
+  if (hlsl::IsHLSLMatType(lhsType)) {
+    if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
+      if (cast->getCastKind() == CK_HLSLMatrixSplat) {
+        const uint32_t matType = typeTranslator.translateType(expr->getType());
+        const spv::Op opcode = selectOpcode(lhsType);
+        if (isa<CompoundAssignOperator>(expr)) {
+          uint32_t lhsPtr = 0;
+          const uint32_t result =
+              processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
+                              matType, &lhsPtr, opcode);
+          return processAssignment(lhs, result, true, lhsPtr);
+        } else {
+          return processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
+                                 matType, nullptr, opcode);
+        }
+      }
+    }
+  }
+
+  // scalar * matrix
+  if (hlsl::IsHLSLMatType(rhsType)) {
+    if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
+      if (cast->getCastKind() == CK_HLSLMatrixSplat) {
+        const uint32_t matType = typeTranslator.translateType(expr->getType());
+        const spv::Op opcode = selectOpcode(rhsType);
+        // We need to switch the positions of lhs and rhs here because
+        // OpMatrixTimesScalar requires the first operand to be a matrix and
+        // the second to be a scalar.
+        return processBinaryOp(rhs, cast->getSubExpr(), expr->getOpcode(),
+                               matType, nullptr, opcode);
+      }
+    }
+  }
+
+  return 0;
+}
+
 uint32_t SPIRVEmitter::tryToAssignToVectorElements(const Expr *lhs,
                                                    const uint32_t rhs) {
   // Assigning to a vector swizzling lhs is tricky if we are neither

+ 5 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -158,6 +158,11 @@ private:
   /// floatN * float.
   uint32_t tryToGenFloatVectorScale(const BinaryOperator *expr);
 
+  /// Translates a floatMxN * float multiplication into SPIR-V instructions and
+  /// returns the <result-id>. Returns 0 if the given binary operation is not
+  /// floatMxN * float.
+  uint32_t tryToGenFloatMatrixScale(const BinaryOperator *expr);
+
   /// Tries to emit instructions for assigning to the given vector element
   /// accessing expression. Returns 0 if the trial fails and no instructions
   /// are generated.

+ 10 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -155,6 +155,16 @@ bool TypeTranslator::is1xNMatrixType(QualType type) {
   return rowCount == 1 && colCount > 1;
 }
 
+bool TypeTranslator::isMx1MatrixType(QualType type) {
+  if (!hlsl::IsHLSLMatType(type))
+    return false;
+
+  uint32_t rowCount = 0, colCount = 0;
+  hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
+
+  return rowCount > 1 && colCount == 1;
+}
+
 /// Returns true if the given type is a SPIR-V acceptable matrix type, i.e.,
 /// with floating point elements and greater than 1 row and column counts.
 bool TypeTranslator::isSpirvAcceptableMatrixType(QualType type) {

+ 3 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -48,6 +48,9 @@ public:
   /// \brief Returns true if the givne type is a 1xN (N > 1) matrix type.
   static bool is1xNMatrixType(QualType type);
 
+  /// \brief Returns true if the givne type is a Mx1 (M > 1) matrix type.
+  static bool isMx1MatrixType(QualType type);
+
   /// \brief Returns true if the given type is a SPIR-V acceptable matrix type,
   /// i.e., with floating point elements and greater than 1 row and column
   /// counts.

+ 57 - 7
tools/clang/test/CodeGenSPIRV/binary-op.arith-assign.mixed.hlsl

@@ -1,7 +1,5 @@
 // Run: %dxc -T vs_6_0 -E main
 
-// TODO: matrix *= scalar
-
 void main() {
 // CHECK-LABEL: %bb_entry = OpLabel
 
@@ -11,18 +9,70 @@ void main() {
     int3 c;
     int t;
 
+    float1 e;
+    int1 g;
+
+    float2x3 i;
+    float1x3 k;
+    float2x1 m;
+    float1x1 o;
+
     // Use OpVectorTimesScalar for floatN * float
-// CHECK:      [[s4:%\d+]] = OpLoad %float %s
-// CHECK-NEXT: [[a4:%\d+]] = OpLoad %v4float %a
-// CHECK-NEXT: [[mul0:%\d+]] = OpVectorTimesScalar %v4float [[a4]] [[s4]]
+// CHECK:      [[s0:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[a0:%\d+]] = OpLoad %v4float %a
+// CHECK-NEXT: [[mul0:%\d+]] = OpVectorTimesScalar %v4float [[a0]] [[s0]]
 // CHECK-NEXT: OpStore %a [[mul0]]
     a *= s;
 
     // Use normal OpCompositeConstruct and OpIMul for intN * int
 // CHECK-NEXT: [[t0:%\d+]] = OpLoad %int %t
-// CHECK-NEXT: [[cc10:%\d+]] = OpCompositeConstruct %v3int [[t0]] [[t0]] [[t0]]
+// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v3int [[t0]] [[t0]] [[t0]]
 // CHECK-NEXT: [[c0:%\d+]] = OpLoad %v3int %c
-// CHECK-NEXT: [[mul2:%\d+]] = OpIMul %v3int [[c0]] [[cc10]]
+// CHECK-NEXT: [[mul2:%\d+]] = OpIMul %v3int [[c0]] [[cc0]]
 // CHECK-NEXT: OpStore %c [[mul2]]
     c *= t;
+
+    // Vector of size 1
+// CHECK-NEXT: [[s2:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[e0:%\d+]] = OpLoad %float %e
+// CHECK-NEXT: [[mul4:%\d+]] = OpFMul %float [[e0]] [[s2]]
+// CHECK-NEXT: OpStore %e [[mul4]]
+    e *= s;
+// CHECK-NEXT: [[t2:%\d+]] = OpLoad %int %t
+// CHECK-NEXT: [[g0:%\d+]] = OpLoad %int %g
+// CHECK-NEXT: [[mul6:%\d+]] = OpIMul %int [[g0]] [[t2]]
+// CHECK-NEXT: OpStore %g [[mul6]]
+    g *= t;
+
+    // Use OpMatrixTimesScalar for floatMxN * float
+// CHECK-NEXT: [[s4:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %mat2v3float %i
+// CHECK-NEXT: [[mul8:%\d+]] = OpMatrixTimesScalar %mat2v3float [[i0]] [[s4]]
+// CHECK-NEXT: OpStore %i [[mul8]]
+    i *= s;
+
+    // Use OpVectorTimesScalar for float1xN * float
+    // Sadly, the AST is constructed differently for 'float1xN *= float' cases.
+    // So we are not able generate an OpVectorTimesScalar here.
+    // TODO: Minor issue. Fix this later maybe.
+// CHECK-NEXT: [[s6:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[cc1:%\d+]] = OpCompositeConstruct %v3float [[s6]] [[s6]] [[s6]]
+// CHECK-NEXT: [[k0:%\d+]] = OpLoad %v3float %k
+// CHECK-NEXT: [[mul10:%\d+]] = OpFMul %v3float [[k0]] [[cc1]]
+// CHECK-NEXT: OpStore %k [[mul10]]
+    k *= s;
+
+    // Use OpVectorTimesScalar for floatMx1 * float
+// CHECK-NEXT: [[s8:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[m0:%\d+]] = OpLoad %v2float %m
+// CHECK-NEXT: [[mul12:%\d+]] = OpVectorTimesScalar %v2float [[m0]] [[s8]]
+// CHECK-NEXT: OpStore %m [[mul12]]
+    m *= s;
+
+    // Matrix of size 1x1
+// CHECK-NEXT: [[s10:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[o0:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[mul14:%\d+]] = OpFMul %float [[o0]] [[s10]]
+// CHECK-NEXT: OpStore %o [[mul14]]
+    o *= s;
 }

+ 82 - 6
tools/clang/test/CodeGenSPIRV/binary-op.arithmetic.mixed.hlsl

@@ -1,7 +1,5 @@
 // Run: %dxc -T vs_6_0 -E main
 
-// TODO: matrix * scalar
-
 void main() {
 // CHECK-LABEL: %bb_entry = OpLabel
 
@@ -11,15 +9,23 @@ void main() {
     int3 c, d;
     int t;
 
+    float1 e, f;
+    int1 g, h;
+
+    float2x3 i, j;
+    float1x3 k, l;
+    float2x1 m, n;
+    float1x1 o, p;
+
     // Use OpVectorTimesScalar for floatN * float
 // CHECK:      [[a4:%\d+]] = OpLoad %v4float %a
-// CHECK-NEXT: [[s4:%\d+]] = OpLoad %float %s
-// CHECK-NEXT: [[mul0:%\d+]] = OpVectorTimesScalar %v4float [[a4]] [[s4]]
+// CHECK-NEXT: [[s0:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[mul0:%\d+]] = OpVectorTimesScalar %v4float [[a4]] [[s0]]
 // CHECK-NEXT: OpStore %b [[mul0]]
     b = a * s;
 // CHECK-NEXT: [[a5:%\d+]] = OpLoad %v4float %a
-// CHECK-NEXT: [[s5:%\d+]] = OpLoad %float %s
-// CHECK-NEXT: [[mul1:%\d+]] = OpVectorTimesScalar %v4float [[a5]] [[s5]]
+// CHECK-NEXT: [[s1:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[mul1:%\d+]] = OpVectorTimesScalar %v4float [[a5]] [[s1]]
 // CHECK-NEXT: OpStore %b [[mul1]]
     b = s * a;
 
@@ -36,4 +42,74 @@ void main() {
 // CHECK-NEXT: [[mul3:%\d+]] = OpIMul %v3int [[cc11]] [[c1]]
 // CHECK-NEXT: OpStore %d [[mul3]]
     d = t * c;
+
+    // Vector of size 1
+// CHECK-NEXT: [[e0:%\d+]] = OpLoad %float %e
+// CHECK-NEXT: [[s2:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[mul4:%\d+]] = OpFMul %float [[e0]] [[s2]]
+// CHECK-NEXT: OpStore %f [[mul4]]
+    f = e * s;
+// CHECK-NEXT: [[s3:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[e1:%\d+]] = OpLoad %float %e
+// CHECK-NEXT: [[mul5:%\d+]] = OpFMul %float [[s3]] [[e1]]
+// CHECK-NEXT: OpStore %f [[mul5]]
+    f = s * e;
+// CHECK-NEXT: [[g0:%\d+]] = OpLoad %int %g
+// CHECK-NEXT: [[t2:%\d+]] = OpLoad %int %t
+// CHECK-NEXT: [[mul6:%\d+]] = OpIMul %int [[g0]] [[t2]]
+// CHECK-NEXT: OpStore %h [[mul6]]
+    h = g * t;
+// CHECK-NEXT: [[t3:%\d+]] = OpLoad %int %t
+// CHECK-NEXT: [[g1:%\d+]] = OpLoad %int %g
+// CHECK-NEXT: [[mul7:%\d+]] = OpIMul %int [[t3]] [[g1]]
+// CHECK-NEXT: OpStore %h [[mul7]]
+    h = t * g;
+
+    // Use OpMatrixTimesScalar for floatMxN * float
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %mat2v3float %i
+// CHECK-NEXT: [[s4:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[mul8:%\d+]] = OpMatrixTimesScalar %mat2v3float [[i0]] [[s4]]
+// CHECK-NEXT: OpStore %j [[mul8]]
+    j = i * s;
+// CHECK-NEXT: [[i1:%\d+]] = OpLoad %mat2v3float %i
+// CHECK-NEXT: [[s5:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[mul9:%\d+]] = OpMatrixTimesScalar %mat2v3float [[i1]] [[s5]]
+// CHECK-NEXT: OpStore %j [[mul9]]
+    j = s * i;
+
+    // Use OpVectorTimesScalar for float1xN * float
+// CHECK-NEXT: [[k0:%\d+]] = OpLoad %v3float %k
+// CHECK-NEXT: [[s6:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[mul10:%\d+]] = OpVectorTimesScalar %v3float [[k0]] [[s6]]
+// CHECK-NEXT: OpStore %l [[mul10]]
+    l = k * s;
+// CHECK-NEXT: [[k1:%\d+]] = OpLoad %v3float %k
+// CHECK-NEXT: [[s7:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[mul11:%\d+]] = OpVectorTimesScalar %v3float [[k1]] [[s7]]
+// CHECK-NEXT: OpStore %l [[mul11]]
+    l = s * k;
+
+    // Use OpVectorTimesScalar for floatMx1 * float
+// CHECK-NEXT: [[m0:%\d+]] = OpLoad %v2float %m
+// CHECK-NEXT: [[s8:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[mul12:%\d+]] = OpVectorTimesScalar %v2float [[m0]] [[s8]]
+// CHECK-NEXT: OpStore %n [[mul12]]
+    n = m * s;
+// CHECK-NEXT: [[m1:%\d+]] = OpLoad %v2float %m
+// CHECK-NEXT: [[s9:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[mul13:%\d+]] = OpVectorTimesScalar %v2float [[m1]] [[s9]]
+// CHECK-NEXT: OpStore %n [[mul13]]
+    n = s * m;
+
+    // Matrix of size 1x1
+// CHECK-NEXT: [[o0:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[s10:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[mul14:%\d+]] = OpFMul %float [[o0]] [[s10]]
+// CHECK-NEXT: OpStore %p [[mul14]]
+    p = o * s;
+// CHECK-NEXT: [[s11:%\d+]] = OpLoad %float %s
+// CHECK-NEXT: [[o1:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[mul15:%\d+]] = OpFMul %float [[s11]] [[o1]]
+// CHECK-NEXT: OpStore %p [[mul15]]
+    p = s * o;
 }