Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Useful auxiliary data structures for query construction. Not useful outside
  3the SQL domain.
  4"""
  5
  6from plain.exceptions import FullResultSet
  7from plain.models.sql.constants import INNER, LOUTER
  8
  9
 10class MultiJoin(Exception):
 11    """
 12    Used by join construction code to indicate the point at which a
 13    multi-valued join was attempted (if the caller wants to treat that
 14    exceptionally).
 15    """
 16
 17    def __init__(self, names_pos, path_with_names):
 18        self.level = names_pos
 19        # The path travelled, this includes the path to the multijoin.
 20        self.names_with_path = path_with_names
 21
 22
 23class Empty:
 24    pass
 25
 26
 27class Join:
 28    """
 29    Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
 30    FROM entry. For example, the SQL generated could be
 31        LEFT OUTER JOIN "sometable" T1
 32        ON ("othertable"."sometable_id" = "sometable"."id")
 33
 34    This class is primarily used in Query.alias_map. All entries in alias_map
 35    must be Join compatible by providing the following attributes and methods:
 36        - table_name (string)
 37        - table_alias (possible alias for the table, can be None)
 38        - join_type (can be None for those entries that aren't joined from
 39          anything)
 40        - parent_alias (which table is this join's parent, can be None similarly
 41          to join_type)
 42        - as_sql()
 43        - relabeled_clone()
 44    """
 45
 46    def __init__(
 47        self,
 48        table_name,
 49        parent_alias,
 50        table_alias,
 51        join_type,
 52        join_field,
 53        nullable,
 54        filtered_relation=None,
 55    ):
 56        # Join table
 57        self.table_name = table_name
 58        self.parent_alias = parent_alias
 59        # Note: table_alias is not necessarily known at instantiation time.
 60        self.table_alias = table_alias
 61        # LOUTER or INNER
 62        self.join_type = join_type
 63        # A list of 2-tuples to use in the ON clause of the JOIN.
 64        # Each 2-tuple will create one join condition in the ON clause.
 65        self.join_cols = join_field.get_joining_columns()
 66        # Along which field (or ForeignObjectRel in the reverse join case)
 67        self.join_field = join_field
 68        # Is this join nullabled?
 69        self.nullable = nullable
 70        self.filtered_relation = filtered_relation
 71
 72    def as_sql(self, compiler, connection):
 73        """
 74        Generate the full
 75           LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
 76        clause for this join.
 77        """
 78        join_conditions = []
 79        params = []
 80        qn = compiler.quote_name_unless_alias
 81        qn2 = connection.ops.quote_name
 82
 83        # Add a join condition for each pair of joining columns.
 84        for lhs_col, rhs_col in self.join_cols:
 85            join_conditions.append(
 86                f"{qn(self.parent_alias)}.{qn2(lhs_col)} = {qn(self.table_alias)}.{qn2(rhs_col)}"
 87            )
 88
 89        # Add a single condition inside parentheses for whatever
 90        # get_extra_restriction() returns.
 91        extra_cond = self.join_field.get_extra_restriction(
 92            self.table_alias, self.parent_alias
 93        )
 94        if extra_cond:
 95            extra_sql, extra_params = compiler.compile(extra_cond)
 96            join_conditions.append(f"({extra_sql})")
 97            params.extend(extra_params)
 98        if self.filtered_relation:
 99            try:
100                extra_sql, extra_params = compiler.compile(self.filtered_relation)
101            except FullResultSet:
102                pass
103            else:
104                join_conditions.append(f"({extra_sql})")
105                params.extend(extra_params)
106        if not join_conditions:
107            # This might be a rel on the other end of an actual declared field.
108            declared_field = getattr(self.join_field, "field", self.join_field)
109            raise ValueError(
110                f"Join generated an empty ON clause. {declared_field.__class__} did not yield either "
111                "joining columns or extra restrictions."
112            )
113        on_clause_sql = " AND ".join(join_conditions)
114        alias_str = (
115            "" if self.table_alias == self.table_name else (f" {self.table_alias}")
116        )
117        sql = f"{self.join_type} {qn(self.table_name)}{alias_str} ON ({on_clause_sql})"
118        return sql, params
119
120    def relabeled_clone(self, change_map):
121        new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
122        new_table_alias = change_map.get(self.table_alias, self.table_alias)
123        if self.filtered_relation is not None:
124            filtered_relation = self.filtered_relation.clone()
125            filtered_relation.path = [
126                change_map.get(p, p) for p in self.filtered_relation.path
127            ]
128        else:
129            filtered_relation = None
130        return self.__class__(
131            self.table_name,
132            new_parent_alias,
133            new_table_alias,
134            self.join_type,
135            self.join_field,
136            self.nullable,
137            filtered_relation=filtered_relation,
138        )
139
140    @property
141    def identity(self):
142        return (
143            self.__class__,
144            self.table_name,
145            self.parent_alias,
146            self.join_field,
147            self.filtered_relation,
148        )
149
150    def __eq__(self, other):
151        if not isinstance(other, Join):
152            return NotImplemented
153        return self.identity == other.identity
154
155    def __hash__(self):
156        return hash(self.identity)
157
158    def equals(self, other):
159        # Ignore filtered_relation in equality check.
160        return self.identity[:-1] == other.identity[:-1]
161
162    def demote(self):
163        new = self.relabeled_clone({})
164        new.join_type = INNER
165        return new
166
167    def promote(self):
168        new = self.relabeled_clone({})
169        new.join_type = LOUTER
170        return new
171
172
173class BaseTable:
174    """
175    The BaseTable class is used for base table references in FROM clause. For
176    example, the SQL "foo" in
177        SELECT * FROM "foo" WHERE somecond
178    could be generated by this class.
179    """
180
181    join_type = None
182    parent_alias = None
183    filtered_relation = None
184
185    def __init__(self, table_name, alias):
186        self.table_name = table_name
187        self.table_alias = alias
188
189    def as_sql(self, compiler, connection):
190        alias_str = (
191            "" if self.table_alias == self.table_name else (f" {self.table_alias}")
192        )
193        base_sql = compiler.quote_name_unless_alias(self.table_name)
194        return base_sql + alias_str, []
195
196    def relabeled_clone(self, change_map):
197        return self.__class__(
198            self.table_name, change_map.get(self.table_alias, self.table_alias)
199        )
200
201    @property
202    def identity(self):
203        return self.__class__, self.table_name, self.table_alias
204
205    def __eq__(self, other):
206        if not isinstance(other, BaseTable):
207            return NotImplemented
208        return self.identity == other.identity
209
210    def __hash__(self):
211        return hash(self.identity)
212
213    def equals(self, other):
214        return self.identity == other.identity