Browse Source

[spirv] Use SetVector for storing Extensions.

If we use SetVector, we don't need to do a linear search of existing
extensions to figure out whether an extension has already been used or
not.
Ehsan Nasiri 6 years ago
parent
commit
156d7facde

+ 2 - 0
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -240,6 +240,8 @@ public:
 
 
   bool invokeVisitor(Visitor *v) override;
   bool invokeVisitor(Visitor *v) override;
 
 
+  bool operator==(const SpirvExtension &that) const;
+
   llvm::StringRef getExtensionName() const { return extName; }
   llvm::StringRef getExtensionName() const { return extName; }
 
 
 private:
 private:

+ 23 - 1
tools/clang/include/clang/SPIRV/SpirvModule.h

@@ -12,6 +12,8 @@
 #include <vector>
 #include <vector>
 
 
 #include "clang/SPIRV/SpirvInstruction.h"
 #include "clang/SPIRV/SpirvInstruction.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/SmallVector.h"
 
 
 namespace clang {
 namespace clang {
@@ -20,6 +22,18 @@ namespace spirv {
 class SpirvFunction;
 class SpirvFunction;
 class SpirvVisitor;
 class SpirvVisitor;
 
 
+struct ExtensionComparisonInfo {
+  static inline SpirvExtension *getEmptyKey() { return nullptr; }
+  static inline SpirvExtension *getTombstoneKey() { return nullptr; }
+  static unsigned getHashValue(const SpirvExtension *ext) {
+    return llvm::hash_combine(ext->getExtensionName());
+  }
+  static bool isEqual(SpirvExtension *LHS, SpirvExtension *RHS) {
+    // Either both are null, or both should have the same underlying type.
+    return (LHS == RHS) || (LHS && RHS && *LHS == *RHS);
+  }
+};
+
 /// The class representing a SPIR-V module in memory.
 /// The class representing a SPIR-V module in memory.
 ///
 ///
 /// A SPIR-V module contains two main parts: instructions for "metadata" (e.g.,
 /// A SPIR-V module contains two main parts: instructions for "metadata" (e.g.,
@@ -93,7 +107,15 @@ public:
 private:
 private:
   // "Metadata" instructions
   // "Metadata" instructions
   llvm::SmallVector<SpirvCapability *, 8> capabilities;
   llvm::SmallVector<SpirvCapability *, 8> capabilities;
-  llvm::SmallVector<SpirvExtension *, 4> extensions;
+
+  // Use a set for storing extensions. This will ensure there are no duplicate
+  // extensions. Although the set stores pointers, the provided
+  // ExtensionComparisonInfo compares the SpirvExtension objects, not the
+  // pointers.
+  llvm::SetVector<SpirvExtension *, std::vector<SpirvExtension *>,
+                  llvm::DenseSet<SpirvExtension *, ExtensionComparisonInfo>>
+      extensions;
+
   llvm::SmallVector<SpirvExtInstImport *, 1> extInstSets;
   llvm::SmallVector<SpirvExtInstImport *, 1> extInstSets;
   SpirvMemoryModel *memoryModel;
   SpirvMemoryModel *memoryModel;
   llvm::SmallVector<SpirvEntryPoint *, 1> entryPoints;
   llvm::SmallVector<SpirvEntryPoint *, 1> entryPoints;

+ 4 - 0
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -137,6 +137,10 @@ SpirvExtension::SpirvExtension(SourceLocation loc,
     : SpirvInstruction(IK_Extension, spv::Op::OpExtension, QualType(), loc),
     : SpirvInstruction(IK_Extension, spv::Op::OpExtension, QualType(), loc),
       extName(extensionName) {}
       extName(extensionName) {}
 
 
+bool SpirvExtension::operator==(const SpirvExtension &that) const {
+  return extName == that.extName;
+}
+
 SpirvExtInstImport::SpirvExtInstImport(SourceLocation loc,
 SpirvExtInstImport::SpirvExtInstImport(SourceLocation loc,
                                        llvm::StringRef extensionName)
                                        llvm::StringRef extensionName)
     : SpirvInstruction(IK_ExtInstImport, spv::Op::OpExtInstImport, QualType(),
     : SpirvInstruction(IK_ExtInstImport, spv::Op::OpExtInstImport, QualType(),

+ 7 - 13
tools/clang/lib/SPIRV/SpirvModule.cpp

@@ -92,8 +92,10 @@ bool SpirvModule::invokeVisitor(Visitor *visitor, bool reverseOrder) {
         return false;
         return false;
     }
     }
 
 
-    for (auto iter = extensions.rbegin(); iter != extensions.rend(); ++iter) {
-      auto *extension = *iter;
+    // Since SetVector doesn't have 'rbegin()' and 'rend()' methods, we use
+    // manual indexing.
+    for (auto extIndex = extensions.size(); extIndex > 0; --extIndex) {
+      auto *extension = extensions[extIndex - 1];
       if (!extension->invokeVisitor(visitor))
       if (!extension->invokeVisitor(visitor))
         return false;
         return false;
     }
     }
@@ -198,17 +200,9 @@ void SpirvModule::addExecutionMode(SpirvExecutionMode *em) {
 
 
 void SpirvModule::addExtension(SpirvExtension *ext) {
 void SpirvModule::addExtension(SpirvExtension *ext) {
   assert(ext && "cannot add null extension");
   assert(ext && "cannot add null extension");
-  // Only add the extension to the module if it is not already added.
-  // Due to the small number of extensions, this should not be too expensive.
-  const auto extName = ext->getExtensionName();
-  auto found =
-      std::find_if(extensions.begin(), extensions.end(),
-                   [&extName](SpirvExtension *existingExtension) {
-                     return extName == existingExtension->getExtensionName();
-                   });
-  if (found == extensions.end()) {
-    extensions.push_back(ext);
-  }
+  // The underlying data structure is a set, so there will not be any duplicate
+  // extensions.
+  extensions.insert(ext);
 }
 }
 
 
 void SpirvModule::addExtInstSet(SpirvExtInstImport *set) {
 void SpirvModule::addExtInstSet(SpirvExtInstImport *set) {