Explorar o código

added get_pattern_locals to Interp

Simon Krajewski %!s(int64=12) %!d(string=hai) anos
pai
achega
3d2308c94b
Modificáronse 4 ficheiros con 73 adicións e 59 borrados
  1. 7 0
      interp.ml
  2. 62 59
      matcher.ml
  3. 1 0
      typecore.ml
  4. 3 0
      typer.ml

+ 7 - 0
interp.ml

@@ -108,6 +108,7 @@ type extern_api = {
 	get_local_using : unit -> tclass list;
 	get_local_vars : unit -> (string, Type.tvar) PMap.t;
 	get_build_fields : unit -> value;
+	get_pattern_locals : Ast.expr -> Type.t -> (string,Type.tvar) PMap.t;
 	define_type : value -> unit;
 	module_dependency : string -> string -> bool -> unit;
 	current_module : unit -> module_def;
@@ -2410,6 +2411,12 @@ let macro_lib =
 			else
 				VObject (obj (hash_field (get_ctx())) ["file",VString p.Ast.pfile;"pos",VInt p.Ast.pmin])
 		);
+		"pattern_locals", Fun2 (fun e t ->
+			let loc = (get_ctx()).curapi.get_pattern_locals (decode_expr e) (decode_type t) in
+			let h = Hashtbl.create 0 in
+			PMap.iter (fun n v -> Hashtbl.replace h (VString n) (encode_type v.v_type)) loc;
+			enc_hash h
+		);
 	]
 
 (* ---------------------------------------------------------------------- *)

+ 62 - 59
matcher.ml

@@ -17,19 +17,6 @@ and con = {
 	c_pos : pos;
 }
 
-type st_def =
-	| SVar of tvar
-	| SField of st * string
-	| SEnum of st * string * int
-	| SArray of st * int
-	| STuple of st * int * int
-
-and st = {
-	st_def : st_def;
-	st_type : t;
-	st_pos : pos;
-}
-
 type pat_def =
 	| PAny
 	| PVar of tvar
@@ -43,6 +30,19 @@ and pat = {
 	p_pos : pos;
 }
 
+type st_def =
+	| SVar of tvar
+	| SField of st * string
+	| SEnum of st * string * int
+	| SArray of st * int
+	| STuple of st * int * int
+
+and st = {
+	st_def : st_def;
+	st_type : t;
+	st_pos : pos;
+}
+
 type out = {
 	o_expr : texpr;
 	o_guard : texpr option;
@@ -258,8 +258,7 @@ let rec is_value_type = function
 	| _ ->
 		false
 
-let to_pattern mctx e st =
-	let ctx = mctx.ctx in
+let to_pattern ctx e t =
 	let perror p = error "Unrecognized pattern" p in
 	let verror n p = error ("Variable " ^ n ^ " must appear exactly once in each sub-pattern") p in
 	let mk_var tctx s t p =
@@ -267,40 +266,40 @@ let to_pattern mctx e st =
 			| Some vmap -> fst (try PMap.find s vmap with Not_found -> verror s p)
 			| None -> alloc_var s t
 		in
-		unify mctx.ctx t v.v_type p;
+		unify ctx t v.v_type p;
 		if PMap.mem s tctx.pc_locals then verror s p;
 		tctx.pc_locals <- PMap.add s (v,p) tctx.pc_locals;
 		v
 	in
-	let rec loop pctx e st =
+	let rec loop pctx e t =
 		let p = pos e in
 		match fst e with
 		| EConst(Ident "null") ->
 			error "null-patterns are not allowed" p
 		| EParenthesis e ->
-			loop pctx e st
+			loop pctx e t
 		| ECast(e1,None) ->
-			loop pctx e1 st
+			loop pctx e1 t
 		| EConst((Ident ("false" | "true") | Int _ | String _ | Float _) as c) ->
 			let e = Codegen.type_constant ctx.com c p in
-			unify ctx e.etype st.st_type p;
+			unify ctx e.etype t p;
 			let c = match e.eexpr with TConst c -> c | _ -> assert false in
-			mk_con_pat (CConst c) [] st.st_type p
+			mk_con_pat (CConst c) [] t p
 		| EField _ ->
-			let e = type_expr_with_type ctx e (Some st.st_type) false in
+			let e = type_expr_with_type ctx e (Some t) false in
 			let e = match Optimizer.make_constant_expression ctx e with Some e -> e | None -> e in
 			(match e.eexpr with
-			| TConst c -> mk_con_pat (CConst c) [] st.st_type p
-			| TTypeExpr mt -> mk_con_pat (CType mt) [] st.st_type p
+			| TConst c -> mk_con_pat (CConst c) [] t p
+			| TTypeExpr mt -> mk_con_pat (CType mt) [] t p
 			| TField(_, FStatic(_,cf)) when is_value_type cf.cf_type ->
 				mk_con_pat (CExpr e) [] cf.cf_type p
 			| TField(_, FEnum(en,ef)) ->
-				let tc = monomorphs ctx.type_params (st.st_type) in
+				let tc = monomorphs ctx.type_params (t) in
 				unify_enum_field en (List.map (fun _ -> mk_mono()) en.e_types) ef tc;
-				mk_con_pat (CEnum(en,ef)) [] st.st_type p
+				mk_con_pat (CEnum(en,ef)) [] t p
 			| _ -> error "Constant expression expected" p)
 		| ECall(ec,el) ->
-			let tc = monomorphs ctx.type_params (st.st_type) in
+			let tc = monomorphs ctx.type_params (t) in
 			let ec = type_expr_with_type ctx ec (Some tc) false in
 			(match follow ec.etype with
 			| TEnum(en,pl)
@@ -327,8 +326,7 @@ let to_pattern mctx e st =
 						let pat = mk_pat PAny t_dynamic pany in
 						(ExtList.List.make ((List.length tl) + 1) pat)
 					| e :: el, t :: tl ->
-						let st = mk_st (SEnum(st,ef.ef_name,i)) t (pos e) in
-						let pat = loop pctx e st in
+						let pat = loop pctx e t in
 						pat :: loop2 (i + 1) el tl
 					| e :: _, [] ->
 						error "Too many arguments" (pos e);
@@ -337,13 +335,13 @@ let to_pattern mctx e st =
 					| [],[] ->
 						[]
 				in
-				mk_con_pat (CEnum(en,ef)) (loop2 0 el tl) st.st_type p
+				mk_con_pat (CEnum(en,ef)) (loop2 0 el tl) t p
 			| _ -> perror p)
 		| EConst(Ident "_") ->
-			mk_any st.st_type p
+			mk_any t p
 		| EConst(Ident s) ->
 			begin try
-				let tc = monomorphs ctx.type_params (st.st_type) in
+				let tc = monomorphs ctx.type_params (t) in
 				let ec = match tc with
 					| TEnum(en,pl) ->
 						let ef = PMap.find s en.e_constrs in
@@ -362,7 +360,7 @@ let to_pattern mctx e st =
 				(match ec.eexpr with
 					| TField (_,FEnum (en,ef)) ->
 						unify_enum_field en (List.map (fun _ -> mk_mono()) en.e_types) ef tc;
-						mk_con_pat (CEnum(en,ef)) [] st.st_type p
+						mk_con_pat (CEnum(en,ef)) [] t p
                     | TConst c ->
                         unify ctx ec.etype tc p;
                         mk_con_pat (CConst c) [] tc p
@@ -374,45 +372,43 @@ let to_pattern mctx e st =
 						raise Not_found);
 			with Not_found ->
 				if not (is_lower_ident s) then error "Capture variables must be lower-case" p;
-				let v = mk_var pctx s st.st_type p in
+				let v = mk_var pctx s t p in
 				mk_pat (PVar v) v.v_type p
 			end
 		| (EObjectDecl fl) ->
-			begin match follow st.st_type with
+			begin match follow t with
 			| TAnon {a_fields = fields}
 			| TInst({cl_fields = fields},_) ->
-				List.iter (fun (n,(_,p)) -> if not (PMap.mem n fields) then error (unify_error_msg (print_context()) (has_extra_field st.st_type n)) p) fl;
+				List.iter (fun (n,(_,p)) -> if not (PMap.mem n fields) then error (unify_error_msg (print_context()) (has_extra_field t n)) p) fl;
 				let sl,pl,i = PMap.foldi (fun n cf (sl,pl,i) ->
-					let st = mk_st (SField(st,n)) cf.cf_type (pos e) in
-					let pat = try loop pctx (List.assoc n fl) st with Not_found -> (mk_any cf.cf_type p) in
+					let pat = try loop pctx (List.assoc n fl) cf.cf_type with Not_found -> (mk_any cf.cf_type p) in
 					(n,cf) :: sl,pat :: pl,i + 1
 				) fields ([],[],0) in
-				mk_con_pat (CFields(i,sl)) pl st.st_type p
+				mk_con_pat (CFields(i,sl)) pl t p
 			| _ ->
-				error ((s_type st.st_type) ^ " should be { }") p
+				error ((s_type t) ^ " should be { }") p
 			end
 		| EArrayDecl [] ->
-			mk_con_pat (CArray 0) [] st.st_type p
+			mk_con_pat (CArray 0) [] t p
 		| EArrayDecl el ->
-			begin match follow st.st_type with
+			begin match follow t with
 				| TInst({cl_path=[],"Array"},[t2]) ->
 					let pl = ExtList.List.mapi (fun i e ->
-						let st = mk_st (SArray(st,i)) t2 p in
-						loop pctx e st
+						loop pctx e t2
 					) el in
-					mk_con_pat (CArray (List.length el)) pl st.st_type p
+					mk_con_pat (CArray (List.length el)) pl t p
 				| _ ->
-					error ((s_type st.st_type) ^ " should be Array") p
+					error ((s_type t) ^ " should be Array") p
 			end
 		| EBinop(OpAssign,(EConst(Ident s),p2),e1) ->
-			let v = mk_var pctx s st.st_type p in
-			let pat1 = loop pctx e1 st in
-			mk_pat (PBind(v,pat1)) st.st_type p2
+			let v = mk_var pctx s t p in
+			let pat1 = loop pctx e1 t in
+			mk_pat (PBind(v,pat1)) t p2
 		| EBinop(OpOr,(EBinop(OpOr,e1,e2),p2),e3) ->
-			loop pctx (EBinop(OpOr,e1,(EBinop(OpOr,e2,e3),p2)),p) st
+			loop pctx (EBinop(OpOr,e1,(EBinop(OpOr,e2,e3),p2)),p) t
 		| EBinop(OpOr,e1,e2) ->
 			let old = pctx.pc_locals in
-			let pat1 = loop pctx e1 st in
+			let pat1 = loop pctx e1 t in
 			begin match pat1.p_def with
 				| PAny | PVar _ ->
 					ctx.com.warning "This pattern is unused" (pos e2);
@@ -422,7 +418,7 @@ let to_pattern mctx e st =
 					pc_sub_vars = Some pctx.pc_locals;
 					pc_locals = old;
 				} in
-				let pat2 = loop pctx2 e2 st in
+				let pat2 = loop pctx2 e2 t in
 				PMap.iter (fun s (_,p) -> if not (PMap.mem s pctx2.pc_locals) then verror s p) pctx.pc_locals;
 				unify ctx pat1.p_type pat2.p_type pat1.p_pos;
 				mk_pat (POr(pat1,pat2)) pat2.p_type (punion pat1.p_pos pat2.p_pos);
@@ -434,9 +430,11 @@ let to_pattern mctx e st =
 		pc_locals = PMap.empty;
 		pc_sub_vars = None;
 	} in
-	let e = loop pctx e st in
-	PMap.iter (fun n (v,p) -> ctx.locals <- PMap.add n v ctx.locals) pctx.pc_locals;
-	e
+	loop pctx e t, pctx.pc_locals
+
+let get_pattern_locals ctx e t =
+	let _,locals = to_pattern ctx e t in
+	PMap.foldi (fun n (v,_) acc -> PMap.add n v acc) locals PMap.empty
 
 (* Match compilation *)
 
@@ -919,20 +917,24 @@ let match_expr ctx e cases def need_val with_type p =
 		subtree_index = Hashtbl.create 0;
 		num_subtrees = 0;
 	} in
+	let add_pattern_locals (pat,locals) =
+		PMap.iter (fun n (v,p) -> ctx.locals <- PMap.add n v ctx.locals) locals;
+		pat
+	in
 	let pl = List.map (fun (el,eg,e) ->
 		let ep = collapse_case el in
 		let save = save_locals ctx in
 		let pl = match fst ep,stl with
 			| EArrayDecl el,[st] when (match follow st.st_type with TInst({cl_path=[],"Array"},[_]) -> true | _ -> false) ->
-				[to_pattern mctx ep st]
+				[add_pattern_locals (to_pattern ctx ep st.st_type)]
 			| EArrayDecl el,stl ->
 				begin try
-					List.map2 (fun e st -> to_pattern mctx e st) el stl
+					List.map2 (fun e st -> add_pattern_locals (to_pattern ctx e st.st_type)) el stl
 				with Invalid_argument _ ->
 					error ("Invalid number of arguments: expected " ^ (string_of_int (List.length stl)) ^ ", found " ^ (string_of_int (List.length el))) (pos ep)
 				end
 			| _,[st] ->
-				[to_pattern mctx ep st]
+				[add_pattern_locals (to_pattern ctx ep st.st_type)]
 			| EConst(Ident "_"),stl ->
 				List.map (fun st -> mk_any st.st_type st.st_pos) stl
 			| _,_ ->
@@ -986,4 +988,5 @@ let match_expr ctx e cases def need_val with_type p =
 		error ("Unmatched patterns: " ^ (s_st_r false (s_pat pat) st)) p
 	end;
 ;;
-match_expr_ref := match_expr
+match_expr_ref := match_expr;
+get_pattern_locals_ref := get_pattern_locals

+ 1 - 0
typecore.ml

@@ -124,6 +124,7 @@ let type_expr_ref : (typer -> Ast.expr -> bool -> texpr) ref = ref (fun _ _ _ ->
 let unify_min_ref : (typer -> texpr list -> t) ref = ref (fun _ _ -> assert false)
 let type_expr_with_type_ref : (typer -> Ast.expr -> t option -> bool -> texpr) ref = ref (fun _ _ _ -> assert false)
 let match_expr_ref : (typer -> Ast.expr -> (Ast.expr list * Ast.expr option * Ast.expr option) list -> Ast.expr option option -> bool -> t option -> Ast.pos -> texpr) ref = ref (fun _ _ _ _ _ _ _ -> assert false)
+let get_pattern_locals_ref : (typer -> Ast.expr -> Type.t -> (string, tvar) PMap.t) ref = ref (fun _ _ _ -> assert false)
 
 let short_type ctx t =
 	let tstr = s_type ctx t in

+ 3 - 0
typer.ml

@@ -3111,6 +3111,9 @@ let make_macro_api ctx p =
 			| None -> Interp.VNull
 			| Some (_,fields) -> Interp.enc_array (List.map Interp.encode_field fields)
 		);
+		Interp.get_pattern_locals = (fun e t ->
+			!get_pattern_locals_ref ctx e t
+		);
 		Interp.define_type = (fun v ->
 			let m, tdef, pos = (try Interp.decode_type_def v with Interp.Invalid_expr -> Interp.exc (Interp.VString "Invalid type definition")) in
 			let mdep = Typeload.type_module ctx m ctx.m.curmod.m_extra.m_file [tdef,pos] pos in