1"""
2Create SQL statements for QuerySets.
3
4The code in here encapsulates all of the SQL construction so that QuerySets
5themselves do not have to (and could be backed by things other than SQL
6databases). The abstraction barrier only works one way: this module has to know
7all about the internals of models in order to get the information it needs.
8"""
9
10from __future__ import annotations
11
12import copy
13import difflib
14import functools
15import sys
16from collections import Counter
17from collections.abc import Callable, Iterable, Iterator, Mapping
18from collections.abc import Iterator as TypingIterator
19from functools import cached_property
20from itertools import chain, count, product
21from string import ascii_uppercase
22from typing import (
23 TYPE_CHECKING,
24 Any,
25 Literal,
26 NamedTuple,
27 Self,
28 TypeVar,
29 cast,
30 overload,
31)
32
33from plain.models.aggregates import Count
34from plain.models.constants import LOOKUP_SEP
35from plain.models.db import NotSupportedError, db_connection
36from plain.models.exceptions import FieldDoesNotExist, FieldError
37from plain.models.expressions import (
38 BaseExpression,
39 Col,
40 Exists,
41 F,
42 OuterRef,
43 Ref,
44 ResolvableExpression,
45 ResolvedOuterRef,
46 Value,
47)
48from plain.models.fields import Field
49from plain.models.fields.related_lookups import MultiColSource
50from plain.models.lookups import Lookup
51from plain.models.query_utils import (
52 PathInfo,
53 Q,
54 check_rel_lookup_compatibility,
55 refs_expression,
56)
57from plain.models.sql.constants import INNER, LOUTER, ORDER_DIR, SINGLE
58from plain.models.sql.datastructures import BaseTable, Empty, Join, MultiJoin
59from plain.models.sql.where import AND, OR, ExtraWhere, NothingNode, WhereNode
60from plain.utils.regex_helper import _lazy_re_compile
61from plain.utils.tree import Node
62
63if TYPE_CHECKING:
64 from plain.models import Model
65 from plain.models.backends.base.base import BaseDatabaseWrapper
66 from plain.models.fields.related import RelatedField
67 from plain.models.fields.reverse_related import ForeignObjectRel
68 from plain.models.meta import Meta
69 from plain.models.sql.compiler import SQLCompiler, SqlWithParams
70
71__all__ = ["Query", "RawQuery"]
72
73# Quotation marks ('"`[]), whitespace characters, semicolons, or inline
74# SQL comments are forbidden in column aliases.
75FORBIDDEN_ALIAS_PATTERN = _lazy_re_compile(r"['`\"\]\[;\s]|--|/\*|\*/")
76
77# Inspired from
78# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
79EXPLAIN_OPTIONS_PATTERN = _lazy_re_compile(r"[\w\-]+")
80
81
82def get_field_names_from_opts(meta: Meta | None) -> set[str]:
83 if meta is None:
84 return set()
85 return set(
86 chain.from_iterable(
87 (f.name, f.attname) if f.concrete else (f.name,) for f in meta.get_fields()
88 )
89 )
90
91
92def get_children_from_q(q: Q) -> TypingIterator[tuple[str, Any]]:
93 for child in q.children:
94 if isinstance(child, Node):
95 yield from get_children_from_q(child)
96 else:
97 yield child
98
99
100class JoinInfo(NamedTuple):
101 """Information about a join operation in a query."""
102
103 final_field: Field[Any]
104 targets: tuple[Field[Any], ...]
105 meta: Meta
106 joins: list[str]
107 path: list[PathInfo]
108 transform_function: Callable[[Field[Any], str | None], BaseExpression]
109
110
111class RawQuery:
112 """A single raw SQL query."""
113
114 def __init__(self, sql: str, params: tuple[Any, ...] | dict[str, Any] = ()):
115 self.params = params
116 self.sql = sql
117 self.cursor: Any = None
118
119 # Mirror some properties of a normal query so that
120 # the compiler can be used to process results.
121 self.low_mark, self.high_mark = 0, None # Used for offset/limit
122 self.extra_select = {}
123 self.annotation_select = {}
124
125 def chain(self) -> RawQuery:
126 return self.clone()
127
128 def clone(self) -> RawQuery:
129 return RawQuery(self.sql, params=self.params)
130
131 def get_columns(self) -> list[str]:
132 if self.cursor is None:
133 self._execute_query()
134 converter = db_connection.introspection.identifier_converter
135 return [
136 converter(column_model_meta[0])
137 for column_model_meta in self.cursor.description
138 ]
139
140 def __iter__(self) -> TypingIterator[Any]:
141 # Always execute a new query for a new iterator.
142 # This could be optimized with a cache at the expense of RAM.
143 self._execute_query()
144 if not db_connection.features.can_use_chunked_reads:
145 # If the database can't use chunked reads we need to make sure we
146 # evaluate the entire query up front.
147 result = list(self.cursor)
148 else:
149 result = self.cursor
150 return iter(result)
151
152 def __repr__(self) -> str:
153 return f"<{self.__class__.__name__}: {self}>"
154
155 @property
156 def params_type(self) -> type[dict] | type[tuple] | None:
157 if self.params is None:
158 return None
159 return dict if isinstance(self.params, Mapping) else tuple
160
161 def __str__(self) -> str:
162 if self.params_type is None:
163 return self.sql
164 return self.sql % self.params_type(self.params)
165
166 def _execute_query(self) -> None:
167 # Adapt parameters to the database, as much as possible considering
168 # that the target type isn't known. See #17755.
169 params_type = self.params_type
170 adapter = db_connection.ops.adapt_unknown_value
171 if params_type is tuple:
172 assert isinstance(self.params, tuple)
173 params = tuple(adapter(val) for val in self.params)
174 elif params_type is dict:
175 assert isinstance(self.params, dict)
176 params = {key: adapter(val) for key, val in self.params.items()}
177 elif params_type is None:
178 params = None
179 else:
180 raise RuntimeError(f"Unexpected params type: {params_type}")
181
182 self.cursor = db_connection.cursor()
183 self.cursor.execute(self.sql, params)
184
185
186class ExplainInfo(NamedTuple):
187 """Information about an EXPLAIN query."""
188
189 format: str | None
190 options: dict[str, Any]
191
192
193class TransformWrapper:
194 """Wrapper for transform functions that supports the has_transforms attribute.
195
196 This replaces functools.partial for transform functions, allowing proper
197 type checking while supporting dynamic attribute assignment.
198 """
199
200 def __init__(
201 self,
202 func: Callable[..., BaseExpression],
203 **kwargs: Any,
204 ):
205 self._partial = functools.partial(func, **kwargs)
206 self.has_transforms: bool = False
207
208 def __call__(self, field: Field[Any], alias: str | None) -> BaseExpression:
209 return self._partial(field, alias)
210
211
212QueryType = TypeVar("QueryType", bound="Query")
213
214
215class Query(BaseExpression):
216 """A single SQL query."""
217
218 alias_prefix = "T"
219 empty_result_set_value = None
220 subq_aliases = frozenset([alias_prefix])
221
222 base_table_class = BaseTable
223 join_class = Join
224
225 default_cols = True
226 default_ordering = True
227 standard_ordering = True
228
229 filter_is_sticky = False
230 subquery = False
231
232 # SQL-related attributes.
233 # Select and related select clauses are expressions to use in the SELECT
234 # clause of the query. The select is used for cases where we want to set up
235 # the select clause to contain other than default fields (values(),
236 # subqueries...). Note that annotations go to annotations dictionary.
237 select = ()
238 # The group_by attribute can have one of the following forms:
239 # - None: no group by at all in the query
240 # - A tuple of expressions: group by (at least) those expressions.
241 # String refs are also allowed for now.
242 # - True: group by all select fields of the model
243 # See compiler.get_group_by() for details.
244 group_by = None
245 order_by = ()
246 low_mark = 0 # Used for offset/limit.
247 high_mark = None # Used for offset/limit.
248 distinct = False
249 distinct_fields = ()
250 select_for_update = False
251 select_for_update_nowait = False
252 select_for_update_skip_locked = False
253 select_for_update_of = ()
254 select_for_no_key_update = False
255 select_related: bool | dict[str, Any] = False
256 has_select_fields = False
257 # Arbitrary limit for select_related to prevents infinite recursion.
258 max_depth = 5
259 # Holds the selects defined by a call to values() or values_list()
260 # excluding annotation_select and extra_select.
261 values_select = ()
262
263 # SQL annotation-related attributes.
264 annotation_select_mask = None
265 _annotation_select_cache = None
266
267 # These are for extensions. The contents are more or less appended verbatim
268 # to the appropriate clause.
269 extra_select_mask = None
270 _extra_select_cache = None
271
272 extra_tables = ()
273 extra_order_by = ()
274
275 # A tuple that is a set of model field names and either True, if these are
276 # the fields to defer, or False if these are the only fields to load.
277 deferred_loading = (frozenset(), True)
278
279 explain_info = None
280
281 def __init__(self, model: type[Model] | None, alias_cols: bool = True):
282 self.model = model
283 self.alias_refcount = {}
284 # alias_map is the most important data structure regarding joins.
285 # It's used for recording which joins exist in the query and what
286 # types they are. The key is the alias of the joined table (possibly
287 # the table name) and the value is a Join-like object (see
288 # sql.datastructures.Join for more information).
289 self.alias_map = {}
290 # Whether to provide alias to columns during reference resolving.
291 self.alias_cols = alias_cols
292 # Sometimes the query contains references to aliases in outer queries (as
293 # a result of split_exclude). Correct alias quoting needs to know these
294 # aliases too.
295 # Map external tables to whether they are aliased.
296 self.external_aliases = {}
297 self.table_map = {} # Maps table names to list of aliases.
298 self.used_aliases = set()
299
300 self.where = WhereNode()
301 # Maps alias -> Annotation Expression.
302 self.annotations = {}
303 # These are for extensions. The contents are more or less appended
304 # verbatim to the appropriate clause.
305 self.extra = {} # Maps col_alias -> (col_sql, params).
306
307 self._filtered_relations = {}
308
309 @property
310 def output_field(self) -> Field | None:
311 if len(self.select) == 1:
312 select = self.select[0]
313 return getattr(select, "target", None) or select.field
314 elif len(self.annotation_select) == 1:
315 return next(iter(self.annotation_select.values())).output_field
316
317 @cached_property
318 def base_table(self) -> str | None:
319 for alias in self.alias_map:
320 return alias
321
322 def __str__(self) -> str:
323 """
324 Return the query as a string of SQL with the parameter values
325 substituted in (use sql_with_params() to see the unsubstituted string).
326
327 Parameter values won't necessarily be quoted correctly, since that is
328 done by the database interface at execution time.
329 """
330 sql, params = self.sql_with_params()
331 return sql % params
332
333 def sql_with_params(self) -> SqlWithParams:
334 """
335 Return the query as an SQL string and the parameters that will be
336 substituted into the query.
337 """
338 return self.get_compiler().as_sql()
339
340 def __deepcopy__(self, memo: dict[int, Any]) -> Self:
341 """Limit the amount of work when a Query is deepcopied."""
342 result = self.clone()
343 memo[id(self)] = result
344 return result
345
346 def get_compiler(self, *, elide_empty: bool = True) -> SQLCompiler:
347 """Return a compiler instance for this query."""
348 return db_connection.ops.get_compiler_for(self, elide_empty)
349
350 def clone(self) -> Self:
351 """
352 Return a copy of the current Query. A lightweight alternative to
353 deepcopy().
354 """
355 obj = Empty()
356 obj.__class__ = self.__class__
357 obj = cast(Self, obj) # Type checker doesn't understand __class__ reassignment
358 # Copy references to everything.
359 obj.__dict__ = self.__dict__.copy()
360 # Clone attributes that can't use shallow copy.
361 obj.alias_refcount = self.alias_refcount.copy()
362 obj.alias_map = self.alias_map.copy()
363 obj.external_aliases = self.external_aliases.copy()
364 obj.table_map = self.table_map.copy()
365 obj.where = self.where.clone()
366 obj.annotations = self.annotations.copy()
367 if self.annotation_select_mask is not None:
368 obj.annotation_select_mask = self.annotation_select_mask.copy()
369 # _annotation_select_cache cannot be copied, as doing so breaks the
370 # (necessary) state in which both annotations and
371 # _annotation_select_cache point to the same underlying objects.
372 # It will get re-populated in the cloned queryset the next time it's
373 # used.
374 obj._annotation_select_cache = None
375 obj.extra = self.extra.copy()
376 if self.extra_select_mask is not None:
377 obj.extra_select_mask = self.extra_select_mask.copy()
378 if self._extra_select_cache is not None:
379 obj._extra_select_cache = self._extra_select_cache.copy()
380 if self.select_related is not False:
381 # Use deepcopy because select_related stores fields in nested
382 # dicts.
383 obj.select_related = copy.deepcopy(obj.select_related)
384 if "subq_aliases" in self.__dict__:
385 obj.subq_aliases = self.subq_aliases.copy()
386 obj.used_aliases = self.used_aliases.copy()
387 obj._filtered_relations = self._filtered_relations.copy()
388 # Clear the cached_property, if it exists.
389 obj.__dict__.pop("base_table", None)
390 return obj
391
392 @overload
393 def chain(self, klass: None = None) -> Self: ...
394
395 @overload
396 def chain(self, klass: type[QueryType]) -> QueryType: ...
397
398 def chain(self, klass: type[Query] | None = None) -> Query:
399 """
400 Return a copy of the current Query that's ready for another operation.
401 The klass argument changes the type of the Query, e.g. UpdateQuery.
402 """
403 obj = self.clone()
404 if klass and obj.__class__ != klass:
405 obj.__class__ = klass
406 if not obj.filter_is_sticky:
407 obj.used_aliases = set()
408 obj.filter_is_sticky = False
409 if hasattr(obj, "_setup_query"):
410 obj._setup_query() # type: ignore[call-non-callable]
411 return obj
412
413 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
414 clone = self.clone()
415 clone.change_aliases(change_map)
416 return clone
417
418 def _get_col(self, target: Any, field: Field, alias: str | None) -> Col:
419 if not self.alias_cols:
420 alias = None
421 return target.get_col(alias, field)
422
423 def get_aggregation(self, aggregate_exprs: dict[str, Any]) -> dict[str, Any]:
424 """
425 Return the dictionary with the values of the existing aggregations.
426 """
427 if not aggregate_exprs:
428 return {}
429 aggregates = {}
430 for alias, aggregate_expr in aggregate_exprs.items():
431 self.check_alias(alias)
432 aggregate = aggregate_expr.resolve_expression(
433 self, allow_joins=True, reuse=None, summarize=True
434 )
435 if not aggregate.contains_aggregate:
436 raise TypeError(f"{alias} is not an aggregate expression")
437 aggregates[alias] = aggregate
438 # Existing usage of aggregation can be determined by the presence of
439 # selected aggregates but also by filters against aliased aggregates.
440 _, having, qualify = self.where.split_having_qualify()
441 has_existing_aggregation = (
442 any(
443 getattr(annotation, "contains_aggregate", True)
444 for annotation in self.annotations.values()
445 )
446 or having
447 )
448 # Decide if we need to use a subquery.
449 #
450 # Existing aggregations would cause incorrect results as
451 # get_aggregation() must produce just one result and thus must not use
452 # GROUP BY.
453 #
454 # If the query has limit or distinct, or uses set operations, then
455 # those operations must be done in a subquery so that the query
456 # aggregates on the limit and/or distinct results instead of applying
457 # the distinct and limit after the aggregation.
458 if (
459 isinstance(self.group_by, tuple)
460 or self.is_sliced
461 or has_existing_aggregation
462 or qualify
463 or self.distinct
464 ):
465 from plain.models.sql.subqueries import AggregateQuery
466
467 inner_query = self.clone()
468 inner_query.subquery = True
469 outer_query = AggregateQuery(self.model, inner_query)
470 inner_query.select_for_update = False
471 inner_query.select_related = False
472 inner_query.set_annotation_mask(self.annotation_select)
473 # Queries with distinct_fields need ordering and when a limit is
474 # applied we must take the slice from the ordered query. Otherwise
475 # no need for ordering.
476 inner_query.clear_ordering(force=False)
477 if not inner_query.distinct:
478 # If the inner query uses default select and it has some
479 # aggregate annotations, then we must make sure the inner
480 # query is grouped by the main model's primary key. However,
481 # clearing the select clause can alter results if distinct is
482 # used.
483 if inner_query.default_cols and has_existing_aggregation:
484 assert self.model is not None, "Aggregation requires a model"
485 inner_query.group_by = (
486 self.model._model_meta.get_forward_field("id").get_col(
487 inner_query.get_initial_alias()
488 ),
489 )
490 inner_query.default_cols = False
491 if not qualify:
492 # Mask existing annotations that are not referenced by
493 # aggregates to be pushed to the outer query unless
494 # filtering against window functions is involved as it
495 # requires complex realising.
496 annotation_mask = set()
497 for aggregate in aggregates.values():
498 annotation_mask |= aggregate.get_refs()
499 inner_query.set_annotation_mask(annotation_mask)
500
501 # Add aggregates to the outer AggregateQuery. This requires making
502 # sure all columns referenced by the aggregates are selected in the
503 # inner query. It is achieved by retrieving all column references
504 # by the aggregates, explicitly selecting them in the inner query,
505 # and making sure the aggregates are repointed to them.
506 col_refs = {}
507 for alias, aggregate in aggregates.items():
508 replacements = {}
509 for col in self._gen_cols([aggregate], resolve_refs=False):
510 if not (col_ref := col_refs.get(col)):
511 index = len(col_refs) + 1
512 col_alias = f"__col{index}"
513 col_ref = Ref(col_alias, col)
514 col_refs[col] = col_ref
515 inner_query.annotations[col_alias] = col
516 inner_query.append_annotation_mask([col_alias])
517 replacements[col] = col_ref
518 outer_query.annotations[alias] = aggregate.replace_expressions(
519 replacements
520 )
521 if (
522 inner_query.select == ()
523 and not inner_query.default_cols
524 and not inner_query.annotation_select_mask
525 ):
526 # In case of Model.objects[0:3].count(), there would be no
527 # field selected in the inner query, yet we must use a subquery.
528 # So, make sure at least one field is selected.
529 assert self.model is not None, "Count with slicing requires a model"
530 inner_query.select = (
531 self.model._model_meta.get_forward_field("id").get_col(
532 inner_query.get_initial_alias()
533 ),
534 )
535 else:
536 outer_query = self
537 self.select = ()
538 self.default_cols = False
539 self.extra = {}
540 if self.annotations:
541 # Inline reference to existing annotations and mask them as
542 # they are unnecessary given only the summarized aggregations
543 # are requested.
544 replacements = {
545 Ref(alias, annotation): annotation
546 for alias, annotation in self.annotations.items()
547 }
548 self.annotations = {
549 alias: aggregate.replace_expressions(replacements)
550 for alias, aggregate in aggregates.items()
551 }
552 else:
553 self.annotations = aggregates
554 self.set_annotation_mask(aggregates)
555
556 empty_set_result = [
557 expression.empty_result_set_value
558 for expression in outer_query.annotation_select.values()
559 ]
560 elide_empty = not any(result is NotImplemented for result in empty_set_result)
561 outer_query.clear_ordering(force=True)
562 outer_query.clear_limits()
563 outer_query.select_for_update = False
564 outer_query.select_related = False
565 compiler = outer_query.get_compiler(elide_empty=elide_empty)
566 result = compiler.execute_sql(SINGLE)
567 if result is None:
568 result = empty_set_result
569 else:
570 converters = compiler.get_converters(outer_query.annotation_select.values())
571 result = next(compiler.apply_converters((result,), converters))
572
573 return dict(zip(outer_query.annotation_select, result))
574
575 def get_count(self) -> int:
576 """
577 Perform a COUNT() query using the current filter constraints.
578 """
579 obj = self.clone()
580 return obj.get_aggregation({"__count": Count("*")})["__count"]
581
582 def has_filters(self) -> bool:
583 return bool(self.where)
584
585 def exists(self, limit: bool = True) -> Self:
586 q = self.clone()
587 if not (q.distinct and q.is_sliced):
588 if q.group_by is True:
589 assert self.model is not None, "GROUP BY requires a model"
590 q.add_fields(
591 (f.attname for f in self.model._model_meta.concrete_fields), False
592 )
593 # Disable GROUP BY aliases to avoid orphaning references to the
594 # SELECT clause which is about to be cleared.
595 q.set_group_by(allow_aliases=False)
596 q.clear_select_clause()
597 q.clear_ordering(force=True)
598 if limit:
599 q.set_limits(high=1)
600 q.add_annotation(Value(1), "a")
601 return q
602
603 def has_results(self) -> bool:
604 q = self.exists()
605 compiler = q.get_compiler()
606 return compiler.has_results()
607
608 def explain(self, format: str | None = None, **options: Any) -> str:
609 q = self.clone()
610 for option_name in options:
611 if (
612 not EXPLAIN_OPTIONS_PATTERN.fullmatch(option_name)
613 or "--" in option_name
614 ):
615 raise ValueError(f"Invalid option name: {option_name!r}.")
616 q.explain_info = ExplainInfo(format, options)
617 compiler = q.get_compiler()
618 return "\n".join(compiler.explain_query())
619
620 def combine(self, rhs: Query, connector: str) -> None:
621 """
622 Merge the 'rhs' query into the current one (with any 'rhs' effects
623 being applied *after* (that is, "to the right of") anything in the
624 current query. 'rhs' is not modified during a call to this function.
625
626 The 'connector' parameter describes how to connect filters from the
627 'rhs' query.
628 """
629 if self.model != rhs.model:
630 raise TypeError("Cannot combine queries on two different base models.")
631 if self.is_sliced:
632 raise TypeError("Cannot combine queries once a slice has been taken.")
633 if self.distinct != rhs.distinct:
634 raise TypeError("Cannot combine a unique query with a non-unique query.")
635 if self.distinct_fields != rhs.distinct_fields:
636 raise TypeError("Cannot combine queries with different distinct fields.")
637
638 # If lhs and rhs shares the same alias prefix, it is possible to have
639 # conflicting alias changes like T4 -> T5, T5 -> T6, which might end up
640 # as T4 -> T6 while combining two querysets. To prevent this, change an
641 # alias prefix of the rhs and update current aliases accordingly,
642 # except if the alias is the base table since it must be present in the
643 # query on both sides.
644 initial_alias = self.get_initial_alias()
645 assert initial_alias is not None
646 rhs.bump_prefix(self, exclude={initial_alias})
647
648 # Work out how to relabel the rhs aliases, if necessary.
649 change_map = {}
650 conjunction = connector == AND
651
652 # Determine which existing joins can be reused. When combining the
653 # query with AND we must recreate all joins for m2m filters. When
654 # combining with OR we can reuse joins. The reason is that in AND
655 # case a single row can't fulfill a condition like:
656 # revrel__col=1 & revrel__col=2
657 # But, there might be two different related rows matching this
658 # condition. In OR case a single True is enough, so single row is
659 # enough, too.
660 #
661 # Note that we will be creating duplicate joins for non-m2m joins in
662 # the AND case. The results will be correct but this creates too many
663 # joins. This is something that could be fixed later on.
664 reuse = set() if conjunction else set(self.alias_map)
665 joinpromoter = JoinPromoter(connector, 2, False)
666 joinpromoter.add_votes(
667 j for j in self.alias_map if self.alias_map[j].join_type == INNER
668 )
669 rhs_votes = set()
670 # Now, add the joins from rhs query into the new query (skipping base
671 # table).
672 rhs_tables = list(rhs.alias_map)[1:]
673 for alias in rhs_tables:
674 join = rhs.alias_map[alias]
675 # If the left side of the join was already relabeled, use the
676 # updated alias.
677 join = join.relabeled_clone(change_map)
678 new_alias = self.join(join, reuse=reuse)
679 if join.join_type == INNER:
680 rhs_votes.add(new_alias)
681 # We can't reuse the same join again in the query. If we have two
682 # distinct joins for the same connection in rhs query, then the
683 # combined query must have two joins, too.
684 reuse.discard(new_alias)
685 if alias != new_alias:
686 change_map[alias] = new_alias
687 if not rhs.alias_refcount[alias]:
688 # The alias was unused in the rhs query. Unref it so that it
689 # will be unused in the new query, too. We have to add and
690 # unref the alias so that join promotion has information of
691 # the join type for the unused alias.
692 self.unref_alias(new_alias)
693 joinpromoter.add_votes(rhs_votes)
694 joinpromoter.update_join_types(self)
695
696 # Combine subqueries aliases to ensure aliases relabelling properly
697 # handle subqueries when combining where and select clauses.
698 self.subq_aliases |= rhs.subq_aliases
699
700 # Now relabel a copy of the rhs where-clause and add it to the current
701 # one.
702 w = rhs.where.clone()
703 w.relabel_aliases(change_map)
704 self.where.add(w, connector)
705
706 # Selection columns and extra extensions are those provided by 'rhs'.
707 if rhs.select:
708 self.set_select([col.relabeled_clone(change_map) for col in rhs.select])
709 else:
710 self.select = ()
711
712 if connector == OR:
713 # It would be nice to be able to handle this, but the queries don't
714 # really make sense (or return consistent value sets). Not worth
715 # the extra complexity when you can write a real query instead.
716 if self.extra and rhs.extra:
717 raise ValueError(
718 "When merging querysets using 'or', you cannot have "
719 "extra(select=...) on both sides."
720 )
721 self.extra.update(rhs.extra)
722 extra_select_mask = set()
723 if self.extra_select_mask is not None:
724 extra_select_mask.update(self.extra_select_mask)
725 if rhs.extra_select_mask is not None:
726 extra_select_mask.update(rhs.extra_select_mask)
727 if extra_select_mask:
728 self.set_extra_mask(extra_select_mask)
729 self.extra_tables += rhs.extra_tables
730
731 # Ordering uses the 'rhs' ordering, unless it has none, in which case
732 # the current ordering is used.
733 self.order_by = rhs.order_by or self.order_by
734 self.extra_order_by = rhs.extra_order_by or self.extra_order_by
735
736 def _get_defer_select_mask(
737 self,
738 meta: Meta,
739 mask: dict[str, Any],
740 select_mask: dict[Any, Any] | None = None,
741 ) -> dict[Any, Any]:
742 from plain.models.fields.related import RelatedField
743
744 if select_mask is None:
745 select_mask = {}
746 select_mask[meta.get_forward_field("id")] = {}
747 # All concrete fields that are not part of the defer mask must be
748 # loaded. If a relational field is encountered it gets added to the
749 # mask for it be considered if `select_related` and the cycle continues
750 # by recursively calling this function.
751 for field in meta.concrete_fields:
752 field_mask = mask.pop(field.name, None)
753 if field_mask is None:
754 select_mask.setdefault(field, {})
755 elif field_mask:
756 if not isinstance(field, RelatedField):
757 raise FieldError(next(iter(field_mask)))
758 field_select_mask = select_mask.setdefault(field, {})
759 related_model = field.remote_field.model
760 self._get_defer_select_mask(
761 related_model._model_meta, field_mask, field_select_mask
762 )
763 # Remaining defer entries must be references to reverse relationships.
764 # The following code is expected to raise FieldError if it encounters
765 # a malformed defer entry.
766 for field_name, field_mask in mask.items():
767 if filtered_relation := self._filtered_relations.get(field_name):
768 relation = meta.get_reverse_relation(filtered_relation.relation_name)
769 field_select_mask = select_mask.setdefault((field_name, relation), {})
770 field = relation.field
771 else:
772 field = meta.get_reverse_relation(field_name).field
773 field_select_mask = select_mask.setdefault(field, {})
774 related_model = field.model
775 self._get_defer_select_mask(
776 related_model._model_meta, field_mask, field_select_mask
777 )
778 return select_mask
779
780 def _get_only_select_mask(
781 self,
782 meta: Meta,
783 mask: dict[str, Any],
784 select_mask: dict[Any, Any] | None = None,
785 ) -> dict[Any, Any]:
786 from plain.models.fields.related import RelatedField
787
788 if select_mask is None:
789 select_mask = {}
790 select_mask[meta.get_forward_field("id")] = {}
791 # Only include fields mentioned in the mask.
792 for field_name, field_mask in mask.items():
793 field = meta.get_field(field_name)
794 field_select_mask = select_mask.setdefault(field, {})
795 if field_mask:
796 if not isinstance(field, RelatedField):
797 raise FieldError(next(iter(field_mask)))
798 related_model = field.remote_field.model
799 self._get_only_select_mask(
800 related_model._model_meta, field_mask, field_select_mask
801 )
802 return select_mask
803
804 def get_select_mask(self) -> dict[Any, Any]:
805 """
806 Convert the self.deferred_loading data structure to an alternate data
807 structure, describing the field that *will* be loaded. This is used to
808 compute the columns to select from the database and also by the
809 QuerySet class to work out which fields are being initialized on each
810 model. Models that have all their fields included aren't mentioned in
811 the result, only those that have field restrictions in place.
812 """
813 field_names, defer = self.deferred_loading
814 if not field_names:
815 return {}
816 mask = {}
817 for field_name in field_names:
818 part_mask = mask
819 for part in field_name.split(LOOKUP_SEP):
820 part_mask = part_mask.setdefault(part, {})
821 assert self.model is not None, "Deferred/only field loading requires a model"
822 meta = self.model._model_meta
823 if defer:
824 return self._get_defer_select_mask(meta, mask)
825 return self._get_only_select_mask(meta, mask)
826
827 def table_alias(
828 self, table_name: str, create: bool = False, filtered_relation: Any = None
829 ) -> tuple[str, bool]:
830 """
831 Return a table alias for the given table_name and whether this is a
832 new alias or not.
833
834 If 'create' is true, a new alias is always created. Otherwise, the
835 most recently created alias for the table (if one exists) is reused.
836 """
837 alias_list = self.table_map.get(table_name)
838 if not create and alias_list:
839 alias = alias_list[0]
840 self.alias_refcount[alias] += 1
841 return alias, False
842
843 # Create a new alias for this table.
844 if alias_list:
845 alias = "%s%d" % (self.alias_prefix, len(self.alias_map) + 1) # noqa: UP031
846 alias_list.append(alias)
847 else:
848 # The first occurrence of a table uses the table name directly.
849 alias = (
850 filtered_relation.alias if filtered_relation is not None else table_name
851 )
852 self.table_map[table_name] = [alias]
853 self.alias_refcount[alias] = 1
854 return alias, True
855
856 def ref_alias(self, alias: str) -> None:
857 """Increases the reference count for this alias."""
858 self.alias_refcount[alias] += 1
859
860 def unref_alias(self, alias: str, amount: int = 1) -> None:
861 """Decreases the reference count for this alias."""
862 self.alias_refcount[alias] -= amount
863
864 def promote_joins(self, aliases: set[str] | list[str]) -> None:
865 """
866 Promote recursively the join type of given aliases and its children to
867 an outer join. If 'unconditional' is False, only promote the join if
868 it is nullable or the parent join is an outer join.
869
870 The children promotion is done to avoid join chains that contain a LOUTER
871 b INNER c. So, if we have currently a INNER b INNER c and a->b is promoted,
872 then we must also promote b->c automatically, or otherwise the promotion
873 of a->b doesn't actually change anything in the query results.
874 """
875 aliases = list(aliases)
876 while aliases:
877 alias = aliases.pop(0)
878 if self.alias_map[alias].join_type is None:
879 # This is the base table (first FROM entry) - this table
880 # isn't really joined at all in the query, so we should not
881 # alter its join type.
882 continue
883 # Only the first alias (skipped above) should have None join_type
884 assert self.alias_map[alias].join_type is not None
885 parent_alias = self.alias_map[alias].parent_alias
886 parent_louter = (
887 parent_alias and self.alias_map[parent_alias].join_type == LOUTER
888 )
889 already_louter = self.alias_map[alias].join_type == LOUTER
890 if (self.alias_map[alias].nullable or parent_louter) and not already_louter:
891 self.alias_map[alias] = self.alias_map[alias].promote()
892 # Join type of 'alias' changed, so re-examine all aliases that
893 # refer to this one.
894 aliases.extend(
895 join
896 for join in self.alias_map
897 if self.alias_map[join].parent_alias == alias
898 and join not in aliases
899 )
900
901 def demote_joins(self, aliases: set[str] | list[str]) -> None:
902 """
903 Change join type from LOUTER to INNER for all joins in aliases.
904
905 Similarly to promote_joins(), this method must ensure no join chains
906 containing first an outer, then an inner join are generated. If we
907 are demoting b->c join in chain a LOUTER b LOUTER c then we must
908 demote a->b automatically, or otherwise the demotion of b->c doesn't
909 actually change anything in the query results. .
910 """
911 aliases = list(aliases)
912 while aliases:
913 alias = aliases.pop(0)
914 if self.alias_map[alias].join_type == LOUTER:
915 self.alias_map[alias] = self.alias_map[alias].demote()
916 parent_alias = self.alias_map[alias].parent_alias
917 if self.alias_map[parent_alias].join_type == INNER:
918 aliases.append(parent_alias)
919
920 def reset_refcounts(self, to_counts: dict[str, int]) -> None:
921 """
922 Reset reference counts for aliases so that they match the value passed
923 in `to_counts`.
924 """
925 for alias, cur_refcount in self.alias_refcount.copy().items():
926 unref_amount = cur_refcount - to_counts.get(alias, 0)
927 self.unref_alias(alias, unref_amount)
928
929 def change_aliases(self, change_map: dict[str, str]) -> None:
930 """
931 Change the aliases in change_map (which maps old-alias -> new-alias),
932 relabelling any references to them in select columns and the where
933 clause.
934 """
935 # If keys and values of change_map were to intersect, an alias might be
936 # updated twice (e.g. T4 -> T5, T5 -> T6, so also T4 -> T6) depending
937 # on their order in change_map.
938 assert set(change_map).isdisjoint(change_map.values())
939
940 # 1. Update references in "select" (normal columns plus aliases),
941 # "group by" and "where".
942 self.where.relabel_aliases(change_map)
943 if isinstance(self.group_by, tuple):
944 self.group_by = tuple(
945 [col.relabeled_clone(change_map) for col in self.group_by]
946 )
947 self.select = tuple([col.relabeled_clone(change_map) for col in self.select])
948 self.annotations = self.annotations and {
949 key: col.relabeled_clone(change_map)
950 for key, col in self.annotations.items()
951 }
952
953 # 2. Rename the alias in the internal table/alias datastructures.
954 for old_alias, new_alias in change_map.items():
955 if old_alias not in self.alias_map:
956 continue
957 alias_data = self.alias_map[old_alias].relabeled_clone(change_map)
958 self.alias_map[new_alias] = alias_data
959 self.alias_refcount[new_alias] = self.alias_refcount[old_alias]
960 del self.alias_refcount[old_alias]
961 del self.alias_map[old_alias]
962
963 table_aliases = self.table_map[alias_data.table_name]
964 for pos, alias in enumerate(table_aliases):
965 if alias == old_alias:
966 table_aliases[pos] = new_alias
967 break
968 self.external_aliases = {
969 # Table is aliased or it's being changed and thus is aliased.
970 change_map.get(alias, alias): (aliased or alias in change_map)
971 for alias, aliased in self.external_aliases.items()
972 }
973
974 def bump_prefix(
975 self, other_query: Query, exclude: set[str] | dict[str, str] | None = None
976 ) -> None:
977 """
978 Change the alias prefix to the next letter in the alphabet in a way
979 that the other query's aliases and this query's aliases will not
980 conflict. Even tables that previously had no alias will get an alias
981 after this call. To prevent changing aliases use the exclude parameter.
982 """
983
984 def prefix_gen() -> TypingIterator[str]:
985 """
986 Generate a sequence of characters in alphabetical order:
987 -> 'A', 'B', 'C', ...
988
989 When the alphabet is finished, the sequence will continue with the
990 Cartesian product:
991 -> 'AA', 'AB', 'AC', ...
992 """
993 alphabet = ascii_uppercase
994 prefix = chr(ord(self.alias_prefix) + 1)
995 yield prefix
996 for n in count(1):
997 seq = alphabet[alphabet.index(prefix) :] if prefix else alphabet
998 for s in product(seq, repeat=n):
999 yield "".join(s)
1000 prefix = None
1001
1002 if self.alias_prefix != other_query.alias_prefix:
1003 # No clashes between self and outer query should be possible.
1004 return
1005
1006 # Explicitly avoid infinite loop. The constant divider is based on how
1007 # much depth recursive subquery references add to the stack. This value
1008 # might need to be adjusted when adding or removing function calls from
1009 # the code path in charge of performing these operations.
1010 local_recursion_limit = sys.getrecursionlimit() // 16
1011 for pos, prefix in enumerate(prefix_gen()):
1012 if prefix not in self.subq_aliases:
1013 self.alias_prefix = prefix
1014 break
1015 if pos > local_recursion_limit:
1016 raise RecursionError(
1017 "Maximum recursion depth exceeded: too many subqueries."
1018 )
1019 self.subq_aliases = self.subq_aliases.union([self.alias_prefix])
1020 other_query.subq_aliases = other_query.subq_aliases.union(self.subq_aliases)
1021 if exclude is None:
1022 exclude = {}
1023 self.change_aliases(
1024 {
1025 alias: "%s%d" % (self.alias_prefix, pos) # noqa: UP031
1026 for pos, alias in enumerate(self.alias_map)
1027 if alias not in exclude
1028 }
1029 )
1030
1031 def get_initial_alias(self) -> str | None:
1032 """
1033 Return the first alias for this query, after increasing its reference
1034 count.
1035 """
1036 if self.alias_map:
1037 alias = self.base_table
1038 self.ref_alias(alias)
1039 elif self.model:
1040 alias = self.join(
1041 self.base_table_class(self.model.model_options.db_table, None)
1042 )
1043 else:
1044 alias = None
1045 return alias
1046
1047 def count_active_tables(self) -> int:
1048 """
1049 Return the number of tables in this query with a non-zero reference
1050 count. After execution, the reference counts are zeroed, so tables
1051 added in compiler will not be seen by this method.
1052 """
1053 return len([1 for count in self.alias_refcount.values() if count])
1054
1055 def join(
1056 self,
1057 join: BaseTable | Join,
1058 reuse: set[str] | None = None,
1059 reuse_with_filtered_relation: bool = False,
1060 ) -> str:
1061 """
1062 Return an alias for the 'join', either reusing an existing alias for
1063 that join or creating a new one. 'join' is either a base_table_class or
1064 join_class.
1065
1066 The 'reuse' parameter can be either None which means all joins are
1067 reusable, or it can be a set containing the aliases that can be reused.
1068
1069 The 'reuse_with_filtered_relation' parameter is used when computing
1070 FilteredRelation instances.
1071
1072 A join is always created as LOUTER if the lhs alias is LOUTER to make
1073 sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new
1074 joins are created as LOUTER if the join is nullable.
1075 """
1076 if reuse_with_filtered_relation and reuse:
1077 reuse_aliases = [
1078 a for a, j in self.alias_map.items() if a in reuse and j.equals(join)
1079 ]
1080 else:
1081 reuse_aliases = [
1082 a
1083 for a, j in self.alias_map.items()
1084 if (reuse is None or a in reuse) and j == join
1085 ]
1086 if reuse_aliases:
1087 if join.table_alias in reuse_aliases:
1088 reuse_alias = join.table_alias
1089 else:
1090 # Reuse the most recent alias of the joined table
1091 # (a many-to-many relation may be joined multiple times).
1092 reuse_alias = reuse_aliases[-1]
1093 self.ref_alias(reuse_alias)
1094 return reuse_alias
1095
1096 # No reuse is possible, so we need a new alias.
1097 alias, _ = self.table_alias(
1098 join.table_name, create=True, filtered_relation=join.filtered_relation
1099 )
1100 if isinstance(join, Join):
1101 if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable:
1102 join_type = LOUTER
1103 else:
1104 join_type = INNER
1105 join.join_type = join_type
1106 join.table_alias = alias
1107 self.alias_map[alias] = join
1108 return alias
1109
1110 def check_alias(self, alias: str) -> None:
1111 if FORBIDDEN_ALIAS_PATTERN.search(alias):
1112 raise ValueError(
1113 "Column aliases cannot contain whitespace characters, quotation marks, "
1114 "semicolons, or SQL comments."
1115 )
1116
1117 def add_annotation(
1118 self, annotation: BaseExpression, alias: str, select: bool = True
1119 ) -> None:
1120 """Add a single annotation expression to the Query."""
1121 self.check_alias(alias)
1122 annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None)
1123 if select:
1124 self.append_annotation_mask([alias])
1125 else:
1126 self.set_annotation_mask(set(self.annotation_select).difference({alias}))
1127 self.annotations[alias] = annotation
1128
1129 def resolve_expression(
1130 self,
1131 query: Any = None,
1132 allow_joins: bool = True,
1133 reuse: Any = None,
1134 summarize: bool = False,
1135 for_save: bool = False,
1136 ) -> Self:
1137 clone = self.clone()
1138 # Subqueries need to use a different set of aliases than the outer query.
1139 clone.bump_prefix(query)
1140 clone.subquery = True
1141 clone.where.resolve_expression(query, allow_joins, reuse, summarize, for_save)
1142 for key, value in clone.annotations.items():
1143 resolved = value.resolve_expression(
1144 query, allow_joins, reuse, summarize, for_save
1145 )
1146 if hasattr(resolved, "external_aliases"):
1147 resolved.external_aliases.update(clone.external_aliases)
1148 clone.annotations[key] = resolved
1149 # Outer query's aliases are considered external.
1150 for alias, table in query.alias_map.items():
1151 clone.external_aliases[alias] = (
1152 isinstance(table, Join)
1153 and table.join_field.related_model.model_options.db_table != alias
1154 ) or (
1155 isinstance(table, BaseTable) and table.table_name != table.table_alias
1156 )
1157 return clone
1158
1159 def get_external_cols(self) -> list[Col]:
1160 exprs = chain(self.annotations.values(), self.where.children)
1161 return [
1162 col
1163 for col in self._gen_cols(exprs, include_external=True)
1164 if col.alias in self.external_aliases
1165 ]
1166
1167 def get_group_by_cols( # type: ignore[override]
1168 self, wrapper: BaseExpression | None = None
1169 ) -> list[BaseExpression]:
1170 # If wrapper is referenced by an alias for an explicit GROUP BY through
1171 # values() a reference to this expression and not the self must be
1172 # returned to ensure external column references are not grouped against
1173 # as well.
1174 external_cols = self.get_external_cols()
1175 if any(col.possibly_multivalued for col in external_cols):
1176 return [wrapper or self]
1177 # Cast needed because list is invariant: list[Col] is not list[BaseExpression]
1178 return cast(list[BaseExpression], external_cols)
1179
1180 def as_sql(
1181 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1182 ) -> SqlWithParams:
1183 # Some backends (e.g. Oracle) raise an error when a subquery contains
1184 # unnecessary ORDER BY clause.
1185 if (
1186 self.subquery
1187 and not db_connection.features.ignores_unnecessary_order_by_in_subqueries
1188 ):
1189 self.clear_ordering(force=False)
1190 sql, params = self.get_compiler().as_sql()
1191 if self.subquery:
1192 sql = f"({sql})"
1193 return sql, params
1194
1195 def resolve_lookup_value(
1196 self, value: Any, can_reuse: set[str] | None, allow_joins: bool
1197 ) -> Any:
1198 if isinstance(value, ResolvableExpression):
1199 value = value.resolve_expression(
1200 self,
1201 reuse=can_reuse,
1202 allow_joins=allow_joins,
1203 )
1204 elif isinstance(value, list | tuple):
1205 # The items of the iterable may be expressions and therefore need
1206 # to be resolved independently.
1207 values = (
1208 self.resolve_lookup_value(sub_value, can_reuse, allow_joins)
1209 for sub_value in value
1210 )
1211 type_ = type(value)
1212 if hasattr(type_, "_make"): # namedtuple
1213 return type_(*values)
1214 return type_(values)
1215 return value
1216
1217 def solve_lookup_type(
1218 self, lookup: str, summarize: bool = False
1219 ) -> tuple[
1220 list[str] | tuple[str, ...], tuple[str, ...], BaseExpression | Literal[False]
1221 ]:
1222 """
1223 Solve the lookup type from the lookup (e.g.: 'foobar__id__icontains').
1224 """
1225 lookup_splitted = lookup.split(LOOKUP_SEP)
1226 if self.annotations:
1227 annotation, expression_lookups = refs_expression(
1228 lookup_splitted, self.annotations
1229 )
1230 if annotation:
1231 expression = self.annotations[annotation]
1232 if summarize:
1233 expression = Ref(annotation, expression)
1234 return expression_lookups, (), expression
1235 assert self.model is not None, "Field lookups require a model"
1236 meta = self.model._model_meta
1237 _, field, _, lookup_parts = self.names_to_path(lookup_splitted, meta)
1238 field_parts = lookup_splitted[0 : len(lookup_splitted) - len(lookup_parts)]
1239 if len(lookup_parts) > 1 and not field_parts:
1240 raise FieldError(
1241 f'Invalid lookup "{lookup}" for model {meta.model.__name__}".'
1242 )
1243 return lookup_parts, tuple(field_parts), False
1244
1245 def check_query_object_type(
1246 self, value: Any, meta: Meta, field: Field | ForeignObjectRel
1247 ) -> None:
1248 """
1249 Check whether the object passed while querying is of the correct type.
1250 If not, raise a ValueError specifying the wrong object.
1251 """
1252 from plain.models import Model
1253
1254 if isinstance(value, Model):
1255 if not check_rel_lookup_compatibility(value._model_meta.model, meta, field):
1256 raise ValueError(
1257 f'Cannot query "{value}": Must be "{meta.model.model_options.object_name}" instance.'
1258 )
1259
1260 def check_related_objects(
1261 self, field: RelatedField | ForeignObjectRel, value: Any, meta: Meta
1262 ) -> None:
1263 """Check the type of object passed to query relations."""
1264 from plain.models import Model
1265
1266 # Check that the field and the queryset use the same model in a
1267 # query like .filter(author=Author.query.all()). For example, the
1268 # meta would be Author's (from the author field) and value.model
1269 # would be Author.query.all() queryset's .model (Author also).
1270 # The field is the related field on the lhs side.
1271 if (
1272 isinstance(value, Query)
1273 and not value.has_select_fields
1274 and not check_rel_lookup_compatibility(value.model, meta, field)
1275 ):
1276 raise ValueError(
1277 f'Cannot use QuerySet for "{value.model.model_options.object_name}": Use a QuerySet for "{meta.model.model_options.object_name}".'
1278 )
1279 elif isinstance(value, Model):
1280 self.check_query_object_type(value, meta, field)
1281 elif isinstance(value, Iterable):
1282 for v in value:
1283 self.check_query_object_type(v, meta, field)
1284
1285 def check_filterable(self, expression: Any) -> None:
1286 """Raise an error if expression cannot be used in a WHERE clause."""
1287 if isinstance(expression, ResolvableExpression) and not getattr(
1288 expression, "filterable", True
1289 ):
1290 raise NotSupportedError(
1291 expression.__class__.__name__ + " is disallowed in the filter clause."
1292 )
1293 if hasattr(expression, "get_source_expressions"):
1294 for expr in expression.get_source_expressions():
1295 self.check_filterable(expr)
1296
1297 def build_lookup(
1298 self, lookups: list[str], lhs: BaseExpression | MultiColSource, rhs: Any
1299 ) -> Lookup | None:
1300 """
1301 Try to extract transforms and lookup from given lhs.
1302
1303 The lhs value is something that works like SQLExpression.
1304 The rhs value is what the lookup is going to compare against.
1305 The lookups is a list of names to extract using get_lookup()
1306 and get_transform().
1307 """
1308 # __exact is the default lookup if one isn't given.
1309 *transforms, lookup_name = lookups or ["exact"]
1310 if transforms:
1311 if isinstance(lhs, MultiColSource):
1312 raise FieldError(
1313 "Transforms are not supported on multi-column relations."
1314 )
1315 # At this point, lhs must be BaseExpression
1316 for name in transforms:
1317 lhs = self.try_transform(lhs, name)
1318 # First try get_lookup() so that the lookup takes precedence if the lhs
1319 # supports both transform and lookup for the name.
1320 lookup_class = lhs.get_lookup(lookup_name)
1321 if not lookup_class:
1322 # A lookup wasn't found. Try to interpret the name as a transform
1323 # and do an Exact lookup against it.
1324 if isinstance(lhs, MultiColSource):
1325 raise FieldError(
1326 "Transforms are not supported on multi-column relations."
1327 )
1328 lhs = self.try_transform(lhs, lookup_name)
1329 lookup_name = "exact"
1330 lookup_class = lhs.get_lookup(lookup_name)
1331 if not lookup_class:
1332 return
1333
1334 lookup = lookup_class(lhs, rhs)
1335 # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
1336 # uses of None as a query value unless the lookup supports it.
1337 if lookup.rhs is None and not lookup.can_use_none_as_rhs:
1338 if lookup_name not in ("exact", "iexact"):
1339 raise ValueError("Cannot use None as a query value")
1340 isnull_lookup = lhs.get_lookup("isnull")
1341 assert isnull_lookup is not None
1342 return isnull_lookup(lhs, True)
1343
1344 return lookup
1345
1346 def try_transform(self, lhs: BaseExpression, name: str) -> BaseExpression:
1347 """
1348 Helper method for build_lookup(). Try to fetch and initialize
1349 a transform for name parameter from lhs.
1350 """
1351 transform_class = lhs.get_transform(name)
1352 if transform_class:
1353 return transform_class(lhs)
1354 else:
1355 output_field = lhs.output_field.__class__
1356 suggested_lookups = difflib.get_close_matches(
1357 name, output_field.get_lookups()
1358 )
1359 if suggested_lookups:
1360 suggestion = ", perhaps you meant {}?".format(
1361 " or ".join(suggested_lookups)
1362 )
1363 else:
1364 suggestion = "."
1365 raise FieldError(
1366 f"Unsupported lookup '{name}' for {output_field.__name__} or join on the field not "
1367 f"permitted{suggestion}"
1368 )
1369
1370 def build_filter(
1371 self,
1372 filter_expr: tuple[str, Any] | Q | BaseExpression,
1373 branch_negated: bool = False,
1374 current_negated: bool = False,
1375 can_reuse: set[str] | None = None,
1376 allow_joins: bool = True,
1377 split_subq: bool = True,
1378 reuse_with_filtered_relation: bool = False,
1379 check_filterable: bool = True,
1380 summarize: bool = False,
1381 ) -> tuple[WhereNode, set[str] | tuple[()]]:
1382 from plain.models.fields.related import RelatedField
1383
1384 """
1385 Build a WhereNode for a single filter clause but don't add it
1386 to this Query. Query.add_q() will then add this filter to the where
1387 Node.
1388
1389 The 'branch_negated' tells us if the current branch contains any
1390 negations. This will be used to determine if subqueries are needed.
1391
1392 The 'current_negated' is used to determine if the current filter is
1393 negated or not and this will be used to determine if IS NULL filtering
1394 is needed.
1395
1396 The difference between current_negated and branch_negated is that
1397 branch_negated is set on first negation, but current_negated is
1398 flipped for each negation.
1399
1400 Note that add_filter will not do any negating itself, that is done
1401 upper in the code by add_q().
1402
1403 The 'can_reuse' is a set of reusable joins for multijoins.
1404
1405 If 'reuse_with_filtered_relation' is True, then only joins in can_reuse
1406 will be reused.
1407
1408 The method will create a filter clause that can be added to the current
1409 query. However, if the filter isn't added to the query then the caller
1410 is responsible for unreffing the joins used.
1411 """
1412 if isinstance(filter_expr, dict):
1413 raise FieldError("Cannot parse keyword query as dict")
1414 if isinstance(filter_expr, Q):
1415 return self._add_q(
1416 filter_expr,
1417 branch_negated=branch_negated,
1418 current_negated=current_negated,
1419 used_aliases=can_reuse,
1420 allow_joins=allow_joins,
1421 split_subq=split_subq,
1422 check_filterable=check_filterable,
1423 summarize=summarize,
1424 )
1425 if isinstance(filter_expr, ResolvableExpression):
1426 if not getattr(filter_expr, "conditional", False):
1427 raise TypeError("Cannot filter against a non-conditional expression.")
1428 condition = filter_expr.resolve_expression( # type: ignore[call-non-callable]
1429 self, allow_joins=allow_joins, summarize=summarize
1430 )
1431 if not isinstance(condition, Lookup):
1432 condition = self.build_lookup(["exact"], condition, True)
1433 return WhereNode([condition], connector=AND), set()
1434 if isinstance(filter_expr, BaseExpression):
1435 raise TypeError(f"Unexpected BaseExpression type: {type(filter_expr)}")
1436 arg, value = filter_expr
1437 if not arg:
1438 raise FieldError(f"Cannot parse keyword query {arg!r}")
1439 lookups, parts, reffed_expression = self.solve_lookup_type(arg, summarize)
1440
1441 if check_filterable:
1442 self.check_filterable(reffed_expression)
1443
1444 if not allow_joins and len(parts) > 1:
1445 raise FieldError("Joined field references are not permitted in this query")
1446
1447 pre_joins = self.alias_refcount.copy()
1448 value = self.resolve_lookup_value(value, can_reuse, allow_joins)
1449 used_joins = {
1450 k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)
1451 }
1452
1453 if check_filterable:
1454 self.check_filterable(value)
1455
1456 if reffed_expression:
1457 condition = self.build_lookup(list(lookups), reffed_expression, value)
1458 return WhereNode([condition], connector=AND), set()
1459
1460 assert self.model is not None, "Building filters requires a model"
1461 meta = self.model._model_meta
1462 alias = self.get_initial_alias()
1463 assert alias is not None
1464 allow_many = not branch_negated or not split_subq
1465
1466 try:
1467 join_info = self.setup_joins(
1468 list(parts),
1469 meta,
1470 alias,
1471 can_reuse=can_reuse,
1472 allow_many=allow_many,
1473 reuse_with_filtered_relation=reuse_with_filtered_relation,
1474 )
1475
1476 # Prevent iterator from being consumed by check_related_objects()
1477 if isinstance(value, Iterator):
1478 value = list(value)
1479 from plain.models.fields.related import RelatedField
1480 from plain.models.fields.reverse_related import ForeignObjectRel
1481
1482 if isinstance(join_info.final_field, RelatedField | ForeignObjectRel):
1483 self.check_related_objects(join_info.final_field, value, join_info.meta)
1484
1485 # split_exclude() needs to know which joins were generated for the
1486 # lookup parts
1487 self._lookup_joins = join_info.joins
1488 except MultiJoin as e:
1489 return self.split_exclude(
1490 filter_expr,
1491 can_reuse or set(),
1492 e.names_with_path,
1493 )
1494
1495 # Update used_joins before trimming since they are reused to determine
1496 # which joins could be later promoted to INNER.
1497 used_joins.update(join_info.joins)
1498 targets, alias, join_list = self.trim_joins(
1499 join_info.targets, join_info.joins, join_info.path
1500 )
1501 if can_reuse is not None:
1502 can_reuse.update(join_list)
1503
1504 if isinstance(join_info.final_field, RelatedField | ForeignObjectRel):
1505 if len(targets) == 1:
1506 col = self._get_col(targets[0], join_info.final_field, alias)
1507 else:
1508 col = MultiColSource(
1509 alias, targets, join_info.targets, join_info.final_field
1510 )
1511 else:
1512 col = self._get_col(targets[0], join_info.final_field, alias)
1513
1514 condition = self.build_lookup(list(lookups), col, value)
1515 assert condition is not None
1516 lookup_type = condition.lookup_name
1517 clause = WhereNode([condition], connector=AND)
1518
1519 require_outer = (
1520 lookup_type == "isnull" and condition.rhs is True and not current_negated
1521 )
1522 if (
1523 current_negated
1524 and (lookup_type != "isnull" or condition.rhs is False)
1525 and condition.rhs is not None
1526 ):
1527 require_outer = True
1528 if lookup_type != "isnull":
1529 # The condition added here will be SQL like this:
1530 # NOT (col IS NOT NULL), where the first NOT is added in
1531 # upper layers of code. The reason for addition is that if col
1532 # is null, then col != someval will result in SQL "unknown"
1533 # which isn't the same as in Python. The Python None handling
1534 # is wanted, and it can be gotten by
1535 # (col IS NULL OR col != someval)
1536 # <=>
1537 # NOT (col IS NOT NULL AND col = someval).
1538 if (
1539 self.is_nullable(targets[0])
1540 or self.alias_map[join_list[-1]].join_type == LOUTER
1541 ):
1542 lookup_class = targets[0].get_lookup("isnull")
1543 assert lookup_class is not None
1544 col = self._get_col(targets[0], join_info.targets[0], alias)
1545 clause.add(lookup_class(col, False), AND)
1546 # If someval is a nullable column, someval IS NOT NULL is
1547 # added.
1548 if isinstance(value, Col) and self.is_nullable(value.target):
1549 lookup_class = value.target.get_lookup("isnull")
1550 assert lookup_class is not None
1551 clause.add(lookup_class(value, False), AND)
1552 return clause, used_joins if not require_outer else ()
1553
1554 def add_filter(self, filter_lhs: str, filter_rhs: Any) -> None:
1555 self.add_q(Q((filter_lhs, filter_rhs)))
1556
1557 def add_q(self, q_object: Q) -> None:
1558 """
1559 A preprocessor for the internal _add_q(). Responsible for doing final
1560 join promotion.
1561 """
1562 # For join promotion this case is doing an AND for the added q_object
1563 # and existing conditions. So, any existing inner join forces the join
1564 # type to remain inner. Existing outer joins can however be demoted.
1565 # (Consider case where rel_a is LOUTER and rel_a__col=1 is added - if
1566 # rel_a doesn't produce any rows, then the whole condition must fail.
1567 # So, demotion is OK.
1568 existing_inner = {
1569 a for a in self.alias_map if self.alias_map[a].join_type == INNER
1570 }
1571 clause, _ = self._add_q(q_object, self.used_aliases)
1572 if clause:
1573 self.where.add(clause, AND)
1574 self.demote_joins(existing_inner)
1575
1576 def build_where(
1577 self, filter_expr: tuple[str, Any] | Q | BaseExpression
1578 ) -> WhereNode:
1579 return self.build_filter(filter_expr, allow_joins=False)[0]
1580
1581 def clear_where(self) -> None:
1582 self.where = WhereNode()
1583
1584 def _add_q(
1585 self,
1586 q_object: Q,
1587 used_aliases: set[str] | None,
1588 branch_negated: bool = False,
1589 current_negated: bool = False,
1590 allow_joins: bool = True,
1591 split_subq: bool = True,
1592 check_filterable: bool = True,
1593 summarize: bool = False,
1594 ) -> tuple[WhereNode, set[str] | tuple[()]]:
1595 """Add a Q-object to the current filter."""
1596 connector = q_object.connector
1597 current_negated ^= q_object.negated
1598 branch_negated = branch_negated or q_object.negated
1599 target_clause = WhereNode(connector=connector, negated=q_object.negated)
1600 joinpromoter = JoinPromoter(
1601 q_object.connector, len(q_object.children), current_negated
1602 )
1603 for child in q_object.children:
1604 child_clause, needed_inner = self.build_filter(
1605 child,
1606 can_reuse=used_aliases,
1607 branch_negated=branch_negated,
1608 current_negated=current_negated,
1609 allow_joins=allow_joins,
1610 split_subq=split_subq,
1611 check_filterable=check_filterable,
1612 summarize=summarize,
1613 )
1614 joinpromoter.add_votes(needed_inner)
1615 if child_clause:
1616 target_clause.add(child_clause, connector)
1617 needed_inner = joinpromoter.update_join_types(self)
1618 return target_clause, needed_inner
1619
1620 def build_filtered_relation_q(
1621 self,
1622 q_object: Q,
1623 reuse: set[str],
1624 branch_negated: bool = False,
1625 current_negated: bool = False,
1626 ) -> WhereNode:
1627 """Add a FilteredRelation object to the current filter."""
1628 connector = q_object.connector
1629 current_negated ^= q_object.negated
1630 branch_negated = branch_negated or q_object.negated
1631 target_clause = WhereNode(connector=connector, negated=q_object.negated)
1632 for child in q_object.children:
1633 if isinstance(child, Node):
1634 child_clause = self.build_filtered_relation_q(
1635 child,
1636 reuse=reuse,
1637 branch_negated=branch_negated,
1638 current_negated=current_negated,
1639 )
1640 else:
1641 child_clause, _ = self.build_filter(
1642 child,
1643 can_reuse=reuse,
1644 branch_negated=branch_negated,
1645 current_negated=current_negated,
1646 allow_joins=True,
1647 split_subq=False,
1648 reuse_with_filtered_relation=True,
1649 )
1650 target_clause.add(child_clause, connector)
1651 return target_clause
1652
1653 def add_filtered_relation(self, filtered_relation: Any, alias: str) -> None:
1654 filtered_relation.alias = alias
1655 lookups = dict(get_children_from_q(filtered_relation.condition))
1656 relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(
1657 filtered_relation.relation_name
1658 )
1659 if relation_lookup_parts:
1660 raise ValueError(
1661 "FilteredRelation's relation_name cannot contain lookups "
1662 f"(got {filtered_relation.relation_name!r})."
1663 )
1664 for lookup in chain(lookups):
1665 lookup_parts, lookup_field_parts, _ = self.solve_lookup_type(lookup)
1666 shift = 2 if not lookup_parts else 1
1667 lookup_field_path = lookup_field_parts[:-shift]
1668 for idx, lookup_field_part in enumerate(lookup_field_path):
1669 if len(relation_field_parts) > idx:
1670 if relation_field_parts[idx] != lookup_field_part:
1671 raise ValueError(
1672 "FilteredRelation's condition doesn't support "
1673 f"relations outside the {filtered_relation.relation_name!r} (got {lookup!r})."
1674 )
1675 else:
1676 raise ValueError(
1677 "FilteredRelation's condition doesn't support nested "
1678 f"relations deeper than the relation_name (got {lookup!r} for "
1679 f"{filtered_relation.relation_name!r})."
1680 )
1681 self._filtered_relations[filtered_relation.alias] = filtered_relation
1682
1683 def names_to_path(
1684 self,
1685 names: list[str],
1686 meta: Meta,
1687 allow_many: bool = True,
1688 fail_on_missing: bool = False,
1689 ) -> tuple[list[Any], Field | ForeignObjectRel, tuple[Field, ...], list[str]]:
1690 """
1691 Walk the list of names and turns them into PathInfo tuples. A single
1692 name in 'names' can generate multiple PathInfos (m2m, for example).
1693
1694 'names' is the path of names to travel, 'meta' is the Meta we
1695 start the name resolving from, 'allow_many' is as for setup_joins().
1696 If fail_on_missing is set to True, then a name that can't be resolved
1697 will generate a FieldError.
1698
1699 Return a list of PathInfo tuples. In addition return the final field
1700 (the last used join field) and target (which is a field guaranteed to
1701 contain the same value as the final field). Finally, return those names
1702 that weren't found (which are likely transforms and the final lookup).
1703 """
1704 path, names_with_path = [], []
1705 for pos, name in enumerate(names):
1706 cur_names_with_path = (name, [])
1707
1708 field = None
1709 filtered_relation = None
1710 try:
1711 if meta is None:
1712 raise FieldDoesNotExist
1713 field = meta.get_field(name)
1714 except FieldDoesNotExist:
1715 if name in self.annotation_select:
1716 field = self.annotation_select[name].output_field
1717 elif name in self._filtered_relations and pos == 0:
1718 filtered_relation = self._filtered_relations[name]
1719 if LOOKUP_SEP in filtered_relation.relation_name:
1720 parts = filtered_relation.relation_name.split(LOOKUP_SEP)
1721 filtered_relation_path, field, _, _ = self.names_to_path(
1722 parts,
1723 meta,
1724 allow_many,
1725 fail_on_missing,
1726 )
1727 path.extend(filtered_relation_path[:-1])
1728 else:
1729 field = meta.get_field(filtered_relation.relation_name)
1730 if field is None:
1731 # We didn't find the current field, so move position back
1732 # one step.
1733 pos -= 1
1734 if pos == -1 or fail_on_missing:
1735 available = sorted(
1736 [
1737 *get_field_names_from_opts(meta),
1738 *self.annotation_select,
1739 *self._filtered_relations,
1740 ]
1741 )
1742 raise FieldError(
1743 "Cannot resolve keyword '{}' into field. "
1744 "Choices are: {}".format(name, ", ".join(available))
1745 )
1746 break
1747
1748 # Lazy import to avoid circular dependency
1749 from plain.models.fields.related import ForeignKeyField as FK
1750 from plain.models.fields.related import ManyToManyField as M2M
1751 from plain.models.fields.reverse_related import ForeignObjectRel as FORel
1752
1753 if isinstance(field, FK | M2M | FORel):
1754 pathinfos: list[PathInfo]
1755 if filtered_relation:
1756 pathinfos = field.get_path_info(filtered_relation)
1757 else:
1758 pathinfos = field.path_infos
1759 if not allow_many:
1760 for inner_pos, p in enumerate(pathinfos):
1761 if p.m2m:
1762 cur_names_with_path[1].extend(pathinfos[0 : inner_pos + 1])
1763 names_with_path.append(cur_names_with_path)
1764 raise MultiJoin(pos + 1, names_with_path)
1765 last = pathinfos[-1]
1766 path.extend(pathinfos)
1767 final_field = last.join_field
1768 meta = last.to_meta
1769 targets = last.target_fields
1770 cur_names_with_path[1].extend(pathinfos)
1771 names_with_path.append(cur_names_with_path)
1772 else:
1773 # Local non-relational field.
1774 final_field = field
1775 targets = (field,)
1776 if fail_on_missing and pos + 1 != len(names):
1777 raise FieldError(
1778 f"Cannot resolve keyword {names[pos + 1]!r} into field. Join on '{name}'"
1779 " not permitted."
1780 )
1781 break
1782 return path, final_field, targets, names[pos + 1 :]
1783
1784 def setup_joins(
1785 self,
1786 names: list[str],
1787 meta: Meta,
1788 alias: str,
1789 can_reuse: set[str] | None = None,
1790 allow_many: bool = True,
1791 reuse_with_filtered_relation: bool = False,
1792 ) -> JoinInfo:
1793 """
1794 Compute the necessary table joins for the passage through the fields
1795 given in 'names'. 'meta' is the Meta for the current model
1796 (which gives the table we are starting from), 'alias' is the alias for
1797 the table to start the joining from.
1798
1799 The 'can_reuse' defines the reverse foreign key joins we can reuse. It
1800 can be None in which case all joins are reusable or a set of aliases
1801 that can be reused. Note that non-reverse foreign keys are always
1802 reusable when using setup_joins().
1803
1804 The 'reuse_with_filtered_relation' can be used to force 'can_reuse'
1805 parameter and force the relation on the given connections.
1806
1807 If 'allow_many' is False, then any reverse foreign key seen will
1808 generate a MultiJoin exception.
1809
1810 Return the final field involved in the joins, the target field (used
1811 for any 'where' constraint), the final 'opts' value, the joins, the
1812 field path traveled to generate the joins, and a transform function
1813 that takes a field and alias and is equivalent to `field.get_col(alias)`
1814 in the simple case but wraps field transforms if they were included in
1815 names.
1816
1817 The target field is the field containing the concrete value. Final
1818 field can be something different, for example foreign key pointing to
1819 that value. Final field is needed for example in some value
1820 conversions (convert 'obj' in fk__id=obj to pk val using the foreign
1821 key field for example).
1822 """
1823 joins = [alias]
1824 # The transform can't be applied yet, as joins must be trimmed later.
1825 # To avoid making every caller of this method look up transforms
1826 # directly, compute transforms here and create a partial that converts
1827 # fields to the appropriate wrapped version.
1828
1829 def _base_transformer(field: Field, alias: str | None) -> Col:
1830 if not self.alias_cols:
1831 alias = None
1832 return field.get_col(alias)
1833
1834 final_transformer: TransformWrapper | Callable[[Field, str | None], Col] = (
1835 _base_transformer
1836 )
1837
1838 # Try resolving all the names as fields first. If there's an error,
1839 # treat trailing names as lookups until a field can be resolved.
1840 last_field_exception = None
1841 for pivot in range(len(names), 0, -1):
1842 try:
1843 path, final_field, targets, rest = self.names_to_path(
1844 names[:pivot],
1845 meta,
1846 allow_many,
1847 fail_on_missing=True,
1848 )
1849 except FieldError as exc:
1850 if pivot == 1:
1851 # The first item cannot be a lookup, so it's safe
1852 # to raise the field error here.
1853 raise
1854 else:
1855 last_field_exception = exc
1856 else:
1857 # The transforms are the remaining items that couldn't be
1858 # resolved into fields.
1859 transforms = names[pivot:]
1860 break
1861 for name in transforms:
1862
1863 def transform(
1864 field: Field, alias: str | None, *, name: str, previous: Any
1865 ) -> BaseExpression:
1866 try:
1867 wrapped = previous(field, alias)
1868 return self.try_transform(wrapped, name)
1869 except FieldError:
1870 # FieldError is raised if the transform doesn't exist.
1871 if isinstance(final_field, Field) and last_field_exception:
1872 raise last_field_exception
1873 else:
1874 raise
1875
1876 final_transformer = TransformWrapper(
1877 transform, name=name, previous=final_transformer
1878 )
1879 final_transformer.has_transforms = True
1880 # Then, add the path to the query's joins. Note that we can't trim
1881 # joins at this stage - we will need the information about join type
1882 # of the trimmed joins.
1883 for join in path:
1884 if join.filtered_relation:
1885 filtered_relation = join.filtered_relation.clone()
1886 table_alias = filtered_relation.alias
1887 else:
1888 filtered_relation = None
1889 table_alias = None
1890 meta = join.to_meta
1891 if join.direct:
1892 nullable = self.is_nullable(join.join_field)
1893 else:
1894 nullable = True
1895 connection = self.join_class(
1896 meta.model.model_options.db_table,
1897 alias,
1898 table_alias,
1899 INNER,
1900 join.join_field,
1901 nullable,
1902 filtered_relation=filtered_relation,
1903 )
1904 reuse = can_reuse if join.m2m or reuse_with_filtered_relation else None
1905 alias = self.join(
1906 connection,
1907 reuse=reuse,
1908 reuse_with_filtered_relation=reuse_with_filtered_relation,
1909 )
1910 joins.append(alias)
1911 if filtered_relation:
1912 filtered_relation.path = joins[:]
1913 return JoinInfo(final_field, targets, meta, joins, path, final_transformer)
1914
1915 def trim_joins(
1916 self, targets: tuple[Field, ...], joins: list[str], path: list[Any]
1917 ) -> tuple[tuple[Field, ...], str, list[str]]:
1918 """
1919 The 'target' parameter is the final field being joined to, 'joins'
1920 is the full list of join aliases. The 'path' contain the PathInfos
1921 used to create the joins.
1922
1923 Return the final target field and table alias and the new active
1924 joins.
1925
1926 Always trim any direct join if the target column is already in the
1927 previous table. Can't trim reverse joins as it's unknown if there's
1928 anything on the other side of the join.
1929 """
1930 joins = joins[:]
1931 for pos, info in enumerate(reversed(path)):
1932 if len(joins) == 1 or not info.direct:
1933 break
1934 if info.filtered_relation:
1935 break
1936 join_targets = {t.column for t in info.join_field.foreign_related_fields}
1937 cur_targets = {t.column for t in targets}
1938 if not cur_targets.issubset(join_targets):
1939 break
1940 targets_dict = {
1941 r[1].column: r[0]
1942 for r in info.join_field.related_fields
1943 if r[1].column in cur_targets
1944 }
1945 targets = tuple(targets_dict[t.column] for t in targets)
1946 self.unref_alias(joins.pop())
1947 return targets, joins[-1], joins
1948
1949 @classmethod
1950 def _gen_cols(
1951 cls,
1952 exprs: Iterable[Any],
1953 include_external: bool = False,
1954 resolve_refs: bool = True,
1955 ) -> TypingIterator[Col]:
1956 for expr in exprs:
1957 if isinstance(expr, Col):
1958 yield expr
1959 elif include_external and callable(
1960 getattr(expr, "get_external_cols", None)
1961 ):
1962 yield from expr.get_external_cols()
1963 elif hasattr(expr, "get_source_expressions"):
1964 if not resolve_refs and isinstance(expr, Ref):
1965 continue
1966 yield from cls._gen_cols(
1967 expr.get_source_expressions(),
1968 include_external=include_external,
1969 resolve_refs=resolve_refs,
1970 )
1971
1972 @classmethod
1973 def _gen_col_aliases(cls, exprs: Iterable[Any]) -> TypingIterator[str]:
1974 yield from (expr.alias for expr in cls._gen_cols(exprs))
1975
1976 def resolve_ref(
1977 self,
1978 name: str,
1979 allow_joins: bool = True,
1980 reuse: set[str] | None = None,
1981 summarize: bool = False,
1982 ) -> BaseExpression:
1983 annotation = self.annotations.get(name)
1984 if annotation is not None:
1985 if not allow_joins:
1986 for alias in self._gen_col_aliases([annotation]):
1987 if isinstance(self.alias_map[alias], Join):
1988 raise FieldError(
1989 "Joined field references are not permitted in this query"
1990 )
1991 if summarize:
1992 # Summarize currently means we are doing an aggregate() query
1993 # which is executed as a wrapped subquery if any of the
1994 # aggregate() elements reference an existing annotation. In
1995 # that case we need to return a Ref to the subquery's annotation.
1996 if name not in self.annotation_select:
1997 raise FieldError(
1998 f"Cannot aggregate over the '{name}' alias. Use annotate() "
1999 "to promote it."
2000 )
2001 return Ref(name, self.annotation_select[name])
2002 else:
2003 return annotation
2004 else:
2005 field_list = name.split(LOOKUP_SEP)
2006 annotation = self.annotations.get(field_list[0])
2007 if annotation is not None:
2008 for transform in field_list[1:]:
2009 annotation = self.try_transform(annotation, transform)
2010 return annotation
2011 initial_alias = self.get_initial_alias()
2012 assert initial_alias is not None
2013 assert self.model is not None, "Resolving field references requires a model"
2014 meta = self.model._model_meta
2015 join_info = self.setup_joins(
2016 field_list,
2017 meta,
2018 initial_alias,
2019 can_reuse=reuse,
2020 )
2021 targets, final_alias, join_list = self.trim_joins(
2022 join_info.targets, join_info.joins, join_info.path
2023 )
2024 if not allow_joins and len(join_list) > 1:
2025 raise FieldError(
2026 "Joined field references are not permitted in this query"
2027 )
2028 if len(targets) > 1:
2029 raise FieldError(
2030 "Referencing multicolumn fields with F() objects isn't supported"
2031 )
2032 # Verify that the last lookup in name is a field or a transform:
2033 # transform_function() raises FieldError if not.
2034 transform = join_info.transform_function(targets[0], final_alias)
2035 if reuse is not None:
2036 reuse.update(join_list)
2037 return transform
2038
2039 def split_exclude(
2040 self,
2041 filter_expr: tuple[str, Any],
2042 can_reuse: set[str],
2043 names_with_path: list[tuple[str, list[Any]]],
2044 ) -> tuple[WhereNode, set[str] | tuple[()]]:
2045 """
2046 When doing an exclude against any kind of N-to-many relation, we need
2047 to use a subquery. This method constructs the nested query, given the
2048 original exclude filter (filter_expr) and the portion up to the first
2049 N-to-many relation field.
2050
2051 For example, if the origin filter is ~Q(child__name='foo'), filter_expr
2052 is ('child__name', 'foo') and can_reuse is a set of joins usable for
2053 filters in the original query.
2054
2055 We will turn this into equivalent of:
2056 WHERE NOT EXISTS(
2057 SELECT 1
2058 FROM child
2059 WHERE name = 'foo' AND child.parent_id = parent.id
2060 LIMIT 1
2061 )
2062 """
2063 # Generate the inner query.
2064 query = self.__class__(self.model)
2065 query._filtered_relations = self._filtered_relations
2066 filter_lhs, filter_rhs = filter_expr
2067 if isinstance(filter_rhs, OuterRef):
2068 filter_rhs = OuterRef(filter_rhs)
2069 elif isinstance(filter_rhs, F):
2070 filter_rhs = OuterRef(filter_rhs.name)
2071 query.add_filter(filter_lhs, filter_rhs)
2072 query.clear_ordering(force=True)
2073 # Try to have as simple as possible subquery -> trim leading joins from
2074 # the subquery.
2075 trimmed_prefix, contains_louter = query.trim_start(names_with_path)
2076
2077 col = query.select[0]
2078 select_field = col.target
2079 alias = col.alias
2080 if alias in can_reuse:
2081 id_field = select_field.model._model_meta.get_forward_field("id")
2082 # Need to add a restriction so that outer query's filters are in effect for
2083 # the subquery, too.
2084 query.bump_prefix(self)
2085 lookup_class = select_field.get_lookup("exact")
2086 # Note that the query.select[0].alias is different from alias
2087 # due to bump_prefix above.
2088 lookup = lookup_class(
2089 id_field.get_col(query.select[0].alias), id_field.get_col(alias)
2090 )
2091 query.where.add(lookup, AND)
2092 query.external_aliases[alias] = True
2093
2094 lookup_class = select_field.get_lookup("exact")
2095 lookup = lookup_class(col, ResolvedOuterRef(trimmed_prefix))
2096 query.where.add(lookup, AND)
2097 condition, needed_inner = self.build_filter(Exists(query))
2098
2099 if contains_louter:
2100 or_null_condition, _ = self.build_filter(
2101 (f"{trimmed_prefix}__isnull", True),
2102 current_negated=True,
2103 branch_negated=True,
2104 can_reuse=can_reuse,
2105 )
2106 condition.add(or_null_condition, OR)
2107 # Note that the end result will be:
2108 # (outercol NOT IN innerq AND outercol IS NOT NULL) OR outercol IS NULL.
2109 # This might look crazy but due to how IN works, this seems to be
2110 # correct. If the IS NOT NULL check is removed then outercol NOT
2111 # IN will return UNKNOWN. If the IS NULL check is removed, then if
2112 # outercol IS NULL we will not match the row.
2113 return condition, needed_inner
2114
2115 def set_empty(self) -> None:
2116 self.where.add(NothingNode(), AND)
2117
2118 def is_empty(self) -> bool:
2119 return any(isinstance(c, NothingNode) for c in self.where.children)
2120
2121 def set_limits(self, low: int | None = None, high: int | None = None) -> None:
2122 """
2123 Adjust the limits on the rows retrieved. Use low/high to set these,
2124 as it makes it more Pythonic to read and write. When the SQL query is
2125 created, convert them to the appropriate offset and limit values.
2126
2127 Apply any limits passed in here to the existing constraints. Add low
2128 to the current low value and clamp both to any existing high value.
2129 """
2130 if high is not None:
2131 if self.high_mark is not None:
2132 self.high_mark = min(self.high_mark, self.low_mark + high)
2133 else:
2134 self.high_mark = self.low_mark + high
2135 if low is not None:
2136 if self.high_mark is not None:
2137 self.low_mark = min(self.high_mark, self.low_mark + low)
2138 else:
2139 self.low_mark = self.low_mark + low
2140
2141 if self.low_mark == self.high_mark:
2142 self.set_empty()
2143
2144 def clear_limits(self) -> None:
2145 """Clear any existing limits."""
2146 self.low_mark, self.high_mark = 0, None
2147
2148 @property
2149 def is_sliced(self) -> bool:
2150 return self.low_mark != 0 or self.high_mark is not None
2151
2152 def has_limit_one(self) -> bool:
2153 return self.high_mark is not None and (self.high_mark - self.low_mark) == 1
2154
2155 def can_filter(self) -> bool:
2156 """
2157 Return True if adding filters to this instance is still possible.
2158
2159 Typically, this means no limits or offsets have been put on the results.
2160 """
2161 return not self.is_sliced
2162
2163 def clear_select_clause(self) -> None:
2164 """Remove all fields from SELECT clause."""
2165 self.select = ()
2166 self.default_cols = False
2167 self.select_related = False
2168 self.set_extra_mask(())
2169 self.set_annotation_mask(())
2170
2171 def clear_select_fields(self) -> None:
2172 """
2173 Clear the list of fields to select (but not extra_select columns).
2174 Some queryset types completely replace any existing list of select
2175 columns.
2176 """
2177 self.select = ()
2178 self.values_select = ()
2179
2180 def add_select_col(self, col: BaseExpression, name: str) -> None:
2181 self.select += (col,)
2182 self.values_select += (name,)
2183
2184 def set_select(self, cols: list[Col] | tuple[Col, ...]) -> None:
2185 self.default_cols = False
2186 self.select = tuple(cols)
2187
2188 def add_distinct_fields(self, *field_names: str) -> None:
2189 """
2190 Add and resolve the given fields to the query's "distinct on" clause.
2191 """
2192 self.distinct_fields = field_names
2193 self.distinct = True
2194
2195 def add_fields(
2196 self, field_names: list[str] | TypingIterator[str], allow_m2m: bool = True
2197 ) -> None:
2198 """
2199 Add the given (model) fields to the select set. Add the field names in
2200 the order specified.
2201 """
2202 alias = self.get_initial_alias()
2203 assert alias is not None
2204 assert self.model is not None, "add_fields() requires a model"
2205 meta = self.model._model_meta
2206
2207 try:
2208 cols = []
2209 for name in field_names:
2210 # Join promotion note - we must not remove any rows here, so
2211 # if there is no existing joins, use outer join.
2212 join_info = self.setup_joins(
2213 name.split(LOOKUP_SEP), meta, alias, allow_many=allow_m2m
2214 )
2215 targets, final_alias, joins = self.trim_joins(
2216 join_info.targets,
2217 join_info.joins,
2218 join_info.path,
2219 )
2220 for target in targets:
2221 cols.append(join_info.transform_function(target, final_alias))
2222 if cols:
2223 self.set_select(cols)
2224 except MultiJoin:
2225 raise FieldError(f"Invalid field name: '{name}'")
2226 except FieldError:
2227 if LOOKUP_SEP in name:
2228 # For lookups spanning over relationships, show the error
2229 # from the model on which the lookup failed.
2230 raise
2231 elif name in self.annotations:
2232 raise FieldError(
2233 f"Cannot select the '{name}' alias. Use annotate() to promote it."
2234 )
2235 else:
2236 names = sorted(
2237 [
2238 *get_field_names_from_opts(meta),
2239 *self.extra,
2240 *self.annotation_select,
2241 *self._filtered_relations,
2242 ]
2243 )
2244 raise FieldError(
2245 "Cannot resolve keyword {!r} into field. Choices are: {}".format(
2246 name, ", ".join(names)
2247 )
2248 )
2249
2250 def add_ordering(self, *ordering: str | BaseExpression) -> None:
2251 """
2252 Add items from the 'ordering' sequence to the query's "order by"
2253 clause. These items are either field names (not column names) --
2254 possibly with a direction prefix ('-' or '?') -- or OrderBy
2255 expressions.
2256
2257 If 'ordering' is empty, clear all ordering from the query.
2258 """
2259 errors = []
2260 for item in ordering:
2261 if isinstance(item, str):
2262 if item == "?":
2263 continue
2264 item = item.removeprefix("-")
2265 if item in self.annotations:
2266 continue
2267 if self.extra and item in self.extra:
2268 continue
2269 # names_to_path() validates the lookup. A descriptive
2270 # FieldError will be raise if it's not.
2271 assert self.model is not None, "ORDER BY field names require a model"
2272 self.names_to_path(item.split(LOOKUP_SEP), self.model._model_meta)
2273 elif not isinstance(item, ResolvableExpression):
2274 errors.append(item)
2275 if getattr(item, "contains_aggregate", False):
2276 raise FieldError(
2277 "Using an aggregate in order_by() without also including "
2278 f"it in annotate() is not allowed: {item}"
2279 )
2280 if errors:
2281 raise FieldError(f"Invalid order_by arguments: {errors}")
2282 if ordering:
2283 self.order_by += ordering
2284 else:
2285 self.default_ordering = False
2286
2287 def clear_ordering(self, force: bool = False, clear_default: bool = True) -> None:
2288 """
2289 Remove any ordering settings if the current query allows it without
2290 side effects, set 'force' to True to clear the ordering regardless.
2291 If 'clear_default' is True, there will be no ordering in the resulting
2292 query (not even the model's default).
2293 """
2294 if not force and (
2295 self.is_sliced or self.distinct_fields or self.select_for_update
2296 ):
2297 return
2298 self.order_by = ()
2299 self.extra_order_by = ()
2300 if clear_default:
2301 self.default_ordering = False
2302
2303 def set_group_by(self, allow_aliases: bool = True) -> None:
2304 """
2305 Expand the GROUP BY clause required by the query.
2306
2307 This will usually be the set of all non-aggregate fields in the
2308 return data. If the database backend supports grouping by the
2309 primary key, and the query would be equivalent, the optimization
2310 will be made automatically.
2311 """
2312 if allow_aliases and self.values_select:
2313 # If grouping by aliases is allowed assign selected value aliases
2314 # by moving them to annotations.
2315 group_by_annotations = {}
2316 values_select = {}
2317 for alias, expr in zip(self.values_select, self.select):
2318 if isinstance(expr, Col):
2319 values_select[alias] = expr
2320 else:
2321 group_by_annotations[alias] = expr
2322 self.annotations = {**group_by_annotations, **self.annotations}
2323 self.append_annotation_mask(group_by_annotations)
2324 self.select = tuple(values_select.values())
2325 self.values_select = tuple(values_select)
2326 group_by = list(self.select)
2327 for alias, annotation in self.annotation_select.items():
2328 if not (group_by_cols := annotation.get_group_by_cols()):
2329 continue
2330 if allow_aliases and not annotation.contains_aggregate:
2331 group_by.append(Ref(alias, annotation))
2332 else:
2333 group_by.extend(group_by_cols)
2334 self.group_by = tuple(group_by)
2335
2336 def add_select_related(self, fields: list[str]) -> None:
2337 """
2338 Set up the select_related data structure so that we only select
2339 certain related models (as opposed to all models, when
2340 self.select_related=True).
2341 """
2342 if isinstance(self.select_related, bool):
2343 field_dict: dict[str, Any] = {}
2344 else:
2345 field_dict = self.select_related
2346 for field in fields:
2347 d = field_dict
2348 for part in field.split(LOOKUP_SEP):
2349 d = d.setdefault(part, {})
2350 self.select_related = field_dict
2351
2352 def add_extra(
2353 self,
2354 select: dict[str, str],
2355 select_params: list[Any] | None,
2356 where: list[str],
2357 params: list[Any],
2358 tables: list[str],
2359 order_by: tuple[str, ...],
2360 ) -> None:
2361 """
2362 Add data to the various extra_* attributes for user-created additions
2363 to the query.
2364 """
2365 if select:
2366 # We need to pair any placeholder markers in the 'select'
2367 # dictionary with their parameters in 'select_params' so that
2368 # subsequent updates to the select dictionary also adjust the
2369 # parameters appropriately.
2370 select_pairs = {}
2371 if select_params:
2372 param_iter = iter(select_params)
2373 else:
2374 param_iter = iter([])
2375 for name, entry in select.items():
2376 self.check_alias(name)
2377 entry = str(entry)
2378 entry_params = []
2379 pos = entry.find("%s")
2380 while pos != -1:
2381 if pos == 0 or entry[pos - 1] != "%":
2382 entry_params.append(next(param_iter))
2383 pos = entry.find("%s", pos + 2)
2384 select_pairs[name] = (entry, entry_params)
2385 self.extra.update(select_pairs)
2386 if where or params:
2387 self.where.add(ExtraWhere(where, params), AND)
2388 if tables:
2389 self.extra_tables += tuple(tables)
2390 if order_by:
2391 self.extra_order_by = order_by
2392
2393 def clear_deferred_loading(self) -> None:
2394 """Remove any fields from the deferred loading set."""
2395 self.deferred_loading = (frozenset(), True)
2396
2397 def add_deferred_loading(self, field_names: frozenset[str]) -> None:
2398 """
2399 Add the given list of model field names to the set of fields to
2400 exclude from loading from the database when automatic column selection
2401 is done. Add the new field names to any existing field names that
2402 are deferred (or removed from any existing field names that are marked
2403 as the only ones for immediate loading).
2404 """
2405 # Fields on related models are stored in the literal double-underscore
2406 # format, so that we can use a set datastructure. We do the foo__bar
2407 # splitting and handling when computing the SQL column names (as part of
2408 # get_columns()).
2409 existing, defer = self.deferred_loading
2410 existing_set = set(existing)
2411 if defer:
2412 # Add to existing deferred names.
2413 self.deferred_loading = frozenset(existing_set.union(field_names)), True
2414 else:
2415 # Remove names from the set of any existing "immediate load" names.
2416 if new_existing := existing_set.difference(field_names):
2417 self.deferred_loading = frozenset(new_existing), False
2418 else:
2419 self.clear_deferred_loading()
2420 if new_only := set(field_names).difference(existing_set):
2421 self.deferred_loading = frozenset(new_only), True
2422
2423 def add_immediate_loading(self, field_names: list[str] | set[str]) -> None:
2424 """
2425 Add the given list of model field names to the set of fields to
2426 retrieve when the SQL is executed ("immediate loading" fields). The
2427 field names replace any existing immediate loading field names. If
2428 there are field names already specified for deferred loading, remove
2429 those names from the new field_names before storing the new names
2430 for immediate loading. (That is, immediate loading overrides any
2431 existing immediate values, but respects existing deferrals.)
2432 """
2433 existing, defer = self.deferred_loading
2434 field_names_set = set(field_names)
2435
2436 if defer:
2437 # Remove any existing deferred names from the current set before
2438 # setting the new names.
2439 self.deferred_loading = (
2440 frozenset(field_names_set.difference(existing)),
2441 False,
2442 )
2443 else:
2444 # Replace any existing "immediate load" field names.
2445 self.deferred_loading = frozenset(field_names_set), False
2446
2447 def set_annotation_mask(
2448 self,
2449 names: set[str]
2450 | frozenset[str]
2451 | list[str]
2452 | tuple[str, ...]
2453 | dict[str, Any]
2454 | None,
2455 ) -> None:
2456 """Set the mask of annotations that will be returned by the SELECT."""
2457 if names is None:
2458 self.annotation_select_mask = None
2459 else:
2460 self.annotation_select_mask = set(names)
2461 self._annotation_select_cache = None
2462
2463 def append_annotation_mask(self, names: list[str] | dict[str, Any]) -> None:
2464 if self.annotation_select_mask is not None:
2465 self.set_annotation_mask(self.annotation_select_mask.union(names))
2466
2467 def set_extra_mask(
2468 self, names: set[str] | list[str] | tuple[str, ...] | None
2469 ) -> None:
2470 """
2471 Set the mask of extra select items that will be returned by SELECT.
2472 Don't remove them from the Query since they might be used later.
2473 """
2474 if names is None:
2475 self.extra_select_mask = None
2476 else:
2477 self.extra_select_mask = set(names)
2478 self._extra_select_cache = None
2479
2480 def set_values(self, fields: list[str]) -> None:
2481 self.select_related = False
2482 self.clear_deferred_loading()
2483 self.clear_select_fields()
2484 self.has_select_fields = True
2485
2486 if fields:
2487 field_names = []
2488 extra_names = []
2489 annotation_names = []
2490 if not self.extra and not self.annotations:
2491 # Shortcut - if there are no extra or annotations, then
2492 # the values() clause must be just field names.
2493 field_names = list(fields)
2494 else:
2495 self.default_cols = False
2496 for f in fields:
2497 if f in self.extra_select:
2498 extra_names.append(f)
2499 elif f in self.annotation_select:
2500 annotation_names.append(f)
2501 else:
2502 field_names.append(f)
2503 self.set_extra_mask(extra_names)
2504 self.set_annotation_mask(annotation_names)
2505 selected = frozenset(field_names + extra_names + annotation_names)
2506 else:
2507 assert self.model is not None, "Default values query requires a model"
2508 field_names = [f.attname for f in self.model._model_meta.concrete_fields]
2509 selected = frozenset(field_names)
2510 # Selected annotations must be known before setting the GROUP BY
2511 # clause.
2512 if self.group_by is True:
2513 assert self.model is not None, "GROUP BY True requires a model"
2514 self.add_fields(
2515 (f.attname for f in self.model._model_meta.concrete_fields), False
2516 )
2517 # Disable GROUP BY aliases to avoid orphaning references to the
2518 # SELECT clause which is about to be cleared.
2519 self.set_group_by(allow_aliases=False)
2520 self.clear_select_fields()
2521 elif self.group_by:
2522 # Resolve GROUP BY annotation references if they are not part of
2523 # the selected fields anymore.
2524 group_by = []
2525 for expr in self.group_by:
2526 if isinstance(expr, Ref) and expr.refs not in selected:
2527 expr = self.annotations[expr.refs]
2528 group_by.append(expr)
2529 self.group_by = tuple(group_by)
2530
2531 self.values_select = tuple(field_names)
2532 self.add_fields(field_names, True)
2533
2534 @property
2535 def annotation_select(self) -> dict[str, BaseExpression]:
2536 """
2537 Return the dictionary of aggregate columns that are not masked and
2538 should be used in the SELECT clause. Cache this result for performance.
2539 """
2540 if self._annotation_select_cache is not None:
2541 return self._annotation_select_cache
2542 elif not self.annotations:
2543 return {}
2544 elif self.annotation_select_mask is not None:
2545 self._annotation_select_cache = {
2546 k: v
2547 for k, v in self.annotations.items()
2548 if k in self.annotation_select_mask
2549 }
2550 return self._annotation_select_cache
2551 else:
2552 return self.annotations
2553
2554 @property
2555 def extra_select(self) -> dict[str, tuple[str, list[Any]]]:
2556 if self._extra_select_cache is not None:
2557 return self._extra_select_cache
2558 if not self.extra:
2559 return {}
2560 elif self.extra_select_mask is not None:
2561 self._extra_select_cache = {
2562 k: v for k, v in self.extra.items() if k in self.extra_select_mask
2563 }
2564 return self._extra_select_cache
2565 else:
2566 return self.extra
2567
2568 def trim_start(
2569 self, names_with_path: list[tuple[str, list[Any]]]
2570 ) -> tuple[str, bool]:
2571 """
2572 Trim joins from the start of the join path. The candidates for trim
2573 are the PathInfos in names_with_path structure that are m2m joins.
2574
2575 Also set the select column so the start matches the join.
2576
2577 This method is meant to be used for generating the subquery joins &
2578 cols in split_exclude().
2579
2580 Return a lookup usable for doing outerq.filter(lookup=self) and a
2581 boolean indicating if the joins in the prefix contain a LEFT OUTER join.
2582 _"""
2583 all_paths = []
2584 for _, paths in names_with_path:
2585 all_paths.extend(paths)
2586 contains_louter = False
2587 # Trim and operate only on tables that were generated for
2588 # the lookup part of the query. That is, avoid trimming
2589 # joins generated for F() expressions.
2590 lookup_tables = [
2591 t for t in self.alias_map if t in self._lookup_joins or t == self.base_table
2592 ]
2593 for trimmed_paths, path in enumerate(all_paths):
2594 if path.m2m:
2595 break
2596 if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type == LOUTER:
2597 contains_louter = True
2598 alias = lookup_tables[trimmed_paths]
2599 self.unref_alias(alias)
2600 # The path.join_field is a Rel, lets get the other side's field
2601 join_field = path.join_field.field
2602 # Build the filter prefix.
2603 paths_in_prefix = trimmed_paths
2604 trimmed_prefix = []
2605 for name, path in names_with_path:
2606 if paths_in_prefix - len(path) < 0:
2607 break
2608 trimmed_prefix.append(name)
2609 paths_in_prefix -= len(path)
2610 trimmed_prefix.append(join_field.foreign_related_fields[0].name)
2611 trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
2612 # Lets still see if we can trim the first join from the inner query
2613 # (that is, self). We can't do this for:
2614 # - LEFT JOINs because we would miss those rows that have nothing on
2615 # the outer side,
2616 # - INNER JOINs from filtered relations because we would miss their
2617 # filters.
2618 first_join = self.alias_map[lookup_tables[trimmed_paths + 1]]
2619 if first_join.join_type != LOUTER and not first_join.filtered_relation:
2620 select_fields = [r[0] for r in join_field.related_fields]
2621 select_alias = lookup_tables[trimmed_paths + 1]
2622 self.unref_alias(lookup_tables[trimmed_paths])
2623 else:
2624 # TODO: It might be possible to trim more joins from the start of the
2625 # inner query if it happens to have a longer join chain containing the
2626 # values in select_fields. Lets punt this one for now.
2627 select_fields = [r[1] for r in join_field.related_fields]
2628 select_alias = lookup_tables[trimmed_paths]
2629 # The found starting point is likely a join_class instead of a
2630 # base_table_class reference. But the first entry in the query's FROM
2631 # clause must not be a JOIN.
2632 for table in self.alias_map:
2633 if self.alias_refcount[table] > 0:
2634 self.alias_map[table] = self.base_table_class(
2635 self.alias_map[table].table_name,
2636 table,
2637 )
2638 break
2639 self.set_select([f.get_col(select_alias) for f in select_fields])
2640 return trimmed_prefix, contains_louter
2641
2642 def is_nullable(self, field: Field) -> bool:
2643 """Check if the given field should be treated as nullable."""
2644 # QuerySet does not have knowledge of which connection is going to be
2645 # used. For the single-database setup we always reference the default
2646 # connection here.
2647 return field.allow_null
2648
2649
2650def get_order_dir(field: str, default: str = "ASC") -> tuple[str, str]:
2651 """
2652 Return the field name and direction for an order specification. For
2653 example, '-foo' is returned as ('foo', 'DESC').
2654
2655 The 'default' param is used to indicate which way no prefix (or a '+'
2656 prefix) should sort. The '-' prefix always sorts the opposite way.
2657 """
2658 dirn = ORDER_DIR[default]
2659 if field[0] == "-":
2660 return field[1:], dirn[1]
2661 return field, dirn[0]
2662
2663
2664class JoinPromoter:
2665 """
2666 A class to abstract away join promotion problems for complex filter
2667 conditions.
2668 """
2669
2670 def __init__(self, connector: str, num_children: int, negated: bool):
2671 self.connector = connector
2672 self.negated = negated
2673 if self.negated:
2674 if connector == AND:
2675 self.effective_connector = OR
2676 else:
2677 self.effective_connector = AND
2678 else:
2679 self.effective_connector = self.connector
2680 self.num_children = num_children
2681 # Maps of table alias to how many times it is seen as required for
2682 # inner and/or outer joins.
2683 self.votes = Counter()
2684
2685 def __repr__(self) -> str:
2686 return (
2687 f"{self.__class__.__qualname__}(connector={self.connector!r}, "
2688 f"num_children={self.num_children!r}, negated={self.negated!r})"
2689 )
2690
2691 def add_votes(self, votes: Any) -> None:
2692 """
2693 Add single vote per item to self.votes. Parameter can be any
2694 iterable.
2695 """
2696 self.votes.update(votes)
2697
2698 def update_join_types(self, query: Query) -> set[str]:
2699 """
2700 Change join types so that the generated query is as efficient as
2701 possible, but still correct. So, change as many joins as possible
2702 to INNER, but don't make OUTER joins INNER if that could remove
2703 results from the query.
2704 """
2705 to_promote = set()
2706 to_demote = set()
2707 # The effective_connector is used so that NOT (a AND b) is treated
2708 # similarly to (a OR b) for join promotion.
2709 for table, votes in self.votes.items():
2710 # We must use outer joins in OR case when the join isn't contained
2711 # in all of the joins. Otherwise the INNER JOIN itself could remove
2712 # valid results. Consider the case where a model with rel_a and
2713 # rel_b relations is queried with rel_a__col=1 | rel_b__col=2. Now,
2714 # if rel_a join doesn't produce any results is null (for example
2715 # reverse foreign key or null value in direct foreign key), and
2716 # there is a matching row in rel_b with col=2, then an INNER join
2717 # to rel_a would remove a valid match from the query. So, we need
2718 # to promote any existing INNER to LOUTER (it is possible this
2719 # promotion in turn will be demoted later on).
2720 if self.effective_connector == OR and votes < self.num_children:
2721 to_promote.add(table)
2722 # If connector is AND and there is a filter that can match only
2723 # when there is a joinable row, then use INNER. For example, in
2724 # rel_a__col=1 & rel_b__col=2, if either of the rels produce NULL
2725 # as join output, then the col=1 or col=2 can't match (as
2726 # NULL=anything is always false).
2727 # For the OR case, if all children voted for a join to be inner,
2728 # then we can use INNER for the join. For example:
2729 # (rel_a__col__icontains=Alex | rel_a__col__icontains=Russell)
2730 # then if rel_a doesn't produce any rows, the whole condition
2731 # can't match. Hence we can safely use INNER join.
2732 if self.effective_connector == AND or (
2733 self.effective_connector == OR and votes == self.num_children
2734 ):
2735 to_demote.add(table)
2736 # Finally, what happens in cases where we have:
2737 # (rel_a__col=1|rel_b__col=2) & rel_a__col__gte=0
2738 # Now, we first generate the OR clause, and promote joins for it
2739 # in the first if branch above. Both rel_a and rel_b are promoted
2740 # to LOUTER joins. After that we do the AND case. The OR case
2741 # voted no inner joins but the rel_a__col__gte=0 votes inner join
2742 # for rel_a. We demote it back to INNER join (in AND case a single
2743 # vote is enough). The demotion is OK, if rel_a doesn't produce
2744 # rows, then the rel_a__col__gte=0 clause can't be true, and thus
2745 # the whole clause must be false. So, it is safe to use INNER
2746 # join.
2747 # Note that in this example we could just as well have the __gte
2748 # clause and the OR clause swapped. Or we could replace the __gte
2749 # clause with an OR clause containing rel_a__col=1|rel_a__col=2,
2750 # and again we could safely demote to INNER.
2751 query.promote_joins(to_promote)
2752 query.demote_joins(to_demote)
2753 return to_demote