1from __future__ import annotations
2
3from collections import Counter, defaultdict
4from collections.abc import Callable, Generator, Iterable
5from functools import partial, reduce
6from itertools import chain
7from operator import attrgetter, or_
8from typing import TYPE_CHECKING, Any
9
10from plain.models import (
11 query_utils,
12 sql,
13 transaction,
14)
15from plain.models.db import IntegrityError, db_connection
16from plain.models.meta import Meta
17from plain.models.query import QuerySet
18
19if TYPE_CHECKING:
20 from plain.models.fields import Field
21 from plain.models.fields.related import RelatedField
22 from plain.models.fields.reverse_related import ForeignKeyRel
23
24
25class ProtectedError(IntegrityError):
26 def __init__(self, msg: str, protected_objects: Iterable[Any]) -> None:
27 self.protected_objects = protected_objects
28 super().__init__(msg, protected_objects)
29
30
31class RestrictedError(IntegrityError):
32 def __init__(self, msg: str, restricted_objects: Iterable[Any]) -> None:
33 self.restricted_objects = restricted_objects
34 super().__init__(msg, restricted_objects)
35
36
37def CASCADE(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
38 collector.collect(
39 sub_objs,
40 source=field.remote_field.model,
41 nullable=field.allow_null,
42 fail_on_restricted=False,
43 )
44 if field.allow_null and not db_connection.features.can_defer_constraint_checks:
45 collector.add_field_update(field, None, sub_objs)
46
47
48def PROTECT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
49 raise ProtectedError(
50 f"Cannot delete some instances of model '{field.remote_field.model.__name__}' because they are "
51 f"referenced through a protected foreign key: '{sub_objs[0].__class__.__name__}.{field.name}'",
52 sub_objs,
53 )
54
55
56def RESTRICT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
57 collector.add_restricted_objects(field, sub_objs)
58 collector.add_dependency(field.remote_field.model, field.model)
59
60
61def SET(value: Any) -> Callable[[Collector, RelatedField, Any], None]:
62 if callable(value):
63
64 def set_on_delete(
65 collector: Collector, field: RelatedField, sub_objs: Any
66 ) -> None:
67 collector.add_field_update(field, value(), sub_objs)
68
69 else:
70
71 def set_on_delete(
72 collector: Collector, field: RelatedField, sub_objs: Any
73 ) -> None:
74 collector.add_field_update(field, value, sub_objs)
75
76 set_on_delete.deconstruct = lambda: ("plain.models.SET", (value,), {}) # type: ignore[attr-defined]
77 set_on_delete.lazy_sub_objs = True # type: ignore[attr-defined]
78 return set_on_delete
79
80
81def SET_NULL(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
82 collector.add_field_update(field, None, sub_objs)
83
84
85SET_NULL.lazy_sub_objs = True # type: ignore[attr-defined]
86
87
88def SET_DEFAULT(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
89 collector.add_field_update(field, field.get_default(), sub_objs)
90
91
92SET_DEFAULT.lazy_sub_objs = True # type: ignore[attr-defined]
93
94
95def DO_NOTHING(collector: Collector, field: RelatedField, sub_objs: Any) -> None:
96 pass
97
98
99def get_candidate_relations_to_delete(
100 meta: Meta,
101) -> Generator[ForeignKeyRel, None, None]:
102 from plain.models.fields.reverse_related import ForeignKeyRel
103
104 # The candidate relations are the ones that come from N-1 and 1-1 relations.
105 # N-N (i.e., many-to-many) relations aren't candidates for deletion.
106 return (
107 f
108 for f in meta.get_fields(include_reverse=True)
109 if f.auto_created and not f.concrete and isinstance(f, ForeignKeyRel)
110 )
111
112
113class Collector:
114 def __init__(self, origin: Any = None) -> None:
115 # A Model or QuerySet object.
116 self.origin = origin
117 # Initially, {model: {instances}}, later values become lists.
118 self.data: defaultdict[Any, Any] = defaultdict(set)
119 # {(field, value): [instances, …]}
120 self.field_updates: defaultdict[tuple[Field, Any], list[Any]] = defaultdict(
121 list
122 )
123 # {model: {field: {instances}}}
124 self.restricted_objects: defaultdict[Any, Any] = defaultdict(
125 partial(defaultdict, set)
126 )
127 # fast_deletes is a list of queryset-likes that can be deleted without
128 # fetching the objects into memory.
129 self.fast_deletes: list[Any] = []
130
131 # Tracks deletion-order dependency for databases without transactions
132 # or ability to defer constraint checks. Only concrete model classes
133 # should be included, as the dependencies exist only between actual
134 # database tables.
135 self.dependencies: defaultdict[Any, set[Any]] = defaultdict(
136 set
137 ) # {model: {models}}
138
139 def add(
140 self,
141 objs: Iterable[Any],
142 source: Any = None,
143 nullable: bool = False,
144 reverse_dependency: bool = False,
145 ) -> list[Any]:
146 """
147 Add 'objs' to the collection of objects to be deleted. If the call is
148 the result of a cascade, 'source' should be the model that caused it,
149 and 'nullable' should be set to True if the relation can be null.
150
151 Return a list of all objects that were not already collected.
152 """
153 if not objs:
154 return []
155 new_objs = []
156 model = objs[0].__class__
157 instances = self.data[model]
158 for obj in objs:
159 if obj not in instances:
160 new_objs.append(obj)
161 instances.update(new_objs)
162 # Nullable relationships can be ignored -- they are nulled out before
163 # deleting, and therefore do not affect the order in which objects have
164 # to be deleted.
165 if source is not None and not nullable:
166 self.add_dependency(source, model, reverse_dependency=reverse_dependency)
167 return new_objs
168
169 def add_dependency(
170 self, model: Any, dependency: Any, reverse_dependency: bool = False
171 ) -> None:
172 if reverse_dependency:
173 model, dependency = dependency, model
174 self.dependencies[model].add(dependency)
175 self.data.setdefault(dependency, set())
176
177 def add_field_update(
178 self, field: RelatedField, value: Any, objs: Iterable[Any]
179 ) -> None:
180 """
181 Schedule a field update. 'objs' must be a homogeneous iterable
182 collection of model instances (e.g. a QuerySet).
183 """
184 self.field_updates[field, value].append(objs)
185
186 def add_restricted_objects(self, field: RelatedField, objs: Iterable[Any]) -> None:
187 if objs:
188 model = objs[0].__class__
189 self.restricted_objects[model][field].update(objs)
190
191 def clear_restricted_objects_from_set(self, model: Any, objs: set[Any]) -> None:
192 if model in self.restricted_objects:
193 self.restricted_objects[model] = {
194 field: items - objs
195 for field, items in self.restricted_objects[model].items()
196 }
197
198 def clear_restricted_objects_from_queryset(self, model: Any, qs: QuerySet) -> None:
199 if model in self.restricted_objects:
200 objs = set(
201 qs.filter(
202 id__in=[
203 obj.id
204 for objs in self.restricted_objects[model].values()
205 for obj in objs
206 ]
207 )
208 )
209 self.clear_restricted_objects_from_set(model, objs)
210
211 def can_fast_delete(self, objs: Any, from_field: Any = None) -> bool:
212 """
213 Determine if the objects in the given queryset-like or single object
214 can be fast-deleted. This can be done if there are no cascades, no
215 parents and no signal listeners for the object class.
216
217 The 'from_field' tells where we are coming from - we need this to
218 determine if the objects are in fact to be deleted. Allow also
219 skipping parent -> child -> parent chain preventing fast delete of
220 the child.
221 """
222 from plain.models.fields.related import RelatedField
223
224 if (
225 isinstance(from_field, RelatedField)
226 and from_field.remote_field.on_delete is not CASCADE
227 ):
228 return False
229 if hasattr(objs, "_model_meta"):
230 model = objs._model_meta.model
231 elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
232 model = objs.model
233 else:
234 return False
235
236 # The use of from_field comes from the need to avoid cascade back to
237 # parent when parent delete is cascading to child.
238 meta = model._model_meta
239 return (
240 # Foreign keys pointing to this model.
241 all(
242 related.field.remote_field.on_delete is DO_NOTHING
243 for related in get_candidate_relations_to_delete(meta)
244 )
245 )
246
247 def get_del_batches(self, objs: list[Any], fields: list[Field]) -> list[list[Any]]:
248 """
249 Return the objs in suitably sized batches for the used db_connection.
250 """
251 field_names = [field.name for field in fields]
252 conn_batch_size = max(
253 db_connection.ops.bulk_batch_size(field_names, objs),
254 1,
255 )
256 if len(objs) > conn_batch_size:
257 return [
258 objs[i : i + conn_batch_size]
259 for i in range(0, len(objs), conn_batch_size)
260 ]
261 else:
262 return [objs]
263
264 def collect(
265 self,
266 objs: Iterable[Any],
267 source: Any = None,
268 nullable: bool = False,
269 collect_related: bool = True,
270 reverse_dependency: bool = False,
271 fail_on_restricted: bool = True,
272 ) -> None:
273 """
274 Add 'objs' to the collection of objects to be deleted as well as all
275 parent instances. 'objs' must be a homogeneous iterable collection of
276 model instances (e.g. a QuerySet). If 'collect_related' is True,
277 related objects will be handled by their respective on_delete handler.
278
279 If the call is the result of a cascade, 'source' should be the model
280 that caused it and 'nullable' should be set to True, if the relation
281 can be null.
282
283 If 'reverse_dependency' is True, 'source' will be deleted before the
284 current model, rather than after. (Needed for cascading to parent
285 models, the one case in which the cascade follows the forwards
286 direction of an FK rather than the reverse direction.)
287
288 If 'fail_on_restricted' is False, error won't be raised even if it's
289 prohibited to delete such objects due to RESTRICT, that defers
290 restricted object checking in recursive calls where the top-level call
291 may need to collect more objects to determine whether restricted ones
292 can be deleted.
293 """
294 if self.can_fast_delete(objs):
295 self.fast_deletes.append(objs)
296 return
297 new_objs = self.add(
298 objs, source, nullable, reverse_dependency=reverse_dependency
299 )
300 if not new_objs:
301 return
302
303 model = new_objs[0].__class__
304
305 if not collect_related:
306 return
307
308 model_fast_deletes = defaultdict(list)
309 protected_objects = defaultdict(list)
310 for related in get_candidate_relations_to_delete(model._model_meta):
311 field = related.field
312 on_delete = field.remote_field.on_delete
313 if on_delete == DO_NOTHING:
314 continue
315 related_model = related.related_model
316 if self.can_fast_delete(related_model, from_field=field):
317 model_fast_deletes[related_model].append(field)
318 continue
319 batches = self.get_del_batches(new_objs, [field])
320 for batch in batches:
321 sub_objs = self.related_objects(related_model, [field], batch)
322 # Non-referenced fields can be deferred if no signal receivers
323 # are connected for the related model as they'll never be
324 # exposed to the user. Skip field deferring when some
325 # relationships are select_related as interactions between both
326 # features are hard to get right. This should only happen in
327 # the rare cases where .related_objects is overridden anyway.
328 if not sub_objs.sql_query.select_related:
329 referenced_fields = set(
330 chain.from_iterable(
331 (rf.attname for rf in rel.field.foreign_related_fields)
332 for rel in get_candidate_relations_to_delete(
333 related_model._model_meta
334 )
335 )
336 )
337 sub_objs = sub_objs.only(*tuple(referenced_fields))
338 if getattr(on_delete, "lazy_sub_objs", False) or sub_objs:
339 try:
340 on_delete(self, field, sub_objs)
341 except ProtectedError as error:
342 key = f"'{field.model.__name__}.{field.name}'"
343 protected_objects[key] += error.protected_objects
344 if protected_objects:
345 raise ProtectedError(
346 "Cannot delete some instances of model {!r} because they are "
347 "referenced through protected foreign keys: {}.".format(
348 model.__name__,
349 ", ".join(protected_objects),
350 ),
351 set(chain.from_iterable(protected_objects.values())),
352 )
353 for related_model, related_fields in model_fast_deletes.items():
354 batches = self.get_del_batches(new_objs, related_fields)
355 for batch in batches:
356 sub_objs = self.related_objects(related_model, related_fields, batch)
357 self.fast_deletes.append(sub_objs)
358
359 if fail_on_restricted:
360 # Raise an error if collected restricted objects (RESTRICT) aren't
361 # candidates for deletion also collected via CASCADE.
362 for related_model, instances in self.data.items():
363 self.clear_restricted_objects_from_set(related_model, instances)
364 for qs in self.fast_deletes:
365 self.clear_restricted_objects_from_queryset(qs.model, qs)
366 if self.restricted_objects.values():
367 restricted_objects = defaultdict(list)
368 for related_model, fields in self.restricted_objects.items():
369 for field, objs in fields.items():
370 if objs:
371 key = f"'{related_model.__name__}.{field.name}'"
372 restricted_objects[key] += objs
373 if restricted_objects:
374 raise RestrictedError(
375 "Cannot delete some instances of model {!r} because "
376 "they are referenced through restricted foreign keys: "
377 "{}.".format(
378 model.__name__,
379 ", ".join(restricted_objects),
380 ),
381 set(chain.from_iterable(restricted_objects.values())),
382 )
383
384 def related_objects(
385 self, related_model: Any, related_fields: list[Field], objs: Iterable[Any]
386 ) -> QuerySet:
387 """
388 Get a QuerySet of the related model to objs via related fields.
389 """
390 predicate = query_utils.Q.create(
391 [(f"{related_field.name}__in", objs) for related_field in related_fields],
392 connector=query_utils.Q.OR,
393 )
394 return related_model._model_meta.base_queryset.filter(predicate)
395
396 def sort(self) -> None:
397 sorted_models = []
398 concrete_models = set()
399 models = list(self.data)
400 while len(sorted_models) < len(models):
401 found = False
402 for model in models:
403 if model in sorted_models:
404 continue
405 dependencies = self.dependencies.get(model)
406 if not (dependencies and dependencies.difference(concrete_models)):
407 sorted_models.append(model)
408 concrete_models.add(model)
409 found = True
410 if not found:
411 return
412 self.data = defaultdict(
413 set, {model: self.data[model] for model in sorted_models}
414 )
415
416 def delete(self) -> tuple[int, dict[str, int]]:
417 # sort instance collections
418 for model, instances in self.data.items():
419 self.data[model] = sorted(instances, key=attrgetter("id"))
420
421 # if possible, bring the models in an order suitable for databases that
422 # don't support transactions or cannot defer constraint checks until the
423 # end of a transaction.
424 self.sort()
425 # number of objects deleted for each model label
426 deleted_counter = Counter()
427
428 # Optimize for the case with a single obj and no dependencies
429 if len(self.data) == 1 and len(instances) == 1:
430 instance = list(instances)[0]
431 if self.can_fast_delete(instance):
432 with transaction.mark_for_rollback_on_error():
433 count = sql.DeleteQuery(model).delete_batch([instance.id])
434 setattr(
435 instance, model._model_meta.get_forward_field("id").attname, None
436 )
437 return count, {model.model_options.label: count}
438
439 with transaction.atomic(savepoint=False):
440 # fast deletes
441 for qs in self.fast_deletes:
442 count = qs._raw_delete()
443 if count:
444 deleted_counter[qs.model.model_options.label] += count
445
446 # update fields
447 for (field, value), instances_list in self.field_updates.items():
448 assert field.name is not None
449 updates = []
450 objs = []
451 for instances in instances_list:
452 if (
453 isinstance(instances, QuerySet)
454 and instances._result_cache is None
455 ):
456 updates.append(instances)
457 else:
458 objs.extend(instances)
459 if updates:
460 combined_updates = reduce(or_, updates)
461 combined_updates.update(**{field.name: value})
462 if objs:
463 model = objs[0].__class__
464 query = sql.UpdateQuery(model)
465 query.update_batch(
466 list({obj.id for obj in objs}), {field.name: value}
467 )
468
469 # reverse instance collections
470 for instances in self.data.values():
471 instances.reverse()
472
473 # delete instances
474 for model, instances in self.data.items():
475 query = sql.DeleteQuery(model)
476 id_list = [obj.id for obj in instances]
477 count = query.delete_batch(id_list)
478 if count:
479 deleted_counter[model.model_options.label] += count
480
481 for model, instances in self.data.items():
482 for instance in instances:
483 setattr(
484 instance, model._model_meta.get_forward_field("id").attname, None
485 )
486 return sum(deleted_counter.values()), dict(deleted_counter)