浏览代码

more StateMachine, Actor, and Orchestrator improvements

Nick Sweeting 1 年之前
父节点
当前提交
a9a3b153b1

+ 441 - 183
archivebox/actors/actor.py

@@ -2,78 +2,240 @@ __package__ = 'archivebox.actors'
 
 
 import os
 import os
 import time
 import time
-from abc import ABC, abstractmethod
-from typing import ClassVar, Generic, TypeVar, Any, cast, Literal, Type
-from django.utils.functional import classproperty
+from typing import ClassVar, Generic, TypeVar, Any, Literal, Type, Iterable, cast, get_args
+from datetime import timedelta
+from multiprocessing import Process, cpu_count
+from threading import Thread, get_native_id
 
 
-from rich import print
 import psutil
 import psutil
+from rich import print
+from statemachine import State, StateMachine, registry
+from statemachine.mixins import MachineMixin
 
 
 from django import db
 from django import db
-from django.db import models
-from django.db.models import QuerySet
-from multiprocessing import Process, cpu_count
-from threading import Thread, get_native_id
+from django.db.models import QuerySet, sql, Q
+from django.db.models import Model as DjangoModel
+from django.utils import timezone
+from django.utils.functional import classproperty
+
+from .models import ModelWithStateMachine
 
 
 # from archivebox.logging_util import TimedProgress
 # from archivebox.logging_util import TimedProgress
 
 
+class ActorObjectAlreadyClaimed(Exception):
+    """Raised when the Actor tries to claim the next object from the queue but it's already been claimed by another Actor"""
+    pass
+
+class ActorQueueIsEmpty(Exception):
+    """Raised when the Actor tries to get the next object from the queue but it's empty"""
+    pass
+
+CPU_COUNT = cpu_count()
+DEFAULT_MAX_TICK_TIME = 60
+DEFAULT_MAX_CONCURRENT_ACTORS = min(max(2, int(CPU_COUNT * 0.6)), 8)   # 2 < 60% * num available cpu cores < 8
+
+limit = lambda n, max: min(n, max)
+
 LaunchKwargs = dict[str, Any]
 LaunchKwargs = dict[str, Any]
+ObjectState = State | str
+ObjectStateList = Iterable[ObjectState]
 
 
-ModelType = TypeVar('ModelType', bound=models.Model)
+ModelType = TypeVar('ModelType', bound=ModelWithStateMachine)
 
 
-class ActorType(ABC, Generic[ModelType]):
+class ActorType(Generic[ModelType]):
     """
     """
     Base class for all actors. Usage:
     Base class for all actors. Usage:
-    class FaviconActor(ActorType[ArchiveResult]):
-        QUERYSET: ClassVar[QuerySet] = ArchiveResult.objects.filter(status='queued', extractor='favicon')
-        CLAIM_WHERE: ClassVar[str] = 'status = "queued" AND extractor = "favicon"'
-        CLAIM_ORDER: ClassVar[str] = 'created_at DESC'
-        ATOMIC: ClassVar[bool] = True
-
-        def claim_sql_set(self, obj: ArchiveResult) -> str:
-            # SQL fields to update atomically while claiming an object from the queue
-            retry_at = datetime.now() + timedelta(seconds=self.MAX_TICK_TIME)
-            return f"status = 'started', locked_by = {self.pid}, retry_at = {retry_at}"
-
-        def tick(self, obj: ArchiveResult) -> None:
-            run_favicon_extractor(obj)
-            ArchiveResult.objects.filter(pk=obj.pk, status='started').update(status='success')
+    
+    class FaviconActor(ActorType[FaviconArchiveResult]):
+        FINAL_STATES: ClassVar[tuple[str, ...]] = ('succeeded', 'failed', 'skipped')
+        ACTIVE_STATE: ClassVar[str] = 'started'
+        
+        @classmethod
+        def qs(cls) -> QuerySet[FaviconArchiveResult]:
+            return ArchiveResult.objects.filter(extractor='favicon')   # or leave the default: FaviconArchiveResult.objects.all()
     """
     """
+    
+    ### Class attributes (defined on the class at compile-time when ActorType[MyModel] is defined)
+    Model: Type[ModelType]
+    StateMachineClass: Type[StateMachine]
+    
+    STATE_FIELD_NAME: ClassVar[str]
+    ACTIVE_STATE: ClassVar[ObjectState]
+    FINAL_STATES: ClassVar[ObjectStateList]
+    EVENT_NAME: ClassVar[str] = 'tick'                                    # the event name to trigger on the obj.sm: StateMachine (usually 'tick')
+    
+    CLAIM_ORDER: ClassVar[tuple[str, ...]] = ('retry_at',)                # the .order(*args) to claim the queue objects in, use ('?',) for random order
+    CLAIM_FROM_TOP_N: ClassVar[int] = CPU_COUNT * 10                      # the number of objects to consider when atomically getting the next object from the queue
+    CLAIM_ATOMIC: ClassVar[bool] = True                                   # whether to atomically fetch+claim the next object in one query, or fetch and lock it in two queries
+    
+    MAX_TICK_TIME: ClassVar[int] = DEFAULT_MAX_TICK_TIME                  # maximum duration in seconds to process a single object
+    MAX_CONCURRENT_ACTORS: ClassVar[int] = DEFAULT_MAX_CONCURRENT_ACTORS  # maximum number of concurrent actors that can be running at once
+    
+    _SPAWNED_ACTOR_PIDS: ClassVar[list[psutil.Process]] = []      # used to record all the pids of Actors spawned on the class
+    
+    ### Instance attributes (only used within an actor instance inside a spawned actor thread/process)
     pid: int
     pid: int
     idle_count: int = 0
     idle_count: int = 0
     launch_kwargs: LaunchKwargs = {}
     launch_kwargs: LaunchKwargs = {}
     mode: Literal['thread', 'process'] = 'process'
     mode: Literal['thread', 'process'] = 'process'
     
     
-    MAX_CONCURRENT_ACTORS: ClassVar[int] = min(max(2, int(cpu_count() * 0.6)), 8)   # min 2, max 8, up to 60% of available cpu cores
-    MAX_TICK_TIME: ClassVar[int] = 60                          # maximum duration in seconds to process a single object
-    
-    QUERYSET: ClassVar[QuerySet]                      # the QuerySet to claim objects from
-    CLAIM_WHERE: ClassVar[str] = 'status = "queued"'  # the WHERE clause to filter the objects when atomically getting the next object from the queue
-    CLAIM_SET: ClassVar[str] = 'status = "started"'   # the SET clause to claim the object when atomically getting the next object from the queue
-    CLAIM_ORDER: ClassVar[str] = 'created_at DESC'    # the ORDER BY clause to sort the objects with when atomically getting the next object from the queue
-    CLAIM_FROM_TOP: ClassVar[int] = MAX_CONCURRENT_ACTORS * 10  # the number of objects to consider when atomically getting the next object from the queue
-    ATOMIC: ClassVar[bool] = True                     # whether to atomically fetch+claim the nextobject in one step, or fetch and lock it in two steps
-    
-    # model_type: Type[ModelType]
-    
-    _SPAWNED_ACTOR_PIDS: ClassVar[list[psutil.Process]] = []   # record all the pids of Actors spawned by this class
+    def __init_subclass__(cls) -> None:
+        """
+        Executed at class definition time (i.e. during import of any file containing class MyActor(ActorType[MyModel]): ...).
+        Loads the django Model from the Generic[ModelType] TypeVar arg and populates any missing class-level config using it.
+        """
+        if getattr(cls, 'Model', None) is None:
+            cls.Model = cls._get_model_from_generic_typevar()
+        cls._populate_missing_classvars_from_model(cls.Model)
     
     
     def __init__(self, mode: Literal['thread', 'process']|None=None, **launch_kwargs: LaunchKwargs):
     def __init__(self, mode: Literal['thread', 'process']|None=None, **launch_kwargs: LaunchKwargs):
+        """
+        Executed right before the Actor is spawned to create a unique Actor instance for that thread/process.
+        actor_instance.runloop() is then executed from inside the newly spawned thread/process.
+        """
         self.mode = mode or self.mode
         self.mode = mode or self.mode
         self.launch_kwargs = launch_kwargs or dict(self.launch_kwargs)
         self.launch_kwargs = launch_kwargs or dict(self.launch_kwargs)
     
     
+
+    ### Private Helper Methods: Not desiged to be overridden by subclasses or called by anything outside of this class
+    
     @classproperty
     @classproperty
     def name(cls) -> str:
     def name(cls) -> str:
         return cls.__name__  # type: ignore
         return cls.__name__  # type: ignore
     
     
     def __str__(self) -> str:
     def __str__(self) -> str:
-        return self.__repr__()
+        return repr(self)
     
     
     def __repr__(self) -> str:
     def __repr__(self) -> str:
-        """FaviconActor[pid=1234]"""
+        """-> FaviconActor[pid=1234]"""
         label = 'pid' if self.mode == 'process' else 'tid'
         label = 'pid' if self.mode == 'process' else 'tid'
         return f'[underline]{self.name}[/underline]\\[{label}={self.pid}]'
         return f'[underline]{self.name}[/underline]\\[{label}={self.pid}]'
     
     
+    @staticmethod
+    def _state_to_str(state: ObjectState) -> str:
+        """Convert a statemachine.State, models.TextChoices.choices value, or Enum value to a str"""
+        return str(state.value) if isinstance(state, State) else str(state)
+    
+    @staticmethod
+    def _sql_for_select_top_n_candidates(qs: QuerySet, claim_from_top_n: int=CLAIM_FROM_TOP_N) -> tuple[str, tuple[Any, ...]]:
+        """Get the SQL for selecting the top N candidates from the queue (to claim one from)"""
+        queryset = qs.only('id')[:claim_from_top_n]
+        select_sql, select_params = compile_sql_select(queryset)
+        return select_sql, select_params
+    
+    @staticmethod
+    def _sql_for_update_claimed_obj(qs: QuerySet, update_kwargs: dict[str, Any]) -> tuple[str, tuple[Any, ...]]:
+        """Get the SQL for updating a claimed object to mark it as ACTIVE"""
+        # qs.update(status='started', retry_at=<now + MAX_TICK_TIME>)
+        update_sql, update_params = compile_sql_update(qs, update_kwargs=update_kwargs)
+        # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN ('succeeded', 'failed', 'sealed', 'started') AND retry_at <= '2024-11-04 10:14:33.240903'
+        return update_sql, update_params
+    
+    @classmethod
+    def _get_model_from_generic_typevar(cls) -> Type[ModelType]:
+        """Get the django Model from the Generic[ModelType] TypeVar arg (and check that it inherits from django.db.models.Model)"""
+        # cls.__orig_bases__ is non-standard and may be removed in the future! if this breaks,
+        # we can just require the inerited class to define the Model as a classvar manually, e.g.:
+        #     class SnapshotActor(ActorType[Snapshot]):
+        #         Model: ClassVar[Type[Snapshot]] = Snapshot
+        # https://stackoverflow.com/questions/57706180/generict-base-class-how-to-get-type-of-t-from-within-instance
+        Model = get_args(cls.__orig_bases__[0])[0]   # type: ignore
+        assert issubclass(Model, DjangoModel), f'{cls.__name__}.Model must be a valid django Model'
+        return cast(Type[ModelType], Model)
+    
+    @staticmethod
+    def _get_state_machine_cls(Model: Type[ModelType]) -> Type[StateMachine]:
+        """Get the StateMachine class for the given django Model that inherits from MachineMixin"""
+        assert issubclass(Model, MachineMixin), f'{Model.__name__} must inherit from MachineMixin and define a .state_machine_name: str'
+        model_state_machine_name = getattr(Model, 'state_machine_name', None)
+        if model_state_machine_name:
+            StateMachine = registry.get_machine_cls(model_state_machine_name)
+            assert issubclass(StateMachine, StateMachine)
+            return StateMachine
+        raise NotImplementedError(f'ActorType[{Model.__name__}] must define .state_machine_name: str that points to a valid StateMachine')
+    
+    @classmethod
+    def _get_state_machine_instance(cls, obj: ModelType) -> StateMachine:
+        """Get the StateMachine instance for the given django Model instance (and check that it is a valid instance of cls.StateMachineClass)"""
+        obj_statemachine = None
+        state_machine_attr = getattr(obj, 'state_machine_attr', 'sm')
+        try:
+            obj_statemachine = getattr(obj, state_machine_attr)
+        except Exception:
+            pass
+        
+        if not isinstance(obj_statemachine, cls.StateMachineClass):
+            raise Exception(f'{cls.__name__}: Failed to find a valid StateMachine instance at {type(obj).__name__}.{state_machine_attr}')
+            
+        return obj_statemachine
+    
+    @classmethod
+    def _populate_missing_classvars_from_model(cls, Model: Type[ModelType]):
+        """Check that the class variables are set correctly based on the ModelType"""
+        
+        # check that Model is the same as the Generic[ModelType] parameter in the class definition
+        cls.Model = getattr(cls, 'Model', None) or Model
+        if cls.Model != Model:
+            raise ValueError(f'{cls.__name__}.Model must be set to the same Model as the Generic[ModelType] parameter in the class definition')
+        
+        # check that Model has a valid StateMachine with the required event defined on it
+        cls.StateMachineClass = getattr(cls, 'StateMachineClass', None) or cls._get_state_machine_cls(cls.Model)
+        assert isinstance(cls.EVENT_NAME, str), f'{cls.__name__}.EVENT_NAME must be a str, got: {type(cls.EVENT_NAME).__name__} instead'
+        assert hasattr(cls.StateMachineClass, cls.EVENT_NAME), f'StateMachine {cls.StateMachineClass.__name__} must define a {cls.EVENT_NAME} event ({cls.__name__}.EVENT_NAME = {cls.EVENT_NAME})'
+        
+        # check that Model uses .id as its primary key field
+        primary_key_field = cls.Model._meta.pk.name
+        if primary_key_field != 'id':
+            raise NotImplementedError(f'Actors currently only support models that use .id as their primary key field ({cls.__name__} uses {cls.__name__}.{primary_key_field} as primary key)')
+        
+        # check for STATE_FIELD_NAME classvar or set it from the model's state_field_name attr
+        if not getattr(cls, 'STATE_FIELD_NAME', None):
+            if hasattr(cls.Model, 'state_field_name'):
+                cls.STATE_FIELD_NAME = getattr(cls.Model, 'state_field_name')
+            else:
+                raise NotImplementedError(f'{cls.__name__} must define a STATE_FIELD_NAME: ClassVar[str] (e.g. "status") or have a .state_field_name attr on its Model')
+        assert isinstance(cls.STATE_FIELD_NAME, str), f'{cls.__name__}.STATE_FIELD_NAME must be a str, got: {type(cls.STATE_FIELD_NAME).__name__} instead'
+            
+        # check for FINAL_STATES classvar or set it from the model's final_states attr
+        if not getattr(cls, 'FINAL_STATES', None):
+            cls.FINAL_STATES = cls.StateMachineClass.final_states
+            if not cls.FINAL_STATES:
+                raise NotImplementedError(f'{cls.__name__} must define a non-empty FINAL_STATES: ClassVar[list[str]] (e.g. ["sealed"]) or have a {cls.Model.__name__}.state_machine_name pointing to a StateMachine that provides .final_states')
+        cls.FINAL_STATES = [cls._state_to_str(state) for state in cls.FINAL_STATES]
+        assert all(isinstance(state, str) for state in cls.FINAL_STATES), f'{cls.__name__}.FINAL_STATES must be a list[str], got: {type(cls.FINAL_STATES).__name__} instead'
+        
+        # check for ACTIVE_STATE classvar or set it from the model's active_state attr
+        if not getattr(cls, 'ACTIVE_STATE', None):
+            raise NotImplementedError(f'{cls.__name__} must define an ACTIVE_STATE: ClassVar[State] (e.g. SnapshotMachine.started) ({cls.Model.__name__}.{cls.STATE_FIELD_NAME} gets set to this value to mark objects as actively processing)')
+        assert isinstance(cls.ACTIVE_STATE, (State, str)), f'{cls.__name__}.ACTIVE_STATE must be a statemachine.State | str, got: {type(cls.ACTIVE_STATE).__name__} instead'
+        
+        # check the other ClassVar attributes for valid values
+        assert cls.CLAIM_ORDER and isinstance(cls.CLAIM_ORDER, tuple) and all(isinstance(order, str) for order in cls.CLAIM_ORDER), f'{cls.__name__}.CLAIM_ORDER must be a non-empty tuple[str, ...], got: {type(cls.CLAIM_ORDER).__name__} instead'
+        assert cls.CLAIM_FROM_TOP_N > 0, f'{cls.__name__}.CLAIM_FROM_TOP_N must be a positive int, got: {cls.CLAIM_FROM_TOP_N} instead'
+        assert cls.MAX_TICK_TIME >= 1, f'{cls.__name__}.MAX_TICK_TIME must be a positive int > 1, got: {cls.MAX_TICK_TIME} instead'
+        assert cls.MAX_CONCURRENT_ACTORS >= 1, f'{cls.__name__}.MAX_CONCURRENT_ACTORS must be a positive int >=1, got: {cls.MAX_CONCURRENT_ACTORS} instead'
+        assert isinstance(cls.CLAIM_ATOMIC, bool), f'{cls.__name__}.CLAIM_ATOMIC must be a bool, got: {cls.CLAIM_ATOMIC} instead'
+
+    @classmethod
+    def _fork_actor_as_thread(cls, **launch_kwargs: LaunchKwargs) -> int:
+        """Spawn a new background thread running the actor's runloop"""
+        actor = cls(mode='thread', **launch_kwargs)
+        bg_actor_thread = Thread(target=actor.runloop)
+        bg_actor_thread.start()
+        assert bg_actor_thread.native_id is not None
+        return bg_actor_thread.native_id
+    
+    @classmethod
+    def _fork_actor_as_process(cls, **launch_kwargs: LaunchKwargs) -> int:
+        """Spawn a new background process running the actor's runloop"""
+        actor = cls(mode='process', **launch_kwargs)
+        bg_actor_process = Process(target=actor.runloop)
+        bg_actor_process.start()
+        assert bg_actor_process.pid is not None
+        cls._SPAWNED_ACTOR_PIDS.append(psutil.Process(pid=bg_actor_process.pid))
+        return bg_actor_process.pid
+    
+    
     ### Class Methods: Called by Orchestrator on ActorType class before it has been spawned
     ### Class Methods: Called by Orchestrator on ActorType class before it has been spawned
     
     
     @classmethod
     @classmethod
@@ -94,71 +256,92 @@ class ActorType(ABC, Generic[ModelType]):
         if not queue_length:                                      # queue is empty, spawn 0 actors
         if not queue_length:                                      # queue is empty, spawn 0 actors
             return []
             return []
         
         
-        actors_to_spawn: list[LaunchKwargs] = []
-        max_spawnable = cls.MAX_CONCURRENT_ACTORS - len(running_actors)
+        # WARNING:
+        # spawning new actors processes is slow/expensive, avoid spawning many actors at once in a single orchestrator tick.
+        # limit to spawning 1 or 2 at a time per orchestrator tick, and let the next tick handle starting another couple.
+        # DONT DO THIS:
+        # if queue_length > 20:                      # queue is extremely long, spawn maximum actors at once!
+        #   num_to_spawn_this_tick = cls.MAX_CONCURRENT_ACTORS
         
         
-        # spawning new actors is expensive, avoid spawning all the actors at once. To stagger them,
-        # let the next orchestrator tick handle starting another 2 on the next tick()
-        # if queue_length > 10:                                   # queue is long, spawn as many as possible
-        #   actors_to_spawn += max_spawnable * [{}]
+        if queue_length > 10:    
+            num_to_spawn_this_tick = 2  # spawn more actors per tick if queue is long
+        else:
+            num_to_spawn_this_tick = 1  # spawn fewer actors per tick if queue is short
+        
+        num_remaining = cls.MAX_CONCURRENT_ACTORS - len(running_actors)
+        num_to_spawn_now: int = limit(num_to_spawn_this_tick, num_remaining)
         
         
-        if queue_length > 4:                                    # queue is medium, spawn 1 or 2 actors
-            actors_to_spawn += min(2, max_spawnable) * [{**cls.launch_kwargs}]
-        else:                                                     # queue is short, spawn 1 actor
-            actors_to_spawn += min(1, max_spawnable) * [{**cls.launch_kwargs}]
-        return actors_to_spawn
+        actors_launch_kwargs: list[LaunchKwargs] = num_to_spawn_now * [{**cls.launch_kwargs}]
+        return actors_launch_kwargs
         
         
     @classmethod
     @classmethod
     def start(cls, mode: Literal['thread', 'process']='process', **launch_kwargs: LaunchKwargs) -> int:
     def start(cls, mode: Literal['thread', 'process']='process', **launch_kwargs: LaunchKwargs) -> int:
         if mode == 'thread':
         if mode == 'thread':
-            return cls.fork_actor_as_thread(**launch_kwargs)
+            return cls._fork_actor_as_thread(**launch_kwargs)
         elif mode == 'process':
         elif mode == 'process':
-            return cls.fork_actor_as_process(**launch_kwargs)
+            return cls._fork_actor_as_process(**launch_kwargs)
         raise ValueError(f'Invalid actor mode: {mode} must be "thread" or "process"')
         raise ValueError(f'Invalid actor mode: {mode} must be "thread" or "process"')
-        
-    @classmethod
-    def fork_actor_as_thread(cls, **launch_kwargs: LaunchKwargs) -> int:
-        """Spawn a new background thread running the actor's runloop"""
-        actor = cls(mode='thread', **launch_kwargs)
-        bg_actor_thread = Thread(target=actor.runloop)
-        bg_actor_thread.start()
-        assert bg_actor_thread.native_id is not None
-        return bg_actor_thread.native_id
     
     
-    @classmethod
-    def fork_actor_as_process(cls, **launch_kwargs: LaunchKwargs) -> int:
-        """Spawn a new background process running the actor's runloop"""
-        actor = cls(mode='process', **launch_kwargs)
-        bg_actor_process = Process(target=actor.runloop)
-        bg_actor_process.start()
-        assert bg_actor_process.pid is not None
-        cls._SPAWNED_ACTOR_PIDS.append(psutil.Process(pid=bg_actor_process.pid))
-        return bg_actor_process.pid
+    @classproperty
+    def qs(cls) -> QuerySet[ModelType]:
+        """
+        Get the unfiltered and unsorted QuerySet of all objects that this Actor might care about.
+        Override this in the subclass to define the QuerySet of objects that the Actor is going to poll for new work.
+        (don't limit, order, or filter this by retry_at or status yet, Actor.get_queue() handles that part)
+        """
+        return cls.Model.objects.all()
     
     
-    @classmethod
-    def get_model(cls) -> Type[ModelType]:
-        # wish this was a @classproperty but Generic[ModelType] return type cant be statically inferred for @classproperty
-        return cls.QUERYSET.model
+    @classproperty
+    def final_q(cls) -> Q:
+        """Get the filter for objects that are in a final state"""
+        return Q(**{f'{cls.STATE_FIELD_NAME}__in': [cls._state_to_str(s) for s in cls.FINAL_STATES]})
     
     
-    @classmethod
-    def get_queue(cls) -> QuerySet:
-        """override this to provide your queryset as the queue"""
-        # return ArchiveResult.objects.filter(status='queued', extractor__in=('pdf', 'dom', 'screenshot'))
-        return cls.QUERYSET
+    @classproperty
+    def active_q(cls) -> Q:
+        """Get the filter for objects that are actively processing right now"""
+        return Q(**{cls.STATE_FIELD_NAME: cls._state_to_str(cls.ACTIVE_STATE)})   # e.g. Q(status='started')
     
     
-    ### Instance Methods: Called by Actor after it has been spawned (i.e. forked as a thread or process)
+    @classproperty
+    def stalled_q(cls) -> Q:
+        """Get the filter for objects that are marked active but have timed out"""
+        return cls.active_q & Q(retry_at__lte=timezone.now())                     # e.g. Q(status='started') AND Q(<retry_at is in the past>)
+    
+    @classproperty
+    def future_q(cls) -> Q:
+        """Get the filter for objects that have a retry_at in the future"""
+        return Q(retry_at__gt=timezone.now())
+    
+    @classproperty
+    def pending_q(cls) -> Q:
+        """Get the filter for objects that are ready for processing."""
+        return ~(cls.active_q) & ~(cls.final_q) & ~(cls.future_q)
+    
+    @classmethod
+    def get_queue(cls, sort: bool=True) -> QuerySet[ModelType]:
+        """
+        Get the sorted and filtered QuerySet of objects that are ready for processing.
+        e.g. qs.exclude(status__in=('sealed', 'started'), retry_at__gt=timezone.now()).order_by('retry_at')
+        """
+        unsorted_qs = cls.qs.filter(cls.pending_q)
+        return unsorted_qs.order_by(*cls.CLAIM_ORDER) if sort else unsorted_qs
+
+    ### Instance Methods: Only called from within Actor instance after it has been spawned (i.e. forked as a thread or process)
     
     
     def runloop(self):
     def runloop(self):
         """The main runloop that starts running when the actor is spawned (as subprocess or thread) and exits when the queue is empty"""
         """The main runloop that starts running when the actor is spawned (as subprocess or thread) and exits when the queue is empty"""
         self.on_startup()
         self.on_startup()
+        obj_to_process: ModelType | None = None
+        last_error: BaseException | None = None
         try:
         try:
             while True:
             while True:
-                obj_to_process: ModelType | None = None
+                # Get the next object to process from the queue
                 try:
                 try:
                     obj_to_process = cast(ModelType, self.get_next(atomic=self.atomic))
                     obj_to_process = cast(ModelType, self.get_next(atomic=self.atomic))
-                except Exception:
-                    pass
+                except (ActorQueueIsEmpty, ActorObjectAlreadyClaimed) as err:
+                    last_error = err
+                    obj_to_process = None
                 
                 
+                # Handle the case where there is no next object to process
                 if obj_to_process:
                 if obj_to_process:
                     self.idle_count = 0   # reset idle count if we got an object
                     self.idle_count = 0   # reset idle count if we got an object
                 else:
                 else:
@@ -170,119 +353,127 @@ class ActorType(ABC, Generic[ModelType]):
                         time.sleep(1)
                         time.sleep(1)
                         continue
                         continue
                 
                 
+                # Process the object by triggering its StateMachine.tick() method
                 self.on_tick_start(obj_to_process)
                 self.on_tick_start(obj_to_process)
-                
-                # Process the object
                 try:
                 try:
                     self.tick(obj_to_process)
                     self.tick(obj_to_process)
                 except Exception as err:
                 except Exception as err:
-                    print(f'[red]🏃‍♂️ ERROR: {self}.tick()[/red]', err)
+                    last_error = err
+                    # print(f'[red]🏃‍♂️ {self}.tick()[/red] {obj_to_process} ERROR: [red]{type(err).__name__}: {err}[/red]')
                     db.connections.close_all()                         # always reset the db connection after an exception to clear any pending transactions
                     db.connections.close_all()                         # always reset the db connection after an exception to clear any pending transactions
                     self.on_tick_exception(obj_to_process, err)
                     self.on_tick_exception(obj_to_process, err)
                 finally:
                 finally:
                     self.on_tick_end(obj_to_process)
                     self.on_tick_end(obj_to_process)
-            
-            self.on_shutdown(err=None)
+
         except BaseException as err:
         except BaseException as err:
+            last_error = err
             if isinstance(err, KeyboardInterrupt):
             if isinstance(err, KeyboardInterrupt):
                 print()
                 print()
             else:
             else:
-                print(f'\n[red]🏃‍♂️ {self}.runloop() FATAL:[/red]', err.__class__.__name__, err)
-            self.on_shutdown(err=err)
+                print(f'\n[red]🏃‍♂️ {self}.runloop() FATAL:[/red] {type(err).__name__}: {err}')
+                print(f'    Last processed object: {obj_to_process}')
+                raise
+        finally:
+            self.on_shutdown(last_obj=obj_to_process, last_error=last_error)
+    
+    def get_update_kwargs_to_claim_obj(self) -> dict[str, Any]:
+        """
+        Get the field values needed to mark an pending obj_to_process as being actively processing (aka claimed)
+        by the current Actor. returned kwargs will be applied using: qs.filter(id=obj_to_process.id).update(**kwargs).
+        F() expressions are allowed in field values if you need to update a field based on its current value.
+        Can be a defined as a normal method (instead of classmethod) on subclasses if it needs to access instance vars.
+        """
+        return {
+            self.STATE_FIELD_NAME: self.ACTIVE_STATE,
+            'retry_at': timezone.now() + timedelta(seconds=self.MAX_TICK_TIME),
+        }
     
     
     def get_next(self, atomic: bool | None=None) -> ModelType | None:
     def get_next(self, atomic: bool | None=None) -> ModelType | None:
         """get the next object from the queue, atomically locking it if self.atomic=True"""
         """get the next object from the queue, atomically locking it if self.atomic=True"""
-        if atomic is None:
-            atomic = self.ATOMIC
-
+        atomic = self.CLAIM_ATOMIC if atomic is None else atomic
         if atomic:
         if atomic:
             # fetch and claim the next object from in the queue in one go atomically
             # fetch and claim the next object from in the queue in one go atomically
             obj = self.get_next_atomic()
             obj = self.get_next_atomic()
         else:
         else:
             # two-step claim: fetch the next object and lock it in a separate query
             # two-step claim: fetch the next object and lock it in a separate query
-            obj = self.get_queue().last()
-            assert obj and self.lock_next(obj), f'Unable to fetch+lock the next {self.get_model().__name__} ojbect from {self}.QUEUE'
+            obj = self.get_next_non_atomic()
         return obj
         return obj
     
     
-    def lock_next(self, obj: ModelType) -> bool:
-        """override this to implement a custom two-step (non-atomic)lock mechanism"""
-        # For example:
-        # assert obj._model.objects.filter(pk=obj.pk, status='queued').update(status='started', locked_by=self.pid)
-        # Not needed if using get_next_and_lock() to claim the object atomically
-        # print(f'[blue]🏃‍♂️ {self}.lock()[/blue]', obj.abid or obj.id)
-        return True
-    
-    def claim_sql_where(self) -> str:
-        """override this to implement a custom WHERE clause for the atomic claim step e.g. "status = 'queued' AND locked_by = NULL" """
-        return self.CLAIM_WHERE
-    
-    def claim_sql_set(self) -> str:
-        """override this to implement a custom SET clause for the atomic claim step e.g. "status = 'started' AND locked_by = {self.pid}" """
-        return self.CLAIM_SET
-    
-    def claim_sql_order(self) -> str:
-        """override this to implement a custom ORDER BY clause for the atomic claim step e.g. "created_at DESC" """
-        return self.CLAIM_ORDER
-    
-    def claim_from_top(self) -> int:
-        """override this to implement a custom number of objects to consider when atomically claiming the next object from the top of the queue"""
-        return self.CLAIM_FROM_TOP
-        
-    def get_next_atomic(self, shallow: bool=True) -> ModelType | None:
+    def get_next_non_atomic(self) -> ModelType:
         """
         """
-        claim a random object from the top n=50 objects in the queue (atomically updates status=queued->started for claimed object)
-        optimized for minimizing contention on the queue with other actors selecting from the same list
-        slightly faster than claim_any_obj() which selects randomly from the entire queue but needs to know the total count
+        Naiively selects the top/first object from self.get_queue().order_by(*self.CLAIM_ORDER),
+        then claims it by running .update(status='started', retry_at=<now + MAX_TICK_TIME>).
+        
+        Do not use this method if there is more than one Actor racing to get objects from the same queue,
+        it will be slow/buggy as they'll compete to lock the same object at the same time (TOCTTOU race).
         """
         """
-        Model = self.get_model()                                     # e.g. ArchiveResult
-        table = f'{Model._meta.app_label}_{Model._meta.model_name}'  # e.g. core_archiveresult
+        obj = self.get_queue().first()
+        if obj is None:
+            raise ActorQueueIsEmpty(f'No next object available in {self}.get_queue()')
         
         
-        where_sql = self.claim_sql_where()
-        set_sql = self.claim_sql_set()
-        order_by_sql = self.claim_sql_order()
-        choose_from_top = self.claim_from_top()
+        locked = self.get_queue().filter(id=obj.id).update(**self.get_update_kwargs_to_claim_obj())
+        if not locked:
+            raise ActorObjectAlreadyClaimed(f'Unable to lock the next {self.Model.__name__} object from {self}.get_queue().first()')
+        return obj
         
         
-        with db.connection.cursor() as cursor:
-            # subquery gets the pool of the top 50 candidates sorted by sort and order
-            # main query selects a random one from that pool
-            cursor.execute(f"""
-                UPDATE {table} 
-                SET {set_sql}
-                WHERE {where_sql} and id = (
-                    SELECT id FROM (
-                        SELECT id FROM {table}
-                        WHERE {where_sql}
-                        ORDER BY {order_by_sql}
-                        LIMIT {choose_from_top}
-                    ) candidates
-                    ORDER BY RANDOM()
-                    LIMIT 1
-                )
-                RETURNING id;
-            """)
-            result = cursor.fetchone()
-            
-            if result is None:
-                return None           # If no rows were claimed, return None
-
-            if shallow:
-                # shallow: faster, returns potentially incomplete object instance missing some django auto-populated fields:
-                columns = [col[0] for col in cursor.description or ['id']]
-                return Model(**dict(zip(columns, result)))
+    def get_next_atomic(self) -> ModelType | None:
+        """
+        Selects the top n=50 objects from the queue and atomically claims a random one from that set.
+        This approach safely minimizes contention with other Actors trying to select from the same Queue.
 
 
-            # if not shallow do one extra query to get a more complete object instance (load it fully from scratch)
-            return Model.objects.get(id=result[0])
+        The atomic query is roughly equivalent to the following:  (all done in one SQL query to avoid a TOCTTOU race)
+            top_candidates are selected from:   qs.order_by(*CLAIM_ORDER).only('id')[:CLAIM_FROM_TOP_N]
+            a single candidate is chosen using: qs.filter(id__in=top_n_candidates).order_by('?').first()
+            the chosen obj is claimed using:    qs.filter(id=chosen_obj).update(status=ACTIVE_STATE, retry_at=<now + MAX_TICK_TIME>)
+        """
+        # TODO: if we switch from SQLite to PostgreSQL in the future, we should change this
+        # to use SELECT FOR UPDATE instead of a subquery + ORDER BY RANDOM() LIMIT 1
+        
+        # e.g. SELECT id FROM core_archiveresult WHERE status NOT IN (...) AND retry_at <= '...' ORDER BY retry_at ASC LIMIT 50
+        qs = self.get_queue()
+        select_top_canidates_sql, select_params = self._sql_for_select_top_n_candidates(qs=qs)
+        assert select_top_canidates_sql.startswith('SELECT ')
+        
+        # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (...) AND retry_at <= '...'
+        update_claimed_obj_sql, update_params = self._sql_for_update_claimed_obj(qs=qs, update_kwargs=self.get_update_kwargs_to_claim_obj())
+        assert update_claimed_obj_sql.startswith('UPDATE ')
+        db_table = self.Model._meta.db_table  # e.g. core_archiveresult
+        
+        # subquery gets the pool of the top candidates e.g. self.get_queue().only('id')[:CLAIM_FROM_TOP_N]
+        # main query selects a random one from that pool, and claims it using .update(status=ACTIVE_STATE, retry_at=<now + MAX_TICK_TIME>)
+        # this is all done in one atomic SQL query to avoid TOCTTOU race conditions (as much as possible)
+        atomic_select_and_update_sql = f"""
+            {update_claimed_obj_sql} AND "{db_table}"."id" = (
+                SELECT "{db_table}"."id" FROM (
+                    {select_top_canidates_sql}
+                ) candidates
+                ORDER BY RANDOM()
+                LIMIT 1
+            )
+            RETURNING *;
+        """
+        try:
+            return self.Model.objects.raw(atomic_select_and_update_sql, (*update_params, *select_params))[0]
+        except KeyError:
+            if self.get_queue().exists():
+                raise ActorObjectAlreadyClaimed(f'Unable to lock the next {self.Model.__name__} object from {self}.get_queue().first()')
+            else:
+                raise ActorQueueIsEmpty(f'No next object available in {self}.get_queue()')
 
 
-    @abstractmethod
-    def tick(self, obj: ModelType) -> None:
-        """override this to process the object"""
-        print(f'[blue]🏃‍♂️ {self}.tick()[/blue]', obj.abid or obj.id)
-        # For example:
-        # do_some_task(obj)
-        # do_something_else(obj)
-        # obj._model.objects.filter(pk=obj.pk, status='started').update(status='success')
-        raise NotImplementedError('tick() must be implemented by the Actor subclass')
-    
+    def tick(self, obj_to_process: ModelType) -> None:
+        """Call the object.sm.tick() method to process the object"""
+        print(f'[blue]🏃‍♂️ {self}.tick()[/blue] {obj_to_process}')
+        
+        # get the StateMachine instance from the object
+        obj_statemachine = self._get_state_machine_instance(obj_to_process)
+        
+        # trigger the event on the StateMachine instance
+        obj_tick_method = getattr(obj_statemachine, self.EVENT_NAME)  # e.g. obj_statemachine.tick()
+        obj_tick_method()
+        
+        # save the object to persist any state changes
+        obj_to_process.save()
+        
     def on_startup(self) -> None:
     def on_startup(self) -> None:
         if self.mode == 'thread':
         if self.mode == 'thread':
             self.pid = get_native_id()  # thread id
             self.pid = get_native_id()  # thread id
@@ -290,24 +481,91 @@ class ActorType(ABC, Generic[ModelType]):
         else:
         else:
             self.pid = os.getpid()      # process id
             self.pid = os.getpid()      # process id
             print(f'[green]🏃‍♂️ {self}.on_startup() STARTUP (PROCESS)[/green]')
             print(f'[green]🏃‍♂️ {self}.on_startup() STARTUP (PROCESS)[/green]')
-        # abx.pm.hook.on_actor_startup(self)
+        # abx.pm.hook.on_actor_startup(actor=self)
         
         
-    def on_shutdown(self, err: BaseException | None=None) -> None:
-        print(f'[grey53]🏃‍♂️ {self}.on_shutdown() SHUTTING DOWN[/grey53]', err or '[green](gracefully)[/green]')
-        # abx.pm.hook.on_actor_shutdown(self)
+    def on_shutdown(self, last_obj: ModelType | None=None, last_error: BaseException | None=None) -> None:
+        if isinstance(last_error, KeyboardInterrupt) or last_error is None:
+            last_error_str = '[green](CTRL-C)[/green]'
+        elif isinstance(last_error, ActorQueueIsEmpty):
+            last_error_str = '[green](queue empty)[/green]'
+        elif isinstance(last_error, ActorObjectAlreadyClaimed):
+            last_error_str = '[green](queue race)[/green]'
+        else:
+            last_error_str = f'[red]{type(last_error).__name__}: {last_error}[/red]'
+
+        print(f'[grey53]🏃‍♂️ {self}.on_shutdown() SHUTTING DOWN[/grey53] {last_error_str}')
+        # abx.pm.hook.on_actor_shutdown(actor=self, last_obj=last_obj, last_error=last_error)
         
         
-    def on_tick_start(self, obj: ModelType) -> None:
-        # print(f'🏃‍♂️ {self}.on_tick_start()', obj.abid or obj.id)
-        # abx.pm.hook.on_actor_tick_start(self, obj_to_process)
+    def on_tick_start(self, obj_to_process: ModelType) -> None:
+        print(f'🏃‍♂️ {self}.on_tick_start() {obj_to_process}')
+        # abx.pm.hook.on_actor_tick_start(actor=self, obj_to_process=obj)
         # self.timer = TimedProgress(self.MAX_TICK_TIME, prefix='      ')
         # self.timer = TimedProgress(self.MAX_TICK_TIME, prefix='      ')
-        pass
     
     
-    def on_tick_end(self, obj: ModelType) -> None:
-        # print(f'🏃‍♂️ {self}.on_tick_end()', obj.abid or obj.id)
-        # abx.pm.hook.on_actor_tick_end(self, obj_to_process)
+    def on_tick_end(self, obj_to_process: ModelType) -> None:
+        print(f'🏃‍♂️ {self}.on_tick_end() {obj_to_process}')
+        # abx.pm.hook.on_actor_tick_end(actor=self, obj_to_process=obj_to_process)
         # self.timer.end()
         # self.timer.end()
-        pass
     
     
-    def on_tick_exception(self, obj: ModelType, err: BaseException) -> None:
-        print(f'[red]🏃‍♂️ {self}.on_tick_exception()[/red]', obj.abid or obj.id, err)
-        # abx.pm.hook.on_actor_tick_exception(self, obj_to_process, err)
+    def on_tick_exception(self, obj_to_process: ModelType, error: Exception) -> None:
+        print(f'[red]🏃‍♂️ {self}.on_tick_exception()[/red] {obj_to_process}: [red]{type(error).__name__}: {error}[/red]')
+        # abx.pm.hook.on_actor_tick_exception(actor=self, obj_to_process=obj_to_process, error=error)
+
+
+
+def compile_sql_select(queryset: QuerySet, filter_kwargs: dict[str, Any] | None=None, order_args: tuple[str, ...]=(), limit: int | None=None) -> tuple[str, tuple[Any, ...]]:
+    """
+    Compute the SELECT query SQL for a queryset.filter(**filter_kwargs).order_by(*order_args)[:limit] call
+    Returns a tuple of (sql, params) where sql is a template string containing %s (unquoted) placeholders for the params
+    
+    WARNING:
+    final_sql = sql % params  DOES NOT WORK to assemble the final SQL string because the %s placeholders are not quoted/escaped
+    they should always passed separately to the DB driver so it can do its own quoting/escaping to avoid SQL injection and syntax errors
+    """
+    assert isinstance(queryset, QuerySet), f'compile_sql_select(...) first argument must be a QuerySet, got: {type(queryset).__name__} instead'
+    assert filter_kwargs is None or isinstance(filter_kwargs, dict), f'compile_sql_select(...) filter_kwargs argument must be a dict[str, Any], got: {type(filter_kwargs).__name__} instead'
+    assert isinstance(order_args, tuple) and all(isinstance(arg, str) for arg in order_args), f'compile_sql_select(...) order_args argument must be a tuple[str, ...] got: {type(order_args).__name__} instead'
+    assert limit is None or isinstance(limit, int), f'compile_sql_select(...) limit argument must be an int, got: {type(limit).__name__} instead'
+    
+    queryset = queryset._chain()                      # type: ignore   # copy queryset to avoid modifying the original
+    if filter_kwargs:
+        queryset = queryset.filter(**filter_kwargs)
+    if order_args:
+        queryset = queryset.order_by(*order_args)
+    if limit is not None:
+        queryset = queryset[:limit]
+    query = queryset.query
+    
+    # e.g. SELECT id FROM core_archiveresult WHERE status NOT IN (%s, %s, %s) AND retry_at <= %s ORDER BY retry_at ASC LIMIT 50
+    select_sql, select_params = query.get_compiler(queryset.db).as_sql()
+    return select_sql, select_params
+
+
+def compile_sql_update(queryset: QuerySet, update_kwargs: dict[str, Any], filter_kwargs: dict[str, Any] | None=None) -> tuple[str, tuple[Any, ...]]:
+    """
+    Compute the UPDATE query SQL for a queryset.filter(**filter_kwargs).update(**update_kwargs) call
+    Returns a tuple of (sql, params) where sql is a template string containing %s (unquoted) placeholders for the params
+    
+    Based on the django.db.models.QuerySet.update() source code, but modified to return the SQL instead of executing the update
+    https://github.com/django/django/blob/611bf6c2e2a1b4ab93273980c45150c099ab146d/django/db/models/query.py#L1217
+    
+    WARNING:
+    final_sql = sql % params  DOES NOT WORK to assemble the final SQL string because the %s placeholders are not quoted/escaped
+    they should always passed separately to the DB driver so it can do its own quoting/escaping to avoid SQL injection and syntax errors
+    """
+    assert isinstance(queryset, QuerySet), f'compile_sql_update(...) first argument must be a QuerySet, got: {type(queryset).__name__} instead'
+    assert isinstance(update_kwargs, dict), f'compile_sql_update(...) update_kwargs argument must be a dict[str, Any], got: {type(update_kwargs).__name__} instead'
+    assert filter_kwargs is None or isinstance(filter_kwargs, dict), f'compile_sql_update(...) filter_kwargs argument must be a dict[str, Any], got: {type(filter_kwargs).__name__} instead'
+    
+    queryset = queryset._chain()                      # type: ignore   # copy queryset to avoid modifying the original
+    if filter_kwargs:
+        queryset = queryset.filter(**filter_kwargs)
+    queryset.query.clear_ordering(force=True)                          # clear any ORDER BY clauses
+    queryset.query.clear_limits()                                      # clear any LIMIT clauses aka slices[:n]
+    queryset._for_write = True                        # type: ignore
+    query = queryset.query.chain(sql.UpdateQuery)     # type: ignore
+    query.add_update_values(update_kwargs)            # type: ignore
+    query.annotations = {}                                             # clear any annotations
+    
+    # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (%s, %s, %s) AND retry_at <= %s
+    update_sql, update_params = query.get_compiler(queryset.db).as_sql()
+    return update_sql, update_params

+ 298 - 1
archivebox/actors/models.py

@@ -1,3 +1,300 @@
+from typing import ClassVar, Type, Iterable
+from datetime import datetime, timedelta
+
+from statemachine.mixins import MachineMixin
+
 from django.db import models
 from django.db import models
+from django.utils import timezone
+from django.utils.functional import classproperty
+
+from statemachine import registry, StateMachine, State
+
+from django.core import checks
+
+class DefaultStatusChoices(models.TextChoices):
+    QUEUED = 'queued', 'Queued'
+    STARTED = 'started', 'Started'
+    SEALED = 'sealed', 'Sealed'
+
+
+default_status_field: models.CharField = models.CharField(choices=DefaultStatusChoices.choices, max_length=15, default=DefaultStatusChoices.QUEUED, null=False, blank=False, db_index=True)
+default_retry_at_field: models.DateTimeField = models.DateTimeField(default=timezone.now, null=False, db_index=True)
+
+ObjectState = State | str
+ObjectStateList = Iterable[ObjectState]
+
+
+class BaseModelWithStateMachine(models.Model, MachineMixin):
+    id: models.UUIDField
+    
+    StatusChoices: ClassVar[Type[models.TextChoices]]
+    
+    # status: models.CharField
+    # retry_at: models.DateTimeField
+    
+    state_machine_name: ClassVar[str]
+    state_field_name: ClassVar[str]
+    state_machine_attr: ClassVar[str] = 'sm'
+    bind_events_as_methods: ClassVar[bool] = True
+    
+    active_state: ClassVar[ObjectState]
+    retry_at_field_name: ClassVar[str]
+    
+    class Meta:
+        abstract = True
+        
+    @classmethod
+    def check(cls, sender=None, **kwargs):
+        errors = super().check(**kwargs)
+        
+        found_id_field = False
+        found_status_field = False
+        found_retry_at_field = False
+        
+        for field in cls._meta.get_fields():
+            if getattr(field, '_is_state_field', False):
+                if cls.state_field_name == field.name:
+                    found_status_field = True
+                    if getattr(field, 'choices', None) != cls.StatusChoices.choices:
+                        errors.append(checks.Error(
+                            f'{cls.__name__}.{field.name} must have choices set to {cls.__name__}.StatusChoices.choices',
+                            hint=f'{cls.__name__}.{field.name}.choices = {getattr(field, "choices", None)!r}',
+                            obj=cls,
+                            id='actors.E011',
+                        ))
+            if getattr(field, '_is_retry_at_field', False):
+                if cls.retry_at_field_name == field.name:
+                    found_retry_at_field = True
+            if field.name == 'id' and getattr(field, 'primary_key', False):
+                found_id_field = True
+                    
+        if not found_status_field:
+            errors.append(checks.Error(
+                f'{cls.__name__}.state_field_name must be defined and point to a StatusField()',
+                hint=f'{cls.__name__}.state_field_name = {cls.state_field_name!r} but {cls.__name__}.{cls.state_field_name!r} was not found or does not refer to StatusField',
+                obj=cls,
+                id='actors.E012',
+            ))
+        if not found_retry_at_field:
+            errors.append(checks.Error(
+                f'{cls.__name__}.retry_at_field_name must be defined and point to a RetryAtField()',
+                hint=f'{cls.__name__}.retry_at_field_name = {cls.retry_at_field_name!r} but {cls.__name__}.{cls.retry_at_field_name!r} was not found or does not refer to RetryAtField',
+                obj=cls,
+                id='actors.E013',
+            ))
+            
+        if not found_id_field:
+            errors.append(checks.Error(
+                f'{cls.__name__} must have an id field that is a primary key',
+                hint=f'{cls.__name__}.id = {cls.id!r}',
+                obj=cls,
+                id='actors.E014',
+            ))
+            
+        if not isinstance(cls.state_machine_name, str):
+            errors.append(checks.Error(
+                f'{cls.__name__}.state_machine_name must be a dotted-import path to a StateMachine class',
+                hint=f'{cls.__name__}.state_machine_name = {cls.state_machine_name!r}',
+                obj=cls,
+                id='actors.E015',
+            ))
+        
+        try:
+            cls.StateMachineClass
+        except Exception as err:
+            errors.append(checks.Error(
+                f'{cls.__name__}.state_machine_name must point to a valid StateMachine class, but got {type(err).__name__} {err} when trying to access {cls.__name__}.StateMachineClass',
+                hint=f'{cls.__name__}.state_machine_name = {cls.state_machine_name!r}',
+                obj=cls,
+                id='actors.E016',
+            ))
+        
+        if cls.INITIAL_STATE not in cls.StatusChoices.values:
+            errors.append(checks.Error(
+                f'{cls.__name__}.StateMachineClass.initial_state must be present within {cls.__name__}.StatusChoices',
+                hint=f'{cls.__name__}.StateMachineClass.initial_state = {cls.StateMachineClass.initial_state!r}',
+                obj=cls,
+                id='actors.E017',
+            ))
+            
+        if cls.ACTIVE_STATE not in cls.StatusChoices.values:
+            errors.append(checks.Error(
+                f'{cls.__name__}.active_state must be set to a valid State present within {cls.__name__}.StatusChoices',
+                hint=f'{cls.__name__}.active_state = {cls.active_state!r}',
+                obj=cls,
+                id='actors.E018',
+            ))
+            
+        
+        for state in cls.FINAL_STATES:
+            if state not in cls.StatusChoices.values:
+                errors.append(checks.Error(
+                    f'{cls.__name__}.StateMachineClass.final_states must all be present within {cls.__name__}.StatusChoices',
+                    hint=f'{cls.__name__}.StateMachineClass.final_states = {cls.StateMachineClass.final_states!r}',
+                    obj=cls,
+                    id='actors.E019',
+                ))
+                break
+        return errors
+    
+    @staticmethod
+    def _state_to_str(state: ObjectState) -> str:
+        """Convert a statemachine.State, models.TextChoices.choices value, or Enum value to a str"""
+        return str(state.value) if isinstance(state, State) else str(state)
+    
+    
+    @property
+    def RETRY_AT(self) -> datetime:
+        return getattr(self, self.retry_at_field_name)
+    
+    @RETRY_AT.setter
+    def RETRY_AT(self, value: datetime):
+        setattr(self, self.retry_at_field_name, value)
+        
+    @property
+    def STATE(self) -> str:
+        return getattr(self, self.state_field_name)
+    
+    @STATE.setter
+    def STATE(self, value: str):
+        setattr(self, self.state_field_name, value)
+        
+    def bump_retry_at(self, seconds: int = 10):
+        self.RETRY_AT = timezone.now() + timedelta(seconds=seconds)
+        
+    @classproperty
+    def ACTIVE_STATE(cls) -> str:
+        return cls._state_to_str(cls.StateMachineClass.active_state)
+        
+    @classproperty
+    def INITIAL_STATE(cls) -> str:
+        return cls._state_to_str(cls.StateMachineClass.initial_state)
+    
+    @classproperty
+    def FINAL_STATES(cls) -> list[str]:
+        return [cls._state_to_str(state) for state in cls.StateMachineClass.final_states]
+    
+    @classproperty
+    def FINAL_OR_ACTIVE_STATES(cls) -> list[str]:
+        return [*cls.FINAL_STATES, cls.ACTIVE_STATE]
+        
+    @classmethod
+    def extend_choices(cls, base_choices: Type[models.TextChoices]):
+        """
+        Decorator to extend the base choices with extra choices, e.g.:
+        
+        class MyModel(ModelWithStateMachine):
+        
+            @ModelWithStateMachine.extend_choices(ModelWithStateMachine.StatusChoices)
+            class StatusChoices(models.TextChoices):
+                SUCCEEDED = 'succeeded'
+                FAILED = 'failed'
+                SKIPPED = 'skipped'
+        """
+        assert issubclass(base_choices, models.TextChoices), f'@extend_choices(base_choices) must be a TextChoices class, not {base_choices.__name__}'
+        def wrapper(extra_choices: Type[models.TextChoices]) -> Type[models.TextChoices]:
+            joined = {}
+            for item in base_choices.choices:
+                joined[item[0]] = item[1]
+            for item in extra_choices.choices:
+                joined[item[0]] = item[1]
+            return models.TextChoices('StatusChoices', joined)
+        return wrapper
+        
+    @classmethod
+    def StatusField(cls, **kwargs) -> models.CharField:
+        """
+        Used on subclasses to extend/modify the status field with updated kwargs. e.g.:
+        
+        class MyModel(ModelWithStateMachine):
+            class StatusChoices(ModelWithStateMachine.StatusChoices):
+                QUEUED = 'queued', 'Queued'
+                STARTED = 'started', 'Started'
+                SEALED = 'sealed', 'Sealed'
+                BACKOFF = 'backoff', 'Backoff'
+                FAILED = 'failed', 'Failed'
+                SKIPPED = 'skipped', 'Skipped'
+        
+            status = ModelWithStateMachine.StatusField(choices=StatusChoices.choices, default=StatusChoices.QUEUED)
+        """
+        default_kwargs = default_status_field.deconstruct()[3]
+        updated_kwargs = {**default_kwargs, **kwargs}
+        field = models.CharField(**updated_kwargs)
+        field._is_state_field = True                    # type: ignore
+        return field
+
+    @classmethod
+    def RetryAtField(cls, **kwargs) -> models.DateTimeField:
+        """
+        Used on subclasses to extend/modify the retry_at field with updated kwargs. e.g.:
+        
+        class MyModel(ModelWithStateMachine):
+            retry_at = ModelWithStateMachine.RetryAtField(editable=False)
+        """
+        default_kwargs = default_retry_at_field.deconstruct()[3]
+        updated_kwargs = {**default_kwargs, **kwargs}
+        field = models.DateTimeField(**updated_kwargs)
+        field._is_retry_at_field = True                 # type: ignore
+        return field
+    
+    @classproperty
+    def StateMachineClass(cls) -> Type[StateMachine]:
+        """Get the StateMachine class for the given django Model that inherits from MachineMixin"""
+
+        model_state_machine_name = getattr(cls, 'state_machine_name', None)
+        if model_state_machine_name:
+            StateMachineCls = registry.get_machine_cls(model_state_machine_name)
+            assert issubclass(StateMachineCls, StateMachine)
+            return StateMachineCls
+        raise NotImplementedError(f'ActorType[{cls.__name__}] must define .state_machine_name: str that points to a valid StateMachine')
+    
+    # @classproperty
+    # def final_q(cls) -> Q:
+    #     """Get the filter for objects that are in a final state"""
+    #     return Q(**{f'{cls.state_field_name}__in': cls.final_states})
+    
+    # @classproperty
+    # def active_q(cls) -> Q:
+    #     """Get the filter for objects that are actively processing right now"""
+    #     return Q(**{cls.state_field_name: cls._state_to_str(cls.active_state)})   # e.g. Q(status='started')
+    
+    # @classproperty
+    # def stalled_q(cls) -> Q:
+    #     """Get the filter for objects that are marked active but have timed out"""
+    #     return cls.active_q & Q(retry_at__lte=timezone.now())                     # e.g. Q(status='started') AND Q(<retry_at is in the past>)
+    
+    # @classproperty
+    # def future_q(cls) -> Q:
+    #     """Get the filter for objects that have a retry_at in the future"""
+    #     return Q(retry_at__gt=timezone.now())
+    
+    # @classproperty
+    # def pending_q(cls) -> Q:
+    #     """Get the filter for objects that are ready for processing."""
+    #     return ~(cls.active_q) & ~(cls.final_q) & ~(cls.future_q)
+    
+    # @classmethod
+    # def get_queue(cls) -> QuerySet:
+    #     """
+    #     Get the sorted and filtered QuerySet of objects that are ready for processing.
+    #     e.g. qs.exclude(status__in=('sealed', 'started'), retry_at__gt=timezone.now()).order_by('retry_at')
+    #     """
+    #     return cls.objects.filter(cls.pending_q)
+
 
 
-# Create your models here.
+class ModelWithStateMachine(BaseModelWithStateMachine):
+    StatusChoices: ClassVar[Type[DefaultStatusChoices]] = DefaultStatusChoices
+    
+    status: models.CharField = BaseModelWithStateMachine.StatusField()
+    retry_at: models.DateTimeField = BaseModelWithStateMachine.RetryAtField()
+    
+    state_machine_name: ClassVar[str]      # e.g. 'core.statemachines.ArchiveResultMachine'
+    state_field_name: ClassVar[str]        = 'status'
+    state_machine_attr: ClassVar[str]      = 'sm'
+    bind_events_as_methods: ClassVar[bool] = True
+    
+    active_state: ClassVar[str]            = StatusChoices.STARTED
+    retry_at_field_name: ClassVar[str]     = 'retry_at'
+    
+    class Meta:
+        abstract = True

+ 1 - 1
archivebox/actors/orchestrator.py

@@ -6,7 +6,7 @@ import itertools
 from typing import Dict, Type, Literal, ClassVar
 from typing import Dict, Type, Literal, ClassVar
 from django.utils.functional import classproperty
 from django.utils.functional import classproperty
 
 
-from multiprocessing import Process, cpu_count
+from multiprocessing import Process
 from threading import Thread, get_native_id
 from threading import Thread, get_native_id
 
 
 
 

+ 0 - 286
archivebox/actors/statemachine.py

@@ -1,286 +0,0 @@
-from statemachine import State, StateMachine
-from django.db import models
-from multiprocessing import Process
-import psutil
-import time
-
-# State Machine Definitions
-#################################################
-
-class SnapshotMachine(StateMachine):
-    """State machine for managing Snapshot lifecycle."""
-    
-    # States
-    queued = State(initial=True)
-    started = State()
-    sealed = State(final=True)
-    
-    # Transitions
-    start = queued.to(started, cond='can_start')
-    seal = started.to(sealed, cond='is_finished')
-    
-    # Events
-    tick = (
-        queued.to.itself(unless='can_start') |
-        queued.to(started, cond='can_start') |
-        started.to.itself(unless='is_finished') |
-        started.to(sealed, cond='is_finished')
-    )
-    
-    def __init__(self, snapshot):
-        self.snapshot = snapshot
-        super().__init__()
-        
-    def can_start(self):
-        return True
-        
-    def is_finished(self):
-        return not self.snapshot.has_pending_archiveresults()
-        
-    def before_start(self):
-        """Pre-start validation and setup."""
-        self.snapshot.cleanup_dir()
-        
-    def after_start(self):
-        """Post-start side effects."""
-        self.snapshot.create_pending_archiveresults()
-        self.snapshot.update_indices()
-        self.snapshot.bump_retry_at(seconds=10)
-        
-    def before_seal(self):
-        """Pre-seal validation and cleanup."""
-        self.snapshot.cleanup_dir()
-        
-    def after_seal(self):
-        """Post-seal actions."""
-        self.snapshot.update_indices()
-        self.snapshot.seal_dir()
-        self.snapshot.upload_dir()
-        self.snapshot.retry_at = None
-        self.snapshot.save()
-
-
-class ArchiveResultMachine(StateMachine):
-    """State machine for managing ArchiveResult lifecycle."""
-    
-    # States
-    queued = State(initial=True)
-    started = State()
-    succeeded = State(final=True)
-    backoff = State()
-    failed = State(final=True)
-    
-    # Transitions
-    start = queued.to(started, cond='can_start')
-    succeed = started.to(succeeded, cond='extractor_succeeded')
-    backoff = started.to(backoff, unless='extractor_succeeded')
-    retry = backoff.to(queued, cond='can_retry')
-    fail = backoff.to(failed, unless='can_retry')
-    
-    # Events
-    tick = (
-        queued.to.itself(unless='can_start') |
-        queued.to(started, cond='can_start') |
-        started.to.itself(cond='extractor_still_running') |
-        started.to(succeeded, cond='extractor_succeeded') |
-        started.to(backoff, unless='extractor_succeeded') |
-        backoff.to.itself(cond='still_waiting_to_retry') |
-        backoff.to(queued, cond='can_retry') |
-        backoff.to(failed, unless='can_retry')
-    )
-    
-    def __init__(self, archiveresult):
-        self.archiveresult = archiveresult
-        super().__init__()
-    
-    def can_start(self):
-        return True
-    
-    def extractor_still_running(self):
-        return self.archiveresult.start_ts > time.now() - timedelta(seconds=5)
-    
-    def extractor_succeeded(self):
-        # return check_if_extractor_succeeded(self.archiveresult)
-        return self.archiveresult.start_ts < time.now() - timedelta(seconds=5)
-    
-    def can_retry(self):
-        return self.archiveresult.retries < self.archiveresult.max_retries
-        
-    def before_start(self):
-        """Pre-start initialization."""
-        self.archiveresult.retries += 1
-        self.archiveresult.start_ts = time.now()
-        self.archiveresult.output = None
-        self.archiveresult.error = None
-        
-    def after_start(self):
-        """Post-start execution."""
-        self.archiveresult.bump_retry_at(seconds=self.archiveresult.timeout + 5)
-        execute_extractor(self.archiveresult)
-        self.archiveresult.snapshot.bump_retry_at(seconds=5)
-        
-    def before_succeed(self):
-        """Pre-success validation."""
-        self.archiveresult.output = get_archiveresult_output(self.archiveresult)
-        
-    def after_succeed(self):
-        """Post-success cleanup."""
-        self.archiveresult.end_ts = time.now()
-        self.archiveresult.retry_at = None
-        self.archiveresult.update_indices()
-        
-    def before_backoff(self):
-        """Pre-backoff error capture."""
-        self.archiveresult.error = get_archiveresult_error(self.archiveresult)
-        
-    def after_backoff(self):
-        """Post-backoff retry scheduling."""
-        self.archiveresult.end_ts = time.now()
-        self.archiveresult.bump_retry_at(
-            seconds=self.archiveresult.timeout * self.archiveresult.retries
-        )
-        self.archiveresult.update_indices()
-        
-    def before_fail(self):
-        """Pre-failure finalization."""
-        self.archiveresult.retry_at = None
-        
-    def after_fail(self):
-        """Post-failure cleanup."""
-        self.archiveresult.update_indices()
-
-# Models
-#################################################
-
-class Snapshot(models.Model):
-    status = models.CharField(max_length=32, default='queued')
-    retry_at = models.DateTimeField(null=True)
-    
-    @property
-    def sm(self):
-        """Get the state machine for this snapshot."""
-        return SnapshotMachine(self)
-    
-    def has_pending_archiveresults(self):
-        return self.archiveresult_set.exclude(
-            status__in=['succeeded', 'failed']
-        ).exists()
-    
-    def bump_retry_at(self, seconds):
-        self.retry_at = time.now() + timedelta(seconds=seconds)
-        self.save()
-        
-    def cleanup_dir(self):
-        cleanup_snapshot_dir(self)
-        
-    def create_pending_archiveresults(self):
-        create_snapshot_pending_archiveresults(self)
-        
-    def update_indices(self):
-        update_snapshot_index_json(self)
-        update_snapshot_index_html(self)
-        
-    def seal_dir(self):
-        seal_snapshot_dir(self)
-        
-    def upload_dir(self):
-        upload_snapshot_dir(self)
-
-
-class ArchiveResult(models.Model):
-    snapshot = models.ForeignKey(Snapshot, on_delete=models.CASCADE)
-    status = models.CharField(max_length=32, default='queued')
-    retry_at = models.DateTimeField(null=True)
-    retries = models.IntegerField(default=0)
-    max_retries = models.IntegerField(default=3)
-    timeout = models.IntegerField(default=60)
-    start_ts = models.DateTimeField(null=True)
-    end_ts = models.DateTimeField(null=True)
-    output = models.TextField(null=True)
-    error = models.TextField(null=True)
-    
-    def get_machine(self):
-        return ArchiveResultMachine(self)
-    
-    def bump_retry_at(self, seconds):
-        self.retry_at = time.now() + timedelta(seconds=seconds)
-        self.save()
-        
-    def update_indices(self):
-        update_archiveresult_index_json(self)
-        update_archiveresult_index_html(self)
-
-
-# Actor System
-#################################################
-
-class BaseActor:
-    MAX_TICK_TIME = 60
-    
-    def tick(self, obj):
-        """Process a single object through its state machine."""
-        machine = obj.get_machine()
-        
-        if machine.is_queued:
-            if machine.can_start():
-                machine.start()
-                
-        elif machine.is_started:
-            if machine.can_seal():
-                machine.seal()
-                
-        elif machine.is_backoff:
-            if machine.can_retry():
-                machine.retry()
-            else:
-                machine.fail()
-
-
-class Orchestrator:
-    """Main orchestrator that manages all actors."""
-    
-    def __init__(self):
-        self.pid = None
-        
-    @classmethod
-    def spawn(cls):
-        orchestrator = cls()
-        proc = Process(target=orchestrator.runloop)
-        proc.start()
-        return proc.pid
-        
-    def runloop(self):
-        self.pid = os.getpid()
-        abx.pm.hook.on_orchestrator_startup(self)
-        
-        try:
-            while True:
-                self.process_queue(Snapshot)
-                self.process_queue(ArchiveResult)
-                time.sleep(0.1)
-                
-        except (KeyboardInterrupt, SystemExit):
-            abx.pm.hook.on_orchestrator_shutdown(self)
-            
-    def process_queue(self, model):
-        retry_at_reached = Q(retry_at__isnull=True) | Q(retry_at__lte=time.now())
-        queue = model.objects.filter(retry_at_reached)
-        
-        if queue.exists():
-            actor = BaseActor()
-            for obj in queue:
-                try:
-                    with transaction.atomic():
-                        actor.tick(obj)
-                except Exception as e:
-                    abx.pm.hook.on_actor_tick_exception(actor, obj, e)
-
-
-# Periodic Tasks
-#################################################
-
[email protected]_task(schedule=djhuey.crontab(minute='*'))
-def ensure_orchestrator_running():
-    """Ensure orchestrator is running, start if not."""
-    if not any(p.name().startswith('Orchestrator') for p in psutil.process_iter()):
-        Orchestrator.spawn()

+ 28 - 60
archivebox/core/actors.py

@@ -2,72 +2,40 @@ __package__ = 'archivebox.core'
 
 
 from typing import ClassVar
 from typing import ClassVar
 
 
-from rich import print
-
-from django.db.models import QuerySet
-from django.utils import timezone
-from datetime import timedelta
-from core.models import Snapshot
+from statemachine import State
 
 
+from core.models import Snapshot, ArchiveResult
+from core.statemachines import SnapshotMachine, ArchiveResultMachine
 from actors.actor import ActorType
 from actors.actor import ActorType
 
 
 
 
 class SnapshotActor(ActorType[Snapshot]):
 class SnapshotActor(ActorType[Snapshot]):
+    Model = Snapshot
+    StateMachineClass = SnapshotMachine
     
     
-    QUERYSET: ClassVar[QuerySet] = Snapshot.objects.filter(status='queued')
-    CLAIM_WHERE: ClassVar[str] = 'status = "queued"'  # the WHERE clause to filter the objects when atomically getting the next object from the queue
-    CLAIM_SET: ClassVar[str] = 'status = "started"'   # the SET clause to claim the object when atomically getting the next object from the queue
-    CLAIM_ORDER: ClassVar[str] = 'created_at DESC'    # the ORDER BY clause to sort the objects with when atomically getting the next object from the queue
-    CLAIM_FROM_TOP: ClassVar[int] = 50                # the number of objects to consider when atomically getting the next object from the queue
-    
-    # model_type: Type[ModelType]
-    MAX_CONCURRENT_ACTORS: ClassVar[int] = 4               # min 2, max 8, up to 60% of available cpu cores
-    MAX_TICK_TIME: ClassVar[int] = 60                          # maximum duration in seconds to process a single object
-    
-    def claim_sql_where(self) -> str:
-        """override this to implement a custom WHERE clause for the atomic claim step e.g. "status = 'queued' AND locked_by = NULL" """
-        return self.CLAIM_WHERE
-    
-    def claim_sql_set(self) -> str:
-        """override this to implement a custom SET clause for the atomic claim step e.g. "status = 'started' AND locked_by = {self.pid}" """
-        retry_at = timezone.now() + timedelta(seconds=self.MAX_TICK_TIME)
-        # format as 2024-10-31 10:14:33.240903
-        retry_at_str = retry_at.strftime('%Y-%m-%d %H:%M:%S.%f')
-        return f'{self.CLAIM_SET}, retry_at = {retry_at_str}'
-    
-    def claim_sql_order(self) -> str:
-        """override this to implement a custom ORDER BY clause for the atomic claim step e.g. "created_at DESC" """
-        return self.CLAIM_ORDER
+    ACTIVE_STATE: ClassVar[State] = SnapshotMachine.started
+    FINAL_STATES: ClassVar[list[State]] = SnapshotMachine.final_states
+    STATE_FIELD_NAME: ClassVar[str] = SnapshotMachine.state_field_name
     
     
-    def claim_from_top(self) -> int:
-        """override this to implement a custom number of objects to consider when atomically claiming the next object from the top of the queue"""
-        return self.CLAIM_FROM_TOP
-        
-    def tick(self, obj: Snapshot) -> None:
-        """override this to process the object"""
-        print(f'[blue]🏃‍♂️ {self}.tick()[/blue]', obj.abid or obj.id)
-        # For example:
-        # do_some_task(obj)
-        # do_something_else(obj)
-        # obj._model.objects.filter(pk=obj.pk, status='started').update(status='success')
-        # raise NotImplementedError('tick() must be implemented by the Actor subclass')
-    
-    def on_shutdown(self, err: BaseException | None=None) -> None:
-        print(f'[grey53]🏃‍♂️ {self}.on_shutdown() SHUTTING DOWN[/grey53]', err or '[green](gracefully)[/green]')
-        # abx.pm.hook.on_actor_shutdown(self)
-        
-    def on_tick_start(self, obj: Snapshot) -> None:
-        # print(f'🏃‍♂️ {self}.on_tick_start()', obj.abid or obj.id)
-        # abx.pm.hook.on_actor_tick_start(self, obj_to_process)
-        # self.timer = TimedProgress(self.MAX_TICK_TIME, prefix='      ')
-        pass
+    MAX_CONCURRENT_ACTORS: ClassVar[int] = 3
+    MAX_TICK_TIME: ClassVar[int] = 10
+    CLAIM_FROM_TOP_N: ClassVar[int] = MAX_CONCURRENT_ACTORS * 10
+
+
+
+class ArchiveResultActor(ActorType[ArchiveResult]):
+    Model = ArchiveResult
+    StateMachineClass = ArchiveResultMachine
     
     
-    def on_tick_end(self, obj: Snapshot) -> None:
-        # print(f'🏃‍♂️ {self}.on_tick_end()', obj.abid or obj.id)
-        # abx.pm.hook.on_actor_tick_end(self, obj_to_process)
-        # self.timer.end()
-        pass
+    ACTIVE_STATE: ClassVar[State] = ArchiveResultMachine.started
+    FINAL_STATES: ClassVar[list[State]] = ArchiveResultMachine.final_states
+    STATE_FIELD_NAME: ClassVar[str] = ArchiveResultMachine.state_field_name
     
     
-    def on_tick_exception(self, obj: Snapshot, err: BaseException) -> None:
-        print(f'[red]🏃‍♂️ {self}.on_tick_exception()[/red]', obj.abid or obj.id, err)
-        # abx.pm.hook.on_actor_tick_exception(self, obj_to_process, err)
+    MAX_CONCURRENT_ACTORS: ClassVar[int] = 6
+    MAX_TICK_TIME: ClassVar[int] = 60
+    CLAIM_FROM_TOP_N: ClassVar[int] = MAX_CONCURRENT_ACTORS * 10
+
+    # @classproperty
+    # def qs(cls) -> QuerySet[ModelType]:
+    #     """Get the unfiltered and unsorted QuerySet of all objects that this Actor might care about."""
+    #     return cls.Model.objects.filter(extractor='favicon')

+ 57 - 50
archivebox/core/models.py

@@ -20,7 +20,7 @@ from django.db.models import Case, When, Value, IntegerField
 from django.contrib import admin
 from django.contrib import admin
 from django.conf import settings
 from django.conf import settings
 
 
-from statemachine.mixins import MachineMixin
+from actors.models import ModelWithStateMachine
 
 
 from archivebox.config import CONSTANTS
 from archivebox.config import CONSTANTS
 
 
@@ -156,7 +156,7 @@ class SnapshotManager(models.Manager):
         return super().get_queryset().prefetch_related('tags', 'archiveresult_set')  # .annotate(archiveresult_count=models.Count('archiveresult')).distinct()
         return super().get_queryset().prefetch_related('tags', 'archiveresult_set')  # .annotate(archiveresult_count=models.Count('archiveresult')).distinct()
 
 
 
 
-class Snapshot(ABIDModel, MachineMixin):
+class Snapshot(ABIDModel, ModelWithStateMachine):
     abid_prefix = 'snp_'
     abid_prefix = 'snp_'
     abid_ts_src = 'self.created_at'
     abid_ts_src = 'self.created_at'
     abid_uri_src = 'self.url'
     abid_uri_src = 'self.url'
@@ -164,34 +164,32 @@ class Snapshot(ABIDModel, MachineMixin):
     abid_rand_src = 'self.id'
     abid_rand_src = 'self.id'
     abid_drift_allowed = True
     abid_drift_allowed = True
 
 
-    state_field_name = 'status'
     state_machine_name = 'core.statemachines.SnapshotMachine'
     state_machine_name = 'core.statemachines.SnapshotMachine'
-    state_machine_attr = 'sm'
+    state_field_name = 'status'
+    retry_at_field_name = 'retry_at'
+    StatusChoices = ModelWithStateMachine.StatusChoices
+    active_state = StatusChoices.STARTED
     
     
-    class SnapshotStatus(models.TextChoices):
-        QUEUED = 'queued', 'Queued'
-        STARTED = 'started', 'Started'
-        SEALED = 'sealed', 'Sealed'
-        
-    status = models.CharField(max_length=15, default=SnapshotStatus.QUEUED, null=False, blank=False)
-
     id = models.UUIDField(primary_key=True, default=None, null=False, editable=False, unique=True, verbose_name='ID')
     id = models.UUIDField(primary_key=True, default=None, null=False, editable=False, unique=True, verbose_name='ID')
     abid = ABIDField(prefix=abid_prefix)
     abid = ABIDField(prefix=abid_prefix)
 
 
-    created_by = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, default=None, null=False, related_name='snapshot_set')
+    created_by = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, default=None, null=False, related_name='snapshot_set', db_index=True)
     created_at = AutoDateTimeField(default=None, null=False, db_index=True)  # loaded from self._init_timestamp
     created_at = AutoDateTimeField(default=None, null=False, db_index=True)  # loaded from self._init_timestamp
     modified_at = models.DateTimeField(auto_now=True)
     modified_at = models.DateTimeField(auto_now=True)
+    
+    status = ModelWithStateMachine.StatusField(choices=StatusChoices, default=StatusChoices.QUEUED)
+    retry_at = ModelWithStateMachine.RetryAtField(default=timezone.now)
 
 
     # legacy ts fields
     # legacy ts fields
     bookmarked_at = AutoDateTimeField(default=None, null=False, editable=True, db_index=True)
     bookmarked_at = AutoDateTimeField(default=None, null=False, editable=True, db_index=True)
     downloaded_at = models.DateTimeField(default=None, null=True, editable=False, db_index=True, blank=True)
     downloaded_at = models.DateTimeField(default=None, null=True, editable=False, db_index=True, blank=True)
 
 
-    crawl = models.ForeignKey(Crawl, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name='snapshot_set')
+    crawl: Crawl = models.ForeignKey(Crawl, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name='snapshot_set', db_index=True)  # type: ignore
 
 
     url = models.URLField(unique=True, db_index=True)
     url = models.URLField(unique=True, db_index=True)
     timestamp = models.CharField(max_length=32, unique=True, db_index=True, editable=False)
     timestamp = models.CharField(max_length=32, unique=True, db_index=True, editable=False)
     tags = models.ManyToManyField(Tag, blank=True, through=SnapshotTag, related_name='snapshot_set', through_fields=('snapshot', 'tag'))
     tags = models.ManyToManyField(Tag, blank=True, through=SnapshotTag, related_name='snapshot_set', through_fields=('snapshot', 'tag'))
-    title = models.CharField(max_length=512, null=True, blank=True, db_index=True)    
+    title = models.CharField(max_length=512, null=True, blank=True, db_index=True)
 
 
     keys = ('url', 'timestamp', 'title', 'tags', 'downloaded_at')
     keys = ('url', 'timestamp', 'title', 'tags', 'downloaded_at')
 
 
@@ -210,12 +208,14 @@ class Snapshot(ABIDModel, MachineMixin):
         return result
         return result
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
-        title = (self.title_stripped or '-')[:64]
-        return f'[{self.timestamp}] {self.url[:64]} ({title})'
+        url = self.url or '<no url set>'
+        created_at = self.created_at.strftime("%Y-%m-%d %H:%M") if self.created_at else '<no timestamp set>'
+        if self.id and self.url:
+            return f'[{self.ABID}] {url[:64]} @ {created_at}'
+        return f'[{self.abid_prefix}****not*saved*yet****] {url[:64]} @ {created_at}'
 
 
     def __str__(self) -> str:
     def __str__(self) -> str:
-        title = (self.title_stripped or '-')[:64]
-        return f'[{self.timestamp}] {self.url[:64]} ({title})'
+        return repr(self)
 
 
     @classmethod
     @classmethod
     def from_json(cls, info: dict):
     def from_json(cls, info: dict):
@@ -413,8 +413,7 @@ class Snapshot(ABIDModel, MachineMixin):
         self.tags.add(*tags_id)
         self.tags.add(*tags_id)
         
         
     def has_pending_archiveresults(self) -> bool:
     def has_pending_archiveresults(self) -> bool:
-        pending_statuses = [ArchiveResult.ArchiveResultStatus.QUEUED, ArchiveResult.ArchiveResultStatus.STARTED]
-        pending_archiveresults = self.archiveresult_set.filter(status__in=pending_statuses)
+        pending_archiveresults = self.archiveresult_set.exclude(status__in=ArchiveResult.FINAL_OR_ACTIVE_STATES)
         return pending_archiveresults.exists()
         return pending_archiveresults.exists()
     
     
     def create_pending_archiveresults(self) -> list['ArchiveResult']:
     def create_pending_archiveresults(self) -> list['ArchiveResult']:
@@ -423,13 +422,10 @@ class Snapshot(ABIDModel, MachineMixin):
             archiveresult, _created = ArchiveResult.objects.get_or_create(
             archiveresult, _created = ArchiveResult.objects.get_or_create(
                 snapshot=self,
                 snapshot=self,
                 extractor=extractor,
                 extractor=extractor,
-                status=ArchiveResult.ArchiveResultStatus.QUEUED,
+                status=ArchiveResult.INITIAL_STATE,
             )
             )
             archiveresults.append(archiveresult)
             archiveresults.append(archiveresult)
         return archiveresults
         return archiveresults
-    
-    def bump_retry_at(self, seconds: int = 10):
-        self.retry_at = timezone.now() + timedelta(seconds=seconds)
 
 
 
 
     # def get_storage_dir(self, create=True, symlink=True) -> Path:
     # def get_storage_dir(self, create=True, symlink=True) -> Path:
@@ -479,7 +475,7 @@ class ArchiveResultManager(models.Manager):
             ).order_by('indexing_precedence')
             ).order_by('indexing_precedence')
         return qs
         return qs
 
 
-class ArchiveResult(ABIDModel):
+class ArchiveResult(ABIDModel, ModelWithStateMachine):
     abid_prefix = 'res_'
     abid_prefix = 'res_'
     abid_ts_src = 'self.snapshot.created_at'
     abid_ts_src = 'self.snapshot.created_at'
     abid_uri_src = 'self.snapshot.url'
     abid_uri_src = 'self.snapshot.url'
@@ -487,19 +483,19 @@ class ArchiveResult(ABIDModel):
     abid_rand_src = 'self.id'
     abid_rand_src = 'self.id'
     abid_drift_allowed = True
     abid_drift_allowed = True
     
     
-    state_field_name = 'status'
-    state_machine_name = 'core.statemachines.ArchiveResultMachine'
-    state_machine_attr = 'sm'
-
-    class ArchiveResultStatus(models.TextChoices):
-        QUEUED = 'queued', 'Queued'
-        STARTED = 'started', 'Started'
-        SUCCEEDED = 'succeeded', 'Succeeded'
-        FAILED = 'failed', 'Failed'
-        SKIPPED = 'skipped', 'Skipped'
-        BACKOFF = 'backoff', 'Waiting to retry'
+    class StatusChoices(models.TextChoices):
+        QUEUED = 'queued', 'Queued'                     # pending, initial
+        STARTED = 'started', 'Started'                  # active
         
         
-    status = models.CharField(max_length=15, choices=ArchiveResultStatus.choices, default=ArchiveResultStatus.QUEUED, null=False, blank=False)
+        BACKOFF = 'backoff', 'Waiting to retry'         # pending
+        SUCCEEDED = 'succeeded', 'Succeeded'            # final
+        FAILED = 'failed', 'Failed'                     # final
+        SKIPPED = 'skipped', 'Skipped'                  # final
+        
+    state_machine_name = 'core.statemachines.ArchiveResultMachine'
+    retry_at_field_name = 'retry_at'
+    state_field_name = 'status'
+    active_state = StatusChoices.STARTED
 
 
     EXTRACTOR_CHOICES = (
     EXTRACTOR_CHOICES = (
         ('htmltotext', 'htmltotext'),
         ('htmltotext', 'htmltotext'),
@@ -522,19 +518,22 @@ class ArchiveResult(ABIDModel):
     id = models.UUIDField(primary_key=True, default=None, null=False, editable=False, unique=True, verbose_name='ID')
     id = models.UUIDField(primary_key=True, default=None, null=False, editable=False, unique=True, verbose_name='ID')
     abid = ABIDField(prefix=abid_prefix)
     abid = ABIDField(prefix=abid_prefix)
 
 
-    created_by = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, default=None, null=False, related_name='archiveresult_set')
+    created_by = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, default=None, null=False, related_name='archiveresult_set', db_index=True)
     created_at = AutoDateTimeField(default=None, null=False, db_index=True)
     created_at = AutoDateTimeField(default=None, null=False, db_index=True)
     modified_at = models.DateTimeField(auto_now=True)
     modified_at = models.DateTimeField(auto_now=True)
+    
+    status = ModelWithStateMachine.StatusField(choices=StatusChoices.choices, default=StatusChoices.QUEUED)
+    retry_at = ModelWithStateMachine.RetryAtField(default=timezone.now)
 
 
-    snapshot = models.ForeignKey(Snapshot, on_delete=models.CASCADE, to_field='id', db_column='snapshot_id')
+    snapshot: Snapshot = models.ForeignKey(Snapshot, on_delete=models.CASCADE)   # type: ignore
 
 
-    extractor = models.CharField(choices=EXTRACTOR_CHOICES, max_length=32)
-    cmd = models.JSONField()
-    pwd = models.CharField(max_length=256)
+    extractor = models.CharField(choices=EXTRACTOR_CHOICES, max_length=32, blank=False, null=False, db_index=True)
+    cmd = models.JSONField(default=None, null=True, blank=True)
+    pwd = models.CharField(max_length=256, default=None, null=True, blank=True)
     cmd_version = models.CharField(max_length=128, default=None, null=True, blank=True)
     cmd_version = models.CharField(max_length=128, default=None, null=True, blank=True)
-    output = models.CharField(max_length=1024)
-    start_ts = models.DateTimeField(db_index=True)
-    end_ts = models.DateTimeField()
+    output = models.CharField(max_length=1024, default=None, null=True, blank=True)
+    start_ts = models.DateTimeField(default=None, null=True, blank=True)
+    end_ts = models.DateTimeField(default=None, null=True, blank=True)
 
 
     # the network interface that was used to download this result
     # the network interface that was used to download this result
     # uplink = models.ForeignKey(NetworkInterface, on_delete=models.SET_NULL, null=True, blank=True, verbose_name='Network Interface Used')
     # uplink = models.ForeignKey(NetworkInterface, on_delete=models.SET_NULL, null=True, blank=True, verbose_name='Network Interface Used')
@@ -545,10 +544,17 @@ class ArchiveResult(ABIDModel):
         verbose_name = 'Archive Result'
         verbose_name = 'Archive Result'
         verbose_name_plural = 'Archive Results Log'
         verbose_name_plural = 'Archive Results Log'
 
 
+    def __repr__(self):
+        snapshot_id = getattr(self, 'snapshot_id', None)
+        url = self.snapshot.url if snapshot_id else '<no url set>'
+        created_at = self.snapshot.created_at.strftime("%Y-%m-%d %H:%M") if snapshot_id else '<no timestamp set>'
+        extractor = self.extractor or '<no extractor set>'
+        if self.id and snapshot_id:
+            return f'[{self.ABID}] {url[:64]} @ {created_at} -> {extractor}'
+        return f'[{self.abid_prefix}****not*saved*yet****] {url} @ {created_at} -> {extractor}'
 
 
     def __str__(self):
     def __str__(self):
-        # return f'[{self.abid}] 📅 {self.start_ts.strftime("%Y-%m-%d %H:%M")} 📄 {self.extractor} {self.snapshot.url}'
-        return self.extractor
+        return repr(self)
 
 
     # TODO: finish connecting machine.models
     # TODO: finish connecting machine.models
     # @cached_property
     # @cached_property
@@ -558,6 +564,10 @@ class ArchiveResult(ABIDModel):
     @cached_property
     @cached_property
     def snapshot_dir(self):
     def snapshot_dir(self):
         return Path(self.snapshot.link_dir)
         return Path(self.snapshot.link_dir)
+    
+    @cached_property
+    def url(self):
+        return self.snapshot.url
 
 
     @property
     @property
     def api_url(self) -> str:
     def api_url(self) -> str:
@@ -596,9 +606,6 @@ class ArchiveResult(ABIDModel):
 
 
     def output_exists(self) -> bool:
     def output_exists(self) -> bool:
         return os.path.exists(self.output_path())
         return os.path.exists(self.output_path())
-    
-    def bump_retry_at(self, seconds: int = 10):
-        self.retry_at = timezone.now() + timedelta(seconds=seconds)
         
         
     def create_output_dir(self):
     def create_output_dir(self):
         snap_dir = self.snapshot_dir
         snap_dir = self.snapshot_dir

+ 13 - 10
archivebox/core/statemachines.py

@@ -16,9 +16,9 @@ class SnapshotMachine(StateMachine, strict_states=True):
     model: Snapshot
     model: Snapshot
     
     
     # States
     # States
-    queued = State(value=Snapshot.SnapshotStatus.QUEUED, initial=True)
-    started = State(value=Snapshot.SnapshotStatus.STARTED)
-    sealed = State(value=Snapshot.SnapshotStatus.SEALED, final=True)
+    queued = State(value=Snapshot.StatusChoices.QUEUED, initial=True)
+    started = State(value=Snapshot.StatusChoices.STARTED)
+    sealed = State(value=Snapshot.StatusChoices.SEALED, final=True)
     
     
     # Tick Event
     # Tick Event
     tick = (
     tick = (
@@ -53,11 +53,11 @@ class ArchiveResultMachine(StateMachine, strict_states=True):
     model: ArchiveResult
     model: ArchiveResult
     
     
     # States
     # States
-    queued = State(value=ArchiveResult.ArchiveResultStatus.QUEUED, initial=True)
-    started = State(value=ArchiveResult.ArchiveResultStatus.STARTED)
-    backoff = State(value=ArchiveResult.ArchiveResultStatus.BACKOFF)
-    succeeded = State(value=ArchiveResult.ArchiveResultStatus.SUCCEEDED, final=True)
-    failed = State(value=ArchiveResult.ArchiveResultStatus.FAILED, final=True)
+    queued = State(value=ArchiveResult.StatusChoices.QUEUED, initial=True)
+    started = State(value=ArchiveResult.StatusChoices.STARTED)
+    backoff = State(value=ArchiveResult.StatusChoices.BACKOFF)
+    succeeded = State(value=ArchiveResult.StatusChoices.SUCCEEDED, final=True)
+    failed = State(value=ArchiveResult.StatusChoices.FAILED, final=True)
     
     
     # Tick Event
     # Tick Event
     tick = (
     tick = (
@@ -78,7 +78,7 @@ class ArchiveResultMachine(StateMachine, strict_states=True):
         super().__init__(archiveresult, *args, **kwargs)
         super().__init__(archiveresult, *args, **kwargs)
         
         
     def can_start(self) -> bool:
     def can_start(self) -> bool:
-        return self.archiveresult.snapshot and self.archiveresult.snapshot.is_started()
+        return self.archiveresult.snapshot and self.archiveresult.snapshot.STATE == Snapshot.active_state
     
     
     def is_succeeded(self) -> bool:
     def is_succeeded(self) -> bool:
         return self.archiveresult.output_exists()
         return self.archiveresult.output_exists()
@@ -87,7 +87,10 @@ class ArchiveResultMachine(StateMachine, strict_states=True):
         return not self.archiveresult.output_exists()
         return not self.archiveresult.output_exists()
     
     
     def is_backoff(self) -> bool:
     def is_backoff(self) -> bool:
-        return self.archiveresult.status == ArchiveResult.ArchiveResultStatus.BACKOFF
+        return self.archiveresult.STATE == ArchiveResult.StatusChoices.BACKOFF
+    
+    def is_finished(self) -> bool:
+        return self.is_failed() or self.is_succeeded()
 
 
     def on_started(self):
     def on_started(self):
         self.archiveresult.start_ts = timezone.now()
         self.archiveresult.start_ts = timezone.now()

+ 11 - 57
archivebox/crawls/actors.py

@@ -2,68 +2,22 @@ __package__ = 'archivebox.crawls'
 
 
 from typing import ClassVar
 from typing import ClassVar
 
 
-from rich import print
-
-from django.db.models import QuerySet
-
 from crawls.models import Crawl
 from crawls.models import Crawl
+from crawls.statemachines import CrawlMachine
 
 
-from actors.actor import ActorType
+from actors.actor import ActorType, State
 
 
 
 
 class CrawlActor(ActorType[Crawl]):
 class CrawlActor(ActorType[Crawl]):
+    """The Actor that manages the lifecycle of all Crawl objects"""
     
     
-    QUERYSET: ClassVar[QuerySet] = Crawl.objects.filter(status='queued')
-    CLAIM_WHERE: ClassVar[str] = 'status = "queued"'  # the WHERE clause to filter the objects when atomically getting the next object from the queue
-    CLAIM_SET: ClassVar[str] = 'status = "started"'   # the SET clause to claim the object when atomically getting the next object from the queue
-    CLAIM_ORDER: ClassVar[str] = 'created_at DESC'    # the ORDER BY clause to sort the objects with when atomically getting the next object from the queue
-    CLAIM_FROM_TOP: ClassVar[int] = 50                # the number of objects to consider when atomically getting the next object from the queue
-    
-    # model_type: Type[ModelType]
-    MAX_CONCURRENT_ACTORS: ClassVar[int] = 4               # min 2, max 8, up to 60% of available cpu cores
-    MAX_TICK_TIME: ClassVar[int] = 60                          # maximum duration in seconds to process a single object
-    
-    def claim_sql_where(self) -> str:
-        """override this to implement a custom WHERE clause for the atomic claim step e.g. "status = 'queued' AND locked_by = NULL" """
-        return self.CLAIM_WHERE
-    
-    def claim_sql_set(self) -> str:
-        """override this to implement a custom SET clause for the atomic claim step e.g. "status = 'started' AND locked_by = {self.pid}" """
-        return self.CLAIM_SET
-    
-    def claim_sql_order(self) -> str:
-        """override this to implement a custom ORDER BY clause for the atomic claim step e.g. "created_at DESC" """
-        return self.CLAIM_ORDER
-    
-    def claim_from_top(self) -> int:
-        """override this to implement a custom number of objects to consider when atomically claiming the next object from the top of the queue"""
-        return self.CLAIM_FROM_TOP
-        
-    def tick(self, obj: Crawl) -> None:
-        """override this to process the object"""
-        print(f'[blue]🏃‍♂️ {self}.tick()[/blue]', obj.abid or obj.id)
-        # For example:
-        # do_some_task(obj)
-        # do_something_else(obj)
-        # obj._model.objects.filter(pk=obj.pk, status='started').update(status='success')
-        # raise NotImplementedError('tick() must be implemented by the Actor subclass')
-    
-    def on_shutdown(self, err: BaseException | None=None) -> None:
-        print(f'[grey53]🏃‍♂️ {self}.on_shutdown() SHUTTING DOWN[/grey53]', err or '[green](gracefully)[/green]')
-        # abx.pm.hook.on_actor_shutdown(self)
-        
-    def on_tick_start(self, obj: Crawl) -> None:
-        # print(f'🏃‍♂️ {self}.on_tick_start()', obj.abid or obj.id)
-        # abx.pm.hook.on_actor_tick_start(self, obj_to_process)
-        # self.timer = TimedProgress(self.MAX_TICK_TIME, prefix='      ')
-        pass
+    Model = Crawl
+    StateMachineClass = CrawlMachine
     
     
-    def on_tick_end(self, obj: Crawl) -> None:
-        # print(f'🏃‍♂️ {self}.on_tick_end()', obj.abid or obj.id)
-        # abx.pm.hook.on_actor_tick_end(self, obj_to_process)
-        # self.timer.end()
-        pass
+    ACTIVE_STATE: ClassVar[State] = CrawlMachine.started
+    FINAL_STATES: ClassVar[list[State]] = CrawlMachine.final_states
+    STATE_FIELD_NAME: ClassVar[str] = Crawl.state_field_name
     
     
-    def on_tick_exception(self, obj: Crawl, err: BaseException) -> None:
-        print(f'[red]🏃‍♂️ {self}.on_tick_exception()[/red]', obj.abid or obj.id, err)
-        # abx.pm.hook.on_actor_tick_exception(self, obj_to_process, err)
+    MAX_CONCURRENT_ACTORS: ClassVar[int] = 3
+    MAX_TICK_TIME: ClassVar[int] = 10
+    CLAIM_FROM_TOP_N: ClassVar[int] = MAX_CONCURRENT_ACTORS * 10

+ 10 - 19
archivebox/crawls/models.py

@@ -11,7 +11,7 @@ from django.conf import settings
 from django.urls import reverse_lazy
 from django.urls import reverse_lazy
 from django.utils import timezone
 from django.utils import timezone
 
 
-from statemachine.mixins import MachineMixin
+from actors.models import ModelWithStateMachine
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from core.models import Snapshot
     from core.models import Snapshot
@@ -50,7 +50,7 @@ class CrawlSchedule(ABIDModel, ModelWithHealthStats):
 
 
     
     
 
 
-class Crawl(ABIDModel, ModelWithHealthStats, MachineMixin):
+class Crawl(ABIDModel, ModelWithHealthStats, ModelWithStateMachine):
     """
     """
     A single session of URLs to archive starting from a given Seed and expanding outwards. An "archiving session" so to speak.
     A single session of URLs to archive starting from a given Seed and expanding outwards. An "archiving session" so to speak.
 
 
@@ -67,17 +67,11 @@ class Crawl(ABIDModel, ModelWithHealthStats, MachineMixin):
     abid_rand_src = 'self.id'
     abid_rand_src = 'self.id'
     abid_drift_allowed = True
     abid_drift_allowed = True
     
     
-    state_field_name = 'status'
     state_machine_name = 'crawls.statemachines.CrawlMachine'
     state_machine_name = 'crawls.statemachines.CrawlMachine'
-    state_machine_attr = 'sm'
-    bind_events_as_methods = True
-
-    class CrawlStatus(models.TextChoices):
-        QUEUED = 'queued', 'Queued'
-        STARTED = 'started', 'Started'
-        SEALED = 'sealed', 'Sealed'
-
-    status = models.CharField(choices=CrawlStatus.choices, max_length=15, default=CrawlStatus.QUEUED, null=False, blank=False)
+    retry_at_field_name = 'retry_at'
+    state_field_name = 'status'
+    StatusChoices = ModelWithStateMachine.StatusChoices
+    active_state = StatusChoices.STARTED
     
     
     id = models.UUIDField(primary_key=True, default=None, null=False, editable=False, unique=True, verbose_name='ID')
     id = models.UUIDField(primary_key=True, default=None, null=False, editable=False, unique=True, verbose_name='ID')
     abid = ABIDField(prefix=abid_prefix)
     abid = ABIDField(prefix=abid_prefix)
@@ -86,6 +80,8 @@ class Crawl(ABIDModel, ModelWithHealthStats, MachineMixin):
     created_at = AutoDateTimeField(default=None, null=False, db_index=True)
     created_at = AutoDateTimeField(default=None, null=False, db_index=True)
     modified_at = models.DateTimeField(auto_now=True)
     modified_at = models.DateTimeField(auto_now=True)
     
     
+    status = ModelWithStateMachine.StatusField(choices=StatusChoices, default=StatusChoices.QUEUED)
+    retry_at = ModelWithStateMachine.RetryAtField(default=timezone.now)
 
 
     seed = models.ForeignKey(Seed, on_delete=models.PROTECT, related_name='crawl_set', null=False, blank=False)
     seed = models.ForeignKey(Seed, on_delete=models.PROTECT, related_name='crawl_set', null=False, blank=False)
     max_depth = models.PositiveSmallIntegerField(default=0, validators=[MinValueValidator(0), MaxValueValidator(4)])
     max_depth = models.PositiveSmallIntegerField(default=0, validators=[MinValueValidator(0), MaxValueValidator(4)])
@@ -127,10 +123,8 @@ class Crawl(ABIDModel, ModelWithHealthStats, MachineMixin):
     def has_pending_archiveresults(self) -> bool:
     def has_pending_archiveresults(self) -> bool:
         from core.models import ArchiveResult
         from core.models import ArchiveResult
         
         
-        pending_statuses = [ArchiveResult.ArchiveResultStatus.QUEUED, ArchiveResult.ArchiveResultStatus.STARTED]
-        
         snapshot_ids = self.snapshot_set.values_list('id', flat=True)
         snapshot_ids = self.snapshot_set.values_list('id', flat=True)
-        pending_archiveresults = ArchiveResult.objects.filter(snapshot_id__in=snapshot_ids, status__in=pending_statuses)
+        pending_archiveresults = ArchiveResult.objects.filter(snapshot_id__in=snapshot_ids).exclude(status__in=ArchiveResult.FINAL_OR_ACTIVE_STATES)
         return pending_archiveresults.exists()
         return pending_archiveresults.exists()
     
     
     def create_root_snapshot(self) -> 'Snapshot':
     def create_root_snapshot(self) -> 'Snapshot':
@@ -139,12 +133,9 @@ class Crawl(ABIDModel, ModelWithHealthStats, MachineMixin):
         root_snapshot, _ = Snapshot.objects.get_or_create(
         root_snapshot, _ = Snapshot.objects.get_or_create(
             crawl=self,
             crawl=self,
             url=self.seed.uri,
             url=self.seed.uri,
+            status=Snapshot.INITIAL_STATE,
         )
         )
         return root_snapshot
         return root_snapshot
-    
-    def bump_retry_at(self, seconds: int = 10):
-        self.retry_at = timezone.now() + timedelta(seconds=seconds)
-        self.save()
 
 
 
 
 class Outlink(models.Model):
 class Outlink(models.Model):

+ 3 - 3
archivebox/crawls/statemachines.py

@@ -14,9 +14,9 @@ class CrawlMachine(StateMachine, strict_states=True):
     model: Crawl
     model: Crawl
     
     
     # States
     # States
-    queued = State(value=Crawl.CrawlStatus.QUEUED, initial=True)
-    started = State(value=Crawl.CrawlStatus.STARTED)
-    sealed = State(value=Crawl.CrawlStatus.SEALED, final=True)
+    queued = State(value=Crawl.StatusChoices.QUEUED, initial=True)
+    started = State(value=Crawl.StatusChoices.STARTED)
+    sealed = State(value=Crawl.StatusChoices.SEALED, final=True)
     
     
     # Tick Event
     # Tick Event
     tick = (
     tick = (