v0.146.0
   1from __future__ import annotations
   2
   3import collections
   4import json
   5import re
   6from collections.abc import Generator, Iterable, Sequence
   7from functools import cached_property, partial
   8from itertools import chain
   9from typing import TYPE_CHECKING, Any, Protocol, cast
  10
  11from plain.postgres.constants import LOOKUP_SEP
  12from plain.postgres.dialect import (
  13    PK_DEFAULT_VALUE,
  14    bulk_insert_sql,
  15    distinct_sql,
  16    explain_query_prefix,
  17    for_update_sql,
  18    limit_offset_sql,
  19    on_conflict_suffix_sql,
  20    quote_name,
  21    return_insert_columns,
  22)
  23from plain.postgres.exceptions import EmptyResultSet, FieldError, FullResultSet
  24from plain.postgres.expressions import (
  25    F,
  26    OrderBy,
  27    RawSQL,
  28    Ref,
  29    ResolvableExpression,
  30    Value,
  31)
  32from plain.postgres.fields import DATABASE_DEFAULT
  33from plain.postgres.fields.related import RelatedField
  34from plain.postgres.functions import Cast, Random
  35from plain.postgres.lookups import Lookup
  36from plain.postgres.meta import Meta
  37from plain.postgres.query_utils import select_related_descend
  38from plain.postgres.sql.constants import (
  39    CURSOR,
  40    MULTI,
  41    NO_RESULTS,
  42    ORDER_DIR,
  43    SINGLE,
  44)
  45from plain.postgres.sql.query import Query, get_order_dir
  46from plain.postgres.transaction import TransactionManagementError
  47from plain.utils.hashable import make_hashable
  48from plain.utils.regex_helper import _lazy_re_compile
  49
  50if TYPE_CHECKING:
  51    from plain.postgres.connection import DatabaseConnection
  52    from plain.postgres.expressions import BaseExpression
  53    from plain.postgres.sql.query import AggregateQuery, InsertQuery
  54
  55# Type aliases for SQL compilation results
  56SqlParams = tuple[Any, ...]
  57SqlWithParams = tuple[str, SqlParams]
  58
  59
  60class SQLCompilable(Protocol):
  61    """Protocol for objects that can be compiled to SQL."""
  62
  63    def as_sql(
  64        self, compiler: SQLCompiler, connection: DatabaseConnection
  65    ) -> tuple[str, Sequence[Any]]:
  66        """Return SQL string and parameters for this object."""
  67        ...
  68
  69
  70class PositionRef(Ref):
  71    def __init__(self, ordinal: int, refs: str, source: Any):
  72        self.ordinal = ordinal
  73        super().__init__(refs, source)
  74
  75    def as_sql(
  76        self, compiler: SQLCompiler, connection: DatabaseConnection
  77    ) -> tuple[str, list[Any]]:
  78        return str(self.ordinal), []
  79
  80
  81def get_converters(
  82    expressions: Iterable[Any], connection: DatabaseConnection
  83) -> dict[int, tuple[list[Any], Any]]:
  84    converters = {}
  85    for i, expression in enumerate(expressions):
  86        if expression:
  87            field_converters = expression.get_db_converters(connection)
  88            if field_converters:
  89                converters[i] = (field_converters, expression)
  90    return converters
  91
  92
  93def apply_converters(
  94    rows: Iterable, converters: dict, connection: DatabaseConnection
  95) -> Generator[list]:
  96    converters_list = list(converters.items())
  97    for row in map(list, rows):
  98        for pos, (convs, expression) in converters_list:
  99            value = row[pos]
 100            for converter in convs:
 101                value = converter(value, expression, connection)
 102            row[pos] = value
 103        yield row
 104
 105
 106class SQLCompiler:
 107    # Multiline ordering SQL clause may appear from RawSQL.
 108    ordering_parts = _lazy_re_compile(
 109        r"^(.*)\s(?:ASC|DESC).*",
 110        re.MULTILINE | re.DOTALL,
 111    )
 112
 113    def __init__(
 114        self, query: Query, connection: DatabaseConnection, elide_empty: bool = True
 115    ):
 116        self.query = query
 117        self.connection = connection
 118        # Some queries, e.g. coalesced aggregation, need to be executed even if
 119        # they would return an empty result set.
 120        self.elide_empty = elide_empty
 121        self.quote_cache: dict[str, str] = {"*": "*"}
 122        # The select, klass_info, and annotations are needed by QuerySet.iterator()
 123        # these are set as a side-effect of executing the query. Note that we calculate
 124        # separately a list of extra select columns needed for grammatical correctness
 125        # of the query, but these columns are not included in self.select.
 126        self.select: list[tuple[Any, SqlWithParams, str | None]] | None = None
 127        self.annotation_col_map: dict[str, int] | None = None
 128        self.klass_info: dict[str, Any] | None = None
 129        self._meta_ordering: list[str] | None = None
 130
 131    def __repr__(self) -> str:
 132        model_name = self.query.model.__qualname__ if self.query.model else "None"
 133        return (
 134            f"<{self.__class__.__qualname__} "
 135            f"model={model_name} "
 136            f"connection={self.connection!r}>"
 137        )
 138
 139    def setup_query(self, with_col_aliases: bool = False) -> None:
 140        if all(self.query.alias_refcount[a] == 0 for a in self.query.alias_map):
 141            self.query.get_initial_alias()
 142        self.select, self.klass_info, self.annotation_col_map = self.get_select(
 143            with_col_aliases=with_col_aliases,
 144        )
 145        self.col_count = len(self.select)
 146
 147    def pre_sql_setup(
 148        self, with_col_aliases: bool = False
 149    ) -> tuple[list[Any], list[Any], list[SqlWithParams]] | None:
 150        """
 151        Do any necessary class setup immediately prior to producing SQL. This
 152        is for things that can't necessarily be done in __init__ because we
 153        might not have all the pieces in place at that time.
 154        """
 155        self.setup_query(with_col_aliases=with_col_aliases)
 156        assert self.select is not None  # Set by setup_query()
 157        order_by = self.get_order_by()
 158        self.where, self.having, self.qualify = self.query.where.split_having_qualify(
 159            must_group_by=self.query.group_by is not None
 160        )
 161        extra_select = self.get_extra_select(order_by, self.select)
 162        self.has_extra_select = bool(extra_select)
 163        group_by = self.get_group_by(self.select + extra_select, order_by)
 164        return extra_select, order_by, group_by
 165
 166    def get_group_by(
 167        self, select: list[Any], order_by: list[Any]
 168    ) -> list[SqlWithParams]:
 169        """
 170        Return a list of 2-tuples of form (sql, params).
 171
 172        The logic of what exactly the GROUP BY clause contains is hard
 173        to describe in other words than "if it passes the test suite,
 174        then it is correct".
 175        """
 176        # Some examples:
 177        #     SomeModel.query.annotate(Count('somecol'))
 178        #     GROUP BY: all fields of the model
 179        #
 180        #    SomeModel.query.values('name').annotate(Count('somecol'))
 181        #    GROUP BY: name
 182        #
 183        #    SomeModel.query.annotate(Count('somecol')).values('name')
 184        #    GROUP BY: all cols of the model
 185        #
 186        #    SomeModel.query.values('name', 'id')
 187        #    .annotate(Count('somecol')).values('id')
 188        #    GROUP BY: name, id
 189        #
 190        #    SomeModel.query.values('name').annotate(Count('somecol')).values('id')
 191        #    GROUP BY: name, id
 192        #
 193        # In fact, the self.query.group_by is the minimal set to GROUP BY. It
 194        # can't be ever restricted to a smaller set, but additional columns in
 195        # HAVING, ORDER BY, and SELECT clauses are added to it. Unfortunately
 196        # the end result is that it is impossible to force the query to have
 197        # a chosen GROUP BY clause - you can almost do this by using the form:
 198        #     .values(*wanted_cols).annotate(AnAggregate())
 199        # but any later annotations, extra selects, values calls that
 200        # refer some column outside of the wanted_cols, order_by, or even
 201        # filter calls can alter the GROUP BY clause.
 202
 203        # The query.group_by is either None (no GROUP BY at all), True
 204        # (group by select fields), or a list of expressions to be added
 205        # to the group by.
 206        if self.query.group_by is None:
 207            return []
 208        expressions = []
 209        group_by_refs = set()
 210        if self.query.group_by is not True:
 211            # If the group by is set to a list (by .values() call most likely),
 212            # then we need to add everything in it to the GROUP BY clause.
 213            # Backwards compatibility hack for setting query.group_by. Remove
 214            # when we have public API way of forcing the GROUP BY clause.
 215            # Converts string references to expressions.
 216            for expr in self.query.group_by:
 217                if not hasattr(expr, "as_sql"):
 218                    expr = self.query.resolve_ref(expr)
 219                if isinstance(expr, Ref):
 220                    if expr.refs not in group_by_refs:
 221                        group_by_refs.add(expr.refs)
 222                        expressions.append(expr.source)
 223                else:
 224                    expressions.append(expr)
 225        # Note that even if the group_by is set, it is only the minimal
 226        # set to group by. So, we need to add cols in select, order_by, and
 227        # having into the select in any case.
 228        selected_expr_positions = {}
 229        for ordinal, (expr, _, alias) in enumerate(select, start=1):
 230            if alias:
 231                selected_expr_positions[expr] = ordinal
 232            # Skip members of the select clause that are already explicitly
 233            # grouped against.
 234            if alias in group_by_refs:
 235                continue
 236            expressions.extend(expr.get_group_by_cols())
 237        if not self._meta_ordering:
 238            for expr, (sql, params, is_ref) in order_by:
 239                # Skip references to the SELECT clause, as all expressions in
 240                # the SELECT clause are already part of the GROUP BY.
 241                if not is_ref:
 242                    expressions.extend(expr.get_group_by_cols())
 243        having_group_by = self.having.get_group_by_cols() if self.having else []
 244        for expr in having_group_by:
 245            expressions.append(expr)
 246        result = []
 247        seen = set()
 248        expressions = self.collapse_group_by(expressions, having_group_by)
 249
 250        for expr in expressions:
 251            try:
 252                sql, params = self.compile(expr)
 253            except (EmptyResultSet, FullResultSet):
 254                continue
 255            # Use select index for GROUP BY when possible
 256            if (position := selected_expr_positions.get(expr)) is not None:
 257                sql, params = str(position), ()
 258            else:
 259                sql, params = expr.select_format(self, sql, params)
 260            params_hash = make_hashable(params)
 261            if (sql, params_hash) not in seen:
 262                result.append((sql, params))
 263                seen.add((sql, params_hash))
 264        return result
 265
 266    def collapse_group_by(self, expressions: list[Any], having: list[Any]) -> list[Any]:
 267        # Use group by functional dependence reduction:
 268        # expressions can be reduced to the set of selected table
 269        # primary keys as all other columns are functionally dependent on them.
 270        # Filter out all expressions associated with a table's primary key
 271        # present in the grouped columns. This is done by identifying all
 272        # tables that have their primary key included in the grouped
 273        # columns and removing non-primary key columns referring to them.
 274        pks = {
 275            expr
 276            for expr in expressions
 277            if hasattr(expr, "target") and expr.target.primary_key
 278        }
 279        aliases = {expr.alias for expr in pks}
 280        return [
 281            expr
 282            for expr in expressions
 283            if expr in pks
 284            or expr in having
 285            or getattr(expr, "alias", None) not in aliases
 286        ]
 287
 288    def get_select(
 289        self, with_col_aliases: bool = False
 290    ) -> tuple[
 291        list[tuple[Any, SqlWithParams, str | None]],
 292        dict[str, Any] | None,
 293        dict[str, int],
 294    ]:
 295        """
 296        Return three values:
 297        - a list of 3-tuples of (expression, (sql, params), alias)
 298        - a klass_info structure,
 299        - a dictionary of annotations
 300
 301        The (sql, params) is what the expression will produce, and alias is the
 302        "AS alias" for the column (possibly None).
 303
 304        The klass_info structure contains the following information:
 305        - The base model of the query.
 306        - Which columns for that model are present in the query (by
 307          position of the select clause).
 308        - related_klass_infos: [f, klass_info] to descent into
 309
 310        The annotations is a dictionary of {'attname': column position} values.
 311        """
 312        select = []
 313        klass_info = None
 314        annotations = {}
 315        select_idx = 0
 316        for alias, (sql, params) in self.query.extra_select.items():
 317            annotations[alias] = select_idx
 318            select.append((RawSQL(sql, params), alias))
 319            select_idx += 1
 320        assert not (self.query.select and self.query.default_cols)
 321        select_mask = self.query.get_select_mask()
 322        if self.query.default_cols:
 323            cols = self.get_default_columns(select_mask)
 324        else:
 325            # self.query.select is a special case. These columns never go to
 326            # any model.
 327            cols = self.query.select
 328        if cols:
 329            select_list = []
 330            for col in cols:
 331                select_list.append(select_idx)
 332                select.append((col, None))
 333                select_idx += 1
 334            klass_info = {
 335                "model": self.query.model,
 336                "select_fields": select_list,
 337            }
 338        for alias, annotation in self.query.annotation_select.items():
 339            annotations[alias] = select_idx
 340            select.append((annotation, alias))
 341            select_idx += 1
 342
 343        if self.query.select_related:
 344            related_klass_infos = self.get_related_selections(select, select_mask)
 345            if klass_info is not None:
 346                klass_info["related_klass_infos"] = related_klass_infos
 347
 348        ret = []
 349        col_idx = 1
 350        for col, alias in select:
 351            try:
 352                sql, params = self.compile(col)
 353            except EmptyResultSet:
 354                empty_result_set_value = getattr(
 355                    col, "empty_result_set_value", NotImplemented
 356                )
 357                if empty_result_set_value is NotImplemented:
 358                    # Select a predicate that's always False.
 359                    sql, params = "0", ()
 360                else:
 361                    sql, params = self.compile(Value(empty_result_set_value))
 362            except FullResultSet:
 363                sql, params = self.compile(Value(True))
 364            else:
 365                sql, params = col.select_format(self, sql, params)
 366            if alias is None and with_col_aliases:
 367                alias = f"col{col_idx}"
 368                col_idx += 1
 369            ret.append((col, (sql, params), alias))
 370        return ret, klass_info, annotations
 371
 372    def _order_by_pairs(self) -> Generator[tuple[OrderBy, bool]]:
 373        if self.query.extra_order_by:
 374            ordering = self.query.extra_order_by
 375        elif not self.query.default_ordering:
 376            ordering = self.query.order_by
 377        elif self.query.order_by:
 378            ordering = self.query.order_by
 379        elif (
 380            self.query.model
 381            and (options := self.query.model.model_options)
 382            and options.ordering
 383        ):
 384            ordering = options.ordering
 385            self._meta_ordering = list(ordering)
 386        else:
 387            ordering = []
 388        if self.query.standard_ordering:
 389            default_order, _ = ORDER_DIR["ASC"]
 390        else:
 391            default_order, _ = ORDER_DIR["DESC"]
 392
 393        selected_exprs = {}
 394        if select := self.select:
 395            for ordinal, (expr, _, alias) in enumerate(select, start=1):
 396                pos_expr = PositionRef(ordinal, alias, expr)  # ty: ignore[invalid-argument-type]
 397                if alias:
 398                    selected_exprs[alias] = pos_expr
 399                selected_exprs[expr] = pos_expr
 400
 401        for field in ordering:
 402            if isinstance(field, ResolvableExpression):
 403                # field is a BaseExpression (has asc/desc/copy methods)
 404                field_expr = cast(BaseExpression, field)
 405                if isinstance(field_expr, Value):
 406                    # output_field must be resolved for constants.
 407                    field_expr = Cast(field_expr, field_expr.output_field)
 408                if not isinstance(field_expr, OrderBy):
 409                    field_expr = field_expr.asc()
 410                if not self.query.standard_ordering:
 411                    field_expr = field_expr.copy()
 412                    field_expr.reverse_ordering()
 413                field = field_expr
 414                select_ref = selected_exprs.get(field.expression)
 415                if select_ref or (
 416                    isinstance(field.expression, F)
 417                    and (select_ref := selected_exprs.get(field.expression.name))
 418                ):
 419                    field = field.copy()
 420                    field.expression = select_ref
 421                yield field, select_ref is not None
 422                continue
 423            if field == "?":  # random
 424                yield OrderBy(Random()), False
 425                continue
 426
 427            col, order = get_order_dir(field, default_order)
 428            descending = order == "DESC"
 429
 430            if select_ref := selected_exprs.get(col):
 431                # Reference to expression in SELECT clause
 432                yield (
 433                    OrderBy(
 434                        select_ref,
 435                        descending=descending,
 436                    ),
 437                    True,
 438                )
 439                continue
 440            if col in self.query.annotations:
 441                # References to an expression which is masked out of the SELECT
 442                # clause.
 443                expr = self.query.annotations[col]
 444                if isinstance(expr, Value):
 445                    # output_field must be resolved for constants.
 446                    expr = Cast(expr, expr.output_field)
 447                yield OrderBy(expr, descending=descending), False
 448                continue
 449
 450            if "." in field:
 451                # This came in through an extra(order_by=...) addition. Pass it
 452                # on verbatim.
 453                table, col = col.split(".", 1)
 454                yield (
 455                    OrderBy(
 456                        RawSQL(f"{self.quote_name_unless_alias(table)}.{col}", []),
 457                        descending=descending,
 458                    ),
 459                    False,
 460                )
 461                continue
 462
 463            if self.query.extra and col in self.query.extra:
 464                if col in self.query.extra_select:
 465                    yield (
 466                        OrderBy(
 467                            Ref(col, RawSQL(*self.query.extra[col])),
 468                            descending=descending,
 469                        ),
 470                        True,
 471                    )
 472                else:
 473                    yield (
 474                        OrderBy(RawSQL(*self.query.extra[col]), descending=descending),
 475                        False,
 476                    )
 477            else:
 478                # 'col' is of the form 'field' or 'field1__field2' or
 479                # '-field1__field2__field', etc.
 480                assert self.query.model is not None, (
 481                    "Ordering by fields requires a model"
 482                )
 483                meta = self.query.model._model_meta
 484                yield from self.find_ordering_name(
 485                    field,
 486                    meta,
 487                    default_order=default_order,
 488                )
 489
 490    def get_order_by(self) -> list[tuple[Any, tuple[str, tuple, bool]]]:
 491        """
 492        Return a list of 2-tuples of the form (expr, (sql, params, is_ref)) for
 493        the ORDER BY clause.
 494
 495        The order_by clause can alter the select clause (for example it can add
 496        aliases to clauses that do not yet have one, or it can add totally new
 497        select clauses).
 498        """
 499        result = []
 500        seen = set()
 501        for expr, is_ref in self._order_by_pairs():
 502            resolved = expr.resolve_expression(self.query, allow_joins=True, reuse=None)
 503            sql, params = self.compile(resolved)
 504            # Don't add the same column twice, but the order direction is
 505            # not taken into account so we strip it. When this entire method
 506            # is refactored into expressions, then we can check each part as we
 507            # generate it.
 508            without_ordering = self.ordering_parts.search(sql)[1]
 509            params_hash = make_hashable(params)
 510            if (without_ordering, params_hash) in seen:
 511                continue
 512            seen.add((without_ordering, params_hash))
 513            result.append((resolved, (sql, params, is_ref)))
 514        return result
 515
 516    def get_extra_select(
 517        self, order_by: list[Any], select: list[Any]
 518    ) -> list[tuple[Any, SqlWithParams, None]]:
 519        extra_select = []
 520        if self.query.distinct and not self.query.distinct_fields:
 521            select_sql = [t[1] for t in select]
 522            for expr, (sql, params, is_ref) in order_by:
 523                without_ordering = self.ordering_parts.search(sql)[1]
 524                if not is_ref and (without_ordering, params) not in select_sql:
 525                    extra_select.append((expr, (without_ordering, params), None))
 526        return extra_select
 527
 528    def quote_name_unless_alias(self, name: str) -> str:
 529        """
 530        A wrapper around quote_name() that doesn't quote aliases for table
 531        names. This avoids problems with some SQL dialects that treat quoted
 532        strings specially (e.g. PostgreSQL).
 533        """
 534        if name in self.quote_cache:
 535            return self.quote_cache[name]
 536        if (
 537            (name in self.query.alias_map and name not in self.query.table_map)
 538            or name in self.query.extra_select
 539            or (
 540                self.query.external_aliases.get(name)
 541                and name not in self.query.table_map
 542            )
 543        ):
 544            self.quote_cache[name] = name
 545            return name
 546        r = quote_name(name)
 547        self.quote_cache[name] = r
 548        return r
 549
 550    def compile(self, node: SQLCompilable) -> SqlWithParams:
 551        sql, params = node.as_sql(self, self.connection)
 552        return sql, tuple(params)
 553
 554    def get_qualify_sql(self) -> tuple[list[str], list[Any]]:
 555        where_parts = []
 556        if self.where:
 557            where_parts.append(self.where)
 558        if self.having:
 559            where_parts.append(self.having)
 560        inner_query = self.query.clone()
 561        inner_query.subquery = True
 562        inner_query.where = inner_query.where.__class__(where_parts)
 563        # Augment the inner query with any window function references that
 564        # might have been masked via values() and alias(). If any masked
 565        # aliases are added they'll be masked again to avoid fetching
 566        # the data in the `if qual_aliases` branch below.
 567        select = {
 568            expr: alias for expr, _, alias in self.get_select(with_col_aliases=True)[0]
 569        }
 570        select_aliases = set(select.values())
 571        qual_aliases = set()
 572        replacements = {}
 573
 574        def collect_replacements(expressions: list[Any]) -> None:
 575            while expressions:
 576                expr = expressions.pop()
 577                if expr in replacements:
 578                    continue
 579                elif select_alias := select.get(expr):
 580                    replacements[expr] = select_alias
 581                elif isinstance(expr, Lookup):
 582                    expressions.extend(expr.get_source_expressions())
 583                elif isinstance(expr, Ref):
 584                    if expr.refs not in select_aliases:
 585                        expressions.extend(expr.get_source_expressions())
 586                else:
 587                    num_qual_alias = len(qual_aliases)
 588                    select_alias = f"qual{num_qual_alias}"
 589                    qual_aliases.add(select_alias)
 590                    inner_query.add_annotation(expr, select_alias)
 591                    replacements[expr] = select_alias
 592
 593        qualify = self.qualify
 594        if qualify is None:
 595            raise ValueError("QUALIFY clause expected but not provided")
 596        collect_replacements(list(qualify.leaves()))
 597        qualify = qualify.replace_expressions(
 598            {expr: Ref(alias, expr) for expr, alias in replacements.items()}
 599        )
 600        self.qualify = qualify
 601        order_by = []
 602        for order_by_expr, *_ in self.get_order_by():
 603            collect_replacements(order_by_expr.get_source_expressions())
 604            order_by.append(
 605                order_by_expr.replace_expressions(
 606                    {expr: Ref(alias, expr) for expr, alias in replacements.items()}
 607                )
 608            )
 609        inner_query_compiler = inner_query.get_compiler(elide_empty=self.elide_empty)
 610        inner_sql, inner_params = inner_query_compiler.as_sql(
 611            # The limits must be applied to the outer query to avoid pruning
 612            # results too eagerly.
 613            with_limits=False,
 614            # Force unique aliasing of selected columns to avoid collisions
 615            # and make rhs predicates referencing easier.
 616            with_col_aliases=True,
 617        )
 618        qualify_sql, qualify_params = self.compile(qualify)
 619        result = [
 620            "SELECT * FROM (",
 621            inner_sql,
 622            ")",
 623            quote_name("qualify"),
 624            "WHERE",
 625            qualify_sql,
 626        ]
 627        if qual_aliases:
 628            # If some select aliases were unmasked for filtering purposes they
 629            # must be masked back.
 630            cols = [quote_name(alias) for alias in select.values() if alias is not None]
 631            result = [
 632                "SELECT",
 633                ", ".join(cols),
 634                "FROM (",
 635                *result,
 636                ")",
 637                quote_name("qualify_mask"),
 638            ]
 639        params = list(inner_params) + list(qualify_params)
 640        # As the SQL spec is unclear on whether or not derived tables
 641        # ordering must propagate it has to be explicitly repeated on the
 642        # outer-most query to ensure it's preserved.
 643        if order_by:
 644            ordering_sqls = []
 645            for ordering in order_by:
 646                ordering_sql, ordering_params = self.compile(ordering)
 647                ordering_sqls.append(ordering_sql)
 648                params.extend(ordering_params)
 649            result.extend(["ORDER BY", ", ".join(ordering_sqls)])
 650        return result, params
 651
 652    def as_sql(
 653        self, with_limits: bool = True, with_col_aliases: bool = False
 654    ) -> SqlWithParams:
 655        """
 656        Create the SQL for this query. Return the SQL string and list of
 657        parameters.
 658
 659        If 'with_limits' is False, any limit/offset information is not included
 660        in the query.
 661        """
 662        refcounts_before = self.query.alias_refcount.copy()
 663        try:
 664            result = self.pre_sql_setup(with_col_aliases=with_col_aliases)
 665            assert result is not None  # SQLCompiler.pre_sql_setup always returns tuple
 666            extra_select, order_by, group_by = result
 667            assert self.select is not None  # Set by pre_sql_setup()
 668            for_update_part = None
 669            # Is a LIMIT/OFFSET clause needed?
 670            with_limit_offset = with_limits and self.query.is_sliced
 671            if self.qualify:
 672                result, params = self.get_qualify_sql()
 673                order_by = None
 674            else:
 675                distinct_fields, distinct_params = self.get_distinct()
 676                # This must come after 'select', 'ordering', and 'distinct'
 677                # (see docstring of get_from_clause() for details).
 678                from_, f_params = self.get_from_clause()
 679                try:
 680                    where, w_params = (
 681                        self.compile(self.where) if self.where is not None else ("", [])
 682                    )
 683                except EmptyResultSet:
 684                    if self.elide_empty:
 685                        raise
 686                    # Use a predicate that's always False.
 687                    where, w_params = "0 = 1", []
 688                except FullResultSet:
 689                    where, w_params = "", []
 690                try:
 691                    having, h_params = (
 692                        self.compile(self.having)
 693                        if self.having is not None
 694                        else ("", [])
 695                    )
 696                except FullResultSet:
 697                    having, h_params = "", []
 698                result = ["SELECT"]
 699                params = []
 700
 701                if self.query.distinct:
 702                    distinct_result, distinct_params = distinct_sql(
 703                        distinct_fields,
 704                        distinct_params,
 705                    )
 706                    result += distinct_result
 707                    params += distinct_params
 708
 709                out_cols = []
 710                for _, (s_sql, s_params), alias in self.select + extra_select:
 711                    if alias:
 712                        s_sql = f"{s_sql} AS {quote_name(alias)}"
 713                    params.extend(s_params)
 714                    out_cols.append(s_sql)
 715
 716                result += [", ".join(out_cols)]
 717                if from_:
 718                    result += ["FROM", *from_]
 719                params.extend(f_params)
 720
 721                if self.query.select_for_update:
 722                    if self.connection.get_autocommit():
 723                        raise TransactionManagementError(
 724                            "select_for_update cannot be used outside of a transaction."
 725                        )
 726
 727                    for_update_part = for_update_sql(
 728                        nowait=self.query.select_for_update_nowait,
 729                        skip_locked=self.query.select_for_update_skip_locked,
 730                        of=tuple(self.get_select_for_update_of_arguments()),
 731                        no_key=self.query.select_for_no_key_update,
 732                    )
 733
 734                if where:
 735                    result.append(f"WHERE {where}")
 736                    params.extend(w_params)
 737
 738                grouping = []
 739                for g_sql, g_params in group_by:
 740                    grouping.append(g_sql)
 741                    params.extend(g_params)
 742                if grouping:
 743                    if distinct_fields:
 744                        raise NotImplementedError(
 745                            "annotate() + distinct(fields) is not implemented."
 746                        )
 747                    order_by = order_by or []
 748                    result.append("GROUP BY {}".format(", ".join(grouping)))
 749                    if self._meta_ordering:
 750                        order_by = None
 751                if having:
 752                    result.append(f"HAVING {having}")
 753                    params.extend(h_params)
 754
 755            if self.query.explain_info:
 756                result.insert(
 757                    0,
 758                    explain_query_prefix(
 759                        self.query.explain_info.format,
 760                        **self.query.explain_info.options,
 761                    ),
 762                )
 763
 764            if order_by:
 765                ordering = []
 766                for _, (o_sql, o_params, _) in order_by:
 767                    ordering.append(o_sql)
 768                    params.extend(o_params)
 769                result.append("ORDER BY {}".format(", ".join(ordering)))
 770
 771            if with_limit_offset:
 772                result.append(
 773                    limit_offset_sql(self.query.low_mark, self.query.high_mark)
 774                )
 775
 776            if for_update_part:
 777                result.append(for_update_part)
 778
 779            if self.query.subquery and extra_select:
 780                # If the query is used as a subquery, the extra selects would
 781                # result in more columns than the left-hand side expression is
 782                # expecting. This can happen when a subquery uses a combination
 783                # of order_by() and distinct(), forcing the ordering expressions
 784                # to be selected as well. Wrap the query in another subquery
 785                # to exclude extraneous selects.
 786                sub_selects = []
 787                sub_params = []
 788                for index, (select, _, alias) in enumerate(self.select, start=1):
 789                    if alias:
 790                        sub_selects.append(
 791                            "{}.{}".format(
 792                                quote_name("subquery"),
 793                                quote_name(alias),
 794                            )
 795                        )
 796                    else:
 797                        select_clone = select.relabeled_clone(
 798                            {select.alias: "subquery"}
 799                        )
 800                        subselect, subparams = select_clone.as_sql(
 801                            self, self.connection
 802                        )
 803                        sub_selects.append(subselect)
 804                        sub_params.extend(subparams)
 805                return "SELECT {} FROM ({}) subquery".format(
 806                    ", ".join(sub_selects),
 807                    " ".join(result),
 808                ), tuple(sub_params + params)
 809
 810            return " ".join(result), tuple(params)
 811        finally:
 812            # Finally do cleanup - get rid of the joins we created above.
 813            self.query.reset_refcounts(refcounts_before)
 814
 815    def get_default_columns(
 816        self,
 817        select_mask: Any,
 818        start_alias: str | None = None,
 819        opts: Meta | None = None,
 820    ) -> list[Any]:
 821        """
 822        Return Col expressions for every concrete field on the model. When
 823        pulling in a related model (e.g. via select_related), the caller
 824        passes ``opts`` and ``start_alias`` to traverse from that join.
 825        """
 826        result = []
 827        if opts is None:
 828            if self.query.model is None:
 829                return result
 830            opts = self.query.model._model_meta
 831        start_alias = start_alias or self.query.get_initial_alias()
 832
 833        for field in opts.concrete_fields:
 834            if select_mask and field not in select_mask:
 835                continue
 836            result.append(field.get_col(start_alias))
 837        return result
 838
 839    def get_distinct(self) -> tuple[list[str], list]:
 840        """
 841        Return a quoted list of fields to use in DISTINCT ON part of the query.
 842
 843        This method can alter the tables in the query, and thus it must be
 844        called before get_from_clause().
 845        """
 846        result = []
 847        params = []
 848        if not self.query.distinct_fields:
 849            return result, params
 850
 851        if self.query.model is None:
 852            return result, params
 853        opts = self.query.model._model_meta
 854
 855        for name in self.query.distinct_fields:
 856            parts = name.split(LOOKUP_SEP)
 857            _, targets, alias, joins, path, _, transform_function = self._setup_joins(
 858                parts, opts, None
 859            )
 860            targets, alias, _ = self.query.trim_joins(targets, joins, path)
 861            for target in targets:
 862                if name in self.query.annotation_select:
 863                    result.append(quote_name(name))
 864                else:
 865                    r, p = self.compile(transform_function(target, alias))
 866                    result.append(r)
 867                    params.append(p)
 868        return result, params
 869
 870    def find_ordering_name(
 871        self,
 872        name: str,
 873        meta: Meta,
 874        alias: str | None = None,
 875        default_order: str = "ASC",
 876        already_seen: set | None = None,
 877    ) -> list[tuple[OrderBy, bool]]:
 878        """
 879        Return the table alias (the name might be ambiguous, the alias will
 880        not be) and column name for ordering by the given 'name' parameter.
 881        The 'name' is of the form 'field1__field2__...__fieldN'.
 882        """
 883        name, order = get_order_dir(name, default_order)
 884        descending = order == "DESC"
 885        pieces = name.split(LOOKUP_SEP)
 886        (
 887            field,
 888            targets,
 889            alias,
 890            joins,
 891            path,
 892            meta,
 893            transform_function,
 894        ) = self._setup_joins(pieces, meta, alias)
 895
 896        # If we get to this point and the field is a relation to another model,
 897        # append the default ordering for that model unless it is the
 898        # attribute name of the field that is specified or
 899        # there are transforms to process.
 900        if (
 901            isinstance(field, RelatedField)
 902            and meta.model.model_options.ordering
 903            and getattr(field, "attname", None) != pieces[-1]
 904            and not getattr(transform_function, "has_transforms", False)
 905        ):
 906            # Firstly, avoid infinite loops.
 907            already_seen = already_seen or set()
 908            join_tuple = tuple(
 909                getattr(self.query.alias_map[j], "join_cols", None) for j in joins
 910            )
 911            if join_tuple in already_seen:
 912                raise FieldError("Infinite loop caused by ordering.")
 913            already_seen.add(join_tuple)
 914
 915            results = []
 916            for item in meta.model.model_options.ordering:
 917                if isinstance(item, ResolvableExpression) and not isinstance(
 918                    item, OrderBy
 919                ):
 920                    item_expr: BaseExpression = cast(BaseExpression, item)
 921                    item = item_expr.desc() if descending else item_expr.asc()
 922                if isinstance(item, OrderBy):
 923                    results.append(
 924                        (item.prefix_references(f"{name}{LOOKUP_SEP}"), False)
 925                    )
 926                    continue
 927                results.extend(
 928                    (expr.prefix_references(f"{name}{LOOKUP_SEP}"), is_ref)
 929                    for expr, is_ref in self.find_ordering_name(
 930                        item, meta, alias, order, already_seen
 931                    )
 932                )
 933            return results
 934        targets, alias, _ = self.query.trim_joins(targets, joins, path)
 935        return [
 936            (OrderBy(transform_function(t, alias), descending=descending), False)
 937            for t in targets
 938        ]
 939
 940    def _setup_joins(
 941        self, pieces: list[str], meta: Meta, alias: str | None
 942    ) -> tuple[Any, Any, str, list, Any, Meta, Any]:
 943        """
 944        Helper method for get_order_by() and get_distinct().
 945
 946        get_ordering() and get_distinct() must produce same target columns on
 947        same input, as the prefixes of get_ordering() and get_distinct() must
 948        match. Executing SQL where this is not true is an error.
 949        """
 950        alias = alias or self.query.get_initial_alias()
 951        assert alias is not None
 952        field, targets, meta, joins, path, transform_function = self.query.setup_joins(
 953            pieces, meta, alias
 954        )
 955        alias = joins[-1]
 956        return field, targets, alias, joins, path, meta, transform_function
 957
 958    def get_from_clause(self) -> tuple[list[str], list]:
 959        """
 960        Return a list of strings that are joined together to go after the
 961        "FROM" part of the query, as well as a list any extra parameters that
 962        need to be included. Subclasses, can override this to create a
 963        from-clause via a "select".
 964
 965        This should only be called after any SQL construction methods that
 966        might change the tables that are needed. This means the select columns,
 967        ordering, and distinct must be done first.
 968        """
 969        result = []
 970        params = []
 971        for alias in tuple(self.query.alias_map):
 972            if not self.query.alias_refcount[alias]:
 973                continue
 974            try:
 975                from_clause = self.query.alias_map[alias]
 976            except KeyError:
 977                # Extra tables can end up in self.tables, but not in the
 978                # alias_map if they aren't in a join. That's OK. We skip them.
 979                continue
 980            clause_sql, clause_params = self.compile(from_clause)
 981            result.append(clause_sql)
 982            params.extend(clause_params)
 983        for t in self.query.extra_tables:
 984            alias, _ = self.query.table_alias(t)
 985            # Only add the alias if it's not already present (the table_alias()
 986            # call increments the refcount, so an alias refcount of one means
 987            # this is the only reference).
 988            if (
 989                alias not in self.query.alias_map
 990                or self.query.alias_refcount[alias] == 1
 991            ):
 992                result.append(f", {self.quote_name_unless_alias(alias)}")
 993        return result, params
 994
 995    def get_related_selections(
 996        self,
 997        select: list[Any],
 998        select_mask: Any,
 999        opts: Meta | None = None,
1000        root_alias: str | None = None,
1001        cur_depth: int = 1,
1002        requested: dict | None = None,
1003        restricted: bool | None = None,
1004    ) -> list[dict[str, Any]]:
1005        """
1006        Fill in the information needed for a select_related query. The current
1007        depth is measured as the number of connections away from the root model
1008        (for example, cur_depth=1 means we are looking at models with direct
1009        connections to the root model).
1010
1011        Args:
1012            opts: Meta for the model being queried (internal metadata)
1013        """
1014
1015        related_klass_infos = []
1016        if not restricted and cur_depth > self.query.max_depth:
1017            # We've recursed far enough; bail out.
1018            return related_klass_infos
1019
1020        if not opts:
1021            assert self.query.model is not None, "select_related requires a model"
1022            opts = self.query.model._model_meta
1023            root_alias = self.query.get_initial_alias()
1024
1025        assert root_alias is not None  # Must be provided or set above
1026        assert opts is not None
1027
1028        def _get_field_choices() -> chain:
1029            direct_choices = (
1030                f.name for f in opts.fields if isinstance(f, RelatedField)
1031            )
1032            reverse_choices = (
1033                f.field.related_query_name()
1034                for f in opts.related_objects
1035                if f.field.primary_key
1036            )
1037            return chain(
1038                direct_choices, reverse_choices, self.query._filtered_relations
1039            )
1040
1041        # Setup for the case when only particular related fields should be
1042        # included in the related selection.
1043        fields_found = set()
1044        if requested is None:
1045            restricted = isinstance(self.query.select_related, dict)
1046            if restricted:
1047                requested = cast(dict, self.query.select_related)
1048
1049        def get_related_klass_infos(
1050            klass_info: dict, related_klass_infos: list
1051        ) -> None:
1052            klass_info["related_klass_infos"] = related_klass_infos
1053
1054        for f in opts.fields:
1055            fields_found.add(f.name)
1056
1057            if restricted:
1058                assert requested is not None
1059                next = requested.get(f.name, {})
1060                if not isinstance(f, RelatedField):
1061                    # If a non-related field is used like a relation,
1062                    # or if a single non-relational field is given.
1063                    if next or f.name in requested:
1064                        raise FieldError(
1065                            "Non-relational field given in select_related: '{}'. "
1066                            "Choices are: {}".format(
1067                                f.name,
1068                                ", ".join(_get_field_choices()) or "(none)",
1069                            )
1070                        )
1071            else:
1072                next = None
1073
1074            if not select_related_descend(f, restricted, requested, select_mask):
1075                continue
1076            related_select_mask = select_mask.get(f) or {}
1077            klass_info: dict[str, Any] = {
1078                "model": f.remote_field.model,
1079                "field": f,
1080                "reverse": False,
1081                "local_setter": f.set_cached_value,
1082                "remote_setter": f.remote_field.set_cached_value
1083                if f.primary_key
1084                else lambda x, y: None,
1085            }
1086            related_klass_infos.append(klass_info)
1087            select_fields = []
1088            _, _, _, joins, _, _ = self.query.setup_joins([f.name], opts, root_alias)
1089            alias = joins[-1]
1090            columns = self.get_default_columns(
1091                related_select_mask,
1092                start_alias=alias,
1093                opts=f.remote_field.model._model_meta,
1094            )
1095            for col in columns:
1096                select_fields.append(len(select))
1097                select.append((col, None))
1098            klass_info["select_fields"] = select_fields
1099            next_klass_infos = self.get_related_selections(
1100                select,
1101                related_select_mask,
1102                f.remote_field.model._model_meta,
1103                alias,
1104                cur_depth + 1,
1105                next,
1106                restricted,
1107            )
1108            get_related_klass_infos(klass_info, next_klass_infos)
1109
1110        if restricted:
1111            from plain.postgres.fields.reverse_related import ManyToManyRel
1112
1113            related_fields = [
1114                (o.field, o.related_model)
1115                for o in opts.related_objects
1116                if o.field.primary_key and not isinstance(o, ManyToManyRel)
1117            ]
1118            for related_field, model in related_fields:
1119                related_select_mask = select_mask.get(related_field) or {}
1120
1121                if not select_related_descend(
1122                    related_field,
1123                    restricted,
1124                    requested,
1125                    related_select_mask,
1126                    reverse=True,
1127                ):
1128                    continue
1129
1130                related_field_name = related_field.related_query_name()
1131                fields_found.add(related_field_name)
1132
1133                join_info = self.query.setup_joins(
1134                    [related_field_name], opts, root_alias
1135                )
1136                alias = join_info.joins[-1]
1137                klass_info: dict[str, Any] = {
1138                    "model": model,
1139                    "field": related_field,
1140                    "reverse": True,
1141                    "local_setter": related_field.remote_field.set_cached_value,
1142                    "remote_setter": related_field.set_cached_value,
1143                }
1144                related_klass_infos.append(klass_info)
1145                select_fields = []
1146                columns = self.get_default_columns(
1147                    related_select_mask,
1148                    start_alias=alias,
1149                    opts=model._model_meta,
1150                )
1151                for col in columns:
1152                    select_fields.append(len(select))
1153                    select.append((col, None))
1154                klass_info["select_fields"] = select_fields
1155                assert requested is not None
1156                next = requested.get(related_field.related_query_name(), {})
1157                next_klass_infos = self.get_related_selections(
1158                    select,
1159                    related_select_mask,
1160                    model._model_meta,
1161                    alias,
1162                    cur_depth + 1,
1163                    next,
1164                    restricted,
1165                )
1166                get_related_klass_infos(klass_info, next_klass_infos)
1167
1168            def local_setter(final_field: Any, obj: Any, from_obj: Any) -> None:
1169                # Set a reverse fk object when relation is non-empty.
1170                if from_obj:
1171                    final_field.remote_field.set_cached_value(from_obj, obj)
1172
1173            def local_setter_noop(obj: Any, from_obj: Any) -> None:
1174                pass
1175
1176            def remote_setter(name: str, obj: Any, from_obj: Any) -> None:
1177                setattr(from_obj, name, obj)
1178
1179            assert requested is not None
1180            for name in list(requested):
1181                # Filtered relations work only on the topmost level.
1182                if cur_depth > 1:
1183                    break
1184                if name in self.query._filtered_relations:
1185                    fields_found.add(name)
1186                    final_field, _, join_opts, joins, _, _ = self.query.setup_joins(
1187                        [name], opts, root_alias
1188                    )
1189                    model = join_opts.model
1190                    alias = joins[-1]
1191                    klass_info: dict[str, Any] = {
1192                        "model": model,
1193                        "field": final_field,
1194                        "reverse": True,
1195                        "local_setter": (
1196                            partial(local_setter, final_field)
1197                            if len(joins) <= 2
1198                            else local_setter_noop
1199                        ),
1200                        "remote_setter": partial(remote_setter, name),
1201                    }
1202                    related_klass_infos.append(klass_info)
1203                    select_fields = []
1204                    field_select_mask = select_mask.get((name, final_field)) or {}
1205                    columns = self.get_default_columns(
1206                        field_select_mask,
1207                        start_alias=alias,
1208                        opts=model._model_meta,
1209                    )
1210                    for col in columns:
1211                        select_fields.append(len(select))
1212                        select.append((col, None))
1213                    klass_info["select_fields"] = select_fields
1214                    next_requested = requested.get(name, {})
1215                    next_klass_infos = self.get_related_selections(
1216                        select,
1217                        field_select_mask,
1218                        opts=model._model_meta,
1219                        root_alias=alias,
1220                        cur_depth=cur_depth + 1,
1221                        requested=next_requested,
1222                        restricted=restricted,
1223                    )
1224                    get_related_klass_infos(klass_info, next_klass_infos)
1225            fields_not_found = set(requested).difference(fields_found)
1226            if fields_not_found:
1227                invalid_fields = (f"'{s}'" for s in fields_not_found)
1228                raise FieldError(
1229                    "Invalid field name(s) given in select_related: {}. "
1230                    "Choices are: {}".format(
1231                        ", ".join(invalid_fields),
1232                        ", ".join(_get_field_choices()) or "(none)",
1233                    )
1234                )
1235        return related_klass_infos
1236
1237    def get_select_for_update_of_arguments(self) -> list[str]:
1238        """
1239        Return a quoted list of arguments for the SELECT FOR UPDATE OF part of
1240        the query.
1241        """
1242
1243        def _get_first_selected_col_from_model(klass_info: dict) -> Any | None:
1244            """
1245            Find the first selected column whose target field belongs to this
1246            klass_info's model. Returns None when the model isn't represented
1247            in the select list — callers use that to skip locking the row.
1248            """
1249            assert self.select is not None
1250            model = klass_info["model"]
1251            for select_index in klass_info["select_fields"]:
1252                if self.select[select_index][0].target.model == model:
1253                    return self.select[select_index][0]
1254            return None
1255
1256        def _get_field_choices() -> Generator[str]:
1257            """Yield all allowed field paths in breadth-first search order."""
1258            queue = collections.deque([(None, self.klass_info)])
1259            while queue:
1260                parent_path, klass_info = queue.popleft()
1261                if parent_path is None:
1262                    path = []
1263                    yield "self"
1264                else:
1265                    assert klass_info is not None  # Only first iteration has None
1266                    field = klass_info["field"]
1267                    if klass_info["reverse"]:
1268                        field = field.remote_field
1269                    path = parent_path + [field.name]
1270                    yield LOOKUP_SEP.join(path)
1271                if klass_info is not None:
1272                    queue.extend(
1273                        (path, related_klass_info)  # type: ignore[invalid-argument-type]
1274                        for related_klass_info in klass_info.get(
1275                            "related_klass_infos", []
1276                        )
1277                    )
1278
1279        if not self.klass_info:
1280            return []
1281        result = []
1282        invalid_names = []
1283        for name in self.query.select_for_update_of:
1284            klass_info = self.klass_info
1285            if name == "self":
1286                col = _get_first_selected_col_from_model(klass_info)
1287            else:
1288                for part in name.split(LOOKUP_SEP):
1289                    if klass_info is None:
1290                        break
1291                    klass_infos = (*klass_info.get("related_klass_infos", []),)
1292                    for related_klass_info in klass_infos:
1293                        field = related_klass_info["field"]
1294                        if related_klass_info["reverse"]:
1295                            field = field.remote_field
1296                        if field.name == part:
1297                            klass_info = related_klass_info
1298                            break
1299                    else:
1300                        klass_info = None
1301                        break
1302                if klass_info is None:
1303                    invalid_names.append(name)
1304                    continue
1305                col = _get_first_selected_col_from_model(klass_info)
1306            if col is not None:
1307                result.append(self.quote_name_unless_alias(col.alias))
1308        if invalid_names:
1309            raise FieldError(
1310                "Invalid field name(s) given in select_for_update(of=(...)): {}. "
1311                "Only relational fields followed in the query are allowed. "
1312                "Choices are: {}.".format(
1313                    ", ".join(invalid_names),
1314                    ", ".join(_get_field_choices()),
1315                )
1316            )
1317        return result
1318
1319    def results_iter(
1320        self,
1321        results: Any = None,
1322        tuple_expected: bool = False,
1323        chunked_fetch: bool = False,
1324    ) -> Iterable[Any]:
1325        """Return an iterator over the results from executing this query."""
1326        if results is None:
1327            results = self.execute_sql(MULTI, chunked_fetch=chunked_fetch)
1328        assert self.select is not None  # Set during query execution
1329        fields = [s[0] for s in self.select[0 : self.col_count]]
1330        converters = get_converters(fields, self.connection)
1331        rows = results
1332        if converters:
1333            rows = apply_converters(rows, converters, self.connection)
1334            if tuple_expected:
1335                rows = map(tuple, rows)
1336        return rows
1337
1338    def has_results(self) -> bool:
1339        """Check if the query returns any results."""
1340        return bool(self.execute_sql(SINGLE))
1341
1342    def execute_sql(
1343        self,
1344        result_type: str = MULTI,
1345        chunked_fetch: bool = False,
1346    ) -> Any:
1347        """
1348        Run the query against the database and return the result(s). The
1349        return value is a single data item if result_type is SINGLE, or a
1350        flat iterable of rows if the result_type is MULTI.
1351
1352        result_type is either MULTI (returns a list from fetchall(), or a
1353        streaming generator from cursor.stream() when chunked_fetch=True),
1354        SINGLE (only retrieve a single row), or None. In this last case, the
1355        cursor is returned if any query is executed, since it's used by
1356        subclasses such as InsertQuery). It's possible, however, that no query
1357        is needed, as the filters describe an empty set. In that case, None is
1358        returned, to avoid any unnecessary database interaction.
1359        """
1360        result_type = result_type or NO_RESULTS
1361        try:
1362            as_sql_result = self.as_sql()
1363            # SQLCompiler.as_sql returns SqlWithParams, subclasses may differ
1364            assert isinstance(as_sql_result, tuple)
1365            assert isinstance(as_sql_result[0], str)
1366            sql, params = as_sql_result
1367            if not sql:
1368                raise EmptyResultSet
1369        except EmptyResultSet:
1370            if result_type == MULTI:
1371                return iter([])
1372            else:
1373                return
1374        cursor = self.connection.cursor()
1375        if chunked_fetch:
1376            # Use psycopg3's cursor.stream() for server-side cursor iteration.
1377            result = cursor.stream(sql, params)
1378            if self.has_extra_select:
1379                col_count = self.col_count
1380                result = (r[:col_count] for r in result)
1381            return result
1382
1383        try:
1384            cursor.execute(sql, params)
1385        except Exception:
1386            cursor.close()
1387            raise
1388
1389        if result_type == CURSOR:
1390            # Give the caller the cursor to process and close.
1391            return cursor
1392        if result_type == SINGLE:
1393            try:
1394                val = cursor.fetchone()
1395                if val:
1396                    return val[0 : self.col_count]
1397                return val
1398            finally:
1399                # done with the cursor
1400                cursor.close()
1401        if result_type == NO_RESULTS:
1402            cursor.close()
1403            return
1404
1405        try:
1406            rows = cursor.fetchall()
1407        finally:
1408            cursor.close()
1409        if self.has_extra_select:
1410            rows = [r[: self.col_count] for r in rows]
1411        return rows
1412
1413    def explain_query(self) -> Generator[str]:
1414        result = self.execute_sql()
1415        explain_info = self.query.explain_info
1416        # PostgreSQL may return tuples with integers and strings depending on
1417        # the EXPLAIN format. Flatten them out into strings.
1418        format_ = explain_info.format if explain_info is not None else None
1419        output_formatter = json.dumps if format_ and format_.lower() == "json" else str
1420        for row in result:
1421            if not isinstance(row, str):
1422                yield " ".join(output_formatter(c) for c in row)
1423            else:
1424                yield row
1425
1426
1427class SQLInsertCompiler(SQLCompiler):
1428    query: InsertQuery
1429    returning_fields: list | None = None
1430    returning_params: tuple = ()
1431
1432    def field_as_sql(self, field: Any, val: Any) -> tuple[str, list]:
1433        """
1434        Take a field and a value intended to be saved on that field, and
1435        return placeholder SQL and accompanying params. Check for raw values,
1436        expressions, and fields with get_placeholder() defined in that order.
1437
1438        When field is None, consider the value raw and use it as the
1439        placeholder, with no corresponding parameters returned.
1440        """
1441        if val is DATABASE_DEFAULT:
1442            # Emit the literal DEFAULT keyword so Postgres uses the column's
1443            # persistent DEFAULT (e.g. `gen_random_uuid()`). RETURNING then
1444            # populates the real value back onto the instance.
1445            sql, params = "DEFAULT", []
1446        elif field is None:
1447            # A field value of None means the value is raw.
1448            sql, params = val, []
1449        elif hasattr(val, "as_sql"):
1450            # This is an expression, let's compile it.
1451            sql, params_tuple = self.compile(val)
1452            params = list(params_tuple)
1453        elif hasattr(field, "get_placeholder"):
1454            # Some fields (e.g. geo fields) need special munging before
1455            # they can be inserted.
1456            sql, params = field.get_placeholder(val, self, self.connection), [val]
1457        else:
1458            # Return the common case for the placeholder
1459            sql, params = "%s", [val]
1460
1461        return sql, list(params)  # Ensure params is a list
1462
1463    def prepare_value(self, field: Any, value: Any) -> Any:
1464        """
1465        Prepare a value to be used in a query by resolving it if it is an
1466        expression and otherwise calling the field's get_db_prep_save().
1467        """
1468        if value is DATABASE_DEFAULT:
1469            # Carry the sentinel through untouched — field_as_sql will emit
1470            # the literal DEFAULT keyword.
1471            return value
1472        if isinstance(value, ResolvableExpression):
1473            value = value.resolve_expression(
1474                self.query, allow_joins=False, for_save=True
1475            )
1476            # Don't allow values containing Col expressions. They refer to
1477            # existing columns on a row, but in the case of insert the row
1478            # doesn't exist yet.
1479            if value.contains_column_references:
1480                raise ValueError(
1481                    f'Failed to insert expression "{value}" on {field}. F() expressions '
1482                    "can only be used to update, not to insert."
1483                )
1484            if value.contains_aggregate:
1485                raise FieldError(
1486                    "Aggregate functions are not allowed in this query "
1487                    f"({field.name}={value!r})."
1488                )
1489            if value.contains_over_clause:
1490                raise FieldError(
1491                    f"Window expressions are not allowed in this query ({field.name}={value!r})."
1492                )
1493        return field.get_db_prep_save(value, connection=self.connection)
1494
1495    def pre_save_val(self, field: Any, obj: Any) -> Any:
1496        """
1497        Get the given field's value off the given obj. pre_save() is used for
1498        things like update_now on DateTimeField. Skip it if this is a raw query.
1499        """
1500        if self.query.raw:
1501            return getattr(obj, field.attname)
1502        return field.pre_save(obj, add=True)
1503
1504    def assemble_as_sql(
1505        self, fields: list[Any], value_rows: list[list[Any]]
1506    ) -> tuple[Any, list[list[Any]]]:
1507        """
1508        Take a sequence of N fields and a sequence of M rows of values, and
1509        generate placeholder SQL and parameters for each field and value.
1510        Return a pair containing:
1511         * a sequence of M rows of N SQL placeholder strings, and
1512         * a sequence of M rows of corresponding parameter values.
1513
1514        Each placeholder string may contain any number of '%s' interpolation
1515        strings, and each parameter row will contain exactly as many params
1516        as the total number of '%s's in the corresponding placeholder row.
1517        """
1518        if not value_rows:
1519            return [], []
1520
1521        # list of (sql, [params]) tuples for each object to be saved
1522        # Shape: [n_objs][n_fields][2]
1523        rows_of_fields_as_sql = (
1524            (self.field_as_sql(field, v) for field, v in zip(fields, row))
1525            for row in value_rows
1526        )
1527
1528        # tuple like ([sqls], [[params]s]) for each object to be saved
1529        # Shape: [n_objs][2][n_fields]
1530        sql_and_param_pair_rows = (zip(*row) for row in rows_of_fields_as_sql)
1531
1532        # Extract separate lists for placeholders and params.
1533        # Each of these has shape [n_objs][n_fields]
1534        placeholder_rows, param_rows = zip(*sql_and_param_pair_rows)
1535
1536        # Params for each field are still lists, and need to be flattened.
1537        param_rows = [[p for ps in row for p in ps] for row in param_rows]
1538
1539        return placeholder_rows, param_rows
1540
1541    def as_sql(  # ty: ignore[invalid-method-override]  # Returns list for internal iteration in execute_sql
1542        self, with_limits: bool = True, with_col_aliases: bool = False
1543    ) -> list[SqlWithParams]:
1544        # We don't need quote_name_unless_alias() here, since these are all
1545        # going to be column names (so we can avoid the extra overhead).
1546        qn = quote_name
1547        assert self.query.model is not None, "INSERT requires a model"
1548        meta = self.query.model._model_meta
1549        options = self.query.model.model_options
1550        result = [f"INSERT INTO {qn(options.db_table)}"]
1551        if self.query.fields:
1552            fields = self.query.fields
1553        else:
1554            fields = [meta.get_forward_field("id")]
1555        result.append("({})".format(", ".join(qn(f.column) for f in fields)))
1556
1557        if self.query.fields:
1558            value_rows = [
1559                [
1560                    self.prepare_value(field, self.pre_save_val(field, obj))
1561                    for field in fields
1562                ]
1563                for obj in self.query.objs
1564            ]
1565        else:
1566            # An empty object.
1567            value_rows = [[PK_DEFAULT_VALUE] for _ in self.query.objs]
1568            fields = [None]
1569
1570        placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
1571
1572        conflict_suffix_sql = on_conflict_suffix_sql(
1573            fields,  # ty: ignore[invalid-argument-type]
1574            self.query.on_conflict,
1575            (f.column for f in self.query.update_fields),
1576            (f.column for f in self.query.unique_fields),
1577        )
1578        if self.returning_fields:
1579            # Use RETURNING clause to get inserted values
1580            result.append(
1581                bulk_insert_sql(fields, placeholder_rows)  # ty: ignore[invalid-argument-type]
1582            )
1583            params = param_rows
1584            if conflict_suffix_sql:
1585                result.append(conflict_suffix_sql)
1586            # Skip empty r_sql in case returning_cols returns an empty string.
1587            returning_cols = return_insert_columns(self.returning_fields)
1588            if returning_cols:
1589                r_sql, self.returning_params = returning_cols
1590                if r_sql:
1591                    result.append(r_sql)
1592                    params += [list(self.returning_params)]
1593            return [(" ".join(result), tuple(chain.from_iterable(params)))]
1594
1595        # Bulk insert without returning fields
1596        result.append(bulk_insert_sql(fields, placeholder_rows))  # ty: ignore[invalid-argument-type]
1597        if conflict_suffix_sql:
1598            result.append(conflict_suffix_sql)
1599        return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
1600
1601    def execute_sql(  # ty: ignore[invalid-method-override]
1602        self, returning_fields: list | None = None
1603    ) -> list:
1604        assert self.query.model is not None, "INSERT execution requires a model"
1605        options = self.query.model.model_options
1606        self.returning_fields = returning_fields
1607        with self.connection.cursor() as cursor:
1608            for sql, params in self.as_sql():
1609                cursor.execute(sql, params)
1610            if not self.returning_fields:
1611                return []
1612            # Use RETURNING clause for both single and bulk inserts
1613            if len(self.query.objs) > 1:
1614                rows = cursor.fetchall()
1615            else:
1616                rows = [cursor.fetchone()]
1617        cols = [field.get_col(options.db_table) for field in self.returning_fields]
1618        converters = get_converters(cols, self.connection)
1619        if converters:
1620            rows = list(apply_converters(rows, converters, self.connection))
1621        return rows
1622
1623
1624class SQLDeleteCompiler(SQLCompiler):
1625    @cached_property
1626    def single_alias(self) -> bool:
1627        # Ensure base table is in aliases.
1628        self.query.get_initial_alias()
1629        return sum(self.query.alias_refcount[t] > 0 for t in self.query.alias_map) == 1
1630
1631    @classmethod
1632    def _expr_refs_base_model(cls, expr: Any, base_model: Any) -> bool:
1633        if isinstance(expr, Query):
1634            return expr.model == base_model
1635        if not hasattr(expr, "get_source_expressions"):
1636            return False
1637        return any(
1638            cls._expr_refs_base_model(source_expr, base_model)
1639            for source_expr in expr.get_source_expressions()
1640        )
1641
1642    @cached_property
1643    def contains_self_reference_subquery(self) -> bool:
1644        return any(
1645            self._expr_refs_base_model(expr, self.query.model)
1646            for expr in chain(
1647                self.query.annotations.values(), self.query.where.children
1648            )
1649        )
1650
1651    def _as_sql(self, query: Query) -> SqlWithParams:
1652        delete = f"DELETE FROM {self.quote_name_unless_alias(query.base_table)}"  # ty: ignore[invalid-argument-type]
1653        try:
1654            where, params = self.compile(query.where)
1655        except FullResultSet:
1656            return delete, ()
1657        return f"{delete} WHERE {where}", tuple(params)
1658
1659    def as_sql(
1660        self, with_limits: bool = True, with_col_aliases: bool = False
1661    ) -> SqlWithParams:
1662        """
1663        Create the SQL for this query. Return the SQL string and list of
1664        parameters.
1665        """
1666        if self.single_alias and not self.contains_self_reference_subquery:
1667            return self._as_sql(self.query)
1668        innerq = self.query.clone()
1669        innerq.__class__ = Query
1670        innerq.clear_select_clause()
1671        assert self.query.model is not None, "DELETE requires a model"
1672        id_field = self.query.model._model_meta.get_forward_field("id")
1673        innerq.select = (id_field.get_col(self.query.get_initial_alias()),)
1674        outerq = Query(self.query.model)
1675        outerq.add_filter("id__in", innerq)
1676        return self._as_sql(outerq)
1677
1678
1679class SQLUpdateCompiler(SQLCompiler):
1680    def as_sql(
1681        self, with_limits: bool = True, with_col_aliases: bool = False
1682    ) -> SqlWithParams:
1683        """
1684        Create the SQL for this query. Return the SQL string and list of
1685        parameters.
1686        """
1687        self.pre_sql_setup()
1688        query_values = getattr(self.query, "values", None)
1689        if not query_values:
1690            return "", ()
1691        qn = self.quote_name_unless_alias
1692        values, update_params = [], []
1693        for field, val in query_values:
1694            if isinstance(val, ResolvableExpression):
1695                val = val.resolve_expression(
1696                    self.query, allow_joins=False, for_save=True
1697                )
1698                if val.contains_aggregate:
1699                    raise FieldError(
1700                        "Aggregate functions are not allowed in this query "
1701                        f"({field.name}={val!r})."
1702                    )
1703                if val.contains_over_clause:
1704                    raise FieldError(
1705                        "Window expressions are not allowed in this query "
1706                        f"({field.name}={val!r})."
1707                    )
1708            elif hasattr(val, "prepare_database_save"):
1709                if isinstance(field, RelatedField):
1710                    val = val.prepare_database_save(field)
1711                else:
1712                    raise TypeError(
1713                        f"Tried to update field {field} with a model instance, {val!r}. "
1714                        f"Use a value compatible with {field.__class__.__name__}."
1715                    )
1716            val = field.get_db_prep_save(val, connection=self.connection)
1717
1718            # Getting the placeholder for the field.
1719            if hasattr(field, "get_placeholder"):
1720                placeholder = field.get_placeholder(val, self, self.connection)
1721            else:
1722                placeholder = "%s"
1723            name = field.column
1724            if hasattr(val, "as_sql"):
1725                sql, params = self.compile(val)
1726                values.append(f"{qn(name)} = {placeholder % sql}")
1727                update_params.extend(params)
1728            elif val is not None:
1729                values.append(f"{qn(name)} = {placeholder}")
1730                update_params.append(val)
1731            else:
1732                values.append(f"{qn(name)} = NULL")
1733        table = self.query.base_table
1734        result = [
1735            f"UPDATE {qn(table)} SET",  # ty: ignore[invalid-argument-type]
1736            ", ".join(values),
1737        ]
1738        try:
1739            where, params = self.compile(self.query.where)
1740        except FullResultSet:
1741            params = []
1742        else:
1743            result.append(f"WHERE {where}")
1744        return " ".join(result), tuple(update_params + list(params))
1745
1746    def execute_sql(self, result_type: str) -> int:  # ty: ignore[invalid-method-override]
1747        """Execute the update and return the number of rows affected."""
1748        cursor = super().execute_sql(result_type)
1749        try:
1750            return cursor.rowcount if cursor else 0
1751        finally:
1752            if cursor:
1753                cursor.close()
1754
1755    def pre_sql_setup(
1756        self, with_col_aliases: bool = False
1757    ) -> tuple[list[Any], list[Any], list[SqlWithParams]] | None:
1758        """
1759        If the update depends on other tables (JOINs in the WHERE clause),
1760        rewrite the query so the current table is filtered by `id IN (subquery)`.
1761        """
1762        refcounts_before = self.query.alias_refcount.copy()
1763        # Ensure base table is in the query
1764        self.query.get_initial_alias()
1765        count = self.query.count_active_tables()
1766        if count == 1:
1767            return
1768        query = self.query.chain(klass=Query)
1769        query.select_related = False
1770        query.clear_ordering(force=True)
1771        query.extra = {}
1772        query.select = ()
1773        query.add_fields(["id"])
1774        super().pre_sql_setup()
1775
1776        # Reset the where clause and drop the tables we no longer need (they
1777        # live in the sub-select now).
1778        self.query.clear_where()
1779        self.query.add_filter("id__in", query)
1780        self.query.reset_refcounts(refcounts_before)
1781
1782
1783class SQLAggregateCompiler(SQLCompiler):
1784    def as_sql(
1785        self, with_limits: bool = True, with_col_aliases: bool = False
1786    ) -> SqlWithParams:
1787        """
1788        Create the SQL for this query. Return the SQL string and list of
1789        parameters.
1790        """
1791        sql, params = [], []
1792        for annotation in self.query.annotation_select.values():
1793            ann_sql, ann_params = self.compile(annotation)
1794            ann_sql, ann_params = annotation.select_format(self, ann_sql, ann_params)
1795            sql.append(ann_sql)
1796            params.extend(ann_params)
1797        self.col_count = len(self.query.annotation_select)
1798        sql = ", ".join(sql)
1799        params = tuple(params)
1800
1801        inner_query = cast("AggregateQuery", self.query).inner_query
1802        inner_query_sql, inner_query_params = inner_query.get_compiler(
1803            elide_empty=self.elide_empty,
1804        ).as_sql(with_col_aliases=True)
1805        sql = f"SELECT {sql} FROM ({inner_query_sql}) subquery"
1806        params += inner_query_params
1807        return sql, params