1from __future__ import annotations
   2
   3import copy
   4import datetime
   5import functools
   6import inspect
   7from collections import defaultdict
   8from decimal import Decimal
   9from functools import cached_property
  10from types import NoneType
  11from typing import TYPE_CHECKING, Any, Protocol, Self, runtime_checkable
  12from uuid import UUID
  13
  14from plain.postgres import fields
  15from plain.postgres.constants import LOOKUP_SEP
  16from plain.postgres.db import NotSupportedError
  17from plain.postgres.dialect import (
  18    CURRENT_ROW,
  19    FOLLOWING,
  20    PRECEDING,
  21    UNBOUNDED_FOLLOWING,
  22    UNBOUNDED_PRECEDING,
  23    combine_expression,
  24    quote_name,
  25    subtract_temporals,
  26    window_frame_range_start_end,
  27    window_frame_rows_start_end,
  28)
  29from plain.postgres.exceptions import EmptyResultSet, FieldError, FullResultSet
  30from plain.postgres.query_utils import Q
  31from plain.utils.deconstruct import deconstructible
  32from plain.utils.hashable import make_hashable
  33
  34if TYPE_CHECKING:
  35    from collections.abc import Callable, Iterable, Sequence
  36
  37    from plain.postgres.connection import DatabaseConnection
  38    from plain.postgres.fields import Field
  39    from plain.postgres.lookups import Lookup, Transform
  40    from plain.postgres.query import QuerySet
  41    from plain.postgres.sql.compiler import SQLCompilable, SQLCompiler
  42    from plain.postgres.sql.query import Query
  43
  44__all__ = [
  45    # Core expression classes
  46    "F",
  47    "Value",
  48    "Case",
  49    "When",
  50    "Subquery",
  51    "Exists",
  52    "OuterRef",
  53    "Window",
  54    "ExpressionWrapper",
  55    "RawSQL",
  56    "OrderBy",
  57    # Base classes (for extension)
  58    "Func",
  59    "Expression",
  60    "Combinable",
  61    # Window frame specs
  62    "RowRange",
  63    "ValueRange",
  64]
  65
  66
  67@runtime_checkable
  68class ResolvableExpression(Protocol):
  69    """Protocol for expressions that can be resolved in query context."""
  70
  71    def resolve_expression(
  72        self,
  73        query: Any = None,
  74        allow_joins: bool = True,
  75        reuse: Any = None,
  76        summarize: bool = False,
  77        for_save: bool = False,
  78    ) -> Any: ...
  79
  80
  81@runtime_checkable
  82class ReplaceableExpression(Protocol):
  83    """Protocol for expressions that support expression replacement."""
  84
  85    def replace_expressions(self, replacements: dict[Any, Any]) -> Self: ...
  86
  87
  88class Combinable:
  89    """
  90    Provide the ability to combine one or two objects with
  91    some connector. For example F('foo') + F('bar').
  92    """
  93
  94    # Arithmetic connectors
  95    ADD = "+"
  96    SUB = "-"
  97    MUL = "*"
  98    DIV = "/"
  99    POW = "^"
 100    # The following is a quoted % operator - it is quoted because it can be
 101    # used in strings that also have parameter substitution.
 102    MOD = "%%"
 103
 104    # Bitwise operators - note that these are generated by .bitand()
 105    # and .bitor(), the '&' and '|' are reserved for boolean operator
 106    # usage.
 107    BITAND = "&"
 108    BITOR = "|"
 109    BITLEFTSHIFT = "<<"
 110    BITRIGHTSHIFT = ">>"
 111    BITXOR = "#"
 112
 113    def _combine(
 114        self, other: Any, connector: str, reversed: bool
 115    ) -> CombinedExpression:
 116        if not isinstance(other, ResolvableExpression):
 117            # everything must be resolvable to an expression
 118            other = Value(other)
 119
 120        if reversed:
 121            return CombinedExpression(other, connector, self)
 122        return CombinedExpression(self, connector, other)
 123
 124    #############
 125    # OPERATORS #
 126    #############
 127
 128    def __neg__(self) -> CombinedExpression:
 129        return self._combine(-1, self.MUL, False)
 130
 131    def __add__(self, other: Any) -> CombinedExpression:
 132        return self._combine(other, self.ADD, False)
 133
 134    def __sub__(self, other: Any) -> CombinedExpression:
 135        return self._combine(other, self.SUB, False)
 136
 137    def __mul__(self, other: Any) -> CombinedExpression:
 138        return self._combine(other, self.MUL, False)
 139
 140    def __truediv__(self, other: Any) -> CombinedExpression:
 141        return self._combine(other, self.DIV, False)
 142
 143    def __mod__(self, other: Any) -> CombinedExpression:
 144        return self._combine(other, self.MOD, False)
 145
 146    def __pow__(self, other: Any) -> CombinedExpression:
 147        return self._combine(other, self.POW, False)
 148
 149    def __and__(self, other: Any) -> Q:
 150        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 151            return Q(self) & Q(other)
 152        raise NotImplementedError(
 153            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 154        )
 155
 156    def bitand(self, other: Any) -> CombinedExpression:
 157        return self._combine(other, self.BITAND, False)
 158
 159    def bitleftshift(self, other: Any) -> CombinedExpression:
 160        return self._combine(other, self.BITLEFTSHIFT, False)
 161
 162    def bitrightshift(self, other: Any) -> CombinedExpression:
 163        return self._combine(other, self.BITRIGHTSHIFT, False)
 164
 165    def __xor__(self, other: Any) -> Q:
 166        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 167            return Q(self) ^ Q(other)
 168        raise NotImplementedError(
 169            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 170        )
 171
 172    def bitxor(self, other: Any) -> CombinedExpression:
 173        return self._combine(other, self.BITXOR, False)
 174
 175    def __or__(self, other: Any) -> Q:
 176        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 177            return Q(self) | Q(other)
 178        raise NotImplementedError(
 179            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 180        )
 181
 182    def bitor(self, other: Any) -> CombinedExpression:
 183        return self._combine(other, self.BITOR, False)
 184
 185    def __radd__(self, other: Any) -> CombinedExpression:
 186        return self._combine(other, self.ADD, True)
 187
 188    def __rsub__(self, other: Any) -> CombinedExpression:
 189        return self._combine(other, self.SUB, True)
 190
 191    def __rmul__(self, other: Any) -> CombinedExpression:
 192        return self._combine(other, self.MUL, True)
 193
 194    def __rtruediv__(self, other: Any) -> CombinedExpression:
 195        return self._combine(other, self.DIV, True)
 196
 197    def __rmod__(self, other: Any) -> CombinedExpression:
 198        return self._combine(other, self.MOD, True)
 199
 200    def __rpow__(self, other: Any) -> CombinedExpression:
 201        return self._combine(other, self.POW, True)
 202
 203    def __rand__(self, other: Any) -> None:
 204        raise NotImplementedError(
 205            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 206        )
 207
 208    def __ror__(self, other: Any) -> None:
 209        raise NotImplementedError(
 210            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 211        )
 212
 213    def __rxor__(self, other: Any) -> None:
 214        raise NotImplementedError(
 215            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 216        )
 217
 218    def __invert__(self) -> NegatedExpression:
 219        return NegatedExpression(self)
 220
 221
 222class BaseExpression:
 223    """Base class for all query expressions."""
 224
 225    empty_result_set_value = NotImplemented
 226    # aggregate specific fields
 227    is_summary = False
 228    _output_field_resolved_to_none = False
 229    # Can the expression be used in a WHERE clause?
 230    filterable = True
 231    # Can the expression can be used as a source expression in Window?
 232    window_compatible = False
 233
 234    def __init__(self, output_field: Field | None = None):
 235        if output_field is not None:
 236            self.output_field = output_field
 237
 238    def __getstate__(self) -> dict[str, Any]:
 239        state = self.__dict__.copy()
 240        state.pop("convert_value", None)
 241        return state
 242
 243    def get_db_converters(
 244        self, connection: DatabaseConnection
 245    ) -> list[Callable[..., Any]]:
 246        converters = []
 247        if self.convert_value is not self._convert_value_noop:
 248            converters.append(self.convert_value)
 249        converters.extend(self.output_field.get_db_converters(connection))
 250        return converters
 251
 252    def get_source_expressions(self) -> list[Any]:
 253        return []
 254
 255    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 256        assert not exprs
 257
 258    def _parse_expressions(self, *expressions: Any) -> list[Any]:
 259        return [
 260            arg
 261            if isinstance(arg, ResolvableExpression)
 262            else (F(arg) if isinstance(arg, str) else Value(arg))
 263            for arg in expressions
 264        ]
 265
 266    def as_sql(
 267        self, compiler: SQLCompiler, connection: DatabaseConnection
 268    ) -> tuple[str, Sequence[Any]]:
 269        """
 270        Return a (sql, params) tuple to be included in the current query.
 271
 272        Arguments:
 273         * compiler: the query compiler responsible for generating the query.
 274           Must have a compile method, returning a (sql, [params]) tuple.
 275           Calling compiler(value) will return a quoted `value`.
 276
 277         * connection: the database connection used for the current query.
 278
 279        Return: (sql, params)
 280          Where `sql` is a string containing ordered sql parameters to be
 281          replaced with the elements of the list `params`.
 282        """
 283        raise NotImplementedError("Subclasses must implement as_sql()")
 284
 285    @cached_property
 286    def contains_aggregate(self) -> bool:
 287        return any(
 288            expr and expr.contains_aggregate for expr in self.get_source_expressions()
 289        )
 290
 291    @cached_property
 292    def contains_over_clause(self) -> bool:
 293        return any(
 294            expr and expr.contains_over_clause for expr in self.get_source_expressions()
 295        )
 296
 297    @cached_property
 298    def contains_column_references(self) -> bool:
 299        return any(
 300            expr and expr.contains_column_references
 301            for expr in self.get_source_expressions()
 302        )
 303
 304    def resolve_expression(
 305        self,
 306        query: Any = None,
 307        allow_joins: bool = True,
 308        reuse: Any = None,
 309        summarize: bool = False,
 310        for_save: bool = False,
 311    ) -> Self:
 312        """
 313        Provide the chance to do any preprocessing or validation before being
 314        added to the query.
 315
 316        Arguments:
 317         * query: the backend query implementation
 318         * allow_joins: boolean allowing or denying use of joins
 319           in this query
 320         * reuse: a set of reusable joins for multijoins
 321         * summarize: a terminal aggregate clause
 322         * for_save: whether this expression about to be used in a save or update
 323
 324        Return: an Expression to be added to the query.
 325        """
 326        c = self.copy()
 327        c.is_summary = summarize
 328        c.set_source_expressions(
 329            [
 330                expr.resolve_expression(query, allow_joins, reuse, summarize)
 331                if expr
 332                else None
 333                for expr in c.get_source_expressions()
 334            ]
 335        )
 336        return c
 337
 338    @property
 339    def conditional(self) -> bool:
 340        output_field = getattr(self, "output_field", None)
 341        return isinstance(output_field, fields.BooleanField)
 342
 343    @property
 344    def field(self) -> Field:
 345        return self.output_field
 346
 347    @cached_property
 348    def output_field(self) -> Field:
 349        """Return the output type of this expressions."""
 350        output_field = self._resolve_output_field()
 351        if output_field is None:
 352            self._output_field_resolved_to_none = True
 353            raise FieldError("Cannot resolve expression type, unknown output_field")
 354        return output_field
 355
 356    @cached_property
 357    def _output_field_or_none(self) -> Field | None:
 358        """
 359        Return the output field of this expression, or None if
 360        _resolve_output_field() didn't return an output type.
 361        """
 362        try:
 363            return self.output_field
 364        except FieldError:
 365            if not self._output_field_resolved_to_none:
 366                raise
 367            return None
 368
 369    def _resolve_output_field(self) -> Field | None:
 370        """
 371        Attempt to infer the output type of the expression.
 372
 373        As a guess, if the output fields of all source fields match then simply
 374        infer the same type here.
 375
 376        If a source's output field resolves to None, exclude it from this check.
 377        If all sources are None, then an error is raised higher up the stack in
 378        the output_field property.
 379        """
 380        # This guess is mostly a bad idea, but there is quite a lot of code
 381        # (especially 3rd party Func subclasses) that depend on it, we'd need a
 382        # deprecation path to fix it.
 383        sources_iter = (
 384            source for source in self.get_source_fields() if source is not None
 385        )
 386        for output_field in sources_iter:
 387            for source in sources_iter:
 388                if not isinstance(output_field, source.__class__):
 389                    raise FieldError(
 390                        f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
 391                        "set output_field."
 392                    )
 393            return output_field
 394        return None
 395
 396    @staticmethod
 397    def _convert_value_noop(
 398        value: Any, expression: Any, connection: DatabaseConnection
 399    ) -> Any:
 400        return value
 401
 402    @cached_property
 403    def convert_value(self) -> Callable[[Any, Any, Any], Any]:
 404        """
 405        Expressions provide their own converters because users have the option
 406        of manually specifying the output_field which may be a different type
 407        from the one the database returns.
 408        """
 409        field = self.output_field
 410        internal_type = field.get_internal_type()
 411        if internal_type == "FloatField":
 412            return (
 413                lambda value, expression, connection: None
 414                if value is None
 415                else float(value)
 416            )
 417        elif internal_type.endswith("IntegerField"):
 418            return (
 419                lambda value, expression, connection: None
 420                if value is None
 421                else int(value)
 422            )
 423        elif internal_type == "DecimalField":
 424            return (
 425                lambda value, expression, connection: None
 426                if value is None
 427                else Decimal(value)
 428            )
 429        return self._convert_value_noop
 430
 431    def get_lookup(self, lookup: str) -> type[Lookup] | None:
 432        return self.output_field.get_lookup(lookup)
 433
 434    def get_transform(self, name: str) -> type[Transform] | None:
 435        return self.output_field.get_transform(name)  # type: ignore[return-type]
 436
 437    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
 438        clone = self.copy()
 439        clone.set_source_expressions(
 440            [
 441                e.relabeled_clone(change_map) if e is not None else None
 442                for e in self.get_source_expressions()
 443            ]
 444        )
 445        return clone
 446
 447    def replace_expressions(self, replacements: dict[BaseExpression, Any]) -> Self:
 448        if replacement := replacements.get(self):
 449            return replacement
 450        clone = self.copy()
 451        source_expressions = clone.get_source_expressions()
 452        clone.set_source_expressions(
 453            [
 454                expr.replace_expressions(replacements) if expr else None
 455                for expr in source_expressions
 456            ]
 457        )
 458        return clone
 459
 460    def get_refs(self) -> set[str]:
 461        refs = set()
 462        for expr in self.get_source_expressions():
 463            refs |= expr.get_refs()
 464        return refs
 465
 466    def copy(self) -> Self:
 467        return copy.copy(self)
 468
 469    def prefix_references(self, prefix: str) -> Self:
 470        clone = self.copy()
 471        clone.set_source_expressions(
 472            [
 473                F(f"{prefix}{expr.name}")
 474                if isinstance(expr, F)
 475                else expr.prefix_references(prefix)
 476                for expr in self.get_source_expressions()
 477            ]
 478        )
 479        return clone
 480
 481    def get_group_by_cols(self) -> list[BaseExpression]:
 482        if not self.contains_aggregate:
 483            return [self]
 484        cols: list[BaseExpression] = []
 485        for source in self.get_source_expressions():
 486            cols.extend(source.get_group_by_cols())
 487        return cols
 488
 489    def get_source_fields(self) -> list[Field | None]:
 490        """Return the underlying field types used by this aggregate."""
 491        return [e._output_field_or_none for e in self.get_source_expressions()]
 492
 493    def asc(self, **kwargs: Any) -> OrderBy:
 494        return OrderBy(self, **kwargs)
 495
 496    def desc(self, **kwargs: Any) -> OrderBy:
 497        return OrderBy(self, descending=True, **kwargs)
 498
 499    def reverse_ordering(self) -> Self:
 500        return self
 501
 502    def flatten(self) -> Iterable[Any]:
 503        """
 504        Recursively yield this expression and all subexpressions, in
 505        depth-first order.
 506        """
 507        yield self
 508        for expr in self.get_source_expressions():
 509            if expr:
 510                if hasattr(expr, "flatten"):
 511                    yield from expr.flatten()
 512                else:
 513                    yield expr
 514
 515    def select_format(
 516        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
 517    ) -> tuple[str, Sequence[Any]]:
 518        """Custom format for select clauses."""
 519        if output_field := getattr(self, "output_field", None):
 520            if select_format := getattr(output_field, "select_format", None):
 521                return select_format(compiler, sql, params)
 522        return sql, params
 523
 524
 525@deconstructible
 526class Expression(BaseExpression, Combinable):
 527    """An expression that can be combined with other expressions."""
 528
 529    # Set by @deconstructible decorator in __new__
 530    _constructor_args: tuple[tuple[Any, ...], dict[str, Any]]
 531
 532    @cached_property
 533    def identity(self) -> tuple[Any, ...]:
 534        constructor_signature = inspect.signature(self.__init__)
 535        args, kwargs = self._constructor_args
 536        signature = constructor_signature.bind_partial(*args, **kwargs)
 537        signature.apply_defaults()
 538        arguments = signature.arguments.items()
 539        identity: list[Any] = [self.__class__]
 540        for arg, value in arguments:
 541            if isinstance(value, fields.Field):
 542                if value.name and value.model:
 543                    value = (value.model.model_options.label, value.name)
 544                else:
 545                    value = type(value)
 546            else:
 547                value = make_hashable(value)
 548            identity.append((arg, value))
 549        return tuple(identity)
 550
 551    def __eq__(self, other: object) -> bool:
 552        if not isinstance(other, Expression):
 553            return NotImplemented
 554        return other.identity == self.identity
 555
 556    def __hash__(self) -> int:
 557        return hash(self.identity)
 558
 559
 560# Type inference for CombinedExpression.output_field.
 561# Missing items will result in FieldError, by design.
 562#
 563# The current approach for NULL is based on lowest common denominator behavior
 564# i.e. if one of the supported databases is raising an error (rather than
 565# return NULL) for `val <op> NULL`, then Plain raises FieldError.
 566
 567_connector_combinations = [
 568    # Numeric operations - operands of same type.
 569    {
 570        connector: [
 571            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 572            (fields.FloatField, fields.FloatField, fields.FloatField),
 573            (fields.DecimalField, fields.DecimalField, fields.DecimalField),
 574        ]
 575        for connector in (
 576            Combinable.ADD,
 577            Combinable.SUB,
 578            Combinable.MUL,
 579            Combinable.DIV,
 580            Combinable.MOD,
 581            Combinable.POW,
 582        )
 583    },
 584    # Numeric operations - operands of different type.
 585    {
 586        connector: [
 587            (fields.IntegerField, fields.DecimalField, fields.DecimalField),
 588            (fields.DecimalField, fields.IntegerField, fields.DecimalField),
 589            (fields.IntegerField, fields.FloatField, fields.FloatField),
 590            (fields.FloatField, fields.IntegerField, fields.FloatField),
 591        ]
 592        for connector in (
 593            Combinable.ADD,
 594            Combinable.SUB,
 595            Combinable.MUL,
 596            Combinable.DIV,
 597            Combinable.MOD,
 598        )
 599    },
 600    # Bitwise operators.
 601    {
 602        connector: [
 603            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 604        ]
 605        for connector in (
 606            Combinable.BITAND,
 607            Combinable.BITOR,
 608            Combinable.BITLEFTSHIFT,
 609            Combinable.BITRIGHTSHIFT,
 610            Combinable.BITXOR,
 611        )
 612    },
 613    # Numeric with NULL.
 614    {
 615        connector: [
 616            (field_type, NoneType, field_type),
 617            (NoneType, field_type, field_type),
 618        ]
 619        for connector in (
 620            Combinable.ADD,
 621            Combinable.SUB,
 622            Combinable.MUL,
 623            Combinable.DIV,
 624            Combinable.MOD,
 625            Combinable.POW,
 626        )
 627        for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
 628    },
 629    # Date/DateTimeField/DurationField/TimeField.
 630    {
 631        Combinable.ADD: [
 632            # Date/DateTimeField.
 633            (fields.DateField, fields.DurationField, fields.DateTimeField),
 634            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 635            (fields.DurationField, fields.DateField, fields.DateTimeField),
 636            (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
 637            # DurationField.
 638            (fields.DurationField, fields.DurationField, fields.DurationField),
 639            # TimeField.
 640            (fields.TimeField, fields.DurationField, fields.TimeField),
 641            (fields.DurationField, fields.TimeField, fields.TimeField),
 642        ],
 643    },
 644    {
 645        Combinable.SUB: [
 646            # Date/DateTimeField.
 647            (fields.DateField, fields.DurationField, fields.DateTimeField),
 648            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 649            (fields.DateField, fields.DateField, fields.DurationField),
 650            (fields.DateField, fields.DateTimeField, fields.DurationField),
 651            (fields.DateTimeField, fields.DateField, fields.DurationField),
 652            (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
 653            # DurationField.
 654            (fields.DurationField, fields.DurationField, fields.DurationField),
 655            # TimeField.
 656            (fields.TimeField, fields.DurationField, fields.TimeField),
 657            (fields.TimeField, fields.TimeField, fields.DurationField),
 658        ],
 659    },
 660]
 661
 662_connector_combinators = defaultdict(list)
 663
 664
 665def register_combinable_fields(
 666    lhs: type[Field] | type[None],
 667    connector: str,
 668    rhs: type[Field] | type[None],
 669    result: type[Field],
 670) -> None:
 671    """
 672    Register combinable types:
 673        lhs <connector> rhs -> result
 674    e.g.
 675        register_combinable_fields(
 676            IntegerField, Combinable.ADD, FloatField, FloatField
 677        )
 678    """
 679    _connector_combinators[connector].append((lhs, rhs, result))
 680
 681
 682for d in _connector_combinations:
 683    for connector, field_types in d.items():
 684        for lhs, rhs, result in field_types:
 685            register_combinable_fields(lhs, connector, rhs, result)
 686
 687
 688@functools.lru_cache(maxsize=128)
 689def _resolve_combined_type(
 690    connector: str, lhs_type: type[Field], rhs_type: type[Field]
 691) -> type[Field] | None:
 692    combinators = _connector_combinators.get(connector, ())
 693    for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
 694        if issubclass(lhs_type, combinator_lhs_type) and issubclass(
 695            rhs_type, combinator_rhs_type
 696        ):
 697            return combined_type
 698    return None
 699
 700
 701class CombinedExpression(Expression):
 702    def __init__(
 703        self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
 704    ):
 705        super().__init__(output_field=output_field)
 706        self.connector = connector
 707        self.lhs = lhs
 708        self.rhs = rhs
 709
 710    def __repr__(self) -> str:
 711        return f"<{self.__class__.__name__}: {self}>"
 712
 713    def __str__(self) -> str:
 714        return f"{self.lhs} {self.connector} {self.rhs}"
 715
 716    def get_source_expressions(self) -> list[Any]:
 717        return [self.lhs, self.rhs]
 718
 719    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 720        self.lhs, self.rhs = exprs
 721
 722    def _resolve_output_field(self) -> Field | None:
 723        # We avoid using super() here for reasons given in
 724        # Expression._resolve_output_field()
 725        combined_type = _resolve_combined_type(
 726            self.connector,
 727            type(self.lhs._output_field_or_none),
 728            type(self.rhs._output_field_or_none),
 729        )
 730        if combined_type is None:
 731            raise FieldError(
 732                f"Cannot infer type of {self.connector!r} expression involving these "
 733                f"types: {self.lhs.output_field.__class__.__name__}, "
 734                f"{self.rhs.output_field.__class__.__name__}. You must set "
 735                f"output_field."
 736            )
 737        return combined_type()
 738
 739    def as_sql(
 740        self, compiler: SQLCompiler, connection: DatabaseConnection
 741    ) -> tuple[str, list[Any]]:
 742        expressions = []
 743        expression_params = []
 744        sql, params = compiler.compile(self.lhs)
 745        expressions.append(sql)
 746        expression_params.extend(params)
 747        sql, params = compiler.compile(self.rhs)
 748        expressions.append(sql)
 749        expression_params.extend(params)
 750        # order of precedence
 751        expression_wrapper = "(%s)"
 752        sql = combine_expression(self.connector, expressions)
 753        return expression_wrapper % sql, expression_params
 754
 755    def resolve_expression(
 756        self,
 757        query: Any = None,
 758        allow_joins: bool = True,
 759        reuse: Any = None,
 760        summarize: bool = False,
 761        for_save: bool = False,
 762    ) -> CombinedExpression | TemporalSubtraction:
 763        lhs = self.lhs.resolve_expression(
 764            query, allow_joins, reuse, summarize, for_save
 765        )
 766        rhs = self.rhs.resolve_expression(
 767            query, allow_joins, reuse, summarize, for_save
 768        )
 769        if not isinstance(self, TemporalSubtraction):
 770            try:
 771                lhs_type = lhs.output_field.get_internal_type()
 772            except (AttributeError, FieldError):
 773                lhs_type = None
 774            try:
 775                rhs_type = rhs.output_field.get_internal_type()
 776            except (AttributeError, FieldError):
 777                rhs_type = None
 778            datetime_fields = {"DateField", "DateTimeField", "TimeField"}
 779            if (
 780                self.connector == self.SUB
 781                and lhs_type in datetime_fields
 782                and lhs_type == rhs_type
 783            ):
 784                return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
 785                    query,
 786                    allow_joins,
 787                    reuse,
 788                    summarize,
 789                    for_save,
 790                )
 791        c = self.copy()
 792        c.is_summary = summarize
 793        c.lhs = lhs
 794        c.rhs = rhs
 795        return c
 796
 797
 798class TemporalSubtraction(CombinedExpression):
 799    output_field = fields.DurationField()
 800
 801    def __init__(self, lhs: Any, rhs: Any):
 802        super().__init__(lhs, self.SUB, rhs)
 803
 804    def as_sql(
 805        self, compiler: SQLCompiler, connection: DatabaseConnection
 806    ) -> tuple[str, list[Any]]:
 807        lhs = compiler.compile(self.lhs)
 808        rhs = compiler.compile(self.rhs)
 809        sql, params = subtract_temporals(
 810            self.lhs.output_field.get_internal_type(), lhs, rhs
 811        )
 812        return sql, list(params)
 813
 814
 815@deconstructible(path="plain.postgres.F")
 816class F(Combinable):
 817    """An object capable of resolving references to existing query objects."""
 818
 819    def __init__(self, name: str):
 820        """
 821        Arguments:
 822         * name: the name of the field this expression references
 823        """
 824        self.name = name
 825
 826    def __repr__(self) -> str:
 827        return f"{self.__class__.__name__}({self.name})"
 828
 829    def resolve_expression(
 830        self,
 831        query: Any = None,
 832        allow_joins: bool = True,
 833        reuse: Any = None,
 834        summarize: bool = False,
 835        for_save: bool = False,
 836    ) -> Any:
 837        return query.resolve_ref(self.name, allow_joins, reuse, summarize)
 838
 839    def replace_expressions(self, replacements: dict[Any, Any]) -> F:
 840        return replacements.get(self, self)
 841
 842    def asc(self, **kwargs: Any) -> OrderBy:
 843        return OrderBy(self, **kwargs)
 844
 845    def desc(self, **kwargs: Any) -> OrderBy:
 846        return OrderBy(self, descending=True, **kwargs)
 847
 848    def __eq__(self, other: object) -> bool:
 849        if not isinstance(other, F):
 850            return NotImplemented
 851        return self.__class__ == other.__class__ and self.name == other.name
 852
 853    def __hash__(self) -> int:
 854        return hash(self.name)
 855
 856    def copy(self) -> Self:
 857        return copy.copy(self)
 858
 859
 860class ResolvedOuterRef(F):
 861    """
 862    An object that contains a reference to an outer query.
 863
 864    In this case, the reference to the outer query has been resolved because
 865    the inner query has been used as a subquery.
 866    """
 867
 868    contains_aggregate = False
 869    contains_over_clause = False
 870
 871    def as_sql(self, *args: Any, **kwargs: Any) -> None:
 872        raise ValueError(
 873            "This queryset contains a reference to an outer query and may "
 874            "only be used in a subquery."
 875        )
 876
 877    def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
 878        col = super().resolve_expression(*args, **kwargs)
 879        if col.contains_over_clause:
 880            raise NotSupportedError(
 881                f"Referencing outer query window expression is not supported: "
 882                f"{self.name}."
 883            )
 884        # FIXME: Rename possibly_multivalued to multivalued and fix detection
 885        # for non-multivalued JOINs (e.g. foreign key fields). This should take
 886        # into accountย only many-to-many and one-to-many relationships.
 887        col.possibly_multivalued = LOOKUP_SEP in self.name
 888        return col
 889
 890    def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
 891        return self
 892
 893    def get_group_by_cols(self) -> list[Any]:
 894        return []
 895
 896
 897class OuterRef(F):
 898    contains_aggregate = False
 899
 900    def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
 901        if isinstance(self.name, self.__class__):
 902            return self.name
 903        return ResolvedOuterRef(self.name)
 904
 905    def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
 906        return self
 907
 908
 909@deconstructible(path="plain.postgres.Func")
 910class Func(Expression):
 911    """An SQL function call."""
 912
 913    function = None
 914    template = "%(function)s(%(expressions)s)"
 915    arg_joiner = ", "
 916    arity = None  # The number of arguments the function accepts.
 917
 918    def __init__(
 919        self, *expressions: Any, output_field: Field | None = None, **extra: Any
 920    ):
 921        if self.arity is not None and len(expressions) != self.arity:
 922            raise TypeError(
 923                "'{}' takes exactly {} {} ({} given)".format(
 924                    self.__class__.__name__,
 925                    self.arity,
 926                    "argument" if self.arity == 1 else "arguments",
 927                    len(expressions),
 928                )
 929            )
 930        super().__init__(output_field=output_field)
 931        self.source_expressions: list[Any] = self._parse_expressions(*expressions)
 932        self.extra = extra
 933
 934    def __repr__(self) -> str:
 935        args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
 936        extra = {**self.extra, **self._get_repr_options()}
 937        if extra:
 938            extra = ", ".join(
 939                str(key) + "=" + str(val) for key, val in sorted(extra.items())
 940            )
 941            return f"{self.__class__.__name__}({args}, {extra})"
 942        return f"{self.__class__.__name__}({args})"
 943
 944    def _get_repr_options(self) -> dict[str, Any]:
 945        """Return a dict of extra __init__() options to include in the repr."""
 946        return {}
 947
 948    def get_source_expressions(self) -> list[Any]:
 949        return self.source_expressions
 950
 951    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 952        self.source_expressions = list(exprs)
 953
 954    def resolve_expression(
 955        self,
 956        query: Any = None,
 957        allow_joins: bool = True,
 958        reuse: Any = None,
 959        summarize: bool = False,
 960        for_save: bool = False,
 961    ) -> Self:
 962        c = self.copy()
 963        c.is_summary = summarize
 964        for pos, arg in enumerate(c.source_expressions):
 965            c.source_expressions[pos] = arg.resolve_expression(
 966                query, allow_joins, reuse, summarize, for_save
 967            )
 968        return c
 969
 970    def as_sql(
 971        self,
 972        compiler: SQLCompiler,
 973        connection: DatabaseConnection,
 974        function: str | None = None,
 975        template: str | None = None,
 976        arg_joiner: str | None = None,
 977        **extra_context: Any,
 978    ) -> tuple[str, list[Any]]:
 979        sql_parts = []
 980        params = []
 981        for arg in self.source_expressions:
 982            try:
 983                arg_sql, arg_params = compiler.compile(arg)
 984            except EmptyResultSet:
 985                empty_result_set_value = getattr(
 986                    arg, "empty_result_set_value", NotImplemented
 987                )
 988                if empty_result_set_value is NotImplemented:
 989                    raise
 990                arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
 991            except FullResultSet:
 992                arg_sql, arg_params = compiler.compile(Value(True))
 993            sql_parts.append(arg_sql)
 994            params.extend(arg_params)
 995        data = {**self.extra, **extra_context}
 996        # Use the first supplied value in this order: the parameter to this
 997        # method, a value supplied in __init__()'s **extra (the value in
 998        # `data`), or the value defined on the class.
 999        if function is not None:
1000            data["function"] = function
1001        else:
1002            data.setdefault("function", self.function)
1003        template = template or data.get("template", self.template)
1004        arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1005        data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1006        return template % data, params
1007
1008    def copy(self) -> Self:
1009        clone = super().copy()
1010        clone.source_expressions = self.source_expressions[:]
1011        clone.extra = self.extra.copy()
1012        return clone
1013
1014
1015@deconstructible(path="plain.postgres.Value")
1016class Value(Expression):
1017    """Represent a wrapped value as a node within an expression."""
1018
1019    # Provide a default value for `for_save` in order to allow unresolved
1020    # instances to be compiled until a decision is taken in #25425.
1021    for_save = False
1022
1023    def __init__(self, value: Any, output_field: Field | None = None):
1024        """
1025        Arguments:
1026         * value: the value this expression represents. The value will be
1027           added into the sql parameter list and properly quoted.
1028
1029         * output_field: an instance of the model field type that this
1030           expression will return, such as IntegerField() or CharField().
1031        """
1032        super().__init__(output_field=output_field)
1033        self.value = value
1034
1035    def __repr__(self) -> str:
1036        return f"{self.__class__.__name__}({self.value!r})"
1037
1038    def as_sql(
1039        self, compiler: SQLCompiler, connection: DatabaseConnection
1040    ) -> tuple[str, list[Any]]:
1041        val = self.value
1042        output_field = self._output_field_or_none
1043        if output_field is not None:
1044            if self.for_save:
1045                val = output_field.get_db_prep_save(val, connection=connection)
1046            else:
1047                val = output_field.get_db_prep_value(val, connection=connection)
1048            if hasattr(output_field, "get_placeholder"):
1049                return output_field.get_placeholder(val, compiler, connection), [val]  # type: ignore[call-non-callable]
1050        if val is None:
1051            return "NULL", []
1052        return "%s", [val]
1053
1054    def resolve_expression(
1055        self,
1056        query: Any = None,
1057        allow_joins: bool = True,
1058        reuse: Any = None,
1059        summarize: bool = False,
1060        for_save: bool = False,
1061    ) -> Value:
1062        c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1063        c.for_save = for_save
1064        return c
1065
1066    def get_group_by_cols(self) -> list[Any]:
1067        return []
1068
1069    def _resolve_output_field(self) -> Field | None:
1070        if isinstance(self.value, str):
1071            return fields.CharField()
1072        if isinstance(self.value, bool):
1073            return fields.BooleanField()
1074        if isinstance(self.value, int):
1075            return fields.IntegerField()
1076        if isinstance(self.value, float):
1077            return fields.FloatField()
1078        if isinstance(self.value, datetime.datetime):
1079            return fields.DateTimeField()
1080        if isinstance(self.value, datetime.date):
1081            return fields.DateField()
1082        if isinstance(self.value, datetime.time):
1083            return fields.TimeField()
1084        if isinstance(self.value, datetime.timedelta):
1085            return fields.DurationField()
1086        if isinstance(self.value, Decimal):
1087            return fields.DecimalField()
1088        if isinstance(self.value, bytes):
1089            return fields.BinaryField()
1090        if isinstance(self.value, UUID):
1091            return fields.UUIDField()
1092
1093    @property
1094    def empty_result_set_value(self) -> Any:
1095        return self.value
1096
1097
1098class RawSQL(Expression):
1099    def __init__(
1100        self, sql: str, params: Sequence[Any], output_field: Field | None = None
1101    ):
1102        if output_field is None:
1103            output_field = fields.Field()
1104        self.sql, self.params = sql, params
1105        super().__init__(output_field=output_field)
1106
1107    def __repr__(self) -> str:
1108        return f"{self.__class__.__name__}({self.sql}, {self.params})"
1109
1110    def as_sql(
1111        self, compiler: SQLCompiler, connection: DatabaseConnection
1112    ) -> tuple[str, Sequence[Any]]:
1113        return f"({self.sql})", self.params
1114
1115    def get_group_by_cols(self) -> list[BaseExpression]:
1116        return [self]
1117
1118
1119class Star(Expression):
1120    def __repr__(self) -> str:
1121        return "'*'"
1122
1123    def as_sql(
1124        self, compiler: SQLCompiler, connection: DatabaseConnection
1125    ) -> tuple[str, list[Any]]:
1126        return "*", []
1127
1128
1129class Col(Expression):
1130    contains_column_references = True
1131    possibly_multivalued = False
1132
1133    def __init__(
1134        self, alias: str | None, target: Any, output_field: Field | None = None
1135    ):
1136        if output_field is None:
1137            output_field = target
1138        super().__init__(output_field=output_field)
1139        self.alias, self.target = alias, target
1140
1141    def __repr__(self) -> str:
1142        alias, target = self.alias, self.target
1143        identifiers = (alias, str(target)) if alias else (str(target),)
1144        return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1145
1146    def as_sql(
1147        self, compiler: SQLCompiler, connection: DatabaseConnection
1148    ) -> tuple[str, list[Any]]:
1149        alias, column = self.alias, self.target.column
1150        identifiers = (alias, column) if alias else (column,)
1151        sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1152        return sql, []
1153
1154    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1155        if self.alias is None:
1156            return self
1157        return self.__class__(
1158            change_map.get(self.alias, self.alias), self.target, self.output_field
1159        )
1160
1161    def get_group_by_cols(self) -> list[BaseExpression]:
1162        return [self]
1163
1164    def get_db_converters(
1165        self, connection: DatabaseConnection
1166    ) -> list[Callable[..., Any]]:
1167        if self.target == self.output_field:
1168            return self.output_field.get_db_converters(connection)
1169        return self.output_field.get_db_converters(
1170            connection
1171        ) + self.target.get_db_converters(connection)
1172
1173
1174class Ref(Expression):
1175    """
1176    Reference to column alias of the query. For example, Ref('sum_cost') in
1177    qs.annotate(sum_cost=Sum('cost')) query.
1178    """
1179
1180    def __init__(self, refs: str, source: Any):
1181        super().__init__()
1182        self.refs, self.source = refs, source
1183
1184    def __repr__(self) -> str:
1185        return f"{self.__class__.__name__}({self.refs}, {self.source})"
1186
1187    def get_source_expressions(self) -> list[Any]:
1188        return [self.source]
1189
1190    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1191        (self.source,) = exprs
1192
1193    def resolve_expression(
1194        self,
1195        query: Any = None,
1196        allow_joins: bool = True,
1197        reuse: Any = None,
1198        summarize: bool = False,
1199        for_save: bool = False,
1200    ) -> Ref:
1201        # The sub-expression `source` has already been resolved, as this is
1202        # just a reference to the name of `source`.
1203        return self
1204
1205    def get_refs(self) -> set[str]:
1206        return {self.refs}
1207
1208    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1209        return self
1210
1211    def as_sql(
1212        self, compiler: SQLCompiler, connection: DatabaseConnection
1213    ) -> tuple[str, list[Any]]:
1214        return quote_name(self.refs), []
1215
1216    def get_group_by_cols(self) -> list[BaseExpression]:
1217        return [self]
1218
1219
1220class ExpressionList(Func):
1221    """
1222    An expression containing multiple expressions. Can be used to provide a
1223    list of expressions as an argument to another expression, like a partition
1224    clause.
1225    """
1226
1227    template = "%(expressions)s"
1228
1229    def __init__(self, *expressions: Any, **extra: Any):
1230        if not expressions:
1231            raise ValueError(
1232                f"{self.__class__.__name__} requires at least one expression."
1233            )
1234        super().__init__(*expressions, **extra)
1235
1236    def __str__(self) -> str:
1237        return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1238
1239
1240class OrderByList(Func):
1241    template = "ORDER BY %(expressions)s"
1242
1243    def __init__(self, *expressions: Any, **extra: Any):
1244        expressions_tuple = tuple(
1245            (
1246                OrderBy(F(expr[1:]), descending=True)
1247                if isinstance(expr, str) and expr[0] == "-"
1248                else expr
1249            )
1250            for expr in expressions
1251        )
1252        super().__init__(*expressions_tuple, **extra)
1253
1254    def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, list[Any]]:
1255        if not self.source_expressions:
1256            return "", []
1257        sql, params = super().as_sql(*args, **kwargs)
1258        return sql, list(params)
1259
1260    def get_group_by_cols(self) -> list[Any]:
1261        group_by_cols = []
1262        for order_by in self.get_source_expressions():
1263            group_by_cols.extend(order_by.get_group_by_cols())
1264        return group_by_cols
1265
1266
1267@deconstructible(path="plain.postgres.ExpressionWrapper")
1268class ExpressionWrapper(Expression):
1269    """
1270    An expression that can wrap another expression so that it can provide
1271    extra context to the inner expression, such as the output_field.
1272    """
1273
1274    def __init__(self, expression: Any, output_field: Field):
1275        super().__init__(output_field=output_field)
1276        self.expression = expression
1277
1278    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1279        self.expression = exprs[0]
1280
1281    def get_source_expressions(self) -> list[Any]:
1282        return [self.expression]
1283
1284    def get_group_by_cols(self) -> list[Any]:
1285        if isinstance(self.expression, Expression):
1286            expression = self.expression.copy()
1287            expression.output_field = self.output_field
1288            return expression.get_group_by_cols()
1289        # For non-expressions e.g. an SQL WHERE clause, the entire
1290        # `expression` must be included in the GROUP BY clause.
1291        return super().get_group_by_cols()
1292
1293    def as_sql(
1294        self, compiler: SQLCompiler, connection: DatabaseConnection
1295    ) -> tuple[str, Sequence[Any]]:
1296        return compiler.compile(self.expression)
1297
1298    def __repr__(self) -> str:
1299        return f"{self.__class__.__name__}({self.expression})"
1300
1301
1302class NegatedExpression(ExpressionWrapper):
1303    """The logical negation of a conditional expression."""
1304
1305    def __init__(self, expression: Any):
1306        super().__init__(expression, output_field=fields.BooleanField())
1307
1308    def __invert__(self) -> Any:
1309        return self.expression.copy()
1310
1311    def as_sql(
1312        self, compiler: SQLCompiler, connection: DatabaseConnection
1313    ) -> tuple[str, Sequence[Any]]:
1314        try:
1315            sql, params = super().as_sql(compiler, connection)
1316        except EmptyResultSet:
1317            return compiler.compile(Value(True))
1318        return f"NOT {sql}", params
1319
1320    def resolve_expression(
1321        self,
1322        query: Any = None,
1323        allow_joins: bool = True,
1324        reuse: Any = None,
1325        summarize: bool = False,
1326        for_save: bool = False,
1327    ) -> NegatedExpression:
1328        resolved = super().resolve_expression(
1329            query, allow_joins, reuse, summarize, for_save
1330        )
1331        if not getattr(resolved.expression, "conditional", False):
1332            raise TypeError("Cannot negate non-conditional expressions.")
1333        return resolved
1334
1335    def select_format(
1336        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1337    ) -> tuple[str, Sequence[Any]]:
1338        # Boolean expressions work directly in SELECT
1339        return sql, params
1340
1341
1342@deconstructible(path="plain.postgres.When")
1343class When(Expression):
1344    template = "WHEN %(condition)s THEN %(result)s"
1345    # This isn't a complete conditional expression, must be used in Case().
1346    conditional = False
1347    condition: SQLCompilable
1348
1349    def __init__(
1350        self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1351    ):
1352        lookups_dict: dict[str, Any] | None = lookups or None
1353        if lookups_dict:
1354            if condition is None:
1355                condition, lookups_dict = Q(**lookups_dict), None
1356            elif getattr(condition, "conditional", False):
1357                condition, lookups_dict = Q(condition, **lookups_dict), None
1358        if (
1359            condition is None
1360            or not getattr(condition, "conditional", False)
1361            or lookups_dict
1362        ):
1363            raise TypeError(
1364                "When() supports a Q object, a boolean expression, or lookups "
1365                "as a condition."
1366            )
1367        if isinstance(condition, Q) and not condition:
1368            raise ValueError("An empty Q() can't be used as a When() condition.")
1369        super().__init__(output_field=None)
1370        self.condition = condition  # type: ignore[assignment]
1371        self.result = self._parse_expressions(then)[0]
1372
1373    def __str__(self) -> str:
1374        return f"WHEN {self.condition!r} THEN {self.result!r}"
1375
1376    def __repr__(self) -> str:
1377        return f"<{self.__class__.__name__}: {self}>"
1378
1379    def get_source_expressions(self) -> list[Any]:
1380        return [self.condition, self.result]
1381
1382    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1383        self.condition, self.result = exprs
1384
1385    def get_source_fields(self) -> list[Field | None]:
1386        # We're only interested in the fields of the result expressions.
1387        return [self.result._output_field_or_none]
1388
1389    def resolve_expression(
1390        self,
1391        query: Any = None,
1392        allow_joins: bool = True,
1393        reuse: Any = None,
1394        summarize: bool = False,
1395        for_save: bool = False,
1396    ) -> When:
1397        c = self.copy()
1398        c.is_summary = summarize
1399        if isinstance(c.condition, ResolvableExpression):
1400            c.condition = c.condition.resolve_expression(
1401                query, allow_joins, reuse, summarize, False
1402            )
1403        c.result = c.result.resolve_expression(
1404            query, allow_joins, reuse, summarize, for_save
1405        )
1406        return c
1407
1408    def as_sql(
1409        self,
1410        compiler: SQLCompiler,
1411        connection: DatabaseConnection,
1412        template: str | None = None,
1413        **extra_context: Any,
1414    ) -> tuple[str, tuple[Any, ...]]:
1415        template_params = extra_context
1416        sql_params = []
1417        # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1418        condition_sql, condition_params = compiler.compile(self.condition)
1419        template_params["condition"] = condition_sql
1420        result_sql, result_params = compiler.compile(self.result)
1421        template_params["result"] = result_sql
1422        template = template or self.template
1423        return template % template_params, (
1424            *sql_params,
1425            *condition_params,
1426            *result_params,
1427        )
1428
1429    def get_group_by_cols(self) -> list[Any]:
1430        # This is not a complete expression and cannot be used in GROUP BY.
1431        cols = []
1432        for source in self.get_source_expressions():
1433            cols.extend(source.get_group_by_cols())
1434        return cols
1435
1436
1437@deconstructible(path="plain.postgres.Case")
1438class Case(Expression):
1439    """
1440    An SQL searched CASE expression:
1441
1442        CASE
1443            WHEN n > 0
1444                THEN 'positive'
1445            WHEN n < 0
1446                THEN 'negative'
1447            ELSE 'zero'
1448        END
1449    """
1450
1451    template = "CASE %(cases)s ELSE %(default)s END"
1452    case_joiner = " "
1453
1454    def __init__(
1455        self,
1456        *cases: When,
1457        default: Any = None,
1458        output_field: Field | None = None,
1459        **extra: Any,
1460    ):
1461        if not all(isinstance(case, When) for case in cases):
1462            raise TypeError("Positional arguments must all be When objects.")
1463        super().__init__(output_field)
1464        self.cases = list(cases)
1465        self.default = self._parse_expressions(default)[0]
1466        self.extra = extra
1467
1468    def __str__(self) -> str:
1469        return "CASE {}, ELSE {!r}".format(
1470            ", ".join(str(c) for c in self.cases),
1471            self.default,
1472        )
1473
1474    def __repr__(self) -> str:
1475        return f"<{self.__class__.__name__}: {self}>"
1476
1477    def get_source_expressions(self) -> list[Any]:
1478        return self.cases + [self.default]
1479
1480    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1481        *self.cases, self.default = exprs
1482
1483    def resolve_expression(
1484        self,
1485        query: Any = None,
1486        allow_joins: bool = True,
1487        reuse: Any = None,
1488        summarize: bool = False,
1489        for_save: bool = False,
1490    ) -> Case:
1491        c = self.copy()
1492        c.is_summary = summarize
1493        for pos, case in enumerate(c.cases):
1494            c.cases[pos] = case.resolve_expression(
1495                query, allow_joins, reuse, summarize, for_save
1496            )
1497        c.default = c.default.resolve_expression(
1498            query, allow_joins, reuse, summarize, for_save
1499        )
1500        return c
1501
1502    def copy(self) -> Self:
1503        c = super().copy()
1504        c.cases = c.cases[:]
1505        return c
1506
1507    def as_sql(
1508        self,
1509        compiler: SQLCompiler,
1510        connection: DatabaseConnection,
1511        template: str | None = None,
1512        case_joiner: str | None = None,
1513        **extra_context: Any,
1514    ) -> tuple[str, list[Any]]:
1515        if not self.cases:
1516            sql, params = compiler.compile(self.default)
1517            return sql, list(params)
1518        template_params = {**self.extra, **extra_context}
1519        case_parts = []
1520        sql_params = []
1521        default_sql, default_params = compiler.compile(self.default)
1522        for case in self.cases:
1523            try:
1524                case_sql, case_params = compiler.compile(case)
1525            except EmptyResultSet:
1526                continue
1527            except FullResultSet:
1528                default_sql, default_params = compiler.compile(case.result)
1529                break
1530            case_parts.append(case_sql)
1531            sql_params.extend(case_params)
1532        if not case_parts:
1533            return default_sql, list(default_params)
1534        case_joiner = case_joiner or self.case_joiner
1535        template_params["cases"] = case_joiner.join(case_parts)
1536        template_params["default"] = default_sql
1537        sql_params.extend(default_params)
1538        template = template or template_params.get("template", self.template)
1539        sql = template % template_params
1540        if self._output_field_or_none is not None:
1541            sql = connection.unification_cast_sql(self.output_field) % sql
1542        return sql, sql_params
1543
1544    def get_group_by_cols(self) -> list[Any]:
1545        if not self.cases:
1546            return self.default.get_group_by_cols()
1547        return super().get_group_by_cols()
1548
1549
1550class Subquery(BaseExpression, Combinable):
1551    """
1552    An explicit subquery. It may contain OuterRef() references to the outer
1553    query which will be resolved when it is applied to that query.
1554    """
1555
1556    template = "(%(subquery)s)"
1557    contains_aggregate = False
1558    empty_result_set_value = None
1559
1560    def __init__(
1561        self,
1562        query: QuerySet[Any] | Query,
1563        output_field: Field | None = None,
1564        **extra: Any,
1565    ):
1566        # Import here to avoid circular import
1567        from plain.postgres.sql.query import Query
1568
1569        # Allow the usage of both QuerySet and sql.Query objects.
1570        if isinstance(query, Query):
1571            # It's already a Query object, use it directly
1572            sql_query = query
1573        else:
1574            # It's a QuerySet, extract the sql.Query
1575            sql_query = query.sql_query
1576        self.query = sql_query.clone()
1577        self.query.subquery = True
1578        self.extra = extra
1579        super().__init__(output_field)
1580
1581    def get_source_expressions(self) -> list[Any]:
1582        return [self.query]
1583
1584    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1585        self.query = exprs[0]
1586
1587    def _resolve_output_field(self) -> Field | None:
1588        return self.query.output_field
1589
1590    def copy(self) -> Self:
1591        clone = super().copy()
1592        clone.query = clone.query.clone()
1593        return clone
1594
1595    @property
1596    def external_aliases(self) -> dict[str, bool]:
1597        return self.query.external_aliases
1598
1599    def get_external_cols(self) -> list[Any]:
1600        return self.query.get_external_cols()
1601
1602    def as_sql(
1603        self,
1604        compiler: SQLCompiler,
1605        connection: DatabaseConnection,
1606        template: str | None = None,
1607        **extra_context: Any,
1608    ) -> tuple[str, tuple[Any, ...]]:
1609        template_params = {**self.extra, **extra_context}
1610        subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1611        template_params["subquery"] = subquery_sql[1:-1]
1612
1613        template = template or template_params.get("template", self.template)
1614        sql = template % template_params
1615        return sql, sql_params
1616
1617    def get_group_by_cols(self) -> list[Any]:
1618        return self.query.get_group_by_cols(wrapper=self)
1619
1620
1621class Exists(Subquery):
1622    template = "EXISTS(%(subquery)s)"
1623    output_field = fields.BooleanField()
1624    empty_result_set_value = False
1625
1626    def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1627        super().__init__(query, **kwargs)
1628        self.query = self.query.exists()
1629
1630    def select_format(
1631        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1632    ) -> tuple[str, Sequence[Any]]:
1633        # Boolean expressions work directly in SELECT
1634        return sql, params
1635
1636
1637@deconstructible(path="plain.postgres.OrderBy")
1638class OrderBy(Expression):
1639    template = "%(expression)s %(ordering)s"
1640    conditional = False
1641
1642    def __init__(
1643        self,
1644        expression: Any,
1645        descending: bool = False,
1646        nulls_first: bool | None = None,
1647        nulls_last: bool | None = None,
1648    ):
1649        if nulls_first and nulls_last:
1650            raise ValueError("nulls_first and nulls_last are mutually exclusive")
1651        if nulls_first is False or nulls_last is False:
1652            raise ValueError("nulls_first and nulls_last values must be True or None.")
1653        self.nulls_first = nulls_first
1654        self.nulls_last = nulls_last
1655        self.descending = descending
1656        if not isinstance(expression, ResolvableExpression):
1657            raise ValueError("expression must be an expression type")
1658        self.expression = expression
1659
1660    def __repr__(self) -> str:
1661        return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1662
1663    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1664        self.expression = exprs[0]
1665
1666    def get_source_expressions(self) -> list[Any]:
1667        return [self.expression]
1668
1669    def as_sql(
1670        self,
1671        compiler: SQLCompiler,
1672        connection: DatabaseConnection,
1673        template: str | None = None,
1674        **extra_context: Any,
1675    ) -> tuple[str, tuple[Any, ...]]:
1676        template = template or self.template
1677        # Handle NULLS FIRST/LAST modifiers
1678        if self.nulls_last:
1679            template = f"{template} NULLS LAST"
1680        elif self.nulls_first:
1681            template = f"{template} NULLS FIRST"
1682        expression_sql, params = compiler.compile(self.expression)
1683        placeholders = {
1684            "expression": expression_sql,
1685            "ordering": "DESC" if self.descending else "ASC",
1686            **extra_context,
1687        }
1688        params *= template.count("%(expression)s")
1689        return (template % placeholders).rstrip(), params
1690
1691    def get_group_by_cols(self) -> list[Any]:
1692        cols = []
1693        for source in self.get_source_expressions():
1694            cols.extend(source.get_group_by_cols())
1695        return cols
1696
1697    def reverse_ordering(self) -> OrderBy:
1698        self.descending = not self.descending
1699        if self.nulls_first:
1700            self.nulls_last = True
1701            self.nulls_first = None
1702        elif self.nulls_last:
1703            self.nulls_first = True
1704            self.nulls_last = None
1705        return self
1706
1707    def asc(self) -> None:  # type: ignore[override]
1708        self.descending = False
1709
1710    def desc(self) -> None:  # type: ignore[override]
1711        self.descending = True
1712
1713
1714class Window(Expression):
1715    template = "%(expression)s OVER (%(window)s)"
1716    # Although the main expression may either be an aggregate or an
1717    # expression with an aggregate function, the GROUP BY that will
1718    # be introduced in the query as a result is not desired.
1719    contains_aggregate = False
1720    contains_over_clause = True
1721    partition_by: ExpressionList | None
1722    order_by: OrderByList | None
1723
1724    def __init__(
1725        self,
1726        expression: Any,
1727        partition_by: Any = None,
1728        order_by: Any = None,
1729        frame: Any = None,
1730        output_field: Field | None = None,
1731    ):
1732        self.partition_by = partition_by
1733        self.order_by = order_by
1734        self.frame = frame
1735
1736        if not getattr(expression, "window_compatible", False):
1737            raise ValueError(
1738                f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1739            )
1740
1741        if self.partition_by is not None:
1742            partition_by_values = (
1743                self.partition_by
1744                if isinstance(self.partition_by, tuple | list)
1745                else (self.partition_by,)
1746            )
1747            self.partition_by = ExpressionList(*partition_by_values)
1748
1749        if self.order_by is not None:
1750            if isinstance(self.order_by, list | tuple):
1751                self.order_by = OrderByList(*self.order_by)
1752            elif isinstance(self.order_by, BaseExpression | str):
1753                self.order_by = OrderByList(self.order_by)
1754            else:
1755                raise ValueError(
1756                    "Window.order_by must be either a string reference to a "
1757                    "field, an expression, or a list or tuple of them."
1758                )
1759        super().__init__(output_field=output_field)
1760        self.source_expression = self._parse_expressions(expression)[0]
1761
1762    def _resolve_output_field(self) -> Field | None:
1763        return self.source_expression.output_field
1764
1765    def get_source_expressions(self) -> list[Any]:
1766        return [self.source_expression, self.partition_by, self.order_by, self.frame]
1767
1768    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1769        self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1770
1771    def as_sql(
1772        self,
1773        compiler: SQLCompiler,
1774        connection: DatabaseConnection,
1775        template: str | None = None,
1776    ) -> tuple[str, tuple[Any, ...]]:
1777        expr_sql, params = compiler.compile(self.source_expression)
1778        window_sql, window_params = [], ()
1779
1780        if self.partition_by is not None:
1781            sql_expr, sql_params = self.partition_by.as_sql(
1782                compiler=compiler,
1783                connection=connection,
1784                template="PARTITION BY %(expressions)s",
1785            )
1786            window_sql.append(sql_expr)
1787            window_params += tuple(sql_params)
1788
1789        if self.order_by is not None:
1790            order_sql, order_params = compiler.compile(self.order_by)
1791            window_sql.append(order_sql)
1792            window_params += tuple(order_params)
1793
1794        if self.frame:
1795            frame_sql, frame_params = compiler.compile(self.frame)
1796            window_sql.append(frame_sql)
1797            window_params += tuple(frame_params)
1798
1799        template = template or self.template
1800
1801        return (
1802            template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1803            (*params, *window_params),
1804        )
1805
1806    def __str__(self) -> str:
1807        return "{} OVER ({}{}{})".format(
1808            str(self.source_expression),
1809            "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1810            str(self.order_by or ""),
1811            str(self.frame or ""),
1812        )
1813
1814    def __repr__(self) -> str:
1815        return f"<{self.__class__.__name__}: {self}>"
1816
1817    def get_group_by_cols(self) -> list[Any]:
1818        group_by_cols = []
1819        if self.partition_by:
1820            group_by_cols.extend(self.partition_by.get_group_by_cols())
1821        if self.order_by is not None:
1822            group_by_cols.extend(self.order_by.get_group_by_cols())
1823        return group_by_cols
1824
1825
1826class WindowFrame(Expression):
1827    """
1828    Model the frame clause in window expressions. There are two types of frame
1829    clauses which are subclasses, however, all processing and validation (by no
1830    means intended to be complete) is done here. Thus, providing an end for a
1831    frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1832    row in the frame).
1833    """
1834
1835    template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1836    frame_type: str
1837
1838    def __init__(self, start: int | None = None, end: int | None = None):
1839        self.start = Value(start)
1840        self.end = Value(end)
1841
1842    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1843        self.start, self.end = exprs
1844
1845    def get_source_expressions(self) -> list[Any]:
1846        return [self.start, self.end]
1847
1848    def as_sql(
1849        self, compiler: SQLCompiler, connection: DatabaseConnection
1850    ) -> tuple[str, list[Any]]:
1851        start, end = self.window_frame_start_end(
1852            connection, self.start.value, self.end.value
1853        )
1854        return (
1855            self.template
1856            % {
1857                "frame_type": self.frame_type,
1858                "start": start,
1859                "end": end,
1860            },
1861            [],
1862        )
1863
1864    def __repr__(self) -> str:
1865        return f"<{self.__class__.__name__}: {self}>"
1866
1867    def get_group_by_cols(self) -> list[Any]:
1868        return []
1869
1870    def __str__(self) -> str:
1871        if self.start.value is not None and self.start.value < 0:
1872            start = f"{abs(self.start.value)} {PRECEDING}"
1873        elif self.start.value is not None and self.start.value == 0:
1874            start = CURRENT_ROW
1875        else:
1876            start = UNBOUNDED_PRECEDING
1877
1878        if self.end.value is not None and self.end.value > 0:
1879            end = f"{self.end.value} {FOLLOWING}"
1880        elif self.end.value is not None and self.end.value == 0:
1881            end = CURRENT_ROW
1882        else:
1883            end = UNBOUNDED_FOLLOWING
1884        return self.template % {
1885            "frame_type": self.frame_type,
1886            "start": start,
1887            "end": end,
1888        }
1889
1890    def window_frame_start_end(
1891        self, connection: DatabaseConnection, start: int | None, end: int | None
1892    ) -> tuple[str, str]:
1893        """Return the window frame start and end for the given connection."""
1894        raise NotImplementedError("Subclasses must implement window_frame_start_end()")
1895
1896
1897class RowRange(WindowFrame):
1898    frame_type = "ROWS"
1899
1900    def window_frame_start_end(
1901        self, connection: DatabaseConnection, start: int | None, end: int | None
1902    ) -> tuple[str, str]:
1903        return window_frame_rows_start_end(start, end)
1904
1905
1906class ValueRange(WindowFrame):
1907    frame_type = "RANGE"
1908
1909    def window_frame_start_end(
1910        self, connection: DatabaseConnection, start: int | None, end: int | None
1911    ) -> tuple[str, str]:
1912        return window_frame_range_start_end(start, end)