瀏覽代碼

[nullsafety] fix checking recursive functions for immediate execution of callbacks passed to their arguments

Alexander Kuzmenko 6 年之前
父節點
當前提交
40d617ac13
共有 2 個文件被更改,包括 111 次插入67 次删除
  1. 98 67
      src/typing/nullSafety.ml
  2. 13 0
      tests/nullsafety/src/cases/Test.hx

+ 98 - 67
src/typing/nullSafety.ml

@@ -228,67 +228,6 @@ let rec can_pass_type src dst =
 			| TAbstract ({ a_path = ([],"Null") }, [t]) -> true
 			| TAbstract _ -> true
 
-(**
-	Check if a lambda passed to `arg_num`th argument of the `callee` function will be executed immediately without
-	delaying it or storing it somewhere else.
-*)
-let rec immediate_execution callee arg_num =
-	match (reveal_expr callee).eexpr with
-		| TField (_, FClosure (Some (cls, _), ({ cf_kind = Method (MethNormal | MethInline) } as field)))
-		| TField (_, FStatic (cls, ({ cf_kind = Method (MethNormal | MethInline) } as field)))
-		| TField (_, FInstance (cls, _, ({ cf_kind = Method (MethNormal | MethInline) } as field))) ->
-			if PurityState.is_pure cls field then
-				true
-			else
-				(match cls, field with
-					(* known to be pure *)
-					| { cl_path = ([], "Array") }, _ -> true
-					(* try to analyze function code *)
-					| _, { cf_expr = (Some { eexpr = TFunction fn }) } ->
-						if arg_num < 0 || arg_num >= List.length fn.tf_args then
-							false
-						else
-							let (arg_var, _) = List.nth fn.tf_args arg_num in
-							not (is_stored arg_var fn.tf_expr)
-					| _ ->
-						false
-				)
-		| _ -> false
-
-and is_stored fn_var expr =
-	match expr.eexpr with
-		| TThrow { eexpr = TLocal v }
-		| TReturn (Some { eexpr = TLocal v })
-		| TCast ({ eexpr = TLocal v }, _)
-		| TMeta (_, { eexpr = TLocal v })
-		| TBinop (OpAssign, _, { eexpr = TLocal v }) when v.v_id = fn_var.v_id ->
-			true
-		| TFunction fn ->
-			let rec captured e =
-				match e.eexpr with
-					| TLocal v -> v.v_id = fn_var.v_id
-					| _ -> check_expr captured e
-			in
-			captured fn.tf_expr
-		| TCall (callee, args) ->
-			if is_stored fn_var callee then
-				true
-			else begin
-				let arg_num = ref 0 in
-				List.exists
-					(fun arg ->
-						let result =
-							match arg.eexpr with
-								| TLocal v when v.v_id = fn_var.v_id -> not (immediate_execution callee !arg_num)
-								| _ -> is_stored fn_var arg
-						in
-						incr arg_num;
-						result
-					)
-					args
-			end
-		| _ -> check_expr (is_stored fn_var) expr
-
 (**
 	Collect nullable local vars which are checked against `null`.
 	Returns a tuple of (vars_checked_to_be_null * vars_checked_to_be_not_null) in case `condition` evaluates to `true`.
@@ -380,6 +319,97 @@ let should_be_initialized field =
 		| Var _ -> Meta.has Meta.IsVar field.cf_meta
 		| _ -> false
 
+(**
+	A class which is used to check if an anonymous function passed to a method will be executed
+	before that method execution is finished.
+*)
+class immediate_execution =
+	object(self)
+		val cache = Hashtbl.create 500
+		(**
+			Get cached results of the previous checks for the specified `field`
+		*)
+		method private get_cache field =
+			try
+				Hashtbl.find cache field
+			with
+				| Not_found ->
+					let field_cache = Hashtbl.create 5 in
+					Hashtbl.add cache field field_cache;
+					field_cache
+		(**
+			Check if a lambda passed to `arg_num`th argument of the `callee` function will be executed immediately without
+			delaying it or storing it somewhere else.
+		*)
+		method check callee arg_num =
+			match (reveal_expr callee).eexpr with
+				| TField (_, FClosure (Some (cls, _), ({ cf_kind = Method (MethNormal | MethInline) } as field)))
+				| TField (_, FStatic (cls, ({ cf_kind = Method (MethNormal | MethInline) } as field)))
+				| TField (_, FInstance (cls, _, ({ cf_kind = Method (MethNormal | MethInline) } as field))) ->
+					if PurityState.is_pure cls field then
+						true
+					else
+						(match cls, field with
+							(* known to be pure *)
+							| { cl_path = ([], "Array") }, _ -> true
+							(* try to analyze function code *)
+							| _, ({ cf_expr = (Some { eexpr = TFunction fn }) } as field) ->
+								if arg_num < 0 || arg_num >= List.length fn.tf_args then
+									false
+								else begin
+									let cache = self#get_cache field in
+									if Hashtbl.mem cache arg_num then
+										Hashtbl.find cache arg_num
+									else begin
+										Hashtbl.add cache arg_num true;
+										let (arg_var, _) = List.nth fn.tf_args arg_num in
+										let result = not (self#is_stored arg_var fn.tf_expr) in
+										Hashtbl.replace cache arg_num result;
+										result
+									end
+								end
+							| _ ->
+								false
+						)
+				| _ -> false
+		(**
+			Check if `fn_var` is passed somewhere else in `expr` (stored to a var/field, captured by a closure etc.)
+		*)
+		method private is_stored fn_var expr =
+			match expr.eexpr with
+				| TThrow { eexpr = TLocal v }
+				| TReturn (Some { eexpr = TLocal v })
+				| TCast ({ eexpr = TLocal v }, _)
+				| TMeta (_, { eexpr = TLocal v })
+				| TBinop (OpAssign, _, { eexpr = TLocal v }) when v.v_id = fn_var.v_id ->
+					true
+				| TFunction fn ->
+					let rec captured e =
+						match e.eexpr with
+							| TLocal v -> v.v_id = fn_var.v_id
+							| _ -> check_expr captured e
+					in
+					captured fn.tf_expr
+				| TCall (callee, args) ->
+					if self#is_stored fn_var callee then
+						true
+					else begin
+						let arg_num = ref 0 in
+						List.exists
+							(fun arg ->
+								let result =
+									match arg.eexpr with
+										| TLocal v when v.v_id = fn_var.v_id -> not (self#check callee !arg_num)
+										| _ -> self#is_stored fn_var arg
+								in
+								incr arg_num;
+								result
+							)
+							args
+					end
+				| _ -> check_expr (self#is_stored fn_var) expr
+	end
+
 (**
 	Each loop or function should have its own scope.
 *)
@@ -656,9 +686,9 @@ class local_vars =
 	end
 
 (**
-	This is a base class is used to recursively check typed expressions for null-safety
+	This class is used to recursively check typed expressions for null-safety
 *)
-class expr_checker report =
+class expr_checker immediate_execution report =
 	object (self)
 		val local_safety = new local_vars
 		val mutable return_types = []
@@ -1044,7 +1074,7 @@ class expr_checker report =
 					end;
 					(match arg.eexpr with
 						| TFunction fn ->
-							self#check_function ~immediate_execution:(immediate_execution callee arg_num) fn
+							self#check_function ~immediate_execution:(immediate_execution#check callee arg_num) fn
 						| _ ->
 							self#check_expr arg
 					);
@@ -1052,9 +1082,9 @@ class expr_checker report =
 				| _ -> ()
 	end
 
-class class_checker cls report  =
+class class_checker cls immediate_execution report  =
 	object (self)
-			val checker = new expr_checker report
+			val checker = new expr_checker immediate_execution report
 		(**
 			Entry point for checking a class
 		*)
@@ -1175,13 +1205,14 @@ class class_checker cls report  =
 let run (com:Common.context) (types:module_type list) =
 	let timer = Timer.timer ["null safety"] in
 	let report = { sr_errors = [] } in
+	let immediate_execution = new immediate_execution in
 	let rec traverse module_type =
 		match module_type with
 			| TEnumDecl enm -> ()
 			| TTypeDecl typedef -> ()
 			| TAbstractDecl abstr -> ()
 			| TClassDecl cls when (contains_safe_meta cls.cl_meta) && not (contains_unsafe_meta cls.cl_meta) ->
-				(new class_checker cls report)#check
+				(new class_checker cls immediate_execution report)#check
 			| TClassDecl _ -> ()
 	in
 	List.iter traverse types;

+ 13 - 0
tests/nullsafety/src/cases/Test.hx

@@ -740,4 +740,17 @@ class Test {
 	static function storesSomewhere(cb:()->Int) {
 		tmp2 = cb;
 	}
+
+	static function closure_passedToRecursiveFunction_shouldNotCrashTheCompiler(?a:String) {
+		if(a != null) {
+			recursive(() -> a.length);
+		}
+	}
+	static function recursive(cb:Void->Int) {
+		if(Std.random(10) == 0) {
+			recursive(cb);
+		} else {
+			cb();
+		}
+	}
 }