Bladeren bron

Loop unrolling wip

Krzysztof Krysiński 2 maanden geleden
bovenliggende
commit
c35da3829d

+ 27 - 3
src/PixiEditor.ChangeableDocument/Changeables/Graph/InputProperty.cs

@@ -4,6 +4,7 @@ using PixiEditor.ChangeableDocument.Changeables.Graph.Nodes;
 using PixiEditor.ChangeableDocument.Changes.NodeGraph;
 using PixiEditor.Common;
 using Drawie.Backend.Core.Shaders.Generation;
+using PixiEditor.ChangeableDocument.Rendering;
 
 namespace PixiEditor.ChangeableDocument.Changeables.Graph;
 
@@ -14,6 +15,7 @@ public class InputProperty : IInputProperty
     protected int lastConnectionHash = -1;
     private PropertyValidator? validator;
     private IOutputProperty? connection;
+    private Dictionary<Guid, IOutputProperty> virtualConnections = new();
 
     public event Action ConnectionChanged;
     public event Action<object> NonOverridenValueChanged;
@@ -21,16 +23,22 @@ public class InputProperty : IInputProperty
     public string InternalPropertyName { get; }
     public string DisplayName { get; }
 
+    public Guid? ActiveVirtualSession { get; set; } = null;
+
     public object? Value
     {
         get
         {
-            if (Connection == null)
+            var connection = ActiveVirtualSession is not null && virtualConnections.TryGetValue(ActiveVirtualSession.Value, out var virtualConnection)
+                ? virtualConnection
+                : Connection;
+
+            if (connection == null)
             {
                 return _internalValue;
             }
 
-            var connectionValue = Connection.Value;
+            var connectionValue = connection.Value;
 
             if (connectionValue is null)
             {
@@ -106,6 +114,11 @@ public class InputProperty : IInputProperty
         }
     }
 
+    public void RemoveVirtualConnection(Guid virtualConnectionId)
+    {
+        virtualConnections.Remove(virtualConnectionId);
+    }
+
     public PropertyValidator Validator
     {
         get
@@ -197,9 +210,20 @@ public class InputProperty : IInputProperty
 
     IReadOnlyNode INodeProperty.Node => Node;
 
+    public IOutputProperty? GetVirtualConnection(Guid virtualConnectionId)
+    {
+        return virtualConnections.GetValueOrDefault(virtualConnectionId);
+    }
+
+    public void SetVirtualConnection(OutputProperty outputProperty, Guid virtualConnectionId, RenderContext context)
+    {
+        virtualConnections[virtualConnectionId] = outputProperty;
+        context.RecordVirtualConnection(this, virtualConnectionId);
+    }
+
     public IOutputProperty? Connection
     {
-        get => connection;
+        get => ActiveVirtualSession != null ? virtualConnections.GetValueOrDefault(ActiveVirtualSession.Value) : connection;
         set
         {
             if (connection != value)

+ 7 - 1
src/PixiEditor.ChangeableDocument/Changeables/Graph/Interfaces/INodeProperty.cs

@@ -1,4 +1,5 @@
-using PixiEditor.Common;
+using PixiEditor.ChangeableDocument.Rendering;
+using PixiEditor.Common;
 
 namespace PixiEditor.ChangeableDocument.Changeables.Graph.Interfaces;
 
@@ -20,14 +21,19 @@ public interface INodeProperty<T> : INodeProperty
 
 public interface IInputProperty : INodeProperty
 {
+    public IOutputProperty? GetVirtualConnection(Guid virtualConnectionId);
+    public void SetVirtualConnection(OutputProperty outputProperty, Guid virtualConnectionId, RenderContext context);
     public IOutputProperty? Connection { get; set; }
     public object NonOverridenValue { get; set;  }
+    public void RemoveVirtualConnection(Guid virtualConnectionId);
 }
 
 public interface IOutputProperty : INodeProperty
 {
+    public void VirtualConnectTo(IInputProperty property, Guid virtualConnectionId, RenderContext context);
     public void ConnectTo(IInputProperty property);
     public void DisconnectFrom(IInputProperty property);
+    public void DisconnectFromVirtual(IInputProperty property, Guid virtualConnectionId);
     IReadOnlyCollection<IInputProperty> Connections { get; }
 }
 

+ 2 - 0
src/PixiEditor.ChangeableDocument/Changeables/Graph/NodeGraph.cs

@@ -117,6 +117,8 @@ public class NodeGraph : IReadOnlyNodeGraph, IDisposable
                 node.Execute(context);
             }
         }
+
+        context.CleanupVirtualConnectionScopes();
     }
     
     private bool CanExecute()

+ 7 - 24
src/PixiEditor.ChangeableDocument/Changeables/Graph/Nodes/Utility/RepeatNodeEnd.cs

@@ -17,6 +17,8 @@ public class RepeatNodeEnd : Node, IPairNode, IExecutionFlowNode
 
     private RepeatNodeStart startNode;
 
+    public HashSet<IReadOnlyNode> HandledNodes => CalculateHandledNodes();
+
     public RepeatNodeEnd()
     {
         Input = CreateInput<object>("Input", "INPUT", null);
@@ -35,30 +37,13 @@ public class RepeatNodeEnd : Node, IPairNode, IExecutionFlowNode
             }
         }
 
-        if(startNode.Iterations.Value <= 0)
+        if(startNode.Iterations.Value == 0)
         {
             Output.Value = DefaultOfType(Input.Value);
+            return;
         }
-        else
-        {
-            if (Input.Value is Delegate del)
-            {
-                var result = del.DynamicInvoke(ShaderFuncContext.NoContext);
-                if (result is ShaderExpressionVariable expressionVariable)
-                {
-                    var constant = expressionVariable.GetConstant();
-                    Output.Value = constant;
-                }
-                else
-                {
-                    Output.Value = result;
-                }
-            }
-            else
-            {
-                Output.Value = Input.Value;
-            }
-        }
+
+        Output.Value = Input.Value;
     }
 
     private object DefaultOfType(object? val)
@@ -74,6 +59,7 @@ public class RepeatNodeEnd : Node, IPairNode, IExecutionFlowNode
                 return DefaultOfType(expressionVariable.GetConstant());
             }
         }
+
         return null;
     }
 
@@ -110,8 +96,6 @@ public class RepeatNodeEnd : Node, IPairNode, IExecutionFlowNode
         return startNode;
     }
 
-    public HashSet<IReadOnlyNode> HandledNodes => CalculateHandledNodes();
-
     private HashSet<IReadOnlyNode> CalculateHandledNodes()
     {
         HashSet<IReadOnlyNode> handled = new();
@@ -132,7 +116,6 @@ public class RepeatNodeEnd : Node, IPairNode, IExecutionFlowNode
             {
                 if (leftNode == this)
                 {
-                    handled.Add(node);
                     break;
                 }
 

+ 118 - 16
src/PixiEditor.ChangeableDocument/Changeables/Graph/Nodes/Utility/RepeatNodeStart.cs

@@ -17,7 +17,10 @@ public class RepeatNodeStart : Node, IPairNode
     public Guid OtherNode { get; set; }
     private RepeatNodeEnd? endNode;
 
-    private bool iterationInProgress = false;
+
+    private Guid virtualSessionId;
+    private Queue<IReadOnlyNode> unrolledQueue;
+    private List<IReadOnlyNode> clonedNodes = new List<IReadOnlyNode>();
 
     public RepeatNodeStart()
     {
@@ -29,11 +32,6 @@ public class RepeatNodeStart : Node, IPairNode
 
     protected override void OnExecute(RenderContext context)
     {
-        if (iterationInProgress)
-        {
-            return;
-        }
-
         endNode = FindEndNode();
         if (endNode == null)
         {
@@ -46,26 +44,130 @@ public class RepeatNodeStart : Node, IPairNode
         var queue = GraphUtils.CalculateExecutionQueue(endNode, true, true,
             property => property.Connection?.Node != this);
 
+        if (iterations == 0)
+        {
+            Output.Value = null;
+            CurrentIteration.Value = 0;
+            return;
+        }
+
+        if (iterations > 1)
+        {
+            virtualSessionId = Guid.NewGuid();
+            context.BeginVirtualConnectionScope(virtualSessionId);
+            ClearLastUnrolledNodes();
+            queue = UnrollLoop(iterations, queue, context);
+        }
+
         Output.Value = Input.Value;
-        iterationInProgress = true;
-        for (int i = 0; i < iterations; i++)
+        CurrentIteration.Value = 0; // TODO: Increment iteration in unrolled nodes
+
+        foreach (var node in queue)
         {
-            CurrentIteration.Value = i + 1;
-            foreach (var node in queue)
+            node.Execute(context);
+        }
+    }
+
+    private void ClearLastUnrolledNodes()
+    {
+        if (clonedNodes.Count > 0)
+        {
+            foreach (var node in clonedNodes)
             {
-                if (node == this)
+                if (node is IDisposable disposable) disposable.Dispose();
+            }
+
+            clonedNodes.Clear();
+        }
+    }
+
+    private Queue<IReadOnlyNode> UnrollLoop(int iterations, Queue<IReadOnlyNode> executionQueue, RenderContext context)
+    {
+        var connectToNextStart = endNode.Input.Connection;
+        var connectPreviousTo = Output.Connections;
+
+        Queue<IReadOnlyNode> lastQueue = new Queue<IReadOnlyNode>(executionQueue.Where(x => x != this && x != endNode));
+        for (int i = 0; i < iterations - 1; i++)
+        {
+            var mapping = new Dictionary<Guid, Node>();
+            CloneNodes(lastQueue, mapping);
+            connectPreviousTo =
+                ReplaceConnections(connectToNextStart, connectPreviousTo, mapping, virtualSessionId, context);
+            connectToNextStart = mapping[connectToNextStart.Node.Id].OutputProperties
+                .FirstOrDefault(y => y.InternalPropertyName == connectToNextStart.InternalPropertyName);
+
+            clonedNodes.AddRange(mapping.Values);
+            lastQueue = new Queue<IReadOnlyNode>(mapping.Values);
+        }
+
+        connectToNextStart.VirtualConnectTo(endNode.Input, virtualSessionId, context);
+
+        return GraphUtils.CalculateExecutionQueue(endNode, true, true,
+            property => property.Connection?.Node != this);
+    }
+
+    private IReadOnlyCollection<IInputProperty> ReplaceConnections(IOutputProperty? connectToNextStart,
+        IReadOnlyCollection<IInputProperty> connectPreviousTo, Dictionary<Guid, Node> mapping, Guid virtualConnectionId,
+        RenderContext context)
+    {
+        var connectPreviousToMapped = new List<IInputProperty>();
+        foreach (var input in connectPreviousTo)
+        {
+            if (mapping.TryGetValue(input.Node.Id, out var mappedNode))
+            {
+                var mappedInput =
+                    mappedNode.InputProperties.FirstOrDefault(i =>
+                        i.InternalPropertyName == input.InternalPropertyName);
+                if (mappedInput != null)
                 {
-                    continue;
+                    connectPreviousToMapped.Add(mappedInput);
                 }
-
-                node.Execute(context);
             }
+        }
+
+        foreach (var input in connectPreviousToMapped)
+        {
+            connectToNextStart?.VirtualConnectTo(input, virtualConnectionId, context);
+        }
 
+        return connectPreviousToMapped;
+    }
 
-            Output.Value = endNode.Output.Value;
+    private void CloneNodes(Queue<IReadOnlyNode> originalQueue, Dictionary<Guid, Node> mapping)
+    {
+        foreach (var node in originalQueue)
+        {
+            if (node is not Node n) continue;
+            var clonedNode = n.Clone();
+            mapping[node.Id] = clonedNode;
         }
 
-        iterationInProgress = false;
+        ConnectRelatedNodes(originalQueue, mapping);
+    }
+
+    private void ConnectRelatedNodes(Queue<IReadOnlyNode> originalQueue, Dictionary<Guid, Node> mapping)
+    {
+        foreach (var node in originalQueue)
+        {
+            if (node is not Node n) continue;
+            var clonedNode = mapping[node.Id];
+
+            foreach (var input in n.InputProperties)
+            {
+                if (input.Connection != null &&
+                    mapping.TryGetValue(input.Connection.Node.Id, out var connectedClonedNode))
+                {
+                    var output = connectedClonedNode.OutputProperties.FirstOrDefault(o =>
+                        o.InternalPropertyName == input.Connection.InternalPropertyName);
+                    if (output != null)
+                    {
+                        var inputProp = clonedNode.InputProperties.FirstOrDefault(i =>
+                            i.InternalPropertyName == input.InternalPropertyName);
+                        output.ConnectTo(inputProp);
+                    }
+                }
+            }
+        }
     }
 
     private RepeatNodeEnd FindEndNode()

+ 41 - 0
src/PixiEditor.ChangeableDocument/Changeables/Graph/OutputProperty.cs

@@ -1,5 +1,6 @@
 using PixiEditor.ChangeableDocument.Changeables.Graph.Interfaces;
 using PixiEditor.ChangeableDocument.Changeables.Graph.Nodes;
+using PixiEditor.ChangeableDocument.Rendering;
 using PixiEditor.Common;
 
 namespace PixiEditor.ChangeableDocument.Changeables.Graph;
@@ -8,6 +9,7 @@ public delegate void InputConnectedEvent(IInputProperty input, IOutputProperty o
 
 public class OutputProperty : IOutputProperty
 {
+    private Dictionary<Guid, List<IInputProperty>> _virtualConnections = new();
     private List<IInputProperty> _connections = new();
     private object _value;
     public string InternalPropertyName { get; }
@@ -27,6 +29,7 @@ public class OutputProperty : IOutputProperty
     }
 
     public IReadOnlyCollection<IInputProperty> Connections => _connections;
+    public IReadOnlyCollection<IInputProperty> GetVirtualConnections(Guid virtualSession) => _virtualConnections[virtualSession];
 
     public event InputConnectedEvent Connected;
     public event InputConnectedEvent Disconnected;
@@ -40,6 +43,26 @@ public class OutputProperty : IOutputProperty
         ValueType = valueType;
     }
 
+    public void VirtualConnectTo(IInputProperty property, Guid virtualConnectionId, RenderContext context)
+    {
+        if (!_virtualConnections.ContainsKey(virtualConnectionId))
+        {
+            _virtualConnections[virtualConnectionId] = new List<IInputProperty>();
+        }
+
+        if (property.GetVirtualConnection(virtualConnectionId) != null)
+        {
+            property.GetVirtualConnection(virtualConnectionId).DisconnectFromVirtual(property, virtualConnectionId);
+        }
+
+        property.SetVirtualConnection(this, virtualConnectionId, context);
+
+        if (_virtualConnections[virtualConnectionId].Contains(property)) return;
+
+        _virtualConnections[virtualConnectionId].Add(property);
+        context.RecordVirtualConnection(this, virtualConnectionId);
+    }
+
     public void ConnectTo(IInputProperty property)
     {
         if (Connections.Contains(property)) return;
@@ -67,6 +90,19 @@ public class OutputProperty : IOutputProperty
         Disconnected?.Invoke(property, this);
     }
 
+
+    public void DisconnectFromVirtual(IInputProperty property, Guid virtualConnectionId)
+    {
+        if (!_virtualConnections.TryGetValue(virtualConnectionId, out var connection)) return;
+        if (connection != property) return;
+
+        _virtualConnections.Remove(virtualConnectionId);
+        if (property.GetVirtualConnection(virtualConnectionId) == this)
+        {
+            property.RemoveVirtualConnection(virtualConnectionId);
+        }
+    }
+
     public int GetCacheHash()
     {
         if (Value is ICacheable cacheable)
@@ -76,6 +112,11 @@ public class OutputProperty : IOutputProperty
 
         return 0;
     }
+
+    public void RemoveAllVirtualConnections(Guid virtualSessionId)
+    {
+        _virtualConnections.Remove(virtualSessionId);
+    }
 }
 
 public class OutputProperty<T> : OutputProperty, INodeProperty<T>

+ 74 - 0
src/PixiEditor.ChangeableDocument/Rendering/RenderContext.cs

@@ -2,6 +2,7 @@
 using Drawie.Backend.Core.Surfaces;
 using Drawie.Backend.Core.Surfaces.ImageData;
 using Drawie.Numerics;
+using PixiEditor.ChangeableDocument.Changeables.Graph;
 using BlendMode = PixiEditor.ChangeableDocument.Enums.BlendMode;
 using DrawingApiBlendMode = Drawie.Backend.Core.Surfaces.BlendMode;
 
@@ -23,6 +24,9 @@ public class RenderContext
     public ColorSpace ProcessingColorSpace { get; set; }
     public string? TargetOutput { get; set; }   
 
+    private List<Guid> virtualGraphSessions = new List<Guid>();
+    private Dictionary<Guid, List<InputProperty>> recordedVirtualInputs = new();
+    private Dictionary<Guid, List<OutputProperty>> recordedVirtualOutputs = new();
 
     public RenderContext(DrawingSurface renderSurface, KeyFrameTime frameTime, ChunkResolution chunkResolution,
         VecI renderOutputSize, VecI documentSize, ColorSpace processingColorSpace, SamplingOptions desiredSampling, double opacity = 1)
@@ -71,4 +75,74 @@ public class RenderContext
             TargetOutput = TargetOutput,
         };
     }
+
+    public void BeginVirtualConnectionScope(Guid virtualSessionId)
+    {
+        if (virtualGraphSessions.Contains(virtualSessionId))
+            return;
+
+        virtualGraphSessions.Add(virtualSessionId);
+    }
+
+    public void EndVirtualConnectionScope(Guid virtualSessionId)
+    {
+        if (!virtualGraphSessions.Contains(virtualSessionId))
+            return;
+
+        virtualGraphSessions.Remove(virtualSessionId);
+
+        foreach (var inputProperty in recordedVirtualInputs)
+        {
+            foreach (var input in inputProperty.Value)
+            {
+                input.RemoveVirtualConnection(virtualSessionId);
+                input.ActiveVirtualSession = null;
+            }
+        }
+
+        foreach (var outputProperty in recordedVirtualOutputs)
+        {
+            foreach (var output in outputProperty.Value)
+            {
+                output.RemoveAllVirtualConnections(virtualSessionId);
+            }
+        }
+    }
+
+    public void RecordVirtualConnection(InputProperty inputProperty, Guid virtualSessionId)
+    {
+        if (virtualGraphSessions.Count == 0)
+            return;
+
+        if (!recordedVirtualInputs.TryGetValue(virtualSessionId, out var inputs))
+        {
+            inputs = new List<InputProperty>();
+            recordedVirtualInputs[virtualSessionId] = inputs;
+        }
+
+        inputs.Add(inputProperty);
+        inputProperty.ActiveVirtualSession = virtualSessionId;
+    }
+
+    public void RecordVirtualConnection(OutputProperty outputProperty, Guid virtualConnectionId)
+    {
+        if (virtualGraphSessions.Count == 0)
+            return;
+
+        if (!recordedVirtualOutputs.TryGetValue(virtualConnectionId, out var outputs))
+        {
+            outputs = new List<OutputProperty>();
+            recordedVirtualOutputs[virtualConnectionId] = outputs;
+        }
+
+        outputs.Add(outputProperty);
+    }
+
+    public void CleanupVirtualConnectionScopes()
+    {
+        foreach (var virtualSessionId in virtualGraphSessions.ToArray())
+        {
+            EndVirtualConnectionScope(virtualSessionId);
+        }
+    }
 }