Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import decimal
  3import json
  4from importlib import import_module
  5
  6import sqlparse
  7
  8from plain.models.backends import utils
  9from plain.models.db import NotSupportedError
 10from plain.runtime import settings
 11from plain.utils import timezone
 12from plain.utils.encoding import force_str
 13
 14
 15class BaseDatabaseOperations:
 16    """
 17    Encapsulate backend-specific differences, such as the way a backend
 18    performs ordering or calculates the ID of a recently-inserted row.
 19    """
 20
 21    compiler_module = "plain.models.sql.compiler"
 22
 23    # Integer field safe ranges by `internal_type` as documented
 24    # in docs/ref/models/fields.txt.
 25    integer_field_ranges = {
 26        "SmallIntegerField": (-32768, 32767),
 27        "IntegerField": (-2147483648, 2147483647),
 28        "BigIntegerField": (-9223372036854775808, 9223372036854775807),
 29        "PositiveBigIntegerField": (0, 9223372036854775807),
 30        "PositiveSmallIntegerField": (0, 32767),
 31        "PositiveIntegerField": (0, 2147483647),
 32        "SmallAutoField": (-32768, 32767),
 33        "AutoField": (-2147483648, 2147483647),
 34        "BigAutoField": (-9223372036854775808, 9223372036854775807),
 35    }
 36    set_operators = {
 37        "union": "UNION",
 38        "intersection": "INTERSECT",
 39        "difference": "EXCEPT",
 40    }
 41    # Mapping of Field.get_internal_type() (typically the model field's class
 42    # name) to the data type to use for the Cast() function, if different from
 43    # DatabaseWrapper.data_types.
 44    cast_data_types = {}
 45    # CharField data type if the max_length argument isn't provided.
 46    cast_char_field_without_max_length = None
 47
 48    # Start and end points for window expressions.
 49    PRECEDING = "PRECEDING"
 50    FOLLOWING = "FOLLOWING"
 51    UNBOUNDED_PRECEDING = "UNBOUNDED " + PRECEDING
 52    UNBOUNDED_FOLLOWING = "UNBOUNDED " + FOLLOWING
 53    CURRENT_ROW = "CURRENT ROW"
 54
 55    # Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
 56    explain_prefix = None
 57
 58    def __init__(self, connection):
 59        self.connection = connection
 60        self._cache = None
 61
 62    def autoinc_sql(self, table, column):
 63        """
 64        Return any SQL needed to support auto-incrementing primary keys, or
 65        None if no SQL is necessary.
 66
 67        This SQL is executed when a table is created.
 68        """
 69        return None
 70
 71    def bulk_batch_size(self, fields, objs):
 72        """
 73        Return the maximum allowed batch size for the backend. The fields
 74        are the fields going to be inserted in the batch, the objs contains
 75        all the objects to be inserted.
 76        """
 77        return len(objs)
 78
 79    def format_for_duration_arithmetic(self, sql):
 80        raise NotImplementedError(
 81            "subclasses of BaseDatabaseOperations may require a "
 82            "format_for_duration_arithmetic() method."
 83        )
 84
 85    def cache_key_culling_sql(self):
 86        """
 87        Return an SQL query that retrieves the first cache key greater than the
 88        n smallest.
 89
 90        This is used by the 'db' cache backend to determine where to start
 91        culling.
 92        """
 93        cache_key = self.quote_name("cache_key")
 94        return f"SELECT {cache_key} FROM %s ORDER BY {cache_key} LIMIT 1 OFFSET %%s"
 95
 96    def unification_cast_sql(self, output_field):
 97        """
 98        Given a field instance, return the SQL that casts the result of a union
 99        to that type. The resulting string should contain a '%s' placeholder
100        for the expression being cast.
101        """
102        return "%s"
103
104    def date_extract_sql(self, lookup_type, sql, params):
105        """
106        Given a lookup_type of 'year', 'month', or 'day', return the SQL that
107        extracts a value from the given date field field_name.
108        """
109        raise NotImplementedError(
110            "subclasses of BaseDatabaseOperations may require a date_extract_sql() "
111            "method"
112        )
113
114    def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
115        """
116        Given a lookup_type of 'year', 'month', or 'day', return the SQL that
117        truncates the given date or datetime field field_name to a date object
118        with only the given specificity.
119
120        If `tzname` is provided, the given value is truncated in a specific
121        timezone.
122        """
123        raise NotImplementedError(
124            "subclasses of BaseDatabaseOperations may require a date_trunc_sql() "
125            "method."
126        )
127
128    def datetime_cast_date_sql(self, sql, params, tzname):
129        """
130        Return the SQL to cast a datetime value to date value.
131        """
132        raise NotImplementedError(
133            "subclasses of BaseDatabaseOperations may require a "
134            "datetime_cast_date_sql() method."
135        )
136
137    def datetime_cast_time_sql(self, sql, params, tzname):
138        """
139        Return the SQL to cast a datetime value to time value.
140        """
141        raise NotImplementedError(
142            "subclasses of BaseDatabaseOperations may require a "
143            "datetime_cast_time_sql() method"
144        )
145
146    def datetime_extract_sql(self, lookup_type, sql, params, tzname):
147        """
148        Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
149        'second', return the SQL that extracts a value from the given
150        datetime field field_name.
151        """
152        raise NotImplementedError(
153            "subclasses of BaseDatabaseOperations may require a datetime_extract_sql() "
154            "method"
155        )
156
157    def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
158        """
159        Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
160        'second', return the SQL that truncates the given datetime field
161        field_name to a datetime object with only the given specificity.
162        """
163        raise NotImplementedError(
164            "subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() "
165            "method"
166        )
167
168    def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
169        """
170        Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
171        that truncates the given time or datetime field field_name to a time
172        object with only the given specificity.
173
174        If `tzname` is provided, the given value is truncated in a specific
175        timezone.
176        """
177        raise NotImplementedError(
178            "subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
179        )
180
181    def time_extract_sql(self, lookup_type, sql, params):
182        """
183        Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
184        that extracts a value from the given time field field_name.
185        """
186        return self.date_extract_sql(lookup_type, sql, params)
187
188    def deferrable_sql(self):
189        """
190        Return the SQL to make a constraint "initially deferred" during a
191        CREATE TABLE statement.
192        """
193        return ""
194
195    def distinct_sql(self, fields, params):
196        """
197        Return an SQL DISTINCT clause which removes duplicate rows from the
198        result set. If any fields are given, only check the given fields for
199        duplicates.
200        """
201        if fields:
202            raise NotSupportedError(
203                "DISTINCT ON fields is not supported by this database backend"
204            )
205        else:
206            return ["DISTINCT"], []
207
208    def fetch_returned_insert_columns(self, cursor, returning_params):
209        """
210        Given a cursor object that has just performed an INSERT...RETURNING
211        statement into a table, return the newly created data.
212        """
213        return cursor.fetchone()
214
215    def field_cast_sql(self, db_type, internal_type):
216        """
217        Given a column type (e.g. 'BLOB', 'VARCHAR') and an internal type
218        (e.g. 'GenericIPAddressField'), return the SQL to cast it before using
219        it in a WHERE statement. The resulting string should contain a '%s'
220        placeholder for the column being searched against.
221        """
222        return "%s"
223
224    def force_no_ordering(self):
225        """
226        Return a list used in the "ORDER BY" clause to force no ordering at
227        all. Return an empty list to include nothing in the ordering.
228        """
229        return []
230
231    def for_update_sql(self, nowait=False, skip_locked=False, of=(), no_key=False):
232        """
233        Return the FOR UPDATE SQL clause to lock rows for an update operation.
234        """
235        return "FOR{} UPDATE{}{}{}".format(
236            " NO KEY" if no_key else "",
237            " OF %s" % ", ".join(of) if of else "",
238            " NOWAIT" if nowait else "",
239            " SKIP LOCKED" if skip_locked else "",
240        )
241
242    def _get_limit_offset_params(self, low_mark, high_mark):
243        offset = low_mark or 0
244        if high_mark is not None:
245            return (high_mark - offset), offset
246        elif offset:
247            return self.connection.ops.no_limit_value(), offset
248        return None, offset
249
250    def limit_offset_sql(self, low_mark, high_mark):
251        """Return LIMIT/OFFSET SQL clause."""
252        limit, offset = self._get_limit_offset_params(low_mark, high_mark)
253        return " ".join(
254            sql
255            for sql in (
256                ("LIMIT %d" % limit) if limit else None,
257                ("OFFSET %d" % offset) if offset else None,
258            )
259            if sql
260        )
261
262    def last_executed_query(self, cursor, sql, params):
263        """
264        Return a string of the query last executed by the given cursor, with
265        placeholders replaced with actual values.
266
267        `sql` is the raw query containing placeholders and `params` is the
268        sequence of parameters. These are used by default, but this method
269        exists for database backends to provide a better implementation
270        according to their own quoting schemes.
271        """
272
273        # Convert params to contain string values.
274        def to_string(s):
275            return force_str(s, strings_only=True, errors="replace")
276
277        if isinstance(params, list | tuple):
278            u_params = tuple(to_string(val) for val in params)
279        elif params is None:
280            u_params = ()
281        else:
282            u_params = {to_string(k): to_string(v) for k, v in params.items()}
283
284        return f"QUERY = {sql!r} - PARAMS = {u_params!r}"
285
286    def last_insert_id(self, cursor, table_name, pk_name):
287        """
288        Given a cursor object that has just performed an INSERT statement into
289        a table that has an auto-incrementing ID, return the newly created ID.
290
291        `pk_name` is the name of the primary-key column.
292        """
293        return cursor.lastrowid
294
295    def lookup_cast(self, lookup_type, internal_type=None):
296        """
297        Return the string to use in a query when performing lookups
298        ("contains", "like", etc.). It should contain a '%s' placeholder for
299        the column being searched against.
300        """
301        return "%s"
302
303    def max_in_list_size(self):
304        """
305        Return the maximum number of items that can be passed in a single 'IN'
306        list condition, or None if the backend does not impose a limit.
307        """
308        return None
309
310    def max_name_length(self):
311        """
312        Return the maximum length of table and column names, or None if there
313        is no limit.
314        """
315        return None
316
317    def no_limit_value(self):
318        """
319        Return the value to use for the LIMIT when we are wanting "LIMIT
320        infinity". Return None if the limit clause can be omitted in this case.
321        """
322        raise NotImplementedError(
323            "subclasses of BaseDatabaseOperations may require a no_limit_value() method"
324        )
325
326    def pk_default_value(self):
327        """
328        Return the value to use during an INSERT statement to specify that
329        the field should use its default value.
330        """
331        return "DEFAULT"
332
333    def prepare_sql_script(self, sql):
334        """
335        Take an SQL script that may contain multiple lines and return a list
336        of statements to feed to successive cursor.execute() calls.
337
338        Since few databases are able to process raw SQL scripts in a single
339        cursor.execute() call and PEP 249 doesn't talk about this use case,
340        the default implementation is conservative.
341        """
342        return [
343            sqlparse.format(statement, strip_comments=True)
344            for statement in sqlparse.split(sql)
345            if statement
346        ]
347
348    def process_clob(self, value):
349        """
350        Return the value of a CLOB column, for backends that return a locator
351        object that requires additional processing.
352        """
353        return value
354
355    def return_insert_columns(self, fields):
356        """
357        For backends that support returning columns as part of an insert query,
358        return the SQL and params to append to the INSERT query. The returned
359        fragment should contain a format string to hold the appropriate column.
360        """
361        pass
362
363    def compiler(self, compiler_name):
364        """
365        Return the SQLCompiler class corresponding to the given name,
366        in the namespace corresponding to the `compiler_module` attribute
367        on this backend.
368        """
369        if self._cache is None:
370            self._cache = import_module(self.compiler_module)
371        return getattr(self._cache, compiler_name)
372
373    def quote_name(self, name):
374        """
375        Return a quoted version of the given table, index, or column name. Do
376        not quote the given name if it's already been quoted.
377        """
378        raise NotImplementedError(
379            "subclasses of BaseDatabaseOperations may require a quote_name() method"
380        )
381
382    def regex_lookup(self, lookup_type):
383        """
384        Return the string to use in a query when performing regular expression
385        lookups (using "regex" or "iregex"). It should contain a '%s'
386        placeholder for the column being searched against.
387
388        If the feature is not supported (or part of it is not supported), raise
389        NotImplementedError.
390        """
391        raise NotImplementedError(
392            "subclasses of BaseDatabaseOperations may require a regex_lookup() method"
393        )
394
395    def savepoint_create_sql(self, sid):
396        """
397        Return the SQL for starting a new savepoint. Only required if the
398        "uses_savepoints" feature is True. The "sid" parameter is a string
399        for the savepoint id.
400        """
401        return "SAVEPOINT %s" % self.quote_name(sid)
402
403    def savepoint_commit_sql(self, sid):
404        """
405        Return the SQL for committing the given savepoint.
406        """
407        return "RELEASE SAVEPOINT %s" % self.quote_name(sid)
408
409    def savepoint_rollback_sql(self, sid):
410        """
411        Return the SQL for rolling back the given savepoint.
412        """
413        return "ROLLBACK TO SAVEPOINT %s" % self.quote_name(sid)
414
415    def set_time_zone_sql(self):
416        """
417        Return the SQL that will set the connection's time zone.
418
419        Return '' if the backend doesn't support time zones.
420        """
421        return ""
422
423    def sequence_reset_by_name_sql(self, style, sequences):
424        """
425        Return a list of the SQL statements required to reset sequences
426        passed in `sequences`.
427
428        The `style` argument is a Style object as returned by either
429        color_style() or no_style() in plain.internal.legacy.management.color.
430        """
431        return []
432
433    def sequence_reset_sql(self, style, model_list):
434        """
435        Return a list of the SQL statements required to reset sequences for
436        the given models.
437
438        The `style` argument is a Style object as returned by either
439        color_style() or no_style() in plain.internal.legacy.management.color.
440        """
441        return []  # No sequence reset required by default.
442
443    def start_transaction_sql(self):
444        """Return the SQL statement required to start a transaction."""
445        return "BEGIN;"
446
447    def end_transaction_sql(self, success=True):
448        """Return the SQL statement required to end a transaction."""
449        if not success:
450            return "ROLLBACK;"
451        return "COMMIT;"
452
453    def tablespace_sql(self, tablespace, inline=False):
454        """
455        Return the SQL that will be used in a query to define the tablespace.
456
457        Return '' if the backend doesn't support tablespaces.
458
459        If `inline` is True, append the SQL to a row; otherwise append it to
460        the entire CREATE TABLE or CREATE INDEX statement.
461        """
462        return ""
463
464    def prep_for_like_query(self, x):
465        """Prepare a value for use in a LIKE query."""
466        return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
467
468    # Same as prep_for_like_query(), but called for "iexact" matches, which
469    # need not necessarily be implemented using "LIKE" in the backend.
470    prep_for_iexact_query = prep_for_like_query
471
472    def validate_autopk_value(self, value):
473        """
474        Certain backends do not accept some values for "serial" fields
475        (for example zero in MySQL). Raise a ValueError if the value is
476        invalid, otherwise return the validated value.
477        """
478        return value
479
480    def adapt_unknown_value(self, value):
481        """
482        Transform a value to something compatible with the backend driver.
483
484        This method only depends on the type of the value. It's designed for
485        cases where the target type isn't known, such as .raw() SQL queries.
486        As a consequence it may not work perfectly in all circumstances.
487        """
488        if isinstance(value, datetime.datetime):  # must be before date
489            return self.adapt_datetimefield_value(value)
490        elif isinstance(value, datetime.date):
491            return self.adapt_datefield_value(value)
492        elif isinstance(value, datetime.time):
493            return self.adapt_timefield_value(value)
494        elif isinstance(value, decimal.Decimal):
495            return self.adapt_decimalfield_value(value)
496        else:
497            return value
498
499    def adapt_integerfield_value(self, value, internal_type):
500        return value
501
502    def adapt_datefield_value(self, value):
503        """
504        Transform a date value to an object compatible with what is expected
505        by the backend driver for date columns.
506        """
507        if value is None:
508            return None
509        return str(value)
510
511    def adapt_datetimefield_value(self, value):
512        """
513        Transform a datetime value to an object compatible with what is expected
514        by the backend driver for datetime columns.
515        """
516        if value is None:
517            return None
518        # Expression values are adapted by the database.
519        if hasattr(value, "resolve_expression"):
520            return value
521
522        return str(value)
523
524    def adapt_timefield_value(self, value):
525        """
526        Transform a time value to an object compatible with what is expected
527        by the backend driver for time columns.
528        """
529        if value is None:
530            return None
531        # Expression values are adapted by the database.
532        if hasattr(value, "resolve_expression"):
533            return value
534
535        if timezone.is_aware(value):
536            raise ValueError("Plain does not support timezone-aware times.")
537        return str(value)
538
539    def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
540        """
541        Transform a decimal.Decimal value to an object compatible with what is
542        expected by the backend driver for decimal (numeric) columns.
543        """
544        return utils.format_number(value, max_digits, decimal_places)
545
546    def adapt_ipaddressfield_value(self, value):
547        """
548        Transform a string representation of an IP address into the expected
549        type for the backend driver.
550        """
551        return value or None
552
553    def adapt_json_value(self, value, encoder):
554        return json.dumps(value, cls=encoder)
555
556    def year_lookup_bounds_for_date_field(self, value, iso_year=False):
557        """
558        Return a two-elements list with the lower and upper bound to be used
559        with a BETWEEN operator to query a DateField value using a year
560        lookup.
561
562        `value` is an int, containing the looked-up year.
563        If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
564        """
565        if iso_year:
566            first = datetime.date.fromisocalendar(value, 1, 1)
567            second = datetime.date.fromisocalendar(
568                value + 1, 1, 1
569            ) - datetime.timedelta(days=1)
570        else:
571            first = datetime.date(value, 1, 1)
572            second = datetime.date(value, 12, 31)
573        first = self.adapt_datefield_value(first)
574        second = self.adapt_datefield_value(second)
575        return [first, second]
576
577    def year_lookup_bounds_for_datetime_field(self, value, iso_year=False):
578        """
579        Return a two-elements list with the lower and upper bound to be used
580        with a BETWEEN operator to query a DateTimeField value using a year
581        lookup.
582
583        `value` is an int, containing the looked-up year.
584        If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
585        """
586        if iso_year:
587            first = datetime.datetime.fromisocalendar(value, 1, 1)
588            second = datetime.datetime.fromisocalendar(
589                value + 1, 1, 1
590            ) - datetime.timedelta(microseconds=1)
591        else:
592            first = datetime.datetime(value, 1, 1)
593            second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
594        if settings.USE_TZ:
595            tz = timezone.get_current_timezone()
596            first = timezone.make_aware(first, tz)
597            second = timezone.make_aware(second, tz)
598        first = self.adapt_datetimefield_value(first)
599        second = self.adapt_datetimefield_value(second)
600        return [first, second]
601
602    def get_db_converters(self, expression):
603        """
604        Return a list of functions needed to convert field data.
605
606        Some field types on some backends do not provide data in the correct
607        format, this is the hook for converter functions.
608        """
609        return []
610
611    def convert_durationfield_value(self, value, expression, connection):
612        if value is not None:
613            return datetime.timedelta(0, 0, value)
614
615    def check_expression_support(self, expression):
616        """
617        Check that the backend supports the provided expression.
618
619        This is used on specific backends to rule out known expressions
620        that have problematic or nonexistent implementations. If the
621        expression has a known problem, the backend should raise
622        NotSupportedError.
623        """
624        pass
625
626    def conditional_expression_supported_in_where_clause(self, expression):
627        """
628        Return True, if the conditional expression is supported in the WHERE
629        clause.
630        """
631        return True
632
633    def combine_expression(self, connector, sub_expressions):
634        """
635        Combine a list of subexpressions into a single expression, using
636        the provided connecting operator. This is required because operators
637        can vary between backends (e.g., Oracle with %% and &) and between
638        subexpression types (e.g., date expressions).
639        """
640        conn = " %s " % connector
641        return conn.join(sub_expressions)
642
643    def combine_duration_expression(self, connector, sub_expressions):
644        return self.combine_expression(connector, sub_expressions)
645
646    def binary_placeholder_sql(self, value):
647        """
648        Some backends require special syntax to insert binary content (MySQL
649        for example uses '_binary %s').
650        """
651        return "%s"
652
653    def modify_insert_params(self, placeholder, params):
654        """
655        Allow modification of insert parameters. Needed for Oracle Spatial
656        backend due to #10888.
657        """
658        return params
659
660    def integer_field_range(self, internal_type):
661        """
662        Given an integer field internal type (e.g. 'PositiveIntegerField'),
663        return a tuple of the (min_value, max_value) form representing the
664        range of the column type bound to the field.
665        """
666        return self.integer_field_ranges[internal_type]
667
668    def subtract_temporals(self, internal_type, lhs, rhs):
669        if self.connection.features.supports_temporal_subtraction:
670            lhs_sql, lhs_params = lhs
671            rhs_sql, rhs_params = rhs
672            return f"({lhs_sql} - {rhs_sql})", (*lhs_params, *rhs_params)
673        raise NotSupportedError(
674            "This backend does not support %s subtraction." % internal_type
675        )
676
677    def window_frame_start(self, start):
678        if isinstance(start, int):
679            if start < 0:
680                return "%d %s" % (abs(start), self.PRECEDING)
681            elif start == 0:
682                return self.CURRENT_ROW
683        elif start is None:
684            return self.UNBOUNDED_PRECEDING
685        raise ValueError(
686            "start argument must be a negative integer, zero, or None, but got '%s'."
687            % start
688        )
689
690    def window_frame_end(self, end):
691        if isinstance(end, int):
692            if end == 0:
693                return self.CURRENT_ROW
694            elif end > 0:
695                return "%d %s" % (end, self.FOLLOWING)
696        elif end is None:
697            return self.UNBOUNDED_FOLLOWING
698        raise ValueError(
699            "end argument must be a positive integer, zero, or None, but got '%s'."
700            % end
701        )
702
703    def window_frame_rows_start_end(self, start=None, end=None):
704        """
705        Return SQL for start and end points in an OVER clause window frame.
706        """
707        if not self.connection.features.supports_over_clause:
708            raise NotSupportedError("This backend does not support window expressions.")
709        return self.window_frame_start(start), self.window_frame_end(end)
710
711    def window_frame_range_start_end(self, start=None, end=None):
712        start_, end_ = self.window_frame_rows_start_end(start, end)
713        features = self.connection.features
714        if features.only_supports_unbounded_with_preceding_and_following and (
715            (start and start < 0) or (end and end > 0)
716        ):
717            raise NotSupportedError(
718                "%s only supports UNBOUNDED together with PRECEDING and "
719                "FOLLOWING." % self.connection.display_name
720            )
721        return start_, end_
722
723    def explain_query_prefix(self, format=None, **options):
724        if not self.connection.features.supports_explaining_query_execution:
725            raise NotSupportedError(
726                "This backend does not support explaining query execution."
727            )
728        if format:
729            supported_formats = self.connection.features.supported_explain_formats
730            normalized_format = format.upper()
731            if normalized_format not in supported_formats:
732                msg = "%s is not a recognized format." % normalized_format
733                if supported_formats:
734                    msg += " Allowed formats: %s" % ", ".join(sorted(supported_formats))
735                else:
736                    msg += (
737                        f" {self.connection.display_name} does not support any formats."
738                    )
739                raise ValueError(msg)
740        if options:
741            raise ValueError("Unknown options: %s" % ", ".join(sorted(options.keys())))
742        return self.explain_prefix
743
744    def insert_statement(self, on_conflict=None):
745        return "INSERT INTO"
746
747    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
748        return ""