Forráskód Böngészése

The function equality situation (#7441)

* [tests] add test for method/closure equality

* make tests more annoying

* [eval] fix field closure equality

* add type-change test

* more tests

* use Reflect.isFunction & Reflect.compareMethods for targets without function equality support

* don't use Reflect.isFunction if exprs types are TFun

* minor

* more minor refactoring

* [hl] disable type change test as it's another issue

* comment [skip ci]

* additional test

* [jvm] call equals on functions in Reflect.compare

* generate equals when a known function type is on the stack

* neko hates this trick

* add more tests

---------

Co-authored-by: Aleksandr Kuzmenko <[email protected]>
Simon Krajewski 1 éve
szülő
commit
f8d283819b

+ 26 - 3
src/generators/genjvm.ml

@@ -64,6 +64,7 @@ type generation_context = {
 	mutable preprocessor : jsignature preprocessor;
 	default_export_config : export_config;
 	typed_functions : JvmFunctions.typed_functions;
+	known_typed_functions : (path,unit) Hashtbl.t;
 	closure_paths : (path * string * jsignature,path) Hashtbl.t;
 	enum_paths : (path,unit) Hashtbl.t;
 	detail_times : bool;
@@ -433,10 +434,16 @@ let associate_functional_interfaces gctx f t =
 		) gctx.functional_interfaces
 	end
 
+let create_typed_function gctx kind jc jm context =
+	let wf = new JvmFunctions.typed_function gctx.typed_functions kind jc jm context in
+	let jc = wf#get_class in
+	Hashtbl.add gctx.known_typed_functions jc#get_this_path ();
+	wf
+
 let create_field_closure gctx jc path_this jm name jsig t =
 	let jsig_this = object_path_sig path_this in
 	let context = ["this",jsig_this] in
-	let wf = new JvmFunctions.typed_function gctx.typed_functions (FuncMember(path_this,name)) jc jm context in
+	let wf = create_typed_function gctx (FuncMember(path_this,name)) jc jm context in
 	begin match t with
 		| None ->
 			()
@@ -597,7 +604,7 @@ class texpr_to_jvm
 			| RValue(_,Some s) -> Some s
 			| _ -> None
 		in
-		let wf = new JvmFunctions.typed_function gctx.typed_functions (FuncLocal name) jc jm context in
+		let wf = create_typed_function gctx (FuncLocal name) jc jm context in
 		associate_functional_interfaces gctx wf e.etype;
 		let jc_closure = wf#get_class in
 		ignore(wf#generate_constructor (env <> []));
@@ -687,7 +694,7 @@ class texpr_to_jvm
 		let closure_path = try
 			Hashtbl.find gctx.closure_paths (path,name,jsig)
 		with Not_found ->
-			let wf = new JvmFunctions.typed_function gctx.typed_functions (FuncStatic(path,name)) jc jm [] in
+			let wf = create_typed_function gctx (FuncStatic(path,name)) jc jm [] in
 			associate_functional_interfaces gctx wf t;
 			let jc_closure = wf#get_class in
 			ignore(wf#generate_constructor false);
@@ -1041,6 +1048,15 @@ class texpr_to_jvm
 	method get_binop_type t1 t2 = self#get_binop_type_sig (jsignature_of_type gctx t1) (jsignature_of_type gctx t2)
 
 	method do_compare op =
+		let fun_compare path1 sig2 = match sig2 with
+			| TObject(path2,_) when path1 = path2 ->
+				jm#invokevirtual path2 "equals" (method_sig [object_sig] (Some TBool));
+				CmpNormal(op,TBool)
+			| _ ->
+				jm#invokestatic haxe_jvm_path "compare" (method_sig [object_sig;object_sig] (Some TInt));
+				let op = flip_cmp_op op in
+				CmpNormal(op,TBool)
+		in
 		match code#get_stack#get_stack_items 2 with
 		| [TInt | TByte | TChar | TShort | TBool;TInt | TByte | TChar | TShort | TBool] ->
 			let op = flip_cmp_op op in
@@ -1054,6 +1070,11 @@ class texpr_to_jvm
 			jm#invokestatic haxe_jvm_path "compare" (method_sig [object_sig;object_sig] (Some TInt));
 			let op = flip_cmp_op op in
 			CmpNormal(op,TBool)
+		| [sig2;TObject(path1,_)] when Hashtbl.mem gctx.known_typed_functions path1 ->
+			fun_compare path1 sig2
+		| [TObject(path1,_);sig2] when Hashtbl.mem gctx.known_typed_functions path1 ->
+			code#swap;
+			fun_compare path1 sig2
 		| [(TObject _ | TArray _ | TMethod _) as t1;(TObject _ | TArray _ | TMethod _) as t2] ->
 			CmpSpecial ((if op = CmpEq then code#if_acmp_ne else code#if_acmp_eq) t1 t2)
 		| [TDouble;TDouble] ->
@@ -3088,6 +3109,7 @@ let generate jvm_flag com =
 		preprocessor = Obj.magic ();
 		typedef_interfaces = Obj.magic ();
 		typed_functions = new JvmFunctions.typed_functions;
+		known_typed_functions = Hashtbl.create 0;
 		closure_paths = Hashtbl.create 0;
 		enum_paths = Hashtbl.create 0;
 		default_export_config = {
@@ -3099,6 +3121,7 @@ let generate jvm_flag com =
 		dynamic_level = dynamic_level;
 		functional_interfaces = [];
 	} in
+	Hashtbl.add gctx.known_typed_functions haxe_function_path ();
 	gctx.preprocessor <- new preprocessor com.basic (jsignature_of_type gctx);
 	gctx.typedef_interfaces <- new typedef_interfaces gctx.preprocessor#get_infos anon_identification;
 	gctx.typedef_interfaces#add_interface_rewrite (["haxe";"root"],"Iterator") (["java";"util"],"Iterator") true;

+ 1 - 0
src/macro/eval/evalValue.ml

@@ -320,6 +320,7 @@ let rec equals a b = match a,b with
 	| VVector vv1,VVector vv2 -> vv1 == vv2
 	| VFunction(vf1,_),VFunction(vf2,_) -> vf1 == vf2
 	| VPrototype proto1,VPrototype proto2 -> proto1.ppath = proto2.ppath
+	| VFieldClosure(v1,f1),VFieldClosure(v2,f2) -> f1 == f2 && equals v1 v2
 	| VNativeString s1,VNativeString s2 -> s1 = s2
 	| VHandle h1,VHandle h2 -> same_handle h1 h2
 	| VLazy f1,_ -> equals (!f1()) b

+ 14 - 9
std/jvm/_std/Reflect.hx

@@ -21,7 +21,6 @@
  */
 
 import jvm.Jvm;
-
 import java.lang.Number;
 import java.lang.Long.LongClass;
 import java.lang.Double.DoubleClass;
@@ -110,24 +109,24 @@ class Reflect {
 			return 1;
 		}
 		if (Jvm.instanceof(a, Number) && Jvm.instanceof(b, Number)) {
-			var a = (cast a:Number);
-			var b = (cast b:Number);
+			var a = (cast a : Number);
+			var b = (cast b : Number);
 			inline function isBig(v:Number)
 				return Jvm.instanceof(v, BigDecimal) || Jvm.instanceof(v, BigInteger);
 			inline function cmpLongTo(long:Number, another:Number) {
-				if(Jvm.instanceof(another, DoubleClass)) {
+				if (Jvm.instanceof(another, DoubleClass)) {
 					return new BigDecimal(long.longValue()).compareTo(new BigDecimal(another.doubleValue()));
-				} else if(Jvm.instanceof(another, FloatClass)) {
+				} else if (Jvm.instanceof(another, FloatClass)) {
 					return new BigDecimal(long.longValue()).compareTo(new BigDecimal(another.floatValue()));
 				} else {
 					return LongClass.compare(long.longValue(), another.longValue());
 				}
 			}
-			if(isBig(a) || isBig(b))
-				return new BigDecimal((cast a:java.lang.Object).toString()).compareTo((cast a:java.lang.Object).toString());
-			if(Jvm.instanceof(a, LongClass))
+			if (isBig(a) || isBig(b))
+				return new BigDecimal((cast a : java.lang.Object).toString()).compareTo((cast a : java.lang.Object).toString());
+			if (Jvm.instanceof(a, LongClass))
 				return cmpLongTo(a, b);
-			if(Jvm.instanceof(b, LongClass))
+			if (Jvm.instanceof(b, LongClass))
 				return -1 * cmpLongTo(b, a);
 			return DoubleClass.compare(a.doubleValue(), b.doubleValue());
 		}
@@ -137,6 +136,12 @@ class Reflect {
 			}
 			return (cast a : java.NativeString).compareTo(cast b);
 		}
+		if (Jvm.instanceof(a, jvm.Function)) {
+			if (!(cast a : jvm.Function).equals(cast b)) {
+				return -1;
+			}
+			return 0;
+		}
 		return -1;
 	}
 

+ 268 - 0
tests/unit/src/unit/issues/Issue6705.hx

@@ -0,0 +1,268 @@
+package unit.issues;
+
+class Issue6705 extends unit.Test {
+	function memberFunction() {}
+
+	static function staticFunction() {}
+
+	function memberFunction1(i:Int) {}
+
+	static function staticFunction1(i:Int) {}
+
+	@:pure(false) static function alias<T>(t:T)
+		return t;
+
+	@:pure(false) static function equalsT<T>(a:T, b:T)
+		return a == b;
+
+	function test() {
+		function localFunction() {}
+
+		var localClosure = localFunction;
+		var memberClosure = memberFunction;
+		var staticClosure = staticFunction;
+
+		t(localFunction == alias(localFunction));
+		t(localFunction == alias(localClosure));
+		#if !neko
+		t(memberFunction == alias(memberFunction));
+		t(memberFunction == alias(memberClosure));
+		#end
+		t(staticFunction == alias(staticFunction));
+		t(staticFunction == alias(staticClosure));
+		t(localClosure == alias(localClosure));
+		t(memberClosure == alias(memberClosure));
+		t(staticClosure == alias(staticClosure));
+		t(localFunction == alias(localFunction));
+
+		t(equalsT(localFunction, localClosure));
+		#if !neko
+		t(equalsT(memberFunction, memberFunction));
+		t(equalsT(memberFunction, memberClosure));
+		#end
+		t(equalsT(staticFunction, staticFunction));
+		t(equalsT(staticFunction, staticClosure));
+		t(equalsT(localClosure, localClosure));
+		t(equalsT(memberClosure, memberClosure));
+		t(equalsT(staticClosure, staticClosure));
+
+		t(Reflect.compareMethods(localFunction, alias(localFunction)));
+		t(Reflect.compareMethods(localFunction, alias(localClosure)));
+		t(Reflect.compareMethods(memberFunction, alias(memberFunction)));
+		t(Reflect.compareMethods(memberFunction, alias(memberClosure)));
+		t(Reflect.compareMethods(staticFunction, alias(staticFunction)));
+		t(Reflect.compareMethods(staticFunction, alias(staticClosure)));
+		t(Reflect.compareMethods(localClosure, alias(localClosure)));
+		t(Reflect.compareMethods(memberClosure, alias(memberClosure)));
+		t(Reflect.compareMethods(staticClosure, alias(staticClosure)));
+
+		var array = [localFunction, memberFunction, staticFunction];
+		eq(0, array.indexOf(localFunction));
+		#if !neko
+		eq(1, array.indexOf(memberFunction));
+		#end
+		eq(2, array.indexOf(staticFunction));
+	}
+
+	function testButEverythingIsDynamic() {
+		function localFunction() {}
+
+		var localClosure:Dynamic = localFunction;
+		var memberClosure:Dynamic = memberFunction;
+		var staticClosure:Dynamic = staticFunction;
+
+		t((localFunction : Dynamic) == alias((localFunction : Dynamic)));
+		t((localFunction : Dynamic) == alias(localClosure));
+		#if !neko
+		t((memberFunction : Dynamic) == alias((memberFunction : Dynamic)));
+		t((memberFunction : Dynamic) == alias(memberClosure));
+		#end
+		t((staticFunction : Dynamic) == alias((staticFunction : Dynamic)));
+		t((staticFunction : Dynamic) == alias(staticClosure));
+		t(localClosure == alias(localClosure));
+		t(memberClosure == alias(memberClosure));
+		t(staticClosure == alias(staticClosure));
+		t((localFunction : Dynamic) == alias((localFunction : Dynamic)));
+
+		t(equalsT((localFunction : Dynamic), localClosure));
+		#if !neko
+		t(equalsT((memberFunction : Dynamic), (memberFunction : Dynamic)));
+		t(equalsT((memberFunction : Dynamic), memberClosure));
+		#end
+		t(equalsT((staticFunction : Dynamic), (staticFunction : Dynamic)));
+		t(equalsT((staticFunction : Dynamic), staticClosure));
+		t(equalsT(localClosure, localClosure));
+		t(equalsT(memberClosure, memberClosure));
+		t(equalsT(staticClosure, staticClosure));
+
+		t(Reflect.compareMethods((localFunction : Dynamic), alias((localFunction : Dynamic))));
+		t(Reflect.compareMethods((localFunction : Dynamic), alias(localClosure)));
+		t(Reflect.compareMethods((memberFunction : Dynamic), alias((memberFunction : Dynamic))));
+		t(Reflect.compareMethods((memberFunction : Dynamic), alias(memberClosure)));
+		t(Reflect.compareMethods((staticFunction : Dynamic), alias((staticFunction : Dynamic))));
+		t(Reflect.compareMethods((staticFunction : Dynamic), alias(staticClosure)));
+		t(Reflect.compareMethods(localClosure, alias(localClosure)));
+		t(Reflect.compareMethods(memberClosure, alias(memberClosure)));
+		t(Reflect.compareMethods(staticClosure, alias(staticClosure)));
+
+		var array = [(localFunction : Dynamic), (memberFunction : Dynamic), (staticFunction : Dynamic)];
+		eq(0, array.indexOf((localFunction : Dynamic)));
+		#if !neko
+		eq(1, array.indexOf((memberFunction : Dynamic)));
+		#end
+		eq(2, array.indexOf((staticFunction : Dynamic)));
+	}
+
+	function testButEverythingIsBackwards() {
+		function localFunction() {}
+
+		var localClosure = localFunction;
+		var memberClosure = memberFunction;
+		var staticClosure = staticFunction;
+
+		t(alias(localFunction) == localFunction);
+		t(alias(localClosure) == localFunction);
+		#if !neko
+		t(alias(memberFunction) == memberFunction);
+		t(alias(memberClosure) == memberFunction);
+		#end
+		t(alias(staticFunction) == staticFunction);
+		t(alias(staticClosure) == staticFunction);
+		t(alias(localClosure) == localClosure);
+		t(alias(memberClosure) == memberClosure);
+		t(alias(staticClosure) == staticClosure);
+		t(alias(localFunction) == localFunction);
+
+		t(equalsT(localClosure, localFunction));
+		#if !neko
+		t(equalsT(memberFunction, memberFunction));
+		t(equalsT(memberClosure, memberFunction));
+		#end
+		t(equalsT(staticFunction, staticFunction));
+		t(equalsT(staticClosure, staticFunction));
+		t(equalsT(localClosure, localClosure));
+		t(equalsT(memberClosure, memberClosure));
+		t(equalsT(staticClosure, staticClosure));
+
+		t(Reflect.compareMethods(alias(localFunction), localFunction));
+		t(Reflect.compareMethods(alias(localClosure), localFunction));
+		t(Reflect.compareMethods(alias(memberFunction), memberFunction));
+		t(Reflect.compareMethods(alias(memberClosure), memberFunction));
+		t(Reflect.compareMethods(alias(staticFunction), staticFunction));
+		t(Reflect.compareMethods(alias(staticClosure), staticFunction));
+		t(Reflect.compareMethods(alias(localClosure), localClosure));
+		t(Reflect.compareMethods(alias(memberClosure), memberClosure));
+		t(Reflect.compareMethods(alias(staticClosure), staticClosure));
+	}
+
+	function testButEverythingIsBackwardsAndDynamic() {
+		function localFunction() {}
+
+		var localClosure:Dynamic = localFunction;
+		var memberClosure:Dynamic = memberFunction;
+		var staticClosure:Dynamic = staticFunction;
+
+		t(alias((localFunction : Dynamic)) == (localFunction : Dynamic));
+		t(alias(localClosure) == (localFunction : Dynamic));
+		#if !neko
+		t(alias((memberFunction : Dynamic)) == (memberFunction : Dynamic));
+		t(alias(memberClosure) == (memberFunction : Dynamic));
+		#end
+		t(alias((staticFunction : Dynamic)) == (staticFunction : Dynamic));
+		t(alias(staticClosure) == (staticFunction : Dynamic));
+		t(alias(localClosure) == localClosure);
+		t(alias(memberClosure) == memberClosure);
+		t(alias(staticClosure) == staticClosure);
+		t(alias((localFunction : Dynamic)) == (localFunction : Dynamic));
+
+		t(equalsT(localClosure, (localFunction : Dynamic)));
+		#if !neko
+		t(equalsT((memberFunction : Dynamic), (memberFunction : Dynamic)));
+		t(equalsT(memberClosure, (memberFunction : Dynamic)));
+		#end
+		t(equalsT((staticFunction : Dynamic), (staticFunction : Dynamic)));
+		t(equalsT(staticClosure, (staticFunction : Dynamic)));
+		t(equalsT(localClosure, localClosure));
+		t(equalsT(memberClosure, memberClosure));
+		t(equalsT(staticClosure, staticClosure));
+
+		t(Reflect.compareMethods(alias((localFunction : Dynamic)), (localFunction : Dynamic)));
+		t(Reflect.compareMethods(alias(localClosure), (localFunction : Dynamic)));
+		t(Reflect.compareMethods(alias((memberFunction : Dynamic)), (memberFunction : Dynamic)));
+		t(Reflect.compareMethods(alias(memberClosure), (memberFunction : Dynamic)));
+		t(Reflect.compareMethods(alias((staticFunction : Dynamic)), (staticFunction : Dynamic)));
+		t(Reflect.compareMethods(alias(staticClosure), (staticFunction : Dynamic)));
+		t(Reflect.compareMethods(alias(localClosure), localClosure));
+		t(Reflect.compareMethods(alias(memberClosure), memberClosure));
+		t(Reflect.compareMethods(alias(staticClosure), staticClosure));
+	}
+
+	function test1() {
+		function localFunction1(i:Int) {}
+
+		var localClosure1 = localFunction1;
+		var memberClosure1 = memberFunction1;
+		var staticClosure1 = staticFunction1;
+
+		t(localFunction1 == alias(localFunction1));
+		t(localFunction1 == alias(localClosure1));
+		#if !neko
+		t(memberFunction1 == alias(memberFunction1));
+		t(memberFunction1 == alias(memberClosure1));
+		#end
+		t(staticFunction1 == alias(staticFunction1));
+		t(staticFunction1 == alias(staticClosure1));
+		t(localClosure1 == alias(localClosure1));
+		t(memberClosure1 == alias(memberClosure1));
+		t(staticClosure1 == alias(staticClosure1));
+
+		t(equalsT(localFunction1, localFunction1));
+		t(equalsT(localFunction1, localClosure1));
+		#if !neko
+		t(equalsT(memberFunction1, memberFunction1));
+		t(equalsT(memberFunction1, memberClosure1));
+		#end
+		t(equalsT(staticFunction1, staticFunction1));
+		t(equalsT(staticFunction1, staticClosure1));
+		t(equalsT(localClosure1, localClosure1));
+		t(equalsT(memberClosure1, memberClosure1));
+		t(equalsT(staticClosure1, staticClosure1));
+
+		t(Reflect.compareMethods(localFunction1, alias(localFunction1)));
+		t(Reflect.compareMethods(localFunction1, alias(localClosure1)));
+		t(Reflect.compareMethods(memberFunction1, alias(memberFunction1)));
+		t(Reflect.compareMethods(memberFunction1, alias(memberClosure1)));
+		t(Reflect.compareMethods(staticFunction1, alias(staticFunction1)));
+		t(Reflect.compareMethods(staticFunction1, alias(staticClosure1)));
+		t(Reflect.compareMethods(localClosure1, alias(localClosure1)));
+		t(Reflect.compareMethods(memberClosure1, alias(memberClosure1)));
+		t(Reflect.compareMethods(staticClosure1, alias(staticClosure1)));
+
+		var array = [localFunction1, memberFunction1, staticFunction1];
+		eq(0, array.indexOf(localFunction1));
+		#if !neko
+		eq(1, array.indexOf(memberFunction1));
+		#end
+		eq(2, array.indexOf(staticFunction1));
+	}
+
+	#if !neko
+	function testCallsEqualityCheck_tempvarCallExprs() {
+		var callCount = 0;
+		function getFn():() -> Void {
+			callCount++;
+			return memberFunction;
+		}
+		t(getFn() == getFn());
+		eq(2, callCount);
+	}
+	#end
+
+	#if !hl // @see https://github.com/HaxeFoundation/haxe/issues/10031
+	function testTypeChange() {
+		function f1(x:Float) {}
+		var f2:Int->Void = f1;
+		t(f1 == f2);
+	}
+	#end
+}