1from __future__ import annotations
2
3import copy
4import datetime
5import functools
6import inspect
7from abc import ABC, abstractmethod
8from collections import defaultdict
9from decimal import Decimal
10from functools import cached_property
11from types import NoneType
12from typing import TYPE_CHECKING, Any, Protocol, Self, cast, runtime_checkable
13from uuid import UUID
14
15from plain.models import fields
16from plain.models.constants import LOOKUP_SEP
17from plain.models.db import (
18 DatabaseError,
19 NotSupportedError,
20 db_connection,
21)
22from plain.models.exceptions import EmptyResultSet, FieldError, FullResultSet
23from plain.models.query_utils import Q
24from plain.utils.deconstruct import deconstructible
25from plain.utils.hashable import make_hashable
26
27if TYPE_CHECKING:
28 from collections.abc import Callable, Iterable, Sequence
29
30 from plain.models.backends.base.base import BaseDatabaseWrapper
31 from plain.models.fields import Field
32 from plain.models.lookups import Lookup, Transform
33 from plain.models.query import QuerySet
34 from plain.models.sql.compiler import SQLCompilable, SQLCompiler
35 from plain.models.sql.query import Query
36
37
38@runtime_checkable
39class ResolvableExpression(Protocol):
40 """Protocol for expressions that can be resolved in query context."""
41
42 def resolve_expression(
43 self,
44 query: Any = None,
45 allow_joins: bool = True,
46 reuse: Any = None,
47 summarize: bool = False,
48 for_save: bool = False,
49 ) -> Any: ...
50
51
52@runtime_checkable
53class ReplaceableExpression(Protocol):
54 """Protocol for expressions that support expression replacement."""
55
56 def replace_expressions(self, replacements: dict[Any, Any]) -> Self: ...
57
58
59class Combinable:
60 """
61 Provide the ability to combine one or two objects with
62 some connector. For example F('foo') + F('bar').
63 """
64
65 # Arithmetic connectors
66 ADD = "+"
67 SUB = "-"
68 MUL = "*"
69 DIV = "/"
70 POW = "^"
71 # The following is a quoted % operator - it is quoted because it can be
72 # used in strings that also have parameter substitution.
73 MOD = "%%"
74
75 # Bitwise operators - note that these are generated by .bitand()
76 # and .bitor(), the '&' and '|' are reserved for boolean operator
77 # usage.
78 BITAND = "&"
79 BITOR = "|"
80 BITLEFTSHIFT = "<<"
81 BITRIGHTSHIFT = ">>"
82 BITXOR = "#"
83
84 def _combine(
85 self, other: Any, connector: str, reversed: bool
86 ) -> CombinedExpression:
87 if not isinstance(other, ResolvableExpression):
88 # everything must be resolvable to an expression
89 other = Value(other)
90
91 if reversed:
92 return CombinedExpression(other, connector, self)
93 return CombinedExpression(self, connector, other)
94
95 #############
96 # OPERATORS #
97 #############
98
99 def __neg__(self) -> CombinedExpression:
100 return self._combine(-1, self.MUL, False)
101
102 def __add__(self, other: Any) -> CombinedExpression:
103 return self._combine(other, self.ADD, False)
104
105 def __sub__(self, other: Any) -> CombinedExpression:
106 return self._combine(other, self.SUB, False)
107
108 def __mul__(self, other: Any) -> CombinedExpression:
109 return self._combine(other, self.MUL, False)
110
111 def __truediv__(self, other: Any) -> CombinedExpression:
112 return self._combine(other, self.DIV, False)
113
114 def __mod__(self, other: Any) -> CombinedExpression:
115 return self._combine(other, self.MOD, False)
116
117 def __pow__(self, other: Any) -> CombinedExpression:
118 return self._combine(other, self.POW, False)
119
120 def __and__(self, other: Any) -> Q:
121 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
122 return Q(self) & Q(other)
123 raise NotImplementedError(
124 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
125 )
126
127 def bitand(self, other: Any) -> CombinedExpression:
128 return self._combine(other, self.BITAND, False)
129
130 def bitleftshift(self, other: Any) -> CombinedExpression:
131 return self._combine(other, self.BITLEFTSHIFT, False)
132
133 def bitrightshift(self, other: Any) -> CombinedExpression:
134 return self._combine(other, self.BITRIGHTSHIFT, False)
135
136 def __xor__(self, other: Any) -> Q:
137 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
138 return Q(self) ^ Q(other)
139 raise NotImplementedError(
140 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
141 )
142
143 def bitxor(self, other: Any) -> CombinedExpression:
144 return self._combine(other, self.BITXOR, False)
145
146 def __or__(self, other: Any) -> Q:
147 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
148 return Q(self) | Q(other)
149 raise NotImplementedError(
150 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
151 )
152
153 def bitor(self, other: Any) -> CombinedExpression:
154 return self._combine(other, self.BITOR, False)
155
156 def __radd__(self, other: Any) -> CombinedExpression:
157 return self._combine(other, self.ADD, True)
158
159 def __rsub__(self, other: Any) -> CombinedExpression:
160 return self._combine(other, self.SUB, True)
161
162 def __rmul__(self, other: Any) -> CombinedExpression:
163 return self._combine(other, self.MUL, True)
164
165 def __rtruediv__(self, other: Any) -> CombinedExpression:
166 return self._combine(other, self.DIV, True)
167
168 def __rmod__(self, other: Any) -> CombinedExpression:
169 return self._combine(other, self.MOD, True)
170
171 def __rpow__(self, other: Any) -> CombinedExpression:
172 return self._combine(other, self.POW, True)
173
174 def __rand__(self, other: Any) -> None:
175 raise NotImplementedError(
176 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
177 )
178
179 def __ror__(self, other: Any) -> None:
180 raise NotImplementedError(
181 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
182 )
183
184 def __rxor__(self, other: Any) -> None:
185 raise NotImplementedError(
186 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
187 )
188
189 def __invert__(self) -> NegatedExpression:
190 return NegatedExpression(self)
191
192
193class BaseExpression:
194 """Base class for all query expressions."""
195
196 empty_result_set_value = NotImplemented
197 # aggregate specific fields
198 is_summary = False
199 _output_field_resolved_to_none = False
200 # Can the expression be used in a WHERE clause?
201 filterable = True
202 # Can the expression can be used as a source expression in Window?
203 window_compatible = False
204
205 def __init__(self, output_field: Field | None = None):
206 if output_field is not None:
207 self.output_field = output_field
208
209 def __getstate__(self) -> dict[str, Any]:
210 state = self.__dict__.copy()
211 state.pop("convert_value", None)
212 return state
213
214 def get_db_converters(
215 self, connection: BaseDatabaseWrapper
216 ) -> list[Callable[..., Any]]:
217 converters = []
218 if self.convert_value is not self._convert_value_noop:
219 converters.append(self.convert_value)
220 converters.extend(self.output_field.get_db_converters(connection))
221 return converters
222
223 def get_source_expressions(self) -> list[Any]:
224 return []
225
226 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
227 assert not exprs
228
229 def _parse_expressions(self, *expressions: Any) -> list[Any]:
230 return [
231 arg
232 if isinstance(arg, ResolvableExpression)
233 else (F(arg) if isinstance(arg, str) else Value(arg))
234 for arg in expressions
235 ]
236
237 def as_sql(
238 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
239 ) -> tuple[str, Sequence[Any]]:
240 """
241 Responsible for returning a (sql, [params]) tuple to be included
242 in the current query.
243
244 Different backends can provide their own implementation, by
245 providing an `as_{vendor}` method and patching the Expression:
246
247 ```
248 def override_as_sql(self, compiler, connection):
249 # custom logic
250 return super().as_sql(compiler, connection)
251 setattr(Expression, 'as_' + connection.vendor, override_as_sql)
252 ```
253
254 Arguments:
255 * compiler: the query compiler responsible for generating the query.
256 Must have a compile method, returning a (sql, [params]) tuple.
257 Calling compiler(value) will return a quoted `value`.
258
259 * connection: the database connection used for the current query.
260
261 Return: (sql, params)
262 Where `sql` is a string containing ordered sql parameters to be
263 replaced with the elements of the list `params`.
264 """
265 raise NotImplementedError("Subclasses must implement as_sql()")
266
267 @cached_property
268 def contains_aggregate(self) -> bool:
269 return any(
270 expr and expr.contains_aggregate for expr in self.get_source_expressions()
271 )
272
273 @cached_property
274 def contains_over_clause(self) -> bool:
275 return any(
276 expr and expr.contains_over_clause for expr in self.get_source_expressions()
277 )
278
279 @cached_property
280 def contains_column_references(self) -> bool:
281 return any(
282 expr and expr.contains_column_references
283 for expr in self.get_source_expressions()
284 )
285
286 def resolve_expression(
287 self,
288 query: Any = None,
289 allow_joins: bool = True,
290 reuse: Any = None,
291 summarize: bool = False,
292 for_save: bool = False,
293 ) -> Self:
294 """
295 Provide the chance to do any preprocessing or validation before being
296 added to the query.
297
298 Arguments:
299 * query: the backend query implementation
300 * allow_joins: boolean allowing or denying use of joins
301 in this query
302 * reuse: a set of reusable joins for multijoins
303 * summarize: a terminal aggregate clause
304 * for_save: whether this expression about to be used in a save or update
305
306 Return: an Expression to be added to the query.
307 """
308 c = self.copy()
309 c.is_summary = summarize
310 c.set_source_expressions(
311 [
312 expr.resolve_expression(query, allow_joins, reuse, summarize)
313 if expr
314 else None
315 for expr in c.get_source_expressions()
316 ]
317 )
318 return c
319
320 @property
321 def conditional(self) -> bool:
322 output_field = getattr(self, "output_field", None)
323 return isinstance(output_field, fields.BooleanField)
324
325 @property
326 def field(self) -> Field:
327 return self.output_field
328
329 @cached_property
330 def output_field(self) -> Field:
331 """Return the output type of this expressions."""
332 output_field = self._resolve_output_field()
333 if output_field is None:
334 self._output_field_resolved_to_none = True
335 raise FieldError("Cannot resolve expression type, unknown output_field")
336 return output_field
337
338 @cached_property
339 def _output_field_or_none(self) -> Field | None:
340 """
341 Return the output field of this expression, or None if
342 _resolve_output_field() didn't return an output type.
343 """
344 try:
345 return self.output_field
346 except FieldError:
347 if not self._output_field_resolved_to_none:
348 raise
349 return None
350
351 def _resolve_output_field(self) -> Field | None:
352 """
353 Attempt to infer the output type of the expression.
354
355 As a guess, if the output fields of all source fields match then simply
356 infer the same type here.
357
358 If a source's output field resolves to None, exclude it from this check.
359 If all sources are None, then an error is raised higher up the stack in
360 the output_field property.
361 """
362 # This guess is mostly a bad idea, but there is quite a lot of code
363 # (especially 3rd party Func subclasses) that depend on it, we'd need a
364 # deprecation path to fix it.
365 sources_iter = (
366 source for source in self.get_source_fields() if source is not None
367 )
368 for output_field in sources_iter:
369 for source in sources_iter:
370 if not isinstance(output_field, source.__class__):
371 raise FieldError(
372 f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
373 "set output_field."
374 )
375 return output_field
376 return None
377
378 @staticmethod
379 def _convert_value_noop(
380 value: Any, expression: Any, connection: BaseDatabaseWrapper
381 ) -> Any:
382 return value
383
384 @cached_property
385 def convert_value(self) -> Callable[[Any, Any, Any], Any]:
386 """
387 Expressions provide their own converters because users have the option
388 of manually specifying the output_field which may be a different type
389 from the one the database returns.
390 """
391 field = self.output_field
392 internal_type = field.get_internal_type()
393 if internal_type == "FloatField":
394 return (
395 lambda value, expression, connection: None
396 if value is None
397 else float(value)
398 )
399 elif internal_type.endswith("IntegerField"):
400 return (
401 lambda value, expression, connection: None
402 if value is None
403 else int(value)
404 )
405 elif internal_type == "DecimalField":
406 return (
407 lambda value, expression, connection: None
408 if value is None
409 else Decimal(value)
410 )
411 return self._convert_value_noop
412
413 def get_lookup(self, lookup: str) -> type[Lookup] | None:
414 return self.output_field.get_lookup(lookup)
415
416 def get_transform(self, name: str) -> type[Transform] | None:
417 return self.output_field.get_transform(name)
418
419 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
420 clone = self.copy()
421 clone.set_source_expressions(
422 [
423 e.relabeled_clone(change_map) if e is not None else None
424 for e in self.get_source_expressions()
425 ]
426 )
427 return clone
428
429 def replace_expressions(self, replacements: dict[BaseExpression, Any]) -> Self:
430 if replacement := replacements.get(self):
431 return replacement
432 clone = self.copy()
433 source_expressions = clone.get_source_expressions()
434 clone.set_source_expressions(
435 [
436 expr.replace_expressions(replacements) if expr else None
437 for expr in source_expressions
438 ]
439 )
440 return clone
441
442 def get_refs(self) -> set[str]:
443 refs = set()
444 for expr in self.get_source_expressions():
445 refs |= expr.get_refs()
446 return refs
447
448 def copy(self) -> Self:
449 return copy.copy(self)
450
451 def prefix_references(self, prefix: str) -> Self:
452 clone = self.copy()
453 clone.set_source_expressions(
454 [
455 F(f"{prefix}{expr.name}")
456 if isinstance(expr, F)
457 else expr.prefix_references(prefix)
458 for expr in self.get_source_expressions()
459 ]
460 )
461 return clone
462
463 def get_group_by_cols(self) -> list[BaseExpression]:
464 if not self.contains_aggregate:
465 return [self]
466 cols: list[BaseExpression] = []
467 for source in self.get_source_expressions():
468 cols.extend(source.get_group_by_cols())
469 return cols
470
471 def get_source_fields(self) -> list[Field | None]:
472 """Return the underlying field types used by this aggregate."""
473 return [e._output_field_or_none for e in self.get_source_expressions()]
474
475 def asc(self, **kwargs: Any) -> OrderBy:
476 return OrderBy(self, **kwargs)
477
478 def desc(self, **kwargs: Any) -> OrderBy:
479 return OrderBy(self, descending=True, **kwargs)
480
481 def reverse_ordering(self) -> Self:
482 return self
483
484 def flatten(self) -> Iterable[Any]:
485 """
486 Recursively yield this expression and all subexpressions, in
487 depth-first order.
488 """
489 yield self
490 for expr in self.get_source_expressions():
491 if expr:
492 if hasattr(expr, "flatten"):
493 yield from expr.flatten()
494 else:
495 yield expr
496
497 def select_format(
498 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
499 ) -> tuple[str, Sequence[Any]]:
500 """
501 Custom format for select clauses. For example, EXISTS expressions need
502 to be wrapped in CASE WHEN on Oracle.
503 """
504 if output_field := getattr(self, "output_field", None):
505 if select_format := getattr(output_field, "select_format", None):
506 return select_format(compiler, sql, params)
507 return sql, params
508
509
510@deconstructible
511class Expression(BaseExpression, Combinable):
512 """An expression that can be combined with other expressions."""
513
514 # Set by @deconstructible decorator in __new__
515 _constructor_args: tuple[tuple[Any, ...], dict[str, Any]]
516
517 @cached_property
518 def identity(self) -> tuple[Any, ...]:
519 constructor_signature = inspect.signature(self.__init__)
520 args, kwargs = self._constructor_args
521 signature = constructor_signature.bind_partial(*args, **kwargs)
522 signature.apply_defaults()
523 arguments = signature.arguments.items()
524 identity: list[Any] = [self.__class__]
525 for arg, value in arguments:
526 if isinstance(value, fields.Field):
527 if value.name and value.model:
528 value = (value.model.model_options.label, value.name)
529 else:
530 value = type(value)
531 else:
532 value = make_hashable(value)
533 identity.append((arg, value))
534 return tuple(identity)
535
536 def __eq__(self, other: object) -> bool:
537 if not isinstance(other, Expression):
538 return NotImplemented
539 return other.identity == self.identity
540
541 def __hash__(self) -> int:
542 return hash(self.identity)
543
544
545# Type inference for CombinedExpression.output_field.
546# Missing items will result in FieldError, by design.
547#
548# The current approach for NULL is based on lowest common denominator behavior
549# i.e. if one of the supported databases is raising an error (rather than
550# return NULL) for `val <op> NULL`, then Plain raises FieldError.
551
552_connector_combinations = [
553 # Numeric operations - operands of same type.
554 {
555 connector: [
556 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
557 (fields.FloatField, fields.FloatField, fields.FloatField),
558 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
559 ]
560 for connector in (
561 Combinable.ADD,
562 Combinable.SUB,
563 Combinable.MUL,
564 # Behavior for DIV with integer arguments follows Postgres/SQLite,
565 # not MySQL/Oracle.
566 Combinable.DIV,
567 Combinable.MOD,
568 Combinable.POW,
569 )
570 },
571 # Numeric operations - operands of different type.
572 {
573 connector: [
574 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
575 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
576 (fields.IntegerField, fields.FloatField, fields.FloatField),
577 (fields.FloatField, fields.IntegerField, fields.FloatField),
578 ]
579 for connector in (
580 Combinable.ADD,
581 Combinable.SUB,
582 Combinable.MUL,
583 Combinable.DIV,
584 Combinable.MOD,
585 )
586 },
587 # Bitwise operators.
588 {
589 connector: [
590 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
591 ]
592 for connector in (
593 Combinable.BITAND,
594 Combinable.BITOR,
595 Combinable.BITLEFTSHIFT,
596 Combinable.BITRIGHTSHIFT,
597 Combinable.BITXOR,
598 )
599 },
600 # Numeric with NULL.
601 {
602 connector: [
603 (field_type, NoneType, field_type),
604 (NoneType, field_type, field_type),
605 ]
606 for connector in (
607 Combinable.ADD,
608 Combinable.SUB,
609 Combinable.MUL,
610 Combinable.DIV,
611 Combinable.MOD,
612 Combinable.POW,
613 )
614 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
615 },
616 # Date/DateTimeField/DurationField/TimeField.
617 {
618 Combinable.ADD: [
619 # Date/DateTimeField.
620 (fields.DateField, fields.DurationField, fields.DateTimeField),
621 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
622 (fields.DurationField, fields.DateField, fields.DateTimeField),
623 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
624 # DurationField.
625 (fields.DurationField, fields.DurationField, fields.DurationField),
626 # TimeField.
627 (fields.TimeField, fields.DurationField, fields.TimeField),
628 (fields.DurationField, fields.TimeField, fields.TimeField),
629 ],
630 },
631 {
632 Combinable.SUB: [
633 # Date/DateTimeField.
634 (fields.DateField, fields.DurationField, fields.DateTimeField),
635 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
636 (fields.DateField, fields.DateField, fields.DurationField),
637 (fields.DateField, fields.DateTimeField, fields.DurationField),
638 (fields.DateTimeField, fields.DateField, fields.DurationField),
639 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
640 # DurationField.
641 (fields.DurationField, fields.DurationField, fields.DurationField),
642 # TimeField.
643 (fields.TimeField, fields.DurationField, fields.TimeField),
644 (fields.TimeField, fields.TimeField, fields.DurationField),
645 ],
646 },
647]
648
649_connector_combinators = defaultdict(list)
650
651
652def register_combinable_fields(
653 lhs: type[Field] | type[None],
654 connector: str,
655 rhs: type[Field] | type[None],
656 result: type[Field],
657) -> None:
658 """
659 Register combinable types:
660 lhs <connector> rhs -> result
661 e.g.
662 register_combinable_fields(
663 IntegerField, Combinable.ADD, FloatField, FloatField
664 )
665 """
666 _connector_combinators[connector].append((lhs, rhs, result))
667
668
669for d in _connector_combinations:
670 for connector, field_types in d.items():
671 for lhs, rhs, result in field_types:
672 register_combinable_fields(lhs, connector, rhs, result)
673
674
675@functools.lru_cache(maxsize=128)
676def _resolve_combined_type(
677 connector: str, lhs_type: type[Field], rhs_type: type[Field]
678) -> type[Field] | None:
679 combinators = _connector_combinators.get(connector, ())
680 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
681 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
682 rhs_type, combinator_rhs_type
683 ):
684 return combined_type
685 return None
686
687
688class SQLiteNumericMixin(Expression):
689 """
690 Some expressions with output_field=DecimalField() must be cast to
691 numeric to be properly filtered.
692 """
693
694 def as_sqlite(
695 self,
696 compiler: SQLCompiler,
697 connection: BaseDatabaseWrapper,
698 **extra_context: Any,
699 ) -> tuple[str, Sequence[Any]]:
700 sql, params = self.as_sql(compiler, connection, **extra_context)
701 try:
702 if self.output_field.get_internal_type() == "DecimalField":
703 sql = f"CAST({sql} AS NUMERIC)"
704 except FieldError:
705 pass
706 return sql, params
707
708
709class CombinedExpression(SQLiteNumericMixin, Expression):
710 def __init__(
711 self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
712 ):
713 super().__init__(output_field=output_field)
714 self.connector = connector
715 self.lhs = lhs
716 self.rhs = rhs
717
718 def __repr__(self) -> str:
719 return f"<{self.__class__.__name__}: {self}>"
720
721 def __str__(self) -> str:
722 return f"{self.lhs} {self.connector} {self.rhs}"
723
724 def get_source_expressions(self) -> list[Any]:
725 return [self.lhs, self.rhs]
726
727 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
728 self.lhs, self.rhs = exprs
729
730 def _resolve_output_field(self) -> Field | None:
731 # We avoid using super() here for reasons given in
732 # Expression._resolve_output_field()
733 combined_type = _resolve_combined_type(
734 self.connector,
735 type(self.lhs._output_field_or_none),
736 type(self.rhs._output_field_or_none),
737 )
738 if combined_type is None:
739 raise FieldError(
740 f"Cannot infer type of {self.connector!r} expression involving these "
741 f"types: {self.lhs.output_field.__class__.__name__}, "
742 f"{self.rhs.output_field.__class__.__name__}. You must set "
743 f"output_field."
744 )
745 return combined_type()
746
747 def as_sql(
748 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
749 ) -> tuple[str, list[Any]]:
750 expressions = []
751 expression_params = []
752 sql, params = compiler.compile(self.lhs)
753 expressions.append(sql)
754 expression_params.extend(params)
755 sql, params = compiler.compile(self.rhs)
756 expressions.append(sql)
757 expression_params.extend(params)
758 # order of precedence
759 expression_wrapper = "(%s)"
760 sql = connection.ops.combine_expression(self.connector, expressions)
761 return expression_wrapper % sql, expression_params
762
763 def resolve_expression(
764 self,
765 query: Any = None,
766 allow_joins: bool = True,
767 reuse: Any = None,
768 summarize: bool = False,
769 for_save: bool = False,
770 ) -> CombinedExpression | DurationExpression | TemporalSubtraction:
771 lhs = self.lhs.resolve_expression(
772 query, allow_joins, reuse, summarize, for_save
773 )
774 rhs = self.rhs.resolve_expression(
775 query, allow_joins, reuse, summarize, for_save
776 )
777 if not isinstance(self, DurationExpression | TemporalSubtraction):
778 try:
779 lhs_type = lhs.output_field.get_internal_type()
780 except (AttributeError, FieldError):
781 lhs_type = None
782 try:
783 rhs_type = rhs.output_field.get_internal_type()
784 except (AttributeError, FieldError):
785 rhs_type = None
786 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
787 return DurationExpression(
788 self.lhs, self.connector, self.rhs
789 ).resolve_expression(
790 query,
791 allow_joins,
792 reuse,
793 summarize,
794 for_save,
795 )
796 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
797 if (
798 self.connector == self.SUB
799 and lhs_type in datetime_fields
800 and lhs_type == rhs_type
801 ):
802 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
803 query,
804 allow_joins,
805 reuse,
806 summarize,
807 for_save,
808 )
809 c = self.copy()
810 c.is_summary = summarize
811 c.lhs = lhs
812 c.rhs = rhs
813 return c
814
815
816class DurationExpression(CombinedExpression):
817 def compile(
818 self, side: Any, compiler: SQLCompiler, connection: BaseDatabaseWrapper
819 ) -> tuple[str, Sequence[Any]]:
820 try:
821 output = side.output_field
822 except FieldError:
823 pass
824 else:
825 if output.get_internal_type() == "DurationField":
826 sql, params = compiler.compile(side)
827 return connection.ops.format_for_duration_arithmetic(sql), params
828 return compiler.compile(side)
829
830 def as_sql(
831 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
832 ) -> tuple[str, list[Any]]:
833 if connection.features.has_native_duration_field:
834 return super().as_sql(compiler, connection)
835 connection.ops.check_expression_support(self)
836 expressions = []
837 expression_params = []
838 sql, params = self.compile(self.lhs, compiler, connection)
839 expressions.append(sql)
840 expression_params.extend(params)
841 sql, params = self.compile(self.rhs, compiler, connection)
842 expressions.append(sql)
843 expression_params.extend(params)
844 # order of precedence
845 expression_wrapper = "(%s)"
846 sql = connection.ops.combine_duration_expression(self.connector, expressions)
847 return expression_wrapper % sql, expression_params
848
849 def as_sqlite(
850 self,
851 compiler: SQLCompiler,
852 connection: BaseDatabaseWrapper,
853 **extra_context: Any,
854 ) -> tuple[str, Sequence[Any]]:
855 sql, params = self.as_sql(compiler, connection, **extra_context)
856 if self.connector in {Combinable.MUL, Combinable.DIV}:
857 try:
858 lhs_type = self.lhs.output_field.get_internal_type()
859 rhs_type = self.rhs.output_field.get_internal_type()
860 except (AttributeError, FieldError):
861 pass
862 else:
863 allowed_fields = {
864 "DecimalField",
865 "DurationField",
866 "FloatField",
867 "IntegerField",
868 }
869 if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
870 raise DatabaseError(
871 f"Invalid arguments for operator {self.connector}."
872 )
873 return sql, params
874
875
876class TemporalSubtraction(CombinedExpression):
877 output_field = fields.DurationField()
878
879 def __init__(self, lhs: Any, rhs: Any):
880 super().__init__(lhs, self.SUB, rhs)
881
882 def as_sql(
883 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
884 ) -> tuple[str, list[Any]]:
885 connection.ops.check_expression_support(self)
886 lhs = compiler.compile(self.lhs)
887 rhs = compiler.compile(self.rhs)
888 sql, params = connection.ops.subtract_temporals(
889 self.lhs.output_field.get_internal_type(), lhs, rhs
890 )
891 return sql, list(params)
892
893
894@deconstructible(path="plain.models.F")
895class F(Combinable):
896 """An object capable of resolving references to existing query objects."""
897
898 def __init__(self, name: str):
899 """
900 Arguments:
901 * name: the name of the field this expression references
902 """
903 self.name = name
904
905 def __repr__(self) -> str:
906 return f"{self.__class__.__name__}({self.name})"
907
908 def resolve_expression(
909 self,
910 query: Any = None,
911 allow_joins: bool = True,
912 reuse: Any = None,
913 summarize: bool = False,
914 for_save: bool = False,
915 ) -> Any:
916 return query.resolve_ref(self.name, allow_joins, reuse, summarize)
917
918 def replace_expressions(self, replacements: dict[Any, Any]) -> F:
919 return replacements.get(self, self)
920
921 def asc(self, **kwargs: Any) -> OrderBy:
922 return OrderBy(self, **kwargs)
923
924 def desc(self, **kwargs: Any) -> OrderBy:
925 return OrderBy(self, descending=True, **kwargs)
926
927 def __eq__(self, other: object) -> bool:
928 if not isinstance(other, F):
929 return NotImplemented
930 return self.__class__ == other.__class__ and self.name == other.name
931
932 def __hash__(self) -> int:
933 return hash(self.name)
934
935 def copy(self) -> Self:
936 return copy.copy(self)
937
938
939class ResolvedOuterRef(F):
940 """
941 An object that contains a reference to an outer query.
942
943 In this case, the reference to the outer query has been resolved because
944 the inner query has been used as a subquery.
945 """
946
947 contains_aggregate = False
948 contains_over_clause = False
949
950 def as_sql(self, *args: Any, **kwargs: Any) -> None:
951 raise ValueError(
952 "This queryset contains a reference to an outer query and may "
953 "only be used in a subquery."
954 )
955
956 def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
957 col = super().resolve_expression(*args, **kwargs)
958 if col.contains_over_clause:
959 raise NotSupportedError(
960 f"Referencing outer query window expression is not supported: "
961 f"{self.name}."
962 )
963 # FIXME: Rename possibly_multivalued to multivalued and fix detection
964 # for non-multivalued JOINs (e.g. foreign key fields). This should take
965 # into account only many-to-many and one-to-many relationships.
966 col.possibly_multivalued = LOOKUP_SEP in self.name
967 return col
968
969 def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
970 return self
971
972 def get_group_by_cols(self) -> list[Any]:
973 return []
974
975
976class OuterRef(F):
977 contains_aggregate = False
978
979 def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
980 if isinstance(self.name, self.__class__):
981 return self.name
982 return ResolvedOuterRef(self.name)
983
984 def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
985 return self
986
987
988@deconstructible(path="plain.models.Func")
989class Func(SQLiteNumericMixin, Expression):
990 """An SQL function call."""
991
992 function = None
993 template = "%(function)s(%(expressions)s)"
994 arg_joiner = ", "
995 arity = None # The number of arguments the function accepts.
996
997 def __init__(
998 self, *expressions: Any, output_field: Field | None = None, **extra: Any
999 ):
1000 if self.arity is not None and len(expressions) != self.arity:
1001 raise TypeError(
1002 "'{}' takes exactly {} {} ({} given)".format(
1003 self.__class__.__name__,
1004 self.arity,
1005 "argument" if self.arity == 1 else "arguments",
1006 len(expressions),
1007 )
1008 )
1009 super().__init__(output_field=output_field)
1010 self.source_expressions: list[Any] = self._parse_expressions(*expressions)
1011 self.extra = extra
1012
1013 def __repr__(self) -> str:
1014 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1015 extra = {**self.extra, **self._get_repr_options()}
1016 if extra:
1017 extra = ", ".join(
1018 str(key) + "=" + str(val) for key, val in sorted(extra.items())
1019 )
1020 return f"{self.__class__.__name__}({args}, {extra})"
1021 return f"{self.__class__.__name__}({args})"
1022
1023 def _get_repr_options(self) -> dict[str, Any]:
1024 """Return a dict of extra __init__() options to include in the repr."""
1025 return {}
1026
1027 def get_source_expressions(self) -> list[Any]:
1028 return self.source_expressions
1029
1030 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1031 self.source_expressions = list(exprs)
1032
1033 def resolve_expression(
1034 self,
1035 query: Any = None,
1036 allow_joins: bool = True,
1037 reuse: Any = None,
1038 summarize: bool = False,
1039 for_save: bool = False,
1040 ) -> Self:
1041 c = self.copy()
1042 c.is_summary = summarize
1043 for pos, arg in enumerate(c.source_expressions):
1044 c.source_expressions[pos] = arg.resolve_expression(
1045 query, allow_joins, reuse, summarize, for_save
1046 )
1047 return c
1048
1049 def as_sql(
1050 self,
1051 compiler: SQLCompiler,
1052 connection: BaseDatabaseWrapper,
1053 function: str | None = None,
1054 template: str | None = None,
1055 arg_joiner: str | None = None,
1056 **extra_context: Any,
1057 ) -> tuple[str, list[Any]]:
1058 connection.ops.check_expression_support(self)
1059 sql_parts = []
1060 params = []
1061 for arg in self.source_expressions:
1062 try:
1063 arg_sql, arg_params = compiler.compile(arg)
1064 except EmptyResultSet:
1065 empty_result_set_value = getattr(
1066 arg, "empty_result_set_value", NotImplemented
1067 )
1068 if empty_result_set_value is NotImplemented:
1069 raise
1070 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
1071 except FullResultSet:
1072 arg_sql, arg_params = compiler.compile(Value(True))
1073 sql_parts.append(arg_sql)
1074 params.extend(arg_params)
1075 data = {**self.extra, **extra_context}
1076 # Use the first supplied value in this order: the parameter to this
1077 # method, a value supplied in __init__()'s **extra (the value in
1078 # `data`), or the value defined on the class.
1079 if function is not None:
1080 data["function"] = function
1081 else:
1082 data.setdefault("function", self.function)
1083 template = template or data.get("template", self.template)
1084 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1085 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1086 return template % data, params
1087
1088 def copy(self) -> Self:
1089 clone = super().copy()
1090 clone.source_expressions = self.source_expressions[:]
1091 clone.extra = self.extra.copy()
1092 return cast(Self, clone)
1093
1094
1095@deconstructible(path="plain.models.Value")
1096class Value(SQLiteNumericMixin, Expression):
1097 """Represent a wrapped value as a node within an expression."""
1098
1099 # Provide a default value for `for_save` in order to allow unresolved
1100 # instances to be compiled until a decision is taken in #25425.
1101 for_save = False
1102
1103 def __init__(self, value: Any, output_field: Field | None = None):
1104 """
1105 Arguments:
1106 * value: the value this expression represents. The value will be
1107 added into the sql parameter list and properly quoted.
1108
1109 * output_field: an instance of the model field type that this
1110 expression will return, such as IntegerField() or CharField().
1111 """
1112 super().__init__(output_field=output_field)
1113 self.value = value
1114
1115 def __repr__(self) -> str:
1116 return f"{self.__class__.__name__}({self.value!r})"
1117
1118 def as_sql(
1119 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1120 ) -> tuple[str, list[Any]]:
1121 connection.ops.check_expression_support(self)
1122 val = self.value
1123 output_field = self._output_field_or_none
1124 if output_field is not None:
1125 if self.for_save:
1126 val = output_field.get_db_prep_save(val, connection=connection)
1127 else:
1128 val = output_field.get_db_prep_value(val, connection=connection)
1129 if hasattr(output_field, "get_placeholder"):
1130 return output_field.get_placeholder(val, compiler, connection), [val]
1131 if val is None:
1132 # cx_Oracle does not always convert None to the appropriate
1133 # NULL type (like in case expressions using numbers), so we
1134 # use a literal SQL NULL
1135 return "NULL", []
1136 return "%s", [val]
1137
1138 def resolve_expression(
1139 self,
1140 query: Any = None,
1141 allow_joins: bool = True,
1142 reuse: Any = None,
1143 summarize: bool = False,
1144 for_save: bool = False,
1145 ) -> Value:
1146 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1147 c.for_save = for_save
1148 return c
1149
1150 def get_group_by_cols(self) -> list[Any]:
1151 return []
1152
1153 def _resolve_output_field(self) -> Field | None:
1154 if isinstance(self.value, str):
1155 return fields.CharField()
1156 if isinstance(self.value, bool):
1157 return fields.BooleanField()
1158 if isinstance(self.value, int):
1159 return fields.IntegerField()
1160 if isinstance(self.value, float):
1161 return fields.FloatField()
1162 if isinstance(self.value, datetime.datetime):
1163 return fields.DateTimeField()
1164 if isinstance(self.value, datetime.date):
1165 return fields.DateField()
1166 if isinstance(self.value, datetime.time):
1167 return fields.TimeField()
1168 if isinstance(self.value, datetime.timedelta):
1169 return fields.DurationField()
1170 if isinstance(self.value, Decimal):
1171 return fields.DecimalField()
1172 if isinstance(self.value, bytes):
1173 return fields.BinaryField()
1174 if isinstance(self.value, UUID):
1175 return fields.UUIDField()
1176
1177 @property
1178 def empty_result_set_value(self) -> Any:
1179 return self.value
1180
1181
1182class RawSQL(Expression):
1183 def __init__(
1184 self, sql: str, params: Sequence[Any], output_field: Field | None = None
1185 ):
1186 if output_field is None:
1187 output_field = fields.Field()
1188 self.sql, self.params = sql, params
1189 super().__init__(output_field=output_field)
1190
1191 def __repr__(self) -> str:
1192 return f"{self.__class__.__name__}({self.sql}, {self.params})"
1193
1194 def as_sql(
1195 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1196 ) -> tuple[str, Sequence[Any]]:
1197 return f"({self.sql})", self.params
1198
1199 def get_group_by_cols(self) -> list[BaseExpression]:
1200 return [self]
1201
1202
1203class Star(Expression):
1204 def __repr__(self) -> str:
1205 return "'*'"
1206
1207 def as_sql(
1208 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1209 ) -> tuple[str, list[Any]]:
1210 return "*", []
1211
1212
1213class Col(Expression):
1214 contains_column_references = True
1215 possibly_multivalued = False
1216
1217 def __init__(
1218 self, alias: str | None, target: Any, output_field: Field | None = None
1219 ):
1220 if output_field is None:
1221 output_field = target
1222 super().__init__(output_field=output_field)
1223 self.alias, self.target = alias, target
1224
1225 def __repr__(self) -> str:
1226 alias, target = self.alias, self.target
1227 identifiers = (alias, str(target)) if alias else (str(target),)
1228 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1229
1230 def as_sql(
1231 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1232 ) -> tuple[str, list[Any]]:
1233 alias, column = self.alias, self.target.column
1234 identifiers = (alias, column) if alias else (column,)
1235 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1236 return sql, []
1237
1238 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1239 if self.alias is None:
1240 return self
1241 return self.__class__(
1242 change_map.get(self.alias, self.alias), self.target, self.output_field
1243 )
1244
1245 def get_group_by_cols(self) -> list[BaseExpression]:
1246 return [self]
1247
1248 def get_db_converters(
1249 self, connection: BaseDatabaseWrapper
1250 ) -> list[Callable[..., Any]]:
1251 if self.target == self.output_field:
1252 return self.output_field.get_db_converters(connection)
1253 return self.output_field.get_db_converters(
1254 connection
1255 ) + self.target.get_db_converters(connection)
1256
1257
1258class Ref(Expression):
1259 """
1260 Reference to column alias of the query. For example, Ref('sum_cost') in
1261 qs.annotate(sum_cost=Sum('cost')) query.
1262 """
1263
1264 def __init__(self, refs: str, source: Any):
1265 super().__init__()
1266 self.refs, self.source = refs, source
1267
1268 def __repr__(self) -> str:
1269 return f"{self.__class__.__name__}({self.refs}, {self.source})"
1270
1271 def get_source_expressions(self) -> list[Any]:
1272 return [self.source]
1273
1274 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1275 (self.source,) = exprs
1276
1277 def resolve_expression(
1278 self,
1279 query: Any = None,
1280 allow_joins: bool = True,
1281 reuse: Any = None,
1282 summarize: bool = False,
1283 for_save: bool = False,
1284 ) -> Ref:
1285 # The sub-expression `source` has already been resolved, as this is
1286 # just a reference to the name of `source`.
1287 return self
1288
1289 def get_refs(self) -> set[str]:
1290 return {self.refs}
1291
1292 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1293 return self
1294
1295 def as_sql(
1296 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1297 ) -> tuple[str, list[Any]]:
1298 return connection.ops.quote_name(self.refs), []
1299
1300 def get_group_by_cols(self) -> list[BaseExpression]:
1301 return [self]
1302
1303
1304class ExpressionList(Func):
1305 """
1306 An expression containing multiple expressions. Can be used to provide a
1307 list of expressions as an argument to another expression, like a partition
1308 clause.
1309 """
1310
1311 template = "%(expressions)s"
1312
1313 def __init__(self, *expressions: Any, **extra: Any):
1314 if not expressions:
1315 raise ValueError(
1316 f"{self.__class__.__name__} requires at least one expression."
1317 )
1318 super().__init__(*expressions, **extra)
1319
1320 def __str__(self) -> str:
1321 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1322
1323 def as_sqlite(
1324 self,
1325 compiler: SQLCompiler,
1326 connection: BaseDatabaseWrapper,
1327 **extra_context: Any,
1328 ) -> tuple[str, Sequence[Any]]:
1329 # Casting to numeric is unnecessary.
1330 return self.as_sql(compiler, connection, **extra_context)
1331
1332
1333class OrderByList(Func):
1334 template = "ORDER BY %(expressions)s"
1335
1336 def __init__(self, *expressions: Any, **extra: Any):
1337 expressions_tuple = tuple(
1338 (
1339 OrderBy(F(expr[1:]), descending=True)
1340 if isinstance(expr, str) and expr[0] == "-"
1341 else expr
1342 )
1343 for expr in expressions
1344 )
1345 super().__init__(*expressions_tuple, **extra)
1346
1347 def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, list[Any]]:
1348 if not self.source_expressions:
1349 return "", []
1350 sql, params = super().as_sql(*args, **kwargs)
1351 return sql, list(params)
1352
1353 def get_group_by_cols(self) -> list[Any]:
1354 group_by_cols = []
1355 for order_by in self.get_source_expressions():
1356 group_by_cols.extend(order_by.get_group_by_cols())
1357 return group_by_cols
1358
1359
1360@deconstructible(path="plain.models.ExpressionWrapper")
1361class ExpressionWrapper(SQLiteNumericMixin, Expression):
1362 """
1363 An expression that can wrap another expression so that it can provide
1364 extra context to the inner expression, such as the output_field.
1365 """
1366
1367 def __init__(self, expression: Any, output_field: Field):
1368 super().__init__(output_field=output_field)
1369 self.expression = expression
1370
1371 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1372 self.expression = exprs[0]
1373
1374 def get_source_expressions(self) -> list[Any]:
1375 return [self.expression]
1376
1377 def get_group_by_cols(self) -> list[Any]:
1378 if isinstance(self.expression, Expression):
1379 expression = self.expression.copy()
1380 expression.output_field = self.output_field
1381 return expression.get_group_by_cols()
1382 # For non-expressions e.g. an SQL WHERE clause, the entire
1383 # `expression` must be included in the GROUP BY clause.
1384 return super().get_group_by_cols()
1385
1386 def as_sql(
1387 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1388 ) -> tuple[str, Sequence[Any]]:
1389 return compiler.compile(self.expression)
1390
1391 def __repr__(self) -> str:
1392 return f"{self.__class__.__name__}({self.expression})"
1393
1394
1395class NegatedExpression(ExpressionWrapper):
1396 """The logical negation of a conditional expression."""
1397
1398 def __init__(self, expression: Any):
1399 super().__init__(expression, output_field=fields.BooleanField())
1400
1401 def __invert__(self) -> Any:
1402 return self.expression.copy()
1403
1404 def as_sql(
1405 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1406 ) -> tuple[str, Sequence[Any]]:
1407 try:
1408 sql, params = super().as_sql(compiler, connection)
1409 except EmptyResultSet:
1410 features = compiler.connection.features
1411 if not features.supports_boolean_expr_in_select_clause:
1412 return "1=1", ()
1413 return compiler.compile(Value(True))
1414 ops = compiler.connection.ops
1415 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1416 # to be compared to another expression unless they're wrapped in a CASE
1417 # WHEN.
1418 if not ops.conditional_expression_supported_in_where_clause(self.expression):
1419 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1420 return f"NOT {sql}", params
1421
1422 def resolve_expression(
1423 self,
1424 query: Any = None,
1425 allow_joins: bool = True,
1426 reuse: Any = None,
1427 summarize: bool = False,
1428 for_save: bool = False,
1429 ) -> NegatedExpression:
1430 resolved = super().resolve_expression(
1431 query, allow_joins, reuse, summarize, for_save
1432 )
1433 if not getattr(resolved.expression, "conditional", False):
1434 raise TypeError("Cannot negate non-conditional expressions.")
1435 return resolved
1436
1437 def select_format(
1438 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1439 ) -> tuple[str, Sequence[Any]]:
1440 # Wrap boolean expressions with a CASE WHEN expression if a database
1441 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1442 # GROUP BY list.
1443 expression_supported_in_where_clause = (
1444 compiler.connection.ops.conditional_expression_supported_in_where_clause
1445 )
1446 if (
1447 not compiler.connection.features.supports_boolean_expr_in_select_clause
1448 # Avoid double wrapping.
1449 and expression_supported_in_where_clause(self.expression)
1450 ):
1451 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1452 return sql, params
1453
1454
1455@deconstructible(path="plain.models.When")
1456class When(Expression):
1457 template = "WHEN %(condition)s THEN %(result)s"
1458 # This isn't a complete conditional expression, must be used in Case().
1459 conditional = False
1460 condition: SQLCompilable
1461
1462 def __init__(
1463 self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1464 ):
1465 lookups_dict: dict[str, Any] | None = lookups or None
1466 if lookups_dict:
1467 if condition is None:
1468 condition, lookups_dict = Q(**lookups_dict), None
1469 elif getattr(condition, "conditional", False):
1470 condition, lookups_dict = Q(condition, **lookups_dict), None
1471 if (
1472 condition is None
1473 or not getattr(condition, "conditional", False)
1474 or lookups_dict
1475 ):
1476 raise TypeError(
1477 "When() supports a Q object, a boolean expression, or lookups "
1478 "as a condition."
1479 )
1480 if isinstance(condition, Q) and not condition:
1481 raise ValueError("An empty Q() can't be used as a When() condition.")
1482 super().__init__(output_field=None)
1483 self.condition = condition # type: ignore[assignment]
1484 self.result = self._parse_expressions(then)[0]
1485
1486 def __str__(self) -> str:
1487 return f"WHEN {self.condition!r} THEN {self.result!r}"
1488
1489 def __repr__(self) -> str:
1490 return f"<{self.__class__.__name__}: {self}>"
1491
1492 def get_source_expressions(self) -> list[Any]:
1493 return [self.condition, self.result]
1494
1495 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1496 self.condition, self.result = exprs
1497
1498 def get_source_fields(self) -> list[Field | None]:
1499 # We're only interested in the fields of the result expressions.
1500 return [self.result._output_field_or_none]
1501
1502 def resolve_expression(
1503 self,
1504 query: Any = None,
1505 allow_joins: bool = True,
1506 reuse: Any = None,
1507 summarize: bool = False,
1508 for_save: bool = False,
1509 ) -> When:
1510 c = self.copy()
1511 c.is_summary = summarize
1512 if isinstance(c.condition, ResolvableExpression):
1513 c.condition = c.condition.resolve_expression(
1514 query, allow_joins, reuse, summarize, False
1515 )
1516 c.result = c.result.resolve_expression(
1517 query, allow_joins, reuse, summarize, for_save
1518 )
1519 return c
1520
1521 def as_sql(
1522 self,
1523 compiler: SQLCompiler,
1524 connection: BaseDatabaseWrapper,
1525 template: str | None = None,
1526 **extra_context: Any,
1527 ) -> tuple[str, tuple[Any, ...]]:
1528 connection.ops.check_expression_support(self)
1529 template_params = extra_context
1530 sql_params = []
1531 # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1532 condition_sql, condition_params = compiler.compile(self.condition)
1533 template_params["condition"] = condition_sql
1534 result_sql, result_params = compiler.compile(self.result)
1535 template_params["result"] = result_sql
1536 template = template or self.template
1537 return template % template_params, (
1538 *sql_params,
1539 *condition_params,
1540 *result_params,
1541 )
1542
1543 def get_group_by_cols(self) -> list[Any]:
1544 # This is not a complete expression and cannot be used in GROUP BY.
1545 cols = []
1546 for source in self.get_source_expressions():
1547 cols.extend(source.get_group_by_cols())
1548 return cols
1549
1550
1551@deconstructible(path="plain.models.Case")
1552class Case(SQLiteNumericMixin, Expression):
1553 """
1554 An SQL searched CASE expression:
1555
1556 CASE
1557 WHEN n > 0
1558 THEN 'positive'
1559 WHEN n < 0
1560 THEN 'negative'
1561 ELSE 'zero'
1562 END
1563 """
1564
1565 template = "CASE %(cases)s ELSE %(default)s END"
1566 case_joiner = " "
1567
1568 def __init__(
1569 self,
1570 *cases: When,
1571 default: Any = None,
1572 output_field: Field | None = None,
1573 **extra: Any,
1574 ):
1575 if not all(isinstance(case, When) for case in cases):
1576 raise TypeError("Positional arguments must all be When objects.")
1577 super().__init__(output_field)
1578 self.cases = list(cases)
1579 self.default = self._parse_expressions(default)[0]
1580 self.extra = extra
1581
1582 def __str__(self) -> str:
1583 return "CASE {}, ELSE {!r}".format(
1584 ", ".join(str(c) for c in self.cases),
1585 self.default,
1586 )
1587
1588 def __repr__(self) -> str:
1589 return f"<{self.__class__.__name__}: {self}>"
1590
1591 def get_source_expressions(self) -> list[Any]:
1592 return self.cases + [self.default]
1593
1594 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1595 *self.cases, self.default = exprs
1596
1597 def resolve_expression(
1598 self,
1599 query: Any = None,
1600 allow_joins: bool = True,
1601 reuse: Any = None,
1602 summarize: bool = False,
1603 for_save: bool = False,
1604 ) -> Case:
1605 c = self.copy()
1606 c.is_summary = summarize
1607 for pos, case in enumerate(c.cases):
1608 c.cases[pos] = case.resolve_expression(
1609 query, allow_joins, reuse, summarize, for_save
1610 )
1611 c.default = c.default.resolve_expression(
1612 query, allow_joins, reuse, summarize, for_save
1613 )
1614 return c
1615
1616 def copy(self) -> Self:
1617 c = super().copy()
1618 c.cases = c.cases[:]
1619 return cast(Self, c)
1620
1621 def as_sql(
1622 self,
1623 compiler: SQLCompiler,
1624 connection: BaseDatabaseWrapper,
1625 template: str | None = None,
1626 case_joiner: str | None = None,
1627 **extra_context: Any,
1628 ) -> tuple[str, list[Any]]:
1629 connection.ops.check_expression_support(self)
1630 if not self.cases:
1631 sql, params = compiler.compile(self.default)
1632 return sql, list(params)
1633 template_params = {**self.extra, **extra_context}
1634 case_parts = []
1635 sql_params = []
1636 default_sql, default_params = compiler.compile(self.default)
1637 for case in self.cases:
1638 try:
1639 case_sql, case_params = compiler.compile(case)
1640 except EmptyResultSet:
1641 continue
1642 except FullResultSet:
1643 default_sql, default_params = compiler.compile(case.result)
1644 break
1645 case_parts.append(case_sql)
1646 sql_params.extend(case_params)
1647 if not case_parts:
1648 return default_sql, list(default_params)
1649 case_joiner = case_joiner or self.case_joiner
1650 template_params["cases"] = case_joiner.join(case_parts)
1651 template_params["default"] = default_sql
1652 sql_params.extend(default_params)
1653 template = template or template_params.get("template", self.template)
1654 sql = template % template_params
1655 if self._output_field_or_none is not None:
1656 sql = connection.ops.unification_cast_sql(self.output_field) % sql
1657 return sql, sql_params
1658
1659 def get_group_by_cols(self) -> list[Any]:
1660 if not self.cases:
1661 return self.default.get_group_by_cols()
1662 return super().get_group_by_cols()
1663
1664
1665class Subquery(BaseExpression, Combinable):
1666 """
1667 An explicit subquery. It may contain OuterRef() references to the outer
1668 query which will be resolved when it is applied to that query.
1669 """
1670
1671 template = "(%(subquery)s)"
1672 contains_aggregate = False
1673 empty_result_set_value = None
1674
1675 def __init__(
1676 self,
1677 query: QuerySet[Any] | Query,
1678 output_field: Field | None = None,
1679 **extra: Any,
1680 ):
1681 # Import here to avoid circular import
1682 from plain.models.sql.query import Query
1683
1684 # Allow the usage of both QuerySet and sql.Query objects.
1685 if isinstance(query, Query):
1686 # It's already a Query object, use it directly
1687 sql_query = query
1688 else:
1689 # It's a QuerySet, extract the sql.Query
1690 sql_query = query.sql_query
1691 self.query = sql_query.clone()
1692 self.query.subquery = True
1693 self.extra = extra
1694 super().__init__(output_field)
1695
1696 def get_source_expressions(self) -> list[Any]:
1697 return [self.query]
1698
1699 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1700 self.query = exprs[0]
1701
1702 def _resolve_output_field(self) -> Field | None:
1703 return self.query.output_field
1704
1705 def copy(self) -> Self:
1706 clone = super().copy()
1707 clone.query = clone.query.clone()
1708 return cast(Self, clone)
1709
1710 @property
1711 def external_aliases(self) -> dict[str, bool]:
1712 return self.query.external_aliases
1713
1714 def get_external_cols(self) -> list[Any]:
1715 return self.query.get_external_cols()
1716
1717 def as_sql(
1718 self,
1719 compiler: SQLCompiler,
1720 connection: BaseDatabaseWrapper,
1721 template: str | None = None,
1722 **extra_context: Any,
1723 ) -> tuple[str, tuple[Any, ...]]:
1724 connection.ops.check_expression_support(self)
1725 template_params = {**self.extra, **extra_context}
1726 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1727 template_params["subquery"] = subquery_sql[1:-1]
1728
1729 template = template or template_params.get("template", self.template)
1730 sql = template % template_params
1731 return sql, sql_params
1732
1733 def get_group_by_cols(self) -> list[Any]:
1734 return self.query.get_group_by_cols(wrapper=self)
1735
1736
1737class Exists(Subquery):
1738 template = "EXISTS(%(subquery)s)"
1739 output_field = fields.BooleanField()
1740 empty_result_set_value = False
1741
1742 def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1743 super().__init__(query, **kwargs)
1744 self.query = self.query.exists()
1745
1746 def select_format(
1747 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1748 ) -> tuple[str, Sequence[Any]]:
1749 # Wrap EXISTS() with a CASE WHEN expression if a database backend
1750 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1751 # BY list.
1752 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1753 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1754 return sql, params
1755
1756
1757@deconstructible(path="plain.models.OrderBy")
1758class OrderBy(Expression):
1759 template = "%(expression)s %(ordering)s"
1760 conditional = False
1761
1762 def __init__(
1763 self,
1764 expression: Any,
1765 descending: bool = False,
1766 nulls_first: bool | None = None,
1767 nulls_last: bool | None = None,
1768 ):
1769 if nulls_first and nulls_last:
1770 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1771 if nulls_first is False or nulls_last is False:
1772 raise ValueError("nulls_first and nulls_last values must be True or None.")
1773 self.nulls_first = nulls_first
1774 self.nulls_last = nulls_last
1775 self.descending = descending
1776 if not isinstance(expression, ResolvableExpression):
1777 raise ValueError("expression must be an expression type")
1778 self.expression = expression
1779
1780 def __repr__(self) -> str:
1781 return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1782
1783 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1784 self.expression = exprs[0]
1785
1786 def get_source_expressions(self) -> list[Any]:
1787 return [self.expression]
1788
1789 def as_sql(
1790 self,
1791 compiler: SQLCompiler,
1792 connection: BaseDatabaseWrapper,
1793 template: str | None = None,
1794 **extra_context: Any,
1795 ) -> tuple[str, tuple[Any, ...]]:
1796 template = template or self.template
1797 if connection.features.supports_order_by_nulls_modifier:
1798 if self.nulls_last:
1799 template = f"{template} NULLS LAST"
1800 elif self.nulls_first:
1801 template = f"{template} NULLS FIRST"
1802 else:
1803 if self.nulls_last and not (
1804 self.descending and connection.features.order_by_nulls_first
1805 ):
1806 template = f"%(expression)s IS NULL, {template}"
1807 elif self.nulls_first and not (
1808 not self.descending and connection.features.order_by_nulls_first
1809 ):
1810 template = f"%(expression)s IS NOT NULL, {template}"
1811 connection.ops.check_expression_support(self)
1812 expression_sql, params = compiler.compile(self.expression)
1813 placeholders = {
1814 "expression": expression_sql,
1815 "ordering": "DESC" if self.descending else "ASC",
1816 **extra_context,
1817 }
1818 params *= template.count("%(expression)s")
1819 return (template % placeholders).rstrip(), params
1820
1821 def get_group_by_cols(self) -> list[Any]:
1822 cols = []
1823 for source in self.get_source_expressions():
1824 cols.extend(source.get_group_by_cols())
1825 return cols
1826
1827 def reverse_ordering(self) -> OrderBy:
1828 self.descending = not self.descending
1829 if self.nulls_first:
1830 self.nulls_last = True
1831 self.nulls_first = None
1832 elif self.nulls_last:
1833 self.nulls_first = True
1834 self.nulls_last = None
1835 return self
1836
1837 def asc(self) -> None: # type: ignore[override]
1838 self.descending = False
1839
1840 def desc(self) -> None: # type: ignore[override]
1841 self.descending = True
1842
1843
1844class Window(SQLiteNumericMixin, Expression):
1845 template = "%(expression)s OVER (%(window)s)"
1846 # Although the main expression may either be an aggregate or an
1847 # expression with an aggregate function, the GROUP BY that will
1848 # be introduced in the query as a result is not desired.
1849 contains_aggregate = False
1850 contains_over_clause = True
1851 partition_by: ExpressionList | None
1852 order_by: OrderByList | None
1853
1854 def __init__(
1855 self,
1856 expression: Any,
1857 partition_by: Any = None,
1858 order_by: Any = None,
1859 frame: Any = None,
1860 output_field: Field | None = None,
1861 ):
1862 self.partition_by = partition_by
1863 self.order_by = order_by
1864 self.frame = frame
1865
1866 if not getattr(expression, "window_compatible", False):
1867 raise ValueError(
1868 f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1869 )
1870
1871 if self.partition_by is not None:
1872 partition_by_values = (
1873 self.partition_by
1874 if isinstance(self.partition_by, tuple | list)
1875 else (self.partition_by,)
1876 )
1877 self.partition_by = ExpressionList(*partition_by_values)
1878
1879 if self.order_by is not None:
1880 if isinstance(self.order_by, list | tuple):
1881 self.order_by = OrderByList(*self.order_by)
1882 elif isinstance(self.order_by, BaseExpression | str):
1883 self.order_by = OrderByList(self.order_by)
1884 else:
1885 raise ValueError(
1886 "Window.order_by must be either a string reference to a "
1887 "field, an expression, or a list or tuple of them."
1888 )
1889 super().__init__(output_field=output_field)
1890 self.source_expression = self._parse_expressions(expression)[0]
1891
1892 def _resolve_output_field(self) -> Field | None:
1893 return self.source_expression.output_field
1894
1895 def get_source_expressions(self) -> list[Any]:
1896 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1897
1898 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1899 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1900
1901 def as_sql(
1902 self,
1903 compiler: SQLCompiler,
1904 connection: BaseDatabaseWrapper,
1905 template: str | None = None,
1906 ) -> tuple[str, tuple[Any, ...]]:
1907 connection.ops.check_expression_support(self)
1908 if not connection.features.supports_over_clause:
1909 raise NotSupportedError("This backend does not support window expressions.")
1910 expr_sql, params = compiler.compile(self.source_expression)
1911 window_sql, window_params = [], ()
1912
1913 if self.partition_by is not None:
1914 sql_expr, sql_params = self.partition_by.as_sql(
1915 compiler=compiler,
1916 connection=connection,
1917 template="PARTITION BY %(expressions)s",
1918 )
1919 window_sql.append(sql_expr)
1920 window_params += tuple(sql_params)
1921
1922 if self.order_by is not None:
1923 order_sql, order_params = compiler.compile(self.order_by)
1924 window_sql.append(order_sql)
1925 window_params += tuple(order_params)
1926
1927 if self.frame:
1928 frame_sql, frame_params = compiler.compile(self.frame)
1929 window_sql.append(frame_sql)
1930 window_params += tuple(frame_params)
1931
1932 template = template or self.template
1933
1934 return (
1935 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1936 (*params, *window_params),
1937 )
1938
1939 def as_sqlite(
1940 self,
1941 compiler: SQLCompiler,
1942 connection: BaseDatabaseWrapper,
1943 **extra_context: Any,
1944 ) -> tuple[str, Sequence[Any]]:
1945 if isinstance(self.output_field, fields.DecimalField):
1946 # Casting to numeric must be outside of the window expression.
1947 copy = self.copy()
1948 source_expressions = copy.get_source_expressions()
1949 source_expressions[0].output_field = fields.FloatField()
1950 copy.set_source_expressions(source_expressions)
1951 return super(Window, copy).as_sqlite(compiler, connection, **extra_context)
1952 return self.as_sql(compiler, connection, **extra_context)
1953
1954 def __str__(self) -> str:
1955 return "{} OVER ({}{}{})".format(
1956 str(self.source_expression),
1957 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1958 str(self.order_by or ""),
1959 str(self.frame or ""),
1960 )
1961
1962 def __repr__(self) -> str:
1963 return f"<{self.__class__.__name__}: {self}>"
1964
1965 def get_group_by_cols(self) -> list[Any]:
1966 group_by_cols = []
1967 if self.partition_by:
1968 group_by_cols.extend(self.partition_by.get_group_by_cols())
1969 if self.order_by is not None:
1970 group_by_cols.extend(self.order_by.get_group_by_cols())
1971 return group_by_cols
1972
1973
1974class WindowFrame(Expression, ABC):
1975 """
1976 Model the frame clause in window expressions. There are two types of frame
1977 clauses which are subclasses, however, all processing and validation (by no
1978 means intended to be complete) is done here. Thus, providing an end for a
1979 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1980 row in the frame).
1981 """
1982
1983 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1984 frame_type: str
1985
1986 def __init__(self, start: int | None = None, end: int | None = None):
1987 self.start = Value(start)
1988 self.end = Value(end)
1989
1990 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1991 self.start, self.end = exprs
1992
1993 def get_source_expressions(self) -> list[Any]:
1994 return [self.start, self.end]
1995
1996 def as_sql(
1997 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1998 ) -> tuple[str, list[Any]]:
1999 connection.ops.check_expression_support(self)
2000 start, end = self.window_frame_start_end(
2001 connection, self.start.value, self.end.value
2002 )
2003 return (
2004 self.template
2005 % {
2006 "frame_type": self.frame_type,
2007 "start": start,
2008 "end": end,
2009 },
2010 [],
2011 )
2012
2013 def __repr__(self) -> str:
2014 return f"<{self.__class__.__name__}: {self}>"
2015
2016 def get_group_by_cols(self) -> list[Any]:
2017 return []
2018
2019 def __str__(self) -> str:
2020 if self.start.value is not None and self.start.value < 0:
2021 start = f"{abs(self.start.value)} {db_connection.ops.PRECEDING}"
2022 elif self.start.value is not None and self.start.value == 0:
2023 start = db_connection.ops.CURRENT_ROW
2024 else:
2025 start = db_connection.ops.UNBOUNDED_PRECEDING
2026
2027 if self.end.value is not None and self.end.value > 0:
2028 end = f"{self.end.value} {db_connection.ops.FOLLOWING}"
2029 elif self.end.value is not None and self.end.value == 0:
2030 end = db_connection.ops.CURRENT_ROW
2031 else:
2032 end = db_connection.ops.UNBOUNDED_FOLLOWING
2033 return self.template % {
2034 "frame_type": self.frame_type,
2035 "start": start,
2036 "end": end,
2037 }
2038
2039 @abstractmethod
2040 def window_frame_start_end(
2041 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2042 ) -> tuple[str, str]: ...
2043
2044
2045class RowRange(WindowFrame):
2046 frame_type = "ROWS"
2047
2048 def window_frame_start_end(
2049 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2050 ) -> tuple[str, str]:
2051 return connection.ops.window_frame_rows_start_end(start, end)
2052
2053
2054class ValueRange(WindowFrame):
2055 frame_type = "RANGE"
2056
2057 def window_frame_start_end(
2058 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2059 ) -> tuple[str, str]:
2060 return connection.ops.window_frame_range_start_end(start, end)