Browse Source

[spirv] Support matrices in any, all, asfloat, asint, asuint (#533)

Ehsan 8 years ago
parent
commit
b5043877bc

+ 96 - 24
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1015,7 +1015,7 @@ uint32_t SPIRVEmitter::doCastExpr(const CastExpr *expr) {
   }
   case CastKind::CK_HLSLVectorToMatrixCast: {
     // The target type should already be a 1xN matrix type.
-    assert(TypeTranslator::is1xNMatrixType(toType));
+    assert(TypeTranslator::is1xNMatrix(toType));
     return doExpr(subExpr);
   }
   case CastKind::CK_HLSLMatrixSplat: {
@@ -1046,12 +1046,12 @@ uint32_t SPIRVEmitter::doCastExpr(const CastExpr *expr) {
   }
   case CastKind::CK_HLSLMatrixToScalarCast: {
     // The underlying should already be a matrix of 1x1.
-    assert(TypeTranslator::is1x1MatrixType(subExpr->getType()));
+    assert(TypeTranslator::is1x1Matrix(subExpr->getType()));
     return doExpr(subExpr);
   }
   case CastKind::CK_HLSLMatrixToVectorCast: {
     // The underlying should already be a matrix of 1xN.
-    assert(TypeTranslator::is1xNMatrixType(subExpr->getType()));
+    assert(TypeTranslator::is1xNMatrix(subExpr->getType()));
     return doExpr(subExpr);
   }
   case CastKind::CK_FunctionToPointerDecay:
@@ -1327,7 +1327,7 @@ uint32_t SPIRVEmitter::doUnaryOperator(const UnaryOperator *expr) {
                              : getValueOne(subType);
     uint32_t incValue = 0;
     if (TypeTranslator::isSpirvAcceptableMatrixType(subType)) {
-      // For matrices, we can only incremnt/decrement each vector of it.
+      // For matrices, we can only increment/decrement each vector of it.
       const auto actOnEachVec = [this, spvOp, one](
           uint32_t /*index*/, uint32_t vecType, uint32_t lhsVec) {
         return theBuilder.createBinaryOp(spvOp, vecType, lhsVec, one);
@@ -1810,8 +1810,7 @@ uint32_t SPIRVEmitter::tryToGenFloatMatrixScale(const BinaryOperator *expr) {
   const QualType rhsType = rhs->getType();
 
   const auto selectOpcode = [](const QualType ty) {
-    return TypeTranslator::isMx1MatrixType(ty) ||
-                   TypeTranslator::is1xNMatrixType(ty)
+    return TypeTranslator::isMx1Matrix(ty) || TypeTranslator::is1xNMatrix(ty)
                ? spv::Op::OpVectorTimesScalar
                : spv::Op::OpMatrixTimesScalar;
   };
@@ -2071,6 +2070,13 @@ uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,
   if (isBoolOrVecOfBoolType(fromType))
     return fromVal;
 
+  // Handle 1x1 bool matrix, Mx1 bool matrix, and 1xN bool matrix.
+  {
+    if (typeTranslator.is1x1OrMx1Or1xNMatrix(fromType) &&
+        hlsl::GetHLSLMatElementType(fromType)->isBooleanType())
+      return fromVal;
+  }
+
   // Converting to bool means comparing with value zero.
   const spv::Op spvOp = translateOp(BO_NE, fromType);
   const uint32_t boolType = typeTranslator.translateType(toBoolType);
@@ -2243,29 +2249,81 @@ uint32_t SPIRVEmitter::processIntrinsicDot(const CallExpr *callExpr) {
 
 uint32_t SPIRVEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
                                                 spv::Op spvOp) {
-  const uint32_t returnType = typeTranslator.translateType(callExpr->getType());
-
   // 'all' and 'any' take only 1 parameter.
   assert(callExpr->getNumArgs() == 1u);
+  const QualType returnType = callExpr->getType();
+  const uint32_t returnTypeId = typeTranslator.translateType(returnType);
   const Expr *arg = callExpr->getArg(0);
   const QualType argType = arg->getType();
 
-  if (hlsl::IsHLSLMatType(argType)) {
-    emitError("'all' and 'any' do not support matrix arguments yet.");
-    return 0;
+  // Handle scalars, vectors of size 1, and 1x1 matrices as arguments.
+  // Optimization: can directly cast them to boolean. No need for OpAny/OpAll.
+  {
+    if (argType->isBooleanType() || argType->isFloatingType() ||
+        argType->isIntegerType() || TypeTranslator::isVec1Type(argType) ||
+        TypeTranslator::is1x1Matrix(argType))
+      return castToBool(doExpr(arg), argType, returnType);
   }
 
-  bool isSpirvAcceptableVecType =
-      hlsl::IsHLSLVecType(argType) && hlsl::GetHLSLVecSize(argType) > 1;
-  if (!isSpirvAcceptableVecType) {
-    // For a scalar or vector of 1 scalar, we can simply cast to boolean.
-    return castToBool(doExpr(arg), arg->getType(), callExpr->getType());
-  } else {
-    // First cast the vector to a vector of booleans, then use OpAll
-    uint32_t boolVecId =
-        castToBool(doExpr(arg), arg->getType(), callExpr->getType());
-    return theBuilder.createUnaryOp(spvOp, returnType, boolVecId);
+  // Handle vectors larger than 1, Mx1 matrices, and 1xN matrices as arguments.
+  // Cast the vector to a boolean vector, then run OpAny/OpAll on it.
+  {
+    QualType elemType = {};
+    uint32_t size = 0;
+    if (TypeTranslator::isVectorType(argType, &elemType, &size) ||
+        TypeTranslator::isMx1Or1xNMatrix(argType, &elemType, &size)) {
+      const QualType castToBoolType =
+          astContext.getExtVectorType(returnType, size);
+      uint32_t castedToBoolId =
+          castToBool(doExpr(arg), argType, castToBoolType);
+      return theBuilder.createUnaryOp(spvOp, returnTypeId, castedToBoolId);
+    }
   }
+
+  // Handle MxN matrices as arguments.
+  {
+    QualType elemType = {};
+    uint32_t matRowCount = 0, matColCount = 0;
+    if (TypeTranslator::isMxNMatrix(argType, &elemType, &matRowCount,
+                                    &matColCount)) {
+      if (!elemType->isFloatingType()) {
+        emitError("'all' and 'any' currently do not take non-floating point "
+                  "matrices as argument.");
+        return 0;
+      }
+
+      uint32_t matrixId = doExpr(arg);
+      const uint32_t vecType = typeTranslator.getComponentVectorType(argType);
+      llvm::SmallVector<uint32_t, 4> rowResults;
+      for (uint32_t i = 0; i < matRowCount; ++i) {
+        // Extract the row which is a float vector of size matColCount.
+        const uint32_t rowFloatVec =
+            theBuilder.createCompositeExtract(vecType, matrixId, {i});
+        // Cast the float vector to boolean vector.
+        const auto rowFloatQualType =
+            astContext.getExtVectorType(elemType, matColCount);
+        const auto rowBoolQualType =
+            astContext.getExtVectorType(returnType, matColCount);
+        const uint32_t rowBoolVec =
+            castToBool(rowFloatVec, rowFloatQualType, rowBoolQualType);
+        // Perform OpAny/OpAll on the boolean vector.
+        rowResults.push_back(
+            theBuilder.createUnaryOp(spvOp, returnTypeId, rowBoolVec));
+      }
+      // Create a new vector that is the concatenation of results of all rows.
+      uint32_t boolId = theBuilder.getBoolType();
+      uint32_t vecOfBoolsId = theBuilder.getVecType(boolId, matRowCount);
+      const uint32_t rowResultsId =
+          theBuilder.createCompositeConstruct(vecOfBoolsId, rowResults);
+
+      // Run OpAny/OpAll on the newly-created vector.
+      return theBuilder.createUnaryOp(spvOp, returnTypeId, rowResultsId);
+    }
+  }
+
+  // All types should be handled already.
+  llvm_unreachable("Unknown argument type passed to all()/any().");
+  return 0;
 }
 
 uint32_t SPIRVEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
@@ -2280,9 +2338,11 @@ uint32_t SPIRVEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
   if (returnType.getCanonicalType() == argType.getCanonicalType())
     return doExpr(arg);
 
-  if (hlsl::IsHLSLMatType(argType)) {
-    emitError("'asfloat', 'asint', and 'asuint' do not support matrix "
-              "arguments yet.");
+  // SPIR-V does not support non-floating point matrices. So 'asint' and
+  // 'asuint' for MxN matrices are currently not supported.
+  if (TypeTranslator::isMxNMatrix(argType)) {
+    emitError("SPIR-V does not support non-floating point matrices. Thus, "
+              "'asint' and 'asuint' currently do not take matrix arguments.");
     return 0;
   }
 
@@ -2311,6 +2371,18 @@ uint32_t SPIRVEmitter::getValueZero(QualType type) {
     }
   }
 
+  // 1x1, Mx1, and 1xN Matrices
+  {
+    QualType elemType = {};
+    uint32_t size = {};
+    if (TypeTranslator::is1x1Matrix(type, &elemType))
+      return getValueZero(elemType);
+    if (TypeTranslator::isMx1Or1xNMatrix(type, &elemType, &size))
+      return getVecValueZero(elemType, size);
+
+    // TODO: Handle getValueZero for MxN matrices.
+  }
+
   emitError("getting value 0 for type '%0' unimplemented")
       << type.getAsString();
   return 0;

+ 82 - 12
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -114,6 +114,12 @@ uint32_t TypeTranslator::translateType(QualType type) {
   return 0;
 }
 
+bool TypeTranslator::isVec1Type(QualType type, QualType *elemType) {
+  uint32_t count = 0;
+  const bool isVec = isVectorType(type, elemType, &count);
+  return isVec && count == 1;
+}
+
 bool TypeTranslator::isVectorType(QualType type, QualType *elemType,
                                   uint32_t *count) {
   if (hlsl::IsHLSLVecType(type)) {
@@ -135,49 +141,113 @@ bool TypeTranslator::isVectorType(QualType type, QualType *elemType,
   return false;
 }
 
-bool TypeTranslator::is1x1MatrixType(QualType type) {
+bool TypeTranslator::is1x1Matrix(QualType type, QualType *elemType) {
   if (!hlsl::IsHLSLMatType(type))
     return false;
 
   uint32_t rowCount = 0, colCount = 0;
   hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
 
-  return rowCount == 1 && colCount == 1;
+  const bool is1x1 = rowCount == 1 && colCount == 1;
+
+  if (!is1x1)
+    return false;
+
+  if (elemType)
+    *elemType = hlsl::GetHLSLMatElementType(type);
+  return true;
 }
 
-bool TypeTranslator::is1xNMatrixType(QualType type) {
+bool TypeTranslator::is1xNMatrix(QualType type, QualType *elemType,
+                                 uint32_t *count) {
   if (!hlsl::IsHLSLMatType(type))
     return false;
 
   uint32_t rowCount = 0, colCount = 0;
   hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
 
-  return rowCount == 1 && colCount > 1;
+  const bool is1xN = rowCount == 1 && colCount > 1;
+
+  if (!is1xN)
+    return false;
+
+  if (elemType)
+    *elemType = hlsl::GetHLSLMatElementType(type);
+  if (count)
+    *count = colCount;
+  return true;
 }
 
-bool TypeTranslator::isMx1MatrixType(QualType type) {
+bool TypeTranslator::isMx1Matrix(QualType type, QualType *elemType,
+                                 uint32_t *count) {
   if (!hlsl::IsHLSLMatType(type))
     return false;
 
   uint32_t rowCount = 0, colCount = 0;
   hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
 
-  return rowCount > 1 && colCount == 1;
+  const bool isMx1 = rowCount > 1 && colCount == 1;
+
+  if (!isMx1)
+    return false;
+
+  if (elemType)
+    *elemType = hlsl::GetHLSLMatElementType(type);
+  if (count)
+    *count = rowCount;
+  return true;
 }
 
-/// Returns true if the given type is a SPIR-V acceptable matrix type, i.e.,
-/// with floating point elements and greater than 1 row and column counts.
-bool TypeTranslator::isSpirvAcceptableMatrixType(QualType type) {
+bool TypeTranslator::isMx1Or1xNMatrix(QualType type, QualType *elemType,
+                                      uint32_t *count) {
   if (!hlsl::IsHLSLMatType(type))
     return false;
 
-  const auto elemType = hlsl::GetHLSLMatElementType(type);
-  if (!elemType->isFloatingType())
+  uint32_t rowCount = 0, colCount = 0;
+  hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
+
+  const bool isSingleRowOrCol =
+      (rowCount == 1 && colCount > 1) || (rowCount > 1 && colCount == 1);
+
+  if (!isSingleRowOrCol)
+    return false;
+
+  if (elemType)
+    *elemType = hlsl::GetHLSLMatElementType(type);
+  if (count)
+    *count = rowCount > 1 ? rowCount : colCount;
+  return true;
+}
+
+bool TypeTranslator::is1x1OrMx1Or1xNMatrix(QualType type) {
+  return is1x1Matrix(type) || isMx1Or1xNMatrix(type);
+}
+
+bool TypeTranslator::isMxNMatrix(QualType type, QualType *elemType,
+                                 uint32_t *numRows, uint32_t *numCols) {
+  if (!hlsl::IsHLSLMatType(type))
     return false;
 
   uint32_t rowCount = 0, colCount = 0;
   hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
-  return rowCount > 1 && colCount > 1;
+
+  const bool isMxN = rowCount > 1 && colCount > 1;
+
+  if (!isMxN)
+    return false;
+
+  if (elemType)
+    *elemType = hlsl::GetHLSLMatElementType(type);
+  if (numRows)
+    *numRows = rowCount;
+  if (numCols)
+    *numCols = colCount;
+  return true;
+}
+
+bool TypeTranslator::isSpirvAcceptableMatrixType(QualType type) {
+  QualType elemType = {};
+  return isMxNMatrix(type, &elemType) && elemType->isFloatingType();
 }
 
 uint32_t TypeTranslator::getComponentVectorType(QualType matrixType) {

+ 36 - 6
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -42,14 +42,44 @@ public:
   /// into *elementType and *count respectively if they are not nullptr.
   static bool isVectorType(QualType type, QualType *elemType, uint32_t *count);
 
-  /// \brief Returns true if the givne type is a 1x1 matrix type.
-  static bool is1x1MatrixType(QualType type);
+  /// \brief Returns true if the given type is a vector type of size 1.
+  /// If elemType is not nullptr, writes the element type to *elemType.
+  static bool isVec1Type(QualType type, QualType *elemType = nullptr);
 
-  /// \brief Returns true if the givne type is a 1xN (N > 1) matrix type.
-  static bool is1xNMatrixType(QualType type);
+  /// \brief Returns true if the given type is a 1x1 matrix type.
+  /// If elemType is not nullptr, writes the element type to *elemType.
+  static bool is1x1Matrix(QualType type, QualType *elemType = nullptr);
 
-  /// \brief Returns true if the givne type is a Mx1 (M > 1) matrix type.
-  static bool isMx1MatrixType(QualType type);
+  /// \brief Returns true if the given type is a 1xN (N > 1) matrix type.
+  /// If elemType is not nullptr, writes the element type to *elemType.
+  /// If count is not nullptr, writes the value of N into *count.
+  static bool is1xNMatrix(QualType type, QualType *elemType = nullptr,
+                          uint32_t *count = nullptr);
+
+  /// \brief Returns true if the given type is a Mx1 (M > 1) matrix type.
+  /// If elemType is not nullptr, writes the element type to *elemType.
+  /// If count is not nullptr, writes the value of M into *count.
+  static bool isMx1Matrix(QualType type, QualType *elemType = nullptr,
+                          uint32_t *count = nullptr);
+
+  /// \brief Returns true if the given type is a Mx1 (M > 1), or 1xN (N > 1)
+  /// matrix type. If elemType is not nullptr, writes the matrix element type to
+  /// *elemType. If count is not nullptr, writes the size (M or N) into *count.
+  static bool isMx1Or1xNMatrix(QualType type, QualType *elemType = nullptr,
+                               uint32_t *count = nullptr);
+
+  /// \brief Returns true if the given type is a 1x1, or Mx1 (M > 1), or
+  /// 1xN (N > 1) matrix type.
+  static bool is1x1OrMx1Or1xNMatrix(QualType type);
+
+  /// \brief returns true if the given type is a matrix with more than 1 row and
+  /// more than 1 column.
+  /// If elemType is not nullptr, writes the element type to *elemType.
+  /// If rowCount is not nullptr, writes the number of rows (M) into *rowCount.
+  /// If colCount is not nullptr, writes the number of cols (N) into *colCount.
+  static bool isMxNMatrix(QualType type, QualType *elemType = nullptr,
+                          uint32_t *rowCount = nullptr,
+                          uint32_t *colCount = nullptr);
 
   /// \brief Returns true if the given type is a SPIR-V acceptable matrix type,
   /// i.e., with floating point elements and greater than 1 row and column

+ 41 - 4
tools/clang/test/CodeGenSPIRV/intrinsics.all.hlsl

@@ -7,6 +7,8 @@
 // CHECK:      [[v4int_0:%\d+]] = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
 // CHECK-NEXT: [[v4uint_0:%\d+]] = OpConstantComposite %v4uint %uint_0 %uint_0 %uint_0 %uint_0
 // CHECK-NEXT: [[v4float_0:%\d+]] = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+// CHECK-NEXT: [[v3float_0:%\d+]] = OpConstantComposite %v3float %float_0 %float_0 %float_0
+// CHECK-NEXT: [[v2float_0:%\d+]] = OpConstantComposite %v2float %float_0 %float_0
 
 void main() {
     bool result;
@@ -58,14 +60,14 @@ void main() {
     result = all(h);
 
     // CHECK-NEXT: [[i:%\d+]] = OpLoad %v4int %i
-    // CHECK-NEXT: [[v4int_to_bool:%\d+]] = OpINotEqual %bool [[i]] [[v4int_0]]
+    // CHECK-NEXT: [[v4int_to_bool:%\d+]] = OpINotEqual %v4bool [[i]] [[v4int_0]]
     // CHECK-NEXT: [[all_int4:%\d+]] = OpAll %bool [[v4int_to_bool]]
     // CHECK-NEXT: OpStore %result [[all_int4]]
     int4 i;
     result = all(i);
 
     // CHECK-NEXT: [[j:%\d+]] = OpLoad %v4uint %j
-    // CHECK-NEXT: [[v4uint_to_bool:%\d+]] = OpINotEqual %bool [[j]] [[v4uint_0]]
+    // CHECK-NEXT: [[v4uint_to_bool:%\d+]] = OpINotEqual %v4bool [[j]] [[v4uint_0]]
     // CHECK-NEXT: [[all_uint4:%\d+]] = OpAll %bool [[v4uint_to_bool]]
     // CHECK-NEXT: OpStore %result [[all_uint4]]
     uint4 j;
@@ -78,10 +80,45 @@ void main() {
     result = all(k);
 
     // CHECK-NEXT: [[l:%\d+]] = OpLoad %v4float %l
-    // CHECK-NEXT: [[v4float_to_bool:%\d+]] = OpFOrdNotEqual %bool [[l]] [[v4float_0]]
+    // CHECK-NEXT: [[v4float_to_bool:%\d+]] = OpFOrdNotEqual %v4bool [[l]] [[v4float_0]]
     // CHECK-NEXT: [[all_float4:%\d+]] = OpAll %bool [[v4float_to_bool]]
     // CHECK-NEXT: OpStore %result [[all_float4]]
     float4 l;
     result = all(l);
-}
 
+    // CHECK-NEXT: [[m:%\d+]] = OpLoad %float %m
+    // CHECK-NEXT: [[mat1x1_to_bool:%\d+]] = OpFOrdNotEqual %bool [[m]] %float_0
+    // CHECK-NEXT: OpStore %result [[mat1x1_to_bool]]
+    float1x1 m;
+    result = all(m);
+
+    // CHECK-NEXT: [[n:%\d+]] = OpLoad %v3float %n
+    // CHECK-NEXT: [[mat1x3_to_bool:%\d+]] = OpFOrdNotEqual %v3bool [[n]] [[v3float_0]]
+    // CHECK-NEXT: [[all_mat1x3:%\d+]] = OpAll %bool [[mat1x3_to_bool]]
+    // CHECK-NEXT: OpStore %result [[all_mat1x3]]
+    float1x3 n;
+    result = all(n);
+
+    // CHECK-NEXT: [[o:%\d+]] = OpLoad %v2float %o
+    // CHECK-NEXT: [[mat2x1_to_bool:%\d+]] = OpFOrdNotEqual %v2bool [[o]] [[v2float_0]]
+    // CHECK-NEXT: [[all_mat2x1:%\d+]] = OpAll %bool [[mat2x1_to_bool]]
+    // CHECK-NEXT: OpStore %result [[all_mat2x1]]
+    float2x1 o;
+    result = all(o);
+
+    // CHECK-NEXT: [[p:%\d+]] = OpLoad %mat3v4float %p
+    // CHECK-NEXT: [[row0:%\d+]] = OpCompositeExtract %v4float [[p]] 0
+    // CHECK-NEXT: [[row0_to_bool_vec:%\d+]] = OpFOrdNotEqual %v4bool [[row0]] [[v4float_0]]
+    // CHECK-NEXT: [[all_row0:%\d+]] = OpAll %bool [[row0_to_bool_vec]]
+    // CHECK-NEXT: [[row1:%\d+]] = OpCompositeExtract %v4float [[p]] 1
+    // CHECK-NEXT: [[row1_to_bool_vec:%\d+]] = OpFOrdNotEqual %v4bool [[row1]] [[v4float_0]]
+    // CHECK-NEXT: [[all_row1:%\d+]] = OpAll %bool [[row1_to_bool_vec]]
+    // CHECK-NEXT: [[row2:%\d+]] = OpCompositeExtract %v4float [[p]] 2
+    // CHECK-NEXT: [[row2_to_bool_vec:%\d+]] = OpFOrdNotEqual %v4bool [[row2]] [[v4float_0]]
+    // CHECK-NEXT: [[all_row2:%\d+]] = OpAll %bool [[row2_to_bool_vec]]
+    // CHECK-NEXT: [[all_rows:%\d+]] = OpCompositeConstruct %v3bool [[all_row0]] [[all_row1]] [[all_row2]]
+    // CHECK-NEXT: [[all_mat3x4:%\d+]] = OpAll %bool [[all_rows]]
+    // CHECK-NEXT: OpStore %result [[all_mat3x4]]
+    float3x4 p;
+    result = all(p);
+}

+ 41 - 3
tools/clang/test/CodeGenSPIRV/intrinsics.any.hlsl

@@ -7,6 +7,8 @@
 // CHECK:      [[v4int_0:%\d+]] = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
 // CHECK-NEXT: [[v4uint_0:%\d+]] = OpConstantComposite %v4uint %uint_0 %uint_0 %uint_0 %uint_0
 // CHECK-NEXT: [[v4float_0:%\d+]] = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+// CHECK-NEXT: [[v3float_0:%\d+]] = OpConstantComposite %v3float %float_0 %float_0 %float_0
+// CHECK-NEXT: [[v2float_0:%\d+]] = OpConstantComposite %v2float %float_0 %float_0
 
 void main() {
     bool result;
@@ -58,14 +60,14 @@ void main() {
     result = any(h);
 
     // CHECK-NEXT: [[i:%\d+]] = OpLoad %v4int %i
-    // CHECK-NEXT: [[v4int_to_bool:%\d+]] = OpINotEqual %bool [[i]] [[v4int_0]]
+    // CHECK-NEXT: [[v4int_to_bool:%\d+]] = OpINotEqual %v4bool [[i]] [[v4int_0]]
     // CHECK-NEXT: [[any_int4:%\d+]] = OpAny %bool [[v4int_to_bool]]
     // CHECK-NEXT: OpStore %result [[any_int4]]
     int4 i;
     result = any(i);
 
     // CHECK-NEXT: [[j:%\d+]] = OpLoad %v4uint %j
-    // CHECK-NEXT: [[v4uint_to_bool:%\d+]] = OpINotEqual %bool [[j]] [[v4uint_0]]
+    // CHECK-NEXT: [[v4uint_to_bool:%\d+]] = OpINotEqual %v4bool [[j]] [[v4uint_0]]
     // CHECK-NEXT: [[any_uint4:%\d+]] = OpAny %bool [[v4uint_to_bool]]
     // CHECK-NEXT: OpStore %result [[any_uint4]]
     uint4 j;
@@ -78,10 +80,46 @@ void main() {
     result = any(k);
 
     // CHECK-NEXT: [[l:%\d+]] = OpLoad %v4float %l
-    // CHECK-NEXT: [[v4float_to_bool:%\d+]] = OpFOrdNotEqual %bool [[l]] [[v4float_0]]
+    // CHECK-NEXT: [[v4float_to_bool:%\d+]] = OpFOrdNotEqual %v4bool [[l]] [[v4float_0]]
     // CHECK-NEXT: [[any_float4:%\d+]] = OpAny %bool [[v4float_to_bool]]
     // CHECK-NEXT: OpStore %result [[any_float4]]
     float4 l;
     result = any(l);
+
+    // CHECK-NEXT: [[m:%\d+]] = OpLoad %float %m
+    // CHECK-NEXT: [[mat1x1_to_bool:%\d+]] = OpFOrdNotEqual %bool [[m]] %float_0
+    // CHECK-NEXT: OpStore %result [[mat1x1_to_bool]]
+    float1x1 m;
+    result = any(m);
+
+    // CHECK-NEXT: [[n:%\d+]] = OpLoad %v3float %n
+    // CHECK-NEXT: [[mat1x3_to_bool:%\d+]] = OpFOrdNotEqual %v3bool [[n]] [[v3float_0]]
+    // CHECK-NEXT: [[any_mat1x3:%\d+]] = OpAny %bool [[mat1x3_to_bool]]
+    // CHECK-NEXT: OpStore %result [[any_mat1x3]]
+    float1x3 n;
+    result = any(n);
+
+    // CHECK-NEXT: [[o:%\d+]] = OpLoad %v2float %o
+    // CHECK-NEXT: [[mat2x1_to_bool:%\d+]] = OpFOrdNotEqual %v2bool [[o]] [[v2float_0]]
+    // CHECK-NEXT: [[any_mat2x1:%\d+]] = OpAny %bool [[mat2x1_to_bool]]
+    // CHECK-NEXT: OpStore %result [[any_mat2x1]]
+    float2x1 o;
+    result = any(o);
+
+    // CHECK-NEXT: [[p:%\d+]] = OpLoad %mat3v4float %p
+    // CHECK-NEXT: [[row0:%\d+]] = OpCompositeExtract %v4float [[p]] 0
+    // CHECK-NEXT: [[row0_to_bool_vec:%\d+]] = OpFOrdNotEqual %v4bool [[row0]] [[v4float_0]]
+    // CHECK-NEXT: [[any_row0:%\d+]] = OpAny %bool [[row0_to_bool_vec]]
+    // CHECK-NEXT: [[row1:%\d+]] = OpCompositeExtract %v4float [[p]] 1
+    // CHECK-NEXT: [[row1_to_bool_vec:%\d+]] = OpFOrdNotEqual %v4bool [[row1]] [[v4float_0]]
+    // CHECK-NEXT: [[any_row1:%\d+]] = OpAny %bool [[row1_to_bool_vec]]
+    // CHECK-NEXT: [[row2:%\d+]] = OpCompositeExtract %v4float [[p]] 2
+    // CHECK-NEXT: [[row2_to_bool_vec:%\d+]] = OpFOrdNotEqual %v4bool [[row2]] [[v4float_0]]
+    // CHECK-NEXT: [[any_row2:%\d+]] = OpAny %bool [[row2_to_bool_vec]]
+    // CHECK-NEXT: [[any_rows:%\d+]] = OpCompositeConstruct %v3bool [[any_row0]] [[any_row1]] [[any_row2]]
+    // CHECK-NEXT: [[any_mat3x4:%\d+]] = OpAny %bool [[any_rows]]
+    // CHECK-NEXT: OpStore %result [[any_mat3x4]]
+    float3x4 p;
+    result = any(p);
 }
 

+ 25 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.asfloat.hlsl

@@ -7,6 +7,11 @@
 void main() {
     float result;
     float4 result4;
+    float1x1 result1x1;
+    float1x3 result1x3;
+    float2x1 result2x1;
+    float2x3 result2x3; 
+
 
     // CHECK:      [[a:%\d+]] = OpLoad %int %a
     // CHECK-NEXT: [[a_as_float:%\d+]] = OpBitcast %float [[a]]
@@ -58,4 +63,24 @@ void main() {
     // CHECK-NEXT: OpStore %result4 [[i]]
     float4 i;
     result4 = asfloat(i);
+    
+    // CHECK-NEXT: [[j:%\d+]] = OpLoad %float %j
+    // CHECK-NEXT: OpStore %result1x1 [[j]]
+    float1x1 j;
+    result1x1 = asfloat(j);
+    
+    // CHECK-NEXT: [[k:%\d+]] = OpLoad %v3float %k
+    // CHECK-NEXT: OpStore %result1x3 [[k]]    
+    float1x3 k;
+    result1x3 = asfloat(k);
+    
+    // CHECK-NEXT: [[l:%\d+]] = OpLoad %v2float %l
+    // CHECK-NEXT: OpStore %result2x1 [[l]]
+    float2x1 l;
+    result2x1 = asfloat(l);
+    
+    // CHECK-NEXT: [[m:%\d+]] = OpLoad %mat2v3float %m
+    // CHECK-NEXT: OpStore %result2x3 [[m]]
+    float2x3 m;
+    result2x3 = asfloat(m);
 }