Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Managers for related objects.
  3
  4These managers provide the API for working with collections of related objects
  5through foreign key and many-to-many relationships.
  6"""
  7
  8from plain.models import transaction
  9from plain.models.db import NotSupportedError, db_connection
 10from plain.models.expressions import Window
 11from plain.models.functions import RowNumber
 12from plain.models.lookups import GreaterThan, LessThanOrEqual
 13from plain.models.query import QuerySet
 14from plain.models.query_utils import Q
 15from plain.models.utils import resolve_callables
 16
 17
 18def _filter_prefetch_queryset(queryset, field_name, instances):
 19    predicate = Q(**{f"{field_name}__in": instances})
 20    if queryset.query.is_sliced:
 21        if not db_connection.features.supports_over_clause:
 22            raise NotSupportedError(
 23                "Prefetching from a limited queryset is only supported on backends "
 24                "that support window functions."
 25            )
 26        low_mark, high_mark = queryset.query.low_mark, queryset.query.high_mark
 27        order_by = [expr for expr, _ in queryset.query.get_compiler().get_order_by()]
 28        window = Window(RowNumber(), partition_by=field_name, order_by=order_by)
 29        predicate &= GreaterThan(window, low_mark)
 30        if high_mark is not None:
 31            predicate &= LessThanOrEqual(window, high_mark)
 32        queryset.query.clear_limits()
 33    return queryset.filter(predicate)
 34
 35
 36class BaseRelatedManager:
 37    """
 38    Base class for all related object managers.
 39
 40    All related managers should have a 'query' property that returns a QuerySet.
 41    """
 42
 43    @property
 44    def query(self) -> QuerySet:
 45        """Access the QuerySet for this relationship."""
 46        return self.get_queryset()
 47
 48    def get_queryset(self) -> QuerySet:
 49        """Return the QuerySet for this relationship."""
 50        raise NotImplementedError("Subclasses must implement get_queryset()")
 51
 52
 53class ReverseManyToOneManager(BaseRelatedManager):
 54    """
 55    Manager for the reverse side of a many-to-one relation.
 56
 57    This manager adds behaviors specific to many-to-one relations.
 58    """
 59
 60    def __init__(self, instance, rel):
 61        self.model = rel.related_model
 62        self.instance = instance
 63        self.field = rel.field
 64        self.core_filters = {self.field.name: instance}
 65        # Store the base queryset class for this model
 66        self.base_queryset_class = rel.related_model._meta.queryset.__class__
 67        self.allow_null = rel.field.allow_null
 68
 69    def _check_fk_val(self):
 70        for field in self.field.foreign_related_fields:
 71            if getattr(self.instance, field.attname) is None:
 72                raise ValueError(
 73                    f'"{self.instance!r}" needs to have a value for field '
 74                    f'"{field.attname}" before this relationship can be used.'
 75                )
 76
 77    def _apply_rel_filters(self, queryset):
 78        """
 79        Filter the queryset for the instance this manager is bound to.
 80        """
 81        from plain.exceptions import FieldError
 82
 83        queryset._defer_next_filter = True
 84        queryset = queryset.filter(**self.core_filters)
 85        for field in self.field.foreign_related_fields:
 86            val = getattr(self.instance, field.attname)
 87            if val is None:
 88                return queryset.none()
 89        if self.field.many_to_one:
 90            # Guard against field-like objects such as GenericRelation
 91            # that abuse create_reverse_many_to_one_manager() with reverse
 92            # one-to-many relationships instead and break known related
 93            # objects assignment.
 94            try:
 95                target_field = self.field.target_field
 96            except FieldError:
 97                # The relationship has multiple target fields. Use a tuple
 98                # for related object id.
 99                rel_obj_id = tuple(
100                    [
101                        getattr(self.instance, target_field.attname)
102                        for target_field in self.field.path_infos[-1].target_fields
103                    ]
104                )
105            else:
106                rel_obj_id = getattr(self.instance, target_field.attname)
107            queryset._known_related_objects = {self.field: {rel_obj_id: self.instance}}
108        return queryset
109
110    def _remove_prefetched_objects(self):
111        try:
112            self.instance._prefetched_objects_cache.pop(
113                self.field.remote_field.get_cache_name()
114            )
115        except (AttributeError, KeyError):
116            pass  # nothing to clear from cache
117
118    def get_queryset(self):
119        # Even if this relation is not to primary key, we require still primary key value.
120        # The wish is that the instance has been already saved to DB,
121        # although having a primary key value isn't a guarantee of that.
122        if self.instance.id is None:
123            raise ValueError(
124                f"{self.instance.__class__.__name__!r} instance needs to have a "
125                f"primary key value before this relationship can be used."
126            )
127        try:
128            return self.instance._prefetched_objects_cache[
129                self.field.remote_field.get_cache_name()
130            ]
131        except (AttributeError, KeyError):
132            # Use the base queryset class for this model
133            queryset = self.base_queryset_class(model=self.model)
134            return self._apply_rel_filters(queryset)
135
136    def get_prefetch_queryset(self, instances, queryset=None):
137        if queryset is None:
138            queryset = self.base_queryset_class(model=self.model)
139
140        rel_obj_attr = self.field.get_local_related_value
141        instance_attr = self.field.get_foreign_related_value
142        instances_dict = {instance_attr(inst): inst for inst in instances}
143        queryset = _filter_prefetch_queryset(queryset, self.field.name, instances)
144
145        # Since we just bypassed this class' get_queryset(), we must manage
146        # the reverse relation manually.
147        for rel_obj in queryset:
148            if not self.field.is_cached(rel_obj):
149                instance = instances_dict[rel_obj_attr(rel_obj)]
150                setattr(rel_obj, self.field.name, instance)
151        cache_name = self.field.remote_field.get_cache_name()
152        return queryset, rel_obj_attr, instance_attr, False, cache_name, False
153
154    def add(self, *objs, bulk=True):
155        self._check_fk_val()
156        self._remove_prefetched_objects()
157
158        def check_and_update_obj(obj):
159            if not isinstance(obj, self.model):
160                raise TypeError(
161                    f"'{self.model._meta.object_name}' instance expected, got {obj!r}"
162                )
163            setattr(obj, self.field.name, self.instance)
164
165        if bulk:
166            ids = []
167            for obj in objs:
168                check_and_update_obj(obj)
169                if obj._state.adding:
170                    raise ValueError(
171                        f"{obj!r} instance isn't saved. Use bulk=False or save "
172                        "the object first."
173                    )
174                ids.append(obj.id)
175            self.model._meta.base_queryset.filter(id__in=ids).update(
176                **{
177                    self.field.name: self.instance,
178                }
179            )
180        else:
181            with transaction.atomic(savepoint=False):
182                for obj in objs:
183                    check_and_update_obj(obj)
184                    obj.save()
185
186    def create(self, **kwargs):
187        self._check_fk_val()
188        kwargs[self.field.name] = self.instance
189        return self.base_queryset_class(model=self.model).create(**kwargs)
190
191    def get_or_create(self, **kwargs):
192        self._check_fk_val()
193        kwargs[self.field.name] = self.instance
194        return self.base_queryset_class(model=self.model).get_or_create(**kwargs)
195
196    def update_or_create(self, **kwargs):
197        self._check_fk_val()
198        kwargs[self.field.name] = self.instance
199        return self.base_queryset_class(model=self.model).update_or_create(**kwargs)
200
201    def remove(self, *objs, bulk=True):
202        # remove() is only provided if the ForeignKey can have a value of null
203        if not self.allow_null:
204            raise AttributeError(
205                f"Cannot call remove() on a related manager for field "
206                f"{self.field.name} where null=False."
207            )
208        if not objs:
209            return
210        self._check_fk_val()
211        val = self.field.get_foreign_related_value(self.instance)
212        old_ids = set()
213        for obj in objs:
214            if not isinstance(obj, self.model):
215                raise TypeError(
216                    f"'{self.model._meta.object_name}' instance expected, got {obj!r}"
217                )
218            # Is obj actually part of this descriptor set?
219            if self.field.get_local_related_value(obj) == val:
220                old_ids.add(obj.id)
221            else:
222                raise self.field.remote_field.model.DoesNotExist(
223                    f"{obj!r} is not related to {self.instance!r}."
224                )
225        self._clear(self.query.filter(id__in=old_ids), bulk)
226
227    def clear(self, *, bulk=True):
228        # clear() is only provided if the ForeignKey can have a value of null
229        if not self.allow_null:
230            raise AttributeError(
231                f"Cannot call clear() on a related manager for field "
232                f"{self.field.name} where null=False."
233            )
234        self._check_fk_val()
235        self._clear(self.query, bulk)
236
237    def _clear(self, queryset, bulk):
238        self._remove_prefetched_objects()
239        if bulk:
240            # `QuerySet.update()` is intrinsically atomic.
241            queryset.update(**{self.field.name: None})
242        else:
243            with transaction.atomic(savepoint=False):
244                for obj in queryset:
245                    setattr(obj, self.field.name, None)
246                    obj.save(update_fields=[self.field.name])
247
248    def set(self, objs, *, bulk=True, clear=False):
249        self._check_fk_val()
250        # Force evaluation of `objs` in case it's a queryset whose value
251        # could be affected by `manager.clear()`. Refs #19816.
252        objs = tuple(objs)
253
254        if self.field.allow_null:
255            with transaction.atomic(savepoint=False):
256                if clear:
257                    self.clear(bulk=bulk)
258                    self.add(*objs, bulk=bulk)
259                else:
260                    old_objs = set(self.query.all())
261                    new_objs = []
262                    for obj in objs:
263                        if obj in old_objs:
264                            old_objs.remove(obj)
265                        else:
266                            new_objs.append(obj)
267
268                    self.remove(*old_objs, bulk=bulk)
269                    self.add(*new_objs, bulk=bulk)
270        else:
271            self.add(*objs, bulk=bulk)
272
273
274class BaseManyToManyManager(BaseRelatedManager):
275    """
276    Base class for many-to-many managers with common functionality.
277
278    Subclasses must set these attributes in __init__:
279    - model
280    - query_field_name
281    - prefetch_cache_name
282    - source_field_name
283    - target_field_name
284    - symmetrical (for forward relations)
285    """
286
287    def __init__(self, instance, rel):
288        self.instance = instance
289        self.through = rel.through
290        # Subclasses must set model before calling super().__init__
291        self.base_queryset_class = self.model._meta.queryset.__class__
292
293        self.source_field = self.through._meta.get_field(self.source_field_name)
294        self.target_field = self.through._meta.get_field(self.target_field_name)
295
296        self.core_filters = {}
297        self.id_field_names = {}
298        for lh_field, rh_field in self.source_field.related_fields:
299            core_filter_key = f"{self.query_field_name}__{rh_field.name}"
300            self.core_filters[core_filter_key] = getattr(instance, rh_field.attname)
301            self.id_field_names[lh_field.name] = rh_field.name
302
303        self.related_val = self.source_field.get_foreign_related_value(instance)
304        if None in self.related_val:
305            raise ValueError(
306                f'"{instance!r}" needs to have a value for field "{self.id_field_names[self.source_field_name]}" before '
307                "this many-to-many relationship can be used."
308            )
309        # Even if this relation is not to primary key, we require still primary key value.
310        if instance.id is None:
311            raise ValueError(
312                f"{instance.__class__.__name__!r} instance needs to have a primary key value before "
313                "a many-to-many relationship can be used."
314            )
315
316    def _apply_rel_filters(self, queryset):
317        """Filter the queryset for the instance this manager is bound to."""
318        queryset._defer_next_filter = True
319        return queryset._next_is_sticky().filter(**self.core_filters)
320
321    def _remove_prefetched_objects(self):
322        try:
323            self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name)
324        except (AttributeError, KeyError):
325            pass  # nothing to clear from cache
326
327    def get_queryset(self) -> QuerySet:
328        try:
329            return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
330        except (AttributeError, KeyError):
331            queryset = self.base_queryset_class(model=self.model)
332            return self._apply_rel_filters(queryset)
333
334    def get_prefetch_queryset(self, instances, queryset=None):
335        if queryset is None:
336            queryset = self.base_queryset_class(model=self.model)
337
338        queryset = _filter_prefetch_queryset(
339            queryset._next_is_sticky(), self.query_field_name, instances
340        )
341
342        # M2M: need to annotate the query in order to get the primary model
343        # that the secondary model was actually related to.
344        fk = self.through._meta.get_field(self.source_field_name)
345        join_table = fk.model._meta.db_table
346        qn = db_connection.ops.quote_name
347        queryset = queryset.extra(
348            select={
349                f"_prefetch_related_val_{f.attname}": f"{qn(join_table)}.{qn(f.column)}"
350                for f in fk.local_related_fields
351            }
352        )
353        return (
354            queryset,
355            lambda result: tuple(
356                getattr(result, f"_prefetch_related_val_{f.attname}")
357                for f in fk.local_related_fields
358            ),
359            lambda inst: tuple(
360                f.get_db_prep_value(getattr(inst, f.attname), db_connection)
361                for f in fk.foreign_related_fields
362            ),
363            False,
364            self.prefetch_cache_name,
365            False,
366        )
367
368    def clear(self):
369        with transaction.atomic(savepoint=False):
370            self._remove_prefetched_objects()
371            filters = self._build_remove_filters(
372                self.base_queryset_class(model=self.model)
373            )
374            self.through.query.filter(filters).delete()
375
376    def set(self, objs, *, clear=False, through_defaults=None):
377        # Force evaluation of `objs` in case it's a queryset whose value
378        # could be affected by `manager.clear()`. Refs #19816.
379        objs = tuple(objs)
380
381        with transaction.atomic(savepoint=False):
382            if clear:
383                self.clear()
384                self.add(*objs, through_defaults=through_defaults)
385            else:
386                old_ids = set(
387                    self.query.values_list(
388                        self.target_field.target_field.attname, flat=True
389                    )
390                )
391
392                new_objs = []
393                for obj in objs:
394                    fk_val = (
395                        self.target_field.get_foreign_related_value(obj)[0]
396                        if isinstance(obj, self.model)
397                        else self.target_field.get_prep_value(obj)
398                    )
399                    if fk_val in old_ids:
400                        old_ids.remove(fk_val)
401                    else:
402                        new_objs.append(obj)
403
404                self.remove(*old_ids)
405                self.add(*new_objs, through_defaults=through_defaults)
406
407    def create(self, *, through_defaults=None, **kwargs):
408        new_obj = self.base_queryset_class(model=self.model).create(**kwargs)
409        self.add(new_obj, through_defaults=through_defaults)
410        return new_obj
411
412    def get_or_create(self, *, through_defaults=None, **kwargs):
413        obj, created = self.base_queryset_class(model=self.model).get_or_create(
414            **kwargs
415        )
416        # We only need to add() if created because if we got an object back
417        # from get() then the relationship already exists.
418        if created:
419            self.add(obj, through_defaults=through_defaults)
420        return obj, created
421
422    def update_or_create(self, *, through_defaults=None, **kwargs):
423        obj, created = self.base_queryset_class(model=self.model).update_or_create(
424            **kwargs
425        )
426        # We only need to add() if created because if we got an object back
427        # from get() then the relationship already exists.
428        if created:
429            self.add(obj, through_defaults=through_defaults)
430        return obj, created
431
432    def _get_target_ids(self, target_field_name, objs):
433        """Return the set of ids of `objs` that the target field references."""
434        from plain.models import Model
435
436        target_ids = set()
437        target_field = self.through._meta.get_field(target_field_name)
438        for obj in objs:
439            if isinstance(obj, self.model):
440                target_id = target_field.get_foreign_related_value(obj)[0]
441                if target_id is None:
442                    raise ValueError(
443                        f'Cannot add "{obj!r}": the value for field "{target_field_name}" is None'
444                    )
445                target_ids.add(target_id)
446            elif isinstance(obj, Model):
447                raise TypeError(
448                    f"'{self.model._meta.object_name}' instance expected, got {obj!r}"
449                )
450            else:
451                target_ids.add(target_field.get_prep_value(obj))
452        return target_ids
453
454    def _get_missing_target_ids(self, source_field_name, target_field_name, target_ids):
455        """Return the subset of ids of `objs` that aren't already assigned to this relationship."""
456        vals = self.through.query.values_list(target_field_name, flat=True).filter(
457            **{
458                source_field_name: self.related_val[0],
459                f"{target_field_name}__in": target_ids,
460            }
461        )
462        return target_ids.difference(vals)
463
464    def _add_items(
465        self, source_field_name, target_field_name, *objs, through_defaults=None
466    ):
467        if not objs:
468            return
469
470        through_defaults = dict(resolve_callables(through_defaults or {}))
471        target_ids = self._get_target_ids(target_field_name, objs)
472
473        missing_target_ids = self._get_missing_target_ids(
474            source_field_name, target_field_name, target_ids
475        )
476        with transaction.atomic(savepoint=False):
477            # Add the ones that aren't there already.
478            self.through.query.bulk_create(
479                [
480                    self.through(
481                        **through_defaults,
482                        **{
483                            f"{source_field_name}_id": self.related_val[0],
484                            f"{target_field_name}_id": target_id,
485                        },
486                    )
487                    for target_id in missing_target_ids
488                ],
489            )
490
491    def _remove_items(self, source_field_name, target_field_name, *objs):
492        if not objs:
493            return
494
495        # Check that all the objects are of the right type
496        old_ids = set()
497        for obj in objs:
498            if isinstance(obj, self.model):
499                fk_val = self.target_field.get_foreign_related_value(obj)[0]
500                old_ids.add(fk_val)
501            else:
502                old_ids.add(obj)
503
504        with transaction.atomic(savepoint=False):
505            target_model_qs = self.base_queryset_class(model=self.model)
506            if target_model_qs._has_filters():
507                old_vals = target_model_qs.filter(
508                    **{f"{self.target_field.target_field.attname}__in": old_ids}
509                )
510            else:
511                old_vals = old_ids
512            filters = self._build_remove_filters(old_vals)
513            self.through.query.filter(filters).delete()
514
515    # Subclasses must implement these methods:
516    def _build_remove_filters(self, removed_vals):
517        raise NotImplementedError
518
519    def add(self, *objs, through_defaults=None):
520        raise NotImplementedError
521
522    def remove(self, *objs):
523        raise NotImplementedError
524
525
526class ForwardManyToManyManager(BaseManyToManyManager):
527    """
528    Manager for the forward side of a many-to-many relation.
529
530    This manager adds behaviors specific to many-to-many relations.
531    """
532
533    def __init__(self, instance, rel):
534        # Set required attributes before calling super().__init__
535        self.model = rel.model
536        self.query_field_name = rel.field.related_query_name()
537        self.prefetch_cache_name = rel.field.name
538        self.source_field_name = rel.field.m2m_field_name()
539        self.target_field_name = rel.field.m2m_reverse_field_name()
540        self.symmetrical = rel.symmetrical
541
542        super().__init__(instance, rel)
543
544    def _build_remove_filters(self, removed_vals):
545        filters = Q.create([(self.source_field_name, self.related_val)])
546        # No need to add a subquery condition if removed_vals is a QuerySet without
547        # filters.
548        removed_vals_filters = (
549            not isinstance(removed_vals, QuerySet) or removed_vals._has_filters()
550        )
551        if removed_vals_filters:
552            filters &= Q.create([(f"{self.target_field_name}__in", removed_vals)])
553        if self.symmetrical:
554            symmetrical_filters = Q.create([(self.target_field_name, self.related_val)])
555            if removed_vals_filters:
556                symmetrical_filters &= Q.create(
557                    [(f"{self.source_field_name}__in", removed_vals)]
558                )
559            filters |= symmetrical_filters
560        return filters
561
562    def add(self, *objs, through_defaults=None):
563        self._remove_prefetched_objects()
564        with transaction.atomic(savepoint=False):
565            self._add_items(
566                self.source_field_name,
567                self.target_field_name,
568                *objs,
569                through_defaults=through_defaults,
570            )
571            # If this is a symmetrical m2m relation to self, add the mirror
572            # entry in the m2m table.
573            if self.symmetrical:
574                self._add_items(
575                    self.target_field_name,
576                    self.source_field_name,
577                    *objs,
578                    through_defaults=through_defaults,
579                )
580
581    def remove(self, *objs):
582        self._remove_prefetched_objects()
583        self._remove_items(self.source_field_name, self.target_field_name, *objs)
584
585
586class ReverseManyToManyManager(BaseManyToManyManager):
587    """
588    Manager for the reverse side of a many-to-many relation.
589
590    This manager adds behaviors specific to many-to-many relations.
591    """
592
593    def __init__(self, instance, rel):
594        # Set required attributes before calling super().__init__
595        self.model = rel.related_model
596        self.query_field_name = rel.field.name
597        self.prefetch_cache_name = rel.field.related_query_name()
598        self.source_field_name = rel.field.m2m_reverse_field_name()
599        self.target_field_name = rel.field.m2m_field_name()
600        self.symmetrical = False  # Reverse relations are never symmetrical
601
602        super().__init__(instance, rel)
603
604    def _build_remove_filters(self, removed_vals):
605        filters = Q.create([(self.source_field_name, self.related_val)])
606        # No need to add a subquery condition if removed_vals is a QuerySet without
607        # filters.
608        removed_vals_filters = (
609            not isinstance(removed_vals, QuerySet) or removed_vals._has_filters()
610        )
611        if removed_vals_filters:
612            filters &= Q.create([(f"{self.target_field_name}__in", removed_vals)])
613        # Note: reverse relations are never symmetrical, so no symmetrical logic here
614        return filters
615
616    def add(self, *objs, through_defaults=None):
617        self._remove_prefetched_objects()
618        with transaction.atomic(savepoint=False):
619            self._add_items(
620                self.source_field_name,
621                self.target_field_name,
622                *objs,
623                through_defaults=through_defaults,
624            )
625            # Reverse relations are never symmetrical, so no mirror entry logic
626
627    def remove(self, *objs):
628        self._remove_prefetched_objects()
629        self._remove_items(self.source_field_name, self.target_field_name, *objs)