Răsfoiți Sursa

[spirv] Remove usages of TypeTranslator.

Ehsan Nasiri 6 ani în urmă
părinte
comite
36c1cd728e

+ 122 - 6
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -117,22 +117,22 @@ bool canTreatAsSameScalarType(QualType type1, QualType type2);
 /// regardless of constness and literalness.
 bool isSameScalarOrVecType(QualType type1, QualType type2);
 
-  /// \brief Returns true if the two types are the same type, regardless of
-  /// constness and literalness.
+/// \brief Returns true if the two types are the same type, regardless of
+/// constness and literalness.
 bool isSameType(const ASTContext &, QualType type1, QualType type2);
 
 /// Returns true if all members in structType are of the same element
 /// type and can be fit into a 4-component vector. Writes element type and
 /// count to *elemType and *elemCount if not nullptr. Otherwise, emit errors
 /// explaining why not.
-bool canFitIntoOneRegister(QualType structType, QualType *elemType,
-                           uint32_t *elemCount = nullptr);
+bool canFitIntoOneRegister(const ASTContext &, QualType structType,
+                           QualType *elemType, uint32_t *elemCount = nullptr);
 
 /// Returns the element type of the given type. The given type may be a scalar
 /// type, vector type, matrix type, or array type. It may also be a struct with
 /// members that can fit into a register. In such case, the result would be the
 /// struct member type.
-QualType getElementType(QualType type);
+QualType getElementType(const ASTContext &, QualType type);
 
 QualType getTypeWithCustomBitwidth(const ASTContext &, QualType type,
                                    uint32_t bitwidth);
@@ -148,7 +148,7 @@ bool isLitTypeOrVecOfLitType(QualType type);
 /// desugared one. If isRowMajor is not nullptr, and a 'row_major' or
 /// 'column-major' attribute is found during desugaring, this information is
 /// written to *isRowMajor.
-QualType desugarType(QualType type, llvm::Optional<bool>* isRowMajor);
+QualType desugarType(QualType type, llvm::Optional<bool> *isRowMajor);
 
 /// Returns true if type is a SPIR-V row-major matrix or array of matrices.
 /// Returns false if type is a SPIR-V col-major matrix or array of matrices.
@@ -160,6 +160,122 @@ QualType desugarType(QualType type, llvm::Optional<bool>* isRowMajor);
 /// a row into a column here.
 bool isRowMajorMatrix(const SpirvCodeGenOptions &, QualType type);
 
+/// \brief Returns true if the given type is a (RW)StructuredBuffer type.
+bool isStructuredBuffer(QualType type);
+
+/// \brief Returns true if the given type is an AppendStructuredBuffer type.
+bool isAppendStructuredBuffer(QualType type);
+
+/// \brief Returns true if the given type is a ConsumeStructuredBuffer type.
+bool isConsumeStructuredBuffer(QualType type);
+
+/// \brief Returns true if the given type is a RW/Append/Consume
+/// StructuredBuffer type.
+bool isRWAppendConsumeSBuffer(QualType type);
+
+/// \brief Returns true if the given type is the HLSL ByteAddressBufferType.
+bool isByteAddressBuffer(QualType type);
+
+/// \brief Returns true if the given type is the HLSL RWByteAddressBufferType.
+bool isRWByteAddressBuffer(QualType type);
+
+/// \brief Returns true if the given type is the HLSL (RW)StructuredBuffer,
+/// (RW)ByteAddressBuffer, or {Append|Consume}StructuredBuffer.
+bool isAKindOfStructuredOrByteBuffer(QualType type);
+
+/// \brief Returns true if the given type is the HLSL (RW)StructuredBuffer,
+/// (RW)ByteAddressBuffer, {Append|Consume}StructuredBuffer, or a struct
+/// containing one of the above.
+bool isOrContainsAKindOfStructuredOrByteBuffer(QualType type);
+
+/// \brief Returns true if the given type is the HLSL Buffer type.
+bool isBuffer(QualType type);
+
+/// \brief Returns true if the given type is the HLSL RWBuffer type.
+bool isRWBuffer(QualType type);
+
+/// \brief Returns true if the given type is an HLSL Texture type.
+bool isTexture(QualType);
+
+/// \brief Returns true if the given type is an HLSL Texture2DMS or
+/// Texture2DMSArray type.
+bool isTextureMS(QualType);
+
+/// \brief Returns true if the given type is an HLSL RWTexture type.
+bool isRWTexture(QualType);
+
+/// \brief Returns true if the given type is an HLSL sampler type.
+bool isSampler(QualType);
+
+/// \brief Returns true if the given type is SubpassInput.
+bool isSubpassInput(QualType);
+
+/// \brief Returns true if the given type is SubpassInputMS.
+bool isSubpassInputMS(QualType);
+
+/// Returns true if the given type will be translated into a SPIR-V image,
+/// sampler or struct containing images or samplers.
+///
+/// Note: legalization specific code
+bool isOpaqueType(QualType type);
+
+/// Returns true if the given type will be translated into a array of SPIR-V
+/// images or samplers.
+bool isOpaqueArrayType(QualType type);
+
+/// Returns true if the given type is a struct type who has an opaque field
+/// (in a recursive away).
+///
+/// Note: legalization specific code
+bool isOpaqueStructType(QualType tye);
+
+/// \brief Returns true if the given type can use relaxed precision
+/// decoration. Integer and float types with lower than 32 bits can be
+/// operated on with a relaxed precision.
+bool isRelaxedPrecisionType(QualType, const SpirvCodeGenOptions &);
+
+/// Returns true if the given type is a bool or vector of bool type.
+bool isBoolOrVecOfBoolType(QualType type);
+
+/// Returns true if the given type is a signed integer or vector of signed
+/// integer type.
+bool isSintOrVecOfSintType(QualType type);
+
+/// Returns true if the given type is an unsigned integer or vector of unsigned
+/// integer type.
+bool isUintOrVecOfUintType(QualType type);
+
+/// Returns true if the given type is a float or vector of float type.
+bool isFloatOrVecOfFloatType(QualType type);
+
+/// Returns true if the given type is a bool or vector/matrix of bool type.
+bool isBoolOrVecMatOfBoolType(QualType type);
+
+/// Returns true if the given type is a signed integer or vector/matrix of
+/// signed integer type.
+bool isSintOrVecMatOfSintType(QualType type);
+
+/// Returns true if the given type is an unsigned integer or vector/matrix of
+/// unsigned integer type.
+bool isUintOrVecMatOfUintType(QualType type);
+
+/// Returns true if the given type is a float or vector/matrix of float type.
+bool isFloatOrVecMatOfFloatType(QualType type);
+
+/// \brief Returns true if the decl type is a non-floating-point matrix and
+/// the matrix is column major, or if it is an array/struct containing such
+/// matrices.
+bool isOrContainsNonFpColMajorMatrix(const ASTContext &,
+                                     const SpirvCodeGenOptions &, QualType type,
+                                     const Decl *decl);
+
+/// \brief Generates the corresponding SPIR-V vector type for the given Clang
+/// frontend matrix type's vector component and returns the <result-id>.
+///
+/// This method will panic if the given matrix type is not a SPIR-V acceptable
+/// matrix type.
+QualType getComponentVectorType(const ASTContext &, QualType matrixType);
+
 } // namespace spirv
 } // namespace clang
 

+ 0 - 1
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -68,7 +68,6 @@ public:
   static bool isSubpassInputMS(const SpirvType *);
   static bool isResourceType(const SpirvType *);
   static bool isOrContains16BitType(const SpirvType *);
-  static bool isMatrixOrArrayOfMatrix(const SpirvType *);
 
 protected:
   SpirvType(Kind k, llvm::StringRef name = "") : kind(k), debugName(name) {}

+ 355 - 42
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -9,9 +9,21 @@
 
 #include "clang/SPIRV/AstTypeProbe.h"
 #include "clang/AST/ASTContext.h"
+#include "clang/AST/Attr.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/HlslTypes.h"
 
+namespace {
+template <unsigned N>
+clang::DiagnosticBuilder emitError(const clang::ASTContext &astContext,
+                                   const char (&message)[N],
+                                   clang::SourceLocation srcLoc = {}) {
+  const auto diagId = astContext.getDiagnostics().getCustomDiagID(
+      clang::DiagnosticsEngine::Error, message);
+  return astContext.getDiagnostics().Report(srcLoc, diagId);
+}
+} // namespace
+
 namespace clang {
 namespace spirv {
 
@@ -212,22 +224,6 @@ bool isMxNMatrix(QualType type, QualType *elemType, uint32_t *numRows,
   return false;
 }
 
-bool isOrContainsAKindOfStructuredOrByteBuffer(QualType type) {
-  if (const RecordType *recordType = type->getAs<RecordType>()) {
-    StringRef name = recordType->getDecl()->getName();
-    if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
-        name == "ByteAddressBuffer" || name == "RWByteAddressBuffer" ||
-        name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer")
-      return true;
-
-    for (const auto *field : recordType->getDecl()->fields()) {
-      if (isOrContainsAKindOfStructuredOrByteBuffer(field->getType()))
-        return true;
-    }
-  }
-  return false;
-}
-
 bool isSubpassInput(QualType type) {
   if (const auto *rt = type->getAs<RecordType>())
     return rt->getDecl()->getName() == "SubpassInput";
@@ -267,21 +263,6 @@ bool isResourceType(const ValueDecl *decl) {
   return hlsl::IsHLSLResourceType(declType);
 }
 
-bool isAKindOfStructuredOrByteBuffer(QualType type) {
-  // Strip outer arrayness first
-  while (type->isArrayType())
-    type = type->getAsArrayTypeUnsafe()->getElementType();
-
-  if (const RecordType *recordType = type->getAs<RecordType>()) {
-    StringRef name = recordType->getDecl()->getName();
-    return name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
-           name == "ByteAddressBuffer" || name == "RWByteAddressBuffer" ||
-           name == "AppendStructuredBuffer" ||
-           name == "ConsumeStructuredBuffer";
-  }
-  return false;
-}
-
 bool isOrContains16BitType(QualType type, bool enable16BitTypesOption) {
   // Primitive types
   {
@@ -478,8 +459,8 @@ bool canTreatAsSameScalarType(QualType type1, QualType type2) {
           !type1->isSpecificBuiltinType(BuiltinType::Bool));
 }
 
-bool canFitIntoOneRegister(QualType structType, QualType *elemType,
-                           uint32_t *elemCount) {
+bool canFitIntoOneRegister(const ASTContext &astContext, QualType structType,
+                           QualType *elemType, uint32_t *elemCount) {
   if (structType->getAsStructureType() == nullptr)
     return false;
 
@@ -497,23 +478,29 @@ bool canFitIntoOneRegister(QualType structType, QualType *elemType,
         firstElemType = type;
       } else {
         if (!canTreatAsSameScalarType(firstElemType, type)) {
-          assert(false && "all struct members should have the same element "
-                          "type for resource template instantiation");
+          emitError(astContext,
+                    "all struct members should have the same element type for "
+                    "resource template instantiation",
+                    structDecl->getLocation());
           return false;
         }
       }
       totalCount += count;
     } else {
-      assert(false && "unsupported struct element type for resource template "
-                      "instantiation");
+      emitError(
+          astContext,
+          "unsupported struct element type for resource template instantiation",
+          structDecl->getLocation());
       return false;
     }
   }
 
   if (totalCount > 4) {
-    assert(
-        false &&
-        "resource template element type cannot fit into four 32-bit scalars");
+    emitError(
+        astContext,
+        "resource template element type %0 cannot fit into four 32-bit scalars",
+        structDecl->getLocation())
+        << structType;
     return false;
   }
 
@@ -524,10 +511,11 @@ bool canFitIntoOneRegister(QualType structType, QualType *elemType,
   return true;
 }
 
-QualType getElementType(QualType type) {
+QualType getElementType(const ASTContext &astContext, QualType type) {
   QualType elemType = {};
   if (isScalarType(type, &elemType) || isVectorType(type, &elemType) ||
-      isMxNMatrix(type, &elemType) || canFitIntoOneRegister(type, &elemType)) {
+      isMxNMatrix(type, &elemType) ||
+      canFitIntoOneRegister(astContext, type, &elemType)) {
     return elemType;
   }
 
@@ -701,5 +689,330 @@ bool isRowMajorMatrix(const SpirvCodeGenOptions &spvOptions, QualType type) {
   return !spvOptions.defaultRowMajor;
 }
 
+bool isStructuredBuffer(QualType type) {
+  const auto *recordType = type->getAs<RecordType>();
+  if (!recordType)
+    return false;
+  const auto name = recordType->getDecl()->getName();
+  return name == "StructuredBuffer" || name == "RWStructuredBuffer";
+}
+
+bool isByteAddressBuffer(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>()) {
+    return rt->getDecl()->getName() == "ByteAddressBuffer";
+  }
+  return false;
+}
+
+bool isRWBuffer(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>()) {
+    return rt->getDecl()->getName() == "RWBuffer";
+  }
+  return false;
+}
+
+bool isBuffer(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>()) {
+    return rt->getDecl()->getName() == "Buffer";
+  }
+  return false;
+}
+
+bool isRWTexture(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>()) {
+    const auto name = rt->getDecl()->getName();
+    if (name == "RWTexture1D" || name == "RWTexture1DArray" ||
+        name == "RWTexture2D" || name == "RWTexture2DArray" ||
+        name == "RWTexture3D")
+      return true;
+  }
+  return false;
+}
+
+bool isTexture(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>()) {
+    const auto name = rt->getDecl()->getName();
+    if (name == "Texture1D" || name == "Texture1DArray" ||
+        name == "Texture2D" || name == "Texture2DArray" ||
+        name == "Texture2DMS" || name == "Texture2DMSArray" ||
+        name == "TextureCube" || name == "TextureCubeArray" ||
+        name == "Texture3D")
+      return true;
+  }
+  return false;
+}
+
+bool isTextureMS(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>()) {
+    const auto name = rt->getDecl()->getName();
+    if (name == "Texture2DMS" || name == "Texture2DMSArray")
+      return true;
+  }
+  return false;
+}
+
+bool isSampler(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>()) {
+    const auto name = rt->getDecl()->getName();
+    if (name == "SamplerState" || name == "SamplerComparisonState")
+      return true;
+  }
+  return false;
+}
+
+bool isRWByteAddressBuffer(QualType type) {
+  if (const auto *rt = type->getAs<RecordType>()) {
+    return rt->getDecl()->getName() == "RWByteAddressBuffer";
+  }
+  return false;
+}
+
+bool isAppendStructuredBuffer(QualType type) {
+  const auto *recordType = type->getAs<RecordType>();
+  if (!recordType)
+    return false;
+  const auto name = recordType->getDecl()->getName();
+  return name == "AppendStructuredBuffer";
+}
+
+bool isConsumeStructuredBuffer(QualType type) {
+  const auto *recordType = type->getAs<RecordType>();
+  if (!recordType)
+    return false;
+  const auto name = recordType->getDecl()->getName();
+  return name == "ConsumeStructuredBuffer";
+}
+
+bool isRWAppendConsumeSBuffer(QualType type) {
+  if (const RecordType *recordType = type->getAs<RecordType>()) {
+    StringRef name = recordType->getDecl()->getName();
+    return name == "RWStructuredBuffer" || name == "AppendStructuredBuffer" ||
+           name == "ConsumeStructuredBuffer";
+  }
+  return false;
+}
+
+bool isAKindOfStructuredOrByteBuffer(QualType type) {
+  // Strip outer arrayness first
+  while (type->isArrayType())
+    type = type->getAsArrayTypeUnsafe()->getElementType();
+
+  if (const RecordType *recordType = type->getAs<RecordType>()) {
+    StringRef name = recordType->getDecl()->getName();
+    return name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
+           name == "ByteAddressBuffer" || name == "RWByteAddressBuffer" ||
+           name == "AppendStructuredBuffer" ||
+           name == "ConsumeStructuredBuffer";
+  }
+  return false;
+}
+
+bool isOrContainsAKindOfStructuredOrByteBuffer(QualType type) {
+  if (const RecordType *recordType = type->getAs<RecordType>()) {
+    StringRef name = recordType->getDecl()->getName();
+    if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
+        name == "ByteAddressBuffer" || name == "RWByteAddressBuffer" ||
+        name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer")
+      return true;
+
+    for (const auto *field : recordType->getDecl()->fields()) {
+      if (isOrContainsAKindOfStructuredOrByteBuffer(field->getType()))
+        return true;
+    }
+  }
+  return false;
+}
+
+bool isOpaqueType(QualType type) {
+  if (const auto *recordType = type->getAs<RecordType>()) {
+    const auto name = recordType->getDecl()->getName();
+
+    if (name == "Texture1D" || name == "RWTexture1D")
+      return true;
+    if (name == "Texture2D" || name == "RWTexture2D")
+      return true;
+    if (name == "Texture2DMS" || name == "RWTexture2DMS")
+      return true;
+    if (name == "Texture3D" || name == "RWTexture3D")
+      return true;
+    if (name == "TextureCube" || name == "RWTextureCube")
+      return true;
+
+    if (name == "Texture1DArray" || name == "RWTexture1DArray")
+      return true;
+    if (name == "Texture2DArray" || name == "RWTexture2DArray")
+      return true;
+    if (name == "Texture2DMSArray" || name == "RWTexture2DMSArray")
+      return true;
+    if (name == "TextureCubeArray" || name == "RWTextureCubeArray")
+      return true;
+
+    if (name == "Buffer" || name == "RWBuffer")
+      return true;
+
+    if (name == "SamplerState" || name == "SamplerComparisonState")
+      return true;
+  }
+  return false;
+}
+
+bool isOpaqueStructType(QualType type) {
+  if (isOpaqueType(type))
+    return false;
+
+  if (const auto *recordType = type->getAs<RecordType>())
+    for (const auto *field : recordType->getDecl()->decls())
+      if (const auto *fieldDecl = dyn_cast<FieldDecl>(field))
+        if (isOpaqueType(fieldDecl->getType()) ||
+            isOpaqueStructType(fieldDecl->getType()))
+          return true;
+
+  return false;
+}
+
+bool isOpaqueArrayType(QualType type) {
+  if (const auto *arrayType = type->getAsArrayTypeUnsafe())
+    return isOpaqueType(arrayType->getElementType());
+  return false;
+}
+
+bool isRelaxedPrecisionType(QualType type, const SpirvCodeGenOptions &opts) {
+  // Primitive types
+  {
+    QualType ty = {};
+    if (isScalarType(type, &ty))
+      if (const auto *builtinType = ty->getAs<BuiltinType>())
+        switch (builtinType->getKind()) {
+        case BuiltinType::Min12Int:
+        case BuiltinType::Min16Int:
+        case BuiltinType::Min16UInt:
+        case BuiltinType::Min16Float:
+        case BuiltinType::Min10Float: {
+          // If '-enable-16bit-types' options is enabled, these types are
+          // translated to real 16-bit type, and therefore are not
+          // RelaxedPrecision.
+          // If the options is not enabled, these types are translated to 32-bit
+          // types with the added RelaxedPrecision decoration.
+          return !opts.enable16BitTypes;
+        default:
+          // Filter switch only interested in relaxed precision eligible types.
+          break;
+        }
+        }
+  }
+
+  // Vector & Matrix types could use relaxed precision based on their element
+  // type.
+  {
+    QualType elemType = {};
+    if (isVectorType(type, &elemType) || isMxNMatrix(type, &elemType))
+      return isRelaxedPrecisionType(elemType, opts);
+  }
+
+  return false;
+}
+
+/// Returns true if the given type is a bool or vector of bool type.
+bool isBoolOrVecOfBoolType(QualType type) {
+  QualType elemType = {};
+  return (isScalarType(type, &elemType) || isVectorType(type, &elemType)) &&
+         elemType->isBooleanType();
+}
+
+/// Returns true if the given type is a signed integer or vector of signed
+/// integer type.
+bool isSintOrVecOfSintType(QualType type) {
+  QualType elemType = {};
+  return (isScalarType(type, &elemType) || isVectorType(type, &elemType)) &&
+         elemType->isSignedIntegerType();
+}
+
+/// Returns true if the given type is an unsigned integer or vector of unsigned
+/// integer type.
+bool isUintOrVecOfUintType(QualType type) {
+  QualType elemType = {};
+  return (isScalarType(type, &elemType) || isVectorType(type, &elemType)) &&
+         elemType->isUnsignedIntegerType();
+}
+
+/// Returns true if the given type is a float or vector of float type.
+bool isFloatOrVecOfFloatType(QualType type) {
+  QualType elemType = {};
+  return (isScalarType(type, &elemType) || isVectorType(type, &elemType)) &&
+         elemType->isFloatingType();
+}
+
+/// Returns true if the given type is a bool or vector/matrix of bool type.
+bool isBoolOrVecMatOfBoolType(QualType type) {
+  return isBoolOrVecOfBoolType(type) ||
+         (hlsl::IsHLSLMatType(type) &&
+          hlsl::GetHLSLMatElementType(type)->isBooleanType());
+}
+
+/// Returns true if the given type is a signed integer or vector/matrix of
+/// signed integer type.
+bool isSintOrVecMatOfSintType(QualType type) {
+  return isSintOrVecOfSintType(type) ||
+         (hlsl::IsHLSLMatType(type) &&
+          hlsl::GetHLSLMatElementType(type)->isSignedIntegerType());
+}
+
+/// Returns true if the given type is an unsigned integer or vector/matrix of
+/// unsigned integer type.
+bool isUintOrVecMatOfUintType(QualType type) {
+  return isUintOrVecOfUintType(type) ||
+         (hlsl::IsHLSLMatType(type) &&
+          hlsl::GetHLSLMatElementType(type)->isUnsignedIntegerType());
+}
+
+/// Returns true if the given type is a float or vector/matrix of float type.
+bool isFloatOrVecMatOfFloatType(QualType type) {
+  return isFloatOrVecOfFloatType(type) ||
+         (hlsl::IsHLSLMatType(type) &&
+          hlsl::GetHLSLMatElementType(type)->isFloatingType());
+}
+
+bool isOrContainsNonFpColMajorMatrix(const ASTContext &astContext,
+                                     const SpirvCodeGenOptions &spirvOptions,
+                                     QualType type, const Decl *decl) {
+  const auto isColMajorDecl = [&spirvOptions](const Decl *decl) {
+    return decl->hasAttr<clang::HLSLColumnMajorAttr>() ||
+           (!decl->hasAttr<clang::HLSLRowMajorAttr>() &&
+            !spirvOptions.defaultRowMajor);
+  };
+
+  QualType elemType = {};
+  if (isMxNMatrix(type, &elemType) && !elemType->isFloatingType()) {
+    return isColMajorDecl(decl);
+  }
+
+  if (const auto *arrayType = astContext.getAsConstantArrayType(type)) {
+    if (isMxNMatrix(arrayType->getElementType(), &elemType) &&
+        !elemType->isFloatingType())
+      return isColMajorDecl(decl);
+  }
+
+  if (const auto *structType = type->getAs<RecordType>()) {
+    const auto *decl = structType->getDecl();
+    for (const auto *field : decl->fields()) {
+      if (isOrContainsNonFpColMajorMatrix(astContext, spirvOptions,
+                                          field->getType(), field))
+        return true;
+    }
+  }
+
+  return false;
+}
+
+QualType getComponentVectorType(const ASTContext &astContext,
+                                QualType matrixType) {
+  assert(isMxNMatrix(matrixType));
+
+  const QualType elemType = hlsl::GetHLSLMatElementType(matrixType);
+  uint32_t rowCount = 0, colCount = 0;
+  hlsl::GetHLSLMatRowColCount(matrixType, rowCount, colCount);
+  return astContext.getExtVectorType(elemType, colCount);
+}
+
 } // namespace spirv
 } // namespace clang

+ 3 - 4
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -554,8 +554,7 @@ DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
 void DeclResultIdMapper::createCounterVarForDecl(const DeclaratorDecl *decl) {
   const QualType declType = getTypeOrFnRetType(decl);
 
-  if (!counterVars.count(decl) &&
-      TypeTranslator::isRWAppendConsumeSBuffer(declType)) {
+  if (!counterVars.count(decl) && isRWAppendConsumeSBuffer(declType)) {
     createCounterVar(decl, /*declId=*/0, /*isAlias=*/true);
   } else if (!fieldCounterVars.count(decl) && declType->isStructureType() &&
              // Exclude other resource types which are represented as structs
@@ -994,7 +993,7 @@ void DeclResultIdMapper::createFieldCounterVars(
     indices->push_back(getNumBaseClasses(type) + field->getFieldIndex());
 
     const QualType fieldType = field->getType();
-    if (TypeTranslator::isRWAppendConsumeSBuffer(fieldType))
+    if (isRWAppendConsumeSBuffer(fieldType))
       createCounterVar(rootDecl, /*declId=*/0, /*isAlias=*/true, indices);
     else if (fieldType->isStructureType() &&
              !hlsl::IsHLSLResourceType(fieldType))
@@ -2140,7 +2139,7 @@ DeclResultIdMapper::invertWIfRequested(SpirvInstruction *position) {
 void DeclResultIdMapper::decoratePSInterpolationMode(const NamedDecl *decl,
                                                      QualType type,
                                                      SpirvVariable *varInstr) {
-  const QualType elemType = getElementType(type);
+  const QualType elemType = getElementType(astContext, type);
   const auto loc = decl->getLocation();
 
   if (elemType->isBooleanType() || elemType->isIntegerType()) {

+ 0 - 1
tools/clang/lib/SPIRV/GlPerVertex.h

@@ -13,7 +13,6 @@
 #include "dxc/DXIL/DxilSemantic.h"
 #include "dxc/DXIL/DxilShaderModel.h"
 #include "dxc/DXIL/DxilSigPoint.h"
-#include "clang/SPIRV/ModuleBuilder.h"
 #include "clang/SPIRV/SpirvBuilder.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Optional.h"

+ 7 - 7
tools/clang/lib/SPIRV/InitListHandler.cpp

@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "InitListHandler.h"
+#include "clang/SPIRV/AstTypeProbe.h"
 
 #include <algorithm>
 #include <iterator>
@@ -23,8 +24,7 @@ namespace spirv {
 
 InitListHandler::InitListHandler(const ASTContext &ctx, SPIRVEmitter &emitter)
     : astContext(ctx), theEmitter(emitter),
-      spvBuilder(emitter.getModuleBuilder()),
-      typeTranslator(emitter.getTypeTranslator()),
+      spvBuilder(emitter.getSpirvBuilder()),
       diags(emitter.getDiagnosticsEngine()) {}
 
 SpirvInstruction *InitListHandler::process(const InitListExpr *expr) {
@@ -119,7 +119,7 @@ bool InitListHandler::tryToSplitStruct() {
   const QualType initType = init->getType();
   if (!initType->isStructureType() ||
       // Sampler types will pass the above check but we cannot split it.
-      TypeTranslator::isSampler(initType))
+      isSampler(initType))
     return false;
 
   // We are certain the current intializer will be replaced by now.
@@ -202,11 +202,11 @@ SpirvInstruction *InitListHandler::createInitForType(QualType type,
 
   // Samplers, (RW)Buffers, (RW)Textures
   // It is important that this happens before checking of structure types.
-  if (TypeTranslator::isOpaqueType(type))
+  if (isOpaqueType(type))
     return createInitForSamplerImageType(type, srcLoc);
 
   // This should happen before the check for normal struct types
-  if (TypeTranslator::isAKindOfStructuredOrByteBuffer(type)) {
+  if (isAKindOfStructuredOrByteBuffer(type)) {
     emitError("cannot handle structured/byte buffer as initializer", srcLoc);
     return nullptr;
   }
@@ -338,7 +338,7 @@ InitListHandler::createInitForMatrixType(QualType matrixType,
 }
 
 SpirvInstruction *InitListHandler::createInitForStructType(QualType type) {
-  assert(type->isStructureType() && !TypeTranslator::isSampler(type));
+  assert(type->isStructureType() && !isSampler(type));
 
   // Same as the vector case, first try to see if we already have a struct at
   // the beginning of the initializer queue.
@@ -417,7 +417,7 @@ InitListHandler::createInitForConstantArrayType(QualType type,
 SpirvInstruction *
 InitListHandler::createInitForSamplerImageType(QualType type,
                                                SourceLocation srcLoc) {
-  assert(TypeTranslator::isOpaqueType(type));
+  assert(isOpaqueType(type));
 
   // Samplers, (RW)Buffers, and (RW)Textures are translated into OpTypeSampler
   // and OpTypeImage. They should be treated similar as builtin types.

+ 0 - 3
tools/clang/lib/SPIRV/InitListHandler.h

@@ -20,10 +20,8 @@
 
 #include "clang/AST/Expr.h"
 #include "clang/Basic/Diagnostic.h"
-#include "clang/SPIRV/ModuleBuilder.h"
 
 #include "SPIRVEmitter.h"
-#include "TypeTranslator.h"
 
 namespace clang {
 namespace spirv {
@@ -137,7 +135,6 @@ private:
   const ASTContext &astContext;
   SPIRVEmitter &theEmitter;
   SpirvBuilder &spvBuilder;
-  TypeTranslator &typeTranslator;
   DiagnosticsEngine &diags;
 
   /// A queue keeping track of unused AST nodes for initializers. Since we will

+ 7 - 7
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -485,7 +485,7 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
       const bool isMS = (name == "Texture2DMS" || name == "Texture2DMSArray");
       const auto sampledType = hlsl::GetHLSLResourceResultType(type);
       return spvContext.getImageType(
-          lowerType(getElementType(sampledType), rule,
+          lowerType(getElementType(astContext, sampledType), rule,
                     /*isRowMajor*/ llvm::None, srcLoc),
           dim, ImageType::WithDepth::Unknown, isArray, isMS,
           ImageType::WithSampler::Yes, spv::ImageFormat::Unknown);
@@ -501,7 +501,7 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
       const auto format =
           translateSampledTypeToImageFormat(sampledType, srcLoc);
       return spvContext.getImageType(
-          lowerType(getElementType(sampledType), rule,
+          lowerType(getElementType(astContext, sampledType), rule,
                     /*isRowMajor*/ llvm::None, srcLoc),
           dim, ImageType::WithDepth::Unknown, isArray,
           /*isMultiSampled=*/false, /*sampled=*/ImageType::WithSampler::No,
@@ -602,8 +602,8 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
     }
     const auto format = translateSampledTypeToImageFormat(sampledType, srcLoc);
     return spvContext.getImageType(
-        lowerType(getElementType(sampledType), rule, /*isRowMajor*/ llvm::None,
-                  srcLoc),
+        lowerType(getElementType(astContext, sampledType), rule,
+                  /*isRowMajor*/ llvm::None, srcLoc),
         spv::Dim::Buffer, ImageType::WithDepth::Unknown,
         /*isArrayed=*/false, /*isMultiSampled=*/false,
         /*sampled*/ name == "Buffer" ? ImageType::WithSampler::Yes
@@ -637,8 +637,8 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
   if (name == "SubpassInput" || name == "SubpassInputMS") {
     const auto sampledType = hlsl::GetHLSLResourceResultType(type);
     return spvContext.getImageType(
-        lowerType(getElementType(sampledType), rule, /*isRowMajor*/ llvm::None,
-                  srcLoc),
+        lowerType(getElementType(astContext, sampledType), rule,
+                  /*isRowMajor*/ llvm::None, srcLoc),
         spv::Dim::SubpassData, ImageType::WithDepth::Unknown,
         /*isArrayed=*/false,
         /*isMultipleSampled=*/name == "SubpassInputMS",
@@ -655,7 +655,7 @@ LowerTypeVisitor::translateSampledTypeToImageFormat(QualType sampledType,
   QualType ty = {};
   if (isScalarType(sampledType, &ty) ||
       isVectorType(sampledType, &ty, &elemCount) ||
-      canFitIntoOneRegister(sampledType, &ty, &elemCount)) {
+      canFitIntoOneRegister(astContext, sampledType, &ty, &elemCount)) {
     if (const auto *builtinType = ty->getAs<BuiltinType>()) {
       switch (builtinType->getKind()) {
       case BuiltinType::Int:

+ 69 - 142
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -60,68 +60,6 @@ bool patchConstFuncTakesHullOutputPatch(FunctionDecl *pcf) {
   return false;
 }
 
-// TODO: Maybe we should move these type probing functions to TypeTranslator.
-
-/// Returns true if the given type is a bool or vector of bool type.
-bool isBoolOrVecOfBoolType(QualType type) {
-  QualType elemType = {};
-  return (isScalarType(type, &elemType) || isVectorType(type, &elemType)) &&
-         elemType->isBooleanType();
-}
-
-/// Returns true if the given type is a signed integer or vector of signed
-/// integer type.
-bool isSintOrVecOfSintType(QualType type) {
-  QualType elemType = {};
-  return (isScalarType(type, &elemType) || isVectorType(type, &elemType)) &&
-         elemType->isSignedIntegerType();
-}
-
-/// Returns true if the given type is an unsigned integer or vector of unsigned
-/// integer type.
-bool isUintOrVecOfUintType(QualType type) {
-  QualType elemType = {};
-  return (isScalarType(type, &elemType) || isVectorType(type, &elemType)) &&
-         elemType->isUnsignedIntegerType();
-}
-
-/// Returns true if the given type is a float or vector of float type.
-bool isFloatOrVecOfFloatType(QualType type) {
-  QualType elemType = {};
-  return (isScalarType(type, &elemType) || isVectorType(type, &elemType)) &&
-         elemType->isFloatingType();
-}
-
-/// Returns true if the given type is a bool or vector/matrix of bool type.
-bool isBoolOrVecMatOfBoolType(QualType type) {
-  return isBoolOrVecOfBoolType(type) ||
-         (hlsl::IsHLSLMatType(type) &&
-          hlsl::GetHLSLMatElementType(type)->isBooleanType());
-}
-
-/// Returns true if the given type is a signed integer or vector/matrix of
-/// signed integer type.
-bool isSintOrVecMatOfSintType(QualType type) {
-  return isSintOrVecOfSintType(type) ||
-         (hlsl::IsHLSLMatType(type) &&
-          hlsl::GetHLSLMatElementType(type)->isSignedIntegerType());
-}
-
-/// Returns true if the given type is an unsigned integer or vector/matrix of
-/// unsigned integer type.
-bool isUintOrVecMatOfUintType(QualType type) {
-  return isUintOrVecOfUintType(type) ||
-         (hlsl::IsHLSLMatType(type) &&
-          hlsl::GetHLSLMatElementType(type)->isUnsignedIntegerType());
-}
-
-/// Returns true if the given type is a float or vector/matrix of float type.
-bool isFloatOrVecMatOfFloatType(QualType type) {
-  return isFloatOrVecOfFloatType(type) ||
-         (hlsl::IsHLSLMatType(type) &&
-          hlsl::GetHLSLMatElementType(type)->isFloatingType());
-}
-
 inline bool isSpirvMatrixOp(spv::Op opcode) {
   return opcode == spv::Op::OpMatrixTimesMatrix ||
          opcode == spv::Op::OpMatrixTimesVector ||
@@ -143,7 +81,7 @@ const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) {
     if (GetIntrinsicOp(callee, opcode, group)) {
       if (static_cast<IntrinsicOp>(opcode) == IntrinsicOp::MOP_Load) {
         const auto *object = indexing->getImplicitObjectArgument();
-        if (TypeTranslator::isStructuredBuffer(object->getType())) {
+        if (isStructuredBuffer(object->getType())) {
           *index = indexing->getArg(0);
           return indexing->getImplicitObjectArgument();
         }
@@ -172,7 +110,7 @@ inline bool isExternalVar(const VarDecl *var) {
 const DeclContext *isConstantTextureBufferDeclRef(const Expr *expr) {
   if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr->IgnoreParenCasts()))
     if (const auto *varDecl = dyn_cast<VarDecl>(declRefExpr->getFoundDecl()))
-      if (TypeTranslator::isConstantTextureBuffer(varDecl))
+      if (isConstantTextureBuffer(varDecl))
         return varDecl->getType()->getAs<RecordType>()->getDecl();
 
   return nullptr;
@@ -191,10 +129,10 @@ bool isReferencingNonAliasStructuredOrByteBuffer(const Expr *expr) {
   expr = expr->IgnoreParenCasts();
   if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
     if (const auto *varDecl = dyn_cast<VarDecl>(declRefExpr->getFoundDecl()))
-      if (TypeTranslator::isAKindOfStructuredOrByteBuffer(varDecl->getType()))
+      if (isAKindOfStructuredOrByteBuffer(varDecl->getType()))
         return isExternalVar(varDecl);
   } else if (const auto *callExpr = dyn_cast<CallExpr>(expr)) {
-    if (TypeTranslator::isAKindOfStructuredOrByteBuffer(callExpr->getType()))
+    if (isAKindOfStructuredOrByteBuffer(callExpr->getType()))
       return true;
   } else if (const auto *arrSubExpr = dyn_cast<ArraySubscriptExpr>(expr)) {
     return isReferencingNonAliasStructuredOrByteBuffer(arrSubExpr->getBase());
@@ -543,10 +481,8 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
       entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
       shaderModel(*hlsl::ShaderModel::GetByName(
           ci.getCodeGenOpts().HLSLProfile.c_str())),
-      theContext(), spvContext(), featureManager(diags, spirvOptions),
-      theBuilder(&theContext, &featureManager, spirvOptions),
+      spvContext(), featureManager(diags, spirvOptions),
       spvBuilder(astContext, spvContext, &featureManager, spirvOptions),
-      typeTranslator(astContext, theBuilder, diags, spirvOptions),
       declIdMapper(shaderModel, astContext, spvContext, spvBuilder, *this,
                    featureManager, spirvOptions),
       entryFunction(nullptr), curFunction(nullptr), curThis(nullptr),
@@ -860,7 +796,7 @@ SpirvInstruction *SPIRVEmitter::loadIfGLValue(const Expr *expr,
   // If true, we are likely to copy it as a whole. To assist per-element
   // copying, avoid the load here and return the pointer directly.
   // TODO: consider moving this hack into SPIRV-Tools as a transformation.
-  if (TypeTranslator::isOpaqueArrayType(expr->getType()))
+  if (isOpaqueArrayType(expr->getType()))
     return info;
 
   // Check whether we are trying to load an externally visible structured/byte
@@ -1102,8 +1038,8 @@ bool SPIRVEmitter::validateVKAttributes(const NamedDecl *decl) {
 
   if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
     const auto varType = varDecl->getType();
-    if ((TypeTranslator::isSubpassInput(varType) ||
-         TypeTranslator::isSubpassInputMS(varType)) &&
+    if ((isSubpassInput(varType) ||
+         isSubpassInputMS(varType)) &&
         !varDecl->hasAttr<VKInputAttachmentIndexAttr>()) {
       emitError("missing vk::input_attachment_index attribute",
                 varDecl->getLocation());
@@ -1134,7 +1070,7 @@ bool SPIRVEmitter::validateVKAttributes(const NamedDecl *decl) {
           "only scalar/vector types allowed as SubpassInput(MS) parameter type",
           decl->getLocation());
       // Return directly to avoid further type processing, which will hit
-      // asserts in TypeTranslator.
+      // asserts when lowering the type.
       return false;
     }
   }
@@ -1191,8 +1127,8 @@ void SPIRVEmitter::doHLSLBufferDecl(const HLSLBufferDecl *bufferDecl) {
       }
 
       // We cannot handle external initialization of column-major matrices now.
-      if (typeTranslator.isOrContainsNonFpColMajorMatrix(varMember->getType(),
-                                                         varMember)) {
+      if (isOrContainsNonFpColMajorMatrix(astContext, spirvOptions,
+                                          varMember->getType(), varMember)) {
         emitError("externally initialized non-floating-point column-major "
                   "matrices not supported yet",
                   varMember->getLocation());
@@ -1228,7 +1164,8 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
 
   // We cannot handle external initialization of column-major matrices now.
   if (isExternalVar(decl) &&
-      typeTranslator.isOrContainsNonFpColMajorMatrix(decl->getType(), decl)) {
+      isOrContainsNonFpColMajorMatrix(astContext, spirvOptions, decl->getType(),
+                                      decl)) {
     emitError("externally initialized non-floating-point column-major "
               "matrices not supported yet",
               decl->getLocation());
@@ -1242,7 +1179,7 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
       type = type->getAsArrayTypeUnsafe()->getElementType();
     } while (type->isArrayType());
 
-    if (TypeTranslator::isRWAppendConsumeSBuffer(type)) {
+    if (isRWAppendConsumeSBuffer(type)) {
       emitError("arrays of RW/append/consume structured buffers unsupported",
                 decl->getLocation());
       return;
@@ -1314,16 +1251,16 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
 
     // Variables that are not externally visible and of opaque types should
     // request legalization.
-    if (!needsLegalization && TypeTranslator::isOpaqueType(decl->getType()))
+    if (!needsLegalization && isOpaqueType(decl->getType()))
       needsLegalization = true;
   }
 
-  if (TypeTranslator::isRelaxedPrecisionType(decl->getType(), spirvOptions)) {
+  if (isRelaxedPrecisionType(decl->getType(), spirvOptions)) {
     spvBuilder.decorateRelaxedPrecision(var);
   }
 
   // All variables that are of opaque struct types should request legalization.
-  if (!needsLegalization && TypeTranslator::isOpaqueStructType(decl->getType()))
+  if (!needsLegalization && isOpaqueStructType(decl->getType()))
     needsLegalization = true;
 }
 
@@ -2619,13 +2556,11 @@ SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions(
   const auto *object = expr->getImplicitObjectArgument();
   auto *objectInstr = loadIfAliasVarRef(object);
   const auto type = object->getType();
-  const bool isByteAddressBuffer = TypeTranslator::isByteAddressBuffer(type) ||
-                                   TypeTranslator::isRWByteAddressBuffer(type);
-  const bool isStructuredBuffer =
-      TypeTranslator::isStructuredBuffer(type) ||
-      TypeTranslator::isAppendStructuredBuffer(type) ||
-      TypeTranslator::isConsumeStructuredBuffer(type);
-  assert(isByteAddressBuffer || isStructuredBuffer);
+  const bool isBABuf = isByteAddressBuffer(type) || isRWByteAddressBuffer(type);
+  const bool isStructuredBuf = isStructuredBuffer(type) ||
+                               isAppendStructuredBuffer(type) ||
+                               isConsumeStructuredBuffer(type);
+  assert(isBABuf || isStructuredBuf);
 
   // (RW)ByteAddressBuffers/(RW)StructuredBuffers are represented as a structure
   // with only one member that is a runtime array. We need to perform
@@ -2635,7 +2570,7 @@ SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions(
   // For (RW)ByteAddressBuffers, GetDimensions() must return the array length
   // in bytes, but OpArrayLength returns the number of uints in the runtime
   // array. Therefore we must multiply the results by 4.
-  if (isByteAddressBuffer) {
+  if (isBABuf) {
     length = spvBuilder.createBinaryOp(
         spv::Op::OpIMul, astContext.UnsignedIntTy, length,
         spvBuilder.getConstantInt(astContext.UnsignedIntTy,
@@ -2643,7 +2578,7 @@ SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions(
   }
   spvBuilder.createStore(doExpr(expr->getArg(0)), length);
 
-  if (isStructuredBuffer) {
+  if (isStructuredBuf) {
     // For (RW)StructuredBuffer, the stride of the runtime array (which is the
     // size of the struct) must also be written to the second argument.
     AlignmentSizeCalculator alignmentCalc(astContext, spirvOptions);
@@ -2747,8 +2682,8 @@ SPIRVEmitter::processBufferTextureGetDimensions(const CXXMemberCallExpr *expr) {
   const auto numArgs = expr->getNumArgs();
   const Expr *mipLevel = nullptr, *numLevels = nullptr, *numSamples = nullptr;
 
-  assert(TypeTranslator::isTexture(type) || TypeTranslator::isRWTexture(type) ||
-         TypeTranslator::isBuffer(type) || TypeTranslator::isRWBuffer(type));
+  assert(isTexture(type) || isRWTexture(type) || isBuffer(type) ||
+         isRWBuffer(type));
 
   // For Texture1D, arguments are either:
   // a) width
@@ -2804,7 +2739,7 @@ SPIRVEmitter::processBufferTextureGetDimensions(const CXXMemberCallExpr *expr) {
     mipLevel = expr->getArg(0);
     numLevels = expr->getArg(numArgs - 1);
   }
-  if (TypeTranslator::isTextureMS(type)) {
+  if (isTextureMS(type)) {
     numSamples = expr->getArg(numArgs - 1);
   }
 
@@ -2825,7 +2760,7 @@ SPIRVEmitter::processBufferTextureGetDimensions(const CXXMemberCallExpr *expr) {
   // Only Texture types use ImageQuerySizeLod.
   // TextureMS, RWTexture, Buffers, RWBuffers use ImageQuerySize.
   SpirvInstruction *lod = nullptr;
-  if (TypeTranslator::isTexture(type) && !numSamples) {
+  if (isTexture(type) && !numSamples) {
     if (mipLevel) {
       // For Texture types when mipLevel argument is present.
       lod = doExpr(mipLevel);
@@ -3061,20 +2996,16 @@ SpirvInstruction *SPIRVEmitter::processBufferTextureLoad(
   // Loading for Buffer and RWBuffer translates to an OpImageFetch.
   // The result type of an OpImageFetch must be a vec4 of float or int.
   const auto type = object->getType();
-  assert(TypeTranslator::isBuffer(type) || TypeTranslator::isRWBuffer(type) ||
-         TypeTranslator::isTexture(type) || TypeTranslator::isRWTexture(type) ||
-         TypeTranslator::isSubpassInput(type) ||
-         TypeTranslator::isSubpassInputMS(type));
+  assert(isBuffer(type) || isRWBuffer(type) || isTexture(type) ||
+         isRWTexture(type) || isSubpassInput(type) || isSubpassInputMS(type));
 
-  const bool doFetch =
-      TypeTranslator::isBuffer(type) || TypeTranslator::isTexture(type);
+  const bool doFetch = isBuffer(type) || isTexture(type);
 
   auto *objectInfo = loadIfGLValue(object);
 
   // For Texture2DMS and Texture2DMSArray, Sample must be used rather than Lod.
   SpirvInstruction *sampleNumber = nullptr;
-  if (TypeTranslator::isTextureMS(type) ||
-      TypeTranslator::isSubpassInputMS(type)) {
+  if (isTextureMS(type) || isSubpassInputMS(type)) {
     sampleNumber = lod;
     lod = nullptr;
   }
@@ -3091,8 +3022,8 @@ SpirvInstruction *SPIRVEmitter::processBufferTextureLoad(
       // For struct type, we need to make sure it can fit into a 4-component
       // vector. Detailed failing reasons will be emitted by the function so
       // we don't need to emit errors here.
-      if (!typeTranslator.canFitIntoOneRegister(sampledType, &elemType,
-                                                &elemCount))
+      if (!canFitIntoOneRegister(astContext, sampledType, &elemType,
+                                 &elemCount))
         return nullptr;
     }
   }
@@ -3126,11 +3057,11 @@ SpirvInstruction *SPIRVEmitter::processByteAddressBufferLoadStore(
   auto *objectInfo = loadIfAliasVarRef(object);
   assert(numWords >= 1 && numWords <= 4);
   if (doStore) {
-    assert(typeTranslator.isRWByteAddressBuffer(object->getType()));
+    assert(isRWByteAddressBuffer(object->getType()));
     assert(expr->getNumArgs() == 2);
   } else {
-    assert(typeTranslator.isRWByteAddressBuffer(object->getType()) ||
-           typeTranslator.isByteAddressBuffer(object->getType()));
+    assert(isRWByteAddressBuffer(object->getType()) ||
+           isByteAddressBuffer(object->getType()));
     if (expr->getNumArgs() == 2) {
       emitError(
           "(RW)ByteAddressBuffer::Load(in address, out status) not supported",
@@ -4221,23 +4152,20 @@ SPIRVEmitter::processBufferTextureLoad(const CXXMemberCallExpr *expr) {
   const auto *object = expr->getImplicitObjectArgument();
   const auto objectType = object->getType();
 
-  if (typeTranslator.isRWByteAddressBuffer(objectType) ||
-      typeTranslator.isByteAddressBuffer(objectType))
+  if (isRWByteAddressBuffer(objectType) || isByteAddressBuffer(objectType))
     return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ false);
 
-  if (TypeTranslator::isStructuredBuffer(objectType))
+  if (isStructuredBuffer(objectType))
     return processStructuredBufferLoad(expr);
 
   const auto numArgs = expr->getNumArgs();
   const auto *locationArg = expr->getArg(0);
-  const bool isTextureMS = TypeTranslator::isTextureMS(objectType);
+  const bool textureMS = isTextureMS(objectType);
   const bool hasStatusArg =
       expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
   auto *status = hasStatusArg ? doExpr(expr->getArg(numArgs - 1)) : nullptr;
 
-  if (TypeTranslator::isBuffer(objectType) ||
-      TypeTranslator::isRWBuffer(objectType) ||
-      TypeTranslator::isRWTexture(objectType))
+  if (isBuffer(objectType) || isRWBuffer(objectType) || isRWTexture(objectType))
     return processBufferTextureLoad(object, doExpr(locationArg),
                                     /*constOffset*/ nullptr,
                                     /*varOffset*/ nullptr, /*lod*/ nullptr,
@@ -4245,15 +4173,15 @@ SPIRVEmitter::processBufferTextureLoad(const CXXMemberCallExpr *expr) {
 
   // Subtract 1 for status (if it exists), and 1 for sampleIndex (if it exists),
   // and 1 for location.
-  const bool hasOffsetArg = numArgs - hasStatusArg - isTextureMS - 1 > 0;
+  const bool hasOffsetArg = numArgs - hasStatusArg - textureMS - 1 > 0;
 
-  if (TypeTranslator::isTexture(objectType)) {
+  if (isTexture(objectType)) {
     // .Load() has a second optional paramter for offset.
     SpirvInstruction *location = doExpr(locationArg);
     SpirvInstruction *constOffset = nullptr, *varOffset = nullptr;
     SpirvInstruction *coordinate = location, *lod = nullptr;
 
-    if (isTextureMS) {
+    if (textureMS) {
       // SampleIndex is only available when the Object is of Texture2DMS or
       // Texture2DMSArray types. Under those cases, Offset will be the third
       // parameter (index 2).
@@ -4283,16 +4211,14 @@ SPIRVEmitter::processBufferTextureLoad(const CXXMemberCallExpr *expr) {
 SpirvInstruction *
 SPIRVEmitter::processGetDimensions(const CXXMemberCallExpr *expr) {
   const auto objectType = expr->getImplicitObjectArgument()->getType();
-  if (TypeTranslator::isTexture(objectType) ||
-      TypeTranslator::isRWTexture(objectType) ||
-      TypeTranslator::isBuffer(objectType) ||
-      TypeTranslator::isRWBuffer(objectType)) {
+  if (isTexture(objectType) || isRWTexture(objectType) ||
+      isBuffer(objectType) || isRWBuffer(objectType)) {
     return processBufferTextureGetDimensions(expr);
-  } else if (TypeTranslator::isByteAddressBuffer(objectType) ||
-             TypeTranslator::isRWByteAddressBuffer(objectType) ||
-             TypeTranslator::isStructuredBuffer(objectType) ||
-             TypeTranslator::isAppendStructuredBuffer(objectType) ||
-             TypeTranslator::isConsumeStructuredBuffer(objectType)) {
+  } else if (isByteAddressBuffer(objectType) ||
+             isRWByteAddressBuffer(objectType) ||
+             isStructuredBuffer(objectType) ||
+             isAppendStructuredBuffer(objectType) ||
+             isConsumeStructuredBuffer(objectType)) {
     return processByteAddressBufferStructuredBufferGetDimensions(expr);
   } else {
     emitError("GetDimensions() of the given object type unimplemented",
@@ -4310,7 +4236,7 @@ SPIRVEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr) {
 
     // For Textures, regular indexing (operator[]) uses slice 0.
     if (isBufferTextureIndexing(expr, &baseExpr, &indexExpr)) {
-      auto *lod = TypeTranslator::isTexture(baseExpr->getType())
+      auto *lod = isTexture(baseExpr->getType())
                       ? spvBuilder.getConstantInt(astContext.UnsignedIntTy,
                                                   llvm::APInt(32, 0))
                       : nullptr;
@@ -4809,7 +4735,7 @@ void SPIRVEmitter::storeValue(SpirvInstruction *lhsPtr,
     }
 
     spvBuilder.createStore(lhsPtr, rhsVal);
-  } else if (TypeTranslator::isOpaqueType(lhsValType)) {
+  } else if (isOpaqueType(lhsValType)) {
     // Resource types are represented using RecordType in the AST.
     // Handle them before the general RecordType.
     //
@@ -4838,7 +4764,7 @@ void SPIRVEmitter::storeValue(SpirvInstruction *lhsPtr,
     // assignments/returns from ConstantBuffer<T>/TextureBuffer<T> to function
     // parameters/returns/variables of type T. And ConstantBuffer<T> is not
     // represented differently as struct T.
-  } else if (TypeTranslator::isOpaqueArrayType(lhsValType)) {
+  } else if (isOpaqueArrayType(lhsValType)) {
     // For opaque array types, we cannot perform OpLoad on the whole array and
     // then write out as a whole; instead, we need to OpLoad each element
     // using access chains. This is to influence later SPIR-V transformations
@@ -5241,7 +5167,7 @@ bool SPIRVEmitter::isTextureMipsSampleIndexing(const CXXOperatorCallExpr *expr,
 
   const Expr *object = memberExpr->getBase();
   const auto objectType = object->getType();
-  if (!TypeTranslator::isTexture(objectType))
+  if (!isTexture(objectType))
     return false;
 
   if (base)
@@ -5264,10 +5190,8 @@ bool SPIRVEmitter::isBufferTextureIndexing(const CXXOperatorCallExpr *indexExpr,
     return false;
   const Expr *object = indexExpr->getArg(0);
   const auto objectType = object->getType();
-  if (TypeTranslator::isBuffer(objectType) ||
-      TypeTranslator::isRWBuffer(objectType) ||
-      TypeTranslator::isTexture(objectType) ||
-      TypeTranslator::isRWTexture(objectType)) {
+  if (isBuffer(objectType) || isRWBuffer(objectType) || isTexture(objectType) ||
+      isRWTexture(objectType)) {
     if (base)
       *base = object;
     if (index)
@@ -5722,7 +5646,7 @@ SpirvInstruction *SPIRVEmitter::processEachVectorInMatrix(
         actOnEachVector) {
   const auto matType = matrix->getType();
   assert(isMxNMatrix(matType));
-  const QualType vecType = typeTranslator.getComponentVectorType(matType);
+  const QualType vecType = getComponentVectorType(astContext, matType);
 
   uint32_t rowCount = 0, colCount = 0;
   hlsl::GetHLSLMatRowColCount(matType, rowCount, colCount);
@@ -5932,7 +5856,7 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
           collectArrayStructIndices(thisBase, rawIndex, rawIndices, indices);
 
       if (thisBaseType != base->getType() &&
-          TypeTranslator::isAKindOfStructuredOrByteBuffer(thisBaseType)) {
+          isAKindOfStructuredOrByteBuffer(thisBaseType)) {
         // The immediate base is a kind of structured or byte buffer. It should
         // be an alias variable. Break the normal index collecting chain.
         // Return the immediate base as the base so that we can apply other
@@ -5946,7 +5870,7 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
       // If the base is a StructureType, we need to push an addtional index 0
       // here. This is because we created an additional OpTypeRuntimeArray
       // in the structure.
-      if (TypeTranslator::isStructuredBuffer(thisBaseType))
+      if (isStructuredBuffer(thisBaseType))
         indices->push_back(
             spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0)));
 
@@ -6100,7 +6024,7 @@ SpirvInstruction *SPIRVEmitter::castToInt(SpirvInstruction *fromVal,
       // Casting to a matrix of integers: Cast each row and construct a
       // composite.
       llvm::SmallVector<SpirvInstruction *, 4> castedRows;
-      const QualType vecType = typeTranslator.getComponentVectorType(fromType);
+      const QualType vecType = getComponentVectorType(astContext, fromType);
       const auto fromVecQualType =
           astContext.getExtVectorType(elemType, numCols);
       const auto toIntVecQualType =
@@ -6131,8 +6055,10 @@ SpirvInstruction *SPIRVEmitter::convertBitwidth(SpirvInstruction *fromVal,
       fromType->isSpecificBuiltinType(BuiltinType::LitInt))
     return fromVal;
 
-  const auto fromBitwidth = typeTranslator.getElementSpirvBitwidth(fromType);
-  const auto toBitwidth = typeTranslator.getElementSpirvBitwidth(toType);
+  const auto fromBitwidth = getElementSpirvBitwidth(
+      astContext, fromType, spirvOptions.enable16BitTypes);
+  const auto toBitwidth = getElementSpirvBitwidth(
+      astContext, toType, spirvOptions.enable16BitTypes);
   if (fromBitwidth == toBitwidth) {
     if (resultType)
       *resultType = fromType;
@@ -6204,7 +6130,7 @@ SpirvInstruction *SPIRVEmitter::castToFloat(SpirvInstruction *fromVal,
       // Casting to a matrix of floats: Cast each row and construct a
       // composite.
       llvm::SmallVector<SpirvInstruction *, 4> castedRows;
-      const QualType vecType = typeTranslator.getComponentVectorType(fromType);
+      const QualType vecType = getComponentVectorType(astContext, fromType);
       const auto fromVecQualType =
           astContext.getExtVectorType(elemType, numCols);
       const auto toIntVecQualType =
@@ -8052,7 +7978,7 @@ SPIRVEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
     uint32_t matRowCount = 0, matColCount = 0;
     if (isMxNMatrix(argType, &elemType, &matRowCount, &matColCount)) {
       auto *matrix = doExpr(arg);
-      const QualType vecType = typeTranslator.getComponentVectorType(argType);
+      const QualType vecType = getComponentVectorType(astContext, argType);
       llvm::SmallVector<SpirvInstruction *, 4> rowResults;
       for (uint32_t i = 0; i < matRowCount; ++i) {
         // Extract the row which is a float vector of size matColCount.
@@ -8660,7 +8586,8 @@ SpirvConstant *SPIRVEmitter::getMaskForBitwidthValue(QualType type) {
   uint32_t count = 1;
 
   if (isScalarType(type, &elemType) || isVectorType(type, &elemType, &count)) {
-    const auto bitwidth = typeTranslator.getElementSpirvBitwidth(elemType);
+    const auto bitwidth = getElementSpirvBitwidth(
+        astContext, elemType, spirvOptions.enable16BitTypes);
     SpirvConstant *mask = nullptr;
     switch (bitwidth) {
     case 16:

+ 1 - 7
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -28,7 +28,6 @@
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/SPIRV/FeatureManager.h"
-#include "clang/SPIRV/ModuleBuilder.h"
 #include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/SpirvBuilder.h"
 #include "llvm/ADT/STLExtras.h"
@@ -36,7 +35,6 @@
 
 #include "DeclResultIdMapper.h"
 #include "SpirvEvalInfo.h"
-#include "TypeTranslator.h"
 
 namespace clang {
 namespace spirv {
@@ -52,8 +50,7 @@ public:
   void HandleTranslationUnit(ASTContext &context) override;
 
   ASTContext &getASTContext() { return astContext; }
-  SpirvBuilder &getModuleBuilder() { return spvBuilder; }
-  TypeTranslator &getTypeTranslator() { return typeTranslator; }
+  SpirvBuilder &getSpirvBuilder() { return spvBuilder; }
   DiagnosticsEngine &getDiagnosticsEngine() { return diags; }
 
   void doDecl(const Decl *decl);
@@ -966,12 +963,9 @@ private:
   const llvm::StringRef entryFunctionName;
   const hlsl::ShaderModel &shaderModel;
 
-  SPIRVContext theContext;
   SpirvContext spvContext;
   FeatureManager featureManager;
-  ModuleBuilder theBuilder;
   SpirvBuilder spvBuilder;
-  TypeTranslator typeTranslator;
   DeclResultIdMapper declIdMapper;
 
   /// A queue of decls reachable from the entry function. Decls inserted into

+ 0 - 1
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -12,7 +12,6 @@
 #include "EmitVisitor.h"
 #include "LiteralTypeVisitor.h"
 #include "LowerTypeVisitor.h"
-#include "TypeTranslator.h"
 #include "clang/SPIRV/AstTypeProbe.h"
 
 namespace clang {

+ 0 - 9
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -126,15 +126,6 @@ bool SpirvType::isOrContains16BitType(const SpirvType *type) {
   return false;
 }
 
-bool SpirvType::isMatrixOrArrayOfMatrix(const SpirvType *type) {
-  if (isa<MatrixType>(type))
-    return true;
-  if (const auto *arrayType = dyn_cast<ArrayType>(type))
-    return isMatrixOrArrayOfMatrix(arrayType->getElementType());
-
-  return false;
-}
-
 MatrixType::MatrixType(const VectorType *vecType, uint32_t vecCount)
     : SpirvType(TK_Matrix), vectorType(vecType), vectorCount(vecCount) {}