1"""
2Managers for related objects.
3
4These managers provide the API for working with collections of related objects
5through foreign key and many-to-many relationships.
6"""
7
8from __future__ import annotations
9
10from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
11
12if TYPE_CHECKING:
13 from collections.abc import Callable, Iterable
14
15 from plain.postgres.base import Model
16 from plain.postgres.fields.related import ForeignKeyField, ManyToManyField
17
18import builtins
19
20from plain.postgres import transaction
21from plain.postgres.db import get_connection
22from plain.postgres.dialect import quote_name
23from plain.postgres.expressions import Window
24from plain.postgres.functions import RowNumber
25from plain.postgres.lookups import GreaterThan, LessThanOrEqual
26from plain.postgres.query import QuerySet
27from plain.postgres.query_utils import Q
28from plain.postgres.utils import resolve_callables
29
30# TypeVar for generic manager support
31T = TypeVar("T", bound="Model")
32# TypeVar for custom QuerySet types (defaults to QuerySet[Any] when not specified)
33QS = TypeVar("QS", bound="QuerySet[Any]", default="QuerySet[Any]")
34
35
36def _filter_prefetch_queryset(
37 queryset: QuerySet, field_name: str, instances: Iterable[Model]
38) -> QuerySet:
39 filter_kwargs: dict[str, Any] = {f"{field_name}__in": instances}
40 predicate = Q(**filter_kwargs)
41 if queryset.sql_query.is_sliced:
42 # Use window functions for limited queryset prefetching
43 low_mark, high_mark = queryset.sql_query.low_mark, queryset.sql_query.high_mark
44 order_by = [
45 expr for expr, _ in queryset.sql_query.get_compiler().get_order_by()
46 ]
47 window = Window(RowNumber(), partition_by=field_name, order_by=order_by)
48 predicate &= GreaterThan(window, low_mark)
49 if high_mark is not None:
50 predicate &= LessThanOrEqual(window, high_mark)
51 queryset.sql_query.clear_limits()
52 return queryset.filter(predicate)
53
54
55class BaseRelatedManager(Generic[T, QS]):
56 """
57 Base class for all related object managers.
58
59 All related managers should have a 'query' property that returns a QuerySet.
60 """
61
62 @property
63 def query(self) -> QS:
64 """Access the QuerySet for this relationship."""
65 return self.get_queryset()
66
67 def get_queryset(self) -> QS:
68 """Return the QuerySet for this relationship."""
69 raise NotImplementedError("Subclasses must implement get_queryset()")
70
71
72class ReverseForeignKeyManager(BaseRelatedManager[T, QS]):
73 """
74 Manager for the reverse side of a foreign key relation.
75
76 This manager adds behaviors specific to foreign key relations.
77 """
78
79 # Type hints for attributes
80 model: type[T]
81 instance: Model
82 field: ForeignKeyField
83 core_filters: dict[str, Model]
84 allow_null: bool
85
86 def __init__(
87 self, instance: Model, field: ForeignKeyField, related_model: type[Model]
88 ):
89 assert field.name is not None, "Field must have a name"
90 self.model = cast(type[T], related_model)
91 self.instance = instance
92 self.field = field
93 self.core_filters = {self.field.name: instance}
94 self.allow_null = self.field.allow_null
95
96 def _check_fk_val(self) -> None:
97 for field in self.field.foreign_related_fields:
98 if getattr(self.instance, field.attname) is None:
99 raise ValueError(
100 f'"{self.instance!r}" needs to have a value for field '
101 f'"{field.attname}" before this relationship can be used.'
102 )
103
104 def _apply_rel_filters(self, queryset: QuerySet) -> QuerySet:
105 """
106 Filter the queryset for the instance this manager is bound to.
107 """
108 from plain.postgres.exceptions import FieldError
109
110 queryset._defer_next_filter = True
111 queryset = queryset.filter(**self.core_filters)
112 for field in self.field.foreign_related_fields:
113 val = getattr(self.instance, field.attname)
114 if val is None:
115 return queryset.none()
116
117 try:
118 target_field = self.field.target_field
119 except FieldError:
120 # The relationship has multiple target fields. Use a tuple
121 # for related object id.
122 rel_obj_id = tuple(
123 [
124 getattr(self.instance, target_field.attname)
125 for target_field in self.field.path_infos[-1].target_fields
126 ]
127 )
128 else:
129 rel_obj_id = getattr(self.instance, target_field.attname)
130 queryset._known_related_objects = {self.field: {rel_obj_id: self.instance}}
131 return queryset
132
133 def _remove_prefetched_objects(self) -> None:
134 try:
135 self.instance._prefetched_objects_cache.pop(
136 self.field.remote_field.get_cache_name()
137 )
138 except (AttributeError, KeyError):
139 pass # nothing to clear from cache
140
141 def get_queryset(self) -> QS:
142 # Even if this relation is not to primary key, we require still primary key value.
143 # The wish is that the instance has been already saved to DB,
144 # although having a primary key value isn't a guarantee of that.
145 if self.instance.id is None:
146 raise ValueError(
147 f"{self.instance.__class__.__name__!r} instance needs to have a "
148 f"primary key value before this relationship can be used."
149 )
150 try:
151 return self.instance._prefetched_objects_cache[
152 self.field.remote_field.get_cache_name()
153 ]
154 except (AttributeError, KeyError):
155 queryset = self.model.query
156 return cast(QS, self._apply_rel_filters(queryset))
157
158 def get_prefetch_queryset(
159 self, instances: Iterable[Model], queryset: QuerySet | None = None
160 ) -> tuple[
161 QuerySet, Callable[[Model], Any], Callable[[Model], Any], bool, str, bool
162 ]:
163 if queryset is None:
164 queryset = self.model.query
165
166 rel_obj_attr = self.field.get_local_related_value
167 instance_attr = self.field.get_foreign_related_value
168 instances_dict = {instance_attr(inst): inst for inst in instances}
169 queryset = _filter_prefetch_queryset(queryset, self.field.name, instances)
170
171 # Since we just bypassed this class' get_queryset(), we must manage
172 # the reverse relation manually.
173 for rel_obj in queryset:
174 if not self.field.is_cached(rel_obj):
175 instance = instances_dict[rel_obj_attr(rel_obj)]
176 setattr(rel_obj, self.field.name, instance)
177 cache_name = self.field.remote_field.get_cache_name()
178 return queryset, rel_obj_attr, instance_attr, False, cache_name, False
179
180 def add(self, *objs: T, bulk: bool = True) -> None:
181 self._check_fk_val()
182 self._remove_prefetched_objects()
183
184 def check_and_update_obj(obj: Any) -> None:
185 if not isinstance(obj, self.model):
186 raise TypeError(
187 f"'{self.model.model_options.object_name}' instance expected, got {obj!r}"
188 )
189 setattr(obj, self.field.name, self.instance)
190
191 if bulk:
192 ids = []
193 for obj in objs:
194 check_and_update_obj(obj)
195 if obj._state.adding:
196 raise ValueError(
197 f"{obj!r} instance isn't saved. Use bulk=False or save "
198 "the object first."
199 )
200 ids.append(obj.id)
201 self.model._model_meta.base_queryset.filter(id__in=ids).update(
202 **{
203 self.field.name: self.instance,
204 }
205 )
206 else:
207 with transaction.atomic(savepoint=False):
208 for obj in objs:
209 check_and_update_obj(obj)
210 obj.save()
211
212 def create(self, **kwargs: Any) -> T:
213 self._check_fk_val()
214 kwargs[self.field.name] = self.instance
215 return cast(T, self.model.query.create(**kwargs))
216
217 def get_or_create(self, **kwargs: Any) -> tuple[T, bool]:
218 self._check_fk_val()
219 kwargs[self.field.name] = self.instance
220 return cast(tuple[T, bool], self.model.query.get_or_create(**kwargs))
221
222 def update_or_create(self, **kwargs: Any) -> tuple[T, bool]:
223 self._check_fk_val()
224 kwargs[self.field.name] = self.instance
225 return cast(tuple[T, bool], self.model.query.update_or_create(**kwargs))
226
227 def remove(self, *objs: T, bulk: bool = True) -> None:
228 # remove() is only provided if the ForeignKeyField can have a value of null
229 if not self.allow_null:
230 raise AttributeError(
231 f"Cannot call remove() on a related manager for field "
232 f"{self.field.name} where null=False."
233 )
234 if not objs:
235 return
236 self._check_fk_val()
237 val = self.field.get_foreign_related_value(self.instance)
238 old_ids = set()
239 for obj in objs:
240 if not isinstance(obj, self.model):
241 raise TypeError(
242 f"'{self.model.model_options.object_name}' instance expected, got {obj!r}"
243 )
244 # Is obj actually part of this descriptor set?
245 if self.field.get_local_related_value(obj) == val:
246 old_ids.add(obj.id)
247 else:
248 raise self.field.remote_field.model.DoesNotExist(
249 f"{obj!r} is not related to {self.instance!r}."
250 )
251 self._clear(self.query.filter(id__in=old_ids), bulk)
252
253 def clear(self, *, bulk: bool = True) -> None:
254 # clear() is only provided if the ForeignKeyField can have a value of null
255 if not self.allow_null:
256 raise AttributeError(
257 f"Cannot call clear() on a related manager for field "
258 f"{self.field.name} where null=False."
259 )
260 self._check_fk_val()
261 self._clear(self.query, bulk)
262
263 def _clear(self, queryset: QuerySet, bulk: bool) -> None:
264 self._remove_prefetched_objects()
265 if bulk:
266 # `QuerySet.update()` is intrinsically atomic.
267 queryset.update(**{self.field.name: None})
268 else:
269 with transaction.atomic(savepoint=False):
270 for obj in queryset:
271 setattr(obj, self.field.name, None)
272 obj.save(update_fields=[self.field.name])
273
274 def set(self, objs: Any, *, bulk: bool = True, clear: bool = False) -> None:
275 self._check_fk_val()
276 # Force evaluation of `objs` in case it's a queryset whose value
277 # could be affected by `manager.clear()`. Refs #19816.
278 objs = tuple(objs)
279
280 if self.field.allow_null:
281 with transaction.atomic(savepoint=False):
282 if clear:
283 self.clear(bulk=bulk)
284 self.add(*objs, bulk=bulk)
285 else:
286 old_objs = set(self.query.all())
287 new_objs = []
288 for obj in objs:
289 if obj in old_objs:
290 old_objs.remove(obj)
291 else:
292 new_objs.append(obj)
293
294 self.remove(*old_objs, bulk=bulk)
295 self.add(*new_objs, bulk=bulk)
296 else:
297 self.add(*objs, bulk=bulk)
298
299
300class ManyToManyManager(BaseRelatedManager[T, QS]):
301 """
302 Manager for both forward and reverse sides of a many-to-many relation.
303
304 This manager handles both directions of many-to-many relations with
305 conditional logic for symmetrical relationships (which only apply to
306 forward relations).
307 """
308
309 # Type hints for attributes
310 model: type[T]
311 instance: Model
312 field: ManyToManyField
313 through: type[Model]
314 query_field_name: str
315 prefetch_cache_name: str
316 source_field_name: str
317 target_field_name: str
318 symmetrical: bool
319 core_filters: dict[str, Any]
320 id_field_names: dict[str, str]
321 related_val: tuple[Any, ...]
322
323 def __init__(
324 self,
325 instance: Model,
326 field: ManyToManyField,
327 through: type[Model],
328 related_model: type[Model],
329 is_reverse: bool,
330 symmetrical: bool = False,
331 ):
332 assert field.name is not None, "Field must have a name"
333 # Set direction-specific attributes
334 if is_reverse:
335 # Reverse: accessing from the target model back to the source
336 self.model = cast(type[T], related_model)
337 self.query_field_name = field.name
338 self.prefetch_cache_name = field.related_query_name()
339 self.source_field_name = field.m2m_reverse_field_name()
340 self.target_field_name = field.m2m_field_name()
341 self.symmetrical = False # Reverse relations are never symmetrical
342 else:
343 # Forward: accessing from the source model to the target
344 self.model = cast(type[T], related_model)
345 self.query_field_name = field.related_query_name()
346 self.prefetch_cache_name = field.name
347 self.source_field_name = field.m2m_field_name()
348 self.target_field_name = field.m2m_reverse_field_name()
349 self.symmetrical = symmetrical
350
351 # Initialize common M2M attributes
352 self.instance = instance
353 self.through = through
354
355 # M2M through model fields are always ForeignKey
356 self.source_field = cast(
357 "ForeignKeyField",
358 self.through._model_meta.get_forward_field(self.source_field_name),
359 )
360 self.target_field = cast(
361 "ForeignKeyField",
362 self.through._model_meta.get_forward_field(self.target_field_name),
363 )
364
365 self.core_filters = {}
366 self.id_field_names = {}
367 for lh_field, rh_field in self.source_field.related_fields:
368 core_filter_key = f"{self.query_field_name}__{rh_field.name}"
369 self.core_filters[core_filter_key] = getattr(instance, rh_field.attname)
370 self.id_field_names[lh_field.name] = rh_field.name # type: ignore[assignment]
371
372 self.related_val = self.source_field.get_foreign_related_value(instance)
373 if None in self.related_val:
374 raise ValueError(
375 f'"{instance!r}" needs to have a value for field "{self.id_field_names[self.source_field_name]}" before '
376 "this many-to-many relationship can be used."
377 )
378 # Even if this relation is not to primary key, we require still primary key value.
379 if instance.id is None:
380 raise ValueError(
381 f"{instance.__class__.__name__!r} instance needs to have a primary key value before "
382 "a many-to-many relationship can be used."
383 )
384
385 def _apply_rel_filters(self, queryset: QuerySet) -> QuerySet:
386 """Filter the queryset for the instance this manager is bound to."""
387 queryset._defer_next_filter = True
388 return queryset._next_is_sticky().filter(**self.core_filters)
389
390 def _remove_prefetched_objects(self) -> None:
391 try:
392 self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name)
393 except (AttributeError, KeyError):
394 pass # nothing to clear from cache
395
396 def get_queryset(self) -> QS:
397 try:
398 return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
399 except (AttributeError, KeyError):
400 queryset = self.model.query
401 return cast(QS, self._apply_rel_filters(queryset))
402
403 def get_prefetch_queryset(
404 self, instances: Iterable[Model], queryset: QuerySet | None = None
405 ) -> tuple[
406 QuerySet, Callable[[Model], Any], Callable[[Model], Any], bool, str, bool
407 ]:
408 if queryset is None:
409 queryset = self.model.query
410
411 queryset = _filter_prefetch_queryset(
412 queryset._next_is_sticky(), self.query_field_name, instances
413 )
414
415 # M2M: need to annotate the query in order to get the primary model
416 # that the secondary model was actually related to.
417 from typing import cast
418
419 from plain.postgres.fields.related import ForeignKeyField
420
421 fk = cast(
422 ForeignKeyField,
423 self.through._model_meta.get_forward_field(self.source_field_name),
424 ) # M2M through model fields are always ForeignKey
425 join_table = fk.model.model_options.db_table
426 qn = quote_name
427 queryset = queryset.extra(
428 select={
429 f"_prefetch_related_val_{f.attname}": f"{qn(join_table)}.{qn(f.column)}"
430 for f in fk.local_related_fields
431 }
432 )
433 conn = get_connection()
434 return (
435 queryset,
436 lambda result: tuple(
437 getattr(result, f"_prefetch_related_val_{f.attname}")
438 for f in fk.local_related_fields
439 ),
440 lambda inst: tuple(
441 f.get_db_prep_value(getattr(inst, f.attname), conn)
442 for f in fk.foreign_related_fields
443 ),
444 False,
445 self.prefetch_cache_name,
446 False,
447 )
448
449 def clear(self) -> None:
450 with transaction.atomic(savepoint=False):
451 self._remove_prefetched_objects()
452 filters = self._build_remove_filters(self.model.query)
453 self.through.query.filter(filters).delete()
454
455 def set(
456 self,
457 objs: Any,
458 *,
459 clear: bool = False,
460 through_defaults: dict[str, Any] | None = None,
461 ) -> None:
462 # Force evaluation of `objs` in case it's a queryset whose value
463 # could be affected by `manager.clear()`. Refs #19816.
464 objs = tuple(objs)
465
466 with transaction.atomic(savepoint=False):
467 if clear:
468 self.clear()
469 self.add(*objs, through_defaults=through_defaults)
470 else:
471 old_ids = set(
472 self.query.values_list(
473 self.target_field.target_field.attname, flat=True
474 )
475 )
476
477 new_objs = []
478 for obj in objs:
479 fk_val = (
480 self.target_field.get_foreign_related_value(obj)[0]
481 if isinstance(obj, self.model)
482 else self.target_field.get_prep_value(obj)
483 )
484 if fk_val in old_ids:
485 old_ids.remove(fk_val)
486 else:
487 new_objs.append(obj)
488
489 self.remove(*old_ids)
490 self.add(*new_objs, through_defaults=through_defaults)
491
492 def create(
493 self, *, through_defaults: dict[str, Any] | None = None, **kwargs: Any
494 ) -> T:
495 new_obj = self.model.query.create(**kwargs)
496 self.add(new_obj, through_defaults=through_defaults)
497 return cast(T, new_obj)
498
499 def get_or_create(
500 self, *, through_defaults: dict[str, Any] | None = None, **kwargs: Any
501 ) -> tuple[T, bool]:
502 obj, created = self.model.query.get_or_create(**kwargs)
503 # We only need to add() if created because if we got an object back
504 # from get() then the relationship already exists.
505 if created:
506 self.add(obj, through_defaults=through_defaults)
507 return cast(T, obj), created
508
509 def update_or_create(
510 self, *, through_defaults: dict[str, Any] | None = None, **kwargs: Any
511 ) -> tuple[T, bool]:
512 obj, created = self.model.query.update_or_create(**kwargs)
513 # We only need to add() if created because if we got an object back
514 # from get() then the relationship already exists.
515 if created:
516 self.add(obj, through_defaults=through_defaults)
517 return cast(T, obj), created
518
519 def _get_target_ids(self, target_field_name: str, objs: Any) -> builtins.set[Any]:
520 """Return the set of ids of `objs` that the target field references."""
521 from typing import cast
522
523 from plain.postgres import Model
524 from plain.postgres.fields.related import ForeignKeyField
525
526 target_ids: set[Any] = set()
527 target_field = cast(
528 ForeignKeyField,
529 self.through._model_meta.get_forward_field(target_field_name),
530 ) # M2M through model fields are always ForeignKey
531 for obj in objs:
532 if isinstance(obj, self.model):
533 target_id = target_field.get_foreign_related_value(obj)[0]
534 if target_id is None:
535 raise ValueError(
536 f'Cannot add "{obj!r}": the value for field "{target_field_name}" is None'
537 )
538 target_ids.add(target_id)
539 elif isinstance(obj, Model):
540 raise TypeError(
541 f"'{self.model.model_options.object_name}' instance expected, got {obj!r}"
542 )
543 else:
544 target_ids.add(target_field.get_prep_value(obj))
545 return target_ids
546
547 def _get_missing_target_ids(
548 self,
549 source_field_name: str,
550 target_field_name: str,
551 target_ids: builtins.set[Any],
552 ) -> builtins.set[Any]:
553 """Return the subset of ids of `objs` that aren't already assigned to this relationship."""
554 vals = self.through.query.values_list(target_field_name, flat=True).filter(
555 **{
556 source_field_name: self.related_val[0],
557 f"{target_field_name}__in": target_ids,
558 }
559 )
560 return target_ids.difference(vals)
561
562 def _add_items(
563 self,
564 source_field_name: str,
565 target_field_name: str,
566 *objs: Any,
567 through_defaults: dict[str, Any] | None = None,
568 ) -> None:
569 if not objs:
570 return
571
572 through_defaults = dict(resolve_callables(through_defaults or {}))
573 target_ids = self._get_target_ids(target_field_name, objs)
574
575 missing_target_ids = self._get_missing_target_ids(
576 source_field_name, target_field_name, target_ids
577 )
578 with transaction.atomic(savepoint=False):
579 # Add the ones that aren't there already.
580 self.through.query.bulk_create(
581 [
582 self.through(
583 **through_defaults,
584 **{
585 f"{source_field_name}_id": self.related_val[0],
586 f"{target_field_name}_id": target_id,
587 },
588 )
589 for target_id in missing_target_ids
590 ],
591 )
592
593 def _remove_items(
594 self, source_field_name: str, target_field_name: str, *objs: Any
595 ) -> None:
596 if not objs:
597 return
598
599 # Check that all the objects are of the right type
600 old_ids = set()
601 for obj in objs:
602 if isinstance(obj, self.model):
603 fk_val = self.target_field.get_foreign_related_value(obj)[0]
604 old_ids.add(fk_val)
605 else:
606 old_ids.add(obj)
607
608 with transaction.atomic(savepoint=False):
609 target_model_qs = self.model.query
610 if target_model_qs._has_filters():
611 old_vals = target_model_qs.filter(
612 **{f"{self.target_field.target_field.attname}__in": old_ids}
613 )
614 else:
615 old_vals = old_ids
616 filters = self._build_remove_filters(old_vals)
617 self.through.query.filter(filters).delete()
618
619 def _build_remove_filters(self, removed_vals: Any) -> Any:
620 filters = Q.create([(self.source_field_name, self.related_val)])
621 # No need to add a subquery condition if removed_vals is a QuerySet without
622 # filters.
623 removed_vals_filters = (
624 not isinstance(removed_vals, QuerySet) or removed_vals._has_filters()
625 )
626 if removed_vals_filters:
627 filters = filters & Q.create(
628 [(f"{self.target_field_name}__in", removed_vals)]
629 )
630 # Add symmetrical filters for forward symmetrical relations
631 if self.symmetrical:
632 symmetrical_filters = Q.create([(self.target_field_name, self.related_val)])
633 if removed_vals_filters:
634 symmetrical_filters = symmetrical_filters & Q.create(
635 [(f"{self.source_field_name}__in", removed_vals)]
636 )
637 filters = filters | symmetrical_filters
638 return filters
639
640 def add(self, *objs: T, through_defaults: dict[str, Any] | None = None) -> None:
641 self._remove_prefetched_objects()
642 with transaction.atomic(savepoint=False):
643 self._add_items(
644 self.source_field_name,
645 self.target_field_name,
646 *objs,
647 through_defaults=through_defaults,
648 )
649 # If this is a symmetrical m2m relation to self, add the mirror
650 # entry in the m2m table.
651 if self.symmetrical:
652 self._add_items(
653 self.target_field_name,
654 self.source_field_name,
655 *objs,
656 through_defaults=through_defaults,
657 )
658
659 def remove(self, *objs: T) -> None:
660 self._remove_prefetched_objects()
661 self._remove_items(self.source_field_name, self.target_field_name, *objs)