Procházet zdrojové kódy

[spirv] Use SetVector for storing Capabilities.

Ehsan Nasiri před 6 roky
rodič
revize
e0098e8f68

+ 2 - 0
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -222,6 +222,8 @@ public:
 
   bool invokeVisitor(Visitor *v) override;
 
+  bool operator==(const SpirvCapability &that) const;
+
   spv::Capability getCapability() const { return capability; }
 
 private:

+ 20 - 3
tools/clang/include/clang/SPIRV/SpirvModule.h

@@ -29,7 +29,7 @@ struct ExtensionComparisonInfo {
     return llvm::hash_combine(ext->getExtensionName());
   }
   static bool isEqual(SpirvExtension *LHS, SpirvExtension *RHS) {
-    // Either both are null, or both should have the same underlying type.
+    // Either both are null, or both should have the same underlying extension.
     return (LHS == RHS) || (LHS && RHS && *LHS == *RHS);
   }
 };
@@ -47,6 +47,18 @@ struct DecorationComparisonInfo {
   }
 };
 
+struct CapabilityComparisonInfo {
+  static inline SpirvCapability *getEmptyKey() { return nullptr; }
+  static inline SpirvCapability *getTombstoneKey() { return nullptr; }
+  static unsigned getHashValue(const SpirvCapability *cap) {
+    return llvm::hash_combine(static_cast<uint32_t>(cap->getCapability()));
+  }
+  static bool isEqual(SpirvCapability *LHS, SpirvCapability *RHS) {
+    // Either both are null, or both should have the same underlying capability.
+    return (LHS == RHS) || (LHS && RHS && *LHS == *RHS);
+  }
+};
+
 /// The class representing a SPIR-V module in memory.
 ///
 /// A SPIR-V module contains two main parts: instructions for "metadata" (e.g.,
@@ -118,8 +130,13 @@ public:
   void addModuleProcessed(SpirvModuleProcessed *);
 
 private:
-  // "Metadata" instructions
-  llvm::SmallVector<SpirvCapability *, 8> capabilities;
+  // Use a set for storing capabilities. This will ensure there are no duplicate
+  // capabilities. Although the set stores pointers, the provided
+  // CapabilityComparisonInfo compares the SpirvCapability objects, not the
+  // pointers.
+  llvm::SetVector<SpirvCapability *, std::vector<SpirvCapability *>,
+                  llvm::DenseSet<SpirvCapability *, CapabilityComparisonInfo>>
+      capabilities;
 
   // Use a set for storing extensions. This will ensure there are no duplicate
   // extensions. Although the set stores pointers, the provided

+ 4 - 0
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -132,6 +132,10 @@ SpirvCapability::SpirvCapability(SourceLocation loc, spv::Capability cap)
     : SpirvInstruction(IK_Capability, spv::Op::OpCapability, QualType(), loc),
       capability(cap) {}
 
+bool SpirvCapability::operator==(const SpirvCapability &that) const {
+  return capability == that.capability;
+}
+
 SpirvExtension::SpirvExtension(SourceLocation loc,
                                llvm::StringRef extensionName)
     : SpirvInstruction(IK_Extension, spv::Op::OpExtension, QualType(), loc),

+ 5 - 16
tools/clang/lib/SPIRV/SpirvModule.cpp

@@ -102,9 +102,10 @@ bool SpirvModule::invokeVisitor(Visitor *visitor, bool reverseOrder) {
         return false;
     }
 
-    for (auto iter = capabilities.rbegin(); iter != capabilities.rend();
-         ++iter) {
-      auto *capability = *iter;
+    // Since SetVector doesn't have 'rbegin()' and 'rend()' methods, we use
+    // manual indexing.
+    for (auto capIndex = capabilities.size(); capIndex > 0; --capIndex) {
+      auto *capability = capabilities[capIndex - 1];
       if (!capability->invokeVisitor(visitor))
         return false;
     }
@@ -172,17 +173,7 @@ void SpirvModule::addFunction(SpirvFunction *fn) {
 
 void SpirvModule::addCapability(SpirvCapability *cap) {
   assert(cap && "cannot add null capability to the module");
-  // Only add the capability to the module if it is not already added.
-  // Due to the small number of capabilities, this should not be too expensive.
-  const spv::Capability capability = cap->getCapability();
-  auto found =
-      std::find_if(capabilities.begin(), capabilities.end(),
-                   [capability](SpirvCapability *existingCapability) {
-                     return capability == existingCapability->getCapability();
-                   });
-  if (found == capabilities.end()) {
-    capabilities.push_back(cap);
-  }
+  capabilities.insert(cap);
 }
 
 void SpirvModule::setMemoryModel(SpirvMemoryModel *model) {
@@ -202,8 +193,6 @@ void SpirvModule::addExecutionMode(SpirvExecutionMode *em) {
 
 void SpirvModule::addExtension(SpirvExtension *ext) {
   assert(ext && "cannot add null extension");
-  // The underlying data structure is a set, so there will not be any duplicate
-  // extensions.
   extensions.insert(ext);
 }
 

+ 13 - 0
tools/clang/test/CodeGenSPIRV/capability.unique.hlsl

@@ -0,0 +1,13 @@
+// Run: %dxc -T ps_6_2 -E main
+
+// Make sure the same capability is not applied twice.
+//
+// CHECK:     OpCapability Int64
+// CHECK-NOT: OpCapability Int64
+
+void main() {
+  int64_t a = 1;
+  int64_t b = 2;
+  int64_t c = a + b;
+}
+

+ 11 - 0
tools/clang/test/CodeGenSPIRV/extension.unique.hlsl

@@ -0,0 +1,11 @@
+// Run: %dxc -T ps_6_2 -E main -enable-16bit-types
+
+// Make sure the same decoration is not applied twice.
+//
+// CHECK:     OpExtension "SPV_KHR_16bit_storage"
+// CHECK-NOT: OpExtension "SPV_KHR_16bit_storage"
+
+float4 main(int16_t pix_pos : INPUT_1,
+            int16_t pix_pos2: INPUT_2): SV_Target {
+  return 0;
+}

+ 6 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -1852,6 +1852,12 @@ TEST_F(FileTest, RayTracingNVLibrary) {
 // For decoration uniqueness
 TEST_F(FileTest, DecorationUnique) { runFileTest("decoration.unique.hlsl"); }
 
+// For capability uniqueness
+TEST_F(FileTest, CapabilityUnique) { runFileTest("capability.unique.hlsl"); }
+
+// For extension uniqueness
+TEST_F(FileTest, ExtensionUnique) { runFileTest("extension.unique.hlsl"); }
+
 // For RelaxedPrecision decorations
 TEST_F(FileTest, DecorationRelaxedPrecisionBasic) {
   runFileTest("decoration.relaxed-precision.basic.hlsl");