Browse Source

[spirv] Add SPIR-V capability visitor.

Collect all necessary capabilities in a visitor pass. Remove all the
capabilities added in different places in the code.
Ehsan 6 years ago
parent
commit
9c820e030a

+ 8 - 2
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -657,12 +657,18 @@ private:
   SpirvConstantBoolean *boolFalseConstant;
   SpirvConstantBoolean *boolTrueSpecConstant;
   SpirvConstantBoolean *boolFalseSpecConstant;
+
+  llvm::SetVector<spv::Capability> existingCapabilities;
 };
 
 void SpirvBuilder::requireCapability(spv::Capability cap, SourceLocation loc) {
   if (cap != spv::Capability::Max) {
-    auto *capability = new (context) SpirvCapability(loc, cap);
-    module->addCapability(capability);
+    // No need to create a new capability nor add it to the module if it has
+    // already been added.
+    if (existingCapabilities.insert(cap)) {
+      auto *capability = new (context) SpirvCapability(loc, cap);
+      module->addCapability(capability);
+    }
   }
 }
 

+ 3 - 1
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -1384,6 +1384,7 @@ public:
   SpirvInstruction *getCoordinate() const { return coordinate; }
   spv::ImageOperandsMask getImageOperandsMask() const { return operandsMask; }
 
+  bool isSparse() const;
   bool hasDref() const { return dref != nullptr; }
   bool hasBias() const { return bias != nullptr; }
   bool hasLod() const { return lod != nullptr; }
@@ -1540,7 +1541,8 @@ private:
 
 /// \brief OpSampledImage instruction
 /// Result Type must be the OpTypeSampledImage type whose Image Type operand is
-/// the type of Image. We store the QualType for the underlying image as result type.
+/// the type of Image. We store the QualType for the underlying image as result
+/// type.
 class SpirvSampledImage : public SpirvInstruction {
 public:
   SpirvSampledImage(QualType resultType, uint32_t resultId, SourceLocation loc,

+ 10 - 0
tools/clang/include/clang/SPIRV/SpirvType.h

@@ -56,6 +56,16 @@ public:
 
   Kind getKind() const { return kind; }
 
+  static bool isTexture(const SpirvType *);
+  static bool isRWTexture(const SpirvType *);
+  static bool isSampler(const SpirvType *);
+  static bool isBuffer(const SpirvType *);
+  static bool isRWBuffer(const SpirvType *);
+  static bool isSubpassInput(const SpirvType *);
+  static bool isSubpassInputMS(const SpirvType *);
+  static bool isResourceType(const SpirvType *);
+  static bool isOrContains16BitType(const SpirvType *);
+
 protected:
   SpirvType(Kind k) : kind(k) {}
 

+ 1 - 0
tools/clang/lib/SPIRV/CMakeLists.txt

@@ -5,6 +5,7 @@ set(LLVM_LINK_COMPONENTS
 add_clang_library(clangSPIRV
   AstTypeProbe.cpp
   BlockReadableOrder.cpp
+  CapabilityVisitor.cpp
   Constant.cpp
   DeclResultIdMapper.cpp
   Decoration.cpp

+ 467 - 0
tools/clang/lib/SPIRV/CapabilityVisitor.cpp

@@ -0,0 +1,467 @@
+//===--- CapabilityVisitor.cpp - Capability Visitor --------------*- C++ -*-==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CapabilityVisitor.h"
+#include "clang/SPIRV/SpirvBuilder.h"
+
+namespace clang {
+namespace spirv {
+
+void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
+                                             SourceLocation loc,
+                                             spv::StorageClass sc) {
+  // Defent against instructions that do not have a return type.
+  if (!type)
+    return;
+
+  // Integer-related capabilities
+  if (const auto *intType = dyn_cast<IntegerType>(type)) {
+    switch (intType->getBitwidth()) {
+    case 16: {
+      // Usage of a 16-bit integer type.
+      spvBuilder.requireCapability(spv::Capability::Int16);
+
+      // Usage of a 16-bit integer type as stage I/O.
+      if (sc == spv::StorageClass::Input || sc == spv::StorageClass::Output) {
+        spvBuilder.addExtension(Extension::KHR_16bit_storage,
+                                "16-bit stage IO variables", loc);
+        spvBuilder.requireCapability(spv::Capability::StorageInputOutput16);
+      }
+      break;
+    }
+    case 64: {
+      spvBuilder.requireCapability(spv::Capability::Int64);
+      break;
+    }
+    default:
+      break;
+    }
+  }
+  // Float-related capabilities
+  else if (const auto *floatType = dyn_cast<FloatType>(type)) {
+    switch (floatType->getBitwidth()) {
+    case 16: {
+      // Usage of a 16-bit float type.
+      spvBuilder.requireCapability(spv::Capability::Float16);
+      spvBuilder.addExtension(Extension::AMD_gpu_shader_half_float,
+                              "16-bit float", {});
+
+      // Usage of a 16-bit float type as stage I/O.
+      if (sc == spv::StorageClass::Input || sc == spv::StorageClass::Output) {
+        spvBuilder.addExtension(Extension::KHR_16bit_storage,
+                                "16-bit stage IO variables", loc);
+        spvBuilder.requireCapability(spv::Capability::StorageInputOutput16);
+      }
+      break;
+    }
+    case 64: {
+      spvBuilder.requireCapability(spv::Capability::Float64);
+      break;
+    }
+    default:
+      break;
+    }
+  }
+  // Vectors
+  else if (const auto *vecType = dyn_cast<VectorType>(type)) {
+    addCapabilityForType(vecType->getElementType());
+  }
+  // Matrices
+  else if (const auto *matType = dyn_cast<MatrixType>(type)) {
+    addCapabilityForType(matType->getElementType());
+  }
+  // Arrays
+  else if (const auto *arrType = dyn_cast<ArrayType>(type)) {
+    addCapabilityForType(arrType->getElementType());
+  }
+  // Runtime array of resources requires additional capability.
+  else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
+    if (SpirvType::isResourceType(raType->getElementType())) {
+      // the elements inside the runtime array are resources
+      spvBuilder.addExtension(Extension::EXT_descriptor_indexing,
+                              "runtime array of resources", {});
+      spvBuilder.requireCapability(spv::Capability::RuntimeDescriptorArrayEXT);
+    }
+    addCapabilityForType(raType->getElementType());
+  }
+  // Image types
+  else if (const auto *imageType = dyn_cast<ImageType>(type)) {
+    switch (imageType->getDimension()) {
+    case spv::Dim::Buffer: {
+      spvBuilder.requireCapability(spv::Capability::SampledBuffer);
+      if (imageType->withSampler() == ImageType::WithSampler::No) {
+        spvBuilder.requireCapability(spv::Capability::ImageBuffer);
+      }
+      break;
+    }
+    case spv::Dim::Dim1D: {
+      if (imageType->withSampler() == ImageType::WithSampler::No) {
+        spvBuilder.requireCapability(spv::Capability::Image1D);
+      } else {
+        spvBuilder.requireCapability(spv::Capability::Sampled1D);
+      }
+      break;
+    }
+    case spv::Dim::SubpassData: {
+      spvBuilder.requireCapability(spv::Capability::InputAttachment);
+      break;
+    }
+    default:
+      break;
+    }
+
+    switch (imageType->getImageFormat()) {
+    case spv::ImageFormat::Rg32f:
+    case spv::ImageFormat::Rg16f:
+    case spv::ImageFormat::R11fG11fB10f:
+    case spv::ImageFormat::R16f:
+    case spv::ImageFormat::Rgba16:
+    case spv::ImageFormat::Rgb10A2:
+    case spv::ImageFormat::Rg16:
+    case spv::ImageFormat::Rg8:
+    case spv::ImageFormat::R16:
+    case spv::ImageFormat::R8:
+    case spv::ImageFormat::Rgba16Snorm:
+    case spv::ImageFormat::Rg16Snorm:
+    case spv::ImageFormat::Rg8Snorm:
+    case spv::ImageFormat::R16Snorm:
+    case spv::ImageFormat::R8Snorm:
+    case spv::ImageFormat::Rg32i:
+    case spv::ImageFormat::Rg16i:
+    case spv::ImageFormat::Rg8i:
+    case spv::ImageFormat::R16i:
+    case spv::ImageFormat::R8i:
+    case spv::ImageFormat::Rgb10a2ui:
+    case spv::ImageFormat::Rg32ui:
+    case spv::ImageFormat::Rg16ui:
+    case spv::ImageFormat::Rg8ui:
+    case spv::ImageFormat::R16ui:
+    case spv::ImageFormat::R8ui:
+      spvBuilder.requireCapability(
+          spv::Capability::StorageImageExtendedFormats);
+      break;
+    default:
+      // Only image formats requiring extended formats are relevant. The rest
+      // just pass through.
+      break;
+    }
+
+    if (imageType->isArrayedImage() && imageType->isMSImage())
+      spvBuilder.requireCapability(spv::Capability::ImageMSArray);
+
+    addCapabilityForType(imageType->getSampledType());
+  }
+  // Sampled image type
+  else if (const auto *sampledImageType = dyn_cast<SampledImageType>(type)) {
+    addCapabilityForType(sampledImageType->getImageType());
+  }
+  // Pointer type
+  else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
+    addCapabilityForType(ptrType->getPointeeType());
+  }
+  // Struct type
+  else if (const auto *structType = dyn_cast<StructType>(type)) {
+    if (SpirvType::isOrContains16BitType(structType)) {
+      spvBuilder.addExtension(Extension::KHR_16bit_storage,
+                              "16-bit types in resource", loc);
+      if (sc == spv::StorageClass::PushConstant) {
+        spvBuilder.requireCapability(spv::Capability::StoragePushConstant16);
+      } else if (structType->getInterfaceType() ==
+                 StructInterfaceType::UniformBuffer) {
+        spvBuilder.requireCapability(spv::Capability::StorageUniform16);
+      } else if (structType->getInterfaceType() ==
+                 StructInterfaceType::StorageBuffer) {
+        spvBuilder.requireCapability(
+            spv::Capability::StorageUniformBufferBlock16);
+      }
+    }
+    for (auto field : structType->getFields())
+      addCapabilityForType(field.type, loc, sc);
+  }
+}
+
+bool CapabilityVisitor::visit(SpirvDecoration *decor) {
+  const auto loc = decor->getSourceLocation();
+  switch (decor->getDecoration()) {
+  case spv::Decoration::Sample: {
+    spvBuilder.requireCapability(spv::Capability::SampleRateShading, loc);
+    break;
+  }
+  case spv::Decoration::NonUniformEXT: {
+    spvBuilder.addExtension(Extension::EXT_descriptor_indexing, "NonUniformEXT",
+                            loc);
+    spvBuilder.requireCapability(spv::Capability::ShaderNonUniformEXT);
+
+    break;
+  }
+  // Capabilities needed for built-ins
+  case spv::Decoration::BuiltIn: {
+    assert(decor->getParams().size() == 1);
+    const auto builtin = static_cast<spv::BuiltIn>(decor->getParams()[0]);
+    switch (builtin) {
+    case spv::BuiltIn::SampleId:
+    case spv::BuiltIn::SamplePosition: {
+      spvBuilder.requireCapability(spv::Capability::SampleRateShading, loc);
+      break;
+    }
+    case spv::BuiltIn::SubgroupSize:
+    case spv::BuiltIn::NumSubgroups:
+    case spv::BuiltIn::SubgroupId:
+    case spv::BuiltIn::SubgroupLocalInvocationId: {
+      spvBuilder.requireCapability(spv::Capability::GroupNonUniform, loc);
+      break;
+    }
+    case spv::BuiltIn::BaseVertex: {
+      spvBuilder.addExtension(Extension::KHR_shader_draw_parameters,
+                              "BaseVertex Builtin", loc);
+      spvBuilder.requireCapability(spv::Capability::DrawParameters);
+      break;
+    }
+    case spv::BuiltIn::BaseInstance: {
+      spvBuilder.addExtension(Extension::KHR_shader_draw_parameters,
+                              "BaseInstance Builtin", loc);
+      spvBuilder.requireCapability(spv::Capability::DrawParameters);
+      break;
+    }
+    case spv::BuiltIn::DrawIndex: {
+      spvBuilder.addExtension(Extension::KHR_shader_draw_parameters,
+                              "DrawIndex Builtin", loc);
+      spvBuilder.requireCapability(spv::Capability::DrawParameters);
+      break;
+    }
+    case spv::BuiltIn::DeviceIndex: {
+      spvBuilder.addExtension(Extension::KHR_device_group,
+                              "DeviceIndex Builtin", loc);
+      spvBuilder.requireCapability(spv::Capability::DeviceGroup);
+      break;
+    }
+    case spv::BuiltIn::FragStencilRefEXT: {
+      spvBuilder.addExtension(Extension::EXT_shader_stencil_export,
+                              "SV_StencilRef", loc);
+      spvBuilder.requireCapability(spv::Capability::StencilExportEXT);
+      break;
+    }
+    case spv::BuiltIn::ViewIndex: {
+      spvBuilder.addExtension(Extension::KHR_multiview, "SV_ViewID", loc);
+      spvBuilder.requireCapability(spv::Capability::MultiView);
+      break;
+    }
+    case spv::BuiltIn::FullyCoveredEXT: {
+      spvBuilder.addExtension(Extension::EXT_fragment_fully_covered,
+                              "SV_InnerCoverage", loc);
+      spvBuilder.requireCapability(spv::Capability::FragmentFullyCoveredEXT);
+      break;
+    }
+    case spv::BuiltIn::PrimitiveId: {
+      // PrimitiveID can be used as PSIn
+      if (shaderModel == spv::ExecutionModel::Fragment)
+        spvBuilder.requireCapability(spv::Capability::Geometry);
+      break;
+    }
+    case spv::BuiltIn::Layer: {
+      if (shaderModel == spv::ExecutionModel::Vertex ||
+          shaderModel == spv::ExecutionModel::TessellationControl ||
+          shaderModel == spv::ExecutionModel::TessellationEvaluation) {
+        spvBuilder.addExtension(Extension::EXT_shader_viewport_index_layer,
+                                "SV_RenderTargetArrayIndex", loc);
+        spvBuilder.requireCapability(
+            spv::Capability::ShaderViewportIndexLayerEXT);
+      } else if (shaderModel == spv::ExecutionModel::Fragment) {
+        // SV_RenderTargetArrayIndex can be used as PSIn.
+        spvBuilder.requireCapability(spv::Capability::Geometry);
+      }
+      break;
+    }
+    case spv::BuiltIn::ViewportIndex: {
+      if (shaderModel == spv::ExecutionModel::Vertex ||
+          shaderModel == spv::ExecutionModel::TessellationControl ||
+          shaderModel == spv::ExecutionModel::TessellationEvaluation) {
+        spvBuilder.addExtension(Extension::EXT_shader_viewport_index_layer,
+                                "SV_ViewPortArrayIndex", loc);
+        spvBuilder.requireCapability(
+            spv::Capability::ShaderViewportIndexLayerEXT);
+      } else if (shaderModel == spv::ExecutionModel::Fragment) {
+        // SV_ViewportArrayIndex can be used as PSIn.
+        spvBuilder.requireCapability(spv::Capability::MultiViewport);
+      }
+      break;
+    }
+    case spv::BuiltIn::ClipDistance: {
+      spvBuilder.requireCapability(spv::Capability::ClipDistance);
+      break;
+    }
+    case spv::BuiltIn::CullDistance: {
+      spvBuilder.requireCapability(spv::Capability::CullDistance);
+      break;
+    }
+    default:
+      break;
+    }
+
+    break;
+  }
+  default:
+    break;
+  }
+
+  return true;
+}
+
+spv::Capability
+CapabilityVisitor::getNonUniformCapability(const SpirvType *type) {
+  if (!type)
+    return spv::Capability::Max;
+
+  if (const auto *arrayType = dyn_cast<ArrayType>(type)) {
+    return getNonUniformCapability(arrayType->getElementType());
+  }
+  if (SpirvType::isTexture(type) || SpirvType::isSampler(type)) {
+    return spv::Capability::SampledImageArrayNonUniformIndexingEXT;
+  }
+  if (SpirvType::isRWTexture(type)) {
+    return spv::Capability::StorageImageArrayNonUniformIndexingEXT;
+  }
+  if (SpirvType::isBuffer(type)) {
+    return spv::Capability::UniformTexelBufferArrayNonUniformIndexingEXT;
+  }
+  if (SpirvType::isRWBuffer(type)) {
+    return spv::Capability::StorageTexelBufferArrayNonUniformIndexingEXT;
+  }
+  if (SpirvType::isSubpassInput(type) || SpirvType::isSubpassInputMS(type)) {
+    return spv::Capability::InputAttachmentArrayNonUniformIndexingEXT;
+  }
+
+  return spv::Capability::Max;
+}
+
+bool CapabilityVisitor::visit(SpirvImageQuery *instr) {
+  addCapabilityForType(instr->getResultType());
+  spvBuilder.requireCapability(spv::Capability::ImageQuery);
+  return true;
+}
+
+bool CapabilityVisitor::visit(SpirvImageSparseTexelsResident *instr) {
+  addCapabilityForType(instr->getResultType());
+  spvBuilder.requireCapability(spv::Capability::ImageGatherExtended);
+  return true;
+}
+
+bool CapabilityVisitor::visit(SpirvImageOp *instr) {
+  addCapabilityForType(instr->getResultType());
+  if (instr->hasOffset() || instr->hasConstOffsets())
+    spvBuilder.requireCapability(spv::Capability::ImageGatherExtended);
+  if (instr->hasMinLod())
+    spvBuilder.requireCapability(spv::Capability::MinLod);
+  if (instr->isSparse())
+    spvBuilder.requireCapability(spv::Capability::SparseResidency);
+
+  return true;
+}
+
+bool CapabilityVisitor::visit(SpirvVariable *var) {
+  addCapabilityForType(var->getResultType(), var->getSourceLocation(),
+                       var->getStorageClass());
+  return true;
+}
+
+bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
+  const SpirvType *resultType = instr->getResultType();
+  const auto opcode = instr->getopcode();
+
+  // Add result-type-specific capabilities
+  addCapabilityForType(resultType, instr->getSourceLocation());
+
+  // Add NonUniform capabilities if necessary
+  if (instr->isNonUniform()) {
+    spvBuilder.requireCapability(getNonUniformCapability(resultType));
+  }
+
+  // Add opcode-specific capabilities
+  switch (opcode) {
+  case spv::Op::OpDPdxCoarse:
+  case spv::Op::OpDPdyCoarse:
+  case spv::Op::OpFwidthCoarse:
+  case spv::Op::OpDPdxFine:
+  case spv::Op::OpDPdyFine:
+  case spv::Op::OpFwidthFine:
+    spvBuilder.requireCapability(spv::Capability::DerivativeControl);
+    break;
+  case spv::Op::OpGroupNonUniformElect:
+    spvBuilder.requireCapability(spv::Capability::GroupNonUniform);
+  case spv::Op::OpGroupNonUniformAny:
+  case spv::Op::OpGroupNonUniformAll:
+  case spv::Op::OpGroupNonUniformAllEqual:
+    spvBuilder.requireCapability(spv::Capability::GroupNonUniformVote);
+  case spv::Op::OpGroupNonUniformBallot:
+  case spv::Op::OpGroupNonUniformInverseBallot:
+  case spv::Op::OpGroupNonUniformBallotBitExtract:
+  case spv::Op::OpGroupNonUniformBallotBitCount:
+  case spv::Op::OpGroupNonUniformBallotFindLSB:
+  case spv::Op::OpGroupNonUniformBallotFindMSB:
+  case spv::Op::OpGroupNonUniformBroadcast:
+  case spv::Op::OpGroupNonUniformBroadcastFirst:
+    spvBuilder.requireCapability(spv::Capability::GroupNonUniformBallot);
+  case spv::Op::OpGroupNonUniformIAdd:
+  case spv::Op::OpGroupNonUniformFAdd:
+  case spv::Op::OpGroupNonUniformIMul:
+  case spv::Op::OpGroupNonUniformFMul:
+  case spv::Op::OpGroupNonUniformSMax:
+  case spv::Op::OpGroupNonUniformUMax:
+  case spv::Op::OpGroupNonUniformFMax:
+  case spv::Op::OpGroupNonUniformSMin:
+  case spv::Op::OpGroupNonUniformUMin:
+  case spv::Op::OpGroupNonUniformFMin:
+  case spv::Op::OpGroupNonUniformBitwiseAnd:
+  case spv::Op::OpGroupNonUniformBitwiseOr:
+  case spv::Op::OpGroupNonUniformBitwiseXor:
+  case spv::Op::OpGroupNonUniformLogicalAnd:
+  case spv::Op::OpGroupNonUniformLogicalOr:
+  case spv::Op::OpGroupNonUniformLogicalXor:
+    spvBuilder.requireCapability(spv::Capability::GroupNonUniformArithmetic);
+  case spv::Op::OpGroupNonUniformQuadBroadcast:
+  case spv::Op::OpGroupNonUniformQuadSwap:
+    spvBuilder.requireCapability(spv::Capability::GroupNonUniformQuad);
+  }
+
+  return true;
+}
+
+bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
+  shaderModel = entryPoint->getExecModel();
+  switch (shaderModel) {
+  case spv::ExecutionModel::Fragment:
+  case spv::ExecutionModel::Vertex:
+  case spv::ExecutionModel::GLCompute:
+    spvBuilder.requireCapability(spv::Capability::Shader);
+    break;
+  case spv::ExecutionModel::Geometry:
+    spvBuilder.requireCapability(spv::Capability::Geometry);
+    break;
+  case spv::ExecutionModel::TessellationControl:
+  case spv::ExecutionModel::TessellationEvaluation:
+    spvBuilder.requireCapability(spv::Capability::Tessellation);
+    break;
+  default:
+    llvm_unreachable("found unknown shader model");
+    break;
+  }
+  return true;
+}
+
+bool CapabilityVisitor::visit(SpirvExecutionMode *execMode) {
+  if (execMode->getExecutionMode() == spv::ExecutionMode::PostDepthCoverage) {
+    spvBuilder.requireCapability(
+        spv::Capability::SampleMaskPostDepthCoverage,
+        execMode->getEntryPoint()->getSourceLocation());
+  }
+  return true;
+}
+
+} // end namespace spirv
+} // end namespace clang

+ 65 - 0
tools/clang/lib/SPIRV/CapabilityVisitor.h

@@ -0,0 +1,65 @@
+//===--- CapabilityVisitor.h - Capability Visitor ----------------*- C++ -*-==//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_LIB_SPIRV_CAPABILITYVISITOR_H
+#define LLVM_CLANG_LIB_SPIRV_CAPABILITYVISITOR_H
+
+#include "clang/SPIRV/SPIRVContext.h"
+#include "clang/SPIRV/SpirvVisitor.h"
+#include "llvm/ADT/Optional.h"
+
+namespace clang {
+namespace spirv {
+
+class SpirvBuilder;
+
+class CapabilityVisitor : public Visitor {
+public:
+  CapabilityVisitor(SpirvContext &spvCtx, const SpirvCodeGenOptions &opts,
+                    SpirvBuilder &builder)
+      : Visitor(opts, spvCtx), spvContext(spvCtx), spvBuilder(builder) {}
+
+
+  bool visit(SpirvDecoration *decor);
+  bool visit(SpirvEntryPoint *);
+  bool visit(SpirvExecutionMode *);
+  bool visit(SpirvImageQuery *);
+  bool visit(SpirvImageOp *);
+  bool visit(SpirvImageSparseTexelsResident *);
+  bool visit(SpirvVariable *);
+
+  /// The "sink" visit function for all instructions.
+  ///
+  /// By default, all other visit instructions redirect to this visit function.
+  /// So that you want override this visit function to handle all instructions,
+  /// regardless of their polymorphism.
+  bool visitInstruction(SpirvInstruction *instr);
+
+private:
+  /// Adds necessary capabilities for using the given type.
+  /// The called may also provide the storage class for variable types, because
+  /// in the case of variable types, the storage class may affect the capability
+  /// that is used.
+  void addCapabilityForType(const SpirvType *, SourceLocation loc = {},
+                            spv::StorageClass sc = spv::StorageClass::Max);
+
+  /// Returns the capability required to non-uniformly index into the given
+  /// type.
+  spv::Capability getNonUniformCapability(const SpirvType *);
+
+private:
+  SpirvContext &spvContext;        /// SPIR-V context
+  SpirvBuilder &spvBuilder;        /// SPIR-V builder
+  spv::ExecutionModel shaderModel; /// Execution model
+};
+
+} // end namespace spirv
+} // end namespace clang
+
+#endif // LLVM_CLANG_LIB_SPIRV_CAPABILITYVISITOR_H

+ 0 - 88
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -632,14 +632,6 @@ SpirvVariable *DeclResultIdMapper::createExternVar(const VarDecl *var) {
     return cast<SpirvVariable>(astDecls[var].instr);
   }
 
-  if (storageClass == spv::StorageClass::Uniform &&
-      spirvOptions.enable16BitTypes &&
-      isOrContains16BitType(var->getType(), spirvOptions.enable16BitTypes)) {
-    spvBuilder.addExtension(Extension::KHR_16bit_storage,
-                            "16-bit types in resource", var->getLocation());
-    spvBuilder.requireCapability(spv::Capability::StorageUniformBufferBlock16);
-  }
-
   const auto type = var->getType();
   const auto loc = var->getLocation();
   SpirvVariable *varInstr = spvBuilder.addModuleVar(
@@ -704,18 +696,6 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
     varType.removeLocalConst();
     HybridStructType::FieldInfo info(varType, declDecl->getName());
     fields.push_back(info);
-
-    if (spirvOptions.enable16BitTypes &&
-        isOrContains16BitType(varType, spirvOptions.enable16BitTypes)) {
-      spvBuilder.addExtension(Extension::KHR_16bit_storage,
-                              "16-bit types in resource",
-                              declDecl->getLocation());
-      spvBuilder.requireCapability(
-          (forCBuffer || forGlobals)
-              ? spv::Capability::StorageUniform16
-              : forPC ? spv::Capability::StoragePushConstant16
-                      : spv::Capability::StorageUniformBufferBlock16);
-    }
   }
 
   // Get the type for the whole struct
@@ -729,10 +709,6 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
   if (arraySize > 0) {
     resultType = spvContext.getArrayType(resultType, arraySize);
   } else if (arraySize == -1) {
-    // Runtime arrays of cbuffer/tbuffer needs additional capability.
-    spvBuilder.addExtension(Extension::EXT_descriptor_indexing,
-                            "runtime array of resources", {});
-    spvBuilder.requireCapability(spv::Capability::RuntimeDescriptorArrayEXT);
     resultType = spvContext.getRuntimeArrayType(resultType);
   }
 
@@ -1670,14 +1646,6 @@ bool DeclResultIdMapper::createStageVars(
     // Mark that we have used one index for this semantic
     ++semanticToUse->index;
 
-    // Require extension and capability if using 16-bit types
-    if (getElementSpirvBitwidth(astContext, type,
-                                spirvOptions.enable16BitTypes) == 16) {
-      spvBuilder.addExtension(Extension::KHR_16bit_storage,
-                              "16-bit stage IO variables", decl->getLocation());
-      spvBuilder.requireCapability(spv::Capability::StorageInputOutput16);
-    }
-
     // TODO: the following may not be correct?
     if (sigPoint->GetSignatureKind() ==
         hlsl::DXIL::SignatureKind::PatchConstant)
@@ -2168,7 +2136,6 @@ void DeclResultIdMapper::decoratePSInterpolationMode(const NamedDecl *decl,
     if (decl->getAttr<HLSLNoPerspectiveAttr>())
       spvBuilder.decorateNoPerspective(varInstr, loc);
     if (decl->getAttr<HLSLSampleAttr>()) {
-      spvBuilder.requireCapability(spv::Capability::SampleRateShading, loc);
       spvBuilder.decorateSample(varInstr, loc);
     }
   }
@@ -2191,8 +2158,6 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn) {
     return nullptr;
   }
 
-  spvBuilder.requireCapability(spv::Capability::GroupNonUniform);
-
   // Create a dummy StageVar for this builtin variable
   auto var = spvBuilder.addStageBuiltinVar(spvContext.getUIntType(32),
                                            spv::StorageClass::Input, builtIn);
@@ -2253,26 +2218,6 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
             .Default(BuiltIn::Max);
 
     assert(spvBuiltIn != BuiltIn::Max); // The frontend should guarantee this.
-
-    switch (spvBuiltIn) {
-    case BuiltIn::BaseVertex:
-    case BuiltIn::BaseInstance:
-    case BuiltIn::DrawIndex:
-      spvBuilder.addExtension(Extension::KHR_shader_draw_parameters,
-                              builtinAttr->getBuiltIn(),
-                              builtinAttr->getLocation());
-      spvBuilder.requireCapability(spv::Capability::DrawParameters);
-      break;
-    case BuiltIn::DeviceIndex:
-      spvBuilder.addExtension(Extension::KHR_device_group,
-                              stageVar->getSemanticStr(), srcLoc);
-      spvBuilder.requireCapability(spv::Capability::DeviceGroup);
-      break;
-    default:
-      // Just seeking builtins requiring extensions. The rest can be ignored.
-      break;
-    }
-
     return spvBuilder.addStageBuiltinVar(type, sc, spvBuiltIn);
   }
 
@@ -2440,11 +2385,6 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   // According to Vulkan spec, the PrimitiveId BuiltIn can only be used in
   // HS/DS/PS In, GS In/Out.
   case hlsl::Semantic::Kind::PrimitiveID: {
-    // PrimitiveId requires either Tessellation or Geometry capability.
-    // Need to require one for PSIn.
-    if (sigPointKind == hlsl::SigPoint::Kind::PSIn)
-      spvBuilder.requireCapability(spv::Capability::Geometry);
-
     // Translate to PrimitiveId BuiltIn for all valid SigPoints.
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::PrimitiveId);
@@ -2481,17 +2421,11 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   // According to DXIL spec, the SampleIndex SV can only be used by PSIn.
   // According to Vulkan spec, the SampleId BuiltIn can only be used in PSIn.
   case hlsl::Semantic::Kind::SampleIndex: {
-    spvBuilder.requireCapability(spv::Capability::SampleRateShading);
-
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleId);
   }
   // According to DXIL spec, the StencilRef SV can only be used by PSOut.
   case hlsl::Semantic::Kind::StencilRef: {
-    spvBuilder.addExtension(Extension::EXT_shader_stencil_export,
-                            stageVar->getSemanticStr(), srcLoc);
-    spvBuilder.requireCapability(spv::Capability::StencilExportEXT);
-
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragStencilRefEXT);
   }
@@ -2539,17 +2473,10 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
       return spvBuilder.addStageIOVar(type, sc, name.str());
     case hlsl::SigPoint::Kind::VSOut:
     case hlsl::SigPoint::Kind::DSOut:
-      spvBuilder.addExtension(Extension::EXT_shader_viewport_index_layer,
-                              "SV_RenderTargetArrayIndex", srcLoc);
-      spvBuilder.requireCapability(
-          spv::Capability::ShaderViewportIndexLayerEXT);
-
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer);
     case hlsl::SigPoint::Kind::GSOut:
     case hlsl::SigPoint::Kind::PSIn:
-      spvBuilder.requireCapability(spv::Capability::Geometry);
-
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer);
     default:
@@ -2572,17 +2499,10 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
       return spvBuilder.addStageIOVar(type, sc, name.str());
     case hlsl::SigPoint::Kind::VSOut:
     case hlsl::SigPoint::Kind::DSOut:
-      spvBuilder.addExtension(Extension::EXT_shader_viewport_index_layer,
-                              "SV_ViewPortArrayIndex", srcLoc);
-      spvBuilder.requireCapability(
-          spv::Capability::ShaderViewportIndexLayerEXT);
-
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex);
     case hlsl::SigPoint::Kind::GSOut:
     case hlsl::SigPoint::Kind::PSIn:
-      spvBuilder.requireCapability(spv::Capability::MultiViewport);
-
       stageVar->setIsSpirvBuiltin();
       return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex);
     default:
@@ -2601,10 +2521,6 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
   // According to Vulkan spec, the ViewIndex BuiltIn can only be used in
   // VS/HS/DS/GS/PS input.
   case hlsl::Semantic::Kind::ViewID: {
-    spvBuilder.addExtension(Extension::KHR_multiview,
-                            stageVar->getSemanticStr(), srcLoc);
-    spvBuilder.requireCapability(spv::Capability::MultiView);
-
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewIndex);
   }
@@ -2612,10 +2528,6 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
     // According to Vulkan spec, the FullyCoveredEXT BuiltIn can only be used as
     // PSIn.
   case hlsl::Semantic::Kind::InnerCoverage: {
-    spvBuilder.addExtension(Extension::EXT_fragment_fully_covered,
-                            stageVar->getSemanticStr(), srcLoc);
-    spvBuilder.requireCapability(spv::Capability::FragmentFullyCoveredEXT);
-
     stageVar->setIsSpirvBuiltin();
     return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FullyCoveredEXT);
   }

+ 0 - 8
tools/clang/lib/SPIRV/GlPerVertex.cpp

@@ -108,14 +108,6 @@ llvm::SmallVector<SpirvVariable *, 2> GlPerVertex::getStageOutVars() const {
   return vars;
 }
 
-void GlPerVertex::requireCapabilityIfNecessary() {
-  if (!inClipType.empty() || !outClipType.empty())
-    spvBuilder.requireCapability(spv::Capability::ClipDistance);
-
-  if (!inCullType.empty() || !outCullType.empty())
-    spvBuilder.requireCapability(spv::Capability::CullDistance);
-}
-
 bool GlPerVertex::recordGlPerVertexDeclFacts(const DeclaratorDecl *decl,
                                              bool asInput) {
   const QualType type = getTypeOrFnRetType(decl);

+ 0 - 4
tools/clang/lib/SPIRV/GlPerVertex.h

@@ -68,10 +68,6 @@ public:
   /// Returns the stage output variables.
   llvm::SmallVector<SpirvVariable *, 2> getStageOutVars() const;
 
-  /// Requires the ClipDistance/CullDistance capability if we've seen
-  /// definition of SV_ClipDistance/SV_CullDistance.
-  void requireCapabilityIfNecessary();
-
   /// Tries to access the builtin translated from the given HLSL semantic of the
   /// given index.
   ///

+ 4 - 2
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -402,10 +402,12 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
       structName = getAstTypeName(innerType);
 
     const auto *raType = spvContext.getRuntimeArrayType(structType);
+    const bool isReadOnly = (name == "StructuredBuffer");
 
     const std::string typeName = "type." + name.str() + "." + structName;
-    const auto *valType =
-        spvContext.getStructType({StructType::FieldInfo(raType)}, typeName);
+    const auto *valType = spvContext.getStructType(
+        {StructType::FieldInfo(raType)}, typeName, isReadOnly,
+        StructInterfaceType::StorageBuffer);
 
     if (asAlias) {
       // All structured buffers are in the Uniform storage class.

+ 4 - 91
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -501,44 +501,6 @@ void getBaseClassIndices(const CastExpr *expr,
   }
 }
 
-spv::Capability getCapabilityForGroupNonUniform(spv::Op opcode) {
-  switch (opcode) {
-  case spv::Op::OpGroupNonUniformElect:
-    return spv::Capability::GroupNonUniform;
-  case spv::Op::OpGroupNonUniformAny:
-  case spv::Op::OpGroupNonUniformAll:
-  case spv::Op::OpGroupNonUniformAllEqual:
-    return spv::Capability::GroupNonUniformVote;
-  case spv::Op::OpGroupNonUniformBallot:
-  case spv::Op::OpGroupNonUniformBallotBitCount:
-  case spv::Op::OpGroupNonUniformBroadcast:
-  case spv::Op::OpGroupNonUniformBroadcastFirst:
-    return spv::Capability::GroupNonUniformBallot;
-  case spv::Op::OpGroupNonUniformIAdd:
-  case spv::Op::OpGroupNonUniformFAdd:
-  case spv::Op::OpGroupNonUniformIMul:
-  case spv::Op::OpGroupNonUniformFMul:
-  case spv::Op::OpGroupNonUniformSMax:
-  case spv::Op::OpGroupNonUniformUMax:
-  case spv::Op::OpGroupNonUniformFMax:
-  case spv::Op::OpGroupNonUniformSMin:
-  case spv::Op::OpGroupNonUniformUMin:
-  case spv::Op::OpGroupNonUniformFMin:
-  case spv::Op::OpGroupNonUniformBitwiseAnd:
-  case spv::Op::OpGroupNonUniformBitwiseOr:
-  case spv::Op::OpGroupNonUniformBitwiseXor:
-    return spv::Capability::GroupNonUniformArithmetic;
-  case spv::Op::OpGroupNonUniformQuadBroadcast:
-  case spv::Op::OpGroupNonUniformQuadSwap:
-    return spv::Capability::GroupNonUniformQuad;
-  default:
-    assert(false && "unhandled opcode");
-    break;
-  }
-  assert(false && "unhandled opcode");
-  return spv::Capability::Max;
-}
-
 std::string getNamespacePrefix(const Decl *decl) {
   std::string nsPrefix = "";
   const DeclContext *dc = decl->getDeclContext();
@@ -691,8 +653,6 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
 
   const spv_target_env targetEnv = featureManager.getTargetEnv();
 
-  AddRequiredCapabilitiesForShaderModel();
-
   // Addressing and memory model are required in a valid SPIR-V module.
   spvBuilder.setMemoryModel(spv::AddressingModel::Logical,
                             spv::MemoryModel::GLSL450);
@@ -1999,9 +1959,6 @@ SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
   auto *info = loadIfAliasVarRef(base);
 
   if (foundNonUniformResourceIndex) {
-    // Add the necessary capability required for indexing into this kind
-    // of resource
-    spvBuilder.requireCapability(getNonUniformCapability(base->getType()));
     info->setNonUniform(); // Carry forward the NonUniformEXT decoration
     foundNonUniformResourceIndex = false;
   }
@@ -6408,13 +6365,6 @@ SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
 
   SpirvInstruction *retVal = nullptr;
 
-#define INTRINSIC_SPIRV_OP_WITH_CAP_CASE(intrinsicOp, spirvOp, doEachVec, cap) \
-  case hlsl::IntrinsicOp::IOP_##intrinsicOp: {                                 \
-    spvBuilder.requireCapability(cap);                                         \
-    retVal = processIntrinsicUsingSpirvInst(callExpr, spv::Op::Op##spirvOp,    \
-                                            doEachVec);                        \
-  } break
-
 #define INTRINSIC_SPIRV_OP_CASE(intrinsicOp, spirvOp, doEachVec)               \
   case hlsl::IntrinsicOp::IOP_##intrinsicOp: {                                 \
     retVal = processIntrinsicUsingSpirvInst(callExpr, spv::Op::Op##spirvOp,    \
@@ -6702,15 +6652,11 @@ SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
     break;
   }
     INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
-    INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddx_coarse, DPdxCoarse, false,
-                                     spv::Capability::DerivativeControl);
-    INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddx_fine, DPdxFine, false,
-                                     spv::Capability::DerivativeControl);
+    INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
+    INTRINSIC_SPIRV_OP_CASE(ddx_fine, DPdxFine, false);
     INTRINSIC_SPIRV_OP_CASE(ddy, DPdy, true);
-    INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddy_coarse, DPdyCoarse, false,
-                                     spv::Capability::DerivativeControl);
-    INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddy_fine, DPdyFine, false,
-                                     spv::Capability::DerivativeControl);
+    INTRINSIC_SPIRV_OP_CASE(ddy_coarse, DPdyCoarse, false);
+    INTRINSIC_SPIRV_OP_CASE(ddy_fine, DPdyFine, false);
     INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
     INTRINSIC_SPIRV_OP_CASE(isinf, IsInf, true);
     INTRINSIC_SPIRV_OP_CASE(isnan, IsNan, true);
@@ -6918,10 +6864,6 @@ SPIRVEmitter::processIntrinsicInterlockedMethod(const CallExpr *expr,
 SpirvInstruction *
 SPIRVEmitter::processIntrinsicNonUniformResourceIndex(const CallExpr *expr) {
   foundNonUniformResourceIndex = true;
-  spvBuilder.addExtension(Extension::EXT_descriptor_indexing,
-                          "NonUniformResourceIndex", expr->getExprLoc());
-  spvBuilder.requireCapability(spv::Capability::ShaderNonUniformEXT);
-
   auto *index = doExpr(expr->getArg(0));
   index->setNonUniform();
 
@@ -7110,7 +7052,6 @@ SpirvInstruction *SPIRVEmitter::processWaveQuery(const CallExpr *callExpr,
   assert(callExpr->getNumArgs() == 0);
   featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
                                   callExpr->getExprLoc());
-  spvBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
   const QualType retType = callExpr->getCallReturnType(astContext);
   return spvBuilder.createGroupNonUniformElect(opcode, retType,
                                                spv::Scope::Subgroup);
@@ -7125,7 +7066,6 @@ SpirvInstruction *SPIRVEmitter::processWaveVote(const CallExpr *callExpr,
   assert(callExpr->getNumArgs() == 1);
   featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
                                   callExpr->getExprLoc());
-  spvBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
   auto *predicate = doExpr(callExpr->getArg(0));
   const QualType retType = callExpr->getCallReturnType(astContext);
   return spvBuilder.createGroupNonUniformUnaryOp(
@@ -7211,9 +7151,6 @@ SPIRVEmitter::processWaveCountBits(const CallExpr *callExpr,
 
   featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
                                   callExpr->getExprLoc());
-  spvBuilder.requireCapability(getCapabilityForGroupNonUniform(
-      spv::Op::OpGroupNonUniformBallotBitCount));
-
   auto *predicate = doExpr(callExpr->getArg(0));
   const QualType u32Type = astContext.UnsignedIntTy;
   const QualType v4u32Type = astContext.getExtVectorType(u32Type, 4);
@@ -7244,7 +7181,6 @@ SpirvInstruction *SPIRVEmitter::processWaveReductionOrPrefix(
   assert(callExpr->getNumArgs() == 1);
   featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
                                   callExpr->getExprLoc());
-  spvBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
   auto *predicate = doExpr(callExpr->getArg(0));
   const QualType retType = callExpr->getCallReturnType(astContext);
   return spvBuilder.createGroupNonUniformUnaryOp(
@@ -7260,7 +7196,6 @@ SpirvInstruction *SPIRVEmitter::processWaveBroadcast(const CallExpr *callExpr) {
   assert(numArgs == 1 || numArgs == 2);
   featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
                                   callExpr->getExprLoc());
-  spvBuilder.requireCapability(spv::Capability::GroupNonUniformBallot);
   auto *value = doExpr(callExpr->getArg(0));
   const QualType retType = callExpr->getCallReturnType(astContext);
   if (numArgs == 2)
@@ -7284,7 +7219,6 @@ SPIRVEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr,
   assert(callExpr->getNumArgs() == 1 || callExpr->getNumArgs() == 2);
   featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
                                   callExpr->getExprLoc());
-  spvBuilder.requireCapability(spv::Capability::GroupNonUniformQuad);
 
   auto *value = doExpr(callExpr->getArg(0));
   const QualType retType = callExpr->getCallReturnType(astContext);
@@ -9154,16 +9088,6 @@ SPIRVEmitter::getSpirvShaderStage(const hlsl::ShaderModel &model) {
   llvm_unreachable("unknown shader model");
 }
 
-void SPIRVEmitter::AddRequiredCapabilitiesForShaderModel() {
-  if (shaderModel.IsHS() || shaderModel.IsDS()) {
-    spvBuilder.requireCapability(spv::Capability::Tessellation);
-  } else if (shaderModel.IsGS()) {
-    spvBuilder.requireCapability(spv::Capability::Geometry);
-  } else {
-    spvBuilder.requireCapability(spv::Capability::Shader);
-  }
-}
-
 bool SPIRVEmitter::processGeometryShaderAttributes(const FunctionDecl *decl,
                                                    uint32_t *arraySize) {
   bool success = true;
@@ -9271,7 +9195,6 @@ void SPIRVEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
   if (decl->getAttr<VKPostDepthCoverageAttr>()) {
     spvBuilder.addExtension(Extension::KHR_post_depth_coverage,
                             "[[vk::post_depth_coverage]]", decl->getLocation());
-    spvBuilder.requireCapability(spv::Capability::SampleMaskPostDepthCoverage);
     spvBuilder.addExecutionMode(entryFunction,
                                 spv::ExecutionMode::PostDepthCoverage, {},
                                 decl->getLocation());
@@ -9475,16 +9398,6 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
     declIdMapper.glPerVertex.generateVars(inputArraySize, outputArraySize);
   }
 
-  // Require the ClipDistance/CullDistance capability if necessary.
-  // It is legal to just use the ClipDistance/CullDistance builtin without
-  // requiring the ClipDistance/CullDistance capability, as long as we don't
-  // read or write the builtin variable.
-  // For our CodeGen, that corresponds to not seeing SV_ClipDistance or
-  // SV_CullDistance at all. If we see them, we will generate code to read
-  // them to initialize temporary variable for calling the source code entry
-  // function or write to them after calling the source code entry function.
-  declIdMapper.glPerVertex.requireCapabilityIfNecessary();
-
   // The entry basic block.
   auto *entryLabel = spvBuilder.createBasicBlock();
   spvBuilder.setInsertPoint(entryLabel);

+ 0 - 2
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -588,8 +588,6 @@ private:
   static spv::ExecutionModel
   getSpirvShaderStage(const hlsl::ShaderModel &model);
 
-  void AddRequiredCapabilitiesForShaderModel();
-
   /// \brief Adds necessary execution modes for the hull/domain shaders based on
   /// the HLSL attributes of the entry point function.
   /// In the case of hull shaders, also writes the number of output control

+ 9 - 49
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -8,6 +8,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/SPIRV/SpirvBuilder.h"
+#include "CapabilityVisitor.h"
 #include "TypeTranslator.h"
 #include "clang/SPIRV/EmitVisitor.h"
 #include "clang/SPIRV/LowerTypeVisitor.h"
@@ -244,16 +245,6 @@ SpirvUnaryOp *SpirvBuilder::createUnaryOp(spv::Op op, QualType resultType,
   auto *instruction =
       new (context) SpirvUnaryOp(op, resultType, /*id*/ 0, loc, operand);
   insertPoint->addInstruction(instruction);
-  switch (op) {
-  case spv::Op::OpImageQuerySize:
-  case spv::Op::OpImageQueryLevels:
-  case spv::Op::OpImageQuerySamples:
-    requireCapability(spv::Capability::ImageQuery);
-    break;
-  default:
-    // Only checking for ImageQueries, the other Ops can be ignored.
-    break;
-  }
   return instruction;
 }
 
@@ -265,15 +256,6 @@ SpirvBinaryOp *SpirvBuilder::createBinaryOp(spv::Op op, QualType resultType,
   auto *instruction =
       new (context) SpirvBinaryOp(op, resultType, /*id*/ 0, loc, lhs, rhs);
   insertPoint->addInstruction(instruction);
-  switch (op) {
-  case spv::Op::OpImageQueryLod:
-  case spv::Op::OpImageQuerySizeLod:
-    requireCapability(spv::Capability::ImageQuery);
-    break;
-  default:
-    // Only checking for ImageQueries, the other Ops can be ignored.
-    break;
-  }
   return instruction;
 }
 
@@ -287,15 +269,6 @@ SpirvBinaryOp *SpirvBuilder::createBinaryOp(spv::Op op,
       new (context) SpirvBinaryOp(op, /*QualType*/ {}, /*id*/ 0, loc, lhs, rhs);
   instruction->setResultType(resultType);
   insertPoint->addInstruction(instruction);
-  switch (op) {
-  case spv::Op::OpImageQueryLod:
-  case spv::Op::OpImageQuerySizeLod:
-    requireCapability(spv::Capability::ImageQuery);
-    break;
-  default:
-    // Only checking for ImageQueries, the other Ops can be ignored.
-    break;
-  }
   return instruction;
 }
 
@@ -402,17 +375,14 @@ spv::ImageOperandsMask SpirvBuilder::composeImageOperandsMask(
   }
   if (varOffset) {
     mask = mask | ImageOperandsMask::Offset;
-    requireCapability(spv::Capability::ImageGatherExtended);
   }
   if (constOffsets) {
     mask = mask | ImageOperandsMask::ConstOffsets;
-    requireCapability(spv::Capability::ImageGatherExtended);
   }
   if (sample) {
     mask = mask | ImageOperandsMask::Sample;
   }
   if (minLod) {
-    requireCapability(spv::Capability::MinLod);
     mask = mask | ImageOperandsMask::MinLod;
   }
   return mask;
@@ -452,10 +422,6 @@ SpirvInstruction *SpirvBuilder::createImageSample(
   // explicit insturctions. So either lod or minLod or both must be zero.
   assert(lod == nullptr || minLod == nullptr);
 
-  if (isSparse) {
-    requireCapability(spv::Capability::SparseResidency);
-  }
-
   // An OpSampledImage is required to do the image sampling.
   auto *sampledImage =
       new (context) SpirvSampledImage(imageType, /*id*/ 0, loc, image, sampler);
@@ -502,20 +468,12 @@ SpirvInstruction *SpirvBuilder::createImageFetchOrRead(
       varOffset, constOffsets, sample, /*minLod*/ nullptr);
 
   const bool isSparse = (residencyCode != nullptr);
-  if (isSparse) {
-    requireCapability(spv::Capability::SparseResidency);
-  }
 
   spv::Op op =
       doImageFetch
           ? (isSparse ? spv::Op::OpImageSparseFetch : spv::Op::OpImageFetch)
           : (isSparse ? spv::Op::OpImageSparseRead : spv::Op::OpImageRead);
 
-  if (!doImageFetch) {
-    requireCapability(
-        TypeTranslator::getCapabilityForStorageImageReadWrite(imageType));
-  }
-
   auto *fetchOrReadInst = new (context) SpirvImageOp(
       op, texelType, /*id*/ 0, loc, image, coordinate, mask,
       /*dref*/ nullptr, /*bias*/ nullptr, lod, /*gradDx*/ nullptr,
@@ -539,8 +497,6 @@ void SpirvBuilder::createImageWrite(QualType imageType, SpirvInstruction *image,
                                     SpirvInstruction *texel,
                                     SourceLocation loc) {
   assert(insertPoint && "null insert point");
-  requireCapability(
-      TypeTranslator::getCapabilityForStorageImageReadWrite(imageType));
   auto *writeInst = new (context) SpirvImageOp(
       spv::Op::OpImageWrite, imageType, /*id*/ 0, loc, image, coord,
       spv::ImageOperandsMask::MaskNone,
@@ -560,10 +516,6 @@ SpirvInstruction *SpirvBuilder::createImageGather(
     SpirvInstruction *residencyCode, SourceLocation loc) {
   assert(insertPoint && "null insert point");
 
-  if (residencyCode) {
-    requireCapability(spv::Capability::SparseResidency);
-  }
-
   // An OpSampledImage is required to do the image sampling.
   auto *sampledImage =
       new (context) SpirvSampledImage(imageType, /*id*/ 0, loc, image, sampler);
@@ -1169,8 +1121,16 @@ SpirvConstant *SpirvBuilder::getConstantNull(QualType type) {
 std::vector<uint32_t> SpirvBuilder::takeModule() {
   // Run necessary visitor passes first
   LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions);
+  CapabilityVisitor capabilityVisitor(context, spirvOptions, *this);
   EmitVisitor emitVisitor(astContext, context, spirvOptions, *this);
+
+  // Lower types
   module->invokeVisitor(&lowerTypeVisitor);
+
+  // Add necessary capabilities and extensions
+  module->invokeVisitor(&capabilityVisitor);
+
+  // Emit SPIR-V
   module->invokeVisitor(&emitVisitor);
 
   return emitVisitor.takeBinary();

+ 11 - 0
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -773,6 +773,17 @@ SpirvImageOp::SpirvImageOp(
   }
 }
 
+bool SpirvImageOp::isSparse() const {
+  return opcode == spv::Op::OpImageSparseSampleImplicitLod ||
+         opcode == spv::Op::OpImageSparseSampleExplicitLod ||
+         opcode == spv::Op::OpImageSparseSampleDrefImplicitLod ||
+         opcode == spv::Op::OpImageSparseSampleDrefExplicitLod ||
+         opcode == spv::Op::OpImageSparseFetch ||
+         opcode == spv::Op::OpImageSparseGather ||
+         opcode == spv::Op::OpImageSparseDrefGather ||
+         opcode == spv::Op::OpImageSparseRead;
+}
+
 SpirvImageQuery::SpirvImageQuery(spv::Op op, QualType resultType,
                                  uint32_t resultId, SourceLocation loc,
                                  SpirvInstruction *img,

+ 101 - 0
tools/clang/lib/SPIRV/SpirvType.cpp

@@ -27,6 +27,107 @@ bool ScalarType::classof(const SpirvType *t) {
   return false;
 }
 
+bool SpirvType::isTexture(const SpirvType *type) {
+  if (const auto *imageType = dyn_cast<ImageType>(type)) {
+    const auto dim = imageType->getDimension();
+    const auto withSampler = imageType->withSampler();
+    return (withSampler == ImageType::WithSampler::Yes) &&
+           (dim == spv::Dim::Dim1D || dim == spv::Dim::Dim2D ||
+            dim == spv::Dim::Dim3D || dim == spv::Dim::Cube);
+  }
+  return false;
+}
+
+bool SpirvType::isRWTexture(const SpirvType *type) {
+  if (const auto *imageType = dyn_cast<ImageType>(type)) {
+    const auto dim = imageType->getDimension();
+    const auto withSampler = imageType->withSampler();
+    return (withSampler == ImageType::WithSampler::No) &&
+           (dim == spv::Dim::Dim1D || dim == spv::Dim::Dim2D ||
+            dim == spv::Dim::Dim3D);
+  }
+  return false;
+}
+
+bool SpirvType::isSampler(const SpirvType *type) {
+  return isa<SamplerType>(type);
+}
+
+bool SpirvType::isBuffer(const SpirvType *type) {
+  if (const auto *imageType = dyn_cast<ImageType>(type)) {
+    const auto dim = imageType->getDimension();
+    const auto withSampler = imageType->withSampler();
+    return imageType->getDimension() == spv::Dim::Buffer &&
+           imageType->withSampler() == ImageType::WithSampler::Yes;
+  }
+  return false;
+}
+
+bool SpirvType::isRWBuffer(const SpirvType *type) {
+  if (const auto *imageType = dyn_cast<ImageType>(type)) {
+    return imageType->getDimension() == spv::Dim::Buffer &&
+           imageType->withSampler() == ImageType::WithSampler::No;
+  }
+  return false;
+}
+
+bool SpirvType::isSubpassInput(const SpirvType *type) {
+  if (const auto *imageType = dyn_cast<ImageType>(type)) {
+    return imageType->getDimension() == spv::Dim::SubpassData &&
+           imageType->isMSImage() == false;
+  }
+  return false;
+}
+
+bool SpirvType::isSubpassInputMS(const SpirvType *type) {
+  if (const auto *imageType = dyn_cast<ImageType>(type)) {
+    return imageType->getDimension() == spv::Dim::SubpassData &&
+           imageType->isMSImage() == true;
+  }
+  return false;
+}
+
+bool SpirvType::isResourceType(const SpirvType *type) {
+  if (isa<ImageType>(type) || isa<SamplerType>(type))
+    return true;
+
+  if (const auto *structType = dyn_cast<StructType>(type))
+    return structType->getInterfaceType() !=
+           StructInterfaceType::InternalStorage;
+
+  if (const auto *pointerType = dyn_cast<SpirvPointerType>(type))
+    return isResourceType(pointerType->getPointeeType());
+
+  return false;
+}
+
+bool SpirvType::isOrContains16BitType(const SpirvType *type) {
+  if (const auto *numericType = dyn_cast<NumericalType>(type))
+    if (numericType->getBitwidth() == 16)
+      return true;
+
+  if (const auto *vecType = dyn_cast<VectorType>(type))
+    return isOrContains16BitType(vecType->getElementType());
+  if (const auto *matType = dyn_cast<MatrixType>(type))
+    return isOrContains16BitType(matType->getElementType());
+  if (const auto *arrType = dyn_cast<MatrixType>(type))
+    return isOrContains16BitType(arrType->getElementType());
+  if (const auto *pointerType = dyn_cast<SpirvPointerType>(type))
+    return isOrContains16BitType(pointerType->getPointeeType());
+  if (const auto *raType = dyn_cast<MatrixType>(type))
+    return isOrContains16BitType(raType->getElementType());
+  if (const auto *imgType = dyn_cast<ImageType>(type))
+    return isOrContains16BitType(imgType->getSampledType());
+  if (const auto *sampledImageType = dyn_cast<SampledImageType>(type))
+    return isOrContains16BitType(sampledImageType->getImageType());
+  if (const auto *structType = dyn_cast<StructType>(type))
+    for (auto &field : structType->getFields())
+      if (isOrContains16BitType(field.type))
+        return true;
+
+  return false;
+}
+
 MatrixType::MatrixType(const VectorType *vecType, uint32_t vecCount,
                        bool rowMajor)
     : SpirvType(TK_Matrix), vectorType(vecType), vectorCount(vecCount),