1from __future__ import annotations
2
3import copy
4import datetime
5import functools
6import inspect
7from collections import defaultdict
8from decimal import Decimal
9from functools import cached_property
10from types import NoneType
11from typing import TYPE_CHECKING, Any, Protocol, Self, runtime_checkable
12from uuid import UUID
13
14from plain.postgres import fields
15from plain.postgres.constants import LOOKUP_SEP
16from plain.postgres.db import NotSupportedError
17from plain.postgres.dialect import (
18 CURRENT_ROW,
19 FOLLOWING,
20 PRECEDING,
21 UNBOUNDED_FOLLOWING,
22 UNBOUNDED_PRECEDING,
23 combine_expression,
24 quote_name,
25 subtract_temporals,
26 window_frame_range_start_end,
27 window_frame_rows_start_end,
28)
29from plain.postgres.exceptions import EmptyResultSet, FieldError, FullResultSet
30from plain.postgres.query_utils import Q
31from plain.utils.deconstruct import deconstructible
32from plain.utils.hashable import make_hashable
33
34if TYPE_CHECKING:
35 from collections.abc import Callable, Iterable, Sequence
36
37 from plain.postgres.connection import DatabaseConnection
38 from plain.postgres.fields import Field
39 from plain.postgres.lookups import Lookup, Transform
40 from plain.postgres.query import QuerySet
41 from plain.postgres.sql.compiler import SQLCompilable, SQLCompiler
42 from plain.postgres.sql.query import Query
43
44__all__ = [
45 # Core expression classes
46 "F",
47 "Value",
48 "Case",
49 "When",
50 "Subquery",
51 "Exists",
52 "OuterRef",
53 "Window",
54 "ExpressionWrapper",
55 "RawSQL",
56 "OrderBy",
57 # Base classes (for extension)
58 "Func",
59 "Expression",
60 "Combinable",
61 # Window frame specs
62 "RowRange",
63 "ValueRange",
64]
65
66
67@runtime_checkable
68class ResolvableExpression(Protocol):
69 """Protocol for expressions that can be resolved in query context."""
70
71 def resolve_expression(
72 self,
73 query: Any = None,
74 allow_joins: bool = True,
75 reuse: Any = None,
76 summarize: bool = False,
77 for_save: bool = False,
78 ) -> Any: ...
79
80
81@runtime_checkable
82class ReplaceableExpression(Protocol):
83 """Protocol for expressions that support expression replacement."""
84
85 def replace_expressions(self, replacements: dict[Any, Any]) -> Self: ...
86
87
88class Combinable:
89 """
90 Provide the ability to combine one or two objects with
91 some connector. For example F('foo') + F('bar').
92 """
93
94 # Arithmetic connectors
95 ADD = "+"
96 SUB = "-"
97 MUL = "*"
98 DIV = "/"
99 POW = "^"
100 # The following is a quoted % operator - it is quoted because it can be
101 # used in strings that also have parameter substitution.
102 MOD = "%%"
103
104 # Bitwise operators - note that these are generated by .bitand()
105 # and .bitor(), the '&' and '|' are reserved for boolean operator
106 # usage.
107 BITAND = "&"
108 BITOR = "|"
109 BITLEFTSHIFT = "<<"
110 BITRIGHTSHIFT = ">>"
111 BITXOR = "#"
112
113 def _combine(
114 self, other: Any, connector: str, reversed: bool
115 ) -> CombinedExpression:
116 if not isinstance(other, ResolvableExpression):
117 # everything must be resolvable to an expression
118 other = Value(other)
119
120 if reversed:
121 return CombinedExpression(other, connector, self)
122 return CombinedExpression(self, connector, other)
123
124 #############
125 # OPERATORS #
126 #############
127
128 def __neg__(self) -> CombinedExpression:
129 return self._combine(-1, self.MUL, False)
130
131 def __add__(self, other: Any) -> CombinedExpression:
132 return self._combine(other, self.ADD, False)
133
134 def __sub__(self, other: Any) -> CombinedExpression:
135 return self._combine(other, self.SUB, False)
136
137 def __mul__(self, other: Any) -> CombinedExpression:
138 return self._combine(other, self.MUL, False)
139
140 def __truediv__(self, other: Any) -> CombinedExpression:
141 return self._combine(other, self.DIV, False)
142
143 def __mod__(self, other: Any) -> CombinedExpression:
144 return self._combine(other, self.MOD, False)
145
146 def __pow__(self, other: Any) -> CombinedExpression:
147 return self._combine(other, self.POW, False)
148
149 def __and__(self, other: Any) -> Q:
150 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
151 return Q(self) & Q(other)
152 raise NotImplementedError(
153 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
154 )
155
156 def bitand(self, other: Any) -> CombinedExpression:
157 return self._combine(other, self.BITAND, False)
158
159 def bitleftshift(self, other: Any) -> CombinedExpression:
160 return self._combine(other, self.BITLEFTSHIFT, False)
161
162 def bitrightshift(self, other: Any) -> CombinedExpression:
163 return self._combine(other, self.BITRIGHTSHIFT, False)
164
165 def __xor__(self, other: Any) -> Q:
166 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
167 return Q(self) ^ Q(other)
168 raise NotImplementedError(
169 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
170 )
171
172 def bitxor(self, other: Any) -> CombinedExpression:
173 return self._combine(other, self.BITXOR, False)
174
175 def __or__(self, other: Any) -> Q:
176 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
177 return Q(self) | Q(other)
178 raise NotImplementedError(
179 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
180 )
181
182 def bitor(self, other: Any) -> CombinedExpression:
183 return self._combine(other, self.BITOR, False)
184
185 def __radd__(self, other: Any) -> CombinedExpression:
186 return self._combine(other, self.ADD, True)
187
188 def __rsub__(self, other: Any) -> CombinedExpression:
189 return self._combine(other, self.SUB, True)
190
191 def __rmul__(self, other: Any) -> CombinedExpression:
192 return self._combine(other, self.MUL, True)
193
194 def __rtruediv__(self, other: Any) -> CombinedExpression:
195 return self._combine(other, self.DIV, True)
196
197 def __rmod__(self, other: Any) -> CombinedExpression:
198 return self._combine(other, self.MOD, True)
199
200 def __rpow__(self, other: Any) -> CombinedExpression:
201 return self._combine(other, self.POW, True)
202
203 def __rand__(self, other: Any) -> None:
204 raise NotImplementedError(
205 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
206 )
207
208 def __ror__(self, other: Any) -> None:
209 raise NotImplementedError(
210 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
211 )
212
213 def __rxor__(self, other: Any) -> None:
214 raise NotImplementedError(
215 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
216 )
217
218 def __invert__(self) -> NegatedExpression:
219 return NegatedExpression(self)
220
221
222class BaseExpression:
223 """Base class for all query expressions."""
224
225 empty_result_set_value = NotImplemented
226 # aggregate specific fields
227 is_summary = False
228 _output_field_resolved_to_none = False
229 # Can the expression be used in a WHERE clause?
230 filterable = True
231 # Can the expression can be used as a source expression in Window?
232 window_compatible = False
233
234 def __init__(self, output_field: Field | None = None):
235 if output_field is not None:
236 self.output_field = output_field
237
238 def __getstate__(self) -> dict[str, Any]:
239 state = self.__dict__.copy()
240 state.pop("convert_value", None)
241 return state
242
243 def get_db_converters(
244 self, connection: DatabaseConnection
245 ) -> list[Callable[..., Any]]:
246 converters = []
247 if self.convert_value is not self._convert_value_noop:
248 converters.append(self.convert_value)
249 converters.extend(self.output_field.get_db_converters(connection))
250 return converters
251
252 def get_source_expressions(self) -> list[Any]:
253 return []
254
255 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
256 assert not exprs
257
258 def _parse_expressions(self, *expressions: Any) -> list[Any]:
259 return [
260 arg
261 if isinstance(arg, ResolvableExpression)
262 else (F(arg) if isinstance(arg, str) else Value(arg))
263 for arg in expressions
264 ]
265
266 def as_sql(
267 self, compiler: SQLCompiler, connection: DatabaseConnection
268 ) -> tuple[str, Sequence[Any]]:
269 """
270 Return a (sql, params) tuple to be included in the current query.
271
272 Arguments:
273 * compiler: the query compiler responsible for generating the query.
274 Must have a compile method, returning a (sql, [params]) tuple.
275 Calling compiler(value) will return a quoted `value`.
276
277 * connection: the database connection used for the current query.
278
279 Return: (sql, params)
280 Where `sql` is a string containing ordered sql parameters to be
281 replaced with the elements of the list `params`.
282 """
283 raise NotImplementedError("Subclasses must implement as_sql()")
284
285 @cached_property
286 def contains_aggregate(self) -> bool:
287 return any(
288 expr and expr.contains_aggregate for expr in self.get_source_expressions()
289 )
290
291 @cached_property
292 def contains_over_clause(self) -> bool:
293 return any(
294 expr and expr.contains_over_clause for expr in self.get_source_expressions()
295 )
296
297 @cached_property
298 def contains_column_references(self) -> bool:
299 return any(
300 expr and expr.contains_column_references
301 for expr in self.get_source_expressions()
302 )
303
304 def resolve_expression(
305 self,
306 query: Any = None,
307 allow_joins: bool = True,
308 reuse: Any = None,
309 summarize: bool = False,
310 for_save: bool = False,
311 ) -> Self:
312 """
313 Provide the chance to do any preprocessing or validation before being
314 added to the query.
315
316 Arguments:
317 * query: the backend query implementation
318 * allow_joins: boolean allowing or denying use of joins
319 in this query
320 * reuse: a set of reusable joins for multijoins
321 * summarize: a terminal aggregate clause
322 * for_save: whether this expression about to be used in a save or update
323
324 Return: an Expression to be added to the query.
325 """
326 c = self.copy()
327 c.is_summary = summarize
328 c.set_source_expressions(
329 [
330 expr.resolve_expression(query, allow_joins, reuse, summarize)
331 if expr
332 else None
333 for expr in c.get_source_expressions()
334 ]
335 )
336 return c
337
338 @property
339 def conditional(self) -> bool:
340 output_field = getattr(self, "output_field", None)
341 return isinstance(output_field, fields.BooleanField)
342
343 @property
344 def field(self) -> Field:
345 return self.output_field
346
347 @cached_property
348 def output_field(self) -> Field:
349 """Return the output type of this expressions."""
350 output_field = self._resolve_output_field()
351 if output_field is None:
352 self._output_field_resolved_to_none = True
353 raise FieldError("Cannot resolve expression type, unknown output_field")
354 return output_field
355
356 @cached_property
357 def _output_field_or_none(self) -> Field | None:
358 """
359 Return the output field of this expression, or None if
360 _resolve_output_field() didn't return an output type.
361 """
362 try:
363 return self.output_field
364 except FieldError:
365 if not self._output_field_resolved_to_none:
366 raise
367 return None
368
369 def _resolve_output_field(self) -> Field | None:
370 """
371 Attempt to infer the output type of the expression.
372
373 As a guess, if the output fields of all source fields match then simply
374 infer the same type here.
375
376 If a source's output field resolves to None, exclude it from this check.
377 If all sources are None, then an error is raised higher up the stack in
378 the output_field property.
379 """
380 # This guess is mostly a bad idea, but there is quite a lot of code
381 # (especially 3rd party Func subclasses) that depend on it, we'd need a
382 # deprecation path to fix it.
383 sources_iter = (
384 source for source in self.get_source_fields() if source is not None
385 )
386 for output_field in sources_iter:
387 for source in sources_iter:
388 if not isinstance(output_field, source.__class__):
389 raise FieldError(
390 f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
391 "set output_field."
392 )
393 return output_field
394 return None
395
396 @staticmethod
397 def _convert_value_noop(
398 value: Any, expression: Any, connection: DatabaseConnection
399 ) -> Any:
400 return value
401
402 @cached_property
403 def convert_value(self) -> Callable[[Any, Any, Any], Any]:
404 """
405 Expressions provide their own converters because users have the option
406 of manually specifying the output_field which may be a different type
407 from the one the database returns.
408 """
409 field = self.output_field
410 internal_type = field.get_internal_type()
411 if internal_type == "FloatField":
412 return (
413 lambda value, expression, connection: None
414 if value is None
415 else float(value)
416 )
417 elif internal_type.endswith("IntegerField"):
418 return (
419 lambda value, expression, connection: None
420 if value is None
421 else int(value)
422 )
423 elif internal_type == "DecimalField":
424 return (
425 lambda value, expression, connection: None
426 if value is None
427 else Decimal(value)
428 )
429 return self._convert_value_noop
430
431 def get_lookup(self, lookup: str) -> type[Lookup] | None:
432 return self.output_field.get_lookup(lookup)
433
434 def get_transform(self, name: str) -> type[Transform] | None:
435 return self.output_field.get_transform(name) # type: ignore[return-type]
436
437 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
438 clone = self.copy()
439 clone.set_source_expressions(
440 [
441 e.relabeled_clone(change_map) if e is not None else None
442 for e in self.get_source_expressions()
443 ]
444 )
445 return clone
446
447 def replace_expressions(self, replacements: dict[BaseExpression, Any]) -> Self:
448 if replacement := replacements.get(self):
449 return replacement
450 clone = self.copy()
451 source_expressions = clone.get_source_expressions()
452 clone.set_source_expressions(
453 [
454 expr.replace_expressions(replacements) if expr else None
455 for expr in source_expressions
456 ]
457 )
458 return clone
459
460 def get_refs(self) -> set[str]:
461 refs = set()
462 for expr in self.get_source_expressions():
463 refs |= expr.get_refs()
464 return refs
465
466 def copy(self) -> Self:
467 return copy.copy(self)
468
469 def prefix_references(self, prefix: str) -> Self:
470 clone = self.copy()
471 clone.set_source_expressions(
472 [
473 F(f"{prefix}{expr.name}")
474 if isinstance(expr, F)
475 else expr.prefix_references(prefix)
476 for expr in self.get_source_expressions()
477 ]
478 )
479 return clone
480
481 def get_group_by_cols(self) -> list[BaseExpression]:
482 if not self.contains_aggregate:
483 return [self]
484 cols: list[BaseExpression] = []
485 for source in self.get_source_expressions():
486 cols.extend(source.get_group_by_cols())
487 return cols
488
489 def get_source_fields(self) -> list[Field | None]:
490 """Return the underlying field types used by this aggregate."""
491 return [e._output_field_or_none for e in self.get_source_expressions()]
492
493 def asc(self, **kwargs: Any) -> OrderBy:
494 return OrderBy(self, **kwargs)
495
496 def desc(self, **kwargs: Any) -> OrderBy:
497 return OrderBy(self, descending=True, **kwargs)
498
499 def reverse_ordering(self) -> Self:
500 return self
501
502 def flatten(self) -> Iterable[Any]:
503 """
504 Recursively yield this expression and all subexpressions, in
505 depth-first order.
506 """
507 yield self
508 for expr in self.get_source_expressions():
509 if expr:
510 if hasattr(expr, "flatten"):
511 yield from expr.flatten()
512 else:
513 yield expr
514
515 def select_format(
516 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
517 ) -> tuple[str, Sequence[Any]]:
518 """Custom format for select clauses."""
519 if output_field := getattr(self, "output_field", None):
520 if select_format := getattr(output_field, "select_format", None):
521 return select_format(compiler, sql, params)
522 return sql, params
523
524
525@deconstructible
526class Expression(BaseExpression, Combinable):
527 """An expression that can be combined with other expressions."""
528
529 # Set by @deconstructible decorator in __new__
530 _constructor_args: tuple[tuple[Any, ...], dict[str, Any]]
531
532 @cached_property
533 def identity(self) -> tuple[Any, ...]:
534 constructor_signature = inspect.signature(self.__init__)
535 args, kwargs = self._constructor_args
536 signature = constructor_signature.bind_partial(*args, **kwargs)
537 signature.apply_defaults()
538 arguments = signature.arguments.items()
539 identity: list[Any] = [self.__class__]
540 for arg, value in arguments:
541 if isinstance(value, fields.Field):
542 if value.name and value.model:
543 value = (value.model.model_options.label, value.name)
544 else:
545 value = type(value)
546 else:
547 value = make_hashable(value)
548 identity.append((arg, value))
549 return tuple(identity)
550
551 def __eq__(self, other: object) -> bool:
552 if not isinstance(other, Expression):
553 return NotImplemented
554 return other.identity == self.identity
555
556 def __hash__(self) -> int:
557 return hash(self.identity)
558
559
560# Type inference for CombinedExpression.output_field.
561# Missing items will result in FieldError, by design.
562#
563# The current approach for NULL is based on lowest common denominator behavior
564# i.e. if one of the supported databases is raising an error (rather than
565# return NULL) for `val <op> NULL`, then Plain raises FieldError.
566
567_connector_combinations = [
568 # Numeric operations - operands of same type.
569 {
570 connector: [
571 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
572 (fields.FloatField, fields.FloatField, fields.FloatField),
573 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
574 ]
575 for connector in (
576 Combinable.ADD,
577 Combinable.SUB,
578 Combinable.MUL,
579 Combinable.DIV,
580 Combinable.MOD,
581 Combinable.POW,
582 )
583 },
584 # Numeric operations - operands of different type.
585 {
586 connector: [
587 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
588 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
589 (fields.IntegerField, fields.FloatField, fields.FloatField),
590 (fields.FloatField, fields.IntegerField, fields.FloatField),
591 ]
592 for connector in (
593 Combinable.ADD,
594 Combinable.SUB,
595 Combinable.MUL,
596 Combinable.DIV,
597 Combinable.MOD,
598 )
599 },
600 # Bitwise operators.
601 {
602 connector: [
603 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
604 ]
605 for connector in (
606 Combinable.BITAND,
607 Combinable.BITOR,
608 Combinable.BITLEFTSHIFT,
609 Combinable.BITRIGHTSHIFT,
610 Combinable.BITXOR,
611 )
612 },
613 # Numeric with NULL.
614 {
615 connector: [
616 (field_type, NoneType, field_type),
617 (NoneType, field_type, field_type),
618 ]
619 for connector in (
620 Combinable.ADD,
621 Combinable.SUB,
622 Combinable.MUL,
623 Combinable.DIV,
624 Combinable.MOD,
625 Combinable.POW,
626 )
627 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
628 },
629 # Date/DateTimeField/DurationField/TimeField.
630 {
631 Combinable.ADD: [
632 # Date/DateTimeField.
633 (fields.DateField, fields.DurationField, fields.DateTimeField),
634 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
635 (fields.DurationField, fields.DateField, fields.DateTimeField),
636 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
637 # DurationField.
638 (fields.DurationField, fields.DurationField, fields.DurationField),
639 # TimeField.
640 (fields.TimeField, fields.DurationField, fields.TimeField),
641 (fields.DurationField, fields.TimeField, fields.TimeField),
642 ],
643 },
644 {
645 Combinable.SUB: [
646 # Date/DateTimeField.
647 (fields.DateField, fields.DurationField, fields.DateTimeField),
648 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
649 (fields.DateField, fields.DateField, fields.DurationField),
650 (fields.DateField, fields.DateTimeField, fields.DurationField),
651 (fields.DateTimeField, fields.DateField, fields.DurationField),
652 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
653 # DurationField.
654 (fields.DurationField, fields.DurationField, fields.DurationField),
655 # TimeField.
656 (fields.TimeField, fields.DurationField, fields.TimeField),
657 (fields.TimeField, fields.TimeField, fields.DurationField),
658 ],
659 },
660]
661
662_connector_combinators = defaultdict(list)
663
664
665def register_combinable_fields(
666 lhs: type[Field] | type[None],
667 connector: str,
668 rhs: type[Field] | type[None],
669 result: type[Field],
670) -> None:
671 """
672 Register combinable types:
673 lhs <connector> rhs -> result
674 e.g.
675 register_combinable_fields(
676 IntegerField, Combinable.ADD, FloatField, FloatField
677 )
678 """
679 _connector_combinators[connector].append((lhs, rhs, result))
680
681
682for d in _connector_combinations:
683 for connector, field_types in d.items():
684 for lhs, rhs, result in field_types:
685 register_combinable_fields(lhs, connector, rhs, result)
686
687
688@functools.lru_cache(maxsize=128)
689def _resolve_combined_type(
690 connector: str, lhs_type: type[Field], rhs_type: type[Field]
691) -> type[Field] | None:
692 combinators = _connector_combinators.get(connector, ())
693 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
694 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
695 rhs_type, combinator_rhs_type
696 ):
697 return combined_type
698 return None
699
700
701class CombinedExpression(Expression):
702 def __init__(
703 self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
704 ):
705 super().__init__(output_field=output_field)
706 self.connector = connector
707 self.lhs = lhs
708 self.rhs = rhs
709
710 def __repr__(self) -> str:
711 return f"<{self.__class__.__name__}: {self}>"
712
713 def __str__(self) -> str:
714 return f"{self.lhs} {self.connector} {self.rhs}"
715
716 def get_source_expressions(self) -> list[Any]:
717 return [self.lhs, self.rhs]
718
719 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
720 self.lhs, self.rhs = exprs
721
722 def _resolve_output_field(self) -> Field | None:
723 # We avoid using super() here for reasons given in
724 # Expression._resolve_output_field()
725 combined_type = _resolve_combined_type(
726 self.connector,
727 type(self.lhs._output_field_or_none),
728 type(self.rhs._output_field_or_none),
729 )
730 if combined_type is None:
731 raise FieldError(
732 f"Cannot infer type of {self.connector!r} expression involving these "
733 f"types: {self.lhs.output_field.__class__.__name__}, "
734 f"{self.rhs.output_field.__class__.__name__}. You must set "
735 f"output_field."
736 )
737 return combined_type()
738
739 def as_sql(
740 self, compiler: SQLCompiler, connection: DatabaseConnection
741 ) -> tuple[str, list[Any]]:
742 expressions = []
743 expression_params = []
744 sql, params = compiler.compile(self.lhs)
745 expressions.append(sql)
746 expression_params.extend(params)
747 sql, params = compiler.compile(self.rhs)
748 expressions.append(sql)
749 expression_params.extend(params)
750 # order of precedence
751 expression_wrapper = "(%s)"
752 sql = combine_expression(self.connector, expressions)
753 return expression_wrapper % sql, expression_params
754
755 def resolve_expression(
756 self,
757 query: Any = None,
758 allow_joins: bool = True,
759 reuse: Any = None,
760 summarize: bool = False,
761 for_save: bool = False,
762 ) -> CombinedExpression | TemporalSubtraction:
763 lhs = self.lhs.resolve_expression(
764 query, allow_joins, reuse, summarize, for_save
765 )
766 rhs = self.rhs.resolve_expression(
767 query, allow_joins, reuse, summarize, for_save
768 )
769 if not isinstance(self, TemporalSubtraction):
770 try:
771 lhs_type = lhs.output_field.get_internal_type()
772 except (AttributeError, FieldError):
773 lhs_type = None
774 try:
775 rhs_type = rhs.output_field.get_internal_type()
776 except (AttributeError, FieldError):
777 rhs_type = None
778 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
779 if (
780 self.connector == self.SUB
781 and lhs_type in datetime_fields
782 and lhs_type == rhs_type
783 ):
784 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
785 query,
786 allow_joins,
787 reuse,
788 summarize,
789 for_save,
790 )
791 c = self.copy()
792 c.is_summary = summarize
793 c.lhs = lhs
794 c.rhs = rhs
795 return c
796
797
798class TemporalSubtraction(CombinedExpression):
799 output_field = fields.DurationField()
800
801 def __init__(self, lhs: Any, rhs: Any):
802 super().__init__(lhs, self.SUB, rhs)
803
804 def as_sql(
805 self, compiler: SQLCompiler, connection: DatabaseConnection
806 ) -> tuple[str, list[Any]]:
807 lhs = compiler.compile(self.lhs)
808 rhs = compiler.compile(self.rhs)
809 sql, params = subtract_temporals(
810 self.lhs.output_field.get_internal_type(), lhs, rhs
811 )
812 return sql, list(params)
813
814
815@deconstructible(path="plain.postgres.F")
816class F(Combinable):
817 """An object capable of resolving references to existing query objects."""
818
819 def __init__(self, name: str):
820 """
821 Arguments:
822 * name: the name of the field this expression references
823 """
824 self.name = name
825
826 def __repr__(self) -> str:
827 return f"{self.__class__.__name__}({self.name})"
828
829 def resolve_expression(
830 self,
831 query: Any = None,
832 allow_joins: bool = True,
833 reuse: Any = None,
834 summarize: bool = False,
835 for_save: bool = False,
836 ) -> Any:
837 return query.resolve_ref(self.name, allow_joins, reuse, summarize)
838
839 def replace_expressions(self, replacements: dict[Any, Any]) -> F:
840 return replacements.get(self, self)
841
842 def asc(self, **kwargs: Any) -> OrderBy:
843 return OrderBy(self, **kwargs)
844
845 def desc(self, **kwargs: Any) -> OrderBy:
846 return OrderBy(self, descending=True, **kwargs)
847
848 def __eq__(self, other: object) -> bool:
849 if not isinstance(other, F):
850 return NotImplemented
851 return self.__class__ == other.__class__ and self.name == other.name
852
853 def __hash__(self) -> int:
854 return hash(self.name)
855
856 def copy(self) -> Self:
857 return copy.copy(self)
858
859
860class ResolvedOuterRef(F):
861 """
862 An object that contains a reference to an outer query.
863
864 In this case, the reference to the outer query has been resolved because
865 the inner query has been used as a subquery.
866 """
867
868 contains_aggregate = False
869 contains_over_clause = False
870
871 def as_sql(self, *args: Any, **kwargs: Any) -> None:
872 raise ValueError(
873 "This queryset contains a reference to an outer query and may "
874 "only be used in a subquery."
875 )
876
877 def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
878 col = super().resolve_expression(*args, **kwargs)
879 if col.contains_over_clause:
880 raise NotSupportedError(
881 f"Referencing outer query window expression is not supported: "
882 f"{self.name}."
883 )
884 # FIXME: Rename possibly_multivalued to multivalued and fix detection
885 # for non-multivalued JOINs (e.g. foreign key fields). This should take
886 # into accountย only many-to-many and one-to-many relationships.
887 col.possibly_multivalued = LOOKUP_SEP in self.name
888 return col
889
890 def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
891 return self
892
893 def get_group_by_cols(self) -> list[Any]:
894 return []
895
896
897class OuterRef(F):
898 contains_aggregate = False
899
900 def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
901 if isinstance(self.name, self.__class__):
902 return self.name
903 return ResolvedOuterRef(self.name)
904
905 def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
906 return self
907
908
909@deconstructible(path="plain.postgres.Func")
910class Func(Expression):
911 """An SQL function call."""
912
913 function = None
914 template = "%(function)s(%(expressions)s)"
915 arg_joiner = ", "
916 arity = None # The number of arguments the function accepts.
917
918 def __init__(
919 self, *expressions: Any, output_field: Field | None = None, **extra: Any
920 ):
921 if self.arity is not None and len(expressions) != self.arity:
922 raise TypeError(
923 "'{}' takes exactly {} {} ({} given)".format(
924 self.__class__.__name__,
925 self.arity,
926 "argument" if self.arity == 1 else "arguments",
927 len(expressions),
928 )
929 )
930 super().__init__(output_field=output_field)
931 self.source_expressions: list[Any] = self._parse_expressions(*expressions)
932 self.extra = extra
933
934 def __repr__(self) -> str:
935 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
936 extra = {**self.extra, **self._get_repr_options()}
937 if extra:
938 extra = ", ".join(
939 str(key) + "=" + str(val) for key, val in sorted(extra.items())
940 )
941 return f"{self.__class__.__name__}({args}, {extra})"
942 return f"{self.__class__.__name__}({args})"
943
944 def _get_repr_options(self) -> dict[str, Any]:
945 """Return a dict of extra __init__() options to include in the repr."""
946 return {}
947
948 def get_source_expressions(self) -> list[Any]:
949 return self.source_expressions
950
951 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
952 self.source_expressions = list(exprs)
953
954 def resolve_expression(
955 self,
956 query: Any = None,
957 allow_joins: bool = True,
958 reuse: Any = None,
959 summarize: bool = False,
960 for_save: bool = False,
961 ) -> Self:
962 c = self.copy()
963 c.is_summary = summarize
964 for pos, arg in enumerate(c.source_expressions):
965 c.source_expressions[pos] = arg.resolve_expression(
966 query, allow_joins, reuse, summarize, for_save
967 )
968 return c
969
970 def as_sql(
971 self,
972 compiler: SQLCompiler,
973 connection: DatabaseConnection,
974 function: str | None = None,
975 template: str | None = None,
976 arg_joiner: str | None = None,
977 **extra_context: Any,
978 ) -> tuple[str, list[Any]]:
979 sql_parts = []
980 params = []
981 for arg in self.source_expressions:
982 try:
983 arg_sql, arg_params = compiler.compile(arg)
984 except EmptyResultSet:
985 empty_result_set_value = getattr(
986 arg, "empty_result_set_value", NotImplemented
987 )
988 if empty_result_set_value is NotImplemented:
989 raise
990 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
991 except FullResultSet:
992 arg_sql, arg_params = compiler.compile(Value(True))
993 sql_parts.append(arg_sql)
994 params.extend(arg_params)
995 data = {**self.extra, **extra_context}
996 # Use the first supplied value in this order: the parameter to this
997 # method, a value supplied in __init__()'s **extra (the value in
998 # `data`), or the value defined on the class.
999 if function is not None:
1000 data["function"] = function
1001 else:
1002 data.setdefault("function", self.function)
1003 template = template or data.get("template", self.template)
1004 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1005 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1006 return template % data, params
1007
1008 def copy(self) -> Self:
1009 clone = super().copy()
1010 clone.source_expressions = self.source_expressions[:]
1011 clone.extra = self.extra.copy()
1012 return clone
1013
1014
1015@deconstructible(path="plain.postgres.Value")
1016class Value(Expression):
1017 """Represent a wrapped value as a node within an expression."""
1018
1019 # Provide a default value for `for_save` in order to allow unresolved
1020 # instances to be compiled until a decision is taken in #25425.
1021 for_save = False
1022
1023 def __init__(self, value: Any, output_field: Field | None = None):
1024 """
1025 Arguments:
1026 * value: the value this expression represents. The value will be
1027 added into the sql parameter list and properly quoted.
1028
1029 * output_field: an instance of the model field type that this
1030 expression will return, such as IntegerField() or CharField().
1031 """
1032 super().__init__(output_field=output_field)
1033 self.value = value
1034
1035 def __repr__(self) -> str:
1036 return f"{self.__class__.__name__}({self.value!r})"
1037
1038 def as_sql(
1039 self, compiler: SQLCompiler, connection: DatabaseConnection
1040 ) -> tuple[str, list[Any]]:
1041 val = self.value
1042 output_field = self._output_field_or_none
1043 if output_field is not None:
1044 if self.for_save:
1045 val = output_field.get_db_prep_save(val, connection=connection)
1046 else:
1047 val = output_field.get_db_prep_value(val, connection=connection)
1048 if hasattr(output_field, "get_placeholder"):
1049 return output_field.get_placeholder(val, compiler, connection), [val] # type: ignore[call-non-callable]
1050 if val is None:
1051 return "NULL", []
1052 return "%s", [val]
1053
1054 def resolve_expression(
1055 self,
1056 query: Any = None,
1057 allow_joins: bool = True,
1058 reuse: Any = None,
1059 summarize: bool = False,
1060 for_save: bool = False,
1061 ) -> Value:
1062 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1063 c.for_save = for_save
1064 return c
1065
1066 def get_group_by_cols(self) -> list[Any]:
1067 return []
1068
1069 def _resolve_output_field(self) -> Field | None:
1070 if isinstance(self.value, str):
1071 return fields.CharField()
1072 if isinstance(self.value, bool):
1073 return fields.BooleanField()
1074 if isinstance(self.value, int):
1075 return fields.IntegerField()
1076 if isinstance(self.value, float):
1077 return fields.FloatField()
1078 if isinstance(self.value, datetime.datetime):
1079 return fields.DateTimeField()
1080 if isinstance(self.value, datetime.date):
1081 return fields.DateField()
1082 if isinstance(self.value, datetime.time):
1083 return fields.TimeField()
1084 if isinstance(self.value, datetime.timedelta):
1085 return fields.DurationField()
1086 if isinstance(self.value, Decimal):
1087 return fields.DecimalField()
1088 if isinstance(self.value, bytes):
1089 return fields.BinaryField()
1090 if isinstance(self.value, UUID):
1091 return fields.UUIDField()
1092
1093 @property
1094 def empty_result_set_value(self) -> Any:
1095 return self.value
1096
1097
1098class RawSQL(Expression):
1099 def __init__(
1100 self, sql: str, params: Sequence[Any], output_field: Field | None = None
1101 ):
1102 if output_field is None:
1103 output_field = fields.Field()
1104 self.sql, self.params = sql, params
1105 super().__init__(output_field=output_field)
1106
1107 def __repr__(self) -> str:
1108 return f"{self.__class__.__name__}({self.sql}, {self.params})"
1109
1110 def as_sql(
1111 self, compiler: SQLCompiler, connection: DatabaseConnection
1112 ) -> tuple[str, Sequence[Any]]:
1113 return f"({self.sql})", self.params
1114
1115 def get_group_by_cols(self) -> list[BaseExpression]:
1116 return [self]
1117
1118
1119class Star(Expression):
1120 def __repr__(self) -> str:
1121 return "'*'"
1122
1123 def as_sql(
1124 self, compiler: SQLCompiler, connection: DatabaseConnection
1125 ) -> tuple[str, list[Any]]:
1126 return "*", []
1127
1128
1129class Col(Expression):
1130 contains_column_references = True
1131 possibly_multivalued = False
1132
1133 def __init__(
1134 self, alias: str | None, target: Any, output_field: Field | None = None
1135 ):
1136 if output_field is None:
1137 output_field = target
1138 super().__init__(output_field=output_field)
1139 self.alias, self.target = alias, target
1140
1141 def __repr__(self) -> str:
1142 alias, target = self.alias, self.target
1143 identifiers = (alias, str(target)) if alias else (str(target),)
1144 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1145
1146 def as_sql(
1147 self, compiler: SQLCompiler, connection: DatabaseConnection
1148 ) -> tuple[str, list[Any]]:
1149 alias, column = self.alias, self.target.column
1150 identifiers = (alias, column) if alias else (column,)
1151 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1152 return sql, []
1153
1154 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1155 if self.alias is None:
1156 return self
1157 return self.__class__(
1158 change_map.get(self.alias, self.alias), self.target, self.output_field
1159 )
1160
1161 def get_group_by_cols(self) -> list[BaseExpression]:
1162 return [self]
1163
1164 def get_db_converters(
1165 self, connection: DatabaseConnection
1166 ) -> list[Callable[..., Any]]:
1167 if self.target == self.output_field:
1168 return self.output_field.get_db_converters(connection)
1169 return self.output_field.get_db_converters(
1170 connection
1171 ) + self.target.get_db_converters(connection)
1172
1173
1174class Ref(Expression):
1175 """
1176 Reference to column alias of the query. For example, Ref('sum_cost') in
1177 qs.annotate(sum_cost=Sum('cost')) query.
1178 """
1179
1180 def __init__(self, refs: str, source: Any):
1181 super().__init__()
1182 self.refs, self.source = refs, source
1183
1184 def __repr__(self) -> str:
1185 return f"{self.__class__.__name__}({self.refs}, {self.source})"
1186
1187 def get_source_expressions(self) -> list[Any]:
1188 return [self.source]
1189
1190 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1191 (self.source,) = exprs
1192
1193 def resolve_expression(
1194 self,
1195 query: Any = None,
1196 allow_joins: bool = True,
1197 reuse: Any = None,
1198 summarize: bool = False,
1199 for_save: bool = False,
1200 ) -> Ref:
1201 # The sub-expression `source` has already been resolved, as this is
1202 # just a reference to the name of `source`.
1203 return self
1204
1205 def get_refs(self) -> set[str]:
1206 return {self.refs}
1207
1208 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1209 return self
1210
1211 def as_sql(
1212 self, compiler: SQLCompiler, connection: DatabaseConnection
1213 ) -> tuple[str, list[Any]]:
1214 return quote_name(self.refs), []
1215
1216 def get_group_by_cols(self) -> list[BaseExpression]:
1217 return [self]
1218
1219
1220class ExpressionList(Func):
1221 """
1222 An expression containing multiple expressions. Can be used to provide a
1223 list of expressions as an argument to another expression, like a partition
1224 clause.
1225 """
1226
1227 template = "%(expressions)s"
1228
1229 def __init__(self, *expressions: Any, **extra: Any):
1230 if not expressions:
1231 raise ValueError(
1232 f"{self.__class__.__name__} requires at least one expression."
1233 )
1234 super().__init__(*expressions, **extra)
1235
1236 def __str__(self) -> str:
1237 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1238
1239
1240class OrderByList(Func):
1241 template = "ORDER BY %(expressions)s"
1242
1243 def __init__(self, *expressions: Any, **extra: Any):
1244 expressions_tuple = tuple(
1245 (
1246 OrderBy(F(expr[1:]), descending=True)
1247 if isinstance(expr, str) and expr[0] == "-"
1248 else expr
1249 )
1250 for expr in expressions
1251 )
1252 super().__init__(*expressions_tuple, **extra)
1253
1254 def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, list[Any]]:
1255 if not self.source_expressions:
1256 return "", []
1257 sql, params = super().as_sql(*args, **kwargs)
1258 return sql, list(params)
1259
1260 def get_group_by_cols(self) -> list[Any]:
1261 group_by_cols = []
1262 for order_by in self.get_source_expressions():
1263 group_by_cols.extend(order_by.get_group_by_cols())
1264 return group_by_cols
1265
1266
1267@deconstructible(path="plain.postgres.ExpressionWrapper")
1268class ExpressionWrapper(Expression):
1269 """
1270 An expression that can wrap another expression so that it can provide
1271 extra context to the inner expression, such as the output_field.
1272 """
1273
1274 def __init__(self, expression: Any, output_field: Field):
1275 super().__init__(output_field=output_field)
1276 self.expression = expression
1277
1278 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1279 self.expression = exprs[0]
1280
1281 def get_source_expressions(self) -> list[Any]:
1282 return [self.expression]
1283
1284 def get_group_by_cols(self) -> list[Any]:
1285 if isinstance(self.expression, Expression):
1286 expression = self.expression.copy()
1287 expression.output_field = self.output_field
1288 return expression.get_group_by_cols()
1289 # For non-expressions e.g. an SQL WHERE clause, the entire
1290 # `expression` must be included in the GROUP BY clause.
1291 return super().get_group_by_cols()
1292
1293 def as_sql(
1294 self, compiler: SQLCompiler, connection: DatabaseConnection
1295 ) -> tuple[str, Sequence[Any]]:
1296 return compiler.compile(self.expression)
1297
1298 def __repr__(self) -> str:
1299 return f"{self.__class__.__name__}({self.expression})"
1300
1301
1302class NegatedExpression(ExpressionWrapper):
1303 """The logical negation of a conditional expression."""
1304
1305 def __init__(self, expression: Any):
1306 super().__init__(expression, output_field=fields.BooleanField())
1307
1308 def __invert__(self) -> Any:
1309 return self.expression.copy()
1310
1311 def as_sql(
1312 self, compiler: SQLCompiler, connection: DatabaseConnection
1313 ) -> tuple[str, Sequence[Any]]:
1314 try:
1315 sql, params = super().as_sql(compiler, connection)
1316 except EmptyResultSet:
1317 return compiler.compile(Value(True))
1318 return f"NOT {sql}", params
1319
1320 def resolve_expression(
1321 self,
1322 query: Any = None,
1323 allow_joins: bool = True,
1324 reuse: Any = None,
1325 summarize: bool = False,
1326 for_save: bool = False,
1327 ) -> NegatedExpression:
1328 resolved = super().resolve_expression(
1329 query, allow_joins, reuse, summarize, for_save
1330 )
1331 if not getattr(resolved.expression, "conditional", False):
1332 raise TypeError("Cannot negate non-conditional expressions.")
1333 return resolved
1334
1335 def select_format(
1336 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1337 ) -> tuple[str, Sequence[Any]]:
1338 # Boolean expressions work directly in SELECT
1339 return sql, params
1340
1341
1342@deconstructible(path="plain.postgres.When")
1343class When(Expression):
1344 template = "WHEN %(condition)s THEN %(result)s"
1345 # This isn't a complete conditional expression, must be used in Case().
1346 conditional = False
1347 condition: SQLCompilable
1348
1349 def __init__(
1350 self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1351 ):
1352 lookups_dict: dict[str, Any] | None = lookups or None
1353 if lookups_dict:
1354 if condition is None:
1355 condition, lookups_dict = Q(**lookups_dict), None
1356 elif getattr(condition, "conditional", False):
1357 condition, lookups_dict = Q(condition, **lookups_dict), None
1358 if (
1359 condition is None
1360 or not getattr(condition, "conditional", False)
1361 or lookups_dict
1362 ):
1363 raise TypeError(
1364 "When() supports a Q object, a boolean expression, or lookups "
1365 "as a condition."
1366 )
1367 if isinstance(condition, Q) and not condition:
1368 raise ValueError("An empty Q() can't be used as a When() condition.")
1369 super().__init__(output_field=None)
1370 self.condition = condition # type: ignore[assignment]
1371 self.result = self._parse_expressions(then)[0]
1372
1373 def __str__(self) -> str:
1374 return f"WHEN {self.condition!r} THEN {self.result!r}"
1375
1376 def __repr__(self) -> str:
1377 return f"<{self.__class__.__name__}: {self}>"
1378
1379 def get_source_expressions(self) -> list[Any]:
1380 return [self.condition, self.result]
1381
1382 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1383 self.condition, self.result = exprs
1384
1385 def get_source_fields(self) -> list[Field | None]:
1386 # We're only interested in the fields of the result expressions.
1387 return [self.result._output_field_or_none]
1388
1389 def resolve_expression(
1390 self,
1391 query: Any = None,
1392 allow_joins: bool = True,
1393 reuse: Any = None,
1394 summarize: bool = False,
1395 for_save: bool = False,
1396 ) -> When:
1397 c = self.copy()
1398 c.is_summary = summarize
1399 if isinstance(c.condition, ResolvableExpression):
1400 c.condition = c.condition.resolve_expression(
1401 query, allow_joins, reuse, summarize, False
1402 )
1403 c.result = c.result.resolve_expression(
1404 query, allow_joins, reuse, summarize, for_save
1405 )
1406 return c
1407
1408 def as_sql(
1409 self,
1410 compiler: SQLCompiler,
1411 connection: DatabaseConnection,
1412 template: str | None = None,
1413 **extra_context: Any,
1414 ) -> tuple[str, tuple[Any, ...]]:
1415 template_params = extra_context
1416 sql_params = []
1417 # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1418 condition_sql, condition_params = compiler.compile(self.condition)
1419 template_params["condition"] = condition_sql
1420 result_sql, result_params = compiler.compile(self.result)
1421 template_params["result"] = result_sql
1422 template = template or self.template
1423 return template % template_params, (
1424 *sql_params,
1425 *condition_params,
1426 *result_params,
1427 )
1428
1429 def get_group_by_cols(self) -> list[Any]:
1430 # This is not a complete expression and cannot be used in GROUP BY.
1431 cols = []
1432 for source in self.get_source_expressions():
1433 cols.extend(source.get_group_by_cols())
1434 return cols
1435
1436
1437@deconstructible(path="plain.postgres.Case")
1438class Case(Expression):
1439 """
1440 An SQL searched CASE expression:
1441
1442 CASE
1443 WHEN n > 0
1444 THEN 'positive'
1445 WHEN n < 0
1446 THEN 'negative'
1447 ELSE 'zero'
1448 END
1449 """
1450
1451 template = "CASE %(cases)s ELSE %(default)s END"
1452 case_joiner = " "
1453
1454 def __init__(
1455 self,
1456 *cases: When,
1457 default: Any = None,
1458 output_field: Field | None = None,
1459 **extra: Any,
1460 ):
1461 if not all(isinstance(case, When) for case in cases):
1462 raise TypeError("Positional arguments must all be When objects.")
1463 super().__init__(output_field)
1464 self.cases = list(cases)
1465 self.default = self._parse_expressions(default)[0]
1466 self.extra = extra
1467
1468 def __str__(self) -> str:
1469 return "CASE {}, ELSE {!r}".format(
1470 ", ".join(str(c) for c in self.cases),
1471 self.default,
1472 )
1473
1474 def __repr__(self) -> str:
1475 return f"<{self.__class__.__name__}: {self}>"
1476
1477 def get_source_expressions(self) -> list[Any]:
1478 return self.cases + [self.default]
1479
1480 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1481 *self.cases, self.default = exprs
1482
1483 def resolve_expression(
1484 self,
1485 query: Any = None,
1486 allow_joins: bool = True,
1487 reuse: Any = None,
1488 summarize: bool = False,
1489 for_save: bool = False,
1490 ) -> Case:
1491 c = self.copy()
1492 c.is_summary = summarize
1493 for pos, case in enumerate(c.cases):
1494 c.cases[pos] = case.resolve_expression(
1495 query, allow_joins, reuse, summarize, for_save
1496 )
1497 c.default = c.default.resolve_expression(
1498 query, allow_joins, reuse, summarize, for_save
1499 )
1500 return c
1501
1502 def copy(self) -> Self:
1503 c = super().copy()
1504 c.cases = c.cases[:]
1505 return c
1506
1507 def as_sql(
1508 self,
1509 compiler: SQLCompiler,
1510 connection: DatabaseConnection,
1511 template: str | None = None,
1512 case_joiner: str | None = None,
1513 **extra_context: Any,
1514 ) -> tuple[str, list[Any]]:
1515 if not self.cases:
1516 sql, params = compiler.compile(self.default)
1517 return sql, list(params)
1518 template_params = {**self.extra, **extra_context}
1519 case_parts = []
1520 sql_params = []
1521 default_sql, default_params = compiler.compile(self.default)
1522 for case in self.cases:
1523 try:
1524 case_sql, case_params = compiler.compile(case)
1525 except EmptyResultSet:
1526 continue
1527 except FullResultSet:
1528 default_sql, default_params = compiler.compile(case.result)
1529 break
1530 case_parts.append(case_sql)
1531 sql_params.extend(case_params)
1532 if not case_parts:
1533 return default_sql, list(default_params)
1534 case_joiner = case_joiner or self.case_joiner
1535 template_params["cases"] = case_joiner.join(case_parts)
1536 template_params["default"] = default_sql
1537 sql_params.extend(default_params)
1538 template = template or template_params.get("template", self.template)
1539 sql = template % template_params
1540 if self._output_field_or_none is not None:
1541 sql = connection.unification_cast_sql(self.output_field) % sql
1542 return sql, sql_params
1543
1544 def get_group_by_cols(self) -> list[Any]:
1545 if not self.cases:
1546 return self.default.get_group_by_cols()
1547 return super().get_group_by_cols()
1548
1549
1550class Subquery(BaseExpression, Combinable):
1551 """
1552 An explicit subquery. It may contain OuterRef() references to the outer
1553 query which will be resolved when it is applied to that query.
1554 """
1555
1556 template = "(%(subquery)s)"
1557 contains_aggregate = False
1558 empty_result_set_value = None
1559
1560 def __init__(
1561 self,
1562 query: QuerySet[Any] | Query,
1563 output_field: Field | None = None,
1564 **extra: Any,
1565 ):
1566 # Import here to avoid circular import
1567 from plain.postgres.sql.query import Query
1568
1569 # Allow the usage of both QuerySet and sql.Query objects.
1570 if isinstance(query, Query):
1571 # It's already a Query object, use it directly
1572 sql_query = query
1573 else:
1574 # It's a QuerySet, extract the sql.Query
1575 sql_query = query.sql_query
1576 self.query = sql_query.clone()
1577 self.query.subquery = True
1578 self.extra = extra
1579 super().__init__(output_field)
1580
1581 def get_source_expressions(self) -> list[Any]:
1582 return [self.query]
1583
1584 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1585 self.query = exprs[0]
1586
1587 def _resolve_output_field(self) -> Field | None:
1588 return self.query.output_field
1589
1590 def copy(self) -> Self:
1591 clone = super().copy()
1592 clone.query = clone.query.clone()
1593 return clone
1594
1595 @property
1596 def external_aliases(self) -> dict[str, bool]:
1597 return self.query.external_aliases
1598
1599 def get_external_cols(self) -> list[Any]:
1600 return self.query.get_external_cols()
1601
1602 def as_sql(
1603 self,
1604 compiler: SQLCompiler,
1605 connection: DatabaseConnection,
1606 template: str | None = None,
1607 **extra_context: Any,
1608 ) -> tuple[str, tuple[Any, ...]]:
1609 template_params = {**self.extra, **extra_context}
1610 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1611 template_params["subquery"] = subquery_sql[1:-1]
1612
1613 template = template or template_params.get("template", self.template)
1614 sql = template % template_params
1615 return sql, sql_params
1616
1617 def get_group_by_cols(self) -> list[Any]:
1618 return self.query.get_group_by_cols(wrapper=self)
1619
1620
1621class Exists(Subquery):
1622 template = "EXISTS(%(subquery)s)"
1623 output_field = fields.BooleanField()
1624 empty_result_set_value = False
1625
1626 def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1627 super().__init__(query, **kwargs)
1628 self.query = self.query.exists()
1629
1630 def select_format(
1631 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1632 ) -> tuple[str, Sequence[Any]]:
1633 # Boolean expressions work directly in SELECT
1634 return sql, params
1635
1636
1637@deconstructible(path="plain.postgres.OrderBy")
1638class OrderBy(Expression):
1639 template = "%(expression)s %(ordering)s"
1640 conditional = False
1641
1642 def __init__(
1643 self,
1644 expression: Any,
1645 descending: bool = False,
1646 nulls_first: bool | None = None,
1647 nulls_last: bool | None = None,
1648 ):
1649 if nulls_first and nulls_last:
1650 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1651 if nulls_first is False or nulls_last is False:
1652 raise ValueError("nulls_first and nulls_last values must be True or None.")
1653 self.nulls_first = nulls_first
1654 self.nulls_last = nulls_last
1655 self.descending = descending
1656 if not isinstance(expression, ResolvableExpression):
1657 raise ValueError("expression must be an expression type")
1658 self.expression = expression
1659
1660 def __repr__(self) -> str:
1661 return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1662
1663 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1664 self.expression = exprs[0]
1665
1666 def get_source_expressions(self) -> list[Any]:
1667 return [self.expression]
1668
1669 def as_sql(
1670 self,
1671 compiler: SQLCompiler,
1672 connection: DatabaseConnection,
1673 template: str | None = None,
1674 **extra_context: Any,
1675 ) -> tuple[str, tuple[Any, ...]]:
1676 template = template or self.template
1677 # Handle NULLS FIRST/LAST modifiers
1678 if self.nulls_last:
1679 template = f"{template} NULLS LAST"
1680 elif self.nulls_first:
1681 template = f"{template} NULLS FIRST"
1682 expression_sql, params = compiler.compile(self.expression)
1683 placeholders = {
1684 "expression": expression_sql,
1685 "ordering": "DESC" if self.descending else "ASC",
1686 **extra_context,
1687 }
1688 params *= template.count("%(expression)s")
1689 return (template % placeholders).rstrip(), params
1690
1691 def get_group_by_cols(self) -> list[Any]:
1692 cols = []
1693 for source in self.get_source_expressions():
1694 cols.extend(source.get_group_by_cols())
1695 return cols
1696
1697 def reverse_ordering(self) -> OrderBy:
1698 self.descending = not self.descending
1699 if self.nulls_first:
1700 self.nulls_last = True
1701 self.nulls_first = None
1702 elif self.nulls_last:
1703 self.nulls_first = True
1704 self.nulls_last = None
1705 return self
1706
1707 def asc(self) -> None: # type: ignore[override]
1708 self.descending = False
1709
1710 def desc(self) -> None: # type: ignore[override]
1711 self.descending = True
1712
1713
1714class Window(Expression):
1715 template = "%(expression)s OVER (%(window)s)"
1716 # Although the main expression may either be an aggregate or an
1717 # expression with an aggregate function, the GROUP BY that will
1718 # be introduced in the query as a result is not desired.
1719 contains_aggregate = False
1720 contains_over_clause = True
1721 partition_by: ExpressionList | None
1722 order_by: OrderByList | None
1723
1724 def __init__(
1725 self,
1726 expression: Any,
1727 partition_by: Any = None,
1728 order_by: Any = None,
1729 frame: Any = None,
1730 output_field: Field | None = None,
1731 ):
1732 self.partition_by = partition_by
1733 self.order_by = order_by
1734 self.frame = frame
1735
1736 if not getattr(expression, "window_compatible", False):
1737 raise ValueError(
1738 f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1739 )
1740
1741 if self.partition_by is not None:
1742 partition_by_values = (
1743 self.partition_by
1744 if isinstance(self.partition_by, tuple | list)
1745 else (self.partition_by,)
1746 )
1747 self.partition_by = ExpressionList(*partition_by_values)
1748
1749 if self.order_by is not None:
1750 if isinstance(self.order_by, list | tuple):
1751 self.order_by = OrderByList(*self.order_by)
1752 elif isinstance(self.order_by, BaseExpression | str):
1753 self.order_by = OrderByList(self.order_by)
1754 else:
1755 raise ValueError(
1756 "Window.order_by must be either a string reference to a "
1757 "field, an expression, or a list or tuple of them."
1758 )
1759 super().__init__(output_field=output_field)
1760 self.source_expression = self._parse_expressions(expression)[0]
1761
1762 def _resolve_output_field(self) -> Field | None:
1763 return self.source_expression.output_field
1764
1765 def get_source_expressions(self) -> list[Any]:
1766 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1767
1768 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1769 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1770
1771 def as_sql(
1772 self,
1773 compiler: SQLCompiler,
1774 connection: DatabaseConnection,
1775 template: str | None = None,
1776 ) -> tuple[str, tuple[Any, ...]]:
1777 expr_sql, params = compiler.compile(self.source_expression)
1778 window_sql, window_params = [], ()
1779
1780 if self.partition_by is not None:
1781 sql_expr, sql_params = self.partition_by.as_sql(
1782 compiler=compiler,
1783 connection=connection,
1784 template="PARTITION BY %(expressions)s",
1785 )
1786 window_sql.append(sql_expr)
1787 window_params += tuple(sql_params)
1788
1789 if self.order_by is not None:
1790 order_sql, order_params = compiler.compile(self.order_by)
1791 window_sql.append(order_sql)
1792 window_params += tuple(order_params)
1793
1794 if self.frame:
1795 frame_sql, frame_params = compiler.compile(self.frame)
1796 window_sql.append(frame_sql)
1797 window_params += tuple(frame_params)
1798
1799 template = template or self.template
1800
1801 return (
1802 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1803 (*params, *window_params),
1804 )
1805
1806 def __str__(self) -> str:
1807 return "{} OVER ({}{}{})".format(
1808 str(self.source_expression),
1809 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1810 str(self.order_by or ""),
1811 str(self.frame or ""),
1812 )
1813
1814 def __repr__(self) -> str:
1815 return f"<{self.__class__.__name__}: {self}>"
1816
1817 def get_group_by_cols(self) -> list[Any]:
1818 group_by_cols = []
1819 if self.partition_by:
1820 group_by_cols.extend(self.partition_by.get_group_by_cols())
1821 if self.order_by is not None:
1822 group_by_cols.extend(self.order_by.get_group_by_cols())
1823 return group_by_cols
1824
1825
1826class WindowFrame(Expression):
1827 """
1828 Model the frame clause in window expressions. There are two types of frame
1829 clauses which are subclasses, however, all processing and validation (by no
1830 means intended to be complete) is done here. Thus, providing an end for a
1831 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1832 row in the frame).
1833 """
1834
1835 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1836 frame_type: str
1837
1838 def __init__(self, start: int | None = None, end: int | None = None):
1839 self.start = Value(start)
1840 self.end = Value(end)
1841
1842 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1843 self.start, self.end = exprs
1844
1845 def get_source_expressions(self) -> list[Any]:
1846 return [self.start, self.end]
1847
1848 def as_sql(
1849 self, compiler: SQLCompiler, connection: DatabaseConnection
1850 ) -> tuple[str, list[Any]]:
1851 start, end = self.window_frame_start_end(
1852 connection, self.start.value, self.end.value
1853 )
1854 return (
1855 self.template
1856 % {
1857 "frame_type": self.frame_type,
1858 "start": start,
1859 "end": end,
1860 },
1861 [],
1862 )
1863
1864 def __repr__(self) -> str:
1865 return f"<{self.__class__.__name__}: {self}>"
1866
1867 def get_group_by_cols(self) -> list[Any]:
1868 return []
1869
1870 def __str__(self) -> str:
1871 if self.start.value is not None and self.start.value < 0:
1872 start = f"{abs(self.start.value)} {PRECEDING}"
1873 elif self.start.value is not None and self.start.value == 0:
1874 start = CURRENT_ROW
1875 else:
1876 start = UNBOUNDED_PRECEDING
1877
1878 if self.end.value is not None and self.end.value > 0:
1879 end = f"{self.end.value} {FOLLOWING}"
1880 elif self.end.value is not None and self.end.value == 0:
1881 end = CURRENT_ROW
1882 else:
1883 end = UNBOUNDED_FOLLOWING
1884 return self.template % {
1885 "frame_type": self.frame_type,
1886 "start": start,
1887 "end": end,
1888 }
1889
1890 def window_frame_start_end(
1891 self, connection: DatabaseConnection, start: int | None, end: int | None
1892 ) -> tuple[str, str]:
1893 """Return the window frame start and end for the given connection."""
1894 raise NotImplementedError("Subclasses must implement window_frame_start_end()")
1895
1896
1897class RowRange(WindowFrame):
1898 frame_type = "ROWS"
1899
1900 def window_frame_start_end(
1901 self, connection: DatabaseConnection, start: int | None, end: int | None
1902 ) -> tuple[str, str]:
1903 return window_frame_rows_start_end(start, end)
1904
1905
1906class ValueRange(WindowFrame):
1907 frame_type = "RANGE"
1908
1909 def window_frame_start_end(
1910 self, connection: DatabaseConnection, start: int | None, end: int | None
1911 ) -> tuple[str, str]:
1912 return window_frame_range_start_end(start, end)