Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import decimal
  3import json
  4from importlib import import_module
  5
  6import sqlparse
  7
  8from plain.models.backends import utils
  9from plain.models.db import NotSupportedError
 10from plain.utils import timezone
 11from plain.utils.encoding import force_str
 12
 13
 14class BaseDatabaseOperations:
 15    """
 16    Encapsulate backend-specific differences, such as the way a backend
 17    performs ordering or calculates the ID of a recently-inserted row.
 18    """
 19
 20    compiler_module = "plain.models.sql.compiler"
 21
 22    # Integer field safe ranges by `internal_type` as documented
 23    # in docs/ref/models/fields.txt.
 24    integer_field_ranges = {
 25        "SmallIntegerField": (-32768, 32767),
 26        "IntegerField": (-2147483648, 2147483647),
 27        "BigIntegerField": (-9223372036854775808, 9223372036854775807),
 28        "PositiveBigIntegerField": (0, 9223372036854775807),
 29        "PositiveSmallIntegerField": (0, 32767),
 30        "PositiveIntegerField": (0, 2147483647),
 31        "PrimaryKeyField": (-9223372036854775808, 9223372036854775807),
 32    }
 33    set_operators = {
 34        "union": "UNION",
 35        "intersection": "INTERSECT",
 36        "difference": "EXCEPT",
 37    }
 38    # Mapping of Field.get_internal_type() (typically the model field's class
 39    # name) to the data type to use for the Cast() function, if different from
 40    # DatabaseWrapper.data_types.
 41    cast_data_types = {}
 42    # CharField data type if the max_length argument isn't provided.
 43    cast_char_field_without_max_length = None
 44
 45    # Start and end points for window expressions.
 46    PRECEDING = "PRECEDING"
 47    FOLLOWING = "FOLLOWING"
 48    UNBOUNDED_PRECEDING = "UNBOUNDED " + PRECEDING
 49    UNBOUNDED_FOLLOWING = "UNBOUNDED " + FOLLOWING
 50    CURRENT_ROW = "CURRENT ROW"
 51
 52    # Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
 53    explain_prefix = None
 54
 55    def __init__(self, connection):
 56        self.connection = connection
 57        self._cache = None
 58
 59    def autoinc_sql(self, table, column):
 60        """
 61        Return any SQL needed to support auto-incrementing primary keys, or
 62        None if no SQL is necessary.
 63
 64        This SQL is executed when a table is created.
 65        """
 66        return None
 67
 68    def bulk_batch_size(self, fields, objs):
 69        """
 70        Return the maximum allowed batch size for the backend. The fields
 71        are the fields going to be inserted in the batch, the objs contains
 72        all the objects to be inserted.
 73        """
 74        return len(objs)
 75
 76    def format_for_duration_arithmetic(self, sql):
 77        raise NotImplementedError(
 78            "subclasses of BaseDatabaseOperations may require a "
 79            "format_for_duration_arithmetic() method."
 80        )
 81
 82    def unification_cast_sql(self, output_field):
 83        """
 84        Given a field instance, return the SQL that casts the result of a union
 85        to that type. The resulting string should contain a '%s' placeholder
 86        for the expression being cast.
 87        """
 88        return "%s"
 89
 90    def date_extract_sql(self, lookup_type, sql, params):
 91        """
 92        Given a lookup_type of 'year', 'month', or 'day', return the SQL that
 93        extracts a value from the given date field field_name.
 94        """
 95        raise NotImplementedError(
 96            "subclasses of BaseDatabaseOperations may require a date_extract_sql() "
 97            "method"
 98        )
 99
100    def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
101        """
102        Given a lookup_type of 'year', 'month', or 'day', return the SQL that
103        truncates the given date or datetime field field_name to a date object
104        with only the given specificity.
105
106        If `tzname` is provided, the given value is truncated in a specific
107        timezone.
108        """
109        raise NotImplementedError(
110            "subclasses of BaseDatabaseOperations may require a date_trunc_sql() "
111            "method."
112        )
113
114    def datetime_cast_date_sql(self, sql, params, tzname):
115        """
116        Return the SQL to cast a datetime value to date value.
117        """
118        raise NotImplementedError(
119            "subclasses of BaseDatabaseOperations may require a "
120            "datetime_cast_date_sql() method."
121        )
122
123    def datetime_cast_time_sql(self, sql, params, tzname):
124        """
125        Return the SQL to cast a datetime value to time value.
126        """
127        raise NotImplementedError(
128            "subclasses of BaseDatabaseOperations may require a "
129            "datetime_cast_time_sql() method"
130        )
131
132    def datetime_extract_sql(self, lookup_type, sql, params, tzname):
133        """
134        Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
135        'second', return the SQL that extracts a value from the given
136        datetime field field_name.
137        """
138        raise NotImplementedError(
139            "subclasses of BaseDatabaseOperations may require a datetime_extract_sql() "
140            "method"
141        )
142
143    def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
144        """
145        Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
146        'second', return the SQL that truncates the given datetime field
147        field_name to a datetime object with only the given specificity.
148        """
149        raise NotImplementedError(
150            "subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() "
151            "method"
152        )
153
154    def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
155        """
156        Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
157        that truncates the given time or datetime field field_name to a time
158        object with only the given specificity.
159
160        If `tzname` is provided, the given value is truncated in a specific
161        timezone.
162        """
163        raise NotImplementedError(
164            "subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
165        )
166
167    def time_extract_sql(self, lookup_type, sql, params):
168        """
169        Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
170        that extracts a value from the given time field field_name.
171        """
172        return self.date_extract_sql(lookup_type, sql, params)
173
174    def deferrable_sql(self):
175        """
176        Return the SQL to make a constraint "initially deferred" during a
177        CREATE TABLE statement.
178        """
179        return ""
180
181    def distinct_sql(self, fields, params):
182        """
183        Return an SQL DISTINCT clause which removes duplicate rows from the
184        result set. If any fields are given, only check the given fields for
185        duplicates.
186        """
187        if fields:
188            raise NotSupportedError(
189                "DISTINCT ON fields is not supported by this database backend"
190            )
191        else:
192            return ["DISTINCT"], []
193
194    def fetch_returned_insert_columns(self, cursor, returning_params):
195        """
196        Given a cursor object that has just performed an INSERT...RETURNING
197        statement into a table, return the newly created data.
198        """
199        return cursor.fetchone()
200
201    def field_cast_sql(self, db_type, internal_type):
202        """
203        Given a column type (e.g. 'BLOB', 'VARCHAR') and an internal type
204        (e.g. 'GenericIPAddressField'), return the SQL to cast it before using
205        it in a WHERE statement. The resulting string should contain a '%s'
206        placeholder for the column being searched against.
207        """
208        return "%s"
209
210    def force_no_ordering(self):
211        """
212        Return a list used in the "ORDER BY" clause to force no ordering at
213        all. Return an empty list to include nothing in the ordering.
214        """
215        return []
216
217    def for_update_sql(self, nowait=False, skip_locked=False, of=(), no_key=False):
218        """
219        Return the FOR UPDATE SQL clause to lock rows for an update operation.
220        """
221        return "FOR{} UPDATE{}{}{}".format(
222            " NO KEY" if no_key else "",
223            " OF {}".format(", ".join(of)) if of else "",
224            " NOWAIT" if nowait else "",
225            " SKIP LOCKED" if skip_locked else "",
226        )
227
228    def _get_limit_offset_params(self, low_mark, high_mark):
229        offset = low_mark or 0
230        if high_mark is not None:
231            return (high_mark - offset), offset
232        elif offset:
233            return self.connection.ops.no_limit_value(), offset
234        return None, offset
235
236    def limit_offset_sql(self, low_mark, high_mark):
237        """Return LIMIT/OFFSET SQL clause."""
238        limit, offset = self._get_limit_offset_params(low_mark, high_mark)
239        return " ".join(
240            sql
241            for sql in (
242                ("LIMIT %d" % limit) if limit else None,  # noqa: UP031
243                ("OFFSET %d" % offset) if offset else None,  # noqa: UP031
244            )
245            if sql
246        )
247
248    def last_executed_query(self, cursor, sql, params):
249        """
250        Return a string of the query last executed by the given cursor, with
251        placeholders replaced with actual values.
252
253        `sql` is the raw query containing placeholders and `params` is the
254        sequence of parameters. These are used by default, but this method
255        exists for database backends to provide a better implementation
256        according to their own quoting schemes.
257        """
258
259        # Convert params to contain string values.
260        def to_string(s):
261            return force_str(s, strings_only=True, errors="replace")
262
263        if isinstance(params, list | tuple):
264            u_params = tuple(to_string(val) for val in params)
265        elif params is None:
266            u_params = ()
267        else:
268            u_params = {to_string(k): to_string(v) for k, v in params.items()}
269
270        return f"QUERY = {sql!r} - PARAMS = {u_params!r}"
271
272    def last_insert_id(self, cursor, table_name, pk_name):
273        """
274        Given a cursor object that has just performed an INSERT statement into
275        a table that has an auto-incrementing ID, return the newly created ID.
276
277        `pk_name` is the name of the primary-key column.
278        """
279        return cursor.lastrowid
280
281    def lookup_cast(self, lookup_type, internal_type=None):
282        """
283        Return the string to use in a query when performing lookups
284        ("contains", "like", etc.). It should contain a '%s' placeholder for
285        the column being searched against.
286        """
287        return "%s"
288
289    def max_in_list_size(self):
290        """
291        Return the maximum number of items that can be passed in a single 'IN'
292        list condition, or None if the backend does not impose a limit.
293        """
294        return None
295
296    def max_name_length(self):
297        """
298        Return the maximum length of table and column names, or None if there
299        is no limit.
300        """
301        return None
302
303    def no_limit_value(self):
304        """
305        Return the value to use for the LIMIT when we are wanting "LIMIT
306        infinity". Return None if the limit clause can be omitted in this case.
307        """
308        raise NotImplementedError(
309            "subclasses of BaseDatabaseOperations may require a no_limit_value() method"
310        )
311
312    def pk_default_value(self):
313        """
314        Return the value to use during an INSERT statement to specify that
315        the field should use its default value.
316        """
317        return "DEFAULT"
318
319    def prepare_sql_script(self, sql):
320        """
321        Take an SQL script that may contain multiple lines and return a list
322        of statements to feed to successive cursor.execute() calls.
323
324        Since few databases are able to process raw SQL scripts in a single
325        cursor.execute() call and PEP 249 doesn't talk about this use case,
326        the default implementation is conservative.
327        """
328        return [
329            sqlparse.format(statement, strip_comments=True)
330            for statement in sqlparse.split(sql)
331            if statement
332        ]
333
334    def return_insert_columns(self, fields):
335        """
336        For backends that support returning columns as part of an insert query,
337        return the SQL and params to append to the INSERT query. The returned
338        fragment should contain a format string to hold the appropriate column.
339        """
340        pass
341
342    def compiler(self, compiler_name):
343        """
344        Return the SQLCompiler class corresponding to the given name,
345        in the namespace corresponding to the `compiler_module` attribute
346        on this backend.
347        """
348        if self._cache is None:
349            self._cache = import_module(self.compiler_module)
350        return getattr(self._cache, compiler_name)
351
352    def quote_name(self, name):
353        """
354        Return a quoted version of the given table, index, or column name. Do
355        not quote the given name if it's already been quoted.
356        """
357        raise NotImplementedError(
358            "subclasses of BaseDatabaseOperations may require a quote_name() method"
359        )
360
361    def regex_lookup(self, lookup_type):
362        """
363        Return the string to use in a query when performing regular expression
364        lookups (using "regex" or "iregex"). It should contain a '%s'
365        placeholder for the column being searched against.
366
367        If the feature is not supported (or part of it is not supported), raise
368        NotImplementedError.
369        """
370        raise NotImplementedError(
371            "subclasses of BaseDatabaseOperations may require a regex_lookup() method"
372        )
373
374    def savepoint_create_sql(self, sid):
375        """
376        Return the SQL for starting a new savepoint. Only required if the
377        "uses_savepoints" feature is True. The "sid" parameter is a string
378        for the savepoint id.
379        """
380        return f"SAVEPOINT {self.quote_name(sid)}"
381
382    def savepoint_commit_sql(self, sid):
383        """
384        Return the SQL for committing the given savepoint.
385        """
386        return f"RELEASE SAVEPOINT {self.quote_name(sid)}"
387
388    def savepoint_rollback_sql(self, sid):
389        """
390        Return the SQL for rolling back the given savepoint.
391        """
392        return f"ROLLBACK TO SAVEPOINT {self.quote_name(sid)}"
393
394    def set_time_zone_sql(self):
395        """
396        Return the SQL that will set the connection's time zone.
397
398        Return '' if the backend doesn't support time zones.
399        """
400        return ""
401
402    def prep_for_like_query(self, x):
403        """Prepare a value for use in a LIKE query."""
404        return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
405
406    # Same as prep_for_like_query(), but called for "iexact" matches, which
407    # need not necessarily be implemented using "LIKE" in the backend.
408    prep_for_iexact_query = prep_for_like_query
409
410    def validate_autopk_value(self, value):
411        """
412        Certain backends do not accept some values for "serial" fields
413        (for example zero in MySQL). Raise a ValueError if the value is
414        invalid, otherwise return the validated value.
415        """
416        return value
417
418    def adapt_unknown_value(self, value):
419        """
420        Transform a value to something compatible with the backend driver.
421
422        This method only depends on the type of the value. It's designed for
423        cases where the target type isn't known, such as .raw() SQL queries.
424        As a consequence it may not work perfectly in all circumstances.
425        """
426        if isinstance(value, datetime.datetime):  # must be before date
427            return self.adapt_datetimefield_value(value)
428        elif isinstance(value, datetime.date):
429            return self.adapt_datefield_value(value)
430        elif isinstance(value, datetime.time):
431            return self.adapt_timefield_value(value)
432        elif isinstance(value, decimal.Decimal):
433            return self.adapt_decimalfield_value(value)
434        else:
435            return value
436
437    def adapt_integerfield_value(self, value, internal_type):
438        return value
439
440    def adapt_datefield_value(self, value):
441        """
442        Transform a date value to an object compatible with what is expected
443        by the backend driver for date columns.
444        """
445        if value is None:
446            return None
447        return str(value)
448
449    def adapt_datetimefield_value(self, value):
450        """
451        Transform a datetime value to an object compatible with what is expected
452        by the backend driver for datetime columns.
453        """
454        if value is None:
455            return None
456        # Expression values are adapted by the database.
457        if hasattr(value, "resolve_expression"):
458            return value
459
460        return str(value)
461
462    def adapt_timefield_value(self, value):
463        """
464        Transform a time value to an object compatible with what is expected
465        by the backend driver for time columns.
466        """
467        if value is None:
468            return None
469        # Expression values are adapted by the database.
470        if hasattr(value, "resolve_expression"):
471            return value
472
473        if timezone.is_aware(value):
474            raise ValueError("Plain does not support timezone-aware times.")
475        return str(value)
476
477    def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
478        """
479        Transform a decimal.Decimal value to an object compatible with what is
480        expected by the backend driver for decimal (numeric) columns.
481        """
482        return utils.format_number(value, max_digits, decimal_places)
483
484    def adapt_ipaddressfield_value(self, value):
485        """
486        Transform a string representation of an IP address into the expected
487        type for the backend driver.
488        """
489        return value or None
490
491    def adapt_json_value(self, value, encoder):
492        return json.dumps(value, cls=encoder)
493
494    def year_lookup_bounds_for_date_field(self, value, iso_year=False):
495        """
496        Return a two-elements list with the lower and upper bound to be used
497        with a BETWEEN operator to query a DateField value using a year
498        lookup.
499
500        `value` is an int, containing the looked-up year.
501        If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
502        """
503        if iso_year:
504            first = datetime.date.fromisocalendar(value, 1, 1)
505            second = datetime.date.fromisocalendar(
506                value + 1, 1, 1
507            ) - datetime.timedelta(days=1)
508        else:
509            first = datetime.date(value, 1, 1)
510            second = datetime.date(value, 12, 31)
511        first = self.adapt_datefield_value(first)
512        second = self.adapt_datefield_value(second)
513        return [first, second]
514
515    def year_lookup_bounds_for_datetime_field(self, value, iso_year=False):
516        """
517        Return a two-elements list with the lower and upper bound to be used
518        with a BETWEEN operator to query a DateTimeField value using a year
519        lookup.
520
521        `value` is an int, containing the looked-up year.
522        If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
523        """
524        if iso_year:
525            first = datetime.datetime.fromisocalendar(value, 1, 1)
526            second = datetime.datetime.fromisocalendar(
527                value + 1, 1, 1
528            ) - datetime.timedelta(microseconds=1)
529        else:
530            first = datetime.datetime(value, 1, 1)
531            second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
532
533        # Make sure that datetimes are aware in the current timezone
534        tz = timezone.get_current_timezone()
535        first = timezone.make_aware(first, tz)
536        second = timezone.make_aware(second, tz)
537
538        first = self.adapt_datetimefield_value(first)
539        second = self.adapt_datetimefield_value(second)
540        return [first, second]
541
542    def get_db_converters(self, expression):
543        """
544        Return a list of functions needed to convert field data.
545
546        Some field types on some backends do not provide data in the correct
547        format, this is the hook for converter functions.
548        """
549        return []
550
551    def convert_durationfield_value(self, value, expression, connection):
552        if value is not None:
553            return datetime.timedelta(0, 0, value)
554
555    def check_expression_support(self, expression):
556        """
557        Check that the backend supports the provided expression.
558
559        This is used on specific backends to rule out known expressions
560        that have problematic or nonexistent implementations. If the
561        expression has a known problem, the backend should raise
562        NotSupportedError.
563        """
564        pass
565
566    def conditional_expression_supported_in_where_clause(self, expression):
567        """
568        Return True, if the conditional expression is supported in the WHERE
569        clause.
570        """
571        return True
572
573    def combine_expression(self, connector, sub_expressions):
574        """
575        Combine a list of subexpressions into a single expression, using
576        the provided connecting operator. This is required because operators
577        can vary between backends (e.g., Oracle with %% and &) and between
578        subexpression types (e.g., date expressions).
579        """
580        conn = f" {connector} "
581        return conn.join(sub_expressions)
582
583    def combine_duration_expression(self, connector, sub_expressions):
584        return self.combine_expression(connector, sub_expressions)
585
586    def binary_placeholder_sql(self, value):
587        """
588        Some backends require special syntax to insert binary content (MySQL
589        for example uses '_binary %s').
590        """
591        return "%s"
592
593    def modify_insert_params(self, placeholder, params):
594        """
595        Allow modification of insert parameters. Needed for Oracle Spatial
596        backend due to #10888.
597        """
598        return params
599
600    def integer_field_range(self, internal_type):
601        """
602        Given an integer field internal type (e.g. 'PositiveIntegerField'),
603        return a tuple of the (min_value, max_value) form representing the
604        range of the column type bound to the field.
605        """
606        return self.integer_field_ranges[internal_type]
607
608    def subtract_temporals(self, internal_type, lhs, rhs):
609        if self.connection.features.supports_temporal_subtraction:
610            lhs_sql, lhs_params = lhs
611            rhs_sql, rhs_params = rhs
612            return f"({lhs_sql} - {rhs_sql})", (*lhs_params, *rhs_params)
613        raise NotSupportedError(
614            f"This backend does not support {internal_type} subtraction."
615        )
616
617    def window_frame_start(self, start):
618        if isinstance(start, int):
619            if start < 0:
620                return "%d %s" % (abs(start), self.PRECEDING)  # noqa: UP031
621            elif start == 0:
622                return self.CURRENT_ROW
623        elif start is None:
624            return self.UNBOUNDED_PRECEDING
625        raise ValueError(
626            f"start argument must be a negative integer, zero, or None, but got '{start}'."
627        )
628
629    def window_frame_end(self, end):
630        if isinstance(end, int):
631            if end == 0:
632                return self.CURRENT_ROW
633            elif end > 0:
634                return "%d %s" % (end, self.FOLLOWING)  # noqa: UP031
635        elif end is None:
636            return self.UNBOUNDED_FOLLOWING
637        raise ValueError(
638            f"end argument must be a positive integer, zero, or None, but got '{end}'."
639        )
640
641    def window_frame_rows_start_end(self, start=None, end=None):
642        """
643        Return SQL for start and end points in an OVER clause window frame.
644        """
645        if not self.connection.features.supports_over_clause:
646            raise NotSupportedError("This backend does not support window expressions.")
647        return self.window_frame_start(start), self.window_frame_end(end)
648
649    def window_frame_range_start_end(self, start=None, end=None):
650        start_, end_ = self.window_frame_rows_start_end(start, end)
651        features = self.connection.features
652        if features.only_supports_unbounded_with_preceding_and_following and (
653            (start and start < 0) or (end and end > 0)
654        ):
655            raise NotSupportedError(
656                f"{self.connection.display_name} only supports UNBOUNDED together with PRECEDING and "
657                "FOLLOWING."
658            )
659        return start_, end_
660
661    def explain_query_prefix(self, format=None, **options):
662        if not self.connection.features.supports_explaining_query_execution:
663            raise NotSupportedError(
664                "This backend does not support explaining query execution."
665            )
666        if format:
667            supported_formats = self.connection.features.supported_explain_formats
668            normalized_format = format.upper()
669            if normalized_format not in supported_formats:
670                msg = f"{normalized_format} is not a recognized format."
671                if supported_formats:
672                    msg += " Allowed formats: {}".format(
673                        ", ".join(sorted(supported_formats))
674                    )
675                else:
676                    msg += (
677                        f" {self.connection.display_name} does not support any formats."
678                    )
679                raise ValueError(msg)
680        if options:
681            raise ValueError(
682                "Unknown options: {}".format(", ".join(sorted(options.keys())))
683            )
684        return self.explain_prefix
685
686    def insert_statement(self, on_conflict=None):
687        return "INSERT INTO"
688
689    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
690        return ""