Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2SQLite backend for the sqlite3 module in the standard library.
  3"""
  4
  5from __future__ import annotations
  6
  7import datetime
  8import decimal
  9import warnings
 10from collections.abc import Callable, Iterable, Mapping
 11from itertools import chain, tee
 12from sqlite3 import dbapi2 as Database
 13from typing import Any
 14
 15from plain.exceptions import ImproperlyConfigured
 16from plain.models.backends.base.base import BaseDatabaseWrapper
 17from plain.models.db import IntegrityError
 18from plain.utils.dateparse import parse_date, parse_datetime, parse_time
 19from plain.utils.regex_helper import _lazy_re_compile
 20
 21from ._functions import register as register_functions
 22from .client import DatabaseClient
 23from .creation import DatabaseCreation
 24from .features import DatabaseFeatures
 25from .introspection import DatabaseIntrospection
 26from .operations import DatabaseOperations
 27from .schema import DatabaseSchemaEditor
 28
 29
 30def decoder(conv_func: Callable[[str], Any]) -> Callable[[bytes], Any]:
 31    """
 32    Convert bytestrings from Python's sqlite3 interface to a regular string.
 33    """
 34    return lambda s: conv_func(s.decode())
 35
 36
 37def adapt_date(val: datetime.date) -> str:
 38    return val.isoformat()
 39
 40
 41def adapt_datetime(val: datetime.datetime) -> str:
 42    return val.isoformat(" ")
 43
 44
 45def _get_varchar_column(data: dict[str, Any]) -> str:
 46    if data["max_length"] is None:
 47        return "varchar"
 48    return "varchar({max_length})".format(**data)
 49
 50
 51Database.register_converter("bool", b"1".__eq__)
 52Database.register_converter("date", decoder(parse_date))
 53Database.register_converter("time", decoder(parse_time))
 54Database.register_converter("datetime", decoder(parse_datetime))
 55Database.register_converter("timestamp", decoder(parse_datetime))
 56
 57Database.register_adapter(decimal.Decimal, str)
 58Database.register_adapter(datetime.date, adapt_date)
 59Database.register_adapter(datetime.datetime, adapt_datetime)
 60
 61
 62class SQLiteDatabaseWrapper(BaseDatabaseWrapper):
 63    # Type checker hints: narrow base class attribute types to backend-specific classes
 64    ops: DatabaseOperations
 65    features: DatabaseFeatures
 66    introspection: DatabaseIntrospection
 67    creation: DatabaseCreation
 68
 69    vendor = "sqlite"
 70    display_name = "SQLite"
 71    # SQLite doesn't actually support most of these types, but it "does the right
 72    # thing" given more verbose field definitions, so leave them as is so that
 73    # schema inspection is more useful.
 74    data_types = {
 75        "PrimaryKeyField": "integer",
 76        "BinaryField": "BLOB",
 77        "BooleanField": "bool",
 78        "CharField": _get_varchar_column,
 79        "DateField": "date",
 80        "DateTimeField": "datetime",
 81        "DecimalField": "decimal",
 82        "DurationField": "bigint",
 83        "FloatField": "real",
 84        "IntegerField": "integer",
 85        "BigIntegerField": "bigint",
 86        "GenericIPAddressField": "char(39)",
 87        "JSONField": "text",
 88        "PositiveBigIntegerField": "bigint unsigned",
 89        "PositiveIntegerField": "integer unsigned",
 90        "PositiveSmallIntegerField": "smallint unsigned",
 91        "SmallIntegerField": "smallint",
 92        "TextField": "text",
 93        "TimeField": "time",
 94        "UUIDField": "char(32)",
 95    }
 96    data_type_check_constraints = {
 97        "PositiveBigIntegerField": '"%(column)s" >= 0',
 98        "JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
 99        "PositiveIntegerField": '"%(column)s" >= 0',
100        "PositiveSmallIntegerField": '"%(column)s" >= 0',
101    }
102    data_types_suffix = {
103        "PrimaryKeyField": "AUTOINCREMENT",
104    }
105    # SQLite requires LIKE statements to include an ESCAPE clause if the value
106    # being escaped has a percent or underscore in it.
107    # See https://www.sqlite.org/lang_expr.html for an explanation.
108    operators = {
109        "exact": "= %s",
110        "iexact": "LIKE %s ESCAPE '\\'",
111        "contains": "LIKE %s ESCAPE '\\'",
112        "icontains": "LIKE %s ESCAPE '\\'",
113        "regex": "REGEXP %s",
114        "iregex": "REGEXP '(?i)' || %s",
115        "gt": "> %s",
116        "gte": ">= %s",
117        "lt": "< %s",
118        "lte": "<= %s",
119        "startswith": "LIKE %s ESCAPE '\\'",
120        "endswith": "LIKE %s ESCAPE '\\'",
121        "istartswith": "LIKE %s ESCAPE '\\'",
122        "iendswith": "LIKE %s ESCAPE '\\'",
123    }
124
125    # The patterns below are used to generate SQL pattern lookup clauses when
126    # the right-hand side of the lookup isn't a raw string (it might be an expression
127    # or the result of a bilateral transformation).
128    # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
129    # escaped on database side.
130    #
131    # Note: we use str.format() here for readability as '%' is used as a wildcard for
132    # the LIKE operator.
133    pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
134    pattern_ops = {
135        "contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
136        "icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
137        "startswith": r"LIKE {} || '%%' ESCAPE '\'",
138        "istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
139        "endswith": r"LIKE '%%' || {} ESCAPE '\'",
140        "iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
141    }
142
143    Database = Database
144    SchemaEditorClass = DatabaseSchemaEditor
145    # Classes instantiated in __init__().
146    client_class = DatabaseClient
147    creation_class = DatabaseCreation
148    features_class = DatabaseFeatures
149    introspection_class = DatabaseIntrospection
150    ops_class = DatabaseOperations
151
152    def get_connection_params(self) -> dict[str, Any]:
153        settings_dict = self.settings_dict
154        if not settings_dict["NAME"]:
155            raise ImproperlyConfigured(
156                "settings.DATABASE is improperly configured. "
157                "Please supply the NAME value."
158            )
159        kwargs = {
160            "database": settings_dict["NAME"],
161            "detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
162            **settings_dict["OPTIONS"],
163        }
164        # Always allow the underlying SQLite connection to be shareable
165        # between multiple threads. The safe-guarding will be handled at a
166        # higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
167        # property. This is necessary as the shareability is disabled by
168        # default in sqlite3 and it cannot be changed once a connection is
169        # opened.
170        if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
171            warnings.warn(
172                "The `check_same_thread` option was provided and set to "
173                "True. It will be overridden with False. Use the "
174                "`DatabaseWrapper.allow_thread_sharing` property instead "
175                "for controlling thread shareability.",
176                RuntimeWarning,
177            )
178        kwargs.update({"check_same_thread": False, "uri": True})
179        return kwargs
180
181    def get_database_version(self) -> tuple[int, ...]:
182        return self.Database.sqlite_version_info
183
184    def get_new_connection(self, conn_params: dict[str, Any]) -> Any:
185        conn = Database.connect(**conn_params)  # type: ignore[call-overload]
186        register_functions(conn)
187
188        conn.execute("PRAGMA foreign_keys = ON")
189        # The macOS bundled SQLite defaults legacy_alter_table ON, which
190        # prevents atomic table renames (feature supports_atomic_references_rename)
191        conn.execute("PRAGMA legacy_alter_table = OFF")
192        return conn
193
194    def create_cursor(self, name: str | None = None) -> Any:
195        return self.connection.cursor(factory=SQLiteCursorWrapper)
196
197    def close(self) -> None:
198        self.validate_thread_sharing()
199        # If database is in memory, closing the connection destroys the
200        # database. To prevent accidental data loss, ignore close requests on
201        # an in-memory db.
202        if not self.is_in_memory_db():
203            BaseDatabaseWrapper.close(self)
204        return None
205
206    def _savepoint_allowed(self) -> bool:
207        # When 'isolation_level' is not None, sqlite3 commits before each
208        # savepoint; it's a bug. When it is None, savepoints don't make sense
209        # because autocommit is enabled. The only exception is inside 'atomic'
210        # blocks. To work around that bug, on SQLite, 'atomic' starts a
211        # transaction explicitly rather than simply disable autocommit.
212        return self.in_atomic_block
213
214    def _set_autocommit(self, autocommit: bool) -> None:
215        if autocommit:
216            level = None
217        else:
218            # sqlite3's internal default is ''. It's different from None.
219            # See Modules/_sqlite/connection.c.
220            level = ""
221        # 'isolation_level' is a misleading API.
222        # SQLite always runs at the SERIALIZABLE isolation level.
223        with self.wrap_database_errors:
224            self.connection.isolation_level = level
225        return None
226
227    def disable_constraint_checking(self) -> bool:
228        with self.cursor() as cursor:
229            cursor.execute("PRAGMA foreign_keys = OFF")
230            # Foreign key constraints cannot be turned off while in a multi-
231            # statement transaction. Fetch the current state of the pragma
232            # to determine if constraints are effectively disabled.
233            enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
234        return not bool(enabled)
235
236    def enable_constraint_checking(self) -> None:
237        with self.cursor() as cursor:
238            cursor.execute("PRAGMA foreign_keys = ON")
239        return None
240
241    def check_constraints(self, table_names: list[str] | None = None) -> None:
242        """
243        Check each table name in `table_names` for rows with invalid foreign
244        key references. This method is intended to be used in conjunction with
245        `disable_constraint_checking()` and `enable_constraint_checking()`, to
246        determine if rows with invalid references were entered while constraint
247        checks were off.
248        """
249        with self.cursor() as cursor:
250            if table_names is None:
251                violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
252            else:
253                violations = chain.from_iterable(
254                    cursor.execute(
255                        f"PRAGMA foreign_key_check({self.ops.quote_name(table_name)})"
256                    ).fetchall()
257                    for table_name in table_names
258                )
259            # See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
260            for (
261                table_name,
262                rowid,
263                referenced_table_name,
264                foreign_key_index,
265            ) in violations:
266                foreign_key = cursor.execute(
267                    f"PRAGMA foreign_key_list({self.ops.quote_name(table_name)})"
268                ).fetchall()[foreign_key_index]
269                column_name, referenced_column_name = foreign_key[3:5]
270                primary_key_column_name = self.introspection.get_primary_key_column(
271                    cursor, table_name
272                )
273                assert primary_key_column_name is not None, (
274                    f"Table {table_name} must have a primary key"
275                )
276                primary_key_value, bad_value = cursor.execute(
277                    f"SELECT {self.ops.quote_name(primary_key_column_name)}, {self.ops.quote_name(column_name)} FROM {self.ops.quote_name(table_name)} WHERE rowid = %s",
278                    (rowid,),
279                ).fetchone()
280                raise IntegrityError(
281                    f"The row in table '{table_name}' with primary key '{primary_key_value}' has an "
282                    f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_value}' that "
283                    f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
284                )
285        return None
286
287    def is_usable(self) -> bool:
288        return True
289
290    def _start_transaction_under_autocommit(self) -> None:
291        """
292        Start a transaction explicitly in autocommit mode.
293
294        Staying in autocommit mode works around a bug of sqlite3 that breaks
295        savepoints when autocommit is disabled.
296        """
297        self.cursor().execute("BEGIN")
298        return None
299
300    def is_in_memory_db(self) -> bool:
301        return self.creation.is_in_memory_db(self.settings_dict["NAME"])
302
303
304FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
305
306
307class SQLiteCursorWrapper(Database.Cursor):
308    """
309    Plain uses the "format" and "pyformat" styles, but Python's sqlite3 module
310    supports neither of these styles.
311
312    This wrapper performs the following conversions:
313
314    - "format" style to "qmark" style
315    - "pyformat" style to "named" style
316
317    In both cases, if you want to use a literal "%s", you'll need to use "%%s".
318    """
319
320    def execute(  # type: ignore[override]
321        self, query: str, params: Iterable[Any] | Mapping[str, Any] | None = None
322    ) -> Any:
323        if params is None:
324            return super().execute(query)
325        # Extract names if params is a mapping, i.e. "pyformat" style is used.
326        param_names = list(params) if isinstance(params, Mapping) else None
327        query = self.convert_query(query, param_names=param_names)
328        return super().execute(query, params)
329
330    def executemany(  # type: ignore[override]
331        self,
332        query: str,
333        param_list: Iterable[Iterable[Any] | Mapping[str, Any]],
334    ) -> Any:
335        # Extract names if params is a mapping, i.e. "pyformat" style is used.
336        # Peek carefully as a generator can be passed instead of a list/tuple.
337        peekable, param_list = tee(iter(param_list))
338        if (params := next(peekable, None)) and isinstance(params, Mapping):
339            param_names = list(params)
340        else:
341            param_names = None
342        query = self.convert_query(query, param_names=param_names)
343        return super().executemany(query, param_list)
344
345    def convert_query(self, query: str, *, param_names: list[str] | None = None) -> str:
346        if param_names is None:
347            # Convert from "format" style to "qmark" style.
348            return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")
349        else:
350            # Convert from "pyformat" style to "named" style.
351            return query % {name: f":{name}" for name in param_names}