1import collections
2import json
3import re
4from functools import partial
5from itertools import chain
6
7from plain.exceptions import EmptyResultSet, FieldError, FullResultSet
8from plain.models.constants import LOOKUP_SEP
9from plain.models.db import DatabaseError, NotSupportedError
10from plain.models.expressions import F, OrderBy, RawSQL, Ref, Value
11from plain.models.functions import Cast, Random
12from plain.models.lookups import Lookup
13from plain.models.query_utils import select_related_descend
14from plain.models.sql.constants import (
15 CURSOR,
16 GET_ITERATOR_CHUNK_SIZE,
17 MULTI,
18 NO_RESULTS,
19 ORDER_DIR,
20 SINGLE,
21)
22from plain.models.sql.query import Query, get_order_dir
23from plain.models.sql.where import AND
24from plain.models.transaction import TransactionManagementError
25from plain.utils.functional import cached_property
26from plain.utils.hashable import make_hashable
27from plain.utils.regex_helper import _lazy_re_compile
28
29
30class PositionRef(Ref):
31 def __init__(self, ordinal, refs, source):
32 self.ordinal = ordinal
33 super().__init__(refs, source)
34
35 def as_sql(self, compiler, connection):
36 return str(self.ordinal), ()
37
38
39class SQLCompiler:
40 # Multiline ordering SQL clause may appear from RawSQL.
41 ordering_parts = _lazy_re_compile(
42 r"^(.*)\s(?:ASC|DESC).*",
43 re.MULTILINE | re.DOTALL,
44 )
45
46 def __init__(self, query, connection, using, elide_empty=True):
47 self.query = query
48 self.connection = connection
49 self.using = using
50 # Some queries, e.g. coalesced aggregation, need to be executed even if
51 # they would return an empty result set.
52 self.elide_empty = elide_empty
53 self.quote_cache = {"*": "*"}
54 # The select, klass_info, and annotations are needed by QuerySet.iterator()
55 # these are set as a side-effect of executing the query. Note that we calculate
56 # separately a list of extra select columns needed for grammatical correctness
57 # of the query, but these columns are not included in self.select.
58 self.select = None
59 self.annotation_col_map = None
60 self.klass_info = None
61 self._meta_ordering = None
62
63 def __repr__(self):
64 return (
65 f"<{self.__class__.__qualname__} "
66 f"model={self.query.model.__qualname__} "
67 f"connection={self.connection!r} using={self.using!r}>"
68 )
69
70 def setup_query(self, with_col_aliases=False):
71 if all(self.query.alias_refcount[a] == 0 for a in self.query.alias_map):
72 self.query.get_initial_alias()
73 self.select, self.klass_info, self.annotation_col_map = self.get_select(
74 with_col_aliases=with_col_aliases,
75 )
76 self.col_count = len(self.select)
77
78 def pre_sql_setup(self, with_col_aliases=False):
79 """
80 Do any necessary class setup immediately prior to producing SQL. This
81 is for things that can't necessarily be done in __init__ because we
82 might not have all the pieces in place at that time.
83 """
84 self.setup_query(with_col_aliases=with_col_aliases)
85 order_by = self.get_order_by()
86 self.where, self.having, self.qualify = self.query.where.split_having_qualify(
87 must_group_by=self.query.group_by is not None
88 )
89 extra_select = self.get_extra_select(order_by, self.select)
90 self.has_extra_select = bool(extra_select)
91 group_by = self.get_group_by(self.select + extra_select, order_by)
92 return extra_select, order_by, group_by
93
94 def get_group_by(self, select, order_by):
95 """
96 Return a list of 2-tuples of form (sql, params).
97
98 The logic of what exactly the GROUP BY clause contains is hard
99 to describe in other words than "if it passes the test suite,
100 then it is correct".
101 """
102 # Some examples:
103 # SomeModel.objects.annotate(Count('somecol'))
104 # GROUP BY: all fields of the model
105 #
106 # SomeModel.objects.values('name').annotate(Count('somecol'))
107 # GROUP BY: name
108 #
109 # SomeModel.objects.annotate(Count('somecol')).values('name')
110 # GROUP BY: all cols of the model
111 #
112 # SomeModel.objects.values('name', 'pk')
113 # .annotate(Count('somecol')).values('pk')
114 # GROUP BY: name, pk
115 #
116 # SomeModel.objects.values('name').annotate(Count('somecol')).values('pk')
117 # GROUP BY: name, pk
118 #
119 # In fact, the self.query.group_by is the minimal set to GROUP BY. It
120 # can't be ever restricted to a smaller set, but additional columns in
121 # HAVING, ORDER BY, and SELECT clauses are added to it. Unfortunately
122 # the end result is that it is impossible to force the query to have
123 # a chosen GROUP BY clause - you can almost do this by using the form:
124 # .values(*wanted_cols).annotate(AnAggregate())
125 # but any later annotations, extra selects, values calls that
126 # refer some column outside of the wanted_cols, order_by, or even
127 # filter calls can alter the GROUP BY clause.
128
129 # The query.group_by is either None (no GROUP BY at all), True
130 # (group by select fields), or a list of expressions to be added
131 # to the group by.
132 if self.query.group_by is None:
133 return []
134 expressions = []
135 group_by_refs = set()
136 if self.query.group_by is not True:
137 # If the group by is set to a list (by .values() call most likely),
138 # then we need to add everything in it to the GROUP BY clause.
139 # Backwards compatibility hack for setting query.group_by. Remove
140 # when we have public API way of forcing the GROUP BY clause.
141 # Converts string references to expressions.
142 for expr in self.query.group_by:
143 if not hasattr(expr, "as_sql"):
144 expr = self.query.resolve_ref(expr)
145 if isinstance(expr, Ref):
146 if expr.refs not in group_by_refs:
147 group_by_refs.add(expr.refs)
148 expressions.append(expr.source)
149 else:
150 expressions.append(expr)
151 # Note that even if the group_by is set, it is only the minimal
152 # set to group by. So, we need to add cols in select, order_by, and
153 # having into the select in any case.
154 selected_expr_positions = {}
155 for ordinal, (expr, _, alias) in enumerate(select, start=1):
156 if alias:
157 selected_expr_positions[expr] = ordinal
158 # Skip members of the select clause that are already explicitly
159 # grouped against.
160 if alias in group_by_refs:
161 continue
162 expressions.extend(expr.get_group_by_cols())
163 if not self._meta_ordering:
164 for expr, (sql, params, is_ref) in order_by:
165 # Skip references to the SELECT clause, as all expressions in
166 # the SELECT clause are already part of the GROUP BY.
167 if not is_ref:
168 expressions.extend(expr.get_group_by_cols())
169 having_group_by = self.having.get_group_by_cols() if self.having else ()
170 for expr in having_group_by:
171 expressions.append(expr)
172 result = []
173 seen = set()
174 expressions = self.collapse_group_by(expressions, having_group_by)
175
176 allows_group_by_select_index = (
177 self.connection.features.allows_group_by_select_index
178 )
179 for expr in expressions:
180 try:
181 sql, params = self.compile(expr)
182 except (EmptyResultSet, FullResultSet):
183 continue
184 if (
185 allows_group_by_select_index
186 and (position := selected_expr_positions.get(expr)) is not None
187 ):
188 sql, params = str(position), ()
189 else:
190 sql, params = expr.select_format(self, sql, params)
191 params_hash = make_hashable(params)
192 if (sql, params_hash) not in seen:
193 result.append((sql, params))
194 seen.add((sql, params_hash))
195 return result
196
197 def collapse_group_by(self, expressions, having):
198 # If the database supports group by functional dependence reduction,
199 # then the expressions can be reduced to the set of selected table
200 # primary keys as all other columns are functionally dependent on them.
201 if self.connection.features.allows_group_by_selected_pks:
202 # Filter out all expressions associated with a table's primary key
203 # present in the grouped columns. This is done by identifying all
204 # tables that have their primary key included in the grouped
205 # columns and removing non-primary key columns referring to them.
206 # Unmanaged models are excluded because they could be representing
207 # database views on which the optimization might not be allowed.
208 pks = {
209 expr
210 for expr in expressions
211 if (
212 hasattr(expr, "target")
213 and expr.target.primary_key
214 and self.connection.features.allows_group_by_selected_pks_on_model(
215 expr.target.model
216 )
217 )
218 }
219 aliases = {expr.alias for expr in pks}
220 expressions = [
221 expr
222 for expr in expressions
223 if expr in pks
224 or expr in having
225 or getattr(expr, "alias", None) not in aliases
226 ]
227 return expressions
228
229 def get_select(self, with_col_aliases=False):
230 """
231 Return three values:
232 - a list of 3-tuples of (expression, (sql, params), alias)
233 - a klass_info structure,
234 - a dictionary of annotations
235
236 The (sql, params) is what the expression will produce, and alias is the
237 "AS alias" for the column (possibly None).
238
239 The klass_info structure contains the following information:
240 - The base model of the query.
241 - Which columns for that model are present in the query (by
242 position of the select clause).
243 - related_klass_infos: [f, klass_info] to descent into
244
245 The annotations is a dictionary of {'attname': column position} values.
246 """
247 select = []
248 klass_info = None
249 annotations = {}
250 select_idx = 0
251 for alias, (sql, params) in self.query.extra_select.items():
252 annotations[alias] = select_idx
253 select.append((RawSQL(sql, params), alias))
254 select_idx += 1
255 assert not (self.query.select and self.query.default_cols)
256 select_mask = self.query.get_select_mask()
257 if self.query.default_cols:
258 cols = self.get_default_columns(select_mask)
259 else:
260 # self.query.select is a special case. These columns never go to
261 # any model.
262 cols = self.query.select
263 if cols:
264 select_list = []
265 for col in cols:
266 select_list.append(select_idx)
267 select.append((col, None))
268 select_idx += 1
269 klass_info = {
270 "model": self.query.model,
271 "select_fields": select_list,
272 }
273 for alias, annotation in self.query.annotation_select.items():
274 annotations[alias] = select_idx
275 select.append((annotation, alias))
276 select_idx += 1
277
278 if self.query.select_related:
279 related_klass_infos = self.get_related_selections(select, select_mask)
280 klass_info["related_klass_infos"] = related_klass_infos
281
282 def get_select_from_parent(klass_info):
283 for ki in klass_info["related_klass_infos"]:
284 if ki["from_parent"]:
285 ki["select_fields"] = (
286 klass_info["select_fields"] + ki["select_fields"]
287 )
288 get_select_from_parent(ki)
289
290 get_select_from_parent(klass_info)
291
292 ret = []
293 col_idx = 1
294 for col, alias in select:
295 try:
296 sql, params = self.compile(col)
297 except EmptyResultSet:
298 empty_result_set_value = getattr(
299 col, "empty_result_set_value", NotImplemented
300 )
301 if empty_result_set_value is NotImplemented:
302 # Select a predicate that's always False.
303 sql, params = "0", ()
304 else:
305 sql, params = self.compile(Value(empty_result_set_value))
306 except FullResultSet:
307 sql, params = self.compile(Value(True))
308 else:
309 sql, params = col.select_format(self, sql, params)
310 if alias is None and with_col_aliases:
311 alias = f"col{col_idx}"
312 col_idx += 1
313 ret.append((col, (sql, params), alias))
314 return ret, klass_info, annotations
315
316 def _order_by_pairs(self):
317 if self.query.extra_order_by:
318 ordering = self.query.extra_order_by
319 elif not self.query.default_ordering:
320 ordering = self.query.order_by
321 elif self.query.order_by:
322 ordering = self.query.order_by
323 elif (meta := self.query.get_meta()) and meta.ordering:
324 ordering = meta.ordering
325 self._meta_ordering = ordering
326 else:
327 ordering = []
328 if self.query.standard_ordering:
329 default_order, _ = ORDER_DIR["ASC"]
330 else:
331 default_order, _ = ORDER_DIR["DESC"]
332
333 selected_exprs = {}
334 if select := self.select:
335 for ordinal, (expr, _, alias) in enumerate(select, start=1):
336 pos_expr = PositionRef(ordinal, alias, expr)
337 if alias:
338 selected_exprs[alias] = pos_expr
339 selected_exprs[expr] = pos_expr
340
341 for field in ordering:
342 if hasattr(field, "resolve_expression"):
343 if isinstance(field, Value):
344 # output_field must be resolved for constants.
345 field = Cast(field, field.output_field)
346 if not isinstance(field, OrderBy):
347 field = field.asc()
348 if not self.query.standard_ordering:
349 field = field.copy()
350 field.reverse_ordering()
351 select_ref = selected_exprs.get(field.expression)
352 if select_ref or (
353 isinstance(field.expression, F)
354 and (select_ref := selected_exprs.get(field.expression.name))
355 ):
356 # Emulation of NULLS (FIRST|LAST) cannot be combined with
357 # the usage of ordering by position.
358 if (
359 field.nulls_first is None and field.nulls_last is None
360 ) or self.connection.features.supports_order_by_nulls_modifier:
361 field = field.copy()
362 field.expression = select_ref
363 # Alias collisions are not possible when dealing with
364 # combined queries so fallback to it if emulation of NULLS
365 # handling is required.
366 elif self.query.combinator:
367 field = field.copy()
368 field.expression = Ref(select_ref.refs, select_ref.source)
369 yield field, select_ref is not None
370 continue
371 if field == "?": # random
372 yield OrderBy(Random()), False
373 continue
374
375 col, order = get_order_dir(field, default_order)
376 descending = order == "DESC"
377
378 if select_ref := selected_exprs.get(col):
379 # Reference to expression in SELECT clause
380 yield (
381 OrderBy(
382 select_ref,
383 descending=descending,
384 ),
385 True,
386 )
387 continue
388 if col in self.query.annotations:
389 # References to an expression which is masked out of the SELECT
390 # clause.
391 if self.query.combinator and self.select:
392 # Don't use the resolved annotation because other
393 # combinated queries might define it differently.
394 expr = F(col)
395 else:
396 expr = self.query.annotations[col]
397 if isinstance(expr, Value):
398 # output_field must be resolved for constants.
399 expr = Cast(expr, expr.output_field)
400 yield OrderBy(expr, descending=descending), False
401 continue
402
403 if "." in field:
404 # This came in through an extra(order_by=...) addition. Pass it
405 # on verbatim.
406 table, col = col.split(".", 1)
407 yield (
408 OrderBy(
409 RawSQL(f"{self.quote_name_unless_alias(table)}.{col}", []),
410 descending=descending,
411 ),
412 False,
413 )
414 continue
415
416 if self.query.extra and col in self.query.extra:
417 if col in self.query.extra_select:
418 yield (
419 OrderBy(
420 Ref(col, RawSQL(*self.query.extra[col])),
421 descending=descending,
422 ),
423 True,
424 )
425 else:
426 yield (
427 OrderBy(RawSQL(*self.query.extra[col]), descending=descending),
428 False,
429 )
430 else:
431 if self.query.combinator and self.select:
432 # Don't use the first model's field because other
433 # combinated queries might define it differently.
434 yield OrderBy(F(col), descending=descending), False
435 else:
436 # 'col' is of the form 'field' or 'field1__field2' or
437 # '-field1__field2__field', etc.
438 yield from self.find_ordering_name(
439 field,
440 self.query.get_meta(),
441 default_order=default_order,
442 )
443
444 def get_order_by(self):
445 """
446 Return a list of 2-tuples of the form (expr, (sql, params, is_ref)) for
447 the ORDER BY clause.
448
449 The order_by clause can alter the select clause (for example it can add
450 aliases to clauses that do not yet have one, or it can add totally new
451 select clauses).
452 """
453 result = []
454 seen = set()
455 for expr, is_ref in self._order_by_pairs():
456 resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None)
457 if not is_ref and self.query.combinator and self.select:
458 src = resolved.expression
459 expr_src = expr.expression
460 for sel_expr, _, col_alias in self.select:
461 if src == sel_expr:
462 # When values() is used the exact alias must be used to
463 # reference annotations.
464 if (
465 self.query.has_select_fields
466 and col_alias in self.query.annotation_select
467 and not (
468 isinstance(expr_src, F) and col_alias == expr_src.name
469 )
470 ):
471 continue
472 resolved.set_source_expressions(
473 [Ref(col_alias if col_alias else src.target.column, src)]
474 )
475 break
476 else:
477 # Add column used in ORDER BY clause to the selected
478 # columns and to each combined query.
479 order_by_idx = len(self.query.select) + 1
480 col_alias = f"__orderbycol{order_by_idx}"
481 for q in self.query.combined_queries:
482 # If fields were explicitly selected through values()
483 # combined queries cannot be augmented.
484 if q.has_select_fields:
485 raise DatabaseError(
486 "ORDER BY term does not match any column in "
487 "the result set."
488 )
489 q.add_annotation(expr_src, col_alias)
490 self.query.add_select_col(resolved, col_alias)
491 resolved.set_source_expressions([Ref(col_alias, src)])
492 sql, params = self.compile(resolved)
493 # Don't add the same column twice, but the order direction is
494 # not taken into account so we strip it. When this entire method
495 # is refactored into expressions, then we can check each part as we
496 # generate it.
497 without_ordering = self.ordering_parts.search(sql)[1]
498 params_hash = make_hashable(params)
499 if (without_ordering, params_hash) in seen:
500 continue
501 seen.add((without_ordering, params_hash))
502 result.append((resolved, (sql, params, is_ref)))
503 return result
504
505 def get_extra_select(self, order_by, select):
506 extra_select = []
507 if self.query.distinct and not self.query.distinct_fields:
508 select_sql = [t[1] for t in select]
509 for expr, (sql, params, is_ref) in order_by:
510 without_ordering = self.ordering_parts.search(sql)[1]
511 if not is_ref and (without_ordering, params) not in select_sql:
512 extra_select.append((expr, (without_ordering, params), None))
513 return extra_select
514
515 def quote_name_unless_alias(self, name):
516 """
517 A wrapper around connection.ops.quote_name that doesn't quote aliases
518 for table names. This avoids problems with some SQL dialects that treat
519 quoted strings specially (e.g. PostgreSQL).
520 """
521 if name in self.quote_cache:
522 return self.quote_cache[name]
523 if (
524 (name in self.query.alias_map and name not in self.query.table_map)
525 or name in self.query.extra_select
526 or (
527 self.query.external_aliases.get(name)
528 and name not in self.query.table_map
529 )
530 ):
531 self.quote_cache[name] = name
532 return name
533 r = self.connection.ops.quote_name(name)
534 self.quote_cache[name] = r
535 return r
536
537 def compile(self, node):
538 vendor_impl = getattr(node, "as_" + self.connection.vendor, None)
539 if vendor_impl:
540 sql, params = vendor_impl(self, self.connection)
541 else:
542 sql, params = node.as_sql(self, self.connection)
543 return sql, params
544
545 def get_combinator_sql(self, combinator, all):
546 features = self.connection.features
547 compilers = [
548 query.get_compiler(self.using, self.connection, self.elide_empty)
549 for query in self.query.combined_queries
550 ]
551 if not features.supports_slicing_ordering_in_compound:
552 for compiler in compilers:
553 if compiler.query.is_sliced:
554 raise DatabaseError(
555 "LIMIT/OFFSET not allowed in subqueries of compound statements."
556 )
557 if compiler.get_order_by():
558 raise DatabaseError(
559 "ORDER BY not allowed in subqueries of compound statements."
560 )
561 elif self.query.is_sliced and combinator == "union":
562 for compiler in compilers:
563 # A sliced union cannot have its parts elided as some of them
564 # might be sliced as well and in the event where only a single
565 # part produces a non-empty resultset it might be impossible to
566 # generate valid SQL.
567 compiler.elide_empty = False
568 parts = ()
569 for compiler in compilers:
570 try:
571 # If the columns list is limited, then all combined queries
572 # must have the same columns list. Set the selects defined on
573 # the query on all combined queries, if not already set.
574 if not compiler.query.values_select and self.query.values_select:
575 compiler.query = compiler.query.clone()
576 compiler.query.set_values(
577 (
578 *self.query.extra_select,
579 *self.query.values_select,
580 *self.query.annotation_select,
581 )
582 )
583 part_sql, part_args = compiler.as_sql(with_col_aliases=True)
584 if compiler.query.combinator:
585 # Wrap in a subquery if wrapping in parentheses isn't
586 # supported.
587 if not features.supports_parentheses_in_compound:
588 part_sql = f"SELECT * FROM ({part_sql})"
589 # Add parentheses when combining with compound query if not
590 # already added for all compound queries.
591 elif (
592 self.query.subquery
593 or not features.supports_slicing_ordering_in_compound
594 ):
595 part_sql = f"({part_sql})"
596 elif (
597 self.query.subquery
598 and features.supports_slicing_ordering_in_compound
599 ):
600 part_sql = f"({part_sql})"
601 parts += ((part_sql, part_args),)
602 except EmptyResultSet:
603 # Omit the empty queryset with UNION and with DIFFERENCE if the
604 # first queryset is nonempty.
605 if combinator == "union" or (combinator == "difference" and parts):
606 continue
607 raise
608 if not parts:
609 raise EmptyResultSet
610 combinator_sql = self.connection.ops.set_operators[combinator]
611 if all and combinator == "union":
612 combinator_sql += " ALL"
613 braces = "{}"
614 if not self.query.subquery and features.supports_slicing_ordering_in_compound:
615 braces = "({})"
616 sql_parts, args_parts = zip(
617 *((braces.format(sql), args) for sql, args in parts)
618 )
619 result = [f" {combinator_sql} ".join(sql_parts)]
620 params = []
621 for part in args_parts:
622 params.extend(part)
623 return result, params
624
625 def get_qualify_sql(self):
626 where_parts = []
627 if self.where:
628 where_parts.append(self.where)
629 if self.having:
630 where_parts.append(self.having)
631 inner_query = self.query.clone()
632 inner_query.subquery = True
633 inner_query.where = inner_query.where.__class__(where_parts)
634 # Augment the inner query with any window function references that
635 # might have been masked via values() and alias(). If any masked
636 # aliases are added they'll be masked again to avoid fetching
637 # the data in the `if qual_aliases` branch below.
638 select = {
639 expr: alias for expr, _, alias in self.get_select(with_col_aliases=True)[0]
640 }
641 select_aliases = set(select.values())
642 qual_aliases = set()
643 replacements = {}
644
645 def collect_replacements(expressions):
646 while expressions:
647 expr = expressions.pop()
648 if expr in replacements:
649 continue
650 elif select_alias := select.get(expr):
651 replacements[expr] = select_alias
652 elif isinstance(expr, Lookup):
653 expressions.extend(expr.get_source_expressions())
654 elif isinstance(expr, Ref):
655 if expr.refs not in select_aliases:
656 expressions.extend(expr.get_source_expressions())
657 else:
658 num_qual_alias = len(qual_aliases)
659 select_alias = f"qual{num_qual_alias}"
660 qual_aliases.add(select_alias)
661 inner_query.add_annotation(expr, select_alias)
662 replacements[expr] = select_alias
663
664 collect_replacements(list(self.qualify.leaves()))
665 self.qualify = self.qualify.replace_expressions(
666 {expr: Ref(alias, expr) for expr, alias in replacements.items()}
667 )
668 order_by = []
669 for order_by_expr, *_ in self.get_order_by():
670 collect_replacements(order_by_expr.get_source_expressions())
671 order_by.append(
672 order_by_expr.replace_expressions(
673 {expr: Ref(alias, expr) for expr, alias in replacements.items()}
674 )
675 )
676 inner_query_compiler = inner_query.get_compiler(
677 self.using, connection=self.connection, elide_empty=self.elide_empty
678 )
679 inner_sql, inner_params = inner_query_compiler.as_sql(
680 # The limits must be applied to the outer query to avoid pruning
681 # results too eagerly.
682 with_limits=False,
683 # Force unique aliasing of selected columns to avoid collisions
684 # and make rhs predicates referencing easier.
685 with_col_aliases=True,
686 )
687 qualify_sql, qualify_params = self.compile(self.qualify)
688 result = [
689 "SELECT * FROM (",
690 inner_sql,
691 ")",
692 self.connection.ops.quote_name("qualify"),
693 "WHERE",
694 qualify_sql,
695 ]
696 if qual_aliases:
697 # If some select aliases were unmasked for filtering purposes they
698 # must be masked back.
699 cols = [self.connection.ops.quote_name(alias) for alias in select.values()]
700 result = [
701 "SELECT",
702 ", ".join(cols),
703 "FROM (",
704 *result,
705 ")",
706 self.connection.ops.quote_name("qualify_mask"),
707 ]
708 params = list(inner_params) + qualify_params
709 # As the SQL spec is unclear on whether or not derived tables
710 # ordering must propagate it has to be explicitly repeated on the
711 # outer-most query to ensure it's preserved.
712 if order_by:
713 ordering_sqls = []
714 for ordering in order_by:
715 ordering_sql, ordering_params = self.compile(ordering)
716 ordering_sqls.append(ordering_sql)
717 params.extend(ordering_params)
718 result.extend(["ORDER BY", ", ".join(ordering_sqls)])
719 return result, params
720
721 def as_sql(self, with_limits=True, with_col_aliases=False):
722 """
723 Create the SQL for this query. Return the SQL string and list of
724 parameters.
725
726 If 'with_limits' is False, any limit/offset information is not included
727 in the query.
728 """
729 refcounts_before = self.query.alias_refcount.copy()
730 try:
731 combinator = self.query.combinator
732 extra_select, order_by, group_by = self.pre_sql_setup(
733 with_col_aliases=with_col_aliases or bool(combinator),
734 )
735 for_update_part = None
736 # Is a LIMIT/OFFSET clause needed?
737 with_limit_offset = with_limits and self.query.is_sliced
738 combinator = self.query.combinator
739 features = self.connection.features
740 if combinator:
741 if not getattr(features, f"supports_select_{combinator}"):
742 raise NotSupportedError(
743 f"{combinator} is not supported on this database backend."
744 )
745 result, params = self.get_combinator_sql(
746 combinator, self.query.combinator_all
747 )
748 elif self.qualify:
749 result, params = self.get_qualify_sql()
750 order_by = None
751 else:
752 distinct_fields, distinct_params = self.get_distinct()
753 # This must come after 'select', 'ordering', and 'distinct'
754 # (see docstring of get_from_clause() for details).
755 from_, f_params = self.get_from_clause()
756 try:
757 where, w_params = (
758 self.compile(self.where) if self.where is not None else ("", [])
759 )
760 except EmptyResultSet:
761 if self.elide_empty:
762 raise
763 # Use a predicate that's always False.
764 where, w_params = "0 = 1", []
765 except FullResultSet:
766 where, w_params = "", []
767 try:
768 having, h_params = (
769 self.compile(self.having)
770 if self.having is not None
771 else ("", [])
772 )
773 except FullResultSet:
774 having, h_params = "", []
775 result = ["SELECT"]
776 params = []
777
778 if self.query.distinct:
779 distinct_result, distinct_params = self.connection.ops.distinct_sql(
780 distinct_fields,
781 distinct_params,
782 )
783 result += distinct_result
784 params += distinct_params
785
786 out_cols = []
787 for _, (s_sql, s_params), alias in self.select + extra_select:
788 if alias:
789 s_sql = f"{s_sql} AS {self.connection.ops.quote_name(alias)}"
790 params.extend(s_params)
791 out_cols.append(s_sql)
792
793 result += [", ".join(out_cols)]
794 if from_:
795 result += ["FROM", *from_]
796 elif self.connection.features.bare_select_suffix:
797 result += [self.connection.features.bare_select_suffix]
798 params.extend(f_params)
799
800 if self.query.select_for_update and features.has_select_for_update:
801 if (
802 self.connection.get_autocommit()
803 # Don't raise an exception when database doesn't
804 # support transactions, as it's a noop.
805 and features.supports_transactions
806 ):
807 raise TransactionManagementError(
808 "select_for_update cannot be used outside of a transaction."
809 )
810
811 if (
812 with_limit_offset
813 and not features.supports_select_for_update_with_limit
814 ):
815 raise NotSupportedError(
816 "LIMIT/OFFSET is not supported with "
817 "select_for_update on this database backend."
818 )
819 nowait = self.query.select_for_update_nowait
820 skip_locked = self.query.select_for_update_skip_locked
821 of = self.query.select_for_update_of
822 no_key = self.query.select_for_no_key_update
823 # If it's a NOWAIT/SKIP LOCKED/OF/NO KEY query but the
824 # backend doesn't support it, raise NotSupportedError to
825 # prevent a possible deadlock.
826 if nowait and not features.has_select_for_update_nowait:
827 raise NotSupportedError(
828 "NOWAIT is not supported on this database backend."
829 )
830 elif skip_locked and not features.has_select_for_update_skip_locked:
831 raise NotSupportedError(
832 "SKIP LOCKED is not supported on this database backend."
833 )
834 elif of and not features.has_select_for_update_of:
835 raise NotSupportedError(
836 "FOR UPDATE OF is not supported on this database backend."
837 )
838 elif no_key and not features.has_select_for_no_key_update:
839 raise NotSupportedError(
840 "FOR NO KEY UPDATE is not supported on this "
841 "database backend."
842 )
843 for_update_part = self.connection.ops.for_update_sql(
844 nowait=nowait,
845 skip_locked=skip_locked,
846 of=self.get_select_for_update_of_arguments(),
847 no_key=no_key,
848 )
849
850 if for_update_part and features.for_update_after_from:
851 result.append(for_update_part)
852
853 if where:
854 result.append("WHERE %s" % where)
855 params.extend(w_params)
856
857 grouping = []
858 for g_sql, g_params in group_by:
859 grouping.append(g_sql)
860 params.extend(g_params)
861 if grouping:
862 if distinct_fields:
863 raise NotImplementedError(
864 "annotate() + distinct(fields) is not implemented."
865 )
866 order_by = order_by or self.connection.ops.force_no_ordering()
867 result.append("GROUP BY %s" % ", ".join(grouping))
868 if self._meta_ordering:
869 order_by = None
870 if having:
871 result.append("HAVING %s" % having)
872 params.extend(h_params)
873
874 if self.query.explain_info:
875 result.insert(
876 0,
877 self.connection.ops.explain_query_prefix(
878 self.query.explain_info.format,
879 **self.query.explain_info.options,
880 ),
881 )
882
883 if order_by:
884 ordering = []
885 for _, (o_sql, o_params, _) in order_by:
886 ordering.append(o_sql)
887 params.extend(o_params)
888 order_by_sql = "ORDER BY %s" % ", ".join(ordering)
889 if combinator and features.requires_compound_order_by_subquery:
890 result = ["SELECT * FROM (", *result, ")", order_by_sql]
891 else:
892 result.append(order_by_sql)
893
894 if with_limit_offset:
895 result.append(
896 self.connection.ops.limit_offset_sql(
897 self.query.low_mark, self.query.high_mark
898 )
899 )
900
901 if for_update_part and not features.for_update_after_from:
902 result.append(for_update_part)
903
904 if self.query.subquery and extra_select:
905 # If the query is used as a subquery, the extra selects would
906 # result in more columns than the left-hand side expression is
907 # expecting. This can happen when a subquery uses a combination
908 # of order_by() and distinct(), forcing the ordering expressions
909 # to be selected as well. Wrap the query in another subquery
910 # to exclude extraneous selects.
911 sub_selects = []
912 sub_params = []
913 for index, (select, _, alias) in enumerate(self.select, start=1):
914 if alias:
915 sub_selects.append(
916 "{}.{}".format(
917 self.connection.ops.quote_name("subquery"),
918 self.connection.ops.quote_name(alias),
919 )
920 )
921 else:
922 select_clone = select.relabeled_clone(
923 {select.alias: "subquery"}
924 )
925 subselect, subparams = select_clone.as_sql(
926 self, self.connection
927 )
928 sub_selects.append(subselect)
929 sub_params.extend(subparams)
930 return "SELECT {} FROM ({}) subquery".format(
931 ", ".join(sub_selects),
932 " ".join(result),
933 ), tuple(sub_params + params)
934
935 return " ".join(result), tuple(params)
936 finally:
937 # Finally do cleanup - get rid of the joins we created above.
938 self.query.reset_refcounts(refcounts_before)
939
940 def get_default_columns(
941 self, select_mask, start_alias=None, opts=None, from_parent=None
942 ):
943 """
944 Compute the default columns for selecting every field in the base
945 model. Will sometimes be called to pull in related models (e.g. via
946 select_related), in which case "opts" and "start_alias" will be given
947 to provide a starting point for the traversal.
948
949 Return a list of strings, quoted appropriately for use in SQL
950 directly, as well as a set of aliases used in the select statement (if
951 'as_pairs' is True, return a list of (alias, col_name) pairs instead
952 of strings as the first component and None as the second component).
953 """
954 result = []
955 if opts is None:
956 if (opts := self.query.get_meta()) is None:
957 return result
958 start_alias = start_alias or self.query.get_initial_alias()
959 # The 'seen_models' is used to optimize checking the needed parent
960 # alias for a given field. This also includes None -> start_alias to
961 # be used by local fields.
962 seen_models = {None: start_alias}
963
964 for field in opts.concrete_fields:
965 model = field.model._meta.concrete_model
966 # A proxy model will have a different model and concrete_model. We
967 # will assign None if the field belongs to this model.
968 if model == opts.model:
969 model = None
970 if (
971 from_parent
972 and model is not None
973 and issubclass(
974 from_parent._meta.concrete_model, model._meta.concrete_model
975 )
976 ):
977 # Avoid loading data for already loaded parents.
978 # We end up here in the case select_related() resolution
979 # proceeds from parent model to child model. In that case the
980 # parent model data is already present in the SELECT clause,
981 # and we want to avoid reloading the same data again.
982 continue
983 if select_mask and field not in select_mask:
984 continue
985 alias = self.query.join_parent_model(opts, model, start_alias, seen_models)
986 column = field.get_col(alias)
987 result.append(column)
988 return result
989
990 def get_distinct(self):
991 """
992 Return a quoted list of fields to use in DISTINCT ON part of the query.
993
994 This method can alter the tables in the query, and thus it must be
995 called before get_from_clause().
996 """
997 result = []
998 params = []
999 opts = self.query.get_meta()
1000
1001 for name in self.query.distinct_fields:
1002 parts = name.split(LOOKUP_SEP)
1003 _, targets, alias, joins, path, _, transform_function = self._setup_joins(
1004 parts, opts, None
1005 )
1006 targets, alias, _ = self.query.trim_joins(targets, joins, path)
1007 for target in targets:
1008 if name in self.query.annotation_select:
1009 result.append(self.connection.ops.quote_name(name))
1010 else:
1011 r, p = self.compile(transform_function(target, alias))
1012 result.append(r)
1013 params.append(p)
1014 return result, params
1015
1016 def find_ordering_name(
1017 self, name, opts, alias=None, default_order="ASC", already_seen=None
1018 ):
1019 """
1020 Return the table alias (the name might be ambiguous, the alias will
1021 not be) and column name for ordering by the given 'name' parameter.
1022 The 'name' is of the form 'field1__field2__...__fieldN'.
1023 """
1024 name, order = get_order_dir(name, default_order)
1025 descending = order == "DESC"
1026 pieces = name.split(LOOKUP_SEP)
1027 (
1028 field,
1029 targets,
1030 alias,
1031 joins,
1032 path,
1033 opts,
1034 transform_function,
1035 ) = self._setup_joins(pieces, opts, alias)
1036
1037 # If we get to this point and the field is a relation to another model,
1038 # append the default ordering for that model unless it is the pk
1039 # shortcut or the attribute name of the field that is specified or
1040 # there are transforms to process.
1041 if (
1042 field.is_relation
1043 and opts.ordering
1044 and getattr(field, "attname", None) != pieces[-1]
1045 and name != "pk"
1046 and not getattr(transform_function, "has_transforms", False)
1047 ):
1048 # Firstly, avoid infinite loops.
1049 already_seen = already_seen or set()
1050 join_tuple = tuple(
1051 getattr(self.query.alias_map[j], "join_cols", None) for j in joins
1052 )
1053 if join_tuple in already_seen:
1054 raise FieldError("Infinite loop caused by ordering.")
1055 already_seen.add(join_tuple)
1056
1057 results = []
1058 for item in opts.ordering:
1059 if hasattr(item, "resolve_expression") and not isinstance(
1060 item, OrderBy
1061 ):
1062 item = item.desc() if descending else item.asc()
1063 if isinstance(item, OrderBy):
1064 results.append(
1065 (item.prefix_references(f"{name}{LOOKUP_SEP}"), False)
1066 )
1067 continue
1068 results.extend(
1069 (expr.prefix_references(f"{name}{LOOKUP_SEP}"), is_ref)
1070 for expr, is_ref in self.find_ordering_name(
1071 item, opts, alias, order, already_seen
1072 )
1073 )
1074 return results
1075 targets, alias, _ = self.query.trim_joins(targets, joins, path)
1076 return [
1077 (OrderBy(transform_function(t, alias), descending=descending), False)
1078 for t in targets
1079 ]
1080
1081 def _setup_joins(self, pieces, opts, alias):
1082 """
1083 Helper method for get_order_by() and get_distinct().
1084
1085 get_ordering() and get_distinct() must produce same target columns on
1086 same input, as the prefixes of get_ordering() and get_distinct() must
1087 match. Executing SQL where this is not true is an error.
1088 """
1089 alias = alias or self.query.get_initial_alias()
1090 field, targets, opts, joins, path, transform_function = self.query.setup_joins(
1091 pieces, opts, alias
1092 )
1093 alias = joins[-1]
1094 return field, targets, alias, joins, path, opts, transform_function
1095
1096 def get_from_clause(self):
1097 """
1098 Return a list of strings that are joined together to go after the
1099 "FROM" part of the query, as well as a list any extra parameters that
1100 need to be included. Subclasses, can override this to create a
1101 from-clause via a "select".
1102
1103 This should only be called after any SQL construction methods that
1104 might change the tables that are needed. This means the select columns,
1105 ordering, and distinct must be done first.
1106 """
1107 result = []
1108 params = []
1109 for alias in tuple(self.query.alias_map):
1110 if not self.query.alias_refcount[alias]:
1111 continue
1112 try:
1113 from_clause = self.query.alias_map[alias]
1114 except KeyError:
1115 # Extra tables can end up in self.tables, but not in the
1116 # alias_map if they aren't in a join. That's OK. We skip them.
1117 continue
1118 clause_sql, clause_params = self.compile(from_clause)
1119 result.append(clause_sql)
1120 params.extend(clause_params)
1121 for t in self.query.extra_tables:
1122 alias, _ = self.query.table_alias(t)
1123 # Only add the alias if it's not already present (the table_alias()
1124 # call increments the refcount, so an alias refcount of one means
1125 # this is the only reference).
1126 if (
1127 alias not in self.query.alias_map
1128 or self.query.alias_refcount[alias] == 1
1129 ):
1130 result.append(", %s" % self.quote_name_unless_alias(alias))
1131 return result, params
1132
1133 def get_related_selections(
1134 self,
1135 select,
1136 select_mask,
1137 opts=None,
1138 root_alias=None,
1139 cur_depth=1,
1140 requested=None,
1141 restricted=None,
1142 ):
1143 """
1144 Fill in the information needed for a select_related query. The current
1145 depth is measured as the number of connections away from the root model
1146 (for example, cur_depth=1 means we are looking at models with direct
1147 connections to the root model).
1148 """
1149
1150 def _get_field_choices():
1151 direct_choices = (f.name for f in opts.fields if f.is_relation)
1152 reverse_choices = (
1153 f.field.related_query_name()
1154 for f in opts.related_objects
1155 if f.field.unique
1156 )
1157 return chain(
1158 direct_choices, reverse_choices, self.query._filtered_relations
1159 )
1160
1161 related_klass_infos = []
1162 if not restricted and cur_depth > self.query.max_depth:
1163 # We've recursed far enough; bail out.
1164 return related_klass_infos
1165
1166 if not opts:
1167 opts = self.query.get_meta()
1168 root_alias = self.query.get_initial_alias()
1169
1170 # Setup for the case when only particular related fields should be
1171 # included in the related selection.
1172 fields_found = set()
1173 if requested is None:
1174 restricted = isinstance(self.query.select_related, dict)
1175 if restricted:
1176 requested = self.query.select_related
1177
1178 def get_related_klass_infos(klass_info, related_klass_infos):
1179 klass_info["related_klass_infos"] = related_klass_infos
1180
1181 for f in opts.fields:
1182 fields_found.add(f.name)
1183
1184 if restricted:
1185 next = requested.get(f.name, {})
1186 if not f.is_relation:
1187 # If a non-related field is used like a relation,
1188 # or if a single non-relational field is given.
1189 if next or f.name in requested:
1190 raise FieldError(
1191 "Non-relational field given in select_related: '{}'. "
1192 "Choices are: {}".format(
1193 f.name,
1194 ", ".join(_get_field_choices()) or "(none)",
1195 )
1196 )
1197 else:
1198 next = False
1199
1200 if not select_related_descend(f, restricted, requested, select_mask):
1201 continue
1202 related_select_mask = select_mask.get(f) or {}
1203 klass_info = {
1204 "model": f.remote_field.model,
1205 "field": f,
1206 "reverse": False,
1207 "local_setter": f.set_cached_value,
1208 "remote_setter": f.remote_field.set_cached_value
1209 if f.unique
1210 else lambda x, y: None,
1211 "from_parent": False,
1212 }
1213 related_klass_infos.append(klass_info)
1214 select_fields = []
1215 _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias)
1216 alias = joins[-1]
1217 columns = self.get_default_columns(
1218 related_select_mask, start_alias=alias, opts=f.remote_field.model._meta
1219 )
1220 for col in columns:
1221 select_fields.append(len(select))
1222 select.append((col, None))
1223 klass_info["select_fields"] = select_fields
1224 next_klass_infos = self.get_related_selections(
1225 select,
1226 related_select_mask,
1227 f.remote_field.model._meta,
1228 alias,
1229 cur_depth + 1,
1230 next,
1231 restricted,
1232 )
1233 get_related_klass_infos(klass_info, next_klass_infos)
1234
1235 if restricted:
1236 related_fields = [
1237 (o.field, o.related_model)
1238 for o in opts.related_objects
1239 if o.field.unique and not o.many_to_many
1240 ]
1241 for related_field, model in related_fields:
1242 related_select_mask = select_mask.get(related_field) or {}
1243 if not select_related_descend(
1244 related_field,
1245 restricted,
1246 requested,
1247 related_select_mask,
1248 reverse=True,
1249 ):
1250 continue
1251
1252 related_field_name = related_field.related_query_name()
1253 fields_found.add(related_field_name)
1254
1255 join_info = self.query.setup_joins(
1256 [related_field_name], opts, root_alias
1257 )
1258 alias = join_info.joins[-1]
1259 from_parent = issubclass(model, opts.model) and model is not opts.model
1260 klass_info = {
1261 "model": model,
1262 "field": related_field,
1263 "reverse": True,
1264 "local_setter": related_field.remote_field.set_cached_value,
1265 "remote_setter": related_field.set_cached_value,
1266 "from_parent": from_parent,
1267 }
1268 related_klass_infos.append(klass_info)
1269 select_fields = []
1270 columns = self.get_default_columns(
1271 related_select_mask,
1272 start_alias=alias,
1273 opts=model._meta,
1274 from_parent=opts.model,
1275 )
1276 for col in columns:
1277 select_fields.append(len(select))
1278 select.append((col, None))
1279 klass_info["select_fields"] = select_fields
1280 next = requested.get(related_field.related_query_name(), {})
1281 next_klass_infos = self.get_related_selections(
1282 select,
1283 related_select_mask,
1284 model._meta,
1285 alias,
1286 cur_depth + 1,
1287 next,
1288 restricted,
1289 )
1290 get_related_klass_infos(klass_info, next_klass_infos)
1291
1292 def local_setter(final_field, obj, from_obj):
1293 # Set a reverse fk object when relation is non-empty.
1294 if from_obj:
1295 final_field.remote_field.set_cached_value(from_obj, obj)
1296
1297 def local_setter_noop(obj, from_obj):
1298 pass
1299
1300 def remote_setter(name, obj, from_obj):
1301 setattr(from_obj, name, obj)
1302
1303 for name in list(requested):
1304 # Filtered relations work only on the topmost level.
1305 if cur_depth > 1:
1306 break
1307 if name in self.query._filtered_relations:
1308 fields_found.add(name)
1309 final_field, _, join_opts, joins, _, _ = self.query.setup_joins(
1310 [name], opts, root_alias
1311 )
1312 model = join_opts.model
1313 alias = joins[-1]
1314 from_parent = (
1315 issubclass(model, opts.model) and model is not opts.model
1316 )
1317 klass_info = {
1318 "model": model,
1319 "field": final_field,
1320 "reverse": True,
1321 "local_setter": (
1322 partial(local_setter, final_field)
1323 if len(joins) <= 2
1324 else local_setter_noop
1325 ),
1326 "remote_setter": partial(remote_setter, name),
1327 "from_parent": from_parent,
1328 }
1329 related_klass_infos.append(klass_info)
1330 select_fields = []
1331 field_select_mask = select_mask.get((name, final_field)) or {}
1332 columns = self.get_default_columns(
1333 field_select_mask,
1334 start_alias=alias,
1335 opts=model._meta,
1336 from_parent=opts.model,
1337 )
1338 for col in columns:
1339 select_fields.append(len(select))
1340 select.append((col, None))
1341 klass_info["select_fields"] = select_fields
1342 next_requested = requested.get(name, {})
1343 next_klass_infos = self.get_related_selections(
1344 select,
1345 field_select_mask,
1346 opts=model._meta,
1347 root_alias=alias,
1348 cur_depth=cur_depth + 1,
1349 requested=next_requested,
1350 restricted=restricted,
1351 )
1352 get_related_klass_infos(klass_info, next_klass_infos)
1353 fields_not_found = set(requested).difference(fields_found)
1354 if fields_not_found:
1355 invalid_fields = ("'%s'" % s for s in fields_not_found)
1356 raise FieldError(
1357 "Invalid field name(s) given in select_related: {}. "
1358 "Choices are: {}".format(
1359 ", ".join(invalid_fields),
1360 ", ".join(_get_field_choices()) or "(none)",
1361 )
1362 )
1363 return related_klass_infos
1364
1365 def get_select_for_update_of_arguments(self):
1366 """
1367 Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
1368 the query.
1369 """
1370
1371 def _get_parent_klass_info(klass_info):
1372 concrete_model = klass_info["model"]._meta.concrete_model
1373 for parent_model, parent_link in concrete_model._meta.parents.items():
1374 parent_list = parent_model._meta.get_parent_list()
1375 yield {
1376 "model": parent_model,
1377 "field": parent_link,
1378 "reverse": False,
1379 "select_fields": [
1380 select_index
1381 for select_index in klass_info["select_fields"]
1382 # Selected columns from a model or its parents.
1383 if (
1384 self.select[select_index][0].target.model == parent_model
1385 or self.select[select_index][0].target.model in parent_list
1386 )
1387 ],
1388 }
1389
1390 def _get_first_selected_col_from_model(klass_info):
1391 """
1392 Find the first selected column from a model. If it doesn't exist,
1393 don't lock a model.
1394
1395 select_fields is filled recursively, so it also contains fields
1396 from the parent models.
1397 """
1398 concrete_model = klass_info["model"]._meta.concrete_model
1399 for select_index in klass_info["select_fields"]:
1400 if self.select[select_index][0].target.model == concrete_model:
1401 return self.select[select_index][0]
1402
1403 def _get_field_choices():
1404 """Yield all allowed field paths in breadth-first search order."""
1405 queue = collections.deque([(None, self.klass_info)])
1406 while queue:
1407 parent_path, klass_info = queue.popleft()
1408 if parent_path is None:
1409 path = []
1410 yield "self"
1411 else:
1412 field = klass_info["field"]
1413 if klass_info["reverse"]:
1414 field = field.remote_field
1415 path = parent_path + [field.name]
1416 yield LOOKUP_SEP.join(path)
1417 queue.extend(
1418 (path, klass_info)
1419 for klass_info in _get_parent_klass_info(klass_info)
1420 )
1421 queue.extend(
1422 (path, klass_info)
1423 for klass_info in klass_info.get("related_klass_infos", [])
1424 )
1425
1426 if not self.klass_info:
1427 return []
1428 result = []
1429 invalid_names = []
1430 for name in self.query.select_for_update_of:
1431 klass_info = self.klass_info
1432 if name == "self":
1433 col = _get_first_selected_col_from_model(klass_info)
1434 else:
1435 for part in name.split(LOOKUP_SEP):
1436 klass_infos = (
1437 *klass_info.get("related_klass_infos", []),
1438 *_get_parent_klass_info(klass_info),
1439 )
1440 for related_klass_info in klass_infos:
1441 field = related_klass_info["field"]
1442 if related_klass_info["reverse"]:
1443 field = field.remote_field
1444 if field.name == part:
1445 klass_info = related_klass_info
1446 break
1447 else:
1448 klass_info = None
1449 break
1450 if klass_info is None:
1451 invalid_names.append(name)
1452 continue
1453 col = _get_first_selected_col_from_model(klass_info)
1454 if col is not None:
1455 if self.connection.features.select_for_update_of_column:
1456 result.append(self.compile(col)[0])
1457 else:
1458 result.append(self.quote_name_unless_alias(col.alias))
1459 if invalid_names:
1460 raise FieldError(
1461 "Invalid field name(s) given in select_for_update(of=(...)): {}. "
1462 "Only relational fields followed in the query are allowed. "
1463 "Choices are: {}.".format(
1464 ", ".join(invalid_names),
1465 ", ".join(_get_field_choices()),
1466 )
1467 )
1468 return result
1469
1470 def get_converters(self, expressions):
1471 converters = {}
1472 for i, expression in enumerate(expressions):
1473 if expression:
1474 backend_converters = self.connection.ops.get_db_converters(expression)
1475 field_converters = expression.get_db_converters(self.connection)
1476 if backend_converters or field_converters:
1477 converters[i] = (backend_converters + field_converters, expression)
1478 return converters
1479
1480 def apply_converters(self, rows, converters):
1481 connection = self.connection
1482 converters = list(converters.items())
1483 for row in map(list, rows):
1484 for pos, (convs, expression) in converters:
1485 value = row[pos]
1486 for converter in convs:
1487 value = converter(value, expression, connection)
1488 row[pos] = value
1489 yield row
1490
1491 def results_iter(
1492 self,
1493 results=None,
1494 tuple_expected=False,
1495 chunked_fetch=False,
1496 chunk_size=GET_ITERATOR_CHUNK_SIZE,
1497 ):
1498 """Return an iterator over the results from executing this query."""
1499 if results is None:
1500 results = self.execute_sql(
1501 MULTI, chunked_fetch=chunked_fetch, chunk_size=chunk_size
1502 )
1503 fields = [s[0] for s in self.select[0 : self.col_count]]
1504 converters = self.get_converters(fields)
1505 rows = chain.from_iterable(results)
1506 if converters:
1507 rows = self.apply_converters(rows, converters)
1508 if tuple_expected:
1509 rows = map(tuple, rows)
1510 return rows
1511
1512 def has_results(self):
1513 """
1514 Backends (e.g. NoSQL) can override this in order to use optimized
1515 versions of "query has any results."
1516 """
1517 return bool(self.execute_sql(SINGLE))
1518
1519 def execute_sql(
1520 self, result_type=MULTI, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
1521 ):
1522 """
1523 Run the query against the database and return the result(s). The
1524 return value is a single data item if result_type is SINGLE, or an
1525 iterator over the results if the result_type is MULTI.
1526
1527 result_type is either MULTI (use fetchmany() to retrieve all rows),
1528 SINGLE (only retrieve a single row), or None. In this last case, the
1529 cursor is returned if any query is executed, since it's used by
1530 subclasses such as InsertQuery). It's possible, however, that no query
1531 is needed, as the filters describe an empty set. In that case, None is
1532 returned, to avoid any unnecessary database interaction.
1533 """
1534 result_type = result_type or NO_RESULTS
1535 try:
1536 sql, params = self.as_sql()
1537 if not sql:
1538 raise EmptyResultSet
1539 except EmptyResultSet:
1540 if result_type == MULTI:
1541 return iter([])
1542 else:
1543 return
1544 if chunked_fetch:
1545 cursor = self.connection.chunked_cursor()
1546 else:
1547 cursor = self.connection.cursor()
1548 try:
1549 cursor.execute(sql, params)
1550 except Exception:
1551 # Might fail for server-side cursors (e.g. connection closed)
1552 cursor.close()
1553 raise
1554
1555 if result_type == CURSOR:
1556 # Give the caller the cursor to process and close.
1557 return cursor
1558 if result_type == SINGLE:
1559 try:
1560 val = cursor.fetchone()
1561 if val:
1562 return val[0 : self.col_count]
1563 return val
1564 finally:
1565 # done with the cursor
1566 cursor.close()
1567 if result_type == NO_RESULTS:
1568 cursor.close()
1569 return
1570
1571 result = cursor_iter(
1572 cursor,
1573 self.connection.features.empty_fetchmany_value,
1574 self.col_count if self.has_extra_select else None,
1575 chunk_size,
1576 )
1577 if not chunked_fetch or not self.connection.features.can_use_chunked_reads:
1578 # If we are using non-chunked reads, we return the same data
1579 # structure as normally, but ensure it is all read into memory
1580 # before going any further. Use chunked_fetch if requested,
1581 # unless the database doesn't support it.
1582 return list(result)
1583 return result
1584
1585 def as_subquery_condition(self, alias, columns, compiler):
1586 qn = compiler.quote_name_unless_alias
1587 qn2 = self.connection.ops.quote_name
1588
1589 for index, select_col in enumerate(self.query.select):
1590 lhs_sql, lhs_params = self.compile(select_col)
1591 rhs = f"{qn(alias)}.{qn2(columns[index])}"
1592 self.query.where.add(RawSQL(f"{lhs_sql} = {rhs}", lhs_params), AND)
1593
1594 sql, params = self.as_sql()
1595 return "EXISTS (%s)" % sql, params
1596
1597 def explain_query(self):
1598 result = list(self.execute_sql())
1599 # Some backends return 1 item tuples with strings, and others return
1600 # tuples with integers and strings. Flatten them out into strings.
1601 format_ = self.query.explain_info.format
1602 output_formatter = json.dumps if format_ and format_.lower() == "json" else str
1603 for row in result[0]:
1604 if not isinstance(row, str):
1605 yield " ".join(output_formatter(c) for c in row)
1606 else:
1607 yield row
1608
1609
1610class SQLInsertCompiler(SQLCompiler):
1611 returning_fields = None
1612 returning_params = ()
1613
1614 def field_as_sql(self, field, val):
1615 """
1616 Take a field and a value intended to be saved on that field, and
1617 return placeholder SQL and accompanying params. Check for raw values,
1618 expressions, and fields with get_placeholder() defined in that order.
1619
1620 When field is None, consider the value raw and use it as the
1621 placeholder, with no corresponding parameters returned.
1622 """
1623 if field is None:
1624 # A field value of None means the value is raw.
1625 sql, params = val, []
1626 elif hasattr(val, "as_sql"):
1627 # This is an expression, let's compile it.
1628 sql, params = self.compile(val)
1629 elif hasattr(field, "get_placeholder"):
1630 # Some fields (e.g. geo fields) need special munging before
1631 # they can be inserted.
1632 sql, params = field.get_placeholder(val, self, self.connection), [val]
1633 else:
1634 # Return the common case for the placeholder
1635 sql, params = "%s", [val]
1636
1637 # The following hook is only used by Oracle Spatial, which sometimes
1638 # needs to yield 'NULL' and [] as its placeholder and params instead
1639 # of '%s' and [None]. The 'NULL' placeholder is produced earlier by
1640 # OracleOperations.get_geom_placeholder(). The following line removes
1641 # the corresponding None parameter. See ticket #10888.
1642 params = self.connection.ops.modify_insert_params(sql, params)
1643
1644 return sql, params
1645
1646 def prepare_value(self, field, value):
1647 """
1648 Prepare a value to be used in a query by resolving it if it is an
1649 expression and otherwise calling the field's get_db_prep_save().
1650 """
1651 if hasattr(value, "resolve_expression"):
1652 value = value.resolve_expression(
1653 self.query, allow_joins=False, for_save=True
1654 )
1655 # Don't allow values containing Col expressions. They refer to
1656 # existing columns on a row, but in the case of insert the row
1657 # doesn't exist yet.
1658 if value.contains_column_references:
1659 raise ValueError(
1660 'Failed to insert expression "{}" on {}. F() expressions '
1661 "can only be used to update, not to insert.".format(value, field)
1662 )
1663 if value.contains_aggregate:
1664 raise FieldError(
1665 "Aggregate functions are not allowed in this query "
1666 f"({field.name}={value!r})."
1667 )
1668 if value.contains_over_clause:
1669 raise FieldError(
1670 "Window expressions are not allowed in this query ({}={!r}).".format(
1671 field.name, value
1672 )
1673 )
1674 return field.get_db_prep_save(value, connection=self.connection)
1675
1676 def pre_save_val(self, field, obj):
1677 """
1678 Get the given field's value off the given obj. pre_save() is used for
1679 things like auto_now on DateTimeField. Skip it if this is a raw query.
1680 """
1681 if self.query.raw:
1682 return getattr(obj, field.attname)
1683 return field.pre_save(obj, add=True)
1684
1685 def assemble_as_sql(self, fields, value_rows):
1686 """
1687 Take a sequence of N fields and a sequence of M rows of values, and
1688 generate placeholder SQL and parameters for each field and value.
1689 Return a pair containing:
1690 * a sequence of M rows of N SQL placeholder strings, and
1691 * a sequence of M rows of corresponding parameter values.
1692
1693 Each placeholder string may contain any number of '%s' interpolation
1694 strings, and each parameter row will contain exactly as many params
1695 as the total number of '%s's in the corresponding placeholder row.
1696 """
1697 if not value_rows:
1698 return [], []
1699
1700 # list of (sql, [params]) tuples for each object to be saved
1701 # Shape: [n_objs][n_fields][2]
1702 rows_of_fields_as_sql = (
1703 (self.field_as_sql(field, v) for field, v in zip(fields, row))
1704 for row in value_rows
1705 )
1706
1707 # tuple like ([sqls], [[params]s]) for each object to be saved
1708 # Shape: [n_objs][2][n_fields]
1709 sql_and_param_pair_rows = (zip(*row) for row in rows_of_fields_as_sql)
1710
1711 # Extract separate lists for placeholders and params.
1712 # Each of these has shape [n_objs][n_fields]
1713 placeholder_rows, param_rows = zip(*sql_and_param_pair_rows)
1714
1715 # Params for each field are still lists, and need to be flattened.
1716 param_rows = [[p for ps in row for p in ps] for row in param_rows]
1717
1718 return placeholder_rows, param_rows
1719
1720 def as_sql(self):
1721 # We don't need quote_name_unless_alias() here, since these are all
1722 # going to be column names (so we can avoid the extra overhead).
1723 qn = self.connection.ops.quote_name
1724 opts = self.query.get_meta()
1725 insert_statement = self.connection.ops.insert_statement(
1726 on_conflict=self.query.on_conflict,
1727 )
1728 result = [f"{insert_statement} {qn(opts.db_table)}"]
1729 fields = self.query.fields or [opts.pk]
1730 result.append("(%s)" % ", ".join(qn(f.column) for f in fields))
1731
1732 if self.query.fields:
1733 value_rows = [
1734 [
1735 self.prepare_value(field, self.pre_save_val(field, obj))
1736 for field in fields
1737 ]
1738 for obj in self.query.objs
1739 ]
1740 else:
1741 # An empty object.
1742 value_rows = [
1743 [self.connection.ops.pk_default_value()] for _ in self.query.objs
1744 ]
1745 fields = [None]
1746
1747 # Currently the backends just accept values when generating bulk
1748 # queries and generate their own placeholders. Doing that isn't
1749 # necessary and it should be possible to use placeholders and
1750 # expressions in bulk inserts too.
1751 can_bulk = (
1752 not self.returning_fields and self.connection.features.has_bulk_insert
1753 )
1754
1755 placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
1756
1757 on_conflict_suffix_sql = self.connection.ops.on_conflict_suffix_sql(
1758 fields,
1759 self.query.on_conflict,
1760 (f.column for f in self.query.update_fields),
1761 (f.column for f in self.query.unique_fields),
1762 )
1763 if (
1764 self.returning_fields
1765 and self.connection.features.can_return_columns_from_insert
1766 ):
1767 if self.connection.features.can_return_rows_from_bulk_insert:
1768 result.append(
1769 self.connection.ops.bulk_insert_sql(fields, placeholder_rows)
1770 )
1771 params = param_rows
1772 else:
1773 result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
1774 params = [param_rows[0]]
1775 if on_conflict_suffix_sql:
1776 result.append(on_conflict_suffix_sql)
1777 # Skip empty r_sql to allow subclasses to customize behavior for
1778 # 3rd party backends. Refs #19096.
1779 r_sql, self.returning_params = self.connection.ops.return_insert_columns(
1780 self.returning_fields
1781 )
1782 if r_sql:
1783 result.append(r_sql)
1784 params += [self.returning_params]
1785 return [(" ".join(result), tuple(chain.from_iterable(params)))]
1786
1787 if can_bulk:
1788 result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
1789 if on_conflict_suffix_sql:
1790 result.append(on_conflict_suffix_sql)
1791 return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
1792 else:
1793 if on_conflict_suffix_sql:
1794 result.append(on_conflict_suffix_sql)
1795 return [
1796 (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
1797 for p, vals in zip(placeholder_rows, param_rows)
1798 ]
1799
1800 def execute_sql(self, returning_fields=None):
1801 assert not (
1802 returning_fields
1803 and len(self.query.objs) != 1
1804 and not self.connection.features.can_return_rows_from_bulk_insert
1805 )
1806 opts = self.query.get_meta()
1807 self.returning_fields = returning_fields
1808 with self.connection.cursor() as cursor:
1809 for sql, params in self.as_sql():
1810 cursor.execute(sql, params)
1811 if not self.returning_fields:
1812 return []
1813 if (
1814 self.connection.features.can_return_rows_from_bulk_insert
1815 and len(self.query.objs) > 1
1816 ):
1817 rows = self.connection.ops.fetch_returned_insert_rows(cursor)
1818 elif self.connection.features.can_return_columns_from_insert:
1819 assert len(self.query.objs) == 1
1820 rows = [
1821 self.connection.ops.fetch_returned_insert_columns(
1822 cursor,
1823 self.returning_params,
1824 )
1825 ]
1826 else:
1827 rows = [
1828 (
1829 self.connection.ops.last_insert_id(
1830 cursor,
1831 opts.db_table,
1832 opts.pk.column,
1833 ),
1834 )
1835 ]
1836 cols = [field.get_col(opts.db_table) for field in self.returning_fields]
1837 converters = self.get_converters(cols)
1838 if converters:
1839 rows = list(self.apply_converters(rows, converters))
1840 return rows
1841
1842
1843class SQLDeleteCompiler(SQLCompiler):
1844 @cached_property
1845 def single_alias(self):
1846 # Ensure base table is in aliases.
1847 self.query.get_initial_alias()
1848 return sum(self.query.alias_refcount[t] > 0 for t in self.query.alias_map) == 1
1849
1850 @classmethod
1851 def _expr_refs_base_model(cls, expr, base_model):
1852 if isinstance(expr, Query):
1853 return expr.model == base_model
1854 if not hasattr(expr, "get_source_expressions"):
1855 return False
1856 return any(
1857 cls._expr_refs_base_model(source_expr, base_model)
1858 for source_expr in expr.get_source_expressions()
1859 )
1860
1861 @cached_property
1862 def contains_self_reference_subquery(self):
1863 return any(
1864 self._expr_refs_base_model(expr, self.query.model)
1865 for expr in chain(
1866 self.query.annotations.values(), self.query.where.children
1867 )
1868 )
1869
1870 def _as_sql(self, query):
1871 delete = "DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)
1872 try:
1873 where, params = self.compile(query.where)
1874 except FullResultSet:
1875 return delete, ()
1876 return f"{delete} WHERE {where}", tuple(params)
1877
1878 def as_sql(self):
1879 """
1880 Create the SQL for this query. Return the SQL string and list of
1881 parameters.
1882 """
1883 if self.single_alias and not self.contains_self_reference_subquery:
1884 return self._as_sql(self.query)
1885 innerq = self.query.clone()
1886 innerq.__class__ = Query
1887 innerq.clear_select_clause()
1888 pk = self.query.model._meta.pk
1889 innerq.select = [pk.get_col(self.query.get_initial_alias())]
1890 outerq = Query(self.query.model)
1891 if not self.connection.features.update_can_self_select:
1892 # Force the materialization of the inner query to allow reference
1893 # to the target table on MySQL.
1894 sql, params = innerq.get_compiler(connection=self.connection).as_sql()
1895 innerq = RawSQL("SELECT * FROM (%s) subquery" % sql, params)
1896 outerq.add_filter("pk__in", innerq)
1897 return self._as_sql(outerq)
1898
1899
1900class SQLUpdateCompiler(SQLCompiler):
1901 def as_sql(self):
1902 """
1903 Create the SQL for this query. Return the SQL string and list of
1904 parameters.
1905 """
1906 self.pre_sql_setup()
1907 if not self.query.values:
1908 return "", ()
1909 qn = self.quote_name_unless_alias
1910 values, update_params = [], []
1911 for field, model, val in self.query.values:
1912 if hasattr(val, "resolve_expression"):
1913 val = val.resolve_expression(
1914 self.query, allow_joins=False, for_save=True
1915 )
1916 if val.contains_aggregate:
1917 raise FieldError(
1918 "Aggregate functions are not allowed in this query "
1919 f"({field.name}={val!r})."
1920 )
1921 if val.contains_over_clause:
1922 raise FieldError(
1923 "Window expressions are not allowed in this query "
1924 f"({field.name}={val!r})."
1925 )
1926 elif hasattr(val, "prepare_database_save"):
1927 if field.remote_field:
1928 val = val.prepare_database_save(field)
1929 else:
1930 raise TypeError(
1931 "Tried to update field {} with a model instance, {!r}. "
1932 "Use a value compatible with {}.".format(
1933 field, val, field.__class__.__name__
1934 )
1935 )
1936 val = field.get_db_prep_save(val, connection=self.connection)
1937
1938 # Getting the placeholder for the field.
1939 if hasattr(field, "get_placeholder"):
1940 placeholder = field.get_placeholder(val, self, self.connection)
1941 else:
1942 placeholder = "%s"
1943 name = field.column
1944 if hasattr(val, "as_sql"):
1945 sql, params = self.compile(val)
1946 values.append(f"{qn(name)} = {placeholder % sql}")
1947 update_params.extend(params)
1948 elif val is not None:
1949 values.append(f"{qn(name)} = {placeholder}")
1950 update_params.append(val)
1951 else:
1952 values.append("%s = NULL" % qn(name))
1953 table = self.query.base_table
1954 result = [
1955 "UPDATE %s SET" % qn(table),
1956 ", ".join(values),
1957 ]
1958 try:
1959 where, params = self.compile(self.query.where)
1960 except FullResultSet:
1961 params = []
1962 else:
1963 result.append("WHERE %s" % where)
1964 return " ".join(result), tuple(update_params + params)
1965
1966 def execute_sql(self, result_type):
1967 """
1968 Execute the specified update. Return the number of rows affected by
1969 the primary update query. The "primary update query" is the first
1970 non-empty query that is executed. Row counts for any subsequent,
1971 related queries are not available.
1972 """
1973 cursor = super().execute_sql(result_type)
1974 try:
1975 rows = cursor.rowcount if cursor else 0
1976 is_empty = cursor is None
1977 finally:
1978 if cursor:
1979 cursor.close()
1980 for query in self.query.get_related_updates():
1981 aux_rows = query.get_compiler(self.using).execute_sql(result_type)
1982 if is_empty and aux_rows:
1983 rows = aux_rows
1984 is_empty = False
1985 return rows
1986
1987 def pre_sql_setup(self):
1988 """
1989 If the update depends on results from other tables, munge the "where"
1990 conditions to match the format required for (portable) SQL updates.
1991
1992 If multiple updates are required, pull out the id values to update at
1993 this point so that they don't change as a result of the progressive
1994 updates.
1995 """
1996 refcounts_before = self.query.alias_refcount.copy()
1997 # Ensure base table is in the query
1998 self.query.get_initial_alias()
1999 count = self.query.count_active_tables()
2000 if not self.query.related_updates and count == 1:
2001 return
2002 query = self.query.chain(klass=Query)
2003 query.select_related = False
2004 query.clear_ordering(force=True)
2005 query.extra = {}
2006 query.select = []
2007 meta = query.get_meta()
2008 fields = [meta.pk.name]
2009 related_ids_index = []
2010 for related in self.query.related_updates:
2011 if all(
2012 path.join_field.primary_key for path in meta.get_path_to_parent(related)
2013 ):
2014 # If a primary key chain exists to the targeted related update,
2015 # then the meta.pk value can be used for it.
2016 related_ids_index.append((related, 0))
2017 else:
2018 # This branch will only be reached when updating a field of an
2019 # ancestor that is not part of the primary key chain of a MTI
2020 # tree.
2021 related_ids_index.append((related, len(fields)))
2022 fields.append(related._meta.pk.name)
2023 query.add_fields(fields)
2024 super().pre_sql_setup()
2025
2026 must_pre_select = (
2027 count > 1 and not self.connection.features.update_can_self_select
2028 )
2029
2030 # Now we adjust the current query: reset the where clause and get rid
2031 # of all the tables we don't need (since they're in the sub-select).
2032 self.query.clear_where()
2033 if self.query.related_updates or must_pre_select:
2034 # Either we're using the idents in multiple update queries (so
2035 # don't want them to change), or the db backend doesn't support
2036 # selecting from the updating table (e.g. MySQL).
2037 idents = []
2038 related_ids = collections.defaultdict(list)
2039 for rows in query.get_compiler(self.using).execute_sql(MULTI):
2040 idents.extend(r[0] for r in rows)
2041 for parent, index in related_ids_index:
2042 related_ids[parent].extend(r[index] for r in rows)
2043 self.query.add_filter("pk__in", idents)
2044 self.query.related_ids = related_ids
2045 else:
2046 # The fast path. Filters and updates in one query.
2047 self.query.add_filter("pk__in", query)
2048 self.query.reset_refcounts(refcounts_before)
2049
2050
2051class SQLAggregateCompiler(SQLCompiler):
2052 def as_sql(self):
2053 """
2054 Create the SQL for this query. Return the SQL string and list of
2055 parameters.
2056 """
2057 sql, params = [], []
2058 for annotation in self.query.annotation_select.values():
2059 ann_sql, ann_params = self.compile(annotation)
2060 ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params)
2061 sql.append(ann_sql)
2062 params.extend(ann_params)
2063 self.col_count = len(self.query.annotation_select)
2064 sql = ", ".join(sql)
2065 params = tuple(params)
2066
2067 inner_query_sql, inner_query_params = self.query.inner_query.get_compiler(
2068 self.using,
2069 elide_empty=self.elide_empty,
2070 ).as_sql(with_col_aliases=True)
2071 sql = f"SELECT {sql} FROM ({inner_query_sql}) subquery"
2072 params += inner_query_params
2073 return sql, params
2074
2075
2076def cursor_iter(cursor, sentinel, col_count, itersize):
2077 """
2078 Yield blocks of rows from a cursor and ensure the cursor is closed when
2079 done.
2080 """
2081 try:
2082 for rows in iter((lambda: cursor.fetchmany(itersize)), sentinel):
2083 yield rows if col_count is None else [r[:col_count] for r in rows]
2084 finally:
2085 cursor.close()