Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import datetime
  4import decimal
  5import uuid
  6from collections.abc import Callable
  7from functools import cached_property, lru_cache
  8from typing import TYPE_CHECKING, Any
  9
 10from plain import models
 11from plain.models.backends.base.operations import BaseDatabaseOperations
 12from plain.models.constants import OnConflict
 13from plain.models.db import DatabaseError, NotSupportedError
 14from plain.models.exceptions import FieldError
 15from plain.models.expressions import Col
 16from plain.utils import timezone
 17from plain.utils.dateparse import parse_date, parse_datetime, parse_time
 18
 19if TYPE_CHECKING:
 20    from plain.models.backends.base.base import BaseDatabaseWrapper
 21
 22
 23class DatabaseOperations(BaseDatabaseOperations):
 24    cast_char_field_without_max_length = "text"
 25    cast_data_types = {
 26        "DateField": "TEXT",
 27        "DateTimeField": "TEXT",
 28    }
 29    explain_prefix = "EXPLAIN QUERY PLAN"
 30    # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
 31    # SQLite. Use JSON_TYPE() instead.
 32    jsonfield_datatype_values = frozenset(["null", "false", "true"])
 33
 34    def bulk_batch_size(self, fields: list[Any], objs: list[Any]) -> int:
 35        """
 36        SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
 37        999 variables per query.
 38
 39        If there's only a single field to insert, the limit is 500
 40        (SQLITE_MAX_COMPOUND_SELECT).
 41        """
 42        if len(fields) == 1:
 43            return 500
 44        elif len(fields) > 1:
 45            return self.connection.features.max_query_params // len(fields)
 46        else:
 47            return len(objs)
 48
 49    def check_expression_support(self, expression: Any) -> None:
 50        bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
 51        bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
 52        if isinstance(expression, bad_aggregates):
 53            for expr in expression.get_source_expressions():
 54                try:
 55                    output_field = expr.output_field
 56                except (AttributeError, FieldError):
 57                    # Not every subexpression has an output_field which is fine
 58                    # to ignore.
 59                    pass
 60                else:
 61                    if isinstance(output_field, bad_fields):
 62                        raise NotSupportedError(
 63                            "You cannot use Sum, Avg, StdDev, and Variance "
 64                            "aggregations on date/time fields in sqlite3 "
 65                            "since date/time is saved as text."
 66                        )
 67        if (
 68            isinstance(expression, models.Aggregate)
 69            and expression.distinct
 70            and len(expression.source_expressions) > 1
 71        ):
 72            raise NotSupportedError(
 73                "SQLite doesn't support DISTINCT on aggregate functions "
 74                "accepting multiple arguments."
 75            )
 76
 77    def date_extract_sql(
 78        self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
 79    ) -> tuple[str, tuple[Any, ...]]:
 80        """
 81        Support EXTRACT with a user-defined function plain_date_extract()
 82        that's registered in connect(). Use single quotes because this is a
 83        string and could otherwise cause a collision with a field name.
 84        """
 85        return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params)
 86
 87    def fetch_returned_insert_rows(self, cursor: Any) -> list[Any]:
 88        """
 89        Given a cursor object that has just performed an INSERT...RETURNING
 90        statement into a table, return the list of returned data.
 91        """
 92        return cursor.fetchall()
 93
 94    def format_for_duration_arithmetic(self, sql: str) -> str:
 95        """Do nothing since formatting is handled in the custom function."""
 96        return sql
 97
 98    def date_trunc_sql(
 99        self,
100        lookup_type: str,
101        sql: str,
102        params: list[Any] | tuple[Any, ...],
103        tzname: str | None = None,
104    ) -> tuple[str, tuple[Any, ...]]:
105        return f"plain_date_trunc(%s, {sql}, %s, %s)", (
106            lookup_type.lower(),
107            *params,
108            *self._convert_tznames_to_sql(tzname),
109        )
110
111    def time_trunc_sql(
112        self,
113        lookup_type: str,
114        sql: str,
115        params: list[Any] | tuple[Any, ...],
116        tzname: str | None = None,
117    ) -> tuple[str, tuple[Any, ...]]:
118        return f"plain_time_trunc(%s, {sql}, %s, %s)", (
119            lookup_type.lower(),
120            *params,
121            *self._convert_tznames_to_sql(tzname),
122        )
123
124    def _convert_tznames_to_sql(
125        self, tzname: str | None
126    ) -> tuple[str | None, str | None]:
127        if tzname:
128            return tzname, self.connection.timezone_name
129        return None, None
130
131    def datetime_cast_date_sql(
132        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
133    ) -> tuple[str, tuple[Any, ...]]:
134        return f"plain_datetime_cast_date({sql}, %s, %s)", (
135            *params,
136            *self._convert_tznames_to_sql(tzname),
137        )
138
139    def datetime_cast_time_sql(
140        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
141    ) -> tuple[str, tuple[Any, ...]]:
142        return f"plain_datetime_cast_time({sql}, %s, %s)", (
143            *params,
144            *self._convert_tznames_to_sql(tzname),
145        )
146
147    def datetime_extract_sql(
148        self,
149        lookup_type: str,
150        sql: str,
151        params: list[Any] | tuple[Any, ...],
152        tzname: str | None,
153    ) -> tuple[str, tuple[Any, ...]]:
154        return f"plain_datetime_extract(%s, {sql}, %s, %s)", (
155            lookup_type.lower(),
156            *params,
157            *self._convert_tznames_to_sql(tzname),
158        )
159
160    def datetime_trunc_sql(
161        self,
162        lookup_type: str,
163        sql: str,
164        params: list[Any] | tuple[Any, ...],
165        tzname: str | None,
166    ) -> tuple[str, tuple[Any, ...]]:
167        return f"plain_datetime_trunc(%s, {sql}, %s, %s)", (
168            lookup_type.lower(),
169            *params,
170            *self._convert_tznames_to_sql(tzname),
171        )
172
173    def time_extract_sql(
174        self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
175    ) -> tuple[str, tuple[Any, ...]]:
176        return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params)
177
178    def pk_default_value(self) -> str:
179        return "NULL"
180
181    def _quote_params_for_last_executed_query(
182        self, params: list[Any] | tuple[Any, ...]
183    ) -> tuple[Any, ...]:
184        """
185        Only for last_executed_query! Don't use this to execute SQL queries!
186        """
187        # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
188        # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
189        # number of return values, default = 2000). Since Python's sqlite3
190        # module doesn't expose the get_limit() C API, assume the default
191        # limits are in effect and split the work in batches if needed.
192        BATCH_SIZE = 999
193        if len(params) > BATCH_SIZE:
194            results = ()
195            for index in range(0, len(params), BATCH_SIZE):
196                chunk = params[index : index + BATCH_SIZE]
197                results += self._quote_params_for_last_executed_query(chunk)
198            return results
199
200        sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
201        # Bypass Plain's wrappers and use the underlying sqlite3 connection
202        # to avoid logging this query - it would trigger infinite recursion.
203        cursor = self.connection.connection.cursor()
204        # Native sqlite3 cursors cannot be used as context managers.
205        try:
206            return cursor.execute(sql, params).fetchone()
207        finally:
208            cursor.close()
209
210    def last_executed_query(
211        self,
212        cursor: Any,
213        sql: str,
214        params: list[Any] | tuple[Any, ...] | dict[str, Any] | None,
215    ) -> str:
216        # Python substitutes parameters in Modules/_sqlite/cursor.c with:
217        # bind_parameters(state, self->statement, parameters);
218        # Unfortunately there is no way to reach self->statement from Python,
219        # so we quote and substitute parameters manually.
220        if params:
221            if isinstance(params, list | tuple):
222                params = self._quote_params_for_last_executed_query(params)
223            else:
224                values = tuple(params.values())  # type: ignore[union-attr]
225                values = self._quote_params_for_last_executed_query(values)
226                params = dict(zip(params, values))
227            return sql % params
228        # For consistency with SQLiteCursorWrapper.execute(), just return sql
229        # when there are no parameters. See #13648 and #17158.
230        else:
231            return sql
232
233    def quote_name(self, name: str) -> str:
234        if name.startswith('"') and name.endswith('"'):
235            return name  # Quoting once is enough.
236        return f'"{name}"'
237
238    def no_limit_value(self) -> int:
239        return -1
240
241    def __references_graph(self, table_name: str) -> list[str]:
242        query = """
243        WITH tables AS (
244            SELECT %s name
245            UNION
246            SELECT sqlite_master.name
247            FROM sqlite_master
248            JOIN tables ON (sql REGEXP %s || tables.name || %s)
249        ) SELECT name FROM tables;
250        """
251        params = (
252            table_name,
253            r'(?i)\s+references\s+("|\')?',
254            r'("|\')?\s*\(',
255        )
256        with self.connection.cursor() as cursor:
257            results = cursor.execute(query, params)
258            return [row[0] for row in results.fetchall()]
259
260    @cached_property
261    def _references_graph(self) -> Callable[[str], list[str]]:
262        # 512 is large enough to fit the ~330 tables (as of this writing) in
263        # Plain's test suite.
264        return lru_cache(maxsize=512)(self.__references_graph)
265
266    def adapt_datetimefield_value(
267        self, value: datetime.datetime | Any | None
268    ) -> str | Any | None:
269        if value is None:
270            return None
271
272        # Expression values are adapted by the database.
273        if hasattr(value, "resolve_expression"):
274            return value
275
276        # SQLite doesn't support tz-aware datetimes
277        if timezone.is_aware(value):
278            value = timezone.make_naive(value, self.connection.timezone)
279
280        return str(value)
281
282    def adapt_timefield_value(
283        self, value: datetime.time | Any | None
284    ) -> str | Any | None:
285        if value is None:
286            return None
287
288        # Expression values are adapted by the database.
289        if hasattr(value, "resolve_expression"):
290            return value
291
292        # SQLite doesn't support tz-aware datetimes
293        if timezone.is_aware(value):  # type: ignore[arg-type]
294            raise ValueError("SQLite backend does not support timezone-aware times.")
295
296        return str(value)
297
298    def get_db_converters(self, expression: Any) -> list[Any]:
299        converters = super().get_db_converters(expression)
300        internal_type = expression.output_field.get_internal_type()
301        if internal_type == "DateTimeField":
302            converters.append(self.convert_datetimefield_value)
303        elif internal_type == "DateField":
304            converters.append(self.convert_datefield_value)
305        elif internal_type == "TimeField":
306            converters.append(self.convert_timefield_value)
307        elif internal_type == "DecimalField":
308            converters.append(self.get_decimalfield_converter(expression))
309        elif internal_type == "UUIDField":
310            converters.append(self.convert_uuidfield_value)
311        elif internal_type == "BooleanField":
312            converters.append(self.convert_booleanfield_value)
313        return converters
314
315    def convert_datetimefield_value(
316        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
317    ) -> datetime.datetime | None:
318        if value is not None:
319            if not isinstance(value, datetime.datetime):
320                value = parse_datetime(value)
321            if value is not None and not timezone.is_aware(value):
322                value = timezone.make_aware(value, self.connection.timezone)
323        return value
324
325    def convert_datefield_value(
326        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
327    ) -> datetime.date | None:
328        if value is not None:
329            if not isinstance(value, datetime.date):
330                value = parse_date(value)
331        return value
332
333    def convert_timefield_value(
334        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
335    ) -> datetime.time | None:
336        if value is not None:
337            if not isinstance(value, datetime.time):
338                value = parse_time(value)
339        return value
340
341    def get_decimalfield_converter(self, expression: Any) -> Callable[..., Any]:
342        # SQLite stores only 15 significant digits. Digits coming from
343        # float inaccuracy must be removed.
344        create_decimal = decimal.Context(prec=15).create_decimal_from_float
345        if isinstance(expression, Col):
346            quantize_value = decimal.Decimal(1).scaleb(
347                -expression.output_field.decimal_places
348            )
349
350            def converter(
351                value: Any, expression: Any, connection: BaseDatabaseWrapper
352            ) -> decimal.Decimal | None:
353                if value is not None:
354                    return create_decimal(value).quantize(
355                        quantize_value, context=expression.output_field.context
356                    )
357                return None
358
359        else:
360
361            def converter(
362                value: Any, expression: Any, connection: BaseDatabaseWrapper
363            ) -> decimal.Decimal | None:
364                if value is not None:
365                    return create_decimal(value)
366                return None
367
368        return converter
369
370    def convert_uuidfield_value(
371        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
372    ) -> uuid.UUID | None:
373        if value is not None:
374            value = uuid.UUID(value)
375        return value
376
377    def convert_booleanfield_value(
378        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
379    ) -> bool | Any:
380        return bool(value) if value in (1, 0) else value
381
382    def bulk_insert_sql(
383        self, fields: list[Any], placeholder_rows: list[list[str]]
384    ) -> str:
385        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
386        values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
387        return f"VALUES {values_sql}"
388
389    def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
390        # SQLite doesn't have a ^ operator, so use the user-defined POWER
391        # function that's registered in connect().
392        if connector == "^":
393            return "POWER({})".format(",".join(sub_expressions))
394        elif connector == "#":
395            return "BITXOR({})".format(",".join(sub_expressions))
396        return super().combine_expression(connector, sub_expressions)
397
398    def combine_duration_expression(
399        self, connector: str, sub_expressions: list[str]
400    ) -> str:
401        if connector not in ["+", "-", "*", "/"]:
402            raise DatabaseError(f"Invalid connector for timedelta: {connector}.")
403        fn_params = [f"'{connector}'"] + sub_expressions
404        if len(fn_params) > 3:
405            raise ValueError("Too many params for timedelta operations.")
406        return "plain_format_dtdelta({})".format(", ".join(fn_params))
407
408    def integer_field_range(self, internal_type: str) -> tuple[int, int]:
409        # SQLite doesn't enforce any integer constraints, but sqlite3 supports
410        # integers up to 64 bits.
411        if internal_type in [
412            "PositiveBigIntegerField",
413            "PositiveIntegerField",
414            "PositiveSmallIntegerField",
415        ]:
416            return (0, 9223372036854775807)
417        return (-9223372036854775808, 9223372036854775807)
418
419    def subtract_temporals(
420        self,
421        internal_type: str,
422        lhs: tuple[str, list[Any] | tuple[Any, ...]],
423        rhs: tuple[str, list[Any] | tuple[Any, ...]],
424    ) -> tuple[str, tuple[Any, ...]]:
425        lhs_sql, lhs_params = lhs
426        rhs_sql, rhs_params = rhs
427        params = (*lhs_params, *rhs_params)
428        if internal_type == "TimeField":
429            return f"plain_time_diff({lhs_sql}, {rhs_sql})", params
430        return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params
431
432    def insert_statement(self, on_conflict: Any = None) -> str:
433        if on_conflict == OnConflict.IGNORE:
434            return "INSERT OR IGNORE INTO"
435        return super().insert_statement(on_conflict=on_conflict)
436
437    def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
438        # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
439        if not fields:
440            return "", ()
441        columns = [
442            f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
443            for field in fields
444        ]
445        return "RETURNING {}".format(", ".join(columns)), ()
446
447    def on_conflict_suffix_sql(
448        self,
449        fields: list[Any],
450        on_conflict: Any,
451        update_fields: list[Any],
452        unique_fields: list[Any],
453    ) -> str:
454        if (
455            on_conflict == OnConflict.UPDATE
456            and self.connection.features.supports_update_conflicts_with_target
457        ):
458            return "ON CONFLICT({}) DO UPDATE SET {}".format(
459                ", ".join(map(self.quote_name, unique_fields)),
460                ", ".join(
461                    [
462                        f"{field} = EXCLUDED.{field}"
463                        for field in map(self.quote_name, update_fields)
464                    ]
465                ),
466            )
467        return super().on_conflict_suffix_sql(
468            fields,
469            on_conflict,
470            update_fields,
471            unique_fields,
472        )