Browse Source

fix serious bug with Actor.get_next updating all rows instead of only top row

Nick Sweeting 1 year ago
parent
commit
148ea907bd
3 changed files with 26 additions and 16 deletions
  1. 1 0
      archivebox/__init__.py
  2. 23 14
      archivebox/actors/actor.py
  3. 2 2
      archivebox/actors/orchestrator.py

+ 1 - 0
archivebox/__init__.py

@@ -84,6 +84,7 @@ ARCHIVEBOX_BUILTIN_PLUGINS = {
     'config': PACKAGE_DIR / 'config',
     'config': PACKAGE_DIR / 'config',
     'core': PACKAGE_DIR / 'core',
     'core': PACKAGE_DIR / 'core',
     'crawls': PACKAGE_DIR / 'crawls',
     'crawls': PACKAGE_DIR / 'crawls',
+    'queues': PACKAGE_DIR / 'queues',
     'seeds': PACKAGE_DIR / 'seeds',
     'seeds': PACKAGE_DIR / 'seeds',
     'actors': PACKAGE_DIR / 'actors',
     'actors': PACKAGE_DIR / 'actors',
     # 'search': PACKAGE_DIR / 'search',
     # 'search': PACKAGE_DIR / 'search',

+ 23 - 14
archivebox/actors/actor.py

@@ -75,7 +75,7 @@ class ActorType(Generic[ModelType]):
     _SPAWNED_ACTOR_PIDS: ClassVar[list[psutil.Process]] = []      # used to record all the pids of Actors spawned on the class
     _SPAWNED_ACTOR_PIDS: ClassVar[list[psutil.Process]] = []      # used to record all the pids of Actors spawned on the class
     
     
     ### Instance attributes (only used within an actor instance inside a spawned actor thread/process)
     ### Instance attributes (only used within an actor instance inside a spawned actor thread/process)
-    pid: int
+    pid: int = os.getpid()
     idle_count: int = 0
     idle_count: int = 0
     launch_kwargs: LaunchKwargs = {}
     launch_kwargs: LaunchKwargs = {}
     mode: Literal['thread', 'process'] = 'process'
     mode: Literal['thread', 'process'] = 'process'
@@ -290,7 +290,7 @@ class ActorType(Generic[ModelType]):
         Override this in the subclass to define the QuerySet of objects that the Actor is going to poll for new work.
         Override this in the subclass to define the QuerySet of objects that the Actor is going to poll for new work.
         (don't limit, order, or filter this by retry_at or status yet, Actor.get_queue() handles that part)
         (don't limit, order, or filter this by retry_at or status yet, Actor.get_queue() handles that part)
         """
         """
-        return cls.Model.objects.all()
+        return cls.Model.objects.filter()
     
     
     @classproperty
     @classproperty
     def final_q(cls) -> Q:
     def final_q(cls) -> Q:
@@ -438,25 +438,30 @@ class ActorType(Generic[ModelType]):
         assert select_top_canidates_sql.startswith('SELECT ')
         assert select_top_canidates_sql.startswith('SELECT ')
         
         
         # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (...) AND retry_at <= '...'
         # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (...) AND retry_at <= '...'
-        update_claimed_obj_sql, update_params = self._sql_for_update_claimed_obj(qs=qs, update_kwargs=self.get_update_kwargs_to_claim_obj())
-        assert update_claimed_obj_sql.startswith('UPDATE ')
+        update_claimed_obj_sql, update_params = self._sql_for_update_claimed_obj(qs=self.qs.all(), update_kwargs=self.get_update_kwargs_to_claim_obj())
+        assert update_claimed_obj_sql.startswith('UPDATE ') and 'WHERE' not in update_claimed_obj_sql
         db_table = self.Model._meta.db_table  # e.g. core_archiveresult
         db_table = self.Model._meta.db_table  # e.g. core_archiveresult
         
         
         # subquery gets the pool of the top candidates e.g. self.get_queue().only('id')[:CLAIM_FROM_TOP_N]
         # subquery gets the pool of the top candidates e.g. self.get_queue().only('id')[:CLAIM_FROM_TOP_N]
         # main query selects a random one from that pool, and claims it using .update(status=ACTIVE_STATE, retry_at=<now + MAX_TICK_TIME>)
         # main query selects a random one from that pool, and claims it using .update(status=ACTIVE_STATE, retry_at=<now + MAX_TICK_TIME>)
         # this is all done in one atomic SQL query to avoid TOCTTOU race conditions (as much as possible)
         # this is all done in one atomic SQL query to avoid TOCTTOU race conditions (as much as possible)
         atomic_select_and_update_sql = f"""
         atomic_select_and_update_sql = f"""
-            {update_claimed_obj_sql} AND "{db_table}"."id" = (
-                SELECT "{db_table}"."id" FROM (
-                    {select_top_canidates_sql}
-                ) candidates
+            with top_candidates AS ({select_top_canidates_sql})
+            {update_claimed_obj_sql}
+            WHERE "{db_table}"."id" IN (
+                SELECT id FROM top_candidates
                 ORDER BY RANDOM()
                 ORDER BY RANDOM()
                 LIMIT 1
                 LIMIT 1
             )
             )
             RETURNING *;
             RETURNING *;
         """
         """
+        
+        # import ipdb; ipdb.set_trace()
+
         try:
         try:
-            return self.Model.objects.raw(atomic_select_and_update_sql, (*update_params, *select_params))[0]
+            updated = qs.raw(atomic_select_and_update_sql, (*select_params, *update_params))
+            assert len(updated) <= 1, f'Expected to claim at most 1 object, but Django modified {len(updated)} objects!'
+            return updated[0]
         except IndexError:
         except IndexError:
             if self.get_queue().exists():
             if self.get_queue().exists():
                 raise ActorObjectAlreadyClaimed(f'Unable to lock the next {self.Model.__name__} object from {self}.get_queue().first()')
                 raise ActorObjectAlreadyClaimed(f'Unable to lock the next {self.Model.__name__} object from {self}.get_queue().first()')
@@ -548,7 +553,7 @@ def compile_sql_select(queryset: QuerySet, filter_kwargs: dict[str, Any] | None=
     return select_sql, select_params
     return select_sql, select_params
 
 
 
 
-def compile_sql_update(queryset: QuerySet, update_kwargs: dict[str, Any], filter_kwargs: dict[str, Any] | None=None) -> tuple[str, tuple[Any, ...]]:
+def compile_sql_update(queryset: QuerySet, update_kwargs: dict[str, Any]) -> tuple[str, tuple[Any, ...]]:
     """
     """
     Compute the UPDATE query SQL for a queryset.filter(**filter_kwargs).update(**update_kwargs) call
     Compute the UPDATE query SQL for a queryset.filter(**filter_kwargs).update(**update_kwargs) call
     Returns a tuple of (sql, params) where sql is a template string containing %s (unquoted) placeholders for the params
     Returns a tuple of (sql, params) where sql is a template string containing %s (unquoted) placeholders for the params
@@ -562,11 +567,8 @@ def compile_sql_update(queryset: QuerySet, update_kwargs: dict[str, Any], filter
     """
     """
     assert isinstance(queryset, QuerySet), f'compile_sql_update(...) first argument must be a QuerySet, got: {type(queryset).__name__} instead'
     assert isinstance(queryset, QuerySet), f'compile_sql_update(...) first argument must be a QuerySet, got: {type(queryset).__name__} instead'
     assert isinstance(update_kwargs, dict), f'compile_sql_update(...) update_kwargs argument must be a dict[str, Any], got: {type(update_kwargs).__name__} instead'
     assert isinstance(update_kwargs, dict), f'compile_sql_update(...) update_kwargs argument must be a dict[str, Any], got: {type(update_kwargs).__name__} instead'
-    assert filter_kwargs is None or isinstance(filter_kwargs, dict), f'compile_sql_update(...) filter_kwargs argument must be a dict[str, Any], got: {type(filter_kwargs).__name__} instead'
     
     
-    queryset = queryset._chain()                      # type: ignore   # copy queryset to avoid modifying the original
-    if filter_kwargs:
-        queryset = queryset.filter(**filter_kwargs)
+    queryset = queryset._chain().all()                # type: ignore   # copy queryset to avoid modifying the original and clear any filters
     queryset.query.clear_ordering(force=True)                          # clear any ORDER BY clauses
     queryset.query.clear_ordering(force=True)                          # clear any ORDER BY clauses
     queryset.query.clear_limits()                                      # clear any LIMIT clauses aka slices[:n]
     queryset.query.clear_limits()                                      # clear any LIMIT clauses aka slices[:n]
     queryset._for_write = True                        # type: ignore
     queryset._for_write = True                        # type: ignore
@@ -576,5 +578,12 @@ def compile_sql_update(queryset: QuerySet, update_kwargs: dict[str, Any], filter
     
     
     # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (%s, %s, %s) AND retry_at <= %s
     # e.g. UPDATE core_archiveresult SET status='%s', retry_at='%s' WHERE status NOT IN (%s, %s, %s) AND retry_at <= %s
     update_sql, update_params = query.get_compiler(queryset.db).as_sql()
     update_sql, update_params = query.get_compiler(queryset.db).as_sql()
+    
+    # make sure you only pass a raw queryset with no .filter(...) clauses applied to it, the return value is designed to used
+    # in a manually assembled SQL query with its own WHERE clause later on
+    assert 'WHERE' not in update_sql, f'compile_sql_update(...) should only contain a SET statement but it tried to return a query with a WHERE clause: {update_sql}'
+    
+    # print(update_sql, update_params)
+
     return update_sql, update_params
     return update_sql, update_params
 
 

+ 2 - 2
archivebox/actors/orchestrator.py

@@ -102,7 +102,7 @@ class Orchestrator:
         # returns a list of objects that are in the queues of all actor types but not in the queues of any other actor types
         # returns a list of objects that are in the queues of all actor types but not in the queues of any other actor types
 
 
         return any(
         return any(
-            queue.filter(retry_at__gt=timezone.now()).exists()
+            queue.filter(retry_at__gte=timezone.now()).exists()
             for queue in all_queues.values()
             for queue in all_queues.values()
         )
         )
     
     
@@ -163,7 +163,7 @@ class Orchestrator:
 
 
                 for actor_type, queue in all_queues.items():
                 for actor_type, queue in all_queues.items():
                     next_obj = queue.first()
                     next_obj = queue.first()
-                    print(f'🏃‍♂️ {self}.runloop() {actor_type.__name__.ljust(20)} queue={str(queue.count()).ljust(3)} next={next_obj.abid if next_obj else "None"} {next_obj.status if next_obj else "None"} {(timezone.now() - next_obj.retry_at).total_seconds() if next_obj else "None"}')
+                    print(f'🏃‍♂️ {self}.runloop() {actor_type.__name__.ljust(20)} queue={str(queue.count()).ljust(3)} next={next_obj.abid if next_obj else "None"} {next_obj.status if next_obj else "None"} {(timezone.now() - next_obj.retry_at).total_seconds() if next_obj and next_obj.retry_at else "None"}')
                     try:
                     try:
                         existing_actors = actor_type.get_running_actors()
                         existing_actors = actor_type.get_running_actors()
                         all_existing_actors.extend(existing_actors)
                         all_existing_actors.extend(existing_actors)