models.py 23 KB


  1. __package__ = 'archivebox.workers'
  2. import uuid
  3. import json
  4. from typing import ClassVar, Type, Iterable, TypedDict
  5. from datetime import datetime, timedelta
  6. from statemachine.mixins import MachineMixin
  7. from django.db import models
  8. from django.db.models import QuerySet
  9. from django.core import checks
  10. from django.utils import timezone
  11. from django.utils.functional import classproperty
  12. from base_models.models import ABIDModel, ABIDField
  13. from machine.models import Process
  14. from statemachine import registry, StateMachine, State
  15. class DefaultStatusChoices(models.TextChoices):
  16. QUEUED = 'queued', 'Queued'
  17. STARTED = 'started', 'Started'
  18. SEALED = 'sealed', 'Sealed'
  19. default_status_field: models.CharField = models.CharField(choices=DefaultStatusChoices.choices, max_length=15, default=DefaultStatusChoices.QUEUED, null=False, blank=False, db_index=True)
  20. default_retry_at_field: models.DateTimeField = models.DateTimeField(default=timezone.now, null=True, blank=True, db_index=True)
  21. ObjectState = State | str
  22. ObjectStateList = Iterable[ObjectState]
  23. class BaseModelWithStateMachine(models.Model, MachineMixin):
  24. id: models.UUIDField
  25. StatusChoices: ClassVar[Type[models.TextChoices]]
  26. # status: models.CharField
  27. # retry_at: models.DateTimeField
  28. state_machine_name: ClassVar[str]
  29. state_field_name: ClassVar[str]
  30. state_machine_attr: ClassVar[str] = 'sm'
  31. bind_events_as_methods: ClassVar[bool] = True
  32. active_state: ClassVar[ObjectState]
  33. retry_at_field_name: ClassVar[str]
  34. class Meta:
  35. abstract = True
  36. @classmethod
  37. def check(cls, sender=None, **kwargs):
  38. errors = super().check(**kwargs)
  39. found_id_field = False
  40. found_status_field = False
  41. found_retry_at_field = False
  42. for field in cls._meta.get_fields():
  43. if getattr(field, '_is_state_field', False):
  44. if cls.state_field_name == field.name:
  45. found_status_field = True
  46. if getattr(field, 'choices', None) != cls.StatusChoices.choices:
  47. errors.append(checks.Error(
  48. f'{cls.__name__}.{field.name} must have choices set to {cls.__name__}.StatusChoices.choices',
  49. hint=f'{cls.__name__}.{field.name}.choices = {getattr(field, "choices", None)!r}',
  50. obj=cls,
  51. id='workers.E011',
  52. ))
  53. if getattr(field, '_is_retry_at_field', False):
  54. if cls.retry_at_field_name == field.name:
  55. found_retry_at_field = True
  56. if field.name == 'id' and getattr(field, 'primary_key', False):
  57. found_id_field = True
  58. if not found_status_field:
  59. errors.append(checks.Error(
  60. f'{cls.__name__}.state_field_name must be defined and point to a StatusField()',
  61. hint=f'{cls.__name__}.state_field_name = {cls.state_field_name!r} but {cls.__name__}.{cls.state_field_name!r} was not found or does not refer to StatusField',
  62. obj=cls,
  63. id='workers.E012',
  64. ))
  65. if not found_retry_at_field:
  66. errors.append(checks.Error(
  67. f'{cls.__name__}.retry_at_field_name must be defined and point to a RetryAtField()',
  68. hint=f'{cls.__name__}.retry_at_field_name = {cls.retry_at_field_name!r} but {cls.__name__}.{cls.retry_at_field_name!r} was not found or does not refer to RetryAtField',
  69. obj=cls,
  70. id='workers.E013',
  71. ))
  72. if not found_id_field:
  73. errors.append(checks.Error(
  74. f'{cls.__name__} must have an id field that is a primary key',
  75. hint=f'{cls.__name__}.id = {cls.id!r}',
  76. obj=cls,
  77. id='workers.E014',
  78. ))
  79. if not isinstance(cls.state_machine_name, str):
  80. errors.append(checks.Error(
  81. f'{cls.__name__}.state_machine_name must be a dotted-import path to a StateMachine class',
  82. hint=f'{cls.__name__}.state_machine_name = {cls.state_machine_name!r}',
  83. obj=cls,
  84. id='workers.E015',
  85. ))
  86. try:
  87. cls.StateMachineClass
  88. except Exception as err:
  89. errors.append(checks.Error(
  90. f'{cls.__name__}.state_machine_name must point to a valid StateMachine class, but got {type(err).__name__} {err} when trying to access {cls.__name__}.StateMachineClass',
  91. hint=f'{cls.__name__}.state_machine_name = {cls.state_machine_name!r}',
  92. obj=cls,
  93. id='workers.E016',
  94. ))
  95. if cls.INITIAL_STATE not in cls.StatusChoices.values:
  96. errors.append(checks.Error(
  97. f'{cls.__name__}.StateMachineClass.initial_state must be present within {cls.__name__}.StatusChoices',
  98. hint=f'{cls.__name__}.StateMachineClass.initial_state = {cls.StateMachineClass.initial_state!r}',
  99. obj=cls,
  100. id='workers.E017',
  101. ))
  102. if cls.ACTIVE_STATE not in cls.StatusChoices.values:
  103. errors.append(checks.Error(
  104. f'{cls.__name__}.active_state must be set to a valid State present within {cls.__name__}.StatusChoices',
  105. hint=f'{cls.__name__}.active_state = {cls.active_state!r}',
  106. obj=cls,
  107. id='workers.E018',
  108. ))
  109. for state in cls.FINAL_STATES:
  110. if state not in cls.StatusChoices.values:
  111. errors.append(checks.Error(
  112. f'{cls.__name__}.StateMachineClass.final_states must all be present within {cls.__name__}.StatusChoices',
  113. hint=f'{cls.__name__}.StateMachineClass.final_states = {cls.StateMachineClass.final_states!r}',
  114. obj=cls,
  115. id='workers.E019',
  116. ))
  117. break
  118. return errors
  119. @staticmethod
  120. def _state_to_str(state: ObjectState) -> str:
  121. """Convert a statemachine.State, models.TextChoices.choices value, or Enum value to a str"""
  122. return str(state.value) if isinstance(state, State) else str(state)
  123. @property
  124. def RETRY_AT(self) -> datetime:
  125. return getattr(self, self.retry_at_field_name)
  126. @RETRY_AT.setter
  127. def RETRY_AT(self, value: datetime):
  128. setattr(self, self.retry_at_field_name, value)
  129. @property
  130. def STATE(self) -> str:
  131. return getattr(self, self.state_field_name)
  132. @STATE.setter
  133. def STATE(self, value: str):
  134. setattr(self, self.state_field_name, value)
  135. def bump_retry_at(self, seconds: int = 10):
  136. self.RETRY_AT = timezone.now() + timedelta(seconds=seconds)
  137. @classproperty
  138. def ACTIVE_STATE(cls) -> str:
  139. return cls._state_to_str(cls.active_state)
  140. @classproperty
  141. def INITIAL_STATE(cls) -> str:
  142. return cls._state_to_str(cls.StateMachineClass.initial_state)
  143. @classproperty
  144. def FINAL_STATES(cls) -> list[str]:
  145. return [cls._state_to_str(state) for state in cls.StateMachineClass.final_states]
  146. @classproperty
  147. def FINAL_OR_ACTIVE_STATES(cls) -> list[str]:
  148. return [*cls.FINAL_STATES, cls.ACTIVE_STATE]
  149. @classmethod
  150. def extend_choices(cls, base_choices: Type[models.TextChoices]):
  151. """
  152. Decorator to extend the base choices with extra choices, e.g.:
  153. class MyModel(ModelWithStateMachine):
  154. @ModelWithStateMachine.extend_choices(ModelWithStateMachine.StatusChoices)
  155. class StatusChoices(models.TextChoices):
  156. SUCCEEDED = 'succeeded'
  157. FAILED = 'failed'
  158. SKIPPED = 'skipped'
  159. """
  160. assert issubclass(base_choices, models.TextChoices), f'@extend_choices(base_choices) must be a TextChoices class, not {base_choices.__name__}'
  161. def wrapper(extra_choices: Type[models.TextChoices]) -> Type[models.TextChoices]:
  162. joined = {}
  163. for item in base_choices.choices:
  164. joined[item[0]] = item[1]
  165. for item in extra_choices.choices:
  166. joined[item[0]] = item[1]
  167. return models.TextChoices('StatusChoices', joined)
  168. return wrapper
  169. @classmethod
  170. def StatusField(cls, **kwargs) -> models.CharField:
  171. """
  172. Used on subclasses to extend/modify the status field with updated kwargs. e.g.:
  173. class MyModel(ModelWithStateMachine):
  174. class StatusChoices(ModelWithStateMachine.StatusChoices):
  175. QUEUED = 'queued', 'Queued'
  176. STARTED = 'started', 'Started'
  177. SEALED = 'sealed', 'Sealed'
  178. BACKOFF = 'backoff', 'Backoff'
  179. FAILED = 'failed', 'Failed'
  180. SKIPPED = 'skipped', 'Skipped'
  181. status = ModelWithStateMachine.StatusField(choices=StatusChoices.choices, default=StatusChoices.QUEUED)
  182. """
  183. default_kwargs = default_status_field.deconstruct()[3]
  184. updated_kwargs = {**default_kwargs, **kwargs}
  185. field = models.CharField(**updated_kwargs)
  186. field._is_state_field = True # type: ignore
  187. return field
  188. @classmethod
  189. def RetryAtField(cls, **kwargs) -> models.DateTimeField:
  190. """
  191. Used on subclasses to extend/modify the retry_at field with updated kwargs. e.g.:
  192. class MyModel(ModelWithStateMachine):
  193. retry_at = ModelWithStateMachine.RetryAtField(editable=False)
  194. """
  195. default_kwargs = default_retry_at_field.deconstruct()[3]
  196. updated_kwargs = {**default_kwargs, **kwargs}
  197. field = models.DateTimeField(**updated_kwargs)
  198. field._is_retry_at_field = True # type: ignore
  199. return field
  200. @classproperty
  201. def StateMachineClass(cls) -> Type[StateMachine]:
  202. """Get the StateMachine class for the given django Model that inherits from MachineMixin"""
  203. model_state_machine_name = getattr(cls, 'state_machine_name', None)
  204. if model_state_machine_name:
  205. StateMachineCls = registry.get_machine_cls(model_state_machine_name)
  206. assert issubclass(StateMachineCls, StateMachine)
  207. return StateMachineCls
  208. raise NotImplementedError(f'ActorType[{cls.__name__}] must define .state_machine_name: str that points to a valid StateMachine')
  209. # @classproperty
  210. # def final_q(cls) -> Q:
  211. # """Get the filter for objects that are in a final state"""
  212. # return Q(**{f'{cls.state_field_name}__in': cls.final_states})
  213. # @classproperty
  214. # def active_q(cls) -> Q:
  215. # """Get the filter for objects that are actively processing right now"""
  216. # return Q(**{cls.state_field_name: cls._state_to_str(cls.active_state)}) # e.g. Q(status='started')
  217. # @classproperty
  218. # def stalled_q(cls) -> Q:
  219. # """Get the filter for objects that are marked active but have timed out"""
  220. # return cls.active_q & Q(retry_at__lte=timezone.now()) # e.g. Q(status='started') AND Q(<retry_at is in the past>)
  221. # @classproperty
  222. # def future_q(cls) -> Q:
  223. # """Get the filter for objects that have a retry_at in the future"""
  224. # return Q(retry_at__gt=timezone.now())
  225. # @classproperty
  226. # def pending_q(cls) -> Q:
  227. # """Get the filter for objects that are ready for processing."""
  228. # return ~(cls.active_q) & ~(cls.final_q) & ~(cls.future_q)
  229. # @classmethod
  230. # def get_queue(cls) -> QuerySet:
  231. # """
  232. # Get the sorted and filtered QuerySet of objects that are ready for processing.
  233. # e.g. qs.exclude(status__in=('sealed', 'started'), retry_at__gt=timezone.now()).order_by('retry_at')
  234. # """
  235. # return cls.objects.filter(cls.pending_q)
  236. class ModelWithStateMachine(BaseModelWithStateMachine):
  237. StatusChoices: ClassVar[Type[DefaultStatusChoices]] = DefaultStatusChoices
  238. status: models.CharField = BaseModelWithStateMachine.StatusField()
  239. retry_at: models.DateTimeField = BaseModelWithStateMachine.RetryAtField()
  240. state_machine_name: ClassVar[str] # e.g. 'core.statemachines.ArchiveResultMachine'
  241. state_field_name: ClassVar[str] = 'status'
  242. state_machine_attr: ClassVar[str] = 'sm'
  243. bind_events_as_methods: ClassVar[bool] = True
  244. active_state: ClassVar[str] = StatusChoices.STARTED
  245. retry_at_field_name: ClassVar[str] = 'retry_at'
  246. class Meta:
  247. abstract = True
  248. class EventDict(TypedDict, total=False):
  249. name: str
  250. id: str | uuid.UUID
  251. path: str
  252. content: str
  253. status: str
  254. retry_at: datetime | None
  255. url: str
  256. seed_id: str | uuid.UUID
  257. crawl_id: str | uuid.UUID
  258. snapshot_id: str | uuid.UUID
  259. process_id: str | uuid.UUID
  260. extractor: str
  261. error: str
  262. on_success: dict | None
  263. on_failure: dict | None
  264. class EventManager(models.Manager):
  265. pass
  266. class EventQuerySet(models.QuerySet):
  267. def get_next_unclaimed(self) -> 'Event | None':
  268. return self.filter(claimed_at=None).order_by('deliver_at').first()
  269. def expired(self, older_than: int=60 * 10) -> QuerySet['Event']:
  270. return self.filter(claimed_at__lt=timezone.now() - timedelta(seconds=older_than))
  271. class Event(ABIDModel):
  272. abid_prefix = 'evn_'
  273. abid_ts_src = 'self.deliver_at' # e.g. 'self.created_at'
  274. abid_uri_src = 'self.name' # e.g. 'self.uri' (MUST BE SET)
  275. abid_subtype_src = 'self.emitted_by' # e.g. 'self.extractor'
  276. abid_rand_src = 'self.id' # e.g. 'self.uuid' or 'self.id'
  277. abid_drift_allowed: bool = False # set to True to allow abid_field values to change after a fixed ABID has been issued (NOT RECOMMENDED: means values can drift out of sync from original ABID)
  278. read_only_fields = ('id', 'deliver_at', 'name', 'kwargs', 'timeout', 'parent', 'emitted_by', 'on_success', 'on_failure')
  279. id = models.UUIDField(primary_key=True, default=uuid.uuid4, null=False, editable=False, unique=True, verbose_name='ID')
  280. # disable these fields from inherited models, they're not needed / take up too much room
  281. abid = None
  282. created_at = None
  283. created_by = None
  284. created_by_id = None
  285. # immutable fields
  286. deliver_at = models.DateTimeField(default=timezone.now, null=False, editable=False, unique=True, db_index=True)
  287. name = models.CharField(max_length=255, null=False, blank=False, db_index=True)
  288. kwargs = models.JSONField(default=dict)
  289. timeout = models.IntegerField(null=False, default=60)
  290. parent = models.ForeignKey('Event', null=True, on_delete=models.SET_NULL, related_name='child_events')
  291. emitted_by = models.ForeignKey(Process, null=False, on_delete=models.PROTECT, related_name='emitted_events')
  292. on_success = models.JSONField(null=True)
  293. on_failure = models.JSONField(null=True)
  294. # mutable fields
  295. modified_at = models.DateTimeField(auto_now=True)
  296. claimed_proc = models.ForeignKey(Process, null=True, on_delete=models.CASCADE, related_name='claimed_events')
  297. claimed_at = models.DateTimeField(null=True)
  298. finished_at = models.DateTimeField(null=True)
  299. error = models.TextField(null=True)
  300. objects: EventManager = EventManager.from_queryset(EventQuerySet)()
  301. child_events: models.RelatedManager['Event']
  302. @classmethod
  303. def get_next_timestamp(cls):
  304. """Get the next monotonically increasing timestamp for the next event.dispatch_at"""
  305. latest_event = cls.objects.order_by('-deliver_at').first()
  306. ts = timezone.now()
  307. if latest_event:
  308. assert ts > latest_event.deliver_at, f'Event.deliver_at is not monotonically increasing: {latest_event.deliver_at} > {ts}'
  309. return ts
  310. @classmethod
  311. def dispatch(cls, name: str | EventDict | None = None, event: EventDict | None = None, **event_init_kwargs) -> 'Event':
  312. """
  313. Create a new Event and save it to the database.
  314. Can be called as either:
  315. >>> Event.dispatch(name, {**kwargs}, **event_init_kwargs)
  316. # OR
  317. >>> Event.dispatch({name, **kwargs}, **event_init_kwargs)
  318. """
  319. event_kwargs: EventDict = event or {}
  320. if isinstance(name, dict):
  321. event_kwargs.update(name)
  322. assert isinstance(event_kwargs, dict), 'must be called as Event.dispatch(name, {**kwargs}) or Event.dispatch({name, **kwargs})'
  323. event_name: str = name if (isinstance(name, str) and name) else event_kwargs.pop('name')
  324. new_event = cls(
  325. name=event_name,
  326. kwargs=event_kwargs,
  327. emitted_by=Process.current(),
  328. **event_init_kwargs,
  329. )
  330. new_event.save()
  331. return new_event
  332. def clean(self, *args, **kwargs) -> None:
  333. """Fill and validate all the event fields"""
  334. # check uuid and deliver_at are set
  335. assert self.id, 'Event.id must be set to a valid v4 UUID'
  336. if not self.deliver_at:
  337. self.deliver_at = self.get_next_timestamp()
  338. assert self.deliver_at and (datetime(2024, 12, 8, 12, 0, 0, tzinfo=timezone.utc) < self.deliver_at < datetime(2100, 12, 31, 23, 59, 0, tzinfo=timezone.utc)), (
  339. f'Event.deliver_at must be set to a valid UTC datetime (got Event.deliver_at = {self.deliver_at})')
  340. # if name is not set but it's found in the kwargs, move it out of the kwargs to the name field
  341. if 'type' in self.kwargs and ((self.name == self.kwargs['type']) or not self.name):
  342. self.name = self.kwargs.pop('type')
  343. if 'name' in self.kwargs and ((self.name == self.kwargs['name']) or not self.name):
  344. self.name = self.kwargs.pop('name')
  345. # check name is set and is a valid identifier
  346. assert isinstance(self.name, str) and len(self.name) > 3, 'Event.name must be set to a non-empty string'
  347. assert self.name.isidentifier(), f'Event.name must be a valid identifier (got Event.name = {self.name})'
  348. assert self.name.isupper(), f'Event.name must be in uppercase (got Event.name = {self.name})'
  349. # check that kwargs keys and values are valid
  350. for key, value in self.kwargs.items():
  351. assert isinstance(key, str), f'Event kwargs keys can only be strings (got Event.kwargs[{key}: {type(key).__name__}])'
  352. assert key not in self._meta.get_fields(), f'Event.kwargs cannot contain "{key}" key (Event.kwargs[{key}] conflicts with with reserved attr Event.{key} = {getattr(self, key)})'
  353. assert json.dumps(value, sort_keys=True), f'Event can only contain JSON serializable values (got Event.kwargs[{key}]: {type(value).__name__} = {value})'
  354. # validate on_success and on_failure are valid event dicts if set
  355. if self.on_success:
  356. assert isinstance(self.on_success, dict) and self.on_success.get('name', '!invalid').isidentifier(), f'Event.on_success must be a valid event dict (got {self.on_success})'
  357. if self.on_failure:
  358. assert isinstance(self.on_failure, dict) and self.on_failure.get('name', '!invalid').isidentifier(), f'Event.on_failure must be a valid event dict (got {self.on_failure})'
  359. # validate mutable fields like claimed_at, claimed_proc, finished_at are set correctly
  360. if self.claimed_at:
  361. assert self.claimed_proc, f'Event.claimed_at and Event.claimed_proc must be set together (only found Event.claimed_at = {self.claimed_at})'
  362. if self.claimed_proc:
  363. assert self.claimed_at, f'Event.claimed_at and Event.claimed_proc must be set together (only found Event.claimed_proc = {self.claimed_proc})'
  364. if self.finished_at:
  365. assert self.claimed_at, f'If Event.finished_at is set, Event.claimed_at and Event.claimed_proc must also be set (Event.claimed_proc = {self.claimed_proc} and Event.claimed_at = {self.claimed_at})'
  366. # validate error is a non-empty string or None
  367. if isinstance(self.error, BaseException):
  368. self.error = f'{type(self.error).__name__}: {self.error}'
  369. if self.error:
  370. assert isinstance(self.error, str) and str(self.error).strip(), f'Event.error must be a non-empty string (got Event.error: {type(self.error).__name__} = {self.error})'
  371. else:
  372. assert self.error is None, f'Event.error must be None or a non-empty string (got Event.error: {type(self.error).__name__} = {self.error})'
  373. def save(self, *args, **kwargs):
  374. self.clean()
  375. return super().save(*args, **kwargs)
  376. def reset(self):
  377. """Force-update an event to a pending/unclaimed state (without running any of its handlers or callbacks)"""
  378. self.claimed_proc = None
  379. self.claimed_at = None
  380. self.finished_at = None
  381. self.error = None
  382. self.save()
  383. def abort(self):
  384. """Force-update an event to a completed/failed state (without running any of its handlers or callbacks)"""
  385. self.claimed_proc = Process.current()
  386. self.claimed_at = timezone.now()
  387. self.finished_at = timezone.now()
  388. self.error = 'Aborted'
  389. self.save()
  390. def __repr__(self) -> str:
  391. label = f'[{self.name} {self.kwargs}]'
  392. if self.is_finished:
  393. label += f' ✅'
  394. elif self.claimed_proc:
  395. label += f' 🏃'
  396. return label
  397. def __str__(self) -> str:
  398. return repr(self)
  399. @property
  400. def type(self) -> str:
  401. return self.name
  402. @property
  403. def is_queued(self):
  404. return not self.is_claimed and not self.is_finished
  405. @property
  406. def is_claimed(self):
  407. return self.claimed_at is not None
  408. @property
  409. def is_expired(self):
  410. if not self.claimed_at:
  411. return False
  412. elapsed_time = timezone.now() - self.claimed_at
  413. return elapsed_time > timedelta(seconds=self.timeout)
  414. @property
  415. def is_processing(self):
  416. return self.is_claimed and not self.is_finished
  417. @property
  418. def is_finished(self):
  419. return self.finished_at is not None
  420. @property
  421. def is_failed(self):
  422. return self.is_finished and bool(self.error)
  423. @property
  424. def is_succeeded(self):
  425. return self.is_finished and not bool(self.error)
  426. def __getattr__(self, key: str):
  427. """
  428. Allow access to the event kwargs as attributes e.g.
  429. Event(name='CRAWL_CREATE', kwargs={'some_key': 'some_val'}).some_key -> 'some_val'
  430. """
  431. return self.kwargs.get(key)