Browse Source

Reworked AsyncContextTracker. Fixes #474 (#480)

Dmitry Panov 2 years ago
parent
commit
746f7ebdc5
3 changed files with 133 additions and 38 deletions
  1. 12 0
      builtin_promise.go
  2. 13 22
      func.go
  3. 108 16
      func_test.go

+ 12 - 0
builtin_promise.go

@@ -42,6 +42,7 @@ type promiseReaction struct {
 	typ         promiseReactionType
 	handler     *jobCallback
 	asyncRunner *asyncRunner
+	asyncCtx    interface{}
 }
 
 var typePromise = reflect.TypeOf((*Promise)(nil))
@@ -150,6 +151,11 @@ func (p *Promise) export(*objectExportCtx) interface{} {
 
 func (p *Promise) addReactions(fulfillReaction *promiseReaction, rejectReaction *promiseReaction) {
 	r := p.val.runtime
+	if tracker := r.asyncContextTracker; tracker != nil {
+		ctx := tracker.Grab()
+		fulfillReaction.asyncCtx = ctx
+		rejectReaction.asyncCtx = ctx
+	}
 	switch p.state {
 	case PromiseStatePending:
 		p.fulfillReactions = append(p.fulfillReactions, fulfillReaction)
@@ -200,6 +206,9 @@ func (r *Runtime) newPromiseReactionJob(reaction *promiseReaction, argument Valu
 				fulfill = true
 			}
 		} else {
+			if tracker := r.asyncContextTracker; tracker != nil {
+				tracker.Resumed(reaction.asyncCtx)
+			}
 			ex := r.vm.try(func() {
 				handlerResult = r.callJobCallback(reaction.handler, _undefined, argument)
 				fulfill = true
@@ -207,6 +216,9 @@ func (r *Runtime) newPromiseReactionJob(reaction *promiseReaction, argument Valu
 			if ex != nil {
 				handlerResult = ex.val
 			}
+			if tracker := r.asyncContextTracker; tracker != nil {
+				tracker.Exited()
+			}
 		}
 		if reaction.capability != nil {
 			if fulfill {

+ 13 - 22
func.go

@@ -34,15 +34,22 @@ var (
 	yieldEmpty       = &yieldMarker{resultType: resultYield}
 )
 
-// AsyncContextTracker is a handler that allows to track async function's execution context. Every time an async
-// function is suspended on 'await', Suspended() is called. The trackingObject it returns is remembered and
-// the next time just before the context is resumed, Resumed is called with the same trackingObject as argument.
-// Completed is called when an async function returns or throws.
+// AsyncContextTracker is a handler that allows to track an async execution context to ensure it remains
+// consistent across all callback invocations.
+// Whenever a Promise reaction job is scheduled the Grab method is called. It is supposed to return the
+// current context. The same context will be supplied to the Resumed method before the reaction job is
+// executed. The Exited method is called after the reaction job is finished.
+// This means that for each invocation of the Grab method there will be exactly one subsequent invocation
+// of Resumed and then Exited methods (assuming the Promise is fulfilled or rejected). Also, the Resumed/Exited
+// calls cannot be nested, so Exited can simply clear the current context instead of popping from a stack.
+// Note, this works for both async functions and regular Promise.then()/Promise.catch() callbacks.
+// See TestAsyncContextTracker for more insight.
+//
 // To register it call Runtime.SetAsyncContextTracker().
 type AsyncContextTracker interface {
-	Suspended() (trackingObject interface{})
+	Grab() (trackingObject interface{})
 	Resumed(trackingObject interface{})
-	Completed()
+	Exited()
 }
 
 type funcObjectImpl interface {
@@ -664,15 +671,9 @@ type asyncRunner struct {
 	promiseCap *promiseCapability
 	f          *Object
 	vmCall     func(*vm, int)
-
-	trackingObj interface{}
 }
 
 func (ar *asyncRunner) onFulfilled(call FunctionCall) Value {
-	if tracker := ar.f.runtime.asyncContextTracker; tracker != nil {
-		tracker.Resumed(ar.trackingObj)
-		ar.trackingObj = nil
-	}
 	ar.gen.vm.curAsyncRunner = ar
 	defer func() {
 		ar.gen.vm.curAsyncRunner = nil
@@ -684,10 +685,6 @@ func (ar *asyncRunner) onFulfilled(call FunctionCall) Value {
 }
 
 func (ar *asyncRunner) onRejected(call FunctionCall) Value {
-	if tracker := ar.f.runtime.asyncContextTracker; tracker != nil {
-		tracker.Resumed(ar.trackingObj)
-		ar.trackingObj = nil
-	}
 	ar.gen.vm.curAsyncRunner = ar
 	defer func() {
 		ar.gen.vm.curAsyncRunner = nil
@@ -701,9 +698,6 @@ func (ar *asyncRunner) onRejected(call FunctionCall) Value {
 func (ar *asyncRunner) step(res Value, done bool, ex *Exception) {
 	r := ar.f.runtime
 	if done || ex != nil {
-		if tracker := r.asyncContextTracker; tracker != nil {
-			tracker.Completed()
-		}
 		if ex == nil {
 			ar.promiseCap.resolve(res)
 		} else {
@@ -713,9 +707,6 @@ func (ar *asyncRunner) step(res Value, done bool, ex *Exception) {
 	}
 
 	// await
-	if tracker := r.asyncContextTracker; tracker != nil {
-		ar.trackingObj = tracker.Suspended()
-	}
 	promise := r.promiseResolve(r.global.Promise, res)
 	promise.self.(*Promise).addReactions(&promiseReaction{
 		typ:         promiseReactionFulfill,

+ 108 - 16
func_test.go

@@ -161,43 +161,88 @@ func ExampleAssertConstructor() {
 	// Output: Test
 }
 
+type testAsyncCtx struct {
+	group    string
+	refCount int
+}
+
 type testAsyncContextTracker struct {
-	groupNamePtr *string
+	ctx     *testAsyncCtx
+	logFunc func(...interface{})
+	resumed bool
 }
 
-func (s testAsyncContextTracker) Suspended() interface{} {
-	return *s.groupNamePtr
+func (s *testAsyncContextTracker) Grab() interface{} {
+	ctx := s.ctx
+	if ctx != nil {
+		s.logFunc("Grab", ctx.group)
+		ctx.refCount++
+	}
+	return ctx
 }
 
-func (s testAsyncContextTracker) Resumed(trackingObj interface{}) {
-	*s.groupNamePtr = trackingObj.(string)
+func (s *testAsyncContextTracker) Resumed(trackingObj interface{}) {
+	s.logFunc("Resumed", trackingObj)
+	if s.resumed {
+		panic("Nested Resumed() calls")
+	}
+	s.ctx = trackingObj.(*testAsyncCtx)
+	s.resumed = true
 }
 
-func (s testAsyncContextTracker) Completed() {
-	*s.groupNamePtr = ""
+func (s *testAsyncContextTracker) releaseCtx() {
+	s.ctx.refCount--
+	if s.ctx.refCount < 0 {
+		panic("refCount < 0")
+	}
+	if s.ctx.refCount == 0 {
+		s.logFunc(s.ctx.group, "is finished")
+	}
+}
+
+func (s *testAsyncContextTracker) Exited() {
+	s.logFunc("Exited")
+	if s.ctx != nil {
+		s.releaseCtx()
+		s.ctx = nil
+	}
+	s.resumed = false
 }
 
 func TestAsyncContextTracker(t *testing.T) {
 	r := New()
-	var groupName string
+	var tracker testAsyncContextTracker
+	tracker.logFunc = t.Log
 
 	group := func(name string, asyncFunc func(FunctionCall) Value) Value {
-		prevGroupName := groupName
+		prevCtx := tracker.ctx
 		defer func() {
-			groupName = prevGroupName
+			t.Log("Returned", name)
+			tracker.releaseCtx()
+			tracker.ctx = prevCtx
 		}()
-		groupName = name
+		tracker.ctx = &testAsyncCtx{
+			group:    name,
+			refCount: 1,
+		}
+		t.Log("Set", name)
 		return asyncFunc(FunctionCall{})
 	}
-	r.SetAsyncContextTracker(testAsyncContextTracker{groupNamePtr: &groupName})
+	r.SetAsyncContextTracker(&tracker)
 	r.Set("group", group)
 	r.Set("check", func(expectedGroup, msg string) {
+		var groupName string
+		if tracker.ctx != nil {
+			groupName = tracker.ctx.group
+		}
 		if groupName != expectedGroup {
 			t.Fatalf("Unexpected group (%q), expected %q in %s", groupName, expectedGroup, msg)
 		}
+		t.Log("In", msg)
 	})
 
-	_, err := r.RunString(`
+	t.Run("", func(t *testing.T) {
+		_, err := r.RunString(`
 		group("1", async () => {
 		  check("1", "line A");
 		  await 3;
@@ -210,8 +255,55 @@ func TestAsyncContextTracker(t *testing.T) {
 		}).then(() => {
             check("", "line E");
 		})
+		`)
+		if err != nil {
+			t.Fatal(err)
+		}
+	})
+
+	t.Run("", func(t *testing.T) {
+		_, err := r.RunString(`
+		group("some", async () => {
+			check("some", "line A");
+		    (async () => {
+				check("some", "line B");
+		        await 1;
+				check("some", "line C");
+		        await 2;
+				check("some", "line D");
+		    })();
+			check("some", "line E");
+		});
 	`)
-	if err != nil {
-		t.Fatal(err)
-	}
+		if err != nil {
+			t.Fatal(err)
+		}
+	})
+
+	t.Run("", func(t *testing.T) {
+		_, err := r.RunString(`
+	group("Main", async () => {
+		check("Main", "0.1");
+		await Promise.all([
+			group("A", async () => {
+				check("A", "1.1");
+				await 1;
+				check("A", "1.2");
+			}),
+			(async () => {
+				check("Main", "3.1");
+			})(),
+			group("B", async () => {
+				check("B", "2.1");
+				await 2;
+				check("B", "2.2");
+			})
+		]);
+		check("Main", "0.2");
+	});
+	`)
+		if err != nil {
+			t.Fatal(err)
+		}
+	})
 }