Plain is headed towards 1.0! Subscribe for development updates →

  1from collections import Counter, defaultdict
  2from functools import partial, reduce
  3from itertools import chain
  4from operator import attrgetter, or_
  5
  6from plain.models import (
  7    query_utils,
  8    sql,
  9    transaction,
 10)
 11from plain.models.db import IntegrityError, db_connection
 12from plain.models.query import QuerySet
 13
 14
 15class ProtectedError(IntegrityError):
 16    def __init__(self, msg, protected_objects):
 17        self.protected_objects = protected_objects
 18        super().__init__(msg, protected_objects)
 19
 20
 21class RestrictedError(IntegrityError):
 22    def __init__(self, msg, restricted_objects):
 23        self.restricted_objects = restricted_objects
 24        super().__init__(msg, restricted_objects)
 25
 26
 27def CASCADE(collector, field, sub_objs):
 28    collector.collect(
 29        sub_objs,
 30        source=field.remote_field.model,
 31        nullable=field.allow_null,
 32        fail_on_restricted=False,
 33    )
 34    if field.allow_null and not db_connection.features.can_defer_constraint_checks:
 35        collector.add_field_update(field, None, sub_objs)
 36
 37
 38def PROTECT(collector, field, sub_objs):
 39    raise ProtectedError(
 40        f"Cannot delete some instances of model '{field.remote_field.model.__name__}' because they are "
 41        f"referenced through a protected foreign key: '{sub_objs[0].__class__.__name__}.{field.name}'",
 42        sub_objs,
 43    )
 44
 45
 46def RESTRICT(collector, field, sub_objs):
 47    collector.add_restricted_objects(field, sub_objs)
 48    collector.add_dependency(field.remote_field.model, field.model)
 49
 50
 51def SET(value):
 52    if callable(value):
 53
 54        def set_on_delete(collector, field, sub_objs):
 55            collector.add_field_update(field, value(), sub_objs)
 56
 57    else:
 58
 59        def set_on_delete(collector, field, sub_objs):
 60            collector.add_field_update(field, value, sub_objs)
 61
 62    set_on_delete.deconstruct = lambda: ("plain.models.SET", (value,), {})
 63    set_on_delete.lazy_sub_objs = True
 64    return set_on_delete
 65
 66
 67def SET_NULL(collector, field, sub_objs):
 68    collector.add_field_update(field, None, sub_objs)
 69
 70
 71SET_NULL.lazy_sub_objs = True
 72
 73
 74def SET_DEFAULT(collector, field, sub_objs):
 75    collector.add_field_update(field, field.get_default(), sub_objs)
 76
 77
 78SET_DEFAULT.lazy_sub_objs = True
 79
 80
 81def DO_NOTHING(collector, field, sub_objs):
 82    pass
 83
 84
 85def get_candidate_relations_to_delete(opts):
 86    # The candidate relations are the ones that come from N-1 and 1-1 relations.
 87    # N-N  (i.e., many-to-many) relations aren't candidates for deletion.
 88    return (
 89        f
 90        for f in opts.get_fields(include_hidden=True)
 91        if f.auto_created and not f.concrete and f.one_to_many
 92    )
 93
 94
 95class Collector:
 96    def __init__(self, origin=None):
 97        # A Model or QuerySet object.
 98        self.origin = origin
 99        # Initially, {model: {instances}}, later values become lists.
100        self.data = defaultdict(set)
101        # {(field, value): [instances, …]}
102        self.field_updates = defaultdict(list)
103        # {model: {field: {instances}}}
104        self.restricted_objects = defaultdict(partial(defaultdict, set))
105        # fast_deletes is a list of queryset-likes that can be deleted without
106        # fetching the objects into memory.
107        self.fast_deletes = []
108
109        # Tracks deletion-order dependency for databases without transactions
110        # or ability to defer constraint checks. Only concrete model classes
111        # should be included, as the dependencies exist only between actual
112        # database tables.
113        self.dependencies = defaultdict(set)  # {model: {models}}
114
115    def add(self, objs, source=None, nullable=False, reverse_dependency=False):
116        """
117        Add 'objs' to the collection of objects to be deleted.  If the call is
118        the result of a cascade, 'source' should be the model that caused it,
119        and 'nullable' should be set to True if the relation can be null.
120
121        Return a list of all objects that were not already collected.
122        """
123        if not objs:
124            return []
125        new_objs = []
126        model = objs[0].__class__
127        instances = self.data[model]
128        for obj in objs:
129            if obj not in instances:
130                new_objs.append(obj)
131        instances.update(new_objs)
132        # Nullable relationships can be ignored -- they are nulled out before
133        # deleting, and therefore do not affect the order in which objects have
134        # to be deleted.
135        if source is not None and not nullable:
136            self.add_dependency(source, model, reverse_dependency=reverse_dependency)
137        return new_objs
138
139    def add_dependency(self, model, dependency, reverse_dependency=False):
140        if reverse_dependency:
141            model, dependency = dependency, model
142        self.dependencies[model._meta.concrete_model].add(
143            dependency._meta.concrete_model
144        )
145        self.data.setdefault(dependency, self.data.default_factory())
146
147    def add_field_update(self, field, value, objs):
148        """
149        Schedule a field update. 'objs' must be a homogeneous iterable
150        collection of model instances (e.g. a QuerySet).
151        """
152        self.field_updates[field, value].append(objs)
153
154    def add_restricted_objects(self, field, objs):
155        if objs:
156            model = objs[0].__class__
157            self.restricted_objects[model][field].update(objs)
158
159    def clear_restricted_objects_from_set(self, model, objs):
160        if model in self.restricted_objects:
161            self.restricted_objects[model] = {
162                field: items - objs
163                for field, items in self.restricted_objects[model].items()
164            }
165
166    def clear_restricted_objects_from_queryset(self, model, qs):
167        if model in self.restricted_objects:
168            objs = set(
169                qs.filter(
170                    pk__in=[
171                        obj.pk
172                        for objs in self.restricted_objects[model].values()
173                        for obj in objs
174                    ]
175                )
176            )
177            self.clear_restricted_objects_from_set(model, objs)
178
179    def can_fast_delete(self, objs, from_field=None):
180        """
181        Determine if the objects in the given queryset-like or single object
182        can be fast-deleted. This can be done if there are no cascades, no
183        parents and no signal listeners for the object class.
184
185        The 'from_field' tells where we are coming from - we need this to
186        determine if the objects are in fact to be deleted. Allow also
187        skipping parent -> child -> parent chain preventing fast delete of
188        the child.
189        """
190        if from_field and from_field.remote_field.on_delete is not CASCADE:
191            return False
192        if hasattr(objs, "_meta"):
193            model = objs._meta.model
194        elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
195            model = objs.model
196        else:
197            return False
198
199        # The use of from_field comes from the need to avoid cascade back to
200        # parent when parent delete is cascading to child.
201        opts = model._meta
202        return (
203            # Foreign keys pointing to this model.
204            all(
205                related.field.remote_field.on_delete is DO_NOTHING
206                for related in get_candidate_relations_to_delete(opts)
207            )
208        )
209
210    def get_del_batches(self, objs, fields):
211        """
212        Return the objs in suitably sized batches for the used db_connection.
213        """
214        field_names = [field.name for field in fields]
215        conn_batch_size = max(
216            db_connection.ops.bulk_batch_size(field_names, objs),
217            1,
218        )
219        if len(objs) > conn_batch_size:
220            return [
221                objs[i : i + conn_batch_size]
222                for i in range(0, len(objs), conn_batch_size)
223            ]
224        else:
225            return [objs]
226
227    def collect(
228        self,
229        objs,
230        source=None,
231        nullable=False,
232        collect_related=True,
233        reverse_dependency=False,
234        fail_on_restricted=True,
235    ):
236        """
237        Add 'objs' to the collection of objects to be deleted as well as all
238        parent instances.  'objs' must be a homogeneous iterable collection of
239        model instances (e.g. a QuerySet).  If 'collect_related' is True,
240        related objects will be handled by their respective on_delete handler.
241
242        If the call is the result of a cascade, 'source' should be the model
243        that caused it and 'nullable' should be set to True, if the relation
244        can be null.
245
246        If 'reverse_dependency' is True, 'source' will be deleted before the
247        current model, rather than after. (Needed for cascading to parent
248        models, the one case in which the cascade follows the forwards
249        direction of an FK rather than the reverse direction.)
250
251        If 'fail_on_restricted' is False, error won't be raised even if it's
252        prohibited to delete such objects due to RESTRICT, that defers
253        restricted object checking in recursive calls where the top-level call
254        may need to collect more objects to determine whether restricted ones
255        can be deleted.
256        """
257        if self.can_fast_delete(objs):
258            self.fast_deletes.append(objs)
259            return
260        new_objs = self.add(
261            objs, source, nullable, reverse_dependency=reverse_dependency
262        )
263        if not new_objs:
264            return
265
266        model = new_objs[0].__class__
267
268        if not collect_related:
269            return
270
271        model_fast_deletes = defaultdict(list)
272        protected_objects = defaultdict(list)
273        for related in get_candidate_relations_to_delete(model._meta):
274            field = related.field
275            on_delete = field.remote_field.on_delete
276            if on_delete == DO_NOTHING:
277                continue
278            related_model = related.related_model
279            if self.can_fast_delete(related_model, from_field=field):
280                model_fast_deletes[related_model].append(field)
281                continue
282            batches = self.get_del_batches(new_objs, [field])
283            for batch in batches:
284                sub_objs = self.related_objects(related_model, [field], batch)
285                # Non-referenced fields can be deferred if no signal receivers
286                # are connected for the related model as they'll never be
287                # exposed to the user. Skip field deferring when some
288                # relationships are select_related as interactions between both
289                # features are hard to get right. This should only happen in
290                # the rare cases where .related_objects is overridden anyway.
291                if not sub_objs.query.select_related:
292                    referenced_fields = set(
293                        chain.from_iterable(
294                            (rf.attname for rf in rel.field.foreign_related_fields)
295                            for rel in get_candidate_relations_to_delete(
296                                related_model._meta
297                            )
298                        )
299                    )
300                    sub_objs = sub_objs.only(*tuple(referenced_fields))
301                if getattr(on_delete, "lazy_sub_objs", False) or sub_objs:
302                    try:
303                        on_delete(self, field, sub_objs)
304                    except ProtectedError as error:
305                        key = f"'{field.model.__name__}.{field.name}'"
306                        protected_objects[key] += error.protected_objects
307        if protected_objects:
308            raise ProtectedError(
309                "Cannot delete some instances of model {!r} because they are "
310                "referenced through protected foreign keys: {}.".format(
311                    model.__name__,
312                    ", ".join(protected_objects),
313                ),
314                set(chain.from_iterable(protected_objects.values())),
315            )
316        for related_model, related_fields in model_fast_deletes.items():
317            batches = self.get_del_batches(new_objs, related_fields)
318            for batch in batches:
319                sub_objs = self.related_objects(related_model, related_fields, batch)
320                self.fast_deletes.append(sub_objs)
321
322        if fail_on_restricted:
323            # Raise an error if collected restricted objects (RESTRICT) aren't
324            # candidates for deletion also collected via CASCADE.
325            for related_model, instances in self.data.items():
326                self.clear_restricted_objects_from_set(related_model, instances)
327            for qs in self.fast_deletes:
328                self.clear_restricted_objects_from_queryset(qs.model, qs)
329            if self.restricted_objects.values():
330                restricted_objects = defaultdict(list)
331                for related_model, fields in self.restricted_objects.items():
332                    for field, objs in fields.items():
333                        if objs:
334                            key = f"'{related_model.__name__}.{field.name}'"
335                            restricted_objects[key] += objs
336                if restricted_objects:
337                    raise RestrictedError(
338                        "Cannot delete some instances of model {!r} because "
339                        "they are referenced through restricted foreign keys: "
340                        "{}.".format(
341                            model.__name__,
342                            ", ".join(restricted_objects),
343                        ),
344                        set(chain.from_iterable(restricted_objects.values())),
345                    )
346
347    def related_objects(self, related_model, related_fields, objs):
348        """
349        Get a QuerySet of the related model to objs via related fields.
350        """
351        predicate = query_utils.Q.create(
352            [(f"{related_field.name}__in", objs) for related_field in related_fields],
353            connector=query_utils.Q.OR,
354        )
355        return related_model._base_manager.filter(predicate)
356
357    def sort(self):
358        sorted_models = []
359        concrete_models = set()
360        models = list(self.data)
361        while len(sorted_models) < len(models):
362            found = False
363            for model in models:
364                if model in sorted_models:
365                    continue
366                dependencies = self.dependencies.get(model._meta.concrete_model)
367                if not (dependencies and dependencies.difference(concrete_models)):
368                    sorted_models.append(model)
369                    concrete_models.add(model._meta.concrete_model)
370                    found = True
371            if not found:
372                return
373        self.data = {model: self.data[model] for model in sorted_models}
374
375    def delete(self):
376        # sort instance collections
377        for model, instances in self.data.items():
378            self.data[model] = sorted(instances, key=attrgetter("pk"))
379
380        # if possible, bring the models in an order suitable for databases that
381        # don't support transactions or cannot defer constraint checks until the
382        # end of a transaction.
383        self.sort()
384        # number of objects deleted for each model label
385        deleted_counter = Counter()
386
387        # Optimize for the case with a single obj and no dependencies
388        if len(self.data) == 1 and len(instances) == 1:
389            instance = list(instances)[0]
390            if self.can_fast_delete(instance):
391                with transaction.mark_for_rollback_on_error():
392                    count = sql.DeleteQuery(model).delete_batch([instance.pk])
393                setattr(instance, model._meta.pk.attname, None)
394                return count, {model._meta.label: count}
395
396        with transaction.atomic(savepoint=False):
397            # fast deletes
398            for qs in self.fast_deletes:
399                count = qs._raw_delete()
400                if count:
401                    deleted_counter[qs.model._meta.label] += count
402
403            # update fields
404            for (field, value), instances_list in self.field_updates.items():
405                updates = []
406                objs = []
407                for instances in instances_list:
408                    if (
409                        isinstance(instances, QuerySet)
410                        and instances._result_cache is None
411                    ):
412                        updates.append(instances)
413                    else:
414                        objs.extend(instances)
415                if updates:
416                    combined_updates = reduce(or_, updates)
417                    combined_updates.update(**{field.name: value})
418                if objs:
419                    model = objs[0].__class__
420                    query = sql.UpdateQuery(model)
421                    query.update_batch(
422                        list({obj.pk for obj in objs}), {field.name: value}
423                    )
424
425            # reverse instance collections
426            for instances in self.data.values():
427                instances.reverse()
428
429            # delete instances
430            for model, instances in self.data.items():
431                query = sql.DeleteQuery(model)
432                pk_list = [obj.pk for obj in instances]
433                count = query.delete_batch(pk_list)
434                if count:
435                    deleted_counter[model._meta.label] += count
436
437        for model, instances in self.data.items():
438            for instance in instances:
439                setattr(instance, model._meta.pk.attname, None)
440        return sum(deleted_counter.values()), dict(deleted_counter)