瀏覽代碼

[spirv] 16-bit and 64-bit int, uint, and float. (#966)

* [spirv] 16-bit and 64-bit int, uint, and float.

* Added Int64, Uint64, Int16, Uint16.
* Added 16-bit float constants.
* Get the -enable-16bit-types cmd option.
* Add tests for constant Int64/Uint64/Int16/etc.
Ehsan 7 年之前
父節點
當前提交
eeab612da9
共有 28 個文件被更改,包括 546 次插入174 次删除
  1. 35 25
      docs/SPIR-V.rst
  2. 10 0
      tools/clang/include/clang/SPIRV/Constant.h
  3. 1 0
      tools/clang/include/clang/SPIRV/EmitSPIRVOptions.h
  4. 8 0
      tools/clang/include/clang/SPIRV/ModuleBuilder.h
  5. 88 0
      tools/clang/lib/SPIRV/Constant.cpp
  6. 1 1
      tools/clang/lib/SPIRV/DeclResultIdMapper.h
  7. 15 6
      tools/clang/lib/SPIRV/ModuleBuilder.cpp
  8. 63 38
      tools/clang/lib/SPIRV/SPIRVEmitter.cpp
  9. 73 21
      tools/clang/lib/SPIRV/TypeTranslator.cpp
  10. 34 7
      tools/clang/lib/SPIRV/TypeTranslator.h
  11. 25 23
      tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv
  12. 1 1
      tools/clang/test/CodeGenSPIRV/binary-op.assign.composite.hlsl
  13. 46 0
      tools/clang/test/CodeGenSPIRV/constant.scalar.16bit.disabled.hlsl
  14. 70 0
      tools/clang/test/CodeGenSPIRV/constant.scalar.16bit.enabled.hlsl
  15. 29 0
      tools/clang/test/CodeGenSPIRV/constant.scalar.64bit.hlsl
  16. 17 33
      tools/clang/test/CodeGenSPIRV/constant.scalar.hlsl
  17. 1 1
      tools/clang/test/CodeGenSPIRV/cs.groupshared.hlsl
  18. 4 4
      tools/clang/test/CodeGenSPIRV/op.array.access.hlsl
  19. 1 1
      tools/clang/test/CodeGenSPIRV/op.cbuffer.access.hlsl
  20. 1 1
      tools/clang/test/CodeGenSPIRV/op.constant-buffer.access.hlsl
  21. 2 2
      tools/clang/test/CodeGenSPIRV/op.rw-structured-buffer.access.hlsl
  22. 2 2
      tools/clang/test/CodeGenSPIRV/op.structured-buffer.access.hlsl
  23. 1 1
      tools/clang/test/CodeGenSPIRV/op.tbuffer.access.hlsl
  24. 1 1
      tools/clang/test/CodeGenSPIRV/op.texture-buffer.access.hlsl
  25. 4 4
      tools/clang/test/CodeGenSPIRV/var.init.array.hlsl
  26. 1 1
      tools/clang/test/CodeGenSPIRV/vk.push-constant.hlsl
  27. 1 1
      tools/clang/tools/dxcompiler/dxcompilerobj.cpp
  28. 11 0
      tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

+ 35 - 25
docs/SPIR-V.rst

@@ -243,24 +243,26 @@ Normal scalar types
 in HLSL are relatively easy to handle and can be mapped directly to SPIR-V
 in HLSL are relatively easy to handle and can be mapped directly to SPIR-V
 type instructions:
 type instructions:
 
 
-================== ================== =========== ====================
-      HLSL               SPIR-V       Capability       Decoration
-================== ================== =========== ====================
-``bool``           ``OpTypeBool``
-``int``            ``OpTypeInt 32 1``
-``uint``/``dword`` ``OpTypeInt 32 0``
-``half``           ``OpTypeFloat 32``             ``RelexedPrecision``
-``float``          ``OpTypeFloat 32``
-``snorm float``    ``OpTypeFloat 32``
-``unorm float``    ``OpTypeFloat 32``
-``double``         ``OpTypeFloat 64`` ``Float64``
-================== ================== =========== ====================
+============================== ======================= ================== =========== =================================
+      HLSL                      Command Line Option           SPIR-V       Capability       Extension
+============================== ======================= ================== =========== =================================
+``bool``                                               ``OpTypeBool``
+``int``/``int32_t``                                    ``OpTypeInt 32 1``
+``int16_t``                    ``-enable-16bit-types`` ``OpTypeInt 16 1`` ``Int16``
+``uint``/``dword``/``uin32_t``                         ``OpTypeInt 32 0``
+``uint16_t``                   ``-enable-16bit-types`` ``OpTypeInt 16 0`` ``Int16``
+``half``                                               ``OpTypeFloat 32`` 
+``half``/``float16_t``         ``-enable-16bit-types`` ``OpTypeFloat 16`` ``Float16`` ``SPV_AMD_gpu_shader_half_float``
+``float``/``float32_t``                                ``OpTypeFloat 32``
+``snorm float``                                        ``OpTypeFloat 32``
+``unorm float``                                        ``OpTypeFloat 32``
+``double``/``float64_t``                               ``OpTypeFloat 64`` ``Float64``
+============================== ======================= ================== =========== =================================
 
 
 Please note that ``half`` is translated into 32-bit floating point numbers
 Please note that ``half`` is translated into 32-bit floating point numbers
 right now because MSDN says that "this data type is provided only for language
 right now because MSDN says that "this data type is provided only for language
 compatibility. Direct3D 10 shader targets map all ``half`` data types to
 compatibility. Direct3D 10 shader targets map all ``half`` data types to
-``float`` data types." This may change in the future to map to 16-bit floating
-point numbers (possibly via a command-line option).
+``float`` data types."
 
 
 Minimal precision scalar types
 Minimal precision scalar types
 ------------------------------
 ------------------------------
@@ -270,17 +272,25 @@ HLSL also supports various
 which graphics drivers can implement by using any precision greater than or
 which graphics drivers can implement by using any precision greater than or
 equal to their specified bit precision.
 equal to their specified bit precision.
 There are no direct mappings in SPIR-V for these types. We translate them into
 There are no direct mappings in SPIR-V for these types. We translate them into
-the corresponding 32-bit scalar types with the ``RelexedPrecision`` decoration:
-
-============== ================== ====================
-    HLSL            SPIR-V            Decoration
-============== ================== ====================
-``min16float`` ``OpTypeFloat 32`` ``RelexedPrecision``
-``min10float`` ``OpTypeFloat 32`` ``RelexedPrecision``
-``min16int``   ``OpTypeInt 32 1`` ``RelexedPrecision``
-``min12int``   ``OpTypeInt 32 1`` ``RelexedPrecision``
-``min16uint``  ``OpTypeInt 32 0`` ``RelexedPrecision``
-============== ================== ====================
+the corresponding 16-bit or 32-bit scalar types with the ``RelaxedPrecision`` decoration.
+We use the 16-bit variants if '-enable-16bit-types' command line option is present.
+For more information on these types, please refer to:
+https://github.com/Microsoft/DirectXShaderCompiler/wiki/16-Bit-Scalar-Types
+
+============== ======================= ================== ==================== ============ =================================
+    HLSL        Command Line Option          SPIR-V            Decoration       Capability        Extension
+============== ======================= ================== ==================== ============ =================================
+``min16float``                         ``OpTypeFloat 32`` ``RelaxedPrecision``
+``min10float``                         ``OpTypeFloat 32`` ``RelaxedPrecision``
+``min16int``                           ``OpTypeInt 32 1`` ``RelaxedPrecision``
+``min12int``                           ``OpTypeInt 32 1`` ``RelaxedPrecision``
+``min16uint``                          ``OpTypeInt 32 0`` ``RelaxedPrecision``
+``min16float`` ``-enable-16bit-types`` ``OpTypeFloat 16``                      ``Float16``  ``SPV_AMD_gpu_shader_half_float``
+``min10float`` ``-enable-16bit-types`` ``OpTypeFloat 16``                      ``Float16``  ``SPV_AMD_gpu_shader_half_float``
+``min16int``   ``-enable-16bit-types`` ``OpTypeInt 16 1``                      ``Int16``
+``min12int``   ``-enable-16bit-types`` ``OpTypeInt 16 1``                      ``Int16``
+``min16uint``  ``-enable-16bit-types`` ``OpTypeInt 16 0``                      ``Int16``
+============== ======================= ================== ==================== ============ =================================
 
 
 Vectors and matrices
 Vectors and matrices
 --------------------
 --------------------

+ 10 - 0
tools/clang/include/clang/SPIRV/Constant.h

@@ -61,10 +61,20 @@ public:
                                  DecorationSet dec = {});
                                  DecorationSet dec = {});
   static const Constant *getFalse(SPIRVContext &ctx, uint32_t type_id,
   static const Constant *getFalse(SPIRVContext &ctx, uint32_t type_id,
                                   DecorationSet dec = {});
                                   DecorationSet dec = {});
+  static const Constant *getInt16(SPIRVContext &ctx, uint32_t type_id,
+                                  int16_t value, DecorationSet dec = {});
   static const Constant *getInt32(SPIRVContext &ctx, uint32_t type_id,
   static const Constant *getInt32(SPIRVContext &ctx, uint32_t type_id,
                                   int32_t value, DecorationSet dec = {});
                                   int32_t value, DecorationSet dec = {});
+  static const Constant *getInt64(SPIRVContext &ctx, uint32_t type_id,
+                                  int64_t value, DecorationSet dec = {});
+  static const Constant *getUint16(SPIRVContext &ctx, uint32_t type_id,
+                                   uint16_t value, DecorationSet dec = {});
   static const Constant *getUint32(SPIRVContext &ctx, uint32_t type_id,
   static const Constant *getUint32(SPIRVContext &ctx, uint32_t type_id,
                                    uint32_t value, DecorationSet dec = {});
                                    uint32_t value, DecorationSet dec = {});
+  static const Constant *getUint64(SPIRVContext &ctx, uint32_t type_id,
+                                   uint64_t value, DecorationSet dec = {});
+  static const Constant *getFloat16(SPIRVContext &ctx, uint32_t type_id,
+                                    int16_t value, DecorationSet dec = {});
   static const Constant *getFloat32(SPIRVContext &ctx, uint32_t type_id,
   static const Constant *getFloat32(SPIRVContext &ctx, uint32_t type_id,
                                     float value, DecorationSet dec = {});
                                     float value, DecorationSet dec = {});
   static const Constant *getFloat64(SPIRVContext &ctx, uint32_t type_id,
   static const Constant *getFloat64(SPIRVContext &ctx, uint32_t type_id,

+ 1 - 0
tools/clang/include/clang/SPIRV/EmitSPIRVOptions.h

@@ -19,6 +19,7 @@ struct EmitSPIRVOptions {
   bool codeGenHighLevel;
   bool codeGenHighLevel;
   bool disableValidation;
   bool disableValidation;
   bool ignoreUnusedResources;
   bool ignoreUnusedResources;
+  bool enable16BitTypes;
   llvm::StringRef stageIoOrder;
   llvm::StringRef stageIoOrder;
   llvm::SmallVector<uint32_t, 4> bShift;
   llvm::SmallVector<uint32_t, 4> bShift;
   llvm::SmallVector<uint32_t, 4> tShift;
   llvm::SmallVector<uint32_t, 4> tShift;

+ 8 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -358,10 +358,13 @@ public:
 
 
   uint32_t getVoidType();
   uint32_t getVoidType();
   uint32_t getBoolType();
   uint32_t getBoolType();
+  uint32_t getInt16Type();
   uint32_t getInt32Type();
   uint32_t getInt32Type();
   uint32_t getInt64Type();
   uint32_t getInt64Type();
+  uint32_t getUint16Type();
   uint32_t getUint32Type();
   uint32_t getUint32Type();
   uint32_t getUint64Type();
   uint32_t getUint64Type();
+  uint32_t getFloat16Type();
   uint32_t getFloat32Type();
   uint32_t getFloat32Type();
   uint32_t getFloat64Type();
   uint32_t getFloat64Type();
   uint32_t getVecType(uint32_t elemType, uint32_t elemCount);
   uint32_t getVecType(uint32_t elemType, uint32_t elemCount);
@@ -391,8 +394,13 @@ public:
 
 
   // === Constant ===
   // === Constant ===
   uint32_t getConstantBool(bool value);
   uint32_t getConstantBool(bool value);
+  uint32_t getConstantInt16(int16_t value);
   uint32_t getConstantInt32(int32_t value);
   uint32_t getConstantInt32(int32_t value);
+  uint32_t getConstantInt64(int64_t value);
+  uint32_t getConstantUint16(uint16_t value);
   uint32_t getConstantUint32(uint32_t value);
   uint32_t getConstantUint32(uint32_t value);
+  uint32_t getConstantUint64(uint64_t value);
+  uint32_t getConstantFloat16(int16_t value);
   uint32_t getConstantFloat32(float value);
   uint32_t getConstantFloat32(float value);
   uint32_t getConstantFloat64(double value);
   uint32_t getConstantFloat64(double value);
   uint32_t getConstantComposite(uint32_t typeId,
   uint32_t getConstantComposite(uint32_t typeId,

+ 88 - 0
tools/clang/lib/SPIRV/Constant.cpp

@@ -11,6 +11,37 @@
 #include "clang/SPIRV/BitwiseCast.h"
 #include "clang/SPIRV/BitwiseCast.h"
 #include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/SPIRVContext.h"
 
 
+namespace {
+uint32_t zeroExtendTo32Bits(uint16_t value) {
+  // TODO: The ordering of the 2 words depends on the endian-ness of the host
+  // machine. Assuming Little Endian at the moment.
+  struct two16Bits {
+    uint16_t low;
+    uint16_t high;
+  };
+
+  two16Bits result = {value, 0};
+  return clang::spirv::cast::BitwiseCast<uint32_t, two16Bits>(result);
+}
+
+uint32_t signExtendTo32Bits(int16_t value) {
+  // TODO: The ordering of the 2 words depends on the endian-ness of the host
+  // machine. Assuming Little Endian at the moment.
+  struct two16Bits {
+    int16_t low;
+    uint16_t high;
+  };
+
+  two16Bits result = {value, 0};
+
+  // Sign bit is 1
+  if (value >> 15) {
+    result.high = 0xffff;
+  }
+  return clang::spirv::cast::BitwiseCast<uint32_t, two16Bits>(result);
+}
+}
+
 namespace clang {
 namespace clang {
 namespace spirv {
 namespace spirv {
 
 
@@ -37,6 +68,17 @@ const Constant *Constant::getFalse(SPIRVContext &ctx, uint32_t type_id,
   return getUniqueConstant(ctx, c);
   return getUniqueConstant(ctx, c);
 }
 }
 
 
+const Constant *Constant::getFloat16(SPIRVContext &ctx, uint32_t type_id,
+                                     int16_t value, DecorationSet dec) {
+  // According to the SPIR-V Spec:
+  // When the type's bit width is less than 32-bits, the literal's value appears
+  // in the low-order bits of the word, and the high-order bits must be 0 for a
+  // floating-point type.
+  Constant c = Constant(spv::Op::OpConstant, type_id,
+                        {zeroExtendTo32Bits(value)}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
 const Constant *Constant::getFloat32(SPIRVContext &ctx, uint32_t type_id,
 const Constant *Constant::getFloat32(SPIRVContext &ctx, uint32_t type_id,
                                      float value, DecorationSet dec) {
                                      float value, DecorationSet dec) {
   Constant c = Constant(spv::Op::OpConstant, type_id,
   Constant c = Constant(spv::Op::OpConstant, type_id,
@@ -58,12 +100,46 @@ const Constant *Constant::getFloat64(SPIRVContext &ctx, uint32_t type_id,
   return getUniqueConstant(ctx, c);
   return getUniqueConstant(ctx, c);
 }
 }
 
 
+const Constant *Constant::getUint16(SPIRVContext &ctx, uint32_t type_id,
+                                    uint16_t value, DecorationSet dec) {
+  // According to the SPIR-V Spec:
+  // When the type's bit width is less than 32-bits, the literal's value appears
+  // in the low-order bits of the word, and the high-order bits must be 0 for an
+  // integer type with Signedness of 0.
+  Constant c = Constant(spv::Op::OpConstant, type_id,
+                        {zeroExtendTo32Bits(value)}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
 const Constant *Constant::getUint32(SPIRVContext &ctx, uint32_t type_id,
 const Constant *Constant::getUint32(SPIRVContext &ctx, uint32_t type_id,
                                     uint32_t value, DecorationSet dec) {
                                     uint32_t value, DecorationSet dec) {
   Constant c = Constant(spv::Op::OpConstant, type_id, {value}, dec);
   Constant c = Constant(spv::Op::OpConstant, type_id, {value}, dec);
   return getUniqueConstant(ctx, c);
   return getUniqueConstant(ctx, c);
 }
 }
 
 
+const Constant *Constant::getUint64(SPIRVContext &ctx, uint32_t type_id,
+                                    uint64_t value, DecorationSet dec) {
+  struct wideInt {
+    uint32_t word0;
+    uint32_t word1;
+  };
+  wideInt words = cast::BitwiseCast<wideInt, uint64_t>(value);
+  Constant c =
+      Constant(spv::Op::OpConstant, type_id, {words.word0, words.word1}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getInt16(SPIRVContext &ctx, uint32_t type_id,
+                                    int16_t value, DecorationSet dec) {
+  // According to the SPIR-V Spec:
+  // When the type's bit width is less than 32-bits, the literal's value appears
+  // in the low-order bits of the word, and the high-order bits must be
+  // sign-extended for integers with Signedness of 1.
+  Constant c = Constant(spv::Op::OpConstant, type_id,
+                        {signExtendTo32Bits(value)}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
 const Constant *Constant::getInt32(SPIRVContext &ctx, uint32_t type_id,
 const Constant *Constant::getInt32(SPIRVContext &ctx, uint32_t type_id,
                                    int32_t value, DecorationSet dec) {
                                    int32_t value, DecorationSet dec) {
   Constant c = Constant(spv::Op::OpConstant, type_id,
   Constant c = Constant(spv::Op::OpConstant, type_id,
@@ -71,6 +147,18 @@ const Constant *Constant::getInt32(SPIRVContext &ctx, uint32_t type_id,
   return getUniqueConstant(ctx, c);
   return getUniqueConstant(ctx, c);
 }
 }
 
 
+const Constant *Constant::getInt64(SPIRVContext &ctx, uint32_t type_id,
+                                   int64_t value, DecorationSet dec) {
+  struct wideInt {
+    uint32_t word0;
+    uint32_t word1;
+  };
+  wideInt words = cast::BitwiseCast<wideInt, int64_t>(value);
+  Constant c =
+      Constant(spv::Op::OpConstant, type_id, {words.word0, words.word1}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
 const Constant *Constant::getComposite(SPIRVContext &ctx, uint32_t type_id,
 const Constant *Constant::getComposite(SPIRVContext &ctx, uint32_t type_id,
                                        llvm::ArrayRef<uint32_t> constituents,
                                        llvm::ArrayRef<uint32_t> constituents,
                                        DecorationSet dec) {
                                        DecorationSet dec) {

+ 1 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -597,7 +597,7 @@ DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
                                        const EmitSPIRVOptions &options)
                                        const EmitSPIRVOptions &options)
     : shaderModel(model), theBuilder(builder), spirvOptions(options),
     : shaderModel(model), theBuilder(builder), spirvOptions(options),
       astContext(context), diags(context.getDiagnostics()),
       astContext(context), diags(context.getDiagnostics()),
-      typeTranslator(context, builder, diags), entryFunctionId(0),
+      typeTranslator(context, builder, diags, options), entryFunctionId(0),
       needsLegalization(false),
       needsLegalization(false),
       glPerVertex(model, context, builder, typeTranslator) {}
       glPerVertex(model, context, builder, typeTranslator) {}
 
 

+ 15 - 6
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -791,19 +791,23 @@ IMPL_GET_PRIMITIVE_TYPE(Float32)
 #undef IMPL_GET_PRIMITIVE_TYPE
 #undef IMPL_GET_PRIMITIVE_TYPE
 
 
 #define IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(ty, cap)                       \
 #define IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(ty, cap)                       \
-  \
-uint32_t ModuleBuilder::get##ty##Type() {                                      \
-    requireCapability(spv::Capability::cap);                                    \
+                                                                               \
+  uint32_t ModuleBuilder::get##ty##Type() {                                    \
+    requireCapability(spv::Capability::cap);                                   \
+    if (spv::Capability::cap == spv::Capability::Float16)                      \
+      theModule.addExtension("SPV_AMD_gpu_shader_half_float");                 \
     const Type *type = Type::get##ty(theContext);                              \
     const Type *type = Type::get##ty(theContext);                              \
     const uint32_t typeId = theContext.getResultIdForType(type);               \
     const uint32_t typeId = theContext.getResultIdForType(type);               \
     theModule.addType(type, typeId);                                           \
     theModule.addType(type, typeId);                                           \
     return typeId;                                                             \
     return typeId;                                                             \
-  \
-}
+  }
 
 
-IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float64, Float64)
 IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int64, Int64)
 IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int64, Int64)
 IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint64, Int64)
 IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint64, Int64)
+IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float64, Float64)
+IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Int16, Int16)
+IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Uint16, Int16)
+IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float16, Float16)
 
 
 #undef IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY
 #undef IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY
 
 
@@ -1056,10 +1060,15 @@ uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) {                  \
   \
   \
 }
 }
 
 
+IMPL_GET_PRIMITIVE_CONST(Int16, int16_t)
 IMPL_GET_PRIMITIVE_CONST(Int32, int32_t)
 IMPL_GET_PRIMITIVE_CONST(Int32, int32_t)
+IMPL_GET_PRIMITIVE_CONST(Uint16, uint16_t)
 IMPL_GET_PRIMITIVE_CONST(Uint32, uint32_t)
 IMPL_GET_PRIMITIVE_CONST(Uint32, uint32_t)
+IMPL_GET_PRIMITIVE_CONST(Float16, int16_t)
 IMPL_GET_PRIMITIVE_CONST(Float32, float)
 IMPL_GET_PRIMITIVE_CONST(Float32, float)
 IMPL_GET_PRIMITIVE_CONST(Float64, double)
 IMPL_GET_PRIMITIVE_CONST(Float64, double)
+IMPL_GET_PRIMITIVE_CONST(Int64, int64_t)
+IMPL_GET_PRIMITIVE_CONST(Uint64, uint64_t)
 
 
 #undef IMPL_GET_PRIMITIVE_VALUE
 #undef IMPL_GET_PRIMITIVE_VALUE
 
 

+ 63 - 38
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -379,19 +379,6 @@ const ValueDecl *getReferencedDef(const Expr *expr) {
   return nullptr;
   return nullptr;
 }
 }
 
 
-bool isLiteralType(QualType type) {
-  if (type->isSpecificBuiltinType(BuiltinType::LitInt) ||
-      type->isSpecificBuiltinType(BuiltinType::LitFloat))
-    return true;
-
-  // For cases such as 'vector<literal int, 2>'
-  QualType elemType = {};
-  if (TypeTranslator::isVectorType(type, &elemType))
-    return isLiteralType(elemType);
-
-  return false;
-}
-
 } // namespace
 } // namespace
 
 
 SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
 SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
@@ -403,9 +390,9 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
           ci.getCodeGenOpts().HLSLProfile.c_str())),
           ci.getCodeGenOpts().HLSLProfile.c_str())),
       theContext(), theBuilder(&theContext),
       theContext(), theBuilder(&theContext),
       declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
       declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
-      typeTranslator(astContext, theBuilder, diags), entryFunctionId(0),
-      curFunction(nullptr), curThis(0), seenPushConstantAt(),
-      needsLegalization(false) {
+      typeTranslator(astContext, theBuilder, diags, options),
+      entryFunctionId(0), curFunction(nullptr), curThis(0),
+      seenPushConstantAt(), needsLegalization(false) {
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
     emitError("unknown shader module: %0", {}) << shaderModel.GetName();
     emitError("unknown shader module: %0", {}) << shaderModel.GetName();
 }
 }
@@ -597,9 +584,9 @@ SpirvEvalInfo SPIRVEmitter::doDeclRefExpr(const DeclRefExpr *expr) {
 SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
 SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
   SpirvEvalInfo result(/*id*/ 0);
   SpirvEvalInfo result(/*id*/ 0);
 
 
-  const bool isNonLiteralType = !isLiteralType(expr->getType());
-  if (isNonLiteralType)
-    typeTranslator.pushIntendedLiteralType(expr->getType());
+  // Provide a hint to the typeTranslator that if a literal is discovered, its
+  // intended usage is as this expression type.
+  TypeTranslator::LiteralTypeHint hint(typeTranslator, expr->getType());
 
 
   expr = expr->IgnoreParens();
   expr = expr->IgnoreParens();
 
 
@@ -650,9 +637,6 @@ SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
         << expr->getStmtClassName() << expr->getSourceRange();
         << expr->getStmtClassName() << expr->getSourceRange();
   }
   }
 
 
-  if (isNonLiteralType)
-    typeTranslator.popIntendedLiteralType();
-
   return result;
   return result;
 }
 }
 
 
@@ -1000,7 +984,7 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
       needsLegalization = true;
       needsLegalization = true;
   }
   }
 
 
-  if (TypeTranslator::isRelaxedPrecisionType(decl->getType())) {
+  if (TypeTranslator::isRelaxedPrecisionType(decl->getType(), spirvOptions)) {
     theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision);
     theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision);
   }
   }
 
 
@@ -1551,6 +1535,10 @@ void SPIRVEmitter::doSwitchStmt(const SwitchStmt *switchStmt,
 
 
 SpirvEvalInfo
 SpirvEvalInfo
 SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
 SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
+  // Provide a hint to the TypeTranslator that the integer literal used to
+  // index into the array should be translated as a 32-bit integer.
+  TypeTranslator::LiteralTypeHint hint(typeTranslator, astContext.IntTy);
+
   llvm::SmallVector<uint32_t, 4> indices;
   llvm::SmallVector<uint32_t, 4> indices;
   auto info = doExpr(collectArrayStructIndices(expr, &indices));
   auto info = doExpr(collectArrayStructIndices(expr, &indices));
 
 
@@ -6656,10 +6644,10 @@ uint32_t SPIRVEmitter::getMatElemValueOne(QualType type) {
 uint32_t SPIRVEmitter::translateAPValue(const APValue &value,
 uint32_t SPIRVEmitter::translateAPValue(const APValue &value,
                                         const QualType targetType) {
                                         const QualType targetType) {
   uint32_t result = 0;
   uint32_t result = 0;
-  const bool isNonLiteralType = !isLiteralType(targetType);
 
 
-  if (isNonLiteralType)
-    typeTranslator.pushIntendedLiteralType(targetType);
+  // Provide a hint to the typeTranslator that if a literal is discovered, its
+  // intended usage is targetType.
+  TypeTranslator::LiteralTypeHint hint(typeTranslator, targetType);
 
 
   if (targetType->isBooleanType()) {
   if (targetType->isBooleanType()) {
     result = theBuilder.getConstantBool(value.getInt().getBoolValue());
     result = theBuilder.getConstantBool(value.getInt().getBoolValue());
@@ -6684,11 +6672,8 @@ uint32_t SPIRVEmitter::translateAPValue(const APValue &value,
     }
     }
   }
   }
 
 
-  if (result) {
-    if (isNonLiteralType)
-      typeTranslator.popIntendedLiteralType();
+  if (result)
     return result;
     return result;
-  }
 
 
   emitError("APValue of type %0 unimplemented", {}) << value.getKind();
   emitError("APValue of type %0 unimplemented", {}) << value.getKind();
   value.dump();
   value.dump();
@@ -6698,18 +6683,42 @@ uint32_t SPIRVEmitter::translateAPValue(const APValue &value,
 uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
 uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
                                       QualType targetType) {
                                       QualType targetType) {
   targetType = typeTranslator.getIntendedLiteralType(targetType);
   targetType = typeTranslator.getIntendedLiteralType(targetType);
-  if (targetType->isSignedIntegerType()) {
-    // Try to see if this integer can be represented in 32-bit.
-    if (intValue.isSignedIntN(32)) {
+  const auto targetTypeBitWidth = astContext.getTypeSize(targetType);
+  const bool isSigned = targetType->isSignedIntegerType();
+  switch (targetTypeBitWidth) {
+  case 16: {
+    if (spirvOptions.enable16BitTypes) {
+      if (isSigned) {
+        return theBuilder.getConstantInt16(
+            static_cast<int16_t>(intValue.getSExtValue()));
+      } else {
+        return theBuilder.getConstantUint16(
+            static_cast<uint16_t>(intValue.getZExtValue()));
+      }
+    } else {
+      // If enable16BitTypes option is not true, treat as 32-bit integer.
+      if (isSigned)
+        return theBuilder.getConstantInt32(
+            static_cast<int32_t>(intValue.getSExtValue()));
+      else
+        return theBuilder.getConstantUint32(
+            static_cast<uint32_t>(intValue.getZExtValue()));
+    }
+  }
+  case 32: {
+    if (isSigned)
       return theBuilder.getConstantInt32(
       return theBuilder.getConstantInt32(
           static_cast<int32_t>(intValue.getSExtValue()));
           static_cast<int32_t>(intValue.getSExtValue()));
-    }
-  } else {
-    // Try to see if this integer can be represented in 32-bit.
-    if (intValue.isIntN(32)) {
+    else
       return theBuilder.getConstantUint32(
       return theBuilder.getConstantUint32(
           static_cast<uint32_t>(intValue.getZExtValue()));
           static_cast<uint32_t>(intValue.getZExtValue()));
-    }
+  }
+  case 64: {
+    if (isSigned)
+      return theBuilder.getConstantInt64(intValue.getSExtValue());
+    else
+      return theBuilder.getConstantUint64(intValue.getZExtValue());
+  }
   }
   }
 
 
   emitError("APInt for target bitwidth %0 unimplemented", {})
   emitError("APInt for target bitwidth %0 unimplemented", {})
@@ -6761,6 +6770,22 @@ uint32_t SPIRVEmitter::translateAPFloat(const llvm::APFloat &floatValue,
   const auto &semantics = astContext.getFloatTypeSemantics(targetType);
   const auto &semantics = astContext.getFloatTypeSemantics(targetType);
   const auto bitwidth = llvm::APFloat::getSizeInBits(semantics);
   const auto bitwidth = llvm::APFloat::getSizeInBits(semantics);
   switch (bitwidth) {
   switch (bitwidth) {
+  case 16: {
+    if (spirvOptions.enable16BitTypes) {
+      return theBuilder.getConstantFloat16(
+          static_cast<uint16_t>(floatValue.bitcastToAPInt().getZExtValue()));
+    } else {
+      // If 16-bit types are not enabled, treat as 32-bit float.
+      llvm::APFloat f32 = floatValue;
+      bool losesInfo = false;
+      f32.convert(llvm::APFloat::IEEEsingle,
+                  llvm::APFloat::roundingMode::rmTowardZero, &losesInfo);
+      // Conversion from 16-bit float value to 32-bit float value should be
+      // loss-less.
+      assert(!losesInfo);
+      return theBuilder.getConstantFloat32(f32.convertToFloat());
+    }
+  }
   case 32:
   case 32:
     return theBuilder.getConstantFloat32(floatValue.convertToFloat());
     return theBuilder.getConstantFloat32(floatValue.convertToFloat());
   case 64:
   case 64:

+ 73 - 21
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -34,19 +34,29 @@ inline void roundToPow2(uint32_t *val, uint32_t pow2) {
 }
 }
 } // anonymous namespace
 } // anonymous namespace
 
 
-bool TypeTranslator::isRelaxedPrecisionType(QualType type) {
+bool TypeTranslator::isRelaxedPrecisionType(QualType type,
+                                            const EmitSPIRVOptions &opts) {
   // Primitive types
   // Primitive types
   {
   {
     QualType ty = {};
     QualType ty = {};
     if (isScalarType(type, &ty))
     if (isScalarType(type, &ty))
       if (const auto *builtinType = ty->getAs<BuiltinType>())
       if (const auto *builtinType = ty->getAs<BuiltinType>())
         switch (builtinType->getKind()) {
         switch (builtinType->getKind()) {
+        // TODO: Figure out why 'min16float' and 'half' share an enum.
+        // 'half' should not get RelaxedPrecision decoration, but due to the
+        // shared enum, we currently do so.
+        case BuiltinType::Half:
         case BuiltinType::Short:
         case BuiltinType::Short:
         case BuiltinType::UShort:
         case BuiltinType::UShort:
         case BuiltinType::Min12Int:
         case BuiltinType::Min12Int:
-        case BuiltinType::Min10Float:
-        case BuiltinType::Half:
-          return true;
+        case BuiltinType::Min10Float: {
+          // If '-enable-16bit-types' options is enabled, these types are
+          // translated to real 16-bit type, and therefore are not
+          // RelaxedPrecision.
+          // If the options is not enabled, these types are translated to 32-bit
+          // types with the added RelaxedPrecision decoration.
+          return !opts.enable16BitTypes;
+        }
         }
         }
   }
   }
 
 
@@ -55,7 +65,7 @@ bool TypeTranslator::isRelaxedPrecisionType(QualType type) {
   {
   {
     QualType elemType = {};
     QualType elemType = {};
     if (isVectorType(type, &elemType) || isMxNMatrix(type, &elemType))
     if (isVectorType(type, &elemType) || isMxNMatrix(type, &elemType))
-      return isRelaxedPrecisionType(elemType);
+      return isRelaxedPrecisionType(elemType, opts);
   }
   }
 
 
   return false;
   return false;
@@ -108,6 +118,29 @@ bool TypeTranslator::isOpaqueStructType(QualType type) {
   return false;
   return false;
 }
 }
 
 
+TypeTranslator::LiteralTypeHint::LiteralTypeHint(TypeTranslator &t, QualType ty)
+    : translator(t), type(ty) {
+  if (!isLiteralType(type))
+    translator.pushIntendedLiteralType(type);
+}
+TypeTranslator::LiteralTypeHint::~LiteralTypeHint() {
+  if (!isLiteralType(type))
+    translator.popIntendedLiteralType();
+}
+
+bool TypeTranslator::LiteralTypeHint::isLiteralType(QualType type) {
+  if (type->isSpecificBuiltinType(BuiltinType::LitInt) ||
+      type->isSpecificBuiltinType(BuiltinType::LitFloat))
+    return true;
+
+  // For cases such as 'vector<literal int, 2>'
+  QualType elemType = {};
+  if (isVectorType(type, &elemType))
+    return isLiteralType(elemType);
+
+  return false;
+}
+
 void TypeTranslator::pushIntendedLiteralType(QualType type) {
 void TypeTranslator::pushIntendedLiteralType(QualType type) {
   QualType elemType = {};
   QualType elemType = {};
   if (isVectorType(type, &elemType)) {
   if (isVectorType(type, &elemType)) {
@@ -129,8 +162,8 @@ QualType TypeTranslator::getIntendedLiteralType(QualType type) {
 }
 }
 
 
 void TypeTranslator::popIntendedLiteralType() {
 void TypeTranslator::popIntendedLiteralType() {
-  if (!intendedLiteralTypes.empty())
-    intendedLiteralTypes.pop();
+  assert(!intendedLiteralTypes.empty());
+  intendedLiteralTypes.pop();
 }
 }
 
 
 uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
 uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
@@ -154,29 +187,48 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
           return theBuilder.getVoidType();
           return theBuilder.getVoidType();
         case BuiltinType::Bool:
         case BuiltinType::Bool:
           return theBuilder.getBoolType();
           return theBuilder.getBoolType();
-          // int, min16int (short), and min12int are all translated to 32-bit
-          // signed integers in SPIR-V.
         case BuiltinType::Int:
         case BuiltinType::Int:
-        case BuiltinType::Short:
-        case BuiltinType::Min12Int:
           return theBuilder.getInt32Type();
           return theBuilder.getInt32Type();
-          // uint and min16uint (ushort) are both translated to 32-bit unsigned
-          // integers in SPIR-V.
-        case BuiltinType::UShort:
         case BuiltinType::UInt:
         case BuiltinType::UInt:
           return theBuilder.getUint32Type();
           return theBuilder.getUint32Type();
+        case BuiltinType::Float:
+          return theBuilder.getFloat32Type();
+        case BuiltinType::Double:
+          return theBuilder.getFloat64Type();
         case BuiltinType::LongLong:
         case BuiltinType::LongLong:
           return theBuilder.getInt64Type();
           return theBuilder.getInt64Type();
         case BuiltinType::ULongLong:
         case BuiltinType::ULongLong:
           return theBuilder.getUint64Type();
           return theBuilder.getUint64Type();
-          // float, min16float (half), and min10float are all translated to
-          // 32-bit float in SPIR-V.
-        case BuiltinType::Float:
+        // min16int (short), and min12int are treated as 16-bit Int if
+        // '-enable-16bit-types' option is enabled. They are treated as 32-bit
+        // Int otherwise.
+        case BuiltinType::Short:
+        case BuiltinType::Min12Int: {
+          if (spirvOptions.enable16BitTypes)
+            return theBuilder.getInt16Type();
+          else
+            return theBuilder.getInt32Type();
+        }
+        // min16uint (ushort) is treated as 16-bit Uint if '-enable-16bit-types'
+        // option is enabled. It is treated as 32-bit Uint otherwise.
+        case BuiltinType::UShort: {
+          if (spirvOptions.enable16BitTypes)
+            return theBuilder.getUint16Type();
+          else
+            return theBuilder.getUint32Type();
+        }
+        // min16float (half), and min10float are all translated to
+        // 32-bit float in SPIR-V.
+        // min16float (half), and min10float are treated as 16-bit float if
+        // '-enable-16bit-types' option is enabled. They are treated as 32-bit
+        // float otherwise.
         case BuiltinType::Half:
         case BuiltinType::Half:
-        case BuiltinType::Min10Float:
-          return theBuilder.getFloat32Type();
-        case BuiltinType::Double:
-          return theBuilder.getFloat64Type();
+        case BuiltinType::Min10Float: {
+          if (spirvOptions.enable16BitTypes)
+            return theBuilder.getFloat16Type();
+          else
+            return theBuilder.getFloat32Type();
+        }
         case BuiltinType::LitFloat: {
         case BuiltinType::LitFloat: {
           // First try to see if there are any hints about how this literal type
           // First try to see if there are any hints about how this literal type
           // is going to be used. If so, use the hint.
           // is going to be used. If so, use the hint.

+ 34 - 7
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -14,6 +14,7 @@
 
 
 #include "clang/AST/Type.h"
 #include "clang/AST/Type.h"
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Basic/Diagnostic.h"
+#include "clang/SPIRV/EmitSPIRVOptions.h"
 #include "clang/SPIRV/ModuleBuilder.h"
 #include "clang/SPIRV/ModuleBuilder.h"
 
 
 #include "SpirvEvalInfo.h"
 #include "SpirvEvalInfo.h"
@@ -31,8 +32,14 @@ namespace spirv {
 class TypeTranslator {
 class TypeTranslator {
 public:
 public:
   TypeTranslator(ASTContext &context, ModuleBuilder &builder,
   TypeTranslator(ASTContext &context, ModuleBuilder &builder,
-                 DiagnosticsEngine &diag)
-      : astContext(context), theBuilder(builder), diags(diag) {}
+                 DiagnosticsEngine &diag, const EmitSPIRVOptions &opts)
+      : astContext(context), theBuilder(builder), diags(diag),
+        spirvOptions(opts) {}
+
+  ~TypeTranslator() {
+    // Perform any sanity checks.
+    assert(intendedLiteralTypes.empty());
+  }
 
 
   /// \brief Generates the corresponding SPIR-V type for the given Clang
   /// \brief Generates the corresponding SPIR-V type for the given Clang
   /// frontend type and returns the type's <result-id>. On failure, reports
   /// frontend type and returns the type's <result-id>. On failure, reports
@@ -154,7 +161,7 @@ public:
   /// \brief Returns true if the given type can use relaxed precision
   /// \brief Returns true if the given type can use relaxed precision
   /// decoration. Integer and float types with lower than 32 bits can be
   /// decoration. Integer and float types with lower than 32 bits can be
   /// operated on with a relaxed precision.
   /// operated on with a relaxed precision.
-  static bool isRelaxedPrecisionType(QualType);
+  static bool isRelaxedPrecisionType(QualType, const EmitSPIRVOptions &);
 
 
   /// Returns true if the given type will be translated into a SPIR-V image,
   /// Returns true if the given type will be translated into a SPIR-V image,
   /// sampler or struct containing images or samplers.
   /// sampler or struct containing images or samplers.
@@ -231,16 +238,35 @@ public:
                                                     uint32_t *stride);
                                                     uint32_t *stride);
 
 
 public:
 public:
-  /// \brief Adds the given type to the intendedLiteralTypes stack. This will be
-  /// used as a hint regarding usage of literal types.
-  void pushIntendedLiteralType(QualType type);
-
   /// \brief If a hint exists regarding the usage of literal types, it
   /// \brief If a hint exists regarding the usage of literal types, it
   /// is returned. Otherwise, the given type itself is returned.
   /// is returned. Otherwise, the given type itself is returned.
   /// The hint is the type on top of the intendedLiteralTypes stack. This is the
   /// The hint is the type on top of the intendedLiteralTypes stack. This is the
   /// type we suspect the literal under question should be interpreted as.
   /// type we suspect the literal under question should be interpreted as.
   QualType getIntendedLiteralType(QualType type);
   QualType getIntendedLiteralType(QualType type);
 
 
+public:
+  // A RAII class for maintaining the intendedLiteralTypes stack.
+  // Instantiating an object of this class ensures that as long as the
+  // object lives, the hint lives in the TypeTranslator, and once the object is
+  // destroyed, the hint is automatically removed from the stack.
+  class LiteralTypeHint {
+  public:
+    LiteralTypeHint(TypeTranslator &t, QualType ty);
+    ~LiteralTypeHint();
+
+  private:
+    static bool isLiteralType(QualType type);
+
+  private:
+    QualType type;
+    TypeTranslator &translator;
+  };
+
+private:
+  /// \brief Adds the given type to the intendedLiteralTypes stack. This will be
+  /// used as a hint regarding usage of literal types.
+  void pushIntendedLiteralType(QualType type);
+
   /// \brief Removes the type at the top of the intendedLiteralTypes stack.
   /// \brief Removes the type at the top of the intendedLiteralTypes stack.
   void popIntendedLiteralType();
   void popIntendedLiteralType();
 
 
@@ -248,6 +274,7 @@ private:
   ASTContext &astContext;
   ASTContext &astContext;
   ModuleBuilder &theBuilder;
   ModuleBuilder &theBuilder;
   DiagnosticsEngine &diags;
   DiagnosticsEngine &diags;
+  const EmitSPIRVOptions &spirvOptions;
 
 
   /// \brief This is a stack which is used to track the intended usage type for
   /// \brief This is a stack which is used to track the intended usage type for
   /// literals. For example: while a floating literal is being visited, if the
   /// literals. For example: while a floating literal is being visited, if the

+ 25 - 23
tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv

@@ -176,7 +176,7 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 // %95 = OpTypeFunction %HS_CONSTANT_DATA_OUTPUT %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3 %_ptr_Function_uint
 // %95 = OpTypeFunction %HS_CONSTANT_DATA_OUTPUT %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3 %_ptr_Function_uint
 // %_ptr_Function_HS_CONSTANT_DATA_OUTPUT = OpTypePointer Function %HS_CONSTANT_DATA_OUTPUT
 // %_ptr_Function_HS_CONSTANT_DATA_OUTPUT = OpTypePointer Function %HS_CONSTANT_DATA_OUTPUT
 // %_ptr_Function_float = OpTypePointer Function %float
 // %_ptr_Function_float = OpTypePointer Function %float
-// %118 = OpTypeFunction %BEZIER_CONTROL_POINT %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3 %_ptr_Function_uint %_ptr_Function_uint
+// %120 = OpTypeFunction %BEZIER_CONTROL_POINT %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3 %_ptr_Function_uint %_ptr_Function_uint
 // %_ptr_Function_VS_CONTROL_POINT_OUTPUT = OpTypePointer Function %VS_CONTROL_POINT_OUTPUT
 // %_ptr_Function_VS_CONTROL_POINT_OUTPUT = OpTypePointer Function %VS_CONTROL_POINT_OUTPUT
 // %_ptr_Function_BEZIER_CONTROL_POINT = OpTypePointer Function %BEZIER_CONTROL_POINT
 // %_ptr_Function_BEZIER_CONTROL_POINT = OpTypePointer Function %BEZIER_CONTROL_POINT
 // %_ptr_Function_v3float = OpTypePointer Function %v3float
 // %_ptr_Function_v3float = OpTypePointer Function %v3float
@@ -184,10 +184,12 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 // %float_1 = OpConstant %float 1
 // %float_1 = OpConstant %float 1
 // %int_0 = OpConstant %int 0
 // %int_0 = OpConstant %int 0
 // %float_2 = OpConstant %float 2
 // %float_2 = OpConstant %float 2
+// %int_1 = OpConstant %int 1
 // %float_3 = OpConstant %float 3
 // %float_3 = OpConstant %float 3
+// %int_2 = OpConstant %int 2
 // %float_4 = OpConstant %float 4
 // %float_4 = OpConstant %float 4
+// %int_3 = OpConstant %int 3
 // %float_5 = OpConstant %float 5
 // %float_5 = OpConstant %float 5
-// %int_1 = OpConstant %int 1
 // %float_6 = OpConstant %float 6
 // %float_6 = OpConstant %float 6
 // %gl_PerVertexIn = OpVariable %_ptr_Input__arr_type_gl_PerVertex_uint_3 Input
 // %gl_PerVertexIn = OpVariable %_ptr_Input__arr_type_gl_PerVertex_uint_3 Input
 // %gl_PerVertexOut = OpVariable %_ptr_Output__arr_type_gl_PerVertex_uint_3 Output
 // %gl_PerVertexOut = OpVariable %_ptr_Output__arr_type_gl_PerVertex_uint_3 Output
@@ -263,32 +265,32 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 // %PatchID = OpFunctionParameter %_ptr_Function_uint
 // %PatchID = OpFunctionParameter %_ptr_Function_uint
 // %bb_entry = OpLabel
 // %bb_entry = OpLabel
 // %Output = OpVariable %_ptr_Function_HS_CONSTANT_DATA_OUTPUT Function
 // %Output = OpVariable %_ptr_Function_HS_CONSTANT_DATA_OUTPUT Function
-// %105 = OpAccessChain %_ptr_Function_float %Output %int_0 %uint_0
+// %105 = OpAccessChain %_ptr_Function_float %Output %int_0 %int_0
 // OpStore %105 %float_1
 // OpStore %105 %float_1
-// %107 = OpAccessChain %_ptr_Function_float %Output %int_0 %uint_1
-// OpStore %107 %float_2
-// %109 = OpAccessChain %_ptr_Function_float %Output %int_0 %uint_2
-// OpStore %109 %float_3
-// %111 = OpAccessChain %_ptr_Function_float %Output %int_0 %uint_3
-// OpStore %111 %float_4
-// %114 = OpAccessChain %_ptr_Function_float %Output %int_1 %uint_0
-// OpStore %114 %float_5
-// %116 = OpAccessChain %_ptr_Function_float %Output %int_1 %uint_1
-// OpStore %116 %float_6
-// %117 = OpLoad %HS_CONSTANT_DATA_OUTPUT %Output
-// OpReturnValue %117
+// %108 = OpAccessChain %_ptr_Function_float %Output %int_0 %int_1
+// OpStore %108 %float_2
+// %111 = OpAccessChain %_ptr_Function_float %Output %int_0 %int_2
+// OpStore %111 %float_3
+// %114 = OpAccessChain %_ptr_Function_float %Output %int_0 %int_3
+// OpStore %114 %float_4
+// %116 = OpAccessChain %_ptr_Function_float %Output %int_1 %int_0
+// OpStore %116 %float_5
+// %118 = OpAccessChain %_ptr_Function_float %Output %int_1 %int_1
+// OpStore %118 %float_6
+// %119 = OpLoad %HS_CONSTANT_DATA_OUTPUT %Output
+// OpReturnValue %119
 // OpFunctionEnd
 // OpFunctionEnd
-// %src_SubDToBezierHS = OpFunction %BEZIER_CONTROL_POINT None %118
+// %src_SubDToBezierHS = OpFunction %BEZIER_CONTROL_POINT None %120
 // %ip_0 = OpFunctionParameter %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3
 // %ip_0 = OpFunctionParameter %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3
 // %cpid = OpFunctionParameter %_ptr_Function_uint
 // %cpid = OpFunctionParameter %_ptr_Function_uint
 // %PatchID_0 = OpFunctionParameter %_ptr_Function_uint
 // %PatchID_0 = OpFunctionParameter %_ptr_Function_uint
 // %bb_entry_0 = OpLabel
 // %bb_entry_0 = OpLabel
 // %vsOutput = OpVariable %_ptr_Function_VS_CONTROL_POINT_OUTPUT Function
 // %vsOutput = OpVariable %_ptr_Function_VS_CONTROL_POINT_OUTPUT Function
 // %result = OpVariable %_ptr_Function_BEZIER_CONTROL_POINT Function
 // %result = OpVariable %_ptr_Function_BEZIER_CONTROL_POINT Function
-// %128 = OpAccessChain %_ptr_Function_v3float %vsOutput %int_0
-// %129 = OpLoad %v3float %128
-// %130 = OpAccessChain %_ptr_Function_v3float %result %int_0
-// OpStore %130 %129
-// %131 = OpLoad %BEZIER_CONTROL_POINT %result
-// OpReturnValue %131
-// OpFunctionEnd
+// %130 = OpAccessChain %_ptr_Function_v3float %vsOutput %int_0
+// %131 = OpLoad %v3float %130
+// %132 = OpAccessChain %_ptr_Function_v3float %result %int_0
+// OpStore %132 %131
+// %133 = OpLoad %BEZIER_CONTROL_POINT %result
+// OpReturnValue %133
+// OpFunctionEnd

+ 1 - 1
tools/clang/test/CodeGenSPIRV/binary-op.assign.composite.hlsl

@@ -74,7 +74,7 @@ void main(uint index: A) {
     BufferType lbuf;                  // %BufferType_0                   & %SubBuffer_1
     BufferType lbuf;                  // %BufferType_0                   & %SubBuffer_1
     sbuf[5]  = lbuf;             // %BufferType <- %BufferType_0
     sbuf[5]  = lbuf;             // %BufferType <- %BufferType_0
 
 
-// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_SubBuffer_0 %cbuf %int_3 %uint_0
+// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_SubBuffer_0 %cbuf %int_3 %int_0
 // CHECK-NEXT: [[cbuf_d0:%\d+]] = OpLoad %SubBuffer_0 [[ptr]]
 // CHECK-NEXT: [[cbuf_d0:%\d+]] = OpLoad %SubBuffer_0 [[ptr]]
 
 
     // sub.a[0] <- cbuf.d[0].a[0]
     // sub.a[0] <- cbuf.d[0].a[0]

+ 46 - 0
tools/clang/test/CodeGenSPIRV/constant.scalar.16bit.disabled.hlsl

@@ -0,0 +1,46 @@
+// Run: %dxc -T ps_6_0 -E main
+
+/////////////////////////////////////////////////////////////////////////////////
+/// Types with fewer than 32 bits, used without '-enable-16bit-type' options  ///
+/////////////////////////////////////////////////////////////////////////////////
+
+// See https://github.com/Microsoft/DirectXShaderCompiler/wiki/16-Bit-Scalar-Types
+// for details about these types.
+
+// CHECK-NOT: OpDecorate %c_half_4_5 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_half_n8_2 RelaxedPrecision
+// CHECK: OpDecorate %c_min10float RelaxedPrecision
+// CHECK: OpDecorate %c_min16float RelaxedPrecision
+// CHECK: OpDecorate %c_min16int_n3 RelaxedPrecision
+// CHECK: OpDecorate %c_min16uint_5 RelaxedPrecision
+// CHECK: OpDecorate %c_min12int RelaxedPrecision
+
+void main() {
+// Note: in the absence of "-enable-16bit-types" option,
+// 'half' is translated to float *without* RelaxedPrecision decoration.
+// CHECK: %float_7_7 = OpConstant %float 7.7
+  half c_half_4_5 = 7.7;
+// CHECK: %float_n8_8 = OpConstant %float -8.8
+  half c_half_n8_2 = -8.8;
+
+// Note: in the absence of "-enable-16bit-type" option,
+// 'min{10|16}float' are translated to
+// 32-bit float in SPIR-V with RelaxedPrecision decoration (checked above).
+// CHECK: %float_1_5 = OpConstant %float 1.5
+  min10float c_min10float = 1.5;
+// CHECK: %float_n1 = OpConstant %float -1
+  min16float c_min16float = -1.0;
+
+// Note: in the absence of "-enable-16bit-type" option,
+// 'min12{uint|int}' and 'min16{uint|int}' are translated to
+// 32-bit uint/int in SPIR-V with RelaxedPrecision decoration (checked above).
+// CHECK: %int_n3 = OpConstant %int -3
+  min16int c_min16int_n3 = -3;
+// CHECK: %uint_5 = OpConstant %uint 5
+  min16uint c_min16uint_5 = 5;
+// CHECK: %int_n9 = OpConstant %int -9
+  min12int c_min12int = -9;
+// It seems that min12uint is still not supported by the front-end.
+// XXXXX: %uint_12 = OpConstant %uint 12 
+//  min12uint c_min12uint = 12;
+}

+ 70 - 0
tools/clang/test/CodeGenSPIRV/constant.scalar.16bit.enabled.hlsl

@@ -0,0 +1,70 @@
+// Run: %dxc -T ps_6_2 -E main -enable-16bit-types
+
+// Handling of 16-bit integers and 16-bit floats.
+// Note that this test runs utilizes "-enable-16bit-types" option above.
+
+// When this option is enabled, 16-bit types are use as outlined below:
+// min10float: float16_t(warning)
+// min16float: float16_t(warning)
+// half:       float16_t
+// float16_t:  float16_t
+// min12int:   int16_t(warning)
+// min16int:   int16_t(warning)
+// int16_t:    int16_t
+// min12uint:  uint16_t(warning)
+// min16uint:  uint16_t(warning)
+// uint16_t:   uint16_t
+
+// CHECK: OpCapability Float16
+// CHECK: OpCapability Int16
+
+// CHECK-NOT: OpDecorate %c_half RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min10float RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min16float RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_float16t RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min16int_n3 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min16int_3 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min16uint_5 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min12int_n9 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_min12int_9 RelaxedPrecision
+// XXXXX-NOT: OpDecorate %c_min12uint RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_uint16_16 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_int16_n16 RelaxedPrecision
+// CHECK-NOT: OpDecorate %c_int16_16 RelaxedPrecision
+
+// CHECK: %short = OpTypeInt 16 1
+// CHECK: %ushort = OpTypeInt 16 0
+// CHECK: %half = OpTypeFloat 16
+
+void main() {
+// CHECK: %half_0x1_2p_0 = OpConstant %half 0x1.2p+0
+  half       c_half = 1.125;
+// CHECK: %half_0x1_ep_3 = OpConstant %half 0x1.ep+3
+  min10float c_min10float = 15.0;
+// CHECK: %half_n0x1p_0 = OpConstant %half -0x1p+0
+  min16float c_min16float = -1.0;
+// CHECK: %half_0x1_8p_0 = OpConstant %half 0x1.8p+0
+  float16_t  c_float16t = 1.5;
+
+// CHECK: %short_n3 = OpConstant %short -3
+  min16int   c_min16int_n3 = -3;
+// CHECK: %short_3 = OpConstant %short 3
+  min16int   c_min16int_3 = 3;
+// CHECK: %ushort_5 = OpConstant %ushort 5
+  min16uint  c_min16uint_5 = 5;
+
+// CHECK: %short_n9 = OpConstant %short -9
+  min12int   c_min12int_n9 = -9;
+// CHECK: %short_9 = OpConstant %short 9
+  min12int   c_min12int_9 = 9;
+// It seems that min12uint is still not supported by the front-end.
+// XXXXX: %short_12 = OpConstant %short 12 
+//  min12uint   c_min12uint = 12;
+
+// CHECK: %ushort_16 = OpConstant %ushort 16
+  uint16_t  c_uint16_16 = 16;
+// CHECK: %short_n16 = OpConstant %short -16
+  int16_t   c_int16_n16 = -16;
+// CHECK: %short_16 = OpConstant %short 16
+  int16_t   c_int16_16 = 16;
+}

+ 29 - 0
tools/clang/test/CodeGenSPIRV/constant.scalar.64bit.hlsl

@@ -0,0 +1,29 @@
+// Run: %dxc -T ps_6_0 -E main
+
+void main() {
+
+// CHECK: %double_0 = OpConstant %double 0
+  double c_double_0 = 0.;
+// CHECK: %double_n0 = OpConstant %double -0
+  float64_t c_double_n0 = -0.;
+// CHECK: %double_4_5 = OpConstant %double 4.5
+  float64_t c_double_4_5 = 4.5;
+// CHECK: %double_n8_2 = OpConstant %double -8.2
+  double c_double_n8_2 = -8.2;
+// CHECK: %double_1234567898765_32 = OpConstant %double 1234567898765.32
+  double c_large  =  1234567898765.32;
+// CHECK: %double_n1234567898765_32 = OpConstant %double -1234567898765.32
+  float64_t c_nlarge = -1234567898765.32;
+
+// CHECK: %long_1 = OpConstant %long 1
+  int64_t  c_int64_small_1  = 1;  
+// CHECK: %long_n1 = OpConstant %long -1
+  int64_t  c_int64_small_n1  = -1;  
+// CHECK: %long_2147483648 = OpConstant %long 2147483648
+  int64_t  c_int64_large  = 2147483648;
+
+// CHECK: %ulong_2 = OpConstant %ulong 2
+  uint64_t c_uint64_small_2 = 2;
+// CHECK: %ulong_4294967296 = OpConstant %ulong 4294967296
+  uint64_t c_uint64_large = 4294967296;
+}

+ 17 - 33
tools/clang/test/CodeGenSPIRV/constant.scalar.hlsl

@@ -1,64 +1,48 @@
 // Run: %dxc -T ps_6_0 -E main
 // Run: %dxc -T ps_6_0 -E main
 
 
 // TODO
 // TODO
-// 16bit & 64bit integer (require additional capability)
-// 16bit floats (require additional capability)
 // float: denormalized numbers, Inf, NaN
 // float: denormalized numbers, Inf, NaN
 
 
 void main() {
 void main() {
   // Boolean constants
   // Boolean constants
-// CHECK-DAG: %true = OpConstantTrue %bool
+// CHECK: %true = OpConstantTrue %bool
   bool c_bool_t = true;
   bool c_bool_t = true;
-// CHECK-DAG: %false = OpConstantFalse %bool
+// CHECK: %false = OpConstantFalse %bool
   bool c_bool_f = false;
   bool c_bool_f = false;
 
 
   // Signed integer constants
   // Signed integer constants
-// CHECK-DAG: %int_0 = OpConstant %int 0
+// CHECK: %int_0 = OpConstant %int 0
   int c_int_0 = 0;
   int c_int_0 = 0;
-// CHECK-DAG: %int_1 = OpConstant %int 1
+// CHECK: %int_1 = OpConstant %int 1
   int c_int_1 = 1;
   int c_int_1 = 1;
-// CHECK-DAG: %int_n1 = OpConstant %int -1
+// CHECK: %int_n1 = OpConstant %int -1
   int c_int_n1 = -1;
   int c_int_n1 = -1;
-// CHECK-DAG: %int_42 = OpConstant %int 42
+// CHECK: %int_42 = OpConstant %int 42
   int c_int_42 = 42;
   int c_int_42 = 42;
-// CHECK-DAG: %int_n42 = OpConstant %int -42
+// CHECK: %int_n42 = OpConstant %int -42
   int c_int_n42 = -42;
   int c_int_n42 = -42;
-// CHECK-DAG: %int_2147483647 = OpConstant %int 2147483647
+// CHECK: %int_2147483647 = OpConstant %int 2147483647
   int c_int_max = 2147483647;
   int c_int_max = 2147483647;
-// CHECK-DAG: %int_n2147483648 = OpConstant %int -2147483648
+// CHECK: %int_n2147483648 = OpConstant %int -2147483648
   int c_int_min = -2147483648;
   int c_int_min = -2147483648;
 
 
   // Unsigned integer constants
   // Unsigned integer constants
-// CHECK-DAG: %uint_0 = OpConstant %uint 0
+// CHECK: %uint_0 = OpConstant %uint 0
   uint c_uint_0 = 0;
   uint c_uint_0 = 0;
-// CHECK-DAG: %uint_1 = OpConstant %uint 1
+// CHECK: %uint_1 = OpConstant %uint 1
   uint c_uint_1 = 1;
   uint c_uint_1 = 1;
-// CHECK-DAG: %uint_38 = OpConstant %uint 38
+// CHECK: %uint_38 = OpConstant %uint 38
   uint c_uint_38 = 38;
   uint c_uint_38 = 38;
-// CHECK-DAG: %uint_4294967295 = OpConstant %uint 4294967295
+// CHECK: %uint_4294967295 = OpConstant %uint 4294967295
   uint c_uint_max = 4294967295;
   uint c_uint_max = 4294967295;
 
 
   // Float constants
   // Float constants
-// CHECK-DAG: %float_0 = OpConstant %float 0
+// CHECK: %float_0 = OpConstant %float 0
   float c_float_0 = 0.;
   float c_float_0 = 0.;
-// CHECK-DAG: %float_n0 = OpConstant %float -0
+// CHECK: %float_n0 = OpConstant %float -0
   float c_float_n0 = -0.;
   float c_float_n0 = -0.;
-// CHECK-DAG: %float_4_2 = OpConstant %float 4.2
+// CHECK: %float_4_2 = OpConstant %float 4.2
   float c_float_4_2 = 4.2;
   float c_float_4_2 = 4.2;
-// CHECK-DAG: %float_n4_2 = OpConstant %float -4.2
+// CHECK: %float_n4_2 = OpConstant %float -4.2
   float c_float_n4_2 = -4.2;
   float c_float_n4_2 = -4.2;
-  
-  // double constants
-// CHECK-DAG: %double_0 = OpConstant %double 0
-  double c_double_0 = 0.;
-// CHECK-DAG: %double_n0 = OpConstant %double -0
-  double c_double_n0 = -0.;
-// CHECK-DAG: %double_4_5 = OpConstant %double 4.5
-  double c_double_4_5 = 4.5;
-// CHECK-DAG: %double_n8_2 = OpConstant %double -8.2
-  double c_double_n8_2 = -8.2;
-// CHECK-DAG: %double_1234567898765_32 = OpConstant %double 1234567898765.32
-  double c_large  =  1234567898765.32;
-// CHECK-DAG: %double_n1234567898765_32 = OpConstant %double -1234567898765.32
-  double c_nlarge = -1234567898765.32;
 }
 }

+ 1 - 1
tools/clang/test/CodeGenSPIRV/cs.groupshared.hlsl

@@ -20,7 +20,7 @@ groupshared              S        s;
 void main(uint2 tid : SV_DispatchThreadID, uint2 gid : SV_GroupID) {
 void main(uint2 tid : SV_DispatchThreadID, uint2 gid : SV_GroupID) {
 // Make sure pointers have the correct storage class
 // Make sure pointers have the correct storage class
 // CHECK:    {{%\d+}} = OpAccessChain %_ptr_Workgroup_float %s %int_0
 // CHECK:    {{%\d+}} = OpAccessChain %_ptr_Workgroup_float %s %int_0
-// CHECK: [[d0:%\d+]] = OpAccessChain %_ptr_Workgroup_v2float %d %uint_0
+// CHECK: [[d0:%\d+]] = OpAccessChain %_ptr_Workgroup_v2float %d %int_0
 // CHECK:    {{%\d+}} = OpAccessChain %_ptr_Workgroup_float [[d0]] %int_1
 // CHECK:    {{%\d+}} = OpAccessChain %_ptr_Workgroup_float [[d0]] %int_1
     d[0].y = s.f1;
     d[0].y = s.f1;
 }
 }

+ 4 - 4
tools/clang/test/CodeGenSPIRV/op.array.access.hlsl

@@ -17,20 +17,20 @@ float main(float val: A, uint index: B) : C {
 
 
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK-NEXT:  [[idx:%\d+]] = OpLoad %uint %index
 // CHECK-NEXT:  [[idx:%\d+]] = OpLoad %uint %index
-// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_float %var [[idx]] %uint_1 %int_0 %uint_2
+// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_float %var [[idx]] %int_1 %int_0 %int_2
 // CHECK-NEXT:                 OpStore [[ptr0]] [[val]]
 // CHECK-NEXT:                 OpStore [[ptr0]] [[val]]
 
 
     var[index][1].f[2] = val;
     var[index][1].f[2] = val;
 // CHECK-NEXT: [[idx0:%\d+]] = OpLoad %uint %index
 // CHECK-NEXT: [[idx0:%\d+]] = OpLoad %uint %index
 // CHECK-NEXT: [[idx1:%\d+]] = OpLoad %uint %index
 // CHECK-NEXT: [[idx1:%\d+]] = OpLoad %uint %index
-// CHECK:      [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_float %var %uint_0 [[idx0]] %int_1 [[idx1]]
+// CHECK:      [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_float %var %int_0 [[idx0]] %int_1 [[idx1]]
 // CHECK-NEXT: [[load:%\d+]] = OpLoad %float [[ptr0]]
 // CHECK-NEXT: [[load:%\d+]] = OpLoad %float [[ptr0]]
 // CHECK-NEXT:                 OpStore %r [[load]]
 // CHECK-NEXT:                 OpStore %r [[load]]
     r = var[0][index].g[index];
     r = var[0][index].g[index];
 
 
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK-NEXT: [[vec2:%\d+]] = OpCompositeConstruct %v2float [[val]] [[val]]
 // CHECK-NEXT: [[vec2:%\d+]] = OpCompositeConstruct %v2float [[val]] [[val]]
-// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_v4float %vecvar %uint_3
+// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_v4float %vecvar %int_3
 // CHECK-NEXT: [[vec4:%\d+]] = OpLoad %v4float [[ptr0]]
 // CHECK-NEXT: [[vec4:%\d+]] = OpLoad %v4float [[ptr0]]
 // CHECK-NEXT:  [[res:%\d+]] = OpVectorShuffle %v4float [[vec4]] [[vec2]] 0 1 5 4
 // CHECK-NEXT:  [[res:%\d+]] = OpVectorShuffle %v4float [[vec4]] [[vec2]] 0 1 5 4
 // CHECK-NEXT:                 OpStore [[ptr0]] [[res]]
 // CHECK-NEXT:                 OpStore [[ptr0]] [[res]]
@@ -42,7 +42,7 @@ float main(float val: A, uint index: B) : C {
 
 
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK-NEXT: [[vec2:%\d+]] = OpCompositeConstruct %v2float [[val]] [[val]]
 // CHECK-NEXT: [[vec2:%\d+]] = OpCompositeConstruct %v2float [[val]] [[val]]
-// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float %matvar %uint_2
+// CHECK-NEXT: [[ptr0:%\d+]] = OpAccessChain %_ptr_Function_mat2v3float %matvar %int_2
 // CHECK-NEXT: [[val0:%\d+]] = OpCompositeExtract %float [[vec2]] 0
 // CHECK-NEXT: [[val0:%\d+]] = OpCompositeExtract %float [[vec2]] 0
 // CHECK-NEXT: [[ptr1:%\d+]] = OpAccessChain %_ptr_Function_float [[ptr0]] %int_0 %int_1
 // CHECK-NEXT: [[ptr1:%\d+]] = OpAccessChain %_ptr_Function_float [[ptr0]] %int_0 %int_1
 // CHECK-NEXT:                 OpStore [[ptr1]] [[val0]]
 // CHECK-NEXT:                 OpStore [[ptr1]] [[val0]]

+ 1 - 1
tools/clang/test/CodeGenSPIRV/op.cbuffer.access.hlsl

@@ -29,7 +29,7 @@ float main() : A {
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s0]]
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s0]]
 
 
 // CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform__arr_float_uint_4 %var_MyCbuffer %int_4
 // CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform__arr_float_uint_4 %var_MyCbuffer %int_4
-// CHECK-NEXT: [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float [[t]] %uint_3
+// CHECK-NEXT: [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float [[t]] %int_3
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t3]]
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t3]]
     return a + b.x + c[1][2] + s.f + t[3];
     return a + b.x + c[1][2] + s.f + t[3];
 }
 }

+ 1 - 1
tools/clang/test/CodeGenSPIRV/op.constant-buffer.access.hlsl

@@ -29,7 +29,7 @@ float main() : A {
 // CHECK:      [[s:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_3 %int_0
 // CHECK:      [[s:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_3 %int_0
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s]]
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s]]
 
 
-// CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_4 %uint_3
+// CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyCbuffer %int_4 %int_3
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t]]
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t]]
     return MyCbuffer.a + MyCbuffer.b.x + MyCbuffer.c[1][2] + MyCbuffer.s.f + MyCbuffer.t[3];
     return MyCbuffer.a + MyCbuffer.b.x + MyCbuffer.c[1][2] + MyCbuffer.s.f + MyCbuffer.t[3];
 }
 }

+ 2 - 2
tools/clang/test/CodeGenSPIRV/op.rw-structured-buffer.access.hlsl

@@ -26,7 +26,7 @@ void main(uint index: A) {
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK:       [[val:%\d+]] = OpLoad %float %val
 // CHECK-NEXT:  [[index:%\d+]] = OpLoad %uint %index
 // CHECK-NEXT:  [[index:%\d+]] = OpLoad %uint %index
 
 
-// CHECK-NEXT:  [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 [[index]] %int_4 %uint_3
+// CHECK-NEXT:  [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 [[index]] %int_4 %int_3
 // CHECK-NEXT:  OpStore [[t3]] [[val]]
 // CHECK-NEXT:  OpStore [[t3]] [[val]]
 
 
 // CHECK:       [[f:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_3 %int_3 %uint_0 %int_0
 // CHECK:       [[f:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_3 %int_3 %uint_0 %int_0
@@ -35,7 +35,7 @@ void main(uint index: A) {
 // CHECK-NEXT:  [[c212:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_2 %int_2 %uint_2 %uint_1 %uint_2
 // CHECK-NEXT:  [[c212:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_2 %int_2 %uint_2 %uint_1 %uint_2
 // CHECK-NEXT:  OpStore [[c212]] [[val]]
 // CHECK-NEXT:  OpStore [[c212]] [[val]]
 
 
-// CHECK-NEXT:  [[b1:%\d+]] = OpAccessChain %_ptr_Uniform_v2float %MySbuffer %int_0 %uint_1 %int_1 %uint_1
+// CHECK-NEXT:  [[b1:%\d+]] = OpAccessChain %_ptr_Uniform_v2float %MySbuffer %int_0 %uint_1 %int_1 %int_1
 // CHECK-NEXT:  [[x:%\d+]] = OpAccessChain %_ptr_Uniform_float [[b1]] %int_0
 // CHECK-NEXT:  [[x:%\d+]] = OpAccessChain %_ptr_Uniform_float [[b1]] %int_0
 // CHECK-NEXT:  OpStore [[x]] [[val]]
 // CHECK-NEXT:  OpStore [[x]] [[val]]
 
 

+ 2 - 2
tools/clang/test/CodeGenSPIRV/op.structured-buffer.access.hlsl

@@ -19,7 +19,7 @@ float4 main(uint index: A) : SV_Target {
 // CHECK:      [[a:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_0 %int_0
 // CHECK:      [[a:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_0 %int_0
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[a]]
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[a]]
 
 
-// CHECK:      [[b1:%\d+]] = OpAccessChain %_ptr_Uniform_v2float %MySbuffer %int_0 %uint_1 %int_1 %uint_1
+// CHECK:      [[b1:%\d+]] = OpAccessChain %_ptr_Uniform_v2float %MySbuffer %int_0 %uint_1 %int_1 %int_1
 // CHECK-NEXT: [[x:%\d+]] = OpAccessChain %_ptr_Uniform_float [[b1]] %int_0
 // CHECK-NEXT: [[x:%\d+]] = OpAccessChain %_ptr_Uniform_float [[b1]] %int_0
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[x]]
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[x]]
 
 
@@ -30,7 +30,7 @@ float4 main(uint index: A) : SV_Target {
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s]]
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s]]
 
 
 // CHECK:      [[index:%\d+]] = OpLoad %uint %index
 // CHECK:      [[index:%\d+]] = OpLoad %uint %index
-// CHECK-NEXT: [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 [[index]] %int_4 %uint_3
+// CHECK-NEXT: [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 [[index]] %int_4 %int_3
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t]]
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t]]
     return MySbuffer[0].a + MySbuffer[1].b[1].x + MySbuffer[2].c[2][1][2] +
     return MySbuffer[0].a + MySbuffer[1].b[1].x + MySbuffer[2].c[2][1][2] +
            MySbuffer[3].s[0].f + MySbuffer[index].t[3];
            MySbuffer[3].s[0].f + MySbuffer[index].t[3];

+ 1 - 1
tools/clang/test/CodeGenSPIRV/op.tbuffer.access.hlsl

@@ -29,7 +29,7 @@ float main() : A {
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s0]]
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s0]]
 
 
 // CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform__arr_float_uint_4 %var_MyTbuffer %int_4
 // CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform__arr_float_uint_4 %var_MyTbuffer %int_4
-// CHECK-NEXT: [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float [[t]] %uint_3
+// CHECK-NEXT: [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float [[t]] %int_3
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t3]]
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t3]]
     return a + b.x + c[1][2] + s.f + t[3];
     return a + b.x + c[1][2] + s.f + t[3];
 }
 }

+ 1 - 1
tools/clang/test/CodeGenSPIRV/op.texture-buffer.access.hlsl

@@ -29,7 +29,7 @@ float main() : A {
 // CHECK:      [[s:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyTB %int_3 %int_0
 // CHECK:      [[s:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyTB %int_3 %int_0
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s]]
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s]]
 
 
-// CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyTB %int_4 %uint_3
+// CHECK:      [[t:%\d+]] = OpAccessChain %_ptr_Uniform_float %MyTB %int_4 %int_3
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t]]
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[t]]
   return MyTB.a + MyTB.b.x + MyTB.c[1][2] + MyTB.s.f + MyTB.t[3];
   return MyTB.a + MyTB.b.x + MyTB.c[1][2] + MyTB.s.f + MyTB.t[3];
 }
 }

+ 4 - 4
tools/clang/test/CodeGenSPIRV/var.init.array.hlsl

@@ -67,19 +67,19 @@ void main() {
     T2 val2[2] = {val1};
     T2 val2[2] = {val1};
 
 
 // val3[0]: Construct T3.h from T1.c.b[0]
 // val3[0]: Construct T3.h from T1.c.b[0]
-// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %uint_0 %int_0 %int_0 %uint_0
+// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %int_0 %int_0 %int_0 %uint_0
 // CHECK-NEXT:   [[h_val:%\d+]] = OpLoad %v2float [[b_0]]
 // CHECK-NEXT:   [[h_val:%\d+]] = OpLoad %v2float [[b_0]]
 
 
 // val3[0]: Construct T3.i from T1.c.b[1]
 // val3[0]: Construct T3.i from T1.c.b[1]
-// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %uint_0 %int_0 %int_0 %uint_1
+// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %int_0 %int_0 %int_0 %uint_1
 // CHECK-NEXT:   [[i_val:%\d+]] = OpLoad %v2float [[b_1]]
 // CHECK-NEXT:   [[i_val:%\d+]] = OpLoad %v2float [[b_1]]
 
 
 // val3[0]: Construct T3.j from T1.d.b[0]
 // val3[0]: Construct T3.j from T1.d.b[0]
-// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %uint_0 %int_1 %int_0 %uint_0
+// CHECK-NEXT:     [[b_0:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %int_0 %int_1 %int_0 %uint_0
 // CHECK-NEXT:   [[j_val:%\d+]] = OpLoad %v2float [[b_0]]
 // CHECK-NEXT:   [[j_val:%\d+]] = OpLoad %v2float [[b_0]]
 
 
 // val3[0]: Construct T3.k from T1.d.b[1]
 // val3[0]: Construct T3.k from T1.d.b[1]
-// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %uint_0 %int_1 %int_0 %uint_1
+// CHECK-NEXT:     [[b_1:%\d+]] = OpAccessChain %_ptr_Function_v2float %val1 %int_0 %int_1 %int_0 %uint_1
 // CHECK-NEXT:   [[k_val:%\d+]] = OpLoad %v2float [[b_1]]
 // CHECK-NEXT:   [[k_val:%\d+]] = OpLoad %v2float [[b_1]]
 
 
 // CHECK-NEXT:  [[val3_0:%\d+]] = OpCompositeConstruct %T3 [[h_val]] [[i_val]] [[j_val]] [[k_val]]
 // CHECK-NEXT:  [[val3_0:%\d+]] = OpCompositeConstruct %T3 [[h_val]] [[i_val]] [[j_val]] [[k_val]]

+ 1 - 1
tools/clang/test/CodeGenSPIRV/vk.push-constant.hlsl

@@ -32,7 +32,7 @@ float main() : A {
         pcs.f2.z +
         pcs.f2.z +
 // CHECK:     {{%\d+}} = OpAccessChain %_ptr_PushConstant_float %pcs %int_2 %uint_1 %uint_2
 // CHECK:     {{%\d+}} = OpAccessChain %_ptr_PushConstant_float %pcs %int_2 %uint_1 %uint_2
         pcs.f3[1][2] +
         pcs.f3[1][2] +
-// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_PushConstant_v2float %pcs %int_3 %int_0 %uint_2
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_PushConstant_v2float %pcs %int_3 %int_0 %int_2
 // CHECK:     {{%\d+}} = OpAccessChain %_ptr_PushConstant_float [[ptr]] %int_1
 // CHECK:     {{%\d+}} = OpAccessChain %_ptr_PushConstant_float [[ptr]] %int_1
         pcs.f4.val[2].y;
         pcs.f4.val[2].y;
 }
 }

+ 1 - 1
tools/clang/tools/dxcompiler/dxcompilerobj.cpp

@@ -473,7 +473,7 @@ public:
           spirvOpts.tShift = opts.VkTShift;
           spirvOpts.tShift = opts.VkTShift;
           spirvOpts.sShift = opts.VkSShift;
           spirvOpts.sShift = opts.VkSShift;
           spirvOpts.uShift = opts.VkUShift;
           spirvOpts.uShift = opts.VkUShift;
-
+          spirvOpts.enable16BitTypes = opts.Enable16BitTypes;
           clang::EmitSPIRVAction action(spirvOpts);
           clang::EmitSPIRVAction action(spirvOpts);
           FrontendInputFile file(utf8SourceName.m_psz, IK_HLSL);
           FrontendInputFile file(utf8SourceName.m_psz, IK_HLSL);
           action.BeginSourceFile(compiler, file);
           action.BeginSourceFile(compiler, file);

+ 11 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -82,6 +82,17 @@ TEST_F(FileTest, TriangleStreamTypes) {
 
 
 // For constants
 // For constants
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }
+TEST_F(FileTest, 16BitDisabledScalarConstants) {
+  runFileTest("constant.scalar.16bit.disabled.hlsl");
+}
+TEST_F(FileTest, 16BitEnabledScalarConstants) {
+  // TODO: Fix spirv-val to make sure it respects the 16-bit extension.
+  runFileTest("constant.scalar.16bit.enabled.hlsl", FileTest::Expect::Success,
+              /*runValidation*/ false);
+}
+TEST_F(FileTest, 64BitScalarConstants) {
+  runFileTest("constant.scalar.64bit.hlsl");
+}
 TEST_F(FileTest, VectorConstants) { runFileTest("constant.vector.hlsl"); }
 TEST_F(FileTest, VectorConstants) { runFileTest("constant.vector.hlsl"); }
 TEST_F(FileTest, MatrixConstants) { runFileTest("constant.matrix.hlsl"); }
 TEST_F(FileTest, MatrixConstants) { runFileTest("constant.matrix.hlsl"); }
 TEST_F(FileTest, StructConstants) { runFileTest("constant.struct.hlsl"); }
 TEST_F(FileTest, StructConstants) { runFileTest("constant.struct.hlsl"); }