Browse Source

Copy var flags when duplicating local variables (#11803)

* copy var flags when duplicating

Also don't unroll loops that have static vars
closes #11800

* hoist static locals when unrolling loops

see #11800

* hoist all var declarations when unrolling loops

* awkwardly deal with captured locals

* clean up a bit, but don't hoist non-statics after all

* don't need this now

* remove test
Simon Krajewski 9 months ago
parent
commit
0ab8de0161

+ 1 - 0
src/core/texpr.ml

@@ -307,6 +307,7 @@ let duplicate_tvars f_this e =
 		let v2 = alloc_var v.v_kind v.v_name v.v_type v.v_pos in
 		v2.v_meta <- v.v_meta;
 		v2.v_extra <- v.v_extra;
+		v2.v_flags <- v.v_flags;
 		Hashtbl.add vars v.v_id v2;
 		v2;
 	in

+ 1 - 1
src/filters/filters.ml

@@ -700,7 +700,7 @@ module ForRemap = struct
 		| TFor(v,e1,e2) ->
 			let e1 = loop e1 in
 			let e2 = loop e2 in
-			let iterator = ForLoop.IterationKind.of_texpr ctx e1 (ForLoop.is_cheap_enough_t ctx e2) e.epos in
+			let iterator = ForLoop.IterationKind.of_texpr ctx e1 (ForLoop.get_unroll_params_t ctx e2) e.epos in
 			let restore = save_locals ctx in
 			let e = ForLoop.IterationKind.to_texpr ctx v iterator e2 e.epos in
 			restore();

+ 63 - 34
src/typing/forLoop.ml

@@ -55,10 +55,14 @@ let optimize_for_loop_iterator ctx v e1 e2 p =
 		mk (TWhile (ehasnext,eblock,NormalWhile)) ctx.t.tvoid p
 	]) ctx.t.tvoid p
 
+type unroll_parameters = {
+	expression_weight : int;
+}
+
 module IterationKind = struct
 	type t_kind =
 		| IteratorIntConst of texpr * texpr * bool (* ascending? *)
-		| IteratorIntUnroll of int * int * bool
+		| IteratorIntUnroll of int * int * bool * unroll_parameters
 		| IteratorInt of texpr * texpr
 		| IteratorArrayDecl of texpr list
 		| IteratorArray
@@ -170,7 +174,22 @@ module IterationKind = struct
 			)
 	 	| _ -> raise Not_found
 
-	let of_texpr ?(resume=false) ctx e unroll p =
+	let map_unroll_params ctx unroll_params i = match unroll_params with
+		| None ->
+			None
+		| Some unroll_params ->
+			let cost = i * unroll_params.expression_weight in
+			let max_cost = try
+				int_of_string (Common.defined_value ctx.com Define.LoopUnrollMaxCost)
+			with Not_found ->
+				250
+			in
+			if cost <= max_cost then
+				Some unroll_params
+			else
+				None
+
+	let of_texpr ?(resume=false) ctx e unroll_params p =
 		let dynamic_iterator e =
 			display_error ctx.com "You can't iterate on a Dynamic value, please specify Iterator or Iterable" e.epos;
 			IteratorDynamic,e,t_dynamic
@@ -211,9 +230,12 @@ module IterationKind = struct
 			let it = match efrom.eexpr,eto.eexpr with
 				| TConst (TInt a),TConst (TInt b) ->
 					let diff = Int32.to_int (Int32.sub a b) in
-					let unroll = unroll (abs diff) in
-					if unroll then IteratorIntUnroll(Int32.to_int a,abs(diff),diff <= 0)
-					else IteratorIntConst(efrom,eto,diff <= 0)
+					begin match map_unroll_params ctx unroll_params (abs diff) with
+					| Some unroll_params ->
+						IteratorIntUnroll(Int32.to_int a,abs(diff),diff <= 0,unroll_params)
+					| None ->
+						IteratorIntConst(efrom,eto,diff <= 0)
+					end
 				| _ ->
 					let eto = match follow eto.etype with
 						| TAbstract ({ a_path = ([],"Int") }, []) -> eto
@@ -223,8 +245,10 @@ module IterationKind = struct
 			in
 			it,e,ctx.t.tint
 		| TArrayDecl el,TInst({ cl_path = [],"Array" },[pt]) ->
-			let it = if unroll (List.length el) then IteratorArrayDecl el
-			else IteratorArray in
+			let it = match map_unroll_params ctx unroll_params (List.length el) with
+				| Some _ -> IteratorArrayDecl el
+				| None -> IteratorArray
+			in
 			(it,e,pt)
 		| _,TInst({ cl_path = [],"Array" },[pt])
 		| _,TInst({ cl_path = ["flash"],"Vector" },[pt]) ->
@@ -317,18 +341,31 @@ module IterationKind = struct
 		match iterator.it_kind with
 		| _ when not ctx.allow_transform ->
 			mk (TFor(v,e1,e2)) t_void p
-		| IteratorIntUnroll(offset,length,ascending) ->
+		| IteratorIntUnroll(offset,length,ascending,unroll_params) ->
 			check_loop_var_modification [v] e2;
 			if not ascending then typing_error "Cannot iterate backwards" p;
-			let el = ExtList.List.init length (fun i ->
-				let ei = make_int ctx.t (if ascending then i + offset else offset - i) p in
-				let rec loop e = match e.eexpr with
-					| TLocal v' when v == v' -> {ei with epos = e.epos}
-					| _ -> map_expr loop e
+			let rec unroll acc i =
+				if i = length then
+					List.rev acc
+				else begin
+					let ei = make_int ctx.t (if ascending then i + offset else offset - i) p in
+					let local_vars = ref [] in
+					let rec loop e = match e.eexpr with
+					| TLocal v' when v == v' ->
+						{ei with epos = e.epos}
+					| TVar(v,eo) when has_var_flag v VStatic ->
+						if acc = [] then
+							local_vars := {e with eexpr = TVar(v,eo)} :: !local_vars;
+						mk (TConst TNull) t_dynamic null_pos
+					| _ ->
+						map_expr loop e
 				in
 				let e2 = loop e2 in
-				Texpr.duplicate_tvars e_identity e2
-			) in
+				let acc = acc @ !local_vars in
+				let e2 = Texpr.duplicate_tvars e_identity e2 in
+				unroll (e2 :: acc) (i + 1)
+			end in
+			let el = unroll [] 0 in
 			mk (TBlock el) t_void p
 		| IteratorIntConst(a,b,ascending) ->
 			check_loop_var_modification [v] e2;
@@ -408,7 +445,7 @@ module IterationKind = struct
 			mk (TFor(v,e1,e2)) t_void p
 end
 
-let is_cheap_enough ctx e2 i =
+let get_unroll_params ctx e2 =
 	let num_expr = ref 0 in
 	let rec loop e = match fst e with
 		| EContinue | EBreak ->
@@ -420,17 +457,13 @@ let is_cheap_enough ctx e2 i =
 	try
 		if ctx.com.display.dms_kind <> DMNone then raise Exit;
 		ignore(loop e2);
-		let cost = i * !num_expr in
-		let max_cost = try
-			int_of_string (Common.defined_value ctx.com Define.LoopUnrollMaxCost)
-		with Not_found ->
-			250
-		in
-		cost <= max_cost
+		Some {
+			expression_weight = !num_expr;
+		}
 	with Exit ->
-		false
+		None
 
-let is_cheap_enough_t ctx e2 i =
+let get_unroll_params_t ctx e2 =
 	let num_expr = ref 0 in
 	let rec loop e = match e.eexpr with
 		| TContinue | TBreak ->
@@ -442,15 +475,11 @@ let is_cheap_enough_t ctx e2 i =
 	try
 		if ctx.com.display.dms_kind <> DMNone then raise Exit;
 		ignore(loop e2);
-		let cost = i * !num_expr in
-		let max_cost = try
-			int_of_string (Common.defined_value ctx.com Define.LoopUnrollMaxCost)
-		with Not_found ->
-			250
-		in
-		cost <= max_cost
+		Some {
+			expression_weight = !num_expr;
+		}
 	with Exit ->
-		false
+		None
 
 type iteration_ident = string * pos * display_kind option
 
@@ -491,7 +520,7 @@ let type_for_loop ctx handle_display it e2 p =
 	in
 	match ik with
 	| IKNormal(i,pi,dko) ->
-		let iterator = IterationKind.of_texpr ctx e1 (is_cheap_enough ctx e2) p in
+		let iterator = IterationKind.of_texpr ctx e1 (get_unroll_params ctx e2) p in
 		let i = add_local_with_origin ctx TVOForVariable i iterator.it_type pi in
 		let e2 = type_expr ctx e2 NoValue in
 		check_display (i,pi,dko);

+ 1 - 1
src/typing/typerDisplay.ml

@@ -518,7 +518,7 @@ and display_expr ctx e_ast e dk mode with_type p =
 		let fields = DisplayFields.collect ctx e_ast e dk with_type p in
 		let item = completion_item_of_expr ctx e in
 		let iterator = try
-			let it = (ForLoop.IterationKind.of_texpr ~resume:true ctx e (fun _ -> false) e.epos) in
+			let it = (ForLoop.IterationKind.of_texpr ~resume:true ctx e None e.epos) in
 			match follow it.it_type with
 				| TDynamic _ ->  None
 				| t -> Some t

+ 25 - 0
tests/optimization/src/issues/Issue11800.hx

@@ -0,0 +1,25 @@
+package issues;
+
+class Issue11800 {
+	@:js('
+		++issues_Issue11800.test_a;
+		++issues_Issue11800.test_b;
+		++issues_Issue11800.test_a;
+		++issues_Issue11800.test_b;
+	')
+	static function test() {
+		static var a = 0;
+
+		for (i in 0...3) {
+			switch i {
+				case n if (n < 2):
+					use(++a);
+					static var b = 0;
+					use(++b);
+				case _:
+			}
+		}
+	}
+
+	static function use(v:Int) {}
+}

+ 21 - 0
tests/unit/src/unit/issues/Issue11800.hx

@@ -0,0 +1,21 @@
+package unit.issues;
+
+class Issue11800 extends unit.Test {
+	public function test() {
+		static var a = 0; // Works.
+		var buf = new StringBuf();
+		function append(v:Int) {
+			buf.add(Std.string(v));
+		}
+		for (i in 0...3) {
+			switch i {
+				case n if (n < 2):
+					append(++a);
+					static var b = 0; // Not static.
+					append(++b); // Always `1`.
+				case _:
+			}
+		}
+		eq("1122", buf.toString());
+	}
+}