Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import decimal
  3import uuid
  4from functools import lru_cache
  5
  6from plain import models
  7from plain.exceptions import FieldError
  8from plain.models.backends.base.operations import BaseDatabaseOperations
  9from plain.models.constants import OnConflict
 10from plain.models.db import DatabaseError, NotSupportedError
 11from plain.models.expressions import Col
 12from plain.runtime import settings
 13from plain.utils import timezone
 14from plain.utils.dateparse import parse_date, parse_datetime, parse_time
 15from plain.utils.functional import cached_property
 16
 17
 18class DatabaseOperations(BaseDatabaseOperations):
 19    cast_char_field_without_max_length = "text"
 20    cast_data_types = {
 21        "DateField": "TEXT",
 22        "DateTimeField": "TEXT",
 23    }
 24    explain_prefix = "EXPLAIN QUERY PLAN"
 25    # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
 26    # SQLite. Use JSON_TYPE() instead.
 27    jsonfield_datatype_values = frozenset(["null", "false", "true"])
 28
 29    def bulk_batch_size(self, fields, objs):
 30        """
 31        SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
 32        999 variables per query.
 33
 34        If there's only a single field to insert, the limit is 500
 35        (SQLITE_MAX_COMPOUND_SELECT).
 36        """
 37        if len(fields) == 1:
 38            return 500
 39        elif len(fields) > 1:
 40            return self.connection.features.max_query_params // len(fields)
 41        else:
 42            return len(objs)
 43
 44    def check_expression_support(self, expression):
 45        bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
 46        bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
 47        if isinstance(expression, bad_aggregates):
 48            for expr in expression.get_source_expressions():
 49                try:
 50                    output_field = expr.output_field
 51                except (AttributeError, FieldError):
 52                    # Not every subexpression has an output_field which is fine
 53                    # to ignore.
 54                    pass
 55                else:
 56                    if isinstance(output_field, bad_fields):
 57                        raise NotSupportedError(
 58                            "You cannot use Sum, Avg, StdDev, and Variance "
 59                            "aggregations on date/time fields in sqlite3 "
 60                            "since date/time is saved as text."
 61                        )
 62        if (
 63            isinstance(expression, models.Aggregate)
 64            and expression.distinct
 65            and len(expression.source_expressions) > 1
 66        ):
 67            raise NotSupportedError(
 68                "SQLite doesn't support DISTINCT on aggregate functions "
 69                "accepting multiple arguments."
 70            )
 71
 72    def date_extract_sql(self, lookup_type, sql, params):
 73        """
 74        Support EXTRACT with a user-defined function plain_date_extract()
 75        that's registered in connect(). Use single quotes because this is a
 76        string and could otherwise cause a collision with a field name.
 77        """
 78        return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params)
 79
 80    def fetch_returned_insert_rows(self, cursor):
 81        """
 82        Given a cursor object that has just performed an INSERT...RETURNING
 83        statement into a table, return the list of returned data.
 84        """
 85        return cursor.fetchall()
 86
 87    def format_for_duration_arithmetic(self, sql):
 88        """Do nothing since formatting is handled in the custom function."""
 89        return sql
 90
 91    def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
 92        return f"plain_date_trunc(%s, {sql}, %s, %s)", (
 93            lookup_type.lower(),
 94            *params,
 95            *self._convert_tznames_to_sql(tzname),
 96        )
 97
 98    def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
 99        return f"plain_time_trunc(%s, {sql}, %s, %s)", (
100            lookup_type.lower(),
101            *params,
102            *self._convert_tznames_to_sql(tzname),
103        )
104
105    def _convert_tznames_to_sql(self, tzname):
106        if tzname and settings.USE_TZ:
107            return tzname, self.connection.timezone_name
108        return None, None
109
110    def datetime_cast_date_sql(self, sql, params, tzname):
111        return f"plain_datetime_cast_date({sql}, %s, %s)", (
112            *params,
113            *self._convert_tznames_to_sql(tzname),
114        )
115
116    def datetime_cast_time_sql(self, sql, params, tzname):
117        return f"plain_datetime_cast_time({sql}, %s, %s)", (
118            *params,
119            *self._convert_tznames_to_sql(tzname),
120        )
121
122    def datetime_extract_sql(self, lookup_type, sql, params, tzname):
123        return f"plain_datetime_extract(%s, {sql}, %s, %s)", (
124            lookup_type.lower(),
125            *params,
126            *self._convert_tznames_to_sql(tzname),
127        )
128
129    def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
130        return f"plain_datetime_trunc(%s, {sql}, %s, %s)", (
131            lookup_type.lower(),
132            *params,
133            *self._convert_tznames_to_sql(tzname),
134        )
135
136    def time_extract_sql(self, lookup_type, sql, params):
137        return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params)
138
139    def pk_default_value(self):
140        return "NULL"
141
142    def _quote_params_for_last_executed_query(self, params):
143        """
144        Only for last_executed_query! Don't use this to execute SQL queries!
145        """
146        # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
147        # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
148        # number of return values, default = 2000). Since Python's sqlite3
149        # module doesn't expose the get_limit() C API, assume the default
150        # limits are in effect and split the work in batches if needed.
151        BATCH_SIZE = 999
152        if len(params) > BATCH_SIZE:
153            results = ()
154            for index in range(0, len(params), BATCH_SIZE):
155                chunk = params[index : index + BATCH_SIZE]
156                results += self._quote_params_for_last_executed_query(chunk)
157            return results
158
159        sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
160        # Bypass Plain's wrappers and use the underlying sqlite3 connection
161        # to avoid logging this query - it would trigger infinite recursion.
162        cursor = self.connection.connection.cursor()
163        # Native sqlite3 cursors cannot be used as context managers.
164        try:
165            return cursor.execute(sql, params).fetchone()
166        finally:
167            cursor.close()
168
169    def last_executed_query(self, cursor, sql, params):
170        # Python substitutes parameters in Modules/_sqlite/cursor.c with:
171        # bind_parameters(state, self->statement, parameters);
172        # Unfortunately there is no way to reach self->statement from Python,
173        # so we quote and substitute parameters manually.
174        if params:
175            if isinstance(params, list | tuple):
176                params = self._quote_params_for_last_executed_query(params)
177            else:
178                values = tuple(params.values())
179                values = self._quote_params_for_last_executed_query(values)
180                params = dict(zip(params, values))
181            return sql % params
182        # For consistency with SQLiteCursorWrapper.execute(), just return sql
183        # when there are no parameters. See #13648 and #17158.
184        else:
185            return sql
186
187    def quote_name(self, name):
188        if name.startswith('"') and name.endswith('"'):
189            return name  # Quoting once is enough.
190        return '"%s"' % name
191
192    def no_limit_value(self):
193        return -1
194
195    def __references_graph(self, table_name):
196        query = """
197        WITH tables AS (
198            SELECT %s name
199            UNION
200            SELECT sqlite_master.name
201            FROM sqlite_master
202            JOIN tables ON (sql REGEXP %s || tables.name || %s)
203        ) SELECT name FROM tables;
204        """
205        params = (
206            table_name,
207            r'(?i)\s+references\s+("|\')?',
208            r'("|\')?\s*\(',
209        )
210        with self.connection.cursor() as cursor:
211            results = cursor.execute(query, params)
212            return [row[0] for row in results.fetchall()]
213
214    @cached_property
215    def _references_graph(self):
216        # 512 is large enough to fit the ~330 tables (as of this writing) in
217        # Plain's test suite.
218        return lru_cache(maxsize=512)(self.__references_graph)
219
220    def sequence_reset_by_name_sql(self, style, sequences):
221        if not sequences:
222            return []
223        return [
224            "{} {} {} {} = 0 {} {} {} ({});".format(
225                style.SQL_KEYWORD("UPDATE"),
226                style.SQL_TABLE(self.quote_name("sqlite_sequence")),
227                style.SQL_KEYWORD("SET"),
228                style.SQL_FIELD(self.quote_name("seq")),
229                style.SQL_KEYWORD("WHERE"),
230                style.SQL_FIELD(self.quote_name("name")),
231                style.SQL_KEYWORD("IN"),
232                ", ".join(
233                    ["'%s'" % sequence_info["table"] for sequence_info in sequences]
234                ),
235            ),
236        ]
237
238    def adapt_datetimefield_value(self, value):
239        if value is None:
240            return None
241
242        # Expression values are adapted by the database.
243        if hasattr(value, "resolve_expression"):
244            return value
245
246        # SQLite doesn't support tz-aware datetimes
247        if timezone.is_aware(value):
248            if settings.USE_TZ:
249                value = timezone.make_naive(value, self.connection.timezone)
250            else:
251                raise ValueError(
252                    "SQLite backend does not support timezone-aware datetimes when "
253                    "USE_TZ is False."
254                )
255
256        return str(value)
257
258    def adapt_timefield_value(self, value):
259        if value is None:
260            return None
261
262        # Expression values are adapted by the database.
263        if hasattr(value, "resolve_expression"):
264            return value
265
266        # SQLite doesn't support tz-aware datetimes
267        if timezone.is_aware(value):
268            raise ValueError("SQLite backend does not support timezone-aware times.")
269
270        return str(value)
271
272    def get_db_converters(self, expression):
273        converters = super().get_db_converters(expression)
274        internal_type = expression.output_field.get_internal_type()
275        if internal_type == "DateTimeField":
276            converters.append(self.convert_datetimefield_value)
277        elif internal_type == "DateField":
278            converters.append(self.convert_datefield_value)
279        elif internal_type == "TimeField":
280            converters.append(self.convert_timefield_value)
281        elif internal_type == "DecimalField":
282            converters.append(self.get_decimalfield_converter(expression))
283        elif internal_type == "UUIDField":
284            converters.append(self.convert_uuidfield_value)
285        elif internal_type == "BooleanField":
286            converters.append(self.convert_booleanfield_value)
287        return converters
288
289    def convert_datetimefield_value(self, value, expression, connection):
290        if value is not None:
291            if not isinstance(value, datetime.datetime):
292                value = parse_datetime(value)
293            if settings.USE_TZ and not timezone.is_aware(value):
294                value = timezone.make_aware(value, self.connection.timezone)
295        return value
296
297    def convert_datefield_value(self, value, expression, connection):
298        if value is not None:
299            if not isinstance(value, datetime.date):
300                value = parse_date(value)
301        return value
302
303    def convert_timefield_value(self, value, expression, connection):
304        if value is not None:
305            if not isinstance(value, datetime.time):
306                value = parse_time(value)
307        return value
308
309    def get_decimalfield_converter(self, expression):
310        # SQLite stores only 15 significant digits. Digits coming from
311        # float inaccuracy must be removed.
312        create_decimal = decimal.Context(prec=15).create_decimal_from_float
313        if isinstance(expression, Col):
314            quantize_value = decimal.Decimal(1).scaleb(
315                -expression.output_field.decimal_places
316            )
317
318            def converter(value, expression, connection):
319                if value is not None:
320                    return create_decimal(value).quantize(
321                        quantize_value, context=expression.output_field.context
322                    )
323
324        else:
325
326            def converter(value, expression, connection):
327                if value is not None:
328                    return create_decimal(value)
329
330        return converter
331
332    def convert_uuidfield_value(self, value, expression, connection):
333        if value is not None:
334            value = uuid.UUID(value)
335        return value
336
337    def convert_booleanfield_value(self, value, expression, connection):
338        return bool(value) if value in (1, 0) else value
339
340    def bulk_insert_sql(self, fields, placeholder_rows):
341        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
342        values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
343        return f"VALUES {values_sql}"
344
345    def combine_expression(self, connector, sub_expressions):
346        # SQLite doesn't have a ^ operator, so use the user-defined POWER
347        # function that's registered in connect().
348        if connector == "^":
349            return "POWER(%s)" % ",".join(sub_expressions)
350        elif connector == "#":
351            return "BITXOR(%s)" % ",".join(sub_expressions)
352        return super().combine_expression(connector, sub_expressions)
353
354    def combine_duration_expression(self, connector, sub_expressions):
355        if connector not in ["+", "-", "*", "/"]:
356            raise DatabaseError("Invalid connector for timedelta: %s." % connector)
357        fn_params = ["'%s'" % connector] + sub_expressions
358        if len(fn_params) > 3:
359            raise ValueError("Too many params for timedelta operations.")
360        return "plain_format_dtdelta(%s)" % ", ".join(fn_params)
361
362    def integer_field_range(self, internal_type):
363        # SQLite doesn't enforce any integer constraints, but sqlite3 supports
364        # integers up to 64 bits.
365        if internal_type in [
366            "PositiveBigIntegerField",
367            "PositiveIntegerField",
368            "PositiveSmallIntegerField",
369        ]:
370            return (0, 9223372036854775807)
371        return (-9223372036854775808, 9223372036854775807)
372
373    def subtract_temporals(self, internal_type, lhs, rhs):
374        lhs_sql, lhs_params = lhs
375        rhs_sql, rhs_params = rhs
376        params = (*lhs_params, *rhs_params)
377        if internal_type == "TimeField":
378            return f"plain_time_diff({lhs_sql}, {rhs_sql})", params
379        return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params
380
381    def insert_statement(self, on_conflict=None):
382        if on_conflict == OnConflict.IGNORE:
383            return "INSERT OR IGNORE INTO"
384        return super().insert_statement(on_conflict=on_conflict)
385
386    def return_insert_columns(self, fields):
387        # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
388        if not fields:
389            return "", ()
390        columns = [
391            "{}.{}".format(
392                self.quote_name(field.model._meta.db_table),
393                self.quote_name(field.column),
394            )
395            for field in fields
396        ]
397        return "RETURNING %s" % ", ".join(columns), ()
398
399    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
400        if (
401            on_conflict == OnConflict.UPDATE
402            and self.connection.features.supports_update_conflicts_with_target
403        ):
404            return "ON CONFLICT({}) DO UPDATE SET {}".format(
405                ", ".join(map(self.quote_name, unique_fields)),
406                ", ".join(
407                    [
408                        f"{field} = EXCLUDED.{field}"
409                        for field in map(self.quote_name, update_fields)
410                    ]
411                ),
412            )
413        return super().on_conflict_suffix_sql(
414            fields,
415            on_conflict,
416            update_fields,
417            unique_fields,
418        )