瀏覽代碼

[spirv] Allow explicitly controlling SPIR-V extensions (#1151)

Added FeatureManager to record all extensions specified from
the command-line and emit error if trying to use one not permitted.

Added command-line option -fspv-extension= to specify
whitelisted extensions.
Lei Zhang 7 年之前
父節點
當前提交
241d32c810

+ 22 - 1
docs/SPIR-V.rst

@@ -267,7 +267,7 @@ The namespace ``vk`` will be used for all Vulkan attributes:
 - ``push_constant``: For marking a variable as the push constant block. Allowed
 - ``push_constant``: For marking a variable as the push constant block. Allowed
   on global variables of struct type. At most one variable can be marked as
   on global variables of struct type. At most one variable can be marked as
   ``push_constant`` in a shader.
   ``push_constant`` in a shader.
-- ``constant_id``: For marking a global constant as a specialization constant.
+- ``constant_id(X)``: For marking a global constant as a specialization constant.
   Allowed on global variables of boolean/integer/float types.
   Allowed on global variables of boolean/integer/float types.
 - ``input_attachment_index(X)``: To associate the Xth entry in the input pass
 - ``input_attachment_index(X)``: To associate the Xth entry in the input pass
   list to the annotated object. Only allowed on objects whose type are
   list to the annotated object. Only allowed on objects whose type are
@@ -294,6 +294,25 @@ interface variables:
   main([[vk::location(N)]] float4 input: A) : B
   main([[vk::location(N)]] float4 input: A) : B
   { ... }
   { ... }
 
 
+SPIR-V version and extension
+----------------------------
+
+In the **defult** mode (without ``-fspv-extension=<extension>`` command-line
+option), SPIR-V CodeGen will try its best to use the lowest SPIR-V version, and
+only require higher SPIR-V versions and extensions when they are truly needed
+for translating the input source code.
+
+For example, unless `Shader Model 6.0 wave intrinsics`_ are used, the generated
+SPIR-V will always be of version 1.0. The ``SPV_KHR_multivew`` extension will
+not be emitted unless you use ``SV_ViewID``.
+
+You can of course have fine-grained control of what extensions are permitted
+in the CodeGen using the **explicit** mode, turned on by the
+``-fspv-extension=<extension>`` command-line option. Only extensions supplied
+via ``-fspv-extension=`` will be used. If that does not suffice, errors will
+be emitted explaining what additional extensions are required to translate what
+specific feature in the source code.
+
 Legalization, optimization, validation
 Legalization, optimization, validation
 --------------------------------------
 --------------------------------------
 
 
@@ -2631,6 +2650,8 @@ codegen for Vulkan:
   location number according to alphabetical order or declaration order. See
   location number according to alphabetical order or declaration order. See
   `HLSL semantic and Vulkan Location`_ for more details.
   `HLSL semantic and Vulkan Location`_ for more details.
 - ``-fspv-reflect``: Emits additional SPIR-V instructions to aid reflection.
 - ``-fspv-reflect``: Emits additional SPIR-V instructions to aid reflection.
+- ``-fspv-extension=<extension>``: Only allows using ``<extension>`` in CodeGen.
+  If you want to allow multiple extensions, provide more than one such option.
 
 
 Unsupported HLSL Features
 Unsupported HLSL Features
 =========================
 =========================

+ 7 - 6
include/dxc/Support/HLSLOptions.h

@@ -160,16 +160,17 @@ public:
 
 
   // SPIRV Change Starts
   // SPIRV Change Starts
 #ifdef ENABLE_SPIRV_CODEGEN
 #ifdef ENABLE_SPIRV_CODEGEN
-  bool GenSPIRV; // OPT_spirv
-  bool VkIgnoreUnusedResources; // OPT_fvk_ignore_used_resources
-  bool VkInvertY; // OPT_fvk_invert_y
-  bool VkUseGlslLayout; // OPT_fvk_use_glsl_layout
-  bool SpvEnableReflect; // OPT_fspv_reflect
-  llvm::StringRef VkStageIoOrder; // OPT_fvk_stage_io_order
+  bool GenSPIRV;                           // OPT_spirv
+  bool VkIgnoreUnusedResources;            // OPT_fvk_ignore_used_resources
+  bool VkInvertY;                          // OPT_fvk_invert_y
+  bool VkUseGlslLayout;                    // OPT_fvk_use_glsl_layout
+  bool SpvEnableReflect;                   // OPT_fspv_reflect
+  llvm::StringRef VkStageIoOrder;          // OPT_fvk_stage_io_order
   llvm::SmallVector<uint32_t, 4> VkBShift; // OPT_fvk_b_shift
   llvm::SmallVector<uint32_t, 4> VkBShift; // OPT_fvk_b_shift
   llvm::SmallVector<uint32_t, 4> VkTShift; // OPT_fvk_t_shift
   llvm::SmallVector<uint32_t, 4> VkTShift; // OPT_fvk_t_shift
   llvm::SmallVector<uint32_t, 4> VkSShift; // OPT_fvk_s_shift
   llvm::SmallVector<uint32_t, 4> VkSShift; // OPT_fvk_s_shift
   llvm::SmallVector<uint32_t, 4> VkUShift; // OPT_fvk_u_shift
   llvm::SmallVector<uint32_t, 4> VkUShift; // OPT_fvk_u_shift
+  llvm::SmallVector<llvm::StringRef, 4> SpvExtensions; // OPT_fspv_extension
 #endif
 #endif
   // SPIRV Change Ends
   // SPIRV Change Ends
 };
 };

+ 2 - 0
include/dxc/Support/HLSLOptions.td

@@ -254,6 +254,8 @@ def fvk_use_glsl_layout: Flag<["-"], "fvk-use-glsl-layout">, Group<spirv_Group>,
   HelpText<"Use conventional GLSL std140/std430 layout for resources">;
   HelpText<"Use conventional GLSL std140/std430 layout for resources">;
 def fspv_reflect: Flag<["-"], "fspv-reflect">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
 def fspv_reflect: Flag<["-"], "fspv-reflect">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
   HelpText<"Emit additional SPIR-V instructions to aid reflection">;
   HelpText<"Emit additional SPIR-V instructions to aid reflection">;
+def fspv_extension_EQ : Joined<["-"], "fspv-extension=">, Group<spirv_Group>, Flags<[CoreOption, DriverOption]>,
+  HelpText<"Specify SPIR-V extension permitted to use">;
 // SPIRV Change Ends
 // SPIRV Change Ends
 
 
 //////////////////////////////////////////////////////////////////////////////
 //////////////////////////////////////////////////////////////////////////////

+ 5 - 0
lib/DxcSupport/HLSLOptions.cpp

@@ -522,6 +522,10 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
            << opts.VkStageIoOrder;
            << opts.VkStageIoOrder;
     return 1;
     return 1;
   }
   }
+
+  for (const Arg *A : Args.filtered(OPT_fspv_extension_EQ)) {
+    opts.SpvExtensions.push_back(A->getValue());
+  }
 #else
 #else
   if (Args.hasFlag(OPT_spirv, OPT_INVALID, false) ||
   if (Args.hasFlag(OPT_spirv, OPT_INVALID, false) ||
       Args.hasFlag(OPT_fvk_invert_y, OPT_INVALID, false) ||
       Args.hasFlag(OPT_fvk_invert_y, OPT_INVALID, false) ||
@@ -529,6 +533,7 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
       Args.hasFlag(OPT_fspv_reflect, OPT_INVALID, false) ||
       Args.hasFlag(OPT_fspv_reflect, OPT_INVALID, false) ||
       Args.hasFlag(OPT_fvk_ignore_unused_resources, OPT_INVALID, false) ||
       Args.hasFlag(OPT_fvk_ignore_unused_resources, OPT_INVALID, false) ||
       !Args.getLastArgValue(OPT_fvk_stage_io_order_EQ).empty() ||
       !Args.getLastArgValue(OPT_fvk_stage_io_order_EQ).empty() ||
+      !Args.getLastArgValue(OPT_fspv_extension_EQ).empty() ||
       !Args.getLastArgValue(OPT_fvk_b_shift).empty() ||
       !Args.getLastArgValue(OPT_fvk_b_shift).empty() ||
       !Args.getLastArgValue(OPT_fvk_t_shift).empty() ||
       !Args.getLastArgValue(OPT_fvk_t_shift).empty() ||
       !Args.getLastArgValue(OPT_fvk_s_shift).empty() ||
       !Args.getLastArgValue(OPT_fvk_s_shift).empty() ||

+ 1 - 0
tools/clang/include/clang/SPIRV/EmitSPIRVOptions.h

@@ -29,6 +29,7 @@ struct EmitSPIRVOptions {
   llvm::SmallVector<uint32_t, 4> tShift;
   llvm::SmallVector<uint32_t, 4> tShift;
   llvm::SmallVector<uint32_t, 4> sShift;
   llvm::SmallVector<uint32_t, 4> sShift;
   llvm::SmallVector<uint32_t, 4> uShift;
   llvm::SmallVector<uint32_t, 4> uShift;
+  llvm::SmallVector<llvm::StringRef, 4> allowedExtensions;
 };
 };
 } // end namespace clang
 } // end namespace clang
 
 

+ 90 - 0
tools/clang/include/clang/SPIRV/FeatureManager.h

@@ -0,0 +1,90 @@
+//===------ FeatureManager.h - SPIR-V Version/Extension Manager -*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//===----------------------------------------------------------------------===//
+//
+//  This file defines a SPIR-V version and extension manager.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_LIB_SPIRV_FEATUREMANAGER_H
+#define LLVM_CLANG_LIB_SPIRV_FEATUREMANAGER_H
+
+#include <string>
+
+#include "clang/Basic/Diagnostic.h"
+#include "clang/Basic/SourceLocation.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace clang {
+namespace spirv {
+
+/// A list of SPIR-V extensions known to our CodeGen.
+enum class Extension {
+  KHR_device_group,
+  KHR_multiview,
+  KHR_shader_draw_parameters,
+  EXT_fragment_fully_covered,
+  EXT_shader_stencil_export,
+  AMD_gpu_shader_half_float,
+  AMD_shader_explicit_vertex_parameter,
+  GOOGLE_decorate_string,
+  GOOGLE_hlsl_functionality1,
+  Unknown,
+};
+
+/// The class for handling SPIR-V version and extension requests.
+class FeatureManager {
+public:
+  explicit FeatureManager(DiagnosticsEngine &de);
+
+  /// Allows the given extension to be used in CodeGen.
+  bool allowExtension(llvm::StringRef);
+  /// Allows all extensions to be used in CodeGen.
+  void allowAllKnownExtensions();
+  /// Rqeusts the given extension for translating the given target feature at
+  /// the given source location. Emits an error if the given extension is not
+  /// permitted to use.
+  bool requestExtension(Extension, llvm::StringRef target, SourceLocation);
+
+  /// Translates extension name to symbol.
+  static Extension getExtensionSymbol(llvm::StringRef name);
+  /// Translates extension symbol to name.
+  static const char *getExtensionName(Extension symbol);
+
+  /// Returns the names of all known extensions as a string.
+  std::string getKnownExtensions(const char *delimiter, const char *prefix = "",
+                                 const char *postfix = "");
+
+private:
+  /// \brief Wrapper method to create an error message and report it
+  /// in the diagnostic engine associated with this object.
+  template <unsigned N>
+  DiagnosticBuilder emitError(const char (&message)[N], SourceLocation loc) {
+    const auto diagId =
+        diags.getCustomDiagID(clang::DiagnosticsEngine::Error, message);
+    return diags.Report(loc, diagId);
+  }
+
+  /// \brief Wrapper method to create an note message and report it
+  /// in the diagnostic engine associated with this object.
+  template <unsigned N>
+  DiagnosticBuilder emitNote(const char (&message)[N], SourceLocation loc) {
+    const auto diagId =
+        diags.getCustomDiagID(clang::DiagnosticsEngine::Note, message);
+    return diags.Report(loc, diagId);
+  }
+
+  DiagnosticsEngine &diags;
+
+  llvm::SmallBitVector allowedExtensions;
+};
+
+} // end namespace spirv
+} // end namespace clang
+
+#endif // LLVM_CLANG_LIB_SPIRV_FEATUREMANAGER_H

+ 9 - 11
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -14,6 +14,7 @@
 #include <vector>
 #include <vector>
 
 
 #include "clang/AST/Type.h"
 #include "clang/AST/Type.h"
+#include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/InstBuilder.h"
 #include "clang/SPIRV/InstBuilder.h"
 #include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/Structure.h"
 #include "clang/SPIRV/Structure.h"
@@ -35,7 +36,7 @@ namespace spirv {
 class ModuleBuilder {
 class ModuleBuilder {
 public:
 public:
   /// \brief Constructs a ModuleBuilder with the given SPIR-V context.
   /// \brief Constructs a ModuleBuilder with the given SPIR-V context.
-  explicit ModuleBuilder(SPIRVContext *, bool enablReflect);
+  ModuleBuilder(SPIRVContext *, FeatureManager *features, bool enableReflect);
 
 
   /// \brief Returns the associated SPIRVContext.
   /// \brief Returns the associated SPIRVContext.
   inline SPIRVContext *getSPIRVContext();
   inline SPIRVContext *getSPIRVContext();
@@ -335,8 +336,9 @@ public:
   void addExecutionMode(uint32_t entryPointId, spv::ExecutionMode em,
   void addExecutionMode(uint32_t entryPointId, spv::ExecutionMode em,
                         llvm::ArrayRef<uint32_t> params);
                         llvm::ArrayRef<uint32_t> params);
 
 
-  /// \brief Adds an extension to the module under construction.
-  inline void addExtension(llvm::StringRef extension);
+  /// \brief Adds an extension to the module under construction for translating
+  /// the given target at the given source location.
+  void addExtension(Extension, llvm::StringRef target, SourceLocation);
 
 
   /// \brief If not added already, adds an OpExtInstImport (import of extended
   /// \brief If not added already, adds an OpExtInstImport (import of extended
   /// instruction set) of the GLSL instruction set. Returns the <result-id> for
   /// instruction set) of the GLSL instruction set. Returns the <result-id> for
@@ -468,11 +470,11 @@ private:
       uint32_t sample, uint32_t minLod,
       uint32_t sample, uint32_t minLod,
       llvm::SmallVectorImpl<uint32_t> *orderedParams);
       llvm::SmallVectorImpl<uint32_t> *orderedParams);
 
 
-  SPIRVContext &theContext; ///< The SPIR-V context.
-  SPIRVModule theModule;    ///< The module under building.
-
-  const bool allowReflect; ///< Whether allow reflect instructions.
+  SPIRVContext &theContext;       ///< The SPIR-V context.
+  FeatureManager *featureManager; ///< SPIR-V version/extension manager.
+  const bool allowReflect;        ///< Whether allow reflect instructions.
 
 
+  SPIRVModule theModule;                 ///< The module under building.
   std::unique_ptr<Function> theFunction; ///< The function under building.
   std::unique_ptr<Function> theFunction; ///< The function under building.
   OrderedBasicBlockMap basicBlocks;      ///< The basic blocks under building.
   OrderedBasicBlockMap basicBlocks;      ///< The basic blocks under building.
   BasicBlock *insertPoint;               ///< The current insertion point.
   BasicBlock *insertPoint;               ///< The current insertion point.
@@ -516,10 +518,6 @@ void ModuleBuilder::setShaderModelVersion(uint32_t major, uint32_t minor) {
   theModule.setShaderModelVersion(major * 100 + minor * 10);
   theModule.setShaderModelVersion(major * 100 + minor * 10);
 }
 }
 
 
-void ModuleBuilder::addExtension(llvm::StringRef extension) {
-  theModule.addExtension(extension);
-}
-
 } // end namespace spirv
 } // end namespace spirv
 } // end namespace clang
 } // end namespace clang
 
 

+ 1 - 0
tools/clang/lib/SPIRV/CMakeLists.txt

@@ -8,6 +8,7 @@ add_clang_library(clangSPIRV
   DeclResultIdMapper.cpp
   DeclResultIdMapper.cpp
   Decoration.cpp
   Decoration.cpp
   EmitSPIRVAction.cpp
   EmitSPIRVAction.cpp
+  FeatureManager.cpp
   GlPerVertex.cpp
   GlPerVertex.cpp
   InitListHandler.cpp
   InitListHandler.cpp
   InstBuilderAuto.cpp
   InstBuilderAuto.cpp

+ 13 - 6
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -1872,11 +1872,14 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
     case BuiltIn::BaseVertex:
     case BuiltIn::BaseVertex:
     case BuiltIn::BaseInstance:
     case BuiltIn::BaseInstance:
     case BuiltIn::DrawIndex:
     case BuiltIn::DrawIndex:
-      theBuilder.addExtension("SPV_KHR_shader_draw_parameters");
+      theBuilder.addExtension(Extension::KHR_shader_draw_parameters,
+                              builtinAttr->getBuiltIn(),
+                              builtinAttr->getLocation());
       theBuilder.requireCapability(spv::Capability::DrawParameters);
       theBuilder.requireCapability(spv::Capability::DrawParameters);
       break;
       break;
     case BuiltIn::DeviceIndex:
     case BuiltIn::DeviceIndex:
-      theBuilder.addExtension("SPV_KHR_device_group");
+      theBuilder.addExtension(Extension::KHR_device_group,
+                              stageVar->getSemanticStr(), srcLoc);
       theBuilder.requireCapability(spv::Capability::DeviceGroup);
       theBuilder.requireCapability(spv::Capability::DeviceGroup);
       break;
       break;
     }
     }
@@ -2095,7 +2098,8 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
   }
   }
   // According to DXIL spec, the StencilRef SV can only be used by PSOut.
   // According to DXIL spec, the StencilRef SV can only be used by PSOut.
   case hlsl::Semantic::Kind::StencilRef: {
   case hlsl::Semantic::Kind::StencilRef: {
-    theBuilder.addExtension("SPV_EXT_shader_stencil_export");
+    theBuilder.addExtension(Extension::EXT_shader_stencil_export,
+                            stageVar->getSemanticStr(), srcLoc);
     theBuilder.requireCapability(spv::Capability::StencilExportEXT);
     theBuilder.requireCapability(spv::Capability::StencilExportEXT);
 
 
     stageVar->setIsSpirvBuiltin();
     stageVar->setIsSpirvBuiltin();
@@ -2103,7 +2107,8 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
   }
   }
   // According to DXIL spec, the ViewID SV can only be used by PSIn.
   // According to DXIL spec, the ViewID SV can only be used by PSIn.
   case hlsl::Semantic::Kind::Barycentrics: {
   case hlsl::Semantic::Kind::Barycentrics: {
-    theBuilder.addExtension("SPV_AMD_shader_explicit_vertex_parameter");
+    theBuilder.addExtension(Extension::AMD_shader_explicit_vertex_parameter,
+                            stageVar->getSemanticStr(), srcLoc);
     stageVar->setIsSpirvBuiltin();
     stageVar->setIsSpirvBuiltin();
 
 
     // Selecting the correct builtin according to interpolation mode
     // Selecting the correct builtin according to interpolation mode
@@ -2192,7 +2197,8 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
   // According to Vulkan spec, the ViewIndex BuiltIn can only be used in
   // According to Vulkan spec, the ViewIndex BuiltIn can only be used in
   // VS/HS/DS/GS/PS input.
   // VS/HS/DS/GS/PS input.
   case hlsl::Semantic::Kind::ViewID: {
   case hlsl::Semantic::Kind::ViewID: {
-    theBuilder.addExtension("SPV_KHR_multiview");
+    theBuilder.addExtension(Extension::KHR_multiview,
+                            stageVar->getSemanticStr(), srcLoc);
     theBuilder.requireCapability(spv::Capability::MultiView);
     theBuilder.requireCapability(spv::Capability::MultiView);
 
 
     stageVar->setIsSpirvBuiltin();
     stageVar->setIsSpirvBuiltin();
@@ -2202,7 +2208,8 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
     // According to Vulkan spec, the FullyCoveredEXT BuiltIn can only be used as
     // According to Vulkan spec, the FullyCoveredEXT BuiltIn can only be used as
     // PSIn.
     // PSIn.
   case hlsl::Semantic::Kind::InnerCoverage: {
   case hlsl::Semantic::Kind::InnerCoverage: {
-    theBuilder.addExtension("SPV_EXT_fragment_fully_covered");
+    theBuilder.addExtension(Extension::EXT_fragment_fully_covered,
+                            stageVar->getSemanticStr(), srcLoc);
     theBuilder.requireCapability(spv::Capability::FragmentFullyCoveredEXT);
     theBuilder.requireCapability(spv::Capability::FragmentFullyCoveredEXT);
 
 
     stageVar->setIsSpirvBuiltin();
     stageVar->setIsSpirvBuiltin();

+ 8 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -19,6 +19,7 @@
 #include "spirv/unified1/spirv.hpp11"
 #include "spirv/unified1/spirv.hpp11"
 #include "clang/AST/Attr.h"
 #include "clang/AST/Attr.h"
 #include "clang/SPIRV/EmitSPIRVOptions.h"
 #include "clang/SPIRV/EmitSPIRVOptions.h"
+#include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/ModuleBuilder.h"
 #include "clang/SPIRV/ModuleBuilder.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/Optional.h"
@@ -258,7 +259,8 @@ private:
 class DeclResultIdMapper {
 class DeclResultIdMapper {
 public:
 public:
   inline DeclResultIdMapper(const hlsl::ShaderModel &stage, ASTContext &context,
   inline DeclResultIdMapper(const hlsl::ShaderModel &stage, ASTContext &context,
-                            ModuleBuilder &builder,
+                            ModuleBuilder &builder, TypeTranslator &translator,
+                            FeatureManager &features,
                             const EmitSPIRVOptions &spirvOptions);
                             const EmitSPIRVOptions &spirvOptions);
 
 
   /// \brief Returns the <result-id> for a SPIR-V builtin variable.
   /// \brief Returns the <result-id> for a SPIR-V builtin variable.
@@ -632,7 +634,8 @@ private:
   ASTContext &astContext;
   ASTContext &astContext;
   DiagnosticsEngine &diags;
   DiagnosticsEngine &diags;
 
 
-  TypeTranslator typeTranslator;
+  TypeTranslator &typeTranslator;
+  FeatureManager &featureManager;
 
 
   uint32_t entryFunctionId;
   uint32_t entryFunctionId;
 
 
@@ -737,10 +740,12 @@ void CounterIdAliasPair::assign(const CounterIdAliasPair &srcPair,
 DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
 DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
                                        ASTContext &context,
                                        ASTContext &context,
                                        ModuleBuilder &builder,
                                        ModuleBuilder &builder,
+                                       TypeTranslator &translator,
+                                       FeatureManager &features,
                                        const EmitSPIRVOptions &options)
                                        const EmitSPIRVOptions &options)
     : shaderModel(model), theBuilder(builder), spirvOptions(options),
     : shaderModel(model), theBuilder(builder), spirvOptions(options),
       astContext(context), diags(context.getDiagnostics()),
       astContext(context), diags(context.getDiagnostics()),
-      typeTranslator(context, builder, diags, options), entryFunctionId(0),
+      typeTranslator(translator), featureManager(features), entryFunctionId(0),
       laneCountBuiltinId(0), laneIndexBuiltinId(0), needsLegalization(false),
       laneCountBuiltinId(0), laneIndexBuiltinId(0), needsLegalization(false),
       glPerVertex(model, context, builder, typeTranslator, options.invertY) {}
       glPerVertex(model, context, builder, typeTranslator, options.invertY) {}
 
 

+ 118 - 0
tools/clang/lib/SPIRV/FeatureManager.cpp

@@ -0,0 +1,118 @@
+//===---- FeatureManager.cpp - SPIR-V Version/Extension Manager -*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//===----------------------------------------------------------------------===//
+
+#include "clang/SPIRV/FeatureManager.h"
+
+#include <sstream>
+
+#include "llvm/ADT/StringSwitch.h"
+
+namespace clang {
+namespace spirv {
+
+FeatureManager::FeatureManager(DiagnosticsEngine &de) : diags(de) {
+  allowedExtensions.resize(static_cast<unsigned>(Extension::Unknown) + 1);
+}
+
+bool FeatureManager::allowExtension(llvm::StringRef name) {
+  const auto symbol = getExtensionSymbol(name);
+  if (symbol == Extension::Unknown) {
+    emitError("unknown SPIR-V extension '%0'", {}) << name;
+    emitNote("known extensions are\n%0", {})
+        << getKnownExtensions("\n* ", "* ");
+    return false;
+  }
+
+  allowedExtensions.set(static_cast<unsigned>(symbol));
+  if (symbol == Extension::GOOGLE_hlsl_functionality1)
+    allowedExtensions.set(
+        static_cast<unsigned>(Extension::GOOGLE_decorate_string));
+
+  return true;
+}
+
+void FeatureManager::allowAllKnownExtensions() { allowedExtensions.set(); }
+
+bool FeatureManager::requestExtension(Extension ext, llvm::StringRef target,
+                                      SourceLocation srcLoc) {
+  if (allowedExtensions.test(static_cast<unsigned>(ext)))
+    return true;
+
+  emitError("SPIR-V extension '%0' required for %1 but not permitted to use",
+            srcLoc)
+      << getExtensionName(ext) << target;
+  return false;
+}
+
+Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
+  return llvm::StringSwitch<Extension>(name)
+      .Case("SPV_KHR_device_group", Extension::KHR_device_group)
+      .Case("SPV_KHR_multiview", Extension::KHR_multiview)
+      .Case("SPV_KHR_shader_draw_parameters",
+            Extension::KHR_shader_draw_parameters)
+      .Case("SPV_EXT_fragment_fully_covered",
+            Extension::EXT_fragment_fully_covered)
+      .Case("SPV_EXT_shader_stencil_export",
+            Extension::EXT_shader_stencil_export)
+      .Case("SPV_AMD_gpu_shader_half_float",
+            Extension::AMD_gpu_shader_half_float)
+      .Case("SPV_AMD_shader_explicit_vertex_parameter",
+            Extension::AMD_shader_explicit_vertex_parameter)
+      .Case("SPV_GOOGLE_decorate_string", Extension::GOOGLE_decorate_string)
+      .Case("SPV_GOOGLE_hlsl_functionality1",
+            Extension::GOOGLE_hlsl_functionality1)
+      .Default(Extension::Unknown);
+}
+
+const char *FeatureManager::getExtensionName(Extension symbol) {
+  switch (symbol) {
+  case Extension::KHR_device_group:
+    return "SPV_KHR_device_group";
+  case Extension::KHR_multiview:
+    return "SPV_KHR_multiview";
+  case Extension::KHR_shader_draw_parameters:
+    return "SPV_KHR_shader_draw_parameters";
+  case Extension::EXT_fragment_fully_covered:
+    return "SPV_EXT_fragment_fully_covered";
+  case Extension::EXT_shader_stencil_export:
+    return "SPV_EXT_shader_stencil_export";
+  case Extension::AMD_gpu_shader_half_float:
+    return "SPV_AMD_gpu_shader_half_float";
+  case Extension::AMD_shader_explicit_vertex_parameter:
+    return "SPV_AMD_shader_explicit_vertex_parameter";
+  case Extension::GOOGLE_decorate_string:
+    return "SPV_GOOGLE_decorate_string";
+  case Extension::GOOGLE_hlsl_functionality1:
+    return "SPV_GOOGLE_hlsl_functionality1";
+  default:
+    break;
+  }
+  return "<unknown extension>";
+}
+
+std::string FeatureManager::getKnownExtensions(const char *delimiter,
+                                               const char *prefix,
+                                               const char *postfix) {
+  std::ostringstream oss;
+
+  oss << prefix;
+
+  const auto numExtensions = static_cast<uint32_t>(Extension::Unknown);
+  for (uint32_t i = 0; i < numExtensions; ++i) {
+    oss << getExtensionName(static_cast<Extension>(i));
+    if (i + 1 < numExtensions)
+      oss << delimiter;
+  }
+
+  oss << postfix;
+
+  return oss.str();
+}
+
+} // end namespace spirv
+} // end namespace clang

+ 18 - 7
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -18,9 +18,11 @@
 namespace clang {
 namespace clang {
 namespace spirv {
 namespace spirv {
 
 
-ModuleBuilder::ModuleBuilder(SPIRVContext *C, bool reflect)
-    : theContext(*C), theModule(), allowReflect(reflect), theFunction(nullptr),
-      insertPoint(nullptr), instBuilder(nullptr), glslExtSetId(0) {
+ModuleBuilder::ModuleBuilder(SPIRVContext *C, FeatureManager *features,
+                             bool reflect)
+    : theContext(*C), featureManager(features), allowReflect(reflect),
+      theModule(), theFunction(nullptr), insertPoint(nullptr),
+      instBuilder(nullptr), glslExtSetId(0) {
   instBuilder.setConsumer([this](std::vector<uint32_t> &&words) {
   instBuilder.setConsumer([this](std::vector<uint32_t> &&words) {
     this->constructSite = std::move(words);
     this->constructSite = std::move(words);
   });
   });
@@ -752,6 +754,13 @@ void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
   theModule.addExecutionMode(std::move(constructSite));
   theModule.addExecutionMode(std::move(constructSite));
 }
 }
 
 
+void ModuleBuilder::addExtension(Extension ext, llvm::StringRef target,
+                                 SourceLocation srcLoc) {
+  assert(featureManager);
+  featureManager->requestExtension(ext, target, srcLoc);
+  theModule.addExtension(featureManager->getExtensionName(ext));
+}
+
 uint32_t ModuleBuilder::getGLSLExtInstSet() {
 uint32_t ModuleBuilder::getGLSLExtInstSet() {
   if (glslExtSetId == 0) {
   if (glslExtSetId == 0) {
     glslExtSetId = theContext.takeNextId();
     glslExtSetId = theContext.takeNextId();
@@ -817,7 +826,8 @@ void ModuleBuilder::decorateInputAttachmentIndex(uint32_t targetId,
 void ModuleBuilder::decorateCounterBufferId(uint32_t mainBufferId,
 void ModuleBuilder::decorateCounterBufferId(uint32_t mainBufferId,
                                             uint32_t counterBufferId) {
                                             uint32_t counterBufferId) {
   if (allowReflect) {
   if (allowReflect) {
-    addExtension("SPV_GOOGLE_hlsl_functionality1");
+    addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
+                 {});
     theModule.addDecoration(
     theModule.addDecoration(
         Decoration::getHlslCounterBufferGOOGLE(theContext, counterBufferId),
         Decoration::getHlslCounterBufferGOOGLE(theContext, counterBufferId),
         mainBufferId);
         mainBufferId);
@@ -828,8 +838,9 @@ void ModuleBuilder::decorateHlslSemantic(uint32_t targetId,
                                          llvm::StringRef semantic,
                                          llvm::StringRef semantic,
                                          llvm::Optional<uint32_t> memberIdx) {
                                          llvm::Optional<uint32_t> memberIdx) {
   if (allowReflect) {
   if (allowReflect) {
-    addExtension("SPV_GOOGLE_decorate_string");
-    addExtension("SPV_GOOGLE_hlsl_functionality1");
+    addExtension(Extension::GOOGLE_decorate_string, "SPIR-V reflection", {});
+    addExtension(Extension::GOOGLE_hlsl_functionality1, "SPIR-V reflection",
+                 {});
     theModule.addDecoration(
     theModule.addDecoration(
         Decoration::getHlslSemanticGOOGLE(theContext, semantic, memberIdx),
         Decoration::getHlslSemanticGOOGLE(theContext, semantic, memberIdx),
         targetId);
         targetId);
@@ -902,7 +913,7 @@ IMPL_GET_PRIMITIVE_TYPE(Float32)
                                                                                \
                                                                                \
   uint32_t ModuleBuilder::get##ty##Type() {                                    \
   uint32_t ModuleBuilder::get##ty##Type() {                                    \
     if (spv::Capability::cap == spv::Capability::Float16)                      \
     if (spv::Capability::cap == spv::Capability::Float16)                      \
-      theModule.addExtension("SPV_AMD_gpu_shader_half_float");                 \
+      addExtension(Extension::AMD_gpu_shader_half_float, "16-bit float", {});  \
     else                                                                       \
     else                                                                       \
       requireCapability(spv::Capability::cap);                                 \
       requireCapability(spv::Capability::cap);                                 \
     const Type *type = Type::get##ty(theContext);                              \
     const Type *type = Type::get##ty(theContext);                              \

+ 14 - 2
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -519,17 +519,29 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
       entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
       entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
       shaderModel(*hlsl::ShaderModel::GetByName(
       shaderModel(*hlsl::ShaderModel::GetByName(
           ci.getCodeGenOpts().HLSLProfile.c_str())),
           ci.getCodeGenOpts().HLSLProfile.c_str())),
-      theContext(), theBuilder(&theContext, options.enableReflect),
-      declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
+      theContext(), featureManager(diags),
+      theBuilder(&theContext, &featureManager, options.enableReflect),
       typeTranslator(astContext, theBuilder, diags, options),
       typeTranslator(astContext, theBuilder, diags, options),
+      declIdMapper(shaderModel, astContext, theBuilder, typeTranslator,
+                   featureManager, spirvOptions),
       entryFunctionId(0), curFunction(nullptr), curThis(0),
       entryFunctionId(0), curFunction(nullptr), curThis(0),
       seenPushConstantAt(), isSpecConstantMode(false), needsLegalization(false),
       seenPushConstantAt(), isSpecConstantMode(false), needsLegalization(false),
       needsSpirv1p3(false) {
       needsSpirv1p3(false) {
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
     emitError("unknown shader module: %0", {}) << shaderModel.GetName();
     emitError("unknown shader module: %0", {}) << shaderModel.GetName();
+
   if (options.invertY && !shaderModel.IsVS() && !shaderModel.IsDS() &&
   if (options.invertY && !shaderModel.IsVS() && !shaderModel.IsDS() &&
       !shaderModel.IsGS())
       !shaderModel.IsGS())
     emitError("-fvk-invert-y can only be used in VS/DS/GS", {});
     emitError("-fvk-invert-y can only be used in VS/DS/GS", {});
+
+  if (options.allowedExtensions.empty()) {
+    // If no explicit extension control from command line, use the default mode:
+    // allowing all extensions.
+    featureManager.allowAllKnownExtensions();
+  } else {
+    for (auto ext : options.allowedExtensions)
+      featureManager.allowExtension(ext);
+  }
 }
 }
 
 
 void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
 void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {

+ 3 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -28,6 +28,7 @@
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/SPIRV/EmitSPIRVOptions.h"
 #include "clang/SPIRV/EmitSPIRVOptions.h"
+#include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/ModuleBuilder.h"
 #include "clang/SPIRV/ModuleBuilder.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SetVector.h"
@@ -911,9 +912,10 @@ private:
   const hlsl::ShaderModel &shaderModel;
   const hlsl::ShaderModel &shaderModel;
 
 
   SPIRVContext theContext;
   SPIRVContext theContext;
+  FeatureManager featureManager;
   ModuleBuilder theBuilder;
   ModuleBuilder theBuilder;
-  DeclResultIdMapper declIdMapper;
   TypeTranslator typeTranslator;
   TypeTranslator typeTranslator;
+  DeclResultIdMapper declIdMapper;
 
 
   /// A queue of decls reachable from the entry function. Decls inserted into
   /// A queue of decls reachable from the entry function. Decls inserted into
   /// this queue will persist to avoid duplicated translations. And we'd like
   /// this queue will persist to avoid duplicated translations. And we'd like

+ 10 - 0
tools/clang/test/CodeGenSPIRV/spirv.ext.cl.allow.hlsl

@@ -0,0 +1,10 @@
+// Run: %dxc -T vs_6_1 -E main -fspv-extension=SPV_KHR_multiview -fspv-extension=SPV_KHR_shader_draw_parameters
+
+// CHECK:      OpExtension "SPV_KHR_shader_draw_parameters"
+// CHECK:      OpExtension "SPV_KHR_multiview"
+
+float4 main(
+    [[vk::builtin("BaseVertex")]]    int baseVertex : A,
+                                     uint viewid    : SV_ViewID) : B {
+    return baseVertex + viewid;
+}

+ 9 - 0
tools/clang/test/CodeGenSPIRV/spirv.ext.cl.forbid.hlsl

@@ -0,0 +1,9 @@
+// Run: %dxc -T vs_6_1 -E main -fspv-extension=SPV_KHR_shader_draw_parameters
+
+float4 main(
+    [[vk::builtin("BaseVertex")]]    int baseVertex : A,
+                                     uint viewid    : SV_ViewID) : B {
+    return baseVertex + viewid;
+}
+
+// CHECK: :5:55: error: SPIR-V extension 'SPV_KHR_multiview' required for SV_ViewID but not permitted to use

+ 8 - 0
tools/clang/test/CodeGenSPIRV/spirv.ext.cl.unknown.hlsl

@@ -0,0 +1,8 @@
+// Run: %dxc -T ps_6_1 -E main -fspv-extension=MyExtension
+
+float4 main(uint viewid: SV_ViewID) : SV_Target {
+    return viewid;
+}
+
+// CHECK: error: unknown SPIR-V extension 'MyExtension'
+// CHECK: note: known extensions are

+ 1 - 0
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -478,6 +478,7 @@ public:
           spirvOpts.tShift = opts.VkTShift;
           spirvOpts.tShift = opts.VkTShift;
           spirvOpts.sShift = opts.VkSShift;
           spirvOpts.sShift = opts.VkSShift;
           spirvOpts.uShift = opts.VkUShift;
           spirvOpts.uShift = opts.VkUShift;
+          spirvOpts.allowedExtensions = opts.SpvExtensions;
           spirvOpts.enable16BitTypes = opts.Enable16BitTypes;
           spirvOpts.enable16BitTypes = opts.Enable16BitTypes;
           clang::EmitSPIRVAction action(spirvOpts);
           clang::EmitSPIRVAction action(spirvOpts);
           FrontendInputFile file(utf8SourceName.m_psz, IK_HLSL);
           FrontendInputFile file(utf8SourceName.m_psz, IK_HLSL);

+ 10 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -1168,6 +1168,16 @@ TEST_F(FileTest, SpirvBuiltInDeviceIndexInvalidUsage) {
   runFileTest("spirv.builtin.device-index.invalid.hlsl", Expect::Failure);
   runFileTest("spirv.builtin.device-index.invalid.hlsl", Expect::Failure);
 }
 }
 
 
+TEST_F(FileTest, SpirvExtensionCLAllow) {
+  runFileTest("spirv.ext.cl.allow.hlsl");
+}
+TEST_F(FileTest, SpirvExtensionCLForbid) {
+  runFileTest("spirv.ext.cl.forbid.hlsl", Expect::Failure);
+}
+TEST_F(FileTest, SpirvExtensionCLUnknown) {
+  runFileTest("spirv.ext.cl.unknown.hlsl", Expect::Failure);
+}
+
 // For shader stage input/output interface
 // For shader stage input/output interface
 // For semantic SV_Position, SV_ClipDistance, SV_CullDistance
 // For semantic SV_Position, SV_ClipDistance, SV_CullDistance
 TEST_F(FileTest, SpirvStageIOInterfaceVS) {
 TEST_F(FileTest, SpirvStageIOInterfaceVS) {

+ 3 - 3
tools/clang/unittests/SPIRV/ModuleBuilderTest.cpp

@@ -21,7 +21,7 @@ using ::testing::ElementsAre;
 
 
 TEST(ModuleBuilder, TakeModuleDirectlyCreatesHeader) {
 TEST(ModuleBuilder, TakeModuleDirectlyCreatesHeader) {
   SPIRVContext context;
   SPIRVContext context;
-  ModuleBuilder builder(&context, false);
+  ModuleBuilder builder(&context, nullptr, false);
 
 
   EXPECT_THAT(builder.takeModule(),
   EXPECT_THAT(builder.takeModule(),
               ElementsAre(spv::MagicNumber, 0x00010000, 14u << 16, 1u, 0u));
               ElementsAre(spv::MagicNumber, 0x00010000, 14u << 16, 1u, 0u));
@@ -29,7 +29,7 @@ TEST(ModuleBuilder, TakeModuleDirectlyCreatesHeader) {
 
 
 TEST(ModuleBuilder, CreateFunction) {
 TEST(ModuleBuilder, CreateFunction) {
   SPIRVContext context;
   SPIRVContext context;
-  ModuleBuilder builder(&context, false);
+  ModuleBuilder builder(&context, nullptr, false);
 
 
   const auto rType = context.takeNextId();
   const auto rType = context.takeNextId();
   const auto fType = context.takeNextId();
   const auto fType = context.takeNextId();
@@ -47,7 +47,7 @@ TEST(ModuleBuilder, CreateFunction) {
 
 
 TEST(ModuleBuilder, CreateBasicBlock) {
 TEST(ModuleBuilder, CreateBasicBlock) {
   SPIRVContext context;
   SPIRVContext context;
-  ModuleBuilder builder(&context, false);
+  ModuleBuilder builder(&context, nullptr, false);
 
 
   const auto rType = context.takeNextId();
   const auto rType = context.takeNextId();
   const auto fType = context.takeNextId();
   const auto fType = context.takeNextId();