Pārlūkot izejas kodu

[spirv] Support 16-bit types in stage IO (#1227)

This feature will require the SPV_KHR_16bit_storage extension.
Lei Zhang 7 gadi atpakaļ
vecāks
revīzija
b6d65cbe27

+ 1 - 1
tools/clang/include/clang/SPIRV/FeatureManager.h

@@ -15,7 +15,6 @@
 
 #include <string>
 
-
 #include "spirv-tools/libspirv.h"
 
 #include "clang/Basic/Diagnostic.h"
@@ -31,6 +30,7 @@ namespace spirv {
 /// A list of SPIR-V extensions known to our CodeGen.
 enum class Extension {
   KHR = 0,
+  KHR_16bit_storage,
   KHR_device_group,
   KHR_multiview,
   KHR_shader_draw_parameters,

+ 7 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -1322,6 +1322,13 @@ bool DeclResultIdMapper::createStageVars(const hlsl::SigPoint *sigPoint,
     // Mark that we have used one index for this semantic
     ++semanticToUse->index;
 
+    // Require extension and capability if using 16-bit types
+    if (typeTranslator.getElementSpirvBitwidth(type) == 16) {
+      theBuilder.addExtension(Extension::KHR_16bit_storage,
+                              "16-bit stage IO variables", decl->getLocation());
+      theBuilder.requireCapability(spv::Capability::StorageInputOutput16);
+    }
+
     // TODO: the following may not be correct?
     if (sigPoint->GetSignatureKind() ==
         hlsl::DXIL::SignatureKind::PatchConstant)

+ 0 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -39,7 +39,6 @@ public:
   inline StageVar(const hlsl::SigPoint *sig, llvm::StringRef semaStr,
                   const hlsl::Semantic *sema, llvm::StringRef semaName,
                   uint32_t semaIndex, const VKBuiltInAttr *builtin,
-
                   uint32_t type, uint32_t locCount)
       : sigPoint(sig), semanticStr(semaStr), semantic(sema),
         semanticName(semaName), semanticIndex(semaIndex), builtinAttr(builtin),

+ 11 - 11
tools/clang/lib/SPIRV/FeatureManager.cpp

@@ -84,7 +84,8 @@ bool FeatureManager::requestTargetEnv(spv_target_env requestedEnv,
   if (targetEnv == SPV_ENV_VULKAN_1_0 && requestedEnv == SPV_ENV_VULKAN_1_1) {
     emitError("Vulkan 1.1 is required for %0 but not permitted to use", srcLoc)
         << target;
-    emitNote("please specify your target environment via command line option -fspv-target-env=",
+    emitNote("please specify your target environment via command line option "
+             "-fspv-target-env=",
              {});
     return false;
   }
@@ -94,6 +95,7 @@ bool FeatureManager::requestTargetEnv(spv_target_env requestedEnv,
 Extension FeatureManager::getExtensionSymbol(llvm::StringRef name) {
   return llvm::StringSwitch<Extension>(name)
       .Case("KHR", Extension::KHR)
+      .Case("SPV_KHR_16bit_storage", Extension::KHR_16bit_storage)
       .Case("SPV_KHR_device_group", Extension::KHR_device_group)
       .Case("SPV_KHR_multiview", Extension::KHR_multiview)
       .Case("SPV_KHR_shader_draw_parameters",
@@ -115,6 +117,8 @@ const char *FeatureManager::getExtensionName(Extension symbol) {
   switch (symbol) {
   case Extension::KHR:
     return "KHR";
+  case Extension::KHR_16bit_storage:
+    return "SPV_KHR_16bit_storage";
   case Extension::KHR_device_group:
     return "SPV_KHR_device_group";
   case Extension::KHR_multiview:
@@ -164,19 +168,15 @@ bool FeatureManager::isExtensionRequiredForTargetEnv(Extension ext) {
   bool required = true;
   if (targetEnv == SPV_ENV_VULKAN_1_1) {
     // The following extensions are incorporated into Vulkan 1.1, and are
-    // therefore not required to be emitted for that target environment. The
-    // last 3 are currently not supported by the FeatureManager.
-    // TODO: Add the last 3 extensions to the list if we start to support them.
-    // SPV_KHR_shader_draw_parameters
-    // SPV_KHR_device_group
-    // SPV_KHR_multiview
-    // SPV_KHR_16bit_storage
-    // SPV_KHR_storage_buffer_storage_class
-    // SPV_KHR_variable_pointers
+    // therefore not required to be emitted for that target environment.
+    // TODO: Also add the following extensions  if we start to support them.
+    // * SPV_KHR_storage_buffer_storage_class
+    // * SPV_KHR_variable_pointers
     switch (ext) {
-    case Extension::KHR_shader_draw_parameters:
+    case Extension::KHR_16bit_storage:
     case Extension::KHR_device_group:
     case Extension::KHR_multiview:
+    case Extension::KHR_shader_draw_parameters:
       required = false;
     }
   }

+ 22 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -395,12 +395,34 @@ uint32_t TypeTranslator::getElementSpirvBitwidth(QualType type) {
       return getElementSpirvBitwidth(elemType);
   }
 
+  // Matrix types
+  if (hlsl::IsHLSLMatType(type))
+    return getElementSpirvBitwidth(hlsl::GetHLSLMatElementType(type));
+
+  // Array types
+  if (const auto *arrayType = type->getAsArrayTypeUnsafe()) {
+    return getElementSpirvBitwidth(arrayType->getElementType());
+  }
+
+  // Typedefs
+  if (const auto *typedefType = type->getAs<TypedefType>())
+    return getLocationCount(typedefType->desugar());
+
+  // Reference types
+  if (const auto *refType = type->getAs<ReferenceType>())
+    return getLocationCount(refType->getPointeeType());
+
+  // Pointer types
+  if (const auto *ptrType = type->getAs<PointerType>())
+    return getLocationCount(ptrType->getPointeeType());
+
   // Scalar types
   QualType ty = {};
   const bool isScalar = isScalarType(type, &ty);
   assert(isScalar);
   if (const auto *builtinType = ty->getAs<BuiltinType>()) {
     switch (builtinType->getKind()) {
+    case BuiltinType::Bool:
     case BuiltinType::Int:
     case BuiltinType::UInt:
     case BuiltinType::Float:

+ 3 - 3
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -140,9 +140,9 @@ public:
   uint32_t getTypeWithCustomBitwidth(QualType type, uint32_t bitwidth);
 
   /// \brief Returns the realized bitwidth of the given type when represented in
-  /// SPIR-V. Panics if the given type is not a scalar or vector of float or
-  /// integer. In case of vectors, it returns the realized SPIR-V bitwidth of
-  /// the vector elements.
+  /// SPIR-V. Panics if the given type is not a scalar, a vector/matrix of float
+  /// or integer, or an array of them. In case of vectors, it returns the
+  /// realized SPIR-V bitwidth of the vector elements.
   uint32_t getElementSpirvBitwidth(QualType type);
 
   /// \brief Returns true if the given type will be translated into a SPIR-V

+ 53 - 0
tools/clang/test/CodeGenSPIRV/spirv.stage-io.16bit.hlsl

@@ -0,0 +1,53 @@
+// Run: %dxc -T vs_6_2 -E main -enable-16bit-types
+
+// CHECK: OpCapability StorageInputOutput16
+
+// CHECK: OpExtension "SPV_KHR_16bit_storage"
+
+// CHECK: OpDecorate %in_var_A Location 0
+// CHECK: OpDecorate %in_var_B Location 4
+// CHECK: OpDecorate %in_var_C Location 6
+// CHECK: OpDecorate %in_var_D Location 7
+// CHECK: OpDecorate %in_var_E Location 8
+
+// CHECK: OpDecorate %out_var_A Location 0
+// CHECK: OpDecorate %out_var_B Location 2
+// CHECK: OpDecorate %out_var_C Location 6
+// CHECK: OpDecorate %out_var_D Location 7
+// CHECK: OpDecorate %out_var_E Location 8
+
+// CHECK:  %in_var_A = OpVariable %_ptr_Input__arr_v2half_uint_4 Input
+// CHECK:  %in_var_B = OpVariable %_ptr_Input__arr_v3ushort_uint_2 Input
+// CHECK:  %in_var_C = OpVariable %_ptr_Input_short Input
+// CHECK:  %in_var_D = OpVariable %_ptr_Input_v2ushort Input
+// CHECK:  %in_var_E = OpVariable %_ptr_Input_mat3v2half Input
+
+// CHECK: %out_var_A = OpVariable %_ptr_Output_mat2v3half Output
+// CHECK: %out_var_B = OpVariable %_ptr_Output__arr_v2short_uint_4 Output
+// CHECK: %out_var_C = OpVariable %_ptr_Output_half Output
+// CHECK: %out_var_D = OpVariable %_ptr_Output_v2short Output
+// CHECK: %out_var_E = OpVariable %_ptr_Output_v3ushort Output
+
+struct VSOut {
+    half2x3   outA    : A; // 2 locations: 0, 1
+    int16_t2  outB[4] : B; // 4 locations: 2, 3, 4, 5
+    half      outC    : C; // 1 location : 6
+    int16_t2  outD    : D; // 1 location : 7
+    uint16_t3 outE    : E; // 1 location : 8
+};
+
+VSOut main(
+    half2        inA[4] : A, // 4 locations: 0, 1, 2, 3
+    uint16_t2x3  inB    : B, // 2 locations: 4, 5
+    int16_t      inC    : C, // 1 location : 6
+    uint16_t2    inD    : D, // 1 location : 7
+    float16_t3x2 inE    : E  // 3 location : 8, 9, 10
+) {
+    VSOut o;
+    o.outA    = inA[0].x;
+    o.outB[0] = inB[0][0];
+    o.outC    = inC.x;
+    o.outD    = inD[0];
+    o.outE    = inE[0][0];
+    return o;
+}

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -1216,6 +1216,10 @@ TEST_F(FileTest, SpirvStageIOInterfacePS) {
   runFileTest("spirv.interface.ps.hlsl");
 }
 
+TEST_F(FileTest, SpirvStageIO16bitTypes) {
+  runFileTest("spirv.stage-io.16bit.hlsl");
+}
+
 TEST_F(FileTest, SpirvInterpolation) {
   runFileTest("spirv.interpolation.hlsl");
 }