1from __future__ import annotations
2
3import copy
4import datetime
5import functools
6import inspect
7from abc import ABC, abstractmethod
8from collections import defaultdict
9from decimal import Decimal
10from functools import cached_property
11from types import NoneType
12from typing import TYPE_CHECKING, Any, Protocol, Self, runtime_checkable
13from uuid import UUID
14
15from plain.models import fields
16from plain.models.constants import LOOKUP_SEP
17from plain.models.db import (
18 DatabaseError,
19 NotSupportedError,
20 db_connection,
21)
22from plain.models.exceptions import EmptyResultSet, FieldError, FullResultSet
23from plain.models.query_utils import Q
24from plain.utils.deconstruct import deconstructible
25from plain.utils.hashable import make_hashable
26
27if TYPE_CHECKING:
28 from collections.abc import Callable, Iterable, Sequence
29
30 from plain.models.backends.base.base import BaseDatabaseWrapper
31 from plain.models.fields import Field
32 from plain.models.lookups import Lookup, Transform
33 from plain.models.query import QuerySet
34 from plain.models.sql.compiler import SQLCompilable, SQLCompiler
35 from plain.models.sql.query import Query
36
37__all__ = [
38 # Core expression classes
39 "F",
40 "Value",
41 "Case",
42 "When",
43 "Subquery",
44 "Exists",
45 "OuterRef",
46 "Window",
47 "ExpressionWrapper",
48 "RawSQL",
49 "OrderBy",
50 # Base classes (for extension)
51 "Func",
52 "Expression",
53 "Combinable",
54 # Window frame specs
55 "RowRange",
56 "ValueRange",
57]
58
59
60@runtime_checkable
61class ResolvableExpression(Protocol):
62 """Protocol for expressions that can be resolved in query context."""
63
64 def resolve_expression(
65 self,
66 query: Any = None,
67 allow_joins: bool = True,
68 reuse: Any = None,
69 summarize: bool = False,
70 for_save: bool = False,
71 ) -> Any: ...
72
73
74@runtime_checkable
75class ReplaceableExpression(Protocol):
76 """Protocol for expressions that support expression replacement."""
77
78 def replace_expressions(self, replacements: dict[Any, Any]) -> Self: ...
79
80
81class Combinable:
82 """
83 Provide the ability to combine one or two objects with
84 some connector. For example F('foo') + F('bar').
85 """
86
87 # Arithmetic connectors
88 ADD = "+"
89 SUB = "-"
90 MUL = "*"
91 DIV = "/"
92 POW = "^"
93 # The following is a quoted % operator - it is quoted because it can be
94 # used in strings that also have parameter substitution.
95 MOD = "%%"
96
97 # Bitwise operators - note that these are generated by .bitand()
98 # and .bitor(), the '&' and '|' are reserved for boolean operator
99 # usage.
100 BITAND = "&"
101 BITOR = "|"
102 BITLEFTSHIFT = "<<"
103 BITRIGHTSHIFT = ">>"
104 BITXOR = "#"
105
106 def _combine(
107 self, other: Any, connector: str, reversed: bool
108 ) -> CombinedExpression:
109 if not isinstance(other, ResolvableExpression):
110 # everything must be resolvable to an expression
111 other = Value(other)
112
113 if reversed:
114 return CombinedExpression(other, connector, self)
115 return CombinedExpression(self, connector, other)
116
117 #############
118 # OPERATORS #
119 #############
120
121 def __neg__(self) -> CombinedExpression:
122 return self._combine(-1, self.MUL, False)
123
124 def __add__(self, other: Any) -> CombinedExpression:
125 return self._combine(other, self.ADD, False)
126
127 def __sub__(self, other: Any) -> CombinedExpression:
128 return self._combine(other, self.SUB, False)
129
130 def __mul__(self, other: Any) -> CombinedExpression:
131 return self._combine(other, self.MUL, False)
132
133 def __truediv__(self, other: Any) -> CombinedExpression:
134 return self._combine(other, self.DIV, False)
135
136 def __mod__(self, other: Any) -> CombinedExpression:
137 return self._combine(other, self.MOD, False)
138
139 def __pow__(self, other: Any) -> CombinedExpression:
140 return self._combine(other, self.POW, False)
141
142 def __and__(self, other: Any) -> Q:
143 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
144 return Q(self) & Q(other)
145 raise NotImplementedError(
146 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
147 )
148
149 def bitand(self, other: Any) -> CombinedExpression:
150 return self._combine(other, self.BITAND, False)
151
152 def bitleftshift(self, other: Any) -> CombinedExpression:
153 return self._combine(other, self.BITLEFTSHIFT, False)
154
155 def bitrightshift(self, other: Any) -> CombinedExpression:
156 return self._combine(other, self.BITRIGHTSHIFT, False)
157
158 def __xor__(self, other: Any) -> Q:
159 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
160 return Q(self) ^ Q(other)
161 raise NotImplementedError(
162 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
163 )
164
165 def bitxor(self, other: Any) -> CombinedExpression:
166 return self._combine(other, self.BITXOR, False)
167
168 def __or__(self, other: Any) -> Q:
169 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
170 return Q(self) | Q(other)
171 raise NotImplementedError(
172 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
173 )
174
175 def bitor(self, other: Any) -> CombinedExpression:
176 return self._combine(other, self.BITOR, False)
177
178 def __radd__(self, other: Any) -> CombinedExpression:
179 return self._combine(other, self.ADD, True)
180
181 def __rsub__(self, other: Any) -> CombinedExpression:
182 return self._combine(other, self.SUB, True)
183
184 def __rmul__(self, other: Any) -> CombinedExpression:
185 return self._combine(other, self.MUL, True)
186
187 def __rtruediv__(self, other: Any) -> CombinedExpression:
188 return self._combine(other, self.DIV, True)
189
190 def __rmod__(self, other: Any) -> CombinedExpression:
191 return self._combine(other, self.MOD, True)
192
193 def __rpow__(self, other: Any) -> CombinedExpression:
194 return self._combine(other, self.POW, True)
195
196 def __rand__(self, other: Any) -> None:
197 raise NotImplementedError(
198 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
199 )
200
201 def __ror__(self, other: Any) -> None:
202 raise NotImplementedError(
203 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
204 )
205
206 def __rxor__(self, other: Any) -> None:
207 raise NotImplementedError(
208 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
209 )
210
211 def __invert__(self) -> NegatedExpression:
212 return NegatedExpression(self)
213
214
215class BaseExpression:
216 """Base class for all query expressions."""
217
218 empty_result_set_value = NotImplemented
219 # aggregate specific fields
220 is_summary = False
221 _output_field_resolved_to_none = False
222 # Can the expression be used in a WHERE clause?
223 filterable = True
224 # Can the expression can be used as a source expression in Window?
225 window_compatible = False
226
227 def __init__(self, output_field: Field | None = None):
228 if output_field is not None:
229 self.output_field = output_field
230
231 def __getstate__(self) -> dict[str, Any]:
232 state = self.__dict__.copy()
233 state.pop("convert_value", None)
234 return state
235
236 def get_db_converters(
237 self, connection: BaseDatabaseWrapper
238 ) -> list[Callable[..., Any]]:
239 converters = []
240 if self.convert_value is not self._convert_value_noop:
241 converters.append(self.convert_value)
242 converters.extend(self.output_field.get_db_converters(connection))
243 return converters
244
245 def get_source_expressions(self) -> list[Any]:
246 return []
247
248 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
249 assert not exprs
250
251 def _parse_expressions(self, *expressions: Any) -> list[Any]:
252 return [
253 arg
254 if isinstance(arg, ResolvableExpression)
255 else (F(arg) if isinstance(arg, str) else Value(arg))
256 for arg in expressions
257 ]
258
259 def as_sql(
260 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
261 ) -> tuple[str, Sequence[Any]]:
262 """
263 Responsible for returning a (sql, [params]) tuple to be included
264 in the current query.
265
266 Different backends can provide their own implementation, by
267 providing an `as_{vendor}` method and patching the Expression:
268
269 ```
270 def override_as_sql(self, compiler, connection):
271 # custom logic
272 return super().as_sql(compiler, connection)
273 setattr(Expression, 'as_' + connection.vendor, override_as_sql)
274 ```
275
276 Arguments:
277 * compiler: the query compiler responsible for generating the query.
278 Must have a compile method, returning a (sql, [params]) tuple.
279 Calling compiler(value) will return a quoted `value`.
280
281 * connection: the database connection used for the current query.
282
283 Return: (sql, params)
284 Where `sql` is a string containing ordered sql parameters to be
285 replaced with the elements of the list `params`.
286 """
287 raise NotImplementedError("Subclasses must implement as_sql()")
288
289 @cached_property
290 def contains_aggregate(self) -> bool:
291 return any(
292 expr and expr.contains_aggregate for expr in self.get_source_expressions()
293 )
294
295 @cached_property
296 def contains_over_clause(self) -> bool:
297 return any(
298 expr and expr.contains_over_clause for expr in self.get_source_expressions()
299 )
300
301 @cached_property
302 def contains_column_references(self) -> bool:
303 return any(
304 expr and expr.contains_column_references
305 for expr in self.get_source_expressions()
306 )
307
308 def resolve_expression(
309 self,
310 query: Any = None,
311 allow_joins: bool = True,
312 reuse: Any = None,
313 summarize: bool = False,
314 for_save: bool = False,
315 ) -> Self:
316 """
317 Provide the chance to do any preprocessing or validation before being
318 added to the query.
319
320 Arguments:
321 * query: the backend query implementation
322 * allow_joins: boolean allowing or denying use of joins
323 in this query
324 * reuse: a set of reusable joins for multijoins
325 * summarize: a terminal aggregate clause
326 * for_save: whether this expression about to be used in a save or update
327
328 Return: an Expression to be added to the query.
329 """
330 c = self.copy()
331 c.is_summary = summarize
332 c.set_source_expressions(
333 [
334 expr.resolve_expression(query, allow_joins, reuse, summarize)
335 if expr
336 else None
337 for expr in c.get_source_expressions()
338 ]
339 )
340 return c
341
342 @property
343 def conditional(self) -> bool:
344 output_field = getattr(self, "output_field", None)
345 return isinstance(output_field, fields.BooleanField)
346
347 @property
348 def field(self) -> Field:
349 return self.output_field
350
351 @cached_property
352 def output_field(self) -> Field:
353 """Return the output type of this expressions."""
354 output_field = self._resolve_output_field()
355 if output_field is None:
356 self._output_field_resolved_to_none = True
357 raise FieldError("Cannot resolve expression type, unknown output_field")
358 return output_field
359
360 @cached_property
361 def _output_field_or_none(self) -> Field | None:
362 """
363 Return the output field of this expression, or None if
364 _resolve_output_field() didn't return an output type.
365 """
366 try:
367 return self.output_field
368 except FieldError:
369 if not self._output_field_resolved_to_none:
370 raise
371 return None
372
373 def _resolve_output_field(self) -> Field | None:
374 """
375 Attempt to infer the output type of the expression.
376
377 As a guess, if the output fields of all source fields match then simply
378 infer the same type here.
379
380 If a source's output field resolves to None, exclude it from this check.
381 If all sources are None, then an error is raised higher up the stack in
382 the output_field property.
383 """
384 # This guess is mostly a bad idea, but there is quite a lot of code
385 # (especially 3rd party Func subclasses) that depend on it, we'd need a
386 # deprecation path to fix it.
387 sources_iter = (
388 source for source in self.get_source_fields() if source is not None
389 )
390 for output_field in sources_iter:
391 for source in sources_iter:
392 if not isinstance(output_field, source.__class__):
393 raise FieldError(
394 f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
395 "set output_field."
396 )
397 return output_field
398 return None
399
400 @staticmethod
401 def _convert_value_noop(
402 value: Any, expression: Any, connection: BaseDatabaseWrapper
403 ) -> Any:
404 return value
405
406 @cached_property
407 def convert_value(self) -> Callable[[Any, Any, Any], Any]:
408 """
409 Expressions provide their own converters because users have the option
410 of manually specifying the output_field which may be a different type
411 from the one the database returns.
412 """
413 field = self.output_field
414 internal_type = field.get_internal_type()
415 if internal_type == "FloatField":
416 return (
417 lambda value, expression, connection: None
418 if value is None
419 else float(value)
420 )
421 elif internal_type.endswith("IntegerField"):
422 return (
423 lambda value, expression, connection: None
424 if value is None
425 else int(value)
426 )
427 elif internal_type == "DecimalField":
428 return (
429 lambda value, expression, connection: None
430 if value is None
431 else Decimal(value)
432 )
433 return self._convert_value_noop
434
435 def get_lookup(self, lookup: str) -> type[Lookup] | None:
436 return self.output_field.get_lookup(lookup)
437
438 def get_transform(self, name: str) -> type[Transform] | None:
439 return self.output_field.get_transform(name)
440
441 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
442 clone = self.copy()
443 clone.set_source_expressions(
444 [
445 e.relabeled_clone(change_map) if e is not None else None
446 for e in self.get_source_expressions()
447 ]
448 )
449 return clone
450
451 def replace_expressions(self, replacements: dict[BaseExpression, Any]) -> Self:
452 if replacement := replacements.get(self):
453 return replacement
454 clone = self.copy()
455 source_expressions = clone.get_source_expressions()
456 clone.set_source_expressions(
457 [
458 expr.replace_expressions(replacements) if expr else None
459 for expr in source_expressions
460 ]
461 )
462 return clone
463
464 def get_refs(self) -> set[str]:
465 refs = set()
466 for expr in self.get_source_expressions():
467 refs |= expr.get_refs()
468 return refs
469
470 def copy(self) -> Self:
471 return copy.copy(self)
472
473 def prefix_references(self, prefix: str) -> Self:
474 clone = self.copy()
475 clone.set_source_expressions(
476 [
477 F(f"{prefix}{expr.name}")
478 if isinstance(expr, F)
479 else expr.prefix_references(prefix)
480 for expr in self.get_source_expressions()
481 ]
482 )
483 return clone
484
485 def get_group_by_cols(self) -> list[BaseExpression]:
486 if not self.contains_aggregate:
487 return [self]
488 cols: list[BaseExpression] = []
489 for source in self.get_source_expressions():
490 cols.extend(source.get_group_by_cols())
491 return cols
492
493 def get_source_fields(self) -> list[Field | None]:
494 """Return the underlying field types used by this aggregate."""
495 return [e._output_field_or_none for e in self.get_source_expressions()]
496
497 def asc(self, **kwargs: Any) -> OrderBy:
498 return OrderBy(self, **kwargs)
499
500 def desc(self, **kwargs: Any) -> OrderBy:
501 return OrderBy(self, descending=True, **kwargs)
502
503 def reverse_ordering(self) -> Self:
504 return self
505
506 def flatten(self) -> Iterable[Any]:
507 """
508 Recursively yield this expression and all subexpressions, in
509 depth-first order.
510 """
511 yield self
512 for expr in self.get_source_expressions():
513 if expr:
514 if hasattr(expr, "flatten"):
515 yield from expr.flatten()
516 else:
517 yield expr
518
519 def select_format(
520 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
521 ) -> tuple[str, Sequence[Any]]:
522 """
523 Custom format for select clauses. For example, EXISTS expressions need
524 to be wrapped in CASE WHEN on Oracle.
525 """
526 if output_field := getattr(self, "output_field", None):
527 if select_format := getattr(output_field, "select_format", None):
528 return select_format(compiler, sql, params)
529 return sql, params
530
531
532@deconstructible
533class Expression(BaseExpression, Combinable):
534 """An expression that can be combined with other expressions."""
535
536 # Set by @deconstructible decorator in __new__
537 _constructor_args: tuple[tuple[Any, ...], dict[str, Any]]
538
539 @cached_property
540 def identity(self) -> tuple[Any, ...]:
541 constructor_signature = inspect.signature(self.__init__)
542 args, kwargs = self._constructor_args
543 signature = constructor_signature.bind_partial(*args, **kwargs)
544 signature.apply_defaults()
545 arguments = signature.arguments.items()
546 identity: list[Any] = [self.__class__]
547 for arg, value in arguments:
548 if isinstance(value, fields.Field):
549 if value.name and value.model:
550 value = (value.model.model_options.label, value.name)
551 else:
552 value = type(value)
553 else:
554 value = make_hashable(value)
555 identity.append((arg, value))
556 return tuple(identity)
557
558 def __eq__(self, other: object) -> bool:
559 if not isinstance(other, Expression):
560 return NotImplemented
561 return other.identity == self.identity
562
563 def __hash__(self) -> int:
564 return hash(self.identity)
565
566
567# Type inference for CombinedExpression.output_field.
568# Missing items will result in FieldError, by design.
569#
570# The current approach for NULL is based on lowest common denominator behavior
571# i.e. if one of the supported databases is raising an error (rather than
572# return NULL) for `val <op> NULL`, then Plain raises FieldError.
573
574_connector_combinations = [
575 # Numeric operations - operands of same type.
576 {
577 connector: [
578 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
579 (fields.FloatField, fields.FloatField, fields.FloatField),
580 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
581 ]
582 for connector in (
583 Combinable.ADD,
584 Combinable.SUB,
585 Combinable.MUL,
586 # Behavior for DIV with integer arguments follows Postgres/SQLite,
587 # not MySQL/Oracle.
588 Combinable.DIV,
589 Combinable.MOD,
590 Combinable.POW,
591 )
592 },
593 # Numeric operations - operands of different type.
594 {
595 connector: [
596 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
597 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
598 (fields.IntegerField, fields.FloatField, fields.FloatField),
599 (fields.FloatField, fields.IntegerField, fields.FloatField),
600 ]
601 for connector in (
602 Combinable.ADD,
603 Combinable.SUB,
604 Combinable.MUL,
605 Combinable.DIV,
606 Combinable.MOD,
607 )
608 },
609 # Bitwise operators.
610 {
611 connector: [
612 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
613 ]
614 for connector in (
615 Combinable.BITAND,
616 Combinable.BITOR,
617 Combinable.BITLEFTSHIFT,
618 Combinable.BITRIGHTSHIFT,
619 Combinable.BITXOR,
620 )
621 },
622 # Numeric with NULL.
623 {
624 connector: [
625 (field_type, NoneType, field_type),
626 (NoneType, field_type, field_type),
627 ]
628 for connector in (
629 Combinable.ADD,
630 Combinable.SUB,
631 Combinable.MUL,
632 Combinable.DIV,
633 Combinable.MOD,
634 Combinable.POW,
635 )
636 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
637 },
638 # Date/DateTimeField/DurationField/TimeField.
639 {
640 Combinable.ADD: [
641 # Date/DateTimeField.
642 (fields.DateField, fields.DurationField, fields.DateTimeField),
643 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
644 (fields.DurationField, fields.DateField, fields.DateTimeField),
645 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
646 # DurationField.
647 (fields.DurationField, fields.DurationField, fields.DurationField),
648 # TimeField.
649 (fields.TimeField, fields.DurationField, fields.TimeField),
650 (fields.DurationField, fields.TimeField, fields.TimeField),
651 ],
652 },
653 {
654 Combinable.SUB: [
655 # Date/DateTimeField.
656 (fields.DateField, fields.DurationField, fields.DateTimeField),
657 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
658 (fields.DateField, fields.DateField, fields.DurationField),
659 (fields.DateField, fields.DateTimeField, fields.DurationField),
660 (fields.DateTimeField, fields.DateField, fields.DurationField),
661 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
662 # DurationField.
663 (fields.DurationField, fields.DurationField, fields.DurationField),
664 # TimeField.
665 (fields.TimeField, fields.DurationField, fields.TimeField),
666 (fields.TimeField, fields.TimeField, fields.DurationField),
667 ],
668 },
669]
670
671_connector_combinators = defaultdict(list)
672
673
674def register_combinable_fields(
675 lhs: type[Field] | type[None],
676 connector: str,
677 rhs: type[Field] | type[None],
678 result: type[Field],
679) -> None:
680 """
681 Register combinable types:
682 lhs <connector> rhs -> result
683 e.g.
684 register_combinable_fields(
685 IntegerField, Combinable.ADD, FloatField, FloatField
686 )
687 """
688 _connector_combinators[connector].append((lhs, rhs, result))
689
690
691for d in _connector_combinations:
692 for connector, field_types in d.items():
693 for lhs, rhs, result in field_types:
694 register_combinable_fields(lhs, connector, rhs, result)
695
696
697@functools.lru_cache(maxsize=128)
698def _resolve_combined_type(
699 connector: str, lhs_type: type[Field], rhs_type: type[Field]
700) -> type[Field] | None:
701 combinators = _connector_combinators.get(connector, ())
702 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
703 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
704 rhs_type, combinator_rhs_type
705 ):
706 return combined_type
707 return None
708
709
710class SQLiteNumericMixin(Expression):
711 """
712 Some expressions with output_field=DecimalField() must be cast to
713 numeric to be properly filtered.
714 """
715
716 def as_sqlite(
717 self,
718 compiler: SQLCompiler,
719 connection: BaseDatabaseWrapper,
720 **extra_context: Any,
721 ) -> tuple[str, Sequence[Any]]:
722 sql, params = self.as_sql(compiler, connection, **extra_context)
723 try:
724 if self.output_field.get_internal_type() == "DecimalField":
725 sql = f"CAST({sql} AS NUMERIC)"
726 except FieldError:
727 pass
728 return sql, params
729
730
731class CombinedExpression(SQLiteNumericMixin, Expression):
732 def __init__(
733 self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
734 ):
735 super().__init__(output_field=output_field)
736 self.connector = connector
737 self.lhs = lhs
738 self.rhs = rhs
739
740 def __repr__(self) -> str:
741 return f"<{self.__class__.__name__}: {self}>"
742
743 def __str__(self) -> str:
744 return f"{self.lhs} {self.connector} {self.rhs}"
745
746 def get_source_expressions(self) -> list[Any]:
747 return [self.lhs, self.rhs]
748
749 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
750 self.lhs, self.rhs = exprs
751
752 def _resolve_output_field(self) -> Field | None:
753 # We avoid using super() here for reasons given in
754 # Expression._resolve_output_field()
755 combined_type = _resolve_combined_type(
756 self.connector,
757 type(self.lhs._output_field_or_none),
758 type(self.rhs._output_field_or_none),
759 )
760 if combined_type is None:
761 raise FieldError(
762 f"Cannot infer type of {self.connector!r} expression involving these "
763 f"types: {self.lhs.output_field.__class__.__name__}, "
764 f"{self.rhs.output_field.__class__.__name__}. You must set "
765 f"output_field."
766 )
767 return combined_type()
768
769 def as_sql(
770 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
771 ) -> tuple[str, list[Any]]:
772 expressions = []
773 expression_params = []
774 sql, params = compiler.compile(self.lhs)
775 expressions.append(sql)
776 expression_params.extend(params)
777 sql, params = compiler.compile(self.rhs)
778 expressions.append(sql)
779 expression_params.extend(params)
780 # order of precedence
781 expression_wrapper = "(%s)"
782 sql = connection.ops.combine_expression(self.connector, expressions)
783 return expression_wrapper % sql, expression_params
784
785 def resolve_expression(
786 self,
787 query: Any = None,
788 allow_joins: bool = True,
789 reuse: Any = None,
790 summarize: bool = False,
791 for_save: bool = False,
792 ) -> CombinedExpression | DurationExpression | TemporalSubtraction:
793 lhs = self.lhs.resolve_expression(
794 query, allow_joins, reuse, summarize, for_save
795 )
796 rhs = self.rhs.resolve_expression(
797 query, allow_joins, reuse, summarize, for_save
798 )
799 if not isinstance(self, DurationExpression | TemporalSubtraction):
800 try:
801 lhs_type = lhs.output_field.get_internal_type()
802 except (AttributeError, FieldError):
803 lhs_type = None
804 try:
805 rhs_type = rhs.output_field.get_internal_type()
806 except (AttributeError, FieldError):
807 rhs_type = None
808 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
809 return DurationExpression(
810 self.lhs, self.connector, self.rhs
811 ).resolve_expression(
812 query,
813 allow_joins,
814 reuse,
815 summarize,
816 for_save,
817 )
818 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
819 if (
820 self.connector == self.SUB
821 and lhs_type in datetime_fields
822 and lhs_type == rhs_type
823 ):
824 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
825 query,
826 allow_joins,
827 reuse,
828 summarize,
829 for_save,
830 )
831 c = self.copy()
832 c.is_summary = summarize
833 c.lhs = lhs
834 c.rhs = rhs
835 return c
836
837
838class DurationExpression(CombinedExpression):
839 def compile(
840 self, side: Any, compiler: SQLCompiler, connection: BaseDatabaseWrapper
841 ) -> tuple[str, Sequence[Any]]:
842 try:
843 output = side.output_field
844 except FieldError:
845 pass
846 else:
847 if output.get_internal_type() == "DurationField":
848 sql, params = compiler.compile(side)
849 return connection.ops.format_for_duration_arithmetic(sql), params
850 return compiler.compile(side)
851
852 def as_sql(
853 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
854 ) -> tuple[str, list[Any]]:
855 if connection.features.has_native_duration_field:
856 return super().as_sql(compiler, connection)
857 connection.ops.check_expression_support(self)
858 expressions = []
859 expression_params = []
860 sql, params = self.compile(self.lhs, compiler, connection)
861 expressions.append(sql)
862 expression_params.extend(params)
863 sql, params = self.compile(self.rhs, compiler, connection)
864 expressions.append(sql)
865 expression_params.extend(params)
866 # order of precedence
867 expression_wrapper = "(%s)"
868 sql = connection.ops.combine_duration_expression(self.connector, expressions)
869 return expression_wrapper % sql, expression_params
870
871 def as_sqlite(
872 self,
873 compiler: SQLCompiler,
874 connection: BaseDatabaseWrapper,
875 **extra_context: Any,
876 ) -> tuple[str, Sequence[Any]]:
877 sql, params = self.as_sql(compiler, connection, **extra_context)
878 if self.connector in {Combinable.MUL, Combinable.DIV}:
879 try:
880 lhs_type = self.lhs.output_field.get_internal_type()
881 rhs_type = self.rhs.output_field.get_internal_type()
882 except (AttributeError, FieldError):
883 pass
884 else:
885 allowed_fields = {
886 "DecimalField",
887 "DurationField",
888 "FloatField",
889 "IntegerField",
890 }
891 if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
892 raise DatabaseError(
893 f"Invalid arguments for operator {self.connector}."
894 )
895 return sql, params
896
897
898class TemporalSubtraction(CombinedExpression):
899 output_field = fields.DurationField()
900
901 def __init__(self, lhs: Any, rhs: Any):
902 super().__init__(lhs, self.SUB, rhs)
903
904 def as_sql(
905 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
906 ) -> tuple[str, list[Any]]:
907 connection.ops.check_expression_support(self)
908 lhs = compiler.compile(self.lhs)
909 rhs = compiler.compile(self.rhs)
910 sql, params = connection.ops.subtract_temporals(
911 self.lhs.output_field.get_internal_type(), lhs, rhs
912 )
913 return sql, list(params)
914
915
916@deconstructible(path="plain.models.F")
917class F(Combinable):
918 """An object capable of resolving references to existing query objects."""
919
920 def __init__(self, name: str):
921 """
922 Arguments:
923 * name: the name of the field this expression references
924 """
925 self.name = name
926
927 def __repr__(self) -> str:
928 return f"{self.__class__.__name__}({self.name})"
929
930 def resolve_expression(
931 self,
932 query: Any = None,
933 allow_joins: bool = True,
934 reuse: Any = None,
935 summarize: bool = False,
936 for_save: bool = False,
937 ) -> Any:
938 return query.resolve_ref(self.name, allow_joins, reuse, summarize)
939
940 def replace_expressions(self, replacements: dict[Any, Any]) -> F:
941 return replacements.get(self, self)
942
943 def asc(self, **kwargs: Any) -> OrderBy:
944 return OrderBy(self, **kwargs)
945
946 def desc(self, **kwargs: Any) -> OrderBy:
947 return OrderBy(self, descending=True, **kwargs)
948
949 def __eq__(self, other: object) -> bool:
950 if not isinstance(other, F):
951 return NotImplemented
952 return self.__class__ == other.__class__ and self.name == other.name
953
954 def __hash__(self) -> int:
955 return hash(self.name)
956
957 def copy(self) -> Self:
958 return copy.copy(self)
959
960
961class ResolvedOuterRef(F):
962 """
963 An object that contains a reference to an outer query.
964
965 In this case, the reference to the outer query has been resolved because
966 the inner query has been used as a subquery.
967 """
968
969 contains_aggregate = False
970 contains_over_clause = False
971
972 def as_sql(self, *args: Any, **kwargs: Any) -> None:
973 raise ValueError(
974 "This queryset contains a reference to an outer query and may "
975 "only be used in a subquery."
976 )
977
978 def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
979 col = super().resolve_expression(*args, **kwargs)
980 if col.contains_over_clause:
981 raise NotSupportedError(
982 f"Referencing outer query window expression is not supported: "
983 f"{self.name}."
984 )
985 # FIXME: Rename possibly_multivalued to multivalued and fix detection
986 # for non-multivalued JOINs (e.g. foreign key fields). This should take
987 # into accountย only many-to-many and one-to-many relationships.
988 col.possibly_multivalued = LOOKUP_SEP in self.name
989 return col
990
991 def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
992 return self
993
994 def get_group_by_cols(self) -> list[Any]:
995 return []
996
997
998class OuterRef(F):
999 contains_aggregate = False
1000
1001 def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
1002 if isinstance(self.name, self.__class__):
1003 return self.name
1004 return ResolvedOuterRef(self.name)
1005
1006 def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
1007 return self
1008
1009
1010@deconstructible(path="plain.models.Func")
1011class Func(SQLiteNumericMixin, Expression):
1012 """An SQL function call."""
1013
1014 function = None
1015 template = "%(function)s(%(expressions)s)"
1016 arg_joiner = ", "
1017 arity = None # The number of arguments the function accepts.
1018
1019 def __init__(
1020 self, *expressions: Any, output_field: Field | None = None, **extra: Any
1021 ):
1022 if self.arity is not None and len(expressions) != self.arity:
1023 raise TypeError(
1024 "'{}' takes exactly {} {} ({} given)".format(
1025 self.__class__.__name__,
1026 self.arity,
1027 "argument" if self.arity == 1 else "arguments",
1028 len(expressions),
1029 )
1030 )
1031 super().__init__(output_field=output_field)
1032 self.source_expressions: list[Any] = self._parse_expressions(*expressions)
1033 self.extra = extra
1034
1035 def __repr__(self) -> str:
1036 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1037 extra = {**self.extra, **self._get_repr_options()}
1038 if extra:
1039 extra = ", ".join(
1040 str(key) + "=" + str(val) for key, val in sorted(extra.items())
1041 )
1042 return f"{self.__class__.__name__}({args}, {extra})"
1043 return f"{self.__class__.__name__}({args})"
1044
1045 def _get_repr_options(self) -> dict[str, Any]:
1046 """Return a dict of extra __init__() options to include in the repr."""
1047 return {}
1048
1049 def get_source_expressions(self) -> list[Any]:
1050 return self.source_expressions
1051
1052 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1053 self.source_expressions = list(exprs)
1054
1055 def resolve_expression(
1056 self,
1057 query: Any = None,
1058 allow_joins: bool = True,
1059 reuse: Any = None,
1060 summarize: bool = False,
1061 for_save: bool = False,
1062 ) -> Self:
1063 c = self.copy()
1064 c.is_summary = summarize
1065 for pos, arg in enumerate(c.source_expressions):
1066 c.source_expressions[pos] = arg.resolve_expression(
1067 query, allow_joins, reuse, summarize, for_save
1068 )
1069 return c
1070
1071 def as_sql(
1072 self,
1073 compiler: SQLCompiler,
1074 connection: BaseDatabaseWrapper,
1075 function: str | None = None,
1076 template: str | None = None,
1077 arg_joiner: str | None = None,
1078 **extra_context: Any,
1079 ) -> tuple[str, list[Any]]:
1080 connection.ops.check_expression_support(self)
1081 sql_parts = []
1082 params = []
1083 for arg in self.source_expressions:
1084 try:
1085 arg_sql, arg_params = compiler.compile(arg)
1086 except EmptyResultSet:
1087 empty_result_set_value = getattr(
1088 arg, "empty_result_set_value", NotImplemented
1089 )
1090 if empty_result_set_value is NotImplemented:
1091 raise
1092 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
1093 except FullResultSet:
1094 arg_sql, arg_params = compiler.compile(Value(True))
1095 sql_parts.append(arg_sql)
1096 params.extend(arg_params)
1097 data = {**self.extra, **extra_context}
1098 # Use the first supplied value in this order: the parameter to this
1099 # method, a value supplied in __init__()'s **extra (the value in
1100 # `data`), or the value defined on the class.
1101 if function is not None:
1102 data["function"] = function
1103 else:
1104 data.setdefault("function", self.function)
1105 template = template or data.get("template", self.template)
1106 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1107 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1108 return template % data, params
1109
1110 def copy(self) -> Self:
1111 clone = super().copy()
1112 clone.source_expressions = self.source_expressions[:]
1113 clone.extra = self.extra.copy()
1114 return clone
1115
1116
1117@deconstructible(path="plain.models.Value")
1118class Value(SQLiteNumericMixin, Expression):
1119 """Represent a wrapped value as a node within an expression."""
1120
1121 # Provide a default value for `for_save` in order to allow unresolved
1122 # instances to be compiled until a decision is taken in #25425.
1123 for_save = False
1124
1125 def __init__(self, value: Any, output_field: Field | None = None):
1126 """
1127 Arguments:
1128 * value: the value this expression represents. The value will be
1129 added into the sql parameter list and properly quoted.
1130
1131 * output_field: an instance of the model field type that this
1132 expression will return, such as IntegerField() or CharField().
1133 """
1134 super().__init__(output_field=output_field)
1135 self.value = value
1136
1137 def __repr__(self) -> str:
1138 return f"{self.__class__.__name__}({self.value!r})"
1139
1140 def as_sql(
1141 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1142 ) -> tuple[str, list[Any]]:
1143 connection.ops.check_expression_support(self)
1144 val = self.value
1145 output_field = self._output_field_or_none
1146 if output_field is not None:
1147 if self.for_save:
1148 val = output_field.get_db_prep_save(val, connection=connection)
1149 else:
1150 val = output_field.get_db_prep_value(val, connection=connection)
1151 if hasattr(output_field, "get_placeholder"):
1152 return output_field.get_placeholder(val, compiler, connection), [val]
1153 if val is None:
1154 # cx_Oracle does not always convert None to the appropriate
1155 # NULL type (like in case expressions using numbers), so we
1156 # use a literal SQL NULL
1157 return "NULL", []
1158 return "%s", [val]
1159
1160 def resolve_expression(
1161 self,
1162 query: Any = None,
1163 allow_joins: bool = True,
1164 reuse: Any = None,
1165 summarize: bool = False,
1166 for_save: bool = False,
1167 ) -> Value:
1168 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1169 c.for_save = for_save
1170 return c
1171
1172 def get_group_by_cols(self) -> list[Any]:
1173 return []
1174
1175 def _resolve_output_field(self) -> Field | None:
1176 if isinstance(self.value, str):
1177 return fields.CharField()
1178 if isinstance(self.value, bool):
1179 return fields.BooleanField()
1180 if isinstance(self.value, int):
1181 return fields.IntegerField()
1182 if isinstance(self.value, float):
1183 return fields.FloatField()
1184 if isinstance(self.value, datetime.datetime):
1185 return fields.DateTimeField()
1186 if isinstance(self.value, datetime.date):
1187 return fields.DateField()
1188 if isinstance(self.value, datetime.time):
1189 return fields.TimeField()
1190 if isinstance(self.value, datetime.timedelta):
1191 return fields.DurationField()
1192 if isinstance(self.value, Decimal):
1193 return fields.DecimalField()
1194 if isinstance(self.value, bytes):
1195 return fields.BinaryField()
1196 if isinstance(self.value, UUID):
1197 return fields.UUIDField()
1198
1199 @property
1200 def empty_result_set_value(self) -> Any:
1201 return self.value
1202
1203
1204class RawSQL(Expression):
1205 def __init__(
1206 self, sql: str, params: Sequence[Any], output_field: Field | None = None
1207 ):
1208 if output_field is None:
1209 output_field = fields.Field()
1210 self.sql, self.params = sql, params
1211 super().__init__(output_field=output_field)
1212
1213 def __repr__(self) -> str:
1214 return f"{self.__class__.__name__}({self.sql}, {self.params})"
1215
1216 def as_sql(
1217 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1218 ) -> tuple[str, Sequence[Any]]:
1219 return f"({self.sql})", self.params
1220
1221 def get_group_by_cols(self) -> list[BaseExpression]:
1222 return [self]
1223
1224
1225class Star(Expression):
1226 def __repr__(self) -> str:
1227 return "'*'"
1228
1229 def as_sql(
1230 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1231 ) -> tuple[str, list[Any]]:
1232 return "*", []
1233
1234
1235class Col(Expression):
1236 contains_column_references = True
1237 possibly_multivalued = False
1238
1239 def __init__(
1240 self, alias: str | None, target: Any, output_field: Field | None = None
1241 ):
1242 if output_field is None:
1243 output_field = target
1244 super().__init__(output_field=output_field)
1245 self.alias, self.target = alias, target
1246
1247 def __repr__(self) -> str:
1248 alias, target = self.alias, self.target
1249 identifiers = (alias, str(target)) if alias else (str(target),)
1250 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1251
1252 def as_sql(
1253 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1254 ) -> tuple[str, list[Any]]:
1255 alias, column = self.alias, self.target.column
1256 identifiers = (alias, column) if alias else (column,)
1257 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1258 return sql, []
1259
1260 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1261 if self.alias is None:
1262 return self
1263 return self.__class__(
1264 change_map.get(self.alias, self.alias), self.target, self.output_field
1265 )
1266
1267 def get_group_by_cols(self) -> list[BaseExpression]:
1268 return [self]
1269
1270 def get_db_converters(
1271 self, connection: BaseDatabaseWrapper
1272 ) -> list[Callable[..., Any]]:
1273 if self.target == self.output_field:
1274 return self.output_field.get_db_converters(connection)
1275 return self.output_field.get_db_converters(
1276 connection
1277 ) + self.target.get_db_converters(connection)
1278
1279
1280class Ref(Expression):
1281 """
1282 Reference to column alias of the query. For example, Ref('sum_cost') in
1283 qs.annotate(sum_cost=Sum('cost')) query.
1284 """
1285
1286 def __init__(self, refs: str, source: Any):
1287 super().__init__()
1288 self.refs, self.source = refs, source
1289
1290 def __repr__(self) -> str:
1291 return f"{self.__class__.__name__}({self.refs}, {self.source})"
1292
1293 def get_source_expressions(self) -> list[Any]:
1294 return [self.source]
1295
1296 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1297 (self.source,) = exprs
1298
1299 def resolve_expression(
1300 self,
1301 query: Any = None,
1302 allow_joins: bool = True,
1303 reuse: Any = None,
1304 summarize: bool = False,
1305 for_save: bool = False,
1306 ) -> Ref:
1307 # The sub-expression `source` has already been resolved, as this is
1308 # just a reference to the name of `source`.
1309 return self
1310
1311 def get_refs(self) -> set[str]:
1312 return {self.refs}
1313
1314 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
1315 return self
1316
1317 def as_sql(
1318 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1319 ) -> tuple[str, list[Any]]:
1320 return connection.ops.quote_name(self.refs), []
1321
1322 def get_group_by_cols(self) -> list[BaseExpression]:
1323 return [self]
1324
1325
1326class ExpressionList(Func):
1327 """
1328 An expression containing multiple expressions. Can be used to provide a
1329 list of expressions as an argument to another expression, like a partition
1330 clause.
1331 """
1332
1333 template = "%(expressions)s"
1334
1335 def __init__(self, *expressions: Any, **extra: Any):
1336 if not expressions:
1337 raise ValueError(
1338 f"{self.__class__.__name__} requires at least one expression."
1339 )
1340 super().__init__(*expressions, **extra)
1341
1342 def __str__(self) -> str:
1343 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1344
1345 def as_sqlite(
1346 self,
1347 compiler: SQLCompiler,
1348 connection: BaseDatabaseWrapper,
1349 **extra_context: Any,
1350 ) -> tuple[str, Sequence[Any]]:
1351 # Casting to numeric is unnecessary.
1352 return self.as_sql(compiler, connection, **extra_context)
1353
1354
1355class OrderByList(Func):
1356 template = "ORDER BY %(expressions)s"
1357
1358 def __init__(self, *expressions: Any, **extra: Any):
1359 expressions_tuple = tuple(
1360 (
1361 OrderBy(F(expr[1:]), descending=True)
1362 if isinstance(expr, str) and expr[0] == "-"
1363 else expr
1364 )
1365 for expr in expressions
1366 )
1367 super().__init__(*expressions_tuple, **extra)
1368
1369 def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, list[Any]]:
1370 if not self.source_expressions:
1371 return "", []
1372 sql, params = super().as_sql(*args, **kwargs)
1373 return sql, list(params)
1374
1375 def get_group_by_cols(self) -> list[Any]:
1376 group_by_cols = []
1377 for order_by in self.get_source_expressions():
1378 group_by_cols.extend(order_by.get_group_by_cols())
1379 return group_by_cols
1380
1381
1382@deconstructible(path="plain.models.ExpressionWrapper")
1383class ExpressionWrapper(SQLiteNumericMixin, Expression):
1384 """
1385 An expression that can wrap another expression so that it can provide
1386 extra context to the inner expression, such as the output_field.
1387 """
1388
1389 def __init__(self, expression: Any, output_field: Field):
1390 super().__init__(output_field=output_field)
1391 self.expression = expression
1392
1393 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1394 self.expression = exprs[0]
1395
1396 def get_source_expressions(self) -> list[Any]:
1397 return [self.expression]
1398
1399 def get_group_by_cols(self) -> list[Any]:
1400 if isinstance(self.expression, Expression):
1401 expression = self.expression.copy()
1402 expression.output_field = self.output_field
1403 return expression.get_group_by_cols()
1404 # For non-expressions e.g. an SQL WHERE clause, the entire
1405 # `expression` must be included in the GROUP BY clause.
1406 return super().get_group_by_cols()
1407
1408 def as_sql(
1409 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1410 ) -> tuple[str, Sequence[Any]]:
1411 return compiler.compile(self.expression)
1412
1413 def __repr__(self) -> str:
1414 return f"{self.__class__.__name__}({self.expression})"
1415
1416
1417class NegatedExpression(ExpressionWrapper):
1418 """The logical negation of a conditional expression."""
1419
1420 def __init__(self, expression: Any):
1421 super().__init__(expression, output_field=fields.BooleanField())
1422
1423 def __invert__(self) -> Any:
1424 return self.expression.copy()
1425
1426 def as_sql(
1427 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1428 ) -> tuple[str, Sequence[Any]]:
1429 try:
1430 sql, params = super().as_sql(compiler, connection)
1431 except EmptyResultSet:
1432 features = compiler.connection.features
1433 if not features.supports_boolean_expr_in_select_clause:
1434 return "1=1", ()
1435 return compiler.compile(Value(True))
1436 ops = compiler.connection.ops
1437 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1438 # to be compared to another expression unless they're wrapped in a CASE
1439 # WHEN.
1440 if not ops.conditional_expression_supported_in_where_clause(self.expression):
1441 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1442 return f"NOT {sql}", params
1443
1444 def resolve_expression(
1445 self,
1446 query: Any = None,
1447 allow_joins: bool = True,
1448 reuse: Any = None,
1449 summarize: bool = False,
1450 for_save: bool = False,
1451 ) -> NegatedExpression:
1452 resolved = super().resolve_expression(
1453 query, allow_joins, reuse, summarize, for_save
1454 )
1455 if not getattr(resolved.expression, "conditional", False):
1456 raise TypeError("Cannot negate non-conditional expressions.")
1457 return resolved
1458
1459 def select_format(
1460 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1461 ) -> tuple[str, Sequence[Any]]:
1462 # Wrap boolean expressions with a CASE WHEN expression if a database
1463 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1464 # GROUP BY list.
1465 expression_supported_in_where_clause = (
1466 compiler.connection.ops.conditional_expression_supported_in_where_clause
1467 )
1468 if (
1469 not compiler.connection.features.supports_boolean_expr_in_select_clause
1470 # Avoid double wrapping.
1471 and expression_supported_in_where_clause(self.expression)
1472 ):
1473 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1474 return sql, params
1475
1476
1477@deconstructible(path="plain.models.When")
1478class When(Expression):
1479 template = "WHEN %(condition)s THEN %(result)s"
1480 # This isn't a complete conditional expression, must be used in Case().
1481 conditional = False
1482 condition: SQLCompilable
1483
1484 def __init__(
1485 self, condition: Q | Expression | None = None, then: Any = None, **lookups: Any
1486 ):
1487 lookups_dict: dict[str, Any] | None = lookups or None
1488 if lookups_dict:
1489 if condition is None:
1490 condition, lookups_dict = Q(**lookups_dict), None
1491 elif getattr(condition, "conditional", False):
1492 condition, lookups_dict = Q(condition, **lookups_dict), None
1493 if (
1494 condition is None
1495 or not getattr(condition, "conditional", False)
1496 or lookups_dict
1497 ):
1498 raise TypeError(
1499 "When() supports a Q object, a boolean expression, or lookups "
1500 "as a condition."
1501 )
1502 if isinstance(condition, Q) and not condition:
1503 raise ValueError("An empty Q() can't be used as a When() condition.")
1504 super().__init__(output_field=None)
1505 self.condition = condition # type: ignore[assignment]
1506 self.result = self._parse_expressions(then)[0]
1507
1508 def __str__(self) -> str:
1509 return f"WHEN {self.condition!r} THEN {self.result!r}"
1510
1511 def __repr__(self) -> str:
1512 return f"<{self.__class__.__name__}: {self}>"
1513
1514 def get_source_expressions(self) -> list[Any]:
1515 return [self.condition, self.result]
1516
1517 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1518 self.condition, self.result = exprs
1519
1520 def get_source_fields(self) -> list[Field | None]:
1521 # We're only interested in the fields of the result expressions.
1522 return [self.result._output_field_or_none]
1523
1524 def resolve_expression(
1525 self,
1526 query: Any = None,
1527 allow_joins: bool = True,
1528 reuse: Any = None,
1529 summarize: bool = False,
1530 for_save: bool = False,
1531 ) -> When:
1532 c = self.copy()
1533 c.is_summary = summarize
1534 if isinstance(c.condition, ResolvableExpression):
1535 c.condition = c.condition.resolve_expression(
1536 query, allow_joins, reuse, summarize, False
1537 )
1538 c.result = c.result.resolve_expression(
1539 query, allow_joins, reuse, summarize, for_save
1540 )
1541 return c
1542
1543 def as_sql(
1544 self,
1545 compiler: SQLCompiler,
1546 connection: BaseDatabaseWrapper,
1547 template: str | None = None,
1548 **extra_context: Any,
1549 ) -> tuple[str, tuple[Any, ...]]:
1550 connection.ops.check_expression_support(self)
1551 template_params = extra_context
1552 sql_params = []
1553 # After resolve_expression, condition is WhereNode | resolved Expression (both SQLCompilable)
1554 condition_sql, condition_params = compiler.compile(self.condition)
1555 template_params["condition"] = condition_sql
1556 result_sql, result_params = compiler.compile(self.result)
1557 template_params["result"] = result_sql
1558 template = template or self.template
1559 return template % template_params, (
1560 *sql_params,
1561 *condition_params,
1562 *result_params,
1563 )
1564
1565 def get_group_by_cols(self) -> list[Any]:
1566 # This is not a complete expression and cannot be used in GROUP BY.
1567 cols = []
1568 for source in self.get_source_expressions():
1569 cols.extend(source.get_group_by_cols())
1570 return cols
1571
1572
1573@deconstructible(path="plain.models.Case")
1574class Case(SQLiteNumericMixin, Expression):
1575 """
1576 An SQL searched CASE expression:
1577
1578 CASE
1579 WHEN n > 0
1580 THEN 'positive'
1581 WHEN n < 0
1582 THEN 'negative'
1583 ELSE 'zero'
1584 END
1585 """
1586
1587 template = "CASE %(cases)s ELSE %(default)s END"
1588 case_joiner = " "
1589
1590 def __init__(
1591 self,
1592 *cases: When,
1593 default: Any = None,
1594 output_field: Field | None = None,
1595 **extra: Any,
1596 ):
1597 if not all(isinstance(case, When) for case in cases):
1598 raise TypeError("Positional arguments must all be When objects.")
1599 super().__init__(output_field)
1600 self.cases = list(cases)
1601 self.default = self._parse_expressions(default)[0]
1602 self.extra = extra
1603
1604 def __str__(self) -> str:
1605 return "CASE {}, ELSE {!r}".format(
1606 ", ".join(str(c) for c in self.cases),
1607 self.default,
1608 )
1609
1610 def __repr__(self) -> str:
1611 return f"<{self.__class__.__name__}: {self}>"
1612
1613 def get_source_expressions(self) -> list[Any]:
1614 return self.cases + [self.default]
1615
1616 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1617 *self.cases, self.default = exprs
1618
1619 def resolve_expression(
1620 self,
1621 query: Any = None,
1622 allow_joins: bool = True,
1623 reuse: Any = None,
1624 summarize: bool = False,
1625 for_save: bool = False,
1626 ) -> Case:
1627 c = self.copy()
1628 c.is_summary = summarize
1629 for pos, case in enumerate(c.cases):
1630 c.cases[pos] = case.resolve_expression(
1631 query, allow_joins, reuse, summarize, for_save
1632 )
1633 c.default = c.default.resolve_expression(
1634 query, allow_joins, reuse, summarize, for_save
1635 )
1636 return c
1637
1638 def copy(self) -> Self:
1639 c = super().copy()
1640 c.cases = c.cases[:]
1641 return c
1642
1643 def as_sql(
1644 self,
1645 compiler: SQLCompiler,
1646 connection: BaseDatabaseWrapper,
1647 template: str | None = None,
1648 case_joiner: str | None = None,
1649 **extra_context: Any,
1650 ) -> tuple[str, list[Any]]:
1651 connection.ops.check_expression_support(self)
1652 if not self.cases:
1653 sql, params = compiler.compile(self.default)
1654 return sql, list(params)
1655 template_params = {**self.extra, **extra_context}
1656 case_parts = []
1657 sql_params = []
1658 default_sql, default_params = compiler.compile(self.default)
1659 for case in self.cases:
1660 try:
1661 case_sql, case_params = compiler.compile(case)
1662 except EmptyResultSet:
1663 continue
1664 except FullResultSet:
1665 default_sql, default_params = compiler.compile(case.result)
1666 break
1667 case_parts.append(case_sql)
1668 sql_params.extend(case_params)
1669 if not case_parts:
1670 return default_sql, list(default_params)
1671 case_joiner = case_joiner or self.case_joiner
1672 template_params["cases"] = case_joiner.join(case_parts)
1673 template_params["default"] = default_sql
1674 sql_params.extend(default_params)
1675 template = template or template_params.get("template", self.template)
1676 sql = template % template_params
1677 if self._output_field_or_none is not None:
1678 sql = connection.ops.unification_cast_sql(self.output_field) % sql
1679 return sql, sql_params
1680
1681 def get_group_by_cols(self) -> list[Any]:
1682 if not self.cases:
1683 return self.default.get_group_by_cols()
1684 return super().get_group_by_cols()
1685
1686
1687class Subquery(BaseExpression, Combinable):
1688 """
1689 An explicit subquery. It may contain OuterRef() references to the outer
1690 query which will be resolved when it is applied to that query.
1691 """
1692
1693 template = "(%(subquery)s)"
1694 contains_aggregate = False
1695 empty_result_set_value = None
1696
1697 def __init__(
1698 self,
1699 query: QuerySet[Any] | Query,
1700 output_field: Field | None = None,
1701 **extra: Any,
1702 ):
1703 # Import here to avoid circular import
1704 from plain.models.sql.query import Query
1705
1706 # Allow the usage of both QuerySet and sql.Query objects.
1707 if isinstance(query, Query):
1708 # It's already a Query object, use it directly
1709 sql_query = query
1710 else:
1711 # It's a QuerySet, extract the sql.Query
1712 sql_query = query.sql_query
1713 self.query = sql_query.clone()
1714 self.query.subquery = True
1715 self.extra = extra
1716 super().__init__(output_field)
1717
1718 def get_source_expressions(self) -> list[Any]:
1719 return [self.query]
1720
1721 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1722 self.query = exprs[0]
1723
1724 def _resolve_output_field(self) -> Field | None:
1725 return self.query.output_field
1726
1727 def copy(self) -> Self:
1728 clone = super().copy()
1729 clone.query = clone.query.clone()
1730 return clone
1731
1732 @property
1733 def external_aliases(self) -> dict[str, bool]:
1734 return self.query.external_aliases
1735
1736 def get_external_cols(self) -> list[Any]:
1737 return self.query.get_external_cols()
1738
1739 def as_sql(
1740 self,
1741 compiler: SQLCompiler,
1742 connection: BaseDatabaseWrapper,
1743 template: str | None = None,
1744 **extra_context: Any,
1745 ) -> tuple[str, tuple[Any, ...]]:
1746 connection.ops.check_expression_support(self)
1747 template_params = {**self.extra, **extra_context}
1748 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1749 template_params["subquery"] = subquery_sql[1:-1]
1750
1751 template = template or template_params.get("template", self.template)
1752 sql = template % template_params
1753 return sql, sql_params
1754
1755 def get_group_by_cols(self) -> list[Any]:
1756 return self.query.get_group_by_cols(wrapper=self)
1757
1758
1759class Exists(Subquery):
1760 template = "EXISTS(%(subquery)s)"
1761 output_field = fields.BooleanField()
1762 empty_result_set_value = False
1763
1764 def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1765 super().__init__(query, **kwargs)
1766 self.query = self.query.exists()
1767
1768 def select_format(
1769 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1770 ) -> tuple[str, Sequence[Any]]:
1771 # Wrap EXISTS() with a CASE WHEN expression if a database backend
1772 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1773 # BY list.
1774 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1775 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1776 return sql, params
1777
1778
1779@deconstructible(path="plain.models.OrderBy")
1780class OrderBy(Expression):
1781 template = "%(expression)s %(ordering)s"
1782 conditional = False
1783
1784 def __init__(
1785 self,
1786 expression: Any,
1787 descending: bool = False,
1788 nulls_first: bool | None = None,
1789 nulls_last: bool | None = None,
1790 ):
1791 if nulls_first and nulls_last:
1792 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1793 if nulls_first is False or nulls_last is False:
1794 raise ValueError("nulls_first and nulls_last values must be True or None.")
1795 self.nulls_first = nulls_first
1796 self.nulls_last = nulls_last
1797 self.descending = descending
1798 if not isinstance(expression, ResolvableExpression):
1799 raise ValueError("expression must be an expression type")
1800 self.expression = expression
1801
1802 def __repr__(self) -> str:
1803 return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1804
1805 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1806 self.expression = exprs[0]
1807
1808 def get_source_expressions(self) -> list[Any]:
1809 return [self.expression]
1810
1811 def as_sql(
1812 self,
1813 compiler: SQLCompiler,
1814 connection: BaseDatabaseWrapper,
1815 template: str | None = None,
1816 **extra_context: Any,
1817 ) -> tuple[str, tuple[Any, ...]]:
1818 template = template or self.template
1819 if connection.features.supports_order_by_nulls_modifier:
1820 if self.nulls_last:
1821 template = f"{template} NULLS LAST"
1822 elif self.nulls_first:
1823 template = f"{template} NULLS FIRST"
1824 else:
1825 if self.nulls_last and not (
1826 self.descending and connection.features.order_by_nulls_first
1827 ):
1828 template = f"%(expression)s IS NULL, {template}"
1829 elif self.nulls_first and not (
1830 not self.descending and connection.features.order_by_nulls_first
1831 ):
1832 template = f"%(expression)s IS NOT NULL, {template}"
1833 connection.ops.check_expression_support(self)
1834 expression_sql, params = compiler.compile(self.expression)
1835 placeholders = {
1836 "expression": expression_sql,
1837 "ordering": "DESC" if self.descending else "ASC",
1838 **extra_context,
1839 }
1840 params *= template.count("%(expression)s")
1841 return (template % placeholders).rstrip(), params
1842
1843 def get_group_by_cols(self) -> list[Any]:
1844 cols = []
1845 for source in self.get_source_expressions():
1846 cols.extend(source.get_group_by_cols())
1847 return cols
1848
1849 def reverse_ordering(self) -> OrderBy:
1850 self.descending = not self.descending
1851 if self.nulls_first:
1852 self.nulls_last = True
1853 self.nulls_first = None
1854 elif self.nulls_last:
1855 self.nulls_first = True
1856 self.nulls_last = None
1857 return self
1858
1859 def asc(self) -> None: # type: ignore[override]
1860 self.descending = False
1861
1862 def desc(self) -> None: # type: ignore[override]
1863 self.descending = True
1864
1865
1866class Window(SQLiteNumericMixin, Expression):
1867 template = "%(expression)s OVER (%(window)s)"
1868 # Although the main expression may either be an aggregate or an
1869 # expression with an aggregate function, the GROUP BY that will
1870 # be introduced in the query as a result is not desired.
1871 contains_aggregate = False
1872 contains_over_clause = True
1873 partition_by: ExpressionList | None
1874 order_by: OrderByList | None
1875
1876 def __init__(
1877 self,
1878 expression: Any,
1879 partition_by: Any = None,
1880 order_by: Any = None,
1881 frame: Any = None,
1882 output_field: Field | None = None,
1883 ):
1884 self.partition_by = partition_by
1885 self.order_by = order_by
1886 self.frame = frame
1887
1888 if not getattr(expression, "window_compatible", False):
1889 raise ValueError(
1890 f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1891 )
1892
1893 if self.partition_by is not None:
1894 partition_by_values = (
1895 self.partition_by
1896 if isinstance(self.partition_by, tuple | list)
1897 else (self.partition_by,)
1898 )
1899 self.partition_by = ExpressionList(*partition_by_values)
1900
1901 if self.order_by is not None:
1902 if isinstance(self.order_by, list | tuple):
1903 self.order_by = OrderByList(*self.order_by)
1904 elif isinstance(self.order_by, BaseExpression | str):
1905 self.order_by = OrderByList(self.order_by)
1906 else:
1907 raise ValueError(
1908 "Window.order_by must be either a string reference to a "
1909 "field, an expression, or a list or tuple of them."
1910 )
1911 super().__init__(output_field=output_field)
1912 self.source_expression = self._parse_expressions(expression)[0]
1913
1914 def _resolve_output_field(self) -> Field | None:
1915 return self.source_expression.output_field
1916
1917 def get_source_expressions(self) -> list[Any]:
1918 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1919
1920 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1921 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1922
1923 def as_sql(
1924 self,
1925 compiler: SQLCompiler,
1926 connection: BaseDatabaseWrapper,
1927 template: str | None = None,
1928 ) -> tuple[str, tuple[Any, ...]]:
1929 connection.ops.check_expression_support(self)
1930 if not connection.features.supports_over_clause:
1931 raise NotSupportedError("This backend does not support window expressions.")
1932 expr_sql, params = compiler.compile(self.source_expression)
1933 window_sql, window_params = [], ()
1934
1935 if self.partition_by is not None:
1936 sql_expr, sql_params = self.partition_by.as_sql(
1937 compiler=compiler,
1938 connection=connection,
1939 template="PARTITION BY %(expressions)s",
1940 )
1941 window_sql.append(sql_expr)
1942 window_params += tuple(sql_params)
1943
1944 if self.order_by is not None:
1945 order_sql, order_params = compiler.compile(self.order_by)
1946 window_sql.append(order_sql)
1947 window_params += tuple(order_params)
1948
1949 if self.frame:
1950 frame_sql, frame_params = compiler.compile(self.frame)
1951 window_sql.append(frame_sql)
1952 window_params += tuple(frame_params)
1953
1954 template = template or self.template
1955
1956 return (
1957 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1958 (*params, *window_params),
1959 )
1960
1961 def as_sqlite(
1962 self,
1963 compiler: SQLCompiler,
1964 connection: BaseDatabaseWrapper,
1965 **extra_context: Any,
1966 ) -> tuple[str, Sequence[Any]]:
1967 if isinstance(self.output_field, fields.DecimalField):
1968 # Casting to numeric must be outside of the window expression.
1969 copy = self.copy()
1970 source_expressions = copy.get_source_expressions()
1971 source_expressions[0].output_field = fields.FloatField()
1972 copy.set_source_expressions(source_expressions)
1973 return super(Window, copy).as_sqlite(compiler, connection, **extra_context)
1974 return self.as_sql(compiler, connection, **extra_context)
1975
1976 def __str__(self) -> str:
1977 return "{} OVER ({}{}{})".format(
1978 str(self.source_expression),
1979 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1980 str(self.order_by or ""),
1981 str(self.frame or ""),
1982 )
1983
1984 def __repr__(self) -> str:
1985 return f"<{self.__class__.__name__}: {self}>"
1986
1987 def get_group_by_cols(self) -> list[Any]:
1988 group_by_cols = []
1989 if self.partition_by:
1990 group_by_cols.extend(self.partition_by.get_group_by_cols())
1991 if self.order_by is not None:
1992 group_by_cols.extend(self.order_by.get_group_by_cols())
1993 return group_by_cols
1994
1995
1996class WindowFrame(Expression, ABC):
1997 """
1998 Model the frame clause in window expressions. There are two types of frame
1999 clauses which are subclasses, however, all processing and validation (by no
2000 means intended to be complete) is done here. Thus, providing an end for a
2001 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
2002 row in the frame).
2003 """
2004
2005 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
2006 frame_type: str
2007
2008 def __init__(self, start: int | None = None, end: int | None = None):
2009 self.start = Value(start)
2010 self.end = Value(end)
2011
2012 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
2013 self.start, self.end = exprs
2014
2015 def get_source_expressions(self) -> list[Any]:
2016 return [self.start, self.end]
2017
2018 def as_sql(
2019 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
2020 ) -> tuple[str, list[Any]]:
2021 connection.ops.check_expression_support(self)
2022 start, end = self.window_frame_start_end(
2023 connection, self.start.value, self.end.value
2024 )
2025 return (
2026 self.template
2027 % {
2028 "frame_type": self.frame_type,
2029 "start": start,
2030 "end": end,
2031 },
2032 [],
2033 )
2034
2035 def __repr__(self) -> str:
2036 return f"<{self.__class__.__name__}: {self}>"
2037
2038 def get_group_by_cols(self) -> list[Any]:
2039 return []
2040
2041 def __str__(self) -> str:
2042 if self.start.value is not None and self.start.value < 0:
2043 start = f"{abs(self.start.value)} {db_connection.ops.PRECEDING}"
2044 elif self.start.value is not None and self.start.value == 0:
2045 start = db_connection.ops.CURRENT_ROW
2046 else:
2047 start = db_connection.ops.UNBOUNDED_PRECEDING
2048
2049 if self.end.value is not None and self.end.value > 0:
2050 end = f"{self.end.value} {db_connection.ops.FOLLOWING}"
2051 elif self.end.value is not None and self.end.value == 0:
2052 end = db_connection.ops.CURRENT_ROW
2053 else:
2054 end = db_connection.ops.UNBOUNDED_FOLLOWING
2055 return self.template % {
2056 "frame_type": self.frame_type,
2057 "start": start,
2058 "end": end,
2059 }
2060
2061 @abstractmethod
2062 def window_frame_start_end(
2063 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2064 ) -> tuple[str, str]: ...
2065
2066
2067class RowRange(WindowFrame):
2068 frame_type = "ROWS"
2069
2070 def window_frame_start_end(
2071 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2072 ) -> tuple[str, str]:
2073 return connection.ops.window_frame_rows_start_end(start, end)
2074
2075
2076class ValueRange(WindowFrame):
2077 frame_type = "RANGE"
2078
2079 def window_frame_start_end(
2080 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2081 ) -> tuple[str, str]:
2082 return connection.ops.window_frame_range_start_end(start, end)