1"""
2SQLite backend for the sqlite3 module in the standard library.
3"""
4
5import datetime
6import decimal
7import warnings
8from collections.abc import Mapping
9from itertools import chain, tee
10from sqlite3 import dbapi2 as Database
11
12from plain.exceptions import ImproperlyConfigured
13from plain.models.backends.base.base import BaseDatabaseWrapper
14from plain.models.db import IntegrityError
15from plain.utils.dateparse import parse_date, parse_datetime, parse_time
16from plain.utils.regex_helper import _lazy_re_compile
17
18from ._functions import register as register_functions
19from .client import DatabaseClient
20from .creation import DatabaseCreation
21from .features import DatabaseFeatures
22from .introspection import DatabaseIntrospection
23from .operations import DatabaseOperations
24from .schema import DatabaseSchemaEditor
25
26
27def decoder(conv_func):
28 """
29 Convert bytestrings from Python's sqlite3 interface to a regular string.
30 """
31 return lambda s: conv_func(s.decode())
32
33
34def adapt_date(val):
35 return val.isoformat()
36
37
38def adapt_datetime(val):
39 return val.isoformat(" ")
40
41
42Database.register_converter("bool", b"1".__eq__)
43Database.register_converter("date", decoder(parse_date))
44Database.register_converter("time", decoder(parse_time))
45Database.register_converter("datetime", decoder(parse_datetime))
46Database.register_converter("timestamp", decoder(parse_datetime))
47
48Database.register_adapter(decimal.Decimal, str)
49Database.register_adapter(datetime.date, adapt_date)
50Database.register_adapter(datetime.datetime, adapt_datetime)
51
52
53class DatabaseWrapper(BaseDatabaseWrapper):
54 vendor = "sqlite"
55 display_name = "SQLite"
56 # SQLite doesn't actually support most of these types, but it "does the right
57 # thing" given more verbose field definitions, so leave them as is so that
58 # schema inspection is more useful.
59 data_types = {
60 "PrimaryKeyField": "integer",
61 "BinaryField": "BLOB",
62 "BooleanField": "bool",
63 "CharField": "varchar(%(max_length)s)",
64 "DateField": "date",
65 "DateTimeField": "datetime",
66 "DecimalField": "decimal",
67 "DurationField": "bigint",
68 "FloatField": "real",
69 "IntegerField": "integer",
70 "BigIntegerField": "bigint",
71 "IPAddressField": "char(15)",
72 "GenericIPAddressField": "char(39)",
73 "JSONField": "text",
74 "PositiveBigIntegerField": "bigint unsigned",
75 "PositiveIntegerField": "integer unsigned",
76 "PositiveSmallIntegerField": "smallint unsigned",
77 "SmallIntegerField": "smallint",
78 "TextField": "text",
79 "TimeField": "time",
80 "UUIDField": "char(32)",
81 }
82 data_type_check_constraints = {
83 "PositiveBigIntegerField": '"%(column)s" >= 0',
84 "JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
85 "PositiveIntegerField": '"%(column)s" >= 0',
86 "PositiveSmallIntegerField": '"%(column)s" >= 0',
87 }
88 data_types_suffix = {
89 "PrimaryKeyField": "AUTOINCREMENT",
90 }
91 # SQLite requires LIKE statements to include an ESCAPE clause if the value
92 # being escaped has a percent or underscore in it.
93 # See https://www.sqlite.org/lang_expr.html for an explanation.
94 operators = {
95 "exact": "= %s",
96 "iexact": "LIKE %s ESCAPE '\\'",
97 "contains": "LIKE %s ESCAPE '\\'",
98 "icontains": "LIKE %s ESCAPE '\\'",
99 "regex": "REGEXP %s",
100 "iregex": "REGEXP '(?i)' || %s",
101 "gt": "> %s",
102 "gte": ">= %s",
103 "lt": "< %s",
104 "lte": "<= %s",
105 "startswith": "LIKE %s ESCAPE '\\'",
106 "endswith": "LIKE %s ESCAPE '\\'",
107 "istartswith": "LIKE %s ESCAPE '\\'",
108 "iendswith": "LIKE %s ESCAPE '\\'",
109 }
110
111 # The patterns below are used to generate SQL pattern lookup clauses when
112 # the right-hand side of the lookup isn't a raw string (it might be an expression
113 # or the result of a bilateral transformation).
114 # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
115 # escaped on database side.
116 #
117 # Note: we use str.format() here for readability as '%' is used as a wildcard for
118 # the LIKE operator.
119 pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
120 pattern_ops = {
121 "contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
122 "icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
123 "startswith": r"LIKE {} || '%%' ESCAPE '\'",
124 "istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
125 "endswith": r"LIKE '%%' || {} ESCAPE '\'",
126 "iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
127 }
128
129 Database = Database
130 SchemaEditorClass = DatabaseSchemaEditor
131 # Classes instantiated in __init__().
132 client_class = DatabaseClient
133 creation_class = DatabaseCreation
134 features_class = DatabaseFeatures
135 introspection_class = DatabaseIntrospection
136 ops_class = DatabaseOperations
137
138 def get_connection_params(self):
139 settings_dict = self.settings_dict
140 if not settings_dict["NAME"]:
141 raise ImproperlyConfigured(
142 "settings.DATABASE is improperly configured. "
143 "Please supply the NAME value."
144 )
145 kwargs = {
146 "database": settings_dict["NAME"],
147 "detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
148 **settings_dict["OPTIONS"],
149 }
150 # Always allow the underlying SQLite connection to be shareable
151 # between multiple threads. The safe-guarding will be handled at a
152 # higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
153 # property. This is necessary as the shareability is disabled by
154 # default in sqlite3 and it cannot be changed once a connection is
155 # opened.
156 if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
157 warnings.warn(
158 "The `check_same_thread` option was provided and set to "
159 "True. It will be overridden with False. Use the "
160 "`DatabaseWrapper.allow_thread_sharing` property instead "
161 "for controlling thread shareability.",
162 RuntimeWarning,
163 )
164 kwargs.update({"check_same_thread": False, "uri": True})
165 return kwargs
166
167 def get_database_version(self):
168 return self.Database.sqlite_version_info
169
170 def get_new_connection(self, conn_params):
171 conn = Database.connect(**conn_params)
172 register_functions(conn)
173
174 conn.execute("PRAGMA foreign_keys = ON")
175 # The macOS bundled SQLite defaults legacy_alter_table ON, which
176 # prevents atomic table renames (feature supports_atomic_references_rename)
177 conn.execute("PRAGMA legacy_alter_table = OFF")
178 return conn
179
180 def create_cursor(self, name=None):
181 return self.connection.cursor(factory=SQLiteCursorWrapper)
182
183 def close(self):
184 self.validate_thread_sharing()
185 # If database is in memory, closing the connection destroys the
186 # database. To prevent accidental data loss, ignore close requests on
187 # an in-memory db.
188 if not self.is_in_memory_db():
189 BaseDatabaseWrapper.close(self)
190
191 def _savepoint_allowed(self):
192 # When 'isolation_level' is not None, sqlite3 commits before each
193 # savepoint; it's a bug. When it is None, savepoints don't make sense
194 # because autocommit is enabled. The only exception is inside 'atomic'
195 # blocks. To work around that bug, on SQLite, 'atomic' starts a
196 # transaction explicitly rather than simply disable autocommit.
197 return self.in_atomic_block
198
199 def _set_autocommit(self, autocommit):
200 if autocommit:
201 level = None
202 else:
203 # sqlite3's internal default is ''. It's different from None.
204 # See Modules/_sqlite/connection.c.
205 level = ""
206 # 'isolation_level' is a misleading API.
207 # SQLite always runs at the SERIALIZABLE isolation level.
208 with self.wrap_database_errors:
209 self.connection.isolation_level = level
210
211 def disable_constraint_checking(self):
212 with self.cursor() as cursor:
213 cursor.execute("PRAGMA foreign_keys = OFF")
214 # Foreign key constraints cannot be turned off while in a multi-
215 # statement transaction. Fetch the current state of the pragma
216 # to determine if constraints are effectively disabled.
217 enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
218 return not bool(enabled)
219
220 def enable_constraint_checking(self):
221 with self.cursor() as cursor:
222 cursor.execute("PRAGMA foreign_keys = ON")
223
224 def check_constraints(self, table_names=None):
225 """
226 Check each table name in `table_names` for rows with invalid foreign
227 key references. This method is intended to be used in conjunction with
228 `disable_constraint_checking()` and `enable_constraint_checking()`, to
229 determine if rows with invalid references were entered while constraint
230 checks were off.
231 """
232 with self.cursor() as cursor:
233 if table_names is None:
234 violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
235 else:
236 violations = chain.from_iterable(
237 cursor.execute(
238 f"PRAGMA foreign_key_check({self.ops.quote_name(table_name)})"
239 ).fetchall()
240 for table_name in table_names
241 )
242 # See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
243 for (
244 table_name,
245 rowid,
246 referenced_table_name,
247 foreign_key_index,
248 ) in violations:
249 foreign_key = cursor.execute(
250 f"PRAGMA foreign_key_list({self.ops.quote_name(table_name)})"
251 ).fetchall()[foreign_key_index]
252 column_name, referenced_column_name = foreign_key[3:5]
253 primary_key_column_name = self.introspection.get_primary_key_column(
254 cursor, table_name
255 )
256 primary_key_value, bad_value = cursor.execute(
257 f"SELECT {self.ops.quote_name(primary_key_column_name)}, {self.ops.quote_name(column_name)} FROM {self.ops.quote_name(table_name)} WHERE rowid = %s",
258 (rowid,),
259 ).fetchone()
260 raise IntegrityError(
261 f"The row in table '{table_name}' with primary key '{primary_key_value}' has an "
262 f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_value}' that "
263 f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
264 )
265
266 def is_usable(self):
267 return True
268
269 def _start_transaction_under_autocommit(self):
270 """
271 Start a transaction explicitly in autocommit mode.
272
273 Staying in autocommit mode works around a bug of sqlite3 that breaks
274 savepoints when autocommit is disabled.
275 """
276 self.cursor().execute("BEGIN")
277
278 def is_in_memory_db(self):
279 return self.creation.is_in_memory_db(self.settings_dict["NAME"])
280
281
282FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
283
284
285class SQLiteCursorWrapper(Database.Cursor):
286 """
287 Plain uses the "format" and "pyformat" styles, but Python's sqlite3 module
288 supports neither of these styles.
289
290 This wrapper performs the following conversions:
291
292 - "format" style to "qmark" style
293 - "pyformat" style to "named" style
294
295 In both cases, if you want to use a literal "%s", you'll need to use "%%s".
296 """
297
298 def execute(self, query, params=None):
299 if params is None:
300 return super().execute(query)
301 # Extract names if params is a mapping, i.e. "pyformat" style is used.
302 param_names = list(params) if isinstance(params, Mapping) else None
303 query = self.convert_query(query, param_names=param_names)
304 return super().execute(query, params)
305
306 def executemany(self, query, param_list):
307 # Extract names if params is a mapping, i.e. "pyformat" style is used.
308 # Peek carefully as a generator can be passed instead of a list/tuple.
309 peekable, param_list = tee(iter(param_list))
310 if (params := next(peekable, None)) and isinstance(params, Mapping):
311 param_names = list(params)
312 else:
313 param_names = None
314 query = self.convert_query(query, param_names=param_names)
315 return super().executemany(query, param_list)
316
317 def convert_query(self, query, *, param_names=None):
318 if param_names is None:
319 # Convert from "format" style to "qmark" style.
320 return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")
321 else:
322 # Convert from "pyformat" style to "named" style.
323 return query % {name: f":{name}" for name in param_names}