Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections import Counter, defaultdict
  4from collections.abc import Callable, Generator, Iterable
  5from functools import partial, reduce
  6from itertools import chain
  7from operator import attrgetter, or_
  8from typing import TYPE_CHECKING, Any
  9
 10from plain.models import (
 11    query_utils,
 12    sql,
 13    transaction,
 14)
 15from plain.models.db import IntegrityError, db_connection
 16from plain.models.meta import Meta
 17from plain.models.query import QuerySet
 18
 19if TYPE_CHECKING:
 20    from plain.models.fields import Field
 21
 22
 23class ProtectedError(IntegrityError):
 24    def __init__(self, msg: str, protected_objects: Iterable[Any]) -> None:
 25        self.protected_objects = protected_objects
 26        super().__init__(msg, protected_objects)
 27
 28
 29class RestrictedError(IntegrityError):
 30    def __init__(self, msg: str, restricted_objects: Iterable[Any]) -> None:
 31        self.restricted_objects = restricted_objects
 32        super().__init__(msg, restricted_objects)
 33
 34
 35def CASCADE(collector: Collector, field: Field, sub_objs: Any) -> None:
 36    collector.collect(
 37        sub_objs,
 38        source=field.remote_field.model,  # type: ignore[attr-defined]
 39        nullable=field.allow_null,
 40        fail_on_restricted=False,
 41    )
 42    if field.allow_null and not db_connection.features.can_defer_constraint_checks:
 43        collector.add_field_update(field, None, sub_objs)
 44
 45
 46def PROTECT(collector: Collector, field: Field, sub_objs: Any) -> None:
 47    raise ProtectedError(
 48        f"Cannot delete some instances of model '{field.remote_field.model.__name__}' because they are "  # type: ignore[attr-defined]
 49        f"referenced through a protected foreign key: '{sub_objs[0].__class__.__name__}.{field.name}'",
 50        sub_objs,
 51    )
 52
 53
 54def RESTRICT(collector: Collector, field: Field, sub_objs: Any) -> None:
 55    collector.add_restricted_objects(field, sub_objs)
 56    collector.add_dependency(field.remote_field.model, field.model)  # type: ignore[attr-defined]
 57
 58
 59def SET(value: Any) -> Callable[[Collector, Field, Any], None]:
 60    if callable(value):
 61
 62        def set_on_delete(collector: Collector, field: Field, sub_objs: Any) -> None:
 63            collector.add_field_update(field, value(), sub_objs)
 64
 65    else:
 66
 67        def set_on_delete(collector: Collector, field: Field, sub_objs: Any) -> None:
 68            collector.add_field_update(field, value, sub_objs)
 69
 70    set_on_delete.deconstruct = lambda: ("plain.models.SET", (value,), {})  # type: ignore[attr-defined]
 71    set_on_delete.lazy_sub_objs = True  # type: ignore[attr-defined]
 72    return set_on_delete
 73
 74
 75def SET_NULL(collector: Collector, field: Field, sub_objs: Any) -> None:
 76    collector.add_field_update(field, None, sub_objs)
 77
 78
 79SET_NULL.lazy_sub_objs = True  # type: ignore[attr-defined]
 80
 81
 82def SET_DEFAULT(collector: Collector, field: Field, sub_objs: Any) -> None:
 83    collector.add_field_update(field, field.get_default(), sub_objs)
 84
 85
 86SET_DEFAULT.lazy_sub_objs = True  # type: ignore[attr-defined]
 87
 88
 89def DO_NOTHING(collector: Collector, field: Field, sub_objs: Any) -> None:
 90    pass
 91
 92
 93def get_candidate_relations_to_delete(meta: Meta) -> Generator[Any, None, None]:
 94    # The candidate relations are the ones that come from N-1 and 1-1 relations.
 95    # N-N  (i.e., many-to-many) relations aren't candidates for deletion.
 96    return (
 97        f
 98        for f in meta.get_fields(include_hidden=True)
 99        if f.auto_created and not f.concrete and f.one_to_many
100    )
101
102
103class Collector:
104    def __init__(self, origin: Any = None) -> None:
105        # A Model or QuerySet object.
106        self.origin = origin
107        # Initially, {model: {instances}}, later values become lists.
108        self.data: defaultdict[Any, Any] = defaultdict(set)
109        # {(field, value): [instances, …]}
110        self.field_updates: defaultdict[tuple[Field, Any], list[Any]] = defaultdict(
111            list
112        )
113        # {model: {field: {instances}}}
114        self.restricted_objects: defaultdict[Any, Any] = defaultdict(
115            partial(defaultdict, set)
116        )
117        # fast_deletes is a list of queryset-likes that can be deleted without
118        # fetching the objects into memory.
119        self.fast_deletes: list[Any] = []
120
121        # Tracks deletion-order dependency for databases without transactions
122        # or ability to defer constraint checks. Only concrete model classes
123        # should be included, as the dependencies exist only between actual
124        # database tables.
125        self.dependencies: defaultdict[Any, set[Any]] = defaultdict(
126            set
127        )  # {model: {models}}
128
129    def add(
130        self,
131        objs: Iterable[Any],
132        source: Any = None,
133        nullable: bool = False,
134        reverse_dependency: bool = False,
135    ) -> list[Any]:
136        """
137        Add 'objs' to the collection of objects to be deleted.  If the call is
138        the result of a cascade, 'source' should be the model that caused it,
139        and 'nullable' should be set to True if the relation can be null.
140
141        Return a list of all objects that were not already collected.
142        """
143        if not objs:
144            return []
145        new_objs = []
146        model = objs[0].__class__
147        instances = self.data[model]
148        for obj in objs:
149            if obj not in instances:
150                new_objs.append(obj)
151        instances.update(new_objs)
152        # Nullable relationships can be ignored -- they are nulled out before
153        # deleting, and therefore do not affect the order in which objects have
154        # to be deleted.
155        if source is not None and not nullable:
156            self.add_dependency(source, model, reverse_dependency=reverse_dependency)
157        return new_objs
158
159    def add_dependency(
160        self, model: Any, dependency: Any, reverse_dependency: bool = False
161    ) -> None:
162        if reverse_dependency:
163            model, dependency = dependency, model
164        self.dependencies[model].add(dependency)
165        self.data.setdefault(dependency, self.data.default_factory())
166
167    def add_field_update(self, field: Field, value: Any, objs: Iterable[Any]) -> None:
168        """
169        Schedule a field update. 'objs' must be a homogeneous iterable
170        collection of model instances (e.g. a QuerySet).
171        """
172        self.field_updates[field, value].append(objs)
173
174    def add_restricted_objects(self, field: Field, objs: Iterable[Any]) -> None:
175        if objs:
176            model = objs[0].__class__
177            self.restricted_objects[model][field].update(objs)
178
179    def clear_restricted_objects_from_set(self, model: Any, objs: set[Any]) -> None:
180        if model in self.restricted_objects:
181            self.restricted_objects[model] = {
182                field: items - objs
183                for field, items in self.restricted_objects[model].items()
184            }
185
186    def clear_restricted_objects_from_queryset(self, model: Any, qs: QuerySet) -> None:
187        if model in self.restricted_objects:
188            objs = set(
189                qs.filter(
190                    id__in=[
191                        obj.id
192                        for objs in self.restricted_objects[model].values()
193                        for obj in objs
194                    ]
195                )
196            )
197            self.clear_restricted_objects_from_set(model, objs)
198
199    def can_fast_delete(self, objs: Any, from_field: Any = None) -> bool:
200        """
201        Determine if the objects in the given queryset-like or single object
202        can be fast-deleted. This can be done if there are no cascades, no
203        parents and no signal listeners for the object class.
204
205        The 'from_field' tells where we are coming from - we need this to
206        determine if the objects are in fact to be deleted. Allow also
207        skipping parent -> child -> parent chain preventing fast delete of
208        the child.
209        """
210        if from_field and from_field.remote_field.on_delete is not CASCADE:
211            return False
212        if hasattr(objs, "_model_meta"):
213            model = objs._model_meta.model
214        elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
215            model = objs.model
216        else:
217            return False
218
219        # The use of from_field comes from the need to avoid cascade back to
220        # parent when parent delete is cascading to child.
221        meta = model._model_meta
222        return (
223            # Foreign keys pointing to this model.
224            all(
225                related.field.remote_field.on_delete is DO_NOTHING
226                for related in get_candidate_relations_to_delete(meta)
227            )
228        )
229
230    def get_del_batches(self, objs: list[Any], fields: list[Field]) -> list[list[Any]]:
231        """
232        Return the objs in suitably sized batches for the used db_connection.
233        """
234        field_names = [field.name for field in fields]
235        conn_batch_size = max(
236            db_connection.ops.bulk_batch_size(field_names, objs),
237            1,
238        )
239        if len(objs) > conn_batch_size:
240            return [
241                objs[i : i + conn_batch_size]
242                for i in range(0, len(objs), conn_batch_size)
243            ]
244        else:
245            return [objs]
246
247    def collect(
248        self,
249        objs: Iterable[Any],
250        source: Any = None,
251        nullable: bool = False,
252        collect_related: bool = True,
253        reverse_dependency: bool = False,
254        fail_on_restricted: bool = True,
255    ) -> None:
256        """
257        Add 'objs' to the collection of objects to be deleted as well as all
258        parent instances.  'objs' must be a homogeneous iterable collection of
259        model instances (e.g. a QuerySet).  If 'collect_related' is True,
260        related objects will be handled by their respective on_delete handler.
261
262        If the call is the result of a cascade, 'source' should be the model
263        that caused it and 'nullable' should be set to True, if the relation
264        can be null.
265
266        If 'reverse_dependency' is True, 'source' will be deleted before the
267        current model, rather than after. (Needed for cascading to parent
268        models, the one case in which the cascade follows the forwards
269        direction of an FK rather than the reverse direction.)
270
271        If 'fail_on_restricted' is False, error won't be raised even if it's
272        prohibited to delete such objects due to RESTRICT, that defers
273        restricted object checking in recursive calls where the top-level call
274        may need to collect more objects to determine whether restricted ones
275        can be deleted.
276        """
277        if self.can_fast_delete(objs):
278            self.fast_deletes.append(objs)
279            return
280        new_objs = self.add(
281            objs, source, nullable, reverse_dependency=reverse_dependency
282        )
283        if not new_objs:
284            return
285
286        model = new_objs[0].__class__
287
288        if not collect_related:
289            return
290
291        model_fast_deletes = defaultdict(list)
292        protected_objects = defaultdict(list)
293        for related in get_candidate_relations_to_delete(model._model_meta):
294            field = related.field
295            on_delete = field.remote_field.on_delete
296            if on_delete == DO_NOTHING:
297                continue
298            related_model = related.related_model
299            if self.can_fast_delete(related_model, from_field=field):
300                model_fast_deletes[related_model].append(field)
301                continue
302            batches = self.get_del_batches(new_objs, [field])
303            for batch in batches:
304                sub_objs = self.related_objects(related_model, [field], batch)
305                # Non-referenced fields can be deferred if no signal receivers
306                # are connected for the related model as they'll never be
307                # exposed to the user. Skip field deferring when some
308                # relationships are select_related as interactions between both
309                # features are hard to get right. This should only happen in
310                # the rare cases where .related_objects is overridden anyway.
311                if not sub_objs.sql_query.select_related:
312                    referenced_fields = set(
313                        chain.from_iterable(
314                            (rf.attname for rf in rel.field.foreign_related_fields)
315                            for rel in get_candidate_relations_to_delete(
316                                related_model._model_meta
317                            )
318                        )
319                    )
320                    sub_objs = sub_objs.only(*tuple(referenced_fields))
321                if getattr(on_delete, "lazy_sub_objs", False) or sub_objs:
322                    try:
323                        on_delete(self, field, sub_objs)
324                    except ProtectedError as error:
325                        key = f"'{field.model.__name__}.{field.name}'"
326                        protected_objects[key] += error.protected_objects
327        if protected_objects:
328            raise ProtectedError(
329                "Cannot delete some instances of model {!r} because they are "
330                "referenced through protected foreign keys: {}.".format(
331                    model.__name__,
332                    ", ".join(protected_objects),
333                ),
334                set(chain.from_iterable(protected_objects.values())),
335            )
336        for related_model, related_fields in model_fast_deletes.items():
337            batches = self.get_del_batches(new_objs, related_fields)
338            for batch in batches:
339                sub_objs = self.related_objects(related_model, related_fields, batch)
340                self.fast_deletes.append(sub_objs)
341
342        if fail_on_restricted:
343            # Raise an error if collected restricted objects (RESTRICT) aren't
344            # candidates for deletion also collected via CASCADE.
345            for related_model, instances in self.data.items():
346                self.clear_restricted_objects_from_set(related_model, instances)
347            for qs in self.fast_deletes:
348                self.clear_restricted_objects_from_queryset(qs.model, qs)
349            if self.restricted_objects.values():
350                restricted_objects = defaultdict(list)
351                for related_model, fields in self.restricted_objects.items():
352                    for field, objs in fields.items():
353                        if objs:
354                            key = f"'{related_model.__name__}.{field.name}'"
355                            restricted_objects[key] += objs
356                if restricted_objects:
357                    raise RestrictedError(
358                        "Cannot delete some instances of model {!r} because "
359                        "they are referenced through restricted foreign keys: "
360                        "{}.".format(
361                            model.__name__,
362                            ", ".join(restricted_objects),
363                        ),
364                        set(chain.from_iterable(restricted_objects.values())),
365                    )
366
367    def related_objects(
368        self, related_model: Any, related_fields: list[Field], objs: Iterable[Any]
369    ) -> QuerySet:
370        """
371        Get a QuerySet of the related model to objs via related fields.
372        """
373        predicate = query_utils.Q.create(
374            [(f"{related_field.name}__in", objs) for related_field in related_fields],
375            connector=query_utils.Q.OR,
376        )
377        return related_model._model_meta.base_queryset.filter(predicate)
378
379    def sort(self) -> None:
380        sorted_models = []
381        concrete_models = set()
382        models = list(self.data)
383        while len(sorted_models) < len(models):
384            found = False
385            for model in models:
386                if model in sorted_models:
387                    continue
388                dependencies = self.dependencies.get(model)
389                if not (dependencies and dependencies.difference(concrete_models)):
390                    sorted_models.append(model)
391                    concrete_models.add(model)
392                    found = True
393            if not found:
394                return
395        self.data = {model: self.data[model] for model in sorted_models}
396
397    def delete(self) -> tuple[int, dict[str, int]]:
398        # sort instance collections
399        for model, instances in self.data.items():
400            self.data[model] = sorted(instances, key=attrgetter("id"))
401
402        # if possible, bring the models in an order suitable for databases that
403        # don't support transactions or cannot defer constraint checks until the
404        # end of a transaction.
405        self.sort()
406        # number of objects deleted for each model label
407        deleted_counter = Counter()
408
409        # Optimize for the case with a single obj and no dependencies
410        if len(self.data) == 1 and len(instances) == 1:
411            instance = list(instances)[0]
412            if self.can_fast_delete(instance):
413                with transaction.mark_for_rollback_on_error():
414                    count = sql.DeleteQuery(model).delete_batch([instance.id])
415                setattr(instance, model._model_meta.get_field("id").attname, None)
416                return count, {model.model_options.label: count}
417
418        with transaction.atomic(savepoint=False):
419            # fast deletes
420            for qs in self.fast_deletes:
421                count = qs._raw_delete()
422                if count:
423                    deleted_counter[qs.model.model_options.label] += count
424
425            # update fields
426            for (field, value), instances_list in self.field_updates.items():
427                updates = []
428                objs = []
429                for instances in instances_list:
430                    if (
431                        isinstance(instances, QuerySet)
432                        and instances._result_cache is None
433                    ):
434                        updates.append(instances)
435                    else:
436                        objs.extend(instances)
437                if updates:
438                    combined_updates = reduce(or_, updates)
439                    combined_updates.update(**{field.name: value})
440                if objs:
441                    model = objs[0].__class__
442                    query = sql.UpdateQuery(model)
443                    query.update_batch(
444                        list({obj.id for obj in objs}), {field.name: value}
445                    )
446
447            # reverse instance collections
448            for instances in self.data.values():
449                instances.reverse()
450
451            # delete instances
452            for model, instances in self.data.items():
453                query = sql.DeleteQuery(model)
454                id_list = [obj.id for obj in instances]
455                count = query.delete_batch(id_list)
456                if count:
457                    deleted_counter[model.model_options.label] += count
458
459        for model, instances in self.data.items():
460            for instance in instances:
461                setattr(instance, model._model_meta.get_field("id").attname, None)
462        return sum(deleted_counter.values()), dict(deleted_counter)