Browse Source

[spirv] Support 16-bit types in resources (#1272)

* Added layout rules and debug names for 16-bit types.
* Generated necessary capabilities for using 16-bit types.
Lei Zhang 7 years ago
parent
commit
0abc68b119

+ 25 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -408,6 +408,15 @@ SpirvEvalInfo DeclResultIdMapper::createExternVar(const VarDecl *var) {
 
   uint32_t varType = typeTranslator.translateType(var->getType(), rule);
 
+  // Require corresponding capability for accessing 16-bit data.
+  if (storageClass == spv::StorageClass::Uniform &&
+      spirvOptions.enable16BitTypes &&
+      typeTranslator.isOrContains16BitType(var->getType())) {
+    theBuilder.addExtension(Extension::KHR_16bit_storage,
+                            "16-bit types in resource", var->getLocation());
+    theBuilder.requireCapability(spv::Capability::StorageUniformBufferBlock16);
+  }
+
   const uint32_t id = theBuilder.addModuleVar(varType, storageClass,
                                               var->getName(), llvm::None);
   const auto info =
@@ -473,6 +482,7 @@ uint32_t DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
   const bool forCBuffer = usageKind == ContextUsageKind::CBuffer;
   const bool forTBuffer = usageKind == ContextUsageKind::TBuffer;
   const bool forGlobals = usageKind == ContextUsageKind::Globals;
+  const bool forPC = usageKind == ContextUsageKind::PushConstant;
 
   auto &context = *theBuilder.getSPIRVContext();
   const LayoutRule layoutRule =
@@ -506,6 +516,19 @@ uint32_t DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
     fieldTypes.push_back(typeTranslator.translateType(varType, layoutRule));
     fieldNames.push_back(declDecl->getName());
 
+    // Require corresponding capability for accessing 16-bit data.
+    if (spirvOptions.enable16BitTypes &&
+        typeTranslator.isOrContains16BitType(varType)) {
+      theBuilder.addExtension(Extension::KHR_16bit_storage,
+                              "16-bit types in resource",
+                              declDecl->getLocation());
+      theBuilder.requireCapability(
+          (forCBuffer || forGlobals)
+              ? spv::Capability::StorageUniform16
+              : forPC ? spv::Capability::StoragePushConstant16
+                      : spv::Capability::StorageUniformBufferBlock16);
+    }
+
     // tbuffer/TextureBuffers are non-writable SSBOs. OpMemberDecorate
     // NonWritable must be applied to all fields.
     if (forTBuffer) {
@@ -534,9 +557,8 @@ uint32_t DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
   // Register the <type-id> for this decl
   ctBufferPCTypeIds[decl] = resultType;
 
-  const auto sc = usageKind == ContextUsageKind::PushConstant
-                      ? spv::StorageClass::PushConstant
-                      : spv::StorageClass::Uniform;
+  const auto sc =
+      forPC ? spv::StorageClass::PushConstant : spv::StorageClass::Uniform;
 
   // Create the variable for the whole struct / struct array.
   return theBuilder.addModuleVar(resultType, sc, varName);

+ 99 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -761,6 +761,79 @@ bool TypeTranslator::isOrContainsAKindOfStructuredOrByteBuffer(QualType type) {
   return false;
 }
 
+bool TypeTranslator::isOrContains16BitType(QualType type) {
+  // Primitive types
+  {
+    QualType ty = {};
+    if (isScalarType(type, &ty)) {
+      if (const auto *builtinType = ty->getAs<BuiltinType>()) {
+        switch (builtinType->getKind()) {
+        case BuiltinType::Short:
+        case BuiltinType::UShort:
+        case BuiltinType::Min12Int:
+        case BuiltinType::Half:
+        case BuiltinType::Min10Float: {
+          return spirvOptions.enable16BitTypes;
+        }
+        default:
+          return false;
+        }
+      }
+    }
+  }
+
+  // Vector types
+  {
+    QualType elemType = {};
+    if (isVectorType(type, &elemType))
+      return isOrContains16BitType(elemType);
+  }
+
+  // Matrix types
+  {
+    QualType elemType = {};
+    if (isMxNMatrix(type, &elemType)) {
+      return isOrContains16BitType(elemType);
+    }
+  }
+
+  // Struct type
+  if (const auto *structType = type->getAs<RecordType>()) {
+    const auto *decl = structType->getDecl();
+
+    for (const auto *field : decl->fields()) {
+      if (isOrContains16BitType(field->getType()))
+        return true;
+    }
+
+    return false;
+  }
+
+  // Array type
+  if (const auto *arrayType = type->getAsArrayTypeUnsafe()) {
+    return isOrContains16BitType(arrayType->getElementType());
+  }
+
+  // Reference types
+  if (const auto *refType = type->getAs<ReferenceType>()) {
+    return isOrContains16BitType(refType->getPointeeType());
+  }
+
+  // Pointer types
+  if (const auto *ptrType = type->getAs<PointerType>()) {
+    return isOrContains16BitType(ptrType->getPointeeType());
+  }
+
+  if (const auto *typedefType = type->getAs<TypedefType>()) {
+    return isOrContains16BitType(typedefType->desugar());
+  }
+
+  emitError("checking 16-bit type for %0 unimplemented")
+      << type->getTypeClassName();
+  type->dump();
+  return 0;
+}
+
 bool TypeTranslator::isStructuredBuffer(QualType type) {
   const auto *recordType = type->getAs<RecordType>();
   if (!recordType)
@@ -1640,6 +1713,16 @@ TypeTranslator::getAlignmentAndSize(QualType type, LayoutRule rule,
         case BuiltinType::LongLong:
         case BuiltinType::ULongLong:
           return {8, 8};
+        case BuiltinType::Short:
+        case BuiltinType::UShort:
+        case BuiltinType::Min12Int:
+        case BuiltinType::Half:
+        case BuiltinType::Min10Float: {
+          if (spirvOptions.enable16BitTypes)
+            return {2, 2};
+          else
+            return {4, 4};
+        }
         default:
           emitError("alignment and size calculation for type %0 unimplemented")
               << type;
@@ -1798,6 +1881,22 @@ std::string TypeTranslator::getName(QualType type) {
           return "uint";
         case BuiltinType::Float:
           return "float";
+        case BuiltinType::Double:
+          return "double";
+        case BuiltinType::LongLong:
+          return "int64";
+        case BuiltinType::ULongLong:
+          return "uint64";
+        case BuiltinType::Short:
+          return "short";
+        case BuiltinType::UShort:
+          return "ushort";
+        case BuiltinType::Half:
+          return "half";
+        case BuiltinType::Min12Int:
+          return "min12int";
+        case BuiltinType::Min10Float:
+          return "min10float";
         default:
           return "";
         }

+ 3 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -88,6 +88,9 @@ public:
   /// containing one of the above.
   static bool isOrContainsAKindOfStructuredOrByteBuffer(QualType type);
 
+  /// \brief Returns true if the given type is or contains 16-bit type.
+  bool isOrContains16BitType(QualType type);
+
   /// \brief Returns true if the given type is the HLSL Buffer type.
   static bool isBuffer(QualType type);
 

+ 22 - 0
tools/clang/test/CodeGenSPIRV/vk.layout.16bit-types.cbuffer.hlsl

@@ -0,0 +1,22 @@
+// Run: %dxc -T ps_6_2 -E main -enable-16bit-types
+
+// CHECK: OpCapability StorageUniform16
+
+// CHECK: OpExtension "SPV_KHR_16bit_storage"
+
+// CHECK: OpMemberDecorate %type_MyCBuffer 0 Offset 0
+// CHECK: OpMemberDecorate %type_MyCBuffer 1 Offset 4
+// CHECK: OpMemberDecorate %type_MyCBuffer 2 Offset 16
+// CHECK: OpMemberDecorate %type_MyCBuffer 2 MatrixStride 16
+// CHECK: OpMemberDecorate %type_MyCBuffer 2 RowMajor
+// CHECK: OpDecorate %type_MyCBuffer Block
+
+cbuffer MyCBuffer {
+    uint16_t2 gVal1; // 16-bit vector
+    int16_t   gVal2; // 16-bit scalar
+    half3x2   gVal3; // 16-bit matrix
+};
+
+float4 main() : SV_Target {
+    return gVal2;
+}

+ 33 - 0
tools/clang/test/CodeGenSPIRV/vk.layout.16bit-types.pc.hlsl

@@ -0,0 +1,33 @@
+// Run: %dxc -T ps_6_2 -E main -enable-16bit-types
+
+// CHECK: OpCapability StoragePushConstant16
+
+// CHECK: OpExtension "SPV_KHR_16bit_storage"
+
+// CHECK: OpMemberDecorate %S 0 Offset 0
+// CHECK: OpMemberDecorate %S 1 Offset 2
+// CHECK: OpMemberDecorate %S 2 Offset 8
+// CHECK: OpMemberDecorate %S 2 MatrixStride 4
+// CHECK: OpMemberDecorate %S 2 RowMajor
+// CHECK: OpMemberDecorate %type_PushConstant_T 0 Offset 0
+// CHECK: OpMemberDecorate %type_PushConstant_T 1 Offset 32
+// CHECK: OpDecorate %type_PushConstant_T Block
+
+struct S {
+    int16_t      val1; // Nested 16-bit scalar
+    uint16_t2    val2; // Nested 16-bit vector
+    float16_t2x3 val3; // Nested 16-bit matrix
+
+};
+
+struct T {
+    S      nested;
+    float4 val;
+};
+
+[[vk::push_constant]]
+T MyPC;
+
+float4 main() : SV_Target {
+    return MyPC.val;
+}

+ 39 - 0
tools/clang/test/CodeGenSPIRV/vk.layout.16bit-types.sbuffer.hlsl

@@ -0,0 +1,39 @@
+// Run: %dxc -T ps_6_2 -E main -enable-16bit-types
+
+// CHECK: OpCapability StorageBuffer16BitAccess
+
+// CHECK: OpExtension "SPV_KHR_16bit_storage"
+
+// CHECK: OpMemberDecorate %S 0 Offset 0
+// CHECK: OpMemberDecorate %S 0 MatrixStride 8
+// CHECK: OpMemberDecorate %S 0 RowMajor
+// CHECK: OpMemberDecorate %S 1 Offset 16
+// CHECK: OpMemberDecorate %S 2 Offset 20
+
+// CHECK: OpMemberDecorate %T 0 Offset 0
+// CHECK: OpMemberDecorate %T 1 Offset 32
+
+// CHECK: OpDecorate %_runtimearr_T ArrayStride 48
+
+// CHECK: OpMemberDecorate %type_StructuredBuffer_T 0 Offset 0
+// CHECK: OpMemberDecorate %type_StructuredBuffer_T 0 NonWritable
+
+// CHECK: OpDecorate %type_StructuredBuffer_T BufferBlock
+
+struct S {
+    float16_t3x2 val1; // Nested 16-bit matrix
+    uint16_t2    val2; // Nested 16-bit vector
+    int16_t      val3; // Nested 16-bit scalar
+
+};
+
+struct T {
+    S      nested;
+    float4 val;
+};
+
+StructuredBuffer<T> MySBuffer;
+
+float4 main() : SV_Target {
+    return MySBuffer[0].val;
+}

+ 25 - 0
tools/clang/test/CodeGenSPIRV/vk.layout.16bit-types.tbuffer.hlsl

@@ -0,0 +1,25 @@
+// Run: %dxc -T ps_6_2 -E main -enable-16bit-types
+
+// CHECK: OpCapability StorageBuffer16BitAccess
+
+// CHECK: OpExtension "SPV_KHR_16bit_storage"
+
+// CHECK: OpMemberDecorate %type_MyTBuffer 0 Offset 0
+// CHECK: OpMemberDecorate %type_MyTBuffer 1 Offset 4
+// CHECK: OpMemberDecorate %type_MyTBuffer 2 Offset 8
+// CHECK: OpMemberDecorate %type_MyTBuffer 2 MatrixStride 8
+// CHECK: OpMemberDecorate %type_MyTBuffer 2 RowMajor
+// CHECK: OpDecorate %type_MyTBuffer BufferBlock
+// CHECK: OpMemberDecorate %type_MyTBuffer 0 NonWritable
+// CHECK: OpMemberDecorate %type_MyTBuffer 1 NonWritable
+// CHECK: OpMemberDecorate %type_MyTBuffer 2 NonWritable
+
+tbuffer MyTBuffer {
+    uint16_t2 gVal1; // 16-bit vector
+    int16_t   gVal2; // 16-bit scalar
+    half3x2   gVal3; // 16-bit matrix
+};
+
+float4 main() : SV_Target {
+    return gVal2;
+}

+ 12 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -1415,6 +1415,18 @@ TEST_F(FileTest, VulkanLayout64BitTypesStd430) {
 TEST_F(FileTest, VulkanLayout64BitTypesStd140) {
   runFileTest("vk.layout.64bit-types.std140.hlsl");
 }
+TEST_F(FileTest, VulkanLayout16BitTypesPushConstant) {
+  runFileTest("vk.layout.16bit-types.pc.hlsl");
+}
+TEST_F(FileTest, VulkanLayout16BitTypesCBuffer) {
+  runFileTest("vk.layout.16bit-types.cbuffer.hlsl");
+}
+TEST_F(FileTest, VulkanLayout16BitTypesTBuffer) {
+  runFileTest("vk.layout.16bit-types.tbuffer.hlsl");
+}
+TEST_F(FileTest, VulkanLayout16BitTypesStructuredBuffer) {
+  runFileTest("vk.layout.16bit-types.sbuffer.hlsl");
+}
 TEST_F(FileTest, VulkanLayoutVectorRelaxedLayout) {
   // Allows vectors to be aligned according to their element types, if not
   // causing improper straddle