1from __future__ import annotations
2
3import copy
4import datetime
5import functools
6import inspect
7from collections import defaultdict
8from decimal import Decimal
9from functools import cached_property
10from types import NoneType
11from typing import TYPE_CHECKING, Any
12from uuid import UUID
13
14from plain.models import fields
15from plain.models.constants import LOOKUP_SEP
16from plain.models.db import (
17 DatabaseError,
18 NotSupportedError,
19 db_connection,
20)
21from plain.models.exceptions import EmptyResultSet, FieldError, FullResultSet
22from plain.models.query_utils import Q
23from plain.utils.deconstruct import deconstructible
24from plain.utils.hashable import make_hashable
25
26if TYPE_CHECKING:
27 from collections.abc import Callable, Iterable, Sequence
28
29 from plain.models.backends.base.base import BaseDatabaseWrapper
30 from plain.models.fields import Field
31 from plain.models.query import QuerySet
32 from plain.models.sql.compiler import SQLCompiler
33 from plain.models.sql.query import Query
34
35
36class SQLiteNumericMixin:
37 """
38 Some expressions with output_field=DecimalField() must be cast to
39 numeric to be properly filtered.
40 """
41
42 def as_sqlite(
43 self,
44 compiler: SQLCompiler,
45 connection: BaseDatabaseWrapper,
46 **extra_context: Any,
47 ) -> tuple[str, Sequence[Any]]:
48 sql, params = self.as_sql(compiler, connection, **extra_context) # type: ignore[attr-defined]
49 try:
50 if self.output_field.get_internal_type() == "DecimalField": # type: ignore[attr-defined]
51 sql = f"CAST({sql} AS NUMERIC)"
52 except FieldError:
53 pass
54 return sql, params
55
56
57class Combinable:
58 """
59 Provide the ability to combine one or two objects with
60 some connector. For example F('foo') + F('bar').
61 """
62
63 # Arithmetic connectors
64 ADD = "+"
65 SUB = "-"
66 MUL = "*"
67 DIV = "/"
68 POW = "^"
69 # The following is a quoted % operator - it is quoted because it can be
70 # used in strings that also have parameter substitution.
71 MOD = "%%"
72
73 # Bitwise operators - note that these are generated by .bitand()
74 # and .bitor(), the '&' and '|' are reserved for boolean operator
75 # usage.
76 BITAND = "&"
77 BITOR = "|"
78 BITLEFTSHIFT = "<<"
79 BITRIGHTSHIFT = ">>"
80 BITXOR = "#"
81
82 def _combine(
83 self, other: Any, connector: str, reversed: bool
84 ) -> CombinedExpression:
85 if not hasattr(other, "resolve_expression"):
86 # everything must be resolvable to an expression
87 other = Value(other)
88
89 if reversed:
90 return CombinedExpression(other, connector, self)
91 return CombinedExpression(self, connector, other)
92
93 #############
94 # OPERATORS #
95 #############
96
97 def __neg__(self) -> CombinedExpression:
98 return self._combine(-1, self.MUL, False)
99
100 def __add__(self, other: Any) -> CombinedExpression:
101 return self._combine(other, self.ADD, False)
102
103 def __sub__(self, other: Any) -> CombinedExpression:
104 return self._combine(other, self.SUB, False)
105
106 def __mul__(self, other: Any) -> CombinedExpression:
107 return self._combine(other, self.MUL, False)
108
109 def __truediv__(self, other: Any) -> CombinedExpression:
110 return self._combine(other, self.DIV, False)
111
112 def __mod__(self, other: Any) -> CombinedExpression:
113 return self._combine(other, self.MOD, False)
114
115 def __pow__(self, other: Any) -> CombinedExpression:
116 return self._combine(other, self.POW, False)
117
118 def __and__(self, other: Any) -> Q:
119 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
120 return Q(self) & Q(other) # type: ignore[unsupported-operator]
121 raise NotImplementedError(
122 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
123 )
124
125 def bitand(self, other: Any) -> CombinedExpression:
126 return self._combine(other, self.BITAND, False)
127
128 def bitleftshift(self, other: Any) -> CombinedExpression:
129 return self._combine(other, self.BITLEFTSHIFT, False)
130
131 def bitrightshift(self, other: Any) -> CombinedExpression:
132 return self._combine(other, self.BITRIGHTSHIFT, False)
133
134 def __xor__(self, other: Any) -> Q:
135 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
136 return Q(self) ^ Q(other) # type: ignore[unsupported-operator]
137 raise NotImplementedError(
138 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
139 )
140
141 def bitxor(self, other: Any) -> CombinedExpression:
142 return self._combine(other, self.BITXOR, False)
143
144 def __or__(self, other: Any) -> Q:
145 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
146 return Q(self) | Q(other) # type: ignore[unsupported-operator]
147 raise NotImplementedError(
148 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
149 )
150
151 def bitor(self, other: Any) -> CombinedExpression:
152 return self._combine(other, self.BITOR, False)
153
154 def __radd__(self, other: Any) -> CombinedExpression:
155 return self._combine(other, self.ADD, True)
156
157 def __rsub__(self, other: Any) -> CombinedExpression:
158 return self._combine(other, self.SUB, True)
159
160 def __rmul__(self, other: Any) -> CombinedExpression:
161 return self._combine(other, self.MUL, True)
162
163 def __rtruediv__(self, other: Any) -> CombinedExpression:
164 return self._combine(other, self.DIV, True)
165
166 def __rmod__(self, other: Any) -> CombinedExpression:
167 return self._combine(other, self.MOD, True)
168
169 def __rpow__(self, other: Any) -> CombinedExpression:
170 return self._combine(other, self.POW, True)
171
172 def __rand__(self, other: Any) -> None:
173 raise NotImplementedError(
174 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
175 )
176
177 def __ror__(self, other: Any) -> None:
178 raise NotImplementedError(
179 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
180 )
181
182 def __rxor__(self, other: Any) -> None:
183 raise NotImplementedError(
184 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
185 )
186
187 def __invert__(self) -> NegatedExpression:
188 return NegatedExpression(self)
189
190
191class BaseExpression:
192 """Base class for all query expressions."""
193
194 empty_result_set_value = NotImplemented
195 # aggregate specific fields
196 is_summary = False
197 _output_field_resolved_to_none = False
198 # Can the expression be used in a WHERE clause?
199 filterable = True
200 # Can the expression can be used as a source expression in Window?
201 window_compatible = False
202
203 def __init__(self, output_field: Field | None = None):
204 if output_field is not None:
205 self.output_field = output_field
206
207 def __getstate__(self) -> dict[str, Any]:
208 state = self.__dict__.copy()
209 state.pop("convert_value", None)
210 return state
211
212 def get_db_converters(
213 self, connection: BaseDatabaseWrapper
214 ) -> list[Callable[..., Any]]:
215 return (
216 []
217 if self.convert_value is self._convert_value_noop # type: ignore[attr-defined]
218 else [self.convert_value] # type: ignore[attr-defined]
219 ) + self.output_field.get_db_converters(connection) # type: ignore[attr-defined]
220
221 def get_source_expressions(self) -> list[Any]:
222 return []
223
224 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
225 assert not exprs
226
227 def _parse_expressions(self, *expressions: Any) -> list[Any]:
228 return [
229 arg
230 if hasattr(arg, "resolve_expression")
231 else (F(arg) if isinstance(arg, str) else Value(arg))
232 for arg in expressions
233 ]
234
235 def as_sql(
236 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
237 ) -> tuple[str, Sequence[Any]]:
238 """
239 Responsible for returning a (sql, [params]) tuple to be included
240 in the current query.
241
242 Different backends can provide their own implementation, by
243 providing an `as_{vendor}` method and patching the Expression:
244
245 ```
246 def override_as_sql(self, compiler, connection):
247 # custom logic
248 return super().as_sql(compiler, connection)
249 setattr(Expression, 'as_' + connection.vendor, override_as_sql)
250 ```
251
252 Arguments:
253 * compiler: the query compiler responsible for generating the query.
254 Must have a compile method, returning a (sql, [params]) tuple.
255 Calling compiler(value) will return a quoted `value`.
256
257 * connection: the database connection used for the current query.
258
259 Return: (sql, params)
260 Where `sql` is a string containing ordered sql parameters to be
261 replaced with the elements of the list `params`.
262 """
263 raise NotImplementedError("Subclasses must implement as_sql()")
264
265 @cached_property
266 def contains_aggregate(self) -> bool:
267 return any(
268 expr and expr.contains_aggregate for expr in self.get_source_expressions()
269 )
270
271 @cached_property
272 def contains_over_clause(self) -> bool:
273 return any(
274 expr and expr.contains_over_clause for expr in self.get_source_expressions()
275 )
276
277 @cached_property
278 def contains_column_references(self) -> bool:
279 return any(
280 expr and expr.contains_column_references
281 for expr in self.get_source_expressions()
282 )
283
284 def resolve_expression(
285 self,
286 query: Any = None,
287 allow_joins: bool = True,
288 reuse: Any = None,
289 summarize: bool = False,
290 for_save: bool = False,
291 ) -> BaseExpression:
292 """
293 Provide the chance to do any preprocessing or validation before being
294 added to the query.
295
296 Arguments:
297 * query: the backend query implementation
298 * allow_joins: boolean allowing or denying use of joins
299 in this query
300 * reuse: a set of reusable joins for multijoins
301 * summarize: a terminal aggregate clause
302 * for_save: whether this expression about to be used in a save or update
303
304 Return: an Expression to be added to the query.
305 """
306 c = self.copy()
307 c.is_summary = summarize
308 c.set_source_expressions(
309 [
310 expr.resolve_expression(query, allow_joins, reuse, summarize)
311 if expr
312 else None
313 for expr in c.get_source_expressions()
314 ]
315 )
316 return c
317
318 @property
319 def conditional(self) -> bool:
320 return isinstance(self.output_field, fields.BooleanField) # type: ignore[attr-defined]
321
322 @property
323 def field(self) -> Field:
324 return self.output_field # type: ignore[attr-defined]
325
326 @cached_property
327 def output_field(self) -> Field:
328 """Return the output type of this expressions."""
329 output_field = self._resolve_output_field()
330 if output_field is None:
331 self._output_field_resolved_to_none = True
332 raise FieldError("Cannot resolve expression type, unknown output_field")
333 return output_field
334
335 @cached_property
336 def _output_field_or_none(self) -> Field | None:
337 """
338 Return the output field of this expression, or None if
339 _resolve_output_field() didn't return an output type.
340 """
341 try:
342 return self.output_field
343 except FieldError:
344 if not self._output_field_resolved_to_none:
345 raise
346 return None
347
348 def _resolve_output_field(self) -> Field | None:
349 """
350 Attempt to infer the output type of the expression.
351
352 As a guess, if the output fields of all source fields match then simply
353 infer the same type here.
354
355 If a source's output field resolves to None, exclude it from this check.
356 If all sources are None, then an error is raised higher up the stack in
357 the output_field property.
358 """
359 # This guess is mostly a bad idea, but there is quite a lot of code
360 # (especially 3rd party Func subclasses) that depend on it, we'd need a
361 # deprecation path to fix it.
362 sources_iter = (
363 source for source in self.get_source_fields() if source is not None
364 )
365 for output_field in sources_iter:
366 for source in sources_iter:
367 if not isinstance(output_field, source.__class__):
368 raise FieldError(
369 f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
370 "set output_field."
371 )
372 return output_field
373 return None
374
375 @staticmethod
376 def _convert_value_noop(
377 value: Any, expression: Any, connection: BaseDatabaseWrapper
378 ) -> Any:
379 return value
380
381 @cached_property
382 def convert_value(self) -> Callable[[Any, Any, Any], Any]:
383 """
384 Expressions provide their own converters because users have the option
385 of manually specifying the output_field which may be a different type
386 from the one the database returns.
387 """
388 field = self.output_field
389 internal_type = field.get_internal_type()
390 if internal_type == "FloatField":
391 return (
392 lambda value, expression, connection: None
393 if value is None
394 else float(value)
395 )
396 elif internal_type.endswith("IntegerField"):
397 return (
398 lambda value, expression, connection: None
399 if value is None
400 else int(value)
401 )
402 elif internal_type == "DecimalField":
403 return (
404 lambda value, expression, connection: None
405 if value is None
406 else Decimal(value)
407 )
408 return self._convert_value_noop
409
410 def get_lookup(self, lookup: str) -> type | None:
411 return self.output_field.get_lookup(lookup) # type: ignore[attr-defined]
412
413 def get_transform(self, name: str) -> type | None:
414 return self.output_field.get_transform(name) # type: ignore[attr-defined]
415
416 def relabeled_clone(self, change_map: dict[str, str]) -> BaseExpression:
417 clone = self.copy()
418 clone.set_source_expressions(
419 [
420 e.relabeled_clone(change_map) if e is not None else None
421 for e in self.get_source_expressions()
422 ]
423 )
424 return clone
425
426 def replace_expressions(
427 self, replacements: dict[BaseExpression, Any]
428 ) -> BaseExpression:
429 if replacement := replacements.get(self):
430 return replacement
431 clone = self.copy()
432 source_expressions = clone.get_source_expressions()
433 clone.set_source_expressions(
434 [
435 expr.replace_expressions(replacements) if expr else None
436 for expr in source_expressions
437 ]
438 )
439 return clone
440
441 def get_refs(self) -> set[str]:
442 refs = set()
443 for expr in self.get_source_expressions():
444 refs |= expr.get_refs()
445 return refs
446
447 def copy(self) -> BaseExpression:
448 return copy.copy(self)
449
450 def prefix_references(self, prefix: str) -> BaseExpression:
451 clone = self.copy()
452 clone.set_source_expressions(
453 [
454 F(f"{prefix}{expr.name}")
455 if isinstance(expr, F)
456 else expr.prefix_references(prefix)
457 for expr in self.get_source_expressions()
458 ]
459 )
460 return clone
461
462 def get_group_by_cols(self) -> list[BaseExpression]:
463 if not self.contains_aggregate:
464 return [self]
465 cols = []
466 for source in self.get_source_expressions():
467 cols.extend(source.get_group_by_cols())
468 return cols
469
470 def get_source_fields(self) -> list[Field | None]:
471 """Return the underlying field types used by this aggregate."""
472 return [e._output_field_or_none for e in self.get_source_expressions()]
473
474 def asc(self, **kwargs: Any) -> OrderBy:
475 return OrderBy(self, **kwargs)
476
477 def desc(self, **kwargs: Any) -> OrderBy:
478 return OrderBy(self, descending=True, **kwargs)
479
480 def reverse_ordering(self) -> BaseExpression:
481 return self
482
483 def flatten(self) -> Iterable[Any]:
484 """
485 Recursively yield this expression and all subexpressions, in
486 depth-first order.
487 """
488 yield self
489 for expr in self.get_source_expressions():
490 if expr:
491 if hasattr(expr, "flatten"):
492 yield from expr.flatten()
493 else:
494 yield expr
495
496 def select_format(
497 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
498 ) -> tuple[str, Sequence[Any]]:
499 """
500 Custom format for select clauses. For example, EXISTS expressions need
501 to be wrapped in CASE WHEN on Oracle.
502 """
503 if hasattr(self.output_field, "select_format"): # type: ignore[attr-defined]
504 return self.output_field.select_format(compiler, sql, params) # type: ignore[attr-defined]
505 return sql, params
506
507
508@deconstructible
509class Expression(BaseExpression, Combinable):
510 """An expression that can be combined with other expressions."""
511
512 @cached_property
513 def identity(self) -> tuple[Any, ...]:
514 constructor_signature = inspect.signature(self.__init__)
515 args, kwargs = self._constructor_args # type: ignore[attr-defined]
516 signature = constructor_signature.bind_partial(*args, **kwargs)
517 signature.apply_defaults()
518 arguments = signature.arguments.items()
519 identity: list[Any] = [self.__class__]
520 for arg, value in arguments:
521 if isinstance(value, fields.Field):
522 if value.name and value.model:
523 value = (value.model.model_options.label, value.name)
524 else:
525 value = type(value)
526 else:
527 value = make_hashable(value)
528 identity.append((arg, value))
529 return tuple(identity)
530
531 def __eq__(self, other: object) -> bool:
532 if not isinstance(other, Expression):
533 return NotImplemented
534 return other.identity == self.identity
535
536 def __hash__(self) -> int:
537 return hash(self.identity)
538
539
540# Type inference for CombinedExpression.output_field.
541# Missing items will result in FieldError, by design.
542#
543# The current approach for NULL is based on lowest common denominator behavior
544# i.e. if one of the supported databases is raising an error (rather than
545# return NULL) for `val <op> NULL`, then Plain raises FieldError.
546
547_connector_combinations = [
548 # Numeric operations - operands of same type.
549 {
550 connector: [
551 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
552 (fields.FloatField, fields.FloatField, fields.FloatField),
553 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
554 ]
555 for connector in (
556 Combinable.ADD,
557 Combinable.SUB,
558 Combinable.MUL,
559 # Behavior for DIV with integer arguments follows Postgres/SQLite,
560 # not MySQL/Oracle.
561 Combinable.DIV,
562 Combinable.MOD,
563 Combinable.POW,
564 )
565 },
566 # Numeric operations - operands of different type.
567 {
568 connector: [
569 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
570 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
571 (fields.IntegerField, fields.FloatField, fields.FloatField),
572 (fields.FloatField, fields.IntegerField, fields.FloatField),
573 ]
574 for connector in (
575 Combinable.ADD,
576 Combinable.SUB,
577 Combinable.MUL,
578 Combinable.DIV,
579 Combinable.MOD,
580 )
581 },
582 # Bitwise operators.
583 {
584 connector: [
585 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
586 ]
587 for connector in (
588 Combinable.BITAND,
589 Combinable.BITOR,
590 Combinable.BITLEFTSHIFT,
591 Combinable.BITRIGHTSHIFT,
592 Combinable.BITXOR,
593 )
594 },
595 # Numeric with NULL.
596 {
597 connector: [
598 (field_type, NoneType, field_type),
599 (NoneType, field_type, field_type),
600 ]
601 for connector in (
602 Combinable.ADD,
603 Combinable.SUB,
604 Combinable.MUL,
605 Combinable.DIV,
606 Combinable.MOD,
607 Combinable.POW,
608 )
609 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
610 },
611 # Date/DateTimeField/DurationField/TimeField.
612 {
613 Combinable.ADD: [
614 # Date/DateTimeField.
615 (fields.DateField, fields.DurationField, fields.DateTimeField),
616 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
617 (fields.DurationField, fields.DateField, fields.DateTimeField),
618 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
619 # DurationField.
620 (fields.DurationField, fields.DurationField, fields.DurationField),
621 # TimeField.
622 (fields.TimeField, fields.DurationField, fields.TimeField),
623 (fields.DurationField, fields.TimeField, fields.TimeField),
624 ],
625 },
626 {
627 Combinable.SUB: [
628 # Date/DateTimeField.
629 (fields.DateField, fields.DurationField, fields.DateTimeField),
630 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
631 (fields.DateField, fields.DateField, fields.DurationField),
632 (fields.DateField, fields.DateTimeField, fields.DurationField),
633 (fields.DateTimeField, fields.DateField, fields.DurationField),
634 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
635 # DurationField.
636 (fields.DurationField, fields.DurationField, fields.DurationField),
637 # TimeField.
638 (fields.TimeField, fields.DurationField, fields.TimeField),
639 (fields.TimeField, fields.TimeField, fields.DurationField),
640 ],
641 },
642]
643
644_connector_combinators = defaultdict(list)
645
646
647def register_combinable_fields(
648 lhs: type[Field], connector: str, rhs: type[Field], result: type[Field]
649) -> None:
650 """
651 Register combinable types:
652 lhs <connector> rhs -> result
653 e.g.
654 register_combinable_fields(
655 IntegerField, Combinable.ADD, FloatField, FloatField
656 )
657 """
658 _connector_combinators[connector].append((lhs, rhs, result))
659
660
661for d in _connector_combinations:
662 for connector, field_types in d.items():
663 for lhs, rhs, result in field_types:
664 register_combinable_fields(lhs, connector, rhs, result)
665
666
667@functools.lru_cache(maxsize=128)
668def _resolve_combined_type(
669 connector: str, lhs_type: type[Field], rhs_type: type[Field]
670) -> type[Field] | None:
671 combinators = _connector_combinators.get(connector, ())
672 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
673 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
674 rhs_type, combinator_rhs_type
675 ):
676 return combined_type
677 return None
678
679
680class CombinedExpression(SQLiteNumericMixin, Expression):
681 def __init__(
682 self, lhs: Any, connector: str, rhs: Any, output_field: Field | None = None
683 ):
684 super().__init__(output_field=output_field)
685 self.connector = connector
686 self.lhs = lhs
687 self.rhs = rhs
688
689 def __repr__(self) -> str:
690 return f"<{self.__class__.__name__}: {self}>"
691
692 def __str__(self) -> str:
693 return f"{self.lhs} {self.connector} {self.rhs}"
694
695 def get_source_expressions(self) -> list[Any]:
696 return [self.lhs, self.rhs]
697
698 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
699 self.lhs, self.rhs = exprs
700
701 def _resolve_output_field(self) -> Field | None:
702 # We avoid using super() here for reasons given in
703 # Expression._resolve_output_field()
704 combined_type = _resolve_combined_type(
705 self.connector,
706 type(self.lhs._output_field_or_none),
707 type(self.rhs._output_field_or_none),
708 )
709 if combined_type is None:
710 raise FieldError(
711 f"Cannot infer type of {self.connector!r} expression involving these "
712 f"types: {self.lhs.output_field.__class__.__name__}, "
713 f"{self.rhs.output_field.__class__.__name__}. You must set "
714 f"output_field."
715 )
716 return combined_type()
717
718 def as_sql(
719 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
720 ) -> tuple[str, list[Any]]:
721 expressions = []
722 expression_params = []
723 sql, params = compiler.compile(self.lhs)
724 expressions.append(sql)
725 expression_params.extend(params)
726 sql, params = compiler.compile(self.rhs)
727 expressions.append(sql)
728 expression_params.extend(params)
729 # order of precedence
730 expression_wrapper = "(%s)"
731 sql = connection.ops.combine_expression(self.connector, expressions)
732 return expression_wrapper % sql, expression_params
733
734 def resolve_expression(
735 self,
736 query: Any = None,
737 allow_joins: bool = True,
738 reuse: Any = None,
739 summarize: bool = False,
740 for_save: bool = False,
741 ) -> CombinedExpression | DurationExpression | TemporalSubtraction:
742 lhs = self.lhs.resolve_expression(
743 query, allow_joins, reuse, summarize, for_save
744 )
745 rhs = self.rhs.resolve_expression(
746 query, allow_joins, reuse, summarize, for_save
747 )
748 if not isinstance(self, DurationExpression | TemporalSubtraction):
749 try:
750 lhs_type = lhs.output_field.get_internal_type()
751 except (AttributeError, FieldError):
752 lhs_type = None
753 try:
754 rhs_type = rhs.output_field.get_internal_type()
755 except (AttributeError, FieldError):
756 rhs_type = None
757 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
758 return DurationExpression(
759 self.lhs, self.connector, self.rhs
760 ).resolve_expression(
761 query,
762 allow_joins,
763 reuse,
764 summarize,
765 for_save,
766 )
767 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
768 if (
769 self.connector == self.SUB
770 and lhs_type in datetime_fields
771 and lhs_type == rhs_type
772 ):
773 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
774 query,
775 allow_joins,
776 reuse,
777 summarize,
778 for_save,
779 )
780 c = self.copy()
781 c.is_summary = summarize
782 c.lhs = lhs
783 c.rhs = rhs
784 return c
785
786
787class DurationExpression(CombinedExpression):
788 def compile(
789 self, side: Any, compiler: SQLCompiler, connection: BaseDatabaseWrapper
790 ) -> tuple[str, Sequence[Any]]:
791 try:
792 output = side.output_field
793 except FieldError:
794 pass
795 else:
796 if output.get_internal_type() == "DurationField":
797 sql, params = compiler.compile(side)
798 return connection.ops.format_for_duration_arithmetic(sql), params
799 return compiler.compile(side)
800
801 def as_sql(
802 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
803 ) -> tuple[str, list[Any]]:
804 if connection.features.has_native_duration_field:
805 return super().as_sql(compiler, connection) # type: ignore[misc]
806 connection.ops.check_expression_support(self)
807 expressions = []
808 expression_params = []
809 sql, params = self.compile(self.lhs, compiler, connection)
810 expressions.append(sql)
811 expression_params.extend(params)
812 sql, params = self.compile(self.rhs, compiler, connection)
813 expressions.append(sql)
814 expression_params.extend(params)
815 # order of precedence
816 expression_wrapper = "(%s)"
817 sql = connection.ops.combine_duration_expression(self.connector, expressions)
818 return expression_wrapper % sql, expression_params
819
820 def as_sqlite(
821 self,
822 compiler: SQLCompiler,
823 connection: BaseDatabaseWrapper,
824 **extra_context: Any,
825 ) -> tuple[str, Sequence[Any]]:
826 sql, params = self.as_sql(compiler, connection, **extra_context)
827 if self.connector in {Combinable.MUL, Combinable.DIV}:
828 try:
829 lhs_type = self.lhs.output_field.get_internal_type()
830 rhs_type = self.rhs.output_field.get_internal_type()
831 except (AttributeError, FieldError):
832 pass
833 else:
834 allowed_fields = {
835 "DecimalField",
836 "DurationField",
837 "FloatField",
838 "IntegerField",
839 }
840 if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
841 raise DatabaseError(
842 f"Invalid arguments for operator {self.connector}."
843 )
844 return sql, params
845
846
847class TemporalSubtraction(CombinedExpression):
848 output_field = fields.DurationField()
849
850 def __init__(self, lhs: Any, rhs: Any):
851 super().__init__(lhs, self.SUB, rhs)
852
853 def as_sql(
854 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
855 ) -> tuple[str, Sequence[Any]]:
856 connection.ops.check_expression_support(self)
857 lhs = compiler.compile(self.lhs)
858 rhs = compiler.compile(self.rhs)
859 return connection.ops.subtract_temporals(
860 self.lhs.output_field.get_internal_type(), lhs, rhs
861 )
862
863
864@deconstructible(path="plain.models.F")
865class F(Combinable):
866 """An object capable of resolving references to existing query objects."""
867
868 def __init__(self, name: str):
869 """
870 Arguments:
871 * name: the name of the field this expression references
872 """
873 self.name = name
874
875 def __repr__(self) -> str:
876 return f"{self.__class__.__name__}({self.name})"
877
878 def resolve_expression(
879 self,
880 query: Any = None,
881 allow_joins: bool = True,
882 reuse: Any = None,
883 summarize: bool = False,
884 for_save: bool = False,
885 ) -> Any:
886 return query.resolve_ref(self.name, allow_joins, reuse, summarize) # type: ignore[union-attr]
887
888 def replace_expressions(self, replacements: dict[Any, Any]) -> F:
889 return replacements.get(self, self)
890
891 def asc(self, **kwargs: Any) -> OrderBy:
892 return OrderBy(self, **kwargs)
893
894 def desc(self, **kwargs: Any) -> OrderBy:
895 return OrderBy(self, descending=True, **kwargs)
896
897 def __eq__(self, other: object) -> bool:
898 return (
899 self.__class__ == other.__class__ and self.name == other.name # type: ignore[attr-defined]
900 )
901
902 def __hash__(self) -> int:
903 return hash(self.name)
904
905 def copy(self) -> F:
906 return copy.copy(self)
907
908
909class ResolvedOuterRef(F):
910 """
911 An object that contains a reference to an outer query.
912
913 In this case, the reference to the outer query has been resolved because
914 the inner query has been used as a subquery.
915 """
916
917 contains_aggregate = False
918 contains_over_clause = False
919
920 def as_sql(self, *args: Any, **kwargs: Any) -> None:
921 raise ValueError(
922 "This queryset contains a reference to an outer query and may "
923 "only be used in a subquery."
924 )
925
926 def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
927 col = super().resolve_expression(*args, **kwargs) # type: ignore[misc]
928 if col.contains_over_clause:
929 raise NotSupportedError(
930 f"Referencing outer query window expression is not supported: "
931 f"{self.name}."
932 )
933 # FIXME: Rename possibly_multivalued to multivalued and fix detection
934 # for non-multivalued JOINs (e.g. foreign key fields). This should take
935 # into account only many-to-many and one-to-many relationships.
936 col.possibly_multivalued = LOOKUP_SEP in self.name
937 return col
938
939 def relabeled_clone(self, relabels: dict[str, str]) -> ResolvedOuterRef:
940 return self
941
942 def get_group_by_cols(self) -> list[Any]:
943 return []
944
945
946class OuterRef(F):
947 contains_aggregate = False
948
949 def resolve_expression(self, *args: Any, **kwargs: Any) -> ResolvedOuterRef | F:
950 if isinstance(self.name, self.__class__):
951 return self.name
952 return ResolvedOuterRef(self.name)
953
954 def relabeled_clone(self, relabels: dict[str, str]) -> OuterRef:
955 return self
956
957
958@deconstructible(path="plain.models.Func")
959class Func(SQLiteNumericMixin, Expression):
960 """An SQL function call."""
961
962 function = None
963 template = "%(function)s(%(expressions)s)"
964 arg_joiner = ", "
965 arity = None # The number of arguments the function accepts.
966
967 def __init__(
968 self, *expressions: Any, output_field: Field | None = None, **extra: Any
969 ):
970 if self.arity is not None and len(expressions) != self.arity:
971 raise TypeError(
972 "'{}' takes exactly {} {} ({} given)".format(
973 self.__class__.__name__,
974 self.arity,
975 "argument" if self.arity == 1 else "arguments",
976 len(expressions),
977 )
978 )
979 super().__init__(output_field=output_field)
980 self.source_expressions = self._parse_expressions(*expressions)
981 self.extra = extra
982
983 def __repr__(self) -> str:
984 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
985 extra = {**self.extra, **self._get_repr_options()}
986 if extra:
987 extra = ", ".join(
988 str(key) + "=" + str(val) for key, val in sorted(extra.items())
989 )
990 return f"{self.__class__.__name__}({args}, {extra})"
991 return f"{self.__class__.__name__}({args})"
992
993 def _get_repr_options(self) -> dict[str, Any]:
994 """Return a dict of extra __init__() options to include in the repr."""
995 return {}
996
997 def get_source_expressions(self) -> list[Any]:
998 return self.source_expressions
999
1000 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1001 self.source_expressions = exprs
1002
1003 def resolve_expression(
1004 self,
1005 query: Any = None,
1006 allow_joins: bool = True,
1007 reuse: Any = None,
1008 summarize: bool = False,
1009 for_save: bool = False,
1010 ) -> Func:
1011 c = self.copy()
1012 c.is_summary = summarize
1013 for pos, arg in enumerate(c.source_expressions):
1014 c.source_expressions[pos] = arg.resolve_expression(
1015 query, allow_joins, reuse, summarize, for_save
1016 )
1017 return c
1018
1019 def as_sql(
1020 self,
1021 compiler: SQLCompiler,
1022 connection: BaseDatabaseWrapper,
1023 function: str | None = None,
1024 template: str | None = None,
1025 arg_joiner: str | None = None,
1026 **extra_context: Any,
1027 ) -> tuple[str, list[Any]]:
1028 connection.ops.check_expression_support(self)
1029 sql_parts = []
1030 params = []
1031 for arg in self.source_expressions:
1032 try:
1033 arg_sql, arg_params = compiler.compile(arg)
1034 except EmptyResultSet:
1035 empty_result_set_value = getattr(
1036 arg, "empty_result_set_value", NotImplemented
1037 )
1038 if empty_result_set_value is NotImplemented:
1039 raise
1040 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
1041 except FullResultSet:
1042 arg_sql, arg_params = compiler.compile(Value(True))
1043 sql_parts.append(arg_sql)
1044 params.extend(arg_params)
1045 data = {**self.extra, **extra_context}
1046 # Use the first supplied value in this order: the parameter to this
1047 # method, a value supplied in __init__()'s **extra (the value in
1048 # `data`), or the value defined on the class.
1049 if function is not None:
1050 data["function"] = function
1051 else:
1052 data.setdefault("function", self.function)
1053 template = template or data.get("template", self.template)
1054 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
1055 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
1056 return template % data, params
1057
1058 def copy(self) -> Func:
1059 copy = super().copy() # type: ignore[misc]
1060 copy.source_expressions = self.source_expressions[:]
1061 copy.extra = self.extra.copy()
1062 return copy
1063
1064
1065@deconstructible(path="plain.models.Value")
1066class Value(SQLiteNumericMixin, Expression):
1067 """Represent a wrapped value as a node within an expression."""
1068
1069 # Provide a default value for `for_save` in order to allow unresolved
1070 # instances to be compiled until a decision is taken in #25425.
1071 for_save = False
1072
1073 def __init__(self, value: Any, output_field: Field | None = None):
1074 """
1075 Arguments:
1076 * value: the value this expression represents. The value will be
1077 added into the sql parameter list and properly quoted.
1078
1079 * output_field: an instance of the model field type that this
1080 expression will return, such as IntegerField() or CharField().
1081 """
1082 super().__init__(output_field=output_field)
1083 self.value = value
1084
1085 def __repr__(self) -> str:
1086 return f"{self.__class__.__name__}({self.value!r})"
1087
1088 def as_sql(
1089 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1090 ) -> tuple[str, list[Any]]:
1091 connection.ops.check_expression_support(self)
1092 val = self.value
1093 output_field = self._output_field_or_none
1094 if output_field is not None:
1095 if self.for_save:
1096 val = output_field.get_db_prep_save(val, connection=connection)
1097 else:
1098 val = output_field.get_db_prep_value(val, connection=connection)
1099 if hasattr(output_field, "get_placeholder"):
1100 return output_field.get_placeholder(val, compiler, connection), [val]
1101 if val is None:
1102 # cx_Oracle does not always convert None to the appropriate
1103 # NULL type (like in case expressions using numbers), so we
1104 # use a literal SQL NULL
1105 return "NULL", []
1106 return "%s", [val]
1107
1108 def resolve_expression(
1109 self,
1110 query: Any = None,
1111 allow_joins: bool = True,
1112 reuse: Any = None,
1113 summarize: bool = False,
1114 for_save: bool = False,
1115 ) -> Value:
1116 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) # type: ignore[misc]
1117 c.for_save = for_save
1118 return c
1119
1120 def get_group_by_cols(self) -> list[Any]:
1121 return []
1122
1123 def _resolve_output_field(self) -> Field | None:
1124 if isinstance(self.value, str):
1125 return fields.CharField()
1126 if isinstance(self.value, bool):
1127 return fields.BooleanField()
1128 if isinstance(self.value, int):
1129 return fields.IntegerField()
1130 if isinstance(self.value, float):
1131 return fields.FloatField()
1132 if isinstance(self.value, datetime.datetime):
1133 return fields.DateTimeField()
1134 if isinstance(self.value, datetime.date):
1135 return fields.DateField()
1136 if isinstance(self.value, datetime.time):
1137 return fields.TimeField()
1138 if isinstance(self.value, datetime.timedelta):
1139 return fields.DurationField()
1140 if isinstance(self.value, Decimal):
1141 return fields.DecimalField()
1142 if isinstance(self.value, bytes):
1143 return fields.BinaryField()
1144 if isinstance(self.value, UUID):
1145 return fields.UUIDField()
1146
1147 @property
1148 def empty_result_set_value(self) -> Any:
1149 return self.value
1150
1151
1152class RawSQL(Expression):
1153 def __init__(
1154 self, sql: str, params: Sequence[Any], output_field: Field | None = None
1155 ):
1156 if output_field is None:
1157 output_field = fields.Field()
1158 self.sql, self.params = sql, params
1159 super().__init__(output_field=output_field)
1160
1161 def __repr__(self) -> str:
1162 return f"{self.__class__.__name__}({self.sql}, {self.params})"
1163
1164 def as_sql(
1165 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1166 ) -> tuple[str, Sequence[Any]]:
1167 return f"({self.sql})", self.params
1168
1169 def get_group_by_cols(self) -> list[RawSQL]:
1170 return [self]
1171
1172
1173class Star(Expression):
1174 def __repr__(self) -> str:
1175 return "'*'"
1176
1177 def as_sql(
1178 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1179 ) -> tuple[str, list[Any]]:
1180 return "*", []
1181
1182
1183class Col(Expression):
1184 contains_column_references = True
1185 possibly_multivalued = False
1186
1187 def __init__(
1188 self, alias: str | None, target: Any, output_field: Field | None = None
1189 ):
1190 if output_field is None:
1191 output_field = target
1192 super().__init__(output_field=output_field)
1193 self.alias, self.target = alias, target
1194
1195 def __repr__(self) -> str:
1196 alias, target = self.alias, self.target
1197 identifiers = (alias, str(target)) if alias else (str(target),)
1198 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1199
1200 def as_sql(
1201 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1202 ) -> tuple[str, list[Any]]:
1203 alias, column = self.alias, self.target.column
1204 identifiers = (alias, column) if alias else (column,)
1205 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1206 return sql, []
1207
1208 def relabeled_clone(self, relabels: dict[str, str]) -> Col:
1209 if self.alias is None:
1210 return self
1211 return self.__class__(
1212 relabels.get(self.alias, self.alias), self.target, self.output_field
1213 )
1214
1215 def get_group_by_cols(self) -> list[Col]:
1216 return [self]
1217
1218 def get_db_converters(
1219 self, connection: BaseDatabaseWrapper
1220 ) -> list[Callable[..., Any]]:
1221 if self.target == self.output_field:
1222 return self.output_field.get_db_converters(connection)
1223 return self.output_field.get_db_converters(
1224 connection
1225 ) + self.target.get_db_converters(connection)
1226
1227
1228class Ref(Expression):
1229 """
1230 Reference to column alias of the query. For example, Ref('sum_cost') in
1231 qs.annotate(sum_cost=Sum('cost')) query.
1232 """
1233
1234 def __init__(self, refs: str, source: Any):
1235 super().__init__()
1236 self.refs, self.source = refs, source
1237
1238 def __repr__(self) -> str:
1239 return f"{self.__class__.__name__}({self.refs}, {self.source})"
1240
1241 def get_source_expressions(self) -> list[Any]:
1242 return [self.source]
1243
1244 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1245 (self.source,) = exprs
1246
1247 def resolve_expression(
1248 self,
1249 query: Any = None,
1250 allow_joins: bool = True,
1251 reuse: Any = None,
1252 summarize: bool = False,
1253 for_save: bool = False,
1254 ) -> Ref:
1255 # The sub-expression `source` has already been resolved, as this is
1256 # just a reference to the name of `source`.
1257 return self
1258
1259 def get_refs(self) -> set[str]:
1260 return {self.refs}
1261
1262 def relabeled_clone(self, relabels: dict[str, str]) -> Ref:
1263 return self
1264
1265 def as_sql(
1266 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1267 ) -> tuple[str, list[Any]]:
1268 return connection.ops.quote_name(self.refs), []
1269
1270 def get_group_by_cols(self) -> list[Ref]:
1271 return [self]
1272
1273
1274class ExpressionList(Func):
1275 """
1276 An expression containing multiple expressions. Can be used to provide a
1277 list of expressions as an argument to another expression, like a partition
1278 clause.
1279 """
1280
1281 template = "%(expressions)s"
1282
1283 def __init__(self, *expressions: Any, **extra: Any):
1284 if not expressions:
1285 raise ValueError(
1286 f"{self.__class__.__name__} requires at least one expression."
1287 )
1288 super().__init__(*expressions, **extra)
1289
1290 def __str__(self) -> str:
1291 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1292
1293 def as_sqlite(
1294 self,
1295 compiler: SQLCompiler,
1296 connection: BaseDatabaseWrapper,
1297 **extra_context: Any,
1298 ) -> tuple[str, Sequence[Any]]:
1299 # Casting to numeric is unnecessary.
1300 return self.as_sql(compiler, connection, **extra_context)
1301
1302
1303class OrderByList(Func):
1304 template = "ORDER BY %(expressions)s"
1305
1306 def __init__(self, *expressions: Any, **extra: Any):
1307 expressions_tuple = tuple(
1308 (
1309 OrderBy(F(expr[1:]), descending=True)
1310 if isinstance(expr, str) and expr[0] == "-"
1311 else expr
1312 )
1313 for expr in expressions
1314 )
1315 super().__init__(*expressions_tuple, **extra)
1316
1317 def as_sql(self, *args: Any, **kwargs: Any) -> tuple[str, tuple[Any, ...]]:
1318 if not self.source_expressions:
1319 return "", ()
1320 return super().as_sql(*args, **kwargs) # type: ignore[misc]
1321
1322 def get_group_by_cols(self) -> list[Any]:
1323 group_by_cols = []
1324 for order_by in self.get_source_expressions():
1325 group_by_cols.extend(order_by.get_group_by_cols())
1326 return group_by_cols
1327
1328
1329@deconstructible(path="plain.models.ExpressionWrapper")
1330class ExpressionWrapper(SQLiteNumericMixin, Expression):
1331 """
1332 An expression that can wrap another expression so that it can provide
1333 extra context to the inner expression, such as the output_field.
1334 """
1335
1336 def __init__(self, expression: Any, output_field: Field):
1337 super().__init__(output_field=output_field)
1338 self.expression = expression
1339
1340 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1341 self.expression = exprs[0]
1342
1343 def get_source_expressions(self) -> list[Any]:
1344 return [self.expression]
1345
1346 def get_group_by_cols(self) -> list[Any]:
1347 if isinstance(self.expression, Expression):
1348 expression = self.expression.copy()
1349 expression.output_field = self.output_field
1350 return expression.get_group_by_cols()
1351 # For non-expressions e.g. an SQL WHERE clause, the entire
1352 # `expression` must be included in the GROUP BY clause.
1353 return super().get_group_by_cols() # type: ignore[misc]
1354
1355 def as_sql(
1356 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1357 ) -> tuple[str, Sequence[Any]]:
1358 return compiler.compile(self.expression)
1359
1360 def __repr__(self) -> str:
1361 return f"{self.__class__.__name__}({self.expression})"
1362
1363
1364class NegatedExpression(ExpressionWrapper):
1365 """The logical negation of a conditional expression."""
1366
1367 def __init__(self, expression: Any):
1368 super().__init__(expression, output_field=fields.BooleanField())
1369
1370 def __invert__(self) -> Any:
1371 return self.expression.copy()
1372
1373 def as_sql(
1374 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1375 ) -> tuple[str, Sequence[Any]]:
1376 try:
1377 sql, params = super().as_sql(compiler, connection)
1378 except EmptyResultSet:
1379 features = compiler.connection.features
1380 if not features.supports_boolean_expr_in_select_clause:
1381 return "1=1", ()
1382 return compiler.compile(Value(True))
1383 ops = compiler.connection.ops
1384 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1385 # to be compared to another expression unless they're wrapped in a CASE
1386 # WHEN.
1387 if not ops.conditional_expression_supported_in_where_clause(self.expression):
1388 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1389 return f"NOT {sql}", params
1390
1391 def resolve_expression(
1392 self,
1393 query: Any = None,
1394 allow_joins: bool = True,
1395 reuse: Any = None,
1396 summarize: bool = False,
1397 for_save: bool = False,
1398 ) -> NegatedExpression:
1399 resolved = super().resolve_expression(
1400 query, allow_joins, reuse, summarize, for_save
1401 )
1402 if not getattr(resolved.expression, "conditional", False):
1403 raise TypeError("Cannot negate non-conditional expressions.")
1404 return resolved
1405
1406 def select_format(
1407 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1408 ) -> tuple[str, Sequence[Any]]:
1409 # Wrap boolean expressions with a CASE WHEN expression if a database
1410 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1411 # GROUP BY list.
1412 expression_supported_in_where_clause = (
1413 compiler.connection.ops.conditional_expression_supported_in_where_clause
1414 )
1415 if (
1416 not compiler.connection.features.supports_boolean_expr_in_select_clause
1417 # Avoid double wrapping.
1418 and expression_supported_in_where_clause(self.expression)
1419 ):
1420 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1421 return sql, params
1422
1423
1424@deconstructible(path="plain.models.When")
1425class When(Expression):
1426 template = "WHEN %(condition)s THEN %(result)s"
1427 # This isn't a complete conditional expression, must be used in Case().
1428 conditional = False
1429
1430 def __init__(self, condition: Any = None, then: Any = None, **lookups: Any):
1431 lookups_dict: dict[str, Any] | None = lookups or None
1432 if lookups_dict:
1433 if condition is None:
1434 condition, lookups_dict = Q(**lookups_dict), None
1435 elif getattr(condition, "conditional", False):
1436 condition, lookups_dict = Q(condition, **lookups_dict), None
1437 if (
1438 condition is None
1439 or not getattr(condition, "conditional", False)
1440 or lookups_dict
1441 ):
1442 raise TypeError(
1443 "When() supports a Q object, a boolean expression, or lookups "
1444 "as a condition."
1445 )
1446 if isinstance(condition, Q) and not condition:
1447 raise ValueError("An empty Q() can't be used as a When() condition.")
1448 super().__init__(output_field=None)
1449 self.condition = condition
1450 self.result = self._parse_expressions(then)[0]
1451
1452 def __str__(self) -> str:
1453 return f"WHEN {self.condition!r} THEN {self.result!r}"
1454
1455 def __repr__(self) -> str:
1456 return f"<{self.__class__.__name__}: {self}>"
1457
1458 def get_source_expressions(self) -> list[Any]:
1459 return [self.condition, self.result]
1460
1461 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1462 self.condition, self.result = exprs
1463
1464 def get_source_fields(self) -> list[Field | None]:
1465 # We're only interested in the fields of the result expressions.
1466 return [self.result._output_field_or_none]
1467
1468 def resolve_expression(
1469 self,
1470 query: Any = None,
1471 allow_joins: bool = True,
1472 reuse: Any = None,
1473 summarize: bool = False,
1474 for_save: bool = False,
1475 ) -> When:
1476 c = self.copy()
1477 c.is_summary = summarize
1478 if hasattr(c.condition, "resolve_expression"):
1479 c.condition = c.condition.resolve_expression(
1480 query, allow_joins, reuse, summarize, False
1481 )
1482 c.result = c.result.resolve_expression(
1483 query, allow_joins, reuse, summarize, for_save
1484 )
1485 return c
1486
1487 def as_sql(
1488 self,
1489 compiler: SQLCompiler,
1490 connection: BaseDatabaseWrapper,
1491 template: str | None = None,
1492 **extra_context: Any,
1493 ) -> tuple[str, tuple[Any, ...]]:
1494 connection.ops.check_expression_support(self)
1495 template_params = extra_context
1496 sql_params = []
1497 condition_sql, condition_params = compiler.compile(self.condition)
1498 template_params["condition"] = condition_sql
1499 result_sql, result_params = compiler.compile(self.result)
1500 template_params["result"] = result_sql
1501 template = template or self.template
1502 return template % template_params, (
1503 *sql_params,
1504 *condition_params,
1505 *result_params,
1506 )
1507
1508 def get_group_by_cols(self) -> list[Any]:
1509 # This is not a complete expression and cannot be used in GROUP BY.
1510 cols = []
1511 for source in self.get_source_expressions():
1512 cols.extend(source.get_group_by_cols())
1513 return cols
1514
1515
1516@deconstructible(path="plain.models.Case")
1517class Case(SQLiteNumericMixin, Expression):
1518 """
1519 An SQL searched CASE expression:
1520
1521 CASE
1522 WHEN n > 0
1523 THEN 'positive'
1524 WHEN n < 0
1525 THEN 'negative'
1526 ELSE 'zero'
1527 END
1528 """
1529
1530 template = "CASE %(cases)s ELSE %(default)s END"
1531 case_joiner = " "
1532
1533 def __init__(
1534 self,
1535 *cases: When,
1536 default: Any = None,
1537 output_field: Field | None = None,
1538 **extra: Any,
1539 ):
1540 if not all(isinstance(case, When) for case in cases):
1541 raise TypeError("Positional arguments must all be When objects.")
1542 super().__init__(output_field)
1543 self.cases = list(cases)
1544 self.default = self._parse_expressions(default)[0]
1545 self.extra = extra
1546
1547 def __str__(self) -> str:
1548 return "CASE {}, ELSE {!r}".format(
1549 ", ".join(str(c) for c in self.cases),
1550 self.default,
1551 )
1552
1553 def __repr__(self) -> str:
1554 return f"<{self.__class__.__name__}: {self}>"
1555
1556 def get_source_expressions(self) -> list[Any]:
1557 return self.cases + [self.default]
1558
1559 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1560 *self.cases, self.default = exprs
1561
1562 def resolve_expression(
1563 self,
1564 query: Any = None,
1565 allow_joins: bool = True,
1566 reuse: Any = None,
1567 summarize: bool = False,
1568 for_save: bool = False,
1569 ) -> Case:
1570 c = self.copy()
1571 c.is_summary = summarize
1572 for pos, case in enumerate(c.cases):
1573 c.cases[pos] = case.resolve_expression(
1574 query, allow_joins, reuse, summarize, for_save
1575 )
1576 c.default = c.default.resolve_expression(
1577 query, allow_joins, reuse, summarize, for_save
1578 )
1579 return c
1580
1581 def copy(self) -> Case:
1582 c = super().copy() # type: ignore[misc]
1583 c.cases = c.cases[:]
1584 return c
1585
1586 def as_sql(
1587 self,
1588 compiler: SQLCompiler,
1589 connection: BaseDatabaseWrapper,
1590 template: str | None = None,
1591 case_joiner: str | None = None,
1592 **extra_context: Any,
1593 ) -> tuple[str, list[Any]]:
1594 connection.ops.check_expression_support(self)
1595 if not self.cases:
1596 sql, params = compiler.compile(self.default)
1597 return sql, list(params)
1598 template_params = {**self.extra, **extra_context}
1599 case_parts = []
1600 sql_params = []
1601 default_sql, default_params = compiler.compile(self.default)
1602 for case in self.cases:
1603 try:
1604 case_sql, case_params = compiler.compile(case)
1605 except EmptyResultSet:
1606 continue
1607 except FullResultSet:
1608 default_sql, default_params = compiler.compile(case.result)
1609 break
1610 case_parts.append(case_sql)
1611 sql_params.extend(case_params)
1612 if not case_parts:
1613 return default_sql, list(default_params)
1614 case_joiner = case_joiner or self.case_joiner
1615 template_params["cases"] = case_joiner.join(case_parts)
1616 template_params["default"] = default_sql
1617 sql_params.extend(default_params)
1618 template = template or template_params.get("template", self.template)
1619 sql = template % template_params
1620 if self._output_field_or_none is not None:
1621 sql = connection.ops.unification_cast_sql(self.output_field) % sql
1622 return sql, sql_params
1623
1624 def get_group_by_cols(self) -> list[Any]:
1625 if not self.cases:
1626 return self.default.get_group_by_cols()
1627 return super().get_group_by_cols() # type: ignore[misc]
1628
1629
1630class Subquery(BaseExpression, Combinable):
1631 """
1632 An explicit subquery. It may contain OuterRef() references to the outer
1633 query which will be resolved when it is applied to that query.
1634 """
1635
1636 template = "(%(subquery)s)"
1637 contains_aggregate = False
1638 empty_result_set_value = None
1639
1640 def __init__(
1641 self,
1642 query: QuerySet[Any] | Query,
1643 output_field: Field | None = None,
1644 **extra: Any,
1645 ):
1646 # Import here to avoid circular import
1647 from plain.models.sql.query import Query
1648
1649 # Allow the usage of both QuerySet and sql.Query objects.
1650 if isinstance(query, Query):
1651 # It's already a Query object, use it directly
1652 sql_query = query
1653 else:
1654 # It's a QuerySet, extract the sql.Query
1655 sql_query = query.sql_query
1656 self.query = sql_query.clone()
1657 self.query.subquery = True
1658 self.extra = extra
1659 super().__init__(output_field)
1660
1661 def get_source_expressions(self) -> list[Any]:
1662 return [self.query]
1663
1664 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1665 self.query = exprs[0]
1666
1667 def _resolve_output_field(self) -> Field | None:
1668 return self.query.output_field
1669
1670 def copy(self) -> Subquery:
1671 clone = super().copy()
1672 clone.query = clone.query.clone()
1673 return clone
1674
1675 @property
1676 def external_aliases(self) -> list[str]:
1677 return self.query.external_aliases
1678
1679 def get_external_cols(self) -> list[Any]:
1680 return self.query.get_external_cols()
1681
1682 def as_sql(
1683 self,
1684 compiler: SQLCompiler,
1685 connection: BaseDatabaseWrapper,
1686 template: str | None = None,
1687 **extra_context: Any,
1688 ) -> tuple[str, tuple[Any, ...]]:
1689 connection.ops.check_expression_support(self)
1690 template_params = {**self.extra, **extra_context}
1691 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1692 template_params["subquery"] = subquery_sql[1:-1]
1693
1694 template = template or template_params.get("template", self.template)
1695 sql = template % template_params
1696 return sql, sql_params
1697
1698 def get_group_by_cols(self) -> list[Any]:
1699 return self.query.get_group_by_cols(wrapper=self)
1700
1701
1702class Exists(Subquery):
1703 template = "EXISTS(%(subquery)s)"
1704 output_field = fields.BooleanField()
1705 empty_result_set_value = False
1706
1707 def __init__(self, query: QuerySet[Any] | Query, **kwargs: Any):
1708 super().__init__(query, **kwargs)
1709 self.query = self.query.exists()
1710
1711 def select_format(
1712 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
1713 ) -> tuple[str, Sequence[Any]]:
1714 # Wrap EXISTS() with a CASE WHEN expression if a database backend
1715 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1716 # BY list.
1717 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1718 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1719 return sql, params
1720
1721
1722@deconstructible(path="plain.models.OrderBy")
1723class OrderBy(Expression):
1724 template = "%(expression)s %(ordering)s"
1725 conditional = False
1726
1727 def __init__(
1728 self,
1729 expression: Any,
1730 descending: bool = False,
1731 nulls_first: bool | None = None,
1732 nulls_last: bool | None = None,
1733 ):
1734 if nulls_first and nulls_last:
1735 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1736 if nulls_first is False or nulls_last is False:
1737 raise ValueError("nulls_first and nulls_last values must be True or None.")
1738 self.nulls_first = nulls_first
1739 self.nulls_last = nulls_last
1740 self.descending = descending
1741 if not hasattr(expression, "resolve_expression"):
1742 raise ValueError("expression must be an expression type")
1743 self.expression = expression
1744
1745 def __repr__(self) -> str:
1746 return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1747
1748 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1749 self.expression = exprs[0]
1750
1751 def get_source_expressions(self) -> list[Any]:
1752 return [self.expression]
1753
1754 def as_sql(
1755 self,
1756 compiler: SQLCompiler,
1757 connection: BaseDatabaseWrapper,
1758 template: str | None = None,
1759 **extra_context: Any,
1760 ) -> tuple[str, tuple[Any, ...]]:
1761 template = template or self.template
1762 if connection.features.supports_order_by_nulls_modifier:
1763 if self.nulls_last:
1764 template = f"{template} NULLS LAST"
1765 elif self.nulls_first:
1766 template = f"{template} NULLS FIRST"
1767 else:
1768 if self.nulls_last and not (
1769 self.descending and connection.features.order_by_nulls_first
1770 ):
1771 template = f"%(expression)s IS NULL, {template}"
1772 elif self.nulls_first and not (
1773 not self.descending and connection.features.order_by_nulls_first
1774 ):
1775 template = f"%(expression)s IS NOT NULL, {template}"
1776 connection.ops.check_expression_support(self)
1777 expression_sql, params = compiler.compile(self.expression)
1778 placeholders = {
1779 "expression": expression_sql,
1780 "ordering": "DESC" if self.descending else "ASC",
1781 **extra_context,
1782 }
1783 params *= template.count("%(expression)s")
1784 return (template % placeholders).rstrip(), params
1785
1786 def get_group_by_cols(self) -> list[Any]:
1787 cols = []
1788 for source in self.get_source_expressions():
1789 cols.extend(source.get_group_by_cols())
1790 return cols
1791
1792 def reverse_ordering(self) -> OrderBy:
1793 self.descending = not self.descending
1794 if self.nulls_first:
1795 self.nulls_last = True
1796 self.nulls_first = None
1797 elif self.nulls_last:
1798 self.nulls_first = True
1799 self.nulls_last = None
1800 return self
1801
1802 def asc(self) -> None:
1803 self.descending = False
1804
1805 def desc(self) -> None:
1806 self.descending = True
1807
1808
1809class Window(SQLiteNumericMixin, Expression):
1810 template = "%(expression)s OVER (%(window)s)"
1811 # Although the main expression may either be an aggregate or an
1812 # expression with an aggregate function, the GROUP BY that will
1813 # be introduced in the query as a result is not desired.
1814 contains_aggregate = False
1815 contains_over_clause = True
1816
1817 def __init__(
1818 self,
1819 expression: Any,
1820 partition_by: Any = None,
1821 order_by: Any = None,
1822 frame: Any = None,
1823 output_field: Field | None = None,
1824 ):
1825 self.partition_by = partition_by
1826 self.order_by = order_by
1827 self.frame = frame
1828
1829 if not getattr(expression, "window_compatible", False):
1830 raise ValueError(
1831 f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1832 )
1833
1834 if self.partition_by is not None:
1835 if not isinstance(self.partition_by, tuple | list):
1836 self.partition_by = (self.partition_by,)
1837 self.partition_by = ExpressionList(*self.partition_by)
1838
1839 if self.order_by is not None:
1840 if isinstance(self.order_by, list | tuple):
1841 self.order_by = OrderByList(*self.order_by)
1842 elif isinstance(self.order_by, BaseExpression | str):
1843 self.order_by = OrderByList(self.order_by)
1844 else:
1845 raise ValueError(
1846 "Window.order_by must be either a string reference to a "
1847 "field, an expression, or a list or tuple of them."
1848 )
1849 super().__init__(output_field=output_field)
1850 self.source_expression = self._parse_expressions(expression)[0]
1851
1852 def _resolve_output_field(self) -> Field | None:
1853 return self.source_expression.output_field
1854
1855 def get_source_expressions(self) -> list[Any]:
1856 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1857
1858 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1859 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1860
1861 def as_sql(
1862 self,
1863 compiler: SQLCompiler,
1864 connection: BaseDatabaseWrapper,
1865 template: str | None = None,
1866 ) -> tuple[str, tuple[Any, ...]]:
1867 connection.ops.check_expression_support(self)
1868 if not connection.features.supports_over_clause:
1869 raise NotSupportedError("This backend does not support window expressions.")
1870 expr_sql, params = compiler.compile(self.source_expression)
1871 window_sql, window_params = [], ()
1872
1873 if self.partition_by is not None:
1874 sql_expr, sql_params = self.partition_by.as_sql(
1875 compiler=compiler,
1876 connection=connection,
1877 template="PARTITION BY %(expressions)s",
1878 )
1879 window_sql.append(sql_expr)
1880 window_params += tuple(sql_params)
1881
1882 if self.order_by is not None:
1883 order_sql, order_params = compiler.compile(self.order_by)
1884 window_sql.append(order_sql)
1885 window_params += tuple(order_params)
1886
1887 if self.frame:
1888 frame_sql, frame_params = compiler.compile(self.frame)
1889 window_sql.append(frame_sql)
1890 window_params += tuple(frame_params)
1891
1892 template = template or self.template
1893
1894 return (
1895 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1896 (*params, *window_params),
1897 )
1898
1899 def as_sqlite(
1900 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1901 ) -> tuple[str, Sequence[Any]]:
1902 if isinstance(self.output_field, fields.DecimalField):
1903 # Casting to numeric must be outside of the window expression.
1904 copy = self.copy()
1905 source_expressions = copy.get_source_expressions()
1906 source_expressions[0].output_field = fields.FloatField()
1907 copy.set_source_expressions(source_expressions)
1908 return super(Window, copy).as_sqlite(compiler, connection)
1909 return self.as_sql(compiler, connection)
1910
1911 def __str__(self) -> str:
1912 return "{} OVER ({}{}{})".format(
1913 str(self.source_expression),
1914 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1915 str(self.order_by or ""),
1916 str(self.frame or ""),
1917 )
1918
1919 def __repr__(self) -> str:
1920 return f"<{self.__class__.__name__}: {self}>"
1921
1922 def get_group_by_cols(self) -> list[Any]:
1923 group_by_cols = []
1924 if self.partition_by:
1925 group_by_cols.extend(self.partition_by.get_group_by_cols())
1926 if self.order_by is not None:
1927 group_by_cols.extend(self.order_by.get_group_by_cols())
1928 return group_by_cols
1929
1930
1931class WindowFrame(Expression):
1932 """
1933 Model the frame clause in window expressions. There are two types of frame
1934 clauses which are subclasses, however, all processing and validation (by no
1935 means intended to be complete) is done here. Thus, providing an end for a
1936 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1937 row in the frame).
1938 """
1939
1940 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1941
1942 def __init__(self, start: int | None = None, end: int | None = None):
1943 self.start = Value(start)
1944 self.end = Value(end)
1945
1946 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
1947 self.start, self.end = exprs
1948
1949 def get_source_expressions(self) -> list[Any]:
1950 return [self.start, self.end]
1951
1952 def as_sql(
1953 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1954 ) -> tuple[str, list[Any]]:
1955 connection.ops.check_expression_support(self)
1956 start, end = self.window_frame_start_end(
1957 connection, self.start.value, self.end.value
1958 )
1959 return (
1960 self.template
1961 % {
1962 "frame_type": self.frame_type,
1963 "start": start,
1964 "end": end,
1965 },
1966 [],
1967 )
1968
1969 def __repr__(self) -> str:
1970 return f"<{self.__class__.__name__}: {self}>"
1971
1972 def get_group_by_cols(self) -> list[Any]:
1973 return []
1974
1975 def __str__(self) -> str:
1976 if self.start.value is not None and self.start.value < 0:
1977 start = f"{abs(self.start.value)} {db_connection.ops.PRECEDING}"
1978 elif self.start.value is not None and self.start.value == 0:
1979 start = db_connection.ops.CURRENT_ROW
1980 else:
1981 start = db_connection.ops.UNBOUNDED_PRECEDING
1982
1983 if self.end.value is not None and self.end.value > 0:
1984 end = f"{self.end.value} {db_connection.ops.FOLLOWING}"
1985 elif self.end.value is not None and self.end.value == 0:
1986 end = db_connection.ops.CURRENT_ROW
1987 else:
1988 end = db_connection.ops.UNBOUNDED_FOLLOWING
1989 return self.template % {
1990 "frame_type": self.frame_type,
1991 "start": start,
1992 "end": end,
1993 }
1994
1995 def window_frame_start_end(
1996 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
1997 ) -> tuple[str, str]:
1998 raise NotImplementedError("Subclasses must implement window_frame_start_end().")
1999
2000
2001class RowRange(WindowFrame):
2002 frame_type = "ROWS"
2003
2004 def window_frame_start_end(
2005 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2006 ) -> tuple[str, str]:
2007 return connection.ops.window_frame_rows_start_end(start, end)
2008
2009
2010class ValueRange(WindowFrame):
2011 frame_type = "RANGE"
2012
2013 def window_frame_start_end(
2014 self, connection: BaseDatabaseWrapper, start: int | None, end: int | None
2015 ) -> tuple[str, str]:
2016 return connection.ops.window_frame_range_start_end(start, end)