1import copy
2import datetime
3import functools
4import inspect
5from collections import defaultdict
6from decimal import Decimal
7from functools import cached_property
8from types import NoneType
9from uuid import UUID
10
11from plain.exceptions import EmptyResultSet, FieldError, FullResultSet
12from plain.models import fields
13from plain.models.constants import LOOKUP_SEP
14from plain.models.db import (
15 DatabaseError,
16 NotSupportedError,
17 db_connection,
18)
19from plain.models.query_utils import Q
20from plain.utils.deconstruct import deconstructible
21from plain.utils.hashable import make_hashable
22
23
24class SQLiteNumericMixin:
25 """
26 Some expressions with output_field=DecimalField() must be cast to
27 numeric to be properly filtered.
28 """
29
30 def as_sqlite(self, compiler, connection, **extra_context):
31 sql, params = self.as_sql(compiler, connection, **extra_context)
32 try:
33 if self.output_field.get_internal_type() == "DecimalField":
34 sql = f"CAST({sql} AS NUMERIC)"
35 except FieldError:
36 pass
37 return sql, params
38
39
40class Combinable:
41 """
42 Provide the ability to combine one or two objects with
43 some connector. For example F('foo') + F('bar').
44 """
45
46 # Arithmetic connectors
47 ADD = "+"
48 SUB = "-"
49 MUL = "*"
50 DIV = "/"
51 POW = "^"
52 # The following is a quoted % operator - it is quoted because it can be
53 # used in strings that also have parameter substitution.
54 MOD = "%%"
55
56 # Bitwise operators - note that these are generated by .bitand()
57 # and .bitor(), the '&' and '|' are reserved for boolean operator
58 # usage.
59 BITAND = "&"
60 BITOR = "|"
61 BITLEFTSHIFT = "<<"
62 BITRIGHTSHIFT = ">>"
63 BITXOR = "#"
64
65 def _combine(self, other, connector, reversed):
66 if not hasattr(other, "resolve_expression"):
67 # everything must be resolvable to an expression
68 other = Value(other)
69
70 if reversed:
71 return CombinedExpression(other, connector, self)
72 return CombinedExpression(self, connector, other)
73
74 #############
75 # OPERATORS #
76 #############
77
78 def __neg__(self):
79 return self._combine(-1, self.MUL, False)
80
81 def __add__(self, other):
82 return self._combine(other, self.ADD, False)
83
84 def __sub__(self, other):
85 return self._combine(other, self.SUB, False)
86
87 def __mul__(self, other):
88 return self._combine(other, self.MUL, False)
89
90 def __truediv__(self, other):
91 return self._combine(other, self.DIV, False)
92
93 def __mod__(self, other):
94 return self._combine(other, self.MOD, False)
95
96 def __pow__(self, other):
97 return self._combine(other, self.POW, False)
98
99 def __and__(self, other):
100 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
101 return Q(self) & Q(other)
102 raise NotImplementedError(
103 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
104 )
105
106 def bitand(self, other):
107 return self._combine(other, self.BITAND, False)
108
109 def bitleftshift(self, other):
110 return self._combine(other, self.BITLEFTSHIFT, False)
111
112 def bitrightshift(self, other):
113 return self._combine(other, self.BITRIGHTSHIFT, False)
114
115 def __xor__(self, other):
116 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
117 return Q(self) ^ Q(other)
118 raise NotImplementedError(
119 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
120 )
121
122 def bitxor(self, other):
123 return self._combine(other, self.BITXOR, False)
124
125 def __or__(self, other):
126 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
127 return Q(self) | Q(other)
128 raise NotImplementedError(
129 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
130 )
131
132 def bitor(self, other):
133 return self._combine(other, self.BITOR, False)
134
135 def __radd__(self, other):
136 return self._combine(other, self.ADD, True)
137
138 def __rsub__(self, other):
139 return self._combine(other, self.SUB, True)
140
141 def __rmul__(self, other):
142 return self._combine(other, self.MUL, True)
143
144 def __rtruediv__(self, other):
145 return self._combine(other, self.DIV, True)
146
147 def __rmod__(self, other):
148 return self._combine(other, self.MOD, True)
149
150 def __rpow__(self, other):
151 return self._combine(other, self.POW, True)
152
153 def __rand__(self, other):
154 raise NotImplementedError(
155 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
156 )
157
158 def __ror__(self, other):
159 raise NotImplementedError(
160 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
161 )
162
163 def __rxor__(self, other):
164 raise NotImplementedError(
165 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
166 )
167
168 def __invert__(self):
169 return NegatedExpression(self)
170
171
172class BaseExpression:
173 """Base class for all query expressions."""
174
175 empty_result_set_value = NotImplemented
176 # aggregate specific fields
177 is_summary = False
178 _output_field_resolved_to_none = False
179 # Can the expression be used in a WHERE clause?
180 filterable = True
181 # Can the expression can be used as a source expression in Window?
182 window_compatible = False
183
184 def __init__(self, output_field=None):
185 if output_field is not None:
186 self.output_field = output_field
187
188 def __getstate__(self):
189 state = self.__dict__.copy()
190 state.pop("convert_value", None)
191 return state
192
193 def get_db_converters(self, connection):
194 return (
195 []
196 if self.convert_value is self._convert_value_noop
197 else [self.convert_value]
198 ) + self.output_field.get_db_converters(connection)
199
200 def get_source_expressions(self):
201 return []
202
203 def set_source_expressions(self, exprs):
204 assert not exprs
205
206 def _parse_expressions(self, *expressions):
207 return [
208 arg
209 if hasattr(arg, "resolve_expression")
210 else (F(arg) if isinstance(arg, str) else Value(arg))
211 for arg in expressions
212 ]
213
214 def as_sql(self, compiler, connection):
215 """
216 Responsible for returning a (sql, [params]) tuple to be included
217 in the current query.
218
219 Different backends can provide their own implementation, by
220 providing an `as_{vendor}` method and patching the Expression:
221
222 ```
223 def override_as_sql(self, compiler, connection):
224 # custom logic
225 return super().as_sql(compiler, connection)
226 setattr(Expression, 'as_' + connection.vendor, override_as_sql)
227 ```
228
229 Arguments:
230 * compiler: the query compiler responsible for generating the query.
231 Must have a compile method, returning a (sql, [params]) tuple.
232 Calling compiler(value) will return a quoted `value`.
233
234 * connection: the database connection used for the current query.
235
236 Return: (sql, params)
237 Where `sql` is a string containing ordered sql parameters to be
238 replaced with the elements of the list `params`.
239 """
240 raise NotImplementedError("Subclasses must implement as_sql()")
241
242 @cached_property
243 def contains_aggregate(self):
244 return any(
245 expr and expr.contains_aggregate for expr in self.get_source_expressions()
246 )
247
248 @cached_property
249 def contains_over_clause(self):
250 return any(
251 expr and expr.contains_over_clause for expr in self.get_source_expressions()
252 )
253
254 @cached_property
255 def contains_column_references(self):
256 return any(
257 expr and expr.contains_column_references
258 for expr in self.get_source_expressions()
259 )
260
261 def resolve_expression(
262 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
263 ):
264 """
265 Provide the chance to do any preprocessing or validation before being
266 added to the query.
267
268 Arguments:
269 * query: the backend query implementation
270 * allow_joins: boolean allowing or denying use of joins
271 in this query
272 * reuse: a set of reusable joins for multijoins
273 * summarize: a terminal aggregate clause
274 * for_save: whether this expression about to be used in a save or update
275
276 Return: an Expression to be added to the query.
277 """
278 c = self.copy()
279 c.is_summary = summarize
280 c.set_source_expressions(
281 [
282 expr.resolve_expression(query, allow_joins, reuse, summarize)
283 if expr
284 else None
285 for expr in c.get_source_expressions()
286 ]
287 )
288 return c
289
290 @property
291 def conditional(self):
292 return isinstance(self.output_field, fields.BooleanField)
293
294 @property
295 def field(self):
296 return self.output_field
297
298 @cached_property
299 def output_field(self):
300 """Return the output type of this expressions."""
301 output_field = self._resolve_output_field()
302 if output_field is None:
303 self._output_field_resolved_to_none = True
304 raise FieldError("Cannot resolve expression type, unknown output_field")
305 return output_field
306
307 @cached_property
308 def _output_field_or_none(self):
309 """
310 Return the output field of this expression, or None if
311 _resolve_output_field() didn't return an output type.
312 """
313 try:
314 return self.output_field
315 except FieldError:
316 if not self._output_field_resolved_to_none:
317 raise
318
319 def _resolve_output_field(self):
320 """
321 Attempt to infer the output type of the expression.
322
323 As a guess, if the output fields of all source fields match then simply
324 infer the same type here.
325
326 If a source's output field resolves to None, exclude it from this check.
327 If all sources are None, then an error is raised higher up the stack in
328 the output_field property.
329 """
330 # This guess is mostly a bad idea, but there is quite a lot of code
331 # (especially 3rd party Func subclasses) that depend on it, we'd need a
332 # deprecation path to fix it.
333 sources_iter = (
334 source for source in self.get_source_fields() if source is not None
335 )
336 for output_field in sources_iter:
337 for source in sources_iter:
338 if not isinstance(output_field, source.__class__):
339 raise FieldError(
340 f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
341 "set output_field."
342 )
343 return output_field
344
345 @staticmethod
346 def _convert_value_noop(value, expression, connection):
347 return value
348
349 @cached_property
350 def convert_value(self):
351 """
352 Expressions provide their own converters because users have the option
353 of manually specifying the output_field which may be a different type
354 from the one the database returns.
355 """
356 field = self.output_field
357 internal_type = field.get_internal_type()
358 if internal_type == "FloatField":
359 return (
360 lambda value, expression, connection: None
361 if value is None
362 else float(value)
363 )
364 elif internal_type.endswith("IntegerField"):
365 return (
366 lambda value, expression, connection: None
367 if value is None
368 else int(value)
369 )
370 elif internal_type == "DecimalField":
371 return (
372 lambda value, expression, connection: None
373 if value is None
374 else Decimal(value)
375 )
376 return self._convert_value_noop
377
378 def get_lookup(self, lookup):
379 return self.output_field.get_lookup(lookup)
380
381 def get_transform(self, name):
382 return self.output_field.get_transform(name)
383
384 def relabeled_clone(self, change_map):
385 clone = self.copy()
386 clone.set_source_expressions(
387 [
388 e.relabeled_clone(change_map) if e is not None else None
389 for e in self.get_source_expressions()
390 ]
391 )
392 return clone
393
394 def replace_expressions(self, replacements):
395 if replacement := replacements.get(self):
396 return replacement
397 clone = self.copy()
398 source_expressions = clone.get_source_expressions()
399 clone.set_source_expressions(
400 [
401 expr.replace_expressions(replacements) if expr else None
402 for expr in source_expressions
403 ]
404 )
405 return clone
406
407 def get_refs(self):
408 refs = set()
409 for expr in self.get_source_expressions():
410 refs |= expr.get_refs()
411 return refs
412
413 def copy(self):
414 return copy.copy(self)
415
416 def prefix_references(self, prefix):
417 clone = self.copy()
418 clone.set_source_expressions(
419 [
420 F(f"{prefix}{expr.name}")
421 if isinstance(expr, F)
422 else expr.prefix_references(prefix)
423 for expr in self.get_source_expressions()
424 ]
425 )
426 return clone
427
428 def get_group_by_cols(self):
429 if not self.contains_aggregate:
430 return [self]
431 cols = []
432 for source in self.get_source_expressions():
433 cols.extend(source.get_group_by_cols())
434 return cols
435
436 def get_source_fields(self):
437 """Return the underlying field types used by this aggregate."""
438 return [e._output_field_or_none for e in self.get_source_expressions()]
439
440 def asc(self, **kwargs):
441 return OrderBy(self, **kwargs)
442
443 def desc(self, **kwargs):
444 return OrderBy(self, descending=True, **kwargs)
445
446 def reverse_ordering(self):
447 return self
448
449 def flatten(self):
450 """
451 Recursively yield this expression and all subexpressions, in
452 depth-first order.
453 """
454 yield self
455 for expr in self.get_source_expressions():
456 if expr:
457 if hasattr(expr, "flatten"):
458 yield from expr.flatten()
459 else:
460 yield expr
461
462 def select_format(self, compiler, sql, params):
463 """
464 Custom format for select clauses. For example, EXISTS expressions need
465 to be wrapped in CASE WHEN on Oracle.
466 """
467 if hasattr(self.output_field, "select_format"):
468 return self.output_field.select_format(compiler, sql, params)
469 return sql, params
470
471
472@deconstructible
473class Expression(BaseExpression, Combinable):
474 """An expression that can be combined with other expressions."""
475
476 @cached_property
477 def identity(self):
478 constructor_signature = inspect.signature(self.__init__)
479 args, kwargs = self._constructor_args
480 signature = constructor_signature.bind_partial(*args, **kwargs)
481 signature.apply_defaults()
482 arguments = signature.arguments.items()
483 identity = [self.__class__]
484 for arg, value in arguments:
485 if isinstance(value, fields.Field):
486 if value.name and value.model:
487 value = (value.model._meta.label, value.name)
488 else:
489 value = type(value)
490 else:
491 value = make_hashable(value)
492 identity.append((arg, value))
493 return tuple(identity)
494
495 def __eq__(self, other):
496 if not isinstance(other, Expression):
497 return NotImplemented
498 return other.identity == self.identity
499
500 def __hash__(self):
501 return hash(self.identity)
502
503
504# Type inference for CombinedExpression.output_field.
505# Missing items will result in FieldError, by design.
506#
507# The current approach for NULL is based on lowest common denominator behavior
508# i.e. if one of the supported databases is raising an error (rather than
509# return NULL) for `val <op> NULL`, then Plain raises FieldError.
510
511_connector_combinations = [
512 # Numeric operations - operands of same type.
513 {
514 connector: [
515 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
516 (fields.FloatField, fields.FloatField, fields.FloatField),
517 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
518 ]
519 for connector in (
520 Combinable.ADD,
521 Combinable.SUB,
522 Combinable.MUL,
523 # Behavior for DIV with integer arguments follows Postgres/SQLite,
524 # not MySQL/Oracle.
525 Combinable.DIV,
526 Combinable.MOD,
527 Combinable.POW,
528 )
529 },
530 # Numeric operations - operands of different type.
531 {
532 connector: [
533 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
534 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
535 (fields.IntegerField, fields.FloatField, fields.FloatField),
536 (fields.FloatField, fields.IntegerField, fields.FloatField),
537 ]
538 for connector in (
539 Combinable.ADD,
540 Combinable.SUB,
541 Combinable.MUL,
542 Combinable.DIV,
543 Combinable.MOD,
544 )
545 },
546 # Bitwise operators.
547 {
548 connector: [
549 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
550 ]
551 for connector in (
552 Combinable.BITAND,
553 Combinable.BITOR,
554 Combinable.BITLEFTSHIFT,
555 Combinable.BITRIGHTSHIFT,
556 Combinable.BITXOR,
557 )
558 },
559 # Numeric with NULL.
560 {
561 connector: [
562 (field_type, NoneType, field_type),
563 (NoneType, field_type, field_type),
564 ]
565 for connector in (
566 Combinable.ADD,
567 Combinable.SUB,
568 Combinable.MUL,
569 Combinable.DIV,
570 Combinable.MOD,
571 Combinable.POW,
572 )
573 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
574 },
575 # Date/DateTimeField/DurationField/TimeField.
576 {
577 Combinable.ADD: [
578 # Date/DateTimeField.
579 (fields.DateField, fields.DurationField, fields.DateTimeField),
580 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
581 (fields.DurationField, fields.DateField, fields.DateTimeField),
582 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
583 # DurationField.
584 (fields.DurationField, fields.DurationField, fields.DurationField),
585 # TimeField.
586 (fields.TimeField, fields.DurationField, fields.TimeField),
587 (fields.DurationField, fields.TimeField, fields.TimeField),
588 ],
589 },
590 {
591 Combinable.SUB: [
592 # Date/DateTimeField.
593 (fields.DateField, fields.DurationField, fields.DateTimeField),
594 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
595 (fields.DateField, fields.DateField, fields.DurationField),
596 (fields.DateField, fields.DateTimeField, fields.DurationField),
597 (fields.DateTimeField, fields.DateField, fields.DurationField),
598 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
599 # DurationField.
600 (fields.DurationField, fields.DurationField, fields.DurationField),
601 # TimeField.
602 (fields.TimeField, fields.DurationField, fields.TimeField),
603 (fields.TimeField, fields.TimeField, fields.DurationField),
604 ],
605 },
606]
607
608_connector_combinators = defaultdict(list)
609
610
611def register_combinable_fields(lhs, connector, rhs, result):
612 """
613 Register combinable types:
614 lhs <connector> rhs -> result
615 e.g.
616 register_combinable_fields(
617 IntegerField, Combinable.ADD, FloatField, FloatField
618 )
619 """
620 _connector_combinators[connector].append((lhs, rhs, result))
621
622
623for d in _connector_combinations:
624 for connector, field_types in d.items():
625 for lhs, rhs, result in field_types:
626 register_combinable_fields(lhs, connector, rhs, result)
627
628
629@functools.lru_cache(maxsize=128)
630def _resolve_combined_type(connector, lhs_type, rhs_type):
631 combinators = _connector_combinators.get(connector, ())
632 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
633 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
634 rhs_type, combinator_rhs_type
635 ):
636 return combined_type
637
638
639class CombinedExpression(SQLiteNumericMixin, Expression):
640 def __init__(self, lhs, connector, rhs, output_field=None):
641 super().__init__(output_field=output_field)
642 self.connector = connector
643 self.lhs = lhs
644 self.rhs = rhs
645
646 def __repr__(self):
647 return f"<{self.__class__.__name__}: {self}>"
648
649 def __str__(self):
650 return f"{self.lhs} {self.connector} {self.rhs}"
651
652 def get_source_expressions(self):
653 return [self.lhs, self.rhs]
654
655 def set_source_expressions(self, exprs):
656 self.lhs, self.rhs = exprs
657
658 def _resolve_output_field(self):
659 # We avoid using super() here for reasons given in
660 # Expression._resolve_output_field()
661 combined_type = _resolve_combined_type(
662 self.connector,
663 type(self.lhs._output_field_or_none),
664 type(self.rhs._output_field_or_none),
665 )
666 if combined_type is None:
667 raise FieldError(
668 f"Cannot infer type of {self.connector!r} expression involving these "
669 f"types: {self.lhs.output_field.__class__.__name__}, "
670 f"{self.rhs.output_field.__class__.__name__}. You must set "
671 f"output_field."
672 )
673 return combined_type()
674
675 def as_sql(self, compiler, connection):
676 expressions = []
677 expression_params = []
678 sql, params = compiler.compile(self.lhs)
679 expressions.append(sql)
680 expression_params.extend(params)
681 sql, params = compiler.compile(self.rhs)
682 expressions.append(sql)
683 expression_params.extend(params)
684 # order of precedence
685 expression_wrapper = "(%s)"
686 sql = connection.ops.combine_expression(self.connector, expressions)
687 return expression_wrapper % sql, expression_params
688
689 def resolve_expression(
690 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
691 ):
692 lhs = self.lhs.resolve_expression(
693 query, allow_joins, reuse, summarize, for_save
694 )
695 rhs = self.rhs.resolve_expression(
696 query, allow_joins, reuse, summarize, for_save
697 )
698 if not isinstance(self, DurationExpression | TemporalSubtraction):
699 try:
700 lhs_type = lhs.output_field.get_internal_type()
701 except (AttributeError, FieldError):
702 lhs_type = None
703 try:
704 rhs_type = rhs.output_field.get_internal_type()
705 except (AttributeError, FieldError):
706 rhs_type = None
707 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
708 return DurationExpression(
709 self.lhs, self.connector, self.rhs
710 ).resolve_expression(
711 query,
712 allow_joins,
713 reuse,
714 summarize,
715 for_save,
716 )
717 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
718 if (
719 self.connector == self.SUB
720 and lhs_type in datetime_fields
721 and lhs_type == rhs_type
722 ):
723 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
724 query,
725 allow_joins,
726 reuse,
727 summarize,
728 for_save,
729 )
730 c = self.copy()
731 c.is_summary = summarize
732 c.lhs = lhs
733 c.rhs = rhs
734 return c
735
736
737class DurationExpression(CombinedExpression):
738 def compile(self, side, compiler, connection):
739 try:
740 output = side.output_field
741 except FieldError:
742 pass
743 else:
744 if output.get_internal_type() == "DurationField":
745 sql, params = compiler.compile(side)
746 return connection.ops.format_for_duration_arithmetic(sql), params
747 return compiler.compile(side)
748
749 def as_sql(self, compiler, connection):
750 if connection.features.has_native_duration_field:
751 return super().as_sql(compiler, connection)
752 connection.ops.check_expression_support(self)
753 expressions = []
754 expression_params = []
755 sql, params = self.compile(self.lhs, compiler, connection)
756 expressions.append(sql)
757 expression_params.extend(params)
758 sql, params = self.compile(self.rhs, compiler, connection)
759 expressions.append(sql)
760 expression_params.extend(params)
761 # order of precedence
762 expression_wrapper = "(%s)"
763 sql = connection.ops.combine_duration_expression(self.connector, expressions)
764 return expression_wrapper % sql, expression_params
765
766 def as_sqlite(self, compiler, connection, **extra_context):
767 sql, params = self.as_sql(compiler, connection, **extra_context)
768 if self.connector in {Combinable.MUL, Combinable.DIV}:
769 try:
770 lhs_type = self.lhs.output_field.get_internal_type()
771 rhs_type = self.rhs.output_field.get_internal_type()
772 except (AttributeError, FieldError):
773 pass
774 else:
775 allowed_fields = {
776 "DecimalField",
777 "DurationField",
778 "FloatField",
779 "IntegerField",
780 }
781 if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
782 raise DatabaseError(
783 f"Invalid arguments for operator {self.connector}."
784 )
785 return sql, params
786
787
788class TemporalSubtraction(CombinedExpression):
789 output_field = fields.DurationField()
790
791 def __init__(self, lhs, rhs):
792 super().__init__(lhs, self.SUB, rhs)
793
794 def as_sql(self, compiler, connection):
795 connection.ops.check_expression_support(self)
796 lhs = compiler.compile(self.lhs)
797 rhs = compiler.compile(self.rhs)
798 return connection.ops.subtract_temporals(
799 self.lhs.output_field.get_internal_type(), lhs, rhs
800 )
801
802
803@deconstructible(path="plain.models.F")
804class F(Combinable):
805 """An object capable of resolving references to existing query objects."""
806
807 def __init__(self, name):
808 """
809 Arguments:
810 * name: the name of the field this expression references
811 """
812 self.name = name
813
814 def __repr__(self):
815 return f"{self.__class__.__name__}({self.name})"
816
817 def resolve_expression(
818 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
819 ):
820 return query.resolve_ref(self.name, allow_joins, reuse, summarize)
821
822 def replace_expressions(self, replacements):
823 return replacements.get(self, self)
824
825 def asc(self, **kwargs):
826 return OrderBy(self, **kwargs)
827
828 def desc(self, **kwargs):
829 return OrderBy(self, descending=True, **kwargs)
830
831 def __eq__(self, other):
832 return self.__class__ == other.__class__ and self.name == other.name
833
834 def __hash__(self):
835 return hash(self.name)
836
837 def copy(self):
838 return copy.copy(self)
839
840
841class ResolvedOuterRef(F):
842 """
843 An object that contains a reference to an outer query.
844
845 In this case, the reference to the outer query has been resolved because
846 the inner query has been used as a subquery.
847 """
848
849 contains_aggregate = False
850 contains_over_clause = False
851
852 def as_sql(self, *args, **kwargs):
853 raise ValueError(
854 "This queryset contains a reference to an outer query and may "
855 "only be used in a subquery."
856 )
857
858 def resolve_expression(self, *args, **kwargs):
859 col = super().resolve_expression(*args, **kwargs)
860 if col.contains_over_clause:
861 raise NotSupportedError(
862 f"Referencing outer query window expression is not supported: "
863 f"{self.name}."
864 )
865 # FIXME: Rename possibly_multivalued to multivalued and fix detection
866 # for non-multivalued JOINs (e.g. foreign key fields). This should take
867 # into account only many-to-many and one-to-many relationships.
868 col.possibly_multivalued = LOOKUP_SEP in self.name
869 return col
870
871 def relabeled_clone(self, relabels):
872 return self
873
874 def get_group_by_cols(self):
875 return []
876
877
878class OuterRef(F):
879 contains_aggregate = False
880
881 def resolve_expression(self, *args, **kwargs):
882 if isinstance(self.name, self.__class__):
883 return self.name
884 return ResolvedOuterRef(self.name)
885
886 def relabeled_clone(self, relabels):
887 return self
888
889
890@deconstructible(path="plain.models.Func")
891class Func(SQLiteNumericMixin, Expression):
892 """An SQL function call."""
893
894 function = None
895 template = "%(function)s(%(expressions)s)"
896 arg_joiner = ", "
897 arity = None # The number of arguments the function accepts.
898
899 def __init__(self, *expressions, output_field=None, **extra):
900 if self.arity is not None and len(expressions) != self.arity:
901 raise TypeError(
902 "'{}' takes exactly {} {} ({} given)".format(
903 self.__class__.__name__,
904 self.arity,
905 "argument" if self.arity == 1 else "arguments",
906 len(expressions),
907 )
908 )
909 super().__init__(output_field=output_field)
910 self.source_expressions = self._parse_expressions(*expressions)
911 self.extra = extra
912
913 def __repr__(self):
914 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
915 extra = {**self.extra, **self._get_repr_options()}
916 if extra:
917 extra = ", ".join(
918 str(key) + "=" + str(val) for key, val in sorted(extra.items())
919 )
920 return f"{self.__class__.__name__}({args}, {extra})"
921 return f"{self.__class__.__name__}({args})"
922
923 def _get_repr_options(self):
924 """Return a dict of extra __init__() options to include in the repr."""
925 return {}
926
927 def get_source_expressions(self):
928 return self.source_expressions
929
930 def set_source_expressions(self, exprs):
931 self.source_expressions = exprs
932
933 def resolve_expression(
934 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
935 ):
936 c = self.copy()
937 c.is_summary = summarize
938 for pos, arg in enumerate(c.source_expressions):
939 c.source_expressions[pos] = arg.resolve_expression(
940 query, allow_joins, reuse, summarize, for_save
941 )
942 return c
943
944 def as_sql(
945 self,
946 compiler,
947 connection,
948 function=None,
949 template=None,
950 arg_joiner=None,
951 **extra_context,
952 ):
953 connection.ops.check_expression_support(self)
954 sql_parts = []
955 params = []
956 for arg in self.source_expressions:
957 try:
958 arg_sql, arg_params = compiler.compile(arg)
959 except EmptyResultSet:
960 empty_result_set_value = getattr(
961 arg, "empty_result_set_value", NotImplemented
962 )
963 if empty_result_set_value is NotImplemented:
964 raise
965 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
966 except FullResultSet:
967 arg_sql, arg_params = compiler.compile(Value(True))
968 sql_parts.append(arg_sql)
969 params.extend(arg_params)
970 data = {**self.extra, **extra_context}
971 # Use the first supplied value in this order: the parameter to this
972 # method, a value supplied in __init__()'s **extra (the value in
973 # `data`), or the value defined on the class.
974 if function is not None:
975 data["function"] = function
976 else:
977 data.setdefault("function", self.function)
978 template = template or data.get("template", self.template)
979 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
980 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
981 return template % data, params
982
983 def copy(self):
984 copy = super().copy()
985 copy.source_expressions = self.source_expressions[:]
986 copy.extra = self.extra.copy()
987 return copy
988
989
990@deconstructible(path="plain.models.Value")
991class Value(SQLiteNumericMixin, Expression):
992 """Represent a wrapped value as a node within an expression."""
993
994 # Provide a default value for `for_save` in order to allow unresolved
995 # instances to be compiled until a decision is taken in #25425.
996 for_save = False
997
998 def __init__(self, value, output_field=None):
999 """
1000 Arguments:
1001 * value: the value this expression represents. The value will be
1002 added into the sql parameter list and properly quoted.
1003
1004 * output_field: an instance of the model field type that this
1005 expression will return, such as IntegerField() or CharField().
1006 """
1007 super().__init__(output_field=output_field)
1008 self.value = value
1009
1010 def __repr__(self):
1011 return f"{self.__class__.__name__}({self.value!r})"
1012
1013 def as_sql(self, compiler, connection):
1014 connection.ops.check_expression_support(self)
1015 val = self.value
1016 output_field = self._output_field_or_none
1017 if output_field is not None:
1018 if self.for_save:
1019 val = output_field.get_db_prep_save(val, connection=connection)
1020 else:
1021 val = output_field.get_db_prep_value(val, connection=connection)
1022 if hasattr(output_field, "get_placeholder"):
1023 return output_field.get_placeholder(val, compiler, connection), [val]
1024 if val is None:
1025 # cx_Oracle does not always convert None to the appropriate
1026 # NULL type (like in case expressions using numbers), so we
1027 # use a literal SQL NULL
1028 return "NULL", []
1029 return "%s", [val]
1030
1031 def resolve_expression(
1032 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1033 ):
1034 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1035 c.for_save = for_save
1036 return c
1037
1038 def get_group_by_cols(self):
1039 return []
1040
1041 def _resolve_output_field(self):
1042 if isinstance(self.value, str):
1043 return fields.CharField()
1044 if isinstance(self.value, bool):
1045 return fields.BooleanField()
1046 if isinstance(self.value, int):
1047 return fields.IntegerField()
1048 if isinstance(self.value, float):
1049 return fields.FloatField()
1050 if isinstance(self.value, datetime.datetime):
1051 return fields.DateTimeField()
1052 if isinstance(self.value, datetime.date):
1053 return fields.DateField()
1054 if isinstance(self.value, datetime.time):
1055 return fields.TimeField()
1056 if isinstance(self.value, datetime.timedelta):
1057 return fields.DurationField()
1058 if isinstance(self.value, Decimal):
1059 return fields.DecimalField()
1060 if isinstance(self.value, bytes):
1061 return fields.BinaryField()
1062 if isinstance(self.value, UUID):
1063 return fields.UUIDField()
1064
1065 @property
1066 def empty_result_set_value(self):
1067 return self.value
1068
1069
1070class RawSQL(Expression):
1071 def __init__(self, sql, params, output_field=None):
1072 if output_field is None:
1073 output_field = fields.Field()
1074 self.sql, self.params = sql, params
1075 super().__init__(output_field=output_field)
1076
1077 def __repr__(self):
1078 return f"{self.__class__.__name__}({self.sql}, {self.params})"
1079
1080 def as_sql(self, compiler, connection):
1081 return f"({self.sql})", self.params
1082
1083 def get_group_by_cols(self):
1084 return [self]
1085
1086
1087class Star(Expression):
1088 def __repr__(self):
1089 return "'*'"
1090
1091 def as_sql(self, compiler, connection):
1092 return "*", []
1093
1094
1095class Col(Expression):
1096 contains_column_references = True
1097 possibly_multivalued = False
1098
1099 def __init__(self, alias, target, output_field=None):
1100 if output_field is None:
1101 output_field = target
1102 super().__init__(output_field=output_field)
1103 self.alias, self.target = alias, target
1104
1105 def __repr__(self):
1106 alias, target = self.alias, self.target
1107 identifiers = (alias, str(target)) if alias else (str(target),)
1108 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1109
1110 def as_sql(self, compiler, connection):
1111 alias, column = self.alias, self.target.column
1112 identifiers = (alias, column) if alias else (column,)
1113 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1114 return sql, []
1115
1116 def relabeled_clone(self, relabels):
1117 if self.alias is None:
1118 return self
1119 return self.__class__(
1120 relabels.get(self.alias, self.alias), self.target, self.output_field
1121 )
1122
1123 def get_group_by_cols(self):
1124 return [self]
1125
1126 def get_db_converters(self, connection):
1127 if self.target == self.output_field:
1128 return self.output_field.get_db_converters(connection)
1129 return self.output_field.get_db_converters(
1130 connection
1131 ) + self.target.get_db_converters(connection)
1132
1133
1134class Ref(Expression):
1135 """
1136 Reference to column alias of the query. For example, Ref('sum_cost') in
1137 qs.annotate(sum_cost=Sum('cost')) query.
1138 """
1139
1140 def __init__(self, refs, source):
1141 super().__init__()
1142 self.refs, self.source = refs, source
1143
1144 def __repr__(self):
1145 return f"{self.__class__.__name__}({self.refs}, {self.source})"
1146
1147 def get_source_expressions(self):
1148 return [self.source]
1149
1150 def set_source_expressions(self, exprs):
1151 (self.source,) = exprs
1152
1153 def resolve_expression(
1154 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1155 ):
1156 # The sub-expression `source` has already been resolved, as this is
1157 # just a reference to the name of `source`.
1158 return self
1159
1160 def get_refs(self):
1161 return {self.refs}
1162
1163 def relabeled_clone(self, relabels):
1164 return self
1165
1166 def as_sql(self, compiler, connection):
1167 return connection.ops.quote_name(self.refs), []
1168
1169 def get_group_by_cols(self):
1170 return [self]
1171
1172
1173class ExpressionList(Func):
1174 """
1175 An expression containing multiple expressions. Can be used to provide a
1176 list of expressions as an argument to another expression, like a partition
1177 clause.
1178 """
1179
1180 template = "%(expressions)s"
1181
1182 def __init__(self, *expressions, **extra):
1183 if not expressions:
1184 raise ValueError(
1185 f"{self.__class__.__name__} requires at least one expression."
1186 )
1187 super().__init__(*expressions, **extra)
1188
1189 def __str__(self):
1190 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1191
1192 def as_sqlite(self, compiler, connection, **extra_context):
1193 # Casting to numeric is unnecessary.
1194 return self.as_sql(compiler, connection, **extra_context)
1195
1196
1197class OrderByList(Func):
1198 template = "ORDER BY %(expressions)s"
1199
1200 def __init__(self, *expressions, **extra):
1201 expressions = (
1202 (
1203 OrderBy(F(expr[1:]), descending=True)
1204 if isinstance(expr, str) and expr[0] == "-"
1205 else expr
1206 )
1207 for expr in expressions
1208 )
1209 super().__init__(*expressions, **extra)
1210
1211 def as_sql(self, *args, **kwargs):
1212 if not self.source_expressions:
1213 return "", ()
1214 return super().as_sql(*args, **kwargs)
1215
1216 def get_group_by_cols(self):
1217 group_by_cols = []
1218 for order_by in self.get_source_expressions():
1219 group_by_cols.extend(order_by.get_group_by_cols())
1220 return group_by_cols
1221
1222
1223@deconstructible(path="plain.models.ExpressionWrapper")
1224class ExpressionWrapper(SQLiteNumericMixin, Expression):
1225 """
1226 An expression that can wrap another expression so that it can provide
1227 extra context to the inner expression, such as the output_field.
1228 """
1229
1230 def __init__(self, expression, output_field):
1231 super().__init__(output_field=output_field)
1232 self.expression = expression
1233
1234 def set_source_expressions(self, exprs):
1235 self.expression = exprs[0]
1236
1237 def get_source_expressions(self):
1238 return [self.expression]
1239
1240 def get_group_by_cols(self):
1241 if isinstance(self.expression, Expression):
1242 expression = self.expression.copy()
1243 expression.output_field = self.output_field
1244 return expression.get_group_by_cols()
1245 # For non-expressions e.g. an SQL WHERE clause, the entire
1246 # `expression` must be included in the GROUP BY clause.
1247 return super().get_group_by_cols()
1248
1249 def as_sql(self, compiler, connection):
1250 return compiler.compile(self.expression)
1251
1252 def __repr__(self):
1253 return f"{self.__class__.__name__}({self.expression})"
1254
1255
1256class NegatedExpression(ExpressionWrapper):
1257 """The logical negation of a conditional expression."""
1258
1259 def __init__(self, expression):
1260 super().__init__(expression, output_field=fields.BooleanField())
1261
1262 def __invert__(self):
1263 return self.expression.copy()
1264
1265 def as_sql(self, compiler, connection):
1266 try:
1267 sql, params = super().as_sql(compiler, connection)
1268 except EmptyResultSet:
1269 features = compiler.connection.features
1270 if not features.supports_boolean_expr_in_select_clause:
1271 return "1=1", ()
1272 return compiler.compile(Value(True))
1273 ops = compiler.connection.ops
1274 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1275 # to be compared to another expression unless they're wrapped in a CASE
1276 # WHEN.
1277 if not ops.conditional_expression_supported_in_where_clause(self.expression):
1278 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1279 return f"NOT {sql}", params
1280
1281 def resolve_expression(
1282 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1283 ):
1284 resolved = super().resolve_expression(
1285 query, allow_joins, reuse, summarize, for_save
1286 )
1287 if not getattr(resolved.expression, "conditional", False):
1288 raise TypeError("Cannot negate non-conditional expressions.")
1289 return resolved
1290
1291 def select_format(self, compiler, sql, params):
1292 # Wrap boolean expressions with a CASE WHEN expression if a database
1293 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1294 # GROUP BY list.
1295 expression_supported_in_where_clause = (
1296 compiler.connection.ops.conditional_expression_supported_in_where_clause
1297 )
1298 if (
1299 not compiler.connection.features.supports_boolean_expr_in_select_clause
1300 # Avoid double wrapping.
1301 and expression_supported_in_where_clause(self.expression)
1302 ):
1303 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1304 return sql, params
1305
1306
1307@deconstructible(path="plain.models.When")
1308class When(Expression):
1309 template = "WHEN %(condition)s THEN %(result)s"
1310 # This isn't a complete conditional expression, must be used in Case().
1311 conditional = False
1312
1313 def __init__(self, condition=None, then=None, **lookups):
1314 if lookups:
1315 if condition is None:
1316 condition, lookups = Q(**lookups), None
1317 elif getattr(condition, "conditional", False):
1318 condition, lookups = Q(condition, **lookups), None
1319 if condition is None or not getattr(condition, "conditional", False) or lookups:
1320 raise TypeError(
1321 "When() supports a Q object, a boolean expression, or lookups "
1322 "as a condition."
1323 )
1324 if isinstance(condition, Q) and not condition:
1325 raise ValueError("An empty Q() can't be used as a When() condition.")
1326 super().__init__(output_field=None)
1327 self.condition = condition
1328 self.result = self._parse_expressions(then)[0]
1329
1330 def __str__(self):
1331 return f"WHEN {self.condition!r} THEN {self.result!r}"
1332
1333 def __repr__(self):
1334 return f"<{self.__class__.__name__}: {self}>"
1335
1336 def get_source_expressions(self):
1337 return [self.condition, self.result]
1338
1339 def set_source_expressions(self, exprs):
1340 self.condition, self.result = exprs
1341
1342 def get_source_fields(self):
1343 # We're only interested in the fields of the result expressions.
1344 return [self.result._output_field_or_none]
1345
1346 def resolve_expression(
1347 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1348 ):
1349 c = self.copy()
1350 c.is_summary = summarize
1351 if hasattr(c.condition, "resolve_expression"):
1352 c.condition = c.condition.resolve_expression(
1353 query, allow_joins, reuse, summarize, False
1354 )
1355 c.result = c.result.resolve_expression(
1356 query, allow_joins, reuse, summarize, for_save
1357 )
1358 return c
1359
1360 def as_sql(self, compiler, connection, template=None, **extra_context):
1361 connection.ops.check_expression_support(self)
1362 template_params = extra_context
1363 sql_params = []
1364 condition_sql, condition_params = compiler.compile(self.condition)
1365 template_params["condition"] = condition_sql
1366 result_sql, result_params = compiler.compile(self.result)
1367 template_params["result"] = result_sql
1368 template = template or self.template
1369 return template % template_params, (
1370 *sql_params,
1371 *condition_params,
1372 *result_params,
1373 )
1374
1375 def get_group_by_cols(self):
1376 # This is not a complete expression and cannot be used in GROUP BY.
1377 cols = []
1378 for source in self.get_source_expressions():
1379 cols.extend(source.get_group_by_cols())
1380 return cols
1381
1382
1383@deconstructible(path="plain.models.Case")
1384class Case(SQLiteNumericMixin, Expression):
1385 """
1386 An SQL searched CASE expression:
1387
1388 CASE
1389 WHEN n > 0
1390 THEN 'positive'
1391 WHEN n < 0
1392 THEN 'negative'
1393 ELSE 'zero'
1394 END
1395 """
1396
1397 template = "CASE %(cases)s ELSE %(default)s END"
1398 case_joiner = " "
1399
1400 def __init__(self, *cases, default=None, output_field=None, **extra):
1401 if not all(isinstance(case, When) for case in cases):
1402 raise TypeError("Positional arguments must all be When objects.")
1403 super().__init__(output_field)
1404 self.cases = list(cases)
1405 self.default = self._parse_expressions(default)[0]
1406 self.extra = extra
1407
1408 def __str__(self):
1409 return "CASE {}, ELSE {!r}".format(
1410 ", ".join(str(c) for c in self.cases),
1411 self.default,
1412 )
1413
1414 def __repr__(self):
1415 return f"<{self.__class__.__name__}: {self}>"
1416
1417 def get_source_expressions(self):
1418 return self.cases + [self.default]
1419
1420 def set_source_expressions(self, exprs):
1421 *self.cases, self.default = exprs
1422
1423 def resolve_expression(
1424 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1425 ):
1426 c = self.copy()
1427 c.is_summary = summarize
1428 for pos, case in enumerate(c.cases):
1429 c.cases[pos] = case.resolve_expression(
1430 query, allow_joins, reuse, summarize, for_save
1431 )
1432 c.default = c.default.resolve_expression(
1433 query, allow_joins, reuse, summarize, for_save
1434 )
1435 return c
1436
1437 def copy(self):
1438 c = super().copy()
1439 c.cases = c.cases[:]
1440 return c
1441
1442 def as_sql(
1443 self, compiler, connection, template=None, case_joiner=None, **extra_context
1444 ):
1445 connection.ops.check_expression_support(self)
1446 if not self.cases:
1447 return compiler.compile(self.default)
1448 template_params = {**self.extra, **extra_context}
1449 case_parts = []
1450 sql_params = []
1451 default_sql, default_params = compiler.compile(self.default)
1452 for case in self.cases:
1453 try:
1454 case_sql, case_params = compiler.compile(case)
1455 except EmptyResultSet:
1456 continue
1457 except FullResultSet:
1458 default_sql, default_params = compiler.compile(case.result)
1459 break
1460 case_parts.append(case_sql)
1461 sql_params.extend(case_params)
1462 if not case_parts:
1463 return default_sql, default_params
1464 case_joiner = case_joiner or self.case_joiner
1465 template_params["cases"] = case_joiner.join(case_parts)
1466 template_params["default"] = default_sql
1467 sql_params.extend(default_params)
1468 template = template or template_params.get("template", self.template)
1469 sql = template % template_params
1470 if self._output_field_or_none is not None:
1471 sql = connection.ops.unification_cast_sql(self.output_field) % sql
1472 return sql, sql_params
1473
1474 def get_group_by_cols(self):
1475 if not self.cases:
1476 return self.default.get_group_by_cols()
1477 return super().get_group_by_cols()
1478
1479
1480class Subquery(BaseExpression, Combinable):
1481 """
1482 An explicit subquery. It may contain OuterRef() references to the outer
1483 query which will be resolved when it is applied to that query.
1484 """
1485
1486 template = "(%(subquery)s)"
1487 contains_aggregate = False
1488 empty_result_set_value = None
1489
1490 def __init__(self, queryset, output_field=None, **extra):
1491 # Allow the usage of both QuerySet and sql.Query objects.
1492 self.query = getattr(queryset, "query", queryset).clone()
1493 self.query.subquery = True
1494 self.extra = extra
1495 super().__init__(output_field)
1496
1497 def get_source_expressions(self):
1498 return [self.query]
1499
1500 def set_source_expressions(self, exprs):
1501 self.query = exprs[0]
1502
1503 def _resolve_output_field(self):
1504 return self.query.output_field
1505
1506 def copy(self):
1507 clone = super().copy()
1508 clone.query = clone.query.clone()
1509 return clone
1510
1511 @property
1512 def external_aliases(self):
1513 return self.query.external_aliases
1514
1515 def get_external_cols(self):
1516 return self.query.get_external_cols()
1517
1518 def as_sql(self, compiler, connection, template=None, **extra_context):
1519 connection.ops.check_expression_support(self)
1520 template_params = {**self.extra, **extra_context}
1521 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1522 template_params["subquery"] = subquery_sql[1:-1]
1523
1524 template = template or template_params.get("template", self.template)
1525 sql = template % template_params
1526 return sql, sql_params
1527
1528 def get_group_by_cols(self):
1529 return self.query.get_group_by_cols(wrapper=self)
1530
1531
1532class Exists(Subquery):
1533 template = "EXISTS(%(subquery)s)"
1534 output_field = fields.BooleanField()
1535 empty_result_set_value = False
1536
1537 def __init__(self, queryset, **kwargs):
1538 super().__init__(queryset, **kwargs)
1539 self.query = self.query.exists()
1540
1541 def select_format(self, compiler, sql, params):
1542 # Wrap EXISTS() with a CASE WHEN expression if a database backend
1543 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1544 # BY list.
1545 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1546 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1547 return sql, params
1548
1549
1550@deconstructible(path="plain.models.OrderBy")
1551class OrderBy(Expression):
1552 template = "%(expression)s %(ordering)s"
1553 conditional = False
1554
1555 def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None):
1556 if nulls_first and nulls_last:
1557 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1558 if nulls_first is False or nulls_last is False:
1559 raise ValueError("nulls_first and nulls_last values must be True or None.")
1560 self.nulls_first = nulls_first
1561 self.nulls_last = nulls_last
1562 self.descending = descending
1563 if not hasattr(expression, "resolve_expression"):
1564 raise ValueError("expression must be an expression type")
1565 self.expression = expression
1566
1567 def __repr__(self):
1568 return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1569
1570 def set_source_expressions(self, exprs):
1571 self.expression = exprs[0]
1572
1573 def get_source_expressions(self):
1574 return [self.expression]
1575
1576 def as_sql(self, compiler, connection, template=None, **extra_context):
1577 template = template or self.template
1578 if connection.features.supports_order_by_nulls_modifier:
1579 if self.nulls_last:
1580 template = f"{template} NULLS LAST"
1581 elif self.nulls_first:
1582 template = f"{template} NULLS FIRST"
1583 else:
1584 if self.nulls_last and not (
1585 self.descending and connection.features.order_by_nulls_first
1586 ):
1587 template = f"%(expression)s IS NULL, {template}"
1588 elif self.nulls_first and not (
1589 not self.descending and connection.features.order_by_nulls_first
1590 ):
1591 template = f"%(expression)s IS NOT NULL, {template}"
1592 connection.ops.check_expression_support(self)
1593 expression_sql, params = compiler.compile(self.expression)
1594 placeholders = {
1595 "expression": expression_sql,
1596 "ordering": "DESC" if self.descending else "ASC",
1597 **extra_context,
1598 }
1599 params *= template.count("%(expression)s")
1600 return (template % placeholders).rstrip(), params
1601
1602 def get_group_by_cols(self):
1603 cols = []
1604 for source in self.get_source_expressions():
1605 cols.extend(source.get_group_by_cols())
1606 return cols
1607
1608 def reverse_ordering(self):
1609 self.descending = not self.descending
1610 if self.nulls_first:
1611 self.nulls_last = True
1612 self.nulls_first = None
1613 elif self.nulls_last:
1614 self.nulls_first = True
1615 self.nulls_last = None
1616 return self
1617
1618 def asc(self):
1619 self.descending = False
1620
1621 def desc(self):
1622 self.descending = True
1623
1624
1625class Window(SQLiteNumericMixin, Expression):
1626 template = "%(expression)s OVER (%(window)s)"
1627 # Although the main expression may either be an aggregate or an
1628 # expression with an aggregate function, the GROUP BY that will
1629 # be introduced in the query as a result is not desired.
1630 contains_aggregate = False
1631 contains_over_clause = True
1632
1633 def __init__(
1634 self,
1635 expression,
1636 partition_by=None,
1637 order_by=None,
1638 frame=None,
1639 output_field=None,
1640 ):
1641 self.partition_by = partition_by
1642 self.order_by = order_by
1643 self.frame = frame
1644
1645 if not getattr(expression, "window_compatible", False):
1646 raise ValueError(
1647 f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1648 )
1649
1650 if self.partition_by is not None:
1651 if not isinstance(self.partition_by, tuple | list):
1652 self.partition_by = (self.partition_by,)
1653 self.partition_by = ExpressionList(*self.partition_by)
1654
1655 if self.order_by is not None:
1656 if isinstance(self.order_by, list | tuple):
1657 self.order_by = OrderByList(*self.order_by)
1658 elif isinstance(self.order_by, BaseExpression | str):
1659 self.order_by = OrderByList(self.order_by)
1660 else:
1661 raise ValueError(
1662 "Window.order_by must be either a string reference to a "
1663 "field, an expression, or a list or tuple of them."
1664 )
1665 super().__init__(output_field=output_field)
1666 self.source_expression = self._parse_expressions(expression)[0]
1667
1668 def _resolve_output_field(self):
1669 return self.source_expression.output_field
1670
1671 def get_source_expressions(self):
1672 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1673
1674 def set_source_expressions(self, exprs):
1675 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1676
1677 def as_sql(self, compiler, connection, template=None):
1678 connection.ops.check_expression_support(self)
1679 if not connection.features.supports_over_clause:
1680 raise NotSupportedError("This backend does not support window expressions.")
1681 expr_sql, params = compiler.compile(self.source_expression)
1682 window_sql, window_params = [], ()
1683
1684 if self.partition_by is not None:
1685 sql_expr, sql_params = self.partition_by.as_sql(
1686 compiler=compiler,
1687 connection=connection,
1688 template="PARTITION BY %(expressions)s",
1689 )
1690 window_sql.append(sql_expr)
1691 window_params += tuple(sql_params)
1692
1693 if self.order_by is not None:
1694 order_sql, order_params = compiler.compile(self.order_by)
1695 window_sql.append(order_sql)
1696 window_params += tuple(order_params)
1697
1698 if self.frame:
1699 frame_sql, frame_params = compiler.compile(self.frame)
1700 window_sql.append(frame_sql)
1701 window_params += tuple(frame_params)
1702
1703 template = template or self.template
1704
1705 return (
1706 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1707 (*params, *window_params),
1708 )
1709
1710 def as_sqlite(self, compiler, connection):
1711 if isinstance(self.output_field, fields.DecimalField):
1712 # Casting to numeric must be outside of the window expression.
1713 copy = self.copy()
1714 source_expressions = copy.get_source_expressions()
1715 source_expressions[0].output_field = fields.FloatField()
1716 copy.set_source_expressions(source_expressions)
1717 return super(Window, copy).as_sqlite(compiler, connection)
1718 return self.as_sql(compiler, connection)
1719
1720 def __str__(self):
1721 return "{} OVER ({}{}{})".format(
1722 str(self.source_expression),
1723 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1724 str(self.order_by or ""),
1725 str(self.frame or ""),
1726 )
1727
1728 def __repr__(self):
1729 return f"<{self.__class__.__name__}: {self}>"
1730
1731 def get_group_by_cols(self):
1732 group_by_cols = []
1733 if self.partition_by:
1734 group_by_cols.extend(self.partition_by.get_group_by_cols())
1735 if self.order_by is not None:
1736 group_by_cols.extend(self.order_by.get_group_by_cols())
1737 return group_by_cols
1738
1739
1740class WindowFrame(Expression):
1741 """
1742 Model the frame clause in window expressions. There are two types of frame
1743 clauses which are subclasses, however, all processing and validation (by no
1744 means intended to be complete) is done here. Thus, providing an end for a
1745 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1746 row in the frame).
1747 """
1748
1749 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1750
1751 def __init__(self, start=None, end=None):
1752 self.start = Value(start)
1753 self.end = Value(end)
1754
1755 def set_source_expressions(self, exprs):
1756 self.start, self.end = exprs
1757
1758 def get_source_expressions(self):
1759 return [self.start, self.end]
1760
1761 def as_sql(self, compiler, connection):
1762 connection.ops.check_expression_support(self)
1763 start, end = self.window_frame_start_end(
1764 connection, self.start.value, self.end.value
1765 )
1766 return (
1767 self.template
1768 % {
1769 "frame_type": self.frame_type,
1770 "start": start,
1771 "end": end,
1772 },
1773 [],
1774 )
1775
1776 def __repr__(self):
1777 return f"<{self.__class__.__name__}: {self}>"
1778
1779 def get_group_by_cols(self):
1780 return []
1781
1782 def __str__(self):
1783 if self.start.value is not None and self.start.value < 0:
1784 start = f"{abs(self.start.value)} {db_connection.ops.PRECEDING}"
1785 elif self.start.value is not None and self.start.value == 0:
1786 start = db_connection.ops.CURRENT_ROW
1787 else:
1788 start = db_connection.ops.UNBOUNDED_PRECEDING
1789
1790 if self.end.value is not None and self.end.value > 0:
1791 end = f"{self.end.value} {db_connection.ops.FOLLOWING}"
1792 elif self.end.value is not None and self.end.value == 0:
1793 end = db_connection.ops.CURRENT_ROW
1794 else:
1795 end = db_connection.ops.UNBOUNDED_FOLLOWING
1796 return self.template % {
1797 "frame_type": self.frame_type,
1798 "start": start,
1799 "end": end,
1800 }
1801
1802 def window_frame_start_end(self, connection, start, end):
1803 raise NotImplementedError("Subclasses must implement window_frame_start_end().")
1804
1805
1806class RowRange(WindowFrame):
1807 frame_type = "ROWS"
1808
1809 def window_frame_start_end(self, connection, start, end):
1810 return connection.ops.window_frame_rows_start_end(start, end)
1811
1812
1813class ValueRange(WindowFrame):
1814 frame_type = "RANGE"
1815
1816 def window_frame_start_end(self, connection, start, end):
1817 return connection.ops.window_frame_range_start_end(start, end)