plain.models
Model your data and store it in a database.
# app/users/models.py
from plain import models
from plain.passwords.models import PasswordField
class User(models.Model):
email = models.EmailField()
password = PasswordField()
is_admin = models.BooleanField(default=False)
created_at = models.DateTimeField(auto_now_add=True)
def __str__(self):
return self.email
Create, update, and delete instances of your models:
from .models import User
# Create a new user
user = User.objects.create(
email="[email protected]",
password="password",
)
# Update a user
user.email = "[email protected]"
user.save()
# Delete a user
user.delete()
# Query for users
admin_users = User.objects.filter(is_admin=True)
Installation
# app/settings.py
INSTALLED_PACKAGES = [
...
"plain.models",
]
To connect to a database, you can provide a DATABASE_URL
environment variable.
DATABASE_URL=postgresql://user:password@localhost:5432/dbname
Or you can manually define the DATABASES
setting.
# app/settings.py
DATABASES = {
"default": {
"ENGINE": "plain.models.backends.postgresql",
"NAME": "dbname",
"USER": "user",
"PASSWORD": "password",
"HOST": "localhost",
"PORT": "5432",
}
}
Multiple backends are supported, including Postgres, MySQL, and SQLite.
Querying
Migrations
Fields
Validation
Indexes and constraints
Managers
Forms
1import copy
2import datetime
3import functools
4import inspect
5from collections import defaultdict
6from decimal import Decimal
7from types import NoneType
8from uuid import UUID
9
10from plain.exceptions import EmptyResultSet, FieldError, FullResultSet
11from plain.models import fields
12from plain.models.constants import LOOKUP_SEP
13from plain.models.db import DatabaseError, NotSupportedError, connection
14from plain.models.query_utils import Q
15from plain.utils.deconstruct import deconstructible
16from plain.utils.functional import cached_property
17from plain.utils.hashable import make_hashable
18
19
20class SQLiteNumericMixin:
21 """
22 Some expressions with output_field=DecimalField() must be cast to
23 numeric to be properly filtered.
24 """
25
26 def as_sqlite(self, compiler, connection, **extra_context):
27 sql, params = self.as_sql(compiler, connection, **extra_context)
28 try:
29 if self.output_field.get_internal_type() == "DecimalField":
30 sql = f"CAST({sql} AS NUMERIC)"
31 except FieldError:
32 pass
33 return sql, params
34
35
36class Combinable:
37 """
38 Provide the ability to combine one or two objects with
39 some connector. For example F('foo') + F('bar').
40 """
41
42 # Arithmetic connectors
43 ADD = "+"
44 SUB = "-"
45 MUL = "*"
46 DIV = "/"
47 POW = "^"
48 # The following is a quoted % operator - it is quoted because it can be
49 # used in strings that also have parameter substitution.
50 MOD = "%%"
51
52 # Bitwise operators - note that these are generated by .bitand()
53 # and .bitor(), the '&' and '|' are reserved for boolean operator
54 # usage.
55 BITAND = "&"
56 BITOR = "|"
57 BITLEFTSHIFT = "<<"
58 BITRIGHTSHIFT = ">>"
59 BITXOR = "#"
60
61 def _combine(self, other, connector, reversed):
62 if not hasattr(other, "resolve_expression"):
63 # everything must be resolvable to an expression
64 other = Value(other)
65
66 if reversed:
67 return CombinedExpression(other, connector, self)
68 return CombinedExpression(self, connector, other)
69
70 #############
71 # OPERATORS #
72 #############
73
74 def __neg__(self):
75 return self._combine(-1, self.MUL, False)
76
77 def __add__(self, other):
78 return self._combine(other, self.ADD, False)
79
80 def __sub__(self, other):
81 return self._combine(other, self.SUB, False)
82
83 def __mul__(self, other):
84 return self._combine(other, self.MUL, False)
85
86 def __truediv__(self, other):
87 return self._combine(other, self.DIV, False)
88
89 def __mod__(self, other):
90 return self._combine(other, self.MOD, False)
91
92 def __pow__(self, other):
93 return self._combine(other, self.POW, False)
94
95 def __and__(self, other):
96 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
97 return Q(self) & Q(other)
98 raise NotImplementedError(
99 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
100 )
101
102 def bitand(self, other):
103 return self._combine(other, self.BITAND, False)
104
105 def bitleftshift(self, other):
106 return self._combine(other, self.BITLEFTSHIFT, False)
107
108 def bitrightshift(self, other):
109 return self._combine(other, self.BITRIGHTSHIFT, False)
110
111 def __xor__(self, other):
112 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
113 return Q(self) ^ Q(other)
114 raise NotImplementedError(
115 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
116 )
117
118 def bitxor(self, other):
119 return self._combine(other, self.BITXOR, False)
120
121 def __or__(self, other):
122 if getattr(self, "conditional", False) and getattr(other, "conditional", False):
123 return Q(self) | Q(other)
124 raise NotImplementedError(
125 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
126 )
127
128 def bitor(self, other):
129 return self._combine(other, self.BITOR, False)
130
131 def __radd__(self, other):
132 return self._combine(other, self.ADD, True)
133
134 def __rsub__(self, other):
135 return self._combine(other, self.SUB, True)
136
137 def __rmul__(self, other):
138 return self._combine(other, self.MUL, True)
139
140 def __rtruediv__(self, other):
141 return self._combine(other, self.DIV, True)
142
143 def __rmod__(self, other):
144 return self._combine(other, self.MOD, True)
145
146 def __rpow__(self, other):
147 return self._combine(other, self.POW, True)
148
149 def __rand__(self, other):
150 raise NotImplementedError(
151 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
152 )
153
154 def __ror__(self, other):
155 raise NotImplementedError(
156 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
157 )
158
159 def __rxor__(self, other):
160 raise NotImplementedError(
161 "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
162 )
163
164 def __invert__(self):
165 return NegatedExpression(self)
166
167
168class BaseExpression:
169 """Base class for all query expressions."""
170
171 empty_result_set_value = NotImplemented
172 # aggregate specific fields
173 is_summary = False
174 _output_field_resolved_to_none = False
175 # Can the expression be used in a WHERE clause?
176 filterable = True
177 # Can the expression can be used as a source expression in Window?
178 window_compatible = False
179
180 def __init__(self, output_field=None):
181 if output_field is not None:
182 self.output_field = output_field
183
184 def __getstate__(self):
185 state = self.__dict__.copy()
186 state.pop("convert_value", None)
187 return state
188
189 def get_db_converters(self, connection):
190 return (
191 []
192 if self.convert_value is self._convert_value_noop
193 else [self.convert_value]
194 ) + self.output_field.get_db_converters(connection)
195
196 def get_source_expressions(self):
197 return []
198
199 def set_source_expressions(self, exprs):
200 assert not exprs
201
202 def _parse_expressions(self, *expressions):
203 return [
204 arg
205 if hasattr(arg, "resolve_expression")
206 else (F(arg) if isinstance(arg, str) else Value(arg))
207 for arg in expressions
208 ]
209
210 def as_sql(self, compiler, connection):
211 """
212 Responsible for returning a (sql, [params]) tuple to be included
213 in the current query.
214
215 Different backends can provide their own implementation, by
216 providing an `as_{vendor}` method and patching the Expression:
217
218 ```
219 def override_as_sql(self, compiler, connection):
220 # custom logic
221 return super().as_sql(compiler, connection)
222 setattr(Expression, 'as_' + connection.vendor, override_as_sql)
223 ```
224
225 Arguments:
226 * compiler: the query compiler responsible for generating the query.
227 Must have a compile method, returning a (sql, [params]) tuple.
228 Calling compiler(value) will return a quoted `value`.
229
230 * connection: the database connection used for the current query.
231
232 Return: (sql, params)
233 Where `sql` is a string containing ordered sql parameters to be
234 replaced with the elements of the list `params`.
235 """
236 raise NotImplementedError("Subclasses must implement as_sql()")
237
238 @cached_property
239 def contains_aggregate(self):
240 return any(
241 expr and expr.contains_aggregate for expr in self.get_source_expressions()
242 )
243
244 @cached_property
245 def contains_over_clause(self):
246 return any(
247 expr and expr.contains_over_clause for expr in self.get_source_expressions()
248 )
249
250 @cached_property
251 def contains_column_references(self):
252 return any(
253 expr and expr.contains_column_references
254 for expr in self.get_source_expressions()
255 )
256
257 def resolve_expression(
258 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
259 ):
260 """
261 Provide the chance to do any preprocessing or validation before being
262 added to the query.
263
264 Arguments:
265 * query: the backend query implementation
266 * allow_joins: boolean allowing or denying use of joins
267 in this query
268 * reuse: a set of reusable joins for multijoins
269 * summarize: a terminal aggregate clause
270 * for_save: whether this expression about to be used in a save or update
271
272 Return: an Expression to be added to the query.
273 """
274 c = self.copy()
275 c.is_summary = summarize
276 c.set_source_expressions(
277 [
278 expr.resolve_expression(query, allow_joins, reuse, summarize)
279 if expr
280 else None
281 for expr in c.get_source_expressions()
282 ]
283 )
284 return c
285
286 @property
287 def conditional(self):
288 return isinstance(self.output_field, fields.BooleanField)
289
290 @property
291 def field(self):
292 return self.output_field
293
294 @cached_property
295 def output_field(self):
296 """Return the output type of this expressions."""
297 output_field = self._resolve_output_field()
298 if output_field is None:
299 self._output_field_resolved_to_none = True
300 raise FieldError("Cannot resolve expression type, unknown output_field")
301 return output_field
302
303 @cached_property
304 def _output_field_or_none(self):
305 """
306 Return the output field of this expression, or None if
307 _resolve_output_field() didn't return an output type.
308 """
309 try:
310 return self.output_field
311 except FieldError:
312 if not self._output_field_resolved_to_none:
313 raise
314
315 def _resolve_output_field(self):
316 """
317 Attempt to infer the output type of the expression.
318
319 As a guess, if the output fields of all source fields match then simply
320 infer the same type here.
321
322 If a source's output field resolves to None, exclude it from this check.
323 If all sources are None, then an error is raised higher up the stack in
324 the output_field property.
325 """
326 # This guess is mostly a bad idea, but there is quite a lot of code
327 # (especially 3rd party Func subclasses) that depend on it, we'd need a
328 # deprecation path to fix it.
329 sources_iter = (
330 source for source in self.get_source_fields() if source is not None
331 )
332 for output_field in sources_iter:
333 for source in sources_iter:
334 if not isinstance(output_field, source.__class__):
335 raise FieldError(
336 f"Expression contains mixed types: {output_field.__class__.__name__}, {source.__class__.__name__}. You must "
337 "set output_field."
338 )
339 return output_field
340
341 @staticmethod
342 def _convert_value_noop(value, expression, connection):
343 return value
344
345 @cached_property
346 def convert_value(self):
347 """
348 Expressions provide their own converters because users have the option
349 of manually specifying the output_field which may be a different type
350 from the one the database returns.
351 """
352 field = self.output_field
353 internal_type = field.get_internal_type()
354 if internal_type == "FloatField":
355 return (
356 lambda value, expression, connection: None
357 if value is None
358 else float(value)
359 )
360 elif internal_type.endswith("IntegerField"):
361 return (
362 lambda value, expression, connection: None
363 if value is None
364 else int(value)
365 )
366 elif internal_type == "DecimalField":
367 return (
368 lambda value, expression, connection: None
369 if value is None
370 else Decimal(value)
371 )
372 return self._convert_value_noop
373
374 def get_lookup(self, lookup):
375 return self.output_field.get_lookup(lookup)
376
377 def get_transform(self, name):
378 return self.output_field.get_transform(name)
379
380 def relabeled_clone(self, change_map):
381 clone = self.copy()
382 clone.set_source_expressions(
383 [
384 e.relabeled_clone(change_map) if e is not None else None
385 for e in self.get_source_expressions()
386 ]
387 )
388 return clone
389
390 def replace_expressions(self, replacements):
391 if replacement := replacements.get(self):
392 return replacement
393 clone = self.copy()
394 source_expressions = clone.get_source_expressions()
395 clone.set_source_expressions(
396 [
397 expr.replace_expressions(replacements) if expr else None
398 for expr in source_expressions
399 ]
400 )
401 return clone
402
403 def get_refs(self):
404 refs = set()
405 for expr in self.get_source_expressions():
406 refs |= expr.get_refs()
407 return refs
408
409 def copy(self):
410 return copy.copy(self)
411
412 def prefix_references(self, prefix):
413 clone = self.copy()
414 clone.set_source_expressions(
415 [
416 F(f"{prefix}{expr.name}")
417 if isinstance(expr, F)
418 else expr.prefix_references(prefix)
419 for expr in self.get_source_expressions()
420 ]
421 )
422 return clone
423
424 def get_group_by_cols(self):
425 if not self.contains_aggregate:
426 return [self]
427 cols = []
428 for source in self.get_source_expressions():
429 cols.extend(source.get_group_by_cols())
430 return cols
431
432 def get_source_fields(self):
433 """Return the underlying field types used by this aggregate."""
434 return [e._output_field_or_none for e in self.get_source_expressions()]
435
436 def asc(self, **kwargs):
437 return OrderBy(self, **kwargs)
438
439 def desc(self, **kwargs):
440 return OrderBy(self, descending=True, **kwargs)
441
442 def reverse_ordering(self):
443 return self
444
445 def flatten(self):
446 """
447 Recursively yield this expression and all subexpressions, in
448 depth-first order.
449 """
450 yield self
451 for expr in self.get_source_expressions():
452 if expr:
453 if hasattr(expr, "flatten"):
454 yield from expr.flatten()
455 else:
456 yield expr
457
458 def select_format(self, compiler, sql, params):
459 """
460 Custom format for select clauses. For example, EXISTS expressions need
461 to be wrapped in CASE WHEN on Oracle.
462 """
463 if hasattr(self.output_field, "select_format"):
464 return self.output_field.select_format(compiler, sql, params)
465 return sql, params
466
467
468@deconstructible
469class Expression(BaseExpression, Combinable):
470 """An expression that can be combined with other expressions."""
471
472 @cached_property
473 def identity(self):
474 constructor_signature = inspect.signature(self.__init__)
475 args, kwargs = self._constructor_args
476 signature = constructor_signature.bind_partial(*args, **kwargs)
477 signature.apply_defaults()
478 arguments = signature.arguments.items()
479 identity = [self.__class__]
480 for arg, value in arguments:
481 if isinstance(value, fields.Field):
482 if value.name and value.model:
483 value = (value.model._meta.label, value.name)
484 else:
485 value = type(value)
486 else:
487 value = make_hashable(value)
488 identity.append((arg, value))
489 return tuple(identity)
490
491 def __eq__(self, other):
492 if not isinstance(other, Expression):
493 return NotImplemented
494 return other.identity == self.identity
495
496 def __hash__(self):
497 return hash(self.identity)
498
499
500# Type inference for CombinedExpression.output_field.
501# Missing items will result in FieldError, by design.
502#
503# The current approach for NULL is based on lowest common denominator behavior
504# i.e. if one of the supported databases is raising an error (rather than
505# return NULL) for `val <op> NULL`, then Plain raises FieldError.
506
507_connector_combinations = [
508 # Numeric operations - operands of same type.
509 {
510 connector: [
511 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
512 (fields.FloatField, fields.FloatField, fields.FloatField),
513 (fields.DecimalField, fields.DecimalField, fields.DecimalField),
514 ]
515 for connector in (
516 Combinable.ADD,
517 Combinable.SUB,
518 Combinable.MUL,
519 # Behavior for DIV with integer arguments follows Postgres/SQLite,
520 # not MySQL/Oracle.
521 Combinable.DIV,
522 Combinable.MOD,
523 Combinable.POW,
524 )
525 },
526 # Numeric operations - operands of different type.
527 {
528 connector: [
529 (fields.IntegerField, fields.DecimalField, fields.DecimalField),
530 (fields.DecimalField, fields.IntegerField, fields.DecimalField),
531 (fields.IntegerField, fields.FloatField, fields.FloatField),
532 (fields.FloatField, fields.IntegerField, fields.FloatField),
533 ]
534 for connector in (
535 Combinable.ADD,
536 Combinable.SUB,
537 Combinable.MUL,
538 Combinable.DIV,
539 Combinable.MOD,
540 )
541 },
542 # Bitwise operators.
543 {
544 connector: [
545 (fields.IntegerField, fields.IntegerField, fields.IntegerField),
546 ]
547 for connector in (
548 Combinable.BITAND,
549 Combinable.BITOR,
550 Combinable.BITLEFTSHIFT,
551 Combinable.BITRIGHTSHIFT,
552 Combinable.BITXOR,
553 )
554 },
555 # Numeric with NULL.
556 {
557 connector: [
558 (field_type, NoneType, field_type),
559 (NoneType, field_type, field_type),
560 ]
561 for connector in (
562 Combinable.ADD,
563 Combinable.SUB,
564 Combinable.MUL,
565 Combinable.DIV,
566 Combinable.MOD,
567 Combinable.POW,
568 )
569 for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
570 },
571 # Date/DateTimeField/DurationField/TimeField.
572 {
573 Combinable.ADD: [
574 # Date/DateTimeField.
575 (fields.DateField, fields.DurationField, fields.DateTimeField),
576 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
577 (fields.DurationField, fields.DateField, fields.DateTimeField),
578 (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
579 # DurationField.
580 (fields.DurationField, fields.DurationField, fields.DurationField),
581 # TimeField.
582 (fields.TimeField, fields.DurationField, fields.TimeField),
583 (fields.DurationField, fields.TimeField, fields.TimeField),
584 ],
585 },
586 {
587 Combinable.SUB: [
588 # Date/DateTimeField.
589 (fields.DateField, fields.DurationField, fields.DateTimeField),
590 (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
591 (fields.DateField, fields.DateField, fields.DurationField),
592 (fields.DateField, fields.DateTimeField, fields.DurationField),
593 (fields.DateTimeField, fields.DateField, fields.DurationField),
594 (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
595 # DurationField.
596 (fields.DurationField, fields.DurationField, fields.DurationField),
597 # TimeField.
598 (fields.TimeField, fields.DurationField, fields.TimeField),
599 (fields.TimeField, fields.TimeField, fields.DurationField),
600 ],
601 },
602]
603
604_connector_combinators = defaultdict(list)
605
606
607def register_combinable_fields(lhs, connector, rhs, result):
608 """
609 Register combinable types:
610 lhs <connector> rhs -> result
611 e.g.
612 register_combinable_fields(
613 IntegerField, Combinable.ADD, FloatField, FloatField
614 )
615 """
616 _connector_combinators[connector].append((lhs, rhs, result))
617
618
619for d in _connector_combinations:
620 for connector, field_types in d.items():
621 for lhs, rhs, result in field_types:
622 register_combinable_fields(lhs, connector, rhs, result)
623
624
625@functools.lru_cache(maxsize=128)
626def _resolve_combined_type(connector, lhs_type, rhs_type):
627 combinators = _connector_combinators.get(connector, ())
628 for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
629 if issubclass(lhs_type, combinator_lhs_type) and issubclass(
630 rhs_type, combinator_rhs_type
631 ):
632 return combined_type
633
634
635class CombinedExpression(SQLiteNumericMixin, Expression):
636 def __init__(self, lhs, connector, rhs, output_field=None):
637 super().__init__(output_field=output_field)
638 self.connector = connector
639 self.lhs = lhs
640 self.rhs = rhs
641
642 def __repr__(self):
643 return f"<{self.__class__.__name__}: {self}>"
644
645 def __str__(self):
646 return f"{self.lhs} {self.connector} {self.rhs}"
647
648 def get_source_expressions(self):
649 return [self.lhs, self.rhs]
650
651 def set_source_expressions(self, exprs):
652 self.lhs, self.rhs = exprs
653
654 def _resolve_output_field(self):
655 # We avoid using super() here for reasons given in
656 # Expression._resolve_output_field()
657 combined_type = _resolve_combined_type(
658 self.connector,
659 type(self.lhs._output_field_or_none),
660 type(self.rhs._output_field_or_none),
661 )
662 if combined_type is None:
663 raise FieldError(
664 f"Cannot infer type of {self.connector!r} expression involving these "
665 f"types: {self.lhs.output_field.__class__.__name__}, "
666 f"{self.rhs.output_field.__class__.__name__}. You must set "
667 f"output_field."
668 )
669 return combined_type()
670
671 def as_sql(self, compiler, connection):
672 expressions = []
673 expression_params = []
674 sql, params = compiler.compile(self.lhs)
675 expressions.append(sql)
676 expression_params.extend(params)
677 sql, params = compiler.compile(self.rhs)
678 expressions.append(sql)
679 expression_params.extend(params)
680 # order of precedence
681 expression_wrapper = "(%s)"
682 sql = connection.ops.combine_expression(self.connector, expressions)
683 return expression_wrapper % sql, expression_params
684
685 def resolve_expression(
686 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
687 ):
688 lhs = self.lhs.resolve_expression(
689 query, allow_joins, reuse, summarize, for_save
690 )
691 rhs = self.rhs.resolve_expression(
692 query, allow_joins, reuse, summarize, for_save
693 )
694 if not isinstance(self, DurationExpression | TemporalSubtraction):
695 try:
696 lhs_type = lhs.output_field.get_internal_type()
697 except (AttributeError, FieldError):
698 lhs_type = None
699 try:
700 rhs_type = rhs.output_field.get_internal_type()
701 except (AttributeError, FieldError):
702 rhs_type = None
703 if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
704 return DurationExpression(
705 self.lhs, self.connector, self.rhs
706 ).resolve_expression(
707 query,
708 allow_joins,
709 reuse,
710 summarize,
711 for_save,
712 )
713 datetime_fields = {"DateField", "DateTimeField", "TimeField"}
714 if (
715 self.connector == self.SUB
716 and lhs_type in datetime_fields
717 and lhs_type == rhs_type
718 ):
719 return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
720 query,
721 allow_joins,
722 reuse,
723 summarize,
724 for_save,
725 )
726 c = self.copy()
727 c.is_summary = summarize
728 c.lhs = lhs
729 c.rhs = rhs
730 return c
731
732
733class DurationExpression(CombinedExpression):
734 def compile(self, side, compiler, connection):
735 try:
736 output = side.output_field
737 except FieldError:
738 pass
739 else:
740 if output.get_internal_type() == "DurationField":
741 sql, params = compiler.compile(side)
742 return connection.ops.format_for_duration_arithmetic(sql), params
743 return compiler.compile(side)
744
745 def as_sql(self, compiler, connection):
746 if connection.features.has_native_duration_field:
747 return super().as_sql(compiler, connection)
748 connection.ops.check_expression_support(self)
749 expressions = []
750 expression_params = []
751 sql, params = self.compile(self.lhs, compiler, connection)
752 expressions.append(sql)
753 expression_params.extend(params)
754 sql, params = self.compile(self.rhs, compiler, connection)
755 expressions.append(sql)
756 expression_params.extend(params)
757 # order of precedence
758 expression_wrapper = "(%s)"
759 sql = connection.ops.combine_duration_expression(self.connector, expressions)
760 return expression_wrapper % sql, expression_params
761
762 def as_sqlite(self, compiler, connection, **extra_context):
763 sql, params = self.as_sql(compiler, connection, **extra_context)
764 if self.connector in {Combinable.MUL, Combinable.DIV}:
765 try:
766 lhs_type = self.lhs.output_field.get_internal_type()
767 rhs_type = self.rhs.output_field.get_internal_type()
768 except (AttributeError, FieldError):
769 pass
770 else:
771 allowed_fields = {
772 "DecimalField",
773 "DurationField",
774 "FloatField",
775 "IntegerField",
776 }
777 if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
778 raise DatabaseError(
779 f"Invalid arguments for operator {self.connector}."
780 )
781 return sql, params
782
783
784class TemporalSubtraction(CombinedExpression):
785 output_field = fields.DurationField()
786
787 def __init__(self, lhs, rhs):
788 super().__init__(lhs, self.SUB, rhs)
789
790 def as_sql(self, compiler, connection):
791 connection.ops.check_expression_support(self)
792 lhs = compiler.compile(self.lhs)
793 rhs = compiler.compile(self.rhs)
794 return connection.ops.subtract_temporals(
795 self.lhs.output_field.get_internal_type(), lhs, rhs
796 )
797
798
799@deconstructible(path="plain.models.F")
800class F(Combinable):
801 """An object capable of resolving references to existing query objects."""
802
803 def __init__(self, name):
804 """
805 Arguments:
806 * name: the name of the field this expression references
807 """
808 self.name = name
809
810 def __repr__(self):
811 return f"{self.__class__.__name__}({self.name})"
812
813 def resolve_expression(
814 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
815 ):
816 return query.resolve_ref(self.name, allow_joins, reuse, summarize)
817
818 def replace_expressions(self, replacements):
819 return replacements.get(self, self)
820
821 def asc(self, **kwargs):
822 return OrderBy(self, **kwargs)
823
824 def desc(self, **kwargs):
825 return OrderBy(self, descending=True, **kwargs)
826
827 def __eq__(self, other):
828 return self.__class__ == other.__class__ and self.name == other.name
829
830 def __hash__(self):
831 return hash(self.name)
832
833 def copy(self):
834 return copy.copy(self)
835
836
837class ResolvedOuterRef(F):
838 """
839 An object that contains a reference to an outer query.
840
841 In this case, the reference to the outer query has been resolved because
842 the inner query has been used as a subquery.
843 """
844
845 contains_aggregate = False
846 contains_over_clause = False
847
848 def as_sql(self, *args, **kwargs):
849 raise ValueError(
850 "This queryset contains a reference to an outer query and may "
851 "only be used in a subquery."
852 )
853
854 def resolve_expression(self, *args, **kwargs):
855 col = super().resolve_expression(*args, **kwargs)
856 if col.contains_over_clause:
857 raise NotSupportedError(
858 f"Referencing outer query window expression is not supported: "
859 f"{self.name}."
860 )
861 # FIXME: Rename possibly_multivalued to multivalued and fix detection
862 # for non-multivalued JOINs (e.g. foreign key fields). This should take
863 # into account only many-to-many and one-to-many relationships.
864 col.possibly_multivalued = LOOKUP_SEP in self.name
865 return col
866
867 def relabeled_clone(self, relabels):
868 return self
869
870 def get_group_by_cols(self):
871 return []
872
873
874class OuterRef(F):
875 contains_aggregate = False
876
877 def resolve_expression(self, *args, **kwargs):
878 if isinstance(self.name, self.__class__):
879 return self.name
880 return ResolvedOuterRef(self.name)
881
882 def relabeled_clone(self, relabels):
883 return self
884
885
886@deconstructible(path="plain.models.Func")
887class Func(SQLiteNumericMixin, Expression):
888 """An SQL function call."""
889
890 function = None
891 template = "%(function)s(%(expressions)s)"
892 arg_joiner = ", "
893 arity = None # The number of arguments the function accepts.
894
895 def __init__(self, *expressions, output_field=None, **extra):
896 if self.arity is not None and len(expressions) != self.arity:
897 raise TypeError(
898 "'{}' takes exactly {} {} ({} given)".format(
899 self.__class__.__name__,
900 self.arity,
901 "argument" if self.arity == 1 else "arguments",
902 len(expressions),
903 )
904 )
905 super().__init__(output_field=output_field)
906 self.source_expressions = self._parse_expressions(*expressions)
907 self.extra = extra
908
909 def __repr__(self):
910 args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
911 extra = {**self.extra, **self._get_repr_options()}
912 if extra:
913 extra = ", ".join(
914 str(key) + "=" + str(val) for key, val in sorted(extra.items())
915 )
916 return f"{self.__class__.__name__}({args}, {extra})"
917 return f"{self.__class__.__name__}({args})"
918
919 def _get_repr_options(self):
920 """Return a dict of extra __init__() options to include in the repr."""
921 return {}
922
923 def get_source_expressions(self):
924 return self.source_expressions
925
926 def set_source_expressions(self, exprs):
927 self.source_expressions = exprs
928
929 def resolve_expression(
930 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
931 ):
932 c = self.copy()
933 c.is_summary = summarize
934 for pos, arg in enumerate(c.source_expressions):
935 c.source_expressions[pos] = arg.resolve_expression(
936 query, allow_joins, reuse, summarize, for_save
937 )
938 return c
939
940 def as_sql(
941 self,
942 compiler,
943 connection,
944 function=None,
945 template=None,
946 arg_joiner=None,
947 **extra_context,
948 ):
949 connection.ops.check_expression_support(self)
950 sql_parts = []
951 params = []
952 for arg in self.source_expressions:
953 try:
954 arg_sql, arg_params = compiler.compile(arg)
955 except EmptyResultSet:
956 empty_result_set_value = getattr(
957 arg, "empty_result_set_value", NotImplemented
958 )
959 if empty_result_set_value is NotImplemented:
960 raise
961 arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
962 except FullResultSet:
963 arg_sql, arg_params = compiler.compile(Value(True))
964 sql_parts.append(arg_sql)
965 params.extend(arg_params)
966 data = {**self.extra, **extra_context}
967 # Use the first supplied value in this order: the parameter to this
968 # method, a value supplied in __init__()'s **extra (the value in
969 # `data`), or the value defined on the class.
970 if function is not None:
971 data["function"] = function
972 else:
973 data.setdefault("function", self.function)
974 template = template or data.get("template", self.template)
975 arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
976 data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
977 return template % data, params
978
979 def copy(self):
980 copy = super().copy()
981 copy.source_expressions = self.source_expressions[:]
982 copy.extra = self.extra.copy()
983 return copy
984
985
986@deconstructible(path="plain.models.Value")
987class Value(SQLiteNumericMixin, Expression):
988 """Represent a wrapped value as a node within an expression."""
989
990 # Provide a default value for `for_save` in order to allow unresolved
991 # instances to be compiled until a decision is taken in #25425.
992 for_save = False
993
994 def __init__(self, value, output_field=None):
995 """
996 Arguments:
997 * value: the value this expression represents. The value will be
998 added into the sql parameter list and properly quoted.
999
1000 * output_field: an instance of the model field type that this
1001 expression will return, such as IntegerField() or CharField().
1002 """
1003 super().__init__(output_field=output_field)
1004 self.value = value
1005
1006 def __repr__(self):
1007 return f"{self.__class__.__name__}({self.value!r})"
1008
1009 def as_sql(self, compiler, connection):
1010 connection.ops.check_expression_support(self)
1011 val = self.value
1012 output_field = self._output_field_or_none
1013 if output_field is not None:
1014 if self.for_save:
1015 val = output_field.get_db_prep_save(val, connection=connection)
1016 else:
1017 val = output_field.get_db_prep_value(val, connection=connection)
1018 if hasattr(output_field, "get_placeholder"):
1019 return output_field.get_placeholder(val, compiler, connection), [val]
1020 if val is None:
1021 # cx_Oracle does not always convert None to the appropriate
1022 # NULL type (like in case expressions using numbers), so we
1023 # use a literal SQL NULL
1024 return "NULL", []
1025 return "%s", [val]
1026
1027 def resolve_expression(
1028 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1029 ):
1030 c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
1031 c.for_save = for_save
1032 return c
1033
1034 def get_group_by_cols(self):
1035 return []
1036
1037 def _resolve_output_field(self):
1038 if isinstance(self.value, str):
1039 return fields.CharField()
1040 if isinstance(self.value, bool):
1041 return fields.BooleanField()
1042 if isinstance(self.value, int):
1043 return fields.IntegerField()
1044 if isinstance(self.value, float):
1045 return fields.FloatField()
1046 if isinstance(self.value, datetime.datetime):
1047 return fields.DateTimeField()
1048 if isinstance(self.value, datetime.date):
1049 return fields.DateField()
1050 if isinstance(self.value, datetime.time):
1051 return fields.TimeField()
1052 if isinstance(self.value, datetime.timedelta):
1053 return fields.DurationField()
1054 if isinstance(self.value, Decimal):
1055 return fields.DecimalField()
1056 if isinstance(self.value, bytes):
1057 return fields.BinaryField()
1058 if isinstance(self.value, UUID):
1059 return fields.UUIDField()
1060
1061 @property
1062 def empty_result_set_value(self):
1063 return self.value
1064
1065
1066class RawSQL(Expression):
1067 def __init__(self, sql, params, output_field=None):
1068 if output_field is None:
1069 output_field = fields.Field()
1070 self.sql, self.params = sql, params
1071 super().__init__(output_field=output_field)
1072
1073 def __repr__(self):
1074 return f"{self.__class__.__name__}({self.sql}, {self.params})"
1075
1076 def as_sql(self, compiler, connection):
1077 return f"({self.sql})", self.params
1078
1079 def get_group_by_cols(self):
1080 return [self]
1081
1082
1083class Star(Expression):
1084 def __repr__(self):
1085 return "'*'"
1086
1087 def as_sql(self, compiler, connection):
1088 return "*", []
1089
1090
1091class Col(Expression):
1092 contains_column_references = True
1093 possibly_multivalued = False
1094
1095 def __init__(self, alias, target, output_field=None):
1096 if output_field is None:
1097 output_field = target
1098 super().__init__(output_field=output_field)
1099 self.alias, self.target = alias, target
1100
1101 def __repr__(self):
1102 alias, target = self.alias, self.target
1103 identifiers = (alias, str(target)) if alias else (str(target),)
1104 return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
1105
1106 def as_sql(self, compiler, connection):
1107 alias, column = self.alias, self.target.column
1108 identifiers = (alias, column) if alias else (column,)
1109 sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
1110 return sql, []
1111
1112 def relabeled_clone(self, relabels):
1113 if self.alias is None:
1114 return self
1115 return self.__class__(
1116 relabels.get(self.alias, self.alias), self.target, self.output_field
1117 )
1118
1119 def get_group_by_cols(self):
1120 return [self]
1121
1122 def get_db_converters(self, connection):
1123 if self.target == self.output_field:
1124 return self.output_field.get_db_converters(connection)
1125 return self.output_field.get_db_converters(
1126 connection
1127 ) + self.target.get_db_converters(connection)
1128
1129
1130class Ref(Expression):
1131 """
1132 Reference to column alias of the query. For example, Ref('sum_cost') in
1133 qs.annotate(sum_cost=Sum('cost')) query.
1134 """
1135
1136 def __init__(self, refs, source):
1137 super().__init__()
1138 self.refs, self.source = refs, source
1139
1140 def __repr__(self):
1141 return f"{self.__class__.__name__}({self.refs}, {self.source})"
1142
1143 def get_source_expressions(self):
1144 return [self.source]
1145
1146 def set_source_expressions(self, exprs):
1147 (self.source,) = exprs
1148
1149 def resolve_expression(
1150 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1151 ):
1152 # The sub-expression `source` has already been resolved, as this is
1153 # just a reference to the name of `source`.
1154 return self
1155
1156 def get_refs(self):
1157 return {self.refs}
1158
1159 def relabeled_clone(self, relabels):
1160 return self
1161
1162 def as_sql(self, compiler, connection):
1163 return connection.ops.quote_name(self.refs), []
1164
1165 def get_group_by_cols(self):
1166 return [self]
1167
1168
1169class ExpressionList(Func):
1170 """
1171 An expression containing multiple expressions. Can be used to provide a
1172 list of expressions as an argument to another expression, like a partition
1173 clause.
1174 """
1175
1176 template = "%(expressions)s"
1177
1178 def __init__(self, *expressions, **extra):
1179 if not expressions:
1180 raise ValueError(
1181 f"{self.__class__.__name__} requires at least one expression."
1182 )
1183 super().__init__(*expressions, **extra)
1184
1185 def __str__(self):
1186 return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
1187
1188 def as_sqlite(self, compiler, connection, **extra_context):
1189 # Casting to numeric is unnecessary.
1190 return self.as_sql(compiler, connection, **extra_context)
1191
1192
1193class OrderByList(Func):
1194 template = "ORDER BY %(expressions)s"
1195
1196 def __init__(self, *expressions, **extra):
1197 expressions = (
1198 (
1199 OrderBy(F(expr[1:]), descending=True)
1200 if isinstance(expr, str) and expr[0] == "-"
1201 else expr
1202 )
1203 for expr in expressions
1204 )
1205 super().__init__(*expressions, **extra)
1206
1207 def as_sql(self, *args, **kwargs):
1208 if not self.source_expressions:
1209 return "", ()
1210 return super().as_sql(*args, **kwargs)
1211
1212 def get_group_by_cols(self):
1213 group_by_cols = []
1214 for order_by in self.get_source_expressions():
1215 group_by_cols.extend(order_by.get_group_by_cols())
1216 return group_by_cols
1217
1218
1219@deconstructible(path="plain.models.ExpressionWrapper")
1220class ExpressionWrapper(SQLiteNumericMixin, Expression):
1221 """
1222 An expression that can wrap another expression so that it can provide
1223 extra context to the inner expression, such as the output_field.
1224 """
1225
1226 def __init__(self, expression, output_field):
1227 super().__init__(output_field=output_field)
1228 self.expression = expression
1229
1230 def set_source_expressions(self, exprs):
1231 self.expression = exprs[0]
1232
1233 def get_source_expressions(self):
1234 return [self.expression]
1235
1236 def get_group_by_cols(self):
1237 if isinstance(self.expression, Expression):
1238 expression = self.expression.copy()
1239 expression.output_field = self.output_field
1240 return expression.get_group_by_cols()
1241 # For non-expressions e.g. an SQL WHERE clause, the entire
1242 # `expression` must be included in the GROUP BY clause.
1243 return super().get_group_by_cols()
1244
1245 def as_sql(self, compiler, connection):
1246 return compiler.compile(self.expression)
1247
1248 def __repr__(self):
1249 return f"{self.__class__.__name__}({self.expression})"
1250
1251
1252class NegatedExpression(ExpressionWrapper):
1253 """The logical negation of a conditional expression."""
1254
1255 def __init__(self, expression):
1256 super().__init__(expression, output_field=fields.BooleanField())
1257
1258 def __invert__(self):
1259 return self.expression.copy()
1260
1261 def as_sql(self, compiler, connection):
1262 try:
1263 sql, params = super().as_sql(compiler, connection)
1264 except EmptyResultSet:
1265 features = compiler.connection.features
1266 if not features.supports_boolean_expr_in_select_clause:
1267 return "1=1", ()
1268 return compiler.compile(Value(True))
1269 ops = compiler.connection.ops
1270 # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
1271 # to be compared to another expression unless they're wrapped in a CASE
1272 # WHEN.
1273 if not ops.conditional_expression_supported_in_where_clause(self.expression):
1274 return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
1275 return f"NOT {sql}", params
1276
1277 def resolve_expression(
1278 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1279 ):
1280 resolved = super().resolve_expression(
1281 query, allow_joins, reuse, summarize, for_save
1282 )
1283 if not getattr(resolved.expression, "conditional", False):
1284 raise TypeError("Cannot negate non-conditional expressions.")
1285 return resolved
1286
1287 def select_format(self, compiler, sql, params):
1288 # Wrap boolean expressions with a CASE WHEN expression if a database
1289 # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
1290 # GROUP BY list.
1291 expression_supported_in_where_clause = (
1292 compiler.connection.ops.conditional_expression_supported_in_where_clause
1293 )
1294 if (
1295 not compiler.connection.features.supports_boolean_expr_in_select_clause
1296 # Avoid double wrapping.
1297 and expression_supported_in_where_clause(self.expression)
1298 ):
1299 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1300 return sql, params
1301
1302
1303@deconstructible(path="plain.models.When")
1304class When(Expression):
1305 template = "WHEN %(condition)s THEN %(result)s"
1306 # This isn't a complete conditional expression, must be used in Case().
1307 conditional = False
1308
1309 def __init__(self, condition=None, then=None, **lookups):
1310 if lookups:
1311 if condition is None:
1312 condition, lookups = Q(**lookups), None
1313 elif getattr(condition, "conditional", False):
1314 condition, lookups = Q(condition, **lookups), None
1315 if condition is None or not getattr(condition, "conditional", False) or lookups:
1316 raise TypeError(
1317 "When() supports a Q object, a boolean expression, or lookups "
1318 "as a condition."
1319 )
1320 if isinstance(condition, Q) and not condition:
1321 raise ValueError("An empty Q() can't be used as a When() condition.")
1322 super().__init__(output_field=None)
1323 self.condition = condition
1324 self.result = self._parse_expressions(then)[0]
1325
1326 def __str__(self):
1327 return f"WHEN {self.condition!r} THEN {self.result!r}"
1328
1329 def __repr__(self):
1330 return f"<{self.__class__.__name__}: {self}>"
1331
1332 def get_source_expressions(self):
1333 return [self.condition, self.result]
1334
1335 def set_source_expressions(self, exprs):
1336 self.condition, self.result = exprs
1337
1338 def get_source_fields(self):
1339 # We're only interested in the fields of the result expressions.
1340 return [self.result._output_field_or_none]
1341
1342 def resolve_expression(
1343 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1344 ):
1345 c = self.copy()
1346 c.is_summary = summarize
1347 if hasattr(c.condition, "resolve_expression"):
1348 c.condition = c.condition.resolve_expression(
1349 query, allow_joins, reuse, summarize, False
1350 )
1351 c.result = c.result.resolve_expression(
1352 query, allow_joins, reuse, summarize, for_save
1353 )
1354 return c
1355
1356 def as_sql(self, compiler, connection, template=None, **extra_context):
1357 connection.ops.check_expression_support(self)
1358 template_params = extra_context
1359 sql_params = []
1360 condition_sql, condition_params = compiler.compile(self.condition)
1361 template_params["condition"] = condition_sql
1362 result_sql, result_params = compiler.compile(self.result)
1363 template_params["result"] = result_sql
1364 template = template or self.template
1365 return template % template_params, (
1366 *sql_params,
1367 *condition_params,
1368 *result_params,
1369 )
1370
1371 def get_group_by_cols(self):
1372 # This is not a complete expression and cannot be used in GROUP BY.
1373 cols = []
1374 for source in self.get_source_expressions():
1375 cols.extend(source.get_group_by_cols())
1376 return cols
1377
1378
1379@deconstructible(path="plain.models.Case")
1380class Case(SQLiteNumericMixin, Expression):
1381 """
1382 An SQL searched CASE expression:
1383
1384 CASE
1385 WHEN n > 0
1386 THEN 'positive'
1387 WHEN n < 0
1388 THEN 'negative'
1389 ELSE 'zero'
1390 END
1391 """
1392
1393 template = "CASE %(cases)s ELSE %(default)s END"
1394 case_joiner = " "
1395
1396 def __init__(self, *cases, default=None, output_field=None, **extra):
1397 if not all(isinstance(case, When) for case in cases):
1398 raise TypeError("Positional arguments must all be When objects.")
1399 super().__init__(output_field)
1400 self.cases = list(cases)
1401 self.default = self._parse_expressions(default)[0]
1402 self.extra = extra
1403
1404 def __str__(self):
1405 return "CASE {}, ELSE {!r}".format(
1406 ", ".join(str(c) for c in self.cases),
1407 self.default,
1408 )
1409
1410 def __repr__(self):
1411 return f"<{self.__class__.__name__}: {self}>"
1412
1413 def get_source_expressions(self):
1414 return self.cases + [self.default]
1415
1416 def set_source_expressions(self, exprs):
1417 *self.cases, self.default = exprs
1418
1419 def resolve_expression(
1420 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
1421 ):
1422 c = self.copy()
1423 c.is_summary = summarize
1424 for pos, case in enumerate(c.cases):
1425 c.cases[pos] = case.resolve_expression(
1426 query, allow_joins, reuse, summarize, for_save
1427 )
1428 c.default = c.default.resolve_expression(
1429 query, allow_joins, reuse, summarize, for_save
1430 )
1431 return c
1432
1433 def copy(self):
1434 c = super().copy()
1435 c.cases = c.cases[:]
1436 return c
1437
1438 def as_sql(
1439 self, compiler, connection, template=None, case_joiner=None, **extra_context
1440 ):
1441 connection.ops.check_expression_support(self)
1442 if not self.cases:
1443 return compiler.compile(self.default)
1444 template_params = {**self.extra, **extra_context}
1445 case_parts = []
1446 sql_params = []
1447 default_sql, default_params = compiler.compile(self.default)
1448 for case in self.cases:
1449 try:
1450 case_sql, case_params = compiler.compile(case)
1451 except EmptyResultSet:
1452 continue
1453 except FullResultSet:
1454 default_sql, default_params = compiler.compile(case.result)
1455 break
1456 case_parts.append(case_sql)
1457 sql_params.extend(case_params)
1458 if not case_parts:
1459 return default_sql, default_params
1460 case_joiner = case_joiner or self.case_joiner
1461 template_params["cases"] = case_joiner.join(case_parts)
1462 template_params["default"] = default_sql
1463 sql_params.extend(default_params)
1464 template = template or template_params.get("template", self.template)
1465 sql = template % template_params
1466 if self._output_field_or_none is not None:
1467 sql = connection.ops.unification_cast_sql(self.output_field) % sql
1468 return sql, sql_params
1469
1470 def get_group_by_cols(self):
1471 if not self.cases:
1472 return self.default.get_group_by_cols()
1473 return super().get_group_by_cols()
1474
1475
1476class Subquery(BaseExpression, Combinable):
1477 """
1478 An explicit subquery. It may contain OuterRef() references to the outer
1479 query which will be resolved when it is applied to that query.
1480 """
1481
1482 template = "(%(subquery)s)"
1483 contains_aggregate = False
1484 empty_result_set_value = None
1485
1486 def __init__(self, queryset, output_field=None, **extra):
1487 # Allow the usage of both QuerySet and sql.Query objects.
1488 self.query = getattr(queryset, "query", queryset).clone()
1489 self.query.subquery = True
1490 self.extra = extra
1491 super().__init__(output_field)
1492
1493 def get_source_expressions(self):
1494 return [self.query]
1495
1496 def set_source_expressions(self, exprs):
1497 self.query = exprs[0]
1498
1499 def _resolve_output_field(self):
1500 return self.query.output_field
1501
1502 def copy(self):
1503 clone = super().copy()
1504 clone.query = clone.query.clone()
1505 return clone
1506
1507 @property
1508 def external_aliases(self):
1509 return self.query.external_aliases
1510
1511 def get_external_cols(self):
1512 return self.query.get_external_cols()
1513
1514 def as_sql(self, compiler, connection, template=None, **extra_context):
1515 connection.ops.check_expression_support(self)
1516 template_params = {**self.extra, **extra_context}
1517 subquery_sql, sql_params = self.query.as_sql(compiler, connection)
1518 template_params["subquery"] = subquery_sql[1:-1]
1519
1520 template = template or template_params.get("template", self.template)
1521 sql = template % template_params
1522 return sql, sql_params
1523
1524 def get_group_by_cols(self):
1525 return self.query.get_group_by_cols(wrapper=self)
1526
1527
1528class Exists(Subquery):
1529 template = "EXISTS(%(subquery)s)"
1530 output_field = fields.BooleanField()
1531 empty_result_set_value = False
1532
1533 def __init__(self, queryset, **kwargs):
1534 super().__init__(queryset, **kwargs)
1535 self.query = self.query.exists()
1536
1537 def select_format(self, compiler, sql, params):
1538 # Wrap EXISTS() with a CASE WHEN expression if a database backend
1539 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
1540 # BY list.
1541 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
1542 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
1543 return sql, params
1544
1545
1546@deconstructible(path="plain.models.OrderBy")
1547class OrderBy(Expression):
1548 template = "%(expression)s %(ordering)s"
1549 conditional = False
1550
1551 def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None):
1552 if nulls_first and nulls_last:
1553 raise ValueError("nulls_first and nulls_last are mutually exclusive")
1554 if nulls_first is False or nulls_last is False:
1555 raise ValueError("nulls_first and nulls_last values must be True or None.")
1556 self.nulls_first = nulls_first
1557 self.nulls_last = nulls_last
1558 self.descending = descending
1559 if not hasattr(expression, "resolve_expression"):
1560 raise ValueError("expression must be an expression type")
1561 self.expression = expression
1562
1563 def __repr__(self):
1564 return f"{self.__class__.__name__}({self.expression}, descending={self.descending})"
1565
1566 def set_source_expressions(self, exprs):
1567 self.expression = exprs[0]
1568
1569 def get_source_expressions(self):
1570 return [self.expression]
1571
1572 def as_sql(self, compiler, connection, template=None, **extra_context):
1573 template = template or self.template
1574 if connection.features.supports_order_by_nulls_modifier:
1575 if self.nulls_last:
1576 template = f"{template} NULLS LAST"
1577 elif self.nulls_first:
1578 template = f"{template} NULLS FIRST"
1579 else:
1580 if self.nulls_last and not (
1581 self.descending and connection.features.order_by_nulls_first
1582 ):
1583 template = f"%(expression)s IS NULL, {template}"
1584 elif self.nulls_first and not (
1585 not self.descending and connection.features.order_by_nulls_first
1586 ):
1587 template = f"%(expression)s IS NOT NULL, {template}"
1588 connection.ops.check_expression_support(self)
1589 expression_sql, params = compiler.compile(self.expression)
1590 placeholders = {
1591 "expression": expression_sql,
1592 "ordering": "DESC" if self.descending else "ASC",
1593 **extra_context,
1594 }
1595 params *= template.count("%(expression)s")
1596 return (template % placeholders).rstrip(), params
1597
1598 def get_group_by_cols(self):
1599 cols = []
1600 for source in self.get_source_expressions():
1601 cols.extend(source.get_group_by_cols())
1602 return cols
1603
1604 def reverse_ordering(self):
1605 self.descending = not self.descending
1606 if self.nulls_first:
1607 self.nulls_last = True
1608 self.nulls_first = None
1609 elif self.nulls_last:
1610 self.nulls_first = True
1611 self.nulls_last = None
1612 return self
1613
1614 def asc(self):
1615 self.descending = False
1616
1617 def desc(self):
1618 self.descending = True
1619
1620
1621class Window(SQLiteNumericMixin, Expression):
1622 template = "%(expression)s OVER (%(window)s)"
1623 # Although the main expression may either be an aggregate or an
1624 # expression with an aggregate function, the GROUP BY that will
1625 # be introduced in the query as a result is not desired.
1626 contains_aggregate = False
1627 contains_over_clause = True
1628
1629 def __init__(
1630 self,
1631 expression,
1632 partition_by=None,
1633 order_by=None,
1634 frame=None,
1635 output_field=None,
1636 ):
1637 self.partition_by = partition_by
1638 self.order_by = order_by
1639 self.frame = frame
1640
1641 if not getattr(expression, "window_compatible", False):
1642 raise ValueError(
1643 f"Expression '{expression.__class__.__name__}' isn't compatible with OVER clauses."
1644 )
1645
1646 if self.partition_by is not None:
1647 if not isinstance(self.partition_by, tuple | list):
1648 self.partition_by = (self.partition_by,)
1649 self.partition_by = ExpressionList(*self.partition_by)
1650
1651 if self.order_by is not None:
1652 if isinstance(self.order_by, list | tuple):
1653 self.order_by = OrderByList(*self.order_by)
1654 elif isinstance(self.order_by, BaseExpression | str):
1655 self.order_by = OrderByList(self.order_by)
1656 else:
1657 raise ValueError(
1658 "Window.order_by must be either a string reference to a "
1659 "field, an expression, or a list or tuple of them."
1660 )
1661 super().__init__(output_field=output_field)
1662 self.source_expression = self._parse_expressions(expression)[0]
1663
1664 def _resolve_output_field(self):
1665 return self.source_expression.output_field
1666
1667 def get_source_expressions(self):
1668 return [self.source_expression, self.partition_by, self.order_by, self.frame]
1669
1670 def set_source_expressions(self, exprs):
1671 self.source_expression, self.partition_by, self.order_by, self.frame = exprs
1672
1673 def as_sql(self, compiler, connection, template=None):
1674 connection.ops.check_expression_support(self)
1675 if not connection.features.supports_over_clause:
1676 raise NotSupportedError("This backend does not support window expressions.")
1677 expr_sql, params = compiler.compile(self.source_expression)
1678 window_sql, window_params = [], ()
1679
1680 if self.partition_by is not None:
1681 sql_expr, sql_params = self.partition_by.as_sql(
1682 compiler=compiler,
1683 connection=connection,
1684 template="PARTITION BY %(expressions)s",
1685 )
1686 window_sql.append(sql_expr)
1687 window_params += tuple(sql_params)
1688
1689 if self.order_by is not None:
1690 order_sql, order_params = compiler.compile(self.order_by)
1691 window_sql.append(order_sql)
1692 window_params += tuple(order_params)
1693
1694 if self.frame:
1695 frame_sql, frame_params = compiler.compile(self.frame)
1696 window_sql.append(frame_sql)
1697 window_params += tuple(frame_params)
1698
1699 template = template or self.template
1700
1701 return (
1702 template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
1703 (*params, *window_params),
1704 )
1705
1706 def as_sqlite(self, compiler, connection):
1707 if isinstance(self.output_field, fields.DecimalField):
1708 # Casting to numeric must be outside of the window expression.
1709 copy = self.copy()
1710 source_expressions = copy.get_source_expressions()
1711 source_expressions[0].output_field = fields.FloatField()
1712 copy.set_source_expressions(source_expressions)
1713 return super(Window, copy).as_sqlite(compiler, connection)
1714 return self.as_sql(compiler, connection)
1715
1716 def __str__(self):
1717 return "{} OVER ({}{}{})".format(
1718 str(self.source_expression),
1719 "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
1720 str(self.order_by or ""),
1721 str(self.frame or ""),
1722 )
1723
1724 def __repr__(self):
1725 return f"<{self.__class__.__name__}: {self}>"
1726
1727 def get_group_by_cols(self):
1728 group_by_cols = []
1729 if self.partition_by:
1730 group_by_cols.extend(self.partition_by.get_group_by_cols())
1731 if self.order_by is not None:
1732 group_by_cols.extend(self.order_by.get_group_by_cols())
1733 return group_by_cols
1734
1735
1736class WindowFrame(Expression):
1737 """
1738 Model the frame clause in window expressions. There are two types of frame
1739 clauses which are subclasses, however, all processing and validation (by no
1740 means intended to be complete) is done here. Thus, providing an end for a
1741 frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
1742 row in the frame).
1743 """
1744
1745 template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
1746
1747 def __init__(self, start=None, end=None):
1748 self.start = Value(start)
1749 self.end = Value(end)
1750
1751 def set_source_expressions(self, exprs):
1752 self.start, self.end = exprs
1753
1754 def get_source_expressions(self):
1755 return [self.start, self.end]
1756
1757 def as_sql(self, compiler, connection):
1758 connection.ops.check_expression_support(self)
1759 start, end = self.window_frame_start_end(
1760 connection, self.start.value, self.end.value
1761 )
1762 return (
1763 self.template
1764 % {
1765 "frame_type": self.frame_type,
1766 "start": start,
1767 "end": end,
1768 },
1769 [],
1770 )
1771
1772 def __repr__(self):
1773 return f"<{self.__class__.__name__}: {self}>"
1774
1775 def get_group_by_cols(self):
1776 return []
1777
1778 def __str__(self):
1779 if self.start.value is not None and self.start.value < 0:
1780 start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING) # noqa: UP031
1781 elif self.start.value is not None and self.start.value == 0:
1782 start = connection.ops.CURRENT_ROW
1783 else:
1784 start = connection.ops.UNBOUNDED_PRECEDING
1785
1786 if self.end.value is not None and self.end.value > 0:
1787 end = "%d %s" % (self.end.value, connection.ops.FOLLOWING) # noqa: UP031
1788 elif self.end.value is not None and self.end.value == 0:
1789 end = connection.ops.CURRENT_ROW
1790 else:
1791 end = connection.ops.UNBOUNDED_FOLLOWING
1792 return self.template % {
1793 "frame_type": self.frame_type,
1794 "start": start,
1795 "end": end,
1796 }
1797
1798 def window_frame_start_end(self, connection, start, end):
1799 raise NotImplementedError("Subclasses must implement window_frame_start_end().")
1800
1801
1802class RowRange(WindowFrame):
1803 frame_type = "ROWS"
1804
1805 def window_frame_start_end(self, connection, start, end):
1806 return connection.ops.window_frame_rows_start_end(start, end)
1807
1808
1809class ValueRange(WindowFrame):
1810 frame_type = "RANGE"
1811
1812 def window_frame_start_end(self, connection, start, end):
1813 return connection.ops.window_frame_range_start_end(start, end)