Browse Source

add extractor support (implementation not perfect yet)

Simon Krajewski 12 years ago
parent
commit
c8ae56cc08
3 changed files with 219 additions and 20 deletions
  1. 53 3
      matcher.ml
  2. 10 0
      std/haxe/macro/TypeTools.hx
  3. 156 17
      tests/unit/TestMatch.hx

+ 53 - 3
matcher.ml

@@ -106,6 +106,7 @@ type matcher = {
 	mutable outcomes : (pat list,out) PMap.t;
 	mutable outcomes : (pat list,out) PMap.t;
 	mutable toplevel_or : bool;
 	mutable toplevel_or : bool;
 	mutable used_paths : (int,bool) Hashtbl.t;
 	mutable used_paths : (int,bool) Hashtbl.t;
+	mutable has_extractor : bool;
 }
 }
 
 
 exception Not_exhaustive of pat * st
 exception Not_exhaustive of pat * st
@@ -946,7 +947,7 @@ let convert_switch ctx st cases loop =
 	| TAbstract(a,_) when Meta.has Meta.FakeEnum a.a_meta ->
 	| TAbstract(a,_) when Meta.has Meta.FakeEnum a.a_meta ->
 		mk (TMeta((Meta.Exhaustive,[],p), e_st)) e_st.etype e_st.epos
 		mk (TMeta((Meta.Exhaustive,[],p), e_st)) e_st.etype e_st.epos
 	| TAbstract({a_path = [],"Bool"},_) ->
 	| TAbstract({a_path = [],"Bool"},_) ->
-		mk (TMeta((Meta.Exhaustive,[],p), e_st)) e_st.etype e_st.epos		
+		mk (TMeta((Meta.Exhaustive,[],p), e_st)) e_st.etype e_st.epos
 	| _ ->
 	| _ ->
 		e_st
 		e_st
 	in
 	in
@@ -975,6 +976,51 @@ let convert_switch ctx st cases loop =
 
 
 (* Decision tree compilation *)
 (* Decision tree compilation *)
 
 
+let transform_extractors mctx stl cases =
+	let rec loop cl = match cl with
+		| (epat,eg,e) :: cl ->
+			let ex = ref [] in
+			let exc = ref 0 in
+			let rec find_ex e = match fst e with
+				| EBinop(OpArrow, e1, e2) ->
+					let p = pos e in
+					let ec = EConst (Ident ("__ex" ^ string_of_int (!exc))),snd e in
+					let ecall = match fst e1 with
+						| EConst(Ident s) -> ECall((EField(ec,s),p),[]),p
+						| _ -> ECall(e1,[ec]),p
+					in
+					ex := (ecall,e2) :: !ex;
+					incr exc;
+					ec
+				| _ ->
+					Ast.map_expr find_ex e
+			in
+			let p = pos epat in
+			let epat = find_ex epat in
+			if !exc = 0 then (epat,eg,e) :: loop cl else begin
+				mctx.has_extractor <- true;
+				let esubjects = EArrayDecl (List.map fst !ex),p in
+				let case1 = [EArrayDecl (List.map snd !ex),p],eg,e in
+				let cases = match cl with
+					| [] -> [case1]
+					| [(EConst (Ident "_"),_),_,e] -> case1 :: [[(EConst (Ident "_"),p)],None,e]
+					| _ ->
+						let cl2 = List.map (fun (epat,eg,e) -> [epat],eg,e) (loop cl) in
+						let st = match stl with st :: stl -> st | _ -> error "Unsupported" p in
+						let subj = convert_st mctx.ctx st in
+						let e_subj = Interp.make_ast subj in
+						case1 :: [[(EConst (Ident "_"),p)],None,Some (ESwitch(e_subj,cl2,None),p)]
+				in
+				let eswitch = (ESwitch(esubjects,cases,None)),p in
+				(epat,None,Some eswitch) :: loop cl
+			end
+		| [] ->
+			[]
+	in
+	loop cases
+
+let extractor_depth = ref 0
+
 let match_expr ctx e cases def with_type p =
 let match_expr ctx e cases def with_type p =
 	let need_val,with_type,tmono = match with_type with
 	let need_val,with_type,tmono = match with_type with
 		| NoValue -> false,NoValue,None
 		| NoValue -> false,NoValue,None
@@ -1042,6 +1088,7 @@ let match_expr ctx e cases def with_type p =
 		dt_cache = Hashtbl.create 0;
 		dt_cache = Hashtbl.create 0;
 		dt_lut = DynArray.create ();
 		dt_lut = DynArray.create ();
 		dt_count = 0;
 		dt_count = 0;
+		has_extractor = false;
 	} in
 	} in
 	(* flatten cases *)
 	(* flatten cases *)
 	let cases = List.map (fun (el,eg,e) ->
 	let cases = List.map (fun (el,eg,e) ->
@@ -1049,6 +1096,8 @@ let match_expr ctx e cases def with_type p =
 		collapse_case el,eg,e
 		collapse_case el,eg,e
 	) cases in
 	) cases in
 	let is_complex = ref false in
 	let is_complex = ref false in
+	let cases = transform_extractors mctx stl cases in
+	if mctx.has_extractor then incr extractor_depth;
 	let add_pattern_locals (pat,locals,complex) =
 	let add_pattern_locals (pat,locals,complex) =
 		PMap.iter (fun n (v,p) -> ctx.locals <- PMap.add n v ctx.locals) locals;
 		PMap.iter (fun n (v,p) -> ctx.locals <- PMap.add n v ctx.locals) locals;
 		if complex then is_complex := true;
 		if complex then is_complex := true;
@@ -1116,7 +1165,7 @@ let match_expr ctx e cases def with_type p =
 	let check_unused () =
 	let check_unused () =
 		let unused p =
 		let unused p =
 			display_error ctx "This pattern is unused" p;
 			display_error ctx "This pattern is unused" p;
-			let old_error = ctx.on_error in
+ 			let old_error = ctx.on_error in
 			ctx.on_error <- (fun ctx s p -> ctx.on_error <- old_error; raise Exit);
 			ctx.on_error <- (fun ctx s p -> ctx.on_error <- old_error; raise Exit);
 	 		let check_expr e p =
 	 		let check_expr e p =
 				try begin match fst e with
 				try begin match fst e with
@@ -1175,7 +1224,8 @@ let match_expr ctx e cases def with_type p =
 		error ("Unmatched patterns: " ^ (s_st_r true false st (s_pat pat))) st.st_pos
 		error ("Unmatched patterns: " ^ (s_st_r true false st (s_pat pat))) st.st_pos
 	in
 	in
 	(* check for unused patterns *)
 	(* check for unused patterns *)
-	check_unused();
+	if !extractor_depth = 0 then check_unused();
+	if mctx.has_extractor then decr extractor_depth;
 	(* determine type of switch statement *)
 	(* determine type of switch statement *)
 	let t = if not need_val then
 	let t = if not need_val then
 		mk_mono()
 		mk_mono()

+ 10 - 0
std/haxe/macro/TypeTools.hx

@@ -150,4 +150,14 @@ class TypeTools {
 			else if (c.superClass != null) findField(c.superClass.t.get(), name, isStatic);
 			else if (c.superClass != null) findField(c.superClass.t.get(), name, isStatic);
 			else null;
 			else null;
 	}
 	}
+	
+	/**
+		Gets the value of a reference `r`.
+		
+		If `r` is null, the result is unspecified. Otherwise `r.get()` is
+		called.
+	**/
+	static inline function deref<T>(r:Ref<T>):T {
+		return r.get();
+	}
 }
 }

+ 156 - 17
tests/unit/TestMatch.hx

@@ -1,6 +1,9 @@
 package unit;
 package unit;
+import haxe.ds.Option;
 import haxe.macro.Expr;
 import haxe.macro.Expr;
 
 
+using unit.TestMatch;
+
 enum Tree<T> {
 enum Tree<T> {
 	Leaf(t:T);
 	Leaf(t:T);
 	Node(l:Tree<T>, r:Tree<T>);
 	Node(l:Tree<T>, r:Tree<T>);
@@ -21,14 +24,26 @@ enum NE {
 	A(?x:Int);
 	A(?x:Int);
 }
 }
 
 
-class TestMatch extends Test {
-	static macro function getErrorMessage(e:Expr) {
+enum MiniType {
+	MTString(t:MiniRef<String>, tl:Array<MiniType>);
+	MTInt(t:MiniRef<Int>, tl:Array<MiniType>);
+}
+
+typedef MiniRef<T> = {
+	public function get():T;
+}
+
+class TestMatchMacro {
+	static public macro function getErrorMessage(e:Expr) {
 		var result = try {
 		var result = try {
 			haxe.macro.Context.typeof(e);
 			haxe.macro.Context.typeof(e);
 			"no error";
 			"no error";
 		} catch (e:Dynamic) Std.string(e.message);
 		} catch (e:Dynamic) Std.string(e.message);
 		return macro $v{result};
 		return macro $v{result};
 	}
 	}
+}
+
+class TestMatch extends Test {
 
 
 	static function switchNormal(e:Expr):String {
 	static function switchNormal(e:Expr):String {
 		return switch(e.expr) {
 		return switch(e.expr) {
@@ -314,58 +329,58 @@ class TestMatch extends Test {
 	}
 	}
 
 
 	function testNonExhaustiveness() {
 	function testNonExhaustiveness() {
-		eq("Unmatched patterns: false", getErrorMessage(switch(true) {
+		eq("Unmatched patterns: false", TestMatchMacro.getErrorMessage(switch(true) {
 			case true:
 			case true:
 		}));
 		}));
-		eq("Unmatched patterns: OpNegBits | OpNeg", getErrorMessage(switch(OpIncrement) {
+		eq("Unmatched patterns: OpNegBits | OpNeg", TestMatchMacro.getErrorMessage(switch(OpIncrement) {
 			case OpIncrement:
 			case OpIncrement:
 			case OpDecrement:
 			case OpDecrement:
 			case OpNot:
 			case OpNot:
 		}));
 		}));
-		eq("Unmatched patterns: Node(Leaf(_),_)", getErrorMessage(switch(Leaf("foo")) {
+		eq("Unmatched patterns: Node(Leaf(_),_)", TestMatchMacro.getErrorMessage(switch(Leaf("foo")) {
 			case Node(Leaf("foo"), _):
 			case Node(Leaf("foo"), _):
 			case Leaf(_):
 			case Leaf(_):
 		}));
 		}));
-		eq("Unmatched patterns: Leaf", getErrorMessage(switch(Leaf("foo")) {
+		eq("Unmatched patterns: Leaf", TestMatchMacro.getErrorMessage(switch(Leaf("foo")) {
 			case Node(_, _):
 			case Node(_, _):
 			case Leaf(_) if (false):
 			case Leaf(_) if (false):
 		}));
 		}));
-		eq("Unmatched patterns: Leaf(_)", getErrorMessage(switch(Leaf("foo")) {
+		eq("Unmatched patterns: Leaf(_)", TestMatchMacro.getErrorMessage(switch(Leaf("foo")) {
 			case Node(_, _):
 			case Node(_, _):
 			case Leaf("foo"):
 			case Leaf("foo"):
 		}));
 		}));
-		eq("Unmatched patterns: [_,false,_]", getErrorMessage(switch [1, true, "foo"] {
+		eq("Unmatched patterns: [_,false,_]", TestMatchMacro.getErrorMessage(switch [1, true, "foo"] {
 			case [_, true, _]:
 			case [_, true, _]:
 		}));
 		}));
 		//var x:Null<Bool> = true;
 		//var x:Null<Bool> = true;
-		//eq("Unmatched patterns: null", getErrorMessage(switch x {
+		//eq("Unmatched patterns: null", TestMatchMacro.getErrorMessage(switch x {
 			//case true:
 			//case true:
 			//case false:
 			//case false:
 		//}));
 		//}));
 		//var t:Null<Tree<String>> = null;
 		//var t:Null<Tree<String>> = null;
-		//eq("Unmatched patterns: null", getErrorMessage(switch t {
+		//eq("Unmatched patterns: null", TestMatchMacro.getErrorMessage(switch t {
 			//case Leaf(_):
 			//case Leaf(_):
 			//case Node(_):
 			//case Node(_):
 		//}));
 		//}));
 	}
 	}
 
 
 	function testInvalidBinding() {
 	function testInvalidBinding() {
-		eq("Variable y must appear exactly once in each sub-pattern", getErrorMessage(switch(Leaf("foo")) {
+		eq("Variable y must appear exactly once in each sub-pattern", TestMatchMacro.getErrorMessage(switch(Leaf("foo")) {
 			case Leaf(x) | Leaf(y):
 			case Leaf(x) | Leaf(y):
 		}));
 		}));
-		eq("Variable y must appear exactly once in each sub-pattern", getErrorMessage(switch(Leaf("foo")) {
+		eq("Variable y must appear exactly once in each sub-pattern", TestMatchMacro.getErrorMessage(switch(Leaf("foo")) {
 			case Leaf(x) | Leaf(x) | Leaf(y):
 			case Leaf(x) | Leaf(x) | Leaf(y):
 		}));
 		}));
-		eq("Variable x must appear exactly once in each sub-pattern", getErrorMessage(switch(Leaf("foo")) {
+		eq("Variable x must appear exactly once in each sub-pattern", TestMatchMacro.getErrorMessage(switch(Leaf("foo")) {
 			case Leaf(x) | Leaf(x) | Leaf(_):
 			case Leaf(x) | Leaf(x) | Leaf(_):
 		}));
 		}));
-		eq("Variable l must appear exactly once in each sub-pattern", getErrorMessage(switch(Leaf("foo")) {
+		eq("Variable l must appear exactly once in each sub-pattern", TestMatchMacro.getErrorMessage(switch(Leaf("foo")) {
 			case Node(l = Leaf(x),_) | Node(Leaf(x), _):
 			case Node(l = Leaf(x),_) | Node(Leaf(x), _):
 		}));
 		}));
-		eq("Variable l must appear exactly once in each sub-pattern", getErrorMessage(switch(Leaf("foo")) {
+		eq("Variable l must appear exactly once in each sub-pattern", TestMatchMacro.getErrorMessage(switch(Leaf("foo")) {
 			case Node(l = Leaf(l), _):
 			case Node(l = Leaf(l), _):
 		}));
 		}));
-		eq("String should be unit.Tree<String>", getErrorMessage(switch(Leaf("foo")) {
+		eq("String should be unit.Tree<String>", TestMatchMacro.getErrorMessage(switch(Leaf("foo")) {
 			case Node(l = Leaf(_), _) | Leaf(l):
 			case Node(l = Leaf(_), _) | Leaf(l):
 		}));
 		}));
 	}
 	}
@@ -437,12 +452,136 @@ class TestMatch extends Test {
 		}
 		}
 		eq(r, 1);
 		eq(r, 1);
 		
 		
-		eq("Unmatched patterns: 405", getErrorMessage(switch(a) {
+		eq("Unmatched patterns: 405", TestMatchMacro.getErrorMessage(switch(a) {
 			case NotFound:
 			case NotFound:
 		}));
 		}));
 		#end
 		#end
 	}
 	}
 
 
+	function testExtractors() {
+		function f(i) {
+			return switch(i) {
+				case 1,2,3: 1;
+				case even => true: 2;
+				case 4: throw "unreachable";
+				case _: 3;
+			}
+		}
+		
+		eq(1, f(1));
+		eq(1, f(2));
+		eq(1, f(3));
+		eq(2, f(4));
+		eq(3, f(5));
+		eq(3, f(7));
+		eq(3, f(9));
+		eq(2, f(6));
+		eq(2, f(8));
+		
+		function ref<T>(t:T):MiniRef<T> return {
+			get: function() return t
+		}
+		
+		function f(t:MiniType) {
+			return switch (t) {
+				case MTString(deref => "Foo", []): "Foo";
+				case MTString(deref => "Bar" | "Baz", _): "BarBaz";
+				case MTInt(deref => i, []): 'Int:$i';
+				case MTString(_): "OtherString";
+				case _: "Other";
+			}
+		}
+		
+		eq("Foo", f(MTString(ref("Foo"), [])));
+		eq("BarBaz", f(MTString(ref("Bar"), [])));
+		eq("BarBaz", f(MTString(ref("Baz"), [])));
+		eq("OtherString", f(MTString(ref("a"), [])));
+		eq("OtherString", f(MTString(ref(""), [])));
+		eq("Int:12", f(MTInt(ref(12), [])));
+		eq("Other", f(MTInt(ref(12), [MTInt(ref(10),[])])));
+		
+		function g(i : Array<Int>) {
+			return switch(i) {
+				case [x]: 1;
+				case isPair => Some(p) : p.a+p.b;
+				case arr: 3;
+			}
+		}
+
+		eq(3, g([]));
+		eq(1, g([1]));
+		eq(5, g([2, 3]));
+		eq(3, g([2, 3, 4]));
+		
+		var anon = {
+			odd: function(i) return i & 1 != 0
+		};
+		
+		var i = 9;
+		var r = switch(i) {
+			case 1: 1;
+			case anon.odd => true: 2;
+			case 9: 3;
+			case _: 4;
+		}
+		eq(2, r);
+		
+		function mul(i1,i2) return i1 * i2;
+		
+		function check(i) {
+			return switch(i) {
+				case 1: 1;
+				case mul.bind(4) => 8: 2;
+				case mul.bind(5) => 15: 3;
+				case _: 4;
+			}
+		}
+		
+		eq(1, check(1));
+		eq(2, check(2));
+		eq(3, check(3));
+		eq(4, check(4));
+		
+		function is<T>(pred : T -> Bool) return function (x : T) {
+			return pred(x)?Some(x):None;
+		}
+
+		function isNot<T>(pred : T -> Bool) return function (x : T) {
+			return (!pred(x))?Some(x):None;
+		}
+
+		function testArgs<T>(i:Int, s:String, t:T) {
+			return Std.string(t);
+		}
+		function h(i : Array<Int>) {
+			return switch(i) {
+				case [x]: 1;
+				case isPair => Some({ a : a, b : b }) if (a < 0): 42;
+				case isPair => Some({ a : is(even) => Some(a), b : b }) : a+b;
+				case isPair => Some({ a : isNot(even) => Some(a), b : b }) : a*b;
+				case testArgs.bind(1, "foo") => "[99,98,97]": 99;
+				case arr: 3;
+			}
+		}
+		
+		eq(3, h([]));
+		eq(1, h([1]));
+		eq(1, h([2]));
+		eq(5, h([2, 3]));
+		eq(3, h([1, 3]));
+		eq(3, h([2, 3, 4]));
+		eq(42, h([-1, 3]));
+		eq(99, h([99,98,97]));
+	}
+	
+	static function isPair<T>(t:Array<T>) return t.length == 2 ? Some({a:t[0], b:t[1]}) : None;
+	
+	static function even(i:Int) {
+		return i & 1 == 0;
+	}
+	
+	static function deref<T>(ref:MiniRef<T>) return ref.get();
+	
 	#if false
 	#if false
 	 //all lines marked as // unused should give an error
 	 //all lines marked as // unused should give an error
 	function testRedundance() {
 	function testRedundance() {