Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import datetime
  4import decimal
  5import uuid
  6from collections.abc import Callable, Iterable
  7from functools import cached_property, lru_cache
  8from typing import TYPE_CHECKING, Any
  9
 10from plain import models
 11from plain.models.aggregates import Aggregate, Avg, StdDev, Sum, Variance
 12from plain.models.backends.base.operations import BaseDatabaseOperations
 13from plain.models.backends.utils import CursorWrapper
 14from plain.models.constants import OnConflict
 15from plain.models.db import DatabaseError, NotSupportedError
 16from plain.models.exceptions import FieldError
 17from plain.models.expressions import Col, ResolvableExpression
 18from plain.utils import timezone
 19from plain.utils.dateparse import parse_date, parse_datetime, parse_time
 20
 21if TYPE_CHECKING:
 22    from plain.models.backends.base.base import BaseDatabaseWrapper
 23    from plain.models.backends.sqlite3.base import SQLiteDatabaseWrapper
 24    from plain.models.fields import Field
 25
 26
 27class DatabaseOperations(BaseDatabaseOperations):
 28    # Type checker hint: connection is always SQLiteDatabaseWrapper in this class
 29    connection: SQLiteDatabaseWrapper
 30
 31    cast_char_field_without_max_length = "text"
 32    cast_data_types = {
 33        "DateField": "TEXT",
 34        "DateTimeField": "TEXT",
 35    }
 36    explain_prefix = "EXPLAIN QUERY PLAN"
 37    # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
 38    # SQLite. Use JSON_TYPE() instead.
 39    jsonfield_datatype_values = frozenset(["null", "false", "true"])
 40
 41    def bulk_batch_size(self, fields: list[Any], objs: list[Any]) -> int:
 42        """
 43        SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
 44        999 variables per query.
 45
 46        If there's only a single field to insert, the limit is 500
 47        (SQLITE_MAX_COMPOUND_SELECT).
 48        """
 49        if len(fields) == 1:
 50            return 500
 51        elif len(fields) > 1:
 52            return self.connection.features.max_query_params // len(fields)
 53        else:
 54            return len(objs)
 55
 56    def check_expression_support(self, expression: Any) -> None:
 57        bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
 58        bad_aggregates = (Sum, Avg, Variance, StdDev)
 59        if isinstance(expression, bad_aggregates):
 60            for expr in expression.get_source_expressions():
 61                try:
 62                    output_field = expr.output_field
 63                except (AttributeError, FieldError):
 64                    # Not every subexpression has an output_field which is fine
 65                    # to ignore.
 66                    pass
 67                else:
 68                    if isinstance(output_field, bad_fields):
 69                        raise NotSupportedError(
 70                            "You cannot use Sum, Avg, StdDev, and Variance "
 71                            "aggregations on date/time fields in sqlite3 "
 72                            "since date/time is saved as text."
 73                        )
 74        if (
 75            isinstance(expression, Aggregate)
 76            and expression.distinct
 77            and len(expression.source_expressions) > 1
 78        ):
 79            raise NotSupportedError(
 80                "SQLite doesn't support DISTINCT on aggregate functions "
 81                "accepting multiple arguments."
 82            )
 83
 84    def date_extract_sql(
 85        self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
 86    ) -> tuple[str, tuple[Any, ...]]:
 87        """
 88        Support EXTRACT with a user-defined function plain_date_extract()
 89        that's registered in connect(). Use single quotes because this is a
 90        string and could otherwise cause a collision with a field name.
 91        """
 92        return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params)
 93
 94    def fetch_returned_insert_rows(self, cursor: CursorWrapper) -> list[Any]:
 95        """
 96        Given a cursor object that has just performed an INSERT...RETURNING
 97        statement into a table, return the list of returned data.
 98        """
 99        return cursor.fetchall()
100
101    def format_for_duration_arithmetic(self, sql: str) -> str:
102        """Do nothing since formatting is handled in the custom function."""
103        return sql
104
105    def date_trunc_sql(
106        self,
107        lookup_type: str,
108        sql: str,
109        params: list[Any] | tuple[Any, ...],
110        tzname: str | None = None,
111    ) -> tuple[str, tuple[Any, ...]]:
112        return f"plain_date_trunc(%s, {sql}, %s, %s)", (
113            lookup_type.lower(),
114            *params,
115            *self._convert_tznames_to_sql(tzname),
116        )
117
118    def time_trunc_sql(
119        self,
120        lookup_type: str,
121        sql: str,
122        params: list[Any] | tuple[Any, ...],
123        tzname: str | None = None,
124    ) -> tuple[str, tuple[Any, ...]]:
125        return f"plain_time_trunc(%s, {sql}, %s, %s)", (
126            lookup_type.lower(),
127            *params,
128            *self._convert_tznames_to_sql(tzname),
129        )
130
131    def _convert_tznames_to_sql(
132        self, tzname: str | None
133    ) -> tuple[str | None, str | None]:
134        if tzname:
135            return tzname, self.connection.timezone_name
136        return None, None
137
138    def datetime_cast_date_sql(
139        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
140    ) -> tuple[str, tuple[Any, ...]]:
141        return f"plain_datetime_cast_date({sql}, %s, %s)", (
142            *params,
143            *self._convert_tznames_to_sql(tzname),
144        )
145
146    def datetime_cast_time_sql(
147        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
148    ) -> tuple[str, tuple[Any, ...]]:
149        return f"plain_datetime_cast_time({sql}, %s, %s)", (
150            *params,
151            *self._convert_tznames_to_sql(tzname),
152        )
153
154    def datetime_extract_sql(
155        self,
156        lookup_type: str,
157        sql: str,
158        params: list[Any] | tuple[Any, ...],
159        tzname: str | None,
160    ) -> tuple[str, tuple[Any, ...]]:
161        return f"plain_datetime_extract(%s, {sql}, %s, %s)", (
162            lookup_type.lower(),
163            *params,
164            *self._convert_tznames_to_sql(tzname),
165        )
166
167    def datetime_trunc_sql(
168        self,
169        lookup_type: str,
170        sql: str,
171        params: list[Any] | tuple[Any, ...],
172        tzname: str | None,
173    ) -> tuple[str, tuple[Any, ...]]:
174        return f"plain_datetime_trunc(%s, {sql}, %s, %s)", (
175            lookup_type.lower(),
176            *params,
177            *self._convert_tznames_to_sql(tzname),
178        )
179
180    def time_extract_sql(
181        self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
182    ) -> tuple[str, tuple[Any, ...]]:
183        return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params)
184
185    def pk_default_value(self) -> str:
186        return "NULL"
187
188    def _quote_params_for_last_executed_query(
189        self, params: list[Any] | tuple[Any, ...]
190    ) -> tuple[Any, ...]:
191        """
192        Only for last_executed_query! Don't use this to execute SQL queries!
193        """
194        # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
195        # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
196        # number of return values, default = 2000). Since Python's sqlite3
197        # module doesn't expose the get_limit() C API, assume the default
198        # limits are in effect and split the work in batches if needed.
199        BATCH_SIZE = 999
200        if len(params) > BATCH_SIZE:
201            results = ()
202            for index in range(0, len(params), BATCH_SIZE):
203                chunk = params[index : index + BATCH_SIZE]
204                results += self._quote_params_for_last_executed_query(chunk)
205            return results
206
207        sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
208        # Bypass Plain's wrappers and use the underlying sqlite3 connection
209        # to avoid logging this query - it would trigger infinite recursion.
210        cursor = self.connection.connection.cursor()
211        # Native sqlite3 cursors cannot be used as context managers.
212        try:
213            return cursor.execute(sql, params).fetchone()
214        finally:
215            cursor.close()
216
217    def last_executed_query(
218        self,
219        cursor: CursorWrapper,
220        sql: str,
221        params: list[Any] | tuple[Any, ...] | dict[str, Any] | None,
222    ) -> str:
223        # Python substitutes parameters in Modules/_sqlite/cursor.c with:
224        # bind_parameters(state, self->statement, parameters);
225        # Unfortunately there is no way to reach self->statement from Python,
226        # so we quote and substitute parameters manually.
227        if params:
228            if isinstance(params, list | tuple):
229                params = self._quote_params_for_last_executed_query(params)
230            else:
231                values = tuple(params.values())
232                values = self._quote_params_for_last_executed_query(values)
233                params = dict(zip(params, values))
234            return sql % params
235        # For consistency with SQLiteCursorWrapper.execute(), just return sql
236        # when there are no parameters. See #13648 and #17158.
237        else:
238            return sql
239
240    def quote_name(self, name: str) -> str:
241        if name.startswith('"') and name.endswith('"'):
242            return name  # Quoting once is enough.
243        return f'"{name}"'
244
245    def no_limit_value(self) -> int:
246        return -1
247
248    def __references_graph(self, table_name: str) -> list[str]:
249        query = """
250        WITH tables AS (
251            SELECT %s name
252            UNION
253            SELECT sqlite_master.name
254            FROM sqlite_master
255            JOIN tables ON (sql REGEXP %s || tables.name || %s)
256        ) SELECT name FROM tables;
257        """
258        params = (
259            table_name,
260            r'(?i)\s+references\s+("|\')?',
261            r'("|\')?\s*\(',
262        )
263        with self.connection.cursor() as cursor:
264            results = cursor.execute(query, params)
265            return [row[0] for row in results.fetchall()]
266
267    @cached_property
268    def _references_graph(self) -> Callable[[str], list[str]]:
269        # 512 is large enough to fit the ~330 tables (as of this writing) in
270        # Plain's test suite.
271        return lru_cache(maxsize=512)(self.__references_graph)
272
273    def adapt_datetimefield_value(
274        self, value: datetime.datetime | Any | None
275    ) -> str | Any | None:
276        if value is None:
277            return None
278
279        # Expression values are adapted by the database.
280        if isinstance(value, ResolvableExpression):
281            return value
282
283        # SQLite doesn't support tz-aware datetimes
284        if timezone.is_aware(value):
285            value = timezone.make_naive(value, self.connection.timezone)
286
287        return str(value)
288
289    def adapt_timefield_value(
290        self, value: datetime.time | Any | None
291    ) -> str | Any | None:
292        if value is None:
293            return None
294
295        # Expression values are adapted by the database.
296        if isinstance(value, ResolvableExpression):
297            return value
298
299        # SQLite doesn't support tz-aware datetimes
300        if isinstance(value, datetime.time) and value.utcoffset() is not None:
301            raise ValueError("SQLite backend does not support timezone-aware times.")
302
303        return str(value)
304
305    def get_db_converters(self, expression: Any) -> list[Any]:
306        converters = super().get_db_converters(expression)
307        internal_type = expression.output_field.get_internal_type()
308        if internal_type == "DateTimeField":
309            converters.append(self.convert_datetimefield_value)
310        elif internal_type == "DateField":
311            converters.append(self.convert_datefield_value)
312        elif internal_type == "TimeField":
313            converters.append(self.convert_timefield_value)
314        elif internal_type == "DecimalField":
315            converters.append(self.get_decimalfield_converter(expression))
316        elif internal_type == "UUIDField":
317            converters.append(self.convert_uuidfield_value)
318        elif internal_type == "BooleanField":
319            converters.append(self.convert_booleanfield_value)
320        return converters
321
322    def convert_datetimefield_value(
323        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
324    ) -> datetime.datetime | None:
325        if value is not None:
326            if not isinstance(value, datetime.datetime):
327                value = parse_datetime(value)
328            if value is not None and not timezone.is_aware(value):
329                value = timezone.make_aware(value, self.connection.timezone)
330        return value
331
332    def convert_datefield_value(
333        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
334    ) -> datetime.date | None:
335        if value is not None:
336            if not isinstance(value, datetime.date):
337                value = parse_date(value)
338        return value
339
340    def convert_timefield_value(
341        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
342    ) -> datetime.time | None:
343        if value is not None:
344            if not isinstance(value, datetime.time):
345                value = parse_time(value)
346        return value
347
348    def get_decimalfield_converter(self, expression: Any) -> Callable[..., Any]:
349        # SQLite stores only 15 significant digits. Digits coming from
350        # float inaccuracy must be removed.
351        create_decimal = decimal.Context(prec=15).create_decimal_from_float
352        if isinstance(expression, Col):
353            quantize_value = decimal.Decimal(1).scaleb(
354                -expression.output_field.decimal_places
355            )
356
357            def converter(
358                value: Any, expression: Any, connection: BaseDatabaseWrapper
359            ) -> decimal.Decimal | None:
360                if value is not None:
361                    return create_decimal(value).quantize(
362                        quantize_value, context=expression.output_field.context
363                    )
364                return None
365
366        else:
367
368            def converter(
369                value: Any, expression: Any, connection: BaseDatabaseWrapper
370            ) -> decimal.Decimal | None:
371                if value is not None:
372                    return create_decimal(value)
373                return None
374
375        return converter
376
377    def convert_uuidfield_value(
378        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
379    ) -> uuid.UUID | None:
380        if value is not None:
381            value = uuid.UUID(value)
382        return value
383
384    def convert_booleanfield_value(
385        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
386    ) -> bool | Any:
387        return bool(value) if value in (1, 0) else value
388
389    def bulk_insert_sql(
390        self, fields: list[Any], placeholder_rows: list[list[str]]
391    ) -> str:
392        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
393        values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
394        return f"VALUES {values_sql}"
395
396    def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
397        # SQLite doesn't have a ^ operator, so use the user-defined POWER
398        # function that's registered in connect().
399        if connector == "^":
400            return "POWER({})".format(",".join(sub_expressions))
401        elif connector == "#":
402            return "BITXOR({})".format(",".join(sub_expressions))
403        return super().combine_expression(connector, sub_expressions)
404
405    def combine_duration_expression(
406        self, connector: str, sub_expressions: list[str]
407    ) -> str:
408        if connector not in ["+", "-", "*", "/"]:
409            raise DatabaseError(f"Invalid connector for timedelta: {connector}.")
410        fn_params = [f"'{connector}'"] + sub_expressions
411        if len(fn_params) > 3:
412            raise ValueError("Too many params for timedelta operations.")
413        return "plain_format_dtdelta({})".format(", ".join(fn_params))
414
415    def integer_field_range(self, internal_type: str) -> tuple[int, int]:
416        # SQLite doesn't enforce any integer constraints, but sqlite3 supports
417        # integers up to 64 bits.
418        if internal_type in [
419            "PositiveBigIntegerField",
420            "PositiveIntegerField",
421            "PositiveSmallIntegerField",
422        ]:
423            return (0, 9223372036854775807)
424        return (-9223372036854775808, 9223372036854775807)
425
426    def subtract_temporals(
427        self,
428        internal_type: str,
429        lhs: tuple[str, list[Any] | tuple[Any, ...]],
430        rhs: tuple[str, list[Any] | tuple[Any, ...]],
431    ) -> tuple[str, tuple[Any, ...]]:
432        lhs_sql, lhs_params = lhs
433        rhs_sql, rhs_params = rhs
434        params = (*lhs_params, *rhs_params)
435        if internal_type == "TimeField":
436            return f"plain_time_diff({lhs_sql}, {rhs_sql})", params
437        return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params
438
439    def insert_statement(self, on_conflict: Any = None) -> str:
440        if on_conflict == OnConflict.IGNORE:
441            return "INSERT OR IGNORE INTO"
442        return super().insert_statement(on_conflict=on_conflict)
443
444    def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
445        # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
446        if not fields:
447            return "", ()
448        columns = [
449            f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
450            for field in fields
451        ]
452        return "RETURNING {}".format(", ".join(columns)), ()
453
454    def on_conflict_suffix_sql(
455        self,
456        fields: list[Field],
457        on_conflict: Any,
458        update_fields: Iterable[str],
459        unique_fields: Iterable[str],
460    ) -> str:
461        if (
462            on_conflict == OnConflict.UPDATE
463            and self.connection.features.supports_update_conflicts_with_target
464        ):
465            return "ON CONFLICT({}) DO UPDATE SET {}".format(
466                ", ".join(map(self.quote_name, unique_fields)),
467                ", ".join(
468                    [
469                        f"{field} = EXCLUDED.{field}"
470                        for field in map(self.quote_name, update_fields)
471                    ]
472                ),
473            )
474        return super().on_conflict_suffix_sql(
475            fields,
476            on_conflict,
477            update_fields,
478            unique_fields,
479        )