1"""
2Useful auxiliary data structures for query construction. Not useful outside
3the SQL domain.
4"""
5
6from __future__ import annotations
7
8from typing import TYPE_CHECKING, Any
9
10from plain.models.exceptions import FullResultSet
11from plain.models.postgres.sql import quote_name
12from plain.models.sql.constants import INNER, LOUTER
13
14if TYPE_CHECKING:
15 from plain.models.fields.related import ForeignKeyField
16 from plain.models.fields.reverse_related import ForeignObjectRel
17 from plain.models.postgres.connection import DatabaseConnection
18 from plain.models.sql.compiler import SQLCompiler
19
20
21class MultiJoin(Exception):
22 """
23 Used by join construction code to indicate the point at which a
24 multi-valued join was attempted (if the caller wants to treat that
25 exceptionally).
26 """
27
28 def __init__(
29 self, names_pos: int, path_with_names: list[tuple[str, list[Any]]]
30 ) -> None:
31 self.level = names_pos
32 # The path travelled, this includes the path to the multijoin.
33 self.names_with_path = path_with_names
34
35
36class Empty:
37 pass
38
39
40class Join:
41 """
42 Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
43 FROM entry. For example, the SQL generated could be
44 LEFT OUTER JOIN "sometable" T1
45 ON ("othertable"."sometable_id" = "sometable"."id")
46
47 This class is primarily used in Query.alias_map. All entries in alias_map
48 must be Join compatible by providing the following attributes and methods:
49 - table_name (string)
50 - table_alias (possible alias for the table, can be None)
51 - join_type (can be None for those entries that aren't joined from
52 anything)
53 - parent_alias (which table is this join's parent, can be None similarly
54 to join_type)
55 - as_sql()
56 - relabeled_clone()
57 """
58
59 def __init__(
60 self,
61 table_name: str,
62 parent_alias: str,
63 table_alias: str,
64 join_type: str,
65 join_field: ForeignKeyField | ForeignObjectRel,
66 nullable: bool,
67 filtered_relation: Any = None,
68 ) -> None:
69 # Join table
70 self.table_name = table_name
71 self.parent_alias = parent_alias
72 # Note: table_alias is not necessarily known at instantiation time.
73 self.table_alias = table_alias
74 # LOUTER or INNER
75 self.join_type = join_type
76 # A list of 2-tuples to use in the ON clause of the JOIN.
77 # Each 2-tuple will create one join condition in the ON clause.
78 self.join_cols = join_field.get_joining_columns()
79 # Along which field (or ForeignObjectRel in the reverse join case)
80 self.join_field = join_field
81 # Is this join nullabled?
82 self.nullable = nullable
83 self.filtered_relation = filtered_relation
84
85 def as_sql(
86 self, compiler: SQLCompiler, connection: DatabaseConnection
87 ) -> tuple[str, list[Any]]:
88 """
89 Generate the full
90 LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
91 clause for this join.
92 """
93 join_conditions = []
94 params = []
95 qn = compiler.quote_name_unless_alias
96 qn2 = quote_name
97
98 # Add a join condition for each pair of joining columns.
99 for lhs_col, rhs_col in self.join_cols:
100 join_conditions.append(
101 f"{qn(self.parent_alias)}.{qn2(lhs_col)} = {qn(self.table_alias)}.{qn2(rhs_col)}"
102 )
103
104 if self.filtered_relation:
105 try:
106 extra_sql, extra_params = compiler.compile(self.filtered_relation)
107 except FullResultSet:
108 pass
109 else:
110 join_conditions.append(f"({extra_sql})")
111 params.extend(extra_params)
112 if not join_conditions:
113 # This might be a rel on the other end of an actual declared field.
114 declared_field = getattr(self.join_field, "field", self.join_field)
115 raise ValueError(
116 f"Join generated an empty ON clause. {declared_field.__class__} did not yield either "
117 "joining columns or extra restrictions."
118 )
119 on_clause_sql = " AND ".join(join_conditions)
120 alias_str = (
121 "" if self.table_alias == self.table_name else (f" {self.table_alias}")
122 )
123 sql = f"{self.join_type} {qn(self.table_name)}{alias_str} ON ({on_clause_sql})"
124 return sql, params
125
126 def relabeled_clone(self, change_map: dict[str, str]) -> Join:
127 new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
128 new_table_alias = change_map.get(self.table_alias, self.table_alias)
129 if self.filtered_relation is not None:
130 filtered_relation = self.filtered_relation.clone()
131 filtered_relation.path = [
132 change_map.get(p, p) for p in self.filtered_relation.path
133 ]
134 else:
135 filtered_relation = None
136 return self.__class__(
137 self.table_name,
138 new_parent_alias,
139 new_table_alias,
140 self.join_type,
141 self.join_field,
142 self.nullable,
143 filtered_relation=filtered_relation,
144 )
145
146 @property
147 def identity(self) -> tuple[type[Join], str, str, Any, Any]:
148 return (
149 self.__class__,
150 self.table_name,
151 self.parent_alias,
152 self.join_field,
153 self.filtered_relation,
154 )
155
156 def __eq__(self, other: object) -> bool:
157 if not isinstance(other, Join):
158 return NotImplemented
159 return self.identity == other.identity
160
161 def __hash__(self) -> int:
162 return hash(self.identity)
163
164 def equals(self, other: Join) -> bool:
165 # Ignore filtered_relation in equality check.
166 return self.identity[:-1] == other.identity[:-1]
167
168 def demote(self) -> Join:
169 new = self.relabeled_clone({})
170 new.join_type = INNER
171 return new
172
173 def promote(self) -> Join:
174 new = self.relabeled_clone({})
175 new.join_type = LOUTER
176 return new
177
178
179class BaseTable:
180 """
181 The BaseTable class is used for base table references in FROM clause. For
182 example, the SQL "foo" in
183 SELECT * FROM "foo" WHERE somecond
184 could be generated by this class.
185 """
186
187 join_type = None
188 parent_alias = None
189 filtered_relation = None
190
191 def __init__(self, table_name: str, alias: str) -> None:
192 self.table_name = table_name
193 self.table_alias = alias
194
195 def as_sql(
196 self, compiler: SQLCompiler, connection: DatabaseConnection
197 ) -> tuple[str, list[Any]]:
198 alias_str = (
199 "" if self.table_alias == self.table_name else (f" {self.table_alias}")
200 )
201 base_sql = compiler.quote_name_unless_alias(self.table_name)
202 return base_sql + alias_str, []
203
204 def relabeled_clone(self, change_map: dict[str, str]) -> BaseTable:
205 return self.__class__(
206 self.table_name, change_map.get(self.table_alias, self.table_alias)
207 )
208
209 @property
210 def identity(self) -> tuple[type[BaseTable], str, str]:
211 return self.__class__, self.table_name, self.table_alias
212
213 def __eq__(self, other: object) -> bool:
214 if not isinstance(other, BaseTable):
215 return NotImplemented
216 return self.identity == other.identity
217
218 def __hash__(self) -> int:
219 return hash(self.identity)
220
221 def equals(self, other: BaseTable) -> bool:
222 return self.identity == other.identity