LuaCoroutine.cs 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. using System.Buffers;
  2. using System.Threading.Tasks.Sources;
  3. using Lua.Internal;
  4. using Lua.Runtime;
  5. namespace Lua;
  6. public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.YieldContext>, IValueTaskSource<LuaCoroutine.ResumeContext>
  7. {
  8. struct YieldContext
  9. {
  10. public required LuaValue[] Results;
  11. }
  12. struct ResumeContext
  13. {
  14. public required LuaValue[] Results;
  15. }
  16. byte status;
  17. bool isFirstCall = true;
  18. ValueTask<int> functionTask;
  19. int returnFrameBase;
  20. ManualResetValueTaskSourceCore<ResumeContext> resume;
  21. ManualResetValueTaskSourceCore<YieldContext> yield;
  22. Traceback? traceback;
  23. internal int ReturnFrameBase => returnFrameBase;
  24. public LuaCoroutine(LuaFunction function, bool isProtectedMode)
  25. {
  26. IsProtectedMode = isProtectedMode;
  27. Function = function;
  28. }
  29. public override LuaThreadStatus GetStatus() => (LuaThreadStatus)status;
  30. public override void UnsafeSetStatus(LuaThreadStatus status)
  31. {
  32. this.status = (byte)status;
  33. }
  34. public bool IsProtectedMode { get; }
  35. public LuaFunction Function { get; }
  36. internal Traceback? LuaTraceback => traceback;
  37. public override async ValueTask<int> ResumeAsync(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
  38. {
  39. var baseThread = context.Thread;
  40. baseThread.UnsafeSetStatus(LuaThreadStatus.Normal);
  41. context.State.ThreadStack.Push(this);
  42. try
  43. {
  44. switch ((LuaThreadStatus)Volatile.Read(ref status))
  45. {
  46. case LuaThreadStatus.Suspended:
  47. Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
  48. if (isFirstCall)
  49. {
  50. // copy stack value
  51. var argCount = context.ArgumentCount;
  52. Stack.EnsureCapacity(argCount);
  53. baseThread.Stack.AsSpan()[^argCount..].CopyTo(Stack.GetBuffer());
  54. Stack.NotifyTop(argCount);
  55. }
  56. else
  57. {
  58. yield.SetResult(new()
  59. {
  60. Results = context.ArgumentCount == 1
  61. ? []
  62. : context.Arguments[1..].ToArray()
  63. });
  64. }
  65. break;
  66. case LuaThreadStatus.Normal:
  67. case LuaThreadStatus.Running:
  68. if (IsProtectedMode)
  69. {
  70. return context.Return(false, "cannot resume non-suspended coroutine");
  71. }
  72. else
  73. {
  74. throw new LuaRuntimeException(context.State.GetTraceback(), "cannot resume non-suspended coroutine");
  75. }
  76. case LuaThreadStatus.Dead:
  77. if (IsProtectedMode)
  78. {
  79. return context.Return(false, "cannot resume dead coroutine");
  80. }
  81. else
  82. {
  83. throw new LuaRuntimeException(context.State.GetTraceback(), "cannot resume dead coroutine");
  84. }
  85. }
  86. var resumeTask = new ValueTask<ResumeContext>(this, resume.Version);
  87. CancellationTokenRegistration registration = default;
  88. if (cancellationToken.CanBeCanceled)
  89. {
  90. registration = cancellationToken.UnsafeRegister(static x =>
  91. {
  92. var coroutine = (LuaCoroutine)x!;
  93. coroutine.yield.SetException(new OperationCanceledException());
  94. }, this);
  95. }
  96. try
  97. {
  98. if (isFirstCall)
  99. {
  100. returnFrameBase = Stack.Count;
  101. int frameBase;
  102. var variableArgumentCount = Function.GetVariableArgumentCount(context.ArgumentCount - 1);
  103. if (variableArgumentCount > 0)
  104. {
  105. var fixedArgumentCount = context.ArgumentCount - 1 - variableArgumentCount;
  106. var args = context.Arguments;
  107. Stack.PushRange(args.Slice(1 + fixedArgumentCount, variableArgumentCount));
  108. frameBase = Stack.Count;
  109. Stack.PushRange(args.Slice(1, fixedArgumentCount));
  110. }
  111. else
  112. {
  113. frameBase = Stack.Count;
  114. Stack.PushRange(context.Arguments[1..]);
  115. }
  116. functionTask = Function.InvokeAsync(new()
  117. {
  118. State = context.State,
  119. Thread = this,
  120. ArgumentCount = context.ArgumentCount - 1,
  121. FrameBase = frameBase,
  122. ReturnFrameBase = returnFrameBase
  123. }, cancellationToken).Preserve();
  124. Volatile.Write(ref isFirstCall, false);
  125. }
  126. var (index, result0, result1) = await ValueTaskEx.WhenAny(resumeTask, functionTask!);
  127. if (index == 0)
  128. {
  129. var results = result0.Results;
  130. return context.Return(true, results.AsSpan());
  131. }
  132. else
  133. {
  134. Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
  135. var count = context.Return(true, Stack.AsSpan()[returnFrameBase..]);
  136. Stack.PopUntil(returnFrameBase);
  137. return count;
  138. }
  139. }
  140. catch (Exception ex) when (ex is not OperationCanceledException)
  141. {
  142. if (IsProtectedMode)
  143. {
  144. traceback = (ex as LuaRuntimeException)?.LuaTraceback;
  145. Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
  146. return context.Return(false, ex is LuaRuntimeException luaEx ? luaEx.ErrorObject : ex.Message);
  147. }
  148. else
  149. {
  150. throw;
  151. }
  152. }
  153. finally
  154. {
  155. registration.Dispose();
  156. resume.Reset();
  157. }
  158. }
  159. finally
  160. {
  161. context.State.ThreadStack.Pop();
  162. baseThread.UnsafeSetStatus(LuaThreadStatus.Running);
  163. }
  164. }
  165. public override async ValueTask<int> YieldAsync(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
  166. {
  167. if (Volatile.Read(ref status) != (byte)LuaThreadStatus.Running)
  168. {
  169. throw new LuaRuntimeException(context.State.GetTraceback(), "cannot call yield on a coroutine that is not currently running");
  170. }
  171. if (context.Thread.GetCallStackFrames()[^2].Function is not LuaClosure)
  172. {
  173. throw new LuaRuntimeException(context.State.GetTraceback(), "attempt to yield across a C#-call boundary");
  174. }
  175. resume.SetResult(new() { Results = context.Arguments.ToArray(), });
  176. Volatile.Write(ref status, (byte)LuaThreadStatus.Suspended);
  177. CancellationTokenRegistration registration = default;
  178. if (cancellationToken.CanBeCanceled)
  179. {
  180. registration = cancellationToken.UnsafeRegister(static x =>
  181. {
  182. var coroutine = (LuaCoroutine)x!;
  183. coroutine.yield.SetException(new OperationCanceledException());
  184. }, this);
  185. }
  186. RETRY:
  187. try
  188. {
  189. var result = await new ValueTask<YieldContext>(this, yield.Version);
  190. return (context.Return(result.Results));
  191. }
  192. catch (Exception ex) when (ex is not OperationCanceledException)
  193. {
  194. yield.Reset();
  195. goto RETRY;
  196. }
  197. finally
  198. {
  199. registration.Dispose();
  200. yield.Reset();
  201. }
  202. }
  203. YieldContext IValueTaskSource<YieldContext>.GetResult(short token)
  204. {
  205. return yield.GetResult(token);
  206. }
  207. ValueTaskSourceStatus IValueTaskSource<YieldContext>.GetStatus(short token)
  208. {
  209. return yield.GetStatus(token);
  210. }
  211. void IValueTaskSource<YieldContext>.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
  212. {
  213. yield.OnCompleted(continuation, state, token, flags);
  214. }
  215. ResumeContext IValueTaskSource<ResumeContext>.GetResult(short token)
  216. {
  217. return resume.GetResult(token);
  218. }
  219. ValueTaskSourceStatus IValueTaskSource<ResumeContext>.GetStatus(short token)
  220. {
  221. return resume.GetStatus(token);
  222. }
  223. void IValueTaskSource<ResumeContext>.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
  224. {
  225. resume.OnCompleted(continuation, state, token, flags);
  226. }
  227. }