Browse Source

[matcher] detuple a little

Simon Krajewski 2 years ago
parent
commit
c540d72cfd
1 changed files with 58 additions and 36 deletions
  1. 58 36
      src/typing/matcher.ml

+ 58 - 36
src/typing/matcher.ml

@@ -638,10 +638,16 @@ module Decision_tree = struct
 		| CompileTimeFinite (* type is considered finite only at compile-time but has inifite possible run-time values (enum abstracts) *)
 		| RunTimeFinite     (* type is truly finite (Bool, enums) *)
 
+	type bind = {
+		b_var : tvar;
+		b_pos : pos;
+		b_expr : texpr;
+	}
+
 	type t =
 		| Leaf of Case.t
-		| Switch of subject * (Constructor.t * bool * dt) list * dt
-		| Bind of (tvar * pos * texpr) list * dt
+		| Switch of subject * switch_case list * dt
+		| Bind of bind list * dt
 		| Guard of texpr * dt * dt
 		| GuardNull of texpr * dt * dt
 		| Fail
@@ -654,6 +660,18 @@ module Decision_tree = struct
 		mutable dt_texpr : texpr option;
 	}
 
+	and switch_case = {
+		sc_con : Constructor.t;
+		sc_unguarded : bool;
+		sc_dt : dt;
+	}
+
+	let make_bind v p e = {
+		b_var = v;
+		b_pos = p;
+		b_expr = e;
+	}
+
 	let tab_string = "    "
 
 	let to_string dt =
@@ -696,21 +714,21 @@ module Decision_tree = struct
 				print_case_expr tabs case
 			| Switch(e,cases,dt) ->
 				add_line tabs (Printf.sprintf "switch (%s)" (s_expr tabs e));
-				List.iter (fun (con,unguarded,dt) ->
+				List.iter (fun sc ->
 					add_line (tabs ^ tab_string) "case ";
-					add (Constructor.to_string con);
-					add (if unguarded then "(unguarded)" else "guarded");
+					add (Constructor.to_string sc.sc_con);
+					add (if sc.sc_unguarded then "(unguarded)" else "guarded");
 					add ":";
-					loop (tabs ^ tab_string ^ tab_string) dt;
+					loop (tabs ^ tab_string ^ tab_string) sc.sc_dt;
 				) cases;
 				add_line (tabs ^ tab_string) "default";
 				loop (tabs ^ tab_string ^ tab_string) dt;
 			| Bind(bl,dt) ->
-				List.iter (fun (v,_,e) ->
+				List.iter (fun bind ->
 					add_line tabs "var ";
-					add v.v_name;
+					add bind.b_var.v_name;
 					add " = ";
-					add (s_expr tabs e);
+					add (s_expr tabs bind.b_expr);
 				) bl;
 				loop tabs dt
 			| Guard(e,dt1,dt2) ->
@@ -738,10 +756,10 @@ module Decision_tree = struct
 			case1 == case2
 		| Switch(subject1,cases1,dt1),Switch(subject2,cases2,dt2) ->
 			Texpr.equal subject1 subject2 &&
-			safe_for_all2 (fun (con1,b1,dt1) (con2,b2,dt2) -> Constructor.equal con1 con2 && b1 = b2 && equal_dt dt1 dt2) cases1 cases2 &&
+			safe_for_all2 (fun sc1 sc2 -> Constructor.equal sc1.sc_con sc2.sc_con && sc1.sc_unguarded = sc2.sc_unguarded && equal_dt sc1.sc_dt sc2.sc_dt) cases1 cases2 &&
 			equal_dt dt1 dt2
 		| Bind(l1,dt1),Bind(l2,dt2) ->
-			safe_for_all2 (fun (v1,_,e1) (v2,_,e2) -> v1 == v2 && Texpr.equal e1 e2) l1 l2 &&
+			safe_for_all2 (fun bind1 bind2 -> bind1.b_var == bind2.b_var && Texpr.equal bind1.b_expr bind2.b_expr) l1 l2 &&
 			equal_dt dt1 dt2
 		| Fail,Fail ->
 			true
@@ -1022,11 +1040,11 @@ module Compile = struct
 			| (PatConstructor(con',patterns1),_) :: patterns2 when Constructor.equal con con' ->
 				Some (case,bindings,patterns1 @ patterns2)
 			| (PatVariable v,p) :: patterns2 ->
-				Some (case,(v,p,subject) :: bindings,ExtList.List.make arity (PatAny,p) @ patterns2)
+				Some (case,(make_bind v p subject) :: bindings,ExtList.List.make arity (PatAny,p) @ patterns2)
 			| (PatAny,_) as pat :: patterns2 ->
 				Some (case,bindings,ExtList.List.make arity pat @ patterns2)
 			| (PatBind(v,pat1),p) :: patterns ->
-				specialize (case,(v,p,subject) :: bindings,pat1 :: patterns)
+				specialize (case,(make_bind v p subject) :: bindings,pat1 :: patterns)
 			| _ ->
 				None
 		in
@@ -1035,11 +1053,11 @@ module Compile = struct
 	let default subject cases =
 		let rec default (case,bindings,patterns) = match patterns with
 			| (PatVariable v,p) :: patterns ->
-				Some (case,((v,p,subject) :: bindings),patterns)
+				Some (case,((make_bind v p subject) :: bindings),patterns)
 			| (PatAny,_) :: patterns ->
 				Some (case,bindings,patterns)
 			| (PatBind(v,pat1),p) :: patterns ->
-				default (case,((v,p,subject) :: bindings),pat1 :: patterns)
+				default (case,((make_bind v p subject) :: bindings),pat1 :: patterns)
 			| _ ->
 				None
 		in
@@ -1069,7 +1087,7 @@ module Compile = struct
 		String.concat " " (List.map s_expr_pretty subjects)
 
 	let s_case (case,bindings,patterns) =
-		let s_bindings = String.concat ", " (List.map (fun (v,_,e) -> Printf.sprintf "%s<%i> = %s" v.v_name v.v_id (s_expr_pretty e)) bindings) in
+		let s_bindings = String.concat ", " (List.map (fun bind -> Printf.sprintf "%s<%i> = %s" bind.b_var.v_name bind.b_var.v_id (s_expr_pretty bind.b_expr)) bindings) in
 		let s_patterns = String.concat " " (List.map Pattern.to_string patterns) in
 		let s_expr = match case.case_expr with None -> "" | Some e -> Type.s_expr_pretty false "\t\t" false s_type e in
 		let s_guard = match case.case_guard with None -> "" | Some e -> Type.s_expr_pretty false "\t\t" false s_type e in
@@ -1134,9 +1152,9 @@ module Compile = struct
 			| [PatAny,_],_ ->
 				bindings
 			| (PatVariable v,p) :: patterns,e :: el ->
-				loop patterns el ((v,p,e) :: bindings)
+				loop patterns el ((make_bind v p e) :: bindings)
 			| (PatBind(v,pat1),p) :: patterns,e :: el ->
-				loop (pat1 :: patterns) (e :: el) ((v,p,e) :: bindings)
+				loop (pat1 :: patterns) (e :: el) ((make_bind v p e) :: bindings)
 			| _ :: patterns,_ :: el ->
 				loop patterns el bindings
 			| [],[] ->
@@ -1166,7 +1184,7 @@ module Compile = struct
 						if case.case_guard = None then ConTable.replace unguarded con true;
 						let arg_positions = snd (List.split patterns) in
 						ConTable.replace sigma con arg_positions;
-					| PatBind(v,pat1) -> loop ((v,pos pat,subject) :: bindings) pat1
+					| PatBind(v,pat1) -> loop ((make_bind v (pos pat) subject) :: bindings) pat1
 					| PatVariable _ | PatAny -> ()
 					| PatExtractor _ -> raise Extractor
 					| _ -> typing_error ("Unexpected pattern: " ^ (Pattern.to_string pat)) case.case_pos;
@@ -1184,7 +1202,7 @@ module Compile = struct
 			let rec loop bindings locals sub_subjects = match sub_subjects with
 				| (name,e) :: sub_subjects ->
 					let v = add_local mctx.ctx VGenerated (Printf.sprintf "%s%s" gen_local_prefix name) e.etype e.epos in
-					loop ((v,v.v_pos,e) :: bindings) ((mk (TLocal v) v.v_type v.v_pos) :: locals) sub_subjects
+					loop ((make_bind v v.v_pos e) :: bindings) ((mk (TLocal v) v.v_type v.v_pos) :: locals) sub_subjects
 				| [] ->
 					List.rev bindings,List.rev locals
 			in
@@ -1193,7 +1211,11 @@ module Compile = struct
 			let spec = specialize subject con cases in
 			let dt = compile mctx subjects spec in
 			let dt = bind mctx bindings dt in
-			con,unguarded,dt
+			{
+				sc_con = con;
+				sc_unguarded = unguarded;
+				sc_dt = dt;
+			}
 		) sigma in
 		let default = default subject cases in
 		let switch_default = compile mctx subjects default in
@@ -1217,7 +1239,7 @@ module Compile = struct
 		let num_extractors,extractors = List.fold_left (fun (i,extractors) (_,_,patterns) ->
 			let rec loop bindings pat = match pat with
 				| (PatExtractor(v,e1,pat),_) -> i + 1,Some (v,e1,pat,bindings) :: extractors
-				| (PatBind(v,pat1),p) -> loop ((v,p,subject) :: bindings) pat1
+				| (PatBind(v,pat1),p) -> loop ((make_bind v p subject) :: bindings) pat1
 				| _ -> i,None :: extractors
 			in
 			loop [] (List.hd patterns)
@@ -1250,7 +1272,7 @@ module Compile = struct
 				die "" __LOC__
 		) (0,num_extractors,[],[],[]) cases (List.rev extractors) in
 		let dt = compile mctx ((subject :: List.rev ex_subjects) @ subjects) (List.rev cases) in
-		let bindings = List.map (fun (a,b,c,_,_) -> (a,b,c)) bindings in
+		let bindings = List.map (fun (a,b,c,_,_) -> (make_bind a b c)) bindings in
 		bind mctx bindings dt
 
 	let compile ctx match_debug subjects cases p =
@@ -1271,7 +1293,7 @@ module Compile = struct
 				| _ ->
 					let v = gen_local ctx e.etype e.epos in
 					let ev = mk (TLocal v) e.etype e.epos in
-					(ev :: subjects,(v,e.epos,e) :: vars)
+					(ev :: subjects,(make_bind v e.epos e) :: vars)
 				in
 				loop (subjects,vars) el
 		in
@@ -1358,12 +1380,12 @@ module TexprConverter = struct
 	let all_ctors ctx e cases =
 		let infer_type() = match cases with
 			| [] -> e,e.etype,false
-			| (con,_,_) :: _ ->
+			| sc :: _ ->
 				let fail() =
 					(* error "Could not determine switch kind, make sure the type is known" e.epos; *)
 					t_dynamic
 				in
-				let t = match fst con with
+				let t = match fst sc.sc_con with
 					| ConEnum(en,_) -> TEnum(en,extract_param_types en.e_params)
 					| ConArray _ -> ctx.t.tarray t_dynamic
 					| ConConst ct ->
@@ -1430,9 +1452,9 @@ module TexprConverter = struct
 			| ConArray _ -> kind = SKLength
 			| _ -> kind = SKValue
 		in
-		List.iter (fun (con,unguarded,dt) ->
-			if not (compatible_kind con) then typing_error "Incompatible pattern" dt.dt_pos;
-			if unguarded then ConTable.remove h con
+		List.iter (fun sc ->
+			if not (compatible_kind sc.sc_con) then typing_error "Incompatible pattern" sc.sc_dt.dt_pos;
+			if sc.sc_unguarded then ConTable.remove h sc.sc_con
 		) cases;
 		let unmatched = ConTable.fold (fun con _ acc -> con :: acc) h [] in
 		e,unmatched,kind,finiteness
@@ -1499,8 +1521,8 @@ module TexprConverter = struct
 							| Some e -> Some e
 							| None -> Some (mk (TBlock []) ctx.t.tvoid case.case_pos)
 						end
-					| Switch(_,[(ConFields _,_),_,dt],_) -> (* TODO: Can we improve this by making it more general? *)
-						loop dt_rec params dt
+					| Switch(_,[{sc_con = (ConFields _,_)} as sc],_) -> (* TODO: Can we improve this by making it more general? *)
+						loop dt_rec params sc.sc_dt
 					| Switch(e_subject,cases,default) ->
 						let dt_rec',toplevel = match dt_rec with
 							| Toplevel -> AfterSwitch,true
@@ -1519,8 +1541,8 @@ module TexprConverter = struct
 							| Some e ->
 								Some e
 						in
-						let cases = ExtList.List.filter_map (fun (con,_,dt) -> match unify_constructor ctx params e_subject.etype con with
-							| Some(_,params) -> Some (con,dt,params)
+						let cases = ExtList.List.filter_map (fun sc -> match unify_constructor ctx params e_subject.etype sc.sc_con with
+							| Some(_,params) -> Some (sc.sc_con,sc.sc_dt,params)
 							| None -> None
 						) cases in
 						let group cases =
@@ -1643,9 +1665,9 @@ module TexprConverter = struct
 							end
 						end
 					| Bind(bl,dt) ->
-						let el = List.map (fun (v,p,e) ->
-							v_lookup := IntMap.add v.v_id e !v_lookup;
-							mk (TVar(v,Some e)) com.basic.tvoid p
+						let el = List.map (fun bind ->
+							v_lookup := IntMap.add bind.b_var.v_id bind.b_expr !v_lookup;
+							mk (TVar(bind.b_var,Some bind.b_expr)) com.basic.tvoid p
 						) bl in
 						let e = loop dt_rec params dt in
 						Option.map (fun e -> mk (TBlock (el @ [e])) e.etype dt.dt_pos) e;