|
@@ -59,8 +59,8 @@ and dt =
|
|
|
| Switch of st * (con * dt) list
|
|
|
| Bind of ((tvar * pos) * st) list * dt
|
|
|
| Goto of int
|
|
|
- | Expr of texpr
|
|
|
- | Guard of texpr * dt * dt option
|
|
|
+ | Expr of int
|
|
|
+ | Guard of int * dt * dt option
|
|
|
|
|
|
(* Pattern *)
|
|
|
|
|
@@ -79,8 +79,6 @@ and pat = {
|
|
|
}
|
|
|
|
|
|
type out = {
|
|
|
- o_expr : texpr;
|
|
|
- o_guard : texpr option;
|
|
|
o_pos : pos;
|
|
|
o_id : int;
|
|
|
o_default : bool;
|
|
@@ -101,13 +99,14 @@ type pattern_ctx = {
|
|
|
type matcher = {
|
|
|
ctx : typer;
|
|
|
need_val : bool;
|
|
|
- dt_cache : (dt,int) Hashtbl.t;
|
|
|
dt_lut : dt DynArray.t;
|
|
|
mutable dt_count : int;
|
|
|
mutable outcomes : (pat list,out) PMap.t;
|
|
|
mutable toplevel_or : bool;
|
|
|
mutable used_paths : (int,bool) Hashtbl.t;
|
|
|
mutable has_extractor : bool;
|
|
|
+ mutable expr_map : (int,texpr * texpr option) PMap.t;
|
|
|
+ mutable first : int;
|
|
|
}
|
|
|
|
|
|
exception Not_exhaustive of pat * st
|
|
@@ -131,19 +130,24 @@ let mk_st def t p = {
|
|
|
|
|
|
let mk_out mctx id e eg pl is_default p =
|
|
|
let out = {
|
|
|
- o_expr = e;
|
|
|
- o_guard = eg;
|
|
|
o_pos = p;
|
|
|
o_id = id;
|
|
|
o_default = is_default;
|
|
|
} in
|
|
|
mctx.outcomes <- PMap.add pl out mctx.outcomes;
|
|
|
+ mctx.expr_map <- PMap.add id (e,eg) mctx.expr_map;
|
|
|
out
|
|
|
|
|
|
let clone_out mctx out pl p =
|
|
|
let out = {out with o_pos = p; } in
|
|
|
out
|
|
|
|
|
|
+let get_guard mctx id =
|
|
|
+ snd (PMap.find id mctx.expr_map)
|
|
|
+
|
|
|
+let get_expr mctx id =
|
|
|
+ fst (PMap.find id mctx.expr_map)
|
|
|
+
|
|
|
let mk_pat pdef t p = {
|
|
|
p_def = pdef;
|
|
|
p_type = t;
|
|
@@ -712,7 +716,7 @@ let column_sigma mctx st pmat =
|
|
|
| (pv,out) :: pr ->
|
|
|
let rec loop2 out = function
|
|
|
| PCon (c,_) ->
|
|
|
- add c (out.o_guard <> None);
|
|
|
+ add c ((get_guard mctx out.o_id) <> None);
|
|
|
| POr(pat1,pat2) ->
|
|
|
let out2 = clone_out mctx out [pat2] pat2.p_pos in
|
|
|
loop2 out pat1.p_def;
|
|
@@ -816,23 +820,18 @@ let bind_remaining out pv stl =
|
|
|
in
|
|
|
loop stl pv
|
|
|
|
|
|
-let get_cache mctx dt =
|
|
|
- match dt with Goto _ -> dt | _ ->
|
|
|
- try
|
|
|
- let i = Hashtbl.find mctx.dt_cache dt in
|
|
|
- Goto i
|
|
|
- with Not_found ->
|
|
|
- Hashtbl.replace mctx.dt_cache dt mctx.dt_count;
|
|
|
- mctx.dt_count <- mctx.dt_count + 1;
|
|
|
- DynArray.add mctx.dt_lut dt;
|
|
|
- dt
|
|
|
+let get_cache mctx toplevel dt =
|
|
|
+ if toplevel then mctx.first <- mctx.dt_count;
|
|
|
+ mctx.dt_count <- mctx.dt_count + 1;
|
|
|
+ DynArray.add mctx.dt_lut dt;
|
|
|
+ dt
|
|
|
|
|
|
let rec compile mctx stl pmat toplevel =
|
|
|
- let guard e dt1 dt2 = get_cache mctx (Guard(e,dt1,dt2)) in
|
|
|
- let expr e = get_cache mctx (Expr e) in
|
|
|
- let bind bl dt = get_cache mctx (Bind(bl,dt)) in
|
|
|
- let switch st cl = get_cache mctx (Switch(st,cl)) in
|
|
|
- get_cache mctx (match pmat with
|
|
|
+ let guard id dt1 dt2 = get_cache mctx toplevel (Guard(id,dt1,dt2)) in
|
|
|
+ let expr id = get_cache mctx toplevel (Expr id) in
|
|
|
+ let bind bl dt = get_cache mctx toplevel (Bind(bl,dt)) in
|
|
|
+ let switch st cl = get_cache mctx toplevel (Switch(st,cl)) in
|
|
|
+ (match pmat with
|
|
|
| [] ->
|
|
|
(match stl with
|
|
|
| st :: stl ->
|
|
@@ -854,9 +853,9 @@ let rec compile mctx stl pmat toplevel =
|
|
|
if i = -1 then begin
|
|
|
Hashtbl.replace mctx.used_paths out.o_id true;
|
|
|
let bl = bind_remaining out pv stl in
|
|
|
- let dt = match out.o_guard with
|
|
|
- | None -> expr out.o_expr
|
|
|
- | Some e -> guard e (expr out.o_expr) (match pl with [] -> None | _ -> Some (compile mctx stl pl false))
|
|
|
+ let dt = match (get_guard mctx out.o_id) with
|
|
|
+ | None -> expr out.o_id
|
|
|
+ | Some _ -> guard out.o_id (expr out.o_id) (match pl with [] -> None | _ -> Some (compile mctx stl pl false))
|
|
|
in
|
|
|
(if bl = [] then dt else bind bl dt)
|
|
|
end else if i > 0 then begin
|
|
@@ -1096,10 +1095,11 @@ let match_expr ctx e cases def with_type p =
|
|
|
outcomes = PMap.empty;
|
|
|
toplevel_or = false;
|
|
|
used_paths = Hashtbl.create 0;
|
|
|
- dt_cache = Hashtbl.create 0;
|
|
|
dt_lut = DynArray.create ();
|
|
|
dt_count = 0;
|
|
|
has_extractor = false;
|
|
|
+ expr_map = PMap.empty;
|
|
|
+ first = 0;
|
|
|
} in
|
|
|
(* flatten cases *)
|
|
|
let cases = List.map (fun (el,eg,e) ->
|
|
@@ -1209,9 +1209,9 @@ let match_expr ctx e cases def with_type p =
|
|
|
end
|
|
|
) mctx.outcomes;
|
|
|
in
|
|
|
- let dt = try
|
|
|
+ begin try
|
|
|
(* compile decision tree *)
|
|
|
- compile mctx stl pl true
|
|
|
+ ignore(compile mctx stl pl true)
|
|
|
with Not_exhaustive(pat,st) ->
|
|
|
let rec s_st_r top pre st v = match st.st_def with
|
|
|
| SVar v1 ->
|
|
@@ -1257,7 +1257,7 @@ let match_expr ctx e cases def with_type p =
|
|
|
s_pat pat
|
|
|
in
|
|
|
error ("Unmatched patterns: " ^ (s_st_r true false st pat)) st.st_pos
|
|
|
- in
|
|
|
+ end;
|
|
|
save();
|
|
|
(* check for unused patterns *)
|
|
|
if !extractor_depth = 0 then check_unused();
|
|
@@ -1267,7 +1267,7 @@ let match_expr ctx e cases def with_type p =
|
|
|
mk_mono()
|
|
|
else match with_type with
|
|
|
| WithType t | WithTypeResume t -> t
|
|
|
- | _ -> try Typer.unify_min_raise ctx (List.rev_map (fun (_,out) -> out.o_expr) (List.rev pl)) with Error (Unify l,p) -> error (error_msg (Unify l)) p
|
|
|
+ | _ -> try Typer.unify_min_raise ctx (List.rev_map (fun (_,out) -> get_expr mctx out.o_id) (List.rev pl)) with Error (Unify l,p) -> error (error_msg (Unify l)) p
|
|
|
in
|
|
|
(* unify with expected type if necessary *)
|
|
|
begin match tmono with
|
|
@@ -1278,9 +1278,8 @@ let match_expr ctx e cases def with_type p =
|
|
|
end;
|
|
|
(* count usage *)
|
|
|
let usage = Array.make (DynArray.length mctx.dt_lut) 0 in
|
|
|
- let first = (match dt with Goto i -> i | _ -> Hashtbl.find mctx.dt_cache dt) in
|
|
|
(* we always want to keep the first part *)
|
|
|
- Array.set usage first 2;
|
|
|
+ Array.set usage mctx.first 2;
|
|
|
let rec loop dt = match dt with
|
|
|
| Goto i -> Array.set usage i ((Array.get usage i) + 1)
|
|
|
| Switch(st,cl) -> List.iter (fun (_,dt) -> loop dt) cl
|
|
@@ -1310,12 +1309,12 @@ let match_expr ctx e cases def with_type p =
|
|
|
| Goto i -> if usage.(i) > 1 then DTGoto (map.(i)) else loop (DynArray.get mctx.dt_lut i)
|
|
|
| Switch(st,cl) -> convert_switch ctx st cl loop
|
|
|
| Bind(bl,dt) -> DTBind(List.map (fun (v,st) -> v,convert_st ctx st) bl,loop dt)
|
|
|
- | Expr e -> DTExpr e
|
|
|
- | Guard(e,dt1,dt2) -> DTGuard(e,loop dt1, match dt2 with None -> None | Some dt -> Some (loop dt))
|
|
|
+ | Expr id -> DTExpr (get_expr mctx id)
|
|
|
+ | Guard(id,dt1,dt2) -> DTGuard((match get_guard mctx id with Some e -> e | None -> assert false),loop dt1, match dt2 with None -> None | Some dt -> Some (loop dt))
|
|
|
in
|
|
|
let lut = DynArray.map loop lut in
|
|
|
{
|
|
|
- dt_first = map.(first);
|
|
|
+ dt_first = map.(mctx.first);
|
|
|
dt_dt_lookup = DynArray.to_array lut;
|
|
|
dt_type = t;
|
|
|
dt_var_init = List.rev !var_inits;
|