|
|
@@ -1,5 +1,5 @@
|
|
|
import codecs
|
|
|
-from typing import List, Optional, Generator
|
|
|
+from typing import List, Generator
|
|
|
import sqlite3
|
|
|
|
|
|
from archivebox.util import enforce_types
|
|
|
@@ -22,7 +22,7 @@ if FTS_SEPARATE_DATABASE:
|
|
|
return database
|
|
|
SQLITE_BIND = "?"
|
|
|
else:
|
|
|
- from django.db import connection as database
|
|
|
+ from django.db import connection as database # type: ignore[no-redef, assignment]
|
|
|
get_connection = database.cursor
|
|
|
SQLITE_BIND = "%s"
|
|
|
|
|
|
@@ -31,7 +31,7 @@ else:
|
|
|
try:
|
|
|
limit_id = sqlite3.SQLITE_LIMIT_LENGTH
|
|
|
try:
|
|
|
- with database.temporary_connection() as cursor:
|
|
|
+ with database.temporary_connection() as cursor: # type: ignore[attr-defined]
|
|
|
SQLITE_LIMIT_LENGTH = cursor.connection.getlimit(limit_id)
|
|
|
except AttributeError:
|
|
|
SQLITE_LIMIT_LENGTH = database.getlimit(limit_id)
|
|
|
@@ -51,6 +51,7 @@ def _escape_sqlite3(value: str, *, quote: str, errors='strict') -> str:
|
|
|
nul_index, nul_index + 1, "NUL not allowed")
|
|
|
error_handler = codecs.lookup_error(errors)
|
|
|
replacement, _ = error_handler(error)
|
|
|
+ assert isinstance(replacement, str), "handling a UnicodeEncodeError should return a str replacement"
|
|
|
encodable = encodable.replace("\x00", replacement)
|
|
|
|
|
|
return quote + encodable.replace(quote, quote * 2) + quote
|
|
|
@@ -99,6 +100,16 @@ def _create_tables():
|
|
|
" END;"
|
|
|
)
|
|
|
|
|
|
+def _handle_query_exception(exc: Exception):
|
|
|
+ message = str(exc)
|
|
|
+ if message.startswith("no such table:"):
|
|
|
+ raise RuntimeError(
|
|
|
+ "SQLite full-text search index has not yet"
|
|
|
+ " been created; run `archivebox update --index-only`."
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise exc
|
|
|
+
|
|
|
@enforce_types
|
|
|
def index(snapshot_id: str, texts: List[str]):
|
|
|
text = ' '.join(texts)[:SQLITE_LIMIT_LENGTH]
|
|
|
@@ -145,22 +156,29 @@ def search(text: str) -> List[str]:
|
|
|
id_table = _escape_sqlite3_identifier(FTS_ID_TABLE)
|
|
|
|
|
|
with get_connection() as cursor:
|
|
|
- res = cursor.execute(
|
|
|
- f"SELECT snapshot_id FROM {table}"
|
|
|
- f" INNER JOIN {id_table}"
|
|
|
- f" ON {id_table}.rowid = {table}.rowid"
|
|
|
- f" WHERE {table} MATCH {SQLITE_BIND}",
|
|
|
- [text])
|
|
|
+ try:
|
|
|
+ res = cursor.execute(
|
|
|
+ f"SELECT snapshot_id FROM {table}"
|
|
|
+ f" INNER JOIN {id_table}"
|
|
|
+ f" ON {id_table}.rowid = {table}.rowid"
|
|
|
+ f" WHERE {table} MATCH {SQLITE_BIND}",
|
|
|
+ [text])
|
|
|
+ except Exception as e:
|
|
|
+ _handle_query_exception(e)
|
|
|
+
|
|
|
snap_ids = [row[0] for row in res.fetchall()]
|
|
|
return snap_ids
|
|
|
|
|
|
@enforce_types
|
|
|
def flush(snapshot_ids: Generator[str, None, None]):
|
|
|
- snapshot_ids = list(snapshot_ids)
|
|
|
+ snapshot_ids = list(snapshot_ids) # type: ignore[assignment]
|
|
|
|
|
|
id_table = _escape_sqlite3_identifier(FTS_ID_TABLE)
|
|
|
|
|
|
with get_connection() as cursor:
|
|
|
- cursor.executemany(
|
|
|
- f"DELETE FROM {id_table} WHERE snapshot_id={SQLITE_BIND}",
|
|
|
- [snapshot_ids])
|
|
|
+ try:
|
|
|
+ cursor.executemany(
|
|
|
+ f"DELETE FROM {id_table} WHERE snapshot_id={SQLITE_BIND}",
|
|
|
+ [snapshot_ids])
|
|
|
+ except Exception as e:
|
|
|
+ _handle_query_exception(e)
|