1"""
2Create SQL statements for QuerySets.
3
4The code in here encapsulates all of the SQL construction so that QuerySets
5themselves do not have to. This module has to know all about the internals of
6models in order to get the information it needs.
7"""
8
9from __future__ import annotations
10
11import copy
12import difflib
13import functools
14import sys
15from collections import Counter
16from collections.abc import Callable, Iterable, Iterator, Mapping
17from collections.abc import Iterator as TypingIterator
18from functools import cached_property
19from itertools import chain, count, product
20from string import ascii_uppercase
21from typing import (
22 TYPE_CHECKING,
23 Any,
24 Literal,
25 NamedTuple,
26 Self,
27 TypeVar,
28 cast,
29 overload,
30)
31
32import psycopg
33
34from plain.postgres.aggregates import Count
35from plain.postgres.constants import LOOKUP_SEP, OnConflict
36from plain.postgres.db import get_connection
37from plain.postgres.exceptions import FieldDoesNotExist, FieldError
38from plain.postgres.expressions import (
39 BaseExpression,
40 Col,
41 Exists,
42 F,
43 OuterRef,
44 Ref,
45 ResolvableExpression,
46 ResolvedOuterRef,
47 Value,
48)
49from plain.postgres.fields import Field
50from plain.postgres.fields.base import ColumnField
51from plain.postgres.lookups import Lookup
52from plain.postgres.query_utils import (
53 PathInfo,
54 Q,
55 check_rel_lookup_compatibility,
56 refs_expression,
57)
58from plain.postgres.sql.constants import INNER, LOUTER, ORDER_DIR, SINGLE
59from plain.postgres.sql.datastructures import BaseTable, Empty, Join, MultiJoin
60from plain.postgres.sql.where import AND, OR, ExtraWhere, NothingNode, WhereNode
61from plain.utils.regex_helper import _lazy_re_compile
62from plain.utils.tree import Node
63
64if TYPE_CHECKING:
65 from plain.postgres import Model
66 from plain.postgres.connection import DatabaseConnection
67 from plain.postgres.fields.related import RelatedField
68 from plain.postgres.fields.reverse_related import ForeignObjectRel
69 from plain.postgres.meta import Meta
70 from plain.postgres.sql.compiler import (
71 SQLAggregateCompiler,
72 SQLCompiler,
73 SQLDeleteCompiler,
74 SQLInsertCompiler,
75 SQLUpdateCompiler,
76 SqlWithParams,
77 )
78
79__all__ = [
80 "Query",
81 "RawQuery",
82 "DeleteQuery",
83 "UpdateQuery",
84 "InsertQuery",
85 "AggregateQuery",
86]
87
88
89# Quotation marks ('"`[]), whitespace characters, semicolons, or inline
90# SQL comments are forbidden in column aliases.
91FORBIDDEN_ALIAS_PATTERN = _lazy_re_compile(r"['`\"\]\[;\s]|--|/\*|\*/")
92
93# Inspired from
94# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
95EXPLAIN_OPTIONS_PATTERN = _lazy_re_compile(r"[\w\-]+")
96
97
98def get_field_names_from_opts(meta: Meta | None) -> set[str]:
99 if meta is None:
100 return set()
101 return set(
102 chain.from_iterable(
103 (f.name, f.attname) if f.concrete else (f.name,) for f in meta.get_fields()
104 )
105 )
106
107
108def get_children_from_q(q: Q) -> TypingIterator[tuple[str, Any]]:
109 for child in q.children:
110 if isinstance(child, Node):
111 yield from get_children_from_q(child)
112 else:
113 yield child
114
115
116class JoinInfo(NamedTuple):
117 """Information about a join operation in a query."""
118
119 final_field: Field[Any]
120 targets: tuple[Field[Any], ...]
121 meta: Meta
122 joins: list[str]
123 path: list[PathInfo]
124 transform_function: Callable[[Field[Any], str | None], BaseExpression]
125
126
127class RawQuery:
128 """A single raw SQL query."""
129
130 def __init__(self, sql: str, params: tuple[Any, ...] | dict[str, Any] = ()):
131 self.params = params
132 self.sql = sql
133 self.cursor: Any = None
134
135 # Mirror some properties of a normal query so that
136 # the compiler can be used to process results.
137 self.low_mark, self.high_mark = 0, None # Used for offset/limit
138 self.extra_select = {}
139 self.annotation_select = {}
140
141 def chain(self) -> RawQuery:
142 return self.clone()
143
144 def clone(self) -> RawQuery:
145 return RawQuery(self.sql, params=self.params)
146
147 def get_columns(self) -> list[str]:
148 if self.cursor is None:
149 self._execute_query()
150 return [column_meta[0] for column_meta in self.cursor.description]
151
152 def __iter__(self) -> TypingIterator[Any]:
153 # Always execute a new query for a new iterator.
154 # This could be optimized with a cache at the expense of RAM.
155 self._execute_query()
156 return iter(self.cursor)
157
158 def __repr__(self) -> str:
159 return f"<{self.__class__.__name__}: {self}>"
160
161 @property
162 def params_type(self) -> type[dict] | type[tuple] | None:
163 if self.params is None:
164 return None
165 return dict if isinstance(self.params, Mapping) else tuple
166
167 def __str__(self) -> str:
168 if self.params_type is None:
169 return self.sql
170 return self.sql % self.params_type(self.params)
171
172 def _execute_query(self) -> None:
173 self.cursor = get_connection().cursor()
174 self.cursor.execute(self.sql, self.params)
175
176
177class ExplainInfo(NamedTuple):
178 """Information about an EXPLAIN query."""
179
180 format: str | None
181 options: dict[str, Any]
182
183
184class TransformWrapper:
185 """Wrapper for transform functions that supports the has_transforms attribute.
186
187 This replaces functools.partial for transform functions, allowing proper
188 type checking while supporting dynamic attribute assignment.
189 """
190
191 def __init__(
192 self,
193 func: Callable[..., BaseExpression],
194 **kwargs: Any,
195 ):
196 self._partial = functools.partial(func, **kwargs)
197 self.has_transforms: bool = False
198
199 def __call__(self, field: Field[Any], alias: str | None) -> BaseExpression:
200 return self._partial(field, alias)
201
202
203QueryType = TypeVar("QueryType", bound="Query")
204
205
206class Query(BaseExpression):
207 """A single SQL query."""
208
209 alias_prefix = "T"
210 empty_result_set_value = None
211 subq_aliases = frozenset([alias_prefix])
212
213 base_table_class = BaseTable
214 join_class = Join
215
216 default_cols = True
217 default_ordering = True
218 standard_ordering = True
219
220 filter_is_sticky = False
221 subquery = False
222
223 # SQL-related attributes.
224 # Select and related select clauses are expressions to use in the SELECT
225 # clause of the query. The select is used for cases where we want to set up
226 # the select clause to contain other than default fields (values(),
227 # subqueries...). Note that annotations go to annotations dictionary.
228 select: tuple[BaseExpression, ...] = ()
229 # The group_by attribute can have one of the following forms:
230 # - None: no group by at all in the query
231 # - A tuple of expressions: group by (at least) those expressions.
232 # String refs are also allowed for now.
233 # - True: group by all select fields of the model
234 # See compiler.get_group_by() for details.
235 group_by = None
236 order_by = ()
237 low_mark = 0 # Used for offset/limit.
238 high_mark = None # Used for offset/limit.
239 distinct = False
240 distinct_fields: tuple[str, ...] = ()
241 select_for_update = False
242 select_for_update_nowait = False
243 select_for_update_skip_locked = False
244 select_for_update_of: tuple[str, ...] = ()
245 select_for_no_key_update = False
246 select_related: bool | dict[str, Any] = False
247 has_select_fields = False
248 # Arbitrary limit for select_related to prevents infinite recursion.
249 max_depth = 5
250 # Holds the selects defined by a call to values() or values_list()
251 # excluding annotation_select and extra_select.
252 values_select: tuple[str, ...] = ()
253
254 # SQL annotation-related attributes.
255 annotation_select_mask = None
256 _annotation_select_cache = None
257
258 # These are for extensions. The contents are more or less appended verbatim
259 # to the appropriate clause.
260 extra_select_mask = None
261 _extra_select_cache = None
262
263 extra_tables: tuple[str, ...] = ()
264 extra_order_by: tuple[str, ...] = ()
265
266 # A tuple that is a set of model field names and either True, if these are
267 # the fields to defer, or False if these are the only fields to load.
268 deferred_loading = (frozenset(), True)
269
270 explain_info = None
271
272 def __init__(self, model: type[Model] | None, alias_cols: bool = True):
273 self.model = model
274 self.alias_refcount = {}
275 # alias_map is the most important data structure regarding joins.
276 # It's used for recording which joins exist in the query and what
277 # types they are. The key is the alias of the joined table (possibly
278 # the table name) and the value is a Join-like object (see
279 # sql.datastructures.Join for more information).
280 self.alias_map = {}
281 # Whether to provide alias to columns during reference resolving.
282 self.alias_cols = alias_cols
283 # Sometimes the query contains references to aliases in outer queries (as
284 # a result of split_exclude). Correct alias quoting needs to know these
285 # aliases too.
286 # Map external tables to whether they are aliased.
287 self.external_aliases = {}
288 self.table_map = {} # Maps table names to list of aliases.
289 self.used_aliases = set()
290
291 self.where = WhereNode()
292 # Maps alias -> Annotation Expression.
293 self.annotations = {}
294 # These are for extensions. The contents are more or less appended
295 # verbatim to the appropriate clause.
296 self.extra = {} # Maps col_alias -> (col_sql, params).
297
298 self._filtered_relations = {}
299
300 @property
301 def output_field(self) -> Field | None:
302 if len(self.select) == 1:
303 select = self.select[0]
304 return getattr(select, "target", None) or select.field
305 elif len(self.annotation_select) == 1:
306 return next(iter(self.annotation_select.values())).output_field
307
308 @cached_property
309 def base_table(self) -> str | None:
310 for alias in self.alias_map:
311 return alias
312
313 def __str__(self) -> str:
314 """
315 Return the query as a string of SQL with the parameter values
316 substituted in (use sql_with_params() to see the unsubstituted string).
317
318 Parameter values won't necessarily be quoted correctly, since that is
319 done by the database interface at execution time.
320 """
321 sql, params = self.sql_with_params()
322 return sql % params
323
324 def sql_with_params(self) -> SqlWithParams:
325 """
326 Return the query as an SQL string and the parameters that will be
327 substituted into the query.
328 """
329 return self.get_compiler().as_sql()
330
331 def __deepcopy__(self, memo: dict[int, Any]) -> Self:
332 """Limit the amount of work when a Query is deepcopied."""
333 result = self.clone()
334 memo[id(self)] = result
335 return result
336
337 def get_compiler(self, *, elide_empty: bool = True) -> SQLCompiler:
338 """Return a compiler instance for this query."""
339 # Import compilers here to avoid circular imports at module load time
340 from plain.postgres.sql.compiler import SQLCompiler as Compiler
341
342 return Compiler(self, get_connection(), elide_empty)
343
344 def clone(self) -> Self:
345 """
346 Return a copy of the current Query. A lightweight alternative to
347 deepcopy().
348 """
349 obj = Empty()
350 obj.__class__ = self.__class__
351 obj = cast(Self, obj) # Type checker doesn't understand __class__ reassignment
352 # Copy references to everything.
353 obj.__dict__ = self.__dict__.copy()
354 # Clone attributes that can't use shallow copy.
355 obj.alias_refcount = self.alias_refcount.copy()
356 obj.alias_map = self.alias_map.copy()
357 obj.external_aliases = self.external_aliases.copy()
358 obj.table_map = self.table_map.copy()
359 obj.where = self.where.clone()
360 obj.annotations = self.annotations.copy()
361 if self.annotation_select_mask is not None:
362 obj.annotation_select_mask = self.annotation_select_mask.copy()
363 # _annotation_select_cache cannot be copied, as doing so breaks the
364 # (necessary) state in which both annotations and
365 # _annotation_select_cache point to the same underlying objects.
366 # It will get re-populated in the cloned queryset the next time it's
367 # used.
368 obj._annotation_select_cache = None
369 obj.extra = self.extra.copy()
370 if self.extra_select_mask is not None:
371 obj.extra_select_mask = self.extra_select_mask.copy()
372 if self._extra_select_cache is not None:
373 obj._extra_select_cache = self._extra_select_cache.copy()
374 if self.select_related is not False:
375 # Use deepcopy because select_related stores fields in nested
376 # dicts.
377 obj.select_related = copy.deepcopy(obj.select_related)
378 if "subq_aliases" in self.__dict__:
379 obj.subq_aliases = self.subq_aliases.copy()
380 obj.used_aliases = self.used_aliases.copy()
381 obj._filtered_relations = self._filtered_relations.copy()
382 # Clear the cached_property, if it exists.
383 obj.__dict__.pop("base_table", None)
384 return obj
385
386 @overload
387 def chain(self, klass: None = None) -> Self: ...
388
389 @overload
390 def chain(self, klass: type[QueryType]) -> QueryType: ...
391
392 def chain(self, klass: type[Query] | None = None) -> Query:
393 """
394 Return a copy of the current Query that's ready for another operation.
395 The klass argument changes the type of the Query, e.g. UpdateQuery.
396 """
397 obj = self.clone()
398 if klass and obj.__class__ != klass:
399 obj.__class__ = klass
400 if not obj.filter_is_sticky:
401 obj.used_aliases = set()
402 obj.filter_is_sticky = False
403 if hasattr(obj, "_setup_query"):
404 obj._setup_query() # ty: ignore[call-non-callable]
405 return obj
406
407 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
408 clone = self.clone()
409 clone.change_aliases(change_map)
410 return clone
411
412 def _get_col(self, target: Any, field: Field, alias: str | None) -> Col:
413 if not self.alias_cols:
414 alias = None
415 return target.get_col(alias, field)
416
417 def get_aggregation(self, aggregate_exprs: dict[str, Any]) -> dict[str, Any]:
418 """
419 Return the dictionary with the values of the existing aggregations.
420 """
421 if not aggregate_exprs:
422 return {}
423 aggregates = {}
424 for alias, aggregate_expr in aggregate_exprs.items():
425 self.check_alias(alias)
426 aggregate = aggregate_expr.resolve_expression(
427 self, allow_joins=True, reuse=None, summarize=True
428 )
429 if not aggregate.contains_aggregate:
430 raise TypeError(f"{alias} is not an aggregate expression")
431 aggregates[alias] = aggregate
432 # Existing usage of aggregation can be determined by the presence of
433 # selected aggregates but also by filters against aliased aggregates.
434 _, having, qualify = self.where.split_having_qualify()
435 has_existing_aggregation = (
436 any(
437 getattr(annotation, "contains_aggregate", True)
438 for annotation in self.annotations.values()
439 )
440 or having
441 )
442 # Decide if we need to use a subquery.
443 #
444 # Existing aggregations would cause incorrect results as
445 # get_aggregation() must produce just one result and thus must not use
446 # GROUP BY.
447 #
448 # If the query has limit or distinct, or uses set operations, then
449 # those operations must be done in a subquery so that the query
450 # aggregates on the limit and/or distinct results instead of applying
451 # the distinct and limit after the aggregation.
452 if (
453 isinstance(self.group_by, tuple)
454 or self.is_sliced
455 or has_existing_aggregation
456 or qualify
457 or self.distinct
458 ):
459 inner_query = self.clone()
460 inner_query.subquery = True
461 outer_query = AggregateQuery(self.model, inner_query)
462 inner_query.select_for_update = False
463 inner_query.select_related = False
464 inner_query.set_annotation_mask(self.annotation_select)
465 # Queries with distinct_fields need ordering and when a limit is
466 # applied we must take the slice from the ordered query. Otherwise
467 # no need for ordering.
468 inner_query.clear_ordering(force=False)
469 if not inner_query.distinct:
470 # If the inner query uses default select and it has some
471 # aggregate annotations, then we must make sure the inner
472 # query is grouped by the main model's primary key. However,
473 # clearing the select clause can alter results if distinct is
474 # used.
475 if inner_query.default_cols and has_existing_aggregation:
476 assert self.model is not None, "Aggregation requires a model"
477 inner_query.group_by = (
478 self.model._model_meta.get_forward_field("id").get_col(
479 inner_query.get_initial_alias()
480 ),
481 )
482 inner_query.default_cols = False
483 if not qualify:
484 # Mask existing annotations that are not referenced by
485 # aggregates to be pushed to the outer query unless
486 # filtering against window functions is involved as it
487 # requires complex realising.
488 annotation_mask = set()
489 for aggregate in aggregates.values():
490 annotation_mask |= aggregate.get_refs()
491 inner_query.set_annotation_mask(annotation_mask)
492
493 # Add aggregates to the outer AggregateQuery. This requires making
494 # sure all columns referenced by the aggregates are selected in the
495 # inner query. It is achieved by retrieving all column references
496 # by the aggregates, explicitly selecting them in the inner query,
497 # and making sure the aggregates are repointed to them.
498 col_refs = {}
499 for alias, aggregate in aggregates.items():
500 replacements = {}
501 for col in self._gen_cols([aggregate], resolve_refs=False):
502 if not (col_ref := col_refs.get(col)):
503 index = len(col_refs) + 1
504 col_alias = f"__col{index}"
505 col_ref = Ref(col_alias, col)
506 col_refs[col] = col_ref
507 inner_query.annotations[col_alias] = col
508 inner_query.append_annotation_mask([col_alias])
509 replacements[col] = col_ref
510 outer_query.annotations[alias] = aggregate.replace_expressions(
511 replacements
512 )
513 if (
514 inner_query.select == ()
515 and not inner_query.default_cols
516 and not inner_query.annotation_select_mask
517 ):
518 # In case of Model.objects[0:3].count(), there would be no
519 # field selected in the inner query, yet we must use a subquery.
520 # So, make sure at least one field is selected.
521 assert self.model is not None, "Count with slicing requires a model"
522 inner_query.select = (
523 self.model._model_meta.get_forward_field("id").get_col(
524 inner_query.get_initial_alias()
525 ),
526 )
527 else:
528 outer_query = self
529 self.select = ()
530 self.default_cols = False
531 self.extra = {}
532 if self.annotations:
533 # Inline reference to existing annotations and mask them as
534 # they are unnecessary given only the summarized aggregations
535 # are requested.
536 replacements = {
537 Ref(alias, annotation): annotation
538 for alias, annotation in self.annotations.items()
539 }
540 self.annotations = {
541 alias: aggregate.replace_expressions(replacements)
542 for alias, aggregate in aggregates.items()
543 }
544 else:
545 self.annotations = aggregates
546 self.set_annotation_mask(aggregates)
547
548 empty_set_result = [
549 expression.empty_result_set_value
550 for expression in outer_query.annotation_select.values()
551 ]
552 elide_empty = not any(result is NotImplemented for result in empty_set_result)
553 outer_query.clear_ordering(force=True)
554 outer_query.clear_limits()
555 outer_query.select_for_update = False
556 outer_query.select_related = False
557 compiler = outer_query.get_compiler(elide_empty=elide_empty)
558 result = compiler.execute_sql(SINGLE)
559 if result is None:
560 result = empty_set_result
561 else:
562 from plain.postgres.sql.compiler import apply_converters, get_converters
563
564 converters = get_converters(
565 outer_query.annotation_select.values(), compiler.connection
566 )
567 result = next(apply_converters((result,), converters, compiler.connection))
568
569 return dict(zip(outer_query.annotation_select, result))
570
571 def get_count(self) -> int:
572 """
573 Perform a COUNT() query using the current filter constraints.
574 """
575 obj = self.clone()
576 return obj.get_aggregation({"__count": Count("*")})["__count"]
577
578 def has_filters(self) -> bool:
579 return bool(self.where)
580
581 def exists(self, limit: bool = True) -> Self:
582 q = self.clone()
583 if not (q.distinct and q.is_sliced):
584 if q.group_by is True:
585 assert self.model is not None, "GROUP BY requires a model"
586 q.add_fields(
587 (f.attname for f in self.model._model_meta.concrete_fields), False
588 )
589 # Disable GROUP BY aliases to avoid orphaning references to the
590 # SELECT clause which is about to be cleared.
591 q.set_group_by(allow_aliases=False)
592 q.clear_select_clause()
593 q.clear_ordering(force=True)
594 if limit:
595 q.set_limits(high=1)
596 q.add_annotation(Value(1), "a")
597 return q
598
599 def has_results(self) -> bool:
600 q = self.exists()
601 compiler = q.get_compiler()
602 return compiler.has_results()
603
604 def explain(self, format: str | None = None, **options: Any) -> str:
605 q = self.clone()
606 for option_name in options:
607 if (
608 not EXPLAIN_OPTIONS_PATTERN.fullmatch(option_name)
609 or "--" in option_name
610 ):
611 raise ValueError(f"Invalid option name: {option_name!r}.")
612 q.explain_info = ExplainInfo(format, options)
613 compiler = q.get_compiler()
614 return "\n".join(compiler.explain_query())
615
616 def combine(self, rhs: Query, connector: str) -> None:
617 """
618 Merge the 'rhs' query into the current one (with any 'rhs' effects
619 being applied *after* (that is, "to the right of") anything in the
620 current query. 'rhs' is not modified during a call to this function.
621
622 The 'connector' parameter describes how to connect filters from the
623 'rhs' query.
624 """
625 if self.model != rhs.model:
626 raise TypeError("Cannot combine queries on two different base models.")
627 if self.is_sliced:
628 raise TypeError("Cannot combine queries once a slice has been taken.")
629 if self.distinct != rhs.distinct:
630 raise TypeError("Cannot combine a unique query with a non-unique query.")
631 if self.distinct_fields != rhs.distinct_fields:
632 raise TypeError("Cannot combine queries with different distinct fields.")
633
634 # If lhs and rhs shares the same alias prefix, it is possible to have
635 # conflicting alias changes like T4 -> T5, T5 -> T6, which might end up
636 # as T4 -> T6 while combining two querysets. To prevent this, change an
637 # alias prefix of the rhs and update current aliases accordingly,
638 # except if the alias is the base table since it must be present in the
639 # query on both sides.
640 initial_alias = self.get_initial_alias()
641 assert initial_alias is not None
642 rhs.bump_prefix(self, exclude={initial_alias})
643
644 # Work out how to relabel the rhs aliases, if necessary.
645 change_map = {}
646 conjunction = connector == AND
647
648 # Determine which existing joins can be reused. When combining the
649 # query with AND we must recreate all joins for m2m filters. When
650 # combining with OR we can reuse joins. The reason is that in AND
651 # case a single row can't fulfill a condition like:
652 # revrel__col=1 & revrel__col=2
653 # But, there might be two different related rows matching this
654 # condition. In OR case a single True is enough, so single row is
655 # enough, too.
656 #
657 # Note that we will be creating duplicate joins for non-m2m joins in
658 # the AND case. The results will be correct but this creates too many
659 # joins. This is something that could be fixed later on.
660 reuse = set() if conjunction else set(self.alias_map)
661 joinpromoter = JoinPromoter(connector, 2, False)
662 joinpromoter.add_votes(
663 j for j in self.alias_map if self.alias_map[j].join_type == INNER
664 )
665 rhs_votes = set()
666 # Now, add the joins from rhs query into the new query (skipping base
667 # table).
668 rhs_tables = list(rhs.alias_map)[1:]
669 for alias in rhs_tables:
670 join = rhs.alias_map[alias]
671 # If the left side of the join was already relabeled, use the
672 # updated alias.
673 join = join.relabeled_clone(change_map)
674 new_alias = self.join(join, reuse=reuse)
675 if join.join_type == INNER:
676 rhs_votes.add(new_alias)
677 # We can't reuse the same join again in the query. If we have two
678 # distinct joins for the same connection in rhs query, then the
679 # combined query must have two joins, too.
680 reuse.discard(new_alias)
681 if alias != new_alias:
682 change_map[alias] = new_alias
683 if not rhs.alias_refcount[alias]:
684 # The alias was unused in the rhs query. Unref it so that it
685 # will be unused in the new query, too. We have to add and
686 # unref the alias so that join promotion has information of
687 # the join type for the unused alias.
688 self.unref_alias(new_alias)
689 joinpromoter.add_votes(rhs_votes)
690 joinpromoter.update_join_types(self)
691
692 # Combine subqueries aliases to ensure aliases relabelling properly
693 # handle subqueries when combining where and select clauses.
694 self.subq_aliases |= rhs.subq_aliases
695
696 # Now relabel a copy of the rhs where-clause and add it to the current
697 # one.
698 w = rhs.where.clone()
699 w.relabel_aliases(change_map)
700 self.where.add(w, connector)
701
702 # Selection columns and extra extensions are those provided by 'rhs'.
703 if rhs.select:
704 self.set_select([col.relabeled_clone(change_map) for col in rhs.select])
705 else:
706 self.select = ()
707
708 if connector == OR:
709 # It would be nice to be able to handle this, but the queries don't
710 # really make sense (or return consistent value sets). Not worth
711 # the extra complexity when you can write a real query instead.
712 if self.extra and rhs.extra:
713 raise ValueError(
714 "When merging querysets using 'or', you cannot have "
715 "extra(select=...) on both sides."
716 )
717 self.extra.update(rhs.extra)
718 extra_select_mask = set()
719 if self.extra_select_mask is not None:
720 extra_select_mask.update(self.extra_select_mask)
721 if rhs.extra_select_mask is not None:
722 extra_select_mask.update(rhs.extra_select_mask)
723 if extra_select_mask:
724 self.set_extra_mask(extra_select_mask)
725 self.extra_tables += rhs.extra_tables
726
727 # Ordering uses the 'rhs' ordering, unless it has none, in which case
728 # the current ordering is used.
729 self.order_by = rhs.order_by or self.order_by
730 self.extra_order_by = rhs.extra_order_by or self.extra_order_by
731
732 def _get_defer_select_mask(
733 self,
734 meta: Meta,
735 mask: dict[str, Any],
736 select_mask: dict[Any, Any] | None = None,
737 ) -> dict[Any, Any]:
738 from plain.postgres.fields.related import RelatedField
739
740 if select_mask is None:
741 select_mask = {}
742 select_mask[meta.get_forward_field("id")] = {}
743 # All concrete fields that are not part of the defer mask must be
744 # loaded. If a relational field is encountered it gets added to the
745 # mask for it be considered if `select_related` and the cycle continues
746 # by recursively calling this function.
747 for field in meta.concrete_fields:
748 field_mask = mask.pop(field.name, None)
749 if field_mask is None:
750 select_mask.setdefault(field, {})
751 elif field_mask:
752 if not isinstance(field, RelatedField):
753 raise FieldError(next(iter(field_mask)))
754 field_select_mask = select_mask.setdefault(field, {})
755 related_model = field.remote_field.model
756 self._get_defer_select_mask(
757 related_model._model_meta, field_mask, field_select_mask
758 )
759 # Remaining defer entries must be references to reverse relationships.
760 # The following code is expected to raise FieldError if it encounters
761 # a malformed defer entry.
762 for field_name, field_mask in mask.items():
763 if filtered_relation := self._filtered_relations.get(field_name):
764 relation = meta.get_reverse_relation(filtered_relation.relation_name)
765 field_select_mask = select_mask.setdefault((field_name, relation), {})
766 field = relation.field
767 else:
768 field = meta.get_reverse_relation(field_name).field
769 field_select_mask = select_mask.setdefault(field, {})
770 related_model = field.model
771 self._get_defer_select_mask(
772 related_model._model_meta, field_mask, field_select_mask
773 )
774 return select_mask
775
776 def _get_only_select_mask(
777 self,
778 meta: Meta,
779 mask: dict[str, Any],
780 select_mask: dict[Any, Any] | None = None,
781 ) -> dict[Any, Any]:
782 from plain.postgres.fields.related import RelatedField
783
784 if select_mask is None:
785 select_mask = {}
786 select_mask[meta.get_forward_field("id")] = {}
787 # Only include fields mentioned in the mask.
788 for field_name, field_mask in mask.items():
789 field = meta.get_field(field_name)
790 field_select_mask = select_mask.setdefault(field, {})
791 if field_mask:
792 if not isinstance(field, RelatedField):
793 raise FieldError(next(iter(field_mask)))
794 related_model = field.remote_field.model
795 self._get_only_select_mask(
796 related_model._model_meta, field_mask, field_select_mask
797 )
798 return select_mask
799
800 def get_select_mask(self) -> dict[Any, Any]:
801 """
802 Convert the self.deferred_loading data structure to an alternate data
803 structure, describing the field that *will* be loaded. This is used to
804 compute the columns to select from the database and also by the
805 QuerySet class to work out which fields are being initialized on each
806 model. Models that have all their fields included aren't mentioned in
807 the result, only those that have field restrictions in place.
808 """
809 field_names, defer = self.deferred_loading
810 if not field_names:
811 return {}
812 mask = {}
813 for field_name in field_names:
814 part_mask = mask
815 for part in field_name.split(LOOKUP_SEP):
816 part_mask = part_mask.setdefault(part, {})
817 assert self.model is not None, "Deferred/only field loading requires a model"
818 meta = self.model._model_meta
819 if defer:
820 return self._get_defer_select_mask(meta, mask)
821 return self._get_only_select_mask(meta, mask)
822
823 def table_alias(
824 self, table_name: str, create: bool = False, filtered_relation: Any = None
825 ) -> tuple[str, bool]:
826 """
827 Return a table alias for the given table_name and whether this is a
828 new alias or not.
829
830 If 'create' is true, a new alias is always created. Otherwise, the
831 most recently created alias for the table (if one exists) is reused.
832 """
833 alias_list = self.table_map.get(table_name)
834 if not create and alias_list:
835 alias = alias_list[0]
836 self.alias_refcount[alias] += 1
837 return alias, False
838
839 # Create a new alias for this table.
840 if alias_list:
841 alias = "%s%d" % (self.alias_prefix, len(self.alias_map) + 1) # noqa: UP031
842 alias_list.append(alias)
843 else:
844 # The first occurrence of a table uses the table name directly.
845 alias = (
846 filtered_relation.alias if filtered_relation is not None else table_name
847 )
848 self.table_map[table_name] = [alias]
849 self.alias_refcount[alias] = 1
850 return alias, True
851
852 def ref_alias(self, alias: str) -> None:
853 """Increases the reference count for this alias."""
854 self.alias_refcount[alias] += 1
855
856 def unref_alias(self, alias: str, amount: int = 1) -> None:
857 """Decreases the reference count for this alias."""
858 self.alias_refcount[alias] -= amount
859
860 def promote_joins(self, aliases: set[str] | list[str]) -> None:
861 """
862 Promote recursively the join type of given aliases and its children to
863 an outer join. If 'unconditional' is False, only promote the join if
864 it is nullable or the parent join is an outer join.
865
866 The children promotion is done to avoid join chains that contain a LOUTER
867 b INNER c. So, if we have currently a INNER b INNER c and a->b is promoted,
868 then we must also promote b->c automatically, or otherwise the promotion
869 of a->b doesn't actually change anything in the query results.
870 """
871 aliases = list(aliases)
872 while aliases:
873 alias = aliases.pop(0)
874 if self.alias_map[alias].join_type is None:
875 # This is the base table (first FROM entry) - this table
876 # isn't really joined at all in the query, so we should not
877 # alter its join type.
878 continue
879 # Only the first alias (skipped above) should have None join_type
880 assert self.alias_map[alias].join_type is not None
881 parent_alias = self.alias_map[alias].parent_alias
882 parent_louter = (
883 parent_alias and self.alias_map[parent_alias].join_type == LOUTER
884 )
885 already_louter = self.alias_map[alias].join_type == LOUTER
886 if (self.alias_map[alias].nullable or parent_louter) and not already_louter:
887 self.alias_map[alias] = self.alias_map[alias].promote()
888 # Join type of 'alias' changed, so re-examine all aliases that
889 # refer to this one.
890 aliases.extend(
891 join
892 for join in self.alias_map
893 if self.alias_map[join].parent_alias == alias
894 and join not in aliases
895 )
896
897 def demote_joins(self, aliases: set[str] | list[str]) -> None:
898 """
899 Change join type from LOUTER to INNER for all joins in aliases.
900
901 Similarly to promote_joins(), this method must ensure no join chains
902 containing first an outer, then an inner join are generated. If we
903 are demoting b->c join in chain a LOUTER b LOUTER c then we must
904 demote a->b automatically, or otherwise the demotion of b->c doesn't
905 actually change anything in the query results. .
906 """
907 aliases = list(aliases)
908 while aliases:
909 alias = aliases.pop(0)
910 if self.alias_map[alias].join_type == LOUTER:
911 self.alias_map[alias] = self.alias_map[alias].demote()
912 parent_alias = self.alias_map[alias].parent_alias
913 if self.alias_map[parent_alias].join_type == INNER:
914 aliases.append(parent_alias)
915
916 def reset_refcounts(self, to_counts: dict[str, int]) -> None:
917 """
918 Reset reference counts for aliases so that they match the value passed
919 in `to_counts`.
920 """
921 for alias, cur_refcount in self.alias_refcount.copy().items():
922 unref_amount = cur_refcount - to_counts.get(alias, 0)
923 self.unref_alias(alias, unref_amount)
924
925 def change_aliases(self, change_map: dict[str, str]) -> None:
926 """
927 Change the aliases in change_map (which maps old-alias -> new-alias),
928 relabelling any references to them in select columns and the where
929 clause.
930 """
931 # If keys and values of change_map were to intersect, an alias might be
932 # updated twice (e.g. T4 -> T5, T5 -> T6, so also T4 -> T6) depending
933 # on their order in change_map.
934 assert set(change_map).isdisjoint(change_map.values())
935
936 # 1. Update references in "select" (normal columns plus aliases),
937 # "group by" and "where".
938 self.where.relabel_aliases(change_map)
939 if isinstance(self.group_by, tuple):
940 self.group_by = tuple(
941 [col.relabeled_clone(change_map) for col in self.group_by]
942 )
943 self.select = tuple([col.relabeled_clone(change_map) for col in self.select])
944 self.annotations = self.annotations and {
945 key: col.relabeled_clone(change_map)
946 for key, col in self.annotations.items()
947 }
948
949 # 2. Rename the alias in the internal table/alias datastructures.
950 for old_alias, new_alias in change_map.items():
951 if old_alias not in self.alias_map:
952 continue
953 alias_data = self.alias_map[old_alias].relabeled_clone(change_map)
954 self.alias_map[new_alias] = alias_data
955 self.alias_refcount[new_alias] = self.alias_refcount[old_alias]
956 del self.alias_refcount[old_alias]
957 del self.alias_map[old_alias]
958
959 table_aliases = self.table_map[alias_data.table_name]
960 for pos, alias in enumerate(table_aliases):
961 if alias == old_alias:
962 table_aliases[pos] = new_alias
963 break
964 self.external_aliases = {
965 # Table is aliased or it's being changed and thus is aliased.
966 change_map.get(alias, alias): (aliased or alias in change_map)
967 for alias, aliased in self.external_aliases.items()
968 }
969
970 def bump_prefix(
971 self, other_query: Query, exclude: set[str] | dict[str, str] | None = None
972 ) -> None:
973 """
974 Change the alias prefix to the next letter in the alphabet in a way
975 that the other query's aliases and this query's aliases will not
976 conflict. Even tables that previously had no alias will get an alias
977 after this call. To prevent changing aliases use the exclude parameter.
978 """
979
980 def prefix_gen() -> TypingIterator[str]:
981 """
982 Generate a sequence of characters in alphabetical order:
983 -> 'A', 'B', 'C', ...
984
985 When the alphabet is finished, the sequence will continue with the
986 Cartesian product:
987 -> 'AA', 'AB', 'AC', ...
988 """
989 alphabet = ascii_uppercase
990 prefix = chr(ord(self.alias_prefix) + 1)
991 yield prefix
992 for n in count(1):
993 seq = alphabet[alphabet.index(prefix) :] if prefix else alphabet
994 for s in product(seq, repeat=n):
995 yield "".join(s)
996 prefix = None
997
998 if self.alias_prefix != other_query.alias_prefix:
999 # No clashes between self and outer query should be possible.
1000 return
1001
1002 # Explicitly avoid infinite loop. The constant divider is based on how
1003 # much depth recursive subquery references add to the stack. This value
1004 # might need to be adjusted when adding or removing function calls from
1005 # the code path in charge of performing these operations.
1006 local_recursion_limit = sys.getrecursionlimit() // 16
1007 for pos, prefix in enumerate(prefix_gen()):
1008 if prefix not in self.subq_aliases:
1009 self.alias_prefix = prefix
1010 break
1011 if pos > local_recursion_limit:
1012 raise RecursionError(
1013 "Maximum recursion depth exceeded: too many subqueries."
1014 )
1015 self.subq_aliases = self.subq_aliases.union([self.alias_prefix])
1016 other_query.subq_aliases = other_query.subq_aliases.union(self.subq_aliases)
1017 if exclude is None:
1018 exclude = {}
1019 self.change_aliases(
1020 {
1021 alias: "%s%d" % (self.alias_prefix, pos) # noqa: UP031
1022 for pos, alias in enumerate(self.alias_map)
1023 if alias not in exclude
1024 }
1025 )
1026
1027 def get_initial_alias(self) -> str | None:
1028 """
1029 Return the first alias for this query, after increasing its reference
1030 count.
1031 """
1032 if self.alias_map:
1033 alias = self.base_table
1034 self.ref_alias(alias) # ty: ignore[invalid-argument-type]
1035 elif self.model:
1036 alias = self.join(
1037 self.base_table_class(self.model.model_options.db_table, None) # ty: ignore[invalid-argument-type]
1038 )
1039 else:
1040 alias = None
1041 return alias
1042
1043 def count_active_tables(self) -> int:
1044 """
1045 Return the number of tables in this query with a non-zero reference
1046 count. After execution, the reference counts are zeroed, so tables
1047 added in compiler will not be seen by this method.
1048 """
1049 return len([1 for count in self.alias_refcount.values() if count])
1050
1051 def join(
1052 self,
1053 join: BaseTable | Join,
1054 reuse: set[str] | None = None,
1055 reuse_with_filtered_relation: bool = False,
1056 ) -> str:
1057 """
1058 Return an alias for the 'join', either reusing an existing alias for
1059 that join or creating a new one. 'join' is either a base_table_class or
1060 join_class.
1061
1062 The 'reuse' parameter can be either None which means all joins are
1063 reusable, or it can be a set containing the aliases that can be reused.
1064
1065 The 'reuse_with_filtered_relation' parameter is used when computing
1066 FilteredRelation instances.
1067
1068 A join is always created as LOUTER if the lhs alias is LOUTER to make
1069 sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new
1070 joins are created as LOUTER if the join is nullable.
1071 """
1072 if reuse_with_filtered_relation and reuse:
1073 reuse_aliases = [
1074 a for a, j in self.alias_map.items() if a in reuse and j.equals(join)
1075 ]
1076 else:
1077 reuse_aliases = [
1078 a
1079 for a, j in self.alias_map.items()
1080 if (reuse is None or a in reuse) and j == join
1081 ]
1082 if reuse_aliases:
1083 if join.table_alias in reuse_aliases:
1084 reuse_alias = join.table_alias
1085 else:
1086 # Reuse the most recent alias of the joined table
1087 # (a many-to-many relation may be joined multiple times).
1088 reuse_alias = reuse_aliases[-1]
1089 self.ref_alias(reuse_alias)
1090 return reuse_alias
1091
1092 # No reuse is possible, so we need a new alias.
1093 alias, _ = self.table_alias(
1094 join.table_name, create=True, filtered_relation=join.filtered_relation
1095 )
1096 if isinstance(join, Join):
1097 if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable:
1098 join_type = LOUTER
1099 else:
1100 join_type = INNER
1101 join.join_type = join_type
1102 join.table_alias = alias
1103 self.alias_map[alias] = join
1104 return alias
1105
1106 def check_alias(self, alias: str) -> None:
1107 if FORBIDDEN_ALIAS_PATTERN.search(alias):
1108 raise ValueError(
1109 "Column aliases cannot contain whitespace characters, quotation marks, "
1110 "semicolons, or SQL comments."
1111 )
1112
1113 def add_annotation(
1114 self, annotation: BaseExpression, alias: str, select: bool = True
1115 ) -> None:
1116 """Add a single annotation expression to the Query."""
1117 self.check_alias(alias)
1118 annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None)
1119 if select:
1120 self.append_annotation_mask([alias])
1121 else:
1122 self.set_annotation_mask(set(self.annotation_select).difference({alias}))
1123 self.annotations[alias] = annotation
1124
1125 def resolve_expression(
1126 self,
1127 query: Any = None,
1128 allow_joins: bool = True,
1129 reuse: Any = None,
1130 summarize: bool = False,
1131 for_save: bool = False,
1132 ) -> Self:
1133 clone = self.clone()
1134 # Subqueries need to use a different set of aliases than the outer query.
1135 clone.bump_prefix(query)
1136 clone.subquery = True
1137 clone.where.resolve_expression(query, allow_joins, reuse, summarize, for_save)
1138 for key, value in clone.annotations.items():
1139 resolved = value.resolve_expression(
1140 query, allow_joins, reuse, summarize, for_save
1141 )
1142 if hasattr(resolved, "external_aliases"):
1143 resolved.external_aliases.update(clone.external_aliases)
1144 clone.annotations[key] = resolved
1145 # Outer query's aliases are considered external.
1146 for alias, table in query.alias_map.items():
1147 clone.external_aliases[alias] = (
1148 isinstance(table, Join)
1149 and table.join_field.related_model.model_options.db_table != alias
1150 ) or (
1151 isinstance(table, BaseTable) and table.table_name != table.table_alias
1152 )
1153 return clone
1154
1155 def get_external_cols(self) -> list[Col]:
1156 exprs = chain(self.annotations.values(), self.where.children)
1157 return [
1158 col
1159 for col in self._gen_cols(exprs, include_external=True)
1160 if col.alias in self.external_aliases
1161 ]
1162
1163 def get_group_by_cols(
1164 self, wrapper: BaseExpression | None = None
1165 ) -> list[BaseExpression]:
1166 # If wrapper is referenced by an alias for an explicit GROUP BY through
1167 # values() a reference to this expression and not the self must be
1168 # returned to ensure external column references are not grouped against
1169 # as well.
1170 external_cols = self.get_external_cols()
1171 if any(col.possibly_multivalued for col in external_cols):
1172 return [wrapper or self]
1173 # Cast needed because list is invariant: list[Col] is not list[BaseExpression]
1174 return cast(list[BaseExpression], external_cols)
1175
1176 def as_sql(
1177 self, compiler: SQLCompiler, connection: DatabaseConnection
1178 ) -> SqlWithParams:
1179 sql, params = self.get_compiler().as_sql()
1180 if self.subquery:
1181 sql = f"({sql})"
1182 return sql, params
1183
1184 def resolve_lookup_value(
1185 self, value: Any, can_reuse: set[str] | None, allow_joins: bool
1186 ) -> Any:
1187 if isinstance(value, ResolvableExpression):
1188 value = value.resolve_expression(
1189 self,
1190 reuse=can_reuse,
1191 allow_joins=allow_joins,
1192 )
1193 elif isinstance(value, list | tuple):
1194 # The items of the iterable may be expressions and therefore need
1195 # to be resolved independently.
1196 values = (
1197 self.resolve_lookup_value(sub_value, can_reuse, allow_joins)
1198 for sub_value in value
1199 )
1200 type_ = type(value)
1201 if hasattr(type_, "_make"): # namedtuple
1202 return type_(*values)
1203 return type_(values)
1204 return value
1205
1206 def solve_lookup_type(
1207 self, lookup: str, summarize: bool = False
1208 ) -> tuple[
1209 list[str] | tuple[str, ...], tuple[str, ...], BaseExpression | Literal[False]
1210 ]:
1211 """
1212 Solve the lookup type from the lookup (e.g.: 'foobar__id__icontains').
1213 """
1214 lookup_splitted = lookup.split(LOOKUP_SEP)
1215 if self.annotations:
1216 annotation, expression_lookups = refs_expression(
1217 lookup_splitted, self.annotations
1218 )
1219 if annotation:
1220 expression = self.annotations[annotation]
1221 if summarize:
1222 expression = Ref(annotation, expression)
1223 return expression_lookups, (), expression
1224 assert self.model is not None, "Field lookups require a model"
1225 meta = self.model._model_meta
1226 _, field, _, lookup_parts = self.names_to_path(lookup_splitted, meta)
1227 field_parts = lookup_splitted[0 : len(lookup_splitted) - len(lookup_parts)]
1228 if len(lookup_parts) > 1 and not field_parts:
1229 raise FieldError(
1230 f'Invalid lookup "{lookup}" for model {meta.model.__name__}".'
1231 )
1232 return lookup_parts, tuple(field_parts), False
1233
1234 def check_query_object_type(
1235 self, value: Any, meta: Meta, field: Field | ForeignObjectRel
1236 ) -> None:
1237 """
1238 Check whether the object passed while querying is of the correct type.
1239 If not, raise a ValueError specifying the wrong object.
1240 """
1241 from plain.postgres import Model
1242
1243 if isinstance(value, Model):
1244 if not check_rel_lookup_compatibility(value._model_meta.model, meta, field):
1245 raise ValueError(
1246 f'Cannot query "{value}": Must be "{meta.model.model_options.object_name}" instance.'
1247 )
1248
1249 def check_related_objects(
1250 self, field: RelatedField | ForeignObjectRel, value: Any, meta: Meta
1251 ) -> None:
1252 """Check the type of object passed to query relations."""
1253 from plain.postgres import Model
1254
1255 # Check that the field and the queryset use the same model in a
1256 # query like .filter(author=Author.query.all()). For example, the
1257 # meta would be Author's (from the author field) and value.model
1258 # would be Author.query.all() queryset's .model (Author also).
1259 # The field is the related field on the lhs side.
1260 if (
1261 isinstance(value, Query)
1262 and not value.has_select_fields
1263 and not check_rel_lookup_compatibility(value.model, meta, field)
1264 ):
1265 raise ValueError(
1266 f'Cannot use QuerySet for "{value.model.model_options.object_name}": Use a QuerySet for "{meta.model.model_options.object_name}".'
1267 )
1268 elif isinstance(value, Model):
1269 self.check_query_object_type(value, meta, field)
1270 elif isinstance(value, Iterable):
1271 for v in value:
1272 self.check_query_object_type(v, meta, field)
1273
1274 def check_filterable(self, expression: Any) -> None:
1275 """Raise an error if expression cannot be used in a WHERE clause."""
1276 if isinstance(expression, ResolvableExpression) and not getattr(
1277 expression, "filterable", True
1278 ):
1279 raise psycopg.NotSupportedError(
1280 expression.__class__.__name__ + " is disallowed in the filter clause."
1281 )
1282 if hasattr(expression, "get_source_expressions"):
1283 for expr in expression.get_source_expressions():
1284 self.check_filterable(expr)
1285
1286 def build_lookup(
1287 self, lookups: list[str], lhs: BaseExpression, rhs: Any
1288 ) -> Lookup | None:
1289 """
1290 Try to extract transforms and lookup from given lhs.
1291
1292 The lhs value is something that works like SQLExpression.
1293 The rhs value is what the lookup is going to compare against.
1294 The lookups is a list of names to extract using get_lookup()
1295 and get_transform().
1296 """
1297 # __exact is the default lookup if one isn't given.
1298 *transforms, lookup_name = lookups or ["exact"]
1299 for name in transforms:
1300 lhs = self.try_transform(lhs, name)
1301 # First try get_lookup() so that the lookup takes precedence if the lhs
1302 # supports both transform and lookup for the name.
1303 lookup_class = lhs.get_lookup(lookup_name)
1304 if not lookup_class:
1305 # A lookup wasn't found. Try to interpret the name as a transform
1306 # and do an Exact lookup against it.
1307 lhs = self.try_transform(lhs, lookup_name)
1308 lookup_name = "exact"
1309 lookup_class = lhs.get_lookup(lookup_name)
1310 if not lookup_class:
1311 return
1312
1313 lookup = lookup_class(lhs, rhs)
1314 # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
1315 # uses of None as a query value unless the lookup supports it.
1316 if lookup.rhs is None and not lookup.can_use_none_as_rhs:
1317 if lookup_name not in ("exact", "iexact"):
1318 raise ValueError("Cannot use None as a query value")
1319 isnull_lookup = lhs.get_lookup("isnull")
1320 assert isnull_lookup is not None
1321 return isnull_lookup(lhs, True)
1322
1323 return lookup
1324
1325 def try_transform(self, lhs: BaseExpression, name: str) -> BaseExpression:
1326 """
1327 Helper method for build_lookup(). Try to fetch and initialize
1328 a transform for name parameter from lhs.
1329 """
1330 transform_class = lhs.get_transform(name)
1331 if transform_class:
1332 return transform_class(lhs)
1333 else:
1334 output_field = lhs.output_field.__class__
1335 suggested_lookups = difflib.get_close_matches(
1336 name, output_field.get_lookups()
1337 )
1338 if suggested_lookups:
1339 suggestion = ", perhaps you meant {}?".format(
1340 " or ".join(suggested_lookups)
1341 )
1342 else:
1343 suggestion = "."
1344 raise FieldError(
1345 f"Unsupported lookup '{name}' for {output_field.__name__} or join on the field not "
1346 f"permitted{suggestion}"
1347 )
1348
1349 def build_filter(
1350 self,
1351 filter_expr: tuple[str, Any] | Q | BaseExpression,
1352 branch_negated: bool = False,
1353 current_negated: bool = False,
1354 can_reuse: set[str] | None = None,
1355 allow_joins: bool = True,
1356 split_subq: bool = True,
1357 reuse_with_filtered_relation: bool = False,
1358 check_filterable: bool = True,
1359 summarize: bool = False,
1360 ) -> tuple[WhereNode, set[str] | tuple[()]]:
1361 from plain.postgres.fields.related import RelatedField
1362
1363 """
1364 Build a WhereNode for a single filter clause but don't add it
1365 to this Query. Query.add_q() will then add this filter to the where
1366 Node.
1367
1368 The 'branch_negated' tells us if the current branch contains any
1369 negations. This will be used to determine if subqueries are needed.
1370
1371 The 'current_negated' is used to determine if the current filter is
1372 negated or not and this will be used to determine if IS NULL filtering
1373 is needed.
1374
1375 The difference between current_negated and branch_negated is that
1376 branch_negated is set on first negation, but current_negated is
1377 flipped for each negation.
1378
1379 Note that add_filter will not do any negating itself, that is done
1380 upper in the code by add_q().
1381
1382 The 'can_reuse' is a set of reusable joins for multijoins.
1383
1384 If 'reuse_with_filtered_relation' is True, then only joins in can_reuse
1385 will be reused.
1386
1387 The method will create a filter clause that can be added to the current
1388 query. However, if the filter isn't added to the query then the caller
1389 is responsible for unreffing the joins used.
1390 """
1391 if isinstance(filter_expr, dict):
1392 raise FieldError("Cannot parse keyword query as dict")
1393 if isinstance(filter_expr, Q):
1394 return self._add_q(
1395 filter_expr,
1396 branch_negated=branch_negated,
1397 current_negated=current_negated,
1398 used_aliases=can_reuse,
1399 allow_joins=allow_joins,
1400 split_subq=split_subq,
1401 check_filterable=check_filterable,
1402 summarize=summarize,
1403 )
1404 if isinstance(filter_expr, ResolvableExpression):
1405 if not getattr(filter_expr, "conditional", False):
1406 raise TypeError("Cannot filter against a non-conditional expression.")
1407 condition = filter_expr.resolve_expression(
1408 self, allow_joins=allow_joins, summarize=summarize
1409 )
1410 if not isinstance(condition, Lookup):
1411 condition = self.build_lookup(["exact"], condition, True)
1412 return WhereNode([condition], connector=AND), set()
1413 if isinstance(filter_expr, BaseExpression):
1414 raise TypeError(f"Unexpected BaseExpression type: {type(filter_expr)}")
1415 arg, value = filter_expr
1416 if not arg:
1417 raise FieldError(f"Cannot parse keyword query {arg!r}")
1418 lookups, parts, reffed_expression = self.solve_lookup_type(arg, summarize)
1419
1420 if check_filterable:
1421 self.check_filterable(reffed_expression)
1422
1423 if not allow_joins and len(parts) > 1:
1424 raise FieldError("Joined field references are not permitted in this query")
1425
1426 pre_joins = self.alias_refcount.copy()
1427 value = self.resolve_lookup_value(value, can_reuse, allow_joins)
1428 used_joins = {
1429 k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)
1430 }
1431
1432 if check_filterable:
1433 self.check_filterable(value)
1434
1435 if reffed_expression:
1436 condition = self.build_lookup(list(lookups), reffed_expression, value)
1437 return WhereNode([condition], connector=AND), set()
1438
1439 assert self.model is not None, "Building filters requires a model"
1440 meta = self.model._model_meta
1441 alias = self.get_initial_alias()
1442 assert alias is not None
1443 allow_many = not branch_negated or not split_subq
1444
1445 try:
1446 join_info = self.setup_joins(
1447 list(parts),
1448 meta,
1449 alias,
1450 can_reuse=can_reuse,
1451 allow_many=allow_many,
1452 reuse_with_filtered_relation=reuse_with_filtered_relation,
1453 )
1454
1455 # Prevent iterator from being consumed by check_related_objects()
1456 if isinstance(value, Iterator):
1457 value = list(value)
1458 from plain.postgres.fields.related import RelatedField
1459 from plain.postgres.fields.reverse_related import ForeignObjectRel
1460
1461 if isinstance(join_info.final_field, RelatedField | ForeignObjectRel):
1462 self.check_related_objects(join_info.final_field, value, join_info.meta)
1463
1464 # split_exclude() needs to know which joins were generated for the
1465 # lookup parts
1466 self._lookup_joins = join_info.joins
1467 except MultiJoin as e:
1468 return self.split_exclude(
1469 filter_expr,
1470 can_reuse or set(),
1471 e.names_with_path,
1472 )
1473
1474 # Update used_joins before trimming since they are reused to determine
1475 # which joins could be later promoted to INNER.
1476 used_joins.update(join_info.joins)
1477 targets, alias, join_list = self.trim_joins(
1478 join_info.targets, join_info.joins, join_info.path
1479 )
1480 if can_reuse is not None:
1481 can_reuse.update(join_list)
1482
1483 col = self._get_col(targets[0], join_info.final_field, alias)
1484
1485 condition = self.build_lookup(list(lookups), col, value)
1486 assert condition is not None
1487 lookup_type = condition.lookup_name
1488 clause = WhereNode([condition], connector=AND)
1489
1490 require_outer = (
1491 lookup_type == "isnull" and condition.rhs is True and not current_negated
1492 )
1493 if (
1494 current_negated
1495 and (lookup_type != "isnull" or condition.rhs is False)
1496 and condition.rhs is not None
1497 ):
1498 require_outer = True
1499 if lookup_type != "isnull":
1500 # The condition added here will be SQL like this:
1501 # NOT (col IS NOT NULL), where the first NOT is added in
1502 # upper layers of code. The reason for addition is that if col
1503 # is null, then col != someval will result in SQL "unknown"
1504 # which isn't the same as in Python. The Python None handling
1505 # is wanted, and it can be gotten by
1506 # (col IS NULL OR col != someval)
1507 # <=>
1508 # NOT (col IS NOT NULL AND col = someval).
1509 if (
1510 self.is_nullable(targets[0])
1511 or self.alias_map[join_list[-1]].join_type == LOUTER
1512 ):
1513 lookup_class = targets[0].get_lookup("isnull")
1514 assert lookup_class is not None
1515 col = self._get_col(targets[0], join_info.targets[0], alias)
1516 clause.add(lookup_class(col, False), AND)
1517 # If someval is a nullable column, someval IS NOT NULL is
1518 # added.
1519 if isinstance(value, Col) and self.is_nullable(value.target):
1520 lookup_class = value.target.get_lookup("isnull")
1521 assert lookup_class is not None
1522 clause.add(lookup_class(value, False), AND)
1523 return clause, used_joins if not require_outer else ()
1524
1525 def add_filter(self, filter_lhs: str, filter_rhs: Any) -> None:
1526 self.add_q(Q((filter_lhs, filter_rhs)))
1527
1528 def add_q(self, q_object: Q) -> None:
1529 """
1530 A preprocessor for the internal _add_q(). Responsible for doing final
1531 join promotion.
1532 """
1533 # For join promotion this case is doing an AND for the added q_object
1534 # and existing conditions. So, any existing inner join forces the join
1535 # type to remain inner. Existing outer joins can however be demoted.
1536 # (Consider case where rel_a is LOUTER and rel_a__col=1 is added - if
1537 # rel_a doesn't produce any rows, then the whole condition must fail.
1538 # So, demotion is OK.
1539 existing_inner = {
1540 a for a in self.alias_map if self.alias_map[a].join_type == INNER
1541 }
1542 clause, _ = self._add_q(q_object, self.used_aliases)
1543 if clause:
1544 self.where.add(clause, AND)
1545 self.demote_joins(existing_inner)
1546
1547 def build_where(
1548 self, filter_expr: tuple[str, Any] | Q | BaseExpression
1549 ) -> WhereNode:
1550 return self.build_filter(filter_expr, allow_joins=False)[0]
1551
1552 def clear_where(self) -> None:
1553 self.where = WhereNode()
1554
1555 def _add_q(
1556 self,
1557 q_object: Q,
1558 used_aliases: set[str] | None,
1559 branch_negated: bool = False,
1560 current_negated: bool = False,
1561 allow_joins: bool = True,
1562 split_subq: bool = True,
1563 check_filterable: bool = True,
1564 summarize: bool = False,
1565 ) -> tuple[WhereNode, set[str] | tuple[()]]:
1566 """Add a Q-object to the current filter."""
1567 connector = q_object.connector
1568 current_negated ^= q_object.negated
1569 branch_negated = branch_negated or q_object.negated
1570 target_clause = WhereNode(connector=connector, negated=q_object.negated)
1571 joinpromoter = JoinPromoter(
1572 q_object.connector, len(q_object.children), current_negated
1573 )
1574 for child in q_object.children:
1575 child_clause, needed_inner = self.build_filter(
1576 child,
1577 can_reuse=used_aliases,
1578 branch_negated=branch_negated,
1579 current_negated=current_negated,
1580 allow_joins=allow_joins,
1581 split_subq=split_subq,
1582 check_filterable=check_filterable,
1583 summarize=summarize,
1584 )
1585 joinpromoter.add_votes(needed_inner)
1586 if child_clause:
1587 target_clause.add(child_clause, connector)
1588 needed_inner = joinpromoter.update_join_types(self)
1589 return target_clause, needed_inner
1590
1591 def build_filtered_relation_q(
1592 self,
1593 q_object: Q,
1594 reuse: set[str],
1595 branch_negated: bool = False,
1596 current_negated: bool = False,
1597 ) -> WhereNode:
1598 """Add a FilteredRelation object to the current filter."""
1599 connector = q_object.connector
1600 current_negated ^= q_object.negated
1601 branch_negated = branch_negated or q_object.negated
1602 target_clause = WhereNode(connector=connector, negated=q_object.negated)
1603 for child in q_object.children:
1604 if isinstance(child, Node):
1605 child_clause = self.build_filtered_relation_q(
1606 child,
1607 reuse=reuse,
1608 branch_negated=branch_negated,
1609 current_negated=current_negated,
1610 )
1611 else:
1612 child_clause, _ = self.build_filter(
1613 child,
1614 can_reuse=reuse,
1615 branch_negated=branch_negated,
1616 current_negated=current_negated,
1617 allow_joins=True,
1618 split_subq=False,
1619 reuse_with_filtered_relation=True,
1620 )
1621 target_clause.add(child_clause, connector)
1622 return target_clause
1623
1624 def add_filtered_relation(self, filtered_relation: Any, alias: str) -> None:
1625 filtered_relation.alias = alias
1626 lookups = dict(get_children_from_q(filtered_relation.condition))
1627 relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(
1628 filtered_relation.relation_name
1629 )
1630 if relation_lookup_parts:
1631 raise ValueError(
1632 "FilteredRelation's relation_name cannot contain lookups "
1633 f"(got {filtered_relation.relation_name!r})."
1634 )
1635 for lookup in chain(lookups):
1636 lookup_parts, lookup_field_parts, _ = self.solve_lookup_type(lookup)
1637 shift = 2 if not lookup_parts else 1
1638 lookup_field_path = lookup_field_parts[:-shift]
1639 for idx, lookup_field_part in enumerate(lookup_field_path):
1640 if len(relation_field_parts) > idx:
1641 if relation_field_parts[idx] != lookup_field_part:
1642 raise ValueError(
1643 "FilteredRelation's condition doesn't support "
1644 f"relations outside the {filtered_relation.relation_name!r} (got {lookup!r})."
1645 )
1646 else:
1647 raise ValueError(
1648 "FilteredRelation's condition doesn't support nested "
1649 f"relations deeper than the relation_name (got {lookup!r} for "
1650 f"{filtered_relation.relation_name!r})."
1651 )
1652 self._filtered_relations[filtered_relation.alias] = filtered_relation
1653
1654 def names_to_path(
1655 self,
1656 names: list[str],
1657 meta: Meta,
1658 allow_many: bool = True,
1659 fail_on_missing: bool = False,
1660 ) -> tuple[list[Any], Field | ForeignObjectRel, tuple[Field, ...], list[str]]:
1661 """
1662 Walk the list of names and turns them into PathInfo tuples. A single
1663 name in 'names' can generate multiple PathInfos (m2m, for example).
1664
1665 'names' is the path of names to travel, 'meta' is the Meta we
1666 start the name resolving from, 'allow_many' is as for setup_joins().
1667 If fail_on_missing is set to True, then a name that can't be resolved
1668 will generate a FieldError.
1669
1670 Return a list of PathInfo tuples. In addition return the final field
1671 (the last used join field) and target (which is a field guaranteed to
1672 contain the same value as the final field). Finally, return those names
1673 that weren't found (which are likely transforms and the final lookup).
1674 """
1675 path, names_with_path = [], []
1676 for pos, name in enumerate(names):
1677 cur_names_with_path = (name, [])
1678
1679 field = None
1680 filtered_relation = None
1681 try:
1682 if meta is None:
1683 raise FieldDoesNotExist
1684 field = meta.get_field(name)
1685 except FieldDoesNotExist:
1686 if name in self.annotation_select:
1687 field = self.annotation_select[name].output_field
1688 elif name in self._filtered_relations and pos == 0:
1689 filtered_relation = self._filtered_relations[name]
1690 if LOOKUP_SEP in filtered_relation.relation_name:
1691 parts = filtered_relation.relation_name.split(LOOKUP_SEP)
1692 filtered_relation_path, field, _, _ = self.names_to_path(
1693 parts,
1694 meta,
1695 allow_many,
1696 fail_on_missing,
1697 )
1698 path.extend(filtered_relation_path[:-1])
1699 else:
1700 field = meta.get_field(filtered_relation.relation_name)
1701 if field is None:
1702 # We didn't find the current field, so move position back
1703 # one step.
1704 pos -= 1
1705 if pos == -1 or fail_on_missing:
1706 available = sorted(
1707 [
1708 *get_field_names_from_opts(meta),
1709 *self.annotation_select,
1710 *self._filtered_relations,
1711 ]
1712 )
1713 raise FieldError(
1714 "Cannot resolve keyword '{}' into field. "
1715 "Choices are: {}".format(name, ", ".join(available))
1716 )
1717 break
1718
1719 # Lazy import to avoid circular dependency
1720 from plain.postgres.fields.related import ForeignKeyField as FK
1721 from plain.postgres.fields.related import ManyToManyField as M2M
1722 from plain.postgres.fields.reverse_related import ForeignObjectRel as FORel
1723
1724 if isinstance(field, FK | M2M | FORel):
1725 pathinfos: list[PathInfo]
1726 if filtered_relation:
1727 pathinfos = field.get_path_info(filtered_relation)
1728 else:
1729 pathinfos = field.path_infos
1730 if not allow_many:
1731 for inner_pos, p in enumerate(pathinfos):
1732 if p.m2m:
1733 cur_names_with_path[1].extend(pathinfos[0 : inner_pos + 1])
1734 names_with_path.append(cur_names_with_path)
1735 raise MultiJoin(pos + 1, names_with_path)
1736 last = pathinfos[-1]
1737 path.extend(pathinfos)
1738 final_field = last.join_field
1739 meta = last.to_meta
1740 targets = last.target_fields
1741 cur_names_with_path[1].extend(pathinfos)
1742 names_with_path.append(cur_names_with_path)
1743 else:
1744 # Local non-relational field.
1745 final_field = field
1746 targets = (field,)
1747 if fail_on_missing and pos + 1 != len(names):
1748 raise FieldError(
1749 f"Cannot resolve keyword {names[pos + 1]!r} into field. Join on '{name}'"
1750 " not permitted."
1751 )
1752 break
1753 return path, final_field, targets, names[pos + 1 :]
1754
1755 def setup_joins(
1756 self,
1757 names: list[str],
1758 meta: Meta,
1759 alias: str,
1760 can_reuse: set[str] | None = None,
1761 allow_many: bool = True,
1762 reuse_with_filtered_relation: bool = False,
1763 ) -> JoinInfo:
1764 """
1765 Compute the necessary table joins for the passage through the fields
1766 given in 'names'. 'meta' is the Meta for the current model
1767 (which gives the table we are starting from), 'alias' is the alias for
1768 the table to start the joining from.
1769
1770 The 'can_reuse' defines the reverse foreign key joins we can reuse. It
1771 can be None in which case all joins are reusable or a set of aliases
1772 that can be reused. Note that non-reverse foreign keys are always
1773 reusable when using setup_joins().
1774
1775 The 'reuse_with_filtered_relation' can be used to force 'can_reuse'
1776 parameter and force the relation on the given connections.
1777
1778 If 'allow_many' is False, then any reverse foreign key seen will
1779 generate a MultiJoin exception.
1780
1781 Return the final field involved in the joins, the target field (used
1782 for any 'where' constraint), the final 'opts' value, the joins, the
1783 field path traveled to generate the joins, and a transform function
1784 that takes a field and alias and is equivalent to `field.get_col(alias)`
1785 in the simple case but wraps field transforms if they were included in
1786 names.
1787
1788 The target field is the field containing the concrete value. Final
1789 field can be something different, for example foreign key pointing to
1790 that value. Final field is needed for example in some value
1791 conversions (convert 'obj' in fk__id=obj to pk val using the foreign
1792 key field for example).
1793 """
1794 joins = [alias]
1795 # The transform can't be applied yet, as joins must be trimmed later.
1796 # To avoid making every caller of this method look up transforms
1797 # directly, compute transforms here and create a partial that converts
1798 # fields to the appropriate wrapped version.
1799
1800 def _base_transformer(field: Field, alias: str | None) -> Col:
1801 if not self.alias_cols:
1802 alias = None
1803 return field.get_col(alias)
1804
1805 final_transformer: TransformWrapper | Callable[[Field, str | None], Col] = (
1806 _base_transformer
1807 )
1808
1809 # Try resolving all the names as fields first. If there's an error,
1810 # treat trailing names as lookups until a field can be resolved.
1811 last_field_exception = None
1812 for pivot in range(len(names), 0, -1):
1813 try:
1814 path, final_field, targets, rest = self.names_to_path(
1815 names[:pivot],
1816 meta,
1817 allow_many,
1818 fail_on_missing=True,
1819 )
1820 except FieldError as exc:
1821 if pivot == 1:
1822 # The first item cannot be a lookup, so it's safe
1823 # to raise the field error here.
1824 raise
1825 else:
1826 last_field_exception = exc
1827 else:
1828 # The transforms are the remaining items that couldn't be
1829 # resolved into fields.
1830 transforms = names[pivot:]
1831 break
1832 for name in transforms:
1833
1834 def transform(
1835 field: Field, alias: str | None, *, name: str, previous: Any
1836 ) -> BaseExpression:
1837 try:
1838 wrapped = previous(field, alias)
1839 return self.try_transform(wrapped, name)
1840 except FieldError:
1841 # FieldError is raised if the transform doesn't exist.
1842 if isinstance(final_field, Field) and last_field_exception:
1843 raise last_field_exception
1844 else:
1845 raise
1846
1847 final_transformer = TransformWrapper(
1848 transform, name=name, previous=final_transformer
1849 )
1850 final_transformer.has_transforms = True
1851 # Then, add the path to the query's joins. Note that we can't trim
1852 # joins at this stage - we will need the information about join type
1853 # of the trimmed joins.
1854 for join in path:
1855 if join.filtered_relation:
1856 filtered_relation = join.filtered_relation.clone()
1857 table_alias = filtered_relation.alias
1858 else:
1859 filtered_relation = None
1860 table_alias = None
1861 meta = join.to_meta
1862 if join.direct:
1863 nullable = self.is_nullable(join.join_field)
1864 else:
1865 nullable = True
1866 connection = self.join_class(
1867 meta.model.model_options.db_table,
1868 alias,
1869 table_alias, # ty: ignore[invalid-argument-type]
1870 INNER,
1871 join.join_field,
1872 nullable,
1873 filtered_relation=filtered_relation,
1874 )
1875 reuse = can_reuse if join.m2m or reuse_with_filtered_relation else None
1876 alias = self.join(
1877 connection,
1878 reuse=reuse,
1879 reuse_with_filtered_relation=reuse_with_filtered_relation,
1880 )
1881 joins.append(alias)
1882 if filtered_relation:
1883 filtered_relation.path = joins[:]
1884 return JoinInfo(final_field, targets, meta, joins, path, final_transformer) # ty: ignore[invalid-argument-type]
1885
1886 def trim_joins(
1887 self, targets: tuple[Field, ...], joins: list[str], path: list[Any]
1888 ) -> tuple[tuple[Field, ...], str, list[str]]:
1889 """
1890 The 'target' parameter is the final field being joined to, 'joins'
1891 is the full list of join aliases. The 'path' contain the PathInfos
1892 used to create the joins.
1893
1894 Return the final target field and table alias and the new active
1895 joins.
1896
1897 Always trim any direct join if the target column is already in the
1898 previous table. Can't trim reverse joins as it's unknown if there's
1899 anything on the other side of the join.
1900 """
1901 joins = joins[:]
1902 for pos, info in enumerate(reversed(path)):
1903 if len(joins) == 1 or not info.direct:
1904 break
1905 if info.filtered_relation:
1906 break
1907 join_targets = {t.column for t in info.join_field.foreign_related_fields}
1908 cur_targets = {t.column for t in targets}
1909 if not cur_targets.issubset(join_targets):
1910 break
1911 targets_dict = {
1912 r[1].column: r[0]
1913 for r in info.join_field.related_fields
1914 if r[1].column in cur_targets
1915 }
1916 targets = tuple(targets_dict[t.column] for t in targets)
1917 self.unref_alias(joins.pop())
1918 return targets, joins[-1], joins
1919
1920 @classmethod
1921 def _gen_cols(
1922 cls,
1923 exprs: Iterable[Any],
1924 include_external: bool = False,
1925 resolve_refs: bool = True,
1926 ) -> TypingIterator[Col]:
1927 for expr in exprs:
1928 if isinstance(expr, Col):
1929 yield expr
1930 elif include_external and callable(
1931 getattr(expr, "get_external_cols", None)
1932 ):
1933 yield from expr.get_external_cols()
1934 elif hasattr(expr, "get_source_expressions"):
1935 if not resolve_refs and isinstance(expr, Ref):
1936 continue
1937 yield from cls._gen_cols(
1938 expr.get_source_expressions(),
1939 include_external=include_external,
1940 resolve_refs=resolve_refs,
1941 )
1942
1943 @classmethod
1944 def _gen_col_aliases(cls, exprs: Iterable[Any]) -> TypingIterator[str | None]:
1945 yield from (expr.alias for expr in cls._gen_cols(exprs))
1946
1947 def resolve_ref(
1948 self,
1949 name: str,
1950 allow_joins: bool = True,
1951 reuse: set[str] | None = None,
1952 summarize: bool = False,
1953 ) -> BaseExpression:
1954 annotation = self.annotations.get(name)
1955 if annotation is not None:
1956 if not allow_joins:
1957 for alias in self._gen_col_aliases([annotation]):
1958 if isinstance(self.alias_map[alias], Join):
1959 raise FieldError(
1960 "Joined field references are not permitted in this query"
1961 )
1962 if summarize:
1963 # Summarize currently means we are doing an aggregate() query
1964 # which is executed as a wrapped subquery if any of the
1965 # aggregate() elements reference an existing annotation. In
1966 # that case we need to return a Ref to the subquery's annotation.
1967 if name not in self.annotation_select:
1968 raise FieldError(
1969 f"Cannot aggregate over the '{name}' alias. Use annotate() "
1970 "to promote it."
1971 )
1972 return Ref(name, self.annotation_select[name])
1973 else:
1974 return annotation
1975 else:
1976 field_list = name.split(LOOKUP_SEP)
1977 annotation = self.annotations.get(field_list[0])
1978 if annotation is not None:
1979 for transform in field_list[1:]:
1980 annotation = self.try_transform(annotation, transform)
1981 return annotation
1982 initial_alias = self.get_initial_alias()
1983 assert initial_alias is not None
1984 assert self.model is not None, "Resolving field references requires a model"
1985 meta = self.model._model_meta
1986 join_info = self.setup_joins(
1987 field_list,
1988 meta,
1989 initial_alias,
1990 can_reuse=reuse,
1991 )
1992 targets, final_alias, join_list = self.trim_joins(
1993 join_info.targets, join_info.joins, join_info.path
1994 )
1995 if not allow_joins and len(join_list) > 1:
1996 raise FieldError(
1997 "Joined field references are not permitted in this query"
1998 )
1999 # Verify that the last lookup in name is a field or a transform:
2000 # transform_function() raises FieldError if not.
2001 transform = join_info.transform_function(targets[0], final_alias)
2002 if reuse is not None:
2003 reuse.update(join_list)
2004 return transform
2005
2006 def split_exclude(
2007 self,
2008 filter_expr: tuple[str, Any],
2009 can_reuse: set[str],
2010 names_with_path: list[tuple[str, list[Any]]],
2011 ) -> tuple[WhereNode, set[str] | tuple[()]]:
2012 """
2013 When doing an exclude against any kind of N-to-many relation, we need
2014 to use a subquery. This method constructs the nested query, given the
2015 original exclude filter (filter_expr) and the portion up to the first
2016 N-to-many relation field.
2017
2018 For example, if the origin filter is ~Q(child__name='foo'), filter_expr
2019 is ('child__name', 'foo') and can_reuse is a set of joins usable for
2020 filters in the original query.
2021
2022 We will turn this into equivalent of:
2023 WHERE NOT EXISTS(
2024 SELECT 1
2025 FROM child
2026 WHERE name = 'foo' AND child.parent_id = parent.id
2027 LIMIT 1
2028 )
2029 """
2030 # Generate the inner query.
2031 query = self.__class__(self.model)
2032 query._filtered_relations = self._filtered_relations
2033 filter_lhs, filter_rhs = filter_expr
2034 if isinstance(filter_rhs, OuterRef):
2035 filter_rhs = OuterRef(filter_rhs)
2036 elif isinstance(filter_rhs, F):
2037 filter_rhs = OuterRef(filter_rhs.name)
2038 query.add_filter(filter_lhs, filter_rhs)
2039 query.clear_ordering(force=True)
2040 # Try to have as simple as possible subquery -> trim leading joins from
2041 # the subquery.
2042 trimmed_prefix, contains_louter = query.trim_start(names_with_path)
2043
2044 # trim_start() seeds query.select with Col instances built from the
2045 # join field, so .target/.alias access below is sound.
2046 col = query.select[0]
2047 assert isinstance(col, Col)
2048 select_field = col.target
2049 alias = col.alias
2050 if alias in can_reuse:
2051 id_field = select_field.model._model_meta.get_forward_field("id")
2052 # Need to add a restriction so that outer query's filters are in effect for
2053 # the subquery, too.
2054 query.bump_prefix(self)
2055 lookup_class = select_field.get_lookup("exact")
2056 # Note that the query.select[0].alias is different from alias
2057 # due to bump_prefix above.
2058 bumped_col = query.select[0]
2059 assert isinstance(bumped_col, Col)
2060 lookup = lookup_class(
2061 id_field.get_col(bumped_col.alias), id_field.get_col(alias)
2062 )
2063 query.where.add(lookup, AND)
2064 query.external_aliases[alias] = True
2065
2066 lookup_class = select_field.get_lookup("exact")
2067 lookup = lookup_class(col, ResolvedOuterRef(trimmed_prefix))
2068 query.where.add(lookup, AND)
2069 condition, needed_inner = self.build_filter(Exists(query))
2070
2071 if contains_louter:
2072 or_null_condition, _ = self.build_filter(
2073 (f"{trimmed_prefix}__isnull", True),
2074 current_negated=True,
2075 branch_negated=True,
2076 can_reuse=can_reuse,
2077 )
2078 condition.add(or_null_condition, OR)
2079 # Note that the end result will be:
2080 # (outercol NOT IN innerq AND outercol IS NOT NULL) OR outercol IS NULL.
2081 # This might look crazy but due to how IN works, this seems to be
2082 # correct. If the IS NOT NULL check is removed then outercol NOT
2083 # IN will return UNKNOWN. If the IS NULL check is removed, then if
2084 # outercol IS NULL we will not match the row.
2085 return condition, needed_inner
2086
2087 def set_empty(self) -> None:
2088 self.where.add(NothingNode(), AND)
2089
2090 def is_empty(self) -> bool:
2091 return any(isinstance(c, NothingNode) for c in self.where.children)
2092
2093 def set_limits(self, low: int | None = None, high: int | None = None) -> None:
2094 """
2095 Adjust the limits on the rows retrieved. Use low/high to set these,
2096 as it makes it more Pythonic to read and write. When the SQL query is
2097 created, convert them to the appropriate offset and limit values.
2098
2099 Apply any limits passed in here to the existing constraints. Add low
2100 to the current low value and clamp both to any existing high value.
2101 """
2102 if high is not None:
2103 if self.high_mark is not None:
2104 self.high_mark = min(self.high_mark, self.low_mark + high)
2105 else:
2106 self.high_mark = self.low_mark + high
2107 if low is not None:
2108 if self.high_mark is not None:
2109 self.low_mark = min(self.high_mark, self.low_mark + low)
2110 else:
2111 self.low_mark = self.low_mark + low
2112
2113 if self.low_mark == self.high_mark:
2114 self.set_empty()
2115
2116 def clear_limits(self) -> None:
2117 """Clear any existing limits."""
2118 self.low_mark, self.high_mark = 0, None
2119
2120 @property
2121 def is_sliced(self) -> bool:
2122 return self.low_mark != 0 or self.high_mark is not None
2123
2124 def has_limit_one(self) -> bool:
2125 return self.high_mark is not None and (self.high_mark - self.low_mark) == 1
2126
2127 def can_filter(self) -> bool:
2128 """
2129 Return True if adding filters to this instance is still possible.
2130
2131 Typically, this means no limits or offsets have been put on the results.
2132 """
2133 return not self.is_sliced
2134
2135 def clear_select_clause(self) -> None:
2136 """Remove all fields from SELECT clause."""
2137 self.select = ()
2138 self.default_cols = False
2139 self.select_related = False
2140 self.set_extra_mask(())
2141 self.set_annotation_mask(())
2142
2143 def clear_select_fields(self) -> None:
2144 """
2145 Clear the list of fields to select (but not extra_select columns).
2146 Some queryset types completely replace any existing list of select
2147 columns.
2148 """
2149 self.select = ()
2150 self.values_select = ()
2151
2152 def add_select_col(self, col: BaseExpression, name: str) -> None:
2153 self.select += (col,)
2154 self.values_select += (name,)
2155
2156 def set_select(
2157 self, cols: list[BaseExpression] | tuple[BaseExpression, ...]
2158 ) -> None:
2159 self.default_cols = False
2160 self.select = tuple(cols)
2161
2162 def add_distinct_fields(self, *field_names: str) -> None:
2163 """
2164 Add and resolve the given fields to the query's "distinct on" clause.
2165 """
2166 self.distinct_fields = field_names
2167 self.distinct = True
2168
2169 def add_fields(
2170 self, field_names: list[str] | TypingIterator[str], allow_m2m: bool = True
2171 ) -> None:
2172 """
2173 Add the given (model) fields to the select set. Add the field names in
2174 the order specified.
2175 """
2176 alias = self.get_initial_alias()
2177 assert alias is not None
2178 assert self.model is not None, "add_fields() requires a model"
2179 meta = self.model._model_meta
2180
2181 try:
2182 cols = []
2183 for name in field_names:
2184 # Join promotion note - we must not remove any rows here, so
2185 # if there is no existing joins, use outer join.
2186 join_info = self.setup_joins(
2187 name.split(LOOKUP_SEP), meta, alias, allow_many=allow_m2m
2188 )
2189 targets, final_alias, joins = self.trim_joins(
2190 join_info.targets,
2191 join_info.joins,
2192 join_info.path,
2193 )
2194 for target in targets:
2195 cols.append(join_info.transform_function(target, final_alias))
2196 if cols:
2197 self.set_select(cols)
2198 except MultiJoin:
2199 raise FieldError(f"Invalid field name: '{name}'")
2200 except FieldError:
2201 if LOOKUP_SEP in name:
2202 # For lookups spanning over relationships, show the error
2203 # from the model on which the lookup failed.
2204 raise
2205 elif name in self.annotations:
2206 raise FieldError(
2207 f"Cannot select the '{name}' alias. Use annotate() to promote it."
2208 )
2209 else:
2210 names = sorted(
2211 [
2212 *get_field_names_from_opts(meta),
2213 *self.extra,
2214 *self.annotation_select,
2215 *self._filtered_relations,
2216 ]
2217 )
2218 raise FieldError(
2219 "Cannot resolve keyword {!r} into field. Choices are: {}".format(
2220 name, ", ".join(names)
2221 )
2222 )
2223
2224 def add_ordering(self, *ordering: str | BaseExpression) -> None:
2225 """
2226 Add items from the 'ordering' sequence to the query's "order by"
2227 clause. These items are either field names (not column names) --
2228 possibly with a direction prefix ('-' or '?') -- or OrderBy
2229 expressions.
2230
2231 If 'ordering' is empty, clear all ordering from the query.
2232 """
2233 errors = []
2234 for item in ordering:
2235 if isinstance(item, str):
2236 if item == "?":
2237 continue
2238 item = item.removeprefix("-")
2239 if item in self.annotations:
2240 continue
2241 if self.extra and item in self.extra:
2242 continue
2243 # names_to_path() validates the lookup. A descriptive
2244 # FieldError will be raise if it's not.
2245 assert self.model is not None, "ORDER BY field names require a model"
2246 self.names_to_path(item.split(LOOKUP_SEP), self.model._model_meta)
2247 elif not isinstance(item, ResolvableExpression):
2248 errors.append(item)
2249 if getattr(item, "contains_aggregate", False):
2250 raise FieldError(
2251 "Using an aggregate in order_by() without also including "
2252 f"it in annotate() is not allowed: {item}"
2253 )
2254 if errors:
2255 raise FieldError(f"Invalid order_by arguments: {errors}")
2256 if ordering:
2257 self.order_by += ordering
2258 else:
2259 self.default_ordering = False
2260
2261 def clear_ordering(self, force: bool = False, clear_default: bool = True) -> None:
2262 """
2263 Remove any ordering settings if the current query allows it without
2264 side effects, set 'force' to True to clear the ordering regardless.
2265 If 'clear_default' is True, there will be no ordering in the resulting
2266 query (not even the model's default).
2267 """
2268 if not force and (
2269 self.is_sliced or self.distinct_fields or self.select_for_update
2270 ):
2271 return
2272 self.order_by = ()
2273 self.extra_order_by = ()
2274 if clear_default:
2275 self.default_ordering = False
2276
2277 def set_group_by(self, allow_aliases: bool = True) -> None:
2278 """
2279 Expand the GROUP BY clause required by the query.
2280
2281 This will usually be the set of all non-aggregate fields in the
2282 return data. If the database backend supports grouping by the
2283 primary key, and the query would be equivalent, the optimization
2284 will be made automatically.
2285 """
2286 if allow_aliases and self.values_select:
2287 # If grouping by aliases is allowed assign selected value aliases
2288 # by moving them to annotations.
2289 group_by_annotations = {}
2290 values_select = {}
2291 for alias, expr in zip(self.values_select, self.select):
2292 if isinstance(expr, Col):
2293 values_select[alias] = expr
2294 else:
2295 group_by_annotations[alias] = expr
2296 self.annotations = {**group_by_annotations, **self.annotations}
2297 self.append_annotation_mask(group_by_annotations)
2298 self.select = tuple(values_select.values())
2299 self.values_select = tuple(values_select)
2300 group_by = list(self.select)
2301 for alias, annotation in self.annotation_select.items():
2302 if not (group_by_cols := annotation.get_group_by_cols()):
2303 continue
2304 if allow_aliases and not annotation.contains_aggregate:
2305 group_by.append(Ref(alias, annotation))
2306 else:
2307 group_by.extend(group_by_cols)
2308 self.group_by = tuple(group_by)
2309
2310 def add_select_related(self, fields: list[str]) -> None:
2311 """
2312 Set up the select_related data structure so that we only select
2313 certain related models (as opposed to all models, when
2314 self.select_related=True).
2315 """
2316 if isinstance(self.select_related, bool):
2317 field_dict: dict[str, Any] = {}
2318 else:
2319 field_dict = self.select_related
2320 for field in fields:
2321 d = field_dict
2322 for part in field.split(LOOKUP_SEP):
2323 d = d.setdefault(part, {})
2324 self.select_related = field_dict
2325
2326 def add_extra(
2327 self,
2328 select: dict[str, str],
2329 select_params: list[Any] | None,
2330 where: list[str],
2331 params: list[Any],
2332 tables: list[str],
2333 order_by: tuple[str, ...],
2334 ) -> None:
2335 """
2336 Add data to the various extra_* attributes for user-created additions
2337 to the query.
2338 """
2339 if select:
2340 # We need to pair any placeholder markers in the 'select'
2341 # dictionary with their parameters in 'select_params' so that
2342 # subsequent updates to the select dictionary also adjust the
2343 # parameters appropriately.
2344 select_pairs = {}
2345 if select_params:
2346 param_iter = iter(select_params)
2347 else:
2348 param_iter = iter([])
2349 for name, entry in select.items():
2350 self.check_alias(name)
2351 entry = str(entry)
2352 entry_params = []
2353 pos = entry.find("%s")
2354 while pos != -1:
2355 if pos == 0 or entry[pos - 1] != "%":
2356 entry_params.append(next(param_iter))
2357 pos = entry.find("%s", pos + 2)
2358 select_pairs[name] = (entry, entry_params)
2359 self.extra.update(select_pairs)
2360 if where or params:
2361 self.where.add(ExtraWhere(where, params), AND)
2362 if tables:
2363 self.extra_tables += tuple(tables)
2364 if order_by:
2365 self.extra_order_by = order_by
2366
2367 def clear_deferred_loading(self) -> None:
2368 """Remove any fields from the deferred loading set."""
2369 self.deferred_loading = (frozenset(), True)
2370
2371 def add_deferred_loading(self, field_names: frozenset[str]) -> None:
2372 """
2373 Add the given list of model field names to the set of fields to
2374 exclude from loading from the database when automatic column selection
2375 is done. Add the new field names to any existing field names that
2376 are deferred (or removed from any existing field names that are marked
2377 as the only ones for immediate loading).
2378 """
2379 # Fields on related models are stored in the literal double-underscore
2380 # format, so that we can use a set datastructure. We do the foo__bar
2381 # splitting and handling when computing the SQL column names (as part of
2382 # get_columns()).
2383 existing, defer = self.deferred_loading
2384 existing_set = set(existing)
2385 if defer:
2386 # Add to existing deferred names.
2387 self.deferred_loading = frozenset(existing_set.union(field_names)), True
2388 else:
2389 # Remove names from the set of any existing "immediate load" names.
2390 if new_existing := existing_set.difference(field_names):
2391 self.deferred_loading = frozenset(new_existing), False
2392 else:
2393 self.clear_deferred_loading()
2394 if new_only := set(field_names).difference(existing_set):
2395 self.deferred_loading = frozenset(new_only), True
2396
2397 def add_immediate_loading(self, field_names: list[str] | set[str]) -> None:
2398 """
2399 Add the given list of model field names to the set of fields to
2400 retrieve when the SQL is executed ("immediate loading" fields). The
2401 field names replace any existing immediate loading field names. If
2402 there are field names already specified for deferred loading, remove
2403 those names from the new field_names before storing the new names
2404 for immediate loading. (That is, immediate loading overrides any
2405 existing immediate values, but respects existing deferrals.)
2406 """
2407 existing, defer = self.deferred_loading
2408 field_names_set = set(field_names)
2409
2410 if defer:
2411 # Remove any existing deferred names from the current set before
2412 # setting the new names.
2413 self.deferred_loading = (
2414 frozenset(field_names_set.difference(existing)),
2415 False,
2416 )
2417 else:
2418 # Replace any existing "immediate load" field names.
2419 self.deferred_loading = frozenset(field_names_set), False
2420
2421 def set_annotation_mask(
2422 self,
2423 names: set[str]
2424 | frozenset[str]
2425 | list[str]
2426 | tuple[str, ...]
2427 | dict[str, Any]
2428 | None,
2429 ) -> None:
2430 """Set the mask of annotations that will be returned by the SELECT."""
2431 if names is None:
2432 self.annotation_select_mask = None
2433 else:
2434 self.annotation_select_mask = set(names)
2435 self._annotation_select_cache = None
2436
2437 def append_annotation_mask(self, names: list[str] | dict[str, Any]) -> None:
2438 if self.annotation_select_mask is not None:
2439 self.set_annotation_mask(self.annotation_select_mask.union(names))
2440
2441 def set_extra_mask(
2442 self, names: set[str] | list[str] | tuple[str, ...] | None
2443 ) -> None:
2444 """
2445 Set the mask of extra select items that will be returned by SELECT.
2446 Don't remove them from the Query since they might be used later.
2447 """
2448 if names is None:
2449 self.extra_select_mask = None
2450 else:
2451 self.extra_select_mask = set(names)
2452 self._extra_select_cache = None
2453
2454 def set_values(self, fields: list[str]) -> None:
2455 self.select_related = False
2456 self.clear_deferred_loading()
2457 self.clear_select_fields()
2458 self.has_select_fields = True
2459
2460 if fields:
2461 field_names = []
2462 extra_names = []
2463 annotation_names = []
2464 if not self.extra and not self.annotations:
2465 # Shortcut - if there are no extra or annotations, then
2466 # the values() clause must be just field names.
2467 field_names = list(fields)
2468 else:
2469 self.default_cols = False
2470 for f in fields:
2471 if f in self.extra_select:
2472 extra_names.append(f)
2473 elif f in self.annotation_select:
2474 annotation_names.append(f)
2475 else:
2476 field_names.append(f)
2477 self.set_extra_mask(extra_names)
2478 self.set_annotation_mask(annotation_names)
2479 selected = frozenset(field_names + extra_names + annotation_names)
2480 else:
2481 assert self.model is not None, "Default values query requires a model"
2482 field_names = [f.attname for f in self.model._model_meta.concrete_fields]
2483 selected = frozenset(field_names)
2484 # Selected annotations must be known before setting the GROUP BY
2485 # clause.
2486 if self.group_by is True:
2487 assert self.model is not None, "GROUP BY True requires a model"
2488 self.add_fields(
2489 (f.attname for f in self.model._model_meta.concrete_fields), False
2490 )
2491 # Disable GROUP BY aliases to avoid orphaning references to the
2492 # SELECT clause which is about to be cleared.
2493 self.set_group_by(allow_aliases=False)
2494 self.clear_select_fields()
2495 elif self.group_by:
2496 # Resolve GROUP BY annotation references if they are not part of
2497 # the selected fields anymore.
2498 group_by = []
2499 for expr in self.group_by:
2500 if isinstance(expr, Ref) and expr.refs not in selected:
2501 expr = self.annotations[expr.refs]
2502 group_by.append(expr)
2503 self.group_by = tuple(group_by)
2504
2505 self.values_select = tuple(field_names)
2506 self.add_fields(field_names, True)
2507
2508 @property
2509 def annotation_select(self) -> dict[str, BaseExpression]:
2510 """
2511 Return the dictionary of aggregate columns that are not masked and
2512 should be used in the SELECT clause. Cache this result for performance.
2513 """
2514 if self._annotation_select_cache is not None:
2515 return self._annotation_select_cache
2516 elif not self.annotations:
2517 return {}
2518 elif self.annotation_select_mask is not None:
2519 self._annotation_select_cache = {
2520 k: v
2521 for k, v in self.annotations.items()
2522 if k in self.annotation_select_mask
2523 }
2524 return self._annotation_select_cache
2525 else:
2526 return self.annotations
2527
2528 @property
2529 def extra_select(self) -> dict[str, tuple[str, list[Any]]]:
2530 if self._extra_select_cache is not None:
2531 return self._extra_select_cache
2532 if not self.extra:
2533 return {}
2534 elif self.extra_select_mask is not None:
2535 self._extra_select_cache = {
2536 k: v for k, v in self.extra.items() if k in self.extra_select_mask
2537 }
2538 return self._extra_select_cache
2539 else:
2540 return self.extra
2541
2542 def trim_start(
2543 self, names_with_path: list[tuple[str, list[Any]]]
2544 ) -> tuple[str, bool]:
2545 """
2546 Trim joins from the start of the join path. The candidates for trim
2547 are the PathInfos in names_with_path structure that are m2m joins.
2548
2549 Also set the select column so the start matches the join.
2550
2551 This method is meant to be used for generating the subquery joins &
2552 cols in split_exclude().
2553
2554 Return a lookup usable for doing outerq.filter(lookup=self) and a
2555 boolean indicating if the joins in the prefix contain a LEFT OUTER join.
2556 _"""
2557 all_paths = []
2558 for _, paths in names_with_path:
2559 all_paths.extend(paths)
2560 contains_louter = False
2561 # Trim and operate only on tables that were generated for
2562 # the lookup part of the query. That is, avoid trimming
2563 # joins generated for F() expressions.
2564 lookup_tables = [
2565 t for t in self.alias_map if t in self._lookup_joins or t == self.base_table
2566 ]
2567 for trimmed_paths, path in enumerate(all_paths):
2568 if path.m2m:
2569 break
2570 if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type == LOUTER:
2571 contains_louter = True
2572 alias = lookup_tables[trimmed_paths]
2573 self.unref_alias(alias)
2574 # The path.join_field is a Rel, lets get the other side's field
2575 join_field = path.join_field.field
2576 # Build the filter prefix.
2577 paths_in_prefix = trimmed_paths
2578 trimmed_prefix = []
2579 for name, path in names_with_path:
2580 if paths_in_prefix - len(path) < 0:
2581 break
2582 trimmed_prefix.append(name)
2583 paths_in_prefix -= len(path)
2584 trimmed_prefix.append(join_field.foreign_related_fields[0].name)
2585 trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
2586 # Lets still see if we can trim the first join from the inner query
2587 # (that is, self). We can't do this for:
2588 # - LEFT JOINs because we would miss those rows that have nothing on
2589 # the outer side,
2590 # - INNER JOINs from filtered relations because we would miss their
2591 # filters.
2592 first_join = self.alias_map[lookup_tables[trimmed_paths + 1]]
2593 if first_join.join_type != LOUTER and not first_join.filtered_relation:
2594 select_fields = [r[0] for r in join_field.related_fields]
2595 select_alias = lookup_tables[trimmed_paths + 1]
2596 self.unref_alias(lookup_tables[trimmed_paths])
2597 else:
2598 # TODO: It might be possible to trim more joins from the start of the
2599 # inner query if it happens to have a longer join chain containing the
2600 # values in select_fields. Lets punt this one for now.
2601 select_fields = [r[1] for r in join_field.related_fields]
2602 select_alias = lookup_tables[trimmed_paths]
2603 # The found starting point is likely a join_class instead of a
2604 # base_table_class reference. But the first entry in the query's FROM
2605 # clause must not be a JOIN.
2606 for table in self.alias_map:
2607 if self.alias_refcount[table] > 0:
2608 self.alias_map[table] = self.base_table_class(
2609 self.alias_map[table].table_name,
2610 table,
2611 )
2612 break
2613 self.set_select([f.get_col(select_alias) for f in select_fields])
2614 return trimmed_prefix, contains_louter
2615
2616 def is_nullable(self, field: Field) -> bool:
2617 """Check if the given field should be treated as nullable."""
2618 # QuerySet does not have knowledge of which connection is going to be
2619 # used. For the single-database setup we always reference the default
2620 # connection here.
2621 if not isinstance(field, ColumnField):
2622 return False
2623 return field.allow_null
2624
2625
2626def get_order_dir(field: str, default: str = "ASC") -> tuple[str, str]:
2627 """
2628 Return the field name and direction for an order specification. For
2629 example, '-foo' is returned as ('foo', 'DESC').
2630
2631 The 'default' param is used to indicate which way no prefix (or a '+'
2632 prefix) should sort. The '-' prefix always sorts the opposite way.
2633 """
2634 dirn = ORDER_DIR[default]
2635 if field[0] == "-":
2636 return field[1:], dirn[1]
2637 return field, dirn[0]
2638
2639
2640class JoinPromoter:
2641 """
2642 A class to abstract away join promotion problems for complex filter
2643 conditions.
2644 """
2645
2646 def __init__(self, connector: str, num_children: int, negated: bool):
2647 self.connector = connector
2648 self.negated = negated
2649 if self.negated:
2650 if connector == AND:
2651 self.effective_connector = OR
2652 else:
2653 self.effective_connector = AND
2654 else:
2655 self.effective_connector = self.connector
2656 self.num_children = num_children
2657 # Maps of table alias to how many times it is seen as required for
2658 # inner and/or outer joins.
2659 self.votes = Counter()
2660
2661 def __repr__(self) -> str:
2662 return (
2663 f"{self.__class__.__qualname__}(connector={self.connector!r}, "
2664 f"num_children={self.num_children!r}, negated={self.negated!r})"
2665 )
2666
2667 def add_votes(self, votes: Any) -> None:
2668 """
2669 Add single vote per item to self.votes. Parameter can be any
2670 iterable.
2671 """
2672 self.votes.update(votes)
2673
2674 def update_join_types(self, query: Query) -> set[str]:
2675 """
2676 Change join types so that the generated query is as efficient as
2677 possible, but still correct. So, change as many joins as possible
2678 to INNER, but don't make OUTER joins INNER if that could remove
2679 results from the query.
2680 """
2681 to_promote = set()
2682 to_demote = set()
2683 # The effective_connector is used so that NOT (a AND b) is treated
2684 # similarly to (a OR b) for join promotion.
2685 for table, votes in self.votes.items():
2686 # We must use outer joins in OR case when the join isn't contained
2687 # in all of the joins. Otherwise the INNER JOIN itself could remove
2688 # valid results. Consider the case where a model with rel_a and
2689 # rel_b relations is queried with rel_a__col=1 | rel_b__col=2. Now,
2690 # if rel_a join doesn't produce any results is null (for example
2691 # reverse foreign key or null value in direct foreign key), and
2692 # there is a matching row in rel_b with col=2, then an INNER join
2693 # to rel_a would remove a valid match from the query. So, we need
2694 # to promote any existing INNER to LOUTER (it is possible this
2695 # promotion in turn will be demoted later on).
2696 if self.effective_connector == OR and votes < self.num_children:
2697 to_promote.add(table)
2698 # If connector is AND and there is a filter that can match only
2699 # when there is a joinable row, then use INNER. For example, in
2700 # rel_a__col=1 & rel_b__col=2, if either of the rels produce NULL
2701 # as join output, then the col=1 or col=2 can't match (as
2702 # NULL=anything is always false).
2703 # For the OR case, if all children voted for a join to be inner,
2704 # then we can use INNER for the join. For example:
2705 # (rel_a__col__icontains=Alex | rel_a__col__icontains=Russell)
2706 # then if rel_a doesn't produce any rows, the whole condition
2707 # can't match. Hence we can safely use INNER join.
2708 if self.effective_connector == AND or (
2709 self.effective_connector == OR and votes == self.num_children
2710 ):
2711 to_demote.add(table)
2712 # Finally, what happens in cases where we have:
2713 # (rel_a__col=1|rel_b__col=2) & rel_a__col__gte=0
2714 # Now, we first generate the OR clause, and promote joins for it
2715 # in the first if branch above. Both rel_a and rel_b are promoted
2716 # to LOUTER joins. After that we do the AND case. The OR case
2717 # voted no inner joins but the rel_a__col__gte=0 votes inner join
2718 # for rel_a. We demote it back to INNER join (in AND case a single
2719 # vote is enough). The demotion is OK, if rel_a doesn't produce
2720 # rows, then the rel_a__col__gte=0 clause can't be true, and thus
2721 # the whole clause must be false. So, it is safe to use INNER
2722 # join.
2723 # Note that in this example we could just as well have the __gte
2724 # clause and the OR clause swapped. Or we could replace the __gte
2725 # clause with an OR clause containing rel_a__col=1|rel_a__col=2,
2726 # and again we could safely demote to INNER.
2727 query.promote_joins(to_promote)
2728 query.demote_joins(to_demote)
2729 return to_demote
2730
2731
2732# ##### Query subclasses (merged from subqueries.py) #####
2733
2734
2735class DeleteQuery(Query):
2736 """A DELETE SQL query."""
2737
2738 def get_compiler(self, *, elide_empty: bool = True) -> SQLDeleteCompiler:
2739 from plain.postgres.sql.compiler import SQLDeleteCompiler
2740
2741 return SQLDeleteCompiler(self, get_connection(), elide_empty)
2742
2743 def do_query(self, table: str, where: Any) -> int:
2744 from plain.postgres.sql.constants import CURSOR
2745
2746 self.alias_map = {table: self.alias_map[table]}
2747 self.where = where
2748 cursor = self.get_compiler().execute_sql(CURSOR)
2749 if cursor:
2750 with cursor:
2751 return cursor.rowcount
2752 return 0
2753
2754 def delete_batch(self, id_list: list[Any]) -> int:
2755 """
2756 Set up and execute delete queries for all the objects in id_list.
2757
2758 More than one physical query may be executed if there are a
2759 lot of values in id_list.
2760 """
2761 from plain.postgres.sql.constants import GET_ITERATOR_CHUNK_SIZE
2762
2763 # number of objects deleted
2764 num_deleted = 0
2765 assert self.model is not None, "DELETE requires a model"
2766 meta = self.model._model_meta
2767 field = meta.get_forward_field("id")
2768 for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
2769 self.clear_where()
2770 self.add_filter(
2771 f"{field.attname}__in",
2772 id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
2773 )
2774 num_deleted += self.do_query(self.model.model_options.db_table, self.where)
2775 return num_deleted
2776
2777
2778class UpdateQuery(Query):
2779 """An UPDATE SQL query."""
2780
2781 def get_compiler(self, *, elide_empty: bool = True) -> SQLUpdateCompiler:
2782 from plain.postgres.sql.compiler import SQLUpdateCompiler
2783
2784 return SQLUpdateCompiler(self, get_connection(), elide_empty)
2785
2786 def __init__(self, *args: Any, **kwargs: Any) -> None:
2787 super().__init__(*args, **kwargs)
2788 self._setup_query()
2789
2790 def _setup_query(self) -> None:
2791 """
2792 Run on initialization and at the end of chaining. Any attributes that
2793 would normally be set in __init__() should go here instead.
2794 """
2795 self.values: list[tuple[Any, Any]] = []
2796
2797 def update_batch(self, id_list: list[Any], values: dict[str, Any]) -> None:
2798 from plain.postgres.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
2799
2800 self.add_update_values(values)
2801 for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
2802 self.clear_where()
2803 self.add_filter(
2804 "id__in", id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
2805 )
2806 self.get_compiler().execute_sql(NO_RESULTS)
2807
2808 def add_update_values(self, values: dict[str, Any]) -> None:
2809 """
2810 Convert a dictionary of field name to value mappings into an update
2811 query. This is the entry point for the public update() method on
2812 querysets.
2813 """
2814
2815 assert self.model is not None, "UPDATE requires model metadata"
2816 meta = self.model._model_meta
2817 values_seq = []
2818 for name, val in values.items():
2819 field = meta.get_field(name)
2820 from plain.postgres.fields.related import ManyToManyField
2821
2822 if isinstance(field, ManyToManyField):
2823 raise FieldError(
2824 f"Cannot update model field {field!r} (only non-relations and "
2825 "foreign keys permitted)."
2826 )
2827 values_seq.append((field, val))
2828 return self.add_update_fields(values_seq)
2829
2830 def add_update_fields(self, values_seq: list[tuple[Any, Any]]) -> None:
2831 """
2832 Append a sequence of (field, value) pairs to the internal list that
2833 will be used to generate the UPDATE query.
2834 """
2835 for field, val in values_seq:
2836 if isinstance(val, ResolvableExpression):
2837 # Resolve expressions here so that annotations are no longer needed
2838 val = val.resolve_expression(self, allow_joins=False, for_save=True)
2839 self.values.append((field, val))
2840
2841
2842class InsertQuery(Query):
2843 def get_compiler(self, *, elide_empty: bool = True) -> SQLInsertCompiler:
2844 from plain.postgres.sql.compiler import SQLInsertCompiler
2845
2846 return SQLInsertCompiler(self, get_connection(), elide_empty)
2847
2848 def __str__(self) -> str:
2849 raise NotImplementedError(
2850 "InsertQuery does not support __str__(). "
2851 "Use get_compiler().as_sql() which returns a list of SQL statements."
2852 )
2853
2854 def sql_with_params(self) -> Any:
2855 raise NotImplementedError(
2856 "InsertQuery does not support sql_with_params(). "
2857 "Use get_compiler().as_sql() which returns a list of SQL statements."
2858 )
2859
2860 def __init__(
2861 self,
2862 *args: Any,
2863 on_conflict: OnConflict | None = None,
2864 update_fields: list[Field] | None = None,
2865 unique_fields: list[Field] | None = None,
2866 **kwargs: Any,
2867 ) -> None:
2868 super().__init__(*args, **kwargs)
2869 self.fields: list[Field] = []
2870 self.objs: list[Any] = []
2871 self.on_conflict = on_conflict
2872 self.update_fields: list[Field] = update_fields or []
2873 self.unique_fields: list[Field] = unique_fields or []
2874
2875 def insert_values(
2876 self, fields: list[Any], objs: list[Any], raw: bool = False
2877 ) -> None:
2878 self.fields = fields
2879 self.objs = objs
2880 self.raw = raw
2881
2882
2883class AggregateQuery(Query):
2884 """
2885 Take another query as a parameter to the FROM clause and only select the
2886 elements in the provided list.
2887 """
2888
2889 def get_compiler(self, *, elide_empty: bool = True) -> SQLAggregateCompiler:
2890 from plain.postgres.sql.compiler import SQLAggregateCompiler
2891
2892 return SQLAggregateCompiler(self, get_connection(), elide_empty)
2893
2894 def __init__(self, model: Any, inner_query: Any) -> None:
2895 self.inner_query = inner_query
2896 super().__init__(model)