Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Useful auxiliary data structures for query construction. Not useful outside
  3the SQL domain.
  4"""
  5from plain.exceptions import FullResultSet
  6from plain.models.sql.constants import INNER, LOUTER
  7
  8
  9class MultiJoin(Exception):
 10    """
 11    Used by join construction code to indicate the point at which a
 12    multi-valued join was attempted (if the caller wants to treat that
 13    exceptionally).
 14    """
 15
 16    def __init__(self, names_pos, path_with_names):
 17        self.level = names_pos
 18        # The path travelled, this includes the path to the multijoin.
 19        self.names_with_path = path_with_names
 20
 21
 22class Empty:
 23    pass
 24
 25
 26class Join:
 27    """
 28    Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
 29    FROM entry. For example, the SQL generated could be
 30        LEFT OUTER JOIN "sometable" T1
 31        ON ("othertable"."sometable_id" = "sometable"."id")
 32
 33    This class is primarily used in Query.alias_map. All entries in alias_map
 34    must be Join compatible by providing the following attributes and methods:
 35        - table_name (string)
 36        - table_alias (possible alias for the table, can be None)
 37        - join_type (can be None for those entries that aren't joined from
 38          anything)
 39        - parent_alias (which table is this join's parent, can be None similarly
 40          to join_type)
 41        - as_sql()
 42        - relabeled_clone()
 43    """
 44
 45    def __init__(
 46        self,
 47        table_name,
 48        parent_alias,
 49        table_alias,
 50        join_type,
 51        join_field,
 52        nullable,
 53        filtered_relation=None,
 54    ):
 55        # Join table
 56        self.table_name = table_name
 57        self.parent_alias = parent_alias
 58        # Note: table_alias is not necessarily known at instantiation time.
 59        self.table_alias = table_alias
 60        # LOUTER or INNER
 61        self.join_type = join_type
 62        # A list of 2-tuples to use in the ON clause of the JOIN.
 63        # Each 2-tuple will create one join condition in the ON clause.
 64        self.join_cols = join_field.get_joining_columns()
 65        # Along which field (or ForeignObjectRel in the reverse join case)
 66        self.join_field = join_field
 67        # Is this join nullabled?
 68        self.nullable = nullable
 69        self.filtered_relation = filtered_relation
 70
 71    def as_sql(self, compiler, connection):
 72        """
 73        Generate the full
 74           LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
 75        clause for this join.
 76        """
 77        join_conditions = []
 78        params = []
 79        qn = compiler.quote_name_unless_alias
 80        qn2 = connection.ops.quote_name
 81
 82        # Add a join condition for each pair of joining columns.
 83        for lhs_col, rhs_col in self.join_cols:
 84            join_conditions.append(
 85                "{}.{} = {}.{}".format(
 86                    qn(self.parent_alias),
 87                    qn2(lhs_col),
 88                    qn(self.table_alias),
 89                    qn2(rhs_col),
 90                )
 91            )
 92
 93        # Add a single condition inside parentheses for whatever
 94        # get_extra_restriction() returns.
 95        extra_cond = self.join_field.get_extra_restriction(
 96            self.table_alias, self.parent_alias
 97        )
 98        if extra_cond:
 99            extra_sql, extra_params = compiler.compile(extra_cond)
100            join_conditions.append("(%s)" % extra_sql)
101            params.extend(extra_params)
102        if self.filtered_relation:
103            try:
104                extra_sql, extra_params = compiler.compile(self.filtered_relation)
105            except FullResultSet:
106                pass
107            else:
108                join_conditions.append("(%s)" % extra_sql)
109                params.extend(extra_params)
110        if not join_conditions:
111            # This might be a rel on the other end of an actual declared field.
112            declared_field = getattr(self.join_field, "field", self.join_field)
113            raise ValueError(
114                "Join generated an empty ON clause. %s did not yield either "
115                "joining columns or extra restrictions." % declared_field.__class__
116            )
117        on_clause_sql = " AND ".join(join_conditions)
118        alias_str = (
119            "" if self.table_alias == self.table_name else (" %s" % self.table_alias)
120        )
121        sql = f"{self.join_type} {qn(self.table_name)}{alias_str} ON ({on_clause_sql})"
122        return sql, params
123
124    def relabeled_clone(self, change_map):
125        new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
126        new_table_alias = change_map.get(self.table_alias, self.table_alias)
127        if self.filtered_relation is not None:
128            filtered_relation = self.filtered_relation.clone()
129            filtered_relation.path = [
130                change_map.get(p, p) for p in self.filtered_relation.path
131            ]
132        else:
133            filtered_relation = None
134        return self.__class__(
135            self.table_name,
136            new_parent_alias,
137            new_table_alias,
138            self.join_type,
139            self.join_field,
140            self.nullable,
141            filtered_relation=filtered_relation,
142        )
143
144    @property
145    def identity(self):
146        return (
147            self.__class__,
148            self.table_name,
149            self.parent_alias,
150            self.join_field,
151            self.filtered_relation,
152        )
153
154    def __eq__(self, other):
155        if not isinstance(other, Join):
156            return NotImplemented
157        return self.identity == other.identity
158
159    def __hash__(self):
160        return hash(self.identity)
161
162    def equals(self, other):
163        # Ignore filtered_relation in equality check.
164        return self.identity[:-1] == other.identity[:-1]
165
166    def demote(self):
167        new = self.relabeled_clone({})
168        new.join_type = INNER
169        return new
170
171    def promote(self):
172        new = self.relabeled_clone({})
173        new.join_type = LOUTER
174        return new
175
176
177class BaseTable:
178    """
179    The BaseTable class is used for base table references in FROM clause. For
180    example, the SQL "foo" in
181        SELECT * FROM "foo" WHERE somecond
182    could be generated by this class.
183    """
184
185    join_type = None
186    parent_alias = None
187    filtered_relation = None
188
189    def __init__(self, table_name, alias):
190        self.table_name = table_name
191        self.table_alias = alias
192
193    def as_sql(self, compiler, connection):
194        alias_str = (
195            "" if self.table_alias == self.table_name else (" %s" % self.table_alias)
196        )
197        base_sql = compiler.quote_name_unless_alias(self.table_name)
198        return base_sql + alias_str, []
199
200    def relabeled_clone(self, change_map):
201        return self.__class__(
202            self.table_name, change_map.get(self.table_alias, self.table_alias)
203        )
204
205    @property
206    def identity(self):
207        return self.__class__, self.table_name, self.table_alias
208
209    def __eq__(self, other):
210        if not isinstance(other, BaseTable):
211            return NotImplemented
212        return self.identity == other.identity
213
214    def __hash__(self):
215        return hash(self.identity)
216
217    def equals(self, other):
218        return self.identity == other.identity