Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2SQLite backend for the sqlite3 module in the standard library.
  3"""
  4import datetime
  5import decimal
  6import warnings
  7from collections.abc import Mapping
  8from itertools import chain, tee
  9from sqlite3 import dbapi2 as Database
 10
 11from plain.exceptions import ImproperlyConfigured
 12from plain.models.backends.base.base import BaseDatabaseWrapper
 13from plain.models.db import IntegrityError
 14from plain.utils.dateparse import parse_date, parse_datetime, parse_time
 15from plain.utils.regex_helper import _lazy_re_compile
 16
 17from ._functions import register as register_functions
 18from .client import DatabaseClient
 19from .creation import DatabaseCreation
 20from .features import DatabaseFeatures
 21from .introspection import DatabaseIntrospection
 22from .operations import DatabaseOperations
 23from .schema import DatabaseSchemaEditor
 24
 25
 26def decoder(conv_func):
 27    """
 28    Convert bytestrings from Python's sqlite3 interface to a regular string.
 29    """
 30    return lambda s: conv_func(s.decode())
 31
 32
 33def adapt_date(val):
 34    return val.isoformat()
 35
 36
 37def adapt_datetime(val):
 38    return val.isoformat(" ")
 39
 40
 41Database.register_converter("bool", b"1".__eq__)
 42Database.register_converter("date", decoder(parse_date))
 43Database.register_converter("time", decoder(parse_time))
 44Database.register_converter("datetime", decoder(parse_datetime))
 45Database.register_converter("timestamp", decoder(parse_datetime))
 46
 47Database.register_adapter(decimal.Decimal, str)
 48Database.register_adapter(datetime.date, adapt_date)
 49Database.register_adapter(datetime.datetime, adapt_datetime)
 50
 51
 52class DatabaseWrapper(BaseDatabaseWrapper):
 53    vendor = "sqlite"
 54    display_name = "SQLite"
 55    # SQLite doesn't actually support most of these types, but it "does the right
 56    # thing" given more verbose field definitions, so leave them as is so that
 57    # schema inspection is more useful.
 58    data_types = {
 59        "AutoField": "integer",
 60        "BigAutoField": "integer",
 61        "BinaryField": "BLOB",
 62        "BooleanField": "bool",
 63        "CharField": "varchar(%(max_length)s)",
 64        "DateField": "date",
 65        "DateTimeField": "datetime",
 66        "DecimalField": "decimal",
 67        "DurationField": "bigint",
 68        "FloatField": "real",
 69        "IntegerField": "integer",
 70        "BigIntegerField": "bigint",
 71        "IPAddressField": "char(15)",
 72        "GenericIPAddressField": "char(39)",
 73        "JSONField": "text",
 74        "OneToOneField": "integer",
 75        "PositiveBigIntegerField": "bigint unsigned",
 76        "PositiveIntegerField": "integer unsigned",
 77        "PositiveSmallIntegerField": "smallint unsigned",
 78        "SlugField": "varchar(%(max_length)s)",
 79        "SmallAutoField": "integer",
 80        "SmallIntegerField": "smallint",
 81        "TextField": "text",
 82        "TimeField": "time",
 83        "UUIDField": "char(32)",
 84    }
 85    data_type_check_constraints = {
 86        "PositiveBigIntegerField": '"%(column)s" >= 0',
 87        "JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
 88        "PositiveIntegerField": '"%(column)s" >= 0',
 89        "PositiveSmallIntegerField": '"%(column)s" >= 0',
 90    }
 91    data_types_suffix = {
 92        "AutoField": "AUTOINCREMENT",
 93        "BigAutoField": "AUTOINCREMENT",
 94        "SmallAutoField": "AUTOINCREMENT",
 95    }
 96    # SQLite requires LIKE statements to include an ESCAPE clause if the value
 97    # being escaped has a percent or underscore in it.
 98    # See https://www.sqlite.org/lang_expr.html for an explanation.
 99    operators = {
100        "exact": "= %s",
101        "iexact": "LIKE %s ESCAPE '\\'",
102        "contains": "LIKE %s ESCAPE '\\'",
103        "icontains": "LIKE %s ESCAPE '\\'",
104        "regex": "REGEXP %s",
105        "iregex": "REGEXP '(?i)' || %s",
106        "gt": "> %s",
107        "gte": ">= %s",
108        "lt": "< %s",
109        "lte": "<= %s",
110        "startswith": "LIKE %s ESCAPE '\\'",
111        "endswith": "LIKE %s ESCAPE '\\'",
112        "istartswith": "LIKE %s ESCAPE '\\'",
113        "iendswith": "LIKE %s ESCAPE '\\'",
114    }
115
116    # The patterns below are used to generate SQL pattern lookup clauses when
117    # the right-hand side of the lookup isn't a raw string (it might be an expression
118    # or the result of a bilateral transformation).
119    # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
120    # escaped on database side.
121    #
122    # Note: we use str.format() here for readability as '%' is used as a wildcard for
123    # the LIKE operator.
124    pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
125    pattern_ops = {
126        "contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
127        "icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
128        "startswith": r"LIKE {} || '%%' ESCAPE '\'",
129        "istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
130        "endswith": r"LIKE '%%' || {} ESCAPE '\'",
131        "iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
132    }
133
134    Database = Database
135    SchemaEditorClass = DatabaseSchemaEditor
136    # Classes instantiated in __init__().
137    client_class = DatabaseClient
138    creation_class = DatabaseCreation
139    features_class = DatabaseFeatures
140    introspection_class = DatabaseIntrospection
141    ops_class = DatabaseOperations
142
143    def get_connection_params(self):
144        settings_dict = self.settings_dict
145        if not settings_dict["NAME"]:
146            raise ImproperlyConfigured(
147                "settings.DATABASES is improperly configured. "
148                "Please supply the NAME value."
149            )
150        kwargs = {
151            "database": settings_dict["NAME"],
152            "detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
153            **settings_dict["OPTIONS"],
154        }
155        # Always allow the underlying SQLite connection to be shareable
156        # between multiple threads. The safe-guarding will be handled at a
157        # higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
158        # property. This is necessary as the shareability is disabled by
159        # default in sqlite3 and it cannot be changed once a connection is
160        # opened.
161        if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
162            warnings.warn(
163                "The `check_same_thread` option was provided and set to "
164                "True. It will be overridden with False. Use the "
165                "`DatabaseWrapper.allow_thread_sharing` property instead "
166                "for controlling thread shareability.",
167                RuntimeWarning,
168            )
169        kwargs.update({"check_same_thread": False, "uri": True})
170        return kwargs
171
172    def get_database_version(self):
173        return self.Database.sqlite_version_info
174
175    def get_new_connection(self, conn_params):
176        conn = Database.connect(**conn_params)
177        register_functions(conn)
178
179        conn.execute("PRAGMA foreign_keys = ON")
180        # The macOS bundled SQLite defaults legacy_alter_table ON, which
181        # prevents atomic table renames (feature supports_atomic_references_rename)
182        conn.execute("PRAGMA legacy_alter_table = OFF")
183        return conn
184
185    def create_cursor(self, name=None):
186        return self.connection.cursor(factory=SQLiteCursorWrapper)
187
188    def close(self):
189        self.validate_thread_sharing()
190        # If database is in memory, closing the connection destroys the
191        # database. To prevent accidental data loss, ignore close requests on
192        # an in-memory db.
193        if not self.is_in_memory_db():
194            BaseDatabaseWrapper.close(self)
195
196    def _savepoint_allowed(self):
197        # When 'isolation_level' is not None, sqlite3 commits before each
198        # savepoint; it's a bug. When it is None, savepoints don't make sense
199        # because autocommit is enabled. The only exception is inside 'atomic'
200        # blocks. To work around that bug, on SQLite, 'atomic' starts a
201        # transaction explicitly rather than simply disable autocommit.
202        return self.in_atomic_block
203
204    def _set_autocommit(self, autocommit):
205        if autocommit:
206            level = None
207        else:
208            # sqlite3's internal default is ''. It's different from None.
209            # See Modules/_sqlite/connection.c.
210            level = ""
211        # 'isolation_level' is a misleading API.
212        # SQLite always runs at the SERIALIZABLE isolation level.
213        with self.wrap_database_errors:
214            self.connection.isolation_level = level
215
216    def disable_constraint_checking(self):
217        with self.cursor() as cursor:
218            cursor.execute("PRAGMA foreign_keys = OFF")
219            # Foreign key constraints cannot be turned off while in a multi-
220            # statement transaction. Fetch the current state of the pragma
221            # to determine if constraints are effectively disabled.
222            enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
223        return not bool(enabled)
224
225    def enable_constraint_checking(self):
226        with self.cursor() as cursor:
227            cursor.execute("PRAGMA foreign_keys = ON")
228
229    def check_constraints(self, table_names=None):
230        """
231        Check each table name in `table_names` for rows with invalid foreign
232        key references. This method is intended to be used in conjunction with
233        `disable_constraint_checking()` and `enable_constraint_checking()`, to
234        determine if rows with invalid references were entered while constraint
235        checks were off.
236        """
237        with self.cursor() as cursor:
238            if table_names is None:
239                violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
240            else:
241                violations = chain.from_iterable(
242                    cursor.execute(
243                        "PRAGMA foreign_key_check(%s)" % self.ops.quote_name(table_name)
244                    ).fetchall()
245                    for table_name in table_names
246                )
247            # See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
248            for (
249                table_name,
250                rowid,
251                referenced_table_name,
252                foreign_key_index,
253            ) in violations:
254                foreign_key = cursor.execute(
255                    "PRAGMA foreign_key_list(%s)" % self.ops.quote_name(table_name)
256                ).fetchall()[foreign_key_index]
257                column_name, referenced_column_name = foreign_key[3:5]
258                primary_key_column_name = self.introspection.get_primary_key_column(
259                    cursor, table_name
260                )
261                primary_key_value, bad_value = cursor.execute(
262                    "SELECT {}, {} FROM {} WHERE rowid = %s".format(
263                        self.ops.quote_name(primary_key_column_name),
264                        self.ops.quote_name(column_name),
265                        self.ops.quote_name(table_name),
266                    ),
267                    (rowid,),
268                ).fetchone()
269                raise IntegrityError(
270                    "The row in table '{}' with primary key '{}' has an "
271                    "invalid foreign key: {}.{} contains a value '{}' that "
272                    "does not have a corresponding value in {}.{}.".format(
273                        table_name,
274                        primary_key_value,
275                        table_name,
276                        column_name,
277                        bad_value,
278                        referenced_table_name,
279                        referenced_column_name,
280                    )
281                )
282
283    def is_usable(self):
284        return True
285
286    def _start_transaction_under_autocommit(self):
287        """
288        Start a transaction explicitly in autocommit mode.
289
290        Staying in autocommit mode works around a bug of sqlite3 that breaks
291        savepoints when autocommit is disabled.
292        """
293        self.cursor().execute("BEGIN")
294
295    def is_in_memory_db(self):
296        return self.creation.is_in_memory_db(self.settings_dict["NAME"])
297
298
299FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
300
301
302class SQLiteCursorWrapper(Database.Cursor):
303    """
304    Plain uses the "format" and "pyformat" styles, but Python's sqlite3 module
305    supports neither of these styles.
306
307    This wrapper performs the following conversions:
308
309    - "format" style to "qmark" style
310    - "pyformat" style to "named" style
311
312    In both cases, if you want to use a literal "%s", you'll need to use "%%s".
313    """
314
315    def execute(self, query, params=None):
316        if params is None:
317            return super().execute(query)
318        # Extract names if params is a mapping, i.e. "pyformat" style is used.
319        param_names = list(params) if isinstance(params, Mapping) else None
320        query = self.convert_query(query, param_names=param_names)
321        return super().execute(query, params)
322
323    def executemany(self, query, param_list):
324        # Extract names if params is a mapping, i.e. "pyformat" style is used.
325        # Peek carefully as a generator can be passed instead of a list/tuple.
326        peekable, param_list = tee(iter(param_list))
327        if (params := next(peekable, None)) and isinstance(params, Mapping):
328            param_names = list(params)
329        else:
330            param_names = None
331        query = self.convert_query(query, param_names=param_names)
332        return super().executemany(query, param_list)
333
334    def convert_query(self, query, *, param_names=None):
335        if param_names is None:
336            # Convert from "format" style to "qmark" style.
337            return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")
338        else:
339            # Convert from "pyformat" style to "named" style.
340            return query % {name: f":{name}" for name in param_names}