LuaCoroutine.cs 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. using System.Buffers;
  2. using System.Threading.Tasks.Sources;
  3. using Lua.Internal;
  4. using Lua.Runtime;
  5. namespace Lua;
  6. public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.YieldContext>, IValueTaskSource<LuaCoroutine.ResumeContext>
  7. {
  8. struct YieldContext
  9. {
  10. public required LuaValue[] Results;
  11. }
  12. struct ResumeContext
  13. {
  14. public required LuaValue[] Results;
  15. }
  16. byte status;
  17. bool isFirstCall = true;
  18. ValueTask<int> functionTask;
  19. LuaValue[] buffer;
  20. ManualResetValueTaskSourceCore<ResumeContext> resume;
  21. ManualResetValueTaskSourceCore<YieldContext> yield;
  22. public LuaCoroutine(LuaFunction function, bool isProtectedMode)
  23. {
  24. IsProtectedMode = isProtectedMode;
  25. Function = function;
  26. buffer = ArrayPool<LuaValue>.Shared.Rent(1024);
  27. buffer.AsSpan().Clear();
  28. }
  29. public override LuaThreadStatus GetStatus() => (LuaThreadStatus)status;
  30. public override void UnsafeSetStatus(LuaThreadStatus status)
  31. {
  32. this.status = (byte)status;
  33. }
  34. public bool IsProtectedMode { get; }
  35. public LuaFunction Function { get; }
  36. public override async ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
  37. {
  38. var baseThread = context.Thread;
  39. baseThread.UnsafeSetStatus(LuaThreadStatus.Normal);
  40. context.State.ThreadStack.Push(this);
  41. try
  42. {
  43. switch ((LuaThreadStatus)Volatile.Read(ref status))
  44. {
  45. case LuaThreadStatus.Suspended:
  46. Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
  47. if (isFirstCall)
  48. {
  49. // copy stack value
  50. Stack.EnsureCapacity(baseThread.Stack.Count);
  51. baseThread.Stack.AsSpan().CopyTo(Stack.GetBuffer());
  52. Stack.NotifyTop(baseThread.Stack.Count);
  53. // copy callstack value
  54. CallStack.EnsureCapacity(baseThread.CallStack.Count);
  55. baseThread.CallStack.AsSpan().CopyTo(CallStack.GetBuffer());
  56. CallStack.NotifyTop(baseThread.CallStack.Count);
  57. }
  58. else
  59. {
  60. yield.SetResult(new()
  61. {
  62. Results = context.ArgumentCount == 1
  63. ? []
  64. : context.Arguments[1..].ToArray()
  65. });
  66. }
  67. break;
  68. case LuaThreadStatus.Normal:
  69. case LuaThreadStatus.Running:
  70. if (IsProtectedMode)
  71. {
  72. buffer.Span[0] = false;
  73. buffer.Span[1] = "cannot resume non-suspended coroutine";
  74. return 2;
  75. }
  76. else
  77. {
  78. throw new LuaRuntimeException(context.State.GetTraceback(), "cannot resume non-suspended coroutine");
  79. }
  80. case LuaThreadStatus.Dead:
  81. if (IsProtectedMode)
  82. {
  83. buffer.Span[0] = false;
  84. buffer.Span[1] = "cannot resume dead coroutine";
  85. return 2;
  86. }
  87. else
  88. {
  89. throw new LuaRuntimeException(context.State.GetTraceback(), "cannot resume dead coroutine");
  90. }
  91. }
  92. var resumeTask = new ValueTask<ResumeContext>(this, resume.Version);
  93. CancellationTokenRegistration registration = default;
  94. if (cancellationToken.CanBeCanceled)
  95. {
  96. registration = cancellationToken.UnsafeRegister(static x =>
  97. {
  98. var coroutine = (LuaCoroutine)x!;
  99. coroutine.yield.SetException(new OperationCanceledException());
  100. }, this);
  101. }
  102. try
  103. {
  104. if (isFirstCall)
  105. {
  106. int frameBase;
  107. var variableArgumentCount = Function.GetVariableArgumentCount(context.ArgumentCount - 1);
  108. if (variableArgumentCount > 0)
  109. {
  110. var fixedArgumentCount = context.ArgumentCount - 1 - variableArgumentCount;
  111. for (int i = 0; i < variableArgumentCount; i++)
  112. {
  113. Stack.Push(context.GetArgument(i + fixedArgumentCount + 1));
  114. }
  115. frameBase = Stack.Count;
  116. for (int i = 0; i < fixedArgumentCount; i++)
  117. {
  118. Stack.Push(context.GetArgument(i + 1));
  119. }
  120. }
  121. else
  122. {
  123. frameBase = Stack.Count;
  124. for (int i = 0; i < context.ArgumentCount - 1; i++)
  125. {
  126. Stack.Push(context.GetArgument(i + 1));
  127. }
  128. }
  129. functionTask = Function.InvokeAsync(new()
  130. {
  131. State = context.State,
  132. Thread = this,
  133. ArgumentCount = context.ArgumentCount - 1,
  134. FrameBase = frameBase,
  135. ChunkName = Function.Name,
  136. RootChunkName = context.RootChunkName,
  137. }, this.buffer, cancellationToken).Preserve();
  138. Volatile.Write(ref isFirstCall, false);
  139. }
  140. (var index, var result0, var result1) = await ValueTaskEx.WhenAny(resumeTask, functionTask!);
  141. if (index == 0)
  142. {
  143. var results = result0.Results;
  144. buffer.Span[0] = true;
  145. for (int i = 0; i < results.Length; i++)
  146. {
  147. buffer.Span[i + 1] = results[i];
  148. }
  149. return results.Length + 1;
  150. }
  151. else
  152. {
  153. var resultCount = functionTask!.Result;
  154. Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
  155. buffer.Span[0] = true;
  156. this.buffer[0..resultCount].CopyTo(buffer.Span[1..]);
  157. ArrayPool<LuaValue>.Shared.Return(this.buffer);
  158. return 1 + resultCount;
  159. }
  160. }
  161. catch (Exception ex) when (ex is not OperationCanceledException)
  162. {
  163. if (IsProtectedMode)
  164. {
  165. ArrayPool<LuaValue>.Shared.Return(this.buffer);
  166. Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
  167. buffer.Span[0] = false;
  168. buffer.Span[1] = ex.Message;
  169. return 2;
  170. }
  171. else
  172. {
  173. throw;
  174. }
  175. }
  176. finally
  177. {
  178. registration.Dispose();
  179. resume.Reset();
  180. }
  181. }
  182. finally
  183. {
  184. context.State.ThreadStack.Pop();
  185. baseThread.UnsafeSetStatus(LuaThreadStatus.Running);
  186. }
  187. }
  188. public override async ValueTask<int> Yield(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
  189. {
  190. if (Volatile.Read(ref status) != (byte)LuaThreadStatus.Running)
  191. {
  192. throw new LuaRuntimeException(context.State.GetTraceback(), "cannot call yield on a coroutine that is not currently running");
  193. }
  194. resume.SetResult(new()
  195. {
  196. Results = context.Arguments.ToArray(),
  197. });
  198. Volatile.Write(ref status, (byte)LuaThreadStatus.Suspended);
  199. CancellationTokenRegistration registration = default;
  200. if (cancellationToken.CanBeCanceled)
  201. {
  202. registration = cancellationToken.UnsafeRegister(static x =>
  203. {
  204. var coroutine = (LuaCoroutine)x!;
  205. coroutine.yield.SetException(new OperationCanceledException());
  206. }, this);
  207. }
  208. RETRY:
  209. try
  210. {
  211. var result = await new ValueTask<YieldContext>(this, yield.Version);
  212. for (int i = 0; i < result.Results.Length; i++)
  213. {
  214. buffer.Span[i] = result.Results[i];
  215. }
  216. return result.Results.Length;
  217. }
  218. catch (Exception ex) when (ex is not OperationCanceledException)
  219. {
  220. yield.Reset();
  221. goto RETRY;
  222. }
  223. finally
  224. {
  225. registration.Dispose();
  226. yield.Reset();
  227. }
  228. }
  229. YieldContext IValueTaskSource<YieldContext>.GetResult(short token)
  230. {
  231. return yield.GetResult(token);
  232. }
  233. ValueTaskSourceStatus IValueTaskSource<YieldContext>.GetStatus(short token)
  234. {
  235. return yield.GetStatus(token);
  236. }
  237. void IValueTaskSource<YieldContext>.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
  238. {
  239. yield.OnCompleted(continuation, state, token, flags);
  240. }
  241. ResumeContext IValueTaskSource<ResumeContext>.GetResult(short token)
  242. {
  243. return resume.GetResult(token);
  244. }
  245. ValueTaskSourceStatus IValueTaskSource<ResumeContext>.GetStatus(short token)
  246. {
  247. return resume.GetStatus(token);
  248. }
  249. void IValueTaskSource<ResumeContext>.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
  250. {
  251. resume.OnCompleted(continuation, state, token, flags);
  252. }
  253. }