Jelajahi Sumber

sqlite search: clean up errors and type-checking

Clean up error handling, and report a better error message
on search and flush if FTS5 tables haven't yet been created.

Add some mypy comments to clean up type-checking errors.
Ross Williams 2 tahun lalu
induk
melakukan
1e604a1352
1 mengubah file dengan 31 tambahan dan 13 penghapusan
  1. 31 13
      archivebox/search/backends/sqlite.py

+ 31 - 13
archivebox/search/backends/sqlite.py

@@ -1,5 +1,5 @@
 import codecs
-from typing import List, Optional, Generator
+from typing import List, Generator
 import sqlite3
 
 from archivebox.util import enforce_types
@@ -22,7 +22,7 @@ if FTS_SEPARATE_DATABASE:
         return database
     SQLITE_BIND = "?"
 else:
-    from django.db import connection as database
+    from django.db import connection as database  # type: ignore[no-redef, assignment]
     get_connection = database.cursor
     SQLITE_BIND = "%s"
 
@@ -31,7 +31,7 @@ else:
 try:
     limit_id = sqlite3.SQLITE_LIMIT_LENGTH
     try:
-        with database.temporary_connection() as cursor:
+        with database.temporary_connection() as cursor:  # type: ignore[attr-defined]
             SQLITE_LIMIT_LENGTH = cursor.connection.getlimit(limit_id)
     except AttributeError:
         SQLITE_LIMIT_LENGTH = database.getlimit(limit_id)
@@ -51,6 +51,7 @@ def _escape_sqlite3(value: str, *, quote: str, errors='strict') -> str:
                                    nul_index, nul_index + 1, "NUL not allowed")
         error_handler = codecs.lookup_error(errors)
         replacement, _ = error_handler(error)
+        assert isinstance(replacement, str), "handling a UnicodeEncodeError should return a str replacement"
         encodable = encodable.replace("\x00", replacement)
 
     return quote + encodable.replace(quote, quote * 2) + quote
@@ -99,6 +100,16 @@ def _create_tables():
             " END;"
             )
 
+def _handle_query_exception(exc: Exception):
+    message = str(exc)
+    if message.startswith("no such table:"):
+        raise RuntimeError(
+            "SQLite full-text search index has not yet"
+            " been created; run `archivebox update --index-only`."
+        )
+    else:
+        raise exc
+
 @enforce_types
 def index(snapshot_id: str, texts: List[str]):
     text = ' '.join(texts)[:SQLITE_LIMIT_LENGTH]
@@ -145,22 +156,29 @@ def search(text: str) -> List[str]:
     id_table = _escape_sqlite3_identifier(FTS_ID_TABLE)
 
     with get_connection() as cursor:
-        res = cursor.execute(
-            f"SELECT snapshot_id FROM {table}"
-            f" INNER JOIN {id_table}"
-            f" ON {id_table}.rowid = {table}.rowid"
-            f" WHERE {table} MATCH {SQLITE_BIND}",
-            [text])
+        try:
+            res = cursor.execute(
+                f"SELECT snapshot_id FROM {table}"
+                f" INNER JOIN {id_table}"
+                f" ON {id_table}.rowid = {table}.rowid"
+                f" WHERE {table} MATCH {SQLITE_BIND}",
+                [text])
+        except Exception as e:
+            _handle_query_exception(e)
+
         snap_ids = [row[0] for row in res.fetchall()]
     return snap_ids
 
 @enforce_types
 def flush(snapshot_ids: Generator[str, None, None]):
-    snapshot_ids = list(snapshot_ids)
+    snapshot_ids = list(snapshot_ids)  # type: ignore[assignment]
 
     id_table = _escape_sqlite3_identifier(FTS_ID_TABLE)
 
     with get_connection() as cursor:
-        cursor.executemany(
-            f"DELETE FROM {id_table} WHERE snapshot_id={SQLITE_BIND}",
-            [snapshot_ids])
+        try:
+            cursor.executemany(
+                f"DELETE FROM {id_table} WHERE snapshot_id={SQLITE_BIND}",
+                [snapshot_ids])
+        except Exception as e:
+            _handle_query_exception(e)