Parcourir la source

[spirv] Clean up extension & capability management. (#1994)

Ehsan il y a 6 ans
Parent
commit
d16812ed3f

+ 21 - 22
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -9,7 +9,6 @@
 #ifndef LLVM_CLANG_SPIRV_SPIRVBUILDER_H
 #define LLVM_CLANG_SPIRV_SPIRVBUILDER_H
 
-#include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/SpirvContext.h"
 #include "clang/SPIRV/SpirvBasicBlock.h"
 #include "clang/SPIRV/SpirvFunction.h"
@@ -31,9 +30,10 @@ namespace spirv {
 /// Call `getModule()` to get the SPIR-V words after finishing building the
 /// module.
 class SpirvBuilder {
+  friend class CapabilityVisitor;
+
 public:
-  SpirvBuilder(ASTContext &ac, SpirvContext &c, FeatureManager *,
-               const SpirvCodeGenOptions &);
+  SpirvBuilder(ASTContext &ac, SpirvContext &c, const SpirvCodeGenOptions &);
   ~SpirvBuilder() = default;
 
   // Forbid copy construction and assignment
@@ -426,9 +426,6 @@ public:
                         SourceLocation loc);
 
   // === SPIR-V Module Structure ===
-
-  inline void requireCapability(spv::Capability, SourceLocation loc = {});
-
   inline void setMemoryModel(spv::AddressingModel, spv::MemoryModel);
 
   /// \brief Adds an entry point for the module under construction. We only
@@ -449,10 +446,6 @@ public:
                                llvm::ArrayRef<uint32_t> params,
                                SourceLocation loc = {});
 
-  /// \brief Adds an extension to the module under construction for translating
-  /// the given target at the given source location.
-  void addExtension(Extension, llvm::StringRef target, SourceLocation);
-
   /// \brief Adds an OpModuleProcessed instruction to the module under
   /// construction.
   void addModuleProcessed(llvm::StringRef process);
@@ -569,6 +562,19 @@ public:
 public:
   std::vector<uint32_t> takeModule();
 
+
+protected:
+  /// Only friend classes are allowed to add capability/extension to the module
+  /// under construction.
+
+  /// \brief Adds the given capability to the module under construction due to
+  /// the feature used at the given source location.
+  inline void requireCapability(spv::Capability, SourceLocation loc = {});
+
+  /// \brief Adds an extension to the module under construction for translating
+  /// the given target at the given source location.
+  inline void requireExtension(llvm::StringRef extension, SourceLocation);
+
 private:
   /// \brief Returns the composed ImageOperandsMask from non-zero parameters
   /// and pushes non-zero parameters to *orderedParams in the expected order.
@@ -594,12 +600,8 @@ private:
   /// the entry block.
   std::vector<SpirvBasicBlock *> basicBlocks;
 
-  FeatureManager *featureManager; ///< SPIR-V version/extension manager.
   const SpirvCodeGenOptions &spirvOptions; ///< Command line options.
 
-  llvm::SetVector<spv::Capability> existingCapabilities;
-  llvm::SetVector<Extension> existingExtensions;
-
   /// A struct containing information regarding a builtin variable.
   struct BuiltInVarInfo {
     BuiltInVarInfo(spv::StorageClass s, spv::BuiltIn b, SpirvVariable *v)
@@ -613,14 +615,11 @@ private:
 };
 
 void SpirvBuilder::requireCapability(spv::Capability cap, SourceLocation loc) {
-  if (cap != spv::Capability::Max) {
-    // No need to create a new capability nor add it to the module if it has
-    // already been added.
-    if (existingCapabilities.insert(cap)) {
-      auto *capability = new (context) SpirvCapability(loc, cap);
-      module->addCapability(capability);
-    }
-  }
+  module->addCapability(new (context) SpirvCapability(loc, cap));
+}
+
+void SpirvBuilder::requireExtension(llvm::StringRef ext, SourceLocation loc) {
+  module->addExtension(new (context) SpirvExtension(loc, ext));
 }
 
 void SpirvBuilder::setMemoryModel(spv::AddressingModel addrModel,

+ 116 - 94
tools/clang/lib/SPIRV/CapabilityVisitor.cpp

@@ -13,6 +13,21 @@
 namespace clang {
 namespace spirv {
 
+void CapabilityVisitor::addExtension(Extension ext, llvm::StringRef target,
+                                     SourceLocation loc) {
+  featureManager.requestExtension(ext, target, loc);
+  // Do not emit OpExtension if the given extension is natively supported in
+  // the target environment.
+  if (featureManager.isExtensionRequiredForTargetEnv(ext))
+    spvBuilder.requireExtension(featureManager.getExtensionName(ext), loc);
+}
+
+void CapabilityVisitor::addCapability(spv::Capability cap, SourceLocation loc) {
+  if (cap != spv::Capability::Max) {
+    spvBuilder.requireCapability(cap, loc);
+  }
+}
+
 void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
                                              SourceLocation loc,
                                              spv::StorageClass sc) {
@@ -25,18 +40,18 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
     switch (intType->getBitwidth()) {
     case 16: {
       // Usage of a 16-bit integer type.
-      spvBuilder.requireCapability(spv::Capability::Int16);
+      addCapability(spv::Capability::Int16);
 
       // Usage of a 16-bit integer type as stage I/O.
       if (sc == spv::StorageClass::Input || sc == spv::StorageClass::Output) {
-        spvBuilder.addExtension(Extension::KHR_16bit_storage,
-                                "16-bit stage IO variables", loc);
-        spvBuilder.requireCapability(spv::Capability::StorageInputOutput16);
+        addExtension(Extension::KHR_16bit_storage, "16-bit stage IO variables",
+                     loc);
+        addCapability(spv::Capability::StorageInputOutput16);
       }
       break;
     }
     case 64: {
-      spvBuilder.requireCapability(spv::Capability::Int64);
+      addCapability(spv::Capability::Int64);
       break;
     }
     default:
@@ -51,20 +66,19 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
       // It looks like the validator does not approve of Float16
       // capability even though we do use the necessary extension.
       // TODO: Re-enable adding Float16 capability below.
-      // spvBuilder.requireCapability(spv::Capability::Float16);
-      spvBuilder.addExtension(Extension::AMD_gpu_shader_half_float,
-                              "16-bit float", loc);
+      // addCapability(spv::Capability::Float16);
+      addExtension(Extension::AMD_gpu_shader_half_float, "16-bit float", loc);
 
       // Usage of a 16-bit float type as stage I/O.
       if (sc == spv::StorageClass::Input || sc == spv::StorageClass::Output) {
-        spvBuilder.addExtension(Extension::KHR_16bit_storage,
-                                "16-bit stage IO variables", loc);
-        spvBuilder.requireCapability(spv::Capability::StorageInputOutput16);
+        addExtension(Extension::KHR_16bit_storage, "16-bit stage IO variables",
+                     loc);
+        addCapability(spv::Capability::StorageInputOutput16);
       }
       break;
     }
     case 64: {
-      spvBuilder.requireCapability(spv::Capability::Float64);
+      addCapability(spv::Capability::Float64);
       break;
     }
     default:
@@ -87,9 +101,9 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
   else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
     if (SpirvType::isResourceType(raType->getElementType())) {
       // the elements inside the runtime array are resources
-      spvBuilder.addExtension(Extension::EXT_descriptor_indexing,
-                              "runtime array of resources", loc);
-      spvBuilder.requireCapability(spv::Capability::RuntimeDescriptorArrayEXT);
+      addExtension(Extension::EXT_descriptor_indexing,
+                   "runtime array of resources", loc);
+      addCapability(spv::Capability::RuntimeDescriptorArrayEXT);
     }
     addCapabilityForType(raType->getElementType(), loc, sc);
   }
@@ -97,22 +111,22 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
   else if (const auto *imageType = dyn_cast<ImageType>(type)) {
     switch (imageType->getDimension()) {
     case spv::Dim::Buffer: {
-      spvBuilder.requireCapability(spv::Capability::SampledBuffer);
+      addCapability(spv::Capability::SampledBuffer);
       if (imageType->withSampler() == ImageType::WithSampler::No) {
-        spvBuilder.requireCapability(spv::Capability::ImageBuffer);
+        addCapability(spv::Capability::ImageBuffer);
       }
       break;
     }
     case spv::Dim::Dim1D: {
       if (imageType->withSampler() == ImageType::WithSampler::No) {
-        spvBuilder.requireCapability(spv::Capability::Image1D);
+        addCapability(spv::Capability::Image1D);
       } else {
-        spvBuilder.requireCapability(spv::Capability::Sampled1D);
+        addCapability(spv::Capability::Sampled1D);
       }
       break;
     }
     case spv::Dim::SubpassData: {
-      spvBuilder.requireCapability(spv::Capability::InputAttachment);
+      addCapability(spv::Capability::InputAttachment);
       break;
     }
     default:
@@ -146,8 +160,7 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
     case spv::ImageFormat::Rg8ui:
     case spv::ImageFormat::R16ui:
     case spv::ImageFormat::R8ui:
-      spvBuilder.requireCapability(
-          spv::Capability::StorageImageExtendedFormats);
+      addCapability(spv::Capability::StorageImageExtendedFormats);
       break;
     default:
       // Only image formats requiring extended formats are relevant. The rest
@@ -156,7 +169,7 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
     }
 
     if (imageType->isArrayedImage() && imageType->isMSImage())
-      spvBuilder.requireCapability(spv::Capability::ImageMSArray);
+      addCapability(spv::Capability::ImageMSArray);
 
     addCapabilityForType(imageType->getSampledType(), loc, sc);
   }
@@ -171,17 +184,16 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
   // Struct type
   else if (const auto *structType = dyn_cast<StructType>(type)) {
     if (SpirvType::isOrContains16BitType(structType)) {
-      spvBuilder.addExtension(Extension::KHR_16bit_storage,
-                              "16-bit types in resource", loc);
+      addExtension(Extension::KHR_16bit_storage, "16-bit types in resource",
+                   loc);
       if (sc == spv::StorageClass::PushConstant) {
-        spvBuilder.requireCapability(spv::Capability::StoragePushConstant16);
+        addCapability(spv::Capability::StoragePushConstant16);
       } else if (structType->getInterfaceType() ==
                  StructInterfaceType::UniformBuffer) {
-        spvBuilder.requireCapability(spv::Capability::StorageUniform16);
+        addCapability(spv::Capability::StorageUniform16);
       } else if (structType->getInterfaceType() ==
                  StructInterfaceType::StorageBuffer) {
-        spvBuilder.requireCapability(
-            spv::Capability::StorageUniformBufferBlock16);
+        addCapability(spv::Capability::StorageUniformBufferBlock16);
       }
     }
     for (auto field : structType->getFields())
@@ -193,16 +205,21 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
   const auto loc = decor->getSourceLocation();
   switch (decor->getDecoration()) {
   case spv::Decoration::Sample: {
-    spvBuilder.requireCapability(spv::Capability::SampleRateShading, loc);
+    addCapability(spv::Capability::SampleRateShading, loc);
     break;
   }
   case spv::Decoration::NonUniformEXT: {
-    spvBuilder.addExtension(Extension::EXT_descriptor_indexing, "NonUniformEXT",
-                            loc);
-    spvBuilder.requireCapability(spv::Capability::ShaderNonUniformEXT);
+    addExtension(Extension::EXT_descriptor_indexing, "NonUniformEXT", loc);
+    addCapability(spv::Capability::ShaderNonUniformEXT);
 
     break;
   }
+  case spv::Decoration::HlslSemanticGOOGLE:
+  case spv::Decoration::HlslCounterBufferGOOGLE: {
+    addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
+                 loc);
+    break;
+  }
   // Capabilities needed for built-ins
   case spv::Decoration::BuiltIn: {
     assert(decor->getParams().size() == 1);
@@ -210,74 +227,71 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
     switch (builtin) {
     case spv::BuiltIn::SampleId:
     case spv::BuiltIn::SamplePosition: {
-      spvBuilder.requireCapability(spv::Capability::SampleRateShading, loc);
+      addCapability(spv::Capability::SampleRateShading, loc);
       break;
     }
     case spv::BuiltIn::SubgroupSize:
     case spv::BuiltIn::NumSubgroups:
     case spv::BuiltIn::SubgroupId:
     case spv::BuiltIn::SubgroupLocalInvocationId: {
-      spvBuilder.requireCapability(spv::Capability::GroupNonUniform, loc);
+      addCapability(spv::Capability::GroupNonUniform, loc);
       break;
     }
     case spv::BuiltIn::BaseVertex: {
-      spvBuilder.addExtension(Extension::KHR_shader_draw_parameters,
-                              "BaseVertex Builtin", loc);
-      spvBuilder.requireCapability(spv::Capability::DrawParameters);
+      addExtension(Extension::KHR_shader_draw_parameters, "BaseVertex Builtin",
+                   loc);
+      addCapability(spv::Capability::DrawParameters);
       break;
     }
     case spv::BuiltIn::BaseInstance: {
-      spvBuilder.addExtension(Extension::KHR_shader_draw_parameters,
-                              "BaseInstance Builtin", loc);
-      spvBuilder.requireCapability(spv::Capability::DrawParameters);
+      addExtension(Extension::KHR_shader_draw_parameters,
+                   "BaseInstance Builtin", loc);
+      addCapability(spv::Capability::DrawParameters);
       break;
     }
     case spv::BuiltIn::DrawIndex: {
-      spvBuilder.addExtension(Extension::KHR_shader_draw_parameters,
-                              "DrawIndex Builtin", loc);
-      spvBuilder.requireCapability(spv::Capability::DrawParameters);
+      addExtension(Extension::KHR_shader_draw_parameters, "DrawIndex Builtin",
+                   loc);
+      addCapability(spv::Capability::DrawParameters);
       break;
     }
     case spv::BuiltIn::DeviceIndex: {
-      spvBuilder.addExtension(Extension::KHR_device_group,
-                              "DeviceIndex Builtin", loc);
-      spvBuilder.requireCapability(spv::Capability::DeviceGroup);
+      addExtension(Extension::KHR_device_group, "DeviceIndex Builtin", loc);
+      addCapability(spv::Capability::DeviceGroup);
       break;
     }
     case spv::BuiltIn::FragStencilRefEXT: {
-      spvBuilder.addExtension(Extension::EXT_shader_stencil_export,
-                              "SV_StencilRef", loc);
-      spvBuilder.requireCapability(spv::Capability::StencilExportEXT);
+      addExtension(Extension::EXT_shader_stencil_export, "SV_StencilRef", loc);
+      addCapability(spv::Capability::StencilExportEXT);
       break;
     }
     case spv::BuiltIn::ViewIndex: {
-      spvBuilder.addExtension(Extension::KHR_multiview, "SV_ViewID", loc);
-      spvBuilder.requireCapability(spv::Capability::MultiView);
+      addExtension(Extension::KHR_multiview, "SV_ViewID", loc);
+      addCapability(spv::Capability::MultiView);
       break;
     }
     case spv::BuiltIn::FullyCoveredEXT: {
-      spvBuilder.addExtension(Extension::EXT_fragment_fully_covered,
-                              "SV_InnerCoverage", loc);
-      spvBuilder.requireCapability(spv::Capability::FragmentFullyCoveredEXT);
+      addExtension(Extension::EXT_fragment_fully_covered, "SV_InnerCoverage",
+                   loc);
+      addCapability(spv::Capability::FragmentFullyCoveredEXT);
       break;
     }
     case spv::BuiltIn::PrimitiveId: {
       // PrimitiveID can be used as PSIn
       if (shaderModel == spv::ExecutionModel::Fragment)
-        spvBuilder.requireCapability(spv::Capability::Geometry);
+        addCapability(spv::Capability::Geometry);
       break;
     }
     case spv::BuiltIn::Layer: {
       if (shaderModel == spv::ExecutionModel::Vertex ||
           shaderModel == spv::ExecutionModel::TessellationControl ||
           shaderModel == spv::ExecutionModel::TessellationEvaluation) {
-        spvBuilder.addExtension(Extension::EXT_shader_viewport_index_layer,
-                                "SV_RenderTargetArrayIndex", loc);
-        spvBuilder.requireCapability(
-            spv::Capability::ShaderViewportIndexLayerEXT);
+        addExtension(Extension::EXT_shader_viewport_index_layer,
+                     "SV_RenderTargetArrayIndex", loc);
+        addCapability(spv::Capability::ShaderViewportIndexLayerEXT);
       } else if (shaderModel == spv::ExecutionModel::Fragment) {
         // SV_RenderTargetArrayIndex can be used as PSIn.
-        spvBuilder.requireCapability(spv::Capability::Geometry);
+        addCapability(spv::Capability::Geometry);
       }
       break;
     }
@@ -285,25 +299,34 @@ bool CapabilityVisitor::visit(SpirvDecoration *decor) {
       if (shaderModel == spv::ExecutionModel::Vertex ||
           shaderModel == spv::ExecutionModel::TessellationControl ||
           shaderModel == spv::ExecutionModel::TessellationEvaluation) {
-        spvBuilder.addExtension(Extension::EXT_shader_viewport_index_layer,
-                                "SV_ViewPortArrayIndex", loc);
-        spvBuilder.requireCapability(
-            spv::Capability::ShaderViewportIndexLayerEXT);
+        addExtension(Extension::EXT_shader_viewport_index_layer,
+                     "SV_ViewPortArrayIndex", loc);
+        addCapability(spv::Capability::ShaderViewportIndexLayerEXT);
       } else if (shaderModel == spv::ExecutionModel::Fragment ||
                  shaderModel == spv::ExecutionModel::Geometry) {
         // SV_ViewportArrayIndex can be used as PSIn.
-        spvBuilder.requireCapability(spv::Capability::MultiViewport);
+        addCapability(spv::Capability::MultiViewport);
       }
       break;
     }
     case spv::BuiltIn::ClipDistance: {
-      spvBuilder.requireCapability(spv::Capability::ClipDistance);
+      addCapability(spv::Capability::ClipDistance);
       break;
     }
     case spv::BuiltIn::CullDistance: {
-      spvBuilder.requireCapability(spv::Capability::CullDistance);
+      addCapability(spv::Capability::CullDistance);
       break;
     }
+    case spv::BuiltIn::BaryCoordNoPerspAMD:
+    case spv::BuiltIn::BaryCoordNoPerspCentroidAMD:
+    case spv::BuiltIn::BaryCoordNoPerspSampleAMD:
+    case spv::BuiltIn::BaryCoordSmoothAMD:
+    case spv::BuiltIn::BaryCoordSmoothCentroidAMD:
+    case spv::BuiltIn::BaryCoordSmoothSampleAMD:
+    case spv::BuiltIn::BaryCoordPullModelAMD: {
+      addExtension(Extension::AMD_shader_explicit_vertex_parameter,
+                   "SV_Barycentrics", loc);
+    }
     default:
       break;
     }
@@ -347,14 +370,14 @@ CapabilityVisitor::getNonUniformCapability(const SpirvType *type) {
 bool CapabilityVisitor::visit(SpirvImageQuery *instr) {
   addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
                        instr->getStorageClass());
-  spvBuilder.requireCapability(spv::Capability::ImageQuery);
+  addCapability(spv::Capability::ImageQuery);
   return true;
 }
 
 bool CapabilityVisitor::visit(SpirvImageSparseTexelsResident *instr) {
   addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
                        instr->getStorageClass());
-  spvBuilder.requireCapability(spv::Capability::ImageGatherExtended);
+  addCapability(spv::Capability::ImageGatherExtended);
   return true;
 }
 
@@ -362,11 +385,11 @@ bool CapabilityVisitor::visit(SpirvImageOp *instr) {
   addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
                        instr->getStorageClass());
   if (instr->hasOffset() || instr->hasConstOffsets())
-    spvBuilder.requireCapability(spv::Capability::ImageGatherExtended);
+    addCapability(spv::Capability::ImageGatherExtended);
   if (instr->hasMinLod())
-    spvBuilder.requireCapability(spv::Capability::MinLod);
+    addCapability(spv::Capability::MinLod);
   if (instr->isSparse())
-    spvBuilder.requireCapability(spv::Capability::SparseResidency);
+    addCapability(spv::Capability::SparseResidency);
 
   return true;
 }
@@ -374,17 +397,16 @@ bool CapabilityVisitor::visit(SpirvImageOp *instr) {
 bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
   const SpirvType *resultType = instr->getResultType();
   const auto opcode = instr->getopcode();
+  const auto loc = instr->getSourceLocation();
 
   // Add result-type-specific capabilities
-  addCapabilityForType(resultType, instr->getSourceLocation(),
-                       instr->getStorageClass());
+  addCapabilityForType(resultType, loc, instr->getStorageClass());
 
   // Add NonUniform capabilities if necessary
   if (instr->isNonUniform()) {
-    spvBuilder.addExtension(Extension::EXT_descriptor_indexing, "NonUniformEXT",
-                            instr->getSourceLocation());
-    spvBuilder.requireCapability(spv::Capability::ShaderNonUniformEXT);
-    spvBuilder.requireCapability(getNonUniformCapability(resultType));
+    addExtension(Extension::EXT_descriptor_indexing, "NonUniformEXT", loc);
+    addCapability(spv::Capability::ShaderNonUniformEXT);
+    addCapability(getNonUniformCapability(resultType));
   }
 
   // Add opcode-specific capabilities
@@ -395,15 +417,15 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
   case spv::Op::OpDPdxFine:
   case spv::Op::OpDPdyFine:
   case spv::Op::OpFwidthFine:
-    spvBuilder.requireCapability(spv::Capability::DerivativeControl);
+    addCapability(spv::Capability::DerivativeControl);
     break;
   case spv::Op::OpGroupNonUniformElect:
-    spvBuilder.requireCapability(spv::Capability::GroupNonUniform);
+    addCapability(spv::Capability::GroupNonUniform);
     break;
   case spv::Op::OpGroupNonUniformAny:
   case spv::Op::OpGroupNonUniformAll:
   case spv::Op::OpGroupNonUniformAllEqual:
-    spvBuilder.requireCapability(spv::Capability::GroupNonUniformVote);
+    addCapability(spv::Capability::GroupNonUniformVote);
     break;
   case spv::Op::OpGroupNonUniformBallot:
   case spv::Op::OpGroupNonUniformInverseBallot:
@@ -413,7 +435,7 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
   case spv::Op::OpGroupNonUniformBallotFindMSB:
   case spv::Op::OpGroupNonUniformBroadcast:
   case spv::Op::OpGroupNonUniformBroadcastFirst:
-    spvBuilder.requireCapability(spv::Capability::GroupNonUniformBallot);
+    addCapability(spv::Capability::GroupNonUniformBallot);
     break;
   case spv::Op::OpGroupNonUniformIAdd:
   case spv::Op::OpGroupNonUniformFAdd:
@@ -431,11 +453,11 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
   case spv::Op::OpGroupNonUniformLogicalAnd:
   case spv::Op::OpGroupNonUniformLogicalOr:
   case spv::Op::OpGroupNonUniformLogicalXor:
-    spvBuilder.requireCapability(spv::Capability::GroupNonUniformArithmetic);
+    addCapability(spv::Capability::GroupNonUniformArithmetic);
     break;
   case spv::Op::OpGroupNonUniformQuadBroadcast:
   case spv::Op::OpGroupNonUniformQuadSwap:
-    spvBuilder.requireCapability(spv::Capability::GroupNonUniformQuad);
+    addCapability(spv::Capability::GroupNonUniformQuad);
     break;
   default:
     break;
@@ -450,14 +472,14 @@ bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
   case spv::ExecutionModel::Fragment:
   case spv::ExecutionModel::Vertex:
   case spv::ExecutionModel::GLCompute:
-    spvBuilder.requireCapability(spv::Capability::Shader);
+    addCapability(spv::Capability::Shader);
     break;
   case spv::ExecutionModel::Geometry:
-    spvBuilder.requireCapability(spv::Capability::Geometry);
+    addCapability(spv::Capability::Geometry);
     break;
   case spv::ExecutionModel::TessellationControl:
   case spv::ExecutionModel::TessellationEvaluation:
-    spvBuilder.requireCapability(spv::Capability::Tessellation);
+    addCapability(spv::Capability::Tessellation);
     break;
   case spv::ExecutionModel::RayGenerationNV:
   case spv::ExecutionModel::IntersectionNV:
@@ -465,9 +487,8 @@ bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
   case spv::ExecutionModel::AnyHitNV:
   case spv::ExecutionModel::MissNV:
   case spv::ExecutionModel::CallableNV:
-    spvBuilder.requireCapability(spv::Capability::RayTracingNV);
-    spvBuilder.addExtension(Extension::NV_ray_tracing, "SPV_NV_ray_tracing",
-                            {});
+    addCapability(spv::Capability::RayTracingNV);
+    addExtension(Extension::NV_ray_tracing, "SPV_NV_ray_tracing", {});
     break;
   default:
     llvm_unreachable("found unknown shader model");
@@ -478,9 +499,10 @@ bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
 
 bool CapabilityVisitor::visit(SpirvExecutionMode *execMode) {
   if (execMode->getExecutionMode() == spv::ExecutionMode::PostDepthCoverage) {
-    spvBuilder.requireCapability(
-        spv::Capability::SampleMaskPostDepthCoverage,
-        execMode->getEntryPoint()->getSourceLocation());
+    addCapability(spv::Capability::SampleMaskPostDepthCoverage,
+                  execMode->getEntryPoint()->getSourceLocation());
+    addExtension(Extension::KHR_post_depth_coverage,
+                 "[[vk::post_depth_coverage]]", execMode->getSourceLocation());
   }
   return true;
 }

+ 18 - 5
tools/clang/lib/SPIRV/CapabilityVisitor.h

@@ -10,6 +10,7 @@
 #ifndef LLVM_CLANG_LIB_SPIRV_CAPABILITYVISITOR_H
 #define LLVM_CLANG_LIB_SPIRV_CAPABILITYVISITOR_H
 
+#include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/SpirvContext.h"
 #include "clang/SPIRV/SpirvVisitor.h"
 
@@ -20,9 +21,10 @@ class SpirvBuilder;
 
 class CapabilityVisitor : public Visitor {
 public:
-  CapabilityVisitor(SpirvContext &spvCtx, const SpirvCodeGenOptions &opts,
-                    SpirvBuilder &builder)
-      : Visitor(opts, spvCtx), spvBuilder(builder) {}
+  CapabilityVisitor(ASTContext &astCtx, SpirvContext &spvCtx,
+                    const SpirvCodeGenOptions &opts, SpirvBuilder &builder)
+      : Visitor(opts, spvCtx), spvBuilder(builder),
+        featureManager(astCtx.getDiagnostics(), opts) {}
 
   bool visit(SpirvDecoration *decor);
   bool visit(SpirvEntryPoint *);
@@ -46,13 +48,24 @@ private:
   void addCapabilityForType(const SpirvType *, SourceLocation loc,
                             spv::StorageClass sc);
 
+  /// Checks that the given extension is a valid extension for the target
+  /// environment (e.g. Vulkan 1.0). And if so, utilizes the SpirvBuilder to add
+  /// the given extension to the SPIR-V module in memory.
+  void addExtension(Extension ext, llvm::StringRef target, SourceLocation loc);
+
+  /// Checks that the given capability is a valid capability. And if so,
+  /// utilizes the SpirvBuilder to add the given capability to the SPIR-V module
+  /// in memory.
+  void addCapability(spv::Capability, SourceLocation loc = {});
+
   /// Returns the capability required to non-uniformly index into the given
   /// type.
   spv::Capability getNonUniformCapability(const SpirvType *);
 
 private:
-  SpirvBuilder &spvBuilder;        /// SPIR-V builder
-  spv::ExecutionModel shaderModel; /// Execution model
+  SpirvBuilder &spvBuilder;        ///< SPIR-V builder
+  spv::ExecutionModel shaderModel; ///< Execution model
+  FeatureManager featureManager;   ///< SPIR-V version/extension manager.
 };
 
 } // end namespace spirv

+ 1 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -2489,10 +2489,8 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragStencilRefEXT,
                                          srcLoc);
   }
-  // According to DXIL spec, the ViewID SV can only be used by PSIn.
+  // According to DXIL spec, the Barycentrics SV can only be used by PSIn.
   case hlsl::Semantic::Kind::Barycentrics: {
-    spvBuilder.addExtension(Extension::AMD_shader_explicit_vertex_parameter,
-                            stageVar->getSemanticStr(), srcLoc);
     stageVar->setIsSpirvBuiltin();
 
     // Selecting the correct builtin according to interpolation mode

+ 3 - 23
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -18,9 +18,9 @@ namespace clang {
 namespace spirv {
 
 SpirvBuilder::SpirvBuilder(ASTContext &ac, SpirvContext &ctx,
-                           FeatureManager *fm, const SpirvCodeGenOptions &opt)
+                           const SpirvCodeGenOptions &opt)
     : astContext(ac), context(ctx), module(nullptr), function(nullptr),
-      featureManager(fm), spirvOptions(opt) {
+      spirvOptions(opt) {
   module = new (context) SpirvModule;
 }
 
@@ -814,22 +814,6 @@ SpirvBuilder::createRayTracingOpsNV(spv::Op opcode, QualType resultType,
   return inst;
 }
 
-void SpirvBuilder::addExtension(Extension ext, llvm::StringRef target,
-                                SourceLocation loc) {
-  // TODO: The extension management should be removed from here and added as a
-  // separate pass.
-
-  if (existingExtensions.insert(ext)) {
-    assert(featureManager);
-    featureManager->requestExtension(ext, target, loc);
-    // Do not emit OpExtension if the given extension is natively supported in
-    // the target environment.
-    if (featureManager->isExtensionRequiredForTargetEnv(ext))
-      module->addExtension(new (context) SpirvExtension(
-          loc, featureManager->getExtensionName(ext)));
-  }
-}
-
 void SpirvBuilder::addModuleProcessed(llvm::StringRef process) {
   module->addModuleProcessed(new (context) SpirvModuleProcessed({}, process));
 }
@@ -956,8 +940,6 @@ void SpirvBuilder::decorateCounterBuffer(SpirvInstruction *mainBuffer,
                                          SpirvInstruction *counterBuffer,
                                          SourceLocation srcLoc) {
   if (spirvOptions.enableReflect) {
-    addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
-                 srcLoc);
     auto *decor = new (context) SpirvDecoration(
         srcLoc, mainBuffer, spv::Decoration::HlslCounterBufferGOOGLE,
         {counterBuffer});
@@ -970,8 +952,6 @@ void SpirvBuilder::decorateHlslSemantic(SpirvInstruction *target,
                                         llvm::Optional<uint32_t> memberIdx,
                                         SourceLocation srcLoc) {
   if (spirvOptions.enableReflect) {
-    addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
-                 srcLoc);
     auto *decor = new (context)
         SpirvDecoration(srcLoc, target, spv::Decoration::HlslSemanticGOOGLE,
                         semantic, memberIdx);
@@ -1082,7 +1062,7 @@ std::vector<uint32_t> SpirvBuilder::takeModule() {
   // Run necessary visitor passes first
   LiteralTypeVisitor literalTypeVisitor(astContext, context, spirvOptions);
   LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions);
-  CapabilityVisitor capabilityVisitor(context, spirvOptions, *this);
+  CapabilityVisitor capabilityVisitor(astContext, context, spirvOptions, *this);
   EmitVisitor emitVisitor(astContext, context, spirvOptions);
 
   module->invokeVisitor(&literalTypeVisitor, true);

+ 1 - 3
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -481,7 +481,7 @@ SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
       spirvOptions(ci.getCodeGenOpts().SpirvOptions),
       entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction), spvContext(),
       featureManager(diags, spirvOptions),
-      spvBuilder(astContext, spvContext, &featureManager, spirvOptions),
+      spvBuilder(astContext, spvContext, spirvOptions),
       declIdMapper(astContext, spvContext, spvBuilder, *this, featureManager,
                    spirvOptions),
       entryFunction(nullptr), curFunction(nullptr), curThis(nullptr),
@@ -9422,8 +9422,6 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
                                 decl->getLocation());
   }
   if (decl->getAttr<VKPostDepthCoverageAttr>()) {
-    spvBuilder.addExtension(Extension::KHR_post_depth_coverage,
-                            "[[vk::post_depth_coverage]]", decl->getLocation());
     spvBuilder.addExecutionMode(entryFunction,
                                 spv::ExecutionMode::PostDepthCoverage, {},
                                 decl->getLocation());

+ 22 - 2
tools/clang/lib/SPIRV/SpirvModule.cpp

@@ -89,7 +89,17 @@ void SpirvModule::addFunction(SpirvFunction *fn) {
 
 void SpirvModule::addCapability(SpirvCapability *cap) {
   assert(cap && "cannot add null capability to the module");
-  capabilities.push_back(cap);
+  // Only add the capability to the module if it is not already added.
+  // Due to the small number of capabilities, this should not be too expensive.
+  const spv::Capability capability = cap->getCapability();
+  auto found =
+      std::find_if(capabilities.begin(), capabilities.end(),
+                   [capability](SpirvCapability *existingCapability) {
+                     return capability == existingCapability->getCapability();
+                   });
+  if (found == capabilities.end()) {
+    capabilities.push_back(cap);
+  }
 }
 
 void SpirvModule::setMemoryModel(SpirvMemoryModel *model) {
@@ -109,7 +119,17 @@ void SpirvModule::addExecutionMode(SpirvExecutionMode *em) {
 
 void SpirvModule::addExtension(SpirvExtension *ext) {
   assert(ext && "cannot add null extension");
-  extensions.push_back(ext);
+  // Only add the extension to the module if it is not already added.
+  // Due to the small number of extensions, this should not be too expensive.
+  const auto extName = ext->getExtensionName();
+  auto found =
+      std::find_if(extensions.begin(), extensions.end(),
+                   [&extName](SpirvExtension *existingExtension) {
+                     return extName == existingExtension->getExtensionName();
+                   });
+  if (found == extensions.end()) {
+    extensions.push_back(ext);
+  }
 }
 
 void SpirvModule::addExtInstSet(SpirvExtInstImport *set) {