Browse Source

Added AsyncContextTracker.Completed. Closes #472.

Dmitry Panov 2 years ago
parent
commit
17fd568758
2 changed files with 81 additions and 20 deletions
  1. 26 20
      func.go
  2. 55 0
      func_test.go

+ 26 - 20
func.go

@@ -22,10 +22,12 @@ var (
 // AsyncContextTracker is a handler that allows to track async function's execution context. Every time an async
 // AsyncContextTracker is a handler that allows to track async function's execution context. Every time an async
 // function is suspended on 'await', Suspended() is called. The trackingObject it returns is remembered and
 // function is suspended on 'await', Suspended() is called. The trackingObject it returns is remembered and
 // the next time just before the context is resumed, Resumed is called with the same trackingObject as argument.
 // the next time just before the context is resumed, Resumed is called with the same trackingObject as argument.
+// Completed is called when an async function returns or throws.
 // To register it call Runtime.SetAsyncContextTracker().
 // To register it call Runtime.SetAsyncContextTracker().
 type AsyncContextTracker interface {
 type AsyncContextTracker interface {
 	Suspended() (trackingObject interface{})
 	Suspended() (trackingObject interface{})
 	Resumed(trackingObject interface{})
 	Resumed(trackingObject interface{})
+	Completed()
 }
 }
 
 
 type funcObjectImpl interface {
 type funcObjectImpl interface {
@@ -652,29 +654,33 @@ func (ar *asyncRunner) onRejected(call FunctionCall) Value {
 }
 }
 
 
 func (ar *asyncRunner) step(res Value, done bool, ex *Exception) {
 func (ar *asyncRunner) step(res Value, done bool, ex *Exception) {
-	if ex != nil {
-		ar.promiseCap.reject(ex.val)
-		return
-	}
-	if done {
-		ar.promiseCap.resolve(res)
-	} else {
-		// await
-		r := ar.f.runtime
+	r := ar.f.runtime
+	if done || ex != nil {
 		if tracker := r.asyncContextTracker; tracker != nil {
 		if tracker := r.asyncContextTracker; tracker != nil {
-			ar.trackingObj = tracker.Suspended()
+			tracker.Completed()
 		}
 		}
-		promise := r.promiseResolve(r.global.Promise, res)
-		promise.self.(*Promise).addReactions(&promiseReaction{
-			typ:         promiseReactionFulfill,
-			handler:     &jobCallback{callback: ar.onFulfilled},
-			asyncRunner: ar,
-		}, &promiseReaction{
-			typ:         promiseReactionReject,
-			handler:     &jobCallback{callback: ar.onRejected},
-			asyncRunner: ar,
-		})
+		if ex == nil {
+			ar.promiseCap.resolve(res)
+		} else {
+			ar.promiseCap.reject(ex.val)
+		}
+		return
 	}
 	}
+
+	// await
+	if tracker := r.asyncContextTracker; tracker != nil {
+		ar.trackingObj = tracker.Suspended()
+	}
+	promise := r.promiseResolve(r.global.Promise, res)
+	promise.self.(*Promise).addReactions(&promiseReaction{
+		typ:         promiseReactionFulfill,
+		handler:     &jobCallback{callback: ar.onFulfilled},
+		asyncRunner: ar,
+	}, &promiseReaction{
+		typ:         promiseReactionReject,
+		handler:     &jobCallback{callback: ar.onRejected},
+		asyncRunner: ar,
+	})
 }
 }
 
 
 func (ar *asyncRunner) start(nArgs int) {
 func (ar *asyncRunner) start(nArgs int) {

+ 55 - 0
func_test.go

@@ -160,3 +160,58 @@ func ExampleAssertConstructor() {
 	}
 	}
 	// Output: Test
 	// Output: Test
 }
 }
+
+type testAsyncContextTracker struct {
+	groupNamePtr *string
+}
+
+func (s testAsyncContextTracker) Suspended() interface{} {
+	return *s.groupNamePtr
+}
+
+func (s testAsyncContextTracker) Resumed(trackingObj interface{}) {
+	*s.groupNamePtr = trackingObj.(string)
+}
+
+func (s testAsyncContextTracker) Completed() {
+	*s.groupNamePtr = ""
+}
+
+func TestAsyncContextTracker(t *testing.T) {
+	r := New()
+	var groupName string
+
+	group := func(name string, asyncFunc func(FunctionCall) Value) Value {
+		prevGroupName := groupName
+		defer func() {
+			groupName = prevGroupName
+		}()
+		groupName = name
+		return asyncFunc(FunctionCall{})
+	}
+	r.SetAsyncContextTracker(testAsyncContextTracker{groupNamePtr: &groupName})
+	r.Set("group", group)
+	r.Set("check", func(expectedGroup, msg string) {
+		if groupName != expectedGroup {
+			t.Fatalf("Unexpected group (%q), expected %q in %s", groupName, expectedGroup, msg)
+		}
+	})
+
+	_, err := r.RunString(`
+		group("1", async () => {
+		  check("1", "line A");
+		  await 3;
+		  check("1", "line B");
+		  group("2", async () => {
+		     check("2", "line C");
+		     await 4;
+		     check("2", "line D");
+		 })
+		}).then(() => {
+            check("", "line E");
+		})
+	`)
+	if err != nil {
+		t.Fatal(err)
+	}
+}