Преглед изворни кода

[spirv] Legalization: associated counter in nested structs (#1001)

This commit supports assigning nested structs as a whole,
which means to go over all fields' associated counters and
assign to the corresponding fields. The front end parsing
and semantic analysis should guarantee the type matching.
Lei Zhang пре 7 година
родитељ
комит
690b44ec16

+ 52 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -126,15 +126,60 @@ CounterVarFields::get(const llvm::SmallVectorImpl<uint32_t> &indices) const {
   return nullptr;
 }
 
-void CounterVarFields::assign(const CounterVarFields &srcFields,
+bool CounterVarFields::assign(const CounterVarFields &srcFields,
                               ModuleBuilder &builder,
                               TypeTranslator &translator) const {
   for (const auto &field : fields) {
     const auto *srcField = srcFields.get(field.indices);
-    // TODO: this will fail for AssocCounter#4.
-    assert(srcField);
+    if (!srcField)
+      return false;
+
     field.counterVar.assign(*srcField, builder, translator);
   }
+
+  return true;
+}
+
+bool CounterVarFields::assign(const CounterVarFields &srcFields,
+                              const llvm::SmallVector<uint32_t, 4> &dstPrefix,
+                              const llvm::SmallVector<uint32_t, 4> &srcPrefix,
+                              ModuleBuilder &builder,
+                              TypeTranslator &translator) const {
+  if (dstPrefix.empty() && srcPrefix.empty())
+    return assign(srcFields, builder, translator);
+
+  llvm::SmallVector<uint32_t, 4> srcIndices = srcPrefix;
+
+  // If whole has the given prefix, appends all elements after the prefix in
+  // whole to srcIndices.
+  const auto applyDiff =
+      [&srcIndices](const llvm::SmallVector<uint32_t, 4> &whole,
+                    const llvm::SmallVector<uint32_t, 4> &prefix) -> bool {
+    uint32_t i = 0;
+    for (; i < prefix.size(); ++i)
+      if (whole[i] != prefix[i]) {
+        break;
+      }
+    if (i == prefix.size()) {
+      for (; i < whole.size(); ++i)
+        srcIndices.push_back(whole[i]);
+      return true;
+    }
+    return false;
+  };
+
+  for (const auto &field : fields)
+    if (applyDiff(field.indices, dstPrefix)) {
+      const auto *srcField = srcFields.get(srcIndices);
+      if (!srcField)
+        return false;
+
+      field.counterVar.assign(*srcField, builder, translator);
+      for (uint32_t i = srcPrefix.size(); i < srcIndices.size(); ++i)
+        srcIndices.pop_back();
+    }
+
+  return true;
 }
 
 DeclResultIdMapper::SemanticInfo
@@ -568,9 +613,13 @@ const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
 
 const CounterVarFields *
 DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) const {
+  if (!decl)
+    return nullptr;
+
   const auto found = fieldCounterVars.find(decl);
   if (found != fieldCounterVars.end())
     return &found->second;
+
   return nullptr;
 }
 

+ 10 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -208,10 +208,17 @@ public:
   get(const llvm::SmallVectorImpl<uint32_t> &indices) const;
 
   /// Assigns to all the fields' associated counter from the srcFields.
-  /// This is for assigning a struct as whole: we need to update all the
-  /// associated counters in the target struct.
-  void assign(const CounterVarFields &srcFields, ModuleBuilder &builder,
+  /// Returns true if there are no errors during the assignment.
+  ///
+  /// This first overload is for assigning a struct as whole: we need to update
+  /// all the associated counters in the target struct. This second overload is
+  /// for assigning a potentially nested struct.
+  bool assign(const CounterVarFields &srcFields, ModuleBuilder &builder,
               TypeTranslator &translator) const;
+  bool assign(const CounterVarFields &srcFields,
+              const llvm::SmallVector<uint32_t, 4> &dstPrefix,
+              const llvm::SmallVector<uint32_t, 4> &srcPrefix,
+              ModuleBuilder &builder, TypeTranslator &translator) const;
 
 private:
   struct IndexCounterPair {

+ 28 - 36
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -2821,26 +2821,26 @@ bool SPIRVEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
     return true;
   }
 
-  // AssocCounter#2 for the lhs cannot happen since the lhs is a stand-alone
-  // decl in this method.
-
   // Handle AssocCounter#3
-  if (const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl)) {
-    if (const auto *srcDecl = getReferencedDef(srcExpr)) {
-      const auto *srcFields = declIdMapper.getCounterVarFields(srcDecl);
-      if (!srcFields) {
-        emitFatalError("cannot find the associated counter variable",
-                       srcExpr->getExprLoc());
-        return false;
-      }
-      dstFields->assign(*srcFields, theBuilder, typeTranslator);
-      return true;
+  const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl);
+  llvm::SmallVector<uint32_t, 4> srcIndices;
+  const auto *srcDecl = getReferencedDef(
+      collectArrayStructIndices(srcExpr, &srcIndices, /*rawIndex=*/true));
+  const auto *srcFields = declIdMapper.getCounterVarFields(srcDecl);
+
+  if (dstFields && srcFields) {
+    if (!dstFields->assign(*srcFields, theBuilder, typeTranslator)) {
+      emitFatalError("cannot handle associated counter variable assignment",
+                     srcExpr->getExprLoc());
+      return false;
     }
+    return true;
   }
 
-  // Handle AssocCounter#4: TODO
+  // AssocCounter#2 and AssocCounter#4 for the lhs cannot happen since the lhs
+  // is a stand-alone decl in this method.
 
-  return true;
+  return false;
 }
 
 bool SPIRVEmitter::tryToAssignCounterVar(const Expr *dstExpr,
@@ -2863,28 +2863,20 @@ bool SPIRVEmitter::tryToAssignCounterVar(const Expr *dstExpr,
     return true;
   }
 
-  // Handle AssocCounter#3
-  if (const auto *dstDecl = getReferencedDef(dstExpr))
-    if (const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl)) {
-      const auto *srcDecl = getReferencedDef(srcExpr);
-      if (!srcDecl) {
-        emitFatalError("cannot find the associated counter variable",
-                       srcExpr->getExprLoc());
-        return false;
-      }
-
-      const auto *srcFields = declIdMapper.getCounterVarFields(srcDecl);
-      if (!srcFields) {
-        emitFatalError("cannot find the associated counter variable",
-                       srcExpr->getExprLoc());
-        return false;
-      }
+  // Handle AssocCounter#3 & AssocCounter#4
+  llvm::SmallVector<uint32_t, 4> dstIndices;
+  llvm::SmallVector<uint32_t, 4> srcIndices;
+  const auto *dstDecl = getReferencedDef(
+      collectArrayStructIndices(dstExpr, &dstIndices, /*rawIndex=*/true));
+  const auto *srcDecl = getReferencedDef(
+      collectArrayStructIndices(srcExpr, &srcIndices, /*rawIndex=*/true));
+  const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl);
+  const auto *srcFields = declIdMapper.getCounterVarFields(srcDecl);
 
-      dstFields->assign(*srcFields, theBuilder, typeTranslator);
-      return true;
-    }
-
-  // Handle AssocCounter#4: TODO
+  if (dstFields && srcFields) {
+    return dstFields->assign(*srcFields, dstIndices, srcIndices, theBuilder,
+                             typeTranslator);
+  }
 
   return false;
 }

+ 22 - 4
tools/clang/test/CodeGenSPIRV/spirv.legal.sbuffer.counter.struct.hlsl

@@ -177,13 +177,31 @@ Wrapper CreateWrapper() {
 // CHECK-NEXT:                OpStore %counter_var_w_0_0_2 [[src]]
     w.b.b1.consume = staticBundle.consume;
 
-    // TODO:
-
     // Assign to intermediate structs whose fields have associated counters
-    //w.b.b2         = staticBundle;
+// CHECK:      [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticBundle_0
+// CHECK-NEXT:                OpStore %counter_var_w_0_1_0 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticBundle_1
+// CHECK-NEXT:                OpStore %counter_var_w_0_1_1 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_staticBundle_2
+// CHECK-NEXT:                OpStore %counter_var_w_0_1_2 [[src]]
+    w.b.b2         = staticBundle;
 
     // Assign from intermediate structs whose fields have associated counters
-    //staticBundle   = w.b.b1;
+// CHECK:      [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_w_0_0_0
+// CHECK-NEXT:                OpStore %counter_var_staticBundle_0 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_w_0_0_1
+// CHECK-NEXT:                OpStore %counter_var_staticBundle_1 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_w_0_0_2
+// CHECK-NEXT:                OpStore %counter_var_staticBundle_2 [[src]]
+    staticBundle   = w.b.b1;
+
+// CHECK:      [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_w_0_1_0
+// CHECK-NEXT:                OpStore %counter_var_w_0_0_0 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_w_0_1_1
+// CHECK-NEXT:                OpStore %counter_var_w_0_0_1 [[src]]
+// CHECK-NEXT: [[src:%\d+]] = OpLoad %_ptr_Uniform_type_ACSBuffer_counter %counter_var_w_0_1_2
+// CHECK-NEXT:                OpStore %counter_var_w_0_0_2 [[src]]
+    w.b.b1         = w.b.b2;
 
     return w;
 }