plain.models
Model your data and store it in a database.
# app/users/models.py
from plain import models
from plain.passwords.models import PasswordField
class User(models.Model):
email = models.EmailField(unique=True)
password = PasswordField()
is_staff = models.BooleanField(default=False)
created_at = models.DateTimeField(auto_now_add=True)
def __str__(self):
return self.email
Create, update, and delete instances of your models:
from .models import User
# Create a new user
user = User.objects.create(
email="[email protected]",
password="password",
)
# Update a user
user.email = "[email protected]"
user.save()
# Delete a user
user.delete()
# Query for users
staff_users = User.objects.filter(is_staff=True)
Installation
# app/settings.py
INSTALLED_PACKAGES = [
...
"plain.models",
]
To connect to a database, you can provide a DATABASE_URL
environment variable.
DATABASE_URL=postgresql://user:password@localhost:5432/dbname
Or you can manually define the DATABASES
setting.
# app/settings.py
DATABASES = {
"default": {
"ENGINE": "plain.models.backends.postgresql",
"NAME": "dbname",
"USER": "user",
"PASSWORD": "password",
"HOST": "localhost",
"PORT": "5432",
}
}
Multiple backends are supported, including Postgres, MySQL, and SQLite.
Querying
Migrations
Fields
Validation
Indexes and constraints
Managers
Forms
1import copy
2import datetime
3import functools
4import inspect
5from collections import defaultdict
6from decimal import Decimal
7from types import NoneType
8from uuid import UUID
9
10from plain.exceptions import EmptyResultSet, FieldError, FullResultSet
11from plain.models import fields
12from plain.models.constants import LOOKUP_SEP
13from plain.models.db import DatabaseError, NotSupportedError, connection
14from plain.models.query_utils import Q
15from plain.utils.deconstruct import deconstructible
16from plain.utils.functional import cached_property
17from plain.utils.hashable import make_hashable
18
19
20class SQLiteNumericMixin:
21 """
22 Some expressions with output_field=DecimalField() must be cast to
23 numeric to be properly filtered.
24 """
25
26 def as_sqlite(self, compiler, connection, **extra_context):
27 sql, params = self.as_sql(compiler, connection, **extra_context)
28 try:
29 if self.output_field.get_internal_type() == "DecimalField":
30 sql = "CAST(%s AS NUMERIC)" % sql
31 except FieldError:
32 pass
33 return sql, params
34
35
36class Combinable:
37 """
38 Provide the ability to combine one or two objects with
39 some connector. For example F('foo') + F('bar').
40 """
41
42 # Arithmetic connectors
43 ADD = "+"
44 SUB = "-"
45 MUL = "*"
46 DIV = "/"
47 POW = "^"
48 # The following is a quoted % operator - it is quoted because it can be
49 # used in strings that also have parameter substitution.
50 MOD = "%%"
51
52 # Bitwise operators - note that these are generated by .bitand()
53 # and .bitor(), the '&' and '|' are reserved for boolean operator
54 # usage.
55 BITAND = "&"
56 BITOR = "|"
57 BITLEFTSHIFT = "<<"
58 BITRIGHTSHIFT = ">>"
59 BITXOR = "#"
60
61 def _combine(self, other, connector, reversed):
62 if not hasattr(other, "resolve_expression"):
63 # everything must be resolvable to an expression
64 other = Value(other)
65
66 if reversed:
67 return CombinedExpression(other, connector, self)
68 return CombinedExpression(self, connector, other)
69
70 #############
71 # OPERATORS #
72 #############
73
74 def __neg__(self):
75 return self._combine(-1, self.MUL, False)
76
77 def __add__(self, other):
78 return self._combine(other, self.ADD, False)
79
80 def __sub__(self, other):
81 return self._combine(other, self.SUB, False)
82
83 def __mul__(self, other):
84 return self._combine(other, self.MUL, False)
85
86 def __truediv__(self, other):
87 return self._combine(other, self.DIV, False)
88
89 def __mod__(self, other):
90 return self._combine(other, self.MOD, False)
91
92 def __pow__(self, other):
93 return self._combine(other, self.POW, False)
94
95 def __and__(self, other):
96 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
97 return Q(self) & Q(other)
98 raise NotImplementedError(
99 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
100 )
101
102 def bitand(self, other):
103 return self._combine(other, self.BITAND, False)
104
105 def bitleftshift(self, other):
106 return self._combine(other, self.BITLEFTSHIFT, False)
107
108 def bitrightshift(self, other):
109 return self._combine(other, self.BITRIGHTSHIFT, False)
110
111 def __xor__(self, other):
112 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
113 return Q(self) ^ Q(other)
114 raise NotImplementedError(
115 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
116 )
117
118 def bitxor(self, other):
119 return self._combine(other, self.BITXOR, False)
120
121 def __or__(self, other):
122 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
123 return Q(self) | Q(other)
124 raise NotImplementedError(
125 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
126 )
127
128 def bitor(self, other):
129 return self._combine(other, self.BITOR, False)
130
131 def __radd__(self, other):
132 return self._combine(other, self.ADD, True)
133
134 def __rsub__(self, other):
135 return self._combine(other, self.SUB, True)
136
137 def __rmul__(self, other):
138 return self._combine(other, self.MUL, True)
139
140 def __rtruediv__(self, other):
141 return self._combine(other, self.DIV, True)
142
143 def __rmod__(self, other):
144 return self._combine(other, self.MOD, True)
145
146 def __rpow__(self, other):
147 return self._combine(other, self.POW, True)
148
149 def __rand__(self, other):
150 raise NotImplementedError(
151 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
152 )
153
154 def __ror__(self, other):
155 raise NotImplementedError(
156 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
157 )
158
159 def __rxor__(self, other):
160 raise NotImplementedError(
161 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
162 )
163
164 def __invert__(self):
165 return NegatedExpression(self)
166
167
168class BaseExpression:
169 """Base class for all query expressions."""
170
171 empty_result_set_value = NotImplemented
172 # aggregate specific fields
173 is_summary = False
174 _output_field_resolved_to_none = False
175 # Can the expression be used in a WHERE clause?
176 filterable = True
177 # Can the expression can be used as a source expression in Window?
178 window_compatible = False
179
180 def __init__(self, output_field=None):
181 if output_field is not None:
182 self.output_field = output_field
183
184 def __getstate__(self):
185 state = self.__dict__.copy()
186 state.pop("convert_value", None)
187 return state
188
189 def get_db_converters(self, connection):
190 return (
191 []
192 if self.convert_value is self._convert_value_noop
193 else [self.convert_value]
194 ) + self.output_field.get_db_converters(connection)
195
196 def get_source_expressions(self):
197 return []
198
199 def set_source_expressions(self, exprs):
200 assert not exprs
201
202 def _parse_expressions(self, *expressions):
203 return [
204 arg
205 if hasattr(arg, "resolve_expression")
206 else (F(arg) if isinstance(arg, str) else Value(arg))
207 for arg in expressions
208 ]
209
210 def as_sql(self, compiler, connection):
211 """
212 Responsible for returning a (sql, [params]) tuple to be included
213 in the current query.
214
215 Different backends can provide their own implementation, by
216 providing an `as_{vendor}` method and patching the Expression:
217
218 ```
219 def override_as_sql(self, compiler, connection):
220 # custom logic
221 return super().as_sql(compiler, connection)
222 setattr(Expression, 'as_' + connection.vendor, override_as_sql)
223 ```
224
225 Arguments:
226 * compiler: the query compiler responsible for generating the query.
227 Must have a compile method, returning a (sql, [params]) tuple.
228 Calling compiler(value) will return a quoted `value`.
229
230 * connection: the database connection used for the current query.
231
232 Return: (sql, params)
233 Where `sql` is a string containing ordered sql parameters to be
234 replaced with the elements of the list `params`.
235 """
236 raise NotImplementedError("Subclasses must implement as_sql()")
237
238 @cached_property
239 def contains_aggregate(self):
240 return any(
241 expr and expr.contains_aggregate for expr in self.get_source_expressions()
242 )
243
244 @cached_property
245 def contains_over_clause(self):
246 return any(
247 expr and expr.contains_over_clause for expr in self.get_source_expressions()
248 )
249
250 @cached_property
251 def contains_column_references(self):
252 return any(
253 expr and expr.contains_column_references
254 for expr in self.get_source_expressions()
255 )
256
257 def resolve_expression(
258 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
259 ):
260 """
261 Provide the chance to do any preprocessing or validation before being
262 added to the query.
263
264 Arguments:
265 * query: the backend query implementation
266 * allow_joins: boolean allowing or denying use of joins
267 in this query
268 * reuse: a set of reusable joins for multijoins
269 * summarize: a terminal aggregate clause
270 * for_save: whether this expression about to be used in a save or update
271
272 Return: an Expression to be added to the query.
273 """
274 c = self.copy()
275 c.is_summary = summarize
276 c.set_source_expressions(
277 [
278 expr.resolve_expression(query, allow_joins, reuse, summarize)
279 if expr
280 else None
281 for expr in c.get_source_expressions()
282 ]
283 )
284 return c
285
286 @property
287 def conditional(self):
288 return isinstance(self.output_field, fields.BooleanField)
289
290 @property
291 def field(self):
292 return self.output_field
293
294 @cached_property
295 def output_field(self):
296 """Return the output type of this expressions."""
297 output_field = self._resolve_output_field()
298 if output_field is None:
299 self._output_field_resolved_to_none = True
300 raise FieldError("Cannot resolve expression type, unknown output_field")
301 return output_field
302
303 @cached_property
304 def _output_field_or_none(self):
305 """
306 Return the output field of this expression, or None if
307 _resolve_output_field() didn't return an output type.
308 """
309 try:
310 return self.output_field
311 except FieldError:
312 if not self._output_field_resolved_to_none:
313 raise
314
315 def _resolve_output_field(self):
316 """
317 Attempt to infer the output type of the expression.
318
319 As a guess, if the output fields of all source fields match then simply
320 infer the same type here.
321
322 If a source's output field resolves to None, exclude it from this check.
323 If all sources are None, then an error is raised higher up the stack in
324 the output_field property.
325 """
326 # This guess is mostly a bad idea, but there is quite a lot of code
327 # (especially 3rd party Func subclasses) that depend on it, we'd need a
328 # deprecation path to fix it.
329 sources_iter = (
330 source for source in self.get_source_fields() if source is not None
331 )
332 for output_field in sources_iter:
333 for source in sources_iter:
334 if not isinstance(output_field, source.__class__):
335 raise FieldError(
336 "Expression contains mixed types: {}, {}. You must "
337 "set output_field.".format(
338 output_field.__class__.__name__,
339 source.__class__.__name__,
340 )
341 )
342 return output_field
343
344 @staticmethod
345 def _convert_value_noop(value, expression, connection):
346 return value
347
348 @cached_property
349 def convert_value(self):
350 """
351 Expressions provide their own converters because users have the option
352 of manually specifying the output_field which may be a different type
353 from the one the database returns.
354 """
355 field = self.output_field
356 internal_type = field.get_internal_type()
357 if internal_type == "FloatField":
358 return (
359 lambda value, expression, connection: None
360 if value is None
361 else float(value)
362 )
363 elif internal_type.endswith("IntegerField"):
364 return (
365 lambda value, expression, connection: None
366 if value is None
367 else int(value)
368 )
369 elif internal_type == "DecimalField":
370 return (
371 lambda value, expression, connection: None
372 if value is None
373 else Decimal(value)
374 )
375 return self._convert_value_noop
376
377 def get_lookup(self, lookup):
378 return self.output_field.get_lookup(lookup)
379
380 def get_transform(self, name):
381 return self.output_field.get_transform(name)
382
383 def relabeled_clone(self, change_map):
384 clone = self.copy()
385 clone.set_source_expressions(
386 [
387 e.relabeled_clone(change_map) if e is not None else None
388 for e in self.get_source_expressions()
389 ]
390 )
391 return clone
392
393 def replace_expressions(self, replacements):
394 if replacement := replacements.get(self):
395 return replacement
396 clone = self.copy()
397 source_expressions = clone.get_source_expressions()
398 clone.set_source_expressions(
399 [
400 expr.replace_expressions(replacements) if expr else None
401 for expr in source_expressions
402 ]
403 )
404 return clone
405
406 def get_refs(self):
407 refs = set()
408 for expr in self.get_source_expressions():
409 refs |= expr.get_refs()
410 return refs
411
412 def copy(self):
413 return copy.copy(self)
414
415 def prefix_references(self, prefix):
416 clone = self.copy()
417 clone.set_source_expressions(
418 [
419 F(f"{prefix}{expr.name}")
420 if isinstance(expr, F)
421 else expr.prefix_references(prefix)
422 for expr in self.get_source_expressions()
423 ]
424 )
425 return clone
426
427 def get_group_by_cols(self):
428 if not self.contains_aggregate:
429 return [self]
430 cols = []
431 for source in self.get_source_expressions():
432 cols.extend(source.get_group_by_cols())
433 return cols
434
435 def get_source_fields(self):
436 """Return the underlying field types used by this aggregate."""
437 return [e._output_field_or_none for e in self.get_source_expressions()]
438
439 def asc(self, **kwargs):
440 return OrderBy(self, **kwargs)
441
442 def desc(self, **kwargs):
443 return OrderBy(self, descending=True, **kwargs)
444
445 def reverse_ordering(self):
446 return self
447
448 def flatten(self):
449 """
450 Recursively yield this expression and all subexpressions, in
451 depth-first order.
452 """
453 yield self
454 for expr in self.get_source_expressions():
455 if expr:
456 if hasattr(expr, "flatten"):
457 yield from expr.flatten()
458 else:
459 yield expr
460
461 def select_format(self, compiler, sql, params):
462 """
463 Custom format for select clauses. For example, EXISTS expressions need
464 to be wrapped in CASE WHEN on Oracle.
465 """
466 if hasattr(self.output_field, "select_format"):
467 return self.output_field.select_format(compiler, sql, params)
468 return sql, params
469
470
471@deconstructible
472class Expression(BaseExpression, Combinable):
473 """An expression that can be combined with other expressions."""
474
475 @cached_property
476 def identity(self):
477 constructor_signature = inspect.signature(self.__init__)
478 args, kwargs = self._constructor_args
479 signature = constructor_signature.bind_partial(*args, **kwargs)
480 signature.apply_defaults()
481 arguments = signature.arguments.items()
482 identity = [self.__class__]
483 for arg, value in arguments:
484 if isinstance(value, fields.Field):
485 if value.name and value.model:
486 value = (value.model._meta.label, value.name)
487 else:
488 value = type(value)
489 else:
490 value = make_hashable(value)
491 identity.append((arg, value))
492 return tuple(identity)
493
494 def __eq__(self, other):
495 if not isinstance(other, Expression):
496 return NotImplemented
497 return other.identity == self.identity
498
499 def __hash__(self):
500 return hash(self.identity)
501
502
503# Type inference for CombinedExpression.output_field.
504# Missing items will result in FieldError, by design.
505#
506# The current approach for NULL is based on lowest common denominator behavior
507# i.e. if one of the supported databases is raising an error (rather than
508# return NULL) for `val <op> NULL`, then Plain raises FieldError.
509
510_connector_combinations = [
511 # Numeric operations - operands of same type.
512 {
513 connector: [
514 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
515 (fields.FloatField, fields.FloatField, fields.FloatField),
516 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
517 ]
518 for connector in (
519 Combinable.ADD,
520 Combinable.SUB,
521 Combinable.MUL,
522 # Behavior for DIV with integer arguments follows Postgres/SQLite,
523 # not MySQL/Oracle.
524 Combinable.DIV,
525 Combinable.MOD,
526 Combinable.POW,
527 )
528 },
529 # Numeric operations - operands of different type.
530 {
531 connector: [
532 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
533 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
534 (fields.IntegerField, fields.FloatField, fields.FloatField),
535 (fields.FloatField, fields.IntegerField, fields.FloatField),
536 ]
537 for connector in (
538 Combinable.ADD,
539 Combinable.SUB,
540 Combinable.MUL,
541 Combinable.DIV,
542 Combinable.MOD,
543 )
544 },
545 # Bitwise operators.
546 {
547 connector: [
548 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
549 ]
550 for connector in (
551 Combinable.BITAND,
552 Combinable.BITOR,
553 Combinable.BITLEFTSHIFT,
554 Combinable.BITRIGHTSHIFT,
555 Combinable.BITXOR,
556 )
557 },
558 # Numeric with NULL.
559 {
560 connector: [
561 (field_type, NoneType, field_type),
562 (NoneType, field_type, field_type),
563 ]
564 for connector in (
565 Combinable.ADD,
566 Combinable.SUB,
567 Combinable.MUL,
568 Combinable.DIV,
569 Combinable.MOD,
570 Combinable.POW,
571 )
572 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
573 },
574 # Date/DateTimeField/DurationField/TimeField.
575 {
576 Combinable.ADD: [
577 # Date/DateTimeField.
578 (fields.DateField, fields.DurationField, fields.DateTimeField),
579 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
580 (fields.DurationField, fields.DateField, fields.DateTimeField),
581 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
582 # DurationField.
583 (fields.DurationField, fields.DurationField, fields.DurationField),
584 # TimeField.
585 (fields.TimeField, fields.DurationField, fields.TimeField),
586 (fields.DurationField, fields.TimeField, fields.TimeField),
587 ],
588 },
589 {
590 Combinable.SUB: [
591 # Date/DateTimeField.
592 (fields.DateField, fields.DurationField, fields.DateTimeField),
593 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
594 (fields.DateField, fields.DateField, fields.DurationField),
595 (fields.DateField, fields.DateTimeField, fields.DurationField),
596 (fields.DateTimeField, fields.DateField, fields.DurationField),
597 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
598 # DurationField.
599 (fields.DurationField, fields.DurationField, fields.DurationField),
600 # TimeField.
601 (fields.TimeField, fields.DurationField, fields.TimeField),
602 (fields.TimeField, fields.TimeField, fields.DurationField),
603 ],
604 },
605]
606
607_connector_combinators = defaultdict(list)
608
609
610def register_combinable_fields(lhs, connector, rhs, result):
611 """
612 Register combinable types:
613 lhs <connector> rhs -> result
614 e.g.
615 register_combinable_fields(
616 IntegerField, Combinable.ADD, FloatField, FloatField
617 )
618 """
619 _connector_combinators[connector].append((lhs, rhs, result))
620
621
622for d in _connector_combinations:
623 for connector, field_types in d.items():
624 for lhs, rhs, result in field_types:
625 register_combinable_fields(lhs, connector, rhs, result)
626
627
628@functools.lru_cache(maxsize=128)
629def _resolve_combined_type(connector, lhs_type, rhs_type):
630 combinators = _connector_combinators.get(connector, ())
631 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
632 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
633 rhs_type, combinator_rhs_type
634 ):
635 return combined_type
636
637
638class CombinedExpression(SQLiteNumericMixin, Expression):
639 def __init__(self, lhs, connector, rhs, output_field=None):
640 super().__init__(output_field=output_field)
641 self.connector = connector
642 self.lhs = lhs
643 self.rhs = rhs
644
645 def __repr__(self):
646 return f"<{self.__class__.__name__}: {self}>"
647
648 def __str__(self):
649 return f"{self.lhs} {self.connector} {self.rhs}"
650
651 def get_source_expressions(self):
652 return [self.lhs, self.rhs]
653
654 def set_source_expressions(self, exprs):
655 self.lhs, self.rhs = exprs
656
657 def _resolve_output_field(self):
658 # We avoid using super() here for reasons given in
659 # Expression._resolve_output_field()
660 combined_type = _resolve_combined_type(
661 self.connector,
662 type(self.lhs._output_field_or_none),
663 type(self.rhs._output_field_or_none),
664 )
665 if combined_type is None:
666 raise FieldError(
667 f"Cannot infer type of {self.connector!r} expression involving these "
668 f"types: {self.lhs.output_field.__class__.__name__}, "
669 f"{self.rhs.output_field.__class__.__name__}. You must set "
670 f"output_field."
671 )
672 return combined_type()
673
674 def as_sql(self, compiler, connection):
675 expressions = []
676 expression_params = []
677 sql, params = compiler.compile(self.lhs)
678 expressions.append(sql)
679 expression_params.extend(params)
680 sql, params = compiler.compile(self.rhs)
681 expressions.append(sql)
682 expression_params.extend(params)
683 # order of precedence
684 expression_wrapper = "(%s)"
685 sql = connection.ops.combine_expression(self.connector, expressions)
686 return expression_wrapper % sql, expression_params
687
688 def resolve_expression(
689 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
690 ):
691 lhs = self.lhs.resolve_expression(
692 query, allow_joins, reuse, summarize, for_save
693 )
694 rhs = self.rhs.resolve_expression(
695 query, allow_joins, reuse, summarize, for_save
696 )
697 if not isinstance(self, DurationExpression | TemporalSubtraction):
698 try:
699 lhs_type = lhs.output_field.get_internal_type()
700 except (AttributeError, FieldError):
701 lhs_type = None
702 try:
703 rhs_type = rhs.output_field.get_internal_type()
704 except (AttributeError, FieldError):
705 rhs_type = None
706 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
707 return DurationExpression(
708 self.lhs, self.connector, self.rhs
709 ).resolve_expression(
710 query,
711 allow_joins,
712 reuse,
713 summarize,
714 for_save,
715 )
716 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
717 if (
718 self.connector == self.SUB
719 and lhs_type in datetime_fields
720 and lhs_type == rhs_type
721 ):
722 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
723 query,
724 allow_joins,
725 reuse,
726 summarize,
727 for_save,
728 )
729 c = self.copy()
730 c.is_summary = summarize
731 c.lhs = lhs
732 c.rhs = rhs
733 return c
734
735
736class DurationExpression(CombinedExpression):
737 def compile(self, side, compiler, connection):
738 try:
739 output = side.output_field
740 except FieldError:
741 pass
742 else:
743 if output.get_internal_type() == "DurationField":
744 sql, params = compiler.compile(side)
745 return connection.ops.format_for_duration_arithmetic(sql), params
746 return compiler.compile(side)
747
748 def as_sql(self, compiler, connection):
749 if connection.features.has_native_duration_field:
750 return super().as_sql(compiler, connection)
751 connection.ops.check_expression_support(self)
752 expressions = []
753 expression_params = []
754 sql, params = self.compile(self.lhs, compiler, connection)
755 expressions.append(sql)
756 expression_params.extend(params)
757 sql, params = self.compile(self.rhs, compiler, connection)
758 expressions.append(sql)
759 expression_params.extend(params)
760 # order of precedence
761 expression_wrapper = "(%s)"
762 sql = connection.ops.combine_duration_expression(self.connector, expressions)
763 return expression_wrapper % sql, expression_params
764
765 def as_sqlite(self, compiler, connection, **extra_context):
766 sql, params = self.as_sql(compiler, connection, **extra_context)
767 if self.connector in {Combinable.MUL, Combinable.DIV}:
768 try:
769 lhs_type = self.lhs.output_field.get_internal_type()
770 rhs_type = self.rhs.output_field.get_internal_type()
771 except (AttributeError, FieldError):
772 pass
773 else:
774 allowed_fields = {
775 "DecimalField",
776 "DurationField",
777 "FloatField",
778 "IntegerField",
779 }
780 if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
781 raise DatabaseError(
782 f"Invalid arguments for operator {self.connector}."
783 )
784 return sql, params
785
786
787class TemporalSubtraction(CombinedExpression):
788 output_field = fields.DurationField()
789
790 def __init__(self, lhs, rhs):
791 super().__init__(lhs, self.SUB, rhs)
792
793 def as_sql(self, compiler, connection):
794 connection.ops.check_expression_support(self)
795 lhs = compiler.compile(self.lhs)
796 rhs = compiler.compile(self.rhs)
797 return connection.ops.subtract_temporals(
798 self.lhs.output_field.get_internal_type(), lhs, rhs
799 )
800
801
802@deconstructible(path="plain.models.F")
803class F(Combinable):
804 """An object capable of resolving references to existing query objects."""
805
806 def __init__(self, name):
807 """
808 Arguments:
809 * name: the name of the field this expression references
810 """
811 self.name = name
812
813 def __repr__(self):
814 return f"{self.__class__.__name__}({self.name})"
815
816 def resolve_expression(
817 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
818 ):
819 return query.resolve_ref(self.name, allow_joins, reuse, summarize)
820
821 def replace_expressions(self, replacements):
822 return replacements.get(self, self)
823
824 def asc(self, **kwargs):
825 return OrderBy(self, **kwargs)
826
827 def desc(self, **kwargs):
828 return OrderBy(self, descending=True, **kwargs)
829
830 def __eq__(self, other):
831 return self.__class__ == other.__class__ and self.name == other.name
832
833 def __hash__(self):
834 return hash(self.name)
835
836 def copy(self):
837 return copy.copy(self)
838
839
840class ResolvedOuterRef(F):
841 """
842 An object that contains a reference to an outer query.
843
844 In this case, the reference to the outer query has been resolved because
845 the inner query has been used as a subquery.
846 """
847
848 contains_aggregate = False
849 contains_over_clause = False
850
851 def as_sql(self, *args, **kwargs):
852 raise ValueError(
853 "This queryset contains a reference to an outer query and may "
854 "only be used in a subquery."
855 )
856
857 def resolve_expression(self, *args, **kwargs):
858 col = super().resolve_expression(*args, **kwargs)
859 if col.contains_over_clause:
860 raise NotSupportedError(
861 f"Referencing outer query window expression is not supported: "
862 f"{self.name}."
863 )
864 # FIXME: Rename possibly_multivalued to multivalued and fix detection
865 # for non-multivalued JOINs (e.g. foreign key fields). This should take
866 # into account only many-to-many and one-to-many relationships.
867 col.possibly_multivalued = LOOKUP_SEP in self.name
868 return col
869
870 def relabeled_clone(self, relabels):
871 return self
872
873 def get_group_by_cols(self):
874 return []
875
876
877class OuterRef(F):
878 contains_aggregate = False
879
880 def resolve_expression(self, *args, **kwargs):
881 if isinstance(self.name, self.__class__):
882 return self.name
883 return ResolvedOuterRef(self.name)
884
885 def relabeled_clone(self, relabels):
886 return self
887
888
889@deconstructible(path="plain.models.Func")
890class Func(SQLiteNumericMixin, Expression):
891 """An SQL function call."""
892
893 function = None
894 template = "%(function)s(%(expressions)s)"
895 arg_joiner = ", "
896 arity = None # The number of arguments the function accepts.
897
898 def __init__(self, *expressions, output_field=None, **extra):
899 if self.arity is not None and len(expressions) != self.arity:
900 raise TypeError(
901 "'{}' takes exactly {} {} ({} given)".format(
902 self.__class__.__name__,
903 self.arity,
904 "argument" if self.arity == 1 else "arguments",
905 len(expressions),
906 )
907 )
908 super().__init__(output_field=output_field)
909 self.source_expressions = self._parse_expressions(*expressions)
910 self.extra = extra
911
912 def __repr__(self):
913 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
914 extra = {**self.extra, **self._get_repr_options()}
915 if extra:
916 extra = ", ".join(
917 str(key) + "=" + str(val) for key, val in sorted(extra.items())
918 )
919 return f"{self.__class__.__name__}({args}, {extra})"
920 return f"{self.__class__.__name__}({args})"
921
922 def _get_repr_options(self):
923 """Return a dict of extra __init__() options to include in the repr."""
924 return {}
925
926 def get_source_expressions(self):
927 return self.source_expressions
928
929 def set_source_expressions(self, exprs):
930 self.source_expressions = exprs
931
932 def resolve_expression(
933 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
934 ):
935 c = self.copy()
936 c.is_summary = summarize
937 for pos, arg in enumerate(c.source_expressions):
938 c.source_expressions[pos] = arg.resolve_expression(
939 query, allow_joins, reuse, summarize, for_save
940 )
941 return c
942
943 def as_sql(
944 self,
945 compiler,
946 connection,
947 function=None,
948 template=None,
949 arg_joiner=None,
950 **extra_context,
951 ):
952 connection.ops.check_expression_support(self)
953 sql_parts = []
954 params = []
955 for arg in self.source_expressions:
956 try:
957 arg_sql, arg_params = compiler.compile(arg)
958 except EmptyResultSet:
959 empty_result_set_value = getattr(
960 arg, "empty_result_set_value", NotImplemented
961 )
962 if empty_result_set_value is NotImplemented:
963 raise
964 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
965 except FullResultSet:
966 arg_sql, arg_params = compiler.compile(Value(True))
967 sql_parts.append(arg_sql)
968 params.extend(arg_params)
969 data = {**self.extra, **extra_context}
970 # Use the first supplied value in this order: the parameter to this
971 # method, a value supplied in __init__()'s **extra (the value in
972 # `data`), or the value defined on the class.
973 if function is not None:
974 data["function"] = function
975 else:
976 data.setdefault("function", self.function)
977 template = template or data.get("template", self.template)
978 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
979 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
980 return template % data, params
981
982 def copy(self):
983 copy = super().copy()
984 copy.source_expressions = self.source_expressions[:]
985 copy.extra = self.extra.copy()
986 return copy
987
988
989@deconstructible(path="plain.models.Value")
990class Value(SQLiteNumericMixin, Expression):
991 """Represent a wrapped value as a node within an expression."""
992
993 # Provide a default value for `for_save` in order to allow unresolved
994 # instances to be compiled until a decision is taken in #25425.
995 for_save = False
996
997 def __init__(self, value, output_field=None):
998 """
999 Arguments:
1000 * value: the value this expression represents. The value will be
1001 added into the sql parameter list and properly quoted.
1002
1003 * output_field: an instance of the model field type that this
1004 expression will return, such as IntegerField() or CharField().
1005 """
1006 super().__init__(output_field=output_field)
1007 self.value = value
1008
1009 def __repr__(self):
1010 return f"{self.__class__.__name__}({self.value!r})"
1011
1012 def as_sql(self, compiler, connection):
1013 connection.ops.check_expression_support(self)
1014 val = self.value
1015 output_field = self._output_field_or_none
1016 if output_field is not None:
1017 if self.for_save:
1018 val = output_field.get_db_prep_save(val, connection=connection)
1019 else:
1020 val = output_field.get_db_prep_value(val, connection=connection)
1021 if hasattr(output_field, "get_placeholder"):
1022 return output_field.get_placeholder(val, compiler, connection), [val]
1023 if val is None:
1024 # cx_Oracle does not always convert None to the appropriate
1025 # NULL type (like in case expressions using numbers), so we
1026 # use a literal SQL NULL
1027 return "NULL", []
1028 return "%s", [val]
1029
1030 def resolve_expression(
1031 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1032 ):
1033 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1034 c.for_save = for_save
1035 return c
1036
1037 def get_group_by_cols(self):
1038 return []
1039
1040 def _resolve_output_field(self):
1041 if isinstance(self.value, str):
1042 return fields.CharField()
1043 if isinstance(self.value, bool):
1044 return fields.BooleanField()
1045 if isinstance(self.value, int):
1046 return fields.IntegerField()
1047 if isinstance(self.value, float):
1048 return fields.FloatField()
1049 if isinstance(self.value, datetime.datetime):
1050 return fields.DateTimeField()
1051 if isinstance(self.value, datetime.date):
1052 return fields.DateField()
1053 if isinstance(self.value, datetime.time):
1054 return fields.TimeField()
1055 if isinstance(self.value, datetime.timedelta):
1056 return fields.DurationField()
1057 if isinstance(self.value, Decimal):
1058 return fields.DecimalField()
1059 if isinstance(self.value, bytes):
1060 return fields.BinaryField()
1061 if isinstance(self.value, UUID):
1062 return fields.UUIDField()
1063
1064 @property
1065 def empty_result_set_value(self):
1066 return self.value
1067
1068
1069class RawSQL(Expression):
1070 def __init__(self, sql, params, output_field=None):
1071 if output_field is None:
1072 output_field = fields.Field()
1073 self.sql, self.params = sql, params
1074 super().__init__(output_field=output_field)
1075
1076 def __repr__(self):
1077 return f"{self.__class__.__name__}({self.sql}, {self.params})"
1078
1079 def as_sql(self, compiler, connection):
1080 return "(%s)" % self.sql, self.params
1081
1082 def get_group_by_cols(self):
1083 return [self]
1084
1085 def resolve_expression(
1086 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1087 ):
1088 # Resolve parents fields used in raw SQL.
1089 if query.model:
1090 for parent in query.model._meta.get_parent_list():
1091 for parent_field in parent._meta.local_fields:
1092 _, column_name = parent_field.get_attname_column()
1093 if column_name.lower() in self.sql.lower():
1094 query.resolve_ref(
1095 parent_field.name, allow_joins, reuse, summarize
1096 )
1097 break
1098 return super().resolve_expression(
1099 query, allow_joins, reuse, summarize, for_save
1100 )
1101
1102
1103class Star(Expression):
1104 def __repr__(self):
1105 return "'*'"
1106
1107 def as_sql(self, compiler, connection):
1108 return "*", []
1109
1110
1111class Col(Expression):
1112 contains_column_references = True
1113 possibly_multivalued = False
1114
1115 def __init__(self, alias, target, output_field=None):
1116 if output_field is None:
1117 output_field = target
1118 super().__init__(output_field=output_field)
1119 self.alias, self.target = alias, target
1120
1121 def __repr__(self):
1122 alias, target = self.alias, self.target
1123 identifiers = (alias, str(target)) if alias else (str(target),)
1124 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1125
1126 def as_sql(self, compiler, connection):
1127 alias, column = self.alias, self.target.column
1128 identifiers = (alias, column) if alias else (column,)
1129 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1130 return sql, []
1131
1132 def relabeled_clone(self, relabels):
1133 if self.alias is None:
1134 return self
1135 return self.__class__(
1136 relabels.get(self.alias, self.alias), self.target, self.output_field
1137 )
1138
1139 def get_group_by_cols(self):
1140 return [self]
1141
1142 def get_db_converters(self, connection):
1143 if self.target == self.output_field:
1144 return self.output_field.get_db_converters(connection)
1145 return self.output_field.get_db_converters(
1146 connection
1147 ) + self.target.get_db_converters(connection)
1148
1149
1150class Ref(Expression):
1151 """
1152 Reference to column alias of the query. For example, Ref('sum_cost') in
1153 qs.annotate(sum_cost=Sum('cost')) query.
1154 """
1155
1156 def __init__(self, refs, source):
1157 super().__init__()
1158 self.refs, self.source = refs, source
1159
1160 def __repr__(self):
1161 return f"{self.__class__.__name__}({self.refs}, {self.source})"
1162
1163 def get_source_expressions(self):
1164 return [self.source]
1165
1166 def set_source_expressions(self, exprs):
1167 (self.source,) = exprs
1168
1169 def resolve_expression(
1170 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1171 ):
1172 # The sub-expression `source` has already been resolved, as this is
1173 # just a reference to the name of `source`.
1174 return self
1175
1176 def get_refs(self):
1177 return {self.refs}
1178
1179 def relabeled_clone(self, relabels):
1180 return self
1181
1182 def as_sql(self, compiler, connection):
1183 return connection.ops.quote_name(self.refs), []
1184
1185 def get_group_by_cols(self):
1186 return [self]
1187
1188
1189class ExpressionList(Func):
1190 """
1191 An expression containing multiple expressions. Can be used to provide a
1192 list of expressions as an argument to another expression, like a partition
1193 clause.
1194 """
1195
1196 template = "%(expressions)s"
1197
1198 def __init__(self, *expressions, **extra):
1199 if not expressions:
1200 raise ValueError(
1201 "%s requires at least one expression." % self.__class__.__name__
1202 )
1203 super().__init__(*expressions, **extra)
1204
1205 def __str__(self):
1206 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1207
1208 def as_sqlite(self, compiler, connection, **extra_context):
1209 # Casting to numeric is unnecessary.
1210 return self.as_sql(compiler, connection, **extra_context)
1211
1212
1213class OrderByList(Func):
1214 template = "ORDER BY %(expressions)s"
1215
1216 def __init__(self, *expressions, **extra):
1217 expressions = (
1218 (
1219 OrderBy(F(expr[1:]), descending=True)
1220 if isinstance(expr, str) and expr[0] == "-"
1221 else expr
1222 )
1223 for expr in expressions
1224 )
1225 super().__init__(*expressions, **extra)
1226
1227 def as_sql(self, *args, **kwargs):
1228 if not self.source_expressions:
1229 return "", ()
1230 return super().as_sql(*args, **kwargs)
1231
1232 def get_group_by_cols(self):
1233 group_by_cols = []
1234 for order_by in self.get_source_expressions():
1235 group_by_cols.extend(order_by.get_group_by_cols())
1236 return group_by_cols
1237
1238
1239@deconstructible(path="plain.models.ExpressionWrapper")
1240class ExpressionWrapper(SQLiteNumericMixin, Expression):
1241 """
1242 An expression that can wrap another expression so that it can provide
1243 extra context to the inner expression, such as the output_field.
1244 """
1245
1246 def __init__(self, expression, output_field):
1247 super().__init__(output_field=output_field)
1248 self.expression = expression
1249
1250 def set_source_expressions(self, exprs):
1251 self.expression = exprs[0]
1252
1253 def get_source_expressions(self):
1254 return [self.expression]
1255
1256 def get_group_by_cols(self):
1257 if isinstance(self.expression, Expression):
1258 expression = self.expression.copy()
1259 expression.output_field = self.output_field
1260 return expression.get_group_by_cols()
1261 # For non-expressions e.g. an SQL WHERE clause, the entire
1262 # `expression` must be included in the GROUP BY clause.
1263 return super().get_group_by_cols()
1264
1265 def as_sql(self, compiler, connection):
1266 return compiler.compile(self.expression)
1267
1268 def __repr__(self):
1269 return f"{self.__class__.__name__}({self.expression})"
1270
1271
1272class NegatedExpression(ExpressionWrapper):
1273 """The logical negation of a conditional expression."""
1274
1275 def __init__(self, expression):
1276 super().__init__(expression, output_field=fields.BooleanField())
1277
1278 def __invert__(self):
1279 return self.expression.copy()
1280
1281 def as_sql(self, compiler, connection):
1282 try:
1283 sql, params = super().as_sql(compiler, connection)
1284 except EmptyResultSet:
1285 features = compiler.connection.features
1286 if not features.supports_boolean_expr_in_select_clause:
1287 return "1=1", ()
1288 return compiler.compile(Value(True))
1289 ops = compiler.connection.ops
1290 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1291 # to be compared to another expression unless they're wrapped in a CASE
1292 # WHEN.
1293 if not ops.conditional_expression_supported_in_where_clause(self.expression):
1294 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1295 return f"NOT {sql}", params
1296
1297 def resolve_expression(
1298 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1299 ):
1300 resolved = super().resolve_expression(
1301 query, allow_joins, reuse, summarize, for_save
1302 )
1303 if not getattr(resolved.expression, "conditional", False):
1304 raise TypeError("Cannot negate non-conditional expressions.")
1305 return resolved
1306
1307 def select_format(self, compiler, sql, params):
1308 # Wrap boolean expressions with a CASE WHEN expression if a database
1309 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1310 # GROUP BY list.
1311 expression_supported_in_where_clause = (
1312 compiler.connection.ops.conditional_expression_supported_in_where_clause
1313 )
1314 if (
1315 not compiler.connection.features.supports_boolean_expr_in_select_clause
1316 # Avoid double wrapping.
1317 and expression_supported_in_where_clause(self.expression)
1318 ):
1319 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1320 return sql, params
1321
1322
1323@deconstructible(path="plain.models.When")
1324class When(Expression):
1325 template = "WHEN %(condition)s THEN %(result)s"
1326 # This isn't a complete conditional expression, must be used in Case().
1327 conditional = False
1328
1329 def __init__(self, condition=None, then=None, **lookups):
1330 if lookups:
1331 if condition is None:
1332 condition, lookups = Q(**lookups), None
1333 elif getattr(condition, "conditional", False):
1334 condition, lookups = Q(condition, **lookups), None
1335 if condition is None or not getattr(condition, "conditional", False) or lookups:
1336 raise TypeError(
1337 "When() supports a Q object, a boolean expression, or lookups "
1338 "as a condition."
1339 )
1340 if isinstance(condition, Q) and not condition:
1341 raise ValueError("An empty Q() can't be used as a When() condition.")
1342 super().__init__(output_field=None)
1343 self.condition = condition
1344 self.result = self._parse_expressions(then)[0]
1345
1346 def __str__(self):
1347 return f"WHEN {self.condition!r} THEN {self.result!r}"
1348
1349 def __repr__(self):
1350 return f"<{self.__class__.__name__}: {self}>"
1351
1352 def get_source_expressions(self):
1353 return [self.condition, self.result]
1354
1355 def set_source_expressions(self, exprs):
1356 self.condition, self.result = exprs
1357
1358 def get_source_fields(self):
1359 # We're only interested in the fields of the result expressions.
1360 return [self.result._output_field_or_none]
1361
1362 def resolve_expression(
1363 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1364 ):
1365 c = self.copy()
1366 c.is_summary = summarize
1367 if hasattr(c.condition, "resolve_expression"):
1368 c.condition = c.condition.resolve_expression(
1369 query, allow_joins, reuse, summarize, False
1370 )
1371 c.result = c.result.resolve_expression(
1372 query, allow_joins, reuse, summarize, for_save
1373 )
1374 return c
1375
1376 def as_sql(self, compiler, connection, template=None, **extra_context):
1377 connection.ops.check_expression_support(self)
1378 template_params = extra_context
1379 sql_params = []
1380 condition_sql, condition_params = compiler.compile(self.condition)
1381 template_params["condition"] = condition_sql
1382 result_sql, result_params = compiler.compile(self.result)
1383 template_params["result"] = result_sql
1384 template = template or self.template
1385 return template % template_params, (
1386 *sql_params,
1387 *condition_params,
1388 *result_params,
1389 )
1390
1391 def get_group_by_cols(self):
1392 # This is not a complete expression and cannot be used in GROUP BY.
1393 cols = []
1394 for source in self.get_source_expressions():
1395 cols.extend(source.get_group_by_cols())
1396 return cols
1397
1398
1399@deconstructible(path="plain.models.Case")
1400class Case(SQLiteNumericMixin, Expression):
1401 """
1402 An SQL searched CASE expression:
1403
1404 CASE
1405 WHEN n > 0
1406 THEN 'positive'
1407 WHEN n < 0
1408 THEN 'negative'
1409 ELSE 'zero'
1410 END
1411 """
1412
1413 template = "CASE %(cases)s ELSE %(default)s END"
1414 case_joiner = " "
1415
1416 def __init__(self, *cases, default=None, output_field=None, **extra):
1417 if not all(isinstance(case, When) for case in cases):
1418 raise TypeError("Positional arguments must all be When objects.")
1419 super().__init__(output_field)
1420 self.cases = list(cases)
1421 self.default = self._parse_expressions(default)[0]
1422 self.extra = extra
1423
1424 def __str__(self):
1425 return "CASE {}, ELSE {!r}".format(
1426 ", ".join(str(c) for c in self.cases),
1427 self.default,
1428 )
1429
1430 def __repr__(self):
1431 return f"<{self.__class__.__name__}: {self}>"
1432
1433 def get_source_expressions(self):
1434 return self.cases + [self.default]
1435
1436 def set_source_expressions(self, exprs):
1437 *self.cases, self.default = exprs
1438
1439 def resolve_expression(
1440 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1441 ):
1442 c = self.copy()
1443 c.is_summary = summarize
1444 for pos, case in enumerate(c.cases):
1445 c.cases[pos] = case.resolve_expression(
1446 query, allow_joins, reuse, summarize, for_save
1447 )
1448 c.default = c.default.resolve_expression(
1449 query, allow_joins, reuse, summarize, for_save
1450 )
1451 return c
1452
1453 def copy(self):
1454 c = super().copy()
1455 c.cases = c.cases[:]
1456 return c
1457
1458 def as_sql(
1459 self, compiler, connection, template=None, case_joiner=None, **extra_context
1460 ):
1461 connection.ops.check_expression_support(self)
1462 if not self.cases:
1463 return compiler.compile(self.default)
1464 template_params = {**self.extra, **extra_context}
1465 case_parts = []
1466 sql_params = []
1467 default_sql, default_params = compiler.compile(self.default)
1468 for case in self.cases:
1469 try:
1470 case_sql, case_params = compiler.compile(case)
1471 except EmptyResultSet:
1472 continue
1473 except FullResultSet:
1474 default_sql, default_params = compiler.compile(case.result)
1475 break
1476 case_parts.append(case_sql)
1477 sql_params.extend(case_params)
1478 if not case_parts:
1479 return default_sql, default_params
1480 case_joiner = case_joiner or self.case_joiner
1481 template_params["cases"] = case_joiner.join(case_parts)
1482 template_params["default"] = default_sql
1483 sql_params.extend(default_params)
1484 template = template or template_params.get("template", self.template)
1485 sql = template % template_params
1486 if self._output_field_or_none is not None:
1487 sql = connection.ops.unification_cast_sql(self.output_field) % sql
1488 return sql, sql_params
1489
1490 def get_group_by_cols(self):
1491 if not self.cases:
1492 return self.default.get_group_by_cols()
1493 return super().get_group_by_cols()
1494
1495
1496class Subquery(BaseExpression, Combinable):
1497 """
1498 An explicit subquery. It may contain OuterRef() references to the outer
1499 query which will be resolved when it is applied to that query.
1500 """
1501
1502 template = "(%(subquery)s)"
1503 contains_aggregate = False
1504 empty_result_set_value = None
1505
1506 def __init__(self, queryset, output_field=None, **extra):
1507 # Allow the usage of both QuerySet and sql.Query objects.
1508 self.query = getattr(queryset, "query", queryset).clone()
1509 self.query.subquery = True
1510 self.extra = extra
1511 super().__init__(output_field)
1512
1513 def get_source_expressions(self):
1514 return [self.query]
1515
1516 def set_source_expressions(self, exprs):
1517 self.query = exprs[0]
1518
1519 def _resolve_output_field(self):
1520 return self.query.output_field
1521
1522 def copy(self):
1523 clone = super().copy()
1524 clone.query = clone.query.clone()
1525 return clone
1526
1527 @property
1528 def external_aliases(self):
1529 return self.query.external_aliases
1530
1531 def get_external_cols(self):
1532 return self.query.get_external_cols()
1533
1534 def as_sql(self, compiler, connection, template=None, **extra_context):
1535 connection.ops.check_expression_support(self)
1536 template_params = {**self.extra, **extra_context}
1537 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1538 template_params["subquery"] = subquery_sql[1:-1]
1539
1540 template = template or template_params.get("template", self.template)
1541 sql = template % template_params
1542 return sql, sql_params
1543
1544 def get_group_by_cols(self):
1545 return self.query.get_group_by_cols(wrapper=self)
1546
1547
1548class Exists(Subquery):
1549 template = "EXISTS(%(subquery)s)"
1550 output_field = fields.BooleanField()
1551 empty_result_set_value = False
1552
1553 def __init__(self, queryset, **kwargs):
1554 super().__init__(queryset, **kwargs)
1555 self.query = self.query.exists()
1556
1557 def select_format(self, compiler, sql, params):
1558 # Wrap EXISTS() with a CASE WHEN expression if a database backend
1559 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1560 # BY list.
1561 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1562 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1563 return sql, params
1564
1565
1566@deconstructible(path="plain.models.OrderBy")
1567class OrderBy(Expression):
1568 template = "%(expression)s %(ordering)s"
1569 conditional = False
1570
1571 def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None):
1572 if nulls_first and nulls_last:
1573 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1574 if nulls_first is False or nulls_last is False:
1575 raise ValueError("nulls_first and nulls_last values must be True or None.")
1576 self.nulls_first = nulls_first
1577 self.nulls_last = nulls_last
1578 self.descending = descending
1579 if not hasattr(expression, "resolve_expression"):
1580 raise ValueError("expression must be an expression type")
1581 self.expression = expression
1582
1583 def __repr__(self):
1584 return "{}({}, descending={})".format(
1585 self.__class__.__name__, self.expression, self.descending
1586 )
1587
1588 def set_source_expressions(self, exprs):
1589 self.expression = exprs[0]
1590
1591 def get_source_expressions(self):
1592 return [self.expression]
1593
1594 def as_sql(self, compiler, connection, template=None, **extra_context):
1595 template = template or self.template
1596 if connection.features.supports_order_by_nulls_modifier:
1597 if self.nulls_last:
1598 template = "%s NULLS LAST" % template
1599 elif self.nulls_first:
1600 template = "%s NULLS FIRST" % template
1601 else:
1602 if self.nulls_last and not (
1603 self.descending and connection.features.order_by_nulls_first
1604 ):
1605 template = "%%(expression)s IS NULL, %s" % template
1606 elif self.nulls_first and not (
1607 not self.descending and connection.features.order_by_nulls_first
1608 ):
1609 template = "%%(expression)s IS NOT NULL, %s" % template
1610 connection.ops.check_expression_support(self)
1611 expression_sql, params = compiler.compile(self.expression)
1612 placeholders = {
1613 "expression": expression_sql,
1614 "ordering": "DESC" if self.descending else "ASC",
1615 **extra_context,
1616 }
1617 params *= template.count("%(expression)s")
1618 return (template % placeholders).rstrip(), params
1619
1620 def get_group_by_cols(self):
1621 cols = []
1622 for source in self.get_source_expressions():
1623 cols.extend(source.get_group_by_cols())
1624 return cols
1625
1626 def reverse_ordering(self):
1627 self.descending = not self.descending
1628 if self.nulls_first:
1629 self.nulls_last = True
1630 self.nulls_first = None
1631 elif self.nulls_last:
1632 self.nulls_first = True
1633 self.nulls_last = None
1634 return self
1635
1636 def asc(self):
1637 self.descending = False
1638
1639 def desc(self):
1640 self.descending = True
1641
1642
1643class Window(SQLiteNumericMixin, Expression):
1644 template = "%(expression)s OVER (%(window)s)"
1645 # Although the main expression may either be an aggregate or an
1646 # expression with an aggregate function, the GROUP BY that will
1647 # be introduced in the query as a result is not desired.
1648 contains_aggregate = False
1649 contains_over_clause = True
1650
1651 def __init__(
1652 self,
1653 expression,
1654 partition_by=None,
1655 order_by=None,
1656 frame=None,
1657 output_field=None,
1658 ):
1659 self.partition_by = partition_by
1660 self.order_by = order_by
1661 self.frame = frame
1662
1663 if not getattr(expression, "window_compatible", False):
1664 raise ValueError(
1665 "Expression '%s' isn't compatible with OVER clauses."
1666 % expression.__class__.__name__
1667 )
1668
1669 if self.partition_by is not None:
1670 if not isinstance(self.partition_by, tuple | list):
1671 self.partition_by = (self.partition_by,)
1672 self.partition_by = ExpressionList(*self.partition_by)
1673
1674 if self.order_by is not None:
1675 if isinstance(self.order_by, list | tuple):
1676 self.order_by = OrderByList(*self.order_by)
1677 elif isinstance(self.order_by, BaseExpression | str):
1678 self.order_by = OrderByList(self.order_by)
1679 else:
1680 raise ValueError(
1681 "Window.order_by must be either a string reference to a "
1682 "field, an expression, or a list or tuple of them."
1683 )
1684 super().__init__(output_field=output_field)
1685 self.source_expression = self._parse_expressions(expression)[0]
1686
1687 def _resolve_output_field(self):
1688 return self.source_expression.output_field
1689
1690 def get_source_expressions(self):
1691 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1692
1693 def set_source_expressions(self, exprs):
1694 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1695
1696 def as_sql(self, compiler, connection, template=None):
1697 connection.ops.check_expression_support(self)
1698 if not connection.features.supports_over_clause:
1699 raise NotSupportedError("This backend does not support window expressions.")
1700 expr_sql, params = compiler.compile(self.source_expression)
1701 window_sql, window_params = [], ()
1702
1703 if self.partition_by is not None:
1704 sql_expr, sql_params = self.partition_by.as_sql(
1705 compiler=compiler,
1706 connection=connection,
1707 template="PARTITION BY %(expressions)s",
1708 )
1709 window_sql.append(sql_expr)
1710 window_params += tuple(sql_params)
1711
1712 if self.order_by is not None:
1713 order_sql, order_params = compiler.compile(self.order_by)
1714 window_sql.append(order_sql)
1715 window_params += tuple(order_params)
1716
1717 if self.frame:
1718 frame_sql, frame_params = compiler.compile(self.frame)
1719 window_sql.append(frame_sql)
1720 window_params += tuple(frame_params)
1721
1722 template = template or self.template
1723
1724 return (
1725 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1726 (*params, *window_params),
1727 )
1728
1729 def as_sqlite(self, compiler, connection):
1730 if isinstance(self.output_field, fields.DecimalField):
1731 # Casting to numeric must be outside of the window expression.
1732 copy = self.copy()
1733 source_expressions = copy.get_source_expressions()
1734 source_expressions[0].output_field = fields.FloatField()
1735 copy.set_source_expressions(source_expressions)
1736 return super(Window, copy).as_sqlite(compiler, connection)
1737 return self.as_sql(compiler, connection)
1738
1739 def __str__(self):
1740 return "{} OVER ({}{}{})".format(
1741 str(self.source_expression),
1742 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1743 str(self.order_by or ""),
1744 str(self.frame or ""),
1745 )
1746
1747 def __repr__(self):
1748 return f"<{self.__class__.__name__}: {self}>"
1749
1750 def get_group_by_cols(self):
1751 group_by_cols = []
1752 if self.partition_by:
1753 group_by_cols.extend(self.partition_by.get_group_by_cols())
1754 if self.order_by is not None:
1755 group_by_cols.extend(self.order_by.get_group_by_cols())
1756 return group_by_cols
1757
1758
1759class WindowFrame(Expression):
1760 """
1761 Model the frame clause in window expressions. There are two types of frame
1762 clauses which are subclasses, however, all processing and validation (by no
1763 means intended to be complete) is done here. Thus, providing an end for a
1764 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1765 row in the frame).
1766 """
1767
1768 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1769
1770 def __init__(self, start=None, end=None):
1771 self.start = Value(start)
1772 self.end = Value(end)
1773
1774 def set_source_expressions(self, exprs):
1775 self.start, self.end = exprs
1776
1777 def get_source_expressions(self):
1778 return [self.start, self.end]
1779
1780 def as_sql(self, compiler, connection):
1781 connection.ops.check_expression_support(self)
1782 start, end = self.window_frame_start_end(
1783 connection, self.start.value, self.end.value
1784 )
1785 return (
1786 self.template
1787 % {
1788 "frame_type": self.frame_type,
1789 "start": start,
1790 "end": end,
1791 },
1792 [],
1793 )
1794
1795 def __repr__(self):
1796 return f"<{self.__class__.__name__}: {self}>"
1797
1798 def get_group_by_cols(self):
1799 return []
1800
1801 def __str__(self):
1802 if self.start.value is not None and self.start.value < 0:
1803 start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
1804 elif self.start.value is not None and self.start.value == 0:
1805 start = connection.ops.CURRENT_ROW
1806 else:
1807 start = connection.ops.UNBOUNDED_PRECEDING
1808
1809 if self.end.value is not None and self.end.value > 0:
1810 end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
1811 elif self.end.value is not None and self.end.value == 0:
1812 end = connection.ops.CURRENT_ROW
1813 else:
1814 end = connection.ops.UNBOUNDED_FOLLOWING
1815 return self.template % {
1816 "frame_type": self.frame_type,
1817 "start": start,
1818 "end": end,
1819 }
1820
1821 def window_frame_start_end(self, connection, start, end):
1822 raise NotImplementedError("Subclasses must implement window_frame_start_end().")
1823
1824
1825class RowRange(WindowFrame):
1826 frame_type = "ROWS"
1827
1828 def window_frame_start_end(self, connection, start, end):
1829 return connection.ops.window_frame_rows_start_end(start, end)
1830
1831
1832class ValueRange(WindowFrame):
1833 frame_type = "RANGE"
1834
1835 def window_frame_start_end(self, connection, start, end):
1836 return connection.ops.window_frame_range_start_end(start, end)