Browse Source

[spirv] Add initial support for specialization constant (#1009)

This commit add support for generating OpSpecConstant* instructions
with SpecId decorations. Spec constants are only allowed to be of
scalar boolean/integer/float types. Using spec constant as the array
size does not work at the moment.
Lei Zhang 7 years ago
parent
commit
14c3c0d92c

+ 14 - 0
docs/SPIR-V.rst

@@ -214,6 +214,18 @@ annotated with the ``[[vk::push_constant]]`` attribute.
 Please note as per the requirements of Vulkan, "there must be no more than one
 push constant block statically used per shader entry point."
 
+Specialization constants
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+To use Vulkan specialization constants, annotate global constants with the
+``[[vk::constant_id(X)]]`` attribute. For example,
+
+.. code:: hlsl
+
+  [[vk::constant_id(1)]] const bool  specConstBool  = true;
+  [[vk::constant_id(2)]] const int   specConstInt   = 42;
+  [[vk::constant_id(3)]] const float specConstFloat = 1.5;
+
 Builtin variables
 ~~~~~~~~~~~~~~~~~
 
@@ -247,6 +259,8 @@ The namespace ``vk`` will be used for all Vulkan attributes:
 - ``push_constant``: For marking a variable as the push constant block. Allowed
   on global variables of struct type. At most one variable can be marked as
   ``push_constant`` in a shader.
+- ``constant_id``: For marking a global constant as a specialization constant.
+  Allowed on global variables of boolean/integer/float types.
 - ``input_attachment_index(X)``: To associate the Xth entry in the input pass
   list to the annotated object. Only allowed on objects whose type are
   ``SubpassInput`` or ``SubpassInputMS``.

+ 1 - 1
external/effcee

@@ -1 +1 @@
-Subproject commit 4a6edb2f740b9b87b04306a7815f42de5ca149a4
+Subproject commit 2741bade14f1ab23f3b90f0e5c77c6b935fc2fff

+ 11 - 0
tools/clang/include/clang/Basic/Attr.td

@@ -936,6 +936,17 @@ def VKInputAttachmentIndex : InheritableAttr {
   let Documentation = [Undocumented];
 }
 
+// Global variables that are of scalar type
+def ScalarGlobalVar : SubsetSubject<Var, [{S->hasGlobalStorage() && S->getType()->isScalarType()}]>;
+
+def VKConstantId : InheritableAttr {
+  let Spellings = [CXX11<"vk", "constant_id">];
+  let Subjects = SubjectList<[ScalarGlobalVar], ErrorDiag, "ExpectedScalarGlobalVar">;
+  let Args = [IntArgument<"SpecConstId">];
+  let LangOpts = [SPIRV];
+  let Documentation = [Undocumented];
+}
+
 // SPIRV Change Ends
 
 def C11NoReturn : InheritableAttr {

+ 1 - 0
tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

@@ -2327,6 +2327,7 @@ def warn_attribute_wrong_decl_type : Warning<
   "interface or protocol declarations|kernel functions|"
   // SPIRV Change Starts
   "fields|"
+  "global variables of scalar type|"
   "global variables of struct type|"
   "global variables, cbuffers, and tbuffers|"
   "RWStructuredBuffers, AppendStructuredBuffers, and ConsumeStructuredBuffers|"

+ 5 - 0
tools/clang/include/clang/SPIRV/InstBuilder.h

@@ -882,10 +882,15 @@ public:
                        uint32_t operand);
   InstBuilder &binaryOp(spv::Op op, uint32_t result_type, uint32_t result_id,
                         uint32_t lhs, uint32_t rhs);
+  InstBuilder &specConstantBinaryOp(spv::Op op, uint32_t result_type,
+                                    uint32_t result_id, uint32_t lhs,
+                                    uint32_t rhs);
 
   // Methods for building constants.
   InstBuilder &opConstant(uint32_t result_type, uint32_t result_id,
                           uint32_t value);
+  InstBuilder &opSpecConstant(uint32_t result_type, uint32_t result_id,
+                              uint32_t value);
 
   // All-in-one method for creating different types of OpImageSample*.
   InstBuilder &

+ 9 - 4
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -151,6 +151,8 @@ public:
   /// the <result-id> for the result.
   uint32_t createBinaryOp(spv::Op op, uint32_t resultType, uint32_t lhs,
                           uint32_t rhs);
+  uint32_t createSpecConstantBinaryOp(spv::Op op, uint32_t resultType,
+                                      uint32_t lhs, uint32_t rhs);
 
   /// \brief Creates an atomic instruction with the given parameters.
   /// Returns the <result-id> for the result.
@@ -357,6 +359,9 @@ public:
   void decorateDSetBinding(uint32_t targetId, uint32_t setNumber,
                            uint32_t bindingNumber);
 
+  /// \brief Decorates the given target <result-id> with the given SpecId.
+  void decorateSpecId(uint32_t targetId, uint32_t specId);
+
   /// \brief Decorates the given target <result-id> with the given input
   /// attchment index number.
   void decorateInputAttachmentIndex(uint32_t targetId, uint32_t indexNumber);
@@ -404,15 +409,15 @@ public:
   uint32_t getSparseResidencyStructType(uint32_t type);
 
   // === Constant ===
-  uint32_t getConstantBool(bool value);
+  uint32_t getConstantBool(bool value, bool isSpecConst = false);
   uint32_t getConstantInt16(int16_t value);
-  uint32_t getConstantInt32(int32_t value);
+  uint32_t getConstantInt32(int32_t value, bool isSpecConst = false);
   uint32_t getConstantInt64(int64_t value);
   uint32_t getConstantUint16(uint16_t value);
-  uint32_t getConstantUint32(uint32_t value);
+  uint32_t getConstantUint32(uint32_t value, bool isSpecConst = false);
   uint32_t getConstantUint64(uint64_t value);
   uint32_t getConstantFloat16(int16_t value);
-  uint32_t getConstantFloat32(float value);
+  uint32_t getConstantFloat32(float value, bool isSpecConst = false);
   uint32_t getConstantFloat64(double value);
   uint32_t getConstantComposite(uint32_t typeId,
                                 llvm::ArrayRef<uint32_t> constituents);

+ 1 - 0
tools/clang/include/clang/Sema/AttributeList.h

@@ -857,6 +857,7 @@ enum AttributeDeclKind {
   ExpectedKernelFunction
   // SPIRV Change Begins
   ,ExpectedField
+  ,ExpectedScalarGlobalVar
   ,ExpectedStructGlobalVar
   ,ExpectedGlobalVarOrCTBuffer
   ,ExpectedCounterStructuredBuffer

+ 6 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -283,7 +283,7 @@ DeclResultIdMapper::getDeclSpirvInfo(const ValueDecl *decl) const {
   return nullptr;
 }
 
-SpirvEvalInfo DeclResultIdMapper::getDeclResultId(const ValueDecl *decl,
+SpirvEvalInfo DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
                                                   bool checkRegistered) {
   if (const auto *info = getDeclSpirvInfo(decl))
     if (info->indexInCTBuffer >= 0) {
@@ -631,6 +631,11 @@ DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
   return nullptr;
 }
 
+void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
+                                              uint32_t specConstant) {
+  astDecls[decl].info.setResultId(specConstant).setRValue().setSpecConstant();
+}
+
 void DeclResultIdMapper::createCounterVar(
     const DeclaratorDecl *decl, bool isAlias,
     const llvm::SmallVector<uint32_t, 4> *indices) {

+ 5 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -369,7 +369,7 @@ public:
   ///
   /// This method will emit a fatal error if checkRegistered is true and the
   /// decl is not registered.
-  SpirvEvalInfo getDeclResultId(const ValueDecl *decl,
+  SpirvEvalInfo getDeclEvalInfo(const ValueDecl *decl,
                                 bool checkRegistered = true);
 
   /// \brief Returns the <result-id> for the given function if already
@@ -377,6 +377,10 @@ public:
   /// returns a newly assigned <result-id> for it.
   uint32_t getOrRegisterFnResultId(const FunctionDecl *fn);
 
+  /// Registers that the given decl should be translated into the given spec
+  /// constant.
+  void registerSpecConstant(const VarDecl *decl, uint32_t specConstant);
+
   /// \brief Returns the associated counter's (<result-id>, is-alias-or-not)
   /// pair for the given {RW|Append|Consume}StructuredBuffer variable.
   /// If indices is not nullptr, walks trhough the fields of the decl, expected

+ 37 - 0
tools/clang/lib/SPIRV/InstBuilderManual.cpp

@@ -64,6 +64,27 @@ InstBuilder &InstBuilder::binaryOp(spv::Op op, uint32_t result_type,
   return *this;
 }
 
+InstBuilder &InstBuilder::specConstantBinaryOp(spv::Op op, uint32_t result_type,
+                                               uint32_t result_id, uint32_t lhs,
+                                               uint32_t rhs) {
+  if (!TheInst.empty()) {
+    TheStatus = Status::NestedInst;
+    return *this;
+  }
+
+  // TODO: check op range
+
+  TheInst.reserve(6);
+  TheInst.emplace_back(static_cast<uint32_t>(spv::Op::OpSpecConstantOp));
+  TheInst.emplace_back(result_type);
+  TheInst.emplace_back(result_id);
+  TheInst.emplace_back(static_cast<uint32_t>(op));
+  TheInst.emplace_back(lhs);
+  TheInst.emplace_back(rhs);
+
+  return *this;
+}
+
 InstBuilder &InstBuilder::opConstant(uint32_t resultType, uint32_t resultId,
                                      uint32_t value) {
   if (!TheInst.empty()) {
@@ -134,6 +155,22 @@ InstBuilder &InstBuilder::opImageFetchRead(
   return *this;
 }
 
+InstBuilder &InstBuilder::opSpecConstant(uint32_t resultType, uint32_t resultId,
+                                         uint32_t value) {
+  if (!TheInst.empty()) {
+    TheStatus = Status::NestedInst;
+    return *this;
+  }
+
+  TheInst.reserve(4);
+  TheInst.emplace_back(static_cast<uint32_t>(spv::Op::OpSpecConstant));
+  TheInst.emplace_back(resultType);
+  TheInst.emplace_back(resultId);
+  TheInst.emplace_back(value);
+
+  return *this;
+}
+
 void InstBuilder::encodeString(std::string value) {
   const auto &words = string::encodeSPIRVString(value);
   TheInst.insert(TheInst.end(), words.begin(), words.end());

+ 54 - 4
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -11,6 +11,7 @@
 
 #include "TypeTranslator.h"
 #include "spirv/unified1//spirv.hpp11"
+#include "clang/SPIRV/BitwiseCast.h"
 #include "clang/SPIRV/InstBuilder.h"
 #include "llvm/llvm_assert/assert.h"
 
@@ -237,6 +238,15 @@ uint32_t ModuleBuilder::createBinaryOp(spv::Op op, uint32_t resultType,
   return id;
 }
 
+uint32_t ModuleBuilder::createSpecConstantBinaryOp(spv::Op op,
+                                                   uint32_t resultType,
+                                                   uint32_t lhs, uint32_t rhs) {
+  const uint32_t id = theContext.takeNextId();
+  instBuilder.specConstantBinaryOp(op, resultType, id, lhs, rhs).x();
+  theModule.addVariable(std::move(constructSite));
+  return id;
+}
+
 uint32_t ModuleBuilder::createAtomicOp(spv::Op opcode, uint32_t resultType,
                                        uint32_t orignalValuePtr,
                                        uint32_t scopeId,
@@ -773,6 +783,11 @@ void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
   theModule.addDecoration(d, targetId);
 }
 
+void ModuleBuilder::decorateSpecId(uint32_t targetId, uint32_t specId) {
+  const Decoration *d = Decoration::getSpecId(theContext, specId);
+  theModule.addDecoration(d, targetId);
+}
+
 void ModuleBuilder::decorate(uint32_t targetId, spv::Decoration decoration) {
   const Decoration *d = nullptr;
   switch (decoration) {
@@ -1068,7 +1083,19 @@ uint32_t ModuleBuilder::getByteAddressBufferType(bool isRW) {
   return typeId;
 }
 
-uint32_t ModuleBuilder::getConstantBool(bool value) {
+uint32_t ModuleBuilder::getConstantBool(bool value, bool isSpecConst) {
+  if (isSpecConst) {
+    const uint32_t constId = theContext.takeNextId();
+    if (value) {
+      instBuilder.opSpecConstantTrue(getBoolType(), constId).x();
+    } else {
+      instBuilder.opSpecConstantFalse(getBoolType(), constId).x();
+    }
+
+    theModule.addVariable(std::move(constructSite));
+    return constId;
+  }
+
   const uint32_t typeId = getBoolType();
   const Constant *constant = value ? Constant::getTrue(theContext, typeId)
                                    : Constant::getFalse(theContext, typeId);
@@ -1089,17 +1116,40 @@ uint32_t ModuleBuilder::getConstantBool(bool value) {
     return constId;                                                            \
   }
 
+#define IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(builderTy, cppTy)                  \
+                                                                               \
+  uint32_t ModuleBuilder::getConstant##builderTy(cppTy value,                  \
+                                                 bool isSpecConst) {           \
+    if (isSpecConst) {                                                         \
+      const uint32_t constId = theContext.takeNextId();                        \
+      instBuilder                                                              \
+          .opSpecConstant(get##builderTy##Type(), constId,                     \
+                          cast::BitwiseCast<uint32_t>(value))                  \
+          .x();                                                                \
+      theModule.addVariable(std::move(constructSite));                         \
+      return constId;                                                          \
+    }                                                                          \
+                                                                               \
+    const uint32_t typeId = get##builderTy##Type();                            \
+    const Constant *constant =                                                 \
+        Constant::get##builderTy(theContext, typeId, value);                   \
+    const uint32_t constId = theContext.getResultIdForConstant(constant);      \
+    theModule.addConstant(constant, constId);                                  \
+    return constId;                                                            \
+  }
+
 IMPL_GET_PRIMITIVE_CONST(Int16, int16_t)
-IMPL_GET_PRIMITIVE_CONST(Int32, int32_t)
+IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Int32, int32_t)
 IMPL_GET_PRIMITIVE_CONST(Uint16, uint16_t)
-IMPL_GET_PRIMITIVE_CONST(Uint32, uint32_t)
+IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Uint32, uint32_t)
 IMPL_GET_PRIMITIVE_CONST(Float16, int16_t)
-IMPL_GET_PRIMITIVE_CONST(Float32, float)
+IMPL_GET_PRIMITIVE_CONST_SPEC_CONST(Float32, float)
 IMPL_GET_PRIMITIVE_CONST(Float64, double)
 IMPL_GET_PRIMITIVE_CONST(Int64, int64_t)
 IMPL_GET_PRIMITIVE_CONST(Uint64, uint64_t)
 
 #undef IMPL_GET_PRIMITIVE_CONST
+#undef IMPL_GET_PRIMITIVE_CONST_SPEC_CONST
 
 uint32_t
 ModuleBuilder::getConstantComposite(uint32_t typeId,

+ 168 - 27
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -307,6 +307,65 @@ spv::Op translateAtomicHlslOpcodeToSpirvOpcode(hlsl::IntrinsicOp opcode) {
   return Op::Max;
 }
 
+// Returns true if the given opcode is an accepted binary opcode in
+// OpSpecConstantOp.
+bool isAcceptedSpecConstantBinaryOp(spv::Op op) {
+  switch (op) {
+  case spv::Op::OpIAdd:
+  case spv::Op::OpISub:
+  case spv::Op::OpIMul:
+  case spv::Op::OpUDiv:
+  case spv::Op::OpSDiv:
+  case spv::Op::OpUMod:
+  case spv::Op::OpSRem:
+  case spv::Op::OpSMod:
+  case spv::Op::OpShiftRightLogical:
+  case spv::Op::OpShiftRightArithmetic:
+  case spv::Op::OpShiftLeftLogical:
+  case spv::Op::OpBitwiseOr:
+  case spv::Op::OpBitwiseXor:
+  case spv::Op::OpBitwiseAnd:
+  case spv::Op::OpVectorShuffle:
+  case spv::Op::OpCompositeExtract:
+  case spv::Op::OpCompositeInsert:
+  case spv::Op::OpLogicalOr:
+  case spv::Op::OpLogicalAnd:
+  case spv::Op::OpLogicalNot:
+  case spv::Op::OpLogicalEqual:
+  case spv::Op::OpLogicalNotEqual:
+  case spv::Op::OpIEqual:
+  case spv::Op::OpINotEqual:
+  case spv::Op::OpULessThan:
+  case spv::Op::OpSLessThan:
+  case spv::Op::OpUGreaterThan:
+  case spv::Op::OpSGreaterThan:
+  case spv::Op::OpULessThanEqual:
+  case spv::Op::OpSLessThanEqual:
+  case spv::Op::OpUGreaterThanEqual:
+  case spv::Op::OpSGreaterThanEqual:
+    return true;
+  }
+  return false;
+}
+
+/// Returns true if the given expression is an accepted initializer for a spec
+/// constant.
+bool isAcceptedSpecConstantInit(const Expr *init) {
+  // Allow numeric casts
+  init = init->IgnoreParenCasts();
+
+  if (isa<CXXBoolLiteralExpr>(init) || isa<IntegerLiteral>(init) ||
+      isa<FloatingLiteral>(init))
+    return true;
+
+  // Allow the minus operator which is used to specify negative values
+  if (const auto *unaryOp = dyn_cast<UnaryOperator>(init))
+    return unaryOp->getOpcode() == UO_Minus &&
+           isAcceptedSpecConstantInit(unaryOp->getSubExpr());
+
+  return false;
+}
+
 /// Returns true if the given function parameter can act as shader stage
 /// input parameter.
 inline bool canActAsInParmVar(const ParmVarDecl *param) {
@@ -403,7 +462,8 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
       declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
       typeTranslator(astContext, theBuilder, diags, options),
       entryFunctionId(0), curFunction(nullptr), curThis(0),
-      seenPushConstantAt(), needsLegalization(false) {
+      seenPushConstantAt(), isSpecConstantMode(false),
+      needsLegalization(false) {
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
     emitError("unknown shader module: %0", {}) << shaderModel.GetName();
   if (options.invertY && !shaderModel.IsVS() && !shaderModel.IsDS() &&
@@ -583,14 +643,14 @@ void SPIRVEmitter::doStmt(const Stmt *stmt,
 
 SpirvEvalInfo SPIRVEmitter::doDeclRefExpr(const DeclRefExpr *expr) {
   const auto *decl = expr->getDecl();
-  auto id = declIdMapper.getDeclResultId(decl, false);
+  auto id = declIdMapper.getDeclEvalInfo(decl, false);
 
   if (spirvOptions.ignoreUnusedResources && !id) {
     // First time referencing a Decl inside TranslationUnit. Register
     // into DeclResultIdMapper and emit SPIR-V for it and then query
     // again.
     doDecl(decl);
-    id = declIdMapper.getDeclResultId(decl);
+    id = declIdMapper.getDeclEvalInfo(decl);
   }
 
   return id;
@@ -614,7 +674,8 @@ SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
   } else if (const auto *initListExpr = dyn_cast<InitListExpr>(expr)) {
     result = doInitListExpr(initListExpr);
   } else if (const auto *boolLiteral = dyn_cast<CXXBoolLiteralExpr>(expr)) {
-    const auto value = theBuilder.getConstantBool(boolLiteral->getValue());
+    const auto value =
+        theBuilder.getConstantBool(boolLiteral->getValue(), isSpecConstantMode);
     result = SpirvEvalInfo(value).setConstant().setRValue();
   } else if (const auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
     const auto value = translateAPInt(intLiteral->getValue(), expr->getType());
@@ -800,7 +861,7 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
     // Non-entry functions are added to the work queue following function
     // calls. We have already assigned <result-id>s for it when translating
     // its call site. Query it here.
-    funcId = declIdMapper.getDeclResultId(decl);
+    funcId = declIdMapper.getDeclEvalInfo(decl);
   }
 
   const uint32_t retType =
@@ -1013,6 +1074,12 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
   if (!validateVKAttributes(decl))
     return;
 
+  if (decl->hasAttr<VKConstantIdAttr>()) {
+    // This is a VarDecl for specialization constant.
+    createSpecConstant(decl);
+    return;
+  }
+
   if (decl->hasAttr<VKPushConstantAttr>()) {
     // This is a VarDecl for PushConstant block.
     (void)declIdMapper.createPushConstant(decl);
@@ -1808,7 +1875,7 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
   }
 
   // Inherit the SpirvEvalInfo from the function definition
-  return declIdMapper.getDeclResultId(callee).setResultId(retVal);
+  return declIdMapper.getDeclEvalInfo(callee).setResultId(retVal);
 }
 
 SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
@@ -2981,8 +3048,6 @@ SPIRVEmitter::getFinalACSBufferCounter(const Expr *expr) {
           ? getOrCreateDeclForMethodObject(cast<CXXMethodDecl>(curFunction))
           : getReferencedDef(base);
   return declIdMapper.getCounterIdAliasPair(decl, &indices);
-
-  return nullptr;
 }
 
 const CounterVarFields *SPIRVEmitter::getIntermediateACSBufferCounter(
@@ -4505,17 +4570,26 @@ SpirvEvalInfo SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
   case BO_XorAssign:
   case BO_ShlAssign:
   case BO_ShrAssign: {
+    // To evaluate this expression as an OpSpecConstantOp, we need to make sure
+    // both operands are constant and at least one of them is a spec constant.
+    if (lhsVal.isConstant() && rhsVal.isConstant() &&
+        (lhsVal.isSpecConstant() || rhsVal.isSpecConstant()) &&
+        isAcceptedSpecConstantBinaryOp(spvOp)) {
+      const auto valId = theBuilder.createSpecConstantBinaryOp(
+          spvOp, resultTypeId, lhsVal, rhsVal);
+      return SpirvEvalInfo(valId).setRValue().setSpecConstant();
+    }
+
+    // Normal binary operation
     const auto valId =
         theBuilder.createBinaryOp(spvOp, resultTypeId, lhsVal, rhsVal);
     auto result = SpirvEvalInfo(valId).setRValue();
-    return lhsVal.isRelaxedPrecision() || rhsVal.isRelaxedPrecision()
-               ? result.setRelaxedPrecision()
-               : result;
+    if (lhsVal.isRelaxedPrecision() || rhsVal.isRelaxedPrecision())
+      result.setRelaxedPrecision();
+    return result;
   }
   case BO_Assign:
     llvm_unreachable("assignment should not be handled here");
-  default:
-    break;
   }
 
   emitError("binary operator '%0' unimplemented", lhs->getExprLoc())
@@ -5098,6 +5172,71 @@ SpirvEvalInfo SPIRVEmitter::processEachVectorInMatrix(
   return SpirvEvalInfo(valId).setRValue();
 }
 
+void SPIRVEmitter::createSpecConstant(const VarDecl *varDecl) {
+  class SpecConstantEnvRAII {
+  public:
+    // Creates a new instance which sets mode to true on creation,
+    // and resets mode to false on destruction.
+    SpecConstantEnvRAII(bool *mode) : modeSlot(mode) { *modeSlot = true; }
+    ~SpecConstantEnvRAII() { *modeSlot = false; }
+
+  private:
+    bool *modeSlot;
+  };
+
+  const QualType varType = varDecl->getType();
+
+  bool hasError = false;
+
+  if (!varDecl->isExternallyVisible()) {
+    emitError("specialization constant must be externally visible",
+              varDecl->getLocation());
+    hasError = true;
+  }
+
+  if (const auto *builtinType = varType->getAs<BuiltinType>()) {
+    switch (builtinType->getKind()) {
+    case BuiltinType::Bool:
+    case BuiltinType::Int:
+    case BuiltinType::UInt:
+    case BuiltinType::Float:
+      break;
+    default:
+      emitError("unsupported specialization constant type",
+                varDecl->getLocStart());
+      hasError = true;
+    }
+  }
+
+  const auto *init = varDecl->getInit();
+
+  if (!init) {
+    emitError("missing default value for specialization constant",
+              varDecl->getLocation());
+    hasError = true;
+  } else if (!isAcceptedSpecConstantInit(init)) {
+    emitError("unsupported specialization constant initializer",
+              init->getLocStart())
+        << init->getSourceRange();
+    hasError = true;
+  }
+
+  if (hasError)
+    return;
+
+  SpecConstantEnvRAII specConstantEnvRAII(&isSpecConstantMode);
+
+  const auto specConstant = doExpr(init);
+
+  // We are not creating a variable to hold the spec constant, instead, we
+  // translate the varDecl directly into the spec constant here.
+
+  theBuilder.decorateSpecId(
+      specConstant, varDecl->getAttr<VKConstantIdAttr>()->getSpecConstId());
+
+  declIdMapper.registerSpecConstant(varDecl, specConstant);
+}
+
 SpirvEvalInfo
 SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
                                     const BinaryOperatorKind opcode,
@@ -7252,7 +7391,8 @@ uint32_t SPIRVEmitter::translateAPValue(const APValue &value,
   TypeTranslator::LiteralTypeHint hint(typeTranslator, targetType);
 
   if (targetType->isBooleanType()) {
-    result = theBuilder.getConstantBool(value.getInt().getBoolValue());
+    result = theBuilder.getConstantBool(value.getInt().getBoolValue(),
+                                        isSpecConstantMode);
   } else if (targetType->isIntegerType()) {
     result = translateAPInt(value.getInt(), targetType);
   } else if (targetType->isFloatingType()) {
@@ -7301,10 +7441,10 @@ uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
       // If enable16BitTypes option is not true, treat as 32-bit integer.
       if (isSigned)
         return theBuilder.getConstantInt32(
-            static_cast<int32_t>(intValue.getSExtValue()));
+            static_cast<int32_t>(intValue.getSExtValue()), isSpecConstantMode);
       else
         return theBuilder.getConstantUint32(
-            static_cast<uint32_t>(intValue.getZExtValue()));
+            static_cast<uint32_t>(intValue.getZExtValue()), isSpecConstantMode);
     }
   }
   case 32: {
@@ -7317,7 +7457,7 @@ uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
         return 0;
       }
       return theBuilder.getConstantInt32(
-          static_cast<int32_t>(intValue.getSExtValue()));
+          static_cast<int32_t>(intValue.getSExtValue()), isSpecConstantMode);
     } else {
       if (!intValue.isIntN(32)) {
         emitError("evaluating integer literal %0 as a 32-bit integer loses "
@@ -7327,7 +7467,7 @@ uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
         return 0;
       }
       return theBuilder.getConstantUint32(
-          static_cast<uint32_t>(intValue.getZExtValue()));
+          static_cast<uint32_t>(intValue.getZExtValue()), isSpecConstantMode);
     }
   }
   case 64: {
@@ -7392,7 +7532,8 @@ uint32_t SPIRVEmitter::tryToEvaluateAsFloat32(const llvm::APFloat &floatValue) {
   const auto &semantics = floatValue.getSemantics();
   // If the given value is already a 32-bit float, there is no need to convert.
   if (&semantics == &llvm::APFloat::IEEEsingle) {
-    return theBuilder.getConstantFloat32(floatValue.convertToFloat());
+    return theBuilder.getConstantFloat32(floatValue.convertToFloat(),
+                                         isSpecConstantMode);
   }
 
   // Try to see if this literal float can be represented in 32-bit.
@@ -7437,12 +7578,11 @@ uint32_t SPIRVEmitter::translateAPFloat(llvm::APFloat floatValue,
       emitError(
           "evaluating float literal %0 at a lower bitwidth loses information",
           {})
-          << std::to_string(
-                 valueBitwidth == 16
-                     ? static_cast<float>(
-                           originalValue.bitcastToAPInt().getZExtValue())
-                     : valueBitwidth == 32 ? originalValue.convertToFloat()
-                                           : originalValue.convertToDouble());
+          // Converting from 16bit to 32/64-bit won't lose information.
+          // So only 32/64-bit values can reach here.
+          << std::to_string(valueBitwidth == 32
+                                ? originalValue.convertToFloat()
+                                : originalValue.convertToDouble());
       return 0;
     }
   }
@@ -7452,7 +7592,8 @@ uint32_t SPIRVEmitter::translateAPFloat(llvm::APFloat floatValue,
     return theBuilder.getConstantFloat16(
         static_cast<uint16_t>(floatValue.bitcastToAPInt().getZExtValue()));
   case 32:
-    return theBuilder.getConstantFloat32(floatValue.convertToFloat());
+    return theBuilder.getConstantFloat32(floatValue.convertToFloat(),
+                                         isSpecConstantMode);
   case 64:
     return theBuilder.getConstantFloat64(floatValue.convertToDouble());
   default:
@@ -7803,7 +7944,7 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
 
   // Initialize all global variables at the beginning of the wrapper
   for (const VarDecl *varDecl : toInitGloalVars) {
-    const auto varInfo = declIdMapper.getDeclResultId(varDecl);
+    const auto varInfo = declIdMapper.getDeclEvalInfo(varDecl);
     if (const auto *init = varDecl->getInit()) {
       storeValue(varInfo, doExpr(init), varDecl->getType());
 

+ 7 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -243,6 +243,9 @@ private:
       llvm::function_ref<uint32_t(uint32_t, uint32_t, uint32_t)>
           actOnEachVector);
 
+  /// Translates the given varDecl into a spec constant.
+  void createSpecConstant(const VarDecl *varDecl);
+
   /// Generates the necessary instructions for conducting the given binary
   /// operation on lhs and rhs.
   ///
@@ -850,6 +853,10 @@ private:
   /// Invalid means no push constant blocks defined thus far.
   SourceLocation seenPushConstantAt;
 
+  /// Indicates whether the current emitter is in specialization constant mode:
+  /// all 32-bit scalar constants will be translated into OpSpecConstant.
+  bool isSpecConstantMode;
+
   /// Whether the translated SPIR-V binary needs legalization.
   ///
   /// The following cases will require legalization:

+ 10 - 0
tools/clang/lib/SPIRV/SpirvEvalInfo.h

@@ -94,6 +94,9 @@ public:
   inline SpirvEvalInfo &setConstant();
   bool isConstant() const { return isConstant_; }
 
+  inline SpirvEvalInfo &setSpecConstant();
+  bool isSpecConstant() const { return isSpecConstant_; }
+
   inline SpirvEvalInfo &setRelaxedPrecision();
   bool isRelaxedPrecision() const { return isRelaxedPrecision_; }
 
@@ -114,6 +117,7 @@ private:
 
   bool isRValue_;
   bool isConstant_;
+  bool isSpecConstant_;
   bool isRelaxedPrecision_;
 };
 
@@ -158,6 +162,12 @@ SpirvEvalInfo &SpirvEvalInfo::setConstant() {
   return *this;
 }
 
+SpirvEvalInfo &SpirvEvalInfo::setSpecConstant() {
+  // Specialization constant is also a kind of constant.
+  isConstant_ = isSpecConstant_ = true;
+  return *this;
+}
+
 SpirvEvalInfo &SpirvEvalInfo::setRelaxedPrecision() {
   isRelaxedPrecision_ = true;
   return *this;

+ 4 - 0
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -10470,6 +10470,10 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
         A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
         A.getAttributeSpellingListIndex());
     break;
+  case AttributeList::AT_VKConstantId:
+    declAttr = ::new (S.Context) VKConstantIdAttr(A.getRange(), S.Context,
+      ValidateAttributeIntArg(S, A), A.getAttributeSpellingListIndex());
+    break;
   default:
     Handled = false;
     return;

+ 29 - 0
tools/clang/test/CodeGenSPIRV/vk.spec-constant.error.hlsl

@@ -0,0 +1,29 @@
+// Run: %dxc -T vs_6_0 -E main
+
+[[vk::constant_id(0)]]
+const bool sc0 = true;
+
+[[vk::constant_id(1)]]
+const bool sc1 = sc0; // error
+
+[[vk::constant_id(2)]]
+const double sc2 = 42; // error
+
+[[vk::constant_id(3)]]
+const int sc3; // error
+
+[[vk::constant_id(4)]]
+const int sc4 = sc3 + sc3; // error
+
+[[vk::constant_id(5)]]
+static const int sc5 = 1; // error
+
+float main() : A {
+    return 1.0;
+}
+
+// CHECK:  :7:18: error: unsupported specialization constant initializer
+// CHECK:  :10:1: error: unsupported specialization constant type
+// CHECK: :13:11: error: missing default value for specialization constant
+// CHECK: :16:17: error: unsupported specialization constant initializer
+// CHECK: :19:18: error: specialization constant must be externally visible

+ 63 - 0
tools/clang/test/CodeGenSPIRV/vk.spec-constant.init.hlsl

@@ -0,0 +1,63 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK: OpDecorate [[b0:%\d+]] SpecId 0
+// CHECK: OpDecorate [[b1:%\d+]] SpecId 1
+// CHECK: OpDecorate [[b2:%\d+]] SpecId 2
+
+// CHECK: OpDecorate [[i0:%\d+]] SpecId 10
+// CHECK: OpDecorate [[i1:%\d+]] SpecId 11
+// CHECK: OpDecorate [[i2:%\d+]] SpecId 12
+// CHECK: OpDecorate [[i3:%\d+]] SpecId 13
+
+// CHECK: OpDecorate [[u0:%\d+]] SpecId 20
+
+// CHECK: OpDecorate [[f0:%\d+]] SpecId 30
+// CHECK: OpDecorate [[f1:%\d+]] SpecId 31
+// CHECK: OpDecorate [[f2:%\d+]] SpecId 32
+// CHECK: OpDecorate [[f3:%\d+]] SpecId 33
+
+// CHECK: [[b0]] = OpSpecConstantTrue %bool
+[[vk::constant_id(0)]]
+bool b0 = true;
+// CHECK: [[b1]] = OpSpecConstantFalse %bool
+[[vk::constant_id(1)]]
+bool b1 = 0;
+// CHECK: [[b2]] = OpSpecConstantTrue %bool
+[[vk::constant_id(2)]]
+bool b2 = 1.5;
+
+
+// CHECK:  [[i0]] = OpSpecConstant %int 42
+[[vk::constant_id(10)]]
+int i0 = 42;
+// CHECK:  [[i1]] = OpSpecConstant %int -42
+[[vk::constant_id(11)]]
+int i1 = -42;
+// CHECK:  [[i2]] = OpSpecConstant %int 1
+[[vk::constant_id(12)]]
+int i2 = (true);
+// CHECK:  [[i3]] = OpSpecConstant %int 2
+[[vk::constant_id(13)]]
+int i3 = 2.5;
+
+// CHECK: [[u0]] = OpSpecConstant %uint 56
+[[vk::constant_id(20)]]
+uint uintConst1 = 56;
+
+// CHECK: [[f0]] = OpSpecConstant %float 4.2
+[[vk::constant_id(30)]]
+float f0 = (4.2);
+// CHECK: [[f1]] = OpSpecConstant %float -4.2
+[[vk::constant_id(31)]]
+float f1 = -4.2;
+// CHECK: [[f2]] = OpSpecConstant %float 1
+[[vk::constant_id(32)]]
+float f2 = true;
+// CHECK: [[f3]] = OpSpecConstant %float 20
+[[vk::constant_id(33)]]
+float f3 = 20;
+
+
+float main() : A {
+    return 1.0;
+}

+ 24 - 0
tools/clang/test/CodeGenSPIRV/vk.spec-constant.usage.hlsl

@@ -0,0 +1,24 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK: OpDecorate [[sc:%\d+]] SpecId 10
+[[vk::constant_id(10)]]
+// CHECK: [[sc]] = OpSpecConstant %int 12
+const int specConst = 12;
+
+// TODO: The frontend parsing hits assertion failures saying cannot evaluating
+// as constant int for the following usages.
+/*
+cbuffer Data {
+    float4 pos[specConst];
+    float4 tex[specConst + 5];
+};
+*/
+
+// CHECK: [[add:%\d+]] = OpSpecConstantOp %int IAdd [[sc]] %int_3
+static const int val = specConst + 3;
+
+// CHECK-LABEL:  %main = OpFunction
+// CHECK:                OpStore %val [[add]]
+void main() {
+
+}

+ 10 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -1116,6 +1116,16 @@ TEST_F(FileTest, VulkanMultiplePushConstant) {
   runFileTest("vk.push-constant.multiple.hlsl", Expect::Failure);
 }
 
+TEST_F(FileTest, VulkanSpecConstantInit) {
+  runFileTest("vk.spec-constant.init.hlsl");
+}
+TEST_F(FileTest, VulkanSpecConstantUsage) {
+  runFileTest("vk.spec-constant.usage.hlsl");
+}
+TEST_F(FileTest, VulkanSpecConstantError) {
+  runFileTest("vk.spec-constant.error.hlsl", Expect::Failure);
+}
+
 TEST_F(FileTest, VulkanLayoutCBufferStd140) {
   runFileTest("vk.layout.cbuffer.std140.hlsl");
 }