Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2SQLite backend for the sqlite3 module in the standard library.
  3"""
  4
  5from __future__ import annotations
  6
  7import datetime
  8import decimal
  9import warnings
 10from collections.abc import Callable, Iterable, Mapping
 11from itertools import chain, tee
 12from sqlite3 import dbapi2 as Database
 13from typing import Any
 14
 15from plain.exceptions import ImproperlyConfigured
 16from plain.models.backends.base.base import BaseDatabaseWrapper
 17from plain.models.db import IntegrityError
 18from plain.utils.dateparse import parse_date, parse_datetime, parse_time
 19from plain.utils.regex_helper import _lazy_re_compile
 20
 21from ._functions import register as register_functions
 22from .client import DatabaseClient
 23from .creation import DatabaseCreation
 24from .features import DatabaseFeatures
 25from .introspection import DatabaseIntrospection
 26from .operations import DatabaseOperations
 27from .schema import DatabaseSchemaEditor
 28
 29
 30def decoder(conv_func: Callable[[str], Any]) -> Callable[[bytes], Any]:
 31    """
 32    Convert bytestrings from Python's sqlite3 interface to a regular string.
 33    """
 34    return lambda s: conv_func(s.decode())
 35
 36
 37def adapt_date(val: datetime.date) -> str:
 38    return val.isoformat()
 39
 40
 41def adapt_datetime(val: datetime.datetime) -> str:
 42    return val.isoformat(" ")
 43
 44
 45def _get_varchar_column(data: dict[str, Any]) -> str:
 46    if data["max_length"] is None:
 47        return "varchar"
 48    return "varchar({max_length})".format(**data)
 49
 50
 51Database.register_converter("bool", b"1".__eq__)
 52Database.register_converter("date", decoder(parse_date))
 53Database.register_converter("time", decoder(parse_time))
 54Database.register_converter("datetime", decoder(parse_datetime))
 55Database.register_converter("timestamp", decoder(parse_datetime))
 56
 57Database.register_adapter(decimal.Decimal, str)
 58Database.register_adapter(datetime.date, adapt_date)
 59Database.register_adapter(datetime.datetime, adapt_datetime)
 60
 61
 62class SQLiteDatabaseWrapper(BaseDatabaseWrapper):
 63    vendor = "sqlite"
 64    display_name = "SQLite"
 65    # SQLite doesn't actually support most of these types, but it "does the right
 66    # thing" given more verbose field definitions, so leave them as is so that
 67    # schema inspection is more useful.
 68    data_types = {
 69        "PrimaryKeyField": "integer",
 70        "BinaryField": "BLOB",
 71        "BooleanField": "bool",
 72        "CharField": _get_varchar_column,
 73        "DateField": "date",
 74        "DateTimeField": "datetime",
 75        "DecimalField": "decimal",
 76        "DurationField": "bigint",
 77        "FloatField": "real",
 78        "IntegerField": "integer",
 79        "BigIntegerField": "bigint",
 80        "GenericIPAddressField": "char(39)",
 81        "JSONField": "text",
 82        "PositiveBigIntegerField": "bigint unsigned",
 83        "PositiveIntegerField": "integer unsigned",
 84        "PositiveSmallIntegerField": "smallint unsigned",
 85        "SmallIntegerField": "smallint",
 86        "TextField": "text",
 87        "TimeField": "time",
 88        "UUIDField": "char(32)",
 89    }
 90    data_type_check_constraints = {
 91        "PositiveBigIntegerField": '"%(column)s" >= 0',
 92        "JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
 93        "PositiveIntegerField": '"%(column)s" >= 0',
 94        "PositiveSmallIntegerField": '"%(column)s" >= 0',
 95    }
 96    data_types_suffix = {
 97        "PrimaryKeyField": "AUTOINCREMENT",
 98    }
 99    # SQLite requires LIKE statements to include an ESCAPE clause if the value
100    # being escaped has a percent or underscore in it.
101    # See https://www.sqlite.org/lang_expr.html for an explanation.
102    operators = {
103        "exact": "= %s",
104        "iexact": "LIKE %s ESCAPE '\\'",
105        "contains": "LIKE %s ESCAPE '\\'",
106        "icontains": "LIKE %s ESCAPE '\\'",
107        "regex": "REGEXP %s",
108        "iregex": "REGEXP '(?i)' || %s",
109        "gt": "> %s",
110        "gte": ">= %s",
111        "lt": "< %s",
112        "lte": "<= %s",
113        "startswith": "LIKE %s ESCAPE '\\'",
114        "endswith": "LIKE %s ESCAPE '\\'",
115        "istartswith": "LIKE %s ESCAPE '\\'",
116        "iendswith": "LIKE %s ESCAPE '\\'",
117    }
118
119    # The patterns below are used to generate SQL pattern lookup clauses when
120    # the right-hand side of the lookup isn't a raw string (it might be an expression
121    # or the result of a bilateral transformation).
122    # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
123    # escaped on database side.
124    #
125    # Note: we use str.format() here for readability as '%' is used as a wildcard for
126    # the LIKE operator.
127    pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
128    pattern_ops = {
129        "contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
130        "icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
131        "startswith": r"LIKE {} || '%%' ESCAPE '\'",
132        "istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
133        "endswith": r"LIKE '%%' || {} ESCAPE '\'",
134        "iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
135    }
136
137    Database = Database
138    SchemaEditorClass = DatabaseSchemaEditor
139    # Classes instantiated in __init__().
140    client_class = DatabaseClient
141    creation_class = DatabaseCreation
142    features_class = DatabaseFeatures
143    introspection_class = DatabaseIntrospection
144    ops_class = DatabaseOperations
145
146    def get_connection_params(self) -> dict[str, Any]:
147        settings_dict = self.settings_dict
148        if not settings_dict["NAME"]:
149            raise ImproperlyConfigured(
150                "settings.DATABASE is improperly configured. "
151                "Please supply the NAME value."
152            )
153        kwargs = {
154            "database": settings_dict["NAME"],
155            "detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
156            **settings_dict["OPTIONS"],
157        }
158        # Always allow the underlying SQLite connection to be shareable
159        # between multiple threads. The safe-guarding will be handled at a
160        # higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
161        # property. This is necessary as the shareability is disabled by
162        # default in sqlite3 and it cannot be changed once a connection is
163        # opened.
164        if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
165            warnings.warn(
166                "The `check_same_thread` option was provided and set to "
167                "True. It will be overridden with False. Use the "
168                "`DatabaseWrapper.allow_thread_sharing` property instead "
169                "for controlling thread shareability.",
170                RuntimeWarning,
171            )
172        kwargs.update({"check_same_thread": False, "uri": True})
173        return kwargs
174
175    def get_database_version(self) -> tuple[int, ...]:
176        return self.Database.sqlite_version_info
177
178    def get_new_connection(self, conn_params: dict[str, Any]) -> Any:
179        conn = Database.connect(**conn_params)  # type: ignore[call-overload]
180        register_functions(conn)
181
182        conn.execute("PRAGMA foreign_keys = ON")
183        # The macOS bundled SQLite defaults legacy_alter_table ON, which
184        # prevents atomic table renames (feature supports_atomic_references_rename)
185        conn.execute("PRAGMA legacy_alter_table = OFF")
186        return conn
187
188    def create_cursor(self, name: str | None = None) -> Any:
189        return self.connection.cursor(factory=SQLiteCursorWrapper)
190
191    def close(self) -> None:
192        self.validate_thread_sharing()
193        # If database is in memory, closing the connection destroys the
194        # database. To prevent accidental data loss, ignore close requests on
195        # an in-memory db.
196        if not self.is_in_memory_db():
197            BaseDatabaseWrapper.close(self)
198        return None
199
200    def _savepoint_allowed(self) -> bool:
201        # When 'isolation_level' is not None, sqlite3 commits before each
202        # savepoint; it's a bug. When it is None, savepoints don't make sense
203        # because autocommit is enabled. The only exception is inside 'atomic'
204        # blocks. To work around that bug, on SQLite, 'atomic' starts a
205        # transaction explicitly rather than simply disable autocommit.
206        return self.in_atomic_block
207
208    def _set_autocommit(self, autocommit: bool) -> None:
209        if autocommit:
210            level = None
211        else:
212            # sqlite3's internal default is ''. It's different from None.
213            # See Modules/_sqlite/connection.c.
214            level = ""
215        # 'isolation_level' is a misleading API.
216        # SQLite always runs at the SERIALIZABLE isolation level.
217        with self.wrap_database_errors:
218            self.connection.isolation_level = level
219        return None
220
221    def disable_constraint_checking(self) -> bool:
222        with self.cursor() as cursor:
223            cursor.execute("PRAGMA foreign_keys = OFF")
224            # Foreign key constraints cannot be turned off while in a multi-
225            # statement transaction. Fetch the current state of the pragma
226            # to determine if constraints are effectively disabled.
227            enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
228        return not bool(enabled)
229
230    def enable_constraint_checking(self) -> None:
231        with self.cursor() as cursor:
232            cursor.execute("PRAGMA foreign_keys = ON")
233        return None
234
235    def check_constraints(self, table_names: list[str] | None = None) -> None:
236        """
237        Check each table name in `table_names` for rows with invalid foreign
238        key references. This method is intended to be used in conjunction with
239        `disable_constraint_checking()` and `enable_constraint_checking()`, to
240        determine if rows with invalid references were entered while constraint
241        checks were off.
242        """
243        with self.cursor() as cursor:
244            if table_names is None:
245                violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
246            else:
247                violations = chain.from_iterable(
248                    cursor.execute(
249                        f"PRAGMA foreign_key_check({self.ops.quote_name(table_name)})"
250                    ).fetchall()
251                    for table_name in table_names
252                )
253            # See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
254            for (
255                table_name,
256                rowid,
257                referenced_table_name,
258                foreign_key_index,
259            ) in violations:
260                foreign_key = cursor.execute(
261                    f"PRAGMA foreign_key_list({self.ops.quote_name(table_name)})"
262                ).fetchall()[foreign_key_index]
263                column_name, referenced_column_name = foreign_key[3:5]
264                primary_key_column_name = self.introspection.get_primary_key_column(
265                    cursor, table_name
266                )
267                primary_key_value, bad_value = cursor.execute(
268                    f"SELECT {self.ops.quote_name(primary_key_column_name)}, {self.ops.quote_name(column_name)} FROM {self.ops.quote_name(table_name)} WHERE rowid = %s",
269                    (rowid,),
270                ).fetchone()
271                raise IntegrityError(
272                    f"The row in table '{table_name}' with primary key '{primary_key_value}' has an "
273                    f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_value}' that "
274                    f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
275                )
276        return None
277
278    def is_usable(self) -> bool:
279        return True
280
281    def _start_transaction_under_autocommit(self) -> None:
282        """
283        Start a transaction explicitly in autocommit mode.
284
285        Staying in autocommit mode works around a bug of sqlite3 that breaks
286        savepoints when autocommit is disabled.
287        """
288        self.cursor().execute("BEGIN")
289        return None
290
291    def is_in_memory_db(self) -> bool:
292        return self.creation.is_in_memory_db(self.settings_dict["NAME"])
293
294
295FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
296
297
298class SQLiteCursorWrapper(Database.Cursor):
299    """
300    Plain uses the "format" and "pyformat" styles, but Python's sqlite3 module
301    supports neither of these styles.
302
303    This wrapper performs the following conversions:
304
305    - "format" style to "qmark" style
306    - "pyformat" style to "named" style
307
308    In both cases, if you want to use a literal "%s", you'll need to use "%%s".
309    """
310
311    def execute(  # type: ignore[override]
312        self, query: str, params: Iterable[Any] | Mapping[str, Any] | None = None
313    ) -> Any:
314        if params is None:
315            return super().execute(query)
316        # Extract names if params is a mapping, i.e. "pyformat" style is used.
317        param_names = list(params) if isinstance(params, Mapping) else None
318        query = self.convert_query(query, param_names=param_names)
319        return super().execute(query, params)
320
321    def executemany(  # type: ignore[override]
322        self,
323        query: str,
324        param_list: Iterable[Iterable[Any] | Mapping[str, Any]],
325    ) -> Any:
326        # Extract names if params is a mapping, i.e. "pyformat" style is used.
327        # Peek carefully as a generator can be passed instead of a list/tuple.
328        peekable, param_list = tee(iter(param_list))
329        if (params := next(peekable, None)) and isinstance(params, Mapping):
330            param_names = list(params)
331        else:
332            param_names = None
333        query = self.convert_query(query, param_names=param_names)
334        return super().executemany(query, param_list)
335
336    def convert_query(self, query: str, *, param_names: list[str] | None = None) -> str:
337        if param_names is None:
338            # Convert from "format" style to "qmark" style.
339            return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")
340        else:
341            # Convert from "pyformat" style to "named" style.
342            return query % {name: f":{name}" for name in param_names}