瀏覽代碼

modulo ac matching and more tests

Quentin Carbonneaux 8 年之前
父節點
當前提交
a374da3c2e
共有 2 個文件被更改,包括 392 次插入67 次删除
  1. 290 67
      tools/match.ml
  2. 102 0
      tools/match_test.ml

+ 290 - 67
tools/match.ml

@@ -5,26 +5,36 @@ type op_base =
   | Omul
   | Omul
 type op = cls * op_base
 type op = cls * op_base
 
 
+let commutative = function
+  | (_, (Oadd | Omul)) -> true
+  | (_, _) -> false
+
+let associative = function
+  | (_, (Oadd | Omul)) -> true
+  | (_, _) -> false
+
 type atomic_pattern =
 type atomic_pattern =
-  | Any
+  | Tmp
+  | AnyCon
   | Con of int64
   | Con of int64
 
 
 type pattern =
 type pattern =
   | Bnr of op * pattern * pattern
   | Bnr of op * pattern * pattern
-  | Unr of op * pattern
   | Atm of atomic_pattern
   | Atm of atomic_pattern
+  | Var of string * atomic_pattern
 
 
 let rec pattern_match p w =
 let rec pattern_match p w =
   match p with
   match p with
-  | Atm (Any) -> true
-  | Atm (Con _) -> w = p
-  | Unr (o, pa) ->
+  | Var _ ->
+    failwith "variable not allowed"
+  | Atm (Tmp) ->
       begin match w with
       begin match w with
-      | Unr (o', wa) ->
-          o' = o &&
-          pattern_match pa wa
-      | _ -> false
+      | Atm (Con _ | AnyCon) -> false
+      | _ -> true
       end
       end
+  | Atm (Con _) -> w = p
+  | Atm (AnyCon) ->
+      not (pattern_match (Atm Tmp) w)
   | Bnr (o, pl, pr) ->
   | Bnr (o, pl, pr) ->
       begin match w with
       begin match w with
       | Bnr (o', wl, wr) ->
       | Bnr (o', wl, wr) ->
@@ -34,75 +44,288 @@ let rec pattern_match p w =
       | _ -> false
       | _ -> false
       end
       end
 
 
-let test_pattern_match =
-  let pm = pattern_match
-  and nm = fun x y -> not (pattern_match x y)
-  and o = (Kw, Oadd) in
-  begin
-    assert (pm (Atm Any) (Atm (Con 42L)));
-    assert (pm (Atm Any) (Unr (o, Atm Any)));
-    assert (nm (Atm (Con 42L)) (Atm Any));
-    assert (pm (Unr (o, Atm Any))
-               (Unr (o, Atm (Con 42L))));
-    assert (nm (Unr (o, Atm Any))
-               (Unr ((Kl, Oadd), Atm (Con 42L))));
-    assert (nm (Unr (o, Atm Any))
-               (Bnr (o, Atm (Con 42L), Atm Any)));
-  end
-
-type cursor = (* a position inside a pattern *)
-  | Bnrl of op * cursor * pattern
-  | Bnrr of op * pattern * cursor
-  | Unra of op * cursor
-  | Top
+type 'a cursor = (* a position inside a pattern *)
+  | Bnrl of op * 'a cursor * pattern
+  | Bnrr of op * pattern * 'a cursor
+  | Top of 'a
 
 
 let rec fold_cursor c p =
 let rec fold_cursor c p =
   match c with
   match c with
   | Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
   | Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
   | Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
   | Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
-  | Unra (o, c') -> fold_cursor c' (Unr (o, p))
-  | Top -> p
+  | Top _ -> p
 
 
-let peel p =
-  let once out (c, p) =
+let peel p x =
+  let once out (p, c) =
     match p with
     match p with
-    | Atm _ -> (c, p) :: out
-    | Unr (o, pa) ->
-        (Unra (o, c), pa) :: out
+    | Var _ -> failwith "variable not allowed"
+    | Atm _ -> (p, c) :: out
     | Bnr (o, pl, pr) ->
     | Bnr (o, pl, pr) ->
-        (Bnrl (o, c, pr), pl) ::
-        (Bnrr (o, pl, c), pr) :: out
+        (pl, Bnrl (o, c, pr)) ::
+        (pr, Bnrr (o, pl, c)) :: out
   in
   in
   let rec go l =
   let rec go l =
     let l' = List.fold_left once [] l in
     let l' = List.fold_left once [] l in
     if List.length l' = List.length l
     if List.length l' = List.length l
     then l
     then l
     else go l'
     else go l'
-  in go [(Top, p)]
-
-let test_peel =
-  let o = Kw, Oadd in
-  let p = Bnr (o, Bnr (o, Atm Any, Atm Any),
-                  Atm (Con 42L)) in
-  let l = peel p in
-  let () = assert (List.length l = 3) in
-  let atomic_p (_, p) =
-    match p with Atm _ -> true | _ -> false in
-  let () = assert (List.for_all atomic_p l) in
-  let l = List.map (fun (c, p) -> fold_cursor c p) l in
-  let () = assert (List.for_all ((=) p) l) in
-  ()
-
-(* we want to compute all the configurations we could
- * possibly be in when processing a block of instructions;
- * to do so, we start with all the possible cursors for
- * the list of patterns we are given, this will be our
- * main "initial state"; each constant (used in the
- * patterns) also generates a state of its own
- *
- * to create new states we can take pairs of states, and
- * combine them with binary operations, we keep the
- * result if it is non-trivial (non-empty) and new (we
- * have not seen this cursor combination yet); we can
- * also do the same with unary operations
- * *)
+  in go [(p, Top x)]
+
+let fold_pairs l1 l2 ini f =
+  let rec go acc = function
+    | [] -> acc
+    | a :: l1' ->
+        go (List.fold_left
+          (fun acc b -> f (a, b) acc)
+          acc l2) l1'
+  in go ini l1
+
+let iter_pairs l f =
+  fold_pairs l l () (fun x () -> f x)
+
+type 'a state =
+  { id: int
+  ; seen: pattern
+  ; point: ('a cursor) list }
+
+let rec binops side {point; _} =
+  List.fold_left (fun res c ->
+      match c, side with
+      | Bnrl (o, c, r), `L -> ((o, c), r) :: res
+      | Bnrr (o, l, c), `R -> ((o, c), l) :: res
+    | _ -> res)
+    [] point
+
+let group_by_fst l =
+  List.fast_sort (fun (a, _) (b, _) ->
+    compare a b) l |>
+  List.fold_left (fun (oo, l, res) (o', c) ->
+      match oo with
+      | None -> (Some o', [c], [])
+      | Some o when o = o' -> (oo, c :: l, res)
+      | Some o -> (Some o', [c], (o, l) :: res))
+    (None, [], []) |>
+  (function
+    | (None, _, _) -> []
+    | (Some o, l, res) -> (o, l) :: res)
+
+let sort_uniq cmp l =
+  List.fast_sort cmp l |>
+  List.fold_left (fun (eo, l) e' ->
+      match eo with
+      | None -> (Some e', l)
+      | Some e ->
+        if cmp e e' = 0
+        then (eo, l)
+        else (Some e', e :: l)
+    ) (None, []) |>
+  (function
+    | (None, _) -> []
+    | (Some e, l) -> List.rev (e :: l))
+
+let normalize (point: ('a cursor) list) =
+  sort_uniq compare point
+
+let nextbnr tmp s1 s2 =
+  let pm w (_, p) = pattern_match p w in
+  let o1 = binops `L s1 |>
+           List.filter (pm s2.seen) |>
+           List.map fst
+  and o2 = binops `R s2 |>
+           List.filter (pm s1.seen) |>
+           List.map fst
+  in
+  List.map (fun (o, l) ->
+    o,
+    { id = 0
+    ; seen = Bnr (o, s1.seen, s2.seen)
+    ; point = normalize (l @ tmp)
+    }) (group_by_fst (o1 @ o2))
+
+type p = string
+
+module StateSet : sig
+  type set
+  val create: unit -> set
+  val add: set -> p state ->
+           [> `Added | `Found ] * p state
+  val iter: set -> (p state -> unit) -> unit
+  val elems: set -> (p state) list
+end = struct
+  include Hashtbl.Make(struct
+    type t = p state
+    let equal s1 s2 = s1.point = s2.point
+    let hash s = Hashtbl.hash s.point
+  end)
+  type set =
+    { h: int t
+    ; mutable next_id: int }
+  let create () =
+    { h = create 500; next_id = 1 }
+  let add set s =
+    (* delete the check later *)
+    assert (s.point = normalize s.point);
+    try
+      let id = find set.h s in
+      `Found, {s with id}
+    with Not_found -> begin
+      let id = set.next_id in
+      set.next_id <- id + 1;
+      add set.h s id;
+      `Added, {s with id}
+    end
+  let iter set f =
+    let f s id = f {s with id} in
+    iter f set.h
+  let elems set =
+    let res = ref [] in
+    iter set (fun s -> res := s :: !res);
+    !res
+end
+
+type table_key =
+  | K of op * p state * p state
+
+module StateMap = Map.Make(struct
+  type t = table_key
+  let compare ka kb =
+    match ka, kb with
+    | K (o, sl, sr), K (o', sl', sr') ->
+      compare (o, sl.id, sr.id)
+              (o', sl'.id, sr'.id)
+end)
+
+type rule =
+  { name: string
+  ; pattern: pattern
+  (* TODO access pattern *)
+  }
+
+let generate_table rl =
+  let states = StateSet.create () in
+  (* initialize states *)
+  let ground =
+    List.fold_left
+      (fun ini r ->
+        peel r.pattern r.name @ ini)
+      [] rl |>
+    group_by_fst
+  in
+  let find x d l =
+    try List.assoc x l with Not_found -> d in
+  let tmp = find (Atm Tmp) [] ground in
+  let con = find (Atm AnyCon) [] ground in
+  let () =
+    List.iter (fun (seen, l) ->
+      let point =
+        if pattern_match (Atm Tmp) seen
+        then normalize (tmp @ l)
+        else normalize (con @ l)
+      in
+      let s = {id = 0; seen; point} in
+      let flag, _ = StateSet.add states s in
+      assert (flag = `Added)
+    ) ground
+  in
+  (* setup loop state *)
+  let map = ref StateMap.empty in
+  let map_add k s' =
+    map := StateMap.add k s' !map
+  in
+  let flag = ref `Added in
+  let flagmerge = function
+    | `Added -> flag := `Added
+    | _ -> ()
+  in
+  (* iterate until fixpoint *)
+  while !flag = `Added do
+    flag := `Stop;
+    let statel = StateSet.elems states in
+    iter_pairs statel (fun (sl, sr) ->
+      nextbnr tmp sl sr |>
+      List.iter (fun (o, s') ->
+        let flag', s' =
+          StateSet.add states s' in
+        flagmerge flag';
+        map_add (K (o, sl, sr)) s';
+    ));
+  done;
+  (StateSet.elems states, !map)
+
+let intersperse x l =
+  let rec go left right out =
+    let out =
+      (List.rev left @ [x] @ right) ::
+      out in
+    match right with
+    | x :: right' ->
+      go (x :: left) right' out
+    | [] -> out
+  in go [] l []
+
+let rec permute = function
+  | [] -> [[]]
+  | x :: l ->
+    List.concat (List.map
+      (intersperse x) (permute l))
+
+(* build all binary trees with ordered
+ * leaves l *)
+let rec bins build l =
+  let rec go l r out =
+    match r with
+    | [] -> out
+    | x :: r' ->
+      go (l @ [x]) r'
+        (fold_pairs
+          (bins build l)
+          (bins build r)
+          out (fun (l, r) out ->
+                 build l r :: out))
+  in
+  match l with
+  | [] -> []
+  | [x] -> [x]
+  | x :: l -> go [x] l []
+
+let products l ini f =
+  let rec go acc la = function
+    | [] -> f (List.rev la) acc
+    | xs :: l ->
+      List.fold_left (fun acc x ->
+          go acc (x :: la) l)
+        acc xs
+  in go ini [] l
+
+(* combinatorial nuke... *)
+let rec ac_equiv =
+  let rec alevel o = function
+    | Bnr (o', l, r) when o' = o ->
+      alevel o l @ alevel o r
+    | x -> [x]
+  in function
+  | Bnr (o, _, _) as p
+  when associative o ->
+    products
+      (List.map ac_equiv (alevel o p)) []
+      (fun choice out ->
+        List.map
+          (bins (fun l r -> Bnr (o, l, r)))
+          (if commutative o
+            then permute choice
+            else [choice]) |>
+        List.concat |>
+        (fun l -> List.rev_append l out))
+  | Bnr (o, l, r)
+  when commutative o ->
+    fold_pairs
+      (ac_equiv l) (ac_equiv r) []
+      (fun (l, r) out ->
+        Bnr (o, l, r) ::
+        Bnr (o, r, l) :: out)
+  | Bnr (o, l, r) ->
+    fold_pairs
+      (ac_equiv l) (ac_equiv r) []
+      (fun (l, r) out ->
+        Bnr (o, l, r) :: out)
+  | x -> [x]

+ 102 - 0
tools/match_test.ml

@@ -0,0 +1,102 @@
+#use "match.ml"
+
+let test_pattern_match =
+  let pm = pattern_match
+  and nm = fun x y -> not (pattern_match x y) in
+  begin
+    assert (nm (Atm Tmp) (Atm (Con 42L)));
+    assert (pm (Atm AnyCon) (Atm (Con 42L)));
+    assert (nm (Atm (Con 42L)) (Atm AnyCon));
+    assert (nm (Atm (Con 42L)) (Atm Tmp));
+  end
+
+let test_peel =
+  let o = Kw, Oadd in
+  let p = Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
+                  Atm (Con 42L)) in
+  let l = peel p () in
+  let () = assert (List.length l = 3) in
+  let atomic_p (p, _) =
+    match p with Atm _ -> true | _ -> false in
+  let () = assert (List.for_all atomic_p l) in
+  let l = List.map (fun (p, c) -> fold_cursor c p) l in
+  let () = assert (List.for_all ((=) p) l) in
+  ()
+
+let test_fold_pairs =
+  let l = [1; 2; 3; 4; 5] in
+  let p = fold_pairs l l [] (fun a b -> a :: b) in
+  let () = assert (List.length p = 25) in
+  let p = sort_uniq compare p in
+  let () = assert (List.length p = 25) in
+  ()
+
+(* test pattern & state *)
+let tp =
+  let o = Kw, Oadd in
+  Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
+                  Atm (Con 0L))
+let ts =
+  { id = 0
+  ; seen = Atm Tmp
+  ; point =
+    List.map snd
+      (List.filter (fun (p, _) -> p = Atm Tmp)
+        (peel tp ()))
+  }
+
+let print_sm =
+  let op_str (k, o) =
+    Printf.sprintf "%s%s"
+      (match o with
+       | Oadd -> "add"
+       | Osub -> "sub"
+       | Omul -> "mul")
+      (match k with
+       | Kw -> "w"
+       | Kl -> "l"
+       | Ks -> "s"
+       | Kd -> "d")
+  in
+  StateMap.iter (fun k s' ->
+    match k with
+    | K (o, sl, sr) ->
+        Printf.printf
+          "(%s %d %d) -> %d\n"
+          (op_str o)
+          sl.id sr.id s'.id
+  )
+
+let address_rules =
+  let oa = Kl, Oadd in
+  let om = Kl, Omul in
+  let rule name pattern = { name; pattern; } in
+  (* o + b *)
+  [ rule "ob1" (Bnr (oa, Atm Tmp, Atm AnyCon))
+  ; rule "ob2" (Bnr (oa, Atm AnyCon, Atm Tmp))
+
+  (* b + s * i *)
+  ; rule "bs1" (Bnr (oa, Atm Tmp, Bnr (om, Atm AnyCon, Atm Tmp)))
+  ; rule "bs2" (Bnr (oa, Atm Tmp, Bnr (om, Atm Tmp, Atm AnyCon)))
+  ; rule "bs3" (Bnr (oa, Bnr (om, Atm AnyCon, Atm Tmp), Atm Tmp))
+  ; rule "bs4" (Bnr (oa, Bnr (om, Atm Tmp, Atm AnyCon), Atm Tmp))
+
+  (* o + s * i *)
+  ; rule "os1" (Bnr (oa, Atm AnyCon, Bnr (om, Atm AnyCon, Atm Tmp)))
+  ; rule "os2" (Bnr (oa, Atm AnyCon, Bnr (om, Atm Tmp, Atm AnyCon)))
+  ; rule "os3" (Bnr (oa, Bnr (om, Atm AnyCon, Atm Tmp), Atm AnyCon))
+  ; rule "os4" (Bnr (oa, Bnr (om, Atm Tmp, Atm AnyCon), Atm AnyCon))
+  ]
+
+(*
+let sl, sm = generate_table address_rules
+let s n = List.find (fun {id; _} -> id = n) sl
+let () = print_sm sm
+*)
+
+let tp0 =
+  let o = Kw, Oadd in
+  Bnr (o, Atm Tmp, Atm (Con 0L))
+let tp1 =
+  let o = Kw, Oadd in
+  Bnr (o, tp0, Atm (Con 1L))