Browse Source

add EnumValue.match (closes #1783)

Simon Krajewski 12 years ago
parent
commit
30b48c2b47
4 changed files with 19 additions and 4 deletions
  1. 2 2
      interp.ml
  2. 1 1
      matcher.ml
  3. 1 1
      typecore.ml
  4. 15 0
      typer.ml

+ 2 - 2
interp.ml

@@ -113,7 +113,7 @@ type extern_api = {
 	get_local_using : unit -> tclass list;
 	get_local_vars : unit -> (string, Type.tvar) PMap.t;
 	get_build_fields : unit -> value;
-	get_pattern_locals : Ast.expr -> Type.t -> (string,Type.tvar) PMap.t;
+	get_pattern_locals : Ast.expr -> Type.t -> (string,(Type.tvar * Ast.pos)) PMap.t;
 	define_type : value -> unit;
 	module_dependency : string -> string -> bool -> unit;
 	current_module : unit -> module_def;
@@ -2432,7 +2432,7 @@ let macro_lib =
 		"pattern_locals", Fun2 (fun e t ->
 			let loc = (get_ctx()).curapi.get_pattern_locals (decode_expr e) (decode_type t) in
 			let h = Hashtbl.create 0 in
-			PMap.iter (fun n v -> Hashtbl.replace h (VString n) (encode_type v.v_type)) loc;
+			PMap.iter (fun n (v,_) -> Hashtbl.replace h (VString n) (encode_type v.v_type)) loc;
 			enc_hash h
 		);
 		"macro_context_reused", Fun1 (fun c ->

+ 1 - 1
matcher.ml

@@ -546,7 +546,7 @@ let to_pattern ctx e t =
 let get_pattern_locals ctx e t =
 	try
 		let _,locals = to_pattern ctx e t in
-		PMap.foldi (fun n (v,_) acc -> PMap.add n v acc) locals PMap.empty
+		PMap.foldi (fun n v acc -> PMap.add n v acc) locals PMap.empty
 	with Unrecognized_pattern _ ->
 		PMap.empty
 

+ 1 - 1
typecore.ml

@@ -141,7 +141,7 @@ let type_expr_ref : (typer -> Ast.expr -> with_type -> texpr) ref = ref (fun _ _
 let type_module_type_ref : (typer -> module_type -> t list option -> pos -> texpr) ref = ref (fun _ _ _ _ -> assert false)
 let unify_min_ref : (typer -> texpr list -> t) ref = ref (fun _ _ -> assert false)
 let match_expr_ref : (typer -> Ast.expr -> (Ast.expr list * Ast.expr option * Ast.expr option) list -> Ast.expr option option -> with_type -> Ast.pos -> decision_tree) ref = ref (fun _ _ _ _ _ _ -> assert false)
-let get_pattern_locals_ref : (typer -> Ast.expr -> Type.t -> (string, tvar) PMap.t) ref = ref (fun _ _ _ -> assert false)
+let get_pattern_locals_ref : (typer -> Ast.expr -> Type.t -> (string, (tvar * pos)) PMap.t) ref = ref (fun _ _ _ -> assert false)
 let get_constructor_ref : (typer -> tclass -> t list -> Ast.pos -> (t * tclass_field)) ref = ref (fun _ _ _ _ -> assert false)
 let check_abstract_cast_ref : (typer -> t -> texpr -> Ast.pos -> texpr) ref = ref (fun _ _ _ _ -> assert false)
 

+ 15 - 0
typer.ml

@@ -2831,6 +2831,9 @@ and type_expr ctx (e,p) (with_type:with_type) =
 				if field_name fa = "bind" then (match follow e1.etype with
 					| TFun(args,ret) -> {e1 with etype = opt_args args ret}
 					| _ -> e)
+				else if field_name fa = "match" then (match follow e1.etype with
+					| TEnum _ as t -> {e1 with etype = tfun [t] ctx.t.tbool }
+					| _ -> e)
 				else if mode = "position" then (match extract_field fa with
 					| None -> e
 					| Some cf -> raise (Typecore.DisplayPosition [cf.cf_pos]))
@@ -2904,6 +2907,9 @@ and type_expr ctx (e,p) (with_type:with_type) =
 					PMap.fold (fun f acc -> if can_access ctx c f true then PMap.add f.cf_name { f with cf_public = true; cf_type = opt_type f.cf_type } acc else acc) a.a_fields PMap.empty
 				| _ ->
 					a.a_fields)
+			| TEnum(_) as t ->
+				let cf = mk_field "match" (tfun [t] ctx.t.tbool) p in
+				PMap.add "match" cf PMap.empty
 			| TFun (args,ret) ->
 				let t = opt_args args ret in
 				let cf = mk_field "bind" (tfun [t] t) p in
@@ -3027,6 +3033,15 @@ and type_call ctx e el (with_type:with_type) p =
 		(match follow e.etype with
 			| TFun _ -> type_bind ctx e args p
 			| _ -> def ())
+	| (EField(e,"match"),p), [epat] ->
+		let et = type_expr ctx e Value in
+		(match follow et.etype with
+			| TEnum _ as t ->
+				let e = type_switch ctx e [[epat],None,Some (EConst(Ident "true"),p)] (Some (Some (EConst(Ident "false"),p))) (WithType ctx.t.tbool) p in
+				let locals = !get_pattern_locals_ref ctx epat t in
+				PMap.iter (fun _ (_,p) -> display_error ctx "Capture variables are not allowed" p) locals;
+				e
+			| _ -> def ())
 	| (EConst (Ident "$type"),_) , [e] ->
 		let e = type_expr ctx e Value in
 		ctx.com.warning (s_type (print_context()) e.etype) e.epos;