abx.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. __package__ = 'abx'
  2. __id__ = 'abx'
  3. __label__ = 'ABX'
  4. __author__ = 'Nick Sweeting'
  5. __homepage__ = 'https://github.com/ArchiveBox'
  6. __order__ = 0
  7. import inspect
  8. import importlib
  9. import itertools
  10. from pathlib import Path
  11. from typing import Dict, Callable, List, Set, Tuple, Iterable, Any, TypeVar, TypedDict, Type, cast, Generic, Mapping, overload, Final, ParamSpec, Literal, Protocol
  12. from types import ModuleType
  13. from typing_extensions import Annotated
  14. from functools import cache
  15. from benedict import benedict
  16. from pydantic import AfterValidator
  17. from pluggy import HookimplMarker, PluginManager, HookimplOpts, HookspecOpts, HookCaller
  18. ParamsT = ParamSpec("ParamsT")
  19. ReturnT = TypeVar('ReturnT')
  20. class HookSpecDecoratorThatReturnsFirstResult(Protocol):
  21. def __call__(self, func: Callable[ParamsT, ReturnT]) -> Callable[ParamsT, ReturnT]: ...
  22. class HookSpecDecoratorThatReturnsListResults(Protocol):
  23. def __call__(self, func: Callable[ParamsT, ReturnT]) -> Callable[ParamsT, List[ReturnT]]: ...
  24. class TypedHookspecMarker:
  25. """
  26. Improved version of pluggy.HookspecMarker that supports type inference of hookspecs with firstresult=True|False correctly
  27. https://github.com/pytest-dev/pluggy/issues/191
  28. """
  29. __slots__ = ('project_name',)
  30. def __init__(self, project_name: str) -> None:
  31. self.project_name: Final[str] = project_name
  32. # handle @hookspec(firstresult=False) -> List[ReturnT] (test_firstresult_False_hookspec)
  33. @overload
  34. def __call__(
  35. self,
  36. function: None = ...,
  37. firstresult: Literal[False] = ...,
  38. historic: bool = ...,
  39. warn_on_impl: Warning | None = ...,
  40. warn_on_impl_args: Mapping[str, Warning] | None = ...,
  41. ) -> HookSpecDecoratorThatReturnsListResults: ...
  42. # handle @hookspec(firstresult=True) -> ReturnT (test_firstresult_True_hookspec)
  43. @overload
  44. def __call__(
  45. self,
  46. function: None = ...,
  47. firstresult: Literal[True] = ...,
  48. historic: bool = ...,
  49. warn_on_impl: Warning | None = ...,
  50. warn_on_impl_args: Mapping[str, Warning] | None = ...,
  51. ) -> HookSpecDecoratorThatReturnsFirstResult: ...
  52. # handle @hookspec -> List[ReturnT] (test_normal_hookspec)
  53. # order matters!!! this one has to come last
  54. @overload
  55. def __call__(
  56. self,
  57. function: Callable[ParamsT, ReturnT] = ...,
  58. firstresult: Literal[False] = ...,
  59. historic: bool = ...,
  60. warn_on_impl: None = ...,
  61. warn_on_impl_args: None = ...,
  62. ) -> Callable[ParamsT, List[ReturnT]]: ...
  63. def __call__(
  64. self,
  65. function: Callable[ParamsT, ReturnT] | None = None,
  66. firstresult: bool = False,
  67. historic: bool = False,
  68. warn_on_impl: Warning | None = None,
  69. warn_on_impl_args: Mapping[str, Warning] | None = None,
  70. ) -> Callable[ParamsT, List[ReturnT]] | HookSpecDecoratorThatReturnsListResults | HookSpecDecoratorThatReturnsFirstResult:
  71. def setattr_hookspec_opts(func) -> Callable:
  72. if historic and firstresult:
  73. raise ValueError("cannot have a historic firstresult hook")
  74. opts: HookspecOpts = {
  75. "firstresult": firstresult,
  76. "historic": historic,
  77. "warn_on_impl": warn_on_impl,
  78. "warn_on_impl_args": warn_on_impl_args,
  79. }
  80. setattr(func, self.project_name + "_spec", opts)
  81. return func
  82. if function is not None:
  83. return setattr_hookspec_opts(function)
  84. else:
  85. return setattr_hookspec_opts
  86. spec = hookspec = TypedHookspecMarker("abx")
  87. impl = hookimpl = HookimplMarker("abx")
  88. def is_valid_attr_name(x: str) -> str:
  89. assert x.isidentifier() and not x.startswith('_')
  90. return x
  91. def is_valid_module_name(x: str) -> str:
  92. assert x.isidentifier() and not x.startswith('_') and x.islower()
  93. return x
  94. AttrName = Annotated[str, AfterValidator(is_valid_attr_name)]
  95. PluginId = Annotated[str, AfterValidator(is_valid_module_name)]
  96. class PluginInfo(TypedDict, total=True):
  97. id: PluginId
  98. package: AttrName
  99. label: str
  100. version: str
  101. author: str
  102. homepage: str
  103. dependencies: List[str]
  104. source_code: str
  105. hooks: Dict[AttrName, Callable]
  106. module: ModuleType
  107. PluginSpec = TypeVar("PluginSpec")
  108. class ABXPluginManager(PluginManager, Generic[PluginSpec]):
  109. """
  110. Patch to fix pluggy's PluginManager to work with pydantic models.
  111. See: https://github.com/pytest-dev/pluggy/pull/536
  112. """
  113. # enable static type checking of pm.hook.call() calls
  114. # https://stackoverflow.com/a/62871889/2156113
  115. # https://github.com/pytest-dev/pluggy/issues/191
  116. hook: PluginSpec
  117. def create_typed_hookcaller(self, name: str, module_or_class: Type[PluginSpec], spec_opts: HookspecOpts) -> HookCaller:
  118. """
  119. create a new HookCaller subclass with a modified __signature__
  120. so that the return type is correct and args are converted to kwargs
  121. """
  122. TypedHookCaller = type('TypedHookCaller', (HookCaller,), {})
  123. hookspec_signature = inspect.signature(getattr(module_or_class, name))
  124. hookspec_return_type = hookspec_signature.return_annotation
  125. # replace return type with list if firstresult=False
  126. hookcall_return_type = hookspec_return_type if spec_opts['firstresult'] else List[hookspec_return_type]
  127. # replace each arg with kwarg equivalent (pm.hook.call() only accepts kwargs)
  128. args_as_kwargs = [
  129. param.replace(kind=inspect.Parameter.KEYWORD_ONLY) if param.name != 'self' else param
  130. for param in hookspec_signature.parameters.values()
  131. ]
  132. TypedHookCaller.__signature__ = hookspec_signature.replace(parameters=args_as_kwargs, return_annotation=hookcall_return_type)
  133. TypedHookCaller.__name__ = f'{name}_HookCaller'
  134. return TypedHookCaller(name, self._hookexec, module_or_class, spec_opts)
  135. def add_hookspecs(self, module_or_class: Type[PluginSpec]) -> None:
  136. """Add HookSpecs from the given class, (generic type allows us to enforce types of pm.hook.call() statically)"""
  137. names = []
  138. for name in dir(module_or_class):
  139. spec_opts = self.parse_hookspec_opts(module_or_class, name)
  140. if spec_opts is not None:
  141. hc: HookCaller | None = getattr(self.hook, name, None)
  142. if hc is None:
  143. hc = self.create_typed_hookcaller(name, module_or_class, spec_opts)
  144. setattr(self.hook, name, hc)
  145. else:
  146. # Plugins registered this hook without knowing the spec.
  147. hc.set_specification(module_or_class, spec_opts)
  148. for hookfunction in hc.get_hookimpls():
  149. self._verify_hook(hc, hookfunction)
  150. names.append(name)
  151. if not names:
  152. raise ValueError(
  153. f"did not find any {self.project_name!r} hooks in {module_or_class!r}"
  154. )
  155. def parse_hookimpl_opts(self, plugin, name: str) -> HookimplOpts | None:
  156. # IMPORTANT: @property methods can have side effects, and are never hookimpl
  157. # if attr is a property, skip it in advance
  158. # plugin_class = plugin if inspect.isclass(plugin) else type(plugin)
  159. if isinstance(getattr(plugin, name, None), property):
  160. return None
  161. try:
  162. return super().parse_hookimpl_opts(plugin, name)
  163. except AttributeError:
  164. return None
  165. pm = ABXPluginManager("abx")
  166. def get_plugin_order(plugin: PluginId | Path | ModuleType | Type) -> Tuple[int, Path]:
  167. assert plugin
  168. plugin_module = None
  169. plugin_dir = None
  170. if isinstance(plugin, str) or isinstance(plugin, Path):
  171. if str(plugin).endswith('.py'):
  172. plugin_dir = Path(plugin).parent
  173. elif '/' in str(plugin):
  174. # assume it's a path to a plugin directory
  175. plugin_dir = Path(plugin)
  176. elif str(plugin).isidentifier():
  177. pass
  178. elif inspect.ismodule(plugin):
  179. plugin_module = plugin
  180. plugin_dir = Path(str(plugin_module.__file__)).parent
  181. elif inspect.isclass(plugin):
  182. plugin_module = plugin
  183. plugin_dir = Path(inspect.getfile(plugin)).parent
  184. else:
  185. raise ValueError(f'Invalid plugin, cannot get order: {plugin}')
  186. if plugin_dir:
  187. try:
  188. # if .plugin_order file exists, use it to set the load priority
  189. order = int((plugin_dir / '.plugin_order').read_text())
  190. assert -1000000 < order < 100000000
  191. return (order, plugin_dir)
  192. except FileNotFoundError:
  193. pass
  194. if plugin_module:
  195. order = getattr(plugin_module, '__order__', 999)
  196. else:
  197. order = 999
  198. assert order is not None
  199. assert plugin_dir
  200. return (order, plugin_dir)
  201. # @cache
  202. def get_plugin(plugin: PluginId | ModuleType | Type) -> PluginInfo:
  203. assert plugin
  204. # import the plugin module by its name
  205. if isinstance(plugin, str):
  206. module = importlib.import_module(plugin)
  207. # print('IMPORTED PLUGIN:', plugin)
  208. plugin = getattr(module, 'PLUGIN_SPEC', getattr(module, 'PLUGIN', module))
  209. elif inspect.ismodule(plugin):
  210. module = plugin
  211. plugin = getattr(module, 'PLUGIN_SPEC', getattr(module, 'PLUGIN', module))
  212. elif inspect.isclass(plugin):
  213. module = inspect.getmodule(plugin)
  214. else:
  215. raise ValueError(f'Invalid plugin, must be a module, class, or plugin ID (package name): {plugin}')
  216. assert module
  217. plugin_file = Path(inspect.getfile(module))
  218. plugin_package = module.__package__ or module.__name__
  219. plugin_id = plugin_package.replace('.', '_')
  220. # load the plugin info from the plugin/__init__.py __attr__s if they exist
  221. plugin_module_attrs = {
  222. 'label': getattr(module, '__label__', plugin_id),
  223. 'version': getattr(module, '__version__', '0.0.1'),
  224. 'author': getattr(module, '__author__', 'ArchiveBox'),
  225. 'homepage': getattr(module, '__homepage__', 'https://github.com/ArchiveBox'),
  226. 'dependencies': getattr(module, '__dependencies__', []),
  227. }
  228. # load the plugin info from the plugin/pyproject.toml file if it has one
  229. plugin_toml_info = {}
  230. try:
  231. # try loading ./pyproject.toml first in case the plugin is a bare python file not inside a package dir
  232. plugin_toml_info = benedict.from_toml((plugin_file.parent / 'pyproject.toml').read_text()).project
  233. except Exception:
  234. try:
  235. # try loading ../pyproject.toml next in case the plugin is in a packge dir
  236. plugin_toml_info = benedict.from_toml((plugin_file.parent.parent / 'pyproject.toml').read_text()).project
  237. except Exception:
  238. # print('WARNING: could not detect pyproject.toml for PLUGIN:', plugin_id, plugin_file.parent, 'ERROR:', e)
  239. pass
  240. assert plugin_id
  241. assert plugin_package
  242. assert module.__file__
  243. # merge the plugin info from all sources + add dyanmically calculated info
  244. return cast(PluginInfo, benedict(PluginInfo(**{
  245. 'id': plugin_id,
  246. **plugin_module_attrs,
  247. **plugin_toml_info,
  248. 'package': plugin_package,
  249. 'source_code': module.__file__,
  250. 'order': get_plugin_order(plugin),
  251. 'hooks': get_plugin_hooks(plugin),
  252. 'module': module,
  253. 'plugin': plugin,
  254. })))
  255. def get_all_plugins() -> Dict[PluginId, PluginInfo]:
  256. """Get the metadata for all the plugins registered with Pluggy."""
  257. plugins = {}
  258. for plugin_module in pm.get_plugins():
  259. plugin_info = get_plugin(plugin=plugin_module)
  260. assert 'id' in plugin_info
  261. plugins[plugin_info['id']] = plugin_info
  262. return benedict(plugins)
  263. def get_all_hook_names() -> Set[str]:
  264. """Get a set of all hook names across all plugins"""
  265. return {
  266. hook_name
  267. for plugin_module in pm.get_plugins()
  268. for hook_name in get_plugin_hooks(plugin_module)
  269. }
  270. def get_all_hook_specs() -> Dict[str, Dict[str, Any]]:
  271. """Get a set of all hookspec methods defined in all plugins (useful for type checking if a pm.hook.call() is valid)"""
  272. hook_specs = {}
  273. for hook_name in get_all_hook_names():
  274. for plugin_module in pm.get_plugins():
  275. if hasattr(plugin_module, hook_name):
  276. hookspecopts = pm.parse_hookspec_opts(plugin_module, hook_name)
  277. if hookspecopts:
  278. method = getattr(plugin_module, hook_name)
  279. signature = inspect.signature(method)
  280. return_type = signature.return_annotation if signature.return_annotation != inspect._empty else None
  281. if hookspecopts.get('firstresult'):
  282. return_type = return_type
  283. else:
  284. # if not firstresult, return_type is a sequence
  285. return_type = List[return_type]
  286. call_signature = signature.replace(return_annotation=return_type)
  287. method = lambda *args, **kwargs: getattr(pm.hook, hook_name)(*args, **kwargs)
  288. method.__signature__ = call_signature
  289. method.__name__ = hook_name
  290. method.__package__ = plugin_module.__package__
  291. hook_specs[hook_name] = {
  292. 'name': hook_name,
  293. 'method': method,
  294. 'signature': call_signature,
  295. 'hookspec_opts': hookspecopts,
  296. 'hookspec_signature': signature,
  297. 'hookspec_plugin': plugin_module.__package__,
  298. }
  299. return hook_specs
  300. ###### PLUGIN DISCOVERY AND LOADING ########################################################
  301. def find_plugins_in_dir(plugins_dir: Path) -> Dict[PluginId, Path]:
  302. """
  303. Find all the plugins in a given directory. Just looks for an __init__.py file.
  304. """
  305. python_dirs = plugins_dir.glob("*/__init__.py")
  306. sorted_python_dirs = sorted(python_dirs, key=lambda p: get_plugin_order(plugin=p) or 500)
  307. return {
  308. plugin_entrypoint.parent.name: plugin_entrypoint.parent
  309. for plugin_entrypoint in sorted_python_dirs
  310. if plugin_entrypoint.parent.name not in ('abx', 'core')
  311. }
  312. def get_pip_installed_plugins(group: PluginId='abx') -> Dict[PluginId, Path]:
  313. """replaces pm.load_setuptools_entrypoints("abx"), finds plugins that registered entrypoints via pip"""
  314. import importlib.metadata
  315. DETECTED_PLUGINS = {} # module_name: module_dir_path
  316. for dist in list(importlib.metadata.distributions()):
  317. for entrypoint in dist.entry_points:
  318. if entrypoint.group != group or pm.is_blocked(entrypoint.name):
  319. continue
  320. DETECTED_PLUGINS[entrypoint.name] = Path(entrypoint.load().__file__).parent
  321. # pm.register(plugin, name=ep.name)
  322. # pm._plugin_distinfo.append((plugin, DistFacade(dist)))
  323. return DETECTED_PLUGINS
  324. # Load all plugins from pip packages, archivebox built-ins, and user plugins
  325. def load_plugins(plugins: Iterable[PluginId | ModuleType | Type] | Dict[PluginId, Path]):
  326. """
  327. Load all the plugins from a dictionary of module names and directory paths.
  328. """
  329. PLUGINS_TO_LOAD = []
  330. LOADED_PLUGINS = {}
  331. for plugin in plugins:
  332. plugin_info = get_plugin(plugin)
  333. assert plugin_info, f'No plugin metadata found for {plugin}'
  334. assert 'id' in plugin_info and 'module' in plugin_info
  335. if plugin_info['module'] in pm.get_plugins():
  336. LOADED_PLUGINS[plugin_info['id']] = plugin_info
  337. continue
  338. else:
  339. PLUGINS_TO_LOAD.append(plugin_info)
  340. PLUGINS_TO_LOAD = sorted(PLUGINS_TO_LOAD, key=lambda x: x['order'])
  341. for plugin_info in PLUGINS_TO_LOAD:
  342. pm.register(plugin_info['module'])
  343. LOADED_PLUGINS[plugin_info['id']] = plugin_info
  344. # print(f' √ Loaded plugin: {plugin_id}')
  345. return benedict(LOADED_PLUGINS)
  346. @cache
  347. def get_plugin_hooks(plugin: PluginId | ModuleType | Type | None) -> Dict[AttrName, Callable]:
  348. """Get all the functions marked with @hookimpl on a module."""
  349. if not plugin:
  350. return {}
  351. hooks = {}
  352. if isinstance(plugin, str):
  353. plugin_module = importlib.import_module(plugin)
  354. elif inspect.ismodule(plugin) or inspect.isclass(plugin):
  355. plugin_module = plugin
  356. else:
  357. raise ValueError(f'Invalid plugin, cannot get hooks: {plugin}')
  358. for attr_name in dir(plugin_module):
  359. if attr_name.startswith('_'):
  360. continue
  361. try:
  362. attr = getattr(plugin_module, attr_name)
  363. if isinstance(attr, Callable):
  364. if pm.parse_hookimpl_opts(plugin_module, attr_name):
  365. hooks[attr_name] = attr
  366. except Exception as e:
  367. print(f'Error getting hookimpls for {plugin}: {e}')
  368. return hooks
  369. ReturnT = TypeVar('ReturnT')
  370. def as_list(results: List[List[ReturnT]]) -> List[ReturnT]:
  371. """Flatten a list of lists returned by a pm.hook.call() into a single list"""
  372. return list(itertools.chain(*results))
  373. def as_dict(results: List[Dict[PluginId, ReturnT]]) -> Dict[PluginId, ReturnT]:
  374. """Flatten a list of dicts returned by a pm.hook.call() into a single dict"""
  375. if isinstance(results, (dict, benedict)):
  376. results_list = results.values()
  377. else:
  378. results_list = results
  379. return benedict({
  380. result_id: result
  381. for plugin_results in results_list
  382. for result_id, result in plugin_results.items()
  383. })