1"""
2Various data structures used in query construction.
3
4Factored out from plain.models.query to avoid making the main module very
5large and/or so that they can be used by other modules without getting into
6circular import difficulties.
7"""
8
9from __future__ import annotations
10
11import functools
12import inspect
13import logging
14from collections.abc import Callable, Generator
15from typing import TYPE_CHECKING, Any, NamedTuple, Self
16
17from plain.models.constants import LOOKUP_SEP
18from plain.models.db import DatabaseError, db_connection
19from plain.models.exceptions import FieldError
20from plain.utils import tree
21
22if TYPE_CHECKING:
23 from plain.models.backends.base.base import BaseDatabaseWrapper
24 from plain.models.base import Model
25 from plain.models.fields import Field
26 from plain.models.fields.related import ForeignKeyField
27 from plain.models.fields.reverse_related import ForeignObjectRel
28 from plain.models.lookups import Lookup, Transform
29 from plain.models.meta import Meta
30 from plain.models.sql.compiler import SQLCompiler
31 from plain.models.sql.where import WhereNode
32
33logger = logging.getLogger("plain.models")
34
35
36class PathInfo(NamedTuple):
37 """Information about a relation path when converting lookups (fk__somecol).
38
39 Describes the relation in Model terms (Meta and Fields for both
40 sides of the relation). The join_field is the field backing the relation.
41 """
42
43 from_meta: Meta
44 to_meta: Meta
45 target_fields: tuple[Field, ...]
46 join_field: ForeignKeyField | ForeignObjectRel
47 m2m: bool
48 direct: bool
49 filtered_relation: FilteredRelation | None
50
51
52def subclasses(cls: type) -> Generator[type, None, None]:
53 yield cls
54 for subclass in cls.__subclasses__():
55 yield from subclasses(subclass)
56
57
58class Q(tree.Node):
59 """
60 Encapsulate filters as objects that can then be combined logically (using
61 `&` and `|`).
62 """
63
64 # Connection types
65 AND = "AND"
66 OR = "OR"
67 XOR = "XOR"
68 default = AND
69 conditional = True
70
71 def __init__(
72 self,
73 *args: Any,
74 _connector: str | None = None,
75 _negated: bool = False,
76 **kwargs: Any,
77 ) -> None:
78 super().__init__(
79 children=[*args, *sorted(kwargs.items())],
80 connector=_connector,
81 negated=_negated,
82 )
83
84 def _combine(self, other: Any, conn: str) -> Q:
85 if getattr(other, "conditional", False) is False:
86 raise TypeError(other)
87 if not self:
88 return other.copy()
89 if not other and isinstance(other, Q):
90 return self.copy()
91
92 obj = self.create(connector=conn)
93 obj.add(self, conn)
94 obj.add(other, conn)
95 return obj
96
97 def __or__(self, other: Any) -> Q:
98 return self._combine(other, self.OR)
99
100 def __and__(self, other: Any) -> Q:
101 return self._combine(other, self.AND)
102
103 def __xor__(self, other: Any) -> Q:
104 return self._combine(other, self.XOR)
105
106 def __invert__(self) -> Q:
107 obj = self.copy()
108 obj.negate()
109 return obj
110
111 def resolve_expression(
112 self,
113 query: Any = None,
114 allow_joins: bool = True,
115 reuse: Any = None,
116 summarize: bool = False,
117 for_save: bool = False,
118 ) -> WhereNode:
119 # We must promote any new joins to left outer joins so that when Q is
120 # used as an expression, rows aren't filtered due to joins.
121 clause, joins = query._add_q(
122 self,
123 reuse,
124 allow_joins=allow_joins,
125 split_subq=False,
126 check_filterable=False,
127 summarize=summarize,
128 )
129 query.promote_joins(joins)
130 return clause
131
132 def flatten(self) -> Generator[Any, None, None]:
133 """
134 Recursively yield this Q object and all subexpressions, in depth-first
135 order.
136 """
137 yield self
138 for child in self.children:
139 if isinstance(child, tuple):
140 # Use the lookup.
141 child = child[1]
142 if hasattr(child, "flatten"):
143 yield from child.flatten()
144 else:
145 yield child
146
147 def check(self, against: dict[str, Any]) -> bool:
148 """
149 Do a database query to check if the expressions of the Q instance
150 matches against the expressions.
151 """
152 # Avoid circular imports.
153 from plain.models.expressions import ResolvableExpression, Value
154 from plain.models.fields import BooleanField
155 from plain.models.functions import Coalesce
156 from plain.models.sql import Query
157 from plain.models.sql.constants import SINGLE
158
159 query = Query(None)
160 for name, value in against.items():
161 if not isinstance(value, ResolvableExpression):
162 value = Value(value)
163 query.add_annotation(value, name, select=False)
164 query.add_annotation(Value(1), "_check")
165 # This will raise a FieldError if a field is missing in "against".
166 if db_connection.features.supports_comparing_boolean_expr:
167 query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
168 else:
169 query.add_q(self)
170 compiler = query.get_compiler()
171 try:
172 return compiler.execute_sql(SINGLE) is not None
173 except DatabaseError as e:
174 logger.warning("Got a database error calling check() on %r: %s", self, e)
175 return True
176
177 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
178 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
179 if path.startswith("plain.models.query_utils"):
180 path = path.replace("plain.models.query_utils", "plain.models")
181 args = tuple(self.children)
182 kwargs: dict[str, Any] = {}
183 if self.connector != self.default:
184 kwargs["_connector"] = self.connector
185 if self.negated:
186 kwargs["_negated"] = True
187 return path, args, kwargs
188
189
190class class_or_instance_method:
191 """
192 Hook used in RegisterLookupMixin to return partial functions depending on
193 the caller type (instance or class of models.Field).
194 """
195
196 def __init__(self, class_method: Any, instance_method: Any) -> None:
197 self.class_method = class_method
198 self.instance_method = instance_method
199
200 def __get__(self, instance: Any, owner: type) -> Any:
201 if instance is None:
202 return functools.partial(self.class_method, owner)
203 return functools.partial(self.instance_method, instance)
204
205
206class RegisterLookupMixin:
207 def _get_lookup(self, lookup_name: str) -> type[Lookup | Transform] | None:
208 return self.get_lookups().get(lookup_name, None)
209
210 @functools.cache
211 def get_class_lookups(cls: type[Self]) -> dict[str, type[Lookup | Transform]]:
212 class_lookups = [
213 parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
214 ]
215 return cls.merge_dicts(class_lookups)
216
217 def get_instance_lookups(self) -> dict[str, type[Lookup | Transform]]:
218 class_lookups = self.get_class_lookups()
219 if instance_lookups := getattr(self, "instance_lookups", None):
220 return {**class_lookups, **instance_lookups}
221 return class_lookups
222
223 get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
224 get_class_lookups = classmethod(get_class_lookups) # type: ignore[assignment]
225
226 def get_lookup(self, lookup_name: str) -> type[Lookup] | None:
227 from plain.models.lookups import Lookup
228
229 found = self._get_lookup(lookup_name)
230 if found is None:
231 # output_field is a Field which inherits from RegisterLookupMixin
232 if output_field := getattr(self, "output_field", None):
233 return output_field.get_lookup(lookup_name)
234 if found is not None and not issubclass(found, Lookup):
235 return None
236 return found
237
238 def get_transform(
239 self, lookup_name: str
240 ) -> type[Transform] | Callable[..., Any] | None:
241 from plain.models.lookups import Transform
242
243 found = self._get_lookup(lookup_name)
244 if found is None:
245 # output_field is a Field which inherits from RegisterLookupMixin
246 if output_field := getattr(self, "output_field", None):
247 return output_field.get_transform(lookup_name)
248 if found is not None and not issubclass(found, Transform):
249 return None
250 return found
251
252 @staticmethod
253 def merge_dicts(
254 dicts: list[dict[str, type[Lookup | Transform]]],
255 ) -> dict[str, type[Lookup | Transform]]:
256 """
257 Merge dicts in reverse to preference the order of the original list. e.g.,
258 merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
259 """
260 merged: dict[str, type[Lookup | Transform]] = {}
261 for d in reversed(dicts):
262 merged.update(d)
263 return merged
264
265 @classmethod
266 def _clear_cached_class_lookups(cls: type[Self]) -> None:
267 for subclass in subclasses(cls):
268 subclass.get_class_lookups.cache_clear() # type: ignore[attr-defined]
269
270 def register_class_lookup(
271 cls: type[Self],
272 lookup: type[Lookup | Transform],
273 lookup_name: str | None = None,
274 ) -> type[Lookup | Transform]:
275 if lookup_name is None:
276 lookup_name = lookup.lookup_name # type: ignore[attr-defined]
277 if "class_lookups" not in cls.__dict__:
278 cls.class_lookups = {} # type: ignore[misc]
279 cls.class_lookups[lookup_name] = lookup # type: ignore[attr-defined]
280 cls._clear_cached_class_lookups()
281 return lookup
282
283 def register_instance_lookup(
284 self, lookup: type[Lookup | Transform], lookup_name: str | None = None
285 ) -> type[Lookup | Transform]:
286 if lookup_name is None:
287 lookup_name = lookup.lookup_name # type: ignore[attr-defined]
288 if "instance_lookups" not in self.__dict__:
289 self.instance_lookups = {}
290 self.instance_lookups[lookup_name] = lookup
291 return lookup
292
293 register_lookup = class_or_instance_method(
294 register_class_lookup, register_instance_lookup
295 )
296 register_class_lookup = classmethod(register_class_lookup) # type: ignore[assignment]
297
298 def _unregister_class_lookup(
299 cls: type[Self],
300 lookup: type[Lookup | Transform],
301 lookup_name: str | None = None,
302 ) -> None:
303 """
304 Remove given lookup from cls lookups. For use in tests only as it's
305 not thread-safe.
306 """
307 if lookup_name is None:
308 lookup_name = lookup.lookup_name # type: ignore[attr-defined]
309 del cls.class_lookups[lookup_name] # type: ignore[attr-defined]
310 cls._clear_cached_class_lookups()
311
312 def _unregister_instance_lookup(
313 self, lookup: type[Lookup | Transform], lookup_name: str | None = None
314 ) -> None:
315 """
316 Remove given lookup from instance lookups. For use in tests only as
317 it's not thread-safe.
318 """
319 if lookup_name is None:
320 lookup_name = lookup.lookup_name # type: ignore[attr-defined]
321 del self.instance_lookups[lookup_name]
322
323 _unregister_lookup = class_or_instance_method(
324 _unregister_class_lookup, _unregister_instance_lookup
325 )
326 _unregister_class_lookup = classmethod(_unregister_class_lookup) # type: ignore[assignment]
327
328
329def select_related_descend(
330 field: Any,
331 restricted: bool | None,
332 requested: dict[str, Any] | None,
333 select_mask: Any,
334 reverse: bool = False,
335) -> bool:
336 """
337 Return True if this field should be used to descend deeper for
338 select_related() purposes. Used by both the query construction code
339 (compiler.get_related_selections()) and the model instance creation code
340 (compiler.klass_info).
341
342 Arguments:
343 * field - the field to be checked
344 * restricted - a boolean field, indicating if the field list has been
345 manually restricted using a requested clause)
346 * requested - The select_related() dictionary.
347 * select_mask - the dictionary of selected fields.
348 * reverse - boolean, True if we are checking a reverse select related
349 """
350 from plain.models.fields.related import RelatedField
351
352 if not isinstance(field, RelatedField):
353 return False
354 if restricted:
355 assert requested is not None, "requested must be provided when restricted=True"
356 if reverse and field.related_query_name() not in requested:
357 return False
358 if not reverse and field.name not in requested:
359 return False
360 if not restricted and field.allow_null:
361 return False
362 if (
363 restricted
364 and select_mask
365 and field.name in requested # type: ignore[operator]
366 and field not in select_mask
367 ):
368 raise FieldError(
369 f"Field {field.model.model_options.object_name}.{field.name} cannot be both "
370 "deferred and traversed using select_related at the same time."
371 )
372 return True
373
374
375def refs_expression(
376 lookup_parts: list[str], annotations: dict[str, Any]
377) -> tuple[str | None, tuple[str, ...]]:
378 """
379 Check if the lookup_parts contains references to the given annotations set.
380 Because the LOOKUP_SEP is contained in the default annotation names, check
381 each prefix of the lookup_parts for a match.
382 """
383 for n in range(1, len(lookup_parts) + 1):
384 level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
385 if annotations.get(level_n_lookup):
386 return level_n_lookup, tuple(lookup_parts[n:])
387 return None, ()
388
389
390def check_rel_lookup_compatibility(
391 model: type[Model], target_meta: Meta, field: Field | ForeignObjectRel
392) -> bool:
393 """
394 Check that model is compatible with target_meta. Compatibility
395 is OK if:
396 1) model and meta.model match (where proxy inheritance is removed)
397 2) model is parent of meta's model or the other way around
398 """
399
400 def check(meta: Meta) -> bool:
401 return model == meta.model
402
403 # If the field is a primary key, then doing a query against the field's
404 # model is ok, too. Consider the case:
405 # class Restaurant(models.Model):
406 # place = OneToOneField(Place, primary_key=True):
407 # Restaurant.query.filter(id__in=Restaurant.query.all()).
408 # If we didn't have the primary key check, then id__in (== place__in) would
409 # give Place's meta as the target meta, but Restaurant isn't compatible
410 # with that. This logic applies only to primary keys, as when doing __in=qs,
411 # we are going to turn this into __in=qs.values('id') later on.
412 return check(target_meta) or (
413 getattr(field, "primary_key", False) and check(field.model._model_meta)
414 )
415
416
417class FilteredRelation:
418 """Specify custom filtering in the ON clause of SQL joins."""
419
420 def __init__(self, relation_name: str, *, condition: Q = Q()) -> None:
421 if not relation_name:
422 raise ValueError("relation_name cannot be empty.")
423 self.relation_name = relation_name
424 self.alias: str | None = None
425 if not isinstance(condition, Q):
426 raise ValueError("condition argument must be a Q() instance.")
427 self.condition = condition
428 self.path: list[str] = []
429
430 def __eq__(self, other: object) -> bool:
431 if not isinstance(other, self.__class__):
432 return NotImplemented
433 return (
434 self.relation_name == other.relation_name
435 and self.alias == other.alias
436 and self.condition == other.condition
437 )
438
439 def clone(self) -> FilteredRelation:
440 clone = FilteredRelation(self.relation_name, condition=self.condition)
441 clone.alias = self.alias
442 clone.path = self.path[:]
443 return clone
444
445 def resolve_expression(self, *args: Any, **kwargs: Any) -> Any:
446 """
447 QuerySet.annotate() only accepts expression-like arguments
448 (with a resolve_expression() method).
449 """
450 raise NotImplementedError("FilteredRelation.resolve_expression() is unused.")
451
452 def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any:
453 # Resolve the condition in Join.filtered_relation.
454 query = compiler.query
455 where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
456 return compiler.compile(where)