|
|
@@ -2,78 +2,240 @@ __package__ = 'archivebox.actors'
|
|
|
|
|
|
import os
|
|
|
import time
|
|
|
-from abc import ABC, abstractmethod
|
|
|
-from typing import ClassVar, Generic, TypeVar, Any, cast, Literal, Type
|
|
|
-from django.utils.functional import classproperty
|
|
|
+from typing import ClassVar, Generic, TypeVar, Any, Literal, Type, Iterable, cast, get_args
|
|
|
+from datetime import timedelta
|
|
|
+from multiprocessing import Process, cpu_count
|
|
|
+from threading import Thread, get_native_id
|
|
|
|
|
|
-from rich import print
|
|
|
import psutil
|
|
|
+from rich import print
|
|
|
+from statemachine import State, StateMachine, registry
|
|
|
+from statemachine.mixins import MachineMixin
|
|
|
|
|
|
from django import db
|
|
|
-from django.db import models
|
|
|
-from django.db.models import QuerySet
|
|
|
-from multiprocessing import Process, cpu_count
|
|
|
-from threading import Thread, get_native_id
|
|
|
+from django.db.models import QuerySet, sql, Q
|
|
|
+from django.db.models import Model as DjangoModel
|
|
|
+from django.utils import timezone
|
|
|
+from django.utils.functional import classproperty
|
|
|
+
|
|
|
+from .models import ModelWithStateMachine
|
|
|
|
|
|
# from archivebox.logging_util import TimedProgress
|
|
|
|
|
|
+class ActorObjectAlreadyClaimed(Exception):
|
|
|
+ """Raised when the Actor tries to claim the next object from the queue but it's already been claimed by another Actor"""
|
|
|
+ pass
|
|
|
+
|
|
|
+class ActorQueueIsEmpty(Exception):
|
|
|
+ """Raised when the Actor tries to get the next object from the queue but it's empty"""
|
|
|
+ pass
|
|
|
+
|
|
|
+CPU_COUNT = cpu_count()
|
|
|
+DEFAULT_MAX_TICK_TIME = 60
|
|
|
+DEFAULT_MAX_CONCURRENT_ACTORS = min(max(2, int(CPU_COUNT * 0.6)), 8) # 2 < 60% * num available cpu cores < 8
|
|
|
+
|
|
|
+limit = lambda n, max: min(n, max)
|
|
|
+
|
|
|
LaunchKwargs = dict[str, Any]
|
|
|
+ObjectState = State | str
|
|
|
+ObjectStateList = Iterable[ObjectState]
|
|
|
|
|
|
-ModelType = TypeVar('ModelType', bound=models.Model)
|
|
|
+ModelType = TypeVar('ModelType', bound=ModelWithStateMachine)
|
|
|
|
|
|
-class ActorType(ABC, Generic[ModelType]):
|
|
|
+class ActorType(Generic[ModelType]):
|
|
|
"""
|
|
|
Base class for all actors. Usage:
|
|
|
- class FaviconActor(ActorType[ArchiveResult]):
|
|
|
- QUERYSET: ClassVar[QuerySet] = ArchiveResult.objects.filter(status='queued', extractor='favicon')
|
|
|
- CLAIM_WHERE: ClassVar[str] = 'status = "queued" AND extractor = "favicon"'
|
|
|
- CLAIM_ORDER: ClassVar[str] = 'created_at DESC'
|
|
|
- ATOMIC: ClassVar[bool] = True
|
|
|
-
|
|
|
- def claim_sql_set(self, obj: ArchiveResult) -> str:
|
|
|
- # SQL fields to update atomically while claiming an object from the queue
|
|
|
- retry_at = datetime.now() + timedelta(seconds=self.MAX_TICK_TIME)
|
|
|
- return f"status = 'started', locked_by = {self.pid}, retry_at = {retry_at}"
|
|
|
-
|
|
|
- def tick(self, obj: ArchiveResult) -> None:
|
|
|
- run_favicon_extractor(obj)
|
|
|
- ArchiveResult.objects.filter(pk=obj.pk, status='started').update(status='success')
|
|
|
+
|
|
|
+ class FaviconActor(ActorType[FaviconArchiveResult]):
|
|
|
+ FINAL_STATES: ClassVar[tuple[str, ...]] = ('succeeded', 'failed', 'skipped')
|
|
|
+ ACTIVE_STATE: ClassVar[str] = 'started'
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def qs(cls) -> QuerySet[FaviconArchiveResult]:
|
|
|
+ return ArchiveResult.objects.filter(extractor='favicon') # or leave the default: FaviconArchiveResult.objects.all()
|
|
|
"""
|
|
|
+
|
|
|
+ ### Class attributes (defined on the class at compile-time when ActorType[MyModel] is defined)
|
|
|
+ Model: Type[ModelType]
|
|
|
+ StateMachineClass: Type[StateMachine]
|
|
|
+
|
|
|
+ STATE_FIELD_NAME: ClassVar[str]
|
|
|
+ ACTIVE_STATE: ClassVar[ObjectState]
|
|
|
+ FINAL_STATES: ClassVar[ObjectStateList]
|
|
|
+ EVENT_NAME: ClassVar[str] = 'tick' # the event name to trigger on the obj.sm: StateMachine (usually 'tick')
|
|
|
+
|
|
|
+ CLAIM_ORDER: ClassVar[tuple[str, ...]] = ('retry_at',) # the .order(*args) to claim the queue objects in, use ('?',) for random order
|
|
|
+ CLAIM_FROM_TOP_N: ClassVar[int] = CPU_COUNT * 10 # the number of objects to consider when atomically getting the next object from the queue
|
|
|
+ CLAIM_ATOMIC: ClassVar[bool] = True # whether to atomically fetch+claim the next object in one query, or fetch and lock it in two queries
|
|
|
+
|
|
|
+ MAX_TICK_TIME: ClassVar[int] = DEFAULT_MAX_TICK_TIME # maximum duration in seconds to process a single object
|
|
|
+ MAX_CONCURRENT_ACTORS: ClassVar[int] = DEFAULT_MAX_CONCURRENT_ACTORS # maximum number of concurrent actors that can be running at once
|
|
|
+
|
|
|
+ _SPAWNED_ACTOR_PIDS: ClassVar[list[psutil.Process]] = [] # used to record all the pids of Actors spawned on the class
|
|
|
+
|
|
|
+ ### Instance attributes (only used within an actor instance inside a spawned actor thread/process)
|
|
|
pid: int
|
|
|
idle_count: int = 0
|
|
|
launch_kwargs: LaunchKwargs = {}
|
|
|
mode: Literal['thread', 'process'] = 'process'
|
|
|
|
|
|
- MAX_CONCURRENT_ACTORS: ClassVar[int] = min(max(2, int(cpu_count() * 0.6)), 8) # min 2, max 8, up to 60% of available cpu cores
|
|
|
- MAX_TICK_TIME: ClassVar[int] = 60 # maximum duration in seconds to process a single object
|
|
|
-
|
|
|
- QUERYSET: ClassVar[QuerySet] # the QuerySet to claim objects from
|
|
|
- CLAIM_WHERE: ClassVar[str] = 'status = "queued"' # the WHERE clause to filter the objects when atomically getting the next object from the queue
|
|
|
- CLAIM_SET: ClassVar[str] = 'status = "started"' # the SET clause to claim the object when atomically getting the next object from the queue
|
|
|
- CLAIM_ORDER: ClassVar[str] = 'created_at DESC' # the ORDER BY clause to sort the objects with when atomically getting the next object from the queue
|
|
|
- CLAIM_FROM_TOP: ClassVar[int] = MAX_CONCURRENT_ACTORS * 10 # the number of objects to consider when atomically getting the next object from the queue
|
|
|
- ATOMIC: ClassVar[bool] = True # whether to atomically fetch+claim the nextobject in one step, or fetch and lock it in two steps
|
|
|
-
|
|
|
- # model_type: Type[ModelType]
|
|
|
-
|
|
|
- _SPAWNED_ACTOR_PIDS: ClassVar[list[psutil.Process]] = [] # record all the pids of Actors spawned by this class
|
|
|
+ def __init_subclass__(cls) -> None:
|
|
|
+ """
|
|
|
+ Executed at class definition time (i.e. during import of any file containing class MyActor(ActorType[MyModel]): ...).
|
|
|
+ Loads the django Model from the Generic[ModelType] TypeVar arg and populates any missing class-level config using it.
|
|
|
+ """
|
|
|
+ if getattr(cls, 'Model', None) is None:
|
|
|
+ cls.Model = cls._get_model_from_generic_typevar()
|
|
|
+ cls._populate_missing_classvars_from_model(cls.Model)
|
|
|
|
|
|
def __init__(self, mode: Literal['thread', 'process']|None=None, **launch_kwargs: LaunchKwargs):
|
|
|
+ """
|
|
|
+ Executed right before the Actor is spawned to create a unique Actor instance for that thread/process.
|
|
|
+ actor_instance.runloop() is then executed from inside the newly spawned thread/process.
|
|
|
+ """
|
|
|
self.mode = mode or self.mode
|
|
|
self.launch_kwargs = launch_kwargs or dict(self.launch_kwargs)
|
|
|
|
|
|
+
|
|
|
+ ### Private Helper Methods: Not desiged to be overridden by subclasses or called by anything outside of this class
|
|
|
+
|
|
|
@classproperty
|
|
|
def name(cls) -> str:
|
|
|
return cls.__name__ # type: ignore
|
|
|
|
|
|
def __str__(self) -> str:
|
|
|
- return self.__repr__()
|
|
|
+ return repr(self)
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
- """FaviconActor[pid=1234]"""
|
|
|
+ """-> FaviconActor[pid=1234]"""
|
|
|
label = 'pid' if self.mode == 'process' else 'tid'
|
|
|
return f'[underline]{self.name}[/underline]\\[{label}={self.pid}]'
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def _state_to_str(state: ObjectState) -> str:
|
|
|
+ """Convert a statemachine.State, models.TextChoices.choices value, or Enum value to a str"""
|
|
|
+ return str(state.value) if isinstance(state, State) else str(state)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _sql_for_select_top_n_candidates(qs: QuerySet, claim_from_top_n: int=CLAIM_FROM_TOP_N) -> tuple[str, tuple[Any, ...]]:
|
|
|
+ """Get the SQL for selecting the top N candidates from the queue (to claim one from)"""
|
|
|
+ queryset = qs.only('id')[:claim_from_top_n]
|
|
|
+ select_sql, select_params = compile_sql_select(queryset)
|
|
|
+ return select_sql, select_params
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _sql_for_update_claimed_obj(qs: QuerySet, update_kwargs: dict[str, Any]) -> tuple[str, tuple[Any, ...]]:
|
|
|
+ """Get the SQL for updating a claimed object to mark it as ACTIVE"""
|
|
|
+ # qs.update(status='started', retry_at=<now + MAX_TICK_TIME>)
|
|
|
+ update_sql, update_params = compile_sql_update(qs, update_kwargs=update_kwargs)
|
|
|
+ # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN ('succeeded', 'failed', 'sealed', 'started') AND retry_at <= '2024-11-04 10:14:33.240903'
|
|
|
+ return update_sql, update_params
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_model_from_generic_typevar(cls) -> Type[ModelType]:
|
|
|
+ """Get the django Model from the Generic[ModelType] TypeVar arg (and check that it inherits from django.db.models.Model)"""
|
|
|
+ # cls.__orig_bases__ is non-standard and may be removed in the future! if this breaks,
|
|
|
+ # we can just require the inerited class to define the Model as a classvar manually, e.g.:
|
|
|
+ # class SnapshotActor(ActorType[Snapshot]):
|
|
|
+ # Model: ClassVar[Type[Snapshot]] = Snapshot
|
|
|
+ # https://stackoverflow.com/questions/57706180/generict-base-class-how-to-get-type-of-t-from-within-instance
|
|
|
+ Model = get_args(cls.__orig_bases__[0])[0] # type: ignore
|
|
|
+ assert issubclass(Model, DjangoModel), f'{cls.__name__}.Model must be a valid django Model'
|
|
|
+ return cast(Type[ModelType], Model)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _get_state_machine_cls(Model: Type[ModelType]) -> Type[StateMachine]:
|
|
|
+ """Get the StateMachine class for the given django Model that inherits from MachineMixin"""
|
|
|
+ assert issubclass(Model, MachineMixin), f'{Model.__name__} must inherit from MachineMixin and define a .state_machine_name: str'
|
|
|
+ model_state_machine_name = getattr(Model, 'state_machine_name', None)
|
|
|
+ if model_state_machine_name:
|
|
|
+ StateMachine = registry.get_machine_cls(model_state_machine_name)
|
|
|
+ assert issubclass(StateMachine, StateMachine)
|
|
|
+ return StateMachine
|
|
|
+ raise NotImplementedError(f'ActorType[{Model.__name__}] must define .state_machine_name: str that points to a valid StateMachine')
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_state_machine_instance(cls, obj: ModelType) -> StateMachine:
|
|
|
+ """Get the StateMachine instance for the given django Model instance (and check that it is a valid instance of cls.StateMachineClass)"""
|
|
|
+ obj_statemachine = None
|
|
|
+ state_machine_attr = getattr(obj, 'state_machine_attr', 'sm')
|
|
|
+ try:
|
|
|
+ obj_statemachine = getattr(obj, state_machine_attr)
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+ if not isinstance(obj_statemachine, cls.StateMachineClass):
|
|
|
+ raise Exception(f'{cls.__name__}: Failed to find a valid StateMachine instance at {type(obj).__name__}.{state_machine_attr}')
|
|
|
+
|
|
|
+ return obj_statemachine
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _populate_missing_classvars_from_model(cls, Model: Type[ModelType]):
|
|
|
+ """Check that the class variables are set correctly based on the ModelType"""
|
|
|
+
|
|
|
+ # check that Model is the same as the Generic[ModelType] parameter in the class definition
|
|
|
+ cls.Model = getattr(cls, 'Model', None) or Model
|
|
|
+ if cls.Model != Model:
|
|
|
+ raise ValueError(f'{cls.__name__}.Model must be set to the same Model as the Generic[ModelType] parameter in the class definition')
|
|
|
+
|
|
|
+ # check that Model has a valid StateMachine with the required event defined on it
|
|
|
+ cls.StateMachineClass = getattr(cls, 'StateMachineClass', None) or cls._get_state_machine_cls(cls.Model)
|
|
|
+ assert isinstance(cls.EVENT_NAME, str), f'{cls.__name__}.EVENT_NAME must be a str, got: {type(cls.EVENT_NAME).__name__} instead'
|
|
|
+ assert hasattr(cls.StateMachineClass, cls.EVENT_NAME), f'StateMachine {cls.StateMachineClass.__name__} must define a {cls.EVENT_NAME} event ({cls.__name__}.EVENT_NAME = {cls.EVENT_NAME})'
|
|
|
+
|
|
|
+ # check that Model uses .id as its primary key field
|
|
|
+ primary_key_field = cls.Model._meta.pk.name
|
|
|
+ if primary_key_field != 'id':
|
|
|
+ raise NotImplementedError(f'Actors currently only support models that use .id as their primary key field ({cls.__name__} uses {cls.__name__}.{primary_key_field} as primary key)')
|
|
|
+
|
|
|
+ # check for STATE_FIELD_NAME classvar or set it from the model's state_field_name attr
|
|
|
+ if not getattr(cls, 'STATE_FIELD_NAME', None):
|
|
|
+ if hasattr(cls.Model, 'state_field_name'):
|
|
|
+ cls.STATE_FIELD_NAME = getattr(cls.Model, 'state_field_name')
|
|
|
+ else:
|
|
|
+ raise NotImplementedError(f'{cls.__name__} must define a STATE_FIELD_NAME: ClassVar[str] (e.g. "status") or have a .state_field_name attr on its Model')
|
|
|
+ assert isinstance(cls.STATE_FIELD_NAME, str), f'{cls.__name__}.STATE_FIELD_NAME must be a str, got: {type(cls.STATE_FIELD_NAME).__name__} instead'
|
|
|
+
|
|
|
+ # check for FINAL_STATES classvar or set it from the model's final_states attr
|
|
|
+ if not getattr(cls, 'FINAL_STATES', None):
|
|
|
+ cls.FINAL_STATES = cls.StateMachineClass.final_states
|
|
|
+ if not cls.FINAL_STATES:
|
|
|
+ raise NotImplementedError(f'{cls.__name__} must define a non-empty FINAL_STATES: ClassVar[list[str]] (e.g. ["sealed"]) or have a {cls.Model.__name__}.state_machine_name pointing to a StateMachine that provides .final_states')
|
|
|
+ cls.FINAL_STATES = [cls._state_to_str(state) for state in cls.FINAL_STATES]
|
|
|
+ assert all(isinstance(state, str) for state in cls.FINAL_STATES), f'{cls.__name__}.FINAL_STATES must be a list[str], got: {type(cls.FINAL_STATES).__name__} instead'
|
|
|
+
|
|
|
+ # check for ACTIVE_STATE classvar or set it from the model's active_state attr
|
|
|
+ if not getattr(cls, 'ACTIVE_STATE', None):
|
|
|
+ raise NotImplementedError(f'{cls.__name__} must define an ACTIVE_STATE: ClassVar[State] (e.g. SnapshotMachine.started) ({cls.Model.__name__}.{cls.STATE_FIELD_NAME} gets set to this value to mark objects as actively processing)')
|
|
|
+ assert isinstance(cls.ACTIVE_STATE, (State, str)), f'{cls.__name__}.ACTIVE_STATE must be a statemachine.State | str, got: {type(cls.ACTIVE_STATE).__name__} instead'
|
|
|
+
|
|
|
+ # check the other ClassVar attributes for valid values
|
|
|
+ assert cls.CLAIM_ORDER and isinstance(cls.CLAIM_ORDER, tuple) and all(isinstance(order, str) for order in cls.CLAIM_ORDER), f'{cls.__name__}.CLAIM_ORDER must be a non-empty tuple[str, ...], got: {type(cls.CLAIM_ORDER).__name__} instead'
|
|
|
+ assert cls.CLAIM_FROM_TOP_N > 0, f'{cls.__name__}.CLAIM_FROM_TOP_N must be a positive int, got: {cls.CLAIM_FROM_TOP_N} instead'
|
|
|
+ assert cls.MAX_TICK_TIME >= 1, f'{cls.__name__}.MAX_TICK_TIME must be a positive int > 1, got: {cls.MAX_TICK_TIME} instead'
|
|
|
+ assert cls.MAX_CONCURRENT_ACTORS >= 1, f'{cls.__name__}.MAX_CONCURRENT_ACTORS must be a positive int >=1, got: {cls.MAX_CONCURRENT_ACTORS} instead'
|
|
|
+ assert isinstance(cls.CLAIM_ATOMIC, bool), f'{cls.__name__}.CLAIM_ATOMIC must be a bool, got: {cls.CLAIM_ATOMIC} instead'
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _fork_actor_as_thread(cls, **launch_kwargs: LaunchKwargs) -> int:
|
|
|
+ """Spawn a new background thread running the actor's runloop"""
|
|
|
+ actor = cls(mode='thread', **launch_kwargs)
|
|
|
+ bg_actor_thread = Thread(target=actor.runloop)
|
|
|
+ bg_actor_thread.start()
|
|
|
+ assert bg_actor_thread.native_id is not None
|
|
|
+ return bg_actor_thread.native_id
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _fork_actor_as_process(cls, **launch_kwargs: LaunchKwargs) -> int:
|
|
|
+ """Spawn a new background process running the actor's runloop"""
|
|
|
+ actor = cls(mode='process', **launch_kwargs)
|
|
|
+ bg_actor_process = Process(target=actor.runloop)
|
|
|
+ bg_actor_process.start()
|
|
|
+ assert bg_actor_process.pid is not None
|
|
|
+ cls._SPAWNED_ACTOR_PIDS.append(psutil.Process(pid=bg_actor_process.pid))
|
|
|
+ return bg_actor_process.pid
|
|
|
+
|
|
|
+
|
|
|
### Class Methods: Called by Orchestrator on ActorType class before it has been spawned
|
|
|
|
|
|
@classmethod
|
|
|
@@ -94,71 +256,92 @@ class ActorType(ABC, Generic[ModelType]):
|
|
|
if not queue_length: # queue is empty, spawn 0 actors
|
|
|
return []
|
|
|
|
|
|
- actors_to_spawn: list[LaunchKwargs] = []
|
|
|
- max_spawnable = cls.MAX_CONCURRENT_ACTORS - len(running_actors)
|
|
|
+ # WARNING:
|
|
|
+ # spawning new actors processes is slow/expensive, avoid spawning many actors at once in a single orchestrator tick.
|
|
|
+ # limit to spawning 1 or 2 at a time per orchestrator tick, and let the next tick handle starting another couple.
|
|
|
+ # DONT DO THIS:
|
|
|
+ # if queue_length > 20: # queue is extremely long, spawn maximum actors at once!
|
|
|
+ # num_to_spawn_this_tick = cls.MAX_CONCURRENT_ACTORS
|
|
|
|
|
|
- # spawning new actors is expensive, avoid spawning all the actors at once. To stagger them,
|
|
|
- # let the next orchestrator tick handle starting another 2 on the next tick()
|
|
|
- # if queue_length > 10: # queue is long, spawn as many as possible
|
|
|
- # actors_to_spawn += max_spawnable * [{}]
|
|
|
+ if queue_length > 10:
|
|
|
+ num_to_spawn_this_tick = 2 # spawn more actors per tick if queue is long
|
|
|
+ else:
|
|
|
+ num_to_spawn_this_tick = 1 # spawn fewer actors per tick if queue is short
|
|
|
+
|
|
|
+ num_remaining = cls.MAX_CONCURRENT_ACTORS - len(running_actors)
|
|
|
+ num_to_spawn_now: int = limit(num_to_spawn_this_tick, num_remaining)
|
|
|
|
|
|
- if queue_length > 4: # queue is medium, spawn 1 or 2 actors
|
|
|
- actors_to_spawn += min(2, max_spawnable) * [{**cls.launch_kwargs}]
|
|
|
- else: # queue is short, spawn 1 actor
|
|
|
- actors_to_spawn += min(1, max_spawnable) * [{**cls.launch_kwargs}]
|
|
|
- return actors_to_spawn
|
|
|
+ actors_launch_kwargs: list[LaunchKwargs] = num_to_spawn_now * [{**cls.launch_kwargs}]
|
|
|
+ return actors_launch_kwargs
|
|
|
|
|
|
@classmethod
|
|
|
def start(cls, mode: Literal['thread', 'process']='process', **launch_kwargs: LaunchKwargs) -> int:
|
|
|
if mode == 'thread':
|
|
|
- return cls.fork_actor_as_thread(**launch_kwargs)
|
|
|
+ return cls._fork_actor_as_thread(**launch_kwargs)
|
|
|
elif mode == 'process':
|
|
|
- return cls.fork_actor_as_process(**launch_kwargs)
|
|
|
+ return cls._fork_actor_as_process(**launch_kwargs)
|
|
|
raise ValueError(f'Invalid actor mode: {mode} must be "thread" or "process"')
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def fork_actor_as_thread(cls, **launch_kwargs: LaunchKwargs) -> int:
|
|
|
- """Spawn a new background thread running the actor's runloop"""
|
|
|
- actor = cls(mode='thread', **launch_kwargs)
|
|
|
- bg_actor_thread = Thread(target=actor.runloop)
|
|
|
- bg_actor_thread.start()
|
|
|
- assert bg_actor_thread.native_id is not None
|
|
|
- return bg_actor_thread.native_id
|
|
|
|
|
|
- @classmethod
|
|
|
- def fork_actor_as_process(cls, **launch_kwargs: LaunchKwargs) -> int:
|
|
|
- """Spawn a new background process running the actor's runloop"""
|
|
|
- actor = cls(mode='process', **launch_kwargs)
|
|
|
- bg_actor_process = Process(target=actor.runloop)
|
|
|
- bg_actor_process.start()
|
|
|
- assert bg_actor_process.pid is not None
|
|
|
- cls._SPAWNED_ACTOR_PIDS.append(psutil.Process(pid=bg_actor_process.pid))
|
|
|
- return bg_actor_process.pid
|
|
|
+ @classproperty
|
|
|
+ def qs(cls) -> QuerySet[ModelType]:
|
|
|
+ """
|
|
|
+ Get the unfiltered and unsorted QuerySet of all objects that this Actor might care about.
|
|
|
+ Override this in the subclass to define the QuerySet of objects that the Actor is going to poll for new work.
|
|
|
+ (don't limit, order, or filter this by retry_at or status yet, Actor.get_queue() handles that part)
|
|
|
+ """
|
|
|
+ return cls.Model.objects.all()
|
|
|
|
|
|
- @classmethod
|
|
|
- def get_model(cls) -> Type[ModelType]:
|
|
|
- # wish this was a @classproperty but Generic[ModelType] return type cant be statically inferred for @classproperty
|
|
|
- return cls.QUERYSET.model
|
|
|
+ @classproperty
|
|
|
+ def final_q(cls) -> Q:
|
|
|
+ """Get the filter for objects that are in a final state"""
|
|
|
+ return Q(**{f'{cls.STATE_FIELD_NAME}__in': [cls._state_to_str(s) for s in cls.FINAL_STATES]})
|
|
|
|
|
|
- @classmethod
|
|
|
- def get_queue(cls) -> QuerySet:
|
|
|
- """override this to provide your queryset as the queue"""
|
|
|
- # return ArchiveResult.objects.filter(status='queued', extractor__in=('pdf', 'dom', 'screenshot'))
|
|
|
- return cls.QUERYSET
|
|
|
+ @classproperty
|
|
|
+ def active_q(cls) -> Q:
|
|
|
+ """Get the filter for objects that are actively processing right now"""
|
|
|
+ return Q(**{cls.STATE_FIELD_NAME: cls._state_to_str(cls.ACTIVE_STATE)}) # e.g. Q(status='started')
|
|
|
|
|
|
- ### Instance Methods: Called by Actor after it has been spawned (i.e. forked as a thread or process)
|
|
|
+ @classproperty
|
|
|
+ def stalled_q(cls) -> Q:
|
|
|
+ """Get the filter for objects that are marked active but have timed out"""
|
|
|
+ return cls.active_q & Q(retry_at__lte=timezone.now()) # e.g. Q(status='started') AND Q(<retry_at is in the past>)
|
|
|
+
|
|
|
+ @classproperty
|
|
|
+ def future_q(cls) -> Q:
|
|
|
+ """Get the filter for objects that have a retry_at in the future"""
|
|
|
+ return Q(retry_at__gt=timezone.now())
|
|
|
+
|
|
|
+ @classproperty
|
|
|
+ def pending_q(cls) -> Q:
|
|
|
+ """Get the filter for objects that are ready for processing."""
|
|
|
+ return ~(cls.active_q) & ~(cls.final_q) & ~(cls.future_q)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_queue(cls, sort: bool=True) -> QuerySet[ModelType]:
|
|
|
+ """
|
|
|
+ Get the sorted and filtered QuerySet of objects that are ready for processing.
|
|
|
+ e.g. qs.exclude(status__in=('sealed', 'started'), retry_at__gt=timezone.now()).order_by('retry_at')
|
|
|
+ """
|
|
|
+ unsorted_qs = cls.qs.filter(cls.pending_q)
|
|
|
+ return unsorted_qs.order_by(*cls.CLAIM_ORDER) if sort else unsorted_qs
|
|
|
+
|
|
|
+ ### Instance Methods: Only called from within Actor instance after it has been spawned (i.e. forked as a thread or process)
|
|
|
|
|
|
def runloop(self):
|
|
|
"""The main runloop that starts running when the actor is spawned (as subprocess or thread) and exits when the queue is empty"""
|
|
|
self.on_startup()
|
|
|
+ obj_to_process: ModelType | None = None
|
|
|
+ last_error: BaseException | None = None
|
|
|
try:
|
|
|
while True:
|
|
|
- obj_to_process: ModelType | None = None
|
|
|
+ # Get the next object to process from the queue
|
|
|
try:
|
|
|
obj_to_process = cast(ModelType, self.get_next(atomic=self.atomic))
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
+ except (ActorQueueIsEmpty, ActorObjectAlreadyClaimed) as err:
|
|
|
+ last_error = err
|
|
|
+ obj_to_process = None
|
|
|
|
|
|
+ # Handle the case where there is no next object to process
|
|
|
if obj_to_process:
|
|
|
self.idle_count = 0 # reset idle count if we got an object
|
|
|
else:
|
|
|
@@ -170,119 +353,127 @@ class ActorType(ABC, Generic[ModelType]):
|
|
|
time.sleep(1)
|
|
|
continue
|
|
|
|
|
|
+ # Process the object by triggering its StateMachine.tick() method
|
|
|
self.on_tick_start(obj_to_process)
|
|
|
-
|
|
|
- # Process the object
|
|
|
try:
|
|
|
self.tick(obj_to_process)
|
|
|
except Exception as err:
|
|
|
- print(f'[red]🏃♂️ ERROR: {self}.tick()[/red]', err)
|
|
|
+ last_error = err
|
|
|
+ # print(f'[red]🏃♂️ {self}.tick()[/red] {obj_to_process} ERROR: [red]{type(err).__name__}: {err}[/red]')
|
|
|
db.connections.close_all() # always reset the db connection after an exception to clear any pending transactions
|
|
|
self.on_tick_exception(obj_to_process, err)
|
|
|
finally:
|
|
|
self.on_tick_end(obj_to_process)
|
|
|
-
|
|
|
- self.on_shutdown(err=None)
|
|
|
+
|
|
|
except BaseException as err:
|
|
|
+ last_error = err
|
|
|
if isinstance(err, KeyboardInterrupt):
|
|
|
print()
|
|
|
else:
|
|
|
- print(f'\n[red]🏃♂️ {self}.runloop() FATAL:[/red]', err.__class__.__name__, err)
|
|
|
- self.on_shutdown(err=err)
|
|
|
+ print(f'\n[red]🏃♂️ {self}.runloop() FATAL:[/red] {type(err).__name__}: {err}')
|
|
|
+ print(f' Last processed object: {obj_to_process}')
|
|
|
+ raise
|
|
|
+ finally:
|
|
|
+ self.on_shutdown(last_obj=obj_to_process, last_error=last_error)
|
|
|
+
|
|
|
+ def get_update_kwargs_to_claim_obj(self) -> dict[str, Any]:
|
|
|
+ """
|
|
|
+ Get the field values needed to mark an pending obj_to_process as being actively processing (aka claimed)
|
|
|
+ by the current Actor. returned kwargs will be applied using: qs.filter(id=obj_to_process.id).update(**kwargs).
|
|
|
+ F() expressions are allowed in field values if you need to update a field based on its current value.
|
|
|
+ Can be a defined as a normal method (instead of classmethod) on subclasses if it needs to access instance vars.
|
|
|
+ """
|
|
|
+ return {
|
|
|
+ self.STATE_FIELD_NAME: self.ACTIVE_STATE,
|
|
|
+ 'retry_at': timezone.now() + timedelta(seconds=self.MAX_TICK_TIME),
|
|
|
+ }
|
|
|
|
|
|
def get_next(self, atomic: bool | None=None) -> ModelType | None:
|
|
|
"""get the next object from the queue, atomically locking it if self.atomic=True"""
|
|
|
- if atomic is None:
|
|
|
- atomic = self.ATOMIC
|
|
|
-
|
|
|
+ atomic = self.CLAIM_ATOMIC if atomic is None else atomic
|
|
|
if atomic:
|
|
|
# fetch and claim the next object from in the queue in one go atomically
|
|
|
obj = self.get_next_atomic()
|
|
|
else:
|
|
|
# two-step claim: fetch the next object and lock it in a separate query
|
|
|
- obj = self.get_queue().last()
|
|
|
- assert obj and self.lock_next(obj), f'Unable to fetch+lock the next {self.get_model().__name__} ojbect from {self}.QUEUE'
|
|
|
+ obj = self.get_next_non_atomic()
|
|
|
return obj
|
|
|
|
|
|
- def lock_next(self, obj: ModelType) -> bool:
|
|
|
- """override this to implement a custom two-step (non-atomic)lock mechanism"""
|
|
|
- # For example:
|
|
|
- # assert obj._model.objects.filter(pk=obj.pk, status='queued').update(status='started', locked_by=self.pid)
|
|
|
- # Not needed if using get_next_and_lock() to claim the object atomically
|
|
|
- # print(f'[blue]🏃♂️ {self}.lock()[/blue]', obj.abid or obj.id)
|
|
|
- return True
|
|
|
-
|
|
|
- def claim_sql_where(self) -> str:
|
|
|
- """override this to implement a custom WHERE clause for the atomic claim step e.g. "status = 'queued' AND locked_by = NULL" """
|
|
|
- return self.CLAIM_WHERE
|
|
|
-
|
|
|
- def claim_sql_set(self) -> str:
|
|
|
- """override this to implement a custom SET clause for the atomic claim step e.g. "status = 'started' AND locked_by = {self.pid}" """
|
|
|
- return self.CLAIM_SET
|
|
|
-
|
|
|
- def claim_sql_order(self) -> str:
|
|
|
- """override this to implement a custom ORDER BY clause for the atomic claim step e.g. "created_at DESC" """
|
|
|
- return self.CLAIM_ORDER
|
|
|
-
|
|
|
- def claim_from_top(self) -> int:
|
|
|
- """override this to implement a custom number of objects to consider when atomically claiming the next object from the top of the queue"""
|
|
|
- return self.CLAIM_FROM_TOP
|
|
|
-
|
|
|
- def get_next_atomic(self, shallow: bool=True) -> ModelType | None:
|
|
|
+ def get_next_non_atomic(self) -> ModelType:
|
|
|
"""
|
|
|
- claim a random object from the top n=50 objects in the queue (atomically updates status=queued->started for claimed object)
|
|
|
- optimized for minimizing contention on the queue with other actors selecting from the same list
|
|
|
- slightly faster than claim_any_obj() which selects randomly from the entire queue but needs to know the total count
|
|
|
+ Naiively selects the top/first object from self.get_queue().order_by(*self.CLAIM_ORDER),
|
|
|
+ then claims it by running .update(status='started', retry_at=<now + MAX_TICK_TIME>).
|
|
|
+
|
|
|
+ Do not use this method if there is more than one Actor racing to get objects from the same queue,
|
|
|
+ it will be slow/buggy as they'll compete to lock the same object at the same time (TOCTTOU race).
|
|
|
"""
|
|
|
- Model = self.get_model() # e.g. ArchiveResult
|
|
|
- table = f'{Model._meta.app_label}_{Model._meta.model_name}' # e.g. core_archiveresult
|
|
|
+ obj = self.get_queue().first()
|
|
|
+ if obj is None:
|
|
|
+ raise ActorQueueIsEmpty(f'No next object available in {self}.get_queue()')
|
|
|
|
|
|
- where_sql = self.claim_sql_where()
|
|
|
- set_sql = self.claim_sql_set()
|
|
|
- order_by_sql = self.claim_sql_order()
|
|
|
- choose_from_top = self.claim_from_top()
|
|
|
+ locked = self.get_queue().filter(id=obj.id).update(**self.get_update_kwargs_to_claim_obj())
|
|
|
+ if not locked:
|
|
|
+ raise ActorObjectAlreadyClaimed(f'Unable to lock the next {self.Model.__name__} object from {self}.get_queue().first()')
|
|
|
+ return obj
|
|
|
|
|
|
- with db.connection.cursor() as cursor:
|
|
|
- # subquery gets the pool of the top 50 candidates sorted by sort and order
|
|
|
- # main query selects a random one from that pool
|
|
|
- cursor.execute(f"""
|
|
|
- UPDATE {table}
|
|
|
- SET {set_sql}
|
|
|
- WHERE {where_sql} and id = (
|
|
|
- SELECT id FROM (
|
|
|
- SELECT id FROM {table}
|
|
|
- WHERE {where_sql}
|
|
|
- ORDER BY {order_by_sql}
|
|
|
- LIMIT {choose_from_top}
|
|
|
- ) candidates
|
|
|
- ORDER BY RANDOM()
|
|
|
- LIMIT 1
|
|
|
- )
|
|
|
- RETURNING id;
|
|
|
- """)
|
|
|
- result = cursor.fetchone()
|
|
|
-
|
|
|
- if result is None:
|
|
|
- return None # If no rows were claimed, return None
|
|
|
-
|
|
|
- if shallow:
|
|
|
- # shallow: faster, returns potentially incomplete object instance missing some django auto-populated fields:
|
|
|
- columns = [col[0] for col in cursor.description or ['id']]
|
|
|
- return Model(**dict(zip(columns, result)))
|
|
|
+ def get_next_atomic(self) -> ModelType | None:
|
|
|
+ """
|
|
|
+ Selects the top n=50 objects from the queue and atomically claims a random one from that set.
|
|
|
+ This approach safely minimizes contention with other Actors trying to select from the same Queue.
|
|
|
|
|
|
- # if not shallow do one extra query to get a more complete object instance (load it fully from scratch)
|
|
|
- return Model.objects.get(id=result[0])
|
|
|
+ The atomic query is roughly equivalent to the following: (all done in one SQL query to avoid a TOCTTOU race)
|
|
|
+ top_candidates are selected from: qs.order_by(*CLAIM_ORDER).only('id')[:CLAIM_FROM_TOP_N]
|
|
|
+ a single candidate is chosen using: qs.filter(id__in=top_n_candidates).order_by('?').first()
|
|
|
+ the chosen obj is claimed using: qs.filter(id=chosen_obj).update(status=ACTIVE_STATE, retry_at=<now + MAX_TICK_TIME>)
|
|
|
+ """
|
|
|
+ # TODO: if we switch from SQLite to PostgreSQL in the future, we should change this
|
|
|
+ # to use SELECT FOR UPDATE instead of a subquery + ORDER BY RANDOM() LIMIT 1
|
|
|
+
|
|
|
+ # e.g. SELECT id FROM core_archiveresult WHERE status NOT IN (...) AND retry_at <= '...' ORDER BY retry_at ASC LIMIT 50
|
|
|
+ qs = self.get_queue()
|
|
|
+ select_top_canidates_sql, select_params = self._sql_for_select_top_n_candidates(qs=qs)
|
|
|
+ assert select_top_canidates_sql.startswith('SELECT ')
|
|
|
+
|
|
|
+ # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (...) AND retry_at <= '...'
|
|
|
+ update_claimed_obj_sql, update_params = self._sql_for_update_claimed_obj(qs=qs, update_kwargs=self.get_update_kwargs_to_claim_obj())
|
|
|
+ assert update_claimed_obj_sql.startswith('UPDATE ')
|
|
|
+ db_table = self.Model._meta.db_table # e.g. core_archiveresult
|
|
|
+
|
|
|
+ # subquery gets the pool of the top candidates e.g. self.get_queue().only('id')[:CLAIM_FROM_TOP_N]
|
|
|
+ # main query selects a random one from that pool, and claims it using .update(status=ACTIVE_STATE, retry_at=<now + MAX_TICK_TIME>)
|
|
|
+ # this is all done in one atomic SQL query to avoid TOCTTOU race conditions (as much as possible)
|
|
|
+ atomic_select_and_update_sql = f"""
|
|
|
+ {update_claimed_obj_sql} AND "{db_table}"."id" = (
|
|
|
+ SELECT "{db_table}"."id" FROM (
|
|
|
+ {select_top_canidates_sql}
|
|
|
+ ) candidates
|
|
|
+ ORDER BY RANDOM()
|
|
|
+ LIMIT 1
|
|
|
+ )
|
|
|
+ RETURNING *;
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ return self.Model.objects.raw(atomic_select_and_update_sql, (*update_params, *select_params))[0]
|
|
|
+ except KeyError:
|
|
|
+ if self.get_queue().exists():
|
|
|
+ raise ActorObjectAlreadyClaimed(f'Unable to lock the next {self.Model.__name__} object from {self}.get_queue().first()')
|
|
|
+ else:
|
|
|
+ raise ActorQueueIsEmpty(f'No next object available in {self}.get_queue()')
|
|
|
|
|
|
- @abstractmethod
|
|
|
- def tick(self, obj: ModelType) -> None:
|
|
|
- """override this to process the object"""
|
|
|
- print(f'[blue]🏃♂️ {self}.tick()[/blue]', obj.abid or obj.id)
|
|
|
- # For example:
|
|
|
- # do_some_task(obj)
|
|
|
- # do_something_else(obj)
|
|
|
- # obj._model.objects.filter(pk=obj.pk, status='started').update(status='success')
|
|
|
- raise NotImplementedError('tick() must be implemented by the Actor subclass')
|
|
|
-
|
|
|
+ def tick(self, obj_to_process: ModelType) -> None:
|
|
|
+ """Call the object.sm.tick() method to process the object"""
|
|
|
+ print(f'[blue]🏃♂️ {self}.tick()[/blue] {obj_to_process}')
|
|
|
+
|
|
|
+ # get the StateMachine instance from the object
|
|
|
+ obj_statemachine = self._get_state_machine_instance(obj_to_process)
|
|
|
+
|
|
|
+ # trigger the event on the StateMachine instance
|
|
|
+ obj_tick_method = getattr(obj_statemachine, self.EVENT_NAME) # e.g. obj_statemachine.tick()
|
|
|
+ obj_tick_method()
|
|
|
+
|
|
|
+ # save the object to persist any state changes
|
|
|
+ obj_to_process.save()
|
|
|
+
|
|
|
def on_startup(self) -> None:
|
|
|
if self.mode == 'thread':
|
|
|
self.pid = get_native_id() # thread id
|
|
|
@@ -290,24 +481,91 @@ class ActorType(ABC, Generic[ModelType]):
|
|
|
else:
|
|
|
self.pid = os.getpid() # process id
|
|
|
print(f'[green]🏃♂️ {self}.on_startup() STARTUP (PROCESS)[/green]')
|
|
|
- # abx.pm.hook.on_actor_startup(self)
|
|
|
+ # abx.pm.hook.on_actor_startup(actor=self)
|
|
|
|
|
|
- def on_shutdown(self, err: BaseException | None=None) -> None:
|
|
|
- print(f'[grey53]🏃♂️ {self}.on_shutdown() SHUTTING DOWN[/grey53]', err or '[green](gracefully)[/green]')
|
|
|
- # abx.pm.hook.on_actor_shutdown(self)
|
|
|
+ def on_shutdown(self, last_obj: ModelType | None=None, last_error: BaseException | None=None) -> None:
|
|
|
+ if isinstance(last_error, KeyboardInterrupt) or last_error is None:
|
|
|
+ last_error_str = '[green](CTRL-C)[/green]'
|
|
|
+ elif isinstance(last_error, ActorQueueIsEmpty):
|
|
|
+ last_error_str = '[green](queue empty)[/green]'
|
|
|
+ elif isinstance(last_error, ActorObjectAlreadyClaimed):
|
|
|
+ last_error_str = '[green](queue race)[/green]'
|
|
|
+ else:
|
|
|
+ last_error_str = f'[red]{type(last_error).__name__}: {last_error}[/red]'
|
|
|
+
|
|
|
+ print(f'[grey53]🏃♂️ {self}.on_shutdown() SHUTTING DOWN[/grey53] {last_error_str}')
|
|
|
+ # abx.pm.hook.on_actor_shutdown(actor=self, last_obj=last_obj, last_error=last_error)
|
|
|
|
|
|
- def on_tick_start(self, obj: ModelType) -> None:
|
|
|
- # print(f'🏃♂️ {self}.on_tick_start()', obj.abid or obj.id)
|
|
|
- # abx.pm.hook.on_actor_tick_start(self, obj_to_process)
|
|
|
+ def on_tick_start(self, obj_to_process: ModelType) -> None:
|
|
|
+ print(f'🏃♂️ {self}.on_tick_start() {obj_to_process}')
|
|
|
+ # abx.pm.hook.on_actor_tick_start(actor=self, obj_to_process=obj)
|
|
|
# self.timer = TimedProgress(self.MAX_TICK_TIME, prefix=' ')
|
|
|
- pass
|
|
|
|
|
|
- def on_tick_end(self, obj: ModelType) -> None:
|
|
|
- # print(f'🏃♂️ {self}.on_tick_end()', obj.abid or obj.id)
|
|
|
- # abx.pm.hook.on_actor_tick_end(self, obj_to_process)
|
|
|
+ def on_tick_end(self, obj_to_process: ModelType) -> None:
|
|
|
+ print(f'🏃♂️ {self}.on_tick_end() {obj_to_process}')
|
|
|
+ # abx.pm.hook.on_actor_tick_end(actor=self, obj_to_process=obj_to_process)
|
|
|
# self.timer.end()
|
|
|
- pass
|
|
|
|
|
|
- def on_tick_exception(self, obj: ModelType, err: BaseException) -> None:
|
|
|
- print(f'[red]🏃♂️ {self}.on_tick_exception()[/red]', obj.abid or obj.id, err)
|
|
|
- # abx.pm.hook.on_actor_tick_exception(self, obj_to_process, err)
|
|
|
+ def on_tick_exception(self, obj_to_process: ModelType, error: Exception) -> None:
|
|
|
+ print(f'[red]🏃♂️ {self}.on_tick_exception()[/red] {obj_to_process}: [red]{type(error).__name__}: {error}[/red]')
|
|
|
+ # abx.pm.hook.on_actor_tick_exception(actor=self, obj_to_process=obj_to_process, error=error)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+def compile_sql_select(queryset: QuerySet, filter_kwargs: dict[str, Any] | None=None, order_args: tuple[str, ...]=(), limit: int | None=None) -> tuple[str, tuple[Any, ...]]:
|
|
|
+ """
|
|
|
+ Compute the SELECT query SQL for a queryset.filter(**filter_kwargs).order_by(*order_args)[:limit] call
|
|
|
+ Returns a tuple of (sql, params) where sql is a template string containing %s (unquoted) placeholders for the params
|
|
|
+
|
|
|
+ WARNING:
|
|
|
+ final_sql = sql % params DOES NOT WORK to assemble the final SQL string because the %s placeholders are not quoted/escaped
|
|
|
+ they should always passed separately to the DB driver so it can do its own quoting/escaping to avoid SQL injection and syntax errors
|
|
|
+ """
|
|
|
+ assert isinstance(queryset, QuerySet), f'compile_sql_select(...) first argument must be a QuerySet, got: {type(queryset).__name__} instead'
|
|
|
+ assert filter_kwargs is None or isinstance(filter_kwargs, dict), f'compile_sql_select(...) filter_kwargs argument must be a dict[str, Any], got: {type(filter_kwargs).__name__} instead'
|
|
|
+ assert isinstance(order_args, tuple) and all(isinstance(arg, str) for arg in order_args), f'compile_sql_select(...) order_args argument must be a tuple[str, ...] got: {type(order_args).__name__} instead'
|
|
|
+ assert limit is None or isinstance(limit, int), f'compile_sql_select(...) limit argument must be an int, got: {type(limit).__name__} instead'
|
|
|
+
|
|
|
+ queryset = queryset._chain() # type: ignore # copy queryset to avoid modifying the original
|
|
|
+ if filter_kwargs:
|
|
|
+ queryset = queryset.filter(**filter_kwargs)
|
|
|
+ if order_args:
|
|
|
+ queryset = queryset.order_by(*order_args)
|
|
|
+ if limit is not None:
|
|
|
+ queryset = queryset[:limit]
|
|
|
+ query = queryset.query
|
|
|
+
|
|
|
+ # e.g. SELECT id FROM core_archiveresult WHERE status NOT IN (%s, %s, %s) AND retry_at <= %s ORDER BY retry_at ASC LIMIT 50
|
|
|
+ select_sql, select_params = query.get_compiler(queryset.db).as_sql()
|
|
|
+ return select_sql, select_params
|
|
|
+
|
|
|
+
|
|
|
+def compile_sql_update(queryset: QuerySet, update_kwargs: dict[str, Any], filter_kwargs: dict[str, Any] | None=None) -> tuple[str, tuple[Any, ...]]:
|
|
|
+ """
|
|
|
+ Compute the UPDATE query SQL for a queryset.filter(**filter_kwargs).update(**update_kwargs) call
|
|
|
+ Returns a tuple of (sql, params) where sql is a template string containing %s (unquoted) placeholders for the params
|
|
|
+
|
|
|
+ Based on the django.db.models.QuerySet.update() source code, but modified to return the SQL instead of executing the update
|
|
|
+ https://github.com/django/django/blob/611bf6c2e2a1b4ab93273980c45150c099ab146d/django/db/models/query.py#L1217
|
|
|
+
|
|
|
+ WARNING:
|
|
|
+ final_sql = sql % params DOES NOT WORK to assemble the final SQL string because the %s placeholders are not quoted/escaped
|
|
|
+ they should always passed separately to the DB driver so it can do its own quoting/escaping to avoid SQL injection and syntax errors
|
|
|
+ """
|
|
|
+ assert isinstance(queryset, QuerySet), f'compile_sql_update(...) first argument must be a QuerySet, got: {type(queryset).__name__} instead'
|
|
|
+ assert isinstance(update_kwargs, dict), f'compile_sql_update(...) update_kwargs argument must be a dict[str, Any], got: {type(update_kwargs).__name__} instead'
|
|
|
+ assert filter_kwargs is None or isinstance(filter_kwargs, dict), f'compile_sql_update(...) filter_kwargs argument must be a dict[str, Any], got: {type(filter_kwargs).__name__} instead'
|
|
|
+
|
|
|
+ queryset = queryset._chain() # type: ignore # copy queryset to avoid modifying the original
|
|
|
+ if filter_kwargs:
|
|
|
+ queryset = queryset.filter(**filter_kwargs)
|
|
|
+ queryset.query.clear_ordering(force=True) # clear any ORDER BY clauses
|
|
|
+ queryset.query.clear_limits() # clear any LIMIT clauses aka slices[:n]
|
|
|
+ queryset._for_write = True # type: ignore
|
|
|
+ query = queryset.query.chain(sql.UpdateQuery) # type: ignore
|
|
|
+ query.add_update_values(update_kwargs) # type: ignore
|
|
|
+ query.annotations = {} # clear any annotations
|
|
|
+
|
|
|
+ # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (%s, %s, %s) AND retry_at <= %s
|
|
|
+ update_sql, update_params = query.get_compiler(queryset.db).as_sql()
|
|
|
+ return update_sql, update_params
|