1"""
  2Useful auxiliary data structures for query construction. Not useful outside
  3the SQL domain.
  4"""
  5
  6from __future__ import annotations
  7
  8from typing import TYPE_CHECKING, Any
  9
 10from plain.models.exceptions import FullResultSet
 11from plain.models.postgres.sql import quote_name
 12from plain.models.sql.constants import INNER, LOUTER
 13
 14if TYPE_CHECKING:
 15    from plain.models.fields.related import ForeignKeyField
 16    from plain.models.fields.reverse_related import ForeignObjectRel
 17    from plain.models.postgres.connection import DatabaseConnection
 18    from plain.models.sql.compiler import SQLCompiler
 19
 20
 21class MultiJoin(Exception):
 22    """
 23    Used by join construction code to indicate the point at which a
 24    multi-valued join was attempted (if the caller wants to treat that
 25    exceptionally).
 26    """
 27
 28    def __init__(
 29        self, names_pos: int, path_with_names: list[tuple[str, list[Any]]]
 30    ) -> None:
 31        self.level = names_pos
 32        # The path travelled, this includes the path to the multijoin.
 33        self.names_with_path = path_with_names
 34
 35
 36class Empty:
 37    pass
 38
 39
 40class Join:
 41    """
 42    Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
 43    FROM entry. For example, the SQL generated could be
 44        LEFT OUTER JOIN "sometable" T1
 45        ON ("othertable"."sometable_id" = "sometable"."id")
 46
 47    This class is primarily used in Query.alias_map. All entries in alias_map
 48    must be Join compatible by providing the following attributes and methods:
 49        - table_name (string)
 50        - table_alias (possible alias for the table, can be None)
 51        - join_type (can be None for those entries that aren't joined from
 52          anything)
 53        - parent_alias (which table is this join's parent, can be None similarly
 54          to join_type)
 55        - as_sql()
 56        - relabeled_clone()
 57    """
 58
 59    def __init__(
 60        self,
 61        table_name: str,
 62        parent_alias: str,
 63        table_alias: str,
 64        join_type: str,
 65        join_field: ForeignKeyField | ForeignObjectRel,
 66        nullable: bool,
 67        filtered_relation: Any = None,
 68    ) -> None:
 69        # Join table
 70        self.table_name = table_name
 71        self.parent_alias = parent_alias
 72        # Note: table_alias is not necessarily known at instantiation time.
 73        self.table_alias = table_alias
 74        # LOUTER or INNER
 75        self.join_type = join_type
 76        # A list of 2-tuples to use in the ON clause of the JOIN.
 77        # Each 2-tuple will create one join condition in the ON clause.
 78        self.join_cols = join_field.get_joining_columns()
 79        # Along which field (or ForeignObjectRel in the reverse join case)
 80        self.join_field = join_field
 81        # Is this join nullabled?
 82        self.nullable = nullable
 83        self.filtered_relation = filtered_relation
 84
 85    def as_sql(
 86        self, compiler: SQLCompiler, connection: DatabaseConnection
 87    ) -> tuple[str, list[Any]]:
 88        """
 89        Generate the full
 90           LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
 91        clause for this join.
 92        """
 93        join_conditions = []
 94        params = []
 95        qn = compiler.quote_name_unless_alias
 96        qn2 = quote_name
 97
 98        # Add a join condition for each pair of joining columns.
 99        for lhs_col, rhs_col in self.join_cols:
100            join_conditions.append(
101                f"{qn(self.parent_alias)}.{qn2(lhs_col)} = {qn(self.table_alias)}.{qn2(rhs_col)}"
102            )
103
104        if self.filtered_relation:
105            try:
106                extra_sql, extra_params = compiler.compile(self.filtered_relation)
107            except FullResultSet:
108                pass
109            else:
110                join_conditions.append(f"({extra_sql})")
111                params.extend(extra_params)
112        if not join_conditions:
113            # This might be a rel on the other end of an actual declared field.
114            declared_field = getattr(self.join_field, "field", self.join_field)
115            raise ValueError(
116                f"Join generated an empty ON clause. {declared_field.__class__} did not yield either "
117                "joining columns or extra restrictions."
118            )
119        on_clause_sql = " AND ".join(join_conditions)
120        alias_str = (
121            "" if self.table_alias == self.table_name else (f" {self.table_alias}")
122        )
123        sql = f"{self.join_type} {qn(self.table_name)}{alias_str} ON ({on_clause_sql})"
124        return sql, params
125
126    def relabeled_clone(self, change_map: dict[str, str]) -> Join:
127        new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
128        new_table_alias = change_map.get(self.table_alias, self.table_alias)
129        if self.filtered_relation is not None:
130            filtered_relation = self.filtered_relation.clone()
131            filtered_relation.path = [
132                change_map.get(p, p) for p in self.filtered_relation.path
133            ]
134        else:
135            filtered_relation = None
136        return self.__class__(
137            self.table_name,
138            new_parent_alias,
139            new_table_alias,
140            self.join_type,
141            self.join_field,
142            self.nullable,
143            filtered_relation=filtered_relation,
144        )
145
146    @property
147    def identity(self) -> tuple[type[Join], str, str, Any, Any]:
148        return (
149            self.__class__,
150            self.table_name,
151            self.parent_alias,
152            self.join_field,
153            self.filtered_relation,
154        )
155
156    def __eq__(self, other: object) -> bool:
157        if not isinstance(other, Join):
158            return NotImplemented
159        return self.identity == other.identity
160
161    def __hash__(self) -> int:
162        return hash(self.identity)
163
164    def equals(self, other: Join) -> bool:
165        # Ignore filtered_relation in equality check.
166        return self.identity[:-1] == other.identity[:-1]
167
168    def demote(self) -> Join:
169        new = self.relabeled_clone({})
170        new.join_type = INNER
171        return new
172
173    def promote(self) -> Join:
174        new = self.relabeled_clone({})
175        new.join_type = LOUTER
176        return new
177
178
179class BaseTable:
180    """
181    The BaseTable class is used for base table references in FROM clause. For
182    example, the SQL "foo" in
183        SELECT * FROM "foo" WHERE somecond
184    could be generated by this class.
185    """
186
187    join_type = None
188    parent_alias = None
189    filtered_relation = None
190
191    def __init__(self, table_name: str, alias: str) -> None:
192        self.table_name = table_name
193        self.table_alias = alias
194
195    def as_sql(
196        self, compiler: SQLCompiler, connection: DatabaseConnection
197    ) -> tuple[str, list[Any]]:
198        alias_str = (
199            "" if self.table_alias == self.table_name else (f" {self.table_alias}")
200        )
201        base_sql = compiler.quote_name_unless_alias(self.table_name)
202        return base_sql + alias_str, []
203
204    def relabeled_clone(self, change_map: dict[str, str]) -> BaseTable:
205        return self.__class__(
206            self.table_name, change_map.get(self.table_alias, self.table_alias)
207        )
208
209    @property
210    def identity(self) -> tuple[type[BaseTable], str, str]:
211        return self.__class__, self.table_name, self.table_alias
212
213    def __eq__(self, other: object) -> bool:
214        if not isinstance(other, BaseTable):
215            return NotImplemented
216        return self.identity == other.identity
217
218    def __hash__(self) -> int:
219        return hash(self.identity)
220
221    def equals(self, other: BaseTable) -> bool:
222        return self.identity == other.identity