Explorar o código

Find references of the base field or include descendant fields (#9315)

Aleksandr Kuzmenko %!s(int64=5) %!d(string=hai) anos
pai
achega
ce56f300b6

+ 2 - 2
src/compiler/displayOutput.ml

@@ -243,7 +243,7 @@ let handle_display_argument com file_pos pre_compilation did_something =
 				DMDefinition
 			| "usage" ->
 				Common.define com Define.NoCOpt;
-				DMUsage false
+				DMUsage (false,false,false)
 			(*| "rename" ->
 				Common.define com Define.NoCOpt;
 				DMUsage true*)
@@ -413,7 +413,7 @@ let promote_type_hints tctx =
 let process_global_display_mode com tctx =
 	promote_type_hints tctx;
 	match com.display.dms_kind with
-	| DMUsage with_definition ->
+	| DMUsage (with_definition,_,_) ->
 		FindReferences.find_references tctx com with_definition
 	| DMImplementation ->
 		FindReferences.find_implementations tctx com

+ 3 - 1
src/context/display/displayEmitter.ml

@@ -101,8 +101,10 @@ let display_field ctx origin scope cf p = match ctx.com.display.dms_kind with
 			| "new",(Self (TClassDecl c) | Parent(TClassDecl c)) ->
 				(* For constructors, we care about the class name so we don't end up looking for "new". *)
 				snd c.cl_path,SKConstructor cf
+			| _,(Self (TClassDecl c) | Parent(TClassDecl c)) ->
+				cf.cf_name,SKField (cf,Some c.cl_path)
 			| _ ->
-				cf.cf_name,SKField cf
+				cf.cf_name,SKField (cf,None)
 		in
 		ReferencePosition.set (name,cf.cf_name_pos,kind)
 	| DMHover ->

+ 7 - 1
src/context/display/displayJson.ml

@@ -139,7 +139,13 @@ let handler =
 		"display/references", (fun hctx ->
 			Common.define hctx.com Define.NoCOpt;
 			hctx.display#set_display_file false true;
-			hctx.display#enable_display (DMUsage false);
+			match hctx.jsonrpc#get_opt_param (fun () -> hctx.jsonrpc#get_string_param "kind") "normal" with
+			| "withBaseAndDescendants" ->
+				hctx.display#enable_display (DMUsage (false,true,true));
+			| "withDescendants" ->
+				hctx.display#enable_display (DMUsage (false,true,false));
+			| _ ->
+				hctx.display#enable_display (DMUsage (false,false,false));
 		);
 		"display/hover", (fun hctx ->
 			Common.define hctx.com Define.NoCOpt;

+ 65 - 8
src/context/display/findReferences.ml

@@ -2,6 +2,7 @@ open Globals
 open Ast
 open DisplayTypes
 open Common
+open Type
 open Typecore
 open CompilationServer
 open ImportHandling
@@ -25,18 +26,74 @@ let find_references tctx com with_definition name pos kind =
 		(try loop acc (Hashtbl.find relations p)
 		with Not_found -> acc)
 	) symbols [] in
-	let usages = List.sort (fun p1 p2 ->
-		let c = compare p1.pfile p2.pfile in
-		if c <> 0 then c else compare p1.pmin p2.pmin
-	) usages in
 	t();
 	Display.ReferencePosition.set ("",null_pos,SKOther);
-	DisplayException.raise_positions usages
+	usages
 
-let find_references tctx com with_definition =
+let collect_reference_positions com =
 	let name,pos,kind = Display.ReferencePosition.get () in
-	if pos <> null_pos then find_references tctx com with_definition name pos kind
-	else DisplayException.raise_positions []
+	match kind, com.display.dms_kind with
+	| SKField (cf,Some cl_path), DMUsage (_,find_descendants,find_base) when find_descendants || find_base ->
+		let collect() =
+			let c =
+				let rec loop = function
+					| [] -> raise Exit
+					| TClassDecl c :: _ when c.cl_path = cl_path -> c
+					| _ :: types -> loop types
+				in
+				loop com.types
+			in
+			let cf,c =
+				if find_base then
+					let rec loop c =
+						match c.cl_super with
+						| None -> (PMap.find cf.cf_name c.cl_fields),c
+						| Some (csup,_) ->
+							try loop csup
+							with Not_found -> (PMap.find cf.cf_name c.cl_fields),c
+					in
+					try loop c
+					with Not_found -> cf,c
+				else
+					cf,c
+			in
+			let full_pos p = { p with pfile = Path.unique_full_path p.pfile } in
+			if find_descendants then
+				List.fold_left (fun acc t ->
+					match t with
+					| TClassDecl child_cls when extends child_cls c ->
+						(try
+							let cf = PMap.find cf.cf_name child_cls.cl_fields in
+							(name,full_pos cf.cf_name_pos,SKField (cf,Some child_cls.cl_path)) :: acc
+						with Not_found -> acc
+						)
+					| _ ->
+						acc
+				) [] com.types
+			else
+				[name,full_pos cf.cf_name_pos,SKField (cf,Some c.cl_path)]
+		in
+		(try collect()
+		with Exit -> [name,pos,kind])
+	| _ ->
+		[name,pos,kind]
+
+let find_references tctx com with_definition =
+	let usages =
+		List.fold_left (fun acc (name,pos,kind) ->
+			if pos <> null_pos then begin
+				acc @ (find_references tctx com with_definition name pos kind)
+			end
+			else acc
+		) [] (collect_reference_positions com)
+	in
+	let usages =
+		List.sort (fun p1 p2 ->
+			let c = compare p1.pfile p2.pfile in
+			if c <> 0 then c else compare p1.pmin p2.pmin
+		) usages
+	in
+	DisplayException.raise_positions usages
 
 let find_implementations tctx com name pos kind =
 	let t = Timer.timer ["display";"implementations";"collect"] in

+ 2 - 2
src/context/display/statistics.ml

@@ -153,7 +153,7 @@ let collect_statistics ctx pfilter with_expressions =
 					| FInstance(c,_,cf) | FClosure(Some(c,_),cf) ->
 						field_reference (Some c) cf e.epos
 					| FAnon cf ->
-						declare  (SKField cf) cf.cf_name_pos;
+						declare  (SKField (cf,None)) cf.cf_name_pos;
 						field_reference None cf e.epos
 					| FEnum(_,ef) ->
 						add_relation ef.ef_name_pos (Referenced,patch_string_pos e.epos ef.ef_name)
@@ -232,7 +232,7 @@ let collect_statistics ctx pfilter with_expressions =
 			if c.cl_interface then
 				collect_implementations c;
 			let field cf =
-				if cf.cf_pos.pmin > c.cl_name_pos.pmin then declare (SKField cf) cf.cf_name_pos;
+				if cf.cf_pos.pmin > c.cl_name_pos.pmin then declare (SKField (cf,Some c.cl_path)) cf.cf_name_pos;
 				if with_expressions then begin
 					let _ = follow cf.cf_type in
 					match cf.cf_expr with None -> () | Some e -> collect_references c e

+ 11 - 5
src/core/displayTypes.ml

@@ -190,7 +190,13 @@ module DisplayMode = struct
 	type t =
 		| DMNone
 		| DMDefault
-		| DMUsage of bool (* true = also report definition *)
+		(**
+			Find usages/references of the requested symbol.
+			@param bool - add symbol definition to the response
+			@param bool - also find usages of descendants of the symbol (e.g methods, which override the requested one)
+			@param bool - look for a base method if requested for a method with `override` accessor.
+		*)
+		| DMUsage of bool * bool * bool
 		| DMDefinition
 		| DMTypeDefinition
 		| DMImplementation
@@ -292,8 +298,8 @@ module DisplayMode = struct
 		| DMResolve s -> "resolve " ^ s
 		| DMPackage -> "package"
 		| DMHover -> "type"
-		| DMUsage true -> "rename"
-		| DMUsage false -> "references"
+		| DMUsage (true,_,_) -> "rename"
+		| DMUsage (false,_,_) -> "references"
 		| DMModuleSymbols None -> "module-symbols"
 		| DMModuleSymbols (Some s) -> "workspace-symbols " ^ s
 		| DMDiagnostics _ -> "diagnostics"
@@ -307,7 +313,7 @@ type symbol =
 	| SKEnum of tenum
 	| SKTypedef of tdef
 	| SKAbstract of tabstract
-	| SKField of tclass_field
+	| SKField of tclass_field * path option (* path - class path *)
 	| SKConstructor of tclass_field
 	| SKEnumField of tenum_field
 	| SKVariable of tvar
@@ -330,7 +336,7 @@ let string_of_symbol = function
 	| SKEnum en -> snd en.e_path
 	| SKTypedef td -> snd td.t_path
 	| SKAbstract a -> snd a.a_path
-	| SKField cf | SKConstructor cf -> cf.cf_name
+	| SKField (cf,_) | SKConstructor cf -> cf.cf_name
 	| SKEnumField ef -> ef.ef_name
 	| SKVariable v -> v.v_name
 	| SKOther -> ""

+ 8 - 6
src/typing/typerDisplay.ml

@@ -296,7 +296,7 @@ and display_expr ctx e_ast e dk with_type p =
 		| None -> error "Current class does not have a super" p
 		| Some (c,params) ->
 			let _, f = get_constructor ctx c params p in
-			f
+			f,c
 	in
 	match ctx.com.display.dms_kind with
 	| DMResolve _ | DMPackage ->
@@ -310,8 +310,10 @@ and display_expr ctx e_ast e dk with_type p =
 		let rec loop e = match e.eexpr with
 		| TField(_,FEnum(_,ef)) ->
 			Display.ReferencePosition.set (ef.ef_name,ef.ef_name_pos,SKEnumField ef);
-		| TField(_,(FAnon cf | FInstance (_,_,cf) | FStatic (_,cf) | FClosure (_,cf))) ->
-			Display.ReferencePosition.set (cf.cf_name,cf.cf_name_pos,SKField cf);
+		| TField(_,(FAnon cf | FClosure (None,cf))) ->
+			Display.ReferencePosition.set (cf.cf_name,cf.cf_name_pos,SKField (cf,None));
+		| TField(_,(FInstance (c,_,cf) | FStatic (c,cf) | FClosure (Some (c,_),cf))) ->
+			Display.ReferencePosition.set (cf.cf_name,cf.cf_name_pos,SKField (cf,Some c.cl_path));
 		| TLocal v | TVar(v,_) ->
 			Display.ReferencePosition.set (v.v_name,v.v_pos,SKVariable v);
 		| TTypeExpr mt ->
@@ -326,8 +328,8 @@ and display_expr ctx e_ast e dk with_type p =
 			end
 		| TCall({eexpr = TConst TSuper},_) ->
 			begin try
-				let cf = get_super_constructor() in
-				Display.ReferencePosition.set (cf.cf_name,cf.cf_name_pos,SKField cf);
+				let cf,c = get_super_constructor() in
+				Display.ReferencePosition.set (cf.cf_name,cf.cf_name_pos,SKField (cf,Some c.cl_path));
 			with Not_found ->
 				()
 			end
@@ -380,7 +382,7 @@ and display_expr ctx e_ast e dk with_type p =
 			end
 		| TCall({eexpr = TConst TSuper},_) ->
 			begin try
-				let cf = get_super_constructor() in
+				let cf,_ = get_super_constructor() in
 				[cf.cf_name_pos]
 			with Not_found ->
 				[]

+ 23 - 1
std/haxe/display/Display.hx

@@ -46,7 +46,7 @@ class DisplayMethods {
 	/**
 		The find references request is sent from the client to Haxe to find locations that reference the symbol at a given text document position.
 	**/
-	static inline var FindReferences = new HaxeRequestMethod<PositionParams, GotoDefinitionResult>("display/references");
+	static inline var FindReferences = new HaxeRequestMethod<FindReferencesParams, GotoDefinitionResult>("display/references");
 
 	/**
 		The goto definition request is sent from the client to Haxe to resolve the definition location(s) of a symbol at a given text document position.
@@ -464,6 +464,28 @@ typedef CompletionItemResolveResult = Response<{
 	var item:DisplayItem<Dynamic>;
 }>;
 
+/** FindReferences **/
+typedef FindReferencesParams = PositionParams & {
+	var ?kind:FindReferencesKind;
+}
+
+enum abstract FindReferencesKind(String) to String {
+	/**
+		Find only direct references to the requested symbol.
+		Does not look for references to parent or overriding methods.
+	**/
+	var Direct = "direct";
+	/**
+		Find references to the base field and all the overidding fields in the inheritance chain.
+	**/
+	var WithBaseAndDescendants = "withBaseAndDescendants";
+	/**
+		Find references to the requested field and references to all
+		descendants of the requested field.
+	**/
+	var WithDescendants = "withDescendants";
+}
+
 /** GotoDefinition **/
 typedef GotoDefinitionResult = Response<Array<Location>>;
 

+ 3 - 0
tests/display/src-shared/Marker.hx

@@ -1,3 +1,6 @@
+import haxe.Exception;
+import haxe.display.Position;
+
 class Marker {
 	static var markerRe = ~/{-(\d+)-}/g;
 

+ 11 - 0
tests/server/src/TestCase.hx

@@ -1,3 +1,5 @@
+import haxe.Exception;
+import haxe.display.Position;
 import haxeserver.HaxeServerRequestResult;
 import haxe.display.JsonModuleTypes;
 import haxe.display.Display;
@@ -125,6 +127,15 @@ class TestCase implements ITest {
 		return Json.parse(lastResult.stderr).result;
 	}
 
+	function parseGotoDefinitionLocations():Array<Location> {
+		switch parseGotoDefinition().result {
+			case null:
+				throw new Exception('No result for GotoDefinition found');
+			case result:
+				return result;
+		}
+	}
+
 	function assertSuccess(?p:haxe.PosInfos) {
 		Assert.isTrue(0 == errorMessages.length, p);
 	}

+ 55 - 0
tests/server/src/cases/display/issues/Issue9044.hx

@@ -0,0 +1,55 @@
+package cases.display.issues;
+
+class Issue9044 extends DisplayTestCase {
+	/**
+		class Child extends Base {
+			public function new() {}
+
+			override function f{-1-}unc() {
+				super.{-2-}func{-3-}();
+			}
+		}
+
+		class GrandChild extends Child {
+			override function func() {
+				super.{-5-}func{-6-}();
+			}
+		}
+
+		class Base {
+			public function func() {}
+		}
+
+		class Main {
+			static function main() {
+				var c = new Child();
+				c.{-8-}func{-9-}();
+				var base:Base = c;
+				base.{-10-}func{-11-}();
+				var g = new GrandChild();
+				g.{-12-}func{-13-}();
+			}
+		}
+	**/
+	function test(_) {
+		runHaxeJson([], DisplayMethods.FindReferences, {
+			file: file,
+			offset: offset(1),
+			contents: source,
+			kind: WithBaseAndDescendants
+		});
+		var result = parseGotoDefinitionLocations();
+		var expectedRanges = [range(2, 3), range(5, 6), range(8, 9), range(10, 11), range(12, 13)];
+		Assert.same(expectedRanges, result.map(l -> l.range));
+
+		runHaxeJson([], DisplayMethods.FindReferences, {
+			file: file,
+			offset: offset(1),
+			contents: source,
+			kind: WithDescendants
+		});
+		var result = parseGotoDefinitionLocations();
+		var expectedRanges = [range(5, 6), range(8, 9), range(12, 13)];
+		Assert.same(expectedRanges, result.map(l -> l.range));
+	}
+}