1"""
  2Managers for related objects.
  3
  4These managers provide the API for working with collections of related objects
  5through foreign key and many-to-many relationships.
  6"""
  7
  8from __future__ import annotations
  9
 10from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
 11
 12if TYPE_CHECKING:
 13    from collections.abc import Callable, Iterable
 14
 15    from plain.postgres.base import Model
 16    from plain.postgres.fields.related import ForeignKeyField, ManyToManyField
 17
 18import builtins
 19
 20from plain.postgres import transaction
 21from plain.postgres.db import get_connection
 22from plain.postgres.dialect import quote_name
 23from plain.postgres.expressions import Window
 24from plain.postgres.functions import RowNumber
 25from plain.postgres.lookups import GreaterThan, LessThanOrEqual
 26from plain.postgres.query import QuerySet
 27from plain.postgres.query_utils import Q
 28from plain.postgres.utils import resolve_callables
 29
 30# TypeVar for generic manager support
 31T = TypeVar("T", bound="Model")
 32# TypeVar for custom QuerySet types (defaults to QuerySet[Any] when not specified)
 33QS = TypeVar("QS", bound="QuerySet[Any]", default="QuerySet[Any]")
 34
 35
 36def _filter_prefetch_queryset(
 37    queryset: QuerySet, field_name: str, instances: Iterable[Model]
 38) -> QuerySet:
 39    filter_kwargs: dict[str, Any] = {f"{field_name}__in": instances}
 40    predicate = Q(**filter_kwargs)
 41    if queryset.sql_query.is_sliced:
 42        # Use window functions for limited queryset prefetching
 43        low_mark, high_mark = queryset.sql_query.low_mark, queryset.sql_query.high_mark
 44        order_by = [
 45            expr for expr, _ in queryset.sql_query.get_compiler().get_order_by()
 46        ]
 47        window = Window(RowNumber(), partition_by=field_name, order_by=order_by)
 48        predicate &= GreaterThan(window, low_mark)
 49        if high_mark is not None:
 50            predicate &= LessThanOrEqual(window, high_mark)
 51        queryset.sql_query.clear_limits()
 52    return queryset.filter(predicate)
 53
 54
 55class BaseRelatedManager(Generic[T, QS]):
 56    """
 57    Base class for all related object managers.
 58
 59    All related managers should have a 'query' property that returns a QuerySet.
 60    """
 61
 62    @property
 63    def query(self) -> QS:
 64        """Access the QuerySet for this relationship."""
 65        return self.get_queryset()
 66
 67    def get_queryset(self) -> QS:
 68        """Return the QuerySet for this relationship."""
 69        raise NotImplementedError("Subclasses must implement get_queryset()")
 70
 71
 72class ReverseForeignKeyManager(BaseRelatedManager[T, QS]):
 73    """
 74    Manager for the reverse side of a foreign key relation.
 75
 76    This manager adds behaviors specific to foreign key relations.
 77    """
 78
 79    # Type hints for attributes
 80    model: type[T]
 81    instance: Model
 82    field: ForeignKeyField
 83    core_filters: dict[str, Model]
 84    allow_null: bool
 85
 86    def __init__(
 87        self, instance: Model, field: ForeignKeyField, related_model: type[Model]
 88    ):
 89        assert field.name is not None, "Field must have a name"
 90        self.model = cast(type[T], related_model)
 91        self.instance = instance
 92        self.field = field
 93        self.core_filters = {self.field.name: instance}
 94        self.allow_null = self.field.allow_null
 95
 96    def _check_fk_val(self) -> None:
 97        for field in self.field.foreign_related_fields:
 98            if getattr(self.instance, field.attname) is None:
 99                raise ValueError(
100                    f'"{self.instance!r}" needs to have a value for field '
101                    f'"{field.attname}" before this relationship can be used.'
102                )
103
104    def _apply_rel_filters(self, queryset: QuerySet) -> QuerySet:
105        """
106        Filter the queryset for the instance this manager is bound to.
107        """
108        from plain.postgres.exceptions import FieldError
109
110        queryset._defer_next_filter = True
111        queryset = queryset.filter(**self.core_filters)
112        for field in self.field.foreign_related_fields:
113            val = getattr(self.instance, field.attname)
114            if val is None:
115                return queryset.none()
116
117        try:
118            target_field = self.field.target_field
119        except FieldError:
120            # The relationship has multiple target fields. Use a tuple
121            # for related object id.
122            rel_obj_id = tuple(
123                [
124                    getattr(self.instance, target_field.attname)
125                    for target_field in self.field.path_infos[-1].target_fields
126                ]
127            )
128        else:
129            rel_obj_id = getattr(self.instance, target_field.attname)
130        queryset._known_related_objects = {self.field: {rel_obj_id: self.instance}}
131        return queryset
132
133    def _remove_prefetched_objects(self) -> None:
134        try:
135            self.instance._prefetched_objects_cache.pop(
136                self.field.remote_field.get_cache_name()
137            )
138        except (AttributeError, KeyError):
139            pass  # nothing to clear from cache
140
141    def get_queryset(self) -> QS:
142        # Even if this relation is not to primary key, we require still primary key value.
143        # The wish is that the instance has been already saved to DB,
144        # although having a primary key value isn't a guarantee of that.
145        if self.instance.id is None:
146            raise ValueError(
147                f"{self.instance.__class__.__name__!r} instance needs to have a "
148                f"primary key value before this relationship can be used."
149            )
150        try:
151            return self.instance._prefetched_objects_cache[
152                self.field.remote_field.get_cache_name()
153            ]
154        except (AttributeError, KeyError):
155            queryset = self.model.query
156            return cast(QS, self._apply_rel_filters(queryset))
157
158    def get_prefetch_queryset(
159        self, instances: Iterable[Model], queryset: QuerySet | None = None
160    ) -> tuple[
161        QuerySet, Callable[[Model], Any], Callable[[Model], Any], bool, str, bool
162    ]:
163        if queryset is None:
164            queryset = self.model.query
165
166        rel_obj_attr = self.field.get_local_related_value
167        instance_attr = self.field.get_foreign_related_value
168        instances_dict = {instance_attr(inst): inst for inst in instances}
169        queryset = _filter_prefetch_queryset(queryset, self.field.name, instances)
170
171        # Since we just bypassed this class' get_queryset(), we must manage
172        # the reverse relation manually.
173        for rel_obj in queryset:
174            if not self.field.is_cached(rel_obj):
175                instance = instances_dict[rel_obj_attr(rel_obj)]
176                setattr(rel_obj, self.field.name, instance)
177        cache_name = self.field.remote_field.get_cache_name()
178        return queryset, rel_obj_attr, instance_attr, False, cache_name, False
179
180    def add(self, *objs: T, bulk: bool = True) -> None:
181        self._check_fk_val()
182        self._remove_prefetched_objects()
183
184        def check_and_update_obj(obj: Any) -> None:
185            if not isinstance(obj, self.model):
186                raise TypeError(
187                    f"'{self.model.model_options.object_name}' instance expected, got {obj!r}"
188                )
189            setattr(obj, self.field.name, self.instance)
190
191        if bulk:
192            ids = []
193            for obj in objs:
194                check_and_update_obj(obj)
195                if obj._state.adding:
196                    raise ValueError(
197                        f"{obj!r} instance isn't saved. Use bulk=False or save "
198                        "the object first."
199                    )
200                ids.append(obj.id)
201            self.model._model_meta.base_queryset.filter(id__in=ids).update(
202                **{
203                    self.field.name: self.instance,
204                }
205            )
206        else:
207            with transaction.atomic(savepoint=False):
208                for obj in objs:
209                    check_and_update_obj(obj)
210                    obj.save()
211
212    def create(self, **kwargs: Any) -> T:
213        self._check_fk_val()
214        kwargs[self.field.name] = self.instance
215        return cast(T, self.model.query.create(**kwargs))
216
217    def get_or_create(self, **kwargs: Any) -> tuple[T, bool]:
218        self._check_fk_val()
219        kwargs[self.field.name] = self.instance
220        return cast(tuple[T, bool], self.model.query.get_or_create(**kwargs))
221
222    def update_or_create(self, **kwargs: Any) -> tuple[T, bool]:
223        self._check_fk_val()
224        kwargs[self.field.name] = self.instance
225        return cast(tuple[T, bool], self.model.query.update_or_create(**kwargs))
226
227    def remove(self, *objs: T, bulk: bool = True) -> None:
228        # remove() is only provided if the ForeignKeyField can have a value of null
229        if not self.allow_null:
230            raise AttributeError(
231                f"Cannot call remove() on a related manager for field "
232                f"{self.field.name} where null=False."
233            )
234        if not objs:
235            return
236        self._check_fk_val()
237        val = self.field.get_foreign_related_value(self.instance)
238        old_ids = set()
239        for obj in objs:
240            if not isinstance(obj, self.model):
241                raise TypeError(
242                    f"'{self.model.model_options.object_name}' instance expected, got {obj!r}"
243                )
244            # Is obj actually part of this descriptor set?
245            if self.field.get_local_related_value(obj) == val:
246                old_ids.add(obj.id)
247            else:
248                raise self.field.remote_field.model.DoesNotExist(
249                    f"{obj!r} is not related to {self.instance!r}."
250                )
251        self._clear(self.query.filter(id__in=old_ids), bulk)
252
253    def clear(self, *, bulk: bool = True) -> None:
254        # clear() is only provided if the ForeignKeyField can have a value of null
255        if not self.allow_null:
256            raise AttributeError(
257                f"Cannot call clear() on a related manager for field "
258                f"{self.field.name} where null=False."
259            )
260        self._check_fk_val()
261        self._clear(self.query, bulk)
262
263    def _clear(self, queryset: QuerySet, bulk: bool) -> None:
264        self._remove_prefetched_objects()
265        if bulk:
266            # `QuerySet.update()` is intrinsically atomic.
267            queryset.update(**{self.field.name: None})
268        else:
269            with transaction.atomic(savepoint=False):
270                for obj in queryset:
271                    setattr(obj, self.field.name, None)
272                    obj.save(update_fields=[self.field.name])
273
274    def set(self, objs: Any, *, bulk: bool = True, clear: bool = False) -> None:
275        self._check_fk_val()
276        # Force evaluation of `objs` in case it's a queryset whose value
277        # could be affected by `manager.clear()`. Refs #19816.
278        objs = tuple(objs)
279
280        if self.field.allow_null:
281            with transaction.atomic(savepoint=False):
282                if clear:
283                    self.clear(bulk=bulk)
284                    self.add(*objs, bulk=bulk)
285                else:
286                    old_objs = set(self.query.all())
287                    new_objs = []
288                    for obj in objs:
289                        if obj in old_objs:
290                            old_objs.remove(obj)
291                        else:
292                            new_objs.append(obj)
293
294                    self.remove(*old_objs, bulk=bulk)
295                    self.add(*new_objs, bulk=bulk)
296        else:
297            self.add(*objs, bulk=bulk)
298
299
300class ManyToManyManager(BaseRelatedManager[T, QS]):
301    """
302    Manager for both forward and reverse sides of a many-to-many relation.
303
304    This manager handles both directions of many-to-many relations with
305    conditional logic for symmetrical relationships (which only apply to
306    forward relations).
307    """
308
309    # Type hints for attributes
310    model: type[T]
311    instance: Model
312    field: ManyToManyField
313    through: type[Model]
314    query_field_name: str
315    prefetch_cache_name: str
316    source_field_name: str
317    target_field_name: str
318    symmetrical: bool
319    core_filters: dict[str, Any]
320    id_field_names: dict[str, str]
321    related_val: tuple[Any, ...]
322
323    def __init__(
324        self,
325        instance: Model,
326        field: ManyToManyField,
327        through: type[Model],
328        related_model: type[Model],
329        is_reverse: bool,
330        symmetrical: bool = False,
331    ):
332        assert field.name is not None, "Field must have a name"
333        # Set direction-specific attributes
334        if is_reverse:
335            # Reverse: accessing from the target model back to the source
336            self.model = cast(type[T], related_model)
337            self.query_field_name = field.name
338            self.prefetch_cache_name = field.related_query_name()
339            self.source_field_name = field.m2m_reverse_field_name()
340            self.target_field_name = field.m2m_field_name()
341            self.symmetrical = False  # Reverse relations are never symmetrical
342        else:
343            # Forward: accessing from the source model to the target
344            self.model = cast(type[T], related_model)
345            self.query_field_name = field.related_query_name()
346            self.prefetch_cache_name = field.name
347            self.source_field_name = field.m2m_field_name()
348            self.target_field_name = field.m2m_reverse_field_name()
349            self.symmetrical = symmetrical
350
351        # Initialize common M2M attributes
352        self.instance = instance
353        self.through = through
354
355        # M2M through model fields are always ForeignKey
356        self.source_field = cast(
357            "ForeignKeyField",
358            self.through._model_meta.get_forward_field(self.source_field_name),
359        )
360        self.target_field = cast(
361            "ForeignKeyField",
362            self.through._model_meta.get_forward_field(self.target_field_name),
363        )
364
365        self.core_filters = {}
366        self.id_field_names = {}
367        for lh_field, rh_field in self.source_field.related_fields:
368            core_filter_key = f"{self.query_field_name}__{rh_field.name}"
369            self.core_filters[core_filter_key] = getattr(instance, rh_field.attname)
370            self.id_field_names[lh_field.name] = rh_field.name  # type: ignore[assignment]
371
372        self.related_val = self.source_field.get_foreign_related_value(instance)
373        if None in self.related_val:
374            raise ValueError(
375                f'"{instance!r}" needs to have a value for field "{self.id_field_names[self.source_field_name]}" before '
376                "this many-to-many relationship can be used."
377            )
378        # Even if this relation is not to primary key, we require still primary key value.
379        if instance.id is None:
380            raise ValueError(
381                f"{instance.__class__.__name__!r} instance needs to have a primary key value before "
382                "a many-to-many relationship can be used."
383            )
384
385    def _apply_rel_filters(self, queryset: QuerySet) -> QuerySet:
386        """Filter the queryset for the instance this manager is bound to."""
387        queryset._defer_next_filter = True
388        return queryset._next_is_sticky().filter(**self.core_filters)
389
390    def _remove_prefetched_objects(self) -> None:
391        try:
392            self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name)
393        except (AttributeError, KeyError):
394            pass  # nothing to clear from cache
395
396    def get_queryset(self) -> QS:
397        try:
398            return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
399        except (AttributeError, KeyError):
400            queryset = self.model.query
401            return cast(QS, self._apply_rel_filters(queryset))
402
403    def get_prefetch_queryset(
404        self, instances: Iterable[Model], queryset: QuerySet | None = None
405    ) -> tuple[
406        QuerySet, Callable[[Model], Any], Callable[[Model], Any], bool, str, bool
407    ]:
408        if queryset is None:
409            queryset = self.model.query
410
411        queryset = _filter_prefetch_queryset(
412            queryset._next_is_sticky(), self.query_field_name, instances
413        )
414
415        # M2M: need to annotate the query in order to get the primary model
416        # that the secondary model was actually related to.
417        from typing import cast
418
419        from plain.postgres.fields.related import ForeignKeyField
420
421        fk = cast(
422            ForeignKeyField,
423            self.through._model_meta.get_forward_field(self.source_field_name),
424        )  # M2M through model fields are always ForeignKey
425        join_table = fk.model.model_options.db_table
426        qn = quote_name
427        queryset = queryset.extra(
428            select={
429                f"_prefetch_related_val_{f.attname}": f"{qn(join_table)}.{qn(f.column)}"
430                for f in fk.local_related_fields
431            }
432        )
433        conn = get_connection()
434        return (
435            queryset,
436            lambda result: tuple(
437                getattr(result, f"_prefetch_related_val_{f.attname}")
438                for f in fk.local_related_fields
439            ),
440            lambda inst: tuple(
441                f.get_db_prep_value(getattr(inst, f.attname), conn)
442                for f in fk.foreign_related_fields
443            ),
444            False,
445            self.prefetch_cache_name,
446            False,
447        )
448
449    def clear(self) -> None:
450        with transaction.atomic(savepoint=False):
451            self._remove_prefetched_objects()
452            filters = self._build_remove_filters(self.model.query)
453            self.through.query.filter(filters).delete()
454
455    def set(
456        self,
457        objs: Any,
458        *,
459        clear: bool = False,
460        through_defaults: dict[str, Any] | None = None,
461    ) -> None:
462        # Force evaluation of `objs` in case it's a queryset whose value
463        # could be affected by `manager.clear()`. Refs #19816.
464        objs = tuple(objs)
465
466        with transaction.atomic(savepoint=False):
467            if clear:
468                self.clear()
469                self.add(*objs, through_defaults=through_defaults)
470            else:
471                old_ids = set(
472                    self.query.values_list(
473                        self.target_field.target_field.attname, flat=True
474                    )
475                )
476
477                new_objs = []
478                for obj in objs:
479                    fk_val = (
480                        self.target_field.get_foreign_related_value(obj)[0]
481                        if isinstance(obj, self.model)
482                        else self.target_field.get_prep_value(obj)
483                    )
484                    if fk_val in old_ids:
485                        old_ids.remove(fk_val)
486                    else:
487                        new_objs.append(obj)
488
489                self.remove(*old_ids)
490                self.add(*new_objs, through_defaults=through_defaults)
491
492    def create(
493        self, *, through_defaults: dict[str, Any] | None = None, **kwargs: Any
494    ) -> T:
495        new_obj = self.model.query.create(**kwargs)
496        self.add(new_obj, through_defaults=through_defaults)
497        return cast(T, new_obj)
498
499    def get_or_create(
500        self, *, through_defaults: dict[str, Any] | None = None, **kwargs: Any
501    ) -> tuple[T, bool]:
502        obj, created = self.model.query.get_or_create(**kwargs)
503        # We only need to add() if created because if we got an object back
504        # from get() then the relationship already exists.
505        if created:
506            self.add(obj, through_defaults=through_defaults)
507        return cast(T, obj), created
508
509    def update_or_create(
510        self, *, through_defaults: dict[str, Any] | None = None, **kwargs: Any
511    ) -> tuple[T, bool]:
512        obj, created = self.model.query.update_or_create(**kwargs)
513        # We only need to add() if created because if we got an object back
514        # from get() then the relationship already exists.
515        if created:
516            self.add(obj, through_defaults=through_defaults)
517        return cast(T, obj), created
518
519    def _get_target_ids(self, target_field_name: str, objs: Any) -> builtins.set[Any]:
520        """Return the set of ids of `objs` that the target field references."""
521        from typing import cast
522
523        from plain.postgres import Model
524        from plain.postgres.fields.related import ForeignKeyField
525
526        target_ids: set[Any] = set()
527        target_field = cast(
528            ForeignKeyField,
529            self.through._model_meta.get_forward_field(target_field_name),
530        )  # M2M through model fields are always ForeignKey
531        for obj in objs:
532            if isinstance(obj, self.model):
533                target_id = target_field.get_foreign_related_value(obj)[0]
534                if target_id is None:
535                    raise ValueError(
536                        f'Cannot add "{obj!r}": the value for field "{target_field_name}" is None'
537                    )
538                target_ids.add(target_id)
539            elif isinstance(obj, Model):
540                raise TypeError(
541                    f"'{self.model.model_options.object_name}' instance expected, got {obj!r}"
542                )
543            else:
544                target_ids.add(target_field.get_prep_value(obj))
545        return target_ids
546
547    def _get_missing_target_ids(
548        self,
549        source_field_name: str,
550        target_field_name: str,
551        target_ids: builtins.set[Any],
552    ) -> builtins.set[Any]:
553        """Return the subset of ids of `objs` that aren't already assigned to this relationship."""
554        vals = self.through.query.values_list(target_field_name, flat=True).filter(
555            **{
556                source_field_name: self.related_val[0],
557                f"{target_field_name}__in": target_ids,
558            }
559        )
560        return target_ids.difference(vals)
561
562    def _add_items(
563        self,
564        source_field_name: str,
565        target_field_name: str,
566        *objs: Any,
567        through_defaults: dict[str, Any] | None = None,
568    ) -> None:
569        if not objs:
570            return
571
572        through_defaults = dict(resolve_callables(through_defaults or {}))
573        target_ids = self._get_target_ids(target_field_name, objs)
574
575        missing_target_ids = self._get_missing_target_ids(
576            source_field_name, target_field_name, target_ids
577        )
578        with transaction.atomic(savepoint=False):
579            # Add the ones that aren't there already.
580            self.through.query.bulk_create(
581                [
582                    self.through(
583                        **through_defaults,
584                        **{
585                            f"{source_field_name}_id": self.related_val[0],
586                            f"{target_field_name}_id": target_id,
587                        },
588                    )
589                    for target_id in missing_target_ids
590                ],
591            )
592
593    def _remove_items(
594        self, source_field_name: str, target_field_name: str, *objs: Any
595    ) -> None:
596        if not objs:
597            return
598
599        # Check that all the objects are of the right type
600        old_ids = set()
601        for obj in objs:
602            if isinstance(obj, self.model):
603                fk_val = self.target_field.get_foreign_related_value(obj)[0]
604                old_ids.add(fk_val)
605            else:
606                old_ids.add(obj)
607
608        with transaction.atomic(savepoint=False):
609            target_model_qs = self.model.query
610            if target_model_qs._has_filters():
611                old_vals = target_model_qs.filter(
612                    **{f"{self.target_field.target_field.attname}__in": old_ids}
613                )
614            else:
615                old_vals = old_ids
616            filters = self._build_remove_filters(old_vals)
617            self.through.query.filter(filters).delete()
618
619    def _build_remove_filters(self, removed_vals: Any) -> Any:
620        filters = Q.create([(self.source_field_name, self.related_val)])
621        # No need to add a subquery condition if removed_vals is a QuerySet without
622        # filters.
623        removed_vals_filters = (
624            not isinstance(removed_vals, QuerySet) or removed_vals._has_filters()
625        )
626        if removed_vals_filters:
627            filters = filters & Q.create(
628                [(f"{self.target_field_name}__in", removed_vals)]
629            )
630        # Add symmetrical filters for forward symmetrical relations
631        if self.symmetrical:
632            symmetrical_filters = Q.create([(self.target_field_name, self.related_val)])
633            if removed_vals_filters:
634                symmetrical_filters = symmetrical_filters & Q.create(
635                    [(f"{self.source_field_name}__in", removed_vals)]
636                )
637            filters = filters | symmetrical_filters
638        return filters
639
640    def add(self, *objs: T, through_defaults: dict[str, Any] | None = None) -> None:
641        self._remove_prefetched_objects()
642        with transaction.atomic(savepoint=False):
643            self._add_items(
644                self.source_field_name,
645                self.target_field_name,
646                *objs,
647                through_defaults=through_defaults,
648            )
649            # If this is a symmetrical m2m relation to self, add the mirror
650            # entry in the m2m table.
651            if self.symmetrical:
652                self._add_items(
653                    self.target_field_name,
654                    self.source_field_name,
655                    *objs,
656                    through_defaults=through_defaults,
657                )
658
659    def remove(self, *objs: T) -> None:
660        self._remove_prefetched_objects()
661        self._remove_items(self.source_field_name, self.target_field_name, *objs)