Pārlūkot izejas kodu

[spirv] Migrate DeclResultIdMapper to new infrastructure.

Ehsan Nasiri 6 gadi atpakaļ
vecāks
revīzija
3db3337bb9

+ 38 - 5
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -10,6 +10,7 @@
 
 #include <string>
 
+#include "clang/AST/Decl.h"
 #include "clang/AST/Type.h"
 
 namespace clang {
@@ -85,13 +86,45 @@ bool isConstantTextureBuffer(const Decl *decl);
 /// * SubpassInput(MS)
 bool isResourceType(const ValueDecl *decl);
 
-/// \brief Returns true if the given type is or contains 16-bit type.
-//bool isOrContains16BitType(QualType type);
-
-  /// \brief Returns true if the given type is the HLSL (RW)StructuredBuffer,
-  /// (RW)ByteAddressBuffer, or {Append|Consume}StructuredBuffer.
+/// \brief Returns true if the given type is the HLSL (RW)StructuredBuffer,
+/// (RW)ByteAddressBuffer, or {Append|Consume}StructuredBuffer.
 bool isAKindOfStructuredOrByteBuffer(QualType type);
 
+/// Returns true if the given type is or contains a 16-bit type.
+/// The caller must also specify whether 16-bit types have been enabled via
+/// command line options.
+bool isOrContains16BitType(QualType type, bool enable16BitTypesOption);
+
+/// NOTE: This method doesn't handle Literal types correctly at the moment.
+///
+/// Note: This method will be deprecated once resolving of literal types are
+/// moved to a dedicated pass.
+///
+/// \brief Returns the realized bitwidth of the given type when represented in
+/// SPIR-V. Panics if the given type is not a scalar, a vector/matrix of float
+/// or integer, or an array of them. In case of vectors, it returns the
+/// realized SPIR-V bitwidth of the vector elements.
+uint32_t getElementSpirvBitwidth(const ASTContext &astContext, QualType type,
+                                 bool is16BitTypeEnabled);
+
+/// Returns true if the two types can be treated as the same scalar
+/// type, which means they have the same canonical type, regardless of
+/// constnesss and literalness.
+bool canTreatAsSameScalarType(QualType type1, QualType type2);
+
+/// Returns true if all members in structType are of the same element
+/// type and can be fit into a 4-component vector. Writes element type and
+/// count to *elemType and *elemCount if not nullptr. Otherwise, emit errors
+/// explaining why not.
+bool canFitIntoOneRegister(QualType structType, QualType *elemType,
+                           uint32_t *elemCount = nullptr);
+
+/// Returns the element type of the given type. The given type may be a scalar
+/// type, vector type, matrix type, or array type. It may also be a struct with
+/// members that can fit into a register. In such case, the result would be the
+/// struct member type.
+QualType getElementType(QualType type);
+
 } // namespace spirv
 } // namespace clang
 

+ 5 - 0
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -211,6 +211,8 @@ public:
 
   SpirvConstant *getConstantUint32(uint32_t value);
   SpirvConstant *getConstantInt32(int32_t value);
+  SpirvConstant *getConstantFloat32(float value);
+  SpirvConstant *getConstantBool(bool value);
   // TODO: Add getConstant* methods for other types.
 
 private:
@@ -266,6 +268,9 @@ private:
   // We currently do a linear search to find an existing constant (if any). This
   // can be done in a more efficient way if needed.
   llvm::SmallVector<SpirvConstantInteger *, 8> integerConstants;
+  llvm::SmallVector<SpirvConstantFloat *, 8> floatConstants;
+  SpirvConstantBoolean *boolTrueConstant;
+  SpirvConstantBoolean *boolFalseConstant;
   // TODO: Add vectors of other constant types here.
 };
 

+ 3 - 0
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -448,6 +448,9 @@ public:
   SpirvVariable *addStageBuiltinVar(const SpirvType *type,
                                     spv::StorageClass storageClass,
                                     spv::BuiltIn, SourceLocation loc = {});
+  SpirvVariable *addStageBuiltinVar(QualType type,
+                                    spv::StorageClass storageClass,
+                                    spv::BuiltIn, SourceLocation loc = {});
 
   /// \brief Adds a module variable. This variable should not have the Function
   /// storage class.

+ 257 - 0
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -281,5 +281,262 @@ bool isAKindOfStructuredOrByteBuffer(QualType type) {
   return false;
 }
 
+bool isOrContains16BitType(QualType type, bool enable16BitTypesOption) {
+  // Primitive types
+  {
+    QualType ty = {};
+    if (isScalarType(type, &ty)) {
+      if (const auto *builtinType = ty->getAs<BuiltinType>()) {
+        switch (builtinType->getKind()) {
+        case BuiltinType::Min12Int:
+        case BuiltinType::Min16Int:
+        case BuiltinType::Min16UInt:
+        case BuiltinType::Min10Float:
+        case BuiltinType::Min16Float:
+          return enable16BitTypesOption;
+        // the 'Half' enum always represents 16-bit and 'HalfFloat' always
+        // represents 32-bit floats.
+        // int16_t and uint16_t map to Short and UShort
+        case BuiltinType::Short:
+        case BuiltinType::UShort:
+        case BuiltinType::Half:
+          return true;
+        default:
+          return false;
+        }
+      }
+    }
+  }
+
+  // Vector types
+  {
+    QualType elemType = {};
+    if (isVectorType(type, &elemType))
+      return isOrContains16BitType(elemType, enable16BitTypesOption);
+  }
+
+  // Matrix types
+  {
+    QualType elemType = {};
+    if (isMxNMatrix(type, &elemType)) {
+      return isOrContains16BitType(elemType, enable16BitTypesOption);
+    }
+  }
+
+  // Struct type
+  if (const auto *structType = type->getAs<RecordType>()) {
+    const auto *decl = structType->getDecl();
+
+    for (const auto *field : decl->fields()) {
+      if (isOrContains16BitType(field->getType(), enable16BitTypesOption))
+        return true;
+    }
+
+    return false;
+  }
+
+  // Array type
+  if (const auto *arrayType = type->getAsArrayTypeUnsafe()) {
+    return isOrContains16BitType(arrayType->getElementType(),
+                                 enable16BitTypesOption);
+  }
+
+  // Reference types
+  if (const auto *refType = type->getAs<ReferenceType>()) {
+    return isOrContains16BitType(refType->getPointeeType(),
+                                 enable16BitTypesOption);
+  }
+
+  // Pointer types
+  if (const auto *ptrType = type->getAs<PointerType>()) {
+    return isOrContains16BitType(ptrType->getPointeeType(),
+                                 enable16BitTypesOption);
+  }
+
+  if (const auto *typedefType = type->getAs<TypedefType>()) {
+    return isOrContains16BitType(typedefType->desugar(),
+                                 enable16BitTypesOption);
+  }
+
+  llvm_unreachable("checking 16-bit type unimplemented");
+  return 0;
+}
+
+uint32_t getElementSpirvBitwidth(const ASTContext &astContext, QualType type,
+                                 bool is16BitTypeEnabled) {
+  const auto canonicalType = type.getCanonicalType();
+  if (canonicalType != type)
+    return getElementSpirvBitwidth(astContext, canonicalType,
+                                   is16BitTypeEnabled);
+
+  // Vector types
+  {
+    QualType elemType = {};
+    if (isVectorType(type, &elemType))
+      return getElementSpirvBitwidth(astContext, elemType, is16BitTypeEnabled);
+  }
+
+  // Matrix types
+  if (hlsl::IsHLSLMatType(type))
+    return getElementSpirvBitwidth(
+        astContext, hlsl::GetHLSLMatElementType(type), is16BitTypeEnabled);
+
+  // Array types
+  if (const auto *arrayType = type->getAsArrayTypeUnsafe()) {
+    return getElementSpirvBitwidth(astContext, arrayType->getElementType(),
+                                   is16BitTypeEnabled);
+  }
+
+  // Typedefs
+  if (const auto *typedefType = type->getAs<TypedefType>())
+    return getElementSpirvBitwidth(astContext, typedefType->desugar(),
+                                   is16BitTypeEnabled);
+
+  // Reference types
+  if (const auto *refType = type->getAs<ReferenceType>())
+    return getElementSpirvBitwidth(astContext, refType->getPointeeType(),
+                                   is16BitTypeEnabled);
+
+  // Pointer types
+  if (const auto *ptrType = type->getAs<PointerType>())
+    return getElementSpirvBitwidth(astContext, ptrType->getPointeeType(),
+                                   is16BitTypeEnabled);
+
+  // Scalar types
+  QualType ty = {};
+  const bool isScalar = isScalarType(type, &ty);
+  assert(isScalar);
+  (void)isScalar;
+  if (const auto *builtinType = ty->getAs<BuiltinType>()) {
+    switch (builtinType->getKind()) {
+    case BuiltinType::Bool:
+    case BuiltinType::Int:
+    case BuiltinType::UInt:
+    case BuiltinType::Float:
+      return 32;
+    case BuiltinType::Double:
+    case BuiltinType::LongLong:
+    case BuiltinType::ULongLong:
+      return 64;
+    // Half builtin type is always 16-bit. The HLSL 'half' keyword is translated
+    // to 'Half' enum if -enable-16bit-types is true.
+    // int16_t and uint16_t map to Short and UShort
+    case BuiltinType::Half:
+    case BuiltinType::Short:
+    case BuiltinType::UShort:
+      return 16;
+    // HalfFloat builtin type is just an alias for Float builtin type and is
+    // always 32-bit. The HLSL 'half' keyword is translated to 'HalfFloat' enum
+    // if -enable-16bit-types is false.
+    case BuiltinType::HalfFloat:
+      return 32;
+    // The following types are treated as 16-bit if '-enable-16bit-types' option
+    // is enabled. They are treated as 32-bit otherwise.
+    case BuiltinType::Min12Int:
+    case BuiltinType::Min16Int:
+    case BuiltinType::Min16UInt:
+    case BuiltinType::Min16Float:
+    case BuiltinType::Min10Float: {
+      return is16BitTypeEnabled ? 16 : 32;
+    }
+    case BuiltinType::LitFloat: {
+      // TODO(ehsan): Literal types not handled properly.
+      return 64;
+    }
+    case BuiltinType::LitInt: {
+      // TODO(ehsan): Literal types not handled properly.
+      return 64;
+    }
+    default:
+      // Other builtin types are either not relevant to bitcount or not in HLSL.
+      break;
+    }
+  }
+  llvm_unreachable("invalid type passed to getElementSpirvBitwidth");
+}
+
+bool canTreatAsSameScalarType(QualType type1, QualType type2) {
+  // Treat const int/float the same as const int/float
+  type1.removeLocalConst();
+  type2.removeLocalConst();
+
+  return (type1.getCanonicalType() == type2.getCanonicalType()) ||
+         // Treat 'literal float' and 'float' as the same
+         (type1->isSpecificBuiltinType(BuiltinType::LitFloat) &&
+          type2->isFloatingType()) ||
+         (type2->isSpecificBuiltinType(BuiltinType::LitFloat) &&
+          type1->isFloatingType()) ||
+         // Treat 'literal int' and 'int'/'uint' as the same
+         (type1->isSpecificBuiltinType(BuiltinType::LitInt) &&
+          type2->isIntegerType() &&
+          // Disallow boolean types
+          !type2->isSpecificBuiltinType(BuiltinType::Bool)) ||
+         (type2->isSpecificBuiltinType(BuiltinType::LitInt) &&
+          type1->isIntegerType() &&
+          // Disallow boolean types
+          !type1->isSpecificBuiltinType(BuiltinType::Bool));
+}
+
+bool canFitIntoOneRegister(QualType structType, QualType *elemType,
+                           uint32_t *elemCount) {
+  if (structType->getAsStructureType() == nullptr)
+    return false;
+
+  const auto *structDecl = structType->getAsStructureType()->getDecl();
+  QualType firstElemType;
+  uint32_t totalCount = 0;
+
+  for (const auto *field : structDecl->fields()) {
+    QualType type;
+    uint32_t count = 1;
+
+    if (isScalarType(field->getType(), &type) ||
+        isVectorType(field->getType(), &type, &count)) {
+      if (firstElemType.isNull()) {
+        firstElemType = type;
+      } else {
+        if (!canTreatAsSameScalarType(firstElemType, type)) {
+          assert(false && "all struct members should have the same element "
+                          "type for resource template instantiation");
+          return false;
+        }
+      }
+      totalCount += count;
+    } else {
+      assert(false && "unsupported struct element type for resource template "
+                      "instantiation");
+      return false;
+    }
+  }
+
+  if (totalCount > 4) {
+    assert(
+        false &&
+        "resource template element type cannot fit into four 32-bit scalars");
+    return false;
+  }
+
+  if (elemType)
+    *elemType = firstElemType;
+  if (elemCount)
+    *elemCount = totalCount;
+  return true;
+}
+
+QualType getElementType(QualType type) {
+  QualType elemType = {};
+  if (isScalarType(type, &elemType) || isVectorType(type, &elemType) ||
+      isMxNMatrix(type, &elemType) || canFitIntoOneRegister(type, &elemType)) {
+    return elemType;
+  }
+
+  if (const auto *arrType = dyn_cast<ConstantArrayType>(type)) {
+    return arrType->getElementType();
+  }
+
+  assert(false && "unsupported resource type parameter");
+  return type;
+}
+
 } // namespace spirv
 } // namespace clang

Failā izmaiņas netiks attēlotas, jo tās ir par lielu
+ 294 - 198
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp


+ 20 - 27
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -20,7 +20,6 @@
 #include "spirv/unified1/spirv.hpp11"
 #include "clang/AST/Attr.h"
 #include "clang/SPIRV/FeatureManager.h"
-#include "clang/SPIRV/ModuleBuilder.h"
 #include "clang/SPIRV/SpirvBuilder.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Optional.h"
@@ -28,7 +27,6 @@
 
 #include "GlPerVertex.h"
 #include "SpirvEvalInfo.h"
-#include "TypeTranslator.h"
 
 namespace clang {
 namespace spirv {
@@ -55,10 +53,10 @@ struct SemanticInfo {
 class StageVar {
 public:
   inline StageVar(const hlsl::SigPoint *sig, SemanticInfo semaInfo,
-                  const VKBuiltInAttr *builtin, const SpirvType *spvType,
+                  const VKBuiltInAttr *builtin, QualType astType,
                   uint32_t locCount)
       : sigPoint(sig), semanticInfo(std::move(semaInfo)), builtinAttr(builtin),
-        type(spvType), value(nullptr), isBuiltin(false),
+        type(astType), value(nullptr), isBuiltin(false),
         storageClass(spv::StorageClass::Max), location(nullptr),
         locationCount(locCount) {
     isBuiltin = builtinAttr != nullptr;
@@ -68,7 +66,7 @@ public:
   const SemanticInfo &getSemanticInfo() const { return semanticInfo; }
   std::string getSemanticStr() const;
 
-  const SpirvType *getSpirvType() const { return type; }
+  QualType getAstType() const { return type; }
 
   SpirvVariable *getSpirvInstr() const { return value; }
   void setSpirvInstr(SpirvVariable *spvInstr) { value = spvInstr; }
@@ -97,8 +95,8 @@ private:
   SemanticInfo semanticInfo;
   /// SPIR-V BuiltIn attribute.
   const VKBuiltInAttr *builtinAttr;
-  /// SPIR-V type.
-  const SpirvType *type;
+  /// The AST QualType.
+  QualType type;
   /// SPIR-V instruction.
   SpirvVariable *value;
   /// Indicates whether this stage variable should be a SPIR-V builtin.
@@ -262,14 +260,13 @@ private:
 class DeclResultIdMapper {
 public:
   inline DeclResultIdMapper(const hlsl::ShaderModel &stage, ASTContext &context,
-                            SpirvContext &spirvContext, ModuleBuilder &builder,
+                            SpirvContext &spirvContext,
                             SpirvBuilder &spirvBuilder, SPIRVEmitter &emitter,
-                            TypeTranslator &translator,
                             FeatureManager &features,
                             const SpirvCodeGenOptions &spirvOptions);
 
-  /// \brief Returns the <result-id> for a SPIR-V builtin variable.
-  uint32_t getBuiltinVar(spv::BuiltIn builtIn);
+  /// \brief Returns the SPIR-V builtin variable.
+  SpirvVariable *getBuiltinVar(spv::BuiltIn builtIn);
 
   /// \brief Creates the stage output variables by parsing the semantics
   /// attached to the given function's parameter or return value and returns
@@ -442,7 +439,7 @@ public:
   /// This method is specially for writing back per-vertex data at the time of
   /// OpEmitVertex in GS.
   bool writeBackOutputStream(const NamedDecl *decl, QualType type,
-                             SpirvVariable *value);
+                             SpirvInstruction *value);
 
   /// \brief Negates to get the additive inverse of SV_Position.y if requested.
   SpirvInstruction *invertYIfRequested(SpirvInstruction *position);
@@ -634,16 +631,12 @@ private:
 
 private:
   const hlsl::ShaderModel &shaderModel;
-  ModuleBuilder &theBuilder;
   SpirvBuilder &spvBuilder;
   SPIRVEmitter &theEmitter;
   const SpirvCodeGenOptions &spirvOptions;
   ASTContext &astContext;
   SpirvContext &spvContext;
   DiagnosticsEngine &diags;
-
-  TypeTranslator &typeTranslator;
-
   SpirvFunction *entryFunction;
 
   /// Mapping of all Clang AST decls to their instruction pointers.
@@ -657,7 +650,7 @@ private:
   /// other cases, stage variable reading and writing is done at the time of
   /// creating that stage variable, so that we don't need to query them again
   /// for reading and writing.
-  llvm::DenseMap<const ValueDecl *, SpirvVariable *> stageVarIds;
+  llvm::DenseMap<const ValueDecl *, SpirvVariable *> stageVarInstructions;
   /// Vector of all defined resource variables.
   llvm::SmallVector<ResourceVar, 8> resourceVars;
   /// Mapping from {RW|Append|Consume}StructuredBuffers to their
@@ -752,16 +745,16 @@ void CounterIdAliasPair::assign(const CounterIdAliasPair &srcPair,
   builder.createStore(counterVar, srcPair.get(builder, context));
 }
 
-DeclResultIdMapper::DeclResultIdMapper(
-    const hlsl::ShaderModel &model, ASTContext &context,
-    SpirvContext &spirvContext, ModuleBuilder &builder,
-    SpirvBuilder &spirvBuilder, SPIRVEmitter &emitter,
-    TypeTranslator &translator, FeatureManager &features,
-    const SpirvCodeGenOptions &options)
-    : shaderModel(model), theBuilder(builder), spvBuilder(spirvBuilder),
-      theEmitter(emitter), spirvOptions(options), astContext(context),
-      spvContext(spirvContext), diags(context.getDiagnostics()),
-      typeTranslator(translator), entryFunction(nullptr),
+DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
+                                       ASTContext &context,
+                                       SpirvContext &spirvContext,
+                                       SpirvBuilder &spirvBuilder,
+                                       SPIRVEmitter &emitter,
+                                       FeatureManager &features,
+                                       const SpirvCodeGenOptions &options)
+    : shaderModel(model), spvBuilder(spirvBuilder), theEmitter(emitter),
+      spirvOptions(options), astContext(context), spvContext(spirvContext),
+      diags(context.getDiagnostics()), entryFunction(nullptr),
       laneCountBuiltinVar(nullptr), laneIndexBuiltinVar(nullptr),
       needsLegalization(false),
       glPerVertex(model, context, spirvContext, spirvBuilder) {}

+ 5 - 99
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -240,7 +240,7 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
       const bool isMS = (name == "Texture2DMS" || name == "Texture2DMSArray");
       const auto sampledType = hlsl::GetHLSLResourceResultType(type);
       return spvContext.getImageType(
-          lowerType(getElementType(sampledType, srcLoc), rule, srcLoc), dim,
+          lowerType(getElementType(sampledType), rule, srcLoc), dim,
           ImageType::WithDepth::Unknown, isArray, isMS,
           ImageType::WithSampler::Yes, spv::ImageFormat::Unknown);
     }
@@ -255,7 +255,7 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
       const auto format =
           translateSampledTypeToImageFormat(sampledType, srcLoc);
       return spvContext.getImageType(
-          lowerType(getElementType(sampledType, srcLoc), rule, srcLoc), dim,
+          lowerType(getElementType(sampledType), rule, srcLoc), dim,
           ImageType::WithDepth::Unknown, isArray,
           /*isMultiSampled=*/false, /*sampled=*/ImageType::WithSampler::No,
           format);
@@ -340,8 +340,8 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
     }
     const auto format = translateSampledTypeToImageFormat(sampledType, srcLoc);
     return spvContext.getImageType(
-        lowerType(getElementType(sampledType, srcLoc), rule, srcLoc),
-        spv::Dim::Buffer, ImageType::WithDepth::Unknown,
+        lowerType(getElementType(sampledType), rule, srcLoc), spv::Dim::Buffer,
+        ImageType::WithDepth::Unknown,
         /*isArrayed=*/false, /*isMultiSampled=*/false,
         /*sampled*/ name == "Buffer" ? ImageType::WithSampler::Yes
                                      : ImageType::WithSampler::No,
@@ -371,7 +371,7 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
   if (name == "SubpassInput" || name == "SubpassInputMS") {
     const auto sampledType = hlsl::GetHLSLResourceResultType(type);
     return spvContext.getImageType(
-        lowerType(getElementType(sampledType, srcLoc), rule, srcLoc),
+        lowerType(getElementType(sampledType), rule, srcLoc),
         spv::Dim::SubpassData, ImageType::WithDepth::Unknown,
         /*isArrayed=*/false,
         /*isMultipleSampled=*/name == "SubpassInputMS",
@@ -381,77 +381,6 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
   return nullptr;
 }
 
-QualType LowerTypeVisitor::getElementType(QualType type,
-                                          SourceLocation srcLoc) {
-  QualType elemType = {};
-  if (isScalarType(type, &elemType) || isVectorType(type, &elemType) ||
-      isMxNMatrix(type, &elemType) || canFitIntoOneRegister(type, &elemType)) {
-    return elemType;
-  }
-
-  if (const auto *arrType = astContext.getAsConstantArrayType(type)) {
-    return arrType->getElementType();
-  }
-
-  emitError("unsupported resource type parameter %0", srcLoc) << type;
-  // Note: We are returning the original type instead of a null QualType here
-  // to keep the translation going and avoid hitting asserts trying to query
-  // info from null QualType in other places of the compiler. Although we are
-  // likely generating invalid code here, it should be fine since the error
-  // reported will prevent the CodeGen from actually outputing.
-  return type;
-}
-
-bool LowerTypeVisitor::canFitIntoOneRegister(QualType structType,
-                                             QualType *elemType,
-                                             uint32_t *elemCount) {
-  if (structType->getAsStructureType() == nullptr)
-    return false;
-
-  const auto *structDecl = structType->getAsStructureType()->getDecl();
-  QualType firstElemType;
-  uint32_t totalCount = 0;
-
-  for (const auto *field : structDecl->fields()) {
-    QualType type;
-    uint32_t count = 1;
-
-    if (isScalarType(field->getType(), &type) ||
-        isVectorType(field->getType(), &type, &count)) {
-      if (firstElemType.isNull()) {
-        firstElemType = type;
-      } else {
-        if (!canTreatAsSameScalarType(firstElemType, type)) {
-          emitError("all struct members should have the same element type for "
-                    "resource template instantiation",
-                    structDecl->getLocation());
-          return false;
-        }
-      }
-      totalCount += count;
-    } else {
-      emitError("unsupported struct element type for resource template "
-                "instantiation",
-                structDecl->getLocation());
-      return false;
-    }
-  }
-
-  if (totalCount > 4) {
-    emitError(
-        "resource template element type %0 cannot fit into four 32-bit scalars",
-        structDecl->getLocation())
-        << structType;
-    return false;
-  }
-
-  if (elemType)
-    *elemType = firstElemType;
-  if (elemCount)
-    *elemCount = totalCount;
-  return true;
-}
-
 spv::ImageFormat
 LowerTypeVisitor::translateSampledTypeToImageFormat(QualType sampledType,
                                                     SourceLocation srcLoc) {
@@ -488,29 +417,6 @@ LowerTypeVisitor::translateSampledTypeToImageFormat(QualType sampledType,
   return spv::ImageFormat::Unknown;
 }
 
-bool LowerTypeVisitor::canTreatAsSameScalarType(QualType type1,
-                                                QualType type2) {
-  // Treat const int/float the same as const int/float
-  type1.removeLocalConst();
-  type2.removeLocalConst();
-
-  return (type1.getCanonicalType() == type2.getCanonicalType()) ||
-         // Treat 'literal float' and 'float' as the same
-         (type1->isSpecificBuiltinType(BuiltinType::LitFloat) &&
-          type2->isFloatingType()) ||
-         (type2->isSpecificBuiltinType(BuiltinType::LitFloat) &&
-          type1->isFloatingType()) ||
-         // Treat 'literal int' and 'int'/'uint' as the same
-         (type1->isSpecificBuiltinType(BuiltinType::LitInt) &&
-          type2->isIntegerType() &&
-          // Disallow boolean types
-          !type2->isSpecificBuiltinType(BuiltinType::Bool)) ||
-         (type2->isSpecificBuiltinType(BuiltinType::LitInt) &&
-          type1->isIntegerType() &&
-          // Disallow boolean types
-          !type1->isSpecificBuiltinType(BuiltinType::Bool));
-}
-
 QualType LowerTypeVisitor::desugarType(QualType type) {
   if (const auto *attrType = type->getAs<AttributedType>()) {
     switch (auto kind = attrType->getAttrKind()) {

+ 0 - 14
tools/clang/lib/SPIRV/LowerTypeVisitor.h

@@ -44,25 +44,11 @@ private:
   const SpirvType *lowerResourceType(QualType type, SpirvLayoutRule rule,
                                      SourceLocation);
 
-  QualType getElementType(QualType type, SourceLocation);
-
-  /// Returns true if all members in structType are of the same element
-  /// type and can be fit into a 4-component vector. Writes element type and
-  /// count to *elemType and *elemCount if not nullptr. Otherwise, emit errors
-  /// explaining why not.
-  bool canFitIntoOneRegister(QualType structType, QualType *elemType,
-                             uint32_t *elemCount = nullptr);
-
   /// For the given sampled type, returns the corresponding image format
   /// that can be used to create an image object.
   spv::ImageFormat translateSampledTypeToImageFormat(QualType sampledType,
                                                      SourceLocation);
 
-  /// Returns true if the two types can be treated as the same scalar
-  /// type, which means they have the same canonical type, regardless of
-  /// constnesss and literalness.
-  bool canTreatAsSameScalarType(QualType type1, QualType type2);
-
   /// Strips the attributes and typedefs from the given type and returns the
   /// desugared one.
   ///

+ 43 - 4
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -76,7 +76,8 @@ const Decoration *SPIRVContext::registerDecoration(const Decoration &d) {
 
 SpirvContext::SpirvContext()
     : allocator(), voidType(nullptr), boolType(nullptr), sintTypes({}),
-      uintTypes({}), floatTypes({}), samplerType(nullptr) {
+      uintTypes({}), floatTypes({}), samplerType(nullptr),
+      boolTrueConstant(nullptr), boolFalseConstant(nullptr) {
   voidType = new (this) VoidType;
   boolType = new (this) BoolType;
   samplerType = new (this) SamplerType;
@@ -259,9 +260,10 @@ const HybridStructType *SpirvContext::getHybridStructType(
 
   HybridStructType type(fields, name, isReadOnly, interfaceType);
 
-  auto found = std::find_if(
-      hybridStructTypes.begin(), hybridStructTypes.end(),
-      [&type](const HybridStructType *cachedType) { return type == *cachedType; });
+  auto found = std::find_if(hybridStructTypes.begin(), hybridStructTypes.end(),
+                            [&type](const HybridStructType *cachedType) {
+                              return type == *cachedType;
+                            });
 
   if (found != hybridStructTypes.end())
     return *found;
@@ -365,5 +367,42 @@ SpirvConstant *SpirvContext::getConstantInt32(int32_t value) {
   return intConst;
 }
 
+SpirvConstant *SpirvContext::getConstantFloat32(float value) {
+  const FloatType *floatType = getFloatType(32);
+  SpirvConstantFloat tempConstant(floatType, value);
+
+  auto found =
+      std::find_if(floatConstants.begin(), floatConstants.end(),
+                   [&tempConstant](SpirvConstantFloat *cachedConstant) {
+                     return tempConstant == *cachedConstant;
+                   });
+
+  if (found != floatConstants.end())
+    return *found;
+
+  // Couldn't find the constant. Create one.
+  auto *floatConst = new (this) SpirvConstantFloat(floatType, value);
+  floatConstants.push_back(floatConst);
+  return floatConst;
+}
+
+SpirvConstant *SpirvContext::getConstantBool(bool value) {
+  if (value && boolTrueConstant)
+    return boolTrueConstant;
+
+  if (!value && boolFalseConstant)
+    return boolFalseConstant;
+
+  // Couldn't find the constant. Create one.
+  auto *boolConst = new (this) SpirvConstantBoolean(getBoolType(), value);
+
+  if (value)
+    boolTrueConstant = boolConst;
+  else
+    boolFalseConstant = boolConst;
+
+  return boolConst;
+}
+
 } // end namespace spirv
 } // end namespace clang

+ 32 - 28
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -602,11 +602,11 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
       entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
       shaderModel(*hlsl::ShaderModel::GetByName(
           ci.getCodeGenOpts().HLSLProfile.c_str())),
-      theContext(), spvContext(astContext), featureManager(diags, spirvOptions),
+      theContext(), spvContext(), featureManager(diags, spirvOptions),
       theBuilder(&theContext, &featureManager, spirvOptions),
       spvBuilder(astContext, spvContext, &featureManager, spirvOptions),
       typeTranslator(astContext, theBuilder, diags, spirvOptions),
-      declIdMapper(shaderModel, astContext, theBuilder, *this, typeTranslator,
+      declIdMapper(shaderModel, astContext, spvContext, spvBuilder, *this,
                    featureManager, spirvOptions),
       entryFunction(nullptr), curFunction(nullptr), curThis(0),
       seenPushConstantAt(), isSpecConstantMode(false),
@@ -642,22 +642,20 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
   }
 
   // Set shader module version
-  theBuilder.setShaderModelVersion(shaderModel.GetMajor(),
+  spvBuilder.setShaderModelVersion(shaderModel.GetMajor(),
                                    shaderModel.GetMinor());
 
   // Set debug info
   const auto &inputFiles = ci.getFrontendOpts().Inputs;
   if (spirvOptions.debugInfoFile && !inputFiles.empty()) {
     // File name
-    mainSourceFileId = theContext.takeNextId();
-    theBuilder.setSourceFileName(mainSourceFileId,
-                                 inputFiles.front().getFile().str());
+    spvBuilder.setSourceFileName(inputFiles.front().getFile().str());
 
     // Source code
     const auto &sm = ci.getSourceManager();
     const llvm::MemoryBuffer *mainFile =
         sm.getBuffer(sm.getMainFileID(), SourceLocation());
-    theBuilder.setSourceFileContent(
+    spvBuilder.setSourceFileContent(
         StringRef(mainFile->getBufferStart(), mainFile->getBufferSize()));
   }
 }
@@ -695,13 +693,11 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
   AddRequiredCapabilitiesForShaderModel();
 
   // Addressing and memory model are required in a valid SPIR-V module.
-  theBuilder.setAddressingModel(spv::AddressingModel::Logical);
-  theBuilder.setMemoryModel(spv::MemoryModel::GLSL450);
+  spvBuilder.setMemoryModel(spv::AddressingModel::Logical,
+                            spv::MemoryModel::GLSL450);
 
-  theBuilder.addEntryPoint(getSpirvShaderStage(shaderModel), entryFunctionId,
-                           entryFunctionName, declIdMapper.collectStageVars());
   spvBuilder.addEntryPoint(getSpirvShaderStage(shaderModel), entryFunction,
-                           entryFunctionName, interfaces);
+                           entryFunctionName, declIdMapper.collectStageVars());
 
   // Add Location decorations to stage input/output variables.
   if (!declIdMapper.decorateStageIOLocations())
@@ -712,6 +708,7 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
     return;
 
   // Output the constructed module.
+  // TODO: Switch to new infra.
   std::vector<uint32_t> m = theBuilder.takeModule();
 
   if (!spirvOptions.codeGenHighLevel) {
@@ -838,11 +835,12 @@ void SPIRVEmitter::doStmt(const Stmt *stmt,
   }
 }
 
-SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
-  SpirvEvalInfo result(/*id*/ 0);
+SpirvInstruction *SPIRVEmitter::doExpr(const Expr *expr) {
+  SpirvInstruction *result = nullptr;
 
   // Provide a hint to the typeTranslator that if a literal is discovered, its
   // intended usage is as this expression type.
+  // TODO(ehsan): Literal type handling must be fixed.
   TypeTranslator::LiteralTypeHint hint(typeTranslator, expr->getType());
 
   expr = expr->IgnoreParens();
@@ -856,6 +854,8 @@ SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
   } else if (const auto *initListExpr = dyn_cast<InitListExpr>(expr)) {
     result = doInitListExpr(initListExpr);
   } else if (const auto *boolLiteral = dyn_cast<CXXBoolLiteralExpr>(expr)) {
+    // TODO: Wtf is isSpecConstantMode
+    // result = spvContext.getConstantBool(boolLiteral->getValue());
     const auto value =
         theBuilder.getConstantBool(boolLiteral->getValue(), isSpecConstantMode);
     result = SpirvEvalInfo(value).setConstant().setRValue();
@@ -1043,8 +1043,9 @@ bool SPIRVEmitter::loadIfAliasVarRef(const Expr *varExpr, SpirvEvalInfo &info) {
   return false;
 }
 
-uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType,
-                                  QualType toType, SourceLocation srcLoc) {
+SpirvInstruction *SPIRVEmitter::castToType(SpirvInstruction *value,
+                                           QualType fromType, QualType toType,
+                                           SourceLocation srcLoc) {
   if (isFloatOrVecOfFloatType(toType))
     return castToFloat(value, fromType, toType, srcLoc);
 
@@ -1059,7 +1060,7 @@ uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType,
     return castToInt(value, fromType, toType, srcLoc);
 
   emitError("casting to type %0 unimplemented", {}) << toType;
-  return 0;
+  return nullptr;
 }
 
 void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
@@ -5985,7 +5986,7 @@ void SPIRVEmitter::createSpecConstant(const VarDecl *varDecl) {
   // We are not creating a variable to hold the spec constant, instead, we
   // translate the varDecl directly into the spec constant here.
 
-  theBuilder.decorateSpecId(
+  spvBuilder.decorateSpecId(
       specConstant, varDecl->getAttr<VKConstantIdAttr>()->getSpecConstId());
 
   declIdMapper.registerSpecConstant(varDecl, specConstant);
@@ -6195,8 +6196,9 @@ SpirvEvalInfo &SPIRVEmitter::turnIntoElementPtr(
   return base;
 }
 
-uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,
-                                  QualType toBoolType) {
+SpirvInstruction *SPIRVEmitter::castToBool(SpirvInstruction *fromVal,
+                                           QualType fromType,
+                                           QualType toBoolType) {
   if (TypeTranslator::isSameScalarOrVecType(fromType, toBoolType))
     return fromVal;
 
@@ -6228,8 +6230,9 @@ uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,
   return theBuilder.createBinaryOp(spvOp, boolType, fromVal, zeroVal);
 }
 
-uint32_t SPIRVEmitter::castToInt(uint32_t fromVal, QualType fromType,
-                                 QualType toIntType, SourceLocation srcLoc) {
+SpirvInstruction *SPIRVEmitter::castToInt(SpirvInstruction *fromVal,
+                                          QualType fromType, QualType toIntType,
+                                          SourceLocation srcLoc) {
   if (TypeTranslator::isSameScalarOrVecType(fromType, toIntType))
     return fromVal;
 
@@ -6332,9 +6335,10 @@ uint32_t SPIRVEmitter::convertBitwidth(uint32_t fromVal, QualType fromType,
   llvm_unreachable("invalid type passed to convertBitwidth");
 }
 
-uint32_t SPIRVEmitter::castToFloat(uint32_t fromVal, QualType fromType,
-                                   QualType toFloatType,
-                                   SourceLocation srcLoc) {
+SpirvInstruction *SPIRVEmitter::castToFloat(SpirvInstruction *fromVal,
+                                            QualType fromType,
+                                            QualType toFloatType,
+                                            SourceLocation srcLoc) {
   if (TypeTranslator::isSameScalarOrVecType(fromType, toFloatType))
     return fromVal;
 
@@ -9319,11 +9323,11 @@ SPIRVEmitter::getSpirvShaderStage(const hlsl::ShaderModel &model) {
 
 void SPIRVEmitter::AddRequiredCapabilitiesForShaderModel() {
   if (shaderModel.IsHS() || shaderModel.IsDS()) {
-    theBuilder.requireCapability(spv::Capability::Tessellation);
+    spvBuilder.requireCapability(spv::Capability::Tessellation);
   } else if (shaderModel.IsGS()) {
-    theBuilder.requireCapability(spv::Capability::Geometry);
+    spvBuilder.requireCapability(spv::Capability::Geometry);
   } else {
-    theBuilder.requireCapability(spv::Capability::Shader);
+    spvBuilder.requireCapability(spv::Capability::Shader);
   }
 }
 

+ 9 - 8
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -58,7 +58,7 @@ public:
 
   void doDecl(const Decl *decl);
   void doStmt(const Stmt *stmt, llvm::ArrayRef<const Attr *> attrs = {});
-  SpirvEvalInfo doExpr(const Expr *expr);
+  SpirvInstruction *doExpr(const Expr *expr);
 
   /// Processes the given expression and emits SPIR-V instructions. If the
   /// result is a GLValue, does an additional load.
@@ -72,8 +72,8 @@ public:
 
   /// Casts the given value from fromType to toType. fromType and toType should
   /// both be scalar or vector types of the same size.
-  uint32_t castToType(uint32_t value, QualType fromType, QualType toType,
-                      SourceLocation);
+  SpirvInstruction *castToType(SpirvInstruction *value, QualType fromType,
+                               QualType toType, SourceLocation);
 
 private:
   void doFunctionDecl(const FunctionDecl *decl);
@@ -316,17 +316,18 @@ private:
 
   /// Processes the given expr, casts the result into the given bool (vector)
   /// type and returns the <result-id> of the casted value.
-  uint32_t castToBool(uint32_t value, QualType fromType, QualType toType);
+  SpirvInstruction *castToBool(SpirvInstruction *value, QualType fromType,
+                               QualType toType);
 
   /// Processes the given expr, casts the result into the given integer (vector)
   /// type and returns the <result-id> of the casted value.
-  uint32_t castToInt(uint32_t value, QualType fromType, QualType toType,
-                     SourceLocation);
+  SpirvInstruction *castToInt(SpirvInstruction *value, QualType fromType,
+                              QualType toType, SourceLocation);
 
   /// Processes the given expr, casts the result into the given float (vector)
   /// type and returns the <result-id> of the casted value.
-  uint32_t castToFloat(uint32_t value, QualType fromType, QualType toType,
-                       SourceLocation);
+  SpirvInstruction *castToFloat(SpirvInstruction *value, QualType fromType,
+                                QualType toType, SourceLocation);
 
 private:
   /// Processes HLSL instrinsic functions.

+ 17 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -759,6 +759,23 @@ SpirvVariable *SpirvBuilder::addStageBuiltinVar(const SpirvType *type,
   return var;
 }
 
+SpirvVariable *SpirvBuilder::addStageBuiltinVar(QualType type,
+                                                spv::StorageClass storageClass,
+                                                spv::BuiltIn builtin,
+                                                SourceLocation loc) {
+  // Note: We store the underlying type in the variable, *not* the pointer type.
+  // TODO(ehsan): type pointer should be added in lowering the type.
+  auto *var = new (context) SpirvVariable(type, /*id*/ 0, loc, storageClass);
+  module->addVariable(var);
+
+  // Decorate with the specified Builtin
+  auto *decor = new (context) SpirvDecoration(
+      loc, var, spv::Decoration::BuiltIn, {static_cast<uint32_t>(builtin)});
+  module->addDecoration(decor);
+
+  return var;
+}
+
 SpirvVariable *SpirvBuilder::addModuleVar(
     QualType type, spv::StorageClass storageClass, llvm::StringRef name,
     llvm::Optional<SpirvInstruction *> init, SourceLocation loc) {

Daži faili netika attēloti, jo izmaiņu fails ir pārāk liels