Prechádzať zdrojové kódy

fuse ac rules in ins-tree matching

The initial plan was to have one
matcher per ac-variant, but that
leads to way too much generated
code. Instead, we can fuse ac
variants of the rules and have
a smarter matching algorithm to
recover bound variables.
Quentin Carbonneaux 3 rokov pred
rodič
commit
56e2263ca4
2 zmenil súbory, kde vykonal 86 pridanie a 72 odobranie
  1. 48 32
      tools/match.ml
  2. 38 40
      tools/match_test.ml

+ 48 - 32
tools/match.ml

@@ -23,11 +23,32 @@ type pattern =
   | Atm of atomic_pattern
   | Var of string * atomic_pattern
 
+let show_op (k, o) =
+  (match o with
+   | Oadd -> "add"
+   | Osub -> "sub"
+   | Omul -> "mul") ^
+  (match k with
+   | Kw -> "w"
+   | Kl -> "l"
+   | Ks -> "s"
+   | Kd -> "d")
+
+let rec show_pattern p =
+  match p with
+  | Var _ -> failwith "variable not allowed"
+  | Atm Tmp -> "%"
+  | Atm AnyCon -> "$"
+  | Atm (Con n) -> Int64.to_string n
+  | Bnr (o, pl, pr) ->
+    "(" ^ show_op o ^
+    " " ^ show_pattern pl ^
+    " " ^ show_pattern pr ^ ")"
+
 let rec pattern_match p w =
   match p with
-  | Var _ ->
-    failwith "variable not allowed"
-  | Atm (Tmp) ->
+  | Var _ -> failwith "variable not allowed"
+  | Atm Tmp ->
       begin match w with
       | Atm (Con _ | AnyCon) -> false
       | _ -> true
@@ -89,12 +110,12 @@ type 'a state =
   ; point: ('a cursor) list }
 
 let rec binops side {point; _} =
-  List.fold_left (fun res c ->
+  List.filter_map (fun c ->
       match c, side with
-      | Bnrl (o, c, r), `L -> ((o, c), r) :: res
-      | Bnrr (o, l, c), `R -> ((o, c), l) :: res
-    | _ -> res)
-    [] point
+      | Bnrl (o, c, r), `L -> Some ((o, c), r)
+      | Bnrr (o, l, c), `R -> Some ((o, c), l)
+      | _ -> None)
+    point
 
 let group_by_fst l =
   List.fast_sort (fun (a, _) (b, _) ->
@@ -114,11 +135,9 @@ let sort_uniq cmp l =
   List.fold_left (fun (eo, l) e' ->
       match eo with
       | None -> (Some e', l)
-      | Some e ->
-        if cmp e e' = 0
-        then (eo, l)
-        else (Some e', e :: l)
-    ) (None, []) |>
+      | Some e when cmp e e' = 0 -> (eo, l)
+      | Some e -> (Some e', e :: l))
+    (None, []) |>
   (function
     | (None, _) -> []
     | (Some e, l) -> List.rev (e :: l))
@@ -126,15 +145,14 @@ let sort_uniq cmp l =
 let normalize (point: ('a cursor) list) =
   sort_uniq compare point
 
-let nextbnr tmp s1 s2 =
+let next_binary tmp s1 s2 =
   let pm w (_, p) = pattern_match p w in
   let o1 = binops `L s1 |>
            List.filter (pm s2.seen) |>
-           List.map fst
-  and o2 = binops `R s2 |>
+           List.map fst in
+  let o2 = binops `R s2 |>
            List.filter (pm s1.seen) |>
-           List.map fst
-  in
+           List.map fst in
   List.map (fun (o, l) ->
     o,
     { id = 0
@@ -145,25 +163,24 @@ let nextbnr tmp s1 s2 =
 type p = string
 
 module StateSet : sig
-  type set
-  val create: unit -> set
-  val add: set -> p state ->
+  type t
+  val create: unit -> t
+  val add: t -> p state ->
            [> `Added | `Found ] * p state
-  val iter: set -> (p state -> unit) -> unit
-  val elems: set -> (p state) list
+  val iter: t -> (p state -> unit) -> unit
+  val elems: t -> (p state) list
 end = struct
-  include Hashtbl.Make(struct
+  open Hashtbl.Make(struct
     type t = p state
     let equal s1 s2 = s1.point = s2.point
     let hash s = Hashtbl.hash s.point
   end)
-  type set =
+  type nonrec t =
     { h: int t
     ; mutable next_id: int }
   let create () =
     { h = create 500; next_id = 1 }
   let add set s =
-    (* delete the check later *)
     assert (s.point = normalize s.point);
     try
       let id = find set.h s in
@@ -171,6 +188,8 @@ end = struct
     with Not_found -> begin
       let id = set.next_id in
       set.next_id <- id + 1;
+      Printf.printf "adding: %d [%s]\n"
+        id (show_pattern s.seen);
       add set.h s id;
       `Added, {s with id}
     end
@@ -198,17 +217,14 @@ end)
 type rule =
   { name: string
   ; pattern: pattern
-  (* TODO access pattern *)
   }
 
 let generate_table rl =
   let states = StateSet.create () in
   (* initialize states *)
   let ground =
-    List.fold_left
-      (fun ini r ->
-        peel r.pattern r.name @ ini)
-      [] rl |>
+    List.concat_map
+      (fun r -> peel r.pattern r.name) rl |>
     group_by_fst
   in
   let find x d l =
@@ -242,7 +258,7 @@ let generate_table rl =
     flag := `Stop;
     let statel = StateSet.elems states in
     iter_pairs statel (fun (sl, sr) ->
-      nextbnr tmp sl sr |>
+      next_binary tmp sl sr |>
       List.iter (fun (o, s') ->
         let flag', s' =
           StateSet.add states s' in

+ 38 - 40
tools/match_test.ml

@@ -46,54 +46,52 @@ let ts =
   }
 
 let print_sm =
-  let op_str (k, o) =
-    Printf.sprintf "%s%s"
-      (match o with
-       | Oadd -> "add"
-       | Osub -> "sub"
-       | Omul -> "mul")
-      (match k with
-       | Kw -> "w"
-       | Kl -> "l"
-       | Ks -> "s"
-       | Kd -> "d")
-  in
   StateMap.iter (fun k s' ->
     match k with
     | K (o, sl, sr) ->
+        let top =
+          List.fold_left (fun top c ->
+            match c with
+            | Top r -> top ^ " " ^ r
+            | _ -> top) "" s'.point
+        in
         Printf.printf
-          "(%s %d %d) -> %d\n"
-          (op_str o)
-          sl.id sr.id s'.id
-  )
+          "(%s %d %d) -> %d%s\n"
+          (show_op o)
+          sl.id sr.id s'.id top)
 
-let address_rules =
+let rules =
   let oa = Kl, Oadd in
   let om = Kl, Omul in
-  let rule name pattern =
-    List.mapi (fun i pattern ->
-        { name = Printf.sprintf "%s%d" name (i+1)
-        ; pattern; })
-      (ac_equiv pattern) in
-
+  match `X64Addr with
+  (* ------------------------------- *)
+  | `X64Addr ->
+    let rule name pattern =
+      List.mapi (fun i pattern ->
+          { name (* = Printf.sprintf "%s%d" name (i+1) *)
+          ; pattern })
+        (ac_equiv pattern) in
     (* o + b *)
-  rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
-  @ (* b + s * i *)
-  rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm AnyCon, Atm Tmp)))
-  @ (* o + s * i *)
-  rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm AnyCon, Atm Tmp)))
-  @ (* b + o + s * i *)
-  rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm AnyCon, Atm Tmp)))
+    rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
+    @ (* b + s * i *)
+    rule "bs" (Bnr (oa, Atm Tmp, Bnr (om, Atm (Con 4L), Atm Tmp)))
+    @ (* o + s * i *)
+    rule "os" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp)))
+    @ (* b + o + s * i *)
+    rule "bos" (Bnr (oa, Bnr (oa, Atm AnyCon, Atm Tmp), Bnr (om, Atm (Con 4L), Atm Tmp)))
+  (* ------------------------------- *)
+  | `Add3 ->
+    [ { name = "add"
+      ; pattern = Bnr (oa, Atm Tmp, Bnr (oa, Atm Tmp, Atm Tmp)) } ] @
+    [ { name = "add"
+      ; pattern = Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp), Atm Tmp) } ] @
+    [ { name = "mul"
+      ; pattern = Bnr (om, Bnr (oa, Bnr (oa, Atm Tmp, Atm Tmp),
+                                    Atm Tmp),
+                           Bnr (oa, Atm Tmp,
+                                    Bnr (oa, Atm Tmp, Atm Tmp))) } ]
+
 
-let sl, sm = generate_table address_rules
+let sl, sm = generate_table rules
 let s n = List.find (fun {id; _} -> id = n) sl
 let () = print_sm sm
-
-(*
-let tp0 =
-  let o = Kw, Oadd in
-  Bnr (o, Atm Tmp, Atm (Con 0L))
-let tp1 =
-  let o = Kw, Oadd in
-  Bnr (o, tp0, Atm (Con 1L))
-*)