浏览代码

filter and reindex match parts

Simon Krajewski 12 年之前
父节点
当前提交
64b9720c65
共有 1 个文件被更改,包括 48 次插入27 次删除
  1. 48 27
      matcher.ml

+ 48 - 27
matcher.ml

@@ -727,7 +727,8 @@ let bind_remaining out pv stl =
 let get_cache mctx dt =
 	match dt with Goto _ -> dt | _ ->
 	try
-		Goto (Hashtbl.find mctx.dt_cache dt)
+		let i = Hashtbl.find mctx.dt_cache dt in
+		Goto i
 	with Not_found ->
 		Hashtbl.replace mctx.dt_cache dt mctx.dt_count;
 		mctx.dt_count <- mctx.dt_count + 1;
@@ -735,7 +736,11 @@ let get_cache mctx dt =
 		dt
 
 let rec compile mctx stl pmat =
-	let dt = match pmat with
+	let guard e dt1 dt2 = get_cache mctx (Guard(e,dt1,dt2)) in
+	let expr e = get_cache mctx (Expr e) in
+	let bind bl dt = get_cache mctx (Bind(bl,dt)) in
+	let switch st cl = get_cache mctx (Switch(st,cl)) in
+	get_cache mctx (match pmat with
 	| [] ->
 		(match stl with
 		| st :: stl ->
@@ -758,10 +763,10 @@ let rec compile mctx stl pmat =
 			Hashtbl.replace mctx.used_paths out.o_id true;
 			let bl = bind_remaining out pv stl in
 			let dt = match out.o_guard with
-				| None -> get_cache mctx (Expr out.o_expr)
-				| Some e -> Guard (e, Expr out.o_expr, match pl with [] -> None | _ -> Some (compile mctx stl pl))
+				| None -> expr out.o_expr
+				| Some e -> guard e (expr out.o_expr) (match pl with [] -> None | _ -> Some (compile mctx stl pl))
 			in
-			get_cache mctx (if bl = [] then dt else Bind(bl,dt))
+			(if bl = [] then dt else bind bl dt)
 		end else if i > 0 then begin
 			let pmat = swap_pmat_columns i pmat in
 			let stls = swap_columns i stl in
@@ -776,16 +781,16 @@ let rec compile mctx stl pmat =
 				let hsubs = mk_subs st_head c in
 				let subs = hsubs @ st_tail in
 				let dt = compile mctx subs spec in
-				c,get_cache mctx dt
+				c,dt
 			) sigma in
 			let def = default mctx pmat in
 			let dt = match def,cases with
  			| _,[{c_def = CFields _},dt] ->
 				dt
 			| _ when not inf && PMap.is_empty !all ->
-				Switch(st_head,cases)
+				switch st_head cases
 			| [],_ when inf && not mctx.need_val ->
-				Switch(st_head,cases)
+				switch st_head cases
 			| [],_ when inf ->
 				raise (Not_exhaustive(any,st_head))
 			| [],_ ->
@@ -795,13 +800,11 @@ let rec compile mctx stl pmat =
 				compile mctx st_tail def
 			| def,_ ->
 				let cdef = mk_con CAny t_dynamic st_head.st_pos in
-				let cases = cases @ [cdef,get_cache mctx (compile mctx st_tail def)] in
-				Switch(st_head,cases)
+				let cases = cases @ [cdef,compile mctx st_tail def] in
+				switch st_head cases
 			in
-			if bl = [] then dt else Bind(bl,get_cache mctx dt)
-		end
-	in
-	get_cache mctx dt
+			if bl = [] then dt else bind bl dt
+		end)
 
 let rec collapse_case el = match el with
 	| e :: [] ->
@@ -1027,29 +1030,47 @@ let match_expr ctx e cases def with_type p =
 		| Some (WithTypeResume t2) -> (try unify_raise ctx t2 t p with Error (Unify l,p) -> raise (Typer.WithTypeError (l,p)))
 		| _ -> assert false
 	end;
-	let lut = DynArray.to_array mctx.dt_lut in
-	let first = match dt with Goto i -> i | _ -> Hashtbl.find mctx.dt_cache dt in
-	let count = Array.make (Array.length lut) 0 in
+	(* count usage *)
+	let usage = Array.make (DynArray.length mctx.dt_lut) 0 in
+	let first = (match dt with Goto i -> i | _ -> Hashtbl.find mctx.dt_cache dt) in
+	(* we always want to keep the first part *)
+	Array.set usage first 2;
 	let rec loop dt = match dt with
-		| Goto i -> Array.set count i (count.(i) + 1)
-		| Switch(_,cl) -> List.iter (fun (_,dt) -> loop dt) cl
-		| Bind(_,dt) -> loop dt
-		| Expr _ -> ()
-		| Guard (_,dt1,dt2) ->
+		| Goto i -> Array.set usage i ((Array.get usage i) + 1)
+		| Switch(st,cl) -> List.iter (fun (_,dt) -> loop dt) cl
+		| Bind(bl,dt) -> loop dt
+		| Expr e -> ()
+		| Guard(e,dt1,dt2) ->
 			loop dt1;
-			(match dt2 with None -> () | Some dt -> loop dt)
+			match dt2 with None -> () | Some dt -> (loop dt)
+	in
+	DynArray.iter loop mctx.dt_lut;
+	(* filter parts that will be inlined and keep a map to them*)
+	let map = Array.make (DynArray.length mctx.dt_lut) 0 in
+	let lut = DynArray.create() in
+	let rec loop i c =
+		if c < DynArray.length mctx.dt_lut then begin
+			let i' = if usage.(c) > 1 then begin
+				DynArray.add lut (DynArray.get mctx.dt_lut c);
+				i + 1
+			end else i in
+			Array.set map c i;
+		 	loop i' (c + 1)
+		end
 	in
-	Array.iter loop lut;
+	loop 0 0;
+	(* reindex *)
 	let rec loop dt = match dt with
-		| Goto i -> if count.(i) < 2 then lut.(i) else Goto i
+		| Goto i -> if usage.(i) > 1 then Goto (map.(i)) else loop (DynArray.get mctx.dt_lut i)
 		| Switch(st,cl) -> Switch(st, List.map (fun (c,dt) -> c, loop dt) cl)
 		| Bind(bl,dt) -> Bind(bl,loop dt)
 		| Expr e -> Expr e
 		| Guard(e,dt1,dt2) -> Guard(e,loop dt1, match dt2 with None -> None | Some dt -> Some (loop dt))
 	in
+	let lut = DynArray.map loop lut in
 	{
-		dt_first = first;
-		dt_dt_lookup = Array.map loop lut;
+		dt_first = map.(first);
+		dt_dt_lookup = DynArray.to_array lut;
 		dt_type = t;
 		dt_var_init = List.rev !var_inits;
 	}