Plain is headed towards 1.0! Subscribe for development updates →

   1from __future__ import annotations
   2
   3import copy
   4import datetime
   5import functools
   6import inspect
   7from abc import ABC, abstractmethod
   8from collections import defaultdict
   9from decimal import Decimal
  10from functools import cached_property
  11from types import NoneType
  12from typing import TYPE_CHECKING, Any, Protocol, Self, cast, runtime_checkable
  13from uuid import UUID
  14
  15from plain.models import fields
  16from plain.models.constants import LOOKUP_SEP
  17from plain.models.db import (
  18    DatabaseError,
  19    NotSupportedError,
  20    db_connection,
  21)
  22from plain.models.exceptions import EmptyResultSet, FieldError, FullResultSet
  23from plain.models.query_utils import Q
  24from plain.utils.deconstruct import deconstructible
  25from plain.utils.hashable import make_hashable
  26
  27if TYPE_CHECKING:
  28    from collections.abc import Callable, Iterable, Sequence
  29
  30    from plain.models.backends.base.base import BaseDatabaseWrapper
  31    from plain.models.fields import Field
  32    from plain.models.lookups import Lookup, Transform
  33    from plain.models.query import QuerySet
  34    from plain.models.sql.compiler import SQLCompilable, SQLCompiler
  35    from plain.models.sql.query import Query
  36
  37
  38@runtime_checkable
  39class ResolvableExpression(Protocol):
  40    """Protocol for expressions that can be resolved in query context."""
  41
  42    def resolve_expression(
  43        self,
  44        query: Any = None,
  45        allow_joins: bool = True,
  46        reuse: Any = None,
  47        summarize: bool = False,
  48        for_save: bool = False,
  49    ) -> Any: ...
  50
  51
  52@runtime_checkable
  53class ReplaceableExpression(Protocol):
  54    """Protocol for expressions that support expression replacement."""
  55
  56    def replace_expressions(self, replacements: dict[Any, Any]) -> Self: ...
  57
  58
  59class Combinable:
  60    """
  61    Provide the ability to combine one or two objects with
  62    some connector. For example F('foo') + F('bar').
  63    """
  64
  65    # Arithmetic connectors
  66    ADD = "+"
  67    SUB = "-"
  68    MUL = "*"
  69    DIV = "/"
  70    POW = "^"
  71    # The following is a quoted % operator - it is quoted because it can be
  72    # used in strings that also have parameter substitution.
  73    MOD = "%%"
  74
  75    # Bitwise operators - note that these are generated by .bitand()
  76    # and .bitor(), the '&' and '|' are reserved for boolean operator
  77    # usage.
  78    BITAND = "&"
  79    BITOR = "|"
  80    BITLEFTSHIFT = "<<"
  81    BITRIGHTSHIFT = ">>"
  82    BITXOR = "#"
  83
  84    def _combine(
  85        self, other: Any, connector: str, reversed: bool
  86    ) -> CombinedExpression:
  87        if not isinstance(other, ResolvableExpression):
  88            # everything must be resolvable to an expression
  89            other = Value(other)
  90
  91        if reversed:
  92            return CombinedExpression(other, connector, self)
  93        return CombinedExpression(self, connector, other)
  94
  95    #############
  96    # OPERATORS #
  97    #############
  98
  99    def __neg__(self) -> CombinedExpression:
 100        return self._combine(-1, self.MUL, False)
 101
 102    def __add__(self, other: Any) -> CombinedExpression:
 103        return self._combine(other, self.ADD, False)
 104
 105    def __sub__(self, other: Any) -> CombinedExpression:
 106        return self._combine(other, self.SUB, False)
 107
 108    def __mul__(self, other: Any) -> CombinedExpression:
 109        return self._combine(other, self.MUL, False)
 110
 111    def __truediv__(self, other: Any) -> CombinedExpression:
 112        return self._combine(other, self.DIV, False)
 113
 114    def __mod__(self, other: Any) -> CombinedExpression:
 115        return self._combine(other, self.MOD, False)
 116
 117    def __pow__(self, other: Any) -> CombinedExpression:
 118        return self._combine(other, self.POW, False)
 119
 120    def __and__(self, other: Any) -> Q:
 121        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 122            return Q(self) & Q(other)
 123        raise NotImplementedError(
 124            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 125        )
 126
 127    def bitand(self, other: Any) -> CombinedExpression:
 128        return self._combine(other, self.BITAND, False)
 129
 130    def bitleftshift(self, other: Any) -> CombinedExpression:
 131        return self._combine(other, self.BITLEFTSHIFT, False)
 132
 133    def bitrightshift(self, other: Any) -> CombinedExpression:
 134        return self._combine(other, self.BITRIGHTSHIFT, False)
 135
 136    def __xor__(self, other: Any) -> Q:
 137        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 138            return Q(self) ^ Q(other)
 139        raise NotImplementedError(
 140            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 141        )
 142
 143    def bitxor(self, other: Any) -> CombinedExpression:
 144        return self._combine(other, self.BITXOR, False)
 145
 146    def __or__(self, other: Any) -> Q:
 147        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 148            return Q(self) | Q(other)
 149        raise NotImplementedError(
 150            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 151        )
 152
 153    def bitor(self, other: Any) -> CombinedExpression:
 154        return self._combine(other, self.BITOR, False)
 155
 156    def __radd__(self, other: Any) -> CombinedExpression:
 157        return self._combine(other, self.ADD, True)
 158
 159    def __rsub__(self, other: Any) -> CombinedExpression:
 160        return self._combine(other, self.SUB, True)
 161
 162    def __rmul__(self, other: Any) -> CombinedExpression:
 163        return self._combine(other, self.MUL, True)
 164
 165    def __rtruediv__(self, other: Any) -> CombinedExpression:
 166        return self._combine(other, self.DIV, True)
 167
 168    def __rmod__(self, other: Any) -> CombinedExpression:
 169        return self._combine(other, self.MOD, True)
 170
 171    def __rpow__(self, other: Any) -> CombinedExpression:
 172        return self._combine(other, self.POW, True)
 173
 174    def __rand__(self, other: Any) -> None:
 175        raise NotImplementedError(
 176            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 177        )
 178
 179    def __ror__(self, other: Any) -> None:
 180        raise NotImplementedError(
 181            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 182        )
 183
 184    def __rxor__(self, other: Any) -> None:
 185        raise NotImplementedError(
 186            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 187        )
 188
 189    def __invert__(self) -> NegatedExpression:
 190        return NegatedExpression(self)
 191
 192
 193class BaseExpression:
 194    """Base class for all query expressions."""
 195
 196    empty_result_set_value = NotImplemented
 197    # aggregate specific fields
 198    is_summary = False
 199    _output_field_resolved_to_none = False
 200    # Can the expression be used in a WHERE clause?
 201    filterable = True
 202    # Can the expression can be used as a source expression in Window?
 203    window_compatible = False
 204
 205    def __init__(self, output_field: Field | None = None):
 206        if output_field is not None:
 207            self.output_field = output_field
 208
 209    def __getstate__(self) -> dict[str, Any]:
 210        state = self.__dict__.copy()
 211        state.pop("convert_value", None)
 212        return state
 213
 214    def get_db_converters(
 215        self, connection: BaseDatabaseWrapper
 216    ) -> list[Callable[..., Any]]:
 217        converters = []
 218        if self.convert_value is not self._convert_value_noop:
 219            converters.append(self.convert_value)
 220        converters.extend(self.output_field.get_db_converters(connection))
 221        return converters
 222
 223    def get_source_expressions(self) -> list[Any]:
 224        return []
 225
 226    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 227        assert not exprs
 228
 229    def _parse_expressions(self, *expressions: Any) -> list[Any]:
 230        return [
 231            arg
 232            if isinstance(arg, ResolvableExpression)
 233            else (F(arg) if isinstance(arg, str) else Value(arg))
 234            for arg in expressions
 235        ]
 236
 237    def as_sql(
 238        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 239    ) -> tuple[str, Sequence[Any]]:
 240        """
 241        Responsible for returning a (sql, [params]) tuple to be included
 242        in the current query.
 243
 244        Different backends can provide their own implementation, by
 245        providing an `as_{vendor}` method and patching the Expression:
 246
 247        ```
 248        def override_as_sql(self, compiler, connection):
 249            # custom logic
 250            return super().as_sql(compiler, connection)
 251        setattr(Expression, 'as_' + connection.vendor, override_as_sql)
 252        ```
 253
 254        Arguments:
 255         * compiler: the query compiler responsible for generating the query.
 256           Must have a compile method, returning a (sql, [params]) tuple.
 257           Calling compiler(value) will return a quoted `value`.
 258
 259         * connection: the database connection used for the current query.
 260
 261        Return: (sql, params)
 262          Where `sql` is a string containing ordered sql parameters to be
 263          replaced with the elements of the list `params`.
 264        """
 265        raise NotImplementedError("Subclasses must implement as_sql()")
 266
 267    @cached_property
 268    def contains_aggregate(self) -> bool:
 269        return any(
 270            expr and expr.contains_aggregate for expr in self.get_source_expressions()
 271        )
 272
 273    @cached_property
 274    def contains_over_clause(self) -> bool:
 275        return any(
 276            expr and expr.contains_over_clause for expr in self.get_source_expressions()
 277        )
 278
 279    @cached_property
 280    def contains_column_references(self) -> bool:
 281        return any(
 282            expr and expr.contains_column_references
 283            for expr in self.get_source_expressions()
 284        )
 285
 286    def resolve_expression(
 287        self,
 288        query: Any = None,
 289        allow_joins: bool = True,
 290        reuse: Any = None,
 291        summarize: bool = False,
 292        for_save: bool = False,
 293    ) -> Self:
 294        """
 295        Provide the chance to do any preprocessing or validation before being
 296        added to the query.
 297
 298        Arguments:
 299         * query: the backend query implementation
 300         * allow_joins: boolean allowing or denying use of joins
 301           in this query
 302         * reuse: a set of reusable joins for multijoins
 303         * summarize: a terminal aggregate clause
 304         * for_save: whether this expression about to be used in a save or update
 305
 306        Return: an Expression to be added to the query.
 307        """
 308        c = self.copy()
 309        c.is_summary = summarize
 310        c.set_source_expressions(
 311            [
 312                expr.resolve_expression(query, allow_joins, reuse, summarize)
 313                if expr
 314                else None
 315                for expr in c.get_source_expressions()
 316            ]
 317        )
 318        return c
 319
 320    @property
 321    def conditional(self) -> bool:
 322        output_field = getattr(self, "output_field", None)
 323        return isinstance(output_field, fields.BooleanField)
 324
 325    @property
 326    def field(self) -> Field:
 327        return self.output_field
 328
 329    @cached_property
 330    def output_field(self) -> Field:
 331        """Return the output type of this expressions."""
 332        output_field = self._resolve_output_field()
 333        if output_field is None:
 334            self._output_field_resolved_to_none = True
 335            raise FieldError("Cannot resolve expression type, unknown output_field")
 336        return output_field
 337
 338    @cached_property
 339    def _output_field_or_none(self) -> Field | None:
 340        """
 341        Return the output field of this expression, or None if
 342        _resolve_output_field() didn't return an output type.
 343        """
 344        try:
 345            return self.output_field
 346        except FieldError:
 347            if not self._output_field_resolved_to_none:
 348                raise
 349            return None
 350
 351    def _resolve_output_field(self) -> Field | None:
 352        """
 353        Attempt to infer the output type of the expression.
 354
 355        As a guess, if the output fields of all source fields match then simply
 356        infer the same type here.
 357
 358        If a source's output field resolves to None, exclude it from this check.
 359        If all sources are None, then an error is raised higher up the stack in
 360        the output_field property.
 361        """
 362        # This guess is mostly a bad idea, but there is quite a lot of code
 363        # (especially 3rd party Func subclasses) that depend on it, we'd need a
 364        # deprecation path to fix it.
 365        sources_iter = (
 366            source for source in self.get_source_fields() if source is not None
 367        )
 368        for output_field in sources_iter:
 369            for source in sources_iter:
 370                if not isinstance(output_field, source.__class__):
 371                    raise FieldError(
 372                        f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
 373                        "set output_field."
 374                    )
 375            return output_field
 376        return None
 377
 378    @staticmethod
 379    def _convert_value_noop(
 380        value: Any, expression: Any, connection: BaseDatabaseWrapper
 381    ) -> Any:
 382        return value
 383
 384    @cached_property
 385    def convert_value(self) -> Callable[[Any, Any, Any], Any]:
 386        """
 387        Expressions provide their own converters because users have the option
 388        of manually specifying the output_field which may be a different type
 389        from the one the database returns.
 390        """
 391        field = self.output_field
 392        internal_type = field.get_internal_type()
 393        if internal_type == "FloatField":
 394            return (
 395                lambda value, expression, connection: None
 396                if value is None
 397                else float(value)
 398            )
 399        elif internal_type.endswith("IntegerField"):
 400            return (
 401                lambda value, expression, connection: None
 402                if value is None
 403                else int(value)
 404            )
 405        elif internal_type == "DecimalField":
 406            return (
 407                lambda value, expression, connection: None
 408                if value is None
 409                else Decimal(value)
 410            )
 411        return self._convert_value_noop
 412
 413    def get_lookup(self, lookup: str) -> type[Lookup] | None:
 414        return self.output_field.get_lookup(lookup)
 415
 416    def get_transform(self, name: str) -> type[Transform] | None:
 417        return self.output_field.get_transform(name)
 418
 419    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
 420        clone = self.copy()
 421        clone.set_source_expressions(
 422            [
 423                e.relabeled_clone(change_map) if e is not None else None
 424                for e in self.get_source_expressions()
 425            ]
 426        )
 427        return clone
 428
 429    def replace_expressions(self, replacements: dict[BaseExpression, Any]) -> Self:
 430        if replacement := replacements.get(self):
 431            return replacement
 432        clone = self.copy()
 433        source_expressions = clone.get_source_expressions()
 434        clone.set_source_expressions(
 435            [
 436                expr.replace_expressions(replacements) if expr else None
 437                for expr in source_expressions
 438            ]
 439        )
 440        return clone
 441
 442    def get_refs(self) -> set[str]:
 443        refs = set()
 444        for expr in self.get_source_expressions():
 445            refs |= expr.get_refs()
 446        return refs
 447
 448    def copy(self) -> Self:
 449        return copy.copy(self)
 450
 451    def prefix_references(self, prefix: str) -> Self:
 452        clone = self.copy()
 453        clone.set_source_expressions(
 454            [
 455                F(f"{prefix}{expr.name}")
 456                if isinstance(expr, F)
 457                else expr.prefix_references(prefix)
 458                for expr in self.get_source_expressions()
 459            ]
 460        )
 461        return clone
 462
 463    def get_group_by_cols(self) -> list[BaseExpression]:
 464        if not self.contains_aggregate:
 465            return [self]
 466        cols = []
 467        for source in self.get_source_expressions():
 468            cols.extend(source.get_group_by_cols())
 469        return cols
 470
 471    def get_source_fields(self) -> list[Field | None]:
 472        """Return the underlying field types used by this aggregate."""
 473        return [e._output_field_or_none for e in self.get_source_expressions()]
 474
 475    def asc(self, **kwargs: Any) -> OrderBy:
 476        return OrderBy(self, **kwargs)
 477
 478    def desc(self, **kwargs: Any) -> OrderBy:
 479        return OrderBy(self, descending=True, **kwargs)
 480
 481    def reverse_ordering(self) -> Self:
 482        return self
 483
 484    def flatten(self) -> Iterable[Any]:
 485        """
 486        Recursively yield this expression and all subexpressions, in
 487        depth-first order.
 488        """
 489        yield self
 490        for expr in self.get_source_expressions():
 491            if expr:
 492                if hasattr(expr, "flatten"):
 493                    yield from expr.flatten()
 494                else:
 495                    yield expr
 496
 497    def select_format(
 498        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
 499    ) -> tuple[str, Sequence[Any]]:
 500        """
 501        Custom format for select clauses. For example, EXISTS expressions need
 502        to be wrapped in CASE WHEN on Oracle.
 503        """
 504        if output_field := getattr(self, "output_field", None):
 505            if select_format := getattr(output_field, "select_format", None):
 506                return select_format(compiler, sql, params)
 507        return sql, params
 508
 509
 510@deconstructible
 511class Expression(BaseExpression, Combinable):
 512    """An expression that can be combined with other expressions."""
 513
 514    # Set by @deconstructible decorator in __new__
 515    _constructor_args: tuple[tuple[Any, ...], dict[str, Any]]
 516
 517    @cached_property
 518    def identity(self) -> tuple[Any, ...]:
 519        constructor_signature = inspect.signature(self.__init__)
 520        args, kwargs = self._constructor_args
 521        signature = constructor_signature.bind_partial(*args, **kwargs)
 522        signature.apply_defaults()
 523        arguments = signature.arguments.items()
 524        identity: list[Any] = [self.__class__]
 525        for arg, value in arguments:
 526            if isinstance(value, fields.Field):
 527                if value.name and value.model:
 528                    value = (value.model.model_options.label, value.name)
 529                else:
 530                    value = type(value)
 531            else:
 532                value = make_hashable(value)
 533            identity.append((arg, value))
 534        return tuple(identity)
 535
 536    def __eq__(self, other: object) -> bool:
 537        if not isinstance(other, Expression):
 538            return NotImplemented
 539        return other.identity == self.identity
 540
 541    def __hash__(self) -> int:
 542        return hash(self.identity)
 543
 544
 545# Type inference for CombinedExpression.output_field.
 546# Missing items will result in FieldError, by design.
 547#
 548# The current approach for NULL is based on lowest common denominator behavior
 549# i.e. if one of the supported databases is raising an error (rather than
 550# return NULL) for `val <op> NULL`, then Plain raises FieldError.
 551
 552_connector_combinations = [
 553    # Numeric operations - operands of same type.
 554    {
 555        connector: [
 556            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 557            (fields.FloatField, fields.FloatField, fields.FloatField),
 558            (fields.DecimalField, fields.DecimalField, fields.DecimalField),
 559        ]
 560        for connector in (
 561            Combinable.ADD,
 562            Combinable.SUB,
 563            Combinable.MUL,
 564            # Behavior for DIV with integer arguments follows Postgres/SQLite,
 565            # not MySQL/Oracle.
 566            Combinable.DIV,
 567            Combinable.MOD,
 568            Combinable.POW,
 569        )
 570    },
 571    # Numeric operations - operands of different type.
 572    {
 573        connector: [
 574            (fields.IntegerField, fields.DecimalField, fields.DecimalField),
 575            (fields.DecimalField, fields.IntegerField, fields.DecimalField),
 576            (fields.IntegerField, fields.FloatField, fields.FloatField),
 577            (fields.FloatField, fields.IntegerField, fields.FloatField),
 578        ]
 579        for connector in (
 580            Combinable.ADD,
 581            Combinable.SUB,
 582            Combinable.MUL,
 583            Combinable.DIV,
 584            Combinable.MOD,
 585        )
 586    },
 587    # Bitwise operators.
 588    {
 589        connector: [
 590            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 591        ]
 592        for connector in (
 593            Combinable.BITAND,
 594            Combinable.BITOR,
 595            Combinable.BITLEFTSHIFT,
 596            Combinable.BITRIGHTSHIFT,
 597            Combinable.BITXOR,
 598        )
 599    },
 600    # Numeric with NULL.
 601    {
 602        connector: [
 603            (field_type, NoneType, field_type),
 604            (NoneType, field_type, field_type),
 605        ]
 606        for connector in (
 607            Combinable.ADD,
 608            Combinable.SUB,
 609            Combinable.MUL,
 610            Combinable.DIV,
 611            Combinable.MOD,
 612            Combinable.POW,
 613        )
 614        for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
 615    },
 616    # Date/DateTimeField/DurationField/TimeField.
 617    {
 618        Combinable.ADD: [
 619            # Date/DateTimeField.
 620            (fields.DateField, fields.DurationField, fields.DateTimeField),
 621            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 622            (fields.DurationField, fields.DateField, fields.DateTimeField),
 623            (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
 624            # DurationField.
 625            (fields.DurationField, fields.DurationField, fields.DurationField),
 626            # TimeField.
 627            (fields.TimeField, fields.DurationField, fields.TimeField),
 628            (fields.DurationField, fields.TimeField, fields.TimeField),
 629        ],
 630    },
 631    {
 632        Combinable.SUB: [
 633            # Date/DateTimeField.
 634            (fields.DateField, fields.DurationField, fields.DateTimeField),
 635            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 636            (fields.DateField, fields.DateField, fields.DurationField),
 637            (fields.DateField, fields.DateTimeField, fields.DurationField),
 638            (fields.DateTimeField, fields.DateField, fields.DurationField),
 639            (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
 640            # DurationField.
 641            (fields.DurationField, fields.DurationField, fields.DurationField),
 642            # TimeField.
 643            (fields.TimeField, fields.DurationField, fields.TimeField),
 644            (fields.TimeField, fields.TimeField, fields.DurationField),
 645        ],
 646    },
 647]
 648
 649_connector_combinators = defaultdict(list)
 650
 651
 652def register_combinable_fields(
 653    lhs: type[Field] | type[None],
 654    connector: str,
 655    rhs: type[Field] | type[None],
 656    result: type[Field],
 657) -> None:
 658    """
 659    Register combinable types:
 660        lhs <connector> rhs -> result
 661    e.g.
 662        register_combinable_fields(
 663            IntegerField, Combinable.ADD, FloatField, FloatField
 664        )
 665    """
 666    _connector_combinators[connector].append((lhs, rhs, result))
 667
 668
 669for d in _connector_combinations:
 670    for connector, field_types in d.items():
 671        for lhs, rhs, result in field_types:
 672            register_combinable_fields(lhs, connector, rhs, result)
 673
 674
 675@functools.lru_cache(maxsize=128)
 676def _resolve_combined_type(
 677    connector: str, lhs_type: type[Field], rhs_type: type[Field]
 678) -> type[Field] | None:
 679    combinators = _connector_combinators.get(connector, ())
 680    for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
 681        if issubclass(lhs_type, combinator_lhs_type) and issubclass(
 682            rhs_type, combinator_rhs_type
 683        ):
 684            return combined_type
 685    return None
 686
 687
 688class SQLiteNumericMixin(Expression):
 689    """
 690    Some expressions with output_field=DecimalField() must be cast to
 691    numeric to be properly filtered.
 692    """
 693
 694    def as_sqlite(
 695        self,
 696        compiler: SQLCompiler,
 697        connection: BaseDatabaseWrapper,
 698        **extra_context: Any,
 699    ) -> tuple[str, Sequence[Any]]:
 700        sql, params = self.as_sql(compiler, connection, **extra_context)
 701        try:
 702            if self.output_field.get_internal_type() == "DecimalField":
 703                sql = f"CAST({sql} AS NUMERIC)"
 704        except FieldError:
 705            pass
 706        return sql, params
 707
 708
 709class CombinedExpression(SQLiteNumericMixin, Expression):
 710    def __init__(
 711        self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
 712    ):
 713        super().__init__(output_field=output_field)
 714        self.connector = connector
 715        self.lhs = lhs
 716        self.rhs = rhs
 717
 718    def __repr__(self) -> str:
 719        return f"<{self.__class__.__name__}: {self}>"
 720
 721    def __str__(self) -> str:
 722        return f"{self.lhs} {self.connector} {self.rhs}"
 723
 724    def get_source_expressions(self) -> list[Any]:
 725        return [self.lhs, self.rhs]
 726
 727    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 728        self.lhs, self.rhs = exprs
 729
 730    def _resolve_output_field(self) -> Field | None:
 731        # We avoid using super() here for reasons given in
 732        # Expression._resolve_output_field()
 733        combined_type = _resolve_combined_type(
 734            self.connector,
 735            type(self.lhs._output_field_or_none),
 736            type(self.rhs._output_field_or_none),
 737        )
 738        if combined_type is None:
 739            raise FieldError(
 740                f"Cannot infer type of {self.connector!r} expression involving these "
 741                f"types: {self.lhs.output_field.__class__.__name__}, "
 742                f"{self.rhs.output_field.__class__.__name__}. You must set "
 743                f"output_field."
 744            )
 745        return combined_type()
 746
 747    def as_sql(
 748        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 749    ) -> tuple[str, list[Any]]:
 750        expressions = []
 751        expression_params = []
 752        sql, params = compiler.compile(self.lhs)
 753        expressions.append(sql)
 754        expression_params.extend(params)
 755        sql, params = compiler.compile(self.rhs)
 756        expressions.append(sql)
 757        expression_params.extend(params)
 758        # order of precedence
 759        expression_wrapper = "(%s)"
 760        sql = connection.ops.combine_expression(self.connector, expressions)
 761        return expression_wrapper % sql, expression_params
 762
 763    def resolve_expression(
 764        self,
 765        query: Any = None,
 766        allow_joins: bool = True,
 767        reuse: Any = None,
 768        summarize: bool = False,
 769        for_save: bool = False,
 770    ) -> CombinedExpression | DurationExpression | TemporalSubtraction:
 771        lhs = self.lhs.resolve_expression(
 772            query, allow_joins, reuse, summarize, for_save
 773        )
 774        rhs = self.rhs.resolve_expression(
 775            query, allow_joins, reuse, summarize, for_save
 776        )
 777        if not isinstance(self, DurationExpression | TemporalSubtraction):
 778            try:
 779                lhs_type = lhs.output_field.get_internal_type()
 780            except (AttributeError, FieldError):
 781                lhs_type = None
 782            try:
 783                rhs_type = rhs.output_field.get_internal_type()
 784            except (AttributeError, FieldError):
 785                rhs_type = None
 786            if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
 787                return DurationExpression(
 788                    self.lhs, self.connector, self.rhs
 789                ).resolve_expression(
 790                    query,
 791                    allow_joins,
 792                    reuse,
 793                    summarize,
 794                    for_save,
 795                )
 796            datetime_fields = {"DateField", "DateTimeField", "TimeField"}
 797            if (
 798                self.connector == self.SUB
 799                and lhs_type in datetime_fields
 800                and lhs_type == rhs_type
 801            ):
 802                return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
 803                    query,
 804                    allow_joins,
 805                    reuse,
 806                    summarize,
 807                    for_save,
 808                )
 809        c = self.copy()
 810        c.is_summary = summarize
 811        c.lhs = lhs
 812        c.rhs = rhs
 813        return c
 814
 815
 816class DurationExpression(CombinedExpression):
 817    def compile(
 818        self, side: Any, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 819    ) -> tuple[str, Sequence[Any]]:
 820        try:
 821            output = side.output_field
 822        except FieldError:
 823            pass
 824        else:
 825            if output.get_internal_type() == "DurationField":
 826                sql, params = compiler.compile(side)
 827                return connection.ops.format_for_duration_arithmetic(sql), params
 828        return compiler.compile(side)
 829
 830    def as_sql(
 831        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 832    ) -> tuple[str, list[Any]]:
 833        if connection.features.has_native_duration_field:
 834            return super().as_sql(compiler, connection)
 835        connection.ops.check_expression_support(self)
 836        expressions = []
 837        expression_params = []
 838        sql, params = self.compile(self.lhs, compiler, connection)
 839        expressions.append(sql)
 840        expression_params.extend(params)
 841        sql, params = self.compile(self.rhs, compiler, connection)
 842        expressions.append(sql)
 843        expression_params.extend(params)
 844        # order of precedence
 845        expression_wrapper = "(%s)"
 846        sql = connection.ops.combine_duration_expression(self.connector, expressions)
 847        return expression_wrapper % sql, expression_params
 848
 849    def as_sqlite(
 850        self,
 851        compiler: SQLCompiler,
 852        connection: BaseDatabaseWrapper,
 853        **extra_context: Any,
 854    ) -> tuple[str, Sequence[Any]]:
 855        sql, params = self.as_sql(compiler, connection, **extra_context)
 856        if self.connector in {Combinable.MUL, Combinable.DIV}:
 857            try:
 858                lhs_type = self.lhs.output_field.get_internal_type()
 859                rhs_type = self.rhs.output_field.get_internal_type()
 860            except (AttributeError, FieldError):
 861                pass
 862            else:
 863                allowed_fields = {
 864                    "DecimalField",
 865                    "DurationField",
 866                    "FloatField",
 867                    "IntegerField",
 868                }
 869                if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
 870                    raise DatabaseError(
 871                        f"Invalid arguments for operator {self.connector}."
 872                    )
 873        return sql, params
 874
 875
 876class TemporalSubtraction(CombinedExpression):
 877    output_field = fields.DurationField()
 878
 879    def __init__(self, lhs: Any, rhs: Any):
 880        super().__init__(lhs, self.SUB, rhs)
 881
 882    def as_sql(
 883        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 884    ) -> tuple[str, Sequence[Any]]:
 885        connection.ops.check_expression_support(self)
 886        lhs = compiler.compile(self.lhs)
 887        rhs = compiler.compile(self.rhs)
 888        return connection.ops.subtract_temporals(
 889            self.lhs.output_field.get_internal_type(), lhs, rhs
 890        )
 891
 892
 893@deconstructible(path="plain.models.F")
 894class F(Combinable):
 895    """An object capable of resolving references to existing query objects."""
 896
 897    def __init__(self, name: str):
 898        """
 899        Arguments:
 900         * name: the name of the field this expression references
 901        """
 902        self.name = name
 903
 904    def __repr__(self) -> str:
 905        return f"{self.__class__.__name__}({self.name})"
 906
 907    def resolve_expression(
 908        self,
 909        query: Any = None,
 910        allow_joins: bool = True,
 911        reuse: Any = None,
 912        summarize: bool = False,
 913        for_save: bool = False,
 914    ) -> Any:
 915        return query.resolve_ref(self.name, allow_joins, reuse, summarize)
 916
 917    def replace_expressions(self, replacements: dict[Any, Any]) -> F:
 918        return replacements.get(self, self)
 919
 920    def asc(self, **kwargs: Any) -> OrderBy:
 921        return OrderBy(self, **kwargs)
 922
 923    def desc(self, **kwargs: Any) -> OrderBy:
 924        return OrderBy(self, descending=True, **kwargs)
 925
 926    def __eq__(self, other: object) -> bool:
 927        if not isinstance(other, F):
 928            return NotImplemented
 929        return self.__class__ == other.__class__ and self.name == other.name
 930
 931    def __hash__(self) -> int:
 932        return hash(self.name)
 933
 934    def copy(self) -> Self:
 935        return copy.copy(self)
 936
 937
 938class ResolvedOuterRef(F):
 939    """
 940    An object that contains a reference to an outer query.
 941
 942    In this case, the reference to the outer query has been resolved because
 943    the inner query has been used as a subquery.
 944    """
 945
 946    contains_aggregate = False
 947    contains_over_clause = False
 948
 949    def as_sql(self, *args: Any, **kwargs: Any) -> None:
 950        raise ValueError(
 951            "This queryset contains a reference to an outer query and may "
 952            "only be used in a subquery."
 953        )
 954
 955    def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
 956        col = super().resolve_expression(*args, **kwargs)
 957        if col.contains_over_clause:
 958            raise NotSupportedError(
 959                f"Referencing outer query window expression is not supported: "
 960                f"{self.name}."
 961            )
 962        # FIXME: Rename possibly_multivalued to multivalued and fix detection
 963        # for non-multivalued JOINs (e.g. foreign key fields). This should take
 964        # into account only many-to-many and one-to-many relationships.
 965        col.possibly_multivalued = LOOKUP_SEP in self.name
 966        return col
 967
 968    def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
 969        return self
 970
 971    def get_group_by_cols(self) -> list[Any]:
 972        return []
 973
 974
 975class OuterRef(F):
 976    contains_aggregate = False
 977
 978    def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
 979        if isinstance(self.name, self.__class__):
 980            return self.name
 981        return ResolvedOuterRef(self.name)
 982
 983    def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
 984        return self
 985
 986
 987@deconstructible(path="plain.models.Func")
 988class Func(SQLiteNumericMixin, Expression):
 989    """An SQL function call."""
 990
 991    function = None
 992    template = "%(function)s(%(expressions)s)"
 993    arg_joiner = ", "
 994    arity = None  # The number of arguments the function accepts.
 995
 996    def __init__(
 997        self, *expressions: Any, output_field: Field | None = None, **extra: Any
 998    ):
 999        if self.arity is not None and len(expressions) != self.arity:
1000            raise TypeError(
1001                "'{}' takes exactly {} {} ({} given)".format(
1002                    self.__class__.__name__,
1003                    self.arity,
1004                    "argument" if self.arity == 1 else "arguments",
1005                    len(expressions),
1006                )
1007            )
1008        super().__init__(output_field=output_field)
1009        self.source_expressions: list[Any] = self._parse_expressions(*expressions)
1010        self.extra = extra
1011
1012    def __repr__(self) -> str:
1013        args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1014        extra = {**self.extra, **self._get_repr_options()}
1015        if extra:
1016            extra = ", ".join(
1017                str(key) + "=" + str(val) for key, val in sorted(extra.items())
1018            )
1019            return f"{self.__class__.__name__}({args}, {extra})"
1020        return f"{self.__class__.__name__}({args})"
1021
1022    def _get_repr_options(self) -> dict[str, Any]:
1023        """Return a dict of extra __init__() options to include in the repr."""
1024        return {}
1025
1026    def get_source_expressions(self) -> list[Any]:
1027        return self.source_expressions
1028
1029    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1030        self.source_expressions = list(exprs)
1031
1032    def resolve_expression(
1033        self,
1034        query: Any = None,
1035        allow_joins: bool = True,
1036        reuse: Any = None,
1037        summarize: bool = False,
1038        for_save: bool = False,
1039    ) -> Self:
1040        c = self.copy()
1041        c.is_summary = summarize
1042        for pos, arg in enumerate(c.source_expressions):
1043            c.source_expressions[pos] = arg.resolve_expression(
1044                query, allow_joins, reuse, summarize, for_save
1045            )
1046        return c
1047
1048    def as_sql(
1049        self,
1050        compiler: SQLCompiler,
1051        connection: BaseDatabaseWrapper,
1052        function: str | None = None,
1053        template: str | None = None,
1054        arg_joiner: str | None = None,
1055        **extra_context: Any,
1056    ) -> tuple[str, list[Any]]:
1057        connection.ops.check_expression_support(self)
1058        sql_parts = []
1059        params = []
1060        for arg in self.source_expressions:
1061            try:
1062                arg_sql, arg_params = compiler.compile(arg)
1063            except EmptyResultSet:
1064                empty_result_set_value = getattr(
1065                    arg, "empty_result_set_value", NotImplemented
1066                )
1067                if empty_result_set_value is NotImplemented:
1068                    raise
1069                arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
1070            except FullResultSet:
1071                arg_sql, arg_params = compiler.compile(Value(True))
1072            sql_parts.append(arg_sql)
1073            params.extend(arg_params)
1074        data = {**self.extra, **extra_context}
1075        # Use the first supplied value in this order: the parameter to this
1076        # method, a value supplied in __init__()'s **extra (the value in
1077        # `data`), or the value defined on the class.
1078        if function is not None:
1079            data["function"] = function
1080        else:
1081            data.setdefault("function", self.function)
1082        template = template or data.get("template", self.template)
1083        arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1084        data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1085        return template % data, params
1086
1087    def copy(self) -> Self:
1088        clone = super().copy()
1089        clone.source_expressions = self.source_expressions[:]
1090        clone.extra = self.extra.copy()
1091        return cast(Self, clone)
1092
1093
1094@deconstructible(path="plain.models.Value")
1095class Value(SQLiteNumericMixin, Expression):
1096    """Represent a wrapped value as a node within an expression."""
1097
1098    # Provide a default value for `for_save` in order to allow unresolved
1099    # instances to be compiled until a decision is taken in #25425.
1100    for_save = False
1101
1102    def __init__(self, value: Any, output_field: Field | None = None):
1103        """
1104        Arguments:
1105         * value: the value this expression represents. The value will be
1106           added into the sql parameter list and properly quoted.
1107
1108         * output_field: an instance of the model field type that this
1109           expression will return, such as IntegerField() or CharField().
1110        """
1111        super().__init__(output_field=output_field)
1112        self.value = value
1113
1114    def __repr__(self) -> str:
1115        return f"{self.__class__.__name__}({self.value!r})"
1116
1117    def as_sql(
1118        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1119    ) -> tuple[str, list[Any]]:
1120        connection.ops.check_expression_support(self)
1121        val = self.value
1122        output_field = self._output_field_or_none
1123        if output_field is not None:
1124            if self.for_save:
1125                val = output_field.get_db_prep_save(val, connection=connection)
1126            else:
1127                val = output_field.get_db_prep_value(val, connection=connection)
1128            if hasattr(output_field, "get_placeholder"):
1129                return output_field.get_placeholder(val, compiler, connection), [val]
1130        if val is None:
1131            # cx_Oracle does not always convert None to the appropriate
1132            # NULL type (like in case expressions using numbers), so we
1133            # use a literal SQL NULL
1134            return "NULL", []
1135        return "%s", [val]
1136
1137    def resolve_expression(
1138        self,
1139        query: Any = None,
1140        allow_joins: bool = True,
1141        reuse: Any = None,
1142        summarize: bool = False,
1143        for_save: bool = False,
1144    ) -> Value:
1145        c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1146        c.for_save = for_save
1147        return c
1148
1149    def get_group_by_cols(self) -> list[Any]:
1150        return []
1151
1152    def _resolve_output_field(self) -> Field | None:
1153        if isinstance(self.value, str):
1154            return fields.CharField()
1155        if isinstance(self.value, bool):
1156            return fields.BooleanField()
1157        if isinstance(self.value, int):
1158            return fields.IntegerField()
1159        if isinstance(self.value, float):
1160            return fields.FloatField()
1161        if isinstance(self.value, datetime.datetime):
1162            return fields.DateTimeField()
1163        if isinstance(self.value, datetime.date):
1164            return fields.DateField()
1165        if isinstance(self.value, datetime.time):
1166            return fields.TimeField()
1167        if isinstance(self.value, datetime.timedelta):
1168            return fields.DurationField()
1169        if isinstance(self.value, Decimal):
1170            return fields.DecimalField()
1171        if isinstance(self.value, bytes):
1172            return fields.BinaryField()
1173        if isinstance(self.value, UUID):
1174            return fields.UUIDField()
1175
1176    @property
1177    def empty_result_set_value(self) -> Any:
1178        return self.value
1179
1180
1181class RawSQL(Expression):
1182    def __init__(
1183        self, sql: str, params: Sequence[Any], output_field: Field | None = None
1184    ):
1185        if output_field is None:
1186            output_field = fields.Field()
1187        self.sql, self.params = sql, params
1188        super().__init__(output_field=output_field)
1189
1190    def __repr__(self) -> str:
1191        return f"{self.__class__.__name__}({self.sql}, {self.params})"
1192
1193    def as_sql(
1194        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1195    ) -> tuple[str, Sequence[Any]]:
1196        return f"({self.sql})", self.params
1197
1198    def get_group_by_cols(self) -> list[RawSQL]:
1199        return [self]
1200
1201
1202class Star(Expression):
1203    def __repr__(self) -> str:
1204        return "'*'"
1205
1206    def as_sql(
1207        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1208    ) -> tuple[str, list[Any]]:
1209        return "*", []
1210
1211
1212class Col(Expression):
1213    contains_column_references = True
1214    possibly_multivalued = False
1215
1216    def __init__(
1217        self, alias: str | None, target: Any, output_field: Field | None = None
1218    ):
1219        if output_field is None:
1220            output_field = target
1221        super().__init__(output_field=output_field)
1222        self.alias, self.target = alias, target
1223
1224    def __repr__(self) -> str:
1225        alias, target = self.alias, self.target
1226        identifiers = (alias, str(target)) if alias else (str(target),)
1227        return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1228
1229    def as_sql(
1230        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1231    ) -> tuple[str, list[Any]]:
1232        alias, column = self.alias, self.target.column
1233        identifiers = (alias, column) if alias else (column,)
1234        sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1235        return sql, []
1236
1237    def relabeled_clone(self, relabels: dict[str, str]) -> Col:
1238        if self.alias is None:
1239            return self
1240        return self.__class__(
1241            relabels.get(self.alias, self.alias), self.target, self.output_field
1242        )
1243
1244    def get_group_by_cols(self) -> list[Col]:
1245        return [self]
1246
1247    def get_db_converters(
1248        self, connection: BaseDatabaseWrapper
1249    ) -> list[Callable[..., Any]]:
1250        if self.target == self.output_field:
1251            return self.output_field.get_db_converters(connection)
1252        return self.output_field.get_db_converters(
1253            connection
1254        ) + self.target.get_db_converters(connection)
1255
1256
1257class Ref(Expression):
1258    """
1259    Reference to column alias of the query. For example, Ref('sum_cost') in
1260    qs.annotate(sum_cost=Sum('cost')) query.
1261    """
1262
1263    def __init__(self, refs: str, source: Any):
1264        super().__init__()
1265        self.refs, self.source = refs, source
1266
1267    def __repr__(self) -> str:
1268        return f"{self.__class__.__name__}({self.refs}, {self.source})"
1269
1270    def get_source_expressions(self) -> list[Any]:
1271        return [self.source]
1272
1273    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1274        (self.source,) = exprs
1275
1276    def resolve_expression(
1277        self,
1278        query: Any = None,
1279        allow_joins: bool = True,
1280        reuse: Any = None,
1281        summarize: bool = False,
1282        for_save: bool = False,
1283    ) -> Ref:
1284        # The sub-expression `source` has already been resolved, as this is
1285        # just a reference to the name of `source`.
1286        return self
1287
1288    def get_refs(self) -> set[str]:
1289        return {self.refs}
1290
1291    def relabeled_clone(self, relabels: dict[str, str]) -> Ref:
1292        return self
1293
1294    def as_sql(
1295        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1296    ) -> tuple[str, list[Any]]:
1297        return connection.ops.quote_name(self.refs), []
1298
1299    def get_group_by_cols(self) -> list[Ref]:
1300        return [self]
1301
1302
1303class ExpressionList(Func):
1304    """
1305    An expression containing multiple expressions. Can be used to provide a
1306    list of expressions as an argument to another expression, like a partition
1307    clause.
1308    """
1309
1310    template = "%(expressions)s"
1311
1312    def __init__(self, *expressions: Any, **extra: Any):
1313        if not expressions:
1314            raise ValueError(
1315                f"{self.__class__.__name__} requires at least one expression."
1316            )
1317        super().__init__(*expressions, **extra)
1318
1319    def __str__(self) -> str:
1320        return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1321
1322    def as_sqlite(
1323        self,
1324        compiler: SQLCompiler,
1325        connection: BaseDatabaseWrapper,
1326        **extra_context: Any,
1327    ) -> tuple[str, Sequence[Any]]:
1328        # Casting to numeric is unnecessary.
1329        return self.as_sql(compiler, connection, **extra_context)
1330
1331
1332class OrderByList(Func):
1333    template = "ORDER BY %(expressions)s"
1334
1335    def __init__(self, *expressions: Any, **extra: Any):
1336        expressions_tuple = tuple(
1337            (
1338                OrderBy(F(expr[1:]), descending=True)
1339                if isinstance(expr, str) and expr[0] == "-"
1340                else expr
1341            )
1342            for expr in expressions
1343        )
1344        super().__init__(*expressions_tuple, **extra)
1345
1346    def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, tuple[Any, ...]]:
1347        if not self.source_expressions:
1348            return "", cast(tuple[Any, ...], ())
1349        sql, params = super().as_sql(*args, **kwargs)
1350        return sql, tuple(params)
1351
1352    def get_group_by_cols(self) -> list[Any]:
1353        group_by_cols = []
1354        for order_by in self.get_source_expressions():
1355            group_by_cols.extend(order_by.get_group_by_cols())
1356        return group_by_cols
1357
1358
1359@deconstructible(path="plain.models.ExpressionWrapper")
1360class ExpressionWrapper(SQLiteNumericMixin, Expression):
1361    """
1362    An expression that can wrap another expression so that it can provide
1363    extra context to the inner expression, such as the output_field.
1364    """
1365
1366    def __init__(self, expression: Any, output_field: Field):
1367        super().__init__(output_field=output_field)
1368        self.expression = expression
1369
1370    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1371        self.expression = exprs[0]
1372
1373    def get_source_expressions(self) -> list[Any]:
1374        return [self.expression]
1375
1376    def get_group_by_cols(self) -> list[Any]:
1377        if isinstance(self.expression, Expression):
1378            expression = self.expression.copy()
1379            expression.output_field = self.output_field
1380            return expression.get_group_by_cols()
1381        # For non-expressions e.g. an SQL WHERE clause, the entire
1382        # `expression` must be included in the GROUP BY clause.
1383        return super().get_group_by_cols()
1384
1385    def as_sql(
1386        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1387    ) -> tuple[str, Sequence[Any]]:
1388        return compiler.compile(self.expression)
1389
1390    def __repr__(self) -> str:
1391        return f"{self.__class__.__name__}({self.expression})"
1392
1393
1394class NegatedExpression(ExpressionWrapper):
1395    """The logical negation of a conditional expression."""
1396
1397    def __init__(self, expression: Any):
1398        super().__init__(expression, output_field=fields.BooleanField())
1399
1400    def __invert__(self) -> Any:
1401        return self.expression.copy()
1402
1403    def as_sql(
1404        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1405    ) -> tuple[str, Sequence[Any]]:
1406        try:
1407            sql, params = super().as_sql(compiler, connection)
1408        except EmptyResultSet:
1409            features = compiler.connection.features
1410            if not features.supports_boolean_expr_in_select_clause:
1411                return "1=1", ()
1412            return compiler.compile(Value(True))
1413        ops = compiler.connection.ops
1414        # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1415        # to be compared to another expression unless they're wrapped in a CASE
1416        # WHEN.
1417        if not ops.conditional_expression_supported_in_where_clause(self.expression):
1418            return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1419        return f"NOT {sql}", params
1420
1421    def resolve_expression(
1422        self,
1423        query: Any = None,
1424        allow_joins: bool = True,
1425        reuse: Any = None,
1426        summarize: bool = False,
1427        for_save: bool = False,
1428    ) -> NegatedExpression:
1429        resolved = super().resolve_expression(
1430            query, allow_joins, reuse, summarize, for_save
1431        )
1432        if not getattr(resolved.expression, "conditional", False):
1433            raise TypeError("Cannot negate non-conditional expressions.")
1434        return resolved
1435
1436    def select_format(
1437        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1438    ) -> tuple[str, Sequence[Any]]:
1439        # Wrap boolean expressions with a CASE WHEN expression if a database
1440        # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1441        # GROUP BY list.
1442        expression_supported_in_where_clause = (
1443            compiler.connection.ops.conditional_expression_supported_in_where_clause
1444        )
1445        if (
1446            not compiler.connection.features.supports_boolean_expr_in_select_clause
1447            # Avoid double wrapping.
1448            and expression_supported_in_where_clause(self.expression)
1449        ):
1450            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1451        return sql, params
1452
1453
1454@deconstructible(path="plain.models.When")
1455class When(Expression):
1456    template = "WHEN %(condition)s THEN %(result)s"
1457    # This isn't a complete conditional expression, must be used in Case().
1458    conditional = False
1459    condition: SQLCompilable
1460
1461    def __init__(
1462        self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1463    ):
1464        lookups_dict: dict[str, Any] | None = lookups or None
1465        if lookups_dict:
1466            if condition is None:
1467                condition, lookups_dict = Q(**lookups_dict), None
1468            elif getattr(condition, "conditional", False):
1469                condition, lookups_dict = Q(condition, **lookups_dict), None
1470        if (
1471            condition is None
1472            or not getattr(condition, "conditional", False)
1473            or lookups_dict
1474        ):
1475            raise TypeError(
1476                "When() supports a Q object, a boolean expression, or lookups "
1477                "as a condition."
1478            )
1479        if isinstance(condition, Q) and not condition:
1480            raise ValueError("An empty Q() can't be used as a When() condition.")
1481        super().__init__(output_field=None)
1482        self.condition = condition  # type: ignore[assignment]
1483        self.result = self._parse_expressions(then)[0]
1484
1485    def __str__(self) -> str:
1486        return f"WHEN {self.condition!r} THEN {self.result!r}"
1487
1488    def __repr__(self) -> str:
1489        return f"<{self.__class__.__name__}: {self}>"
1490
1491    def get_source_expressions(self) -> list[Any]:
1492        return [self.condition, self.result]
1493
1494    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1495        self.condition, self.result = exprs
1496
1497    def get_source_fields(self) -> list[Field | None]:
1498        # We're only interested in the fields of the result expressions.
1499        return [self.result._output_field_or_none]
1500
1501    def resolve_expression(
1502        self,
1503        query: Any = None,
1504        allow_joins: bool = True,
1505        reuse: Any = None,
1506        summarize: bool = False,
1507        for_save: bool = False,
1508    ) -> When:
1509        c = self.copy()
1510        c.is_summary = summarize
1511        if isinstance(c.condition, ResolvableExpression):
1512            c.condition = c.condition.resolve_expression(
1513                query, allow_joins, reuse, summarize, False
1514            )
1515        c.result = c.result.resolve_expression(
1516            query, allow_joins, reuse, summarize, for_save
1517        )
1518        return c
1519
1520    def as_sql(
1521        self,
1522        compiler: SQLCompiler,
1523        connection: BaseDatabaseWrapper,
1524        template: str | None = None,
1525        **extra_context: Any,
1526    ) -> tuple[str, tuple[Any, ...]]:
1527        connection.ops.check_expression_support(self)
1528        template_params = extra_context
1529        sql_params = []
1530        # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1531        condition_sql, condition_params = compiler.compile(self.condition)
1532        template_params["condition"] = condition_sql
1533        result_sql, result_params = compiler.compile(self.result)
1534        template_params["result"] = result_sql
1535        template = template or self.template
1536        return template % template_params, (
1537            *sql_params,
1538            *condition_params,
1539            *result_params,
1540        )
1541
1542    def get_group_by_cols(self) -> list[Any]:
1543        # This is not a complete expression and cannot be used in GROUP BY.
1544        cols = []
1545        for source in self.get_source_expressions():
1546            cols.extend(source.get_group_by_cols())
1547        return cols
1548
1549
1550@deconstructible(path="plain.models.Case")
1551class Case(SQLiteNumericMixin, Expression):
1552    """
1553    An SQL searched CASE expression:
1554
1555        CASE
1556            WHEN n > 0
1557                THEN 'positive'
1558            WHEN n < 0
1559                THEN 'negative'
1560            ELSE 'zero'
1561        END
1562    """
1563
1564    template = "CASE %(cases)s ELSE %(default)s END"
1565    case_joiner = " "
1566
1567    def __init__(
1568        self,
1569        *cases: When,
1570        default: Any = None,
1571        output_field: Field | None = None,
1572        **extra: Any,
1573    ):
1574        if not all(isinstance(case, When) for case in cases):
1575            raise TypeError("Positional arguments must all be When objects.")
1576        super().__init__(output_field)
1577        self.cases = list(cases)
1578        self.default = self._parse_expressions(default)[0]
1579        self.extra = extra
1580
1581    def __str__(self) -> str:
1582        return "CASE {}, ELSE {!r}".format(
1583            ", ".join(str(c) for c in self.cases),
1584            self.default,
1585        )
1586
1587    def __repr__(self) -> str:
1588        return f"<{self.__class__.__name__}: {self}>"
1589
1590    def get_source_expressions(self) -> list[Any]:
1591        return self.cases + [self.default]
1592
1593    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1594        *self.cases, self.default = exprs
1595
1596    def resolve_expression(
1597        self,
1598        query: Any = None,
1599        allow_joins: bool = True,
1600        reuse: Any = None,
1601        summarize: bool = False,
1602        for_save: bool = False,
1603    ) -> Case:
1604        c = self.copy()
1605        c.is_summary = summarize
1606        for pos, case in enumerate(c.cases):
1607            c.cases[pos] = case.resolve_expression(
1608                query, allow_joins, reuse, summarize, for_save
1609            )
1610        c.default = c.default.resolve_expression(
1611            query, allow_joins, reuse, summarize, for_save
1612        )
1613        return c
1614
1615    def copy(self) -> Self:
1616        c = super().copy()
1617        c.cases = c.cases[:]
1618        return cast(Self, c)
1619
1620    def as_sql(
1621        self,
1622        compiler: SQLCompiler,
1623        connection: BaseDatabaseWrapper,
1624        template: str | None = None,
1625        case_joiner: str | None = None,
1626        **extra_context: Any,
1627    ) -> tuple[str, list[Any]]:
1628        connection.ops.check_expression_support(self)
1629        if not self.cases:
1630            sql, params = compiler.compile(self.default)
1631            return sql, list(params)
1632        template_params = {**self.extra, **extra_context}
1633        case_parts = []
1634        sql_params = []
1635        default_sql, default_params = compiler.compile(self.default)
1636        for case in self.cases:
1637            try:
1638                case_sql, case_params = compiler.compile(case)
1639            except EmptyResultSet:
1640                continue
1641            except FullResultSet:
1642                default_sql, default_params = compiler.compile(case.result)
1643                break
1644            case_parts.append(case_sql)
1645            sql_params.extend(case_params)
1646        if not case_parts:
1647            return default_sql, list(default_params)
1648        case_joiner = case_joiner or self.case_joiner
1649        template_params["cases"] = case_joiner.join(case_parts)
1650        template_params["default"] = default_sql
1651        sql_params.extend(default_params)
1652        template = template or template_params.get("template", self.template)
1653        sql = template % template_params
1654        if self._output_field_or_none is not None:
1655            sql = connection.ops.unification_cast_sql(self.output_field) % sql
1656        return sql, sql_params
1657
1658    def get_group_by_cols(self) -> list[Any]:
1659        if not self.cases:
1660            return self.default.get_group_by_cols()
1661        return super().get_group_by_cols()
1662
1663
1664class Subquery(BaseExpression, Combinable):
1665    """
1666    An explicit subquery. It may contain OuterRef() references to the outer
1667    query which will be resolved when it is applied to that query.
1668    """
1669
1670    template = "(%(subquery)s)"
1671    contains_aggregate = False
1672    empty_result_set_value = None
1673
1674    def __init__(
1675        self,
1676        query: QuerySet[Any] | Query,
1677        output_field: Field | None = None,
1678        **extra: Any,
1679    ):
1680        # Import here to avoid circular import
1681        from plain.models.sql.query import Query
1682
1683        # Allow the usage of both QuerySet and sql.Query objects.
1684        if isinstance(query, Query):
1685            # It's already a Query object, use it directly
1686            sql_query = query
1687        else:
1688            # It's a QuerySet, extract the sql.Query
1689            sql_query = query.sql_query
1690        self.query = sql_query.clone()
1691        self.query.subquery = True
1692        self.extra = extra
1693        super().__init__(output_field)
1694
1695    def get_source_expressions(self) -> list[Any]:
1696        return [self.query]
1697
1698    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1699        self.query = exprs[0]
1700
1701    def _resolve_output_field(self) -> Field | None:
1702        return self.query.output_field
1703
1704    def copy(self) -> Self:
1705        clone = super().copy()
1706        clone.query = clone.query.clone()
1707        return cast(Self, clone)
1708
1709    @property
1710    def external_aliases(self) -> dict[str, bool]:
1711        return self.query.external_aliases
1712
1713    def get_external_cols(self) -> list[Any]:
1714        return self.query.get_external_cols()
1715
1716    def as_sql(
1717        self,
1718        compiler: SQLCompiler,
1719        connection: BaseDatabaseWrapper,
1720        template: str | None = None,
1721        **extra_context: Any,
1722    ) -> tuple[str, tuple[Any, ...]]:
1723        connection.ops.check_expression_support(self)
1724        template_params = {**self.extra, **extra_context}
1725        subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1726        template_params["subquery"] = subquery_sql[1:-1]
1727
1728        template = template or template_params.get("template", self.template)
1729        sql = template % template_params
1730        return sql, sql_params
1731
1732    def get_group_by_cols(self) -> list[Any]:
1733        return self.query.get_group_by_cols(wrapper=self)
1734
1735
1736class Exists(Subquery):
1737    template = "EXISTS(%(subquery)s)"
1738    output_field = fields.BooleanField()
1739    empty_result_set_value = False
1740
1741    def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1742        super().__init__(query, **kwargs)
1743        self.query = self.query.exists()
1744
1745    def select_format(
1746        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1747    ) -> tuple[str, Sequence[Any]]:
1748        # Wrap EXISTS() with a CASE WHEN expression if a database backend
1749        # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1750        # BY list.
1751        if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1752            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1753        return sql, params
1754
1755
1756@deconstructible(path="plain.models.OrderBy")
1757class OrderBy(Expression):
1758    template = "%(expression)s %(ordering)s"
1759    conditional = False
1760
1761    def __init__(
1762        self,
1763        expression: Any,
1764        descending: bool = False,
1765        nulls_first: bool | None = None,
1766        nulls_last: bool | None = None,
1767    ):
1768        if nulls_first and nulls_last:
1769            raise ValueError("nulls_first and nulls_last are mutually exclusive")
1770        if nulls_first is False or nulls_last is False:
1771            raise ValueError("nulls_first and nulls_last values must be True or None.")
1772        self.nulls_first = nulls_first
1773        self.nulls_last = nulls_last
1774        self.descending = descending
1775        if not isinstance(expression, ResolvableExpression):
1776            raise ValueError("expression must be an expression type")
1777        self.expression = expression
1778
1779    def __repr__(self) -> str:
1780        return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1781
1782    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1783        self.expression = exprs[0]
1784
1785    def get_source_expressions(self) -> list[Any]:
1786        return [self.expression]
1787
1788    def as_sql(
1789        self,
1790        compiler: SQLCompiler,
1791        connection: BaseDatabaseWrapper,
1792        template: str | None = None,
1793        **extra_context: Any,
1794    ) -> tuple[str, tuple[Any, ...]]:
1795        template = template or self.template
1796        if connection.features.supports_order_by_nulls_modifier:
1797            if self.nulls_last:
1798                template = f"{template} NULLS LAST"
1799            elif self.nulls_first:
1800                template = f"{template} NULLS FIRST"
1801        else:
1802            if self.nulls_last and not (
1803                self.descending and connection.features.order_by_nulls_first
1804            ):
1805                template = f"%(expression)s IS NULL, {template}"
1806            elif self.nulls_first and not (
1807                not self.descending and connection.features.order_by_nulls_first
1808            ):
1809                template = f"%(expression)s IS NOT NULL, {template}"
1810        connection.ops.check_expression_support(self)
1811        expression_sql, params = compiler.compile(self.expression)
1812        placeholders = {
1813            "expression": expression_sql,
1814            "ordering": "DESC" if self.descending else "ASC",
1815            **extra_context,
1816        }
1817        params *= template.count("%(expression)s")
1818        return (template % placeholders).rstrip(), params
1819
1820    def get_group_by_cols(self) -> list[Any]:
1821        cols = []
1822        for source in self.get_source_expressions():
1823            cols.extend(source.get_group_by_cols())
1824        return cols
1825
1826    def reverse_ordering(self) -> OrderBy:
1827        self.descending = not self.descending
1828        if self.nulls_first:
1829            self.nulls_last = True
1830            self.nulls_first = None
1831        elif self.nulls_last:
1832            self.nulls_first = True
1833            self.nulls_last = None
1834        return self
1835
1836    def asc(self) -> None:
1837        self.descending = False
1838
1839    def desc(self) -> None:
1840        self.descending = True
1841
1842
1843class Window(SQLiteNumericMixin, Expression):
1844    template = "%(expression)s OVER (%(window)s)"
1845    # Although the main expression may either be an aggregate or an
1846    # expression with an aggregate function, the GROUP BY that will
1847    # be introduced in the query as a result is not desired.
1848    contains_aggregate = False
1849    contains_over_clause = True
1850    partition_by: ExpressionList | None
1851    order_by: OrderByList | None
1852
1853    def __init__(
1854        self,
1855        expression: Any,
1856        partition_by: Any = None,
1857        order_by: Any = None,
1858        frame: Any = None,
1859        output_field: Field | None = None,
1860    ):
1861        self.partition_by = partition_by
1862        self.order_by = order_by
1863        self.frame = frame
1864
1865        if not getattr(expression, "window_compatible", False):
1866            raise ValueError(
1867                f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1868            )
1869
1870        if self.partition_by is not None:
1871            partition_by_values = (
1872                self.partition_by
1873                if isinstance(self.partition_by, tuple | list)
1874                else (self.partition_by,)
1875            )
1876            self.partition_by = ExpressionList(*partition_by_values)
1877
1878        if self.order_by is not None:
1879            if isinstance(self.order_by, list | tuple):
1880                self.order_by = OrderByList(*self.order_by)
1881            elif isinstance(self.order_by, BaseExpression | str):
1882                self.order_by = OrderByList(self.order_by)
1883            else:
1884                raise ValueError(
1885                    "Window.order_by must be either a string reference to a "
1886                    "field, an expression, or a list or tuple of them."
1887                )
1888        super().__init__(output_field=output_field)
1889        self.source_expression = self._parse_expressions(expression)[0]
1890
1891    def _resolve_output_field(self) -> Field | None:
1892        return self.source_expression.output_field
1893
1894    def get_source_expressions(self) -> list[Any]:
1895        return [self.source_expression, self.partition_by, self.order_by, self.frame]
1896
1897    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1898        self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1899
1900    def as_sql(
1901        self,
1902        compiler: SQLCompiler,
1903        connection: BaseDatabaseWrapper,
1904        template: str | None = None,
1905    ) -> tuple[str, tuple[Any, ...]]:
1906        connection.ops.check_expression_support(self)
1907        if not connection.features.supports_over_clause:
1908            raise NotSupportedError("This backend does not support window expressions.")
1909        expr_sql, params = compiler.compile(self.source_expression)
1910        window_sql, window_params = [], ()
1911
1912        if self.partition_by is not None:
1913            sql_expr, sql_params = self.partition_by.as_sql(
1914                compiler=compiler,
1915                connection=connection,
1916                template="PARTITION BY %(expressions)s",
1917            )
1918            window_sql.append(sql_expr)
1919            window_params += tuple(sql_params)
1920
1921        if self.order_by is not None:
1922            order_sql, order_params = compiler.compile(self.order_by)
1923            window_sql.append(order_sql)
1924            window_params += tuple(order_params)
1925
1926        if self.frame:
1927            frame_sql, frame_params = compiler.compile(self.frame)
1928            window_sql.append(frame_sql)
1929            window_params += tuple(frame_params)
1930
1931        template = template or self.template
1932
1933        return (
1934            template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1935            (*params, *window_params),
1936        )
1937
1938    def as_sqlite(
1939        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1940    ) -> tuple[str, Sequence[Any]]:
1941        if isinstance(self.output_field, fields.DecimalField):
1942            # Casting to numeric must be outside of the window expression.
1943            copy = self.copy()
1944            source_expressions = copy.get_source_expressions()
1945            source_expressions[0].output_field = fields.FloatField()
1946            copy.set_source_expressions(source_expressions)
1947            return super(Window, copy).as_sqlite(compiler, connection)
1948        return self.as_sql(compiler, connection)
1949
1950    def __str__(self) -> str:
1951        return "{} OVER ({}{}{})".format(
1952            str(self.source_expression),
1953            "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1954            str(self.order_by or ""),
1955            str(self.frame or ""),
1956        )
1957
1958    def __repr__(self) -> str:
1959        return f"<{self.__class__.__name__}: {self}>"
1960
1961    def get_group_by_cols(self) -> list[Any]:
1962        group_by_cols = []
1963        if self.partition_by:
1964            group_by_cols.extend(self.partition_by.get_group_by_cols())
1965        if self.order_by is not None:
1966            group_by_cols.extend(self.order_by.get_group_by_cols())
1967        return group_by_cols
1968
1969
1970class WindowFrame(Expression, ABC):
1971    """
1972    Model the frame clause in window expressions. There are two types of frame
1973    clauses which are subclasses, however, all processing and validation (by no
1974    means intended to be complete) is done here. Thus, providing an end for a
1975    frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1976    row in the frame).
1977    """
1978
1979    template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1980    frame_type: str
1981
1982    def __init__(self, start: int | None = None, end: int | None = None):
1983        self.start = Value(start)
1984        self.end = Value(end)
1985
1986    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1987        self.start, self.end = exprs
1988
1989    def get_source_expressions(self) -> list[Any]:
1990        return [self.start, self.end]
1991
1992    def as_sql(
1993        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1994    ) -> tuple[str, list[Any]]:
1995        connection.ops.check_expression_support(self)
1996        start, end = self.window_frame_start_end(
1997            connection, self.start.value, self.end.value
1998        )
1999        return (
2000            self.template
2001            % {
2002                "frame_type": self.frame_type,
2003                "start": start,
2004                "end": end,
2005            },
2006            [],
2007        )
2008
2009    def __repr__(self) -> str:
2010        return f"<{self.__class__.__name__}: {self}>"
2011
2012    def get_group_by_cols(self) -> list[Any]:
2013        return []
2014
2015    def __str__(self) -> str:
2016        if self.start.value is not None and self.start.value < 0:
2017            start = f"{abs(self.start.value)} {db_connection.ops.PRECEDING}"
2018        elif self.start.value is not None and self.start.value == 0:
2019            start = db_connection.ops.CURRENT_ROW
2020        else:
2021            start = db_connection.ops.UNBOUNDED_PRECEDING
2022
2023        if self.end.value is not None and self.end.value > 0:
2024            end = f"{self.end.value} {db_connection.ops.FOLLOWING}"
2025        elif self.end.value is not None and self.end.value == 0:
2026            end = db_connection.ops.CURRENT_ROW
2027        else:
2028            end = db_connection.ops.UNBOUNDED_FOLLOWING
2029        return self.template % {
2030            "frame_type": self.frame_type,
2031            "start": start,
2032            "end": end,
2033        }
2034
2035    @abstractmethod
2036    def window_frame_start_end(
2037        self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2038    ) -> tuple[str, str]: ...
2039
2040
2041class RowRange(WindowFrame):
2042    frame_type = "ROWS"
2043
2044    def window_frame_start_end(
2045        self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2046    ) -> tuple[str, str]:
2047        return connection.ops.window_frame_rows_start_end(start, end)
2048
2049
2050class ValueRange(WindowFrame):
2051    frame_type = "RANGE"
2052
2053    def window_frame_start_end(
2054        self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2055    ) -> tuple[str, str]:
2056        return connection.ops.window_frame_range_start_end(start, end)