Browse Source

more StateMachine, Actor, and Orchestrator improvements

Nick Sweeting 1 year ago
parent
commit
a9a3b153b1

+ 441 - 183
archivebox/actors/actor.py

@@ -2,78 +2,240 @@ __package__ = 'archivebox.actors'
 
 import os
 import time
-from abc import ABC, abstractmethod
-from typing import ClassVar, Generic, TypeVar, Any, cast, Literal, Type
-from django.utils.functional import classproperty
+from typing import ClassVar, Generic, TypeVar, Any, Literal, Type, Iterable, cast, get_args
+from datetime import timedelta
+from multiprocessing import Process, cpu_count
+from threading import Thread, get_native_id
 
-from rich import print
 import psutil
+from rich import print
+from statemachine import State, StateMachine, registry
+from statemachine.mixins import MachineMixin
 
 from django import db
-from django.db import models
-from django.db.models import QuerySet
-from multiprocessing import Process, cpu_count
-from threading import Thread, get_native_id
+from django.db.models import QuerySet, sql, Q
+from django.db.models import Model as DjangoModel
+from django.utils import timezone
+from django.utils.functional import classproperty
+
+from .models import ModelWithStateMachine
 
 # from archivebox.logging_util import TimedProgress
 
+class ActorObjectAlreadyClaimed(Exception):
+    """Raised when the Actor tries to claim the next object from the queue but it's already been claimed by another Actor"""
+    pass
+
+class ActorQueueIsEmpty(Exception):
+    """Raised when the Actor tries to get the next object from the queue but it's empty"""
+    pass
+
+CPU_COUNT = cpu_count()
+DEFAULT_MAX_TICK_TIME = 60
+DEFAULT_MAX_CONCURRENT_ACTORS = min(max(2, int(CPU_COUNT * 0.6)), 8)   # 2 < 60% * num available cpu cores < 8
+
+limit = lambda n, max: min(n, max)
+
 LaunchKwargs = dict[str, Any]
+ObjectState = State | str
+ObjectStateList = Iterable[ObjectState]
 
-ModelType = TypeVar('ModelType', bound=models.Model)
+ModelType = TypeVar('ModelType', bound=ModelWithStateMachine)
 
-class ActorType(ABC, Generic[ModelType]):
+class ActorType(Generic[ModelType]):
     """
     Base class for all actors. Usage:
-    class FaviconActor(ActorType[ArchiveResult]):
-        QUERYSET: ClassVar[QuerySet] = ArchiveResult.objects.filter(status='queued', extractor='favicon')
-        CLAIM_WHERE: ClassVar[str] = 'status = "queued" AND extractor = "favicon"'
-        CLAIM_ORDER: ClassVar[str] = 'created_at DESC'
-        ATOMIC: ClassVar[bool] = True
-
-        def claim_sql_set(self, obj: ArchiveResult) -> str:
-            # SQL fields to update atomically while claiming an object from the queue
-            retry_at = datetime.now() + timedelta(seconds=self.MAX_TICK_TIME)
-            return f"status = 'started', locked_by = {self.pid}, retry_at = {retry_at}"
-
-        def tick(self, obj: ArchiveResult) -> None:
-            run_favicon_extractor(obj)
-            ArchiveResult.objects.filter(pk=obj.pk, status='started').update(status='success')
+    
+    class FaviconActor(ActorType[FaviconArchiveResult]):
+        FINAL_STATES: ClassVar[tuple[str, ...]] = ('succeeded', 'failed', 'skipped')
+        ACTIVE_STATE: ClassVar[str] = 'started'
+        
+        @classmethod
+        def qs(cls) -> QuerySet[FaviconArchiveResult]:
+            return ArchiveResult.objects.filter(extractor='favicon')   # or leave the default: FaviconArchiveResult.objects.all()
     """
+    
+    ### Class attributes (defined on the class at compile-time when ActorType[MyModel] is defined)
+    Model: Type[ModelType]
+    StateMachineClass: Type[StateMachine]
+    
+    STATE_FIELD_NAME: ClassVar[str]
+    ACTIVE_STATE: ClassVar[ObjectState]
+    FINAL_STATES: ClassVar[ObjectStateList]
+    EVENT_NAME: ClassVar[str] = 'tick'                                    # the event name to trigger on the obj.sm: StateMachine (usually 'tick')
+    
+    CLAIM_ORDER: ClassVar[tuple[str, ...]] = ('retry_at',)                # the .order(*args) to claim the queue objects in, use ('?',) for random order
+    CLAIM_FROM_TOP_N: ClassVar[int] = CPU_COUNT * 10                      # the number of objects to consider when atomically getting the next object from the queue
+    CLAIM_ATOMIC: ClassVar[bool] = True                                   # whether to atomically fetch+claim the next object in one query, or fetch and lock it in two queries
+    
+    MAX_TICK_TIME: ClassVar[int] = DEFAULT_MAX_TICK_TIME                  # maximum duration in seconds to process a single object
+    MAX_CONCURRENT_ACTORS: ClassVar[int] = DEFAULT_MAX_CONCURRENT_ACTORS  # maximum number of concurrent actors that can be running at once
+    
+    _SPAWNED_ACTOR_PIDS: ClassVar[list[psutil.Process]] = []      # used to record all the pids of Actors spawned on the class
+    
+    ### Instance attributes (only used within an actor instance inside a spawned actor thread/process)
     pid: int
     idle_count: int = 0
     launch_kwargs: LaunchKwargs = {}
     mode: Literal['thread', 'process'] = 'process'
     
-    MAX_CONCURRENT_ACTORS: ClassVar[int] = min(max(2, int(cpu_count() * 0.6)), 8)   # min 2, max 8, up to 60% of available cpu cores
-    MAX_TICK_TIME: ClassVar[int] = 60                          # maximum duration in seconds to process a single object
-    
-    QUERYSET: ClassVar[QuerySet]                      # the QuerySet to claim objects from
-    CLAIM_WHERE: ClassVar[str] = 'status = "queued"'  # the WHERE clause to filter the objects when atomically getting the next object from the queue
-    CLAIM_SET: ClassVar[str] = 'status = "started"'   # the SET clause to claim the object when atomically getting the next object from the queue
-    CLAIM_ORDER: ClassVar[str] = 'created_at DESC'    # the ORDER BY clause to sort the objects with when atomically getting the next object from the queue
-    CLAIM_FROM_TOP: ClassVar[int] = MAX_CONCURRENT_ACTORS * 10  # the number of objects to consider when atomically getting the next object from the queue
-    ATOMIC: ClassVar[bool] = True                     # whether to atomically fetch+claim the nextobject in one step, or fetch and lock it in two steps
-    
-    # model_type: Type[ModelType]
-    
-    _SPAWNED_ACTOR_PIDS: ClassVar[list[psutil.Process]] = []   # record all the pids of Actors spawned by this class
+    def __init_subclass__(cls) -> None:
+        """
+        Executed at class definition time (i.e. during import of any file containing class MyActor(ActorType[MyModel]): ...).
+        Loads the django Model from the Generic[ModelType] TypeVar arg and populates any missing class-level config using it.
+        """
+        if getattr(cls, 'Model', None) is None:
+            cls.Model = cls._get_model_from_generic_typevar()
+        cls._populate_missing_classvars_from_model(cls.Model)
     
     def __init__(self, mode: Literal['thread', 'process']|None=None, **launch_kwargs: LaunchKwargs):
+        """
+        Executed right before the Actor is spawned to create a unique Actor instance for that thread/process.
+        actor_instance.runloop() is then executed from inside the newly spawned thread/process.
+        """
         self.mode = mode or self.mode
         self.launch_kwargs = launch_kwargs or dict(self.launch_kwargs)
     
+
+    ### Private Helper Methods: Not desiged to be overridden by subclasses or called by anything outside of this class
+    
     @classproperty
     def name(cls) -> str:
         return cls.__name__  # type: ignore
     
     def __str__(self) -> str:
-        return self.__repr__()
+        return repr(self)
     
     def __repr__(self) -> str:
-        """FaviconActor[pid=1234]"""
+        """-> FaviconActor[pid=1234]"""
         label = 'pid' if self.mode == 'process' else 'tid'
         return f'[underline]{self.name}[/underline]\\[{label}={self.pid}]'
     
+    @staticmethod
+    def _state_to_str(state: ObjectState) -> str:
+        """Convert a statemachine.State, models.TextChoices.choices value, or Enum value to a str"""
+        return str(state.value) if isinstance(state, State) else str(state)
+    
+    @staticmethod
+    def _sql_for_select_top_n_candidates(qs: QuerySet, claim_from_top_n: int=CLAIM_FROM_TOP_N) -> tuple[str, tuple[Any, ...]]:
+        """Get the SQL for selecting the top N candidates from the queue (to claim one from)"""
+        queryset = qs.only('id')[:claim_from_top_n]
+        select_sql, select_params = compile_sql_select(queryset)
+        return select_sql, select_params
+    
+    @staticmethod
+    def _sql_for_update_claimed_obj(qs: QuerySet, update_kwargs: dict[str, Any]) -> tuple[str, tuple[Any, ...]]:
+        """Get the SQL for updating a claimed object to mark it as ACTIVE"""
+        # qs.update(status='started', retry_at=<now + MAX_TICK_TIME>)
+        update_sql, update_params = compile_sql_update(qs, update_kwargs=update_kwargs)
+        # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN ('succeeded', 'failed', 'sealed', 'started') AND retry_at <= '2024-11-04 10:14:33.240903'
+        return update_sql, update_params
+    
+    @classmethod
+    def _get_model_from_generic_typevar(cls) -> Type[ModelType]:
+        """Get the django Model from the Generic[ModelType] TypeVar arg (and check that it inherits from django.db.models.Model)"""
+        # cls.__orig_bases__ is non-standard and may be removed in the future! if this breaks,
+        # we can just require the inerited class to define the Model as a classvar manually, e.g.:
+        #     class SnapshotActor(ActorType[Snapshot]):
+        #         Model: ClassVar[Type[Snapshot]] = Snapshot
+        # https://stackoverflow.com/questions/57706180/generict-base-class-how-to-get-type-of-t-from-within-instance
+        Model = get_args(cls.__orig_bases__[0])[0]   # type: ignore
+        assert issubclass(Model, DjangoModel), f'{cls.__name__}.Model must be a valid django Model'
+        return cast(Type[ModelType], Model)
+    
+    @staticmethod
+    def _get_state_machine_cls(Model: Type[ModelType]) -> Type[StateMachine]:
+        """Get the StateMachine class for the given django Model that inherits from MachineMixin"""
+        assert issubclass(Model, MachineMixin), f'{Model.__name__} must inherit from MachineMixin and define a .state_machine_name: str'
+        model_state_machine_name = getattr(Model, 'state_machine_name', None)
+        if model_state_machine_name:
+            StateMachine = registry.get_machine_cls(model_state_machine_name)
+            assert issubclass(StateMachine, StateMachine)
+            return StateMachine
+        raise NotImplementedError(f'ActorType[{Model.__name__}] must define .state_machine_name: str that points to a valid StateMachine')
+    
+    @classmethod
+    def _get_state_machine_instance(cls, obj: ModelType) -> StateMachine:
+        """Get the StateMachine instance for the given django Model instance (and check that it is a valid instance of cls.StateMachineClass)"""
+        obj_statemachine = None
+        state_machine_attr = getattr(obj, 'state_machine_attr', 'sm')
+        try:
+            obj_statemachine = getattr(obj, state_machine_attr)
+        except Exception:
+            pass
+        
+        if not isinstance(obj_statemachine, cls.StateMachineClass):
+            raise Exception(f'{cls.__name__}: Failed to find a valid StateMachine instance at {type(obj).__name__}.{state_machine_attr}')
+            
+        return obj_statemachine
+    
+    @classmethod
+    def _populate_missing_classvars_from_model(cls, Model: Type[ModelType]):
+        """Check that the class variables are set correctly based on the ModelType"""
+        
+        # check that Model is the same as the Generic[ModelType] parameter in the class definition
+        cls.Model = getattr(cls, 'Model', None) or Model
+        if cls.Model != Model:
+            raise ValueError(f'{cls.__name__}.Model must be set to the same Model as the Generic[ModelType] parameter in the class definition')
+        
+        # check that Model has a valid StateMachine with the required event defined on it
+        cls.StateMachineClass = getattr(cls, 'StateMachineClass', None) or cls._get_state_machine_cls(cls.Model)
+        assert isinstance(cls.EVENT_NAME, str), f'{cls.__name__}.EVENT_NAME must be a str, got: {type(cls.EVENT_NAME).__name__} instead'
+        assert hasattr(cls.StateMachineClass, cls.EVENT_NAME), f'StateMachine {cls.StateMachineClass.__name__} must define a {cls.EVENT_NAME} event ({cls.__name__}.EVENT_NAME = {cls.EVENT_NAME})'
+        
+        # check that Model uses .id as its primary key field
+        primary_key_field = cls.Model._meta.pk.name
+        if primary_key_field != 'id':
+            raise NotImplementedError(f'Actors currently only support models that use .id as their primary key field ({cls.__name__} uses {cls.__name__}.{primary_key_field} as primary key)')
+        
+        # check for STATE_FIELD_NAME classvar or set it from the model's state_field_name attr
+        if not getattr(cls, 'STATE_FIELD_NAME', None):
+            if hasattr(cls.Model, 'state_field_name'):
+                cls.STATE_FIELD_NAME = getattr(cls.Model, 'state_field_name')
+            else:
+                raise NotImplementedError(f'{cls.__name__} must define a STATE_FIELD_NAME: ClassVar[str] (e.g. "status") or have a .state_field_name attr on its Model')
+        assert isinstance(cls.STATE_FIELD_NAME, str), f'{cls.__name__}.STATE_FIELD_NAME must be a str, got: {type(cls.STATE_FIELD_NAME).__name__} instead'
+            
+        # check for FINAL_STATES classvar or set it from the model's final_states attr
+        if not getattr(cls, 'FINAL_STATES', None):
+            cls.FINAL_STATES = cls.StateMachineClass.final_states
+            if not cls.FINAL_STATES:
+                raise NotImplementedError(f'{cls.__name__} must define a non-empty FINAL_STATES: ClassVar[list[str]] (e.g. ["sealed"]) or have a {cls.Model.__name__}.state_machine_name pointing to a StateMachine that provides .final_states')
+        cls.FINAL_STATES = [cls._state_to_str(state) for state in cls.FINAL_STATES]
+        assert all(isinstance(state, str) for state in cls.FINAL_STATES), f'{cls.__name__}.FINAL_STATES must be a list[str], got: {type(cls.FINAL_STATES).__name__} instead'
+        
+        # check for ACTIVE_STATE classvar or set it from the model's active_state attr
+        if not getattr(cls, 'ACTIVE_STATE', None):
+            raise NotImplementedError(f'{cls.__name__} must define an ACTIVE_STATE: ClassVar[State] (e.g. SnapshotMachine.started) ({cls.Model.__name__}.{cls.STATE_FIELD_NAME} gets set to this value to mark objects as actively processing)')
+        assert isinstance(cls.ACTIVE_STATE, (State, str)), f'{cls.__name__}.ACTIVE_STATE must be a statemachine.State | str, got: {type(cls.ACTIVE_STATE).__name__} instead'
+        
+        # check the other ClassVar attributes for valid values
+        assert cls.CLAIM_ORDER and isinstance(cls.CLAIM_ORDER, tuple) and all(isinstance(order, str) for order in cls.CLAIM_ORDER), f'{cls.__name__}.CLAIM_ORDER must be a non-empty tuple[str, ...], got: {type(cls.CLAIM_ORDER).__name__} instead'
+        assert cls.CLAIM_FROM_TOP_N > 0, f'{cls.__name__}.CLAIM_FROM_TOP_N must be a positive int, got: {cls.CLAIM_FROM_TOP_N} instead'
+        assert cls.MAX_TICK_TIME >= 1, f'{cls.__name__}.MAX_TICK_TIME must be a positive int > 1, got: {cls.MAX_TICK_TIME} instead'
+        assert cls.MAX_CONCURRENT_ACTORS >= 1, f'{cls.__name__}.MAX_CONCURRENT_ACTORS must be a positive int >=1, got: {cls.MAX_CONCURRENT_ACTORS} instead'
+        assert isinstance(cls.CLAIM_ATOMIC, bool), f'{cls.__name__}.CLAIM_ATOMIC must be a bool, got: {cls.CLAIM_ATOMIC} instead'
+
+    @classmethod
+    def _fork_actor_as_thread(cls, **launch_kwargs: LaunchKwargs) -> int:
+        """Spawn a new background thread running the actor's runloop"""
+        actor = cls(mode='thread', **launch_kwargs)
+        bg_actor_thread = Thread(target=actor.runloop)
+        bg_actor_thread.start()
+        assert bg_actor_thread.native_id is not None
+        return bg_actor_thread.native_id
+    
+    @classmethod
+    def _fork_actor_as_process(cls, **launch_kwargs: LaunchKwargs) -> int:
+        """Spawn a new background process running the actor's runloop"""
+        actor = cls(mode='process', **launch_kwargs)
+        bg_actor_process = Process(target=actor.runloop)
+        bg_actor_process.start()
+        assert bg_actor_process.pid is not None
+        cls._SPAWNED_ACTOR_PIDS.append(psutil.Process(pid=bg_actor_process.pid))
+        return bg_actor_process.pid
+    
+    
     ### Class Methods: Called by Orchestrator on ActorType class before it has been spawned
     
     @classmethod
@@ -94,71 +256,92 @@ class ActorType(ABC, Generic[ModelType]):
         if not queue_length:                                      # queue is empty, spawn 0 actors
             return []
         
-        actors_to_spawn: list[LaunchKwargs] = []
-        max_spawnable = cls.MAX_CONCURRENT_ACTORS - len(running_actors)
+        # WARNING:
+        # spawning new actors processes is slow/expensive, avoid spawning many actors at once in a single orchestrator tick.
+        # limit to spawning 1 or 2 at a time per orchestrator tick, and let the next tick handle starting another couple.
+        # DONT DO THIS:
+        # if queue_length > 20:                      # queue is extremely long, spawn maximum actors at once!
+        #   num_to_spawn_this_tick = cls.MAX_CONCURRENT_ACTORS
         
-        # spawning new actors is expensive, avoid spawning all the actors at once. To stagger them,
-        # let the next orchestrator tick handle starting another 2 on the next tick()
-        # if queue_length > 10:                                   # queue is long, spawn as many as possible
-        #   actors_to_spawn += max_spawnable * [{}]
+        if queue_length > 10:    
+            num_to_spawn_this_tick = 2  # spawn more actors per tick if queue is long
+        else:
+            num_to_spawn_this_tick = 1  # spawn fewer actors per tick if queue is short
+        
+        num_remaining = cls.MAX_CONCURRENT_ACTORS - len(running_actors)
+        num_to_spawn_now: int = limit(num_to_spawn_this_tick, num_remaining)
         
-        if queue_length > 4:                                    # queue is medium, spawn 1 or 2 actors
-            actors_to_spawn += min(2, max_spawnable) * [{**cls.launch_kwargs}]
-        else:                                                     # queue is short, spawn 1 actor
-            actors_to_spawn += min(1, max_spawnable) * [{**cls.launch_kwargs}]
-        return actors_to_spawn
+        actors_launch_kwargs: list[LaunchKwargs] = num_to_spawn_now * [{**cls.launch_kwargs}]
+        return actors_launch_kwargs
         
     @classmethod
     def start(cls, mode: Literal['thread', 'process']='process', **launch_kwargs: LaunchKwargs) -> int:
         if mode == 'thread':
-            return cls.fork_actor_as_thread(**launch_kwargs)
+            return cls._fork_actor_as_thread(**launch_kwargs)
         elif mode == 'process':
-            return cls.fork_actor_as_process(**launch_kwargs)
+            return cls._fork_actor_as_process(**launch_kwargs)
         raise ValueError(f'Invalid actor mode: {mode} must be "thread" or "process"')
-        
-    @classmethod
-    def fork_actor_as_thread(cls, **launch_kwargs: LaunchKwargs) -> int:
-        """Spawn a new background thread running the actor's runloop"""
-        actor = cls(mode='thread', **launch_kwargs)
-        bg_actor_thread = Thread(target=actor.runloop)
-        bg_actor_thread.start()
-        assert bg_actor_thread.native_id is not None
-        return bg_actor_thread.native_id
     
-    @classmethod
-    def fork_actor_as_process(cls, **launch_kwargs: LaunchKwargs) -> int:
-        """Spawn a new background process running the actor's runloop"""
-        actor = cls(mode='process', **launch_kwargs)
-        bg_actor_process = Process(target=actor.runloop)
-        bg_actor_process.start()
-        assert bg_actor_process.pid is not None
-        cls._SPAWNED_ACTOR_PIDS.append(psutil.Process(pid=bg_actor_process.pid))
-        return bg_actor_process.pid
+    @classproperty
+    def qs(cls) -> QuerySet[ModelType]:
+        """
+        Get the unfiltered and unsorted QuerySet of all objects that this Actor might care about.
+        Override this in the subclass to define the QuerySet of objects that the Actor is going to poll for new work.
+        (don't limit, order, or filter this by retry_at or status yet, Actor.get_queue() handles that part)
+        """
+        return cls.Model.objects.all()
     
-    @classmethod
-    def get_model(cls) -> Type[ModelType]:
-        # wish this was a @classproperty but Generic[ModelType] return type cant be statically inferred for @classproperty
-        return cls.QUERYSET.model
+    @classproperty
+    def final_q(cls) -> Q:
+        """Get the filter for objects that are in a final state"""
+        return Q(**{f'{cls.STATE_FIELD_NAME}__in': [cls._state_to_str(s) for s in cls.FINAL_STATES]})
     
-    @classmethod
-    def get_queue(cls) -> QuerySet:
-        """override this to provide your queryset as the queue"""
-        # return ArchiveResult.objects.filter(status='queued', extractor__in=('pdf', 'dom', 'screenshot'))
-        return cls.QUERYSET
+    @classproperty
+    def active_q(cls) -> Q:
+        """Get the filter for objects that are actively processing right now"""
+        return Q(**{cls.STATE_FIELD_NAME: cls._state_to_str(cls.ACTIVE_STATE)})   # e.g. Q(status='started')
     
-    ### Instance Methods: Called by Actor after it has been spawned (i.e. forked as a thread or process)
+    @classproperty
+    def stalled_q(cls) -> Q:
+        """Get the filter for objects that are marked active but have timed out"""
+        return cls.active_q & Q(retry_at__lte=timezone.now())                     # e.g. Q(status='started') AND Q(<retry_at is in the past>)
+    
+    @classproperty
+    def future_q(cls) -> Q:
+        """Get the filter for objects that have a retry_at in the future"""
+        return Q(retry_at__gt=timezone.now())
+    
+    @classproperty
+    def pending_q(cls) -> Q:
+        """Get the filter for objects that are ready for processing."""
+        return ~(cls.active_q) & ~(cls.final_q) & ~(cls.future_q)
+    
+    @classmethod
+    def get_queue(cls, sort: bool=True) -> QuerySet[ModelType]:
+        """
+        Get the sorted and filtered QuerySet of objects that are ready for processing.
+        e.g. qs.exclude(status__in=('sealed', 'started'), retry_at__gt=timezone.now()).order_by('retry_at')
+        """
+        unsorted_qs = cls.qs.filter(cls.pending_q)
+        return unsorted_qs.order_by(*cls.CLAIM_ORDER) if sort else unsorted_qs
+
+    ### Instance Methods: Only called from within Actor instance after it has been spawned (i.e. forked as a thread or process)
     
     def runloop(self):
         """The main runloop that starts running when the actor is spawned (as subprocess or thread) and exits when the queue is empty"""
         self.on_startup()
+        obj_to_process: ModelType | None = None
+        last_error: BaseException | None = None
         try:
             while True:
-                obj_to_process: ModelType | None = None
+                # Get the next object to process from the queue
                 try:
                     obj_to_process = cast(ModelType, self.get_next(atomic=self.atomic))
-                except Exception:
-                    pass
+                except (ActorQueueIsEmpty, ActorObjectAlreadyClaimed) as err:
+                    last_error = err
+                    obj_to_process = None
                 
+                # Handle the case where there is no next object to process
                 if obj_to_process:
                     self.idle_count = 0   # reset idle count if we got an object
                 else:
@@ -170,119 +353,127 @@ class ActorType(ABC, Generic[ModelType]):
                         time.sleep(1)
                         continue
                 
+                # Process the object by triggering its StateMachine.tick() method
                 self.on_tick_start(obj_to_process)
-                
-                # Process the object
                 try:
                     self.tick(obj_to_process)
                 except Exception as err:
-                    print(f'[red]🏃‍♂️ ERROR: {self}.tick()[/red]', err)
+                    last_error = err
+                    # print(f'[red]🏃‍♂️ {self}.tick()[/red] {obj_to_process} ERROR: [red]{type(err).__name__}: {err}[/red]')
                     db.connections.close_all()                         # always reset the db connection after an exception to clear any pending transactions
                     self.on_tick_exception(obj_to_process, err)
                 finally:
                     self.on_tick_end(obj_to_process)
-            
-            self.on_shutdown(err=None)
+
         except BaseException as err:
+            last_error = err
             if isinstance(err, KeyboardInterrupt):
                 print()
             else:
-                print(f'\n[red]🏃‍♂️ {self}.runloop() FATAL:[/red]', err.__class__.__name__, err)
-            self.on_shutdown(err=err)
+                print(f'\n[red]🏃‍♂️ {self}.runloop() FATAL:[/red] {type(err).__name__}: {err}')
+                print(f'    Last processed object: {obj_to_process}')
+                raise
+        finally:
+            self.on_shutdown(last_obj=obj_to_process, last_error=last_error)
+    
+    def get_update_kwargs_to_claim_obj(self) -> dict[str, Any]:
+        """
+        Get the field values needed to mark an pending obj_to_process as being actively processing (aka claimed)
+        by the current Actor. returned kwargs will be applied using: qs.filter(id=obj_to_process.id).update(**kwargs).
+        F() expressions are allowed in field values if you need to update a field based on its current value.
+        Can be a defined as a normal method (instead of classmethod) on subclasses if it needs to access instance vars.
+        """
+        return {
+            self.STATE_FIELD_NAME: self.ACTIVE_STATE,
+            'retry_at': timezone.now() + timedelta(seconds=self.MAX_TICK_TIME),
+        }
     
     def get_next(self, atomic: bool | None=None) -> ModelType | None:
         """get the next object from the queue, atomically locking it if self.atomic=True"""
-        if atomic is None:
-            atomic = self.ATOMIC
-
+        atomic = self.CLAIM_ATOMIC if atomic is None else atomic
         if atomic:
             # fetch and claim the next object from in the queue in one go atomically
             obj = self.get_next_atomic()
         else:
             # two-step claim: fetch the next object and lock it in a separate query
-            obj = self.get_queue().last()
-            assert obj and self.lock_next(obj), f'Unable to fetch+lock the next {self.get_model().__name__} ojbect from {self}.QUEUE'
+            obj = self.get_next_non_atomic()
         return obj
     
-    def lock_next(self, obj: ModelType) -> bool:
-        """override this to implement a custom two-step (non-atomic)lock mechanism"""
-        # For example:
-        # assert obj._model.objects.filter(pk=obj.pk, status='queued').update(status='started', locked_by=self.pid)
-        # Not needed if using get_next_and_lock() to claim the object atomically
-        # print(f'[blue]🏃‍♂️ {self}.lock()[/blue]', obj.abid or obj.id)
-        return True
-    
-    def claim_sql_where(self) -> str:
-        """override this to implement a custom WHERE clause for the atomic claim step e.g. "status = 'queued' AND locked_by = NULL" """
-        return self.CLAIM_WHERE
-    
-    def claim_sql_set(self) -> str:
-        """override this to implement a custom SET clause for the atomic claim step e.g. "status = 'started' AND locked_by = {self.pid}" """
-        return self.CLAIM_SET
-    
-    def claim_sql_order(self) -> str:
-        """override this to implement a custom ORDER BY clause for the atomic claim step e.g. "created_at DESC" """
-        return self.CLAIM_ORDER
-    
-    def claim_from_top(self) -> int:
-        """override this to implement a custom number of objects to consider when atomically claiming the next object from the top of the queue"""
-        return self.CLAIM_FROM_TOP
-        
-    def get_next_atomic(self, shallow: bool=True) -> ModelType | None:
+    def get_next_non_atomic(self) -> ModelType:
         """
-        claim a random object from the top n=50 objects in the queue (atomically updates status=queued->started for claimed object)
-        optimized for minimizing contention on the queue with other actors selecting from the same list
-        slightly faster than claim_any_obj() which selects randomly from the entire queue but needs to know the total count
+        Naiively selects the top/first object from self.get_queue().order_by(*self.CLAIM_ORDER),
+        then claims it by running .update(status='started', retry_at=<now + MAX_TICK_TIME>).
+        
+        Do not use this method if there is more than one Actor racing to get objects from the same queue,
+        it will be slow/buggy as they'll compete to lock the same object at the same time (TOCTTOU race).
         """
-        Model = self.get_model()                                     # e.g. ArchiveResult
-        table = f'{Model._meta.app_label}_{Model._meta.model_name}'  # e.g. core_archiveresult
+        obj = self.get_queue().first()
+        if obj is None:
+            raise ActorQueueIsEmpty(f'No next object available in {self}.get_queue()')
         
-        where_sql = self.claim_sql_where()
-        set_sql = self.claim_sql_set()
-        order_by_sql = self.claim_sql_order()
-        choose_from_top = self.claim_from_top()
+        locked = self.get_queue().filter(id=obj.id).update(**self.get_update_kwargs_to_claim_obj())
+        if not locked:
+            raise ActorObjectAlreadyClaimed(f'Unable to lock the next {self.Model.__name__} object from {self}.get_queue().first()')
+        return obj
         
-        with db.connection.cursor() as cursor:
-            # subquery gets the pool of the top 50 candidates sorted by sort and order
-            # main query selects a random one from that pool
-            cursor.execute(f"""
-                UPDATE {table} 
-                SET {set_sql}
-                WHERE {where_sql} and id = (
-                    SELECT id FROM (
-                        SELECT id FROM {table}
-                        WHERE {where_sql}
-                        ORDER BY {order_by_sql}
-                        LIMIT {choose_from_top}
-                    ) candidates
-                    ORDER BY RANDOM()
-                    LIMIT 1
-                )
-                RETURNING id;
-            """)
-            result = cursor.fetchone()
-            
-            if result is None:
-                return None           # If no rows were claimed, return None
-
-            if shallow:
-                # shallow: faster, returns potentially incomplete object instance missing some django auto-populated fields:
-                columns = [col[0] for col in cursor.description or ['id']]
-                return Model(**dict(zip(columns, result)))
+    def get_next_atomic(self) -> ModelType | None:
+        """
+        Selects the top n=50 objects from the queue and atomically claims a random one from that set.
+        This approach safely minimizes contention with other Actors trying to select from the same Queue.
 
-            # if not shallow do one extra query to get a more complete object instance (load it fully from scratch)
-            return Model.objects.get(id=result[0])
+        The atomic query is roughly equivalent to the following:  (all done in one SQL query to avoid a TOCTTOU race)
+            top_candidates are selected from:   qs.order_by(*CLAIM_ORDER).only('id')[:CLAIM_FROM_TOP_N]
+            a single candidate is chosen using: qs.filter(id__in=top_n_candidates).order_by('?').first()
+            the chosen obj is claimed using:    qs.filter(id=chosen_obj).update(status=ACTIVE_STATE, retry_at=<now + MAX_TICK_TIME>)
+        """
+        # TODO: if we switch from SQLite to PostgreSQL in the future, we should change this
+        # to use SELECT FOR UPDATE instead of a subquery + ORDER BY RANDOM() LIMIT 1
+        
+        # e.g. SELECT id FROM core_archiveresult WHERE status NOT IN (...) AND retry_at <= '...' ORDER BY retry_at ASC LIMIT 50
+        qs = self.get_queue()
+        select_top_canidates_sql, select_params = self._sql_for_select_top_n_candidates(qs=qs)
+        assert select_top_canidates_sql.startswith('SELECT ')
+        
+        # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (...) AND retry_at <= '...'
+        update_claimed_obj_sql, update_params = self._sql_for_update_claimed_obj(qs=qs, update_kwargs=self.get_update_kwargs_to_claim_obj())
+        assert update_claimed_obj_sql.startswith('UPDATE ')
+        db_table = self.Model._meta.db_table  # e.g. core_archiveresult
+        
+        # subquery gets the pool of the top candidates e.g. self.get_queue().only('id')[:CLAIM_FROM_TOP_N]
+        # main query selects a random one from that pool, and claims it using .update(status=ACTIVE_STATE, retry_at=<now + MAX_TICK_TIME>)
+        # this is all done in one atomic SQL query to avoid TOCTTOU race conditions (as much as possible)
+        atomic_select_and_update_sql = f"""
+            {update_claimed_obj_sql} AND "{db_table}"."id" = (
+                SELECT "{db_table}"."id" FROM (
+                    {select_top_canidates_sql}
+                ) candidates
+                ORDER BY RANDOM()
+                LIMIT 1
+            )
+            RETURNING *;
+        """
+        try:
+            return self.Model.objects.raw(atomic_select_and_update_sql, (*update_params, *select_params))[0]
+        except KeyError:
+            if self.get_queue().exists():
+                raise ActorObjectAlreadyClaimed(f'Unable to lock the next {self.Model.__name__} object from {self}.get_queue().first()')
+            else:
+                raise ActorQueueIsEmpty(f'No next object available in {self}.get_queue()')
 
-    @abstractmethod
-    def tick(self, obj: ModelType) -> None:
-        """override this to process the object"""
-        print(f'[blue]🏃‍♂️ {self}.tick()[/blue]', obj.abid or obj.id)
-        # For example:
-        # do_some_task(obj)
-        # do_something_else(obj)
-        # obj._model.objects.filter(pk=obj.pk, status='started').update(status='success')
-        raise NotImplementedError('tick() must be implemented by the Actor subclass')
-    
+    def tick(self, obj_to_process: ModelType) -> None:
+        """Call the object.sm.tick() method to process the object"""
+        print(f'[blue]🏃‍♂️ {self}.tick()[/blue] {obj_to_process}')
+        
+        # get the StateMachine instance from the object
+        obj_statemachine = self._get_state_machine_instance(obj_to_process)
+        
+        # trigger the event on the StateMachine instance
+        obj_tick_method = getattr(obj_statemachine, self.EVENT_NAME)  # e.g. obj_statemachine.tick()
+        obj_tick_method()
+        
+        # save the object to persist any state changes
+        obj_to_process.save()
+        
     def on_startup(self) -> None:
         if self.mode == 'thread':
             self.pid = get_native_id()  # thread id
@@ -290,24 +481,91 @@ class ActorType(ABC, Generic[ModelType]):
         else:
             self.pid = os.getpid()      # process id
             print(f'[green]🏃‍♂️ {self}.on_startup() STARTUP (PROCESS)[/green]')
-        # abx.pm.hook.on_actor_startup(self)
+        # abx.pm.hook.on_actor_startup(actor=self)
         
-    def on_shutdown(self, err: BaseException | None=None) -> None:
-        print(f'[grey53]🏃‍♂️ {self}.on_shutdown() SHUTTING DOWN[/grey53]', err or '[green](gracefully)[/green]')
-        # abx.pm.hook.on_actor_shutdown(self)
+    def on_shutdown(self, last_obj: ModelType | None=None, last_error: BaseException | None=None) -> None:
+        if isinstance(last_error, KeyboardInterrupt) or last_error is None:
+            last_error_str = '[green](CTRL-C)[/green]'
+        elif isinstance(last_error, ActorQueueIsEmpty):
+            last_error_str = '[green](queue empty)[/green]'
+        elif isinstance(last_error, ActorObjectAlreadyClaimed):
+            last_error_str = '[green](queue race)[/green]'
+        else:
+            last_error_str = f'[red]{type(last_error).__name__}: {last_error}[/red]'
+
+        print(f'[grey53]🏃‍♂️ {self}.on_shutdown() SHUTTING DOWN[/grey53] {last_error_str}')
+        # abx.pm.hook.on_actor_shutdown(actor=self, last_obj=last_obj, last_error=last_error)
         
-    def on_tick_start(self, obj: ModelType) -> None:
-        # print(f'🏃‍♂️ {self}.on_tick_start()', obj.abid or obj.id)
-        # abx.pm.hook.on_actor_tick_start(self, obj_to_process)
+    def on_tick_start(self, obj_to_process: ModelType) -> None:
+        print(f'🏃‍♂️ {self}.on_tick_start() {obj_to_process}')
+        # abx.pm.hook.on_actor_tick_start(actor=self, obj_to_process=obj)
         # self.timer = TimedProgress(self.MAX_TICK_TIME, prefix='      ')
-        pass
     
-    def on_tick_end(self, obj: ModelType) -> None:
-        # print(f'🏃‍♂️ {self}.on_tick_end()', obj.abid or obj.id)
-        # abx.pm.hook.on_actor_tick_end(self, obj_to_process)
+    def on_tick_end(self, obj_to_process: ModelType) -> None:
+        print(f'🏃‍♂️ {self}.on_tick_end() {obj_to_process}')
+        # abx.pm.hook.on_actor_tick_end(actor=self, obj_to_process=obj_to_process)
         # self.timer.end()
-        pass
     
-    def on_tick_exception(self, obj: ModelType, err: BaseException) -> None:
-        print(f'[red]🏃‍♂️ {self}.on_tick_exception()[/red]', obj.abid or obj.id, err)
-        # abx.pm.hook.on_actor_tick_exception(self, obj_to_process, err)
+    def on_tick_exception(self, obj_to_process: ModelType, error: Exception) -> None:
+        print(f'[red]🏃‍♂️ {self}.on_tick_exception()[/red] {obj_to_process}: [red]{type(error).__name__}: {error}[/red]')
+        # abx.pm.hook.on_actor_tick_exception(actor=self, obj_to_process=obj_to_process, error=error)
+
+
+
+def compile_sql_select(queryset: QuerySet, filter_kwargs: dict[str, Any] | None=None, order_args: tuple[str, ...]=(), limit: int | None=None) -> tuple[str, tuple[Any, ...]]:
+    """
+    Compute the SELECT query SQL for a queryset.filter(**filter_kwargs).order_by(*order_args)[:limit] call
+    Returns a tuple of (sql, params) where sql is a template string containing %s (unquoted) placeholders for the params
+    
+    WARNING:
+    final_sql = sql % params  DOES NOT WORK to assemble the final SQL string because the %s placeholders are not quoted/escaped
+    they should always passed separately to the DB driver so it can do its own quoting/escaping to avoid SQL injection and syntax errors
+    """
+    assert isinstance(queryset, QuerySet), f'compile_sql_select(...) first argument must be a QuerySet, got: {type(queryset).__name__} instead'
+    assert filter_kwargs is None or isinstance(filter_kwargs, dict), f'compile_sql_select(...) filter_kwargs argument must be a dict[str, Any], got: {type(filter_kwargs).__name__} instead'
+    assert isinstance(order_args, tuple) and all(isinstance(arg, str) for arg in order_args), f'compile_sql_select(...) order_args argument must be a tuple[str, ...] got: {type(order_args).__name__} instead'
+    assert limit is None or isinstance(limit, int), f'compile_sql_select(...) limit argument must be an int, got: {type(limit).__name__} instead'
+    
+    queryset = queryset._chain()                      # type: ignore   # copy queryset to avoid modifying the original
+    if filter_kwargs:
+        queryset = queryset.filter(**filter_kwargs)
+    if order_args:
+        queryset = queryset.order_by(*order_args)
+    if limit is not None:
+        queryset = queryset[:limit]
+    query = queryset.query
+    
+    # e.g. SELECT id FROM core_archiveresult WHERE status NOT IN (%s, %s, %s) AND retry_at <= %s ORDER BY retry_at ASC LIMIT 50
+    select_sql, select_params = query.get_compiler(queryset.db).as_sql()
+    return select_sql, select_params
+
+
+def compile_sql_update(queryset: QuerySet, update_kwargs: dict[str, Any], filter_kwargs: dict[str, Any] | None=None) -> tuple[str, tuple[Any, ...]]:
+    """
+    Compute the UPDATE query SQL for a queryset.filter(**filter_kwargs).update(**update_kwargs) call
+    Returns a tuple of (sql, params) where sql is a template string containing %s (unquoted) placeholders for the params
+    
+    Based on the django.db.models.QuerySet.update() source code, but modified to return the SQL instead of executing the update
+    https://github.com/django/django/blob/611bf6c2e2a1b4ab93273980c45150c099ab146d/django/db/models/query.py#L1217
+    
+    WARNING:
+    final_sql = sql % params  DOES NOT WORK to assemble the final SQL string because the %s placeholders are not quoted/escaped
+    they should always passed separately to the DB driver so it can do its own quoting/escaping to avoid SQL injection and syntax errors
+    """
+    assert isinstance(queryset, QuerySet), f'compile_sql_update(...) first argument must be a QuerySet, got: {type(queryset).__name__} instead'
+    assert isinstance(update_kwargs, dict), f'compile_sql_update(...) update_kwargs argument must be a dict[str, Any], got: {type(update_kwargs).__name__} instead'
+    assert filter_kwargs is None or isinstance(filter_kwargs, dict), f'compile_sql_update(...) filter_kwargs argument must be a dict[str, Any], got: {type(filter_kwargs).__name__} instead'
+    
+    queryset = queryset._chain()                      # type: ignore   # copy queryset to avoid modifying the original
+    if filter_kwargs:
+        queryset = queryset.filter(**filter_kwargs)
+    queryset.query.clear_ordering(force=True)                          # clear any ORDER BY clauses
+    queryset.query.clear_limits()                                      # clear any LIMIT clauses aka slices[:n]
+    queryset._for_write = True                        # type: ignore
+    query = queryset.query.chain(sql.UpdateQuery)     # type: ignore
+    query.add_update_values(update_kwargs)            # type: ignore
+    query.annotations = {}                                             # clear any annotations
+    
+    # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (%s, %s, %s) AND retry_at <= %s
+    update_sql, update_params = query.get_compiler(queryset.db).as_sql()
+    return update_sql, update_params

+ 298 - 1
archivebox/actors/models.py

@@ -1,3 +1,300 @@
+from typing import ClassVar, Type, Iterable
+from datetime import datetime, timedelta
+
+from statemachine.mixins import MachineMixin
+
 from django.db import models
+from django.utils import timezone
+from django.utils.functional import classproperty
+
+from statemachine import registry, StateMachine, State
+
+from django.core import checks
+
+class DefaultStatusChoices(models.TextChoices):
+    QUEUED = 'queued', 'Queued'
+    STARTED = 'started', 'Started'
+    SEALED = 'sealed', 'Sealed'
+
+
+default_status_field: models.CharField = models.CharField(choices=DefaultStatusChoices.choices, max_length=15, default=DefaultStatusChoices.QUEUED, null=False, blank=False, db_index=True)
+default_retry_at_field: models.DateTimeField = models.DateTimeField(default=timezone.now, null=False, db_index=True)
+
+ObjectState = State | str
+ObjectStateList = Iterable[ObjectState]
+
+
+class BaseModelWithStateMachine(models.Model, MachineMixin):
+    id: models.UUIDField
+    
+    StatusChoices: ClassVar[Type[models.TextChoices]]
+    
+    # status: models.CharField
+    # retry_at: models.DateTimeField
+    
+    state_machine_name: ClassVar[str]
+    state_field_name: ClassVar[str]
+    state_machine_attr: ClassVar[str] = 'sm'
+    bind_events_as_methods: ClassVar[bool] = True
+    
+    active_state: ClassVar[ObjectState]
+    retry_at_field_name: ClassVar[str]
+    
+    class Meta:
+        abstract = True
+        
+    @classmethod
+    def check(cls, sender=None, **kwargs):
+        errors = super().check(**kwargs)
+        
+        found_id_field = False
+        found_status_field = False
+        found_retry_at_field = False
+        
+        for field in cls._meta.get_fields():
+            if getattr(field, '_is_state_field', False):
+                if cls.state_field_name == field.name:
+                    found_status_field = True
+                    if getattr(field, 'choices', None) != cls.StatusChoices.choices:
+                        errors.append(checks.Error(
+                            f'{cls.__name__}.{field.name} must have choices set to {cls.__name__}.StatusChoices.choices',
+                            hint=f'{cls.__name__}.{field.name}.choices = {getattr(field, "choices", None)!r}',
+                            obj=cls,
+                            id='actors.E011',
+                        ))
+            if getattr(field, '_is_retry_at_field', False):
+                if cls.retry_at_field_name == field.name:
+                    found_retry_at_field = True
+            if field.name == 'id' and getattr(field, 'primary_key', False):
+                found_id_field = True
+                    
+        if not found_status_field:
+            errors.append(checks.Error(
+                f'{cls.__name__}.state_field_name must be defined and point to a StatusField()',
+                hint=f'{cls.__name__}.state_field_name = {cls.state_field_name!r} but {cls.__name__}.{cls.state_field_name!r} was not found or does not refer to StatusField',
+                obj=cls,
+                id='actors.E012',
+            ))
+        if not found_retry_at_field:
+            errors.append(checks.Error(
+                f'{cls.__name__}.retry_at_field_name must be defined and point to a RetryAtField()',
+                hint=f'{cls.__name__}.retry_at_field_name = {cls.retry_at_field_name!r} but {cls.__name__}.{cls.retry_at_field_name!r} was not found or does not refer to RetryAtField',
+                obj=cls,
+                id='actors.E013',
+            ))
+            
+        if not found_id_field:
+            errors.append(checks.Error(
+                f'{cls.__name__} must have an id field that is a primary key',
+                hint=f'{cls.__name__}.id = {cls.id!r}',
+                obj=cls,
+                id='actors.E014',
+            ))
+            
+        if not isinstance(cls.state_machine_name, str):
+            errors.append(checks.Error(
+                f'{cls.__name__}.state_machine_name must be a dotted-import path to a StateMachine class',
+                hint=f'{cls.__name__}.state_machine_name = {cls.state_machine_name!r}',
+                obj=cls,
+                id='actors.E015',
+            ))
+        
+        try:
+            cls.StateMachineClass
+        except Exception as err:
+            errors.append(checks.Error(
+                f'{cls.__name__}.state_machine_name must point to a valid StateMachine class, but got {type(err).__name__} {err} when trying to access {cls.__name__}.StateMachineClass',
+                hint=f'{cls.__name__}.state_machine_name = {cls.state_machine_name!r}',
+                obj=cls,
+                id='actors.E016',
+            ))
+        
+        if cls.INITIAL_STATE not in cls.StatusChoices.values:
+            errors.append(checks.Error(
+                f'{cls.__name__}.StateMachineClass.initial_state must be present within {cls.__name__}.StatusChoices',
+                hint=f'{cls.__name__}.StateMachineClass.initial_state = {cls.StateMachineClass.initial_state!r}',
+                obj=cls,
+                id='actors.E017',
+            ))
+            
+        if cls.ACTIVE_STATE not in cls.StatusChoices.values:
+            errors.append(checks.Error(
+                f'{cls.__name__}.active_state must be set to a valid State present within {cls.__name__}.StatusChoices',
+                hint=f'{cls.__name__}.active_state = {cls.active_state!r}',
+                obj=cls,
+                id='actors.E018',
+            ))
+            
+        
+        for state in cls.FINAL_STATES:
+            if state not in cls.StatusChoices.values:
+                errors.append(checks.Error(
+                    f'{cls.__name__}.StateMachineClass.final_states must all be present within {cls.__name__}.StatusChoices',
+                    hint=f'{cls.__name__}.StateMachineClass.final_states = {cls.StateMachineClass.final_states!r}',
+                    obj=cls,
+                    id='actors.E019',
+                ))
+                break
+        return errors
+    
+    @staticmethod
+    def _state_to_str(state: ObjectState) -> str:
+        """Convert a statemachine.State, models.TextChoices.choices value, or Enum value to a str"""
+        return str(state.value) if isinstance(state, State) else str(state)
+    
+    
+    @property
+    def RETRY_AT(self) -> datetime:
+        return getattr(self, self.retry_at_field_name)
+    
+    @RETRY_AT.setter
+    def RETRY_AT(self, value: datetime):
+        setattr(self, self.retry_at_field_name, value)
+        
+    @property
+    def STATE(self) -> str:
+        return getattr(self, self.state_field_name)
+    
+    @STATE.setter
+    def STATE(self, value: str):
+        setattr(self, self.state_field_name, value)
+        
+    def bump_retry_at(self, seconds: int = 10):
+        self.RETRY_AT = timezone.now() + timedelta(seconds=seconds)
+        
+    @classproperty
+    def ACTIVE_STATE(cls) -> str:
+        return cls._state_to_str(cls.StateMachineClass.active_state)
+        
+    @classproperty
+    def INITIAL_STATE(cls) -> str:
+        return cls._state_to_str(cls.StateMachineClass.initial_state)
+    
+    @classproperty
+    def FINAL_STATES(cls) -> list[str]:
+        return [cls._state_to_str(state) for state in cls.StateMachineClass.final_states]
+    
+    @classproperty
+    def FINAL_OR_ACTIVE_STATES(cls) -> list[str]:
+        return [*cls.FINAL_STATES, cls.ACTIVE_STATE]
+        
+    @classmethod
+    def extend_choices(cls, base_choices: Type[models.TextChoices]):
+        """
+        Decorator to extend the base choices with extra choices, e.g.:
+        
+        class MyModel(ModelWithStateMachine):
+        
+            @ModelWithStateMachine.extend_choices(ModelWithStateMachine.StatusChoices)
+            class StatusChoices(models.TextChoices):
+                SUCCEEDED = 'succeeded'
+                FAILED = 'failed'
+                SKIPPED = 'skipped'
+        """
+        assert issubclass(base_choices, models.TextChoices), f'@extend_choices(base_choices) must be a TextChoices class, not {base_choices.__name__}'
+        def wrapper(extra_choices: Type[models.TextChoices]) -> Type[models.TextChoices]:
+            joined = {}
+            for item in base_choices.choices:
+                joined[item[0]] = item[1]
+            for item in extra_choices.choices:
+                joined[item[0]] = item[1]
+            return models.TextChoices('StatusChoices', joined)
+        return wrapper
+        
+    @classmethod
+    def StatusField(cls, **kwargs) -> models.CharField:
+        """
+        Used on subclasses to extend/modify the status field with updated kwargs. e.g.:
+        
+        class MyModel(ModelWithStateMachine):
+            class StatusChoices(ModelWithStateMachine.StatusChoices):
+                QUEUED = 'queued', 'Queued'
+                STARTED = 'started', 'Started'
+                SEALED = 'sealed', 'Sealed'
+                BACKOFF = 'backoff', 'Backoff'
+                FAILED = 'failed', 'Failed'
+                SKIPPED = 'skipped', 'Skipped'
+        
+            status = ModelWithStateMachine.StatusField(choices=StatusChoices.choices, default=StatusChoices.QUEUED)
+        """
+        default_kwargs = default_status_field.deconstruct()[3]
+        updated_kwargs = {**default_kwargs, **kwargs}
+        field = models.CharField(**updated_kwargs)
+        field._is_state_field = True                    # type: ignore
+        return field
+
+    @classmethod
+    def RetryAtField(cls, **kwargs) -> models.DateTimeField:
+        """
+        Used on subclasses to extend/modify the retry_at field with updated kwargs. e.g.:
+        
+        class MyModel(ModelWithStateMachine):
+            retry_at = ModelWithStateMachine.RetryAtField(editable=False)
+        """
+        default_kwargs = default_retry_at_field.deconstruct()[3]
+        updated_kwargs = {**default_kwargs, **kwargs}
+        field = models.DateTimeField(**updated_kwargs)
+        field._is_retry_at_field = True                 # type: ignore
+        return field
+    
+    @classproperty
+    def StateMachineClass(cls) -> Type[StateMachine]:
+        """Get the StateMachine class for the given django Model that inherits from MachineMixin"""
+
+        model_state_machine_name = getattr(cls, 'state_machine_name', None)
+        if model_state_machine_name:
+            StateMachineCls = registry.get_machine_cls(model_state_machine_name)
+            assert issubclass(StateMachineCls, StateMachine)
+            return StateMachineCls
+        raise NotImplementedError(f'ActorType[{cls.__name__}] must define .state_machine_name: str that points to a valid StateMachine')
+    
+    # @classproperty
+    # def final_q(cls) -> Q:
+    #     """Get the filter for objects that are in a final state"""
+    #     return Q(**{f'{cls.state_field_name}__in': cls.final_states})
+    
+    # @classproperty
+    # def active_q(cls) -> Q:
+    #     """Get the filter for objects that are actively processing right now"""
+    #     return Q(**{cls.state_field_name: cls._state_to_str(cls.active_state)})   # e.g. Q(status='started')
+    
+    # @classproperty
+    # def stalled_q(cls) -> Q:
+    #     """Get the filter for objects that are marked active but have timed out"""
+    #     return cls.active_q & Q(retry_at__lte=timezone.now())                     # e.g. Q(status='started') AND Q(<retry_at is in the past>)
+    
+    # @classproperty
+    # def future_q(cls) -> Q:
+    #     """Get the filter for objects that have a retry_at in the future"""
+    #     return Q(retry_at__gt=timezone.now())
+    
+    # @classproperty
+    # def pending_q(cls) -> Q:
+    #     """Get the filter for objects that are ready for processing."""
+    #     return ~(cls.active_q) & ~(cls.final_q) & ~(cls.future_q)
+    
+    # @classmethod
+    # def get_queue(cls) -> QuerySet:
+    #     """
+    #     Get the sorted and filtered QuerySet of objects that are ready for processing.
+    #     e.g. qs.exclude(status__in=('sealed', 'started'), retry_at__gt=timezone.now()).order_by('retry_at')
+    #     """
+    #     return cls.objects.filter(cls.pending_q)
+
 
-# Create your models here.
+class ModelWithStateMachine(BaseModelWithStateMachine):
+    StatusChoices: ClassVar[Type[DefaultStatusChoices]] = DefaultStatusChoices
+    
+    status: models.CharField = BaseModelWithStateMachine.StatusField()
+    retry_at: models.DateTimeField = BaseModelWithStateMachine.RetryAtField()
+    
+    state_machine_name: ClassVar[str]      # e.g. 'core.statemachines.ArchiveResultMachine'
+    state_field_name: ClassVar[str]        = 'status'
+    state_machine_attr: ClassVar[str]      = 'sm'
+    bind_events_as_methods: ClassVar[bool] = True
+    
+    active_state: ClassVar[str]            = StatusChoices.STARTED
+    retry_at_field_name: ClassVar[str]     = 'retry_at'
+    
+    class Meta:
+        abstract = True

+ 1 - 1
archivebox/actors/orchestrator.py

@@ -6,7 +6,7 @@ import itertools
 from typing import Dict, Type, Literal, ClassVar
 from django.utils.functional import classproperty
 
-from multiprocessing import Process, cpu_count
+from multiprocessing import Process
 from threading import Thread, get_native_id
 
 

+ 0 - 286
archivebox/actors/statemachine.py

@@ -1,286 +0,0 @@
-from statemachine import State, StateMachine
-from django.db import models
-from multiprocessing import Process
-import psutil
-import time
-
-# State Machine Definitions
-#################################################
-
-class SnapshotMachine(StateMachine):
-    """State machine for managing Snapshot lifecycle."""
-    
-    # States
-    queued = State(initial=True)
-    started = State()
-    sealed = State(final=True)
-    
-    # Transitions
-    start = queued.to(started, cond='can_start')
-    seal = started.to(sealed, cond='is_finished')
-    
-    # Events
-    tick = (
-        queued.to.itself(unless='can_start') |
-        queued.to(started, cond='can_start') |
-        started.to.itself(unless='is_finished') |
-        started.to(sealed, cond='is_finished')
-    )
-    
-    def __init__(self, snapshot):
-        self.snapshot = snapshot
-        super().__init__()
-        
-    def can_start(self):
-        return True
-        
-    def is_finished(self):
-        return not self.snapshot.has_pending_archiveresults()
-        
-    def before_start(self):
-        """Pre-start validation and setup."""
-        self.snapshot.cleanup_dir()
-        
-    def after_start(self):
-        """Post-start side effects."""
-        self.snapshot.create_pending_archiveresults()
-        self.snapshot.update_indices()
-        self.snapshot.bump_retry_at(seconds=10)
-        
-    def before_seal(self):
-        """Pre-seal validation and cleanup."""
-        self.snapshot.cleanup_dir()
-        
-    def after_seal(self):
-        """Post-seal actions."""
-        self.snapshot.update_indices()
-        self.snapshot.seal_dir()
-        self.snapshot.upload_dir()
-        self.snapshot.retry_at = None
-        self.snapshot.save()
-
-
-class ArchiveResultMachine(StateMachine):
-    """State machine for managing ArchiveResult lifecycle."""
-    
-    # States
-    queued = State(initial=True)
-    started = State()
-    succeeded = State(final=True)
-    backoff = State()
-    failed = State(final=True)
-    
-    # Transitions
-    start = queued.to(started, cond='can_start')
-    succeed = started.to(succeeded, cond='extractor_succeeded')
-    backoff = started.to(backoff, unless='extractor_succeeded')
-    retry = backoff.to(queued, cond='can_retry')
-    fail = backoff.to(failed, unless='can_retry')
-    
-    # Events
-    tick = (
-        queued.to.itself(unless='can_start') |
-        queued.to(started, cond='can_start') |
-        started.to.itself(cond='extractor_still_running') |
-        started.to(succeeded, cond='extractor_succeeded') |
-        started.to(backoff, unless='extractor_succeeded') |
-        backoff.to.itself(cond='still_waiting_to_retry') |
-        backoff.to(queued, cond='can_retry') |
-        backoff.to(failed, unless='can_retry')
-    )
-    
-    def __init__(self, archiveresult):
-        self.archiveresult = archiveresult
-        super().__init__()
-    
-    def can_start(self):
-        return True
-    
-    def extractor_still_running(self):
-        return self.archiveresult.start_ts > time.now() - timedelta(seconds=5)
-    
-    def extractor_succeeded(self):
-        # return check_if_extractor_succeeded(self.archiveresult)
-        return self.archiveresult.start_ts < time.now() - timedelta(seconds=5)
-    
-    def can_retry(self):
-        return self.archiveresult.retries < self.archiveresult.max_retries
-        
-    def before_start(self):
-        """Pre-start initialization."""
-        self.archiveresult.retries += 1
-        self.archiveresult.start_ts = time.now()
-        self.archiveresult.output = None
-        self.archiveresult.error = None
-        
-    def after_start(self):
-        """Post-start execution."""
-        self.archiveresult.bump_retry_at(seconds=self.archiveresult.timeout + 5)
-        execute_extractor(self.archiveresult)
-        self.archiveresult.snapshot.bump_retry_at(seconds=5)
-        
-    def before_succeed(self):
-        """Pre-success validation."""
-        self.archiveresult.output = get_archiveresult_output(self.archiveresult)
-        
-    def after_succeed(self):
-        """Post-success cleanup."""
-        self.archiveresult.end_ts = time.now()
-        self.archiveresult.retry_at = None
-        self.archiveresult.update_indices()
-        
-    def before_backoff(self):
-        """Pre-backoff error capture."""
-        self.archiveresult.error = get_archiveresult_error(self.archiveresult)
-        
-    def after_backoff(self):
-        """Post-backoff retry scheduling."""
-        self.archiveresult.end_ts = time.now()
-        self.archiveresult.bump_retry_at(
-            seconds=self.archiveresult.timeout * self.archiveresult.retries
-        )
-        self.archiveresult.update_indices()
-        
-    def before_fail(self):
-        """Pre-failure finalization."""
-        self.archiveresult.retry_at = None
-        
-    def after_fail(self):
-        """Post-failure cleanup."""
-        self.archiveresult.update_indices()
-
-# Models
-#################################################
-
-class Snapshot(models.Model):
-    status = models.CharField(max_length=32, default='queued')
-    retry_at = models.DateTimeField(null=True)
-    
-    @property
-    def sm(self):
-        """Get the state machine for this snapshot."""
-        return SnapshotMachine(self)
-    
-    def has_pending_archiveresults(self):
-        return self.archiveresult_set.exclude(
-            status__in=['succeeded', 'failed']
-        ).exists()
-    
-    def bump_retry_at(self, seconds):
-        self.retry_at = time.now() + timedelta(seconds=seconds)
-        self.save()
-        
-    def cleanup_dir(self):
-        cleanup_snapshot_dir(self)
-        
-    def create_pending_archiveresults(self):
-        create_snapshot_pending_archiveresults(self)
-        
-    def update_indices(self):
-        update_snapshot_index_json(self)
-        update_snapshot_index_html(self)
-        
-    def seal_dir(self):
-        seal_snapshot_dir(self)
-        
-    def upload_dir(self):
-        upload_snapshot_dir(self)
-
-
-class ArchiveResult(models.Model):
-    snapshot = models.ForeignKey(Snapshot, on_delete=models.CASCADE)
-    status = models.CharField(max_length=32, default='queued')
-    retry_at = models.DateTimeField(null=True)
-    retries = models.IntegerField(default=0)
-    max_retries = models.IntegerField(default=3)
-    timeout = models.IntegerField(default=60)
-    start_ts = models.DateTimeField(null=True)
-    end_ts = models.DateTimeField(null=True)
-    output = models.TextField(null=True)
-    error = models.TextField(null=True)
-    
-    def get_machine(self):
-        return ArchiveResultMachine(self)
-    
-    def bump_retry_at(self, seconds):
-        self.retry_at = time.now() + timedelta(seconds=seconds)
-        self.save()
-        
-    def update_indices(self):
-        update_archiveresult_index_json(self)
-        update_archiveresult_index_html(self)
-
-
-# Actor System
-#################################################
-
-class BaseActor:
-    MAX_TICK_TIME = 60
-    
-    def tick(self, obj):
-        """Process a single object through its state machine."""
-        machine = obj.get_machine()
-        
-        if machine.is_queued:
-            if machine.can_start():
-                machine.start()
-                
-        elif machine.is_started:
-            if machine.can_seal():
-                machine.seal()
-                
-        elif machine.is_backoff:
-            if machine.can_retry():
-                machine.retry()
-            else:
-                machine.fail()
-
-
-class Orchestrator:
-    """Main orchestrator that manages all actors."""
-    
-    def __init__(self):
-        self.pid = None
-        
-    @classmethod
-    def spawn(cls):
-        orchestrator = cls()
-        proc = Process(target=orchestrator.runloop)
-        proc.start()
-        return proc.pid
-        
-    def runloop(self):
-        self.pid = os.getpid()
-        abx.pm.hook.on_orchestrator_startup(self)
-        
-        try:
-            while True:
-                self.process_queue(Snapshot)
-                self.process_queue(ArchiveResult)
-                time.sleep(0.1)
-                
-        except (KeyboardInterrupt, SystemExit):
-            abx.pm.hook.on_orchestrator_shutdown(self)
-            
-    def process_queue(self, model):
-        retry_at_reached = Q(retry_at__isnull=True) | Q(retry_at__lte=time.now())
-        queue = model.objects.filter(retry_at_reached)
-        
-        if queue.exists():
-            actor = BaseActor()
-            for obj in queue:
-                try:
-                    with transaction.atomic():
-                        actor.tick(obj)
-                except Exception as e:
-                    abx.pm.hook.on_actor_tick_exception(actor, obj, e)
-
-
-# Periodic Tasks
-#################################################
-
[email protected]_task(schedule=djhuey.crontab(minute='*'))
-def ensure_orchestrator_running():
-    """Ensure orchestrator is running, start if not."""
-    if not any(p.name().startswith('Orchestrator') for p in psutil.process_iter()):
-        Orchestrator.spawn()

+ 28 - 60
archivebox/core/actors.py

@@ -2,72 +2,40 @@ __package__ = 'archivebox.core'
 
 from typing import ClassVar
 
-from rich import print
-
-from django.db.models import QuerySet
-from django.utils import timezone
-from datetime import timedelta
-from core.models import Snapshot
+from statemachine import State
 
+from core.models import Snapshot, ArchiveResult
+from core.statemachines import SnapshotMachine, ArchiveResultMachine
 from actors.actor import ActorType
 
 
 class SnapshotActor(ActorType[Snapshot]):
+    Model = Snapshot
+    StateMachineClass = SnapshotMachine
     
-    QUERYSET: ClassVar[QuerySet] = Snapshot.objects.filter(status='queued')
-    CLAIM_WHERE: ClassVar[str] = 'status = "queued"'  # the WHERE clause to filter the objects when atomically getting the next object from the queue
-    CLAIM_SET: ClassVar[str] = 'status = "started"'   # the SET clause to claim the object when atomically getting the next object from the queue
-    CLAIM_ORDER: ClassVar[str] = 'created_at DESC'    # the ORDER BY clause to sort the objects with when atomically getting the next object from the queue
-    CLAIM_FROM_TOP: ClassVar[int] = 50                # the number of objects to consider when atomically getting the next object from the queue
-    
-    # model_type: Type[ModelType]
-    MAX_CONCURRENT_ACTORS: ClassVar[int] = 4               # min 2, max 8, up to 60% of available cpu cores
-    MAX_TICK_TIME: ClassVar[int] = 60                          # maximum duration in seconds to process a single object
-    
-    def claim_sql_where(self) -> str:
-        """override this to implement a custom WHERE clause for the atomic claim step e.g. "status = 'queued' AND locked_by = NULL" """
-        return self.CLAIM_WHERE
-    
-    def claim_sql_set(self) -> str:
-        """override this to implement a custom SET clause for the atomic claim step e.g. "status = 'started' AND locked_by = {self.pid}" """
-        retry_at = timezone.now() + timedelta(seconds=self.MAX_TICK_TIME)
-        # format as 2024-10-31 10:14:33.240903
-        retry_at_str = retry_at.strftime('%Y-%m-%d %H:%M:%S.%f')
-        return f'{self.CLAIM_SET}, retry_at = {retry_at_str}'
-    
-    def claim_sql_order(self) -> str:
-        """override this to implement a custom ORDER BY clause for the atomic claim step e.g. "created_at DESC" """
-        return self.CLAIM_ORDER
+    ACTIVE_STATE: ClassVar[State] = SnapshotMachine.started
+    FINAL_STATES: ClassVar[list[State]] = SnapshotMachine.final_states
+    STATE_FIELD_NAME: ClassVar[str] = SnapshotMachine.state_field_name
     
-    def claim_from_top(self) -> int:
-        """override this to implement a custom number of objects to consider when atomically claiming the next object from the top of the queue"""
-        return self.CLAIM_FROM_TOP
-        
-    def tick(self, obj: Snapshot) -> None:
-        """override this to process the object"""
-        print(f'[blue]🏃‍♂️ {self}.tick()[/blue]', obj.abid or obj.id)
-        # For example:
-        # do_some_task(obj)
-        # do_something_else(obj)
-        # obj._model.objects.filter(pk=obj.pk, status='started').update(status='success')
-        # raise NotImplementedError('tick() must be implemented by the Actor subclass')
-    
-    def on_shutdown(self, err: BaseException | None=None) -> None:
-        print(f'[grey53]🏃‍♂️ {self}.on_shutdown() SHUTTING DOWN[/grey53]', err or '[green](gracefully)[/green]')
-        # abx.pm.hook.on_actor_shutdown(self)
-        
-    def on_tick_start(self, obj: Snapshot) -> None:
-        # print(f'🏃‍♂️ {self}.on_tick_start()', obj.abid or obj.id)
-        # abx.pm.hook.on_actor_tick_start(self, obj_to_process)
-        # self.timer = TimedProgress(self.MAX_TICK_TIME, prefix='      ')
-        pass
+    MAX_CONCURRENT_ACTORS: ClassVar[int] = 3
+    MAX_TICK_TIME: ClassVar[int] = 10
+    CLAIM_FROM_TOP_N: ClassVar[int] = MAX_CONCURRENT_ACTORS * 10
+
+
+
+class ArchiveResultActor(ActorType[ArchiveResult]):
+    Model = ArchiveResult
+    StateMachineClass = ArchiveResultMachine
     
-    def on_tick_end(self, obj: Snapshot) -> None:
-        # print(f'🏃‍♂️ {self}.on_tick_end()', obj.abid or obj.id)
-        # abx.pm.hook.on_actor_tick_end(self, obj_to_process)
-        # self.timer.end()
-        pass
+    ACTIVE_STATE: ClassVar[State] = ArchiveResultMachine.started
+    FINAL_STATES: ClassVar[list[State]] = ArchiveResultMachine.final_states
+    STATE_FIELD_NAME: ClassVar[str] = ArchiveResultMachine.state_field_name
     
-    def on_tick_exception(self, obj: Snapshot, err: BaseException) -> None:
-        print(f'[red]🏃‍♂️ {self}.on_tick_exception()[/red]', obj.abid or obj.id, err)
-        # abx.pm.hook.on_actor_tick_exception(self, obj_to_process, err)
+    MAX_CONCURRENT_ACTORS: ClassVar[int] = 6
+    MAX_TICK_TIME: ClassVar[int] = 60
+    CLAIM_FROM_TOP_N: ClassVar[int] = MAX_CONCURRENT_ACTORS * 10
+
+    # @classproperty
+    # def qs(cls) -> QuerySet[ModelType]:
+    #     """Get the unfiltered and unsorted QuerySet of all objects that this Actor might care about."""
+    #     return cls.Model.objects.filter(extractor='favicon')

+ 57 - 50
archivebox/core/models.py

@@ -20,7 +20,7 @@ from django.db.models import Case, When, Value, IntegerField
 from django.contrib import admin
 from django.conf import settings
 
-from statemachine.mixins import MachineMixin
+from actors.models import ModelWithStateMachine
 
 from archivebox.config import CONSTANTS
 
@@ -156,7 +156,7 @@ class SnapshotManager(models.Manager):
         return super().get_queryset().prefetch_related('tags', 'archiveresult_set')  # .annotate(archiveresult_count=models.Count('archiveresult')).distinct()
 
 
-class Snapshot(ABIDModel, MachineMixin):
+class Snapshot(ABIDModel, ModelWithStateMachine):
     abid_prefix = 'snp_'
     abid_ts_src = 'self.created_at'
     abid_uri_src = 'self.url'
@@ -164,34 +164,32 @@ class Snapshot(ABIDModel, MachineMixin):
     abid_rand_src = 'self.id'
     abid_drift_allowed = True
 
-    state_field_name = 'status'
     state_machine_name = 'core.statemachines.SnapshotMachine'
-    state_machine_attr = 'sm'
+    state_field_name = 'status'
+    retry_at_field_name = 'retry_at'
+    StatusChoices = ModelWithStateMachine.StatusChoices
+    active_state = StatusChoices.STARTED
     
-    class SnapshotStatus(models.TextChoices):
-        QUEUED = 'queued', 'Queued'
-        STARTED = 'started', 'Started'
-        SEALED = 'sealed', 'Sealed'
-        
-    status = models.CharField(max_length=15, default=SnapshotStatus.QUEUED, null=False, blank=False)
-
     id = models.UUIDField(primary_key=True, default=None, null=False, editable=False, unique=True, verbose_name='ID')
     abid = ABIDField(prefix=abid_prefix)
 
-    created_by = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, default=None, null=False, related_name='snapshot_set')
+    created_by = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, default=None, null=False, related_name='snapshot_set', db_index=True)
     created_at = AutoDateTimeField(default=None, null=False, db_index=True)  # loaded from self._init_timestamp
     modified_at = models.DateTimeField(auto_now=True)
+    
+    status = ModelWithStateMachine.StatusField(choices=StatusChoices, default=StatusChoices.QUEUED)
+    retry_at = ModelWithStateMachine.RetryAtField(default=timezone.now)
 
     # legacy ts fields
     bookmarked_at = AutoDateTimeField(default=None, null=False, editable=True, db_index=True)
     downloaded_at = models.DateTimeField(default=None, null=True, editable=False, db_index=True, blank=True)
 
-    crawl = models.ForeignKey(Crawl, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name='snapshot_set')
+    crawl: Crawl = models.ForeignKey(Crawl, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name='snapshot_set', db_index=True)  # type: ignore
 
     url = models.URLField(unique=True, db_index=True)
     timestamp = models.CharField(max_length=32, unique=True, db_index=True, editable=False)
     tags = models.ManyToManyField(Tag, blank=True, through=SnapshotTag, related_name='snapshot_set', through_fields=('snapshot', 'tag'))
-    title = models.CharField(max_length=512, null=True, blank=True, db_index=True)    
+    title = models.CharField(max_length=512, null=True, blank=True, db_index=True)
 
     keys = ('url', 'timestamp', 'title', 'tags', 'downloaded_at')
 
@@ -210,12 +208,14 @@ class Snapshot(ABIDModel, MachineMixin):
         return result
 
     def __repr__(self) -> str:
-        title = (self.title_stripped or '-')[:64]
-        return f'[{self.timestamp}] {self.url[:64]} ({title})'
+        url = self.url or '<no url set>'
+        created_at = self.created_at.strftime("%Y-%m-%d %H:%M") if self.created_at else '<no timestamp set>'
+        if self.id and self.url:
+            return f'[{self.ABID}] {url[:64]} @ {created_at}'
+        return f'[{self.abid_prefix}****not*saved*yet****] {url[:64]} @ {created_at}'
 
     def __str__(self) -> str:
-        title = (self.title_stripped or '-')[:64]
-        return f'[{self.timestamp}] {self.url[:64]} ({title})'
+        return repr(self)
 
     @classmethod
     def from_json(cls, info: dict):
@@ -413,8 +413,7 @@ class Snapshot(ABIDModel, MachineMixin):
         self.tags.add(*tags_id)
         
     def has_pending_archiveresults(self) -> bool:
-        pending_statuses = [ArchiveResult.ArchiveResultStatus.QUEUED, ArchiveResult.ArchiveResultStatus.STARTED]
-        pending_archiveresults = self.archiveresult_set.filter(status__in=pending_statuses)
+        pending_archiveresults = self.archiveresult_set.exclude(status__in=ArchiveResult.FINAL_OR_ACTIVE_STATES)
         return pending_archiveresults.exists()
     
     def create_pending_archiveresults(self) -> list['ArchiveResult']:
@@ -423,13 +422,10 @@ class Snapshot(ABIDModel, MachineMixin):
             archiveresult, _created = ArchiveResult.objects.get_or_create(
                 snapshot=self,
                 extractor=extractor,
-                status=ArchiveResult.ArchiveResultStatus.QUEUED,
+                status=ArchiveResult.INITIAL_STATE,
             )
             archiveresults.append(archiveresult)
         return archiveresults
-    
-    def bump_retry_at(self, seconds: int = 10):
-        self.retry_at = timezone.now() + timedelta(seconds=seconds)
 
 
     # def get_storage_dir(self, create=True, symlink=True) -> Path:
@@ -479,7 +475,7 @@ class ArchiveResultManager(models.Manager):
             ).order_by('indexing_precedence')
         return qs
 
-class ArchiveResult(ABIDModel):
+class ArchiveResult(ABIDModel, ModelWithStateMachine):
     abid_prefix = 'res_'
     abid_ts_src = 'self.snapshot.created_at'
     abid_uri_src = 'self.snapshot.url'
@@ -487,19 +483,19 @@ class ArchiveResult(ABIDModel):
     abid_rand_src = 'self.id'
     abid_drift_allowed = True
     
-    state_field_name = 'status'
-    state_machine_name = 'core.statemachines.ArchiveResultMachine'
-    state_machine_attr = 'sm'
-
-    class ArchiveResultStatus(models.TextChoices):
-        QUEUED = 'queued', 'Queued'
-        STARTED = 'started', 'Started'
-        SUCCEEDED = 'succeeded', 'Succeeded'
-        FAILED = 'failed', 'Failed'
-        SKIPPED = 'skipped', 'Skipped'
-        BACKOFF = 'backoff', 'Waiting to retry'
+    class StatusChoices(models.TextChoices):
+        QUEUED = 'queued', 'Queued'                     # pending, initial
+        STARTED = 'started', 'Started'                  # active
         
-    status = models.CharField(max_length=15, choices=ArchiveResultStatus.choices, default=ArchiveResultStatus.QUEUED, null=False, blank=False)
+        BACKOFF = 'backoff', 'Waiting to retry'         # pending
+        SUCCEEDED = 'succeeded', 'Succeeded'            # final
+        FAILED = 'failed', 'Failed'                     # final
+        SKIPPED = 'skipped', 'Skipped'                  # final
+        
+    state_machine_name = 'core.statemachines.ArchiveResultMachine'
+    retry_at_field_name = 'retry_at'
+    state_field_name = 'status'
+    active_state = StatusChoices.STARTED
 
     EXTRACTOR_CHOICES = (
         ('htmltotext', 'htmltotext'),
@@ -522,19 +518,22 @@ class ArchiveResult(ABIDModel):
     id = models.UUIDField(primary_key=True, default=None, null=False, editable=False, unique=True, verbose_name='ID')
     abid = ABIDField(prefix=abid_prefix)
 
-    created_by = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, default=None, null=False, related_name='archiveresult_set')
+    created_by = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, default=None, null=False, related_name='archiveresult_set', db_index=True)
     created_at = AutoDateTimeField(default=None, null=False, db_index=True)
     modified_at = models.DateTimeField(auto_now=True)
+    
+    status = ModelWithStateMachine.StatusField(choices=StatusChoices.choices, default=StatusChoices.QUEUED)
+    retry_at = ModelWithStateMachine.RetryAtField(default=timezone.now)
 
-    snapshot = models.ForeignKey(Snapshot, on_delete=models.CASCADE, to_field='id', db_column='snapshot_id')
+    snapshot: Snapshot = models.ForeignKey(Snapshot, on_delete=models.CASCADE)   # type: ignore
 
-    extractor = models.CharField(choices=EXTRACTOR_CHOICES, max_length=32)
-    cmd = models.JSONField()
-    pwd = models.CharField(max_length=256)
+    extractor = models.CharField(choices=EXTRACTOR_CHOICES, max_length=32, blank=False, null=False, db_index=True)
+    cmd = models.JSONField(default=None, null=True, blank=True)
+    pwd = models.CharField(max_length=256, default=None, null=True, blank=True)
     cmd_version = models.CharField(max_length=128, default=None, null=True, blank=True)
-    output = models.CharField(max_length=1024)
-    start_ts = models.DateTimeField(db_index=True)
-    end_ts = models.DateTimeField()
+    output = models.CharField(max_length=1024, default=None, null=True, blank=True)
+    start_ts = models.DateTimeField(default=None, null=True, blank=True)
+    end_ts = models.DateTimeField(default=None, null=True, blank=True)
 
     # the network interface that was used to download this result
     # uplink = models.ForeignKey(NetworkInterface, on_delete=models.SET_NULL, null=True, blank=True, verbose_name='Network Interface Used')
@@ -545,10 +544,17 @@ class ArchiveResult(ABIDModel):
         verbose_name = 'Archive Result'
         verbose_name_plural = 'Archive Results Log'
 
+    def __repr__(self):
+        snapshot_id = getattr(self, 'snapshot_id', None)
+        url = self.snapshot.url if snapshot_id else '<no url set>'
+        created_at = self.snapshot.created_at.strftime("%Y-%m-%d %H:%M") if snapshot_id else '<no timestamp set>'
+        extractor = self.extractor or '<no extractor set>'
+        if self.id and snapshot_id:
+            return f'[{self.ABID}] {url[:64]} @ {created_at} -> {extractor}'
+        return f'[{self.abid_prefix}****not*saved*yet****] {url} @ {created_at} -> {extractor}'
 
     def __str__(self):
-        # return f'[{self.abid}] 📅 {self.start_ts.strftime("%Y-%m-%d %H:%M")} 📄 {self.extractor} {self.snapshot.url}'
-        return self.extractor
+        return repr(self)
 
     # TODO: finish connecting machine.models
     # @cached_property
@@ -558,6 +564,10 @@ class ArchiveResult(ABIDModel):
     @cached_property
     def snapshot_dir(self):
         return Path(self.snapshot.link_dir)
+    
+    @cached_property
+    def url(self):
+        return self.snapshot.url
 
     @property
     def api_url(self) -> str:
@@ -596,9 +606,6 @@ class ArchiveResult(ABIDModel):
 
     def output_exists(self) -> bool:
         return os.path.exists(self.output_path())
-    
-    def bump_retry_at(self, seconds: int = 10):
-        self.retry_at = timezone.now() + timedelta(seconds=seconds)
         
     def create_output_dir(self):
         snap_dir = self.snapshot_dir

+ 13 - 10
archivebox/core/statemachines.py

@@ -16,9 +16,9 @@ class SnapshotMachine(StateMachine, strict_states=True):
     model: Snapshot
     
     # States
-    queued = State(value=Snapshot.SnapshotStatus.QUEUED, initial=True)
-    started = State(value=Snapshot.SnapshotStatus.STARTED)
-    sealed = State(value=Snapshot.SnapshotStatus.SEALED, final=True)
+    queued = State(value=Snapshot.StatusChoices.QUEUED, initial=True)
+    started = State(value=Snapshot.StatusChoices.STARTED)
+    sealed = State(value=Snapshot.StatusChoices.SEALED, final=True)
     
     # Tick Event
     tick = (
@@ -53,11 +53,11 @@ class ArchiveResultMachine(StateMachine, strict_states=True):
     model: ArchiveResult
     
     # States
-    queued = State(value=ArchiveResult.ArchiveResultStatus.QUEUED, initial=True)
-    started = State(value=ArchiveResult.ArchiveResultStatus.STARTED)
-    backoff = State(value=ArchiveResult.ArchiveResultStatus.BACKOFF)
-    succeeded = State(value=ArchiveResult.ArchiveResultStatus.SUCCEEDED, final=True)
-    failed = State(value=ArchiveResult.ArchiveResultStatus.FAILED, final=True)
+    queued = State(value=ArchiveResult.StatusChoices.QUEUED, initial=True)
+    started = State(value=ArchiveResult.StatusChoices.STARTED)
+    backoff = State(value=ArchiveResult.StatusChoices.BACKOFF)
+    succeeded = State(value=ArchiveResult.StatusChoices.SUCCEEDED, final=True)
+    failed = State(value=ArchiveResult.StatusChoices.FAILED, final=True)
     
     # Tick Event
     tick = (
@@ -78,7 +78,7 @@ class ArchiveResultMachine(StateMachine, strict_states=True):
         super().__init__(archiveresult, *args, **kwargs)
         
     def can_start(self) -> bool:
-        return self.archiveresult.snapshot and self.archiveresult.snapshot.is_started()
+        return self.archiveresult.snapshot and self.archiveresult.snapshot.STATE == Snapshot.active_state
     
     def is_succeeded(self) -> bool:
         return self.archiveresult.output_exists()
@@ -87,7 +87,10 @@ class ArchiveResultMachine(StateMachine, strict_states=True):
         return not self.archiveresult.output_exists()
     
     def is_backoff(self) -> bool:
-        return self.archiveresult.status == ArchiveResult.ArchiveResultStatus.BACKOFF
+        return self.archiveresult.STATE == ArchiveResult.StatusChoices.BACKOFF
+    
+    def is_finished(self) -> bool:
+        return self.is_failed() or self.is_succeeded()
 
     def on_started(self):
         self.archiveresult.start_ts = timezone.now()

+ 11 - 57
archivebox/crawls/actors.py

@@ -2,68 +2,22 @@ __package__ = 'archivebox.crawls'
 
 from typing import ClassVar
 
-from rich import print
-
-from django.db.models import QuerySet
-
 from crawls.models import Crawl
+from crawls.statemachines import CrawlMachine
 
-from actors.actor import ActorType
+from actors.actor import ActorType, State
 
 
 class CrawlActor(ActorType[Crawl]):
+    """The Actor that manages the lifecycle of all Crawl objects"""
     
-    QUERYSET: ClassVar[QuerySet] = Crawl.objects.filter(status='queued')
-    CLAIM_WHERE: ClassVar[str] = 'status = "queued"'  # the WHERE clause to filter the objects when atomically getting the next object from the queue
-    CLAIM_SET: ClassVar[str] = 'status = "started"'   # the SET clause to claim the object when atomically getting the next object from the queue
-    CLAIM_ORDER: ClassVar[str] = 'created_at DESC'    # the ORDER BY clause to sort the objects with when atomically getting the next object from the queue
-    CLAIM_FROM_TOP: ClassVar[int] = 50                # the number of objects to consider when atomically getting the next object from the queue
-    
-    # model_type: Type[ModelType]
-    MAX_CONCURRENT_ACTORS: ClassVar[int] = 4               # min 2, max 8, up to 60% of available cpu cores
-    MAX_TICK_TIME: ClassVar[int] = 60                          # maximum duration in seconds to process a single object
-    
-    def claim_sql_where(self) -> str:
-        """override this to implement a custom WHERE clause for the atomic claim step e.g. "status = 'queued' AND locked_by = NULL" """
-        return self.CLAIM_WHERE
-    
-    def claim_sql_set(self) -> str:
-        """override this to implement a custom SET clause for the atomic claim step e.g. "status = 'started' AND locked_by = {self.pid}" """
-        return self.CLAIM_SET
-    
-    def claim_sql_order(self) -> str:
-        """override this to implement a custom ORDER BY clause for the atomic claim step e.g. "created_at DESC" """
-        return self.CLAIM_ORDER
-    
-    def claim_from_top(self) -> int:
-        """override this to implement a custom number of objects to consider when atomically claiming the next object from the top of the queue"""
-        return self.CLAIM_FROM_TOP
-        
-    def tick(self, obj: Crawl) -> None:
-        """override this to process the object"""
-        print(f'[blue]🏃‍♂️ {self}.tick()[/blue]', obj.abid or obj.id)
-        # For example:
-        # do_some_task(obj)
-        # do_something_else(obj)
-        # obj._model.objects.filter(pk=obj.pk, status='started').update(status='success')
-        # raise NotImplementedError('tick() must be implemented by the Actor subclass')
-    
-    def on_shutdown(self, err: BaseException | None=None) -> None:
-        print(f'[grey53]🏃‍♂️ {self}.on_shutdown() SHUTTING DOWN[/grey53]', err or '[green](gracefully)[/green]')
-        # abx.pm.hook.on_actor_shutdown(self)
-        
-    def on_tick_start(self, obj: Crawl) -> None:
-        # print(f'🏃‍♂️ {self}.on_tick_start()', obj.abid or obj.id)
-        # abx.pm.hook.on_actor_tick_start(self, obj_to_process)
-        # self.timer = TimedProgress(self.MAX_TICK_TIME, prefix='      ')
-        pass
+    Model = Crawl
+    StateMachineClass = CrawlMachine
     
-    def on_tick_end(self, obj: Crawl) -> None:
-        # print(f'🏃‍♂️ {self}.on_tick_end()', obj.abid or obj.id)
-        # abx.pm.hook.on_actor_tick_end(self, obj_to_process)
-        # self.timer.end()
-        pass
+    ACTIVE_STATE: ClassVar[State] = CrawlMachine.started
+    FINAL_STATES: ClassVar[list[State]] = CrawlMachine.final_states
+    STATE_FIELD_NAME: ClassVar[str] = Crawl.state_field_name
     
-    def on_tick_exception(self, obj: Crawl, err: BaseException) -> None:
-        print(f'[red]🏃‍♂️ {self}.on_tick_exception()[/red]', obj.abid or obj.id, err)
-        # abx.pm.hook.on_actor_tick_exception(self, obj_to_process, err)
+    MAX_CONCURRENT_ACTORS: ClassVar[int] = 3
+    MAX_TICK_TIME: ClassVar[int] = 10
+    CLAIM_FROM_TOP_N: ClassVar[int] = MAX_CONCURRENT_ACTORS * 10

+ 10 - 19
archivebox/crawls/models.py

@@ -11,7 +11,7 @@ from django.conf import settings
 from django.urls import reverse_lazy
 from django.utils import timezone
 
-from statemachine.mixins import MachineMixin
+from actors.models import ModelWithStateMachine
 
 if TYPE_CHECKING:
     from core.models import Snapshot
@@ -50,7 +50,7 @@ class CrawlSchedule(ABIDModel, ModelWithHealthStats):
 
     
 
-class Crawl(ABIDModel, ModelWithHealthStats, MachineMixin):
+class Crawl(ABIDModel, ModelWithHealthStats, ModelWithStateMachine):
     """
     A single session of URLs to archive starting from a given Seed and expanding outwards. An "archiving session" so to speak.
 
@@ -67,17 +67,11 @@ class Crawl(ABIDModel, ModelWithHealthStats, MachineMixin):
     abid_rand_src = 'self.id'
     abid_drift_allowed = True
     
-    state_field_name = 'status'
     state_machine_name = 'crawls.statemachines.CrawlMachine'
-    state_machine_attr = 'sm'
-    bind_events_as_methods = True
-
-    class CrawlStatus(models.TextChoices):
-        QUEUED = 'queued', 'Queued'
-        STARTED = 'started', 'Started'
-        SEALED = 'sealed', 'Sealed'
-
-    status = models.CharField(choices=CrawlStatus.choices, max_length=15, default=CrawlStatus.QUEUED, null=False, blank=False)
+    retry_at_field_name = 'retry_at'
+    state_field_name = 'status'
+    StatusChoices = ModelWithStateMachine.StatusChoices
+    active_state = StatusChoices.STARTED
     
     id = models.UUIDField(primary_key=True, default=None, null=False, editable=False, unique=True, verbose_name='ID')
     abid = ABIDField(prefix=abid_prefix)
@@ -86,6 +80,8 @@ class Crawl(ABIDModel, ModelWithHealthStats, MachineMixin):
     created_at = AutoDateTimeField(default=None, null=False, db_index=True)
     modified_at = models.DateTimeField(auto_now=True)
     
+    status = ModelWithStateMachine.StatusField(choices=StatusChoices, default=StatusChoices.QUEUED)
+    retry_at = ModelWithStateMachine.RetryAtField(default=timezone.now)
 
     seed = models.ForeignKey(Seed, on_delete=models.PROTECT, related_name='crawl_set', null=False, blank=False)
     max_depth = models.PositiveSmallIntegerField(default=0, validators=[MinValueValidator(0), MaxValueValidator(4)])
@@ -127,10 +123,8 @@ class Crawl(ABIDModel, ModelWithHealthStats, MachineMixin):
     def has_pending_archiveresults(self) -> bool:
         from core.models import ArchiveResult
         
-        pending_statuses = [ArchiveResult.ArchiveResultStatus.QUEUED, ArchiveResult.ArchiveResultStatus.STARTED]
-        
         snapshot_ids = self.snapshot_set.values_list('id', flat=True)
-        pending_archiveresults = ArchiveResult.objects.filter(snapshot_id__in=snapshot_ids, status__in=pending_statuses)
+        pending_archiveresults = ArchiveResult.objects.filter(snapshot_id__in=snapshot_ids).exclude(status__in=ArchiveResult.FINAL_OR_ACTIVE_STATES)
         return pending_archiveresults.exists()
     
     def create_root_snapshot(self) -> 'Snapshot':
@@ -139,12 +133,9 @@ class Crawl(ABIDModel, ModelWithHealthStats, MachineMixin):
         root_snapshot, _ = Snapshot.objects.get_or_create(
             crawl=self,
             url=self.seed.uri,
+            status=Snapshot.INITIAL_STATE,
         )
         return root_snapshot
-    
-    def bump_retry_at(self, seconds: int = 10):
-        self.retry_at = timezone.now() + timedelta(seconds=seconds)
-        self.save()
 
 
 class Outlink(models.Model):

+ 3 - 3
archivebox/crawls/statemachines.py

@@ -14,9 +14,9 @@ class CrawlMachine(StateMachine, strict_states=True):
     model: Crawl
     
     # States
-    queued = State(value=Crawl.CrawlStatus.QUEUED, initial=True)
-    started = State(value=Crawl.CrawlStatus.STARTED)
-    sealed = State(value=Crawl.CrawlStatus.SEALED, final=True)
+    queued = State(value=Crawl.StatusChoices.QUEUED, initial=True)
+    started = State(value=Crawl.StatusChoices.STARTED)
+    sealed = State(value=Crawl.StatusChoices.SEALED, final=True)
     
     # Tick Event
     tick = (