1from __future__ import annotations
2
3from collections import Counter, defaultdict
4from collections.abc import Callable, Generator, Iterable
5from functools import partial, reduce
6from itertools import chain
7from operator import attrgetter, or_
8from typing import TYPE_CHECKING, Any
9
10from plain.postgres import (
11 query_utils,
12 transaction,
13)
14from plain.postgres.db import IntegrityError
15from plain.postgres.meta import Meta
16from plain.postgres.query import QuerySet
17from plain.postgres.sql import DeleteQuery, UpdateQuery
18
19if TYPE_CHECKING:
20 from plain.postgres.fields import Field
21 from plain.postgres.fields.related import RelatedField
22 from plain.postgres.fields.reverse_related import ForeignKeyRel
23
24# Handlers in this set skip the sub_objs emptiness check in Collector.collect,
25# allowing the handler to run even when there are no related objects. External
26# on_delete handlers can still opt-in by setting ``lazy_sub_objs = True`` as an
27# attribute โ the Collector checks both this set and getattr as a fallback.
28_LAZY_ON_DELETE: set[Callable[..., Any]] = set()
29
30
31class ProtectedError(IntegrityError):
32 def __init__(self, msg: str, protected_objects: Iterable[Any]) -> None:
33 self.protected_objects = protected_objects
34 super().__init__(msg, protected_objects)
35
36
37class RestrictedError(IntegrityError):
38 def __init__(self, msg: str, restricted_objects: Iterable[Any]) -> None:
39 self.restricted_objects = restricted_objects
40 super().__init__(msg, restricted_objects)
41
42
43def CASCADE(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
44 collector.collect(
45 sub_objs,
46 source=field.remote_field.model,
47 nullable=field.allow_null,
48 fail_on_restricted=False,
49 )
50
51
52def PROTECT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
53 raise ProtectedError(
54 f"Cannot delete some instances of model '{field.remote_field.model.__name__}' because they are "
55 f"referenced through a protected foreign key: '{sub_objs[0].__class__.__name__}.{field.name}'",
56 sub_objs,
57 )
58
59
60def RESTRICT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
61 collector.add_restricted_objects(field, sub_objs)
62 collector.add_dependency(field.remote_field.model, field.model)
63
64
65class _SetOnDelete:
66 """On-delete handler returned by :func:`SET`."""
67
68 lazy_sub_objs: bool = True
69
70 def __init__(self, value: Any) -> None:
71 self._value = value
72 self._resolve = callable(value)
73
74 def __call__(
75 self, collector: Collector, field: RelatedField, sub_objs: Any
76 ) -> None:
77 resolved = self._value() if self._resolve else self._value
78 collector.add_field_update(field, resolved, sub_objs)
79
80 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
81 return ("plain.postgres.SET", (self._value,), {})
82
83
84def SET(value: Any) -> _SetOnDelete:
85 return _SetOnDelete(value)
86
87
88def SET_NULL(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
89 collector.add_field_update(field, None, sub_objs)
90
91
92_LAZY_ON_DELETE.add(SET_NULL)
93
94
95def SET_DEFAULT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
96 collector.add_field_update(field, field.get_default(), sub_objs)
97
98
99_LAZY_ON_DELETE.add(SET_DEFAULT)
100
101
102def DO_NOTHING(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
103 pass
104
105
106def get_candidate_relations_to_delete(
107 meta: Meta,
108) -> Generator[ForeignKeyRel]:
109 from plain.postgres.fields.reverse_related import ForeignKeyRel
110
111 # The candidate relations are the ones that come from N-1 and 1-1 relations.
112 # N-N (i.e., many-to-many) relations aren't candidates for deletion.
113 return (
114 f
115 for f in meta.get_fields(include_reverse=True)
116 if f.auto_created and not f.concrete and isinstance(f, ForeignKeyRel)
117 )
118
119
120class Collector:
121 def __init__(self, origin: Any = None) -> None:
122 # A Model or QuerySet object.
123 self.origin = origin
124 # Initially, {model: {instances}}, later values become lists.
125 self.data: defaultdict[Any, Any] = defaultdict(set)
126 # {(field, value): [instances, โฆ]}
127 self.field_updates: defaultdict[tuple[Field, Any], list[Any]] = defaultdict(
128 list
129 )
130 # {model: {field: {instances}}}
131 self.restricted_objects: defaultdict[Any, Any] = defaultdict(
132 partial(defaultdict, set)
133 )
134 # fast_deletes is a list of queryset-likes that can be deleted without
135 # fetching the objects into memory.
136 self.fast_deletes: list[Any] = []
137
138 # Tracks deletion-order dependency for constraint checks. Only
139 # concrete model classes should be included, as the dependencies
140 # exist only between actual database tables.
141 self.dependencies: defaultdict[Any, set[Any]] = defaultdict(
142 set
143 ) # {model: {models}}
144
145 def add(
146 self,
147 objs: Iterable[Any],
148 source: Any = None,
149 nullable: bool = False,
150 reverse_dependency: bool = False,
151 ) -> list[Any]:
152 """
153 Add 'objs' to the collection of objects to be deleted. If the call is
154 the result of a cascade, 'source' should be the model that caused it,
155 and 'nullable' should be set to True if the relation can be null.
156
157 Return a list of all objects that were not already collected.
158 """
159 if not objs:
160 return []
161 new_objs = []
162 model = objs[0].__class__ # type: ignore[index]
163 instances = self.data[model]
164 for obj in objs:
165 if obj not in instances:
166 new_objs.append(obj)
167 instances.update(new_objs)
168 # Nullable relationships can be ignored -- they are nulled out before
169 # deleting, and therefore do not affect the order in which objects have
170 # to be deleted.
171 if source is not None and not nullable:
172 self.add_dependency(source, model, reverse_dependency=reverse_dependency)
173 return new_objs
174
175 def add_dependency(
176 self, model: Any, dependency: Any, reverse_dependency: bool = False
177 ) -> None:
178 if reverse_dependency:
179 model, dependency = dependency, model
180 self.dependencies[model].add(dependency)
181 self.data.setdefault(dependency, set())
182
183 def add_field_update(
184 self, field: RelatedField, value: Any, objs: Iterable[Any]
185 ) -> None:
186 """
187 Schedule a field update. 'objs' must be a homogeneous iterable
188 collection of model instances (e.g. a QuerySet).
189 """
190 self.field_updates[field, value].append(objs)
191
192 def add_restricted_objects(self, field: RelatedField, objs: Iterable[Any]) -> None:
193 if objs:
194 model = objs[0].__class__ # type: ignore[index]
195 self.restricted_objects[model][field].update(objs)
196
197 def clear_restricted_objects_from_set(self, model: Any, objs: set[Any]) -> None:
198 if model in self.restricted_objects:
199 self.restricted_objects[model] = {
200 field: items - objs
201 for field, items in self.restricted_objects[model].items()
202 }
203
204 def clear_restricted_objects_from_queryset(self, model: Any, qs: QuerySet) -> None:
205 if model in self.restricted_objects:
206 objs = set(
207 qs.filter(
208 id__in=[
209 obj.id
210 for objs in self.restricted_objects[model].values()
211 for obj in objs
212 ]
213 )
214 )
215 self.clear_restricted_objects_from_set(model, objs)
216
217 def can_fast_delete(self, objs: Any, from_field: Any = None) -> bool:
218 """
219 Determine if the objects in the given queryset-like or single object
220 can be fast-deleted. This can be done if there are no cascades, no
221 parents and no signal listeners for the object class.
222
223 The 'from_field' tells where we are coming from - we need this to
224 determine if the objects are in fact to be deleted. Allow also
225 skipping parent -> child -> parent chain preventing fast delete of
226 the child.
227 """
228 from plain.postgres.fields.related import RelatedField
229
230 if (
231 isinstance(from_field, RelatedField)
232 and from_field.remote_field.on_delete is not CASCADE
233 ):
234 return False
235 if hasattr(objs, "_model_meta"):
236 model = objs._model_meta.model
237 elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
238 model = objs.model
239 else:
240 return False
241
242 # The use of from_field comes from the need to avoid cascade back to
243 # parent when parent delete is cascading to child.
244 meta = model._model_meta
245 return (
246 # Foreign keys pointing to this model.
247 all(
248 related.field.remote_field.on_delete is DO_NOTHING
249 for related in get_candidate_relations_to_delete(meta)
250 )
251 )
252
253 def collect(
254 self,
255 objs: Iterable[Any],
256 source: Any = None,
257 nullable: bool = False,
258 collect_related: bool = True,
259 reverse_dependency: bool = False,
260 fail_on_restricted: bool = True,
261 ) -> None:
262 """
263 Add 'objs' to the collection of objects to be deleted as well as all
264 parent instances. 'objs' must be a homogeneous iterable collection of
265 model instances (e.g. a QuerySet). If 'collect_related' is True,
266 related objects will be handled by their respective on_delete handler.
267
268 If the call is the result of a cascade, 'source' should be the model
269 that caused it and 'nullable' should be set to True, if the relation
270 can be null.
271
272 If 'reverse_dependency' is True, 'source' will be deleted before the
273 current model, rather than after. (Needed for cascading to parent
274 models, the one case in which the cascade follows the forwards
275 direction of an FK rather than the reverse direction.)
276
277 If 'fail_on_restricted' is False, error won't be raised even if it's
278 prohibited to delete such objects due to RESTRICT, that defers
279 restricted object checking in recursive calls where the top-level call
280 may need to collect more objects to determine whether restricted ones
281 can be deleted.
282 """
283 if self.can_fast_delete(objs):
284 self.fast_deletes.append(objs)
285 return
286 new_objs = self.add(
287 objs, source, nullable, reverse_dependency=reverse_dependency
288 )
289 if not new_objs:
290 return
291
292 model = new_objs[0].__class__
293
294 if not collect_related:
295 return
296
297 model_fast_deletes = defaultdict(list)
298 protected_objects = defaultdict(list)
299 for related in get_candidate_relations_to_delete(model._model_meta):
300 field = related.field
301 on_delete = field.remote_field.on_delete
302 if on_delete == DO_NOTHING:
303 continue
304 related_model = related.related_model
305 if self.can_fast_delete(related_model, from_field=field):
306 model_fast_deletes[related_model].append(field)
307 continue
308 sub_objs = self.related_objects(related_model, [field], new_objs)
309 # Non-referenced fields can be deferred if no signal receivers
310 # are connected for the related model as they'll never be
311 # exposed to the user. Skip field deferring when some
312 # relationships are select_related as interactions between both
313 # features are hard to get right. This should only happen in
314 # the rare cases where .related_objects is overridden anyway.
315 if not sub_objs.sql_query.select_related:
316 referenced_fields = set(
317 chain.from_iterable(
318 (rf.attname for rf in rel.field.foreign_related_fields)
319 for rel in get_candidate_relations_to_delete(
320 related_model._model_meta
321 )
322 )
323 )
324 sub_objs = sub_objs.only(*tuple(referenced_fields))
325 if (
326 on_delete in _LAZY_ON_DELETE
327 or getattr(on_delete, "lazy_sub_objs", False)
328 or sub_objs
329 ):
330 try:
331 on_delete(self, field, sub_objs)
332 except ProtectedError as error:
333 key = f"'{field.model.__name__}.{field.name}'"
334 protected_objects[key] += error.protected_objects
335 if protected_objects:
336 raise ProtectedError(
337 "Cannot delete some instances of model {!r} because they are "
338 "referenced through protected foreign keys: {}.".format(
339 model.__name__,
340 ", ".join(protected_objects),
341 ),
342 set(chain.from_iterable(protected_objects.values())),
343 )
344 for related_model, related_fields in model_fast_deletes.items():
345 sub_objs = self.related_objects(related_model, related_fields, new_objs)
346 self.fast_deletes.append(sub_objs)
347
348 if fail_on_restricted:
349 # Raise an error if collected restricted objects (RESTRICT) aren't
350 # candidates for deletion also collected via CASCADE.
351 for related_model, instances in self.data.items():
352 self.clear_restricted_objects_from_set(related_model, instances)
353 for qs in self.fast_deletes:
354 self.clear_restricted_objects_from_queryset(qs.model, qs)
355 if self.restricted_objects.values():
356 restricted_objects = defaultdict(list)
357 for related_model, fields in self.restricted_objects.items():
358 for field, objs in fields.items():
359 if objs:
360 key = f"'{related_model.__name__}.{field.name}'"
361 restricted_objects[key] += objs
362 if restricted_objects:
363 raise RestrictedError(
364 "Cannot delete some instances of model {!r} because "
365 "they are referenced through restricted foreign keys: "
366 "{}.".format(
367 model.__name__,
368 ", ".join(restricted_objects),
369 ),
370 set(chain.from_iterable(restricted_objects.values())),
371 )
372
373 def related_objects(
374 self, related_model: Any, related_fields: list[Field], objs: Iterable[Any]
375 ) -> QuerySet:
376 """
377 Get a QuerySet of the related model to objs via related fields.
378 """
379 predicate = query_utils.Q.create(
380 [(f"{related_field.name}__in", objs) for related_field in related_fields],
381 connector=query_utils.Q.OR,
382 )
383 return related_model._model_meta.base_queryset.filter(predicate)
384
385 def sort(self) -> None:
386 sorted_models = []
387 concrete_models = set()
388 models = list(self.data)
389 while len(sorted_models) < len(models):
390 found = False
391 for model in models:
392 if model in sorted_models:
393 continue
394 dependencies = self.dependencies.get(model)
395 if not (dependencies and dependencies.difference(concrete_models)):
396 sorted_models.append(model)
397 concrete_models.add(model)
398 found = True
399 if not found:
400 return
401 self.data = defaultdict(
402 set, {model: self.data[model] for model in sorted_models}
403 )
404
405 def delete(self) -> tuple[int, dict[str, int]]:
406 # sort instance collections
407 for model, instances in self.data.items():
408 self.data[model] = sorted(instances, key=attrgetter("id"))
409
410 # if possible, bring the models in an order suitable for databases that
411 # don't support transactions or cannot defer constraint checks until the
412 # end of a transaction.
413 self.sort()
414 # number of objects deleted for each model label
415 deleted_counter = Counter()
416
417 # Optimize for the case with a single obj and no dependencies
418 if len(self.data) == 1 and len(instances) == 1:
419 instance = list(instances)[0]
420 if self.can_fast_delete(instance):
421 with transaction.mark_for_rollback_on_error():
422 count = DeleteQuery(model).delete_batch([instance.id])
423 setattr(
424 instance, model._model_meta.get_forward_field("id").attname, None
425 )
426 return count, {model.model_options.label: count}
427
428 with transaction.atomic(savepoint=False):
429 # fast deletes
430 for qs in self.fast_deletes:
431 count = qs._raw_delete()
432 if count:
433 deleted_counter[qs.model.model_options.label] += count
434
435 # update fields
436 for (field, value), instances_list in self.field_updates.items():
437 assert field.name is not None
438 updates = []
439 objs = []
440 for instances in instances_list:
441 if (
442 isinstance(instances, QuerySet)
443 and instances._result_cache is None
444 ):
445 updates.append(instances)
446 else:
447 objs.extend(instances)
448 if updates:
449 combined_updates = reduce(or_, updates)
450 combined_updates.update(**{field.name: value})
451 if objs:
452 model = objs[0].__class__
453 query = UpdateQuery(model)
454 query.update_batch(
455 list({obj.id for obj in objs}), {field.name: value}
456 )
457
458 # reverse instance collections
459 for instances in self.data.values():
460 instances.reverse()
461
462 # delete instances
463 for model, instances in self.data.items():
464 query = DeleteQuery(model)
465 id_list = [obj.id for obj in instances]
466 count = query.delete_batch(id_list)
467 if count:
468 deleted_counter[model.model_options.label] += count
469
470 for model, instances in self.data.items():
471 for instance in instances:
472 setattr(
473 instance, model._model_meta.get_forward_field("id").attname, None
474 )
475 return sum(deleted_counter.values()), dict(deleted_counter)