Explorar el Código

Spilling (#152)

* Some initial work on variable spilling

* skip restoring with the first state

* bodge saving and restoring

* make arguments appear in the first state usage table

* Don't restore variables in their defined state

* Add loop iteration hoisting test

* another attempt

* sort out arguments

they're always a pain

* round and round we go...

Back to per state vars, was hoping the "tmp used without being initialised" would magically solve itself, but no. Might be related to TVar exprs

* give some vars default type expressions

* bodge it

* disable hanging test so we can look at actual failures

* hack it even more

* Add some comments so I don't forget what this all is again

* attempt at avoiding double wrapping

seems to work on some targets...

* mark them as captured as well

* don't duplicate half of capturedVars

* Fix dodgy merge

* Need to follow abstracts away before getting default values

Not sure why I need to do this now

---------

Co-authored-by: Simon Krajewski <[email protected]>
Aidan Lee hace 4 semanas
padre
commit
5fc672b21e

+ 1 - 0
src/core/tType.ml

@@ -559,6 +559,7 @@ type flag_tvar =
 	| VStatic
 	| VUsedByTyper (* Set if the typer looked up this variable *)
 	| VHxb (* Flag used by hxb *)
+	| VCoroCaptured
 
 let flag_tvar_names = [
 	"VCaptured";"VFinal";"VAnalyzed";"VAssigned";"VCaught";"VStatic";"VUsedByTyper"

+ 10 - 2
src/coro/coroFromTexpr.ml

@@ -2,6 +2,7 @@ open Globals
 open Type
 open CoroTypes
 open CoroFunctions
+open LocalUsage
 
 let e_no_value = Texpr.Builder.make_null t_dynamic null_pos
 
@@ -13,6 +14,13 @@ type coro_ret =
 	| RMapExpr of coro_ret * (texpr -> texpr)
 
 let expr_to_coro ctx etmp_result etmp_error_unwrapped cb_root e =
+
+	(* TODO : Not have this be copy and pasted from capturedVars with slight modifications *)
+	let wrapper = ctx.typer.com.local_wrapper in
+	let scom = SafeCom.of_com ctx.typer.com in
+	let scom = { scom with platform_config = { scom.platform_config with pf_capture_policy = CPWrapRef } } in
+	let e = CapturedVars.captured_vars scom wrapper true e in
+
 	let make_block typepos =
 		make_block ctx typepos
 	in
@@ -42,7 +50,7 @@ let expr_to_coro ctx etmp_result etmp_error_unwrapped cb_root e =
 	in
 	let tmp_local cb t p =
 		let v = alloc_var VGenerated "tmp" t p in
-		add_expr cb (mk (TVar(v,None)) ctx.typer.t.tvoid p);
+		add_expr cb (mk (TVar(v, Some (Texpr.Builder.default_value (Abstract.follow_with_abstracts t) p))) ctx.typer.t.tvoid p);
 		v
 	in
 	let check_complex cb ret t p = match ret with
@@ -164,7 +172,7 @@ let expr_to_coro ctx etmp_result etmp_error_unwrapped cb_root e =
 			add_expr cb e;
 			Some (cb,e_no_value)
 		| TVar(v,Some e1) ->
-			add_expr cb {e with eexpr = TVar(v,None)};
+			add_expr cb {e with eexpr = TVar(v,Some (Texpr.Builder.default_value (Abstract.follow_with_abstracts e1.etype) e1.epos))};
 			let cb = loop_assign cb (RLocal v) e1 in
 			cb
 		(* calls *)

+ 109 - 77
src/coro/coroToTexpr.ml

@@ -9,6 +9,10 @@ open CoroControl
 type coro_state = {
 	cs_id : int;
 	mutable cs_el : texpr list;
+	mutable cs_declarations : tvar list;
+
+	(* a "foreign" variable is one which is not declared in this state but is accessed in it *)
+	cs_foreign_vars : (int, tvar) Hashtbl.t;
 }
 
 type coro_to_texpr_exprs = {
@@ -42,109 +46,135 @@ let handle_locals ctx b cls states tf_args forbidden_vars econtinuation =
 		type t = int
 	end) in
 
-	(* function arguments are accessible from the initial state without hoisting needed, so set that now *)
-	let arg_state_set = IntSet.of_list [ (List.hd states).cs_id ] in
-	let var_usages    = tf_args |> List.map (fun (v, _) -> v.v_id, arg_state_set) |> List.to_seq |> Hashtbl.of_seq in
+	let fst_state     = List.hd states in
+	let arg_state_set = IntSet.of_list [ fst_state.cs_id ] in
+
+	(* Keep an extra table of all vars and what states they appear in, easier check if a var is used across states this way. *)
+	let var_usages = tf_args |> List.map (fun (v, _) -> v.v_id, arg_state_set) |> List.to_seq |> Hashtbl.of_seq in
+
+	(* Treat arguments as "declared" in the initial state, this way they aren't spilled if accessed before the first suspension. *)
+	fst_state.cs_declarations <- List.map (fun (a, _) -> a) tf_args;
 
-	(* First iteration, just add newly discovered local variables *)
-	(* After this var_usages will contain all arguments and local vars and the states sets will be just the creation state *)
-	(* We don't handle locals here so we don't poison the var_usage hashtbl with non local var data *)
 	List.iter (fun state ->
 		let rec loop e =
 			match e.eexpr with
 			| TVar (v, eo) ->
-				Option.may loop eo;
-				Hashtbl.replace var_usages v.v_id (IntSet.of_list [ state.cs_id ])
-			| _ ->
-				Type.iter loop e
-		in
-		List.iter loop state.cs_el
-	) states;
+				state.cs_declarations <- v :: state.cs_declarations;
 
-	(* Second interation, visit all locals and update any local variable state sets *)
-	List.iter (fun state ->
-		let rec loop e =
-			match e.eexpr with
-			| TLocal (v) ->
-				(match Hashtbl.find_opt var_usages v.v_id with
-				| Some set ->
-					Hashtbl.replace var_usages v.v_id (IntSet.add state.cs_id set)
-				| None ->
-					())
+				Hashtbl.replace var_usages v.v_id (IntSet.of_list [ state.cs_id ]);
+
+				Option.may loop eo
+			| TLocal v when Hashtbl.mem var_usages v.v_id ->
+				let existing = Hashtbl.find var_usages v.v_id in
+				
+				Hashtbl.replace var_usages v.v_id (IntSet.add state.cs_id existing)
 			| _ ->
 				Type.iter loop e
 		in
 		List.iter loop state.cs_el
 	) states;
 
-	let is_used_across_states v_id =
-		let many_states set v_id =
-			IntSet.elements set |> List.length > 1 in
-		(* forbidden vars are things like the _hx_continuation variable, they should not be hoisted *)
-		let non_coro_var v_id =
-			forbidden_vars |> List.exists (fun id -> id = v_id) |> not in
-
-		match Hashtbl.find_opt var_usages v_id with
-		| Some set when many_states set v_id && non_coro_var v_id ->
-			true
+	(*
+	 * Each variable which is used across multiple states is given a field in the continuation class to store it's value
+	 * during suspension.
+	 * TODO : Instead of giving each variable a field have a set of "slots" which can be used by a field if no other variable is currently using it.
+	 *)
+	let fields = Hashtbl.create 0 in
+	let is_used_across_multiple_states id =
+		match Hashtbl.find_opt var_usages id with
+		| Some set ->
+			(match IntSet.elements set with
+			| [ _ ] ->
+				false
+			| _ ->
+				true)
 		| _ ->
 			false
 	in
 
-	let fields =
-		tf_args
-		|> List.filter_map (fun (v, _) ->
-			if is_used_across_states v.v_id then
-				Some (v.v_id, mk_field (Printf.sprintf "_hx_hoisted%i" v.v_id) v.v_type v.v_pos v.v_pos)
-			else
-				None)
-		|> List.to_seq
-		|> Hashtbl.of_seq in
+	(* Again, treat function arguments as the special case that they are *)
+	List.iter (fun (v, _) ->
+		if is_used_across_multiple_states v.v_id then begin
+			let field = mk_field (Printf.sprintf "_hx_hoisted%i" v.v_id) v.v_type null_pos null_pos in
+
+			Hashtbl.replace fields v.v_id field;
+		end) tf_args;
 
-	(* Third iteration, create fields for vars used across states and remap access to those fields *)
 	List.iter (fun state ->
-		let rec loop e =
+		let is_not_declared_in_state id =
+			List.exists (fun v -> v.v_id == id) state.cs_declarations |> not in
+
+		let rec mapper e =
 			match e.eexpr with
-			| TVar (v, eo) when is_used_across_states v.v_id ->
-				let name  = Printf.sprintf "_hx_hoisted%i" v.v_id in
-				let field = mk_field name v.v_type v.v_pos v.v_pos in
+			| TVar (v, eo) when is_used_across_multiple_states v.v_id ->
+				let field = mk_field (Printf.sprintf "_hx_hoisted%i" v.v_id) v.v_type v.v_pos v.v_pos in
 
 				Hashtbl.replace fields v.v_id field;
-
-				begin match eo with
+				
+				{ e with eexpr = TVar (v, Option.map mapper eo) }
+			| TLocal v when is_used_across_multiple_states v.v_id && is_not_declared_in_state v.v_id ->
+				(* Each state generates new local variables for variables which are used across states. *)
+				(* Here we generate and store those new variables and remap local access to them *)
+				let new_v =
+					match Hashtbl.find_opt state.cs_foreign_vars v.v_id with
+					| Some v -> v
 					| None ->
-						(* We need an expression, so let's just emit `null`. The analyzer will clean this up. *)
-						b#null t_dynamic e.epos
-					| Some e ->
-						let efield = b#instance_field econtinuation cls [] field field.cf_type in
-						let einit  =
-							match eo with
-							| None -> Builder.default_value v.v_type v.v_pos
-							| Some e -> Type.map_expr loop e in
-						b#assign efield einit
-				end
-			(* A local of a var should never appear before its declaration, right? *)
-			| TLocal (v) when is_used_across_states v.v_id ->
-				let field = Hashtbl.find fields v.v_id in
-
-				b#instance_field econtinuation cls [] field field.cf_type
+						let new_v = alloc_var VGenerated (Printf.sprintf "_hx_restored%i" v.v_id) v.v_type v.v_pos in
+						Hashtbl.replace state.cs_foreign_vars v.v_id new_v;
+						new_v
+				in
+				{ e with eexpr = TLocal new_v }
 			| _ ->
-				Type.map_expr loop e
+				Type.map_expr mapper e
 		in
-		state.cs_el <- List.map loop state.cs_el
+		state.cs_el <- List.map mapper state.cs_el
 	) states;
 
-	(* We need to do this argument copying as the last thing we do *)
-	(* Doing it when the initial fields hashtbl is created will cause the third iterations TLocal to re-write them... *)
-	List.iter (fun (v, _) ->
-		if is_used_across_states v.v_id then
-			let initial = List.hd states in
-			let field   = Hashtbl.find fields v.v_id in
-			let efield  = b#instance_field econtinuation cls [] field field.cf_type in
-			let assign  = b#assign efield (b#local v v.v_pos) in
+	List.iter (fun state ->
+		let restoring =
+			Hashtbl.fold
+				(fun id v acc ->
+					let field   = Hashtbl.find fields id in
+					let access  = b#instance_field econtinuation cls [] field field.cf_type in
+					let var_dec = b#var_init v access in
+					var_dec :: acc
+				)
+				state.cs_foreign_vars
+				[] in
+
+		let initial =
+			List.filter_map
+				(fun v ->
+					if is_used_across_multiple_states v.v_id then
+						let field   = Hashtbl.find fields v.v_id in
+						let access  = b#instance_field econtinuation cls [] field field.cf_type in
+						let local   = b#local v v.v_pos in
+						let assign  = b#assign access local in
+						Some assign
+					else
+						None)
+				state.cs_declarations in
+
+		let saving =
+			Hashtbl.fold
+				(fun id v acc ->
+					let field   = Hashtbl.find fields id in
+					let access  = b#instance_field econtinuation cls [] field field.cf_type in
+					let local   = b#local v v.v_pos in
+					let assign  = b#assign access local in
+					assign :: acc
+				)
+				state.cs_foreign_vars
+				initial in
+
+		let body = List.take ((List.length state.cs_el) - 1) state.cs_el in
+		let tail = [ List.nth state.cs_el ((List.length state.cs_el) - 1) ] in
+		state.cs_el <- restoring @ body @ saving @ tail)
+		states;
 
-			initial.cs_el <- assign :: initial.cs_el) tf_args;
 	fields
+	|> Hashtbl.to_seq_values
+	|> List.of_seq
 
 let block_to_texpr_coroutine ctx cb cont cls params tf_args forbidden_vars exprs p stack_item_inserter start_exception =
 	let {econtinuation;ecompletion;estate;eresult;egoto;eerror;etmp_result;etmp_error;etmp_error_unwrapped} = exprs in
@@ -205,6 +235,8 @@ let block_to_texpr_coroutine ctx cb cont cls params tf_args forbidden_vars exprs
 	let make_state id el = {
 		cs_id = id;
 		cs_el = el;
+		cs_declarations = [];
+		cs_foreign_vars = Hashtbl.create 0;
 	} in
 
 	let get_caught,unwrap_exception = match com.basic.texception with
@@ -348,7 +380,7 @@ let block_to_texpr_coroutine ctx cb cont cls params tf_args forbidden_vars exprs
 	let states = !states in
 	let states = states |> List.sort (fun state1 state2 -> state1.cs_id - state2.cs_id) in
 
-	let fields = handle_locals ctx b cls states tf_args forbidden_vars econtinuation in
+	let fields_and_decls = handle_locals ctx b cls states tf_args forbidden_vars econtinuation in
 
 	let ethrow = b#void_block [
 		b#assign etmp_error (get_caught (b#string "Invalid coroutine state" p));
@@ -433,4 +465,4 @@ let block_to_texpr_coroutine ctx cb cont cls params tf_args forbidden_vars exprs
 		etry
 	in
 
-	eloop, init_state, fields |> Hashtbl.to_seq_values |> List.of_seq
+	eloop, init_state, fields_and_decls

+ 1 - 1
src/filters/filters.ml

@@ -433,7 +433,7 @@ let run_safe_filters ectx com (scom : SafeCom.t) all_types_array new_types_array
 		"reduce_expression",Optimizer.reduce_expression;
 		"inline_constructors",InlineConstructors.inline_constructors;
 		"Exceptions_filter",Exceptions.filter ectx;
-		"captured_vars",(fun scom -> CapturedVars.captured_vars scom cv_wrapper_impl);
+		"captured_vars",(fun scom -> CapturedVars.captured_vars scom cv_wrapper_impl false);
 	] in
 
 	let filters_after_analyzer = [

+ 6 - 2
src/filters/safe/capturedVars.ml

@@ -38,7 +38,7 @@ open LocalUsage
 		funs.push(function(x) { function() return x[0]++; }(x));
 	}
 *)
-let captured_vars scom impl e =
+let captured_vars scom impl for_coro e =
 
 	let mk_var v used =
 		let v2 = alloc_var v.v_kind v.v_name (PMap.find v.v_id used) v.v_pos in
@@ -137,6 +137,8 @@ let captured_vars scom impl e =
 				let vt = v.v_type in
 				v.v_type <- impl#captured_type vt;
 				add_var_flag v VCaptured;
+				if for_coro then
+					add_var_flag v VCoroCaptured;
 				vt
 			) used in
 			wrap used e
@@ -217,6 +219,8 @@ let captured_vars scom impl e =
 			incr depth;
 			f collect_vars;
 			decr depth;
+		| Declare v when not for_coro && has_var_flag v VCoroCaptured ->
+			()
 		| Declare v ->
 			vars := PMap.add v.v_id !depth !vars;
 		| Use v ->
@@ -244,7 +248,7 @@ let captured_vars scom impl e =
 		local_usage collect_vars e;
 
 		(* mark all capture variables - also used in rename_local_vars at later stage *)
-		PMap.iter (fun _ v -> add_var_flag v VCaptured) !used;
+		PMap.iter (fun _ v -> add_var_flag v VCaptured; if for_coro then add_var_flag v VCoroCaptured) !used;
 
 		!assigned
 	in

+ 1 - 1
src/typing/macroContext.ml

@@ -657,7 +657,7 @@ and flush_macro_context mint mctx =
 			"handle_abstract_casts",AbstractCast.handle_abstract_casts;
 			"local_statics",LocalStatic.run;
 			"Exceptions",Exceptions.filter ectx;
-			"captured_vars",(fun scom -> CapturedVars.captured_vars scom mctx.com.local_wrapper);
+			"captured_vars",(fun scom -> CapturedVars.captured_vars scom mctx.com.local_wrapper false);
 		] in
 		let type_filters = [
 			(fun _ -> FiltersCommon.remove_generic_base);

+ 15 - 0
tests/misc/coroutines/src/TestHoisting.hx

@@ -115,4 +115,19 @@ class TestHoisting extends utest.Test {
 
         }));
     }
+
+    function testLoopHoisting() {
+        final expected = [1, 2, 3];
+        final actual   = [];
+
+        CoroRun.runScoped(node -> {
+            for (x in expected) {
+                node.async(_ -> {
+                    actual.push(x);
+                });
+            }
+        });
+
+        Assert.same(expected, actual);
+    }
 }