Browse Source

fix : ILuaIOStream open mode behavior

Akeit0 7 months ago
parent
commit
fd686dafaf

+ 74 - 36
src/Lua/IO/ILuaFileSystem.cs

@@ -1,4 +1,5 @@
 using Lua.Internal;
 using Lua.Internal;
+using System.Text;
 
 
 namespace Lua.IO;
 namespace Lua.IO;
 
 
@@ -16,6 +17,7 @@ public interface ILuaFileSystem
 
 
 public interface ILuaIOStream : IDisposable
 public interface ILuaIOStream : IDisposable
 {
 {
+    public LuaFileOpenMode Mode { get; }
     public ValueTask<string?> ReadLineAsync(CancellationToken cancellationToken);
     public ValueTask<string?> ReadLineAsync(CancellationToken cancellationToken);
     public ValueTask<string> ReadToEndAsync(CancellationToken cancellationToken);
     public ValueTask<string> ReadToEndAsync(CancellationToken cancellationToken);
     public ValueTask<string?> ReadStringAsync(int count, CancellationToken cancellationToken);
     public ValueTask<string?> ReadStringAsync(int count, CancellationToken cancellationToken);
@@ -23,6 +25,11 @@ public interface ILuaIOStream : IDisposable
     public ValueTask FlushAsync(CancellationToken cancellationToken);
     public ValueTask FlushAsync(CancellationToken cancellationToken);
     public void SetVBuf(LuaFileBufferingMode mode, int size);
     public void SetVBuf(LuaFileBufferingMode mode, int size);
     public long Seek(long offset, SeekOrigin origin);
     public long Seek(long offset, SeekOrigin origin);
+
+    public static ILuaIOStream CreateStreamWrapper(LuaFileOpenMode mode, Stream stream)
+    {
+        return new LuaIOStreamWrapper(mode, stream);
+    }
 }
 }
 
 
 public sealed class FileSystem : ILuaFileSystem
 public sealed class FileSystem : ILuaFileSystem
@@ -37,7 +44,7 @@ public sealed class FileSystem : ILuaFileSystem
             LuaFileOpenMode.Write => (FileMode.Create, FileAccess.Write),
             LuaFileOpenMode.Write => (FileMode.Create, FileAccess.Write),
             LuaFileOpenMode.Append => (FileMode.Append, FileAccess.Write),
             LuaFileOpenMode.Append => (FileMode.Append, FileAccess.Write),
             LuaFileOpenMode.ReadWriteOpen => (FileMode.Open, FileAccess.ReadWrite),
             LuaFileOpenMode.ReadWriteOpen => (FileMode.Open, FileAccess.ReadWrite),
-            LuaFileOpenMode.ReadWriteCreate => (FileMode.Create, FileAccess.ReadWrite),
+            LuaFileOpenMode.ReadWriteCreate => (FileMode.Truncate, FileAccess.ReadWrite),
             LuaFileOpenMode.ReadAppend => (FileMode.Append, FileAccess.ReadWrite),
             LuaFileOpenMode.ReadAppend => (FileMode.Append, FileAccess.ReadWrite),
             _ => throw new ArgumentOutOfRangeException(nameof(luaFileOpenMode), luaFileOpenMode, null)
             _ => throw new ArgumentOutOfRangeException(nameof(luaFileOpenMode), luaFileOpenMode, null)
         };
         };
@@ -65,15 +72,17 @@ public sealed class FileSystem : ILuaFileSystem
 
 
     public ILuaIOStream? Open(string path, LuaFileOpenMode luaMode, bool throwError)
     public ILuaIOStream? Open(string path, LuaFileOpenMode luaMode, bool throwError)
     {
     {
-        if (luaMode == LuaFileOpenMode.ReadAppend)
-        {
-            throw new NotSupportedException("a+ mode is not supported.");
-        }
-
         var (mode, access) = GetFileMode(luaMode);
         var (mode, access) = GetFileMode(luaMode);
         try
         try
         {
         {
-            return new LuaIOStreamWrapper(File.Open(path, mode, access));
+            if (luaMode == LuaFileOpenMode.ReadAppend)
+            {
+                var s = new LuaIOStreamWrapper(luaMode, File.Open(path, FileMode.OpenOrCreate, FileAccess.ReadWrite, FileShare.ReadWrite | FileShare.Delete));
+                s.Seek(0, SeekOrigin.End);
+                return s;
+            }
+
+            return new LuaIOStreamWrapper(luaMode, File.Open(path, mode, access, FileShare.ReadWrite | FileShare.Delete));
         }
         }
         catch (Exception)
         catch (Exception)
         {
         {
@@ -108,70 +117,95 @@ public sealed class FileSystem : ILuaFileSystem
 
 
     public ILuaIOStream OpenTempFileStream()
     public ILuaIOStream OpenTempFileStream()
     {
     {
-        return new LuaIOStreamWrapper(File.Open(Path.GetTempFileName(), FileMode.OpenOrCreate, FileAccess.ReadWrite));
+        return new LuaIOStreamWrapper(LuaFileOpenMode.ReadAppend, File.Open(Path.GetTempFileName(), FileMode.Open, FileAccess.ReadWrite));
     }
     }
 }
 }
 
 
-public sealed class LuaIOStreamWrapper(Stream innerStream) : ILuaIOStream
+internal sealed class LuaIOStreamWrapper(LuaFileOpenMode mode, Stream innerStream) : ILuaIOStream
 {
 {
-    StreamReader? reader;
-    StreamWriter? writer;
+    public LuaFileOpenMode Mode => mode;
+    Utf8Reader? reader;
+    ulong flushSize = ulong.MaxValue;
+    ulong nextFlushSize = ulong.MaxValue;
 
 
     public ValueTask<string?> ReadLineAsync(CancellationToken cancellationToken)
     public ValueTask<string?> ReadLineAsync(CancellationToken cancellationToken)
     {
     {
         ThrowIfNotReadable();
         ThrowIfNotReadable();
-        reader ??= new(innerStream);
-
-        return new(reader.ReadLine());
+        reader ??= new();
+        return new(reader.ReadLine(innerStream));
     }
     }
 
 
     public ValueTask<string> ReadToEndAsync(CancellationToken cancellationToken)
     public ValueTask<string> ReadToEndAsync(CancellationToken cancellationToken)
     {
     {
         ThrowIfNotReadable();
         ThrowIfNotReadable();
-        reader ??= new(innerStream);
-
-        return new(reader.ReadToEnd());
+        reader ??= new();
+        return new(reader.ReadToEnd(innerStream));
     }
     }
 
 
     public ValueTask<string?> ReadStringAsync(int count, CancellationToken cancellationToken)
     public ValueTask<string?> ReadStringAsync(int count, CancellationToken cancellationToken)
     {
     {
         ThrowIfNotReadable();
         ThrowIfNotReadable();
-        reader ??= new(innerStream);
-
-        using var byteBuffer = new PooledArray<char>(count);
-        var span = byteBuffer.AsSpan();
-        var ret = reader.Read(span);
-        if (ret != span.Length)
-        {
-            return new(default(string));
-        }
-
-        return new(span.ToString());
+        reader ??= new();
+        return new(reader.Read(innerStream, count));
     }
     }
 
 
     public ValueTask WriteAsync(ReadOnlyMemory<char> buffer, CancellationToken cancellationToken)
     public ValueTask WriteAsync(ReadOnlyMemory<char> buffer, CancellationToken cancellationToken)
     {
     {
         ThrowIfNotWritable();
         ThrowIfNotWritable();
-        writer ??= new(innerStream);
-        writer.Write(buffer.Span);
+        if (mode is LuaFileOpenMode.Append or LuaFileOpenMode.ReadAppend)
+        {
+            innerStream.Seek(0, SeekOrigin.End);
+        }
+
+        using var byteBuffer = new PooledArray<byte>(4096);
+        var encoder = Encoding.UTF8.GetEncoder();
+        var totalBytes = encoder.GetByteCount(buffer.Span, true);
+        var remainingBytes = totalBytes;
+        while (0 < remainingBytes)
+        {
+            var byteCount = encoder.GetBytes(buffer.Span, byteBuffer.AsSpan(), false);
+            innerStream.Write(byteBuffer.AsSpan()[..byteCount]);
+            remainingBytes -= byteCount;
+        }
+
+        if (nextFlushSize < (ulong)totalBytes)
+        {
+            innerStream.Flush();
+            nextFlushSize = flushSize;
+        }
+
+        reader?.Clear();
         return new();
         return new();
     }
     }
 
 
     public ValueTask FlushAsync(CancellationToken cancellationToken)
     public ValueTask FlushAsync(CancellationToken cancellationToken)
     {
     {
-        ThrowIfNotWritable();
-        writer?.Flush();
+        innerStream.Flush();
+        nextFlushSize = flushSize;
         return new();
         return new();
     }
     }
 
 
     public void SetVBuf(LuaFileBufferingMode mode, int size)
     public void SetVBuf(LuaFileBufferingMode mode, int size)
     {
     {
-        writer ??= new(innerStream);
         // Ignore size parameter
         // Ignore size parameter
-        writer.AutoFlush = mode is LuaFileBufferingMode.NoBuffering or LuaFileBufferingMode.LineBuffering;
+        if (mode is LuaFileBufferingMode.NoBuffering or LuaFileBufferingMode.LineBuffering)
+        {
+            nextFlushSize = 0;
+            flushSize = 0;
+        }
+        else
+        {
+            nextFlushSize = (ulong)size;
+            flushSize = (ulong)size;
+        }
+    }
+
+    public long Seek(long offset, SeekOrigin origin)
+    {
+        reader?.Clear();
+        return innerStream.Seek(offset, origin);
     }
     }
 
 
-    public long Seek(long offset, SeekOrigin origin) => innerStream.Seek(offset, origin);
     public bool CanRead => innerStream.CanRead;
     public bool CanRead => innerStream.CanRead;
     public bool CanSeek => innerStream.CanSeek;
     public bool CanSeek => innerStream.CanSeek;
     public bool CanWrite => innerStream.CanWrite;
     public bool CanWrite => innerStream.CanWrite;
@@ -192,5 +226,9 @@ public sealed class LuaIOStreamWrapper(Stream innerStream) : ILuaIOStream
         }
         }
     }
     }
 
 
-    public void Dispose() => innerStream.Dispose();
+    public void Dispose()
+    {
+        innerStream.Dispose();
+        reader?.Dispose();
+    }
 }
 }

+ 191 - 0
src/Lua/Internal/Utf8Reader.cs

@@ -0,0 +1,191 @@
+using System.Buffers;
+using System.Text;
+
+namespace Lua.Internal;
+
+internal sealed class Utf8Reader
+{
+    [ThreadStatic]
+    static byte[]? scratchBuffer;
+
+    [ThreadStatic]
+    internal static bool scratchBufferUsed;
+
+    private readonly byte[] buffer;
+    private int bufPos, bufLen;
+    private Decoder? decoder;
+
+    const int ThreadStaticBufferSize = 1024;
+
+    public Utf8Reader()
+    {
+        if (scratchBufferUsed)
+        {
+            buffer = new byte[ThreadStaticBufferSize];
+            return;
+        }
+
+        scratchBuffer ??= new byte[ThreadStaticBufferSize];
+
+        buffer = scratchBuffer;
+        scratchBufferUsed = true;
+    }
+
+    public string? ReadLine(Stream stream)
+    {
+        var resultBuffer = ArrayPool<byte>.Shared.Rent(1024);
+        var lineLen = 0;
+        try
+        {
+            while (true)
+            {
+                if (bufPos >= bufLen)
+                {
+                    bufLen = stream.Read(buffer, 0, buffer.Length);
+                    bufPos = 0;
+                    if (bufLen == 0)
+                        break; // EOF
+                }
+
+                var span = new Span<byte>(buffer, bufPos, bufLen - bufPos);
+                int idx = span.IndexOfAny((byte)'\r', (byte)'\n');
+
+                if (idx >= 0)
+                {
+                    AppendToBuffer(ref resultBuffer, span[..idx], ref lineLen);
+
+                    byte nl = span[idx];
+                    bufPos += idx + 1;
+
+                    // CRLF
+                    if (nl == (byte)'\r' && bufPos < bufLen && buffer[bufPos] == (byte)'\n')
+                        bufPos++;
+
+                    // 行を返す
+                    return Encoding.UTF8.GetString(resultBuffer, 0, lineLen);
+                }
+                else
+                {
+                    // 改行なし → 全部行バッファへ
+                    AppendToBuffer(ref resultBuffer, span, ref lineLen);
+                    bufPos = bufLen;
+                }
+            }
+
+            if (lineLen == 0)
+                return null;
+            return Encoding.UTF8.GetString(resultBuffer, 0, lineLen);
+        }
+        finally
+        {
+            ArrayPool<byte>.Shared.Return(resultBuffer);
+        }
+    }
+
+    public string ReadToEnd(Stream stream)
+    {
+        var resultBuffer = ArrayPool<byte>.Shared.Rent(1024);
+        var len = 0;
+        try
+        {
+            while (true)
+            {
+                if (bufPos >= bufLen)
+                {
+                    bufLen = stream.Read(buffer, 0, buffer.Length);
+                    bufPos = 0;
+                    if (bufLen == 0)
+                        break; // EOF
+                }
+
+                var span = new Span<byte>(buffer, bufPos, bufLen - bufPos);
+                AppendToBuffer(ref resultBuffer, span, ref len);
+                bufPos = bufLen;
+            }
+
+            if (len == 0)
+                return "";
+            return Encoding.UTF8.GetString(resultBuffer, 0, len);
+        }
+        finally
+        {
+            ArrayPool<byte>.Shared.Return(resultBuffer);
+        }
+    }
+
+    public string? Read(Stream stream, int charCount)
+    {
+        if (charCount < 0) throw new ArgumentOutOfRangeException(nameof(charCount));
+        if (charCount == 0) return string.Empty;
+
+        var len = 0;
+        bool dataRead = false;
+        var resultBuffer = ArrayPool<char>.Shared.Rent(charCount);
+
+        try
+        {
+            while (len < charCount)
+            {
+                if (bufPos >= bufLen)
+                {
+                    bufLen = stream.Read(buffer, 0, buffer.Length);
+                    bufPos = 0;
+                    if (bufLen == 0) break; // EOF
+                }
+
+                var byteSpan = new ReadOnlySpan<byte>(buffer, bufPos, bufLen - bufPos);
+                var charSpan = new Span<char>(resultBuffer, len, charCount - len);
+                decoder ??= Encoding.UTF8.GetDecoder();
+                decoder.Convert(
+                    byteSpan,
+                    charSpan,
+                    flush: false,
+                    out int bytesUsed,
+                    out int charsUsed,
+                    out _);
+
+                if (charsUsed > 0)
+                {
+                    len += charsUsed;
+                    dataRead = true;
+                }
+
+                bufPos += bytesUsed;
+                if (bytesUsed == 0) break;
+            }
+
+            if (!dataRead || len != charCount) return null;
+            return resultBuffer.AsSpan(0, len).ToString();
+        }
+        finally
+        {
+            ArrayPool<char>.Shared.Return(resultBuffer);
+        }
+    }
+
+
+    private static void AppendToBuffer(ref byte[] buffer, ReadOnlySpan<byte> segment, ref int length)
+    {
+        if (length + segment.Length > buffer.Length)
+        {
+            int newSize = Math.Max(buffer.Length * 2, length + segment.Length);
+            var newBuffer = ArrayPool<byte>.Shared.Rent(newSize);
+            Array.Copy(buffer, newBuffer, length);
+            ArrayPool<byte>.Shared.Return(buffer);
+        }
+
+        segment.CopyTo(buffer.AsSpan(length));
+        length += segment.Length;
+    }
+
+    public void Clear()
+    {
+        bufPos = 0;
+        bufLen = 0;
+    }
+
+    public void Dispose()
+    {
+        scratchBufferUsed = false;
+    }
+}

+ 1 - 1
src/Lua/Standard/FileHandle.cs

@@ -48,7 +48,7 @@ public class FileHandle : ILuaUserData
         fileHandleMetatable[Metamethods.Index] = IndexMetamethod;
         fileHandleMetatable[Metamethods.Index] = IndexMetamethod;
     }
     }
 
 
-    public FileHandle(Stream stream) : this(new LuaIOStreamWrapper(stream)) { }
+    public FileHandle(LuaFileOpenMode mode, Stream stream) : this(new LuaIOStreamWrapper(mode,stream)) { }
 
 
     public FileHandle(ILuaIOStream stream)
     public FileHandle(ILuaIOStream stream)
     {
     {

+ 4 - 3
src/Lua/Standard/OpenLibsExtensions.cs

@@ -1,3 +1,4 @@
+using Lua.IO;
 using Lua.Runtime;
 using Lua.Runtime;
 using Lua.Standard.Internal;
 using Lua.Standard.Internal;
 
 
@@ -47,9 +48,9 @@ public static class OpenLibsExtensions
         }
         }
 
 
         var registry = state.Registry;
         var registry = state.Registry;
-        registry["stdin"] = new(new FileHandle(ConsoleHelper.OpenStandardInput()));
-        registry["stdout"] = new(new FileHandle(ConsoleHelper.OpenStandardOutput()));
-        registry["stderr"] = new(new FileHandle(ConsoleHelper.OpenStandardError()));
+        registry["stdin"] = new(new FileHandle(LuaFileOpenMode.Read, ConsoleHelper.OpenStandardInput()));
+        registry["stdout"] = new(new FileHandle(LuaFileOpenMode.Write, ConsoleHelper.OpenStandardOutput()));
+        registry["stderr"] = new(new FileHandle(LuaFileOpenMode.Write, ConsoleHelper.OpenStandardError()));
 
 
         state.Environment["io"] = io;
         state.Environment["io"] = io;
         state.LoadedModules["io"] = io;
         state.LoadedModules["io"] = io;

+ 3 - 1
tests/Lua.Tests/Helpers/NotSupportedStreamBase.cs

@@ -8,6 +8,8 @@ namespace Lua.Tests.Helpers
         {
         {
         }
         }
 
 
+        public virtual LuaFileOpenMode Mode => throw IOThrowHelpers.GetNotSupportedException();
+
         public virtual ValueTask<string?> ReadLineAsync(CancellationToken cancellationToken)
         public virtual ValueTask<string?> ReadLineAsync(CancellationToken cancellationToken)
         {
         {
             throw IOThrowHelpers.GetNotSupportedException();
             throw IOThrowHelpers.GetNotSupportedException();
@@ -38,7 +40,7 @@ namespace Lua.Tests.Helpers
             throw IOThrowHelpers.GetNotSupportedException();
             throw IOThrowHelpers.GetNotSupportedException();
         }
         }
 
 
-        public virtual  long Seek(long offset, SeekOrigin origin)
+        public virtual long Seek(long offset, SeekOrigin origin)
         {
         {
             throw IOThrowHelpers.GetNotSupportedException();
             throw IOThrowHelpers.GetNotSupportedException();
         }
         }