1from __future__ import annotations
2
3import collections
4import json
5import re
6from collections.abc import Generator, Iterable, Sequence
7from functools import cached_property, partial
8from itertools import chain
9from typing import TYPE_CHECKING, Any, Protocol, cast
10
11from plain.postgres.constants import LOOKUP_SEP
12from plain.postgres.dialect import (
13 PK_DEFAULT_VALUE,
14 bulk_insert_sql,
15 distinct_sql,
16 explain_query_prefix,
17 for_update_sql,
18 limit_offset_sql,
19 on_conflict_suffix_sql,
20 quote_name,
21 return_insert_columns,
22)
23from plain.postgres.exceptions import EmptyResultSet, FieldError, FullResultSet
24from plain.postgres.expressions import (
25 F,
26 OrderBy,
27 RawSQL,
28 Ref,
29 ResolvableExpression,
30 Value,
31)
32from plain.postgres.fields.related import RelatedField
33from plain.postgres.functions import Cast, Random
34from plain.postgres.lookups import Lookup
35from plain.postgres.meta import Meta
36from plain.postgres.query_utils import select_related_descend
37from plain.postgres.sql.constants import (
38 CURSOR,
39 MULTI,
40 NO_RESULTS,
41 ORDER_DIR,
42 SINGLE,
43)
44from plain.postgres.sql.query import Query, get_order_dir
45from plain.postgres.sql.where import AND
46from plain.postgres.transaction import TransactionManagementError
47from plain.utils.hashable import make_hashable
48from plain.utils.regex_helper import _lazy_re_compile
49
50if TYPE_CHECKING:
51 from plain.postgres.connection import DatabaseConnection
52 from plain.postgres.expressions import BaseExpression
53 from plain.postgres.sql.query import InsertQuery
54
55# Type aliases for SQL compilation results
56SqlParams = tuple[Any, ...]
57SqlWithParams = tuple[str, SqlParams]
58
59
60class SQLCompilable(Protocol):
61 """Protocol for objects that can be compiled to SQL."""
62
63 def as_sql(
64 self, compiler: SQLCompiler, connection: DatabaseConnection
65 ) -> tuple[str, Sequence[Any]]:
66 """Return SQL string and parameters for this object."""
67 ...
68
69
70class PositionRef(Ref):
71 def __init__(self, ordinal: int, refs: str, source: Any):
72 self.ordinal = ordinal
73 super().__init__(refs, source)
74
75 def as_sql(
76 self, compiler: SQLCompiler, connection: DatabaseConnection
77 ) -> tuple[str, list[Any]]:
78 return str(self.ordinal), []
79
80
81def get_converters(
82 expressions: Iterable[Any], connection: DatabaseConnection
83) -> dict[int, tuple[list[Any], Any]]:
84 converters = {}
85 for i, expression in enumerate(expressions):
86 if expression:
87 field_converters = expression.get_db_converters(connection)
88 if field_converters:
89 converters[i] = (field_converters, expression)
90 return converters
91
92
93def apply_converters(
94 rows: Iterable, converters: dict, connection: DatabaseConnection
95) -> Generator[list]:
96 converters_list = list(converters.items())
97 for row in map(list, rows):
98 for pos, (convs, expression) in converters_list:
99 value = row[pos]
100 for converter in convs:
101 value = converter(value, expression, connection)
102 row[pos] = value
103 yield row
104
105
106class SQLCompiler:
107 # Multiline ordering SQL clause may appear from RawSQL.
108 ordering_parts = _lazy_re_compile(
109 r"^(.*)\s(?:ASC|DESC).*",
110 re.MULTILINE | re.DOTALL,
111 )
112
113 def __init__(
114 self, query: Query, connection: DatabaseConnection, elide_empty: bool = True
115 ):
116 self.query = query
117 self.connection = connection
118 # Some queries, e.g. coalesced aggregation, need to be executed even if
119 # they would return an empty result set.
120 self.elide_empty = elide_empty
121 self.quote_cache: dict[str, str] = {"*": "*"}
122 # The select, klass_info, and annotations are needed by QuerySet.iterator()
123 # these are set as a side-effect of executing the query. Note that we calculate
124 # separately a list of extra select columns needed for grammatical correctness
125 # of the query, but these columns are not included in self.select.
126 self.select: list[tuple[Any, SqlWithParams, str | None]] | None = None
127 self.annotation_col_map: dict[str, int] | None = None
128 self.klass_info: dict[str, Any] | None = None
129 self._meta_ordering: list[str] | None = None
130
131 def __repr__(self) -> str:
132 model_name = self.query.model.__qualname__ if self.query.model else "None"
133 return (
134 f"<{self.__class__.__qualname__} "
135 f"model={model_name} "
136 f"connection={self.connection!r}>"
137 )
138
139 def setup_query(self, with_col_aliases: bool = False) -> None:
140 if all(self.query.alias_refcount[a] == 0 for a in self.query.alias_map):
141 self.query.get_initial_alias()
142 self.select, self.klass_info, self.annotation_col_map = self.get_select(
143 with_col_aliases=with_col_aliases,
144 )
145 self.col_count = len(self.select)
146
147 def pre_sql_setup(
148 self, with_col_aliases: bool = False
149 ) -> tuple[list[Any], list[Any], list[SqlWithParams]] | None:
150 """
151 Do any necessary class setup immediately prior to producing SQL. This
152 is for things that can't necessarily be done in __init__ because we
153 might not have all the pieces in place at that time.
154 """
155 self.setup_query(with_col_aliases=with_col_aliases)
156 assert self.select is not None # Set by setup_query()
157 order_by = self.get_order_by()
158 self.where, self.having, self.qualify = self.query.where.split_having_qualify(
159 must_group_by=self.query.group_by is not None
160 )
161 extra_select = self.get_extra_select(order_by, self.select)
162 self.has_extra_select = bool(extra_select)
163 group_by = self.get_group_by(self.select + extra_select, order_by)
164 return extra_select, order_by, group_by
165
166 def get_group_by(
167 self, select: list[Any], order_by: list[Any]
168 ) -> list[SqlWithParams]:
169 """
170 Return a list of 2-tuples of form (sql, params).
171
172 The logic of what exactly the GROUP BY clause contains is hard
173 to describe in other words than "if it passes the test suite,
174 then it is correct".
175 """
176 # Some examples:
177 # SomeModel.query.annotate(Count('somecol'))
178 # GROUP BY: all fields of the model
179 #
180 # SomeModel.query.values('name').annotate(Count('somecol'))
181 # GROUP BY: name
182 #
183 # SomeModel.query.annotate(Count('somecol')).values('name')
184 # GROUP BY: all cols of the model
185 #
186 # SomeModel.query.values('name', 'id')
187 # .annotate(Count('somecol')).values('id')
188 # GROUP BY: name, id
189 #
190 # SomeModel.query.values('name').annotate(Count('somecol')).values('id')
191 # GROUP BY: name, id
192 #
193 # In fact, the self.query.group_by is the minimal set to GROUP BY. It
194 # can't be ever restricted to a smaller set, but additional columns in
195 # HAVING, ORDER BY, and SELECT clauses are added to it. Unfortunately
196 # the end result is that it is impossible to force the query to have
197 # a chosen GROUP BY clause - you can almost do this by using the form:
198 # .values(*wanted_cols).annotate(AnAggregate())
199 # but any later annotations, extra selects, values calls that
200 # refer some column outside of the wanted_cols, order_by, or even
201 # filter calls can alter the GROUP BY clause.
202
203 # The query.group_by is either None (no GROUP BY at all), True
204 # (group by select fields), or a list of expressions to be added
205 # to the group by.
206 if self.query.group_by is None:
207 return []
208 expressions = []
209 group_by_refs = set()
210 if self.query.group_by is not True:
211 # If the group by is set to a list (by .values() call most likely),
212 # then we need to add everything in it to the GROUP BY clause.
213 # Backwards compatibility hack for setting query.group_by. Remove
214 # when we have public API way of forcing the GROUP BY clause.
215 # Converts string references to expressions.
216 for expr in self.query.group_by:
217 if not hasattr(expr, "as_sql"):
218 expr = self.query.resolve_ref(expr)
219 if isinstance(expr, Ref):
220 if expr.refs not in group_by_refs:
221 group_by_refs.add(expr.refs)
222 expressions.append(expr.source)
223 else:
224 expressions.append(expr)
225 # Note that even if the group_by is set, it is only the minimal
226 # set to group by. So, we need to add cols in select, order_by, and
227 # having into the select in any case.
228 selected_expr_positions = {}
229 for ordinal, (expr, _, alias) in enumerate(select, start=1):
230 if alias:
231 selected_expr_positions[expr] = ordinal
232 # Skip members of the select clause that are already explicitly
233 # grouped against.
234 if alias in group_by_refs:
235 continue
236 expressions.extend(expr.get_group_by_cols())
237 if not self._meta_ordering:
238 for expr, (sql, params, is_ref) in order_by:
239 # Skip references to the SELECT clause, as all expressions in
240 # the SELECT clause are already part of the GROUP BY.
241 if not is_ref:
242 expressions.extend(expr.get_group_by_cols())
243 having_group_by = self.having.get_group_by_cols() if self.having else []
244 for expr in having_group_by:
245 expressions.append(expr)
246 result = []
247 seen = set()
248 expressions = self.collapse_group_by(expressions, having_group_by)
249
250 for expr in expressions:
251 try:
252 sql, params = self.compile(expr)
253 except (EmptyResultSet, FullResultSet):
254 continue
255 # Use select index for GROUP BY when possible
256 if (position := selected_expr_positions.get(expr)) is not None:
257 sql, params = str(position), ()
258 else:
259 sql, params = expr.select_format(self, sql, params)
260 params_hash = make_hashable(params)
261 if (sql, params_hash) not in seen:
262 result.append((sql, params))
263 seen.add((sql, params_hash))
264 return result
265
266 def collapse_group_by(self, expressions: list[Any], having: list[Any]) -> list[Any]:
267 # Use group by functional dependence reduction:
268 # expressions can be reduced to the set of selected table
269 # primary keys as all other columns are functionally dependent on them.
270 # Filter out all expressions associated with a table's primary key
271 # present in the grouped columns. This is done by identifying all
272 # tables that have their primary key included in the grouped
273 # columns and removing non-primary key columns referring to them.
274 pks = {
275 expr
276 for expr in expressions
277 if hasattr(expr, "target") and expr.target.primary_key
278 }
279 aliases = {expr.alias for expr in pks}
280 return [
281 expr
282 for expr in expressions
283 if expr in pks
284 or expr in having
285 or getattr(expr, "alias", None) not in aliases
286 ]
287
288 def get_select(
289 self, with_col_aliases: bool = False
290 ) -> tuple[
291 list[tuple[Any, SqlWithParams, str | None]],
292 dict[str, Any] | None,
293 dict[str, int],
294 ]:
295 """
296 Return three values:
297 - a list of 3-tuples of (expression, (sql, params), alias)
298 - a klass_info structure,
299 - a dictionary of annotations
300
301 The (sql, params) is what the expression will produce, and alias is the
302 "AS alias" for the column (possibly None).
303
304 The klass_info structure contains the following information:
305 - The base model of the query.
306 - Which columns for that model are present in the query (by
307 position of the select clause).
308 - related_klass_infos: [f, klass_info] to descent into
309
310 The annotations is a dictionary of {'attname': column position} values.
311 """
312 select = []
313 klass_info = None
314 annotations = {}
315 select_idx = 0
316 for alias, (sql, params) in self.query.extra_select.items():
317 annotations[alias] = select_idx
318 select.append((RawSQL(sql, params), alias))
319 select_idx += 1
320 assert not (self.query.select and self.query.default_cols)
321 select_mask = self.query.get_select_mask()
322 if self.query.default_cols:
323 cols = self.get_default_columns(select_mask)
324 else:
325 # self.query.select is a special case. These columns never go to
326 # any model.
327 cols = self.query.select
328 if cols:
329 select_list = []
330 for col in cols:
331 select_list.append(select_idx)
332 select.append((col, None))
333 select_idx += 1
334 klass_info = {
335 "model": self.query.model,
336 "select_fields": select_list,
337 }
338 for alias, annotation in self.query.annotation_select.items():
339 annotations[alias] = select_idx
340 select.append((annotation, alias))
341 select_idx += 1
342
343 if self.query.select_related:
344 related_klass_infos = self.get_related_selections(select, select_mask)
345 if klass_info is not None:
346 klass_info["related_klass_infos"] = related_klass_infos
347
348 ret = []
349 col_idx = 1
350 for col, alias in select:
351 try:
352 sql, params = self.compile(col)
353 except EmptyResultSet:
354 empty_result_set_value = getattr(
355 col, "empty_result_set_value", NotImplemented
356 )
357 if empty_result_set_value is NotImplemented:
358 # Select a predicate that's always False.
359 sql, params = "0", ()
360 else:
361 sql, params = self.compile(Value(empty_result_set_value))
362 except FullResultSet:
363 sql, params = self.compile(Value(True))
364 else:
365 sql, params = col.select_format(self, sql, params)
366 if alias is None and with_col_aliases:
367 alias = f"col{col_idx}"
368 col_idx += 1
369 ret.append((col, (sql, params), alias))
370 return ret, klass_info, annotations
371
372 def _order_by_pairs(self) -> Generator[tuple[OrderBy, bool]]:
373 if self.query.extra_order_by:
374 ordering = self.query.extra_order_by
375 elif not self.query.default_ordering:
376 ordering = self.query.order_by
377 elif self.query.order_by:
378 ordering = self.query.order_by
379 elif (
380 self.query.model
381 and (options := self.query.model.model_options)
382 and options.ordering
383 ):
384 ordering = options.ordering
385 self._meta_ordering = list(ordering)
386 else:
387 ordering = []
388 if self.query.standard_ordering:
389 default_order, _ = ORDER_DIR["ASC"]
390 else:
391 default_order, _ = ORDER_DIR["DESC"]
392
393 selected_exprs = {}
394 if select := self.select:
395 for ordinal, (expr, _, alias) in enumerate(select, start=1):
396 pos_expr = PositionRef(ordinal, alias, expr)
397 if alias:
398 selected_exprs[alias] = pos_expr
399 selected_exprs[expr] = pos_expr
400
401 for field in ordering:
402 if isinstance(field, ResolvableExpression):
403 # field is a BaseExpression (has asc/desc/copy methods)
404 field_expr = cast(BaseExpression, field)
405 if isinstance(field_expr, Value):
406 # output_field must be resolved for constants.
407 field_expr = Cast(field_expr, field_expr.output_field)
408 if not isinstance(field_expr, OrderBy):
409 field_expr = field_expr.asc()
410 if not self.query.standard_ordering:
411 field_expr = field_expr.copy()
412 field_expr.reverse_ordering()
413 field = field_expr
414 select_ref = selected_exprs.get(field.expression)
415 if select_ref or (
416 isinstance(field.expression, F)
417 and (select_ref := selected_exprs.get(field.expression.name))
418 ):
419 field = field.copy()
420 field.expression = select_ref
421 yield field, select_ref is not None
422 continue
423 if field == "?": # random
424 yield OrderBy(Random()), False
425 continue
426
427 col, order = get_order_dir(field, default_order)
428 descending = order == "DESC"
429
430 if select_ref := selected_exprs.get(col):
431 # Reference to expression in SELECT clause
432 yield (
433 OrderBy(
434 select_ref,
435 descending=descending,
436 ),
437 True,
438 )
439 continue
440 if col in self.query.annotations:
441 # References to an expression which is masked out of the SELECT
442 # clause.
443 expr = self.query.annotations[col]
444 if isinstance(expr, Value):
445 # output_field must be resolved for constants.
446 expr = Cast(expr, expr.output_field)
447 yield OrderBy(expr, descending=descending), False
448 continue
449
450 if "." in field:
451 # This came in through an extra(order_by=...) addition. Pass it
452 # on verbatim.
453 table, col = col.split(".", 1)
454 yield (
455 OrderBy(
456 RawSQL(f"{self.quote_name_unless_alias(table)}.{col}", []),
457 descending=descending,
458 ),
459 False,
460 )
461 continue
462
463 if self.query.extra and col in self.query.extra:
464 if col in self.query.extra_select:
465 yield (
466 OrderBy(
467 Ref(col, RawSQL(*self.query.extra[col])),
468 descending=descending,
469 ),
470 True,
471 )
472 else:
473 yield (
474 OrderBy(RawSQL(*self.query.extra[col]), descending=descending),
475 False,
476 )
477 else:
478 # 'col' is of the form 'field' or 'field1__field2' or
479 # '-field1__field2__field', etc.
480 assert self.query.model is not None, (
481 "Ordering by fields requires a model"
482 )
483 meta = self.query.model._model_meta
484 yield from self.find_ordering_name(
485 field,
486 meta,
487 default_order=default_order,
488 )
489
490 def get_order_by(self) -> list[tuple[Any, tuple[str, tuple, bool]]]:
491 """
492 Return a list of 2-tuples of the form (expr, (sql, params, is_ref)) for
493 the ORDER BY clause.
494
495 The order_by clause can alter the select clause (for example it can add
496 aliases to clauses that do not yet have one, or it can add totally new
497 select clauses).
498 """
499 result = []
500 seen = set()
501 for expr, is_ref in self._order_by_pairs():
502 resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None)
503 sql, params = self.compile(resolved)
504 # Don't add the same column twice, but the order direction is
505 # not taken into account so we strip it. When this entire method
506 # is refactored into expressions, then we can check each part as we
507 # generate it.
508 without_ordering = self.ordering_parts.search(sql)[1]
509 params_hash = make_hashable(params)
510 if (without_ordering, params_hash) in seen:
511 continue
512 seen.add((without_ordering, params_hash))
513 result.append((resolved, (sql, params, is_ref)))
514 return result
515
516 def get_extra_select(
517 self, order_by: list[Any], select: list[Any]
518 ) -> list[tuple[Any, SqlWithParams, None]]:
519 extra_select = []
520 if self.query.distinct and not self.query.distinct_fields:
521 select_sql = [t[1] for t in select]
522 for expr, (sql, params, is_ref) in order_by:
523 without_ordering = self.ordering_parts.search(sql)[1]
524 if not is_ref and (without_ordering, params) not in select_sql:
525 extra_select.append((expr, (without_ordering, params), None))
526 return extra_select
527
528 def quote_name_unless_alias(self, name: str) -> str:
529 """
530 A wrapper around quote_name() that doesn't quote aliases for table
531 names. This avoids problems with some SQL dialects that treat quoted
532 strings specially (e.g. PostgreSQL).
533 """
534 if name in self.quote_cache:
535 return self.quote_cache[name]
536 if (
537 (name in self.query.alias_map and name not in self.query.table_map)
538 or name in self.query.extra_select
539 or (
540 self.query.external_aliases.get(name)
541 and name not in self.query.table_map
542 )
543 ):
544 self.quote_cache[name] = name
545 return name
546 r = quote_name(name)
547 self.quote_cache[name] = r
548 return r
549
550 def compile(self, node: SQLCompilable) -> SqlWithParams:
551 sql, params = node.as_sql(self, self.connection)
552 return sql, tuple(params)
553
554 def get_qualify_sql(self) -> tuple[list[str], list[Any]]:
555 where_parts = []
556 if self.where:
557 where_parts.append(self.where)
558 if self.having:
559 where_parts.append(self.having)
560 inner_query = self.query.clone()
561 inner_query.subquery = True
562 inner_query.where = inner_query.where.__class__(where_parts)
563 # Augment the inner query with any window function references that
564 # might have been masked via values() and alias(). If any masked
565 # aliases are added they'll be masked again to avoid fetching
566 # the data in the `if qual_aliases` branch below.
567 select = {
568 expr: alias for expr, _, alias in self.get_select(with_col_aliases=True)[0]
569 }
570 select_aliases = set(select.values())
571 qual_aliases = set()
572 replacements = {}
573
574 def collect_replacements(expressions: list[Any]) -> None:
575 while expressions:
576 expr = expressions.pop()
577 if expr in replacements:
578 continue
579 elif select_alias := select.get(expr):
580 replacements[expr] = select_alias
581 elif isinstance(expr, Lookup):
582 expressions.extend(expr.get_source_expressions())
583 elif isinstance(expr, Ref):
584 if expr.refs not in select_aliases:
585 expressions.extend(expr.get_source_expressions())
586 else:
587 num_qual_alias = len(qual_aliases)
588 select_alias = f"qual{num_qual_alias}"
589 qual_aliases.add(select_alias)
590 inner_query.add_annotation(expr, select_alias)
591 replacements[expr] = select_alias
592
593 qualify = self.qualify
594 if qualify is None:
595 raise ValueError("QUALIFY clause expected but not provided")
596 collect_replacements(list(qualify.leaves()))
597 qualify = qualify.replace_expressions(
598 {expr: Ref(alias, expr) for expr, alias in replacements.items()}
599 )
600 self.qualify = qualify
601 order_by = []
602 for order_by_expr, *_ in self.get_order_by():
603 collect_replacements(order_by_expr.get_source_expressions())
604 order_by.append(
605 order_by_expr.replace_expressions(
606 {expr: Ref(alias, expr) for expr, alias in replacements.items()}
607 )
608 )
609 inner_query_compiler = inner_query.get_compiler(elide_empty=self.elide_empty)
610 inner_sql, inner_params = inner_query_compiler.as_sql(
611 # The limits must be applied to the outer query to avoid pruning
612 # results too eagerly.
613 with_limits=False,
614 # Force unique aliasing of selected columns to avoid collisions
615 # and make rhs predicates referencing easier.
616 with_col_aliases=True,
617 )
618 qualify_sql, qualify_params = self.compile(qualify)
619 result = [
620 "SELECT * FROM (",
621 inner_sql,
622 ")",
623 quote_name("qualify"),
624 "WHERE",
625 qualify_sql,
626 ]
627 if qual_aliases:
628 # If some select aliases were unmasked for filtering purposes they
629 # must be masked back.
630 cols = [quote_name(alias) for alias in select.values() if alias is not None]
631 result = [
632 "SELECT",
633 ", ".join(cols),
634 "FROM (",
635 *result,
636 ")",
637 quote_name("qualify_mask"),
638 ]
639 params = list(inner_params) + list(qualify_params)
640 # As the SQL spec is unclear on whether or not derived tables
641 # ordering must propagate it has to be explicitly repeated on the
642 # outer-most query to ensure it's preserved.
643 if order_by:
644 ordering_sqls = []
645 for ordering in order_by:
646 ordering_sql, ordering_params = self.compile(ordering)
647 ordering_sqls.append(ordering_sql)
648 params.extend(ordering_params)
649 result.extend(["ORDER BY", ", ".join(ordering_sqls)])
650 return result, params
651
652 def as_sql(
653 self, with_limits: bool = True, with_col_aliases: bool = False
654 ) -> SqlWithParams:
655 """
656 Create the SQL for this query. Return the SQL string and list of
657 parameters.
658
659 If 'with_limits' is False, any limit/offset information is not included
660 in the query.
661 """
662 refcounts_before = self.query.alias_refcount.copy()
663 try:
664 result = self.pre_sql_setup(with_col_aliases=with_col_aliases)
665 assert result is not None # SQLCompiler.pre_sql_setup always returns tuple
666 extra_select, order_by, group_by = result
667 assert self.select is not None # Set by pre_sql_setup()
668 for_update_part = None
669 # Is a LIMIT/OFFSET clause needed?
670 with_limit_offset = with_limits and self.query.is_sliced
671 if self.qualify:
672 result, params = self.get_qualify_sql()
673 order_by = None
674 else:
675 distinct_fields, distinct_params = self.get_distinct()
676 # This must come after 'select', 'ordering', and 'distinct'
677 # (see docstring of get_from_clause() for details).
678 from_, f_params = self.get_from_clause()
679 try:
680 where, w_params = (
681 self.compile(self.where) if self.where is not None else ("", [])
682 )
683 except EmptyResultSet:
684 if self.elide_empty:
685 raise
686 # Use a predicate that's always False.
687 where, w_params = "0 = 1", []
688 except FullResultSet:
689 where, w_params = "", []
690 try:
691 having, h_params = (
692 self.compile(self.having)
693 if self.having is not None
694 else ("", [])
695 )
696 except FullResultSet:
697 having, h_params = "", []
698 result = ["SELECT"]
699 params = []
700
701 if self.query.distinct:
702 distinct_result, distinct_params = distinct_sql(
703 distinct_fields,
704 distinct_params,
705 )
706 result += distinct_result
707 params += distinct_params
708
709 out_cols = []
710 for _, (s_sql, s_params), alias in self.select + extra_select:
711 if alias:
712 s_sql = f"{s_sql} AS {quote_name(alias)}"
713 params.extend(s_params)
714 out_cols.append(s_sql)
715
716 result += [", ".join(out_cols)]
717 if from_:
718 result += ["FROM", *from_]
719 params.extend(f_params)
720
721 if self.query.select_for_update:
722 if self.connection.get_autocommit():
723 raise TransactionManagementError(
724 "select_for_update cannot be used outside of a transaction."
725 )
726
727 for_update_part = for_update_sql(
728 nowait=self.query.select_for_update_nowait,
729 skip_locked=self.query.select_for_update_skip_locked,
730 of=tuple(self.get_select_for_update_of_arguments()),
731 no_key=self.query.select_for_no_key_update,
732 )
733
734 if where:
735 result.append(f"WHERE {where}")
736 params.extend(w_params)
737
738 grouping = []
739 for g_sql, g_params in group_by:
740 grouping.append(g_sql)
741 params.extend(g_params)
742 if grouping:
743 if distinct_fields:
744 raise NotImplementedError(
745 "annotate() + distinct(fields) is not implemented."
746 )
747 order_by = order_by or []
748 result.append("GROUP BY {}".format(", ".join(grouping)))
749 if self._meta_ordering:
750 order_by = None
751 if having:
752 result.append(f"HAVING {having}")
753 params.extend(h_params)
754
755 if self.query.explain_info:
756 result.insert(
757 0,
758 explain_query_prefix(
759 self.query.explain_info.format,
760 **self.query.explain_info.options,
761 ),
762 )
763
764 if order_by:
765 ordering = []
766 for _, (o_sql, o_params, _) in order_by:
767 ordering.append(o_sql)
768 params.extend(o_params)
769 result.append("ORDER BY {}".format(", ".join(ordering)))
770
771 if with_limit_offset:
772 result.append(
773 limit_offset_sql(self.query.low_mark, self.query.high_mark)
774 )
775
776 if for_update_part:
777 result.append(for_update_part)
778
779 if self.query.subquery and extra_select:
780 # If the query is used as a subquery, the extra selects would
781 # result in more columns than the left-hand side expression is
782 # expecting. This can happen when a subquery uses a combination
783 # of order_by() and distinct(), forcing the ordering expressions
784 # to be selected as well. Wrap the query in another subquery
785 # to exclude extraneous selects.
786 sub_selects = []
787 sub_params = []
788 for index, (select, _, alias) in enumerate(self.select, start=1):
789 if alias:
790 sub_selects.append(
791 "{}.{}".format(
792 quote_name("subquery"),
793 quote_name(alias),
794 )
795 )
796 else:
797 select_clone = select.relabeled_clone(
798 {select.alias: "subquery"}
799 )
800 subselect, subparams = select_clone.as_sql(
801 self, self.connection
802 )
803 sub_selects.append(subselect)
804 sub_params.extend(subparams)
805 return "SELECT {} FROM ({}) subquery".format(
806 ", ".join(sub_selects),
807 " ".join(result),
808 ), tuple(sub_params + params)
809
810 return " ".join(result), tuple(params)
811 finally:
812 # Finally do cleanup - get rid of the joins we created above.
813 self.query.reset_refcounts(refcounts_before)
814
815 def get_default_columns(
816 self,
817 select_mask: Any,
818 start_alias: str | None = None,
819 opts: Meta | None = None,
820 ) -> list[Any]:
821 """
822 Compute the default columns for selecting every field in the base
823 model. Will sometimes be called to pull in related models (e.g. via
824 select_related), in which case "opts" (Meta) and "start_alias" will be given
825 to provide a starting point for the traversal.
826
827 Return a list of strings, quoted appropriately for use in SQL
828 directly, as well as a set of aliases used in the select statement (if
829 'as_pairs' is True, return a list of (alias, col_name) pairs instead
830 of strings as the first component and None as the second component).
831 """
832 result = []
833 if opts is None:
834 if self.query.model is None:
835 return result
836 opts = self.query.model._model_meta
837 start_alias = start_alias or self.query.get_initial_alias()
838
839 for field in opts.concrete_fields:
840 model = field.model
841 # Since we no longer have proxy models or inheritance,
842 # the field's model should always be the same as opts.model.
843 if model == opts.model:
844 model = None
845 if select_mask and field not in select_mask:
846 continue
847
848 column = field.get_col(start_alias)
849 result.append(column)
850 return result
851
852 def get_distinct(self) -> tuple[list[str], list]:
853 """
854 Return a quoted list of fields to use in DISTINCT ON part of the query.
855
856 This method can alter the tables in the query, and thus it must be
857 called before get_from_clause().
858 """
859 result = []
860 params = []
861 if not self.query.distinct_fields:
862 return result, params
863
864 if self.query.model is None:
865 return result, params
866 opts = self.query.model._model_meta
867
868 for name in self.query.distinct_fields:
869 parts = name.split(LOOKUP_SEP)
870 _, targets, alias, joins, path, _, transform_function = self._setup_joins(
871 parts, opts, None
872 )
873 targets, alias, _ = self.query.trim_joins(targets, joins, path)
874 for target in targets:
875 if name in self.query.annotation_select:
876 result.append(quote_name(name))
877 else:
878 r, p = self.compile(transform_function(target, alias))
879 result.append(r)
880 params.append(p)
881 return result, params
882
883 def find_ordering_name(
884 self,
885 name: str,
886 meta: Meta,
887 alias: str | None = None,
888 default_order: str = "ASC",
889 already_seen: set | None = None,
890 ) -> list[tuple[OrderBy, bool]]:
891 """
892 Return the table alias (the name might be ambiguous, the alias will
893 not be) and column name for ordering by the given 'name' parameter.
894 The 'name' is of the form 'field1__field2__...__fieldN'.
895 """
896 name, order = get_order_dir(name, default_order)
897 descending = order == "DESC"
898 pieces = name.split(LOOKUP_SEP)
899 (
900 field,
901 targets,
902 alias,
903 joins,
904 path,
905 meta,
906 transform_function,
907 ) = self._setup_joins(pieces, meta, alias)
908
909 # If we get to this point and the field is a relation to another model,
910 # append the default ordering for that model unless it is the
911 # attribute name of the field that is specified or
912 # there are transforms to process.
913 if (
914 isinstance(field, RelatedField)
915 and meta.model.model_options.ordering
916 and getattr(field, "attname", None) != pieces[-1]
917 and not getattr(transform_function, "has_transforms", False)
918 ):
919 # Firstly, avoid infinite loops.
920 already_seen = already_seen or set()
921 join_tuple = tuple(
922 getattr(self.query.alias_map[j], "join_cols", None) for j in joins
923 )
924 if join_tuple in already_seen:
925 raise FieldError("Infinite loop caused by ordering.")
926 already_seen.add(join_tuple)
927
928 results = []
929 for item in meta.model.model_options.ordering:
930 if isinstance(item, ResolvableExpression) and not isinstance(
931 item, OrderBy
932 ):
933 item_expr: BaseExpression = cast(BaseExpression, item)
934 item = item_expr.desc() if descending else item_expr.asc()
935 if isinstance(item, OrderBy):
936 results.append(
937 (item.prefix_references(f"{name}{LOOKUP_SEP}"), False)
938 )
939 continue
940 results.extend(
941 (expr.prefix_references(f"{name}{LOOKUP_SEP}"), is_ref)
942 for expr, is_ref in self.find_ordering_name(
943 item, meta, alias, order, already_seen
944 )
945 )
946 return results
947 targets, alias, _ = self.query.trim_joins(targets, joins, path)
948 return [
949 (OrderBy(transform_function(t, alias), descending=descending), False)
950 for t in targets
951 ]
952
953 def _setup_joins(
954 self, pieces: list[str], meta: Meta, alias: str | None
955 ) -> tuple[Any, Any, str, list, Any, Meta, Any]:
956 """
957 Helper method for get_order_by() and get_distinct().
958
959 get_ordering() and get_distinct() must produce same target columns on
960 same input, as the prefixes of get_ordering() and get_distinct() must
961 match. Executing SQL where this is not true is an error.
962 """
963 alias = alias or self.query.get_initial_alias()
964 assert alias is not None
965 field, targets, meta, joins, path, transform_function = self.query.setup_joins(
966 pieces, meta, alias
967 )
968 alias = joins[-1]
969 return field, targets, alias, joins, path, meta, transform_function
970
971 def get_from_clause(self) -> tuple[list[str], list]:
972 """
973 Return a list of strings that are joined together to go after the
974 "FROM" part of the query, as well as a list any extra parameters that
975 need to be included. Subclasses, can override this to create a
976 from-clause via a "select".
977
978 This should only be called after any SQL construction methods that
979 might change the tables that are needed. This means the select columns,
980 ordering, and distinct must be done first.
981 """
982 result = []
983 params = []
984 for alias in tuple(self.query.alias_map):
985 if not self.query.alias_refcount[alias]:
986 continue
987 try:
988 from_clause = self.query.alias_map[alias]
989 except KeyError:
990 # Extra tables can end up in self.tables, but not in the
991 # alias_map if they aren't in a join. That's OK. We skip them.
992 continue
993 clause_sql, clause_params = self.compile(from_clause)
994 result.append(clause_sql)
995 params.extend(clause_params)
996 for t in self.query.extra_tables:
997 alias, _ = self.query.table_alias(t)
998 # Only add the alias if it's not already present (the table_alias()
999 # call increments the refcount, so an alias refcount of one means
1000 # this is the only reference).
1001 if (
1002 alias not in self.query.alias_map
1003 or self.query.alias_refcount[alias] == 1
1004 ):
1005 result.append(f", {self.quote_name_unless_alias(alias)}")
1006 return result, params
1007
1008 def get_related_selections(
1009 self,
1010 select: list[Any],
1011 select_mask: Any,
1012 opts: Meta | None = None,
1013 root_alias: str | None = None,
1014 cur_depth: int = 1,
1015 requested: dict | None = None,
1016 restricted: bool | None = None,
1017 ) -> list[dict[str, Any]]:
1018 """
1019 Fill in the information needed for a select_related query. The current
1020 depth is measured as the number of connections away from the root model
1021 (for example, cur_depth=1 means we are looking at models with direct
1022 connections to the root model).
1023
1024 Args:
1025 opts: Meta for the model being queried (internal metadata)
1026 """
1027
1028 related_klass_infos = []
1029 if not restricted and cur_depth > self.query.max_depth:
1030 # We've recursed far enough; bail out.
1031 return related_klass_infos
1032
1033 if not opts:
1034 assert self.query.model is not None, "select_related requires a model"
1035 opts = self.query.model._model_meta
1036 root_alias = self.query.get_initial_alias()
1037
1038 assert root_alias is not None # Must be provided or set above
1039 assert opts is not None
1040
1041 def _get_field_choices() -> chain:
1042 direct_choices = (
1043 f.name for f in opts.fields if isinstance(f, RelatedField)
1044 )
1045 reverse_choices = (
1046 f.field.related_query_name()
1047 for f in opts.related_objects
1048 if f.field.primary_key
1049 )
1050 return chain(
1051 direct_choices, reverse_choices, self.query._filtered_relations
1052 )
1053
1054 # Setup for the case when only particular related fields should be
1055 # included in the related selection.
1056 fields_found = set()
1057 if requested is None:
1058 restricted = isinstance(self.query.select_related, dict)
1059 if restricted:
1060 requested = cast(dict, self.query.select_related)
1061
1062 def get_related_klass_infos(
1063 klass_info: dict, related_klass_infos: list
1064 ) -> None:
1065 klass_info["related_klass_infos"] = related_klass_infos
1066
1067 for f in opts.fields:
1068 fields_found.add(f.name)
1069
1070 if restricted:
1071 assert requested is not None
1072 next = requested.get(f.name, {})
1073 if not isinstance(f, RelatedField):
1074 # If a non-related field is used like a relation,
1075 # or if a single non-relational field is given.
1076 if next or f.name in requested:
1077 raise FieldError(
1078 "Non-relational field given in select_related: '{}'. "
1079 "Choices are: {}".format(
1080 f.name,
1081 ", ".join(_get_field_choices()) or "(none)",
1082 )
1083 )
1084 else:
1085 next = None
1086
1087 if not select_related_descend(f, restricted, requested, select_mask):
1088 continue
1089 related_select_mask = select_mask.get(f) or {}
1090 klass_info = {
1091 "model": f.remote_field.model,
1092 "field": f,
1093 "reverse": False,
1094 "local_setter": f.set_cached_value,
1095 "remote_setter": f.remote_field.set_cached_value
1096 if f.primary_key
1097 else lambda x, y: None,
1098 }
1099 related_klass_infos.append(klass_info)
1100 select_fields = []
1101 _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias)
1102 alias = joins[-1]
1103 columns = self.get_default_columns(
1104 related_select_mask,
1105 start_alias=alias,
1106 opts=f.remote_field.model._model_meta,
1107 )
1108 for col in columns:
1109 select_fields.append(len(select))
1110 select.append((col, None))
1111 klass_info["select_fields"] = select_fields
1112 next_klass_infos = self.get_related_selections(
1113 select,
1114 related_select_mask,
1115 f.remote_field.model._model_meta,
1116 alias,
1117 cur_depth + 1,
1118 next,
1119 restricted,
1120 )
1121 get_related_klass_infos(klass_info, next_klass_infos)
1122
1123 if restricted:
1124 from plain.postgres.fields.reverse_related import ManyToManyRel
1125
1126 related_fields = [
1127 (o.field, o.related_model)
1128 for o in opts.related_objects
1129 if o.field.primary_key and not isinstance(o, ManyToManyRel)
1130 ]
1131 for related_field, model in related_fields:
1132 related_select_mask = select_mask.get(related_field) or {}
1133
1134 if not select_related_descend(
1135 related_field,
1136 restricted,
1137 requested,
1138 related_select_mask,
1139 reverse=True,
1140 ):
1141 continue
1142
1143 related_field_name = related_field.related_query_name()
1144 fields_found.add(related_field_name)
1145
1146 join_info = self.query.setup_joins(
1147 [related_field_name], opts, root_alias
1148 )
1149 alias = join_info.joins[-1]
1150 klass_info = {
1151 "model": model,
1152 "field": related_field,
1153 "reverse": True,
1154 "local_setter": related_field.remote_field.set_cached_value,
1155 "remote_setter": related_field.set_cached_value,
1156 }
1157 related_klass_infos.append(klass_info)
1158 select_fields = []
1159 columns = self.get_default_columns(
1160 related_select_mask,
1161 start_alias=alias,
1162 opts=model._model_meta,
1163 )
1164 for col in columns:
1165 select_fields.append(len(select))
1166 select.append((col, None))
1167 klass_info["select_fields"] = select_fields
1168 assert requested is not None
1169 next = requested.get(related_field.related_query_name(), {})
1170 next_klass_infos = self.get_related_selections(
1171 select,
1172 related_select_mask,
1173 model._model_meta,
1174 alias,
1175 cur_depth + 1,
1176 next,
1177 restricted,
1178 )
1179 get_related_klass_infos(klass_info, next_klass_infos)
1180
1181 def local_setter(final_field: Any, obj: Any, from_obj: Any) -> None:
1182 # Set a reverse fk object when relation is non-empty.
1183 if from_obj:
1184 final_field.remote_field.set_cached_value(from_obj, obj)
1185
1186 def local_setter_noop(obj: Any, from_obj: Any) -> None:
1187 pass
1188
1189 def remote_setter(name: str, obj: Any, from_obj: Any) -> None:
1190 setattr(from_obj, name, obj)
1191
1192 assert requested is not None
1193 for name in list(requested):
1194 # Filtered relations work only on the topmost level.
1195 if cur_depth > 1:
1196 break
1197 if name in self.query._filtered_relations:
1198 fields_found.add(name)
1199 final_field, _, join_opts, joins, _, _ = self.query.setup_joins(
1200 [name], opts, root_alias
1201 )
1202 model = join_opts.model
1203 alias = joins[-1]
1204 klass_info = {
1205 "model": model,
1206 "field": final_field,
1207 "reverse": True,
1208 "local_setter": (
1209 partial(local_setter, final_field)
1210 if len(joins) <= 2
1211 else local_setter_noop
1212 ),
1213 "remote_setter": partial(remote_setter, name),
1214 }
1215 related_klass_infos.append(klass_info)
1216 select_fields = []
1217 field_select_mask = select_mask.get((name, final_field)) or {}
1218 columns = self.get_default_columns(
1219 field_select_mask,
1220 start_alias=alias,
1221 opts=model._model_meta,
1222 )
1223 for col in columns:
1224 select_fields.append(len(select))
1225 select.append((col, None))
1226 klass_info["select_fields"] = select_fields
1227 next_requested = requested.get(name, {})
1228 next_klass_infos = self.get_related_selections(
1229 select,
1230 field_select_mask,
1231 opts=model._model_meta,
1232 root_alias=alias,
1233 cur_depth=cur_depth + 1,
1234 requested=next_requested,
1235 restricted=restricted,
1236 )
1237 get_related_klass_infos(klass_info, next_klass_infos)
1238 fields_not_found = set(requested).difference(fields_found)
1239 if fields_not_found:
1240 invalid_fields = (f"'{s}'" for s in fields_not_found)
1241 raise FieldError(
1242 "Invalid field name(s) given in select_related: {}. "
1243 "Choices are: {}".format(
1244 ", ".join(invalid_fields),
1245 ", ".join(_get_field_choices()) or "(none)",
1246 )
1247 )
1248 return related_klass_infos
1249
1250 def get_select_for_update_of_arguments(self) -> list[str]:
1251 """
1252 Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
1253 the query.
1254 """
1255
1256 def _get_first_selected_col_from_model(klass_info: dict) -> Any | None:
1257 """
1258 Find the first selected column from a model. If it doesn't exist,
1259 don't lock a model.
1260
1261 select_fields is filled recursively, so it also contains fields
1262 from the parent models.
1263 """
1264 assert self.select is not None
1265 model = klass_info["model"]
1266 for select_index in klass_info["select_fields"]:
1267 if self.select[select_index][0].target.model == model:
1268 return self.select[select_index][0]
1269 return None
1270
1271 def _get_field_choices() -> Generator[str]:
1272 """Yield all allowed field paths in breadth-first search order."""
1273 queue = collections.deque([(None, self.klass_info)])
1274 while queue:
1275 parent_path, klass_info = queue.popleft()
1276 if parent_path is None:
1277 path = []
1278 yield "self"
1279 else:
1280 assert klass_info is not None # Only first iteration has None
1281 field = klass_info["field"]
1282 if klass_info["reverse"]:
1283 field = field.remote_field
1284 path = parent_path + [field.name]
1285 yield LOOKUP_SEP.join(path)
1286 if klass_info is not None:
1287 queue.extend(
1288 (path, related_klass_info) # type: ignore[invalid-argument-type]
1289 for related_klass_info in klass_info.get(
1290 "related_klass_infos", []
1291 )
1292 )
1293
1294 if not self.klass_info:
1295 return []
1296 result = []
1297 invalid_names = []
1298 for name in self.query.select_for_update_of:
1299 klass_info = self.klass_info
1300 if name == "self":
1301 col = _get_first_selected_col_from_model(klass_info)
1302 else:
1303 for part in name.split(LOOKUP_SEP):
1304 if klass_info is None:
1305 break
1306 klass_infos = (*klass_info.get("related_klass_infos", []),)
1307 for related_klass_info in klass_infos:
1308 field = related_klass_info["field"]
1309 if related_klass_info["reverse"]:
1310 field = field.remote_field
1311 if field.name == part:
1312 klass_info = related_klass_info
1313 break
1314 else:
1315 klass_info = None
1316 break
1317 if klass_info is None:
1318 invalid_names.append(name)
1319 continue
1320 col = _get_first_selected_col_from_model(klass_info)
1321 if col is not None:
1322 result.append(self.quote_name_unless_alias(col.alias))
1323 if invalid_names:
1324 raise FieldError(
1325 "Invalid field name(s) given in select_for_update(of=(...)): {}. "
1326 "Only relational fields followed in the query are allowed. "
1327 "Choices are: {}.".format(
1328 ", ".join(invalid_names),
1329 ", ".join(_get_field_choices()),
1330 )
1331 )
1332 return result
1333
1334 def results_iter(
1335 self,
1336 results: Any = None,
1337 tuple_expected: bool = False,
1338 chunked_fetch: bool = False,
1339 ) -> Iterable[Any]:
1340 """Return an iterator over the results from executing this query."""
1341 if results is None:
1342 results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch)
1343 assert self.select is not None # Set during query execution
1344 fields = [s[0] for s in self.select[0 : self.col_count]]
1345 converters = get_converters(fields, self.connection)
1346 rows = results
1347 if converters:
1348 rows = apply_converters(rows, converters, self.connection)
1349 if tuple_expected:
1350 rows = map(tuple, rows)
1351 return rows
1352
1353 def has_results(self) -> bool:
1354 """Check if the query returns any results."""
1355 return bool(self.execute_sql(SINGLE))
1356
1357 def execute_sql(
1358 self,
1359 result_type: str = MULTI,
1360 chunked_fetch: bool = False,
1361 ) -> Any:
1362 """
1363 Run the query against the database and return the result(s). The
1364 return value is a single data item if result_type is SINGLE, or a
1365 flat iterable of rows if the result_type is MULTI.
1366
1367 result_type is either MULTI (returns a list from fetchall(), or a
1368 streaming generator from cursor.stream() when chunked_fetch=True),
1369 SINGLE (only retrieve a single row), or None. In this last case, the
1370 cursor is returned if any query is executed, since it's used by
1371 subclasses such as InsertQuery). It's possible, however, that no query
1372 is needed, as the filters describe an empty set. In that case, None is
1373 returned, to avoid any unnecessary database interaction.
1374 """
1375 result_type = result_type or NO_RESULTS
1376 try:
1377 as_sql_result = self.as_sql()
1378 # SQLCompiler.as_sql returns SqlWithParams, subclasses may differ
1379 assert isinstance(as_sql_result, tuple)
1380 assert isinstance(as_sql_result[0], str)
1381 sql, params = as_sql_result
1382 if not sql:
1383 raise EmptyResultSet
1384 except EmptyResultSet:
1385 if result_type == MULTI:
1386 return iter([])
1387 else:
1388 return
1389 cursor = self.connection.cursor()
1390 if chunked_fetch:
1391 # Use psycopg3's cursor.stream() for server-side cursor iteration.
1392 result = cursor.stream(sql, params)
1393 if self.has_extra_select:
1394 col_count = self.col_count
1395 result = (r[:col_count] for r in result)
1396 return result
1397
1398 try:
1399 cursor.execute(sql, params)
1400 except Exception:
1401 cursor.close()
1402 raise
1403
1404 if result_type == CURSOR:
1405 # Give the caller the cursor to process and close.
1406 return cursor
1407 if result_type == SINGLE:
1408 try:
1409 val = cursor.fetchone()
1410 if val:
1411 return val[0 : self.col_count]
1412 return val
1413 finally:
1414 # done with the cursor
1415 cursor.close()
1416 if result_type == NO_RESULTS:
1417 cursor.close()
1418 return
1419
1420 try:
1421 rows = cursor.fetchall()
1422 finally:
1423 cursor.close()
1424 if self.has_extra_select:
1425 rows = [r[: self.col_count] for r in rows]
1426 return rows
1427
1428 def as_subquery_condition(
1429 self, alias: str, columns: list[str], compiler: SQLCompiler
1430 ) -> SqlWithParams:
1431 qn = compiler.quote_name_unless_alias
1432 qn2 = quote_name
1433
1434 for index, select_col in enumerate(self.query.select):
1435 lhs_sql, lhs_params = self.compile(select_col)
1436 rhs = f"{qn(alias)}.{qn2(columns[index])}"
1437 self.query.where.add(RawSQL(f"{lhs_sql} = {rhs}", lhs_params), AND)
1438
1439 sql, params = self.as_sql()
1440 return f"EXISTS ({sql})", params
1441
1442 def explain_query(self) -> Generator[str]:
1443 result = self.execute_sql()
1444 explain_info = self.query.explain_info
1445 # PostgreSQL may return tuples with integers and strings depending on
1446 # the EXPLAIN format. Flatten them out into strings.
1447 format_ = explain_info.format if explain_info is not None else None
1448 output_formatter = json.dumps if format_ and format_.lower() == "json" else str
1449 for row in result:
1450 if not isinstance(row, str):
1451 yield " ".join(output_formatter(c) for c in row)
1452 else:
1453 yield row
1454
1455
1456class SQLInsertCompiler(SQLCompiler):
1457 query: InsertQuery
1458 returning_fields: list | None = None
1459 returning_params: tuple = ()
1460
1461 def field_as_sql(self, field: Any, val: Any) -> tuple[str, list]:
1462 """
1463 Take a field and a value intended to be saved on that field, and
1464 return placeholder SQL and accompanying params. Check for raw values,
1465 expressions, and fields with get_placeholder() defined in that order.
1466
1467 When field is None, consider the value raw and use it as the
1468 placeholder, with no corresponding parameters returned.
1469 """
1470 if field is None:
1471 # A field value of None means the value is raw.
1472 sql, params = val, []
1473 elif hasattr(val, "as_sql"):
1474 # This is an expression, let's compile it.
1475 sql, params_tuple = self.compile(val)
1476 params = list(params_tuple)
1477 elif hasattr(field, "get_placeholder"):
1478 # Some fields (e.g. geo fields) need special munging before
1479 # they can be inserted.
1480 sql, params = field.get_placeholder(val, self, self.connection), [val]
1481 else:
1482 # Return the common case for the placeholder
1483 sql, params = "%s", [val]
1484
1485 return sql, list(params) # Ensure params is a list
1486
1487 def prepare_value(self, field: Any, value: Any) -> Any:
1488 """
1489 Prepare a value to be used in a query by resolving it if it is an
1490 expression and otherwise calling the field's get_db_prep_save().
1491 """
1492 if isinstance(value, ResolvableExpression):
1493 value = value.resolve_expression(
1494 self.query, allow_joins=False, for_save=True
1495 )
1496 # Don't allow values containing Col expressions. They refer to
1497 # existing columns on a row, but in the case of insert the row
1498 # doesn't exist yet.
1499 if value.contains_column_references:
1500 raise ValueError(
1501 f'Failed to insert expression "{value}" on {field}. F() expressions '
1502 "can only be used to update, not to insert."
1503 )
1504 if value.contains_aggregate:
1505 raise FieldError(
1506 "Aggregate functions are not allowed in this query "
1507 f"({field.name}={value!r})."
1508 )
1509 if value.contains_over_clause:
1510 raise FieldError(
1511 f"Window expressions are not allowed in this query ({field.name}={value!r})."
1512 )
1513 return field.get_db_prep_save(value, connection=self.connection)
1514
1515 def pre_save_val(self, field: Any, obj: Any) -> Any:
1516 """
1517 Get the given field's value off the given obj. pre_save() is used for
1518 things like auto_now on DateTimeField. Skip it if this is a raw query.
1519 """
1520 if self.query.raw:
1521 return getattr(obj, field.attname)
1522 return field.pre_save(obj, add=True)
1523
1524 def assemble_as_sql(
1525 self, fields: list[Any], value_rows: list[list[Any]]
1526 ) -> tuple[list[list[str]], list[list]]:
1527 """
1528 Take a sequence of N fields and a sequence of M rows of values, and
1529 generate placeholder SQL and parameters for each field and value.
1530 Return a pair containing:
1531 * a sequence of M rows of N SQL placeholder strings, and
1532 * a sequence of M rows of corresponding parameter values.
1533
1534 Each placeholder string may contain any number of '%s' interpolation
1535 strings, and each parameter row will contain exactly as many params
1536 as the total number of '%s's in the corresponding placeholder row.
1537 """
1538 if not value_rows:
1539 return [], []
1540
1541 # list of (sql, [params]) tuples for each object to be saved
1542 # Shape: [n_objs][n_fields][2]
1543 rows_of_fields_as_sql = (
1544 (self.field_as_sql(field, v) for field, v in zip(fields, row))
1545 for row in value_rows
1546 )
1547
1548 # tuple like ([sqls], [[params]s]) for each object to be saved
1549 # Shape: [n_objs][2][n_fields]
1550 sql_and_param_pair_rows = (zip(*row) for row in rows_of_fields_as_sql)
1551
1552 # Extract separate lists for placeholders and params.
1553 # Each of these has shape [n_objs][n_fields]
1554 placeholder_rows, param_rows = zip(*sql_and_param_pair_rows)
1555
1556 # Params for each field are still lists, and need to be flattened.
1557 param_rows = [[p for ps in row for p in ps] for row in param_rows]
1558
1559 return placeholder_rows, param_rows
1560
1561 def as_sql( # type: ignore[override] # Returns list for internal iteration in execute_sql
1562 self, with_limits: bool = True, with_col_aliases: bool = False
1563 ) -> list[SqlWithParams]:
1564 # We don't need quote_name_unless_alias() here, since these are all
1565 # going to be column names (so we can avoid the extra overhead).
1566 qn = quote_name
1567 assert self.query.model is not None, "INSERT requires a model"
1568 meta = self.query.model._model_meta
1569 options = self.query.model.model_options
1570 result = [f"INSERT INTO {qn(options.db_table)}"]
1571 if self.query.fields:
1572 fields = self.query.fields
1573 else:
1574 fields = [meta.get_forward_field("id")]
1575 result.append("({})".format(", ".join(qn(f.column) for f in fields)))
1576
1577 if self.query.fields:
1578 value_rows = [
1579 [
1580 self.prepare_value(field, self.pre_save_val(field, obj))
1581 for field in fields
1582 ]
1583 for obj in self.query.objs
1584 ]
1585 else:
1586 # An empty object.
1587 value_rows = [[PK_DEFAULT_VALUE] for _ in self.query.objs]
1588 fields = [None]
1589
1590 placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
1591
1592 conflict_suffix_sql = on_conflict_suffix_sql(
1593 fields, # type: ignore[arg-type]
1594 self.query.on_conflict,
1595 (f.column for f in self.query.update_fields),
1596 (f.column for f in self.query.unique_fields),
1597 )
1598 if self.returning_fields:
1599 # Use RETURNING clause to get inserted values
1600 result.append(
1601 bulk_insert_sql(fields, placeholder_rows) # type: ignore[arg-type]
1602 )
1603 params = param_rows
1604 if conflict_suffix_sql:
1605 result.append(conflict_suffix_sql)
1606 # Skip empty r_sql in case returning_cols returns an empty string.
1607 returning_cols = return_insert_columns(self.returning_fields)
1608 if returning_cols:
1609 r_sql, self.returning_params = returning_cols
1610 if r_sql:
1611 result.append(r_sql)
1612 params += [list(self.returning_params)]
1613 return [(" ".join(result), tuple(chain.from_iterable(params)))]
1614
1615 # Bulk insert without returning fields
1616 result.append(bulk_insert_sql(fields, placeholder_rows)) # type: ignore[arg-type]
1617 if conflict_suffix_sql:
1618 result.append(conflict_suffix_sql)
1619 return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
1620
1621 def execute_sql( # type: ignore[override]
1622 self, returning_fields: list | None = None
1623 ) -> list:
1624 assert self.query.model is not None, "INSERT execution requires a model"
1625 options = self.query.model.model_options
1626 self.returning_fields = returning_fields
1627 with self.connection.cursor() as cursor:
1628 for sql, params in self.as_sql():
1629 cursor.execute(sql, params)
1630 if not self.returning_fields:
1631 return []
1632 # Use RETURNING clause for both single and bulk inserts
1633 if len(self.query.objs) > 1:
1634 rows = cursor.fetchall()
1635 else:
1636 rows = [cursor.fetchone()]
1637 cols = [field.get_col(options.db_table) for field in self.returning_fields]
1638 converters = get_converters(cols, self.connection)
1639 if converters:
1640 rows = list(apply_converters(rows, converters, self.connection))
1641 return rows
1642
1643
1644class SQLDeleteCompiler(SQLCompiler):
1645 @cached_property
1646 def single_alias(self) -> bool:
1647 # Ensure base table is in aliases.
1648 self.query.get_initial_alias()
1649 return sum(self.query.alias_refcount[t] > 0 for t in self.query.alias_map) == 1
1650
1651 @classmethod
1652 def _expr_refs_base_model(cls, expr: Any, base_model: Any) -> bool:
1653 if isinstance(expr, Query):
1654 return expr.model == base_model
1655 if not hasattr(expr, "get_source_expressions"):
1656 return False
1657 return any(
1658 cls._expr_refs_base_model(source_expr, base_model)
1659 for source_expr in expr.get_source_expressions()
1660 )
1661
1662 @cached_property
1663 def contains_self_reference_subquery(self) -> bool:
1664 return any(
1665 self._expr_refs_base_model(expr, self.query.model)
1666 for expr in chain(
1667 self.query.annotations.values(), self.query.where.children
1668 )
1669 )
1670
1671 def _as_sql(self, query: Query) -> SqlWithParams:
1672 delete = f"DELETE FROM {self.quote_name_unless_alias(query.base_table)}" # type: ignore[invalid-argument-type]
1673 try:
1674 where, params = self.compile(query.where)
1675 except FullResultSet:
1676 return delete, ()
1677 return f"{delete} WHERE {where}", tuple(params)
1678
1679 def as_sql(
1680 self, with_limits: bool = True, with_col_aliases: bool = False
1681 ) -> SqlWithParams:
1682 """
1683 Create the SQL for this query. Return the SQL string and list of
1684 parameters.
1685 """
1686 if self.single_alias and not self.contains_self_reference_subquery:
1687 return self._as_sql(self.query)
1688 innerq = self.query.clone()
1689 innerq.__class__ = Query
1690 innerq.clear_select_clause()
1691 assert self.query.model is not None, "DELETE requires a model"
1692 id_field = self.query.model._model_meta.get_forward_field("id")
1693 innerq.select = [id_field.get_col(self.query.get_initial_alias())]
1694 outerq = Query(self.query.model)
1695 outerq.add_filter("id__in", innerq)
1696 return self._as_sql(outerq)
1697
1698
1699class SQLUpdateCompiler(SQLCompiler):
1700 def as_sql(
1701 self, with_limits: bool = True, with_col_aliases: bool = False
1702 ) -> SqlWithParams:
1703 """
1704 Create the SQL for this query. Return the SQL string and list of
1705 parameters.
1706 """
1707 self.pre_sql_setup()
1708 query_values = getattr(self.query, "values", None)
1709 if not query_values:
1710 return "", ()
1711 qn = self.quote_name_unless_alias
1712 values, update_params = [], []
1713 for field, model, val in query_values:
1714 if isinstance(val, ResolvableExpression):
1715 val = val.resolve_expression(
1716 self.query, allow_joins=False, for_save=True
1717 )
1718 if val.contains_aggregate:
1719 raise FieldError(
1720 "Aggregate functions are not allowed in this query "
1721 f"({field.name}={val!r})."
1722 )
1723 if val.contains_over_clause:
1724 raise FieldError(
1725 "Window expressions are not allowed in this query "
1726 f"({field.name}={val!r})."
1727 )
1728 elif hasattr(val, "prepare_database_save"):
1729 if isinstance(field, RelatedField):
1730 val = val.prepare_database_save(field)
1731 else:
1732 raise TypeError(
1733 f"Tried to update field {field} with a model instance, {val!r}. "
1734 f"Use a value compatible with {field.__class__.__name__}."
1735 )
1736 val = field.get_db_prep_save(val, connection=self.connection)
1737
1738 # Getting the placeholder for the field.
1739 if hasattr(field, "get_placeholder"):
1740 placeholder = field.get_placeholder(val, self, self.connection)
1741 else:
1742 placeholder = "%s"
1743 name = field.column
1744 if hasattr(val, "as_sql"):
1745 sql, params = self.compile(val)
1746 values.append(f"{qn(name)} = {placeholder % sql}")
1747 update_params.extend(params)
1748 elif val is not None:
1749 values.append(f"{qn(name)} = {placeholder}")
1750 update_params.append(val)
1751 else:
1752 values.append(f"{qn(name)} = NULL")
1753 table = self.query.base_table
1754 result = [
1755 f"UPDATE {qn(table)} SET", # type: ignore[invalid-argument-type]
1756 ", ".join(values),
1757 ]
1758 try:
1759 where, params = self.compile(self.query.where)
1760 except FullResultSet:
1761 params = []
1762 else:
1763 result.append(f"WHERE {where}")
1764 return " ".join(result), tuple(update_params + list(params))
1765
1766 def execute_sql(self, result_type: str) -> int: # type: ignore[override]
1767 """
1768 Execute the specified update. Return the number of rows affected by
1769 the primary update query. The "primary update query" is the first
1770 non-empty query that is executed. Row counts for any subsequent,
1771 related queries are not available.
1772 """
1773 cursor = super().execute_sql(result_type)
1774 try:
1775 rows = cursor.rowcount if cursor else 0
1776 is_empty = cursor is None
1777 finally:
1778 if cursor:
1779 cursor.close()
1780 related_update_queries = getattr(self.query, "get_related_updates", None)
1781 if related_update_queries is None:
1782 return rows
1783 for query in related_update_queries():
1784 aux_rows = query.get_compiler().execute_sql(result_type)
1785 if is_empty and aux_rows:
1786 rows = aux_rows
1787 is_empty = False
1788 return rows
1789
1790 def pre_sql_setup(
1791 self, with_col_aliases: bool = False
1792 ) -> tuple[list[Any], list[Any], list[SqlWithParams]] | None:
1793 """
1794 If the update depends on results from other tables, munge the "where"
1795 conditions to match the format required for (portable) SQL updates.
1796
1797 If multiple updates are required, pull out the id values to update at
1798 this point so that they don't change as a result of the progressive
1799 updates.
1800 """
1801 refcounts_before = self.query.alias_refcount.copy()
1802 # Ensure base table is in the query
1803 self.query.get_initial_alias()
1804 count = self.query.count_active_tables()
1805 related_updates = getattr(self.query, "related_updates", [])
1806 if not related_updates and count == 1:
1807 return
1808 query = self.query.chain(klass=Query)
1809 query.select_related = False
1810 query.clear_ordering(force=True)
1811 query.extra = {}
1812 query.select = []
1813 fields = ["id"]
1814 related_ids_index = []
1815 for related in related_updates:
1816 # If a primary key chain exists to the targeted related update,
1817 # then the primary key value can be used for it.
1818 related_ids_index.append((related, 0))
1819
1820 query.add_fields(fields)
1821 super().pre_sql_setup()
1822
1823 # Now we adjust the current query: reset the where clause and get rid
1824 # of all the tables we don't need (since they're in the sub-select).
1825 self.query.clear_where()
1826 if related_updates:
1827 # We're using the idents in multiple update queries (so
1828 # don't want them to change).
1829 idents = []
1830 related_ids = collections.defaultdict(list)
1831 for row in query.get_compiler().execute_sql(MULTI):
1832 idents.append(row[0])
1833 for parent, index in related_ids_index:
1834 related_ids[parent].append(row[index])
1835 self.query.add_filter("id__in", idents)
1836 self.query.related_ids = related_ids # type: ignore[assignment]
1837 else:
1838 # The fast path. Filters and updates in one query using a subquery.
1839 self.query.add_filter("id__in", query)
1840 self.query.reset_refcounts(refcounts_before)
1841
1842
1843class SQLAggregateCompiler(SQLCompiler):
1844 def as_sql(
1845 self, with_limits: bool = True, with_col_aliases: bool = False
1846 ) -> SqlWithParams:
1847 """
1848 Create the SQL for this query. Return the SQL string and list of
1849 parameters.
1850 """
1851 sql, params = [], []
1852 for annotation in self.query.annotation_select.values():
1853 ann_sql, ann_params = self.compile(annotation)
1854 ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params)
1855 sql.append(ann_sql)
1856 params.extend(ann_params)
1857 self.col_count = len(self.query.annotation_select)
1858 sql = ", ".join(sql)
1859 params = tuple(params)
1860
1861 inner_query = getattr(self.query, "inner_query", None)
1862 if inner_query is None:
1863 raise ValueError("Inner query expected for update with related ids")
1864 inner_query_sql, inner_query_params = inner_query.get_compiler(
1865 elide_empty=self.elide_empty,
1866 ).as_sql(with_col_aliases=True)
1867 sql = f"SELECT {sql} FROM ({inner_query_sql}) subquery"
1868 params += inner_query_params
1869 return sql, params