1"""
2SQLite backend for the sqlite3 module in the standard library.
3"""
4
5from __future__ import annotations
6
7import datetime
8import decimal
9import warnings
10from collections.abc import Callable, Iterable, Mapping, Sequence
11from itertools import chain, tee
12from sqlite3 import dbapi2 as Database
13from typing import Any
14
15from plain.exceptions import ImproperlyConfigured
16from plain.models.backends.base.base import BaseDatabaseWrapper
17from plain.models.db import IntegrityError
18from plain.utils.dateparse import parse_date, parse_datetime, parse_time
19from plain.utils.regex_helper import _lazy_re_compile
20
21from ._functions import register as register_functions
22from .client import DatabaseClient
23from .creation import DatabaseCreation
24from .features import DatabaseFeatures
25from .introspection import DatabaseIntrospection
26from .operations import DatabaseOperations
27from .schema import DatabaseSchemaEditor
28
29
30def decoder(conv_func: Callable[[str], Any]) -> Callable[[bytes], Any]:
31 """
32 Convert bytestrings from Python's sqlite3 interface to a regular string.
33 """
34 return lambda s: conv_func(s.decode())
35
36
37def adapt_date(val: datetime.date) -> str:
38 return val.isoformat()
39
40
41def adapt_datetime(val: datetime.datetime) -> str:
42 return val.isoformat(" ")
43
44
45def _get_varchar_column(data: dict[str, Any]) -> str:
46 if data["max_length"] is None:
47 return "varchar"
48 return "varchar({max_length})".format(**data)
49
50
51Database.register_converter("bool", b"1".__eq__)
52Database.register_converter("date", decoder(parse_date))
53Database.register_converter("time", decoder(parse_time))
54Database.register_converter("datetime", decoder(parse_datetime))
55Database.register_converter("timestamp", decoder(parse_datetime))
56
57Database.register_adapter(decimal.Decimal, str)
58Database.register_adapter(datetime.date, adapt_date)
59Database.register_adapter(datetime.datetime, adapt_datetime)
60
61
62class SQLiteDatabaseWrapper(BaseDatabaseWrapper):
63 # Type checker hints: narrow base class attribute types to backend-specific classes
64 ops: DatabaseOperations
65 features: DatabaseFeatures
66 introspection: DatabaseIntrospection
67 creation: DatabaseCreation
68
69 vendor = "sqlite"
70 display_name = "SQLite"
71 # SQLite doesn't actually support most of these types, but it "does the right
72 # thing" given more verbose field definitions, so leave them as is so that
73 # schema inspection is more useful.
74 data_types = {
75 "PrimaryKeyField": "integer",
76 "BinaryField": "BLOB",
77 "BooleanField": "bool",
78 "CharField": _get_varchar_column,
79 "DateField": "date",
80 "DateTimeField": "datetime",
81 "DecimalField": "decimal",
82 "DurationField": "bigint",
83 "FloatField": "real",
84 "IntegerField": "integer",
85 "BigIntegerField": "bigint",
86 "GenericIPAddressField": "char(39)",
87 "JSONField": "text",
88 "PositiveBigIntegerField": "bigint unsigned",
89 "PositiveIntegerField": "integer unsigned",
90 "PositiveSmallIntegerField": "smallint unsigned",
91 "SmallIntegerField": "smallint",
92 "TextField": "text",
93 "TimeField": "time",
94 "UUIDField": "char(32)",
95 }
96 data_type_check_constraints = {
97 "PositiveBigIntegerField": '"%(column)s" >= 0',
98 "JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
99 "PositiveIntegerField": '"%(column)s" >= 0',
100 "PositiveSmallIntegerField": '"%(column)s" >= 0',
101 }
102 data_types_suffix = {
103 "PrimaryKeyField": "AUTOINCREMENT",
104 }
105 # SQLite requires LIKE statements to include an ESCAPE clause if the value
106 # being escaped has a percent or underscore in it.
107 # See https://www.sqlite.org/lang_expr.html for an explanation.
108 operators = {
109 "exact": "= %s",
110 "iexact": "LIKE %s ESCAPE '\\'",
111 "contains": "LIKE %s ESCAPE '\\'",
112 "icontains": "LIKE %s ESCAPE '\\'",
113 "regex": "REGEXP %s",
114 "iregex": "REGEXP '(?i)' || %s",
115 "gt": "> %s",
116 "gte": ">= %s",
117 "lt": "< %s",
118 "lte": "<= %s",
119 "startswith": "LIKE %s ESCAPE '\\'",
120 "endswith": "LIKE %s ESCAPE '\\'",
121 "istartswith": "LIKE %s ESCAPE '\\'",
122 "iendswith": "LIKE %s ESCAPE '\\'",
123 }
124
125 # The patterns below are used to generate SQL pattern lookup clauses when
126 # the right-hand side of the lookup isn't a raw string (it might be an expression
127 # or the result of a bilateral transformation).
128 # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
129 # escaped on database side.
130 #
131 # Note: we use str.format() here for readability as '%' is used as a wildcard for
132 # the LIKE operator.
133 pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
134 pattern_ops = {
135 "contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
136 "icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
137 "startswith": r"LIKE {} || '%%' ESCAPE '\'",
138 "istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
139 "endswith": r"LIKE '%%' || {} ESCAPE '\'",
140 "iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
141 }
142
143 Database = Database
144 SchemaEditorClass = DatabaseSchemaEditor
145 # Classes instantiated in __init__().
146 client_class = DatabaseClient
147 creation_class = DatabaseCreation
148 features_class = DatabaseFeatures
149 introspection_class = DatabaseIntrospection
150 ops_class = DatabaseOperations
151
152 def get_connection_params(self) -> dict[str, Any]:
153 settings_dict = self.settings_dict
154 if not settings_dict["NAME"]:
155 raise ImproperlyConfigured(
156 "settings.DATABASE is improperly configured. "
157 "Please supply the NAME value."
158 )
159 kwargs = {
160 "database": settings_dict["NAME"],
161 "detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
162 **settings_dict["OPTIONS"],
163 }
164 # Always allow the underlying SQLite connection to be shareable
165 # between multiple threads. The safe-guarding will be handled at a
166 # higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
167 # property. This is necessary as the shareability is disabled by
168 # default in sqlite3 and it cannot be changed once a connection is
169 # opened.
170 if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
171 warnings.warn(
172 "The `check_same_thread` option was provided and set to "
173 "True. It will be overridden with False. Use the "
174 "`DatabaseWrapper.allow_thread_sharing` property instead "
175 "for controlling thread shareability.",
176 RuntimeWarning,
177 )
178 kwargs.update({"check_same_thread": False, "uri": True})
179 return kwargs
180
181 def get_database_version(self) -> tuple[int, ...]:
182 return self.Database.sqlite_version_info
183
184 def get_new_connection(self, conn_params: dict[str, Any]) -> Any:
185 conn = Database.connect(**conn_params)
186 register_functions(conn)
187
188 conn.execute("PRAGMA foreign_keys = ON")
189 # The macOS bundled SQLite defaults legacy_alter_table ON, which
190 # prevents atomic table renames (feature supports_atomic_references_rename)
191 conn.execute("PRAGMA legacy_alter_table = OFF")
192 return conn
193
194 def create_cursor(self, name: str | None = None) -> Any:
195 return self.connection.cursor(factory=SQLiteCursorWrapper)
196
197 def close(self) -> None:
198 self.validate_thread_sharing()
199 # If database is in memory, closing the connection destroys the
200 # database. To prevent accidental data loss, ignore close requests on
201 # an in-memory db.
202 if not self.is_in_memory_db():
203 BaseDatabaseWrapper.close(self)
204 return None
205
206 def _savepoint_allowed(self) -> bool:
207 # When 'isolation_level' is not None, sqlite3 commits before each
208 # savepoint; it's a bug. When it is None, savepoints don't make sense
209 # because autocommit is enabled. The only exception is inside 'atomic'
210 # blocks. To work around that bug, on SQLite, 'atomic' starts a
211 # transaction explicitly rather than simply disable autocommit.
212 return self.in_atomic_block
213
214 def _set_autocommit(self, autocommit: bool) -> None:
215 if autocommit:
216 level = None
217 else:
218 # sqlite3's internal default is ''. It's different from None.
219 # See Modules/_sqlite/connection.c.
220 level = ""
221 # 'isolation_level' is a misleading API.
222 # SQLite always runs at the SERIALIZABLE isolation level.
223 with self.wrap_database_errors:
224 self.connection.isolation_level = level
225 return None
226
227 def disable_constraint_checking(self) -> bool:
228 with self.cursor() as cursor:
229 cursor.execute("PRAGMA foreign_keys = OFF")
230 # Foreign key constraints cannot be turned off while in a multi-
231 # statement transaction. Fetch the current state of the pragma
232 # to determine if constraints are effectively disabled.
233 row = cursor.execute("PRAGMA foreign_keys").fetchone()
234 assert row is not None
235 enabled = row[0]
236 return not bool(enabled)
237
238 def enable_constraint_checking(self) -> None:
239 with self.cursor() as cursor:
240 cursor.execute("PRAGMA foreign_keys = ON")
241 return None
242
243 def check_constraints(self, table_names: list[str] | None = None) -> None:
244 """
245 Check each table name in `table_names` for rows with invalid foreign
246 key references. This method is intended to be used in conjunction with
247 `disable_constraint_checking()` and `enable_constraint_checking()`, to
248 determine if rows with invalid references were entered while constraint
249 checks were off.
250 """
251 with self.cursor() as cursor:
252 if table_names is None:
253 violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
254 else:
255 violations = chain.from_iterable(
256 cursor.execute(
257 f"PRAGMA foreign_key_check({self.ops.quote_name(table_name)})"
258 ).fetchall()
259 for table_name in table_names
260 )
261 # See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
262 for (
263 table_name,
264 rowid,
265 referenced_table_name,
266 foreign_key_index,
267 ) in violations:
268 foreign_key = cursor.execute(
269 f"PRAGMA foreign_key_list({self.ops.quote_name(table_name)})"
270 ).fetchall()[foreign_key_index]
271 column_name, referenced_column_name = foreign_key[3:5]
272 primary_key_column_name = self.introspection.get_primary_key_column(
273 cursor, table_name
274 )
275 assert primary_key_column_name is not None, (
276 f"Table {table_name} must have a primary key"
277 )
278 row = cursor.execute(
279 f"SELECT {self.ops.quote_name(primary_key_column_name)}, {self.ops.quote_name(column_name)} FROM {self.ops.quote_name(table_name)} WHERE rowid = %s",
280 (rowid,),
281 ).fetchone()
282 assert row is not None
283 primary_key_value, bad_value = row
284 raise IntegrityError(
285 f"The row in table '{table_name}' with primary key '{primary_key_value}' has an "
286 f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_value}' that "
287 f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
288 )
289 return None
290
291 def is_usable(self) -> bool:
292 return True
293
294 def _start_transaction_under_autocommit(self) -> None:
295 """
296 Start a transaction explicitly in autocommit mode.
297
298 Staying in autocommit mode works around a bug of sqlite3 that breaks
299 savepoints when autocommit is disabled.
300 """
301 self.cursor().execute("BEGIN")
302 return None
303
304 def is_in_memory_db(self) -> bool:
305 name = self.settings_dict.get("NAME") or ""
306 return self.creation.is_in_memory_db(name)
307
308
309FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
310
311
312class SQLiteCursorWrapper(Database.Cursor):
313 """
314 Plain uses the "format" and "pyformat" styles, but Python's sqlite3 module
315 supports neither of these styles.
316
317 This wrapper performs the following conversions:
318
319 - "format" style to "qmark" style
320 - "pyformat" style to "named" style
321
322 In both cases, if you want to use a literal "%s", you'll need to use "%%s".
323 """
324
325 def execute( # type: ignore[override]
326 self, query: str, params: Sequence[Any] | Mapping[str, Any] = ()
327 ) -> Any:
328 if not params:
329 # Still need convert_query for %% → % conversion
330 query = self.convert_query(query)
331 return super().execute(query)
332 # Extract names if params is a mapping, i.e. "pyformat" style is used.
333 param_names = list(params) if isinstance(params, Mapping) else None
334 query = self.convert_query(query, param_names=param_names)
335 return super().execute(query, params)
336
337 def executemany( # type: ignore[override]
338 self,
339 query: str,
340 param_list: Iterable[Sequence[Any] | Mapping[str, Any]],
341 ) -> Any:
342 # Extract names if params is a mapping, i.e. "pyformat" style is used.
343 # Peek carefully as a generator can be passed instead of a list/tuple.
344 peekable, param_list = tee(iter(param_list))
345 if (params := next(peekable, None)) and isinstance(params, Mapping):
346 param_names = list(params)
347 else:
348 param_names = None
349 query = self.convert_query(query, param_names=param_names)
350 return super().executemany(query, param_list)
351
352 def convert_query(self, query: str, *, param_names: list[str] | None = None) -> str:
353 if param_names is None:
354 # Convert from "format" style to "qmark" style.
355 return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")
356 else:
357 # Convert from "pyformat" style to "named" style.
358 return query % {name: f":{name}" for name in param_names}