LuaCoroutine.cs 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. using System.Threading.Tasks.Sources;
  2. using Lua.Internal;
  3. namespace Lua;
  4. public sealed class LuaCoroutine : LuaThread, IValueTaskSource<LuaCoroutine.YieldContext>, IValueTaskSource<LuaCoroutine.ResumeContext>
  5. {
  6. struct YieldContext
  7. {
  8. }
  9. struct ResumeContext
  10. {
  11. public LuaValue[] Results;
  12. }
  13. byte status;
  14. bool isFirstCall = true;
  15. LuaState threadState;
  16. ValueTask<int> functionTask;
  17. ManualResetValueTaskSourceCore<ResumeContext> resume;
  18. ManualResetValueTaskSourceCore<YieldContext> yield;
  19. public override LuaThreadStatus GetStatus() => (LuaThreadStatus)status;
  20. public override void UnsafeSetStatus(LuaThreadStatus status)
  21. {
  22. this.status = (byte)status;
  23. }
  24. public bool IsProtectedMode { get; }
  25. public LuaFunction Function { get; }
  26. internal LuaCoroutine(LuaState state, LuaFunction function, bool isProtectedMode)
  27. {
  28. IsProtectedMode = isProtectedMode;
  29. threadState = state;
  30. Function = function;
  31. }
  32. public override async ValueTask<int> Resume(LuaFunctionExecutionContext context, Memory<LuaValue> buffer, CancellationToken cancellationToken = default)
  33. {
  34. var baseThread = context.Thread;
  35. baseThread.UnsafeSetStatus(LuaThreadStatus.Normal);
  36. context.State.ThreadStack.Push(this);
  37. try
  38. {
  39. switch ((LuaThreadStatus)Volatile.Read(ref status))
  40. {
  41. case LuaThreadStatus.Suspended:
  42. Volatile.Write(ref status, (byte)LuaThreadStatus.Running);
  43. if (isFirstCall)
  44. {
  45. // copy stack value
  46. Stack.EnsureCapacity(baseThread.Stack.Count);
  47. baseThread.Stack.AsSpan().CopyTo(Stack.GetBuffer());
  48. Stack.NotifyTop(baseThread.Stack.Count);
  49. functionTask = Function.InvokeAsync(new()
  50. {
  51. State = threadState,
  52. Thread = this,
  53. ArgumentCount = context.ArgumentCount - 1,
  54. ChunkName = Function.Name,
  55. RootChunkName = context.RootChunkName,
  56. }, buffer[1..], cancellationToken).Preserve();
  57. }
  58. else
  59. {
  60. yield.SetResult(new());
  61. }
  62. break;
  63. case LuaThreadStatus.Normal:
  64. case LuaThreadStatus.Running:
  65. if (IsProtectedMode)
  66. {
  67. buffer.Span[0] = false;
  68. buffer.Span[1] = "cannot resume non-suspended coroutine";
  69. return 2;
  70. }
  71. else
  72. {
  73. throw new InvalidOperationException("cannot resume non-suspended coroutine");
  74. }
  75. case LuaThreadStatus.Dead:
  76. if (IsProtectedMode)
  77. {
  78. buffer.Span[0] = false;
  79. buffer.Span[1] = "cannot resume dead coroutine";
  80. return 2;
  81. }
  82. else
  83. {
  84. throw new InvalidOperationException("cannot resume dead coroutine");
  85. }
  86. }
  87. var resumeTask = new ValueTask<ResumeContext>(this, resume.Version);
  88. CancellationTokenRegistration registration = default;
  89. if (cancellationToken.CanBeCanceled)
  90. {
  91. registration = cancellationToken.UnsafeRegister(static x =>
  92. {
  93. var coroutine = (LuaCoroutine)x!;
  94. coroutine.yield.SetException(new OperationCanceledException());
  95. }, this);
  96. }
  97. try
  98. {
  99. if (isFirstCall)
  100. {
  101. for (int i = 0; i < context.ArgumentCount - 1; i++)
  102. {
  103. threadState.Push(context.Arguments[i + 1]);
  104. }
  105. Volatile.Write(ref isFirstCall, false);
  106. }
  107. (var index, var result0, var result1) = await ValueTaskEx.WhenAny(resumeTask, functionTask!);
  108. if (index == 0)
  109. {
  110. var results = result0.Results;
  111. buffer.Span[0] = true;
  112. for (int i = 0; i < results.Length; i++)
  113. {
  114. buffer.Span[i + 1] = results[i];
  115. }
  116. return results.Length + 1;
  117. }
  118. else
  119. {
  120. Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
  121. buffer.Span[0] = true;
  122. return 1 + functionTask!.Result;
  123. }
  124. }
  125. catch (Exception ex) when (ex is not OperationCanceledException)
  126. {
  127. if (IsProtectedMode)
  128. {
  129. Volatile.Write(ref status, (byte)LuaThreadStatus.Dead);
  130. buffer.Span[0] = false;
  131. buffer.Span[1] = ex.Message;
  132. return 2;
  133. }
  134. else
  135. {
  136. throw;
  137. }
  138. }
  139. finally
  140. {
  141. registration.Dispose();
  142. resume.Reset();
  143. }
  144. }
  145. finally
  146. {
  147. context.State.ThreadStack.Pop();
  148. baseThread.UnsafeSetStatus(LuaThreadStatus.Running);
  149. }
  150. }
  151. public override async ValueTask Yield(LuaFunctionExecutionContext context, CancellationToken cancellationToken = default)
  152. {
  153. if (Volatile.Read(ref status) != (byte)LuaThreadStatus.Running)
  154. {
  155. throw new InvalidOperationException("cannot call yield on a coroutine that is not currently running");
  156. }
  157. resume.SetResult(new()
  158. {
  159. Results = context.Arguments.ToArray(),
  160. });
  161. Volatile.Write(ref status, (byte)LuaThreadStatus.Suspended);
  162. CancellationTokenRegistration registration = default;
  163. if (cancellationToken.CanBeCanceled)
  164. {
  165. registration = cancellationToken.UnsafeRegister(static x =>
  166. {
  167. var coroutine = (LuaCoroutine)x!;
  168. coroutine.yield.SetException(new OperationCanceledException());
  169. }, this);
  170. }
  171. RETRY:
  172. try
  173. {
  174. await new ValueTask<YieldContext>(this, yield.Version);
  175. }
  176. catch (Exception ex) when (ex is not OperationCanceledException)
  177. {
  178. yield.Reset();
  179. goto RETRY;
  180. }
  181. registration.Dispose();
  182. yield.Reset();
  183. }
  184. YieldContext IValueTaskSource<YieldContext>.GetResult(short token)
  185. {
  186. return yield.GetResult(token);
  187. }
  188. ValueTaskSourceStatus IValueTaskSource<YieldContext>.GetStatus(short token)
  189. {
  190. return yield.GetStatus(token);
  191. }
  192. void IValueTaskSource<YieldContext>.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
  193. {
  194. yield.OnCompleted(continuation, state, token, flags);
  195. }
  196. ResumeContext IValueTaskSource<ResumeContext>.GetResult(short token)
  197. {
  198. return resume.GetResult(token);
  199. }
  200. ValueTaskSourceStatus IValueTaskSource<ResumeContext>.GetStatus(short token)
  201. {
  202. return resume.GetStatus(token);
  203. }
  204. void IValueTaskSource<ResumeContext>.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
  205. {
  206. resume.OnCompleted(continuation, state, token, flags);
  207. }
  208. }