1"""
2Create SQL statements for QuerySets.
3
4The code in here encapsulates all of the SQL construction so that QuerySets
5themselves do not have to (and could be backed by things other than SQL
6databases). The abstraction barrier only works one way: this module has to know
7all about the internals of models in order to get the information it needs.
8"""
9
10from __future__ import annotations
11
12import copy
13import difflib
14import functools
15import sys
16from collections import Counter
17from collections.abc import Callable, Iterable, Iterator, Mapping
18from collections.abc import Iterator as TypingIterator
19from functools import cached_property
20from itertools import chain, count, product
21from string import ascii_uppercase
22from typing import (
23 TYPE_CHECKING,
24 Any,
25 Literal,
26 NamedTuple,
27 Self,
28 TypeVar,
29 cast,
30 overload,
31)
32
33from plain.models.aggregates import Count
34from plain.models.constants import LOOKUP_SEP
35from plain.models.db import NotSupportedError, db_connection
36from plain.models.exceptions import FieldDoesNotExist, FieldError
37from plain.models.expressions import (
38 BaseExpression,
39 Col,
40 Exists,
41 F,
42 OuterRef,
43 Ref,
44 ResolvableExpression,
45 ResolvedOuterRef,
46 Value,
47)
48from plain.models.fields import Field
49from plain.models.fields.related_lookups import MultiColSource
50from plain.models.lookups import Lookup
51from plain.models.query_utils import (
52 PathInfo,
53 Q,
54 check_rel_lookup_compatibility,
55 refs_expression,
56)
57from plain.models.sql.constants import INNER, LOUTER, ORDER_DIR, SINGLE
58from plain.models.sql.datastructures import BaseTable, Empty, Join, MultiJoin
59from plain.models.sql.where import AND, OR, ExtraWhere, NothingNode, WhereNode
60from plain.utils.regex_helper import _lazy_re_compile
61from plain.utils.tree import Node
62
63if TYPE_CHECKING:
64 from plain.models import Model
65 from plain.models.backends.base.base import BaseDatabaseWrapper
66 from plain.models.fields.related import RelatedField
67 from plain.models.fields.reverse_related import ForeignObjectRel
68 from plain.models.meta import Meta
69 from plain.models.sql.compiler import SQLCompiler, SqlWithParams
70
71# Quotation marks ('"`[]), whitespace characters, semicolons, or inline
72# SQL comments are forbidden in column aliases.
73FORBIDDEN_ALIAS_PATTERN = _lazy_re_compile(r"['`\"\]\[;\s]|--|/\*|\*/")
74
75# Inspired from
76# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
77EXPLAIN_OPTIONS_PATTERN = _lazy_re_compile(r"[\w\-]+")
78
79
80def get_field_names_from_opts(meta: Meta | None) -> set[str]:
81 if meta is None:
82 return set()
83 return set(
84 chain.from_iterable(
85 (f.name, f.attname) if f.concrete else (f.name,) for f in meta.get_fields()
86 )
87 )
88
89
90def get_children_from_q(q: Q) -> TypingIterator[tuple[str, Any]]:
91 for child in q.children:
92 if isinstance(child, Node):
93 yield from get_children_from_q(child)
94 else:
95 yield child
96
97
98class JoinInfo(NamedTuple):
99 """Information about a join operation in a query."""
100
101 final_field: Field[Any]
102 targets: tuple[Field[Any], ...]
103 meta: Meta
104 joins: list[str]
105 path: list[PathInfo]
106 transform_function: Callable[[Field[Any], str | None], BaseExpression]
107
108
109class RawQuery:
110 """A single raw SQL query."""
111
112 def __init__(self, sql: str, params: tuple[Any, ...] | dict[str, Any] = ()):
113 self.params = params
114 self.sql = sql
115 self.cursor: Any = None
116
117 # Mirror some properties of a normal query so that
118 # the compiler can be used to process results.
119 self.low_mark, self.high_mark = 0, None # Used for offset/limit
120 self.extra_select = {}
121 self.annotation_select = {}
122
123 def chain(self) -> RawQuery:
124 return self.clone()
125
126 def clone(self) -> RawQuery:
127 return RawQuery(self.sql, params=self.params)
128
129 def get_columns(self) -> list[str]:
130 if self.cursor is None:
131 self._execute_query()
132 converter = db_connection.introspection.identifier_converter
133 return [
134 converter(column_model_meta[0])
135 for column_model_meta in self.cursor.description
136 ]
137
138 def __iter__(self) -> TypingIterator[Any]:
139 # Always execute a new query for a new iterator.
140 # This could be optimized with a cache at the expense of RAM.
141 self._execute_query()
142 if not db_connection.features.can_use_chunked_reads:
143 # If the database can't use chunked reads we need to make sure we
144 # evaluate the entire query up front.
145 result = list(self.cursor)
146 else:
147 result = self.cursor
148 return iter(result)
149
150 def __repr__(self) -> str:
151 return f"<{self.__class__.__name__}: {self}>"
152
153 @property
154 def params_type(self) -> type[dict] | type[tuple] | None:
155 if self.params is None:
156 return None
157 return dict if isinstance(self.params, Mapping) else tuple
158
159 def __str__(self) -> str:
160 if self.params_type is None:
161 return self.sql
162 return self.sql % self.params_type(self.params)
163
164 def _execute_query(self) -> None:
165 # Adapt parameters to the database, as much as possible considering
166 # that the target type isn't known. See #17755.
167 params_type = self.params_type
168 adapter = db_connection.ops.adapt_unknown_value
169 if params_type is tuple:
170 assert isinstance(self.params, tuple)
171 params = tuple(adapter(val) for val in self.params)
172 elif params_type is dict:
173 assert isinstance(self.params, dict)
174 params = {key: adapter(val) for key, val in self.params.items()}
175 elif params_type is None:
176 params = None
177 else:
178 raise RuntimeError(f"Unexpected params type: {params_type}")
179
180 self.cursor = db_connection.cursor()
181 self.cursor.execute(self.sql, params)
182
183
184class ExplainInfo(NamedTuple):
185 """Information about an EXPLAIN query."""
186
187 format: str | None
188 options: dict[str, Any]
189
190
191class TransformWrapper:
192 """Wrapper for transform functions that supports the has_transforms attribute.
193
194 This replaces functools.partial for transform functions, allowing proper
195 type checking while supporting dynamic attribute assignment.
196 """
197
198 def __init__(
199 self,
200 func: Callable[..., BaseExpression],
201 **kwargs: Any,
202 ):
203 self._partial = functools.partial(func, **kwargs)
204 self.has_transforms: bool = False
205
206 def __call__(self, field: Field[Any], alias: str | None) -> BaseExpression:
207 return self._partial(field, alias)
208
209
210QueryType = TypeVar("QueryType", bound="Query")
211
212
213class Query(BaseExpression):
214 """A single SQL query."""
215
216 alias_prefix = "T"
217 empty_result_set_value = None
218 subq_aliases = frozenset([alias_prefix])
219
220 base_table_class = BaseTable
221 join_class = Join
222
223 default_cols = True
224 default_ordering = True
225 standard_ordering = True
226
227 filter_is_sticky = False
228 subquery = False
229
230 # SQL-related attributes.
231 # Select and related select clauses are expressions to use in the SELECT
232 # clause of the query. The select is used for cases where we want to set up
233 # the select clause to contain other than default fields (values(),
234 # subqueries...). Note that annotations go to annotations dictionary.
235 select = ()
236 # The group_by attribute can have one of the following forms:
237 # - None: no group by at all in the query
238 # - A tuple of expressions: group by (at least) those expressions.
239 # String refs are also allowed for now.
240 # - True: group by all select fields of the model
241 # See compiler.get_group_by() for details.
242 group_by = None
243 order_by = ()
244 low_mark = 0 # Used for offset/limit.
245 high_mark = None # Used for offset/limit.
246 distinct = False
247 distinct_fields = ()
248 select_for_update = False
249 select_for_update_nowait = False
250 select_for_update_skip_locked = False
251 select_for_update_of = ()
252 select_for_no_key_update = False
253 select_related: bool | dict[str, Any] = False
254 has_select_fields = False
255 # Arbitrary limit for select_related to prevents infinite recursion.
256 max_depth = 5
257 # Holds the selects defined by a call to values() or values_list()
258 # excluding annotation_select and extra_select.
259 values_select = ()
260
261 # SQL annotation-related attributes.
262 annotation_select_mask = None
263 _annotation_select_cache = None
264
265 # These are for extensions. The contents are more or less appended verbatim
266 # to the appropriate clause.
267 extra_select_mask = None
268 _extra_select_cache = None
269
270 extra_tables = ()
271 extra_order_by = ()
272
273 # A tuple that is a set of model field names and either True, if these are
274 # the fields to defer, or False if these are the only fields to load.
275 deferred_loading = (frozenset(), True)
276
277 explain_info = None
278
279 def __init__(self, model: type[Model] | None, alias_cols: bool = True):
280 self.model = model
281 self.alias_refcount = {}
282 # alias_map is the most important data structure regarding joins.
283 # It's used for recording which joins exist in the query and what
284 # types they are. The key is the alias of the joined table (possibly
285 # the table name) and the value is a Join-like object (see
286 # sql.datastructures.Join for more information).
287 self.alias_map = {}
288 # Whether to provide alias to columns during reference resolving.
289 self.alias_cols = alias_cols
290 # Sometimes the query contains references to aliases in outer queries (as
291 # a result of split_exclude). Correct alias quoting needs to know these
292 # aliases too.
293 # Map external tables to whether they are aliased.
294 self.external_aliases = {}
295 self.table_map = {} # Maps table names to list of aliases.
296 self.used_aliases = set()
297
298 self.where = WhereNode()
299 # Maps alias -> Annotation Expression.
300 self.annotations = {}
301 # These are for extensions. The contents are more or less appended
302 # verbatim to the appropriate clause.
303 self.extra = {} # Maps col_alias -> (col_sql, params).
304
305 self._filtered_relations = {}
306
307 @property
308 def output_field(self) -> Field | None:
309 if len(self.select) == 1:
310 select = self.select[0]
311 return getattr(select, "target", None) or select.field
312 elif len(self.annotation_select) == 1:
313 return next(iter(self.annotation_select.values())).output_field
314
315 @cached_property
316 def base_table(self) -> str | None:
317 for alias in self.alias_map:
318 return alias
319
320 def __str__(self) -> str:
321 """
322 Return the query as a string of SQL with the parameter values
323 substituted in (use sql_with_params() to see the unsubstituted string).
324
325 Parameter values won't necessarily be quoted correctly, since that is
326 done by the database interface at execution time.
327 """
328 sql, params = self.sql_with_params()
329 return sql % params
330
331 def sql_with_params(self) -> SqlWithParams:
332 """
333 Return the query as an SQL string and the parameters that will be
334 substituted into the query.
335 """
336 return self.get_compiler().as_sql()
337
338 def __deepcopy__(self, memo: dict[int, Any]) -> Self:
339 """Limit the amount of work when a Query is deepcopied."""
340 result = self.clone()
341 memo[id(self)] = result
342 return result
343
344 def get_compiler(self, *, elide_empty: bool = True) -> SQLCompiler:
345 """Return a compiler instance for this query."""
346 return db_connection.ops.get_compiler_for(self, elide_empty)
347
348 def clone(self) -> Self:
349 """
350 Return a copy of the current Query. A lightweight alternative to
351 deepcopy().
352 """
353 obj = Empty()
354 obj.__class__ = self.__class__
355 obj = cast(Self, obj) # Type checker doesn't understand __class__ reassignment
356 # Copy references to everything.
357 obj.__dict__ = self.__dict__.copy()
358 # Clone attributes that can't use shallow copy.
359 obj.alias_refcount = self.alias_refcount.copy()
360 obj.alias_map = self.alias_map.copy()
361 obj.external_aliases = self.external_aliases.copy()
362 obj.table_map = self.table_map.copy()
363 obj.where = self.where.clone()
364 obj.annotations = self.annotations.copy()
365 if self.annotation_select_mask is not None:
366 obj.annotation_select_mask = self.annotation_select_mask.copy()
367 # _annotation_select_cache cannot be copied, as doing so breaks the
368 # (necessary) state in which both annotations and
369 # _annotation_select_cache point to the same underlying objects.
370 # It will get re-populated in the cloned queryset the next time it's
371 # used.
372 obj._annotation_select_cache = None
373 obj.extra = self.extra.copy()
374 if self.extra_select_mask is not None:
375 obj.extra_select_mask = self.extra_select_mask.copy()
376 if self._extra_select_cache is not None:
377 obj._extra_select_cache = self._extra_select_cache.copy()
378 if self.select_related is not False:
379 # Use deepcopy because select_related stores fields in nested
380 # dicts.
381 obj.select_related = copy.deepcopy(obj.select_related)
382 if "subq_aliases" in self.__dict__:
383 obj.subq_aliases = self.subq_aliases.copy()
384 obj.used_aliases = self.used_aliases.copy()
385 obj._filtered_relations = self._filtered_relations.copy()
386 # Clear the cached_property, if it exists.
387 obj.__dict__.pop("base_table", None)
388 return obj
389
390 @overload
391 def chain(self, klass: None = None) -> Self: ...
392
393 @overload
394 def chain(self, klass: type[QueryType]) -> QueryType: ...
395
396 def chain(self, klass: type[Query] | None = None) -> Query:
397 """
398 Return a copy of the current Query that's ready for another operation.
399 The klass argument changes the type of the Query, e.g. UpdateQuery.
400 """
401 obj = self.clone()
402 if klass and obj.__class__ != klass:
403 obj.__class__ = klass
404 if not obj.filter_is_sticky:
405 obj.used_aliases = set()
406 obj.filter_is_sticky = False
407 if hasattr(obj, "_setup_query"):
408 obj._setup_query() # type: ignore[operator]
409 return obj
410
411 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
412 clone = self.clone()
413 clone.change_aliases(change_map)
414 return clone
415
416 def _get_col(self, target: Any, field: Field, alias: str | None) -> Col:
417 if not self.alias_cols:
418 alias = None
419 return target.get_col(alias, field)
420
421 def get_aggregation(self, aggregate_exprs: dict[str, Any]) -> dict[str, Any]:
422 """
423 Return the dictionary with the values of the existing aggregations.
424 """
425 if not aggregate_exprs:
426 return {}
427 aggregates = {}
428 for alias, aggregate_expr in aggregate_exprs.items():
429 self.check_alias(alias)
430 aggregate = aggregate_expr.resolve_expression(
431 self, allow_joins=True, reuse=None, summarize=True
432 )
433 if not aggregate.contains_aggregate:
434 raise TypeError(f"{alias} is not an aggregate expression")
435 aggregates[alias] = aggregate
436 # Existing usage of aggregation can be determined by the presence of
437 # selected aggregates but also by filters against aliased aggregates.
438 _, having, qualify = self.where.split_having_qualify()
439 has_existing_aggregation = (
440 any(
441 getattr(annotation, "contains_aggregate", True)
442 for annotation in self.annotations.values()
443 )
444 or having
445 )
446 # Decide if we need to use a subquery.
447 #
448 # Existing aggregations would cause incorrect results as
449 # get_aggregation() must produce just one result and thus must not use
450 # GROUP BY.
451 #
452 # If the query has limit or distinct, or uses set operations, then
453 # those operations must be done in a subquery so that the query
454 # aggregates on the limit and/or distinct results instead of applying
455 # the distinct and limit after the aggregation.
456 if (
457 isinstance(self.group_by, tuple)
458 or self.is_sliced
459 or has_existing_aggregation
460 or qualify
461 or self.distinct
462 ):
463 from plain.models.sql.subqueries import AggregateQuery
464
465 inner_query = self.clone()
466 inner_query.subquery = True
467 outer_query = AggregateQuery(self.model, inner_query)
468 inner_query.select_for_update = False
469 inner_query.select_related = False
470 inner_query.set_annotation_mask(self.annotation_select)
471 # Queries with distinct_fields need ordering and when a limit is
472 # applied we must take the slice from the ordered query. Otherwise
473 # no need for ordering.
474 inner_query.clear_ordering(force=False)
475 if not inner_query.distinct:
476 # If the inner query uses default select and it has some
477 # aggregate annotations, then we must make sure the inner
478 # query is grouped by the main model's primary key. However,
479 # clearing the select clause can alter results if distinct is
480 # used.
481 if inner_query.default_cols and has_existing_aggregation:
482 assert self.model is not None, "Aggregation requires a model"
483 inner_query.group_by = (
484 self.model._model_meta.get_forward_field("id").get_col(
485 inner_query.get_initial_alias()
486 ),
487 )
488 inner_query.default_cols = False
489 if not qualify:
490 # Mask existing annotations that are not referenced by
491 # aggregates to be pushed to the outer query unless
492 # filtering against window functions is involved as it
493 # requires complex realising.
494 annotation_mask = set()
495 for aggregate in aggregates.values():
496 annotation_mask |= aggregate.get_refs()
497 inner_query.set_annotation_mask(annotation_mask)
498
499 # Add aggregates to the outer AggregateQuery. This requires making
500 # sure all columns referenced by the aggregates are selected in the
501 # inner query. It is achieved by retrieving all column references
502 # by the aggregates, explicitly selecting them in the inner query,
503 # and making sure the aggregates are repointed to them.
504 col_refs = {}
505 for alias, aggregate in aggregates.items():
506 replacements = {}
507 for col in self._gen_cols([aggregate], resolve_refs=False):
508 if not (col_ref := col_refs.get(col)):
509 index = len(col_refs) + 1
510 col_alias = f"__col{index}"
511 col_ref = Ref(col_alias, col)
512 col_refs[col] = col_ref
513 inner_query.annotations[col_alias] = col
514 inner_query.append_annotation_mask([col_alias])
515 replacements[col] = col_ref
516 outer_query.annotations[alias] = aggregate.replace_expressions(
517 replacements
518 )
519 if (
520 inner_query.select == ()
521 and not inner_query.default_cols
522 and not inner_query.annotation_select_mask
523 ):
524 # In case of Model.objects[0:3].count(), there would be no
525 # field selected in the inner query, yet we must use a subquery.
526 # So, make sure at least one field is selected.
527 assert self.model is not None, "Count with slicing requires a model"
528 inner_query.select = (
529 self.model._model_meta.get_forward_field("id").get_col(
530 inner_query.get_initial_alias()
531 ),
532 )
533 else:
534 outer_query = self
535 self.select = ()
536 self.default_cols = False
537 self.extra = {}
538 if self.annotations:
539 # Inline reference to existing annotations and mask them as
540 # they are unnecessary given only the summarized aggregations
541 # are requested.
542 replacements = {
543 Ref(alias, annotation): annotation
544 for alias, annotation in self.annotations.items()
545 }
546 self.annotations = {
547 alias: aggregate.replace_expressions(replacements)
548 for alias, aggregate in aggregates.items()
549 }
550 else:
551 self.annotations = aggregates
552 self.set_annotation_mask(aggregates)
553
554 empty_set_result = [
555 expression.empty_result_set_value
556 for expression in outer_query.annotation_select.values()
557 ]
558 elide_empty = not any(result is NotImplemented for result in empty_set_result)
559 outer_query.clear_ordering(force=True)
560 outer_query.clear_limits()
561 outer_query.select_for_update = False
562 outer_query.select_related = False
563 compiler = outer_query.get_compiler(elide_empty=elide_empty)
564 result = compiler.execute_sql(SINGLE)
565 if result is None:
566 result = empty_set_result
567 else:
568 converters = compiler.get_converters(outer_query.annotation_select.values())
569 result = next(compiler.apply_converters((result,), converters))
570
571 return dict(zip(outer_query.annotation_select, result))
572
573 def get_count(self) -> int:
574 """
575 Perform a COUNT() query using the current filter constraints.
576 """
577 obj = self.clone()
578 return obj.get_aggregation({"__count": Count("*")})["__count"]
579
580 def has_filters(self) -> bool:
581 return bool(self.where)
582
583 def exists(self, limit: bool = True) -> Self:
584 q = self.clone()
585 if not (q.distinct and q.is_sliced):
586 if q.group_by is True:
587 assert self.model is not None, "GROUP BY requires a model"
588 q.add_fields(
589 (f.attname for f in self.model._model_meta.concrete_fields), False
590 )
591 # Disable GROUP BY aliases to avoid orphaning references to the
592 # SELECT clause which is about to be cleared.
593 q.set_group_by(allow_aliases=False)
594 q.clear_select_clause()
595 q.clear_ordering(force=True)
596 if limit:
597 q.set_limits(high=1)
598 q.add_annotation(Value(1), "a")
599 return q
600
601 def has_results(self) -> bool:
602 q = self.exists()
603 compiler = q.get_compiler()
604 return compiler.has_results()
605
606 def explain(self, format: str | None = None, **options: Any) -> str:
607 q = self.clone()
608 for option_name in options:
609 if (
610 not EXPLAIN_OPTIONS_PATTERN.fullmatch(option_name)
611 or "--" in option_name
612 ):
613 raise ValueError(f"Invalid option name: {option_name!r}.")
614 q.explain_info = ExplainInfo(format, options)
615 compiler = q.get_compiler()
616 return "\n".join(compiler.explain_query())
617
618 def combine(self, rhs: Query, connector: str) -> None:
619 """
620 Merge the 'rhs' query into the current one (with any 'rhs' effects
621 being applied *after* (that is, "to the right of") anything in the
622 current query. 'rhs' is not modified during a call to this function.
623
624 The 'connector' parameter describes how to connect filters from the
625 'rhs' query.
626 """
627 if self.model != rhs.model:
628 raise TypeError("Cannot combine queries on two different base models.")
629 if self.is_sliced:
630 raise TypeError("Cannot combine queries once a slice has been taken.")
631 if self.distinct != rhs.distinct:
632 raise TypeError("Cannot combine a unique query with a non-unique query.")
633 if self.distinct_fields != rhs.distinct_fields:
634 raise TypeError("Cannot combine queries with different distinct fields.")
635
636 # If lhs and rhs shares the same alias prefix, it is possible to have
637 # conflicting alias changes like T4 -> T5, T5 -> T6, which might end up
638 # as T4 -> T6 while combining two querysets. To prevent this, change an
639 # alias prefix of the rhs and update current aliases accordingly,
640 # except if the alias is the base table since it must be present in the
641 # query on both sides.
642 initial_alias = self.get_initial_alias()
643 assert initial_alias is not None
644 rhs.bump_prefix(self, exclude={initial_alias})
645
646 # Work out how to relabel the rhs aliases, if necessary.
647 change_map = {}
648 conjunction = connector == AND
649
650 # Determine which existing joins can be reused. When combining the
651 # query with AND we must recreate all joins for m2m filters. When
652 # combining with OR we can reuse joins. The reason is that in AND
653 # case a single row can't fulfill a condition like:
654 # revrel__col=1 & revrel__col=2
655 # But, there might be two different related rows matching this
656 # condition. In OR case a single True is enough, so single row is
657 # enough, too.
658 #
659 # Note that we will be creating duplicate joins for non-m2m joins in
660 # the AND case. The results will be correct but this creates too many
661 # joins. This is something that could be fixed later on.
662 reuse = set() if conjunction else set(self.alias_map)
663 joinpromoter = JoinPromoter(connector, 2, False)
664 joinpromoter.add_votes(
665 j for j in self.alias_map if self.alias_map[j].join_type == INNER
666 )
667 rhs_votes = set()
668 # Now, add the joins from rhs query into the new query (skipping base
669 # table).
670 rhs_tables = list(rhs.alias_map)[1:]
671 for alias in rhs_tables:
672 join = rhs.alias_map[alias]
673 # If the left side of the join was already relabeled, use the
674 # updated alias.
675 join = join.relabeled_clone(change_map)
676 new_alias = self.join(join, reuse=reuse)
677 if join.join_type == INNER:
678 rhs_votes.add(new_alias)
679 # We can't reuse the same join again in the query. If we have two
680 # distinct joins for the same connection in rhs query, then the
681 # combined query must have two joins, too.
682 reuse.discard(new_alias)
683 if alias != new_alias:
684 change_map[alias] = new_alias
685 if not rhs.alias_refcount[alias]:
686 # The alias was unused in the rhs query. Unref it so that it
687 # will be unused in the new query, too. We have to add and
688 # unref the alias so that join promotion has information of
689 # the join type for the unused alias.
690 self.unref_alias(new_alias)
691 joinpromoter.add_votes(rhs_votes)
692 joinpromoter.update_join_types(self)
693
694 # Combine subqueries aliases to ensure aliases relabelling properly
695 # handle subqueries when combining where and select clauses.
696 self.subq_aliases |= rhs.subq_aliases
697
698 # Now relabel a copy of the rhs where-clause and add it to the current
699 # one.
700 w = rhs.where.clone()
701 w.relabel_aliases(change_map)
702 self.where.add(w, connector)
703
704 # Selection columns and extra extensions are those provided by 'rhs'.
705 if rhs.select:
706 self.set_select([col.relabeled_clone(change_map) for col in rhs.select])
707 else:
708 self.select = ()
709
710 if connector == OR:
711 # It would be nice to be able to handle this, but the queries don't
712 # really make sense (or return consistent value sets). Not worth
713 # the extra complexity when you can write a real query instead.
714 if self.extra and rhs.extra:
715 raise ValueError(
716 "When merging querysets using 'or', you cannot have "
717 "extra(select=...) on both sides."
718 )
719 self.extra.update(rhs.extra)
720 extra_select_mask = set()
721 if self.extra_select_mask is not None:
722 extra_select_mask.update(self.extra_select_mask)
723 if rhs.extra_select_mask is not None:
724 extra_select_mask.update(rhs.extra_select_mask)
725 if extra_select_mask:
726 self.set_extra_mask(extra_select_mask)
727 self.extra_tables += rhs.extra_tables
728
729 # Ordering uses the 'rhs' ordering, unless it has none, in which case
730 # the current ordering is used.
731 self.order_by = rhs.order_by or self.order_by
732 self.extra_order_by = rhs.extra_order_by or self.extra_order_by
733
734 def _get_defer_select_mask(
735 self,
736 meta: Meta,
737 mask: dict[str, Any],
738 select_mask: dict[Any, Any] | None = None,
739 ) -> dict[Any, Any]:
740 from plain.models.fields.related import RelatedField
741
742 if select_mask is None:
743 select_mask = {}
744 select_mask[meta.get_forward_field("id")] = {}
745 # All concrete fields that are not part of the defer mask must be
746 # loaded. If a relational field is encountered it gets added to the
747 # mask for it be considered if `select_related` and the cycle continues
748 # by recursively calling this function.
749 for field in meta.concrete_fields:
750 field_mask = mask.pop(field.name, None)
751 if field_mask is None:
752 select_mask.setdefault(field, {})
753 elif field_mask:
754 if not isinstance(field, RelatedField):
755 raise FieldError(next(iter(field_mask)))
756 field_select_mask = select_mask.setdefault(field, {})
757 related_model = field.remote_field.model
758 self._get_defer_select_mask(
759 related_model._model_meta, field_mask, field_select_mask
760 )
761 # Remaining defer entries must be references to reverse relationships.
762 # The following code is expected to raise FieldError if it encounters
763 # a malformed defer entry.
764 for field_name, field_mask in mask.items():
765 if filtered_relation := self._filtered_relations.get(field_name):
766 relation = meta.get_reverse_relation(filtered_relation.relation_name)
767 field_select_mask = select_mask.setdefault((field_name, relation), {})
768 field = relation.field
769 else:
770 field = meta.get_reverse_relation(field_name).field
771 field_select_mask = select_mask.setdefault(field, {})
772 related_model = field.model
773 self._get_defer_select_mask(
774 related_model._model_meta, field_mask, field_select_mask
775 )
776 return select_mask
777
778 def _get_only_select_mask(
779 self,
780 meta: Meta,
781 mask: dict[str, Any],
782 select_mask: dict[Any, Any] | None = None,
783 ) -> dict[Any, Any]:
784 from plain.models.fields.related import RelatedField
785
786 if select_mask is None:
787 select_mask = {}
788 select_mask[meta.get_forward_field("id")] = {}
789 # Only include fields mentioned in the mask.
790 for field_name, field_mask in mask.items():
791 field = meta.get_field(field_name)
792 field_select_mask = select_mask.setdefault(field, {})
793 if field_mask:
794 if not isinstance(field, RelatedField):
795 raise FieldError(next(iter(field_mask)))
796 related_model = field.remote_field.model
797 self._get_only_select_mask(
798 related_model._model_meta, field_mask, field_select_mask
799 )
800 return select_mask
801
802 def get_select_mask(self) -> dict[Any, Any]:
803 """
804 Convert the self.deferred_loading data structure to an alternate data
805 structure, describing the field that *will* be loaded. This is used to
806 compute the columns to select from the database and also by the
807 QuerySet class to work out which fields are being initialized on each
808 model. Models that have all their fields included aren't mentioned in
809 the result, only those that have field restrictions in place.
810 """
811 field_names, defer = self.deferred_loading
812 if not field_names:
813 return {}
814 mask = {}
815 for field_name in field_names:
816 part_mask = mask
817 for part in field_name.split(LOOKUP_SEP):
818 part_mask = part_mask.setdefault(part, {})
819 assert self.model is not None, "Deferred/only field loading requires a model"
820 meta = self.model._model_meta
821 if defer:
822 return self._get_defer_select_mask(meta, mask)
823 return self._get_only_select_mask(meta, mask)
824
825 def table_alias(
826 self, table_name: str, create: bool = False, filtered_relation: Any = None
827 ) -> tuple[str, bool]:
828 """
829 Return a table alias for the given table_name and whether this is a
830 new alias or not.
831
832 If 'create' is true, a new alias is always created. Otherwise, the
833 most recently created alias for the table (if one exists) is reused.
834 """
835 alias_list = self.table_map.get(table_name)
836 if not create and alias_list:
837 alias = alias_list[0]
838 self.alias_refcount[alias] += 1
839 return alias, False
840
841 # Create a new alias for this table.
842 if alias_list:
843 alias = "%s%d" % (self.alias_prefix, len(self.alias_map) + 1) # noqa: UP031
844 alias_list.append(alias)
845 else:
846 # The first occurrence of a table uses the table name directly.
847 alias = (
848 filtered_relation.alias if filtered_relation is not None else table_name
849 )
850 self.table_map[table_name] = [alias]
851 self.alias_refcount[alias] = 1
852 return alias, True
853
854 def ref_alias(self, alias: str) -> None:
855 """Increases the reference count for this alias."""
856 self.alias_refcount[alias] += 1
857
858 def unref_alias(self, alias: str, amount: int = 1) -> None:
859 """Decreases the reference count for this alias."""
860 self.alias_refcount[alias] -= amount
861
862 def promote_joins(self, aliases: set[str] | list[str]) -> None:
863 """
864 Promote recursively the join type of given aliases and its children to
865 an outer join. If 'unconditional' is False, only promote the join if
866 it is nullable or the parent join is an outer join.
867
868 The children promotion is done to avoid join chains that contain a LOUTER
869 b INNER c. So, if we have currently a INNER b INNER c and a->b is promoted,
870 then we must also promote b->c automatically, or otherwise the promotion
871 of a->b doesn't actually change anything in the query results.
872 """
873 aliases = list(aliases)
874 while aliases:
875 alias = aliases.pop(0)
876 if self.alias_map[alias].join_type is None:
877 # This is the base table (first FROM entry) - this table
878 # isn't really joined at all in the query, so we should not
879 # alter its join type.
880 continue
881 # Only the first alias (skipped above) should have None join_type
882 assert self.alias_map[alias].join_type is not None
883 parent_alias = self.alias_map[alias].parent_alias
884 parent_louter = (
885 parent_alias and self.alias_map[parent_alias].join_type == LOUTER
886 )
887 already_louter = self.alias_map[alias].join_type == LOUTER
888 if (self.alias_map[alias].nullable or parent_louter) and not already_louter:
889 self.alias_map[alias] = self.alias_map[alias].promote()
890 # Join type of 'alias' changed, so re-examine all aliases that
891 # refer to this one.
892 aliases.extend(
893 join
894 for join in self.alias_map
895 if self.alias_map[join].parent_alias == alias
896 and join not in aliases
897 )
898
899 def demote_joins(self, aliases: set[str] | list[str]) -> None:
900 """
901 Change join type from LOUTER to INNER for all joins in aliases.
902
903 Similarly to promote_joins(), this method must ensure no join chains
904 containing first an outer, then an inner join are generated. If we
905 are demoting b->c join in chain a LOUTER b LOUTER c then we must
906 demote a->b automatically, or otherwise the demotion of b->c doesn't
907 actually change anything in the query results. .
908 """
909 aliases = list(aliases)
910 while aliases:
911 alias = aliases.pop(0)
912 if self.alias_map[alias].join_type == LOUTER:
913 self.alias_map[alias] = self.alias_map[alias].demote()
914 parent_alias = self.alias_map[alias].parent_alias
915 if self.alias_map[parent_alias].join_type == INNER:
916 aliases.append(parent_alias)
917
918 def reset_refcounts(self, to_counts: dict[str, int]) -> None:
919 """
920 Reset reference counts for aliases so that they match the value passed
921 in `to_counts`.
922 """
923 for alias, cur_refcount in self.alias_refcount.copy().items():
924 unref_amount = cur_refcount - to_counts.get(alias, 0)
925 self.unref_alias(alias, unref_amount)
926
927 def change_aliases(self, change_map: dict[str, str]) -> None:
928 """
929 Change the aliases in change_map (which maps old-alias -> new-alias),
930 relabelling any references to them in select columns and the where
931 clause.
932 """
933 # If keys and values of change_map were to intersect, an alias might be
934 # updated twice (e.g. T4 -> T5, T5 -> T6, so also T4 -> T6) depending
935 # on their order in change_map.
936 assert set(change_map).isdisjoint(change_map.values())
937
938 # 1. Update references in "select" (normal columns plus aliases),
939 # "group by" and "where".
940 self.where.relabel_aliases(change_map)
941 if isinstance(self.group_by, tuple):
942 self.group_by = tuple(
943 [col.relabeled_clone(change_map) for col in self.group_by]
944 )
945 self.select = tuple([col.relabeled_clone(change_map) for col in self.select])
946 self.annotations = self.annotations and {
947 key: col.relabeled_clone(change_map)
948 for key, col in self.annotations.items()
949 }
950
951 # 2. Rename the alias in the internal table/alias datastructures.
952 for old_alias, new_alias in change_map.items():
953 if old_alias not in self.alias_map:
954 continue
955 alias_data = self.alias_map[old_alias].relabeled_clone(change_map)
956 self.alias_map[new_alias] = alias_data
957 self.alias_refcount[new_alias] = self.alias_refcount[old_alias]
958 del self.alias_refcount[old_alias]
959 del self.alias_map[old_alias]
960
961 table_aliases = self.table_map[alias_data.table_name]
962 for pos, alias in enumerate(table_aliases):
963 if alias == old_alias:
964 table_aliases[pos] = new_alias
965 break
966 self.external_aliases = {
967 # Table is aliased or it's being changed and thus is aliased.
968 change_map.get(alias, alias): (aliased or alias in change_map)
969 for alias, aliased in self.external_aliases.items()
970 }
971
972 def bump_prefix(
973 self, other_query: Query, exclude: set[str] | dict[str, str] | None = None
974 ) -> None:
975 """
976 Change the alias prefix to the next letter in the alphabet in a way
977 that the other query's aliases and this query's aliases will not
978 conflict. Even tables that previously had no alias will get an alias
979 after this call. To prevent changing aliases use the exclude parameter.
980 """
981
982 def prefix_gen() -> TypingIterator[str]:
983 """
984 Generate a sequence of characters in alphabetical order:
985 -> 'A', 'B', 'C', ...
986
987 When the alphabet is finished, the sequence will continue with the
988 Cartesian product:
989 -> 'AA', 'AB', 'AC', ...
990 """
991 alphabet = ascii_uppercase
992 prefix = chr(ord(self.alias_prefix) + 1)
993 yield prefix
994 for n in count(1):
995 seq = alphabet[alphabet.index(prefix) :] if prefix else alphabet
996 for s in product(seq, repeat=n):
997 yield "".join(s)
998 prefix = None
999
1000 if self.alias_prefix != other_query.alias_prefix:
1001 # No clashes between self and outer query should be possible.
1002 return
1003
1004 # Explicitly avoid infinite loop. The constant divider is based on how
1005 # much depth recursive subquery references add to the stack. This value
1006 # might need to be adjusted when adding or removing function calls from
1007 # the code path in charge of performing these operations.
1008 local_recursion_limit = sys.getrecursionlimit() // 16
1009 for pos, prefix in enumerate(prefix_gen()):
1010 if prefix not in self.subq_aliases:
1011 self.alias_prefix = prefix
1012 break
1013 if pos > local_recursion_limit:
1014 raise RecursionError(
1015 "Maximum recursion depth exceeded: too many subqueries."
1016 )
1017 self.subq_aliases = self.subq_aliases.union([self.alias_prefix])
1018 other_query.subq_aliases = other_query.subq_aliases.union(self.subq_aliases)
1019 if exclude is None:
1020 exclude = {}
1021 self.change_aliases(
1022 {
1023 alias: "%s%d" % (self.alias_prefix, pos) # noqa: UP031
1024 for pos, alias in enumerate(self.alias_map)
1025 if alias not in exclude
1026 }
1027 )
1028
1029 def get_initial_alias(self) -> str | None:
1030 """
1031 Return the first alias for this query, after increasing its reference
1032 count.
1033 """
1034 if self.alias_map:
1035 alias = self.base_table
1036 self.ref_alias(alias)
1037 elif self.model:
1038 alias = self.join(
1039 self.base_table_class(self.model.model_options.db_table, None)
1040 )
1041 else:
1042 alias = None
1043 return alias
1044
1045 def count_active_tables(self) -> int:
1046 """
1047 Return the number of tables in this query with a non-zero reference
1048 count. After execution, the reference counts are zeroed, so tables
1049 added in compiler will not be seen by this method.
1050 """
1051 return len([1 for count in self.alias_refcount.values() if count])
1052
1053 def join(
1054 self,
1055 join: BaseTable | Join,
1056 reuse: set[str] | None = None,
1057 reuse_with_filtered_relation: bool = False,
1058 ) -> str:
1059 """
1060 Return an alias for the 'join', either reusing an existing alias for
1061 that join or creating a new one. 'join' is either a base_table_class or
1062 join_class.
1063
1064 The 'reuse' parameter can be either None which means all joins are
1065 reusable, or it can be a set containing the aliases that can be reused.
1066
1067 The 'reuse_with_filtered_relation' parameter is used when computing
1068 FilteredRelation instances.
1069
1070 A join is always created as LOUTER if the lhs alias is LOUTER to make
1071 sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new
1072 joins are created as LOUTER if the join is nullable.
1073 """
1074 if reuse_with_filtered_relation and reuse:
1075 reuse_aliases = [
1076 a for a, j in self.alias_map.items() if a in reuse and j.equals(join)
1077 ]
1078 else:
1079 reuse_aliases = [
1080 a
1081 for a, j in self.alias_map.items()
1082 if (reuse is None or a in reuse) and j == join
1083 ]
1084 if reuse_aliases:
1085 if join.table_alias in reuse_aliases:
1086 reuse_alias = join.table_alias
1087 else:
1088 # Reuse the most recent alias of the joined table
1089 # (a many-to-many relation may be joined multiple times).
1090 reuse_alias = reuse_aliases[-1]
1091 self.ref_alias(reuse_alias)
1092 return reuse_alias
1093
1094 # No reuse is possible, so we need a new alias.
1095 alias, _ = self.table_alias(
1096 join.table_name, create=True, filtered_relation=join.filtered_relation
1097 )
1098 if isinstance(join, Join):
1099 if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable:
1100 join_type = LOUTER
1101 else:
1102 join_type = INNER
1103 join.join_type = join_type
1104 join.table_alias = alias
1105 self.alias_map[alias] = join
1106 return alias
1107
1108 def check_alias(self, alias: str) -> None:
1109 if FORBIDDEN_ALIAS_PATTERN.search(alias):
1110 raise ValueError(
1111 "Column aliases cannot contain whitespace characters, quotation marks, "
1112 "semicolons, or SQL comments."
1113 )
1114
1115 def add_annotation(
1116 self, annotation: BaseExpression, alias: str, select: bool = True
1117 ) -> None:
1118 """Add a single annotation expression to the Query."""
1119 self.check_alias(alias)
1120 annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None)
1121 if select:
1122 self.append_annotation_mask([alias])
1123 else:
1124 self.set_annotation_mask(set(self.annotation_select).difference({alias}))
1125 self.annotations[alias] = annotation
1126
1127 def resolve_expression(
1128 self,
1129 query: Any = None,
1130 allow_joins: bool = True,
1131 reuse: Any = None,
1132 summarize: bool = False,
1133 for_save: bool = False,
1134 ) -> Self:
1135 clone = self.clone()
1136 # Subqueries need to use a different set of aliases than the outer query.
1137 clone.bump_prefix(query)
1138 clone.subquery = True
1139 clone.where.resolve_expression(query, allow_joins, reuse, summarize, for_save)
1140 for key, value in clone.annotations.items():
1141 resolved = value.resolve_expression(
1142 query, allow_joins, reuse, summarize, for_save
1143 )
1144 if hasattr(resolved, "external_aliases"):
1145 resolved.external_aliases.update(clone.external_aliases)
1146 clone.annotations[key] = resolved
1147 # Outer query's aliases are considered external.
1148 for alias, table in query.alias_map.items():
1149 clone.external_aliases[alias] = (
1150 isinstance(table, Join)
1151 and table.join_field.related_model.model_options.db_table != alias
1152 ) or (
1153 isinstance(table, BaseTable) and table.table_name != table.table_alias
1154 )
1155 return clone
1156
1157 def get_external_cols(self) -> list[Col]:
1158 exprs = chain(self.annotations.values(), self.where.children)
1159 return [
1160 col
1161 for col in self._gen_cols(exprs, include_external=True)
1162 if col.alias in self.external_aliases
1163 ]
1164
1165 def get_group_by_cols(
1166 self, wrapper: BaseExpression | None = None
1167 ) -> list[BaseExpression]:
1168 # If wrapper is referenced by an alias for an explicit GROUP BY through
1169 # values() a reference to this expression and not the self must be
1170 # returned to ensure external column references are not grouped against
1171 # as well.
1172 external_cols = self.get_external_cols()
1173 if any(col.possibly_multivalued for col in external_cols):
1174 return [wrapper or self]
1175 # Cast needed because list is invariant: list[Col] is not list[BaseExpression]
1176 return cast(list[BaseExpression], external_cols)
1177
1178 def as_sql(
1179 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
1180 ) -> SqlWithParams:
1181 # Some backends (e.g. Oracle) raise an error when a subquery contains
1182 # unnecessary ORDER BY clause.
1183 if (
1184 self.subquery
1185 and not db_connection.features.ignores_unnecessary_order_by_in_subqueries
1186 ):
1187 self.clear_ordering(force=False)
1188 sql, params = self.get_compiler().as_sql()
1189 if self.subquery:
1190 sql = f"({sql})"
1191 return sql, params
1192
1193 def resolve_lookup_value(
1194 self, value: Any, can_reuse: set[str] | None, allow_joins: bool
1195 ) -> Any:
1196 if isinstance(value, ResolvableExpression):
1197 value = value.resolve_expression(
1198 self,
1199 reuse=can_reuse,
1200 allow_joins=allow_joins,
1201 )
1202 elif isinstance(value, list | tuple):
1203 # The items of the iterable may be expressions and therefore need
1204 # to be resolved independently.
1205 values = (
1206 self.resolve_lookup_value(sub_value, can_reuse, allow_joins)
1207 for sub_value in value
1208 )
1209 type_ = type(value)
1210 if hasattr(type_, "_make"): # namedtuple
1211 return type_(*values)
1212 return type_(values)
1213 return value
1214
1215 def solve_lookup_type(
1216 self, lookup: str, summarize: bool = False
1217 ) -> tuple[
1218 list[str] | tuple[str, ...], tuple[str, ...], BaseExpression | Literal[False]
1219 ]:
1220 """
1221 Solve the lookup type from the lookup (e.g.: 'foobar__id__icontains').
1222 """
1223 lookup_splitted = lookup.split(LOOKUP_SEP)
1224 if self.annotations:
1225 annotation, expression_lookups = refs_expression(
1226 lookup_splitted, self.annotations
1227 )
1228 if annotation:
1229 expression = self.annotations[annotation]
1230 if summarize:
1231 expression = Ref(annotation, expression)
1232 return expression_lookups, (), expression
1233 assert self.model is not None, "Field lookups require a model"
1234 meta = self.model._model_meta
1235 _, field, _, lookup_parts = self.names_to_path(lookup_splitted, meta)
1236 field_parts = lookup_splitted[0 : len(lookup_splitted) - len(lookup_parts)]
1237 if len(lookup_parts) > 1 and not field_parts:
1238 raise FieldError(
1239 f'Invalid lookup "{lookup}" for model {meta.model.__name__}".'
1240 )
1241 return lookup_parts, tuple(field_parts), False
1242
1243 def check_query_object_type(
1244 self, value: Any, meta: Meta, field: Field | ForeignObjectRel
1245 ) -> None:
1246 """
1247 Check whether the object passed while querying is of the correct type.
1248 If not, raise a ValueError specifying the wrong object.
1249 """
1250 from plain.models import Model
1251
1252 if isinstance(value, Model):
1253 if not check_rel_lookup_compatibility(value._model_meta.model, meta, field):
1254 raise ValueError(
1255 f'Cannot query "{value}": Must be "{meta.model.model_options.object_name}" instance.'
1256 )
1257
1258 def check_related_objects(
1259 self, field: RelatedField | ForeignObjectRel, value: Any, meta: Meta
1260 ) -> None:
1261 """Check the type of object passed to query relations."""
1262 from plain.models import Model
1263
1264 # Check that the field and the queryset use the same model in a
1265 # query like .filter(author=Author.query.all()). For example, the
1266 # meta would be Author's (from the author field) and value.model
1267 # would be Author.query.all() queryset's .model (Author also).
1268 # The field is the related field on the lhs side.
1269 if (
1270 isinstance(value, Query)
1271 and not value.has_select_fields
1272 and not check_rel_lookup_compatibility(value.model, meta, field)
1273 ):
1274 raise ValueError(
1275 f'Cannot use QuerySet for "{value.model.model_options.object_name}": Use a QuerySet for "{meta.model.model_options.object_name}".'
1276 )
1277 elif isinstance(value, Model):
1278 self.check_query_object_type(value, meta, field)
1279 elif isinstance(value, Iterable):
1280 for v in value:
1281 self.check_query_object_type(v, meta, field)
1282
1283 def check_filterable(self, expression: Any) -> None:
1284 """Raise an error if expression cannot be used in a WHERE clause."""
1285 if isinstance(expression, ResolvableExpression) and not getattr(
1286 expression, "filterable", True
1287 ):
1288 raise NotSupportedError(
1289 expression.__class__.__name__ + " is disallowed in the filter clause."
1290 )
1291 if hasattr(expression, "get_source_expressions"):
1292 for expr in expression.get_source_expressions():
1293 self.check_filterable(expr)
1294
1295 def build_lookup(
1296 self, lookups: list[str], lhs: BaseExpression | MultiColSource, rhs: Any
1297 ) -> Lookup | None:
1298 """
1299 Try to extract transforms and lookup from given lhs.
1300
1301 The lhs value is something that works like SQLExpression.
1302 The rhs value is what the lookup is going to compare against.
1303 The lookups is a list of names to extract using get_lookup()
1304 and get_transform().
1305 """
1306 # __exact is the default lookup if one isn't given.
1307 *transforms, lookup_name = lookups or ["exact"]
1308 if transforms:
1309 if isinstance(lhs, MultiColSource):
1310 raise FieldError(
1311 "Transforms are not supported on multi-column relations."
1312 )
1313 # At this point, lhs must be BaseExpression
1314 for name in transforms:
1315 lhs = self.try_transform(lhs, name)
1316 # First try get_lookup() so that the lookup takes precedence if the lhs
1317 # supports both transform and lookup for the name.
1318 lookup_class = lhs.get_lookup(lookup_name)
1319 if not lookup_class:
1320 # A lookup wasn't found. Try to interpret the name as a transform
1321 # and do an Exact lookup against it.
1322 if isinstance(lhs, MultiColSource):
1323 raise FieldError(
1324 "Transforms are not supported on multi-column relations."
1325 )
1326 lhs = self.try_transform(lhs, lookup_name)
1327 lookup_name = "exact"
1328 lookup_class = lhs.get_lookup(lookup_name)
1329 if not lookup_class:
1330 return
1331
1332 lookup = lookup_class(lhs, rhs)
1333 # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
1334 # uses of None as a query value unless the lookup supports it.
1335 if lookup.rhs is None and not lookup.can_use_none_as_rhs:
1336 if lookup_name not in ("exact", "iexact"):
1337 raise ValueError("Cannot use None as a query value")
1338 isnull_lookup = lhs.get_lookup("isnull")
1339 assert isnull_lookup is not None
1340 return isnull_lookup(lhs, True)
1341
1342 return lookup
1343
1344 def try_transform(self, lhs: BaseExpression, name: str) -> BaseExpression:
1345 """
1346 Helper method for build_lookup(). Try to fetch and initialize
1347 a transform for name parameter from lhs.
1348 """
1349 transform_class = lhs.get_transform(name)
1350 if transform_class:
1351 return transform_class(lhs)
1352 else:
1353 output_field = lhs.output_field.__class__
1354 suggested_lookups = difflib.get_close_matches(
1355 name, output_field.get_lookups()
1356 )
1357 if suggested_lookups:
1358 suggestion = ", perhaps you meant {}?".format(
1359 " or ".join(suggested_lookups)
1360 )
1361 else:
1362 suggestion = "."
1363 raise FieldError(
1364 f"Unsupported lookup '{name}' for {output_field.__name__} or join on the field not "
1365 f"permitted{suggestion}"
1366 )
1367
1368 def build_filter(
1369 self,
1370 filter_expr: tuple[str, Any] | Q | BaseExpression,
1371 branch_negated: bool = False,
1372 current_negated: bool = False,
1373 can_reuse: set[str] | None = None,
1374 allow_joins: bool = True,
1375 split_subq: bool = True,
1376 reuse_with_filtered_relation: bool = False,
1377 check_filterable: bool = True,
1378 summarize: bool = False,
1379 ) -> tuple[WhereNode, set[str] | tuple[()]]:
1380 from plain.models.fields.related import RelatedField
1381
1382 """
1383 Build a WhereNode for a single filter clause but don't add it
1384 to this Query. Query.add_q() will then add this filter to the where
1385 Node.
1386
1387 The 'branch_negated' tells us if the current branch contains any
1388 negations. This will be used to determine if subqueries are needed.
1389
1390 The 'current_negated' is used to determine if the current filter is
1391 negated or not and this will be used to determine if IS NULL filtering
1392 is needed.
1393
1394 The difference between current_negated and branch_negated is that
1395 branch_negated is set on first negation, but current_negated is
1396 flipped for each negation.
1397
1398 Note that add_filter will not do any negating itself, that is done
1399 upper in the code by add_q().
1400
1401 The 'can_reuse' is a set of reusable joins for multijoins.
1402
1403 If 'reuse_with_filtered_relation' is True, then only joins in can_reuse
1404 will be reused.
1405
1406 The method will create a filter clause that can be added to the current
1407 query. However, if the filter isn't added to the query then the caller
1408 is responsible for unreffing the joins used.
1409 """
1410 if isinstance(filter_expr, dict):
1411 raise FieldError("Cannot parse keyword query as dict")
1412 if isinstance(filter_expr, Q):
1413 return self._add_q(
1414 filter_expr,
1415 branch_negated=branch_negated,
1416 current_negated=current_negated,
1417 used_aliases=can_reuse,
1418 allow_joins=allow_joins,
1419 split_subq=split_subq,
1420 check_filterable=check_filterable,
1421 summarize=summarize,
1422 )
1423 if isinstance(filter_expr, ResolvableExpression):
1424 if not getattr(filter_expr, "conditional", False):
1425 raise TypeError("Cannot filter against a non-conditional expression.")
1426 condition = filter_expr.resolve_expression(
1427 self, allow_joins=allow_joins, summarize=summarize
1428 )
1429 if not isinstance(condition, Lookup):
1430 condition = self.build_lookup(["exact"], condition, True)
1431 return WhereNode([condition], connector=AND), set()
1432 if isinstance(filter_expr, BaseExpression):
1433 raise TypeError(f"Unexpected BaseExpression type: {type(filter_expr)}")
1434 arg, value = filter_expr
1435 if not arg:
1436 raise FieldError(f"Cannot parse keyword query {arg!r}")
1437 lookups, parts, reffed_expression = self.solve_lookup_type(arg, summarize)
1438
1439 if check_filterable:
1440 self.check_filterable(reffed_expression)
1441
1442 if not allow_joins and len(parts) > 1:
1443 raise FieldError("Joined field references are not permitted in this query")
1444
1445 pre_joins = self.alias_refcount.copy()
1446 value = self.resolve_lookup_value(value, can_reuse, allow_joins)
1447 used_joins = {
1448 k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)
1449 }
1450
1451 if check_filterable:
1452 self.check_filterable(value)
1453
1454 if reffed_expression:
1455 condition = self.build_lookup(list(lookups), reffed_expression, value)
1456 return WhereNode([condition], connector=AND), set()
1457
1458 assert self.model is not None, "Building filters requires a model"
1459 meta = self.model._model_meta
1460 alias = self.get_initial_alias()
1461 assert alias is not None
1462 allow_many = not branch_negated or not split_subq
1463
1464 try:
1465 join_info = self.setup_joins(
1466 list(parts),
1467 meta,
1468 alias,
1469 can_reuse=can_reuse,
1470 allow_many=allow_many,
1471 reuse_with_filtered_relation=reuse_with_filtered_relation,
1472 )
1473
1474 # Prevent iterator from being consumed by check_related_objects()
1475 if isinstance(value, Iterator):
1476 value = list(value)
1477 from plain.models.fields.related import RelatedField
1478 from plain.models.fields.reverse_related import ForeignObjectRel
1479
1480 if isinstance(join_info.final_field, RelatedField | ForeignObjectRel):
1481 self.check_related_objects(join_info.final_field, value, join_info.meta)
1482
1483 # split_exclude() needs to know which joins were generated for the
1484 # lookup parts
1485 self._lookup_joins = join_info.joins
1486 except MultiJoin as e:
1487 return self.split_exclude(
1488 filter_expr,
1489 can_reuse or set(),
1490 e.names_with_path,
1491 )
1492
1493 # Update used_joins before trimming since they are reused to determine
1494 # which joins could be later promoted to INNER.
1495 used_joins.update(join_info.joins)
1496 targets, alias, join_list = self.trim_joins(
1497 join_info.targets, join_info.joins, join_info.path
1498 )
1499 if can_reuse is not None:
1500 can_reuse.update(join_list)
1501
1502 if isinstance(join_info.final_field, RelatedField | ForeignObjectRel):
1503 if len(targets) == 1:
1504 col = self._get_col(targets[0], join_info.final_field, alias)
1505 else:
1506 col = MultiColSource(
1507 alias, targets, join_info.targets, join_info.final_field
1508 )
1509 else:
1510 col = self._get_col(targets[0], join_info.final_field, alias)
1511
1512 condition = self.build_lookup(list(lookups), col, value)
1513 assert condition is not None
1514 lookup_type = condition.lookup_name
1515 clause = WhereNode([condition], connector=AND)
1516
1517 require_outer = (
1518 lookup_type == "isnull" and condition.rhs is True and not current_negated
1519 )
1520 if (
1521 current_negated
1522 and (lookup_type != "isnull" or condition.rhs is False)
1523 and condition.rhs is not None
1524 ):
1525 require_outer = True
1526 if lookup_type != "isnull":
1527 # The condition added here will be SQL like this:
1528 # NOT (col IS NOT NULL), where the first NOT is added in
1529 # upper layers of code. The reason for addition is that if col
1530 # is null, then col != someval will result in SQL "unknown"
1531 # which isn't the same as in Python. The Python None handling
1532 # is wanted, and it can be gotten by
1533 # (col IS NULL OR col != someval)
1534 # <=>
1535 # NOT (col IS NOT NULL AND col = someval).
1536 if (
1537 self.is_nullable(targets[0])
1538 or self.alias_map[join_list[-1]].join_type == LOUTER
1539 ):
1540 lookup_class = targets[0].get_lookup("isnull")
1541 assert lookup_class is not None
1542 col = self._get_col(targets[0], join_info.targets[0], alias)
1543 clause.add(lookup_class(col, False), AND)
1544 # If someval is a nullable column, someval IS NOT NULL is
1545 # added.
1546 if isinstance(value, Col) and self.is_nullable(value.target):
1547 lookup_class = value.target.get_lookup("isnull")
1548 assert lookup_class is not None
1549 clause.add(lookup_class(value, False), AND)
1550 return clause, used_joins if not require_outer else ()
1551
1552 def add_filter(self, filter_lhs: str, filter_rhs: Any) -> None:
1553 self.add_q(Q((filter_lhs, filter_rhs)))
1554
1555 def add_q(self, q_object: Q) -> None:
1556 """
1557 A preprocessor for the internal _add_q(). Responsible for doing final
1558 join promotion.
1559 """
1560 # For join promotion this case is doing an AND for the added q_object
1561 # and existing conditions. So, any existing inner join forces the join
1562 # type to remain inner. Existing outer joins can however be demoted.
1563 # (Consider case where rel_a is LOUTER and rel_a__col=1 is added - if
1564 # rel_a doesn't produce any rows, then the whole condition must fail.
1565 # So, demotion is OK.
1566 existing_inner = {
1567 a for a in self.alias_map if self.alias_map[a].join_type == INNER
1568 }
1569 clause, _ = self._add_q(q_object, self.used_aliases)
1570 if clause:
1571 self.where.add(clause, AND)
1572 self.demote_joins(existing_inner)
1573
1574 def build_where(
1575 self, filter_expr: tuple[str, Any] | Q | BaseExpression
1576 ) -> WhereNode:
1577 return self.build_filter(filter_expr, allow_joins=False)[0]
1578
1579 def clear_where(self) -> None:
1580 self.where = WhereNode()
1581
1582 def _add_q(
1583 self,
1584 q_object: Q,
1585 used_aliases: set[str] | None,
1586 branch_negated: bool = False,
1587 current_negated: bool = False,
1588 allow_joins: bool = True,
1589 split_subq: bool = True,
1590 check_filterable: bool = True,
1591 summarize: bool = False,
1592 ) -> tuple[WhereNode, set[str] | tuple[()]]:
1593 """Add a Q-object to the current filter."""
1594 connector = q_object.connector
1595 current_negated ^= q_object.negated
1596 branch_negated = branch_negated or q_object.negated
1597 target_clause = WhereNode(connector=connector, negated=q_object.negated)
1598 joinpromoter = JoinPromoter(
1599 q_object.connector, len(q_object.children), current_negated
1600 )
1601 for child in q_object.children:
1602 child_clause, needed_inner = self.build_filter(
1603 child,
1604 can_reuse=used_aliases,
1605 branch_negated=branch_negated,
1606 current_negated=current_negated,
1607 allow_joins=allow_joins,
1608 split_subq=split_subq,
1609 check_filterable=check_filterable,
1610 summarize=summarize,
1611 )
1612 joinpromoter.add_votes(needed_inner)
1613 if child_clause:
1614 target_clause.add(child_clause, connector)
1615 needed_inner = joinpromoter.update_join_types(self)
1616 return target_clause, needed_inner
1617
1618 def build_filtered_relation_q(
1619 self,
1620 q_object: Q,
1621 reuse: set[str],
1622 branch_negated: bool = False,
1623 current_negated: bool = False,
1624 ) -> WhereNode:
1625 """Add a FilteredRelation object to the current filter."""
1626 connector = q_object.connector
1627 current_negated ^= q_object.negated
1628 branch_negated = branch_negated or q_object.negated
1629 target_clause = WhereNode(connector=connector, negated=q_object.negated)
1630 for child in q_object.children:
1631 if isinstance(child, Node):
1632 child_clause = self.build_filtered_relation_q(
1633 child,
1634 reuse=reuse,
1635 branch_negated=branch_negated,
1636 current_negated=current_negated,
1637 )
1638 else:
1639 child_clause, _ = self.build_filter(
1640 child,
1641 can_reuse=reuse,
1642 branch_negated=branch_negated,
1643 current_negated=current_negated,
1644 allow_joins=True,
1645 split_subq=False,
1646 reuse_with_filtered_relation=True,
1647 )
1648 target_clause.add(child_clause, connector)
1649 return target_clause
1650
1651 def add_filtered_relation(self, filtered_relation: Any, alias: str) -> None:
1652 filtered_relation.alias = alias
1653 lookups = dict(get_children_from_q(filtered_relation.condition))
1654 relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(
1655 filtered_relation.relation_name
1656 )
1657 if relation_lookup_parts:
1658 raise ValueError(
1659 "FilteredRelation's relation_name cannot contain lookups "
1660 f"(got {filtered_relation.relation_name!r})."
1661 )
1662 for lookup in chain(lookups):
1663 lookup_parts, lookup_field_parts, _ = self.solve_lookup_type(lookup)
1664 shift = 2 if not lookup_parts else 1
1665 lookup_field_path = lookup_field_parts[:-shift]
1666 for idx, lookup_field_part in enumerate(lookup_field_path):
1667 if len(relation_field_parts) > idx:
1668 if relation_field_parts[idx] != lookup_field_part:
1669 raise ValueError(
1670 "FilteredRelation's condition doesn't support "
1671 f"relations outside the {filtered_relation.relation_name!r} (got {lookup!r})."
1672 )
1673 else:
1674 raise ValueError(
1675 "FilteredRelation's condition doesn't support nested "
1676 f"relations deeper than the relation_name (got {lookup!r} for "
1677 f"{filtered_relation.relation_name!r})."
1678 )
1679 self._filtered_relations[filtered_relation.alias] = filtered_relation
1680
1681 def names_to_path(
1682 self,
1683 names: list[str],
1684 meta: Meta,
1685 allow_many: bool = True,
1686 fail_on_missing: bool = False,
1687 ) -> tuple[list[Any], Field | ForeignObjectRel, tuple[Field, ...], list[str]]:
1688 """
1689 Walk the list of names and turns them into PathInfo tuples. A single
1690 name in 'names' can generate multiple PathInfos (m2m, for example).
1691
1692 'names' is the path of names to travel, 'meta' is the Meta we
1693 start the name resolving from, 'allow_many' is as for setup_joins().
1694 If fail_on_missing is set to True, then a name that can't be resolved
1695 will generate a FieldError.
1696
1697 Return a list of PathInfo tuples. In addition return the final field
1698 (the last used join field) and target (which is a field guaranteed to
1699 contain the same value as the final field). Finally, return those names
1700 that weren't found (which are likely transforms and the final lookup).
1701 """
1702 path, names_with_path = [], []
1703 for pos, name in enumerate(names):
1704 cur_names_with_path = (name, [])
1705
1706 field = None
1707 filtered_relation = None
1708 try:
1709 if meta is None:
1710 raise FieldDoesNotExist
1711 field = meta.get_field(name)
1712 except FieldDoesNotExist:
1713 if name in self.annotation_select:
1714 field = self.annotation_select[name].output_field
1715 elif name in self._filtered_relations and pos == 0:
1716 filtered_relation = self._filtered_relations[name]
1717 if LOOKUP_SEP in filtered_relation.relation_name:
1718 parts = filtered_relation.relation_name.split(LOOKUP_SEP)
1719 filtered_relation_path, field, _, _ = self.names_to_path(
1720 parts,
1721 meta,
1722 allow_many,
1723 fail_on_missing,
1724 )
1725 path.extend(filtered_relation_path[:-1])
1726 else:
1727 field = meta.get_field(filtered_relation.relation_name)
1728 if field is None:
1729 # We didn't find the current field, so move position back
1730 # one step.
1731 pos -= 1
1732 if pos == -1 or fail_on_missing:
1733 available = sorted(
1734 [
1735 *get_field_names_from_opts(meta),
1736 *self.annotation_select,
1737 *self._filtered_relations,
1738 ]
1739 )
1740 raise FieldError(
1741 "Cannot resolve keyword '{}' into field. "
1742 "Choices are: {}".format(name, ", ".join(available))
1743 )
1744 break
1745
1746 # Lazy import to avoid circular dependency
1747 from plain.models.fields.related import ForeignKeyField as FK
1748 from plain.models.fields.related import ManyToManyField as M2M
1749 from plain.models.fields.reverse_related import ForeignObjectRel as FORel
1750
1751 if isinstance(field, FK | M2M | FORel):
1752 pathinfos: list[PathInfo]
1753 if filtered_relation:
1754 pathinfos = field.get_path_info(filtered_relation)
1755 else:
1756 pathinfos = field.path_infos
1757 if not allow_many:
1758 for inner_pos, p in enumerate(pathinfos):
1759 if p.m2m:
1760 cur_names_with_path[1].extend(pathinfos[0 : inner_pos + 1])
1761 names_with_path.append(cur_names_with_path)
1762 raise MultiJoin(pos + 1, names_with_path)
1763 last = pathinfos[-1]
1764 path.extend(pathinfos)
1765 final_field = last.join_field
1766 meta = last.to_meta
1767 targets = last.target_fields
1768 cur_names_with_path[1].extend(pathinfos)
1769 names_with_path.append(cur_names_with_path)
1770 else:
1771 # Local non-relational field.
1772 final_field = field
1773 targets = (field,)
1774 if fail_on_missing and pos + 1 != len(names):
1775 raise FieldError(
1776 f"Cannot resolve keyword {names[pos + 1]!r} into field. Join on '{name}'"
1777 " not permitted."
1778 )
1779 break
1780 return path, final_field, targets, names[pos + 1 :]
1781
1782 def setup_joins(
1783 self,
1784 names: list[str],
1785 meta: Meta,
1786 alias: str,
1787 can_reuse: set[str] | None = None,
1788 allow_many: bool = True,
1789 reuse_with_filtered_relation: bool = False,
1790 ) -> JoinInfo:
1791 """
1792 Compute the necessary table joins for the passage through the fields
1793 given in 'names'. 'meta' is the Meta for the current model
1794 (which gives the table we are starting from), 'alias' is the alias for
1795 the table to start the joining from.
1796
1797 The 'can_reuse' defines the reverse foreign key joins we can reuse. It
1798 can be None in which case all joins are reusable or a set of aliases
1799 that can be reused. Note that non-reverse foreign keys are always
1800 reusable when using setup_joins().
1801
1802 The 'reuse_with_filtered_relation' can be used to force 'can_reuse'
1803 parameter and force the relation on the given connections.
1804
1805 If 'allow_many' is False, then any reverse foreign key seen will
1806 generate a MultiJoin exception.
1807
1808 Return the final field involved in the joins, the target field (used
1809 for any 'where' constraint), the final 'opts' value, the joins, the
1810 field path traveled to generate the joins, and a transform function
1811 that takes a field and alias and is equivalent to `field.get_col(alias)`
1812 in the simple case but wraps field transforms if they were included in
1813 names.
1814
1815 The target field is the field containing the concrete value. Final
1816 field can be something different, for example foreign key pointing to
1817 that value. Final field is needed for example in some value
1818 conversions (convert 'obj' in fk__id=obj to pk val using the foreign
1819 key field for example).
1820 """
1821 joins = [alias]
1822 # The transform can't be applied yet, as joins must be trimmed later.
1823 # To avoid making every caller of this method look up transforms
1824 # directly, compute transforms here and create a partial that converts
1825 # fields to the appropriate wrapped version.
1826
1827 def _base_transformer(field: Field, alias: str | None) -> Col:
1828 if not self.alias_cols:
1829 alias = None
1830 return field.get_col(alias)
1831
1832 final_transformer: TransformWrapper | Callable[[Field, str | None], Col] = (
1833 _base_transformer
1834 )
1835
1836 # Try resolving all the names as fields first. If there's an error,
1837 # treat trailing names as lookups until a field can be resolved.
1838 last_field_exception = None
1839 for pivot in range(len(names), 0, -1):
1840 try:
1841 path, final_field, targets, rest = self.names_to_path(
1842 names[:pivot],
1843 meta,
1844 allow_many,
1845 fail_on_missing=True,
1846 )
1847 except FieldError as exc:
1848 if pivot == 1:
1849 # The first item cannot be a lookup, so it's safe
1850 # to raise the field error here.
1851 raise
1852 else:
1853 last_field_exception = exc
1854 else:
1855 # The transforms are the remaining items that couldn't be
1856 # resolved into fields.
1857 transforms = names[pivot:]
1858 break
1859 for name in transforms:
1860
1861 def transform(
1862 field: Field, alias: str | None, *, name: str, previous: Any
1863 ) -> BaseExpression:
1864 try:
1865 wrapped = previous(field, alias)
1866 return self.try_transform(wrapped, name)
1867 except FieldError:
1868 # FieldError is raised if the transform doesn't exist.
1869 if isinstance(final_field, Field) and last_field_exception:
1870 raise last_field_exception
1871 else:
1872 raise
1873
1874 final_transformer = TransformWrapper(
1875 transform, name=name, previous=final_transformer
1876 )
1877 final_transformer.has_transforms = True
1878 # Then, add the path to the query's joins. Note that we can't trim
1879 # joins at this stage - we will need the information about join type
1880 # of the trimmed joins.
1881 for join in path:
1882 if join.filtered_relation:
1883 filtered_relation = join.filtered_relation.clone()
1884 table_alias = filtered_relation.alias
1885 else:
1886 filtered_relation = None
1887 table_alias = None
1888 meta = join.to_meta
1889 if join.direct:
1890 nullable = self.is_nullable(join.join_field)
1891 else:
1892 nullable = True
1893 connection = self.join_class(
1894 meta.model.model_options.db_table,
1895 alias,
1896 table_alias,
1897 INNER,
1898 join.join_field,
1899 nullable,
1900 filtered_relation=filtered_relation,
1901 )
1902 reuse = can_reuse if join.m2m or reuse_with_filtered_relation else None
1903 alias = self.join(
1904 connection,
1905 reuse=reuse,
1906 reuse_with_filtered_relation=reuse_with_filtered_relation,
1907 )
1908 joins.append(alias)
1909 if filtered_relation:
1910 filtered_relation.path = joins[:]
1911 return JoinInfo(final_field, targets, meta, joins, path, final_transformer)
1912
1913 def trim_joins(
1914 self, targets: tuple[Field, ...], joins: list[str], path: list[Any]
1915 ) -> tuple[tuple[Field, ...], str, list[str]]:
1916 """
1917 The 'target' parameter is the final field being joined to, 'joins'
1918 is the full list of join aliases. The 'path' contain the PathInfos
1919 used to create the joins.
1920
1921 Return the final target field and table alias and the new active
1922 joins.
1923
1924 Always trim any direct join if the target column is already in the
1925 previous table. Can't trim reverse joins as it's unknown if there's
1926 anything on the other side of the join.
1927 """
1928 joins = joins[:]
1929 for pos, info in enumerate(reversed(path)):
1930 if len(joins) == 1 or not info.direct:
1931 break
1932 if info.filtered_relation:
1933 break
1934 join_targets = {t.column for t in info.join_field.foreign_related_fields}
1935 cur_targets = {t.column for t in targets}
1936 if not cur_targets.issubset(join_targets):
1937 break
1938 targets_dict = {
1939 r[1].column: r[0]
1940 for r in info.join_field.related_fields
1941 if r[1].column in cur_targets
1942 }
1943 targets = tuple(targets_dict[t.column] for t in targets)
1944 self.unref_alias(joins.pop())
1945 return targets, joins[-1], joins
1946
1947 @classmethod
1948 def _gen_cols(
1949 cls,
1950 exprs: Iterable[Any],
1951 include_external: bool = False,
1952 resolve_refs: bool = True,
1953 ) -> TypingIterator[Col]:
1954 for expr in exprs:
1955 if isinstance(expr, Col):
1956 yield expr
1957 elif include_external and callable(
1958 getattr(expr, "get_external_cols", None)
1959 ):
1960 yield from expr.get_external_cols()
1961 elif hasattr(expr, "get_source_expressions"):
1962 if not resolve_refs and isinstance(expr, Ref):
1963 continue
1964 yield from cls._gen_cols(
1965 expr.get_source_expressions(),
1966 include_external=include_external,
1967 resolve_refs=resolve_refs,
1968 )
1969
1970 @classmethod
1971 def _gen_col_aliases(cls, exprs: Iterable[Any]) -> TypingIterator[str]:
1972 yield from (expr.alias for expr in cls._gen_cols(exprs))
1973
1974 def resolve_ref(
1975 self,
1976 name: str,
1977 allow_joins: bool = True,
1978 reuse: set[str] | None = None,
1979 summarize: bool = False,
1980 ) -> BaseExpression:
1981 annotation = self.annotations.get(name)
1982 if annotation is not None:
1983 if not allow_joins:
1984 for alias in self._gen_col_aliases([annotation]):
1985 if isinstance(self.alias_map[alias], Join):
1986 raise FieldError(
1987 "Joined field references are not permitted in this query"
1988 )
1989 if summarize:
1990 # Summarize currently means we are doing an aggregate() query
1991 # which is executed as a wrapped subquery if any of the
1992 # aggregate() elements reference an existing annotation. In
1993 # that case we need to return a Ref to the subquery's annotation.
1994 if name not in self.annotation_select:
1995 raise FieldError(
1996 f"Cannot aggregate over the '{name}' alias. Use annotate() "
1997 "to promote it."
1998 )
1999 return Ref(name, self.annotation_select[name])
2000 else:
2001 return annotation
2002 else:
2003 field_list = name.split(LOOKUP_SEP)
2004 annotation = self.annotations.get(field_list[0])
2005 if annotation is not None:
2006 for transform in field_list[1:]:
2007 annotation = self.try_transform(annotation, transform)
2008 return annotation
2009 initial_alias = self.get_initial_alias()
2010 assert initial_alias is not None
2011 assert self.model is not None, "Resolving field references requires a model"
2012 meta = self.model._model_meta
2013 join_info = self.setup_joins(
2014 field_list,
2015 meta,
2016 initial_alias,
2017 can_reuse=reuse,
2018 )
2019 targets, final_alias, join_list = self.trim_joins(
2020 join_info.targets, join_info.joins, join_info.path
2021 )
2022 if not allow_joins and len(join_list) > 1:
2023 raise FieldError(
2024 "Joined field references are not permitted in this query"
2025 )
2026 if len(targets) > 1:
2027 raise FieldError(
2028 "Referencing multicolumn fields with F() objects isn't supported"
2029 )
2030 # Verify that the last lookup in name is a field or a transform:
2031 # transform_function() raises FieldError if not.
2032 transform = join_info.transform_function(targets[0], final_alias)
2033 if reuse is not None:
2034 reuse.update(join_list)
2035 return transform
2036
2037 def split_exclude(
2038 self,
2039 filter_expr: tuple[str, Any],
2040 can_reuse: set[str],
2041 names_with_path: list[tuple[str, list[Any]]],
2042 ) -> tuple[WhereNode, set[str] | tuple[()]]:
2043 """
2044 When doing an exclude against any kind of N-to-many relation, we need
2045 to use a subquery. This method constructs the nested query, given the
2046 original exclude filter (filter_expr) and the portion up to the first
2047 N-to-many relation field.
2048
2049 For example, if the origin filter is ~Q(child__name='foo'), filter_expr
2050 is ('child__name', 'foo') and can_reuse is a set of joins usable for
2051 filters in the original query.
2052
2053 We will turn this into equivalent of:
2054 WHERE NOT EXISTS(
2055 SELECT 1
2056 FROM child
2057 WHERE name = 'foo' AND child.parent_id = parent.id
2058 LIMIT 1
2059 )
2060 """
2061 # Generate the inner query.
2062 query = self.__class__(self.model)
2063 query._filtered_relations = self._filtered_relations
2064 filter_lhs, filter_rhs = filter_expr
2065 if isinstance(filter_rhs, OuterRef):
2066 filter_rhs = OuterRef(filter_rhs)
2067 elif isinstance(filter_rhs, F):
2068 filter_rhs = OuterRef(filter_rhs.name)
2069 query.add_filter(filter_lhs, filter_rhs)
2070 query.clear_ordering(force=True)
2071 # Try to have as simple as possible subquery -> trim leading joins from
2072 # the subquery.
2073 trimmed_prefix, contains_louter = query.trim_start(names_with_path)
2074
2075 col = query.select[0]
2076 select_field = col.target
2077 alias = col.alias
2078 if alias in can_reuse:
2079 id_field = select_field.model._model_meta.get_forward_field("id")
2080 # Need to add a restriction so that outer query's filters are in effect for
2081 # the subquery, too.
2082 query.bump_prefix(self)
2083 lookup_class = select_field.get_lookup("exact")
2084 # Note that the query.select[0].alias is different from alias
2085 # due to bump_prefix above.
2086 lookup = lookup_class(
2087 id_field.get_col(query.select[0].alias), id_field.get_col(alias)
2088 )
2089 query.where.add(lookup, AND)
2090 query.external_aliases[alias] = True
2091
2092 lookup_class = select_field.get_lookup("exact")
2093 lookup = lookup_class(col, ResolvedOuterRef(trimmed_prefix))
2094 query.where.add(lookup, AND)
2095 condition, needed_inner = self.build_filter(Exists(query))
2096
2097 if contains_louter:
2098 or_null_condition, _ = self.build_filter(
2099 (f"{trimmed_prefix}__isnull", True),
2100 current_negated=True,
2101 branch_negated=True,
2102 can_reuse=can_reuse,
2103 )
2104 condition.add(or_null_condition, OR)
2105 # Note that the end result will be:
2106 # (outercol NOT IN innerq AND outercol IS NOT NULL) OR outercol IS NULL.
2107 # This might look crazy but due to how IN works, this seems to be
2108 # correct. If the IS NOT NULL check is removed then outercol NOT
2109 # IN will return UNKNOWN. If the IS NULL check is removed, then if
2110 # outercol IS NULL we will not match the row.
2111 return condition, needed_inner
2112
2113 def set_empty(self) -> None:
2114 self.where.add(NothingNode(), AND)
2115
2116 def is_empty(self) -> bool:
2117 return any(isinstance(c, NothingNode) for c in self.where.children)
2118
2119 def set_limits(self, low: int | None = None, high: int | None = None) -> None:
2120 """
2121 Adjust the limits on the rows retrieved. Use low/high to set these,
2122 as it makes it more Pythonic to read and write. When the SQL query is
2123 created, convert them to the appropriate offset and limit values.
2124
2125 Apply any limits passed in here to the existing constraints. Add low
2126 to the current low value and clamp both to any existing high value.
2127 """
2128 if high is not None:
2129 if self.high_mark is not None:
2130 self.high_mark = min(self.high_mark, self.low_mark + high)
2131 else:
2132 self.high_mark = self.low_mark + high
2133 if low is not None:
2134 if self.high_mark is not None:
2135 self.low_mark = min(self.high_mark, self.low_mark + low)
2136 else:
2137 self.low_mark = self.low_mark + low
2138
2139 if self.low_mark == self.high_mark:
2140 self.set_empty()
2141
2142 def clear_limits(self) -> None:
2143 """Clear any existing limits."""
2144 self.low_mark, self.high_mark = 0, None
2145
2146 @property
2147 def is_sliced(self) -> bool:
2148 return self.low_mark != 0 or self.high_mark is not None
2149
2150 def has_limit_one(self) -> bool:
2151 return self.high_mark is not None and (self.high_mark - self.low_mark) == 1
2152
2153 def can_filter(self) -> bool:
2154 """
2155 Return True if adding filters to this instance is still possible.
2156
2157 Typically, this means no limits or offsets have been put on the results.
2158 """
2159 return not self.is_sliced
2160
2161 def clear_select_clause(self) -> None:
2162 """Remove all fields from SELECT clause."""
2163 self.select = ()
2164 self.default_cols = False
2165 self.select_related = False
2166 self.set_extra_mask(())
2167 self.set_annotation_mask(())
2168
2169 def clear_select_fields(self) -> None:
2170 """
2171 Clear the list of fields to select (but not extra_select columns).
2172 Some queryset types completely replace any existing list of select
2173 columns.
2174 """
2175 self.select = ()
2176 self.values_select = ()
2177
2178 def add_select_col(self, col: BaseExpression, name: str) -> None:
2179 self.select += (col,)
2180 self.values_select += (name,)
2181
2182 def set_select(self, cols: list[Col] | tuple[Col, ...]) -> None:
2183 self.default_cols = False
2184 self.select = tuple(cols)
2185
2186 def add_distinct_fields(self, *field_names: str) -> None:
2187 """
2188 Add and resolve the given fields to the query's "distinct on" clause.
2189 """
2190 self.distinct_fields = field_names
2191 self.distinct = True
2192
2193 def add_fields(
2194 self, field_names: list[str] | TypingIterator[str], allow_m2m: bool = True
2195 ) -> None:
2196 """
2197 Add the given (model) fields to the select set. Add the field names in
2198 the order specified.
2199 """
2200 alias = self.get_initial_alias()
2201 assert alias is not None
2202 assert self.model is not None, "add_fields() requires a model"
2203 meta = self.model._model_meta
2204
2205 try:
2206 cols = []
2207 for name in field_names:
2208 # Join promotion note - we must not remove any rows here, so
2209 # if there is no existing joins, use outer join.
2210 join_info = self.setup_joins(
2211 name.split(LOOKUP_SEP), meta, alias, allow_many=allow_m2m
2212 )
2213 targets, final_alias, joins = self.trim_joins(
2214 join_info.targets,
2215 join_info.joins,
2216 join_info.path,
2217 )
2218 for target in targets:
2219 cols.append(join_info.transform_function(target, final_alias))
2220 if cols:
2221 self.set_select(cols)
2222 except MultiJoin:
2223 raise FieldError(f"Invalid field name: '{name}'")
2224 except FieldError:
2225 if LOOKUP_SEP in name:
2226 # For lookups spanning over relationships, show the error
2227 # from the model on which the lookup failed.
2228 raise
2229 elif name in self.annotations:
2230 raise FieldError(
2231 f"Cannot select the '{name}' alias. Use annotate() to promote it."
2232 )
2233 else:
2234 names = sorted(
2235 [
2236 *get_field_names_from_opts(meta),
2237 *self.extra,
2238 *self.annotation_select,
2239 *self._filtered_relations,
2240 ]
2241 )
2242 raise FieldError(
2243 "Cannot resolve keyword {!r} into field. Choices are: {}".format(
2244 name, ", ".join(names)
2245 )
2246 )
2247
2248 def add_ordering(self, *ordering: str | BaseExpression) -> None:
2249 """
2250 Add items from the 'ordering' sequence to the query's "order by"
2251 clause. These items are either field names (not column names) --
2252 possibly with a direction prefix ('-' or '?') -- or OrderBy
2253 expressions.
2254
2255 If 'ordering' is empty, clear all ordering from the query.
2256 """
2257 errors = []
2258 for item in ordering:
2259 if isinstance(item, str):
2260 if item == "?":
2261 continue
2262 item = item.removeprefix("-")
2263 if item in self.annotations:
2264 continue
2265 if self.extra and item in self.extra:
2266 continue
2267 # names_to_path() validates the lookup. A descriptive
2268 # FieldError will be raise if it's not.
2269 assert self.model is not None, "ORDER BY field names require a model"
2270 self.names_to_path(item.split(LOOKUP_SEP), self.model._model_meta)
2271 elif not isinstance(item, ResolvableExpression):
2272 errors.append(item)
2273 if getattr(item, "contains_aggregate", False):
2274 raise FieldError(
2275 "Using an aggregate in order_by() without also including "
2276 f"it in annotate() is not allowed: {item}"
2277 )
2278 if errors:
2279 raise FieldError(f"Invalid order_by arguments: {errors}")
2280 if ordering:
2281 self.order_by += ordering
2282 else:
2283 self.default_ordering = False
2284
2285 def clear_ordering(self, force: bool = False, clear_default: bool = True) -> None:
2286 """
2287 Remove any ordering settings if the current query allows it without
2288 side effects, set 'force' to True to clear the ordering regardless.
2289 If 'clear_default' is True, there will be no ordering in the resulting
2290 query (not even the model's default).
2291 """
2292 if not force and (
2293 self.is_sliced or self.distinct_fields or self.select_for_update
2294 ):
2295 return
2296 self.order_by = ()
2297 self.extra_order_by = ()
2298 if clear_default:
2299 self.default_ordering = False
2300
2301 def set_group_by(self, allow_aliases: bool = True) -> None:
2302 """
2303 Expand the GROUP BY clause required by the query.
2304
2305 This will usually be the set of all non-aggregate fields in the
2306 return data. If the database backend supports grouping by the
2307 primary key, and the query would be equivalent, the optimization
2308 will be made automatically.
2309 """
2310 if allow_aliases and self.values_select:
2311 # If grouping by aliases is allowed assign selected value aliases
2312 # by moving them to annotations.
2313 group_by_annotations = {}
2314 values_select = {}
2315 for alias, expr in zip(self.values_select, self.select):
2316 if isinstance(expr, Col):
2317 values_select[alias] = expr
2318 else:
2319 group_by_annotations[alias] = expr
2320 self.annotations = {**group_by_annotations, **self.annotations}
2321 self.append_annotation_mask(group_by_annotations)
2322 self.select = tuple(values_select.values())
2323 self.values_select = tuple(values_select)
2324 group_by = list(self.select)
2325 for alias, annotation in self.annotation_select.items():
2326 if not (group_by_cols := annotation.get_group_by_cols()):
2327 continue
2328 if allow_aliases and not annotation.contains_aggregate:
2329 group_by.append(Ref(alias, annotation))
2330 else:
2331 group_by.extend(group_by_cols)
2332 self.group_by = tuple(group_by)
2333
2334 def add_select_related(self, fields: list[str]) -> None:
2335 """
2336 Set up the select_related data structure so that we only select
2337 certain related models (as opposed to all models, when
2338 self.select_related=True).
2339 """
2340 if isinstance(self.select_related, bool):
2341 field_dict: dict[str, Any] = {}
2342 else:
2343 field_dict = self.select_related
2344 for field in fields:
2345 d = field_dict
2346 for part in field.split(LOOKUP_SEP):
2347 d = d.setdefault(part, {})
2348 self.select_related = field_dict
2349
2350 def add_extra(
2351 self,
2352 select: dict[str, str],
2353 select_params: list[Any] | None,
2354 where: list[str],
2355 params: list[Any],
2356 tables: list[str],
2357 order_by: tuple[str, ...],
2358 ) -> None:
2359 """
2360 Add data to the various extra_* attributes for user-created additions
2361 to the query.
2362 """
2363 if select:
2364 # We need to pair any placeholder markers in the 'select'
2365 # dictionary with their parameters in 'select_params' so that
2366 # subsequent updates to the select dictionary also adjust the
2367 # parameters appropriately.
2368 select_pairs = {}
2369 if select_params:
2370 param_iter = iter(select_params)
2371 else:
2372 param_iter = iter([])
2373 for name, entry in select.items():
2374 self.check_alias(name)
2375 entry = str(entry)
2376 entry_params = []
2377 pos = entry.find("%s")
2378 while pos != -1:
2379 if pos == 0 or entry[pos - 1] != "%":
2380 entry_params.append(next(param_iter))
2381 pos = entry.find("%s", pos + 2)
2382 select_pairs[name] = (entry, entry_params)
2383 self.extra.update(select_pairs)
2384 if where or params:
2385 self.where.add(ExtraWhere(where, params), AND)
2386 if tables:
2387 self.extra_tables += tuple(tables)
2388 if order_by:
2389 self.extra_order_by = order_by
2390
2391 def clear_deferred_loading(self) -> None:
2392 """Remove any fields from the deferred loading set."""
2393 self.deferred_loading = (frozenset(), True)
2394
2395 def add_deferred_loading(self, field_names: frozenset[str]) -> None:
2396 """
2397 Add the given list of model field names to the set of fields to
2398 exclude from loading from the database when automatic column selection
2399 is done. Add the new field names to any existing field names that
2400 are deferred (or removed from any existing field names that are marked
2401 as the only ones for immediate loading).
2402 """
2403 # Fields on related models are stored in the literal double-underscore
2404 # format, so that we can use a set datastructure. We do the foo__bar
2405 # splitting and handling when computing the SQL column names (as part of
2406 # get_columns()).
2407 existing, defer = self.deferred_loading
2408 existing_set = set(existing)
2409 if defer:
2410 # Add to existing deferred names.
2411 self.deferred_loading = frozenset(existing_set.union(field_names)), True
2412 else:
2413 # Remove names from the set of any existing "immediate load" names.
2414 if new_existing := existing_set.difference(field_names):
2415 self.deferred_loading = frozenset(new_existing), False
2416 else:
2417 self.clear_deferred_loading()
2418 if new_only := set(field_names).difference(existing_set):
2419 self.deferred_loading = frozenset(new_only), True
2420
2421 def add_immediate_loading(self, field_names: list[str] | set[str]) -> None:
2422 """
2423 Add the given list of model field names to the set of fields to
2424 retrieve when the SQL is executed ("immediate loading" fields). The
2425 field names replace any existing immediate loading field names. If
2426 there are field names already specified for deferred loading, remove
2427 those names from the new field_names before storing the new names
2428 for immediate loading. (That is, immediate loading overrides any
2429 existing immediate values, but respects existing deferrals.)
2430 """
2431 existing, defer = self.deferred_loading
2432 field_names_set = set(field_names)
2433
2434 if defer:
2435 # Remove any existing deferred names from the current set before
2436 # setting the new names.
2437 self.deferred_loading = (
2438 frozenset(field_names_set.difference(existing)),
2439 False,
2440 )
2441 else:
2442 # Replace any existing "immediate load" field names.
2443 self.deferred_loading = frozenset(field_names_set), False
2444
2445 def set_annotation_mask(
2446 self,
2447 names: set[str]
2448 | frozenset[str]
2449 | list[str]
2450 | tuple[str, ...]
2451 | dict[str, Any]
2452 | None,
2453 ) -> None:
2454 """Set the mask of annotations that will be returned by the SELECT."""
2455 if names is None:
2456 self.annotation_select_mask = None
2457 else:
2458 self.annotation_select_mask = set(names)
2459 self._annotation_select_cache = None
2460
2461 def append_annotation_mask(self, names: list[str] | dict[str, Any]) -> None:
2462 if self.annotation_select_mask is not None:
2463 self.set_annotation_mask(self.annotation_select_mask.union(names))
2464
2465 def set_extra_mask(
2466 self, names: set[str] | list[str] | tuple[str, ...] | None
2467 ) -> None:
2468 """
2469 Set the mask of extra select items that will be returned by SELECT.
2470 Don't remove them from the Query since they might be used later.
2471 """
2472 if names is None:
2473 self.extra_select_mask = None
2474 else:
2475 self.extra_select_mask = set(names)
2476 self._extra_select_cache = None
2477
2478 def set_values(self, fields: list[str]) -> None:
2479 self.select_related = False
2480 self.clear_deferred_loading()
2481 self.clear_select_fields()
2482 self.has_select_fields = True
2483
2484 if fields:
2485 field_names = []
2486 extra_names = []
2487 annotation_names = []
2488 if not self.extra and not self.annotations:
2489 # Shortcut - if there are no extra or annotations, then
2490 # the values() clause must be just field names.
2491 field_names = list(fields)
2492 else:
2493 self.default_cols = False
2494 for f in fields:
2495 if f in self.extra_select:
2496 extra_names.append(f)
2497 elif f in self.annotation_select:
2498 annotation_names.append(f)
2499 else:
2500 field_names.append(f)
2501 self.set_extra_mask(extra_names)
2502 self.set_annotation_mask(annotation_names)
2503 selected = frozenset(field_names + extra_names + annotation_names)
2504 else:
2505 assert self.model is not None, "Default values query requires a model"
2506 field_names = [f.attname for f in self.model._model_meta.concrete_fields]
2507 selected = frozenset(field_names)
2508 # Selected annotations must be known before setting the GROUP BY
2509 # clause.
2510 if self.group_by is True:
2511 assert self.model is not None, "GROUP BY True requires a model"
2512 self.add_fields(
2513 (f.attname for f in self.model._model_meta.concrete_fields), False
2514 )
2515 # Disable GROUP BY aliases to avoid orphaning references to the
2516 # SELECT clause which is about to be cleared.
2517 self.set_group_by(allow_aliases=False)
2518 self.clear_select_fields()
2519 elif self.group_by:
2520 # Resolve GROUP BY annotation references if they are not part of
2521 # the selected fields anymore.
2522 group_by = []
2523 for expr in self.group_by:
2524 if isinstance(expr, Ref) and expr.refs not in selected:
2525 expr = self.annotations[expr.refs]
2526 group_by.append(expr)
2527 self.group_by = tuple(group_by)
2528
2529 self.values_select = tuple(field_names)
2530 self.add_fields(field_names, True)
2531
2532 @property
2533 def annotation_select(self) -> dict[str, BaseExpression]:
2534 """
2535 Return the dictionary of aggregate columns that are not masked and
2536 should be used in the SELECT clause. Cache this result for performance.
2537 """
2538 if self._annotation_select_cache is not None:
2539 return self._annotation_select_cache
2540 elif not self.annotations:
2541 return {}
2542 elif self.annotation_select_mask is not None:
2543 self._annotation_select_cache = {
2544 k: v
2545 for k, v in self.annotations.items()
2546 if k in self.annotation_select_mask
2547 }
2548 return self._annotation_select_cache
2549 else:
2550 return self.annotations
2551
2552 @property
2553 def extra_select(self) -> dict[str, tuple[str, list[Any]]]:
2554 if self._extra_select_cache is not None:
2555 return self._extra_select_cache
2556 if not self.extra:
2557 return {}
2558 elif self.extra_select_mask is not None:
2559 self._extra_select_cache = {
2560 k: v for k, v in self.extra.items() if k in self.extra_select_mask
2561 }
2562 return self._extra_select_cache
2563 else:
2564 return self.extra
2565
2566 def trim_start(
2567 self, names_with_path: list[tuple[str, list[Any]]]
2568 ) -> tuple[str, bool]:
2569 """
2570 Trim joins from the start of the join path. The candidates for trim
2571 are the PathInfos in names_with_path structure that are m2m joins.
2572
2573 Also set the select column so the start matches the join.
2574
2575 This method is meant to be used for generating the subquery joins &
2576 cols in split_exclude().
2577
2578 Return a lookup usable for doing outerq.filter(lookup=self) and a
2579 boolean indicating if the joins in the prefix contain a LEFT OUTER join.
2580 _"""
2581 all_paths = []
2582 for _, paths in names_with_path:
2583 all_paths.extend(paths)
2584 contains_louter = False
2585 # Trim and operate only on tables that were generated for
2586 # the lookup part of the query. That is, avoid trimming
2587 # joins generated for F() expressions.
2588 lookup_tables = [
2589 t for t in self.alias_map if t in self._lookup_joins or t == self.base_table
2590 ]
2591 for trimmed_paths, path in enumerate(all_paths):
2592 if path.m2m:
2593 break
2594 if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type == LOUTER:
2595 contains_louter = True
2596 alias = lookup_tables[trimmed_paths]
2597 self.unref_alias(alias)
2598 # The path.join_field is a Rel, lets get the other side's field
2599 join_field = path.join_field.field
2600 # Build the filter prefix.
2601 paths_in_prefix = trimmed_paths
2602 trimmed_prefix = []
2603 for name, path in names_with_path:
2604 if paths_in_prefix - len(path) < 0:
2605 break
2606 trimmed_prefix.append(name)
2607 paths_in_prefix -= len(path)
2608 trimmed_prefix.append(join_field.foreign_related_fields[0].name)
2609 trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
2610 # Lets still see if we can trim the first join from the inner query
2611 # (that is, self). We can't do this for:
2612 # - LEFT JOINs because we would miss those rows that have nothing on
2613 # the outer side,
2614 # - INNER JOINs from filtered relations because we would miss their
2615 # filters.
2616 first_join = self.alias_map[lookup_tables[trimmed_paths + 1]]
2617 if first_join.join_type != LOUTER and not first_join.filtered_relation:
2618 select_fields = [r[0] for r in join_field.related_fields]
2619 select_alias = lookup_tables[trimmed_paths + 1]
2620 self.unref_alias(lookup_tables[trimmed_paths])
2621 else:
2622 # TODO: It might be possible to trim more joins from the start of the
2623 # inner query if it happens to have a longer join chain containing the
2624 # values in select_fields. Lets punt this one for now.
2625 select_fields = [r[1] for r in join_field.related_fields]
2626 select_alias = lookup_tables[trimmed_paths]
2627 # The found starting point is likely a join_class instead of a
2628 # base_table_class reference. But the first entry in the query's FROM
2629 # clause must not be a JOIN.
2630 for table in self.alias_map:
2631 if self.alias_refcount[table] > 0:
2632 self.alias_map[table] = self.base_table_class(
2633 self.alias_map[table].table_name,
2634 table,
2635 )
2636 break
2637 self.set_select([f.get_col(select_alias) for f in select_fields])
2638 return trimmed_prefix, contains_louter
2639
2640 def is_nullable(self, field: Field) -> bool:
2641 """Check if the given field should be treated as nullable."""
2642 # QuerySet does not have knowledge of which connection is going to be
2643 # used. For the single-database setup we always reference the default
2644 # connection here.
2645 return field.allow_null
2646
2647
2648def get_order_dir(field: str, default: str = "ASC") -> tuple[str, str]:
2649 """
2650 Return the field name and direction for an order specification. For
2651 example, '-foo' is returned as ('foo', 'DESC').
2652
2653 The 'default' param is used to indicate which way no prefix (or a '+'
2654 prefix) should sort. The '-' prefix always sorts the opposite way.
2655 """
2656 dirn = ORDER_DIR[default]
2657 if field[0] == "-":
2658 return field[1:], dirn[1]
2659 return field, dirn[0]
2660
2661
2662class JoinPromoter:
2663 """
2664 A class to abstract away join promotion problems for complex filter
2665 conditions.
2666 """
2667
2668 def __init__(self, connector: str, num_children: int, negated: bool):
2669 self.connector = connector
2670 self.negated = negated
2671 if self.negated:
2672 if connector == AND:
2673 self.effective_connector = OR
2674 else:
2675 self.effective_connector = AND
2676 else:
2677 self.effective_connector = self.connector
2678 self.num_children = num_children
2679 # Maps of table alias to how many times it is seen as required for
2680 # inner and/or outer joins.
2681 self.votes = Counter()
2682
2683 def __repr__(self) -> str:
2684 return (
2685 f"{self.__class__.__qualname__}(connector={self.connector!r}, "
2686 f"num_children={self.num_children!r}, negated={self.negated!r})"
2687 )
2688
2689 def add_votes(self, votes: Any) -> None:
2690 """
2691 Add single vote per item to self.votes. Parameter can be any
2692 iterable.
2693 """
2694 self.votes.update(votes)
2695
2696 def update_join_types(self, query: Query) -> set[str]:
2697 """
2698 Change join types so that the generated query is as efficient as
2699 possible, but still correct. So, change as many joins as possible
2700 to INNER, but don't make OUTER joins INNER if that could remove
2701 results from the query.
2702 """
2703 to_promote = set()
2704 to_demote = set()
2705 # The effective_connector is used so that NOT (a AND b) is treated
2706 # similarly to (a OR b) for join promotion.
2707 for table, votes in self.votes.items():
2708 # We must use outer joins in OR case when the join isn't contained
2709 # in all of the joins. Otherwise the INNER JOIN itself could remove
2710 # valid results. Consider the case where a model with rel_a and
2711 # rel_b relations is queried with rel_a__col=1 | rel_b__col=2. Now,
2712 # if rel_a join doesn't produce any results is null (for example
2713 # reverse foreign key or null value in direct foreign key), and
2714 # there is a matching row in rel_b with col=2, then an INNER join
2715 # to rel_a would remove a valid match from the query. So, we need
2716 # to promote any existing INNER to LOUTER (it is possible this
2717 # promotion in turn will be demoted later on).
2718 if self.effective_connector == OR and votes < self.num_children:
2719 to_promote.add(table)
2720 # If connector is AND and there is a filter that can match only
2721 # when there is a joinable row, then use INNER. For example, in
2722 # rel_a__col=1 & rel_b__col=2, if either of the rels produce NULL
2723 # as join output, then the col=1 or col=2 can't match (as
2724 # NULL=anything is always false).
2725 # For the OR case, if all children voted for a join to be inner,
2726 # then we can use INNER for the join. For example:
2727 # (rel_a__col__icontains=Alex | rel_a__col__icontains=Russell)
2728 # then if rel_a doesn't produce any rows, the whole condition
2729 # can't match. Hence we can safely use INNER join.
2730 if self.effective_connector == AND or (
2731 self.effective_connector == OR and votes == self.num_children
2732 ):
2733 to_demote.add(table)
2734 # Finally, what happens in cases where we have:
2735 # (rel_a__col=1|rel_b__col=2) & rel_a__col__gte=0
2736 # Now, we first generate the OR clause, and promote joins for it
2737 # in the first if branch above. Both rel_a and rel_b are promoted
2738 # to LOUTER joins. After that we do the AND case. The OR case
2739 # voted no inner joins but the rel_a__col__gte=0 votes inner join
2740 # for rel_a. We demote it back to INNER join (in AND case a single
2741 # vote is enough). The demotion is OK, if rel_a doesn't produce
2742 # rows, then the rel_a__col__gte=0 clause can't be true, and thus
2743 # the whole clause must be false. So, it is safe to use INNER
2744 # join.
2745 # Note that in this example we could just as well have the __gte
2746 # clause and the OR clause swapped. Or we could replace the __gte
2747 # clause with an OR clause containing rel_a__col=1|rel_a__col=2,
2748 # and again we could safely demote to INNER.
2749 query.promote_joins(to_promote)
2750 query.demote_joins(to_demote)
2751 return to_demote