1from __future__ import annotations
2
3import copy
4import datetime
5import functools
6import inspect
7from abc import ABC, abstractmethod
8from collections import defaultdict
9from decimal import Decimal
10from functools import cached_property
11from types import NoneType
12from typing import TYPE_CHECKING, Any, Self
13from uuid import UUID
14
15from plain.models import fields
16from plain.models.constants import LOOKUP_SEP
17from plain.models.db import (
18 DatabaseError,
19 NotSupportedError,
20 db_connection,
21)
22from plain.models.exceptions import EmptyResultSet, FieldError, FullResultSet
23from plain.models.query_utils import Q
24from plain.utils.deconstruct import deconstructible
25from plain.utils.hashable import make_hashable
26
27if TYPE_CHECKING:
28 from collections.abc import Callable, Iterable, Sequence
29
30 from plain.models.backends.base.base import BaseDatabaseWrapper
31 from plain.models.fields import Field
32 from plain.models.lookups import Lookup, Transform
33 from plain.models.query import QuerySet
34 from plain.models.sql.compiler import SQLCompiler
35 from plain.models.sql.query import Query
36
37
38class Combinable:
39 """
40 Provide the ability to combine one or two objects with
41 some connector. For example F('foo') + F('bar').
42 """
43
44 # Arithmetic connectors
45 ADD = "+"
46 SUB = "-"
47 MUL = "*"
48 DIV = "/"
49 POW = "^"
50 # The following is a quoted % operator - it is quoted because it can be
51 # used in strings that also have parameter substitution.
52 MOD = "%%"
53
54 # Bitwise operators - note that these are generated by .bitand()
55 # and .bitor(), the '&' and '|' are reserved for boolean operator
56 # usage.
57 BITAND = "&"
58 BITOR = "|"
59 BITLEFTSHIFT = "<<"
60 BITRIGHTSHIFT = ">>"
61 BITXOR = "#"
62
63 def _combine(
64 self, other: Any, connector: str, reversed: bool
65 ) -> CombinedExpression:
66 if not hasattr(other, "resolve_expression"):
67 # everything must be resolvable to an expression
68 other = Value(other)
69
70 if reversed:
71 return CombinedExpression(other, connector, self)
72 return CombinedExpression(self, connector, other)
73
74 #############
75 # OPERATORS #
76 #############
77
78 def __neg__(self) -> CombinedExpression:
79 return self._combine(-1, self.MUL, False)
80
81 def __add__(self, other: Any) -> CombinedExpression:
82 return self._combine(other, self.ADD, False)
83
84 def __sub__(self, other: Any) -> CombinedExpression:
85 return self._combine(other, self.SUB, False)
86
87 def __mul__(self, other: Any) -> CombinedExpression:
88 return self._combine(other, self.MUL, False)
89
90 def __truediv__(self, other: Any) -> CombinedExpression:
91 return self._combine(other, self.DIV, False)
92
93 def __mod__(self, other: Any) -> CombinedExpression:
94 return self._combine(other, self.MOD, False)
95
96 def __pow__(self, other: Any) -> CombinedExpression:
97 return self._combine(other, self.POW, False)
98
99 def __and__(self, other: Any) -> Q:
100 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
101 return Q(self) & Q(other) # type: ignore[unsupported-operator]
102 raise NotImplementedError(
103 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
104 )
105
106 def bitand(self, other: Any) -> CombinedExpression:
107 return self._combine(other, self.BITAND, False)
108
109 def bitleftshift(self, other: Any) -> CombinedExpression:
110 return self._combine(other, self.BITLEFTSHIFT, False)
111
112 def bitrightshift(self, other: Any) -> CombinedExpression:
113 return self._combine(other, self.BITRIGHTSHIFT, False)
114
115 def __xor__(self, other: Any) -> Q:
116 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
117 return Q(self) ^ Q(other) # type: ignore[unsupported-operator]
118 raise NotImplementedError(
119 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
120 )
121
122 def bitxor(self, other: Any) -> CombinedExpression:
123 return self._combine(other, self.BITXOR, False)
124
125 def __or__(self, other: Any) -> Q:
126 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
127 return Q(self) | Q(other) # type: ignore[unsupported-operator]
128 raise NotImplementedError(
129 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
130 )
131
132 def bitor(self, other: Any) -> CombinedExpression:
133 return self._combine(other, self.BITOR, False)
134
135 def __radd__(self, other: Any) -> CombinedExpression:
136 return self._combine(other, self.ADD, True)
137
138 def __rsub__(self, other: Any) -> CombinedExpression:
139 return self._combine(other, self.SUB, True)
140
141 def __rmul__(self, other: Any) -> CombinedExpression:
142 return self._combine(other, self.MUL, True)
143
144 def __rtruediv__(self, other: Any) -> CombinedExpression:
145 return self._combine(other, self.DIV, True)
146
147 def __rmod__(self, other: Any) -> CombinedExpression:
148 return self._combine(other, self.MOD, True)
149
150 def __rpow__(self, other: Any) -> CombinedExpression:
151 return self._combine(other, self.POW, True)
152
153 def __rand__(self, other: Any) -> None:
154 raise NotImplementedError(
155 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
156 )
157
158 def __ror__(self, other: Any) -> None:
159 raise NotImplementedError(
160 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
161 )
162
163 def __rxor__(self, other: Any) -> None:
164 raise NotImplementedError(
165 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
166 )
167
168 def __invert__(self) -> NegatedExpression:
169 return NegatedExpression(self)
170
171
172class BaseExpression:
173 """Base class for all query expressions."""
174
175 empty_result_set_value = NotImplemented
176 # aggregate specific fields
177 is_summary = False
178 _output_field_resolved_to_none = False
179 # Can the expression be used in a WHERE clause?
180 filterable = True
181 # Can the expression can be used as a source expression in Window?
182 window_compatible = False
183
184 def __init__(self, output_field: Field | None = None):
185 if output_field is not None:
186 self.output_field = output_field
187
188 def __getstate__(self) -> dict[str, Any]:
189 state = self.__dict__.copy()
190 state.pop("convert_value", None)
191 return state
192
193 def get_db_converters(
194 self, connection: BaseDatabaseWrapper
195 ) -> list[Callable[..., Any]]:
196 converters = []
197 if self.convert_value is not self._convert_value_noop:
198 converters.append(self.convert_value)
199 converters.extend(self.output_field.get_db_converters(connection))
200 return converters
201
202 def get_source_expressions(self) -> list[Any]:
203 return []
204
205 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
206 assert not exprs
207
208 def _parse_expressions(self, *expressions: Any) -> list[Any]:
209 return [
210 arg
211 if hasattr(arg, "resolve_expression")
212 else (F(arg) if isinstance(arg, str) else Value(arg))
213 for arg in expressions
214 ]
215
216 def as_sql(
217 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
218 ) -> tuple[str, Sequence[Any]]:
219 """
220 Responsible for returning a (sql, [params]) tuple to be included
221 in the current query.
222
223 Different backends can provide their own implementation, by
224 providing an `as_{vendor}` method and patching the Expression:
225
226 ```
227 def override_as_sql(self, compiler, connection):
228 # custom logic
229 return super().as_sql(compiler, connection)
230 setattr(Expression, 'as_' + connection.vendor, override_as_sql)
231 ```
232
233 Arguments:
234 * compiler: the query compiler responsible for generating the query.
235 Must have a compile method, returning a (sql, [params]) tuple.
236 Calling compiler(value) will return a quoted `value`.
237
238 * connection: the database connection used for the current query.
239
240 Return: (sql, params)
241 Where `sql` is a string containing ordered sql parameters to be
242 replaced with the elements of the list `params`.
243 """
244 raise NotImplementedError("Subclasses must implement as_sql()")
245
246 @cached_property
247 def contains_aggregate(self) -> bool:
248 return any(
249 expr and expr.contains_aggregate for expr in self.get_source_expressions()
250 )
251
252 @cached_property
253 def contains_over_clause(self) -> bool:
254 return any(
255 expr and expr.contains_over_clause for expr in self.get_source_expressions()
256 )
257
258 @cached_property
259 def contains_column_references(self) -> bool:
260 return any(
261 expr and expr.contains_column_references
262 for expr in self.get_source_expressions()
263 )
264
265 def resolve_expression(
266 self,
267 query: Any = None,
268 allow_joins: bool = True,
269 reuse: Any = None,
270 summarize: bool = False,
271 for_save: bool = False,
272 ) -> BaseExpression:
273 """
274 Provide the chance to do any preprocessing or validation before being
275 added to the query.
276
277 Arguments:
278 * query: the backend query implementation
279 * allow_joins: boolean allowing or denying use of joins
280 in this query
281 * reuse: a set of reusable joins for multijoins
282 * summarize: a terminal aggregate clause
283 * for_save: whether this expression about to be used in a save or update
284
285 Return: an Expression to be added to the query.
286 """
287 c = self.copy()
288 c.is_summary = summarize
289 c.set_source_expressions(
290 [
291 expr.resolve_expression(query, allow_joins, reuse, summarize)
292 if expr
293 else None
294 for expr in c.get_source_expressions()
295 ]
296 )
297 return c
298
299 @property
300 def conditional(self) -> bool:
301 output_field = getattr(self, "output_field", None)
302 return isinstance(output_field, fields.BooleanField)
303
304 @property
305 def field(self) -> Field:
306 return self.output_field
307
308 @cached_property
309 def output_field(self) -> Field:
310 """Return the output type of this expressions."""
311 output_field = self._resolve_output_field()
312 if output_field is None:
313 self._output_field_resolved_to_none = True
314 raise FieldError("Cannot resolve expression type, unknown output_field")
315 return output_field
316
317 @cached_property
318 def _output_field_or_none(self) -> Field | None:
319 """
320 Return the output field of this expression, or None if
321 _resolve_output_field() didn't return an output type.
322 """
323 try:
324 return self.output_field
325 except FieldError:
326 if not self._output_field_resolved_to_none:
327 raise
328 return None
329
330 def _resolve_output_field(self) -> Field | None:
331 """
332 Attempt to infer the output type of the expression.
333
334 As a guess, if the output fields of all source fields match then simply
335 infer the same type here.
336
337 If a source's output field resolves to None, exclude it from this check.
338 If all sources are None, then an error is raised higher up the stack in
339 the output_field property.
340 """
341 # This guess is mostly a bad idea, but there is quite a lot of code
342 # (especially 3rd party Func subclasses) that depend on it, we'd need a
343 # deprecation path to fix it.
344 sources_iter = (
345 source for source in self.get_source_fields() if source is not None
346 )
347 for output_field in sources_iter:
348 for source in sources_iter:
349 if not isinstance(output_field, source.__class__):
350 raise FieldError(
351 f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
352 "set output_field."
353 )
354 return output_field
355 return None
356
357 @staticmethod
358 def _convert_value_noop(
359 value: Any, expression: Any, connection: BaseDatabaseWrapper
360 ) -> Any:
361 return value
362
363 @cached_property
364 def convert_value(self) -> Callable[[Any, Any, Any], Any]:
365 """
366 Expressions provide their own converters because users have the option
367 of manually specifying the output_field which may be a different type
368 from the one the database returns.
369 """
370 field = self.output_field
371 internal_type = field.get_internal_type()
372 if internal_type == "FloatField":
373 return (
374 lambda value, expression, connection: None
375 if value is None
376 else float(value)
377 )
378 elif internal_type.endswith("IntegerField"):
379 return (
380 lambda value, expression, connection: None
381 if value is None
382 else int(value)
383 )
384 elif internal_type == "DecimalField":
385 return (
386 lambda value, expression, connection: None
387 if value is None
388 else Decimal(value)
389 )
390 return self._convert_value_noop
391
392 def get_lookup(self, lookup: str) -> type[Lookup] | None:
393 return self.output_field.get_lookup(lookup)
394
395 def get_transform(self, name: str) -> type[Transform] | None:
396 return self.output_field.get_transform(name)
397
398 def relabeled_clone(self, change_map: dict[str, str]) -> BaseExpression:
399 clone = self.copy()
400 clone.set_source_expressions(
401 [
402 e.relabeled_clone(change_map) if e is not None else None
403 for e in self.get_source_expressions()
404 ]
405 )
406 return clone
407
408 def replace_expressions(
409 self, replacements: dict[BaseExpression, Any]
410 ) -> BaseExpression:
411 if replacement := replacements.get(self):
412 return replacement
413 clone = self.copy()
414 source_expressions = clone.get_source_expressions()
415 clone.set_source_expressions(
416 [
417 expr.replace_expressions(replacements) if expr else None
418 for expr in source_expressions
419 ]
420 )
421 return clone
422
423 def get_refs(self) -> set[str]:
424 refs = set()
425 for expr in self.get_source_expressions():
426 refs |= expr.get_refs()
427 return refs
428
429 def copy(self) -> Self:
430 return copy.copy(self)
431
432 def prefix_references(self, prefix: str) -> Self:
433 clone = self.copy()
434 clone.set_source_expressions(
435 [
436 F(f"{prefix}{expr.name}")
437 if isinstance(expr, F)
438 else expr.prefix_references(prefix)
439 for expr in self.get_source_expressions()
440 ]
441 )
442 return clone
443
444 def get_group_by_cols(self) -> list[BaseExpression]:
445 if not self.contains_aggregate:
446 return [self]
447 cols = []
448 for source in self.get_source_expressions():
449 cols.extend(source.get_group_by_cols())
450 return cols
451
452 def get_source_fields(self) -> list[Field | None]:
453 """Return the underlying field types used by this aggregate."""
454 return [e._output_field_or_none for e in self.get_source_expressions()]
455
456 def asc(self, **kwargs: Any) -> OrderBy:
457 return OrderBy(self, **kwargs)
458
459 def desc(self, **kwargs: Any) -> OrderBy:
460 return OrderBy(self, descending=True, **kwargs)
461
462 def reverse_ordering(self) -> BaseExpression:
463 return self
464
465 def flatten(self) -> Iterable[Any]:
466 """
467 Recursively yield this expression and all subexpressions, in
468 depth-first order.
469 """
470 yield self
471 for expr in self.get_source_expressions():
472 if expr:
473 if hasattr(expr, "flatten"):
474 yield from expr.flatten()
475 else:
476 yield expr
477
478 def select_format(
479 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
480 ) -> tuple[str, Sequence[Any]]:
481 """
482 Custom format for select clauses. For example, EXISTS expressions need
483 to be wrapped in CASE WHEN on Oracle.
484 """
485 if output_field := getattr(self, "output_field", None):
486 if select_format := getattr(output_field, "select_format", None):
487 return select_format(compiler, sql, params)
488 return sql, params
489
490
491@deconstructible
492class Expression(BaseExpression, Combinable):
493 """An expression that can be combined with other expressions."""
494
495 @cached_property
496 def identity(self) -> tuple[Any, ...]:
497 constructor_signature = inspect.signature(self.__init__)
498 args, kwargs = self._constructor_args # type: ignore[attr-defined]
499 signature = constructor_signature.bind_partial(*args, **kwargs)
500 signature.apply_defaults()
501 arguments = signature.arguments.items()
502 identity: list[Any] = [self.__class__]
503 for arg, value in arguments:
504 if isinstance(value, fields.Field):
505 if value.name and value.model:
506 value = (value.model.model_options.label, value.name)
507 else:
508 value = type(value)
509 else:
510 value = make_hashable(value)
511 identity.append((arg, value))
512 return tuple(identity)
513
514 def __eq__(self, other: object) -> bool:
515 if not isinstance(other, Expression):
516 return NotImplemented
517 return other.identity == self.identity
518
519 def __hash__(self) -> int:
520 return hash(self.identity)
521
522
523# Type inference for CombinedExpression.output_field.
524# Missing items will result in FieldError, by design.
525#
526# The current approach for NULL is based on lowest common denominator behavior
527# i.e. if one of the supported databases is raising an error (rather than
528# return NULL) for `val <op> NULL`, then Plain raises FieldError.
529
530_connector_combinations = [
531 # Numeric operations - operands of same type.
532 {
533 connector: [
534 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
535 (fields.FloatField, fields.FloatField, fields.FloatField),
536 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
537 ]
538 for connector in (
539 Combinable.ADD,
540 Combinable.SUB,
541 Combinable.MUL,
542 # Behavior for DIV with integer arguments follows Postgres/SQLite,
543 # not MySQL/Oracle.
544 Combinable.DIV,
545 Combinable.MOD,
546 Combinable.POW,
547 )
548 },
549 # Numeric operations - operands of different type.
550 {
551 connector: [
552 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
553 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
554 (fields.IntegerField, fields.FloatField, fields.FloatField),
555 (fields.FloatField, fields.IntegerField, fields.FloatField),
556 ]
557 for connector in (
558 Combinable.ADD,
559 Combinable.SUB,
560 Combinable.MUL,
561 Combinable.DIV,
562 Combinable.MOD,
563 )
564 },
565 # Bitwise operators.
566 {
567 connector: [
568 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
569 ]
570 for connector in (
571 Combinable.BITAND,
572 Combinable.BITOR,
573 Combinable.BITLEFTSHIFT,
574 Combinable.BITRIGHTSHIFT,
575 Combinable.BITXOR,
576 )
577 },
578 # Numeric with NULL.
579 {
580 connector: [
581 (field_type, NoneType, field_type),
582 (NoneType, field_type, field_type),
583 ]
584 for connector in (
585 Combinable.ADD,
586 Combinable.SUB,
587 Combinable.MUL,
588 Combinable.DIV,
589 Combinable.MOD,
590 Combinable.POW,
591 )
592 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
593 },
594 # Date/DateTimeField/DurationField/TimeField.
595 {
596 Combinable.ADD: [
597 # Date/DateTimeField.
598 (fields.DateField, fields.DurationField, fields.DateTimeField),
599 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
600 (fields.DurationField, fields.DateField, fields.DateTimeField),
601 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
602 # DurationField.
603 (fields.DurationField, fields.DurationField, fields.DurationField),
604 # TimeField.
605 (fields.TimeField, fields.DurationField, fields.TimeField),
606 (fields.DurationField, fields.TimeField, fields.TimeField),
607 ],
608 },
609 {
610 Combinable.SUB: [
611 # Date/DateTimeField.
612 (fields.DateField, fields.DurationField, fields.DateTimeField),
613 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
614 (fields.DateField, fields.DateField, fields.DurationField),
615 (fields.DateField, fields.DateTimeField, fields.DurationField),
616 (fields.DateTimeField, fields.DateField, fields.DurationField),
617 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
618 # DurationField.
619 (fields.DurationField, fields.DurationField, fields.DurationField),
620 # TimeField.
621 (fields.TimeField, fields.DurationField, fields.TimeField),
622 (fields.TimeField, fields.TimeField, fields.DurationField),
623 ],
624 },
625]
626
627_connector_combinators = defaultdict(list)
628
629
630def register_combinable_fields(
631 lhs: type[Field] | type[None],
632 connector: str,
633 rhs: type[Field] | type[None],
634 result: type[Field],
635) -> None:
636 """
637 Register combinable types:
638 lhs <connector> rhs -> result
639 e.g.
640 register_combinable_fields(
641 IntegerField, Combinable.ADD, FloatField, FloatField
642 )
643 """
644 _connector_combinators[connector].append((lhs, rhs, result))
645
646
647for d in _connector_combinations:
648 for connector, field_types in d.items():
649 for lhs, rhs, result in field_types:
650 register_combinable_fields(lhs, connector, rhs, result)
651
652
653@functools.lru_cache(maxsize=128)
654def _resolve_combined_type(
655 connector: str, lhs_type: type[Field], rhs_type: type[Field]
656) -> type[Field] | None:
657 combinators = _connector_combinators.get(connector, ())
658 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
659 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
660 rhs_type, combinator_rhs_type
661 ):
662 return combined_type
663 return None
664
665
666class SQLiteNumericMixin(Expression):
667 """
668 Some expressions with output_field=DecimalField() must be cast to
669 numeric to be properly filtered.
670 """
671
672 def as_sqlite(
673 self,
674 compiler: SQLCompiler,
675 connection: BaseDatabaseWrapper,
676 **extra_context: Any,
677 ) -> tuple[str, Sequence[Any]]:
678 sql, params = self.as_sql(compiler, connection, **extra_context)
679 try:
680 if self.output_field.get_internal_type() == "DecimalField":
681 sql = f"CAST({sql} AS NUMERIC)"
682 except FieldError:
683 pass
684 return sql, params
685
686
687class CombinedExpression(SQLiteNumericMixin, Expression):
688 def __init__(
689 self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
690 ):
691 super().__init__(output_field=output_field)
692 self.connector = connector
693 self.lhs = lhs
694 self.rhs = rhs
695
696 def __repr__(self) -> str:
697 return f"<{self.__class__.__name__}: {self}>"
698
699 def __str__(self) -> str:
700 return f"{self.lhs} {self.connector} {self.rhs}"
701
702 def get_source_expressions(self) -> list[Any]:
703 return [self.lhs, self.rhs]
704
705 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
706 self.lhs, self.rhs = exprs
707
708 def _resolve_output_field(self) -> Field | None:
709 # We avoid using super() here for reasons given in
710 # Expression._resolve_output_field()
711 combined_type = _resolve_combined_type(
712 self.connector,
713 type(self.lhs._output_field_or_none),
714 type(self.rhs._output_field_or_none),
715 )
716 if combined_type is None:
717 raise FieldError(
718 f"Cannot infer type of {self.connector!r} expression involving these "
719 f"types: {self.lhs.output_field.__class__.__name__}, "
720 f"{self.rhs.output_field.__class__.__name__}. You must set "
721 f"output_field."
722 )
723 return combined_type()
724
725 def as_sql(
726 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
727 ) -> tuple[str, list[Any]]:
728 expressions = []
729 expression_params = []
730 sql, params = compiler.compile(self.lhs)
731 expressions.append(sql)
732 expression_params.extend(params)
733 sql, params = compiler.compile(self.rhs)
734 expressions.append(sql)
735 expression_params.extend(params)
736 # order of precedence
737 expression_wrapper = "(%s)"
738 sql = connection.ops.combine_expression(self.connector, expressions)
739 return expression_wrapper % sql, expression_params
740
741 def resolve_expression(
742 self,
743 query: Any = None,
744 allow_joins: bool = True,
745 reuse: Any = None,
746 summarize: bool = False,
747 for_save: bool = False,
748 ) -> CombinedExpression | DurationExpression | TemporalSubtraction:
749 lhs = self.lhs.resolve_expression(
750 query, allow_joins, reuse, summarize, for_save
751 )
752 rhs = self.rhs.resolve_expression(
753 query, allow_joins, reuse, summarize, for_save
754 )
755 if not isinstance(self, DurationExpression | TemporalSubtraction):
756 try:
757 lhs_type = lhs.output_field.get_internal_type()
758 except (AttributeError, FieldError):
759 lhs_type = None
760 try:
761 rhs_type = rhs.output_field.get_internal_type()
762 except (AttributeError, FieldError):
763 rhs_type = None
764 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
765 return DurationExpression(
766 self.lhs, self.connector, self.rhs
767 ).resolve_expression(
768 query,
769 allow_joins,
770 reuse,
771 summarize,
772 for_save,
773 )
774 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
775 if (
776 self.connector == self.SUB
777 and lhs_type in datetime_fields
778 and lhs_type == rhs_type
779 ):
780 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
781 query,
782 allow_joins,
783 reuse,
784 summarize,
785 for_save,
786 )
787 c = self.copy()
788 c.is_summary = summarize
789 c.lhs = lhs
790 c.rhs = rhs
791 return c
792
793
794class DurationExpression(CombinedExpression):
795 def compile(
796 self, side: Any, compiler: SQLCompiler, connection: BaseDatabaseWrapper
797 ) -> tuple[str, Sequence[Any]]:
798 try:
799 output = side.output_field
800 except FieldError:
801 pass
802 else:
803 if output.get_internal_type() == "DurationField":
804 sql, params = compiler.compile(side)
805 return connection.ops.format_for_duration_arithmetic(sql), params
806 return compiler.compile(side)
807
808 def as_sql(
809 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
810 ) -> tuple[str, list[Any]]:
811 if connection.features.has_native_duration_field:
812 return super().as_sql(compiler, connection) # type: ignore[misc]
813 connection.ops.check_expression_support(self)
814 expressions = []
815 expression_params = []
816 sql, params = self.compile(self.lhs, compiler, connection)
817 expressions.append(sql)
818 expression_params.extend(params)
819 sql, params = self.compile(self.rhs, compiler, connection)
820 expressions.append(sql)
821 expression_params.extend(params)
822 # order of precedence
823 expression_wrapper = "(%s)"
824 sql = connection.ops.combine_duration_expression(self.connector, expressions)
825 return expression_wrapper % sql, expression_params
826
827 def as_sqlite(
828 self,
829 compiler: SQLCompiler,
830 connection: BaseDatabaseWrapper,
831 **extra_context: Any,
832 ) -> tuple[str, Sequence[Any]]:
833 sql, params = self.as_sql(compiler, connection, **extra_context)
834 if self.connector in {Combinable.MUL, Combinable.DIV}:
835 try:
836 lhs_type = self.lhs.output_field.get_internal_type()
837 rhs_type = self.rhs.output_field.get_internal_type()
838 except (AttributeError, FieldError):
839 pass
840 else:
841 allowed_fields = {
842 "DecimalField",
843 "DurationField",
844 "FloatField",
845 "IntegerField",
846 }
847 if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
848 raise DatabaseError(
849 f"Invalid arguments for operator {self.connector}."
850 )
851 return sql, params
852
853
854class TemporalSubtraction(CombinedExpression):
855 output_field = fields.DurationField()
856
857 def __init__(self, lhs: Any, rhs: Any):
858 super().__init__(lhs, self.SUB, rhs)
859
860 def as_sql(
861 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
862 ) -> tuple[str, Sequence[Any]]:
863 connection.ops.check_expression_support(self)
864 lhs = compiler.compile(self.lhs)
865 rhs = compiler.compile(self.rhs)
866 return connection.ops.subtract_temporals(
867 self.lhs.output_field.get_internal_type(), lhs, rhs
868 )
869
870
871@deconstructible(path="plain.models.F")
872class F(Combinable):
873 """An object capable of resolving references to existing query objects."""
874
875 def __init__(self, name: str):
876 """
877 Arguments:
878 * name: the name of the field this expression references
879 """
880 self.name = name
881
882 def __repr__(self) -> str:
883 return f"{self.__class__.__name__}({self.name})"
884
885 def resolve_expression(
886 self,
887 query: Any = None,
888 allow_joins: bool = True,
889 reuse: Any = None,
890 summarize: bool = False,
891 for_save: bool = False,
892 ) -> Any:
893 return query.resolve_ref(self.name, allow_joins, reuse, summarize) # type: ignore[union-attr]
894
895 def replace_expressions(self, replacements: dict[Any, Any]) -> F:
896 return replacements.get(self, self)
897
898 def asc(self, **kwargs: Any) -> OrderBy:
899 return OrderBy(self, **kwargs)
900
901 def desc(self, **kwargs: Any) -> OrderBy:
902 return OrderBy(self, descending=True, **kwargs)
903
904 def __eq__(self, other: object) -> bool:
905 return (
906 self.__class__ == other.__class__ and self.name == other.name # type: ignore[attr-defined]
907 )
908
909 def __hash__(self) -> int:
910 return hash(self.name)
911
912 def copy(self) -> Self:
913 return copy.copy(self)
914
915
916class ResolvedOuterRef(F):
917 """
918 An object that contains a reference to an outer query.
919
920 In this case, the reference to the outer query has been resolved because
921 the inner query has been used as a subquery.
922 """
923
924 contains_aggregate = False
925 contains_over_clause = False
926
927 def as_sql(self, *args: Any, **kwargs: Any) -> None:
928 raise ValueError(
929 "This queryset contains a reference to an outer query and may "
930 "only be used in a subquery."
931 )
932
933 def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
934 col = super().resolve_expression(*args, **kwargs) # type: ignore[misc]
935 if col.contains_over_clause:
936 raise NotSupportedError(
937 f"Referencing outer query window expression is not supported: "
938 f"{self.name}."
939 )
940 # FIXME: Rename possibly_multivalued to multivalued and fix detection
941 # for non-multivalued JOINs (e.g. foreign key fields). This should take
942 # into account only many-to-many and one-to-many relationships.
943 col.possibly_multivalued = LOOKUP_SEP in self.name
944 return col
945
946 def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
947 return self
948
949 def get_group_by_cols(self) -> list[Any]:
950 return []
951
952
953class OuterRef(F):
954 contains_aggregate = False
955
956 def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
957 if isinstance(self.name, self.__class__):
958 return self.name
959 return ResolvedOuterRef(self.name)
960
961 def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
962 return self
963
964
965@deconstructible(path="plain.models.Func")
966class Func(SQLiteNumericMixin, Expression):
967 """An SQL function call."""
968
969 function = None
970 template = "%(function)s(%(expressions)s)"
971 arg_joiner = ", "
972 arity = None # The number of arguments the function accepts.
973
974 def __init__(
975 self, *expressions: Any, output_field: Field | None = None, **extra: Any
976 ):
977 if self.arity is not None and len(expressions) != self.arity:
978 raise TypeError(
979 "'{}' takes exactly {} {} ({} given)".format(
980 self.__class__.__name__,
981 self.arity,
982 "argument" if self.arity == 1 else "arguments",
983 len(expressions),
984 )
985 )
986 super().__init__(output_field=output_field)
987 self.source_expressions: list[Any] = self._parse_expressions(*expressions)
988 self.extra = extra
989
990 def __repr__(self) -> str:
991 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
992 extra = {**self.extra, **self._get_repr_options()}
993 if extra:
994 extra = ", ".join(
995 str(key) + "=" + str(val) for key, val in sorted(extra.items())
996 )
997 return f"{self.__class__.__name__}({args}, {extra})"
998 return f"{self.__class__.__name__}({args})"
999
1000 def _get_repr_options(self) -> dict[str, Any]:
1001 """Return a dict of extra __init__() options to include in the repr."""
1002 return {}
1003
1004 def get_source_expressions(self) -> list[Any]:
1005 return self.source_expressions
1006
1007 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1008 self.source_expressions = list(exprs)
1009
1010 def resolve_expression(
1011 self,
1012 query: Any = None,
1013 allow_joins: bool = True,
1014 reuse: Any = None,
1015 summarize: bool = False,
1016 for_save: bool = False,
1017 ) -> Self:
1018 c = self.copy()
1019 c.is_summary = summarize
1020 for pos, arg in enumerate(c.source_expressions):
1021 c.source_expressions[pos] = arg.resolve_expression(
1022 query, allow_joins, reuse, summarize, for_save
1023 )
1024 return c
1025
1026 def as_sql(
1027 self,
1028 compiler: SQLCompiler,
1029 connection: BaseDatabaseWrapper,
1030 function: str | None = None,
1031 template: str | None = None,
1032 arg_joiner: str | None = None,
1033 **extra_context: Any,
1034 ) -> tuple[str, list[Any]]:
1035 connection.ops.check_expression_support(self)
1036 sql_parts = []
1037 params = []
1038 for arg in self.source_expressions:
1039 try:
1040 arg_sql, arg_params = compiler.compile(arg)
1041 except EmptyResultSet:
1042 empty_result_set_value = getattr(
1043 arg, "empty_result_set_value", NotImplemented
1044 )
1045 if empty_result_set_value is NotImplemented:
1046 raise
1047 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
1048 except FullResultSet:
1049 arg_sql, arg_params = compiler.compile(Value(True))
1050 sql_parts.append(arg_sql)
1051 params.extend(arg_params)
1052 data = {**self.extra, **extra_context}
1053 # Use the first supplied value in this order: the parameter to this
1054 # method, a value supplied in __init__()'s **extra (the value in
1055 # `data`), or the value defined on the class.
1056 if function is not None:
1057 data["function"] = function
1058 else:
1059 data.setdefault("function", self.function)
1060 template = template or data.get("template", self.template)
1061 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1062 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1063 return template % data, params
1064
1065 def copy(self) -> Self:
1066 copy = super().copy()
1067 copy.source_expressions = self.source_expressions[:]
1068 copy.extra = self.extra.copy()
1069 return copy # type: ignore[return-value]
1070
1071
1072@deconstructible(path="plain.models.Value")
1073class Value(SQLiteNumericMixin, Expression):
1074 """Represent a wrapped value as a node within an expression."""
1075
1076 # Provide a default value for `for_save` in order to allow unresolved
1077 # instances to be compiled until a decision is taken in #25425.
1078 for_save = False
1079
1080 def __init__(self, value: Any, output_field: Field | None = None):
1081 """
1082 Arguments:
1083 * value: the value this expression represents. The value will be
1084 added into the sql parameter list and properly quoted.
1085
1086 * output_field: an instance of the model field type that this
1087 expression will return, such as IntegerField() or CharField().
1088 """
1089 super().__init__(output_field=output_field)
1090 self.value = value
1091
1092 def __repr__(self) -> str:
1093 return f"{self.__class__.__name__}({self.value!r})"
1094
1095 def as_sql(
1096 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1097 ) -> tuple[str, list[Any]]:
1098 connection.ops.check_expression_support(self)
1099 val = self.value
1100 output_field = self._output_field_or_none
1101 if output_field is not None:
1102 if self.for_save:
1103 val = output_field.get_db_prep_save(val, connection=connection)
1104 else:
1105 val = output_field.get_db_prep_value(val, connection=connection)
1106 if hasattr(output_field, "get_placeholder"):
1107 return output_field.get_placeholder(val, compiler, connection), [val]
1108 if val is None:
1109 # cx_Oracle does not always convert None to the appropriate
1110 # NULL type (like in case expressions using numbers), so we
1111 # use a literal SQL NULL
1112 return "NULL", []
1113 return "%s", [val]
1114
1115 def resolve_expression(
1116 self,
1117 query: Any = None,
1118 allow_joins: bool = True,
1119 reuse: Any = None,
1120 summarize: bool = False,
1121 for_save: bool = False,
1122 ) -> Value:
1123 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) # type: ignore[misc]
1124 c.for_save = for_save # type: ignore[attr-defined]
1125 return c # type: ignore[return-value]
1126
1127 def get_group_by_cols(self) -> list[Any]:
1128 return []
1129
1130 def _resolve_output_field(self) -> Field | None:
1131 if isinstance(self.value, str):
1132 return fields.CharField()
1133 if isinstance(self.value, bool):
1134 return fields.BooleanField()
1135 if isinstance(self.value, int):
1136 return fields.IntegerField()
1137 if isinstance(self.value, float):
1138 return fields.FloatField()
1139 if isinstance(self.value, datetime.datetime):
1140 return fields.DateTimeField()
1141 if isinstance(self.value, datetime.date):
1142 return fields.DateField()
1143 if isinstance(self.value, datetime.time):
1144 return fields.TimeField()
1145 if isinstance(self.value, datetime.timedelta):
1146 return fields.DurationField()
1147 if isinstance(self.value, Decimal):
1148 return fields.DecimalField()
1149 if isinstance(self.value, bytes):
1150 return fields.BinaryField()
1151 if isinstance(self.value, UUID):
1152 return fields.UUIDField()
1153
1154 @property
1155 def empty_result_set_value(self) -> Any:
1156 return self.value
1157
1158
1159class RawSQL(Expression):
1160 def __init__(
1161 self, sql: str, params: Sequence[Any], output_field: Field | None = None
1162 ):
1163 if output_field is None:
1164 output_field = fields.Field()
1165 self.sql, self.params = sql, params
1166 super().__init__(output_field=output_field)
1167
1168 def __repr__(self) -> str:
1169 return f"{self.__class__.__name__}({self.sql}, {self.params})"
1170
1171 def as_sql(
1172 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1173 ) -> tuple[str, Sequence[Any]]:
1174 return f"({self.sql})", self.params
1175
1176 def get_group_by_cols(self) -> list[RawSQL]:
1177 return [self]
1178
1179
1180class Star(Expression):
1181 def __repr__(self) -> str:
1182 return "'*'"
1183
1184 def as_sql(
1185 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1186 ) -> tuple[str, list[Any]]:
1187 return "*", []
1188
1189
1190class Col(Expression):
1191 contains_column_references = True
1192 possibly_multivalued = False
1193
1194 def __init__(
1195 self, alias: str | None, target: Any, output_field: Field | None = None
1196 ):
1197 if output_field is None:
1198 output_field = target
1199 super().__init__(output_field=output_field)
1200 self.alias, self.target = alias, target
1201
1202 def __repr__(self) -> str:
1203 alias, target = self.alias, self.target
1204 identifiers = (alias, str(target)) if alias else (str(target),)
1205 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1206
1207 def as_sql(
1208 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1209 ) -> tuple[str, list[Any]]:
1210 alias, column = self.alias, self.target.column
1211 identifiers = (alias, column) if alias else (column,)
1212 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1213 return sql, []
1214
1215 def relabeled_clone(self, relabels: dict[str, str]) -> Col:
1216 if self.alias is None:
1217 return self
1218 return self.__class__(
1219 relabels.get(self.alias, self.alias), self.target, self.output_field
1220 )
1221
1222 def get_group_by_cols(self) -> list[Col]:
1223 return [self]
1224
1225 def get_db_converters(
1226 self, connection: BaseDatabaseWrapper
1227 ) -> list[Callable[..., Any]]:
1228 if self.target == self.output_field:
1229 return self.output_field.get_db_converters(connection)
1230 return self.output_field.get_db_converters(
1231 connection
1232 ) + self.target.get_db_converters(connection)
1233
1234
1235class Ref(Expression):
1236 """
1237 Reference to column alias of the query. For example, Ref('sum_cost') in
1238 qs.annotate(sum_cost=Sum('cost')) query.
1239 """
1240
1241 def __init__(self, refs: str, source: Any):
1242 super().__init__()
1243 self.refs, self.source = refs, source
1244
1245 def __repr__(self) -> str:
1246 return f"{self.__class__.__name__}({self.refs}, {self.source})"
1247
1248 def get_source_expressions(self) -> list[Any]:
1249 return [self.source]
1250
1251 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1252 (self.source,) = exprs
1253
1254 def resolve_expression(
1255 self,
1256 query: Any = None,
1257 allow_joins: bool = True,
1258 reuse: Any = None,
1259 summarize: bool = False,
1260 for_save: bool = False,
1261 ) -> Ref:
1262 # The sub-expression `source` has already been resolved, as this is
1263 # just a reference to the name of `source`.
1264 return self
1265
1266 def get_refs(self) -> set[str]:
1267 return {self.refs}
1268
1269 def relabeled_clone(self, relabels: dict[str, str]) -> Ref:
1270 return self
1271
1272 def as_sql(
1273 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1274 ) -> tuple[str, list[Any]]:
1275 return connection.ops.quote_name(self.refs), []
1276
1277 def get_group_by_cols(self) -> list[Ref]:
1278 return [self]
1279
1280
1281class ExpressionList(Func):
1282 """
1283 An expression containing multiple expressions. Can be used to provide a
1284 list of expressions as an argument to another expression, like a partition
1285 clause.
1286 """
1287
1288 template = "%(expressions)s"
1289
1290 def __init__(self, *expressions: Any, **extra: Any):
1291 if not expressions:
1292 raise ValueError(
1293 f"{self.__class__.__name__} requires at least one expression."
1294 )
1295 super().__init__(*expressions, **extra)
1296
1297 def __str__(self) -> str:
1298 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1299
1300 def as_sqlite(
1301 self,
1302 compiler: SQLCompiler,
1303 connection: BaseDatabaseWrapper,
1304 **extra_context: Any,
1305 ) -> tuple[str, Sequence[Any]]:
1306 # Casting to numeric is unnecessary.
1307 return self.as_sql(compiler, connection, **extra_context)
1308
1309
1310class OrderByList(Func):
1311 template = "ORDER BY %(expressions)s"
1312
1313 def __init__(self, *expressions: Any, **extra: Any):
1314 expressions_tuple = tuple(
1315 (
1316 OrderBy(F(expr[1:]), descending=True)
1317 if isinstance(expr, str) and expr[0] == "-"
1318 else expr
1319 )
1320 for expr in expressions
1321 )
1322 super().__init__(*expressions_tuple, **extra)
1323
1324 def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, tuple[Any, ...]]:
1325 if not self.source_expressions:
1326 return "", ()
1327 return super().as_sql(*args, **kwargs) # type: ignore[misc]
1328
1329 def get_group_by_cols(self) -> list[Any]:
1330 group_by_cols = []
1331 for order_by in self.get_source_expressions():
1332 group_by_cols.extend(order_by.get_group_by_cols())
1333 return group_by_cols
1334
1335
1336@deconstructible(path="plain.models.ExpressionWrapper")
1337class ExpressionWrapper(SQLiteNumericMixin, Expression):
1338 """
1339 An expression that can wrap another expression so that it can provide
1340 extra context to the inner expression, such as the output_field.
1341 """
1342
1343 def __init__(self, expression: Any, output_field: Field):
1344 super().__init__(output_field=output_field)
1345 self.expression = expression
1346
1347 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1348 self.expression = exprs[0]
1349
1350 def get_source_expressions(self) -> list[Any]:
1351 return [self.expression]
1352
1353 def get_group_by_cols(self) -> list[Any]:
1354 if isinstance(self.expression, Expression):
1355 expression = self.expression.copy()
1356 expression.output_field = self.output_field
1357 return expression.get_group_by_cols()
1358 # For non-expressions e.g. an SQL WHERE clause, the entire
1359 # `expression` must be included in the GROUP BY clause.
1360 return super().get_group_by_cols() # type: ignore[misc]
1361
1362 def as_sql(
1363 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1364 ) -> tuple[str, Sequence[Any]]:
1365 return compiler.compile(self.expression)
1366
1367 def __repr__(self) -> str:
1368 return f"{self.__class__.__name__}({self.expression})"
1369
1370
1371class NegatedExpression(ExpressionWrapper):
1372 """The logical negation of a conditional expression."""
1373
1374 def __init__(self, expression: Any):
1375 super().__init__(expression, output_field=fields.BooleanField())
1376
1377 def __invert__(self) -> Any:
1378 return self.expression.copy()
1379
1380 def as_sql(
1381 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1382 ) -> tuple[str, Sequence[Any]]:
1383 try:
1384 sql, params = super().as_sql(compiler, connection)
1385 except EmptyResultSet:
1386 features = compiler.connection.features
1387 if not features.supports_boolean_expr_in_select_clause:
1388 return "1=1", ()
1389 return compiler.compile(Value(True))
1390 ops = compiler.connection.ops
1391 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1392 # to be compared to another expression unless they're wrapped in a CASE
1393 # WHEN.
1394 if not ops.conditional_expression_supported_in_where_clause(self.expression):
1395 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1396 return f"NOT {sql}", params
1397
1398 def resolve_expression(
1399 self,
1400 query: Any = None,
1401 allow_joins: bool = True,
1402 reuse: Any = None,
1403 summarize: bool = False,
1404 for_save: bool = False,
1405 ) -> NegatedExpression:
1406 resolved = super().resolve_expression(
1407 query, allow_joins, reuse, summarize, for_save
1408 )
1409 if not getattr(resolved.expression, "conditional", False): # type: ignore[attr-defined]
1410 raise TypeError("Cannot negate non-conditional expressions.")
1411 return resolved # type: ignore[return-value]
1412
1413 def select_format(
1414 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1415 ) -> tuple[str, Sequence[Any]]:
1416 # Wrap boolean expressions with a CASE WHEN expression if a database
1417 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1418 # GROUP BY list.
1419 expression_supported_in_where_clause = (
1420 compiler.connection.ops.conditional_expression_supported_in_where_clause
1421 )
1422 if (
1423 not compiler.connection.features.supports_boolean_expr_in_select_clause
1424 # Avoid double wrapping.
1425 and expression_supported_in_where_clause(self.expression)
1426 ):
1427 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1428 return sql, params
1429
1430
1431@deconstructible(path="plain.models.When")
1432class When(Expression):
1433 template = "WHEN %(condition)s THEN %(result)s"
1434 # This isn't a complete conditional expression, must be used in Case().
1435 conditional = False
1436
1437 def __init__(
1438 self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1439 ):
1440 lookups_dict: dict[str, Any] | None = lookups or None
1441 if lookups_dict:
1442 if condition is None:
1443 condition, lookups_dict = Q(**lookups_dict), None
1444 elif getattr(condition, "conditional", False):
1445 condition, lookups_dict = Q(condition, **lookups_dict), None
1446 if (
1447 condition is None
1448 or not getattr(condition, "conditional", False)
1449 or lookups_dict
1450 ):
1451 raise TypeError(
1452 "When() supports a Q object, a boolean expression, or lookups "
1453 "as a condition."
1454 )
1455 if isinstance(condition, Q) and not condition:
1456 raise ValueError("An empty Q() can't be used as a When() condition.")
1457 super().__init__(output_field=None)
1458 self.condition = condition
1459 self.result = self._parse_expressions(then)[0]
1460
1461 def __str__(self) -> str:
1462 return f"WHEN {self.condition!r} THEN {self.result!r}"
1463
1464 def __repr__(self) -> str:
1465 return f"<{self.__class__.__name__}: {self}>"
1466
1467 def get_source_expressions(self) -> list[Any]:
1468 return [self.condition, self.result]
1469
1470 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1471 self.condition, self.result = exprs
1472
1473 def get_source_fields(self) -> list[Field | None]:
1474 # We're only interested in the fields of the result expressions.
1475 return [self.result._output_field_or_none]
1476
1477 def resolve_expression(
1478 self,
1479 query: Any = None,
1480 allow_joins: bool = True,
1481 reuse: Any = None,
1482 summarize: bool = False,
1483 for_save: bool = False,
1484 ) -> When:
1485 c = self.copy()
1486 c.is_summary = summarize
1487 if hasattr(c.condition, "resolve_expression"):
1488 c.condition = c.condition.resolve_expression(
1489 query, allow_joins, reuse, summarize, False
1490 )
1491 c.result = c.result.resolve_expression(
1492 query, allow_joins, reuse, summarize, for_save
1493 )
1494 return c
1495
1496 def as_sql(
1497 self,
1498 compiler: SQLCompiler,
1499 connection: BaseDatabaseWrapper,
1500 template: str | None = None,
1501 **extra_context: Any,
1502 ) -> tuple[str, tuple[Any, ...]]:
1503 connection.ops.check_expression_support(self)
1504 template_params = extra_context
1505 sql_params = []
1506 # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1507 condition_sql, condition_params = compiler.compile(self.condition) # type: ignore[arg-type]
1508 template_params["condition"] = condition_sql
1509 result_sql, result_params = compiler.compile(self.result)
1510 template_params["result"] = result_sql
1511 template = template or self.template
1512 return template % template_params, (
1513 *sql_params,
1514 *condition_params,
1515 *result_params,
1516 )
1517
1518 def get_group_by_cols(self) -> list[Any]:
1519 # This is not a complete expression and cannot be used in GROUP BY.
1520 cols = []
1521 for source in self.get_source_expressions():
1522 cols.extend(source.get_group_by_cols())
1523 return cols
1524
1525
1526@deconstructible(path="plain.models.Case")
1527class Case(SQLiteNumericMixin, Expression):
1528 """
1529 An SQL searched CASE expression:
1530
1531 CASE
1532 WHEN n > 0
1533 THEN 'positive'
1534 WHEN n < 0
1535 THEN 'negative'
1536 ELSE 'zero'
1537 END
1538 """
1539
1540 template = "CASE %(cases)s ELSE %(default)s END"
1541 case_joiner = " "
1542
1543 def __init__(
1544 self,
1545 *cases: When,
1546 default: Any = None,
1547 output_field: Field | None = None,
1548 **extra: Any,
1549 ):
1550 if not all(isinstance(case, When) for case in cases):
1551 raise TypeError("Positional arguments must all be When objects.")
1552 super().__init__(output_field)
1553 self.cases = list(cases)
1554 self.default = self._parse_expressions(default)[0]
1555 self.extra = extra
1556
1557 def __str__(self) -> str:
1558 return "CASE {}, ELSE {!r}".format(
1559 ", ".join(str(c) for c in self.cases),
1560 self.default,
1561 )
1562
1563 def __repr__(self) -> str:
1564 return f"<{self.__class__.__name__}: {self}>"
1565
1566 def get_source_expressions(self) -> list[Any]:
1567 return self.cases + [self.default]
1568
1569 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1570 *self.cases, self.default = exprs
1571
1572 def resolve_expression(
1573 self,
1574 query: Any = None,
1575 allow_joins: bool = True,
1576 reuse: Any = None,
1577 summarize: bool = False,
1578 for_save: bool = False,
1579 ) -> Case:
1580 c = self.copy()
1581 c.is_summary = summarize
1582 for pos, case in enumerate(c.cases):
1583 c.cases[pos] = case.resolve_expression(
1584 query, allow_joins, reuse, summarize, for_save
1585 )
1586 c.default = c.default.resolve_expression(
1587 query, allow_joins, reuse, summarize, for_save
1588 )
1589 return c
1590
1591 def copy(self) -> Self:
1592 c = super().copy() # type: ignore[misc]
1593 c.cases = c.cases[:] # type: ignore[attr-defined]
1594 return c # type: ignore[return-value]
1595
1596 def as_sql(
1597 self,
1598 compiler: SQLCompiler,
1599 connection: BaseDatabaseWrapper,
1600 template: str | None = None,
1601 case_joiner: str | None = None,
1602 **extra_context: Any,
1603 ) -> tuple[str, list[Any]]:
1604 connection.ops.check_expression_support(self)
1605 if not self.cases:
1606 sql, params = compiler.compile(self.default)
1607 return sql, list(params)
1608 template_params = {**self.extra, **extra_context}
1609 case_parts = []
1610 sql_params = []
1611 default_sql, default_params = compiler.compile(self.default)
1612 for case in self.cases:
1613 try:
1614 case_sql, case_params = compiler.compile(case)
1615 except EmptyResultSet:
1616 continue
1617 except FullResultSet:
1618 default_sql, default_params = compiler.compile(case.result)
1619 break
1620 case_parts.append(case_sql)
1621 sql_params.extend(case_params)
1622 if not case_parts:
1623 return default_sql, list(default_params)
1624 case_joiner = case_joiner or self.case_joiner
1625 template_params["cases"] = case_joiner.join(case_parts)
1626 template_params["default"] = default_sql
1627 sql_params.extend(default_params)
1628 template = template or template_params.get("template", self.template)
1629 sql = template % template_params
1630 if self._output_field_or_none is not None:
1631 sql = connection.ops.unification_cast_sql(self.output_field) % sql
1632 return sql, sql_params
1633
1634 def get_group_by_cols(self) -> list[Any]:
1635 if not self.cases:
1636 return self.default.get_group_by_cols()
1637 return super().get_group_by_cols() # type: ignore[misc]
1638
1639
1640class Subquery(BaseExpression, Combinable):
1641 """
1642 An explicit subquery. It may contain OuterRef() references to the outer
1643 query which will be resolved when it is applied to that query.
1644 """
1645
1646 template = "(%(subquery)s)"
1647 contains_aggregate = False
1648 empty_result_set_value = None
1649
1650 def __init__(
1651 self,
1652 query: QuerySet[Any] | Query,
1653 output_field: Field | None = None,
1654 **extra: Any,
1655 ):
1656 # Import here to avoid circular import
1657 from plain.models.sql.query import Query
1658
1659 # Allow the usage of both QuerySet and sql.Query objects.
1660 if isinstance(query, Query):
1661 # It's already a Query object, use it directly
1662 sql_query = query
1663 else:
1664 # It's a QuerySet, extract the sql.Query
1665 sql_query = query.sql_query
1666 self.query = sql_query.clone()
1667 self.query.subquery = True
1668 self.extra = extra
1669 super().__init__(output_field)
1670
1671 def get_source_expressions(self) -> list[Any]:
1672 return [self.query]
1673
1674 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1675 self.query = exprs[0]
1676
1677 def _resolve_output_field(self) -> Field | None:
1678 return self.query.output_field
1679
1680 def copy(self) -> Self:
1681 clone = super().copy()
1682 clone.query = clone.query.clone() # type: ignore[attr-defined]
1683 return clone # type: ignore[return-value]
1684
1685 @property
1686 def external_aliases(self) -> dict[str, bool]: # type: ignore[override]
1687 return self.query.external_aliases # type: ignore[return-value]
1688
1689 def get_external_cols(self) -> list[Any]:
1690 return self.query.get_external_cols()
1691
1692 def as_sql(
1693 self,
1694 compiler: SQLCompiler,
1695 connection: BaseDatabaseWrapper,
1696 template: str | None = None,
1697 **extra_context: Any,
1698 ) -> tuple[str, tuple[Any, ...]]:
1699 connection.ops.check_expression_support(self)
1700 template_params = {**self.extra, **extra_context}
1701 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1702 template_params["subquery"] = subquery_sql[1:-1]
1703
1704 template = template or template_params.get("template", self.template)
1705 sql = template % template_params
1706 return sql, sql_params
1707
1708 def get_group_by_cols(self) -> list[Any]:
1709 return self.query.get_group_by_cols(wrapper=self)
1710
1711
1712class Exists(Subquery):
1713 template = "EXISTS(%(subquery)s)"
1714 output_field = fields.BooleanField()
1715 empty_result_set_value = False
1716
1717 def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1718 super().__init__(query, **kwargs)
1719 self.query = self.query.exists()
1720
1721 def select_format(
1722 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1723 ) -> tuple[str, Sequence[Any]]:
1724 # Wrap EXISTS() with a CASE WHEN expression if a database backend
1725 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1726 # BY list.
1727 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1728 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1729 return sql, params
1730
1731
1732@deconstructible(path="plain.models.OrderBy")
1733class OrderBy(Expression):
1734 template = "%(expression)s %(ordering)s"
1735 conditional = False
1736
1737 def __init__(
1738 self,
1739 expression: Any,
1740 descending: bool = False,
1741 nulls_first: bool | None = None,
1742 nulls_last: bool | None = None,
1743 ):
1744 if nulls_first and nulls_last:
1745 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1746 if nulls_first is False or nulls_last is False:
1747 raise ValueError("nulls_first and nulls_last values must be True or None.")
1748 self.nulls_first = nulls_first
1749 self.nulls_last = nulls_last
1750 self.descending = descending
1751 if not hasattr(expression, "resolve_expression"):
1752 raise ValueError("expression must be an expression type")
1753 self.expression = expression
1754
1755 def __repr__(self) -> str:
1756 return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1757
1758 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1759 self.expression = exprs[0]
1760
1761 def get_source_expressions(self) -> list[Any]:
1762 return [self.expression]
1763
1764 def as_sql(
1765 self,
1766 compiler: SQLCompiler,
1767 connection: BaseDatabaseWrapper,
1768 template: str | None = None,
1769 **extra_context: Any,
1770 ) -> tuple[str, tuple[Any, ...]]:
1771 template = template or self.template
1772 if connection.features.supports_order_by_nulls_modifier:
1773 if self.nulls_last:
1774 template = f"{template} NULLS LAST"
1775 elif self.nulls_first:
1776 template = f"{template} NULLS FIRST"
1777 else:
1778 if self.nulls_last and not (
1779 self.descending and connection.features.order_by_nulls_first
1780 ):
1781 template = f"%(expression)s IS NULL, {template}"
1782 elif self.nulls_first and not (
1783 not self.descending and connection.features.order_by_nulls_first
1784 ):
1785 template = f"%(expression)s IS NOT NULL, {template}"
1786 connection.ops.check_expression_support(self)
1787 expression_sql, params = compiler.compile(self.expression)
1788 placeholders = {
1789 "expression": expression_sql,
1790 "ordering": "DESC" if self.descending else "ASC",
1791 **extra_context,
1792 }
1793 params *= template.count("%(expression)s")
1794 return (template % placeholders).rstrip(), params
1795
1796 def get_group_by_cols(self) -> list[Any]:
1797 cols = []
1798 for source in self.get_source_expressions():
1799 cols.extend(source.get_group_by_cols())
1800 return cols
1801
1802 def reverse_ordering(self) -> OrderBy:
1803 self.descending = not self.descending
1804 if self.nulls_first:
1805 self.nulls_last = True
1806 self.nulls_first = None
1807 elif self.nulls_last:
1808 self.nulls_first = True
1809 self.nulls_last = None
1810 return self
1811
1812 def asc(self) -> None:
1813 self.descending = False
1814
1815 def desc(self) -> None:
1816 self.descending = True
1817
1818
1819class Window(SQLiteNumericMixin, Expression):
1820 template = "%(expression)s OVER (%(window)s)"
1821 # Although the main expression may either be an aggregate or an
1822 # expression with an aggregate function, the GROUP BY that will
1823 # be introduced in the query as a result is not desired.
1824 contains_aggregate = False
1825 contains_over_clause = True
1826 partition_by: ExpressionList | None
1827 order_by: OrderByList | None
1828
1829 def __init__(
1830 self,
1831 expression: Any,
1832 partition_by: Any = None,
1833 order_by: Any = None,
1834 frame: Any = None,
1835 output_field: Field | None = None,
1836 ):
1837 self.partition_by = partition_by
1838 self.order_by = order_by
1839 self.frame = frame
1840
1841 if not getattr(expression, "window_compatible", False):
1842 raise ValueError(
1843 f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1844 )
1845
1846 if self.partition_by is not None:
1847 partition_by_values = (
1848 self.partition_by
1849 if isinstance(self.partition_by, tuple | list)
1850 else (self.partition_by,)
1851 )
1852 self.partition_by = ExpressionList(*partition_by_values)
1853
1854 if self.order_by is not None:
1855 if isinstance(self.order_by, list | tuple):
1856 self.order_by = OrderByList(*self.order_by)
1857 elif isinstance(self.order_by, BaseExpression | str):
1858 self.order_by = OrderByList(self.order_by)
1859 else:
1860 raise ValueError(
1861 "Window.order_by must be either a string reference to a "
1862 "field, an expression, or a list or tuple of them."
1863 )
1864 super().__init__(output_field=output_field)
1865 self.source_expression = self._parse_expressions(expression)[0]
1866
1867 def _resolve_output_field(self) -> Field | None:
1868 return self.source_expression.output_field
1869
1870 def get_source_expressions(self) -> list[Any]:
1871 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1872
1873 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1874 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1875
1876 def as_sql(
1877 self,
1878 compiler: SQLCompiler,
1879 connection: BaseDatabaseWrapper,
1880 template: str | None = None,
1881 ) -> tuple[str, tuple[Any, ...]]:
1882 connection.ops.check_expression_support(self)
1883 if not connection.features.supports_over_clause:
1884 raise NotSupportedError("This backend does not support window expressions.")
1885 expr_sql, params = compiler.compile(self.source_expression)
1886 window_sql, window_params = [], ()
1887
1888 if self.partition_by is not None:
1889 sql_expr, sql_params = self.partition_by.as_sql(
1890 compiler=compiler,
1891 connection=connection,
1892 template="PARTITION BY %(expressions)s",
1893 )
1894 window_sql.append(sql_expr)
1895 window_params += tuple(sql_params)
1896
1897 if self.order_by is not None:
1898 order_sql, order_params = compiler.compile(self.order_by)
1899 window_sql.append(order_sql)
1900 window_params += tuple(order_params)
1901
1902 if self.frame:
1903 frame_sql, frame_params = compiler.compile(self.frame)
1904 window_sql.append(frame_sql)
1905 window_params += tuple(frame_params)
1906
1907 template = template or self.template
1908
1909 return (
1910 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1911 (*params, *window_params),
1912 )
1913
1914 def as_sqlite(
1915 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1916 ) -> tuple[str, Sequence[Any]]:
1917 if isinstance(self.output_field, fields.DecimalField):
1918 # Casting to numeric must be outside of the window expression.
1919 copy = self.copy()
1920 source_expressions = copy.get_source_expressions()
1921 source_expressions[0].output_field = fields.FloatField()
1922 copy.set_source_expressions(source_expressions)
1923 return super(Window, copy).as_sqlite(compiler, connection)
1924 return self.as_sql(compiler, connection)
1925
1926 def __str__(self) -> str:
1927 return "{} OVER ({}{}{})".format(
1928 str(self.source_expression),
1929 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1930 str(self.order_by or ""),
1931 str(self.frame or ""),
1932 )
1933
1934 def __repr__(self) -> str:
1935 return f"<{self.__class__.__name__}: {self}>"
1936
1937 def get_group_by_cols(self) -> list[Any]:
1938 group_by_cols = []
1939 if self.partition_by:
1940 group_by_cols.extend(self.partition_by.get_group_by_cols())
1941 if self.order_by is not None:
1942 group_by_cols.extend(self.order_by.get_group_by_cols())
1943 return group_by_cols
1944
1945
1946class WindowFrame(Expression, ABC):
1947 """
1948 Model the frame clause in window expressions. There are two types of frame
1949 clauses which are subclasses, however, all processing and validation (by no
1950 means intended to be complete) is done here. Thus, providing an end for a
1951 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1952 row in the frame).
1953 """
1954
1955 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1956 frame_type: str
1957
1958 def __init__(self, start: int | None = None, end: int | None = None):
1959 self.start = Value(start)
1960 self.end = Value(end)
1961
1962 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1963 self.start, self.end = exprs
1964
1965 def get_source_expressions(self) -> list[Any]:
1966 return [self.start, self.end]
1967
1968 def as_sql(
1969 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1970 ) -> tuple[str, list[Any]]:
1971 connection.ops.check_expression_support(self)
1972 start, end = self.window_frame_start_end(
1973 connection, self.start.value, self.end.value
1974 )
1975 return (
1976 self.template
1977 % {
1978 "frame_type": self.frame_type,
1979 "start": start,
1980 "end": end,
1981 },
1982 [],
1983 )
1984
1985 def __repr__(self) -> str:
1986 return f"<{self.__class__.__name__}: {self}>"
1987
1988 def get_group_by_cols(self) -> list[Any]:
1989 return []
1990
1991 def __str__(self) -> str:
1992 if self.start.value is not None and self.start.value < 0:
1993 start = f"{abs(self.start.value)} {db_connection.ops.PRECEDING}"
1994 elif self.start.value is not None and self.start.value == 0:
1995 start = db_connection.ops.CURRENT_ROW
1996 else:
1997 start = db_connection.ops.UNBOUNDED_PRECEDING
1998
1999 if self.end.value is not None and self.end.value > 0:
2000 end = f"{self.end.value} {db_connection.ops.FOLLOWING}"
2001 elif self.end.value is not None and self.end.value == 0:
2002 end = db_connection.ops.CURRENT_ROW
2003 else:
2004 end = db_connection.ops.UNBOUNDED_FOLLOWING
2005 return self.template % {
2006 "frame_type": self.frame_type,
2007 "start": start,
2008 "end": end,
2009 }
2010
2011 @abstractmethod
2012 def window_frame_start_end(
2013 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2014 ) -> tuple[str, str]: ...
2015
2016
2017class RowRange(WindowFrame):
2018 frame_type = "ROWS"
2019
2020 def window_frame_start_end(
2021 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2022 ) -> tuple[str, str]:
2023 return connection.ops.window_frame_rows_start_end(start, end)
2024
2025
2026class ValueRange(WindowFrame):
2027 frame_type = "RANGE"
2028
2029 def window_frame_start_end(
2030 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2031 ) -> tuple[str, str]:
2032 return connection.ops.window_frame_range_start_end(start, end)