1"""
2Various data structures used in query construction.
3
4Factored out from plain.postgres.query to avoid making the main module very
5large and/or so that they can be used by other modules without getting into
6circular import difficulties.
7"""
8
9from __future__ import annotations
10
11import functools
12import inspect
13import logging
14from collections.abc import Callable, Generator
15from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Self
16
17from plain.postgres.constants import LOOKUP_SEP
18from plain.postgres.db import DatabaseError
19from plain.postgres.exceptions import FieldError
20from plain.utils import tree
21
22if TYPE_CHECKING:
23 from plain.postgres.base import Model
24 from plain.postgres.connection import DatabaseConnection
25 from plain.postgres.fields import Field
26 from plain.postgres.fields.related import ForeignKeyField
27 from plain.postgres.fields.reverse_related import ForeignObjectRel
28 from plain.postgres.lookups import Lookup, Transform
29 from plain.postgres.meta import Meta
30 from plain.postgres.sql.compiler import SQLCompiler
31 from plain.postgres.sql.where import WhereNode
32
33logger = logging.getLogger("plain.postgres")
34
35
36class PathInfo(NamedTuple):
37 """Information about a relation path when converting lookups (fk__somecol).
38
39 Describes the relation in Model terms (Meta and Fields for both
40 sides of the relation). The join_field is the field backing the relation.
41 """
42
43 from_meta: Meta
44 to_meta: Meta
45 target_fields: tuple[Field, ...]
46 join_field: ForeignKeyField | ForeignObjectRel
47 m2m: bool
48 direct: bool
49 filtered_relation: FilteredRelation | None
50
51
52def subclasses(cls: type) -> Generator[type]:
53 yield cls
54 for subclass in cls.__subclasses__():
55 yield from subclasses(subclass)
56
57
58class Q(tree.Node):
59 """
60 Encapsulate filters as objects that can then be combined logically (using
61 `&` and `|`).
62 """
63
64 # Connection types
65 AND = "AND"
66 OR = "OR"
67 XOR = "XOR"
68 default = AND
69 conditional = True
70
71 def __init__(
72 self,
73 *args: Any,
74 _connector: str | None = None,
75 _negated: bool = False,
76 **kwargs: Any,
77 ) -> None:
78 super().__init__(
79 children=[*args, *sorted(kwargs.items())],
80 connector=_connector,
81 negated=_negated,
82 )
83
84 def _combine(self, other: Any, conn: str) -> Q:
85 if getattr(other, "conditional", False) is False:
86 raise TypeError(other)
87 if not self:
88 return other.copy()
89 if not other and isinstance(other, Q):
90 return self.copy()
91
92 obj = self.create(connector=conn)
93 obj.add(self, conn)
94 obj.add(other, conn)
95 return obj
96
97 def __or__(self, other: Any) -> Q:
98 return self._combine(other, self.OR)
99
100 def __and__(self, other: Any) -> Q:
101 return self._combine(other, self.AND)
102
103 def __xor__(self, other: Any) -> Q:
104 return self._combine(other, self.XOR)
105
106 def __invert__(self) -> Q:
107 obj = self.copy()
108 obj.negate()
109 return obj
110
111 def resolve_expression(
112 self,
113 query: Any = None,
114 allow_joins: bool = True,
115 reuse: Any = None,
116 summarize: bool = False,
117 for_save: bool = False,
118 ) -> WhereNode:
119 # We must promote any new joins to left outer joins so that when Q is
120 # used as an expression, rows aren't filtered due to joins.
121 clause, joins = query._add_q(
122 self,
123 reuse,
124 allow_joins=allow_joins,
125 split_subq=False,
126 check_filterable=False,
127 summarize=summarize,
128 )
129 query.promote_joins(joins)
130 return clause
131
132 def flatten(self) -> Generator[Any]:
133 """
134 Recursively yield this Q object and all subexpressions, in depth-first
135 order.
136 """
137 yield self
138 for child in self.children:
139 if isinstance(child, tuple):
140 # Use the lookup.
141 child = child[1]
142 if hasattr(child, "flatten"):
143 yield from child.flatten()
144 else:
145 yield child
146
147 def check(self, against: dict[str, Any]) -> bool:
148 """
149 Do a database query to check if the expressions of the Q instance
150 matches against the expressions.
151 """
152 # Avoid circular imports.
153 from plain.postgres.expressions import ResolvableExpression, Value
154 from plain.postgres.fields import BooleanField
155 from plain.postgres.functions import Coalesce
156 from plain.postgres.sql import SINGLE, Query
157
158 query = Query(None)
159 for name, value in against.items():
160 if not isinstance(value, ResolvableExpression):
161 value = Value(value)
162 query.add_annotation(value, name, select=False)
163 query.add_annotation(Value(1), "_check")
164 # This will raise a FieldError if a field is missing in "against".
165 query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
166 compiler = query.get_compiler()
167 try:
168 return compiler.execute_sql(SINGLE) is not None
169 except DatabaseError as e:
170 logger.warning("Got a database error calling check() on %r: %s", self, e)
171 return True
172
173 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
174 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
175 if path.startswith("plain.postgres.query_utils"):
176 path = path.replace("plain.postgres.query_utils", "plain.postgres")
177 args = tuple(self.children)
178 kwargs: dict[str, Any] = {}
179 if self.connector != self.default:
180 kwargs["_connector"] = self.connector
181 if self.negated:
182 kwargs["_negated"] = True
183 return path, args, kwargs
184
185
186class class_or_instance_method:
187 """
188 Hook used in RegisterLookupMixin to return partial functions depending on
189 the caller type (instance or class of models.Field).
190 """
191
192 def __init__(self, class_method: Any, instance_method: Any) -> None:
193 self.class_method = class_method
194 self.instance_method = instance_method
195
196 def __get__(self, instance: Any, owner: type) -> Any:
197 if instance is None:
198 return functools.partial(self.class_method, owner)
199 return functools.partial(self.instance_method, instance)
200
201
202class RegisterLookupMixin:
203 class_lookups: ClassVar[dict[str, type[Lookup | Transform]]]
204
205 def _get_lookup(self, lookup_name: str) -> type[Lookup | Transform] | None:
206 return self.get_lookups().get(lookup_name, None)
207
208 @functools.cache
209 def get_class_lookups(cls: type[Self]) -> dict[str, type[Lookup | Transform]]:
210 class_lookups = [
211 parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
212 ]
213 return cls.merge_dicts(class_lookups)
214
215 def get_instance_lookups(self) -> dict[str, type[Lookup | Transform]]:
216 class_lookups = self.get_class_lookups()
217 if instance_lookups := getattr(self, "instance_lookups", None):
218 return {**class_lookups, **instance_lookups}
219 return class_lookups
220
221 get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
222 get_class_lookups: ClassVar[classmethod[Any, ..., Any]] = classmethod(
223 get_class_lookups
224 )
225
226 def get_lookup(self, lookup_name: str) -> type[Lookup] | None:
227 from plain.postgres.lookups import Lookup
228
229 found = self._get_lookup(lookup_name)
230 if found is None:
231 # output_field is a Field which inherits from RegisterLookupMixin
232 if output_field := getattr(self, "output_field", None):
233 return output_field.get_lookup(lookup_name)
234 if found is not None and not issubclass(found, Lookup):
235 return None
236 return found
237
238 def get_transform(
239 self, lookup_name: str
240 ) -> type[Transform] | Callable[..., Any] | None:
241 from plain.postgres.lookups import Transform
242
243 found = self._get_lookup(lookup_name)
244 if found is None:
245 # output_field is a Field which inherits from RegisterLookupMixin
246 if output_field := getattr(self, "output_field", None):
247 return output_field.get_transform(lookup_name)
248 if found is not None and not issubclass(found, Transform):
249 return None
250 return found
251
252 @staticmethod
253 def merge_dicts(
254 dicts: list[dict[str, type[Lookup | Transform]]],
255 ) -> dict[str, type[Lookup | Transform]]:
256 """
257 Merge dicts in reverse to preference the order of the original list. e.g.,
258 merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
259 """
260 merged: dict[str, type[Lookup | Transform]] = {}
261 for d in reversed(dicts):
262 merged.update(d)
263 return merged
264
265 @classmethod
266 def _clear_cached_class_lookups(cls: type[Self]) -> None:
267 for subclass in subclasses(cls):
268 if cached := getattr(subclass, "get_class_lookups", None):
269 cached.cache_clear()
270
271 def register_class_lookup(
272 cls: type[Self],
273 lookup: type[Lookup | Transform],
274 lookup_name: str | None = None,
275 ) -> type[Lookup | Transform]:
276 if lookup_name is None:
277 lookup_name = lookup.lookup_name
278 assert lookup_name is not None, "lookup_name must be set on the lookup class"
279 if "class_lookups" not in cls.__dict__:
280 cls.class_lookups = {}
281 cls.class_lookups[lookup_name] = lookup
282 cls._clear_cached_class_lookups()
283 return lookup
284
285 def register_instance_lookup(
286 self, lookup: type[Lookup | Transform], lookup_name: str | None = None
287 ) -> type[Lookup | Transform]:
288 if lookup_name is None:
289 lookup_name = lookup.lookup_name
290 if "instance_lookups" not in self.__dict__:
291 self.instance_lookups = {}
292 self.instance_lookups[lookup_name] = lookup
293 return lookup
294
295 register_lookup = class_or_instance_method(
296 register_class_lookup, register_instance_lookup
297 )
298 register_class_lookup: ClassVar[classmethod[Any, ..., Any]] = classmethod(
299 register_class_lookup
300 )
301
302 def _unregister_class_lookup(
303 cls: type[Self],
304 lookup: type[Lookup | Transform],
305 lookup_name: str | None = None,
306 ) -> None:
307 """
308 Remove given lookup from cls lookups. For use in tests only as it's
309 not thread-safe.
310 """
311 if lookup_name is None:
312 lookup_name = lookup.lookup_name
313 assert lookup_name is not None, "lookup_name must be set on the lookup class"
314 del cls.class_lookups[lookup_name]
315 cls._clear_cached_class_lookups()
316
317 def _unregister_instance_lookup(
318 self, lookup: type[Lookup | Transform], lookup_name: str | None = None
319 ) -> None:
320 """
321 Remove given lookup from instance lookups. For use in tests only as
322 it's not thread-safe.
323 """
324 if lookup_name is None:
325 lookup_name = lookup.lookup_name
326 del self.instance_lookups[lookup_name]
327
328 _unregister_lookup = class_or_instance_method(
329 _unregister_class_lookup, _unregister_instance_lookup
330 )
331 _unregister_class_lookup: ClassVar[classmethod[Any, ..., Any]] = classmethod(
332 _unregister_class_lookup
333 )
334
335
336def select_related_descend(
337 field: Any,
338 restricted: bool | None,
339 requested: dict[str, Any] | None,
340 select_mask: Any,
341 reverse: bool = False,
342) -> bool:
343 """
344 Return True if this field should be used to descend deeper for
345 select_related() purposes. Used by both the query construction code
346 (compiler.get_related_selections()) and the model instance creation code
347 (compiler.klass_info).
348
349 Arguments:
350 * field - the field to be checked
351 * restricted - a boolean field, indicating if the field list has been
352 manually restricted using a requested clause)
353 * requested - The select_related() dictionary.
354 * select_mask - the dictionary of selected fields.
355 * reverse - boolean, True if we are checking a reverse select related
356 """
357 from plain.postgres.fields.related import RelatedField
358
359 if not isinstance(field, RelatedField):
360 return False
361 if restricted:
362 assert requested is not None, "requested must be provided when restricted=True"
363 if reverse and field.related_query_name() not in requested:
364 return False
365 if not reverse and field.name not in requested:
366 return False
367 if not restricted and field.allow_null:
368 return False
369 if (
370 restricted
371 and select_mask
372 and field.name in requested # type: ignore[operator]
373 and field not in select_mask
374 ):
375 raise FieldError(
376 f"Field {field.model.model_options.object_name}.{field.name} cannot be both "
377 "deferred and traversed using select_related at the same time."
378 )
379 return True
380
381
382def refs_expression(
383 lookup_parts: list[str], annotations: dict[str, Any]
384) -> tuple[str | None, tuple[str, ...]]:
385 """
386 Check if the lookup_parts contains references to the given annotations set.
387 Because the LOOKUP_SEP is contained in the default annotation names, check
388 each prefix of the lookup_parts for a match.
389 """
390 for n in range(1, len(lookup_parts) + 1):
391 level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
392 if annotations.get(level_n_lookup):
393 return level_n_lookup, tuple(lookup_parts[n:])
394 return None, ()
395
396
397def check_rel_lookup_compatibility(
398 model: type[Model], target_meta: Meta, field: Field | ForeignObjectRel
399) -> bool:
400 """
401 Check that model is compatible with target_meta. Compatibility
402 is OK if:
403 1) model and meta.model match (where proxy inheritance is removed)
404 2) model is parent of meta's model or the other way around
405 """
406
407 def check(meta: Meta) -> bool:
408 return model == meta.model
409
410 # If the field is a primary key, then doing a query against the field's
411 # model is ok, too. Consider the case:
412 # class Restaurant(models.Model):
413 # place = OneToOneField(Place, primary_key=True):
414 # Restaurant.query.filter(id__in=Restaurant.query.all()).
415 # If we didn't have the primary key check, then id__in (== place__in) would
416 # give Place's meta as the target meta, but Restaurant isn't compatible
417 # with that. This logic applies only to primary keys, as when doing __in=qs,
418 # we are going to turn this into __in=qs.values('id') later on.
419 return check(target_meta) or (
420 getattr(field, "primary_key", False) and check(field.model._model_meta)
421 )
422
423
424class FilteredRelation:
425 """Specify custom filtering in the ON clause of SQL joins."""
426
427 def __init__(self, relation_name: str, *, condition: Q = Q()) -> None:
428 if not relation_name:
429 raise ValueError("relation_name cannot be empty.")
430 self.relation_name = relation_name
431 self.alias: str | None = None
432 if not isinstance(condition, Q):
433 raise ValueError("condition argument must be a Q() instance.")
434 self.condition = condition
435 self.path: list[str] = []
436
437 def __eq__(self, other: object) -> bool:
438 if not isinstance(other, self.__class__):
439 return NotImplemented
440 return (
441 self.relation_name == other.relation_name
442 and self.alias == other.alias
443 and self.condition == other.condition
444 )
445
446 def clone(self) -> FilteredRelation:
447 clone = FilteredRelation(self.relation_name, condition=self.condition)
448 clone.alias = self.alias
449 clone.path = self.path[:]
450 return clone
451
452 def as_sql(self, compiler: SQLCompiler, connection: DatabaseConnection) -> Any:
453 # Resolve the condition in Join.filtered_relation.
454 query = compiler.query
455 where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
456 return compiler.compile(where)