1from __future__ import annotations
   2
   3import copy
   4import datetime
   5import functools
   6import inspect
   7from abc import ABC, abstractmethod
   8from collections import defaultdict
   9from decimal import Decimal
  10from functools import cached_property
  11from types import NoneType
  12from typing import TYPE_CHECKING, Any, Protocol, Self, runtime_checkable
  13from uuid import UUID
  14
  15from plain.models import fields
  16from plain.models.constants import LOOKUP_SEP
  17from plain.models.db import (
  18    DatabaseError,
  19    NotSupportedError,
  20    db_connection,
  21)
  22from plain.models.exceptions import EmptyResultSet, FieldError, FullResultSet
  23from plain.models.query_utils import Q
  24from plain.utils.deconstruct import deconstructible
  25from plain.utils.hashable import make_hashable
  26
  27if TYPE_CHECKING:
  28    from collections.abc import Callable, Iterable, Sequence
  29
  30    from plain.models.backends.base.base import BaseDatabaseWrapper
  31    from plain.models.fields import Field
  32    from plain.models.lookups import Lookup, Transform
  33    from plain.models.query import QuerySet
  34    from plain.models.sql.compiler import SQLCompilable, SQLCompiler
  35    from plain.models.sql.query import Query
  36
  37__all__ = [
  38    # Core expression classes
  39    "F",
  40    "Value",
  41    "Case",
  42    "When",
  43    "Subquery",
  44    "Exists",
  45    "OuterRef",
  46    "Window",
  47    "ExpressionWrapper",
  48    "RawSQL",
  49    "OrderBy",
  50    # Base classes (for extension)
  51    "Func",
  52    "Expression",
  53    "Combinable",
  54    # Window frame specs
  55    "RowRange",
  56    "ValueRange",
  57]
  58
  59
  60@runtime_checkable
  61class ResolvableExpression(Protocol):
  62    """Protocol for expressions that can be resolved in query context."""
  63
  64    def resolve_expression(
  65        self,
  66        query: Any = None,
  67        allow_joins: bool = True,
  68        reuse: Any = None,
  69        summarize: bool = False,
  70        for_save: bool = False,
  71    ) -> Any: ...
  72
  73
  74@runtime_checkable
  75class ReplaceableExpression(Protocol):
  76    """Protocol for expressions that support expression replacement."""
  77
  78    def replace_expressions(self, replacements: dict[Any, Any]) -> Self: ...
  79
  80
  81class Combinable:
  82    """
  83    Provide the ability to combine one or two objects with
  84    some connector. For example F('foo') + F('bar').
  85    """
  86
  87    # Arithmetic connectors
  88    ADD = "+"
  89    SUB = "-"
  90    MUL = "*"
  91    DIV = "/"
  92    POW = "^"
  93    # The following is a quoted % operator - it is quoted because it can be
  94    # used in strings that also have parameter substitution.
  95    MOD = "%%"
  96
  97    # Bitwise operators - note that these are generated by .bitand()
  98    # and .bitor(), the '&' and '|' are reserved for boolean operator
  99    # usage.
 100    BITAND = "&"
 101    BITOR = "|"
 102    BITLEFTSHIFT = "<<"
 103    BITRIGHTSHIFT = ">>"
 104    BITXOR = "#"
 105
 106    def _combine(
 107        self, other: Any, connector: str, reversed: bool
 108    ) -> CombinedExpression:
 109        if not isinstance(other, ResolvableExpression):
 110            # everything must be resolvable to an expression
 111            other = Value(other)
 112
 113        if reversed:
 114            return CombinedExpression(other, connector, self)
 115        return CombinedExpression(self, connector, other)
 116
 117    #############
 118    # OPERATORS #
 119    #############
 120
 121    def __neg__(self) -> CombinedExpression:
 122        return self._combine(-1, self.MUL, False)
 123
 124    def __add__(self, other: Any) -> CombinedExpression:
 125        return self._combine(other, self.ADD, False)
 126
 127    def __sub__(self, other: Any) -> CombinedExpression:
 128        return self._combine(other, self.SUB, False)
 129
 130    def __mul__(self, other: Any) -> CombinedExpression:
 131        return self._combine(other, self.MUL, False)
 132
 133    def __truediv__(self, other: Any) -> CombinedExpression:
 134        return self._combine(other, self.DIV, False)
 135
 136    def __mod__(self, other: Any) -> CombinedExpression:
 137        return self._combine(other, self.MOD, False)
 138
 139    def __pow__(self, other: Any) -> CombinedExpression:
 140        return self._combine(other, self.POW, False)
 141
 142    def __and__(self, other: Any) -> Q:
 143        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 144            return Q(self) & Q(other)
 145        raise NotImplementedError(
 146            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 147        )
 148
 149    def bitand(self, other: Any) -> CombinedExpression:
 150        return self._combine(other, self.BITAND, False)
 151
 152    def bitleftshift(self, other: Any) -> CombinedExpression:
 153        return self._combine(other, self.BITLEFTSHIFT, False)
 154
 155    def bitrightshift(self, other: Any) -> CombinedExpression:
 156        return self._combine(other, self.BITRIGHTSHIFT, False)
 157
 158    def __xor__(self, other: Any) -> Q:
 159        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 160            return Q(self) ^ Q(other)
 161        raise NotImplementedError(
 162            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 163        )
 164
 165    def bitxor(self, other: Any) -> CombinedExpression:
 166        return self._combine(other, self.BITXOR, False)
 167
 168    def __or__(self, other: Any) -> Q:
 169        if getattr(self, "conditional", False) and getattr(other, "conditional", False):
 170            return Q(self) | Q(other)
 171        raise NotImplementedError(
 172            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 173        )
 174
 175    def bitor(self, other: Any) -> CombinedExpression:
 176        return self._combine(other, self.BITOR, False)
 177
 178    def __radd__(self, other: Any) -> CombinedExpression:
 179        return self._combine(other, self.ADD, True)
 180
 181    def __rsub__(self, other: Any) -> CombinedExpression:
 182        return self._combine(other, self.SUB, True)
 183
 184    def __rmul__(self, other: Any) -> CombinedExpression:
 185        return self._combine(other, self.MUL, True)
 186
 187    def __rtruediv__(self, other: Any) -> CombinedExpression:
 188        return self._combine(other, self.DIV, True)
 189
 190    def __rmod__(self, other: Any) -> CombinedExpression:
 191        return self._combine(other, self.MOD, True)
 192
 193    def __rpow__(self, other: Any) -> CombinedExpression:
 194        return self._combine(other, self.POW, True)
 195
 196    def __rand__(self, other: Any) -> None:
 197        raise NotImplementedError(
 198            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 199        )
 200
 201    def __ror__(self, other: Any) -> None:
 202        raise NotImplementedError(
 203            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 204        )
 205
 206    def __rxor__(self, other: Any) -> None:
 207        raise NotImplementedError(
 208            "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
 209        )
 210
 211    def __invert__(self) -> NegatedExpression:
 212        return NegatedExpression(self)
 213
 214
 215class BaseExpression:
 216    """Base class for all query expressions."""
 217
 218    empty_result_set_value = NotImplemented
 219    # aggregate specific fields
 220    is_summary = False
 221    _output_field_resolved_to_none = False
 222    # Can the expression be used in a WHERE clause?
 223    filterable = True
 224    # Can the expression can be used as a source expression in Window?
 225    window_compatible = False
 226
 227    def __init__(self, output_field: Field | None = None):
 228        if output_field is not None:
 229            self.output_field = output_field
 230
 231    def __getstate__(self) -> dict[str, Any]:
 232        state = self.__dict__.copy()
 233        state.pop("convert_value", None)
 234        return state
 235
 236    def get_db_converters(
 237        self, connection: BaseDatabaseWrapper
 238    ) -> list[Callable[..., Any]]:
 239        converters = []
 240        if self.convert_value is not self._convert_value_noop:
 241            converters.append(self.convert_value)
 242        converters.extend(self.output_field.get_db_converters(connection))
 243        return converters
 244
 245    def get_source_expressions(self) -> list[Any]:
 246        return []
 247
 248    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 249        assert not exprs
 250
 251    def _parse_expressions(self, *expressions: Any) -> list[Any]:
 252        return [
 253            arg
 254            if isinstance(arg, ResolvableExpression)
 255            else (F(arg) if isinstance(arg, str) else Value(arg))
 256            for arg in expressions
 257        ]
 258
 259    def as_sql(
 260        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 261    ) -> tuple[str, Sequence[Any]]:
 262        """
 263        Responsible for returning a (sql, [params]) tuple to be included
 264        in the current query.
 265
 266        Different backends can provide their own implementation, by
 267        providing an `as_{vendor}` method and patching the Expression:
 268
 269        ```
 270        def override_as_sql(self, compiler, connection):
 271            # custom logic
 272            return super().as_sql(compiler, connection)
 273        setattr(Expression, 'as_' + connection.vendor, override_as_sql)
 274        ```
 275
 276        Arguments:
 277         * compiler: the query compiler responsible for generating the query.
 278           Must have a compile method, returning a (sql, [params]) tuple.
 279           Calling compiler(value) will return a quoted `value`.
 280
 281         * connection: the database connection used for the current query.
 282
 283        Return: (sql, params)
 284          Where `sql` is a string containing ordered sql parameters to be
 285          replaced with the elements of the list `params`.
 286        """
 287        raise NotImplementedError("Subclasses must implement as_sql()")
 288
 289    @cached_property
 290    def contains_aggregate(self) -> bool:
 291        return any(
 292            expr and expr.contains_aggregate for expr in self.get_source_expressions()
 293        )
 294
 295    @cached_property
 296    def contains_over_clause(self) -> bool:
 297        return any(
 298            expr and expr.contains_over_clause for expr in self.get_source_expressions()
 299        )
 300
 301    @cached_property
 302    def contains_column_references(self) -> bool:
 303        return any(
 304            expr and expr.contains_column_references
 305            for expr in self.get_source_expressions()
 306        )
 307
 308    def resolve_expression(
 309        self,
 310        query: Any = None,
 311        allow_joins: bool = True,
 312        reuse: Any = None,
 313        summarize: bool = False,
 314        for_save: bool = False,
 315    ) -> Self:
 316        """
 317        Provide the chance to do any preprocessing or validation before being
 318        added to the query.
 319
 320        Arguments:
 321         * query: the backend query implementation
 322         * allow_joins: boolean allowing or denying use of joins
 323           in this query
 324         * reuse: a set of reusable joins for multijoins
 325         * summarize: a terminal aggregate clause
 326         * for_save: whether this expression about to be used in a save or update
 327
 328        Return: an Expression to be added to the query.
 329        """
 330        c = self.copy()
 331        c.is_summary = summarize
 332        c.set_source_expressions(
 333            [
 334                expr.resolve_expression(query, allow_joins, reuse, summarize)
 335                if expr
 336                else None
 337                for expr in c.get_source_expressions()
 338            ]
 339        )
 340        return c
 341
 342    @property
 343    def conditional(self) -> bool:
 344        output_field = getattr(self, "output_field", None)
 345        return isinstance(output_field, fields.BooleanField)
 346
 347    @property
 348    def field(self) -> Field:
 349        return self.output_field
 350
 351    @cached_property
 352    def output_field(self) -> Field:
 353        """Return the output type of this expressions."""
 354        output_field = self._resolve_output_field()
 355        if output_field is None:
 356            self._output_field_resolved_to_none = True
 357            raise FieldError("Cannot resolve expression type, unknown output_field")
 358        return output_field
 359
 360    @cached_property
 361    def _output_field_or_none(self) -> Field | None:
 362        """
 363        Return the output field of this expression, or None if
 364        _resolve_output_field() didn't return an output type.
 365        """
 366        try:
 367            return self.output_field
 368        except FieldError:
 369            if not self._output_field_resolved_to_none:
 370                raise
 371            return None
 372
 373    def _resolve_output_field(self) -> Field | None:
 374        """
 375        Attempt to infer the output type of the expression.
 376
 377        As a guess, if the output fields of all source fields match then simply
 378        infer the same type here.
 379
 380        If a source's output field resolves to None, exclude it from this check.
 381        If all sources are None, then an error is raised higher up the stack in
 382        the output_field property.
 383        """
 384        # This guess is mostly a bad idea, but there is quite a lot of code
 385        # (especially 3rd party Func subclasses) that depend on it, we'd need a
 386        # deprecation path to fix it.
 387        sources_iter = (
 388            source for source in self.get_source_fields() if source is not None
 389        )
 390        for output_field in sources_iter:
 391            for source in sources_iter:
 392                if not isinstance(output_field, source.__class__):
 393                    raise FieldError(
 394                        f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
 395                        "set output_field."
 396                    )
 397            return output_field
 398        return None
 399
 400    @staticmethod
 401    def _convert_value_noop(
 402        value: Any, expression: Any, connection: BaseDatabaseWrapper
 403    ) -> Any:
 404        return value
 405
 406    @cached_property
 407    def convert_value(self) -> Callable[[Any, Any, Any], Any]:
 408        """
 409        Expressions provide their own converters because users have the option
 410        of manually specifying the output_field which may be a different type
 411        from the one the database returns.
 412        """
 413        field = self.output_field
 414        internal_type = field.get_internal_type()
 415        if internal_type == "FloatField":
 416            return (
 417                lambda value, expression, connection: None
 418                if value is None
 419                else float(value)
 420            )
 421        elif internal_type.endswith("IntegerField"):
 422            return (
 423                lambda value, expression, connection: None
 424                if value is None
 425                else int(value)
 426            )
 427        elif internal_type == "DecimalField":
 428            return (
 429                lambda value, expression, connection: None
 430                if value is None
 431                else Decimal(value)
 432            )
 433        return self._convert_value_noop
 434
 435    def get_lookup(self, lookup: str) -> type[Lookup] | None:
 436        return self.output_field.get_lookup(lookup)
 437
 438    def get_transform(self, name: str) -> type[Transform] | None:
 439        return self.output_field.get_transform(name)
 440
 441    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
 442        clone = self.copy()
 443        clone.set_source_expressions(
 444            [
 445                e.relabeled_clone(change_map) if e is not None else None
 446                for e in self.get_source_expressions()
 447            ]
 448        )
 449        return clone
 450
 451    def replace_expressions(self, replacements: dict[BaseExpression, Any]) -> Self:
 452        if replacement := replacements.get(self):
 453            return replacement
 454        clone = self.copy()
 455        source_expressions = clone.get_source_expressions()
 456        clone.set_source_expressions(
 457            [
 458                expr.replace_expressions(replacements) if expr else None
 459                for expr in source_expressions
 460            ]
 461        )
 462        return clone
 463
 464    def get_refs(self) -> set[str]:
 465        refs = set()
 466        for expr in self.get_source_expressions():
 467            refs |= expr.get_refs()
 468        return refs
 469
 470    def copy(self) -> Self:
 471        return copy.copy(self)
 472
 473    def prefix_references(self, prefix: str) -> Self:
 474        clone = self.copy()
 475        clone.set_source_expressions(
 476            [
 477                F(f"{prefix}{expr.name}")
 478                if isinstance(expr, F)
 479                else expr.prefix_references(prefix)
 480                for expr in self.get_source_expressions()
 481            ]
 482        )
 483        return clone
 484
 485    def get_group_by_cols(self) -> list[BaseExpression]:
 486        if not self.contains_aggregate:
 487            return [self]
 488        cols: list[BaseExpression] = []
 489        for source in self.get_source_expressions():
 490            cols.extend(source.get_group_by_cols())
 491        return cols
 492
 493    def get_source_fields(self) -> list[Field | None]:
 494        """Return the underlying field types used by this aggregate."""
 495        return [e._output_field_or_none for e in self.get_source_expressions()]
 496
 497    def asc(self, **kwargs: Any) -> OrderBy:
 498        return OrderBy(self, **kwargs)
 499
 500    def desc(self, **kwargs: Any) -> OrderBy:
 501        return OrderBy(self, descending=True, **kwargs)
 502
 503    def reverse_ordering(self) -> Self:
 504        return self
 505
 506    def flatten(self) -> Iterable[Any]:
 507        """
 508        Recursively yield this expression and all subexpressions, in
 509        depth-first order.
 510        """
 511        yield self
 512        for expr in self.get_source_expressions():
 513            if expr:
 514                if hasattr(expr, "flatten"):
 515                    yield from expr.flatten()
 516                else:
 517                    yield expr
 518
 519    def select_format(
 520        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
 521    ) -> tuple[str, Sequence[Any]]:
 522        """
 523        Custom format for select clauses. For example, EXISTS expressions need
 524        to be wrapped in CASE WHEN on Oracle.
 525        """
 526        if output_field := getattr(self, "output_field", None):
 527            if select_format := getattr(output_field, "select_format", None):
 528                return select_format(compiler, sql, params)
 529        return sql, params
 530
 531
 532@deconstructible
 533class Expression(BaseExpression, Combinable):
 534    """An expression that can be combined with other expressions."""
 535
 536    # Set by @deconstructible decorator in __new__
 537    _constructor_args: tuple[tuple[Any, ...], dict[str, Any]]
 538
 539    @cached_property
 540    def identity(self) -> tuple[Any, ...]:
 541        constructor_signature = inspect.signature(self.__init__)
 542        args, kwargs = self._constructor_args
 543        signature = constructor_signature.bind_partial(*args, **kwargs)
 544        signature.apply_defaults()
 545        arguments = signature.arguments.items()
 546        identity: list[Any] = [self.__class__]
 547        for arg, value in arguments:
 548            if isinstance(value, fields.Field):
 549                if value.name and value.model:
 550                    value = (value.model.model_options.label, value.name)
 551                else:
 552                    value = type(value)
 553            else:
 554                value = make_hashable(value)
 555            identity.append((arg, value))
 556        return tuple(identity)
 557
 558    def __eq__(self, other: object) -> bool:
 559        if not isinstance(other, Expression):
 560            return NotImplemented
 561        return other.identity == self.identity
 562
 563    def __hash__(self) -> int:
 564        return hash(self.identity)
 565
 566
 567# Type inference for CombinedExpression.output_field.
 568# Missing items will result in FieldError, by design.
 569#
 570# The current approach for NULL is based on lowest common denominator behavior
 571# i.e. if one of the supported databases is raising an error (rather than
 572# return NULL) for `val <op> NULL`, then Plain raises FieldError.
 573
 574_connector_combinations = [
 575    # Numeric operations - operands of same type.
 576    {
 577        connector: [
 578            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 579            (fields.FloatField, fields.FloatField, fields.FloatField),
 580            (fields.DecimalField, fields.DecimalField, fields.DecimalField),
 581        ]
 582        for connector in (
 583            Combinable.ADD,
 584            Combinable.SUB,
 585            Combinable.MUL,
 586            # Behavior for DIV with integer arguments follows Postgres/SQLite,
 587            # not MySQL/Oracle.
 588            Combinable.DIV,
 589            Combinable.MOD,
 590            Combinable.POW,
 591        )
 592    },
 593    # Numeric operations - operands of different type.
 594    {
 595        connector: [
 596            (fields.IntegerField, fields.DecimalField, fields.DecimalField),
 597            (fields.DecimalField, fields.IntegerField, fields.DecimalField),
 598            (fields.IntegerField, fields.FloatField, fields.FloatField),
 599            (fields.FloatField, fields.IntegerField, fields.FloatField),
 600        ]
 601        for connector in (
 602            Combinable.ADD,
 603            Combinable.SUB,
 604            Combinable.MUL,
 605            Combinable.DIV,
 606            Combinable.MOD,
 607        )
 608    },
 609    # Bitwise operators.
 610    {
 611        connector: [
 612            (fields.IntegerField, fields.IntegerField, fields.IntegerField),
 613        ]
 614        for connector in (
 615            Combinable.BITAND,
 616            Combinable.BITOR,
 617            Combinable.BITLEFTSHIFT,
 618            Combinable.BITRIGHTSHIFT,
 619            Combinable.BITXOR,
 620        )
 621    },
 622    # Numeric with NULL.
 623    {
 624        connector: [
 625            (field_type, NoneType, field_type),
 626            (NoneType, field_type, field_type),
 627        ]
 628        for connector in (
 629            Combinable.ADD,
 630            Combinable.SUB,
 631            Combinable.MUL,
 632            Combinable.DIV,
 633            Combinable.MOD,
 634            Combinable.POW,
 635        )
 636        for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
 637    },
 638    # Date/DateTimeField/DurationField/TimeField.
 639    {
 640        Combinable.ADD: [
 641            # Date/DateTimeField.
 642            (fields.DateField, fields.DurationField, fields.DateTimeField),
 643            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 644            (fields.DurationField, fields.DateField, fields.DateTimeField),
 645            (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
 646            # DurationField.
 647            (fields.DurationField, fields.DurationField, fields.DurationField),
 648            # TimeField.
 649            (fields.TimeField, fields.DurationField, fields.TimeField),
 650            (fields.DurationField, fields.TimeField, fields.TimeField),
 651        ],
 652    },
 653    {
 654        Combinable.SUB: [
 655            # Date/DateTimeField.
 656            (fields.DateField, fields.DurationField, fields.DateTimeField),
 657            (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
 658            (fields.DateField, fields.DateField, fields.DurationField),
 659            (fields.DateField, fields.DateTimeField, fields.DurationField),
 660            (fields.DateTimeField, fields.DateField, fields.DurationField),
 661            (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
 662            # DurationField.
 663            (fields.DurationField, fields.DurationField, fields.DurationField),
 664            # TimeField.
 665            (fields.TimeField, fields.DurationField, fields.TimeField),
 666            (fields.TimeField, fields.TimeField, fields.DurationField),
 667        ],
 668    },
 669]
 670
 671_connector_combinators = defaultdict(list)
 672
 673
 674def register_combinable_fields(
 675    lhs: type[Field] | type[None],
 676    connector: str,
 677    rhs: type[Field] | type[None],
 678    result: type[Field],
 679) -> None:
 680    """
 681    Register combinable types:
 682        lhs <connector> rhs -> result
 683    e.g.
 684        register_combinable_fields(
 685            IntegerField, Combinable.ADD, FloatField, FloatField
 686        )
 687    """
 688    _connector_combinators[connector].append((lhs, rhs, result))
 689
 690
 691for d in _connector_combinations:
 692    for connector, field_types in d.items():
 693        for lhs, rhs, result in field_types:
 694            register_combinable_fields(lhs, connector, rhs, result)
 695
 696
 697@functools.lru_cache(maxsize=128)
 698def _resolve_combined_type(
 699    connector: str, lhs_type: type[Field], rhs_type: type[Field]
 700) -> type[Field] | None:
 701    combinators = _connector_combinators.get(connector, ())
 702    for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
 703        if issubclass(lhs_type, combinator_lhs_type) and issubclass(
 704            rhs_type, combinator_rhs_type
 705        ):
 706            return combined_type
 707    return None
 708
 709
 710class SQLiteNumericMixin(Expression):
 711    """
 712    Some expressions with output_field=DecimalField() must be cast to
 713    numeric to be properly filtered.
 714    """
 715
 716    def as_sqlite(
 717        self,
 718        compiler: SQLCompiler,
 719        connection: BaseDatabaseWrapper,
 720        **extra_context: Any,
 721    ) -> tuple[str, Sequence[Any]]:
 722        sql, params = self.as_sql(compiler, connection, **extra_context)
 723        try:
 724            if self.output_field.get_internal_type() == "DecimalField":
 725                sql = f"CAST({sql} AS NUMERIC)"
 726        except FieldError:
 727            pass
 728        return sql, params
 729
 730
 731class CombinedExpression(SQLiteNumericMixin, Expression):
 732    def __init__(
 733        self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
 734    ):
 735        super().__init__(output_field=output_field)
 736        self.connector = connector
 737        self.lhs = lhs
 738        self.rhs = rhs
 739
 740    def __repr__(self) -> str:
 741        return f"<{self.__class__.__name__}: {self}>"
 742
 743    def __str__(self) -> str:
 744        return f"{self.lhs} {self.connector} {self.rhs}"
 745
 746    def get_source_expressions(self) -> list[Any]:
 747        return [self.lhs, self.rhs]
 748
 749    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 750        self.lhs, self.rhs = exprs
 751
 752    def _resolve_output_field(self) -> Field | None:
 753        # We avoid using super() here for reasons given in
 754        # Expression._resolve_output_field()
 755        combined_type = _resolve_combined_type(
 756            self.connector,
 757            type(self.lhs._output_field_or_none),
 758            type(self.rhs._output_field_or_none),
 759        )
 760        if combined_type is None:
 761            raise FieldError(
 762                f"Cannot infer type of {self.connector!r} expression involving these "
 763                f"types: {self.lhs.output_field.__class__.__name__}, "
 764                f"{self.rhs.output_field.__class__.__name__}. You must set "
 765                f"output_field."
 766            )
 767        return combined_type()
 768
 769    def as_sql(
 770        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 771    ) -> tuple[str, list[Any]]:
 772        expressions = []
 773        expression_params = []
 774        sql, params = compiler.compile(self.lhs)
 775        expressions.append(sql)
 776        expression_params.extend(params)
 777        sql, params = compiler.compile(self.rhs)
 778        expressions.append(sql)
 779        expression_params.extend(params)
 780        # order of precedence
 781        expression_wrapper = "(%s)"
 782        sql = connection.ops.combine_expression(self.connector, expressions)
 783        return expression_wrapper % sql, expression_params
 784
 785    def resolve_expression(
 786        self,
 787        query: Any = None,
 788        allow_joins: bool = True,
 789        reuse: Any = None,
 790        summarize: bool = False,
 791        for_save: bool = False,
 792    ) -> CombinedExpression | DurationExpression | TemporalSubtraction:
 793        lhs = self.lhs.resolve_expression(
 794            query, allow_joins, reuse, summarize, for_save
 795        )
 796        rhs = self.rhs.resolve_expression(
 797            query, allow_joins, reuse, summarize, for_save
 798        )
 799        if not isinstance(self, DurationExpression | TemporalSubtraction):
 800            try:
 801                lhs_type = lhs.output_field.get_internal_type()
 802            except (AttributeError, FieldError):
 803                lhs_type = None
 804            try:
 805                rhs_type = rhs.output_field.get_internal_type()
 806            except (AttributeError, FieldError):
 807                rhs_type = None
 808            if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
 809                return DurationExpression(
 810                    self.lhs, self.connector, self.rhs
 811                ).resolve_expression(
 812                    query,
 813                    allow_joins,
 814                    reuse,
 815                    summarize,
 816                    for_save,
 817                )
 818            datetime_fields = {"DateField", "DateTimeField", "TimeField"}
 819            if (
 820                self.connector == self.SUB
 821                and lhs_type in datetime_fields
 822                and lhs_type == rhs_type
 823            ):
 824                return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
 825                    query,
 826                    allow_joins,
 827                    reuse,
 828                    summarize,
 829                    for_save,
 830                )
 831        c = self.copy()
 832        c.is_summary = summarize
 833        c.lhs = lhs
 834        c.rhs = rhs
 835        return c
 836
 837
 838class DurationExpression(CombinedExpression):
 839    def compile(
 840        self, side: Any, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 841    ) -> tuple[str, Sequence[Any]]:
 842        try:
 843            output = side.output_field
 844        except FieldError:
 845            pass
 846        else:
 847            if output.get_internal_type() == "DurationField":
 848                sql, params = compiler.compile(side)
 849                return connection.ops.format_for_duration_arithmetic(sql), params
 850        return compiler.compile(side)
 851
 852    def as_sql(
 853        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 854    ) -> tuple[str, list[Any]]:
 855        if connection.features.has_native_duration_field:
 856            return super().as_sql(compiler, connection)
 857        connection.ops.check_expression_support(self)
 858        expressions = []
 859        expression_params = []
 860        sql, params = self.compile(self.lhs, compiler, connection)
 861        expressions.append(sql)
 862        expression_params.extend(params)
 863        sql, params = self.compile(self.rhs, compiler, connection)
 864        expressions.append(sql)
 865        expression_params.extend(params)
 866        # order of precedence
 867        expression_wrapper = "(%s)"
 868        sql = connection.ops.combine_duration_expression(self.connector, expressions)
 869        return expression_wrapper % sql, expression_params
 870
 871    def as_sqlite(
 872        self,
 873        compiler: SQLCompiler,
 874        connection: BaseDatabaseWrapper,
 875        **extra_context: Any,
 876    ) -> tuple[str, Sequence[Any]]:
 877        sql, params = self.as_sql(compiler, connection, **extra_context)
 878        if self.connector in {Combinable.MUL, Combinable.DIV}:
 879            try:
 880                lhs_type = self.lhs.output_field.get_internal_type()
 881                rhs_type = self.rhs.output_field.get_internal_type()
 882            except (AttributeError, FieldError):
 883                pass
 884            else:
 885                allowed_fields = {
 886                    "DecimalField",
 887                    "DurationField",
 888                    "FloatField",
 889                    "IntegerField",
 890                }
 891                if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
 892                    raise DatabaseError(
 893                        f"Invalid arguments for operator {self.connector}."
 894                    )
 895        return sql, params
 896
 897
 898class TemporalSubtraction(CombinedExpression):
 899    output_field = fields.DurationField()
 900
 901    def __init__(self, lhs: Any, rhs: Any):
 902        super().__init__(lhs, self.SUB, rhs)
 903
 904    def as_sql(
 905        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 906    ) -> tuple[str, list[Any]]:
 907        connection.ops.check_expression_support(self)
 908        lhs = compiler.compile(self.lhs)
 909        rhs = compiler.compile(self.rhs)
 910        sql, params = connection.ops.subtract_temporals(
 911            self.lhs.output_field.get_internal_type(), lhs, rhs
 912        )
 913        return sql, list(params)
 914
 915
 916@deconstructible(path="plain.models.F")
 917class F(Combinable):
 918    """An object capable of resolving references to existing query objects."""
 919
 920    def __init__(self, name: str):
 921        """
 922        Arguments:
 923         * name: the name of the field this expression references
 924        """
 925        self.name = name
 926
 927    def __repr__(self) -> str:
 928        return f"{self.__class__.__name__}({self.name})"
 929
 930    def resolve_expression(
 931        self,
 932        query: Any = None,
 933        allow_joins: bool = True,
 934        reuse: Any = None,
 935        summarize: bool = False,
 936        for_save: bool = False,
 937    ) -> Any:
 938        return query.resolve_ref(self.name, allow_joins, reuse, summarize)
 939
 940    def replace_expressions(self, replacements: dict[Any, Any]) -> F:
 941        return replacements.get(self, self)
 942
 943    def asc(self, **kwargs: Any) -> OrderBy:
 944        return OrderBy(self, **kwargs)
 945
 946    def desc(self, **kwargs: Any) -> OrderBy:
 947        return OrderBy(self, descending=True, **kwargs)
 948
 949    def __eq__(self, other: object) -> bool:
 950        if not isinstance(other, F):
 951            return NotImplemented
 952        return self.__class__ == other.__class__ and self.name == other.name
 953
 954    def __hash__(self) -> int:
 955        return hash(self.name)
 956
 957    def copy(self) -> Self:
 958        return copy.copy(self)
 959
 960
 961class ResolvedOuterRef(F):
 962    """
 963    An object that contains a reference to an outer query.
 964
 965    In this case, the reference to the outer query has been resolved because
 966    the inner query has been used as a subquery.
 967    """
 968
 969    contains_aggregate = False
 970    contains_over_clause = False
 971
 972    def as_sql(self, *args: Any, **kwargs: Any) -> None:
 973        raise ValueError(
 974            "This queryset contains a reference to an outer query and may "
 975            "only be used in a subquery."
 976        )
 977
 978    def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
 979        col = super().resolve_expression(*args, **kwargs)
 980        if col.contains_over_clause:
 981            raise NotSupportedError(
 982                f"Referencing outer query window expression is not supported: "
 983                f"{self.name}."
 984            )
 985        # FIXME: Rename possibly_multivalued to multivalued and fix detection
 986        # for non-multivalued JOINs (e.g. foreign key fields). This should take
 987        # into accountย only many-to-many and one-to-many relationships.
 988        col.possibly_multivalued = LOOKUP_SEP in self.name
 989        return col
 990
 991    def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
 992        return self
 993
 994    def get_group_by_cols(self) -> list[Any]:
 995        return []
 996
 997
 998class OuterRef(F):
 999    contains_aggregate = False
1000
1001    def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
1002        if isinstance(self.name, self.__class__):
1003            return self.name
1004        return ResolvedOuterRef(self.name)
1005
1006    def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
1007        return self
1008
1009
1010@deconstructible(path="plain.models.Func")
1011class Func(SQLiteNumericMixin, Expression):
1012    """An SQL function call."""
1013
1014    function = None
1015    template = "%(function)s(%(expressions)s)"
1016    arg_joiner = ", "
1017    arity = None  # The number of arguments the function accepts.
1018
1019    def __init__(
1020        self, *expressions: Any, output_field: Field | None = None, **extra: Any
1021    ):
1022        if self.arity is not None and len(expressions) != self.arity:
1023            raise TypeError(
1024                "'{}' takes exactly {} {} ({} given)".format(
1025                    self.__class__.__name__,
1026                    self.arity,
1027                    "argument" if self.arity == 1 else "arguments",
1028                    len(expressions),
1029                )
1030            )
1031        super().__init__(output_field=output_field)
1032        self.source_expressions: list[Any] = self._parse_expressions(*expressions)
1033        self.extra = extra
1034
1035    def __repr__(self) -> str:
1036        args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1037        extra = {**self.extra, **self._get_repr_options()}
1038        if extra:
1039            extra = ", ".join(
1040                str(key) + "=" + str(val) for key, val in sorted(extra.items())
1041            )
1042            return f"{self.__class__.__name__}({args}, {extra})"
1043        return f"{self.__class__.__name__}({args})"
1044
1045    def _get_repr_options(self) -> dict[str, Any]:
1046        """Return a dict of extra __init__() options to include in the repr."""
1047        return {}
1048
1049    def get_source_expressions(self) -> list[Any]:
1050        return self.source_expressions
1051
1052    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1053        self.source_expressions = list(exprs)
1054
1055    def resolve_expression(
1056        self,
1057        query: Any = None,
1058        allow_joins: bool = True,
1059        reuse: Any = None,
1060        summarize: bool = False,
1061        for_save: bool = False,
1062    ) -> Self:
1063        c = self.copy()
1064        c.is_summary = summarize
1065        for pos, arg in enumerate(c.source_expressions):
1066            c.source_expressions[pos] = arg.resolve_expression(
1067                query, allow_joins, reuse, summarize, for_save
1068            )
1069        return c
1070
1071    def as_sql(
1072        self,
1073        compiler: SQLCompiler,
1074        connection: BaseDatabaseWrapper,
1075        function: str | None = None,
1076        template: str | None = None,
1077        arg_joiner: str | None = None,
1078        **extra_context: Any,
1079    ) -> tuple[str, list[Any]]:
1080        connection.ops.check_expression_support(self)
1081        sql_parts = []
1082        params = []
1083        for arg in self.source_expressions:
1084            try:
1085                arg_sql, arg_params = compiler.compile(arg)
1086            except EmptyResultSet:
1087                empty_result_set_value = getattr(
1088                    arg, "empty_result_set_value", NotImplemented
1089                )
1090                if empty_result_set_value is NotImplemented:
1091                    raise
1092                arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
1093            except FullResultSet:
1094                arg_sql, arg_params = compiler.compile(Value(True))
1095            sql_parts.append(arg_sql)
1096            params.extend(arg_params)
1097        data = {**self.extra, **extra_context}
1098        # Use the first supplied value in this order: the parameter to this
1099        # method, a value supplied in __init__()'s **extra (the value in
1100        # `data`), or the value defined on the class.
1101        if function is not None:
1102            data["function"] = function
1103        else:
1104            data.setdefault("function", self.function)
1105        template = template or data.get("template", self.template)
1106        arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1107        data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1108        return template % data, params
1109
1110    def copy(self) -> Self:
1111        clone = super().copy()
1112        clone.source_expressions = self.source_expressions[:]
1113        clone.extra = self.extra.copy()
1114        return clone
1115
1116
1117@deconstructible(path="plain.models.Value")
1118class Value(SQLiteNumericMixin, Expression):
1119    """Represent a wrapped value as a node within an expression."""
1120
1121    # Provide a default value for `for_save` in order to allow unresolved
1122    # instances to be compiled until a decision is taken in #25425.
1123    for_save = False
1124
1125    def __init__(self, value: Any, output_field: Field | None = None):
1126        """
1127        Arguments:
1128         * value: the value this expression represents. The value will be
1129           added into the sql parameter list and properly quoted.
1130
1131         * output_field: an instance of the model field type that this
1132           expression will return, such as IntegerField() or CharField().
1133        """
1134        super().__init__(output_field=output_field)
1135        self.value = value
1136
1137    def __repr__(self) -> str:
1138        return f"{self.__class__.__name__}({self.value!r})"
1139
1140    def as_sql(
1141        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1142    ) -> tuple[str, list[Any]]:
1143        connection.ops.check_expression_support(self)
1144        val = self.value
1145        output_field = self._output_field_or_none
1146        if output_field is not None:
1147            if self.for_save:
1148                val = output_field.get_db_prep_save(val, connection=connection)
1149            else:
1150                val = output_field.get_db_prep_value(val, connection=connection)
1151            if hasattr(output_field, "get_placeholder"):
1152                return output_field.get_placeholder(val, compiler, connection), [val]
1153        if val is None:
1154            # cx_Oracle does not always convert None to the appropriate
1155            # NULL type (like in case expressions using numbers), so we
1156            # use a literal SQL NULL
1157            return "NULL", []
1158        return "%s", [val]
1159
1160    def resolve_expression(
1161        self,
1162        query: Any = None,
1163        allow_joins: bool = True,
1164        reuse: Any = None,
1165        summarize: bool = False,
1166        for_save: bool = False,
1167    ) -> Value:
1168        c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1169        c.for_save = for_save
1170        return c
1171
1172    def get_group_by_cols(self) -> list[Any]:
1173        return []
1174
1175    def _resolve_output_field(self) -> Field | None:
1176        if isinstance(self.value, str):
1177            return fields.CharField()
1178        if isinstance(self.value, bool):
1179            return fields.BooleanField()
1180        if isinstance(self.value, int):
1181            return fields.IntegerField()
1182        if isinstance(self.value, float):
1183            return fields.FloatField()
1184        if isinstance(self.value, datetime.datetime):
1185            return fields.DateTimeField()
1186        if isinstance(self.value, datetime.date):
1187            return fields.DateField()
1188        if isinstance(self.value, datetime.time):
1189            return fields.TimeField()
1190        if isinstance(self.value, datetime.timedelta):
1191            return fields.DurationField()
1192        if isinstance(self.value, Decimal):
1193            return fields.DecimalField()
1194        if isinstance(self.value, bytes):
1195            return fields.BinaryField()
1196        if isinstance(self.value, UUID):
1197            return fields.UUIDField()
1198
1199    @property
1200    def empty_result_set_value(self) -> Any:
1201        return self.value
1202
1203
1204class RawSQL(Expression):
1205    def __init__(
1206        self, sql: str, params: Sequence[Any], output_field: Field | None = None
1207    ):
1208        if output_field is None:
1209            output_field = fields.Field()
1210        self.sql, self.params = sql, params
1211        super().__init__(output_field=output_field)
1212
1213    def __repr__(self) -> str:
1214        return f"{self.__class__.__name__}({self.sql}, {self.params})"
1215
1216    def as_sql(
1217        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1218    ) -> tuple[str, Sequence[Any]]:
1219        return f"({self.sql})", self.params
1220
1221    def get_group_by_cols(self) -> list[BaseExpression]:
1222        return [self]
1223
1224
1225class Star(Expression):
1226    def __repr__(self) -> str:
1227        return "'*'"
1228
1229    def as_sql(
1230        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1231    ) -> tuple[str, list[Any]]:
1232        return "*", []
1233
1234
1235class Col(Expression):
1236    contains_column_references = True
1237    possibly_multivalued = False
1238
1239    def __init__(
1240        self, alias: str | None, target: Any, output_field: Field | None = None
1241    ):
1242        if output_field is None:
1243            output_field = target
1244        super().__init__(output_field=output_field)
1245        self.alias, self.target = alias, target
1246
1247    def __repr__(self) -> str:
1248        alias, target = self.alias, self.target
1249        identifiers = (alias, str(target)) if alias else (str(target),)
1250        return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1251
1252    def as_sql(
1253        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1254    ) -> tuple[str, list[Any]]:
1255        alias, column = self.alias, self.target.column
1256        identifiers = (alias, column) if alias else (column,)
1257        sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1258        return sql, []
1259
1260    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1261        if self.alias is None:
1262            return self
1263        return self.__class__(
1264            change_map.get(self.alias, self.alias), self.target, self.output_field
1265        )
1266
1267    def get_group_by_cols(self) -> list[BaseExpression]:
1268        return [self]
1269
1270    def get_db_converters(
1271        self, connection: BaseDatabaseWrapper
1272    ) -> list[Callable[..., Any]]:
1273        if self.target == self.output_field:
1274            return self.output_field.get_db_converters(connection)
1275        return self.output_field.get_db_converters(
1276            connection
1277        ) + self.target.get_db_converters(connection)
1278
1279
1280class Ref(Expression):
1281    """
1282    Reference to column alias of the query. For example, Ref('sum_cost') in
1283    qs.annotate(sum_cost=Sum('cost')) query.
1284    """
1285
1286    def __init__(self, refs: str, source: Any):
1287        super().__init__()
1288        self.refs, self.source = refs, source
1289
1290    def __repr__(self) -> str:
1291        return f"{self.__class__.__name__}({self.refs}, {self.source})"
1292
1293    def get_source_expressions(self) -> list[Any]:
1294        return [self.source]
1295
1296    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1297        (self.source,) = exprs
1298
1299    def resolve_expression(
1300        self,
1301        query: Any = None,
1302        allow_joins: bool = True,
1303        reuse: Any = None,
1304        summarize: bool = False,
1305        for_save: bool = False,
1306    ) -> Ref:
1307        # The sub-expression `source` has already been resolved, as this is
1308        # just a reference to the name of `source`.
1309        return self
1310
1311    def get_refs(self) -> set[str]:
1312        return {self.refs}
1313
1314    def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1315        return self
1316
1317    def as_sql(
1318        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1319    ) -> tuple[str, list[Any]]:
1320        return connection.ops.quote_name(self.refs), []
1321
1322    def get_group_by_cols(self) -> list[BaseExpression]:
1323        return [self]
1324
1325
1326class ExpressionList(Func):
1327    """
1328    An expression containing multiple expressions. Can be used to provide a
1329    list of expressions as an argument to another expression, like a partition
1330    clause.
1331    """
1332
1333    template = "%(expressions)s"
1334
1335    def __init__(self, *expressions: Any, **extra: Any):
1336        if not expressions:
1337            raise ValueError(
1338                f"{self.__class__.__name__} requires at least one expression."
1339            )
1340        super().__init__(*expressions, **extra)
1341
1342    def __str__(self) -> str:
1343        return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1344
1345    def as_sqlite(
1346        self,
1347        compiler: SQLCompiler,
1348        connection: BaseDatabaseWrapper,
1349        **extra_context: Any,
1350    ) -> tuple[str, Sequence[Any]]:
1351        # Casting to numeric is unnecessary.
1352        return self.as_sql(compiler, connection, **extra_context)
1353
1354
1355class OrderByList(Func):
1356    template = "ORDER BY %(expressions)s"
1357
1358    def __init__(self, *expressions: Any, **extra: Any):
1359        expressions_tuple = tuple(
1360            (
1361                OrderBy(F(expr[1:]), descending=True)
1362                if isinstance(expr, str) and expr[0] == "-"
1363                else expr
1364            )
1365            for expr in expressions
1366        )
1367        super().__init__(*expressions_tuple, **extra)
1368
1369    def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, list[Any]]:
1370        if not self.source_expressions:
1371            return "", []
1372        sql, params = super().as_sql(*args, **kwargs)
1373        return sql, list(params)
1374
1375    def get_group_by_cols(self) -> list[Any]:
1376        group_by_cols = []
1377        for order_by in self.get_source_expressions():
1378            group_by_cols.extend(order_by.get_group_by_cols())
1379        return group_by_cols
1380
1381
1382@deconstructible(path="plain.models.ExpressionWrapper")
1383class ExpressionWrapper(SQLiteNumericMixin, Expression):
1384    """
1385    An expression that can wrap another expression so that it can provide
1386    extra context to the inner expression, such as the output_field.
1387    """
1388
1389    def __init__(self, expression: Any, output_field: Field):
1390        super().__init__(output_field=output_field)
1391        self.expression = expression
1392
1393    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1394        self.expression = exprs[0]
1395
1396    def get_source_expressions(self) -> list[Any]:
1397        return [self.expression]
1398
1399    def get_group_by_cols(self) -> list[Any]:
1400        if isinstance(self.expression, Expression):
1401            expression = self.expression.copy()
1402            expression.output_field = self.output_field
1403            return expression.get_group_by_cols()
1404        # For non-expressions e.g. an SQL WHERE clause, the entire
1405        # `expression` must be included in the GROUP BY clause.
1406        return super().get_group_by_cols()
1407
1408    def as_sql(
1409        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1410    ) -> tuple[str, Sequence[Any]]:
1411        return compiler.compile(self.expression)
1412
1413    def __repr__(self) -> str:
1414        return f"{self.__class__.__name__}({self.expression})"
1415
1416
1417class NegatedExpression(ExpressionWrapper):
1418    """The logical negation of a conditional expression."""
1419
1420    def __init__(self, expression: Any):
1421        super().__init__(expression, output_field=fields.BooleanField())
1422
1423    def __invert__(self) -> Any:
1424        return self.expression.copy()
1425
1426    def as_sql(
1427        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1428    ) -> tuple[str, Sequence[Any]]:
1429        try:
1430            sql, params = super().as_sql(compiler, connection)
1431        except EmptyResultSet:
1432            features = compiler.connection.features
1433            if not features.supports_boolean_expr_in_select_clause:
1434                return "1=1", ()
1435            return compiler.compile(Value(True))
1436        ops = compiler.connection.ops
1437        # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1438        # to be compared to another expression unless they're wrapped in a CASE
1439        # WHEN.
1440        if not ops.conditional_expression_supported_in_where_clause(self.expression):
1441            return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1442        return f"NOT {sql}", params
1443
1444    def resolve_expression(
1445        self,
1446        query: Any = None,
1447        allow_joins: bool = True,
1448        reuse: Any = None,
1449        summarize: bool = False,
1450        for_save: bool = False,
1451    ) -> NegatedExpression:
1452        resolved = super().resolve_expression(
1453            query, allow_joins, reuse, summarize, for_save
1454        )
1455        if not getattr(resolved.expression, "conditional", False):
1456            raise TypeError("Cannot negate non-conditional expressions.")
1457        return resolved
1458
1459    def select_format(
1460        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1461    ) -> tuple[str, Sequence[Any]]:
1462        # Wrap boolean expressions with a CASE WHEN expression if a database
1463        # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1464        # GROUP BY list.
1465        expression_supported_in_where_clause = (
1466            compiler.connection.ops.conditional_expression_supported_in_where_clause
1467        )
1468        if (
1469            not compiler.connection.features.supports_boolean_expr_in_select_clause
1470            # Avoid double wrapping.
1471            and expression_supported_in_where_clause(self.expression)
1472        ):
1473            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1474        return sql, params
1475
1476
1477@deconstructible(path="plain.models.When")
1478class When(Expression):
1479    template = "WHEN %(condition)s THEN %(result)s"
1480    # This isn't a complete conditional expression, must be used in Case().
1481    conditional = False
1482    condition: SQLCompilable
1483
1484    def __init__(
1485        self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1486    ):
1487        lookups_dict: dict[str, Any] | None = lookups or None
1488        if lookups_dict:
1489            if condition is None:
1490                condition, lookups_dict = Q(**lookups_dict), None
1491            elif getattr(condition, "conditional", False):
1492                condition, lookups_dict = Q(condition, **lookups_dict), None
1493        if (
1494            condition is None
1495            or not getattr(condition, "conditional", False)
1496            or lookups_dict
1497        ):
1498            raise TypeError(
1499                "When() supports a Q object, a boolean expression, or lookups "
1500                "as a condition."
1501            )
1502        if isinstance(condition, Q) and not condition:
1503            raise ValueError("An empty Q() can't be used as a When() condition.")
1504        super().__init__(output_field=None)
1505        self.condition = condition  # type: ignore[assignment]
1506        self.result = self._parse_expressions(then)[0]
1507
1508    def __str__(self) -> str:
1509        return f"WHEN {self.condition!r} THEN {self.result!r}"
1510
1511    def __repr__(self) -> str:
1512        return f"<{self.__class__.__name__}: {self}>"
1513
1514    def get_source_expressions(self) -> list[Any]:
1515        return [self.condition, self.result]
1516
1517    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1518        self.condition, self.result = exprs
1519
1520    def get_source_fields(self) -> list[Field | None]:
1521        # We're only interested in the fields of the result expressions.
1522        return [self.result._output_field_or_none]
1523
1524    def resolve_expression(
1525        self,
1526        query: Any = None,
1527        allow_joins: bool = True,
1528        reuse: Any = None,
1529        summarize: bool = False,
1530        for_save: bool = False,
1531    ) -> When:
1532        c = self.copy()
1533        c.is_summary = summarize
1534        if isinstance(c.condition, ResolvableExpression):
1535            c.condition = c.condition.resolve_expression(
1536                query, allow_joins, reuse, summarize, False
1537            )
1538        c.result = c.result.resolve_expression(
1539            query, allow_joins, reuse, summarize, for_save
1540        )
1541        return c
1542
1543    def as_sql(
1544        self,
1545        compiler: SQLCompiler,
1546        connection: BaseDatabaseWrapper,
1547        template: str | None = None,
1548        **extra_context: Any,
1549    ) -> tuple[str, tuple[Any, ...]]:
1550        connection.ops.check_expression_support(self)
1551        template_params = extra_context
1552        sql_params = []
1553        # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1554        condition_sql, condition_params = compiler.compile(self.condition)
1555        template_params["condition"] = condition_sql
1556        result_sql, result_params = compiler.compile(self.result)
1557        template_params["result"] = result_sql
1558        template = template or self.template
1559        return template % template_params, (
1560            *sql_params,
1561            *condition_params,
1562            *result_params,
1563        )
1564
1565    def get_group_by_cols(self) -> list[Any]:
1566        # This is not a complete expression and cannot be used in GROUP BY.
1567        cols = []
1568        for source in self.get_source_expressions():
1569            cols.extend(source.get_group_by_cols())
1570        return cols
1571
1572
1573@deconstructible(path="plain.models.Case")
1574class Case(SQLiteNumericMixin, Expression):
1575    """
1576    An SQL searched CASE expression:
1577
1578        CASE
1579            WHEN n > 0
1580                THEN 'positive'
1581            WHEN n < 0
1582                THEN 'negative'
1583            ELSE 'zero'
1584        END
1585    """
1586
1587    template = "CASE %(cases)s ELSE %(default)s END"
1588    case_joiner = " "
1589
1590    def __init__(
1591        self,
1592        *cases: When,
1593        default: Any = None,
1594        output_field: Field | None = None,
1595        **extra: Any,
1596    ):
1597        if not all(isinstance(case, When) for case in cases):
1598            raise TypeError("Positional arguments must all be When objects.")
1599        super().__init__(output_field)
1600        self.cases = list(cases)
1601        self.default = self._parse_expressions(default)[0]
1602        self.extra = extra
1603
1604    def __str__(self) -> str:
1605        return "CASE {}, ELSE {!r}".format(
1606            ", ".join(str(c) for c in self.cases),
1607            self.default,
1608        )
1609
1610    def __repr__(self) -> str:
1611        return f"<{self.__class__.__name__}: {self}>"
1612
1613    def get_source_expressions(self) -> list[Any]:
1614        return self.cases + [self.default]
1615
1616    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1617        *self.cases, self.default = exprs
1618
1619    def resolve_expression(
1620        self,
1621        query: Any = None,
1622        allow_joins: bool = True,
1623        reuse: Any = None,
1624        summarize: bool = False,
1625        for_save: bool = False,
1626    ) -> Case:
1627        c = self.copy()
1628        c.is_summary = summarize
1629        for pos, case in enumerate(c.cases):
1630            c.cases[pos] = case.resolve_expression(
1631                query, allow_joins, reuse, summarize, for_save
1632            )
1633        c.default = c.default.resolve_expression(
1634            query, allow_joins, reuse, summarize, for_save
1635        )
1636        return c
1637
1638    def copy(self) -> Self:
1639        c = super().copy()
1640        c.cases = c.cases[:]
1641        return c
1642
1643    def as_sql(
1644        self,
1645        compiler: SQLCompiler,
1646        connection: BaseDatabaseWrapper,
1647        template: str | None = None,
1648        case_joiner: str | None = None,
1649        **extra_context: Any,
1650    ) -> tuple[str, list[Any]]:
1651        connection.ops.check_expression_support(self)
1652        if not self.cases:
1653            sql, params = compiler.compile(self.default)
1654            return sql, list(params)
1655        template_params = {**self.extra, **extra_context}
1656        case_parts = []
1657        sql_params = []
1658        default_sql, default_params = compiler.compile(self.default)
1659        for case in self.cases:
1660            try:
1661                case_sql, case_params = compiler.compile(case)
1662            except EmptyResultSet:
1663                continue
1664            except FullResultSet:
1665                default_sql, default_params = compiler.compile(case.result)
1666                break
1667            case_parts.append(case_sql)
1668            sql_params.extend(case_params)
1669        if not case_parts:
1670            return default_sql, list(default_params)
1671        case_joiner = case_joiner or self.case_joiner
1672        template_params["cases"] = case_joiner.join(case_parts)
1673        template_params["default"] = default_sql
1674        sql_params.extend(default_params)
1675        template = template or template_params.get("template", self.template)
1676        sql = template % template_params
1677        if self._output_field_or_none is not None:
1678            sql = connection.ops.unification_cast_sql(self.output_field) % sql
1679        return sql, sql_params
1680
1681    def get_group_by_cols(self) -> list[Any]:
1682        if not self.cases:
1683            return self.default.get_group_by_cols()
1684        return super().get_group_by_cols()
1685
1686
1687class Subquery(BaseExpression, Combinable):
1688    """
1689    An explicit subquery. It may contain OuterRef() references to the outer
1690    query which will be resolved when it is applied to that query.
1691    """
1692
1693    template = "(%(subquery)s)"
1694    contains_aggregate = False
1695    empty_result_set_value = None
1696
1697    def __init__(
1698        self,
1699        query: QuerySet[Any] | Query,
1700        output_field: Field | None = None,
1701        **extra: Any,
1702    ):
1703        # Import here to avoid circular import
1704        from plain.models.sql.query import Query
1705
1706        # Allow the usage of both QuerySet and sql.Query objects.
1707        if isinstance(query, Query):
1708            # It's already a Query object, use it directly
1709            sql_query = query
1710        else:
1711            # It's a QuerySet, extract the sql.Query
1712            sql_query = query.sql_query
1713        self.query = sql_query.clone()
1714        self.query.subquery = True
1715        self.extra = extra
1716        super().__init__(output_field)
1717
1718    def get_source_expressions(self) -> list[Any]:
1719        return [self.query]
1720
1721    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1722        self.query = exprs[0]
1723
1724    def _resolve_output_field(self) -> Field | None:
1725        return self.query.output_field
1726
1727    def copy(self) -> Self:
1728        clone = super().copy()
1729        clone.query = clone.query.clone()
1730        return clone
1731
1732    @property
1733    def external_aliases(self) -> dict[str, bool]:
1734        return self.query.external_aliases
1735
1736    def get_external_cols(self) -> list[Any]:
1737        return self.query.get_external_cols()
1738
1739    def as_sql(
1740        self,
1741        compiler: SQLCompiler,
1742        connection: BaseDatabaseWrapper,
1743        template: str | None = None,
1744        **extra_context: Any,
1745    ) -> tuple[str, tuple[Any, ...]]:
1746        connection.ops.check_expression_support(self)
1747        template_params = {**self.extra, **extra_context}
1748        subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1749        template_params["subquery"] = subquery_sql[1:-1]
1750
1751        template = template or template_params.get("template", self.template)
1752        sql = template % template_params
1753        return sql, sql_params
1754
1755    def get_group_by_cols(self) -> list[Any]:
1756        return self.query.get_group_by_cols(wrapper=self)
1757
1758
1759class Exists(Subquery):
1760    template = "EXISTS(%(subquery)s)"
1761    output_field = fields.BooleanField()
1762    empty_result_set_value = False
1763
1764    def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1765        super().__init__(query, **kwargs)
1766        self.query = self.query.exists()
1767
1768    def select_format(
1769        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1770    ) -> tuple[str, Sequence[Any]]:
1771        # Wrap EXISTS() with a CASE WHEN expression if a database backend
1772        # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1773        # BY list.
1774        if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1775            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1776        return sql, params
1777
1778
1779@deconstructible(path="plain.models.OrderBy")
1780class OrderBy(Expression):
1781    template = "%(expression)s %(ordering)s"
1782    conditional = False
1783
1784    def __init__(
1785        self,
1786        expression: Any,
1787        descending: bool = False,
1788        nulls_first: bool | None = None,
1789        nulls_last: bool | None = None,
1790    ):
1791        if nulls_first and nulls_last:
1792            raise ValueError("nulls_first and nulls_last are mutually exclusive")
1793        if nulls_first is False or nulls_last is False:
1794            raise ValueError("nulls_first and nulls_last values must be True or None.")
1795        self.nulls_first = nulls_first
1796        self.nulls_last = nulls_last
1797        self.descending = descending
1798        if not isinstance(expression, ResolvableExpression):
1799            raise ValueError("expression must be an expression type")
1800        self.expression = expression
1801
1802    def __repr__(self) -> str:
1803        return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1804
1805    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1806        self.expression = exprs[0]
1807
1808    def get_source_expressions(self) -> list[Any]:
1809        return [self.expression]
1810
1811    def as_sql(
1812        self,
1813        compiler: SQLCompiler,
1814        connection: BaseDatabaseWrapper,
1815        template: str | None = None,
1816        **extra_context: Any,
1817    ) -> tuple[str, tuple[Any, ...]]:
1818        template = template or self.template
1819        if connection.features.supports_order_by_nulls_modifier:
1820            if self.nulls_last:
1821                template = f"{template} NULLS LAST"
1822            elif self.nulls_first:
1823                template = f"{template} NULLS FIRST"
1824        else:
1825            if self.nulls_last and not (
1826                self.descending and connection.features.order_by_nulls_first
1827            ):
1828                template = f"%(expression)s IS NULL, {template}"
1829            elif self.nulls_first and not (
1830                not self.descending and connection.features.order_by_nulls_first
1831            ):
1832                template = f"%(expression)s IS NOT NULL, {template}"
1833        connection.ops.check_expression_support(self)
1834        expression_sql, params = compiler.compile(self.expression)
1835        placeholders = {
1836            "expression": expression_sql,
1837            "ordering": "DESC" if self.descending else "ASC",
1838            **extra_context,
1839        }
1840        params *= template.count("%(expression)s")
1841        return (template % placeholders).rstrip(), params
1842
1843    def get_group_by_cols(self) -> list[Any]:
1844        cols = []
1845        for source in self.get_source_expressions():
1846            cols.extend(source.get_group_by_cols())
1847        return cols
1848
1849    def reverse_ordering(self) -> OrderBy:
1850        self.descending = not self.descending
1851        if self.nulls_first:
1852            self.nulls_last = True
1853            self.nulls_first = None
1854        elif self.nulls_last:
1855            self.nulls_first = True
1856            self.nulls_last = None
1857        return self
1858
1859    def asc(self) -> None:  # type: ignore[override]
1860        self.descending = False
1861
1862    def desc(self) -> None:  # type: ignore[override]
1863        self.descending = True
1864
1865
1866class Window(SQLiteNumericMixin, Expression):
1867    template = "%(expression)s OVER (%(window)s)"
1868    # Although the main expression may either be an aggregate or an
1869    # expression with an aggregate function, the GROUP BY that will
1870    # be introduced in the query as a result is not desired.
1871    contains_aggregate = False
1872    contains_over_clause = True
1873    partition_by: ExpressionList | None
1874    order_by: OrderByList | None
1875
1876    def __init__(
1877        self,
1878        expression: Any,
1879        partition_by: Any = None,
1880        order_by: Any = None,
1881        frame: Any = None,
1882        output_field: Field | None = None,
1883    ):
1884        self.partition_by = partition_by
1885        self.order_by = order_by
1886        self.frame = frame
1887
1888        if not getattr(expression, "window_compatible", False):
1889            raise ValueError(
1890                f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1891            )
1892
1893        if self.partition_by is not None:
1894            partition_by_values = (
1895                self.partition_by
1896                if isinstance(self.partition_by, tuple | list)
1897                else (self.partition_by,)
1898            )
1899            self.partition_by = ExpressionList(*partition_by_values)
1900
1901        if self.order_by is not None:
1902            if isinstance(self.order_by, list | tuple):
1903                self.order_by = OrderByList(*self.order_by)
1904            elif isinstance(self.order_by, BaseExpression | str):
1905                self.order_by = OrderByList(self.order_by)
1906            else:
1907                raise ValueError(
1908                    "Window.order_by must be either a string reference to a "
1909                    "field, an expression, or a list or tuple of them."
1910                )
1911        super().__init__(output_field=output_field)
1912        self.source_expression = self._parse_expressions(expression)[0]
1913
1914    def _resolve_output_field(self) -> Field | None:
1915        return self.source_expression.output_field
1916
1917    def get_source_expressions(self) -> list[Any]:
1918        return [self.source_expression, self.partition_by, self.order_by, self.frame]
1919
1920    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1921        self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1922
1923    def as_sql(
1924        self,
1925        compiler: SQLCompiler,
1926        connection: BaseDatabaseWrapper,
1927        template: str | None = None,
1928    ) -> tuple[str, tuple[Any, ...]]:
1929        connection.ops.check_expression_support(self)
1930        if not connection.features.supports_over_clause:
1931            raise NotSupportedError("This backend does not support window expressions.")
1932        expr_sql, params = compiler.compile(self.source_expression)
1933        window_sql, window_params = [], ()
1934
1935        if self.partition_by is not None:
1936            sql_expr, sql_params = self.partition_by.as_sql(
1937                compiler=compiler,
1938                connection=connection,
1939                template="PARTITION BY %(expressions)s",
1940            )
1941            window_sql.append(sql_expr)
1942            window_params += tuple(sql_params)
1943
1944        if self.order_by is not None:
1945            order_sql, order_params = compiler.compile(self.order_by)
1946            window_sql.append(order_sql)
1947            window_params += tuple(order_params)
1948
1949        if self.frame:
1950            frame_sql, frame_params = compiler.compile(self.frame)
1951            window_sql.append(frame_sql)
1952            window_params += tuple(frame_params)
1953
1954        template = template or self.template
1955
1956        return (
1957            template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1958            (*params, *window_params),
1959        )
1960
1961    def as_sqlite(
1962        self,
1963        compiler: SQLCompiler,
1964        connection: BaseDatabaseWrapper,
1965        **extra_context: Any,
1966    ) -> tuple[str, Sequence[Any]]:
1967        if isinstance(self.output_field, fields.DecimalField):
1968            # Casting to numeric must be outside of the window expression.
1969            copy = self.copy()
1970            source_expressions = copy.get_source_expressions()
1971            source_expressions[0].output_field = fields.FloatField()
1972            copy.set_source_expressions(source_expressions)
1973            return super(Window, copy).as_sqlite(compiler, connection, **extra_context)
1974        return self.as_sql(compiler, connection, **extra_context)
1975
1976    def __str__(self) -> str:
1977        return "{} OVER ({}{}{})".format(
1978            str(self.source_expression),
1979            "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1980            str(self.order_by or ""),
1981            str(self.frame or ""),
1982        )
1983
1984    def __repr__(self) -> str:
1985        return f"<{self.__class__.__name__}: {self}>"
1986
1987    def get_group_by_cols(self) -> list[Any]:
1988        group_by_cols = []
1989        if self.partition_by:
1990            group_by_cols.extend(self.partition_by.get_group_by_cols())
1991        if self.order_by is not None:
1992            group_by_cols.extend(self.order_by.get_group_by_cols())
1993        return group_by_cols
1994
1995
1996class WindowFrame(Expression, ABC):
1997    """
1998    Model the frame clause in window expressions. There are two types of frame
1999    clauses which are subclasses, however, all processing and validation (by no
2000    means intended to be complete) is done here. Thus, providing an end for a
2001    frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
2002    row in the frame).
2003    """
2004
2005    template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
2006    frame_type: str
2007
2008    def __init__(self, start: int | None = None, end: int | None = None):
2009        self.start = Value(start)
2010        self.end = Value(end)
2011
2012    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
2013        self.start, self.end = exprs
2014
2015    def get_source_expressions(self) -> list[Any]:
2016        return [self.start, self.end]
2017
2018    def as_sql(
2019        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
2020    ) -> tuple[str, list[Any]]:
2021        connection.ops.check_expression_support(self)
2022        start, end = self.window_frame_start_end(
2023            connection, self.start.value, self.end.value
2024        )
2025        return (
2026            self.template
2027            % {
2028                "frame_type": self.frame_type,
2029                "start": start,
2030                "end": end,
2031            },
2032            [],
2033        )
2034
2035    def __repr__(self) -> str:
2036        return f"<{self.__class__.__name__}: {self}>"
2037
2038    def get_group_by_cols(self) -> list[Any]:
2039        return []
2040
2041    def __str__(self) -> str:
2042        if self.start.value is not None and self.start.value < 0:
2043            start = f"{abs(self.start.value)} {db_connection.ops.PRECEDING}"
2044        elif self.start.value is not None and self.start.value == 0:
2045            start = db_connection.ops.CURRENT_ROW
2046        else:
2047            start = db_connection.ops.UNBOUNDED_PRECEDING
2048
2049        if self.end.value is not None and self.end.value > 0:
2050            end = f"{self.end.value} {db_connection.ops.FOLLOWING}"
2051        elif self.end.value is not None and self.end.value == 0:
2052            end = db_connection.ops.CURRENT_ROW
2053        else:
2054            end = db_connection.ops.UNBOUNDED_FOLLOWING
2055        return self.template % {
2056            "frame_type": self.frame_type,
2057            "start": start,
2058            "end": end,
2059        }
2060
2061    @abstractmethod
2062    def window_frame_start_end(
2063        self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2064    ) -> tuple[str, str]: ...
2065
2066
2067class RowRange(WindowFrame):
2068    frame_type = "ROWS"
2069
2070    def window_frame_start_end(
2071        self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2072    ) -> tuple[str, str]:
2073        return connection.ops.window_frame_rows_start_end(start, end)
2074
2075
2076class ValueRange(WindowFrame):
2077    frame_type = "RANGE"
2078
2079    def window_frame_start_end(
2080        self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2081    ) -> tuple[str, str]:
2082        return connection.ops.window_frame_range_start_end(start, end)