| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651 |
- type cls = Kw | Kl | Ks | Kd
- type op_base =
- | Oadd
- | Osub
- | Omul
- | Oor
- | Oshl
- | Oshr
- type op = cls * op_base
- let op_bases =
- [Oadd; Osub; Omul; Oor; Oshl; Oshr]
- let commutative = function
- | (_, (Oadd | Omul | Oor)) -> true
- | (_, _) -> false
- let associative = function
- | (_, (Oadd | Omul | Oor)) -> true
- | (_, _) -> false
- type atomic_pattern =
- | Tmp
- | AnyCon
- | Con of int64
- (* Tmp < AnyCon < Con k *)
- type pattern =
- | Bnr of op * pattern * pattern
- | Atm of atomic_pattern
- | Var of string * atomic_pattern
- let is_atomic = function
- | (Atm _ | Var _) -> true
- | _ -> false
- let show_op_base o =
- match o with
- | Oadd -> "add"
- | Osub -> "sub"
- | Omul -> "mul"
- | Oor -> "or"
- | Oshl -> "shl"
- | Oshr -> "shr"
- let show_op (k, o) =
- show_op_base o ^
- (match k with
- | Kw -> "w"
- | Kl -> "l"
- | Ks -> "s"
- | Kd -> "d")
- let rec show_pattern p =
- match p with
- | Atm Tmp -> "%"
- | Atm AnyCon -> "$"
- | Atm (Con n) -> Int64.to_string n
- | Var (v, p) ->
- show_pattern (Atm p) ^ "'" ^ v
- | Bnr (o, pl, pr) ->
- "(" ^ show_op o ^
- " " ^ show_pattern pl ^
- " " ^ show_pattern pr ^ ")"
- let get_atomic p =
- match p with
- | (Atm a | Var (_, a)) -> Some a
- | _ -> None
- let rec pattern_match p w =
- match p with
- | Var (_, p) ->
- pattern_match (Atm p) w
- | Atm Tmp ->
- begin match get_atomic w with
- | Some (Con _ | AnyCon) -> false
- | _ -> true
- end
- | Atm (Con _) -> w = p
- | Atm (AnyCon) ->
- not (pattern_match (Atm Tmp) w)
- | Bnr (o, pl, pr) ->
- begin match w with
- | Bnr (o', wl, wr) ->
- o' = o &&
- pattern_match pl wl &&
- pattern_match pr wr
- | _ -> false
- end
- type +'a cursor = (* a position inside a pattern *)
- | Bnrl of op * 'a cursor * pattern
- | Bnrr of op * pattern * 'a cursor
- | Top of 'a
- let rec fold_cursor c p =
- match c with
- | Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
- | Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
- | Top _ -> p
- let peel p x =
- let once out (p, c) =
- match p with
- | Var (_, p) -> (Atm p, c) :: out
- | Atm _ -> (p, c) :: out
- | Bnr (o, pl, pr) ->
- (pl, Bnrl (o, c, pr)) ::
- (pr, Bnrr (o, pl, c)) :: out
- in
- let rec go l =
- let l' = List.fold_left once [] l in
- if List.length l' = List.length l
- then l'
- else go l'
- in go [(p, Top x)]
- let fold_pairs l1 l2 ini f =
- let rec go acc = function
- | [] -> acc
- | a :: l1' ->
- go (List.fold_left
- (fun acc b -> f (a, b) acc)
- acc l2) l1'
- in go ini l1
- let iter_pairs l f =
- fold_pairs l l () (fun x () -> f x)
- let inverse l =
- List.map (fun (a, b) -> (b, a)) l
- type 'a state =
- { id: int
- ; seen: pattern
- ; point: ('a cursor) list }
- let rec binops side {point; _} =
- List.filter_map (fun c ->
- match c, side with
- | Bnrl (o, c, r), `L -> Some ((o, c), r)
- | Bnrr (o, l, c), `R -> Some ((o, c), l)
- | _ -> None)
- point
- let group_by_fst l =
- List.fast_sort (fun (a, _) (b, _) ->
- compare a b) l |>
- List.fold_left (fun (oo, l, res) (o', c) ->
- match oo with
- | None -> (Some o', [c], [])
- | Some o when o = o' -> (oo, c :: l, res)
- | Some o -> (Some o', [c], (o, l) :: res))
- (None, [], []) |>
- (function
- | (None, _, _) -> []
- | (Some o, l, res) -> (o, l) :: res)
- let sort_uniq cmp l =
- List.fast_sort cmp l |>
- List.fold_left (fun (eo, l) e' ->
- match eo with
- | None -> (Some e', l)
- | Some e when cmp e e' = 0 -> (eo, l)
- | Some e -> (Some e', e :: l))
- (None, []) |>
- (function
- | (None, _) -> []
- | (Some e, l) -> List.rev (e :: l))
- let setify l =
- sort_uniq compare l
- let normalize (point: ('a cursor) list) =
- setify point
- let next_binary tmp s1 s2 =
- let pm w (_, p) = pattern_match p w in
- let o1 = binops `L s1 |>
- List.filter (pm s2.seen) |>
- List.map fst in
- let o2 = binops `R s2 |>
- List.filter (pm s1.seen) |>
- List.map fst in
- List.map (fun (o, l) ->
- o,
- { id = -1
- ; seen = Bnr (o, s1.seen, s2.seen)
- ; point = normalize (l @ tmp) })
- (group_by_fst (o1 @ o2))
- type p = string
- module StateSet : sig
- type t
- val create: unit -> t
- val add: t -> p state ->
- [> `Added | `Found ] * p state
- val iter: t -> (p state -> unit) -> unit
- val elems: t -> (p state) list
- end = struct
- open Hashtbl.Make(struct
- type t = p state
- let equal s1 s2 = s1.point = s2.point
- let hash s = Hashtbl.hash s.point
- end)
- type nonrec t =
- { h: int t
- ; mutable next_id: int }
- let create () =
- { h = create 500; next_id = 0 }
- let add set s =
- assert (s.point = normalize s.point);
- try
- let id = find set.h s in
- `Found, {s with id}
- with Not_found -> begin
- let id = set.next_id in
- set.next_id <- id + 1;
- add set.h s id;
- `Added, {s with id}
- end
- let iter set f =
- let f s id = f {s with id} in
- iter f set.h
- let elems set =
- let res = ref [] in
- iter set (fun s -> res := s :: !res);
- !res
- end
- type table_key =
- | K of op * p state * p state
- module StateMap = struct
- include Map.Make(struct
- type t = table_key
- let compare ka kb =
- match ka, kb with
- | K (o, sl, sr), K (o', sl', sr') ->
- compare (o, sl.id, sr.id)
- (o', sl'.id, sr'.id)
- end)
- let invert n sm =
- let rmap = Array.make n [] in
- iter (fun k {id; _} ->
- match k with
- | K (o, sl, sr) ->
- rmap.(id) <-
- (o, (sl.id, sr.id)) :: rmap.(id)
- ) sm;
- Array.map group_by_fst rmap
- let by_ops sm =
- fold (fun tk s ops ->
- match tk with
- | K (op, l, r) ->
- (op, ((l.id, r.id), s.id)) :: ops)
- sm [] |> group_by_fst
- end
- type rule =
- { name: string
- ; vars: string list
- ; pattern: pattern }
- let generate_table rl =
- let states = StateSet.create () in
- let rl =
- (* these atomic patterns must occur in
- * rules so that we are able to number
- * all possible refs *)
- [ { name = "$"; vars = []
- ; pattern = Atm AnyCon }
- ; { name = "%"; vars = []
- ; pattern = Atm Tmp } ] @ rl
- in
- (* initialize states *)
- let ground =
- List.concat_map
- (fun r -> peel r.pattern r.name) rl |>
- group_by_fst
- in
- let tmp = List.assoc (Atm Tmp) ground in
- let con = List.assoc (Atm AnyCon) ground in
- let atoms = ref [] in
- let () =
- List.iter (fun (seen, l) ->
- let point =
- if pattern_match (Atm Tmp) seen
- then normalize (tmp @ l)
- else normalize (con @ l)
- in
- let s = {id = -1; seen; point} in
- let _, s = StateSet.add states s in
- match get_atomic seen with
- | Some atm -> atoms := (atm, s) :: !atoms
- | None -> ()
- ) ground
- in
- (* setup loop state *)
- let map = ref StateMap.empty in
- let map_add k s' =
- map := StateMap.add k s' !map
- in
- let flag = ref `Added in
- let flagmerge = function
- | `Added -> flag := `Added
- | _ -> ()
- in
- (* iterate until fixpoint *)
- while !flag = `Added do
- flag := `Stop;
- let statel = StateSet.elems states in
- iter_pairs statel (fun (sl, sr) ->
- next_binary tmp sl sr |>
- List.iter (fun (o, s') ->
- let flag', s' =
- StateSet.add states s' in
- flagmerge flag';
- map_add (K (o, sl, sr)) s';
- ));
- done;
- let states =
- StateSet.elems states |>
- List.sort (fun s s' -> compare s.id s'.id) |>
- Array.of_list
- in
- (states, !atoms, !map)
- let intersperse x l =
- let rec go left right out =
- let out =
- (List.rev left @ [x] @ right) ::
- out in
- match right with
- | x :: right' ->
- go (x :: left) right' out
- | [] -> out
- in go [] l []
- let rec permute = function
- | [] -> [[]]
- | x :: l ->
- List.concat (List.map
- (intersperse x) (permute l))
- (* build all binary trees with ordered
- * leaves l *)
- let rec bins build l =
- let rec go l r out =
- match r with
- | [] -> out
- | x :: r' ->
- go (l @ [x]) r'
- (fold_pairs
- (bins build l)
- (bins build r)
- out (fun (l, r) out ->
- build l r :: out))
- in
- match l with
- | [] -> []
- | [x] -> [x]
- | x :: l -> go [x] l []
- let products l ini f =
- let rec go acc la = function
- | [] -> f (List.rev la) acc
- | xs :: l ->
- List.fold_left (fun acc x ->
- go acc (x :: la) l)
- acc xs
- in go ini [] l
- (* combinatorial nuke... *)
- let rec ac_equiv =
- let rec alevel o = function
- | Bnr (o', l, r) when o' = o ->
- alevel o l @ alevel o r
- | x -> [x]
- in function
- | Bnr (o, _, _) as p
- when associative o ->
- products
- (List.map ac_equiv (alevel o p)) []
- (fun choice out ->
- List.concat_map
- (bins (fun l r -> Bnr (o, l, r)))
- (if commutative o
- then permute choice
- else [choice]) @ out)
- | Bnr (o, l, r)
- when commutative o ->
- fold_pairs
- (ac_equiv l) (ac_equiv r) []
- (fun (l, r) out ->
- Bnr (o, l, r) ::
- Bnr (o, r, l) :: out)
- | Bnr (o, l, r) ->
- fold_pairs
- (ac_equiv l) (ac_equiv r) []
- (fun (l, r) out ->
- Bnr (o, l, r) :: out)
- | x -> [x]
- module Action: sig
- type node =
- | Switch of (int * t) list
- | Push of bool * t
- | Pop of t
- | Set of string * t
- | Stop
- and t = private
- { id: int; node: node }
- val equal: t -> t -> bool
- val size: t -> int
- val stop: t
- val mk_push: sym:bool -> t -> t
- val mk_pop: t -> t
- val mk_set: string -> t -> t
- val mk_switch: int list -> (int -> t) -> t
- val pp: Format.formatter -> t -> unit
- end = struct
- type node =
- | Switch of (int * t) list
- | Push of bool * t
- | Pop of t
- | Set of string * t
- | Stop
- and t =
- { id: int; node: node }
- let equal a a' = a.id = a'.id
- let size a =
- let seen = Hashtbl.create 10 in
- let rec node_size = function
- | Switch l ->
- List.fold_left
- (fun n (_, a) -> n + size a) 0 l
- | (Push (_, a) | Pop a | Set (_, a)) ->
- size a
- | Stop -> 0
- and size {id; node} =
- if Hashtbl.mem seen id
- then 0
- else begin
- Hashtbl.add seen id ();
- 1 + node_size node
- end
- in
- size a
- let mk =
- let hcons = Hashtbl.create 100 in
- let fresh = ref 0 in
- fun node ->
- let id =
- try Hashtbl.find hcons node
- with Not_found ->
- let id = !fresh in
- Hashtbl.add hcons node id;
- fresh := id + 1;
- id
- in
- {id; node}
- let stop = mk Stop
- let mk_push ~sym a = mk (Push (sym, a))
- let mk_pop a =
- match a.node with
- | Stop -> a
- | _ -> mk (Pop a)
- let mk_set v a = mk (Set (v, a))
- let mk_switch ids f =
- match List.map f ids with
- | [] -> failwith "empty switch";
- | c :: cs as cases ->
- if List.for_all (equal c) cs then c
- else
- let cases = List.combine ids cases in
- mk (Switch cases)
- open Format
- let rec pp_node fmt = function
- | Switch l ->
- fprintf fmt "@[<v>@[<v2>switch{";
- let pp_case (c, a) =
- let pp_sep fmt () = fprintf fmt "," in
- fprintf fmt "@,@[<2>→%a:@ @[%a@]@]"
- (pp_print_list ~pp_sep pp_print_int)
- c pp a
- in
- inverse l |> group_by_fst |> inverse |>
- List.iter pp_case;
- fprintf fmt "@]@,}@]"
- | Push (true, a) -> fprintf fmt "pushsym@ %a" pp a
- | Push (false, a) -> fprintf fmt "push@ %a" pp a
- | Pop a -> fprintf fmt "pop@ %a" pp a
- | Set (v, a) -> fprintf fmt "set(%s)@ %a" v pp a
- | Stop -> fprintf fmt "•"
- and pp fmt a = pp_node fmt a.node
- end
- (* a state is commutative if (a op b) enters
- * it iff (b op a) enters it as well *)
- let symmetric rmap id =
- List.for_all (fun (_, l) ->
- let l1, l2 =
- List.filter (fun (a, b) -> a <> b) l |>
- List.partition (fun (a, b) -> a < b)
- in
- setify l1 = setify (inverse l2))
- rmap.(id)
- (* left-to-right matching of a set of patterns;
- * may raise if there is no lr matcher for the
- * input rule *)
- let lr_matcher statemap states rules name =
- let rmap =
- let nstates = Array.length states in
- StateMap.invert nstates statemap
- in
- let exception Stuck in
- (* the list of ids represents a class of terms
- * whose root ends up being labelled with one
- * such id; the gen function generates a matcher
- * that will, given any such term, assign values
- * for the Var nodes of one pattern in pats *)
- let rec gen
- : 'a. int list -> (pattern * 'a) list
- -> (int -> (pattern * 'a) list -> Action.t)
- -> Action.t
- = fun ids pats k ->
- Action.mk_switch (setify ids) @@ fun id_top ->
- let sym = symmetric rmap id_top in
- let id_ops =
- if sym then
- let ordered (a, b) = a <= b in
- List.map (fun (o, l) ->
- (o, List.filter ordered l))
- rmap.(id_top)
- else rmap.(id_top)
- in
- (* consider only the patterns that are
- * compatible with the current id *)
- let atm_pats, bin_pats =
- List.filter (function
- | Bnr (o, _, _), _ ->
- List.exists
- (fun (o', _) -> o' = o)
- id_ops
- | _ -> true) pats |>
- List.partition
- (fun (pat, _) -> is_atomic pat)
- in
- try
- if bin_pats = [] then raise Stuck;
- let pats_l =
- List.map (function
- | (Bnr (o, l, r), x) ->
- (l, (o, x, r))
- | _ -> assert false)
- bin_pats
- and pats_r =
- List.map (fun (l, (o, x, r)) ->
- (r, (o, l, x)))
- and patstop =
- List.map (fun (r, (o, l, x)) ->
- (Bnr (o, l, r), x))
- in
- let id_pairs = List.concat_map snd id_ops in
- let ids_l = List.map fst id_pairs
- and ids_r id_left =
- List.filter_map (fun (l, r) ->
- if l = id_left then Some r else None)
- id_pairs
- in
- (* match the left arm *)
- Action.mk_push ~sym
- (gen ids_l pats_l
- @@ fun lid pats ->
- (* then the right arm, considering
- * only the remaining possible
- * patterns and knowing that the
- * left arm was numbered 'lid' *)
- Action.mk_pop
- (gen (ids_r lid) (pats_r pats)
- @@ fun _rid pats ->
- (* continue with the parent *)
- k id_top (patstop pats)))
- with Stuck ->
- let atm_pats =
- let seen = states.(id_top).seen in
- List.filter (fun (pat, _) ->
- pattern_match pat seen) atm_pats
- in
- if atm_pats = [] then raise Stuck else
- let vars =
- List.filter_map (function
- | (Var (v, _), _) -> Some v
- | _ -> None) atm_pats |> setify
- in
- match vars with
- | [] -> k id_top atm_pats
- | [v] -> Action.mk_set v (k id_top atm_pats)
- | _ -> failwith "ambiguous var match"
- in
- (* generate a matcher for the rule *)
- let ids_top =
- Array.to_list states |>
- List.filter_map (fun {id; point = p; _} ->
- if List.exists ((=) (Top name)) p then
- Some id
- else None)
- in
- let rec filter_dups pats =
- match pats with
- | p :: pats ->
- if List.exists (pattern_match p) pats
- then filter_dups pats
- else p :: filter_dups pats
- | [] -> []
- in
- let pats_top =
- List.filter_map (fun r ->
- if r.name = name then
- Some r.pattern
- else None) rules |>
- filter_dups |>
- List.map (fun p -> (p, ()))
- in
- gen ids_top pats_top (fun _ pats ->
- assert (pats <> []);
- Action.stop)
- type numberer =
- { atoms: (atomic_pattern * p state) list
- ; statemap: p state StateMap.t
- ; states: p state array
- ; mutable ops: op list
- (* memoizes the list of possible operations
- * according to the statemap *) }
- let make_numberer sa am sm =
- { atoms = am
- ; states = sa
- ; statemap = sm
- ; ops = [] }
- let atom_state n atm =
- List.assoc atm n.atoms
|