1"""
2Useful auxiliary data structures for query construction. Not useful outside
3the SQL domain.
4"""
5from plain.exceptions import FullResultSet
6from plain.models.sql.constants import INNER, LOUTER
7
8
9class MultiJoin(Exception):
10 """
11 Used by join construction code to indicate the point at which a
12 multi-valued join was attempted (if the caller wants to treat that
13 exceptionally).
14 """
15
16 def __init__(self, names_pos, path_with_names):
17 self.level = names_pos
18 # The path travelled, this includes the path to the multijoin.
19 self.names_with_path = path_with_names
20
21
22class Empty:
23 pass
24
25
26class Join:
27 """
28 Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
29 FROM entry. For example, the SQL generated could be
30 LEFT OUTER JOIN "sometable" T1
31 ON ("othertable"."sometable_id" = "sometable"."id")
32
33 This class is primarily used in Query.alias_map. All entries in alias_map
34 must be Join compatible by providing the following attributes and methods:
35 - table_name (string)
36 - table_alias (possible alias for the table, can be None)
37 - join_type (can be None for those entries that aren't joined from
38 anything)
39 - parent_alias (which table is this join's parent, can be None similarly
40 to join_type)
41 - as_sql()
42 - relabeled_clone()
43 """
44
45 def __init__(
46 self,
47 table_name,
48 parent_alias,
49 table_alias,
50 join_type,
51 join_field,
52 nullable,
53 filtered_relation=None,
54 ):
55 # Join table
56 self.table_name = table_name
57 self.parent_alias = parent_alias
58 # Note: table_alias is not necessarily known at instantiation time.
59 self.table_alias = table_alias
60 # LOUTER or INNER
61 self.join_type = join_type
62 # A list of 2-tuples to use in the ON clause of the JOIN.
63 # Each 2-tuple will create one join condition in the ON clause.
64 self.join_cols = join_field.get_joining_columns()
65 # Along which field (or ForeignObjectRel in the reverse join case)
66 self.join_field = join_field
67 # Is this join nullabled?
68 self.nullable = nullable
69 self.filtered_relation = filtered_relation
70
71 def as_sql(self, compiler, connection):
72 """
73 Generate the full
74 LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
75 clause for this join.
76 """
77 join_conditions = []
78 params = []
79 qn = compiler.quote_name_unless_alias
80 qn2 = connection.ops.quote_name
81
82 # Add a join condition for each pair of joining columns.
83 for lhs_col, rhs_col in self.join_cols:
84 join_conditions.append(
85 "{}.{} = {}.{}".format(
86 qn(self.parent_alias),
87 qn2(lhs_col),
88 qn(self.table_alias),
89 qn2(rhs_col),
90 )
91 )
92
93 # Add a single condition inside parentheses for whatever
94 # get_extra_restriction() returns.
95 extra_cond = self.join_field.get_extra_restriction(
96 self.table_alias, self.parent_alias
97 )
98 if extra_cond:
99 extra_sql, extra_params = compiler.compile(extra_cond)
100 join_conditions.append("(%s)" % extra_sql)
101 params.extend(extra_params)
102 if self.filtered_relation:
103 try:
104 extra_sql, extra_params = compiler.compile(self.filtered_relation)
105 except FullResultSet:
106 pass
107 else:
108 join_conditions.append("(%s)" % extra_sql)
109 params.extend(extra_params)
110 if not join_conditions:
111 # This might be a rel on the other end of an actual declared field.
112 declared_field = getattr(self.join_field, "field", self.join_field)
113 raise ValueError(
114 "Join generated an empty ON clause. %s did not yield either "
115 "joining columns or extra restrictions." % declared_field.__class__
116 )
117 on_clause_sql = " AND ".join(join_conditions)
118 alias_str = (
119 "" if self.table_alias == self.table_name else (" %s" % self.table_alias)
120 )
121 sql = f"{self.join_type} {qn(self.table_name)}{alias_str} ON ({on_clause_sql})"
122 return sql, params
123
124 def relabeled_clone(self, change_map):
125 new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
126 new_table_alias = change_map.get(self.table_alias, self.table_alias)
127 if self.filtered_relation is not None:
128 filtered_relation = self.filtered_relation.clone()
129 filtered_relation.path = [
130 change_map.get(p, p) for p in self.filtered_relation.path
131 ]
132 else:
133 filtered_relation = None
134 return self.__class__(
135 self.table_name,
136 new_parent_alias,
137 new_table_alias,
138 self.join_type,
139 self.join_field,
140 self.nullable,
141 filtered_relation=filtered_relation,
142 )
143
144 @property
145 def identity(self):
146 return (
147 self.__class__,
148 self.table_name,
149 self.parent_alias,
150 self.join_field,
151 self.filtered_relation,
152 )
153
154 def __eq__(self, other):
155 if not isinstance(other, Join):
156 return NotImplemented
157 return self.identity == other.identity
158
159 def __hash__(self):
160 return hash(self.identity)
161
162 def equals(self, other):
163 # Ignore filtered_relation in equality check.
164 return self.identity[:-1] == other.identity[:-1]
165
166 def demote(self):
167 new = self.relabeled_clone({})
168 new.join_type = INNER
169 return new
170
171 def promote(self):
172 new = self.relabeled_clone({})
173 new.join_type = LOUTER
174 return new
175
176
177class BaseTable:
178 """
179 The BaseTable class is used for base table references in FROM clause. For
180 example, the SQL "foo" in
181 SELECT * FROM "foo" WHERE somecond
182 could be generated by this class.
183 """
184
185 join_type = None
186 parent_alias = None
187 filtered_relation = None
188
189 def __init__(self, table_name, alias):
190 self.table_name = table_name
191 self.table_alias = alias
192
193 def as_sql(self, compiler, connection):
194 alias_str = (
195 "" if self.table_alias == self.table_name else (" %s" % self.table_alias)
196 )
197 base_sql = compiler.quote_name_unless_alias(self.table_name)
198 return base_sql + alias_str, []
199
200 def relabeled_clone(self, change_map):
201 return self.__class__(
202 self.table_name, change_map.get(self.table_alias, self.table_alias)
203 )
204
205 @property
206 def identity(self):
207 return self.__class__, self.table_name, self.table_alias
208
209 def __eq__(self, other):
210 if not isinstance(other, BaseTable):
211 return NotImplemented
212 return self.identity == other.identity
213
214 def __hash__(self):
215 return hash(self.identity)
216
217 def equals(self, other):
218 return self.identity == other.identity