Browse Source

[spirv] Remove BufferBlock for array types (#3039)

* [spirv] Remove BufferBlock for array types

* Add a unit test

* code review
Jaebaek Seo 5 years ago
parent
commit
1ca25d515b

+ 22 - 11
tools/clang/lib/SPIRV/RemoveBufferBlockVisitor.cpp

@@ -33,6 +33,23 @@ bool RemoveBufferBlockVisitor::visit(SpirvModule *mod, Phase phase) {
   return true;
 }
 
+bool RemoveBufferBlockVisitor::hasStorageBufferInterfaceType(
+    const SpirvType *type) {
+  while (type != nullptr) {
+    if (const auto *structType = dyn_cast<StructType>(type)) {
+      return structType->getInterfaceType() ==
+             StructInterfaceType::StorageBuffer;
+    } else if (const auto *elemType = dyn_cast<ArrayType>(type)) {
+      type = elemType->getElementType();
+    } else if (const auto *elemType = dyn_cast<RuntimeArrayType>(type)) {
+      type = elemType->getElementType();
+    } else {
+      return false;
+    }
+  }
+  return false;
+}
+
 bool RemoveBufferBlockVisitor::visitInstruction(SpirvInstruction *inst) {
   if (!inst->getResultType())
     return true;
@@ -60,17 +77,11 @@ bool RemoveBufferBlockVisitor::visitInstruction(SpirvInstruction *inst) {
   // For all instructions, if the result type is a pointer pointing to a struct
   // with StorageBuffer interface, the storage class must be updated.
   if (auto *ptrResultType = dyn_cast<SpirvPointerType>(inst->getResultType())) {
-    if (auto *structPointeeType =
-            dyn_cast<StructType>(ptrResultType->getPointeeType())) {
-      // Update the instruction's storage class if necessary
-      if (structPointeeType->getInterfaceType() ==
-              StructInterfaceType::StorageBuffer &&
-          ptrResultType->getStorageClass() !=
-              spv::StorageClass::StorageBuffer) {
-        inst->setStorageClass(spv::StorageClass::StorageBuffer);
-        inst->setResultType(context.getPointerType(
-            ptrResultType->getPointeeType(), spv::StorageClass::StorageBuffer));
-      }
+    if (hasStorageBufferInterfaceType(ptrResultType->getPointeeType()) &&
+        ptrResultType->getStorageClass() != spv::StorageClass::StorageBuffer) {
+      inst->setStorageClass(spv::StorageClass::StorageBuffer);
+      inst->setResultType(context.getPointerType(
+          ptrResultType->getPointeeType(), spv::StorageClass::StorageBuffer));
     }
   }
 

+ 5 - 0
tools/clang/lib/SPIRV/RemoveBufferBlockVisitor.h

@@ -33,6 +33,11 @@ public:
   /// So that you want override this visit function to handle all instructions,
   /// regardless of their polymorphism.
   bool visitInstruction(SpirvInstruction *instr) override;
+
+private:
+  /// Returns true if |type| is a SPIR-V type whose interface type is
+  /// StorageBuffer.
+  bool hasStorageBufferInterfaceType(const SpirvType *type);
 };
 
 } // end namespace spirv

+ 26 - 0
tools/clang/test/CodeGenSPIRV/vk.1p2.remove.bufferblock.runtimearray.hlsl

@@ -0,0 +1,26 @@
+// Run: %dxc -T cs_6_6 -E main -fspv-target-env=vulkan1.2
+
+struct MeshPart {
+  uint indexOffset;
+  uint positionOffset;
+  uint normalOffset;
+  uint texCoord0Offset;
+};
+
+// CHECK: %_ptr_StorageBuffer_type_StructuredBuffer_MeshPart = OpTypePointer StorageBuffer %type_StructuredBuffer_MeshPart
+// CHECK: %_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
+
+StructuredBuffer<MeshPart> g_meshParts[] : register(t2, space1);
+RWStructuredBuffer<uint> g_output : register(u1, space2);
+
+// CHECK: %g_meshParts = OpVariable %_ptr_StorageBuffer__runtimearr_type_StructuredBuffer_MeshPart StorageBuffer
+// CHECK: %g_output = OpVariable %_ptr_StorageBuffer_type_RWStructuredBuffer_uint StorageBuffer
+
+[numthreads(64, 1, 1)]
+void main()
+{
+  MeshPart meshPart = g_meshParts[0][0];
+  g_output[0] = meshPart.indexOffset;
+// CHECK: OpAccessChain %_ptr_StorageBuffer_type_StructuredBuffer_MeshPart %g_meshParts %int_0
+// CHECK: OpAccessChain %_ptr_StorageBuffer_uint %g_output %int_0 %uint_0
+}

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -2336,6 +2336,10 @@ TEST_F(FileTest, Vk1p2BlockDecoration) {
   useVulkan1p2();
   runFileTest("vk.1p2.block-decoration.hlsl");
 }
+TEST_F(FileTest, Vk1p2RemoveBufferBlockRuntimeArray) {
+  useVulkan1p2();
+  runFileTest("vk.1p2.remove.bufferblock.runtimearray.hlsl");
+}
 
 // Test shaders that require Vulkan1.1 support with
 // -fspv-target-env=vulkan1.2 option to make sure that enabling