1from __future__ import annotations
  2
  3from collections import Counter, defaultdict
  4from collections.abc import Callable, Generator, Iterable
  5from functools import partial, reduce
  6from itertools import chain
  7from operator import attrgetter, or_
  8from typing import TYPE_CHECKING, Any
  9
 10from plain.postgres import (
 11    query_utils,
 12    transaction,
 13)
 14from plain.postgres.db import IntegrityError
 15from plain.postgres.meta import Meta
 16from plain.postgres.query import QuerySet
 17from plain.postgres.sql import DeleteQuery, UpdateQuery
 18
 19if TYPE_CHECKING:
 20    from plain.postgres.fields import Field
 21    from plain.postgres.fields.related import RelatedField
 22    from plain.postgres.fields.reverse_related import ForeignKeyRel
 23
 24# Handlers in this set skip the sub_objs emptiness check in Collector.collect,
 25# allowing the handler to run even when there are no related objects.  External
 26# on_delete handlers can still opt-in by setting ``lazy_sub_objs = True`` as an
 27# attribute โ€” the Collector checks both this set and getattr as a fallback.
 28_LAZY_ON_DELETE: set[Callable[..., Any]] = set()
 29
 30
 31class ProtectedError(IntegrityError):
 32    def __init__(self, msg: str, protected_objects: Iterable[Any]) -> None:
 33        self.protected_objects = protected_objects
 34        super().__init__(msg, protected_objects)
 35
 36
 37class RestrictedError(IntegrityError):
 38    def __init__(self, msg: str, restricted_objects: Iterable[Any]) -> None:
 39        self.restricted_objects = restricted_objects
 40        super().__init__(msg, restricted_objects)
 41
 42
 43def CASCADE(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 44    collector.collect(
 45        sub_objs,
 46        source=field.remote_field.model,
 47        nullable=field.allow_null,
 48        fail_on_restricted=False,
 49    )
 50
 51
 52def PROTECT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 53    raise ProtectedError(
 54        f"Cannot delete some instances of model '{field.remote_field.model.__name__}' because they are "
 55        f"referenced through a protected foreign key: '{sub_objs[0].__class__.__name__}.{field.name}'",
 56        sub_objs,
 57    )
 58
 59
 60def RESTRICT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 61    collector.add_restricted_objects(field, sub_objs)
 62    collector.add_dependency(field.remote_field.model, field.model)
 63
 64
 65class _SetOnDelete:
 66    """On-delete handler returned by :func:`SET`."""
 67
 68    lazy_sub_objs: bool = True
 69
 70    def __init__(self, value: Any) -> None:
 71        self._value = value
 72        self._resolve = callable(value)
 73
 74    def __call__(
 75        self, collector: Collector, field: RelatedField, sub_objs: Any
 76    ) -> None:
 77        resolved = self._value() if self._resolve else self._value
 78        collector.add_field_update(field, resolved, sub_objs)
 79
 80    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 81        return ("plain.postgres.SET", (self._value,), {})
 82
 83
 84def SET(value: Any) -> _SetOnDelete:
 85    return _SetOnDelete(value)
 86
 87
 88def SET_NULL(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 89    collector.add_field_update(field, None, sub_objs)
 90
 91
 92_LAZY_ON_DELETE.add(SET_NULL)
 93
 94
 95def SET_DEFAULT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
 96    collector.add_field_update(field, field.get_default(), sub_objs)
 97
 98
 99_LAZY_ON_DELETE.add(SET_DEFAULT)
100
101
102def DO_NOTHING(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
103    pass
104
105
106def get_candidate_relations_to_delete(
107    meta: Meta,
108) -> Generator[ForeignKeyRel]:
109    from plain.postgres.fields.reverse_related import ForeignKeyRel
110
111    # The candidate relations are the ones that come from N-1 and 1-1 relations.
112    # N-N  (i.e., many-to-many) relations aren't candidates for deletion.
113    return (
114        f
115        for f in meta.get_fields(include_reverse=True)
116        if f.auto_created and not f.concrete and isinstance(f, ForeignKeyRel)
117    )
118
119
120class Collector:
121    def __init__(self, origin: Any = None) -> None:
122        # A Model or QuerySet object.
123        self.origin = origin
124        # Initially, {model: {instances}}, later values become lists.
125        self.data: defaultdict[Any, Any] = defaultdict(set)
126        # {(field, value): [instances, โ€ฆ]}
127        self.field_updates: defaultdict[tuple[Field, Any], list[Any]] = defaultdict(
128            list
129        )
130        # {model: {field: {instances}}}
131        self.restricted_objects: defaultdict[Any, Any] = defaultdict(
132            partial(defaultdict, set)
133        )
134        # fast_deletes is a list of queryset-likes that can be deleted without
135        # fetching the objects into memory.
136        self.fast_deletes: list[Any] = []
137
138        # Tracks deletion-order dependency for constraint checks. Only
139        # concrete model classes should be included, as the dependencies
140        # exist only between actual database tables.
141        self.dependencies: defaultdict[Any, set[Any]] = defaultdict(
142            set
143        )  # {model: {models}}
144
145    def add(
146        self,
147        objs: Iterable[Any],
148        source: Any = None,
149        nullable: bool = False,
150        reverse_dependency: bool = False,
151    ) -> list[Any]:
152        """
153        Add 'objs' to the collection of objects to be deleted.  If the call is
154        the result of a cascade, 'source' should be the model that caused it,
155        and 'nullable' should be set to True if the relation can be null.
156
157        Return a list of all objects that were not already collected.
158        """
159        if not objs:
160            return []
161        new_objs = []
162        model = objs[0].__class__  # type: ignore[index]
163        instances = self.data[model]
164        for obj in objs:
165            if obj not in instances:
166                new_objs.append(obj)
167        instances.update(new_objs)
168        # Nullable relationships can be ignored -- they are nulled out before
169        # deleting, and therefore do not affect the order in which objects have
170        # to be deleted.
171        if source is not None and not nullable:
172            self.add_dependency(source, model, reverse_dependency=reverse_dependency)
173        return new_objs
174
175    def add_dependency(
176        self, model: Any, dependency: Any, reverse_dependency: bool = False
177    ) -> None:
178        if reverse_dependency:
179            model, dependency = dependency, model
180        self.dependencies[model].add(dependency)
181        self.data.setdefault(dependency, set())
182
183    def add_field_update(
184        self, field: RelatedField, value: Any, objs: Iterable[Any]
185    ) -> None:
186        """
187        Schedule a field update. 'objs' must be a homogeneous iterable
188        collection of model instances (e.g. a QuerySet).
189        """
190        self.field_updates[field, value].append(objs)
191
192    def add_restricted_objects(self, field: RelatedField, objs: Iterable[Any]) -> None:
193        if objs:
194            model = objs[0].__class__  # type: ignore[index]
195            self.restricted_objects[model][field].update(objs)
196
197    def clear_restricted_objects_from_set(self, model: Any, objs: set[Any]) -> None:
198        if model in self.restricted_objects:
199            self.restricted_objects[model] = {
200                field: items - objs
201                for field, items in self.restricted_objects[model].items()
202            }
203
204    def clear_restricted_objects_from_queryset(self, model: Any, qs: QuerySet) -> None:
205        if model in self.restricted_objects:
206            objs = set(
207                qs.filter(
208                    id__in=[
209                        obj.id
210                        for objs in self.restricted_objects[model].values()
211                        for obj in objs
212                    ]
213                )
214            )
215            self.clear_restricted_objects_from_set(model, objs)
216
217    def can_fast_delete(self, objs: Any, from_field: Any = None) -> bool:
218        """
219        Determine if the objects in the given queryset-like or single object
220        can be fast-deleted. This can be done if there are no cascades, no
221        parents and no signal listeners for the object class.
222
223        The 'from_field' tells where we are coming from - we need this to
224        determine if the objects are in fact to be deleted. Allow also
225        skipping parent -> child -> parent chain preventing fast delete of
226        the child.
227        """
228        from plain.postgres.fields.related import RelatedField
229
230        if (
231            isinstance(from_field, RelatedField)
232            and from_field.remote_field.on_delete is not CASCADE
233        ):
234            return False
235        if hasattr(objs, "_model_meta"):
236            model = objs._model_meta.model
237        elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
238            model = objs.model
239        else:
240            return False
241
242        # The use of from_field comes from the need to avoid cascade back to
243        # parent when parent delete is cascading to child.
244        meta = model._model_meta
245        return (
246            # Foreign keys pointing to this model.
247            all(
248                related.field.remote_field.on_delete is DO_NOTHING
249                for related in get_candidate_relations_to_delete(meta)
250            )
251        )
252
253    def collect(
254        self,
255        objs: Iterable[Any],
256        source: Any = None,
257        nullable: bool = False,
258        collect_related: bool = True,
259        reverse_dependency: bool = False,
260        fail_on_restricted: bool = True,
261    ) -> None:
262        """
263        Add 'objs' to the collection of objects to be deleted as well as all
264        parent instances.  'objs' must be a homogeneous iterable collection of
265        model instances (e.g. a QuerySet).  If 'collect_related' is True,
266        related objects will be handled by their respective on_delete handler.
267
268        If the call is the result of a cascade, 'source' should be the model
269        that caused it and 'nullable' should be set to True, if the relation
270        can be null.
271
272        If 'reverse_dependency' is True, 'source' will be deleted before the
273        current model, rather than after. (Needed for cascading to parent
274        models, the one case in which the cascade follows the forwards
275        direction of an FK rather than the reverse direction.)
276
277        If 'fail_on_restricted' is False, error won't be raised even if it's
278        prohibited to delete such objects due to RESTRICT, that defers
279        restricted object checking in recursive calls where the top-level call
280        may need to collect more objects to determine whether restricted ones
281        can be deleted.
282        """
283        if self.can_fast_delete(objs):
284            self.fast_deletes.append(objs)
285            return
286        new_objs = self.add(
287            objs, source, nullable, reverse_dependency=reverse_dependency
288        )
289        if not new_objs:
290            return
291
292        model = new_objs[0].__class__
293
294        if not collect_related:
295            return
296
297        model_fast_deletes = defaultdict(list)
298        protected_objects = defaultdict(list)
299        for related in get_candidate_relations_to_delete(model._model_meta):
300            field = related.field
301            on_delete = field.remote_field.on_delete
302            if on_delete == DO_NOTHING:
303                continue
304            related_model = related.related_model
305            if self.can_fast_delete(related_model, from_field=field):
306                model_fast_deletes[related_model].append(field)
307                continue
308            sub_objs = self.related_objects(related_model, [field], new_objs)
309            # Non-referenced fields can be deferred if no signal receivers
310            # are connected for the related model as they'll never be
311            # exposed to the user. Skip field deferring when some
312            # relationships are select_related as interactions between both
313            # features are hard to get right. This should only happen in
314            # the rare cases where .related_objects is overridden anyway.
315            if not sub_objs.sql_query.select_related:
316                referenced_fields = set(
317                    chain.from_iterable(
318                        (rf.attname for rf in rel.field.foreign_related_fields)
319                        for rel in get_candidate_relations_to_delete(
320                            related_model._model_meta
321                        )
322                    )
323                )
324                sub_objs = sub_objs.only(*tuple(referenced_fields))
325            if (
326                on_delete in _LAZY_ON_DELETE
327                or getattr(on_delete, "lazy_sub_objs", False)
328                or sub_objs
329            ):
330                try:
331                    on_delete(self, field, sub_objs)
332                except ProtectedError as error:
333                    key = f"'{field.model.__name__}.{field.name}'"
334                    protected_objects[key] += error.protected_objects
335        if protected_objects:
336            raise ProtectedError(
337                "Cannot delete some instances of model {!r} because they are "
338                "referenced through protected foreign keys: {}.".format(
339                    model.__name__,
340                    ", ".join(protected_objects),
341                ),
342                set(chain.from_iterable(protected_objects.values())),
343            )
344        for related_model, related_fields in model_fast_deletes.items():
345            sub_objs = self.related_objects(related_model, related_fields, new_objs)
346            self.fast_deletes.append(sub_objs)
347
348        if fail_on_restricted:
349            # Raise an error if collected restricted objects (RESTRICT) aren't
350            # candidates for deletion also collected via CASCADE.
351            for related_model, instances in self.data.items():
352                self.clear_restricted_objects_from_set(related_model, instances)
353            for qs in self.fast_deletes:
354                self.clear_restricted_objects_from_queryset(qs.model, qs)
355            if self.restricted_objects.values():
356                restricted_objects = defaultdict(list)
357                for related_model, fields in self.restricted_objects.items():
358                    for field, objs in fields.items():
359                        if objs:
360                            key = f"'{related_model.__name__}.{field.name}'"
361                            restricted_objects[key] += objs
362                if restricted_objects:
363                    raise RestrictedError(
364                        "Cannot delete some instances of model {!r} because "
365                        "they are referenced through restricted foreign keys: "
366                        "{}.".format(
367                            model.__name__,
368                            ", ".join(restricted_objects),
369                        ),
370                        set(chain.from_iterable(restricted_objects.values())),
371                    )
372
373    def related_objects(
374        self, related_model: Any, related_fields: list[Field], objs: Iterable[Any]
375    ) -> QuerySet:
376        """
377        Get a QuerySet of the related model to objs via related fields.
378        """
379        predicate = query_utils.Q.create(
380            [(f"{related_field.name}__in", objs) for related_field in related_fields],
381            connector=query_utils.Q.OR,
382        )
383        return related_model._model_meta.base_queryset.filter(predicate)
384
385    def sort(self) -> None:
386        sorted_models = []
387        concrete_models = set()
388        models = list(self.data)
389        while len(sorted_models) < len(models):
390            found = False
391            for model in models:
392                if model in sorted_models:
393                    continue
394                dependencies = self.dependencies.get(model)
395                if not (dependencies and dependencies.difference(concrete_models)):
396                    sorted_models.append(model)
397                    concrete_models.add(model)
398                    found = True
399            if not found:
400                return
401        self.data = defaultdict(
402            set, {model: self.data[model] for model in sorted_models}
403        )
404
405    def delete(self) -> tuple[int, dict[str, int]]:
406        # sort instance collections
407        for model, instances in self.data.items():
408            self.data[model] = sorted(instances, key=attrgetter("id"))
409
410        # if possible, bring the models in an order suitable for databases that
411        # don't support transactions or cannot defer constraint checks until the
412        # end of a transaction.
413        self.sort()
414        # number of objects deleted for each model label
415        deleted_counter = Counter()
416
417        # Optimize for the case with a single obj and no dependencies
418        if len(self.data) == 1 and len(instances) == 1:
419            instance = list(instances)[0]
420            if self.can_fast_delete(instance):
421                with transaction.mark_for_rollback_on_error():
422                    count = DeleteQuery(model).delete_batch([instance.id])
423                setattr(
424                    instance, model._model_meta.get_forward_field("id").attname, None
425                )
426                return count, {model.model_options.label: count}
427
428        with transaction.atomic(savepoint=False):
429            # fast deletes
430            for qs in self.fast_deletes:
431                count = qs._raw_delete()
432                if count:
433                    deleted_counter[qs.model.model_options.label] += count
434
435            # update fields
436            for (field, value), instances_list in self.field_updates.items():
437                assert field.name is not None
438                updates = []
439                objs = []
440                for instances in instances_list:
441                    if (
442                        isinstance(instances, QuerySet)
443                        and instances._result_cache is None
444                    ):
445                        updates.append(instances)
446                    else:
447                        objs.extend(instances)
448                if updates:
449                    combined_updates = reduce(or_, updates)
450                    combined_updates.update(**{field.name: value})
451                if objs:
452                    model = objs[0].__class__
453                    query = UpdateQuery(model)
454                    query.update_batch(
455                        list({obj.id for obj in objs}), {field.name: value}
456                    )
457
458            # reverse instance collections
459            for instances in self.data.values():
460                instances.reverse()
461
462            # delete instances
463            for model, instances in self.data.items():
464                query = DeleteQuery(model)
465                id_list = [obj.id for obj in instances]
466                count = query.delete_batch(id_list)
467                if count:
468                    deleted_counter[model.model_options.label] += count
469
470        for model, instances in self.data.items():
471            for instance in instances:
472                setattr(
473                    instance, model._model_meta.get_forward_field("id").attname, None
474                )
475        return sum(deleted_counter.values()), dict(deleted_counter)