Browse Source

[spirv] Sanitize scalar/vector/matrix type probing methods (#540)

Make the type probing methods in TypeTranslator much more clear
as for their responsibilities.
Lei Zhang 8 years ago
parent
commit
c7622ec618

+ 59 - 61
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -27,15 +27,21 @@ namespace {
 
 
 /// Returns true if the two types are the same scalar or vector type.
 /// Returns true if the two types are the same scalar or vector type.
 bool isSameScalarOrVecType(QualType type1, QualType type2) {
 bool isSameScalarOrVecType(QualType type1, QualType type2) {
-  if (type1->isBuiltinType())
-    return type1.getCanonicalType() == type2.getCanonicalType();
+  {
+    QualType scalarType1 = {}, scalarType2 = {};
+    if (TypeTranslator::isScalarType(type1, &scalarType1) &&
+        TypeTranslator::isScalarType(type2, &scalarType2))
+      return scalarType1.getCanonicalType() == scalarType2.getCanonicalType();
+  }
 
 
-  QualType elemType1 = {}, elemType2 = {};
-  uint32_t count1 = {}, count2 = {};
-  if (TypeTranslator::isVectorType(type1, &elemType1, &count1) &&
-      TypeTranslator::isVectorType(type2, &elemType2, &count2))
-    return count1 == count2 &&
-           elemType1.getCanonicalType() == elemType2.getCanonicalType();
+  {
+    QualType elemType1 = {}, elemType2 = {};
+    uint32_t count1 = {}, count2 = {};
+    if (TypeTranslator::isVectorType(type1, &elemType1, &count1) &&
+        TypeTranslator::isVectorType(type2, &elemType2, &count2))
+      return count1 == count2 &&
+             elemType1.getCanonicalType() == elemType2.getCanonicalType();
+  }
 
 
   return false;
   return false;
 }
 }
@@ -43,35 +49,35 @@ bool isSameScalarOrVecType(QualType type1, QualType type2) {
 /// Returns true if the given type is a bool or vector of bool type.
 /// Returns true if the given type is a bool or vector of bool type.
 bool isBoolOrVecOfBoolType(QualType type) {
 bool isBoolOrVecOfBoolType(QualType type) {
   QualType elemType = {};
   QualType elemType = {};
-  return type->isBooleanType() ||
-         (TypeTranslator::isVectorType(type, &elemType, nullptr) &&
-          elemType->isBooleanType());
+  return (TypeTranslator::isScalarType(type, &elemType) ||
+          TypeTranslator::isVectorType(type, &elemType)) &&
+         elemType->isBooleanType();
 }
 }
 
 
 /// Returns true if the given type is a signed integer or vector of signed
 /// Returns true if the given type is a signed integer or vector of signed
 /// integer type.
 /// integer type.
 bool isSintOrVecOfSintType(QualType type) {
 bool isSintOrVecOfSintType(QualType type) {
   QualType elemType = {};
   QualType elemType = {};
-  return type->isSignedIntegerType() ||
-         (TypeTranslator::isVectorType(type, &elemType, nullptr) &&
-          elemType->isSignedIntegerType());
+  return (TypeTranslator::isScalarType(type, &elemType) ||
+          TypeTranslator::isVectorType(type, &elemType)) &&
+         elemType->isSignedIntegerType();
 }
 }
 
 
 /// Returns true if the given type is an unsigned integer or vector of unsigned
 /// Returns true if the given type is an unsigned integer or vector of unsigned
 /// integer type.
 /// integer type.
 bool isUintOrVecOfUintType(QualType type) {
 bool isUintOrVecOfUintType(QualType type) {
   QualType elemType = {};
   QualType elemType = {};
-  return type->isUnsignedIntegerType() ||
-         (TypeTranslator::isVectorType(type, &elemType, nullptr) &&
-          elemType->isUnsignedIntegerType());
+  return (TypeTranslator::isScalarType(type, &elemType) ||
+          TypeTranslator::isVectorType(type, &elemType)) &&
+         elemType->isUnsignedIntegerType();
 }
 }
 
 
 /// Returns true if the given type is a float or vector of float type.
 /// Returns true if the given type is a float or vector of float type.
 bool isFloatOrVecOfFloatType(QualType type) {
 bool isFloatOrVecOfFloatType(QualType type) {
   QualType elemType = {};
   QualType elemType = {};
-  return type->isFloatingType() ||
-         (TypeTranslator::isVectorType(type, &elemType, nullptr) &&
-          elemType->isFloatingType());
+  return (TypeTranslator::isScalarType(type, &elemType) ||
+          TypeTranslator::isVectorType(type, &elemType)) &&
+         elemType->isFloatingType();
 }
 }
 
 
 /// Returns true if the given type is a bool or vector/matrix of bool type.
 /// Returns true if the given type is a bool or vector/matrix of bool type.
@@ -2137,17 +2143,9 @@ uint32_t SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
 
 
 uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,
 uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,
                                   QualType toBoolType) {
                                   QualType toBoolType) {
-  // Semantic analysis should already checked the size
-  if (isBoolOrVecOfBoolType(fromType))
+  if (isSameScalarOrVecType(fromType, toBoolType))
     return fromVal;
     return fromVal;
 
 
-  // Handle 1x1 bool matrix, Mx1 bool matrix, and 1xN bool matrix.
-  {
-    if (typeTranslator.is1x1OrMx1Or1xNMatrix(fromType) &&
-        hlsl::GetHLSLMatElementType(fromType)->isBooleanType())
-      return fromVal;
-  }
-
   // Converting to bool means comparing with value zero.
   // Converting to bool means comparing with value zero.
   const spv::Op spvOp = translateOp(BO_NE, fromType);
   const spv::Op spvOp = translateOp(BO_NE, fromType);
   const uint32_t boolType = typeTranslator.translateType(toBoolType);
   const uint32_t boolType = typeTranslator.translateType(toBoolType);
@@ -2384,9 +2382,10 @@ uint32_t SPIRVEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
   // Handle scalars, vectors of size 1, and 1x1 matrices as arguments.
   // Handle scalars, vectors of size 1, and 1x1 matrices as arguments.
   // Optimization: can directly cast them to boolean. No need for OpAny/OpAll.
   // Optimization: can directly cast them to boolean. No need for OpAny/OpAll.
   {
   {
-    if (argType->isBooleanType() || argType->isFloatingType() ||
-        argType->isIntegerType() || TypeTranslator::isVec1Type(argType) ||
-        TypeTranslator::is1x1Matrix(argType))
+    QualType scalarType = {};
+    if (TypeTranslator::isScalarType(argType, &scalarType) &&
+        (scalarType->isBooleanType() || scalarType->isFloatingType() ||
+         scalarType->isIntegerType()))
       return castToBool(doExpr(arg), argType, returnType);
       return castToBool(doExpr(arg), argType, returnType);
   }
   }
 
 
@@ -2395,8 +2394,7 @@ uint32_t SPIRVEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
   {
   {
     QualType elemType = {};
     QualType elemType = {};
     uint32_t size = 0;
     uint32_t size = 0;
-    if (TypeTranslator::isVectorType(argType, &elemType, &size) ||
-        TypeTranslator::isMx1Or1xNMatrix(argType, &elemType, &size)) {
+    if (TypeTranslator::isVectorType(argType, &elemType, &size)) {
       const QualType castToBoolType =
       const QualType castToBoolType =
           astContext.getExtVectorType(returnType, size);
           astContext.getExtVectorType(returnType, size);
       uint32_t castedToBoolId =
       uint32_t castedToBoolId =
@@ -2536,16 +2534,21 @@ uint32_t SPIRVEmitter::processIntrinsicUsingGLSLInst(
 }
 }
 
 
 uint32_t SPIRVEmitter::getValueZero(QualType type) {
 uint32_t SPIRVEmitter::getValueZero(QualType type) {
-  if (type->isSignedIntegerType()) {
-    return theBuilder.getConstantInt32(0);
-  }
+  {
+    QualType scalarType = {};
+    if (TypeTranslator::isScalarType(type, &scalarType)) {
+      if (scalarType->isSignedIntegerType()) {
+        return theBuilder.getConstantInt32(0);
+      }
 
 
-  if (type->isUnsignedIntegerType()) {
-    return theBuilder.getConstantUint32(0);
-  }
+      if (scalarType->isUnsignedIntegerType()) {
+        return theBuilder.getConstantUint32(0);
+      }
 
 
-  if (type->isFloatingType()) {
-    return theBuilder.getConstantFloat32(0.0);
+      if (scalarType->isFloatingType()) {
+        return theBuilder.getConstantFloat32(0.0);
+      }
+    }
   }
   }
 
 
   {
   {
@@ -2556,17 +2559,7 @@ uint32_t SPIRVEmitter::getValueZero(QualType type) {
     }
     }
   }
   }
 
 
-  // 1x1, Mx1, and 1xN Matrices
-  {
-    QualType elemType = {};
-    uint32_t size = {};
-    if (TypeTranslator::is1x1Matrix(type, &elemType))
-      return getValueZero(elemType);
-    if (TypeTranslator::isMx1Or1xNMatrix(type, &elemType, &size))
-      return getVecValueZero(elemType, size);
-
-    // TODO: Handle getValueZero for MxN matrices.
-  }
+  // TODO: Handle getValueZero for MxN matrices.
 
 
   emitError("getting value 0 for type '%0' unimplemented")
   emitError("getting value 0 for type '%0' unimplemented")
       << type.getAsString();
       << type.getAsString();
@@ -2587,16 +2580,21 @@ uint32_t SPIRVEmitter::getVecValueZero(QualType elemType, uint32_t size) {
 }
 }
 
 
 uint32_t SPIRVEmitter::getValueOne(QualType type) {
 uint32_t SPIRVEmitter::getValueOne(QualType type) {
-  if (type->isSignedIntegerType()) {
-    return theBuilder.getConstantInt32(1);
-  }
+  {
+    QualType scalarType = {};
+    if (TypeTranslator::isScalarType(type, &scalarType)) {
+      if (scalarType->isSignedIntegerType()) {
+        return theBuilder.getConstantInt32(1);
+      }
 
 
-  if (type->isUnsignedIntegerType()) {
-    return theBuilder.getConstantUint32(1);
-  }
+      if (scalarType->isUnsignedIntegerType()) {
+        return theBuilder.getConstantUint32(1);
+      }
 
 
-  if (type->isFloatingType()) {
-    return theBuilder.getConstantFloat32(1.0);
+      if (scalarType->isFloatingType()) {
+        return theBuilder.getConstantFloat32(1.0);
+      }
+    }
   }
   }
 
 
   {
   {

+ 80 - 83
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -20,28 +20,34 @@ uint32_t TypeTranslator::translateType(QualType type) {
   if (canonicalType != type)
   if (canonicalType != type)
     return translateType(canonicalType);
     return translateType(canonicalType);
 
 
-  const auto *typePtr = type.getTypePtr();
-
   // Primitive types
   // Primitive types
-  if (const auto *builtinType = dyn_cast<BuiltinType>(typePtr)) {
-    switch (builtinType->getKind()) {
-    case BuiltinType::Void:
-      return theBuilder.getVoidType();
-    case BuiltinType::Bool:
-      return theBuilder.getBoolType();
-    case BuiltinType::Int:
-      return theBuilder.getInt32Type();
-    case BuiltinType::UInt:
-      return theBuilder.getUint32Type();
-    case BuiltinType::Float:
-      return theBuilder.getFloat32Type();
-    default:
-      emitError("Primitive type '%0' is not supported yet.")
-          << builtinType->getTypeClassName();
-      return 0;
+  {
+    QualType ty = {};
+    if (isScalarType(type, &ty)) {
+      if (const auto *builtinType = cast<BuiltinType>(ty.getTypePtr())) {
+        switch (builtinType->getKind()) {
+        case BuiltinType::Void:
+          return theBuilder.getVoidType();
+        case BuiltinType::Bool:
+          return theBuilder.getBoolType();
+        case BuiltinType::Int:
+          return theBuilder.getInt32Type();
+        case BuiltinType::UInt:
+          return theBuilder.getUint32Type();
+        case BuiltinType::Float:
+          return theBuilder.getFloat32Type();
+        default:
+          emitError("Primitive type '%0' is not supported yet.")
+              << builtinType->getTypeClassName();
+          return 0;
+        }
+      }
     }
     }
   }
   }
 
 
+  const auto *typePtr = type.getTypePtr();
+
+  // Typedefs
   if (const auto *typedefType = dyn_cast<TypedefType>(typePtr)) {
   if (const auto *typedefType = dyn_cast<TypedefType>(typePtr)) {
     return translateType(typedefType->desugar());
     return translateType(typedefType->desugar());
   }
   }
@@ -49,6 +55,7 @@ uint32_t TypeTranslator::translateType(QualType type) {
   // In AST, vector/matrix types are TypedefType of TemplateSpecializationType.
   // In AST, vector/matrix types are TypedefType of TemplateSpecializationType.
   // We handle them via HLSL type inspection functions.
   // We handle them via HLSL type inspection functions.
 
 
+  // Vector types
   {
   {
     QualType elemType = {};
     QualType elemType = {};
     uint32_t elemCount = {};
     uint32_t elemCount = {};
@@ -62,18 +69,12 @@ uint32_t TypeTranslator::translateType(QualType type) {
     }
     }
   }
   }
 
 
+  // Matrix types
   if (hlsl::IsHLSLMatType(type)) {
   if (hlsl::IsHLSLMatType(type)) {
-    uint32_t elemCount = 0;
-    const auto elemTy = hlsl::GetHLSLMatElementType(type);
-
-    // 1x1 matrix is a scalar
-    if (is1x1Matrix(type))
-      return translateType(elemTy);
-
-    // Mx1 matrix or 1xN matrix is a vector.
-    if (isMx1Or1xNMatrix(type, nullptr, &elemCount))
-      return theBuilder.getVecType(translateType(elemTy), elemCount);
+    // The other cases should already be handled in the above.
+    assert(isMxNMatrix(type));
 
 
+    const auto elemTy = hlsl::GetHLSLMatElementType(type);
     // NOTE: According to Item "Data rules" of SPIR-V Spec 2.16.1 "Universal
     // NOTE: According to Item "Data rules" of SPIR-V Spec 2.16.1 "Universal
     // Validation Rules":
     // Validation Rules":
     //   Matrix types can only be parameterized with floating-point types.
     //   Matrix types can only be parameterized with floating-point types.
@@ -84,23 +85,11 @@ uint32_t TypeTranslator::translateType(QualType type) {
       emitError("Non-floating-point matrices not supported yet");
       emitError("Non-floating-point matrices not supported yet");
       return 0;
       return 0;
     }
     }
-    const auto elemType = translateType(elemTy);
 
 
+    const auto elemType = translateType(elemTy);
     uint32_t rowCount = 0, colCount = 0;
     uint32_t rowCount = 0, colCount = 0;
     hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
     hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
 
 
-    // In SPIR-V, matrices must have two or more columns.
-    // Handle degenerated cases first.
-
-    if (rowCount == 1 && colCount == 1)
-      return elemType;
-
-    if (rowCount == 1)
-      return theBuilder.getVecType(elemType, colCount);
-
-    if (colCount == 1)
-      return theBuilder.getVecType(elemType, rowCount);
-
     // HLSL matrices are row major, while SPIR-V matrices are column major.
     // HLSL matrices are row major, while SPIR-V matrices are column major.
     // We are mapping what HLSL semantically mean a row into a column here.
     // We are mapping what HLSL semantically mean a row into a column here.
     const uint32_t vecType = theBuilder.getVecType(elemType, colCount);
     const uint32_t vecType = theBuilder.getVecType(elemType, colCount);
@@ -124,31 +113,64 @@ uint32_t TypeTranslator::translateType(QualType type) {
   return 0;
   return 0;
 }
 }
 
 
-bool TypeTranslator::isVec1Type(QualType type, QualType *elemType) {
-  uint32_t count = 0;
-  const bool isVec = isVectorType(type, elemType, &count);
-  return isVec && count == 1;
+bool TypeTranslator::isScalarType(QualType type, QualType *scalarType) {
+  bool isScalar = false;
+  QualType ty = {};
+
+  if (type->isBuiltinType()) {
+    isScalar = true;
+    ty = type;
+  } else if (hlsl::IsHLSLVecType(type) && hlsl::GetHLSLVecSize(type) == 1) {
+    isScalar = true;
+    ty = hlsl::GetHLSLVecElementType(type);
+  } else if (const auto *extVecType =
+                 dyn_cast<ExtVectorType>(type.getTypePtr())) {
+    if (extVecType->getNumElements() == 1) {
+      isScalar = true;
+      ty = extVecType->getElementType();
+    }
+  } else if (is1x1Matrix(type)) {
+    isScalar = true;
+    ty = hlsl::GetHLSLMatElementType(type);
+  }
+
+  if (isScalar && scalarType)
+    *scalarType = ty;
+
+  return isScalar;
 }
 }
 
 
 bool TypeTranslator::isVectorType(QualType type, QualType *elemType,
 bool TypeTranslator::isVectorType(QualType type, QualType *elemType,
-                                  uint32_t *count) {
+                                  uint32_t *elemCount) {
+  bool isVec = false;
+  QualType ty = {};
+  uint32_t count = 0;
+
   if (hlsl::IsHLSLVecType(type)) {
   if (hlsl::IsHLSLVecType(type)) {
-    if (elemType)
-      *elemType = hlsl::GetHLSLVecElementType(type);
-    if (count)
-      *count = hlsl::GetHLSLVecSize(type);
-    return true;
+    ty = hlsl::GetHLSLVecElementType(type);
+    count = hlsl::GetHLSLVecSize(type);
+    isVec = count > 1;
+  } else if (const auto *extVecType =
+                 dyn_cast<ExtVectorType>(type.getTypePtr())) {
+    ty = extVecType->getElementType();
+    count = extVecType->getNumElements();
+    isVec = count > 1;
+  } else if (hlsl::IsHLSLMatType(type)) {
+    uint32_t rowCount = 0, colCount = 0;
+    hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
+
+    ty = hlsl::GetHLSLMatElementType(type);
+    count = rowCount == 1 ? colCount : rowCount;
+    isVec = (rowCount == 1) != (colCount == 1);
   }
   }
 
 
-  if (const auto *extVecType = dyn_cast<ExtVectorType>(type.getTypePtr())) {
+  if (isVec) {
     if (elemType)
     if (elemType)
-      *elemType = extVecType->getElementType();
-    if (count)
-      *count = extVecType->getNumElements();
-    return true;
+      *elemType = ty;
+    if (elemCount)
+      *elemCount = count;
   }
   }
-
-  return false;
+  return isVec;
 }
 }
 
 
 bool TypeTranslator::is1x1Matrix(QualType type, QualType *elemType) {
 bool TypeTranslator::is1x1Matrix(QualType type, QualType *elemType) {
@@ -208,31 +230,6 @@ bool TypeTranslator::isMx1Matrix(QualType type, QualType *elemType,
   return true;
   return true;
 }
 }
 
 
-bool TypeTranslator::isMx1Or1xNMatrix(QualType type, QualType *elemType,
-                                      uint32_t *count) {
-  if (!hlsl::IsHLSLMatType(type))
-    return false;
-
-  uint32_t rowCount = 0, colCount = 0;
-  hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
-
-  const bool isSingleRowOrCol =
-      (rowCount == 1 && colCount > 1) || (rowCount > 1 && colCount == 1);
-
-  if (!isSingleRowOrCol)
-    return false;
-
-  if (elemType)
-    *elemType = hlsl::GetHLSLMatElementType(type);
-  if (count)
-    *count = rowCount > 1 ? rowCount : colCount;
-  return true;
-}
-
-bool TypeTranslator::is1x1OrMx1Or1xNMatrix(QualType type) {
-  return is1x1Matrix(type) || isMx1Or1xNMatrix(type);
-}
-
 bool TypeTranslator::isMxNMatrix(QualType type, QualType *elemType,
 bool TypeTranslator::isMxNMatrix(QualType type, QualType *elemType,
                                  uint32_t *numRows, uint32_t *numCols) {
                                  uint32_t *numRows, uint32_t *numCols) {
   if (!hlsl::IsHLSLMatType(type))
   if (!hlsl::IsHLSLMatType(type))

+ 13 - 18
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -37,14 +37,19 @@ public:
   /// on will be generated.
   /// on will be generated.
   uint32_t translateType(QualType type);
   uint32_t translateType(QualType type);
 
 
-  /// \breif Returns true if the given type is a vector type (either
-  /// ExtVectorType or HLSL vector type) and writes the element type and count
-  /// into *elementType and *count respectively if they are not nullptr.
-  static bool isVectorType(QualType type, QualType *elemType, uint32_t *count);
-
-  /// \brief Returns true if the given type is a vector type of size 1.
-  /// If elemType is not nullptr, writes the element type to *elemType.
-  static bool isVec1Type(QualType type, QualType *elemType = nullptr);
+  /// \brief Returns true if the given type will be translated into a SPIR-V
+  /// scalar type. This includes normal scalar types, vectors of size 1, and
+  /// 1x1 matrices. If scalarType is not nullptr, writes the scalar type to
+  /// *scalarType.
+  static bool isScalarType(QualType type, QualType *scalarType = nullptr);
+
+  /// \breif Returns true if the given type will be translated into a SPIR-V
+  /// vector type. This includes normal types (either ExtVectorType or HLSL
+  /// vector type) with more than one elements and matrices with exactly one
+  /// row or one column. Writes the element type and count into *elementType and
+  /// *count respectively if they are not nullptr.
+  static bool isVectorType(QualType type, QualType *elemType = nullptr,
+                           uint32_t *count = nullptr);
 
 
   /// \brief Returns true if the given type is a 1x1 matrix type.
   /// \brief Returns true if the given type is a 1x1 matrix type.
   /// If elemType is not nullptr, writes the element type to *elemType.
   /// If elemType is not nullptr, writes the element type to *elemType.
@@ -62,16 +67,6 @@ public:
   static bool isMx1Matrix(QualType type, QualType *elemType = nullptr,
   static bool isMx1Matrix(QualType type, QualType *elemType = nullptr,
                           uint32_t *count = nullptr);
                           uint32_t *count = nullptr);
 
 
-  /// \brief Returns true if the given type is a Mx1 (M > 1), or 1xN (N > 1)
-  /// matrix type. If elemType is not nullptr, writes the matrix element type to
-  /// *elemType. If count is not nullptr, writes the size (M or N) into *count.
-  static bool isMx1Or1xNMatrix(QualType type, QualType *elemType = nullptr,
-                               uint32_t *count = nullptr);
-
-  /// \brief Returns true if the given type is a 1x1, or Mx1 (M > 1), or
-  /// 1xN (N > 1) matrix type.
-  static bool is1x1OrMx1Or1xNMatrix(QualType type);
-
   /// \brief returns true if the given type is a matrix with more than 1 row and
   /// \brief returns true if the given type is a matrix with more than 1 row and
   /// more than 1 column.
   /// more than 1 column.
   /// If elemType is not nullptr, writes the element type to *elemType.
   /// If elemType is not nullptr, writes the element type to *elemType.