Browse Source

[spirv] Add support for matrix arithmetic operations (#508)

Lei Zhang 8 years ago
parent
commit
2fde79b470

+ 18 - 0
tools/clang/include/clang/SPIRV/TypeTranslator.h

@@ -37,6 +37,24 @@ public:
   /// on will be generated.
   uint32_t translateType(QualType type);
 
+  /// \brief Returns true if the givne type is a 1x1 matrix type.
+  static bool is1x1MatrixType(QualType type);
+
+  /// \brief Returns true if the givne type is a 1xN (N > 1) matrix type.
+  static bool is1xNMatrixType(QualType type);
+
+  /// \brief Returns true if the given type is a SPIR-V acceptable matrix type,
+  /// i.e., with floating point elements and greater than 1 row and column
+  /// counts.
+  static bool isSpirvAcceptableMatrixType(QualType type);
+
+  /// \brief Generates the corresponding SPIR-V vector type for the given Clang
+  /// frontend matrix type's vector component and returns the <result-id>.
+  ///
+  /// This method will panic if the given matrix type is not a SPIR-V acceptable
+  /// matrix type.
+  uint32_t getComponentVectorType(QualType matrixType);
+
 private:
   /// \brief Wrapper method to create an error message and report it
   /// in the diagnostic engine associated with this consumer.

+ 128 - 4
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -27,6 +27,8 @@ namespace spirv {
 
 namespace {
 
+// TODO: Maybe we should move these type probing functions to TypeTranslator.
+
 /// Returns true if the given type is a bool or vector of bool type.
 bool isBoolOrVecOfBoolType(QualType type) {
   return type->isBooleanType() ||
@@ -57,6 +59,36 @@ bool isFloatOrVecOfFloatType(QualType type) {
           hlsl::GetHLSLVecElementType(type)->isFloatingType());
 }
 
+/// Returns true if the given type is a bool or vector/matrix of bool type.
+bool isBoolOrVecMatOfBoolType(QualType type) {
+  return isBoolOrVecOfBoolType(type) ||
+         (hlsl::IsHLSLMatType(type) &&
+          hlsl::GetHLSLMatElementType(type)->isBooleanType());
+}
+
+/// Returns true if the given type is a signed integer or vector/matrix of
+/// signed integer type.
+bool isSintOrVecMatOfSintType(QualType type) {
+  return isSintOrVecOfSintType(type) ||
+         (hlsl::IsHLSLMatType(type) &&
+          hlsl::GetHLSLMatElementType(type)->isSignedIntegerType());
+}
+
+/// Returns true if the given type is an unsigned integer or vector/matrix of
+/// unsigned integer type.
+bool isUintOrVecMatOfUintType(QualType type) {
+  return isUintOrVecOfUintType(type) ||
+         (hlsl::IsHLSLMatType(type) &&
+          hlsl::GetHLSLMatElementType(type)->isUnsignedIntegerType());
+}
+
+/// Returns true if the given type is a float or vector/matrix of float type.
+bool isFloatOrVecMatOfFloatType(QualType type) {
+  return isFloatOrVecOfFloatType(type) ||
+         (hlsl::IsHLSLMatType(type) &&
+          hlsl::GetHLSLMatElementType(type)->isFloatingType());
+}
+
 bool isCompoundAssignment(BinaryOperatorKind opcode) {
   switch (opcode) {
   case BO_AddAssign:
@@ -1039,6 +1071,72 @@ public:
     return isCompoundAssignment ? lhsPtr : rhs;
   }
 
+  /// Generates the necessary instructions for conducting the given binary
+  /// operation on lhs and rhs.
+  ///
+  /// This method expects that both lhs and rhs are SPIR-V acceptable matrices.
+  uint32_t processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
+                                 const BinaryOperatorKind opcode) {
+    // TODO: some code are duplicated from processBinaryOp. Try to unify them.
+    const auto lhsType = lhs->getType();
+    assert(TypeTranslator::isSpirvAcceptableMatrixType(lhsType));
+    const spv::Op spvOp = translateOp(opcode, lhsType);
+
+    uint32_t rhsVal, lhsPtr, lhsVal;
+    if (isCompoundAssignment(opcode)) {
+      // Evalute rhs before lhs
+      rhsVal = doExpr(rhs);
+      lhsPtr = doExpr(lhs);
+      const uint32_t lhsTy = typeTranslator.translateType(lhsType);
+      lhsVal = theBuilder.createLoad(lhsTy, lhsPtr);
+    } else {
+      // Evalute lhs before rhs
+      lhsVal = lhsPtr = doExpr(lhs);
+      rhsVal = doExpr(rhs);
+    }
+
+    switch (opcode) {
+    case BO_Add:
+    case BO_Sub:
+    case BO_Mul:
+    case BO_Div:
+    case BO_Rem:
+    case BO_AddAssign:
+    case BO_SubAssign:
+    case BO_MulAssign:
+    case BO_DivAssign:
+    case BO_RemAssign: {
+      const uint32_t vecType = typeTranslator.getComponentVectorType(lhsType);
+
+      uint32_t rowCount = 0, colCount = 0;
+      hlsl::GetHLSLMatRowColCount(lhsType, rowCount, colCount);
+
+      llvm::SmallVector<uint32_t, 4> vectors;
+      // Extract each component vector and do operation on it
+      for (uint32_t i = 0; i < rowCount; ++i) {
+        const uint32_t lhsVec =
+            theBuilder.createCompositeExtract(vecType, lhsVal, {i});
+        const uint32_t rhsVec =
+            theBuilder.createCompositeExtract(vecType, rhsVal, {i});
+        vectors.push_back(
+            theBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec));
+      }
+
+      // Construct the result matrix
+      return theBuilder.createCompositeConstruct(
+          typeTranslator.translateType(lhsType), vectors);
+    }
+    case BO_Assign:
+      llvm_unreachable("assignment should not be handled here");
+    default:
+      break;
+    }
+
+    emitError("BinaryOperator '%0' for matrices not supported yet")
+        << BinaryOperator::getOpcodeStr(opcode);
+    return 0;
+  }
+
   /// Generates the necessary instructions for conducting the given binary
   /// operation on lhs and rhs. If lhsResultId is not nullptr, the evaluated
   /// pointer from lhs during the process will be written into it. If
@@ -1049,6 +1147,11 @@ public:
                            const uint32_t resultType,
                            uint32_t *lhsResultId = nullptr,
                            const spv::Op mandateGenOpcode = spv::Op::Max) {
+
+    if (TypeTranslator::isSpirvAcceptableMatrixType(lhs->getType())) {
+      return processMatrixBinaryOp(lhs, rhs, opcode);
+    }
+
     const spv::Op spvOp = (mandateGenOpcode == spv::Op::Max)
                               ? translateOp(opcode, lhs->getType())
                               : mandateGenOpcode;
@@ -1484,6 +1587,28 @@ public:
       assert(hlsl::GetHLSLVecSize(subExpr->getType()) == 1);
       return doExpr(subExpr);
     }
+    case CastKind::CK_HLSLVectorToMatrixCast: {
+      // The target type should already be a 1xN matrix type.
+      assert(TypeTranslator::is1xNMatrixType(toType));
+      return doExpr(subExpr);
+    }
+    case CastKind::CK_HLSLMatrixSplat: {
+      if (TypeTranslator::is1x1MatrixType(toType))
+        return doExpr(subExpr);
+
+      emitError("matrix splatting not supported yet");
+      return 0;
+    }
+    case CastKind::CK_HLSLMatrixToScalarCast: {
+      // The underlying should already be a matrix of 1x1.
+      assert(TypeTranslator::is1x1MatrixType(subExpr->getType()));
+      return doExpr(subExpr);
+    }
+    case CastKind::CK_HLSLMatrixToVectorCast: {
+      // The underlying should already be a matrix of 1xN.
+      assert(TypeTranslator::is1xNMatrixType(subExpr->getType()));
+      return doExpr(subExpr);
+    }
     case CastKind::CK_FunctionToPointerDecay:
       // Just need to return the function id
       return doExpr(subExpr);
@@ -1840,10 +1965,9 @@ public:
   /// Translates the given frontend binary operator into its SPIR-V equivalent
   /// taking consideration of the operand type.
   spv::Op translateOp(BinaryOperator::Opcode op, QualType type) {
-    // TODO: the following is not considering vector types yet.
-    const bool isSintType = isSintOrVecOfSintType(type);
-    const bool isUintType = isUintOrVecOfUintType(type);
-    const bool isFloatType = isFloatOrVecOfFloatType(type);
+    const bool isSintType = isSintOrVecMatOfSintType(type);
+    const bool isUintType = isUintOrVecMatOfUintType(type);
+    const bool isFloatType = isFloatOrVecMatOfFloatType(type);
 
 #define BIN_OP_CASE_INT_FLOAT(kind, intBinOp, floatBinOp)                      \
   \

+ 47 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -106,5 +106,52 @@ uint32_t TypeTranslator::translateType(QualType type) {
   return 0;
 }
 
+bool TypeTranslator::is1x1MatrixType(QualType type) {
+  if (!hlsl::IsHLSLMatType(type))
+    return false;
+
+  uint32_t rowCount = 0, colCount = 0;
+  hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
+
+  return rowCount == 1 && colCount == 1;
+}
+
+bool TypeTranslator::is1xNMatrixType(QualType type) {
+  if (!hlsl::IsHLSLMatType(type))
+    return false;
+
+  uint32_t rowCount = 0, colCount = 0;
+  hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
+
+  return rowCount == 1 && colCount > 1;
+}
+
+/// Returns true if the given type is a SPIR-V acceptable matrix type, i.e.,
+/// with floating point elements and greater than 1 row and column counts.
+bool TypeTranslator::isSpirvAcceptableMatrixType(QualType type) {
+  if (!hlsl::IsHLSLMatType(type))
+    return false;
+
+  const auto elemType = hlsl::GetHLSLMatElementType(type);
+  if (!elemType->isFloatingType())
+    return false;
+
+  uint32_t rowCount = 0, colCount = 0;
+  hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
+  return rowCount > 1 && colCount > 1;
+}
+
+uint32_t TypeTranslator::getComponentVectorType(QualType matrixType) {
+  assert(isSpirvAcceptableMatrixType(matrixType));
+
+  const uint32_t elemType =
+      translateType(hlsl::GetHLSLMatElementType(matrixType));
+
+  uint32_t rowCount = 0, colCount = 0;
+  hlsl::GetHLSLMatRowColCount(matrixType, rowCount, colCount);
+
+  return theBuilder.getVecType(elemType, colCount);
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 55 - 0
tools/clang/test/CodeGenSPIRV/binary-op.arith-assign.matrix.hlsl

@@ -0,0 +1,55 @@
+// Run: %dxc -T vs_6_0 -E main
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    float1x1 a, b;
+// CHECK:      [[a0:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: [[b0:%\d+]] = OpLoad %float %b
+// CHECK-NEXT: [[b1:%\d+]] = OpFAdd %float [[b0]] [[a0]]
+// CHECK-NEXT: OpStore %b [[b1]]
+    b += a;
+
+    float2x1 c, d;
+// CHECK-NEXT: [[c0:%\d+]] = OpLoad %v2float %c
+// CHECK-NEXT: [[d0:%\d+]] = OpLoad %v2float %d
+// CHECK-NEXT: [[d1:%\d+]] = OpFSub %v2float [[d0]] [[c0]]
+// CHECK-NEXT: OpStore %d [[d1]]
+    d -= c;
+
+    float1x3 e, f;
+// CHECK-NEXT: [[e0:%\d+]] = OpLoad %v3float %e
+// CHECK-NEXT: [[f0:%\d+]] = OpLoad %v3float %f
+// CHECK-NEXT: [[f1:%\d+]] = OpFMul %v3float [[f0]] [[e0]]
+// CHECK-NEXT: OpStore %f [[f1]]
+    f *= e;
+
+    float2x3 g, h;
+// CHECK-NEXT: [[g0:%\d+]] = OpLoad %mat2v3float %g
+// CHECK-NEXT: [[h0:%\d+]] = OpLoad %mat2v3float %h
+// CHECK-NEXT: [[h0v0:%\d+]] = OpCompositeExtract %v3float [[h0]] 0
+// CHECK-NEXT: [[g0v0:%\d+]] = OpCompositeExtract %v3float [[g0]] 0
+// CHECK-NEXT: [[h1v0:%\d+]] = OpFDiv %v3float [[h0v0]] [[g0v0]]
+// CHECK-NEXT: [[h0v1:%\d+]] = OpCompositeExtract %v3float [[h0]] 1
+// CHECK-NEXT: [[g0v1:%\d+]] = OpCompositeExtract %v3float [[g0]] 1
+// CHECK-NEXT: [[h1v1:%\d+]] = OpFDiv %v3float [[h0v1]] [[g0v1]]
+// CHECK-NEXT: [[h1:%\d+]] = OpCompositeConstruct %mat2v3float [[h1v0]] [[h1v1]]
+// CHECK-NEXT: OpStore %h [[h1]]
+    h /= g;
+
+    float3x2 i, j;
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %mat3v2float %i
+// CHECK-NEXT: [[j0:%\d+]] = OpLoad %mat3v2float %j
+// CHECK-NEXT: [[j0v0:%\d+]] = OpCompositeExtract %v2float [[j0]] 0
+// CHECK-NEXT: [[i0v0:%\d+]] = OpCompositeExtract %v2float [[i0]] 0
+// CHECK-NEXT: [[j1v0:%\d+]] = OpFRem %v2float [[j0v0]] [[i0v0]]
+// CHECK-NEXT: [[j0v1:%\d+]] = OpCompositeExtract %v2float [[j0]] 1
+// CHECK-NEXT: [[i0v1:%\d+]] = OpCompositeExtract %v2float [[i0]] 1
+// CHECK-NEXT: [[j1v1:%\d+]] = OpFRem %v2float [[j0v1]] [[i0v1]]
+// CHECK-NEXT: [[j0v2:%\d+]] = OpCompositeExtract %v2float [[j0]] 2
+// CHECK-NEXT: [[i0v2:%\d+]] = OpCompositeExtract %v2float [[i0]] 2
+// CHECK-NEXT: [[j1v2:%\d+]] = OpFRem %v2float [[j0v2]] [[i0v2]]
+// CHECK-NEXT: [[j1:%\d+]] = OpCompositeConstruct %mat3v2float [[j1v0]] [[j1v1]] [[j1v2]]
+// CHECK-NEXT: OpStore %j [[j1]]
+    j %= i;
+}

+ 2 - 0
tools/clang/test/CodeGenSPIRV/binary-op.arith-assign.mixed.hlsl

@@ -1,5 +1,7 @@
 // Run: %dxc -T vs_6_0 -E main
 
+// TODO: matrix *= scalar
+
 void main() {
 // CHECK-LABEL: %bb_entry = OpLabel
 

+ 147 - 0
tools/clang/test/CodeGenSPIRV/binary-op.arithmetic.matrix.hlsl

@@ -0,0 +1,147 @@
+// Run: %dxc -T vs_6_0 -E main
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    // 1x1
+    float1x1 a, b, c;
+// CHECK:      [[a0:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: [[b0:%\d+]] = OpLoad %float %b
+// CHECK-NEXT: [[c0:%\d+]] = OpFAdd %float [[a0]] [[b0]]
+// CHECK-NEXT: OpStore %c [[c0]]
+    c = a + b;
+// CHECK-NEXT: [[a1:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: [[b1:%\d+]] = OpLoad %float %b
+// CHECK-NEXT: [[c1:%\d+]] = OpFSub %float [[a1]] [[b1]]
+// CHECK-NEXT: OpStore %c [[c1]]
+    c = a - b;
+// CHECK-NEXT: [[a2:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: [[b2:%\d+]] = OpLoad %float %b
+// CHECK-NEXT: [[c2:%\d+]] = OpFMul %float [[a2]] [[b2]]
+// CHECK-NEXT: OpStore %c [[c2]]
+    c = a * b;
+// CHECK-NEXT: [[a3:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: [[b3:%\d+]] = OpLoad %float %b
+// CHECK-NEXT: [[c3:%\d+]] = OpFDiv %float [[a3]] [[b3]]
+// CHECK-NEXT: OpStore %c [[c3]]
+    c = a / b;
+// CHECK-NEXT: [[a4:%\d+]] = OpLoad %float %a
+// CHECK-NEXT: [[b4:%\d+]] = OpLoad %float %b
+// CHECK-NEXT: [[c4:%\d+]] = OpFRem %float [[a4]] [[b4]]
+// CHECK-NEXT: OpStore %c [[c4]]
+    c = a % b;
+
+    // Mx1
+    float2x1 h, i, j;
+// CHECK-NEXT: [[h0:%\d+]] = OpLoad %v2float %h
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %v2float %i
+// CHECK-NEXT: [[j0:%\d+]] = OpFAdd %v2float [[h0]] [[i0]]
+// CHECK-NEXT: OpStore %j [[j0]]
+    j = h + i;
+// CHECK-NEXT: [[h1:%\d+]] = OpLoad %v2float %h
+// CHECK-NEXT: [[i1:%\d+]] = OpLoad %v2float %i
+// CHECK-NEXT: [[j1:%\d+]] = OpFSub %v2float [[h1]] [[i1]]
+// CHECK-NEXT: OpStore %j [[j1]]
+    j = h - i;
+// CHECK-NEXT: [[h2:%\d+]] = OpLoad %v2float %h
+// CHECK-NEXT: [[i2:%\d+]] = OpLoad %v2float %i
+// CHECK-NEXT: [[j2:%\d+]] = OpFMul %v2float [[h2]] [[i2]]
+// CHECK-NEXT: OpStore %j [[j2]]
+    j = h * i;
+// CHECK-NEXT: [[h3:%\d+]] = OpLoad %v2float %h
+// CHECK-NEXT: [[i3:%\d+]] = OpLoad %v2float %i
+// CHECK-NEXT: [[j3:%\d+]] = OpFDiv %v2float [[h3]] [[i3]]
+// CHECK-NEXT: OpStore %j [[j3]]
+    j = h / i;
+// CHECK-NEXT: [[h4:%\d+]] = OpLoad %v2float %h
+// CHECK-NEXT: [[i4:%\d+]] = OpLoad %v2float %i
+// CHECK-NEXT: [[j4:%\d+]] = OpFRem %v2float [[h4]] [[i4]]
+// CHECK-NEXT: OpStore %j [[j4]]
+    j = h % i;
+
+    // 1xN
+    float1x3 o, p, q;
+// CHECK-NEXT: [[o0:%\d+]] = OpLoad %v3float %o
+// CHECK-NEXT: [[p0:%\d+]] = OpLoad %v3float %p
+// CHECK-NEXT: [[q0:%\d+]] = OpFAdd %v3float [[o0]] [[p0]]
+// CHECK-NEXT: OpStore %q [[q0]]
+    q = o + p;
+// CHECK-NEXT: [[o1:%\d+]] = OpLoad %v3float %o
+// CHECK-NEXT: [[p1:%\d+]] = OpLoad %v3float %p
+// CHECK-NEXT: [[q1:%\d+]] = OpFSub %v3float [[o1]] [[p1]]
+// CHECK-NEXT: OpStore %q [[q1]]
+    q = o - p;
+// CHECK-NEXT: [[o2:%\d+]] = OpLoad %v3float %o
+// CHECK-NEXT: [[p2:%\d+]] = OpLoad %v3float %p
+// CHECK-NEXT: [[q2:%\d+]] = OpFMul %v3float [[o2]] [[p2]]
+// CHECK-NEXT: OpStore %q [[q2]]
+    q = o * p;
+// CHECK-NEXT: [[o3:%\d+]] = OpLoad %v3float %o
+// CHECK-NEXT: [[p3:%\d+]] = OpLoad %v3float %p
+// CHECK-NEXT: [[q3:%\d+]] = OpFDiv %v3float [[o3]] [[p3]]
+// CHECK-NEXT: OpStore %q [[q3]]
+    q = o / p;
+// CHECK-NEXT: [[o4:%\d+]] = OpLoad %v3float %o
+// CHECK-NEXT: [[p4:%\d+]] = OpLoad %v3float %p
+// CHECK-NEXT: [[q4:%\d+]] = OpFRem %v3float [[o4]] [[p4]]
+// CHECK-NEXT: OpStore %q [[q4]]
+    q = o % p;
+
+    // MxN
+    float2x3 r, s, t;
+// CHECK-NEXT: [[r0:%\d+]] = OpLoad %mat2v3float %r
+// CHECK-NEXT: [[s0:%\d+]] = OpLoad %mat2v3float %s
+// CHECK-NEXT: [[r0v0:%\d+]] = OpCompositeExtract %v3float [[r0]] 0
+// CHECK-NEXT: [[s0v0:%\d+]] = OpCompositeExtract %v3float [[s0]] 0
+// CHECK-NEXT: [[t0v0:%\d+]] = OpFAdd %v3float [[r0v0]] [[s0v0]]
+// CHECK-NEXT: [[r0v1:%\d+]] = OpCompositeExtract %v3float [[r0]] 1
+// CHECK-NEXT: [[s0v1:%\d+]] = OpCompositeExtract %v3float [[s0]] 1
+// CHECK-NEXT: [[t0v1:%\d+]] = OpFAdd %v3float [[r0v1]] [[s0v1]]
+// CHECK-NEXT: [[t0:%\d+]] = OpCompositeConstruct %mat2v3float [[t0v0]] [[t0v1]]
+// CHECK-NEXT: OpStore %t [[t0]]
+    t = r + s;
+// CHECK-NEXT: [[r1:%\d+]] = OpLoad %mat2v3float %r
+// CHECK-NEXT: [[s1:%\d+]] = OpLoad %mat2v3float %s
+// CHECK-NEXT: [[r1v0:%\d+]] = OpCompositeExtract %v3float [[r1]] 0
+// CHECK-NEXT: [[s1v0:%\d+]] = OpCompositeExtract %v3float [[s1]] 0
+// CHECK-NEXT: [[t1v0:%\d+]] = OpFSub %v3float [[r1v0]] [[s1v0]]
+// CHECK-NEXT: [[r1v1:%\d+]] = OpCompositeExtract %v3float [[r1]] 1
+// CHECK-NEXT: [[s1v1:%\d+]] = OpCompositeExtract %v3float [[s1]] 1
+// CHECK-NEXT: [[t1v1:%\d+]] = OpFSub %v3float [[r1v1]] [[s1v1]]
+// CHECK-NEXT: [[t1:%\d+]] = OpCompositeConstruct %mat2v3float [[t1v0]] [[t1v1]]
+// CHECK-NEXT: OpStore %t [[t1]]
+    t = r - s;
+// CHECK-NEXT: [[r2:%\d+]] = OpLoad %mat2v3float %r
+// CHECK-NEXT: [[s2:%\d+]] = OpLoad %mat2v3float %s
+// CHECK-NEXT: [[r2v0:%\d+]] = OpCompositeExtract %v3float [[r2]] 0
+// CHECK-NEXT: [[s2v0:%\d+]] = OpCompositeExtract %v3float [[s2]] 0
+// CHECK-NEXT: [[t2v0:%\d+]] = OpFMul %v3float [[r2v0]] [[s2v0]]
+// CHECK-NEXT: [[r2v1:%\d+]] = OpCompositeExtract %v3float [[r2]] 1
+// CHECK-NEXT: [[s2v1:%\d+]] = OpCompositeExtract %v3float [[s2]] 1
+// CHECK-NEXT: [[t2v1:%\d+]] = OpFMul %v3float [[r2v1]] [[s2v1]]
+// CHECK-NEXT: [[t2:%\d+]] = OpCompositeConstruct %mat2v3float [[t2v0]] [[t2v1]]
+// CHECK-NEXT: OpStore %t [[t2]]
+    t = r * s;
+// CHECK-NEXT: [[r3:%\d+]] = OpLoad %mat2v3float %r
+// CHECK-NEXT: [[s3:%\d+]] = OpLoad %mat2v3float %s
+// CHECK-NEXT: [[r3v0:%\d+]] = OpCompositeExtract %v3float [[r3]] 0
+// CHECK-NEXT: [[s3v0:%\d+]] = OpCompositeExtract %v3float [[s3]] 0
+// CHECK-NEXT: [[t3v0:%\d+]] = OpFDiv %v3float [[r3v0]] [[s3v0]]
+// CHECK-NEXT: [[r3v1:%\d+]] = OpCompositeExtract %v3float [[r3]] 1
+// CHECK-NEXT: [[s3v1:%\d+]] = OpCompositeExtract %v3float [[s3]] 1
+// CHECK-NEXT: [[t3v1:%\d+]] = OpFDiv %v3float [[r3v1]] [[s3v1]]
+// CHECK-NEXT: [[t3:%\d+]] = OpCompositeConstruct %mat2v3float [[t3v0]] [[t3v1]]
+// CHECK-NEXT: OpStore %t [[t3]]
+    t = r / s;
+// CHECK-NEXT: [[r4:%\d+]] = OpLoad %mat2v3float %r
+// CHECK-NEXT: [[s4:%\d+]] = OpLoad %mat2v3float %s
+// CHECK-NEXT: [[r4v0:%\d+]] = OpCompositeExtract %v3float [[r4]] 0
+// CHECK-NEXT: [[s4v0:%\d+]] = OpCompositeExtract %v3float [[s4]] 0
+// CHECK-NEXT: [[t4v0:%\d+]] = OpFRem %v3float [[r4v0]] [[s4v0]]
+// CHECK-NEXT: [[r4v1:%\d+]] = OpCompositeExtract %v3float [[r4]] 1
+// CHECK-NEXT: [[s4v1:%\d+]] = OpCompositeExtract %v3float [[s4]] 1
+// CHECK-NEXT: [[t4v1:%\d+]] = OpFRem %v3float [[r4v1]] [[s4v1]]
+// CHECK-NEXT: [[t4:%\d+]] = OpCompositeConstruct %mat2v3float [[t4v0]] [[t4v1]]
+// CHECK-NEXT: OpStore %t [[t4]]
+    t = r % s;
+}

+ 2 - 0
tools/clang/test/CodeGenSPIRV/binary-op.arithmetic.mixed.hlsl

@@ -1,5 +1,7 @@
 // Run: %dxc -T vs_6_0 -E main
 
+// TODO: matrix * scalar
+
 void main() {
 // CHECK-LABEL: %bb_entry = OpLabel
 

+ 6 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -77,6 +77,9 @@ TEST_F(FileTest, BinaryOpScalarArithmetic) {
 TEST_F(FileTest, BinaryOpVectorArithmetic) {
   runFileTest("binary-op.arithmetic.vector.hlsl");
 }
+TEST_F(FileTest, BinaryOpMatrixArithmetic) {
+  runFileTest("binary-op.arithmetic.matrix.hlsl");
+}
 TEST_F(FileTest, BinaryOpMixedArithmetic) {
   runFileTest("binary-op.arithmetic.mixed.hlsl");
 }
@@ -88,6 +91,9 @@ TEST_F(FileTest, BinaryOpScalarArithAssign) {
 TEST_F(FileTest, BinaryOpVectorArithAssign) {
   runFileTest("binary-op.arith-assign.vector.hlsl");
 }
+TEST_F(FileTest, BinaryOpMatrixArithAssign) {
+  runFileTest("binary-op.arith-assign.matrix.hlsl");
+}
 TEST_F(FileTest, BinaryOpMixedArithAssign) {
   runFileTest("binary-op.arith-assign.mixed.hlsl");
 }