match.ml 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. type cls = Kw | Kl | Ks | Kd
  2. type op_base =
  3. | Oadd
  4. | Osub
  5. | Omul
  6. | Oor
  7. | Oshl
  8. | Oshr
  9. type op = cls * op_base
  10. let op_bases =
  11. [Oadd; Osub; Omul; Oor; Oshl; Oshr]
  12. let commutative = function
  13. | (_, (Oadd | Omul | Oor)) -> true
  14. | (_, _) -> false
  15. let associative = function
  16. | (_, (Oadd | Omul | Oor)) -> true
  17. | (_, _) -> false
  18. type atomic_pattern =
  19. | Tmp
  20. | AnyCon
  21. | Con of int64
  22. (* Tmp < AnyCon < Con k *)
  23. type pattern =
  24. | Bnr of op * pattern * pattern
  25. | Atm of atomic_pattern
  26. | Var of string * atomic_pattern
  27. let is_atomic = function
  28. | (Atm _ | Var _) -> true
  29. | _ -> false
  30. let show_op_base o =
  31. match o with
  32. | Oadd -> "add"
  33. | Osub -> "sub"
  34. | Omul -> "mul"
  35. | Oor -> "or"
  36. | Oshl -> "shl"
  37. | Oshr -> "shr"
  38. let show_op (k, o) =
  39. show_op_base o ^
  40. (match k with
  41. | Kw -> "w"
  42. | Kl -> "l"
  43. | Ks -> "s"
  44. | Kd -> "d")
  45. let rec show_pattern p =
  46. match p with
  47. | Atm Tmp -> "%"
  48. | Atm AnyCon -> "$"
  49. | Atm (Con n) -> Int64.to_string n
  50. | Var (v, p) ->
  51. show_pattern (Atm p) ^ "'" ^ v
  52. | Bnr (o, pl, pr) ->
  53. "(" ^ show_op o ^
  54. " " ^ show_pattern pl ^
  55. " " ^ show_pattern pr ^ ")"
  56. let get_atomic p =
  57. match p with
  58. | (Atm a | Var (_, a)) -> Some a
  59. | _ -> None
  60. let rec pattern_match p w =
  61. match p with
  62. | Var (_, p) ->
  63. pattern_match (Atm p) w
  64. | Atm Tmp ->
  65. begin match get_atomic w with
  66. | Some (Con _ | AnyCon) -> false
  67. | _ -> true
  68. end
  69. | Atm (Con _) -> w = p
  70. | Atm (AnyCon) ->
  71. not (pattern_match (Atm Tmp) w)
  72. | Bnr (o, pl, pr) ->
  73. begin match w with
  74. | Bnr (o', wl, wr) ->
  75. o' = o &&
  76. pattern_match pl wl &&
  77. pattern_match pr wr
  78. | _ -> false
  79. end
  80. type +'a cursor = (* a position inside a pattern *)
  81. | Bnrl of op * 'a cursor * pattern
  82. | Bnrr of op * pattern * 'a cursor
  83. | Top of 'a
  84. let rec fold_cursor c p =
  85. match c with
  86. | Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
  87. | Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
  88. | Top _ -> p
  89. let peel p x =
  90. let once out (p, c) =
  91. match p with
  92. | Var (_, p) -> (Atm p, c) :: out
  93. | Atm _ -> (p, c) :: out
  94. | Bnr (o, pl, pr) ->
  95. (pl, Bnrl (o, c, pr)) ::
  96. (pr, Bnrr (o, pl, c)) :: out
  97. in
  98. let rec go l =
  99. let l' = List.fold_left once [] l in
  100. if List.length l' = List.length l
  101. then l'
  102. else go l'
  103. in go [(p, Top x)]
  104. let fold_pairs l1 l2 ini f =
  105. let rec go acc = function
  106. | [] -> acc
  107. | a :: l1' ->
  108. go (List.fold_left
  109. (fun acc b -> f (a, b) acc)
  110. acc l2) l1'
  111. in go ini l1
  112. let iter_pairs l f =
  113. fold_pairs l l () (fun x () -> f x)
  114. let inverse l =
  115. List.map (fun (a, b) -> (b, a)) l
  116. type 'a state =
  117. { id: int
  118. ; seen: pattern
  119. ; point: ('a cursor) list }
  120. let rec binops side {point; _} =
  121. List.filter_map (fun c ->
  122. match c, side with
  123. | Bnrl (o, c, r), `L -> Some ((o, c), r)
  124. | Bnrr (o, l, c), `R -> Some ((o, c), l)
  125. | _ -> None)
  126. point
  127. let group_by_fst l =
  128. List.fast_sort (fun (a, _) (b, _) ->
  129. compare a b) l |>
  130. List.fold_left (fun (oo, l, res) (o', c) ->
  131. match oo with
  132. | None -> (Some o', [c], [])
  133. | Some o when o = o' -> (oo, c :: l, res)
  134. | Some o -> (Some o', [c], (o, l) :: res))
  135. (None, [], []) |>
  136. (function
  137. | (None, _, _) -> []
  138. | (Some o, l, res) -> (o, l) :: res)
  139. let sort_uniq cmp l =
  140. List.fast_sort cmp l |>
  141. List.fold_left (fun (eo, l) e' ->
  142. match eo with
  143. | None -> (Some e', l)
  144. | Some e when cmp e e' = 0 -> (eo, l)
  145. | Some e -> (Some e', e :: l))
  146. (None, []) |>
  147. (function
  148. | (None, _) -> []
  149. | (Some e, l) -> List.rev (e :: l))
  150. let setify l =
  151. sort_uniq compare l
  152. let normalize (point: ('a cursor) list) =
  153. setify point
  154. let next_binary tmp s1 s2 =
  155. let pm w (_, p) = pattern_match p w in
  156. let o1 = binops `L s1 |>
  157. List.filter (pm s2.seen) |>
  158. List.map fst in
  159. let o2 = binops `R s2 |>
  160. List.filter (pm s1.seen) |>
  161. List.map fst in
  162. List.map (fun (o, l) ->
  163. o,
  164. { id = -1
  165. ; seen = Bnr (o, s1.seen, s2.seen)
  166. ; point = normalize (l @ tmp) })
  167. (group_by_fst (o1 @ o2))
  168. type p = string
  169. module StateSet : sig
  170. type t
  171. val create: unit -> t
  172. val add: t -> p state ->
  173. [> `Added | `Found ] * p state
  174. val iter: t -> (p state -> unit) -> unit
  175. val elems: t -> (p state) list
  176. end = struct
  177. open Hashtbl.Make(struct
  178. type t = p state
  179. let equal s1 s2 = s1.point = s2.point
  180. let hash s = Hashtbl.hash s.point
  181. end)
  182. type nonrec t =
  183. { h: int t
  184. ; mutable next_id: int }
  185. let create () =
  186. { h = create 500; next_id = 0 }
  187. let add set s =
  188. assert (s.point = normalize s.point);
  189. try
  190. let id = find set.h s in
  191. `Found, {s with id}
  192. with Not_found -> begin
  193. let id = set.next_id in
  194. set.next_id <- id + 1;
  195. add set.h s id;
  196. `Added, {s with id}
  197. end
  198. let iter set f =
  199. let f s id = f {s with id} in
  200. iter f set.h
  201. let elems set =
  202. let res = ref [] in
  203. iter set (fun s -> res := s :: !res);
  204. !res
  205. end
  206. type table_key =
  207. | K of op * p state * p state
  208. module StateMap = struct
  209. include Map.Make(struct
  210. type t = table_key
  211. let compare ka kb =
  212. match ka, kb with
  213. | K (o, sl, sr), K (o', sl', sr') ->
  214. compare (o, sl.id, sr.id)
  215. (o', sl'.id, sr'.id)
  216. end)
  217. let invert n sm =
  218. let rmap = Array.make n [] in
  219. iter (fun k {id; _} ->
  220. match k with
  221. | K (o, sl, sr) ->
  222. rmap.(id) <-
  223. (o, (sl.id, sr.id)) :: rmap.(id)
  224. ) sm;
  225. Array.map group_by_fst rmap
  226. let by_ops sm =
  227. fold (fun tk s ops ->
  228. match tk with
  229. | K (op, l, r) ->
  230. (op, ((l.id, r.id), s.id)) :: ops)
  231. sm [] |> group_by_fst
  232. end
  233. type rule =
  234. { name: string
  235. ; vars: string list
  236. ; pattern: pattern }
  237. let generate_table rl =
  238. let states = StateSet.create () in
  239. let rl =
  240. (* these atomic patterns must occur in
  241. * rules so that we are able to number
  242. * all possible refs *)
  243. [ { name = "$"; vars = []
  244. ; pattern = Atm AnyCon }
  245. ; { name = "%"; vars = []
  246. ; pattern = Atm Tmp } ] @ rl
  247. in
  248. (* initialize states *)
  249. let ground =
  250. List.concat_map
  251. (fun r -> peel r.pattern r.name) rl |>
  252. group_by_fst
  253. in
  254. let tmp = List.assoc (Atm Tmp) ground in
  255. let con = List.assoc (Atm AnyCon) ground in
  256. let atoms = ref [] in
  257. let () =
  258. List.iter (fun (seen, l) ->
  259. let point =
  260. if pattern_match (Atm Tmp) seen
  261. then normalize (tmp @ l)
  262. else normalize (con @ l)
  263. in
  264. let s = {id = -1; seen; point} in
  265. let _, s = StateSet.add states s in
  266. match get_atomic seen with
  267. | Some atm -> atoms := (atm, s) :: !atoms
  268. | None -> ()
  269. ) ground
  270. in
  271. (* setup loop state *)
  272. let map = ref StateMap.empty in
  273. let map_add k s' =
  274. map := StateMap.add k s' !map
  275. in
  276. let flag = ref `Added in
  277. let flagmerge = function
  278. | `Added -> flag := `Added
  279. | _ -> ()
  280. in
  281. (* iterate until fixpoint *)
  282. while !flag = `Added do
  283. flag := `Stop;
  284. let statel = StateSet.elems states in
  285. iter_pairs statel (fun (sl, sr) ->
  286. next_binary tmp sl sr |>
  287. List.iter (fun (o, s') ->
  288. let flag', s' =
  289. StateSet.add states s' in
  290. flagmerge flag';
  291. map_add (K (o, sl, sr)) s';
  292. ));
  293. done;
  294. let states =
  295. StateSet.elems states |>
  296. List.sort (fun s s' -> compare s.id s'.id) |>
  297. Array.of_list
  298. in
  299. (states, !atoms, !map)
  300. let intersperse x l =
  301. let rec go left right out =
  302. let out =
  303. (List.rev left @ [x] @ right) ::
  304. out in
  305. match right with
  306. | x :: right' ->
  307. go (x :: left) right' out
  308. | [] -> out
  309. in go [] l []
  310. let rec permute = function
  311. | [] -> [[]]
  312. | x :: l ->
  313. List.concat (List.map
  314. (intersperse x) (permute l))
  315. (* build all binary trees with ordered
  316. * leaves l *)
  317. let rec bins build l =
  318. let rec go l r out =
  319. match r with
  320. | [] -> out
  321. | x :: r' ->
  322. go (l @ [x]) r'
  323. (fold_pairs
  324. (bins build l)
  325. (bins build r)
  326. out (fun (l, r) out ->
  327. build l r :: out))
  328. in
  329. match l with
  330. | [] -> []
  331. | [x] -> [x]
  332. | x :: l -> go [x] l []
  333. let products l ini f =
  334. let rec go acc la = function
  335. | [] -> f (List.rev la) acc
  336. | xs :: l ->
  337. List.fold_left (fun acc x ->
  338. go acc (x :: la) l)
  339. acc xs
  340. in go ini [] l
  341. (* combinatorial nuke... *)
  342. let rec ac_equiv =
  343. let rec alevel o = function
  344. | Bnr (o', l, r) when o' = o ->
  345. alevel o l @ alevel o r
  346. | x -> [x]
  347. in function
  348. | Bnr (o, _, _) as p
  349. when associative o ->
  350. products
  351. (List.map ac_equiv (alevel o p)) []
  352. (fun choice out ->
  353. List.concat_map
  354. (bins (fun l r -> Bnr (o, l, r)))
  355. (if commutative o
  356. then permute choice
  357. else [choice]) @ out)
  358. | Bnr (o, l, r)
  359. when commutative o ->
  360. fold_pairs
  361. (ac_equiv l) (ac_equiv r) []
  362. (fun (l, r) out ->
  363. Bnr (o, l, r) ::
  364. Bnr (o, r, l) :: out)
  365. | Bnr (o, l, r) ->
  366. fold_pairs
  367. (ac_equiv l) (ac_equiv r) []
  368. (fun (l, r) out ->
  369. Bnr (o, l, r) :: out)
  370. | x -> [x]
  371. module Action: sig
  372. type node =
  373. | Switch of (int * t) list
  374. | Push of bool * t
  375. | Pop of t
  376. | Set of string * t
  377. | Stop
  378. and t = private
  379. { id: int; node: node }
  380. val equal: t -> t -> bool
  381. val size: t -> int
  382. val stop: t
  383. val mk_push: sym:bool -> t -> t
  384. val mk_pop: t -> t
  385. val mk_set: string -> t -> t
  386. val mk_switch: int list -> (int -> t) -> t
  387. val pp: Format.formatter -> t -> unit
  388. end = struct
  389. type node =
  390. | Switch of (int * t) list
  391. | Push of bool * t
  392. | Pop of t
  393. | Set of string * t
  394. | Stop
  395. and t =
  396. { id: int; node: node }
  397. let equal a a' = a.id = a'.id
  398. let size a =
  399. let seen = Hashtbl.create 10 in
  400. let rec node_size = function
  401. | Switch l ->
  402. List.fold_left
  403. (fun n (_, a) -> n + size a) 0 l
  404. | (Push (_, a) | Pop a | Set (_, a)) ->
  405. size a
  406. | Stop -> 0
  407. and size {id; node} =
  408. if Hashtbl.mem seen id
  409. then 0
  410. else begin
  411. Hashtbl.add seen id ();
  412. 1 + node_size node
  413. end
  414. in
  415. size a
  416. let mk =
  417. let hcons = Hashtbl.create 100 in
  418. let fresh = ref 0 in
  419. fun node ->
  420. let id =
  421. try Hashtbl.find hcons node
  422. with Not_found ->
  423. let id = !fresh in
  424. Hashtbl.add hcons node id;
  425. fresh := id + 1;
  426. id
  427. in
  428. {id; node}
  429. let stop = mk Stop
  430. let mk_push ~sym a = mk (Push (sym, a))
  431. let mk_pop a =
  432. match a.node with
  433. | Stop -> a
  434. | _ -> mk (Pop a)
  435. let mk_set v a = mk (Set (v, a))
  436. let mk_switch ids f =
  437. match List.map f ids with
  438. | [] -> failwith "empty switch";
  439. | c :: cs as cases ->
  440. if List.for_all (equal c) cs then c
  441. else
  442. let cases = List.combine ids cases in
  443. mk (Switch cases)
  444. open Format
  445. let rec pp_node fmt = function
  446. | Switch l ->
  447. fprintf fmt "@[<v>@[<v2>switch{";
  448. let pp_case (c, a) =
  449. let pp_sep fmt () = fprintf fmt "," in
  450. fprintf fmt "@,@[<2>→%a:@ @[%a@]@]"
  451. (pp_print_list ~pp_sep pp_print_int)
  452. c pp a
  453. in
  454. inverse l |> group_by_fst |> inverse |>
  455. List.iter pp_case;
  456. fprintf fmt "@]@,}@]"
  457. | Push (true, a) -> fprintf fmt "pushsym@ %a" pp a
  458. | Push (false, a) -> fprintf fmt "push@ %a" pp a
  459. | Pop a -> fprintf fmt "pop@ %a" pp a
  460. | Set (v, a) -> fprintf fmt "set(%s)@ %a" v pp a
  461. | Stop -> fprintf fmt "•"
  462. and pp fmt a = pp_node fmt a.node
  463. end
  464. (* a state is commutative if (a op b) enters
  465. * it iff (b op a) enters it as well *)
  466. let symmetric rmap id =
  467. List.for_all (fun (_, l) ->
  468. let l1, l2 =
  469. List.filter (fun (a, b) -> a <> b) l |>
  470. List.partition (fun (a, b) -> a < b)
  471. in
  472. setify l1 = setify (inverse l2))
  473. rmap.(id)
  474. (* left-to-right matching of a set of patterns;
  475. * may raise if there is no lr matcher for the
  476. * input rule *)
  477. let lr_matcher statemap states rules name =
  478. let rmap =
  479. let nstates = Array.length states in
  480. StateMap.invert nstates statemap
  481. in
  482. let exception Stuck in
  483. (* the list of ids represents a class of terms
  484. * whose root ends up being labelled with one
  485. * such id; the gen function generates a matcher
  486. * that will, given any such term, assign values
  487. * for the Var nodes of one pattern in pats *)
  488. let rec gen
  489. : 'a. int list -> (pattern * 'a) list
  490. -> (int -> (pattern * 'a) list -> Action.t)
  491. -> Action.t
  492. = fun ids pats k ->
  493. Action.mk_switch (setify ids) @@ fun id_top ->
  494. let sym = symmetric rmap id_top in
  495. let id_ops =
  496. if sym then
  497. let ordered (a, b) = a <= b in
  498. List.map (fun (o, l) ->
  499. (o, List.filter ordered l))
  500. rmap.(id_top)
  501. else rmap.(id_top)
  502. in
  503. (* consider only the patterns that are
  504. * compatible with the current id *)
  505. let atm_pats, bin_pats =
  506. List.filter (function
  507. | Bnr (o, _, _), _ ->
  508. List.exists
  509. (fun (o', _) -> o' = o)
  510. id_ops
  511. | _ -> true) pats |>
  512. List.partition
  513. (fun (pat, _) -> is_atomic pat)
  514. in
  515. try
  516. if bin_pats = [] then raise Stuck;
  517. let pats_l =
  518. List.map (function
  519. | (Bnr (o, l, r), x) ->
  520. (l, (o, x, r))
  521. | _ -> assert false)
  522. bin_pats
  523. and pats_r =
  524. List.map (fun (l, (o, x, r)) ->
  525. (r, (o, l, x)))
  526. and patstop =
  527. List.map (fun (r, (o, l, x)) ->
  528. (Bnr (o, l, r), x))
  529. in
  530. let id_pairs = List.concat_map snd id_ops in
  531. let ids_l = List.map fst id_pairs
  532. and ids_r id_left =
  533. List.filter_map (fun (l, r) ->
  534. if l = id_left then Some r else None)
  535. id_pairs
  536. in
  537. (* match the left arm *)
  538. Action.mk_push ~sym
  539. (gen ids_l pats_l
  540. @@ fun lid pats ->
  541. (* then the right arm, considering
  542. * only the remaining possible
  543. * patterns and knowing that the
  544. * left arm was numbered 'lid' *)
  545. Action.mk_pop
  546. (gen (ids_r lid) (pats_r pats)
  547. @@ fun _rid pats ->
  548. (* continue with the parent *)
  549. k id_top (patstop pats)))
  550. with Stuck ->
  551. let atm_pats =
  552. let seen = states.(id_top).seen in
  553. List.filter (fun (pat, _) ->
  554. pattern_match pat seen) atm_pats
  555. in
  556. if atm_pats = [] then raise Stuck else
  557. let vars =
  558. List.filter_map (function
  559. | (Var (v, _), _) -> Some v
  560. | _ -> None) atm_pats |> setify
  561. in
  562. match vars with
  563. | [] -> k id_top atm_pats
  564. | [v] -> Action.mk_set v (k id_top atm_pats)
  565. | _ -> failwith "ambiguous var match"
  566. in
  567. (* generate a matcher for the rule *)
  568. let ids_top =
  569. Array.to_list states |>
  570. List.filter_map (fun {id; point = p; _} ->
  571. if List.exists ((=) (Top name)) p then
  572. Some id
  573. else None)
  574. in
  575. let rec filter_dups pats =
  576. match pats with
  577. | p :: pats ->
  578. if List.exists (pattern_match p) pats
  579. then filter_dups pats
  580. else p :: filter_dups pats
  581. | [] -> []
  582. in
  583. let pats_top =
  584. List.filter_map (fun r ->
  585. if r.name = name then
  586. Some r.pattern
  587. else None) rules |>
  588. filter_dups |>
  589. List.map (fun p -> (p, ()))
  590. in
  591. gen ids_top pats_top (fun _ pats ->
  592. assert (pats <> []);
  593. Action.stop)
  594. type numberer =
  595. { atoms: (atomic_pattern * p state) list
  596. ; statemap: p state StateMap.t
  597. ; states: p state array
  598. ; mutable ops: op list
  599. (* memoizes the list of possible operations
  600. * according to the statemap *) }
  601. let make_numberer sa am sm =
  602. { atoms = am
  603. ; states = sa
  604. ; statemap = sm
  605. ; ops = [] }
  606. let atom_state n atm =
  607. List.assoc atm n.atoms