Browse Source

[cs] generate method overloads for each optional arg (closes #4623)

Aleksandr Kuzmenko 6 years ago
parent
commit
36f0ecfbf2

+ 57 - 1
src/generators/gencs.ml

@@ -148,6 +148,56 @@ let rec change_md = function
 		TAbstractDecl a
 	| md -> md
 
+(**
+	Generates method overloads for a method with trailing optional arguments.
+	E.g. for `function method(a:Int, b:Bool = false) {...}`
+	generates `function method(a:Int) { method(a, false); }`
+*)
+let get_overloads_for_optional_args gen cl cf is_static =
+	match cf.cf_params,cf.cf_kind with
+	| [],Method (MethNormal | MethDynamic | MethInline) ->
+		(match cf.cf_expr, follow cf.cf_type with
+		| Some ({ eexpr = TFunction fn } as method_expr), TFun (args, return_type) ->
+			let rec collect_overloads tf_args_rev args_rev default_values_rev =
+				match tf_args_rev, args_rev with
+				| (_, Some default_value) :: rest_tf_args_rev, _ :: rest_args_rev ->
+					let field_expr =
+						if is_static then
+							make_static_field cl cf cf.cf_pos
+						else begin
+							let this_expr =
+								mk (TConst TThis) (type_of_module_type (TClassDecl cl)) cf.cf_pos
+							in
+							mk_field_access gen this_expr cf.cf_name cf.cf_pos
+						end
+					in
+					let default_values_rev = default_values_rev @ [default_value] in
+					let args_exprs =
+						List.rev (
+							default_values_rev
+							@ (List.map (fun (v,_) -> mk_local v v.v_pos ) rest_tf_args_rev)
+						)
+					in
+					let call = { fn.tf_expr with eexpr = TCall (field_expr, args_exprs) } in
+					let fn_body =
+						if ExtType.is_void (follow return_type) then call
+						else { fn.tf_expr with eexpr = TReturn (Some call) }
+					in
+					let fn =
+						{ fn with tf_args = List.rev rest_tf_args_rev; tf_expr = mk_block fn_body }
+					in
+					{ cf with
+						cf_overloads = [];
+						cf_type = TFun (List.rev rest_args_rev, return_type);
+						cf_expr = Some { method_expr with eexpr = TFunction fn };
+					} :: collect_overloads rest_tf_args_rev rest_args_rev default_values_rev
+				| _ -> []
+			in
+			collect_overloads (List.rev fn.tf_args) (List.rev args) []
+		| _ -> []
+		)
+	| _ -> []
+
 (* used in c#-specific filters to skip some of them for the special haxe.lang.Runtime class *)
 let in_runtime_class gen =
 	match gen.gcurrent_class with
@@ -2153,10 +2203,16 @@ let generate con =
 								gen_class_field w ~is_overload:true is_static cl (has_class_field_flag cf CfFinal) cf
 						) cf.cf_overloads;
 				| Method mkind ->
+					let overloads =
+						match cf.cf_overloads with
+						| [] when is_overload -> []
+						| [] -> get_overloads_for_optional_args gen cl cf is_static
+						| overloads -> overloads
+					in
 					List.iter (fun cf ->
 						if cl.cl_interface || cf.cf_expr <> None then
 							gen_class_field w ~is_overload:true is_static cl (has_class_field_flag cf CfFinal) cf
-					) cf.cf_overloads;
+					) overloads;
 					let is_virtual = not is_final && match mkind with | MethInline -> false | _ when not is_new -> true | _ -> false in
 					let is_virtual = if not is_virtual || (has_class_field_flag cf CfFinal) then false else is_virtual in
 					let is_override = List.memq cf cl.cl_overrides in

+ 37 - 0
tests/misc/cs/projects/Issue4623/Main.hx

@@ -0,0 +1,37 @@
+class Main {
+	static var voidResult:String;
+
+	static function testVoid(a:Int, b:String = 'hello', ?c:String):Void {
+		voidResult = '$a $b $c';
+	}
+
+	static function test(a:Int, b:String = 'hello', ?c:String):String {
+		return '$a $b $c';
+	}
+
+	public static function main() {
+		untyped __cs__('global::Main.testVoid(1, "foo")');
+		var expected = '1 foo null';
+		if(voidResult != expected) {
+			throw 'Invalid result of testVoid(1, "foo"). Expected: $expected. Got: $voidResult';
+		}
+
+		untyped __cs__('global::Main.testVoid(2)');
+		var expected = '2 hello null';
+		if(voidResult != expected) {
+			throw 'Invalid result of testVoid(2). Expected: $expected. Got: $voidResult';
+		}
+
+		var expected = '3 bar null';
+		var result = untyped __cs__('global::Main.test(3, "bar")');
+		if(expected != result) {
+			throw 'Invalid result of test(3, "bar"). Expected: $expected. Got: $result';
+		}
+
+		var expected = '4 hello null';
+		var result = untyped __cs__('global::Main.test(4)');
+		if(expected != result) {
+			throw 'Invalid result of test(4). Expected: $expected. Got: $result';
+		}
+	}
+}

+ 12 - 0
tests/misc/cs/projects/Issue4623/Run.hx

@@ -0,0 +1,12 @@
+class Run {
+	static public function run() {
+		var exe = 'bin/bin/Main.exe';
+		var exitCode = switch (Sys.systemName()) {
+			case 'Windows':
+				Sys.command(exe);
+			case _:
+				Sys.command('mono', [exe]);
+		}
+		Sys.exit(exitCode);
+	}
+}

+ 4 - 0
tests/misc/cs/projects/Issue4623/compile.hxml

@@ -0,0 +1,4 @@
+-cs bin
+-main Main
+--next
+--macro Run.run()