Browse Source

implement EnumValueMap.compareArg properly (#12139)

Simon Krajewski 4 months ago
parent
commit
6d9ef698f2

+ 0 - 21
src/context/abstractCast.ml

@@ -204,27 +204,6 @@ let find_multitype_specialization' platform a pl p =
 	let uctx = default_unification_context () in
 	let m = mk_mono() in
 	let tl,definitive_types = Abstract.find_multitype_params a pl in
-	if platform = Globals.Js && a.a_path = (["haxe";"ds"],"Map") then begin match tl with
-		| t1 :: _ ->
-			let stack = ref [] in
-			let rec loop t =
-				if List.exists (fun t2 -> fast_eq t t2) !stack then
-					t
-				else begin
-					stack := t :: !stack;
-					match follow t with
-					| TAbstract ({ a_path = [],"Class" },_) ->
-						raise_typing_error (Printf.sprintf "Cannot use %s as key type to Map because Class<T> is not comparable on JavaScript" (s_type (print_context()) t1)) p;
-					| TEnum(en,tl) ->
-						PMap.iter (fun _ ef -> ignore(loop ef.ef_type)) en.e_constrs;
-						Type.map loop t
-					| t ->
-						Type.map loop t
-				end
-			in
-			ignore(loop t1)
-		| _ -> die "" __LOC__
-	end;
 	let _,cf =
 		try
 			let t = Abstract.find_to uctx m a tl in

+ 36 - 6
std/haxe/ds/EnumValueMap.hx

@@ -29,6 +29,9 @@ package haxe.ds;
 	parameter is not an enum value, `Reflect.compare` is used to compare them.
 **/
 class EnumValueMap<K:EnumValue, V> extends haxe.ds.BalancedTree<K, V> implements haxe.Constraints.IMap<K, V> {
+	var sortIndices = new Map<{}, Int>();
+	var sortIndicesCounter = 0;
+
 	override function compare(k1:EnumValue, k2:EnumValue):Int {
 		var d = k1.getIndex() - k2.getIndex();
 		if (d != 0)
@@ -66,13 +69,40 @@ class EnumValueMap<K:EnumValue, V> extends haxe.ds.BalancedTree<K, V> implements
 		return 0;
 	}
 
+	function getSortIndex(v:{}) {
+		if (sortIndices.exists(v)) {
+			return sortIndices[(v)];
+		}
+		var sort = sortIndicesCounter++;
+		sortIndices[v] = sort;
+		return sort;
+	}
+
 	function compareArg(v1:Dynamic, v2:Dynamic):Int {
-		return if (Reflect.isEnumValue(v1) && Reflect.isEnumValue(v2)) {
-			compare(v1, v2);
-		} else if (Std.isOfType(v1, Array) && Std.isOfType(v2, Array)) {
-			compareArgs(v1, v2);
-		} else {
-			Reflect.compare(v1, v2);
+		var vt1 = Type.typeof(v1);
+		var vt2 = Type.typeof(v1);
+		return switch [vt1, vt2] {
+			case [TNull, TNull]:
+				// null is always equal to itself
+				0;
+			case [TInt, TInt] | [TFloat, TFloat] | [TBool, TBool]:
+				// Basic types can be compared directly
+				Reflect.compare(v1, v2);
+			case [TClass(String), TClass(String)]:
+				// Strings as well
+				Reflect.compare(v1, v2);
+			case [TEnum(_), TEnum(_)]:
+				// For enum values we recurse
+				compare(v1, v2);
+			case [TObject, TObject] | [TClass(_), TClass(_)]:
+				// Objects get a sort index associated with them which defines the ordering
+				Reflect.compare(getSortIndex(v1), getSortIndex(v2));
+			case [TFunction, TFunction] | [TUnknown, TUnknown]:
+				// We cannot compare functions and the unknown
+				throw 'Unsupported comparison types: $vt1 $vt2';
+			case _:
+				// If the types differ, we sort by the ValueType index
+				Reflect.compare(vt1.getIndex(), vt2.getIndex());
 		}
 	}
 

+ 56 - 0
tests/unit/src/unit/issues/Issue2479.hx

@@ -0,0 +1,56 @@
+package unit.issues;
+
+import Type;
+
+class Issue2479 extends unit.Test {
+	function testAlex() {
+		var map:Map<EFoo, Int> = [];
+		var k1 = new FooParam();
+		var k2 = new FooParam();
+		var k3 = new FooParam();
+		var k4 = new FooParam();
+		map[Foo(k1)] = 1;
+		map[Foo(k2)] = 2;
+		map[Foo(k3)] = 3;
+		map[Foo(k4)] = 4;
+		eq(1, map[Foo(k1)]);
+		eq(2, map[Foo(k2)]);
+		eq(3, map[Foo(k3)]);
+		eq(4, map[Foo(k4)]);
+	}
+
+	function testLptr() {
+		var map:Map<ValueType, String> = new Map();
+		map.set(TInt, "TInt");
+		map.set(TClass(String), "TClass(String)");
+		map.set(TClass(Color), "TClass(Color)");
+		map.set(TClass(Length), "TClass(Length)");
+		map.set(TFloat, "TFloat");
+
+		eq("TInt", map.get(TInt));
+		eq("TFloat", map.get(TFloat));
+		eq("TClass(String)", map.get(TClass(String)));
+		eq("TClass(Color)", map.get(TClass(Color)));
+		eq("TClass(Length)", map.get(TClass(Length)));
+	}
+}
+
+private class FooParam {
+	public function new() {}
+}
+
+private enum EFoo {
+	Foo(a:FooParam);
+}
+
+private class Color {
+	public function clone() {
+		return this;
+	}
+}
+
+private class Length {
+	public function clone() {
+		return this;
+	}
+}

+ 9 - 28
tests/unit/src/unitstd/haxe/ds/EnumValueMap.unit.hx

@@ -1,4 +1,5 @@
 var em = new haxe.ds.EnumValueMap();
+
 var test = [
 	1 => EContinue,
 	2 => EBreak,
@@ -6,8 +7,9 @@ var test = [
 	4 => EConst(CString("foo")),
 	5 => EArray(null, null),
 ];
+
 for (k in test.keys()) {
-	em.set(test[k],k);
+	em.set(test[k], k);
 }
 for (k in test.keys()) {
 	eq(k, em.get(test[k]));
@@ -21,39 +23,30 @@ for (k in test.keys()) {
 for (k in test.keys()) {
 	eq(false, em.exists(test[k]));
 }
-
-var em = [
-	EConst(CIdent("test")) => "test",
-	EArray(null,null) => "bar",
-	EBreak => "baz"
-];
+var em = [EConst(CIdent("test")) => "test", EArray(null, null) => "bar", EBreak => "baz"];
 em.exists(EConst(CIdent("test"))) == true;
 em.exists(EConst(CIdent("test2"))) == false;
 em.get(EConst(CIdent("test"))) == "test";
 em.remove(EConst(CIdent("test"))) == true;
 em.exists(EConst(CIdent("test"))) == false;
 em.get(EConst(CIdent("test"))) == null;
-
 em.exists(EArray(null, null)) == true;
 em.get(EArray(null, null)) == "bar";
 em.remove(EArray(null, null)) == true;
 em.exists(EArray(null, null)) == false;
 em.get(EArray(null, null)) == null;
-
 em.exists(EBreak) == true;
 em.get(EBreak) == "baz";
 em.remove(EBreak) == true;
 em.exists(EBreak) == false;
 em.get(EBreak) == null;
-
 var evm = new haxe.ds.EnumValueMap();
-evm.set(EVMA,1);
-evm.set(EVMA,2);
+evm.set(EVMA, 1);
+evm.set(EVMA, 2);
 evm.exists(EVMA) == true;
 evm.get(EVMA) == 2;
 evm.remove(EVMA) == true;
 evm.exists(EVMA) == false;
-
 evm.set(EVMB(), 8);
 evm.set(EVMB(), 9);
 evm.set(EVMB(null), 10);
@@ -65,40 +58,28 @@ evm.remove(EVMB()) == true;
 evm.remove(EVMB()) == false;
 evm.exists(EVMB()) == false;
 evm.exists(EVMB(null)) == false;
-
 evm.set(EVMC("foo"), 4);
 evm.set(EVMC("foo"), 5);
 evm.exists(EVMC("foo")) == true;
 evm.get(EVMC("foo")) == 5;
 evm.remove(EVMC("foo")) == true;
 evm.exists(EVMC("foo")) == false;
-
-evm.set(EVMD(null),91);
+evm.set(EVMD(null), 91);
 evm.exists(EVMD(null)) == true;
 evm.get(EVMD(null)) == 91;
 evm.remove(EVMD(null)) == true;
 evm.exists(EVMD(null)) == false;
-
 evm.set(EVMD(EVMA), 12);
 evm.exists(EVMD(EVMA)) == true;
 evm.get(EVMD(EVMA)) == 12;
 evm.remove(EVMD(EVMA)) == true;
 evm.exists(EVMD(EVMA)) == false;
-
-evm.set(EVME(null),99);
+evm.set(EVME(null), 99);
 evm.exists(EVME(null)) == true;
 evm.exists(EVME()) == true;
 evm.get(EVME(null)) == 99;
 evm.get(EVME()) == 99;
-
-evm.set(EVMF([EVMA, EVMB()]), 12);
-evm.exists(EVMF([EVMA, EVMB()])) == true;
-evm.exists(EVMF([EVMA, EVMB(null)])) == true;
-evm.get(EVMF([EVMA, EVMB()])) == 12;
-evm.get(EVMF([EVMA, EVMB(null)])) == 12;
-
 evm.clear();
 evm.exists(EVMF([EVMA, EVMB()])) == false;
 evm.exists(EVMF([EVMA, EVMB(null)])) == false;
-
-[for (k=>v in evm) k] == [];
+[for (k => v in evm) k] == [];