worker.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. """
  2. Worker classes for processing queue items.
  3. Workers poll the database for items to process, claim them atomically,
  4. and run the state machine tick() to process each item.
  5. Architecture:
  6. Orchestrator (spawns workers)
  7. └── Worker (claims items from queue, processes them directly)
  8. """
  9. __package__ = 'archivebox.workers'
  10. import os
  11. import time
  12. import traceback
  13. from typing import ClassVar, Any
  14. from datetime import timedelta
  15. from pathlib import Path
  16. from multiprocessing import Process as MPProcess, cpu_count
  17. from django.db.models import QuerySet
  18. from django.utils import timezone
  19. from django.conf import settings
  20. from rich import print
  21. from archivebox.misc.logging_util import log_worker_event
  22. CPU_COUNT = cpu_count()
  23. # Registry of worker types by name (defined at bottom, referenced here for _run_worker)
  24. WORKER_TYPES: dict[str, type['Worker']] = {}
  25. def _run_worker(worker_class_name: str, worker_id: int, daemon: bool, **kwargs):
  26. """
  27. Module-level function to run a worker. Must be at module level for pickling.
  28. """
  29. from archivebox.config.django import setup_django
  30. setup_django()
  31. # Get worker class by name to avoid pickling class objects
  32. worker_cls = WORKER_TYPES[worker_class_name]
  33. worker = worker_cls(worker_id=worker_id, daemon=daemon, **kwargs)
  34. worker.runloop()
  35. class Worker:
  36. """
  37. Base worker class that polls a queue and processes items directly.
  38. Each item is processed by calling its state machine tick() method.
  39. Workers exit when idle for too long (unless daemon mode).
  40. """
  41. name: ClassVar[str] = 'worker'
  42. # Configuration (can be overridden by subclasses)
  43. MAX_TICK_TIME: ClassVar[int] = 60
  44. MAX_CONCURRENT_TASKS: ClassVar[int] = 1
  45. POLL_INTERVAL: ClassVar[float] = 0.2 # How often to check for new work (seconds)
  46. IDLE_TIMEOUT: ClassVar[int] = 50 # Exit after N idle iterations (10 sec at 0.2 poll interval)
  47. def __init__(self, worker_id: int = 0, daemon: bool = False, **kwargs: Any):
  48. self.worker_id = worker_id
  49. self.daemon = daemon
  50. self.pid: int = os.getpid()
  51. self.pid_file: Path | None = None
  52. self.idle_count: int = 0
  53. def __repr__(self) -> str:
  54. return f'[underline]{self.__class__.__name__}[/underline]\\[id={self.worker_id}, pid={self.pid}]'
  55. def get_model(self):
  56. """Get the Django model class. Subclasses must override this."""
  57. raise NotImplementedError("Subclasses must implement get_model()")
  58. def get_queue(self) -> QuerySet:
  59. """Get the queue of objects ready for processing."""
  60. Model = self.get_model()
  61. return Model.objects.filter(
  62. retry_at__lte=timezone.now()
  63. ).exclude(
  64. status__in=Model.FINAL_STATES
  65. ).order_by('retry_at')
  66. def claim_next(self):
  67. """
  68. Atomically claim the next object from the queue.
  69. Returns the claimed object or None if queue is empty or claim failed.
  70. """
  71. Model = self.get_model()
  72. obj = self.get_queue().first()
  73. if obj is None:
  74. return None
  75. # Atomic claim using optimistic locking on retry_at
  76. claimed = Model.objects.filter(
  77. pk=obj.pk,
  78. retry_at=obj.retry_at,
  79. ).update(
  80. retry_at=timezone.now() + timedelta(seconds=self.MAX_TICK_TIME)
  81. )
  82. if claimed == 1:
  83. obj.refresh_from_db()
  84. return obj
  85. return None # Someone else claimed it
  86. def process_item(self, obj) -> bool:
  87. """
  88. Process a single item by calling its state machine tick().
  89. Returns True on success, False on failure.
  90. Subclasses can override for custom processing.
  91. """
  92. try:
  93. obj.sm.tick()
  94. return True
  95. except Exception as e:
  96. # Error will be logged in runloop's completion event
  97. traceback.print_exc()
  98. return False
  99. def on_startup(self) -> None:
  100. """Called when worker starts."""
  101. from archivebox.machine.models import Process
  102. self.pid = os.getpid()
  103. # Register this worker process in the database
  104. self.db_process = Process.current()
  105. # Explicitly set process_type to WORKER to prevent mis-detection
  106. if self.db_process.process_type != Process.TypeChoices.WORKER:
  107. self.db_process.process_type = Process.TypeChoices.WORKER
  108. self.db_process.save(update_fields=['process_type'])
  109. # Determine worker type for logging
  110. worker_type_name = self.__class__.__name__
  111. indent_level = 1 # Default for most workers
  112. # Adjust indent level based on worker type
  113. if 'Snapshot' in worker_type_name:
  114. indent_level = 2
  115. elif 'ArchiveResult' in worker_type_name:
  116. indent_level = 3
  117. log_worker_event(
  118. worker_type=worker_type_name,
  119. event='Starting...',
  120. indent_level=indent_level,
  121. pid=self.pid,
  122. worker_id=str(self.worker_id),
  123. metadata={
  124. 'max_concurrent': self.MAX_CONCURRENT_TASKS,
  125. 'poll_interval': self.POLL_INTERVAL,
  126. },
  127. )
  128. def on_shutdown(self, error: BaseException | None = None) -> None:
  129. """Called when worker shuts down."""
  130. # Update Process record status
  131. if hasattr(self, 'db_process') and self.db_process:
  132. self.db_process.exit_code = 1 if error else 0
  133. self.db_process.status = self.db_process.StatusChoices.EXITED
  134. self.db_process.ended_at = timezone.now()
  135. self.db_process.save()
  136. # Determine worker type for logging
  137. worker_type_name = self.__class__.__name__
  138. indent_level = 1
  139. if 'Snapshot' in worker_type_name:
  140. indent_level = 2
  141. elif 'ArchiveResult' in worker_type_name:
  142. indent_level = 3
  143. log_worker_event(
  144. worker_type=worker_type_name,
  145. event='Shutting down',
  146. indent_level=indent_level,
  147. pid=self.pid,
  148. worker_id=str(self.worker_id),
  149. error=error if error and not isinstance(error, KeyboardInterrupt) else None,
  150. )
  151. def should_exit(self) -> bool:
  152. """Check if worker should exit due to idle timeout."""
  153. if self.daemon:
  154. return False
  155. if self.IDLE_TIMEOUT == 0:
  156. return False
  157. return self.idle_count >= self.IDLE_TIMEOUT
  158. def runloop(self) -> None:
  159. """Main worker loop - polls queue, processes items."""
  160. self.on_startup()
  161. # Determine worker type for logging
  162. worker_type_name = self.__class__.__name__
  163. indent_level = 1
  164. if 'Snapshot' in worker_type_name:
  165. indent_level = 2
  166. elif 'ArchiveResult' in worker_type_name:
  167. indent_level = 3
  168. try:
  169. while True:
  170. # Try to claim and process an item
  171. obj = self.claim_next()
  172. if obj is not None:
  173. self.idle_count = 0
  174. # Build metadata for task start
  175. start_metadata = {}
  176. url = None
  177. if hasattr(obj, 'url'):
  178. # SnapshotWorker
  179. url = str(obj.url) if obj.url else None
  180. elif hasattr(obj, 'snapshot') and hasattr(obj.snapshot, 'url'):
  181. # ArchiveResultWorker
  182. url = str(obj.snapshot.url) if obj.snapshot.url else None
  183. elif hasattr(obj, 'get_urls_list'):
  184. # CrawlWorker
  185. urls = obj.get_urls_list()
  186. url = urls[0] if urls else None
  187. plugin = None
  188. if hasattr(obj, 'plugin'):
  189. # ArchiveResultWorker, Crawl
  190. plugin = obj.plugin
  191. log_worker_event(
  192. worker_type=worker_type_name,
  193. event='Starting...',
  194. indent_level=indent_level,
  195. pid=self.pid,
  196. worker_id=str(self.worker_id),
  197. url=url,
  198. plugin=plugin,
  199. metadata=start_metadata if start_metadata else None,
  200. )
  201. start_time = time.time()
  202. success = self.process_item(obj)
  203. elapsed = time.time() - start_time
  204. # Build metadata for task completion
  205. complete_metadata = {
  206. 'duration': elapsed,
  207. 'status': 'success' if success else 'failed',
  208. }
  209. log_worker_event(
  210. worker_type=worker_type_name,
  211. event='Completed' if success else 'Failed',
  212. indent_level=indent_level,
  213. pid=self.pid,
  214. worker_id=str(self.worker_id),
  215. url=url,
  216. plugin=plugin,
  217. metadata=complete_metadata,
  218. )
  219. else:
  220. # No work available - idle logging suppressed
  221. self.idle_count += 1
  222. # Check if we should exit
  223. if self.should_exit():
  224. # Exit logging suppressed - shutdown will be logged by on_shutdown()
  225. break
  226. time.sleep(self.POLL_INTERVAL)
  227. except KeyboardInterrupt:
  228. pass
  229. except BaseException as e:
  230. self.on_shutdown(error=e)
  231. raise
  232. else:
  233. self.on_shutdown()
  234. @classmethod
  235. def start(cls, worker_id: int | None = None, daemon: bool = False, **kwargs: Any) -> int:
  236. """
  237. Fork a new worker as a subprocess.
  238. Returns the PID of the new process.
  239. """
  240. from archivebox.machine.models import Process
  241. if worker_id is None:
  242. worker_id = Process.get_next_worker_id(process_type=Process.TypeChoices.WORKER)
  243. # Use module-level function for pickling compatibility
  244. proc = MPProcess(
  245. target=_run_worker,
  246. args=(cls.name, worker_id, daemon),
  247. kwargs=kwargs,
  248. name=f'{cls.name}_worker_{worker_id}',
  249. )
  250. proc.start()
  251. assert proc.pid is not None
  252. return proc.pid
  253. @classmethod
  254. def get_running_workers(cls) -> list:
  255. """Get info about all running workers of this type."""
  256. from archivebox.machine.models import Process
  257. Process.cleanup_stale_running()
  258. # Convert Process objects to dicts to match the expected API contract
  259. processes = Process.get_running(process_type=Process.TypeChoices.WORKER)
  260. # Note: worker_id is not stored on Process model, it's dynamically generated
  261. # We return process_id (UUID) and pid (OS process ID) instead
  262. return [
  263. {
  264. 'pid': p.pid,
  265. 'process_id': str(p.id), # UUID of Process record
  266. 'started_at': p.started_at.isoformat() if p.started_at else None,
  267. 'status': p.status,
  268. }
  269. for p in processes
  270. ]
  271. @classmethod
  272. def get_worker_count(cls) -> int:
  273. """Get count of running workers of this type."""
  274. from archivebox.machine.models import Process
  275. return Process.get_running_count(process_type=Process.TypeChoices.WORKER)
  276. class CrawlWorker(Worker):
  277. """Worker for processing Crawl objects."""
  278. name: ClassVar[str] = 'crawl'
  279. MAX_TICK_TIME: ClassVar[int] = 60
  280. def get_model(self):
  281. from archivebox.crawls.models import Crawl
  282. return Crawl
  283. class SnapshotWorker(Worker):
  284. """Worker for processing Snapshot objects."""
  285. name: ClassVar[str] = 'snapshot'
  286. MAX_TICK_TIME: ClassVar[int] = 60
  287. def get_model(self):
  288. from archivebox.core.models import Snapshot
  289. return Snapshot
  290. class ArchiveResultWorker(Worker):
  291. """Worker for processing ArchiveResult objects."""
  292. name: ClassVar[str] = 'archiveresult'
  293. MAX_TICK_TIME: ClassVar[int] = 120
  294. def __init__(self, plugin: str | None = None, **kwargs: Any):
  295. super().__init__(**kwargs)
  296. self.plugin = plugin
  297. def get_model(self):
  298. from archivebox.core.models import ArchiveResult
  299. return ArchiveResult
  300. def get_queue(self) -> QuerySet:
  301. """
  302. Get queue of ArchiveResults ready for processing.
  303. Uses step-based filtering: only claims ARs where hook step <= snapshot.current_step.
  304. This ensures hooks execute in order (step 0 → 1 → 2 ... → 9).
  305. """
  306. from archivebox.core.models import ArchiveResult
  307. from archivebox.hooks import extract_step
  308. qs = super().get_queue()
  309. if self.plugin:
  310. qs = qs.filter(plugin=self.plugin)
  311. # Step-based filtering: only process ARs whose step <= snapshot.current_step
  312. # Since step is derived from hook_name, we filter in Python after initial query
  313. # This is efficient because the base query already filters by retry_at and status
  314. # Get candidate ARs
  315. candidates = list(qs[:50]) # Limit to avoid loading too many
  316. ready_pks = []
  317. for ar in candidates:
  318. if not ar.hook_name:
  319. # Legacy ARs without hook_name - process them
  320. ready_pks.append(ar.pk)
  321. continue
  322. ar_step = extract_step(ar.hook_name)
  323. snapshot_step = ar.snapshot.current_step
  324. if ar_step <= snapshot_step:
  325. ready_pks.append(ar.pk)
  326. # Return filtered queryset ordered by hook_name (so earlier hooks run first within a step)
  327. return ArchiveResult.objects.filter(pk__in=ready_pks).order_by('hook_name', 'retry_at')
  328. def process_item(self, obj) -> bool:
  329. """Process an ArchiveResult by running its plugin."""
  330. try:
  331. obj.sm.tick()
  332. return True
  333. except Exception as e:
  334. # Error will be logged in runloop's completion event
  335. traceback.print_exc()
  336. return False
  337. @classmethod
  338. def start(cls, worker_id: int | None = None, daemon: bool = False, plugin: str | None = None, **kwargs: Any) -> int:
  339. """Fork a new worker as subprocess with optional plugin filter."""
  340. from archivebox.machine.models import Process
  341. if worker_id is None:
  342. worker_id = Process.get_next_worker_id(process_type=Process.TypeChoices.WORKER)
  343. # Use module-level function for pickling compatibility
  344. proc = MPProcess(
  345. target=_run_worker,
  346. args=(cls.name, worker_id, daemon),
  347. kwargs={'plugin': plugin, **kwargs},
  348. name=f'{cls.name}_worker_{worker_id}',
  349. )
  350. proc.start()
  351. assert proc.pid is not None
  352. return proc.pid
  353. # Populate the registry
  354. WORKER_TYPES.update({
  355. 'crawl': CrawlWorker,
  356. 'snapshot': SnapshotWorker,
  357. 'archiveresult': ArchiveResultWorker,
  358. })
  359. def get_worker_class(name: str) -> type[Worker]:
  360. """Get worker class by name."""
  361. if name not in WORKER_TYPES:
  362. raise ValueError(f'Unknown worker type: {name}. Valid types: {list(WORKER_TYPES.keys())}')
  363. return WORKER_TYPES[name]