1"""
2SQLite backend for the sqlite3 module in the standard library.
3"""
4
5from __future__ import annotations
6
7import datetime
8import decimal
9import warnings
10from collections.abc import Callable, Iterable, Mapping
11from itertools import chain, tee
12from sqlite3 import dbapi2 as Database
13from typing import Any
14
15from plain.exceptions import ImproperlyConfigured
16from plain.models.backends.base.base import BaseDatabaseWrapper
17from plain.models.db import IntegrityError
18from plain.utils.dateparse import parse_date, parse_datetime, parse_time
19from plain.utils.regex_helper import _lazy_re_compile
20
21from ._functions import register as register_functions
22from .client import DatabaseClient
23from .creation import DatabaseCreation
24from .features import DatabaseFeatures
25from .introspection import DatabaseIntrospection
26from .operations import DatabaseOperations
27from .schema import DatabaseSchemaEditor
28
29
30def decoder(conv_func: Callable[[str], Any]) -> Callable[[bytes], Any]:
31 """
32 Convert bytestrings from Python's sqlite3 interface to a regular string.
33 """
34 return lambda s: conv_func(s.decode())
35
36
37def adapt_date(val: datetime.date) -> str:
38 return val.isoformat()
39
40
41def adapt_datetime(val: datetime.datetime) -> str:
42 return val.isoformat(" ")
43
44
45def _get_varchar_column(data: dict[str, Any]) -> str:
46 if data["max_length"] is None:
47 return "varchar"
48 return "varchar({max_length})".format(**data)
49
50
51Database.register_converter("bool", b"1".__eq__)
52Database.register_converter("date", decoder(parse_date))
53Database.register_converter("time", decoder(parse_time))
54Database.register_converter("datetime", decoder(parse_datetime))
55Database.register_converter("timestamp", decoder(parse_datetime))
56
57Database.register_adapter(decimal.Decimal, str)
58Database.register_adapter(datetime.date, adapt_date)
59Database.register_adapter(datetime.datetime, adapt_datetime)
60
61
62class SQLiteDatabaseWrapper(BaseDatabaseWrapper):
63 vendor = "sqlite"
64 display_name = "SQLite"
65 # SQLite doesn't actually support most of these types, but it "does the right
66 # thing" given more verbose field definitions, so leave them as is so that
67 # schema inspection is more useful.
68 data_types = {
69 "PrimaryKeyField": "integer",
70 "BinaryField": "BLOB",
71 "BooleanField": "bool",
72 "CharField": _get_varchar_column,
73 "DateField": "date",
74 "DateTimeField": "datetime",
75 "DecimalField": "decimal",
76 "DurationField": "bigint",
77 "FloatField": "real",
78 "IntegerField": "integer",
79 "BigIntegerField": "bigint",
80 "GenericIPAddressField": "char(39)",
81 "JSONField": "text",
82 "PositiveBigIntegerField": "bigint unsigned",
83 "PositiveIntegerField": "integer unsigned",
84 "PositiveSmallIntegerField": "smallint unsigned",
85 "SmallIntegerField": "smallint",
86 "TextField": "text",
87 "TimeField": "time",
88 "UUIDField": "char(32)",
89 }
90 data_type_check_constraints = {
91 "PositiveBigIntegerField": '"%(column)s" >= 0',
92 "JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
93 "PositiveIntegerField": '"%(column)s" >= 0',
94 "PositiveSmallIntegerField": '"%(column)s" >= 0',
95 }
96 data_types_suffix = {
97 "PrimaryKeyField": "AUTOINCREMENT",
98 }
99 # SQLite requires LIKE statements to include an ESCAPE clause if the value
100 # being escaped has a percent or underscore in it.
101 # See https://www.sqlite.org/lang_expr.html for an explanation.
102 operators = {
103 "exact": "= %s",
104 "iexact": "LIKE %s ESCAPE '\\'",
105 "contains": "LIKE %s ESCAPE '\\'",
106 "icontains": "LIKE %s ESCAPE '\\'",
107 "regex": "REGEXP %s",
108 "iregex": "REGEXP '(?i)' || %s",
109 "gt": "> %s",
110 "gte": ">= %s",
111 "lt": "< %s",
112 "lte": "<= %s",
113 "startswith": "LIKE %s ESCAPE '\\'",
114 "endswith": "LIKE %s ESCAPE '\\'",
115 "istartswith": "LIKE %s ESCAPE '\\'",
116 "iendswith": "LIKE %s ESCAPE '\\'",
117 }
118
119 # The patterns below are used to generate SQL pattern lookup clauses when
120 # the right-hand side of the lookup isn't a raw string (it might be an expression
121 # or the result of a bilateral transformation).
122 # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
123 # escaped on database side.
124 #
125 # Note: we use str.format() here for readability as '%' is used as a wildcard for
126 # the LIKE operator.
127 pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
128 pattern_ops = {
129 "contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
130 "icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
131 "startswith": r"LIKE {} || '%%' ESCAPE '\'",
132 "istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
133 "endswith": r"LIKE '%%' || {} ESCAPE '\'",
134 "iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
135 }
136
137 Database = Database
138 SchemaEditorClass = DatabaseSchemaEditor
139 # Classes instantiated in __init__().
140 client_class = DatabaseClient
141 creation_class = DatabaseCreation
142 features_class = DatabaseFeatures
143 introspection_class = DatabaseIntrospection
144 ops_class = DatabaseOperations
145
146 def get_connection_params(self) -> dict[str, Any]:
147 settings_dict = self.settings_dict
148 if not settings_dict["NAME"]:
149 raise ImproperlyConfigured(
150 "settings.DATABASE is improperly configured. "
151 "Please supply the NAME value."
152 )
153 kwargs = {
154 "database": settings_dict["NAME"],
155 "detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
156 **settings_dict["OPTIONS"],
157 }
158 # Always allow the underlying SQLite connection to be shareable
159 # between multiple threads. The safe-guarding will be handled at a
160 # higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
161 # property. This is necessary as the shareability is disabled by
162 # default in sqlite3 and it cannot be changed once a connection is
163 # opened.
164 if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
165 warnings.warn(
166 "The `check_same_thread` option was provided and set to "
167 "True. It will be overridden with False. Use the "
168 "`DatabaseWrapper.allow_thread_sharing` property instead "
169 "for controlling thread shareability.",
170 RuntimeWarning,
171 )
172 kwargs.update({"check_same_thread": False, "uri": True})
173 return kwargs
174
175 def get_database_version(self) -> tuple[int, ...]:
176 return self.Database.sqlite_version_info
177
178 def get_new_connection(self, conn_params: dict[str, Any]) -> Any:
179 conn = Database.connect(**conn_params) # type: ignore[call-overload]
180 register_functions(conn)
181
182 conn.execute("PRAGMA foreign_keys = ON")
183 # The macOS bundled SQLite defaults legacy_alter_table ON, which
184 # prevents atomic table renames (feature supports_atomic_references_rename)
185 conn.execute("PRAGMA legacy_alter_table = OFF")
186 return conn
187
188 def create_cursor(self, name: str | None = None) -> Any:
189 return self.connection.cursor(factory=SQLiteCursorWrapper)
190
191 def close(self) -> None:
192 self.validate_thread_sharing()
193 # If database is in memory, closing the connection destroys the
194 # database. To prevent accidental data loss, ignore close requests on
195 # an in-memory db.
196 if not self.is_in_memory_db():
197 BaseDatabaseWrapper.close(self)
198 return None
199
200 def _savepoint_allowed(self) -> bool:
201 # When 'isolation_level' is not None, sqlite3 commits before each
202 # savepoint; it's a bug. When it is None, savepoints don't make sense
203 # because autocommit is enabled. The only exception is inside 'atomic'
204 # blocks. To work around that bug, on SQLite, 'atomic' starts a
205 # transaction explicitly rather than simply disable autocommit.
206 return self.in_atomic_block
207
208 def _set_autocommit(self, autocommit: bool) -> None:
209 if autocommit:
210 level = None
211 else:
212 # sqlite3's internal default is ''. It's different from None.
213 # See Modules/_sqlite/connection.c.
214 level = ""
215 # 'isolation_level' is a misleading API.
216 # SQLite always runs at the SERIALIZABLE isolation level.
217 with self.wrap_database_errors:
218 self.connection.isolation_level = level
219 return None
220
221 def disable_constraint_checking(self) -> bool:
222 with self.cursor() as cursor:
223 cursor.execute("PRAGMA foreign_keys = OFF")
224 # Foreign key constraints cannot be turned off while in a multi-
225 # statement transaction. Fetch the current state of the pragma
226 # to determine if constraints are effectively disabled.
227 enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
228 return not bool(enabled)
229
230 def enable_constraint_checking(self) -> None:
231 with self.cursor() as cursor:
232 cursor.execute("PRAGMA foreign_keys = ON")
233 return None
234
235 def check_constraints(self, table_names: list[str] | None = None) -> None:
236 """
237 Check each table name in `table_names` for rows with invalid foreign
238 key references. This method is intended to be used in conjunction with
239 `disable_constraint_checking()` and `enable_constraint_checking()`, to
240 determine if rows with invalid references were entered while constraint
241 checks were off.
242 """
243 with self.cursor() as cursor:
244 if table_names is None:
245 violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
246 else:
247 violations = chain.from_iterable(
248 cursor.execute(
249 f"PRAGMA foreign_key_check({self.ops.quote_name(table_name)})"
250 ).fetchall()
251 for table_name in table_names
252 )
253 # See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
254 for (
255 table_name,
256 rowid,
257 referenced_table_name,
258 foreign_key_index,
259 ) in violations:
260 foreign_key = cursor.execute(
261 f"PRAGMA foreign_key_list({self.ops.quote_name(table_name)})"
262 ).fetchall()[foreign_key_index]
263 column_name, referenced_column_name = foreign_key[3:5]
264 primary_key_column_name = self.introspection.get_primary_key_column(
265 cursor, table_name
266 )
267 primary_key_value, bad_value = cursor.execute(
268 f"SELECT {self.ops.quote_name(primary_key_column_name)}, {self.ops.quote_name(column_name)} FROM {self.ops.quote_name(table_name)} WHERE rowid = %s",
269 (rowid,),
270 ).fetchone()
271 raise IntegrityError(
272 f"The row in table '{table_name}' with primary key '{primary_key_value}' has an "
273 f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_value}' that "
274 f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
275 )
276 return None
277
278 def is_usable(self) -> bool:
279 return True
280
281 def _start_transaction_under_autocommit(self) -> None:
282 """
283 Start a transaction explicitly in autocommit mode.
284
285 Staying in autocommit mode works around a bug of sqlite3 that breaks
286 savepoints when autocommit is disabled.
287 """
288 self.cursor().execute("BEGIN")
289 return None
290
291 def is_in_memory_db(self) -> bool:
292 return self.creation.is_in_memory_db(self.settings_dict["NAME"])
293
294
295FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
296
297
298class SQLiteCursorWrapper(Database.Cursor):
299 """
300 Plain uses the "format" and "pyformat" styles, but Python's sqlite3 module
301 supports neither of these styles.
302
303 This wrapper performs the following conversions:
304
305 - "format" style to "qmark" style
306 - "pyformat" style to "named" style
307
308 In both cases, if you want to use a literal "%s", you'll need to use "%%s".
309 """
310
311 def execute( # type: ignore[override]
312 self, query: str, params: Iterable[Any] | Mapping[str, Any] | None = None
313 ) -> Any:
314 if params is None:
315 return super().execute(query)
316 # Extract names if params is a mapping, i.e. "pyformat" style is used.
317 param_names = list(params) if isinstance(params, Mapping) else None
318 query = self.convert_query(query, param_names=param_names)
319 return super().execute(query, params)
320
321 def executemany( # type: ignore[override]
322 self,
323 query: str,
324 param_list: Iterable[Iterable[Any] | Mapping[str, Any]],
325 ) -> Any:
326 # Extract names if params is a mapping, i.e. "pyformat" style is used.
327 # Peek carefully as a generator can be passed instead of a list/tuple.
328 peekable, param_list = tee(iter(param_list))
329 if (params := next(peekable, None)) and isinstance(params, Mapping):
330 param_names = list(params)
331 else:
332 param_names = None
333 query = self.convert_query(query, param_names=param_names)
334 return super().executemany(query, param_list)
335
336 def convert_query(self, query: str, *, param_names: list[str] | None = None) -> str:
337 if param_names is None:
338 # Convert from "format" style to "qmark" style.
339 return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")
340 else:
341 # Convert from "pyformat" style to "named" style.
342 return query % {name: f":{name}" for name in param_names}