1"""
2Useful auxiliary data structures for query construction. Not useful outside
3the SQL domain.
4"""
5
6from plain.exceptions import FullResultSet
7from plain.models.sql.constants import INNER, LOUTER
8
9
10class MultiJoin(Exception):
11 """
12 Used by join construction code to indicate the point at which a
13 multi-valued join was attempted (if the caller wants to treat that
14 exceptionally).
15 """
16
17 def __init__(self, names_pos, path_with_names):
18 self.level = names_pos
19 # The path travelled, this includes the path to the multijoin.
20 self.names_with_path = path_with_names
21
22
23class Empty:
24 pass
25
26
27class Join:
28 """
29 Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
30 FROM entry. For example, the SQL generated could be
31 LEFT OUTER JOIN "sometable" T1
32 ON ("othertable"."sometable_id" = "sometable"."id")
33
34 This class is primarily used in Query.alias_map. All entries in alias_map
35 must be Join compatible by providing the following attributes and methods:
36 - table_name (string)
37 - table_alias (possible alias for the table, can be None)
38 - join_type (can be None for those entries that aren't joined from
39 anything)
40 - parent_alias (which table is this join's parent, can be None similarly
41 to join_type)
42 - as_sql()
43 - relabeled_clone()
44 """
45
46 def __init__(
47 self,
48 table_name,
49 parent_alias,
50 table_alias,
51 join_type,
52 join_field,
53 nullable,
54 filtered_relation=None,
55 ):
56 # Join table
57 self.table_name = table_name
58 self.parent_alias = parent_alias
59 # Note: table_alias is not necessarily known at instantiation time.
60 self.table_alias = table_alias
61 # LOUTER or INNER
62 self.join_type = join_type
63 # A list of 2-tuples to use in the ON clause of the JOIN.
64 # Each 2-tuple will create one join condition in the ON clause.
65 self.join_cols = join_field.get_joining_columns()
66 # Along which field (or ForeignObjectRel in the reverse join case)
67 self.join_field = join_field
68 # Is this join nullabled?
69 self.nullable = nullable
70 self.filtered_relation = filtered_relation
71
72 def as_sql(self, compiler, connection):
73 """
74 Generate the full
75 LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
76 clause for this join.
77 """
78 join_conditions = []
79 params = []
80 qn = compiler.quote_name_unless_alias
81 qn2 = connection.ops.quote_name
82
83 # Add a join condition for each pair of joining columns.
84 for lhs_col, rhs_col in self.join_cols:
85 join_conditions.append(
86 f"{qn(self.parent_alias)}.{qn2(lhs_col)} = {qn(self.table_alias)}.{qn2(rhs_col)}"
87 )
88
89 # Add a single condition inside parentheses for whatever
90 # get_extra_restriction() returns.
91 extra_cond = self.join_field.get_extra_restriction(
92 self.table_alias, self.parent_alias
93 )
94 if extra_cond:
95 extra_sql, extra_params = compiler.compile(extra_cond)
96 join_conditions.append(f"({extra_sql})")
97 params.extend(extra_params)
98 if self.filtered_relation:
99 try:
100 extra_sql, extra_params = compiler.compile(self.filtered_relation)
101 except FullResultSet:
102 pass
103 else:
104 join_conditions.append(f"({extra_sql})")
105 params.extend(extra_params)
106 if not join_conditions:
107 # This might be a rel on the other end of an actual declared field.
108 declared_field = getattr(self.join_field, "field", self.join_field)
109 raise ValueError(
110 f"Join generated an empty ON clause. {declared_field.__class__} did not yield either "
111 "joining columns or extra restrictions."
112 )
113 on_clause_sql = " AND ".join(join_conditions)
114 alias_str = (
115 "" if self.table_alias == self.table_name else (f" {self.table_alias}")
116 )
117 sql = f"{self.join_type} {qn(self.table_name)}{alias_str} ON ({on_clause_sql})"
118 return sql, params
119
120 def relabeled_clone(self, change_map):
121 new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
122 new_table_alias = change_map.get(self.table_alias, self.table_alias)
123 if self.filtered_relation is not None:
124 filtered_relation = self.filtered_relation.clone()
125 filtered_relation.path = [
126 change_map.get(p, p) for p in self.filtered_relation.path
127 ]
128 else:
129 filtered_relation = None
130 return self.__class__(
131 self.table_name,
132 new_parent_alias,
133 new_table_alias,
134 self.join_type,
135 self.join_field,
136 self.nullable,
137 filtered_relation=filtered_relation,
138 )
139
140 @property
141 def identity(self):
142 return (
143 self.__class__,
144 self.table_name,
145 self.parent_alias,
146 self.join_field,
147 self.filtered_relation,
148 )
149
150 def __eq__(self, other):
151 if not isinstance(other, Join):
152 return NotImplemented
153 return self.identity == other.identity
154
155 def __hash__(self):
156 return hash(self.identity)
157
158 def equals(self, other):
159 # Ignore filtered_relation in equality check.
160 return self.identity[:-1] == other.identity[:-1]
161
162 def demote(self):
163 new = self.relabeled_clone({})
164 new.join_type = INNER
165 return new
166
167 def promote(self):
168 new = self.relabeled_clone({})
169 new.join_type = LOUTER
170 return new
171
172
173class BaseTable:
174 """
175 The BaseTable class is used for base table references in FROM clause. For
176 example, the SQL "foo" in
177 SELECT * FROM "foo" WHERE somecond
178 could be generated by this class.
179 """
180
181 join_type = None
182 parent_alias = None
183 filtered_relation = None
184
185 def __init__(self, table_name, alias):
186 self.table_name = table_name
187 self.table_alias = alias
188
189 def as_sql(self, compiler, connection):
190 alias_str = (
191 "" if self.table_alias == self.table_name else (f" {self.table_alias}")
192 )
193 base_sql = compiler.quote_name_unless_alias(self.table_name)
194 return base_sql + alias_str, []
195
196 def relabeled_clone(self, change_map):
197 return self.__class__(
198 self.table_name, change_map.get(self.table_alias, self.table_alias)
199 )
200
201 @property
202 def identity(self):
203 return self.__class__, self.table_name, self.table_alias
204
205 def __eq__(self, other):
206 if not isinstance(other, BaseTable):
207 return NotImplemented
208 return self.identity == other.identity
209
210 def __hash__(self):
211 return hash(self.identity)
212
213 def equals(self, other):
214 return self.identity == other.identity