Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Useful auxiliary data structures for query construction. Not useful outside
  3the SQL domain.
  4"""
  5
  6from __future__ import annotations
  7
  8from typing import TYPE_CHECKING, Any
  9
 10from plain.models.exceptions import FullResultSet
 11from plain.models.sql.constants import INNER, LOUTER
 12
 13if TYPE_CHECKING:
 14    from plain.models.backends.base.base import BaseDatabaseWrapper
 15    from plain.models.fields.related import ForeignKeyField
 16    from plain.models.fields.reverse_related import ForeignObjectRel
 17    from plain.models.sql.compiler import SQLCompiler
 18
 19
 20class MultiJoin(Exception):
 21    """
 22    Used by join construction code to indicate the point at which a
 23    multi-valued join was attempted (if the caller wants to treat that
 24    exceptionally).
 25    """
 26
 27    def __init__(
 28        self, names_pos: int, path_with_names: list[tuple[str, list[Any]]]
 29    ) -> None:
 30        self.level = names_pos
 31        # The path travelled, this includes the path to the multijoin.
 32        self.names_with_path = path_with_names
 33
 34
 35class Empty:
 36    pass
 37
 38
 39class Join:
 40    """
 41    Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
 42    FROM entry. For example, the SQL generated could be
 43        LEFT OUTER JOIN "sometable" T1
 44        ON ("othertable"."sometable_id" = "sometable"."id")
 45
 46    This class is primarily used in Query.alias_map. All entries in alias_map
 47    must be Join compatible by providing the following attributes and methods:
 48        - table_name (string)
 49        - table_alias (possible alias for the table, can be None)
 50        - join_type (can be None for those entries that aren't joined from
 51          anything)
 52        - parent_alias (which table is this join's parent, can be None similarly
 53          to join_type)
 54        - as_sql()
 55        - relabeled_clone()
 56    """
 57
 58    def __init__(
 59        self,
 60        table_name: str,
 61        parent_alias: str,
 62        table_alias: str,
 63        join_type: str,
 64        join_field: ForeignKeyField | ForeignObjectRel,
 65        nullable: bool,
 66        filtered_relation: Any = None,
 67    ) -> None:
 68        # Join table
 69        self.table_name = table_name
 70        self.parent_alias = parent_alias
 71        # Note: table_alias is not necessarily known at instantiation time.
 72        self.table_alias = table_alias
 73        # LOUTER or INNER
 74        self.join_type = join_type
 75        # A list of 2-tuples to use in the ON clause of the JOIN.
 76        # Each 2-tuple will create one join condition in the ON clause.
 77        self.join_cols = join_field.get_joining_columns()
 78        # Along which field (or ForeignObjectRel in the reverse join case)
 79        self.join_field = join_field
 80        # Is this join nullabled?
 81        self.nullable = nullable
 82        self.filtered_relation = filtered_relation
 83
 84    def as_sql(
 85        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 86    ) -> tuple[str, list[Any]]:
 87        """
 88        Generate the full
 89           LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
 90        clause for this join.
 91        """
 92        join_conditions = []
 93        params = []
 94        qn = compiler.quote_name_unless_alias
 95        qn2 = connection.ops.quote_name
 96
 97        # Add a join condition for each pair of joining columns.
 98        for lhs_col, rhs_col in self.join_cols:
 99            join_conditions.append(
100                f"{qn(self.parent_alias)}.{qn2(lhs_col)} = {qn(self.table_alias)}.{qn2(rhs_col)}"
101            )
102
103        if self.filtered_relation:
104            try:
105                extra_sql, extra_params = compiler.compile(self.filtered_relation)
106            except FullResultSet:
107                pass
108            else:
109                join_conditions.append(f"({extra_sql})")
110                params.extend(extra_params)
111        if not join_conditions:
112            # This might be a rel on the other end of an actual declared field.
113            declared_field = getattr(self.join_field, "field", self.join_field)
114            raise ValueError(
115                f"Join generated an empty ON clause. {declared_field.__class__} did not yield either "
116                "joining columns or extra restrictions."
117            )
118        on_clause_sql = " AND ".join(join_conditions)
119        alias_str = (
120            "" if self.table_alias == self.table_name else (f" {self.table_alias}")
121        )
122        sql = f"{self.join_type} {qn(self.table_name)}{alias_str} ON ({on_clause_sql})"
123        return sql, params
124
125    def relabeled_clone(self, change_map: dict[str, str]) -> Join:
126        new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
127        new_table_alias = change_map.get(self.table_alias, self.table_alias)
128        if self.filtered_relation is not None:
129            filtered_relation = self.filtered_relation.clone()
130            filtered_relation.path = [
131                change_map.get(p, p) for p in self.filtered_relation.path
132            ]
133        else:
134            filtered_relation = None
135        return self.__class__(
136            self.table_name,
137            new_parent_alias,
138            new_table_alias,
139            self.join_type,
140            self.join_field,
141            self.nullable,
142            filtered_relation=filtered_relation,
143        )
144
145    @property
146    def identity(self) -> tuple[type[Join], str, str, Any, Any]:
147        return (
148            self.__class__,
149            self.table_name,
150            self.parent_alias,
151            self.join_field,
152            self.filtered_relation,
153        )
154
155    def __eq__(self, other: object) -> bool:
156        if not isinstance(other, Join):
157            return NotImplemented
158        return self.identity == other.identity
159
160    def __hash__(self) -> int:
161        return hash(self.identity)
162
163    def equals(self, other: Join) -> bool:
164        # Ignore filtered_relation in equality check.
165        return self.identity[:-1] == other.identity[:-1]
166
167    def demote(self) -> Join:
168        new = self.relabeled_clone({})
169        new.join_type = INNER
170        return new
171
172    def promote(self) -> Join:
173        new = self.relabeled_clone({})
174        new.join_type = LOUTER
175        return new
176
177
178class BaseTable:
179    """
180    The BaseTable class is used for base table references in FROM clause. For
181    example, the SQL "foo" in
182        SELECT * FROM "foo" WHERE somecond
183    could be generated by this class.
184    """
185
186    join_type = None
187    parent_alias = None
188    filtered_relation = None
189
190    def __init__(self, table_name: str, alias: str) -> None:
191        self.table_name = table_name
192        self.table_alias = alias
193
194    def as_sql(
195        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
196    ) -> tuple[str, list[Any]]:
197        alias_str = (
198            "" if self.table_alias == self.table_name else (f" {self.table_alias}")
199        )
200        base_sql = compiler.quote_name_unless_alias(self.table_name)
201        return base_sql + alias_str, []
202
203    def relabeled_clone(self, change_map: dict[str, str]) -> BaseTable:
204        return self.__class__(
205            self.table_name, change_map.get(self.table_alias, self.table_alias)
206        )
207
208    @property
209    def identity(self) -> tuple[type[BaseTable], str, str]:
210        return self.__class__, self.table_name, self.table_alias
211
212    def __eq__(self, other: object) -> bool:
213        if not isinstance(other, BaseTable):
214            return NotImplemented
215        return self.identity == other.identity
216
217    def __hash__(self) -> int:
218        return hash(self.identity)
219
220    def equals(self, other: BaseTable) -> bool:
221        return self.identity == other.identity