1"""
2Managers for related objects.
3
4These managers provide the API for working with collections of related objects
5through foreign key and many-to-many relationships.
6"""
7
8from plain.models import transaction
9from plain.models.db import NotSupportedError, db_connection
10from plain.models.expressions import Window
11from plain.models.functions import RowNumber
12from plain.models.lookups import GreaterThan, LessThanOrEqual
13from plain.models.query import QuerySet
14from plain.models.query_utils import Q
15from plain.models.utils import resolve_callables
16
17
18def _filter_prefetch_queryset(queryset, field_name, instances):
19 predicate = Q(**{f"{field_name}__in": instances})
20 if queryset.query.is_sliced:
21 if not db_connection.features.supports_over_clause:
22 raise NotSupportedError(
23 "Prefetching from a limited queryset is only supported on backends "
24 "that support window functions."
25 )
26 low_mark, high_mark = queryset.query.low_mark, queryset.query.high_mark
27 order_by = [expr for expr, _ in queryset.query.get_compiler().get_order_by()]
28 window = Window(RowNumber(), partition_by=field_name, order_by=order_by)
29 predicate &= GreaterThan(window, low_mark)
30 if high_mark is not None:
31 predicate &= LessThanOrEqual(window, high_mark)
32 queryset.query.clear_limits()
33 return queryset.filter(predicate)
34
35
36class BaseRelatedManager:
37 """
38 Base class for all related object managers.
39
40 All related managers should have a 'query' property that returns a QuerySet.
41 """
42
43 @property
44 def query(self) -> QuerySet:
45 """Access the QuerySet for this relationship."""
46 return self.get_queryset()
47
48 def get_queryset(self) -> QuerySet:
49 """Return the QuerySet for this relationship."""
50 raise NotImplementedError("Subclasses must implement get_queryset()")
51
52
53class ReverseManyToOneManager(BaseRelatedManager):
54 """
55 Manager for the reverse side of a many-to-one relation.
56
57 This manager adds behaviors specific to many-to-one relations.
58 """
59
60 def __init__(self, instance, rel):
61 self.model = rel.related_model
62 self.instance = instance
63 self.field = rel.field
64 self.core_filters = {self.field.name: instance}
65 # Store the base queryset class for this model
66 self.base_queryset_class = rel.related_model._meta.queryset.__class__
67 self.allow_null = rel.field.allow_null
68
69 def _check_fk_val(self):
70 for field in self.field.foreign_related_fields:
71 if getattr(self.instance, field.attname) is None:
72 raise ValueError(
73 f'"{self.instance!r}" needs to have a value for field '
74 f'"{field.attname}" before this relationship can be used.'
75 )
76
77 def _apply_rel_filters(self, queryset):
78 """
79 Filter the queryset for the instance this manager is bound to.
80 """
81 from plain.exceptions import FieldError
82
83 queryset._defer_next_filter = True
84 queryset = queryset.filter(**self.core_filters)
85 for field in self.field.foreign_related_fields:
86 val = getattr(self.instance, field.attname)
87 if val is None:
88 return queryset.none()
89 if self.field.many_to_one:
90 # Guard against field-like objects such as GenericRelation
91 # that abuse create_reverse_many_to_one_manager() with reverse
92 # one-to-many relationships instead and break known related
93 # objects assignment.
94 try:
95 target_field = self.field.target_field
96 except FieldError:
97 # The relationship has multiple target fields. Use a tuple
98 # for related object id.
99 rel_obj_id = tuple(
100 [
101 getattr(self.instance, target_field.attname)
102 for target_field in self.field.path_infos[-1].target_fields
103 ]
104 )
105 else:
106 rel_obj_id = getattr(self.instance, target_field.attname)
107 queryset._known_related_objects = {self.field: {rel_obj_id: self.instance}}
108 return queryset
109
110 def _remove_prefetched_objects(self):
111 try:
112 self.instance._prefetched_objects_cache.pop(
113 self.field.remote_field.get_cache_name()
114 )
115 except (AttributeError, KeyError):
116 pass # nothing to clear from cache
117
118 def get_queryset(self):
119 # Even if this relation is not to primary key, we require still primary key value.
120 # The wish is that the instance has been already saved to DB,
121 # although having a primary key value isn't a guarantee of that.
122 if self.instance.id is None:
123 raise ValueError(
124 f"{self.instance.__class__.__name__!r} instance needs to have a "
125 f"primary key value before this relationship can be used."
126 )
127 try:
128 return self.instance._prefetched_objects_cache[
129 self.field.remote_field.get_cache_name()
130 ]
131 except (AttributeError, KeyError):
132 # Use the base queryset class for this model
133 queryset = self.base_queryset_class(model=self.model)
134 return self._apply_rel_filters(queryset)
135
136 def get_prefetch_queryset(self, instances, queryset=None):
137 if queryset is None:
138 queryset = self.base_queryset_class(model=self.model)
139
140 rel_obj_attr = self.field.get_local_related_value
141 instance_attr = self.field.get_foreign_related_value
142 instances_dict = {instance_attr(inst): inst for inst in instances}
143 queryset = _filter_prefetch_queryset(queryset, self.field.name, instances)
144
145 # Since we just bypassed this class' get_queryset(), we must manage
146 # the reverse relation manually.
147 for rel_obj in queryset:
148 if not self.field.is_cached(rel_obj):
149 instance = instances_dict[rel_obj_attr(rel_obj)]
150 setattr(rel_obj, self.field.name, instance)
151 cache_name = self.field.remote_field.get_cache_name()
152 return queryset, rel_obj_attr, instance_attr, False, cache_name, False
153
154 def add(self, *objs, bulk=True):
155 self._check_fk_val()
156 self._remove_prefetched_objects()
157
158 def check_and_update_obj(obj):
159 if not isinstance(obj, self.model):
160 raise TypeError(
161 f"'{self.model._meta.object_name}' instance expected, got {obj!r}"
162 )
163 setattr(obj, self.field.name, self.instance)
164
165 if bulk:
166 ids = []
167 for obj in objs:
168 check_and_update_obj(obj)
169 if obj._state.adding:
170 raise ValueError(
171 f"{obj!r} instance isn't saved. Use bulk=False or save "
172 "the object first."
173 )
174 ids.append(obj.id)
175 self.model._meta.base_queryset.filter(id__in=ids).update(
176 **{
177 self.field.name: self.instance,
178 }
179 )
180 else:
181 with transaction.atomic(savepoint=False):
182 for obj in objs:
183 check_and_update_obj(obj)
184 obj.save()
185
186 def create(self, **kwargs):
187 self._check_fk_val()
188 kwargs[self.field.name] = self.instance
189 return self.base_queryset_class(model=self.model).create(**kwargs)
190
191 def get_or_create(self, **kwargs):
192 self._check_fk_val()
193 kwargs[self.field.name] = self.instance
194 return self.base_queryset_class(model=self.model).get_or_create(**kwargs)
195
196 def update_or_create(self, **kwargs):
197 self._check_fk_val()
198 kwargs[self.field.name] = self.instance
199 return self.base_queryset_class(model=self.model).update_or_create(**kwargs)
200
201 def remove(self, *objs, bulk=True):
202 # remove() is only provided if the ForeignKey can have a value of null
203 if not self.allow_null:
204 raise AttributeError(
205 f"Cannot call remove() on a related manager for field "
206 f"{self.field.name} where null=False."
207 )
208 if not objs:
209 return
210 self._check_fk_val()
211 val = self.field.get_foreign_related_value(self.instance)
212 old_ids = set()
213 for obj in objs:
214 if not isinstance(obj, self.model):
215 raise TypeError(
216 f"'{self.model._meta.object_name}' instance expected, got {obj!r}"
217 )
218 # Is obj actually part of this descriptor set?
219 if self.field.get_local_related_value(obj) == val:
220 old_ids.add(obj.id)
221 else:
222 raise self.field.remote_field.model.DoesNotExist(
223 f"{obj!r} is not related to {self.instance!r}."
224 )
225 self._clear(self.query.filter(id__in=old_ids), bulk)
226
227 def clear(self, *, bulk=True):
228 # clear() is only provided if the ForeignKey can have a value of null
229 if not self.allow_null:
230 raise AttributeError(
231 f"Cannot call clear() on a related manager for field "
232 f"{self.field.name} where null=False."
233 )
234 self._check_fk_val()
235 self._clear(self.query, bulk)
236
237 def _clear(self, queryset, bulk):
238 self._remove_prefetched_objects()
239 if bulk:
240 # `QuerySet.update()` is intrinsically atomic.
241 queryset.update(**{self.field.name: None})
242 else:
243 with transaction.atomic(savepoint=False):
244 for obj in queryset:
245 setattr(obj, self.field.name, None)
246 obj.save(update_fields=[self.field.name])
247
248 def set(self, objs, *, bulk=True, clear=False):
249 self._check_fk_val()
250 # Force evaluation of `objs` in case it's a queryset whose value
251 # could be affected by `manager.clear()`. Refs #19816.
252 objs = tuple(objs)
253
254 if self.field.allow_null:
255 with transaction.atomic(savepoint=False):
256 if clear:
257 self.clear(bulk=bulk)
258 self.add(*objs, bulk=bulk)
259 else:
260 old_objs = set(self.query.all())
261 new_objs = []
262 for obj in objs:
263 if obj in old_objs:
264 old_objs.remove(obj)
265 else:
266 new_objs.append(obj)
267
268 self.remove(*old_objs, bulk=bulk)
269 self.add(*new_objs, bulk=bulk)
270 else:
271 self.add(*objs, bulk=bulk)
272
273
274class BaseManyToManyManager(BaseRelatedManager):
275 """
276 Base class for many-to-many managers with common functionality.
277
278 Subclasses must set these attributes in __init__:
279 - model
280 - query_field_name
281 - prefetch_cache_name
282 - source_field_name
283 - target_field_name
284 - symmetrical (for forward relations)
285 """
286
287 def __init__(self, instance, rel):
288 self.instance = instance
289 self.through = rel.through
290 # Subclasses must set model before calling super().__init__
291 self.base_queryset_class = self.model._meta.queryset.__class__
292
293 self.source_field = self.through._meta.get_field(self.source_field_name)
294 self.target_field = self.through._meta.get_field(self.target_field_name)
295
296 self.core_filters = {}
297 self.id_field_names = {}
298 for lh_field, rh_field in self.source_field.related_fields:
299 core_filter_key = f"{self.query_field_name}__{rh_field.name}"
300 self.core_filters[core_filter_key] = getattr(instance, rh_field.attname)
301 self.id_field_names[lh_field.name] = rh_field.name
302
303 self.related_val = self.source_field.get_foreign_related_value(instance)
304 if None in self.related_val:
305 raise ValueError(
306 f'"{instance!r}" needs to have a value for field "{self.id_field_names[self.source_field_name]}" before '
307 "this many-to-many relationship can be used."
308 )
309 # Even if this relation is not to primary key, we require still primary key value.
310 if instance.id is None:
311 raise ValueError(
312 f"{instance.__class__.__name__!r} instance needs to have a primary key value before "
313 "a many-to-many relationship can be used."
314 )
315
316 def _apply_rel_filters(self, queryset):
317 """Filter the queryset for the instance this manager is bound to."""
318 queryset._defer_next_filter = True
319 return queryset._next_is_sticky().filter(**self.core_filters)
320
321 def _remove_prefetched_objects(self):
322 try:
323 self.instance._prefetched_objects_cache.pop(self.prefetch_cache_name)
324 except (AttributeError, KeyError):
325 pass # nothing to clear from cache
326
327 def get_queryset(self) -> QuerySet:
328 try:
329 return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
330 except (AttributeError, KeyError):
331 queryset = self.base_queryset_class(model=self.model)
332 return self._apply_rel_filters(queryset)
333
334 def get_prefetch_queryset(self, instances, queryset=None):
335 if queryset is None:
336 queryset = self.base_queryset_class(model=self.model)
337
338 queryset = _filter_prefetch_queryset(
339 queryset._next_is_sticky(), self.query_field_name, instances
340 )
341
342 # M2M: need to annotate the query in order to get the primary model
343 # that the secondary model was actually related to.
344 fk = self.through._meta.get_field(self.source_field_name)
345 join_table = fk.model._meta.db_table
346 qn = db_connection.ops.quote_name
347 queryset = queryset.extra(
348 select={
349 f"_prefetch_related_val_{f.attname}": f"{qn(join_table)}.{qn(f.column)}"
350 for f in fk.local_related_fields
351 }
352 )
353 return (
354 queryset,
355 lambda result: tuple(
356 getattr(result, f"_prefetch_related_val_{f.attname}")
357 for f in fk.local_related_fields
358 ),
359 lambda inst: tuple(
360 f.get_db_prep_value(getattr(inst, f.attname), db_connection)
361 for f in fk.foreign_related_fields
362 ),
363 False,
364 self.prefetch_cache_name,
365 False,
366 )
367
368 def clear(self):
369 with transaction.atomic(savepoint=False):
370 self._remove_prefetched_objects()
371 filters = self._build_remove_filters(
372 self.base_queryset_class(model=self.model)
373 )
374 self.through.query.filter(filters).delete()
375
376 def set(self, objs, *, clear=False, through_defaults=None):
377 # Force evaluation of `objs` in case it's a queryset whose value
378 # could be affected by `manager.clear()`. Refs #19816.
379 objs = tuple(objs)
380
381 with transaction.atomic(savepoint=False):
382 if clear:
383 self.clear()
384 self.add(*objs, through_defaults=through_defaults)
385 else:
386 old_ids = set(
387 self.query.values_list(
388 self.target_field.target_field.attname, flat=True
389 )
390 )
391
392 new_objs = []
393 for obj in objs:
394 fk_val = (
395 self.target_field.get_foreign_related_value(obj)[0]
396 if isinstance(obj, self.model)
397 else self.target_field.get_prep_value(obj)
398 )
399 if fk_val in old_ids:
400 old_ids.remove(fk_val)
401 else:
402 new_objs.append(obj)
403
404 self.remove(*old_ids)
405 self.add(*new_objs, through_defaults=through_defaults)
406
407 def create(self, *, through_defaults=None, **kwargs):
408 new_obj = self.base_queryset_class(model=self.model).create(**kwargs)
409 self.add(new_obj, through_defaults=through_defaults)
410 return new_obj
411
412 def get_or_create(self, *, through_defaults=None, **kwargs):
413 obj, created = self.base_queryset_class(model=self.model).get_or_create(
414 **kwargs
415 )
416 # We only need to add() if created because if we got an object back
417 # from get() then the relationship already exists.
418 if created:
419 self.add(obj, through_defaults=through_defaults)
420 return obj, created
421
422 def update_or_create(self, *, through_defaults=None, **kwargs):
423 obj, created = self.base_queryset_class(model=self.model).update_or_create(
424 **kwargs
425 )
426 # We only need to add() if created because if we got an object back
427 # from get() then the relationship already exists.
428 if created:
429 self.add(obj, through_defaults=through_defaults)
430 return obj, created
431
432 def _get_target_ids(self, target_field_name, objs):
433 """Return the set of ids of `objs` that the target field references."""
434 from plain.models import Model
435
436 target_ids = set()
437 target_field = self.through._meta.get_field(target_field_name)
438 for obj in objs:
439 if isinstance(obj, self.model):
440 target_id = target_field.get_foreign_related_value(obj)[0]
441 if target_id is None:
442 raise ValueError(
443 f'Cannot add "{obj!r}": the value for field "{target_field_name}" is None'
444 )
445 target_ids.add(target_id)
446 elif isinstance(obj, Model):
447 raise TypeError(
448 f"'{self.model._meta.object_name}' instance expected, got {obj!r}"
449 )
450 else:
451 target_ids.add(target_field.get_prep_value(obj))
452 return target_ids
453
454 def _get_missing_target_ids(self, source_field_name, target_field_name, target_ids):
455 """Return the subset of ids of `objs` that aren't already assigned to this relationship."""
456 vals = self.through.query.values_list(target_field_name, flat=True).filter(
457 **{
458 source_field_name: self.related_val[0],
459 f"{target_field_name}__in": target_ids,
460 }
461 )
462 return target_ids.difference(vals)
463
464 def _add_items(
465 self, source_field_name, target_field_name, *objs, through_defaults=None
466 ):
467 if not objs:
468 return
469
470 through_defaults = dict(resolve_callables(through_defaults or {}))
471 target_ids = self._get_target_ids(target_field_name, objs)
472
473 missing_target_ids = self._get_missing_target_ids(
474 source_field_name, target_field_name, target_ids
475 )
476 with transaction.atomic(savepoint=False):
477 # Add the ones that aren't there already.
478 self.through.query.bulk_create(
479 [
480 self.through(
481 **through_defaults,
482 **{
483 f"{source_field_name}_id": self.related_val[0],
484 f"{target_field_name}_id": target_id,
485 },
486 )
487 for target_id in missing_target_ids
488 ],
489 )
490
491 def _remove_items(self, source_field_name, target_field_name, *objs):
492 if not objs:
493 return
494
495 # Check that all the objects are of the right type
496 old_ids = set()
497 for obj in objs:
498 if isinstance(obj, self.model):
499 fk_val = self.target_field.get_foreign_related_value(obj)[0]
500 old_ids.add(fk_val)
501 else:
502 old_ids.add(obj)
503
504 with transaction.atomic(savepoint=False):
505 target_model_qs = self.base_queryset_class(model=self.model)
506 if target_model_qs._has_filters():
507 old_vals = target_model_qs.filter(
508 **{f"{self.target_field.target_field.attname}__in": old_ids}
509 )
510 else:
511 old_vals = old_ids
512 filters = self._build_remove_filters(old_vals)
513 self.through.query.filter(filters).delete()
514
515 # Subclasses must implement these methods:
516 def _build_remove_filters(self, removed_vals):
517 raise NotImplementedError
518
519 def add(self, *objs, through_defaults=None):
520 raise NotImplementedError
521
522 def remove(self, *objs):
523 raise NotImplementedError
524
525
526class ForwardManyToManyManager(BaseManyToManyManager):
527 """
528 Manager for the forward side of a many-to-many relation.
529
530 This manager adds behaviors specific to many-to-many relations.
531 """
532
533 def __init__(self, instance, rel):
534 # Set required attributes before calling super().__init__
535 self.model = rel.model
536 self.query_field_name = rel.field.related_query_name()
537 self.prefetch_cache_name = rel.field.name
538 self.source_field_name = rel.field.m2m_field_name()
539 self.target_field_name = rel.field.m2m_reverse_field_name()
540 self.symmetrical = rel.symmetrical
541
542 super().__init__(instance, rel)
543
544 def _build_remove_filters(self, removed_vals):
545 filters = Q.create([(self.source_field_name, self.related_val)])
546 # No need to add a subquery condition if removed_vals is a QuerySet without
547 # filters.
548 removed_vals_filters = (
549 not isinstance(removed_vals, QuerySet) or removed_vals._has_filters()
550 )
551 if removed_vals_filters:
552 filters &= Q.create([(f"{self.target_field_name}__in", removed_vals)])
553 if self.symmetrical:
554 symmetrical_filters = Q.create([(self.target_field_name, self.related_val)])
555 if removed_vals_filters:
556 symmetrical_filters &= Q.create(
557 [(f"{self.source_field_name}__in", removed_vals)]
558 )
559 filters |= symmetrical_filters
560 return filters
561
562 def add(self, *objs, through_defaults=None):
563 self._remove_prefetched_objects()
564 with transaction.atomic(savepoint=False):
565 self._add_items(
566 self.source_field_name,
567 self.target_field_name,
568 *objs,
569 through_defaults=through_defaults,
570 )
571 # If this is a symmetrical m2m relation to self, add the mirror
572 # entry in the m2m table.
573 if self.symmetrical:
574 self._add_items(
575 self.target_field_name,
576 self.source_field_name,
577 *objs,
578 through_defaults=through_defaults,
579 )
580
581 def remove(self, *objs):
582 self._remove_prefetched_objects()
583 self._remove_items(self.source_field_name, self.target_field_name, *objs)
584
585
586class ReverseManyToManyManager(BaseManyToManyManager):
587 """
588 Manager for the reverse side of a many-to-many relation.
589
590 This manager adds behaviors specific to many-to-many relations.
591 """
592
593 def __init__(self, instance, rel):
594 # Set required attributes before calling super().__init__
595 self.model = rel.related_model
596 self.query_field_name = rel.field.name
597 self.prefetch_cache_name = rel.field.related_query_name()
598 self.source_field_name = rel.field.m2m_reverse_field_name()
599 self.target_field_name = rel.field.m2m_field_name()
600 self.symmetrical = False # Reverse relations are never symmetrical
601
602 super().__init__(instance, rel)
603
604 def _build_remove_filters(self, removed_vals):
605 filters = Q.create([(self.source_field_name, self.related_val)])
606 # No need to add a subquery condition if removed_vals is a QuerySet without
607 # filters.
608 removed_vals_filters = (
609 not isinstance(removed_vals, QuerySet) or removed_vals._has_filters()
610 )
611 if removed_vals_filters:
612 filters &= Q.create([(f"{self.target_field_name}__in", removed_vals)])
613 # Note: reverse relations are never symmetrical, so no symmetrical logic here
614 return filters
615
616 def add(self, *objs, through_defaults=None):
617 self._remove_prefetched_objects()
618 with transaction.atomic(savepoint=False):
619 self._add_items(
620 self.source_field_name,
621 self.target_field_name,
622 *objs,
623 through_defaults=through_defaults,
624 )
625 # Reverse relations are never symmetrical, so no mirror entry logic
626
627 def remove(self, *objs):
628 self._remove_prefetched_objects()
629 self._remove_items(self.source_field_name, self.target_field_name, *objs)