Browse Source

More name safety for BuiltInFunctions

CPKreuz 11 months ago
parent
commit
ff75b34acf

+ 107 - 114
src/PixiEditor.DrawingApi.Core/Shaders/Generation/BuiltInFunctions.cs

@@ -1,6 +1,4 @@
-using System;
-using System.Collections.Generic;
-using System.Linq;
+using System.Collections.Generic;
 using System.Text;
 using System.Text;
 using PixiEditor.DrawingApi.Core.Shaders.Generation.Expressions;
 using PixiEditor.DrawingApi.Core.Shaders.Generation.Expressions;
 
 
@@ -8,161 +6,156 @@ namespace PixiEditor.DrawingApi.Core.Shaders.Generation;
 
 
 public class BuiltInFunctions
 public class BuiltInFunctions
 {
 {
-    private readonly List<BuiltInFunctionType> usedFunctions = new(Enum.GetValues(typeof(BuiltInFunctionType)).Length);
+    private readonly List<IBuiltInFunction> usedFunctions = new(6);
 
 
-    public Expression GetRgbToHsv(Expression rgba)
-    {
-        Require(BuiltInFunctionType.RgbToHsv);
+    public Expression GetRgbToHsv(Expression rgba) => Call(RgbToHsv, rgba);
 
 
-        return new Expression($"{nameof(RgbToHsv)}({rgba.ExpressionValue})");
-    }
-    
-    public Expression GetRgbToHsv(Expression r, Expression g, Expression b, Expression a) =>
-        GetRgbToHsv(Half4.Constructor(r, g, b, a));
-
-    public Expression GetHsvToRgb(Expression hsva)
-    {
-        Require(BuiltInFunctionType.HsvToRgb);
+    public Expression GetRgbToHsl(Expression rgba) => Call(RgbToHsl, rgba);
 
 
-        return new Expression($"{nameof(HsvToRgb)}({hsva.ExpressionValue})");
-    }
+    public Expression GetHsvToRgb(Expression hsva) => Call(HsvToRgb, hsva);
     
     
     public Expression GetHsvToRgb(Expression h, Expression s, Expression v, Expression a) =>
     public Expression GetHsvToRgb(Expression h, Expression s, Expression v, Expression a) =>
         GetHsvToRgb(Half4.Constructor(h, s, v, a));
         GetHsvToRgb(Half4.Constructor(h, s, v, a));
 
 
-    public Expression GetHslToRgb(Expression hsla)
-    {
-        Require(BuiltInFunctionType.HslToRgb);
-        
-        return new Expression($"{nameof(HslToRgb)}({hsla.ExpressionValue})");
-    }
-
-    public Expression GetRgbToHsl(Expression rgba)
-    {
-        Require(BuiltInFunctionType.RgbToHsl);
-
-        return new Expression($"{nameof(RgbToHsl)}({rgba.ExpressionValue})");
-    }
+    public Expression GetHslToRgb(Expression hsla) => Call(HslToRgb, hsla);
 
 
     public Expression GetHslToRgb(Expression h, Expression s, Expression l, Expression a) =>
     public Expression GetHslToRgb(Expression h, Expression s, Expression l, Expression a) =>
         GetHslToRgb(Half4.Constructor(h, s, l, a));
         GetHslToRgb(Half4.Constructor(h, s, l, a));
-    
+
     public string BuildFunctions()
     public string BuildFunctions()
     {
     {
         var builder = new StringBuilder();
         var builder = new StringBuilder();
 
 
-        AppendIf(BuiltInFunctionType.HueToRgb, HueToRgb);
-        AppendIf(BuiltInFunctionType.RgbToHcv, RgbToHcv);
+        AppendIf(HueToRgb);
+        AppendIf(RgbToHcv);
         
         
-        AppendIf(BuiltInFunctionType.HsvToRgb, HsvToRgb);
-        AppendIf(BuiltInFunctionType.RgbToHsv, RgbToHsv);
+        AppendIf(HsvToRgb);
+        AppendIf(RgbToHsv);
 
 
-        AppendIf(BuiltInFunctionType.HslToRgb, HslToRgb);
-        AppendIf(BuiltInFunctionType.RgbToHsl, RgbToHsl);
+        AppendIf(HslToRgb);
+        AppendIf(RgbToHsl);
         
         
         return builder.ToString();
         return builder.ToString();
 
 
-        void AppendIf(BuiltInFunctionType type, string source)
+        void AppendIf(IBuiltInFunction function)
         {
         {
-            if (usedFunctions.Contains(type))
+            if (usedFunctions.Contains(function))
             {
             {
-                builder.AppendLine(source);
+                builder.AppendLine(function.FullSource);
             }
             }
         }
         }
     }
     }
 
 
-    private void Require(BuiltInFunctionType type)
+    private Expression Call(IBuiltInFunction function, Expression expression)
     {
     {
-        if (usedFunctions.Contains(type))
+        Require(function);
+
+        return new Expression(function.Call(expression.ExpressionValue));
+    }
+
+    private void Require(IBuiltInFunction function)
+    {
+        if (usedFunctions.Contains(function))
         {
         {
             return;
             return;
         }
         }
 
 
-        switch (type)
+        foreach (var dependency in function.Dependencies)
         {
         {
-            case BuiltInFunctionType.HsvToRgb:
-            case BuiltInFunctionType.HslToRgb:
-                Require(BuiltInFunctionType.HueToRgb);
-                break;
-            case BuiltInFunctionType.RgbToHsv:
-            case BuiltInFunctionType.RgbToHsl:
-                Require(BuiltInFunctionType.RgbToHcv);
-                break;
+            Require(dependency);
         }
         }
 
 
-        usedFunctions.Add(type);
+        usedFunctions.Add(function);
     }
     }
 
 
     // Taken from here https://www.shadertoy.com/view/4dKcWK
     // Taken from here https://www.shadertoy.com/view/4dKcWK
-    private const string HueToRgb =
-        $$"""
-          half3 {{nameof(HueToRgb)}}(float hue)
-          {
-              half3 rgb = abs(hue * 6. - half3(3, 2, 4)) * half3(1, -1, -1) + half3(-1, 2, 2);
-              return clamp(rgb, 0., 1.);
-          }
-          """;
-    
-    private const string RgbToHcv =
-        $$"""
-          half3 {{nameof(RgbToHcv)}}(half3 rgba)
-          {
-              half4 p = (rgba.g < rgba.b) ? half4(rgba.bg, -1., 2. / 3.) : half4(rgba.gb, 0., -1. / 3.);
-              half4 q = (rgba.r < p.x) ? half4(p.xyw, rgba.r) : half4(rgba.r, p.yzx);
-              float c = q.x - min(q.w, q.y);
-              float h = abs((q.w - q.y) / (6. * c) + q.z);
-              return half3(h, c, q.x);
-          }
-          """;
+    private static readonly BuiltInFunction<Half3> HueToRgb = new(
+        "float hue",
+        nameof(HueToRgb),
+        """
+        half3 rgb = abs(hue * 6. - half3(3, 2, 4)) * half3(1, -1, -1) + half3(-1, 2, 2);
+        return clamp(rgb, 0., 1.);
+        """);
+
+    private static readonly BuiltInFunction<Half3> RgbToHcv = new(
+        "half3 rgba",
+        nameof(RgbToHcv),
+        """
+        half4 p = (rgba.g < rgba.b) ? half4(rgba.bg, -1., 2. / 3.) : half4(rgba.gb, 0., -1. / 3.);
+        half4 q = (rgba.r < p.x) ? half4(p.xyw, rgba.r) : half4(rgba.r, p.yzx);
+        float c = q.x - min(q.w, q.y);
+        float h = abs((q.w - q.y) / (6. * c) + q.z);
+        return half3(h, c, q.x);
+        """);
+
+    private static readonly BuiltInFunction<Half4> RgbToHsv = new(
+        "half4 rgba",
+        nameof(RgbToHsv),
+        $"""
+         half3 hcv = {RgbToHcv.Call("rgba.rgb")};
+         float s = hcv.y / (hcv.z);
+         return half4(hcv.x, s, hcv.z, rgba.w);
+         """,
+        RgbToHcv);
+
+    private static readonly BuiltInFunction<Half4> HsvToRgb = new(
+        "half4 hsva",
+        nameof(HsvToRgb),
+        $"""
+         half3 rgb = {HueToRgb.Call("hsva.r")};
+         return half4(((rgb - 1.) * hsva.y + 1.) * hsva.z, hsva.w);
+         """,
+        HueToRgb);
+
+    private static readonly BuiltInFunction<Half4> RgbToHsl = new(
+        "half4 rgba", 
+        nameof(RgbToHsl), 
+        $"""
+         half3 hcv = {RgbToHcv.Call("rgba.rgb")};
+         half z = hcv.z - hcv.y * 0.5;
+         half s = hcv.y / (1. - abs(z * 2. - 1.));
+         return half4(hcv.x, s, z, rgba.w);
+         """,
+        RgbToHcv);
+
+    private static readonly BuiltInFunction<Half4> HslToRgb = new(
+        "half4 hsla",
+        nameof(HslToRgb),
+        $"""
+         half3 rgb = {HueToRgb.Call("hsla.r")};
+         float c = (1. - abs(2. * hsla.z - 1.)) * hsla.y;
+         return half4((rgb - 0.5) * c + hsla.z, hsla.w);
+         """,
+        HueToRgb);
+
+    private class BuiltInFunction<TReturn>(string argumentList, string name, string body, params IBuiltInFunction[] dependencies) : IBuiltInFunction where TReturn : ShaderExpressionVariable
+    {
+        public string ArgumentList { get; } = argumentList;
 
 
-    private const string RgbToHsv =
-        $$"""
-          half4 {{nameof(RgbToHsv)}}(half4 rgba)
-          {
-              half3 hcv = {{nameof(RgbToHcv)}}(rgba.xyz);
-              float s = hcv.y / (hcv.z);
-              return half4(hcv.x, s, hcv.z, rgba.w);
-          }
-          """;
+        public string Name { get; } = name;
 
 
-    private const string HsvToRgb =
-        $$"""
-          half4 {{nameof(HsvToRgb)}}(half4 hsva)
-          {
-              half3 rgb = {{nameof(HueToRgb)}}(hsva.x);
-              return half4(((rgb - 1.) * hsva.y + 1.) * hsva.z, hsva.w);
-          }
-          """;
+        public string Body { get; } = body;
 
 
-    private const string HslToRgb = 
-        $$"""
-          half4 {{nameof(HslToRgb)}}(half4 hsla)
-          {
-              half3 rgb = {{nameof(HueToRgb)}}(hsla.x);
-              float c = (1. - abs(2. * hsla.z - 1.)) * hsla.y;
-              return half4((rgb - 0.5) * c + hsla.z, hsla.w);
-          }
-          """;
+        public IBuiltInFunction[] Dependencies { get; } = dependencies;
 
 
-    private const string RgbToHsl =
-        $$"""
-          half4 {{nameof(RgbToHsl)}}(half4 rgba)
-          {
-              half3 hcv = {{nameof(RgbToHcv)}}(rgba.xyz);
-              half z = hcv.z - hcv.y * 0.5;
-              half s = hcv.y / (1. - abs(z * 2. - 1.));
-              return half4(hcv.x, s, z, rgba.w);
+        public string FullSource =>
+         $$"""
+          {{typeof(TReturn).Name.ToLower()}} {{Name}}({{ArgumentList}}) {
+          {{Body}}
           }
           }
           """;
           """;
+
+        public string Call(string arguments) => $"{Name}({arguments})";
+    }
     
     
-    enum BuiltInFunctionType
+    private interface IBuiltInFunction
     {
     {
-        HueToRgb,
-        RgbToHcv,
+        IBuiltInFunction[] Dependencies { get; }
+        
+        string Name { get; }
+        
+        string FullSource { get; }
 
 
-        RgbToHsv,
-        HsvToRgb,
-        RgbToHsl,
-        HslToRgb
+        string Call(string arguments);
     }
     }
 }
 }

+ 23 - 0
src/PixiEditor.DrawingApi.Core/Shaders/Generation/Expressions/Float1.cs

@@ -42,3 +42,26 @@ public class Half4Float1Accessor : Float1
         return false;
         return false;
     }
     }
 }
 }
+
+public class Half3Float1Accessor : Float1
+{
+    public Half3Float1Accessor(Half3 accessTo, char name) : base(string.IsNullOrEmpty(accessTo.VariableName) ? string.Empty : $"{accessTo.VariableName}.{name}")
+    {
+        Accesses = accessTo;
+    }
+    
+    public Half3 Accesses { get; }
+
+    public static bool AllAccessSame(Expression r, Expression g, Expression b, out Half3? half3)
+    {
+        if (r is Half3Float1Accessor rA && g is Half3Float1Accessor gA && b is Half3Float1Accessor bA &&
+            rA.Accesses == gA.Accesses && rA.Accesses == bA.Accesses)
+        {
+            half3 = rA.Accesses;
+            return true;
+        }
+
+        half3 = null;
+        return false;
+    }
+}

+ 40 - 0
src/PixiEditor.DrawingApi.Core/Shaders/Generation/Expressions/Half3.cs

@@ -0,0 +1,40 @@
+using System;
+using PixiEditor.Numerics;
+
+namespace PixiEditor.DrawingApi.Core.Shaders.Generation.Expressions;
+
+public class Half3(string name) : ShaderExpressionVariable<VecD3>(name), IMultiValueVariable
+{
+    private Expression? _overrideExpression;
+    public override string ConstantValueString => $"half3({ConstantValue.X}, {ConstantValue.Y}, {ConstantValue.Z})";
+    
+    public Float1 R => new Half3Float1Accessor(this, 'r') { ConstantValue = ConstantValue.X, OverrideExpression = _overrideExpression};
+    public Float1 G => new Half3Float1Accessor(this, 'g') { ConstantValue = ConstantValue.X, OverrideExpression = _overrideExpression};
+    public Float1 B => new Half3Float1Accessor(this, 'b') { ConstantValue = ConstantValue.Z, OverrideExpression = _overrideExpression};
+
+    public override Expression? OverrideExpression
+    {
+        get => _overrideExpression;
+        set
+        {
+            _overrideExpression = value;
+        }
+    }
+
+    public ShaderExpressionVariable GetValueAt(int index)
+    {
+        return index switch
+        {
+            0 => R,
+            1 => G,
+            2 => B,
+            _ => throw new IndexOutOfRangeException()
+        };
+    }
+
+    public static string ConstructorText(Expression r, Expression g, Expression b) =>
+        $"half4({r.ExpressionValue}, {g.ExpressionValue}, {b.ExpressionValue})";
+
+    public static Expression Constructor(Expression r, Expression g, Expression b) =>
+        new Expression(ConstructorText(r, g, b));
+}