models.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. from typing import ClassVar, Type, Iterable
  2. from datetime import datetime, timedelta
  3. from statemachine.mixins import MachineMixin
  4. from django.db import models
  5. from django.utils import timezone
  6. from django.utils.functional import classproperty
  7. from statemachine import registry, StateMachine, State
  8. from django.core import checks
  9. class DefaultStatusChoices(models.TextChoices):
  10. QUEUED = 'queued', 'Queued'
  11. STARTED = 'started', 'Started'
  12. SEALED = 'sealed', 'Sealed'
  13. default_status_field: models.CharField = models.CharField(choices=DefaultStatusChoices.choices, max_length=15, default=DefaultStatusChoices.QUEUED, null=False, blank=False, db_index=True)
  14. default_retry_at_field: models.DateTimeField = models.DateTimeField(default=timezone.now, null=False, db_index=True)
  15. ObjectState = State | str
  16. ObjectStateList = Iterable[ObjectState]
  17. class BaseModelWithStateMachine(models.Model, MachineMixin):
  18. id: models.UUIDField
  19. StatusChoices: ClassVar[Type[models.TextChoices]]
  20. # status: models.CharField
  21. # retry_at: models.DateTimeField
  22. state_machine_name: ClassVar[str]
  23. state_field_name: ClassVar[str]
  24. state_machine_attr: ClassVar[str] = 'sm'
  25. bind_events_as_methods: ClassVar[bool] = True
  26. active_state: ClassVar[ObjectState]
  27. retry_at_field_name: ClassVar[str]
  28. class Meta:
  29. abstract = True
  30. @classmethod
  31. def check(cls, sender=None, **kwargs):
  32. errors = super().check(**kwargs)
  33. found_id_field = False
  34. found_status_field = False
  35. found_retry_at_field = False
  36. for field in cls._meta.get_fields():
  37. if getattr(field, '_is_state_field', False):
  38. if cls.state_field_name == field.name:
  39. found_status_field = True
  40. if getattr(field, 'choices', None) != cls.StatusChoices.choices:
  41. errors.append(checks.Error(
  42. f'{cls.__name__}.{field.name} must have choices set to {cls.__name__}.StatusChoices.choices',
  43. hint=f'{cls.__name__}.{field.name}.choices = {getattr(field, "choices", None)!r}',
  44. obj=cls,
  45. id='actors.E011',
  46. ))
  47. if getattr(field, '_is_retry_at_field', False):
  48. if cls.retry_at_field_name == field.name:
  49. found_retry_at_field = True
  50. if field.name == 'id' and getattr(field, 'primary_key', False):
  51. found_id_field = True
  52. if not found_status_field:
  53. errors.append(checks.Error(
  54. f'{cls.__name__}.state_field_name must be defined and point to a StatusField()',
  55. hint=f'{cls.__name__}.state_field_name = {cls.state_field_name!r} but {cls.__name__}.{cls.state_field_name!r} was not found or does not refer to StatusField',
  56. obj=cls,
  57. id='actors.E012',
  58. ))
  59. if not found_retry_at_field:
  60. errors.append(checks.Error(
  61. f'{cls.__name__}.retry_at_field_name must be defined and point to a RetryAtField()',
  62. hint=f'{cls.__name__}.retry_at_field_name = {cls.retry_at_field_name!r} but {cls.__name__}.{cls.retry_at_field_name!r} was not found or does not refer to RetryAtField',
  63. obj=cls,
  64. id='actors.E013',
  65. ))
  66. if not found_id_field:
  67. errors.append(checks.Error(
  68. f'{cls.__name__} must have an id field that is a primary key',
  69. hint=f'{cls.__name__}.id = {cls.id!r}',
  70. obj=cls,
  71. id='actors.E014',
  72. ))
  73. if not isinstance(cls.state_machine_name, str):
  74. errors.append(checks.Error(
  75. f'{cls.__name__}.state_machine_name must be a dotted-import path to a StateMachine class',
  76. hint=f'{cls.__name__}.state_machine_name = {cls.state_machine_name!r}',
  77. obj=cls,
  78. id='actors.E015',
  79. ))
  80. try:
  81. cls.StateMachineClass
  82. except Exception as err:
  83. errors.append(checks.Error(
  84. f'{cls.__name__}.state_machine_name must point to a valid StateMachine class, but got {type(err).__name__} {err} when trying to access {cls.__name__}.StateMachineClass',
  85. hint=f'{cls.__name__}.state_machine_name = {cls.state_machine_name!r}',
  86. obj=cls,
  87. id='actors.E016',
  88. ))
  89. if cls.INITIAL_STATE not in cls.StatusChoices.values:
  90. errors.append(checks.Error(
  91. f'{cls.__name__}.StateMachineClass.initial_state must be present within {cls.__name__}.StatusChoices',
  92. hint=f'{cls.__name__}.StateMachineClass.initial_state = {cls.StateMachineClass.initial_state!r}',
  93. obj=cls,
  94. id='actors.E017',
  95. ))
  96. if cls.ACTIVE_STATE not in cls.StatusChoices.values:
  97. errors.append(checks.Error(
  98. f'{cls.__name__}.active_state must be set to a valid State present within {cls.__name__}.StatusChoices',
  99. hint=f'{cls.__name__}.active_state = {cls.active_state!r}',
  100. obj=cls,
  101. id='actors.E018',
  102. ))
  103. for state in cls.FINAL_STATES:
  104. if state not in cls.StatusChoices.values:
  105. errors.append(checks.Error(
  106. f'{cls.__name__}.StateMachineClass.final_states must all be present within {cls.__name__}.StatusChoices',
  107. hint=f'{cls.__name__}.StateMachineClass.final_states = {cls.StateMachineClass.final_states!r}',
  108. obj=cls,
  109. id='actors.E019',
  110. ))
  111. break
  112. return errors
  113. @staticmethod
  114. def _state_to_str(state: ObjectState) -> str:
  115. """Convert a statemachine.State, models.TextChoices.choices value, or Enum value to a str"""
  116. return str(state.value) if isinstance(state, State) else str(state)
  117. @property
  118. def RETRY_AT(self) -> datetime:
  119. return getattr(self, self.retry_at_field_name)
  120. @RETRY_AT.setter
  121. def RETRY_AT(self, value: datetime):
  122. setattr(self, self.retry_at_field_name, value)
  123. @property
  124. def STATE(self) -> str:
  125. return getattr(self, self.state_field_name)
  126. @STATE.setter
  127. def STATE(self, value: str):
  128. setattr(self, self.state_field_name, value)
  129. def bump_retry_at(self, seconds: int = 10):
  130. self.RETRY_AT = timezone.now() + timedelta(seconds=seconds)
  131. @classproperty
  132. def ACTIVE_STATE(cls) -> str:
  133. return cls._state_to_str(cls.active_state)
  134. @classproperty
  135. def INITIAL_STATE(cls) -> str:
  136. return cls._state_to_str(cls.StateMachineClass.initial_state)
  137. @classproperty
  138. def FINAL_STATES(cls) -> list[str]:
  139. return [cls._state_to_str(state) for state in cls.StateMachineClass.final_states]
  140. @classproperty
  141. def FINAL_OR_ACTIVE_STATES(cls) -> list[str]:
  142. return [*cls.FINAL_STATES, cls.ACTIVE_STATE]
  143. @classmethod
  144. def extend_choices(cls, base_choices: Type[models.TextChoices]):
  145. """
  146. Decorator to extend the base choices with extra choices, e.g.:
  147. class MyModel(ModelWithStateMachine):
  148. @ModelWithStateMachine.extend_choices(ModelWithStateMachine.StatusChoices)
  149. class StatusChoices(models.TextChoices):
  150. SUCCEEDED = 'succeeded'
  151. FAILED = 'failed'
  152. SKIPPED = 'skipped'
  153. """
  154. assert issubclass(base_choices, models.TextChoices), f'@extend_choices(base_choices) must be a TextChoices class, not {base_choices.__name__}'
  155. def wrapper(extra_choices: Type[models.TextChoices]) -> Type[models.TextChoices]:
  156. joined = {}
  157. for item in base_choices.choices:
  158. joined[item[0]] = item[1]
  159. for item in extra_choices.choices:
  160. joined[item[0]] = item[1]
  161. return models.TextChoices('StatusChoices', joined)
  162. return wrapper
  163. @classmethod
  164. def StatusField(cls, **kwargs) -> models.CharField:
  165. """
  166. Used on subclasses to extend/modify the status field with updated kwargs. e.g.:
  167. class MyModel(ModelWithStateMachine):
  168. class StatusChoices(ModelWithStateMachine.StatusChoices):
  169. QUEUED = 'queued', 'Queued'
  170. STARTED = 'started', 'Started'
  171. SEALED = 'sealed', 'Sealed'
  172. BACKOFF = 'backoff', 'Backoff'
  173. FAILED = 'failed', 'Failed'
  174. SKIPPED = 'skipped', 'Skipped'
  175. status = ModelWithStateMachine.StatusField(choices=StatusChoices.choices, default=StatusChoices.QUEUED)
  176. """
  177. default_kwargs = default_status_field.deconstruct()[3]
  178. updated_kwargs = {**default_kwargs, **kwargs}
  179. field = models.CharField(**updated_kwargs)
  180. field._is_state_field = True # type: ignore
  181. return field
  182. @classmethod
  183. def RetryAtField(cls, **kwargs) -> models.DateTimeField:
  184. """
  185. Used on subclasses to extend/modify the retry_at field with updated kwargs. e.g.:
  186. class MyModel(ModelWithStateMachine):
  187. retry_at = ModelWithStateMachine.RetryAtField(editable=False)
  188. """
  189. default_kwargs = default_retry_at_field.deconstruct()[3]
  190. updated_kwargs = {**default_kwargs, **kwargs}
  191. field = models.DateTimeField(**updated_kwargs)
  192. field._is_retry_at_field = True # type: ignore
  193. return field
  194. @classproperty
  195. def StateMachineClass(cls) -> Type[StateMachine]:
  196. """Get the StateMachine class for the given django Model that inherits from MachineMixin"""
  197. model_state_machine_name = getattr(cls, 'state_machine_name', None)
  198. if model_state_machine_name:
  199. StateMachineCls = registry.get_machine_cls(model_state_machine_name)
  200. assert issubclass(StateMachineCls, StateMachine)
  201. return StateMachineCls
  202. raise NotImplementedError(f'ActorType[{cls.__name__}] must define .state_machine_name: str that points to a valid StateMachine')
  203. # @classproperty
  204. # def final_q(cls) -> Q:
  205. # """Get the filter for objects that are in a final state"""
  206. # return Q(**{f'{cls.state_field_name}__in': cls.final_states})
  207. # @classproperty
  208. # def active_q(cls) -> Q:
  209. # """Get the filter for objects that are actively processing right now"""
  210. # return Q(**{cls.state_field_name: cls._state_to_str(cls.active_state)}) # e.g. Q(status='started')
  211. # @classproperty
  212. # def stalled_q(cls) -> Q:
  213. # """Get the filter for objects that are marked active but have timed out"""
  214. # return cls.active_q & Q(retry_at__lte=timezone.now()) # e.g. Q(status='started') AND Q(<retry_at is in the past>)
  215. # @classproperty
  216. # def future_q(cls) -> Q:
  217. # """Get the filter for objects that have a retry_at in the future"""
  218. # return Q(retry_at__gt=timezone.now())
  219. # @classproperty
  220. # def pending_q(cls) -> Q:
  221. # """Get the filter for objects that are ready for processing."""
  222. # return ~(cls.active_q) & ~(cls.final_q) & ~(cls.future_q)
  223. # @classmethod
  224. # def get_queue(cls) -> QuerySet:
  225. # """
  226. # Get the sorted and filtered QuerySet of objects that are ready for processing.
  227. # e.g. qs.exclude(status__in=('sealed', 'started'), retry_at__gt=timezone.now()).order_by('retry_at')
  228. # """
  229. # return cls.objects.filter(cls.pending_q)
  230. class ModelWithStateMachine(BaseModelWithStateMachine):
  231. StatusChoices: ClassVar[Type[DefaultStatusChoices]] = DefaultStatusChoices
  232. status: models.CharField = BaseModelWithStateMachine.StatusField()
  233. retry_at: models.DateTimeField = BaseModelWithStateMachine.RetryAtField()
  234. state_machine_name: ClassVar[str] # e.g. 'core.statemachines.ArchiveResultMachine'
  235. state_field_name: ClassVar[str] = 'status'
  236. state_machine_attr: ClassVar[str] = 'sm'
  237. bind_events_as_methods: ClassVar[bool] = True
  238. active_state: ClassVar[str] = StatusChoices.STARTED
  239. retry_at_field_name: ClassVar[str] = 'retry_at'
  240. class Meta:
  241. abstract = True