1from __future__ import annotations
  2
  3from collections import Counter, defaultdict
  4from collections.abc import Callable, Generator, Iterable
  5from functools import partial, reduce
  6from itertools import chain
  7from operator import attrgetter, or_
  8from typing import TYPE_CHECKING, Any
  9
 10from plain.models import query_utils, transaction
 11from plain.models.db import IntegrityError, db_connection
 12from plain.models.meta import Meta
 13from plain.models.query import QuerySet
 14from plain.models.sql.subqueries import DeleteQuery, UpdateQuery
 15
 16if TYPE_CHECKING:
 17    from plain.models.fields import Field
 18    from plain.models.fields.related import RelatedField
 19    from plain.models.fields.reverse_related import ForeignKeyRel
 20
 21
 22class ProtectedError(IntegrityError):
 23    def __init__(self, msg: str, protected_objects: Iterable[Any]) -> None:
 24        self.protected_objects = protected_objects
 25        super().__init__(msg, protected_objects)
 26
 27
 28class RestrictedError(IntegrityError):
 29    def __init__(self, msg: str, restricted_objects: Iterable[Any]) -> None:
 30        self.restricted_objects = restricted_objects
 31        super().__init__(msg, restricted_objects)
 32
 33
 34def CASCADE(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 35    collector.collect(
 36        sub_objs,
 37        source=field.remote_field.model,
 38        nullable=field.allow_null,
 39        fail_on_restricted=False,
 40    )
 41    if field.allow_null and not db_connection.features.can_defer_constraint_checks:
 42        collector.add_field_update(field, None, sub_objs)
 43
 44
 45def PROTECT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 46    raise ProtectedError(
 47        f"Cannot delete some instances of model '{field.remote_field.model.__name__}' because they are "
 48        f"referenced through a protected foreign key: '{sub_objs[0].__class__.__name__}.{field.name}'",
 49        sub_objs,
 50    )
 51
 52
 53def RESTRICT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 54    collector.add_restricted_objects(field, sub_objs)
 55    collector.add_dependency(field.remote_field.model, field.model)
 56
 57
 58def SET(value: Any) -> Callable[[Collector, RelatedField, Any], None]:
 59    if callable(value):
 60
 61        def set_on_delete(
 62            collector: Collector, field: RelatedField, sub_objs: Any
 63        ) -> None:
 64            collector.add_field_update(field, value(), sub_objs)
 65
 66    else:
 67
 68        def set_on_delete(
 69            collector: Collector, field: RelatedField, sub_objs: Any
 70        ) -> None:
 71            collector.add_field_update(field, value, sub_objs)
 72
 73    set_on_delete.deconstruct = lambda: ("plain.models.SET", (value,), {})  # type: ignore[attr-defined]
 74    set_on_delete.lazy_sub_objs = True  # type: ignore[attr-defined]
 75    return set_on_delete
 76
 77
 78def SET_NULL(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 79    collector.add_field_update(field, None, sub_objs)
 80
 81
 82SET_NULL.lazy_sub_objs = True  # type: ignore[attr-defined]
 83
 84
 85def SET_DEFAULT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 86    collector.add_field_update(field, field.get_default(), sub_objs)
 87
 88
 89SET_DEFAULT.lazy_sub_objs = True  # type: ignore[attr-defined]
 90
 91
 92def DO_NOTHING(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 93    pass
 94
 95
 96def get_candidate_relations_to_delete(
 97    meta: Meta,
 98) -> Generator[ForeignKeyRel, None, None]:
 99    from plain.models.fields.reverse_related import ForeignKeyRel
100
101    # The candidate relations are the ones that come from N-1 and 1-1 relations.
102    # N-N  (i.e., many-to-many) relations aren't candidates for deletion.
103    return (
104        f
105        for f in meta.get_fields(include_reverse=True)
106        if f.auto_created and not f.concrete and isinstance(f, ForeignKeyRel)
107    )
108
109
110class Collector:
111    def __init__(self, origin: Any = None) -> None:
112        # A Model or QuerySet object.
113        self.origin = origin
114        # Initially, {model: {instances}}, later values become lists.
115        self.data: defaultdict[Any, Any] = defaultdict(set)
116        # {(field, value): [instances, โ€ฆ]}
117        self.field_updates: defaultdict[tuple[Field, Any], list[Any]] = defaultdict(
118            list
119        )
120        # {model: {field: {instances}}}
121        self.restricted_objects: defaultdict[Any, Any] = defaultdict(
122            partial(defaultdict, set)
123        )
124        # fast_deletes is a list of queryset-likes that can be deleted without
125        # fetching the objects into memory.
126        self.fast_deletes: list[Any] = []
127
128        # Tracks deletion-order dependency for databases without transactions
129        # or ability to defer constraint checks. Only concrete model classes
130        # should be included, as the dependencies exist only between actual
131        # database tables.
132        self.dependencies: defaultdict[Any, set[Any]] = defaultdict(
133            set
134        )  # {model: {models}}
135
136    def add(
137        self,
138        objs: Iterable[Any],
139        source: Any = None,
140        nullable: bool = False,
141        reverse_dependency: bool = False,
142    ) -> list[Any]:
143        """
144        Add 'objs' to the collection of objects to be deleted.  If the call is
145        the result of a cascade, 'source' should be the model that caused it,
146        and 'nullable' should be set to True if the relation can be null.
147
148        Return a list of all objects that were not already collected.
149        """
150        if not objs:
151            return []
152        new_objs = []
153        model = objs[0].__class__
154        instances = self.data[model]
155        for obj in objs:
156            if obj not in instances:
157                new_objs.append(obj)
158        instances.update(new_objs)
159        # Nullable relationships can be ignored -- they are nulled out before
160        # deleting, and therefore do not affect the order in which objects have
161        # to be deleted.
162        if source is not None and not nullable:
163            self.add_dependency(source, model, reverse_dependency=reverse_dependency)
164        return new_objs
165
166    def add_dependency(
167        self, model: Any, dependency: Any, reverse_dependency: bool = False
168    ) -> None:
169        if reverse_dependency:
170            model, dependency = dependency, model
171        self.dependencies[model].add(dependency)
172        self.data.setdefault(dependency, set())
173
174    def add_field_update(
175        self, field: RelatedField, value: Any, objs: Iterable[Any]
176    ) -> None:
177        """
178        Schedule a field update. 'objs' must be a homogeneous iterable
179        collection of model instances (e.g. a QuerySet).
180        """
181        self.field_updates[field, value].append(objs)
182
183    def add_restricted_objects(self, field: RelatedField, objs: Iterable[Any]) -> None:
184        if objs:
185            model = objs[0].__class__
186            self.restricted_objects[model][field].update(objs)
187
188    def clear_restricted_objects_from_set(self, model: Any, objs: set[Any]) -> None:
189        if model in self.restricted_objects:
190            self.restricted_objects[model] = {
191                field: items - objs
192                for field, items in self.restricted_objects[model].items()
193            }
194
195    def clear_restricted_objects_from_queryset(self, model: Any, qs: QuerySet) -> None:
196        if model in self.restricted_objects:
197            objs = set(
198                qs.filter(
199                    id__in=[
200                        obj.id
201                        for objs in self.restricted_objects[model].values()
202                        for obj in objs
203                    ]
204                )
205            )
206            self.clear_restricted_objects_from_set(model, objs)
207
208    def can_fast_delete(self, objs: Any, from_field: Any = None) -> bool:
209        """
210        Determine if the objects in the given queryset-like or single object
211        can be fast-deleted. This can be done if there are no cascades, no
212        parents and no signal listeners for the object class.
213
214        The 'from_field' tells where we are coming from - we need this to
215        determine if the objects are in fact to be deleted. Allow also
216        skipping parent -> child -> parent chain preventing fast delete of
217        the child.
218        """
219        from plain.models.fields.related import RelatedField
220
221        if (
222            isinstance(from_field, RelatedField)
223            and from_field.remote_field.on_delete is not CASCADE
224        ):
225            return False
226        if hasattr(objs, "_model_meta"):
227            model = objs._model_meta.model
228        elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
229            model = objs.model
230        else:
231            return False
232
233        # The use of from_field comes from the need to avoid cascade back to
234        # parent when parent delete is cascading to child.
235        meta = model._model_meta
236        return (
237            # Foreign keys pointing to this model.
238            all(
239                related.field.remote_field.on_delete is DO_NOTHING
240                for related in get_candidate_relations_to_delete(meta)
241            )
242        )
243
244    def get_del_batches(self, objs: list[Any], fields: list[Field]) -> list[list[Any]]:
245        """
246        Return the objs in suitably sized batches for the used db_connection.
247        """
248        field_names = [field.name for field in fields]
249        conn_batch_size = max(
250            db_connection.ops.bulk_batch_size(field_names, objs),
251            1,
252        )
253        if len(objs) > conn_batch_size:
254            return [
255                objs[i : i + conn_batch_size]
256                for i in range(0, len(objs), conn_batch_size)
257            ]
258        else:
259            return [objs]
260
261    def collect(
262        self,
263        objs: Iterable[Any],
264        source: Any = None,
265        nullable: bool = False,
266        collect_related: bool = True,
267        reverse_dependency: bool = False,
268        fail_on_restricted: bool = True,
269    ) -> None:
270        """
271        Add 'objs' to the collection of objects to be deleted as well as all
272        parent instances.  'objs' must be a homogeneous iterable collection of
273        model instances (e.g. a QuerySet).  If 'collect_related' is True,
274        related objects will be handled by their respective on_delete handler.
275
276        If the call is the result of a cascade, 'source' should be the model
277        that caused it and 'nullable' should be set to True, if the relation
278        can be null.
279
280        If 'reverse_dependency' is True, 'source' will be deleted before the
281        current model, rather than after. (Needed for cascading to parent
282        models, the one case in which the cascade follows the forwards
283        direction of an FK rather than the reverse direction.)
284
285        If 'fail_on_restricted' is False, error won't be raised even if it's
286        prohibited to delete such objects due to RESTRICT, that defers
287        restricted object checking in recursive calls where the top-level call
288        may need to collect more objects to determine whether restricted ones
289        can be deleted.
290        """
291        if self.can_fast_delete(objs):
292            self.fast_deletes.append(objs)
293            return
294        new_objs = self.add(
295            objs, source, nullable, reverse_dependency=reverse_dependency
296        )
297        if not new_objs:
298            return
299
300        model = new_objs[0].__class__
301
302        if not collect_related:
303            return
304
305        model_fast_deletes = defaultdict(list)
306        protected_objects = defaultdict(list)
307        for related in get_candidate_relations_to_delete(model._model_meta):
308            field = related.field
309            on_delete = field.remote_field.on_delete
310            if on_delete == DO_NOTHING:
311                continue
312            related_model = related.related_model
313            if self.can_fast_delete(related_model, from_field=field):
314                model_fast_deletes[related_model].append(field)
315                continue
316            batches = self.get_del_batches(new_objs, [field])
317            for batch in batches:
318                sub_objs = self.related_objects(related_model, [field], batch)
319                # Non-referenced fields can be deferred if no signal receivers
320                # are connected for the related model as they'll never be
321                # exposed to the user. Skip field deferring when some
322                # relationships are select_related as interactions between both
323                # features are hard to get right. This should only happen in
324                # the rare cases where .related_objects is overridden anyway.
325                if not sub_objs.sql_query.select_related:
326                    referenced_fields = set(
327                        chain.from_iterable(
328                            (rf.attname for rf in rel.field.foreign_related_fields)
329                            for rel in get_candidate_relations_to_delete(
330                                related_model._model_meta
331                            )
332                        )
333                    )
334                    sub_objs = sub_objs.only(*tuple(referenced_fields))
335                if getattr(on_delete, "lazy_sub_objs", False) or sub_objs:
336                    try:
337                        on_delete(self, field, sub_objs)
338                    except ProtectedError as error:
339                        key = f"'{field.model.__name__}.{field.name}'"
340                        protected_objects[key] += error.protected_objects
341        if protected_objects:
342            raise ProtectedError(
343                "Cannot delete some instances of model {!r} because they are "
344                "referenced through protected foreign keys: {}.".format(
345                    model.__name__,
346                    ", ".join(protected_objects),
347                ),
348                set(chain.from_iterable(protected_objects.values())),
349            )
350        for related_model, related_fields in model_fast_deletes.items():
351            batches = self.get_del_batches(new_objs, related_fields)
352            for batch in batches:
353                sub_objs = self.related_objects(related_model, related_fields, batch)
354                self.fast_deletes.append(sub_objs)
355
356        if fail_on_restricted:
357            # Raise an error if collected restricted objects (RESTRICT) aren't
358            # candidates for deletion also collected via CASCADE.
359            for related_model, instances in self.data.items():
360                self.clear_restricted_objects_from_set(related_model, instances)
361            for qs in self.fast_deletes:
362                self.clear_restricted_objects_from_queryset(qs.model, qs)
363            if self.restricted_objects.values():
364                restricted_objects = defaultdict(list)
365                for related_model, fields in self.restricted_objects.items():
366                    for field, objs in fields.items():
367                        if objs:
368                            key = f"'{related_model.__name__}.{field.name}'"
369                            restricted_objects[key] += objs
370                if restricted_objects:
371                    raise RestrictedError(
372                        "Cannot delete some instances of model {!r} because "
373                        "they are referenced through restricted foreign keys: "
374                        "{}.".format(
375                            model.__name__,
376                            ", ".join(restricted_objects),
377                        ),
378                        set(chain.from_iterable(restricted_objects.values())),
379                    )
380
381    def related_objects(
382        self, related_model: Any, related_fields: list[Field], objs: Iterable[Any]
383    ) -> QuerySet:
384        """
385        Get a QuerySet of the related model to objs via related fields.
386        """
387        predicate = query_utils.Q.create(
388            [(f"{related_field.name}__in", objs) for related_field in related_fields],
389            connector=query_utils.Q.OR,
390        )
391        return related_model._model_meta.base_queryset.filter(predicate)
392
393    def sort(self) -> None:
394        sorted_models = []
395        concrete_models = set()
396        models = list(self.data)
397        while len(sorted_models) < len(models):
398            found = False
399            for model in models:
400                if model in sorted_models:
401                    continue
402                dependencies = self.dependencies.get(model)
403                if not (dependencies and dependencies.difference(concrete_models)):
404                    sorted_models.append(model)
405                    concrete_models.add(model)
406                    found = True
407            if not found:
408                return
409        self.data = defaultdict(
410            set, {model: self.data[model] for model in sorted_models}
411        )
412
413    def delete(self) -> tuple[int, dict[str, int]]:
414        # sort instance collections
415        for model, instances in self.data.items():
416            self.data[model] = sorted(instances, key=attrgetter("id"))
417
418        # if possible, bring the models in an order suitable for databases that
419        # don't support transactions or cannot defer constraint checks until the
420        # end of a transaction.
421        self.sort()
422        # number of objects deleted for each model label
423        deleted_counter = Counter()
424
425        # Optimize for the case with a single obj and no dependencies
426        if len(self.data) == 1 and len(instances) == 1:
427            instance = list(instances)[0]
428            if self.can_fast_delete(instance):
429                with transaction.mark_for_rollback_on_error():
430                    count = DeleteQuery(model).delete_batch([instance.id])
431                setattr(
432                    instance, model._model_meta.get_forward_field("id").attname, None
433                )
434                return count, {model.model_options.label: count}
435
436        with transaction.atomic(savepoint=False):
437            # fast deletes
438            for qs in self.fast_deletes:
439                count = qs._raw_delete()
440                if count:
441                    deleted_counter[qs.model.model_options.label] += count
442
443            # update fields
444            for (field, value), instances_list in self.field_updates.items():
445                assert field.name is not None
446                updates = []
447                objs = []
448                for instances in instances_list:
449                    if (
450                        isinstance(instances, QuerySet)
451                        and instances._result_cache is None
452                    ):
453                        updates.append(instances)
454                    else:
455                        objs.extend(instances)
456                if updates:
457                    combined_updates = reduce(or_, updates)
458                    combined_updates.update(**{field.name: value})
459                if objs:
460                    model = objs[0].__class__
461                    query = UpdateQuery(model)
462                    query.update_batch(
463                        list({obj.id for obj in objs}), {field.name: value}
464                    )
465
466            # reverse instance collections
467            for instances in self.data.values():
468                instances.reverse()
469
470            # delete instances
471            for model, instances in self.data.items():
472                query = DeleteQuery(model)
473                id_list = [obj.id for obj in instances]
474                count = query.delete_batch(id_list)
475                if count:
476                    deleted_counter[model.model_options.label] += count
477
478        for model, instances in self.data.items():
479            for instance in instances:
480                setattr(
481                    instance, model._model_meta.get_forward_field("id").attname, None
482                )
483        return sum(deleted_counter.values()), dict(deleted_counter)