Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections import Counter, defaultdict
  4from collections.abc import Callable, Generator, Iterable
  5from functools import partial, reduce
  6from itertools import chain
  7from operator import attrgetter, or_
  8from typing import TYPE_CHECKING, Any
  9
 10from plain.models import (
 11    query_utils,
 12    sql,
 13    transaction,
 14)
 15from plain.models.db import IntegrityError, db_connection
 16from plain.models.meta import Meta
 17from plain.models.query import QuerySet
 18
 19if TYPE_CHECKING:
 20    from plain.models.fields import Field
 21    from plain.models.fields.related import RelatedField
 22    from plain.models.fields.reverse_related import ForeignKeyRel
 23
 24
 25class ProtectedError(IntegrityError):
 26    def __init__(self, msg: str, protected_objects: Iterable[Any]) -> None:
 27        self.protected_objects = protected_objects
 28        super().__init__(msg, protected_objects)
 29
 30
 31class RestrictedError(IntegrityError):
 32    def __init__(self, msg: str, restricted_objects: Iterable[Any]) -> None:
 33        self.restricted_objects = restricted_objects
 34        super().__init__(msg, restricted_objects)
 35
 36
 37def CASCADE(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 38    collector.collect(
 39        sub_objs,
 40        source=field.remote_field.model,
 41        nullable=field.allow_null,
 42        fail_on_restricted=False,
 43    )
 44    if field.allow_null and not db_connection.features.can_defer_constraint_checks:
 45        collector.add_field_update(field, None, sub_objs)
 46
 47
 48def PROTECT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 49    raise ProtectedError(
 50        f"Cannot delete some instances of model '{field.remote_field.model.__name__}' because they are "
 51        f"referenced through a protected foreign key: '{sub_objs[0].__class__.__name__}.{field.name}'",
 52        sub_objs,
 53    )
 54
 55
 56def RESTRICT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 57    collector.add_restricted_objects(field, sub_objs)
 58    collector.add_dependency(field.remote_field.model, field.model)
 59
 60
 61def SET(value: Any) -> Callable[[Collector, RelatedField, Any], None]:
 62    if callable(value):
 63
 64        def set_on_delete(
 65            collector: Collector, field: RelatedField, sub_objs: Any
 66        ) -> None:
 67            collector.add_field_update(field, value(), sub_objs)
 68
 69    else:
 70
 71        def set_on_delete(
 72            collector: Collector, field: RelatedField, sub_objs: Any
 73        ) -> None:
 74            collector.add_field_update(field, value, sub_objs)
 75
 76    set_on_delete.deconstruct = lambda: ("plain.models.SET", (value,), {})  # type: ignore[attr-defined]
 77    set_on_delete.lazy_sub_objs = True  # type: ignore[attr-defined]
 78    return set_on_delete
 79
 80
 81def SET_NULL(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 82    collector.add_field_update(field, None, sub_objs)
 83
 84
 85SET_NULL.lazy_sub_objs = True  # type: ignore[attr-defined]
 86
 87
 88def SET_DEFAULT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 89    collector.add_field_update(field, field.get_default(), sub_objs)
 90
 91
 92SET_DEFAULT.lazy_sub_objs = True  # type: ignore[attr-defined]
 93
 94
 95def DO_NOTHING(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 96    pass
 97
 98
 99def get_candidate_relations_to_delete(
100    meta: Meta,
101) -> Generator[ForeignKeyRel, None, None]:
102    from plain.models.fields.reverse_related import ForeignKeyRel
103
104    # The candidate relations are the ones that come from N-1 and 1-1 relations.
105    # N-N  (i.e., many-to-many) relations aren't candidates for deletion.
106    return (
107        f
108        for f in meta.get_fields(include_reverse=True)
109        if f.auto_created and not f.concrete and isinstance(f, ForeignKeyRel)
110    )
111
112
113class Collector:
114    def __init__(self, origin: Any = None) -> None:
115        # A Model or QuerySet object.
116        self.origin = origin
117        # Initially, {model: {instances}}, later values become lists.
118        self.data: defaultdict[Any, Any] = defaultdict(set)
119        # {(field, value): [instances, …]}
120        self.field_updates: defaultdict[tuple[Field, Any], list[Any]] = defaultdict(
121            list
122        )
123        # {model: {field: {instances}}}
124        self.restricted_objects: defaultdict[Any, Any] = defaultdict(
125            partial(defaultdict, set)
126        )
127        # fast_deletes is a list of queryset-likes that can be deleted without
128        # fetching the objects into memory.
129        self.fast_deletes: list[Any] = []
130
131        # Tracks deletion-order dependency for databases without transactions
132        # or ability to defer constraint checks. Only concrete model classes
133        # should be included, as the dependencies exist only between actual
134        # database tables.
135        self.dependencies: defaultdict[Any, set[Any]] = defaultdict(
136            set
137        )  # {model: {models}}
138
139    def add(
140        self,
141        objs: Iterable[Any],
142        source: Any = None,
143        nullable: bool = False,
144        reverse_dependency: bool = False,
145    ) -> list[Any]:
146        """
147        Add 'objs' to the collection of objects to be deleted.  If the call is
148        the result of a cascade, 'source' should be the model that caused it,
149        and 'nullable' should be set to True if the relation can be null.
150
151        Return a list of all objects that were not already collected.
152        """
153        if not objs:
154            return []
155        new_objs = []
156        model = objs[0].__class__
157        instances = self.data[model]
158        for obj in objs:
159            if obj not in instances:
160                new_objs.append(obj)
161        instances.update(new_objs)
162        # Nullable relationships can be ignored -- they are nulled out before
163        # deleting, and therefore do not affect the order in which objects have
164        # to be deleted.
165        if source is not None and not nullable:
166            self.add_dependency(source, model, reverse_dependency=reverse_dependency)
167        return new_objs
168
169    def add_dependency(
170        self, model: Any, dependency: Any, reverse_dependency: bool = False
171    ) -> None:
172        if reverse_dependency:
173            model, dependency = dependency, model
174        self.dependencies[model].add(dependency)
175        self.data.setdefault(dependency, set())
176
177    def add_field_update(
178        self, field: RelatedField, value: Any, objs: Iterable[Any]
179    ) -> None:
180        """
181        Schedule a field update. 'objs' must be a homogeneous iterable
182        collection of model instances (e.g. a QuerySet).
183        """
184        self.field_updates[field, value].append(objs)
185
186    def add_restricted_objects(self, field: RelatedField, objs: Iterable[Any]) -> None:
187        if objs:
188            model = objs[0].__class__
189            self.restricted_objects[model][field].update(objs)
190
191    def clear_restricted_objects_from_set(self, model: Any, objs: set[Any]) -> None:
192        if model in self.restricted_objects:
193            self.restricted_objects[model] = {
194                field: items - objs
195                for field, items in self.restricted_objects[model].items()
196            }
197
198    def clear_restricted_objects_from_queryset(self, model: Any, qs: QuerySet) -> None:
199        if model in self.restricted_objects:
200            objs = set(
201                qs.filter(
202                    id__in=[
203                        obj.id
204                        for objs in self.restricted_objects[model].values()
205                        for obj in objs
206                    ]
207                )
208            )
209            self.clear_restricted_objects_from_set(model, objs)
210
211    def can_fast_delete(self, objs: Any, from_field: Any = None) -> bool:
212        """
213        Determine if the objects in the given queryset-like or single object
214        can be fast-deleted. This can be done if there are no cascades, no
215        parents and no signal listeners for the object class.
216
217        The 'from_field' tells where we are coming from - we need this to
218        determine if the objects are in fact to be deleted. Allow also
219        skipping parent -> child -> parent chain preventing fast delete of
220        the child.
221        """
222        from plain.models.fields.related import RelatedField
223
224        if (
225            isinstance(from_field, RelatedField)
226            and from_field.remote_field.on_delete is not CASCADE
227        ):
228            return False
229        if hasattr(objs, "_model_meta"):
230            model = objs._model_meta.model
231        elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
232            model = objs.model
233        else:
234            return False
235
236        # The use of from_field comes from the need to avoid cascade back to
237        # parent when parent delete is cascading to child.
238        meta = model._model_meta
239        return (
240            # Foreign keys pointing to this model.
241            all(
242                related.field.remote_field.on_delete is DO_NOTHING
243                for related in get_candidate_relations_to_delete(meta)
244            )
245        )
246
247    def get_del_batches(self, objs: list[Any], fields: list[Field]) -> list[list[Any]]:
248        """
249        Return the objs in suitably sized batches for the used db_connection.
250        """
251        field_names = [field.name for field in fields]
252        conn_batch_size = max(
253            db_connection.ops.bulk_batch_size(field_names, objs),
254            1,
255        )
256        if len(objs) > conn_batch_size:
257            return [
258                objs[i : i + conn_batch_size]
259                for i in range(0, len(objs), conn_batch_size)
260            ]
261        else:
262            return [objs]
263
264    def collect(
265        self,
266        objs: Iterable[Any],
267        source: Any = None,
268        nullable: bool = False,
269        collect_related: bool = True,
270        reverse_dependency: bool = False,
271        fail_on_restricted: bool = True,
272    ) -> None:
273        """
274        Add 'objs' to the collection of objects to be deleted as well as all
275        parent instances.  'objs' must be a homogeneous iterable collection of
276        model instances (e.g. a QuerySet).  If 'collect_related' is True,
277        related objects will be handled by their respective on_delete handler.
278
279        If the call is the result of a cascade, 'source' should be the model
280        that caused it and 'nullable' should be set to True, if the relation
281        can be null.
282
283        If 'reverse_dependency' is True, 'source' will be deleted before the
284        current model, rather than after. (Needed for cascading to parent
285        models, the one case in which the cascade follows the forwards
286        direction of an FK rather than the reverse direction.)
287
288        If 'fail_on_restricted' is False, error won't be raised even if it's
289        prohibited to delete such objects due to RESTRICT, that defers
290        restricted object checking in recursive calls where the top-level call
291        may need to collect more objects to determine whether restricted ones
292        can be deleted.
293        """
294        if self.can_fast_delete(objs):
295            self.fast_deletes.append(objs)
296            return
297        new_objs = self.add(
298            objs, source, nullable, reverse_dependency=reverse_dependency
299        )
300        if not new_objs:
301            return
302
303        model = new_objs[0].__class__
304
305        if not collect_related:
306            return
307
308        model_fast_deletes = defaultdict(list)
309        protected_objects = defaultdict(list)
310        for related in get_candidate_relations_to_delete(model._model_meta):
311            field = related.field
312            on_delete = field.remote_field.on_delete
313            if on_delete == DO_NOTHING:
314                continue
315            related_model = related.related_model
316            if self.can_fast_delete(related_model, from_field=field):
317                model_fast_deletes[related_model].append(field)
318                continue
319            batches = self.get_del_batches(new_objs, [field])
320            for batch in batches:
321                sub_objs = self.related_objects(related_model, [field], batch)
322                # Non-referenced fields can be deferred if no signal receivers
323                # are connected for the related model as they'll never be
324                # exposed to the user. Skip field deferring when some
325                # relationships are select_related as interactions between both
326                # features are hard to get right. This should only happen in
327                # the rare cases where .related_objects is overridden anyway.
328                if not sub_objs.sql_query.select_related:
329                    referenced_fields = set(
330                        chain.from_iterable(
331                            (rf.attname for rf in rel.field.foreign_related_fields)
332                            for rel in get_candidate_relations_to_delete(
333                                related_model._model_meta
334                            )
335                        )
336                    )
337                    sub_objs = sub_objs.only(*tuple(referenced_fields))
338                if getattr(on_delete, "lazy_sub_objs", False) or sub_objs:
339                    try:
340                        on_delete(self, field, sub_objs)
341                    except ProtectedError as error:
342                        key = f"'{field.model.__name__}.{field.name}'"
343                        protected_objects[key] += error.protected_objects
344        if protected_objects:
345            raise ProtectedError(
346                "Cannot delete some instances of model {!r} because they are "
347                "referenced through protected foreign keys: {}.".format(
348                    model.__name__,
349                    ", ".join(protected_objects),
350                ),
351                set(chain.from_iterable(protected_objects.values())),
352            )
353        for related_model, related_fields in model_fast_deletes.items():
354            batches = self.get_del_batches(new_objs, related_fields)
355            for batch in batches:
356                sub_objs = self.related_objects(related_model, related_fields, batch)
357                self.fast_deletes.append(sub_objs)
358
359        if fail_on_restricted:
360            # Raise an error if collected restricted objects (RESTRICT) aren't
361            # candidates for deletion also collected via CASCADE.
362            for related_model, instances in self.data.items():
363                self.clear_restricted_objects_from_set(related_model, instances)
364            for qs in self.fast_deletes:
365                self.clear_restricted_objects_from_queryset(qs.model, qs)
366            if self.restricted_objects.values():
367                restricted_objects = defaultdict(list)
368                for related_model, fields in self.restricted_objects.items():
369                    for field, objs in fields.items():
370                        if objs:
371                            key = f"'{related_model.__name__}.{field.name}'"
372                            restricted_objects[key] += objs
373                if restricted_objects:
374                    raise RestrictedError(
375                        "Cannot delete some instances of model {!r} because "
376                        "they are referenced through restricted foreign keys: "
377                        "{}.".format(
378                            model.__name__,
379                            ", ".join(restricted_objects),
380                        ),
381                        set(chain.from_iterable(restricted_objects.values())),
382                    )
383
384    def related_objects(
385        self, related_model: Any, related_fields: list[Field], objs: Iterable[Any]
386    ) -> QuerySet:
387        """
388        Get a QuerySet of the related model to objs via related fields.
389        """
390        predicate = query_utils.Q.create(
391            [(f"{related_field.name}__in", objs) for related_field in related_fields],
392            connector=query_utils.Q.OR,
393        )
394        return related_model._model_meta.base_queryset.filter(predicate)
395
396    def sort(self) -> None:
397        sorted_models = []
398        concrete_models = set()
399        models = list(self.data)
400        while len(sorted_models) < len(models):
401            found = False
402            for model in models:
403                if model in sorted_models:
404                    continue
405                dependencies = self.dependencies.get(model)
406                if not (dependencies and dependencies.difference(concrete_models)):
407                    sorted_models.append(model)
408                    concrete_models.add(model)
409                    found = True
410            if not found:
411                return
412        self.data = defaultdict(
413            set, {model: self.data[model] for model in sorted_models}
414        )
415
416    def delete(self) -> tuple[int, dict[str, int]]:
417        # sort instance collections
418        for model, instances in self.data.items():
419            self.data[model] = sorted(instances, key=attrgetter("id"))
420
421        # if possible, bring the models in an order suitable for databases that
422        # don't support transactions or cannot defer constraint checks until the
423        # end of a transaction.
424        self.sort()
425        # number of objects deleted for each model label
426        deleted_counter = Counter()
427
428        # Optimize for the case with a single obj and no dependencies
429        if len(self.data) == 1 and len(instances) == 1:
430            instance = list(instances)[0]
431            if self.can_fast_delete(instance):
432                with transaction.mark_for_rollback_on_error():
433                    count = sql.DeleteQuery(model).delete_batch([instance.id])
434                setattr(
435                    instance, model._model_meta.get_forward_field("id").attname, None
436                )
437                return count, {model.model_options.label: count}
438
439        with transaction.atomic(savepoint=False):
440            # fast deletes
441            for qs in self.fast_deletes:
442                count = qs._raw_delete()
443                if count:
444                    deleted_counter[qs.model.model_options.label] += count
445
446            # update fields
447            for (field, value), instances_list in self.field_updates.items():
448                assert field.name is not None
449                updates = []
450                objs = []
451                for instances in instances_list:
452                    if (
453                        isinstance(instances, QuerySet)
454                        and instances._result_cache is None
455                    ):
456                        updates.append(instances)
457                    else:
458                        objs.extend(instances)
459                if updates:
460                    combined_updates = reduce(or_, updates)
461                    combined_updates.update(**{field.name: value})
462                if objs:
463                    model = objs[0].__class__
464                    query = sql.UpdateQuery(model)
465                    query.update_batch(
466                        list({obj.id for obj in objs}), {field.name: value}
467                    )
468
469            # reverse instance collections
470            for instances in self.data.values():
471                instances.reverse()
472
473            # delete instances
474            for model, instances in self.data.items():
475                query = sql.DeleteQuery(model)
476                id_list = [obj.id for obj in instances]
477                count = query.delete_batch(id_list)
478                if count:
479                    deleted_counter[model.model_options.label] += count
480
481        for model, instances in self.data.items():
482            for instance in instances:
483                setattr(
484                    instance, model._model_meta.get_forward_field("id").attname, None
485                )
486        return sum(deleted_counter.values()), dict(deleted_counter)