1from __future__ import annotations
2
3from collections import Counter, defaultdict
4from collections.abc import Callable, Generator, Iterable
5from functools import partial, reduce
6from itertools import chain
7from operator import attrgetter, or_
8from typing import TYPE_CHECKING, Any
9
10from plain.models import (
11 query_utils,
12 sql,
13 transaction,
14)
15from plain.models.db import IntegrityError, db_connection
16from plain.models.meta import Meta
17from plain.models.query import QuerySet
18
19if TYPE_CHECKING:
20 from plain.models.fields import Field
21
22
23class ProtectedError(IntegrityError):
24 def __init__(self, msg: str, protected_objects: Iterable[Any]) -> None:
25 self.protected_objects = protected_objects
26 super().__init__(msg, protected_objects)
27
28
29class RestrictedError(IntegrityError):
30 def __init__(self, msg: str, restricted_objects: Iterable[Any]) -> None:
31 self.restricted_objects = restricted_objects
32 super().__init__(msg, restricted_objects)
33
34
35def CASCADE(collector: Collector, field: Field, sub_objs: Any) -> None:
36 collector.collect(
37 sub_objs,
38 source=field.remote_field.model, # type: ignore[attr-defined]
39 nullable=field.allow_null,
40 fail_on_restricted=False,
41 )
42 if field.allow_null and not db_connection.features.can_defer_constraint_checks:
43 collector.add_field_update(field, None, sub_objs)
44
45
46def PROTECT(collector: Collector, field: Field, sub_objs: Any) -> None:
47 raise ProtectedError(
48 f"Cannot delete some instances of model '{field.remote_field.model.__name__}' because they are " # type: ignore[attr-defined]
49 f"referenced through a protected foreign key: '{sub_objs[0].__class__.__name__}.{field.name}'",
50 sub_objs,
51 )
52
53
54def RESTRICT(collector: Collector, field: Field, sub_objs: Any) -> None:
55 collector.add_restricted_objects(field, sub_objs)
56 collector.add_dependency(field.remote_field.model, field.model) # type: ignore[attr-defined]
57
58
59def SET(value: Any) -> Callable[[Collector, Field, Any], None]:
60 if callable(value):
61
62 def set_on_delete(collector: Collector, field: Field, sub_objs: Any) -> None:
63 collector.add_field_update(field, value(), sub_objs)
64
65 else:
66
67 def set_on_delete(collector: Collector, field: Field, sub_objs: Any) -> None:
68 collector.add_field_update(field, value, sub_objs)
69
70 set_on_delete.deconstruct = lambda: ("plain.models.SET", (value,), {}) # type: ignore[attr-defined]
71 set_on_delete.lazy_sub_objs = True # type: ignore[attr-defined]
72 return set_on_delete
73
74
75def SET_NULL(collector: Collector, field: Field, sub_objs: Any) -> None:
76 collector.add_field_update(field, None, sub_objs)
77
78
79SET_NULL.lazy_sub_objs = True # type: ignore[attr-defined]
80
81
82def SET_DEFAULT(collector: Collector, field: Field, sub_objs: Any) -> None:
83 collector.add_field_update(field, field.get_default(), sub_objs)
84
85
86SET_DEFAULT.lazy_sub_objs = True # type: ignore[attr-defined]
87
88
89def DO_NOTHING(collector: Collector, field: Field, sub_objs: Any) -> None:
90 pass
91
92
93def get_candidate_relations_to_delete(meta: Meta) -> Generator[Any, None, None]:
94 # The candidate relations are the ones that come from N-1 and 1-1 relations.
95 # N-N (i.e., many-to-many) relations aren't candidates for deletion.
96 return (
97 f
98 for f in meta.get_fields(include_hidden=True)
99 if f.auto_created and not f.concrete and f.one_to_many
100 )
101
102
103class Collector:
104 def __init__(self, origin: Any = None) -> None:
105 # A Model or QuerySet object.
106 self.origin = origin
107 # Initially, {model: {instances}}, later values become lists.
108 self.data: defaultdict[Any, Any] = defaultdict(set)
109 # {(field, value): [instances, …]}
110 self.field_updates: defaultdict[tuple[Field, Any], list[Any]] = defaultdict(
111 list
112 )
113 # {model: {field: {instances}}}
114 self.restricted_objects: defaultdict[Any, Any] = defaultdict(
115 partial(defaultdict, set)
116 )
117 # fast_deletes is a list of queryset-likes that can be deleted without
118 # fetching the objects into memory.
119 self.fast_deletes: list[Any] = []
120
121 # Tracks deletion-order dependency for databases without transactions
122 # or ability to defer constraint checks. Only concrete model classes
123 # should be included, as the dependencies exist only between actual
124 # database tables.
125 self.dependencies: defaultdict[Any, set[Any]] = defaultdict(
126 set
127 ) # {model: {models}}
128
129 def add(
130 self,
131 objs: Iterable[Any],
132 source: Any = None,
133 nullable: bool = False,
134 reverse_dependency: bool = False,
135 ) -> list[Any]:
136 """
137 Add 'objs' to the collection of objects to be deleted. If the call is
138 the result of a cascade, 'source' should be the model that caused it,
139 and 'nullable' should be set to True if the relation can be null.
140
141 Return a list of all objects that were not already collected.
142 """
143 if not objs:
144 return []
145 new_objs = []
146 model = objs[0].__class__
147 instances = self.data[model]
148 for obj in objs:
149 if obj not in instances:
150 new_objs.append(obj)
151 instances.update(new_objs)
152 # Nullable relationships can be ignored -- they are nulled out before
153 # deleting, and therefore do not affect the order in which objects have
154 # to be deleted.
155 if source is not None and not nullable:
156 self.add_dependency(source, model, reverse_dependency=reverse_dependency)
157 return new_objs
158
159 def add_dependency(
160 self, model: Any, dependency: Any, reverse_dependency: bool = False
161 ) -> None:
162 if reverse_dependency:
163 model, dependency = dependency, model
164 self.dependencies[model].add(dependency)
165 self.data.setdefault(dependency, self.data.default_factory())
166
167 def add_field_update(self, field: Field, value: Any, objs: Iterable[Any]) -> None:
168 """
169 Schedule a field update. 'objs' must be a homogeneous iterable
170 collection of model instances (e.g. a QuerySet).
171 """
172 self.field_updates[field, value].append(objs)
173
174 def add_restricted_objects(self, field: Field, objs: Iterable[Any]) -> None:
175 if objs:
176 model = objs[0].__class__
177 self.restricted_objects[model][field].update(objs)
178
179 def clear_restricted_objects_from_set(self, model: Any, objs: set[Any]) -> None:
180 if model in self.restricted_objects:
181 self.restricted_objects[model] = {
182 field: items - objs
183 for field, items in self.restricted_objects[model].items()
184 }
185
186 def clear_restricted_objects_from_queryset(self, model: Any, qs: QuerySet) -> None:
187 if model in self.restricted_objects:
188 objs = set(
189 qs.filter(
190 id__in=[
191 obj.id
192 for objs in self.restricted_objects[model].values()
193 for obj in objs
194 ]
195 )
196 )
197 self.clear_restricted_objects_from_set(model, objs)
198
199 def can_fast_delete(self, objs: Any, from_field: Any = None) -> bool:
200 """
201 Determine if the objects in the given queryset-like or single object
202 can be fast-deleted. This can be done if there are no cascades, no
203 parents and no signal listeners for the object class.
204
205 The 'from_field' tells where we are coming from - we need this to
206 determine if the objects are in fact to be deleted. Allow also
207 skipping parent -> child -> parent chain preventing fast delete of
208 the child.
209 """
210 if from_field and from_field.remote_field.on_delete is not CASCADE:
211 return False
212 if hasattr(objs, "_model_meta"):
213 model = objs._model_meta.model
214 elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
215 model = objs.model
216 else:
217 return False
218
219 # The use of from_field comes from the need to avoid cascade back to
220 # parent when parent delete is cascading to child.
221 meta = model._model_meta
222 return (
223 # Foreign keys pointing to this model.
224 all(
225 related.field.remote_field.on_delete is DO_NOTHING
226 for related in get_candidate_relations_to_delete(meta)
227 )
228 )
229
230 def get_del_batches(self, objs: list[Any], fields: list[Field]) -> list[list[Any]]:
231 """
232 Return the objs in suitably sized batches for the used db_connection.
233 """
234 field_names = [field.name for field in fields]
235 conn_batch_size = max(
236 db_connection.ops.bulk_batch_size(field_names, objs),
237 1,
238 )
239 if len(objs) > conn_batch_size:
240 return [
241 objs[i : i + conn_batch_size]
242 for i in range(0, len(objs), conn_batch_size)
243 ]
244 else:
245 return [objs]
246
247 def collect(
248 self,
249 objs: Iterable[Any],
250 source: Any = None,
251 nullable: bool = False,
252 collect_related: bool = True,
253 reverse_dependency: bool = False,
254 fail_on_restricted: bool = True,
255 ) -> None:
256 """
257 Add 'objs' to the collection of objects to be deleted as well as all
258 parent instances. 'objs' must be a homogeneous iterable collection of
259 model instances (e.g. a QuerySet). If 'collect_related' is True,
260 related objects will be handled by their respective on_delete handler.
261
262 If the call is the result of a cascade, 'source' should be the model
263 that caused it and 'nullable' should be set to True, if the relation
264 can be null.
265
266 If 'reverse_dependency' is True, 'source' will be deleted before the
267 current model, rather than after. (Needed for cascading to parent
268 models, the one case in which the cascade follows the forwards
269 direction of an FK rather than the reverse direction.)
270
271 If 'fail_on_restricted' is False, error won't be raised even if it's
272 prohibited to delete such objects due to RESTRICT, that defers
273 restricted object checking in recursive calls where the top-level call
274 may need to collect more objects to determine whether restricted ones
275 can be deleted.
276 """
277 if self.can_fast_delete(objs):
278 self.fast_deletes.append(objs)
279 return
280 new_objs = self.add(
281 objs, source, nullable, reverse_dependency=reverse_dependency
282 )
283 if not new_objs:
284 return
285
286 model = new_objs[0].__class__
287
288 if not collect_related:
289 return
290
291 model_fast_deletes = defaultdict(list)
292 protected_objects = defaultdict(list)
293 for related in get_candidate_relations_to_delete(model._model_meta):
294 field = related.field
295 on_delete = field.remote_field.on_delete
296 if on_delete == DO_NOTHING:
297 continue
298 related_model = related.related_model
299 if self.can_fast_delete(related_model, from_field=field):
300 model_fast_deletes[related_model].append(field)
301 continue
302 batches = self.get_del_batches(new_objs, [field])
303 for batch in batches:
304 sub_objs = self.related_objects(related_model, [field], batch)
305 # Non-referenced fields can be deferred if no signal receivers
306 # are connected for the related model as they'll never be
307 # exposed to the user. Skip field deferring when some
308 # relationships are select_related as interactions between both
309 # features are hard to get right. This should only happen in
310 # the rare cases where .related_objects is overridden anyway.
311 if not sub_objs.sql_query.select_related:
312 referenced_fields = set(
313 chain.from_iterable(
314 (rf.attname for rf in rel.field.foreign_related_fields)
315 for rel in get_candidate_relations_to_delete(
316 related_model._model_meta
317 )
318 )
319 )
320 sub_objs = sub_objs.only(*tuple(referenced_fields))
321 if getattr(on_delete, "lazy_sub_objs", False) or sub_objs:
322 try:
323 on_delete(self, field, sub_objs)
324 except ProtectedError as error:
325 key = f"'{field.model.__name__}.{field.name}'"
326 protected_objects[key] += error.protected_objects
327 if protected_objects:
328 raise ProtectedError(
329 "Cannot delete some instances of model {!r} because they are "
330 "referenced through protected foreign keys: {}.".format(
331 model.__name__,
332 ", ".join(protected_objects),
333 ),
334 set(chain.from_iterable(protected_objects.values())),
335 )
336 for related_model, related_fields in model_fast_deletes.items():
337 batches = self.get_del_batches(new_objs, related_fields)
338 for batch in batches:
339 sub_objs = self.related_objects(related_model, related_fields, batch)
340 self.fast_deletes.append(sub_objs)
341
342 if fail_on_restricted:
343 # Raise an error if collected restricted objects (RESTRICT) aren't
344 # candidates for deletion also collected via CASCADE.
345 for related_model, instances in self.data.items():
346 self.clear_restricted_objects_from_set(related_model, instances)
347 for qs in self.fast_deletes:
348 self.clear_restricted_objects_from_queryset(qs.model, qs)
349 if self.restricted_objects.values():
350 restricted_objects = defaultdict(list)
351 for related_model, fields in self.restricted_objects.items():
352 for field, objs in fields.items():
353 if objs:
354 key = f"'{related_model.__name__}.{field.name}'"
355 restricted_objects[key] += objs
356 if restricted_objects:
357 raise RestrictedError(
358 "Cannot delete some instances of model {!r} because "
359 "they are referenced through restricted foreign keys: "
360 "{}.".format(
361 model.__name__,
362 ", ".join(restricted_objects),
363 ),
364 set(chain.from_iterable(restricted_objects.values())),
365 )
366
367 def related_objects(
368 self, related_model: Any, related_fields: list[Field], objs: Iterable[Any]
369 ) -> QuerySet:
370 """
371 Get a QuerySet of the related model to objs via related fields.
372 """
373 predicate = query_utils.Q.create(
374 [(f"{related_field.name}__in", objs) for related_field in related_fields],
375 connector=query_utils.Q.OR,
376 )
377 return related_model._model_meta.base_queryset.filter(predicate)
378
379 def sort(self) -> None:
380 sorted_models = []
381 concrete_models = set()
382 models = list(self.data)
383 while len(sorted_models) < len(models):
384 found = False
385 for model in models:
386 if model in sorted_models:
387 continue
388 dependencies = self.dependencies.get(model)
389 if not (dependencies and dependencies.difference(concrete_models)):
390 sorted_models.append(model)
391 concrete_models.add(model)
392 found = True
393 if not found:
394 return
395 self.data = {model: self.data[model] for model in sorted_models}
396
397 def delete(self) -> tuple[int, dict[str, int]]:
398 # sort instance collections
399 for model, instances in self.data.items():
400 self.data[model] = sorted(instances, key=attrgetter("id"))
401
402 # if possible, bring the models in an order suitable for databases that
403 # don't support transactions or cannot defer constraint checks until the
404 # end of a transaction.
405 self.sort()
406 # number of objects deleted for each model label
407 deleted_counter = Counter()
408
409 # Optimize for the case with a single obj and no dependencies
410 if len(self.data) == 1 and len(instances) == 1:
411 instance = list(instances)[0]
412 if self.can_fast_delete(instance):
413 with transaction.mark_for_rollback_on_error():
414 count = sql.DeleteQuery(model).delete_batch([instance.id])
415 setattr(instance, model._model_meta.get_field("id").attname, None)
416 return count, {model.model_options.label: count}
417
418 with transaction.atomic(savepoint=False):
419 # fast deletes
420 for qs in self.fast_deletes:
421 count = qs._raw_delete()
422 if count:
423 deleted_counter[qs.model.model_options.label] += count
424
425 # update fields
426 for (field, value), instances_list in self.field_updates.items():
427 updates = []
428 objs = []
429 for instances in instances_list:
430 if (
431 isinstance(instances, QuerySet)
432 and instances._result_cache is None
433 ):
434 updates.append(instances)
435 else:
436 objs.extend(instances)
437 if updates:
438 combined_updates = reduce(or_, updates)
439 combined_updates.update(**{field.name: value})
440 if objs:
441 model = objs[0].__class__
442 query = sql.UpdateQuery(model)
443 query.update_batch(
444 list({obj.id for obj in objs}), {field.name: value}
445 )
446
447 # reverse instance collections
448 for instances in self.data.values():
449 instances.reverse()
450
451 # delete instances
452 for model, instances in self.data.items():
453 query = sql.DeleteQuery(model)
454 id_list = [obj.id for obj in instances]
455 count = query.delete_batch(id_list)
456 if count:
457 deleted_counter[model.model_options.label] += count
458
459 for model, instances in self.data.items():
460 for instance in instances:
461 setattr(instance, model._model_meta.get_field("id").attname, None)
462 return sum(deleted_counter.values()), dict(deleted_counter)