1from __future__ import annotations
2
3import collections
4import json
5import re
6from collections.abc import Generator, Iterable, Sequence
7from functools import cached_property, partial
8from itertools import chain
9from typing import TYPE_CHECKING, Any, Protocol, cast
10
11from plain.models.constants import LOOKUP_SEP
12from plain.models.exceptions import EmptyResultSet, FieldError, FullResultSet
13from plain.models.expressions import (
14 F,
15 OrderBy,
16 RawSQL,
17 Ref,
18 ResolvableExpression,
19 Value,
20)
21from plain.models.fields.related import RelatedField
22from plain.models.functions import Cast, Random
23from plain.models.lookups import Lookup
24from plain.models.meta import Meta
25from plain.models.postgres.sql import (
26 PK_DEFAULT_VALUE,
27 bulk_insert_sql,
28 distinct_sql,
29 explain_query_prefix,
30 for_update_sql,
31 limit_offset_sql,
32 on_conflict_suffix_sql,
33 quote_name,
34 return_insert_columns,
35)
36from plain.models.query_utils import select_related_descend
37from plain.models.sql.constants import (
38 CURSOR,
39 GET_ITERATOR_CHUNK_SIZE,
40 MULTI,
41 NO_RESULTS,
42 ORDER_DIR,
43 SINGLE,
44)
45from plain.models.sql.query import Query, get_order_dir
46from plain.models.sql.where import AND
47from plain.models.transaction import TransactionManagementError
48from plain.utils.hashable import make_hashable
49from plain.utils.regex_helper import _lazy_re_compile
50
51if TYPE_CHECKING:
52 from plain.models.expressions import BaseExpression
53 from plain.models.postgres.wrapper import DatabaseWrapper
54 from plain.models.sql.query import InsertQuery
55
56# Type aliases for SQL compilation results
57SqlParams = tuple[Any, ...]
58SqlWithParams = tuple[str, SqlParams]
59
60
61class SQLCompilable(Protocol):
62 """Protocol for objects that can be compiled to SQL."""
63
64 def as_sql(
65 self, compiler: SQLCompiler, connection: DatabaseWrapper
66 ) -> tuple[str, Sequence[Any]]:
67 """Return SQL string and parameters for this object."""
68 ...
69
70
71class PositionRef(Ref):
72 def __init__(self, ordinal: int, refs: str, source: Any):
73 self.ordinal = ordinal
74 super().__init__(refs, source)
75
76 def as_sql(
77 self, compiler: SQLCompiler, connection: DatabaseWrapper
78 ) -> tuple[str, list[Any]]:
79 return str(self.ordinal), []
80
81
82class SQLCompiler:
83 # Multiline ordering SQL clause may appear from RawSQL.
84 ordering_parts = _lazy_re_compile(
85 r"^(.*)\s(?:ASC|DESC).*",
86 re.MULTILINE | re.DOTALL,
87 )
88
89 def __init__(
90 self, query: Query, connection: DatabaseWrapper, elide_empty: bool = True
91 ):
92 self.query = query
93 self.connection = connection
94 # Some queries, e.g. coalesced aggregation, need to be executed even if
95 # they would return an empty result set.
96 self.elide_empty = elide_empty
97 self.quote_cache: dict[str, str] = {"*": "*"}
98 # The select, klass_info, and annotations are needed by QuerySet.iterator()
99 # these are set as a side-effect of executing the query. Note that we calculate
100 # separately a list of extra select columns needed for grammatical correctness
101 # of the query, but these columns are not included in self.select.
102 self.select: list[tuple[Any, SqlWithParams, str | None]] | None = None
103 self.annotation_col_map: dict[str, int] | None = None
104 self.klass_info: dict[str, Any] | None = None
105 self._meta_ordering: list[str] | None = None
106
107 def __repr__(self) -> str:
108 model_name = self.query.model.__qualname__ if self.query.model else "None"
109 return (
110 f"<{self.__class__.__qualname__} "
111 f"model={model_name} "
112 f"connection={self.connection!r}>"
113 )
114
115 def setup_query(self, with_col_aliases: bool = False) -> None:
116 if all(self.query.alias_refcount[a] == 0 for a in self.query.alias_map):
117 self.query.get_initial_alias()
118 self.select, self.klass_info, self.annotation_col_map = self.get_select(
119 with_col_aliases=with_col_aliases,
120 )
121 self.col_count = len(self.select)
122
123 def pre_sql_setup(
124 self, with_col_aliases: bool = False
125 ) -> tuple[list[Any], list[Any], list[SqlWithParams]] | None:
126 """
127 Do any necessary class setup immediately prior to producing SQL. This
128 is for things that can't necessarily be done in __init__ because we
129 might not have all the pieces in place at that time.
130 """
131 self.setup_query(with_col_aliases=with_col_aliases)
132 assert self.select is not None # Set by setup_query()
133 order_by = self.get_order_by()
134 self.where, self.having, self.qualify = self.query.where.split_having_qualify(
135 must_group_by=self.query.group_by is not None
136 )
137 extra_select = self.get_extra_select(order_by, self.select)
138 self.has_extra_select = bool(extra_select)
139 group_by = self.get_group_by(self.select + extra_select, order_by)
140 return extra_select, order_by, group_by
141
142 def get_group_by(
143 self, select: list[Any], order_by: list[Any]
144 ) -> list[SqlWithParams]:
145 """
146 Return a list of 2-tuples of form (sql, params).
147
148 The logic of what exactly the GROUP BY clause contains is hard
149 to describe in other words than "if it passes the test suite,
150 then it is correct".
151 """
152 # Some examples:
153 # SomeModel.query.annotate(Count('somecol'))
154 # GROUP BY: all fields of the model
155 #
156 # SomeModel.query.values('name').annotate(Count('somecol'))
157 # GROUP BY: name
158 #
159 # SomeModel.query.annotate(Count('somecol')).values('name')
160 # GROUP BY: all cols of the model
161 #
162 # SomeModel.query.values('name', 'id')
163 # .annotate(Count('somecol')).values('id')
164 # GROUP BY: name, id
165 #
166 # SomeModel.query.values('name').annotate(Count('somecol')).values('id')
167 # GROUP BY: name, id
168 #
169 # In fact, the self.query.group_by is the minimal set to GROUP BY. It
170 # can't be ever restricted to a smaller set, but additional columns in
171 # HAVING, ORDER BY, and SELECT clauses are added to it. Unfortunately
172 # the end result is that it is impossible to force the query to have
173 # a chosen GROUP BY clause - you can almost do this by using the form:
174 # .values(*wanted_cols).annotate(AnAggregate())
175 # but any later annotations, extra selects, values calls that
176 # refer some column outside of the wanted_cols, order_by, or even
177 # filter calls can alter the GROUP BY clause.
178
179 # The query.group_by is either None (no GROUP BY at all), True
180 # (group by select fields), or a list of expressions to be added
181 # to the group by.
182 if self.query.group_by is None:
183 return []
184 expressions = []
185 group_by_refs = set()
186 if self.query.group_by is not True:
187 # If the group by is set to a list (by .values() call most likely),
188 # then we need to add everything in it to the GROUP BY clause.
189 # Backwards compatibility hack for setting query.group_by. Remove
190 # when we have public API way of forcing the GROUP BY clause.
191 # Converts string references to expressions.
192 for expr in self.query.group_by:
193 if not hasattr(expr, "as_sql"):
194 expr = self.query.resolve_ref(expr)
195 if isinstance(expr, Ref):
196 if expr.refs not in group_by_refs:
197 group_by_refs.add(expr.refs)
198 expressions.append(expr.source)
199 else:
200 expressions.append(expr)
201 # Note that even if the group_by is set, it is only the minimal
202 # set to group by. So, we need to add cols in select, order_by, and
203 # having into the select in any case.
204 selected_expr_positions = {}
205 for ordinal, (expr, _, alias) in enumerate(select, start=1):
206 if alias:
207 selected_expr_positions[expr] = ordinal
208 # Skip members of the select clause that are already explicitly
209 # grouped against.
210 if alias in group_by_refs:
211 continue
212 expressions.extend(expr.get_group_by_cols())
213 if not self._meta_ordering:
214 for expr, (sql, params, is_ref) in order_by:
215 # Skip references to the SELECT clause, as all expressions in
216 # the SELECT clause are already part of the GROUP BY.
217 if not is_ref:
218 expressions.extend(expr.get_group_by_cols())
219 having_group_by = self.having.get_group_by_cols() if self.having else []
220 for expr in having_group_by:
221 expressions.append(expr)
222 result = []
223 seen = set()
224 expressions = self.collapse_group_by(expressions, having_group_by)
225
226 for expr in expressions:
227 try:
228 sql, params = self.compile(expr)
229 except (EmptyResultSet, FullResultSet):
230 continue
231 # Use select index for GROUP BY when possible
232 if (position := selected_expr_positions.get(expr)) is not None:
233 sql, params = str(position), ()
234 else:
235 sql, params = expr.select_format(self, sql, params)
236 params_hash = make_hashable(params)
237 if (sql, params_hash) not in seen:
238 result.append((sql, params))
239 seen.add((sql, params_hash))
240 return result
241
242 def collapse_group_by(self, expressions: list[Any], having: list[Any]) -> list[Any]:
243 # Use group by functional dependence reduction:
244 # expressions can be reduced to the set of selected table
245 # primary keys as all other columns are functionally dependent on them.
246 # Filter out all expressions associated with a table's primary key
247 # present in the grouped columns. This is done by identifying all
248 # tables that have their primary key included in the grouped
249 # columns and removing non-primary key columns referring to them.
250 pks = {
251 expr
252 for expr in expressions
253 if hasattr(expr, "target") and expr.target.primary_key
254 }
255 aliases = {expr.alias for expr in pks}
256 return [
257 expr
258 for expr in expressions
259 if expr in pks
260 or expr in having
261 or getattr(expr, "alias", None) not in aliases
262 ]
263
264 def get_select(
265 self, with_col_aliases: bool = False
266 ) -> tuple[
267 list[tuple[Any, SqlWithParams, str | None]],
268 dict[str, Any] | None,
269 dict[str, int],
270 ]:
271 """
272 Return three values:
273 - a list of 3-tuples of (expression, (sql, params), alias)
274 - a klass_info structure,
275 - a dictionary of annotations
276
277 The (sql, params) is what the expression will produce, and alias is the
278 "AS alias" for the column (possibly None).
279
280 The klass_info structure contains the following information:
281 - The base model of the query.
282 - Which columns for that model are present in the query (by
283 position of the select clause).
284 - related_klass_infos: [f, klass_info] to descent into
285
286 The annotations is a dictionary of {'attname': column position} values.
287 """
288 select = []
289 klass_info = None
290 annotations = {}
291 select_idx = 0
292 for alias, (sql, params) in self.query.extra_select.items():
293 annotations[alias] = select_idx
294 select.append((RawSQL(sql, params), alias))
295 select_idx += 1
296 assert not (self.query.select and self.query.default_cols)
297 select_mask = self.query.get_select_mask()
298 if self.query.default_cols:
299 cols = self.get_default_columns(select_mask)
300 else:
301 # self.query.select is a special case. These columns never go to
302 # any model.
303 cols = self.query.select
304 if cols:
305 select_list = []
306 for col in cols:
307 select_list.append(select_idx)
308 select.append((col, None))
309 select_idx += 1
310 klass_info = {
311 "model": self.query.model,
312 "select_fields": select_list,
313 }
314 for alias, annotation in self.query.annotation_select.items():
315 annotations[alias] = select_idx
316 select.append((annotation, alias))
317 select_idx += 1
318
319 if self.query.select_related:
320 related_klass_infos = self.get_related_selections(select, select_mask)
321 if klass_info is not None:
322 klass_info["related_klass_infos"] = related_klass_infos
323
324 ret = []
325 col_idx = 1
326 for col, alias in select:
327 try:
328 sql, params = self.compile(col)
329 except EmptyResultSet:
330 empty_result_set_value = getattr(
331 col, "empty_result_set_value", NotImplemented
332 )
333 if empty_result_set_value is NotImplemented:
334 # Select a predicate that's always False.
335 sql, params = "0", ()
336 else:
337 sql, params = self.compile(Value(empty_result_set_value))
338 except FullResultSet:
339 sql, params = self.compile(Value(True))
340 else:
341 sql, params = col.select_format(self, sql, params)
342 if alias is None and with_col_aliases:
343 alias = f"col{col_idx}"
344 col_idx += 1
345 ret.append((col, (sql, params), alias))
346 return ret, klass_info, annotations
347
348 def _order_by_pairs(self) -> Generator[tuple[OrderBy, bool], None, None]:
349 if self.query.extra_order_by:
350 ordering = self.query.extra_order_by
351 elif not self.query.default_ordering:
352 ordering = self.query.order_by
353 elif self.query.order_by:
354 ordering = self.query.order_by
355 elif (
356 self.query.model
357 and (options := self.query.model.model_options)
358 and options.ordering
359 ):
360 ordering = options.ordering
361 self._meta_ordering = list(ordering)
362 else:
363 ordering = []
364 if self.query.standard_ordering:
365 default_order, _ = ORDER_DIR["ASC"]
366 else:
367 default_order, _ = ORDER_DIR["DESC"]
368
369 selected_exprs = {}
370 if select := self.select:
371 for ordinal, (expr, _, alias) in enumerate(select, start=1):
372 pos_expr = PositionRef(ordinal, alias, expr)
373 if alias:
374 selected_exprs[alias] = pos_expr
375 selected_exprs[expr] = pos_expr
376
377 for field in ordering:
378 if isinstance(field, ResolvableExpression):
379 # field is a BaseExpression (has asc/desc/copy methods)
380 field_expr = cast(BaseExpression, field)
381 if isinstance(field_expr, Value):
382 # output_field must be resolved for constants.
383 field_expr = Cast(field_expr, field_expr.output_field)
384 if not isinstance(field_expr, OrderBy):
385 field_expr = field_expr.asc()
386 if not self.query.standard_ordering:
387 field_expr = field_expr.copy()
388 field_expr.reverse_ordering()
389 field = field_expr
390 select_ref = selected_exprs.get(field.expression)
391 if select_ref or (
392 isinstance(field.expression, F)
393 and (select_ref := selected_exprs.get(field.expression.name))
394 ):
395 field = field.copy()
396 field.expression = select_ref
397 yield field, select_ref is not None
398 continue
399 if field == "?": # random
400 yield OrderBy(Random()), False
401 continue
402
403 col, order = get_order_dir(field, default_order)
404 descending = order == "DESC"
405
406 if select_ref := selected_exprs.get(col):
407 # Reference to expression in SELECT clause
408 yield (
409 OrderBy(
410 select_ref,
411 descending=descending,
412 ),
413 True,
414 )
415 continue
416 if col in self.query.annotations:
417 # References to an expression which is masked out of the SELECT
418 # clause.
419 expr = self.query.annotations[col]
420 if isinstance(expr, Value):
421 # output_field must be resolved for constants.
422 expr = Cast(expr, expr.output_field)
423 yield OrderBy(expr, descending=descending), False
424 continue
425
426 if "." in field:
427 # This came in through an extra(order_by=...) addition. Pass it
428 # on verbatim.
429 table, col = col.split(".", 1)
430 yield (
431 OrderBy(
432 RawSQL(f"{self.quote_name_unless_alias(table)}.{col}", []),
433 descending=descending,
434 ),
435 False,
436 )
437 continue
438
439 if self.query.extra and col in self.query.extra:
440 if col in self.query.extra_select:
441 yield (
442 OrderBy(
443 Ref(col, RawSQL(*self.query.extra[col])),
444 descending=descending,
445 ),
446 True,
447 )
448 else:
449 yield (
450 OrderBy(RawSQL(*self.query.extra[col]), descending=descending),
451 False,
452 )
453 else:
454 # 'col' is of the form 'field' or 'field1__field2' or
455 # '-field1__field2__field', etc.
456 assert self.query.model is not None, (
457 "Ordering by fields requires a model"
458 )
459 meta = self.query.model._model_meta
460 yield from self.find_ordering_name(
461 field,
462 meta,
463 default_order=default_order,
464 )
465
466 def get_order_by(self) -> list[tuple[Any, tuple[str, tuple, bool]]]:
467 """
468 Return a list of 2-tuples of the form (expr, (sql, params, is_ref)) for
469 the ORDER BY clause.
470
471 The order_by clause can alter the select clause (for example it can add
472 aliases to clauses that do not yet have one, or it can add totally new
473 select clauses).
474 """
475 result = []
476 seen = set()
477 for expr, is_ref in self._order_by_pairs():
478 resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None)
479 sql, params = self.compile(resolved)
480 # Don't add the same column twice, but the order direction is
481 # not taken into account so we strip it. When this entire method
482 # is refactored into expressions, then we can check each part as we
483 # generate it.
484 without_ordering = self.ordering_parts.search(sql)[1]
485 params_hash = make_hashable(params)
486 if (without_ordering, params_hash) in seen:
487 continue
488 seen.add((without_ordering, params_hash))
489 result.append((resolved, (sql, params, is_ref)))
490 return result
491
492 def get_extra_select(
493 self, order_by: list[Any], select: list[Any]
494 ) -> list[tuple[Any, SqlWithParams, None]]:
495 extra_select = []
496 if self.query.distinct and not self.query.distinct_fields:
497 select_sql = [t[1] for t in select]
498 for expr, (sql, params, is_ref) in order_by:
499 without_ordering = self.ordering_parts.search(sql)[1]
500 if not is_ref and (without_ordering, params) not in select_sql:
501 extra_select.append((expr, (without_ordering, params), None))
502 return extra_select
503
504 def quote_name_unless_alias(self, name: str) -> str:
505 """
506 A wrapper around quote_name() that doesn't quote aliases for table
507 names. This avoids problems with some SQL dialects that treat quoted
508 strings specially (e.g. PostgreSQL).
509 """
510 if name in self.quote_cache:
511 return self.quote_cache[name]
512 if (
513 (name in self.query.alias_map and name not in self.query.table_map)
514 or name in self.query.extra_select
515 or (
516 self.query.external_aliases.get(name)
517 and name not in self.query.table_map
518 )
519 ):
520 self.quote_cache[name] = name
521 return name
522 r = quote_name(name)
523 self.quote_cache[name] = r
524 return r
525
526 def compile(self, node: SQLCompilable) -> SqlWithParams:
527 sql, params = node.as_sql(self, self.connection)
528 return sql, tuple(params)
529
530 def get_qualify_sql(self) -> tuple[list[str], list[Any]]:
531 where_parts = []
532 if self.where:
533 where_parts.append(self.where)
534 if self.having:
535 where_parts.append(self.having)
536 inner_query = self.query.clone()
537 inner_query.subquery = True
538 inner_query.where = inner_query.where.__class__(where_parts)
539 # Augment the inner query with any window function references that
540 # might have been masked via values() and alias(). If any masked
541 # aliases are added they'll be masked again to avoid fetching
542 # the data in the `if qual_aliases` branch below.
543 select = {
544 expr: alias for expr, _, alias in self.get_select(with_col_aliases=True)[0]
545 }
546 select_aliases = set(select.values())
547 qual_aliases = set()
548 replacements = {}
549
550 def collect_replacements(expressions: list[Any]) -> None:
551 while expressions:
552 expr = expressions.pop()
553 if expr in replacements:
554 continue
555 elif select_alias := select.get(expr):
556 replacements[expr] = select_alias
557 elif isinstance(expr, Lookup):
558 expressions.extend(expr.get_source_expressions())
559 elif isinstance(expr, Ref):
560 if expr.refs not in select_aliases:
561 expressions.extend(expr.get_source_expressions())
562 else:
563 num_qual_alias = len(qual_aliases)
564 select_alias = f"qual{num_qual_alias}"
565 qual_aliases.add(select_alias)
566 inner_query.add_annotation(expr, select_alias)
567 replacements[expr] = select_alias
568
569 qualify = self.qualify
570 if qualify is None:
571 raise ValueError("QUALIFY clause expected but not provided")
572 collect_replacements(list(qualify.leaves()))
573 qualify = qualify.replace_expressions(
574 {expr: Ref(alias, expr) for expr, alias in replacements.items()}
575 )
576 self.qualify = qualify
577 order_by = []
578 for order_by_expr, *_ in self.get_order_by():
579 collect_replacements(order_by_expr.get_source_expressions())
580 order_by.append(
581 order_by_expr.replace_expressions(
582 {expr: Ref(alias, expr) for expr, alias in replacements.items()}
583 )
584 )
585 inner_query_compiler = inner_query.get_compiler(elide_empty=self.elide_empty)
586 inner_sql, inner_params = inner_query_compiler.as_sql(
587 # The limits must be applied to the outer query to avoid pruning
588 # results too eagerly.
589 with_limits=False,
590 # Force unique aliasing of selected columns to avoid collisions
591 # and make rhs predicates referencing easier.
592 with_col_aliases=True,
593 )
594 qualify_sql, qualify_params = self.compile(qualify)
595 result = [
596 "SELECT * FROM (",
597 inner_sql,
598 ")",
599 quote_name("qualify"),
600 "WHERE",
601 qualify_sql,
602 ]
603 if qual_aliases:
604 # If some select aliases were unmasked for filtering purposes they
605 # must be masked back.
606 cols = [quote_name(alias) for alias in select.values() if alias is not None]
607 result = [
608 "SELECT",
609 ", ".join(cols),
610 "FROM (",
611 *result,
612 ")",
613 quote_name("qualify_mask"),
614 ]
615 params = list(inner_params) + list(qualify_params)
616 # As the SQL spec is unclear on whether or not derived tables
617 # ordering must propagate it has to be explicitly repeated on the
618 # outer-most query to ensure it's preserved.
619 if order_by:
620 ordering_sqls = []
621 for ordering in order_by:
622 ordering_sql, ordering_params = self.compile(ordering)
623 ordering_sqls.append(ordering_sql)
624 params.extend(ordering_params)
625 result.extend(["ORDER BY", ", ".join(ordering_sqls)])
626 return result, params
627
628 def as_sql(
629 self, with_limits: bool = True, with_col_aliases: bool = False
630 ) -> SqlWithParams:
631 """
632 Create the SQL for this query. Return the SQL string and list of
633 parameters.
634
635 If 'with_limits' is False, any limit/offset information is not included
636 in the query.
637 """
638 refcounts_before = self.query.alias_refcount.copy()
639 try:
640 result = self.pre_sql_setup(with_col_aliases=with_col_aliases)
641 assert result is not None # SQLCompiler.pre_sql_setup always returns tuple
642 extra_select, order_by, group_by = result
643 assert self.select is not None # Set by pre_sql_setup()
644 for_update_part = None
645 # Is a LIMIT/OFFSET clause needed?
646 with_limit_offset = with_limits and self.query.is_sliced
647 if self.qualify:
648 result, params = self.get_qualify_sql()
649 order_by = None
650 else:
651 distinct_fields, distinct_params = self.get_distinct()
652 # This must come after 'select', 'ordering', and 'distinct'
653 # (see docstring of get_from_clause() for details).
654 from_, f_params = self.get_from_clause()
655 try:
656 where, w_params = (
657 self.compile(self.where) if self.where is not None else ("", [])
658 )
659 except EmptyResultSet:
660 if self.elide_empty:
661 raise
662 # Use a predicate that's always False.
663 where, w_params = "0 = 1", []
664 except FullResultSet:
665 where, w_params = "", []
666 try:
667 having, h_params = (
668 self.compile(self.having)
669 if self.having is not None
670 else ("", [])
671 )
672 except FullResultSet:
673 having, h_params = "", []
674 result = ["SELECT"]
675 params = []
676
677 if self.query.distinct:
678 distinct_result, distinct_params = distinct_sql(
679 distinct_fields,
680 distinct_params,
681 )
682 result += distinct_result
683 params += distinct_params
684
685 out_cols = []
686 for _, (s_sql, s_params), alias in self.select + extra_select:
687 if alias:
688 s_sql = f"{s_sql} AS {quote_name(alias)}"
689 params.extend(s_params)
690 out_cols.append(s_sql)
691
692 result += [", ".join(out_cols)]
693 if from_:
694 result += ["FROM", *from_]
695 params.extend(f_params)
696
697 if self.query.select_for_update:
698 if self.connection.get_autocommit():
699 raise TransactionManagementError(
700 "select_for_update cannot be used outside of a transaction."
701 )
702
703 for_update_part = for_update_sql(
704 nowait=self.query.select_for_update_nowait,
705 skip_locked=self.query.select_for_update_skip_locked,
706 of=tuple(self.get_select_for_update_of_arguments()),
707 no_key=self.query.select_for_no_key_update,
708 )
709
710 if where:
711 result.append(f"WHERE {where}")
712 params.extend(w_params)
713
714 grouping = []
715 for g_sql, g_params in group_by:
716 grouping.append(g_sql)
717 params.extend(g_params)
718 if grouping:
719 if distinct_fields:
720 raise NotImplementedError(
721 "annotate() + distinct(fields) is not implemented."
722 )
723 order_by = order_by or []
724 result.append("GROUP BY {}".format(", ".join(grouping)))
725 if self._meta_ordering:
726 order_by = None
727 if having:
728 result.append(f"HAVING {having}")
729 params.extend(h_params)
730
731 if self.query.explain_info:
732 result.insert(
733 0,
734 explain_query_prefix(
735 self.query.explain_info.format,
736 **self.query.explain_info.options,
737 ),
738 )
739
740 if order_by:
741 ordering = []
742 for _, (o_sql, o_params, _) in order_by:
743 ordering.append(o_sql)
744 params.extend(o_params)
745 result.append("ORDER BY {}".format(", ".join(ordering)))
746
747 if with_limit_offset:
748 result.append(
749 limit_offset_sql(self.query.low_mark, self.query.high_mark)
750 )
751
752 if for_update_part:
753 result.append(for_update_part)
754
755 if self.query.subquery and extra_select:
756 # If the query is used as a subquery, the extra selects would
757 # result in more columns than the left-hand side expression is
758 # expecting. This can happen when a subquery uses a combination
759 # of order_by() and distinct(), forcing the ordering expressions
760 # to be selected as well. Wrap the query in another subquery
761 # to exclude extraneous selects.
762 sub_selects = []
763 sub_params = []
764 for index, (select, _, alias) in enumerate(self.select, start=1):
765 if alias:
766 sub_selects.append(
767 "{}.{}".format(
768 quote_name("subquery"),
769 quote_name(alias),
770 )
771 )
772 else:
773 select_clone = select.relabeled_clone(
774 {select.alias: "subquery"}
775 )
776 subselect, subparams = select_clone.as_sql(
777 self, self.connection
778 )
779 sub_selects.append(subselect)
780 sub_params.extend(subparams)
781 return "SELECT {} FROM ({}) subquery".format(
782 ", ".join(sub_selects),
783 " ".join(result),
784 ), tuple(sub_params + params)
785
786 return " ".join(result), tuple(params)
787 finally:
788 # Finally do cleanup - get rid of the joins we created above.
789 self.query.reset_refcounts(refcounts_before)
790
791 def get_default_columns(
792 self,
793 select_mask: Any,
794 start_alias: str | None = None,
795 opts: Meta | None = None,
796 ) -> list[Any]:
797 """
798 Compute the default columns for selecting every field in the base
799 model. Will sometimes be called to pull in related models (e.g. via
800 select_related), in which case "opts" (Meta) and "start_alias" will be given
801 to provide a starting point for the traversal.
802
803 Return a list of strings, quoted appropriately for use in SQL
804 directly, as well as a set of aliases used in the select statement (if
805 'as_pairs' is True, return a list of (alias, col_name) pairs instead
806 of strings as the first component and None as the second component).
807 """
808 result = []
809 if opts is None:
810 if self.query.model is None:
811 return result
812 opts = self.query.model._model_meta
813 start_alias = start_alias or self.query.get_initial_alias()
814
815 for field in opts.concrete_fields:
816 model = field.model
817 # Since we no longer have proxy models or inheritance,
818 # the field's model should always be the same as opts.model.
819 if model == opts.model:
820 model = None
821 if select_mask and field not in select_mask:
822 continue
823
824 column = field.get_col(start_alias)
825 result.append(column)
826 return result
827
828 def get_distinct(self) -> tuple[list[str], list]:
829 """
830 Return a quoted list of fields to use in DISTINCT ON part of the query.
831
832 This method can alter the tables in the query, and thus it must be
833 called before get_from_clause().
834 """
835 result = []
836 params = []
837 if not self.query.distinct_fields:
838 return result, params
839
840 if self.query.model is None:
841 return result, params
842 opts = self.query.model._model_meta
843
844 for name in self.query.distinct_fields:
845 parts = name.split(LOOKUP_SEP)
846 _, targets, alias, joins, path, _, transform_function = self._setup_joins(
847 parts, opts, None
848 )
849 targets, alias, _ = self.query.trim_joins(targets, joins, path)
850 for target in targets:
851 if name in self.query.annotation_select:
852 result.append(quote_name(name))
853 else:
854 r, p = self.compile(transform_function(target, alias))
855 result.append(r)
856 params.append(p)
857 return result, params
858
859 def find_ordering_name(
860 self,
861 name: str,
862 meta: Meta,
863 alias: str | None = None,
864 default_order: str = "ASC",
865 already_seen: set | None = None,
866 ) -> list[tuple[OrderBy, bool]]:
867 """
868 Return the table alias (the name might be ambiguous, the alias will
869 not be) and column name for ordering by the given 'name' parameter.
870 The 'name' is of the form 'field1__field2__...__fieldN'.
871 """
872 name, order = get_order_dir(name, default_order)
873 descending = order == "DESC"
874 pieces = name.split(LOOKUP_SEP)
875 (
876 field,
877 targets,
878 alias,
879 joins,
880 path,
881 meta,
882 transform_function,
883 ) = self._setup_joins(pieces, meta, alias)
884
885 # If we get to this point and the field is a relation to another model,
886 # append the default ordering for that model unless it is the
887 # attribute name of the field that is specified or
888 # there are transforms to process.
889 if (
890 isinstance(field, RelatedField)
891 and meta.model.model_options.ordering
892 and getattr(field, "attname", None) != pieces[-1]
893 and not getattr(transform_function, "has_transforms", False)
894 ):
895 # Firstly, avoid infinite loops.
896 already_seen = already_seen or set()
897 join_tuple = tuple(
898 getattr(self.query.alias_map[j], "join_cols", None) for j in joins
899 )
900 if join_tuple in already_seen:
901 raise FieldError("Infinite loop caused by ordering.")
902 already_seen.add(join_tuple)
903
904 results = []
905 for item in meta.model.model_options.ordering:
906 if isinstance(item, ResolvableExpression) and not isinstance(
907 item, OrderBy
908 ):
909 item_expr: BaseExpression = cast(BaseExpression, item)
910 item = item_expr.desc() if descending else item_expr.asc()
911 if isinstance(item, OrderBy):
912 results.append(
913 (item.prefix_references(f"{name}{LOOKUP_SEP}"), False)
914 )
915 continue
916 results.extend(
917 (expr.prefix_references(f"{name}{LOOKUP_SEP}"), is_ref)
918 for expr, is_ref in self.find_ordering_name(
919 item, meta, alias, order, already_seen
920 )
921 )
922 return results
923 targets, alias, _ = self.query.trim_joins(targets, joins, path)
924 return [
925 (OrderBy(transform_function(t, alias), descending=descending), False)
926 for t in targets
927 ]
928
929 def _setup_joins(
930 self, pieces: list[str], meta: Meta, alias: str | None
931 ) -> tuple[Any, Any, str, list, Any, Meta, Any]:
932 """
933 Helper method for get_order_by() and get_distinct().
934
935 get_ordering() and get_distinct() must produce same target columns on
936 same input, as the prefixes of get_ordering() and get_distinct() must
937 match. Executing SQL where this is not true is an error.
938 """
939 alias = alias or self.query.get_initial_alias()
940 assert alias is not None
941 field, targets, meta, joins, path, transform_function = self.query.setup_joins(
942 pieces, meta, alias
943 )
944 alias = joins[-1]
945 return field, targets, alias, joins, path, meta, transform_function
946
947 def get_from_clause(self) -> tuple[list[str], list]:
948 """
949 Return a list of strings that are joined together to go after the
950 "FROM" part of the query, as well as a list any extra parameters that
951 need to be included. Subclasses, can override this to create a
952 from-clause via a "select".
953
954 This should only be called after any SQL construction methods that
955 might change the tables that are needed. This means the select columns,
956 ordering, and distinct must be done first.
957 """
958 result = []
959 params = []
960 for alias in tuple(self.query.alias_map):
961 if not self.query.alias_refcount[alias]:
962 continue
963 try:
964 from_clause = self.query.alias_map[alias]
965 except KeyError:
966 # Extra tables can end up in self.tables, but not in the
967 # alias_map if they aren't in a join. That's OK. We skip them.
968 continue
969 clause_sql, clause_params = self.compile(from_clause)
970 result.append(clause_sql)
971 params.extend(clause_params)
972 for t in self.query.extra_tables:
973 alias, _ = self.query.table_alias(t)
974 # Only add the alias if it's not already present (the table_alias()
975 # call increments the refcount, so an alias refcount of one means
976 # this is the only reference).
977 if (
978 alias not in self.query.alias_map
979 or self.query.alias_refcount[alias] == 1
980 ):
981 result.append(f", {self.quote_name_unless_alias(alias)}")
982 return result, params
983
984 def get_related_selections(
985 self,
986 select: list[Any],
987 select_mask: Any,
988 opts: Meta | None = None,
989 root_alias: str | None = None,
990 cur_depth: int = 1,
991 requested: dict | None = None,
992 restricted: bool | None = None,
993 ) -> list[dict[str, Any]]:
994 """
995 Fill in the information needed for a select_related query. The current
996 depth is measured as the number of connections away from the root model
997 (for example, cur_depth=1 means we are looking at models with direct
998 connections to the root model).
999
1000 Args:
1001 opts: Meta for the model being queried (internal metadata)
1002 """
1003
1004 related_klass_infos = []
1005 if not restricted and cur_depth > self.query.max_depth:
1006 # We've recursed far enough; bail out.
1007 return related_klass_infos
1008
1009 if not opts:
1010 assert self.query.model is not None, "select_related requires a model"
1011 opts = self.query.model._model_meta
1012 root_alias = self.query.get_initial_alias()
1013
1014 assert root_alias is not None # Must be provided or set above
1015 assert opts is not None
1016
1017 def _get_field_choices() -> chain:
1018 direct_choices = (
1019 f.name for f in opts.fields if isinstance(f, RelatedField)
1020 )
1021 reverse_choices = (
1022 f.field.related_query_name()
1023 for f in opts.related_objects
1024 if f.field.primary_key
1025 )
1026 return chain(
1027 direct_choices, reverse_choices, self.query._filtered_relations
1028 )
1029
1030 # Setup for the case when only particular related fields should be
1031 # included in the related selection.
1032 fields_found = set()
1033 if requested is None:
1034 restricted = isinstance(self.query.select_related, dict)
1035 if restricted:
1036 requested = cast(dict, self.query.select_related)
1037
1038 def get_related_klass_infos(
1039 klass_info: dict, related_klass_infos: list
1040 ) -> None:
1041 klass_info["related_klass_infos"] = related_klass_infos
1042
1043 for f in opts.fields:
1044 fields_found.add(f.name)
1045
1046 if restricted:
1047 assert requested is not None
1048 next = requested.get(f.name, {})
1049 if not isinstance(f, RelatedField):
1050 # If a non-related field is used like a relation,
1051 # or if a single non-relational field is given.
1052 if next or f.name in requested:
1053 raise FieldError(
1054 "Non-relational field given in select_related: '{}'. "
1055 "Choices are: {}".format(
1056 f.name,
1057 ", ".join(_get_field_choices()) or "(none)",
1058 )
1059 )
1060 else:
1061 next = None
1062
1063 if not select_related_descend(f, restricted, requested, select_mask):
1064 continue
1065 related_select_mask = select_mask.get(f) or {}
1066 klass_info = {
1067 "model": f.remote_field.model,
1068 "field": f,
1069 "reverse": False,
1070 "local_setter": f.set_cached_value,
1071 "remote_setter": f.remote_field.set_cached_value
1072 if f.primary_key
1073 else lambda x, y: None,
1074 }
1075 related_klass_infos.append(klass_info)
1076 select_fields = []
1077 _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias)
1078 alias = joins[-1]
1079 columns = self.get_default_columns(
1080 related_select_mask,
1081 start_alias=alias,
1082 opts=f.remote_field.model._model_meta,
1083 )
1084 for col in columns:
1085 select_fields.append(len(select))
1086 select.append((col, None))
1087 klass_info["select_fields"] = select_fields
1088 next_klass_infos = self.get_related_selections(
1089 select,
1090 related_select_mask,
1091 f.remote_field.model._model_meta,
1092 alias,
1093 cur_depth + 1,
1094 next,
1095 restricted,
1096 )
1097 get_related_klass_infos(klass_info, next_klass_infos)
1098
1099 if restricted:
1100 from plain.models.fields.reverse_related import ManyToManyRel
1101
1102 related_fields = [
1103 (o.field, o.related_model)
1104 for o in opts.related_objects
1105 if o.field.primary_key and not isinstance(o, ManyToManyRel)
1106 ]
1107 for related_field, model in related_fields:
1108 related_select_mask = select_mask.get(related_field) or {}
1109
1110 if not select_related_descend(
1111 related_field,
1112 restricted,
1113 requested,
1114 related_select_mask,
1115 reverse=True,
1116 ):
1117 continue
1118
1119 related_field_name = related_field.related_query_name()
1120 fields_found.add(related_field_name)
1121
1122 join_info = self.query.setup_joins(
1123 [related_field_name], opts, root_alias
1124 )
1125 alias = join_info.joins[-1]
1126 klass_info = {
1127 "model": model,
1128 "field": related_field,
1129 "reverse": True,
1130 "local_setter": related_field.remote_field.set_cached_value,
1131 "remote_setter": related_field.set_cached_value,
1132 }
1133 related_klass_infos.append(klass_info)
1134 select_fields = []
1135 columns = self.get_default_columns(
1136 related_select_mask,
1137 start_alias=alias,
1138 opts=model._model_meta,
1139 )
1140 for col in columns:
1141 select_fields.append(len(select))
1142 select.append((col, None))
1143 klass_info["select_fields"] = select_fields
1144 assert requested is not None
1145 next = requested.get(related_field.related_query_name(), {})
1146 next_klass_infos = self.get_related_selections(
1147 select,
1148 related_select_mask,
1149 model._model_meta,
1150 alias,
1151 cur_depth + 1,
1152 next,
1153 restricted,
1154 )
1155 get_related_klass_infos(klass_info, next_klass_infos)
1156
1157 def local_setter(final_field: Any, obj: Any, from_obj: Any) -> None:
1158 # Set a reverse fk object when relation is non-empty.
1159 if from_obj:
1160 final_field.remote_field.set_cached_value(from_obj, obj)
1161
1162 def local_setter_noop(obj: Any, from_obj: Any) -> None:
1163 pass
1164
1165 def remote_setter(name: str, obj: Any, from_obj: Any) -> None:
1166 setattr(from_obj, name, obj)
1167
1168 assert requested is not None
1169 for name in list(requested):
1170 # Filtered relations work only on the topmost level.
1171 if cur_depth > 1:
1172 break
1173 if name in self.query._filtered_relations:
1174 fields_found.add(name)
1175 final_field, _, join_opts, joins, _, _ = self.query.setup_joins(
1176 [name], opts, root_alias
1177 )
1178 model = join_opts.model
1179 alias = joins[-1]
1180 klass_info = {
1181 "model": model,
1182 "field": final_field,
1183 "reverse": True,
1184 "local_setter": (
1185 partial(local_setter, final_field)
1186 if len(joins) <= 2
1187 else local_setter_noop
1188 ),
1189 "remote_setter": partial(remote_setter, name),
1190 }
1191 related_klass_infos.append(klass_info)
1192 select_fields = []
1193 field_select_mask = select_mask.get((name, final_field)) or {}
1194 columns = self.get_default_columns(
1195 field_select_mask,
1196 start_alias=alias,
1197 opts=model._model_meta,
1198 )
1199 for col in columns:
1200 select_fields.append(len(select))
1201 select.append((col, None))
1202 klass_info["select_fields"] = select_fields
1203 next_requested = requested.get(name, {})
1204 next_klass_infos = self.get_related_selections(
1205 select,
1206 field_select_mask,
1207 opts=model._model_meta,
1208 root_alias=alias,
1209 cur_depth=cur_depth + 1,
1210 requested=next_requested,
1211 restricted=restricted,
1212 )
1213 get_related_klass_infos(klass_info, next_klass_infos)
1214 fields_not_found = set(requested).difference(fields_found)
1215 if fields_not_found:
1216 invalid_fields = (f"'{s}'" for s in fields_not_found)
1217 raise FieldError(
1218 "Invalid field name(s) given in select_related: {}. "
1219 "Choices are: {}".format(
1220 ", ".join(invalid_fields),
1221 ", ".join(_get_field_choices()) or "(none)",
1222 )
1223 )
1224 return related_klass_infos
1225
1226 def get_select_for_update_of_arguments(self) -> list[str]:
1227 """
1228 Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
1229 the query.
1230 """
1231
1232 def _get_first_selected_col_from_model(klass_info: dict) -> Any | None:
1233 """
1234 Find the first selected column from a model. If it doesn't exist,
1235 don't lock a model.
1236
1237 select_fields is filled recursively, so it also contains fields
1238 from the parent models.
1239 """
1240 assert self.select is not None
1241 model = klass_info["model"]
1242 for select_index in klass_info["select_fields"]:
1243 if self.select[select_index][0].target.model == model:
1244 return self.select[select_index][0]
1245 return None
1246
1247 def _get_field_choices() -> Generator[str, None, None]:
1248 """Yield all allowed field paths in breadth-first search order."""
1249 queue = collections.deque([(None, self.klass_info)])
1250 while queue:
1251 parent_path, klass_info = queue.popleft()
1252 if parent_path is None:
1253 path = []
1254 yield "self"
1255 else:
1256 assert klass_info is not None # Only first iteration has None
1257 field = klass_info["field"]
1258 if klass_info["reverse"]:
1259 field = field.remote_field
1260 path = parent_path + [field.name]
1261 yield LOOKUP_SEP.join(path)
1262 if klass_info is not None:
1263 queue.extend(
1264 (path, related_klass_info)
1265 for related_klass_info in klass_info.get(
1266 "related_klass_infos", []
1267 )
1268 )
1269
1270 if not self.klass_info:
1271 return []
1272 result = []
1273 invalid_names = []
1274 for name in self.query.select_for_update_of:
1275 klass_info = self.klass_info
1276 if name == "self":
1277 col = _get_first_selected_col_from_model(klass_info)
1278 else:
1279 for part in name.split(LOOKUP_SEP):
1280 if klass_info is None:
1281 break
1282 klass_infos = (*klass_info.get("related_klass_infos", []),)
1283 for related_klass_info in klass_infos:
1284 field = related_klass_info["field"]
1285 if related_klass_info["reverse"]:
1286 field = field.remote_field
1287 if field.name == part:
1288 klass_info = related_klass_info
1289 break
1290 else:
1291 klass_info = None
1292 break
1293 if klass_info is None:
1294 invalid_names.append(name)
1295 continue
1296 col = _get_first_selected_col_from_model(klass_info)
1297 if col is not None:
1298 result.append(self.quote_name_unless_alias(col.alias))
1299 if invalid_names:
1300 raise FieldError(
1301 "Invalid field name(s) given in select_for_update(of=(...)): {}. "
1302 "Only relational fields followed in the query are allowed. "
1303 "Choices are: {}.".format(
1304 ", ".join(invalid_names),
1305 ", ".join(_get_field_choices()),
1306 )
1307 )
1308 return result
1309
1310 def get_converters(
1311 self, expressions: Iterable[Any]
1312 ) -> dict[int, tuple[list[Any], Any]]:
1313 converters = {}
1314 for i, expression in enumerate(expressions):
1315 if expression:
1316 field_converters = expression.get_db_converters(self.connection)
1317 if field_converters:
1318 converters[i] = (field_converters, expression)
1319 return converters
1320
1321 def apply_converters(
1322 self, rows: Iterable, converters: dict
1323 ) -> Generator[list, None, None]:
1324 connection = self.connection
1325 converters_list = list(converters.items())
1326 for row in map(list, rows):
1327 for pos, (convs, expression) in converters_list:
1328 value = row[pos]
1329 for converter in convs:
1330 value = converter(value, expression, connection)
1331 row[pos] = value
1332 yield row
1333
1334 def results_iter(
1335 self,
1336 results: Any = None,
1337 tuple_expected: bool = False,
1338 chunked_fetch: bool = False,
1339 chunk_size: int = GET_ITERATOR_CHUNK_SIZE,
1340 ) -> Iterable[Any]:
1341 """Return an iterator over the results from executing this query."""
1342 if results is None:
1343 results = self.execute_sql(
1344 MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size
1345 )
1346 assert self.select is not None # Set during query execution
1347 fields = [s[0] for s in self.select[0 : self.col_count]]
1348 converters = self.get_converters(fields)
1349 rows = chain.from_iterable(results)
1350 if converters:
1351 rows = self.apply_converters(rows, converters)
1352 if tuple_expected:
1353 rows = map(tuple, rows)
1354 return rows
1355
1356 def has_results(self) -> bool:
1357 """Check if the query returns any results."""
1358 return bool(self.execute_sql(SINGLE))
1359
1360 def execute_sql(
1361 self,
1362 result_type: str = MULTI,
1363 chunked_fetch: bool = False,
1364 chunk_size: int = GET_ITERATOR_CHUNK_SIZE,
1365 ) -> Any:
1366 """
1367 Run the query against the database and return the result(s). The
1368 return value is a single data item if result_type is SINGLE, or an
1369 iterator over the results if the result_type is MULTI.
1370
1371 result_type is either MULTI (use fetchmany() to retrieve all rows),
1372 SINGLE (only retrieve a single row), or None. In this last case, the
1373 cursor is returned if any query is executed, since it's used by
1374 subclasses such as InsertQuery). It's possible, however, that no query
1375 is needed, as the filters describe an empty set. In that case, None is
1376 returned, to avoid any unnecessary database interaction.
1377 """
1378 result_type = result_type or NO_RESULTS
1379 try:
1380 as_sql_result = self.as_sql()
1381 # SQLCompiler.as_sql returns SqlWithParams, subclasses may differ
1382 assert isinstance(as_sql_result, tuple)
1383 assert isinstance(as_sql_result[0], str)
1384 sql, params = as_sql_result
1385 if not sql:
1386 raise EmptyResultSet
1387 except EmptyResultSet:
1388 if result_type == MULTI:
1389 return iter([])
1390 else:
1391 return
1392 if chunked_fetch:
1393 cursor = self.connection.chunked_cursor()
1394 else:
1395 cursor = self.connection.cursor()
1396 try:
1397 cursor.execute(sql, params)
1398 except Exception:
1399 # Might fail for server-side cursors (e.g. connection closed)
1400 cursor.close()
1401 raise
1402
1403 if result_type == CURSOR:
1404 # Give the caller the cursor to process and close.
1405 return cursor
1406 if result_type == SINGLE:
1407 try:
1408 val = cursor.fetchone()
1409 if val:
1410 return val[0 : self.col_count]
1411 return val
1412 finally:
1413 # done with the cursor
1414 cursor.close()
1415 if result_type == NO_RESULTS:
1416 cursor.close()
1417 return
1418
1419 result = cursor_iter(
1420 cursor,
1421 [], # empty_fetchmany_value for PostgreSQL
1422 self.col_count if self.has_extra_select else None,
1423 chunk_size,
1424 )
1425 if not chunked_fetch:
1426 # If we are using non-chunked reads, we return the same data
1427 # structure as normally, but ensure it is all read into memory
1428 # before going any further.
1429 return list(result)
1430 return result
1431
1432 def as_subquery_condition(
1433 self, alias: str, columns: list[str], compiler: SQLCompiler
1434 ) -> SqlWithParams:
1435 qn = compiler.quote_name_unless_alias
1436 qn2 = quote_name
1437
1438 for index, select_col in enumerate(self.query.select):
1439 lhs_sql, lhs_params = self.compile(select_col)
1440 rhs = f"{qn(alias)}.{qn2(columns[index])}"
1441 self.query.where.add(RawSQL(f"{lhs_sql} = {rhs}", lhs_params), AND)
1442
1443 sql, params = self.as_sql()
1444 return f"EXISTS ({sql})", params
1445
1446 def explain_query(self) -> Generator[str, None, None]:
1447 result = list(self.execute_sql())
1448 explain_info = self.query.explain_info
1449 # PostgreSQL may return tuples with integers and strings depending on
1450 # the EXPLAIN format. Flatten them out into strings.
1451 format_ = explain_info.format if explain_info is not None else None
1452 output_formatter = json.dumps if format_ and format_.lower() == "json" else str
1453 for row in result[0]:
1454 if not isinstance(row, str):
1455 yield " ".join(output_formatter(c) for c in row)
1456 else:
1457 yield row
1458
1459
1460class SQLInsertCompiler(SQLCompiler):
1461 query: InsertQuery
1462 returning_fields: list | None = None
1463 returning_params: tuple = ()
1464
1465 def field_as_sql(self, field: Any, val: Any) -> tuple[str, list]:
1466 """
1467 Take a field and a value intended to be saved on that field, and
1468 return placeholder SQL and accompanying params. Check for raw values,
1469 expressions, and fields with get_placeholder() defined in that order.
1470
1471 When field is None, consider the value raw and use it as the
1472 placeholder, with no corresponding parameters returned.
1473 """
1474 if field is None:
1475 # A field value of None means the value is raw.
1476 sql, params = val, []
1477 elif hasattr(val, "as_sql"):
1478 # This is an expression, let's compile it.
1479 sql, params_tuple = self.compile(val)
1480 params = list(params_tuple)
1481 elif hasattr(field, "get_placeholder"):
1482 # Some fields (e.g. geo fields) need special munging before
1483 # they can be inserted.
1484 sql, params = field.get_placeholder(val, self, self.connection), [val]
1485 else:
1486 # Return the common case for the placeholder
1487 sql, params = "%s", [val]
1488
1489 return sql, list(params) # Ensure params is a list
1490
1491 def prepare_value(self, field: Any, value: Any) -> Any:
1492 """
1493 Prepare a value to be used in a query by resolving it if it is an
1494 expression and otherwise calling the field's get_db_prep_save().
1495 """
1496 if isinstance(value, ResolvableExpression):
1497 value = value.resolve_expression(
1498 self.query, allow_joins=False, for_save=True
1499 )
1500 # Don't allow values containing Col expressions. They refer to
1501 # existing columns on a row, but in the case of insert the row
1502 # doesn't exist yet.
1503 if value.contains_column_references:
1504 raise ValueError(
1505 f'Failed to insert expression "{value}" on {field}. F() expressions '
1506 "can only be used to update, not to insert."
1507 )
1508 if value.contains_aggregate:
1509 raise FieldError(
1510 "Aggregate functions are not allowed in this query "
1511 f"({field.name}={value!r})."
1512 )
1513 if value.contains_over_clause:
1514 raise FieldError(
1515 f"Window expressions are not allowed in this query ({field.name}={value!r})."
1516 )
1517 return field.get_db_prep_save(value, connection=self.connection)
1518
1519 def pre_save_val(self, field: Any, obj: Any) -> Any:
1520 """
1521 Get the given field's value off the given obj. pre_save() is used for
1522 things like auto_now on DateTimeField. Skip it if this is a raw query.
1523 """
1524 if self.query.raw:
1525 return getattr(obj, field.attname)
1526 return field.pre_save(obj, add=True)
1527
1528 def assemble_as_sql(
1529 self, fields: list[Any], value_rows: list[list[Any]]
1530 ) -> tuple[list[list[str]], list[list]]:
1531 """
1532 Take a sequence of N fields and a sequence of M rows of values, and
1533 generate placeholder SQL and parameters for each field and value.
1534 Return a pair containing:
1535 * a sequence of M rows of N SQL placeholder strings, and
1536 * a sequence of M rows of corresponding parameter values.
1537
1538 Each placeholder string may contain any number of '%s' interpolation
1539 strings, and each parameter row will contain exactly as many params
1540 as the total number of '%s's in the corresponding placeholder row.
1541 """
1542 if not value_rows:
1543 return [], []
1544
1545 # list of (sql, [params]) tuples for each object to be saved
1546 # Shape: [n_objs][n_fields][2]
1547 rows_of_fields_as_sql = (
1548 (self.field_as_sql(field, v) for field, v in zip(fields, row))
1549 for row in value_rows
1550 )
1551
1552 # tuple like ([sqls], [[params]s]) for each object to be saved
1553 # Shape: [n_objs][2][n_fields]
1554 sql_and_param_pair_rows = (zip(*row) for row in rows_of_fields_as_sql)
1555
1556 # Extract separate lists for placeholders and params.
1557 # Each of these has shape [n_objs][n_fields]
1558 placeholder_rows, param_rows = zip(*sql_and_param_pair_rows)
1559
1560 # Params for each field are still lists, and need to be flattened.
1561 param_rows = [[p for ps in row for p in ps] for row in param_rows]
1562
1563 return placeholder_rows, param_rows
1564
1565 def as_sql( # type: ignore[override] # Returns list for internal iteration in execute_sql
1566 self, with_limits: bool = True, with_col_aliases: bool = False
1567 ) -> list[SqlWithParams]:
1568 # We don't need quote_name_unless_alias() here, since these are all
1569 # going to be column names (so we can avoid the extra overhead).
1570 qn = quote_name
1571 assert self.query.model is not None, "INSERT requires a model"
1572 meta = self.query.model._model_meta
1573 options = self.query.model.model_options
1574 result = [f"INSERT INTO {qn(options.db_table)}"]
1575 if self.query.fields:
1576 fields = self.query.fields
1577 else:
1578 fields = [meta.get_forward_field("id")]
1579 result.append("({})".format(", ".join(qn(f.column) for f in fields)))
1580
1581 if self.query.fields:
1582 value_rows = [
1583 [
1584 self.prepare_value(field, self.pre_save_val(field, obj))
1585 for field in fields
1586 ]
1587 for obj in self.query.objs
1588 ]
1589 else:
1590 # An empty object.
1591 value_rows = [[PK_DEFAULT_VALUE] for _ in self.query.objs]
1592 fields = [None]
1593
1594 placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
1595
1596 conflict_suffix_sql = on_conflict_suffix_sql(
1597 fields, # type: ignore[arg-type]
1598 self.query.on_conflict,
1599 (f.column for f in self.query.update_fields),
1600 (f.column for f in self.query.unique_fields),
1601 )
1602 if self.returning_fields:
1603 # Use RETURNING clause to get inserted values
1604 result.append(
1605 bulk_insert_sql(fields, placeholder_rows) # type: ignore[arg-type]
1606 )
1607 params = param_rows
1608 if conflict_suffix_sql:
1609 result.append(conflict_suffix_sql)
1610 # Skip empty r_sql in case returning_cols returns an empty string.
1611 returning_cols = return_insert_columns(self.returning_fields)
1612 if returning_cols:
1613 r_sql, self.returning_params = returning_cols
1614 if r_sql:
1615 result.append(r_sql)
1616 params += [list(self.returning_params)]
1617 return [(" ".join(result), tuple(chain.from_iterable(params)))]
1618
1619 # Bulk insert without returning fields
1620 result.append(bulk_insert_sql(fields, placeholder_rows)) # type: ignore[arg-type]
1621 if conflict_suffix_sql:
1622 result.append(conflict_suffix_sql)
1623 return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
1624
1625 def execute_sql( # type: ignore[override]
1626 self, returning_fields: list | None = None
1627 ) -> list:
1628 assert self.query.model is not None, "INSERT execution requires a model"
1629 options = self.query.model.model_options
1630 self.returning_fields = returning_fields
1631 with self.connection.cursor() as cursor:
1632 for sql, params in self.as_sql():
1633 cursor.execute(sql, params)
1634 if not self.returning_fields:
1635 return []
1636 # Use RETURNING clause for both single and bulk inserts
1637 if len(self.query.objs) > 1:
1638 rows = cursor.fetchall()
1639 else:
1640 rows = [cursor.fetchone()]
1641 cols = [field.get_col(options.db_table) for field in self.returning_fields]
1642 converters = self.get_converters(cols)
1643 if converters:
1644 rows = list(self.apply_converters(rows, converters))
1645 return rows
1646
1647
1648class SQLDeleteCompiler(SQLCompiler):
1649 @cached_property
1650 def single_alias(self) -> bool:
1651 # Ensure base table is in aliases.
1652 self.query.get_initial_alias()
1653 return sum(self.query.alias_refcount[t] > 0 for t in self.query.alias_map) == 1
1654
1655 @classmethod
1656 def _expr_refs_base_model(cls, expr: Any, base_model: Any) -> bool:
1657 if isinstance(expr, Query):
1658 return expr.model == base_model
1659 if not hasattr(expr, "get_source_expressions"):
1660 return False
1661 return any(
1662 cls._expr_refs_base_model(source_expr, base_model)
1663 for source_expr in expr.get_source_expressions()
1664 )
1665
1666 @cached_property
1667 def contains_self_reference_subquery(self) -> bool:
1668 return any(
1669 self._expr_refs_base_model(expr, self.query.model)
1670 for expr in chain(
1671 self.query.annotations.values(), self.query.where.children
1672 )
1673 )
1674
1675 def _as_sql(self, query: Query) -> SqlWithParams:
1676 delete = f"DELETE FROM {self.quote_name_unless_alias(query.base_table)}"
1677 try:
1678 where, params = self.compile(query.where)
1679 except FullResultSet:
1680 return delete, ()
1681 return f"{delete} WHERE {where}", tuple(params)
1682
1683 def as_sql(
1684 self, with_limits: bool = True, with_col_aliases: bool = False
1685 ) -> SqlWithParams:
1686 """
1687 Create the SQL for this query. Return the SQL string and list of
1688 parameters.
1689 """
1690 if self.single_alias and not self.contains_self_reference_subquery:
1691 return self._as_sql(self.query)
1692 innerq = self.query.clone()
1693 innerq.__class__ = Query
1694 innerq.clear_select_clause()
1695 assert self.query.model is not None, "DELETE requires a model"
1696 id_field = self.query.model._model_meta.get_forward_field("id")
1697 innerq.select = [id_field.get_col(self.query.get_initial_alias())]
1698 outerq = Query(self.query.model)
1699 outerq.add_filter("id__in", innerq)
1700 return self._as_sql(outerq)
1701
1702
1703class SQLUpdateCompiler(SQLCompiler):
1704 def as_sql(
1705 self, with_limits: bool = True, with_col_aliases: bool = False
1706 ) -> SqlWithParams:
1707 """
1708 Create the SQL for this query. Return the SQL string and list of
1709 parameters.
1710 """
1711 self.pre_sql_setup()
1712 query_values = getattr(self.query, "values", None)
1713 if not query_values:
1714 return "", ()
1715 qn = self.quote_name_unless_alias
1716 values, update_params = [], []
1717 for field, model, val in query_values:
1718 if isinstance(val, ResolvableExpression):
1719 val = val.resolve_expression(
1720 self.query, allow_joins=False, for_save=True
1721 )
1722 if val.contains_aggregate:
1723 raise FieldError(
1724 "Aggregate functions are not allowed in this query "
1725 f"({field.name}={val!r})."
1726 )
1727 if val.contains_over_clause:
1728 raise FieldError(
1729 "Window expressions are not allowed in this query "
1730 f"({field.name}={val!r})."
1731 )
1732 elif hasattr(val, "prepare_database_save"):
1733 if isinstance(field, RelatedField):
1734 val = val.prepare_database_save(field)
1735 else:
1736 raise TypeError(
1737 f"Tried to update field {field} with a model instance, {val!r}. "
1738 f"Use a value compatible with {field.__class__.__name__}."
1739 )
1740 val = field.get_db_prep_save(val, connection=self.connection)
1741
1742 # Getting the placeholder for the field.
1743 if hasattr(field, "get_placeholder"):
1744 placeholder = field.get_placeholder(val, self, self.connection)
1745 else:
1746 placeholder = "%s"
1747 name = field.column
1748 if hasattr(val, "as_sql"):
1749 sql, params = self.compile(val)
1750 values.append(f"{qn(name)} = {placeholder % sql}")
1751 update_params.extend(params)
1752 elif val is not None:
1753 values.append(f"{qn(name)} = {placeholder}")
1754 update_params.append(val)
1755 else:
1756 values.append(f"{qn(name)} = NULL")
1757 table = self.query.base_table
1758 result = [
1759 f"UPDATE {qn(table)} SET",
1760 ", ".join(values),
1761 ]
1762 try:
1763 where, params = self.compile(self.query.where)
1764 except FullResultSet:
1765 params = []
1766 else:
1767 result.append(f"WHERE {where}")
1768 return " ".join(result), tuple(update_params + list(params))
1769
1770 def execute_sql(self, result_type: str) -> int: # type: ignore[override]
1771 """
1772 Execute the specified update. Return the number of rows affected by
1773 the primary update query. The "primary update query" is the first
1774 non-empty query that is executed. Row counts for any subsequent,
1775 related queries are not available.
1776 """
1777 cursor = super().execute_sql(result_type)
1778 try:
1779 rows = cursor.rowcount if cursor else 0
1780 is_empty = cursor is None
1781 finally:
1782 if cursor:
1783 cursor.close()
1784 related_update_queries = getattr(self.query, "get_related_updates", None)
1785 if related_update_queries is None:
1786 return rows
1787 for query in related_update_queries():
1788 aux_rows = query.get_compiler().execute_sql(result_type)
1789 if is_empty and aux_rows:
1790 rows = aux_rows
1791 is_empty = False
1792 return rows
1793
1794 def pre_sql_setup(
1795 self, with_col_aliases: bool = False
1796 ) -> tuple[list[Any], list[Any], list[SqlWithParams]] | None:
1797 """
1798 If the update depends on results from other tables, munge the "where"
1799 conditions to match the format required for (portable) SQL updates.
1800
1801 If multiple updates are required, pull out the id values to update at
1802 this point so that they don't change as a result of the progressive
1803 updates.
1804 """
1805 refcounts_before = self.query.alias_refcount.copy()
1806 # Ensure base table is in the query
1807 self.query.get_initial_alias()
1808 count = self.query.count_active_tables()
1809 related_updates = getattr(self.query, "related_updates", [])
1810 if not related_updates and count == 1:
1811 return
1812 query = self.query.chain(klass=Query)
1813 query.select_related = False
1814 query.clear_ordering(force=True)
1815 query.extra = {}
1816 query.select = []
1817 fields = ["id"]
1818 related_ids_index = []
1819 for related in related_updates:
1820 # If a primary key chain exists to the targeted related update,
1821 # then the primary key value can be used for it.
1822 related_ids_index.append((related, 0))
1823
1824 query.add_fields(fields)
1825 super().pre_sql_setup()
1826
1827 # Now we adjust the current query: reset the where clause and get rid
1828 # of all the tables we don't need (since they're in the sub-select).
1829 self.query.clear_where()
1830 if related_updates:
1831 # We're using the idents in multiple update queries (so
1832 # don't want them to change).
1833 idents = []
1834 related_ids = collections.defaultdict(list)
1835 for rows in query.get_compiler().execute_sql(MULTI):
1836 idents.extend(r[0] for r in rows)
1837 for parent, index in related_ids_index:
1838 related_ids[parent].extend(r[index] for r in rows)
1839 self.query.add_filter("id__in", idents)
1840 self.query.related_ids = related_ids # type: ignore[assignment]
1841 else:
1842 # The fast path. Filters and updates in one query using a subquery.
1843 self.query.add_filter("id__in", query)
1844 self.query.reset_refcounts(refcounts_before)
1845
1846
1847class SQLAggregateCompiler(SQLCompiler):
1848 def as_sql(
1849 self, with_limits: bool = True, with_col_aliases: bool = False
1850 ) -> SqlWithParams:
1851 """
1852 Create the SQL for this query. Return the SQL string and list of
1853 parameters.
1854 """
1855 sql, params = [], []
1856 for annotation in self.query.annotation_select.values():
1857 ann_sql, ann_params = self.compile(annotation)
1858 ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params)
1859 sql.append(ann_sql)
1860 params.extend(ann_params)
1861 self.col_count = len(self.query.annotation_select)
1862 sql = ", ".join(sql)
1863 params = tuple(params)
1864
1865 inner_query = getattr(self.query, "inner_query", None)
1866 if inner_query is None:
1867 raise ValueError("Inner query expected for update with related ids")
1868 inner_query_sql, inner_query_params = inner_query.get_compiler(
1869 elide_empty=self.elide_empty,
1870 ).as_sql(with_col_aliases=True)
1871 sql = f"SELECT {sql} FROM ({inner_query_sql}) subquery"
1872 params += inner_query_params
1873 return sql, params
1874
1875
1876def cursor_iter(
1877 cursor: Any, sentinel: Any, col_count: int | None, itersize: int
1878) -> Generator[list, None, None]:
1879 """
1880 Yield blocks of rows from a cursor and ensure the cursor is closed when
1881 done.
1882 """
1883 try:
1884 for rows in iter((lambda: cursor.fetchmany(itersize)), sentinel):
1885 yield rows if col_count is None else [r[:col_count] for r in rows]
1886 finally:
1887 cursor.close()