1from __future__ import annotations
2
3import collections
4import json
5import re
6from collections.abc import Generator, Iterable, Sequence
7from functools import cached_property, partial
8from itertools import chain
9from typing import TYPE_CHECKING, Any, Protocol, cast
10
11from plain.postgres.constants import LOOKUP_SEP
12from plain.postgres.dialect import (
13 PK_DEFAULT_VALUE,
14 bulk_insert_sql,
15 distinct_sql,
16 explain_query_prefix,
17 for_update_sql,
18 limit_offset_sql,
19 on_conflict_suffix_sql,
20 quote_name,
21 return_insert_columns,
22)
23from plain.postgres.exceptions import EmptyResultSet, FieldError, FullResultSet
24from plain.postgres.expressions import (
25 F,
26 OrderBy,
27 RawSQL,
28 Ref,
29 ResolvableExpression,
30 Value,
31)
32from plain.postgres.fields import DATABASE_DEFAULT
33from plain.postgres.fields.related import RelatedField
34from plain.postgres.functions import Cast, Random
35from plain.postgres.lookups import Lookup
36from plain.postgres.meta import Meta
37from plain.postgres.query_utils import select_related_descend
38from plain.postgres.sql.constants import (
39 CURSOR,
40 MULTI,
41 NO_RESULTS,
42 ORDER_DIR,
43 SINGLE,
44)
45from plain.postgres.sql.query import Query, get_order_dir
46from plain.postgres.transaction import TransactionManagementError
47from plain.utils.hashable import make_hashable
48from plain.utils.regex_helper import _lazy_re_compile
49
50if TYPE_CHECKING:
51 from plain.postgres.connection import DatabaseConnection
52 from plain.postgres.expressions import BaseExpression
53 from plain.postgres.sql.query import AggregateQuery, InsertQuery
54
55# Type aliases for SQL compilation results
56SqlParams = tuple[Any, ...]
57SqlWithParams = tuple[str, SqlParams]
58
59
60class SQLCompilable(Protocol):
61 """Protocol for objects that can be compiled to SQL."""
62
63 def as_sql(
64 self, compiler: SQLCompiler, connection: DatabaseConnection
65 ) -> tuple[str, Sequence[Any]]:
66 """Return SQL string and parameters for this object."""
67 ...
68
69
70class PositionRef(Ref):
71 def __init__(self, ordinal: int, refs: str, source: Any):
72 self.ordinal = ordinal
73 super().__init__(refs, source)
74
75 def as_sql(
76 self, compiler: SQLCompiler, connection: DatabaseConnection
77 ) -> tuple[str, list[Any]]:
78 return str(self.ordinal), []
79
80
81def get_converters(
82 expressions: Iterable[Any], connection: DatabaseConnection
83) -> dict[int, tuple[list[Any], Any]]:
84 converters = {}
85 for i, expression in enumerate(expressions):
86 if expression:
87 field_converters = expression.get_db_converters(connection)
88 if field_converters:
89 converters[i] = (field_converters, expression)
90 return converters
91
92
93def apply_converters(
94 rows: Iterable, converters: dict, connection: DatabaseConnection
95) -> Generator[list]:
96 converters_list = list(converters.items())
97 for row in map(list, rows):
98 for pos, (convs, expression) in converters_list:
99 value = row[pos]
100 for converter in convs:
101 value = converter(value, expression, connection)
102 row[pos] = value
103 yield row
104
105
106class SQLCompiler:
107 # Multiline ordering SQL clause may appear from RawSQL.
108 ordering_parts = _lazy_re_compile(
109 r"^(.*)\s(?:ASC|DESC).*",
110 re.MULTILINE | re.DOTALL,
111 )
112
113 def __init__(
114 self, query: Query, connection: DatabaseConnection, elide_empty: bool = True
115 ):
116 self.query = query
117 self.connection = connection
118 # Some queries, e.g. coalesced aggregation, need to be executed even if
119 # they would return an empty result set.
120 self.elide_empty = elide_empty
121 self.quote_cache: dict[str, str] = {"*": "*"}
122 # The select, klass_info, and annotations are needed by QuerySet.iterator()
123 # these are set as a side-effect of executing the query. Note that we calculate
124 # separately a list of extra select columns needed for grammatical correctness
125 # of the query, but these columns are not included in self.select.
126 self.select: list[tuple[Any, SqlWithParams, str | None]] | None = None
127 self.annotation_col_map: dict[str, int] | None = None
128 self.klass_info: dict[str, Any] | None = None
129 self._meta_ordering: list[str] | None = None
130
131 def __repr__(self) -> str:
132 model_name = self.query.model.__qualname__ if self.query.model else "None"
133 return (
134 f"<{self.__class__.__qualname__} "
135 f"model={model_name} "
136 f"connection={self.connection!r}>"
137 )
138
139 def setup_query(self, with_col_aliases: bool = False) -> None:
140 if all(self.query.alias_refcount[a] == 0 for a in self.query.alias_map):
141 self.query.get_initial_alias()
142 self.select, self.klass_info, self.annotation_col_map = self.get_select(
143 with_col_aliases=with_col_aliases,
144 )
145 self.col_count = len(self.select)
146
147 def pre_sql_setup(
148 self, with_col_aliases: bool = False
149 ) -> tuple[list[Any], list[Any], list[SqlWithParams]] | None:
150 """
151 Do any necessary class setup immediately prior to producing SQL. This
152 is for things that can't necessarily be done in __init__ because we
153 might not have all the pieces in place at that time.
154 """
155 self.setup_query(with_col_aliases=with_col_aliases)
156 assert self.select is not None # Set by setup_query()
157 order_by = self.get_order_by()
158 self.where, self.having, self.qualify = self.query.where.split_having_qualify(
159 must_group_by=self.query.group_by is not None
160 )
161 extra_select = self.get_extra_select(order_by, self.select)
162 self.has_extra_select = bool(extra_select)
163 group_by = self.get_group_by(self.select + extra_select, order_by)
164 return extra_select, order_by, group_by
165
166 def get_group_by(
167 self, select: list[Any], order_by: list[Any]
168 ) -> list[SqlWithParams]:
169 """
170 Return a list of 2-tuples of form (sql, params).
171
172 The logic of what exactly the GROUP BY clause contains is hard
173 to describe in other words than "if it passes the test suite,
174 then it is correct".
175 """
176 # Some examples:
177 # SomeModel.query.annotate(Count('somecol'))
178 # GROUP BY: all fields of the model
179 #
180 # SomeModel.query.values('name').annotate(Count('somecol'))
181 # GROUP BY: name
182 #
183 # SomeModel.query.annotate(Count('somecol')).values('name')
184 # GROUP BY: all cols of the model
185 #
186 # SomeModel.query.values('name', 'id')
187 # .annotate(Count('somecol')).values('id')
188 # GROUP BY: name, id
189 #
190 # SomeModel.query.values('name').annotate(Count('somecol')).values('id')
191 # GROUP BY: name, id
192 #
193 # In fact, the self.query.group_by is the minimal set to GROUP BY. It
194 # can't be ever restricted to a smaller set, but additional columns in
195 # HAVING, ORDER BY, and SELECT clauses are added to it. Unfortunately
196 # the end result is that it is impossible to force the query to have
197 # a chosen GROUP BY clause - you can almost do this by using the form:
198 # .values(*wanted_cols).annotate(AnAggregate())
199 # but any later annotations, extra selects, values calls that
200 # refer some column outside of the wanted_cols, order_by, or even
201 # filter calls can alter the GROUP BY clause.
202
203 # The query.group_by is either None (no GROUP BY at all), True
204 # (group by select fields), or a list of expressions to be added
205 # to the group by.
206 if self.query.group_by is None:
207 return []
208 expressions = []
209 group_by_refs = set()
210 if self.query.group_by is not True:
211 # If the group by is set to a list (by .values() call most likely),
212 # then we need to add everything in it to the GROUP BY clause.
213 # Backwards compatibility hack for setting query.group_by. Remove
214 # when we have public API way of forcing the GROUP BY clause.
215 # Converts string references to expressions.
216 for expr in self.query.group_by:
217 if not hasattr(expr, "as_sql"):
218 expr = self.query.resolve_ref(expr)
219 if isinstance(expr, Ref):
220 if expr.refs not in group_by_refs:
221 group_by_refs.add(expr.refs)
222 expressions.append(expr.source)
223 else:
224 expressions.append(expr)
225 # Note that even if the group_by is set, it is only the minimal
226 # set to group by. So, we need to add cols in select, order_by, and
227 # having into the select in any case.
228 selected_expr_positions = {}
229 for ordinal, (expr, _, alias) in enumerate(select, start=1):
230 if alias:
231 selected_expr_positions[expr] = ordinal
232 # Skip members of the select clause that are already explicitly
233 # grouped against.
234 if alias in group_by_refs:
235 continue
236 expressions.extend(expr.get_group_by_cols())
237 if not self._meta_ordering:
238 for expr, (sql, params, is_ref) in order_by:
239 # Skip references to the SELECT clause, as all expressions in
240 # the SELECT clause are already part of the GROUP BY.
241 if not is_ref:
242 expressions.extend(expr.get_group_by_cols())
243 having_group_by = self.having.get_group_by_cols() if self.having else []
244 for expr in having_group_by:
245 expressions.append(expr)
246 result = []
247 seen = set()
248 expressions = self.collapse_group_by(expressions, having_group_by)
249
250 for expr in expressions:
251 try:
252 sql, params = self.compile(expr)
253 except (EmptyResultSet, FullResultSet):
254 continue
255 # Use select index for GROUP BY when possible
256 if (position := selected_expr_positions.get(expr)) is not None:
257 sql, params = str(position), ()
258 else:
259 sql, params = expr.select_format(self, sql, params)
260 params_hash = make_hashable(params)
261 if (sql, params_hash) not in seen:
262 result.append((sql, params))
263 seen.add((sql, params_hash))
264 return result
265
266 def collapse_group_by(self, expressions: list[Any], having: list[Any]) -> list[Any]:
267 # Use group by functional dependence reduction:
268 # expressions can be reduced to the set of selected table
269 # primary keys as all other columns are functionally dependent on them.
270 # Filter out all expressions associated with a table's primary key
271 # present in the grouped columns. This is done by identifying all
272 # tables that have their primary key included in the grouped
273 # columns and removing non-primary key columns referring to them.
274 pks = {
275 expr
276 for expr in expressions
277 if hasattr(expr, "target") and expr.target.primary_key
278 }
279 aliases = {expr.alias for expr in pks}
280 return [
281 expr
282 for expr in expressions
283 if expr in pks
284 or expr in having
285 or getattr(expr, "alias", None) not in aliases
286 ]
287
288 def get_select(
289 self, with_col_aliases: bool = False
290 ) -> tuple[
291 list[tuple[Any, SqlWithParams, str | None]],
292 dict[str, Any] | None,
293 dict[str, int],
294 ]:
295 """
296 Return three values:
297 - a list of 3-tuples of (expression, (sql, params), alias)
298 - a klass_info structure,
299 - a dictionary of annotations
300
301 The (sql, params) is what the expression will produce, and alias is the
302 "AS alias" for the column (possibly None).
303
304 The klass_info structure contains the following information:
305 - The base model of the query.
306 - Which columns for that model are present in the query (by
307 position of the select clause).
308 - related_klass_infos: [f, klass_info] to descent into
309
310 The annotations is a dictionary of {'attname': column position} values.
311 """
312 select = []
313 klass_info = None
314 annotations = {}
315 select_idx = 0
316 for alias, (sql, params) in self.query.extra_select.items():
317 annotations[alias] = select_idx
318 select.append((RawSQL(sql, params), alias))
319 select_idx += 1
320 assert not (self.query.select and self.query.default_cols)
321 select_mask = self.query.get_select_mask()
322 if self.query.default_cols:
323 cols = self.get_default_columns(select_mask)
324 else:
325 # self.query.select is a special case. These columns never go to
326 # any model.
327 cols = self.query.select
328 if cols:
329 select_list = []
330 for col in cols:
331 select_list.append(select_idx)
332 select.append((col, None))
333 select_idx += 1
334 klass_info = {
335 "model": self.query.model,
336 "select_fields": select_list,
337 }
338 for alias, annotation in self.query.annotation_select.items():
339 annotations[alias] = select_idx
340 select.append((annotation, alias))
341 select_idx += 1
342
343 if self.query.select_related:
344 related_klass_infos = self.get_related_selections(select, select_mask)
345 if klass_info is not None:
346 klass_info["related_klass_infos"] = related_klass_infos
347
348 ret = []
349 col_idx = 1
350 for col, alias in select:
351 try:
352 sql, params = self.compile(col)
353 except EmptyResultSet:
354 empty_result_set_value = getattr(
355 col, "empty_result_set_value", NotImplemented
356 )
357 if empty_result_set_value is NotImplemented:
358 # Select a predicate that's always False.
359 sql, params = "0", ()
360 else:
361 sql, params = self.compile(Value(empty_result_set_value))
362 except FullResultSet:
363 sql, params = self.compile(Value(True))
364 else:
365 sql, params = col.select_format(self, sql, params)
366 if alias is None and with_col_aliases:
367 alias = f"col{col_idx}"
368 col_idx += 1
369 ret.append((col, (sql, params), alias))
370 return ret, klass_info, annotations
371
372 def _order_by_pairs(self) -> Generator[tuple[OrderBy, bool]]:
373 if self.query.extra_order_by:
374 ordering = self.query.extra_order_by
375 elif not self.query.default_ordering:
376 ordering = self.query.order_by
377 elif self.query.order_by:
378 ordering = self.query.order_by
379 elif (
380 self.query.model
381 and (options := self.query.model.model_options)
382 and options.ordering
383 ):
384 ordering = options.ordering
385 self._meta_ordering = list(ordering)
386 else:
387 ordering = []
388 if self.query.standard_ordering:
389 default_order, _ = ORDER_DIR["ASC"]
390 else:
391 default_order, _ = ORDER_DIR["DESC"]
392
393 selected_exprs = {}
394 if select := self.select:
395 for ordinal, (expr, _, alias) in enumerate(select, start=1):
396 pos_expr = PositionRef(ordinal, alias, expr) # ty: ignore[invalid-argument-type]
397 if alias:
398 selected_exprs[alias] = pos_expr
399 selected_exprs[expr] = pos_expr
400
401 for field in ordering:
402 if isinstance(field, ResolvableExpression):
403 # field is a BaseExpression (has asc/desc/copy methods)
404 field_expr = cast(BaseExpression, field)
405 if isinstance(field_expr, Value):
406 # output_field must be resolved for constants.
407 field_expr = Cast(field_expr, field_expr.output_field)
408 if not isinstance(field_expr, OrderBy):
409 field_expr = field_expr.asc()
410 if not self.query.standard_ordering:
411 field_expr = field_expr.copy()
412 field_expr.reverse_ordering()
413 field = field_expr
414 select_ref = selected_exprs.get(field.expression)
415 if select_ref or (
416 isinstance(field.expression, F)
417 and (select_ref := selected_exprs.get(field.expression.name))
418 ):
419 field = field.copy()
420 field.expression = select_ref
421 yield field, select_ref is not None
422 continue
423 if field == "?": # random
424 yield OrderBy(Random()), False
425 continue
426
427 col, order = get_order_dir(field, default_order)
428 descending = order == "DESC"
429
430 if select_ref := selected_exprs.get(col):
431 # Reference to expression in SELECT clause
432 yield (
433 OrderBy(
434 select_ref,
435 descending=descending,
436 ),
437 True,
438 )
439 continue
440 if col in self.query.annotations:
441 # References to an expression which is masked out of the SELECT
442 # clause.
443 expr = self.query.annotations[col]
444 if isinstance(expr, Value):
445 # output_field must be resolved for constants.
446 expr = Cast(expr, expr.output_field)
447 yield OrderBy(expr, descending=descending), False
448 continue
449
450 if "." in field:
451 # This came in through an extra(order_by=...) addition. Pass it
452 # on verbatim.
453 table, col = col.split(".", 1)
454 yield (
455 OrderBy(
456 RawSQL(f"{self.quote_name_unless_alias(table)}.{col}", []),
457 descending=descending,
458 ),
459 False,
460 )
461 continue
462
463 if self.query.extra and col in self.query.extra:
464 if col in self.query.extra_select:
465 yield (
466 OrderBy(
467 Ref(col, RawSQL(*self.query.extra[col])),
468 descending=descending,
469 ),
470 True,
471 )
472 else:
473 yield (
474 OrderBy(RawSQL(*self.query.extra[col]), descending=descending),
475 False,
476 )
477 else:
478 # 'col' is of the form 'field' or 'field1__field2' or
479 # '-field1__field2__field', etc.
480 assert self.query.model is not None, (
481 "Ordering by fields requires a model"
482 )
483 meta = self.query.model._model_meta
484 yield from self.find_ordering_name(
485 field,
486 meta,
487 default_order=default_order,
488 )
489
490 def get_order_by(self) -> list[tuple[Any, tuple[str, tuple, bool]]]:
491 """
492 Return a list of 2-tuples of the form (expr, (sql, params, is_ref)) for
493 the ORDER BY clause.
494
495 The order_by clause can alter the select clause (for example it can add
496 aliases to clauses that do not yet have one, or it can add totally new
497 select clauses).
498 """
499 result = []
500 seen = set()
501 for expr, is_ref in self._order_by_pairs():
502 resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None)
503 sql, params = self.compile(resolved)
504 # Don't add the same column twice, but the order direction is
505 # not taken into account so we strip it. When this entire method
506 # is refactored into expressions, then we can check each part as we
507 # generate it.
508 without_ordering = self.ordering_parts.search(sql)[1]
509 params_hash = make_hashable(params)
510 if (without_ordering, params_hash) in seen:
511 continue
512 seen.add((without_ordering, params_hash))
513 result.append((resolved, (sql, params, is_ref)))
514 return result
515
516 def get_extra_select(
517 self, order_by: list[Any], select: list[Any]
518 ) -> list[tuple[Any, SqlWithParams, None]]:
519 extra_select = []
520 if self.query.distinct and not self.query.distinct_fields:
521 select_sql = [t[1] for t in select]
522 for expr, (sql, params, is_ref) in order_by:
523 without_ordering = self.ordering_parts.search(sql)[1]
524 if not is_ref and (without_ordering, params) not in select_sql:
525 extra_select.append((expr, (without_ordering, params), None))
526 return extra_select
527
528 def quote_name_unless_alias(self, name: str) -> str:
529 """
530 A wrapper around quote_name() that doesn't quote aliases for table
531 names. This avoids problems with some SQL dialects that treat quoted
532 strings specially (e.g. PostgreSQL).
533 """
534 if name in self.quote_cache:
535 return self.quote_cache[name]
536 if (
537 (name in self.query.alias_map and name not in self.query.table_map)
538 or name in self.query.extra_select
539 or (
540 self.query.external_aliases.get(name)
541 and name not in self.query.table_map
542 )
543 ):
544 self.quote_cache[name] = name
545 return name
546 r = quote_name(name)
547 self.quote_cache[name] = r
548 return r
549
550 def compile(self, node: SQLCompilable) -> SqlWithParams:
551 sql, params = node.as_sql(self, self.connection)
552 return sql, tuple(params)
553
554 def get_qualify_sql(self) -> tuple[list[str], list[Any]]:
555 where_parts = []
556 if self.where:
557 where_parts.append(self.where)
558 if self.having:
559 where_parts.append(self.having)
560 inner_query = self.query.clone()
561 inner_query.subquery = True
562 inner_query.where = inner_query.where.__class__(where_parts)
563 # Augment the inner query with any window function references that
564 # might have been masked via values() and alias(). If any masked
565 # aliases are added they'll be masked again to avoid fetching
566 # the data in the `if qual_aliases` branch below.
567 select = {
568 expr: alias for expr, _, alias in self.get_select(with_col_aliases=True)[0]
569 }
570 select_aliases = set(select.values())
571 qual_aliases = set()
572 replacements = {}
573
574 def collect_replacements(expressions: list[Any]) -> None:
575 while expressions:
576 expr = expressions.pop()
577 if expr in replacements:
578 continue
579 elif select_alias := select.get(expr):
580 replacements[expr] = select_alias
581 elif isinstance(expr, Lookup):
582 expressions.extend(expr.get_source_expressions())
583 elif isinstance(expr, Ref):
584 if expr.refs not in select_aliases:
585 expressions.extend(expr.get_source_expressions())
586 else:
587 num_qual_alias = len(qual_aliases)
588 select_alias = f"qual{num_qual_alias}"
589 qual_aliases.add(select_alias)
590 inner_query.add_annotation(expr, select_alias)
591 replacements[expr] = select_alias
592
593 qualify = self.qualify
594 if qualify is None:
595 raise ValueError("QUALIFY clause expected but not provided")
596 collect_replacements(list(qualify.leaves()))
597 qualify = qualify.replace_expressions(
598 {expr: Ref(alias, expr) for expr, alias in replacements.items()}
599 )
600 self.qualify = qualify
601 order_by = []
602 for order_by_expr, *_ in self.get_order_by():
603 collect_replacements(order_by_expr.get_source_expressions())
604 order_by.append(
605 order_by_expr.replace_expressions(
606 {expr: Ref(alias, expr) for expr, alias in replacements.items()}
607 )
608 )
609 inner_query_compiler = inner_query.get_compiler(elide_empty=self.elide_empty)
610 inner_sql, inner_params = inner_query_compiler.as_sql(
611 # The limits must be applied to the outer query to avoid pruning
612 # results too eagerly.
613 with_limits=False,
614 # Force unique aliasing of selected columns to avoid collisions
615 # and make rhs predicates referencing easier.
616 with_col_aliases=True,
617 )
618 qualify_sql, qualify_params = self.compile(qualify)
619 result = [
620 "SELECT * FROM (",
621 inner_sql,
622 ")",
623 quote_name("qualify"),
624 "WHERE",
625 qualify_sql,
626 ]
627 if qual_aliases:
628 # If some select aliases were unmasked for filtering purposes they
629 # must be masked back.
630 cols = [quote_name(alias) for alias in select.values() if alias is not None]
631 result = [
632 "SELECT",
633 ", ".join(cols),
634 "FROM (",
635 *result,
636 ")",
637 quote_name("qualify_mask"),
638 ]
639 params = list(inner_params) + list(qualify_params)
640 # As the SQL spec is unclear on whether or not derived tables
641 # ordering must propagate it has to be explicitly repeated on the
642 # outer-most query to ensure it's preserved.
643 if order_by:
644 ordering_sqls = []
645 for ordering in order_by:
646 ordering_sql, ordering_params = self.compile(ordering)
647 ordering_sqls.append(ordering_sql)
648 params.extend(ordering_params)
649 result.extend(["ORDER BY", ", ".join(ordering_sqls)])
650 return result, params
651
652 def as_sql(
653 self, with_limits: bool = True, with_col_aliases: bool = False
654 ) -> SqlWithParams:
655 """
656 Create the SQL for this query. Return the SQL string and list of
657 parameters.
658
659 If 'with_limits' is False, any limit/offset information is not included
660 in the query.
661 """
662 refcounts_before = self.query.alias_refcount.copy()
663 try:
664 result = self.pre_sql_setup(with_col_aliases=with_col_aliases)
665 assert result is not None # SQLCompiler.pre_sql_setup always returns tuple
666 extra_select, order_by, group_by = result
667 assert self.select is not None # Set by pre_sql_setup()
668 for_update_part = None
669 # Is a LIMIT/OFFSET clause needed?
670 with_limit_offset = with_limits and self.query.is_sliced
671 if self.qualify:
672 result, params = self.get_qualify_sql()
673 order_by = None
674 else:
675 distinct_fields, distinct_params = self.get_distinct()
676 # This must come after 'select', 'ordering', and 'distinct'
677 # (see docstring of get_from_clause() for details).
678 from_, f_params = self.get_from_clause()
679 try:
680 where, w_params = (
681 self.compile(self.where) if self.where is not None else ("", [])
682 )
683 except EmptyResultSet:
684 if self.elide_empty:
685 raise
686 # Use a predicate that's always False.
687 where, w_params = "0 = 1", []
688 except FullResultSet:
689 where, w_params = "", []
690 try:
691 having, h_params = (
692 self.compile(self.having)
693 if self.having is not None
694 else ("", [])
695 )
696 except FullResultSet:
697 having, h_params = "", []
698 result = ["SELECT"]
699 params = []
700
701 if self.query.distinct:
702 distinct_result, distinct_params = distinct_sql(
703 distinct_fields,
704 distinct_params,
705 )
706 result += distinct_result
707 params += distinct_params
708
709 out_cols = []
710 for _, (s_sql, s_params), alias in self.select + extra_select:
711 if alias:
712 s_sql = f"{s_sql} AS {quote_name(alias)}"
713 params.extend(s_params)
714 out_cols.append(s_sql)
715
716 result += [", ".join(out_cols)]
717 if from_:
718 result += ["FROM", *from_]
719 params.extend(f_params)
720
721 if self.query.select_for_update:
722 if self.connection.get_autocommit():
723 raise TransactionManagementError(
724 "select_for_update cannot be used outside of a transaction."
725 )
726
727 for_update_part = for_update_sql(
728 nowait=self.query.select_for_update_nowait,
729 skip_locked=self.query.select_for_update_skip_locked,
730 of=tuple(self.get_select_for_update_of_arguments()),
731 no_key=self.query.select_for_no_key_update,
732 )
733
734 if where:
735 result.append(f"WHERE {where}")
736 params.extend(w_params)
737
738 grouping = []
739 for g_sql, g_params in group_by:
740 grouping.append(g_sql)
741 params.extend(g_params)
742 if grouping:
743 if distinct_fields:
744 raise NotImplementedError(
745 "annotate() + distinct(fields) is not implemented."
746 )
747 order_by = order_by or []
748 result.append("GROUP BY {}".format(", ".join(grouping)))
749 if self._meta_ordering:
750 order_by = None
751 if having:
752 result.append(f"HAVING {having}")
753 params.extend(h_params)
754
755 if self.query.explain_info:
756 result.insert(
757 0,
758 explain_query_prefix(
759 self.query.explain_info.format,
760 **self.query.explain_info.options,
761 ),
762 )
763
764 if order_by:
765 ordering = []
766 for _, (o_sql, o_params, _) in order_by:
767 ordering.append(o_sql)
768 params.extend(o_params)
769 result.append("ORDER BY {}".format(", ".join(ordering)))
770
771 if with_limit_offset:
772 result.append(
773 limit_offset_sql(self.query.low_mark, self.query.high_mark)
774 )
775
776 if for_update_part:
777 result.append(for_update_part)
778
779 if self.query.subquery and extra_select:
780 # If the query is used as a subquery, the extra selects would
781 # result in more columns than the left-hand side expression is
782 # expecting. This can happen when a subquery uses a combination
783 # of order_by() and distinct(), forcing the ordering expressions
784 # to be selected as well. Wrap the query in another subquery
785 # to exclude extraneous selects.
786 sub_selects = []
787 sub_params = []
788 for index, (select, _, alias) in enumerate(self.select, start=1):
789 if alias:
790 sub_selects.append(
791 "{}.{}".format(
792 quote_name("subquery"),
793 quote_name(alias),
794 )
795 )
796 else:
797 select_clone = select.relabeled_clone(
798 {select.alias: "subquery"}
799 )
800 subselect, subparams = select_clone.as_sql(
801 self, self.connection
802 )
803 sub_selects.append(subselect)
804 sub_params.extend(subparams)
805 return "SELECT {} FROM ({}) subquery".format(
806 ", ".join(sub_selects),
807 " ".join(result),
808 ), tuple(sub_params + params)
809
810 return " ".join(result), tuple(params)
811 finally:
812 # Finally do cleanup - get rid of the joins we created above.
813 self.query.reset_refcounts(refcounts_before)
814
815 def get_default_columns(
816 self,
817 select_mask: Any,
818 start_alias: str | None = None,
819 opts: Meta | None = None,
820 ) -> list[Any]:
821 """
822 Return Col expressions for every concrete field on the model. When
823 pulling in a related model (e.g. via select_related), the caller
824 passes ``opts`` and ``start_alias`` to traverse from that join.
825 """
826 result = []
827 if opts is None:
828 if self.query.model is None:
829 return result
830 opts = self.query.model._model_meta
831 start_alias = start_alias or self.query.get_initial_alias()
832
833 for field in opts.concrete_fields:
834 if select_mask and field not in select_mask:
835 continue
836 result.append(field.get_col(start_alias))
837 return result
838
839 def get_distinct(self) -> tuple[list[str], list]:
840 """
841 Return a quoted list of fields to use in DISTINCT ON part of the query.
842
843 This method can alter the tables in the query, and thus it must be
844 called before get_from_clause().
845 """
846 result = []
847 params = []
848 if not self.query.distinct_fields:
849 return result, params
850
851 if self.query.model is None:
852 return result, params
853 opts = self.query.model._model_meta
854
855 for name in self.query.distinct_fields:
856 parts = name.split(LOOKUP_SEP)
857 _, targets, alias, joins, path, _, transform_function = self._setup_joins(
858 parts, opts, None
859 )
860 targets, alias, _ = self.query.trim_joins(targets, joins, path)
861 for target in targets:
862 if name in self.query.annotation_select:
863 result.append(quote_name(name))
864 else:
865 r, p = self.compile(transform_function(target, alias))
866 result.append(r)
867 params.append(p)
868 return result, params
869
870 def find_ordering_name(
871 self,
872 name: str,
873 meta: Meta,
874 alias: str | None = None,
875 default_order: str = "ASC",
876 already_seen: set | None = None,
877 ) -> list[tuple[OrderBy, bool]]:
878 """
879 Return the table alias (the name might be ambiguous, the alias will
880 not be) and column name for ordering by the given 'name' parameter.
881 The 'name' is of the form 'field1__field2__...__fieldN'.
882 """
883 name, order = get_order_dir(name, default_order)
884 descending = order == "DESC"
885 pieces = name.split(LOOKUP_SEP)
886 (
887 field,
888 targets,
889 alias,
890 joins,
891 path,
892 meta,
893 transform_function,
894 ) = self._setup_joins(pieces, meta, alias)
895
896 # If we get to this point and the field is a relation to another model,
897 # append the default ordering for that model unless it is the
898 # attribute name of the field that is specified or
899 # there are transforms to process.
900 if (
901 isinstance(field, RelatedField)
902 and meta.model.model_options.ordering
903 and getattr(field, "attname", None) != pieces[-1]
904 and not getattr(transform_function, "has_transforms", False)
905 ):
906 # Firstly, avoid infinite loops.
907 already_seen = already_seen or set()
908 join_tuple = tuple(
909 getattr(self.query.alias_map[j], "join_cols", None) for j in joins
910 )
911 if join_tuple in already_seen:
912 raise FieldError("Infinite loop caused by ordering.")
913 already_seen.add(join_tuple)
914
915 results = []
916 for item in meta.model.model_options.ordering:
917 if isinstance(item, ResolvableExpression) and not isinstance(
918 item, OrderBy
919 ):
920 item_expr: BaseExpression = cast(BaseExpression, item)
921 item = item_expr.desc() if descending else item_expr.asc()
922 if isinstance(item, OrderBy):
923 results.append(
924 (item.prefix_references(f"{name}{LOOKUP_SEP}"), False)
925 )
926 continue
927 results.extend(
928 (expr.prefix_references(f"{name}{LOOKUP_SEP}"), is_ref)
929 for expr, is_ref in self.find_ordering_name(
930 item, meta, alias, order, already_seen
931 )
932 )
933 return results
934 targets, alias, _ = self.query.trim_joins(targets, joins, path)
935 return [
936 (OrderBy(transform_function(t, alias), descending=descending), False)
937 for t in targets
938 ]
939
940 def _setup_joins(
941 self, pieces: list[str], meta: Meta, alias: str | None
942 ) -> tuple[Any, Any, str, list, Any, Meta, Any]:
943 """
944 Helper method for get_order_by() and get_distinct().
945
946 get_ordering() and get_distinct() must produce same target columns on
947 same input, as the prefixes of get_ordering() and get_distinct() must
948 match. Executing SQL where this is not true is an error.
949 """
950 alias = alias or self.query.get_initial_alias()
951 assert alias is not None
952 field, targets, meta, joins, path, transform_function = self.query.setup_joins(
953 pieces, meta, alias
954 )
955 alias = joins[-1]
956 return field, targets, alias, joins, path, meta, transform_function
957
958 def get_from_clause(self) -> tuple[list[str], list]:
959 """
960 Return a list of strings that are joined together to go after the
961 "FROM" part of the query, as well as a list any extra parameters that
962 need to be included. Subclasses, can override this to create a
963 from-clause via a "select".
964
965 This should only be called after any SQL construction methods that
966 might change the tables that are needed. This means the select columns,
967 ordering, and distinct must be done first.
968 """
969 result = []
970 params = []
971 for alias in tuple(self.query.alias_map):
972 if not self.query.alias_refcount[alias]:
973 continue
974 try:
975 from_clause = self.query.alias_map[alias]
976 except KeyError:
977 # Extra tables can end up in self.tables, but not in the
978 # alias_map if they aren't in a join. That's OK. We skip them.
979 continue
980 clause_sql, clause_params = self.compile(from_clause)
981 result.append(clause_sql)
982 params.extend(clause_params)
983 for t in self.query.extra_tables:
984 alias, _ = self.query.table_alias(t)
985 # Only add the alias if it's not already present (the table_alias()
986 # call increments the refcount, so an alias refcount of one means
987 # this is the only reference).
988 if (
989 alias not in self.query.alias_map
990 or self.query.alias_refcount[alias] == 1
991 ):
992 result.append(f", {self.quote_name_unless_alias(alias)}")
993 return result, params
994
995 def get_related_selections(
996 self,
997 select: list[Any],
998 select_mask: Any,
999 opts: Meta | None = None,
1000 root_alias: str | None = None,
1001 cur_depth: int = 1,
1002 requested: dict | None = None,
1003 restricted: bool | None = None,
1004 ) -> list[dict[str, Any]]:
1005 """
1006 Fill in the information needed for a select_related query. The current
1007 depth is measured as the number of connections away from the root model
1008 (for example, cur_depth=1 means we are looking at models with direct
1009 connections to the root model).
1010
1011 Args:
1012 opts: Meta for the model being queried (internal metadata)
1013 """
1014
1015 related_klass_infos = []
1016 if not restricted and cur_depth > self.query.max_depth:
1017 # We've recursed far enough; bail out.
1018 return related_klass_infos
1019
1020 if not opts:
1021 assert self.query.model is not None, "select_related requires a model"
1022 opts = self.query.model._model_meta
1023 root_alias = self.query.get_initial_alias()
1024
1025 assert root_alias is not None # Must be provided or set above
1026 assert opts is not None
1027
1028 def _get_field_choices() -> chain:
1029 direct_choices = (
1030 f.name for f in opts.fields if isinstance(f, RelatedField)
1031 )
1032 reverse_choices = (
1033 f.field.related_query_name()
1034 for f in opts.related_objects
1035 if f.field.primary_key
1036 )
1037 return chain(
1038 direct_choices, reverse_choices, self.query._filtered_relations
1039 )
1040
1041 # Setup for the case when only particular related fields should be
1042 # included in the related selection.
1043 fields_found = set()
1044 if requested is None:
1045 restricted = isinstance(self.query.select_related, dict)
1046 if restricted:
1047 requested = cast(dict, self.query.select_related)
1048
1049 def get_related_klass_infos(
1050 klass_info: dict, related_klass_infos: list
1051 ) -> None:
1052 klass_info["related_klass_infos"] = related_klass_infos
1053
1054 for f in opts.fields:
1055 fields_found.add(f.name)
1056
1057 if restricted:
1058 assert requested is not None
1059 next = requested.get(f.name, {})
1060 if not isinstance(f, RelatedField):
1061 # If a non-related field is used like a relation,
1062 # or if a single non-relational field is given.
1063 if next or f.name in requested:
1064 raise FieldError(
1065 "Non-relational field given in select_related: '{}'. "
1066 "Choices are: {}".format(
1067 f.name,
1068 ", ".join(_get_field_choices()) or "(none)",
1069 )
1070 )
1071 else:
1072 next = None
1073
1074 if not select_related_descend(f, restricted, requested, select_mask):
1075 continue
1076 related_select_mask = select_mask.get(f) or {}
1077 klass_info: dict[str, Any] = {
1078 "model": f.remote_field.model,
1079 "field": f,
1080 "reverse": False,
1081 "local_setter": f.set_cached_value,
1082 "remote_setter": f.remote_field.set_cached_value
1083 if f.primary_key
1084 else lambda x, y: None,
1085 }
1086 related_klass_infos.append(klass_info)
1087 select_fields = []
1088 _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias)
1089 alias = joins[-1]
1090 columns = self.get_default_columns(
1091 related_select_mask,
1092 start_alias=alias,
1093 opts=f.remote_field.model._model_meta,
1094 )
1095 for col in columns:
1096 select_fields.append(len(select))
1097 select.append((col, None))
1098 klass_info["select_fields"] = select_fields
1099 next_klass_infos = self.get_related_selections(
1100 select,
1101 related_select_mask,
1102 f.remote_field.model._model_meta,
1103 alias,
1104 cur_depth + 1,
1105 next,
1106 restricted,
1107 )
1108 get_related_klass_infos(klass_info, next_klass_infos)
1109
1110 if restricted:
1111 from plain.postgres.fields.reverse_related import ManyToManyRel
1112
1113 related_fields = [
1114 (o.field, o.related_model)
1115 for o in opts.related_objects
1116 if o.field.primary_key and not isinstance(o, ManyToManyRel)
1117 ]
1118 for related_field, model in related_fields:
1119 related_select_mask = select_mask.get(related_field) or {}
1120
1121 if not select_related_descend(
1122 related_field,
1123 restricted,
1124 requested,
1125 related_select_mask,
1126 reverse=True,
1127 ):
1128 continue
1129
1130 related_field_name = related_field.related_query_name()
1131 fields_found.add(related_field_name)
1132
1133 join_info = self.query.setup_joins(
1134 [related_field_name], opts, root_alias
1135 )
1136 alias = join_info.joins[-1]
1137 klass_info: dict[str, Any] = {
1138 "model": model,
1139 "field": related_field,
1140 "reverse": True,
1141 "local_setter": related_field.remote_field.set_cached_value,
1142 "remote_setter": related_field.set_cached_value,
1143 }
1144 related_klass_infos.append(klass_info)
1145 select_fields = []
1146 columns = self.get_default_columns(
1147 related_select_mask,
1148 start_alias=alias,
1149 opts=model._model_meta,
1150 )
1151 for col in columns:
1152 select_fields.append(len(select))
1153 select.append((col, None))
1154 klass_info["select_fields"] = select_fields
1155 assert requested is not None
1156 next = requested.get(related_field.related_query_name(), {})
1157 next_klass_infos = self.get_related_selections(
1158 select,
1159 related_select_mask,
1160 model._model_meta,
1161 alias,
1162 cur_depth + 1,
1163 next,
1164 restricted,
1165 )
1166 get_related_klass_infos(klass_info, next_klass_infos)
1167
1168 def local_setter(final_field: Any, obj: Any, from_obj: Any) -> None:
1169 # Set a reverse fk object when relation is non-empty.
1170 if from_obj:
1171 final_field.remote_field.set_cached_value(from_obj, obj)
1172
1173 def local_setter_noop(obj: Any, from_obj: Any) -> None:
1174 pass
1175
1176 def remote_setter(name: str, obj: Any, from_obj: Any) -> None:
1177 setattr(from_obj, name, obj)
1178
1179 assert requested is not None
1180 for name in list(requested):
1181 # Filtered relations work only on the topmost level.
1182 if cur_depth > 1:
1183 break
1184 if name in self.query._filtered_relations:
1185 fields_found.add(name)
1186 final_field, _, join_opts, joins, _, _ = self.query.setup_joins(
1187 [name], opts, root_alias
1188 )
1189 model = join_opts.model
1190 alias = joins[-1]
1191 klass_info: dict[str, Any] = {
1192 "model": model,
1193 "field": final_field,
1194 "reverse": True,
1195 "local_setter": (
1196 partial(local_setter, final_field)
1197 if len(joins) <= 2
1198 else local_setter_noop
1199 ),
1200 "remote_setter": partial(remote_setter, name),
1201 }
1202 related_klass_infos.append(klass_info)
1203 select_fields = []
1204 field_select_mask = select_mask.get((name, final_field)) or {}
1205 columns = self.get_default_columns(
1206 field_select_mask,
1207 start_alias=alias,
1208 opts=model._model_meta,
1209 )
1210 for col in columns:
1211 select_fields.append(len(select))
1212 select.append((col, None))
1213 klass_info["select_fields"] = select_fields
1214 next_requested = requested.get(name, {})
1215 next_klass_infos = self.get_related_selections(
1216 select,
1217 field_select_mask,
1218 opts=model._model_meta,
1219 root_alias=alias,
1220 cur_depth=cur_depth + 1,
1221 requested=next_requested,
1222 restricted=restricted,
1223 )
1224 get_related_klass_infos(klass_info, next_klass_infos)
1225 fields_not_found = set(requested).difference(fields_found)
1226 if fields_not_found:
1227 invalid_fields = (f"'{s}'" for s in fields_not_found)
1228 raise FieldError(
1229 "Invalid field name(s) given in select_related: {}. "
1230 "Choices are: {}".format(
1231 ", ".join(invalid_fields),
1232 ", ".join(_get_field_choices()) or "(none)",
1233 )
1234 )
1235 return related_klass_infos
1236
1237 def get_select_for_update_of_arguments(self) -> list[str]:
1238 """
1239 Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
1240 the query.
1241 """
1242
1243 def _get_first_selected_col_from_model(klass_info: dict) -> Any | None:
1244 """
1245 Find the first selected column whose target field belongs to this
1246 klass_info's model. Returns None when the model isn't represented
1247 in the select list — callers use that to skip locking the row.
1248 """
1249 assert self.select is not None
1250 model = klass_info["model"]
1251 for select_index in klass_info["select_fields"]:
1252 if self.select[select_index][0].target.model == model:
1253 return self.select[select_index][0]
1254 return None
1255
1256 def _get_field_choices() -> Generator[str]:
1257 """Yield all allowed field paths in breadth-first search order."""
1258 queue = collections.deque([(None, self.klass_info)])
1259 while queue:
1260 parent_path, klass_info = queue.popleft()
1261 if parent_path is None:
1262 path = []
1263 yield "self"
1264 else:
1265 assert klass_info is not None # Only first iteration has None
1266 field = klass_info["field"]
1267 if klass_info["reverse"]:
1268 field = field.remote_field
1269 path = parent_path + [field.name]
1270 yield LOOKUP_SEP.join(path)
1271 if klass_info is not None:
1272 queue.extend(
1273 (path, related_klass_info) # type: ignore[invalid-argument-type]
1274 for related_klass_info in klass_info.get(
1275 "related_klass_infos", []
1276 )
1277 )
1278
1279 if not self.klass_info:
1280 return []
1281 result = []
1282 invalid_names = []
1283 for name in self.query.select_for_update_of:
1284 klass_info = self.klass_info
1285 if name == "self":
1286 col = _get_first_selected_col_from_model(klass_info)
1287 else:
1288 for part in name.split(LOOKUP_SEP):
1289 if klass_info is None:
1290 break
1291 klass_infos = (*klass_info.get("related_klass_infos", []),)
1292 for related_klass_info in klass_infos:
1293 field = related_klass_info["field"]
1294 if related_klass_info["reverse"]:
1295 field = field.remote_field
1296 if field.name == part:
1297 klass_info = related_klass_info
1298 break
1299 else:
1300 klass_info = None
1301 break
1302 if klass_info is None:
1303 invalid_names.append(name)
1304 continue
1305 col = _get_first_selected_col_from_model(klass_info)
1306 if col is not None:
1307 result.append(self.quote_name_unless_alias(col.alias))
1308 if invalid_names:
1309 raise FieldError(
1310 "Invalid field name(s) given in select_for_update(of=(...)): {}. "
1311 "Only relational fields followed in the query are allowed. "
1312 "Choices are: {}.".format(
1313 ", ".join(invalid_names),
1314 ", ".join(_get_field_choices()),
1315 )
1316 )
1317 return result
1318
1319 def results_iter(
1320 self,
1321 results: Any = None,
1322 tuple_expected: bool = False,
1323 chunked_fetch: bool = False,
1324 ) -> Iterable[Any]:
1325 """Return an iterator over the results from executing this query."""
1326 if results is None:
1327 results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch)
1328 assert self.select is not None # Set during query execution
1329 fields = [s[0] for s in self.select[0 : self.col_count]]
1330 converters = get_converters(fields, self.connection)
1331 rows = results
1332 if converters:
1333 rows = apply_converters(rows, converters, self.connection)
1334 if tuple_expected:
1335 rows = map(tuple, rows)
1336 return rows
1337
1338 def has_results(self) -> bool:
1339 """Check if the query returns any results."""
1340 return bool(self.execute_sql(SINGLE))
1341
1342 def execute_sql(
1343 self,
1344 result_type: str = MULTI,
1345 chunked_fetch: bool = False,
1346 ) -> Any:
1347 """
1348 Run the query against the database and return the result(s). The
1349 return value is a single data item if result_type is SINGLE, or a
1350 flat iterable of rows if the result_type is MULTI.
1351
1352 result_type is either MULTI (returns a list from fetchall(), or a
1353 streaming generator from cursor.stream() when chunked_fetch=True),
1354 SINGLE (only retrieve a single row), or None. In this last case, the
1355 cursor is returned if any query is executed, since it's used by
1356 subclasses such as InsertQuery). It's possible, however, that no query
1357 is needed, as the filters describe an empty set. In that case, None is
1358 returned, to avoid any unnecessary database interaction.
1359 """
1360 result_type = result_type or NO_RESULTS
1361 try:
1362 as_sql_result = self.as_sql()
1363 # SQLCompiler.as_sql returns SqlWithParams, subclasses may differ
1364 assert isinstance(as_sql_result, tuple)
1365 assert isinstance(as_sql_result[0], str)
1366 sql, params = as_sql_result
1367 if not sql:
1368 raise EmptyResultSet
1369 except EmptyResultSet:
1370 if result_type == MULTI:
1371 return iter([])
1372 else:
1373 return
1374 cursor = self.connection.cursor()
1375 if chunked_fetch:
1376 # Use psycopg3's cursor.stream() for server-side cursor iteration.
1377 result = cursor.stream(sql, params)
1378 if self.has_extra_select:
1379 col_count = self.col_count
1380 result = (r[:col_count] for r in result)
1381 return result
1382
1383 try:
1384 cursor.execute(sql, params)
1385 except Exception:
1386 cursor.close()
1387 raise
1388
1389 if result_type == CURSOR:
1390 # Give the caller the cursor to process and close.
1391 return cursor
1392 if result_type == SINGLE:
1393 try:
1394 val = cursor.fetchone()
1395 if val:
1396 return val[0 : self.col_count]
1397 return val
1398 finally:
1399 # done with the cursor
1400 cursor.close()
1401 if result_type == NO_RESULTS:
1402 cursor.close()
1403 return
1404
1405 try:
1406 rows = cursor.fetchall()
1407 finally:
1408 cursor.close()
1409 if self.has_extra_select:
1410 rows = [r[: self.col_count] for r in rows]
1411 return rows
1412
1413 def explain_query(self) -> Generator[str]:
1414 result = self.execute_sql()
1415 explain_info = self.query.explain_info
1416 # PostgreSQL may return tuples with integers and strings depending on
1417 # the EXPLAIN format. Flatten them out into strings.
1418 format_ = explain_info.format if explain_info is not None else None
1419 output_formatter = json.dumps if format_ and format_.lower() == "json" else str
1420 for row in result:
1421 if not isinstance(row, str):
1422 yield " ".join(output_formatter(c) for c in row)
1423 else:
1424 yield row
1425
1426
1427class SQLInsertCompiler(SQLCompiler):
1428 query: InsertQuery
1429 returning_fields: list | None = None
1430 returning_params: tuple = ()
1431
1432 def field_as_sql(self, field: Any, val: Any) -> tuple[str, list]:
1433 """
1434 Take a field and a value intended to be saved on that field, and
1435 return placeholder SQL and accompanying params. Check for raw values,
1436 expressions, and fields with get_placeholder() defined in that order.
1437
1438 When field is None, consider the value raw and use it as the
1439 placeholder, with no corresponding parameters returned.
1440 """
1441 if val is DATABASE_DEFAULT:
1442 # Emit the literal DEFAULT keyword so Postgres uses the column's
1443 # persistent DEFAULT (e.g. `gen_random_uuid()`). RETURNING then
1444 # populates the real value back onto the instance.
1445 sql, params = "DEFAULT", []
1446 elif field is None:
1447 # A field value of None means the value is raw.
1448 sql, params = val, []
1449 elif hasattr(val, "as_sql"):
1450 # This is an expression, let's compile it.
1451 sql, params_tuple = self.compile(val)
1452 params = list(params_tuple)
1453 elif hasattr(field, "get_placeholder"):
1454 # Some fields (e.g. geo fields) need special munging before
1455 # they can be inserted.
1456 sql, params = field.get_placeholder(val, self, self.connection), [val]
1457 else:
1458 # Return the common case for the placeholder
1459 sql, params = "%s", [val]
1460
1461 return sql, list(params) # Ensure params is a list
1462
1463 def prepare_value(self, field: Any, value: Any) -> Any:
1464 """
1465 Prepare a value to be used in a query by resolving it if it is an
1466 expression and otherwise calling the field's get_db_prep_save().
1467 """
1468 if value is DATABASE_DEFAULT:
1469 # Carry the sentinel through untouched — field_as_sql will emit
1470 # the literal DEFAULT keyword.
1471 return value
1472 if isinstance(value, ResolvableExpression):
1473 value = value.resolve_expression(
1474 self.query, allow_joins=False, for_save=True
1475 )
1476 # Don't allow values containing Col expressions. They refer to
1477 # existing columns on a row, but in the case of insert the row
1478 # doesn't exist yet.
1479 if value.contains_column_references:
1480 raise ValueError(
1481 f'Failed to insert expression "{value}" on {field}. F() expressions '
1482 "can only be used to update, not to insert."
1483 )
1484 if value.contains_aggregate:
1485 raise FieldError(
1486 "Aggregate functions are not allowed in this query "
1487 f"({field.name}={value!r})."
1488 )
1489 if value.contains_over_clause:
1490 raise FieldError(
1491 f"Window expressions are not allowed in this query ({field.name}={value!r})."
1492 )
1493 return field.get_db_prep_save(value, connection=self.connection)
1494
1495 def pre_save_val(self, field: Any, obj: Any) -> Any:
1496 """
1497 Get the given field's value off the given obj. pre_save() is used for
1498 things like update_now on DateTimeField. Skip it if this is a raw query.
1499 """
1500 if self.query.raw:
1501 return getattr(obj, field.attname)
1502 return field.pre_save(obj, add=True)
1503
1504 def assemble_as_sql(
1505 self, fields: list[Any], value_rows: list[list[Any]]
1506 ) -> tuple[Any, list[list[Any]]]:
1507 """
1508 Take a sequence of N fields and a sequence of M rows of values, and
1509 generate placeholder SQL and parameters for each field and value.
1510 Return a pair containing:
1511 * a sequence of M rows of N SQL placeholder strings, and
1512 * a sequence of M rows of corresponding parameter values.
1513
1514 Each placeholder string may contain any number of '%s' interpolation
1515 strings, and each parameter row will contain exactly as many params
1516 as the total number of '%s's in the corresponding placeholder row.
1517 """
1518 if not value_rows:
1519 return [], []
1520
1521 # list of (sql, [params]) tuples for each object to be saved
1522 # Shape: [n_objs][n_fields][2]
1523 rows_of_fields_as_sql = (
1524 (self.field_as_sql(field, v) for field, v in zip(fields, row))
1525 for row in value_rows
1526 )
1527
1528 # tuple like ([sqls], [[params]s]) for each object to be saved
1529 # Shape: [n_objs][2][n_fields]
1530 sql_and_param_pair_rows = (zip(*row) for row in rows_of_fields_as_sql)
1531
1532 # Extract separate lists for placeholders and params.
1533 # Each of these has shape [n_objs][n_fields]
1534 placeholder_rows, param_rows = zip(*sql_and_param_pair_rows)
1535
1536 # Params for each field are still lists, and need to be flattened.
1537 param_rows = [[p for ps in row for p in ps] for row in param_rows]
1538
1539 return placeholder_rows, param_rows
1540
1541 def as_sql( # ty: ignore[invalid-method-override] # Returns list for internal iteration in execute_sql
1542 self, with_limits: bool = True, with_col_aliases: bool = False
1543 ) -> list[SqlWithParams]:
1544 # We don't need quote_name_unless_alias() here, since these are all
1545 # going to be column names (so we can avoid the extra overhead).
1546 qn = quote_name
1547 assert self.query.model is not None, "INSERT requires a model"
1548 meta = self.query.model._model_meta
1549 options = self.query.model.model_options
1550 result = [f"INSERT INTO {qn(options.db_table)}"]
1551 if self.query.fields:
1552 fields = self.query.fields
1553 else:
1554 fields = [meta.get_forward_field("id")]
1555 result.append("({})".format(", ".join(qn(f.column) for f in fields)))
1556
1557 if self.query.fields:
1558 value_rows = [
1559 [
1560 self.prepare_value(field, self.pre_save_val(field, obj))
1561 for field in fields
1562 ]
1563 for obj in self.query.objs
1564 ]
1565 else:
1566 # An empty object.
1567 value_rows = [[PK_DEFAULT_VALUE] for _ in self.query.objs]
1568 fields = [None]
1569
1570 placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
1571
1572 conflict_suffix_sql = on_conflict_suffix_sql(
1573 fields, # ty: ignore[invalid-argument-type]
1574 self.query.on_conflict,
1575 (f.column for f in self.query.update_fields),
1576 (f.column for f in self.query.unique_fields),
1577 )
1578 if self.returning_fields:
1579 # Use RETURNING clause to get inserted values
1580 result.append(
1581 bulk_insert_sql(fields, placeholder_rows) # ty: ignore[invalid-argument-type]
1582 )
1583 params = param_rows
1584 if conflict_suffix_sql:
1585 result.append(conflict_suffix_sql)
1586 # Skip empty r_sql in case returning_cols returns an empty string.
1587 returning_cols = return_insert_columns(self.returning_fields)
1588 if returning_cols:
1589 r_sql, self.returning_params = returning_cols
1590 if r_sql:
1591 result.append(r_sql)
1592 params += [list(self.returning_params)]
1593 return [(" ".join(result), tuple(chain.from_iterable(params)))]
1594
1595 # Bulk insert without returning fields
1596 result.append(bulk_insert_sql(fields, placeholder_rows)) # ty: ignore[invalid-argument-type]
1597 if conflict_suffix_sql:
1598 result.append(conflict_suffix_sql)
1599 return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
1600
1601 def execute_sql( # ty: ignore[invalid-method-override]
1602 self, returning_fields: list | None = None
1603 ) -> list:
1604 assert self.query.model is not None, "INSERT execution requires a model"
1605 options = self.query.model.model_options
1606 self.returning_fields = returning_fields
1607 with self.connection.cursor() as cursor:
1608 for sql, params in self.as_sql():
1609 cursor.execute(sql, params)
1610 if not self.returning_fields:
1611 return []
1612 # Use RETURNING clause for both single and bulk inserts
1613 if len(self.query.objs) > 1:
1614 rows = cursor.fetchall()
1615 else:
1616 rows = [cursor.fetchone()]
1617 cols = [field.get_col(options.db_table) for field in self.returning_fields]
1618 converters = get_converters(cols, self.connection)
1619 if converters:
1620 rows = list(apply_converters(rows, converters, self.connection))
1621 return rows
1622
1623
1624class SQLDeleteCompiler(SQLCompiler):
1625 @cached_property
1626 def single_alias(self) -> bool:
1627 # Ensure base table is in aliases.
1628 self.query.get_initial_alias()
1629 return sum(self.query.alias_refcount[t] > 0 for t in self.query.alias_map) == 1
1630
1631 @classmethod
1632 def _expr_refs_base_model(cls, expr: Any, base_model: Any) -> bool:
1633 if isinstance(expr, Query):
1634 return expr.model == base_model
1635 if not hasattr(expr, "get_source_expressions"):
1636 return False
1637 return any(
1638 cls._expr_refs_base_model(source_expr, base_model)
1639 for source_expr in expr.get_source_expressions()
1640 )
1641
1642 @cached_property
1643 def contains_self_reference_subquery(self) -> bool:
1644 return any(
1645 self._expr_refs_base_model(expr, self.query.model)
1646 for expr in chain(
1647 self.query.annotations.values(), self.query.where.children
1648 )
1649 )
1650
1651 def _as_sql(self, query: Query) -> SqlWithParams:
1652 delete = f"DELETE FROM {self.quote_name_unless_alias(query.base_table)}" # ty: ignore[invalid-argument-type]
1653 try:
1654 where, params = self.compile(query.where)
1655 except FullResultSet:
1656 return delete, ()
1657 return f"{delete} WHERE {where}", tuple(params)
1658
1659 def as_sql(
1660 self, with_limits: bool = True, with_col_aliases: bool = False
1661 ) -> SqlWithParams:
1662 """
1663 Create the SQL for this query. Return the SQL string and list of
1664 parameters.
1665 """
1666 if self.single_alias and not self.contains_self_reference_subquery:
1667 return self._as_sql(self.query)
1668 innerq = self.query.clone()
1669 innerq.__class__ = Query
1670 innerq.clear_select_clause()
1671 assert self.query.model is not None, "DELETE requires a model"
1672 id_field = self.query.model._model_meta.get_forward_field("id")
1673 innerq.select = (id_field.get_col(self.query.get_initial_alias()),)
1674 outerq = Query(self.query.model)
1675 outerq.add_filter("id__in", innerq)
1676 return self._as_sql(outerq)
1677
1678
1679class SQLUpdateCompiler(SQLCompiler):
1680 def as_sql(
1681 self, with_limits: bool = True, with_col_aliases: bool = False
1682 ) -> SqlWithParams:
1683 """
1684 Create the SQL for this query. Return the SQL string and list of
1685 parameters.
1686 """
1687 self.pre_sql_setup()
1688 query_values = getattr(self.query, "values", None)
1689 if not query_values:
1690 return "", ()
1691 qn = self.quote_name_unless_alias
1692 values, update_params = [], []
1693 for field, val in query_values:
1694 if isinstance(val, ResolvableExpression):
1695 val = val.resolve_expression(
1696 self.query, allow_joins=False, for_save=True
1697 )
1698 if val.contains_aggregate:
1699 raise FieldError(
1700 "Aggregate functions are not allowed in this query "
1701 f"({field.name}={val!r})."
1702 )
1703 if val.contains_over_clause:
1704 raise FieldError(
1705 "Window expressions are not allowed in this query "
1706 f"({field.name}={val!r})."
1707 )
1708 elif hasattr(val, "prepare_database_save"):
1709 if isinstance(field, RelatedField):
1710 val = val.prepare_database_save(field)
1711 else:
1712 raise TypeError(
1713 f"Tried to update field {field} with a model instance, {val!r}. "
1714 f"Use a value compatible with {field.__class__.__name__}."
1715 )
1716 val = field.get_db_prep_save(val, connection=self.connection)
1717
1718 # Getting the placeholder for the field.
1719 if hasattr(field, "get_placeholder"):
1720 placeholder = field.get_placeholder(val, self, self.connection)
1721 else:
1722 placeholder = "%s"
1723 name = field.column
1724 if hasattr(val, "as_sql"):
1725 sql, params = self.compile(val)
1726 values.append(f"{qn(name)} = {placeholder % sql}")
1727 update_params.extend(params)
1728 elif val is not None:
1729 values.append(f"{qn(name)} = {placeholder}")
1730 update_params.append(val)
1731 else:
1732 values.append(f"{qn(name)} = NULL")
1733 table = self.query.base_table
1734 result = [
1735 f"UPDATE {qn(table)} SET", # ty: ignore[invalid-argument-type]
1736 ", ".join(values),
1737 ]
1738 try:
1739 where, params = self.compile(self.query.where)
1740 except FullResultSet:
1741 params = []
1742 else:
1743 result.append(f"WHERE {where}")
1744 return " ".join(result), tuple(update_params + list(params))
1745
1746 def execute_sql(self, result_type: str) -> int: # ty: ignore[invalid-method-override]
1747 """Execute the update and return the number of rows affected."""
1748 cursor = super().execute_sql(result_type)
1749 try:
1750 return cursor.rowcount if cursor else 0
1751 finally:
1752 if cursor:
1753 cursor.close()
1754
1755 def pre_sql_setup(
1756 self, with_col_aliases: bool = False
1757 ) -> tuple[list[Any], list[Any], list[SqlWithParams]] | None:
1758 """
1759 If the update depends on other tables (JOINs in the WHERE clause),
1760 rewrite the query so the current table is filtered by `id IN (subquery)`.
1761 """
1762 refcounts_before = self.query.alias_refcount.copy()
1763 # Ensure base table is in the query
1764 self.query.get_initial_alias()
1765 count = self.query.count_active_tables()
1766 if count == 1:
1767 return
1768 query = self.query.chain(klass=Query)
1769 query.select_related = False
1770 query.clear_ordering(force=True)
1771 query.extra = {}
1772 query.select = ()
1773 query.add_fields(["id"])
1774 super().pre_sql_setup()
1775
1776 # Reset the where clause and drop the tables we no longer need (they
1777 # live in the sub-select now).
1778 self.query.clear_where()
1779 self.query.add_filter("id__in", query)
1780 self.query.reset_refcounts(refcounts_before)
1781
1782
1783class SQLAggregateCompiler(SQLCompiler):
1784 def as_sql(
1785 self, with_limits: bool = True, with_col_aliases: bool = False
1786 ) -> SqlWithParams:
1787 """
1788 Create the SQL for this query. Return the SQL string and list of
1789 parameters.
1790 """
1791 sql, params = [], []
1792 for annotation in self.query.annotation_select.values():
1793 ann_sql, ann_params = self.compile(annotation)
1794 ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params)
1795 sql.append(ann_sql)
1796 params.extend(ann_params)
1797 self.col_count = len(self.query.annotation_select)
1798 sql = ", ".join(sql)
1799 params = tuple(params)
1800
1801 inner_query = cast("AggregateQuery", self.query).inner_query
1802 inner_query_sql, inner_query_params = inner_query.get_compiler(
1803 elide_empty=self.elide_empty,
1804 ).as_sql(with_col_aliases=True)
1805 sql = f"SELECT {sql} FROM ({inner_query_sql}) subquery"
1806 params += inner_query_params
1807 return sql, params