Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import decimal
  3import json
  4from importlib import import_module
  5
  6import sqlparse
  7
  8from plain.models.backends import utils
  9from plain.models.db import NotSupportedError
 10from plain.utils import timezone
 11from plain.utils.encoding import force_str
 12
 13
 14class BaseDatabaseOperations:
 15    """
 16    Encapsulate backend-specific differences, such as the way a backend
 17    performs ordering or calculates the ID of a recently-inserted row.
 18    """
 19
 20    compiler_module = "plain.models.sql.compiler"
 21
 22    # Integer field safe ranges by `internal_type` as documented
 23    # in docs/ref/models/fields.txt.
 24    integer_field_ranges = {
 25        "SmallIntegerField": (-32768, 32767),
 26        "IntegerField": (-2147483648, 2147483647),
 27        "BigIntegerField": (-9223372036854775808, 9223372036854775807),
 28        "PositiveBigIntegerField": (0, 9223372036854775807),
 29        "PositiveSmallIntegerField": (0, 32767),
 30        "PositiveIntegerField": (0, 2147483647),
 31        "SmallAutoField": (-32768, 32767),
 32        "AutoField": (-2147483648, 2147483647),
 33        "BigAutoField": (-9223372036854775808, 9223372036854775807),
 34    }
 35    set_operators = {
 36        "union": "UNION",
 37        "intersection": "INTERSECT",
 38        "difference": "EXCEPT",
 39    }
 40    # Mapping of Field.get_internal_type() (typically the model field's class
 41    # name) to the data type to use for the Cast() function, if different from
 42    # DatabaseWrapper.data_types.
 43    cast_data_types = {}
 44    # CharField data type if the max_length argument isn't provided.
 45    cast_char_field_without_max_length = None
 46
 47    # Start and end points for window expressions.
 48    PRECEDING = "PRECEDING"
 49    FOLLOWING = "FOLLOWING"
 50    UNBOUNDED_PRECEDING = "UNBOUNDED " + PRECEDING
 51    UNBOUNDED_FOLLOWING = "UNBOUNDED " + FOLLOWING
 52    CURRENT_ROW = "CURRENT ROW"
 53
 54    # Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
 55    explain_prefix = None
 56
 57    def __init__(self, connection):
 58        self.connection = connection
 59        self._cache = None
 60
 61    def autoinc_sql(self, table, column):
 62        """
 63        Return any SQL needed to support auto-incrementing primary keys, or
 64        None if no SQL is necessary.
 65
 66        This SQL is executed when a table is created.
 67        """
 68        return None
 69
 70    def bulk_batch_size(self, fields, objs):
 71        """
 72        Return the maximum allowed batch size for the backend. The fields
 73        are the fields going to be inserted in the batch, the objs contains
 74        all the objects to be inserted.
 75        """
 76        return len(objs)
 77
 78    def format_for_duration_arithmetic(self, sql):
 79        raise NotImplementedError(
 80            "subclasses of BaseDatabaseOperations may require a "
 81            "format_for_duration_arithmetic() method."
 82        )
 83
 84    def unification_cast_sql(self, output_field):
 85        """
 86        Given a field instance, return the SQL that casts the result of a union
 87        to that type. The resulting string should contain a '%s' placeholder
 88        for the expression being cast.
 89        """
 90        return "%s"
 91
 92    def date_extract_sql(self, lookup_type, sql, params):
 93        """
 94        Given a lookup_type of 'year', 'month', or 'day', return the SQL that
 95        extracts a value from the given date field field_name.
 96        """
 97        raise NotImplementedError(
 98            "subclasses of BaseDatabaseOperations may require a date_extract_sql() "
 99            "method"
100        )
101
102    def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
103        """
104        Given a lookup_type of 'year', 'month', or 'day', return the SQL that
105        truncates the given date or datetime field field_name to a date object
106        with only the given specificity.
107
108        If `tzname` is provided, the given value is truncated in a specific
109        timezone.
110        """
111        raise NotImplementedError(
112            "subclasses of BaseDatabaseOperations may require a date_trunc_sql() "
113            "method."
114        )
115
116    def datetime_cast_date_sql(self, sql, params, tzname):
117        """
118        Return the SQL to cast a datetime value to date value.
119        """
120        raise NotImplementedError(
121            "subclasses of BaseDatabaseOperations may require a "
122            "datetime_cast_date_sql() method."
123        )
124
125    def datetime_cast_time_sql(self, sql, params, tzname):
126        """
127        Return the SQL to cast a datetime value to time value.
128        """
129        raise NotImplementedError(
130            "subclasses of BaseDatabaseOperations may require a "
131            "datetime_cast_time_sql() method"
132        )
133
134    def datetime_extract_sql(self, lookup_type, sql, params, tzname):
135        """
136        Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
137        'second', return the SQL that extracts a value from the given
138        datetime field field_name.
139        """
140        raise NotImplementedError(
141            "subclasses of BaseDatabaseOperations may require a datetime_extract_sql() "
142            "method"
143        )
144
145    def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
146        """
147        Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
148        'second', return the SQL that truncates the given datetime field
149        field_name to a datetime object with only the given specificity.
150        """
151        raise NotImplementedError(
152            "subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() "
153            "method"
154        )
155
156    def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
157        """
158        Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
159        that truncates the given time or datetime field field_name to a time
160        object with only the given specificity.
161
162        If `tzname` is provided, the given value is truncated in a specific
163        timezone.
164        """
165        raise NotImplementedError(
166            "subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
167        )
168
169    def time_extract_sql(self, lookup_type, sql, params):
170        """
171        Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
172        that extracts a value from the given time field field_name.
173        """
174        return self.date_extract_sql(lookup_type, sql, params)
175
176    def deferrable_sql(self):
177        """
178        Return the SQL to make a constraint "initially deferred" during a
179        CREATE TABLE statement.
180        """
181        return ""
182
183    def distinct_sql(self, fields, params):
184        """
185        Return an SQL DISTINCT clause which removes duplicate rows from the
186        result set. If any fields are given, only check the given fields for
187        duplicates.
188        """
189        if fields:
190            raise NotSupportedError(
191                "DISTINCT ON fields is not supported by this database backend"
192            )
193        else:
194            return ["DISTINCT"], []
195
196    def fetch_returned_insert_columns(self, cursor, returning_params):
197        """
198        Given a cursor object that has just performed an INSERT...RETURNING
199        statement into a table, return the newly created data.
200        """
201        return cursor.fetchone()
202
203    def field_cast_sql(self, db_type, internal_type):
204        """
205        Given a column type (e.g. 'BLOB', 'VARCHAR') and an internal type
206        (e.g. 'GenericIPAddressField'), return the SQL to cast it before using
207        it in a WHERE statement. The resulting string should contain a '%s'
208        placeholder for the column being searched against.
209        """
210        return "%s"
211
212    def force_no_ordering(self):
213        """
214        Return a list used in the "ORDER BY" clause to force no ordering at
215        all. Return an empty list to include nothing in the ordering.
216        """
217        return []
218
219    def for_update_sql(self, nowait=False, skip_locked=False, of=(), no_key=False):
220        """
221        Return the FOR UPDATE SQL clause to lock rows for an update operation.
222        """
223        return "FOR{} UPDATE{}{}{}".format(
224            " NO KEY" if no_key else "",
225            " OF {}".format(", ".join(of)) if of else "",
226            " NOWAIT" if nowait else "",
227            " SKIP LOCKED" if skip_locked else "",
228        )
229
230    def _get_limit_offset_params(self, low_mark, high_mark):
231        offset = low_mark or 0
232        if high_mark is not None:
233            return (high_mark - offset), offset
234        elif offset:
235            return self.connection.ops.no_limit_value(), offset
236        return None, offset
237
238    def limit_offset_sql(self, low_mark, high_mark):
239        """Return LIMIT/OFFSET SQL clause."""
240        limit, offset = self._get_limit_offset_params(low_mark, high_mark)
241        return " ".join(
242            sql
243            for sql in (
244                ("LIMIT %d" % limit) if limit else None,  # noqa: UP031
245                ("OFFSET %d" % offset) if offset else None,  # noqa: UP031
246            )
247            if sql
248        )
249
250    def last_executed_query(self, cursor, sql, params):
251        """
252        Return a string of the query last executed by the given cursor, with
253        placeholders replaced with actual values.
254
255        `sql` is the raw query containing placeholders and `params` is the
256        sequence of parameters. These are used by default, but this method
257        exists for database backends to provide a better implementation
258        according to their own quoting schemes.
259        """
260
261        # Convert params to contain string values.
262        def to_string(s):
263            return force_str(s, strings_only=True, errors="replace")
264
265        if isinstance(params, list | tuple):
266            u_params = tuple(to_string(val) for val in params)
267        elif params is None:
268            u_params = ()
269        else:
270            u_params = {to_string(k): to_string(v) for k, v in params.items()}
271
272        return f"QUERY = {sql!r} - PARAMS = {u_params!r}"
273
274    def last_insert_id(self, cursor, table_name, pk_name):
275        """
276        Given a cursor object that has just performed an INSERT statement into
277        a table that has an auto-incrementing ID, return the newly created ID.
278
279        `pk_name` is the name of the primary-key column.
280        """
281        return cursor.lastrowid
282
283    def lookup_cast(self, lookup_type, internal_type=None):
284        """
285        Return the string to use in a query when performing lookups
286        ("contains", "like", etc.). It should contain a '%s' placeholder for
287        the column being searched against.
288        """
289        return "%s"
290
291    def max_in_list_size(self):
292        """
293        Return the maximum number of items that can be passed in a single 'IN'
294        list condition, or None if the backend does not impose a limit.
295        """
296        return None
297
298    def max_name_length(self):
299        """
300        Return the maximum length of table and column names, or None if there
301        is no limit.
302        """
303        return None
304
305    def no_limit_value(self):
306        """
307        Return the value to use for the LIMIT when we are wanting "LIMIT
308        infinity". Return None if the limit clause can be omitted in this case.
309        """
310        raise NotImplementedError(
311            "subclasses of BaseDatabaseOperations may require a no_limit_value() method"
312        )
313
314    def pk_default_value(self):
315        """
316        Return the value to use during an INSERT statement to specify that
317        the field should use its default value.
318        """
319        return "DEFAULT"
320
321    def prepare_sql_script(self, sql):
322        """
323        Take an SQL script that may contain multiple lines and return a list
324        of statements to feed to successive cursor.execute() calls.
325
326        Since few databases are able to process raw SQL scripts in a single
327        cursor.execute() call and PEP 249 doesn't talk about this use case,
328        the default implementation is conservative.
329        """
330        return [
331            sqlparse.format(statement, strip_comments=True)
332            for statement in sqlparse.split(sql)
333            if statement
334        ]
335
336    def return_insert_columns(self, fields):
337        """
338        For backends that support returning columns as part of an insert query,
339        return the SQL and params to append to the INSERT query. The returned
340        fragment should contain a format string to hold the appropriate column.
341        """
342        pass
343
344    def compiler(self, compiler_name):
345        """
346        Return the SQLCompiler class corresponding to the given name,
347        in the namespace corresponding to the `compiler_module` attribute
348        on this backend.
349        """
350        if self._cache is None:
351            self._cache = import_module(self.compiler_module)
352        return getattr(self._cache, compiler_name)
353
354    def quote_name(self, name):
355        """
356        Return a quoted version of the given table, index, or column name. Do
357        not quote the given name if it's already been quoted.
358        """
359        raise NotImplementedError(
360            "subclasses of BaseDatabaseOperations may require a quote_name() method"
361        )
362
363    def regex_lookup(self, lookup_type):
364        """
365        Return the string to use in a query when performing regular expression
366        lookups (using "regex" or "iregex"). It should contain a '%s'
367        placeholder for the column being searched against.
368
369        If the feature is not supported (or part of it is not supported), raise
370        NotImplementedError.
371        """
372        raise NotImplementedError(
373            "subclasses of BaseDatabaseOperations may require a regex_lookup() method"
374        )
375
376    def savepoint_create_sql(self, sid):
377        """
378        Return the SQL for starting a new savepoint. Only required if the
379        "uses_savepoints" feature is True. The "sid" parameter is a string
380        for the savepoint id.
381        """
382        return f"SAVEPOINT {self.quote_name(sid)}"
383
384    def savepoint_commit_sql(self, sid):
385        """
386        Return the SQL for committing the given savepoint.
387        """
388        return f"RELEASE SAVEPOINT {self.quote_name(sid)}"
389
390    def savepoint_rollback_sql(self, sid):
391        """
392        Return the SQL for rolling back the given savepoint.
393        """
394        return f"ROLLBACK TO SAVEPOINT {self.quote_name(sid)}"
395
396    def set_time_zone_sql(self):
397        """
398        Return the SQL that will set the connection's time zone.
399
400        Return '' if the backend doesn't support time zones.
401        """
402        return ""
403
404    def prep_for_like_query(self, x):
405        """Prepare a value for use in a LIKE query."""
406        return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
407
408    # Same as prep_for_like_query(), but called for "iexact" matches, which
409    # need not necessarily be implemented using "LIKE" in the backend.
410    prep_for_iexact_query = prep_for_like_query
411
412    def validate_autopk_value(self, value):
413        """
414        Certain backends do not accept some values for "serial" fields
415        (for example zero in MySQL). Raise a ValueError if the value is
416        invalid, otherwise return the validated value.
417        """
418        return value
419
420    def adapt_unknown_value(self, value):
421        """
422        Transform a value to something compatible with the backend driver.
423
424        This method only depends on the type of the value. It's designed for
425        cases where the target type isn't known, such as .raw() SQL queries.
426        As a consequence it may not work perfectly in all circumstances.
427        """
428        if isinstance(value, datetime.datetime):  # must be before date
429            return self.adapt_datetimefield_value(value)
430        elif isinstance(value, datetime.date):
431            return self.adapt_datefield_value(value)
432        elif isinstance(value, datetime.time):
433            return self.adapt_timefield_value(value)
434        elif isinstance(value, decimal.Decimal):
435            return self.adapt_decimalfield_value(value)
436        else:
437            return value
438
439    def adapt_integerfield_value(self, value, internal_type):
440        return value
441
442    def adapt_datefield_value(self, value):
443        """
444        Transform a date value to an object compatible with what is expected
445        by the backend driver for date columns.
446        """
447        if value is None:
448            return None
449        return str(value)
450
451    def adapt_datetimefield_value(self, value):
452        """
453        Transform a datetime value to an object compatible with what is expected
454        by the backend driver for datetime columns.
455        """
456        if value is None:
457            return None
458        # Expression values are adapted by the database.
459        if hasattr(value, "resolve_expression"):
460            return value
461
462        return str(value)
463
464    def adapt_timefield_value(self, value):
465        """
466        Transform a time value to an object compatible with what is expected
467        by the backend driver for time columns.
468        """
469        if value is None:
470            return None
471        # Expression values are adapted by the database.
472        if hasattr(value, "resolve_expression"):
473            return value
474
475        if timezone.is_aware(value):
476            raise ValueError("Plain does not support timezone-aware times.")
477        return str(value)
478
479    def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
480        """
481        Transform a decimal.Decimal value to an object compatible with what is
482        expected by the backend driver for decimal (numeric) columns.
483        """
484        return utils.format_number(value, max_digits, decimal_places)
485
486    def adapt_ipaddressfield_value(self, value):
487        """
488        Transform a string representation of an IP address into the expected
489        type for the backend driver.
490        """
491        return value or None
492
493    def adapt_json_value(self, value, encoder):
494        return json.dumps(value, cls=encoder)
495
496    def year_lookup_bounds_for_date_field(self, value, iso_year=False):
497        """
498        Return a two-elements list with the lower and upper bound to be used
499        with a BETWEEN operator to query a DateField value using a year
500        lookup.
501
502        `value` is an int, containing the looked-up year.
503        If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
504        """
505        if iso_year:
506            first = datetime.date.fromisocalendar(value, 1, 1)
507            second = datetime.date.fromisocalendar(
508                value + 1, 1, 1
509            ) - datetime.timedelta(days=1)
510        else:
511            first = datetime.date(value, 1, 1)
512            second = datetime.date(value, 12, 31)
513        first = self.adapt_datefield_value(first)
514        second = self.adapt_datefield_value(second)
515        return [first, second]
516
517    def year_lookup_bounds_for_datetime_field(self, value, iso_year=False):
518        """
519        Return a two-elements list with the lower and upper bound to be used
520        with a BETWEEN operator to query a DateTimeField value using a year
521        lookup.
522
523        `value` is an int, containing the looked-up year.
524        If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
525        """
526        if iso_year:
527            first = datetime.datetime.fromisocalendar(value, 1, 1)
528            second = datetime.datetime.fromisocalendar(
529                value + 1, 1, 1
530            ) - datetime.timedelta(microseconds=1)
531        else:
532            first = datetime.datetime(value, 1, 1)
533            second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
534
535        # Make sure that datetimes are aware in the current timezone
536        tz = timezone.get_current_timezone()
537        first = timezone.make_aware(first, tz)
538        second = timezone.make_aware(second, tz)
539
540        first = self.adapt_datetimefield_value(first)
541        second = self.adapt_datetimefield_value(second)
542        return [first, second]
543
544    def get_db_converters(self, expression):
545        """
546        Return a list of functions needed to convert field data.
547
548        Some field types on some backends do not provide data in the correct
549        format, this is the hook for converter functions.
550        """
551        return []
552
553    def convert_durationfield_value(self, value, expression, connection):
554        if value is not None:
555            return datetime.timedelta(0, 0, value)
556
557    def check_expression_support(self, expression):
558        """
559        Check that the backend supports the provided expression.
560
561        This is used on specific backends to rule out known expressions
562        that have problematic or nonexistent implementations. If the
563        expression has a known problem, the backend should raise
564        NotSupportedError.
565        """
566        pass
567
568    def conditional_expression_supported_in_where_clause(self, expression):
569        """
570        Return True, if the conditional expression is supported in the WHERE
571        clause.
572        """
573        return True
574
575    def combine_expression(self, connector, sub_expressions):
576        """
577        Combine a list of subexpressions into a single expression, using
578        the provided connecting operator. This is required because operators
579        can vary between backends (e.g., Oracle with %% and &) and between
580        subexpression types (e.g., date expressions).
581        """
582        conn = f" {connector} "
583        return conn.join(sub_expressions)
584
585    def combine_duration_expression(self, connector, sub_expressions):
586        return self.combine_expression(connector, sub_expressions)
587
588    def binary_placeholder_sql(self, value):
589        """
590        Some backends require special syntax to insert binary content (MySQL
591        for example uses '_binary %s').
592        """
593        return "%s"
594
595    def modify_insert_params(self, placeholder, params):
596        """
597        Allow modification of insert parameters. Needed for Oracle Spatial
598        backend due to #10888.
599        """
600        return params
601
602    def integer_field_range(self, internal_type):
603        """
604        Given an integer field internal type (e.g. 'PositiveIntegerField'),
605        return a tuple of the (min_value, max_value) form representing the
606        range of the column type bound to the field.
607        """
608        return self.integer_field_ranges[internal_type]
609
610    def subtract_temporals(self, internal_type, lhs, rhs):
611        if self.connection.features.supports_temporal_subtraction:
612            lhs_sql, lhs_params = lhs
613            rhs_sql, rhs_params = rhs
614            return f"({lhs_sql} - {rhs_sql})", (*lhs_params, *rhs_params)
615        raise NotSupportedError(
616            f"This backend does not support {internal_type} subtraction."
617        )
618
619    def window_frame_start(self, start):
620        if isinstance(start, int):
621            if start < 0:
622                return "%d %s" % (abs(start), self.PRECEDING)  # noqa: UP031
623            elif start == 0:
624                return self.CURRENT_ROW
625        elif start is None:
626            return self.UNBOUNDED_PRECEDING
627        raise ValueError(
628            f"start argument must be a negative integer, zero, or None, but got '{start}'."
629        )
630
631    def window_frame_end(self, end):
632        if isinstance(end, int):
633            if end == 0:
634                return self.CURRENT_ROW
635            elif end > 0:
636                return "%d %s" % (end, self.FOLLOWING)  # noqa: UP031
637        elif end is None:
638            return self.UNBOUNDED_FOLLOWING
639        raise ValueError(
640            f"end argument must be a positive integer, zero, or None, but got '{end}'."
641        )
642
643    def window_frame_rows_start_end(self, start=None, end=None):
644        """
645        Return SQL for start and end points in an OVER clause window frame.
646        """
647        if not self.connection.features.supports_over_clause:
648            raise NotSupportedError("This backend does not support window expressions.")
649        return self.window_frame_start(start), self.window_frame_end(end)
650
651    def window_frame_range_start_end(self, start=None, end=None):
652        start_, end_ = self.window_frame_rows_start_end(start, end)
653        features = self.connection.features
654        if features.only_supports_unbounded_with_preceding_and_following and (
655            (start and start < 0) or (end and end > 0)
656        ):
657            raise NotSupportedError(
658                f"{self.connection.display_name} only supports UNBOUNDED together with PRECEDING and "
659                "FOLLOWING."
660            )
661        return start_, end_
662
663    def explain_query_prefix(self, format=None, **options):
664        if not self.connection.features.supports_explaining_query_execution:
665            raise NotSupportedError(
666                "This backend does not support explaining query execution."
667            )
668        if format:
669            supported_formats = self.connection.features.supported_explain_formats
670            normalized_format = format.upper()
671            if normalized_format not in supported_formats:
672                msg = f"{normalized_format} is not a recognized format."
673                if supported_formats:
674                    msg += " Allowed formats: {}".format(
675                        ", ".join(sorted(supported_formats))
676                    )
677                else:
678                    msg += (
679                        f" {self.connection.display_name} does not support any formats."
680                    )
681                raise ValueError(msg)
682        if options:
683            raise ValueError(
684                "Unknown options: {}".format(", ".join(sorted(options.keys())))
685            )
686        return self.explain_prefix
687
688    def insert_statement(self, on_conflict=None):
689        return "INSERT INTO"
690
691    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
692        return ""