1"""
2Useful auxiliary data structures for query construction. Not useful outside
3the SQL domain.
4"""
5
6from __future__ import annotations
7
8from typing import TYPE_CHECKING, Any
9
10from plain.models.exceptions import FullResultSet
11from plain.models.sql.constants import INNER, LOUTER
12
13if TYPE_CHECKING:
14 from plain.models.backends.base.base import BaseDatabaseWrapper
15 from plain.models.fields import Field
16 from plain.models.sql.compiler import SQLCompiler
17
18
19class MultiJoin(Exception):
20 """
21 Used by join construction code to indicate the point at which a
22 multi-valued join was attempted (if the caller wants to treat that
23 exceptionally).
24 """
25
26 def __init__(self, names_pos: int, path_with_names: list[str]) -> None:
27 self.level = names_pos
28 # The path travelled, this includes the path to the multijoin.
29 self.names_with_path = path_with_names
30
31
32class Empty:
33 pass
34
35
36class Join:
37 """
38 Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
39 FROM entry. For example, the SQL generated could be
40 LEFT OUTER JOIN "sometable" T1
41 ON ("othertable"."sometable_id" = "sometable"."id")
42
43 This class is primarily used in Query.alias_map. All entries in alias_map
44 must be Join compatible by providing the following attributes and methods:
45 - table_name (string)
46 - table_alias (possible alias for the table, can be None)
47 - join_type (can be None for those entries that aren't joined from
48 anything)
49 - parent_alias (which table is this join's parent, can be None similarly
50 to join_type)
51 - as_sql()
52 - relabeled_clone()
53 """
54
55 def __init__(
56 self,
57 table_name: str,
58 parent_alias: str,
59 table_alias: str,
60 join_type: str,
61 join_field: Field,
62 nullable: bool,
63 filtered_relation: Any = None,
64 ) -> None:
65 # Join table
66 self.table_name = table_name
67 self.parent_alias = parent_alias
68 # Note: table_alias is not necessarily known at instantiation time.
69 self.table_alias = table_alias
70 # LOUTER or INNER
71 self.join_type = join_type
72 # A list of 2-tuples to use in the ON clause of the JOIN.
73 # Each 2-tuple will create one join condition in the ON clause.
74 self.join_cols = join_field.get_joining_columns() # type: ignore[attr-defined]
75 # Along which field (or ForeignObjectRel in the reverse join case)
76 self.join_field = join_field
77 # Is this join nullabled?
78 self.nullable = nullable
79 self.filtered_relation = filtered_relation
80
81 def as_sql(
82 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
83 ) -> tuple[str, list[Any]]:
84 """
85 Generate the full
86 LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
87 clause for this join.
88 """
89 join_conditions = []
90 params = []
91 qn = compiler.quote_name_unless_alias
92 qn2 = connection.ops.quote_name
93
94 # Add a join condition for each pair of joining columns.
95 for lhs_col, rhs_col in self.join_cols:
96 join_conditions.append(
97 f"{qn(self.parent_alias)}.{qn2(lhs_col)} = {qn(self.table_alias)}.{qn2(rhs_col)}"
98 )
99
100 # Add a single condition inside parentheses for whatever
101 # get_extra_restriction() returns.
102 extra_cond = self.join_field.get_extra_restriction(
103 self.table_alias, self.parent_alias
104 )
105 if extra_cond:
106 extra_sql, extra_params = compiler.compile(extra_cond)
107 join_conditions.append(f"({extra_sql})")
108 params.extend(extra_params)
109 if self.filtered_relation:
110 try:
111 extra_sql, extra_params = compiler.compile(self.filtered_relation)
112 except FullResultSet:
113 pass
114 else:
115 join_conditions.append(f"({extra_sql})")
116 params.extend(extra_params)
117 if not join_conditions:
118 # This might be a rel on the other end of an actual declared field.
119 declared_field = getattr(self.join_field, "field", self.join_field)
120 raise ValueError(
121 f"Join generated an empty ON clause. {declared_field.__class__} did not yield either "
122 "joining columns or extra restrictions."
123 )
124 on_clause_sql = " AND ".join(join_conditions)
125 alias_str = (
126 "" if self.table_alias == self.table_name else (f" {self.table_alias}")
127 )
128 sql = f"{self.join_type} {qn(self.table_name)}{alias_str} ON ({on_clause_sql})"
129 return sql, params
130
131 def relabeled_clone(self, change_map: dict[str, str]) -> Join:
132 new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
133 new_table_alias = change_map.get(self.table_alias, self.table_alias)
134 if self.filtered_relation is not None:
135 filtered_relation = self.filtered_relation.clone()
136 filtered_relation.path = [
137 change_map.get(p, p) for p in self.filtered_relation.path
138 ]
139 else:
140 filtered_relation = None
141 return self.__class__(
142 self.table_name,
143 new_parent_alias,
144 new_table_alias,
145 self.join_type,
146 self.join_field,
147 self.nullable,
148 filtered_relation=filtered_relation,
149 )
150
151 @property
152 def identity(self) -> tuple[type[Join], str, str, Any, Any]:
153 return (
154 self.__class__,
155 self.table_name,
156 self.parent_alias,
157 self.join_field,
158 self.filtered_relation,
159 )
160
161 def __eq__(self, other: object) -> bool:
162 if not isinstance(other, Join):
163 return NotImplemented
164 return self.identity == other.identity
165
166 def __hash__(self) -> int:
167 return hash(self.identity)
168
169 def equals(self, other: Join) -> bool:
170 # Ignore filtered_relation in equality check.
171 return self.identity[:-1] == other.identity[:-1]
172
173 def demote(self) -> Join:
174 new = self.relabeled_clone({})
175 new.join_type = INNER
176 return new
177
178 def promote(self) -> Join:
179 new = self.relabeled_clone({})
180 new.join_type = LOUTER
181 return new
182
183
184class BaseTable:
185 """
186 The BaseTable class is used for base table references in FROM clause. For
187 example, the SQL "foo" in
188 SELECT * FROM "foo" WHERE somecond
189 could be generated by this class.
190 """
191
192 join_type = None
193 parent_alias = None
194 filtered_relation = None
195
196 def __init__(self, table_name: str, alias: str) -> None:
197 self.table_name = table_name
198 self.table_alias = alias
199
200 def as_sql(
201 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
202 ) -> tuple[str, list[Any]]:
203 alias_str = (
204 "" if self.table_alias == self.table_name else (f" {self.table_alias}")
205 )
206 base_sql = compiler.quote_name_unless_alias(self.table_name)
207 return base_sql + alias_str, []
208
209 def relabeled_clone(self, change_map: dict[str, str]) -> BaseTable:
210 return self.__class__(
211 self.table_name, change_map.get(self.table_alias, self.table_alias)
212 )
213
214 @property
215 def identity(self) -> tuple[type[BaseTable], str, str]:
216 return self.__class__, self.table_name, self.table_alias
217
218 def __eq__(self, other: object) -> bool:
219 if not isinstance(other, BaseTable):
220 return NotImplemented
221 return self.identity == other.identity
222
223 def __hash__(self) -> int:
224 return hash(self.identity)
225
226 def equals(self, other: BaseTable) -> bool:
227 return self.identity == other.identity