LuaThreadAccessExtensions.cs 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. using System.Runtime.CompilerServices;
  2. // ReSharper disable MethodHasAsyncOverloadWithCancellation
  3. namespace Lua.Runtime;
  4. public static class LuaThreadAccessAccessExtensions
  5. {
  6. public static async ValueTask<int> DoStringAsync(this LuaThreadAccess access, string source, Memory<LuaValue> buffer, string? chunkName = null, CancellationToken cancellationToken = default)
  7. {
  8. access.ThrowIfInvalid();
  9. var closure = access.State.Load(source, chunkName ?? source);
  10. var count = await access.RunAsync(closure, 0, cancellationToken);
  11. using var results = access.ReadReturnValues(count);
  12. results.AsSpan()[..Math.Min(buffer.Length, count)].CopyTo(buffer.Span);
  13. return count;
  14. }
  15. public static async ValueTask<LuaValue[]> DoStringAsync(this LuaThreadAccess access, string source, string? chunkName = null, CancellationToken cancellationToken = default)
  16. {
  17. access.ThrowIfInvalid();
  18. var closure = access.State.Load(source, chunkName ?? source);
  19. var count = await access.RunAsync(closure, 0, cancellationToken);
  20. using var results = access.ReadReturnValues(count);
  21. return results.AsSpan().ToArray();
  22. }
  23. public static async ValueTask<int> DoFileAsync(this LuaThreadAccess access, string path, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
  24. {
  25. access.ThrowIfInvalid();
  26. var closure = await access.State.LoadFileAsync(path, "bt", null, cancellationToken);
  27. var count = await access.RunAsync(closure, 0, cancellationToken);
  28. using var results = access.ReadReturnValues(count);
  29. results.AsSpan()[..Math.Min(buffer.Length, results.Length)].CopyTo(buffer.Span);
  30. return results.Count;
  31. }
  32. public static async ValueTask<LuaValue[]> DoFileAsync(this LuaThreadAccess access, string path, CancellationToken cancellationToken = default)
  33. {
  34. var closure = await access.State.LoadFileAsync(path, "bt", null, cancellationToken);
  35. var count = await access.RunAsync(closure, 0, cancellationToken);
  36. using var results = access.ReadReturnValues(count);
  37. return results.AsSpan().ToArray();
  38. }
  39. public static void Push(this LuaThreadAccess access, LuaValue value)
  40. {
  41. access.ThrowIfInvalid();
  42. access.Stack.Push(value);
  43. }
  44. public static void Push(this LuaThreadAccess access, params ReadOnlySpan<LuaValue> span)
  45. {
  46. access.ThrowIfInvalid();
  47. access.Stack.PushRange(span);
  48. }
  49. public static void Pop(this LuaThreadAccess access, int count)
  50. {
  51. access.ThrowIfInvalid();
  52. access.Stack.Pop(count);
  53. }
  54. public static LuaValue Pop(this LuaThreadAccess access)
  55. {
  56. access.ThrowIfInvalid();
  57. return access.Stack.Pop();
  58. }
  59. public static LuaReturnValuesReader ReadReturnValues(this LuaThreadAccess access, int argumentCount)
  60. {
  61. access.ThrowIfInvalid();
  62. var stack = access.Stack;
  63. return new LuaReturnValuesReader(stack, stack.Count - argumentCount);
  64. }
  65. public static async ValueTask<LuaValue> Arithmetic(this LuaThreadAccess access, LuaValue x, LuaValue y, OpCode opCode, CancellationToken cancellationToken = default)
  66. {
  67. [MethodImpl(MethodImplOptions.NoInlining)]
  68. static double Mod(double a, double b)
  69. {
  70. var mod = a % b;
  71. if ((b > 0 && mod < 0) || (b < 0 && mod > 0))
  72. {
  73. mod += b;
  74. }
  75. return mod;
  76. }
  77. [MethodImpl(MethodImplOptions.AggressiveInlining)]
  78. static double ArithmeticOperation(OpCode code, double a, double b)
  79. {
  80. return code switch
  81. {
  82. OpCode.Add => a + b,
  83. OpCode.Sub => a - b,
  84. OpCode.Mul => a * b,
  85. OpCode.Div => a / b,
  86. OpCode.Mod => Mod(a, b),
  87. OpCode.Pow => Math.Pow(a, b),
  88. _ => throw new InvalidOperationException($"Unsupported arithmetic operation: {code}"),
  89. };
  90. }
  91. if (x.TryReadDouble(out var numX) && y.TryReadDouble(out var numY))
  92. {
  93. return ArithmeticOperation(opCode, numX, numY);
  94. }
  95. access.ThrowIfInvalid();
  96. return await LuaVirtualMachine.ExecuteBinaryOperationMetaMethod(access.Thread, x, y, opCode, cancellationToken);
  97. }
  98. public static async ValueTask<LuaValue> Unary(this LuaThreadAccess access, LuaValue value, OpCode opCode, CancellationToken cancellationToken = default)
  99. {
  100. if (opCode == OpCode.Unm)
  101. {
  102. if (value.TryReadDouble(out var numB))
  103. {
  104. return -numB;
  105. }
  106. }
  107. else if (opCode == OpCode.Len)
  108. {
  109. if (value.TryReadString(out var str))
  110. {
  111. return str.Length;
  112. }
  113. if (value.TryReadTable(out var table))
  114. {
  115. return table.ArrayLength;
  116. }
  117. }
  118. else
  119. {
  120. throw new InvalidOperationException($"Unsupported unary operation: {opCode}");
  121. }
  122. access.ThrowIfInvalid();
  123. return await LuaVirtualMachine.ExecuteUnaryOperationMetaMethod(access.Thread, value, opCode, cancellationToken);
  124. }
  125. public static async ValueTask<bool> Compare(this LuaThreadAccess access, LuaValue x, LuaValue y, OpCode opCode, CancellationToken cancellationToken = default)
  126. {
  127. if (opCode is not (OpCode.Eq or OpCode.Lt or OpCode.Le))
  128. {
  129. throw new InvalidOperationException($"Unsupported compare operation: {opCode}");
  130. }
  131. if (opCode == OpCode.Eq)
  132. {
  133. if (x == y)
  134. {
  135. return true;
  136. }
  137. }
  138. else
  139. {
  140. if (x.TryReadNumber(out var numX) && y.TryReadNumber(out var numY))
  141. {
  142. return opCode == OpCode.Lt ? numX < numY : numX <= numY;
  143. }
  144. if (x.TryReadString(out var strX) && y.TryReadString(out var strY))
  145. {
  146. var c = StringComparer.Ordinal.Compare(strX, strY);
  147. return opCode == OpCode.Lt ? c < 0 : c <= 0;
  148. }
  149. }
  150. access.ThrowIfInvalid();
  151. return await LuaVirtualMachine.ExecuteCompareOperationMetaMethod(access.Thread, x, y, opCode, cancellationToken);
  152. }
  153. public static async ValueTask<LuaValue> GetTable(this LuaThreadAccess access, LuaValue table, LuaValue key, CancellationToken cancellationToken = default)
  154. {
  155. if (table.TryReadTable(out var luaTable))
  156. {
  157. if (luaTable.TryGetValue(key, out var value))
  158. {
  159. return value;
  160. }
  161. }
  162. access.ThrowIfInvalid();
  163. return await LuaVirtualMachine.ExecuteGetTableSlowPath(access.Thread, table, key, cancellationToken);
  164. }
  165. public static async ValueTask SetTable(this LuaThreadAccess access, LuaValue table, LuaValue key, LuaValue value, CancellationToken cancellationToken = default)
  166. {
  167. access.ThrowIfInvalid();
  168. if (key.TryReadNumber(out var numB))
  169. {
  170. if (double.IsNaN(numB))
  171. {
  172. throw new LuaRuntimeException(access.Thread, "table index is NaN");
  173. }
  174. }
  175. if (table.TryReadTable(out var luaTable))
  176. {
  177. ref var valueRef = ref luaTable.FindValue(key);
  178. if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil)
  179. {
  180. valueRef = value;
  181. return;
  182. }
  183. }
  184. await LuaVirtualMachine.ExecuteSetTableSlowPath(access.Thread, table, key, value, cancellationToken);
  185. }
  186. public static ValueTask<LuaValue> Concat(this LuaThreadAccess access, ReadOnlySpan<LuaValue> values, CancellationToken cancellationToken = default)
  187. {
  188. access.ThrowIfInvalid();
  189. access.Stack.PushRange(values);
  190. return Concat(access, values.Length, cancellationToken);
  191. }
  192. public static async ValueTask<LuaValue> Concat(this LuaThreadAccess access, int concatCount, CancellationToken cancellationToken = default)
  193. {
  194. access.ThrowIfInvalid();
  195. return await LuaVirtualMachine.Concat(access.Thread, concatCount, cancellationToken);
  196. }
  197. public static ValueTask<int> Call(this LuaThreadAccess access, int funcIndex, int returnBase, CancellationToken cancellationToken = default)
  198. {
  199. access.ThrowIfInvalid();
  200. return LuaVirtualMachine.Call(access.Thread, funcIndex, returnBase, cancellationToken);
  201. }
  202. public static ValueTask<LuaValue[]> Call(this LuaThreadAccess access, LuaValue function, ReadOnlySpan<LuaValue> arguments, CancellationToken cancellationToken = default)
  203. {
  204. access.ThrowIfInvalid();
  205. var thread = access.Thread;
  206. var funcIndex = thread.Stack.Count;
  207. thread.Stack.Push(function);
  208. thread.Stack.PushRange(arguments);
  209. return Impl(access, funcIndex, cancellationToken);
  210. static async ValueTask<LuaValue[]> Impl(LuaThreadAccess access, int funcIndex, CancellationToken cancellationToken)
  211. {
  212. await LuaVirtualMachine.Call(access.Thread, funcIndex, funcIndex, cancellationToken);
  213. var count = access.Stack.Count - funcIndex;
  214. using var results = access.ReadReturnValues(count);
  215. return results.AsSpan().ToArray();
  216. }
  217. }
  218. }