瀏覽代碼

Enable the vk::counter_binding attribute for arrays (#5456)

The `vk::counter_binding` binding attribute is defined to only apply to
scalars because until recently arrays of RWStructuredBuffers and the
like could not be arrays. That has changed, so the attribute can apply
to the arrays of RWStructuredBuffers.
Steven Perron 2 年之前
父節點
當前提交
6c9592f24a

+ 14 - 1
tools/clang/include/clang/Basic/Attr.td

@@ -943,6 +943,19 @@ def CounterStructuredBuffer : SubsetSubject<
        S->getType()->getAs<RecordType>()->getDecl()->getName() == "AppendStructuredBuffer" ||
        S->getType()->getAs<RecordType>()->getDecl()->getName() == "AppendStructuredBuffer" ||
        S->getType()->getAs<RecordType>()->getDecl()->getName() == "ConsumeStructuredBuffer")}]>;
        S->getType()->getAs<RecordType>()->getDecl()->getName() == "ConsumeStructuredBuffer")}]>;
 
 
+// Array of StructuredBuffer types that can have associated counters 
+def ArrayOfCounterStructuredBuffer
+    : SubsetSubject<
+          Var, [{S->hasGlobalStorage() && S->getType()->getAsArrayTypeUnsafe() &&
+                 S->getType()->getAsArrayTypeUnsafe()->getElementType()->getAs<RecordType>() &&
+                 S->getType()->getAsArrayTypeUnsafe()->getElementType()->getAs<RecordType>()->getDecl() &&
+                  (S->getType()->getAsArrayTypeUnsafe()->getElementType()->getAs<RecordType>()->getDecl()->getName() ==
+                      "RWStructuredBuffer" ||
+                  S->getType()->getAsArrayTypeUnsafe()->getElementType()->getAs<RecordType>()->getDecl()->getName() ==
+                      "AppendStructuredBuffer" ||
+                  S->getType()->getAsArrayTypeUnsafe()->getElementType()->getAs<RecordType>()->getDecl()->getName() ==
+                      "ConsumeStructuredBuffer")}]>;
+
 // Global variable with "ConstantBuffer" type
 // Global variable with "ConstantBuffer" type
 def ConstantBuffer
 def ConstantBuffer
     : SubsetSubject<
     : SubsetSubject<
@@ -1067,7 +1080,7 @@ def VKCapabilityExt : InheritableAttr {
 
 
 def VKCounterBinding : InheritableAttr {
 def VKCounterBinding : InheritableAttr {
   let Spellings = [CXX11<"vk", "counter_binding">];
   let Spellings = [CXX11<"vk", "counter_binding">];
-  let Subjects = SubjectList<[CounterStructuredBuffer], ErrorDiag, "ExpectedCounterStructuredBuffer">;
+  let Subjects = SubjectList<[ArrayOfCounterStructuredBuffer, CounterStructuredBuffer], ErrorDiag, "ExpectedCounterStructuredBuffer">;
   let Args = [IntArgument<"Binding">];
   let Args = [IntArgument<"Binding">];
   let LangOpts = [SPIRV];
   let LangOpts = [SPIRV];
   let Documentation = [Undocumented];
   let Documentation = [Undocumented];

+ 30 - 0
tools/clang/test/CodeGenSPIRV/type.rwstructured-buffer.array.binding.attributes.hlsl

@@ -0,0 +1,30 @@
+// RUN: %dxc -T ps_6_6 -E main -fvk-allow-rwstructuredbuffer-arrays
+
+struct PSInput
+{
+	uint idx : COLOR;
+};
+
+// CHECK: OpDecorate %g_rwbuffer DescriptorSet 4
+// CHECK: OpDecorate %g_rwbuffer Binding 3
+// CHECK: OpDecorate %counter_var_g_rwbuffer DescriptorSet 4
+// CHECK: OpDecorate %counter_var_g_rwbuffer Binding 4
+
+// CHECK: %g_rwbuffer = OpVariable %_ptr_Uniform__arr_type_RWStructuredBuffer_uint_uint_5 Uniform
+// CHECK: %counter_var_g_rwbuffer = OpVariable %_ptr_Uniform__arr_type_ACSBuffer_counter_uint_5 Uniform
+[[vk::binding(3,4), vk::counter_binding(4)]] RWStructuredBuffer<uint> g_rwbuffer[5] : register(u0, space2);
+
+float4 main(PSInput input) : SV_TARGET
+{
+// Correctly increment the counter.
+// CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_type_ACSBuffer_counter %counter_var_g_rwbuffer {{%\d+}}
+// CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_int [[ac1]] %uint_0
+// CHECK: OpAtomicIAdd %int [[ac2]] %uint_1 %uint_0 %int_1
+    g_rwbuffer[input.idx].IncrementCounter();
+
+// Correctly access the buffer.
+// CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_type_RWStructuredBuffer_uint %g_rwbuffer {{%\d+}}
+// CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_uint [[ac1]] %int_0 %uint_0
+// CHECK: OpLoad %uint [[ac2]]
+    return g_rwbuffer[input.idx][0];
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -164,6 +164,9 @@ TEST_F(FileTest, RWStructuredBufferArrayCounterFlattened) {
 TEST_F(FileTest, RWStructuredBufferArrayCounterIndirect) {
 TEST_F(FileTest, RWStructuredBufferArrayCounterIndirect) {
   runFileTest("type.rwstructured-buffer.array.counter.indirect.hlsl");
   runFileTest("type.rwstructured-buffer.array.counter.indirect.hlsl");
 }
 }
+TEST_F(FileTest, RWStructuredBufferArrayBindAttributes) {
+  runFileTest("type.rwstructured-buffer.array.binding.attributes.hlsl");
+}
 TEST_F(FileTest, AppendStructuredBufferArrayError) {
 TEST_F(FileTest, AppendStructuredBufferArrayError) {
   runFileTest("type.append-structured-buffer.array.hlsl");
   runFileTest("type.append-structured-buffer.array.hlsl");
 }
 }