1"""
2Various data structures used in query construction.
3
4Factored out from plain.postgres.query to avoid making the main module very
5large and/or so that they can be used by other modules without getting into
6circular import difficulties.
7"""
8
9from __future__ import annotations
10
11import functools
12import inspect
13from collections.abc import Callable, Generator
14from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Self
15
16import psycopg
17
18from plain.logs import get_framework_logger
19from plain.postgres.constants import LOOKUP_SEP
20from plain.postgres.exceptions import FieldError
21from plain.utils import tree
22
23if TYPE_CHECKING:
24 from plain.postgres.base import Model
25 from plain.postgres.connection import DatabaseConnection
26 from plain.postgres.fields import Field
27 from plain.postgres.fields.related import ForeignKeyField
28 from plain.postgres.fields.reverse_related import ForeignObjectRel
29 from plain.postgres.lookups import Lookup, Transform
30 from plain.postgres.meta import Meta
31 from plain.postgres.sql.compiler import SQLCompiler
32 from plain.postgres.sql.where import WhereNode
33
34logger = get_framework_logger()
35
36
37class PathInfo(NamedTuple):
38 """Information about a relation path when converting lookups (fk__somecol).
39
40 Describes the relation in Model terms (Meta and Fields for both
41 sides of the relation). The join_field is the field backing the relation.
42 """
43
44 from_meta: Meta
45 to_meta: Meta
46 target_fields: tuple[Field, ...]
47 join_field: ForeignKeyField | ForeignObjectRel
48 m2m: bool
49 direct: bool
50 filtered_relation: FilteredRelation | None
51
52
53def subclasses(cls: type) -> Generator[type]:
54 yield cls
55 for subclass in cls.__subclasses__():
56 yield from subclasses(subclass)
57
58
59class Q(tree.Node):
60 """
61 Encapsulate filters as objects that can then be combined logically (using
62 `&` and `|`).
63 """
64
65 # Connection types
66 AND = "AND"
67 OR = "OR"
68 XOR = "XOR"
69 default = AND
70 conditional = True
71
72 def __init__(
73 self,
74 *args: Any,
75 _connector: str | None = None,
76 _negated: bool = False,
77 **kwargs: Any,
78 ) -> None:
79 super().__init__(
80 children=[*args, *sorted(kwargs.items())],
81 connector=_connector,
82 negated=_negated,
83 )
84
85 def _combine(self, other: Any, conn: str) -> Q:
86 if getattr(other, "conditional", False) is False:
87 raise TypeError(other)
88 if not self:
89 return other.copy()
90 if not other and isinstance(other, Q):
91 return self.copy()
92
93 obj = self.create(connector=conn)
94 obj.add(self, conn)
95 obj.add(other, conn)
96 return obj
97
98 def __or__(self, other: Any) -> Q:
99 return self._combine(other, self.OR)
100
101 def __and__(self, other: Any) -> Q:
102 return self._combine(other, self.AND)
103
104 def __xor__(self, other: Any) -> Q:
105 return self._combine(other, self.XOR)
106
107 def __invert__(self) -> Q:
108 obj = self.copy()
109 obj.negate()
110 return obj
111
112 def resolve_expression(
113 self,
114 query: Any = None,
115 allow_joins: bool = True,
116 reuse: Any = None,
117 summarize: bool = False,
118 for_save: bool = False,
119 ) -> WhereNode:
120 # We must promote any new joins to left outer joins so that when Q is
121 # used as an expression, rows aren't filtered due to joins.
122 clause, joins = query._add_q(
123 self,
124 reuse,
125 allow_joins=allow_joins,
126 split_subq=False,
127 check_filterable=False,
128 summarize=summarize,
129 )
130 query.promote_joins(joins)
131 return clause
132
133 def flatten(self) -> Generator[Any]:
134 """
135 Recursively yield this Q object and all subexpressions, in depth-first
136 order.
137 """
138 yield self
139 for child in self.children:
140 if isinstance(child, tuple):
141 # Use the lookup.
142 child = child[1]
143 if hasattr(child, "flatten"):
144 yield from child.flatten()
145 else:
146 yield child
147
148 def check(self, against: dict[str, Any]) -> bool:
149 """
150 Do a database query to check if the expressions of the Q instance
151 matches against the expressions.
152 """
153 # Avoid circular imports.
154 from plain.postgres.expressions import ResolvableExpression, Value
155 from plain.postgres.fields import BooleanField
156 from plain.postgres.functions import Coalesce
157 from plain.postgres.sql import SINGLE, Query
158
159 query = Query(None)
160 for name, value in against.items():
161 if not isinstance(value, ResolvableExpression):
162 value = Value(value)
163 query.add_annotation(value, name, select=False)
164 query.add_annotation(Value(1), "_check")
165 # This will raise a FieldError if a field is missing in "against".
166 query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
167 compiler = query.get_compiler()
168 try:
169 return compiler.execute_sql(SINGLE) is not None
170 except psycopg.DatabaseError as e:
171 logger.warning(
172 "Got a database error calling check()",
173 extra={"expression": repr(self), "error": str(e)},
174 )
175 return True
176
177 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
178 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
179 if path.startswith("plain.postgres.query_utils"):
180 path = path.replace("plain.postgres.query_utils", "plain.postgres")
181 args = tuple(self.children)
182 kwargs: dict[str, Any] = {}
183 if self.connector != self.default:
184 kwargs["_connector"] = self.connector
185 if self.negated:
186 kwargs["_negated"] = True
187 return path, args, kwargs
188
189
190class class_or_instance_method:
191 """
192 Hook used in RegisterLookupMixin to return partial functions depending on
193 the caller type (instance or class of models.Field).
194 """
195
196 def __init__(self, class_method: Any, instance_method: Any) -> None:
197 self.class_method = class_method
198 self.instance_method = instance_method
199
200 def __get__(self, instance: Any, owner: type) -> Any:
201 if instance is None:
202 return functools.partial(self.class_method, owner)
203 return functools.partial(self.instance_method, instance)
204
205
206class RegisterLookupMixin:
207 class_lookups: ClassVar[dict[str, type[Lookup | Transform]]]
208
209 def _get_lookup(self, lookup_name: str) -> type[Lookup | Transform] | None:
210 return self.get_lookups().get(lookup_name, None)
211
212 @functools.cache
213 def get_class_lookups(cls: type[Self]) -> dict[str, type[Lookup | Transform]]:
214 class_lookups = [
215 parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
216 ]
217 return cls.merge_dicts(class_lookups)
218
219 def get_instance_lookups(self) -> dict[str, type[Lookup | Transform]]:
220 class_lookups = self.get_class_lookups()
221 if instance_lookups := getattr(self, "instance_lookups", None):
222 return {**class_lookups, **instance_lookups}
223 return class_lookups
224
225 get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
226 get_class_lookups: ClassVar[classmethod[Any, ..., Any]] = classmethod(
227 get_class_lookups
228 )
229
230 def get_lookup(self, lookup_name: str) -> type[Lookup] | None:
231 from plain.postgres.lookups import Lookup
232
233 found = self._get_lookup(lookup_name)
234 if found is None:
235 # output_field is a Field which inherits from RegisterLookupMixin
236 if output_field := getattr(self, "output_field", None):
237 return output_field.get_lookup(lookup_name)
238 if found is not None and not issubclass(found, Lookup):
239 return None
240 return found
241
242 def get_transform(
243 self, lookup_name: str
244 ) -> type[Transform] | Callable[..., Any] | None:
245 from plain.postgres.lookups import Transform
246
247 found = self._get_lookup(lookup_name)
248 if found is None:
249 # output_field is a Field which inherits from RegisterLookupMixin
250 if output_field := getattr(self, "output_field", None):
251 return output_field.get_transform(lookup_name)
252 if found is not None and not issubclass(found, Transform):
253 return None
254 return found
255
256 @staticmethod
257 def merge_dicts(
258 dicts: list[dict[str, type[Lookup | Transform]]],
259 ) -> dict[str, type[Lookup | Transform]]:
260 """
261 Merge dicts in reverse to preference the order of the original list. e.g.,
262 merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
263 """
264 merged: dict[str, type[Lookup | Transform]] = {}
265 for d in reversed(dicts):
266 merged.update(d)
267 return merged
268
269 @classmethod
270 def _clear_cached_class_lookups(cls: type[Self]) -> None:
271 for subclass in subclasses(cls):
272 if cached := getattr(subclass, "get_class_lookups", None):
273 cached.cache_clear()
274
275 def register_class_lookup(
276 cls: type[Self],
277 lookup: type[Lookup | Transform],
278 lookup_name: str | None = None,
279 ) -> type[Lookup | Transform]:
280 if lookup_name is None:
281 lookup_name = lookup.lookup_name
282 assert lookup_name is not None, "lookup_name must be set on the lookup class"
283 if "class_lookups" not in cls.__dict__:
284 cls.class_lookups = {}
285 cls.class_lookups[lookup_name] = lookup
286 cls._clear_cached_class_lookups()
287 return lookup
288
289 def register_instance_lookup(
290 self, lookup: type[Lookup | Transform], lookup_name: str | None = None
291 ) -> type[Lookup | Transform]:
292 if lookup_name is None:
293 lookup_name = lookup.lookup_name
294 if "instance_lookups" not in self.__dict__:
295 self.instance_lookups = {}
296 self.instance_lookups[lookup_name] = lookup
297 return lookup
298
299 register_lookup = class_or_instance_method(
300 register_class_lookup, register_instance_lookup
301 )
302 register_class_lookup: ClassVar[classmethod[Any, ..., Any]] = classmethod(
303 register_class_lookup
304 )
305
306 def _unregister_class_lookup(
307 cls: type[Self],
308 lookup: type[Lookup | Transform],
309 lookup_name: str | None = None,
310 ) -> None:
311 """
312 Remove given lookup from cls lookups. For use in tests only as it's
313 not thread-safe.
314 """
315 if lookup_name is None:
316 lookup_name = lookup.lookup_name
317 assert lookup_name is not None, "lookup_name must be set on the lookup class"
318 del cls.class_lookups[lookup_name]
319 cls._clear_cached_class_lookups()
320
321 def _unregister_instance_lookup(
322 self, lookup: type[Lookup | Transform], lookup_name: str | None = None
323 ) -> None:
324 """
325 Remove given lookup from instance lookups. For use in tests only as
326 it's not thread-safe.
327 """
328 if lookup_name is None:
329 lookup_name = lookup.lookup_name
330 del self.instance_lookups[lookup_name]
331
332 _unregister_lookup = class_or_instance_method(
333 _unregister_class_lookup, _unregister_instance_lookup
334 )
335 _unregister_class_lookup: ClassVar[classmethod[Any, ..., Any]] = classmethod(
336 _unregister_class_lookup
337 )
338
339
340def select_related_descend(
341 field: Any,
342 restricted: bool | None,
343 requested: dict[str, Any] | None,
344 select_mask: Any,
345 reverse: bool = False,
346) -> bool:
347 """
348 Return True if this field should be used to descend deeper for
349 select_related() purposes. Used by both the query construction code
350 (compiler.get_related_selections()) and the model instance creation code
351 (compiler.klass_info).
352
353 Arguments:
354 * field - the field to be checked
355 * restricted - a boolean field, indicating if the field list has been
356 manually restricted using a requested clause)
357 * requested - The select_related() dictionary.
358 * select_mask - the dictionary of selected fields.
359 * reverse - boolean, True if we are checking a reverse select related
360 """
361 from plain.postgres.fields.related import RelatedField
362
363 if not isinstance(field, RelatedField):
364 return False
365 if restricted:
366 assert requested is not None, "requested must be provided when restricted=True"
367 if reverse and field.related_query_name() not in requested:
368 return False
369 if not reverse and field.name not in requested:
370 return False
371 if not restricted and field.allow_null:
372 return False
373 if (
374 restricted
375 and select_mask
376 and field.name in requested # ty: ignore[unsupported-operator]
377 and field not in select_mask
378 ):
379 raise FieldError(
380 f"Field {field.model.model_options.object_name}.{field.name} cannot be both "
381 "deferred and traversed using select_related at the same time."
382 )
383 return True
384
385
386def refs_expression(
387 lookup_parts: list[str], annotations: dict[str, Any]
388) -> tuple[str | None, tuple[str, ...]]:
389 """
390 Check if the lookup_parts contains references to the given annotations set.
391 Because the LOOKUP_SEP is contained in the default annotation names, check
392 each prefix of the lookup_parts for a match.
393 """
394 for n in range(1, len(lookup_parts) + 1):
395 level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
396 if annotations.get(level_n_lookup):
397 return level_n_lookup, tuple(lookup_parts[n:])
398 return None, ()
399
400
401def check_rel_lookup_compatibility(
402 model: type[Model], target_meta: Meta, field: Field | ForeignObjectRel
403) -> bool:
404 """
405 Check that model is compatible with target_meta — i.e. model matches
406 the target's model, or the field is a primary key whose model matches.
407 """
408
409 def check(meta: Meta) -> bool:
410 return model == meta.model
411
412 # Primary-key fields get a second chance: a queryset like
413 # `Model.query.filter(id__in=Model.query.all())` resolves `id__in` through
414 # the PK field, whose target meta is the remote model. Allow the match
415 # against the field's own model so the subquery (later reduced to
416 # `.values("id")`) is accepted.
417 return check(target_meta) or (
418 getattr(field, "primary_key", False) and check(field.model._model_meta)
419 )
420
421
422class FilteredRelation:
423 """Specify custom filtering in the ON clause of SQL joins."""
424
425 def __init__(self, relation_name: str, *, condition: Q = Q()) -> None:
426 if not relation_name:
427 raise ValueError("relation_name cannot be empty.")
428 self.relation_name = relation_name
429 self.alias: str | None = None
430 if not isinstance(condition, Q):
431 raise ValueError("condition argument must be a Q() instance.")
432 self.condition = condition
433 self.path: list[str] = []
434
435 def __eq__(self, other: object) -> bool:
436 if not isinstance(other, self.__class__):
437 return NotImplemented
438 return (
439 self.relation_name == other.relation_name
440 and self.alias == other.alias
441 and self.condition == other.condition
442 )
443
444 def clone(self) -> FilteredRelation:
445 clone = FilteredRelation(self.relation_name, condition=self.condition)
446 clone.alias = self.alias
447 clone.path = self.path[:]
448 return clone
449
450 def as_sql(self, compiler: SQLCompiler, connection: DatabaseConnection) -> Any:
451 # Resolve the condition in Join.filtered_relation.
452 query = compiler.query
453 where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
454 return compiler.compile(where)