2
0
Эх сурвалжийг харах

[nullsafety] fix inferring safety from the conditions of while loops (fixes #7852)

Alexander Kuzmenko 6 жил өмнө
parent
commit
030601e262

+ 42 - 15
src/typing/nullSafety.ml

@@ -717,28 +717,45 @@ class local_safety (mode:safety_mode) =
 			This method should be called upon passing `while`.
 			This method should be called upon passing `while`.
 			It collects locals which are checked against `null` and executes callbacks for expressions with proper statuses of locals.
 			It collects locals which are checked against `null` and executes callbacks for expressions with proper statuses of locals.
 		*)
 		*)
-		method process_while expr is_nullable_expr (condition_callback:texpr->unit) (body_callback:texpr->unit) =
+		method process_while expr is_nullable_expr (condition_callback:texpr->unit) (body_callback:(unit->unit)->texpr->unit) =
 			match expr.eexpr with
 			match expr.eexpr with
 				| TWhile (condition, body, DoWhile) ->
 				| TWhile (condition, body, DoWhile) ->
+					let original_safe_locals = self#get_safe_locals_copy in
 					condition_callback condition;
 					condition_callback condition;
-					body_callback body
+					let (_, not_nulls) = process_condition (mode <> SMStrict) condition is_nullable_expr (fun _ -> ()) in
+					body_callback
+						(fun () ->
+							List.iter
+								(fun not_null ->
+									match get_subject (mode <> SMStrict) not_null with
+										| SNotSuitable -> ()
+										| subj ->
+											if Hashtbl.mem original_safe_locals subj then
+												self#get_current_scope#add_to_safety not_null
+								)
+								not_nulls
+						)
+						body
 				| TWhile (condition, body, NormalWhile) ->
 				| TWhile (condition, body, NormalWhile) ->
 					condition_callback condition;
 					condition_callback condition;
 					let (nulls, not_nulls) = process_condition (mode <> SMStrict) condition is_nullable_expr (fun _ -> ()) in
 					let (nulls, not_nulls) = process_condition (mode <> SMStrict) condition is_nullable_expr (fun _ -> ()) in
 					(** execute `body` with known not-null variables *)
 					(** execute `body` with known not-null variables *)
 					List.iter self#get_current_scope#add_to_safety not_nulls;
 					List.iter self#get_current_scope#add_to_safety not_nulls;
-					body_callback body;
+					body_callback
+						(fun () -> List.iter self#get_current_scope#add_to_safety not_nulls)
+						body;
 					List.iter self#get_current_scope#remove_from_safety not_nulls;
 					List.iter self#get_current_scope#remove_from_safety not_nulls;
 				| _ -> fail ~msg:"Expected TWhile" expr.epos __POS__
 				| _ -> fail ~msg:"Expected TWhile" expr.epos __POS__
 		(**
 		(**
 			Should be called for bodies of loops (for, while)
 			Should be called for bodies of loops (for, while)
 		*)
 		*)
-		method process_loop_body (first_check:unit->unit) (second_check:unit->unit) =
+		method process_loop_body (first_check:unit->unit) (intermediate_action:(unit->unit) option) (second_check:unit->unit) =
 			let original_safe_locals = self#get_safe_locals_copy in
 			let original_safe_locals = self#get_safe_locals_copy in
 			(** The first check to find out which vars will become unsafe in a loop *)
 			(** The first check to find out which vars will become unsafe in a loop *)
 			first_check();
 			first_check();
 			(* If local var became safe in a loop, then we need to remove it from safety to make it unsafe outside of a loop again *)
 			(* If local var became safe in a loop, then we need to remove it from safety to make it unsafe outside of a loop again *)
 			self#get_current_scope#filter_safety original_safe_locals;
 			self#get_current_scope#filter_safety original_safe_locals;
+			Option.may (fun action -> action()) intermediate_action;
 			(** The second check with unsafe vars removed from safety *)
 			(** The second check with unsafe vars removed from safety *)
 			second_check()
 			second_check()
 		(**
 		(**
@@ -1009,14 +1026,23 @@ class expr_checker mode immediate_execution report =
 		*)
 		*)
 		method private check_while e =
 		method private check_while e =
 			match e.eexpr with
 			match e.eexpr with
-				| TWhile _ ->
+				| TWhile (_, _, while_flag) ->
 					let check_condition condition =
 					let check_condition condition =
 						if self#is_nullable_expr condition then
 						if self#is_nullable_expr condition then
 							self#error "Cannot use nullable value as a condition in \"while\"." condition.epos;
 							self#error "Cannot use nullable value as a condition in \"while\"." condition.epos;
 						self#check_expr condition
 						self#check_expr condition
 					in
 					in
 					local_safety#loop_declared e;
 					local_safety#loop_declared e;
-					local_safety#process_while e self#is_nullable_expr check_condition self#check_loop_body;
+					local_safety#process_while
+						e
+						self#is_nullable_expr
+						check_condition
+						(* self#check_loop_body; *)
+						(fun handle_condition_effect body ->
+							self#check_loop_body
+								(Some handle_condition_effect)
+								body
+						);
 					local_safety#scope_closed
 					local_safety#scope_closed
 				| _ -> fail ~msg:"Expected TWhile." e.epos __POS__
 				| _ -> fail ~msg:"Expected TWhile." e.epos __POS__
 		(**
 		(**
@@ -1030,19 +1056,20 @@ class expr_checker mode immediate_execution report =
 					self#check_expr iterable;
 					self#check_expr iterable;
 					local_safety#declare_var v;
 					local_safety#declare_var v;
 					local_safety#loop_declared e;
 					local_safety#loop_declared e;
-					self#check_loop_body body;
+					self#check_loop_body None body;
 					local_safety#scope_closed
 					local_safety#scope_closed
 				| _ -> fail ~msg:"Expected TFor." e.epos __POS__
 				| _ -> fail ~msg:"Expected TFor." e.epos __POS__
 		(**
 		(**
 			Handle safety inside of loops
 			Handle safety inside of loops
 		*)
 		*)
-		method private check_loop_body body =
+		method private check_loop_body (handle_condition_effect:(unit->unit) option) body =
 			local_safety#process_loop_body
 			local_safety#process_loop_body
 				(* Start pretending to ignore errors *)
 				(* Start pretending to ignore errors *)
 				(fun () ->
 				(fun () ->
 					is_pretending <- true;
 					is_pretending <- true;
 					self#check_expr body
 					self#check_expr body
 				)
 				)
+				handle_condition_effect
 				(* Now we know, which vars will become unsafe in this loop. Stop pretending and check again *)
 				(* Now we know, which vars will become unsafe in this loop. Stop pretending and check again *)
 				(fun () ->
 				(fun () ->
 					is_pretending <- false;
 					is_pretending <- false;
@@ -1261,13 +1288,13 @@ class class_checker cls immediate_execution report  =
 			if is_safe_class && (not cls.cl_extern) && (not cls.cl_interface) then
 			if is_safe_class && (not cls.cl_extern) && (not cls.cl_interface) then
 				self#check_var_fields;
 				self#check_var_fields;
 			let check_field is_static f =
 			let check_field is_static f =
-				if self#is_in_safety f then begin
-					(* if f.cf_name = "return_assignNonNullable_shouldPass" then
-						Option.may (fun e -> print_endline (s_expr str_type e)) f.cf_expr; *)
-					let mode = safety_mode (cls.cl_meta @ f.cf_meta) in
-					Option.may ((self#get_checker mode)#check_root_expr) f.cf_expr;
-					self#check_accessors is_static f
-				end
+				(* if f.cf_name = "return_assignNonNullable_shouldPass" then
+					Option.may (fun e -> print_endline (s_expr str_type e)) f.cf_expr; *)
+				match (safety_mode (cls.cl_meta @ f.cf_meta)) with
+					| SMOff -> ()
+					| mode ->
+						Option.may ((self#get_checker mode)#check_root_expr) f.cf_expr;
+						self#check_accessors is_static f
 			in
 			in
 			if is_safe_class then
 			if is_safe_class then
 				Option.may ((self#get_checker (safety_mode cls.cl_meta))#check_root_expr) cls.cl_init;
 				Option.may ((self#get_checker (safety_mode cls.cl_meta))#check_root_expr) cls.cl_init;

+ 21 - 0
tests/nullsafety/src/cases/TestStrict.hx

@@ -543,6 +543,27 @@ class TestStrict {
 		while(a == null) shouldFail(b = a);
 		while(a == null) shouldFail(b = a);
 	}
 	}
 
 
+	static function while_checkAgainstNullInConditionAndReassignInBody(?a:{value:String, ?parent:Dynamic}) {
+		while (a != null) {
+			var s:String = a.value;
+			a = a.parent;
+			shouldFail(a.value);
+		}
+
+		do {
+			a = shouldFail(a.parent);
+		} while (a != null);
+
+		a = {value:'hello', parent:null};
+		do {
+			var s:String = a.value;
+			a = a.parent;
+			shouldFail(a.value);
+		} while (a != null);
+
+		shouldFail(a.value);
+	}
+
 	static function throw_nullableValue_shouldFail() {
 	static function throw_nullableValue_shouldFail() {
 		var s:Null<String> = null;
 		var s:Null<String> = null;
 		shouldFail(throw s);
 		shouldFail(throw s);