Plain is headed towards 1.0! Subscribe for development updates →

   1import copy
   2import datetime
   3import functools
   4import inspect
   5from collections import defaultdict
   6from decimal import Decimal
   7from functools import cached_property
   8from types import NoneType
   9from uuid import UUID
  10
  11from plain.exceptions import EmptyResultSet, FieldError, FullResultSet
  12from plain.models import fields
  13from plain.models.constants import LOOKUP_SEP
  14from plain.models.db import (
  15    DatabaseError,
  16    NotSupportedError,
  17    db_connection,
  18)
  19from plain.models.query_utils import Q
  20from plain.utils.deconstruct import deconstructible
  21from plain.utils.hashable import make_hashable
  22
  23
  24class SQLiteNumericMixin:
  25    """
  26    Some expressions with output_field=DecimalField() must be cast to
  27    numeric to be properly filtered.
  28    """
  29
  30    def as_sqlite(self, compiler, connection, **extra_context):
  31        sql, params = self.as_sql(compiler, connection, **extra_context)
  32        try:
  33            if self.output_field.get_internal_type() == "DecimalField":
  34                sql = f"CAST({sql} AS NUMERIC)"
  35        except FieldError:
  36            pass
  37        return sql, params
  38
  39
  40class Combinable:
  41    """
  42    Provide the ability to combine one or two objects with
  43    some connector. For example F('foo') + F('bar').
  44    """
  45
  46    # Arithmetic connectors
  47    ADD = "+"
  48    SUB = "-"
  49    MUL = "*"
  50    DIV = "/"
  51    POW = "^"
  52    # The following is a quoted % operator - it is quoted because it can be
  53    # used in strings that also have parameter substitution.
  54    MOD = "%%"
  55
  56    # Bitwise operators - note that these are generated by .bitand()
  57    # and .bitor(), the '&' and '|' are reserved for boolean operator
  58    # usage.
  59    BITAND = "&"
  60    BITOR = "|"
  61    BITLEFTSHIFT = "<<"
  62    BITRIGHTSHIFT = ">>"
  63    BITXOR = "#"
  64
  65    def _combine(self, other, connector, reversed):
  66        if not hasattr(other, "resolve_expression"):
  67            # everything must be resolvable to an expression
  68            other = Value(other)
  69
  70        if reversed:
  71            return CombinedExpression(other, connector, self)
  72        return CombinedExpression(self, connector, other)
  73
  74    #############
  75    # OPERATORS #
  76    #############
  77
  78    def __neg__(self):
  79        return self._combine(-1, self.MUL, False)
  80
  81    def __add__(self, other):
  82        return self._combine(other, self.ADD, False)
  83
  84    def __sub__(self, other):
  85        return self._combine(other, self.SUB, False)
  86
  87    def __mul__(self, other):
  88        return self._combine(other, self.MUL, False)
  89
  90    def __truediv__(self, other):
  91        return self._combine(other, self.DIV, False)
  92
  93    def __mod__(self, other):
  94        return self._combine(other, self.MOD, False)
  95
  96    def __pow__(self, other):
  97        return self._combine(other, self.POW, False)
  98
  99    def __and__(self, other):
 100        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 101            return Q(self) & Q(other)
 102        raise NotImplementedError(
 103            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 104        )
 105
 106    def bitand(self, other):
 107        return self._combine(other, self.BITAND, False)
 108
 109    def bitleftshift(self, other):
 110        return self._combine(other, self.BITLEFTSHIFT, False)
 111
 112    def bitrightshift(self, other):
 113        return self._combine(other, self.BITRIGHTSHIFT, False)
 114
 115    def __xor__(self, other):
 116        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 117            return Q(self) ^ Q(other)
 118        raise NotImplementedError(
 119            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 120        )
 121
 122    def bitxor(self, other):
 123        return self._combine(other, self.BITXOR, False)
 124
 125    def __or__(self, other):
 126        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 127            return Q(self) | Q(other)
 128        raise NotImplementedError(
 129            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 130        )
 131
 132    def bitor(self, other):
 133        return self._combine(other, self.BITOR, False)
 134
 135    def __radd__(self, other):
 136        return self._combine(other, self.ADD, True)
 137
 138    def __rsub__(self, other):
 139        return self._combine(other, self.SUB, True)
 140
 141    def __rmul__(self, other):
 142        return self._combine(other, self.MUL, True)
 143
 144    def __rtruediv__(self, other):
 145        return self._combine(other, self.DIV, True)
 146
 147    def __rmod__(self, other):
 148        return self._combine(other, self.MOD, True)
 149
 150    def __rpow__(self, other):
 151        return self._combine(other, self.POW, True)
 152
 153    def __rand__(self, other):
 154        raise NotImplementedError(
 155            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 156        )
 157
 158    def __ror__(self, other):
 159        raise NotImplementedError(
 160            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 161        )
 162
 163    def __rxor__(self, other):
 164        raise NotImplementedError(
 165            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 166        )
 167
 168    def __invert__(self):
 169        return NegatedExpression(self)
 170
 171
 172class BaseExpression:
 173    """Base class for all query expressions."""
 174
 175    empty_result_set_value = NotImplemented
 176    # aggregate specific fields
 177    is_summary = False
 178    _output_field_resolved_to_none = False
 179    # Can the expression be used in a WHERE clause?
 180    filterable = True
 181    # Can the expression can be used as a source expression in Window?
 182    window_compatible = False
 183
 184    def __init__(self, output_field=None):
 185        if output_field is not None:
 186            self.output_field = output_field
 187
 188    def __getstate__(self):
 189        state = self.__dict__.copy()
 190        state.pop("convert_value", None)
 191        return state
 192
 193    def get_db_converters(self, connection):
 194        return (
 195            []
 196            if self.convert_value is self._convert_value_noop
 197            else [self.convert_value]
 198        ) + self.output_field.get_db_converters(connection)
 199
 200    def get_source_expressions(self):
 201        return []
 202
 203    def set_source_expressions(self, exprs):
 204        assert not exprs
 205
 206    def _parse_expressions(self, *expressions):
 207        return [
 208            arg
 209            if hasattr(arg, "resolve_expression")
 210            else (F(arg) if isinstance(arg, str) else Value(arg))
 211            for arg in expressions
 212        ]
 213
 214    def as_sql(self, compiler, connection):
 215        """
 216        Responsible for returning a (sql, [params]) tuple to be included
 217        in the current query.
 218
 219        Different backends can provide their own implementation, by
 220        providing an `as_{vendor}` method and patching the Expression:
 221
 222        ```
 223        def override_as_sql(self, compiler, connection):
 224            # custom logic
 225            return super().as_sql(compiler, connection)
 226        setattr(Expression, 'as_' + connection.vendor, override_as_sql)
 227        ```
 228
 229        Arguments:
 230         * compiler: the query compiler responsible for generating the query.
 231           Must have a compile method, returning a (sql, [params]) tuple.
 232           Calling compiler(value) will return a quoted `value`.
 233
 234         * connection: the database connection used for the current query.
 235
 236        Return: (sql, params)
 237          Where `sql` is a string containing ordered sql parameters to be
 238          replaced with the elements of the list `params`.
 239        """
 240        raise NotImplementedError("Subclasses must implement as_sql()")
 241
 242    @cached_property
 243    def contains_aggregate(self):
 244        return any(
 245            expr and expr.contains_aggregate for expr in self.get_source_expressions()
 246        )
 247
 248    @cached_property
 249    def contains_over_clause(self):
 250        return any(
 251            expr and expr.contains_over_clause for expr in self.get_source_expressions()
 252        )
 253
 254    @cached_property
 255    def contains_column_references(self):
 256        return any(
 257            expr and expr.contains_column_references
 258            for expr in self.get_source_expressions()
 259        )
 260
 261    def resolve_expression(
 262        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
 263    ):
 264        """
 265        Provide the chance to do any preprocessing or validation before being
 266        added to the query.
 267
 268        Arguments:
 269         * query: the backend query implementation
 270         * allow_joins: boolean allowing or denying use of joins
 271           in this query
 272         * reuse: a set of reusable joins for multijoins
 273         * summarize: a terminal aggregate clause
 274         * for_save: whether this expression about to be used in a save or update
 275
 276        Return: an Expression to be added to the query.
 277        """
 278        c = self.copy()
 279        c.is_summary = summarize
 280        c.set_source_expressions(
 281            [
 282                expr.resolve_expression(query, allow_joins, reuse, summarize)
 283                if expr
 284                else None
 285                for expr in c.get_source_expressions()
 286            ]
 287        )
 288        return c
 289
 290    @property
 291    def conditional(self):
 292        return isinstance(self.output_field, fields.BooleanField)
 293
 294    @property
 295    def field(self):
 296        return self.output_field
 297
 298    @cached_property
 299    def output_field(self):
 300        """Return the output type of this expressions."""
 301        output_field = self._resolve_output_field()
 302        if output_field is None:
 303            self._output_field_resolved_to_none = True
 304            raise FieldError("Cannot resolve expression type, unknown output_field")
 305        return output_field
 306
 307    @cached_property
 308    def _output_field_or_none(self):
 309        """
 310        Return the output field of this expression, or None if
 311        _resolve_output_field() didn't return an output type.
 312        """
 313        try:
 314            return self.output_field
 315        except FieldError:
 316            if not self._output_field_resolved_to_none:
 317                raise
 318
 319    def _resolve_output_field(self):
 320        """
 321        Attempt to infer the output type of the expression.
 322
 323        As a guess, if the output fields of all source fields match then simply
 324        infer the same type here.
 325
 326        If a source's output field resolves to None, exclude it from this check.
 327        If all sources are None, then an error is raised higher up the stack in
 328        the output_field property.
 329        """
 330        # This guess is mostly a bad idea, but there is quite a lot of code
 331        # (especially 3rd party Func subclasses) that depend on it, we'd need a
 332        # deprecation path to fix it.
 333        sources_iter = (
 334            source for source in self.get_source_fields() if source is not None
 335        )
 336        for output_field in sources_iter:
 337            for source in sources_iter:
 338                if not isinstance(output_field, source.__class__):
 339                    raise FieldError(
 340                        f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
 341                        "set output_field."
 342                    )
 343            return output_field
 344
 345    @staticmethod
 346    def _convert_value_noop(value, expression, connection):
 347        return value
 348
 349    @cached_property
 350    def convert_value(self):
 351        """
 352        Expressions provide their own converters because users have the option
 353        of manually specifying the output_field which may be a different type
 354        from the one the database returns.
 355        """
 356        field = self.output_field
 357        internal_type = field.get_internal_type()
 358        if internal_type == "FloatField":
 359            return (
 360                lambda value, expression, connection: None
 361                if value is None
 362                else float(value)
 363            )
 364        elif internal_type.endswith("IntegerField"):
 365            return (
 366                lambda value, expression, connection: None
 367                if value is None
 368                else int(value)
 369            )
 370        elif internal_type == "DecimalField":
 371            return (
 372                lambda value, expression, connection: None
 373                if value is None
 374                else Decimal(value)
 375            )
 376        return self._convert_value_noop
 377
 378    def get_lookup(self, lookup):
 379        return self.output_field.get_lookup(lookup)
 380
 381    def get_transform(self, name):
 382        return self.output_field.get_transform(name)
 383
 384    def relabeled_clone(self, change_map):
 385        clone = self.copy()
 386        clone.set_source_expressions(
 387            [
 388                e.relabeled_clone(change_map) if e is not None else None
 389                for e in self.get_source_expressions()
 390            ]
 391        )
 392        return clone
 393
 394    def replace_expressions(self, replacements):
 395        if replacement := replacements.get(self):
 396            return replacement
 397        clone = self.copy()
 398        source_expressions = clone.get_source_expressions()
 399        clone.set_source_expressions(
 400            [
 401                expr.replace_expressions(replacements) if expr else None
 402                for expr in source_expressions
 403            ]
 404        )
 405        return clone
 406
 407    def get_refs(self):
 408        refs = set()
 409        for expr in self.get_source_expressions():
 410            refs |= expr.get_refs()
 411        return refs
 412
 413    def copy(self):
 414        return copy.copy(self)
 415
 416    def prefix_references(self, prefix):
 417        clone = self.copy()
 418        clone.set_source_expressions(
 419            [
 420                F(f"{prefix}{expr.name}")
 421                if isinstance(expr, F)
 422                else expr.prefix_references(prefix)
 423                for expr in self.get_source_expressions()
 424            ]
 425        )
 426        return clone
 427
 428    def get_group_by_cols(self):
 429        if not self.contains_aggregate:
 430            return [self]
 431        cols = []
 432        for source in self.get_source_expressions():
 433            cols.extend(source.get_group_by_cols())
 434        return cols
 435
 436    def get_source_fields(self):
 437        """Return the underlying field types used by this aggregate."""
 438        return [e._output_field_or_none for e in self.get_source_expressions()]
 439
 440    def asc(self, **kwargs):
 441        return OrderBy(self, **kwargs)
 442
 443    def desc(self, **kwargs):
 444        return OrderBy(self, descending=True, **kwargs)
 445
 446    def reverse_ordering(self):
 447        return self
 448
 449    def flatten(self):
 450        """
 451        Recursively yield this expression and all subexpressions, in
 452        depth-first order.
 453        """
 454        yield self
 455        for expr in self.get_source_expressions():
 456            if expr:
 457                if hasattr(expr, "flatten"):
 458                    yield from expr.flatten()
 459                else:
 460                    yield expr
 461
 462    def select_format(self, compiler, sql, params):
 463        """
 464        Custom format for select clauses. For example, EXISTS expressions need
 465        to be wrapped in CASE WHEN on Oracle.
 466        """
 467        if hasattr(self.output_field, "select_format"):
 468            return self.output_field.select_format(compiler, sql, params)
 469        return sql, params
 470
 471
 472@deconstructible
 473class Expression(BaseExpression, Combinable):
 474    """An expression that can be combined with other expressions."""
 475
 476    @cached_property
 477    def identity(self):
 478        constructor_signature = inspect.signature(self.__init__)
 479        args, kwargs = self._constructor_args
 480        signature = constructor_signature.bind_partial(*args, **kwargs)
 481        signature.apply_defaults()
 482        arguments = signature.arguments.items()
 483        identity = [self.__class__]
 484        for arg, value in arguments:
 485            if isinstance(value, fields.Field):
 486                if value.name and value.model:
 487                    value = (value.model._meta.label, value.name)
 488                else:
 489                    value = type(value)
 490            else:
 491                value = make_hashable(value)
 492            identity.append((arg, value))
 493        return tuple(identity)
 494
 495    def __eq__(self, other):
 496        if not isinstance(other, Expression):
 497            return NotImplemented
 498        return other.identity == self.identity
 499
 500    def __hash__(self):
 501        return hash(self.identity)
 502
 503
 504# Type inference for CombinedExpression.output_field.
 505# Missing items will result in FieldError, by design.
 506#
 507# The current approach for NULL is based on lowest common denominator behavior
 508# i.e. if one of the supported databases is raising an error (rather than
 509# return NULL) for `val <op> NULL`, then Plain raises FieldError.
 510
 511_connector_combinations = [
 512    # Numeric operations - operands of same type.
 513    {
 514        connector: [
 515            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 516            (fields.FloatField, fields.FloatField, fields.FloatField),
 517            (fields.DecimalField, fields.DecimalField, fields.DecimalField),
 518        ]
 519        for connector in (
 520            Combinable.ADD,
 521            Combinable.SUB,
 522            Combinable.MUL,
 523            # Behavior for DIV with integer arguments follows Postgres/SQLite,
 524            # not MySQL/Oracle.
 525            Combinable.DIV,
 526            Combinable.MOD,
 527            Combinable.POW,
 528        )
 529    },
 530    # Numeric operations - operands of different type.
 531    {
 532        connector: [
 533            (fields.IntegerField, fields.DecimalField, fields.DecimalField),
 534            (fields.DecimalField, fields.IntegerField, fields.DecimalField),
 535            (fields.IntegerField, fields.FloatField, fields.FloatField),
 536            (fields.FloatField, fields.IntegerField, fields.FloatField),
 537        ]
 538        for connector in (
 539            Combinable.ADD,
 540            Combinable.SUB,
 541            Combinable.MUL,
 542            Combinable.DIV,
 543            Combinable.MOD,
 544        )
 545    },
 546    # Bitwise operators.
 547    {
 548        connector: [
 549            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 550        ]
 551        for connector in (
 552            Combinable.BITAND,
 553            Combinable.BITOR,
 554            Combinable.BITLEFTSHIFT,
 555            Combinable.BITRIGHTSHIFT,
 556            Combinable.BITXOR,
 557        )
 558    },
 559    # Numeric with NULL.
 560    {
 561        connector: [
 562            (field_type, NoneType, field_type),
 563            (NoneType, field_type, field_type),
 564        ]
 565        for connector in (
 566            Combinable.ADD,
 567            Combinable.SUB,
 568            Combinable.MUL,
 569            Combinable.DIV,
 570            Combinable.MOD,
 571            Combinable.POW,
 572        )
 573        for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
 574    },
 575    # Date/DateTimeField/DurationField/TimeField.
 576    {
 577        Combinable.ADD: [
 578            # Date/DateTimeField.
 579            (fields.DateField, fields.DurationField, fields.DateTimeField),
 580            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 581            (fields.DurationField, fields.DateField, fields.DateTimeField),
 582            (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
 583            # DurationField.
 584            (fields.DurationField, fields.DurationField, fields.DurationField),
 585            # TimeField.
 586            (fields.TimeField, fields.DurationField, fields.TimeField),
 587            (fields.DurationField, fields.TimeField, fields.TimeField),
 588        ],
 589    },
 590    {
 591        Combinable.SUB: [
 592            # Date/DateTimeField.
 593            (fields.DateField, fields.DurationField, fields.DateTimeField),
 594            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 595            (fields.DateField, fields.DateField, fields.DurationField),
 596            (fields.DateField, fields.DateTimeField, fields.DurationField),
 597            (fields.DateTimeField, fields.DateField, fields.DurationField),
 598            (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
 599            # DurationField.
 600            (fields.DurationField, fields.DurationField, fields.DurationField),
 601            # TimeField.
 602            (fields.TimeField, fields.DurationField, fields.TimeField),
 603            (fields.TimeField, fields.TimeField, fields.DurationField),
 604        ],
 605    },
 606]
 607
 608_connector_combinators = defaultdict(list)
 609
 610
 611def register_combinable_fields(lhs, connector, rhs, result):
 612    """
 613    Register combinable types:
 614        lhs <connector> rhs -> result
 615    e.g.
 616        register_combinable_fields(
 617            IntegerField, Combinable.ADD, FloatField, FloatField
 618        )
 619    """
 620    _connector_combinators[connector].append((lhs, rhs, result))
 621
 622
 623for d in _connector_combinations:
 624    for connector, field_types in d.items():
 625        for lhs, rhs, result in field_types:
 626            register_combinable_fields(lhs, connector, rhs, result)
 627
 628
 629@functools.lru_cache(maxsize=128)
 630def _resolve_combined_type(connector, lhs_type, rhs_type):
 631    combinators = _connector_combinators.get(connector, ())
 632    for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
 633        if issubclass(lhs_type, combinator_lhs_type) and issubclass(
 634            rhs_type, combinator_rhs_type
 635        ):
 636            return combined_type
 637
 638
 639class CombinedExpression(SQLiteNumericMixin, Expression):
 640    def __init__(self, lhs, connector, rhs, output_field=None):
 641        super().__init__(output_field=output_field)
 642        self.connector = connector
 643        self.lhs = lhs
 644        self.rhs = rhs
 645
 646    def __repr__(self):
 647        return f"<{self.__class__.__name__}: {self}>"
 648
 649    def __str__(self):
 650        return f"{self.lhs} {self.connector} {self.rhs}"
 651
 652    def get_source_expressions(self):
 653        return [self.lhs, self.rhs]
 654
 655    def set_source_expressions(self, exprs):
 656        self.lhs, self.rhs = exprs
 657
 658    def _resolve_output_field(self):
 659        # We avoid using super() here for reasons given in
 660        # Expression._resolve_output_field()
 661        combined_type = _resolve_combined_type(
 662            self.connector,
 663            type(self.lhs._output_field_or_none),
 664            type(self.rhs._output_field_or_none),
 665        )
 666        if combined_type is None:
 667            raise FieldError(
 668                f"Cannot infer type of {self.connector!r} expression involving these "
 669                f"types: {self.lhs.output_field.__class__.__name__}, "
 670                f"{self.rhs.output_field.__class__.__name__}. You must set "
 671                f"output_field."
 672            )
 673        return combined_type()
 674
 675    def as_sql(self, compiler, connection):
 676        expressions = []
 677        expression_params = []
 678        sql, params = compiler.compile(self.lhs)
 679        expressions.append(sql)
 680        expression_params.extend(params)
 681        sql, params = compiler.compile(self.rhs)
 682        expressions.append(sql)
 683        expression_params.extend(params)
 684        # order of precedence
 685        expression_wrapper = "(%s)"
 686        sql = connection.ops.combine_expression(self.connector, expressions)
 687        return expression_wrapper % sql, expression_params
 688
 689    def resolve_expression(
 690        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
 691    ):
 692        lhs = self.lhs.resolve_expression(
 693            query, allow_joins, reuse, summarize, for_save
 694        )
 695        rhs = self.rhs.resolve_expression(
 696            query, allow_joins, reuse, summarize, for_save
 697        )
 698        if not isinstance(self, DurationExpression | TemporalSubtraction):
 699            try:
 700                lhs_type = lhs.output_field.get_internal_type()
 701            except (AttributeError, FieldError):
 702                lhs_type = None
 703            try:
 704                rhs_type = rhs.output_field.get_internal_type()
 705            except (AttributeError, FieldError):
 706                rhs_type = None
 707            if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
 708                return DurationExpression(
 709                    self.lhs, self.connector, self.rhs
 710                ).resolve_expression(
 711                    query,
 712                    allow_joins,
 713                    reuse,
 714                    summarize,
 715                    for_save,
 716                )
 717            datetime_fields = {"DateField", "DateTimeField", "TimeField"}
 718            if (
 719                self.connector == self.SUB
 720                and lhs_type in datetime_fields
 721                and lhs_type == rhs_type
 722            ):
 723                return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
 724                    query,
 725                    allow_joins,
 726                    reuse,
 727                    summarize,
 728                    for_save,
 729                )
 730        c = self.copy()
 731        c.is_summary = summarize
 732        c.lhs = lhs
 733        c.rhs = rhs
 734        return c
 735
 736
 737class DurationExpression(CombinedExpression):
 738    def compile(self, side, compiler, connection):
 739        try:
 740            output = side.output_field
 741        except FieldError:
 742            pass
 743        else:
 744            if output.get_internal_type() == "DurationField":
 745                sql, params = compiler.compile(side)
 746                return connection.ops.format_for_duration_arithmetic(sql), params
 747        return compiler.compile(side)
 748
 749    def as_sql(self, compiler, connection):
 750        if connection.features.has_native_duration_field:
 751            return super().as_sql(compiler, connection)
 752        connection.ops.check_expression_support(self)
 753        expressions = []
 754        expression_params = []
 755        sql, params = self.compile(self.lhs, compiler, connection)
 756        expressions.append(sql)
 757        expression_params.extend(params)
 758        sql, params = self.compile(self.rhs, compiler, connection)
 759        expressions.append(sql)
 760        expression_params.extend(params)
 761        # order of precedence
 762        expression_wrapper = "(%s)"
 763        sql = connection.ops.combine_duration_expression(self.connector, expressions)
 764        return expression_wrapper % sql, expression_params
 765
 766    def as_sqlite(self, compiler, connection, **extra_context):
 767        sql, params = self.as_sql(compiler, connection, **extra_context)
 768        if self.connector in {Combinable.MUL, Combinable.DIV}:
 769            try:
 770                lhs_type = self.lhs.output_field.get_internal_type()
 771                rhs_type = self.rhs.output_field.get_internal_type()
 772            except (AttributeError, FieldError):
 773                pass
 774            else:
 775                allowed_fields = {
 776                    "DecimalField",
 777                    "DurationField",
 778                    "FloatField",
 779                    "IntegerField",
 780                }
 781                if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
 782                    raise DatabaseError(
 783                        f"Invalid arguments for operator {self.connector}."
 784                    )
 785        return sql, params
 786
 787
 788class TemporalSubtraction(CombinedExpression):
 789    output_field = fields.DurationField()
 790
 791    def __init__(self, lhs, rhs):
 792        super().__init__(lhs, self.SUB, rhs)
 793
 794    def as_sql(self, compiler, connection):
 795        connection.ops.check_expression_support(self)
 796        lhs = compiler.compile(self.lhs)
 797        rhs = compiler.compile(self.rhs)
 798        return connection.ops.subtract_temporals(
 799            self.lhs.output_field.get_internal_type(), lhs, rhs
 800        )
 801
 802
 803@deconstructible(path="plain.models.F")
 804class F(Combinable):
 805    """An object capable of resolving references to existing query objects."""
 806
 807    def __init__(self, name):
 808        """
 809        Arguments:
 810         * name: the name of the field this expression references
 811        """
 812        self.name = name
 813
 814    def __repr__(self):
 815        return f"{self.__class__.__name__}({self.name})"
 816
 817    def resolve_expression(
 818        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
 819    ):
 820        return query.resolve_ref(self.name, allow_joins, reuse, summarize)
 821
 822    def replace_expressions(self, replacements):
 823        return replacements.get(self, self)
 824
 825    def asc(self, **kwargs):
 826        return OrderBy(self, **kwargs)
 827
 828    def desc(self, **kwargs):
 829        return OrderBy(self, descending=True, **kwargs)
 830
 831    def __eq__(self, other):
 832        return self.__class__ == other.__class__ and self.name == other.name
 833
 834    def __hash__(self):
 835        return hash(self.name)
 836
 837    def copy(self):
 838        return copy.copy(self)
 839
 840
 841class ResolvedOuterRef(F):
 842    """
 843    An object that contains a reference to an outer query.
 844
 845    In this case, the reference to the outer query has been resolved because
 846    the inner query has been used as a subquery.
 847    """
 848
 849    contains_aggregate = False
 850    contains_over_clause = False
 851
 852    def as_sql(self, *args, **kwargs):
 853        raise ValueError(
 854            "This queryset contains a reference to an outer query and may "
 855            "only be used in a subquery."
 856        )
 857
 858    def resolve_expression(self, *args, **kwargs):
 859        col = super().resolve_expression(*args, **kwargs)
 860        if col.contains_over_clause:
 861            raise NotSupportedError(
 862                f"Referencing outer query window expression is not supported: "
 863                f"{self.name}."
 864            )
 865        # FIXME: Rename possibly_multivalued to multivalued and fix detection
 866        # for non-multivalued JOINs (e.g. foreign key fields). This should take
 867        # into account only many-to-many and one-to-many relationships.
 868        col.possibly_multivalued = LOOKUP_SEP in self.name
 869        return col
 870
 871    def relabeled_clone(self, relabels):
 872        return self
 873
 874    def get_group_by_cols(self):
 875        return []
 876
 877
 878class OuterRef(F):
 879    contains_aggregate = False
 880
 881    def resolve_expression(self, *args, **kwargs):
 882        if isinstance(self.name, self.__class__):
 883            return self.name
 884        return ResolvedOuterRef(self.name)
 885
 886    def relabeled_clone(self, relabels):
 887        return self
 888
 889
 890@deconstructible(path="plain.models.Func")
 891class Func(SQLiteNumericMixin, Expression):
 892    """An SQL function call."""
 893
 894    function = None
 895    template = "%(function)s(%(expressions)s)"
 896    arg_joiner = ", "
 897    arity = None  # The number of arguments the function accepts.
 898
 899    def __init__(self, *expressions, output_field=None, **extra):
 900        if self.arity is not None and len(expressions) != self.arity:
 901            raise TypeError(
 902                "'{}' takes exactly {} {} ({} given)".format(
 903                    self.__class__.__name__,
 904                    self.arity,
 905                    "argument" if self.arity == 1 else "arguments",
 906                    len(expressions),
 907                )
 908            )
 909        super().__init__(output_field=output_field)
 910        self.source_expressions = self._parse_expressions(*expressions)
 911        self.extra = extra
 912
 913    def __repr__(self):
 914        args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
 915        extra = {**self.extra, **self._get_repr_options()}
 916        if extra:
 917            extra = ", ".join(
 918                str(key) + "=" + str(val) for key, val in sorted(extra.items())
 919            )
 920            return f"{self.__class__.__name__}({args}, {extra})"
 921        return f"{self.__class__.__name__}({args})"
 922
 923    def _get_repr_options(self):
 924        """Return a dict of extra __init__() options to include in the repr."""
 925        return {}
 926
 927    def get_source_expressions(self):
 928        return self.source_expressions
 929
 930    def set_source_expressions(self, exprs):
 931        self.source_expressions = exprs
 932
 933    def resolve_expression(
 934        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
 935    ):
 936        c = self.copy()
 937        c.is_summary = summarize
 938        for pos, arg in enumerate(c.source_expressions):
 939            c.source_expressions[pos] = arg.resolve_expression(
 940                query, allow_joins, reuse, summarize, for_save
 941            )
 942        return c
 943
 944    def as_sql(
 945        self,
 946        compiler,
 947        connection,
 948        function=None,
 949        template=None,
 950        arg_joiner=None,
 951        **extra_context,
 952    ):
 953        connection.ops.check_expression_support(self)
 954        sql_parts = []
 955        params = []
 956        for arg in self.source_expressions:
 957            try:
 958                arg_sql, arg_params = compiler.compile(arg)
 959            except EmptyResultSet:
 960                empty_result_set_value = getattr(
 961                    arg, "empty_result_set_value", NotImplemented
 962                )
 963                if empty_result_set_value is NotImplemented:
 964                    raise
 965                arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
 966            except FullResultSet:
 967                arg_sql, arg_params = compiler.compile(Value(True))
 968            sql_parts.append(arg_sql)
 969            params.extend(arg_params)
 970        data = {**self.extra, **extra_context}
 971        # Use the first supplied value in this order: the parameter to this
 972        # method, a value supplied in __init__()'s **extra (the value in
 973        # `data`), or the value defined on the class.
 974        if function is not None:
 975            data["function"] = function
 976        else:
 977            data.setdefault("function", self.function)
 978        template = template or data.get("template", self.template)
 979        arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
 980        data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
 981        return template % data, params
 982
 983    def copy(self):
 984        copy = super().copy()
 985        copy.source_expressions = self.source_expressions[:]
 986        copy.extra = self.extra.copy()
 987        return copy
 988
 989
 990@deconstructible(path="plain.models.Value")
 991class Value(SQLiteNumericMixin, Expression):
 992    """Represent a wrapped value as a node within an expression."""
 993
 994    # Provide a default value for `for_save` in order to allow unresolved
 995    # instances to be compiled until a decision is taken in #25425.
 996    for_save = False
 997
 998    def __init__(self, value, output_field=None):
 999        """
1000        Arguments:
1001         * value: the value this expression represents. The value will be
1002           added into the sql parameter list and properly quoted.
1003
1004         * output_field: an instance of the model field type that this
1005           expression will return, such as IntegerField() or CharField().
1006        """
1007        super().__init__(output_field=output_field)
1008        self.value = value
1009
1010    def __repr__(self):
1011        return f"{self.__class__.__name__}({self.value!r})"
1012
1013    def as_sql(self, compiler, connection):
1014        connection.ops.check_expression_support(self)
1015        val = self.value
1016        output_field = self._output_field_or_none
1017        if output_field is not None:
1018            if self.for_save:
1019                val = output_field.get_db_prep_save(val, connection=connection)
1020            else:
1021                val = output_field.get_db_prep_value(val, connection=connection)
1022            if hasattr(output_field, "get_placeholder"):
1023                return output_field.get_placeholder(val, compiler, connection), [val]
1024        if val is None:
1025            # cx_Oracle does not always convert None to the appropriate
1026            # NULL type (like in case expressions using numbers), so we
1027            # use a literal SQL NULL
1028            return "NULL", []
1029        return "%s", [val]
1030
1031    def resolve_expression(
1032        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1033    ):
1034        c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1035        c.for_save = for_save
1036        return c
1037
1038    def get_group_by_cols(self):
1039        return []
1040
1041    def _resolve_output_field(self):
1042        if isinstance(self.value, str):
1043            return fields.CharField()
1044        if isinstance(self.value, bool):
1045            return fields.BooleanField()
1046        if isinstance(self.value, int):
1047            return fields.IntegerField()
1048        if isinstance(self.value, float):
1049            return fields.FloatField()
1050        if isinstance(self.value, datetime.datetime):
1051            return fields.DateTimeField()
1052        if isinstance(self.value, datetime.date):
1053            return fields.DateField()
1054        if isinstance(self.value, datetime.time):
1055            return fields.TimeField()
1056        if isinstance(self.value, datetime.timedelta):
1057            return fields.DurationField()
1058        if isinstance(self.value, Decimal):
1059            return fields.DecimalField()
1060        if isinstance(self.value, bytes):
1061            return fields.BinaryField()
1062        if isinstance(self.value, UUID):
1063            return fields.UUIDField()
1064
1065    @property
1066    def empty_result_set_value(self):
1067        return self.value
1068
1069
1070class RawSQL(Expression):
1071    def __init__(self, sql, params, output_field=None):
1072        if output_field is None:
1073            output_field = fields.Field()
1074        self.sql, self.params = sql, params
1075        super().__init__(output_field=output_field)
1076
1077    def __repr__(self):
1078        return f"{self.__class__.__name__}({self.sql}, {self.params})"
1079
1080    def as_sql(self, compiler, connection):
1081        return f"({self.sql})", self.params
1082
1083    def get_group_by_cols(self):
1084        return [self]
1085
1086
1087class Star(Expression):
1088    def __repr__(self):
1089        return "'*'"
1090
1091    def as_sql(self, compiler, connection):
1092        return "*", []
1093
1094
1095class Col(Expression):
1096    contains_column_references = True
1097    possibly_multivalued = False
1098
1099    def __init__(self, alias, target, output_field=None):
1100        if output_field is None:
1101            output_field = target
1102        super().__init__(output_field=output_field)
1103        self.alias, self.target = alias, target
1104
1105    def __repr__(self):
1106        alias, target = self.alias, self.target
1107        identifiers = (alias, str(target)) if alias else (str(target),)
1108        return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1109
1110    def as_sql(self, compiler, connection):
1111        alias, column = self.alias, self.target.column
1112        identifiers = (alias, column) if alias else (column,)
1113        sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1114        return sql, []
1115
1116    def relabeled_clone(self, relabels):
1117        if self.alias is None:
1118            return self
1119        return self.__class__(
1120            relabels.get(self.alias, self.alias), self.target, self.output_field
1121        )
1122
1123    def get_group_by_cols(self):
1124        return [self]
1125
1126    def get_db_converters(self, connection):
1127        if self.target == self.output_field:
1128            return self.output_field.get_db_converters(connection)
1129        return self.output_field.get_db_converters(
1130            connection
1131        ) + self.target.get_db_converters(connection)
1132
1133
1134class Ref(Expression):
1135    """
1136    Reference to column alias of the query. For example, Ref('sum_cost') in
1137    qs.annotate(sum_cost=Sum('cost')) query.
1138    """
1139
1140    def __init__(self, refs, source):
1141        super().__init__()
1142        self.refs, self.source = refs, source
1143
1144    def __repr__(self):
1145        return f"{self.__class__.__name__}({self.refs}, {self.source})"
1146
1147    def get_source_expressions(self):
1148        return [self.source]
1149
1150    def set_source_expressions(self, exprs):
1151        (self.source,) = exprs
1152
1153    def resolve_expression(
1154        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1155    ):
1156        # The sub-expression `source` has already been resolved, as this is
1157        # just a reference to the name of `source`.
1158        return self
1159
1160    def get_refs(self):
1161        return {self.refs}
1162
1163    def relabeled_clone(self, relabels):
1164        return self
1165
1166    def as_sql(self, compiler, connection):
1167        return connection.ops.quote_name(self.refs), []
1168
1169    def get_group_by_cols(self):
1170        return [self]
1171
1172
1173class ExpressionList(Func):
1174    """
1175    An expression containing multiple expressions. Can be used to provide a
1176    list of expressions as an argument to another expression, like a partition
1177    clause.
1178    """
1179
1180    template = "%(expressions)s"
1181
1182    def __init__(self, *expressions, **extra):
1183        if not expressions:
1184            raise ValueError(
1185                f"{self.__class__.__name__} requires at least one expression."
1186            )
1187        super().__init__(*expressions, **extra)
1188
1189    def __str__(self):
1190        return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1191
1192    def as_sqlite(self, compiler, connection, **extra_context):
1193        # Casting to numeric is unnecessary.
1194        return self.as_sql(compiler, connection, **extra_context)
1195
1196
1197class OrderByList(Func):
1198    template = "ORDER BY %(expressions)s"
1199
1200    def __init__(self, *expressions, **extra):
1201        expressions = (
1202            (
1203                OrderBy(F(expr[1:]), descending=True)
1204                if isinstance(expr, str) and expr[0] == "-"
1205                else expr
1206            )
1207            for expr in expressions
1208        )
1209        super().__init__(*expressions, **extra)
1210
1211    def as_sql(self, *args, **kwargs):
1212        if not self.source_expressions:
1213            return "", ()
1214        return super().as_sql(*args, **kwargs)
1215
1216    def get_group_by_cols(self):
1217        group_by_cols = []
1218        for order_by in self.get_source_expressions():
1219            group_by_cols.extend(order_by.get_group_by_cols())
1220        return group_by_cols
1221
1222
1223@deconstructible(path="plain.models.ExpressionWrapper")
1224class ExpressionWrapper(SQLiteNumericMixin, Expression):
1225    """
1226    An expression that can wrap another expression so that it can provide
1227    extra context to the inner expression, such as the output_field.
1228    """
1229
1230    def __init__(self, expression, output_field):
1231        super().__init__(output_field=output_field)
1232        self.expression = expression
1233
1234    def set_source_expressions(self, exprs):
1235        self.expression = exprs[0]
1236
1237    def get_source_expressions(self):
1238        return [self.expression]
1239
1240    def get_group_by_cols(self):
1241        if isinstance(self.expression, Expression):
1242            expression = self.expression.copy()
1243            expression.output_field = self.output_field
1244            return expression.get_group_by_cols()
1245        # For non-expressions e.g. an SQL WHERE clause, the entire
1246        # `expression` must be included in the GROUP BY clause.
1247        return super().get_group_by_cols()
1248
1249    def as_sql(self, compiler, connection):
1250        return compiler.compile(self.expression)
1251
1252    def __repr__(self):
1253        return f"{self.__class__.__name__}({self.expression})"
1254
1255
1256class NegatedExpression(ExpressionWrapper):
1257    """The logical negation of a conditional expression."""
1258
1259    def __init__(self, expression):
1260        super().__init__(expression, output_field=fields.BooleanField())
1261
1262    def __invert__(self):
1263        return self.expression.copy()
1264
1265    def as_sql(self, compiler, connection):
1266        try:
1267            sql, params = super().as_sql(compiler, connection)
1268        except EmptyResultSet:
1269            features = compiler.connection.features
1270            if not features.supports_boolean_expr_in_select_clause:
1271                return "1=1", ()
1272            return compiler.compile(Value(True))
1273        ops = compiler.connection.ops
1274        # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1275        # to be compared to another expression unless they're wrapped in a CASE
1276        # WHEN.
1277        if not ops.conditional_expression_supported_in_where_clause(self.expression):
1278            return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1279        return f"NOT {sql}", params
1280
1281    def resolve_expression(
1282        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1283    ):
1284        resolved = super().resolve_expression(
1285            query, allow_joins, reuse, summarize, for_save
1286        )
1287        if not getattr(resolved.expression, "conditional", False):
1288            raise TypeError("Cannot negate non-conditional expressions.")
1289        return resolved
1290
1291    def select_format(self, compiler, sql, params):
1292        # Wrap boolean expressions with a CASE WHEN expression if a database
1293        # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1294        # GROUP BY list.
1295        expression_supported_in_where_clause = (
1296            compiler.connection.ops.conditional_expression_supported_in_where_clause
1297        )
1298        if (
1299            not compiler.connection.features.supports_boolean_expr_in_select_clause
1300            # Avoid double wrapping.
1301            and expression_supported_in_where_clause(self.expression)
1302        ):
1303            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1304        return sql, params
1305
1306
1307@deconstructible(path="plain.models.When")
1308class When(Expression):
1309    template = "WHEN %(condition)s THEN %(result)s"
1310    # This isn't a complete conditional expression, must be used in Case().
1311    conditional = False
1312
1313    def __init__(self, condition=None, then=None, **lookups):
1314        if lookups:
1315            if condition is None:
1316                condition, lookups = Q(**lookups), None
1317            elif getattr(condition, "conditional", False):
1318                condition, lookups = Q(condition, **lookups), None
1319        if condition is None or not getattr(condition, "conditional", False) or lookups:
1320            raise TypeError(
1321                "When() supports a Q object, a boolean expression, or lookups "
1322                "as a condition."
1323            )
1324        if isinstance(condition, Q) and not condition:
1325            raise ValueError("An empty Q() can't be used as a When() condition.")
1326        super().__init__(output_field=None)
1327        self.condition = condition
1328        self.result = self._parse_expressions(then)[0]
1329
1330    def __str__(self):
1331        return f"WHEN {self.condition!r} THEN {self.result!r}"
1332
1333    def __repr__(self):
1334        return f"<{self.__class__.__name__}: {self}>"
1335
1336    def get_source_expressions(self):
1337        return [self.condition, self.result]
1338
1339    def set_source_expressions(self, exprs):
1340        self.condition, self.result = exprs
1341
1342    def get_source_fields(self):
1343        # We're only interested in the fields of the result expressions.
1344        return [self.result._output_field_or_none]
1345
1346    def resolve_expression(
1347        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1348    ):
1349        c = self.copy()
1350        c.is_summary = summarize
1351        if hasattr(c.condition, "resolve_expression"):
1352            c.condition = c.condition.resolve_expression(
1353                query, allow_joins, reuse, summarize, False
1354            )
1355        c.result = c.result.resolve_expression(
1356            query, allow_joins, reuse, summarize, for_save
1357        )
1358        return c
1359
1360    def as_sql(self, compiler, connection, template=None, **extra_context):
1361        connection.ops.check_expression_support(self)
1362        template_params = extra_context
1363        sql_params = []
1364        condition_sql, condition_params = compiler.compile(self.condition)
1365        template_params["condition"] = condition_sql
1366        result_sql, result_params = compiler.compile(self.result)
1367        template_params["result"] = result_sql
1368        template = template or self.template
1369        return template % template_params, (
1370            *sql_params,
1371            *condition_params,
1372            *result_params,
1373        )
1374
1375    def get_group_by_cols(self):
1376        # This is not a complete expression and cannot be used in GROUP BY.
1377        cols = []
1378        for source in self.get_source_expressions():
1379            cols.extend(source.get_group_by_cols())
1380        return cols
1381
1382
1383@deconstructible(path="plain.models.Case")
1384class Case(SQLiteNumericMixin, Expression):
1385    """
1386    An SQL searched CASE expression:
1387
1388        CASE
1389            WHEN n > 0
1390                THEN 'positive'
1391            WHEN n < 0
1392                THEN 'negative'
1393            ELSE 'zero'
1394        END
1395    """
1396
1397    template = "CASE %(cases)s ELSE %(default)s END"
1398    case_joiner = " "
1399
1400    def __init__(self, *cases, default=None, output_field=None, **extra):
1401        if not all(isinstance(case, When) for case in cases):
1402            raise TypeError("Positional arguments must all be When objects.")
1403        super().__init__(output_field)
1404        self.cases = list(cases)
1405        self.default = self._parse_expressions(default)[0]
1406        self.extra = extra
1407
1408    def __str__(self):
1409        return "CASE {}, ELSE {!r}".format(
1410            ", ".join(str(c) for c in self.cases),
1411            self.default,
1412        )
1413
1414    def __repr__(self):
1415        return f"<{self.__class__.__name__}: {self}>"
1416
1417    def get_source_expressions(self):
1418        return self.cases + [self.default]
1419
1420    def set_source_expressions(self, exprs):
1421        *self.cases, self.default = exprs
1422
1423    def resolve_expression(
1424        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1425    ):
1426        c = self.copy()
1427        c.is_summary = summarize
1428        for pos, case in enumerate(c.cases):
1429            c.cases[pos] = case.resolve_expression(
1430                query, allow_joins, reuse, summarize, for_save
1431            )
1432        c.default = c.default.resolve_expression(
1433            query, allow_joins, reuse, summarize, for_save
1434        )
1435        return c
1436
1437    def copy(self):
1438        c = super().copy()
1439        c.cases = c.cases[:]
1440        return c
1441
1442    def as_sql(
1443        self, compiler, connection, template=None, case_joiner=None, **extra_context
1444    ):
1445        connection.ops.check_expression_support(self)
1446        if not self.cases:
1447            return compiler.compile(self.default)
1448        template_params = {**self.extra, **extra_context}
1449        case_parts = []
1450        sql_params = []
1451        default_sql, default_params = compiler.compile(self.default)
1452        for case in self.cases:
1453            try:
1454                case_sql, case_params = compiler.compile(case)
1455            except EmptyResultSet:
1456                continue
1457            except FullResultSet:
1458                default_sql, default_params = compiler.compile(case.result)
1459                break
1460            case_parts.append(case_sql)
1461            sql_params.extend(case_params)
1462        if not case_parts:
1463            return default_sql, default_params
1464        case_joiner = case_joiner or self.case_joiner
1465        template_params["cases"] = case_joiner.join(case_parts)
1466        template_params["default"] = default_sql
1467        sql_params.extend(default_params)
1468        template = template or template_params.get("template", self.template)
1469        sql = template % template_params
1470        if self._output_field_or_none is not None:
1471            sql = connection.ops.unification_cast_sql(self.output_field) % sql
1472        return sql, sql_params
1473
1474    def get_group_by_cols(self):
1475        if not self.cases:
1476            return self.default.get_group_by_cols()
1477        return super().get_group_by_cols()
1478
1479
1480class Subquery(BaseExpression, Combinable):
1481    """
1482    An explicit subquery. It may contain OuterRef() references to the outer
1483    query which will be resolved when it is applied to that query.
1484    """
1485
1486    template = "(%(subquery)s)"
1487    contains_aggregate = False
1488    empty_result_set_value = None
1489
1490    def __init__(self, queryset, output_field=None, **extra):
1491        # Allow the usage of both QuerySet and sql.Query objects.
1492        self.query = getattr(queryset, "query", queryset).clone()
1493        self.query.subquery = True
1494        self.extra = extra
1495        super().__init__(output_field)
1496
1497    def get_source_expressions(self):
1498        return [self.query]
1499
1500    def set_source_expressions(self, exprs):
1501        self.query = exprs[0]
1502
1503    def _resolve_output_field(self):
1504        return self.query.output_field
1505
1506    def copy(self):
1507        clone = super().copy()
1508        clone.query = clone.query.clone()
1509        return clone
1510
1511    @property
1512    def external_aliases(self):
1513        return self.query.external_aliases
1514
1515    def get_external_cols(self):
1516        return self.query.get_external_cols()
1517
1518    def as_sql(self, compiler, connection, template=None, **extra_context):
1519        connection.ops.check_expression_support(self)
1520        template_params = {**self.extra, **extra_context}
1521        subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1522        template_params["subquery"] = subquery_sql[1:-1]
1523
1524        template = template or template_params.get("template", self.template)
1525        sql = template % template_params
1526        return sql, sql_params
1527
1528    def get_group_by_cols(self):
1529        return self.query.get_group_by_cols(wrapper=self)
1530
1531
1532class Exists(Subquery):
1533    template = "EXISTS(%(subquery)s)"
1534    output_field = fields.BooleanField()
1535    empty_result_set_value = False
1536
1537    def __init__(self, queryset, **kwargs):
1538        super().__init__(queryset, **kwargs)
1539        self.query = self.query.exists()
1540
1541    def select_format(self, compiler, sql, params):
1542        # Wrap EXISTS() with a CASE WHEN expression if a database backend
1543        # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1544        # BY list.
1545        if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1546            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1547        return sql, params
1548
1549
1550@deconstructible(path="plain.models.OrderBy")
1551class OrderBy(Expression):
1552    template = "%(expression)s %(ordering)s"
1553    conditional = False
1554
1555    def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None):
1556        if nulls_first and nulls_last:
1557            raise ValueError("nulls_first and nulls_last are mutually exclusive")
1558        if nulls_first is False or nulls_last is False:
1559            raise ValueError("nulls_first and nulls_last values must be True or None.")
1560        self.nulls_first = nulls_first
1561        self.nulls_last = nulls_last
1562        self.descending = descending
1563        if not hasattr(expression, "resolve_expression"):
1564            raise ValueError("expression must be an expression type")
1565        self.expression = expression
1566
1567    def __repr__(self):
1568        return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1569
1570    def set_source_expressions(self, exprs):
1571        self.expression = exprs[0]
1572
1573    def get_source_expressions(self):
1574        return [self.expression]
1575
1576    def as_sql(self, compiler, connection, template=None, **extra_context):
1577        template = template or self.template
1578        if connection.features.supports_order_by_nulls_modifier:
1579            if self.nulls_last:
1580                template = f"{template} NULLS LAST"
1581            elif self.nulls_first:
1582                template = f"{template} NULLS FIRST"
1583        else:
1584            if self.nulls_last and not (
1585                self.descending and connection.features.order_by_nulls_first
1586            ):
1587                template = f"%(expression)s IS NULL, {template}"
1588            elif self.nulls_first and not (
1589                not self.descending and connection.features.order_by_nulls_first
1590            ):
1591                template = f"%(expression)s IS NOT NULL, {template}"
1592        connection.ops.check_expression_support(self)
1593        expression_sql, params = compiler.compile(self.expression)
1594        placeholders = {
1595            "expression": expression_sql,
1596            "ordering": "DESC" if self.descending else "ASC",
1597            **extra_context,
1598        }
1599        params *= template.count("%(expression)s")
1600        return (template % placeholders).rstrip(), params
1601
1602    def get_group_by_cols(self):
1603        cols = []
1604        for source in self.get_source_expressions():
1605            cols.extend(source.get_group_by_cols())
1606        return cols
1607
1608    def reverse_ordering(self):
1609        self.descending = not self.descending
1610        if self.nulls_first:
1611            self.nulls_last = True
1612            self.nulls_first = None
1613        elif self.nulls_last:
1614            self.nulls_first = True
1615            self.nulls_last = None
1616        return self
1617
1618    def asc(self):
1619        self.descending = False
1620
1621    def desc(self):
1622        self.descending = True
1623
1624
1625class Window(SQLiteNumericMixin, Expression):
1626    template = "%(expression)s OVER (%(window)s)"
1627    # Although the main expression may either be an aggregate or an
1628    # expression with an aggregate function, the GROUP BY that will
1629    # be introduced in the query as a result is not desired.
1630    contains_aggregate = False
1631    contains_over_clause = True
1632
1633    def __init__(
1634        self,
1635        expression,
1636        partition_by=None,
1637        order_by=None,
1638        frame=None,
1639        output_field=None,
1640    ):
1641        self.partition_by = partition_by
1642        self.order_by = order_by
1643        self.frame = frame
1644
1645        if not getattr(expression, "window_compatible", False):
1646            raise ValueError(
1647                f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1648            )
1649
1650        if self.partition_by is not None:
1651            if not isinstance(self.partition_by, tuple | list):
1652                self.partition_by = (self.partition_by,)
1653            self.partition_by = ExpressionList(*self.partition_by)
1654
1655        if self.order_by is not None:
1656            if isinstance(self.order_by, list | tuple):
1657                self.order_by = OrderByList(*self.order_by)
1658            elif isinstance(self.order_by, BaseExpression | str):
1659                self.order_by = OrderByList(self.order_by)
1660            else:
1661                raise ValueError(
1662                    "Window.order_by must be either a string reference to a "
1663                    "field, an expression, or a list or tuple of them."
1664                )
1665        super().__init__(output_field=output_field)
1666        self.source_expression = self._parse_expressions(expression)[0]
1667
1668    def _resolve_output_field(self):
1669        return self.source_expression.output_field
1670
1671    def get_source_expressions(self):
1672        return [self.source_expression, self.partition_by, self.order_by, self.frame]
1673
1674    def set_source_expressions(self, exprs):
1675        self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1676
1677    def as_sql(self, compiler, connection, template=None):
1678        connection.ops.check_expression_support(self)
1679        if not connection.features.supports_over_clause:
1680            raise NotSupportedError("This backend does not support window expressions.")
1681        expr_sql, params = compiler.compile(self.source_expression)
1682        window_sql, window_params = [], ()
1683
1684        if self.partition_by is not None:
1685            sql_expr, sql_params = self.partition_by.as_sql(
1686                compiler=compiler,
1687                connection=connection,
1688                template="PARTITION BY %(expressions)s",
1689            )
1690            window_sql.append(sql_expr)
1691            window_params += tuple(sql_params)
1692
1693        if self.order_by is not None:
1694            order_sql, order_params = compiler.compile(self.order_by)
1695            window_sql.append(order_sql)
1696            window_params += tuple(order_params)
1697
1698        if self.frame:
1699            frame_sql, frame_params = compiler.compile(self.frame)
1700            window_sql.append(frame_sql)
1701            window_params += tuple(frame_params)
1702
1703        template = template or self.template
1704
1705        return (
1706            template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1707            (*params, *window_params),
1708        )
1709
1710    def as_sqlite(self, compiler, connection):
1711        if isinstance(self.output_field, fields.DecimalField):
1712            # Casting to numeric must be outside of the window expression.
1713            copy = self.copy()
1714            source_expressions = copy.get_source_expressions()
1715            source_expressions[0].output_field = fields.FloatField()
1716            copy.set_source_expressions(source_expressions)
1717            return super(Window, copy).as_sqlite(compiler, connection)
1718        return self.as_sql(compiler, connection)
1719
1720    def __str__(self):
1721        return "{} OVER ({}{}{})".format(
1722            str(self.source_expression),
1723            "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1724            str(self.order_by or ""),
1725            str(self.frame or ""),
1726        )
1727
1728    def __repr__(self):
1729        return f"<{self.__class__.__name__}: {self}>"
1730
1731    def get_group_by_cols(self):
1732        group_by_cols = []
1733        if self.partition_by:
1734            group_by_cols.extend(self.partition_by.get_group_by_cols())
1735        if self.order_by is not None:
1736            group_by_cols.extend(self.order_by.get_group_by_cols())
1737        return group_by_cols
1738
1739
1740class WindowFrame(Expression):
1741    """
1742    Model the frame clause in window expressions. There are two types of frame
1743    clauses which are subclasses, however, all processing and validation (by no
1744    means intended to be complete) is done here. Thus, providing an end for a
1745    frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1746    row in the frame).
1747    """
1748
1749    template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1750
1751    def __init__(self, start=None, end=None):
1752        self.start = Value(start)
1753        self.end = Value(end)
1754
1755    def set_source_expressions(self, exprs):
1756        self.start, self.end = exprs
1757
1758    def get_source_expressions(self):
1759        return [self.start, self.end]
1760
1761    def as_sql(self, compiler, connection):
1762        connection.ops.check_expression_support(self)
1763        start, end = self.window_frame_start_end(
1764            connection, self.start.value, self.end.value
1765        )
1766        return (
1767            self.template
1768            % {
1769                "frame_type": self.frame_type,
1770                "start": start,
1771                "end": end,
1772            },
1773            [],
1774        )
1775
1776    def __repr__(self):
1777        return f"<{self.__class__.__name__}: {self}>"
1778
1779    def get_group_by_cols(self):
1780        return []
1781
1782    def __str__(self):
1783        if self.start.value is not None and self.start.value < 0:
1784            start = f"{abs(self.start.value)} {db_connection.ops.PRECEDING}"
1785        elif self.start.value is not None and self.start.value == 0:
1786            start = db_connection.ops.CURRENT_ROW
1787        else:
1788            start = db_connection.ops.UNBOUNDED_PRECEDING
1789
1790        if self.end.value is not None and self.end.value > 0:
1791            end = f"{self.end.value} {db_connection.ops.FOLLOWING}"
1792        elif self.end.value is not None and self.end.value == 0:
1793            end = db_connection.ops.CURRENT_ROW
1794        else:
1795            end = db_connection.ops.UNBOUNDED_FOLLOWING
1796        return self.template % {
1797            "frame_type": self.frame_type,
1798            "start": start,
1799            "end": end,
1800        }
1801
1802    def window_frame_start_end(self, connection, start, end):
1803        raise NotImplementedError("Subclasses must implement window_frame_start_end().")
1804
1805
1806class RowRange(WindowFrame):
1807    frame_type = "ROWS"
1808
1809    def window_frame_start_end(self, connection, start, end):
1810        return connection.ops.window_frame_rows_start_end(start, end)
1811
1812
1813class ValueRange(WindowFrame):
1814    frame_type = "RANGE"
1815
1816    def window_frame_start_end(self, connection, start, end):
1817        return connection.ops.window_frame_range_start_end(start, end)