Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Useful auxiliary data structures for query construction. Not useful outside
  3the SQL domain.
  4"""
  5
  6from __future__ import annotations
  7
  8from typing import TYPE_CHECKING, Any
  9
 10from plain.models.exceptions import FullResultSet
 11from plain.models.sql.constants import INNER, LOUTER
 12
 13if TYPE_CHECKING:
 14    from plain.models.backends.base.base import BaseDatabaseWrapper
 15    from plain.models.fields import Field
 16    from plain.models.sql.compiler import SQLCompiler
 17
 18
 19class MultiJoin(Exception):
 20    """
 21    Used by join construction code to indicate the point at which a
 22    multi-valued join was attempted (if the caller wants to treat that
 23    exceptionally).
 24    """
 25
 26    def __init__(self, names_pos: int, path_with_names: list[str]) -> None:
 27        self.level = names_pos
 28        # The path travelled, this includes the path to the multijoin.
 29        self.names_with_path = path_with_names
 30
 31
 32class Empty:
 33    pass
 34
 35
 36class Join:
 37    """
 38    Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
 39    FROM entry. For example, the SQL generated could be
 40        LEFT OUTER JOIN "sometable" T1
 41        ON ("othertable"."sometable_id" = "sometable"."id")
 42
 43    This class is primarily used in Query.alias_map. All entries in alias_map
 44    must be Join compatible by providing the following attributes and methods:
 45        - table_name (string)
 46        - table_alias (possible alias for the table, can be None)
 47        - join_type (can be None for those entries that aren't joined from
 48          anything)
 49        - parent_alias (which table is this join's parent, can be None similarly
 50          to join_type)
 51        - as_sql()
 52        - relabeled_clone()
 53    """
 54
 55    def __init__(
 56        self,
 57        table_name: str,
 58        parent_alias: str,
 59        table_alias: str,
 60        join_type: str,
 61        join_field: Field,
 62        nullable: bool,
 63        filtered_relation: Any = None,
 64    ) -> None:
 65        # Join table
 66        self.table_name = table_name
 67        self.parent_alias = parent_alias
 68        # Note: table_alias is not necessarily known at instantiation time.
 69        self.table_alias = table_alias
 70        # LOUTER or INNER
 71        self.join_type = join_type
 72        # A list of 2-tuples to use in the ON clause of the JOIN.
 73        # Each 2-tuple will create one join condition in the ON clause.
 74        self.join_cols = join_field.get_joining_columns()  # type: ignore[attr-defined]
 75        # Along which field (or ForeignObjectRel in the reverse join case)
 76        self.join_field = join_field
 77        # Is this join nullabled?
 78        self.nullable = nullable
 79        self.filtered_relation = filtered_relation
 80
 81    def as_sql(
 82        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 83    ) -> tuple[str, list[Any]]:
 84        """
 85        Generate the full
 86           LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
 87        clause for this join.
 88        """
 89        join_conditions = []
 90        params = []
 91        qn = compiler.quote_name_unless_alias
 92        qn2 = connection.ops.quote_name
 93
 94        # Add a join condition for each pair of joining columns.
 95        for lhs_col, rhs_col in self.join_cols:
 96            join_conditions.append(
 97                f"{qn(self.parent_alias)}.{qn2(lhs_col)} = {qn(self.table_alias)}.{qn2(rhs_col)}"
 98            )
 99
100        # Add a single condition inside parentheses for whatever
101        # get_extra_restriction() returns.
102        extra_cond = self.join_field.get_extra_restriction(
103            self.table_alias, self.parent_alias
104        )
105        if extra_cond:
106            extra_sql, extra_params = compiler.compile(extra_cond)
107            join_conditions.append(f"({extra_sql})")
108            params.extend(extra_params)
109        if self.filtered_relation:
110            try:
111                extra_sql, extra_params = compiler.compile(self.filtered_relation)
112            except FullResultSet:
113                pass
114            else:
115                join_conditions.append(f"({extra_sql})")
116                params.extend(extra_params)
117        if not join_conditions:
118            # This might be a rel on the other end of an actual declared field.
119            declared_field = getattr(self.join_field, "field", self.join_field)
120            raise ValueError(
121                f"Join generated an empty ON clause. {declared_field.__class__} did not yield either "
122                "joining columns or extra restrictions."
123            )
124        on_clause_sql = " AND ".join(join_conditions)
125        alias_str = (
126            "" if self.table_alias == self.table_name else (f" {self.table_alias}")
127        )
128        sql = f"{self.join_type} {qn(self.table_name)}{alias_str} ON ({on_clause_sql})"
129        return sql, params
130
131    def relabeled_clone(self, change_map: dict[str, str]) -> Join:
132        new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
133        new_table_alias = change_map.get(self.table_alias, self.table_alias)
134        if self.filtered_relation is not None:
135            filtered_relation = self.filtered_relation.clone()
136            filtered_relation.path = [
137                change_map.get(p, p) for p in self.filtered_relation.path
138            ]
139        else:
140            filtered_relation = None
141        return self.__class__(
142            self.table_name,
143            new_parent_alias,
144            new_table_alias,
145            self.join_type,
146            self.join_field,
147            self.nullable,
148            filtered_relation=filtered_relation,
149        )
150
151    @property
152    def identity(self) -> tuple[type[Join], str, str, Any, Any]:
153        return (
154            self.__class__,
155            self.table_name,
156            self.parent_alias,
157            self.join_field,
158            self.filtered_relation,
159        )
160
161    def __eq__(self, other: object) -> bool:
162        if not isinstance(other, Join):
163            return NotImplemented
164        return self.identity == other.identity
165
166    def __hash__(self) -> int:
167        return hash(self.identity)
168
169    def equals(self, other: Join) -> bool:
170        # Ignore filtered_relation in equality check.
171        return self.identity[:-1] == other.identity[:-1]
172
173    def demote(self) -> Join:
174        new = self.relabeled_clone({})
175        new.join_type = INNER
176        return new
177
178    def promote(self) -> Join:
179        new = self.relabeled_clone({})
180        new.join_type = LOUTER
181        return new
182
183
184class BaseTable:
185    """
186    The BaseTable class is used for base table references in FROM clause. For
187    example, the SQL "foo" in
188        SELECT * FROM "foo" WHERE somecond
189    could be generated by this class.
190    """
191
192    join_type = None
193    parent_alias = None
194    filtered_relation = None
195
196    def __init__(self, table_name: str, alias: str) -> None:
197        self.table_name = table_name
198        self.table_alias = alias
199
200    def as_sql(
201        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
202    ) -> tuple[str, list[Any]]:
203        alias_str = (
204            "" if self.table_alias == self.table_name else (f" {self.table_alias}")
205        )
206        base_sql = compiler.quote_name_unless_alias(self.table_name)
207        return base_sql + alias_str, []
208
209    def relabeled_clone(self, change_map: dict[str, str]) -> BaseTable:
210        return self.__class__(
211            self.table_name, change_map.get(self.table_alias, self.table_alias)
212        )
213
214    @property
215    def identity(self) -> tuple[type[BaseTable], str, str]:
216        return self.__class__, self.table_name, self.table_alias
217
218    def __eq__(self, other: object) -> bool:
219        if not isinstance(other, BaseTable):
220            return NotImplemented
221        return self.identity == other.identity
222
223    def __hash__(self) -> int:
224        return hash(self.identity)
225
226    def equals(self, other: BaseTable) -> bool:
227        return self.identity == other.identity