1"""
2Create SQL statements for QuerySets.
3
4The code in here encapsulates all of the SQL construction so that QuerySets
5themselves do not have to. This module has to know all about the internals of
6models in order to get the information it needs.
7"""
8
9from __future__ import annotations
10
11import copy
12import difflib
13import functools
14import sys
15from collections import Counter
16from collections.abc import Callable, Iterable, Iterator, Mapping
17from collections.abc import Iterator as TypingIterator
18from functools import cached_property
19from itertools import chain, count, product
20from string import ascii_uppercase
21from typing import (
22 TYPE_CHECKING,
23 Any,
24 Literal,
25 NamedTuple,
26 Self,
27 TypeVar,
28 cast,
29 overload,
30)
31
32from plain.postgres.aggregates import Count
33from plain.postgres.constants import LOOKUP_SEP, OnConflict
34from plain.postgres.db import NotSupportedError, get_connection
35from plain.postgres.exceptions import FieldDoesNotExist, FieldError
36from plain.postgres.expressions import (
37 BaseExpression,
38 Col,
39 Exists,
40 F,
41 OuterRef,
42 Ref,
43 ResolvableExpression,
44 ResolvedOuterRef,
45 Value,
46)
47from plain.postgres.fields import Field
48from plain.postgres.fields.related_lookups import MultiColSource
49from plain.postgres.lookups import Lookup
50from plain.postgres.query_utils import (
51 PathInfo,
52 Q,
53 check_rel_lookup_compatibility,
54 refs_expression,
55)
56from plain.postgres.sql.constants import INNER, LOUTER, ORDER_DIR, SINGLE
57from plain.postgres.sql.datastructures import BaseTable, Empty, Join, MultiJoin
58from plain.postgres.sql.where import AND, OR, ExtraWhere, NothingNode, WhereNode
59from plain.utils.regex_helper import _lazy_re_compile
60from plain.utils.tree import Node
61
62if TYPE_CHECKING:
63 from plain.postgres import Model
64 from plain.postgres.connection import DatabaseConnection
65 from plain.postgres.fields.related import RelatedField
66 from plain.postgres.fields.reverse_related import ForeignObjectRel
67 from plain.postgres.meta import Meta
68 from plain.postgres.sql.compiler import (
69 SQLAggregateCompiler,
70 SQLCompiler,
71 SQLDeleteCompiler,
72 SQLInsertCompiler,
73 SQLUpdateCompiler,
74 SqlWithParams,
75 )
76
77__all__ = [
78 "Query",
79 "RawQuery",
80 "DeleteQuery",
81 "UpdateQuery",
82 "InsertQuery",
83 "AggregateQuery",
84]
85
86
87# Quotation marks ('"`[]), whitespace characters, semicolons, or inline
88# SQL comments are forbidden in column aliases.
89FORBIDDEN_ALIAS_PATTERN = _lazy_re_compile(r"['`\"\]\[;\s]|--|/\*|\*/")
90
91# Inspired from
92# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
93EXPLAIN_OPTIONS_PATTERN = _lazy_re_compile(r"[\w\-]+")
94
95
96def get_field_names_from_opts(meta: Meta | None) -> set[str]:
97 if meta is None:
98 return set()
99 return set(
100 chain.from_iterable(
101 (f.name, f.attname) if f.concrete else (f.name,) for f in meta.get_fields()
102 )
103 )
104
105
106def get_children_from_q(q: Q) -> TypingIterator[tuple[str, Any]]:
107 for child in q.children:
108 if isinstance(child, Node):
109 yield from get_children_from_q(child)
110 else:
111 yield child
112
113
114class JoinInfo(NamedTuple):
115 """Information about a join operation in a query."""
116
117 final_field: Field[Any]
118 targets: tuple[Field[Any], ...]
119 meta: Meta
120 joins: list[str]
121 path: list[PathInfo]
122 transform_function: Callable[[Field[Any], str | None], BaseExpression]
123
124
125class RawQuery:
126 """A single raw SQL query."""
127
128 def __init__(self, sql: str, params: tuple[Any, ...] | dict[str, Any] = ()):
129 self.params = params
130 self.sql = sql
131 self.cursor: Any = None
132
133 # Mirror some properties of a normal query so that
134 # the compiler can be used to process results.
135 self.low_mark, self.high_mark = 0, None # Used for offset/limit
136 self.extra_select = {}
137 self.annotation_select = {}
138
139 def chain(self) -> RawQuery:
140 return self.clone()
141
142 def clone(self) -> RawQuery:
143 return RawQuery(self.sql, params=self.params)
144
145 def get_columns(self) -> list[str]:
146 if self.cursor is None:
147 self._execute_query()
148 return [column_meta[0] for column_meta in self.cursor.description]
149
150 def __iter__(self) -> TypingIterator[Any]:
151 # Always execute a new query for a new iterator.
152 # This could be optimized with a cache at the expense of RAM.
153 self._execute_query()
154 return iter(self.cursor)
155
156 def __repr__(self) -> str:
157 return f"<{self.__class__.__name__}: {self}>"
158
159 @property
160 def params_type(self) -> type[dict] | type[tuple] | None:
161 if self.params is None:
162 return None
163 return dict if isinstance(self.params, Mapping) else tuple
164
165 def __str__(self) -> str:
166 if self.params_type is None:
167 return self.sql
168 return self.sql % self.params_type(self.params)
169
170 def _execute_query(self) -> None:
171 self.cursor = get_connection().cursor()
172 self.cursor.execute(self.sql, self.params)
173
174
175class ExplainInfo(NamedTuple):
176 """Information about an EXPLAIN query."""
177
178 format: str | None
179 options: dict[str, Any]
180
181
182class TransformWrapper:
183 """Wrapper for transform functions that supports the has_transforms attribute.
184
185 This replaces functools.partial for transform functions, allowing proper
186 type checking while supporting dynamic attribute assignment.
187 """
188
189 def __init__(
190 self,
191 func: Callable[..., BaseExpression],
192 **kwargs: Any,
193 ):
194 self._partial = functools.partial(func, **kwargs)
195 self.has_transforms: bool = False
196
197 def __call__(self, field: Field[Any], alias: str | None) -> BaseExpression:
198 return self._partial(field, alias)
199
200
201QueryType = TypeVar("QueryType", bound="Query")
202
203
204class Query(BaseExpression):
205 """A single SQL query."""
206
207 alias_prefix = "T"
208 empty_result_set_value = None
209 subq_aliases = frozenset([alias_prefix])
210
211 base_table_class = BaseTable
212 join_class = Join
213
214 default_cols = True
215 default_ordering = True
216 standard_ordering = True
217
218 filter_is_sticky = False
219 subquery = False
220
221 # SQL-related attributes.
222 # Select and related select clauses are expressions to use in the SELECT
223 # clause of the query. The select is used for cases where we want to set up
224 # the select clause to contain other than default fields (values(),
225 # subqueries...). Note that annotations go to annotations dictionary.
226 select = ()
227 # The group_by attribute can have one of the following forms:
228 # - None: no group by at all in the query
229 # - A tuple of expressions: group by (at least) those expressions.
230 # String refs are also allowed for now.
231 # - True: group by all select fields of the model
232 # See compiler.get_group_by() for details.
233 group_by = None
234 order_by = ()
235 low_mark = 0 # Used for offset/limit.
236 high_mark = None # Used for offset/limit.
237 distinct = False
238 distinct_fields = ()
239 select_for_update = False
240 select_for_update_nowait = False
241 select_for_update_skip_locked = False
242 select_for_update_of = ()
243 select_for_no_key_update = False
244 select_related: bool | dict[str, Any] = False
245 has_select_fields = False
246 # Arbitrary limit for select_related to prevents infinite recursion.
247 max_depth = 5
248 # Holds the selects defined by a call to values() or values_list()
249 # excluding annotation_select and extra_select.
250 values_select = ()
251
252 # SQL annotation-related attributes.
253 annotation_select_mask = None
254 _annotation_select_cache = None
255
256 # These are for extensions. The contents are more or less appended verbatim
257 # to the appropriate clause.
258 extra_select_mask = None
259 _extra_select_cache = None
260
261 extra_tables = ()
262 extra_order_by = ()
263
264 # A tuple that is a set of model field names and either True, if these are
265 # the fields to defer, or False if these are the only fields to load.
266 deferred_loading = (frozenset(), True)
267
268 explain_info = None
269
270 def __init__(self, model: type[Model] | None, alias_cols: bool = True):
271 self.model = model
272 self.alias_refcount = {}
273 # alias_map is the most important data structure regarding joins.
274 # It's used for recording which joins exist in the query and what
275 # types they are. The key is the alias of the joined table (possibly
276 # the table name) and the value is a Join-like object (see
277 # sql.datastructures.Join for more information).
278 self.alias_map = {}
279 # Whether to provide alias to columns during reference resolving.
280 self.alias_cols = alias_cols
281 # Sometimes the query contains references to aliases in outer queries (as
282 # a result of split_exclude). Correct alias quoting needs to know these
283 # aliases too.
284 # Map external tables to whether they are aliased.
285 self.external_aliases = {}
286 self.table_map = {} # Maps table names to list of aliases.
287 self.used_aliases = set()
288
289 self.where = WhereNode()
290 # Maps alias -> Annotation Expression.
291 self.annotations = {}
292 # These are for extensions. The contents are more or less appended
293 # verbatim to the appropriate clause.
294 self.extra = {} # Maps col_alias -> (col_sql, params).
295
296 self._filtered_relations = {}
297
298 @property
299 def output_field(self) -> Field | None:
300 if len(self.select) == 1:
301 select = self.select[0]
302 return getattr(select, "target", None) or select.field
303 elif len(self.annotation_select) == 1:
304 return next(iter(self.annotation_select.values())).output_field
305
306 @cached_property
307 def base_table(self) -> str | None:
308 for alias in self.alias_map:
309 return alias
310
311 def __str__(self) -> str:
312 """
313 Return the query as a string of SQL with the parameter values
314 substituted in (use sql_with_params() to see the unsubstituted string).
315
316 Parameter values won't necessarily be quoted correctly, since that is
317 done by the database interface at execution time.
318 """
319 sql, params = self.sql_with_params()
320 return sql % params
321
322 def sql_with_params(self) -> SqlWithParams:
323 """
324 Return the query as an SQL string and the parameters that will be
325 substituted into the query.
326 """
327 return self.get_compiler().as_sql()
328
329 def __deepcopy__(self, memo: dict[int, Any]) -> Self:
330 """Limit the amount of work when a Query is deepcopied."""
331 result = self.clone()
332 memo[id(self)] = result
333 return result
334
335 def get_compiler(self, *, elide_empty: bool = True) -> SQLCompiler:
336 """Return a compiler instance for this query."""
337 # Import compilers here to avoid circular imports at module load time
338 from plain.postgres.sql.compiler import SQLCompiler as Compiler
339
340 return Compiler(self, get_connection(), elide_empty)
341
342 def clone(self) -> Self:
343 """
344 Return a copy of the current Query. A lightweight alternative to
345 deepcopy().
346 """
347 obj = Empty()
348 obj.__class__ = self.__class__
349 obj = cast(Self, obj) # Type checker doesn't understand __class__ reassignment
350 # Copy references to everything.
351 obj.__dict__ = self.__dict__.copy()
352 # Clone attributes that can't use shallow copy.
353 obj.alias_refcount = self.alias_refcount.copy()
354 obj.alias_map = self.alias_map.copy()
355 obj.external_aliases = self.external_aliases.copy()
356 obj.table_map = self.table_map.copy()
357 obj.where = self.where.clone()
358 obj.annotations = self.annotations.copy()
359 if self.annotation_select_mask is not None:
360 obj.annotation_select_mask = self.annotation_select_mask.copy()
361 # _annotation_select_cache cannot be copied, as doing so breaks the
362 # (necessary) state in which both annotations and
363 # _annotation_select_cache point to the same underlying objects.
364 # It will get re-populated in the cloned queryset the next time it's
365 # used.
366 obj._annotation_select_cache = None
367 obj.extra = self.extra.copy()
368 if self.extra_select_mask is not None:
369 obj.extra_select_mask = self.extra_select_mask.copy()
370 if self._extra_select_cache is not None:
371 obj._extra_select_cache = self._extra_select_cache.copy()
372 if self.select_related is not False:
373 # Use deepcopy because select_related stores fields in nested
374 # dicts.
375 obj.select_related = copy.deepcopy(obj.select_related)
376 if "subq_aliases" in self.__dict__:
377 obj.subq_aliases = self.subq_aliases.copy()
378 obj.used_aliases = self.used_aliases.copy()
379 obj._filtered_relations = self._filtered_relations.copy()
380 # Clear the cached_property, if it exists.
381 obj.__dict__.pop("base_table", None)
382 return obj
383
384 @overload
385 def chain(self, klass: None = None) -> Self: ...
386
387 @overload
388 def chain(self, klass: type[QueryType]) -> QueryType: ...
389
390 def chain(self, klass: type[Query] | None = None) -> Query:
391 """
392 Return a copy of the current Query that's ready for another operation.
393 The klass argument changes the type of the Query, e.g. UpdateQuery.
394 """
395 obj = self.clone()
396 if klass and obj.__class__ != klass:
397 obj.__class__ = klass
398 if not obj.filter_is_sticky:
399 obj.used_aliases = set()
400 obj.filter_is_sticky = False
401 if hasattr(obj, "_setup_query"):
402 obj._setup_query() # type: ignore[operator]
403 return obj
404
405 def relabeled_clone(self, change_map: dict[str, str]) -> Self:
406 clone = self.clone()
407 clone.change_aliases(change_map)
408 return clone
409
410 def _get_col(self, target: Any, field: Field, alias: str | None) -> Col:
411 if not self.alias_cols:
412 alias = None
413 return target.get_col(alias, field)
414
415 def get_aggregation(self, aggregate_exprs: dict[str, Any]) -> dict[str, Any]:
416 """
417 Return the dictionary with the values of the existing aggregations.
418 """
419 if not aggregate_exprs:
420 return {}
421 aggregates = {}
422 for alias, aggregate_expr in aggregate_exprs.items():
423 self.check_alias(alias)
424 aggregate = aggregate_expr.resolve_expression(
425 self, allow_joins=True, reuse=None, summarize=True
426 )
427 if not aggregate.contains_aggregate:
428 raise TypeError(f"{alias} is not an aggregate expression")
429 aggregates[alias] = aggregate
430 # Existing usage of aggregation can be determined by the presence of
431 # selected aggregates but also by filters against aliased aggregates.
432 _, having, qualify = self.where.split_having_qualify()
433 has_existing_aggregation = (
434 any(
435 getattr(annotation, "contains_aggregate", True)
436 for annotation in self.annotations.values()
437 )
438 or having
439 )
440 # Decide if we need to use a subquery.
441 #
442 # Existing aggregations would cause incorrect results as
443 # get_aggregation() must produce just one result and thus must not use
444 # GROUP BY.
445 #
446 # If the query has limit or distinct, or uses set operations, then
447 # those operations must be done in a subquery so that the query
448 # aggregates on the limit and/or distinct results instead of applying
449 # the distinct and limit after the aggregation.
450 if (
451 isinstance(self.group_by, tuple)
452 or self.is_sliced
453 or has_existing_aggregation
454 or qualify
455 or self.distinct
456 ):
457 inner_query = self.clone()
458 inner_query.subquery = True
459 outer_query = AggregateQuery(self.model, inner_query)
460 inner_query.select_for_update = False
461 inner_query.select_related = False
462 inner_query.set_annotation_mask(self.annotation_select)
463 # Queries with distinct_fields need ordering and when a limit is
464 # applied we must take the slice from the ordered query. Otherwise
465 # no need for ordering.
466 inner_query.clear_ordering(force=False)
467 if not inner_query.distinct:
468 # If the inner query uses default select and it has some
469 # aggregate annotations, then we must make sure the inner
470 # query is grouped by the main model's primary key. However,
471 # clearing the select clause can alter results if distinct is
472 # used.
473 if inner_query.default_cols and has_existing_aggregation:
474 assert self.model is not None, "Aggregation requires a model"
475 inner_query.group_by = (
476 self.model._model_meta.get_forward_field("id").get_col(
477 inner_query.get_initial_alias()
478 ),
479 )
480 inner_query.default_cols = False
481 if not qualify:
482 # Mask existing annotations that are not referenced by
483 # aggregates to be pushed to the outer query unless
484 # filtering against window functions is involved as it
485 # requires complex realising.
486 annotation_mask = set()
487 for aggregate in aggregates.values():
488 annotation_mask |= aggregate.get_refs()
489 inner_query.set_annotation_mask(annotation_mask)
490
491 # Add aggregates to the outer AggregateQuery. This requires making
492 # sure all columns referenced by the aggregates are selected in the
493 # inner query. It is achieved by retrieving all column references
494 # by the aggregates, explicitly selecting them in the inner query,
495 # and making sure the aggregates are repointed to them.
496 col_refs = {}
497 for alias, aggregate in aggregates.items():
498 replacements = {}
499 for col in self._gen_cols([aggregate], resolve_refs=False):
500 if not (col_ref := col_refs.get(col)):
501 index = len(col_refs) + 1
502 col_alias = f"__col{index}"
503 col_ref = Ref(col_alias, col)
504 col_refs[col] = col_ref
505 inner_query.annotations[col_alias] = col
506 inner_query.append_annotation_mask([col_alias])
507 replacements[col] = col_ref
508 outer_query.annotations[alias] = aggregate.replace_expressions(
509 replacements
510 )
511 if (
512 inner_query.select == ()
513 and not inner_query.default_cols
514 and not inner_query.annotation_select_mask
515 ):
516 # In case of Model.objects[0:3].count(), there would be no
517 # field selected in the inner query, yet we must use a subquery.
518 # So, make sure at least one field is selected.
519 assert self.model is not None, "Count with slicing requires a model"
520 inner_query.select = (
521 self.model._model_meta.get_forward_field("id").get_col(
522 inner_query.get_initial_alias()
523 ),
524 )
525 else:
526 outer_query = self
527 self.select = ()
528 self.default_cols = False
529 self.extra = {}
530 if self.annotations:
531 # Inline reference to existing annotations and mask them as
532 # they are unnecessary given only the summarized aggregations
533 # are requested.
534 replacements = {
535 Ref(alias, annotation): annotation
536 for alias, annotation in self.annotations.items()
537 }
538 self.annotations = {
539 alias: aggregate.replace_expressions(replacements)
540 for alias, aggregate in aggregates.items()
541 }
542 else:
543 self.annotations = aggregates
544 self.set_annotation_mask(aggregates)
545
546 empty_set_result = [
547 expression.empty_result_set_value
548 for expression in outer_query.annotation_select.values()
549 ]
550 elide_empty = not any(result is NotImplemented for result in empty_set_result)
551 outer_query.clear_ordering(force=True)
552 outer_query.clear_limits()
553 outer_query.select_for_update = False
554 outer_query.select_related = False
555 compiler = outer_query.get_compiler(elide_empty=elide_empty)
556 result = compiler.execute_sql(SINGLE)
557 if result is None:
558 result = empty_set_result
559 else:
560 from plain.postgres.sql.compiler import apply_converters, get_converters
561
562 converters = get_converters(
563 outer_query.annotation_select.values(), compiler.connection
564 )
565 result = next(apply_converters((result,), converters, compiler.connection))
566
567 return dict(zip(outer_query.annotation_select, result))
568
569 def get_count(self) -> int:
570 """
571 Perform a COUNT() query using the current filter constraints.
572 """
573 obj = self.clone()
574 return obj.get_aggregation({"__count": Count("*")})["__count"]
575
576 def has_filters(self) -> bool:
577 return bool(self.where)
578
579 def exists(self, limit: bool = True) -> Self:
580 q = self.clone()
581 if not (q.distinct and q.is_sliced):
582 if q.group_by is True:
583 assert self.model is not None, "GROUP BY requires a model"
584 q.add_fields(
585 (f.attname for f in self.model._model_meta.concrete_fields), False
586 )
587 # Disable GROUP BY aliases to avoid orphaning references to the
588 # SELECT clause which is about to be cleared.
589 q.set_group_by(allow_aliases=False)
590 q.clear_select_clause()
591 q.clear_ordering(force=True)
592 if limit:
593 q.set_limits(high=1)
594 q.add_annotation(Value(1), "a")
595 return q
596
597 def has_results(self) -> bool:
598 q = self.exists()
599 compiler = q.get_compiler()
600 return compiler.has_results()
601
602 def explain(self, format: str | None = None, **options: Any) -> str:
603 q = self.clone()
604 for option_name in options:
605 if (
606 not EXPLAIN_OPTIONS_PATTERN.fullmatch(option_name)
607 or "--" in option_name
608 ):
609 raise ValueError(f"Invalid option name: {option_name!r}.")
610 q.explain_info = ExplainInfo(format, options)
611 compiler = q.get_compiler()
612 return "\n".join(compiler.explain_query())
613
614 def combine(self, rhs: Query, connector: str) -> None:
615 """
616 Merge the 'rhs' query into the current one (with any 'rhs' effects
617 being applied *after* (that is, "to the right of") anything in the
618 current query. 'rhs' is not modified during a call to this function.
619
620 The 'connector' parameter describes how to connect filters from the
621 'rhs' query.
622 """
623 if self.model != rhs.model:
624 raise TypeError("Cannot combine queries on two different base models.")
625 if self.is_sliced:
626 raise TypeError("Cannot combine queries once a slice has been taken.")
627 if self.distinct != rhs.distinct:
628 raise TypeError("Cannot combine a unique query with a non-unique query.")
629 if self.distinct_fields != rhs.distinct_fields:
630 raise TypeError("Cannot combine queries with different distinct fields.")
631
632 # If lhs and rhs shares the same alias prefix, it is possible to have
633 # conflicting alias changes like T4 -> T5, T5 -> T6, which might end up
634 # as T4 -> T6 while combining two querysets. To prevent this, change an
635 # alias prefix of the rhs and update current aliases accordingly,
636 # except if the alias is the base table since it must be present in the
637 # query on both sides.
638 initial_alias = self.get_initial_alias()
639 assert initial_alias is not None
640 rhs.bump_prefix(self, exclude={initial_alias})
641
642 # Work out how to relabel the rhs aliases, if necessary.
643 change_map = {}
644 conjunction = connector == AND
645
646 # Determine which existing joins can be reused. When combining the
647 # query with AND we must recreate all joins for m2m filters. When
648 # combining with OR we can reuse joins. The reason is that in AND
649 # case a single row can't fulfill a condition like:
650 # revrel__col=1 & revrel__col=2
651 # But, there might be two different related rows matching this
652 # condition. In OR case a single True is enough, so single row is
653 # enough, too.
654 #
655 # Note that we will be creating duplicate joins for non-m2m joins in
656 # the AND case. The results will be correct but this creates too many
657 # joins. This is something that could be fixed later on.
658 reuse = set() if conjunction else set(self.alias_map)
659 joinpromoter = JoinPromoter(connector, 2, False)
660 joinpromoter.add_votes(
661 j for j in self.alias_map if self.alias_map[j].join_type == INNER
662 )
663 rhs_votes = set()
664 # Now, add the joins from rhs query into the new query (skipping base
665 # table).
666 rhs_tables = list(rhs.alias_map)[1:]
667 for alias in rhs_tables:
668 join = rhs.alias_map[alias]
669 # If the left side of the join was already relabeled, use the
670 # updated alias.
671 join = join.relabeled_clone(change_map)
672 new_alias = self.join(join, reuse=reuse)
673 if join.join_type == INNER:
674 rhs_votes.add(new_alias)
675 # We can't reuse the same join again in the query. If we have two
676 # distinct joins for the same connection in rhs query, then the
677 # combined query must have two joins, too.
678 reuse.discard(new_alias)
679 if alias != new_alias:
680 change_map[alias] = new_alias
681 if not rhs.alias_refcount[alias]:
682 # The alias was unused in the rhs query. Unref it so that it
683 # will be unused in the new query, too. We have to add and
684 # unref the alias so that join promotion has information of
685 # the join type for the unused alias.
686 self.unref_alias(new_alias)
687 joinpromoter.add_votes(rhs_votes)
688 joinpromoter.update_join_types(self)
689
690 # Combine subqueries aliases to ensure aliases relabelling properly
691 # handle subqueries when combining where and select clauses.
692 self.subq_aliases |= rhs.subq_aliases
693
694 # Now relabel a copy of the rhs where-clause and add it to the current
695 # one.
696 w = rhs.where.clone()
697 w.relabel_aliases(change_map)
698 self.where.add(w, connector)
699
700 # Selection columns and extra extensions are those provided by 'rhs'.
701 if rhs.select:
702 self.set_select([col.relabeled_clone(change_map) for col in rhs.select])
703 else:
704 self.select = ()
705
706 if connector == OR:
707 # It would be nice to be able to handle this, but the queries don't
708 # really make sense (or return consistent value sets). Not worth
709 # the extra complexity when you can write a real query instead.
710 if self.extra and rhs.extra:
711 raise ValueError(
712 "When merging querysets using 'or', you cannot have "
713 "extra(select=...) on both sides."
714 )
715 self.extra.update(rhs.extra)
716 extra_select_mask = set()
717 if self.extra_select_mask is not None:
718 extra_select_mask.update(self.extra_select_mask)
719 if rhs.extra_select_mask is not None:
720 extra_select_mask.update(rhs.extra_select_mask)
721 if extra_select_mask:
722 self.set_extra_mask(extra_select_mask)
723 self.extra_tables += rhs.extra_tables
724
725 # Ordering uses the 'rhs' ordering, unless it has none, in which case
726 # the current ordering is used.
727 self.order_by = rhs.order_by or self.order_by
728 self.extra_order_by = rhs.extra_order_by or self.extra_order_by
729
730 def _get_defer_select_mask(
731 self,
732 meta: Meta,
733 mask: dict[str, Any],
734 select_mask: dict[Any, Any] | None = None,
735 ) -> dict[Any, Any]:
736 from plain.postgres.fields.related import RelatedField
737
738 if select_mask is None:
739 select_mask = {}
740 select_mask[meta.get_forward_field("id")] = {}
741 # All concrete fields that are not part of the defer mask must be
742 # loaded. If a relational field is encountered it gets added to the
743 # mask for it be considered if `select_related` and the cycle continues
744 # by recursively calling this function.
745 for field in meta.concrete_fields:
746 field_mask = mask.pop(field.name, None)
747 if field_mask is None:
748 select_mask.setdefault(field, {})
749 elif field_mask:
750 if not isinstance(field, RelatedField):
751 raise FieldError(next(iter(field_mask)))
752 field_select_mask = select_mask.setdefault(field, {})
753 related_model = field.remote_field.model
754 self._get_defer_select_mask(
755 related_model._model_meta, field_mask, field_select_mask
756 )
757 # Remaining defer entries must be references to reverse relationships.
758 # The following code is expected to raise FieldError if it encounters
759 # a malformed defer entry.
760 for field_name, field_mask in mask.items():
761 if filtered_relation := self._filtered_relations.get(field_name):
762 relation = meta.get_reverse_relation(filtered_relation.relation_name)
763 field_select_mask = select_mask.setdefault((field_name, relation), {})
764 field = relation.field
765 else:
766 field = meta.get_reverse_relation(field_name).field
767 field_select_mask = select_mask.setdefault(field, {})
768 related_model = field.model
769 self._get_defer_select_mask(
770 related_model._model_meta, field_mask, field_select_mask
771 )
772 return select_mask
773
774 def _get_only_select_mask(
775 self,
776 meta: Meta,
777 mask: dict[str, Any],
778 select_mask: dict[Any, Any] | None = None,
779 ) -> dict[Any, Any]:
780 from plain.postgres.fields.related import RelatedField
781
782 if select_mask is None:
783 select_mask = {}
784 select_mask[meta.get_forward_field("id")] = {}
785 # Only include fields mentioned in the mask.
786 for field_name, field_mask in mask.items():
787 field = meta.get_field(field_name)
788 field_select_mask = select_mask.setdefault(field, {})
789 if field_mask:
790 if not isinstance(field, RelatedField):
791 raise FieldError(next(iter(field_mask)))
792 related_model = field.remote_field.model
793 self._get_only_select_mask(
794 related_model._model_meta, field_mask, field_select_mask
795 )
796 return select_mask
797
798 def get_select_mask(self) -> dict[Any, Any]:
799 """
800 Convert the self.deferred_loading data structure to an alternate data
801 structure, describing the field that *will* be loaded. This is used to
802 compute the columns to select from the database and also by the
803 QuerySet class to work out which fields are being initialized on each
804 model. Models that have all their fields included aren't mentioned in
805 the result, only those that have field restrictions in place.
806 """
807 field_names, defer = self.deferred_loading
808 if not field_names:
809 return {}
810 mask = {}
811 for field_name in field_names:
812 part_mask = mask
813 for part in field_name.split(LOOKUP_SEP):
814 part_mask = part_mask.setdefault(part, {})
815 assert self.model is not None, "Deferred/only field loading requires a model"
816 meta = self.model._model_meta
817 if defer:
818 return self._get_defer_select_mask(meta, mask)
819 return self._get_only_select_mask(meta, mask)
820
821 def table_alias(
822 self, table_name: str, create: bool = False, filtered_relation: Any = None
823 ) -> tuple[str, bool]:
824 """
825 Return a table alias for the given table_name and whether this is a
826 new alias or not.
827
828 If 'create' is true, a new alias is always created. Otherwise, the
829 most recently created alias for the table (if one exists) is reused.
830 """
831 alias_list = self.table_map.get(table_name)
832 if not create and alias_list:
833 alias = alias_list[0]
834 self.alias_refcount[alias] += 1
835 return alias, False
836
837 # Create a new alias for this table.
838 if alias_list:
839 alias = "%s%d" % (self.alias_prefix, len(self.alias_map) + 1) # noqa: UP031
840 alias_list.append(alias)
841 else:
842 # The first occurrence of a table uses the table name directly.
843 alias = (
844 filtered_relation.alias if filtered_relation is not None else table_name
845 )
846 self.table_map[table_name] = [alias]
847 self.alias_refcount[alias] = 1
848 return alias, True
849
850 def ref_alias(self, alias: str) -> None:
851 """Increases the reference count for this alias."""
852 self.alias_refcount[alias] += 1
853
854 def unref_alias(self, alias: str, amount: int = 1) -> None:
855 """Decreases the reference count for this alias."""
856 self.alias_refcount[alias] -= amount
857
858 def promote_joins(self, aliases: set[str] | list[str]) -> None:
859 """
860 Promote recursively the join type of given aliases and its children to
861 an outer join. If 'unconditional' is False, only promote the join if
862 it is nullable or the parent join is an outer join.
863
864 The children promotion is done to avoid join chains that contain a LOUTER
865 b INNER c. So, if we have currently a INNER b INNER c and a->b is promoted,
866 then we must also promote b->c automatically, or otherwise the promotion
867 of a->b doesn't actually change anything in the query results.
868 """
869 aliases = list(aliases)
870 while aliases:
871 alias = aliases.pop(0)
872 if self.alias_map[alias].join_type is None:
873 # This is the base table (first FROM entry) - this table
874 # isn't really joined at all in the query, so we should not
875 # alter its join type.
876 continue
877 # Only the first alias (skipped above) should have None join_type
878 assert self.alias_map[alias].join_type is not None
879 parent_alias = self.alias_map[alias].parent_alias
880 parent_louter = (
881 parent_alias and self.alias_map[parent_alias].join_type == LOUTER
882 )
883 already_louter = self.alias_map[alias].join_type == LOUTER
884 if (self.alias_map[alias].nullable or parent_louter) and not already_louter:
885 self.alias_map[alias] = self.alias_map[alias].promote()
886 # Join type of 'alias' changed, so re-examine all aliases that
887 # refer to this one.
888 aliases.extend(
889 join
890 for join in self.alias_map
891 if self.alias_map[join].parent_alias == alias
892 and join not in aliases
893 )
894
895 def demote_joins(self, aliases: set[str] | list[str]) -> None:
896 """
897 Change join type from LOUTER to INNER for all joins in aliases.
898
899 Similarly to promote_joins(), this method must ensure no join chains
900 containing first an outer, then an inner join are generated. If we
901 are demoting b->c join in chain a LOUTER b LOUTER c then we must
902 demote a->b automatically, or otherwise the demotion of b->c doesn't
903 actually change anything in the query results. .
904 """
905 aliases = list(aliases)
906 while aliases:
907 alias = aliases.pop(0)
908 if self.alias_map[alias].join_type == LOUTER:
909 self.alias_map[alias] = self.alias_map[alias].demote()
910 parent_alias = self.alias_map[alias].parent_alias
911 if self.alias_map[parent_alias].join_type == INNER:
912 aliases.append(parent_alias)
913
914 def reset_refcounts(self, to_counts: dict[str, int]) -> None:
915 """
916 Reset reference counts for aliases so that they match the value passed
917 in `to_counts`.
918 """
919 for alias, cur_refcount in self.alias_refcount.copy().items():
920 unref_amount = cur_refcount - to_counts.get(alias, 0)
921 self.unref_alias(alias, unref_amount)
922
923 def change_aliases(self, change_map: dict[str, str]) -> None:
924 """
925 Change the aliases in change_map (which maps old-alias -> new-alias),
926 relabelling any references to them in select columns and the where
927 clause.
928 """
929 # If keys and values of change_map were to intersect, an alias might be
930 # updated twice (e.g. T4 -> T5, T5 -> T6, so also T4 -> T6) depending
931 # on their order in change_map.
932 assert set(change_map).isdisjoint(change_map.values())
933
934 # 1. Update references in "select" (normal columns plus aliases),
935 # "group by" and "where".
936 self.where.relabel_aliases(change_map)
937 if isinstance(self.group_by, tuple):
938 self.group_by = tuple(
939 [col.relabeled_clone(change_map) for col in self.group_by]
940 )
941 self.select = tuple([col.relabeled_clone(change_map) for col in self.select])
942 self.annotations = self.annotations and {
943 key: col.relabeled_clone(change_map)
944 for key, col in self.annotations.items()
945 }
946
947 # 2. Rename the alias in the internal table/alias datastructures.
948 for old_alias, new_alias in change_map.items():
949 if old_alias not in self.alias_map:
950 continue
951 alias_data = self.alias_map[old_alias].relabeled_clone(change_map)
952 self.alias_map[new_alias] = alias_data
953 self.alias_refcount[new_alias] = self.alias_refcount[old_alias]
954 del self.alias_refcount[old_alias]
955 del self.alias_map[old_alias]
956
957 table_aliases = self.table_map[alias_data.table_name]
958 for pos, alias in enumerate(table_aliases):
959 if alias == old_alias:
960 table_aliases[pos] = new_alias
961 break
962 self.external_aliases = {
963 # Table is aliased or it's being changed and thus is aliased.
964 change_map.get(alias, alias): (aliased or alias in change_map)
965 for alias, aliased in self.external_aliases.items()
966 }
967
968 def bump_prefix(
969 self, other_query: Query, exclude: set[str] | dict[str, str] | None = None
970 ) -> None:
971 """
972 Change the alias prefix to the next letter in the alphabet in a way
973 that the other query's aliases and this query's aliases will not
974 conflict. Even tables that previously had no alias will get an alias
975 after this call. To prevent changing aliases use the exclude parameter.
976 """
977
978 def prefix_gen() -> TypingIterator[str]:
979 """
980 Generate a sequence of characters in alphabetical order:
981 -> 'A', 'B', 'C', ...
982
983 When the alphabet is finished, the sequence will continue with the
984 Cartesian product:
985 -> 'AA', 'AB', 'AC', ...
986 """
987 alphabet = ascii_uppercase
988 prefix = chr(ord(self.alias_prefix) + 1)
989 yield prefix
990 for n in count(1):
991 seq = alphabet[alphabet.index(prefix) :] if prefix else alphabet
992 for s in product(seq, repeat=n):
993 yield "".join(s)
994 prefix = None
995
996 if self.alias_prefix != other_query.alias_prefix:
997 # No clashes between self and outer query should be possible.
998 return
999
1000 # Explicitly avoid infinite loop. The constant divider is based on how
1001 # much depth recursive subquery references add to the stack. This value
1002 # might need to be adjusted when adding or removing function calls from
1003 # the code path in charge of performing these operations.
1004 local_recursion_limit = sys.getrecursionlimit() // 16
1005 for pos, prefix in enumerate(prefix_gen()):
1006 if prefix not in self.subq_aliases:
1007 self.alias_prefix = prefix
1008 break
1009 if pos > local_recursion_limit:
1010 raise RecursionError(
1011 "Maximum recursion depth exceeded: too many subqueries."
1012 )
1013 self.subq_aliases = self.subq_aliases.union([self.alias_prefix])
1014 other_query.subq_aliases = other_query.subq_aliases.union(self.subq_aliases)
1015 if exclude is None:
1016 exclude = {}
1017 self.change_aliases(
1018 {
1019 alias: "%s%d" % (self.alias_prefix, pos) # noqa: UP031
1020 for pos, alias in enumerate(self.alias_map)
1021 if alias not in exclude
1022 }
1023 )
1024
1025 def get_initial_alias(self) -> str | None:
1026 """
1027 Return the first alias for this query, after increasing its reference
1028 count.
1029 """
1030 if self.alias_map:
1031 alias = self.base_table
1032 self.ref_alias(alias) # type: ignore[invalid-argument-type]
1033 elif self.model:
1034 alias = self.join(
1035 self.base_table_class(self.model.model_options.db_table, None) # type: ignore[invalid-argument-type]
1036 )
1037 else:
1038 alias = None
1039 return alias
1040
1041 def count_active_tables(self) -> int:
1042 """
1043 Return the number of tables in this query with a non-zero reference
1044 count. After execution, the reference counts are zeroed, so tables
1045 added in compiler will not be seen by this method.
1046 """
1047 return len([1 for count in self.alias_refcount.values() if count])
1048
1049 def join(
1050 self,
1051 join: BaseTable | Join,
1052 reuse: set[str] | None = None,
1053 reuse_with_filtered_relation: bool = False,
1054 ) -> str:
1055 """
1056 Return an alias for the 'join', either reusing an existing alias for
1057 that join or creating a new one. 'join' is either a base_table_class or
1058 join_class.
1059
1060 The 'reuse' parameter can be either None which means all joins are
1061 reusable, or it can be a set containing the aliases that can be reused.
1062
1063 The 'reuse_with_filtered_relation' parameter is used when computing
1064 FilteredRelation instances.
1065
1066 A join is always created as LOUTER if the lhs alias is LOUTER to make
1067 sure chains like t1 LOUTER t2 INNER t3 aren't generated. All new
1068 joins are created as LOUTER if the join is nullable.
1069 """
1070 if reuse_with_filtered_relation and reuse:
1071 reuse_aliases = [
1072 a for a, j in self.alias_map.items() if a in reuse and j.equals(join)
1073 ]
1074 else:
1075 reuse_aliases = [
1076 a
1077 for a, j in self.alias_map.items()
1078 if (reuse is None or a in reuse) and j == join
1079 ]
1080 if reuse_aliases:
1081 if join.table_alias in reuse_aliases:
1082 reuse_alias = join.table_alias
1083 else:
1084 # Reuse the most recent alias of the joined table
1085 # (a many-to-many relation may be joined multiple times).
1086 reuse_alias = reuse_aliases[-1]
1087 self.ref_alias(reuse_alias)
1088 return reuse_alias
1089
1090 # No reuse is possible, so we need a new alias.
1091 alias, _ = self.table_alias(
1092 join.table_name, create=True, filtered_relation=join.filtered_relation
1093 )
1094 if isinstance(join, Join):
1095 if self.alias_map[join.parent_alias].join_type == LOUTER or join.nullable:
1096 join_type = LOUTER
1097 else:
1098 join_type = INNER
1099 join.join_type = join_type
1100 join.table_alias = alias
1101 self.alias_map[alias] = join
1102 return alias
1103
1104 def check_alias(self, alias: str) -> None:
1105 if FORBIDDEN_ALIAS_PATTERN.search(alias):
1106 raise ValueError(
1107 "Column aliases cannot contain whitespace characters, quotation marks, "
1108 "semicolons, or SQL comments."
1109 )
1110
1111 def add_annotation(
1112 self, annotation: BaseExpression, alias: str, select: bool = True
1113 ) -> None:
1114 """Add a single annotation expression to the Query."""
1115 self.check_alias(alias)
1116 annotation = annotation.resolve_expression(self, allow_joins=True, reuse=None)
1117 if select:
1118 self.append_annotation_mask([alias])
1119 else:
1120 self.set_annotation_mask(set(self.annotation_select).difference({alias}))
1121 self.annotations[alias] = annotation
1122
1123 def resolve_expression(
1124 self,
1125 query: Any = None,
1126 allow_joins: bool = True,
1127 reuse: Any = None,
1128 summarize: bool = False,
1129 for_save: bool = False,
1130 ) -> Self:
1131 clone = self.clone()
1132 # Subqueries need to use a different set of aliases than the outer query.
1133 clone.bump_prefix(query)
1134 clone.subquery = True
1135 clone.where.resolve_expression(query, allow_joins, reuse, summarize, for_save)
1136 for key, value in clone.annotations.items():
1137 resolved = value.resolve_expression(
1138 query, allow_joins, reuse, summarize, for_save
1139 )
1140 if hasattr(resolved, "external_aliases"):
1141 resolved.external_aliases.update(clone.external_aliases)
1142 clone.annotations[key] = resolved
1143 # Outer query's aliases are considered external.
1144 for alias, table in query.alias_map.items():
1145 clone.external_aliases[alias] = (
1146 isinstance(table, Join)
1147 and table.join_field.related_model.model_options.db_table != alias
1148 ) or (
1149 isinstance(table, BaseTable) and table.table_name != table.table_alias
1150 )
1151 return clone
1152
1153 def get_external_cols(self) -> list[Col]:
1154 exprs = chain(self.annotations.values(), self.where.children)
1155 return [
1156 col
1157 for col in self._gen_cols(exprs, include_external=True)
1158 if col.alias in self.external_aliases
1159 ]
1160
1161 def get_group_by_cols(
1162 self, wrapper: BaseExpression | None = None
1163 ) -> list[BaseExpression]:
1164 # If wrapper is referenced by an alias for an explicit GROUP BY through
1165 # values() a reference to this expression and not the self must be
1166 # returned to ensure external column references are not grouped against
1167 # as well.
1168 external_cols = self.get_external_cols()
1169 if any(col.possibly_multivalued for col in external_cols):
1170 return [wrapper or self]
1171 # Cast needed because list is invariant: list[Col] is not list[BaseExpression]
1172 return cast(list[BaseExpression], external_cols)
1173
1174 def as_sql(
1175 self, compiler: SQLCompiler, connection: DatabaseConnection
1176 ) -> SqlWithParams:
1177 sql, params = self.get_compiler().as_sql()
1178 if self.subquery:
1179 sql = f"({sql})"
1180 return sql, params
1181
1182 def resolve_lookup_value(
1183 self, value: Any, can_reuse: set[str] | None, allow_joins: bool
1184 ) -> Any:
1185 if isinstance(value, ResolvableExpression):
1186 value = value.resolve_expression(
1187 self,
1188 reuse=can_reuse,
1189 allow_joins=allow_joins,
1190 )
1191 elif isinstance(value, list | tuple):
1192 # The items of the iterable may be expressions and therefore need
1193 # to be resolved independently.
1194 values = (
1195 self.resolve_lookup_value(sub_value, can_reuse, allow_joins)
1196 for sub_value in value
1197 )
1198 type_ = type(value)
1199 if hasattr(type_, "_make"): # namedtuple
1200 return type_(*values)
1201 return type_(values)
1202 return value
1203
1204 def solve_lookup_type(
1205 self, lookup: str, summarize: bool = False
1206 ) -> tuple[
1207 list[str] | tuple[str, ...], tuple[str, ...], BaseExpression | Literal[False]
1208 ]:
1209 """
1210 Solve the lookup type from the lookup (e.g.: 'foobar__id__icontains').
1211 """
1212 lookup_splitted = lookup.split(LOOKUP_SEP)
1213 if self.annotations:
1214 annotation, expression_lookups = refs_expression(
1215 lookup_splitted, self.annotations
1216 )
1217 if annotation:
1218 expression = self.annotations[annotation]
1219 if summarize:
1220 expression = Ref(annotation, expression)
1221 return expression_lookups, (), expression
1222 assert self.model is not None, "Field lookups require a model"
1223 meta = self.model._model_meta
1224 _, field, _, lookup_parts = self.names_to_path(lookup_splitted, meta)
1225 field_parts = lookup_splitted[0 : len(lookup_splitted) - len(lookup_parts)]
1226 if len(lookup_parts) > 1 and not field_parts:
1227 raise FieldError(
1228 f'Invalid lookup "{lookup}" for model {meta.model.__name__}".'
1229 )
1230 return lookup_parts, tuple(field_parts), False
1231
1232 def check_query_object_type(
1233 self, value: Any, meta: Meta, field: Field | ForeignObjectRel
1234 ) -> None:
1235 """
1236 Check whether the object passed while querying is of the correct type.
1237 If not, raise a ValueError specifying the wrong object.
1238 """
1239 from plain.postgres import Model
1240
1241 if isinstance(value, Model):
1242 if not check_rel_lookup_compatibility(value._model_meta.model, meta, field):
1243 raise ValueError(
1244 f'Cannot query "{value}": Must be "{meta.model.model_options.object_name}" instance.'
1245 )
1246
1247 def check_related_objects(
1248 self, field: RelatedField | ForeignObjectRel, value: Any, meta: Meta
1249 ) -> None:
1250 """Check the type of object passed to query relations."""
1251 from plain.postgres import Model
1252
1253 # Check that the field and the queryset use the same model in a
1254 # query like .filter(author=Author.query.all()). For example, the
1255 # meta would be Author's (from the author field) and value.model
1256 # would be Author.query.all() queryset's .model (Author also).
1257 # The field is the related field on the lhs side.
1258 if (
1259 isinstance(value, Query)
1260 and not value.has_select_fields
1261 and not check_rel_lookup_compatibility(value.model, meta, field)
1262 ):
1263 raise ValueError(
1264 f'Cannot use QuerySet for "{value.model.model_options.object_name}": Use a QuerySet for "{meta.model.model_options.object_name}".'
1265 )
1266 elif isinstance(value, Model):
1267 self.check_query_object_type(value, meta, field)
1268 elif isinstance(value, Iterable):
1269 for v in value:
1270 self.check_query_object_type(v, meta, field)
1271
1272 def check_filterable(self, expression: Any) -> None:
1273 """Raise an error if expression cannot be used in a WHERE clause."""
1274 if isinstance(expression, ResolvableExpression) and not getattr(
1275 expression, "filterable", True
1276 ):
1277 raise NotSupportedError(
1278 expression.__class__.__name__ + " is disallowed in the filter clause."
1279 )
1280 if hasattr(expression, "get_source_expressions"):
1281 for expr in expression.get_source_expressions():
1282 self.check_filterable(expr)
1283
1284 def build_lookup(
1285 self, lookups: list[str], lhs: BaseExpression | MultiColSource, rhs: Any
1286 ) -> Lookup | None:
1287 """
1288 Try to extract transforms and lookup from given lhs.
1289
1290 The lhs value is something that works like SQLExpression.
1291 The rhs value is what the lookup is going to compare against.
1292 The lookups is a list of names to extract using get_lookup()
1293 and get_transform().
1294 """
1295 # __exact is the default lookup if one isn't given.
1296 *transforms, lookup_name = lookups or ["exact"]
1297 if transforms:
1298 if isinstance(lhs, MultiColSource):
1299 raise FieldError(
1300 "Transforms are not supported on multi-column relations."
1301 )
1302 # At this point, lhs must be BaseExpression
1303 for name in transforms:
1304 lhs = self.try_transform(lhs, name)
1305 # First try get_lookup() so that the lookup takes precedence if the lhs
1306 # supports both transform and lookup for the name.
1307 lookup_class = lhs.get_lookup(lookup_name)
1308 if not lookup_class:
1309 # A lookup wasn't found. Try to interpret the name as a transform
1310 # and do an Exact lookup against it.
1311 if isinstance(lhs, MultiColSource):
1312 raise FieldError(
1313 "Transforms are not supported on multi-column relations."
1314 )
1315 lhs = self.try_transform(lhs, lookup_name)
1316 lookup_name = "exact"
1317 lookup_class = lhs.get_lookup(lookup_name)
1318 if not lookup_class:
1319 return
1320
1321 lookup = lookup_class(lhs, rhs)
1322 # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
1323 # uses of None as a query value unless the lookup supports it.
1324 if lookup.rhs is None and not lookup.can_use_none_as_rhs:
1325 if lookup_name not in ("exact", "iexact"):
1326 raise ValueError("Cannot use None as a query value")
1327 isnull_lookup = lhs.get_lookup("isnull")
1328 assert isnull_lookup is not None
1329 return isnull_lookup(lhs, True)
1330
1331 return lookup
1332
1333 def try_transform(self, lhs: BaseExpression, name: str) -> BaseExpression:
1334 """
1335 Helper method for build_lookup(). Try to fetch and initialize
1336 a transform for name parameter from lhs.
1337 """
1338 transform_class = lhs.get_transform(name)
1339 if transform_class:
1340 return transform_class(lhs)
1341 else:
1342 output_field = lhs.output_field.__class__
1343 suggested_lookups = difflib.get_close_matches(
1344 name, output_field.get_lookups()
1345 )
1346 if suggested_lookups:
1347 suggestion = ", perhaps you meant {}?".format(
1348 " or ".join(suggested_lookups)
1349 )
1350 else:
1351 suggestion = "."
1352 raise FieldError(
1353 f"Unsupported lookup '{name}' for {output_field.__name__} or join on the field not "
1354 f"permitted{suggestion}"
1355 )
1356
1357 def build_filter(
1358 self,
1359 filter_expr: tuple[str, Any] | Q | BaseExpression,
1360 branch_negated: bool = False,
1361 current_negated: bool = False,
1362 can_reuse: set[str] | None = None,
1363 allow_joins: bool = True,
1364 split_subq: bool = True,
1365 reuse_with_filtered_relation: bool = False,
1366 check_filterable: bool = True,
1367 summarize: bool = False,
1368 ) -> tuple[WhereNode, set[str] | tuple[()]]:
1369 from plain.postgres.fields.related import RelatedField
1370
1371 """
1372 Build a WhereNode for a single filter clause but don't add it
1373 to this Query. Query.add_q() will then add this filter to the where
1374 Node.
1375
1376 The 'branch_negated' tells us if the current branch contains any
1377 negations. This will be used to determine if subqueries are needed.
1378
1379 The 'current_negated' is used to determine if the current filter is
1380 negated or not and this will be used to determine if IS NULL filtering
1381 is needed.
1382
1383 The difference between current_negated and branch_negated is that
1384 branch_negated is set on first negation, but current_negated is
1385 flipped for each negation.
1386
1387 Note that add_filter will not do any negating itself, that is done
1388 upper in the code by add_q().
1389
1390 The 'can_reuse' is a set of reusable joins for multijoins.
1391
1392 If 'reuse_with_filtered_relation' is True, then only joins in can_reuse
1393 will be reused.
1394
1395 The method will create a filter clause that can be added to the current
1396 query. However, if the filter isn't added to the query then the caller
1397 is responsible for unreffing the joins used.
1398 """
1399 if isinstance(filter_expr, dict):
1400 raise FieldError("Cannot parse keyword query as dict")
1401 if isinstance(filter_expr, Q):
1402 return self._add_q(
1403 filter_expr,
1404 branch_negated=branch_negated,
1405 current_negated=current_negated,
1406 used_aliases=can_reuse,
1407 allow_joins=allow_joins,
1408 split_subq=split_subq,
1409 check_filterable=check_filterable,
1410 summarize=summarize,
1411 )
1412 if isinstance(filter_expr, ResolvableExpression):
1413 if not getattr(filter_expr, "conditional", False):
1414 raise TypeError("Cannot filter against a non-conditional expression.")
1415 condition = filter_expr.resolve_expression(
1416 self, allow_joins=allow_joins, summarize=summarize
1417 )
1418 if not isinstance(condition, Lookup):
1419 condition = self.build_lookup(["exact"], condition, True)
1420 return WhereNode([condition], connector=AND), set()
1421 if isinstance(filter_expr, BaseExpression):
1422 raise TypeError(f"Unexpected BaseExpression type: {type(filter_expr)}")
1423 arg, value = filter_expr
1424 if not arg:
1425 raise FieldError(f"Cannot parse keyword query {arg!r}")
1426 lookups, parts, reffed_expression = self.solve_lookup_type(arg, summarize)
1427
1428 if check_filterable:
1429 self.check_filterable(reffed_expression)
1430
1431 if not allow_joins and len(parts) > 1:
1432 raise FieldError("Joined field references are not permitted in this query")
1433
1434 pre_joins = self.alias_refcount.copy()
1435 value = self.resolve_lookup_value(value, can_reuse, allow_joins)
1436 used_joins = {
1437 k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)
1438 }
1439
1440 if check_filterable:
1441 self.check_filterable(value)
1442
1443 if reffed_expression:
1444 condition = self.build_lookup(list(lookups), reffed_expression, value)
1445 return WhereNode([condition], connector=AND), set()
1446
1447 assert self.model is not None, "Building filters requires a model"
1448 meta = self.model._model_meta
1449 alias = self.get_initial_alias()
1450 assert alias is not None
1451 allow_many = not branch_negated or not split_subq
1452
1453 try:
1454 join_info = self.setup_joins(
1455 list(parts),
1456 meta,
1457 alias,
1458 can_reuse=can_reuse,
1459 allow_many=allow_many,
1460 reuse_with_filtered_relation=reuse_with_filtered_relation,
1461 )
1462
1463 # Prevent iterator from being consumed by check_related_objects()
1464 if isinstance(value, Iterator):
1465 value = list(value)
1466 from plain.postgres.fields.related import RelatedField
1467 from plain.postgres.fields.reverse_related import ForeignObjectRel
1468
1469 if isinstance(join_info.final_field, RelatedField | ForeignObjectRel):
1470 self.check_related_objects(join_info.final_field, value, join_info.meta)
1471
1472 # split_exclude() needs to know which joins were generated for the
1473 # lookup parts
1474 self._lookup_joins = join_info.joins
1475 except MultiJoin as e:
1476 return self.split_exclude(
1477 filter_expr,
1478 can_reuse or set(),
1479 e.names_with_path,
1480 )
1481
1482 # Update used_joins before trimming since they are reused to determine
1483 # which joins could be later promoted to INNER.
1484 used_joins.update(join_info.joins)
1485 targets, alias, join_list = self.trim_joins(
1486 join_info.targets, join_info.joins, join_info.path
1487 )
1488 if can_reuse is not None:
1489 can_reuse.update(join_list)
1490
1491 if isinstance(join_info.final_field, RelatedField | ForeignObjectRel):
1492 if len(targets) == 1:
1493 col = self._get_col(targets[0], join_info.final_field, alias)
1494 else:
1495 col = MultiColSource(
1496 alias, targets, join_info.targets, join_info.final_field
1497 )
1498 else:
1499 col = self._get_col(targets[0], join_info.final_field, alias)
1500
1501 condition = self.build_lookup(list(lookups), col, value)
1502 assert condition is not None
1503 lookup_type = condition.lookup_name
1504 clause = WhereNode([condition], connector=AND)
1505
1506 require_outer = (
1507 lookup_type == "isnull" and condition.rhs is True and not current_negated
1508 )
1509 if (
1510 current_negated
1511 and (lookup_type != "isnull" or condition.rhs is False)
1512 and condition.rhs is not None
1513 ):
1514 require_outer = True
1515 if lookup_type != "isnull":
1516 # The condition added here will be SQL like this:
1517 # NOT (col IS NOT NULL), where the first NOT is added in
1518 # upper layers of code. The reason for addition is that if col
1519 # is null, then col != someval will result in SQL "unknown"
1520 # which isn't the same as in Python. The Python None handling
1521 # is wanted, and it can be gotten by
1522 # (col IS NULL OR col != someval)
1523 # <=>
1524 # NOT (col IS NOT NULL AND col = someval).
1525 if (
1526 self.is_nullable(targets[0])
1527 or self.alias_map[join_list[-1]].join_type == LOUTER
1528 ):
1529 lookup_class = targets[0].get_lookup("isnull")
1530 assert lookup_class is not None
1531 col = self._get_col(targets[0], join_info.targets[0], alias)
1532 clause.add(lookup_class(col, False), AND)
1533 # If someval is a nullable column, someval IS NOT NULL is
1534 # added.
1535 if isinstance(value, Col) and self.is_nullable(value.target):
1536 lookup_class = value.target.get_lookup("isnull")
1537 assert lookup_class is not None
1538 clause.add(lookup_class(value, False), AND)
1539 return clause, used_joins if not require_outer else ()
1540
1541 def add_filter(self, filter_lhs: str, filter_rhs: Any) -> None:
1542 self.add_q(Q((filter_lhs, filter_rhs)))
1543
1544 def add_q(self, q_object: Q) -> None:
1545 """
1546 A preprocessor for the internal _add_q(). Responsible for doing final
1547 join promotion.
1548 """
1549 # For join promotion this case is doing an AND for the added q_object
1550 # and existing conditions. So, any existing inner join forces the join
1551 # type to remain inner. Existing outer joins can however be demoted.
1552 # (Consider case where rel_a is LOUTER and rel_a__col=1 is added - if
1553 # rel_a doesn't produce any rows, then the whole condition must fail.
1554 # So, demotion is OK.
1555 existing_inner = {
1556 a for a in self.alias_map if self.alias_map[a].join_type == INNER
1557 }
1558 clause, _ = self._add_q(q_object, self.used_aliases)
1559 if clause:
1560 self.where.add(clause, AND)
1561 self.demote_joins(existing_inner)
1562
1563 def build_where(
1564 self, filter_expr: tuple[str, Any] | Q | BaseExpression
1565 ) -> WhereNode:
1566 return self.build_filter(filter_expr, allow_joins=False)[0]
1567
1568 def clear_where(self) -> None:
1569 self.where = WhereNode()
1570
1571 def _add_q(
1572 self,
1573 q_object: Q,
1574 used_aliases: set[str] | None,
1575 branch_negated: bool = False,
1576 current_negated: bool = False,
1577 allow_joins: bool = True,
1578 split_subq: bool = True,
1579 check_filterable: bool = True,
1580 summarize: bool = False,
1581 ) -> tuple[WhereNode, set[str] | tuple[()]]:
1582 """Add a Q-object to the current filter."""
1583 connector = q_object.connector
1584 current_negated ^= q_object.negated
1585 branch_negated = branch_negated or q_object.negated
1586 target_clause = WhereNode(connector=connector, negated=q_object.negated)
1587 joinpromoter = JoinPromoter(
1588 q_object.connector, len(q_object.children), current_negated
1589 )
1590 for child in q_object.children:
1591 child_clause, needed_inner = self.build_filter(
1592 child,
1593 can_reuse=used_aliases,
1594 branch_negated=branch_negated,
1595 current_negated=current_negated,
1596 allow_joins=allow_joins,
1597 split_subq=split_subq,
1598 check_filterable=check_filterable,
1599 summarize=summarize,
1600 )
1601 joinpromoter.add_votes(needed_inner)
1602 if child_clause:
1603 target_clause.add(child_clause, connector)
1604 needed_inner = joinpromoter.update_join_types(self)
1605 return target_clause, needed_inner
1606
1607 def build_filtered_relation_q(
1608 self,
1609 q_object: Q,
1610 reuse: set[str],
1611 branch_negated: bool = False,
1612 current_negated: bool = False,
1613 ) -> WhereNode:
1614 """Add a FilteredRelation object to the current filter."""
1615 connector = q_object.connector
1616 current_negated ^= q_object.negated
1617 branch_negated = branch_negated or q_object.negated
1618 target_clause = WhereNode(connector=connector, negated=q_object.negated)
1619 for child in q_object.children:
1620 if isinstance(child, Node):
1621 child_clause = self.build_filtered_relation_q(
1622 child,
1623 reuse=reuse,
1624 branch_negated=branch_negated,
1625 current_negated=current_negated,
1626 )
1627 else:
1628 child_clause, _ = self.build_filter(
1629 child,
1630 can_reuse=reuse,
1631 branch_negated=branch_negated,
1632 current_negated=current_negated,
1633 allow_joins=True,
1634 split_subq=False,
1635 reuse_with_filtered_relation=True,
1636 )
1637 target_clause.add(child_clause, connector)
1638 return target_clause
1639
1640 def add_filtered_relation(self, filtered_relation: Any, alias: str) -> None:
1641 filtered_relation.alias = alias
1642 lookups = dict(get_children_from_q(filtered_relation.condition))
1643 relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(
1644 filtered_relation.relation_name
1645 )
1646 if relation_lookup_parts:
1647 raise ValueError(
1648 "FilteredRelation's relation_name cannot contain lookups "
1649 f"(got {filtered_relation.relation_name!r})."
1650 )
1651 for lookup in chain(lookups):
1652 lookup_parts, lookup_field_parts, _ = self.solve_lookup_type(lookup)
1653 shift = 2 if not lookup_parts else 1
1654 lookup_field_path = lookup_field_parts[:-shift]
1655 for idx, lookup_field_part in enumerate(lookup_field_path):
1656 if len(relation_field_parts) > idx:
1657 if relation_field_parts[idx] != lookup_field_part:
1658 raise ValueError(
1659 "FilteredRelation's condition doesn't support "
1660 f"relations outside the {filtered_relation.relation_name!r} (got {lookup!r})."
1661 )
1662 else:
1663 raise ValueError(
1664 "FilteredRelation's condition doesn't support nested "
1665 f"relations deeper than the relation_name (got {lookup!r} for "
1666 f"{filtered_relation.relation_name!r})."
1667 )
1668 self._filtered_relations[filtered_relation.alias] = filtered_relation
1669
1670 def names_to_path(
1671 self,
1672 names: list[str],
1673 meta: Meta,
1674 allow_many: bool = True,
1675 fail_on_missing: bool = False,
1676 ) -> tuple[list[Any], Field | ForeignObjectRel, tuple[Field, ...], list[str]]:
1677 """
1678 Walk the list of names and turns them into PathInfo tuples. A single
1679 name in 'names' can generate multiple PathInfos (m2m, for example).
1680
1681 'names' is the path of names to travel, 'meta' is the Meta we
1682 start the name resolving from, 'allow_many' is as for setup_joins().
1683 If fail_on_missing is set to True, then a name that can't be resolved
1684 will generate a FieldError.
1685
1686 Return a list of PathInfo tuples. In addition return the final field
1687 (the last used join field) and target (which is a field guaranteed to
1688 contain the same value as the final field). Finally, return those names
1689 that weren't found (which are likely transforms and the final lookup).
1690 """
1691 path, names_with_path = [], []
1692 for pos, name in enumerate(names):
1693 cur_names_with_path = (name, [])
1694
1695 field = None
1696 filtered_relation = None
1697 try:
1698 if meta is None:
1699 raise FieldDoesNotExist
1700 field = meta.get_field(name)
1701 except FieldDoesNotExist:
1702 if name in self.annotation_select:
1703 field = self.annotation_select[name].output_field
1704 elif name in self._filtered_relations and pos == 0:
1705 filtered_relation = self._filtered_relations[name]
1706 if LOOKUP_SEP in filtered_relation.relation_name:
1707 parts = filtered_relation.relation_name.split(LOOKUP_SEP)
1708 filtered_relation_path, field, _, _ = self.names_to_path(
1709 parts,
1710 meta,
1711 allow_many,
1712 fail_on_missing,
1713 )
1714 path.extend(filtered_relation_path[:-1])
1715 else:
1716 field = meta.get_field(filtered_relation.relation_name)
1717 if field is None:
1718 # We didn't find the current field, so move position back
1719 # one step.
1720 pos -= 1
1721 if pos == -1 or fail_on_missing:
1722 available = sorted(
1723 [
1724 *get_field_names_from_opts(meta),
1725 *self.annotation_select,
1726 *self._filtered_relations,
1727 ]
1728 )
1729 raise FieldError(
1730 "Cannot resolve keyword '{}' into field. "
1731 "Choices are: {}".format(name, ", ".join(available))
1732 )
1733 break
1734
1735 # Lazy import to avoid circular dependency
1736 from plain.postgres.fields.related import ForeignKeyField as FK
1737 from plain.postgres.fields.related import ManyToManyField as M2M
1738 from plain.postgres.fields.reverse_related import ForeignObjectRel as FORel
1739
1740 if isinstance(field, FK | M2M | FORel):
1741 pathinfos: list[PathInfo]
1742 if filtered_relation:
1743 pathinfos = field.get_path_info(filtered_relation)
1744 else:
1745 pathinfos = field.path_infos
1746 if not allow_many:
1747 for inner_pos, p in enumerate(pathinfos):
1748 if p.m2m:
1749 cur_names_with_path[1].extend(pathinfos[0 : inner_pos + 1])
1750 names_with_path.append(cur_names_with_path)
1751 raise MultiJoin(pos + 1, names_with_path)
1752 last = pathinfos[-1]
1753 path.extend(pathinfos)
1754 final_field = last.join_field
1755 meta = last.to_meta
1756 targets = last.target_fields
1757 cur_names_with_path[1].extend(pathinfos)
1758 names_with_path.append(cur_names_with_path)
1759 else:
1760 # Local non-relational field.
1761 final_field = field
1762 targets = (field,)
1763 if fail_on_missing and pos + 1 != len(names):
1764 raise FieldError(
1765 f"Cannot resolve keyword {names[pos + 1]!r} into field. Join on '{name}'"
1766 " not permitted."
1767 )
1768 break
1769 return path, final_field, targets, names[pos + 1 :]
1770
1771 def setup_joins(
1772 self,
1773 names: list[str],
1774 meta: Meta,
1775 alias: str,
1776 can_reuse: set[str] | None = None,
1777 allow_many: bool = True,
1778 reuse_with_filtered_relation: bool = False,
1779 ) -> JoinInfo:
1780 """
1781 Compute the necessary table joins for the passage through the fields
1782 given in 'names'. 'meta' is the Meta for the current model
1783 (which gives the table we are starting from), 'alias' is the alias for
1784 the table to start the joining from.
1785
1786 The 'can_reuse' defines the reverse foreign key joins we can reuse. It
1787 can be None in which case all joins are reusable or a set of aliases
1788 that can be reused. Note that non-reverse foreign keys are always
1789 reusable when using setup_joins().
1790
1791 The 'reuse_with_filtered_relation' can be used to force 'can_reuse'
1792 parameter and force the relation on the given connections.
1793
1794 If 'allow_many' is False, then any reverse foreign key seen will
1795 generate a MultiJoin exception.
1796
1797 Return the final field involved in the joins, the target field (used
1798 for any 'where' constraint), the final 'opts' value, the joins, the
1799 field path traveled to generate the joins, and a transform function
1800 that takes a field and alias and is equivalent to `field.get_col(alias)`
1801 in the simple case but wraps field transforms if they were included in
1802 names.
1803
1804 The target field is the field containing the concrete value. Final
1805 field can be something different, for example foreign key pointing to
1806 that value. Final field is needed for example in some value
1807 conversions (convert 'obj' in fk__id=obj to pk val using the foreign
1808 key field for example).
1809 """
1810 joins = [alias]
1811 # The transform can't be applied yet, as joins must be trimmed later.
1812 # To avoid making every caller of this method look up transforms
1813 # directly, compute transforms here and create a partial that converts
1814 # fields to the appropriate wrapped version.
1815
1816 def _base_transformer(field: Field, alias: str | None) -> Col:
1817 if not self.alias_cols:
1818 alias = None
1819 return field.get_col(alias)
1820
1821 final_transformer: TransformWrapper | Callable[[Field, str | None], Col] = (
1822 _base_transformer
1823 )
1824
1825 # Try resolving all the names as fields first. If there's an error,
1826 # treat trailing names as lookups until a field can be resolved.
1827 last_field_exception = None
1828 for pivot in range(len(names), 0, -1):
1829 try:
1830 path, final_field, targets, rest = self.names_to_path(
1831 names[:pivot],
1832 meta,
1833 allow_many,
1834 fail_on_missing=True,
1835 )
1836 except FieldError as exc:
1837 if pivot == 1:
1838 # The first item cannot be a lookup, so it's safe
1839 # to raise the field error here.
1840 raise
1841 else:
1842 last_field_exception = exc
1843 else:
1844 # The transforms are the remaining items that couldn't be
1845 # resolved into fields.
1846 transforms = names[pivot:]
1847 break
1848 for name in transforms:
1849
1850 def transform(
1851 field: Field, alias: str | None, *, name: str, previous: Any
1852 ) -> BaseExpression:
1853 try:
1854 wrapped = previous(field, alias)
1855 return self.try_transform(wrapped, name)
1856 except FieldError:
1857 # FieldError is raised if the transform doesn't exist.
1858 if isinstance(final_field, Field) and last_field_exception:
1859 raise last_field_exception
1860 else:
1861 raise
1862
1863 final_transformer = TransformWrapper(
1864 transform, name=name, previous=final_transformer
1865 )
1866 final_transformer.has_transforms = True
1867 # Then, add the path to the query's joins. Note that we can't trim
1868 # joins at this stage - we will need the information about join type
1869 # of the trimmed joins.
1870 for join in path:
1871 if join.filtered_relation:
1872 filtered_relation = join.filtered_relation.clone()
1873 table_alias = filtered_relation.alias
1874 else:
1875 filtered_relation = None
1876 table_alias = None
1877 meta = join.to_meta
1878 if join.direct:
1879 nullable = self.is_nullable(join.join_field)
1880 else:
1881 nullable = True
1882 connection = self.join_class(
1883 meta.model.model_options.db_table,
1884 alias,
1885 table_alias, # type: ignore[invalid-argument-type]
1886 INNER,
1887 join.join_field,
1888 nullable,
1889 filtered_relation=filtered_relation,
1890 )
1891 reuse = can_reuse if join.m2m or reuse_with_filtered_relation else None
1892 alias = self.join(
1893 connection,
1894 reuse=reuse,
1895 reuse_with_filtered_relation=reuse_with_filtered_relation,
1896 )
1897 joins.append(alias)
1898 if filtered_relation:
1899 filtered_relation.path = joins[:]
1900 return JoinInfo(final_field, targets, meta, joins, path, final_transformer) # type: ignore[invalid-argument-type]
1901
1902 def trim_joins(
1903 self, targets: tuple[Field, ...], joins: list[str], path: list[Any]
1904 ) -> tuple[tuple[Field, ...], str, list[str]]:
1905 """
1906 The 'target' parameter is the final field being joined to, 'joins'
1907 is the full list of join aliases. The 'path' contain the PathInfos
1908 used to create the joins.
1909
1910 Return the final target field and table alias and the new active
1911 joins.
1912
1913 Always trim any direct join if the target column is already in the
1914 previous table. Can't trim reverse joins as it's unknown if there's
1915 anything on the other side of the join.
1916 """
1917 joins = joins[:]
1918 for pos, info in enumerate(reversed(path)):
1919 if len(joins) == 1 or not info.direct:
1920 break
1921 if info.filtered_relation:
1922 break
1923 join_targets = {t.column for t in info.join_field.foreign_related_fields}
1924 cur_targets = {t.column for t in targets}
1925 if not cur_targets.issubset(join_targets):
1926 break
1927 targets_dict = {
1928 r[1].column: r[0]
1929 for r in info.join_field.related_fields
1930 if r[1].column in cur_targets
1931 }
1932 targets = tuple(targets_dict[t.column] for t in targets)
1933 self.unref_alias(joins.pop())
1934 return targets, joins[-1], joins
1935
1936 @classmethod
1937 def _gen_cols(
1938 cls,
1939 exprs: Iterable[Any],
1940 include_external: bool = False,
1941 resolve_refs: bool = True,
1942 ) -> TypingIterator[Col]:
1943 for expr in exprs:
1944 if isinstance(expr, Col):
1945 yield expr
1946 elif include_external and callable(
1947 getattr(expr, "get_external_cols", None)
1948 ):
1949 yield from expr.get_external_cols()
1950 elif hasattr(expr, "get_source_expressions"):
1951 if not resolve_refs and isinstance(expr, Ref):
1952 continue
1953 yield from cls._gen_cols(
1954 expr.get_source_expressions(),
1955 include_external=include_external,
1956 resolve_refs=resolve_refs,
1957 )
1958
1959 @classmethod
1960 def _gen_col_aliases(cls, exprs: Iterable[Any]) -> TypingIterator[str]:
1961 yield from (expr.alias for expr in cls._gen_cols(exprs))
1962
1963 def resolve_ref(
1964 self,
1965 name: str,
1966 allow_joins: bool = True,
1967 reuse: set[str] | None = None,
1968 summarize: bool = False,
1969 ) -> BaseExpression:
1970 annotation = self.annotations.get(name)
1971 if annotation is not None:
1972 if not allow_joins:
1973 for alias in self._gen_col_aliases([annotation]):
1974 if isinstance(self.alias_map[alias], Join):
1975 raise FieldError(
1976 "Joined field references are not permitted in this query"
1977 )
1978 if summarize:
1979 # Summarize currently means we are doing an aggregate() query
1980 # which is executed as a wrapped subquery if any of the
1981 # aggregate() elements reference an existing annotation. In
1982 # that case we need to return a Ref to the subquery's annotation.
1983 if name not in self.annotation_select:
1984 raise FieldError(
1985 f"Cannot aggregate over the '{name}' alias. Use annotate() "
1986 "to promote it."
1987 )
1988 return Ref(name, self.annotation_select[name])
1989 else:
1990 return annotation
1991 else:
1992 field_list = name.split(LOOKUP_SEP)
1993 annotation = self.annotations.get(field_list[0])
1994 if annotation is not None:
1995 for transform in field_list[1:]:
1996 annotation = self.try_transform(annotation, transform)
1997 return annotation
1998 initial_alias = self.get_initial_alias()
1999 assert initial_alias is not None
2000 assert self.model is not None, "Resolving field references requires a model"
2001 meta = self.model._model_meta
2002 join_info = self.setup_joins(
2003 field_list,
2004 meta,
2005 initial_alias,
2006 can_reuse=reuse,
2007 )
2008 targets, final_alias, join_list = self.trim_joins(
2009 join_info.targets, join_info.joins, join_info.path
2010 )
2011 if not allow_joins and len(join_list) > 1:
2012 raise FieldError(
2013 "Joined field references are not permitted in this query"
2014 )
2015 if len(targets) > 1:
2016 raise FieldError(
2017 "Referencing multicolumn fields with F() objects isn't supported"
2018 )
2019 # Verify that the last lookup in name is a field or a transform:
2020 # transform_function() raises FieldError if not.
2021 transform = join_info.transform_function(targets[0], final_alias)
2022 if reuse is not None:
2023 reuse.update(join_list)
2024 return transform
2025
2026 def split_exclude(
2027 self,
2028 filter_expr: tuple[str, Any],
2029 can_reuse: set[str],
2030 names_with_path: list[tuple[str, list[Any]]],
2031 ) -> tuple[WhereNode, set[str] | tuple[()]]:
2032 """
2033 When doing an exclude against any kind of N-to-many relation, we need
2034 to use a subquery. This method constructs the nested query, given the
2035 original exclude filter (filter_expr) and the portion up to the first
2036 N-to-many relation field.
2037
2038 For example, if the origin filter is ~Q(child__name='foo'), filter_expr
2039 is ('child__name', 'foo') and can_reuse is a set of joins usable for
2040 filters in the original query.
2041
2042 We will turn this into equivalent of:
2043 WHERE NOT EXISTS(
2044 SELECT 1
2045 FROM child
2046 WHERE name = 'foo' AND child.parent_id = parent.id
2047 LIMIT 1
2048 )
2049 """
2050 # Generate the inner query.
2051 query = self.__class__(self.model)
2052 query._filtered_relations = self._filtered_relations
2053 filter_lhs, filter_rhs = filter_expr
2054 if isinstance(filter_rhs, OuterRef):
2055 filter_rhs = OuterRef(filter_rhs)
2056 elif isinstance(filter_rhs, F):
2057 filter_rhs = OuterRef(filter_rhs.name)
2058 query.add_filter(filter_lhs, filter_rhs)
2059 query.clear_ordering(force=True)
2060 # Try to have as simple as possible subquery -> trim leading joins from
2061 # the subquery.
2062 trimmed_prefix, contains_louter = query.trim_start(names_with_path)
2063
2064 col = query.select[0]
2065 select_field = col.target
2066 alias = col.alias
2067 if alias in can_reuse:
2068 id_field = select_field.model._model_meta.get_forward_field("id")
2069 # Need to add a restriction so that outer query's filters are in effect for
2070 # the subquery, too.
2071 query.bump_prefix(self)
2072 lookup_class = select_field.get_lookup("exact")
2073 # Note that the query.select[0].alias is different from alias
2074 # due to bump_prefix above.
2075 lookup = lookup_class(
2076 id_field.get_col(query.select[0].alias), id_field.get_col(alias)
2077 )
2078 query.where.add(lookup, AND)
2079 query.external_aliases[alias] = True
2080
2081 lookup_class = select_field.get_lookup("exact")
2082 lookup = lookup_class(col, ResolvedOuterRef(trimmed_prefix))
2083 query.where.add(lookup, AND)
2084 condition, needed_inner = self.build_filter(Exists(query))
2085
2086 if contains_louter:
2087 or_null_condition, _ = self.build_filter(
2088 (f"{trimmed_prefix}__isnull", True),
2089 current_negated=True,
2090 branch_negated=True,
2091 can_reuse=can_reuse,
2092 )
2093 condition.add(or_null_condition, OR)
2094 # Note that the end result will be:
2095 # (outercol NOT IN innerq AND outercol IS NOT NULL) OR outercol IS NULL.
2096 # This might look crazy but due to how IN works, this seems to be
2097 # correct. If the IS NOT NULL check is removed then outercol NOT
2098 # IN will return UNKNOWN. If the IS NULL check is removed, then if
2099 # outercol IS NULL we will not match the row.
2100 return condition, needed_inner
2101
2102 def set_empty(self) -> None:
2103 self.where.add(NothingNode(), AND)
2104
2105 def is_empty(self) -> bool:
2106 return any(isinstance(c, NothingNode) for c in self.where.children)
2107
2108 def set_limits(self, low: int | None = None, high: int | None = None) -> None:
2109 """
2110 Adjust the limits on the rows retrieved. Use low/high to set these,
2111 as it makes it more Pythonic to read and write. When the SQL query is
2112 created, convert them to the appropriate offset and limit values.
2113
2114 Apply any limits passed in here to the existing constraints. Add low
2115 to the current low value and clamp both to any existing high value.
2116 """
2117 if high is not None:
2118 if self.high_mark is not None:
2119 self.high_mark = min(self.high_mark, self.low_mark + high)
2120 else:
2121 self.high_mark = self.low_mark + high
2122 if low is not None:
2123 if self.high_mark is not None:
2124 self.low_mark = min(self.high_mark, self.low_mark + low)
2125 else:
2126 self.low_mark = self.low_mark + low
2127
2128 if self.low_mark == self.high_mark:
2129 self.set_empty()
2130
2131 def clear_limits(self) -> None:
2132 """Clear any existing limits."""
2133 self.low_mark, self.high_mark = 0, None
2134
2135 @property
2136 def is_sliced(self) -> bool:
2137 return self.low_mark != 0 or self.high_mark is not None
2138
2139 def has_limit_one(self) -> bool:
2140 return self.high_mark is not None and (self.high_mark - self.low_mark) == 1
2141
2142 def can_filter(self) -> bool:
2143 """
2144 Return True if adding filters to this instance is still possible.
2145
2146 Typically, this means no limits or offsets have been put on the results.
2147 """
2148 return not self.is_sliced
2149
2150 def clear_select_clause(self) -> None:
2151 """Remove all fields from SELECT clause."""
2152 self.select = ()
2153 self.default_cols = False
2154 self.select_related = False
2155 self.set_extra_mask(())
2156 self.set_annotation_mask(())
2157
2158 def clear_select_fields(self) -> None:
2159 """
2160 Clear the list of fields to select (but not extra_select columns).
2161 Some queryset types completely replace any existing list of select
2162 columns.
2163 """
2164 self.select = ()
2165 self.values_select = ()
2166
2167 def add_select_col(self, col: BaseExpression, name: str) -> None:
2168 self.select += (col,)
2169 self.values_select += (name,)
2170
2171 def set_select(self, cols: list[Col] | tuple[Col, ...]) -> None:
2172 self.default_cols = False
2173 self.select = tuple(cols)
2174
2175 def add_distinct_fields(self, *field_names: str) -> None:
2176 """
2177 Add and resolve the given fields to the query's "distinct on" clause.
2178 """
2179 self.distinct_fields = field_names
2180 self.distinct = True
2181
2182 def add_fields(
2183 self, field_names: list[str] | TypingIterator[str], allow_m2m: bool = True
2184 ) -> None:
2185 """
2186 Add the given (model) fields to the select set. Add the field names in
2187 the order specified.
2188 """
2189 alias = self.get_initial_alias()
2190 assert alias is not None
2191 assert self.model is not None, "add_fields() requires a model"
2192 meta = self.model._model_meta
2193
2194 try:
2195 cols = []
2196 for name in field_names:
2197 # Join promotion note - we must not remove any rows here, so
2198 # if there is no existing joins, use outer join.
2199 join_info = self.setup_joins(
2200 name.split(LOOKUP_SEP), meta, alias, allow_many=allow_m2m
2201 )
2202 targets, final_alias, joins = self.trim_joins(
2203 join_info.targets,
2204 join_info.joins,
2205 join_info.path,
2206 )
2207 for target in targets:
2208 cols.append(join_info.transform_function(target, final_alias))
2209 if cols:
2210 self.set_select(cols)
2211 except MultiJoin:
2212 raise FieldError(f"Invalid field name: '{name}'")
2213 except FieldError:
2214 if LOOKUP_SEP in name:
2215 # For lookups spanning over relationships, show the error
2216 # from the model on which the lookup failed.
2217 raise
2218 elif name in self.annotations:
2219 raise FieldError(
2220 f"Cannot select the '{name}' alias. Use annotate() to promote it."
2221 )
2222 else:
2223 names = sorted(
2224 [
2225 *get_field_names_from_opts(meta),
2226 *self.extra,
2227 *self.annotation_select,
2228 *self._filtered_relations,
2229 ]
2230 )
2231 raise FieldError(
2232 "Cannot resolve keyword {!r} into field. Choices are: {}".format(
2233 name, ", ".join(names)
2234 )
2235 )
2236
2237 def add_ordering(self, *ordering: str | BaseExpression) -> None:
2238 """
2239 Add items from the 'ordering' sequence to the query's "order by"
2240 clause. These items are either field names (not column names) --
2241 possibly with a direction prefix ('-' or '?') -- or OrderBy
2242 expressions.
2243
2244 If 'ordering' is empty, clear all ordering from the query.
2245 """
2246 errors = []
2247 for item in ordering:
2248 if isinstance(item, str):
2249 if item == "?":
2250 continue
2251 item = item.removeprefix("-")
2252 if item in self.annotations:
2253 continue
2254 if self.extra and item in self.extra:
2255 continue
2256 # names_to_path() validates the lookup. A descriptive
2257 # FieldError will be raise if it's not.
2258 assert self.model is not None, "ORDER BY field names require a model"
2259 self.names_to_path(item.split(LOOKUP_SEP), self.model._model_meta)
2260 elif not isinstance(item, ResolvableExpression):
2261 errors.append(item)
2262 if getattr(item, "contains_aggregate", False):
2263 raise FieldError(
2264 "Using an aggregate in order_by() without also including "
2265 f"it in annotate() is not allowed: {item}"
2266 )
2267 if errors:
2268 raise FieldError(f"Invalid order_by arguments: {errors}")
2269 if ordering:
2270 self.order_by += ordering
2271 else:
2272 self.default_ordering = False
2273
2274 def clear_ordering(self, force: bool = False, clear_default: bool = True) -> None:
2275 """
2276 Remove any ordering settings if the current query allows it without
2277 side effects, set 'force' to True to clear the ordering regardless.
2278 If 'clear_default' is True, there will be no ordering in the resulting
2279 query (not even the model's default).
2280 """
2281 if not force and (
2282 self.is_sliced or self.distinct_fields or self.select_for_update
2283 ):
2284 return
2285 self.order_by = ()
2286 self.extra_order_by = ()
2287 if clear_default:
2288 self.default_ordering = False
2289
2290 def set_group_by(self, allow_aliases: bool = True) -> None:
2291 """
2292 Expand the GROUP BY clause required by the query.
2293
2294 This will usually be the set of all non-aggregate fields in the
2295 return data. If the database backend supports grouping by the
2296 primary key, and the query would be equivalent, the optimization
2297 will be made automatically.
2298 """
2299 if allow_aliases and self.values_select:
2300 # If grouping by aliases is allowed assign selected value aliases
2301 # by moving them to annotations.
2302 group_by_annotations = {}
2303 values_select = {}
2304 for alias, expr in zip(self.values_select, self.select):
2305 if isinstance(expr, Col):
2306 values_select[alias] = expr
2307 else:
2308 group_by_annotations[alias] = expr
2309 self.annotations = {**group_by_annotations, **self.annotations}
2310 self.append_annotation_mask(group_by_annotations)
2311 self.select = tuple(values_select.values())
2312 self.values_select = tuple(values_select)
2313 group_by = list(self.select)
2314 for alias, annotation in self.annotation_select.items():
2315 if not (group_by_cols := annotation.get_group_by_cols()):
2316 continue
2317 if allow_aliases and not annotation.contains_aggregate:
2318 group_by.append(Ref(alias, annotation))
2319 else:
2320 group_by.extend(group_by_cols)
2321 self.group_by = tuple(group_by)
2322
2323 def add_select_related(self, fields: list[str]) -> None:
2324 """
2325 Set up the select_related data structure so that we only select
2326 certain related models (as opposed to all models, when
2327 self.select_related=True).
2328 """
2329 if isinstance(self.select_related, bool):
2330 field_dict: dict[str, Any] = {}
2331 else:
2332 field_dict = self.select_related
2333 for field in fields:
2334 d = field_dict
2335 for part in field.split(LOOKUP_SEP):
2336 d = d.setdefault(part, {})
2337 self.select_related = field_dict
2338
2339 def add_extra(
2340 self,
2341 select: dict[str, str],
2342 select_params: list[Any] | None,
2343 where: list[str],
2344 params: list[Any],
2345 tables: list[str],
2346 order_by: tuple[str, ...],
2347 ) -> None:
2348 """
2349 Add data to the various extra_* attributes for user-created additions
2350 to the query.
2351 """
2352 if select:
2353 # We need to pair any placeholder markers in the 'select'
2354 # dictionary with their parameters in 'select_params' so that
2355 # subsequent updates to the select dictionary also adjust the
2356 # parameters appropriately.
2357 select_pairs = {}
2358 if select_params:
2359 param_iter = iter(select_params)
2360 else:
2361 param_iter = iter([])
2362 for name, entry in select.items():
2363 self.check_alias(name)
2364 entry = str(entry)
2365 entry_params = []
2366 pos = entry.find("%s")
2367 while pos != -1:
2368 if pos == 0 or entry[pos - 1] != "%":
2369 entry_params.append(next(param_iter))
2370 pos = entry.find("%s", pos + 2)
2371 select_pairs[name] = (entry, entry_params)
2372 self.extra.update(select_pairs)
2373 if where or params:
2374 self.where.add(ExtraWhere(where, params), AND)
2375 if tables:
2376 self.extra_tables += tuple(tables)
2377 if order_by:
2378 self.extra_order_by = order_by
2379
2380 def clear_deferred_loading(self) -> None:
2381 """Remove any fields from the deferred loading set."""
2382 self.deferred_loading = (frozenset(), True)
2383
2384 def add_deferred_loading(self, field_names: frozenset[str]) -> None:
2385 """
2386 Add the given list of model field names to the set of fields to
2387 exclude from loading from the database when automatic column selection
2388 is done. Add the new field names to any existing field names that
2389 are deferred (or removed from any existing field names that are marked
2390 as the only ones for immediate loading).
2391 """
2392 # Fields on related models are stored in the literal double-underscore
2393 # format, so that we can use a set datastructure. We do the foo__bar
2394 # splitting and handling when computing the SQL column names (as part of
2395 # get_columns()).
2396 existing, defer = self.deferred_loading
2397 existing_set = set(existing)
2398 if defer:
2399 # Add to existing deferred names.
2400 self.deferred_loading = frozenset(existing_set.union(field_names)), True
2401 else:
2402 # Remove names from the set of any existing "immediate load" names.
2403 if new_existing := existing_set.difference(field_names):
2404 self.deferred_loading = frozenset(new_existing), False
2405 else:
2406 self.clear_deferred_loading()
2407 if new_only := set(field_names).difference(existing_set):
2408 self.deferred_loading = frozenset(new_only), True
2409
2410 def add_immediate_loading(self, field_names: list[str] | set[str]) -> None:
2411 """
2412 Add the given list of model field names to the set of fields to
2413 retrieve when the SQL is executed ("immediate loading" fields). The
2414 field names replace any existing immediate loading field names. If
2415 there are field names already specified for deferred loading, remove
2416 those names from the new field_names before storing the new names
2417 for immediate loading. (That is, immediate loading overrides any
2418 existing immediate values, but respects existing deferrals.)
2419 """
2420 existing, defer = self.deferred_loading
2421 field_names_set = set(field_names)
2422
2423 if defer:
2424 # Remove any existing deferred names from the current set before
2425 # setting the new names.
2426 self.deferred_loading = (
2427 frozenset(field_names_set.difference(existing)),
2428 False,
2429 )
2430 else:
2431 # Replace any existing "immediate load" field names.
2432 self.deferred_loading = frozenset(field_names_set), False
2433
2434 def set_annotation_mask(
2435 self,
2436 names: set[str]
2437 | frozenset[str]
2438 | list[str]
2439 | tuple[str, ...]
2440 | dict[str, Any]
2441 | None,
2442 ) -> None:
2443 """Set the mask of annotations that will be returned by the SELECT."""
2444 if names is None:
2445 self.annotation_select_mask = None
2446 else:
2447 self.annotation_select_mask = set(names)
2448 self._annotation_select_cache = None
2449
2450 def append_annotation_mask(self, names: list[str] | dict[str, Any]) -> None:
2451 if self.annotation_select_mask is not None:
2452 self.set_annotation_mask(self.annotation_select_mask.union(names))
2453
2454 def set_extra_mask(
2455 self, names: set[str] | list[str] | tuple[str, ...] | None
2456 ) -> None:
2457 """
2458 Set the mask of extra select items that will be returned by SELECT.
2459 Don't remove them from the Query since they might be used later.
2460 """
2461 if names is None:
2462 self.extra_select_mask = None
2463 else:
2464 self.extra_select_mask = set(names)
2465 self._extra_select_cache = None
2466
2467 def set_values(self, fields: list[str]) -> None:
2468 self.select_related = False
2469 self.clear_deferred_loading()
2470 self.clear_select_fields()
2471 self.has_select_fields = True
2472
2473 if fields:
2474 field_names = []
2475 extra_names = []
2476 annotation_names = []
2477 if not self.extra and not self.annotations:
2478 # Shortcut - if there are no extra or annotations, then
2479 # the values() clause must be just field names.
2480 field_names = list(fields)
2481 else:
2482 self.default_cols = False
2483 for f in fields:
2484 if f in self.extra_select:
2485 extra_names.append(f)
2486 elif f in self.annotation_select:
2487 annotation_names.append(f)
2488 else:
2489 field_names.append(f)
2490 self.set_extra_mask(extra_names)
2491 self.set_annotation_mask(annotation_names)
2492 selected = frozenset(field_names + extra_names + annotation_names)
2493 else:
2494 assert self.model is not None, "Default values query requires a model"
2495 field_names = [f.attname for f in self.model._model_meta.concrete_fields]
2496 selected = frozenset(field_names)
2497 # Selected annotations must be known before setting the GROUP BY
2498 # clause.
2499 if self.group_by is True:
2500 assert self.model is not None, "GROUP BY True requires a model"
2501 self.add_fields(
2502 (f.attname for f in self.model._model_meta.concrete_fields), False
2503 )
2504 # Disable GROUP BY aliases to avoid orphaning references to the
2505 # SELECT clause which is about to be cleared.
2506 self.set_group_by(allow_aliases=False)
2507 self.clear_select_fields()
2508 elif self.group_by:
2509 # Resolve GROUP BY annotation references if they are not part of
2510 # the selected fields anymore.
2511 group_by = []
2512 for expr in self.group_by:
2513 if isinstance(expr, Ref) and expr.refs not in selected:
2514 expr = self.annotations[expr.refs]
2515 group_by.append(expr)
2516 self.group_by = tuple(group_by)
2517
2518 self.values_select = tuple(field_names)
2519 self.add_fields(field_names, True)
2520
2521 @property
2522 def annotation_select(self) -> dict[str, BaseExpression]:
2523 """
2524 Return the dictionary of aggregate columns that are not masked and
2525 should be used in the SELECT clause. Cache this result for performance.
2526 """
2527 if self._annotation_select_cache is not None:
2528 return self._annotation_select_cache
2529 elif not self.annotations:
2530 return {}
2531 elif self.annotation_select_mask is not None:
2532 self._annotation_select_cache = {
2533 k: v
2534 for k, v in self.annotations.items()
2535 if k in self.annotation_select_mask
2536 }
2537 return self._annotation_select_cache
2538 else:
2539 return self.annotations
2540
2541 @property
2542 def extra_select(self) -> dict[str, tuple[str, list[Any]]]:
2543 if self._extra_select_cache is not None:
2544 return self._extra_select_cache
2545 if not self.extra:
2546 return {}
2547 elif self.extra_select_mask is not None:
2548 self._extra_select_cache = {
2549 k: v for k, v in self.extra.items() if k in self.extra_select_mask
2550 }
2551 return self._extra_select_cache
2552 else:
2553 return self.extra
2554
2555 def trim_start(
2556 self, names_with_path: list[tuple[str, list[Any]]]
2557 ) -> tuple[str, bool]:
2558 """
2559 Trim joins from the start of the join path. The candidates for trim
2560 are the PathInfos in names_with_path structure that are m2m joins.
2561
2562 Also set the select column so the start matches the join.
2563
2564 This method is meant to be used for generating the subquery joins &
2565 cols in split_exclude().
2566
2567 Return a lookup usable for doing outerq.filter(lookup=self) and a
2568 boolean indicating if the joins in the prefix contain a LEFT OUTER join.
2569 _"""
2570 all_paths = []
2571 for _, paths in names_with_path:
2572 all_paths.extend(paths)
2573 contains_louter = False
2574 # Trim and operate only on tables that were generated for
2575 # the lookup part of the query. That is, avoid trimming
2576 # joins generated for F() expressions.
2577 lookup_tables = [
2578 t for t in self.alias_map if t in self._lookup_joins or t == self.base_table
2579 ]
2580 for trimmed_paths, path in enumerate(all_paths):
2581 if path.m2m:
2582 break
2583 if self.alias_map[lookup_tables[trimmed_paths + 1]].join_type == LOUTER:
2584 contains_louter = True
2585 alias = lookup_tables[trimmed_paths]
2586 self.unref_alias(alias)
2587 # The path.join_field is a Rel, lets get the other side's field
2588 join_field = path.join_field.field
2589 # Build the filter prefix.
2590 paths_in_prefix = trimmed_paths
2591 trimmed_prefix = []
2592 for name, path in names_with_path:
2593 if paths_in_prefix - len(path) < 0:
2594 break
2595 trimmed_prefix.append(name)
2596 paths_in_prefix -= len(path)
2597 trimmed_prefix.append(join_field.foreign_related_fields[0].name)
2598 trimmed_prefix = LOOKUP_SEP.join(trimmed_prefix)
2599 # Lets still see if we can trim the first join from the inner query
2600 # (that is, self). We can't do this for:
2601 # - LEFT JOINs because we would miss those rows that have nothing on
2602 # the outer side,
2603 # - INNER JOINs from filtered relations because we would miss their
2604 # filters.
2605 first_join = self.alias_map[lookup_tables[trimmed_paths + 1]]
2606 if first_join.join_type != LOUTER and not first_join.filtered_relation:
2607 select_fields = [r[0] for r in join_field.related_fields]
2608 select_alias = lookup_tables[trimmed_paths + 1]
2609 self.unref_alias(lookup_tables[trimmed_paths])
2610 else:
2611 # TODO: It might be possible to trim more joins from the start of the
2612 # inner query if it happens to have a longer join chain containing the
2613 # values in select_fields. Lets punt this one for now.
2614 select_fields = [r[1] for r in join_field.related_fields]
2615 select_alias = lookup_tables[trimmed_paths]
2616 # The found starting point is likely a join_class instead of a
2617 # base_table_class reference. But the first entry in the query's FROM
2618 # clause must not be a JOIN.
2619 for table in self.alias_map:
2620 if self.alias_refcount[table] > 0:
2621 self.alias_map[table] = self.base_table_class(
2622 self.alias_map[table].table_name,
2623 table,
2624 )
2625 break
2626 self.set_select([f.get_col(select_alias) for f in select_fields])
2627 return trimmed_prefix, contains_louter
2628
2629 def is_nullable(self, field: Field) -> bool:
2630 """Check if the given field should be treated as nullable."""
2631 # QuerySet does not have knowledge of which connection is going to be
2632 # used. For the single-database setup we always reference the default
2633 # connection here.
2634 return field.allow_null
2635
2636
2637def get_order_dir(field: str, default: str = "ASC") -> tuple[str, str]:
2638 """
2639 Return the field name and direction for an order specification. For
2640 example, '-foo' is returned as ('foo', 'DESC').
2641
2642 The 'default' param is used to indicate which way no prefix (or a '+'
2643 prefix) should sort. The '-' prefix always sorts the opposite way.
2644 """
2645 dirn = ORDER_DIR[default]
2646 if field[0] == "-":
2647 return field[1:], dirn[1]
2648 return field, dirn[0]
2649
2650
2651class JoinPromoter:
2652 """
2653 A class to abstract away join promotion problems for complex filter
2654 conditions.
2655 """
2656
2657 def __init__(self, connector: str, num_children: int, negated: bool):
2658 self.connector = connector
2659 self.negated = negated
2660 if self.negated:
2661 if connector == AND:
2662 self.effective_connector = OR
2663 else:
2664 self.effective_connector = AND
2665 else:
2666 self.effective_connector = self.connector
2667 self.num_children = num_children
2668 # Maps of table alias to how many times it is seen as required for
2669 # inner and/or outer joins.
2670 self.votes = Counter()
2671
2672 def __repr__(self) -> str:
2673 return (
2674 f"{self.__class__.__qualname__}(connector={self.connector!r}, "
2675 f"num_children={self.num_children!r}, negated={self.negated!r})"
2676 )
2677
2678 def add_votes(self, votes: Any) -> None:
2679 """
2680 Add single vote per item to self.votes. Parameter can be any
2681 iterable.
2682 """
2683 self.votes.update(votes)
2684
2685 def update_join_types(self, query: Query) -> set[str]:
2686 """
2687 Change join types so that the generated query is as efficient as
2688 possible, but still correct. So, change as many joins as possible
2689 to INNER, but don't make OUTER joins INNER if that could remove
2690 results from the query.
2691 """
2692 to_promote = set()
2693 to_demote = set()
2694 # The effective_connector is used so that NOT (a AND b) is treated
2695 # similarly to (a OR b) for join promotion.
2696 for table, votes in self.votes.items():
2697 # We must use outer joins in OR case when the join isn't contained
2698 # in all of the joins. Otherwise the INNER JOIN itself could remove
2699 # valid results. Consider the case where a model with rel_a and
2700 # rel_b relations is queried with rel_a__col=1 | rel_b__col=2. Now,
2701 # if rel_a join doesn't produce any results is null (for example
2702 # reverse foreign key or null value in direct foreign key), and
2703 # there is a matching row in rel_b with col=2, then an INNER join
2704 # to rel_a would remove a valid match from the query. So, we need
2705 # to promote any existing INNER to LOUTER (it is possible this
2706 # promotion in turn will be demoted later on).
2707 if self.effective_connector == OR and votes < self.num_children:
2708 to_promote.add(table)
2709 # If connector is AND and there is a filter that can match only
2710 # when there is a joinable row, then use INNER. For example, in
2711 # rel_a__col=1 & rel_b__col=2, if either of the rels produce NULL
2712 # as join output, then the col=1 or col=2 can't match (as
2713 # NULL=anything is always false).
2714 # For the OR case, if all children voted for a join to be inner,
2715 # then we can use INNER for the join. For example:
2716 # (rel_a__col__icontains=Alex | rel_a__col__icontains=Russell)
2717 # then if rel_a doesn't produce any rows, the whole condition
2718 # can't match. Hence we can safely use INNER join.
2719 if self.effective_connector == AND or (
2720 self.effective_connector == OR and votes == self.num_children
2721 ):
2722 to_demote.add(table)
2723 # Finally, what happens in cases where we have:
2724 # (rel_a__col=1|rel_b__col=2) & rel_a__col__gte=0
2725 # Now, we first generate the OR clause, and promote joins for it
2726 # in the first if branch above. Both rel_a and rel_b are promoted
2727 # to LOUTER joins. After that we do the AND case. The OR case
2728 # voted no inner joins but the rel_a__col__gte=0 votes inner join
2729 # for rel_a. We demote it back to INNER join (in AND case a single
2730 # vote is enough). The demotion is OK, if rel_a doesn't produce
2731 # rows, then the rel_a__col__gte=0 clause can't be true, and thus
2732 # the whole clause must be false. So, it is safe to use INNER
2733 # join.
2734 # Note that in this example we could just as well have the __gte
2735 # clause and the OR clause swapped. Or we could replace the __gte
2736 # clause with an OR clause containing rel_a__col=1|rel_a__col=2,
2737 # and again we could safely demote to INNER.
2738 query.promote_joins(to_promote)
2739 query.demote_joins(to_demote)
2740 return to_demote
2741
2742
2743# ##### Query subclasses (merged from subqueries.py) #####
2744
2745
2746class DeleteQuery(Query):
2747 """A DELETE SQL query."""
2748
2749 def get_compiler(self, *, elide_empty: bool = True) -> SQLDeleteCompiler:
2750 from plain.postgres.sql.compiler import SQLDeleteCompiler
2751
2752 return SQLDeleteCompiler(self, get_connection(), elide_empty)
2753
2754 def do_query(self, table: str, where: Any) -> int:
2755 from plain.postgres.sql.constants import CURSOR
2756
2757 self.alias_map = {table: self.alias_map[table]}
2758 self.where = where
2759 cursor = self.get_compiler().execute_sql(CURSOR)
2760 if cursor:
2761 with cursor:
2762 return cursor.rowcount
2763 return 0
2764
2765 def delete_batch(self, id_list: list[Any]) -> int:
2766 """
2767 Set up and execute delete queries for all the objects in id_list.
2768
2769 More than one physical query may be executed if there are a
2770 lot of values in id_list.
2771 """
2772 from plain.postgres.sql.constants import GET_ITERATOR_CHUNK_SIZE
2773
2774 # number of objects deleted
2775 num_deleted = 0
2776 assert self.model is not None, "DELETE requires a model"
2777 meta = self.model._model_meta
2778 field = meta.get_forward_field("id")
2779 for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
2780 self.clear_where()
2781 self.add_filter(
2782 f"{field.attname}__in",
2783 id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
2784 )
2785 num_deleted += self.do_query(self.model.model_options.db_table, self.where)
2786 return num_deleted
2787
2788
2789class UpdateQuery(Query):
2790 """An UPDATE SQL query."""
2791
2792 def get_compiler(self, *, elide_empty: bool = True) -> SQLUpdateCompiler:
2793 from plain.postgres.sql.compiler import SQLUpdateCompiler
2794
2795 return SQLUpdateCompiler(self, get_connection(), elide_empty)
2796
2797 def __init__(self, *args: Any, **kwargs: Any) -> None:
2798 super().__init__(*args, **kwargs)
2799 self._setup_query()
2800
2801 def _setup_query(self) -> None:
2802 """
2803 Run on initialization and at the end of chaining. Any attributes that
2804 would normally be set in __init__() should go here instead.
2805 """
2806 self.values: list[tuple[Any, Any, Any]] = []
2807 self.related_ids: dict[Any, list[Any]] | None = None
2808 self.related_updates: dict[Any, list[tuple[Any, Any, Any]]] = {}
2809
2810 def clone(self) -> UpdateQuery:
2811 obj = super().clone()
2812 obj.related_updates = self.related_updates.copy()
2813 return obj
2814
2815 def update_batch(self, id_list: list[Any], values: dict[str, Any]) -> None:
2816 from plain.postgres.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
2817
2818 self.add_update_values(values)
2819 for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
2820 self.clear_where()
2821 self.add_filter(
2822 "id__in", id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
2823 )
2824 self.get_compiler().execute_sql(NO_RESULTS)
2825
2826 def add_update_values(self, values: dict[str, Any]) -> None:
2827 """
2828 Convert a dictionary of field name to value mappings into an update
2829 query. This is the entry point for the public update() method on
2830 querysets.
2831 """
2832
2833 assert self.model is not None, "UPDATE requires model metadata"
2834 meta = self.model._model_meta
2835 values_seq = []
2836 for name, val in values.items():
2837 field = meta.get_field(name)
2838 direct = (
2839 not (field.auto_created and not field.concrete) or not field.concrete
2840 )
2841 model = field.model
2842 from plain.postgres.fields.related import ManyToManyField
2843
2844 if not direct or isinstance(field, ManyToManyField):
2845 raise FieldError(
2846 f"Cannot update model field {field!r} (only non-relations and "
2847 "foreign keys permitted)."
2848 )
2849 if model is not meta.model:
2850 self.add_related_update(model, field, val)
2851 continue
2852 values_seq.append((field, model, val))
2853 return self.add_update_fields(values_seq)
2854
2855 def add_update_fields(self, values_seq: list[tuple[Any, Any, Any]]) -> None:
2856 """
2857 Append a sequence of (field, model, value) triples to the internal list
2858 that will be used to generate the UPDATE query. Might be more usefully
2859 called add_update_targets() to hint at the extra information here.
2860 """
2861 for field, model, val in values_seq:
2862 if isinstance(val, ResolvableExpression):
2863 # Resolve expressions here so that annotations are no longer needed
2864 val = val.resolve_expression(self, allow_joins=False, for_save=True)
2865 self.values.append((field, model, val))
2866
2867 def add_related_update(self, model: Any, field: Any, value: Any) -> None:
2868 """
2869 Add (name, value) to an update query for an ancestor model.
2870
2871 Update are coalesced so that only one update query per ancestor is run.
2872 """
2873 self.related_updates.setdefault(model, []).append((field, None, value))
2874
2875 def get_related_updates(self) -> list[UpdateQuery]:
2876 """
2877 Return a list of query objects: one for each update required to an
2878 ancestor model. Each query will have the same filtering conditions as
2879 the current query but will only update a single table.
2880 """
2881 if not self.related_updates:
2882 return []
2883 result = []
2884 for model, values in self.related_updates.items():
2885 query = UpdateQuery(model)
2886 query.values = values
2887 if self.related_ids is not None:
2888 query.add_filter("id__in", self.related_ids[model])
2889 result.append(query)
2890 return result
2891
2892
2893class InsertQuery(Query):
2894 def get_compiler(self, *, elide_empty: bool = True) -> SQLInsertCompiler:
2895 from plain.postgres.sql.compiler import SQLInsertCompiler
2896
2897 return SQLInsertCompiler(self, get_connection(), elide_empty)
2898
2899 def __str__(self) -> str:
2900 raise NotImplementedError(
2901 "InsertQuery does not support __str__(). "
2902 "Use get_compiler().as_sql() which returns a list of SQL statements."
2903 )
2904
2905 def sql_with_params(self) -> Any:
2906 raise NotImplementedError(
2907 "InsertQuery does not support sql_with_params(). "
2908 "Use get_compiler().as_sql() which returns a list of SQL statements."
2909 )
2910
2911 def __init__(
2912 self,
2913 *args: Any,
2914 on_conflict: OnConflict | None = None,
2915 update_fields: list[Field] | None = None,
2916 unique_fields: list[Field] | None = None,
2917 **kwargs: Any,
2918 ) -> None:
2919 super().__init__(*args, **kwargs)
2920 self.fields: list[Field] = []
2921 self.objs: list[Any] = []
2922 self.on_conflict = on_conflict
2923 self.update_fields: list[Field] = update_fields or []
2924 self.unique_fields: list[Field] = unique_fields or []
2925
2926 def insert_values(
2927 self, fields: list[Any], objs: list[Any], raw: bool = False
2928 ) -> None:
2929 self.fields = fields
2930 self.objs = objs
2931 self.raw = raw
2932
2933
2934class AggregateQuery(Query):
2935 """
2936 Take another query as a parameter to the FROM clause and only select the
2937 elements in the provided list.
2938 """
2939
2940 def get_compiler(self, *, elide_empty: bool = True) -> SQLAggregateCompiler:
2941 from plain.postgres.sql.compiler import SQLAggregateCompiler
2942
2943 return SQLAggregateCompiler(self, get_connection(), elide_empty)
2944
2945 def __init__(self, model: Any, inner_query: Any) -> None:
2946 self.inner_query = inner_query
2947 super().__init__(model)