瀏覽代碼

[spirv] Allow explicitly controlling SPIR-V extensions (#1151)

Added FeatureManager to record all extensions specified from
the command-line and emit error if trying to use one not permitted.

Added command-line option -fspv-extension= to specify
whitelisted extensions.
Lei Zhang 7 年之前
父節點
當前提交
241d32c810

+ 22 - 1
docs/SPIR-V.rst

@@ -267,7 +267,7 @@ The namespace ``vk`` will be used for all Vulkan attributes:
 - ``push_constant``: For marking a variable as the push constant block. Allowed
   on global variables of struct type. At most one variable can be marked as
   ``push_constant`` in a shader.
-- ``constant_id``: For marking a global constant as a specialization constant.
+- ``constant_id(X)``: For marking a global constant as a specialization constant.
   Allowed on global variables of boolean/integer/float types.
 - ``input_attachment_index(X)``: To associate the Xth entry in the input pass
   list to the annotated object. Only allowed on objects whose type are
@@ -294,6 +294,25 @@ interface variables:
   main([[vk::location(N)]] float4 input: A) : B
   { ... }
 
+SPIR-V version and extension
+----------------------------
+
+In the **defult** mode (without ``-fspv-extension=<extension>`` command-line
+option), SPIR-V CodeGen will try its best to use the lowest SPIR-V version, and
+only require higher SPIR-V versions and extensions when they are truly needed
+for translating the input source code.
+
+For example, unless `Shader Model 6.0 wave intrinsics`_ are used, the generated
+SPIR-V will always be of version 1.0. The ``SPV_KHR_multivew`` extension will
+not be emitted unless you use ``SV_ViewID``.
+
+You can of course have fine-grained control of what extensions are permitted
+in the CodeGen using the **explicit** mode, turned on by the
+``-fspv-extension=<extension>`` command-line option. Only extensions supplied
+via ``-fspv-extension=`` will be used. If that does not suffice, errors will
+be emitted explaining what additional extensions are required to translate what
+specific feature in the source code.
+
 Legalization, optimization, validation
 --------------------------------------
 
@@ -2631,6 +2650,8 @@ codegen for Vulkan:
   location number according to alphabetical order or declaration order. See
   `HLSL semantic and Vulkan Location`_ for more details.
 - ``-fspv-reflect``: Emits additional SPIR-V instructions to aid reflection.
+- ``-fspv-extension=<extension>``: Only allows using ``<extension>`` in CodeGen.
+  If you want to allow multiple extensions, provide more than one such option.
 
 Unsupported HLSL Features
 =========================

+ 7 - 6
include/dxc/Support/HLSLOptions.h

@@ -160,16 +160,17 @@ public:
 
   // SPIRV Change Starts
 #ifdef ENABLE_SPIRV_CODEGEN
-  bool GenSPIRV; // OPT_spirv
-  bool VkIgnoreUnusedResources; // OPT_fvk_ignore_used_resources
-  bool VkInvertY; // OPT_fvk_invert_y
-  bool VkUseGlslLayout; // OPT_fvk_use_glsl_layout
-  bool SpvEnableReflect; // OPT_fspv_reflect
-  llvm::StringRef VkStageIoOrder; // OPT_fvk_stage_io_order
+  bool GenSPIRV;                           // OPT_spirv
+  bool VkIgnoreUnusedResources;            // OPT_fvk_ignore_used_resources
+  bool VkInvertY;                          // OPT_fvk_invert_y
+  bool VkUseGlslLayout;                    // OPT_fvk_use_glsl_layout
+  bool SpvEnableReflect;                   // OPT_fspv_reflect
+  llvm::StringRef VkStageIoOrder;          // OPT_fvk_stage_io_order
   llvm::SmallVector<uint32_t, 4> VkBShift; // OPT_fvk_b_shift
   llvm::SmallVector<uint32_t, 4> VkTShift; // OPT_fvk_t_shift
   llvm::SmallVector<uint32_t, 4> VkSShift; // OPT_fvk_s_shift
   llvm::SmallVector<uint32_t, 4> VkUShift; // OPT_fvk_u_shift
+  llvm::SmallVector<llvm::StringRef, 4> SpvExtensions; // OPT_fspv_extension
 #endif
   // SPIRV Change Ends
 };

+ 2 - 0
include/dxc/Support/HLSLOptions.td

@@ -254,6 +254,8 @@ def fvk_use_glsl_layout: Flag<["-"], "fvk-use-glsl-layout">, Group<spirv_Group>,
   HelpText<"Use conventional GLSL std140/std430 layout for resources">;
 def fspv_reflect: Flag<["-"], "fspv-reflect">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
   HelpText<"Emit additional SPIR-V instructions to aid reflection">;
+def fspv_extension_EQ : Joined<["-"], "fspv-extension=">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
+  HelpText<"Specify SPIR-V extension permitted to use">;
 // SPIRV Change Ends
 
 //////////////////////////////////////////////////////////////////////////////

+ 5 - 0
lib/DxcSupport/HLSLOptions.cpp

@@ -522,6 +522,10 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
            << opts.VkStageIoOrder;
     return 1;
   }
+
+  for (const Arg *A : Args.filtered(OPT_fspv_extension_EQ)) {
+    opts.SpvExtensions.push_back(A->getValue());
+  }
 #else
   if (Args.hasFlag(OPT_spirv, OPT_INVALID, false) ||
       Args.hasFlag(OPT_fvk_invert_y, OPT_INVALID, false) ||
@@ -529,6 +533,7 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
       Args.hasFlag(OPT_fspv_reflect, OPT_INVALID, false) ||
       Args.hasFlag(OPT_fvk_ignore_unused_resources, OPT_INVALID, false) ||
       !Args.getLastArgValue(OPT_fvk_stage_io_order_EQ).empty() ||
+      !Args.getLastArgValue(OPT_fspv_extension_EQ).empty() ||
       !Args.getLastArgValue(OPT_fvk_b_shift).empty() ||
       !Args.getLastArgValue(OPT_fvk_t_shift).empty() ||
       !Args.getLastArgValue(OPT_fvk_s_shift).empty() ||

+ 1 - 0
tools/clang/include/clang/SPIRV/EmitSPIRVOptions.h

@@ -29,6 +29,7 @@ struct EmitSPIRVOptions {
   llvm::SmallVector<uint32_t, 4> tShift;
   llvm::SmallVector<uint32_t, 4> sShift;
   llvm::SmallVector<uint32_t, 4> uShift;
+  llvm::SmallVector<llvm::StringRef, 4> allowedExtensions;
 };
 } // end namespace clang
 

+ 90 - 0
tools/clang/include/clang/SPIRV/FeatureManager.h

@@ -0,0 +1,90 @@
+//===------ FeatureManager.h - SPIR-V Version/Extension Manager -*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//===----------------------------------------------------------------------===//
+//
+//  This file defines a SPIR-V version and extension manager.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_LIB_SPIRV_FEATUREMANAGER_H
+#define LLVM_CLANG_LIB_SPIRV_FEATUREMANAGER_H
+
+#include <string>
+
+#include "clang/Basic/Diagnostic.h"
+#include "clang/Basic/SourceLocation.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace clang {
+namespace spirv {
+
+/// A list of SPIR-V extensions known to our CodeGen.
+enum class Extension {
+  KHR_device_group,
+  KHR_multiview,
+  KHR_shader_draw_parameters,
+  EXT_fragment_fully_covered,
+  EXT_shader_stencil_export,
+  AMD_gpu_shader_half_float,
+  AMD_shader_explicit_vertex_parameter,
+  GOOGLE_decorate_string,
+  GOOGLE_hlsl_functionality1,
+  Unknown,
+};
+
+/// The class for handling SPIR-V version and extension requests.
+class FeatureManager {
+public:
+  explicit FeatureManager(DiagnosticsEngine &de);
+
+  /// Allows the given extension to be used in CodeGen.
+  bool allowExtension(llvm::StringRef);
+  /// Allows all extensions to be used in CodeGen.
+  void allowAllKnownExtensions();
+  /// Rqeusts the given extension for translating the given target feature at
+  /// the given source location. Emits an error if the given extension is not
+  /// permitted to use.
+  bool requestExtension(Extension, llvm::StringRef target, SourceLocation);
+
+  /// Translates extension name to symbol.
+  static Extension getExtensionSymbol(llvm::StringRef name);
+  /// Translates extension symbol to name.
+  static const char *getExtensionName(Extension symbol);
+
+  /// Returns the names of all known extensions as a string.
+  std::string getKnownExtensions(const char *delimiter, const char *prefix = "",
+                                 const char *postfix = "");
+
+private:
+  /// \brief Wrapper method to create an error message and report it
+  /// in the diagnostic engine associated with this object.
+  template <unsigned N>
+  DiagnosticBuilder emitError(const char (&message)[N], SourceLocation loc) {
+    const auto diagId =
+        diags.getCustomDiagID(clang::DiagnosticsEngine::Error, message);
+    return diags.Report(loc, diagId);
+  }
+
+  /// \brief Wrapper method to create an note message and report it
+  /// in the diagnostic engine associated with this object.
+  template <unsigned N>
+  DiagnosticBuilder emitNote(const char (&message)[N], SourceLocation loc) {
+    const auto diagId =
+        diags.getCustomDiagID(clang::DiagnosticsEngine::Note, message);
+    return diags.Report(loc, diagId);
+  }
+
+  DiagnosticsEngine &diags;
+
+  llvm::SmallBitVector allowedExtensions;
+};
+
+} // end namespace spirv
+} // end namespace clang
+
+#endif // LLVM_CLANG_LIB_SPIRV_FEATUREMANAGER_H

+ 9 - 11
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -14,6 +14,7 @@
 #include <vector>
 
 #include "clang/AST/Type.h"
+#include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/InstBuilder.h"
 #include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/Structure.h"
@@ -35,7 +36,7 @@ namespace spirv {
 class ModuleBuilder {
 public:
   /// \brief Constructs a ModuleBuilder with the given SPIR-V context.
-  explicit ModuleBuilder(SPIRVContext *, bool enablReflect);
+  ModuleBuilder(SPIRVContext *, FeatureManager *features, bool enableReflect);
 
   /// \brief Returns the associated SPIRVContext.
   inline SPIRVContext *getSPIRVContext();
@@ -335,8 +336,9 @@ public:
   void addExecutionMode(uint32_t entryPointId, spv::ExecutionMode em,
                         llvm::ArrayRef<uint32_t> params);
 
-  /// \brief Adds an extension to the module under construction.
-  inline void addExtension(llvm::StringRef extension);
+  /// \brief Adds an extension to the module under construction for translating
+  /// the given target at the given source location.
+  void addExtension(Extension, llvm::StringRef target, SourceLocation);
 
   /// \brief If not added already, adds an OpExtInstImport (import of extended
   /// instruction set) of the GLSL instruction set. Returns the <result-id> for
@@ -468,11 +470,11 @@ private:
       uint32_t sample, uint32_t minLod,
       llvm::SmallVectorImpl<uint32_t> *orderedParams);
 
-  SPIRVContext &theContext; ///< The SPIR-V context.
-  SPIRVModule theModule;    ///< The module under building.
-
-  const bool allowReflect; ///< Whether allow reflect instructions.
+  SPIRVContext &theContext;       ///< The SPIR-V context.
+  FeatureManager *featureManager; ///< SPIR-V version/extension manager.
+  const bool allowReflect;        ///< Whether allow reflect instructions.
 
+  SPIRVModule theModule;                 ///< The module under building.
   std::unique_ptr<Function> theFunction; ///< The function under building.
   OrderedBasicBlockMap basicBlocks;      ///< The basic blocks under building.
   BasicBlock *insertPoint;               ///< The current insertion point.
@@ -516,10 +518,6 @@ void ModuleBuilder::setShaderModelVersion(uint32_t major, uint32_t minor) {
   theModule.setShaderModelVersion(major * 100 + minor * 10);
 }
 
-void ModuleBuilder::addExtension(llvm::StringRef extension) {
-  theModule.addExtension(extension);
-}
-
 } // end namespace spirv
 } // end namespace clang
 

+ 1 - 0
tools/clang/lib/SPIRV/CMakeLists.txt

@@ -8,6 +8,7 @@ add_clang_library(clangSPIRV
   DeclResultIdMapper.cpp
   Decoration.cpp
   EmitSPIRVAction.cpp
+  FeatureManager.cpp
   GlPerVertex.cpp
   InitListHandler.cpp
   InstBuilderAuto.cpp

+ 13 - 6
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -1872,11 +1872,14 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
     case BuiltIn::BaseVertex:
     case BuiltIn::BaseInstance:
     case BuiltIn::DrawIndex:
-      theBuilder.addExtension("SPV_KHR_shader_draw_parameters");
+      theBuilder.addExtension(Extension::KHR_shader_draw_parameters,
+                              builtinAttr->getBuiltIn(),
+                              builtinAttr->getLocation());
       theBuilder.requireCapability(spv::Capability::DrawParameters);
       break;
     case BuiltIn::DeviceIndex:
-      theBuilder.addExtension("SPV_KHR_device_group");
+      theBuilder.addExtension(Extension::KHR_device_group,
+                              stageVar->getSemanticStr(), srcLoc);
       theBuilder.requireCapability(spv::Capability::DeviceGroup);
       break;
     }
@@ -2095,7 +2098,8 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
   }
   // According to DXIL spec, the StencilRef SV can only be used by PSOut.
   case hlsl::Semantic::Kind::StencilRef: {
-    theBuilder.addExtension("SPV_EXT_shader_stencil_export");
+    theBuilder.addExtension(Extension::EXT_shader_stencil_export,
+                            stageVar->getSemanticStr(), srcLoc);
     theBuilder.requireCapability(spv::Capability::StencilExportEXT);
 
     stageVar->setIsSpirvBuiltin();
@@ -2103,7 +2107,8 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
   }
   // According to DXIL spec, the ViewID SV can only be used by PSIn.
   case hlsl::Semantic::Kind::Barycentrics: {
-    theBuilder.addExtension("SPV_AMD_shader_explicit_vertex_parameter");
+    theBuilder.addExtension(Extension::AMD_shader_explicit_vertex_parameter,
+                            stageVar->getSemanticStr(), srcLoc);
     stageVar->setIsSpirvBuiltin();
 
     // Selecting the correct builtin according to interpolation mode
@@ -2192,7 +2197,8 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
   // According to Vulkan spec, the ViewIndex BuiltIn can only be used in
   // VS/HS/DS/GS/PS input.
   case hlsl::Semantic::Kind::ViewID: {
-    theBuilder.addExtension("SPV_KHR_multiview");
+    theBuilder.addExtension(Extension::KHR_multiview,
+                            stageVar->getSemanticStr(), srcLoc);
     theBuilder.requireCapability(spv::Capability::MultiView);
 
     stageVar->setIsSpirvBuiltin();
@@ -2202,7 +2208,8 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
     // According to Vulkan spec, the FullyCoveredEXT BuiltIn can only be used as
     // PSIn.
   case hlsl::Semantic::Kind::InnerCoverage: {
-    theBuilder.addExtension("SPV_EXT_fragment_fully_covered");
+    theBuilder.addExtension(Extension::EXT_fragment_fully_covered,
+                            stageVar->getSemanticStr(), srcLoc);
     theBuilder.requireCapability(spv::Capability::FragmentFullyCoveredEXT);
 
     stageVar->setIsSpirvBuiltin();

+ 8 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -19,6 +19,7 @@
 #include "spirv/unified1/spirv.hpp11"
 #include "clang/AST/Attr.h"
 #include "clang/SPIRV/EmitSPIRVOptions.h"
+#include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/ModuleBuilder.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Optional.h"
@@ -258,7 +259,8 @@ private:
 class DeclResultIdMapper {
 public:
   inline DeclResultIdMapper(const hlsl::ShaderModel &stage, ASTContext &context,
-                            ModuleBuilder &builder,
+                            ModuleBuilder &builder, TypeTranslator &translator,
+                            FeatureManager &features,
                             const EmitSPIRVOptions &spirvOptions);
 
   /// \brief Returns the <result-id> for a SPIR-V builtin variable.
@@ -632,7 +634,8 @@ private:
   ASTContext &astContext;
   DiagnosticsEngine &diags;
 
-  TypeTranslator typeTranslator;
+  TypeTranslator &typeTranslator;
+  FeatureManager &featureManager;
 
   uint32_t entryFunctionId;
 
@@ -737,10 +740,12 @@ void CounterIdAliasPair::assign(const CounterIdAliasPair &srcPair,
 DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
                                        ASTContext &context,
                                        ModuleBuilder &builder,
+                                       TypeTranslator &translator,
+                                       FeatureManager &features,
                                        const EmitSPIRVOptions &options)
     : shaderModel(model), theBuilder(builder), spirvOptions(options),
       astContext(context), diags(context.getDiagnostics()),
-      typeTranslator(context, builder, diags, options), entryFunctionId(0),
+      typeTranslator(translator), featureManager(features), entryFunctionId(0),
       laneCountBuiltinId(0), laneIndexBuiltinId(0), needsLegalization(false),
       glPerVertex(model, context, builder, typeTranslator, options.invertY) {}
 

+ 118 - 0
tools/clang/lib/SPIRV/FeatureManager.cpp

@@ -0,0 +1,118 @@
+//===---- FeatureManager.cpp - SPIR-V Version/Extension Manager -*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//===----------------------------------------------------------------------===//
+
+#include "clang/SPIRV/FeatureManager.h"
+
+#include <sstream>
+
+#include "llvm/ADT/StringSwitch.h"
+
+namespace clang {
+namespace spirv {
+
+FeatureManager::FeatureManager(DiagnosticsEngine &de) : diags(de) {
+  allowedExtensions.resize(static_cast<unsigned>(Extension::Unknown) + 1);
+}
+
+bool FeatureManager::allowExtension(llvm::StringRef name) {
+  const auto symbol = getExtensionSymbol(name);
+  if (symbol == Extension::Unknown) {
+    emitError("unknown SPIR-V extension '%0'", {}) << name;
+    emitNote("known extensions are\n%0", {})
+        << getKnownExtensions("\n* ", "* ");
+    return false;
+  }
+
+  allowedExtensions.set(static_cast<unsigned>(symbol));
+  if (symbol == Extension::GOOGLE_hlsl_functionality1)
+    allowedExtensions.set(
+        static_cast<unsigned>(Extension::GOOGLE_decorate_string));
+
+  return true;
+}
+
+void FeatureManager::allowAllKnownExtensions() { allowedExtensions.set(); }
+
+bool FeatureManager::requestExtension(Extension ext, llvm::StringRef target,
+                                      SourceLocation srcLoc) {
+  if (allowedExtensions.test(static_cast<unsigned>(ext)))
+    return true;
+
+  emitError("SPIR-V extension '%0' required for %1 but not permitted to use",
+            srcLoc)
+      << getExtensionName(ext) << target;
+  return false;
+}
+
+Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
+  return llvm::StringSwitch<Extension>(name)
+      .Case("SPV_KHR_device_group", Extension::KHR_device_group)
+      .Case("SPV_KHR_multiview", Extension::KHR_multiview)
+      .Case("SPV_KHR_shader_draw_parameters",
+            Extension::KHR_shader_draw_parameters)
+      .Case("SPV_EXT_fragment_fully_covered",
+            Extension::EXT_fragment_fully_covered)
+      .Case("SPV_EXT_shader_stencil_export",
+            Extension::EXT_shader_stencil_export)
+      .Case("SPV_AMD_gpu_shader_half_float",
+            Extension::AMD_gpu_shader_half_float)
+      .Case("SPV_AMD_shader_explicit_vertex_parameter",
+            Extension::AMD_shader_explicit_vertex_parameter)
+      .Case("SPV_GOOGLE_decorate_string", Extension::GOOGLE_decorate_string)
+      .Case("SPV_GOOGLE_hlsl_functionality1",
+            Extension::GOOGLE_hlsl_functionality1)
+      .Default(Extension::Unknown);
+}
+
+const char *FeatureManager::getExtensionName(Extension symbol) {
+  switch (symbol) {
+  case Extension::KHR_device_group:
+    return "SPV_KHR_device_group";
+  case Extension::KHR_multiview:
+    return "SPV_KHR_multiview";
+  case Extension::KHR_shader_draw_parameters:
+    return "SPV_KHR_shader_draw_parameters";
+  case Extension::EXT_fragment_fully_covered:
+    return "SPV_EXT_fragment_fully_covered";
+  case Extension::EXT_shader_stencil_export:
+    return "SPV_EXT_shader_stencil_export";
+  case Extension::AMD_gpu_shader_half_float:
+    return "SPV_AMD_gpu_shader_half_float";
+  case Extension::AMD_shader_explicit_vertex_parameter:
+    return "SPV_AMD_shader_explicit_vertex_parameter";
+  case Extension::GOOGLE_decorate_string:
+    return "SPV_GOOGLE_decorate_string";
+  case Extension::GOOGLE_hlsl_functionality1:
+    return "SPV_GOOGLE_hlsl_functionality1";
+  default:
+    break;
+  }
+  return "<unknown extension>";
+}
+
+std::string FeatureManager::getKnownExtensions(const char *delimiter,
+                                               const char *prefix,
+                                               const char *postfix) {
+  std::ostringstream oss;
+
+  oss << prefix;
+
+  const auto numExtensions = static_cast<uint32_t>(Extension::Unknown);
+  for (uint32_t i = 0; i < numExtensions; ++i) {
+    oss << getExtensionName(static_cast<Extension>(i));
+    if (i + 1 < numExtensions)
+      oss << delimiter;
+  }
+
+  oss << postfix;
+
+  return oss.str();
+}
+
+} // end namespace spirv
+} // end namespace clang

+ 18 - 7
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -18,9 +18,11 @@
 namespace clang {
 namespace spirv {
 
-ModuleBuilder::ModuleBuilder(SPIRVContext *C, bool reflect)
-    : theContext(*C), theModule(), allowReflect(reflect), theFunction(nullptr),
-      insertPoint(nullptr), instBuilder(nullptr), glslExtSetId(0) {
+ModuleBuilder::ModuleBuilder(SPIRVContext *C, FeatureManager *features,
+                             bool reflect)
+    : theContext(*C), featureManager(features), allowReflect(reflect),
+      theModule(), theFunction(nullptr), insertPoint(nullptr),
+      instBuilder(nullptr), glslExtSetId(0) {
   instBuilder.setConsumer([this](std::vector<uint32_t> &&words) {
     this->constructSite = std::move(words);
   });
@@ -752,6 +754,13 @@ void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
   theModule.addExecutionMode(std::move(constructSite));
 }
 
+void ModuleBuilder::addExtension(Extension ext, llvm::StringRef target,
+                                 SourceLocation srcLoc) {
+  assert(featureManager);
+  featureManager->requestExtension(ext, target, srcLoc);
+  theModule.addExtension(featureManager->getExtensionName(ext));
+}
+
 uint32_t ModuleBuilder::getGLSLExtInstSet() {
   if (glslExtSetId == 0) {
     glslExtSetId = theContext.takeNextId();
@@ -817,7 +826,8 @@ void ModuleBuilder::decorateInputAttachmentIndex(uint32_t targetId,
 void ModuleBuilder::decorateCounterBufferId(uint32_t mainBufferId,
                                             uint32_t counterBufferId) {
   if (allowReflect) {
-    addExtension("SPV_GOOGLE_hlsl_functionality1");
+    addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
+                 {});
     theModule.addDecoration(
         Decoration::getHlslCounterBufferGOOGLE(theContext, counterBufferId),
         mainBufferId);
@@ -828,8 +838,9 @@ void ModuleBuilder::decorateHlslSemantic(uint32_t targetId,
                                          llvm::StringRef semantic,
                                          llvm::Optional<uint32_t> memberIdx) {
   if (allowReflect) {
-    addExtension("SPV_GOOGLE_decorate_string");
-    addExtension("SPV_GOOGLE_hlsl_functionality1");
+    addExtension(Extension::GOOGLE_decorate_string, "SPIR-V reflection", {});
+    addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
+                 {});
     theModule.addDecoration(
         Decoration::getHlslSemanticGOOGLE(theContext, semantic, memberIdx),
         targetId);
@@ -902,7 +913,7 @@ IMPL_GET_PRIMITIVE_TYPE(Float32)
                                                                                \
   uint32_t ModuleBuilder::get##ty##Type() {                                    \
     if (spv::Capability::cap == spv::Capability::Float16)                      \
-      theModule.addExtension("SPV_AMD_gpu_shader_half_float");                 \
+      addExtension(Extension::AMD_gpu_shader_half_float, "16-bit float", {});  \
     else                                                                       \
       requireCapability(spv::Capability::cap);                                 \
     const Type *type = Type::get##ty(theContext);                              \

+ 14 - 2
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -519,17 +519,29 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
       entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
       shaderModel(*hlsl::ShaderModel::GetByName(
           ci.getCodeGenOpts().HLSLProfile.c_str())),
-      theContext(), theBuilder(&theContext, options.enableReflect),
-      declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
+      theContext(), featureManager(diags),
+      theBuilder(&theContext, &featureManager, options.enableReflect),
       typeTranslator(astContext, theBuilder, diags, options),
+      declIdMapper(shaderModel, astContext, theBuilder, typeTranslator,
+                   featureManager, spirvOptions),
       entryFunctionId(0), curFunction(nullptr), curThis(0),
       seenPushConstantAt(), isSpecConstantMode(false), needsLegalization(false),
       needsSpirv1p3(false) {
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
     emitError("unknown shader module: %0", {}) << shaderModel.GetName();
+
   if (options.invertY && !shaderModel.IsVS() && !shaderModel.IsDS() &&
       !shaderModel.IsGS())
     emitError("-fvk-invert-y can only be used in VS/DS/GS", {});
+
+  if (options.allowedExtensions.empty()) {
+    // If no explicit extension control from command line, use the default mode:
+    // allowing all extensions.
+    featureManager.allowAllKnownExtensions();
+  } else {
+    for (auto ext : options.allowedExtensions)
+      featureManager.allowExtension(ext);
+  }
 }
 
 void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {

+ 3 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -28,6 +28,7 @@
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/SPIRV/EmitSPIRVOptions.h"
+#include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/ModuleBuilder.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
@@ -911,9 +912,10 @@ private:
   const hlsl::ShaderModel &shaderModel;
 
   SPIRVContext theContext;
+  FeatureManager featureManager;
   ModuleBuilder theBuilder;
-  DeclResultIdMapper declIdMapper;
   TypeTranslator typeTranslator;
+  DeclResultIdMapper declIdMapper;
 
   /// A queue of decls reachable from the entry function. Decls inserted into
   /// this queue will persist to avoid duplicated translations. And we'd like

+ 10 - 0
tools/clang/test/CodeGenSPIRV/spirv.ext.cl.allow.hlsl

@@ -0,0 +1,10 @@
+// Run: %dxc -T vs_6_1 -E main -fspv-extension=SPV_KHR_multiview -fspv-extension=SPV_KHR_shader_draw_parameters
+
+// CHECK:      OpExtension "SPV_KHR_shader_draw_parameters"
+// CHECK:      OpExtension "SPV_KHR_multiview"
+
+float4 main(
+    [[vk::builtin("BaseVertex")]]    int baseVertex : A,
+                                     uint viewid    : SV_ViewID) : B {
+    return baseVertex + viewid;
+}

+ 9 - 0
tools/clang/test/CodeGenSPIRV/spirv.ext.cl.forbid.hlsl

@@ -0,0 +1,9 @@
+// Run: %dxc -T vs_6_1 -E main -fspv-extension=SPV_KHR_shader_draw_parameters
+
+float4 main(
+    [[vk::builtin("BaseVertex")]]    int baseVertex : A,
+                                     uint viewid    : SV_ViewID) : B {
+    return baseVertex + viewid;
+}
+
+// CHECK: :5:55: error: SPIR-V extension 'SPV_KHR_multiview' required for SV_ViewID but not permitted to use

+ 8 - 0
tools/clang/test/CodeGenSPIRV/spirv.ext.cl.unknown.hlsl

@@ -0,0 +1,8 @@
+// Run: %dxc -T ps_6_1 -E main -fspv-extension=MyExtension
+
+float4 main(uint viewid: SV_ViewID) : SV_Target {
+    return viewid;
+}
+
+// CHECK: error: unknown SPIR-V extension 'MyExtension'
+// CHECK: note: known extensions are

+ 1 - 0
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -478,6 +478,7 @@ public:
           spirvOpts.tShift = opts.VkTShift;
           spirvOpts.sShift = opts.VkSShift;
           spirvOpts.uShift = opts.VkUShift;
+          spirvOpts.allowedExtensions = opts.SpvExtensions;
           spirvOpts.enable16BitTypes = opts.Enable16BitTypes;
           clang::EmitSPIRVAction action(spirvOpts);
           FrontendInputFile file(utf8SourceName.m_psz, IK_HLSL);

+ 10 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -1168,6 +1168,16 @@ TEST_F(FileTest, SpirvBuiltInDeviceIndexInvalidUsage) {
   runFileTest("spirv.builtin.device-index.invalid.hlsl", Expect::Failure);
 }
 
+TEST_F(FileTest, SpirvExtensionCLAllow) {
+  runFileTest("spirv.ext.cl.allow.hlsl");
+}
+TEST_F(FileTest, SpirvExtensionCLForbid) {
+  runFileTest("spirv.ext.cl.forbid.hlsl", Expect::Failure);
+}
+TEST_F(FileTest, SpirvExtensionCLUnknown) {
+  runFileTest("spirv.ext.cl.unknown.hlsl", Expect::Failure);
+}
+
 // For shader stage input/output interface
 // For semantic SV_Position, SV_ClipDistance, SV_CullDistance
 TEST_F(FileTest, SpirvStageIOInterfaceVS) {

+ 3 - 3
tools/clang/unittests/SPIRV/ModuleBuilderTest.cpp

@@ -21,7 +21,7 @@ using ::testing::ElementsAre;
 
 TEST(ModuleBuilder, TakeModuleDirectlyCreatesHeader) {
   SPIRVContext context;
-  ModuleBuilder builder(&context, false);
+  ModuleBuilder builder(&context, nullptr, false);
 
   EXPECT_THAT(builder.takeModule(),
               ElementsAre(spv::MagicNumber, 0x00010000, 14u << 16, 1u, 0u));
@@ -29,7 +29,7 @@ TEST(ModuleBuilder, TakeModuleDirectlyCreatesHeader) {
 
 TEST(ModuleBuilder, CreateFunction) {
   SPIRVContext context;
-  ModuleBuilder builder(&context, false);
+  ModuleBuilder builder(&context, nullptr, false);
 
   const auto rType = context.takeNextId();
   const auto fType = context.takeNextId();
@@ -47,7 +47,7 @@ TEST(ModuleBuilder, CreateFunction) {
 
 TEST(ModuleBuilder, CreateBasicBlock) {
   SPIRVContext context;
-  ModuleBuilder builder(&context, false);
+  ModuleBuilder builder(&context, nullptr, false);
 
   const auto rType = context.takeNextId();
   const auto fType = context.takeNextId();