Ver Fonte

[nullsafety] fix false-positive for final fields checked against null (fixes #7729)

Alexander Kuzmenko há 6 anos atrás
pai
commit
e0f0370139
2 ficheiros alterados com 206 adições e 61 exclusões
  1. 178 61
      src/typing/nullSafety.ml
  2. 28 0
      tests/nullsafety/src/cases/Test.hx

+ 178 - 61
src/typing/nullSafety.ml

@@ -83,6 +83,65 @@ let symbol_name expr =
 		| TNew _ -> "new"
 		| _ -> ""
 
+type safety_subject =
+	(*
+		Fields accessed through a static access are identified by the class path and the field name.
+		E.g.
+			`pack.MyClass.field` is `((["pack"], "MyClass"), ["field"])`
+			`pack.MyClass.field.sub` is `((["pack"], "MyClass"), ["field"; "sub"])`
+	*)
+	| SFieldOfClass of (path * (string list))
+	(*
+		Fields accessed through a local variable are identified by the var id and the field name.
+		E.g.
+			`v.field` is `(v.v_id, ["field"])`
+			`v.field.sub` is `(v.v_id, ["field"; "sub"])`
+	*)
+	| SFieldOfLocalVar of (int * (string list))
+	(*
+		Fields accessed through `this` are identified by their names.
+		E.g.
+			`this.field` is `["field"]`
+			`this.field.sub` is `["field"; "sub"]`
+	*)
+	| SFieldOfThis of (string list)
+	(*
+		Local variables - by tvar.v_id
+	*)
+	| SLocalVar of int
+	(*
+		For expressions, which cannot be checked agains `null` to become safe
+	*)
+	| SNotSuitable
+
+let get_subject expr =
+	match expr.eexpr with
+		| TLocal v ->
+			SLocalVar v.v_id
+		| TField ({ eexpr = TTypeExpr _ }, FStatic (cls, field)) when field.cf_final ->
+			SFieldOfClass (cls.cl_path, [field.cf_name])
+		| TField ({ eexpr = TConst TThis }, (FInstance (_, _, field) | FAnon field)) when field.cf_final ->
+			SFieldOfThis [field.cf_name]
+		| TField ({ eexpr = TLocal v }, (FInstance (_, _, field) | FAnon field)) when field.cf_final ->
+			SFieldOfLocalVar (v.v_id, [field.cf_name])
+		(* | TField (e, (FInstance (_, _, field) | FAnon field)) ->
+			(match get_subject e with
+				| SFieldOfClass (path, fields) -> SFieldOfClass (path, field.cf_name :: fields)
+				| SFieldOfThis fields -> SFieldOfThis (field.cf_name :: fields)
+				| SFieldOfLocalVar (var_id, fields) -> SFieldOfLocalVar (var_id, field.cf_name :: fields)
+				|_ -> SNotSuitable
+			) *)
+		|_ -> SNotSuitable
+
+let rec is_suitable expr =
+	match expr.eexpr with
+		(* | TField (target, (FInstance _ | FStatic _ | FAnon _)) -> is_suitable target *)
+		| TField ({ eexpr = TConst TThis }, FInstance _)
+		| TField ({ eexpr = TLocal _ }, (FInstance _ | FAnon _))
+		| TField ({ eexpr = TTypeExpr _ }, FStatic _)
+		| TLocal _ -> true
+		|_ -> false
+
 class unificator =
 	object(self)
 		val stack = new_rec_stack()
@@ -254,23 +313,25 @@ let rec can_pass_type src dst =
 let process_condition condition (is_nullable_expr:texpr->bool) callback =
 	let nulls = ref []
 	and not_nulls = ref [] in
-	let add to_nulls v =
-		if to_nulls then nulls := v :: !nulls
-		else not_nulls := v :: !not_nulls
+	let add to_nulls expr =
+		if to_nulls then nulls := expr :: !nulls
+		else not_nulls := expr :: !not_nulls
 	in
 	let rec traverse positive e =
 		match e.eexpr with
 			| TUnop (Not, Prefix, e) -> traverse (not positive) e
-			| TBinop (OpEq, { eexpr = TConst TNull }, { eexpr = TLocal v })
-			| TBinop (OpEq, { eexpr = TLocal v }, { eexpr = TConst TNull }) ->
-				add positive v
-			| TBinop (OpNotEq, { eexpr = TConst TNull }, { eexpr = TLocal v })
-			| TBinop (OpNotEq, { eexpr = TLocal v }, { eexpr = TConst TNull }) ->
-				add (not positive) v
-			| TBinop (OpEq, e, { eexpr = TLocal v }) when not (is_nullable_expr e) ->
-				if positive then not_nulls := v :: !not_nulls
-			| TBinop (OpEq, { eexpr = TLocal v }, e) when not (is_nullable_expr e) ->
-				if positive then not_nulls := v :: !not_nulls
+			| TBinop (OpEq, { eexpr = TConst TNull }, checked_expr) when is_suitable checked_expr ->
+				add positive checked_expr
+			| TBinop (OpEq, checked_expr, { eexpr = TConst TNull }) when is_suitable checked_expr ->
+				add positive checked_expr
+			| TBinop (OpNotEq, { eexpr = TConst TNull }, checked_expr) when is_suitable checked_expr ->
+				add (not positive) checked_expr
+			| TBinop (OpNotEq, checked_expr, { eexpr = TConst TNull }) when is_suitable checked_expr ->
+				add (not positive) checked_expr
+			| TBinop (OpEq, e, checked_expr) when is_suitable checked_expr && not (is_nullable_expr e) ->
+				if positive then not_nulls := checked_expr :: !not_nulls
+			| TBinop (OpEq, checked_expr, e) when is_suitable checked_expr && not (is_nullable_expr e) ->
+				if positive then not_nulls := checked_expr :: !not_nulls
 			| TBinop (OpBoolAnd, left_expr, right_expr) when positive ->
 				traverse positive left_expr;
 				traverse positive right_expr
@@ -324,6 +385,16 @@ let is_overridden cls field =
 	in
 	List.exists (fun d -> loop_inheritance d) cls.cl_descendants
 
+(**
+	Check if all items of the `needle` list exist in the same order in the beginning of the `haystack` list.
+*)
+let rec list_starts_with_list (haystack:string list) (needle:string list) =
+	match haystack, needle with
+		| _, [] -> true
+		| [], _ -> false
+		| current_haystack :: rest_haystack, current_needle :: rest_needle ->
+			current_haystack = current_needle && list_starts_with_list rest_haystack rest_needle
+
 (**
 	A class which is used to check if an anonymous function passed to a method will be executed
 	before that method execution is finished.
@@ -416,9 +487,9 @@ class immediate_execution =
 	end
 
 (**
-	Each loop or function should have its own scope.
+	Each loop or function should have its own safety scope.
 *)
-class safety_scope (scope_type:scope_type) (safe_locals:(int,tvar) Hashtbl.t) (never_safe:(int,tvar) Hashtbl.t) =
+class safety_scope (scope_type:scope_type) (safe_locals:(safety_subject,texpr) Hashtbl.t) (never_safe:(safety_subject,texpr) Hashtbl.t) =
 	object (self)
 		(** Local vars declared in current scope *)
 		val declarations = Hashtbl.create 100
@@ -428,7 +499,7 @@ class safety_scope (scope_type:scope_type) (safe_locals:(int,tvar) Hashtbl.t) (n
 		(**
 			Reset local vars safety to the specified state
 		*)
-		method reset_to (state:(int,tvar) Hashtbl.t) =
+		method reset_to (state:(safety_subject,texpr) Hashtbl.t) =
 			Hashtbl.clear safe_locals;
 			Hashtbl.iter (Hashtbl.add safe_locals) state
 		(**
@@ -444,40 +515,80 @@ class safety_scope (scope_type:scope_type) (safe_locals:(int,tvar) Hashtbl.t) (n
 		(**
 			Check if local variable declared in this scope is guaranteed to not have a `null` value.
 		*)
-		method is_safe local_var =
-			not (Hashtbl.mem never_safe local_var.v_id)
+		method is_safe (expr:texpr) =
+			not (is_nullable_type expr.etype)
+			|| match get_subject expr with
+				| SNotSuitable ->
+					false
+				| subj ->
+					not (Hashtbl.mem never_safe subj)
+					&& Hashtbl.mem safe_locals subj
+			(* not (Hashtbl.mem never_safe local_var.v_id)
 			&& (
 				Hashtbl.mem safe_locals local_var.v_id
 				|| not (is_nullable_type local_var.v_type)
-			)
+			) *)
 		(**
 			Add variable to the list of safe locals.
 		*)
-		method add_to_safety v =
-			Hashtbl.replace safe_locals v.v_id v
+		method add_to_safety expr =
+			match get_subject expr with
+				| SNotSuitable -> ()
+				| subj -> Hashtbl.replace safe_locals subj expr
 		(**
 			Remove variable from the list of safe locals.
 		*)
-		method remove_from_safety ?(forever=false) v =
-			Hashtbl.remove safe_locals v.v_id;
-			if forever then
-				Hashtbl.replace never_safe v.v_id v
+		method remove_from_safety ?(forever=false) expr =
+			match get_subject expr with
+				| SNotSuitable -> ()
+				| subj ->
+					Hashtbl.remove safe_locals subj;
+					if forever then
+						Hashtbl.replace never_safe subj expr
 		(**
 			Remove locals, which don't exist in `sample`, from safety.
 		*)
-		method filter_safety (sample:(int,tvar) Hashtbl.t) =
+		method filter_safety (sample:(safety_subject,texpr) Hashtbl.t) =
 			Hashtbl.iter
-				(fun var_id v ->
-					if not (Hashtbl.mem sample var_id) then
-						self#remove_from_safety v
+				(fun subj expr ->
+					if not (Hashtbl.mem sample subj) then
+						Hashtbl.remove safe_locals subj
 				)
 				(Hashtbl.copy safe_locals);
+		(**
+			Should be called upon assigning a value to `expr`.
+			Removes subjects like `expr.subField` from safety.
+		*)
+		method reassigned (expr:texpr) =
+			match get_subject expr with
+				| SNotSuitable -> ()
+				| subj ->
+					let remove safe_subj safe_fields fields =
+						if list_starts_with_list (List.rev safe_fields) (List.rev fields) then
+							Hashtbl.remove safe_locals safe_subj
+					in
+					Hashtbl.iter
+						(fun safe_subj safe_expr ->
+							match safe_subj, subj with
+								| SFieldOfLocalVar (safe_id, _), SLocalVar v_id when safe_id = v_id ->
+									Hashtbl.remove safe_locals safe_subj
+								| SFieldOfLocalVar (safe_id, safe_fields), SFieldOfLocalVar (v_id, fields) when safe_id = v_id ->
+									remove safe_subj safe_fields fields
+								| SFieldOfClass (safe_path, safe_fields), SFieldOfClass (path, fields) when safe_path = path ->
+									remove safe_subj safe_fields fields
+								| SFieldOfClass (safe_path, safe_fields), SFieldOfClass (path, fields) when safe_path = path ->
+									remove safe_subj safe_fields fields
+								| SFieldOfThis safe_fields, SFieldOfThis fields ->
+									remove safe_subj safe_fields fields
+								| _ -> ()
+						)
+						(Hashtbl.copy safe_locals)
 	end
 
 (**
-	Class to simplify collecting lists of local vars checked against `null`.
+	Class to simplify collecting lists of local vars, fields and other symbols checked against `null`.
 *)
-class local_vars =
+class local_safety =
 	object (self)
 		val mutable scopes = [new safety_scope STNormal (Hashtbl.create 100) (Hashtbl.create 100)]
 		(**
@@ -534,27 +645,32 @@ class local_vars =
 		method declare_var ?(is_safe=false) (v:tvar) =
 			let scope = self#get_current_scope in
 			scope#declare_var v;
-			if is_safe then scope#add_to_safety v
+			if is_safe then scope#add_to_safety { eexpr = TVar (v, None); etype = v.v_type; epos = v.v_pos }
 		(**
 			Check if local variable is guaranteed to not have a `null` value.
 		*)
-		method is_safe local_var =
-			if not (is_nullable_type local_var.v_type) then
+		method is_safe expr =
+			if not (is_nullable_type expr.etype) then
 				true
 			else
-				let rec traverse scopes =
-					match scopes with
-						| [] -> false
-						| current :: rest ->
-							if current#owns_var local_var then
-								false
-							else if current#get_type = STClosure then
-								true
-							else
-								traverse rest
+				let captured =
+					match expr.eexpr with
+						| TLocal local_var ->
+							let rec traverse scopes =
+								match scopes with
+									| [] -> false
+									| current :: rest ->
+										if current#owns_var local_var then
+											false
+										else if current#get_type = STClosure then
+											true
+										else
+											traverse rest
+							in
+							traverse scopes
+						| _ -> false
 				in
-				let captured = traverse scopes in
-				not captured && self#get_current_scope#is_safe local_var
+				not captured && self#get_current_scope#is_safe expr
 		(**
 			This method should be called upon passing `while`.
 			It collects locals which are checked against `null` and executes callbacks for expressions with proper statuses of locals.
@@ -663,31 +779,31 @@ class local_vars =
 			callback right_expr;
 			List.iter self#get_current_scope#remove_from_safety nulls
 		(**
-			Remove local var from safety list if a nullable value is assigned to that var
+			Remove subject from the safety list if a nullable value is assigned or if an object with safe field is reassigned.
 		*)
 		method handle_assignment is_nullable_expr left_expr (right_expr:texpr) =
-			match (reveal_expr left_expr).eexpr with
-				| TLocal v ->
-					if is_nullable_expr right_expr then
-						begin
+			if is_suitable left_expr then
+				self#get_current_scope#reassigned left_expr;
+				if is_nullable_expr right_expr then
+					match left_expr.eexpr with
+						| TLocal v ->
 							let captured = ref false in
 							let rec traverse (lst:safety_scope list) =
 								match lst with
 									| [] -> ()
 									| current :: rest ->
 										if current#owns_var v then
-											current#remove_from_safety ~forever:!captured v
+											current#remove_from_safety ~forever:!captured left_expr
 										else begin
 											captured := !captured || current#get_type = STClosure;
-											current#remove_from_safety ~forever:!captured v;
+											current#remove_from_safety ~forever:!captured left_expr;
 											traverse rest
 										end
 							in
 							traverse scopes
-						end
-					else if is_nullable_type v.v_type then
-						self#get_current_scope#add_to_safety v
-				| _ -> ()
+						| _ -> ()
+				else if is_nullable_type left_expr.etype then
+					self#get_current_scope#add_to_safety left_expr
 	end
 
 (**
@@ -695,7 +811,7 @@ class local_vars =
 *)
 class expr_checker immediate_execution report =
 	object (self)
-		val local_safety = new local_vars
+		val local_safety = new local_safety
 		val mutable return_types = []
 		val mutable in_closure = false
 		(* if this flag is `true` then spotted errors and warnings will not be reported *)
@@ -717,7 +833,6 @@ class expr_checker immediate_execution report =
 				| TConst _ -> false
 				| TParenthesis e -> self#is_nullable_expr e
 				| TMeta (_, e) -> self#is_nullable_expr e
-				| TLocal v -> not (local_safety#is_safe v)
 				| TThrow _ -> false
 				| TBlock exprs ->
 					(match exprs with
@@ -729,7 +844,8 @@ class expr_checker immediate_execution report =
 					let check body = nullable := !nullable || self#is_nullable_expr body in
 					local_safety#process_if e self#is_nullable_expr (fun _ -> ()) check;
 					!nullable
-				| _ -> is_nullable_type e.etype
+				| _ ->
+					is_nullable_type e.etype && not (local_safety#is_safe e)
 		(**
 			Check if `expr` can be passed to a place where `to_type` is expected.
 			This method has side effect: it logs an error if `expr` has a type parameter incompatible with the type parameter of `to_type`.
@@ -738,7 +854,7 @@ class expr_checker immediate_execution report =
 		method can_pass_expr expr to_type p =
 			if self#is_nullable_expr expr && not (is_nullable_type to_type) then
 				false
-			else
+			else begin
 				let expr_type = unfold_null expr.etype in
 				try
 					new unificator#unify expr_type to_type;
@@ -751,6 +867,7 @@ class expr_checker immediate_execution report =
 					| e ->
 						fail ~msg:"Null safety unification failure" expr.epos __POS__
 				(* can_pass_type expr.etype to_type *)
+			end
 		(**
 			Should be called for the root expressions of a method or for then initialization expressions of fields.
 		*)

+ 28 - 0
tests/nullsafety/src/cases/Test.hx

@@ -770,6 +770,34 @@ class Test {
 	static function recursiveTypedef_shouldNotCrashTheCompiler(a:Recursive<Void>, b:Recursive<Void>) {
 		a = b;
 	}
+
+	static function anonFinalNullableField_checkedForNull() {
+		var o:{ final ?f:String; } = {};
+		if (o.f != null) {
+			var s:String = o.f;
+			o = {};
+			shouldFail(var s:String = o.f);
+		}
+	}
+
+	static function staticFinalNullableField_checkedForNull() {
+		if (FinalNullableFields.staticVar != null) {
+			var s:String = FinalNullableFields.staticVar;
+		}
+		shouldFail(var s:String = FinalNullableFields.staticVar);
+	}
+}
+
+private class FinalNullableFields {
+	static public final staticVar:Null<String> = "hello";
+	public final instanceVar:Null<String> = "world";
+
+	function instanceFinalNullableField_checkedForNull() {
+		if (instanceVar != null) {
+			var s:String = instanceVar;
+		}
+		shouldFail(var s:String = instanceVar);
+	}
 }
 
 typedef Recursive<T1> = {