فهرست منبع

Use thread-local-storage for eval (#12154)

* use thread_local_storage

* use ThreadSafeHashtbl for ctx.evals

* move exception stack to eval

* opam
Simon Krajewski 3 ماه پیش
والد
کامیت
1e1828f49a

+ 1 - 0
haxe.opam

@@ -35,4 +35,5 @@ depends: [
   "ipaddr"
   "terminal_size"
   "domainslib"
+  "thread-local-storage"
 ]

+ 23 - 22
src/context/display/displayMemory.ml

@@ -36,6 +36,7 @@ let get_memory_json (cs : CompilationCache.t) mreq =
 				"additionalSizes",jarray (
 					(match !MacroContext.macro_interp_cache with
 					| Some interp ->
+						let eval = Thread_local_storage.get_exn interp.eval in
 						jobject ["name",jstring "macro interpreter";"size",jint (mem_size MacroContext.macro_interp_cache);"child",jarray [
 							jobject ["name",jstring "builtins";"size",jint (mem_size_2 interp.builtins [Obj.repr interp])];
 							jobject ["name",jstring "debug";"size",jint (mem_size_2 interp.debug [Obj.repr interp])];
@@ -51,35 +52,35 @@ let get_memory_json (cs : CompilationCache.t) mreq =
 							jobject ["name",jstring "file_keys";"size",jint (mem_size_2 interp.file_keys [Obj.repr interp])];
 							jobject ["name",jstring "toplevel";"size",jint (mem_size_2 interp.toplevel [Obj.repr interp])];
 							jobject ["name",jstring "eval";"size",jint (mem_size_2 interp.eval [Obj.repr interp]);"child", jarray [
-								(match interp.eval.env with
+								(match eval.env with
 								| Some env ->
-									jobject ["name",jstring "env";"size",jint (mem_size_2 interp.eval.env [Obj.repr interp; Obj.repr interp.eval]);"child", jarray [
-										jobject ["name",jstring "env_info";"size",jint (mem_size_2 env.env_info [Obj.repr interp; Obj.repr interp.eval; Obj.repr env])];
-										jobject ["name",jstring "env_debug";"size",jint (mem_size_2 env.env_debug [Obj.repr interp; Obj.repr interp.eval; Obj.repr env])];
-										jobject ["name",jstring "env_locals";"size",jint (mem_size_2 env.env_locals [Obj.repr interp; Obj.repr interp.eval; Obj.repr env])];
-										jobject ["name",jstring "env_captures";"size",jint (mem_size_2 env.env_captures [Obj.repr interp; Obj.repr interp.eval; Obj.repr env])];
-										jobject ["name",jstring "env_extra_locals";"size",jint (mem_size_2 env.env_extra_locals [Obj.repr interp; Obj.repr interp.eval; Obj.repr env])];
-										jobject ["name",jstring "env_parent";"size",jint (mem_size_2 env.env_parent [Obj.repr interp; Obj.repr interp.eval; Obj.repr env])];
-										jobject ["name",jstring "env_eval";"size",jint (mem_size_2 env.env_eval [Obj.repr interp; Obj.repr interp.eval; Obj.repr env])];
+									jobject ["name",jstring "env";"size",jint (mem_size_2 eval.env [Obj.repr interp; Obj.repr eval]);"child", jarray [
+										jobject ["name",jstring "env_info";"size",jint (mem_size_2 env.env_info [Obj.repr interp; Obj.repr eval; Obj.repr env])];
+										jobject ["name",jstring "env_debug";"size",jint (mem_size_2 env.env_debug [Obj.repr interp; Obj.repr eval; Obj.repr env])];
+										jobject ["name",jstring "env_locals";"size",jint (mem_size_2 env.env_locals [Obj.repr interp; Obj.repr eval; Obj.repr env])];
+										jobject ["name",jstring "env_captures";"size",jint (mem_size_2 env.env_captures [Obj.repr interp; Obj.repr eval; Obj.repr env])];
+										jobject ["name",jstring "env_extra_locals";"size",jint (mem_size_2 env.env_extra_locals [Obj.repr interp; Obj.repr eval; Obj.repr env])];
+										jobject ["name",jstring "env_parent";"size",jint (mem_size_2 env.env_parent [Obj.repr interp; Obj.repr eval; Obj.repr env])];
+										jobject ["name",jstring "env_eval";"size",jint (mem_size_2 env.env_eval [Obj.repr interp; Obj.repr eval; Obj.repr env])];
 									]];
 								| None ->
-									jobject ["name",jstring "env";"size",jint (mem_size_2 interp.eval.env [Obj.repr interp; Obj.repr interp.eval])];
+									jobject ["name",jstring "env";"size",jint (mem_size_2 eval.env [Obj.repr interp; Obj.repr eval])];
 								);
-								jobject ["name",jstring "thread";"size",jint (mem_size_2 interp.eval.thread [Obj.repr interp; Obj.repr interp.eval]);"child", jarray [
-									jobject ["name",jstring "tthread";"size",jint (mem_size_2 interp.eval.thread.tthread [Obj.repr interp; Obj.repr interp.eval; Obj.repr interp.eval.thread])];
-									jobject ["name",jstring "tdeque";"size",jint (mem_size_2 interp.eval.thread.tdeque [Obj.repr interp; Obj.repr interp.eval; Obj.repr interp.eval.thread])];
-									jobject ["name",jstring "tevents";"size",jint (mem_size_2 interp.eval.thread.tevents [Obj.repr interp; Obj.repr interp.eval; Obj.repr interp.eval.thread])];
-									jobject ["name",jstring "tstorage";"size",jint (mem_size_2 interp.eval.thread.tstorage [Obj.repr interp; Obj.repr interp.eval; Obj.repr interp.eval.thread])];
+								jobject ["name",jstring "thread";"size",jint (mem_size_2 eval.thread [Obj.repr interp; Obj.repr eval]);"child", jarray [
+									jobject ["name",jstring "tthread";"size",jint (mem_size_2 eval.thread.tthread [Obj.repr interp; Obj.repr eval; Obj.repr eval.thread])];
+									jobject ["name",jstring "tdeque";"size",jint (mem_size_2 eval.thread.tdeque [Obj.repr interp; Obj.repr eval; Obj.repr eval.thread])];
+									jobject ["name",jstring "tevents";"size",jint (mem_size_2 eval.thread.tevents [Obj.repr interp; Obj.repr eval; Obj.repr eval.thread])];
+									jobject ["name",jstring "tstorage";"size",jint (mem_size_2 eval.thread.tstorage [Obj.repr interp; Obj.repr eval; Obj.repr eval.thread])];
 								]];
-								jobject ["name",jstring "debug_state";"size",jint (mem_size_2 interp.eval.debug_state [Obj.repr interp; Obj.repr interp.eval])];
-								jobject ["name",jstring "breakpoint";"size",jint (mem_size_2 interp.eval.breakpoint [Obj.repr interp; Obj.repr interp.eval])];
-								jobject ["name",jstring "caught_types";"size",jint (mem_size_2 interp.eval.caught_types [Obj.repr interp; Obj.repr interp.eval])];
-								jobject ["name",jstring "caught_exception";"size",jint (mem_size_2 interp.eval.caught_exception [Obj.repr interp; Obj.repr interp.eval])];
-								jobject ["name",jstring "last_return";"size",jint (mem_size_2 interp.eval.last_return [Obj.repr interp; Obj.repr interp.eval])];
-								jobject ["name",jstring "debug_channel";"size",jint (mem_size_2 interp.eval.debug_channel [Obj.repr interp; Obj.repr interp.eval])];
+								jobject ["name",jstring "debug_state";"size",jint (mem_size_2 eval.debug_state [Obj.repr interp; Obj.repr eval])];
+								jobject ["name",jstring "breakpoint";"size",jint (mem_size_2 eval.breakpoint [Obj.repr interp; Obj.repr eval])];
+								jobject ["name",jstring "caught_types";"size",jint (mem_size_2 eval.caught_types [Obj.repr interp; Obj.repr eval])];
+								jobject ["name",jstring "caught_exception";"size",jint (mem_size_2 eval.caught_exception [Obj.repr interp; Obj.repr eval])];
+								jobject ["name",jstring "last_return";"size",jint (mem_size_2 eval.last_return [Obj.repr interp; Obj.repr eval])];
+								jobject ["name",jstring "debug_channel";"size",jint (mem_size_2 eval.debug_channel [Obj.repr interp; Obj.repr eval])];
 							]];
 							jobject ["name",jstring "evals";"size",jint (mem_size_2 interp.evals [Obj.repr interp])];
-							jobject ["name",jstring "exception_stack";"size",jint (mem_size_2 interp.exception_stack [Obj.repr interp])];
+							jobject ["name",jstring "exception_stack";"size",jint (mem_size_2 eval.exception_stack [Obj.repr interp])];
 						]];
 					| None ->
 						jobject ["name",jstring "macro interpreter";"size",jint (mem_size MacroContext.macro_interp_cache)];

+ 10 - 1
src/core/ds/threadSafeHashtbl.ml

@@ -15,4 +15,13 @@ let replace h k v =
 	Mutex.protect h.mutex (fun () -> Hashtbl.replace h.h k) v
 
 let find h k =
-	Mutex.protect h.mutex (fun () -> Hashtbl.find h.h) k
+	Mutex.protect h.mutex (fun () -> Hashtbl.find h.h) k
+
+let mem h k =
+	Mutex.protect h.mutex (fun () -> Hashtbl.mem h.h) k
+
+let remove h k =
+	Mutex.protect h.mutex (fun () -> Hashtbl.remove h.h) k
+
+let fold f h acc =
+	Mutex.protect h.mutex (fun () -> Hashtbl.fold f h.h) acc

+ 1 - 1
src/dune

@@ -22,7 +22,7 @@
 		unix ipaddr str bigarray threads dynlink
 		xml-light extlib sha terminal_size
 		luv
-		domainslib
+		domainslib thread-local-storage
 	)
 	(modules (:standard \ haxe prebuild))
 	(preprocess (per_module

+ 4 - 11
src/macro/eval/evalContext.ml

@@ -109,6 +109,7 @@ type env = {
 and eval = {
 	mutable env : env option;
 	thread : vthread;
+	mutable exception_stack : (pos * env_kind) list;
 	(* The threads current debug state *)
 	mutable debug_state : debug_state;
 	(* The currently active breakpoint. Set to a dummy value initially. *)
@@ -285,9 +286,8 @@ and context = {
 	get_object_prototype : 'a . context -> (int * 'a) list -> vprototype * (int * 'a) list;
 	(* eval *)
 	toplevel : value;
-	eval : eval;
-	mutable evals : eval IntMap.t;
-	mutable exception_stack : (pos * env_kind) list;
+	eval : eval Thread_local_storage.t;
+	evals : (int,eval) ThreadSafeHashtbl.t;
 	max_stack_depth : int;
 	max_print_depth : int;
 	print_indentation : string option;
@@ -321,14 +321,7 @@ let s_debug_state = function
 (* Misc *)
 
 let get_eval ctx =
-	let id = Thread.id (Thread.self()) in
-	if id = 0 then
-		ctx.eval
-	else
-		try
-			IntMap.find id ctx.evals
-		with Not_found ->
-			die "Cannot run Haxe code in a non-Haxe thread" __LOC__
+	Thread_local_storage.get_exn ctx.eval
 
 let kind_name eval kind =
 	let rec loop kind env = match kind with

+ 4 - 4
src/macro/eval/evalDebugSocket.ml

@@ -199,7 +199,7 @@ let output_threads ctx =
 			"name",JString (Printf.sprintf "Thread %i" (Thread.id eval.thread.tthread));
 		]) :: acc
 	in
-	let threads = IntMap.fold fold ctx.evals [] in
+	let threads = ThreadSafeHashtbl.fold fold ctx.evals [] in
 	JArray threads
 
 let is_simn = false
@@ -566,7 +566,7 @@ let handler =
 	in
 	let select_thread hctx =
 		let id = hctx.jsonrpc#get_opt_param (fun () -> hctx.jsonrpc#get_int_param "threadId") 0 in
-		let eval = try IntMap.find id hctx.ctx.evals with Not_found -> hctx.send_error "Invalid thread id" in
+		let eval = try ThreadSafeHashtbl.find hctx.ctx.evals id with Not_found -> hctx.send_error "Invalid thread id" in
 		eval
 	in
 	let h = Hashtbl.create 0 in
@@ -751,7 +751,7 @@ let handler =
 		);
 		"evaluate",(fun hctx ->
 			let ctx = hctx.ctx in
-			let env = try select_frame hctx with _ -> expect_env hctx ctx.eval.env in
+			let env = try select_frame hctx with _ -> expect_env hctx (Thread_local_storage.get_exn ctx.eval).env in
 			let s = hctx.jsonrpc#get_string_param "expr" in
 			begin try
 				let e = parse_expr ctx s env.env_debug.debug_pos in
@@ -765,7 +765,7 @@ let handler =
 			end
 		);
 		"getCompletion",(fun hctx ->
-			let env = expect_env hctx hctx.ctx.eval.env in
+			let env = expect_env hctx (Thread_local_storage.get_exn hctx.ctx.eval).env in
 			let text = hctx.jsonrpc#get_string_param "text" in
 			let column = hctx.jsonrpc#get_int_param "column" in
 			try

+ 1 - 1
src/macro/eval/evalEmitter.ml

@@ -231,7 +231,7 @@ let emit_try exec catches env =
 	with RunTimeException(v,_,_) as exc ->
 		eval.caught_exception <- vnull;
 		restore();
-		build_exception_stack ctx env;
+		build_exception_stack eval env;
 		let rec loop () = match eval.env with
 			| Some env' when env' != env ->
 				pop_environment ctx env';

+ 1 - 1
src/macro/eval/evalExceptions.ml

@@ -120,7 +120,7 @@ let catch_exceptions ctx ?(final=(fun() -> ())) f p =
 	with
 	| RunTimeException(v,eval_stack,p') ->
 		eval.caught_exception <- vnull;
-		Option.may (build_exception_stack ctx) env;
+		Option.may (build_exception_stack eval) env;
 		eval.env <- env;
 
 		(* Careful: We have to get the message before resetting the context because toString() might access it. *)

+ 1 - 1
src/macro/eval/evalLuv.ml

@@ -566,7 +566,7 @@ let uv_error_fields = [
 						ExtLib.String.join "\n" (List.rev !messages)
 				| _ -> Printexc.to_string ex
 			in
-			let e = create_haxe_exception ~stack:(get_ctx()).exception_stack msg in
+			let e = create_haxe_exception ~stack:((get_eval (get_ctx()))).exception_stack msg in
 			ignore(cb [e])
 		);
 		vnull

+ 3 - 3
src/macro/eval/evalMain.ml

@@ -97,7 +97,7 @@ let create com api is_macro =
 		tdeque = EvalThread.Deque.create();
 	} in
 	let eval = EvalThread.create_eval thread in
-	let evals = IntMap.singleton 0 eval in
+	let evals = ThreadSafeHashtbl.create 1 in
 	let ctx = {
 		ctx_id = !GlobalState.sid;
 		is_macro = is_macro;
@@ -122,15 +122,15 @@ let create com api is_macro =
 			ofields = [||];
 			oproto = OProto (fake_proto key_eval_toplevel);
 		};
-		eval = eval;
+		eval = Thread_local_storage.create ();
 		evals = evals;
-		exception_stack = [];
 		timer_ctx = com.timer_ctx;
 		max_stack_depth = int_of_string (Common.defined_value_safe ~default:"1000" com Define.EvalCallStackDepth);
 		max_print_depth = int_of_string (Common.defined_value_safe ~default:"5" com Define.EvalPrintDepth);
 		print_indentation = match Common.defined_value_safe com Define.EvalPrettyPrint
 			with | "" -> None | "1" -> Some "  " | indent -> Some indent;
 	} in
+	Thread_local_storage.set ctx.eval eval;
 	if debug.support_debugger && not !GlobalState.debugger_initialized then begin
 		(* Let's wait till the debugger says we're good to continue. This allows it to finish configuration.
 		   Note that configuration is shared between macro and interpreter contexts, which is why the check

+ 1 - 1
src/macro/eval/evalStackTrace.ml

@@ -41,6 +41,6 @@ let getCallStack = vfun0 (fun () ->
 
 let getExceptionStack = vfun0 (fun () ->
 	let ctx = get_ctx() in
-	let envs = ctx.exception_stack in
+	let envs = (get_eval ctx).exception_stack in
 	make_stack (List.rev envs)
 )

+ 1 - 2
src/macro/eval/evalStdLib.ml

@@ -2782,8 +2782,7 @@ module StdTls = struct
 	let get_value = vifun0 (fun vthis ->
 		let this = this vthis in
 		try
-			let id = Thread.id (Thread.self()) in
-			let eval = IntMap.find id (get_ctx()).evals in
+			let eval = get_eval (get_ctx()) in
 			IntMap.find this eval.thread.tstorage
 		with Not_found ->
 			vnull

+ 5 - 3
src/macro/eval/evalThread.ml

@@ -69,6 +69,7 @@ end
 let create_eval thread = {
 	env = None;
 	thread = thread;
+	exception_stack = [];
 	debug_channel = Event.new_channel ();
 	debug_state = DbgRunning;
 	breakpoint = make_breakpoint 0 0 BPDisabled BPAny None;
@@ -86,9 +87,10 @@ let run ctx f thread =
 			()
 	in
 	let new_eval = create_eval thread in
-	ctx.evals <- IntMap.add id new_eval ctx.evals;
+	ThreadSafeHashtbl.add ctx.evals id new_eval;
+	Thread_local_storage.set ctx.eval new_eval;
 	let close () =
-		ctx.evals <- IntMap.remove id ctx.evals;
+		ThreadSafeHashtbl.remove ctx.evals id;
 		maybe_send_thread_event "exited";
 	in
 	try
@@ -124,7 +126,7 @@ let spawn ctx f =
 *)
 let run ctx f =
 	let id = Thread.id (Thread.self()) in
-	if IntMap.mem id ctx.evals then
+	if ThreadSafeHashtbl.mem ctx.evals id then
 		ignore(f())
 	else begin
 		let thread = {