Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import decimal
  3import uuid
  4from functools import cached_property, lru_cache
  5
  6from plain import models
  7from plain.exceptions import FieldError
  8from plain.models.backends.base.operations import BaseDatabaseOperations
  9from plain.models.constants import OnConflict
 10from plain.models.db import DatabaseError, NotSupportedError
 11from plain.models.expressions import Col
 12from plain.utils import timezone
 13from plain.utils.dateparse import parse_date, parse_datetime, parse_time
 14
 15
 16class DatabaseOperations(BaseDatabaseOperations):
 17    cast_char_field_without_max_length = "text"
 18    cast_data_types = {
 19        "DateField": "TEXT",
 20        "DateTimeField": "TEXT",
 21    }
 22    explain_prefix = "EXPLAIN QUERY PLAN"
 23    # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
 24    # SQLite. Use JSON_TYPE() instead.
 25    jsonfield_datatype_values = frozenset(["null", "false", "true"])
 26
 27    def bulk_batch_size(self, fields, objs):
 28        """
 29        SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
 30        999 variables per query.
 31
 32        If there's only a single field to insert, the limit is 500
 33        (SQLITE_MAX_COMPOUND_SELECT).
 34        """
 35        if len(fields) == 1:
 36            return 500
 37        elif len(fields) > 1:
 38            return self.connection.features.max_query_params // len(fields)
 39        else:
 40            return len(objs)
 41
 42    def check_expression_support(self, expression):
 43        bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
 44        bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
 45        if isinstance(expression, bad_aggregates):
 46            for expr in expression.get_source_expressions():
 47                try:
 48                    output_field = expr.output_field
 49                except (AttributeError, FieldError):
 50                    # Not every subexpression has an output_field which is fine
 51                    # to ignore.
 52                    pass
 53                else:
 54                    if isinstance(output_field, bad_fields):
 55                        raise NotSupportedError(
 56                            "You cannot use Sum, Avg, StdDev, and Variance "
 57                            "aggregations on date/time fields in sqlite3 "
 58                            "since date/time is saved as text."
 59                        )
 60        if (
 61            isinstance(expression, models.Aggregate)
 62            and expression.distinct
 63            and len(expression.source_expressions) > 1
 64        ):
 65            raise NotSupportedError(
 66                "SQLite doesn't support DISTINCT on aggregate functions "
 67                "accepting multiple arguments."
 68            )
 69
 70    def date_extract_sql(self, lookup_type, sql, params):
 71        """
 72        Support EXTRACT with a user-defined function plain_date_extract()
 73        that's registered in connect(). Use single quotes because this is a
 74        string and could otherwise cause a collision with a field name.
 75        """
 76        return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params)
 77
 78    def fetch_returned_insert_rows(self, cursor):
 79        """
 80        Given a cursor object that has just performed an INSERT...RETURNING
 81        statement into a table, return the list of returned data.
 82        """
 83        return cursor.fetchall()
 84
 85    def format_for_duration_arithmetic(self, sql):
 86        """Do nothing since formatting is handled in the custom function."""
 87        return sql
 88
 89    def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
 90        return f"plain_date_trunc(%s, {sql}, %s, %s)", (
 91            lookup_type.lower(),
 92            *params,
 93            *self._convert_tznames_to_sql(tzname),
 94        )
 95
 96    def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
 97        return f"plain_time_trunc(%s, {sql}, %s, %s)", (
 98            lookup_type.lower(),
 99            *params,
100            *self._convert_tznames_to_sql(tzname),
101        )
102
103    def _convert_tznames_to_sql(self, tzname):
104        if tzname:
105            return tzname, self.connection.timezone_name
106        return None, None
107
108    def datetime_cast_date_sql(self, sql, params, tzname):
109        return f"plain_datetime_cast_date({sql}, %s, %s)", (
110            *params,
111            *self._convert_tznames_to_sql(tzname),
112        )
113
114    def datetime_cast_time_sql(self, sql, params, tzname):
115        return f"plain_datetime_cast_time({sql}, %s, %s)", (
116            *params,
117            *self._convert_tznames_to_sql(tzname),
118        )
119
120    def datetime_extract_sql(self, lookup_type, sql, params, tzname):
121        return f"plain_datetime_extract(%s, {sql}, %s, %s)", (
122            lookup_type.lower(),
123            *params,
124            *self._convert_tznames_to_sql(tzname),
125        )
126
127    def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
128        return f"plain_datetime_trunc(%s, {sql}, %s, %s)", (
129            lookup_type.lower(),
130            *params,
131            *self._convert_tznames_to_sql(tzname),
132        )
133
134    def time_extract_sql(self, lookup_type, sql, params):
135        return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params)
136
137    def pk_default_value(self):
138        return "NULL"
139
140    def _quote_params_for_last_executed_query(self, params):
141        """
142        Only for last_executed_query! Don't use this to execute SQL queries!
143        """
144        # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
145        # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
146        # number of return values, default = 2000). Since Python's sqlite3
147        # module doesn't expose the get_limit() C API, assume the default
148        # limits are in effect and split the work in batches if needed.
149        BATCH_SIZE = 999
150        if len(params) > BATCH_SIZE:
151            results = ()
152            for index in range(0, len(params), BATCH_SIZE):
153                chunk = params[index : index + BATCH_SIZE]
154                results += self._quote_params_for_last_executed_query(chunk)
155            return results
156
157        sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
158        # Bypass Plain's wrappers and use the underlying sqlite3 connection
159        # to avoid logging this query - it would trigger infinite recursion.
160        cursor = self.connection.connection.cursor()
161        # Native sqlite3 cursors cannot be used as context managers.
162        try:
163            return cursor.execute(sql, params).fetchone()
164        finally:
165            cursor.close()
166
167    def last_executed_query(self, cursor, sql, params):
168        # Python substitutes parameters in Modules/_sqlite/cursor.c with:
169        # bind_parameters(state, self->statement, parameters);
170        # Unfortunately there is no way to reach self->statement from Python,
171        # so we quote and substitute parameters manually.
172        if params:
173            if isinstance(params, list | tuple):
174                params = self._quote_params_for_last_executed_query(params)
175            else:
176                values = tuple(params.values())
177                values = self._quote_params_for_last_executed_query(values)
178                params = dict(zip(params, values))
179            return sql % params
180        # For consistency with SQLiteCursorWrapper.execute(), just return sql
181        # when there are no parameters. See #13648 and #17158.
182        else:
183            return sql
184
185    def quote_name(self, name):
186        if name.startswith('"') and name.endswith('"'):
187            return name  # Quoting once is enough.
188        return f'"{name}"'
189
190    def no_limit_value(self):
191        return -1
192
193    def __references_graph(self, table_name):
194        query = """
195        WITH tables AS (
196            SELECT %s name
197            UNION
198            SELECT sqlite_master.name
199            FROM sqlite_master
200            JOIN tables ON (sql REGEXP %s || tables.name || %s)
201        ) SELECT name FROM tables;
202        """
203        params = (
204            table_name,
205            r'(?i)\s+references\s+("|\')?',
206            r'("|\')?\s*\(',
207        )
208        with self.connection.cursor() as cursor:
209            results = cursor.execute(query, params)
210            return [row[0] for row in results.fetchall()]
211
212    @cached_property
213    def _references_graph(self):
214        # 512 is large enough to fit the ~330 tables (as of this writing) in
215        # Plain's test suite.
216        return lru_cache(maxsize=512)(self.__references_graph)
217
218    def adapt_datetimefield_value(self, value):
219        if value is None:
220            return None
221
222        # Expression values are adapted by the database.
223        if hasattr(value, "resolve_expression"):
224            return value
225
226        # SQLite doesn't support tz-aware datetimes
227        if timezone.is_aware(value):
228            value = timezone.make_naive(value, self.connection.timezone)
229
230        return str(value)
231
232    def adapt_timefield_value(self, value):
233        if value is None:
234            return None
235
236        # Expression values are adapted by the database.
237        if hasattr(value, "resolve_expression"):
238            return value
239
240        # SQLite doesn't support tz-aware datetimes
241        if timezone.is_aware(value):
242            raise ValueError("SQLite backend does not support timezone-aware times.")
243
244        return str(value)
245
246    def get_db_converters(self, expression):
247        converters = super().get_db_converters(expression)
248        internal_type = expression.output_field.get_internal_type()
249        if internal_type == "DateTimeField":
250            converters.append(self.convert_datetimefield_value)
251        elif internal_type == "DateField":
252            converters.append(self.convert_datefield_value)
253        elif internal_type == "TimeField":
254            converters.append(self.convert_timefield_value)
255        elif internal_type == "DecimalField":
256            converters.append(self.get_decimalfield_converter(expression))
257        elif internal_type == "UUIDField":
258            converters.append(self.convert_uuidfield_value)
259        elif internal_type == "BooleanField":
260            converters.append(self.convert_booleanfield_value)
261        return converters
262
263    def convert_datetimefield_value(self, value, expression, connection):
264        if value is not None:
265            if not isinstance(value, datetime.datetime):
266                value = parse_datetime(value)
267            if not timezone.is_aware(value):
268                value = timezone.make_aware(value, self.connection.timezone)
269        return value
270
271    def convert_datefield_value(self, value, expression, connection):
272        if value is not None:
273            if not isinstance(value, datetime.date):
274                value = parse_date(value)
275        return value
276
277    def convert_timefield_value(self, value, expression, connection):
278        if value is not None:
279            if not isinstance(value, datetime.time):
280                value = parse_time(value)
281        return value
282
283    def get_decimalfield_converter(self, expression):
284        # SQLite stores only 15 significant digits. Digits coming from
285        # float inaccuracy must be removed.
286        create_decimal = decimal.Context(prec=15).create_decimal_from_float
287        if isinstance(expression, Col):
288            quantize_value = decimal.Decimal(1).scaleb(
289                -expression.output_field.decimal_places
290            )
291
292            def converter(value, expression, connection):
293                if value is not None:
294                    return create_decimal(value).quantize(
295                        quantize_value, context=expression.output_field.context
296                    )
297
298        else:
299
300            def converter(value, expression, connection):
301                if value is not None:
302                    return create_decimal(value)
303
304        return converter
305
306    def convert_uuidfield_value(self, value, expression, connection):
307        if value is not None:
308            value = uuid.UUID(value)
309        return value
310
311    def convert_booleanfield_value(self, value, expression, connection):
312        return bool(value) if value in (1, 0) else value
313
314    def bulk_insert_sql(self, fields, placeholder_rows):
315        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
316        values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
317        return f"VALUES {values_sql}"
318
319    def combine_expression(self, connector, sub_expressions):
320        # SQLite doesn't have a ^ operator, so use the user-defined POWER
321        # function that's registered in connect().
322        if connector == "^":
323            return "POWER({})".format(",".join(sub_expressions))
324        elif connector == "#":
325            return "BITXOR({})".format(",".join(sub_expressions))
326        return super().combine_expression(connector, sub_expressions)
327
328    def combine_duration_expression(self, connector, sub_expressions):
329        if connector not in ["+", "-", "*", "/"]:
330            raise DatabaseError(f"Invalid connector for timedelta: {connector}.")
331        fn_params = [f"'{connector}'"] + sub_expressions
332        if len(fn_params) > 3:
333            raise ValueError("Too many params for timedelta operations.")
334        return "plain_format_dtdelta({})".format(", ".join(fn_params))
335
336    def integer_field_range(self, internal_type):
337        # SQLite doesn't enforce any integer constraints, but sqlite3 supports
338        # integers up to 64 bits.
339        if internal_type in [
340            "PositiveBigIntegerField",
341            "PositiveIntegerField",
342            "PositiveSmallIntegerField",
343        ]:
344            return (0, 9223372036854775807)
345        return (-9223372036854775808, 9223372036854775807)
346
347    def subtract_temporals(self, internal_type, lhs, rhs):
348        lhs_sql, lhs_params = lhs
349        rhs_sql, rhs_params = rhs
350        params = (*lhs_params, *rhs_params)
351        if internal_type == "TimeField":
352            return f"plain_time_diff({lhs_sql}, {rhs_sql})", params
353        return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params
354
355    def insert_statement(self, on_conflict=None):
356        if on_conflict == OnConflict.IGNORE:
357            return "INSERT OR IGNORE INTO"
358        return super().insert_statement(on_conflict=on_conflict)
359
360    def return_insert_columns(self, fields):
361        # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
362        if not fields:
363            return "", ()
364        columns = [
365            f"{self.quote_name(field.model._meta.db_table)}.{self.quote_name(field.column)}"
366            for field in fields
367        ]
368        return "RETURNING {}".format(", ".join(columns)), ()
369
370    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
371        if (
372            on_conflict == OnConflict.UPDATE
373            and self.connection.features.supports_update_conflicts_with_target
374        ):
375            return "ON CONFLICT({}) DO UPDATE SET {}".format(
376                ", ".join(map(self.quote_name, unique_fields)),
377                ", ".join(
378                    [
379                        f"{field} = EXCLUDED.{field}"
380                        for field in map(self.quote_name, update_fields)
381                    ]
382                ),
383            )
384        return super().on_conflict_suffix_sql(
385            fields,
386            on_conflict,
387            update_fields,
388            unique_fields,
389        )