Browse Source

Improve null safety control flow in binops (#12197)

* Improve null safety control flow in binops

* Update process_or
RblSb 4 months ago
parent
commit
9b0c1dc8d5
2 changed files with 260 additions and 7 deletions
  1. 52 6
      src/typing/nullSafety.ml
  2. 208 1
      tests/nullsafety/src/cases/TestStrict.hx

+ 52 - 6
src/typing/nullSafety.ml

@@ -367,9 +367,26 @@ let rec process_condition mode condition (is_nullable_expr:texpr->bool) callback
 		if to_nulls then nulls := expr :: !nulls
 		else not_nulls := expr :: !not_nulls
 	in
+	let remove expr =
+		let expr = reveal_expr expr in
+		let subj = get_subject mode expr in
+		nulls := List.filter (fun e ->
+			let e_subj = get_subject mode (reveal_expr e) in
+			e_subj <> subj
+		) !nulls;
+		not_nulls := List.filter (fun e ->
+			let e_subj = get_subject mode (reveal_expr e) in
+			e_subj <> subj
+		) !not_nulls;
+	in
 	let rec traverse positive e =
 		match e.eexpr with
 			| TUnop (Not, Prefix, e) -> traverse (not positive) e
+			| TBinop (OpAssign, checked_expr, e) when is_suitable mode checked_expr && (is_nullable_expr e) ->
+				(* remove expr from both list if there is `e = nullable` in condition *)
+				remove checked_expr
+			| TBlock exprs ->
+				List.iter (fun e -> traverse positive e) exprs
 			| TBinop (OpEq, { eexpr = TConst TNull }, checked_expr) when is_suitable mode checked_expr ->
 				add positive checked_expr
 			| TBinop (OpEq, checked_expr, { eexpr = TConst TNull }) when is_suitable mode checked_expr ->
@@ -834,12 +851,13 @@ class local_safety (mode:safety_mode) =
 				| TWhile (condition, body, NormalWhile) ->
 					condition_callback condition;
 					let (nulls, not_nulls) = process_condition mode condition is_nullable_expr (fun _ -> ()) in
+					let original_safe = self#get_safe_locals_copy in
 					(** execute `body` with known not-null variables *)
 					List.iter self#get_current_scope#add_to_safety not_nulls;
 					body_callback
 						(fun () -> List.iter self#get_current_scope#add_to_safety not_nulls)
 						body;
-					List.iter self#get_current_scope#remove_from_safety not_nulls;
+					self#get_current_scope#filter_safety original_safe;
 				| _ -> fail ~msg:"Expected TWhile" expr.epos __POS__
 		(**
 			Should be called for bodies of loops (for, while)
@@ -955,18 +973,46 @@ class local_safety (mode:safety_mode) =
 		*)
 		method process_and left_expr right_expr is_nullable_expr (callback:texpr->unit) =
 			callback left_expr;
-			let (_, not_nulls) = process_condition mode left_expr is_nullable_expr (fun e -> ()) in
-			List.iter self#get_current_scope#add_to_safety not_nulls;
+			let original_safe = self#get_safe_locals_copy in
+			(* save not_nulls for `a != null && a > 0` *)
+			let (_, not_nulls) = process_condition mode left_expr is_nullable_expr (fun _ -> ()) in
+			(* create temp scope for right_expr *)
+			let temp_scope = new safety_scope mode STNormal (Hashtbl.copy original_safe) (Hashtbl.create 10) in
+			List.iter temp_scope#add_to_safety not_nulls;
+			scopes <- temp_scope :: scopes;
 			callback right_expr;
-			List.iter self#get_current_scope#remove_from_safety not_nulls
+			self#scope_closed;
+
+			let safe_after_rhs = temp_scope#get_safe_locals in
+			let final_safe = Hashtbl.create (Hashtbl.length original_safe) in
+			Hashtbl.iter (fun subj e ->
+				if Hashtbl.mem original_safe subj && Hashtbl.mem safe_after_rhs subj then
+					Hashtbl.replace final_safe subj e
+			) original_safe;
+
+			self#get_current_scope#reset_to final_safe;
 		(**
 			Handle boolean OR outside of `if` condition.
 		*)
 		method process_or left_expr right_expr is_nullable_expr (callback:texpr->unit) =
+			let original_safe = self#get_safe_locals_copy in
+			(* save nulls for `a == null || a > 0` *)
 			let (nulls, _) = process_condition mode left_expr is_nullable_expr callback in
-			List.iter self#get_current_scope#add_to_safety nulls;
+			(* create temp scope for right_expr *)
+			let temp_scope = new safety_scope mode STNormal (Hashtbl.copy original_safe) (Hashtbl.create 10) in
+			List.iter temp_scope#add_to_safety nulls;
+			scopes <- temp_scope :: scopes;
 			callback right_expr;
-			List.iter self#get_current_scope#remove_from_safety nulls
+			self#scope_closed;
+
+			let safe_after_rhs = temp_scope#get_safe_locals in
+			let final_safe = Hashtbl.create (Hashtbl.length original_safe) in
+			Hashtbl.iter (fun subj e ->
+				if Hashtbl.mem original_safe subj && Hashtbl.mem safe_after_rhs subj then
+					Hashtbl.replace final_safe subj e
+			) original_safe;
+
+			self#get_current_scope#reset_to final_safe;
 		(**
 			Remove subject from the safety list if a nullable value is assigned or if an object with safe field is reassigned.
 		*)

+ 208 - 1
tests/nullsafety/src/cases/TestStrict.hx

@@ -615,6 +615,14 @@ class TestStrict {
 		shouldFail(a.value);
 	}
 
+	static function nullable_doWhile_shouldPass(?a:Int) {
+		do {
+			if (a == null) return;
+			a++;
+		} while (true);
+		a++;
+	}
+
 	static function throw_nullableValue_shouldFail() {
 		var s:Null<String> = null;
 		shouldFail(throw s);
@@ -827,7 +835,8 @@ class TestStrict {
 
 	static public function closure_storedSomewhere_shouldFail(?s:String) {
 		if(s != null) {
-			passesSomewhereElse(() -> shouldFail(s.length));
+			// unstable, see #12187
+			// passesSomewhereElse(() -> shouldFail(s.length));
 			storesSomewhere(() -> shouldFail(s.length));
 		}
 	}
@@ -1113,3 +1122,201 @@ abstract NullFloat(Null<Float>) from Null<Float> to Null<Float> {
 		return lhs != null ? lhs.val() + rhs : rhs;
 	}
 }
+
+class BinopFlow {
+	function ifAndTrue_shoudPass(?a:Int):Void {
+		if (a == null) return;
+		if (a == 2 && true) {}
+		a++;
+	}
+
+	function ifOrTrue_shoudPass(?a:Int):Void {
+		if (a == null) return;
+		if (a == null || true) {}
+		a++;
+	}
+
+	function ifWithBlock_after_return_shouldPass(?a:Int):Void {
+		var safe = 0;
+		if (a == null) return;
+		if (a == 2 && {safe = a;true;}) {
+			safe = a;
+			return;
+		}
+		a++;
+	}
+
+	function ifWithNullableAssign_shouldFail(?a:Int):Void {
+		var safe = 0;
+		if (a == 2 || {shouldFail(safe = a);true;}) {}
+		if (a != 2 && {shouldFail(safe = a);true;}) {}
+		if (a == 2 && {safe = a;true;}) {}
+		shouldFail(safe = a);
+	}
+
+	function ifNonNullableCondition_shouldPass(?a:Int, ?b:Int):Void {
+		if (a != null && (b != null && true)) {
+			var sum:Int = a + b;
+		}
+
+		if (a != null && (b != null && a + b > 0)) {
+			final sum:Int = a + b;
+		}
+
+		if (a == null || (b != null && a + b > 0)) {}
+	}
+
+	function ifOrCondition_shouldFail(?a:Int, ?b:Int):Void {
+		if (a != null || (b != null && shouldFail(a + b) > 0)) {
+			final sum:Int = shouldFail(a + b);
+		}
+	}
+
+	function ifBlockNonNullableAssign_shouldPass(?a:Int, ?b:Int):Void {
+		var safe = 1;
+		if (a == null) {}
+		else safe = a;
+		if (a != null && ({safe = a; true;})) {
+			var sum:Int = a;
+		}
+	}
+
+	function ifBlockNullableAssign_shouldFail(?a:Int):Void {
+		if (a == null) return;
+		if (a == 2 && { a = null; false; }) return;
+		shouldFail(var safe:Int = a);
+	}
+
+	function ifFirstBlockNullableAssign_shouldFail(?a:Int):Void {
+		if (a == null) return;
+		if (({ a = null; false; }) && false) return;
+		shouldFail(var safe:Int = a);
+	}
+
+	function ifBlockNullableCall_shouldFail(?a:Int):Void {
+		inline function setNull():Void {
+			a = null;
+		}
+		if (a == null && { setNull(); false; }) return;
+		shouldFail(var safe:Int = a);
+	}
+
+	function ifMutableFunction_with_return_shouldFail(?a:Int):Void {
+		function setNull():Void {
+			a = null;
+		}
+		if (a == null) return;
+		// fails in Strict
+		shouldFail(var safe:Int = a);
+	}
+
+	function ifBlockInlineNullableCall_with_return_shouldPass(?a:Int):Void {
+		inline function setNull():Void {
+			a = null;
+		}
+		if (a == null && { setNull(); false; }) return;
+		if (a == null) return;
+		var safe:Int = a;
+	}
+
+	function ifMultipleAssigns_shouldFail(?a:Int, ?b:Int, ?c:Int, ?d:Int):Void {
+		if (c == null || d == null) return;
+		if (a == 2 || ({ a = 1; b = 1; c = null; d = null; false; })) return;
+		// `a` cannot be null here, but this is hard quest
+		shouldFail(var safe:Int = a);
+		shouldFail(var safe:Int = b);
+		shouldFail(var safe:Int = c);
+		shouldFail(var safe:Int = d);
+	}
+
+	function ifMultipleAssigns2_shouldFail(?a:Int, ?b:Int, ?c:Int, ?d:Int):Void {
+		if (c == null || d == null) return;
+		if (a == 2 && ({ a = 1; b = 1; c = null; d = null; false; })) return;
+		shouldFail(var safe:Int = a);
+		shouldFail(var safe:Int = b);
+		shouldFail(var safe:Int = c);
+		shouldFail(var safe:Int = d);
+		if (a == 2 || ({ a = 1; b = 1; c = null; d = null; false; })) return;
+		shouldFail(var safe:Int = a);
+		shouldFail(var safe:Int = b);
+		shouldFail(var safe:Int = c);
+		shouldFail(var safe:Int = d);
+	}
+
+	function ifMultipleAssigns3_shouldFail(?a:Int, ?b:Int, ?c:Int, ?d:Int):Void {
+		if (c == null || d == null) return;
+		if (({ a = 1; b = 1; c = null; d = null; false; }) && a == 2) return;
+		var safe:Int = a;
+		var safe:Int = b;
+		shouldFail(var safe:Int = c);
+		shouldFail(var safe:Int = d);
+	}
+
+	function nullable_ifAndTrue_shouldPass(?a:Int) {
+		while (a != null) {
+			a++;
+		}
+		if (a == null) return;
+		if (a == 2 && true) return;
+		while (a == 3) {}
+		a++;
+		while ({a = null; shouldFail(a++); false;}) {}
+		shouldFail(a++);
+	}
+
+	function nullableVar_after_while_shouldFail(?a:Int):Void {
+		if (a == null) return;
+		while (a != null) {
+			a = null;
+			break;
+		}
+		shouldFail(a++);
+	}
+
+	function if_nullableString_shouldFail(?s:String):Void {
+		if (true && s != null && true) {
+			s.length;
+		} else {
+			shouldFail(s.length);
+			if (s == null) return;
+			s.length;
+		}
+	}
+
+	function if_else_nullableVars_shouldFail(?a:Int) {
+		var b:Null<Int> = 1;
+		if (a != null && true) {
+			a++;
+			b++;
+		} else {
+			shouldFail(a++);
+			b++;
+		}
+		shouldFail(a++);
+		b++;
+	}
+
+	function if_assignNullable_shouldFail(?a:Int, ?b:Int):Void {
+		if (a != null && true) {
+			a++;
+		}
+		if (a != null && true && {a = b; true;}) {
+			shouldFail(a++);
+		}
+		if (a != null && {a = null; true;}) {
+			shouldFail(a++);
+		}
+		if (a != null && {a = b; true;}) {
+			shouldFail(a++);
+		}
+		if (a != null && {shouldFail(a += b); true;}) {
+			a++;
+		}
+	}
+
+	function if_orAssignBlock_shouldFail(?a:Int, ?b:Int):Void {
+		var safe = 1;
+		if (a == null || {safe = a; true;}) {}
+		if (a != null || {shouldFail(safe = a); true;}) {}
+	}
+}