Browse Source

[spirv] Add names for other types.

We were missing the names for Image types, sampler types, and sampled
image types. We also should only add an extension if it hasn't been
added already.
Ehsan Nasiri 6 years ago
parent
commit
14ecf42d31

+ 1 - 0
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -659,6 +659,7 @@ private:
   SpirvConstantBoolean *boolFalseSpecConstant;
 
   llvm::SetVector<spv::Capability> existingCapabilities;
+  llvm::SetVector<Extension> existingExtensions;
 };
 
 void SpirvBuilder::requireCapability(spv::Capability cap, SourceLocation loc) {

+ 11 - 8
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -55,6 +55,7 @@ public:
   virtual ~SpirvType() = default;
 
   Kind getKind() const { return kind; }
+  llvm::StringRef getName() const { return debugName; }
 
   static bool isTexture(const SpirvType *);
   static bool isRWTexture(const SpirvType *);
@@ -67,10 +68,11 @@ public:
   static bool isOrContains16BitType(const SpirvType *);
 
 protected:
-  SpirvType(Kind k) : kind(k) {}
+  SpirvType(Kind k, llvm::StringRef name = "") : kind(k), debugName(name) {}
 
 private:
   const Kind kind;
+  std::string debugName;
 };
 
 class VoidType : public SpirvType {
@@ -202,6 +204,9 @@ public:
   WithSampler withSampler() const { return isSampled; }
   spv::ImageFormat getImageFormat() const { return imageFormat; }
 
+private:
+  std::string getImageName(spv::Dim, bool arrayed);
+
 private:
   const NumericalType *sampledType;
   spv::Dim dimension;
@@ -214,7 +219,7 @@ private:
 
 class SamplerType : public SpirvType {
 public:
-  SamplerType() : SpirvType(TK_Sampler) {}
+  SamplerType() : SpirvType(TK_Sampler, "type.sampler") {}
 
   static bool classof(const SpirvType *t) { return t->getKind() == TK_Sampler; }
 };
@@ -222,7 +227,7 @@ public:
 class SampledImageType : public SpirvType {
 public:
   SampledImageType(const ImageType *image)
-      : SpirvType(TK_SampledImage), imageType(image) {}
+      : SpirvType(TK_SampledImage, "type.sampled.image"), imageType(image) {}
 
   static bool classof(const SpirvType *t) {
     return t->getKind() == TK_SampledImage;
@@ -294,7 +299,7 @@ public:
 
   llvm::ArrayRef<FieldInfo> getFields() const { return fields; }
   bool isReadOnly() const { return readOnly; }
-  std::string getStructName() const { return structName; }
+  llvm::StringRef getStructName() const { return getName(); }
   StructInterfaceType getInterfaceType() const { return interfaceType; }
 
   bool operator==(const StructType &that) const;
@@ -305,7 +310,6 @@ private:
   // names when considering unification. Otherwise, reflection will be confused.
 
   llvm::SmallVector<FieldInfo, 8> fields;
-  std::string structName;
   bool readOnly;
   // Indicates the interface type of this structure. If this structure is a
   // storage buffer shader-interface, it will be decorated with 'BufferBlock'.
@@ -363,7 +367,7 @@ public:
   }
 
 protected:
-  HybridType(Kind k) : SpirvType(k) {}
+  HybridType(Kind k, llvm::StringRef name = "") : SpirvType(k, name) {}
 };
 
 /// **NOTE**: This type is created in order to facilitate transition of old
@@ -403,7 +407,7 @@ public:
 
   llvm::ArrayRef<FieldInfo> getFields() const { return fields; }
   bool isReadOnly() const { return readOnly; }
-  std::string getStructName() const { return structName; }
+  llvm::StringRef getStructName() const { return getName(); }
   StructInterfaceType getInterfaceType() const { return interfaceType; }
 
   bool operator==(const HybridStructType &that) const;
@@ -414,7 +418,6 @@ private:
   // names when considering unification. Otherwise, reflection will be confused.
 
   llvm::SmallVector<FieldInfo, 8> fields;
-  std::string structName;
   bool readOnly;
   // Indicates the interface type of this structure. If this structure is a
   // storage buffer shader-interface, it will be decorated with 'BufferBlock'.

+ 3 - 3
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -1148,6 +1148,9 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
   if (alreadyExists)
     return id;
 
+  // Emit OpName for the type (if any).
+  emitNameForType(type->getName(), id);
+
   if (isa<VoidType>(type)) {
     initTypeInstruction(spv::Op::OpTypeVoid);
     curTypeInst.push_back(id);
@@ -1250,9 +1253,6 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
     llvm::ArrayRef<StructType::FieldInfo> fields = structType->getFields();
     size_t numFields = fields.size();
 
-    // Emit OpName for the struct.
-    emitNameForType(structType->getStructName(), id);
-
     // Emit OpMemberName for the struct members.
     for (size_t i = 0; i < numFields; ++i)
       emitNameForType(fields[i].name, id, i);

+ 9 - 7
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -723,13 +723,15 @@ void SpirvBuilder::addExtension(Extension ext, llvm::StringRef target,
   // TODO: The extension management should be removed from here and added as a
   // separate pass.
 
-  assert(featureManager);
-  featureManager->requestExtension(ext, target, loc);
-  // Do not emit OpExtension if the given extension is natively supported in the
-  // target environment.
-  if (featureManager->isExtensionRequiredForTargetEnv(ext))
-    module->addExtension(new (context) SpirvExtension(
-        loc, featureManager->getExtensionName(ext)));
+  if (existingExtensions.insert(ext)) {
+    assert(featureManager);
+    featureManager->requestExtension(ext, target, loc);
+    // Do not emit OpExtension if the given extension is natively supported in
+    // the target environment.
+    if (featureManager->isExtensionRequiredForTargetEnv(ext))
+      module->addExtension(new (context) SpirvExtension(
+          loc, featureManager->getExtensionName(ext)));
+  }
 }
 
 SpirvExtInstImport *SpirvBuilder::getGLSLExtInstSet(SourceLocation loc) {

+ 46 - 15
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -141,9 +141,41 @@ bool MatrixType::operator==(const MatrixType &that) const {
 ImageType::ImageType(const NumericalType *type, spv::Dim dim, WithDepth depth,
                      bool arrayed, bool ms, WithSampler sampled,
                      spv::ImageFormat format)
-    : SpirvType(TK_Image), sampledType(type), dimension(dim), imageDepth(depth),
-      isArrayed(arrayed), isMultiSampled(ms), isSampled(sampled),
-      imageFormat(format) {}
+    : SpirvType(TK_Image, getImageName(dim, arrayed)), sampledType(type),
+      dimension(dim), imageDepth(depth), isArrayed(arrayed), isMultiSampled(ms),
+      isSampled(sampled), imageFormat(format) {}
+
+std::string ImageType::getImageName(spv::Dim dim, bool arrayed) {
+  const char *dimStr = "";
+  switch (dim) {
+  case spv::Dim::Dim1D:
+    dimStr = "1d.";
+    break;
+  case spv::Dim::Dim2D:
+    dimStr = "2d.";
+    break;
+  case spv::Dim::Dim3D:
+    dimStr = "3d.";
+    break;
+  case spv::Dim::Cube:
+    dimStr = "cube.";
+    break;
+  case spv::Dim::Rect:
+    dimStr = "rect.";
+    break;
+  case spv::Dim::Buffer:
+    dimStr = "buffer.";
+    break;
+  case spv::Dim::SubpassData:
+    dimStr = "subpass.";
+    break;
+  default:
+    break;
+  }
+  std::string name =
+      std::string("type.") + dimStr + "image" + (arrayed ? ".array" : "");
+  return name;
+}
 
 bool ImageType::operator==(const ImageType &that) const {
   return sampledType == that.sampledType && dimension == that.dimension &&
@@ -154,37 +186,36 @@ bool ImageType::operator==(const ImageType &that) const {
 StructType::StructType(llvm::ArrayRef<StructType::FieldInfo> fieldsVec,
                        llvm::StringRef name, bool isReadOnly,
                        StructInterfaceType iface)
-    : SpirvType(TK_Struct), fields(fieldsVec.begin(), fieldsVec.end()),
-      structName(name), readOnly(isReadOnly), interfaceType(iface) {}
+    : SpirvType(TK_Struct, name), fields(fieldsVec.begin(), fieldsVec.end()),
+      readOnly(isReadOnly), interfaceType(iface) {}
 
 bool StructType::FieldInfo::
 operator==(const StructType::FieldInfo &that) const {
-  return type == that.type && name == that.name &&
-         vkOffsetAttr == that.vkOffsetAttr &&
+  return type == that.type && vkOffsetAttr == that.vkOffsetAttr &&
          packOffsetAttr == that.packOffsetAttr;
 }
 
 bool StructType::operator==(const StructType &that) const {
-  return fields == that.fields && structName == that.structName &&
-         readOnly == that.readOnly;
+  return fields == that.fields && getName() == that.getName() &&
+         readOnly == that.readOnly && interfaceType == that.interfaceType;
 }
 
 HybridStructType::HybridStructType(
     llvm::ArrayRef<HybridStructType::FieldInfo> fieldsVec, llvm::StringRef name,
     bool isReadOnly, StructInterfaceType iface)
-    : HybridType(TK_HybridStruct), fields(fieldsVec.begin(), fieldsVec.end()),
-      structName(name), readOnly(isReadOnly), interfaceType(iface) {}
+    : HybridType(TK_HybridStruct, name),
+      fields(fieldsVec.begin(), fieldsVec.end()), readOnly(isReadOnly),
+      interfaceType(iface) {}
 
 bool HybridStructType::FieldInfo::
 operator==(const HybridStructType::FieldInfo &that) const {
-  return astType == that.astType && name == that.name &&
-         vkOffsetAttr == that.vkOffsetAttr &&
+  return astType == that.astType && vkOffsetAttr == that.vkOffsetAttr &&
          packOffsetAttr == that.packOffsetAttr;
 }
 
 bool HybridStructType::operator==(const HybridStructType &that) const {
-  return fields == that.fields && structName == that.structName &&
-         readOnly == that.readOnly;
+  return fields == that.fields && getName() == that.getName() &&
+         readOnly == that.readOnly && interfaceType == that.interfaceType;
 }
 
 } // namespace spirv