Browse Source

put Task.run around the entire callback instead of individual parallelism calls (#12121)

Simon Krajewski 5 months ago
parent
commit
f88be3b1fa
2 changed files with 4 additions and 5 deletions
  1. 3 3
      src/context/parallel.ml
  2. 1 2
      src/generators/genjvm.ml

+ 3 - 3
src/context/parallel.ml

@@ -6,7 +6,7 @@ let run_parallel_for num_domains ?(chunk_size=0) length f =
 module ParallelArray = struct
 	let iter pool f a =
 		let f' idx = f a.(idx) in
-		Domainslib.Task.run pool (fun _ -> Domainslib.Task.parallel_for pool ~start:0 ~finish:(Array.length a - 1) ~body:f')
+		Domainslib.Task.parallel_for pool ~start:0 ~finish:(Array.length a - 1) ~body:f'
 
 	let map pool f a x =
 		let length = Array.length a in
@@ -14,7 +14,7 @@ module ParallelArray = struct
 		let f' idx =
 			Array.unsafe_set a_out idx (f (Array.unsafe_get a idx))
 		in
-		Domainslib.Task.run pool (fun _ -> Domainslib.Task.parallel_for pool ~start:0 ~finish:(length - 1) ~body:f');
+		Domainslib.Task.parallel_for pool ~start:0 ~finish:(length - 1) ~body:f';
 		a_out
 end
 
@@ -25,4 +25,4 @@ end
 
 let run_in_new_pool timer_ctx f =
 	let pool = Timer.time timer_ctx ["domainslib";"setup"] (Domainslib.Task.setup_pool ~num_domains:(Domain.recommended_domain_count() - 1)) () in
-	Std.finally (fun () -> Timer.time timer_ctx ["domainslib";"teardown"] Domainslib.Task.teardown_pool pool) f pool
+	Std.finally (fun () -> Timer.time timer_ctx ["domainslib";"teardown"] Domainslib.Task.teardown_pool pool) (Domainslib.Task.run pool) (fun () -> f pool)

+ 1 - 2
src/generators/genjvm.ml

@@ -3229,8 +3229,7 @@ let generate jvm_flag gctx =
 		run_timed gctx false "anons" (fun () -> generate_anons gctx pool);
 		run_timed gctx false "typed_functions" (fun () -> generate_typed_functions gctx);
 	in
-	let pool = Domainslib.Task.setup_pool ~num_domains:(Domain.recommended_domain_count()) () in
-	Std.finally (fun () -> Domainslib.Task.teardown_pool pool) generate pool;
+	Parallel.run_in_new_pool gctx.gctx.timer_ctx generate;
 
 	let manifest_content =
 		"Manifest-Version: 1.0\n" ^