Parcourir la source

Spilling Reads and Writes (#170)

* track local read and writes for states

* Use Texpr.skip
Aidan Lee il y a 3 semaines
Parent
commit
3d3d89e152
3 fichiers modifiés avec 85 ajouts et 55 suppressions
  1. 2 2
      src/coro/coroFromTexpr.ml
  2. 69 53
      src/coro/coroToTexpr.ml
  3. 14 0
      tests/misc/coroutines/src/TestHoisting.hx

+ 2 - 2
src/coro/coroFromTexpr.ml

@@ -50,7 +50,7 @@ let expr_to_coro ctx etmp_result etmp_error_unwrapped cb_root e =
 	in
 	let tmp_local cb t p =
 		let v = alloc_var VGenerated "tmp" t p in
-		add_expr cb (mk (TVar(v, Some (Texpr.Builder.default_value (Abstract.follow_with_abstracts t) p))) ctx.typer.t.tvoid p);
+		add_expr cb (mk (TVar(v, None)) ctx.typer.t.tvoid p);
 		v
 	in
 	let check_complex cb ret t p = match ret with
@@ -172,7 +172,7 @@ let expr_to_coro ctx etmp_result etmp_error_unwrapped cb_root e =
 			add_expr cb e;
 			Some (cb,e_no_value)
 		| TVar(v,Some e1) ->
-			add_expr cb {e with eexpr = TVar(v,Some (Texpr.Builder.default_value (Abstract.follow_with_abstracts e1.etype) e1.epos))};
+			add_expr cb {e with eexpr = TVar(v,None)};
 			let cb = loop_assign cb (RLocal v) e1 in
 			cb
 		(* calls *)

+ 69 - 53
src/coro/coroToTexpr.ml

@@ -6,13 +6,19 @@ open ContTypes
 open Texpr
 open CoroControl
 
+module IntSet = Set.Make(struct
+	let compare a b = b - a
+	type t = int
+end)
+
 type coro_state = {
 	cs_id : int;
 	mutable cs_el : texpr list;
 	mutable cs_declarations : tvar list;
 
-	(* a "foreign" variable is one which is not declared in this state but is accessed in it *)
-	cs_foreign_vars : (int, tvar) Hashtbl.t;
+	cs_mapped_local : (int, tvar) Hashtbl.t;
+	mutable cs_reads : IntSet.t;
+	mutable cs_writes : IntSet.t;
 }
 
 type coro_to_texpr_exprs = {
@@ -41,11 +47,6 @@ let make_suspending_call basic cont call econtinuation =
 	mk (TCall (efun, args)) (cont.suspension_result basic.tany) call.cs_pos
 
 let handle_locals ctx b cls states tf_args forbidden_vars econtinuation =
-	let module IntSet = Set.Make(struct
-		let compare a b = b - a
-		type t = int
-	end) in
-
 	let fst_state     = List.hd states in
 	let arg_state_set = IntSet.of_list [ fst_state.cs_id ] in
 
@@ -95,35 +96,64 @@ let handle_locals ctx b cls states tf_args forbidden_vars econtinuation =
 	(* Again, treat function arguments as the special case that they are *)
 	List.iter (fun (v, _) ->
 		if is_used_across_multiple_states v.v_id then begin
+			fst_state.cs_writes <- IntSet.add v.v_id fst_state.cs_writes;
+
 			let field = mk_field (Printf.sprintf "_hx_hoisted%i" v.v_id) v.v_type null_pos null_pos in
 
 			Hashtbl.replace fields v.v_id field;
+			Hashtbl.replace fst_state.cs_mapped_local v.v_id v;
 		end) tf_args;
 
 	List.iter (fun state ->
-		let is_not_declared_in_state id =
-			List.exists (fun v -> v.v_id == id) state.cs_declarations |> not in
+
+		let get_or_create_local_mapping v =
+			match Hashtbl.find_opt state.cs_mapped_local v.v_id with
+			| Some v -> v
+			| None ->
+				let new_v = alloc_var VGenerated (Printf.sprintf "_hx_restored%i" v.v_id) v.v_type v.v_pos in
+				Hashtbl.replace state.cs_mapped_local v.v_id new_v;
+				new_v
+		in
 
 		let rec mapper e =
 			match e.eexpr with
 			| TVar (v, eo) when is_used_across_multiple_states v.v_id ->
+				if Option.is_some eo then
+					state.cs_writes <- IntSet.add v.v_id state.cs_writes;
+
 				let field = mk_field (Printf.sprintf "_hx_hoisted%i" v.v_id) v.v_type v.v_pos v.v_pos in
 
 				Hashtbl.replace fields v.v_id field;
+				Hashtbl.replace state.cs_mapped_local v.v_id v;
 				
 				{ e with eexpr = TVar (v, Option.map mapper eo) }
-			| TLocal v when is_used_across_multiple_states v.v_id && is_not_declared_in_state v.v_id ->
+			| TBinop ((OpAssign | OpAssignOp _) as op, elhs, erhs) ->
+				(match Texpr.skip elhs with
+				| { eexpr = TLocal v } when is_used_across_multiple_states v.v_id ->
+					state.cs_writes <- IntSet.add v.v_id state.cs_writes;
+
+					let new_local = { elhs with eexpr = TLocal (get_or_create_local_mapping v) } in
+					let new_rhs   = mapper erhs in
+
+					{ e with eexpr = TBinop (op, new_local, new_rhs) }
+				| _ ->
+					Type.map_expr mapper e)
+			| TUnop ((Increment | Decrement) as mode, flag, erhs) ->
+				(match Texpr.skip erhs with
+				| { eexpr = TLocal v  } when is_used_across_multiple_states v.v_id ->
+					state.cs_writes <- IntSet.add v.v_id state.cs_writes;
+
+					let new_rhs = { erhs with eexpr = TLocal (get_or_create_local_mapping v) } in
+					{ e with eexpr = TUnop (mode, flag, new_rhs) }
+				| _ ->
+					Type.map_expr mapper e)
+			| TLocal v when is_used_across_multiple_states v.v_id ->
 				(* Each state generates new local variables for variables which are used across states. *)
 				(* Here we generate and store those new variables and remap local access to them *)
-				let new_v =
-					match Hashtbl.find_opt state.cs_foreign_vars v.v_id with
-					| Some v -> v
-					| None ->
-						let new_v = alloc_var VGenerated (Printf.sprintf "_hx_restored%i" v.v_id) v.v_type v.v_pos in
-						Hashtbl.replace state.cs_foreign_vars v.v_id new_v;
-						new_v
-				in
-				{ e with eexpr = TLocal new_v }
+
+				state.cs_reads <- IntSet.add v.v_id state.cs_reads;
+
+				{ e with eexpr = TLocal (get_or_create_local_mapping v) }
 			| _ ->
 				Type.map_expr mapper e
 		in
@@ -131,41 +161,25 @@ let handle_locals ctx b cls states tf_args forbidden_vars econtinuation =
 	) states;
 
 	List.iter (fun state ->
-		let restoring =
-			Hashtbl.fold
-				(fun id v acc ->
-					let field   = Hashtbl.find fields id in
-					let access  = b#instance_field econtinuation cls [] field field.cf_type in
-					let var_dec = b#var_init v access in
-					var_dec :: acc
-				)
-				state.cs_foreign_vars
-				[] in
-
-		let initial =
-			List.filter_map
-				(fun v ->
-					if is_used_across_multiple_states v.v_id then
-						let field   = Hashtbl.find fields v.v_id in
-						let access  = b#instance_field econtinuation cls [] field field.cf_type in
-						let local   = b#local v v.v_pos in
-						let assign  = b#assign access local in
-						Some assign
-					else
-						None)
-				state.cs_declarations in
+		let restoring = IntSet.union state.cs_writes state.cs_reads |> IntSet.to_list |> List.filter_map (fun id ->
+			(* We don't want to restore a variable which is declared in this state *)
+			(* Doing so would mean if the var is an argument the arguments value would be overwritten by whatever is in the hoisted field *)
+			if List.exists (fun v -> v.v_id = id) state.cs_declarations then
+				None
+			else
+				let v       = Hashtbl.find state.cs_mapped_local id in
+				let field   = Hashtbl.find fields id in
+				let access  = b#instance_field econtinuation cls [] field field.cf_type in
+				Some (b#var_init v access)
+		) in
 
 		let saving =
-			Hashtbl.fold
-				(fun id v acc ->
-					let field   = Hashtbl.find fields id in
-					let access  = b#instance_field econtinuation cls [] field field.cf_type in
-					let local   = b#local v v.v_pos in
-					let assign  = b#assign access local in
-					assign :: acc
-				)
-				state.cs_foreign_vars
-				initial in
+			state.cs_writes |> IntSet.to_list |> List.map (fun id ->
+				let v       = Hashtbl.find state.cs_mapped_local id in
+				let field   = Hashtbl.find fields id in
+				let access  = b#instance_field econtinuation cls [] field field.cf_type in
+				let local   = b#local v v.v_pos in
+				b#assign access local) in
 
 		let body = List.take ((List.length state.cs_el) - 1) state.cs_el in
 		let tail = [ List.nth state.cs_el ((List.length state.cs_el) - 1) ] in
@@ -236,7 +250,9 @@ let block_to_texpr_coroutine ctx cb cont cls params tf_args forbidden_vars exprs
 		cs_id = id;
 		cs_el = el;
 		cs_declarations = [];
-		cs_foreign_vars = Hashtbl.create 0;
+		cs_mapped_local = Hashtbl.create 0;
+		cs_reads = IntSet.empty;
+		cs_writes = IntSet.empty;
 	} in
 
 	let get_caught,unwrap_exception = match com.basic.texception with

+ 14 - 0
tests/misc/coroutines/src/TestHoisting.hx

@@ -130,4 +130,18 @@ class TestHoisting extends utest.Test {
 
         Assert.same(expected, actual);
     }
+
+    function testUninitialisedVariable() {
+        Assert.equals(7, CoroRun.run(() -> {
+            var i;
+
+            yield();
+
+            i = 7;
+
+            yield();
+
+            return i;
+        }));
+    }
 }