LuaThreadAccessExtensions.cs 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. using System.Runtime.CompilerServices;
  2. // ReSharper disable MethodHasAsyncOverloadWithCancellation
  3. namespace Lua.Runtime;
  4. public static class LuaThreadAccessAccessExtensions
  5. {
  6. public static async ValueTask<int> DoStringAsync(this LuaThreadAccess access, string source, Memory<LuaValue> buffer, string? chunkName = null, CancellationToken cancellationToken = default)
  7. {
  8. access.ThrowIfInvalid();
  9. var closure = access.State.Load(source, chunkName ?? source);
  10. var count = await access.RunAsync(closure, 0, cancellationToken);
  11. using var results = access.ReadReturnValues(count);
  12. results.AsSpan()[..Math.Min(buffer.Length, count)].CopyTo(buffer.Span);
  13. return count;
  14. }
  15. public static async ValueTask<LuaValue[]> DoStringAsync(this LuaThreadAccess access, string source, string? chunkName = null, CancellationToken cancellationToken = default)
  16. {
  17. access.ThrowIfInvalid();
  18. var closure = access.State.Load(source, chunkName ?? source);
  19. var count = await access.RunAsync(closure, 0, cancellationToken);
  20. using var results = access.ReadReturnValues(count);
  21. return results.AsSpan().ToArray();
  22. }
  23. public static async ValueTask<int> DoFileAsync(this LuaThreadAccess access, string path, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
  24. {
  25. access.ThrowIfInvalid();
  26. var bytes = File.ReadAllBytes(path);
  27. var fileName = "@" + path;
  28. var closure = access.State.Load(bytes, fileName);
  29. var count = await access.RunAsync(closure, 0, cancellationToken);
  30. using var results = access.ReadReturnValues(count);
  31. results.AsSpan()[..Math.Min(buffer.Length, results.Length)].CopyTo(buffer.Span);
  32. return results.Count;
  33. }
  34. public static async ValueTask<LuaValue[]> DoFileAsync(this LuaThreadAccess access, string path, CancellationToken cancellationToken = default)
  35. {
  36. var bytes = File.ReadAllBytes(path);
  37. var fileName = "@" + path;
  38. var closure = access.State.Load(bytes, fileName);
  39. var count = await access.RunAsync(closure, 0, cancellationToken);
  40. using var results = access.ReadReturnValues(count);
  41. return results.AsSpan().ToArray();
  42. }
  43. public static void Push(this LuaThreadAccess access, LuaValue value)
  44. {
  45. access.ThrowIfInvalid();
  46. access.Stack.Push(value);
  47. }
  48. public static void Push(this LuaThreadAccess access, params ReadOnlySpan<LuaValue> span)
  49. {
  50. access.ThrowIfInvalid();
  51. access.Stack.PushRange(span);
  52. }
  53. public static void Pop(this LuaThreadAccess access, int count)
  54. {
  55. access.ThrowIfInvalid();
  56. access.Stack.Pop(count);
  57. }
  58. public static LuaValue Pop(this LuaThreadAccess access)
  59. {
  60. access.ThrowIfInvalid();
  61. return access.Stack.Pop();
  62. }
  63. public static LuaReturnValuesReader ReadReturnValues(this LuaThreadAccess access, int argumentCount)
  64. {
  65. access.ThrowIfInvalid();
  66. var stack = access.Stack;
  67. return new LuaReturnValuesReader(stack, stack.Count - argumentCount);
  68. }
  69. public static async ValueTask<LuaValue> Arithmetic(this LuaThreadAccess access, LuaValue x, LuaValue y, OpCode opCode, CancellationToken cancellationToken = default)
  70. {
  71. [MethodImpl(MethodImplOptions.NoInlining)]
  72. static double Mod(double a, double b)
  73. {
  74. var mod = a % b;
  75. if ((b > 0 && mod < 0) || (b < 0 && mod > 0))
  76. {
  77. mod += b;
  78. }
  79. return mod;
  80. }
  81. [MethodImpl(MethodImplOptions.AggressiveInlining)]
  82. static double ArithmeticOperation(OpCode code, double a, double b)
  83. {
  84. return code switch
  85. {
  86. OpCode.Add => a + b,
  87. OpCode.Sub => a - b,
  88. OpCode.Mul => a * b,
  89. OpCode.Div => a / b,
  90. OpCode.Mod => Mod(a, b),
  91. OpCode.Pow => Math.Pow(a, b),
  92. _ => throw new InvalidOperationException($"Unsupported arithmetic operation: {code}"),
  93. };
  94. }
  95. if (x.TryReadDouble(out var numX) && y.TryReadDouble(out var numY))
  96. {
  97. return ArithmeticOperation(opCode, numX, numY);
  98. }
  99. access.ThrowIfInvalid();
  100. return await LuaVirtualMachine.ExecuteBinaryOperationMetaMethod(access.Thread, x, y, opCode, cancellationToken);
  101. }
  102. public static async ValueTask<LuaValue> Unary(this LuaThreadAccess access, LuaValue value, OpCode opCode, CancellationToken cancellationToken = default)
  103. {
  104. if (opCode == OpCode.Unm)
  105. {
  106. if (value.TryReadDouble(out var numB))
  107. {
  108. return -numB;
  109. }
  110. }
  111. else if (opCode == OpCode.Len)
  112. {
  113. if (value.TryReadString(out var str))
  114. {
  115. return str.Length;
  116. }
  117. if (value.TryReadTable(out var table))
  118. {
  119. return table.ArrayLength;
  120. }
  121. }
  122. else
  123. {
  124. throw new InvalidOperationException($"Unsupported unary operation: {opCode}");
  125. }
  126. access.ThrowIfInvalid();
  127. return await LuaVirtualMachine.ExecuteUnaryOperationMetaMethod(access.Thread, value, opCode, cancellationToken);
  128. }
  129. public static async ValueTask<bool> Compare(this LuaThreadAccess access, LuaValue x, LuaValue y, OpCode opCode, CancellationToken cancellationToken = default)
  130. {
  131. if (opCode is not (OpCode.Eq or OpCode.Lt or OpCode.Le))
  132. {
  133. throw new InvalidOperationException($"Unsupported compare operation: {opCode}");
  134. }
  135. if (opCode == OpCode.Eq)
  136. {
  137. if (x == y)
  138. {
  139. return true;
  140. }
  141. }
  142. else
  143. {
  144. if (x.TryReadNumber(out var numX) && y.TryReadNumber(out var numY))
  145. {
  146. return opCode == OpCode.Lt ? numX < numY : numX <= numY;
  147. }
  148. if (x.TryReadString(out var strX) && y.TryReadString(out var strY))
  149. {
  150. var c = StringComparer.Ordinal.Compare(strX, strY);
  151. return opCode == OpCode.Lt ? c < 0 : c <= 0;
  152. }
  153. }
  154. access.ThrowIfInvalid();
  155. return await LuaVirtualMachine.ExecuteCompareOperationMetaMethod(access.Thread, x, y, opCode, cancellationToken);
  156. }
  157. public static async ValueTask<LuaValue> GetTable(this LuaThreadAccess access, LuaValue table, LuaValue key, CancellationToken cancellationToken = default)
  158. {
  159. if (table.TryReadTable(out var luaTable))
  160. {
  161. if (luaTable.TryGetValue(key, out var value))
  162. {
  163. return new(value);
  164. }
  165. }
  166. access.ThrowIfInvalid();
  167. return await LuaVirtualMachine.ExecuteGetTableSlowPath(access.Thread, table, key, cancellationToken);
  168. }
  169. public static async ValueTask SetTable(this LuaThreadAccess access, LuaValue table, LuaValue key, LuaValue value, CancellationToken cancellationToken = default)
  170. {
  171. access.ThrowIfInvalid();
  172. if (key.TryReadNumber(out var numB))
  173. {
  174. if (double.IsNaN(numB))
  175. {
  176. throw new LuaRuntimeException(access.Thread, "table index is NaN");
  177. }
  178. }
  179. if (table.TryReadTable(out var luaTable))
  180. {
  181. ref var valueRef = ref luaTable.FindValue(key);
  182. if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
  183. {
  184. valueRef = value;
  185. return;
  186. }
  187. }
  188. await LuaVirtualMachine.ExecuteSetTableSlowPath(access.Thread, table, key, value, cancellationToken);
  189. }
  190. public static ValueTask<LuaValue> Concat(this LuaThreadAccess access, ReadOnlySpan<LuaValue> values, CancellationToken cancellationToken = default)
  191. {
  192. access.ThrowIfInvalid();
  193. access.Stack.PushRange(values);
  194. return Concat(access, values.Length, cancellationToken);
  195. }
  196. public static async ValueTask<LuaValue> Concat(this LuaThreadAccess access, int concatCount, CancellationToken cancellationToken = default)
  197. {
  198. access.ThrowIfInvalid();
  199. return await LuaVirtualMachine.Concat(access.Thread, concatCount, cancellationToken);
  200. }
  201. public static ValueTask<int> Call(this LuaThreadAccess access, int funcIndex, int returnBase, CancellationToken cancellationToken = default)
  202. {
  203. access.ThrowIfInvalid();
  204. return LuaVirtualMachine.Call(access.Thread, funcIndex, returnBase, cancellationToken);
  205. }
  206. public static ValueTask<LuaValue[]> Call(this LuaThreadAccess access, LuaValue function, ReadOnlySpan<LuaValue> arguments, CancellationToken cancellationToken = default)
  207. {
  208. access.ThrowIfInvalid();
  209. var thread = access.Thread;
  210. var funcIndex = thread.Stack.Count;
  211. thread.Stack.Push(function);
  212. thread.Stack.PushRange(arguments);
  213. return Impl(access, funcIndex, cancellationToken);
  214. static async ValueTask<LuaValue[]> Impl(LuaThreadAccess access, int funcIndex, CancellationToken cancellationToken)
  215. {
  216. await LuaVirtualMachine.Call(access.Thread, funcIndex, funcIndex, cancellationToken);
  217. var count = access.Stack.Count - funcIndex;
  218. using var results = access.ReadReturnValues(count);
  219. return results.AsSpan().ToArray();
  220. }
  221. }
  222. }