1from __future__ import annotations
2
3import copy
4import datetime
5import functools
6import inspect
7from abc import ABC, abstractmethod
8from collections import defaultdict
9from decimal import Decimal
10from functools import cached_property
11from types import NoneType
12from typing import TYPE_CHECKING, Any, Protocol, Self, cast, runtime_checkable
13from uuid import UUID
14
15from plain.models import fields
16from plain.models.constants import LOOKUP_SEP
17from plain.models.db import (
18 DatabaseError,
19 NotSupportedError,
20 db_connection,
21)
22from plain.models.exceptions import EmptyResultSet, FieldError, FullResultSet
23from plain.models.query_utils import Q
24from plain.utils.deconstruct import deconstructible
25from plain.utils.hashable import make_hashable
26
27if TYPE_CHECKING:
28 from collections.abc import Callable, Iterable, Sequence
29
30 from plain.models.backends.base.base import BaseDatabaseWrapper
31 from plain.models.fields import Field
32 from plain.models.lookups import Lookup, Transform
33 from plain.models.query import QuerySet
34 from plain.models.sql.compiler import SQLCompilable, SQLCompiler
35 from plain.models.sql.query import Query
36
37
38@runtime_checkable
39class ResolvableExpression(Protocol):
40 """Protocol for expressions that can be resolved in query context."""
41
42 def resolve_expression(
43 self,
44 query: Any = None,
45 allow_joins: bool = True,
46 reuse: Any = None,
47 summarize: bool = False,
48 for_save: bool = False,
49 ) -> Any: ...
50
51
52@runtime_checkable
53class ReplaceableExpression(Protocol):
54 """Protocol for expressions that support expression replacement."""
55
56 def replace_expressions(self, replacements: dict[Any, Any]) -> Self: ...
57
58
59class Combinable:
60 """
61 Provide the ability to combine one or two objects with
62 some connector. For example F('foo') + F('bar').
63 """
64
65 # Arithmetic connectors
66 ADD = "+"
67 SUB = "-"
68 MUL = "*"
69 DIV = "/"
70 POW = "^"
71 # The following is a quoted % operator - it is quoted because it can be
72 # used in strings that also have parameter substitution.
73 MOD = "%%"
74
75 # Bitwise operators - note that these are generated by .bitand()
76 # and .bitor(), the '&' and '|' are reserved for boolean operator
77 # usage.
78 BITAND = "&"
79 BITOR = "|"
80 BITLEFTSHIFT = "<<"
81 BITRIGHTSHIFT = ">>"
82 BITXOR = "#"
83
84 def _combine(
85 self, other: Any, connector: str, reversed: bool
86 ) -> CombinedExpression:
87 if not isinstance(other, ResolvableExpression):
88 # everything must be resolvable to an expression
89 other = Value(other)
90
91 if reversed:
92 return CombinedExpression(other, connector, self)
93 return CombinedExpression(self, connector, other)
94
95 #############
96 # OPERATORS #
97 #############
98
99 def __neg__(self) -> CombinedExpression:
100 return self._combine(-1, self.MUL, False)
101
102 def __add__(self, other: Any) -> CombinedExpression:
103 return self._combine(other, self.ADD, False)
104
105 def __sub__(self, other: Any) -> CombinedExpression:
106 return self._combine(other, self.SUB, False)
107
108 def __mul__(self, other: Any) -> CombinedExpression:
109 return self._combine(other, self.MUL, False)
110
111 def __truediv__(self, other: Any) -> CombinedExpression:
112 return self._combine(other, self.DIV, False)
113
114 def __mod__(self, other: Any) -> CombinedExpression:
115 return self._combine(other, self.MOD, False)
116
117 def __pow__(self, other: Any) -> CombinedExpression:
118 return self._combine(other, self.POW, False)
119
120 def __and__(self, other: Any) -> Q:
121 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
122 return Q(self) & Q(other)
123 raise NotImplementedError(
124 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
125 )
126
127 def bitand(self, other: Any) -> CombinedExpression:
128 return self._combine(other, self.BITAND, False)
129
130 def bitleftshift(self, other: Any) -> CombinedExpression:
131 return self._combine(other, self.BITLEFTSHIFT, False)
132
133 def bitrightshift(self, other: Any) -> CombinedExpression:
134 return self._combine(other, self.BITRIGHTSHIFT, False)
135
136 def __xor__(self, other: Any) -> Q:
137 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
138 return Q(self) ^ Q(other)
139 raise NotImplementedError(
140 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
141 )
142
143 def bitxor(self, other: Any) -> CombinedExpression:
144 return self._combine(other, self.BITXOR, False)
145
146 def __or__(self, other: Any) -> Q:
147 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
148 return Q(self) | Q(other)
149 raise NotImplementedError(
150 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
151 )
152
153 def bitor(self, other: Any) -> CombinedExpression:
154 return self._combine(other, self.BITOR, False)
155
156 def __radd__(self, other: Any) -> CombinedExpression:
157 return self._combine(other, self.ADD, True)
158
159 def __rsub__(self, other: Any) -> CombinedExpression:
160 return self._combine(other, self.SUB, True)
161
162 def __rmul__(self, other: Any) -> CombinedExpression:
163 return self._combine(other, self.MUL, True)
164
165 def __rtruediv__(self, other: Any) -> CombinedExpression:
166 return self._combine(other, self.DIV, True)
167
168 def __rmod__(self, other: Any) -> CombinedExpression:
169 return self._combine(other, self.MOD, True)
170
171 def __rpow__(self, other: Any) -> CombinedExpression:
172 return self._combine(other, self.POW, True)
173
174 def __rand__(self, other: Any) -> None:
175 raise NotImplementedError(
176 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
177 )
178
179 def __ror__(self, other: Any) -> None:
180 raise NotImplementedError(
181 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
182 )
183
184 def __rxor__(self, other: Any) -> None:
185 raise NotImplementedError(
186 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
187 )
188
189 def __invert__(self) -> NegatedExpression:
190 return NegatedExpression(self)
191
192
193class BaseExpression:
194 """Base class for all query expressions."""
195
196 empty_result_set_value = NotImplemented
197 # aggregate specific fields
198 is_summary = False
199 _output_field_resolved_to_none = False
200 # Can the expression be used in a WHERE clause?
201 filterable = True
202 # Can the expression can be used as a source expression in Window?
203 window_compatible = False
204
205 def __init__(self, output_field: Field | None = None):
206 if output_field is not None:
207 self.output_field = output_field
208
209 def __getstate__(self) -> dict[str, Any]:
210 state = self.__dict__.copy()
211 state.pop("convert_value", None)
212 return state
213
214 def get_db_converters(
215 self, connection: BaseDatabaseWrapper
216 ) -> list[Callable[..., Any]]:
217 converters = []
218 if self.convert_value is not self._convert_value_noop:
219 converters.append(self.convert_value)
220 converters.extend(self.output_field.get_db_converters(connection))
221 return converters
222
223 def get_source_expressions(self) -> list[Any]:
224 return []
225
226 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
227 assert not exprs
228
229 def _parse_expressions(self, *expressions: Any) -> list[Any]:
230 return [
231 arg
232 if isinstance(arg, ResolvableExpression)
233 else (F(arg) if isinstance(arg, str) else Value(arg))
234 for arg in expressions
235 ]
236
237 def as_sql(
238 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
239 ) -> tuple[str, Sequence[Any]]:
240 """
241 Responsible for returning a (sql, [params]) tuple to be included
242 in the current query.
243
244 Different backends can provide their own implementation, by
245 providing an `as_{vendor}` method and patching the Expression:
246
247 ```
248 def override_as_sql(self, compiler, connection):
249 # custom logic
250 return super().as_sql(compiler, connection)
251 setattr(Expression, 'as_' + connection.vendor, override_as_sql)
252 ```
253
254 Arguments:
255 * compiler: the query compiler responsible for generating the query.
256 Must have a compile method, returning a (sql, [params]) tuple.
257 Calling compiler(value) will return a quoted `value`.
258
259 * connection: the database connection used for the current query.
260
261 Return: (sql, params)
262 Where `sql` is a string containing ordered sql parameters to be
263 replaced with the elements of the list `params`.
264 """
265 raise NotImplementedError("Subclasses must implement as_sql()")
266
267 @cached_property
268 def contains_aggregate(self) -> bool:
269 return any(
270 expr and expr.contains_aggregate for expr in self.get_source_expressions()
271 )
272
273 @cached_property
274 def contains_over_clause(self) -> bool:
275 return any(
276 expr and expr.contains_over_clause for expr in self.get_source_expressions()
277 )
278
279 @cached_property
280 def contains_column_references(self) -> bool:
281 return any(
282 expr and expr.contains_column_references
283 for expr in self.get_source_expressions()
284 )
285
286 def resolve_expression(
287 self,
288 query: Any = None,
289 allow_joins: bool = True,
290 reuse: Any = None,
291 summarize: bool = False,
292 for_save: bool = False,
293 ) -> Self:
294 """
295 Provide the chance to do any preprocessing or validation before being
296 added to the query.
297
298 Arguments:
299 * query: the backend query implementation
300 * allow_joins: boolean allowing or denying use of joins
301 in this query
302 * reuse: a set of reusable joins for multijoins
303 * summarize: a terminal aggregate clause
304 * for_save: whether this expression about to be used in a save or update
305
306 Return: an Expression to be added to the query.
307 """
308 c = self.copy()
309 c.is_summary = summarize
310 c.set_source_expressions(
311 [
312 expr.resolve_expression(query, allow_joins, reuse, summarize)
313 if expr
314 else None
315 for expr in c.get_source_expressions()
316 ]
317 )
318 return c
319
320 @property
321 def conditional(self) -> bool:
322 output_field = getattr(self, "output_field", None)
323 return isinstance(output_field, fields.BooleanField)
324
325 @property
326 def field(self) -> Field:
327 return self.output_field
328
329 @cached_property
330 def output_field(self) -> Field:
331 """Return the output type of this expressions."""
332 output_field = self._resolve_output_field()
333 if output_field is None:
334 self._output_field_resolved_to_none = True
335 raise FieldError("Cannot resolve expression type, unknown output_field")
336 return output_field
337
338 @cached_property
339 def _output_field_or_none(self) -> Field | None:
340 """
341 Return the output field of this expression, or None if
342 _resolve_output_field() didn't return an output type.
343 """
344 try:
345 return self.output_field
346 except FieldError:
347 if not self._output_field_resolved_to_none:
348 raise
349 return None
350
351 def _resolve_output_field(self) -> Field | None:
352 """
353 Attempt to infer the output type of the expression.
354
355 As a guess, if the output fields of all source fields match then simply
356 infer the same type here.
357
358 If a source's output field resolves to None, exclude it from this check.
359 If all sources are None, then an error is raised higher up the stack in
360 the output_field property.
361 """
362 # This guess is mostly a bad idea, but there is quite a lot of code
363 # (especially 3rd party Func subclasses) that depend on it, we'd need a
364 # deprecation path to fix it.
365 sources_iter = (
366 source for source in self.get_source_fields() if source is not None
367 )
368 for output_field in sources_iter:
369 for source in sources_iter:
370 if not isinstance(output_field, source.__class__):
371 raise FieldError(
372 f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
373 "set output_field."
374 )
375 return output_field
376 return None
377
378 @staticmethod
379 def _convert_value_noop(
380 value: Any, expression: Any, connection: BaseDatabaseWrapper
381 ) -> Any:
382 return value
383
384 @cached_property
385 def convert_value(self) -> Callable[[Any, Any, Any], Any]:
386 """
387 Expressions provide their own converters because users have the option
388 of manually specifying the output_field which may be a different type
389 from the one the database returns.
390 """
391 field = self.output_field
392 internal_type = field.get_internal_type()
393 if internal_type == "FloatField":
394 return (
395 lambda value, expression, connection: None
396 if value is None
397 else float(value)
398 )
399 elif internal_type.endswith("IntegerField"):
400 return (
401 lambda value, expression, connection: None
402 if value is None
403 else int(value)
404 )
405 elif internal_type == "DecimalField":
406 return (
407 lambda value, expression, connection: None
408 if value is None
409 else Decimal(value)
410 )
411 return self._convert_value_noop
412
413 def get_lookup(self, lookup: str) -> type[Lookup] | None:
414 return self.output_field.get_lookup(lookup)
415
416 def get_transform(self, name: str) -> type[Transform] | None:
417 return self.output_field.get_transform(name)
418
419 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
420 clone = self.copy()
421 clone.set_source_expressions(
422 [
423 e.relabeled_clone(change_map) if e is not None else None
424 for e in self.get_source_expressions()
425 ]
426 )
427 return clone
428
429 def replace_expressions(self, replacements: dict[BaseExpression, Any]) -> Self:
430 if replacement := replacements.get(self):
431 return replacement
432 clone = self.copy()
433 source_expressions = clone.get_source_expressions()
434 clone.set_source_expressions(
435 [
436 expr.replace_expressions(replacements) if expr else None
437 for expr in source_expressions
438 ]
439 )
440 return clone
441
442 def get_refs(self) -> set[str]:
443 refs = set()
444 for expr in self.get_source_expressions():
445 refs |= expr.get_refs()
446 return refs
447
448 def copy(self) -> Self:
449 return copy.copy(self)
450
451 def prefix_references(self, prefix: str) -> Self:
452 clone = self.copy()
453 clone.set_source_expressions(
454 [
455 F(f"{prefix}{expr.name}")
456 if isinstance(expr, F)
457 else expr.prefix_references(prefix)
458 for expr in self.get_source_expressions()
459 ]
460 )
461 return clone
462
463 def get_group_by_cols(self) -> list[BaseExpression]:
464 if not self.contains_aggregate:
465 return [self]
466 cols = []
467 for source in self.get_source_expressions():
468 cols.extend(source.get_group_by_cols())
469 return cols
470
471 def get_source_fields(self) -> list[Field | None]:
472 """Return the underlying field types used by this aggregate."""
473 return [e._output_field_or_none for e in self.get_source_expressions()]
474
475 def asc(self, **kwargs: Any) -> OrderBy:
476 return OrderBy(self, **kwargs)
477
478 def desc(self, **kwargs: Any) -> OrderBy:
479 return OrderBy(self, descending=True, **kwargs)
480
481 def reverse_ordering(self) -> Self:
482 return self
483
484 def flatten(self) -> Iterable[Any]:
485 """
486 Recursively yield this expression and all subexpressions, in
487 depth-first order.
488 """
489 yield self
490 for expr in self.get_source_expressions():
491 if expr:
492 if hasattr(expr, "flatten"):
493 yield from expr.flatten()
494 else:
495 yield expr
496
497 def select_format(
498 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
499 ) -> tuple[str, Sequence[Any]]:
500 """
501 Custom format for select clauses. For example, EXISTS expressions need
502 to be wrapped in CASE WHEN on Oracle.
503 """
504 if output_field := getattr(self, "output_field", None):
505 if select_format := getattr(output_field, "select_format", None):
506 return select_format(compiler, sql, params)
507 return sql, params
508
509
510@deconstructible
511class Expression(BaseExpression, Combinable):
512 """An expression that can be combined with other expressions."""
513
514 # Set by @deconstructible decorator in __new__
515 _constructor_args: tuple[tuple[Any, ...], dict[str, Any]]
516
517 @cached_property
518 def identity(self) -> tuple[Any, ...]:
519 constructor_signature = inspect.signature(self.__init__)
520 args, kwargs = self._constructor_args
521 signature = constructor_signature.bind_partial(*args, **kwargs)
522 signature.apply_defaults()
523 arguments = signature.arguments.items()
524 identity: list[Any] = [self.__class__]
525 for arg, value in arguments:
526 if isinstance(value, fields.Field):
527 if value.name and value.model:
528 value = (value.model.model_options.label, value.name)
529 else:
530 value = type(value)
531 else:
532 value = make_hashable(value)
533 identity.append((arg, value))
534 return tuple(identity)
535
536 def __eq__(self, other: object) -> bool:
537 if not isinstance(other, Expression):
538 return NotImplemented
539 return other.identity == self.identity
540
541 def __hash__(self) -> int:
542 return hash(self.identity)
543
544
545# Type inference for CombinedExpression.output_field.
546# Missing items will result in FieldError, by design.
547#
548# The current approach for NULL is based on lowest common denominator behavior
549# i.e. if one of the supported databases is raising an error (rather than
550# return NULL) for `val <op> NULL`, then Plain raises FieldError.
551
552_connector_combinations = [
553 # Numeric operations - operands of same type.
554 {
555 connector: [
556 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
557 (fields.FloatField, fields.FloatField, fields.FloatField),
558 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
559 ]
560 for connector in (
561 Combinable.ADD,
562 Combinable.SUB,
563 Combinable.MUL,
564 # Behavior for DIV with integer arguments follows Postgres/SQLite,
565 # not MySQL/Oracle.
566 Combinable.DIV,
567 Combinable.MOD,
568 Combinable.POW,
569 )
570 },
571 # Numeric operations - operands of different type.
572 {
573 connector: [
574 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
575 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
576 (fields.IntegerField, fields.FloatField, fields.FloatField),
577 (fields.FloatField, fields.IntegerField, fields.FloatField),
578 ]
579 for connector in (
580 Combinable.ADD,
581 Combinable.SUB,
582 Combinable.MUL,
583 Combinable.DIV,
584 Combinable.MOD,
585 )
586 },
587 # Bitwise operators.
588 {
589 connector: [
590 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
591 ]
592 for connector in (
593 Combinable.BITAND,
594 Combinable.BITOR,
595 Combinable.BITLEFTSHIFT,
596 Combinable.BITRIGHTSHIFT,
597 Combinable.BITXOR,
598 )
599 },
600 # Numeric with NULL.
601 {
602 connector: [
603 (field_type, NoneType, field_type),
604 (NoneType, field_type, field_type),
605 ]
606 for connector in (
607 Combinable.ADD,
608 Combinable.SUB,
609 Combinable.MUL,
610 Combinable.DIV,
611 Combinable.MOD,
612 Combinable.POW,
613 )
614 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
615 },
616 # Date/DateTimeField/DurationField/TimeField.
617 {
618 Combinable.ADD: [
619 # Date/DateTimeField.
620 (fields.DateField, fields.DurationField, fields.DateTimeField),
621 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
622 (fields.DurationField, fields.DateField, fields.DateTimeField),
623 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
624 # DurationField.
625 (fields.DurationField, fields.DurationField, fields.DurationField),
626 # TimeField.
627 (fields.TimeField, fields.DurationField, fields.TimeField),
628 (fields.DurationField, fields.TimeField, fields.TimeField),
629 ],
630 },
631 {
632 Combinable.SUB: [
633 # Date/DateTimeField.
634 (fields.DateField, fields.DurationField, fields.DateTimeField),
635 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
636 (fields.DateField, fields.DateField, fields.DurationField),
637 (fields.DateField, fields.DateTimeField, fields.DurationField),
638 (fields.DateTimeField, fields.DateField, fields.DurationField),
639 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
640 # DurationField.
641 (fields.DurationField, fields.DurationField, fields.DurationField),
642 # TimeField.
643 (fields.TimeField, fields.DurationField, fields.TimeField),
644 (fields.TimeField, fields.TimeField, fields.DurationField),
645 ],
646 },
647]
648
649_connector_combinators = defaultdict(list)
650
651
652def register_combinable_fields(
653 lhs: type[Field] | type[None],
654 connector: str,
655 rhs: type[Field] | type[None],
656 result: type[Field],
657) -> None:
658 """
659 Register combinable types:
660 lhs <connector> rhs -> result
661 e.g.
662 register_combinable_fields(
663 IntegerField, Combinable.ADD, FloatField, FloatField
664 )
665 """
666 _connector_combinators[connector].append((lhs, rhs, result))
667
668
669for d in _connector_combinations:
670 for connector, field_types in d.items():
671 for lhs, rhs, result in field_types:
672 register_combinable_fields(lhs, connector, rhs, result)
673
674
675@functools.lru_cache(maxsize=128)
676def _resolve_combined_type(
677 connector: str, lhs_type: type[Field], rhs_type: type[Field]
678) -> type[Field] | None:
679 combinators = _connector_combinators.get(connector, ())
680 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
681 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
682 rhs_type, combinator_rhs_type
683 ):
684 return combined_type
685 return None
686
687
688class SQLiteNumericMixin(Expression):
689 """
690 Some expressions with output_field=DecimalField() must be cast to
691 numeric to be properly filtered.
692 """
693
694 def as_sqlite(
695 self,
696 compiler: SQLCompiler,
697 connection: BaseDatabaseWrapper,
698 **extra_context: Any,
699 ) -> tuple[str, Sequence[Any]]:
700 sql, params = self.as_sql(compiler, connection, **extra_context)
701 try:
702 if self.output_field.get_internal_type() == "DecimalField":
703 sql = f"CAST({sql} AS NUMERIC)"
704 except FieldError:
705 pass
706 return sql, params
707
708
709class CombinedExpression(SQLiteNumericMixin, Expression):
710 def __init__(
711 self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
712 ):
713 super().__init__(output_field=output_field)
714 self.connector = connector
715 self.lhs = lhs
716 self.rhs = rhs
717
718 def __repr__(self) -> str:
719 return f"<{self.__class__.__name__}: {self}>"
720
721 def __str__(self) -> str:
722 return f"{self.lhs} {self.connector} {self.rhs}"
723
724 def get_source_expressions(self) -> list[Any]:
725 return [self.lhs, self.rhs]
726
727 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
728 self.lhs, self.rhs = exprs
729
730 def _resolve_output_field(self) -> Field | None:
731 # We avoid using super() here for reasons given in
732 # Expression._resolve_output_field()
733 combined_type = _resolve_combined_type(
734 self.connector,
735 type(self.lhs._output_field_or_none),
736 type(self.rhs._output_field_or_none),
737 )
738 if combined_type is None:
739 raise FieldError(
740 f"Cannot infer type of {self.connector!r} expression involving these "
741 f"types: {self.lhs.output_field.__class__.__name__}, "
742 f"{self.rhs.output_field.__class__.__name__}. You must set "
743 f"output_field."
744 )
745 return combined_type()
746
747 def as_sql(
748 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
749 ) -> tuple[str, list[Any]]:
750 expressions = []
751 expression_params = []
752 sql, params = compiler.compile(self.lhs)
753 expressions.append(sql)
754 expression_params.extend(params)
755 sql, params = compiler.compile(self.rhs)
756 expressions.append(sql)
757 expression_params.extend(params)
758 # order of precedence
759 expression_wrapper = "(%s)"
760 sql = connection.ops.combine_expression(self.connector, expressions)
761 return expression_wrapper % sql, expression_params
762
763 def resolve_expression(
764 self,
765 query: Any = None,
766 allow_joins: bool = True,
767 reuse: Any = None,
768 summarize: bool = False,
769 for_save: bool = False,
770 ) -> CombinedExpression | DurationExpression | TemporalSubtraction:
771 lhs = self.lhs.resolve_expression(
772 query, allow_joins, reuse, summarize, for_save
773 )
774 rhs = self.rhs.resolve_expression(
775 query, allow_joins, reuse, summarize, for_save
776 )
777 if not isinstance(self, DurationExpression | TemporalSubtraction):
778 try:
779 lhs_type = lhs.output_field.get_internal_type()
780 except (AttributeError, FieldError):
781 lhs_type = None
782 try:
783 rhs_type = rhs.output_field.get_internal_type()
784 except (AttributeError, FieldError):
785 rhs_type = None
786 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
787 return DurationExpression(
788 self.lhs, self.connector, self.rhs
789 ).resolve_expression(
790 query,
791 allow_joins,
792 reuse,
793 summarize,
794 for_save,
795 )
796 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
797 if (
798 self.connector == self.SUB
799 and lhs_type in datetime_fields
800 and lhs_type == rhs_type
801 ):
802 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
803 query,
804 allow_joins,
805 reuse,
806 summarize,
807 for_save,
808 )
809 c = self.copy()
810 c.is_summary = summarize
811 c.lhs = lhs
812 c.rhs = rhs
813 return c
814
815
816class DurationExpression(CombinedExpression):
817 def compile(
818 self, side: Any, compiler: SQLCompiler, connection: BaseDatabaseWrapper
819 ) -> tuple[str, Sequence[Any]]:
820 try:
821 output = side.output_field
822 except FieldError:
823 pass
824 else:
825 if output.get_internal_type() == "DurationField":
826 sql, params = compiler.compile(side)
827 return connection.ops.format_for_duration_arithmetic(sql), params
828 return compiler.compile(side)
829
830 def as_sql(
831 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
832 ) -> tuple[str, list[Any]]:
833 if connection.features.has_native_duration_field:
834 return super().as_sql(compiler, connection)
835 connection.ops.check_expression_support(self)
836 expressions = []
837 expression_params = []
838 sql, params = self.compile(self.lhs, compiler, connection)
839 expressions.append(sql)
840 expression_params.extend(params)
841 sql, params = self.compile(self.rhs, compiler, connection)
842 expressions.append(sql)
843 expression_params.extend(params)
844 # order of precedence
845 expression_wrapper = "(%s)"
846 sql = connection.ops.combine_duration_expression(self.connector, expressions)
847 return expression_wrapper % sql, expression_params
848
849 def as_sqlite(
850 self,
851 compiler: SQLCompiler,
852 connection: BaseDatabaseWrapper,
853 **extra_context: Any,
854 ) -> tuple[str, Sequence[Any]]:
855 sql, params = self.as_sql(compiler, connection, **extra_context)
856 if self.connector in {Combinable.MUL, Combinable.DIV}:
857 try:
858 lhs_type = self.lhs.output_field.get_internal_type()
859 rhs_type = self.rhs.output_field.get_internal_type()
860 except (AttributeError, FieldError):
861 pass
862 else:
863 allowed_fields = {
864 "DecimalField",
865 "DurationField",
866 "FloatField",
867 "IntegerField",
868 }
869 if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
870 raise DatabaseError(
871 f"Invalid arguments for operator {self.connector}."
872 )
873 return sql, params
874
875
876class TemporalSubtraction(CombinedExpression):
877 output_field = fields.DurationField()
878
879 def __init__(self, lhs: Any, rhs: Any):
880 super().__init__(lhs, self.SUB, rhs)
881
882 def as_sql(
883 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
884 ) -> tuple[str, Sequence[Any]]:
885 connection.ops.check_expression_support(self)
886 lhs = compiler.compile(self.lhs)
887 rhs = compiler.compile(self.rhs)
888 return connection.ops.subtract_temporals(
889 self.lhs.output_field.get_internal_type(), lhs, rhs
890 )
891
892
893@deconstructible(path="plain.models.F")
894class F(Combinable):
895 """An object capable of resolving references to existing query objects."""
896
897 def __init__(self, name: str):
898 """
899 Arguments:
900 * name: the name of the field this expression references
901 """
902 self.name = name
903
904 def __repr__(self) -> str:
905 return f"{self.__class__.__name__}({self.name})"
906
907 def resolve_expression(
908 self,
909 query: Any = None,
910 allow_joins: bool = True,
911 reuse: Any = None,
912 summarize: bool = False,
913 for_save: bool = False,
914 ) -> Any:
915 return query.resolve_ref(self.name, allow_joins, reuse, summarize)
916
917 def replace_expressions(self, replacements: dict[Any, Any]) -> F:
918 return replacements.get(self, self)
919
920 def asc(self, **kwargs: Any) -> OrderBy:
921 return OrderBy(self, **kwargs)
922
923 def desc(self, **kwargs: Any) -> OrderBy:
924 return OrderBy(self, descending=True, **kwargs)
925
926 def __eq__(self, other: object) -> bool:
927 if not isinstance(other, F):
928 return NotImplemented
929 return self.__class__ == other.__class__ and self.name == other.name
930
931 def __hash__(self) -> int:
932 return hash(self.name)
933
934 def copy(self) -> Self:
935 return copy.copy(self)
936
937
938class ResolvedOuterRef(F):
939 """
940 An object that contains a reference to an outer query.
941
942 In this case, the reference to the outer query has been resolved because
943 the inner query has been used as a subquery.
944 """
945
946 contains_aggregate = False
947 contains_over_clause = False
948
949 def as_sql(self, *args: Any, **kwargs: Any) -> None:
950 raise ValueError(
951 "This queryset contains a reference to an outer query and may "
952 "only be used in a subquery."
953 )
954
955 def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
956 col = super().resolve_expression(*args, **kwargs)
957 if col.contains_over_clause:
958 raise NotSupportedError(
959 f"Referencing outer query window expression is not supported: "
960 f"{self.name}."
961 )
962 # FIXME: Rename possibly_multivalued to multivalued and fix detection
963 # for non-multivalued JOINs (e.g. foreign key fields). This should take
964 # into account only many-to-many and one-to-many relationships.
965 col.possibly_multivalued = LOOKUP_SEP in self.name
966 return col
967
968 def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
969 return self
970
971 def get_group_by_cols(self) -> list[Any]:
972 return []
973
974
975class OuterRef(F):
976 contains_aggregate = False
977
978 def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
979 if isinstance(self.name, self.__class__):
980 return self.name
981 return ResolvedOuterRef(self.name)
982
983 def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
984 return self
985
986
987@deconstructible(path="plain.models.Func")
988class Func(SQLiteNumericMixin, Expression):
989 """An SQL function call."""
990
991 function = None
992 template = "%(function)s(%(expressions)s)"
993 arg_joiner = ", "
994 arity = None # The number of arguments the function accepts.
995
996 def __init__(
997 self, *expressions: Any, output_field: Field | None = None, **extra: Any
998 ):
999 if self.arity is not None and len(expressions) != self.arity:
1000 raise TypeError(
1001 "'{}' takes exactly {} {} ({} given)".format(
1002 self.__class__.__name__,
1003 self.arity,
1004 "argument" if self.arity == 1 else "arguments",
1005 len(expressions),
1006 )
1007 )
1008 super().__init__(output_field=output_field)
1009 self.source_expressions: list[Any] = self._parse_expressions(*expressions)
1010 self.extra = extra
1011
1012 def __repr__(self) -> str:
1013 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1014 extra = {**self.extra, **self._get_repr_options()}
1015 if extra:
1016 extra = ", ".join(
1017 str(key) + "=" + str(val) for key, val in sorted(extra.items())
1018 )
1019 return f"{self.__class__.__name__}({args}, {extra})"
1020 return f"{self.__class__.__name__}({args})"
1021
1022 def _get_repr_options(self) -> dict[str, Any]:
1023 """Return a dict of extra __init__() options to include in the repr."""
1024 return {}
1025
1026 def get_source_expressions(self) -> list[Any]:
1027 return self.source_expressions
1028
1029 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1030 self.source_expressions = list(exprs)
1031
1032 def resolve_expression(
1033 self,
1034 query: Any = None,
1035 allow_joins: bool = True,
1036 reuse: Any = None,
1037 summarize: bool = False,
1038 for_save: bool = False,
1039 ) -> Self:
1040 c = self.copy()
1041 c.is_summary = summarize
1042 for pos, arg in enumerate(c.source_expressions):
1043 c.source_expressions[pos] = arg.resolve_expression(
1044 query, allow_joins, reuse, summarize, for_save
1045 )
1046 return c
1047
1048 def as_sql(
1049 self,
1050 compiler: SQLCompiler,
1051 connection: BaseDatabaseWrapper,
1052 function: str | None = None,
1053 template: str | None = None,
1054 arg_joiner: str | None = None,
1055 **extra_context: Any,
1056 ) -> tuple[str, list[Any]]:
1057 connection.ops.check_expression_support(self)
1058 sql_parts = []
1059 params = []
1060 for arg in self.source_expressions:
1061 try:
1062 arg_sql, arg_params = compiler.compile(arg)
1063 except EmptyResultSet:
1064 empty_result_set_value = getattr(
1065 arg, "empty_result_set_value", NotImplemented
1066 )
1067 if empty_result_set_value is NotImplemented:
1068 raise
1069 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
1070 except FullResultSet:
1071 arg_sql, arg_params = compiler.compile(Value(True))
1072 sql_parts.append(arg_sql)
1073 params.extend(arg_params)
1074 data = {**self.extra, **extra_context}
1075 # Use the first supplied value in this order: the parameter to this
1076 # method, a value supplied in __init__()'s **extra (the value in
1077 # `data`), or the value defined on the class.
1078 if function is not None:
1079 data["function"] = function
1080 else:
1081 data.setdefault("function", self.function)
1082 template = template or data.get("template", self.template)
1083 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1084 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1085 return template % data, params
1086
1087 def copy(self) -> Self:
1088 clone = super().copy()
1089 clone.source_expressions = self.source_expressions[:]
1090 clone.extra = self.extra.copy()
1091 return cast(Self, clone)
1092
1093
1094@deconstructible(path="plain.models.Value")
1095class Value(SQLiteNumericMixin, Expression):
1096 """Represent a wrapped value as a node within an expression."""
1097
1098 # Provide a default value for `for_save` in order to allow unresolved
1099 # instances to be compiled until a decision is taken in #25425.
1100 for_save = False
1101
1102 def __init__(self, value: Any, output_field: Field | None = None):
1103 """
1104 Arguments:
1105 * value: the value this expression represents. The value will be
1106 added into the sql parameter list and properly quoted.
1107
1108 * output_field: an instance of the model field type that this
1109 expression will return, such as IntegerField() or CharField().
1110 """
1111 super().__init__(output_field=output_field)
1112 self.value = value
1113
1114 def __repr__(self) -> str:
1115 return f"{self.__class__.__name__}({self.value!r})"
1116
1117 def as_sql(
1118 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1119 ) -> tuple[str, list[Any]]:
1120 connection.ops.check_expression_support(self)
1121 val = self.value
1122 output_field = self._output_field_or_none
1123 if output_field is not None:
1124 if self.for_save:
1125 val = output_field.get_db_prep_save(val, connection=connection)
1126 else:
1127 val = output_field.get_db_prep_value(val, connection=connection)
1128 if hasattr(output_field, "get_placeholder"):
1129 return output_field.get_placeholder(val, compiler, connection), [val]
1130 if val is None:
1131 # cx_Oracle does not always convert None to the appropriate
1132 # NULL type (like in case expressions using numbers), so we
1133 # use a literal SQL NULL
1134 return "NULL", []
1135 return "%s", [val]
1136
1137 def resolve_expression(
1138 self,
1139 query: Any = None,
1140 allow_joins: bool = True,
1141 reuse: Any = None,
1142 summarize: bool = False,
1143 for_save: bool = False,
1144 ) -> Value:
1145 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1146 c.for_save = for_save
1147 return c
1148
1149 def get_group_by_cols(self) -> list[Any]:
1150 return []
1151
1152 def _resolve_output_field(self) -> Field | None:
1153 if isinstance(self.value, str):
1154 return fields.CharField()
1155 if isinstance(self.value, bool):
1156 return fields.BooleanField()
1157 if isinstance(self.value, int):
1158 return fields.IntegerField()
1159 if isinstance(self.value, float):
1160 return fields.FloatField()
1161 if isinstance(self.value, datetime.datetime):
1162 return fields.DateTimeField()
1163 if isinstance(self.value, datetime.date):
1164 return fields.DateField()
1165 if isinstance(self.value, datetime.time):
1166 return fields.TimeField()
1167 if isinstance(self.value, datetime.timedelta):
1168 return fields.DurationField()
1169 if isinstance(self.value, Decimal):
1170 return fields.DecimalField()
1171 if isinstance(self.value, bytes):
1172 return fields.BinaryField()
1173 if isinstance(self.value, UUID):
1174 return fields.UUIDField()
1175
1176 @property
1177 def empty_result_set_value(self) -> Any:
1178 return self.value
1179
1180
1181class RawSQL(Expression):
1182 def __init__(
1183 self, sql: str, params: Sequence[Any], output_field: Field | None = None
1184 ):
1185 if output_field is None:
1186 output_field = fields.Field()
1187 self.sql, self.params = sql, params
1188 super().__init__(output_field=output_field)
1189
1190 def __repr__(self) -> str:
1191 return f"{self.__class__.__name__}({self.sql}, {self.params})"
1192
1193 def as_sql(
1194 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1195 ) -> tuple[str, Sequence[Any]]:
1196 return f"({self.sql})", self.params
1197
1198 def get_group_by_cols(self) -> list[RawSQL]:
1199 return [self]
1200
1201
1202class Star(Expression):
1203 def __repr__(self) -> str:
1204 return "'*'"
1205
1206 def as_sql(
1207 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1208 ) -> tuple[str, list[Any]]:
1209 return "*", []
1210
1211
1212class Col(Expression):
1213 contains_column_references = True
1214 possibly_multivalued = False
1215
1216 def __init__(
1217 self, alias: str | None, target: Any, output_field: Field | None = None
1218 ):
1219 if output_field is None:
1220 output_field = target
1221 super().__init__(output_field=output_field)
1222 self.alias, self.target = alias, target
1223
1224 def __repr__(self) -> str:
1225 alias, target = self.alias, self.target
1226 identifiers = (alias, str(target)) if alias else (str(target),)
1227 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1228
1229 def as_sql(
1230 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1231 ) -> tuple[str, list[Any]]:
1232 alias, column = self.alias, self.target.column
1233 identifiers = (alias, column) if alias else (column,)
1234 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1235 return sql, []
1236
1237 def relabeled_clone(self, relabels: dict[str, str]) -> Col:
1238 if self.alias is None:
1239 return self
1240 return self.__class__(
1241 relabels.get(self.alias, self.alias), self.target, self.output_field
1242 )
1243
1244 def get_group_by_cols(self) -> list[Col]:
1245 return [self]
1246
1247 def get_db_converters(
1248 self, connection: BaseDatabaseWrapper
1249 ) -> list[Callable[..., Any]]:
1250 if self.target == self.output_field:
1251 return self.output_field.get_db_converters(connection)
1252 return self.output_field.get_db_converters(
1253 connection
1254 ) + self.target.get_db_converters(connection)
1255
1256
1257class Ref(Expression):
1258 """
1259 Reference to column alias of the query. For example, Ref('sum_cost') in
1260 qs.annotate(sum_cost=Sum('cost')) query.
1261 """
1262
1263 def __init__(self, refs: str, source: Any):
1264 super().__init__()
1265 self.refs, self.source = refs, source
1266
1267 def __repr__(self) -> str:
1268 return f"{self.__class__.__name__}({self.refs}, {self.source})"
1269
1270 def get_source_expressions(self) -> list[Any]:
1271 return [self.source]
1272
1273 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1274 (self.source,) = exprs
1275
1276 def resolve_expression(
1277 self,
1278 query: Any = None,
1279 allow_joins: bool = True,
1280 reuse: Any = None,
1281 summarize: bool = False,
1282 for_save: bool = False,
1283 ) -> Ref:
1284 # The sub-expression `source` has already been resolved, as this is
1285 # just a reference to the name of `source`.
1286 return self
1287
1288 def get_refs(self) -> set[str]:
1289 return {self.refs}
1290
1291 def relabeled_clone(self, relabels: dict[str, str]) -> Ref:
1292 return self
1293
1294 def as_sql(
1295 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1296 ) -> tuple[str, list[Any]]:
1297 return connection.ops.quote_name(self.refs), []
1298
1299 def get_group_by_cols(self) -> list[Ref]:
1300 return [self]
1301
1302
1303class ExpressionList(Func):
1304 """
1305 An expression containing multiple expressions. Can be used to provide a
1306 list of expressions as an argument to another expression, like a partition
1307 clause.
1308 """
1309
1310 template = "%(expressions)s"
1311
1312 def __init__(self, *expressions: Any, **extra: Any):
1313 if not expressions:
1314 raise ValueError(
1315 f"{self.__class__.__name__} requires at least one expression."
1316 )
1317 super().__init__(*expressions, **extra)
1318
1319 def __str__(self) -> str:
1320 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1321
1322 def as_sqlite(
1323 self,
1324 compiler: SQLCompiler,
1325 connection: BaseDatabaseWrapper,
1326 **extra_context: Any,
1327 ) -> tuple[str, Sequence[Any]]:
1328 # Casting to numeric is unnecessary.
1329 return self.as_sql(compiler, connection, **extra_context)
1330
1331
1332class OrderByList(Func):
1333 template = "ORDER BY %(expressions)s"
1334
1335 def __init__(self, *expressions: Any, **extra: Any):
1336 expressions_tuple = tuple(
1337 (
1338 OrderBy(F(expr[1:]), descending=True)
1339 if isinstance(expr, str) and expr[0] == "-"
1340 else expr
1341 )
1342 for expr in expressions
1343 )
1344 super().__init__(*expressions_tuple, **extra)
1345
1346 def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, tuple[Any, ...]]:
1347 if not self.source_expressions:
1348 return "", cast(tuple[Any, ...], ())
1349 sql, params = super().as_sql(*args, **kwargs)
1350 return sql, tuple(params)
1351
1352 def get_group_by_cols(self) -> list[Any]:
1353 group_by_cols = []
1354 for order_by in self.get_source_expressions():
1355 group_by_cols.extend(order_by.get_group_by_cols())
1356 return group_by_cols
1357
1358
1359@deconstructible(path="plain.models.ExpressionWrapper")
1360class ExpressionWrapper(SQLiteNumericMixin, Expression):
1361 """
1362 An expression that can wrap another expression so that it can provide
1363 extra context to the inner expression, such as the output_field.
1364 """
1365
1366 def __init__(self, expression: Any, output_field: Field):
1367 super().__init__(output_field=output_field)
1368 self.expression = expression
1369
1370 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1371 self.expression = exprs[0]
1372
1373 def get_source_expressions(self) -> list[Any]:
1374 return [self.expression]
1375
1376 def get_group_by_cols(self) -> list[Any]:
1377 if isinstance(self.expression, Expression):
1378 expression = self.expression.copy()
1379 expression.output_field = self.output_field
1380 return expression.get_group_by_cols()
1381 # For non-expressions e.g. an SQL WHERE clause, the entire
1382 # `expression` must be included in the GROUP BY clause.
1383 return super().get_group_by_cols()
1384
1385 def as_sql(
1386 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1387 ) -> tuple[str, Sequence[Any]]:
1388 return compiler.compile(self.expression)
1389
1390 def __repr__(self) -> str:
1391 return f"{self.__class__.__name__}({self.expression})"
1392
1393
1394class NegatedExpression(ExpressionWrapper):
1395 """The logical negation of a conditional expression."""
1396
1397 def __init__(self, expression: Any):
1398 super().__init__(expression, output_field=fields.BooleanField())
1399
1400 def __invert__(self) -> Any:
1401 return self.expression.copy()
1402
1403 def as_sql(
1404 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1405 ) -> tuple[str, Sequence[Any]]:
1406 try:
1407 sql, params = super().as_sql(compiler, connection)
1408 except EmptyResultSet:
1409 features = compiler.connection.features
1410 if not features.supports_boolean_expr_in_select_clause:
1411 return "1=1", ()
1412 return compiler.compile(Value(True))
1413 ops = compiler.connection.ops
1414 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1415 # to be compared to another expression unless they're wrapped in a CASE
1416 # WHEN.
1417 if not ops.conditional_expression_supported_in_where_clause(self.expression):
1418 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1419 return f"NOT {sql}", params
1420
1421 def resolve_expression(
1422 self,
1423 query: Any = None,
1424 allow_joins: bool = True,
1425 reuse: Any = None,
1426 summarize: bool = False,
1427 for_save: bool = False,
1428 ) -> NegatedExpression:
1429 resolved = super().resolve_expression(
1430 query, allow_joins, reuse, summarize, for_save
1431 )
1432 if not getattr(resolved.expression, "conditional", False):
1433 raise TypeError("Cannot negate non-conditional expressions.")
1434 return resolved
1435
1436 def select_format(
1437 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1438 ) -> tuple[str, Sequence[Any]]:
1439 # Wrap boolean expressions with a CASE WHEN expression if a database
1440 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1441 # GROUP BY list.
1442 expression_supported_in_where_clause = (
1443 compiler.connection.ops.conditional_expression_supported_in_where_clause
1444 )
1445 if (
1446 not compiler.connection.features.supports_boolean_expr_in_select_clause
1447 # Avoid double wrapping.
1448 and expression_supported_in_where_clause(self.expression)
1449 ):
1450 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1451 return sql, params
1452
1453
1454@deconstructible(path="plain.models.When")
1455class When(Expression):
1456 template = "WHEN %(condition)s THEN %(result)s"
1457 # This isn't a complete conditional expression, must be used in Case().
1458 conditional = False
1459 condition: SQLCompilable
1460
1461 def __init__(
1462 self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1463 ):
1464 lookups_dict: dict[str, Any] | None = lookups or None
1465 if lookups_dict:
1466 if condition is None:
1467 condition, lookups_dict = Q(**lookups_dict), None
1468 elif getattr(condition, "conditional", False):
1469 condition, lookups_dict = Q(condition, **lookups_dict), None
1470 if (
1471 condition is None
1472 or not getattr(condition, "conditional", False)
1473 or lookups_dict
1474 ):
1475 raise TypeError(
1476 "When() supports a Q object, a boolean expression, or lookups "
1477 "as a condition."
1478 )
1479 if isinstance(condition, Q) and not condition:
1480 raise ValueError("An empty Q() can't be used as a When() condition.")
1481 super().__init__(output_field=None)
1482 self.condition = condition # type: ignore[assignment]
1483 self.result = self._parse_expressions(then)[0]
1484
1485 def __str__(self) -> str:
1486 return f"WHEN {self.condition!r} THEN {self.result!r}"
1487
1488 def __repr__(self) -> str:
1489 return f"<{self.__class__.__name__}: {self}>"
1490
1491 def get_source_expressions(self) -> list[Any]:
1492 return [self.condition, self.result]
1493
1494 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1495 self.condition, self.result = exprs
1496
1497 def get_source_fields(self) -> list[Field | None]:
1498 # We're only interested in the fields of the result expressions.
1499 return [self.result._output_field_or_none]
1500
1501 def resolve_expression(
1502 self,
1503 query: Any = None,
1504 allow_joins: bool = True,
1505 reuse: Any = None,
1506 summarize: bool = False,
1507 for_save: bool = False,
1508 ) -> When:
1509 c = self.copy()
1510 c.is_summary = summarize
1511 if isinstance(c.condition, ResolvableExpression):
1512 c.condition = c.condition.resolve_expression(
1513 query, allow_joins, reuse, summarize, False
1514 )
1515 c.result = c.result.resolve_expression(
1516 query, allow_joins, reuse, summarize, for_save
1517 )
1518 return c
1519
1520 def as_sql(
1521 self,
1522 compiler: SQLCompiler,
1523 connection: BaseDatabaseWrapper,
1524 template: str | None = None,
1525 **extra_context: Any,
1526 ) -> tuple[str, tuple[Any, ...]]:
1527 connection.ops.check_expression_support(self)
1528 template_params = extra_context
1529 sql_params = []
1530 # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1531 condition_sql, condition_params = compiler.compile(self.condition)
1532 template_params["condition"] = condition_sql
1533 result_sql, result_params = compiler.compile(self.result)
1534 template_params["result"] = result_sql
1535 template = template or self.template
1536 return template % template_params, (
1537 *sql_params,
1538 *condition_params,
1539 *result_params,
1540 )
1541
1542 def get_group_by_cols(self) -> list[Any]:
1543 # This is not a complete expression and cannot be used in GROUP BY.
1544 cols = []
1545 for source in self.get_source_expressions():
1546 cols.extend(source.get_group_by_cols())
1547 return cols
1548
1549
1550@deconstructible(path="plain.models.Case")
1551class Case(SQLiteNumericMixin, Expression):
1552 """
1553 An SQL searched CASE expression:
1554
1555 CASE
1556 WHEN n > 0
1557 THEN 'positive'
1558 WHEN n < 0
1559 THEN 'negative'
1560 ELSE 'zero'
1561 END
1562 """
1563
1564 template = "CASE %(cases)s ELSE %(default)s END"
1565 case_joiner = " "
1566
1567 def __init__(
1568 self,
1569 *cases: When,
1570 default: Any = None,
1571 output_field: Field | None = None,
1572 **extra: Any,
1573 ):
1574 if not all(isinstance(case, When) for case in cases):
1575 raise TypeError("Positional arguments must all be When objects.")
1576 super().__init__(output_field)
1577 self.cases = list(cases)
1578 self.default = self._parse_expressions(default)[0]
1579 self.extra = extra
1580
1581 def __str__(self) -> str:
1582 return "CASE {}, ELSE {!r}".format(
1583 ", ".join(str(c) for c in self.cases),
1584 self.default,
1585 )
1586
1587 def __repr__(self) -> str:
1588 return f"<{self.__class__.__name__}: {self}>"
1589
1590 def get_source_expressions(self) -> list[Any]:
1591 return self.cases + [self.default]
1592
1593 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1594 *self.cases, self.default = exprs
1595
1596 def resolve_expression(
1597 self,
1598 query: Any = None,
1599 allow_joins: bool = True,
1600 reuse: Any = None,
1601 summarize: bool = False,
1602 for_save: bool = False,
1603 ) -> Case:
1604 c = self.copy()
1605 c.is_summary = summarize
1606 for pos, case in enumerate(c.cases):
1607 c.cases[pos] = case.resolve_expression(
1608 query, allow_joins, reuse, summarize, for_save
1609 )
1610 c.default = c.default.resolve_expression(
1611 query, allow_joins, reuse, summarize, for_save
1612 )
1613 return c
1614
1615 def copy(self) -> Self:
1616 c = super().copy()
1617 c.cases = c.cases[:]
1618 return cast(Self, c)
1619
1620 def as_sql(
1621 self,
1622 compiler: SQLCompiler,
1623 connection: BaseDatabaseWrapper,
1624 template: str | None = None,
1625 case_joiner: str | None = None,
1626 **extra_context: Any,
1627 ) -> tuple[str, list[Any]]:
1628 connection.ops.check_expression_support(self)
1629 if not self.cases:
1630 sql, params = compiler.compile(self.default)
1631 return sql, list(params)
1632 template_params = {**self.extra, **extra_context}
1633 case_parts = []
1634 sql_params = []
1635 default_sql, default_params = compiler.compile(self.default)
1636 for case in self.cases:
1637 try:
1638 case_sql, case_params = compiler.compile(case)
1639 except EmptyResultSet:
1640 continue
1641 except FullResultSet:
1642 default_sql, default_params = compiler.compile(case.result)
1643 break
1644 case_parts.append(case_sql)
1645 sql_params.extend(case_params)
1646 if not case_parts:
1647 return default_sql, list(default_params)
1648 case_joiner = case_joiner or self.case_joiner
1649 template_params["cases"] = case_joiner.join(case_parts)
1650 template_params["default"] = default_sql
1651 sql_params.extend(default_params)
1652 template = template or template_params.get("template", self.template)
1653 sql = template % template_params
1654 if self._output_field_or_none is not None:
1655 sql = connection.ops.unification_cast_sql(self.output_field) % sql
1656 return sql, sql_params
1657
1658 def get_group_by_cols(self) -> list[Any]:
1659 if not self.cases:
1660 return self.default.get_group_by_cols()
1661 return super().get_group_by_cols()
1662
1663
1664class Subquery(BaseExpression, Combinable):
1665 """
1666 An explicit subquery. It may contain OuterRef() references to the outer
1667 query which will be resolved when it is applied to that query.
1668 """
1669
1670 template = "(%(subquery)s)"
1671 contains_aggregate = False
1672 empty_result_set_value = None
1673
1674 def __init__(
1675 self,
1676 query: QuerySet[Any] | Query,
1677 output_field: Field | None = None,
1678 **extra: Any,
1679 ):
1680 # Import here to avoid circular import
1681 from plain.models.sql.query import Query
1682
1683 # Allow the usage of both QuerySet and sql.Query objects.
1684 if isinstance(query, Query):
1685 # It's already a Query object, use it directly
1686 sql_query = query
1687 else:
1688 # It's a QuerySet, extract the sql.Query
1689 sql_query = query.sql_query
1690 self.query = sql_query.clone()
1691 self.query.subquery = True
1692 self.extra = extra
1693 super().__init__(output_field)
1694
1695 def get_source_expressions(self) -> list[Any]:
1696 return [self.query]
1697
1698 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1699 self.query = exprs[0]
1700
1701 def _resolve_output_field(self) -> Field | None:
1702 return self.query.output_field
1703
1704 def copy(self) -> Self:
1705 clone = super().copy()
1706 clone.query = clone.query.clone()
1707 return cast(Self, clone)
1708
1709 @property
1710 def external_aliases(self) -> dict[str, bool]:
1711 return self.query.external_aliases
1712
1713 def get_external_cols(self) -> list[Any]:
1714 return self.query.get_external_cols()
1715
1716 def as_sql(
1717 self,
1718 compiler: SQLCompiler,
1719 connection: BaseDatabaseWrapper,
1720 template: str | None = None,
1721 **extra_context: Any,
1722 ) -> tuple[str, tuple[Any, ...]]:
1723 connection.ops.check_expression_support(self)
1724 template_params = {**self.extra, **extra_context}
1725 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1726 template_params["subquery"] = subquery_sql[1:-1]
1727
1728 template = template or template_params.get("template", self.template)
1729 sql = template % template_params
1730 return sql, sql_params
1731
1732 def get_group_by_cols(self) -> list[Any]:
1733 return self.query.get_group_by_cols(wrapper=self)
1734
1735
1736class Exists(Subquery):
1737 template = "EXISTS(%(subquery)s)"
1738 output_field = fields.BooleanField()
1739 empty_result_set_value = False
1740
1741 def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1742 super().__init__(query, **kwargs)
1743 self.query = self.query.exists()
1744
1745 def select_format(
1746 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1747 ) -> tuple[str, Sequence[Any]]:
1748 # Wrap EXISTS() with a CASE WHEN expression if a database backend
1749 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1750 # BY list.
1751 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1752 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1753 return sql, params
1754
1755
1756@deconstructible(path="plain.models.OrderBy")
1757class OrderBy(Expression):
1758 template = "%(expression)s %(ordering)s"
1759 conditional = False
1760
1761 def __init__(
1762 self,
1763 expression: Any,
1764 descending: bool = False,
1765 nulls_first: bool | None = None,
1766 nulls_last: bool | None = None,
1767 ):
1768 if nulls_first and nulls_last:
1769 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1770 if nulls_first is False or nulls_last is False:
1771 raise ValueError("nulls_first and nulls_last values must be True or None.")
1772 self.nulls_first = nulls_first
1773 self.nulls_last = nulls_last
1774 self.descending = descending
1775 if not isinstance(expression, ResolvableExpression):
1776 raise ValueError("expression must be an expression type")
1777 self.expression = expression
1778
1779 def __repr__(self) -> str:
1780 return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1781
1782 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1783 self.expression = exprs[0]
1784
1785 def get_source_expressions(self) -> list[Any]:
1786 return [self.expression]
1787
1788 def as_sql(
1789 self,
1790 compiler: SQLCompiler,
1791 connection: BaseDatabaseWrapper,
1792 template: str | None = None,
1793 **extra_context: Any,
1794 ) -> tuple[str, tuple[Any, ...]]:
1795 template = template or self.template
1796 if connection.features.supports_order_by_nulls_modifier:
1797 if self.nulls_last:
1798 template = f"{template} NULLS LAST"
1799 elif self.nulls_first:
1800 template = f"{template} NULLS FIRST"
1801 else:
1802 if self.nulls_last and not (
1803 self.descending and connection.features.order_by_nulls_first
1804 ):
1805 template = f"%(expression)s IS NULL, {template}"
1806 elif self.nulls_first and not (
1807 not self.descending and connection.features.order_by_nulls_first
1808 ):
1809 template = f"%(expression)s IS NOT NULL, {template}"
1810 connection.ops.check_expression_support(self)
1811 expression_sql, params = compiler.compile(self.expression)
1812 placeholders = {
1813 "expression": expression_sql,
1814 "ordering": "DESC" if self.descending else "ASC",
1815 **extra_context,
1816 }
1817 params *= template.count("%(expression)s")
1818 return (template % placeholders).rstrip(), params
1819
1820 def get_group_by_cols(self) -> list[Any]:
1821 cols = []
1822 for source in self.get_source_expressions():
1823 cols.extend(source.get_group_by_cols())
1824 return cols
1825
1826 def reverse_ordering(self) -> OrderBy:
1827 self.descending = not self.descending
1828 if self.nulls_first:
1829 self.nulls_last = True
1830 self.nulls_first = None
1831 elif self.nulls_last:
1832 self.nulls_first = True
1833 self.nulls_last = None
1834 return self
1835
1836 def asc(self) -> None:
1837 self.descending = False
1838
1839 def desc(self) -> None:
1840 self.descending = True
1841
1842
1843class Window(SQLiteNumericMixin, Expression):
1844 template = "%(expression)s OVER (%(window)s)"
1845 # Although the main expression may either be an aggregate or an
1846 # expression with an aggregate function, the GROUP BY that will
1847 # be introduced in the query as a result is not desired.
1848 contains_aggregate = False
1849 contains_over_clause = True
1850 partition_by: ExpressionList | None
1851 order_by: OrderByList | None
1852
1853 def __init__(
1854 self,
1855 expression: Any,
1856 partition_by: Any = None,
1857 order_by: Any = None,
1858 frame: Any = None,
1859 output_field: Field | None = None,
1860 ):
1861 self.partition_by = partition_by
1862 self.order_by = order_by
1863 self.frame = frame
1864
1865 if not getattr(expression, "window_compatible", False):
1866 raise ValueError(
1867 f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1868 )
1869
1870 if self.partition_by is not None:
1871 partition_by_values = (
1872 self.partition_by
1873 if isinstance(self.partition_by, tuple | list)
1874 else (self.partition_by,)
1875 )
1876 self.partition_by = ExpressionList(*partition_by_values)
1877
1878 if self.order_by is not None:
1879 if isinstance(self.order_by, list | tuple):
1880 self.order_by = OrderByList(*self.order_by)
1881 elif isinstance(self.order_by, BaseExpression | str):
1882 self.order_by = OrderByList(self.order_by)
1883 else:
1884 raise ValueError(
1885 "Window.order_by must be either a string reference to a "
1886 "field, an expression, or a list or tuple of them."
1887 )
1888 super().__init__(output_field=output_field)
1889 self.source_expression = self._parse_expressions(expression)[0]
1890
1891 def _resolve_output_field(self) -> Field | None:
1892 return self.source_expression.output_field
1893
1894 def get_source_expressions(self) -> list[Any]:
1895 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1896
1897 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1898 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1899
1900 def as_sql(
1901 self,
1902 compiler: SQLCompiler,
1903 connection: BaseDatabaseWrapper,
1904 template: str | None = None,
1905 ) -> tuple[str, tuple[Any, ...]]:
1906 connection.ops.check_expression_support(self)
1907 if not connection.features.supports_over_clause:
1908 raise NotSupportedError("This backend does not support window expressions.")
1909 expr_sql, params = compiler.compile(self.source_expression)
1910 window_sql, window_params = [], ()
1911
1912 if self.partition_by is not None:
1913 sql_expr, sql_params = self.partition_by.as_sql(
1914 compiler=compiler,
1915 connection=connection,
1916 template="PARTITION BY %(expressions)s",
1917 )
1918 window_sql.append(sql_expr)
1919 window_params += tuple(sql_params)
1920
1921 if self.order_by is not None:
1922 order_sql, order_params = compiler.compile(self.order_by)
1923 window_sql.append(order_sql)
1924 window_params += tuple(order_params)
1925
1926 if self.frame:
1927 frame_sql, frame_params = compiler.compile(self.frame)
1928 window_sql.append(frame_sql)
1929 window_params += tuple(frame_params)
1930
1931 template = template or self.template
1932
1933 return (
1934 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1935 (*params, *window_params),
1936 )
1937
1938 def as_sqlite(
1939 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1940 ) -> tuple[str, Sequence[Any]]:
1941 if isinstance(self.output_field, fields.DecimalField):
1942 # Casting to numeric must be outside of the window expression.
1943 copy = self.copy()
1944 source_expressions = copy.get_source_expressions()
1945 source_expressions[0].output_field = fields.FloatField()
1946 copy.set_source_expressions(source_expressions)
1947 return super(Window, copy).as_sqlite(compiler, connection)
1948 return self.as_sql(compiler, connection)
1949
1950 def __str__(self) -> str:
1951 return "{} OVER ({}{}{})".format(
1952 str(self.source_expression),
1953 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1954 str(self.order_by or ""),
1955 str(self.frame or ""),
1956 )
1957
1958 def __repr__(self) -> str:
1959 return f"<{self.__class__.__name__}: {self}>"
1960
1961 def get_group_by_cols(self) -> list[Any]:
1962 group_by_cols = []
1963 if self.partition_by:
1964 group_by_cols.extend(self.partition_by.get_group_by_cols())
1965 if self.order_by is not None:
1966 group_by_cols.extend(self.order_by.get_group_by_cols())
1967 return group_by_cols
1968
1969
1970class WindowFrame(Expression, ABC):
1971 """
1972 Model the frame clause in window expressions. There are two types of frame
1973 clauses which are subclasses, however, all processing and validation (by no
1974 means intended to be complete) is done here. Thus, providing an end for a
1975 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1976 row in the frame).
1977 """
1978
1979 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1980 frame_type: str
1981
1982 def __init__(self, start: int | None = None, end: int | None = None):
1983 self.start = Value(start)
1984 self.end = Value(end)
1985
1986 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1987 self.start, self.end = exprs
1988
1989 def get_source_expressions(self) -> list[Any]:
1990 return [self.start, self.end]
1991
1992 def as_sql(
1993 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1994 ) -> tuple[str, list[Any]]:
1995 connection.ops.check_expression_support(self)
1996 start, end = self.window_frame_start_end(
1997 connection, self.start.value, self.end.value
1998 )
1999 return (
2000 self.template
2001 % {
2002 "frame_type": self.frame_type,
2003 "start": start,
2004 "end": end,
2005 },
2006 [],
2007 )
2008
2009 def __repr__(self) -> str:
2010 return f"<{self.__class__.__name__}: {self}>"
2011
2012 def get_group_by_cols(self) -> list[Any]:
2013 return []
2014
2015 def __str__(self) -> str:
2016 if self.start.value is not None and self.start.value < 0:
2017 start = f"{abs(self.start.value)} {db_connection.ops.PRECEDING}"
2018 elif self.start.value is not None and self.start.value == 0:
2019 start = db_connection.ops.CURRENT_ROW
2020 else:
2021 start = db_connection.ops.UNBOUNDED_PRECEDING
2022
2023 if self.end.value is not None and self.end.value > 0:
2024 end = f"{self.end.value} {db_connection.ops.FOLLOWING}"
2025 elif self.end.value is not None and self.end.value == 0:
2026 end = db_connection.ops.CURRENT_ROW
2027 else:
2028 end = db_connection.ops.UNBOUNDED_FOLLOWING
2029 return self.template % {
2030 "frame_type": self.frame_type,
2031 "start": start,
2032 "end": end,
2033 }
2034
2035 @abstractmethod
2036 def window_frame_start_end(
2037 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2038 ) -> tuple[str, str]: ...
2039
2040
2041class RowRange(WindowFrame):
2042 frame_type = "ROWS"
2043
2044 def window_frame_start_end(
2045 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2046 ) -> tuple[str, str]:
2047 return connection.ops.window_frame_rows_start_end(start, end)
2048
2049
2050class ValueRange(WindowFrame):
2051 frame_type = "RANGE"
2052
2053 def window_frame_start_end(
2054 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2055 ) -> tuple[str, str]:
2056 return connection.ops.window_frame_range_start_end(start, end)