Pārlūkot izejas kodu

[spirv] 16-bit and 64-bit int, uint, and float. (#966)

* [spirv] 16-bit and 64-bit int, uint, and float.

* Added Int64, Uint64, Int16, Uint16.
* Added 16-bit float constants.
* Get the -enable-16bit-types cmd option.
* Add tests for constant Int64/Uint64/Int16/etc.
Ehsan 7 gadi atpakaļ
vecāks
revīzija
eeab612da9
28 mainītis faili ar 546 papildinājumiem un 174 dzēšanām
  1. 35 25
      docs/SPIR-V.rst
  2. 10 0
      tools/clang/include/clang/SPIRV/Constant.h
  3. 1 0
      tools/clang/include/clang/SPIRV/EmitSPIRVOptions.h
  4. 8 0
      tools/clang/include/clang/SPIRV/ModuleBuilder.h
  5. 88 0
      tools/clang/lib/SPIRV/Constant.cpp
  6. 1 1
      tools/clang/lib/SPIRV/DeclResultIdMapper.h
  7. 15 6
      tools/clang/lib/SPIRV/ModuleBuilder.cpp
  8. 63 38
      tools/clang/lib/SPIRV/SPIRVEmitter.cpp
  9. 73 21
      tools/clang/lib/SPIRV/TypeTranslator.cpp
  10. 34 7
      tools/clang/lib/SPIRV/TypeTranslator.h
  11. 25 23
      tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv
  12. 1 1
      tools/clang/test/CodeGenSPIRV/binary-op.assign.composite.hlsl
  13. 46 0
      tools/clang/test/CodeGenSPIRV/constant.scalar.16bit.disabled.hlsl
  14. 70 0
      tools/clang/test/CodeGenSPIRV/constant.scalar.16bit.enabled.hlsl
  15. 29 0
      tools/clang/test/CodeGenSPIRV/constant.scalar.64bit.hlsl
  16. 17 33
      tools/clang/test/CodeGenSPIRV/constant.scalar.hlsl
  17. 1 1
      tools/clang/test/CodeGenSPIRV/cs.groupshared.hlsl
  18. 4 4
      tools/clang/test/CodeGenSPIRV/op.array.access.hlsl
  19. 1 1
      tools/clang/test/CodeGenSPIRV/op.cbuffer.access.hlsl
  20. 1 1
      tools/clang/test/CodeGenSPIRV/op.constant-buffer.access.hlsl
  21. 2 2
      tools/clang/test/CodeGenSPIRV/op.rw-structured-buffer.access.hlsl
  22. 2 2
      tools/clang/test/CodeGenSPIRV/op.structured-buffer.access.hlsl
  23. 1 1
      tools/clang/test/CodeGenSPIRV/op.tbuffer.access.hlsl
  24. 1 1
      tools/clang/test/CodeGenSPIRV/op.texture-buffer.access.hlsl
  25. 4 4
      tools/clang/test/CodeGenSPIRV/var.init.array.hlsl
  26. 1 1
      tools/clang/test/CodeGenSPIRV/vk.push-constant.hlsl
  27. 1 1
      tools/clang/tools/dxcompiler/dxcompilerobj.cpp
  28. 11 0
      tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

+ 35 - 25
docs/SPIR-V.rst

@@ -243,24 +243,26 @@ Normal scalar types
 in HLSL are relatively easy to handle and can be mapped directly to SPIR-V
 type instructions:
 
-================== ================== =========== ====================
-      HLSL               SPIR-V       Capability       Decoration
-================== ================== =========== ====================
-``bool``           ``OpTypeBool``
-``int``            ``OpTypeInt 32 1``
-``uint``/``dword`` ``OpTypeInt 32 0``
-``half``           ``OpTypeFloat 32``             ``RelexedPrecision``
-``float``          ``OpTypeFloat 32``
-``snorm float``    ``OpTypeFloat 32``
-``unorm float``    ``OpTypeFloat 32``
-``double``         ``OpTypeFloat 64`` ``Float64``
-================== ================== =========== ====================
+============================== ======================= ================== =========== =================================
+      HLSL                      Command Line Option           SPIR-V       Capability       Extension
+============================== ======================= ================== =========== =================================
+``bool``                                               ``OpTypeBool``
+``int``/``int32_t``                                    ``OpTypeInt 32 1``
+``int16_t``                    ``-enable-16bit-types`` ``OpTypeInt 16 1`` ``Int16``
+``uint``/``dword``/``uin32_t``                         ``OpTypeInt 32 0``
+``uint16_t``                   ``-enable-16bit-types`` ``OpTypeInt 16 0`` ``Int16``
+``half``                                               ``OpTypeFloat 32`` 
+``half``/``float16_t``         ``-enable-16bit-types`` ``OpTypeFloat 16`` ``Float16`` ``SPV_AMD_gpu_shader_half_float``
+``float``/``float32_t``                                ``OpTypeFloat 32``
+``snorm float``                                        ``OpTypeFloat 32``
+``unorm float``                                        ``OpTypeFloat 32``
+``double``/``float64_t``                               ``OpTypeFloat 64`` ``Float64``
+============================== ======================= ================== =========== =================================
 
 Please note that ``half`` is translated into 32-bit floating point numbers
 right now because MSDN says that "this data type is provided only for language
 compatibility. Direct3D 10 shader targets map all ``half`` data types to
-``float`` data types." This may change in the future to map to 16-bit floating
-point numbers (possibly via a command-line option).
+``float`` data types."
 
 Minimal precision scalar types
 ------------------------------
@@ -270,17 +272,25 @@ HLSL also supports various
 which graphics drivers can implement by using any precision greater than or
 equal to their specified bit precision.
 There are no direct mappings in SPIR-V for these types. We translate them into
-the corresponding 32-bit scalar types with the ``RelexedPrecision`` decoration:
-
-============== ================== ====================
-    HLSL            SPIR-V            Decoration
-============== ================== ====================
-``min16float`` ``OpTypeFloat 32`` ``RelexedPrecision``
-``min10float`` ``OpTypeFloat 32`` ``RelexedPrecision``
-``min16int``   ``OpTypeInt 32 1`` ``RelexedPrecision``
-``min12int``   ``OpTypeInt 32 1`` ``RelexedPrecision``
-``min16uint``  ``OpTypeInt 32 0`` ``RelexedPrecision``
-============== ================== ====================
+the corresponding 16-bit or 32-bit scalar types with the ``RelaxedPrecision`` decoration.
+We use the 16-bit variants if '-enable-16bit-types' command line option is present.
+For more information on these types, please refer to:
+https://github.com/Microsoft/DirectXShaderCompiler/wiki/16-Bit-Scalar-Types
+
+============== ======================= ================== ==================== ============ =================================
+    HLSL        Command Line Option          SPIR-V            Decoration       Capability        Extension
+============== ======================= ================== ==================== ============ =================================
+``min16float``                         ``OpTypeFloat 32`` ``RelaxedPrecision``
+``min10float``                         ``OpTypeFloat 32`` ``RelaxedPrecision``
+``min16int``                           ``OpTypeInt 32 1`` ``RelaxedPrecision``
+``min12int``                           ``OpTypeInt 32 1`` ``RelaxedPrecision``
+``min16uint``                          ``OpTypeInt 32 0`` ``RelaxedPrecision``
+``min16float`` ``-enable-16bit-types`` ``OpTypeFloat 16``                      ``Float16``  ``SPV_AMD_gpu_shader_half_float``
+``min10float`` ``-enable-16bit-types`` ``OpTypeFloat 16``                      ``Float16``  ``SPV_AMD_gpu_shader_half_float``
+``min16int``   ``-enable-16bit-types`` ``OpTypeInt 16 1``                      ``Int16``
+``min12int``   ``-enable-16bit-types`` ``OpTypeInt 16 1``                      ``Int16``
+``min16uint``  ``-enable-16bit-types`` ``OpTypeInt 16 0``                      ``Int16``
+============== ======================= ================== ==================== ============ =================================
 
 Vectors and matrices
 --------------------

+ 10 - 0
tools/clang/include/clang/SPIRV/Constant.h

@@ -61,10 +61,20 @@ public:
                                  DecorationSet dec = {});
   static const Constant *getFalse(SPIRVContext &ctx, uint32_t type_id,
                                   DecorationSet dec = {});
+  static const Constant *getInt16(SPIRVContext &ctx, uint32_t type_id,
+                                  int16_t value, DecorationSet dec = {});
   static const Constant *getInt32(SPIRVContext &ctx, uint32_t type_id,
                                   int32_t value, DecorationSet dec = {});
+  static const Constant *getInt64(SPIRVContext &ctx, uint32_t type_id,
+                                  int64_t value, DecorationSet dec = {});
+  static const Constant *getUint16(SPIRVContext &ctx, uint32_t type_id,
+                                   uint16_t value, DecorationSet dec = {});
   static const Constant *getUint32(SPIRVContext &ctx, uint32_t type_id,
                                    uint32_t value, DecorationSet dec = {});
+  static const Constant *getUint64(SPIRVContext &ctx, uint32_t type_id,
+                                   uint64_t value, DecorationSet dec = {});
+  static const Constant *getFloat16(SPIRVContext &ctx, uint32_t type_id,
+                                    int16_t value, DecorationSet dec = {});
   static const Constant *getFloat32(SPIRVContext &ctx, uint32_t type_id,
                                     float value, DecorationSet dec = {});
   static const Constant *getFloat64(SPIRVContext &ctx, uint32_t type_id,

+ 1 - 0
tools/clang/include/clang/SPIRV/EmitSPIRVOptions.h

@@ -19,6 +19,7 @@ struct EmitSPIRVOptions {
   bool codeGenHighLevel;
   bool disableValidation;
   bool ignoreUnusedResources;
+  bool enable16BitTypes;
   llvm::StringRef stageIoOrder;
   llvm::SmallVector<uint32_t, 4> bShift;
   llvm::SmallVector<uint32_t, 4> tShift;

+ 8 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -358,10 +358,13 @@ public:
 
   uint32_t getVoidType();
   uint32_t getBoolType();
+  uint32_t getInt16Type();
   uint32_t getInt32Type();
   uint32_t getInt64Type();
+  uint32_t getUint16Type();
   uint32_t getUint32Type();
   uint32_t getUint64Type();
+  uint32_t getFloat16Type();
   uint32_t getFloat32Type();
   uint32_t getFloat64Type();
   uint32_t getVecType(uint32_t elemType, uint32_t elemCount);
@@ -391,8 +394,13 @@ public:
 
   // === Constant ===
   uint32_t getConstantBool(bool value);
+  uint32_t getConstantInt16(int16_t value);
   uint32_t getConstantInt32(int32_t value);
+  uint32_t getConstantInt64(int64_t value);
+  uint32_t getConstantUint16(uint16_t value);
   uint32_t getConstantUint32(uint32_t value);
+  uint32_t getConstantUint64(uint64_t value);
+  uint32_t getConstantFloat16(int16_t value);
   uint32_t getConstantFloat32(float value);
   uint32_t getConstantFloat64(double value);
   uint32_t getConstantComposite(uint32_t typeId,

+ 88 - 0
tools/clang/lib/SPIRV/Constant.cpp

@@ -11,6 +11,37 @@
 #include "clang/SPIRV/BitwiseCast.h"
 #include "clang/SPIRV/SPIRVContext.h"
 
+namespace {
+uint32_t zeroExtendTo32Bits(uint16_t value) {
+  // TODO: The ordering of the 2 words depends on the endian-ness of the host
+  // machine. Assuming Little Endian at the moment.
+  struct two16Bits {
+    uint16_t low;
+    uint16_t high;
+  };
+
+  two16Bits result = {value, 0};
+  return clang::spirv::cast::BitwiseCast<uint32_t, two16Bits>(result);
+}
+
+uint32_t signExtendTo32Bits(int16_t value) {
+  // TODO: The ordering of the 2 words depends on the endian-ness of the host
+  // machine. Assuming Little Endian at the moment.
+  struct two16Bits {
+    int16_t low;
+    uint16_t high;
+  };
+
+  two16Bits result = {value, 0};
+
+  // Sign bit is 1
+  if (value >> 15) {
+    result.high = 0xffff;
+  }
+  return clang::spirv::cast::BitwiseCast<uint32_t, two16Bits>(result);
+}
+}
+
 namespace clang {
 namespace spirv {
 
@@ -37,6 +68,17 @@ const Constant *Constant::getFalse(SPIRVContext &ctx, uint32_t type_id,
   return getUniqueConstant(ctx, c);
 }
 
+const Constant *Constant::getFloat16(SPIRVContext &ctx, uint32_t type_id,
+                                     int16_t value, DecorationSet dec) {
+  // According to the SPIR-V Spec:
+  // When the type's bit width is less than 32-bits, the literal's value appears
+  // in the low-order bits of the word, and the high-order bits must be 0 for a
+  // floating-point type.
+  Constant c = Constant(spv::Op::OpConstant, type_id,
+                        {zeroExtendTo32Bits(value)}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
 const Constant *Constant::getFloat32(SPIRVContext &ctx, uint32_t type_id,
                                      float value, DecorationSet dec) {
   Constant c = Constant(spv::Op::OpConstant, type_id,
@@ -58,12 +100,46 @@ const Constant *Constant::getFloat64(SPIRVContext &ctx, uint32_t type_id,
   return getUniqueConstant(ctx, c);
 }
 
+const Constant *Constant::getUint16(SPIRVContext &ctx, uint32_t type_id,
+                                    uint16_t value, DecorationSet dec) {
+  // According to the SPIR-V Spec:
+  // When the type's bit width is less than 32-bits, the literal's value appears
+  // in the low-order bits of the word, and the high-order bits must be 0 for an
+  // integer type with Signedness of 0.
+  Constant c = Constant(spv::Op::OpConstant, type_id,
+                        {zeroExtendTo32Bits(value)}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
 const Constant *Constant::getUint32(SPIRVContext &ctx, uint32_t type_id,
                                     uint32_t value, DecorationSet dec) {
   Constant c = Constant(spv::Op::OpConstant, type_id, {value}, dec);
   return getUniqueConstant(ctx, c);
 }
 
+const Constant *Constant::getUint64(SPIRVContext &ctx, uint32_t type_id,
+                                    uint64_t value, DecorationSet dec) {
+  struct wideInt {
+    uint32_t word0;
+    uint32_t word1;
+  };
+  wideInt words = cast::BitwiseCast<wideInt, uint64_t>(value);
+  Constant c =
+      Constant(spv::Op::OpConstant, type_id, {words.word0, words.word1}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getInt16(SPIRVContext &ctx, uint32_t type_id,
+                                    int16_t value, DecorationSet dec) {
+  // According to the SPIR-V Spec:
+  // When the type's bit width is less than 32-bits, the literal's value appears
+  // in the low-order bits of the word, and the high-order bits must be
+  // sign-extended for integers with Signedness of 1.
+  Constant c = Constant(spv::Op::OpConstant, type_id,
+                        {signExtendTo32Bits(value)}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
 const Constant *Constant::getInt32(SPIRVContext &ctx, uint32_t type_id,
                                    int32_t value, DecorationSet dec) {
   Constant c = Constant(spv::Op::OpConstant, type_id,
@@ -71,6 +147,18 @@ const Constant *Constant::getInt32(SPIRVContext &ctx, uint32_t type_id,
   return getUniqueConstant(ctx, c);
 }
 
+const Constant *Constant::getInt64(SPIRVContext &ctx, uint32_t type_id,
+                                   int64_t value, DecorationSet dec) {
+  struct wideInt {
+    uint32_t word0;
+    uint32_t word1;
+  };
+  wideInt words = cast::BitwiseCast<wideInt, int64_t>(value);
+  Constant c =
+      Constant(spv::Op::OpConstant, type_id, {words.word0, words.word1}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
 const Constant *Constant::getComposite(SPIRVContext &ctx, uint32_t type_id,
                                        llvm::ArrayRef<uint32_t> constituents,
                                        DecorationSet dec) {

+ 1 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -597,7 +597,7 @@ DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
                                        const EmitSPIRVOptions &options)
     : shaderModel(model), theBuilder(builder), spirvOptions(options),
       astContext(context), diags(context.getDiagnostics()),
-      typeTranslator(context, builder, diags), entryFunctionId(0),
+      typeTranslator(context, builder, diags, options), entryFunctionId(0),
       needsLegalization(false),
       glPerVertex(model, context, builder, typeTranslator) {}
 

+ 15 - 6
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -791,19 +791,23 @@ IMPL_GET_PRIMITIVE_TYPE(Float32)
 #undef IMPL_GET_PRIMITIVE_TYPE
 
 #define IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(ty, cap)                       \
-  \
-uint32_t ModuleBuilder::get##ty##Type() {                                      \
-    requireCapability(spv::Capability::cap);                                    \
+                                                                               \
+  uint32_t ModuleBuilder::get##ty##Type() {                                    \
+    requireCapability(spv::Capability::cap);                                   \
+    if (spv::Capability::cap == spv::Capability::Float16)                      \
+      theModule.addExtension("SPV_AMD_gpu_shader_half_float");                 \
     const Type *type = Type::get##ty(theContext);                              \
     const uint32_t typeId = theContext.getResultIdForType(type);               \
     theModule.addType(type, typeId);                                           \
     return typeId;                                                             \
-  \
-}
+  }
 
-IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float64, Float64)
 IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int64, Int64)
 IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint64, Int64)
+IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float64, Float64)
+IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int16, Int16)
+IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint16, Int16)
+IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float16, Float16)
 
 #undef IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY
 
@@ -1056,10 +1060,15 @@ uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) {                  \
   \
 }
 
+IMPL_GET_PRIMITIVE_CONST(Int16, int16_t)
 IMPL_GET_PRIMITIVE_CONST(Int32, int32_t)
+IMPL_GET_PRIMITIVE_CONST(Uint16, uint16_t)
 IMPL_GET_PRIMITIVE_CONST(Uint32, uint32_t)
+IMPL_GET_PRIMITIVE_CONST(Float16, int16_t)
 IMPL_GET_PRIMITIVE_CONST(Float32, float)
 IMPL_GET_PRIMITIVE_CONST(Float64, double)
+IMPL_GET_PRIMITIVE_CONST(Int64, int64_t)
+IMPL_GET_PRIMITIVE_CONST(Uint64, uint64_t)
 
 #undef IMPL_GET_PRIMITIVE_VALUE
 

+ 63 - 38
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -379,19 +379,6 @@ const ValueDecl *getReferencedDef(const Expr *expr) {
   return nullptr;
 }
 
-bool isLiteralType(QualType type) {
-  if (type->isSpecificBuiltinType(BuiltinType::LitInt) ||
-      type->isSpecificBuiltinType(BuiltinType::LitFloat))
-    return true;
-
-  // For cases such as 'vector<literal int, 2>'
-  QualType elemType = {};
-  if (TypeTranslator::isVectorType(type, &elemType))
-    return isLiteralType(elemType);
-
-  return false;
-}
-
 } // namespace
 
 SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
@@ -403,9 +390,9 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
           ci.getCodeGenOpts().HLSLProfile.c_str())),
       theContext(), theBuilder(&theContext),
       declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
-      typeTranslator(astContext, theBuilder, diags), entryFunctionId(0),
-      curFunction(nullptr), curThis(0), seenPushConstantAt(),
-      needsLegalization(false) {
+      typeTranslator(astContext, theBuilder, diags, options),
+      entryFunctionId(0), curFunction(nullptr), curThis(0),
+      seenPushConstantAt(), needsLegalization(false) {
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
     emitError("unknown shader module: %0", {}) << shaderModel.GetName();
 }
@@ -597,9 +584,9 @@ SpirvEvalInfo SPIRVEmitter::doDeclRefExpr(const DeclRefExpr *expr) {
 SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
   SpirvEvalInfo result(/*id*/ 0);
 
-  const bool isNonLiteralType = !isLiteralType(expr->getType());
-  if (isNonLiteralType)
-    typeTranslator.pushIntendedLiteralType(expr->getType());
+  // Provide a hint to the typeTranslator that if a literal is discovered, its
+  // intended usage is as this expression type.
+  TypeTranslator::LiteralTypeHint hint(typeTranslator, expr->getType());
 
   expr = expr->IgnoreParens();
 
@@ -650,9 +637,6 @@ SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
         << expr->getStmtClassName() << expr->getSourceRange();
   }
 
-  if (isNonLiteralType)
-    typeTranslator.popIntendedLiteralType();
-
   return result;
 }
 
@@ -1000,7 +984,7 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
       needsLegalization = true;
   }
 
-  if (TypeTranslator::isRelaxedPrecisionType(decl->getType())) {
+  if (TypeTranslator::isRelaxedPrecisionType(decl->getType(), spirvOptions)) {
     theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision);
   }
 
@@ -1551,6 +1535,10 @@ void SPIRVEmitter::doSwitchStmt(const SwitchStmt *switchStmt,
 
 SpirvEvalInfo
 SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
+  // Provide a hint to the TypeTranslator that the integer literal used to
+  // index into the array should be translated as a 32-bit integer.
+  TypeTranslator::LiteralTypeHint hint(typeTranslator, astContext.IntTy);
+
   llvm::SmallVector<uint32_t, 4> indices;
   auto info = doExpr(collectArrayStructIndices(expr, &indices));
 
@@ -6656,10 +6644,10 @@ uint32_t SPIRVEmitter::getMatElemValueOne(QualType type) {
 uint32_t SPIRVEmitter::translateAPValue(const APValue &value,
                                         const QualType targetType) {
   uint32_t result = 0;
-  const bool isNonLiteralType = !isLiteralType(targetType);
 
-  if (isNonLiteralType)
-    typeTranslator.pushIntendedLiteralType(targetType);
+  // Provide a hint to the typeTranslator that if a literal is discovered, its
+  // intended usage is targetType.
+  TypeTranslator::LiteralTypeHint hint(typeTranslator, targetType);
 
   if (targetType->isBooleanType()) {
     result = theBuilder.getConstantBool(value.getInt().getBoolValue());
@@ -6684,11 +6672,8 @@ uint32_t SPIRVEmitter::translateAPValue(const APValue &value,
     }
   }
 
-  if (result) {
-    if (isNonLiteralType)
-      typeTranslator.popIntendedLiteralType();
+  if (result)
     return result;
-  }
 
   emitError("APValue of type %0 unimplemented", {}) << value.getKind();
   value.dump();
@@ -6698,18 +6683,42 @@ uint32_t SPIRVEmitter::translateAPValue(const APValue &value,
 uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
                                       QualType targetType) {
   targetType = typeTranslator.getIntendedLiteralType(targetType);
-  if (targetType->isSignedIntegerType()) {
-    // Try to see if this integer can be represented in 32-bit.
-    if (intValue.isSignedIntN(32)) {
+  const auto targetTypeBitWidth = astContext.getTypeSize(targetType);
+  const bool isSigned = targetType->isSignedIntegerType();
+  switch (targetTypeBitWidth) {
+  case 16: {
+    if (spirvOptions.enable16BitTypes) {
+      if (isSigned) {
+        return theBuilder.getConstantInt16(
+            static_cast<int16_t>(intValue.getSExtValue()));
+      } else {
+        return theBuilder.getConstantUint16(
+            static_cast<uint16_t>(intValue.getZExtValue()));
+      }
+    } else {
+      // If enable16BitTypes option is not true, treat as 32-bit integer.
+      if (isSigned)
+        return theBuilder.getConstantInt32(
+            static_cast<int32_t>(intValue.getSExtValue()));
+      else
+        return theBuilder.getConstantUint32(
+            static_cast<uint32_t>(intValue.getZExtValue()));
+    }
+  }
+  case 32: {
+    if (isSigned)
       return theBuilder.getConstantInt32(
           static_cast<int32_t>(intValue.getSExtValue()));
-    }
-  } else {
-    // Try to see if this integer can be represented in 32-bit.
-    if (intValue.isIntN(32)) {
+    else
       return theBuilder.getConstantUint32(
           static_cast<uint32_t>(intValue.getZExtValue()));
-    }
+  }
+  case 64: {
+    if (isSigned)
+      return theBuilder.getConstantInt64(intValue.getSExtValue());
+    else
+      return theBuilder.getConstantUint64(intValue.getZExtValue());
+  }
   }
 
   emitError("APInt for target bitwidth %0 unimplemented", {})
@@ -6761,6 +6770,22 @@ uint32_t SPIRVEmitter::translateAPFloat(const llvm::APFloat &floatValue,
   const auto &semantics = astContext.getFloatTypeSemantics(targetType);
   const auto bitwidth = llvm::APFloat::getSizeInBits(semantics);
   switch (bitwidth) {
+  case 16: {
+    if (spirvOptions.enable16BitTypes) {
+      return theBuilder.getConstantFloat16(
+          static_cast<uint16_t>(floatValue.bitcastToAPInt().getZExtValue()));
+    } else {
+      // If 16-bit types are not enabled, treat as 32-bit float.
+      llvm::APFloat f32 = floatValue;
+      bool losesInfo = false;
+      f32.convert(llvm::APFloat::IEEEsingle,
+                  llvm::APFloat::roundingMode::rmTowardZero, &losesInfo);
+      // Conversion from 16-bit float value to 32-bit float value should be
+      // loss-less.
+      assert(!losesInfo);
+      return theBuilder.getConstantFloat32(f32.convertToFloat());
+    }
+  }
   case 32:
     return theBuilder.getConstantFloat32(floatValue.convertToFloat());
   case 64:

+ 73 - 21
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -34,19 +34,29 @@ inline void roundToPow2(uint32_t *val, uint32_t pow2) {
 }
 } // anonymous namespace
 
-bool TypeTranslator::isRelaxedPrecisionType(QualType type) {
+bool TypeTranslator::isRelaxedPrecisionType(QualType type,
+                                            const EmitSPIRVOptions &opts) {
   // Primitive types
   {
     QualType ty = {};
     if (isScalarType(type, &ty))
       if (const auto *builtinType = ty->getAs<BuiltinType>())
         switch (builtinType->getKind()) {
+        // TODO: Figure out why 'min16float' and 'half' share an enum.
+        // 'half' should not get RelaxedPrecision decoration, but due to the
+        // shared enum, we currently do so.
+        case BuiltinType::Half:
         case BuiltinType::Short:
         case BuiltinType::UShort:
         case BuiltinType::Min12Int:
-        case BuiltinType::Min10Float:
-        case BuiltinType::Half:
-          return true;
+        case BuiltinType::Min10Float: {
+          // If '-enable-16bit-types' options is enabled, these types are
+          // translated to real 16-bit type, and therefore are not
+          // RelaxedPrecision.
+          // If the options is not enabled, these types are translated to 32-bit
+          // types with the added RelaxedPrecision decoration.
+          return !opts.enable16BitTypes;
+        }
         }
   }
 
@@ -55,7 +65,7 @@ bool TypeTranslator::isRelaxedPrecisionType(QualType type) {
   {
     QualType elemType = {};
     if (isVectorType(type, &elemType) || isMxNMatrix(type, &elemType))
-      return isRelaxedPrecisionType(elemType);
+      return isRelaxedPrecisionType(elemType, opts);
   }
 
   return false;
@@ -108,6 +118,29 @@ bool TypeTranslator::isOpaqueStructType(QualType type) {
   return false;
 }
 
+TypeTranslator::LiteralTypeHint::LiteralTypeHint(TypeTranslator &t, QualType ty)
+    : translator(t), type(ty) {
+  if (!isLiteralType(type))
+    translator.pushIntendedLiteralType(type);
+}
+TypeTranslator::LiteralTypeHint::~LiteralTypeHint() {
+  if (!isLiteralType(type))
+    translator.popIntendedLiteralType();
+}
+
+bool TypeTranslator::LiteralTypeHint::isLiteralType(QualType type) {
+  if (type->isSpecificBuiltinType(BuiltinType::LitInt) ||
+      type->isSpecificBuiltinType(BuiltinType::LitFloat))
+    return true;
+
+  // For cases such as 'vector<literal int, 2>'
+  QualType elemType = {};
+  if (isVectorType(type, &elemType))
+    return isLiteralType(elemType);
+
+  return false;
+}
+
 void TypeTranslator::pushIntendedLiteralType(QualType type) {
   QualType elemType = {};
   if (isVectorType(type, &elemType)) {
@@ -129,8 +162,8 @@ QualType TypeTranslator::getIntendedLiteralType(QualType type) {
 }
 
 void TypeTranslator::popIntendedLiteralType() {
-  if (!intendedLiteralTypes.empty())
-    intendedLiteralTypes.pop();
+  assert(!intendedLiteralTypes.empty());
+  intendedLiteralTypes.pop();
 }
 
 uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
@@ -154,29 +187,48 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
           return theBuilder.getVoidType();
         case BuiltinType::Bool:
           return theBuilder.getBoolType();
-          // int, min16int (short), and min12int are all translated to 32-bit
-          // signed integers in SPIR-V.
         case BuiltinType::Int:
-        case BuiltinType::Short:
-        case BuiltinType::Min12Int:
           return theBuilder.getInt32Type();
-          // uint and min16uint (ushort) are both translated to 32-bit unsigned
-          // integers in SPIR-V.
-        case BuiltinType::UShort:
         case BuiltinType::UInt:
           return theBuilder.getUint32Type();
+        case BuiltinType::Float:
+          return theBuilder.getFloat32Type();
+        case BuiltinType::Double:
+          return theBuilder.getFloat64Type();
         case BuiltinType::LongLong:
           return theBuilder.getInt64Type();
         case BuiltinType::ULongLong:
           return theBuilder.getUint64Type();
-          // float, min16float (half), and min10float are all translated to
-          // 32-bit float in SPIR-V.
-        case BuiltinType::Float:
+        // min16int (short), and min12int are treated as 16-bit Int if
+        // '-enable-16bit-types' option is enabled. They are treated as 32-bit
+        // Int otherwise.
+        case BuiltinType::Short:
+        case BuiltinType::Min12Int: {
+          if (spirvOptions.enable16BitTypes)
+            return theBuilder.getInt16Type();
+          else
+            return theBuilder.getInt32Type();
+        }
+        // min16uint (ushort) is treated as 16-bit Uint if '-enable-16bit-types'
+        // option is enabled. It is treated as 32-bit Uint otherwise.
+        case BuiltinType::UShort: {
+          if (spirvOptions.enable16BitTypes)
+            return theBuilder.getUint16Type();
+          else
+            return theBuilder.getUint32Type();
+        }
+        // min16float (half), and min10float are all translated to
+        // 32-bit float in SPIR-V.
+        // min16float (half), and min10float are treated as 16-bit float if
+        // '-enable-16bit-types' option is enabled. They are treated as 32-bit
+        // float otherwise.
         case BuiltinType::Half:
-        case BuiltinType::Min10Float:
-          return theBuilder.getFloat32Type();
-        case BuiltinType::Double:
-          return theBuilder.getFloat64Type();
+        case BuiltinType::Min10Float: {
+          if (spirvOptions.enable16BitTypes)
+            return theBuilder.getFloat16Type();
+          else
+            return theBuilder.getFloat32Type();
+        }
         case BuiltinType::LitFloat: {
           // First try to see if there are any hints about how this literal type
           // is going to be used. If so, use the hint.

+ 34 - 7
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -14,6 +14,7 @@
 
 #include "clang/AST/Type.h"
 #include "clang/Basic/Diagnostic.h"
+#include "clang/SPIRV/EmitSPIRVOptions.h"
 #include "clang/SPIRV/ModuleBuilder.h"
 
 #include "SpirvEvalInfo.h"
@@ -31,8 +32,14 @@ namespace spirv {
 class TypeTranslator {
 public:
   TypeTranslator(ASTContext &context, ModuleBuilder &builder,
-                 DiagnosticsEngine &diag)
-      : astContext(context), theBuilder(builder), diags(diag) {}
+                 DiagnosticsEngine &diag, const EmitSPIRVOptions &opts)
+      : astContext(context), theBuilder(builder), diags(diag),
+        spirvOptions(opts) {}
+
+  ~TypeTranslator() {
+    // Perform any sanity checks.
+    assert(intendedLiteralTypes.empty());
+  }
 
   /// \brief Generates the corresponding SPIR-V type for the given Clang
   /// frontend type and returns the type's <result-id>. On failure, reports
@@ -154,7 +161,7 @@ public:
   /// \brief Returns true if the given type can use relaxed precision
   /// decoration. Integer and float types with lower than 32 bits can be
   /// operated on with a relaxed precision.
-  static bool isRelaxedPrecisionType(QualType);
+  static bool isRelaxedPrecisionType(QualType, const EmitSPIRVOptions &);
 
   /// Returns true if the given type will be translated into a SPIR-V image,
   /// sampler or struct containing images or samplers.
@@ -231,16 +238,35 @@ public:
                                                     uint32_t *stride);
 
 public:
-  /// \brief Adds the given type to the intendedLiteralTypes stack. This will be
-  /// used as a hint regarding usage of literal types.
-  void pushIntendedLiteralType(QualType type);
-
   /// \brief If a hint exists regarding the usage of literal types, it
   /// is returned. Otherwise, the given type itself is returned.
   /// The hint is the type on top of the intendedLiteralTypes stack. This is the
   /// type we suspect the literal under question should be interpreted as.
   QualType getIntendedLiteralType(QualType type);
 
+public:
+  // A RAII class for maintaining the intendedLiteralTypes stack.
+  // Instantiating an object of this class ensures that as long as the
+  // object lives, the hint lives in the TypeTranslator, and once the object is
+  // destroyed, the hint is automatically removed from the stack.
+  class LiteralTypeHint {
+  public:
+    LiteralTypeHint(TypeTranslator &t, QualType ty);
+    ~LiteralTypeHint();
+
+  private:
+    static bool isLiteralType(QualType type);
+
+  private:
+    QualType type;
+    TypeTranslator &translator;
+  };
+
+private:
+  /// \brief Adds the given type to the intendedLiteralTypes stack. This will be
+  /// used as a hint regarding usage of literal types.
+  void pushIntendedLiteralType(QualType type);
+
   /// \brief Removes the type at the top of the intendedLiteralTypes stack.
   void popIntendedLiteralType();
 
@@ -248,6 +274,7 @@ private:
   ASTContext &astContext;
   ModuleBuilder &theBuilder;
   DiagnosticsEngine &diags;
+  const EmitSPIRVOptions &spirvOptions;
 
   /// \brief This is a stack which is used to track the intended usage type for
   /// literals. For example: while a floating literal is being visited, if the

+ 25 - 23
tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv

@@ -176,7 +176,7 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 // %95 = OpTypeFunction %HS_CONSTANT_DATA_OUTPUT %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3 %_ptr_Function_uint
 // %_ptr_Function_HS_CONSTANT_DATA_OUTPUT = OpTypePointer Function %HS_CONSTANT_DATA_OUTPUT
 // %_ptr_Function_float = OpTypePointer Function %float
-// %118 = OpTypeFunction %BEZIER_CONTROL_POINT %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3 %_ptr_Function_uint %_ptr_Function_uint
+// %120 = OpTypeFunction %BEZIER_CONTROL_POINT %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3 %_ptr_Function_uint %_ptr_Function_uint
 // %_ptr_Function_VS_CONTROL_POINT_OUTPUT = OpTypePointer Function %VS_CONTROL_POINT_OUTPUT
 // %_ptr_Function_BEZIER_CONTROL_POINT = OpTypePointer Function %BEZIER_CONTROL_POINT
 // %_ptr_Function_v3float = OpTypePointer Function %v3float
@@ -184,10 +184,12 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 // %float_1 = OpConstant %float 1
 // %int_0 = OpConstant %int 0
 // %float_2 = OpConstant %float 2
+// %int_1 = OpConstant %int 1
 // %float_3 = OpConstant %float 3
+// %int_2 = OpConstant %int 2
 // %float_4 = OpConstant %float 4
+// %int_3 = OpConstant %int 3
 // %float_5 = OpConstant %float 5
-// %int_1 = OpConstant %int 1
 // %float_6 = OpConstant %float 6
 // %gl_PerVertexIn = OpVariable %_ptr_Input__arr_type_gl_PerVertex_uint_3 Input
 // %gl_PerVertexOut = OpVariable %_ptr_Output__arr_type_gl_PerVertex_uint_3 Output
@@ -263,32 +265,32 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 // %PatchID = OpFunctionParameter %_ptr_Function_uint
 // %bb_entry = OpLabel
 // %Output = OpVariable %_ptr_Function_HS_CONSTANT_DATA_OUTPUT Function
-// %105 = OpAccessChain %_ptr_Function_float %Output %int_0 %uint_0
+// %105 = OpAccessChain %_ptr_Function_float %Output %int_0 %int_0
 // OpStore %105 %float_1
-// %107 = OpAccessChain %_ptr_Function_float %Output %int_0 %uint_1
-// OpStore %107 %float_2
-// %109 = OpAccessChain %_ptr_Function_float %Output %int_0 %uint_2
-// OpStore %109 %float_3
-// %111 = OpAccessChain %_ptr_Function_float %Output %int_0 %uint_3
-// OpStore %111 %float_4
-// %114 = OpAccessChain %_ptr_Function_float %Output %int_1 %uint_0
-// OpStore %114 %float_5
-// %116 = OpAccessChain %_ptr_Function_float %Output %int_1 %uint_1
-// OpStore %116 %float_6
-// %117 = OpLoad %HS_CONSTANT_DATA_OUTPUT %Output
-// OpReturnValue %117
+// %108 = OpAccessChain %_ptr_Function_float %Output %int_0 %int_1
+// OpStore %108 %float_2
+// %111 = OpAccessChain %_ptr_Function_float %Output %int_0 %int_2
+// OpStore %111 %float_3
+// %114 = OpAccessChain %_ptr_Function_float %Output %int_0 %int_3
+// OpStore %114 %float_4
+// %116 = OpAccessChain %_ptr_Function_float %Output %int_1 %int_0
+// OpStore %116 %float_5
+// %118 = OpAccessChain %_ptr_Function_float %Output %int_1 %int_1
+// OpStore %118 %float_6
+// %119 = OpLoad %HS_CONSTANT_DATA_OUTPUT %Output
+// OpReturnValue %119
 // OpFunctionEnd
-// %src_SubDToBezierHS = OpFunction %BEZIER_CONTROL_POINT None %118
+// %src_SubDToBezierHS = OpFunction %BEZIER_CONTROL_POINT None %120
 // %ip_0 = OpFunctionParameter %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3
 // %cpid = OpFunctionParameter %_ptr_Function_uint
 // %PatchID_0 = OpFunctionParameter %_ptr_Function_uint
 // %bb_entry_0 = OpLabel
 // %vsOutput = OpVariable %_ptr_Function_VS_CONTROL_POINT_OUTPUT Function
 // %result = OpVariable %_ptr_Function_BEZIER_CONTROL_POINT Function
-// %128 = OpAccessChain %_ptr_Function_v3float %vsOutput %int_0
-// %129 = OpLoad %v3float %128
-// %130 = OpAccessChain %_ptr_Function_v3float %result %int_0
-// OpStore %130 %129
-// %131 = OpLoad %BEZIER_CONTROL_POINT %result
-// OpReturnValue %131
-// OpFunctionEnd
+// %130 = OpAccessChain %_ptr_Function_v3float %vsOutput %int_0
+// %131 = OpLoad %v3float %130
+// %132 = OpAccessChain %_ptr_Function_v3float %result %int_0
+// OpStore %132 %131
+// %133 = OpLoad %BEZIER_CONTROL_POINT %result
+// OpReturnValue %133
+// OpFunctionEnd

+ 1 - 1
tools/clang/test/CodeGenSPIRV/binary-op.assign.composite.hlsl

@@ -74,7 +74,7 @@ void main(uint index: A) {
     BufferType lbuf;                  // %BufferType_0                   & %SubBuffer_1
     sbuf[5]  = lbuf;             // %BufferType <- %BufferType_0
 
-// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_SubBuffer_0 %cbuf %int_3 %uint_0
+// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_SubBuffer_0 %cbuf %int_3 %int_0
 // CHECK-NEXT: [[cbuf_d0:%\d+]] = OpLoad %SubBuffer_0 [[ptr]]
 
     // sub.a[0] <- cbuf.d[0].a[0]

+ 46 - 0
tools/clang/test/CodeGenSPIRV/constant.scalar.16bit.disabled.hlsl

@@ -0,0 +1,46 @@
+// Run: %dxc -T ps_6_0 -E main
+
+/////////////////////////////////////////////////////////////////////////////////
+/// Types with fewer than 32 bits, used without '-enable-16bit-type' options  ///
+/////////////////////////////////////////////////////////////////////////////////
+
+// See https://github.com/Microsoft/DirectXShaderCompiler/wiki/16-Bit-Scalar-Types
+// for details about these types.
+
+// CHECK-NOT: OpDecorate %c_half_4_5 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_half_n8_2 RelaxedPrecision
+// CHECK: OpDecorate %c_min10float RelaxedPrecision
+// CHECK: OpDecorate %c_min16float RelaxedPrecision
+// CHECK: OpDecorate %c_min16int_n3 RelaxedPrecision
+// CHECK: OpDecorate %c_min16uint_5 RelaxedPrecision
+// CHECK: OpDecorate %c_min12int RelaxedPrecision
+
+void main() {
+// Note: in the absence of "-enable-16bit-types" option,
+// 'half' is translated to float *without* RelaxedPrecision decoration.
+// CHECK: %float_7_7 = OpConstant %float 7.7
+  half c_half_4_5 = 7.7;
+// CHECK: %float_n8_8 = OpConstant %float -8.8
+  half c_half_n8_2 = -8.8;
+
+// Note: in the absence of "-enable-16bit-type" option,
+// 'min{10|16}float' are translated to
+// 32-bit float in SPIR-V with RelaxedPrecision decoration (checked above).
+// CHECK: %float_1_5 = OpConstant %float 1.5
+  min10float c_min10float = 1.5;
+// CHECK: %float_n1 = OpConstant %float -1
+  min16float c_min16float = -1.0;
+
+// Note: in the absence of "-enable-16bit-type" option,
+// 'min12{uint|int}' and 'min16{uint|int}' are translated to
+// 32-bit uint/int in SPIR-V with RelaxedPrecision decoration (checked above).
+// CHECK: %int_n3 = OpConstant %int -3
+  min16int c_min16int_n3 = -3;
+// CHECK: %uint_5 = OpConstant %uint 5
+  min16uint c_min16uint_5 = 5;
+// CHECK: %int_n9 = OpConstant %int -9
+  min12int c_min12int = -9;
+// It seems that min12uint is still not supported by the front-end.
+// XXXXX: %uint_12 = OpConstant %uint 12 
+//  min12uint c_min12uint = 12;
+}

+ 70 - 0
tools/clang/test/CodeGenSPIRV/constant.scalar.16bit.enabled.hlsl

@@ -0,0 +1,70 @@
+// Run: %dxc -T ps_6_2 -E main -enable-16bit-types
+
+// Handling of 16-bit integers and 16-bit floats.
+// Note that this test runs utilizes "-enable-16bit-types" option above.
+
+// When this option is enabled, 16-bit types are use as outlined below:
+// min10float: float16_t(warning)
+// min16float: float16_t(warning)
+// half:       float16_t
+// float16_t:  float16_t
+// min12int:   int16_t(warning)
+// min16int:   int16_t(warning)
+// int16_t:    int16_t
+// min12uint:  uint16_t(warning)
+// min16uint:  uint16_t(warning)
+// uint16_t:   uint16_t
+
+// CHECK: OpCapability Float16
+// CHECK: OpCapability Int16
+
+// CHECK-NOT: OpDecorate %c_half RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min10float RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min16float RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_float16t RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min16int_n3 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min16int_3 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min16uint_5 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min12int_n9 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min12int_9 RelaxedPrecision
+// XXXXX-NOT: OpDecorate %c_min12uint RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_uint16_16 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_int16_n16 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_int16_16 RelaxedPrecision
+
+// CHECK: %short = OpTypeInt 16 1
+// CHECK: %ushort = OpTypeInt 16 0
+// CHECK: %half = OpTypeFloat 16
+
+void main() {
+// CHECK: %half_0x1_2p_0 = OpConstant %half 0x1.2p+0
+  half       c_half = 1.125;
+// CHECK: %half_0x1_ep_3 = OpConstant %half 0x1.ep+3
+  min10float c_min10float = 15.0;
+// CHECK: %half_n0x1p_0 = OpConstant %half -0x1p+0
+  min16float c_min16float = -1.0;
+// CHECK: %half_0x1_8p_0 = OpConstant %half 0x1.8p+0
+  float16_t  c_float16t = 1.5;
+
+// CHECK: %short_n3 = OpConstant %short -3
+  min16int   c_min16int_n3 = -3;
+// CHECK: %short_3 = OpConstant %short 3
+  min16int   c_min16int_3 = 3;
+// CHECK: %ushort_5 = OpConstant %ushort 5
+  min16uint  c_min16uint_5 = 5;
+
+// CHECK: %short_n9 = OpConstant %short -9
+  min12int   c_min12int_n9 = -9;
+// CHECK: %short_9 = OpConstant %short 9
+  min12int   c_min12int_9 = 9;
+// It seems that min12uint is still not supported by the front-end.
+// XXXXX: %short_12 = OpConstant %short 12 
+//  min12uint   c_min12uint = 12;
+
+// CHECK: %ushort_16 = OpConstant %ushort 16
+  uint16_t  c_uint16_16 = 16;
+// CHECK: %short_n16 = OpConstant %short -16
+  int16_t   c_int16_n16 = -16;
+// CHECK: %short_16 = OpConstant %short 16
+  int16_t   c_int16_16 = 16;
+}

+ 29 - 0
tools/clang/test/CodeGenSPIRV/constant.scalar.64bit.hlsl

@@ -0,0 +1,29 @@
+// Run: %dxc -T ps_6_0 -E main
+
+void main() {
+
+// CHECK: %double_0 = OpConstant %double 0
+  double c_double_0 = 0.;
+// CHECK: %double_n0 = OpConstant %double -0
+  float64_t c_double_n0 = -0.;
+// CHECK: %double_4_5 = OpConstant %double 4.5
+  float64_t c_double_4_5 = 4.5;
+// CHECK: %double_n8_2 = OpConstant %double -8.2
+  double c_double_n8_2 = -8.2;
+// CHECK: %double_1234567898765_32 = OpConstant %double 1234567898765.32
+  double c_large  =  1234567898765.32;
+// CHECK: %double_n1234567898765_32 = OpConstant %double -1234567898765.32
+  float64_t c_nlarge = -1234567898765.32;
+
+// CHECK: %long_1 = OpConstant %long 1
+  int64_t  c_int64_small_1  = 1;  
+// CHECK: %long_n1 = OpConstant %long -1
+  int64_t  c_int64_small_n1  = -1;  
+// CHECK: %long_2147483648 = OpConstant %long 2147483648
+  int64_t  c_int64_large  = 2147483648;
+
+// CHECK: %ulong_2 = OpConstant %ulong 2
+  uint64_t c_uint64_small_2 = 2;
+// CHECK: %ulong_4294967296 = OpConstant %ulong 4294967296
+  uint64_t c_uint64_large = 4294967296;
+}

+ 17 - 33
tools/clang/test/CodeGenSPIRV/constant.scalar.hlsl

@@ -1,64 +1,48 @@
 // Run: %dxc -T ps_6_0 -E main
 
 // TODO
-// 16bit & 64bit integer (require additional capability)
-// 16bit floats (require additional capability)
 // float: denormalized numbers, Inf, NaN
 
 void main() {
   // Boolean constants
-// CHECK-DAG: %true = OpConstantTrue %bool
+// CHECK: %true = OpConstantTrue %bool
   bool c_bool_t = true;
-// CHECK-DAG: %false = OpConstantFalse %bool
+// CHECK: %false = OpConstantFalse %bool
   bool c_bool_f = false;
 
   // Signed integer constants
-// CHECK-DAG: %int_0 = OpConstant %int 0
+// CHECK: %int_0 = OpConstant %int 0
   int c_int_0 = 0;
-// CHECK-DAG: %int_1 = OpConstant %int 1
+// CHECK: %int_1 = OpConstant %int 1
   int c_int_1 = 1;
-// CHECK-DAG: %int_n1 = OpConstant %int -1
+// CHECK: %int_n1 = OpConstant %int -1
   int c_int_n1 = -1;
-// CHECK-DAG: %int_42 = OpConstant %int 42
+// CHECK: %int_42 = OpConstant %int 42
   int c_int_42 = 42;
-// CHECK-DAG: %int_n42 = OpConstant %int -42
+// CHECK: %int_n42 = OpConstant %int -42
   int c_int_n42 = -42;
-// CHECK-DAG: %int_2147483647 = OpConstant %int 2147483647
+// CHECK: %int_2147483647 = OpConstant %int 2147483647
   int c_int_max = 2147483647;
-// CHECK-DAG: %int_n2147483648 = OpConstant %int -2147483648
+// CHECK: %int_n2147483648 = OpConstant %int -2147483648
   int c_int_min = -2147483648;
 
   // Unsigned integer constants
-// CHECK-DAG: %uint_0 = OpConstant %uint 0
+// CHECK: %uint_0 = OpConstant %uint 0
   uint c_uint_0 = 0;
-// CHECK-DAG: %uint_1 = OpConstant %uint 1
+// CHECK: %uint_1 = OpConstant %uint 1
   uint c_uint_1 = 1;
-// CHECK-DAG: %uint_38 = OpConstant %uint 38
+// CHECK: %uint_38 = OpConstant %uint 38
   uint c_uint_38 = 38;
-// CHECK-DAG: %uint_4294967295 = OpConstant %uint 4294967295
+// CHECK: %uint_4294967295 = OpConstant %uint 4294967295
   uint c_uint_max = 4294967295;
 
   // Float constants
-// CHECK-DAG: %float_0 = OpConstant %float 0
+// CHECK: %float_0 = OpConstant %float 0
   float c_float_0 = 0.;
-// CHECK-DAG: %float_n0 = OpConstant %float -0
+// CHECK: %float_n0 = OpConstant %float -0
   float c_float_n0 = -0.;
-// CHECK-DAG: %float_4_2 = OpConstant %float 4.2
+// CHECK: %float_4_2 = OpConstant %float 4.2
   float c_float_4_2 = 4.2;
-// CHECK-DAG: %float_n4_2 = OpConstant %float -4.2
+// CHECK: %float_n4_2 = OpConstant %float -4.2
   float c_float_n4_2 = -4.2;
-  
-  // double constants
-// CHECK-DAG: %double_0 = OpConstant %double 0
-  double c_double_0 = 0.;
-// CHECK-DAG: %double_n0 = OpConstant %double -0
-  double c_double_n0 = -0.;
-// CHECK-DAG: %double_4_5 = OpConstant %double 4.5
-  double c_double_4_5 = 4.5;
-// CHECK-DAG: %double_n8_2 = OpConstant %double -8.2
-  double c_double_n8_2 = -8.2;
-// CHECK-DAG: %double_1234567898765_32 = OpConstant %double 1234567898765.32
-  double c_large  =  1234567898765.32;
-// CHECK-DAG: %double_n1234567898765_32 = OpConstant %double -1234567898765.32
-  double c_nlarge = -1234567898765.32;
 }

+ 1 - 1
tools/clang/test/CodeGenSPIRV/cs.groupshared.hlsl

@@ -20,7 +20,7 @@ groupshared              S        s;
 void main(uint2 tid : SV_DispatchThreadID, uint2 gid : SV_GroupID) {
 // Make sure pointers have the correct storage class
 // CHECK:    {{%\d+}} = OpAccessChain %_ptr_Workgroup_float %s %int_0
-// CHECK: [[d0:%\d+]] = OpAccessChain %_ptr_Workgroup_v2float %d %uint_0
+// CHECK: [[d0:%\d+]] = OpAccessChain %_ptr_Workgroup_v2float %d %int_0
 // CHECK:    {{%\d+}} = OpAccessChain %_ptr_Workgroup_float [[d0]] %int_1
     d[0].y = s.f1;
 }

+ 4 - 4
tools/clang/test/CodeGenSPIRV/op.array.access.hlsl

@@ -17,20 +17,20 @@ float main(float val: A, uint index: B) : C {
 
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK-NEXT:  [[idx:%\d+]] = OpLoad %uint %index
-// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_float %var [[idx]] %uint_1 %int_0 %uint_2
+// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_float %var [[idx]] %int_1 %int_0 %int_2
 // CHECK-NEXT:                 OpStore [[ptr0]] [[val]]
 
     var[index][1].f[2] = val;
 // CHECK-NEXT: [[idx0:%\d+]] = OpLoad %uint %index
 // CHECK-NEXT: [[idx1:%\d+]] = OpLoad %uint %index
-// CHECK:      [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_float %var %uint_0 [[idx0]] %int_1 [[idx1]]
+// CHECK:      [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_float %var %int_0 [[idx0]] %int_1 [[idx1]]
 // CHECK-NEXT: [[load:%\d+]] = OpLoad %float [[ptr0]]
 // CHECK-NEXT:                 OpStore %r [[load]]
     r = var[0][index].g[index];
 
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK-NEXT: [[vec2:%\d+]] = OpCompositeConstruct %v2float [[val]] [[val]]
-// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_v4float %vecvar %uint_3
+// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_v4float %vecvar %int_3
 // CHECK-NEXT: [[vec4:%\d+]] = OpLoad %v4float [[ptr0]]
 // CHECK-NEXT:  [[res:%\d+]] = OpVectorShuffle %v4float [[vec4]] [[vec2]] 0 1 5 4
 // CHECK-NEXT:                 OpStore [[ptr0]] [[res]]
@@ -42,7 +42,7 @@ float main(float val: A, uint index: B) : C {
 
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK-NEXT: [[vec2:%\d+]] = OpCompositeConstruct %v2float [[val]] [[val]]
-// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float %matvar %uint_2
+// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float %matvar %int_2
 // CHECK-NEXT: [[val0:%\d+]] = OpCompositeExtract %float [[vec2]] 0
 // CHECK-NEXT: [[ptr1:%\d+]] = OpAccessChain %_ptr_Function_float [[ptr0]] %int_0 %int_1
 // CHECK-NEXT:                 OpStore [[ptr1]] [[val0]]

+ 1 - 1
tools/clang/test/CodeGenSPIRV/op.cbuffer.access.hlsl

@@ -29,7 +29,7 @@ float main() : A {
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s0]]
 
 // CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform__arr_float_uint_4 %var_MyCbuffer %int_4
-// CHECK-NEXT: [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float [[t]] %uint_3
+// CHECK-NEXT: [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float [[t]] %int_3
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t3]]
     return a + b.x + c[1][2] + s.f + t[3];
 }

+ 1 - 1
tools/clang/test/CodeGenSPIRV/op.constant-buffer.access.hlsl

@@ -29,7 +29,7 @@ float main() : A {
 // CHECK:      [[s:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_3 %int_0
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s]]
 
-// CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_4 %uint_3
+// CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_4 %int_3
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t]]
     return MyCbuffer.a + MyCbuffer.b.x + MyCbuffer.c[1][2] + MyCbuffer.s.f + MyCbuffer.t[3];
 }

+ 2 - 2
tools/clang/test/CodeGenSPIRV/op.rw-structured-buffer.access.hlsl

@@ -26,7 +26,7 @@ void main(uint index: A) {
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK-NEXT:  [[index:%\d+]] = OpLoad %uint %index
 
-// CHECK-NEXT:  [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 [[index]] %int_4 %uint_3
+// CHECK-NEXT:  [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 [[index]] %int_4 %int_3
 // CHECK-NEXT:  OpStore [[t3]] [[val]]
 
 // CHECK:       [[f:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_3 %int_3 %uint_0 %int_0
@@ -35,7 +35,7 @@ void main(uint index: A) {
 // CHECK-NEXT:  [[c212:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_2 %int_2 %uint_2 %uint_1 %uint_2
 // CHECK-NEXT:  OpStore [[c212]] [[val]]
 
-// CHECK-NEXT:  [[b1:%\d+]] = OpAccessChain %_ptr_Uniform_v2float %MySbuffer %int_0 %uint_1 %int_1 %uint_1
+// CHECK-NEXT:  [[b1:%\d+]] = OpAccessChain %_ptr_Uniform_v2float %MySbuffer %int_0 %uint_1 %int_1 %int_1
 // CHECK-NEXT:  [[x:%\d+]] = OpAccessChain %_ptr_Uniform_float [[b1]] %int_0
 // CHECK-NEXT:  OpStore [[x]] [[val]]
 

+ 2 - 2
tools/clang/test/CodeGenSPIRV/op.structured-buffer.access.hlsl

@@ -19,7 +19,7 @@ float4 main(uint index: A) : SV_Target {
 // CHECK:      [[a:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_0 %int_0
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[a]]
 
-// CHECK:      [[b1:%\d+]] = OpAccessChain %_ptr_Uniform_v2float %MySbuffer %int_0 %uint_1 %int_1 %uint_1
+// CHECK:      [[b1:%\d+]] = OpAccessChain %_ptr_Uniform_v2float %MySbuffer %int_0 %uint_1 %int_1 %int_1
 // CHECK-NEXT: [[x:%\d+]] = OpAccessChain %_ptr_Uniform_float [[b1]] %int_0
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[x]]
 
@@ -30,7 +30,7 @@ float4 main(uint index: A) : SV_Target {
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s]]
 
 // CHECK:      [[index:%\d+]] = OpLoad %uint %index
-// CHECK-NEXT: [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 [[index]] %int_4 %uint_3
+// CHECK-NEXT: [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 [[index]] %int_4 %int_3
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t]]
     return MySbuffer[0].a + MySbuffer[1].b[1].x + MySbuffer[2].c[2][1][2] +
            MySbuffer[3].s[0].f + MySbuffer[index].t[3];

+ 1 - 1
tools/clang/test/CodeGenSPIRV/op.tbuffer.access.hlsl

@@ -29,7 +29,7 @@ float main() : A {
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s0]]
 
 // CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform__arr_float_uint_4 %var_MyTbuffer %int_4
-// CHECK-NEXT: [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float [[t]] %uint_3
+// CHECK-NEXT: [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float [[t]] %int_3
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t3]]
     return a + b.x + c[1][2] + s.f + t[3];
 }

+ 1 - 1
tools/clang/test/CodeGenSPIRV/op.texture-buffer.access.hlsl

@@ -29,7 +29,7 @@ float main() : A {
 // CHECK:      [[s:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyTB %int_3 %int_0
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s]]
 
-// CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyTB %int_4 %uint_3
+// CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyTB %int_4 %int_3
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t]]
   return MyTB.a + MyTB.b.x + MyTB.c[1][2] + MyTB.s.f + MyTB.t[3];
 }

+ 4 - 4
tools/clang/test/CodeGenSPIRV/var.init.array.hlsl

@@ -67,19 +67,19 @@ void main() {
     T2 val2[2] = {val1};
 
 // val3[0]: Construct T3.h from T1.c.b[0]
-// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %uint_0 %int_0 %int_0 %uint_0
+// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %int_0 %int_0 %int_0 %uint_0
 // CHECK-NEXT:   [[h_val:%\d+]] = OpLoad %v2float [[b_0]]
 
 // val3[0]: Construct T3.i from T1.c.b[1]
-// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %uint_0 %int_0 %int_0 %uint_1
+// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %int_0 %int_0 %int_0 %uint_1
 // CHECK-NEXT:   [[i_val:%\d+]] = OpLoad %v2float [[b_1]]
 
 // val3[0]: Construct T3.j from T1.d.b[0]
-// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %uint_0 %int_1 %int_0 %uint_0
+// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %int_0 %int_1 %int_0 %uint_0
 // CHECK-NEXT:   [[j_val:%\d+]] = OpLoad %v2float [[b_0]]
 
 // val3[0]: Construct T3.k from T1.d.b[1]
-// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %uint_0 %int_1 %int_0 %uint_1
+// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %int_0 %int_1 %int_0 %uint_1
 // CHECK-NEXT:   [[k_val:%\d+]] = OpLoad %v2float [[b_1]]
 
 // CHECK-NEXT:  [[val3_0:%\d+]] = OpCompositeConstruct %T3 [[h_val]] [[i_val]] [[j_val]] [[k_val]]

+ 1 - 1
tools/clang/test/CodeGenSPIRV/vk.push-constant.hlsl

@@ -32,7 +32,7 @@ float main() : A {
         pcs.f2.z +
 // CHECK:     {{%\d+}} = OpAccessChain %_ptr_PushConstant_float %pcs %int_2 %uint_1 %uint_2
         pcs.f3[1][2] +
-// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_PushConstant_v2float %pcs %int_3 %int_0 %uint_2
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_PushConstant_v2float %pcs %int_3 %int_0 %int_2
 // CHECK:     {{%\d+}} = OpAccessChain %_ptr_PushConstant_float [[ptr]] %int_1
         pcs.f4.val[2].y;
 }

+ 1 - 1
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -473,7 +473,7 @@ public:
           spirvOpts.tShift = opts.VkTShift;
           spirvOpts.sShift = opts.VkSShift;
           spirvOpts.uShift = opts.VkUShift;
-
+          spirvOpts.enable16BitTypes = opts.Enable16BitTypes;
           clang::EmitSPIRVAction action(spirvOpts);
           FrontendInputFile file(utf8SourceName.m_psz, IK_HLSL);
           action.BeginSourceFile(compiler, file);

+ 11 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -82,6 +82,17 @@ TEST_F(FileTest, TriangleStreamTypes) {
 
 // For constants
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }
+TEST_F(FileTest, 16BitDisabledScalarConstants) {
+  runFileTest("constant.scalar.16bit.disabled.hlsl");
+}
+TEST_F(FileTest, 16BitEnabledScalarConstants) {
+  // TODO: Fix spirv-val to make sure it respects the 16-bit extension.
+  runFileTest("constant.scalar.16bit.enabled.hlsl", FileTest::Expect::Success,
+              /*runValidation*/ false);
+}
+TEST_F(FileTest, 64BitScalarConstants) {
+  runFileTest("constant.scalar.64bit.hlsl");
+}
 TEST_F(FileTest, VectorConstants) { runFileTest("constant.vector.hlsl"); }
 TEST_F(FileTest, MatrixConstants) { runFileTest("constant.matrix.hlsl"); }
 TEST_F(FileTest, StructConstants) { runFileTest("constant.struct.hlsl"); }