1"""
2Useful auxiliary data structures for query construction. Not useful outside
3the SQL domain.
4"""
5
6from __future__ import annotations
7
8from typing import TYPE_CHECKING, Any
9
10from plain.models.exceptions import FullResultSet
11from plain.models.sql.constants import INNER, LOUTER
12
13if TYPE_CHECKING:
14 from plain.models.backends.base.base import BaseDatabaseWrapper
15 from plain.models.fields.related import ForeignKeyField
16 from plain.models.fields.reverse_related import ForeignObjectRel
17 from plain.models.sql.compiler import SQLCompiler
18
19
20class MultiJoin(Exception):
21 """
22 Used by join construction code to indicate the point at which a
23 multi-valued join was attempted (if the caller wants to treat that
24 exceptionally).
25 """
26
27 def __init__(
28 self, names_pos: int, path_with_names: list[tuple[str, list[Any]]]
29 ) -> None:
30 self.level = names_pos
31 # The path travelled, this includes the path to the multijoin.
32 self.names_with_path = path_with_names
33
34
35class Empty:
36 pass
37
38
39class Join:
40 """
41 Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
42 FROM entry. For example, the SQL generated could be
43 LEFT OUTER JOIN "sometable" T1
44 ON ("othertable"."sometable_id" = "sometable"."id")
45
46 This class is primarily used in Query.alias_map. All entries in alias_map
47 must be Join compatible by providing the following attributes and methods:
48 - table_name (string)
49 - table_alias (possible alias for the table, can be None)
50 - join_type (can be None for those entries that aren't joined from
51 anything)
52 - parent_alias (which table is this join's parent, can be None similarly
53 to join_type)
54 - as_sql()
55 - relabeled_clone()
56 """
57
58 def __init__(
59 self,
60 table_name: str,
61 parent_alias: str,
62 table_alias: str,
63 join_type: str,
64 join_field: ForeignKeyField | ForeignObjectRel,
65 nullable: bool,
66 filtered_relation: Any = None,
67 ) -> None:
68 # Join table
69 self.table_name = table_name
70 self.parent_alias = parent_alias
71 # Note: table_alias is not necessarily known at instantiation time.
72 self.table_alias = table_alias
73 # LOUTER or INNER
74 self.join_type = join_type
75 # A list of 2-tuples to use in the ON clause of the JOIN.
76 # Each 2-tuple will create one join condition in the ON clause.
77 self.join_cols = join_field.get_joining_columns()
78 # Along which field (or ForeignObjectRel in the reverse join case)
79 self.join_field = join_field
80 # Is this join nullabled?
81 self.nullable = nullable
82 self.filtered_relation = filtered_relation
83
84 def as_sql(
85 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
86 ) -> tuple[str, list[Any]]:
87 """
88 Generate the full
89 LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
90 clause for this join.
91 """
92 join_conditions = []
93 params = []
94 qn = compiler.quote_name_unless_alias
95 qn2 = connection.ops.quote_name
96
97 # Add a join condition for each pair of joining columns.
98 for lhs_col, rhs_col in self.join_cols:
99 join_conditions.append(
100 f"{qn(self.parent_alias)}.{qn2(lhs_col)} = {qn(self.table_alias)}.{qn2(rhs_col)}"
101 )
102
103 if self.filtered_relation:
104 try:
105 extra_sql, extra_params = compiler.compile(self.filtered_relation)
106 except FullResultSet:
107 pass
108 else:
109 join_conditions.append(f"({extra_sql})")
110 params.extend(extra_params)
111 if not join_conditions:
112 # This might be a rel on the other end of an actual declared field.
113 declared_field = getattr(self.join_field, "field", self.join_field)
114 raise ValueError(
115 f"Join generated an empty ON clause. {declared_field.__class__} did not yield either "
116 "joining columns or extra restrictions."
117 )
118 on_clause_sql = " AND ".join(join_conditions)
119 alias_str = (
120 "" if self.table_alias == self.table_name else (f" {self.table_alias}")
121 )
122 sql = f"{self.join_type} {qn(self.table_name)}{alias_str} ON ({on_clause_sql})"
123 return sql, params
124
125 def relabeled_clone(self, change_map: dict[str, str]) -> Join:
126 new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
127 new_table_alias = change_map.get(self.table_alias, self.table_alias)
128 if self.filtered_relation is not None:
129 filtered_relation = self.filtered_relation.clone()
130 filtered_relation.path = [
131 change_map.get(p, p) for p in self.filtered_relation.path
132 ]
133 else:
134 filtered_relation = None
135 return self.__class__(
136 self.table_name,
137 new_parent_alias,
138 new_table_alias,
139 self.join_type,
140 self.join_field,
141 self.nullable,
142 filtered_relation=filtered_relation,
143 )
144
145 @property
146 def identity(self) -> tuple[type[Join], str, str, Any, Any]:
147 return (
148 self.__class__,
149 self.table_name,
150 self.parent_alias,
151 self.join_field,
152 self.filtered_relation,
153 )
154
155 def __eq__(self, other: object) -> bool:
156 if not isinstance(other, Join):
157 return NotImplemented
158 return self.identity == other.identity
159
160 def __hash__(self) -> int:
161 return hash(self.identity)
162
163 def equals(self, other: Join) -> bool:
164 # Ignore filtered_relation in equality check.
165 return self.identity[:-1] == other.identity[:-1]
166
167 def demote(self) -> Join:
168 new = self.relabeled_clone({})
169 new.join_type = INNER
170 return new
171
172 def promote(self) -> Join:
173 new = self.relabeled_clone({})
174 new.join_type = LOUTER
175 return new
176
177
178class BaseTable:
179 """
180 The BaseTable class is used for base table references in FROM clause. For
181 example, the SQL "foo" in
182 SELECT * FROM "foo" WHERE somecond
183 could be generated by this class.
184 """
185
186 join_type = None
187 parent_alias = None
188 filtered_relation = None
189
190 def __init__(self, table_name: str, alias: str) -> None:
191 self.table_name = table_name
192 self.table_alias = alias
193
194 def as_sql(
195 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
196 ) -> tuple[str, list[Any]]:
197 alias_str = (
198 "" if self.table_alias == self.table_name else (f" {self.table_alias}")
199 )
200 base_sql = compiler.quote_name_unless_alias(self.table_name)
201 return base_sql + alias_str, []
202
203 def relabeled_clone(self, change_map: dict[str, str]) -> BaseTable:
204 return self.__class__(
205 self.table_name, change_map.get(self.table_alias, self.table_alias)
206 )
207
208 @property
209 def identity(self) -> tuple[type[BaseTable], str, str]:
210 return self.__class__, self.table_name, self.table_alias
211
212 def __eq__(self, other: object) -> bool:
213 if not isinstance(other, BaseTable):
214 return NotImplemented
215 return self.identity == other.identity
216
217 def __hash__(self) -> int:
218 return hash(self.identity)
219
220 def equals(self, other: BaseTable) -> bool:
221 return self.identity == other.identity