sqlite.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import codecs
  2. from typing import List, Generator
  3. import sqlite3
  4. from archivebox.util import enforce_types
  5. from archivebox.config import (
  6. FTS_SEPARATE_DATABASE,
  7. FTS_TOKENIZERS,
  8. FTS_SQLITE_MAX_LENGTH
  9. )
  10. FTS_TABLE = "snapshot_fts"
  11. FTS_ID_TABLE = "snapshot_id_fts"
  12. FTS_COLUMN = "texts"
  13. if FTS_SEPARATE_DATABASE:
  14. database = sqlite3.connect("search.sqlite3")
  15. # Make get_connection callable, because `django.db.connection.cursor()`
  16. # has to be called to get a context manager, but sqlite3.Connection
  17. # is a context manager without being called.
  18. def get_connection():
  19. return database
  20. SQLITE_BIND = "?"
  21. else:
  22. from django.db import connection as database # type: ignore[no-redef, assignment]
  23. get_connection = database.cursor
  24. SQLITE_BIND = "%s"
  25. # Only Python >= 3.11 supports sqlite3.Connection.getlimit(),
  26. # so fall back to the default if the API to get the real value isn't present
  27. try:
  28. limit_id = sqlite3.SQLITE_LIMIT_LENGTH
  29. try:
  30. with database.temporary_connection() as cursor: # type: ignore[attr-defined]
  31. SQLITE_LIMIT_LENGTH = cursor.connection.getlimit(limit_id)
  32. except AttributeError:
  33. SQLITE_LIMIT_LENGTH = database.getlimit(limit_id)
  34. except AttributeError:
  35. SQLITE_LIMIT_LENGTH = FTS_SQLITE_MAX_LENGTH
  36. def _escape_sqlite3(value: str, *, quote: str, errors='strict') -> str:
  37. assert isinstance(quote, str), "quote is not a str"
  38. assert len(quote) == 1, "quote must be a single character"
  39. encodable = value.encode('utf-8', errors).decode('utf-8')
  40. nul_index = encodable.find("\x00")
  41. if nul_index >= 0:
  42. error = UnicodeEncodeError("NUL-terminated utf-8", encodable,
  43. nul_index, nul_index + 1, "NUL not allowed")
  44. error_handler = codecs.lookup_error(errors)
  45. replacement, _ = error_handler(error)
  46. assert isinstance(replacement, str), "handling a UnicodeEncodeError should return a str replacement"
  47. encodable = encodable.replace("\x00", replacement)
  48. return quote + encodable.replace(quote, quote * 2) + quote
  49. def _escape_sqlite3_value(value: str, errors='strict') -> str:
  50. return _escape_sqlite3(value, quote="'", errors=errors)
  51. def _escape_sqlite3_identifier(value: str) -> str:
  52. return _escape_sqlite3(value, quote='"', errors='strict')
  53. @enforce_types
  54. def _create_tables():
  55. table = _escape_sqlite3_identifier(FTS_TABLE)
  56. # Escape as value, because fts5() expects
  57. # string literal column names
  58. column = _escape_sqlite3_value(FTS_COLUMN)
  59. id_table = _escape_sqlite3_identifier(FTS_ID_TABLE)
  60. tokenizers = _escape_sqlite3_value(FTS_TOKENIZERS)
  61. trigger_name = _escape_sqlite3_identifier(f"{FTS_ID_TABLE}_ad")
  62. with get_connection() as cursor:
  63. # Create a contentless-delete FTS5 table that indexes
  64. # but does not store the texts of snapshots
  65. try:
  66. cursor.execute(
  67. f"CREATE VIRTUAL TABLE {table}"
  68. f" USING fts5({column},"
  69. f" tokenize={tokenizers},"
  70. " content='', contentless_delete=1);"
  71. )
  72. except Exception as e:
  73. msg = str(e)
  74. if 'unrecognized option: "contentlessdelete"' in msg:
  75. sqlite_version = getattr(sqlite3, "sqlite_version", "Unknown")
  76. raise RuntimeError(
  77. "SQLite full-text search requires SQLite >= 3.43.0;"
  78. f" the running version is {sqlite_version}"
  79. ) from e
  80. else:
  81. raise
  82. # Create a one-to-one mapping between ArchiveBox snapshot_id
  83. # and FTS5 rowid, because the column type of rowid can't be
  84. # customized.
  85. cursor.execute(
  86. f"CREATE TABLE {id_table}("
  87. " rowid INTEGER PRIMARY KEY AUTOINCREMENT,"
  88. " snapshot_id char(32) NOT NULL UNIQUE"
  89. ");"
  90. )
  91. # Create a trigger to delete items from the FTS5 index when
  92. # the snapshot_id is deleted from the mapping, to maintain
  93. # consistency and make the `flush()` query simpler.
  94. cursor.execute(
  95. f"CREATE TRIGGER {trigger_name}"
  96. f" AFTER DELETE ON {id_table} BEGIN"
  97. f" DELETE FROM {table} WHERE rowid=old.rowid;"
  98. " END;"
  99. )
  100. def _handle_query_exception(exc: Exception):
  101. message = str(exc)
  102. if message.startswith("no such table:"):
  103. raise RuntimeError(
  104. "SQLite full-text search index has not yet"
  105. " been created; run `archivebox update --index-only`."
  106. )
  107. else:
  108. raise exc
  109. @enforce_types
  110. def index(snapshot_id: str, texts: List[str]):
  111. text = ' '.join(texts)[:SQLITE_LIMIT_LENGTH]
  112. table = _escape_sqlite3_identifier(FTS_TABLE)
  113. column = _escape_sqlite3_identifier(FTS_COLUMN)
  114. id_table = _escape_sqlite3_identifier(FTS_ID_TABLE)
  115. with get_connection() as cursor:
  116. retries = 2
  117. while retries > 0:
  118. retries -= 1
  119. try:
  120. # If there is already an FTS index rowid to snapshot_id mapping,
  121. # then don't insert a new one, silently ignoring the operation.
  122. # {id_table}.rowid is AUTOINCREMENT, so will generate an unused
  123. # rowid for the index if it is an unindexed snapshot_id.
  124. cursor.execute(
  125. f"INSERT OR IGNORE INTO {id_table}(snapshot_id) VALUES({SQLITE_BIND})",
  126. [snapshot_id])
  127. # Fetch the FTS index rowid for the given snapshot_id
  128. id_res = cursor.execute(
  129. f"SELECT rowid FROM {id_table} WHERE snapshot_id = {SQLITE_BIND}",
  130. [snapshot_id])
  131. rowid = id_res.fetchone()[0]
  132. # (Re-)index the content
  133. cursor.execute(
  134. "INSERT OR REPLACE INTO"
  135. f" {table}(rowid, {column}) VALUES ({SQLITE_BIND}, {SQLITE_BIND})",
  136. [rowid, text])
  137. # All statements succeeded; return
  138. return
  139. except Exception as e:
  140. if str(e).startswith("no such table:") and retries > 0:
  141. _create_tables()
  142. else:
  143. raise
  144. raise RuntimeError("Failed to create tables for SQLite FTS5 search")
  145. @enforce_types
  146. def search(text: str) -> List[str]:
  147. table = _escape_sqlite3_identifier(FTS_TABLE)
  148. id_table = _escape_sqlite3_identifier(FTS_ID_TABLE)
  149. with get_connection() as cursor:
  150. try:
  151. res = cursor.execute(
  152. f"SELECT snapshot_id FROM {table}"
  153. f" INNER JOIN {id_table}"
  154. f" ON {id_table}.rowid = {table}.rowid"
  155. f" WHERE {table} MATCH {SQLITE_BIND}",
  156. [text])
  157. except Exception as e:
  158. _handle_query_exception(e)
  159. snap_ids = [row[0] for row in res.fetchall()]
  160. return snap_ids
  161. @enforce_types
  162. def flush(snapshot_ids: Generator[str, None, None]):
  163. snapshot_ids = list(snapshot_ids) # type: ignore[assignment]
  164. id_table = _escape_sqlite3_identifier(FTS_ID_TABLE)
  165. with get_connection() as cursor:
  166. try:
  167. cursor.executemany(
  168. f"DELETE FROM {id_table} WHERE snapshot_id={SQLITE_BIND}",
  169. [snapshot_ids])
  170. except Exception as e:
  171. _handle_query_exception(e)