소스 검색

Added type cache, support for embedded structs, and custom field name mapping

Dmitry Panov 9 년 전
부모
커밋
c838446be9
3개의 변경된 파일326개의 추가작업 그리고 21개의 파일을 삭제
  1. 158 20
      object_goreflect.go
  2. 165 1
      object_goreflect_test.go
  3. 3 0
      runtime.go

+ 158 - 20
object_goreflect.go

@@ -6,9 +6,33 @@ import (
 	"reflect"
 )
 
+// FieldNameMapper provides custom mapping between Go and JavaScript property names.
+type FieldNameMapper interface {
+	// FieldName returns a JavaScript name for the given struct field in the given type.
+	// If this method returns "" the field becomes hidden.
+	FieldName(t reflect.Type, f reflect.StructField) string
+
+	// FieldName returns a JavaScript name for the given method in the given type.
+	// If this method returns "" the method becomes hidden.
+	MethodName(t reflect.Type, m reflect.Method) string
+}
+
+type reflectFieldInfo struct {
+	Index     []int
+	Anonymous bool
+}
+
+type reflectTypeInfo struct {
+	Fields                  map[string]reflectFieldInfo
+	Methods                 map[string]int
+	FieldNames, MethodNames []string
+}
+
 type objectGoReflect struct {
 	baseObject
 	origValue, value reflect.Value
+
+	valueTypeInfo, origValueTypeInfo *reflectTypeInfo
 }
 
 func (o *objectGoReflect) init() {
@@ -33,6 +57,9 @@ func (o *objectGoReflect) init() {
 
 	o.baseObject._putProp("toString", o.val.runtime.newNativeFunc(o.toStringFunc, nil, "toString", nil, 0), true, false, true)
 	o.baseObject._putProp("valueOf", o.val.runtime.newNativeFunc(o.valueOfFunc, nil, "valueOf", nil, 0), true, false, true)
+
+	o.valueTypeInfo = o.val.runtime.typeInfo(o.value.Type())
+	o.origValueTypeInfo = o.val.runtime.typeInfo(o.origValue.Type())
 }
 
 func (o *objectGoReflect) toStringFunc(call FunctionCall) Value {
@@ -47,16 +74,37 @@ func (o *objectGoReflect) get(n Value) Value {
 	return o.getStr(n.String())
 }
 
+func (o *objectGoReflect) _getField(jsName string) reflect.Value {
+	if info, exists := o.valueTypeInfo.Fields[jsName]; exists {
+		v := o.value.FieldByIndex(info.Index)
+		if info.Anonymous {
+			v = v.Addr()
+		}
+		return v
+	}
+
+	return reflect.Value{}
+}
+
+func (o *objectGoReflect) _getMethod(jsName string) reflect.Value {
+	if idx, exists := o.origValueTypeInfo.Methods[jsName]; exists {
+		return o.origValue.Method(idx)
+	}
+
+	return reflect.Value{}
+}
+
 func (o *objectGoReflect) _get(name string) Value {
 	if o.value.Kind() == reflect.Struct {
-		if v := o.value.FieldByName(name); v.IsValid() {
+		if v := o._getField(name); v.IsValid() {
 			return o.val.runtime.ToValue(v.Interface())
 		}
 	}
 
-	if v := o.origValue.MethodByName(name); v.IsValid() {
+	if v := o._getMethod(name); v.IsValid() {
 		return o.val.runtime.ToValue(v.Interface())
 	}
+
 	return nil
 }
 
@@ -84,7 +132,7 @@ func (o *objectGoReflect) getPropStr(name string) Value {
 
 func (o *objectGoReflect) getOwnProp(name string) Value {
 	if o.value.Kind() == reflect.Struct {
-		if v := o.value.FieldByName(name); v.IsValid() {
+		if v := o._getField(name); v.IsValid() {
 			return &valueProperty{
 				value:      o.val.runtime.ToValue(v.Interface()),
 				writable:   true,
@@ -93,7 +141,7 @@ func (o *objectGoReflect) getOwnProp(name string) Value {
 		}
 	}
 
-	if v := o.origValue.MethodByName(name); v.IsValid() {
+	if v := o._getMethod(name); v.IsValid() {
 		return &valueProperty{
 			value:      o.val.runtime.ToValue(v.Interface()),
 			enumerable: true,
@@ -115,7 +163,7 @@ func (o *objectGoReflect) putStr(name string, val Value, throw bool) {
 
 func (o *objectGoReflect) _put(name string, val Value, throw bool) bool {
 	if o.value.Kind() == reflect.Struct {
-		if v := o.value.FieldByName(name); v.IsValid() {
+		if v := o._getField(name); v.IsValid() {
 			vv, err := o.val.runtime.toReflectValue(val, v.Type())
 			if err != nil {
 				o.val.runtime.typeErrorResult(throw, "Go struct conversion error: %v", err)
@@ -155,7 +203,7 @@ func (o *objectGoReflect) defineOwnProperty(n Value, descr objectImpl, throw boo
 	name := n.String()
 	if ast.IsExported(name) {
 		if o.value.Kind() == reflect.Struct {
-			if v := o.value.FieldByName(name); v.IsValid() {
+			if v := o._getField(name); v.IsValid() {
 				if !o.val.runtime.checkHostObjectPropertyDescr(name, descr, throw) {
 					return false
 				}
@@ -182,11 +230,11 @@ func (o *objectGoReflect) _has(name string) bool {
 		return false
 	}
 	if o.value.Kind() == reflect.Struct {
-		if v := o.value.FieldByName(name); v.IsValid() {
+		if v := o._getField(name); v.IsValid() {
 			return true
 		}
 	}
-	if v := o.origValue.MethodByName(name); v.IsValid() {
+	if v := o._getMethod(name); v.IsValid() {
 		return true
 	}
 	return false
@@ -291,13 +339,11 @@ type goreflectPropIter struct {
 }
 
 func (i *goreflectPropIter) nextField() (propIterItem, iterNextFunc) {
-	l := i.o.value.NumField()
-	for i.idx < l {
-		name := i.o.value.Type().Field(i.idx).Name
+	names := i.o.valueTypeInfo.FieldNames
+	if i.idx < len(names) {
+		name := names[i.idx]
 		i.idx++
-		if ast.IsExported(name) {
-			return propIterItem{name: name, enumerable: _ENUM_TRUE}, i.nextField
-		}
+		return propIterItem{name: name, enumerable: _ENUM_TRUE}, i.nextField
 	}
 
 	i.idx = 0
@@ -305,13 +351,11 @@ func (i *goreflectPropIter) nextField() (propIterItem, iterNextFunc) {
 }
 
 func (i *goreflectPropIter) nextMethod() (propIterItem, iterNextFunc) {
-	l := i.o.origValue.NumMethod()
-	for i.idx < l {
-		name := i.o.origValue.Type().Method(i.idx).Name
+	names := i.o.origValueTypeInfo.MethodNames
+	if i.idx < len(names) {
+		name := names[i.idx]
 		i.idx++
-		if ast.IsExported(name) {
-			return propIterItem{name: name, enumerable: _ENUM_TRUE}, i.nextMethod
-		}
+		return propIterItem{name: name, enumerable: _ENUM_TRUE}, i.nextMethod
 	}
 
 	if i.recursive {
@@ -354,3 +398,97 @@ func (o *objectGoReflect) equal(other objectImpl) bool {
 	}
 	return false
 }
+
+func (r *Runtime) buildFieldInfo(t reflect.Type, index []int, info *reflectTypeInfo) {
+	n := t.NumField()
+	for i := 0; i < n; i++ {
+		field := t.Field(i)
+		var name string
+		if r.fieldNameMapper == nil {
+			name = field.Name
+			if !ast.IsExported(name) {
+				continue
+			}
+		} else {
+			name = r.fieldNameMapper.FieldName(t, field)
+			if name == "" {
+				continue
+			}
+		}
+
+		idx := make([]int, 0, len(index)+1)
+		idx = append(idx, index...)
+		idx = append(idx, i)
+		if _, exists := info.Fields[name]; !exists {
+			info.FieldNames = append(info.FieldNames, name)
+		}
+		info.Fields[name] = reflectFieldInfo{
+			Index:     idx,
+			Anonymous: field.Anonymous,
+		}
+		if field.Anonymous {
+			idx := make([]int, 0, len(index)+1)
+			idx = append(idx, index...)
+			idx = append(idx, i)
+			r.buildFieldInfo(field.Type, idx, info)
+		}
+	}
+}
+
+func (r *Runtime) buildTypeInfo(t reflect.Type) (info *reflectTypeInfo) {
+	info = new(reflectTypeInfo)
+	if t.Kind() == reflect.Struct {
+		info.Fields = make(map[string]reflectFieldInfo)
+		n := t.NumField()
+		info.FieldNames = make([]string, 0, n)
+		r.buildFieldInfo(t, nil, info)
+	}
+
+	info.Methods = make(map[string]int)
+	n := t.NumMethod()
+	info.MethodNames = make([]string, 0, n)
+	for i := 0; i < n; i++ {
+		method := t.Method(i)
+		var name string
+		if r.fieldNameMapper == nil {
+			name = method.Name
+			if !ast.IsExported(name) {
+				continue
+			}
+		} else {
+			name = r.fieldNameMapper.MethodName(t, method)
+			if name == "" {
+				continue
+			}
+		}
+
+		if _, exists := info.Methods[name]; !exists {
+			info.MethodNames = append(info.MethodNames, name)
+		}
+
+		info.Methods[name] = i
+	}
+	return
+}
+
+func (r *Runtime) typeInfo(t reflect.Type) (info *reflectTypeInfo) {
+	var exists bool
+	if info, exists = r.typeInfoCache[t]; !exists {
+		info = r.buildTypeInfo(t)
+		if r.typeInfoCache == nil {
+			r.typeInfoCache = make(map[reflect.Type]*reflectTypeInfo)
+		}
+		r.typeInfoCache[t] = info
+	}
+
+	return
+}
+
+// Sets a custom field name mapper for Go types. It can be called at any time, however
+// the mapping for any given value is fixed at the point of creation.
+// Setting this to nil restores the default behaviour which is all exported fields and methods are mapped to their
+// original unchanged names.
+func (r *Runtime) SetFieldNameMapper(mapper FieldNameMapper) {
+	r.fieldNameMapper = mapper
+	r.typeInfoCache = nil
+}

+ 165 - 1
object_goreflect_test.go

@@ -1,6 +1,9 @@
 package goja
 
-import "testing"
+import (
+	"reflect"
+	"testing"
+)
 
 func TestGoReflectGet(t *testing.T) {
 	const SCRIPT = `
@@ -397,3 +400,164 @@ func TestGoReflectRedefineMethod(t *testing.T) {
 		t.Fatalf("Expected true, got %v", v)
 	}
 }
+
+func TestGoReflectEmbeddedStruct(t *testing.T) {
+	const SCRIPT = `
+	if (o.ParentField2 !== "ParentField2") {
+		throw new Error("ParentField2 = " + o.ParentField2);
+	}
+
+	if (o.Parent.ParentField2 !== 2) {
+		throw new Error("o.Parent.ParentField2 = " + o.Parent.ParentField2);
+	}
+
+	if (o.ParentField1 !== 1) {
+		throw new Error("o.ParentField1 = " + o.ParentField1);
+
+	}
+
+	if (o.ChildField !== 3) {
+		throw new Error("o.ChildField = " + o.ChildField);
+	}
+
+	var keys = {};
+	for (var k in o) {
+		if (keys[k]) {
+			throw new Error("Duplicate key: " + k);
+		}
+		keys[k] = true;
+	}
+
+	var expectedKeys = ["ParentField2", "ParentField1", "Parent", "ChildField"];
+	for (var i in expectedKeys) {
+		if (!keys[expectedKeys[i]]) {
+			throw new Error("Missing key in enumeration: " + expectedKeys[i]);
+		}
+		delete keys[expectedKeys[i]];
+	}
+
+	var remainingKeys = Object.keys(keys);
+	if (remainingKeys.length > 0) {
+		throw new Error("Unexpected keys: " + remainingKeys);
+	}
+
+	o.ParentField2 = "ParentField22";
+	o.Parent.ParentField2 = 22;
+	o.ParentField1 = 11;
+	o.ChildField = 33;
+	`
+
+	type Parent struct {
+		ParentField1 int
+		ParentField2 int
+	}
+
+	type Child struct {
+		Parent
+		ParentField2 string
+		ChildField   int
+	}
+
+	vm := New()
+	o := Child{
+		Parent: Parent{
+			ParentField1: 1,
+			ParentField2: 2,
+		},
+		ParentField2: "ParentField2",
+		ChildField:   3,
+	}
+	vm.Set("o", &o)
+
+	_, err := vm.RunString(SCRIPT)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if o.ParentField2 != "ParentField22" {
+		t.Fatalf("ParentField2 = %q", o.ParentField2)
+	}
+
+	if o.Parent.ParentField2 != 22 {
+		t.Fatalf("Parent.ParentField2 = %d", o.Parent.ParentField2)
+	}
+
+	if o.ParentField1 != 11 {
+		t.Fatalf("ParentField1 = %d", o.ParentField1)
+	}
+
+	if o.ChildField != 33 {
+		t.Fatalf("ChildField = %d", o.ChildField)
+	}
+}
+
+type jsonTagNamer struct{}
+
+func (*jsonTagNamer) FieldName(t reflect.Type, field reflect.StructField) string {
+	if jsonTag := field.Tag.Get("json"); jsonTag != "" {
+		return jsonTag
+	}
+	return field.Name
+}
+
+func (*jsonTagNamer) MethodName(t reflect.Type, method reflect.Method) string {
+	return method.Name
+}
+
+func TestGoReflectCustomNaming(t *testing.T) {
+
+	type testStructWithJsonTags struct {
+		A string `json:"b"` // <-- script sees field "A" as property "b"
+	}
+
+	o := &testStructWithJsonTags{"Hello world"}
+	r := New()
+	r.SetFieldNameMapper(&jsonTagNamer{})
+	r.Set("fn", func() *testStructWithJsonTags { return o })
+
+	t.Run("get property", func(t *testing.T) {
+		v, err := r.RunString(`fn().b`)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if !v.StrictEquals(newStringValue(o.A)) {
+			t.Fatalf("Expected %q, got %v", o.A, v)
+		}
+	})
+
+	t.Run("set property", func(t *testing.T) {
+		_, err := r.RunString(`fn().b = "Hello universe"`)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if o.A != "Hello universe" {
+			t.Fatalf("Expected \"Hello universe\", got %q", o.A)
+		}
+	})
+
+	t.Run("enumerate properties", func(t *testing.T) {
+		v, err := r.RunString(`Object.keys(fn())`)
+		if err != nil {
+			t.Fatal(err)
+		}
+		if !reflect.DeepEqual(v.Export(), []interface{}{"b"}) {
+			t.Fatalf("Expected [\"b\"], got %v", v.Export())
+		}
+	})
+}
+
+type testGoReflectMethod_Bench struct {
+	field                                   string
+	Test1, Test2, Test3, Test4, Test5, Test string
+}
+
+func BenchmarkGoReflectGet(b *testing.B) {
+	b.StopTimer()
+	vm := New()
+
+	b.StartTimer()
+	for i := 0; i < b.N; i++ {
+		v := vm.ToValue(testGoReflectMethod_O{Test: "Test"}).(*Object)
+		v.Get("Test")
+	}
+}

+ 3 - 0
runtime.go

@@ -73,6 +73,9 @@ type Runtime struct {
 	stringSingleton *stringObject
 	rand            RandSource
 
+	typeInfoCache   map[reflect.Type]*reflectTypeInfo
+	fieldNameMapper FieldNameMapper
+
 	vm *vm
 }