Browse Source

[nullsafety] Allow statics init in main (#12211)

* [nullsafety] Allow statics init in main

* usage before init check
RblSb 3 months ago
parent
commit
bbd0ee0578
2 changed files with 113 additions and 31 deletions
  1. 93 30
      src/typing/nullSafety.ml
  2. 20 1
      tests/nullsafety/src/cases/TestStrict.hx

+ 93 - 30
src/typing/nullSafety.ml

@@ -1534,7 +1534,7 @@ class expr_checker mode immediate_execution report =
 			traverse 0 args types meta
 	end
 
-class class_checker cls immediate_execution report =
+class class_checker cls immediate_execution report (main_expr : texpr option) =
 	let cls_meta = cls.cl_meta @ (match cls.cl_kind with KAbstractImpl a -> a.a_meta | _ -> []) in
 	object (self)
 			val is_safe_class = (safety_enabled cls_meta)
@@ -1616,34 +1616,84 @@ class class_checker cls immediate_execution report =
 		*)
 		method private is_in_safety field =
 			(is_safe_class && not (contains_unsafe_meta field.cf_meta)) || safety_enabled field.cf_meta
+		(**
+			Extract `tf_expr` from `com.main.main_expr` if this expr in current class
+		*)
+		method private get_main_tf_expr (main_expr : texpr option) =
+			match main_expr with
+				| Some main_expr ->
+					begin match main_expr.eexpr with
+						| TCall ({ eexpr = TField (_, FStatic (cl, field))}, _) when cl == cls ->
+							begin match field.cf_expr with
+								| Some ({ eexpr = TFunction { tf_expr = e } }) ->
+									Some e
+								| _ -> None
+							end
+						| _ -> None
+					end
+				| None -> None
 		(**
 			Check `var` fields are initialized properly
 		*)
 		method check_var_fields =
 			let check_field is_static field =
 				validate_safety_meta report field.cf_meta;
-				if should_be_initialized field then
-					if not (is_nullable_type field.cf_type) && self#is_in_safety field then
-						match field.cf_expr with
-							| None ->
-								if is_static then
-									checker#error
-										("Field \"" ^ field.cf_name ^ "\" is not nullable thus should have an initial value.")
-										[field.cf_pos]
-							| Some e ->
-								if not (checker#can_pass_expr e field.cf_type e.epos) then
-									checker#error ("Cannot set nullable initial value for not-nullable field \"" ^ field.cf_name ^ "\".") [field.cf_pos]
+				if
+					should_be_initialized field
+					&& not (is_nullable_type field.cf_type)
+					&& self#is_in_safety field
+				then
+					match field.cf_expr with
+						| Some e ->
+							if not (checker#can_pass_expr e field.cf_type e.epos) then
+								checker#error
+								("Cannot set nullable initial value for not-nullable field \"" ^ field.cf_name ^ "\".") [field.cf_pos]
+						| None -> ()
 			in
 			List.iter (check_field false) cls.cl_ordered_fields;
 			List.iter (check_field true) cls.cl_ordered_statics;
+
+			self#check_statics_initialization ();
 			self#check_fields_initialization_in_constructor ()
+
+		method private check_statics_initialization () =
+			let fields_to_initialize = Hashtbl.create 20 in
+			List.iter
+				(fun f ->
+					if
+						should_be_initialized f
+						&& not (is_nullable_type f.cf_type)
+						&& not (contains_unsafe_meta f.cf_meta)
+					then
+						match f.cf_expr with
+							| Some _ -> ()
+							| None -> Hashtbl.add fields_to_initialize f.cf_name f
+				)
+				cls.cl_ordered_statics;
+
+			begin match TClass.get_cl_init cls with
+				| Some init_expr ->
+					ignore (self#check_fields_initialization fields_to_initialize init_expr true);
+				| None -> ()
+			end;
+			let main_tf_expr = self#get_main_tf_expr main_expr in
+			(match main_tf_expr with
+				| Some tf_expr ->
+					ignore (self#check_fields_initialization fields_to_initialize tf_expr true);
+				| _ -> ()
+			);
+			Hashtbl.iter
+				(fun name field ->
+					checker#error
+						("Field \"" ^ name ^ "\" is not nullable thus should have an initial value.")
+						[field.cf_pos]
+				)
+				fields_to_initialize
 		(**
 			Check instance fields without initial values are properly initialized in constructor
 		*)
 		method private check_fields_initialization_in_constructor () =
-			let fields_to_initialize = Hashtbl.create 20
-			(* Compiler-autogenerated local vars for transfering `this` to local functions *)
-			and this_vars = Hashtbl.create 5 in
+			let fields_to_initialize = Hashtbl.create 20 in
 			List.iter
 				(fun f ->
 					if
@@ -1656,10 +1706,30 @@ class class_checker cls immediate_execution report =
 							| None -> Hashtbl.add fields_to_initialize f.cf_name f
 				)
 				cls.cl_ordered_fields;
+
+			(match cls.cl_constructor with
+				| Some { cf_expr = Some { eexpr = TFunction { tf_expr = e } } } ->
+					ignore (self#check_fields_initialization fields_to_initialize e false);
+				| _ -> ()
+			);
+			Hashtbl.iter
+				(fun name field ->
+					checker#error
+						("Field \"" ^ name ^ "\" is not nullable thus should have an initial value or should be initialized in constructor.")
+						[field.cf_pos]
+				)
+				fields_to_initialize
+
+		method private check_fields_initialization fields_to_initialize tf_expr is_static =
+			(* Compiler-autogenerated local vars for transfering `this` to local functions *)
+			let this_vars = Hashtbl.create 5 in
 			let rec check_unsafe_usage init_list safety_enabled e =
 				if Hashtbl.length init_list > 0 then
 					match e.eexpr with
-						| TField ({ eexpr = TConst TThis }, FInstance (_, _, field)) ->
+						| TField ({ eexpr = TConst TThis }, FInstance (_, _, field)) when not is_static ->
+							if Hashtbl.mem init_list field.cf_name then
+								checker#error ("Cannot use field " ^ field.cf_name ^ " until initialization.") [e.epos]
+						| TField (_, FStatic (_, field)) when is_static ->
 							if Hashtbl.mem init_list field.cf_name then
 								checker#error ("Cannot use field " ^ field.cf_name ^ " until initialization.") [e.epos]
 						| TField ({ eexpr = TConst TThis }, FClosure (_, field)) ->
@@ -1680,7 +1750,11 @@ class class_checker cls immediate_execution report =
 			in
 			let rec traverse init_list e =
 				(match e.eexpr with
-					| TBinop (OpAssign, { eexpr = TField ({ eexpr = TConst TThis }, FInstance (_, _, f)) }, right_expr) ->
+					| TBinop (OpAssign, { eexpr = TField ({ eexpr = TConst TThis }, FInstance (_, _, f)) }, right_expr)
+						when not is_static ->
+						Hashtbl.remove init_list f.cf_name;
+						ignore (traverse init_list right_expr)
+					| TBinop (OpAssign, { eexpr = TField(_, FStatic(_, f)) }, right_expr) when is_static ->
 						Hashtbl.remove init_list f.cf_name;
 						ignore (traverse init_list right_expr)
 					| TWhile (condition, body, DoWhile) ->
@@ -1702,18 +1776,7 @@ class class_checker cls immediate_execution report =
 				);
 				init_list
 			in
-			(match cls.cl_constructor with
-				| Some { cf_expr = Some { eexpr = TFunction { tf_expr = e } } } ->
-					ignore (traverse fields_to_initialize e);
-				| _ -> ()
-			);
-			Hashtbl.iter
-				(fun name field ->
-					checker#error
-						("Field \"" ^ name ^ "\" is not nullable thus should have an initial value or should be initialized in constructor.")
-						[field.cf_pos]
-				)
-				fields_to_initialize
+			traverse fields_to_initialize tf_expr
 	end
 
 (**
@@ -1728,7 +1791,7 @@ let run (com:Common.context) (types:module_type list) =
 				| TEnumDecl enm -> ()
 				| TTypeDecl typedef -> ()
 				| TAbstractDecl abstr -> ()
-				| TClassDecl cls -> (new class_checker cls immediate_execution report)#check
+				| TClassDecl cls -> (new class_checker cls immediate_execution report com.main.main_expr)#check
 		in
 		List.iter traverse types;
 		report;

+ 20 - 1
tests/nullsafety/src/cases/TestStrict.hx

@@ -149,15 +149,34 @@ class TestStrict {
 		shouldFail(return v);
 	}
 
+	@:shouldFail static var badInit:Int;
+	static var init:Int;
+	@:shouldFail static var init2:Int = null;
+
 	/**
 	 *  Null safety should work in __init__ functions
 	 */
 	static function __init__() {
 		var s:Null<String> = null;
 		shouldFail(s.length);
+
+		final v:Int = shouldFail(init);
+
+		if (true) init = 1;
+		else init = 1;
+		init2 = 1;
+
+		final v:Int = init;
+		final v:Int = shouldFail(badInit);
+
+		function name():Void {
+			shouldFail(badInit) = 1;
+		}
+		if (true) shouldFail(badInit) = 1;
 	}
 
-	static public function main() {
+	static public function main() { // not a real main
+		badInit = 1;
 	}
 
 	/**