Plain is headed towards 1.0! Subscribe for development updates →

   1from __future__ import annotations
   2
   3import copy
   4import datetime
   5import functools
   6import inspect
   7from abc import ABC, abstractmethod
   8from collections import defaultdict
   9from decimal import Decimal
  10from functools import cached_property
  11from types import NoneType
  12from typing import TYPE_CHECKING, Any, Protocol, Self, cast, runtime_checkable
  13from uuid import UUID
  14
  15from plain.models import fields
  16from plain.models.constants import LOOKUP_SEP
  17from plain.models.db import (
  18    DatabaseError,
  19    NotSupportedError,
  20    db_connection,
  21)
  22from plain.models.exceptions import EmptyResultSet, FieldError, FullResultSet
  23from plain.models.query_utils import Q
  24from plain.utils.deconstruct import deconstructible
  25from plain.utils.hashable import make_hashable
  26
  27if TYPE_CHECKING:
  28    from collections.abc import Callable, Iterable, Sequence
  29
  30    from plain.models.backends.base.base import BaseDatabaseWrapper
  31    from plain.models.fields import Field
  32    from plain.models.lookups import Lookup, Transform
  33    from plain.models.query import QuerySet
  34    from plain.models.sql.compiler import SQLCompilable, SQLCompiler
  35    from plain.models.sql.query import Query
  36
  37
  38@runtime_checkable
  39class ResolvableExpression(Protocol):
  40    """Protocol for expressions that can be resolved in query context."""
  41
  42    def resolve_expression(
  43        self,
  44        query: Any = None,
  45        allow_joins: bool = True,
  46        reuse: Any = None,
  47        summarize: bool = False,
  48        for_save: bool = False,
  49    ) -> Any: ...
  50
  51
  52@runtime_checkable
  53class ReplaceableExpression(Protocol):
  54    """Protocol for expressions that support expression replacement."""
  55
  56    def replace_expressions(self, replacements: dict[Any, Any]) -> Self: ...
  57
  58
  59class Combinable:
  60    """
  61    Provide the ability to combine one or two objects with
  62    some connector. For example F('foo') + F('bar').
  63    """
  64
  65    # Arithmetic connectors
  66    ADD = "+"
  67    SUB = "-"
  68    MUL = "*"
  69    DIV = "/"
  70    POW = "^"
  71    # The following is a quoted % operator - it is quoted because it can be
  72    # used in strings that also have parameter substitution.
  73    MOD = "%%"
  74
  75    # Bitwise operators - note that these are generated by .bitand()
  76    # and .bitor(), the '&' and '|' are reserved for boolean operator
  77    # usage.
  78    BITAND = "&"
  79    BITOR = "|"
  80    BITLEFTSHIFT = "<<"
  81    BITRIGHTSHIFT = ">>"
  82    BITXOR = "#"
  83
  84    def _combine(
  85        self, other: Any, connector: str, reversed: bool
  86    ) -> CombinedExpression:
  87        if not isinstance(other, ResolvableExpression):
  88            # everything must be resolvable to an expression
  89            other = Value(other)
  90
  91        if reversed:
  92            return CombinedExpression(other, connector, self)
  93        return CombinedExpression(self, connector, other)
  94
  95    #############
  96    # OPERATORS #
  97    #############
  98
  99    def __neg__(self) -> CombinedExpression:
 100        return self._combine(-1, self.MUL, False)
 101
 102    def __add__(self, other: Any) -> CombinedExpression:
 103        return self._combine(other, self.ADD, False)
 104
 105    def __sub__(self, other: Any) -> CombinedExpression:
 106        return self._combine(other, self.SUB, False)
 107
 108    def __mul__(self, other: Any) -> CombinedExpression:
 109        return self._combine(other, self.MUL, False)
 110
 111    def __truediv__(self, other: Any) -> CombinedExpression:
 112        return self._combine(other, self.DIV, False)
 113
 114    def __mod__(self, other: Any) -> CombinedExpression:
 115        return self._combine(other, self.MOD, False)
 116
 117    def __pow__(self, other: Any) -> CombinedExpression:
 118        return self._combine(other, self.POW, False)
 119
 120    def __and__(self, other: Any) -> Q:
 121        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 122            return Q(self) & Q(other)
 123        raise NotImplementedError(
 124            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 125        )
 126
 127    def bitand(self, other: Any) -> CombinedExpression:
 128        return self._combine(other, self.BITAND, False)
 129
 130    def bitleftshift(self, other: Any) -> CombinedExpression:
 131        return self._combine(other, self.BITLEFTSHIFT, False)
 132
 133    def bitrightshift(self, other: Any) -> CombinedExpression:
 134        return self._combine(other, self.BITRIGHTSHIFT, False)
 135
 136    def __xor__(self, other: Any) -> Q:
 137        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 138            return Q(self) ^ Q(other)
 139        raise NotImplementedError(
 140            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 141        )
 142
 143    def bitxor(self, other: Any) -> CombinedExpression:
 144        return self._combine(other, self.BITXOR, False)
 145
 146    def __or__(self, other: Any) -> Q:
 147        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 148            return Q(self) | Q(other)
 149        raise NotImplementedError(
 150            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 151        )
 152
 153    def bitor(self, other: Any) -> CombinedExpression:
 154        return self._combine(other, self.BITOR, False)
 155
 156    def __radd__(self, other: Any) -> CombinedExpression:
 157        return self._combine(other, self.ADD, True)
 158
 159    def __rsub__(self, other: Any) -> CombinedExpression:
 160        return self._combine(other, self.SUB, True)
 161
 162    def __rmul__(self, other: Any) -> CombinedExpression:
 163        return self._combine(other, self.MUL, True)
 164
 165    def __rtruediv__(self, other: Any) -> CombinedExpression:
 166        return self._combine(other, self.DIV, True)
 167
 168    def __rmod__(self, other: Any) -> CombinedExpression:
 169        return self._combine(other, self.MOD, True)
 170
 171    def __rpow__(self, other: Any) -> CombinedExpression:
 172        return self._combine(other, self.POW, True)
 173
 174    def __rand__(self, other: Any) -> None:
 175        raise NotImplementedError(
 176            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 177        )
 178
 179    def __ror__(self, other: Any) -> None:
 180        raise NotImplementedError(
 181            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 182        )
 183
 184    def __rxor__(self, other: Any) -> None:
 185        raise NotImplementedError(
 186            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 187        )
 188
 189    def __invert__(self) -> NegatedExpression:
 190        return NegatedExpression(self)
 191
 192
 193class BaseExpression:
 194    """Base class for all query expressions."""
 195
 196    empty_result_set_value = NotImplemented
 197    # aggregate specific fields
 198    is_summary = False
 199    _output_field_resolved_to_none = False
 200    # Can the expression be used in a WHERE clause?
 201    filterable = True
 202    # Can the expression can be used as a source expression in Window?
 203    window_compatible = False
 204
 205    def __init__(self, output_field: Field | None = None):
 206        if output_field is not None:
 207            self.output_field = output_field
 208
 209    def __getstate__(self) -> dict[str, Any]:
 210        state = self.__dict__.copy()
 211        state.pop("convert_value", None)
 212        return state
 213
 214    def get_db_converters(
 215        self, connection: BaseDatabaseWrapper
 216    ) -> list[Callable[..., Any]]:
 217        converters = []
 218        if self.convert_value is not self._convert_value_noop:
 219            converters.append(self.convert_value)
 220        converters.extend(self.output_field.get_db_converters(connection))
 221        return converters
 222
 223    def get_source_expressions(self) -> list[Any]:
 224        return []
 225
 226    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 227        assert not exprs
 228
 229    def _parse_expressions(self, *expressions: Any) -> list[Any]:
 230        return [
 231            arg
 232            if isinstance(arg, ResolvableExpression)
 233            else (F(arg) if isinstance(arg, str) else Value(arg))
 234            for arg in expressions
 235        ]
 236
 237    def as_sql(
 238        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 239    ) -> tuple[str, Sequence[Any]]:
 240        """
 241        Responsible for returning a (sql, [params]) tuple to be included
 242        in the current query.
 243
 244        Different backends can provide their own implementation, by
 245        providing an `as_{vendor}` method and patching the Expression:
 246
 247        ```
 248        def override_as_sql(self, compiler, connection):
 249            # custom logic
 250            return super().as_sql(compiler, connection)
 251        setattr(Expression, 'as_' + connection.vendor, override_as_sql)
 252        ```
 253
 254        Arguments:
 255         * compiler: the query compiler responsible for generating the query.
 256           Must have a compile method, returning a (sql, [params]) tuple.
 257           Calling compiler(value) will return a quoted `value`.
 258
 259         * connection: the database connection used for the current query.
 260
 261        Return: (sql, params)
 262          Where `sql` is a string containing ordered sql parameters to be
 263          replaced with the elements of the list `params`.
 264        """
 265        raise NotImplementedError("Subclasses must implement as_sql()")
 266
 267    @cached_property
 268    def contains_aggregate(self) -> bool:
 269        return any(
 270            expr and expr.contains_aggregate for expr in self.get_source_expressions()
 271        )
 272
 273    @cached_property
 274    def contains_over_clause(self) -> bool:
 275        return any(
 276            expr and expr.contains_over_clause for expr in self.get_source_expressions()
 277        )
 278
 279    @cached_property
 280    def contains_column_references(self) -> bool:
 281        return any(
 282            expr and expr.contains_column_references
 283            for expr in self.get_source_expressions()
 284        )
 285
 286    def resolve_expression(
 287        self,
 288        query: Any = None,
 289        allow_joins: bool = True,
 290        reuse: Any = None,
 291        summarize: bool = False,
 292        for_save: bool = False,
 293    ) -> Self:
 294        """
 295        Provide the chance to do any preprocessing or validation before being
 296        added to the query.
 297
 298        Arguments:
 299         * query: the backend query implementation
 300         * allow_joins: boolean allowing or denying use of joins
 301           in this query
 302         * reuse: a set of reusable joins for multijoins
 303         * summarize: a terminal aggregate clause
 304         * for_save: whether this expression about to be used in a save or update
 305
 306        Return: an Expression to be added to the query.
 307        """
 308        c = self.copy()
 309        c.is_summary = summarize
 310        c.set_source_expressions(
 311            [
 312                expr.resolve_expression(query, allow_joins, reuse, summarize)
 313                if expr
 314                else None
 315                for expr in c.get_source_expressions()
 316            ]
 317        )
 318        return c
 319
 320    @property
 321    def conditional(self) -> bool:
 322        output_field = getattr(self, "output_field", None)
 323        return isinstance(output_field, fields.BooleanField)
 324
 325    @property
 326    def field(self) -> Field:
 327        return self.output_field
 328
 329    @cached_property
 330    def output_field(self) -> Field:
 331        """Return the output type of this expressions."""
 332        output_field = self._resolve_output_field()
 333        if output_field is None:
 334            self._output_field_resolved_to_none = True
 335            raise FieldError("Cannot resolve expression type, unknown output_field")
 336        return output_field
 337
 338    @cached_property
 339    def _output_field_or_none(self) -> Field | None:
 340        """
 341        Return the output field of this expression, or None if
 342        _resolve_output_field() didn't return an output type.
 343        """
 344        try:
 345            return self.output_field
 346        except FieldError:
 347            if not self._output_field_resolved_to_none:
 348                raise
 349            return None
 350
 351    def _resolve_output_field(self) -> Field | None:
 352        """
 353        Attempt to infer the output type of the expression.
 354
 355        As a guess, if the output fields of all source fields match then simply
 356        infer the same type here.
 357
 358        If a source's output field resolves to None, exclude it from this check.
 359        If all sources are None, then an error is raised higher up the stack in
 360        the output_field property.
 361        """
 362        # This guess is mostly a bad idea, but there is quite a lot of code
 363        # (especially 3rd party Func subclasses) that depend on it, we'd need a
 364        # deprecation path to fix it.
 365        sources_iter = (
 366            source for source in self.get_source_fields() if source is not None
 367        )
 368        for output_field in sources_iter:
 369            for source in sources_iter:
 370                if not isinstance(output_field, source.__class__):
 371                    raise FieldError(
 372                        f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
 373                        "set output_field."
 374                    )
 375            return output_field
 376        return None
 377
 378    @staticmethod
 379    def _convert_value_noop(
 380        value: Any, expression: Any, connection: BaseDatabaseWrapper
 381    ) -> Any:
 382        return value
 383
 384    @cached_property
 385    def convert_value(self) -> Callable[[Any, Any, Any], Any]:
 386        """
 387        Expressions provide their own converters because users have the option
 388        of manually specifying the output_field which may be a different type
 389        from the one the database returns.
 390        """
 391        field = self.output_field
 392        internal_type = field.get_internal_type()
 393        if internal_type == "FloatField":
 394            return (
 395                lambda value, expression, connection: None
 396                if value is None
 397                else float(value)
 398            )
 399        elif internal_type.endswith("IntegerField"):
 400            return (
 401                lambda value, expression, connection: None
 402                if value is None
 403                else int(value)
 404            )
 405        elif internal_type == "DecimalField":
 406            return (
 407                lambda value, expression, connection: None
 408                if value is None
 409                else Decimal(value)
 410            )
 411        return self._convert_value_noop
 412
 413    def get_lookup(self, lookup: str) -> type[Lookup] | None:
 414        return self.output_field.get_lookup(lookup)
 415
 416    def get_transform(self, name: str) -> type[Transform] | None:
 417        return self.output_field.get_transform(name)
 418
 419    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
 420        clone = self.copy()
 421        clone.set_source_expressions(
 422            [
 423                e.relabeled_clone(change_map) if e is not None else None
 424                for e in self.get_source_expressions()
 425            ]
 426        )
 427        return clone
 428
 429    def replace_expressions(self, replacements: dict[BaseExpression, Any]) -> Self:
 430        if replacement := replacements.get(self):
 431            return replacement
 432        clone = self.copy()
 433        source_expressions = clone.get_source_expressions()
 434        clone.set_source_expressions(
 435            [
 436                expr.replace_expressions(replacements) if expr else None
 437                for expr in source_expressions
 438            ]
 439        )
 440        return clone
 441
 442    def get_refs(self) -> set[str]:
 443        refs = set()
 444        for expr in self.get_source_expressions():
 445            refs |= expr.get_refs()
 446        return refs
 447
 448    def copy(self) -> Self:
 449        return copy.copy(self)
 450
 451    def prefix_references(self, prefix: str) -> Self:
 452        clone = self.copy()
 453        clone.set_source_expressions(
 454            [
 455                F(f"{prefix}{expr.name}")
 456                if isinstance(expr, F)
 457                else expr.prefix_references(prefix)
 458                for expr in self.get_source_expressions()
 459            ]
 460        )
 461        return clone
 462
 463    def get_group_by_cols(self) -> list[BaseExpression]:
 464        if not self.contains_aggregate:
 465            return [self]
 466        cols: list[BaseExpression] = []
 467        for source in self.get_source_expressions():
 468            cols.extend(source.get_group_by_cols())
 469        return cols
 470
 471    def get_source_fields(self) -> list[Field | None]:
 472        """Return the underlying field types used by this aggregate."""
 473        return [e._output_field_or_none for e in self.get_source_expressions()]
 474
 475    def asc(self, **kwargs: Any) -> OrderBy:
 476        return OrderBy(self, **kwargs)
 477
 478    def desc(self, **kwargs: Any) -> OrderBy:
 479        return OrderBy(self, descending=True, **kwargs)
 480
 481    def reverse_ordering(self) -> Self:
 482        return self
 483
 484    def flatten(self) -> Iterable[Any]:
 485        """
 486        Recursively yield this expression and all subexpressions, in
 487        depth-first order.
 488        """
 489        yield self
 490        for expr in self.get_source_expressions():
 491            if expr:
 492                if hasattr(expr, "flatten"):
 493                    yield from expr.flatten()
 494                else:
 495                    yield expr
 496
 497    def select_format(
 498        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
 499    ) -> tuple[str, Sequence[Any]]:
 500        """
 501        Custom format for select clauses. For example, EXISTS expressions need
 502        to be wrapped in CASE WHEN on Oracle.
 503        """
 504        if output_field := getattr(self, "output_field", None):
 505            if select_format := getattr(output_field, "select_format", None):
 506                return select_format(compiler, sql, params)
 507        return sql, params
 508
 509
 510@deconstructible
 511class Expression(BaseExpression, Combinable):
 512    """An expression that can be combined with other expressions."""
 513
 514    # Set by @deconstructible decorator in __new__
 515    _constructor_args: tuple[tuple[Any, ...], dict[str, Any]]
 516
 517    @cached_property
 518    def identity(self) -> tuple[Any, ...]:
 519        constructor_signature = inspect.signature(self.__init__)
 520        args, kwargs = self._constructor_args
 521        signature = constructor_signature.bind_partial(*args, **kwargs)
 522        signature.apply_defaults()
 523        arguments = signature.arguments.items()
 524        identity: list[Any] = [self.__class__]
 525        for arg, value in arguments:
 526            if isinstance(value, fields.Field):
 527                if value.name and value.model:
 528                    value = (value.model.model_options.label, value.name)
 529                else:
 530                    value = type(value)
 531            else:
 532                value = make_hashable(value)
 533            identity.append((arg, value))
 534        return tuple(identity)
 535
 536    def __eq__(self, other: object) -> bool:
 537        if not isinstance(other, Expression):
 538            return NotImplemented
 539        return other.identity == self.identity
 540
 541    def __hash__(self) -> int:
 542        return hash(self.identity)
 543
 544
 545# Type inference for CombinedExpression.output_field.
 546# Missing items will result in FieldError, by design.
 547#
 548# The current approach for NULL is based on lowest common denominator behavior
 549# i.e. if one of the supported databases is raising an error (rather than
 550# return NULL) for `val <op> NULL`, then Plain raises FieldError.
 551
 552_connector_combinations = [
 553    # Numeric operations - operands of same type.
 554    {
 555        connector: [
 556            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 557            (fields.FloatField, fields.FloatField, fields.FloatField),
 558            (fields.DecimalField, fields.DecimalField, fields.DecimalField),
 559        ]
 560        for connector in (
 561            Combinable.ADD,
 562            Combinable.SUB,
 563            Combinable.MUL,
 564            # Behavior for DIV with integer arguments follows Postgres/SQLite,
 565            # not MySQL/Oracle.
 566            Combinable.DIV,
 567            Combinable.MOD,
 568            Combinable.POW,
 569        )
 570    },
 571    # Numeric operations - operands of different type.
 572    {
 573        connector: [
 574            (fields.IntegerField, fields.DecimalField, fields.DecimalField),
 575            (fields.DecimalField, fields.IntegerField, fields.DecimalField),
 576            (fields.IntegerField, fields.FloatField, fields.FloatField),
 577            (fields.FloatField, fields.IntegerField, fields.FloatField),
 578        ]
 579        for connector in (
 580            Combinable.ADD,
 581            Combinable.SUB,
 582            Combinable.MUL,
 583            Combinable.DIV,
 584            Combinable.MOD,
 585        )
 586    },
 587    # Bitwise operators.
 588    {
 589        connector: [
 590            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 591        ]
 592        for connector in (
 593            Combinable.BITAND,
 594            Combinable.BITOR,
 595            Combinable.BITLEFTSHIFT,
 596            Combinable.BITRIGHTSHIFT,
 597            Combinable.BITXOR,
 598        )
 599    },
 600    # Numeric with NULL.
 601    {
 602        connector: [
 603            (field_type, NoneType, field_type),
 604            (NoneType, field_type, field_type),
 605        ]
 606        for connector in (
 607            Combinable.ADD,
 608            Combinable.SUB,
 609            Combinable.MUL,
 610            Combinable.DIV,
 611            Combinable.MOD,
 612            Combinable.POW,
 613        )
 614        for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
 615    },
 616    # Date/DateTimeField/DurationField/TimeField.
 617    {
 618        Combinable.ADD: [
 619            # Date/DateTimeField.
 620            (fields.DateField, fields.DurationField, fields.DateTimeField),
 621            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 622            (fields.DurationField, fields.DateField, fields.DateTimeField),
 623            (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
 624            # DurationField.
 625            (fields.DurationField, fields.DurationField, fields.DurationField),
 626            # TimeField.
 627            (fields.TimeField, fields.DurationField, fields.TimeField),
 628            (fields.DurationField, fields.TimeField, fields.TimeField),
 629        ],
 630    },
 631    {
 632        Combinable.SUB: [
 633            # Date/DateTimeField.
 634            (fields.DateField, fields.DurationField, fields.DateTimeField),
 635            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 636            (fields.DateField, fields.DateField, fields.DurationField),
 637            (fields.DateField, fields.DateTimeField, fields.DurationField),
 638            (fields.DateTimeField, fields.DateField, fields.DurationField),
 639            (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
 640            # DurationField.
 641            (fields.DurationField, fields.DurationField, fields.DurationField),
 642            # TimeField.
 643            (fields.TimeField, fields.DurationField, fields.TimeField),
 644            (fields.TimeField, fields.TimeField, fields.DurationField),
 645        ],
 646    },
 647]
 648
 649_connector_combinators = defaultdict(list)
 650
 651
 652def register_combinable_fields(
 653    lhs: type[Field] | type[None],
 654    connector: str,
 655    rhs: type[Field] | type[None],
 656    result: type[Field],
 657) -> None:
 658    """
 659    Register combinable types:
 660        lhs <connector> rhs -> result
 661    e.g.
 662        register_combinable_fields(
 663            IntegerField, Combinable.ADD, FloatField, FloatField
 664        )
 665    """
 666    _connector_combinators[connector].append((lhs, rhs, result))
 667
 668
 669for d in _connector_combinations:
 670    for connector, field_types in d.items():
 671        for lhs, rhs, result in field_types:
 672            register_combinable_fields(lhs, connector, rhs, result)
 673
 674
 675@functools.lru_cache(maxsize=128)
 676def _resolve_combined_type(
 677    connector: str, lhs_type: type[Field], rhs_type: type[Field]
 678) -> type[Field] | None:
 679    combinators = _connector_combinators.get(connector, ())
 680    for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
 681        if issubclass(lhs_type, combinator_lhs_type) and issubclass(
 682            rhs_type, combinator_rhs_type
 683        ):
 684            return combined_type
 685    return None
 686
 687
 688class SQLiteNumericMixin(Expression):
 689    """
 690    Some expressions with output_field=DecimalField() must be cast to
 691    numeric to be properly filtered.
 692    """
 693
 694    def as_sqlite(
 695        self,
 696        compiler: SQLCompiler,
 697        connection: BaseDatabaseWrapper,
 698        **extra_context: Any,
 699    ) -> tuple[str, Sequence[Any]]:
 700        sql, params = self.as_sql(compiler, connection, **extra_context)
 701        try:
 702            if self.output_field.get_internal_type() == "DecimalField":
 703                sql = f"CAST({sql} AS NUMERIC)"
 704        except FieldError:
 705            pass
 706        return sql, params
 707
 708
 709class CombinedExpression(SQLiteNumericMixin, Expression):
 710    def __init__(
 711        self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
 712    ):
 713        super().__init__(output_field=output_field)
 714        self.connector = connector
 715        self.lhs = lhs
 716        self.rhs = rhs
 717
 718    def __repr__(self) -> str:
 719        return f"<{self.__class__.__name__}: {self}>"
 720
 721    def __str__(self) -> str:
 722        return f"{self.lhs} {self.connector} {self.rhs}"
 723
 724    def get_source_expressions(self) -> list[Any]:
 725        return [self.lhs, self.rhs]
 726
 727    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 728        self.lhs, self.rhs = exprs
 729
 730    def _resolve_output_field(self) -> Field | None:
 731        # We avoid using super() here for reasons given in
 732        # Expression._resolve_output_field()
 733        combined_type = _resolve_combined_type(
 734            self.connector,
 735            type(self.lhs._output_field_or_none),
 736            type(self.rhs._output_field_or_none),
 737        )
 738        if combined_type is None:
 739            raise FieldError(
 740                f"Cannot infer type of {self.connector!r} expression involving these "
 741                f"types: {self.lhs.output_field.__class__.__name__}, "
 742                f"{self.rhs.output_field.__class__.__name__}. You must set "
 743                f"output_field."
 744            )
 745        return combined_type()
 746
 747    def as_sql(
 748        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 749    ) -> tuple[str, list[Any]]:
 750        expressions = []
 751        expression_params = []
 752        sql, params = compiler.compile(self.lhs)
 753        expressions.append(sql)
 754        expression_params.extend(params)
 755        sql, params = compiler.compile(self.rhs)
 756        expressions.append(sql)
 757        expression_params.extend(params)
 758        # order of precedence
 759        expression_wrapper = "(%s)"
 760        sql = connection.ops.combine_expression(self.connector, expressions)
 761        return expression_wrapper % sql, expression_params
 762
 763    def resolve_expression(
 764        self,
 765        query: Any = None,
 766        allow_joins: bool = True,
 767        reuse: Any = None,
 768        summarize: bool = False,
 769        for_save: bool = False,
 770    ) -> CombinedExpression | DurationExpression | TemporalSubtraction:
 771        lhs = self.lhs.resolve_expression(
 772            query, allow_joins, reuse, summarize, for_save
 773        )
 774        rhs = self.rhs.resolve_expression(
 775            query, allow_joins, reuse, summarize, for_save
 776        )
 777        if not isinstance(self, DurationExpression | TemporalSubtraction):
 778            try:
 779                lhs_type = lhs.output_field.get_internal_type()
 780            except (AttributeError, FieldError):
 781                lhs_type = None
 782            try:
 783                rhs_type = rhs.output_field.get_internal_type()
 784            except (AttributeError, FieldError):
 785                rhs_type = None
 786            if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
 787                return DurationExpression(
 788                    self.lhs, self.connector, self.rhs
 789                ).resolve_expression(
 790                    query,
 791                    allow_joins,
 792                    reuse,
 793                    summarize,
 794                    for_save,
 795                )
 796            datetime_fields = {"DateField", "DateTimeField", "TimeField"}
 797            if (
 798                self.connector == self.SUB
 799                and lhs_type in datetime_fields
 800                and lhs_type == rhs_type
 801            ):
 802                return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
 803                    query,
 804                    allow_joins,
 805                    reuse,
 806                    summarize,
 807                    for_save,
 808                )
 809        c = self.copy()
 810        c.is_summary = summarize
 811        c.lhs = lhs
 812        c.rhs = rhs
 813        return c
 814
 815
 816class DurationExpression(CombinedExpression):
 817    def compile(
 818        self, side: Any, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 819    ) -> tuple[str, Sequence[Any]]:
 820        try:
 821            output = side.output_field
 822        except FieldError:
 823            pass
 824        else:
 825            if output.get_internal_type() == "DurationField":
 826                sql, params = compiler.compile(side)
 827                return connection.ops.format_for_duration_arithmetic(sql), params
 828        return compiler.compile(side)
 829
 830    def as_sql(
 831        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 832    ) -> tuple[str, list[Any]]:
 833        if connection.features.has_native_duration_field:
 834            return super().as_sql(compiler, connection)
 835        connection.ops.check_expression_support(self)
 836        expressions = []
 837        expression_params = []
 838        sql, params = self.compile(self.lhs, compiler, connection)
 839        expressions.append(sql)
 840        expression_params.extend(params)
 841        sql, params = self.compile(self.rhs, compiler, connection)
 842        expressions.append(sql)
 843        expression_params.extend(params)
 844        # order of precedence
 845        expression_wrapper = "(%s)"
 846        sql = connection.ops.combine_duration_expression(self.connector, expressions)
 847        return expression_wrapper % sql, expression_params
 848
 849    def as_sqlite(
 850        self,
 851        compiler: SQLCompiler,
 852        connection: BaseDatabaseWrapper,
 853        **extra_context: Any,
 854    ) -> tuple[str, Sequence[Any]]:
 855        sql, params = self.as_sql(compiler, connection, **extra_context)
 856        if self.connector in {Combinable.MUL, Combinable.DIV}:
 857            try:
 858                lhs_type = self.lhs.output_field.get_internal_type()
 859                rhs_type = self.rhs.output_field.get_internal_type()
 860            except (AttributeError, FieldError):
 861                pass
 862            else:
 863                allowed_fields = {
 864                    "DecimalField",
 865                    "DurationField",
 866                    "FloatField",
 867                    "IntegerField",
 868                }
 869                if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
 870                    raise DatabaseError(
 871                        f"Invalid arguments for operator {self.connector}."
 872                    )
 873        return sql, params
 874
 875
 876class TemporalSubtraction(CombinedExpression):
 877    output_field = fields.DurationField()
 878
 879    def __init__(self, lhs: Any, rhs: Any):
 880        super().__init__(lhs, self.SUB, rhs)
 881
 882    def as_sql(
 883        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 884    ) -> tuple[str, list[Any]]:
 885        connection.ops.check_expression_support(self)
 886        lhs = compiler.compile(self.lhs)
 887        rhs = compiler.compile(self.rhs)
 888        sql, params = connection.ops.subtract_temporals(
 889            self.lhs.output_field.get_internal_type(), lhs, rhs
 890        )
 891        return sql, list(params)
 892
 893
 894@deconstructible(path="plain.models.F")
 895class F(Combinable):
 896    """An object capable of resolving references to existing query objects."""
 897
 898    def __init__(self, name: str):
 899        """
 900        Arguments:
 901         * name: the name of the field this expression references
 902        """
 903        self.name = name
 904
 905    def __repr__(self) -> str:
 906        return f"{self.__class__.__name__}({self.name})"
 907
 908    def resolve_expression(
 909        self,
 910        query: Any = None,
 911        allow_joins: bool = True,
 912        reuse: Any = None,
 913        summarize: bool = False,
 914        for_save: bool = False,
 915    ) -> Any:
 916        return query.resolve_ref(self.name, allow_joins, reuse, summarize)
 917
 918    def replace_expressions(self, replacements: dict[Any, Any]) -> F:
 919        return replacements.get(self, self)
 920
 921    def asc(self, **kwargs: Any) -> OrderBy:
 922        return OrderBy(self, **kwargs)
 923
 924    def desc(self, **kwargs: Any) -> OrderBy:
 925        return OrderBy(self, descending=True, **kwargs)
 926
 927    def __eq__(self, other: object) -> bool:
 928        if not isinstance(other, F):
 929            return NotImplemented
 930        return self.__class__ == other.__class__ and self.name == other.name
 931
 932    def __hash__(self) -> int:
 933        return hash(self.name)
 934
 935    def copy(self) -> Self:
 936        return copy.copy(self)
 937
 938
 939class ResolvedOuterRef(F):
 940    """
 941    An object that contains a reference to an outer query.
 942
 943    In this case, the reference to the outer query has been resolved because
 944    the inner query has been used as a subquery.
 945    """
 946
 947    contains_aggregate = False
 948    contains_over_clause = False
 949
 950    def as_sql(self, *args: Any, **kwargs: Any) -> None:
 951        raise ValueError(
 952            "This queryset contains a reference to an outer query and may "
 953            "only be used in a subquery."
 954        )
 955
 956    def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
 957        col = super().resolve_expression(*args, **kwargs)
 958        if col.contains_over_clause:
 959            raise NotSupportedError(
 960                f"Referencing outer query window expression is not supported: "
 961                f"{self.name}."
 962            )
 963        # FIXME: Rename possibly_multivalued to multivalued and fix detection
 964        # for non-multivalued JOINs (e.g. foreign key fields). This should take
 965        # into account only many-to-many and one-to-many relationships.
 966        col.possibly_multivalued = LOOKUP_SEP in self.name
 967        return col
 968
 969    def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
 970        return self
 971
 972    def get_group_by_cols(self) -> list[Any]:
 973        return []
 974
 975
 976class OuterRef(F):
 977    contains_aggregate = False
 978
 979    def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
 980        if isinstance(self.name, self.__class__):
 981            return self.name
 982        return ResolvedOuterRef(self.name)
 983
 984    def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
 985        return self
 986
 987
 988@deconstructible(path="plain.models.Func")
 989class Func(SQLiteNumericMixin, Expression):
 990    """An SQL function call."""
 991
 992    function = None
 993    template = "%(function)s(%(expressions)s)"
 994    arg_joiner = ", "
 995    arity = None  # The number of arguments the function accepts.
 996
 997    def __init__(
 998        self, *expressions: Any, output_field: Field | None = None, **extra: Any
 999    ):
1000        if self.arity is not None and len(expressions) != self.arity:
1001            raise TypeError(
1002                "'{}' takes exactly {} {} ({} given)".format(
1003                    self.__class__.__name__,
1004                    self.arity,
1005                    "argument" if self.arity == 1 else "arguments",
1006                    len(expressions),
1007                )
1008            )
1009        super().__init__(output_field=output_field)
1010        self.source_expressions: list[Any] = self._parse_expressions(*expressions)
1011        self.extra = extra
1012
1013    def __repr__(self) -> str:
1014        args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1015        extra = {**self.extra, **self._get_repr_options()}
1016        if extra:
1017            extra = ", ".join(
1018                str(key) + "=" + str(val) for key, val in sorted(extra.items())
1019            )
1020            return f"{self.__class__.__name__}({args}, {extra})"
1021        return f"{self.__class__.__name__}({args})"
1022
1023    def _get_repr_options(self) -> dict[str, Any]:
1024        """Return a dict of extra __init__() options to include in the repr."""
1025        return {}
1026
1027    def get_source_expressions(self) -> list[Any]:
1028        return self.source_expressions
1029
1030    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1031        self.source_expressions = list(exprs)
1032
1033    def resolve_expression(
1034        self,
1035        query: Any = None,
1036        allow_joins: bool = True,
1037        reuse: Any = None,
1038        summarize: bool = False,
1039        for_save: bool = False,
1040    ) -> Self:
1041        c = self.copy()
1042        c.is_summary = summarize
1043        for pos, arg in enumerate(c.source_expressions):
1044            c.source_expressions[pos] = arg.resolve_expression(
1045                query, allow_joins, reuse, summarize, for_save
1046            )
1047        return c
1048
1049    def as_sql(
1050        self,
1051        compiler: SQLCompiler,
1052        connection: BaseDatabaseWrapper,
1053        function: str | None = None,
1054        template: str | None = None,
1055        arg_joiner: str | None = None,
1056        **extra_context: Any,
1057    ) -> tuple[str, list[Any]]:
1058        connection.ops.check_expression_support(self)
1059        sql_parts = []
1060        params = []
1061        for arg in self.source_expressions:
1062            try:
1063                arg_sql, arg_params = compiler.compile(arg)
1064            except EmptyResultSet:
1065                empty_result_set_value = getattr(
1066                    arg, "empty_result_set_value", NotImplemented
1067                )
1068                if empty_result_set_value is NotImplemented:
1069                    raise
1070                arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
1071            except FullResultSet:
1072                arg_sql, arg_params = compiler.compile(Value(True))
1073            sql_parts.append(arg_sql)
1074            params.extend(arg_params)
1075        data = {**self.extra, **extra_context}
1076        # Use the first supplied value in this order: the parameter to this
1077        # method, a value supplied in __init__()'s **extra (the value in
1078        # `data`), or the value defined on the class.
1079        if function is not None:
1080            data["function"] = function
1081        else:
1082            data.setdefault("function", self.function)
1083        template = template or data.get("template", self.template)
1084        arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1085        data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1086        return template % data, params
1087
1088    def copy(self) -> Self:
1089        clone = super().copy()
1090        clone.source_expressions = self.source_expressions[:]
1091        clone.extra = self.extra.copy()
1092        return cast(Self, clone)
1093
1094
1095@deconstructible(path="plain.models.Value")
1096class Value(SQLiteNumericMixin, Expression):
1097    """Represent a wrapped value as a node within an expression."""
1098
1099    # Provide a default value for `for_save` in order to allow unresolved
1100    # instances to be compiled until a decision is taken in #25425.
1101    for_save = False
1102
1103    def __init__(self, value: Any, output_field: Field | None = None):
1104        """
1105        Arguments:
1106         * value: the value this expression represents. The value will be
1107           added into the sql parameter list and properly quoted.
1108
1109         * output_field: an instance of the model field type that this
1110           expression will return, such as IntegerField() or CharField().
1111        """
1112        super().__init__(output_field=output_field)
1113        self.value = value
1114
1115    def __repr__(self) -> str:
1116        return f"{self.__class__.__name__}({self.value!r})"
1117
1118    def as_sql(
1119        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1120    ) -> tuple[str, list[Any]]:
1121        connection.ops.check_expression_support(self)
1122        val = self.value
1123        output_field = self._output_field_or_none
1124        if output_field is not None:
1125            if self.for_save:
1126                val = output_field.get_db_prep_save(val, connection=connection)
1127            else:
1128                val = output_field.get_db_prep_value(val, connection=connection)
1129            if hasattr(output_field, "get_placeholder"):
1130                return output_field.get_placeholder(val, compiler, connection), [val]
1131        if val is None:
1132            # cx_Oracle does not always convert None to the appropriate
1133            # NULL type (like in case expressions using numbers), so we
1134            # use a literal SQL NULL
1135            return "NULL", []
1136        return "%s", [val]
1137
1138    def resolve_expression(
1139        self,
1140        query: Any = None,
1141        allow_joins: bool = True,
1142        reuse: Any = None,
1143        summarize: bool = False,
1144        for_save: bool = False,
1145    ) -> Value:
1146        c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1147        c.for_save = for_save
1148        return c
1149
1150    def get_group_by_cols(self) -> list[Any]:
1151        return []
1152
1153    def _resolve_output_field(self) -> Field | None:
1154        if isinstance(self.value, str):
1155            return fields.CharField()
1156        if isinstance(self.value, bool):
1157            return fields.BooleanField()
1158        if isinstance(self.value, int):
1159            return fields.IntegerField()
1160        if isinstance(self.value, float):
1161            return fields.FloatField()
1162        if isinstance(self.value, datetime.datetime):
1163            return fields.DateTimeField()
1164        if isinstance(self.value, datetime.date):
1165            return fields.DateField()
1166        if isinstance(self.value, datetime.time):
1167            return fields.TimeField()
1168        if isinstance(self.value, datetime.timedelta):
1169            return fields.DurationField()
1170        if isinstance(self.value, Decimal):
1171            return fields.DecimalField()
1172        if isinstance(self.value, bytes):
1173            return fields.BinaryField()
1174        if isinstance(self.value, UUID):
1175            return fields.UUIDField()
1176
1177    @property
1178    def empty_result_set_value(self) -> Any:
1179        return self.value
1180
1181
1182class RawSQL(Expression):
1183    def __init__(
1184        self, sql: str, params: Sequence[Any], output_field: Field | None = None
1185    ):
1186        if output_field is None:
1187            output_field = fields.Field()
1188        self.sql, self.params = sql, params
1189        super().__init__(output_field=output_field)
1190
1191    def __repr__(self) -> str:
1192        return f"{self.__class__.__name__}({self.sql}, {self.params})"
1193
1194    def as_sql(
1195        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1196    ) -> tuple[str, Sequence[Any]]:
1197        return f"({self.sql})", self.params
1198
1199    def get_group_by_cols(self) -> list[BaseExpression]:
1200        return [self]
1201
1202
1203class Star(Expression):
1204    def __repr__(self) -> str:
1205        return "'*'"
1206
1207    def as_sql(
1208        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1209    ) -> tuple[str, list[Any]]:
1210        return "*", []
1211
1212
1213class Col(Expression):
1214    contains_column_references = True
1215    possibly_multivalued = False
1216
1217    def __init__(
1218        self, alias: str | None, target: Any, output_field: Field | None = None
1219    ):
1220        if output_field is None:
1221            output_field = target
1222        super().__init__(output_field=output_field)
1223        self.alias, self.target = alias, target
1224
1225    def __repr__(self) -> str:
1226        alias, target = self.alias, self.target
1227        identifiers = (alias, str(target)) if alias else (str(target),)
1228        return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1229
1230    def as_sql(
1231        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1232    ) -> tuple[str, list[Any]]:
1233        alias, column = self.alias, self.target.column
1234        identifiers = (alias, column) if alias else (column,)
1235        sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1236        return sql, []
1237
1238    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1239        if self.alias is None:
1240            return self
1241        return self.__class__(
1242            change_map.get(self.alias, self.alias), self.target, self.output_field
1243        )
1244
1245    def get_group_by_cols(self) -> list[BaseExpression]:
1246        return [self]
1247
1248    def get_db_converters(
1249        self, connection: BaseDatabaseWrapper
1250    ) -> list[Callable[..., Any]]:
1251        if self.target == self.output_field:
1252            return self.output_field.get_db_converters(connection)
1253        return self.output_field.get_db_converters(
1254            connection
1255        ) + self.target.get_db_converters(connection)
1256
1257
1258class Ref(Expression):
1259    """
1260    Reference to column alias of the query. For example, Ref('sum_cost') in
1261    qs.annotate(sum_cost=Sum('cost')) query.
1262    """
1263
1264    def __init__(self, refs: str, source: Any):
1265        super().__init__()
1266        self.refs, self.source = refs, source
1267
1268    def __repr__(self) -> str:
1269        return f"{self.__class__.__name__}({self.refs}, {self.source})"
1270
1271    def get_source_expressions(self) -> list[Any]:
1272        return [self.source]
1273
1274    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1275        (self.source,) = exprs
1276
1277    def resolve_expression(
1278        self,
1279        query: Any = None,
1280        allow_joins: bool = True,
1281        reuse: Any = None,
1282        summarize: bool = False,
1283        for_save: bool = False,
1284    ) -> Ref:
1285        # The sub-expression `source` has already been resolved, as this is
1286        # just a reference to the name of `source`.
1287        return self
1288
1289    def get_refs(self) -> set[str]:
1290        return {self.refs}
1291
1292    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1293        return self
1294
1295    def as_sql(
1296        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1297    ) -> tuple[str, list[Any]]:
1298        return connection.ops.quote_name(self.refs), []
1299
1300    def get_group_by_cols(self) -> list[BaseExpression]:
1301        return [self]
1302
1303
1304class ExpressionList(Func):
1305    """
1306    An expression containing multiple expressions. Can be used to provide a
1307    list of expressions as an argument to another expression, like a partition
1308    clause.
1309    """
1310
1311    template = "%(expressions)s"
1312
1313    def __init__(self, *expressions: Any, **extra: Any):
1314        if not expressions:
1315            raise ValueError(
1316                f"{self.__class__.__name__} requires at least one expression."
1317            )
1318        super().__init__(*expressions, **extra)
1319
1320    def __str__(self) -> str:
1321        return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1322
1323    def as_sqlite(
1324        self,
1325        compiler: SQLCompiler,
1326        connection: BaseDatabaseWrapper,
1327        **extra_context: Any,
1328    ) -> tuple[str, Sequence[Any]]:
1329        # Casting to numeric is unnecessary.
1330        return self.as_sql(compiler, connection, **extra_context)
1331
1332
1333class OrderByList(Func):
1334    template = "ORDER BY %(expressions)s"
1335
1336    def __init__(self, *expressions: Any, **extra: Any):
1337        expressions_tuple = tuple(
1338            (
1339                OrderBy(F(expr[1:]), descending=True)
1340                if isinstance(expr, str) and expr[0] == "-"
1341                else expr
1342            )
1343            for expr in expressions
1344        )
1345        super().__init__(*expressions_tuple, **extra)
1346
1347    def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, list[Any]]:
1348        if not self.source_expressions:
1349            return "", []
1350        sql, params = super().as_sql(*args, **kwargs)
1351        return sql, list(params)
1352
1353    def get_group_by_cols(self) -> list[Any]:
1354        group_by_cols = []
1355        for order_by in self.get_source_expressions():
1356            group_by_cols.extend(order_by.get_group_by_cols())
1357        return group_by_cols
1358
1359
1360@deconstructible(path="plain.models.ExpressionWrapper")
1361class ExpressionWrapper(SQLiteNumericMixin, Expression):
1362    """
1363    An expression that can wrap another expression so that it can provide
1364    extra context to the inner expression, such as the output_field.
1365    """
1366
1367    def __init__(self, expression: Any, output_field: Field):
1368        super().__init__(output_field=output_field)
1369        self.expression = expression
1370
1371    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1372        self.expression = exprs[0]
1373
1374    def get_source_expressions(self) -> list[Any]:
1375        return [self.expression]
1376
1377    def get_group_by_cols(self) -> list[Any]:
1378        if isinstance(self.expression, Expression):
1379            expression = self.expression.copy()
1380            expression.output_field = self.output_field
1381            return expression.get_group_by_cols()
1382        # For non-expressions e.g. an SQL WHERE clause, the entire
1383        # `expression` must be included in the GROUP BY clause.
1384        return super().get_group_by_cols()
1385
1386    def as_sql(
1387        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1388    ) -> tuple[str, Sequence[Any]]:
1389        return compiler.compile(self.expression)
1390
1391    def __repr__(self) -> str:
1392        return f"{self.__class__.__name__}({self.expression})"
1393
1394
1395class NegatedExpression(ExpressionWrapper):
1396    """The logical negation of a conditional expression."""
1397
1398    def __init__(self, expression: Any):
1399        super().__init__(expression, output_field=fields.BooleanField())
1400
1401    def __invert__(self) -> Any:
1402        return self.expression.copy()
1403
1404    def as_sql(
1405        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1406    ) -> tuple[str, Sequence[Any]]:
1407        try:
1408            sql, params = super().as_sql(compiler, connection)
1409        except EmptyResultSet:
1410            features = compiler.connection.features
1411            if not features.supports_boolean_expr_in_select_clause:
1412                return "1=1", ()
1413            return compiler.compile(Value(True))
1414        ops = compiler.connection.ops
1415        # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1416        # to be compared to another expression unless they're wrapped in a CASE
1417        # WHEN.
1418        if not ops.conditional_expression_supported_in_where_clause(self.expression):
1419            return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1420        return f"NOT {sql}", params
1421
1422    def resolve_expression(
1423        self,
1424        query: Any = None,
1425        allow_joins: bool = True,
1426        reuse: Any = None,
1427        summarize: bool = False,
1428        for_save: bool = False,
1429    ) -> NegatedExpression:
1430        resolved = super().resolve_expression(
1431            query, allow_joins, reuse, summarize, for_save
1432        )
1433        if not getattr(resolved.expression, "conditional", False):
1434            raise TypeError("Cannot negate non-conditional expressions.")
1435        return resolved
1436
1437    def select_format(
1438        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1439    ) -> tuple[str, Sequence[Any]]:
1440        # Wrap boolean expressions with a CASE WHEN expression if a database
1441        # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1442        # GROUP BY list.
1443        expression_supported_in_where_clause = (
1444            compiler.connection.ops.conditional_expression_supported_in_where_clause
1445        )
1446        if (
1447            not compiler.connection.features.supports_boolean_expr_in_select_clause
1448            # Avoid double wrapping.
1449            and expression_supported_in_where_clause(self.expression)
1450        ):
1451            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1452        return sql, params
1453
1454
1455@deconstructible(path="plain.models.When")
1456class When(Expression):
1457    template = "WHEN %(condition)s THEN %(result)s"
1458    # This isn't a complete conditional expression, must be used in Case().
1459    conditional = False
1460    condition: SQLCompilable
1461
1462    def __init__(
1463        self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1464    ):
1465        lookups_dict: dict[str, Any] | None = lookups or None
1466        if lookups_dict:
1467            if condition is None:
1468                condition, lookups_dict = Q(**lookups_dict), None
1469            elif getattr(condition, "conditional", False):
1470                condition, lookups_dict = Q(condition, **lookups_dict), None
1471        if (
1472            condition is None
1473            or not getattr(condition, "conditional", False)
1474            or lookups_dict
1475        ):
1476            raise TypeError(
1477                "When() supports a Q object, a boolean expression, or lookups "
1478                "as a condition."
1479            )
1480        if isinstance(condition, Q) and not condition:
1481            raise ValueError("An empty Q() can't be used as a When() condition.")
1482        super().__init__(output_field=None)
1483        self.condition = condition  # type: ignore[assignment]
1484        self.result = self._parse_expressions(then)[0]
1485
1486    def __str__(self) -> str:
1487        return f"WHEN {self.condition!r} THEN {self.result!r}"
1488
1489    def __repr__(self) -> str:
1490        return f"<{self.__class__.__name__}: {self}>"
1491
1492    def get_source_expressions(self) -> list[Any]:
1493        return [self.condition, self.result]
1494
1495    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1496        self.condition, self.result = exprs
1497
1498    def get_source_fields(self) -> list[Field | None]:
1499        # We're only interested in the fields of the result expressions.
1500        return [self.result._output_field_or_none]
1501
1502    def resolve_expression(
1503        self,
1504        query: Any = None,
1505        allow_joins: bool = True,
1506        reuse: Any = None,
1507        summarize: bool = False,
1508        for_save: bool = False,
1509    ) -> When:
1510        c = self.copy()
1511        c.is_summary = summarize
1512        if isinstance(c.condition, ResolvableExpression):
1513            c.condition = c.condition.resolve_expression(
1514                query, allow_joins, reuse, summarize, False
1515            )
1516        c.result = c.result.resolve_expression(
1517            query, allow_joins, reuse, summarize, for_save
1518        )
1519        return c
1520
1521    def as_sql(
1522        self,
1523        compiler: SQLCompiler,
1524        connection: BaseDatabaseWrapper,
1525        template: str | None = None,
1526        **extra_context: Any,
1527    ) -> tuple[str, tuple[Any, ...]]:
1528        connection.ops.check_expression_support(self)
1529        template_params = extra_context
1530        sql_params = []
1531        # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1532        condition_sql, condition_params = compiler.compile(self.condition)
1533        template_params["condition"] = condition_sql
1534        result_sql, result_params = compiler.compile(self.result)
1535        template_params["result"] = result_sql
1536        template = template or self.template
1537        return template % template_params, (
1538            *sql_params,
1539            *condition_params,
1540            *result_params,
1541        )
1542
1543    def get_group_by_cols(self) -> list[Any]:
1544        # This is not a complete expression and cannot be used in GROUP BY.
1545        cols = []
1546        for source in self.get_source_expressions():
1547            cols.extend(source.get_group_by_cols())
1548        return cols
1549
1550
1551@deconstructible(path="plain.models.Case")
1552class Case(SQLiteNumericMixin, Expression):
1553    """
1554    An SQL searched CASE expression:
1555
1556        CASE
1557            WHEN n > 0
1558                THEN 'positive'
1559            WHEN n < 0
1560                THEN 'negative'
1561            ELSE 'zero'
1562        END
1563    """
1564
1565    template = "CASE %(cases)s ELSE %(default)s END"
1566    case_joiner = " "
1567
1568    def __init__(
1569        self,
1570        *cases: When,
1571        default: Any = None,
1572        output_field: Field | None = None,
1573        **extra: Any,
1574    ):
1575        if not all(isinstance(case, When) for case in cases):
1576            raise TypeError("Positional arguments must all be When objects.")
1577        super().__init__(output_field)
1578        self.cases = list(cases)
1579        self.default = self._parse_expressions(default)[0]
1580        self.extra = extra
1581
1582    def __str__(self) -> str:
1583        return "CASE {}, ELSE {!r}".format(
1584            ", ".join(str(c) for c in self.cases),
1585            self.default,
1586        )
1587
1588    def __repr__(self) -> str:
1589        return f"<{self.__class__.__name__}: {self}>"
1590
1591    def get_source_expressions(self) -> list[Any]:
1592        return self.cases + [self.default]
1593
1594    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1595        *self.cases, self.default = exprs
1596
1597    def resolve_expression(
1598        self,
1599        query: Any = None,
1600        allow_joins: bool = True,
1601        reuse: Any = None,
1602        summarize: bool = False,
1603        for_save: bool = False,
1604    ) -> Case:
1605        c = self.copy()
1606        c.is_summary = summarize
1607        for pos, case in enumerate(c.cases):
1608            c.cases[pos] = case.resolve_expression(
1609                query, allow_joins, reuse, summarize, for_save
1610            )
1611        c.default = c.default.resolve_expression(
1612            query, allow_joins, reuse, summarize, for_save
1613        )
1614        return c
1615
1616    def copy(self) -> Self:
1617        c = super().copy()
1618        c.cases = c.cases[:]
1619        return cast(Self, c)
1620
1621    def as_sql(
1622        self,
1623        compiler: SQLCompiler,
1624        connection: BaseDatabaseWrapper,
1625        template: str | None = None,
1626        case_joiner: str | None = None,
1627        **extra_context: Any,
1628    ) -> tuple[str, list[Any]]:
1629        connection.ops.check_expression_support(self)
1630        if not self.cases:
1631            sql, params = compiler.compile(self.default)
1632            return sql, list(params)
1633        template_params = {**self.extra, **extra_context}
1634        case_parts = []
1635        sql_params = []
1636        default_sql, default_params = compiler.compile(self.default)
1637        for case in self.cases:
1638            try:
1639                case_sql, case_params = compiler.compile(case)
1640            except EmptyResultSet:
1641                continue
1642            except FullResultSet:
1643                default_sql, default_params = compiler.compile(case.result)
1644                break
1645            case_parts.append(case_sql)
1646            sql_params.extend(case_params)
1647        if not case_parts:
1648            return default_sql, list(default_params)
1649        case_joiner = case_joiner or self.case_joiner
1650        template_params["cases"] = case_joiner.join(case_parts)
1651        template_params["default"] = default_sql
1652        sql_params.extend(default_params)
1653        template = template or template_params.get("template", self.template)
1654        sql = template % template_params
1655        if self._output_field_or_none is not None:
1656            sql = connection.ops.unification_cast_sql(self.output_field) % sql
1657        return sql, sql_params
1658
1659    def get_group_by_cols(self) -> list[Any]:
1660        if not self.cases:
1661            return self.default.get_group_by_cols()
1662        return super().get_group_by_cols()
1663
1664
1665class Subquery(BaseExpression, Combinable):
1666    """
1667    An explicit subquery. It may contain OuterRef() references to the outer
1668    query which will be resolved when it is applied to that query.
1669    """
1670
1671    template = "(%(subquery)s)"
1672    contains_aggregate = False
1673    empty_result_set_value = None
1674
1675    def __init__(
1676        self,
1677        query: QuerySet[Any] | Query,
1678        output_field: Field | None = None,
1679        **extra: Any,
1680    ):
1681        # Import here to avoid circular import
1682        from plain.models.sql.query import Query
1683
1684        # Allow the usage of both QuerySet and sql.Query objects.
1685        if isinstance(query, Query):
1686            # It's already a Query object, use it directly
1687            sql_query = query
1688        else:
1689            # It's a QuerySet, extract the sql.Query
1690            sql_query = query.sql_query
1691        self.query = sql_query.clone()
1692        self.query.subquery = True
1693        self.extra = extra
1694        super().__init__(output_field)
1695
1696    def get_source_expressions(self) -> list[Any]:
1697        return [self.query]
1698
1699    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1700        self.query = exprs[0]
1701
1702    def _resolve_output_field(self) -> Field | None:
1703        return self.query.output_field
1704
1705    def copy(self) -> Self:
1706        clone = super().copy()
1707        clone.query = clone.query.clone()
1708        return cast(Self, clone)
1709
1710    @property
1711    def external_aliases(self) -> dict[str, bool]:
1712        return self.query.external_aliases
1713
1714    def get_external_cols(self) -> list[Any]:
1715        return self.query.get_external_cols()
1716
1717    def as_sql(
1718        self,
1719        compiler: SQLCompiler,
1720        connection: BaseDatabaseWrapper,
1721        template: str | None = None,
1722        **extra_context: Any,
1723    ) -> tuple[str, tuple[Any, ...]]:
1724        connection.ops.check_expression_support(self)
1725        template_params = {**self.extra, **extra_context}
1726        subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1727        template_params["subquery"] = subquery_sql[1:-1]
1728
1729        template = template or template_params.get("template", self.template)
1730        sql = template % template_params
1731        return sql, sql_params
1732
1733    def get_group_by_cols(self) -> list[Any]:
1734        return self.query.get_group_by_cols(wrapper=self)
1735
1736
1737class Exists(Subquery):
1738    template = "EXISTS(%(subquery)s)"
1739    output_field = fields.BooleanField()
1740    empty_result_set_value = False
1741
1742    def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1743        super().__init__(query, **kwargs)
1744        self.query = self.query.exists()
1745
1746    def select_format(
1747        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1748    ) -> tuple[str, Sequence[Any]]:
1749        # Wrap EXISTS() with a CASE WHEN expression if a database backend
1750        # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1751        # BY list.
1752        if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1753            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1754        return sql, params
1755
1756
1757@deconstructible(path="plain.models.OrderBy")
1758class OrderBy(Expression):
1759    template = "%(expression)s %(ordering)s"
1760    conditional = False
1761
1762    def __init__(
1763        self,
1764        expression: Any,
1765        descending: bool = False,
1766        nulls_first: bool | None = None,
1767        nulls_last: bool | None = None,
1768    ):
1769        if nulls_first and nulls_last:
1770            raise ValueError("nulls_first and nulls_last are mutually exclusive")
1771        if nulls_first is False or nulls_last is False:
1772            raise ValueError("nulls_first and nulls_last values must be True or None.")
1773        self.nulls_first = nulls_first
1774        self.nulls_last = nulls_last
1775        self.descending = descending
1776        if not isinstance(expression, ResolvableExpression):
1777            raise ValueError("expression must be an expression type")
1778        self.expression = expression
1779
1780    def __repr__(self) -> str:
1781        return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1782
1783    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1784        self.expression = exprs[0]
1785
1786    def get_source_expressions(self) -> list[Any]:
1787        return [self.expression]
1788
1789    def as_sql(
1790        self,
1791        compiler: SQLCompiler,
1792        connection: BaseDatabaseWrapper,
1793        template: str | None = None,
1794        **extra_context: Any,
1795    ) -> tuple[str, tuple[Any, ...]]:
1796        template = template or self.template
1797        if connection.features.supports_order_by_nulls_modifier:
1798            if self.nulls_last:
1799                template = f"{template} NULLS LAST"
1800            elif self.nulls_first:
1801                template = f"{template} NULLS FIRST"
1802        else:
1803            if self.nulls_last and not (
1804                self.descending and connection.features.order_by_nulls_first
1805            ):
1806                template = f"%(expression)s IS NULL, {template}"
1807            elif self.nulls_first and not (
1808                not self.descending and connection.features.order_by_nulls_first
1809            ):
1810                template = f"%(expression)s IS NOT NULL, {template}"
1811        connection.ops.check_expression_support(self)
1812        expression_sql, params = compiler.compile(self.expression)
1813        placeholders = {
1814            "expression": expression_sql,
1815            "ordering": "DESC" if self.descending else "ASC",
1816            **extra_context,
1817        }
1818        params *= template.count("%(expression)s")
1819        return (template % placeholders).rstrip(), params
1820
1821    def get_group_by_cols(self) -> list[Any]:
1822        cols = []
1823        for source in self.get_source_expressions():
1824            cols.extend(source.get_group_by_cols())
1825        return cols
1826
1827    def reverse_ordering(self) -> OrderBy:
1828        self.descending = not self.descending
1829        if self.nulls_first:
1830            self.nulls_last = True
1831            self.nulls_first = None
1832        elif self.nulls_last:
1833            self.nulls_first = True
1834            self.nulls_last = None
1835        return self
1836
1837    def asc(self) -> None:  # type: ignore[override]
1838        self.descending = False
1839
1840    def desc(self) -> None:  # type: ignore[override]
1841        self.descending = True
1842
1843
1844class Window(SQLiteNumericMixin, Expression):
1845    template = "%(expression)s OVER (%(window)s)"
1846    # Although the main expression may either be an aggregate or an
1847    # expression with an aggregate function, the GROUP BY that will
1848    # be introduced in the query as a result is not desired.
1849    contains_aggregate = False
1850    contains_over_clause = True
1851    partition_by: ExpressionList | None
1852    order_by: OrderByList | None
1853
1854    def __init__(
1855        self,
1856        expression: Any,
1857        partition_by: Any = None,
1858        order_by: Any = None,
1859        frame: Any = None,
1860        output_field: Field | None = None,
1861    ):
1862        self.partition_by = partition_by
1863        self.order_by = order_by
1864        self.frame = frame
1865
1866        if not getattr(expression, "window_compatible", False):
1867            raise ValueError(
1868                f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1869            )
1870
1871        if self.partition_by is not None:
1872            partition_by_values = (
1873                self.partition_by
1874                if isinstance(self.partition_by, tuple | list)
1875                else (self.partition_by,)
1876            )
1877            self.partition_by = ExpressionList(*partition_by_values)
1878
1879        if self.order_by is not None:
1880            if isinstance(self.order_by, list | tuple):
1881                self.order_by = OrderByList(*self.order_by)
1882            elif isinstance(self.order_by, BaseExpression | str):
1883                self.order_by = OrderByList(self.order_by)
1884            else:
1885                raise ValueError(
1886                    "Window.order_by must be either a string reference to a "
1887                    "field, an expression, or a list or tuple of them."
1888                )
1889        super().__init__(output_field=output_field)
1890        self.source_expression = self._parse_expressions(expression)[0]
1891
1892    def _resolve_output_field(self) -> Field | None:
1893        return self.source_expression.output_field
1894
1895    def get_source_expressions(self) -> list[Any]:
1896        return [self.source_expression, self.partition_by, self.order_by, self.frame]
1897
1898    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1899        self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1900
1901    def as_sql(
1902        self,
1903        compiler: SQLCompiler,
1904        connection: BaseDatabaseWrapper,
1905        template: str | None = None,
1906    ) -> tuple[str, tuple[Any, ...]]:
1907        connection.ops.check_expression_support(self)
1908        if not connection.features.supports_over_clause:
1909            raise NotSupportedError("This backend does not support window expressions.")
1910        expr_sql, params = compiler.compile(self.source_expression)
1911        window_sql, window_params = [], ()
1912
1913        if self.partition_by is not None:
1914            sql_expr, sql_params = self.partition_by.as_sql(
1915                compiler=compiler,
1916                connection=connection,
1917                template="PARTITION BY %(expressions)s",
1918            )
1919            window_sql.append(sql_expr)
1920            window_params += tuple(sql_params)
1921
1922        if self.order_by is not None:
1923            order_sql, order_params = compiler.compile(self.order_by)
1924            window_sql.append(order_sql)
1925            window_params += tuple(order_params)
1926
1927        if self.frame:
1928            frame_sql, frame_params = compiler.compile(self.frame)
1929            window_sql.append(frame_sql)
1930            window_params += tuple(frame_params)
1931
1932        template = template or self.template
1933
1934        return (
1935            template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1936            (*params, *window_params),
1937        )
1938
1939    def as_sqlite(
1940        self,
1941        compiler: SQLCompiler,
1942        connection: BaseDatabaseWrapper,
1943        **extra_context: Any,
1944    ) -> tuple[str, Sequence[Any]]:
1945        if isinstance(self.output_field, fields.DecimalField):
1946            # Casting to numeric must be outside of the window expression.
1947            copy = self.copy()
1948            source_expressions = copy.get_source_expressions()
1949            source_expressions[0].output_field = fields.FloatField()
1950            copy.set_source_expressions(source_expressions)
1951            return super(Window, copy).as_sqlite(compiler, connection, **extra_context)
1952        return self.as_sql(compiler, connection, **extra_context)
1953
1954    def __str__(self) -> str:
1955        return "{} OVER ({}{}{})".format(
1956            str(self.source_expression),
1957            "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1958            str(self.order_by or ""),
1959            str(self.frame or ""),
1960        )
1961
1962    def __repr__(self) -> str:
1963        return f"<{self.__class__.__name__}: {self}>"
1964
1965    def get_group_by_cols(self) -> list[Any]:
1966        group_by_cols = []
1967        if self.partition_by:
1968            group_by_cols.extend(self.partition_by.get_group_by_cols())
1969        if self.order_by is not None:
1970            group_by_cols.extend(self.order_by.get_group_by_cols())
1971        return group_by_cols
1972
1973
1974class WindowFrame(Expression, ABC):
1975    """
1976    Model the frame clause in window expressions. There are two types of frame
1977    clauses which are subclasses, however, all processing and validation (by no
1978    means intended to be complete) is done here. Thus, providing an end for a
1979    frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1980    row in the frame).
1981    """
1982
1983    template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1984    frame_type: str
1985
1986    def __init__(self, start: int | None = None, end: int | None = None):
1987        self.start = Value(start)
1988        self.end = Value(end)
1989
1990    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1991        self.start, self.end = exprs
1992
1993    def get_source_expressions(self) -> list[Any]:
1994        return [self.start, self.end]
1995
1996    def as_sql(
1997        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1998    ) -> tuple[str, list[Any]]:
1999        connection.ops.check_expression_support(self)
2000        start, end = self.window_frame_start_end(
2001            connection, self.start.value, self.end.value
2002        )
2003        return (
2004            self.template
2005            % {
2006                "frame_type": self.frame_type,
2007                "start": start,
2008                "end": end,
2009            },
2010            [],
2011        )
2012
2013    def __repr__(self) -> str:
2014        return f"<{self.__class__.__name__}: {self}>"
2015
2016    def get_group_by_cols(self) -> list[Any]:
2017        return []
2018
2019    def __str__(self) -> str:
2020        if self.start.value is not None and self.start.value < 0:
2021            start = f"{abs(self.start.value)} {db_connection.ops.PRECEDING}"
2022        elif self.start.value is not None and self.start.value == 0:
2023            start = db_connection.ops.CURRENT_ROW
2024        else:
2025            start = db_connection.ops.UNBOUNDED_PRECEDING
2026
2027        if self.end.value is not None and self.end.value > 0:
2028            end = f"{self.end.value} {db_connection.ops.FOLLOWING}"
2029        elif self.end.value is not None and self.end.value == 0:
2030            end = db_connection.ops.CURRENT_ROW
2031        else:
2032            end = db_connection.ops.UNBOUNDED_FOLLOWING
2033        return self.template % {
2034            "frame_type": self.frame_type,
2035            "start": start,
2036            "end": end,
2037        }
2038
2039    @abstractmethod
2040    def window_frame_start_end(
2041        self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2042    ) -> tuple[str, str]: ...
2043
2044
2045class RowRange(WindowFrame):
2046    frame_type = "ROWS"
2047
2048    def window_frame_start_end(
2049        self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2050    ) -> tuple[str, str]:
2051        return connection.ops.window_frame_rows_start_end(start, end)
2052
2053
2054class ValueRange(WindowFrame):
2055    frame_type = "RANGE"
2056
2057    def window_frame_start_end(
2058        self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2059    ) -> tuple[str, str]:
2060        return connection.ops.window_frame_range_start_end(start, end)