Jelajahi Sumber

Handle passing an element of a RWStructureBuffer array to a function (#5447)

getFinalACSBufferCounter currently returns a CounterIdAliasPair. This is
because all cases in the original design could return the entire
counter variable. When using arrays for RWStructuredBuffers, we will
need to be able to return something the represents a portion of the
array that is needed. The best method of doing that is to write reusable
functions that return a SpirvInstruction whose result is a pointer to
the approparite element of the array.

This is then used to allow copying individual counter variables.

It does not handle copying the entire array of counter variables.
Steven Perron 2 tahun lalu
induk
melakukan
b97f9b9388

+ 12 - 4
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -661,8 +661,14 @@ std::string StageVar::getSemanticStr() const {
   return ss.str();
   return ss.str();
 }
 }
 
 
-SpirvInstruction *CounterIdAliasPair::get(SpirvBuilder &builder,
-                                          SpirvContext &spvContext) const {
+SpirvInstruction *CounterIdAliasPair::getAliasAddress() const {
+  assert(isAlias);
+  return counterVar;
+}
+
+SpirvInstruction *
+CounterIdAliasPair::getCounterVariable(SpirvBuilder &builder,
+                                       SpirvContext &spvContext) const {
   if (isAlias) {
   if (isAlias) {
     const auto *counterType = spvContext.getACSBufferCounterType();
     const auto *counterType = spvContext.getACSBufferCounterType();
     const auto *counterVarType =
     const auto *counterVarType =
@@ -689,7 +695,8 @@ bool CounterVarFields::assign(const CounterVarFields &srcFields,
     if (!srcField)
     if (!srcField)
       return false;
       return false;
 
 
-    field.counterVar.assign(*srcField, builder, context);
+    field.counterVar.assign(srcField->getCounterVariable(builder, context),
+                            builder);
   }
   }
 
 
   return true;
   return true;
@@ -729,7 +736,8 @@ bool CounterVarFields::assign(const CounterVarFields &srcFields,
       if (!srcField)
       if (!srcField)
         return false;
         return false;
 
 
-      field.counterVar.assign(*srcField, builder, context);
+      field.counterVar.assign(srcField->getCounterVariable(builder, context),
+                              builder);
       for (uint32_t i = srcPrefix.size(); i < srcIndices.size(); ++i)
       for (uint32_t i = srcPrefix.size(); i < srcIndices.size(); ++i)
         srcIndices.pop_back();
         srcIndices.pop_back();
     }
     }

+ 12 - 9
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -71,14 +71,19 @@ public:
   CounterIdAliasPair(SpirvVariable *var, bool alias)
   CounterIdAliasPair(SpirvVariable *var, bool alias)
       : counterVar(var), isAlias(alias) {}
       : counterVar(var), isAlias(alias) {}
 
 
+  /// Returns the pointer to the counter variable alias. This returns a pointer
+  /// that can be used as the address to a store instruction when storing to an
+  /// alias counter.
+  SpirvInstruction *getAliasAddress() const;
+
   /// Returns the pointer to the counter variable. Dereferences first if this is
   /// Returns the pointer to the counter variable. Dereferences first if this is
   /// an alias to a counter variable.
   /// an alias to a counter variable.
-  SpirvInstruction *get(SpirvBuilder &builder, SpirvContext &spvContext) const;
+  SpirvInstruction *getCounterVariable(SpirvBuilder &builder,
+                                       SpirvContext &spvContext) const;
 
 
-  /// Stores the counter variable's pointer in srcPair to the curent counter
+  /// Stores the counter variable pointed to by src to the curent counter
   /// variable. The current counter variable must be an alias.
   /// variable. The current counter variable must be an alias.
-  inline void assign(const CounterIdAliasPair &srcPair, SpirvBuilder &,
-                     SpirvContext &) const;
+  inline void assign(SpirvInstruction *src, SpirvBuilder &) const;
 
 
 private:
 private:
   SpirvVariable *counterVar;
   SpirvVariable *counterVar;
@@ -906,12 +911,10 @@ bool SemanticInfo::isTarget() const {
   return semantic && semantic->GetKind() == hlsl::Semantic::Kind::Target;
   return semantic && semantic->GetKind() == hlsl::Semantic::Kind::Target;
 }
 }
 
 
-void CounterIdAliasPair::assign(const CounterIdAliasPair &srcPair,
-                                SpirvBuilder &builder,
-                                SpirvContext &context) const {
+void CounterIdAliasPair::assign(SpirvInstruction *src,
+                                SpirvBuilder &builder) const {
   assert(isAlias);
   assert(isAlias);
-  builder.createStore(counterVar, srcPair.get(builder, context),
-                      /* SourceLocation */ {});
+  builder.createStore(counterVar, src, /* SourceLocation */ {});
 }
 }
 
 
 DeclResultIdMapper::DeclResultIdMapper(ASTContext &context,
 DeclResultIdMapper::DeclResultIdMapper(ASTContext &context,

+ 43 - 22
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -4539,26 +4539,16 @@ SpirvEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
     (void)doExpr(object);
     (void)doExpr(object);
   }
   }
 
 
-  const auto *counterPair = getFinalACSBufferCounter(object);
-  if (!counterPair) {
+  auto *counter = getFinalACSBufferCounterInstruction(object);
+  if (!counter) {
     emitFatalError("cannot find the associated counter variable",
     emitFatalError("cannot find the associated counter variable",
                    object->getExprLoc());
                    object->getExprLoc());
     return nullptr;
     return nullptr;
   }
   }
 
 
-  llvm::SmallVector<SpirvInstruction *, 2> indexes;
-  if(const auto *arraySubscriptExpr = dyn_cast<ArraySubscriptExpr>(object)) {
-    // TODO(5440): This codes does not handle multi-dimensional arrays. We need
-    // to look at specific example to determine the best way to do it.
-    indexes.push_back(doExpr(arraySubscriptExpr->getIdx()));
-  }
-
   // Add an extra 0 because the counter is wrapped in a struct.
   // Add an extra 0 because the counter is wrapped in a struct.
-  indexes.push_back(zero);
-
-  auto *counterPtr = spvBuilder.createAccessChain(
-      astContext.IntTy, counterPair->get(spvBuilder, spvContext), indexes,
-      srcLoc, srcRange);
+  auto *counterPtr = spvBuilder.createAccessChain(astContext.IntTy, counter,
+                                                  {zero}, srcLoc, srcRange);
 
 
   SpirvInstruction *index = nullptr;
   SpirvInstruction *index = nullptr;
   if (isInc) {
   if (isInc) {
@@ -4596,13 +4586,13 @@ bool SpirvEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
   // Handle AssocCounter#1 (see CounterVarFields comment)
   // Handle AssocCounter#1 (see CounterVarFields comment)
   if (const auto *dstPair =
   if (const auto *dstPair =
           declIdMapper.createOrGetCounterIdAliasPair(dstDecl)) {
           declIdMapper.createOrGetCounterIdAliasPair(dstDecl)) {
-    const auto *srcPair = getFinalACSBufferCounter(srcExpr);
-    if (!srcPair) {
+    auto *srcCounter = getFinalACSBufferCounterInstruction(srcExpr);
+    if (!srcCounter) {
       emitFatalError("cannot find the associated counter variable",
       emitFatalError("cannot find the associated counter variable",
                      srcExpr->getExprLoc());
                      srcExpr->getExprLoc());
       return false;
       return false;
     }
     }
-    dstPair->assign(*srcPair, spvBuilder, spvContext);
+    dstPair->assign(srcCounter, spvBuilder);
     return true;
     return true;
   }
   }
 
 
@@ -4633,18 +4623,18 @@ bool SpirvEmitter::tryToAssignCounterVar(const Expr *dstExpr,
   dstExpr = dstExpr->IgnoreParenCasts();
   dstExpr = dstExpr->IgnoreParenCasts();
   srcExpr = srcExpr->IgnoreParenCasts();
   srcExpr = srcExpr->IgnoreParenCasts();
 
 
-  const auto *dstPair = getFinalACSBufferCounter(dstExpr);
-  const auto *srcPair = getFinalACSBufferCounter(srcExpr);
+  auto *dstCounter = getFinalACSBufferCounterAliasAddressInstruction(dstExpr);
+  auto *srcCounter = getFinalACSBufferCounterInstruction(srcExpr);
 
 
-  if ((dstPair == nullptr) != (srcPair == nullptr)) {
+  if ((dstCounter == nullptr) != (srcCounter == nullptr)) {
     emitFatalError("cannot handle associated counter variable assignment",
     emitFatalError("cannot handle associated counter variable assignment",
                    srcExpr->getExprLoc());
                    srcExpr->getExprLoc());
     return false;
     return false;
   }
   }
 
 
   // Handle AssocCounter#1 & AssocCounter#2
   // Handle AssocCounter#1 & AssocCounter#2
-  if (dstPair && srcPair) {
-    dstPair->assign(*srcPair, spvBuilder, spvContext);
+  if (dstCounter && srcCounter) {
+    spvBuilder.createStore(dstCounter, srcCounter, /* SourceLocation */ {});
     return true;
     return true;
   }
   }
 
 
@@ -4662,6 +4652,37 @@ bool SpirvEmitter::tryToAssignCounterVar(const Expr *dstExpr,
   return false;
   return false;
 }
 }
 
 
+SpirvInstruction *SpirvEmitter::getFinalACSBufferCounterAliasAddressInstruction(
+    const Expr *expr) {
+  const CounterIdAliasPair *counter = getFinalACSBufferCounter(expr);
+  return (counter ? counter->getAliasAddress() : nullptr);
+}
+
+SpirvInstruction *
+SpirvEmitter::getFinalACSBufferCounterInstruction(const Expr *expr) {
+  const CounterIdAliasPair *counterPair = getFinalACSBufferCounter(expr);
+  if (!counterPair)
+    return nullptr;
+
+  SpirvInstruction *counter =
+      counterPair->getCounterVariable(spvBuilder, spvContext);
+  const auto srcLoc = expr->getExprLoc();
+
+  // TODO(5440): This codes does not handle multi-dimensional arrays. We need
+  // to look at specific example to determine the best way to do it. Could a
+  // call to collectArrayStructIndices handle that for us?
+  llvm::SmallVector<SpirvInstruction *, 2> indexes;
+  if (const auto *arraySubscriptExpr = dyn_cast<ArraySubscriptExpr>(expr)) {
+    indexes.push_back(doExpr(arraySubscriptExpr->getIdx()));
+  }
+
+  if (!indexes.empty()) {
+    counter = spvBuilder.createAccessChain(spvContext.getACSBufferCounterType(),
+                                           counter, indexes, srcLoc);
+  }
+  return counter;
+}
+
 const CounterIdAliasPair *
 const CounterIdAliasPair *
 SpirvEmitter::getFinalACSBufferCounter(const Expr *expr) {
 SpirvEmitter::getFinalACSBufferCounter(const Expr *expr) {
   // AssocCounter#1: referencing some stand-alone variable
   // AssocCounter#1: referencing some stand-alone variable

+ 15 - 0
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -1051,6 +1051,21 @@ private:
                              const Expr *srcExpr);
                              const Expr *srcExpr);
   bool tryToAssignCounterVar(const Expr *dstExpr, const Expr *srcExpr);
   bool tryToAssignCounterVar(const Expr *dstExpr, const Expr *srcExpr);
 
 
+  /// Returns an instruction that points to the alias counter variable with the
+  /// entity represented by expr.
+  ///
+  /// This method only handles final alias structured buffers, which means
+  /// AssocCounter#1 and AssocCounter#2.
+  SpirvInstruction *
+  getFinalACSBufferCounterAliasAddressInstruction(const Expr *expr);
+
+  /// Returns an instruction that points to the counter variable with the entity
+  /// represented by expr.
+  ///
+  /// This method only handles final alias structured buffers, which means
+  /// AssocCounter#1 and AssocCounter#2.
+  SpirvInstruction *getFinalACSBufferCounterInstruction(const Expr *expr);
+
   /// Returns the counter variable's information associated with the entity
   /// Returns the counter variable's information associated with the entity
   /// represented by the given decl.
   /// represented by the given decl.
   ///
   ///

+ 3 - 2
tools/clang/test/CodeGenSPIRV/type.rwstructured-buffer.array.counter.const.index.hlsl

@@ -17,8 +17,9 @@ RWStructuredBuffer<uint> g_rwbuffer[5] : register(u0, space2);
 float4 main(PSInput input) : SV_TARGET
 float4 main(PSInput input) : SV_TARGET
 {
 {
 // Correctly increment the counter.
 // Correctly increment the counter.
-// CHECK: [[ac:%\d+]] = OpAccessChain %_ptr_Uniform_int %counter_var_g_rwbuffer %int_3 %uint_0
-// CHECK: OpAtomicIAdd %int [[ac]] %uint_1 %uint_0 %int_1
+// CHECK: [[ac1:%\d+]] = OpAccessChain %_ptr_Uniform_type_ACSBuffer_counter %counter_var_g_rwbuffer %int_3
+// CHECK: [[ac2:%\d+]] = OpAccessChain %_ptr_Uniform_int [[ac1]] %uint_0
+// CHECK: OpAtomicIAdd %int [[ac2]] %uint_1 %uint_0 %int_1
     g_rwbuffer[3].IncrementCounter();
     g_rwbuffer[3].IncrementCounter();
 
 
 // Correctly access the buffer.
 // Correctly access the buffer.

+ 3 - 2
tools/clang/test/CodeGenSPIRV/type.rwstructured-buffer.array.counter.flatten.hlsl

@@ -21,8 +21,9 @@ RWStructuredBuffer<uint> g_rwbuffer[5] : register(u0, space2);
 float4 main(PSInput input) : SV_TARGET
 float4 main(PSInput input) : SV_TARGET
 {
 {
 // Correctly increment the counter.
 // Correctly increment the counter.
-// CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_Uniform_int %counter_var_g_rwbuffer {{%\d+}} %uint_0
-// CHECK: OpAtomicIAdd %int [[ac]] %uint_1 %uint_0 %int_1
+// CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_type_ACSBuffer_counter %counter_var_g_rwbuffer {{%\d+}}
+// CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_int [[ac1]] %uint_0
+// CHECK: OpAtomicIAdd %int [[ac2]] %uint_1 %uint_0 %int_1
     g_rwbuffer[input.idx].IncrementCounter();
     g_rwbuffer[input.idx].IncrementCounter();
 
 
 // Correctly access the buffer.
 // Correctly access the buffer.

+ 3 - 2
tools/clang/test/CodeGenSPIRV/type.rwstructured-buffer.array.counter.hlsl

@@ -17,8 +17,9 @@ RWStructuredBuffer<uint> g_rwbuffer[5] : register(u0, space2);
 float4 main(PSInput input) : SV_TARGET
 float4 main(PSInput input) : SV_TARGET
 {
 {
 // Correctly increment the counter.
 // Correctly increment the counter.
-// CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_Uniform_int %counter_var_g_rwbuffer {{%\d+}} %uint_0
-// CHECK: OpAtomicIAdd %int [[ac]] %uint_1 %uint_0 %int_1
+// CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_type_ACSBuffer_counter %counter_var_g_rwbuffer {{%\d+}}
+// CHECK: [[ac2:%\w+]] = OpAccessChain %_ptr_Uniform_int [[ac1]] %uint_0
+// CHECK: OpAtomicIAdd %int [[ac2]] %uint_1 %uint_0 %int_1
     g_rwbuffer[input.idx].IncrementCounter();
     g_rwbuffer[input.idx].IncrementCounter();
 
 
 // Correctly access the buffer.
 // Correctly access the buffer.

+ 32 - 0
tools/clang/test/CodeGenSPIRV/type.rwstructured-buffer.array.counter.indirect.hlsl

@@ -0,0 +1,32 @@
+// RUN: %dxc -T ps_6_6 -E main -O0 -fvk-allow-rwstructuredbuffer-arrays
+
+struct PSInput
+{
+	uint idx : COLOR;
+};
+
+// CHECK: OpDecorate %g_rwbuffer DescriptorSet 2
+// CHECK: OpDecorate %g_rwbuffer Binding 0
+// CHECK: OpDecorate %counter_var_g_rwbuffer DescriptorSet 2
+// CHECK: OpDecorate %counter_var_g_rwbuffer Binding 1
+
+// CHECK: %g_rwbuffer = OpVariable %_ptr_Uniform__arr_type_RWStructuredBuffer_uint_uint_5 Uniform
+// CHECK: %counter_var_g_rwbuffer = OpVariable %_ptr_Uniform__arr_type_ACSBuffer_counter_uint_5 Uniform
+RWStructuredBuffer<uint> g_rwbuffer[5] : register(u0, space2);
+
+void func(RWStructuredBuffer<uint> local) {
+      local.IncrementCounter();
+}
+
+float4 main(PSInput input) : SV_TARGET
+{
+// CHECK: [[ac1:%\d+]] = OpAccessChain %_ptr_Uniform_type_ACSBuffer_counter %counter_var_g_rwbuffer {{%\d+}}
+// CHECK: [[ac2:%\d+]] = OpAccessChain %_ptr_Uniform_int [[ac1]] %uint_0
+// CHECK: OpAtomicIAdd %int [[ac2]] %uint_1 %uint_0 %int_1
+    func(g_rwbuffer[input.idx]);
+
+// CHECK: [[ac1:%\d+]] = OpAccessChain %_ptr_Uniform_type_RWStructuredBuffer_uint %g_rwbuffer {{%\d+}}
+// CHECK: [[ac2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ac1]] %int_0 %uint_0
+// CHECK: OpLoad %uint [[ac2]]
+    return g_rwbuffer[input.idx][0];
+}

+ 3 - 2
tools/clang/test/CodeGenSPIRV/type.rwstructured-buffer.unbounded.array.counter.hlsl

@@ -17,8 +17,9 @@ RWStructuredBuffer<uint> g_rwbuffer[] : register(u0, space2);
 float4 main(PSInput input) : SV_TARGET
 float4 main(PSInput input) : SV_TARGET
 {
 {
 // Correctly increment the counter.
 // Correctly increment the counter.
-// CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_Uniform_int %counter_var_g_rwbuffer {{%\d+}} %uint_0
-// CHECK: OpAtomicIAdd %int [[ac]] %uint_1 %uint_0 %int_1
+// CHECK: [[ac1:%\d+]] = OpAccessChain %_ptr_Uniform_type_ACSBuffer_counter %counter_var_g_rwbuffer {{%\d+}}
+// CHECK: [[ac2:%\d+]] = OpAccessChain %_ptr_Uniform_int [[ac1]] %uint_0
+// CHECK: OpAtomicIAdd %int [[ac2]] %uint_1 %uint_0 %int_1
     g_rwbuffer[input.idx].IncrementCounter();
     g_rwbuffer[input.idx].IncrementCounter();
 
 
 // Correctly access the buffer.
 // Correctly access the buffer.

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -161,6 +161,9 @@ TEST_F(FileTest, RWStructuredBufferArrayCounterConstIndex) {
 TEST_F(FileTest, RWStructuredBufferArrayCounterFlattened) {
 TEST_F(FileTest, RWStructuredBufferArrayCounterFlattened) {
   runFileTest("type.rwstructured-buffer.array.counter.flatten.hlsl");
   runFileTest("type.rwstructured-buffer.array.counter.flatten.hlsl");
 }
 }
+TEST_F(FileTest, RWStructuredBufferArrayCounterIndirect) {
+  runFileTest("type.rwstructured-buffer.array.counter.indirect.hlsl");
+}
 TEST_F(FileTest, AppendStructuredBufferArrayError) {
 TEST_F(FileTest, AppendStructuredBufferArrayError) {
   runFileTest("type.append-structured-buffer.array.hlsl");
   runFileTest("type.append-structured-buffer.array.hlsl");
 }
 }