Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import datetime
  4import decimal
  5import uuid
  6from collections.abc import Callable, Iterable
  7from functools import cached_property, lru_cache
  8from typing import TYPE_CHECKING, Any
  9
 10from plain import models
 11from plain.models.aggregates import Aggregate, Avg, StdDev, Sum, Variance
 12from plain.models.backends.base.operations import BaseDatabaseOperations
 13from plain.models.constants import OnConflict
 14from plain.models.db import DatabaseError, NotSupportedError
 15from plain.models.exceptions import FieldError
 16from plain.models.expressions import Col
 17from plain.utils import timezone
 18from plain.utils.dateparse import parse_date, parse_datetime, parse_time
 19
 20if TYPE_CHECKING:
 21    from plain.models.backends.base.base import BaseDatabaseWrapper
 22    from plain.models.backends.sqlite3.base import SQLiteDatabaseWrapper
 23    from plain.models.fields import Field
 24
 25
 26class DatabaseOperations(BaseDatabaseOperations):
 27    # Type checker hint: connection is always SQLiteDatabaseWrapper in this class
 28    connection: SQLiteDatabaseWrapper
 29
 30    cast_char_field_without_max_length = "text"
 31    cast_data_types = {
 32        "DateField": "TEXT",
 33        "DateTimeField": "TEXT",
 34    }
 35    explain_prefix = "EXPLAIN QUERY PLAN"
 36    # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
 37    # SQLite. Use JSON_TYPE() instead.
 38    jsonfield_datatype_values = frozenset(["null", "false", "true"])
 39
 40    def bulk_batch_size(self, fields: list[Any], objs: list[Any]) -> int:
 41        """
 42        SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
 43        999 variables per query.
 44
 45        If there's only a single field to insert, the limit is 500
 46        (SQLITE_MAX_COMPOUND_SELECT).
 47        """
 48        if len(fields) == 1:
 49            return 500
 50        elif len(fields) > 1:
 51            return self.connection.features.max_query_params // len(fields)
 52        else:
 53            return len(objs)
 54
 55    def check_expression_support(self, expression: Any) -> None:
 56        bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
 57        bad_aggregates = (Sum, Avg, Variance, StdDev)
 58        if isinstance(expression, bad_aggregates):
 59            for expr in expression.get_source_expressions():
 60                try:
 61                    output_field = expr.output_field
 62                except (AttributeError, FieldError):
 63                    # Not every subexpression has an output_field which is fine
 64                    # to ignore.
 65                    pass
 66                else:
 67                    if isinstance(output_field, bad_fields):
 68                        raise NotSupportedError(
 69                            "You cannot use Sum, Avg, StdDev, and Variance "
 70                            "aggregations on date/time fields in sqlite3 "
 71                            "since date/time is saved as text."
 72                        )
 73        if (
 74            isinstance(expression, Aggregate)
 75            and expression.distinct
 76            and len(expression.source_expressions) > 1
 77        ):
 78            raise NotSupportedError(
 79                "SQLite doesn't support DISTINCT on aggregate functions "
 80                "accepting multiple arguments."
 81            )
 82
 83    def date_extract_sql(
 84        self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
 85    ) -> tuple[str, tuple[Any, ...]]:
 86        """
 87        Support EXTRACT with a user-defined function plain_date_extract()
 88        that's registered in connect(). Use single quotes because this is a
 89        string and could otherwise cause a collision with a field name.
 90        """
 91        return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params)
 92
 93    def fetch_returned_insert_rows(self, cursor: Any) -> list[Any]:
 94        """
 95        Given a cursor object that has just performed an INSERT...RETURNING
 96        statement into a table, return the list of returned data.
 97        """
 98        return cursor.fetchall()
 99
100    def format_for_duration_arithmetic(self, sql: str) -> str:
101        """Do nothing since formatting is handled in the custom function."""
102        return sql
103
104    def date_trunc_sql(
105        self,
106        lookup_type: str,
107        sql: str,
108        params: list[Any] | tuple[Any, ...],
109        tzname: str | None = None,
110    ) -> tuple[str, tuple[Any, ...]]:
111        return f"plain_date_trunc(%s, {sql}, %s, %s)", (
112            lookup_type.lower(),
113            *params,
114            *self._convert_tznames_to_sql(tzname),
115        )
116
117    def time_trunc_sql(
118        self,
119        lookup_type: str,
120        sql: str,
121        params: list[Any] | tuple[Any, ...],
122        tzname: str | None = None,
123    ) -> tuple[str, tuple[Any, ...]]:
124        return f"plain_time_trunc(%s, {sql}, %s, %s)", (
125            lookup_type.lower(),
126            *params,
127            *self._convert_tznames_to_sql(tzname),
128        )
129
130    def _convert_tznames_to_sql(
131        self, tzname: str | None
132    ) -> tuple[str | None, str | None]:
133        if tzname:
134            return tzname, self.connection.timezone_name
135        return None, None
136
137    def datetime_cast_date_sql(
138        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
139    ) -> tuple[str, tuple[Any, ...]]:
140        return f"plain_datetime_cast_date({sql}, %s, %s)", (
141            *params,
142            *self._convert_tznames_to_sql(tzname),
143        )
144
145    def datetime_cast_time_sql(
146        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
147    ) -> tuple[str, tuple[Any, ...]]:
148        return f"plain_datetime_cast_time({sql}, %s, %s)", (
149            *params,
150            *self._convert_tznames_to_sql(tzname),
151        )
152
153    def datetime_extract_sql(
154        self,
155        lookup_type: str,
156        sql: str,
157        params: list[Any] | tuple[Any, ...],
158        tzname: str | None,
159    ) -> tuple[str, tuple[Any, ...]]:
160        return f"plain_datetime_extract(%s, {sql}, %s, %s)", (
161            lookup_type.lower(),
162            *params,
163            *self._convert_tznames_to_sql(tzname),
164        )
165
166    def datetime_trunc_sql(
167        self,
168        lookup_type: str,
169        sql: str,
170        params: list[Any] | tuple[Any, ...],
171        tzname: str | None,
172    ) -> tuple[str, tuple[Any, ...]]:
173        return f"plain_datetime_trunc(%s, {sql}, %s, %s)", (
174            lookup_type.lower(),
175            *params,
176            *self._convert_tznames_to_sql(tzname),
177        )
178
179    def time_extract_sql(
180        self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
181    ) -> tuple[str, tuple[Any, ...]]:
182        return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params)
183
184    def pk_default_value(self) -> str:
185        return "NULL"
186
187    def _quote_params_for_last_executed_query(
188        self, params: list[Any] | tuple[Any, ...]
189    ) -> tuple[Any, ...]:
190        """
191        Only for last_executed_query! Don't use this to execute SQL queries!
192        """
193        # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
194        # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
195        # number of return values, default = 2000). Since Python's sqlite3
196        # module doesn't expose the get_limit() C API, assume the default
197        # limits are in effect and split the work in batches if needed.
198        BATCH_SIZE = 999
199        if len(params) > BATCH_SIZE:
200            results = ()
201            for index in range(0, len(params), BATCH_SIZE):
202                chunk = params[index : index + BATCH_SIZE]
203                results += self._quote_params_for_last_executed_query(chunk)
204            return results
205
206        sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
207        # Bypass Plain's wrappers and use the underlying sqlite3 connection
208        # to avoid logging this query - it would trigger infinite recursion.
209        cursor = self.connection.connection.cursor()
210        # Native sqlite3 cursors cannot be used as context managers.
211        try:
212            return cursor.execute(sql, params).fetchone()
213        finally:
214            cursor.close()
215
216    def last_executed_query(
217        self,
218        cursor: Any,
219        sql: str,
220        params: list[Any] | tuple[Any, ...] | dict[str, Any] | None,
221    ) -> str:
222        # Python substitutes parameters in Modules/_sqlite/cursor.c with:
223        # bind_parameters(state, self->statement, parameters);
224        # Unfortunately there is no way to reach self->statement from Python,
225        # so we quote and substitute parameters manually.
226        if params:
227            if isinstance(params, list | tuple):
228                params = self._quote_params_for_last_executed_query(params)
229            else:
230                values = tuple(params.values())  # type: ignore[union-attr]
231                values = self._quote_params_for_last_executed_query(values)
232                params = dict(zip(params, values))
233            return sql % params
234        # For consistency with SQLiteCursorWrapper.execute(), just return sql
235        # when there are no parameters. See #13648 and #17158.
236        else:
237            return sql
238
239    def quote_name(self, name: str) -> str:
240        if name.startswith('"') and name.endswith('"'):
241            return name  # Quoting once is enough.
242        return f'"{name}"'
243
244    def no_limit_value(self) -> int:
245        return -1
246
247    def __references_graph(self, table_name: str) -> list[str]:
248        query = """
249        WITH tables AS (
250            SELECT %s name
251            UNION
252            SELECT sqlite_master.name
253            FROM sqlite_master
254            JOIN tables ON (sql REGEXP %s || tables.name || %s)
255        ) SELECT name FROM tables;
256        """
257        params = (
258            table_name,
259            r'(?i)\s+references\s+("|\')?',
260            r'("|\')?\s*\(',
261        )
262        with self.connection.cursor() as cursor:
263            results = cursor.execute(query, params)
264            return [row[0] for row in results.fetchall()]
265
266    @cached_property
267    def _references_graph(self) -> Callable[[str], list[str]]:
268        # 512 is large enough to fit the ~330 tables (as of this writing) in
269        # Plain's test suite.
270        return lru_cache(maxsize=512)(self.__references_graph)
271
272    def adapt_datetimefield_value(
273        self, value: datetime.datetime | Any | None
274    ) -> str | Any | None:
275        if value is None:
276            return None
277
278        # Expression values are adapted by the database.
279        if hasattr(value, "resolve_expression"):
280            return value
281
282        # SQLite doesn't support tz-aware datetimes
283        if timezone.is_aware(value):
284            value = timezone.make_naive(value, self.connection.timezone)
285
286        return str(value)
287
288    def adapt_timefield_value(
289        self, value: datetime.time | Any | None
290    ) -> str | Any | None:
291        if value is None:
292            return None
293
294        # Expression values are adapted by the database.
295        if hasattr(value, "resolve_expression"):
296            return value
297
298        # SQLite doesn't support tz-aware datetimes
299        if timezone.is_aware(value):  # type: ignore[arg-type]
300            raise ValueError("SQLite backend does not support timezone-aware times.")
301
302        return str(value)
303
304    def get_db_converters(self, expression: Any) -> list[Any]:
305        converters = super().get_db_converters(expression)
306        internal_type = expression.output_field.get_internal_type()
307        if internal_type == "DateTimeField":
308            converters.append(self.convert_datetimefield_value)
309        elif internal_type == "DateField":
310            converters.append(self.convert_datefield_value)
311        elif internal_type == "TimeField":
312            converters.append(self.convert_timefield_value)
313        elif internal_type == "DecimalField":
314            converters.append(self.get_decimalfield_converter(expression))
315        elif internal_type == "UUIDField":
316            converters.append(self.convert_uuidfield_value)
317        elif internal_type == "BooleanField":
318            converters.append(self.convert_booleanfield_value)
319        return converters
320
321    def convert_datetimefield_value(
322        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
323    ) -> datetime.datetime | None:
324        if value is not None:
325            if not isinstance(value, datetime.datetime):
326                value = parse_datetime(value)
327            if value is not None and not timezone.is_aware(value):
328                value = timezone.make_aware(value, self.connection.timezone)
329        return value
330
331    def convert_datefield_value(
332        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
333    ) -> datetime.date | None:
334        if value is not None:
335            if not isinstance(value, datetime.date):
336                value = parse_date(value)
337        return value
338
339    def convert_timefield_value(
340        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
341    ) -> datetime.time | None:
342        if value is not None:
343            if not isinstance(value, datetime.time):
344                value = parse_time(value)
345        return value
346
347    def get_decimalfield_converter(self, expression: Any) -> Callable[..., Any]:
348        # SQLite stores only 15 significant digits. Digits coming from
349        # float inaccuracy must be removed.
350        create_decimal = decimal.Context(prec=15).create_decimal_from_float
351        if isinstance(expression, Col):
352            quantize_value = decimal.Decimal(1).scaleb(
353                -expression.output_field.decimal_places
354            )
355
356            def converter(
357                value: Any, expression: Any, connection: BaseDatabaseWrapper
358            ) -> decimal.Decimal | None:
359                if value is not None:
360                    return create_decimal(value).quantize(
361                        quantize_value, context=expression.output_field.context
362                    )
363                return None
364
365        else:
366
367            def converter(
368                value: Any, expression: Any, connection: BaseDatabaseWrapper
369            ) -> decimal.Decimal | None:
370                if value is not None:
371                    return create_decimal(value)
372                return None
373
374        return converter
375
376    def convert_uuidfield_value(
377        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
378    ) -> uuid.UUID | None:
379        if value is not None:
380            value = uuid.UUID(value)
381        return value
382
383    def convert_booleanfield_value(
384        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
385    ) -> bool | Any:
386        return bool(value) if value in (1, 0) else value
387
388    def bulk_insert_sql(
389        self, fields: list[Any], placeholder_rows: list[list[str]]
390    ) -> str:
391        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
392        values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
393        return f"VALUES {values_sql}"
394
395    def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
396        # SQLite doesn't have a ^ operator, so use the user-defined POWER
397        # function that's registered in connect().
398        if connector == "^":
399            return "POWER({})".format(",".join(sub_expressions))
400        elif connector == "#":
401            return "BITXOR({})".format(",".join(sub_expressions))
402        return super().combine_expression(connector, sub_expressions)
403
404    def combine_duration_expression(
405        self, connector: str, sub_expressions: list[str]
406    ) -> str:
407        if connector not in ["+", "-", "*", "/"]:
408            raise DatabaseError(f"Invalid connector for timedelta: {connector}.")
409        fn_params = [f"'{connector}'"] + sub_expressions
410        if len(fn_params) > 3:
411            raise ValueError("Too many params for timedelta operations.")
412        return "plain_format_dtdelta({})".format(", ".join(fn_params))
413
414    def integer_field_range(self, internal_type: str) -> tuple[int, int]:
415        # SQLite doesn't enforce any integer constraints, but sqlite3 supports
416        # integers up to 64 bits.
417        if internal_type in [
418            "PositiveBigIntegerField",
419            "PositiveIntegerField",
420            "PositiveSmallIntegerField",
421        ]:
422            return (0, 9223372036854775807)
423        return (-9223372036854775808, 9223372036854775807)
424
425    def subtract_temporals(
426        self,
427        internal_type: str,
428        lhs: tuple[str, list[Any] | tuple[Any, ...]],
429        rhs: tuple[str, list[Any] | tuple[Any, ...]],
430    ) -> tuple[str, tuple[Any, ...]]:
431        lhs_sql, lhs_params = lhs
432        rhs_sql, rhs_params = rhs
433        params = (*lhs_params, *rhs_params)
434        if internal_type == "TimeField":
435            return f"plain_time_diff({lhs_sql}, {rhs_sql})", params
436        return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params
437
438    def insert_statement(self, on_conflict: Any = None) -> str:
439        if on_conflict == OnConflict.IGNORE:
440            return "INSERT OR IGNORE INTO"
441        return super().insert_statement(on_conflict=on_conflict)
442
443    def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
444        # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
445        if not fields:
446            return "", ()
447        columns = [
448            f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
449            for field in fields
450        ]
451        return "RETURNING {}".format(", ".join(columns)), ()
452
453    def on_conflict_suffix_sql(
454        self,
455        fields: list[Field],
456        on_conflict: Any,
457        update_fields: Iterable[str],
458        unique_fields: Iterable[str],
459    ) -> str:
460        if (
461            on_conflict == OnConflict.UPDATE
462            and self.connection.features.supports_update_conflicts_with_target
463        ):
464            return "ON CONFLICT({}) DO UPDATE SET {}".format(
465                ", ".join(map(self.quote_name, unique_fields)),
466                ", ".join(
467                    [
468                        f"{field} = EXCLUDED.{field}"
469                        for field in map(self.quote_name, update_fields)
470                    ]
471                ),
472            )
473        return super().on_conflict_suffix_sql(
474            fields,
475            on_conflict,
476            update_fields,
477            unique_fields,
478        )