Browse Source

[spirv] Add initial support for specialization constant (#1009)

This commit add support for generating OpSpecConstant* instructions
with SpecId decorations. Spec constants are only allowed to be of
scalar boolean/integer/float types. Using spec constant as the array
size does not work at the moment.
Lei Zhang 7 years ago
parent
commit
14c3c0d92c

+ 14 - 0
docs/SPIR-V.rst

@@ -214,6 +214,18 @@ annotated with the ``[[vk::push_constant]]`` attribute.
 Please note as per the requirements of Vulkan, "there must be no more than one
 Please note as per the requirements of Vulkan, "there must be no more than one
 push constant block statically used per shader entry point."
 push constant block statically used per shader entry point."
 
 
+Specialization constants
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+To use Vulkan specialization constants, annotate global constants with the
+``[[vk::constant_id(X)]]`` attribute. For example,
+
+.. code:: hlsl
+
+  [[vk::constant_id(1)]] const bool  specConstBool  = true;
+  [[vk::constant_id(2)]] const int   specConstInt   = 42;
+  [[vk::constant_id(3)]] const float specConstFloat = 1.5;
+
 Builtin variables
 Builtin variables
 ~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~
 
 
@@ -247,6 +259,8 @@ The namespace ``vk`` will be used for all Vulkan attributes:
 - ``push_constant``: For marking a variable as the push constant block. Allowed
 - ``push_constant``: For marking a variable as the push constant block. Allowed
   on global variables of struct type. At most one variable can be marked as
   on global variables of struct type. At most one variable can be marked as
   ``push_constant`` in a shader.
   ``push_constant`` in a shader.
+- ``constant_id``: For marking a global constant as a specialization constant.
+  Allowed on global variables of boolean/integer/float types.
 - ``input_attachment_index(X)``: To associate the Xth entry in the input pass
 - ``input_attachment_index(X)``: To associate the Xth entry in the input pass
   list to the annotated object. Only allowed on objects whose type are
   list to the annotated object. Only allowed on objects whose type are
   ``SubpassInput`` or ``SubpassInputMS``.
   ``SubpassInput`` or ``SubpassInputMS``.

+ 1 - 1
external/effcee

@@ -1 +1 @@
-Subproject commit 4a6edb2f740b9b87b04306a7815f42de5ca149a4
+Subproject commit 2741bade14f1ab23f3b90f0e5c77c6b935fc2fff

+ 11 - 0
tools/clang/include/clang/Basic/Attr.td

@@ -936,6 +936,17 @@ def VKInputAttachmentIndex : InheritableAttr {
   let Documentation = [Undocumented];
   let Documentation = [Undocumented];
 }
 }
 
 
+// Global variables that are of scalar type
+def ScalarGlobalVar : SubsetSubject<Var, [{S->hasGlobalStorage() && S->getType()->isScalarType()}]>;
+
+def VKConstantId : InheritableAttr {
+  let Spellings = [CXX11<"vk", "constant_id">];
+  let Subjects = SubjectList<[ScalarGlobalVar], ErrorDiag, "ExpectedScalarGlobalVar">;
+  let Args = [IntArgument<"SpecConstId">];
+  let LangOpts = [SPIRV];
+  let Documentation = [Undocumented];
+}
+
 // SPIRV Change Ends
 // SPIRV Change Ends
 
 
 def C11NoReturn : InheritableAttr {
 def C11NoReturn : InheritableAttr {

+ 1 - 0
tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

@@ -2327,6 +2327,7 @@ def warn_attribute_wrong_decl_type : Warning<
   "interface or protocol declarations|kernel functions|"
   "interface or protocol declarations|kernel functions|"
   // SPIRV Change Starts
   // SPIRV Change Starts
   "fields|"
   "fields|"
+  "global variables of scalar type|"
   "global variables of struct type|"
   "global variables of struct type|"
   "global variables, cbuffers, and tbuffers|"
   "global variables, cbuffers, and tbuffers|"
   "RWStructuredBuffers, AppendStructuredBuffers, and ConsumeStructuredBuffers|"
   "RWStructuredBuffers, AppendStructuredBuffers, and ConsumeStructuredBuffers|"

+ 5 - 0
tools/clang/include/clang/SPIRV/InstBuilder.h

@@ -882,10 +882,15 @@ public:
                        uint32_t operand);
                        uint32_t operand);
   InstBuilder &binaryOp(spv::Op op, uint32_t result_type, uint32_t result_id,
   InstBuilder &binaryOp(spv::Op op, uint32_t result_type, uint32_t result_id,
                         uint32_t lhs, uint32_t rhs);
                         uint32_t lhs, uint32_t rhs);
+  InstBuilder &specConstantBinaryOp(spv::Op op, uint32_t result_type,
+                                    uint32_t result_id, uint32_t lhs,
+                                    uint32_t rhs);
 
 
   // Methods for building constants.
   // Methods for building constants.
   InstBuilder &opConstant(uint32_t result_type, uint32_t result_id,
   InstBuilder &opConstant(uint32_t result_type, uint32_t result_id,
                           uint32_t value);
                           uint32_t value);
+  InstBuilder &opSpecConstant(uint32_t result_type, uint32_t result_id,
+                              uint32_t value);
 
 
   // All-in-one method for creating different types of OpImageSample*.
   // All-in-one method for creating different types of OpImageSample*.
   InstBuilder &
   InstBuilder &

+ 9 - 4
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -151,6 +151,8 @@ public:
   /// the <result-id> for the result.
   /// the <result-id> for the result.
   uint32_t createBinaryOp(spv::Op op, uint32_t resultType, uint32_t lhs,
   uint32_t createBinaryOp(spv::Op op, uint32_t resultType, uint32_t lhs,
                           uint32_t rhs);
                           uint32_t rhs);
+  uint32_t createSpecConstantBinaryOp(spv::Op op, uint32_t resultType,
+                                      uint32_t lhs, uint32_t rhs);
 
 
   /// \brief Creates an atomic instruction with the given parameters.
   /// \brief Creates an atomic instruction with the given parameters.
   /// Returns the <result-id> for the result.
   /// Returns the <result-id> for the result.
@@ -357,6 +359,9 @@ public:
   void decorateDSetBinding(uint32_t targetId, uint32_t setNumber,
   void decorateDSetBinding(uint32_t targetId, uint32_t setNumber,
                            uint32_t bindingNumber);
                            uint32_t bindingNumber);
 
 
+  /// \brief Decorates the given target <result-id> with the given SpecId.
+  void decorateSpecId(uint32_t targetId, uint32_t specId);
+
   /// \brief Decorates the given target <result-id> with the given input
   /// \brief Decorates the given target <result-id> with the given input
   /// attchment index number.
   /// attchment index number.
   void decorateInputAttachmentIndex(uint32_t targetId, uint32_t indexNumber);
   void decorateInputAttachmentIndex(uint32_t targetId, uint32_t indexNumber);
@@ -404,15 +409,15 @@ public:
   uint32_t getSparseResidencyStructType(uint32_t type);
   uint32_t getSparseResidencyStructType(uint32_t type);
 
 
   // === Constant ===
   // === Constant ===
-  uint32_t getConstantBool(bool value);
+  uint32_t getConstantBool(bool value, bool isSpecConst = false);
   uint32_t getConstantInt16(int16_t value);
   uint32_t getConstantInt16(int16_t value);
-  uint32_t getConstantInt32(int32_t value);
+  uint32_t getConstantInt32(int32_t value, bool isSpecConst = false);
   uint32_t getConstantInt64(int64_t value);
   uint32_t getConstantInt64(int64_t value);
   uint32_t getConstantUint16(uint16_t value);
   uint32_t getConstantUint16(uint16_t value);
-  uint32_t getConstantUint32(uint32_t value);
+  uint32_t getConstantUint32(uint32_t value, bool isSpecConst = false);
   uint32_t getConstantUint64(uint64_t value);
   uint32_t getConstantUint64(uint64_t value);
   uint32_t getConstantFloat16(int16_t value);
   uint32_t getConstantFloat16(int16_t value);
-  uint32_t getConstantFloat32(float value);
+  uint32_t getConstantFloat32(float value, bool isSpecConst = false);
   uint32_t getConstantFloat64(double value);
   uint32_t getConstantFloat64(double value);
   uint32_t getConstantComposite(uint32_t typeId,
   uint32_t getConstantComposite(uint32_t typeId,
                                 llvm::ArrayRef<uint32_t> constituents);
                                 llvm::ArrayRef<uint32_t> constituents);

+ 1 - 0
tools/clang/include/clang/Sema/AttributeList.h

@@ -857,6 +857,7 @@ enum AttributeDeclKind {
   ExpectedKernelFunction
   ExpectedKernelFunction
   // SPIRV Change Begins
   // SPIRV Change Begins
   ,ExpectedField
   ,ExpectedField
+  ,ExpectedScalarGlobalVar
   ,ExpectedStructGlobalVar
   ,ExpectedStructGlobalVar
   ,ExpectedGlobalVarOrCTBuffer
   ,ExpectedGlobalVarOrCTBuffer
   ,ExpectedCounterStructuredBuffer
   ,ExpectedCounterStructuredBuffer

+ 6 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -283,7 +283,7 @@ DeclResultIdMapper::getDeclSpirvInfo(const ValueDecl *decl) const {
   return nullptr;
   return nullptr;
 }
 }
 
 
-SpirvEvalInfo DeclResultIdMapper::getDeclResultId(const ValueDecl *decl,
+SpirvEvalInfo DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
                                                   bool checkRegistered) {
                                                   bool checkRegistered) {
   if (const auto *info = getDeclSpirvInfo(decl))
   if (const auto *info = getDeclSpirvInfo(decl))
     if (info->indexInCTBuffer >= 0) {
     if (info->indexInCTBuffer >= 0) {
@@ -631,6 +631,11 @@ DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
   return nullptr;
   return nullptr;
 }
 }
 
 
+void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
+                                              uint32_t specConstant) {
+  astDecls[decl].info.setResultId(specConstant).setRValue().setSpecConstant();
+}
+
 void DeclResultIdMapper::createCounterVar(
 void DeclResultIdMapper::createCounterVar(
     const DeclaratorDecl *decl, bool isAlias,
     const DeclaratorDecl *decl, bool isAlias,
     const llvm::SmallVector<uint32_t, 4> *indices) {
     const llvm::SmallVector<uint32_t, 4> *indices) {

+ 5 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -369,7 +369,7 @@ public:
   ///
   ///
   /// This method will emit a fatal error if checkRegistered is true and the
   /// This method will emit a fatal error if checkRegistered is true and the
   /// decl is not registered.
   /// decl is not registered.
-  SpirvEvalInfo getDeclResultId(const ValueDecl *decl,
+  SpirvEvalInfo getDeclEvalInfo(const ValueDecl *decl,
                                 bool checkRegistered = true);
                                 bool checkRegistered = true);
 
 
   /// \brief Returns the <result-id> for the given function if already
   /// \brief Returns the <result-id> for the given function if already
@@ -377,6 +377,10 @@ public:
   /// returns a newly assigned <result-id> for it.
   /// returns a newly assigned <result-id> for it.
   uint32_t getOrRegisterFnResultId(const FunctionDecl *fn);
   uint32_t getOrRegisterFnResultId(const FunctionDecl *fn);
 
 
+  /// Registers that the given decl should be translated into the given spec
+  /// constant.
+  void registerSpecConstant(const VarDecl *decl, uint32_t specConstant);
+
   /// \brief Returns the associated counter's (<result-id>, is-alias-or-not)
   /// \brief Returns the associated counter's (<result-id>, is-alias-or-not)
   /// pair for the given {RW|Append|Consume}StructuredBuffer variable.
   /// pair for the given {RW|Append|Consume}StructuredBuffer variable.
   /// If indices is not nullptr, walks trhough the fields of the decl, expected
   /// If indices is not nullptr, walks trhough the fields of the decl, expected

+ 37 - 0
tools/clang/lib/SPIRV/InstBuilderManual.cpp

@@ -64,6 +64,27 @@ InstBuilder &InstBuilder::binaryOp(spv::Op op, uint32_t result_type,
   return *this;
   return *this;
 }
 }
 
 
+InstBuilder &InstBuilder::specConstantBinaryOp(spv::Op op, uint32_t result_type,
+                                               uint32_t result_id, uint32_t lhs,
+                                               uint32_t rhs) {
+  if (!TheInst.empty()) {
+    TheStatus = Status::NestedInst;
+    return *this;
+  }
+
+  // TODO: check op range
+
+  TheInst.reserve(6);
+  TheInst.emplace_back(static_cast<uint32_t>(spv::Op::OpSpecConstantOp));
+  TheInst.emplace_back(result_type);
+  TheInst.emplace_back(result_id);
+  TheInst.emplace_back(static_cast<uint32_t>(op));
+  TheInst.emplace_back(lhs);
+  TheInst.emplace_back(rhs);
+
+  return *this;
+}
+
 InstBuilder &InstBuilder::opConstant(uint32_t resultType, uint32_t resultId,
 InstBuilder &InstBuilder::opConstant(uint32_t resultType, uint32_t resultId,
                                      uint32_t value) {
                                      uint32_t value) {
   if (!TheInst.empty()) {
   if (!TheInst.empty()) {
@@ -134,6 +155,22 @@ InstBuilder &InstBuilder::opImageFetchRead(
   return *this;
   return *this;
 }
 }
 
 
+InstBuilder &InstBuilder::opSpecConstant(uint32_t resultType, uint32_t resultId,
+                                         uint32_t value) {
+  if (!TheInst.empty()) {
+    TheStatus = Status::NestedInst;
+    return *this;
+  }
+
+  TheInst.reserve(4);
+  TheInst.emplace_back(static_cast<uint32_t>(spv::Op::OpSpecConstant));
+  TheInst.emplace_back(resultType);
+  TheInst.emplace_back(resultId);
+  TheInst.emplace_back(value);
+
+  return *this;
+}
+
 void InstBuilder::encodeString(std::string value) {
 void InstBuilder::encodeString(std::string value) {
   const auto &words = string::encodeSPIRVString(value);
   const auto &words = string::encodeSPIRVString(value);
   TheInst.insert(TheInst.end(), words.begin(), words.end());
   TheInst.insert(TheInst.end(), words.begin(), words.end());

+ 54 - 4
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -11,6 +11,7 @@
 
 
 #include "TypeTranslator.h"
 #include "TypeTranslator.h"
 #include "spirv/unified1//spirv.hpp11"
 #include "spirv/unified1//spirv.hpp11"
+#include "clang/SPIRV/BitwiseCast.h"
 #include "clang/SPIRV/InstBuilder.h"
 #include "clang/SPIRV/InstBuilder.h"
 #include "llvm/llvm_assert/assert.h"
 #include "llvm/llvm_assert/assert.h"
 
 
@@ -237,6 +238,15 @@ uint32_t ModuleBuilder::createBinaryOp(spv::Op op, uint32_t resultType,
   return id;
   return id;
 }
 }
 
 
+uint32_t ModuleBuilder::createSpecConstantBinaryOp(spv::Op op,
+                                                   uint32_t resultType,
+                                                   uint32_t lhs, uint32_t rhs) {
+  const uint32_t id = theContext.takeNextId();
+  instBuilder.specConstantBinaryOp(op, resultType, id, lhs, rhs).x();
+  theModule.addVariable(std::move(constructSite));
+  return id;
+}
+
 uint32_t ModuleBuilder::createAtomicOp(spv::Op opcode, uint32_t resultType,
 uint32_t ModuleBuilder::createAtomicOp(spv::Op opcode, uint32_t resultType,
                                        uint32_t orignalValuePtr,
                                        uint32_t orignalValuePtr,
                                        uint32_t scopeId,
                                        uint32_t scopeId,
@@ -773,6 +783,11 @@ void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
   theModule.addDecoration(d, targetId);
   theModule.addDecoration(d, targetId);
 }
 }
 
 
+void ModuleBuilder::decorateSpecId(uint32_t targetId, uint32_t specId) {
+  const Decoration *d = Decoration::getSpecId(theContext, specId);
+  theModule.addDecoration(d, targetId);
+}
+
 void ModuleBuilder::decorate(uint32_t targetId, spv::Decoration decoration) {
 void ModuleBuilder::decorate(uint32_t targetId, spv::Decoration decoration) {
   const Decoration *d = nullptr;
   const Decoration *d = nullptr;
   switch (decoration) {
   switch (decoration) {
@@ -1068,7 +1083,19 @@ uint32_t ModuleBuilder::getByteAddressBufferType(bool isRW) {
   return typeId;
   return typeId;
 }
 }
 
 
-uint32_t ModuleBuilder::getConstantBool(bool value) {
+uint32_t ModuleBuilder::getConstantBool(bool value, bool isSpecConst) {
+  if (isSpecConst) {
+    const uint32_t constId = theContext.takeNextId();
+    if (value) {
+      instBuilder.opSpecConstantTrue(getBoolType(), constId).x();
+    } else {
+      instBuilder.opSpecConstantFalse(getBoolType(), constId).x();
+    }
+
+    theModule.addVariable(std::move(constructSite));
+    return constId;
+  }
+
   const uint32_t typeId = getBoolType();
   const uint32_t typeId = getBoolType();
   const Constant *constant = value ? Constant::getTrue(theContext, typeId)
   const Constant *constant = value ? Constant::getTrue(theContext, typeId)
                                    : Constant::getFalse(theContext, typeId);
                                    : Constant::getFalse(theContext, typeId);
@@ -1089,17 +1116,40 @@ uint32_t ModuleBuilder::getConstantBool(bool value) {
     return constId;                                                            \
     return constId;                                                            \
   }
   }
 
 
+#define IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(builderTy, cppTy)                  \
+                                                                               \
+  uint32_t ModuleBuilder::getConstant##builderTy(cppTy value,                  \
+                                                 bool isSpecConst) {           \
+    if (isSpecConst) {                                                         \
+      const uint32_t constId = theContext.takeNextId();                        \
+      instBuilder                                                              \
+          .opSpecConstant(get##builderTy##Type(), constId,                     \
+                          cast::BitwiseCast<uint32_t>(value))                  \
+          .x();                                                                \
+      theModule.addVariable(std::move(constructSite));                         \
+      return constId;                                                          \
+    }                                                                          \
+                                                                               \
+    const uint32_t typeId = get##builderTy##Type();                            \
+    const Constant *constant =                                                 \
+        Constant::get##builderTy(theContext, typeId, value);                   \
+    const uint32_t constId = theContext.getResultIdForConstant(constant);      \
+    theModule.addConstant(constant, constId);                                  \
+    return constId;                                                            \
+  }
+
 IMPL_GET_PRIMITIVE_CONST(Int16, int16_t)
 IMPL_GET_PRIMITIVE_CONST(Int16, int16_t)
-IMPL_GET_PRIMITIVE_CONST(Int32, int32_t)
+IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Int32, int32_t)
 IMPL_GET_PRIMITIVE_CONST(Uint16, uint16_t)
 IMPL_GET_PRIMITIVE_CONST(Uint16, uint16_t)
-IMPL_GET_PRIMITIVE_CONST(Uint32, uint32_t)
+IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Uint32, uint32_t)
 IMPL_GET_PRIMITIVE_CONST(Float16, int16_t)
 IMPL_GET_PRIMITIVE_CONST(Float16, int16_t)
-IMPL_GET_PRIMITIVE_CONST(Float32, float)
+IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Float32, float)
 IMPL_GET_PRIMITIVE_CONST(Float64, double)
 IMPL_GET_PRIMITIVE_CONST(Float64, double)
 IMPL_GET_PRIMITIVE_CONST(Int64, int64_t)
 IMPL_GET_PRIMITIVE_CONST(Int64, int64_t)
 IMPL_GET_PRIMITIVE_CONST(Uint64, uint64_t)
 IMPL_GET_PRIMITIVE_CONST(Uint64, uint64_t)
 
 
 #undef IMPL_GET_PRIMITIVE_CONST
 #undef IMPL_GET_PRIMITIVE_CONST
+#undef IMPL_GET_PRIMITIVE_CONST_SPEC_CONST
 
 
 uint32_t
 uint32_t
 ModuleBuilder::getConstantComposite(uint32_t typeId,
 ModuleBuilder::getConstantComposite(uint32_t typeId,

+ 168 - 27
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -307,6 +307,65 @@ spv::Op translateAtomicHlslOpcodeToSpirvOpcode(hlsl::IntrinsicOp opcode) {
   return Op::Max;
   return Op::Max;
 }
 }
 
 
+// Returns true if the given opcode is an accepted binary opcode in
+// OpSpecConstantOp.
+bool isAcceptedSpecConstantBinaryOp(spv::Op op) {
+  switch (op) {
+  case spv::Op::OpIAdd:
+  case spv::Op::OpISub:
+  case spv::Op::OpIMul:
+  case spv::Op::OpUDiv:
+  case spv::Op::OpSDiv:
+  case spv::Op::OpUMod:
+  case spv::Op::OpSRem:
+  case spv::Op::OpSMod:
+  case spv::Op::OpShiftRightLogical:
+  case spv::Op::OpShiftRightArithmetic:
+  case spv::Op::OpShiftLeftLogical:
+  case spv::Op::OpBitwiseOr:
+  case spv::Op::OpBitwiseXor:
+  case spv::Op::OpBitwiseAnd:
+  case spv::Op::OpVectorShuffle:
+  case spv::Op::OpCompositeExtract:
+  case spv::Op::OpCompositeInsert:
+  case spv::Op::OpLogicalOr:
+  case spv::Op::OpLogicalAnd:
+  case spv::Op::OpLogicalNot:
+  case spv::Op::OpLogicalEqual:
+  case spv::Op::OpLogicalNotEqual:
+  case spv::Op::OpIEqual:
+  case spv::Op::OpINotEqual:
+  case spv::Op::OpULessThan:
+  case spv::Op::OpSLessThan:
+  case spv::Op::OpUGreaterThan:
+  case spv::Op::OpSGreaterThan:
+  case spv::Op::OpULessThanEqual:
+  case spv::Op::OpSLessThanEqual:
+  case spv::Op::OpUGreaterThanEqual:
+  case spv::Op::OpSGreaterThanEqual:
+    return true;
+  }
+  return false;
+}
+
+/// Returns true if the given expression is an accepted initializer for a spec
+/// constant.
+bool isAcceptedSpecConstantInit(const Expr *init) {
+  // Allow numeric casts
+  init = init->IgnoreParenCasts();
+
+  if (isa<CXXBoolLiteralExpr>(init) || isa<IntegerLiteral>(init) ||
+      isa<FloatingLiteral>(init))
+    return true;
+
+  // Allow the minus operator which is used to specify negative values
+  if (const auto *unaryOp = dyn_cast<UnaryOperator>(init))
+    return unaryOp->getOpcode() == UO_Minus &&
+           isAcceptedSpecConstantInit(unaryOp->getSubExpr());
+
+  return false;
+}
+
 /// Returns true if the given function parameter can act as shader stage
 /// Returns true if the given function parameter can act as shader stage
 /// input parameter.
 /// input parameter.
 inline bool canActAsInParmVar(const ParmVarDecl *param) {
 inline bool canActAsInParmVar(const ParmVarDecl *param) {
@@ -403,7 +462,8 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
       declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
       declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
       typeTranslator(astContext, theBuilder, diags, options),
       typeTranslator(astContext, theBuilder, diags, options),
       entryFunctionId(0), curFunction(nullptr), curThis(0),
       entryFunctionId(0), curFunction(nullptr), curThis(0),
-      seenPushConstantAt(), needsLegalization(false) {
+      seenPushConstantAt(), isSpecConstantMode(false),
+      needsLegalization(false) {
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
     emitError("unknown shader module: %0", {}) << shaderModel.GetName();
     emitError("unknown shader module: %0", {}) << shaderModel.GetName();
   if (options.invertY && !shaderModel.IsVS() && !shaderModel.IsDS() &&
   if (options.invertY && !shaderModel.IsVS() && !shaderModel.IsDS() &&
@@ -583,14 +643,14 @@ void SPIRVEmitter::doStmt(const Stmt *stmt,
 
 
 SpirvEvalInfo SPIRVEmitter::doDeclRefExpr(const DeclRefExpr *expr) {
 SpirvEvalInfo SPIRVEmitter::doDeclRefExpr(const DeclRefExpr *expr) {
   const auto *decl = expr->getDecl();
   const auto *decl = expr->getDecl();
-  auto id = declIdMapper.getDeclResultId(decl, false);
+  auto id = declIdMapper.getDeclEvalInfo(decl, false);
 
 
   if (spirvOptions.ignoreUnusedResources && !id) {
   if (spirvOptions.ignoreUnusedResources && !id) {
     // First time referencing a Decl inside TranslationUnit. Register
     // First time referencing a Decl inside TranslationUnit. Register
     // into DeclResultIdMapper and emit SPIR-V for it and then query
     // into DeclResultIdMapper and emit SPIR-V for it and then query
     // again.
     // again.
     doDecl(decl);
     doDecl(decl);
-    id = declIdMapper.getDeclResultId(decl);
+    id = declIdMapper.getDeclEvalInfo(decl);
   }
   }
 
 
   return id;
   return id;
@@ -614,7 +674,8 @@ SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
   } else if (const auto *initListExpr = dyn_cast<InitListExpr>(expr)) {
   } else if (const auto *initListExpr = dyn_cast<InitListExpr>(expr)) {
     result = doInitListExpr(initListExpr);
     result = doInitListExpr(initListExpr);
   } else if (const auto *boolLiteral = dyn_cast<CXXBoolLiteralExpr>(expr)) {
   } else if (const auto *boolLiteral = dyn_cast<CXXBoolLiteralExpr>(expr)) {
-    const auto value = theBuilder.getConstantBool(boolLiteral->getValue());
+    const auto value =
+        theBuilder.getConstantBool(boolLiteral->getValue(), isSpecConstantMode);
     result = SpirvEvalInfo(value).setConstant().setRValue();
     result = SpirvEvalInfo(value).setConstant().setRValue();
   } else if (const auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
   } else if (const auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
     const auto value = translateAPInt(intLiteral->getValue(), expr->getType());
     const auto value = translateAPInt(intLiteral->getValue(), expr->getType());
@@ -800,7 +861,7 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
     // Non-entry functions are added to the work queue following function
     // Non-entry functions are added to the work queue following function
     // calls. We have already assigned <result-id>s for it when translating
     // calls. We have already assigned <result-id>s for it when translating
     // its call site. Query it here.
     // its call site. Query it here.
-    funcId = declIdMapper.getDeclResultId(decl);
+    funcId = declIdMapper.getDeclEvalInfo(decl);
   }
   }
 
 
   const uint32_t retType =
   const uint32_t retType =
@@ -1013,6 +1074,12 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
   if (!validateVKAttributes(decl))
   if (!validateVKAttributes(decl))
     return;
     return;
 
 
+  if (decl->hasAttr<VKConstantIdAttr>()) {
+    // This is a VarDecl for specialization constant.
+    createSpecConstant(decl);
+    return;
+  }
+
   if (decl->hasAttr<VKPushConstantAttr>()) {
   if (decl->hasAttr<VKPushConstantAttr>()) {
     // This is a VarDecl for PushConstant block.
     // This is a VarDecl for PushConstant block.
     (void)declIdMapper.createPushConstant(decl);
     (void)declIdMapper.createPushConstant(decl);
@@ -1808,7 +1875,7 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
   }
   }
 
 
   // Inherit the SpirvEvalInfo from the function definition
   // Inherit the SpirvEvalInfo from the function definition
-  return declIdMapper.getDeclResultId(callee).setResultId(retVal);
+  return declIdMapper.getDeclEvalInfo(callee).setResultId(retVal);
 }
 }
 
 
 SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
 SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
@@ -2981,8 +3048,6 @@ SPIRVEmitter::getFinalACSBufferCounter(const Expr *expr) {
           ? getOrCreateDeclForMethodObject(cast<CXXMethodDecl>(curFunction))
           ? getOrCreateDeclForMethodObject(cast<CXXMethodDecl>(curFunction))
           : getReferencedDef(base);
           : getReferencedDef(base);
   return declIdMapper.getCounterIdAliasPair(decl, &indices);
   return declIdMapper.getCounterIdAliasPair(decl, &indices);
-
-  return nullptr;
 }
 }
 
 
 const CounterVarFields *SPIRVEmitter::getIntermediateACSBufferCounter(
 const CounterVarFields *SPIRVEmitter::getIntermediateACSBufferCounter(
@@ -4505,17 +4570,26 @@ SpirvEvalInfo SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
   case BO_XorAssign:
   case BO_XorAssign:
   case BO_ShlAssign:
   case BO_ShlAssign:
   case BO_ShrAssign: {
   case BO_ShrAssign: {
+    // To evaluate this expression as an OpSpecConstantOp, we need to make sure
+    // both operands are constant and at least one of them is a spec constant.
+    if (lhsVal.isConstant() && rhsVal.isConstant() &&
+        (lhsVal.isSpecConstant() || rhsVal.isSpecConstant()) &&
+        isAcceptedSpecConstantBinaryOp(spvOp)) {
+      const auto valId = theBuilder.createSpecConstantBinaryOp(
+          spvOp, resultTypeId, lhsVal, rhsVal);
+      return SpirvEvalInfo(valId).setRValue().setSpecConstant();
+    }
+
+    // Normal binary operation
     const auto valId =
     const auto valId =
         theBuilder.createBinaryOp(spvOp, resultTypeId, lhsVal, rhsVal);
         theBuilder.createBinaryOp(spvOp, resultTypeId, lhsVal, rhsVal);
     auto result = SpirvEvalInfo(valId).setRValue();
     auto result = SpirvEvalInfo(valId).setRValue();
-    return lhsVal.isRelaxedPrecision() || rhsVal.isRelaxedPrecision()
-               ? result.setRelaxedPrecision()
-               : result;
+    if (lhsVal.isRelaxedPrecision() || rhsVal.isRelaxedPrecision())
+      result.setRelaxedPrecision();
+    return result;
   }
   }
   case BO_Assign:
   case BO_Assign:
     llvm_unreachable("assignment should not be handled here");
     llvm_unreachable("assignment should not be handled here");
-  default:
-    break;
   }
   }
 
 
   emitError("binary operator '%0' unimplemented", lhs->getExprLoc())
   emitError("binary operator '%0' unimplemented", lhs->getExprLoc())
@@ -5098,6 +5172,71 @@ SpirvEvalInfo SPIRVEmitter::processEachVectorInMatrix(
   return SpirvEvalInfo(valId).setRValue();
   return SpirvEvalInfo(valId).setRValue();
 }
 }
 
 
+void SPIRVEmitter::createSpecConstant(const VarDecl *varDecl) {
+  class SpecConstantEnvRAII {
+  public:
+    // Creates a new instance which sets mode to true on creation,
+    // and resets mode to false on destruction.
+    SpecConstantEnvRAII(bool *mode) : modeSlot(mode) { *modeSlot = true; }
+    ~SpecConstantEnvRAII() { *modeSlot = false; }
+
+  private:
+    bool *modeSlot;
+  };
+
+  const QualType varType = varDecl->getType();
+
+  bool hasError = false;
+
+  if (!varDecl->isExternallyVisible()) {
+    emitError("specialization constant must be externally visible",
+              varDecl->getLocation());
+    hasError = true;
+  }
+
+  if (const auto *builtinType = varType->getAs<BuiltinType>()) {
+    switch (builtinType->getKind()) {
+    case BuiltinType::Bool:
+    case BuiltinType::Int:
+    case BuiltinType::UInt:
+    case BuiltinType::Float:
+      break;
+    default:
+      emitError("unsupported specialization constant type",
+                varDecl->getLocStart());
+      hasError = true;
+    }
+  }
+
+  const auto *init = varDecl->getInit();
+
+  if (!init) {
+    emitError("missing default value for specialization constant",
+              varDecl->getLocation());
+    hasError = true;
+  } else if (!isAcceptedSpecConstantInit(init)) {
+    emitError("unsupported specialization constant initializer",
+              init->getLocStart())
+        << init->getSourceRange();
+    hasError = true;
+  }
+
+  if (hasError)
+    return;
+
+  SpecConstantEnvRAII specConstantEnvRAII(&isSpecConstantMode);
+
+  const auto specConstant = doExpr(init);
+
+  // We are not creating a variable to hold the spec constant, instead, we
+  // translate the varDecl directly into the spec constant here.
+
+  theBuilder.decorateSpecId(
+      specConstant, varDecl->getAttr<VKConstantIdAttr>()->getSpecConstId());
+
+  declIdMapper.registerSpecConstant(varDecl, specConstant);
+}
+
 SpirvEvalInfo
 SpirvEvalInfo
 SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
 SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
                                     const BinaryOperatorKind opcode,
                                     const BinaryOperatorKind opcode,
@@ -7252,7 +7391,8 @@ uint32_t SPIRVEmitter::translateAPValue(const APValue &value,
   TypeTranslator::LiteralTypeHint hint(typeTranslator, targetType);
   TypeTranslator::LiteralTypeHint hint(typeTranslator, targetType);
 
 
   if (targetType->isBooleanType()) {
   if (targetType->isBooleanType()) {
-    result = theBuilder.getConstantBool(value.getInt().getBoolValue());
+    result = theBuilder.getConstantBool(value.getInt().getBoolValue(),
+                                        isSpecConstantMode);
   } else if (targetType->isIntegerType()) {
   } else if (targetType->isIntegerType()) {
     result = translateAPInt(value.getInt(), targetType);
     result = translateAPInt(value.getInt(), targetType);
   } else if (targetType->isFloatingType()) {
   } else if (targetType->isFloatingType()) {
@@ -7301,10 +7441,10 @@ uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
       // If enable16BitTypes option is not true, treat as 32-bit integer.
       // If enable16BitTypes option is not true, treat as 32-bit integer.
       if (isSigned)
       if (isSigned)
         return theBuilder.getConstantInt32(
         return theBuilder.getConstantInt32(
-            static_cast<int32_t>(intValue.getSExtValue()));
+            static_cast<int32_t>(intValue.getSExtValue()), isSpecConstantMode);
       else
       else
         return theBuilder.getConstantUint32(
         return theBuilder.getConstantUint32(
-            static_cast<uint32_t>(intValue.getZExtValue()));
+            static_cast<uint32_t>(intValue.getZExtValue()), isSpecConstantMode);
     }
     }
   }
   }
   case 32: {
   case 32: {
@@ -7317,7 +7457,7 @@ uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
         return 0;
         return 0;
       }
       }
       return theBuilder.getConstantInt32(
       return theBuilder.getConstantInt32(
-          static_cast<int32_t>(intValue.getSExtValue()));
+          static_cast<int32_t>(intValue.getSExtValue()), isSpecConstantMode);
     } else {
     } else {
       if (!intValue.isIntN(32)) {
       if (!intValue.isIntN(32)) {
         emitError("evaluating integer literal %0 as a 32-bit integer loses "
         emitError("evaluating integer literal %0 as a 32-bit integer loses "
@@ -7327,7 +7467,7 @@ uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
         return 0;
         return 0;
       }
       }
       return theBuilder.getConstantUint32(
       return theBuilder.getConstantUint32(
-          static_cast<uint32_t>(intValue.getZExtValue()));
+          static_cast<uint32_t>(intValue.getZExtValue()), isSpecConstantMode);
     }
     }
   }
   }
   case 64: {
   case 64: {
@@ -7392,7 +7532,8 @@ uint32_t SPIRVEmitter::tryToEvaluateAsFloat32(const llvm::APFloat &floatValue) {
   const auto &semantics = floatValue.getSemantics();
   const auto &semantics = floatValue.getSemantics();
   // If the given value is already a 32-bit float, there is no need to convert.
   // If the given value is already a 32-bit float, there is no need to convert.
   if (&semantics == &llvm::APFloat::IEEEsingle) {
   if (&semantics == &llvm::APFloat::IEEEsingle) {
-    return theBuilder.getConstantFloat32(floatValue.convertToFloat());
+    return theBuilder.getConstantFloat32(floatValue.convertToFloat(),
+                                         isSpecConstantMode);
   }
   }
 
 
   // Try to see if this literal float can be represented in 32-bit.
   // Try to see if this literal float can be represented in 32-bit.
@@ -7437,12 +7578,11 @@ uint32_t SPIRVEmitter::translateAPFloat(llvm::APFloat floatValue,
       emitError(
       emitError(
           "evaluating float literal %0 at a lower bitwidth loses information",
           "evaluating float literal %0 at a lower bitwidth loses information",
           {})
           {})
-          << std::to_string(
-                 valueBitwidth == 16
-                     ? static_cast<float>(
-                           originalValue.bitcastToAPInt().getZExtValue())
-                     : valueBitwidth == 32 ? originalValue.convertToFloat()
-                                           : originalValue.convertToDouble());
+          // Converting from 16bit to 32/64-bit won't lose information.
+          // So only 32/64-bit values can reach here.
+          << std::to_string(valueBitwidth == 32
+                                ? originalValue.convertToFloat()
+                                : originalValue.convertToDouble());
       return 0;
       return 0;
     }
     }
   }
   }
@@ -7452,7 +7592,8 @@ uint32_t SPIRVEmitter::translateAPFloat(llvm::APFloat floatValue,
     return theBuilder.getConstantFloat16(
     return theBuilder.getConstantFloat16(
         static_cast<uint16_t>(floatValue.bitcastToAPInt().getZExtValue()));
         static_cast<uint16_t>(floatValue.bitcastToAPInt().getZExtValue()));
   case 32:
   case 32:
-    return theBuilder.getConstantFloat32(floatValue.convertToFloat());
+    return theBuilder.getConstantFloat32(floatValue.convertToFloat(),
+                                         isSpecConstantMode);
   case 64:
   case 64:
     return theBuilder.getConstantFloat64(floatValue.convertToDouble());
     return theBuilder.getConstantFloat64(floatValue.convertToDouble());
   default:
   default:
@@ -7803,7 +7944,7 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
 
 
   // Initialize all global variables at the beginning of the wrapper
   // Initialize all global variables at the beginning of the wrapper
   for (const VarDecl *varDecl : toInitGloalVars) {
   for (const VarDecl *varDecl : toInitGloalVars) {
-    const auto varInfo = declIdMapper.getDeclResultId(varDecl);
+    const auto varInfo = declIdMapper.getDeclEvalInfo(varDecl);
     if (const auto *init = varDecl->getInit()) {
     if (const auto *init = varDecl->getInit()) {
       storeValue(varInfo, doExpr(init), varDecl->getType());
       storeValue(varInfo, doExpr(init), varDecl->getType());
 
 

+ 7 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -243,6 +243,9 @@ private:
       llvm::function_ref<uint32_t(uint32_t, uint32_t, uint32_t)>
       llvm::function_ref<uint32_t(uint32_t, uint32_t, uint32_t)>
           actOnEachVector);
           actOnEachVector);
 
 
+  /// Translates the given varDecl into a spec constant.
+  void createSpecConstant(const VarDecl *varDecl);
+
   /// Generates the necessary instructions for conducting the given binary
   /// Generates the necessary instructions for conducting the given binary
   /// operation on lhs and rhs.
   /// operation on lhs and rhs.
   ///
   ///
@@ -850,6 +853,10 @@ private:
   /// Invalid means no push constant blocks defined thus far.
   /// Invalid means no push constant blocks defined thus far.
   SourceLocation seenPushConstantAt;
   SourceLocation seenPushConstantAt;
 
 
+  /// Indicates whether the current emitter is in specialization constant mode:
+  /// all 32-bit scalar constants will be translated into OpSpecConstant.
+  bool isSpecConstantMode;
+
   /// Whether the translated SPIR-V binary needs legalization.
   /// Whether the translated SPIR-V binary needs legalization.
   ///
   ///
   /// The following cases will require legalization:
   /// The following cases will require legalization:

+ 10 - 0
tools/clang/lib/SPIRV/SpirvEvalInfo.h

@@ -94,6 +94,9 @@ public:
   inline SpirvEvalInfo &setConstant();
   inline SpirvEvalInfo &setConstant();
   bool isConstant() const { return isConstant_; }
   bool isConstant() const { return isConstant_; }
 
 
+  inline SpirvEvalInfo &setSpecConstant();
+  bool isSpecConstant() const { return isSpecConstant_; }
+
   inline SpirvEvalInfo &setRelaxedPrecision();
   inline SpirvEvalInfo &setRelaxedPrecision();
   bool isRelaxedPrecision() const { return isRelaxedPrecision_; }
   bool isRelaxedPrecision() const { return isRelaxedPrecision_; }
 
 
@@ -114,6 +117,7 @@ private:
 
 
   bool isRValue_;
   bool isRValue_;
   bool isConstant_;
   bool isConstant_;
+  bool isSpecConstant_;
   bool isRelaxedPrecision_;
   bool isRelaxedPrecision_;
 };
 };
 
 
@@ -158,6 +162,12 @@ SpirvEvalInfo &SpirvEvalInfo::setConstant() {
   return *this;
   return *this;
 }
 }
 
 
+SpirvEvalInfo &SpirvEvalInfo::setSpecConstant() {
+  // Specialization constant is also a kind of constant.
+  isConstant_ = isSpecConstant_ = true;
+  return *this;
+}
+
 SpirvEvalInfo &SpirvEvalInfo::setRelaxedPrecision() {
 SpirvEvalInfo &SpirvEvalInfo::setRelaxedPrecision() {
   isRelaxedPrecision_ = true;
   isRelaxedPrecision_ = true;
   return *this;
   return *this;

+ 4 - 0
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -10470,6 +10470,10 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
         A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
         A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
         A.getAttributeSpellingListIndex());
         A.getAttributeSpellingListIndex());
     break;
     break;
+  case AttributeList::AT_VKConstantId:
+    declAttr = ::new (S.Context) VKConstantIdAttr(A.getRange(), S.Context,
+      ValidateAttributeIntArg(S, A), A.getAttributeSpellingListIndex());
+    break;
   default:
   default:
     Handled = false;
     Handled = false;
     return;
     return;

+ 29 - 0
tools/clang/test/CodeGenSPIRV/vk.spec-constant.error.hlsl

@@ -0,0 +1,29 @@
+// Run: %dxc -T vs_6_0 -E main
+
+[[vk::constant_id(0)]]
+const bool sc0 = true;
+
+[[vk::constant_id(1)]]
+const bool sc1 = sc0; // error
+
+[[vk::constant_id(2)]]
+const double sc2 = 42; // error
+
+[[vk::constant_id(3)]]
+const int sc3; // error
+
+[[vk::constant_id(4)]]
+const int sc4 = sc3 + sc3; // error
+
+[[vk::constant_id(5)]]
+static const int sc5 = 1; // error
+
+float main() : A {
+    return 1.0;
+}
+
+// CHECK:  :7:18: error: unsupported specialization constant initializer
+// CHECK:  :10:1: error: unsupported specialization constant type
+// CHECK: :13:11: error: missing default value for specialization constant
+// CHECK: :16:17: error: unsupported specialization constant initializer
+// CHECK: :19:18: error: specialization constant must be externally visible

+ 63 - 0
tools/clang/test/CodeGenSPIRV/vk.spec-constant.init.hlsl

@@ -0,0 +1,63 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK: OpDecorate [[b0:%\d+]] SpecId 0
+// CHECK: OpDecorate [[b1:%\d+]] SpecId 1
+// CHECK: OpDecorate [[b2:%\d+]] SpecId 2
+
+// CHECK: OpDecorate [[i0:%\d+]] SpecId 10
+// CHECK: OpDecorate [[i1:%\d+]] SpecId 11
+// CHECK: OpDecorate [[i2:%\d+]] SpecId 12
+// CHECK: OpDecorate [[i3:%\d+]] SpecId 13
+
+// CHECK: OpDecorate [[u0:%\d+]] SpecId 20
+
+// CHECK: OpDecorate [[f0:%\d+]] SpecId 30
+// CHECK: OpDecorate [[f1:%\d+]] SpecId 31
+// CHECK: OpDecorate [[f2:%\d+]] SpecId 32
+// CHECK: OpDecorate [[f3:%\d+]] SpecId 33
+
+// CHECK: [[b0]] = OpSpecConstantTrue %bool
+[[vk::constant_id(0)]]
+bool b0 = true;
+// CHECK: [[b1]] = OpSpecConstantFalse %bool
+[[vk::constant_id(1)]]
+bool b1 = 0;
+// CHECK: [[b2]] = OpSpecConstantTrue %bool
+[[vk::constant_id(2)]]
+bool b2 = 1.5;
+
+
+// CHECK:  [[i0]] = OpSpecConstant %int 42
+[[vk::constant_id(10)]]
+int i0 = 42;
+// CHECK:  [[i1]] = OpSpecConstant %int -42
+[[vk::constant_id(11)]]
+int i1 = -42;
+// CHECK:  [[i2]] = OpSpecConstant %int 1
+[[vk::constant_id(12)]]
+int i2 = (true);
+// CHECK:  [[i3]] = OpSpecConstant %int 2
+[[vk::constant_id(13)]]
+int i3 = 2.5;
+
+// CHECK: [[u0]] = OpSpecConstant %uint 56
+[[vk::constant_id(20)]]
+uint uintConst1 = 56;
+
+// CHECK: [[f0]] = OpSpecConstant %float 4.2
+[[vk::constant_id(30)]]
+float f0 = (4.2);
+// CHECK: [[f1]] = OpSpecConstant %float -4.2
+[[vk::constant_id(31)]]
+float f1 = -4.2;
+// CHECK: [[f2]] = OpSpecConstant %float 1
+[[vk::constant_id(32)]]
+float f2 = true;
+// CHECK: [[f3]] = OpSpecConstant %float 20
+[[vk::constant_id(33)]]
+float f3 = 20;
+
+
+float main() : A {
+    return 1.0;
+}

+ 24 - 0
tools/clang/test/CodeGenSPIRV/vk.spec-constant.usage.hlsl

@@ -0,0 +1,24 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK: OpDecorate [[sc:%\d+]] SpecId 10
+[[vk::constant_id(10)]]
+// CHECK: [[sc]] = OpSpecConstant %int 12
+const int specConst = 12;
+
+// TODO: The frontend parsing hits assertion failures saying cannot evaluating
+// as constant int for the following usages.
+/*
+cbuffer Data {
+    float4 pos[specConst];
+    float4 tex[specConst + 5];
+};
+*/
+
+// CHECK: [[add:%\d+]] = OpSpecConstantOp %int IAdd [[sc]] %int_3
+static const int val = specConst + 3;
+
+// CHECK-LABEL:  %main = OpFunction
+// CHECK:                OpStore %val [[add]]
+void main() {
+
+}

+ 10 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -1116,6 +1116,16 @@ TEST_F(FileTest, VulkanMultiplePushConstant) {
   runFileTest("vk.push-constant.multiple.hlsl", Expect::Failure);
   runFileTest("vk.push-constant.multiple.hlsl", Expect::Failure);
 }
 }
 
 
+TEST_F(FileTest, VulkanSpecConstantInit) {
+  runFileTest("vk.spec-constant.init.hlsl");
+}
+TEST_F(FileTest, VulkanSpecConstantUsage) {
+  runFileTest("vk.spec-constant.usage.hlsl");
+}
+TEST_F(FileTest, VulkanSpecConstantError) {
+  runFileTest("vk.spec-constant.error.hlsl", Expect::Failure);
+}
+
 TEST_F(FileTest, VulkanLayoutCBufferStd140) {
 TEST_F(FileTest, VulkanLayoutCBufferStd140) {
   runFileTest("vk.layout.cbuffer.std140.hlsl");
   runFileTest("vk.layout.cbuffer.std140.hlsl");
 }
 }