Pārlūkot izejas kodu

[wip] Tail recursion elimination (#8908)

* draft tre

* don't descend into anon functions

* local named functions

* tests

* don't eliminate tail recursion in user-defined loops

* fix abstract member methods

* fix it even more

* handle arguments captured in closures

* "internal" flag for @:tailRecursion in meta.json

* wip

* tre for void functions

* fix tre for void functions

* cleanup

* don't apply tre to overridden methods
Aleksandr Kuzmenko 6 gadi atpakaļ
vecāks
revīzija
27c349a525

+ 5 - 0
src-json/define.json

@@ -10,6 +10,11 @@
 		"doc": "Allow the SWF to be measured with Monocle tool.",
 		"platforms": ["flash"]
 	},
+	{
+		"name": "AnalyzerOptimize",
+		"define": "analyzer_optimize",
+		"doc": "Perform advanced optimizations."
+	},
 	{
 		"name": "AnnotateSource",
 		"define": "annotate_source",

+ 6 - 0
src-json/meta.json

@@ -1101,6 +1101,12 @@
 		"platforms": ["java"],
 		"targets": ["TClass"]
 	},
+	{
+		"name": "TailRecursion",
+		"metadata": ":tailRecursion",
+		"doc": "Internally used for tail recursion elimination.",
+		"internal": true
+	},
 	{
 		"name": "TemplatedCall",
 		"metadata": ":templatedCall",

+ 1 - 0
src/filters/filters.ml

@@ -820,6 +820,7 @@ let run com tctx main =
 		fix_return_dynamic_from_void_function tctx true;
 		check_local_vars_init;
 		check_abstract_as_value;
+		if defined com Define.AnalyzerOptimize then Tre.run tctx else (fun e -> e);
 		Optimizer.reduce_expression tctx;
 		if Common.defined com Define.OldConstructorInline then Optimizer.inline_constructors tctx else InlineConstructors.inline_constructors tctx;
 		CapturedVars.captured_vars com;

+ 10 - 0
src/filters/filtersCommon.ml

@@ -38,6 +38,16 @@ let rec is_removable_class c =
 	| _ ->
 		false
 
+(**
+	Check if `field` is overridden in subclasses
+*)
+let is_overridden cls field =
+	let rec loop_inheritance c =
+		(PMap.mem field.cf_name c.cl_fields)
+		|| List.exists (fun d -> loop_inheritance d) c.cl_descendants;
+	in
+	List.exists (fun d -> loop_inheritance d) cls.cl_descendants
+
 let run_expression_filters ctx filters t =
 	let run e =
 		List.fold_left (fun e f -> f e) e filters

+ 199 - 0
src/filters/tre.ml

@@ -0,0 +1,199 @@
+open Type
+open Typecore
+open Globals
+
+let rec collect_new_args_values ctx args declarations values n =
+	match args with
+	| [] -> declarations, values
+	| arg :: rest ->
+		let v = alloc_var VGenerated ("`tmp" ^ (string_of_int n)) arg.etype arg.epos in
+		let decl = { eexpr = TVar (v, Some arg); etype = ctx.t.tvoid; epos = v.v_pos }
+		and value = { arg with eexpr = TLocal v } in
+		collect_new_args_values ctx rest (decl :: declarations) (value :: values) (n + 1)
+
+let rec assign_args vars exprs =
+	match vars, exprs with
+	| [], [] -> []
+	| (v, _) :: rest_vars, e :: rest_exprs
+	| (v, Some e) :: rest_vars, rest_exprs ->
+		let arg = { e with eexpr = TLocal v } in
+		{ e with eexpr = TBinop (OpAssign, arg, e) } :: assign_args rest_vars rest_exprs
+	| _ -> assert false
+
+let replacement_for_TReturn ctx fn args p =
+	let temps_rev, args_rev = collect_new_args_values ctx args [] [] 0
+	and continue = mk TContinue ctx.t.tvoid Globals.null_pos in
+	{
+		etype = ctx.t.tvoid;
+		epos = p;
+		eexpr = TMeta ((Meta.TailRecursion, [], null_pos), {
+			eexpr = TBlock ((List.rev temps_rev) @ (assign_args fn.tf_args (List.rev args_rev)) @ [continue]);
+			etype = ctx.t.tvoid;
+			epos = p;
+		});
+	}
+
+let collect_captured_args args e =
+	let result = ref [] in
+	let rec loop in_closure e =
+		match e.eexpr with
+		| TLocal ({ v_kind = VUser TVOArgument } as v) when in_closure && not (List.memq v !result) && List.memq v args ->
+			result := v :: !result
+		| TFunction { tf_expr = e } ->
+			loop true e
+		| _ ->
+			iter (loop in_closure) e
+	in
+	loop false e;
+	!result
+
+let rec redeclare_vars ctx vars declarations replace_list =
+	match vars with
+	| [] -> declarations, replace_list
+	| v :: rest ->
+		let new_v = alloc_var VGenerated ("`" ^ v.v_name) v.v_type v.v_pos in
+		let decl =
+			{
+				eexpr = TVar (new_v, Some { eexpr = TLocal v; etype = v.v_type; epos = v.v_pos; });
+				etype = ctx.t.tvoid;
+				epos = v.v_pos;
+			}
+		in
+		redeclare_vars ctx rest (decl :: declarations) ((v, new_v) :: replace_list)
+
+let rec replace_vars replace_list in_tail_recursion e =
+	match e.eexpr with
+	| TBinop (OpAssign, ({ eexpr = TLocal { v_kind = VUser TVOArgument } } as arg), value) when in_tail_recursion ->
+		let value = replace_vars replace_list in_tail_recursion value in
+		{ e with eexpr = TBinop (OpAssign, arg, value) }
+	| TLocal v ->
+		(try
+			let v = List.assq v replace_list in
+			{ e with eexpr = TLocal v }
+		with Not_found ->
+			e
+		)
+	| TMeta ((Meta.TailRecursion, _, _), _) -> map_expr (replace_vars replace_list true) e
+	| _ -> map_expr (replace_vars replace_list in_tail_recursion) e
+
+let wrap_loop ctx args body =
+	let wrap e =
+		let cond = mk (TConst (TBool true)) ctx.t.tbool Globals.null_pos in
+		{ e with eexpr = TWhile (cond, e, Ast.NormalWhile) }
+	in
+	match collect_captured_args args body with
+	| [] -> wrap body
+	| captured_args ->
+		let declarations, replace_list = redeclare_vars ctx captured_args [] [] in
+		wrap { body with eexpr = TBlock (declarations @ [replace_vars replace_list false body]) }
+
+let fn_args_vars fn = List.map (fun (v,_) -> v) fn.tf_args
+
+let is_recursive_named_local_call fn_var callee args =
+	match callee.eexpr with
+	(* named local function*)
+	| TLocal v ->
+		v == fn_var
+	| _ -> false
+
+let is_recursive_method_call cls field callee args =
+	match callee.eexpr, args with
+	(* member abstract function*)
+	| TField (_, FStatic (_, cf)), { eexpr = TLocal v } :: _ when has_meta Meta.Impl cf.cf_meta ->
+		cf == field && has_meta Meta.This v.v_meta
+	(* static method *)
+	| TField (_, FStatic (_, cf)), _
+	(* instance method *)
+	| TField ({ eexpr = TConst TThis }, FInstance (_, _, cf)), _ ->
+		cf == field && not (FiltersCommon.is_overridden cls field)
+	| _ -> false
+
+let rec transform_function ctx is_recursive_call fn =
+	let add_loop = ref false in
+	let rec transform_expr in_loop function_end e =
+		match e.eexpr with
+		| TWhile _ | TFor _ ->
+			map_expr (transform_expr true false) e
+		(* named local function *)
+		| TBinop (OpAssign, ({ eexpr = TLocal ({ v_kind = VUser TVOLocalFunction } as v) } as e_var), ({ eexpr = TFunction fn } as e_fn)) ->
+			let fn = transform_function ctx (is_recursive_named_local_call v) fn in
+			{ e with eexpr = TBinop (OpAssign, e_var, { e_fn with eexpr = TFunction fn }) }
+		(* anonymous function *)
+		| TFunction _ ->
+			e
+		(* return a recursive call to current function *)
+		| TReturn (Some { eexpr = TCall (callee, args) }) when not in_loop && is_recursive_call callee args ->
+			add_loop := true;
+			replacement_for_TReturn ctx fn args e.epos
+		| TReturn (Some e_return) ->
+			{ e with eexpr = TReturn (Some (transform_expr in_loop function_end e_return)) }
+		| TBlock exprs ->
+			let rec loop exprs =
+				match exprs with
+				| [] -> []
+				| [{ eexpr = TCall (callee, args) } as e] when not in_loop && function_end && is_recursive_call callee args ->
+					add_loop := true;
+					[replacement_for_TReturn ctx fn args e.epos]
+				| { eexpr = TCall (callee, args) } :: [{ eexpr = TReturn None }] when not in_loop && is_recursive_call callee args ->
+					add_loop := true;
+					[replacement_for_TReturn ctx fn args e.epos]
+				| e :: rest ->
+					let function_end = function_end && rest = [] in
+					transform_expr in_loop function_end e :: loop rest
+			in
+			{ e with eexpr = TBlock (loop exprs) }
+		| _ ->
+			map_expr (transform_expr in_loop function_end) e
+	in
+	let body = transform_expr false true fn.tf_expr in
+	let body =
+		if !add_loop then
+			let body =
+				if ExtType.is_void (follow fn.tf_type) then
+					mk (TBlock [body; mk (TReturn None) ctx.t.tvoid null_pos]) ctx.t.tvoid null_pos
+				else
+					body
+			in
+			wrap_loop ctx (fn_args_vars fn) body
+		else
+			body
+	in
+	{ fn with tf_expr = body }
+
+let rec has_tail_recursion is_recursive_call in_loop function_end e =
+	match e.eexpr with
+	| TFor _ | TWhile _ ->
+		check_expr (has_tail_recursion is_recursive_call true false) e
+	(* named local function *)
+	| TBinop (OpAssign, { eexpr = TLocal ({ v_kind = VUser TVOLocalFunction } as v) }, { eexpr = TFunction fn }) ->
+		has_tail_recursion (is_recursive_named_local_call v) false true fn.tf_expr
+	(* anonymous function *)
+	| TFunction _ ->
+		false
+	| TReturn (Some { eexpr = TCall (callee, args)}) ->
+		not in_loop && is_recursive_call callee args
+	| TBlock exprs ->
+		let rec loop exprs =
+			match exprs with
+			| [] -> false
+			| [{ eexpr = TCall (callee, args) }] when not in_loop && function_end ->
+				is_recursive_call callee args
+			| { eexpr = TCall (callee, args) } :: [{ eexpr = TReturn None }] when not in_loop ->
+				is_recursive_call callee args
+			| e :: rest ->
+				let function_end = function_end && rest = [] in
+				has_tail_recursion is_recursive_call in_loop function_end e
+				|| loop rest
+		in
+		loop exprs
+	| _ ->
+		check_expr (has_tail_recursion is_recursive_call in_loop function_end) e
+
+let run ctx e =
+	let is_recursive_call = is_recursive_method_call ctx.curclass ctx.curfield in
+	match e.eexpr with
+	| TFunction fn when has_tail_recursion is_recursive_call false true fn.tf_expr ->
+		(* print_endline ("TRE: " ^ ctx.curfield.cf_pos.pfile ^ ": " ^ ctx.curfield.cf_name); *)
+		let fn = transform_function ctx is_recursive_call fn in
+		{ e with eexpr = TFunction fn }
+	| _ -> e

+ 1 - 1
src/optimization/analyzerConfig.ml

@@ -64,7 +64,7 @@ let is_ignored meta =
 
 let get_base_config com =
 	{
-		optimize = Common.raw_defined com "analyzer-optimize";
+		optimize = Common.defined com Define.AnalyzerOptimize;
 		const_propagation = not (Common.raw_defined com "analyzer-no-const-propagation");
 		copy_propagation = not (Common.raw_defined com "analyzer-no-copy-propagation");
 		local_dce = not (Common.raw_defined com "analyzer-no-local-dce");

+ 1 - 11
src/typing/nullSafety.ml

@@ -450,16 +450,6 @@ let should_be_initialized field =
 		| Var _ -> Meta.has Meta.IsVar field.cf_meta
 		| _ -> false
 
-(**
-	Check if `field` is overridden in subclasses
-*)
-let is_overridden cls field =
-	let rec loop_inheritance c =
-		(PMap.mem field.cf_name c.cl_fields)
-		|| List.exists (fun d -> loop_inheritance d) c.cl_descendants;
-	in
-	List.exists (fun d -> loop_inheritance d) cls.cl_descendants
-
 (**
 	Check if all items of the `needle` list exist in the same order in the beginning of the `haystack` list.
 *)
@@ -504,7 +494,7 @@ class immediate_execution =
 							(* known to be pure *)
 							| { cl_path = ([], "Array") }, _ -> true
 							(* try to analyze function code *)
-							| _, ({ cf_expr = (Some { eexpr = TFunction fn }) } as field) when (has_class_field_flag field CfFinal) || not (is_overridden cls field) ->
+							| _, ({ cf_expr = (Some { eexpr = TFunction fn }) } as field) when (has_class_field_flag field CfFinal) || not (FiltersCommon.is_overridden cls field) ->
 								if arg_num < 0 || arg_num >= List.length fn.tf_args then
 									false
 								else begin

+ 5 - 0
tests/optimization/run.hxml

@@ -11,10 +11,15 @@
 -D analyzer-check-null
 --interp
 
+--next
+--main TestTreBehavior
+--interp
+
 --next
 -js testopt.js
 --macro Macro.register('Test')
 --macro Macro.register('TestJs')
 --macro Macro.register('TestLocalDce')
+--macro Macro.register('TestTreGeneration')
 --macro Macro.register('issues')
 --dce std

+ 6 - 1
tests/optimization/src/TestBaseMacro.hx

@@ -19,7 +19,12 @@ class TestBaseMacro {
 				acc.push(macro $i{field.name}());
 			}
 		}
-		acc.push(macro trace("Done " +numTests+ " tests (" +numFailures+ " failures)"));
+		acc.push(macro {
+			trace("Done " +numTests+ " tests (" +numFailures+ " failures)");
+			if(numFailures > 0) {
+				Sys.exit(1);
+			}
+		});
 		Context.onGenerate(check);
 		return macro $b{acc};
 	}

+ 53 - 0
tests/optimization/src/TestTreBehavior.hx

@@ -0,0 +1,53 @@
+package ;
+
+class TestTreBehavior extends TestBase {
+
+	static function main() {
+		new TestTreBehavior();
+	}
+
+	public function new() {
+		super();
+		TestBaseMacro.run();
+	}
+
+	function testClosureCapturedArgs() {
+		var steps = [];
+
+		function loop(a:Int):Int {
+			steps.push(() -> a);
+			--a;
+			return a <= 0 ? 0 : loop(a - 1);
+		}
+		loop(5);
+
+		var actual = steps.map(fn -> fn());
+		switch actual {
+			case [4, 2, 0]:
+			case _: assertEquals(actual, [4, 2, 0]);
+		}
+	}
+
+	function testOverriddenMethod() {
+		var parent = new Parent();
+		var child = new Child();
+
+		assertEquals(2, parent.rec(2));
+		assertEquals(5, child.rec(2));
+	}
+}
+
+private class Parent {
+	public function new() {}
+
+	public function rec(n:Int, cnt:Int = 0):Int {
+		if(n <= 0) return cnt;
+		return rec(n - 1, cnt + 1);
+	}
+}
+
+private class Child extends Parent {
+	override public function rec(n:Int, cnt:Int = 0):Int {
+		return super.rec(n, cnt + 1);
+	}
+}

+ 129 - 0
tests/optimization/src/TestTreGeneration.hx

@@ -0,0 +1,129 @@
+class TestTreGeneration {
+	@:js('
+		if(b == null) {
+			b = 10;
+		}
+		while(true) {
+			if(Std.random(2) == 0) {
+				var _gtmp1 = a;
+				a = b + a;
+				b = _gtmp1;
+				s += "?";
+				continue;
+			}
+			if(s == null) {
+				return a;
+			} else {
+				return b;
+			}
+		}
+	')
+	static function testStaticMethod(a:Int, b:Int = 10, ?s:String):Int {
+		if(Std.random(2) == 0) {
+			return testStaticMethod(b + a, a, s + '?');
+		}
+		return s == null ? a : b;
+	}
+
+	@:js('
+		if(b == null) {
+			b = 10;
+		}
+		while(true) {
+			if(Std.random(2) == 0) {
+				var _gtmp1 = a;
+				a = b + a;
+				b = _gtmp1;
+				s += "?";
+				continue;
+			}
+			if(s == null) {
+				return a;
+			} else {
+				return b;
+			}
+		}
+	')
+	function testInstanceMethod(a:Int, b:Int = 10, ?s:String):Int {
+		if(Std.random(2) == 0) {
+			return testInstanceMethod(b + a, a, s + '?');
+		}
+		return s == null ? a : b;
+	}
+
+	@:js('
+		var local = null;
+		local = function(a,b,s) {
+			if(b == null) {
+				b = 10;
+			}
+			while(true) {
+				if(Std.random(2) == 0) {
+					var _gtmp1 = a;
+					a = b + a;
+					b = _gtmp1;
+					s += "?";
+					continue;
+				}
+				if(s == null) {
+					return a;
+				} else {
+					return b;
+				}
+			}
+		};
+		local(1,2);
+	')
+	static function testLocalNamedFunction() {
+		function local(a:Int, b:Int = 10, ?s:String):Int {
+			if(Std.random(2) == 0) {
+				return local(b + a, a, s + '?');
+			}
+			return s == null ? a : b;
+		}
+		local(1, 2);
+	}
+
+	@:js('
+		var _g = 0;
+		var _g1 = Std.random(10);
+		while(_g < _g1) {
+			++_g;
+			if(Std.random(2) == 0) {
+				return TestTreGeneration.testTailRecursionInsideLoop();
+			}
+		}
+		return Std.random(10);
+	')
+	static function testTailRecursionInsideLoop():Int {
+		for(i in 0...Std.random(10)) {
+			if(Std.random(2) == 0) {
+				return testTailRecursionInsideLoop();
+			}
+		}
+		return Std.random(10);
+	}
+
+	@:js('
+		while(true) {
+			if(Std.random(2) == 0) {
+				a -= 1;
+				continue;
+			}
+			if(a < 10) {
+				a += 1;
+				continue;
+			}
+			return;
+		}
+	')
+	static function testVoid(a:Int):Void {
+		if(Std.random(2) == 0) {
+			testVoid(a - 1);
+			return;
+		}
+		if(a < 10) {
+			testVoid(a + 1);
+		}
+	}
+}