1"""
2Create SQL statements for QuerySets.
3
4The code in here encapsulates all of the SQL construction so that QuerySets
5themselves do not have to. This module has to know all about the internals of
6models in order to get the information it needs.
7"""
8
9from __future__ import annotations
10
11import copy
12import difflib
13import functools
14import sys
15from collections import Counter
16from collections.abc import Callable, Iterable, Iterator, Mapping
17from collections.abc import Iterator as TypingIterator
18from functools import cached_property
19from itertools import chain, count, product
20from string import ascii_uppercase
21from typing import (
22 TYPE_CHECKING,
23 Any,
24 Literal,
25 NamedTuple,
26 Self,
27 TypeVar,
28 cast,
29 overload,
30)
31
32from plain.models.aggregates import Count
33from plain.models.constants import LOOKUP_SEP, OnConflict
34from plain.models.db import NotSupportedError, db_connection
35from plain.models.exceptions import FieldDoesNotExist, FieldError
36from plain.models.expressions import (
37 BaseExpression,
38 Col,
39 Exists,
40 F,
41 OuterRef,
42 Ref,
43 ResolvableExpression,
44 ResolvedOuterRef,
45 Value,
46)
47from plain.models.fields import Field
48from plain.models.fields.related_lookups import MultiColSource
49from plain.models.lookups import Lookup
50from plain.models.query_utils import (
51 PathInfo,
52 Q,
53 check_rel_lookup_compatibility,
54 refs_expression,
55)
56from plain.models.sql.constants import INNER, LOUTER, ORDER_DIR, SINGLE
57from plain.models.sql.datastructures import BaseTable, Empty, Join, MultiJoin
58from plain.models.sql.where import AND, OR, ExtraWhere, NothingNode, WhereNode
59from plain.utils.regex_helper import _lazy_re_compile
60from plain.utils.tree import Node
61
62if TYPE_CHECKING:
63 from plain.models import Model
64 from plain.models.fields.related import RelatedField
65 from plain.models.fields.reverse_related import ForeignObjectRel
66 from plain.models.meta import Meta
67 from plain.models.postgres.wrapper import DatabaseWrapper
68 from plain.models.sql.compiler import (
69 SQLAggregateCompiler,
70 SQLCompiler,
71 SQLDeleteCompiler,
72 SQLInsertCompiler,
73 SQLUpdateCompiler,
74 SqlWithParams,
75 )
76
77__all__ = [
78 "Query",
79 "RawQuery",
80 "DeleteQuery",
81 "UpdateQuery",
82 "InsertQuery",
83 "AggregateQuery",
84]
85
86
87# Quotation marks ('"`[]), whitespace characters, semicolons, or inline
88# SQL comments are forbidden in column aliases.
89FORBIDDEN_ALIAS_PATTERN = _lazy_re_compile(r"['`\"\]\[;\s]|--|/\*|\*/")
90
91# Inspired from
92# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
93EXPLAIN_OPTIONS_PATTERN = _lazy_re_compile(r"[\w\-]+")
94
95
96def get_field_names_from_opts(meta: Meta | None) -> set[str]:
97 if meta is None:
98 return set()
99 return set(
100 chain.from_iterable(
101 (f.name, f.attname) if f.concrete else (f.name,) for f in meta.get_fields()
102 )
103 )
104
105
106def get_children_from_q(q: Q) -> TypingIterator[tuple[str, Any]]:
107 for child in q.children:
108 if isinstance(child, Node):
109 yield from get_children_from_q(child)
110 else:
111 yield child
112
113
114class JoinInfo(NamedTuple):
115 """Information about a join operation in a query."""
116
117 final_field: Field[Any]
118 targets: tuple[Field[Any], ...]
119 meta: Meta
120 joins: list[str]
121 path: list[PathInfo]
122 transform_function: Callable[[Field[Any], str | None], BaseExpression]
123
124
125class RawQuery:
126 """A single raw SQL query."""
127
128 def __init__(self, sql: str, params: tuple[Any, ...] | dict[str, Any] = ()):
129 self.params = params
130 self.sql = sql
131 self.cursor: Any = None
132
133 # Mirror some properties of a normal query so that
134 # the compiler can be used to process results.
135 self.low_mark, self.high_mark = 0, None # Used for offset/limit
136 self.extra_select = {}
137 self.annotation_select = {}
138
139 def chain(self) -> RawQuery:
140 return self.clone()
141
142 def clone(self) -> RawQuery:
143 return RawQuery(self.sql, params=self.params)
144
145 def get_columns(self) -> list[str]:
146 if self.cursor is None:
147 self._execute_query()
148 return [column_meta[0] for column_meta in self.cursor.description]
149
150 def __iter__(self) -> TypingIterator[Any]:
151 # Always execute a new query for a new iterator.
152 # This could be optimized with a cache at the expense of RAM.
153 self._execute_query()
154 return iter(self.cursor)
155
156 def __repr__(self) -> str:
157 return f"<{self.__class__.__name__}: {self}>"
158
159 @property
160 def params_type(self) -> type[dict] | type[tuple] | None:
161 if self.params is None:
162 return None
163 return dict if isinstance(self.params, Mapping) else tuple
164
165 def __str__(self) -> str:
166 if self.params_type is None:
167 return self.sql
168 return self.sql % self.params_type(self.params)
169
170 def _execute_query(self) -> None:
171 self.cursor = db_connection.cursor()
172 self.cursor.execute(self.sql, self.params)
173
174
175class ExplainInfo(NamedTuple):
176 """Information about an EXPLAIN query."""
177
178 format: str | None
179 options: dict[str, Any]
180
181
182class TransformWrapper:
183 """Wrapper for transform functions that supports the has_transforms attribute.
184
185 This replaces functools.partial for transform functions, allowing proper
186 type checking while supporting dynamic attribute assignment.
187 """
188
189 def __init__(
190 self,
191 func: Callable[..., BaseExpression],
192 **kwargs: Any,
193 ):
194 self._partial = functools.partial(func, **kwargs)
195 self.has_transforms: bool = False
196
197 def __call__(self, field: Field[Any], alias: str | None) -> BaseExpression:
198 return self._partial(field, alias)
199
200
201QueryType = TypeVar("QueryType", bound="Query")
202
203
204class Query(BaseExpression):
205 """A single SQL query."""
206
207 alias_prefix = "T"
208 empty_result_set_value = None
209 subq_aliases = frozenset([alias_prefix])
210
211 base_table_class = BaseTable
212 join_class = Join
213
214 default_cols = True
215 default_ordering = True
216 standard_ordering = True
217
218 filter_is_sticky = False
219 subquery = False
220
221 # SQL-related attributes.
222 # Select and related select clauses are expressions to use in the SELECT
223 # clause of the query. The select is used for cases where we want to set up
224 # the select clause to contain other than default fields (values(),
225 # subqueries...). Note that annotations go to annotations dictionary.
226 select = ()
227 # The group_by attribute can have one of the following forms:
228 # - None: no group by at all in the query
229 # - A tuple of expressions: group by (at least) those expressions.
230 # String refs are also allowed for now.
231 # - True: group by all select fields of the model
232 # See compiler.get_group_by() for details.
233 group_by = None
234 order_by = ()
235 low_mark = 0 # Used for offset/limit.
236 high_mark = None # Used for offset/limit.
237 distinct = False
238 distinct_fields = ()
239 select_for_update = False
240 select_for_update_nowait = False
241 select_for_update_skip_locked = False
242 select_for_update_of = ()
243 select_for_no_key_update = False
244 select_related: bool | dict[str, Any] = False
245 has_select_fields = False
246 # Arbitrary limit for select_related to prevents infinite recursion.
247 max_depth = 5
248 # Holds the selects defined by a call to values() or values_list()
249 # excluding annotation_select and extra_select.
250 values_select = ()
251
252 # SQL annotation-related attributes.
253 annotation_select_mask = None
254 _annotation_select_cache = None
255
256 # These are for extensions. The contents are more or less appended verbatim
257 # to the appropriate clause.
258 extra_select_mask = None
259 _extra_select_cache = None
260
261 extra_tables = ()
262 extra_order_by = ()
263
264 # A tuple that is a set of model field names and either True, if these are
265 # the fields to defer, or False if these are the only fields to load.
266 deferred_loading = (frozenset(), True)
267
268 explain_info = None
269
270 def __init__(self, model: type[Model] | None, alias_cols: bool = True):
271 self.model = model
272 self.alias_refcount = {}
273 # alias_map is the most important data structure regarding joins.
274 # It's used for recording which joins exist in the query and what
275 # types they are. The key is the alias of the joined table (possibly
276 # the table name) and the value is a Join-like object (see
277 # sql.datastructures.Join for more information).
278 self.alias_map = {}
279 # Whether to provide alias to columns during reference resolving.
280 self.alias_cols = alias_cols
281 # Sometimes the query contains references to aliases in outer queries (as
282 # a result of split_exclude). Correct alias quoting needs to know these
283 # aliases too.
284 # Map external tables to whether they are aliased.
285 self.external_aliases = {}
286 self.table_map = {} # Maps table names to list of aliases.
287 self.used_aliases = set()
288
289 self.where = WhereNode()
290 # Maps alias -> Annotation Expression.
291 self.annotations = {}
292 # These are for extensions. The contents are more or less appended
293 # verbatim to the appropriate clause.
294 self.extra = {} # Maps col_alias -> (col_sql, params).
295
296 self._filtered_relations = {}
297
298 @property
299 def output_field(self) -> Field | None:
300 if len(self.select) == 1:
301 select = self.select[0]
302 return getattr(select, "target", None) or select.field
303 elif len(self.annotation_select) == 1:
304 return next(iter(self.annotation_select.values())).output_field
305
306 @cached_property
307 def base_table(self) -> str | None:
308 for alias in self.alias_map:
309 return alias
310
311 def __str__(self) -> str:
312 """
313 Return the query as a string of SQL with the parameter values
314 substituted in (use sql_with_params() to see the unsubstituted string).
315
316 Parameter values won't necessarily be quoted correctly, since that is
317 done by the database interface at execution time.
318 """
319 sql, params = self.sql_with_params()
320 return sql % params
321
322 def sql_with_params(self) -> SqlWithParams:
323 """
324 Return the query as an SQL string and the parameters that will be
325 substituted into the query.
326 """
327 return self.get_compiler().as_sql()
328
329 def __deepcopy__(self, memo: dict[int, Any]) -> Self:
330 """Limit the amount of work when a Query is deepcopied."""
331 result = self.clone()
332 memo[id(self)] = result
333 return result
334
335 def get_compiler(self, *, elide_empty: bool = True) -> SQLCompiler:
336 """Return a compiler instance for this query."""
337 # Import compilers here to avoid circular imports at module load time
338 from plain.models.sql.compiler import SQLCompiler as Compiler
339
340 # db_connection is a proxy that acts as DatabaseWrapper
341 connection = cast("DatabaseWrapper", db_connection)
342 return Compiler(self, connection, elide_empty)
343
344 def clone(self) -> Self:
345 """
346 Return a copy of the current Query. A lightweight alternative to
347 deepcopy().
348 """
349 obj = Empty()
350 obj.__class__ = self.__class__
351 obj = cast(Self, obj) # Type checker doesn't understand __class__ reassignment
352 # Copy references to everything.
353 obj.__dict__ = self.__dict__.copy()
354 # Clone attributes that can't use shallow copy.
355 obj.alias_refcount = self.alias_refcount.copy()
356 obj.alias_map = self.alias_map.copy()
357 obj.external_aliases = self.external_aliases.copy()
358 obj.table_map = self.table_map.copy()
359 obj.where = self.where.clone()
360 obj.annotations = self.annotations.copy()
361 if self.annotation_select_mask is not None:
362 obj.annotation_select_mask = self.annotation_select_mask.copy()
363 # _annotation_select_cache cannot be copied, as doing so breaks the
364 # (necessary) state in which both annotations and
365 # _annotation_select_cache point to the same underlying objects.
366 # It will get re-populated in the cloned queryset the next time it's
367 # used.
368 obj._annotation_select_cache = None
369 obj.extra = self.extra.copy()
370 if self.extra_select_mask is not None:
371 obj.extra_select_mask = self.extra_select_mask.copy()
372 if self._extra_select_cache is not None:
373 obj._extra_select_cache = self._extra_select_cache.copy()
374 if self.select_related is not False:
375 # Use deepcopy because select_related stores fields in nested
376 # dicts.
377 obj.select_related = copy.deepcopy(obj.select_related)
378 if "subq_aliases" in self.__dict__:
379 obj.subq_aliases = self.subq_aliases.copy()
380 obj.used_aliases = self.used_aliases.copy()
381 obj._filtered_relations = self._filtered_relations.copy()
382 # Clear the cached_property, if it exists.
383 obj.__dict__.pop("base_table", None)
384 return obj
385
386 @overload
387 def chain(self, klass: None = None) -> Self: ...
388
389 @overload
390 def chain(self, klass: type[QueryType]) -> QueryType: ...
391
392 def chain(self, klass: type[Query] | None = None) -> Query:
393 """
394 Return a copy of the current Query that's ready for another operation.
395 The klass argument changes the type of the Query, e.g. UpdateQuery.
396 """
397 obj = self.clone()
398 if klass and obj.__class__ != klass:
399 obj.__class__ = klass
400 if not obj.filter_is_sticky:
401 obj.used_aliases = set()
402 obj.filter_is_sticky = False
403 if hasattr(obj, "_setup_query"):
404 obj._setup_query() # type: ignore[operator]
405 return obj
406
407 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
408 clone = self.clone()
409 clone.change_aliases(change_map)
410 return clone
411
412 def _get_col(self, target: Any, field: Field, alias: str | None) -> Col:
413 if not self.alias_cols:
414 alias = None
415 return target.get_col(alias, field)
416
417 def get_aggregation(self, aggregate_exprs: dict[str, Any]) -> dict[str, Any]:
418 """
419 Return the dictionary with the values of the existing aggregations.
420 """
421 if not aggregate_exprs:
422 return {}
423 aggregates = {}
424 for alias, aggregate_expr in aggregate_exprs.items():
425 self.check_alias(alias)
426 aggregate = aggregate_expr.resolve_expression(
427 self, allow_joins=True, reuse=None, summarize=True
428 )
429 if not aggregate.contains_aggregate:
430 raise TypeError(f"{alias} is not an aggregate expression")
431 aggregates[alias] = aggregate
432 # Existing usage of aggregation can be determined by the presence of
433 # selected aggregates but also by filters against aliased aggregates.
434 _, having, qualify = self.where.split_having_qualify()
435 has_existing_aggregation = (
436 any(
437 getattr(annotation, "contains_aggregate", True)
438 for annotation in self.annotations.values()
439 )
440 or having
441 )
442 # Decide if we need to use a subquery.
443 #
444 # Existing aggregations would cause incorrect results as
445 # get_aggregation() must produce just one result and thus must not use
446 # GROUP BY.
447 #
448 # If the query has limit or distinct, or uses set operations, then
449 # those operations must be done in a subquery so that the query
450 # aggregates on the limit and/or distinct results instead of applying
451 # the distinct and limit after the aggregation.
452 if (
453 isinstance(self.group_by, tuple)
454 or self.is_sliced
455 or has_existing_aggregation
456 or qualify
457 or self.distinct
458 ):
459 inner_query = self.clone()
460 inner_query.subquery = True
461 outer_query = AggregateQuery(self.model, inner_query)
462 inner_query.select_for_update = False
463 inner_query.select_related = False
464 inner_query.set_annotation_mask(self.annotation_select)
465 # Queries with distinct_fields need ordering and when a limit is
466 # applied we must take the slice from the ordered query. Otherwise
467 # no need for ordering.
468 inner_query.clear_ordering(force=False)
469 if not inner_query.distinct:
470 # If the inner query uses default select and it has some
471 # aggregate annotations, then we must make sure the inner
472 # query is grouped by the main model's primary key. However,
473 # clearing the select clause can alter results if distinct is
474 # used.
475 if inner_query.default_cols and has_existing_aggregation:
476 assert self.model is not None, "Aggregation requires a model"
477 inner_query.group_by = (
478 self.model._model_meta.get_forward_field("id").get_col(
479 inner_query.get_initial_alias()
480 ),
481 )
482 inner_query.default_cols = False
483 if not qualify:
484 # Mask existing annotations that are not referenced by
485 # aggregates to be pushed to the outer query unless
486 # filtering against window functions is involved as it
487 # requires complex realising.
488 annotation_mask = set()
489 for aggregate in aggregates.values():
490 annotation_mask |= aggregate.get_refs()
491 inner_query.set_annotation_mask(annotation_mask)
492
493 # Add aggregates to the outer AggregateQuery. This requires making
494 # sure all columns referenced by the aggregates are selected in the
495 # inner query. It is achieved by retrieving all column references
496 # by the aggregates, explicitly selecting them in the inner query,
497 # and making sure the aggregates are repointed to them.
498 col_refs = {}
499 for alias, aggregate in aggregates.items():
500 replacements = {}
501 for col in self._gen_cols([aggregate], resolve_refs=False):
502 if not (col_ref := col_refs.get(col)):
503 index = len(col_refs) + 1
504 col_alias = f"__col{index}"
505 col_ref = Ref(col_alias, col)
506 col_refs[col] = col_ref
507 inner_query.annotations[col_alias] = col
508 inner_query.append_annotation_mask([col_alias])
509 replacements[col] = col_ref
510 outer_query.annotations[alias] = aggregate.replace_expressions(
511 replacements
512 )
513 if (
514 inner_query.select == ()
515 and not inner_query.default_cols
516 and not inner_query.annotation_select_mask
517 ):
518 # In case of Model.objects[0:3].count(), there would be no
519 # field selected in the inner query, yet we must use a subquery.
520 # So, make sure at least one field is selected.
521 assert self.model is not None, "Count with slicing requires a model"
522 inner_query.select = (
523 self.model._model_meta.get_forward_field("id").get_col(
524 inner_query.get_initial_alias()
525 ),
526 )
527 else:
528 outer_query = self
529 self.select = ()
530 self.default_cols = False
531 self.extra = {}
532 if self.annotations:
533 # Inline reference to existing annotations and mask them as
534 # they are unnecessary given only the summarized aggregations
535 # are requested.
536 replacements = {
537 Ref(alias, annotation): annotation
538 for alias, annotation in self.annotations.items()
539 }
540 self.annotations = {
541 alias: aggregate.replace_expressions(replacements)
542 for alias, aggregate in aggregates.items()
543 }
544 else:
545 self.annotations = aggregates
546 self.set_annotation_mask(aggregates)
547
548 empty_set_result = [
549 expression.empty_result_set_value
550 for expression in outer_query.annotation_select.values()
551 ]
552 elide_empty = not any(result is NotImplemented for result in empty_set_result)
553 outer_query.clear_ordering(force=True)
554 outer_query.clear_limits()
555 outer_query.select_for_update = False
556 outer_query.select_related = False
557 compiler = outer_query.get_compiler(elide_empty=elide_empty)
558 result = compiler.execute_sql(SINGLE)
559 if result is None:
560 result = empty_set_result
561 else:
562 converters = compiler.get_converters(outer_query.annotation_select.values())
563 result = next(compiler.apply_converters((result,), converters))
564
565 return dict(zip(outer_query.annotation_select, result))
566
567 def get_count(self) -> int:
568 """
569 Perform a COUNT() query using the current filter constraints.
570 """
571 obj = self.clone()
572 return obj.get_aggregation({"__count": Count("*")})["__count"]
573
574 def has_filters(self) -> bool:
575 return bool(self.where)
576
577 def exists(self, limit: bool = True) -> Self:
578 q = self.clone()
579 if not (q.distinct and q.is_sliced):
580 if q.group_by is True:
581 assert self.model is not None, "GROUP BY requires a model"
582 q.add_fields(
583 (f.attname for f in self.model._model_meta.concrete_fields), False
584 )
585 # Disable GROUP BY aliases to avoid orphaning references to the
586 # SELECT clause which is about to be cleared.
587 q.set_group_by(allow_aliases=False)
588 q.clear_select_clause()
589 q.clear_ordering(force=True)
590 if limit:
591 q.set_limits(high=1)
592 q.add_annotation(Value(1), "a")
593 return q
594
595 def has_results(self) -> bool:
596 q = self.exists()
597 compiler = q.get_compiler()
598 return compiler.has_results()
599
600 def explain(self, format: str | None = None, **options: Any) -> str:
601 q = self.clone()
602 for option_name in options:
603 if (
604 not EXPLAIN_OPTIONS_PATTERN.fullmatch(option_name)
605 or "--" in option_name
606 ):
607 raise ValueError(f"Invalid option name: {option_name!r}.")
608 q.explain_info = ExplainInfo(format, options)
609 compiler = q.get_compiler()
610 return "\n".join(compiler.explain_query())
611
612 def combine(self, rhs: Query, connector: str) -> None:
613 """
614 Merge the 'rhs' query into the current one (with any 'rhs' effects
615 being applied *after* (that is, "to the right of") anything in the
616 current query. 'rhs' is not modified during a call to this function.
617
618 The 'connector' parameter describes how to connect filters from the
619 'rhs' query.
620 """
621 if self.model != rhs.model:
622 raise TypeError("Cannot combine queries on two different base models.")
623 if self.is_sliced:
624 raise TypeError("Cannot combine queries once a slice has been taken.")
625 if self.distinct != rhs.distinct:
626 raise TypeError("Cannot combine a unique query with a non-unique query.")
627 if self.distinct_fields != rhs.distinct_fields:
628 raise TypeError("Cannot combine queries with different distinct fields.")
629
630 # If lhs and rhs shares the same alias prefix, it is possible to have
631 # conflicting alias changes like T4 -> T5, T5 -> T6, which might end up
632 # as T4 -> T6 while combining two querysets. To prevent this, change an
633 # alias prefix of the rhs and update current aliases accordingly,
634 # except if the alias is the base table since it must be present in the
635 # query on both sides.
636 initial_alias = self.get_initial_alias()
637 assert initial_alias is not None
638 rhs.bump_prefix(self, exclude={initial_alias})
639
640 # Work out how to relabel the rhs aliases, if necessary.
641 change_map = {}
642 conjunction = connector == AND
643
644 # Determine which existing joins can be reused. When combining the
645 # query with AND we must recreate all joins for m2m filters. When
646 # combining with OR we can reuse joins. The reason is that in AND
647 # case a single row can't fulfill a condition like:
648 # revrel__col=1 & revrel__col=2
649 # But, there might be two different related rows matching this
650 # condition. In OR case a single True is enough, so single row is
651 # enough, too.
652 #
653 # Note that we will be creating duplicate joins for non-m2m joins in
654 # the AND case. The results will be correct but this creates too many
655 # joins. This is something that could be fixed later on.
656 reuse = set() if conjunction else set(self.alias_map)
657 joinpromoter = JoinPromoter(connector, 2, False)
658 joinpromoter.add_votes(
659 j for j in self.alias_map if self.alias_map[j].join_type == INNER
660 )
661 rhs_votes = set()
662 # Now, add the joins from rhs query into the new query (skipping base
663 # table).
664 rhs_tables = list(rhs.alias_map)[1:]
665 for alias in rhs_tables:
666 join = rhs.alias_map[alias]
667 # If the left side of the join was already relabeled, use the
668 # updated alias.
669 join = join.relabeled_clone(change_map)
670 new_alias = self.join(join, reuse=reuse)
671 if join.join_type == INNER:
672 rhs_votes.add(new_alias)
673 # We can't reuse the same join again in the query. If we have two
674 # distinct joins for the same connection in rhs query, then the
675 # combined query must have two joins, too.
676 reuse.discard(new_alias)
677 if alias != new_alias:
678 change_map[alias] = new_alias
679 if not rhs.alias_refcount[alias]:
680 # The alias was unused in the rhs query. Unref it so that it
681 # will be unused in the new query, too. We have to add and
682 # unref the alias so that join promotion has information of
683 # the join type for the unused alias.
684 self.unref_alias(new_alias)
685 joinpromoter.add_votes(rhs_votes)
686 joinpromoter.update_join_types(self)
687
688 # Combine subqueries aliases to ensure aliases relabelling properly
689 # handle subqueries when combining where and select clauses.
690 self.subq_aliases |= rhs.subq_aliases
691
692 # Now relabel a copy of the rhs where-clause and add it to the current
693 # one.
694 w = rhs.where.clone()
695 w.relabel_aliases(change_map)
696 self.where.add(w, connector)
697
698 # Selection columns and extra extensions are those provided by 'rhs'.
699 if rhs.select:
700 self.set_select([col.relabeled_clone(change_map) for col in rhs.select])
701 else:
702 self.select = ()
703
704 if connector == OR:
705 # It would be nice to be able to handle this, but the queries don't
706 # really make sense (or return consistent value sets). Not worth
707 # the extra complexity when you can write a real query instead.
708 if self.extra and rhs.extra:
709 raise ValueError(
710 "When merging querysets using 'or', you cannot have "
711 "extra(select=...) on both sides."
712 )
713 self.extra.update(rhs.extra)
714 extra_select_mask = set()
715 if self.extra_select_mask is not None:
716 extra_select_mask.update(self.extra_select_mask)
717 if rhs.extra_select_mask is not None:
718 extra_select_mask.update(rhs.extra_select_mask)
719 if extra_select_mask:
720 self.set_extra_mask(extra_select_mask)
721 self.extra_tables += rhs.extra_tables
722
723 # Ordering uses the 'rhs' ordering, unless it has none, in which case
724 # the current ordering is used.
725 self.order_by = rhs.order_by or self.order_by
726 self.extra_order_by = rhs.extra_order_by or self.extra_order_by
727
728 def _get_defer_select_mask(
729 self,
730 meta: Meta,
731 mask: dict[str, Any],
732 select_mask: dict[Any, Any] | None = None,
733 ) -> dict[Any, Any]:
734 from plain.models.fields.related import RelatedField
735
736 if select_mask is None:
737 select_mask = {}
738 select_mask[meta.get_forward_field("id")] = {}
739 # All concrete fields that are not part of the defer mask must be
740 # loaded. If a relational field is encountered it gets added to the
741 # mask for it be considered if `select_related` and the cycle continues
742 # by recursively calling this function.
743 for field in meta.concrete_fields:
744 field_mask = mask.pop(field.name, None)
745 if field_mask is None:
746 select_mask.setdefault(field, {})
747 elif field_mask:
748 if not isinstance(field, RelatedField):
749 raise FieldError(next(iter(field_mask)))
750 field_select_mask = select_mask.setdefault(field, {})
751 related_model = field.remote_field.model
752 self._get_defer_select_mask(
753 related_model._model_meta, field_mask, field_select_mask
754 )
755 # Remaining defer entries must be references to reverse relationships.
756 # The following code is expected to raise FieldError if it encounters
757 # a malformed defer entry.
758 for field_name, field_mask in mask.items():
759 if filtered_relation := self._filtered_relations.get(field_name):
760 relation = meta.get_reverse_relation(filtered_relation.relation_name)
761 field_select_mask = select_mask.setdefault((field_name, relation), {})
762 field = relation.field
763 else:
764 field = meta.get_reverse_relation(field_name).field
765 field_select_mask = select_mask.setdefault(field, {})
766 related_model = field.model
767 self._get_defer_select_mask(
768 related_model._model_meta, field_mask, field_select_mask
769 )
770 return select_mask
771
772 def _get_only_select_mask(
773 self,
774 meta: Meta,
775 mask: dict[str, Any],
776 select_mask: dict[Any, Any] | None = None,
777 ) -> dict[Any, Any]:
778 from plain.models.fields.related import RelatedField
779
780 if select_mask is None:
781 select_mask = {}
782 select_mask[meta.get_forward_field("id")] = {}
783 # Only include fields mentioned in the mask.
784 for field_name, field_mask in mask.items():
785 field = meta.get_field(field_name)
786 field_select_mask = select_mask.setdefault(field, {})
787 if field_mask:
788 if not isinstance(field, RelatedField):
789 raise FieldError(next(iter(field_mask)))
790 related_model = field.remote_field.model
791 self._get_only_select_mask(
792 related_model._model_meta, field_mask, field_select_mask
793 )
794 return select_mask
795
796 def get_select_mask(self) -> dict[Any, Any]:
797 """
798 Convert the self.deferred_loading data structure to an alternate data
799 structure, describing the field that *will* be loaded. This is used to
800 compute the columns to select from the database and also by the
801 QuerySet class to work out which fields are being initialized on each
802 model. Models that have all their fields included aren't mentioned in
803 the result, only those that have field restrictions in place.
804 """
805 field_names, defer = self.deferred_loading
806 if not field_names:
807 return {}
808 mask = {}
809 for field_name in field_names:
810 part_mask = mask
811 for part in field_name.split(LOOKUP_SEP):
812 part_mask = part_mask.setdefault(part, {})
813 assert self.model is not None, "Deferred/only field loading requires a model"
814 meta = self.model._model_meta
815 if defer:
816 return self._get_defer_select_mask(meta, mask)
817 return self._get_only_select_mask(meta, mask)
818
819 def table_alias(
820 self, table_name: str, create: bool = False, filtered_relation: Any = None
821 ) -> tuple[str, bool]:
822 """
823 Return a table alias for the given table_name and whether this is a
824 new alias or not.
825
826 If 'create' is true, a new alias is always created. Otherwise, the
827 most recently created alias for the table (if one exists) is reused.
828 """
829 alias_list = self.table_map.get(table_name)
830 if not create and alias_list:
831 alias = alias_list[0]
832 self.alias_refcount[alias] += 1
833 return alias, False
834
835 # Create a new alias for this table.
836 if alias_list:
837 alias = "%s%d" % (self.alias_prefix, len(self.alias_map) + 1) # noqa: UP031
838 alias_list.append(alias)
839 else:
840 # The first occurrence of a table uses the table name directly.
841 alias = (
842 filtered_relation.alias if filtered_relation is not None else table_name
843 )
844 self.table_map[table_name] = [alias]
845 self.alias_refcount[alias] = 1
846 return alias, True
847
848 def ref_alias(self, alias: str) -> None:
849 """Increases the reference count for this alias."""
850 self.alias_refcount[alias] += 1
851
852 def unref_alias(self, alias: str, amount: int = 1) -> None:
853 """Decreases the reference count for this alias."""
854 self.alias_refcount[alias] -= amount
855
856 def promote_joins(self, aliases: set[str] | list[str]) -> None:
857 """
858 Promote recursively the join type of given aliases and its children to
859 an outer join. If 'unconditional' is False, only promote the join if
860 it is nullable or the parent join is an outer join.
861
862 The children promotion is done to avoid join chains that contain a LOUTER
863 b INNER c. So, if we have currently a INNER b INNER c and a->b is promoted,
864 then we must also promote b->c automatically, or otherwise the promotion
865 of a->b doesn't actually change anything in the query results.
866 """
867 aliases = list(aliases)
868 while aliases:
869 alias = aliases.pop(0)
870 if self.alias_map[alias].join_type is None:
871 # This is the base table (first FROM entry) - this table
872 # isn't really joined at all in the query, so we should not
873 # alter its join type.
874 continue
875 # Only the first alias (skipped above) should have None join_type
876 assert self.alias_map[alias].join_type is not None
877 parent_alias = self.alias_map[alias].parent_alias
878 parent_louter = (
879 parent_alias and self.alias_map[parent_alias].join_type == LOUTER
880 )
881 already_louter = self.alias_map[alias].join_type == LOUTER
882 if (self.alias_map[alias].nullable or parent_louter) and not already_louter:
883 self.alias_map[alias] = self.alias_map[alias].promote()
884 # Join type of 'alias' changed, so re-examine all aliases that
885 # refer to this one.
886 aliases.extend(
887 join
888 for join in self.alias_map
889 if self.alias_map[join].parent_alias == alias
890 and join not in aliases
891 )
892
893 def demote_joins(self, aliases: set[str] | list[str]) -> None:
894 """
895 Change join type from LOUTER to INNER for all joins in aliases.
896
897 Similarly to promote_joins(), this method must ensure no join chains
898 containing first an outer, then an inner join are generated. If we
899 are demoting b->c join in chain a LOUTER b LOUTER c then we must
900 demote a->b automatically, or otherwise the demotion of b->c doesn't
901 actually change anything in the query results. .
902 """
903 aliases = list(aliases)
904 while aliases:
905 alias = aliases.pop(0)
906 if self.alias_map[alias].join_type == LOUTER:
907 self.alias_map[alias] = self.alias_map[alias].demote()
908 parent_alias = self.alias_map[alias].parent_alias
909 if self.alias_map[parent_alias].join_type == INNER:
910 aliases.append(parent_alias)
911
912 def reset_refcounts(self, to_counts: dict[str, int]) -> None:
913 """
914 Reset reference counts for aliases so that they match the value passed
915 in `to_counts`.
916 """
917 for alias, cur_refcount in self.alias_refcount.copy().items():
918 unref_amount = cur_refcount - to_counts.get(alias, 0)
919 self.unref_alias(alias, unref_amount)
920
921 def change_aliases(self, change_map: dict[str, str]) -> None:
922 """
923 Change the aliases in change_map (which maps old-alias -> new-alias),
924 relabelling any references to them in select columns and the where
925 clause.
926 """
927 # If keys and values of change_map were to intersect, an alias might be
928 # updated twice (e.g. T4 -> T5, T5 -> T6, so also T4 -> T6) depending
929 # on their order in change_map.
930 assert set(change_map).isdisjoint(change_map.values())
931
932 # 1. Update references in "select" (normal columns plus aliases),
933 # "group by" and "where".
934 self.where.relabel_aliases(change_map)
935 if isinstance(self.group_by, tuple):
936 self.group_by = tuple(
937 [col.relabeled_clone(change_map) for col in self.group_by]
938 )
939 self.select = tuple([col.relabeled_clone(change_map) for col in self.select])
940 self.annotations = self.annotations and {
941 key: col.relabeled_clone(change_map)
942 for key, col in self.annotations.items()
943 }
944
945 # 2. Rename the alias in the internal table/alias datastructures.
946 for old_alias, new_alias in change_map.items():
947 if old_alias not in self.alias_map:
948 continue
949 alias_data = self.alias_map[old_alias].relabeled_clone(change_map)
950 self.alias_map[new_alias] = alias_data
951 self.alias_refcount[new_alias] = self.alias_refcount[old_alias]
952 del self.alias_refcount[old_alias]
953 del self.alias_map[old_alias]
954
955 table_aliases = self.table_map[alias_data.table_name]
956 for pos, alias in enumerate(table_aliases):
957 if alias == old_alias:
958 table_aliases[pos] = new_alias
959 break
960 self.external_aliases = {
961 # Table is aliased or it's being changed and thus is aliased.
962 change_map.get(alias, alias): (aliased or alias in change_map)
963 for alias, aliased in self.external_aliases.items()
964 }
965
966 def bump_prefix(
967 self, other_query: Query, exclude: set[str] | dict[str, str] | None = None
968 ) -> None:
969 """
970 Change the alias prefix to the next letter in the alphabet in a way
971 that the other query's aliases and this query's aliases will not
972 conflict. Even tables that previously had no alias will get an alias
973 after this call. To prevent changing aliases use the exclude parameter.
974 """
975
976 def prefix_gen() -> TypingIterator[str]:
977 """
978 Generate a sequence of characters in alphabetical order:
979 -> 'A', 'B', 'C', ...
980
981 When the alphabet is finished, the sequence will continue with the
982 Cartesian product:
983 -> 'AA', 'AB', 'AC', ...
984 """
985 alphabet = ascii_uppercase
986 prefix = chr(ord(self.alias_prefix) + 1)
987 yield prefix
988 for n in count(1):
989 seq = alphabet[alphabet.index(prefix) :] if prefix else alphabet
990 for s in product(seq, repeat=n):
991 yield "".join(s)
992 prefix = None
993
994 if self.alias_prefix != other_query.alias_prefix:
995 # No clashes between self and outer query should be possible.
996 return
997
998 # Explicitly avoid infinite loop. The constant divider is based on how
999 # much depth recursive subquery references add to the stack. This value
1000 # might need to be adjusted when adding or removing function calls from
1001 # the code path in charge of performing these operations.
1002 local_recursion_limit = sys.getrecursionlimit() // 16
1003 for pos, prefix in enumerate(prefix_gen()):
1004 if prefix not in self.subq_aliases:
1005 self.alias_prefix = prefix
1006 break
1007 if pos > local_recursion_limit:
1008 raise RecursionError(
1009 "Maximum recursion depth exceeded: too many subqueries."
1010 )
1011 self.subq_aliases = self.subq_aliases.union([self.alias_prefix])
1012 other_query.subq_aliases = other_query.subq_aliases.union(self.subq_aliases)
1013 if exclude is None:
1014 exclude = {}
1015 self.change_aliases(
1016 {
1017 alias: "%s%d" % (self.alias_prefix, pos) # noqa: UP031
1018 for pos, alias in enumerate(self.alias_map)
1019 if alias not in exclude
1020 }
1021 )
1022
1023 def get_initial_alias(self) -> str | None:
1024 """
1025 Return the first alias for this query, after increasing its reference
1026 count.
1027 """
1028 if self.alias_map:
1029 alias = self.base_table
1030 self.ref_alias(alias)
1031 elif self.model:
1032 alias = self.join(
1033 self.base_table_class(self.model.model_options.db_table, None)
1034 )
1035 else:
1036 alias = None
1037 return alias
1038
1039 def count_active_tables(self) -> int:
1040 """
1041 Return the number of tables in this query with a non-zero reference
1042 count. After execution, the reference counts are zeroed, so tables
1043 added in compiler will not be seen by this method.
1044 """
1045 return len([1 for count in self.alias_refcount.values() if count])
1046
1047 def join(
1048 self,
1049 join: BaseTable | Join,
1050 reuse: set[str] | None = None,
1051 reuse_with_filtered_relation: bool = False,
1052 ) -> str:
1053 """
1054 Return an alias for the 'join', either reusing an existing alias for
1055 that join or creating a new one. 'join' is either a base_table_class or
1056 join_class.
1057
1058 The 'reuse' parameter can be either None which means all joins are
1059 reusable, or it can be a set containing the aliases that can be reused.
1060
1061 The 'reuse_with_filtered_relation' parameter is used when computing
1062 FilteredRelation instances.
1063
1064 A join is always created as LOUTER if the lhs alias is LOUTER to make
1065 sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new
1066 joins are created as LOUTER if the join is nullable.
1067 """
1068 if reuse_with_filtered_relation and reuse:
1069 reuse_aliases = [
1070 a for a, j in self.alias_map.items() if a in reuse and j.equals(join)
1071 ]
1072 else:
1073 reuse_aliases = [
1074 a
1075 for a, j in self.alias_map.items()
1076 if (reuse is None or a in reuse) and j == join
1077 ]
1078 if reuse_aliases:
1079 if join.table_alias in reuse_aliases:
1080 reuse_alias = join.table_alias
1081 else:
1082 # Reuse the most recent alias of the joined table
1083 # (a many-to-many relation may be joined multiple times).
1084 reuse_alias = reuse_aliases[-1]
1085 self.ref_alias(reuse_alias)
1086 return reuse_alias
1087
1088 # No reuse is possible, so we need a new alias.
1089 alias, _ = self.table_alias(
1090 join.table_name, create=True, filtered_relation=join.filtered_relation
1091 )
1092 if isinstance(join, Join):
1093 if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable:
1094 join_type = LOUTER
1095 else:
1096 join_type = INNER
1097 join.join_type = join_type
1098 join.table_alias = alias
1099 self.alias_map[alias] = join
1100 return alias
1101
1102 def check_alias(self, alias: str) -> None:
1103 if FORBIDDEN_ALIAS_PATTERN.search(alias):
1104 raise ValueError(
1105 "Column aliases cannot contain whitespace characters, quotation marks, "
1106 "semicolons, or SQL comments."
1107 )
1108
1109 def add_annotation(
1110 self, annotation: BaseExpression, alias: str, select: bool = True
1111 ) -> None:
1112 """Add a single annotation expression to the Query."""
1113 self.check_alias(alias)
1114 annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None)
1115 if select:
1116 self.append_annotation_mask([alias])
1117 else:
1118 self.set_annotation_mask(set(self.annotation_select).difference({alias}))
1119 self.annotations[alias] = annotation
1120
1121 def resolve_expression(
1122 self,
1123 query: Any = None,
1124 allow_joins: bool = True,
1125 reuse: Any = None,
1126 summarize: bool = False,
1127 for_save: bool = False,
1128 ) -> Self:
1129 clone = self.clone()
1130 # Subqueries need to use a different set of aliases than the outer query.
1131 clone.bump_prefix(query)
1132 clone.subquery = True
1133 clone.where.resolve_expression(query, allow_joins, reuse, summarize, for_save)
1134 for key, value in clone.annotations.items():
1135 resolved = value.resolve_expression(
1136 query, allow_joins, reuse, summarize, for_save
1137 )
1138 if hasattr(resolved, "external_aliases"):
1139 resolved.external_aliases.update(clone.external_aliases)
1140 clone.annotations[key] = resolved
1141 # Outer query's aliases are considered external.
1142 for alias, table in query.alias_map.items():
1143 clone.external_aliases[alias] = (
1144 isinstance(table, Join)
1145 and table.join_field.related_model.model_options.db_table != alias
1146 ) or (
1147 isinstance(table, BaseTable) and table.table_name != table.table_alias
1148 )
1149 return clone
1150
1151 def get_external_cols(self) -> list[Col]:
1152 exprs = chain(self.annotations.values(), self.where.children)
1153 return [
1154 col
1155 for col in self._gen_cols(exprs, include_external=True)
1156 if col.alias in self.external_aliases
1157 ]
1158
1159 def get_group_by_cols(
1160 self, wrapper: BaseExpression | None = None
1161 ) -> list[BaseExpression]:
1162 # If wrapper is referenced by an alias for an explicit GROUP BY through
1163 # values() a reference to this expression and not the self must be
1164 # returned to ensure external column references are not grouped against
1165 # as well.
1166 external_cols = self.get_external_cols()
1167 if any(col.possibly_multivalued for col in external_cols):
1168 return [wrapper or self]
1169 # Cast needed because list is invariant: list[Col] is not list[BaseExpression]
1170 return cast(list[BaseExpression], external_cols)
1171
1172 def as_sql(
1173 self, compiler: SQLCompiler, connection: DatabaseWrapper
1174 ) -> SqlWithParams:
1175 sql, params = self.get_compiler().as_sql()
1176 if self.subquery:
1177 sql = f"({sql})"
1178 return sql, params
1179
1180 def resolve_lookup_value(
1181 self, value: Any, can_reuse: set[str] | None, allow_joins: bool
1182 ) -> Any:
1183 if isinstance(value, ResolvableExpression):
1184 value = value.resolve_expression(
1185 self,
1186 reuse=can_reuse,
1187 allow_joins=allow_joins,
1188 )
1189 elif isinstance(value, list | tuple):
1190 # The items of the iterable may be expressions and therefore need
1191 # to be resolved independently.
1192 values = (
1193 self.resolve_lookup_value(sub_value, can_reuse, allow_joins)
1194 for sub_value in value
1195 )
1196 type_ = type(value)
1197 if hasattr(type_, "_make"): # namedtuple
1198 return type_(*values)
1199 return type_(values)
1200 return value
1201
1202 def solve_lookup_type(
1203 self, lookup: str, summarize: bool = False
1204 ) -> tuple[
1205 list[str] | tuple[str, ...], tuple[str, ...], BaseExpression | Literal[False]
1206 ]:
1207 """
1208 Solve the lookup type from the lookup (e.g.: 'foobar__id__icontains').
1209 """
1210 lookup_splitted = lookup.split(LOOKUP_SEP)
1211 if self.annotations:
1212 annotation, expression_lookups = refs_expression(
1213 lookup_splitted, self.annotations
1214 )
1215 if annotation:
1216 expression = self.annotations[annotation]
1217 if summarize:
1218 expression = Ref(annotation, expression)
1219 return expression_lookups, (), expression
1220 assert self.model is not None, "Field lookups require a model"
1221 meta = self.model._model_meta
1222 _, field, _, lookup_parts = self.names_to_path(lookup_splitted, meta)
1223 field_parts = lookup_splitted[0 : len(lookup_splitted) - len(lookup_parts)]
1224 if len(lookup_parts) > 1 and not field_parts:
1225 raise FieldError(
1226 f'Invalid lookup "{lookup}" for model {meta.model.__name__}".'
1227 )
1228 return lookup_parts, tuple(field_parts), False
1229
1230 def check_query_object_type(
1231 self, value: Any, meta: Meta, field: Field | ForeignObjectRel
1232 ) -> None:
1233 """
1234 Check whether the object passed while querying is of the correct type.
1235 If not, raise a ValueError specifying the wrong object.
1236 """
1237 from plain.models import Model
1238
1239 if isinstance(value, Model):
1240 if not check_rel_lookup_compatibility(value._model_meta.model, meta, field):
1241 raise ValueError(
1242 f'Cannot query "{value}": Must be "{meta.model.model_options.object_name}" instance.'
1243 )
1244
1245 def check_related_objects(
1246 self, field: RelatedField | ForeignObjectRel, value: Any, meta: Meta
1247 ) -> None:
1248 """Check the type of object passed to query relations."""
1249 from plain.models import Model
1250
1251 # Check that the field and the queryset use the same model in a
1252 # query like .filter(author=Author.query.all()). For example, the
1253 # meta would be Author's (from the author field) and value.model
1254 # would be Author.query.all() queryset's .model (Author also).
1255 # The field is the related field on the lhs side.
1256 if (
1257 isinstance(value, Query)
1258 and not value.has_select_fields
1259 and not check_rel_lookup_compatibility(value.model, meta, field)
1260 ):
1261 raise ValueError(
1262 f'Cannot use QuerySet for "{value.model.model_options.object_name}": Use a QuerySet for "{meta.model.model_options.object_name}".'
1263 )
1264 elif isinstance(value, Model):
1265 self.check_query_object_type(value, meta, field)
1266 elif isinstance(value, Iterable):
1267 for v in value:
1268 self.check_query_object_type(v, meta, field)
1269
1270 def check_filterable(self, expression: Any) -> None:
1271 """Raise an error if expression cannot be used in a WHERE clause."""
1272 if isinstance(expression, ResolvableExpression) and not getattr(
1273 expression, "filterable", True
1274 ):
1275 raise NotSupportedError(
1276 expression.__class__.__name__ + " is disallowed in the filter clause."
1277 )
1278 if hasattr(expression, "get_source_expressions"):
1279 for expr in expression.get_source_expressions():
1280 self.check_filterable(expr)
1281
1282 def build_lookup(
1283 self, lookups: list[str], lhs: BaseExpression | MultiColSource, rhs: Any
1284 ) -> Lookup | None:
1285 """
1286 Try to extract transforms and lookup from given lhs.
1287
1288 The lhs value is something that works like SQLExpression.
1289 The rhs value is what the lookup is going to compare against.
1290 The lookups is a list of names to extract using get_lookup()
1291 and get_transform().
1292 """
1293 # __exact is the default lookup if one isn't given.
1294 *transforms, lookup_name = lookups or ["exact"]
1295 if transforms:
1296 if isinstance(lhs, MultiColSource):
1297 raise FieldError(
1298 "Transforms are not supported on multi-column relations."
1299 )
1300 # At this point, lhs must be BaseExpression
1301 for name in transforms:
1302 lhs = self.try_transform(lhs, name)
1303 # First try get_lookup() so that the lookup takes precedence if the lhs
1304 # supports both transform and lookup for the name.
1305 lookup_class = lhs.get_lookup(lookup_name)
1306 if not lookup_class:
1307 # A lookup wasn't found. Try to interpret the name as a transform
1308 # and do an Exact lookup against it.
1309 if isinstance(lhs, MultiColSource):
1310 raise FieldError(
1311 "Transforms are not supported on multi-column relations."
1312 )
1313 lhs = self.try_transform(lhs, lookup_name)
1314 lookup_name = "exact"
1315 lookup_class = lhs.get_lookup(lookup_name)
1316 if not lookup_class:
1317 return
1318
1319 lookup = lookup_class(lhs, rhs)
1320 # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
1321 # uses of None as a query value unless the lookup supports it.
1322 if lookup.rhs is None and not lookup.can_use_none_as_rhs:
1323 if lookup_name not in ("exact", "iexact"):
1324 raise ValueError("Cannot use None as a query value")
1325 isnull_lookup = lhs.get_lookup("isnull")
1326 assert isnull_lookup is not None
1327 return isnull_lookup(lhs, True)
1328
1329 return lookup
1330
1331 def try_transform(self, lhs: BaseExpression, name: str) -> BaseExpression:
1332 """
1333 Helper method for build_lookup(). Try to fetch and initialize
1334 a transform for name parameter from lhs.
1335 """
1336 transform_class = lhs.get_transform(name)
1337 if transform_class:
1338 return transform_class(lhs)
1339 else:
1340 output_field = lhs.output_field.__class__
1341 suggested_lookups = difflib.get_close_matches(
1342 name, output_field.get_lookups()
1343 )
1344 if suggested_lookups:
1345 suggestion = ", perhaps you meant {}?".format(
1346 " or ".join(suggested_lookups)
1347 )
1348 else:
1349 suggestion = "."
1350 raise FieldError(
1351 f"Unsupported lookup '{name}' for {output_field.__name__} or join on the field not "
1352 f"permitted{suggestion}"
1353 )
1354
1355 def build_filter(
1356 self,
1357 filter_expr: tuple[str, Any] | Q | BaseExpression,
1358 branch_negated: bool = False,
1359 current_negated: bool = False,
1360 can_reuse: set[str] | None = None,
1361 allow_joins: bool = True,
1362 split_subq: bool = True,
1363 reuse_with_filtered_relation: bool = False,
1364 check_filterable: bool = True,
1365 summarize: bool = False,
1366 ) -> tuple[WhereNode, set[str] | tuple[()]]:
1367 from plain.models.fields.related import RelatedField
1368
1369 """
1370 Build a WhereNode for a single filter clause but don't add it
1371 to this Query. Query.add_q() will then add this filter to the where
1372 Node.
1373
1374 The 'branch_negated' tells us if the current branch contains any
1375 negations. This will be used to determine if subqueries are needed.
1376
1377 The 'current_negated' is used to determine if the current filter is
1378 negated or not and this will be used to determine if IS NULL filtering
1379 is needed.
1380
1381 The difference between current_negated and branch_negated is that
1382 branch_negated is set on first negation, but current_negated is
1383 flipped for each negation.
1384
1385 Note that add_filter will not do any negating itself, that is done
1386 upper in the code by add_q().
1387
1388 The 'can_reuse' is a set of reusable joins for multijoins.
1389
1390 If 'reuse_with_filtered_relation' is True, then only joins in can_reuse
1391 will be reused.
1392
1393 The method will create a filter clause that can be added to the current
1394 query. However, if the filter isn't added to the query then the caller
1395 is responsible for unreffing the joins used.
1396 """
1397 if isinstance(filter_expr, dict):
1398 raise FieldError("Cannot parse keyword query as dict")
1399 if isinstance(filter_expr, Q):
1400 return self._add_q(
1401 filter_expr,
1402 branch_negated=branch_negated,
1403 current_negated=current_negated,
1404 used_aliases=can_reuse,
1405 allow_joins=allow_joins,
1406 split_subq=split_subq,
1407 check_filterable=check_filterable,
1408 summarize=summarize,
1409 )
1410 if isinstance(filter_expr, ResolvableExpression):
1411 if not getattr(filter_expr, "conditional", False):
1412 raise TypeError("Cannot filter against a non-conditional expression.")
1413 condition = filter_expr.resolve_expression(
1414 self, allow_joins=allow_joins, summarize=summarize
1415 )
1416 if not isinstance(condition, Lookup):
1417 condition = self.build_lookup(["exact"], condition, True)
1418 return WhereNode([condition], connector=AND), set()
1419 if isinstance(filter_expr, BaseExpression):
1420 raise TypeError(f"Unexpected BaseExpression type: {type(filter_expr)}")
1421 arg, value = filter_expr
1422 if not arg:
1423 raise FieldError(f"Cannot parse keyword query {arg!r}")
1424 lookups, parts, reffed_expression = self.solve_lookup_type(arg, summarize)
1425
1426 if check_filterable:
1427 self.check_filterable(reffed_expression)
1428
1429 if not allow_joins and len(parts) > 1:
1430 raise FieldError("Joined field references are not permitted in this query")
1431
1432 pre_joins = self.alias_refcount.copy()
1433 value = self.resolve_lookup_value(value, can_reuse, allow_joins)
1434 used_joins = {
1435 k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)
1436 }
1437
1438 if check_filterable:
1439 self.check_filterable(value)
1440
1441 if reffed_expression:
1442 condition = self.build_lookup(list(lookups), reffed_expression, value)
1443 return WhereNode([condition], connector=AND), set()
1444
1445 assert self.model is not None, "Building filters requires a model"
1446 meta = self.model._model_meta
1447 alias = self.get_initial_alias()
1448 assert alias is not None
1449 allow_many = not branch_negated or not split_subq
1450
1451 try:
1452 join_info = self.setup_joins(
1453 list(parts),
1454 meta,
1455 alias,
1456 can_reuse=can_reuse,
1457 allow_many=allow_many,
1458 reuse_with_filtered_relation=reuse_with_filtered_relation,
1459 )
1460
1461 # Prevent iterator from being consumed by check_related_objects()
1462 if isinstance(value, Iterator):
1463 value = list(value)
1464 from plain.models.fields.related import RelatedField
1465 from plain.models.fields.reverse_related import ForeignObjectRel
1466
1467 if isinstance(join_info.final_field, RelatedField | ForeignObjectRel):
1468 self.check_related_objects(join_info.final_field, value, join_info.meta)
1469
1470 # split_exclude() needs to know which joins were generated for the
1471 # lookup parts
1472 self._lookup_joins = join_info.joins
1473 except MultiJoin as e:
1474 return self.split_exclude(
1475 filter_expr,
1476 can_reuse or set(),
1477 e.names_with_path,
1478 )
1479
1480 # Update used_joins before trimming since they are reused to determine
1481 # which joins could be later promoted to INNER.
1482 used_joins.update(join_info.joins)
1483 targets, alias, join_list = self.trim_joins(
1484 join_info.targets, join_info.joins, join_info.path
1485 )
1486 if can_reuse is not None:
1487 can_reuse.update(join_list)
1488
1489 if isinstance(join_info.final_field, RelatedField | ForeignObjectRel):
1490 if len(targets) == 1:
1491 col = self._get_col(targets[0], join_info.final_field, alias)
1492 else:
1493 col = MultiColSource(
1494 alias, targets, join_info.targets, join_info.final_field
1495 )
1496 else:
1497 col = self._get_col(targets[0], join_info.final_field, alias)
1498
1499 condition = self.build_lookup(list(lookups), col, value)
1500 assert condition is not None
1501 lookup_type = condition.lookup_name
1502 clause = WhereNode([condition], connector=AND)
1503
1504 require_outer = (
1505 lookup_type == "isnull" and condition.rhs is True and not current_negated
1506 )
1507 if (
1508 current_negated
1509 and (lookup_type != "isnull" or condition.rhs is False)
1510 and condition.rhs is not None
1511 ):
1512 require_outer = True
1513 if lookup_type != "isnull":
1514 # The condition added here will be SQL like this:
1515 # NOT (col IS NOT NULL), where the first NOT is added in
1516 # upper layers of code. The reason for addition is that if col
1517 # is null, then col != someval will result in SQL "unknown"
1518 # which isn't the same as in Python. The Python None handling
1519 # is wanted, and it can be gotten by
1520 # (col IS NULL OR col != someval)
1521 # <=>
1522 # NOT (col IS NOT NULL AND col = someval).
1523 if (
1524 self.is_nullable(targets[0])
1525 or self.alias_map[join_list[-1]].join_type == LOUTER
1526 ):
1527 lookup_class = targets[0].get_lookup("isnull")
1528 assert lookup_class is not None
1529 col = self._get_col(targets[0], join_info.targets[0], alias)
1530 clause.add(lookup_class(col, False), AND)
1531 # If someval is a nullable column, someval IS NOT NULL is
1532 # added.
1533 if isinstance(value, Col) and self.is_nullable(value.target):
1534 lookup_class = value.target.get_lookup("isnull")
1535 assert lookup_class is not None
1536 clause.add(lookup_class(value, False), AND)
1537 return clause, used_joins if not require_outer else ()
1538
1539 def add_filter(self, filter_lhs: str, filter_rhs: Any) -> None:
1540 self.add_q(Q((filter_lhs, filter_rhs)))
1541
1542 def add_q(self, q_object: Q) -> None:
1543 """
1544 A preprocessor for the internal _add_q(). Responsible for doing final
1545 join promotion.
1546 """
1547 # For join promotion this case is doing an AND for the added q_object
1548 # and existing conditions. So, any existing inner join forces the join
1549 # type to remain inner. Existing outer joins can however be demoted.
1550 # (Consider case where rel_a is LOUTER and rel_a__col=1 is added - if
1551 # rel_a doesn't produce any rows, then the whole condition must fail.
1552 # So, demotion is OK.
1553 existing_inner = {
1554 a for a in self.alias_map if self.alias_map[a].join_type == INNER
1555 }
1556 clause, _ = self._add_q(q_object, self.used_aliases)
1557 if clause:
1558 self.where.add(clause, AND)
1559 self.demote_joins(existing_inner)
1560
1561 def build_where(
1562 self, filter_expr: tuple[str, Any] | Q | BaseExpression
1563 ) -> WhereNode:
1564 return self.build_filter(filter_expr, allow_joins=False)[0]
1565
1566 def clear_where(self) -> None:
1567 self.where = WhereNode()
1568
1569 def _add_q(
1570 self,
1571 q_object: Q,
1572 used_aliases: set[str] | None,
1573 branch_negated: bool = False,
1574 current_negated: bool = False,
1575 allow_joins: bool = True,
1576 split_subq: bool = True,
1577 check_filterable: bool = True,
1578 summarize: bool = False,
1579 ) -> tuple[WhereNode, set[str] | tuple[()]]:
1580 """Add a Q-object to the current filter."""
1581 connector = q_object.connector
1582 current_negated ^= q_object.negated
1583 branch_negated = branch_negated or q_object.negated
1584 target_clause = WhereNode(connector=connector, negated=q_object.negated)
1585 joinpromoter = JoinPromoter(
1586 q_object.connector, len(q_object.children), current_negated
1587 )
1588 for child in q_object.children:
1589 child_clause, needed_inner = self.build_filter(
1590 child,
1591 can_reuse=used_aliases,
1592 branch_negated=branch_negated,
1593 current_negated=current_negated,
1594 allow_joins=allow_joins,
1595 split_subq=split_subq,
1596 check_filterable=check_filterable,
1597 summarize=summarize,
1598 )
1599 joinpromoter.add_votes(needed_inner)
1600 if child_clause:
1601 target_clause.add(child_clause, connector)
1602 needed_inner = joinpromoter.update_join_types(self)
1603 return target_clause, needed_inner
1604
1605 def build_filtered_relation_q(
1606 self,
1607 q_object: Q,
1608 reuse: set[str],
1609 branch_negated: bool = False,
1610 current_negated: bool = False,
1611 ) -> WhereNode:
1612 """Add a FilteredRelation object to the current filter."""
1613 connector = q_object.connector
1614 current_negated ^= q_object.negated
1615 branch_negated = branch_negated or q_object.negated
1616 target_clause = WhereNode(connector=connector, negated=q_object.negated)
1617 for child in q_object.children:
1618 if isinstance(child, Node):
1619 child_clause = self.build_filtered_relation_q(
1620 child,
1621 reuse=reuse,
1622 branch_negated=branch_negated,
1623 current_negated=current_negated,
1624 )
1625 else:
1626 child_clause, _ = self.build_filter(
1627 child,
1628 can_reuse=reuse,
1629 branch_negated=branch_negated,
1630 current_negated=current_negated,
1631 allow_joins=True,
1632 split_subq=False,
1633 reuse_with_filtered_relation=True,
1634 )
1635 target_clause.add(child_clause, connector)
1636 return target_clause
1637
1638 def add_filtered_relation(self, filtered_relation: Any, alias: str) -> None:
1639 filtered_relation.alias = alias
1640 lookups = dict(get_children_from_q(filtered_relation.condition))
1641 relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(
1642 filtered_relation.relation_name
1643 )
1644 if relation_lookup_parts:
1645 raise ValueError(
1646 "FilteredRelation's relation_name cannot contain lookups "
1647 f"(got {filtered_relation.relation_name!r})."
1648 )
1649 for lookup in chain(lookups):
1650 lookup_parts, lookup_field_parts, _ = self.solve_lookup_type(lookup)
1651 shift = 2 if not lookup_parts else 1
1652 lookup_field_path = lookup_field_parts[:-shift]
1653 for idx, lookup_field_part in enumerate(lookup_field_path):
1654 if len(relation_field_parts) > idx:
1655 if relation_field_parts[idx] != lookup_field_part:
1656 raise ValueError(
1657 "FilteredRelation's condition doesn't support "
1658 f"relations outside the {filtered_relation.relation_name!r} (got {lookup!r})."
1659 )
1660 else:
1661 raise ValueError(
1662 "FilteredRelation's condition doesn't support nested "
1663 f"relations deeper than the relation_name (got {lookup!r} for "
1664 f"{filtered_relation.relation_name!r})."
1665 )
1666 self._filtered_relations[filtered_relation.alias] = filtered_relation
1667
1668 def names_to_path(
1669 self,
1670 names: list[str],
1671 meta: Meta,
1672 allow_many: bool = True,
1673 fail_on_missing: bool = False,
1674 ) -> tuple[list[Any], Field | ForeignObjectRel, tuple[Field, ...], list[str]]:
1675 """
1676 Walk the list of names and turns them into PathInfo tuples. A single
1677 name in 'names' can generate multiple PathInfos (m2m, for example).
1678
1679 'names' is the path of names to travel, 'meta' is the Meta we
1680 start the name resolving from, 'allow_many' is as for setup_joins().
1681 If fail_on_missing is set to True, then a name that can't be resolved
1682 will generate a FieldError.
1683
1684 Return a list of PathInfo tuples. In addition return the final field
1685 (the last used join field) and target (which is a field guaranteed to
1686 contain the same value as the final field). Finally, return those names
1687 that weren't found (which are likely transforms and the final lookup).
1688 """
1689 path, names_with_path = [], []
1690 for pos, name in enumerate(names):
1691 cur_names_with_path = (name, [])
1692
1693 field = None
1694 filtered_relation = None
1695 try:
1696 if meta is None:
1697 raise FieldDoesNotExist
1698 field = meta.get_field(name)
1699 except FieldDoesNotExist:
1700 if name in self.annotation_select:
1701 field = self.annotation_select[name].output_field
1702 elif name in self._filtered_relations and pos == 0:
1703 filtered_relation = self._filtered_relations[name]
1704 if LOOKUP_SEP in filtered_relation.relation_name:
1705 parts = filtered_relation.relation_name.split(LOOKUP_SEP)
1706 filtered_relation_path, field, _, _ = self.names_to_path(
1707 parts,
1708 meta,
1709 allow_many,
1710 fail_on_missing,
1711 )
1712 path.extend(filtered_relation_path[:-1])
1713 else:
1714 field = meta.get_field(filtered_relation.relation_name)
1715 if field is None:
1716 # We didn't find the current field, so move position back
1717 # one step.
1718 pos -= 1
1719 if pos == -1 or fail_on_missing:
1720 available = sorted(
1721 [
1722 *get_field_names_from_opts(meta),
1723 *self.annotation_select,
1724 *self._filtered_relations,
1725 ]
1726 )
1727 raise FieldError(
1728 "Cannot resolve keyword '{}' into field. "
1729 "Choices are: {}".format(name, ", ".join(available))
1730 )
1731 break
1732
1733 # Lazy import to avoid circular dependency
1734 from plain.models.fields.related import ForeignKeyField as FK
1735 from plain.models.fields.related import ManyToManyField as M2M
1736 from plain.models.fields.reverse_related import ForeignObjectRel as FORel
1737
1738 if isinstance(field, FK | M2M | FORel):
1739 pathinfos: list[PathInfo]
1740 if filtered_relation:
1741 pathinfos = field.get_path_info(filtered_relation)
1742 else:
1743 pathinfos = field.path_infos
1744 if not allow_many:
1745 for inner_pos, p in enumerate(pathinfos):
1746 if p.m2m:
1747 cur_names_with_path[1].extend(pathinfos[0 : inner_pos + 1])
1748 names_with_path.append(cur_names_with_path)
1749 raise MultiJoin(pos + 1, names_with_path)
1750 last = pathinfos[-1]
1751 path.extend(pathinfos)
1752 final_field = last.join_field
1753 meta = last.to_meta
1754 targets = last.target_fields
1755 cur_names_with_path[1].extend(pathinfos)
1756 names_with_path.append(cur_names_with_path)
1757 else:
1758 # Local non-relational field.
1759 final_field = field
1760 targets = (field,)
1761 if fail_on_missing and pos + 1 != len(names):
1762 raise FieldError(
1763 f"Cannot resolve keyword {names[pos + 1]!r} into field. Join on '{name}'"
1764 " not permitted."
1765 )
1766 break
1767 return path, final_field, targets, names[pos + 1 :]
1768
1769 def setup_joins(
1770 self,
1771 names: list[str],
1772 meta: Meta,
1773 alias: str,
1774 can_reuse: set[str] | None = None,
1775 allow_many: bool = True,
1776 reuse_with_filtered_relation: bool = False,
1777 ) -> JoinInfo:
1778 """
1779 Compute the necessary table joins for the passage through the fields
1780 given in 'names'. 'meta' is the Meta for the current model
1781 (which gives the table we are starting from), 'alias' is the alias for
1782 the table to start the joining from.
1783
1784 The 'can_reuse' defines the reverse foreign key joins we can reuse. It
1785 can be None in which case all joins are reusable or a set of aliases
1786 that can be reused. Note that non-reverse foreign keys are always
1787 reusable when using setup_joins().
1788
1789 The 'reuse_with_filtered_relation' can be used to force 'can_reuse'
1790 parameter and force the relation on the given connections.
1791
1792 If 'allow_many' is False, then any reverse foreign key seen will
1793 generate a MultiJoin exception.
1794
1795 Return the final field involved in the joins, the target field (used
1796 for any 'where' constraint), the final 'opts' value, the joins, the
1797 field path traveled to generate the joins, and a transform function
1798 that takes a field and alias and is equivalent to `field.get_col(alias)`
1799 in the simple case but wraps field transforms if they were included in
1800 names.
1801
1802 The target field is the field containing the concrete value. Final
1803 field can be something different, for example foreign key pointing to
1804 that value. Final field is needed for example in some value
1805 conversions (convert 'obj' in fk__id=obj to pk val using the foreign
1806 key field for example).
1807 """
1808 joins = [alias]
1809 # The transform can't be applied yet, as joins must be trimmed later.
1810 # To avoid making every caller of this method look up transforms
1811 # directly, compute transforms here and create a partial that converts
1812 # fields to the appropriate wrapped version.
1813
1814 def _base_transformer(field: Field, alias: str | None) -> Col:
1815 if not self.alias_cols:
1816 alias = None
1817 return field.get_col(alias)
1818
1819 final_transformer: TransformWrapper | Callable[[Field, str | None], Col] = (
1820 _base_transformer
1821 )
1822
1823 # Try resolving all the names as fields first. If there's an error,
1824 # treat trailing names as lookups until a field can be resolved.
1825 last_field_exception = None
1826 for pivot in range(len(names), 0, -1):
1827 try:
1828 path, final_field, targets, rest = self.names_to_path(
1829 names[:pivot],
1830 meta,
1831 allow_many,
1832 fail_on_missing=True,
1833 )
1834 except FieldError as exc:
1835 if pivot == 1:
1836 # The first item cannot be a lookup, so it's safe
1837 # to raise the field error here.
1838 raise
1839 else:
1840 last_field_exception = exc
1841 else:
1842 # The transforms are the remaining items that couldn't be
1843 # resolved into fields.
1844 transforms = names[pivot:]
1845 break
1846 for name in transforms:
1847
1848 def transform(
1849 field: Field, alias: str | None, *, name: str, previous: Any
1850 ) -> BaseExpression:
1851 try:
1852 wrapped = previous(field, alias)
1853 return self.try_transform(wrapped, name)
1854 except FieldError:
1855 # FieldError is raised if the transform doesn't exist.
1856 if isinstance(final_field, Field) and last_field_exception:
1857 raise last_field_exception
1858 else:
1859 raise
1860
1861 final_transformer = TransformWrapper(
1862 transform, name=name, previous=final_transformer
1863 )
1864 final_transformer.has_transforms = True
1865 # Then, add the path to the query's joins. Note that we can't trim
1866 # joins at this stage - we will need the information about join type
1867 # of the trimmed joins.
1868 for join in path:
1869 if join.filtered_relation:
1870 filtered_relation = join.filtered_relation.clone()
1871 table_alias = filtered_relation.alias
1872 else:
1873 filtered_relation = None
1874 table_alias = None
1875 meta = join.to_meta
1876 if join.direct:
1877 nullable = self.is_nullable(join.join_field)
1878 else:
1879 nullable = True
1880 connection = self.join_class(
1881 meta.model.model_options.db_table,
1882 alias,
1883 table_alias,
1884 INNER,
1885 join.join_field,
1886 nullable,
1887 filtered_relation=filtered_relation,
1888 )
1889 reuse = can_reuse if join.m2m or reuse_with_filtered_relation else None
1890 alias = self.join(
1891 connection,
1892 reuse=reuse,
1893 reuse_with_filtered_relation=reuse_with_filtered_relation,
1894 )
1895 joins.append(alias)
1896 if filtered_relation:
1897 filtered_relation.path = joins[:]
1898 return JoinInfo(final_field, targets, meta, joins, path, final_transformer)
1899
1900 def trim_joins(
1901 self, targets: tuple[Field, ...], joins: list[str], path: list[Any]
1902 ) -> tuple[tuple[Field, ...], str, list[str]]:
1903 """
1904 The 'target' parameter is the final field being joined to, 'joins'
1905 is the full list of join aliases. The 'path' contain the PathInfos
1906 used to create the joins.
1907
1908 Return the final target field and table alias and the new active
1909 joins.
1910
1911 Always trim any direct join if the target column is already in the
1912 previous table. Can't trim reverse joins as it's unknown if there's
1913 anything on the other side of the join.
1914 """
1915 joins = joins[:]
1916 for pos, info in enumerate(reversed(path)):
1917 if len(joins) == 1 or not info.direct:
1918 break
1919 if info.filtered_relation:
1920 break
1921 join_targets = {t.column for t in info.join_field.foreign_related_fields}
1922 cur_targets = {t.column for t in targets}
1923 if not cur_targets.issubset(join_targets):
1924 break
1925 targets_dict = {
1926 r[1].column: r[0]
1927 for r in info.join_field.related_fields
1928 if r[1].column in cur_targets
1929 }
1930 targets = tuple(targets_dict[t.column] for t in targets)
1931 self.unref_alias(joins.pop())
1932 return targets, joins[-1], joins
1933
1934 @classmethod
1935 def _gen_cols(
1936 cls,
1937 exprs: Iterable[Any],
1938 include_external: bool = False,
1939 resolve_refs: bool = True,
1940 ) -> TypingIterator[Col]:
1941 for expr in exprs:
1942 if isinstance(expr, Col):
1943 yield expr
1944 elif include_external and callable(
1945 getattr(expr, "get_external_cols", None)
1946 ):
1947 yield from expr.get_external_cols()
1948 elif hasattr(expr, "get_source_expressions"):
1949 if not resolve_refs and isinstance(expr, Ref):
1950 continue
1951 yield from cls._gen_cols(
1952 expr.get_source_expressions(),
1953 include_external=include_external,
1954 resolve_refs=resolve_refs,
1955 )
1956
1957 @classmethod
1958 def _gen_col_aliases(cls, exprs: Iterable[Any]) -> TypingIterator[str]:
1959 yield from (expr.alias for expr in cls._gen_cols(exprs))
1960
1961 def resolve_ref(
1962 self,
1963 name: str,
1964 allow_joins: bool = True,
1965 reuse: set[str] | None = None,
1966 summarize: bool = False,
1967 ) -> BaseExpression:
1968 annotation = self.annotations.get(name)
1969 if annotation is not None:
1970 if not allow_joins:
1971 for alias in self._gen_col_aliases([annotation]):
1972 if isinstance(self.alias_map[alias], Join):
1973 raise FieldError(
1974 "Joined field references are not permitted in this query"
1975 )
1976 if summarize:
1977 # Summarize currently means we are doing an aggregate() query
1978 # which is executed as a wrapped subquery if any of the
1979 # aggregate() elements reference an existing annotation. In
1980 # that case we need to return a Ref to the subquery's annotation.
1981 if name not in self.annotation_select:
1982 raise FieldError(
1983 f"Cannot aggregate over the '{name}' alias. Use annotate() "
1984 "to promote it."
1985 )
1986 return Ref(name, self.annotation_select[name])
1987 else:
1988 return annotation
1989 else:
1990 field_list = name.split(LOOKUP_SEP)
1991 annotation = self.annotations.get(field_list[0])
1992 if annotation is not None:
1993 for transform in field_list[1:]:
1994 annotation = self.try_transform(annotation, transform)
1995 return annotation
1996 initial_alias = self.get_initial_alias()
1997 assert initial_alias is not None
1998 assert self.model is not None, "Resolving field references requires a model"
1999 meta = self.model._model_meta
2000 join_info = self.setup_joins(
2001 field_list,
2002 meta,
2003 initial_alias,
2004 can_reuse=reuse,
2005 )
2006 targets, final_alias, join_list = self.trim_joins(
2007 join_info.targets, join_info.joins, join_info.path
2008 )
2009 if not allow_joins and len(join_list) > 1:
2010 raise FieldError(
2011 "Joined field references are not permitted in this query"
2012 )
2013 if len(targets) > 1:
2014 raise FieldError(
2015 "Referencing multicolumn fields with F() objects isn't supported"
2016 )
2017 # Verify that the last lookup in name is a field or a transform:
2018 # transform_function() raises FieldError if not.
2019 transform = join_info.transform_function(targets[0], final_alias)
2020 if reuse is not None:
2021 reuse.update(join_list)
2022 return transform
2023
2024 def split_exclude(
2025 self,
2026 filter_expr: tuple[str, Any],
2027 can_reuse: set[str],
2028 names_with_path: list[tuple[str, list[Any]]],
2029 ) -> tuple[WhereNode, set[str] | tuple[()]]:
2030 """
2031 When doing an exclude against any kind of N-to-many relation, we need
2032 to use a subquery. This method constructs the nested query, given the
2033 original exclude filter (filter_expr) and the portion up to the first
2034 N-to-many relation field.
2035
2036 For example, if the origin filter is ~Q(child__name='foo'), filter_expr
2037 is ('child__name', 'foo') and can_reuse is a set of joins usable for
2038 filters in the original query.
2039
2040 We will turn this into equivalent of:
2041 WHERE NOT EXISTS(
2042 SELECT 1
2043 FROM child
2044 WHERE name = 'foo' AND child.parent_id = parent.id
2045 LIMIT 1
2046 )
2047 """
2048 # Generate the inner query.
2049 query = self.__class__(self.model)
2050 query._filtered_relations = self._filtered_relations
2051 filter_lhs, filter_rhs = filter_expr
2052 if isinstance(filter_rhs, OuterRef):
2053 filter_rhs = OuterRef(filter_rhs)
2054 elif isinstance(filter_rhs, F):
2055 filter_rhs = OuterRef(filter_rhs.name)
2056 query.add_filter(filter_lhs, filter_rhs)
2057 query.clear_ordering(force=True)
2058 # Try to have as simple as possible subquery -> trim leading joins from
2059 # the subquery.
2060 trimmed_prefix, contains_louter = query.trim_start(names_with_path)
2061
2062 col = query.select[0]
2063 select_field = col.target
2064 alias = col.alias
2065 if alias in can_reuse:
2066 id_field = select_field.model._model_meta.get_forward_field("id")
2067 # Need to add a restriction so that outer query's filters are in effect for
2068 # the subquery, too.
2069 query.bump_prefix(self)
2070 lookup_class = select_field.get_lookup("exact")
2071 # Note that the query.select[0].alias is different from alias
2072 # due to bump_prefix above.
2073 lookup = lookup_class(
2074 id_field.get_col(query.select[0].alias), id_field.get_col(alias)
2075 )
2076 query.where.add(lookup, AND)
2077 query.external_aliases[alias] = True
2078
2079 lookup_class = select_field.get_lookup("exact")
2080 lookup = lookup_class(col, ResolvedOuterRef(trimmed_prefix))
2081 query.where.add(lookup, AND)
2082 condition, needed_inner = self.build_filter(Exists(query))
2083
2084 if contains_louter:
2085 or_null_condition, _ = self.build_filter(
2086 (f"{trimmed_prefix}__isnull", True),
2087 current_negated=True,
2088 branch_negated=True,
2089 can_reuse=can_reuse,
2090 )
2091 condition.add(or_null_condition, OR)
2092 # Note that the end result will be:
2093 # (outercol NOT IN innerq AND outercol IS NOT NULL) OR outercol IS NULL.
2094 # This might look crazy but due to how IN works, this seems to be
2095 # correct. If the IS NOT NULL check is removed then outercol NOT
2096 # IN will return UNKNOWN. If the IS NULL check is removed, then if
2097 # outercol IS NULL we will not match the row.
2098 return condition, needed_inner
2099
2100 def set_empty(self) -> None:
2101 self.where.add(NothingNode(), AND)
2102
2103 def is_empty(self) -> bool:
2104 return any(isinstance(c, NothingNode) for c in self.where.children)
2105
2106 def set_limits(self, low: int | None = None, high: int | None = None) -> None:
2107 """
2108 Adjust the limits on the rows retrieved. Use low/high to set these,
2109 as it makes it more Pythonic to read and write. When the SQL query is
2110 created, convert them to the appropriate offset and limit values.
2111
2112 Apply any limits passed in here to the existing constraints. Add low
2113 to the current low value and clamp both to any existing high value.
2114 """
2115 if high is not None:
2116 if self.high_mark is not None:
2117 self.high_mark = min(self.high_mark, self.low_mark + high)
2118 else:
2119 self.high_mark = self.low_mark + high
2120 if low is not None:
2121 if self.high_mark is not None:
2122 self.low_mark = min(self.high_mark, self.low_mark + low)
2123 else:
2124 self.low_mark = self.low_mark + low
2125
2126 if self.low_mark == self.high_mark:
2127 self.set_empty()
2128
2129 def clear_limits(self) -> None:
2130 """Clear any existing limits."""
2131 self.low_mark, self.high_mark = 0, None
2132
2133 @property
2134 def is_sliced(self) -> bool:
2135 return self.low_mark != 0 or self.high_mark is not None
2136
2137 def has_limit_one(self) -> bool:
2138 return self.high_mark is not None and (self.high_mark - self.low_mark) == 1
2139
2140 def can_filter(self) -> bool:
2141 """
2142 Return True if adding filters to this instance is still possible.
2143
2144 Typically, this means no limits or offsets have been put on the results.
2145 """
2146 return not self.is_sliced
2147
2148 def clear_select_clause(self) -> None:
2149 """Remove all fields from SELECT clause."""
2150 self.select = ()
2151 self.default_cols = False
2152 self.select_related = False
2153 self.set_extra_mask(())
2154 self.set_annotation_mask(())
2155
2156 def clear_select_fields(self) -> None:
2157 """
2158 Clear the list of fields to select (but not extra_select columns).
2159 Some queryset types completely replace any existing list of select
2160 columns.
2161 """
2162 self.select = ()
2163 self.values_select = ()
2164
2165 def add_select_col(self, col: BaseExpression, name: str) -> None:
2166 self.select += (col,)
2167 self.values_select += (name,)
2168
2169 def set_select(self, cols: list[Col] | tuple[Col, ...]) -> None:
2170 self.default_cols = False
2171 self.select = tuple(cols)
2172
2173 def add_distinct_fields(self, *field_names: str) -> None:
2174 """
2175 Add and resolve the given fields to the query's "distinct on" clause.
2176 """
2177 self.distinct_fields = field_names
2178 self.distinct = True
2179
2180 def add_fields(
2181 self, field_names: list[str] | TypingIterator[str], allow_m2m: bool = True
2182 ) -> None:
2183 """
2184 Add the given (model) fields to the select set. Add the field names in
2185 the order specified.
2186 """
2187 alias = self.get_initial_alias()
2188 assert alias is not None
2189 assert self.model is not None, "add_fields() requires a model"
2190 meta = self.model._model_meta
2191
2192 try:
2193 cols = []
2194 for name in field_names:
2195 # Join promotion note - we must not remove any rows here, so
2196 # if there is no existing joins, use outer join.
2197 join_info = self.setup_joins(
2198 name.split(LOOKUP_SEP), meta, alias, allow_many=allow_m2m
2199 )
2200 targets, final_alias, joins = self.trim_joins(
2201 join_info.targets,
2202 join_info.joins,
2203 join_info.path,
2204 )
2205 for target in targets:
2206 cols.append(join_info.transform_function(target, final_alias))
2207 if cols:
2208 self.set_select(cols)
2209 except MultiJoin:
2210 raise FieldError(f"Invalid field name: '{name}'")
2211 except FieldError:
2212 if LOOKUP_SEP in name:
2213 # For lookups spanning over relationships, show the error
2214 # from the model on which the lookup failed.
2215 raise
2216 elif name in self.annotations:
2217 raise FieldError(
2218 f"Cannot select the '{name}' alias. Use annotate() to promote it."
2219 )
2220 else:
2221 names = sorted(
2222 [
2223 *get_field_names_from_opts(meta),
2224 *self.extra,
2225 *self.annotation_select,
2226 *self._filtered_relations,
2227 ]
2228 )
2229 raise FieldError(
2230 "Cannot resolve keyword {!r} into field. Choices are: {}".format(
2231 name, ", ".join(names)
2232 )
2233 )
2234
2235 def add_ordering(self, *ordering: str | BaseExpression) -> None:
2236 """
2237 Add items from the 'ordering' sequence to the query's "order by"
2238 clause. These items are either field names (not column names) --
2239 possibly with a direction prefix ('-' or '?') -- or OrderBy
2240 expressions.
2241
2242 If 'ordering' is empty, clear all ordering from the query.
2243 """
2244 errors = []
2245 for item in ordering:
2246 if isinstance(item, str):
2247 if item == "?":
2248 continue
2249 item = item.removeprefix("-")
2250 if item in self.annotations:
2251 continue
2252 if self.extra and item in self.extra:
2253 continue
2254 # names_to_path() validates the lookup. A descriptive
2255 # FieldError will be raise if it's not.
2256 assert self.model is not None, "ORDER BY field names require a model"
2257 self.names_to_path(item.split(LOOKUP_SEP), self.model._model_meta)
2258 elif not isinstance(item, ResolvableExpression):
2259 errors.append(item)
2260 if getattr(item, "contains_aggregate", False):
2261 raise FieldError(
2262 "Using an aggregate in order_by() without also including "
2263 f"it in annotate() is not allowed: {item}"
2264 )
2265 if errors:
2266 raise FieldError(f"Invalid order_by arguments: {errors}")
2267 if ordering:
2268 self.order_by += ordering
2269 else:
2270 self.default_ordering = False
2271
2272 def clear_ordering(self, force: bool = False, clear_default: bool = True) -> None:
2273 """
2274 Remove any ordering settings if the current query allows it without
2275 side effects, set 'force' to True to clear the ordering regardless.
2276 If 'clear_default' is True, there will be no ordering in the resulting
2277 query (not even the model's default).
2278 """
2279 if not force and (
2280 self.is_sliced or self.distinct_fields or self.select_for_update
2281 ):
2282 return
2283 self.order_by = ()
2284 self.extra_order_by = ()
2285 if clear_default:
2286 self.default_ordering = False
2287
2288 def set_group_by(self, allow_aliases: bool = True) -> None:
2289 """
2290 Expand the GROUP BY clause required by the query.
2291
2292 This will usually be the set of all non-aggregate fields in the
2293 return data. If the database backend supports grouping by the
2294 primary key, and the query would be equivalent, the optimization
2295 will be made automatically.
2296 """
2297 if allow_aliases and self.values_select:
2298 # If grouping by aliases is allowed assign selected value aliases
2299 # by moving them to annotations.
2300 group_by_annotations = {}
2301 values_select = {}
2302 for alias, expr in zip(self.values_select, self.select):
2303 if isinstance(expr, Col):
2304 values_select[alias] = expr
2305 else:
2306 group_by_annotations[alias] = expr
2307 self.annotations = {**group_by_annotations, **self.annotations}
2308 self.append_annotation_mask(group_by_annotations)
2309 self.select = tuple(values_select.values())
2310 self.values_select = tuple(values_select)
2311 group_by = list(self.select)
2312 for alias, annotation in self.annotation_select.items():
2313 if not (group_by_cols := annotation.get_group_by_cols()):
2314 continue
2315 if allow_aliases and not annotation.contains_aggregate:
2316 group_by.append(Ref(alias, annotation))
2317 else:
2318 group_by.extend(group_by_cols)
2319 self.group_by = tuple(group_by)
2320
2321 def add_select_related(self, fields: list[str]) -> None:
2322 """
2323 Set up the select_related data structure so that we only select
2324 certain related models (as opposed to all models, when
2325 self.select_related=True).
2326 """
2327 if isinstance(self.select_related, bool):
2328 field_dict: dict[str, Any] = {}
2329 else:
2330 field_dict = self.select_related
2331 for field in fields:
2332 d = field_dict
2333 for part in field.split(LOOKUP_SEP):
2334 d = d.setdefault(part, {})
2335 self.select_related = field_dict
2336
2337 def add_extra(
2338 self,
2339 select: dict[str, str],
2340 select_params: list[Any] | None,
2341 where: list[str],
2342 params: list[Any],
2343 tables: list[str],
2344 order_by: tuple[str, ...],
2345 ) -> None:
2346 """
2347 Add data to the various extra_* attributes for user-created additions
2348 to the query.
2349 """
2350 if select:
2351 # We need to pair any placeholder markers in the 'select'
2352 # dictionary with their parameters in 'select_params' so that
2353 # subsequent updates to the select dictionary also adjust the
2354 # parameters appropriately.
2355 select_pairs = {}
2356 if select_params:
2357 param_iter = iter(select_params)
2358 else:
2359 param_iter = iter([])
2360 for name, entry in select.items():
2361 self.check_alias(name)
2362 entry = str(entry)
2363 entry_params = []
2364 pos = entry.find("%s")
2365 while pos != -1:
2366 if pos == 0 or entry[pos - 1] != "%":
2367 entry_params.append(next(param_iter))
2368 pos = entry.find("%s", pos + 2)
2369 select_pairs[name] = (entry, entry_params)
2370 self.extra.update(select_pairs)
2371 if where or params:
2372 self.where.add(ExtraWhere(where, params), AND)
2373 if tables:
2374 self.extra_tables += tuple(tables)
2375 if order_by:
2376 self.extra_order_by = order_by
2377
2378 def clear_deferred_loading(self) -> None:
2379 """Remove any fields from the deferred loading set."""
2380 self.deferred_loading = (frozenset(), True)
2381
2382 def add_deferred_loading(self, field_names: frozenset[str]) -> None:
2383 """
2384 Add the given list of model field names to the set of fields to
2385 exclude from loading from the database when automatic column selection
2386 is done. Add the new field names to any existing field names that
2387 are deferred (or removed from any existing field names that are marked
2388 as the only ones for immediate loading).
2389 """
2390 # Fields on related models are stored in the literal double-underscore
2391 # format, so that we can use a set datastructure. We do the foo__bar
2392 # splitting and handling when computing the SQL column names (as part of
2393 # get_columns()).
2394 existing, defer = self.deferred_loading
2395 existing_set = set(existing)
2396 if defer:
2397 # Add to existing deferred names.
2398 self.deferred_loading = frozenset(existing_set.union(field_names)), True
2399 else:
2400 # Remove names from the set of any existing "immediate load" names.
2401 if new_existing := existing_set.difference(field_names):
2402 self.deferred_loading = frozenset(new_existing), False
2403 else:
2404 self.clear_deferred_loading()
2405 if new_only := set(field_names).difference(existing_set):
2406 self.deferred_loading = frozenset(new_only), True
2407
2408 def add_immediate_loading(self, field_names: list[str] | set[str]) -> None:
2409 """
2410 Add the given list of model field names to the set of fields to
2411 retrieve when the SQL is executed ("immediate loading" fields). The
2412 field names replace any existing immediate loading field names. If
2413 there are field names already specified for deferred loading, remove
2414 those names from the new field_names before storing the new names
2415 for immediate loading. (That is, immediate loading overrides any
2416 existing immediate values, but respects existing deferrals.)
2417 """
2418 existing, defer = self.deferred_loading
2419 field_names_set = set(field_names)
2420
2421 if defer:
2422 # Remove any existing deferred names from the current set before
2423 # setting the new names.
2424 self.deferred_loading = (
2425 frozenset(field_names_set.difference(existing)),
2426 False,
2427 )
2428 else:
2429 # Replace any existing "immediate load" field names.
2430 self.deferred_loading = frozenset(field_names_set), False
2431
2432 def set_annotation_mask(
2433 self,
2434 names: set[str]
2435 | frozenset[str]
2436 | list[str]
2437 | tuple[str, ...]
2438 | dict[str, Any]
2439 | None,
2440 ) -> None:
2441 """Set the mask of annotations that will be returned by the SELECT."""
2442 if names is None:
2443 self.annotation_select_mask = None
2444 else:
2445 self.annotation_select_mask = set(names)
2446 self._annotation_select_cache = None
2447
2448 def append_annotation_mask(self, names: list[str] | dict[str, Any]) -> None:
2449 if self.annotation_select_mask is not None:
2450 self.set_annotation_mask(self.annotation_select_mask.union(names))
2451
2452 def set_extra_mask(
2453 self, names: set[str] | list[str] | tuple[str, ...] | None
2454 ) -> None:
2455 """
2456 Set the mask of extra select items that will be returned by SELECT.
2457 Don't remove them from the Query since they might be used later.
2458 """
2459 if names is None:
2460 self.extra_select_mask = None
2461 else:
2462 self.extra_select_mask = set(names)
2463 self._extra_select_cache = None
2464
2465 def set_values(self, fields: list[str]) -> None:
2466 self.select_related = False
2467 self.clear_deferred_loading()
2468 self.clear_select_fields()
2469 self.has_select_fields = True
2470
2471 if fields:
2472 field_names = []
2473 extra_names = []
2474 annotation_names = []
2475 if not self.extra and not self.annotations:
2476 # Shortcut - if there are no extra or annotations, then
2477 # the values() clause must be just field names.
2478 field_names = list(fields)
2479 else:
2480 self.default_cols = False
2481 for f in fields:
2482 if f in self.extra_select:
2483 extra_names.append(f)
2484 elif f in self.annotation_select:
2485 annotation_names.append(f)
2486 else:
2487 field_names.append(f)
2488 self.set_extra_mask(extra_names)
2489 self.set_annotation_mask(annotation_names)
2490 selected = frozenset(field_names + extra_names + annotation_names)
2491 else:
2492 assert self.model is not None, "Default values query requires a model"
2493 field_names = [f.attname for f in self.model._model_meta.concrete_fields]
2494 selected = frozenset(field_names)
2495 # Selected annotations must be known before setting the GROUP BY
2496 # clause.
2497 if self.group_by is True:
2498 assert self.model is not None, "GROUP BY True requires a model"
2499 self.add_fields(
2500 (f.attname for f in self.model._model_meta.concrete_fields), False
2501 )
2502 # Disable GROUP BY aliases to avoid orphaning references to the
2503 # SELECT clause which is about to be cleared.
2504 self.set_group_by(allow_aliases=False)
2505 self.clear_select_fields()
2506 elif self.group_by:
2507 # Resolve GROUP BY annotation references if they are not part of
2508 # the selected fields anymore.
2509 group_by = []
2510 for expr in self.group_by:
2511 if isinstance(expr, Ref) and expr.refs not in selected:
2512 expr = self.annotations[expr.refs]
2513 group_by.append(expr)
2514 self.group_by = tuple(group_by)
2515
2516 self.values_select = tuple(field_names)
2517 self.add_fields(field_names, True)
2518
2519 @property
2520 def annotation_select(self) -> dict[str, BaseExpression]:
2521 """
2522 Return the dictionary of aggregate columns that are not masked and
2523 should be used in the SELECT clause. Cache this result for performance.
2524 """
2525 if self._annotation_select_cache is not None:
2526 return self._annotation_select_cache
2527 elif not self.annotations:
2528 return {}
2529 elif self.annotation_select_mask is not None:
2530 self._annotation_select_cache = {
2531 k: v
2532 for k, v in self.annotations.items()
2533 if k in self.annotation_select_mask
2534 }
2535 return self._annotation_select_cache
2536 else:
2537 return self.annotations
2538
2539 @property
2540 def extra_select(self) -> dict[str, tuple[str, list[Any]]]:
2541 if self._extra_select_cache is not None:
2542 return self._extra_select_cache
2543 if not self.extra:
2544 return {}
2545 elif self.extra_select_mask is not None:
2546 self._extra_select_cache = {
2547 k: v for k, v in self.extra.items() if k in self.extra_select_mask
2548 }
2549 return self._extra_select_cache
2550 else:
2551 return self.extra
2552
2553 def trim_start(
2554 self, names_with_path: list[tuple[str, list[Any]]]
2555 ) -> tuple[str, bool]:
2556 """
2557 Trim joins from the start of the join path. The candidates for trim
2558 are the PathInfos in names_with_path structure that are m2m joins.
2559
2560 Also set the select column so the start matches the join.
2561
2562 This method is meant to be used for generating the subquery joins &
2563 cols in split_exclude().
2564
2565 Return a lookup usable for doing outerq.filter(lookup=self) and a
2566 boolean indicating if the joins in the prefix contain a LEFT OUTER join.
2567 _"""
2568 all_paths = []
2569 for _, paths in names_with_path:
2570 all_paths.extend(paths)
2571 contains_louter = False
2572 # Trim and operate only on tables that were generated for
2573 # the lookup part of the query. That is, avoid trimming
2574 # joins generated for F() expressions.
2575 lookup_tables = [
2576 t for t in self.alias_map if t in self._lookup_joins or t == self.base_table
2577 ]
2578 for trimmed_paths, path in enumerate(all_paths):
2579 if path.m2m:
2580 break
2581 if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type == LOUTER:
2582 contains_louter = True
2583 alias = lookup_tables[trimmed_paths]
2584 self.unref_alias(alias)
2585 # The path.join_field is a Rel, lets get the other side's field
2586 join_field = path.join_field.field
2587 # Build the filter prefix.
2588 paths_in_prefix = trimmed_paths
2589 trimmed_prefix = []
2590 for name, path in names_with_path:
2591 if paths_in_prefix - len(path) < 0:
2592 break
2593 trimmed_prefix.append(name)
2594 paths_in_prefix -= len(path)
2595 trimmed_prefix.append(join_field.foreign_related_fields[0].name)
2596 trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
2597 # Lets still see if we can trim the first join from the inner query
2598 # (that is, self). We can't do this for:
2599 # - LEFT JOINs because we would miss those rows that have nothing on
2600 # the outer side,
2601 # - INNER JOINs from filtered relations because we would miss their
2602 # filters.
2603 first_join = self.alias_map[lookup_tables[trimmed_paths + 1]]
2604 if first_join.join_type != LOUTER and not first_join.filtered_relation:
2605 select_fields = [r[0] for r in join_field.related_fields]
2606 select_alias = lookup_tables[trimmed_paths + 1]
2607 self.unref_alias(lookup_tables[trimmed_paths])
2608 else:
2609 # TODO: It might be possible to trim more joins from the start of the
2610 # inner query if it happens to have a longer join chain containing the
2611 # values in select_fields. Lets punt this one for now.
2612 select_fields = [r[1] for r in join_field.related_fields]
2613 select_alias = lookup_tables[trimmed_paths]
2614 # The found starting point is likely a join_class instead of a
2615 # base_table_class reference. But the first entry in the query's FROM
2616 # clause must not be a JOIN.
2617 for table in self.alias_map:
2618 if self.alias_refcount[table] > 0:
2619 self.alias_map[table] = self.base_table_class(
2620 self.alias_map[table].table_name,
2621 table,
2622 )
2623 break
2624 self.set_select([f.get_col(select_alias) for f in select_fields])
2625 return trimmed_prefix, contains_louter
2626
2627 def is_nullable(self, field: Field) -> bool:
2628 """Check if the given field should be treated as nullable."""
2629 # QuerySet does not have knowledge of which connection is going to be
2630 # used. For the single-database setup we always reference the default
2631 # connection here.
2632 return field.allow_null
2633
2634
2635def get_order_dir(field: str, default: str = "ASC") -> tuple[str, str]:
2636 """
2637 Return the field name and direction for an order specification. For
2638 example, '-foo' is returned as ('foo', 'DESC').
2639
2640 The 'default' param is used to indicate which way no prefix (or a '+'
2641 prefix) should sort. The '-' prefix always sorts the opposite way.
2642 """
2643 dirn = ORDER_DIR[default]
2644 if field[0] == "-":
2645 return field[1:], dirn[1]
2646 return field, dirn[0]
2647
2648
2649class JoinPromoter:
2650 """
2651 A class to abstract away join promotion problems for complex filter
2652 conditions.
2653 """
2654
2655 def __init__(self, connector: str, num_children: int, negated: bool):
2656 self.connector = connector
2657 self.negated = negated
2658 if self.negated:
2659 if connector == AND:
2660 self.effective_connector = OR
2661 else:
2662 self.effective_connector = AND
2663 else:
2664 self.effective_connector = self.connector
2665 self.num_children = num_children
2666 # Maps of table alias to how many times it is seen as required for
2667 # inner and/or outer joins.
2668 self.votes = Counter()
2669
2670 def __repr__(self) -> str:
2671 return (
2672 f"{self.__class__.__qualname__}(connector={self.connector!r}, "
2673 f"num_children={self.num_children!r}, negated={self.negated!r})"
2674 )
2675
2676 def add_votes(self, votes: Any) -> None:
2677 """
2678 Add single vote per item to self.votes. Parameter can be any
2679 iterable.
2680 """
2681 self.votes.update(votes)
2682
2683 def update_join_types(self, query: Query) -> set[str]:
2684 """
2685 Change join types so that the generated query is as efficient as
2686 possible, but still correct. So, change as many joins as possible
2687 to INNER, but don't make OUTER joins INNER if that could remove
2688 results from the query.
2689 """
2690 to_promote = set()
2691 to_demote = set()
2692 # The effective_connector is used so that NOT (a AND b) is treated
2693 # similarly to (a OR b) for join promotion.
2694 for table, votes in self.votes.items():
2695 # We must use outer joins in OR case when the join isn't contained
2696 # in all of the joins. Otherwise the INNER JOIN itself could remove
2697 # valid results. Consider the case where a model with rel_a and
2698 # rel_b relations is queried with rel_a__col=1 | rel_b__col=2. Now,
2699 # if rel_a join doesn't produce any results is null (for example
2700 # reverse foreign key or null value in direct foreign key), and
2701 # there is a matching row in rel_b with col=2, then an INNER join
2702 # to rel_a would remove a valid match from the query. So, we need
2703 # to promote any existing INNER to LOUTER (it is possible this
2704 # promotion in turn will be demoted later on).
2705 if self.effective_connector == OR and votes < self.num_children:
2706 to_promote.add(table)
2707 # If connector is AND and there is a filter that can match only
2708 # when there is a joinable row, then use INNER. For example, in
2709 # rel_a__col=1 & rel_b__col=2, if either of the rels produce NULL
2710 # as join output, then the col=1 or col=2 can't match (as
2711 # NULL=anything is always false).
2712 # For the OR case, if all children voted for a join to be inner,
2713 # then we can use INNER for the join. For example:
2714 # (rel_a__col__icontains=Alex | rel_a__col__icontains=Russell)
2715 # then if rel_a doesn't produce any rows, the whole condition
2716 # can't match. Hence we can safely use INNER join.
2717 if self.effective_connector == AND or (
2718 self.effective_connector == OR and votes == self.num_children
2719 ):
2720 to_demote.add(table)
2721 # Finally, what happens in cases where we have:
2722 # (rel_a__col=1|rel_b__col=2) & rel_a__col__gte=0
2723 # Now, we first generate the OR clause, and promote joins for it
2724 # in the first if branch above. Both rel_a and rel_b are promoted
2725 # to LOUTER joins. After that we do the AND case. The OR case
2726 # voted no inner joins but the rel_a__col__gte=0 votes inner join
2727 # for rel_a. We demote it back to INNER join (in AND case a single
2728 # vote is enough). The demotion is OK, if rel_a doesn't produce
2729 # rows, then the rel_a__col__gte=0 clause can't be true, and thus
2730 # the whole clause must be false. So, it is safe to use INNER
2731 # join.
2732 # Note that in this example we could just as well have the __gte
2733 # clause and the OR clause swapped. Or we could replace the __gte
2734 # clause with an OR clause containing rel_a__col=1|rel_a__col=2,
2735 # and again we could safely demote to INNER.
2736 query.promote_joins(to_promote)
2737 query.demote_joins(to_demote)
2738 return to_demote
2739
2740
2741# ##### Query subclasses (merged from subqueries.py) #####
2742
2743
2744class DeleteQuery(Query):
2745 """A DELETE SQL query."""
2746
2747 def get_compiler(self, *, elide_empty: bool = True) -> SQLDeleteCompiler:
2748 from plain.models.sql.compiler import SQLDeleteCompiler
2749
2750 connection = cast("DatabaseWrapper", db_connection)
2751 return SQLDeleteCompiler(self, connection, elide_empty)
2752
2753 def do_query(self, table: str, where: Any) -> int:
2754 from plain.models.sql.constants import CURSOR
2755
2756 self.alias_map = {table: self.alias_map[table]}
2757 self.where = where
2758 cursor = self.get_compiler().execute_sql(CURSOR)
2759 if cursor:
2760 with cursor:
2761 return cursor.rowcount
2762 return 0
2763
2764 def delete_batch(self, id_list: list[Any]) -> int:
2765 """
2766 Set up and execute delete queries for all the objects in id_list.
2767
2768 More than one physical query may be executed if there are a
2769 lot of values in id_list.
2770 """
2771 from plain.models.sql.constants import GET_ITERATOR_CHUNK_SIZE
2772
2773 # number of objects deleted
2774 num_deleted = 0
2775 assert self.model is not None, "DELETE requires a model"
2776 meta = self.model._model_meta
2777 field = meta.get_forward_field("id")
2778 for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
2779 self.clear_where()
2780 self.add_filter(
2781 f"{field.attname}__in",
2782 id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
2783 )
2784 num_deleted += self.do_query(self.model.model_options.db_table, self.where)
2785 return num_deleted
2786
2787
2788class UpdateQuery(Query):
2789 """An UPDATE SQL query."""
2790
2791 def get_compiler(self, *, elide_empty: bool = True) -> SQLUpdateCompiler:
2792 from plain.models.sql.compiler import SQLUpdateCompiler
2793
2794 connection = cast("DatabaseWrapper", db_connection)
2795 return SQLUpdateCompiler(self, connection, elide_empty)
2796
2797 def __init__(self, *args: Any, **kwargs: Any) -> None:
2798 super().__init__(*args, **kwargs)
2799 self._setup_query()
2800
2801 def _setup_query(self) -> None:
2802 """
2803 Run on initialization and at the end of chaining. Any attributes that
2804 would normally be set in __init__() should go here instead.
2805 """
2806 self.values: list[tuple[Any, Any, Any]] = []
2807 self.related_ids: dict[Any, list[Any]] | None = None
2808 self.related_updates: dict[Any, list[tuple[Any, Any, Any]]] = {}
2809
2810 def clone(self) -> UpdateQuery:
2811 obj = super().clone()
2812 obj.related_updates = self.related_updates.copy()
2813 return obj
2814
2815 def update_batch(self, id_list: list[Any], values: dict[str, Any]) -> None:
2816 from plain.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
2817
2818 self.add_update_values(values)
2819 for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
2820 self.clear_where()
2821 self.add_filter(
2822 "id__in", id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
2823 )
2824 self.get_compiler().execute_sql(NO_RESULTS)
2825
2826 def add_update_values(self, values: dict[str, Any]) -> None:
2827 """
2828 Convert a dictionary of field name to value mappings into an update
2829 query. This is the entry point for the public update() method on
2830 querysets.
2831 """
2832
2833 assert self.model is not None, "UPDATE requires model metadata"
2834 meta = self.model._model_meta
2835 values_seq = []
2836 for name, val in values.items():
2837 field = meta.get_field(name)
2838 direct = (
2839 not (field.auto_created and not field.concrete) or not field.concrete
2840 )
2841 model = field.model
2842 from plain.models.fields.related import ManyToManyField
2843
2844 if not direct or isinstance(field, ManyToManyField):
2845 raise FieldError(
2846 f"Cannot update model field {field!r} (only non-relations and "
2847 "foreign keys permitted)."
2848 )
2849 if model is not meta.model:
2850 self.add_related_update(model, field, val)
2851 continue
2852 values_seq.append((field, model, val))
2853 return self.add_update_fields(values_seq)
2854
2855 def add_update_fields(self, values_seq: list[tuple[Any, Any, Any]]) -> None:
2856 """
2857 Append a sequence of (field, model, value) triples to the internal list
2858 that will be used to generate the UPDATE query. Might be more usefully
2859 called add_update_targets() to hint at the extra information here.
2860 """
2861 for field, model, val in values_seq:
2862 if isinstance(val, ResolvableExpression):
2863 # Resolve expressions here so that annotations are no longer needed
2864 val = val.resolve_expression(self, allow_joins=False, for_save=True)
2865 self.values.append((field, model, val))
2866
2867 def add_related_update(self, model: Any, field: Any, value: Any) -> None:
2868 """
2869 Add (name, value) to an update query for an ancestor model.
2870
2871 Update are coalesced so that only one update query per ancestor is run.
2872 """
2873 self.related_updates.setdefault(model, []).append((field, None, value))
2874
2875 def get_related_updates(self) -> list[UpdateQuery]:
2876 """
2877 Return a list of query objects: one for each update required to an
2878 ancestor model. Each query will have the same filtering conditions as
2879 the current query but will only update a single table.
2880 """
2881 if not self.related_updates:
2882 return []
2883 result = []
2884 for model, values in self.related_updates.items():
2885 query = UpdateQuery(model)
2886 query.values = values
2887 if self.related_ids is not None:
2888 query.add_filter("id__in", self.related_ids[model])
2889 result.append(query)
2890 return result
2891
2892
2893class InsertQuery(Query):
2894 def get_compiler(self, *, elide_empty: bool = True) -> SQLInsertCompiler:
2895 from plain.models.sql.compiler import SQLInsertCompiler
2896
2897 connection = cast("DatabaseWrapper", db_connection)
2898 return SQLInsertCompiler(self, connection, elide_empty)
2899
2900 def __str__(self) -> str:
2901 raise NotImplementedError(
2902 "InsertQuery does not support __str__(). "
2903 "Use get_compiler().as_sql() which returns a list of SQL statements."
2904 )
2905
2906 def sql_with_params(self) -> Any:
2907 raise NotImplementedError(
2908 "InsertQuery does not support sql_with_params(). "
2909 "Use get_compiler().as_sql() which returns a list of SQL statements."
2910 )
2911
2912 def __init__(
2913 self,
2914 *args: Any,
2915 on_conflict: OnConflict | None = None,
2916 update_fields: list[Field] | None = None,
2917 unique_fields: list[Field] | None = None,
2918 **kwargs: Any,
2919 ) -> None:
2920 super().__init__(*args, **kwargs)
2921 self.fields: list[Field] = []
2922 self.objs: list[Any] = []
2923 self.on_conflict = on_conflict
2924 self.update_fields: list[Field] = update_fields or []
2925 self.unique_fields: list[Field] = unique_fields or []
2926
2927 def insert_values(
2928 self, fields: list[Any], objs: list[Any], raw: bool = False
2929 ) -> None:
2930 self.fields = fields
2931 self.objs = objs
2932 self.raw = raw
2933
2934
2935class AggregateQuery(Query):
2936 """
2937 Take another query as a parameter to the FROM clause and only select the
2938 elements in the provided list.
2939 """
2940
2941 def get_compiler(self, *, elide_empty: bool = True) -> SQLAggregateCompiler:
2942 from plain.models.sql.compiler import SQLAggregateCompiler
2943
2944 connection = cast("DatabaseWrapper", db_connection)
2945 return SQLAggregateCompiler(self, connection, elide_empty)
2946
2947 def __init__(self, model: Any, inner_query: Any) -> None:
2948 self.inner_query = inner_query
2949 super().__init__(model)