Browse Source

introduce StringInternPool.cs to reduce allocation and seq comp

Akeit0 3 months ago
parent
commit
2d77780e64

+ 9 - 3
src/Lua/CodeAnalysis/Compilation/Dump.cs

@@ -280,7 +280,7 @@ unsafe ref struct DumpState(IBufferWriter<byte> writer, bool reversedEndian)
     }
 }
 
-unsafe ref struct UnDumpState(ReadOnlySpan<byte> span, ReadOnlySpan<char> name)
+unsafe ref struct UnDumpState(ReadOnlySpan<byte> span, ReadOnlySpan<char> name, StringInternPool internPool)
 {
     public ReadOnlySpan<byte> Unread = span;
     bool otherEndian;
@@ -436,6 +436,7 @@ unsafe ref struct UnDumpState(ReadOnlySpan<byte> span, ReadOnlySpan<char> name)
 
         len--;
         var arrayPooled = ArrayPool<byte>.Shared.Rent(len);
+        char[]? charArrayPooled = null;
         try
         {
             var span = arrayPooled.AsSpan(0, len);
@@ -443,12 +444,17 @@ unsafe ref struct UnDumpState(ReadOnlySpan<byte> span, ReadOnlySpan<char> name)
 
             var l = ReadByte();
             Debug.Assert(l == 0);
-            var str = Encoding.UTF8.GetString(span);
-            return str;
+            var chars = len <= 128 ? stackalloc char[len*2] : (charArrayPooled = ArrayPool<char>.Shared.Rent(len * 2));
+            var count = Encoding.UTF8.GetChars(span, chars);
+            return internPool.Intern(chars[..count]);
         }
         finally
         {
             ArrayPool<byte>.Shared.Return(arrayPooled);
+            if (charArrayPooled != null)
+            {
+                ArrayPool<char>.Shared.Return(charArrayPooled);
+            }
         }
     }
 

+ 5 - 2
src/Lua/CodeAnalysis/Compilation/Parser.cs

@@ -937,6 +937,7 @@ class Parser : IPoolNode<Parser>, IDisposable
 
     public static Prototype Parse(LuaState l, TextReader r, string name)
     {
+        using var internPool = new StringInternPool(4);
         using var p = Get(new()
         {
             R = r,
@@ -945,7 +946,8 @@ class Parser : IPoolNode<Parser>, IDisposable
             LookAheadToken = new(0, TkEos),
             L = l,
             Source = name,
-            Buffer = new(r.Length)
+            Buffer = new(r.Length),
+            StringPool = internPool
         });
         var f = Function.Get(p, PrototypeBuilder.Get(name));
         p.Function = f;
@@ -980,7 +982,8 @@ class Parser : IPoolNode<Parser>, IDisposable
             };
         }
 
-        UnDumpState state = new(span, name);
+        using var internPool = new StringInternPool(4);
+        UnDumpState state = new(span, name, internPool);
         return state.UnDump();
     }
 }

+ 11 - 8
src/Lua/CodeAnalysis/Compilation/Scanner.cs

@@ -17,6 +17,9 @@ struct Scanner
     public string Source;
     public Token LookAheadToken;
     int lastNewLinePos;
+    public StringInternPool StringPool;
+    
+    string Intern(ReadOnlySpan<char> s) => StringPool.Intern(s);
 
     ///inline
     public Token Token;
@@ -94,7 +97,7 @@ struct Scanner
     public void NumberError(int numberStartPosition, int position)
     {
         Buffer.Clear();
-        Token = new(numberStartPosition, TkString, R.Span[numberStartPosition..(position - 1)].ToString());
+        Token = new(numberStartPosition, TkString, Intern(R.Span[numberStartPosition..(position - 1)]));
         ScanError(position, "malformed number", TkString);
     }
 
@@ -147,7 +150,7 @@ struct Scanner
     {
         var shortSourceBuffer = (stackalloc char[59]);
         var len = LuaDebug.WriteShortSource(Source, shortSourceBuffer);
-        var buff = shortSourceBuffer[..len].ToString();
+        var buff = Intern(shortSourceBuffer[..len]);
         string? nearToken = null;
         if (token != 0)
         {
@@ -246,7 +249,7 @@ struct Scanner
                         SaveAndAdvance();
                         if (!comment)
                         {
-                            var s = Buffer.AsSpan().Slice(2 + sep, Buffer.Length - (4 + (2 * sep))).ToString();
+                            var s = Intern( Buffer.AsSpan().Slice(2 + sep, Buffer.Length - (4 + (2 * sep))));
                             Buffer.Clear();
                             return s;
                         }
@@ -465,7 +468,7 @@ struct Scanner
 
         Save('\'');
 
-        Token = new(pos - Buffer.Length, TkString, Buffer.AsSpan().ToString());
+        Token = new(pos - Buffer.Length, TkString, Intern(Buffer.AsSpan()));
         Buffer.Clear();
         ScanError(pos, message, TkString);
     }
@@ -532,11 +535,11 @@ struct Scanner
             switch (Current)
             {
                 case EndOfStream:
-                    Token = new(R.Position - Buffer.Length, TkString, Buffer.AsSpan().ToString());
+                    Token = new(R.Position - Buffer.Length, TkString, Intern(Buffer.AsSpan()));
                     ScanError(R.Position, "unfinished string", TkEos);
                     break;
                 case '\n' or '\r':
-                    Token = new(R.Position - Buffer.Length, TkString, Buffer.AsSpan().ToString());
+                    Token = new(R.Position - Buffer.Length, TkString,Intern( Buffer.AsSpan()));
                     ScanError(R.Position, "unfinished string", TkString);
                     break;
                 case '\\':
@@ -594,7 +597,7 @@ struct Scanner
         // {
         //     length--;
         // }
-        var str = Buffer.AsSpan().Slice(1, length).ToString();
+        var str = Intern(Buffer.AsSpan().Slice(1, length));
         Buffer.Clear();
         return new(pos, TkString, str);
     }
@@ -615,7 +618,7 @@ struct Scanner
     public Token ReservedOrName()
     {
         var pos = R.Position - Buffer.Length;
-        var str = Buffer.AsSpan().ToString();
+        var str = Intern(Buffer.AsSpan());
         Buffer.Clear();
         for (var i = 0; i < Tokens.Length; i++)
         {

+ 181 - 0
src/Lua/CodeAnalysis/Compilation/StringInternPool.cs

@@ -0,0 +1,181 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Buffers;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
+
+namespace Lua.CodeAnalysis.Compilation;
+
+[SuppressMessage("ReSharper", "InconsistentNaming")]
+class StringInternPool : IDisposable
+{
+    int[] _buckets;
+    Entry[] _entries;
+    int _count;
+
+    public StringInternPool(int capacity = 16)
+    {
+        var size = Math.Max(16, capacity);
+        var buckets = ArrayPool<int>.Shared.Rent(size);
+        buckets.AsSpan().Clear();
+        var entries = ArrayPool<Entry>.Shared.Rent(size);
+
+        // Assign member variables after both arrays allocated to guard against corruption from OOM if second fails
+
+        _buckets = buckets;
+        _entries = entries;
+    }
+
+    public int Count => _count;
+
+
+    public void Clear()
+    {
+        var count = _count;
+        if (count > 0)
+        {
+            Debug.Assert(_buckets != null, "_buckets should be non-null");
+            Debug.Assert(_entries != null, "_entries should be non-null");
+            _buckets.AsSpan().Clear();
+
+            _count = 0;
+            _entries.AsSpan(0, count).Clear();
+        }
+    }
+
+    static int GetHashCode(ReadOnlySpan<char> span)
+    {
+        unchecked
+        {
+            int hash = 5381;
+            foreach (var t in span)
+            {
+                hash = ((hash << 5) + hash) ^ t;
+            }
+
+            return hash & 0x7FFFFFFF;
+        }
+    }
+
+    public string Intern(ReadOnlySpan<char> value)
+    {
+        Debug.Assert(_buckets != null);
+
+        Entry[] entries = _entries;
+
+        int hashCode;
+
+        ref int bucket = ref Unsafe.NullRef<int>();
+        {
+            hashCode = GetHashCode(value);
+            bucket = ref GetBucketRef(hashCode);
+            int i = bucket - 1; // Value in _buckets is 1-based
+            while (i >= 0)
+            {
+                ref Entry entry = ref entries[i];
+                if (entry.HashCode == hashCode && (value.SequenceEqual(entry.Value)))
+                {
+                    return entry.Value;
+                }
+
+                i = entry.Next;
+            }
+        }
+
+        int index;
+
+
+        int count = _count;
+        if (count == entries.Length)
+        {
+            Resize();
+            bucket = ref GetBucketRef(hashCode);
+        }
+
+        index = count;
+        _count = count + 1;
+        entries = _entries;
+
+
+        var stringValue = value.ToString();
+        stringValue = string.IsInterned(stringValue) ?? stringValue;
+        {
+            ref Entry entry = ref entries![index];
+            entry.HashCode = hashCode;
+            entry.Next = bucket - 1; // Value in _buckets is 1-based
+            entry.Value = stringValue;
+            bucket = index + 1;
+        }
+
+        return stringValue;
+    }
+
+
+    void Resize()
+    {
+        Resize(_entries!.Length * 2);
+    }
+
+    void Resize(int newSize)
+    {
+        // Value types never rehash
+        Debug.Assert(newSize >= _entries.Length);
+
+        var entries = ArrayPool<Entry>.Shared.Rent(newSize);
+
+        var count = _count;
+        Array.Copy(_entries, entries, count);
+
+        ArrayPool<Entry>.Shared.Return(_entries, true);
+        ArrayPool<int>.Shared.Return(_buckets);
+
+        // Assign member variables after both arrays allocated to guard against corruption from OOM if second fails
+        _buckets = ArrayPool<int>.Shared.Rent(newSize);
+        _buckets.AsSpan().Clear();
+        for (var i = 0; i < count; i++)
+        {
+            if (entries[i].Next >= -1)
+            {
+                ref var bucket = ref GetBucketRef(entries[i].HashCode);
+                entries[i].Next = bucket - 1; // Value in _buckets is 1-based
+                bucket = i + 1;
+            }
+        }
+
+        _entries = entries;
+    }
+
+
+    [MethodImpl(MethodImplOptions.AggressiveInlining)]
+    ref int GetBucketRef(int hashCode)
+    {
+        var buckets = _buckets;
+        return ref buckets[(uint)hashCode & (uint)(buckets.Length - 1)];
+    }
+
+    public void Dispose()
+    {
+        ArrayPool<int>.Shared.Return(_buckets);
+        _buckets = null!;
+
+
+        ArrayPool<Entry>.Shared.Return(_entries, true);
+        _entries = null!;
+    }
+
+    struct Entry
+    {
+        public int HashCode;
+
+        /// <summary>
+        /// 0-based index of next entry in chain: -1 means end of chain
+        /// also encodes whether this entry _itself_ is part of the free list by changing sign and subtracting 3,
+        /// so -2 means end of free list, -3 means index 0 but on free list, -4 means index 1 but on free list, etc.
+        /// </summary>
+        public int Next;
+
+        public string Value;
+    }
+}