Browse Source

Tolerate mismatching number of arguments when calling Go from JS

Dmitry Panov 8 years ago
parent
commit
033ece8aec
2 changed files with 111 additions and 15 deletions
  1. 24 15
      runtime.go
  2. 87 0
      runtime_test.go

+ 24 - 15
runtime.go

@@ -984,18 +984,25 @@ func (r *Runtime) wrapReflectFunc(value reflect.Value) func(FunctionCall) Value
 	return func(call FunctionCall) Value {
 		typ := value.Type()
 		nargs := typ.NumIn()
-		if len(call.Arguments) != nargs {
+		var in []reflect.Value
+
+		if l := len(call.Arguments); l < nargs {
+			// fill missing arguments with zero values
+			n := nargs
 			if typ.IsVariadic() {
-				if len(call.Arguments) < nargs-1 {
-					panic(r.newError(r.global.TypeError, "expected at least %d arguments; got %d", nargs-1, len(call.Arguments)))
-				}
-			} else {
-				panic(r.newError(r.global.TypeError, "expected %d argument(s); got %d", nargs, len(call.Arguments)))
+				n--
+			}
+			in = make([]reflect.Value, n)
+			for i := l; i < n; i++ {
+				in[i] = reflect.Zero(typ.In(i))
+			}
+		} else {
+			if l > nargs && !typ.IsVariadic() {
+				l = nargs
 			}
+			in = make([]reflect.Value, l)
 		}
 
-		in := make([]reflect.Value, len(call.Arguments))
-
 		callSlice := false
 		for i, a := range call.Arguments {
 			var t reflect.Type
@@ -1007,6 +1014,8 @@ func (r *Runtime) wrapReflectFunc(value reflect.Value) func(FunctionCall) Value
 				}
 
 				t = typ.In(n).Elem()
+			} else if n > nargs-1 { // ignore extra arguments
+				break
 			} else {
 				t = typ.In(n)
 			}
@@ -1106,13 +1115,6 @@ func (r *Runtime) toReflectValue(v Value, typ reflect.Type) (reflect.Value, erro
 		return reflect.ValueOf(uint8(i)).Convert(typ), nil
 	}
 
-	t := reflect.TypeOf(v)
-	if t.AssignableTo(typ) {
-		return reflect.ValueOf(v), nil
-	} else if t.ConvertibleTo(typ) {
-		return reflect.ValueOf(v).Convert(typ), nil
-	}
-
 	if typ == typeCallable {
 		if fn, ok := AssertFunction(v); ok {
 			return reflect.ValueOf(fn), nil
@@ -1129,6 +1131,13 @@ func (r *Runtime) toReflectValue(v Value, typ reflect.Type) (reflect.Value, erro
 		return reflect.ValueOf(v.Export()).Convert(typ), nil
 	}
 
+	t := reflect.TypeOf(v)
+	if t.AssignableTo(typ) {
+		return reflect.ValueOf(v), nil
+	} else if t.ConvertibleTo(typ) {
+		return reflect.ValueOf(v).Convert(typ), nil
+	}
+
 	switch typ.Kind() {
 	case reflect.Slice:
 		if o, ok := v.(*Object); ok {

+ 87 - 0
runtime_test.go

@@ -2,6 +2,7 @@ package goja
 
 import (
 	"errors"
+	"fmt"
 	"reflect"
 	"testing"
 	"time"
@@ -806,6 +807,92 @@ func TestObjectKeys(t *testing.T) {
 	}
 }
 
+func TestReflectCallExtraArgs(t *testing.T) {
+	const SCRIPT = `
+	f(41, "extra")
+	`
+	f := func(x int) int {
+		return x + 1
+	}
+
+	vm := New()
+	vm.Set("f", f)
+
+	prg := MustCompile("test.js", SCRIPT, false)
+
+	res, err := vm.RunProgram(prg)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !res.StrictEquals(intToValue(42)) {
+		t.Fatalf("Unexpected result: %v", res)
+	}
+}
+
+func TestReflectCallNotEnoughArgs(t *testing.T) {
+	const SCRIPT = `
+	f(42)
+	`
+	vm := New()
+
+	f := func(x, y int, z *int, s string) (int, error) {
+		if z != nil {
+			return 0, fmt.Errorf("z is not nil")
+		}
+		if s != "" {
+			return 0, fmt.Errorf("s is not \"\"")
+		}
+		return x + y, nil
+	}
+
+	vm.Set("f", f)
+
+	prg := MustCompile("test.js", SCRIPT, false)
+
+	res, err := vm.RunProgram(prg)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !res.StrictEquals(intToValue(42)) {
+		t.Fatalf("Unexpected result: %v", res)
+	}
+}
+
+func TestReflectCallVariadic(t *testing.T) {
+	const SCRIPT = `
+	var r = f("Hello %s, %d", "test", 42);
+	if (r !== "Hello test, 42") {
+		throw new Error("test 1 has failed: " + r);
+	}
+
+	r = f("Hello %s, %d", ["test", 42]);
+	if (r !== "Hello test, 42") {
+		throw new Error("test 2 has failed: " + r);
+	}
+
+	r = f("Hello %s, %s", "test");
+	if (r !== "Hello test, %!s(MISSING)") {
+		throw new Error("test 3 has failed: " + r);
+	}
+
+	r = f();
+	if (r !== "") {
+		throw new Error("test 4 has failed: " + r);
+	}
+
+	`
+
+	vm := New()
+	vm.Set("f", fmt.Sprintf)
+
+	prg := MustCompile("test.js", SCRIPT, false)
+
+	_, err := vm.RunProgram(prg)
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
 /*
 func TestArrayConcatSparse(t *testing.T) {
 function foo(a,b,c)