Răsfoiți Sursa

improve pattern matching implementation with regards to memory usage

Simon Krajewski 11 ani în urmă
părinte
comite
a735e3a45c
1 a modificat fișierele cu 35 adăugiri și 36 ștergeri
  1. 35 36
      matcher.ml

+ 35 - 36
matcher.ml

@@ -59,8 +59,8 @@ and dt =
 	| Switch of st * (con * dt) list
 	| Bind of ((tvar * pos) * st) list * dt
 	| Goto of int
-	| Expr of texpr
-	| Guard of texpr * dt * dt option
+	| Expr of int
+	| Guard of int * dt * dt option
 
 (* Pattern *)
 
@@ -79,8 +79,6 @@ and pat = {
 }
 
 type out = {
-	o_expr : texpr;
-	o_guard : texpr option;
 	o_pos : pos;
 	o_id : int;
 	o_default : bool;
@@ -101,13 +99,14 @@ type pattern_ctx = {
 type matcher = {
 	ctx : typer;
 	need_val : bool;
-	dt_cache : (dt,int) Hashtbl.t;
 	dt_lut : dt DynArray.t;
 	mutable dt_count : int;
 	mutable outcomes : (pat list,out) PMap.t;
 	mutable toplevel_or : bool;
 	mutable used_paths : (int,bool) Hashtbl.t;
 	mutable has_extractor : bool;
+	mutable expr_map : (int,texpr * texpr option) PMap.t;
+	mutable first : int;
 }
 
 exception Not_exhaustive of pat * st
@@ -131,19 +130,24 @@ let mk_st def t p = {
 
 let mk_out mctx id e eg pl is_default p =
 	let out = {
-		o_expr = e;
-		o_guard = eg;
 		o_pos = p;
 		o_id = id;
 		o_default = is_default;
 	} in
 	mctx.outcomes <- PMap.add pl out mctx.outcomes;
+	mctx.expr_map <- PMap.add id (e,eg) mctx.expr_map;
 	out
 
 let clone_out mctx out pl p =
 	let out = {out with o_pos = p; } in
 	out
 
+let get_guard mctx id =
+	snd (PMap.find id mctx.expr_map)
+
+let get_expr mctx id =
+	fst (PMap.find id mctx.expr_map)
+
 let mk_pat pdef t p = {
 	p_def = pdef;
 	p_type = t;
@@ -712,7 +716,7 @@ let column_sigma mctx st pmat =
 		| (pv,out) :: pr ->
 			let rec loop2 out = function
 				| PCon (c,_) ->
-					add c (out.o_guard <> None);
+					add c ((get_guard mctx out.o_id) <> None);
 				| POr(pat1,pat2) ->
 					let out2 = clone_out mctx out [pat2] pat2.p_pos in
 					loop2 out pat1.p_def;
@@ -816,23 +820,18 @@ let bind_remaining out pv stl =
 	in
 	loop stl pv
 
-let get_cache mctx dt =
-	match dt with Goto _ -> dt | _ ->
-	try
-		let i = Hashtbl.find mctx.dt_cache dt in
-		Goto i
-	with Not_found ->
-		Hashtbl.replace mctx.dt_cache dt mctx.dt_count;
-		mctx.dt_count <- mctx.dt_count + 1;
-		DynArray.add mctx.dt_lut dt;
-		dt
+let get_cache mctx toplevel dt =
+	if toplevel then mctx.first <- mctx.dt_count;
+	mctx.dt_count <- mctx.dt_count + 1;
+	DynArray.add mctx.dt_lut dt;
+	dt
 
 let rec compile mctx stl pmat toplevel =
-	let guard e dt1 dt2 = get_cache mctx (Guard(e,dt1,dt2)) in
-	let expr e = get_cache mctx (Expr e) in
-	let bind bl dt = get_cache mctx (Bind(bl,dt)) in
-	let switch st cl = get_cache mctx (Switch(st,cl)) in
-	get_cache mctx (match pmat with
+	let guard id dt1 dt2 = get_cache mctx toplevel (Guard(id,dt1,dt2)) in
+	let expr id = get_cache mctx toplevel (Expr id) in
+	let bind bl dt = get_cache mctx toplevel (Bind(bl,dt)) in
+	let switch st cl = get_cache mctx toplevel (Switch(st,cl)) in
+	(match pmat with
 	| [] ->
 		(match stl with
 		| st :: stl ->
@@ -854,9 +853,9 @@ let rec compile mctx stl pmat toplevel =
 		if i = -1 then begin
 			Hashtbl.replace mctx.used_paths out.o_id true;
 			let bl = bind_remaining out pv stl in
-			let dt = match out.o_guard with
-				| None -> expr out.o_expr
-				| Some e -> guard e (expr out.o_expr) (match pl with [] -> None | _ -> Some (compile mctx stl pl false))
+			let dt = match (get_guard mctx out.o_id) with
+				| None -> expr out.o_id
+				| Some _ -> guard out.o_id (expr out.o_id) (match pl with [] -> None | _ -> Some (compile mctx stl pl false))
 			in
 			(if bl = [] then dt else bind bl dt)
 		end else if i > 0 then begin
@@ -1096,10 +1095,11 @@ let match_expr ctx e cases def with_type p =
 		outcomes = PMap.empty;
 		toplevel_or = false;
 		used_paths = Hashtbl.create 0;
-		dt_cache = Hashtbl.create 0;
 		dt_lut = DynArray.create ();
 		dt_count = 0;
 		has_extractor = false;
+		expr_map = PMap.empty;
+		first = 0;
 	} in
 	(* flatten cases *)
 	let cases = List.map (fun (el,eg,e) ->
@@ -1209,9 +1209,9 @@ let match_expr ctx e cases def with_type p =
 			end
 		) mctx.outcomes;
 	in
-	let dt = try
+	begin try
 		(* compile decision tree *)
-		compile mctx stl pl true
+		ignore(compile mctx stl pl true)
 	with Not_exhaustive(pat,st) ->
  		let rec s_st_r top pre st v = match st.st_def with
  			| SVar v1 ->
@@ -1257,7 +1257,7 @@ let match_expr ctx e cases def with_type p =
 				s_pat pat
 		in
 		error ("Unmatched patterns: " ^ (s_st_r true false st pat)) st.st_pos
-	in
+	end;
 	save();
 	(* check for unused patterns *)
 	if !extractor_depth = 0 then check_unused();
@@ -1267,7 +1267,7 @@ let match_expr ctx e cases def with_type p =
 		mk_mono()
 	else match with_type with
 		| WithType t | WithTypeResume t -> t
-		| _ -> try Typer.unify_min_raise ctx (List.rev_map (fun (_,out) -> out.o_expr) (List.rev pl)) with Error (Unify l,p) -> error (error_msg (Unify l)) p
+		| _ -> try Typer.unify_min_raise ctx (List.rev_map (fun (_,out) -> get_expr mctx out.o_id) (List.rev pl)) with Error (Unify l,p) -> error (error_msg (Unify l)) p
 	in
 	(* unify with expected type if necessary *)
 	begin match tmono with
@@ -1278,9 +1278,8 @@ let match_expr ctx e cases def with_type p =
 	end;
 	(* count usage *)
 	let usage = Array.make (DynArray.length mctx.dt_lut) 0 in
-	let first = (match dt with Goto i -> i | _ -> Hashtbl.find mctx.dt_cache dt) in
 	(* we always want to keep the first part *)
-	Array.set usage first 2;
+	Array.set usage mctx.first 2;
 	let rec loop dt = match dt with
 		| Goto i -> Array.set usage i ((Array.get usage i) + 1)
 		| Switch(st,cl) -> List.iter (fun (_,dt) -> loop dt) cl
@@ -1310,12 +1309,12 @@ let match_expr ctx e cases def with_type p =
 		| Goto i -> if usage.(i) > 1 then DTGoto (map.(i)) else loop (DynArray.get mctx.dt_lut i)
 		| Switch(st,cl) -> convert_switch ctx st cl loop
 		| Bind(bl,dt) -> DTBind(List.map (fun (v,st) -> v,convert_st ctx st) bl,loop dt)
-		| Expr e -> DTExpr e
-		| Guard(e,dt1,dt2) -> DTGuard(e,loop dt1, match dt2 with None -> None | Some dt -> Some (loop dt))
+		| Expr id -> DTExpr (get_expr mctx id)
+		| Guard(id,dt1,dt2) -> DTGuard((match get_guard mctx id with Some e -> e | None -> assert false),loop dt1, match dt2 with None -> None | Some dt -> Some (loop dt))
 	in
 	let lut = DynArray.map loop lut in
 	{
-		dt_first = map.(first);
+		dt_first = map.(mctx.first);
 		dt_dt_lookup = DynArray.to_array lut;
 		dt_type = t;
 		dt_var_init = List.rev !var_inits;