Plain is headed towards 1.0! Subscribe for development updates →

   1from __future__ import annotations
   2
   3import copy
   4import datetime
   5import functools
   6import inspect
   7from abc import ABC, abstractmethod
   8from collections import defaultdict
   9from decimal import Decimal
  10from functools import cached_property
  11from types import NoneType
  12from typing import TYPE_CHECKING, Any, Self
  13from uuid import UUID
  14
  15from plain.models import fields
  16from plain.models.constants import LOOKUP_SEP
  17from plain.models.db import (
  18    DatabaseError,
  19    NotSupportedError,
  20    db_connection,
  21)
  22from plain.models.exceptions import EmptyResultSet, FieldError, FullResultSet
  23from plain.models.query_utils import Q
  24from plain.utils.deconstruct import deconstructible
  25from plain.utils.hashable import make_hashable
  26
  27if TYPE_CHECKING:
  28    from collections.abc import Callable, Iterable, Sequence
  29
  30    from plain.models.backends.base.base import BaseDatabaseWrapper
  31    from plain.models.fields import Field
  32    from plain.models.lookups import Lookup, Transform
  33    from plain.models.query import QuerySet
  34    from plain.models.sql.compiler import SQLCompiler
  35    from plain.models.sql.query import Query
  36
  37
  38class Combinable:
  39    """
  40    Provide the ability to combine one or two objects with
  41    some connector. For example F('foo') + F('bar').
  42    """
  43
  44    # Arithmetic connectors
  45    ADD = "+"
  46    SUB = "-"
  47    MUL = "*"
  48    DIV = "/"
  49    POW = "^"
  50    # The following is a quoted % operator - it is quoted because it can be
  51    # used in strings that also have parameter substitution.
  52    MOD = "%%"
  53
  54    # Bitwise operators - note that these are generated by .bitand()
  55    # and .bitor(), the '&' and '|' are reserved for boolean operator
  56    # usage.
  57    BITAND = "&"
  58    BITOR = "|"
  59    BITLEFTSHIFT = "<<"
  60    BITRIGHTSHIFT = ">>"
  61    BITXOR = "#"
  62
  63    def _combine(
  64        self, other: Any, connector: str, reversed: bool
  65    ) -> CombinedExpression:
  66        if not hasattr(other, "resolve_expression"):
  67            # everything must be resolvable to an expression
  68            other = Value(other)
  69
  70        if reversed:
  71            return CombinedExpression(other, connector, self)
  72        return CombinedExpression(self, connector, other)
  73
  74    #############
  75    # OPERATORS #
  76    #############
  77
  78    def __neg__(self) -> CombinedExpression:
  79        return self._combine(-1, self.MUL, False)
  80
  81    def __add__(self, other: Any) -> CombinedExpression:
  82        return self._combine(other, self.ADD, False)
  83
  84    def __sub__(self, other: Any) -> CombinedExpression:
  85        return self._combine(other, self.SUB, False)
  86
  87    def __mul__(self, other: Any) -> CombinedExpression:
  88        return self._combine(other, self.MUL, False)
  89
  90    def __truediv__(self, other: Any) -> CombinedExpression:
  91        return self._combine(other, self.DIV, False)
  92
  93    def __mod__(self, other: Any) -> CombinedExpression:
  94        return self._combine(other, self.MOD, False)
  95
  96    def __pow__(self, other: Any) -> CombinedExpression:
  97        return self._combine(other, self.POW, False)
  98
  99    def __and__(self, other: Any) -> Q:
 100        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 101            return Q(self) & Q(other)  # type: ignore[unsupported-operator]
 102        raise NotImplementedError(
 103            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 104        )
 105
 106    def bitand(self, other: Any) -> CombinedExpression:
 107        return self._combine(other, self.BITAND, False)
 108
 109    def bitleftshift(self, other: Any) -> CombinedExpression:
 110        return self._combine(other, self.BITLEFTSHIFT, False)
 111
 112    def bitrightshift(self, other: Any) -> CombinedExpression:
 113        return self._combine(other, self.BITRIGHTSHIFT, False)
 114
 115    def __xor__(self, other: Any) -> Q:
 116        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 117            return Q(self) ^ Q(other)  # type: ignore[unsupported-operator]
 118        raise NotImplementedError(
 119            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 120        )
 121
 122    def bitxor(self, other: Any) -> CombinedExpression:
 123        return self._combine(other, self.BITXOR, False)
 124
 125    def __or__(self, other: Any) -> Q:
 126        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 127            return Q(self) | Q(other)  # type: ignore[unsupported-operator]
 128        raise NotImplementedError(
 129            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 130        )
 131
 132    def bitor(self, other: Any) -> CombinedExpression:
 133        return self._combine(other, self.BITOR, False)
 134
 135    def __radd__(self, other: Any) -> CombinedExpression:
 136        return self._combine(other, self.ADD, True)
 137
 138    def __rsub__(self, other: Any) -> CombinedExpression:
 139        return self._combine(other, self.SUB, True)
 140
 141    def __rmul__(self, other: Any) -> CombinedExpression:
 142        return self._combine(other, self.MUL, True)
 143
 144    def __rtruediv__(self, other: Any) -> CombinedExpression:
 145        return self._combine(other, self.DIV, True)
 146
 147    def __rmod__(self, other: Any) -> CombinedExpression:
 148        return self._combine(other, self.MOD, True)
 149
 150    def __rpow__(self, other: Any) -> CombinedExpression:
 151        return self._combine(other, self.POW, True)
 152
 153    def __rand__(self, other: Any) -> None:
 154        raise NotImplementedError(
 155            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 156        )
 157
 158    def __ror__(self, other: Any) -> None:
 159        raise NotImplementedError(
 160            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 161        )
 162
 163    def __rxor__(self, other: Any) -> None:
 164        raise NotImplementedError(
 165            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 166        )
 167
 168    def __invert__(self) -> NegatedExpression:
 169        return NegatedExpression(self)
 170
 171
 172class BaseExpression:
 173    """Base class for all query expressions."""
 174
 175    empty_result_set_value = NotImplemented
 176    # aggregate specific fields
 177    is_summary = False
 178    _output_field_resolved_to_none = False
 179    # Can the expression be used in a WHERE clause?
 180    filterable = True
 181    # Can the expression can be used as a source expression in Window?
 182    window_compatible = False
 183
 184    def __init__(self, output_field: Field | None = None):
 185        if output_field is not None:
 186            self.output_field = output_field
 187
 188    def __getstate__(self) -> dict[str, Any]:
 189        state = self.__dict__.copy()
 190        state.pop("convert_value", None)
 191        return state
 192
 193    def get_db_converters(
 194        self, connection: BaseDatabaseWrapper
 195    ) -> list[Callable[..., Any]]:
 196        converters = []
 197        if self.convert_value is not self._convert_value_noop:
 198            converters.append(self.convert_value)
 199        converters.extend(self.output_field.get_db_converters(connection))
 200        return converters
 201
 202    def get_source_expressions(self) -> list[Any]:
 203        return []
 204
 205    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 206        assert not exprs
 207
 208    def _parse_expressions(self, *expressions: Any) -> list[Any]:
 209        return [
 210            arg
 211            if hasattr(arg, "resolve_expression")
 212            else (F(arg) if isinstance(arg, str) else Value(arg))
 213            for arg in expressions
 214        ]
 215
 216    def as_sql(
 217        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 218    ) -> tuple[str, Sequence[Any]]:
 219        """
 220        Responsible for returning a (sql, [params]) tuple to be included
 221        in the current query.
 222
 223        Different backends can provide their own implementation, by
 224        providing an `as_{vendor}` method and patching the Expression:
 225
 226        ```
 227        def override_as_sql(self, compiler, connection):
 228            # custom logic
 229            return super().as_sql(compiler, connection)
 230        setattr(Expression, 'as_' + connection.vendor, override_as_sql)
 231        ```
 232
 233        Arguments:
 234         * compiler: the query compiler responsible for generating the query.
 235           Must have a compile method, returning a (sql, [params]) tuple.
 236           Calling compiler(value) will return a quoted `value`.
 237
 238         * connection: the database connection used for the current query.
 239
 240        Return: (sql, params)
 241          Where `sql` is a string containing ordered sql parameters to be
 242          replaced with the elements of the list `params`.
 243        """
 244        raise NotImplementedError("Subclasses must implement as_sql()")
 245
 246    @cached_property
 247    def contains_aggregate(self) -> bool:
 248        return any(
 249            expr and expr.contains_aggregate for expr in self.get_source_expressions()
 250        )
 251
 252    @cached_property
 253    def contains_over_clause(self) -> bool:
 254        return any(
 255            expr and expr.contains_over_clause for expr in self.get_source_expressions()
 256        )
 257
 258    @cached_property
 259    def contains_column_references(self) -> bool:
 260        return any(
 261            expr and expr.contains_column_references
 262            for expr in self.get_source_expressions()
 263        )
 264
 265    def resolve_expression(
 266        self,
 267        query: Any = None,
 268        allow_joins: bool = True,
 269        reuse: Any = None,
 270        summarize: bool = False,
 271        for_save: bool = False,
 272    ) -> BaseExpression:
 273        """
 274        Provide the chance to do any preprocessing or validation before being
 275        added to the query.
 276
 277        Arguments:
 278         * query: the backend query implementation
 279         * allow_joins: boolean allowing or denying use of joins
 280           in this query
 281         * reuse: a set of reusable joins for multijoins
 282         * summarize: a terminal aggregate clause
 283         * for_save: whether this expression about to be used in a save or update
 284
 285        Return: an Expression to be added to the query.
 286        """
 287        c = self.copy()
 288        c.is_summary = summarize
 289        c.set_source_expressions(
 290            [
 291                expr.resolve_expression(query, allow_joins, reuse, summarize)
 292                if expr
 293                else None
 294                for expr in c.get_source_expressions()
 295            ]
 296        )
 297        return c
 298
 299    @property
 300    def conditional(self) -> bool:
 301        output_field = getattr(self, "output_field", None)
 302        return isinstance(output_field, fields.BooleanField)
 303
 304    @property
 305    def field(self) -> Field:
 306        return self.output_field
 307
 308    @cached_property
 309    def output_field(self) -> Field:
 310        """Return the output type of this expressions."""
 311        output_field = self._resolve_output_field()
 312        if output_field is None:
 313            self._output_field_resolved_to_none = True
 314            raise FieldError("Cannot resolve expression type, unknown output_field")
 315        return output_field
 316
 317    @cached_property
 318    def _output_field_or_none(self) -> Field | None:
 319        """
 320        Return the output field of this expression, or None if
 321        _resolve_output_field() didn't return an output type.
 322        """
 323        try:
 324            return self.output_field
 325        except FieldError:
 326            if not self._output_field_resolved_to_none:
 327                raise
 328            return None
 329
 330    def _resolve_output_field(self) -> Field | None:
 331        """
 332        Attempt to infer the output type of the expression.
 333
 334        As a guess, if the output fields of all source fields match then simply
 335        infer the same type here.
 336
 337        If a source's output field resolves to None, exclude it from this check.
 338        If all sources are None, then an error is raised higher up the stack in
 339        the output_field property.
 340        """
 341        # This guess is mostly a bad idea, but there is quite a lot of code
 342        # (especially 3rd party Func subclasses) that depend on it, we'd need a
 343        # deprecation path to fix it.
 344        sources_iter = (
 345            source for source in self.get_source_fields() if source is not None
 346        )
 347        for output_field in sources_iter:
 348            for source in sources_iter:
 349                if not isinstance(output_field, source.__class__):
 350                    raise FieldError(
 351                        f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
 352                        "set output_field."
 353                    )
 354            return output_field
 355        return None
 356
 357    @staticmethod
 358    def _convert_value_noop(
 359        value: Any, expression: Any, connection: BaseDatabaseWrapper
 360    ) -> Any:
 361        return value
 362
 363    @cached_property
 364    def convert_value(self) -> Callable[[Any, Any, Any], Any]:
 365        """
 366        Expressions provide their own converters because users have the option
 367        of manually specifying the output_field which may be a different type
 368        from the one the database returns.
 369        """
 370        field = self.output_field
 371        internal_type = field.get_internal_type()
 372        if internal_type == "FloatField":
 373            return (
 374                lambda value, expression, connection: None
 375                if value is None
 376                else float(value)
 377            )
 378        elif internal_type.endswith("IntegerField"):
 379            return (
 380                lambda value, expression, connection: None
 381                if value is None
 382                else int(value)
 383            )
 384        elif internal_type == "DecimalField":
 385            return (
 386                lambda value, expression, connection: None
 387                if value is None
 388                else Decimal(value)
 389            )
 390        return self._convert_value_noop
 391
 392    def get_lookup(self, lookup: str) -> type[Lookup] | None:
 393        return self.output_field.get_lookup(lookup)
 394
 395    def get_transform(self, name: str) -> type[Transform] | None:
 396        return self.output_field.get_transform(name)
 397
 398    def relabeled_clone(self, change_map: dict[str, str]) -> BaseExpression:
 399        clone = self.copy()
 400        clone.set_source_expressions(
 401            [
 402                e.relabeled_clone(change_map) if e is not None else None
 403                for e in self.get_source_expressions()
 404            ]
 405        )
 406        return clone
 407
 408    def replace_expressions(
 409        self, replacements: dict[BaseExpression, Any]
 410    ) -> BaseExpression:
 411        if replacement := replacements.get(self):
 412            return replacement
 413        clone = self.copy()
 414        source_expressions = clone.get_source_expressions()
 415        clone.set_source_expressions(
 416            [
 417                expr.replace_expressions(replacements) if expr else None
 418                for expr in source_expressions
 419            ]
 420        )
 421        return clone
 422
 423    def get_refs(self) -> set[str]:
 424        refs = set()
 425        for expr in self.get_source_expressions():
 426            refs |= expr.get_refs()
 427        return refs
 428
 429    def copy(self) -> Self:
 430        return copy.copy(self)
 431
 432    def prefix_references(self, prefix: str) -> Self:
 433        clone = self.copy()
 434        clone.set_source_expressions(
 435            [
 436                F(f"{prefix}{expr.name}")
 437                if isinstance(expr, F)
 438                else expr.prefix_references(prefix)
 439                for expr in self.get_source_expressions()
 440            ]
 441        )
 442        return clone
 443
 444    def get_group_by_cols(self) -> list[BaseExpression]:
 445        if not self.contains_aggregate:
 446            return [self]
 447        cols = []
 448        for source in self.get_source_expressions():
 449            cols.extend(source.get_group_by_cols())
 450        return cols
 451
 452    def get_source_fields(self) -> list[Field | None]:
 453        """Return the underlying field types used by this aggregate."""
 454        return [e._output_field_or_none for e in self.get_source_expressions()]
 455
 456    def asc(self, **kwargs: Any) -> OrderBy:
 457        return OrderBy(self, **kwargs)
 458
 459    def desc(self, **kwargs: Any) -> OrderBy:
 460        return OrderBy(self, descending=True, **kwargs)
 461
 462    def reverse_ordering(self) -> BaseExpression:
 463        return self
 464
 465    def flatten(self) -> Iterable[Any]:
 466        """
 467        Recursively yield this expression and all subexpressions, in
 468        depth-first order.
 469        """
 470        yield self
 471        for expr in self.get_source_expressions():
 472            if expr:
 473                if hasattr(expr, "flatten"):
 474                    yield from expr.flatten()
 475                else:
 476                    yield expr
 477
 478    def select_format(
 479        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
 480    ) -> tuple[str, Sequence[Any]]:
 481        """
 482        Custom format for select clauses. For example, EXISTS expressions need
 483        to be wrapped in CASE WHEN on Oracle.
 484        """
 485        if output_field := getattr(self, "output_field", None):
 486            if select_format := getattr(output_field, "select_format", None):
 487                return select_format(compiler, sql, params)
 488        return sql, params
 489
 490
 491@deconstructible
 492class Expression(BaseExpression, Combinable):
 493    """An expression that can be combined with other expressions."""
 494
 495    @cached_property
 496    def identity(self) -> tuple[Any, ...]:
 497        constructor_signature = inspect.signature(self.__init__)
 498        args, kwargs = self._constructor_args  # type: ignore[attr-defined]
 499        signature = constructor_signature.bind_partial(*args, **kwargs)
 500        signature.apply_defaults()
 501        arguments = signature.arguments.items()
 502        identity: list[Any] = [self.__class__]
 503        for arg, value in arguments:
 504            if isinstance(value, fields.Field):
 505                if value.name and value.model:
 506                    value = (value.model.model_options.label, value.name)
 507                else:
 508                    value = type(value)
 509            else:
 510                value = make_hashable(value)
 511            identity.append((arg, value))
 512        return tuple(identity)
 513
 514    def __eq__(self, other: object) -> bool:
 515        if not isinstance(other, Expression):
 516            return NotImplemented
 517        return other.identity == self.identity
 518
 519    def __hash__(self) -> int:
 520        return hash(self.identity)
 521
 522
 523# Type inference for CombinedExpression.output_field.
 524# Missing items will result in FieldError, by design.
 525#
 526# The current approach for NULL is based on lowest common denominator behavior
 527# i.e. if one of the supported databases is raising an error (rather than
 528# return NULL) for `val <op> NULL`, then Plain raises FieldError.
 529
 530_connector_combinations = [
 531    # Numeric operations - operands of same type.
 532    {
 533        connector: [
 534            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 535            (fields.FloatField, fields.FloatField, fields.FloatField),
 536            (fields.DecimalField, fields.DecimalField, fields.DecimalField),
 537        ]
 538        for connector in (
 539            Combinable.ADD,
 540            Combinable.SUB,
 541            Combinable.MUL,
 542            # Behavior for DIV with integer arguments follows Postgres/SQLite,
 543            # not MySQL/Oracle.
 544            Combinable.DIV,
 545            Combinable.MOD,
 546            Combinable.POW,
 547        )
 548    },
 549    # Numeric operations - operands of different type.
 550    {
 551        connector: [
 552            (fields.IntegerField, fields.DecimalField, fields.DecimalField),
 553            (fields.DecimalField, fields.IntegerField, fields.DecimalField),
 554            (fields.IntegerField, fields.FloatField, fields.FloatField),
 555            (fields.FloatField, fields.IntegerField, fields.FloatField),
 556        ]
 557        for connector in (
 558            Combinable.ADD,
 559            Combinable.SUB,
 560            Combinable.MUL,
 561            Combinable.DIV,
 562            Combinable.MOD,
 563        )
 564    },
 565    # Bitwise operators.
 566    {
 567        connector: [
 568            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 569        ]
 570        for connector in (
 571            Combinable.BITAND,
 572            Combinable.BITOR,
 573            Combinable.BITLEFTSHIFT,
 574            Combinable.BITRIGHTSHIFT,
 575            Combinable.BITXOR,
 576        )
 577    },
 578    # Numeric with NULL.
 579    {
 580        connector: [
 581            (field_type, NoneType, field_type),
 582            (NoneType, field_type, field_type),
 583        ]
 584        for connector in (
 585            Combinable.ADD,
 586            Combinable.SUB,
 587            Combinable.MUL,
 588            Combinable.DIV,
 589            Combinable.MOD,
 590            Combinable.POW,
 591        )
 592        for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
 593    },
 594    # Date/DateTimeField/DurationField/TimeField.
 595    {
 596        Combinable.ADD: [
 597            # Date/DateTimeField.
 598            (fields.DateField, fields.DurationField, fields.DateTimeField),
 599            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 600            (fields.DurationField, fields.DateField, fields.DateTimeField),
 601            (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
 602            # DurationField.
 603            (fields.DurationField, fields.DurationField, fields.DurationField),
 604            # TimeField.
 605            (fields.TimeField, fields.DurationField, fields.TimeField),
 606            (fields.DurationField, fields.TimeField, fields.TimeField),
 607        ],
 608    },
 609    {
 610        Combinable.SUB: [
 611            # Date/DateTimeField.
 612            (fields.DateField, fields.DurationField, fields.DateTimeField),
 613            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 614            (fields.DateField, fields.DateField, fields.DurationField),
 615            (fields.DateField, fields.DateTimeField, fields.DurationField),
 616            (fields.DateTimeField, fields.DateField, fields.DurationField),
 617            (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
 618            # DurationField.
 619            (fields.DurationField, fields.DurationField, fields.DurationField),
 620            # TimeField.
 621            (fields.TimeField, fields.DurationField, fields.TimeField),
 622            (fields.TimeField, fields.TimeField, fields.DurationField),
 623        ],
 624    },
 625]
 626
 627_connector_combinators = defaultdict(list)
 628
 629
 630def register_combinable_fields(
 631    lhs: type[Field] | type[None],
 632    connector: str,
 633    rhs: type[Field] | type[None],
 634    result: type[Field],
 635) -> None:
 636    """
 637    Register combinable types:
 638        lhs <connector> rhs -> result
 639    e.g.
 640        register_combinable_fields(
 641            IntegerField, Combinable.ADD, FloatField, FloatField
 642        )
 643    """
 644    _connector_combinators[connector].append((lhs, rhs, result))
 645
 646
 647for d in _connector_combinations:
 648    for connector, field_types in d.items():
 649        for lhs, rhs, result in field_types:
 650            register_combinable_fields(lhs, connector, rhs, result)
 651
 652
 653@functools.lru_cache(maxsize=128)
 654def _resolve_combined_type(
 655    connector: str, lhs_type: type[Field], rhs_type: type[Field]
 656) -> type[Field] | None:
 657    combinators = _connector_combinators.get(connector, ())
 658    for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
 659        if issubclass(lhs_type, combinator_lhs_type) and issubclass(
 660            rhs_type, combinator_rhs_type
 661        ):
 662            return combined_type
 663    return None
 664
 665
 666class SQLiteNumericMixin(Expression):
 667    """
 668    Some expressions with output_field=DecimalField() must be cast to
 669    numeric to be properly filtered.
 670    """
 671
 672    def as_sqlite(
 673        self,
 674        compiler: SQLCompiler,
 675        connection: BaseDatabaseWrapper,
 676        **extra_context: Any,
 677    ) -> tuple[str, Sequence[Any]]:
 678        sql, params = self.as_sql(compiler, connection, **extra_context)
 679        try:
 680            if self.output_field.get_internal_type() == "DecimalField":
 681                sql = f"CAST({sql} AS NUMERIC)"
 682        except FieldError:
 683            pass
 684        return sql, params
 685
 686
 687class CombinedExpression(SQLiteNumericMixin, Expression):
 688    def __init__(
 689        self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
 690    ):
 691        super().__init__(output_field=output_field)
 692        self.connector = connector
 693        self.lhs = lhs
 694        self.rhs = rhs
 695
 696    def __repr__(self) -> str:
 697        return f"<{self.__class__.__name__}: {self}>"
 698
 699    def __str__(self) -> str:
 700        return f"{self.lhs} {self.connector} {self.rhs}"
 701
 702    def get_source_expressions(self) -> list[Any]:
 703        return [self.lhs, self.rhs]
 704
 705    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 706        self.lhs, self.rhs = exprs
 707
 708    def _resolve_output_field(self) -> Field | None:
 709        # We avoid using super() here for reasons given in
 710        # Expression._resolve_output_field()
 711        combined_type = _resolve_combined_type(
 712            self.connector,
 713            type(self.lhs._output_field_or_none),
 714            type(self.rhs._output_field_or_none),
 715        )
 716        if combined_type is None:
 717            raise FieldError(
 718                f"Cannot infer type of {self.connector!r} expression involving these "
 719                f"types: {self.lhs.output_field.__class__.__name__}, "
 720                f"{self.rhs.output_field.__class__.__name__}. You must set "
 721                f"output_field."
 722            )
 723        return combined_type()
 724
 725    def as_sql(
 726        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 727    ) -> tuple[str, list[Any]]:
 728        expressions = []
 729        expression_params = []
 730        sql, params = compiler.compile(self.lhs)
 731        expressions.append(sql)
 732        expression_params.extend(params)
 733        sql, params = compiler.compile(self.rhs)
 734        expressions.append(sql)
 735        expression_params.extend(params)
 736        # order of precedence
 737        expression_wrapper = "(%s)"
 738        sql = connection.ops.combine_expression(self.connector, expressions)
 739        return expression_wrapper % sql, expression_params
 740
 741    def resolve_expression(
 742        self,
 743        query: Any = None,
 744        allow_joins: bool = True,
 745        reuse: Any = None,
 746        summarize: bool = False,
 747        for_save: bool = False,
 748    ) -> CombinedExpression | DurationExpression | TemporalSubtraction:
 749        lhs = self.lhs.resolve_expression(
 750            query, allow_joins, reuse, summarize, for_save
 751        )
 752        rhs = self.rhs.resolve_expression(
 753            query, allow_joins, reuse, summarize, for_save
 754        )
 755        if not isinstance(self, DurationExpression | TemporalSubtraction):
 756            try:
 757                lhs_type = lhs.output_field.get_internal_type()
 758            except (AttributeError, FieldError):
 759                lhs_type = None
 760            try:
 761                rhs_type = rhs.output_field.get_internal_type()
 762            except (AttributeError, FieldError):
 763                rhs_type = None
 764            if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
 765                return DurationExpression(
 766                    self.lhs, self.connector, self.rhs
 767                ).resolve_expression(
 768                    query,
 769                    allow_joins,
 770                    reuse,
 771                    summarize,
 772                    for_save,
 773                )
 774            datetime_fields = {"DateField", "DateTimeField", "TimeField"}
 775            if (
 776                self.connector == self.SUB
 777                and lhs_type in datetime_fields
 778                and lhs_type == rhs_type
 779            ):
 780                return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
 781                    query,
 782                    allow_joins,
 783                    reuse,
 784                    summarize,
 785                    for_save,
 786                )
 787        c = self.copy()
 788        c.is_summary = summarize
 789        c.lhs = lhs
 790        c.rhs = rhs
 791        return c
 792
 793
 794class DurationExpression(CombinedExpression):
 795    def compile(
 796        self, side: Any, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 797    ) -> tuple[str, Sequence[Any]]:
 798        try:
 799            output = side.output_field
 800        except FieldError:
 801            pass
 802        else:
 803            if output.get_internal_type() == "DurationField":
 804                sql, params = compiler.compile(side)
 805                return connection.ops.format_for_duration_arithmetic(sql), params
 806        return compiler.compile(side)
 807
 808    def as_sql(
 809        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 810    ) -> tuple[str, list[Any]]:
 811        if connection.features.has_native_duration_field:
 812            return super().as_sql(compiler, connection)  # type: ignore[misc]
 813        connection.ops.check_expression_support(self)
 814        expressions = []
 815        expression_params = []
 816        sql, params = self.compile(self.lhs, compiler, connection)
 817        expressions.append(sql)
 818        expression_params.extend(params)
 819        sql, params = self.compile(self.rhs, compiler, connection)
 820        expressions.append(sql)
 821        expression_params.extend(params)
 822        # order of precedence
 823        expression_wrapper = "(%s)"
 824        sql = connection.ops.combine_duration_expression(self.connector, expressions)
 825        return expression_wrapper % sql, expression_params
 826
 827    def as_sqlite(
 828        self,
 829        compiler: SQLCompiler,
 830        connection: BaseDatabaseWrapper,
 831        **extra_context: Any,
 832    ) -> tuple[str, Sequence[Any]]:
 833        sql, params = self.as_sql(compiler, connection, **extra_context)
 834        if self.connector in {Combinable.MUL, Combinable.DIV}:
 835            try:
 836                lhs_type = self.lhs.output_field.get_internal_type()
 837                rhs_type = self.rhs.output_field.get_internal_type()
 838            except (AttributeError, FieldError):
 839                pass
 840            else:
 841                allowed_fields = {
 842                    "DecimalField",
 843                    "DurationField",
 844                    "FloatField",
 845                    "IntegerField",
 846                }
 847                if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
 848                    raise DatabaseError(
 849                        f"Invalid arguments for operator {self.connector}."
 850                    )
 851        return sql, params
 852
 853
 854class TemporalSubtraction(CombinedExpression):
 855    output_field = fields.DurationField()
 856
 857    def __init__(self, lhs: Any, rhs: Any):
 858        super().__init__(lhs, self.SUB, rhs)
 859
 860    def as_sql(
 861        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 862    ) -> tuple[str, Sequence[Any]]:
 863        connection.ops.check_expression_support(self)
 864        lhs = compiler.compile(self.lhs)
 865        rhs = compiler.compile(self.rhs)
 866        return connection.ops.subtract_temporals(
 867            self.lhs.output_field.get_internal_type(), lhs, rhs
 868        )
 869
 870
 871@deconstructible(path="plain.models.F")
 872class F(Combinable):
 873    """An object capable of resolving references to existing query objects."""
 874
 875    def __init__(self, name: str):
 876        """
 877        Arguments:
 878         * name: the name of the field this expression references
 879        """
 880        self.name = name
 881
 882    def __repr__(self) -> str:
 883        return f"{self.__class__.__name__}({self.name})"
 884
 885    def resolve_expression(
 886        self,
 887        query: Any = None,
 888        allow_joins: bool = True,
 889        reuse: Any = None,
 890        summarize: bool = False,
 891        for_save: bool = False,
 892    ) -> Any:
 893        return query.resolve_ref(self.name, allow_joins, reuse, summarize)  # type: ignore[union-attr]
 894
 895    def replace_expressions(self, replacements: dict[Any, Any]) -> F:
 896        return replacements.get(self, self)
 897
 898    def asc(self, **kwargs: Any) -> OrderBy:
 899        return OrderBy(self, **kwargs)
 900
 901    def desc(self, **kwargs: Any) -> OrderBy:
 902        return OrderBy(self, descending=True, **kwargs)
 903
 904    def __eq__(self, other: object) -> bool:
 905        return (
 906            self.__class__ == other.__class__ and self.name == other.name  # type: ignore[attr-defined]
 907        )
 908
 909    def __hash__(self) -> int:
 910        return hash(self.name)
 911
 912    def copy(self) -> Self:
 913        return copy.copy(self)
 914
 915
 916class ResolvedOuterRef(F):
 917    """
 918    An object that contains a reference to an outer query.
 919
 920    In this case, the reference to the outer query has been resolved because
 921    the inner query has been used as a subquery.
 922    """
 923
 924    contains_aggregate = False
 925    contains_over_clause = False
 926
 927    def as_sql(self, *args: Any, **kwargs: Any) -> None:
 928        raise ValueError(
 929            "This queryset contains a reference to an outer query and may "
 930            "only be used in a subquery."
 931        )
 932
 933    def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
 934        col = super().resolve_expression(*args, **kwargs)  # type: ignore[misc]
 935        if col.contains_over_clause:
 936            raise NotSupportedError(
 937                f"Referencing outer query window expression is not supported: "
 938                f"{self.name}."
 939            )
 940        # FIXME: Rename possibly_multivalued to multivalued and fix detection
 941        # for non-multivalued JOINs (e.g. foreign key fields). This should take
 942        # into account only many-to-many and one-to-many relationships.
 943        col.possibly_multivalued = LOOKUP_SEP in self.name
 944        return col
 945
 946    def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
 947        return self
 948
 949    def get_group_by_cols(self) -> list[Any]:
 950        return []
 951
 952
 953class OuterRef(F):
 954    contains_aggregate = False
 955
 956    def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
 957        if isinstance(self.name, self.__class__):
 958            return self.name
 959        return ResolvedOuterRef(self.name)
 960
 961    def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
 962        return self
 963
 964
 965@deconstructible(path="plain.models.Func")
 966class Func(SQLiteNumericMixin, Expression):
 967    """An SQL function call."""
 968
 969    function = None
 970    template = "%(function)s(%(expressions)s)"
 971    arg_joiner = ", "
 972    arity = None  # The number of arguments the function accepts.
 973
 974    def __init__(
 975        self, *expressions: Any, output_field: Field | None = None, **extra: Any
 976    ):
 977        if self.arity is not None and len(expressions) != self.arity:
 978            raise TypeError(
 979                "'{}' takes exactly {} {} ({} given)".format(
 980                    self.__class__.__name__,
 981                    self.arity,
 982                    "argument" if self.arity == 1 else "arguments",
 983                    len(expressions),
 984                )
 985            )
 986        super().__init__(output_field=output_field)
 987        self.source_expressions: list[Any] = self._parse_expressions(*expressions)
 988        self.extra = extra
 989
 990    def __repr__(self) -> str:
 991        args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
 992        extra = {**self.extra, **self._get_repr_options()}
 993        if extra:
 994            extra = ", ".join(
 995                str(key) + "=" + str(val) for key, val in sorted(extra.items())
 996            )
 997            return f"{self.__class__.__name__}({args}, {extra})"
 998        return f"{self.__class__.__name__}({args})"
 999
1000    def _get_repr_options(self) -> dict[str, Any]:
1001        """Return a dict of extra __init__() options to include in the repr."""
1002        return {}
1003
1004    def get_source_expressions(self) -> list[Any]:
1005        return self.source_expressions
1006
1007    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1008        self.source_expressions = list(exprs)
1009
1010    def resolve_expression(
1011        self,
1012        query: Any = None,
1013        allow_joins: bool = True,
1014        reuse: Any = None,
1015        summarize: bool = False,
1016        for_save: bool = False,
1017    ) -> Self:
1018        c = self.copy()
1019        c.is_summary = summarize
1020        for pos, arg in enumerate(c.source_expressions):
1021            c.source_expressions[pos] = arg.resolve_expression(
1022                query, allow_joins, reuse, summarize, for_save
1023            )
1024        return c
1025
1026    def as_sql(
1027        self,
1028        compiler: SQLCompiler,
1029        connection: BaseDatabaseWrapper,
1030        function: str | None = None,
1031        template: str | None = None,
1032        arg_joiner: str | None = None,
1033        **extra_context: Any,
1034    ) -> tuple[str, list[Any]]:
1035        connection.ops.check_expression_support(self)
1036        sql_parts = []
1037        params = []
1038        for arg in self.source_expressions:
1039            try:
1040                arg_sql, arg_params = compiler.compile(arg)
1041            except EmptyResultSet:
1042                empty_result_set_value = getattr(
1043                    arg, "empty_result_set_value", NotImplemented
1044                )
1045                if empty_result_set_value is NotImplemented:
1046                    raise
1047                arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
1048            except FullResultSet:
1049                arg_sql, arg_params = compiler.compile(Value(True))
1050            sql_parts.append(arg_sql)
1051            params.extend(arg_params)
1052        data = {**self.extra, **extra_context}
1053        # Use the first supplied value in this order: the parameter to this
1054        # method, a value supplied in __init__()'s **extra (the value in
1055        # `data`), or the value defined on the class.
1056        if function is not None:
1057            data["function"] = function
1058        else:
1059            data.setdefault("function", self.function)
1060        template = template or data.get("template", self.template)
1061        arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1062        data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1063        return template % data, params
1064
1065    def copy(self) -> Self:
1066        copy = super().copy()
1067        copy.source_expressions = self.source_expressions[:]
1068        copy.extra = self.extra.copy()
1069        return copy  # type: ignore[return-value]
1070
1071
1072@deconstructible(path="plain.models.Value")
1073class Value(SQLiteNumericMixin, Expression):
1074    """Represent a wrapped value as a node within an expression."""
1075
1076    # Provide a default value for `for_save` in order to allow unresolved
1077    # instances to be compiled until a decision is taken in #25425.
1078    for_save = False
1079
1080    def __init__(self, value: Any, output_field: Field | None = None):
1081        """
1082        Arguments:
1083         * value: the value this expression represents. The value will be
1084           added into the sql parameter list and properly quoted.
1085
1086         * output_field: an instance of the model field type that this
1087           expression will return, such as IntegerField() or CharField().
1088        """
1089        super().__init__(output_field=output_field)
1090        self.value = value
1091
1092    def __repr__(self) -> str:
1093        return f"{self.__class__.__name__}({self.value!r})"
1094
1095    def as_sql(
1096        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1097    ) -> tuple[str, list[Any]]:
1098        connection.ops.check_expression_support(self)
1099        val = self.value
1100        output_field = self._output_field_or_none
1101        if output_field is not None:
1102            if self.for_save:
1103                val = output_field.get_db_prep_save(val, connection=connection)
1104            else:
1105                val = output_field.get_db_prep_value(val, connection=connection)
1106            if hasattr(output_field, "get_placeholder"):
1107                return output_field.get_placeholder(val, compiler, connection), [val]
1108        if val is None:
1109            # cx_Oracle does not always convert None to the appropriate
1110            # NULL type (like in case expressions using numbers), so we
1111            # use a literal SQL NULL
1112            return "NULL", []
1113        return "%s", [val]
1114
1115    def resolve_expression(
1116        self,
1117        query: Any = None,
1118        allow_joins: bool = True,
1119        reuse: Any = None,
1120        summarize: bool = False,
1121        for_save: bool = False,
1122    ) -> Value:
1123        c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)  # type: ignore[misc]
1124        c.for_save = for_save  # type: ignore[attr-defined]
1125        return c  # type: ignore[return-value]
1126
1127    def get_group_by_cols(self) -> list[Any]:
1128        return []
1129
1130    def _resolve_output_field(self) -> Field | None:
1131        if isinstance(self.value, str):
1132            return fields.CharField()
1133        if isinstance(self.value, bool):
1134            return fields.BooleanField()
1135        if isinstance(self.value, int):
1136            return fields.IntegerField()
1137        if isinstance(self.value, float):
1138            return fields.FloatField()
1139        if isinstance(self.value, datetime.datetime):
1140            return fields.DateTimeField()
1141        if isinstance(self.value, datetime.date):
1142            return fields.DateField()
1143        if isinstance(self.value, datetime.time):
1144            return fields.TimeField()
1145        if isinstance(self.value, datetime.timedelta):
1146            return fields.DurationField()
1147        if isinstance(self.value, Decimal):
1148            return fields.DecimalField()
1149        if isinstance(self.value, bytes):
1150            return fields.BinaryField()
1151        if isinstance(self.value, UUID):
1152            return fields.UUIDField()
1153
1154    @property
1155    def empty_result_set_value(self) -> Any:
1156        return self.value
1157
1158
1159class RawSQL(Expression):
1160    def __init__(
1161        self, sql: str, params: Sequence[Any], output_field: Field | None = None
1162    ):
1163        if output_field is None:
1164            output_field = fields.Field()
1165        self.sql, self.params = sql, params
1166        super().__init__(output_field=output_field)
1167
1168    def __repr__(self) -> str:
1169        return f"{self.__class__.__name__}({self.sql}, {self.params})"
1170
1171    def as_sql(
1172        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1173    ) -> tuple[str, Sequence[Any]]:
1174        return f"({self.sql})", self.params
1175
1176    def get_group_by_cols(self) -> list[RawSQL]:
1177        return [self]
1178
1179
1180class Star(Expression):
1181    def __repr__(self) -> str:
1182        return "'*'"
1183
1184    def as_sql(
1185        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1186    ) -> tuple[str, list[Any]]:
1187        return "*", []
1188
1189
1190class Col(Expression):
1191    contains_column_references = True
1192    possibly_multivalued = False
1193
1194    def __init__(
1195        self, alias: str | None, target: Any, output_field: Field | None = None
1196    ):
1197        if output_field is None:
1198            output_field = target
1199        super().__init__(output_field=output_field)
1200        self.alias, self.target = alias, target
1201
1202    def __repr__(self) -> str:
1203        alias, target = self.alias, self.target
1204        identifiers = (alias, str(target)) if alias else (str(target),)
1205        return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1206
1207    def as_sql(
1208        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1209    ) -> tuple[str, list[Any]]:
1210        alias, column = self.alias, self.target.column
1211        identifiers = (alias, column) if alias else (column,)
1212        sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1213        return sql, []
1214
1215    def relabeled_clone(self, relabels: dict[str, str]) -> Col:
1216        if self.alias is None:
1217            return self
1218        return self.__class__(
1219            relabels.get(self.alias, self.alias), self.target, self.output_field
1220        )
1221
1222    def get_group_by_cols(self) -> list[Col]:
1223        return [self]
1224
1225    def get_db_converters(
1226        self, connection: BaseDatabaseWrapper
1227    ) -> list[Callable[..., Any]]:
1228        if self.target == self.output_field:
1229            return self.output_field.get_db_converters(connection)
1230        return self.output_field.get_db_converters(
1231            connection
1232        ) + self.target.get_db_converters(connection)
1233
1234
1235class Ref(Expression):
1236    """
1237    Reference to column alias of the query. For example, Ref('sum_cost') in
1238    qs.annotate(sum_cost=Sum('cost')) query.
1239    """
1240
1241    def __init__(self, refs: str, source: Any):
1242        super().__init__()
1243        self.refs, self.source = refs, source
1244
1245    def __repr__(self) -> str:
1246        return f"{self.__class__.__name__}({self.refs}, {self.source})"
1247
1248    def get_source_expressions(self) -> list[Any]:
1249        return [self.source]
1250
1251    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1252        (self.source,) = exprs
1253
1254    def resolve_expression(
1255        self,
1256        query: Any = None,
1257        allow_joins: bool = True,
1258        reuse: Any = None,
1259        summarize: bool = False,
1260        for_save: bool = False,
1261    ) -> Ref:
1262        # The sub-expression `source` has already been resolved, as this is
1263        # just a reference to the name of `source`.
1264        return self
1265
1266    def get_refs(self) -> set[str]:
1267        return {self.refs}
1268
1269    def relabeled_clone(self, relabels: dict[str, str]) -> Ref:
1270        return self
1271
1272    def as_sql(
1273        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1274    ) -> tuple[str, list[Any]]:
1275        return connection.ops.quote_name(self.refs), []
1276
1277    def get_group_by_cols(self) -> list[Ref]:
1278        return [self]
1279
1280
1281class ExpressionList(Func):
1282    """
1283    An expression containing multiple expressions. Can be used to provide a
1284    list of expressions as an argument to another expression, like a partition
1285    clause.
1286    """
1287
1288    template = "%(expressions)s"
1289
1290    def __init__(self, *expressions: Any, **extra: Any):
1291        if not expressions:
1292            raise ValueError(
1293                f"{self.__class__.__name__} requires at least one expression."
1294            )
1295        super().__init__(*expressions, **extra)
1296
1297    def __str__(self) -> str:
1298        return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1299
1300    def as_sqlite(
1301        self,
1302        compiler: SQLCompiler,
1303        connection: BaseDatabaseWrapper,
1304        **extra_context: Any,
1305    ) -> tuple[str, Sequence[Any]]:
1306        # Casting to numeric is unnecessary.
1307        return self.as_sql(compiler, connection, **extra_context)
1308
1309
1310class OrderByList(Func):
1311    template = "ORDER BY %(expressions)s"
1312
1313    def __init__(self, *expressions: Any, **extra: Any):
1314        expressions_tuple = tuple(
1315            (
1316                OrderBy(F(expr[1:]), descending=True)
1317                if isinstance(expr, str) and expr[0] == "-"
1318                else expr
1319            )
1320            for expr in expressions
1321        )
1322        super().__init__(*expressions_tuple, **extra)
1323
1324    def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, tuple[Any, ...]]:
1325        if not self.source_expressions:
1326            return "", ()
1327        return super().as_sql(*args, **kwargs)  # type: ignore[misc]
1328
1329    def get_group_by_cols(self) -> list[Any]:
1330        group_by_cols = []
1331        for order_by in self.get_source_expressions():
1332            group_by_cols.extend(order_by.get_group_by_cols())
1333        return group_by_cols
1334
1335
1336@deconstructible(path="plain.models.ExpressionWrapper")
1337class ExpressionWrapper(SQLiteNumericMixin, Expression):
1338    """
1339    An expression that can wrap another expression so that it can provide
1340    extra context to the inner expression, such as the output_field.
1341    """
1342
1343    def __init__(self, expression: Any, output_field: Field):
1344        super().__init__(output_field=output_field)
1345        self.expression = expression
1346
1347    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1348        self.expression = exprs[0]
1349
1350    def get_source_expressions(self) -> list[Any]:
1351        return [self.expression]
1352
1353    def get_group_by_cols(self) -> list[Any]:
1354        if isinstance(self.expression, Expression):
1355            expression = self.expression.copy()
1356            expression.output_field = self.output_field
1357            return expression.get_group_by_cols()
1358        # For non-expressions e.g. an SQL WHERE clause, the entire
1359        # `expression` must be included in the GROUP BY clause.
1360        return super().get_group_by_cols()  # type: ignore[misc]
1361
1362    def as_sql(
1363        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1364    ) -> tuple[str, Sequence[Any]]:
1365        return compiler.compile(self.expression)
1366
1367    def __repr__(self) -> str:
1368        return f"{self.__class__.__name__}({self.expression})"
1369
1370
1371class NegatedExpression(ExpressionWrapper):
1372    """The logical negation of a conditional expression."""
1373
1374    def __init__(self, expression: Any):
1375        super().__init__(expression, output_field=fields.BooleanField())
1376
1377    def __invert__(self) -> Any:
1378        return self.expression.copy()
1379
1380    def as_sql(
1381        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1382    ) -> tuple[str, Sequence[Any]]:
1383        try:
1384            sql, params = super().as_sql(compiler, connection)
1385        except EmptyResultSet:
1386            features = compiler.connection.features
1387            if not features.supports_boolean_expr_in_select_clause:
1388                return "1=1", ()
1389            return compiler.compile(Value(True))
1390        ops = compiler.connection.ops
1391        # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1392        # to be compared to another expression unless they're wrapped in a CASE
1393        # WHEN.
1394        if not ops.conditional_expression_supported_in_where_clause(self.expression):
1395            return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1396        return f"NOT {sql}", params
1397
1398    def resolve_expression(
1399        self,
1400        query: Any = None,
1401        allow_joins: bool = True,
1402        reuse: Any = None,
1403        summarize: bool = False,
1404        for_save: bool = False,
1405    ) -> NegatedExpression:
1406        resolved = super().resolve_expression(
1407            query, allow_joins, reuse, summarize, for_save
1408        )
1409        if not getattr(resolved.expression, "conditional", False):  # type: ignore[attr-defined]
1410            raise TypeError("Cannot negate non-conditional expressions.")
1411        return resolved  # type: ignore[return-value]
1412
1413    def select_format(
1414        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1415    ) -> tuple[str, Sequence[Any]]:
1416        # Wrap boolean expressions with a CASE WHEN expression if a database
1417        # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1418        # GROUP BY list.
1419        expression_supported_in_where_clause = (
1420            compiler.connection.ops.conditional_expression_supported_in_where_clause
1421        )
1422        if (
1423            not compiler.connection.features.supports_boolean_expr_in_select_clause
1424            # Avoid double wrapping.
1425            and expression_supported_in_where_clause(self.expression)
1426        ):
1427            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1428        return sql, params
1429
1430
1431@deconstructible(path="plain.models.When")
1432class When(Expression):
1433    template = "WHEN %(condition)s THEN %(result)s"
1434    # This isn't a complete conditional expression, must be used in Case().
1435    conditional = False
1436
1437    def __init__(
1438        self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1439    ):
1440        lookups_dict: dict[str, Any] | None = lookups or None
1441        if lookups_dict:
1442            if condition is None:
1443                condition, lookups_dict = Q(**lookups_dict), None
1444            elif getattr(condition, "conditional", False):
1445                condition, lookups_dict = Q(condition, **lookups_dict), None
1446        if (
1447            condition is None
1448            or not getattr(condition, "conditional", False)
1449            or lookups_dict
1450        ):
1451            raise TypeError(
1452                "When() supports a Q object, a boolean expression, or lookups "
1453                "as a condition."
1454            )
1455        if isinstance(condition, Q) and not condition:
1456            raise ValueError("An empty Q() can't be used as a When() condition.")
1457        super().__init__(output_field=None)
1458        self.condition = condition
1459        self.result = self._parse_expressions(then)[0]
1460
1461    def __str__(self) -> str:
1462        return f"WHEN {self.condition!r} THEN {self.result!r}"
1463
1464    def __repr__(self) -> str:
1465        return f"<{self.__class__.__name__}: {self}>"
1466
1467    def get_source_expressions(self) -> list[Any]:
1468        return [self.condition, self.result]
1469
1470    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1471        self.condition, self.result = exprs
1472
1473    def get_source_fields(self) -> list[Field | None]:
1474        # We're only interested in the fields of the result expressions.
1475        return [self.result._output_field_or_none]
1476
1477    def resolve_expression(
1478        self,
1479        query: Any = None,
1480        allow_joins: bool = True,
1481        reuse: Any = None,
1482        summarize: bool = False,
1483        for_save: bool = False,
1484    ) -> When:
1485        c = self.copy()
1486        c.is_summary = summarize
1487        if hasattr(c.condition, "resolve_expression"):
1488            c.condition = c.condition.resolve_expression(
1489                query, allow_joins, reuse, summarize, False
1490            )
1491        c.result = c.result.resolve_expression(
1492            query, allow_joins, reuse, summarize, for_save
1493        )
1494        return c
1495
1496    def as_sql(
1497        self,
1498        compiler: SQLCompiler,
1499        connection: BaseDatabaseWrapper,
1500        template: str | None = None,
1501        **extra_context: Any,
1502    ) -> tuple[str, tuple[Any, ...]]:
1503        connection.ops.check_expression_support(self)
1504        template_params = extra_context
1505        sql_params = []
1506        # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1507        condition_sql, condition_params = compiler.compile(self.condition)  # type: ignore[arg-type]
1508        template_params["condition"] = condition_sql
1509        result_sql, result_params = compiler.compile(self.result)
1510        template_params["result"] = result_sql
1511        template = template or self.template
1512        return template % template_params, (
1513            *sql_params,
1514            *condition_params,
1515            *result_params,
1516        )
1517
1518    def get_group_by_cols(self) -> list[Any]:
1519        # This is not a complete expression and cannot be used in GROUP BY.
1520        cols = []
1521        for source in self.get_source_expressions():
1522            cols.extend(source.get_group_by_cols())
1523        return cols
1524
1525
1526@deconstructible(path="plain.models.Case")
1527class Case(SQLiteNumericMixin, Expression):
1528    """
1529    An SQL searched CASE expression:
1530
1531        CASE
1532            WHEN n > 0
1533                THEN 'positive'
1534            WHEN n < 0
1535                THEN 'negative'
1536            ELSE 'zero'
1537        END
1538    """
1539
1540    template = "CASE %(cases)s ELSE %(default)s END"
1541    case_joiner = " "
1542
1543    def __init__(
1544        self,
1545        *cases: When,
1546        default: Any = None,
1547        output_field: Field | None = None,
1548        **extra: Any,
1549    ):
1550        if not all(isinstance(case, When) for case in cases):
1551            raise TypeError("Positional arguments must all be When objects.")
1552        super().__init__(output_field)
1553        self.cases = list(cases)
1554        self.default = self._parse_expressions(default)[0]
1555        self.extra = extra
1556
1557    def __str__(self) -> str:
1558        return "CASE {}, ELSE {!r}".format(
1559            ", ".join(str(c) for c in self.cases),
1560            self.default,
1561        )
1562
1563    def __repr__(self) -> str:
1564        return f"<{self.__class__.__name__}: {self}>"
1565
1566    def get_source_expressions(self) -> list[Any]:
1567        return self.cases + [self.default]
1568
1569    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1570        *self.cases, self.default = exprs
1571
1572    def resolve_expression(
1573        self,
1574        query: Any = None,
1575        allow_joins: bool = True,
1576        reuse: Any = None,
1577        summarize: bool = False,
1578        for_save: bool = False,
1579    ) -> Case:
1580        c = self.copy()
1581        c.is_summary = summarize
1582        for pos, case in enumerate(c.cases):
1583            c.cases[pos] = case.resolve_expression(
1584                query, allow_joins, reuse, summarize, for_save
1585            )
1586        c.default = c.default.resolve_expression(
1587            query, allow_joins, reuse, summarize, for_save
1588        )
1589        return c
1590
1591    def copy(self) -> Self:
1592        c = super().copy()  # type: ignore[misc]
1593        c.cases = c.cases[:]  # type: ignore[attr-defined]
1594        return c  # type: ignore[return-value]
1595
1596    def as_sql(
1597        self,
1598        compiler: SQLCompiler,
1599        connection: BaseDatabaseWrapper,
1600        template: str | None = None,
1601        case_joiner: str | None = None,
1602        **extra_context: Any,
1603    ) -> tuple[str, list[Any]]:
1604        connection.ops.check_expression_support(self)
1605        if not self.cases:
1606            sql, params = compiler.compile(self.default)
1607            return sql, list(params)
1608        template_params = {**self.extra, **extra_context}
1609        case_parts = []
1610        sql_params = []
1611        default_sql, default_params = compiler.compile(self.default)
1612        for case in self.cases:
1613            try:
1614                case_sql, case_params = compiler.compile(case)
1615            except EmptyResultSet:
1616                continue
1617            except FullResultSet:
1618                default_sql, default_params = compiler.compile(case.result)
1619                break
1620            case_parts.append(case_sql)
1621            sql_params.extend(case_params)
1622        if not case_parts:
1623            return default_sql, list(default_params)
1624        case_joiner = case_joiner or self.case_joiner
1625        template_params["cases"] = case_joiner.join(case_parts)
1626        template_params["default"] = default_sql
1627        sql_params.extend(default_params)
1628        template = template or template_params.get("template", self.template)
1629        sql = template % template_params
1630        if self._output_field_or_none is not None:
1631            sql = connection.ops.unification_cast_sql(self.output_field) % sql
1632        return sql, sql_params
1633
1634    def get_group_by_cols(self) -> list[Any]:
1635        if not self.cases:
1636            return self.default.get_group_by_cols()
1637        return super().get_group_by_cols()  # type: ignore[misc]
1638
1639
1640class Subquery(BaseExpression, Combinable):
1641    """
1642    An explicit subquery. It may contain OuterRef() references to the outer
1643    query which will be resolved when it is applied to that query.
1644    """
1645
1646    template = "(%(subquery)s)"
1647    contains_aggregate = False
1648    empty_result_set_value = None
1649
1650    def __init__(
1651        self,
1652        query: QuerySet[Any] | Query,
1653        output_field: Field | None = None,
1654        **extra: Any,
1655    ):
1656        # Import here to avoid circular import
1657        from plain.models.sql.query import Query
1658
1659        # Allow the usage of both QuerySet and sql.Query objects.
1660        if isinstance(query, Query):
1661            # It's already a Query object, use it directly
1662            sql_query = query
1663        else:
1664            # It's a QuerySet, extract the sql.Query
1665            sql_query = query.sql_query
1666        self.query = sql_query.clone()
1667        self.query.subquery = True
1668        self.extra = extra
1669        super().__init__(output_field)
1670
1671    def get_source_expressions(self) -> list[Any]:
1672        return [self.query]
1673
1674    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1675        self.query = exprs[0]
1676
1677    def _resolve_output_field(self) -> Field | None:
1678        return self.query.output_field
1679
1680    def copy(self) -> Self:
1681        clone = super().copy()
1682        clone.query = clone.query.clone()  # type: ignore[attr-defined]
1683        return clone  # type: ignore[return-value]
1684
1685    @property
1686    def external_aliases(self) -> dict[str, bool]:  # type: ignore[override]
1687        return self.query.external_aliases  # type: ignore[return-value]
1688
1689    def get_external_cols(self) -> list[Any]:
1690        return self.query.get_external_cols()
1691
1692    def as_sql(
1693        self,
1694        compiler: SQLCompiler,
1695        connection: BaseDatabaseWrapper,
1696        template: str | None = None,
1697        **extra_context: Any,
1698    ) -> tuple[str, tuple[Any, ...]]:
1699        connection.ops.check_expression_support(self)
1700        template_params = {**self.extra, **extra_context}
1701        subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1702        template_params["subquery"] = subquery_sql[1:-1]
1703
1704        template = template or template_params.get("template", self.template)
1705        sql = template % template_params
1706        return sql, sql_params
1707
1708    def get_group_by_cols(self) -> list[Any]:
1709        return self.query.get_group_by_cols(wrapper=self)
1710
1711
1712class Exists(Subquery):
1713    template = "EXISTS(%(subquery)s)"
1714    output_field = fields.BooleanField()
1715    empty_result_set_value = False
1716
1717    def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1718        super().__init__(query, **kwargs)
1719        self.query = self.query.exists()
1720
1721    def select_format(
1722        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1723    ) -> tuple[str, Sequence[Any]]:
1724        # Wrap EXISTS() with a CASE WHEN expression if a database backend
1725        # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1726        # BY list.
1727        if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1728            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1729        return sql, params
1730
1731
1732@deconstructible(path="plain.models.OrderBy")
1733class OrderBy(Expression):
1734    template = "%(expression)s %(ordering)s"
1735    conditional = False
1736
1737    def __init__(
1738        self,
1739        expression: Any,
1740        descending: bool = False,
1741        nulls_first: bool | None = None,
1742        nulls_last: bool | None = None,
1743    ):
1744        if nulls_first and nulls_last:
1745            raise ValueError("nulls_first and nulls_last are mutually exclusive")
1746        if nulls_first is False or nulls_last is False:
1747            raise ValueError("nulls_first and nulls_last values must be True or None.")
1748        self.nulls_first = nulls_first
1749        self.nulls_last = nulls_last
1750        self.descending = descending
1751        if not hasattr(expression, "resolve_expression"):
1752            raise ValueError("expression must be an expression type")
1753        self.expression = expression
1754
1755    def __repr__(self) -> str:
1756        return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1757
1758    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1759        self.expression = exprs[0]
1760
1761    def get_source_expressions(self) -> list[Any]:
1762        return [self.expression]
1763
1764    def as_sql(
1765        self,
1766        compiler: SQLCompiler,
1767        connection: BaseDatabaseWrapper,
1768        template: str | None = None,
1769        **extra_context: Any,
1770    ) -> tuple[str, tuple[Any, ...]]:
1771        template = template or self.template
1772        if connection.features.supports_order_by_nulls_modifier:
1773            if self.nulls_last:
1774                template = f"{template} NULLS LAST"
1775            elif self.nulls_first:
1776                template = f"{template} NULLS FIRST"
1777        else:
1778            if self.nulls_last and not (
1779                self.descending and connection.features.order_by_nulls_first
1780            ):
1781                template = f"%(expression)s IS NULL, {template}"
1782            elif self.nulls_first and not (
1783                not self.descending and connection.features.order_by_nulls_first
1784            ):
1785                template = f"%(expression)s IS NOT NULL, {template}"
1786        connection.ops.check_expression_support(self)
1787        expression_sql, params = compiler.compile(self.expression)
1788        placeholders = {
1789            "expression": expression_sql,
1790            "ordering": "DESC" if self.descending else "ASC",
1791            **extra_context,
1792        }
1793        params *= template.count("%(expression)s")
1794        return (template % placeholders).rstrip(), params
1795
1796    def get_group_by_cols(self) -> list[Any]:
1797        cols = []
1798        for source in self.get_source_expressions():
1799            cols.extend(source.get_group_by_cols())
1800        return cols
1801
1802    def reverse_ordering(self) -> OrderBy:
1803        self.descending = not self.descending
1804        if self.nulls_first:
1805            self.nulls_last = True
1806            self.nulls_first = None
1807        elif self.nulls_last:
1808            self.nulls_first = True
1809            self.nulls_last = None
1810        return self
1811
1812    def asc(self) -> None:
1813        self.descending = False
1814
1815    def desc(self) -> None:
1816        self.descending = True
1817
1818
1819class Window(SQLiteNumericMixin, Expression):
1820    template = "%(expression)s OVER (%(window)s)"
1821    # Although the main expression may either be an aggregate or an
1822    # expression with an aggregate function, the GROUP BY that will
1823    # be introduced in the query as a result is not desired.
1824    contains_aggregate = False
1825    contains_over_clause = True
1826    partition_by: ExpressionList | None
1827    order_by: OrderByList | None
1828
1829    def __init__(
1830        self,
1831        expression: Any,
1832        partition_by: Any = None,
1833        order_by: Any = None,
1834        frame: Any = None,
1835        output_field: Field | None = None,
1836    ):
1837        self.partition_by = partition_by
1838        self.order_by = order_by
1839        self.frame = frame
1840
1841        if not getattr(expression, "window_compatible", False):
1842            raise ValueError(
1843                f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1844            )
1845
1846        if self.partition_by is not None:
1847            partition_by_values = (
1848                self.partition_by
1849                if isinstance(self.partition_by, tuple | list)
1850                else (self.partition_by,)
1851            )
1852            self.partition_by = ExpressionList(*partition_by_values)
1853
1854        if self.order_by is not None:
1855            if isinstance(self.order_by, list | tuple):
1856                self.order_by = OrderByList(*self.order_by)
1857            elif isinstance(self.order_by, BaseExpression | str):
1858                self.order_by = OrderByList(self.order_by)
1859            else:
1860                raise ValueError(
1861                    "Window.order_by must be either a string reference to a "
1862                    "field, an expression, or a list or tuple of them."
1863                )
1864        super().__init__(output_field=output_field)
1865        self.source_expression = self._parse_expressions(expression)[0]
1866
1867    def _resolve_output_field(self) -> Field | None:
1868        return self.source_expression.output_field
1869
1870    def get_source_expressions(self) -> list[Any]:
1871        return [self.source_expression, self.partition_by, self.order_by, self.frame]
1872
1873    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1874        self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1875
1876    def as_sql(
1877        self,
1878        compiler: SQLCompiler,
1879        connection: BaseDatabaseWrapper,
1880        template: str | None = None,
1881    ) -> tuple[str, tuple[Any, ...]]:
1882        connection.ops.check_expression_support(self)
1883        if not connection.features.supports_over_clause:
1884            raise NotSupportedError("This backend does not support window expressions.")
1885        expr_sql, params = compiler.compile(self.source_expression)
1886        window_sql, window_params = [], ()
1887
1888        if self.partition_by is not None:
1889            sql_expr, sql_params = self.partition_by.as_sql(
1890                compiler=compiler,
1891                connection=connection,
1892                template="PARTITION BY %(expressions)s",
1893            )
1894            window_sql.append(sql_expr)
1895            window_params += tuple(sql_params)
1896
1897        if self.order_by is not None:
1898            order_sql, order_params = compiler.compile(self.order_by)
1899            window_sql.append(order_sql)
1900            window_params += tuple(order_params)
1901
1902        if self.frame:
1903            frame_sql, frame_params = compiler.compile(self.frame)
1904            window_sql.append(frame_sql)
1905            window_params += tuple(frame_params)
1906
1907        template = template or self.template
1908
1909        return (
1910            template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1911            (*params, *window_params),
1912        )
1913
1914    def as_sqlite(
1915        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1916    ) -> tuple[str, Sequence[Any]]:
1917        if isinstance(self.output_field, fields.DecimalField):
1918            # Casting to numeric must be outside of the window expression.
1919            copy = self.copy()
1920            source_expressions = copy.get_source_expressions()
1921            source_expressions[0].output_field = fields.FloatField()
1922            copy.set_source_expressions(source_expressions)
1923            return super(Window, copy).as_sqlite(compiler, connection)
1924        return self.as_sql(compiler, connection)
1925
1926    def __str__(self) -> str:
1927        return "{} OVER ({}{}{})".format(
1928            str(self.source_expression),
1929            "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1930            str(self.order_by or ""),
1931            str(self.frame or ""),
1932        )
1933
1934    def __repr__(self) -> str:
1935        return f"<{self.__class__.__name__}: {self}>"
1936
1937    def get_group_by_cols(self) -> list[Any]:
1938        group_by_cols = []
1939        if self.partition_by:
1940            group_by_cols.extend(self.partition_by.get_group_by_cols())
1941        if self.order_by is not None:
1942            group_by_cols.extend(self.order_by.get_group_by_cols())
1943        return group_by_cols
1944
1945
1946class WindowFrame(Expression, ABC):
1947    """
1948    Model the frame clause in window expressions. There are two types of frame
1949    clauses which are subclasses, however, all processing and validation (by no
1950    means intended to be complete) is done here. Thus, providing an end for a
1951    frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1952    row in the frame).
1953    """
1954
1955    template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1956    frame_type: str
1957
1958    def __init__(self, start: int | None = None, end: int | None = None):
1959        self.start = Value(start)
1960        self.end = Value(end)
1961
1962    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1963        self.start, self.end = exprs
1964
1965    def get_source_expressions(self) -> list[Any]:
1966        return [self.start, self.end]
1967
1968    def as_sql(
1969        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1970    ) -> tuple[str, list[Any]]:
1971        connection.ops.check_expression_support(self)
1972        start, end = self.window_frame_start_end(
1973            connection, self.start.value, self.end.value
1974        )
1975        return (
1976            self.template
1977            % {
1978                "frame_type": self.frame_type,
1979                "start": start,
1980                "end": end,
1981            },
1982            [],
1983        )
1984
1985    def __repr__(self) -> str:
1986        return f"<{self.__class__.__name__}: {self}>"
1987
1988    def get_group_by_cols(self) -> list[Any]:
1989        return []
1990
1991    def __str__(self) -> str:
1992        if self.start.value is not None and self.start.value < 0:
1993            start = f"{abs(self.start.value)} {db_connection.ops.PRECEDING}"
1994        elif self.start.value is not None and self.start.value == 0:
1995            start = db_connection.ops.CURRENT_ROW
1996        else:
1997            start = db_connection.ops.UNBOUNDED_PRECEDING
1998
1999        if self.end.value is not None and self.end.value > 0:
2000            end = f"{self.end.value} {db_connection.ops.FOLLOWING}"
2001        elif self.end.value is not None and self.end.value == 0:
2002            end = db_connection.ops.CURRENT_ROW
2003        else:
2004            end = db_connection.ops.UNBOUNDED_FOLLOWING
2005        return self.template % {
2006            "frame_type": self.frame_type,
2007            "start": start,
2008            "end": end,
2009        }
2010
2011    @abstractmethod
2012    def window_frame_start_end(
2013        self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2014    ) -> tuple[str, str]: ...
2015
2016
2017class RowRange(WindowFrame):
2018    frame_type = "ROWS"
2019
2020    def window_frame_start_end(
2021        self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2022    ) -> tuple[str, str]:
2023        return connection.ops.window_frame_rows_start_end(start, end)
2024
2025
2026class ValueRange(WindowFrame):
2027    frame_type = "RANGE"
2028
2029    def window_frame_start_end(
2030        self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2031    ) -> tuple[str, str]:
2032        return connection.ops.window_frame_range_start_end(start, end)