worker.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. """
  2. Worker classes for processing queue items.
  3. Workers poll the database for items to process, claim them atomically,
  4. and run the state machine tick() to process each item.
  5. Architecture:
  6. Orchestrator (spawns workers)
  7. └── Worker (claims items from queue, processes them directly)
  8. """
  9. __package__ = 'archivebox.workers'
  10. import os
  11. import time
  12. import traceback
  13. from typing import ClassVar, Any
  14. from datetime import timedelta
  15. from pathlib import Path
  16. from multiprocessing import Process, cpu_count
  17. from django.db.models import QuerySet
  18. from django.utils import timezone
  19. from django.conf import settings
  20. from rich import print
  21. from archivebox.misc.logging_util import log_worker_event
  22. from .pid_utils import (
  23. write_pid_file,
  24. remove_pid_file,
  25. get_all_worker_pids,
  26. get_next_worker_id,
  27. cleanup_stale_pid_files,
  28. )
  29. CPU_COUNT = cpu_count()
  30. # Registry of worker types by name (defined at bottom, referenced here for _run_worker)
  31. WORKER_TYPES: dict[str, type['Worker']] = {}
  32. def _run_worker(worker_class_name: str, worker_id: int, daemon: bool, **kwargs):
  33. """
  34. Module-level function to run a worker. Must be at module level for pickling.
  35. """
  36. from archivebox.config.django import setup_django
  37. setup_django()
  38. # Get worker class by name to avoid pickling class objects
  39. worker_cls = WORKER_TYPES[worker_class_name]
  40. worker = worker_cls(worker_id=worker_id, daemon=daemon, **kwargs)
  41. worker.runloop()
  42. class Worker:
  43. """
  44. Base worker class that polls a queue and processes items directly.
  45. Each item is processed by calling its state machine tick() method.
  46. Workers exit when idle for too long (unless daemon mode).
  47. """
  48. name: ClassVar[str] = 'worker'
  49. # Configuration (can be overridden by subclasses)
  50. MAX_TICK_TIME: ClassVar[int] = 60
  51. MAX_CONCURRENT_TASKS: ClassVar[int] = 1
  52. POLL_INTERVAL: ClassVar[float] = 0.2 # How often to check for new work (seconds)
  53. IDLE_TIMEOUT: ClassVar[int] = 50 # Exit after N idle iterations (10 sec at 0.2 poll interval)
  54. def __init__(self, worker_id: int = 0, daemon: bool = False, **kwargs: Any):
  55. self.worker_id = worker_id
  56. self.daemon = daemon
  57. self.pid: int = os.getpid()
  58. self.pid_file: Path | None = None
  59. self.idle_count: int = 0
  60. def __repr__(self) -> str:
  61. return f'[underline]{self.__class__.__name__}[/underline]\\[id={self.worker_id}, pid={self.pid}]'
  62. def get_model(self):
  63. """Get the Django model class. Subclasses must override this."""
  64. raise NotImplementedError("Subclasses must implement get_model()")
  65. def get_queue(self) -> QuerySet:
  66. """Get the queue of objects ready for processing."""
  67. Model = self.get_model()
  68. return Model.objects.filter(
  69. retry_at__lte=timezone.now()
  70. ).exclude(
  71. status__in=Model.FINAL_STATES
  72. ).order_by('retry_at')
  73. def claim_next(self):
  74. """
  75. Atomically claim the next object from the queue.
  76. Returns the claimed object or None if queue is empty or claim failed.
  77. """
  78. Model = self.get_model()
  79. obj = self.get_queue().first()
  80. if obj is None:
  81. return None
  82. # Atomic claim using optimistic locking on retry_at
  83. claimed = Model.objects.filter(
  84. pk=obj.pk,
  85. retry_at=obj.retry_at,
  86. ).update(
  87. retry_at=timezone.now() + timedelta(seconds=self.MAX_TICK_TIME)
  88. )
  89. if claimed == 1:
  90. obj.refresh_from_db()
  91. return obj
  92. return None # Someone else claimed it
  93. def process_item(self, obj) -> bool:
  94. """
  95. Process a single item by calling its state machine tick().
  96. Returns True on success, False on failure.
  97. Subclasses can override for custom processing.
  98. """
  99. try:
  100. obj.sm.tick()
  101. return True
  102. except Exception as e:
  103. # Error will be logged in runloop's completion event
  104. traceback.print_exc()
  105. return False
  106. def on_startup(self) -> None:
  107. """Called when worker starts."""
  108. self.pid = os.getpid()
  109. self.pid_file = write_pid_file(self.name, self.worker_id)
  110. # Determine worker type for logging
  111. worker_type_name = self.__class__.__name__
  112. indent_level = 1 # Default for most workers
  113. # Adjust indent level based on worker type
  114. if 'Snapshot' in worker_type_name:
  115. indent_level = 2
  116. elif 'ArchiveResult' in worker_type_name:
  117. indent_level = 3
  118. log_worker_event(
  119. worker_type=worker_type_name,
  120. event='Starting...',
  121. indent_level=indent_level,
  122. pid=self.pid,
  123. worker_id=str(self.worker_id),
  124. metadata={
  125. 'max_concurrent': self.MAX_CONCURRENT_TASKS,
  126. 'poll_interval': self.POLL_INTERVAL,
  127. },
  128. )
  129. def on_shutdown(self, error: BaseException | None = None) -> None:
  130. """Called when worker shuts down."""
  131. # Remove PID file
  132. if self.pid_file:
  133. remove_pid_file(self.pid_file)
  134. # Determine worker type for logging
  135. worker_type_name = self.__class__.__name__
  136. indent_level = 1
  137. if 'Snapshot' in worker_type_name:
  138. indent_level = 2
  139. elif 'ArchiveResult' in worker_type_name:
  140. indent_level = 3
  141. log_worker_event(
  142. worker_type=worker_type_name,
  143. event='Shutting down',
  144. indent_level=indent_level,
  145. pid=self.pid,
  146. worker_id=str(self.worker_id),
  147. error=error if error and not isinstance(error, KeyboardInterrupt) else None,
  148. )
  149. def should_exit(self) -> bool:
  150. """Check if worker should exit due to idle timeout."""
  151. if self.daemon:
  152. return False
  153. if self.IDLE_TIMEOUT == 0:
  154. return False
  155. return self.idle_count >= self.IDLE_TIMEOUT
  156. def runloop(self) -> None:
  157. """Main worker loop - polls queue, processes items."""
  158. self.on_startup()
  159. # Determine worker type for logging
  160. worker_type_name = self.__class__.__name__
  161. indent_level = 1
  162. if 'Snapshot' in worker_type_name:
  163. indent_level = 2
  164. elif 'ArchiveResult' in worker_type_name:
  165. indent_level = 3
  166. try:
  167. while True:
  168. # Try to claim and process an item
  169. obj = self.claim_next()
  170. if obj is not None:
  171. self.idle_count = 0
  172. # Build metadata for task start
  173. start_metadata = {}
  174. url = None
  175. if hasattr(obj, 'url'):
  176. # SnapshotWorker
  177. url = str(obj.url) if obj.url else None
  178. elif hasattr(obj, 'snapshot') and hasattr(obj.snapshot, 'url'):
  179. # ArchiveResultWorker
  180. url = str(obj.snapshot.url) if obj.snapshot.url else None
  181. elif hasattr(obj, 'get_urls_list'):
  182. # CrawlWorker
  183. urls = obj.get_urls_list()
  184. url = urls[0] if urls else None
  185. plugin = None
  186. if hasattr(obj, 'plugin'):
  187. # ArchiveResultWorker, Crawl
  188. plugin = obj.plugin
  189. log_worker_event(
  190. worker_type=worker_type_name,
  191. event='Starting...',
  192. indent_level=indent_level,
  193. pid=self.pid,
  194. worker_id=str(self.worker_id),
  195. url=url,
  196. plugin=plugin,
  197. metadata=start_metadata if start_metadata else None,
  198. )
  199. start_time = time.time()
  200. success = self.process_item(obj)
  201. elapsed = time.time() - start_time
  202. # Build metadata for task completion
  203. complete_metadata = {
  204. 'duration': elapsed,
  205. 'status': 'success' if success else 'failed',
  206. }
  207. log_worker_event(
  208. worker_type=worker_type_name,
  209. event='Completed' if success else 'Failed',
  210. indent_level=indent_level,
  211. pid=self.pid,
  212. worker_id=str(self.worker_id),
  213. url=url,
  214. plugin=plugin,
  215. metadata=complete_metadata,
  216. )
  217. else:
  218. # No work available - idle logging suppressed
  219. self.idle_count += 1
  220. # Check if we should exit
  221. if self.should_exit():
  222. # Exit logging suppressed - shutdown will be logged by on_shutdown()
  223. break
  224. time.sleep(self.POLL_INTERVAL)
  225. except KeyboardInterrupt:
  226. pass
  227. except BaseException as e:
  228. self.on_shutdown(error=e)
  229. raise
  230. else:
  231. self.on_shutdown()
  232. @classmethod
  233. def start(cls, worker_id: int | None = None, daemon: bool = False, **kwargs: Any) -> int:
  234. """
  235. Fork a new worker as a subprocess.
  236. Returns the PID of the new process.
  237. """
  238. if worker_id is None:
  239. worker_id = get_next_worker_id(cls.name)
  240. # Use module-level function for pickling compatibility
  241. proc = Process(
  242. target=_run_worker,
  243. args=(cls.name, worker_id, daemon),
  244. kwargs=kwargs,
  245. name=f'{cls.name}_worker_{worker_id}',
  246. )
  247. proc.start()
  248. assert proc.pid is not None
  249. return proc.pid
  250. @classmethod
  251. def get_running_workers(cls) -> list[dict]:
  252. """Get info about all running workers of this type."""
  253. cleanup_stale_pid_files()
  254. return get_all_worker_pids(cls.name)
  255. @classmethod
  256. def get_worker_count(cls) -> int:
  257. """Get count of running workers of this type."""
  258. return len(cls.get_running_workers())
  259. class CrawlWorker(Worker):
  260. """Worker for processing Crawl objects."""
  261. name: ClassVar[str] = 'crawl'
  262. MAX_TICK_TIME: ClassVar[int] = 60
  263. def get_model(self):
  264. from archivebox.crawls.models import Crawl
  265. return Crawl
  266. class SnapshotWorker(Worker):
  267. """Worker for processing Snapshot objects."""
  268. name: ClassVar[str] = 'snapshot'
  269. MAX_TICK_TIME: ClassVar[int] = 60
  270. def get_model(self):
  271. from archivebox.core.models import Snapshot
  272. return Snapshot
  273. class ArchiveResultWorker(Worker):
  274. """Worker for processing ArchiveResult objects."""
  275. name: ClassVar[str] = 'archiveresult'
  276. MAX_TICK_TIME: ClassVar[int] = 120
  277. def __init__(self, plugin: str | None = None, **kwargs: Any):
  278. super().__init__(**kwargs)
  279. self.plugin = plugin
  280. def get_model(self):
  281. from archivebox.core.models import ArchiveResult
  282. return ArchiveResult
  283. def get_queue(self) -> QuerySet:
  284. """
  285. Get queue of ArchiveResults ready for processing.
  286. Uses step-based filtering: only claims ARs where hook step <= snapshot.current_step.
  287. This ensures hooks execute in order (step 0 → 1 → 2 ... → 9).
  288. """
  289. from archivebox.core.models import ArchiveResult
  290. from archivebox.hooks import extract_step
  291. qs = super().get_queue()
  292. if self.plugin:
  293. qs = qs.filter(plugin=self.plugin)
  294. # Step-based filtering: only process ARs whose step <= snapshot.current_step
  295. # Since step is derived from hook_name, we filter in Python after initial query
  296. # This is efficient because the base query already filters by retry_at and status
  297. # Get candidate ARs
  298. candidates = list(qs[:50]) # Limit to avoid loading too many
  299. ready_pks = []
  300. for ar in candidates:
  301. if not ar.hook_name:
  302. # Legacy ARs without hook_name - process them
  303. ready_pks.append(ar.pk)
  304. continue
  305. ar_step = extract_step(ar.hook_name)
  306. snapshot_step = ar.snapshot.current_step
  307. if ar_step <= snapshot_step:
  308. ready_pks.append(ar.pk)
  309. # Return filtered queryset ordered by hook_name (so earlier hooks run first within a step)
  310. return ArchiveResult.objects.filter(pk__in=ready_pks).order_by('hook_name', 'retry_at')
  311. def process_item(self, obj) -> bool:
  312. """Process an ArchiveResult by running its plugin."""
  313. try:
  314. obj.sm.tick()
  315. return True
  316. except Exception as e:
  317. # Error will be logged in runloop's completion event
  318. traceback.print_exc()
  319. return False
  320. @classmethod
  321. def start(cls, worker_id: int | None = None, daemon: bool = False, plugin: str | None = None, **kwargs: Any) -> int:
  322. """Fork a new worker as subprocess with optional plugin filter."""
  323. if worker_id is None:
  324. worker_id = get_next_worker_id(cls.name)
  325. # Use module-level function for pickling compatibility
  326. proc = Process(
  327. target=_run_worker,
  328. args=(cls.name, worker_id, daemon),
  329. kwargs={'plugin': plugin, **kwargs},
  330. name=f'{cls.name}_worker_{worker_id}',
  331. )
  332. proc.start()
  333. assert proc.pid is not None
  334. return proc.pid
  335. # Populate the registry
  336. WORKER_TYPES.update({
  337. 'crawl': CrawlWorker,
  338. 'snapshot': SnapshotWorker,
  339. 'archiveresult': ArchiveResultWorker,
  340. })
  341. def get_worker_class(name: str) -> type[Worker]:
  342. """Get worker class by name."""
  343. if name not in WORKER_TYPES:
  344. raise ValueError(f'Unknown worker type: {name}. Valid types: {list(WORKER_TYPES.keys())}')
  345. return WORKER_TYPES[name]