|
@@ -638,10 +638,16 @@ module Decision_tree = struct
|
|
|
| CompileTimeFinite (* type is considered finite only at compile-time but has inifite possible run-time values (enum abstracts) *)
|
|
|
| RunTimeFinite (* type is truly finite (Bool, enums) *)
|
|
|
|
|
|
+ type bind = {
|
|
|
+ b_var : tvar;
|
|
|
+ b_pos : pos;
|
|
|
+ b_expr : texpr;
|
|
|
+ }
|
|
|
+
|
|
|
type t =
|
|
|
| Leaf of Case.t
|
|
|
- | Switch of subject * (Constructor.t * bool * dt) list * dt
|
|
|
- | Bind of (tvar * pos * texpr) list * dt
|
|
|
+ | Switch of subject * switch_case list * dt
|
|
|
+ | Bind of bind list * dt
|
|
|
| Guard of texpr * dt * dt
|
|
|
| GuardNull of texpr * dt * dt
|
|
|
| Fail
|
|
@@ -654,6 +660,18 @@ module Decision_tree = struct
|
|
|
mutable dt_texpr : texpr option;
|
|
|
}
|
|
|
|
|
|
+ and switch_case = {
|
|
|
+ sc_con : Constructor.t;
|
|
|
+ sc_unguarded : bool;
|
|
|
+ sc_dt : dt;
|
|
|
+ }
|
|
|
+
|
|
|
+ let make_bind v p e = {
|
|
|
+ b_var = v;
|
|
|
+ b_pos = p;
|
|
|
+ b_expr = e;
|
|
|
+ }
|
|
|
+
|
|
|
let tab_string = " "
|
|
|
|
|
|
let to_string dt =
|
|
@@ -696,21 +714,21 @@ module Decision_tree = struct
|
|
|
print_case_expr tabs case
|
|
|
| Switch(e,cases,dt) ->
|
|
|
add_line tabs (Printf.sprintf "switch (%s)" (s_expr tabs e));
|
|
|
- List.iter (fun (con,unguarded,dt) ->
|
|
|
+ List.iter (fun sc ->
|
|
|
add_line (tabs ^ tab_string) "case ";
|
|
|
- add (Constructor.to_string con);
|
|
|
- add (if unguarded then "(unguarded)" else "guarded");
|
|
|
+ add (Constructor.to_string sc.sc_con);
|
|
|
+ add (if sc.sc_unguarded then "(unguarded)" else "guarded");
|
|
|
add ":";
|
|
|
- loop (tabs ^ tab_string ^ tab_string) dt;
|
|
|
+ loop (tabs ^ tab_string ^ tab_string) sc.sc_dt;
|
|
|
) cases;
|
|
|
add_line (tabs ^ tab_string) "default";
|
|
|
loop (tabs ^ tab_string ^ tab_string) dt;
|
|
|
| Bind(bl,dt) ->
|
|
|
- List.iter (fun (v,_,e) ->
|
|
|
+ List.iter (fun bind ->
|
|
|
add_line tabs "var ";
|
|
|
- add v.v_name;
|
|
|
+ add bind.b_var.v_name;
|
|
|
add " = ";
|
|
|
- add (s_expr tabs e);
|
|
|
+ add (s_expr tabs bind.b_expr);
|
|
|
) bl;
|
|
|
loop tabs dt
|
|
|
| Guard(e,dt1,dt2) ->
|
|
@@ -738,10 +756,10 @@ module Decision_tree = struct
|
|
|
case1 == case2
|
|
|
| Switch(subject1,cases1,dt1),Switch(subject2,cases2,dt2) ->
|
|
|
Texpr.equal subject1 subject2 &&
|
|
|
- safe_for_all2 (fun (con1,b1,dt1) (con2,b2,dt2) -> Constructor.equal con1 con2 && b1 = b2 && equal_dt dt1 dt2) cases1 cases2 &&
|
|
|
+ safe_for_all2 (fun sc1 sc2 -> Constructor.equal sc1.sc_con sc2.sc_con && sc1.sc_unguarded = sc2.sc_unguarded && equal_dt sc1.sc_dt sc2.sc_dt) cases1 cases2 &&
|
|
|
equal_dt dt1 dt2
|
|
|
| Bind(l1,dt1),Bind(l2,dt2) ->
|
|
|
- safe_for_all2 (fun (v1,_,e1) (v2,_,e2) -> v1 == v2 && Texpr.equal e1 e2) l1 l2 &&
|
|
|
+ safe_for_all2 (fun bind1 bind2 -> bind1.b_var == bind2.b_var && Texpr.equal bind1.b_expr bind2.b_expr) l1 l2 &&
|
|
|
equal_dt dt1 dt2
|
|
|
| Fail,Fail ->
|
|
|
true
|
|
@@ -1022,11 +1040,11 @@ module Compile = struct
|
|
|
| (PatConstructor(con',patterns1),_) :: patterns2 when Constructor.equal con con' ->
|
|
|
Some (case,bindings,patterns1 @ patterns2)
|
|
|
| (PatVariable v,p) :: patterns2 ->
|
|
|
- Some (case,(v,p,subject) :: bindings,ExtList.List.make arity (PatAny,p) @ patterns2)
|
|
|
+ Some (case,(make_bind v p subject) :: bindings,ExtList.List.make arity (PatAny,p) @ patterns2)
|
|
|
| (PatAny,_) as pat :: patterns2 ->
|
|
|
Some (case,bindings,ExtList.List.make arity pat @ patterns2)
|
|
|
| (PatBind(v,pat1),p) :: patterns ->
|
|
|
- specialize (case,(v,p,subject) :: bindings,pat1 :: patterns)
|
|
|
+ specialize (case,(make_bind v p subject) :: bindings,pat1 :: patterns)
|
|
|
| _ ->
|
|
|
None
|
|
|
in
|
|
@@ -1035,11 +1053,11 @@ module Compile = struct
|
|
|
let default subject cases =
|
|
|
let rec default (case,bindings,patterns) = match patterns with
|
|
|
| (PatVariable v,p) :: patterns ->
|
|
|
- Some (case,((v,p,subject) :: bindings),patterns)
|
|
|
+ Some (case,((make_bind v p subject) :: bindings),patterns)
|
|
|
| (PatAny,_) :: patterns ->
|
|
|
Some (case,bindings,patterns)
|
|
|
| (PatBind(v,pat1),p) :: patterns ->
|
|
|
- default (case,((v,p,subject) :: bindings),pat1 :: patterns)
|
|
|
+ default (case,((make_bind v p subject) :: bindings),pat1 :: patterns)
|
|
|
| _ ->
|
|
|
None
|
|
|
in
|
|
@@ -1069,7 +1087,7 @@ module Compile = struct
|
|
|
String.concat " " (List.map s_expr_pretty subjects)
|
|
|
|
|
|
let s_case (case,bindings,patterns) =
|
|
|
- let s_bindings = String.concat ", " (List.map (fun (v,_,e) -> Printf.sprintf "%s<%i> = %s" v.v_name v.v_id (s_expr_pretty e)) bindings) in
|
|
|
+ let s_bindings = String.concat ", " (List.map (fun bind -> Printf.sprintf "%s<%i> = %s" bind.b_var.v_name bind.b_var.v_id (s_expr_pretty bind.b_expr)) bindings) in
|
|
|
let s_patterns = String.concat " " (List.map Pattern.to_string patterns) in
|
|
|
let s_expr = match case.case_expr with None -> "" | Some e -> Type.s_expr_pretty false "\t\t" false s_type e in
|
|
|
let s_guard = match case.case_guard with None -> "" | Some e -> Type.s_expr_pretty false "\t\t" false s_type e in
|
|
@@ -1134,9 +1152,9 @@ module Compile = struct
|
|
|
| [PatAny,_],_ ->
|
|
|
bindings
|
|
|
| (PatVariable v,p) :: patterns,e :: el ->
|
|
|
- loop patterns el ((v,p,e) :: bindings)
|
|
|
+ loop patterns el ((make_bind v p e) :: bindings)
|
|
|
| (PatBind(v,pat1),p) :: patterns,e :: el ->
|
|
|
- loop (pat1 :: patterns) (e :: el) ((v,p,e) :: bindings)
|
|
|
+ loop (pat1 :: patterns) (e :: el) ((make_bind v p e) :: bindings)
|
|
|
| _ :: patterns,_ :: el ->
|
|
|
loop patterns el bindings
|
|
|
| [],[] ->
|
|
@@ -1166,7 +1184,7 @@ module Compile = struct
|
|
|
if case.case_guard = None then ConTable.replace unguarded con true;
|
|
|
let arg_positions = snd (List.split patterns) in
|
|
|
ConTable.replace sigma con arg_positions;
|
|
|
- | PatBind(v,pat1) -> loop ((v,pos pat,subject) :: bindings) pat1
|
|
|
+ | PatBind(v,pat1) -> loop ((make_bind v (pos pat) subject) :: bindings) pat1
|
|
|
| PatVariable _ | PatAny -> ()
|
|
|
| PatExtractor _ -> raise Extractor
|
|
|
| _ -> typing_error ("Unexpected pattern: " ^ (Pattern.to_string pat)) case.case_pos;
|
|
@@ -1184,7 +1202,7 @@ module Compile = struct
|
|
|
let rec loop bindings locals sub_subjects = match sub_subjects with
|
|
|
| (name,e) :: sub_subjects ->
|
|
|
let v = add_local mctx.ctx VGenerated (Printf.sprintf "%s%s" gen_local_prefix name) e.etype e.epos in
|
|
|
- loop ((v,v.v_pos,e) :: bindings) ((mk (TLocal v) v.v_type v.v_pos) :: locals) sub_subjects
|
|
|
+ loop ((make_bind v v.v_pos e) :: bindings) ((mk (TLocal v) v.v_type v.v_pos) :: locals) sub_subjects
|
|
|
| [] ->
|
|
|
List.rev bindings,List.rev locals
|
|
|
in
|
|
@@ -1193,7 +1211,11 @@ module Compile = struct
|
|
|
let spec = specialize subject con cases in
|
|
|
let dt = compile mctx subjects spec in
|
|
|
let dt = bind mctx bindings dt in
|
|
|
- con,unguarded,dt
|
|
|
+ {
|
|
|
+ sc_con = con;
|
|
|
+ sc_unguarded = unguarded;
|
|
|
+ sc_dt = dt;
|
|
|
+ }
|
|
|
) sigma in
|
|
|
let default = default subject cases in
|
|
|
let switch_default = compile mctx subjects default in
|
|
@@ -1217,7 +1239,7 @@ module Compile = struct
|
|
|
let num_extractors,extractors = List.fold_left (fun (i,extractors) (_,_,patterns) ->
|
|
|
let rec loop bindings pat = match pat with
|
|
|
| (PatExtractor(v,e1,pat),_) -> i + 1,Some (v,e1,pat,bindings) :: extractors
|
|
|
- | (PatBind(v,pat1),p) -> loop ((v,p,subject) :: bindings) pat1
|
|
|
+ | (PatBind(v,pat1),p) -> loop ((make_bind v p subject) :: bindings) pat1
|
|
|
| _ -> i,None :: extractors
|
|
|
in
|
|
|
loop [] (List.hd patterns)
|
|
@@ -1250,7 +1272,7 @@ module Compile = struct
|
|
|
die "" __LOC__
|
|
|
) (0,num_extractors,[],[],[]) cases (List.rev extractors) in
|
|
|
let dt = compile mctx ((subject :: List.rev ex_subjects) @ subjects) (List.rev cases) in
|
|
|
- let bindings = List.map (fun (a,b,c,_,_) -> (a,b,c)) bindings in
|
|
|
+ let bindings = List.map (fun (a,b,c,_,_) -> (make_bind a b c)) bindings in
|
|
|
bind mctx bindings dt
|
|
|
|
|
|
let compile ctx match_debug subjects cases p =
|
|
@@ -1271,7 +1293,7 @@ module Compile = struct
|
|
|
| _ ->
|
|
|
let v = gen_local ctx e.etype e.epos in
|
|
|
let ev = mk (TLocal v) e.etype e.epos in
|
|
|
- (ev :: subjects,(v,e.epos,e) :: vars)
|
|
|
+ (ev :: subjects,(make_bind v e.epos e) :: vars)
|
|
|
in
|
|
|
loop (subjects,vars) el
|
|
|
in
|
|
@@ -1358,12 +1380,12 @@ module TexprConverter = struct
|
|
|
let all_ctors ctx e cases =
|
|
|
let infer_type() = match cases with
|
|
|
| [] -> e,e.etype,false
|
|
|
- | (con,_,_) :: _ ->
|
|
|
+ | sc :: _ ->
|
|
|
let fail() =
|
|
|
(* error "Could not determine switch kind, make sure the type is known" e.epos; *)
|
|
|
t_dynamic
|
|
|
in
|
|
|
- let t = match fst con with
|
|
|
+ let t = match fst sc.sc_con with
|
|
|
| ConEnum(en,_) -> TEnum(en,extract_param_types en.e_params)
|
|
|
| ConArray _ -> ctx.t.tarray t_dynamic
|
|
|
| ConConst ct ->
|
|
@@ -1430,9 +1452,9 @@ module TexprConverter = struct
|
|
|
| ConArray _ -> kind = SKLength
|
|
|
| _ -> kind = SKValue
|
|
|
in
|
|
|
- List.iter (fun (con,unguarded,dt) ->
|
|
|
- if not (compatible_kind con) then typing_error "Incompatible pattern" dt.dt_pos;
|
|
|
- if unguarded then ConTable.remove h con
|
|
|
+ List.iter (fun sc ->
|
|
|
+ if not (compatible_kind sc.sc_con) then typing_error "Incompatible pattern" sc.sc_dt.dt_pos;
|
|
|
+ if sc.sc_unguarded then ConTable.remove h sc.sc_con
|
|
|
) cases;
|
|
|
let unmatched = ConTable.fold (fun con _ acc -> con :: acc) h [] in
|
|
|
e,unmatched,kind,finiteness
|
|
@@ -1499,8 +1521,8 @@ module TexprConverter = struct
|
|
|
| Some e -> Some e
|
|
|
| None -> Some (mk (TBlock []) ctx.t.tvoid case.case_pos)
|
|
|
end
|
|
|
- | Switch(_,[(ConFields _,_),_,dt],_) -> (* TODO: Can we improve this by making it more general? *)
|
|
|
- loop dt_rec params dt
|
|
|
+ | Switch(_,[{sc_con = (ConFields _,_)} as sc],_) -> (* TODO: Can we improve this by making it more general? *)
|
|
|
+ loop dt_rec params sc.sc_dt
|
|
|
| Switch(e_subject,cases,default) ->
|
|
|
let dt_rec',toplevel = match dt_rec with
|
|
|
| Toplevel -> AfterSwitch,true
|
|
@@ -1519,8 +1541,8 @@ module TexprConverter = struct
|
|
|
| Some e ->
|
|
|
Some e
|
|
|
in
|
|
|
- let cases = ExtList.List.filter_map (fun (con,_,dt) -> match unify_constructor ctx params e_subject.etype con with
|
|
|
- | Some(_,params) -> Some (con,dt,params)
|
|
|
+ let cases = ExtList.List.filter_map (fun sc -> match unify_constructor ctx params e_subject.etype sc.sc_con with
|
|
|
+ | Some(_,params) -> Some (sc.sc_con,sc.sc_dt,params)
|
|
|
| None -> None
|
|
|
) cases in
|
|
|
let group cases =
|
|
@@ -1643,9 +1665,9 @@ module TexprConverter = struct
|
|
|
end
|
|
|
end
|
|
|
| Bind(bl,dt) ->
|
|
|
- let el = List.map (fun (v,p,e) ->
|
|
|
- v_lookup := IntMap.add v.v_id e !v_lookup;
|
|
|
- mk (TVar(v,Some e)) com.basic.tvoid p
|
|
|
+ let el = List.map (fun bind ->
|
|
|
+ v_lookup := IntMap.add bind.b_var.v_id bind.b_expr !v_lookup;
|
|
|
+ mk (TVar(bind.b_var,Some bind.b_expr)) com.basic.tvoid p
|
|
|
) bl in
|
|
|
let e = loop dt_rec params dt in
|
|
|
Option.map (fun e -> mk (TBlock (el @ [e])) e.etype dt.dt_pos) e;
|