1"""
2The main QuerySet implementation. This provides the public API for the ORM.
3"""
4
5from __future__ import annotations
6
7import copy
8import operator
9import warnings
10from collections.abc import Callable, Iterator, Sequence
11from functools import cached_property
12from itertools import chain, islice
13from typing import TYPE_CHECKING, Any, Never, Self, overload
14
15import plain.runtime
16from plain.exceptions import ValidationError
17from plain.postgres import transaction
18from plain.postgres.constants import LOOKUP_SEP, OnConflict
19from plain.postgres.db import (
20 PLAIN_VERSION_PICKLE_KEY,
21 IntegrityError,
22 get_connection,
23)
24from plain.postgres.exceptions import (
25 FieldDoesNotExist,
26 FieldError,
27 ObjectDoesNotExist,
28)
29from plain.postgres.expressions import Case, F, ResolvableExpression, Value, When
30from plain.postgres.fields import (
31 Field,
32 PrimaryKeyField,
33)
34from plain.postgres.functions import Cast
35from plain.postgres.query_utils import FilteredRelation, Q
36from plain.postgres.sql import (
37 AND,
38 CURSOR,
39 OR,
40 XOR,
41 DeleteQuery,
42 InsertQuery,
43 Query,
44 RawQuery,
45 UpdateQuery,
46)
47from plain.postgres.utils import resolve_callables
48from plain.utils.functional import partition
49
50# Re-exports for public API
51__all__ = ["F", "Q", "QuerySet", "RawQuerySet", "Prefetch", "FilteredRelation"]
52
53if TYPE_CHECKING:
54 from plain.postgres import Model
55
56# The maximum number of results to fetch in a get() query.
57MAX_GET_RESULTS = 21
58
59# The maximum number of items to display in a QuerySet.__repr__
60REPR_OUTPUT_SIZE = 20
61
62
63class BaseIterable:
64 def __init__(
65 self,
66 queryset: QuerySet[Any],
67 chunked_fetch: bool = False,
68 ):
69 self.queryset = queryset
70 self.chunked_fetch = chunked_fetch
71
72 def __iter__(self) -> Iterator[Any]:
73 raise NotImplementedError(
74 "subclasses of BaseIterable must provide an __iter__() method"
75 )
76
77
78class ModelIterable(BaseIterable):
79 """Iterable that yields a model instance for each row."""
80
81 def __iter__(self) -> Iterator[Model]:
82 queryset = self.queryset
83 compiler = queryset.sql_query.get_compiler()
84 # Execute the query. This will also fill compiler.select, klass_info,
85 # and annotations.
86 results = compiler.execute_sql(chunked_fetch=self.chunked_fetch)
87 select, klass_info, annotation_col_map = (
88 compiler.select,
89 compiler.klass_info,
90 compiler.annotation_col_map,
91 )
92 # These are set by execute_sql() above
93 assert select is not None
94 assert klass_info is not None
95 model_cls = klass_info["model"]
96 select_fields = klass_info["select_fields"]
97 model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
98 init_list = [
99 f[0].target.attname for f in select[model_fields_start:model_fields_end]
100 ]
101 related_populators = get_related_populators(klass_info, select)
102 known_related_objects = [
103 (
104 field,
105 related_objs,
106 operator.attrgetter(field.attname),
107 )
108 for field, related_objs in queryset._known_related_objects.items()
109 ]
110 for row in compiler.results_iter(results):
111 obj = model_cls.from_db(init_list, row[model_fields_start:model_fields_end])
112 for rel_populator in related_populators:
113 rel_populator.populate(row, obj)
114 if annotation_col_map:
115 for attr_name, col_pos in annotation_col_map.items():
116 setattr(obj, attr_name, row[col_pos])
117
118 # Add the known related objects to the model.
119 for field, rel_objs, rel_getter in known_related_objects:
120 # Avoid overwriting objects loaded by, e.g., select_related().
121 if field.is_cached(obj):
122 continue
123 rel_obj_id = rel_getter(obj)
124 try:
125 rel_obj = rel_objs[rel_obj_id]
126 except KeyError:
127 pass # May happen in qs1 | qs2 scenarios.
128 else:
129 setattr(obj, field.name, rel_obj)
130
131 yield obj
132
133
134class RawModelIterable(BaseIterable):
135 """
136 Iterable that yields a model instance for each row from a raw queryset.
137 """
138
139 queryset: RawQuerySet
140
141 def __iter__(self) -> Iterator[Model]:
142 from plain.postgres.sql.compiler import apply_converters, get_converters
143
144 query = self.queryset.sql_query
145 connection = get_connection()
146 query_iterator: Iterator[Any] = iter(query)
147
148 try:
149 (
150 model_init_names,
151 model_init_pos,
152 annotation_fields,
153 ) = self.queryset.resolve_model_init_order()
154 model_cls = self.queryset.model
155 assert model_cls is not None
156 if "id" not in model_init_names:
157 raise FieldDoesNotExist("Raw query must include the primary key")
158 fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
159 converters = get_converters(
160 [
161 f.get_col(f.model.model_options.db_table) if f else None
162 for f in fields
163 ],
164 connection,
165 )
166 if converters:
167 query_iterator = apply_converters(
168 query_iterator, converters, connection
169 )
170 for values in query_iterator:
171 # Associate fields to values
172 model_init_values = [values[pos] for pos in model_init_pos]
173 instance = model_cls.from_db(model_init_names, model_init_values)
174 if annotation_fields:
175 for column, pos in annotation_fields:
176 setattr(instance, column, values[pos])
177 yield instance
178 finally:
179 # Done iterating the Query. If it has its own cursor, close it.
180 if hasattr(query, "cursor") and query.cursor:
181 query.cursor.close()
182
183
184class ValuesIterable(BaseIterable):
185 """
186 Iterable returned by QuerySet.values() that yields a dict for each row.
187 """
188
189 def __iter__(self) -> Iterator[dict[str, Any]]:
190 queryset = self.queryset
191 query = queryset.sql_query
192 compiler = query.get_compiler()
193
194 # extra(select=...) cols are always at the start of the row.
195 names = [
196 *query.extra_select,
197 *query.values_select,
198 *query.annotation_select,
199 ]
200 indexes = range(len(names))
201 for row in compiler.results_iter(chunked_fetch=self.chunked_fetch):
202 yield {names[i]: row[i] for i in indexes}
203
204
205class ValuesListIterable(BaseIterable):
206 """
207 Iterable returned by QuerySet.values_list(flat=False) that yields a tuple
208 for each row.
209 """
210
211 def __iter__(self) -> Iterator[tuple[Any, ...]]:
212 queryset = self.queryset
213 query = queryset.sql_query
214 compiler = query.get_compiler()
215
216 if queryset._fields:
217 # extra(select=...) cols are always at the start of the row.
218 names = [
219 *query.extra_select,
220 *query.values_select,
221 *query.annotation_select,
222 ]
223 fields = [
224 *queryset._fields,
225 *(f for f in query.annotation_select if f not in queryset._fields),
226 ]
227 if fields != names:
228 # Reorder according to fields.
229 index_map = {name: idx for idx, name in enumerate(names)}
230 rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
231 return map(
232 rowfactory,
233 compiler.results_iter(chunked_fetch=self.chunked_fetch),
234 )
235 return iter(
236 compiler.results_iter(
237 tuple_expected=True,
238 chunked_fetch=self.chunked_fetch,
239 )
240 )
241
242
243class FlatValuesListIterable(BaseIterable):
244 """
245 Iterable returned by QuerySet.values_list(flat=True) that yields single
246 values.
247 """
248
249 def __iter__(self) -> Iterator[Any]:
250 queryset = self.queryset
251 compiler = queryset.sql_query.get_compiler()
252 for row in compiler.results_iter(chunked_fetch=self.chunked_fetch):
253 yield row[0]
254
255
256class QuerySet[T: "Model"]:
257 """
258 Represent a lazy database lookup for a set of objects.
259
260 Usage:
261 MyModel.query.filter(name="test").all()
262
263 Custom QuerySets:
264 from typing import Self
265
266 class TaskQuerySet(QuerySet["Task"]):
267 def active(self) -> Self:
268 return self.filter(is_active=True)
269
270 class Task(Model):
271 is_active = BooleanField(default=True)
272 query = TaskQuerySet()
273
274 Task.query.active().filter(name="test") # Full type inference
275
276 Custom methods should return `Self` to preserve type through method chaining.
277 """
278
279 # Instance attributes (set in from_model())
280 model: type[T]
281 _query: Query
282 _result_cache: list[T] | None
283 _sticky_filter: bool
284 _for_write: bool
285 _prefetch_related_lookups: tuple[Any, ...]
286 _prefetch_done: bool
287 _known_related_objects: dict[Any, dict[Any, Any]]
288 _iterable_class: type[BaseIterable]
289 _fields: tuple[str, ...] | None
290 _defer_next_filter: bool
291 _deferred_filter: tuple[bool, tuple[Any, ...], dict[str, Any]] | None
292
293 def __init__(self):
294 """Minimal init for descriptor mode. Use from_model() to create instances."""
295 pass
296
297 @classmethod
298 def from_model(cls, model: type[T], query: Query | None = None) -> Self:
299 """Create a QuerySet instance bound to a model."""
300 instance = cls()
301 instance.model = model
302 instance._query = query or Query(model)
303 instance._result_cache = None
304 instance._sticky_filter = False
305 instance._for_write = False
306 instance._prefetch_related_lookups = ()
307 instance._prefetch_done = False
308 instance._known_related_objects = {}
309 instance._iterable_class = ModelIterable
310 instance._fields = None
311 instance._defer_next_filter = False
312 instance._deferred_filter = None
313 return instance
314
315 @overload
316 def __get__(self, instance: None, owner: type[T]) -> Self: ...
317
318 @overload
319 def __get__(self, instance: Model, owner: type[T]) -> Never: ...
320
321 def __get__(self, instance: Any, owner: type[T]) -> Self:
322 """Descriptor protocol - return a new QuerySet bound to the model."""
323 if instance is not None:
324 raise AttributeError(
325 f"QuerySet is only accessible from the model class, not instances. "
326 f"Use {owner.__name__}.query instead."
327 )
328 return self.from_model(owner)
329
330 @property
331 def sql_query(self) -> Query:
332 if self._deferred_filter:
333 negate, args, kwargs = self._deferred_filter
334 self._filter_or_exclude_inplace(negate, args, kwargs)
335 self._deferred_filter = None
336 return self._query
337
338 @sql_query.setter
339 def sql_query(self, value: Query) -> None:
340 if value.values_select:
341 self._iterable_class = ValuesIterable
342 self._query = value
343
344 ########################
345 # PYTHON MAGIC METHODS #
346 ########################
347
348 def __deepcopy__(self, memo: dict[int, Any]) -> QuerySet[T]:
349 """Don't populate the QuerySet's cache."""
350 obj = self.__class__.from_model(self.model)
351 for k, v in self.__dict__.items():
352 if k == "_result_cache":
353 obj.__dict__[k] = None
354 else:
355 obj.__dict__[k] = copy.deepcopy(v, memo)
356 return obj
357
358 def __getstate__(self) -> dict[str, Any]:
359 # Force the cache to be fully populated.
360 self._fetch_all()
361 return {**self.__dict__, PLAIN_VERSION_PICKLE_KEY: plain.runtime.__version__}
362
363 def __setstate__(self, state: dict[str, Any]) -> None:
364 pickled_version = state.get(PLAIN_VERSION_PICKLE_KEY)
365 if pickled_version:
366 if pickled_version != plain.runtime.__version__:
367 warnings.warn(
368 f"Pickled queryset instance's Plain version {pickled_version} does not "
369 f"match the current version {plain.runtime.__version__}.",
370 RuntimeWarning,
371 stacklevel=2,
372 )
373 else:
374 warnings.warn(
375 "Pickled queryset instance's Plain version is not specified.",
376 RuntimeWarning,
377 stacklevel=2,
378 )
379 self.__dict__.update(state)
380
381 def __repr__(self) -> str:
382 data = list(self[: REPR_OUTPUT_SIZE + 1])
383 if len(data) > REPR_OUTPUT_SIZE:
384 data[-1] = "...(remaining elements truncated)..."
385 return f"<{self.__class__.__name__} {data!r}>"
386
387 def __len__(self) -> int:
388 self._fetch_all()
389 assert self._result_cache is not None
390 return len(self._result_cache)
391
392 def __iter__(self) -> Iterator[T]:
393 """
394 The queryset iterator protocol uses three nested iterators in the
395 default case:
396 1. sql.compiler.execute_sql()
397 - Returns a flat iterable of rows: a list from fetchall()
398 for regular queries, or a streaming generator from
399 cursor.stream() when using .iterator().
400 2. sql.compiler.results_iter()
401 - Returns one row at a time. At this point the rows are still
402 just tuples. In some cases the return values are converted
403 to Python values at this location.
404 3. self.iterator()
405 - Responsible for turning the rows into model objects.
406 """
407 self._fetch_all()
408 assert self._result_cache is not None
409 return iter(self._result_cache)
410
411 def __bool__(self) -> bool:
412 self._fetch_all()
413 return bool(self._result_cache)
414
415 @overload
416 def __getitem__(self, k: int) -> T: ...
417
418 @overload
419 def __getitem__(self, k: slice) -> QuerySet[T] | list[T]: ...
420
421 def __getitem__(self, k: int | slice) -> T | QuerySet[T] | list[T]:
422 """Retrieve an item or slice from the set of results."""
423 if not isinstance(k, int | slice):
424 raise TypeError(
425 f"QuerySet indices must be integers or slices, not {type(k).__name__}."
426 )
427 if (isinstance(k, int) and k < 0) or (
428 isinstance(k, slice)
429 and (
430 (k.start is not None and k.start < 0)
431 or (k.stop is not None and k.stop < 0)
432 )
433 ):
434 raise ValueError("Negative indexing is not supported.")
435
436 if self._result_cache is not None:
437 return self._result_cache[k]
438
439 if isinstance(k, slice):
440 qs = self._chain()
441 if k.start is not None:
442 start = int(k.start)
443 else:
444 start = None
445 if k.stop is not None:
446 stop = int(k.stop)
447 else:
448 stop = None
449 qs.sql_query.set_limits(start, stop)
450 return list(qs)[:: k.step] if k.step else qs
451
452 qs = self._chain()
453 qs.sql_query.set_limits(k, k + 1)
454 qs._fetch_all()
455 assert qs._result_cache is not None # _fetch_all guarantees this
456 return qs._result_cache[0]
457
458 def __class_getitem__(cls, *args: Any, **kwargs: Any) -> type[QuerySet[Any]]:
459 return cls
460
461 def __and__(self, other: QuerySet[T]) -> QuerySet[T]:
462 self._merge_sanity_check(other)
463 if isinstance(other, EmptyQuerySet):
464 return other
465 if isinstance(self, EmptyQuerySet):
466 return self
467 combined = self._chain()
468 combined._merge_known_related_objects(other)
469 combined.sql_query.combine(other.sql_query, AND)
470 return combined
471
472 def __or__(self, other: QuerySet[T]) -> QuerySet[T]:
473 self._merge_sanity_check(other)
474 if isinstance(self, EmptyQuerySet):
475 return other
476 if isinstance(other, EmptyQuerySet):
477 return self
478 query = (
479 self
480 if self.sql_query.can_filter()
481 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
482 )
483 combined = query._chain()
484 combined._merge_known_related_objects(other)
485 if not other.sql_query.can_filter():
486 other = other.model._model_meta.base_queryset.filter(
487 id__in=other.values("id")
488 )
489 combined.sql_query.combine(other.sql_query, OR)
490 return combined
491
492 def __xor__(self, other: QuerySet[T]) -> QuerySet[T]:
493 self._merge_sanity_check(other)
494 if isinstance(self, EmptyQuerySet):
495 return other
496 if isinstance(other, EmptyQuerySet):
497 return self
498 query = (
499 self
500 if self.sql_query.can_filter()
501 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
502 )
503 combined = query._chain()
504 combined._merge_known_related_objects(other)
505 if not other.sql_query.can_filter():
506 other = other.model._model_meta.base_queryset.filter(
507 id__in=other.values("id")
508 )
509 combined.sql_query.combine(other.sql_query, XOR)
510 return combined
511
512 ####################################
513 # METHODS THAT DO DATABASE QUERIES #
514 ####################################
515
516 def _iterator(self, use_chunked_fetch: bool, chunk_size: int | None) -> Iterator[T]:
517 iterable = self._iterable_class(
518 self,
519 chunked_fetch=use_chunked_fetch,
520 )
521 if not self._prefetch_related_lookups or chunk_size is None:
522 yield from iterable
523 return
524
525 iterator = iter(iterable)
526 while results := list(islice(iterator, chunk_size)):
527 prefetch_related_objects(results, *self._prefetch_related_lookups)
528 yield from results
529
530 def iterator(self, chunk_size: int | None = None) -> Iterator[T]:
531 """
532 An iterator over the results from applying this QuerySet to the
533 database. chunk_size must be provided for QuerySets that prefetch
534 related objects. Otherwise, a default chunk_size of 2000 is supplied.
535 """
536 if chunk_size is None:
537 if self._prefetch_related_lookups:
538 raise ValueError(
539 "chunk_size must be provided when using QuerySet.iterator() after "
540 "prefetch_related()."
541 )
542 elif chunk_size <= 0:
543 raise ValueError("Chunk size must be strictly positive.")
544 # PostgreSQL always supports server-side cursors for chunked fetches
545 return self._iterator(use_chunked_fetch=True, chunk_size=chunk_size)
546
547 def aggregate(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
548 """
549 Return a dictionary containing the calculations (aggregation)
550 over the current queryset.
551
552 If args is present the expression is passed as a kwarg using
553 the Aggregate object's default alias.
554 """
555 if self.sql_query.distinct_fields:
556 raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
557 self._validate_values_are_expressions(
558 (*args, *kwargs.values()), method_name="aggregate"
559 )
560 for arg in args:
561 # The default_alias property raises TypeError if default_alias
562 # can't be set automatically or AttributeError if it isn't an
563 # attribute.
564 try:
565 arg.default_alias
566 except (AttributeError, TypeError):
567 raise TypeError("Complex aggregates require an alias")
568 kwargs[arg.default_alias] = arg
569
570 return self.sql_query.chain().get_aggregation(kwargs)
571
572 def count(self) -> int:
573 """
574 Perform a SELECT COUNT() and return the number of records as an
575 integer.
576
577 If the QuerySet is already fully cached, return the length of the
578 cached results set to avoid multiple SELECT COUNT(*) calls.
579 """
580 if self._result_cache is not None:
581 return len(self._result_cache)
582
583 return self.sql_query.get_count()
584
585 def get(self, *args: Any, **kwargs: Any) -> T:
586 """
587 Perform the query and return a single object matching the given
588 keyword arguments.
589 """
590 clone = self.filter(*args, **kwargs)
591 if self.sql_query.can_filter() and not self.sql_query.distinct_fields:
592 clone = clone.order_by()
593 limit = MAX_GET_RESULTS
594 clone.sql_query.set_limits(high=limit)
595 num = len(clone)
596 if num == 1:
597 assert clone._result_cache is not None # len() fetches results
598 return clone._result_cache[0]
599 if not num:
600 raise self.model.DoesNotExist(
601 f"{self.model.model_options.object_name} matching query does not exist."
602 )
603 raise self.model.MultipleObjectsReturned(
604 "get() returned more than one {} -- it returned {}!".format(
605 self.model.model_options.object_name,
606 num if not limit or num < limit else "more than %s" % (limit - 1),
607 )
608 )
609
610 def get_or_none(self, *args: Any, **kwargs: Any) -> T | None:
611 """
612 Perform the query and return a single object matching the given
613 keyword arguments, or None if no object is found.
614 """
615 try:
616 return self.get(*args, **kwargs)
617 except self.model.DoesNotExist:
618 return None
619
620 def create(self, **kwargs: Any) -> T:
621 """
622 Create a new object with the given kwargs, saving it to the database
623 and returning the created object.
624 """
625 obj = self.model(**kwargs)
626 self._for_write = True
627 obj.save(force_insert=True)
628 return obj
629
630 def _prepare_for_bulk_create(self, objs: list[T]) -> None:
631 id_field = self.model._model_meta.get_forward_field("id")
632 for obj in objs:
633 if obj.id is None:
634 # Populate new primary key values.
635 obj.id = id_field.get_id_value_on_save(obj)
636 obj._prepare_related_fields_for_save(operation_name="bulk_create")
637
638 def _check_bulk_create_options(
639 self,
640 update_conflicts: bool,
641 update_fields: list[Field] | None,
642 unique_fields: list[Field] | None,
643 ) -> OnConflict | None:
644 if update_conflicts:
645 if not update_fields:
646 raise ValueError(
647 "Fields that will be updated when a row insertion fails "
648 "on conflicts must be provided."
649 )
650 if not unique_fields:
651 raise ValueError(
652 "Unique fields that can trigger the upsert must be provided."
653 )
654 # Updating primary keys and non-concrete fields is forbidden.
655 from plain.postgres.fields.related import ManyToManyField
656
657 if any(
658 not f.concrete or isinstance(f, ManyToManyField) for f in update_fields
659 ):
660 raise ValueError(
661 "bulk_create() can only be used with concrete fields in "
662 "update_fields."
663 )
664 if any(f.primary_key for f in update_fields):
665 raise ValueError(
666 "bulk_create() cannot be used with primary keys in update_fields."
667 )
668 if unique_fields:
669 from plain.postgres.fields.related import ManyToManyField
670
671 if any(
672 not f.concrete or isinstance(f, ManyToManyField)
673 for f in unique_fields
674 ):
675 raise ValueError(
676 "bulk_create() can only be used with concrete fields "
677 "in unique_fields."
678 )
679 return OnConflict.UPDATE
680 return None
681
682 def bulk_create(
683 self,
684 objs: Sequence[T],
685 batch_size: int | None = None,
686 update_conflicts: bool = False,
687 update_fields: list[str] | None = None,
688 unique_fields: list[str] | None = None,
689 ) -> list[T]:
690 """
691 Insert each of the instances into the database. Do *not* call
692 save() on each of the instances. Primary keys are set on the objects
693 via the PostgreSQL RETURNING clause. Multi-table models are not supported.
694 """
695 if batch_size is not None and batch_size <= 0:
696 raise ValueError("Batch size must be a positive integer.")
697
698 objs = list(objs)
699 if not objs:
700 return objs
701 meta = self.model._model_meta
702 unique_fields_objs: list[Field] | None = None
703 update_fields_objs: list[Field] | None = None
704 if unique_fields:
705 unique_fields_objs = [
706 meta.get_forward_field(name) for name in unique_fields
707 ]
708 if update_fields:
709 update_fields_objs = [
710 meta.get_forward_field(name) for name in update_fields
711 ]
712 on_conflict = self._check_bulk_create_options(
713 update_conflicts,
714 update_fields_objs,
715 unique_fields_objs,
716 )
717 self._for_write = True
718 fields = meta.concrete_fields
719 self._prepare_for_bulk_create(objs)
720 with transaction.atomic(savepoint=False):
721 objs_with_id, objs_without_id = partition(lambda o: o.id is None, objs)
722 if objs_with_id:
723 returned_columns = self._batched_insert(
724 objs_with_id,
725 fields,
726 batch_size,
727 on_conflict=on_conflict,
728 update_fields=update_fields_objs,
729 unique_fields=unique_fields_objs,
730 )
731 id_field = meta.get_forward_field("id")
732 for obj_with_id, results in zip(objs_with_id, returned_columns):
733 for result, field in zip(results, meta.db_returning_fields):
734 if field != id_field:
735 setattr(obj_with_id, field.attname, result)
736 for obj_with_id in objs_with_id:
737 obj_with_id._state.adding = False
738 if objs_without_id:
739 fields = [f for f in fields if not isinstance(f, PrimaryKeyField)]
740 returned_columns = self._batched_insert(
741 objs_without_id,
742 fields,
743 batch_size,
744 on_conflict=on_conflict,
745 update_fields=update_fields_objs,
746 unique_fields=unique_fields_objs,
747 )
748 if on_conflict is None:
749 assert len(returned_columns) == len(objs_without_id)
750 for obj_without_id, results in zip(objs_without_id, returned_columns):
751 for result, field in zip(results, meta.db_returning_fields):
752 setattr(obj_without_id, field.attname, result)
753 obj_without_id._state.adding = False
754
755 return objs
756
757 def bulk_update(
758 self, objs: Sequence[T], fields: list[str], batch_size: int | None = None
759 ) -> int:
760 """
761 Update the given fields in each of the given objects in the database.
762 """
763 if batch_size is not None and batch_size <= 0:
764 raise ValueError("Batch size must be a positive integer.")
765 if not fields:
766 raise ValueError("Field names must be given to bulk_update().")
767 objs_tuple = tuple(objs)
768 if any(obj.id is None for obj in objs_tuple):
769 raise ValueError("All bulk_update() objects must have a primary key set.")
770 fields_list = [
771 self.model._model_meta.get_forward_field(name) for name in fields
772 ]
773 from plain.postgres.fields.related import ManyToManyField
774
775 if any(not f.concrete or isinstance(f, ManyToManyField) for f in fields_list):
776 raise ValueError("bulk_update() can only be used with concrete fields.")
777 if any(f.primary_key for f in fields_list):
778 raise ValueError("bulk_update() cannot be used with primary key fields.")
779 if not objs_tuple:
780 return 0
781 for obj in objs_tuple:
782 obj._prepare_related_fields_for_save(
783 operation_name="bulk_update", fields=fields_list
784 )
785 # PK is used twice in the resulting update query, once in the filter
786 # and once in the WHEN. Each field will also have one CAST.
787 self._for_write = True
788 max_batch_size = len(objs_tuple)
789 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
790 batches = (
791 objs_tuple[i : i + batch_size]
792 for i in range(0, len(objs_tuple), batch_size)
793 )
794 updates = []
795 for batch_objs in batches:
796 update_kwargs = {}
797 for field in fields_list:
798 when_statements = []
799 for obj in batch_objs:
800 attr = getattr(obj, field.attname)
801 if not isinstance(attr, ResolvableExpression):
802 attr = Value(attr, output_field=field)
803 when_statements.append(When(id=obj.id, then=attr))
804 case_statement = Case(*when_statements, output_field=field)
805 # PostgreSQL requires casted CASE in updates
806 case_statement = Cast(case_statement, output_field=field)
807 update_kwargs[field.attname] = case_statement
808 updates.append(([obj.id for obj in batch_objs], update_kwargs))
809 rows_updated = 0
810 queryset = self._chain()
811 with transaction.atomic(savepoint=False):
812 for ids, update_kwargs in updates:
813 rows_updated += queryset.filter(id__in=ids).update(**update_kwargs)
814 return rows_updated
815
816 def get_or_create(
817 self, defaults: dict[str, Any] | None = None, **kwargs: Any
818 ) -> tuple[T, bool]:
819 """
820 Look up an object with the given kwargs, creating one if necessary.
821 Return a tuple of (object, created), where created is a boolean
822 specifying whether an object was created.
823 """
824 # The get() needs to be targeted at the write database in order
825 # to avoid potential transaction consistency problems.
826 self._for_write = True
827 try:
828 return self.get(**kwargs), False
829 except self.model.DoesNotExist:
830 params = self._extract_model_params(defaults, **kwargs)
831 # Try to create an object using passed params.
832 try:
833 with transaction.atomic():
834 params = dict(resolve_callables(params))
835 return self.create(**params), True
836 except (IntegrityError, ValidationError):
837 # Since create() also validates by default,
838 # we can get any kind of ValidationError here,
839 # or it can flow through and get an IntegrityError from the database.
840 # The main thing we're concerned about is uniqueness failures,
841 # but ValidationError could include other things too.
842 # In all cases though it should be fine to try the get() again
843 # and return an existing object.
844 try:
845 return self.get(**kwargs), False
846 except self.model.DoesNotExist:
847 pass
848 raise
849
850 def update_or_create(
851 self,
852 defaults: dict[str, Any] | None = None,
853 create_defaults: dict[str, Any] | None = None,
854 **kwargs: Any,
855 ) -> tuple[T, bool]:
856 """
857 Look up an object with the given kwargs, updating one with defaults
858 if it exists, otherwise create a new one. Optionally, an object can
859 be created with different values than defaults by using
860 create_defaults.
861 Return a tuple (object, created), where created is a boolean
862 specifying whether an object was created.
863 """
864 if create_defaults is None:
865 update_defaults = create_defaults = defaults or {}
866 else:
867 update_defaults = defaults or {}
868 self._for_write = True
869 with transaction.atomic():
870 # Lock the row so that a concurrent update is blocked until
871 # update_or_create() has performed its save.
872 obj, created = self.select_for_update().get_or_create(
873 create_defaults, **kwargs
874 )
875 if created:
876 return obj, created
877 for k, v in resolve_callables(update_defaults):
878 setattr(obj, k, v)
879
880 update_fields = set(update_defaults)
881 concrete_field_names = self.model._model_meta._non_pk_concrete_field_names
882 # update_fields does not support non-concrete fields.
883 if concrete_field_names.issuperset(update_fields):
884 # Add fields which are set on pre_save(), e.g. auto_now fields.
885 # This is to maintain backward compatibility as these fields
886 # are not updated unless explicitly specified in the
887 # update_fields list.
888 for field in self.model._model_meta.local_concrete_fields:
889 if not (
890 field.primary_key or field.__class__.pre_save is Field.pre_save
891 ):
892 update_fields.add(field.name)
893 if field.name != field.attname:
894 update_fields.add(field.attname)
895 obj.save(update_fields=update_fields)
896 else:
897 obj.save()
898 return obj, False
899
900 def _extract_model_params(
901 self, defaults: dict[str, Any] | None, **kwargs: Any
902 ) -> dict[str, Any]:
903 """
904 Prepare `params` for creating a model instance based on the given
905 kwargs; for use by get_or_create().
906 """
907 defaults = defaults or {}
908 params = {k: v for k, v in kwargs.items() if LOOKUP_SEP not in k}
909 params.update(defaults)
910 property_names = self.model._model_meta._property_names
911 invalid_params = []
912 for param in params:
913 try:
914 self.model._model_meta.get_field(param)
915 except FieldDoesNotExist:
916 # It's okay to use a model's property if it has a setter.
917 if not (param in property_names and getattr(self.model, param).fset):
918 invalid_params.append(param)
919 if invalid_params:
920 raise FieldError(
921 "Invalid field name(s) for model {}: '{}'.".format(
922 self.model.model_options.object_name,
923 "', '".join(sorted(invalid_params)),
924 )
925 )
926 return params
927
928 def first(self) -> T | None:
929 """Return the first object of a query or None if no match is found."""
930 for obj in self[:1]:
931 return obj
932 return None
933
934 def last(self) -> T | None:
935 """Return the last object of a query or None if no match is found."""
936 queryset = self.reverse()
937 for obj in queryset[:1]:
938 return obj
939 return None
940
941 def delete(self) -> tuple[int, dict[str, int]]:
942 """Delete the records in the current QuerySet."""
943 if self.sql_query.is_sliced:
944 raise TypeError("Cannot use 'limit' or 'offset' with delete().")
945 if self.sql_query.distinct or self.sql_query.distinct_fields:
946 raise TypeError("Cannot call delete() after .distinct().")
947 if self._fields is not None:
948 raise TypeError("Cannot call delete() after .values() or .values_list()")
949
950 del_query = self._chain()
951
952 # The delete is actually 2 queries - one to find related objects,
953 # and one to delete. Make sure that the discovery of related
954 # objects is performed on the same database as the deletion.
955 del_query._for_write = True
956
957 # Disable non-supported fields.
958 del_query.sql_query.select_for_update = False
959 del_query.sql_query.select_related = False
960 del_query.sql_query.clear_ordering(force=True)
961
962 from plain.postgres.deletion import Collector
963
964 collector = Collector(origin=self)
965 collector.collect(del_query)
966 deleted, _rows_count = collector.delete()
967
968 # Clear the result cache, in case this QuerySet gets reused.
969 self._result_cache = None
970 return deleted, _rows_count
971
972 def _raw_delete(self) -> int:
973 """
974 Delete objects found from the given queryset in single direct SQL
975 query. No signals are sent and there is no protection for cascades.
976 """
977 query = self.sql_query.clone()
978 query.__class__ = DeleteQuery
979 cursor = query.get_compiler().execute_sql(CURSOR)
980 if cursor:
981 with cursor:
982 return cursor.rowcount
983 return 0
984
985 def update(self, **kwargs: Any) -> int:
986 """
987 Update all elements in the current QuerySet, setting all the given
988 fields to the appropriate values.
989 """
990 if self.sql_query.is_sliced:
991 raise TypeError("Cannot update a query once a slice has been taken.")
992 self._for_write = True
993 query = self.sql_query.chain(UpdateQuery)
994 query.add_update_values(kwargs)
995
996 # Inline annotations in order_by(), if possible.
997 new_order_by = []
998 for col in query.order_by:
999 alias = col
1000 descending = False
1001 if isinstance(alias, str) and alias.startswith("-"):
1002 alias = alias.removeprefix("-")
1003 descending = True
1004 if annotation := query.annotations.get(alias):
1005 if getattr(annotation, "contains_aggregate", False):
1006 raise FieldError(
1007 f"Cannot update when ordering by an aggregate: {annotation}"
1008 )
1009 if descending:
1010 annotation = annotation.desc()
1011 new_order_by.append(annotation)
1012 else:
1013 new_order_by.append(col)
1014 query.order_by = tuple(new_order_by)
1015
1016 # Clear any annotations so that they won't be present in subqueries.
1017 query.annotations = {}
1018 with transaction.mark_for_rollback_on_error():
1019 rows = query.get_compiler().execute_sql(CURSOR)
1020 self._result_cache = None
1021 return rows
1022
1023 def _update(self, values: list[tuple[Field, Any, Any]]) -> int:
1024 """
1025 A version of update() that accepts field objects instead of field names.
1026 Used primarily for model saving and not intended for use by general
1027 code (it requires too much poking around at model internals to be
1028 useful at that level).
1029 """
1030 if self.sql_query.is_sliced:
1031 raise TypeError("Cannot update a query once a slice has been taken.")
1032 query = self.sql_query.chain(UpdateQuery)
1033 query.add_update_fields(values)
1034 # Clear any annotations so that they won't be present in subqueries.
1035 query.annotations = {}
1036 self._result_cache = None
1037 return query.get_compiler().execute_sql(CURSOR)
1038
1039 def exists(self) -> bool:
1040 """
1041 Return True if the QuerySet would have any results, False otherwise.
1042 """
1043 if self._result_cache is None:
1044 return self.sql_query.has_results()
1045 return bool(self._result_cache)
1046
1047 def _prefetch_related_objects(self) -> None:
1048 # This method can only be called once the result cache has been filled.
1049 assert self._result_cache is not None
1050 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1051 self._prefetch_done = True
1052
1053 def explain(self, *, format: str | None = None, **options: Any) -> str:
1054 """
1055 Runs an EXPLAIN on the SQL query this QuerySet would perform, and
1056 returns the results.
1057 """
1058 return self.sql_query.explain(format=format, **options)
1059
1060 ##################################################
1061 # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
1062 ##################################################
1063
1064 def raw(
1065 self,
1066 raw_query: str,
1067 params: Sequence[Any] = (),
1068 translations: dict[str, str] | None = None,
1069 ) -> RawQuerySet:
1070 qs = RawQuerySet(
1071 raw_query,
1072 model=self.model,
1073 params=tuple(params),
1074 translations=translations,
1075 )
1076 qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
1077 return qs
1078
1079 def _values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1080 clone = self._chain()
1081 if expressions:
1082 clone = clone.annotate(**expressions)
1083 clone._fields = fields
1084 clone.sql_query.set_values(list(fields))
1085 return clone
1086
1087 def values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1088 fields += tuple(expressions)
1089 clone = self._values(*fields, **expressions)
1090 clone._iterable_class = ValuesIterable
1091 return clone
1092
1093 def values_list(self, *fields: str, flat: bool = False) -> QuerySet[Any]:
1094 if flat and len(fields) > 1:
1095 raise TypeError(
1096 "'flat' is not valid when values_list is called with more than one "
1097 "field."
1098 )
1099
1100 field_names = {f for f in fields if not isinstance(f, ResolvableExpression)}
1101 _fields = []
1102 expressions = {}
1103 counter = 1
1104 for field in fields:
1105 if isinstance(field, ResolvableExpression):
1106 field_id_prefix = getattr(
1107 field, "default_alias", field.__class__.__name__.lower()
1108 )
1109 while True:
1110 field_id = field_id_prefix + str(counter)
1111 counter += 1
1112 if field_id not in field_names:
1113 break
1114 expressions[field_id] = field
1115 _fields.append(field_id)
1116 else:
1117 _fields.append(field)
1118
1119 clone = self._values(*_fields, **expressions)
1120 clone._iterable_class = FlatValuesListIterable if flat else ValuesListIterable
1121 return clone
1122
1123 def none(self) -> QuerySet[T]:
1124 """Return an empty QuerySet."""
1125 clone = self._chain()
1126 clone.sql_query.set_empty()
1127 return clone
1128
1129 ##################################################################
1130 # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
1131 ##################################################################
1132
1133 def all(self) -> Self:
1134 """
1135 Return a new QuerySet that is a copy of the current one. This allows a
1136 QuerySet to proxy for a model queryset in some cases.
1137 """
1138 obj = self._chain()
1139 # Preserve cache since all() doesn't modify the query.
1140 # This is important for prefetch_related() to work correctly.
1141 obj._result_cache = self._result_cache
1142 obj._prefetch_done = self._prefetch_done
1143 return obj
1144
1145 def filter(self, *args: Any, **kwargs: Any) -> Self:
1146 """
1147 Return a new QuerySet instance with the args ANDed to the existing
1148 set.
1149 """
1150 return self._filter_or_exclude(False, args, kwargs)
1151
1152 def exclude(self, *args: Any, **kwargs: Any) -> Self:
1153 """
1154 Return a new QuerySet instance with NOT (args) ANDed to the existing
1155 set.
1156 """
1157 return self._filter_or_exclude(True, args, kwargs)
1158
1159 def _filter_or_exclude(
1160 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1161 ) -> Self:
1162 if (args or kwargs) and self.sql_query.is_sliced:
1163 raise TypeError("Cannot filter a query once a slice has been taken.")
1164 clone = self._chain()
1165 if self._defer_next_filter:
1166 self._defer_next_filter = False
1167 clone._deferred_filter = negate, args, kwargs
1168 else:
1169 clone._filter_or_exclude_inplace(negate, args, kwargs)
1170 return clone
1171
1172 def _filter_or_exclude_inplace(
1173 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1174 ) -> None:
1175 if negate:
1176 self._query.add_q(~Q(*args, **kwargs))
1177 else:
1178 self._query.add_q(Q(*args, **kwargs))
1179
1180 def complex_filter(self, filter_obj: Q | dict[str, Any]) -> QuerySet[T]:
1181 """
1182 Return a new QuerySet instance with filter_obj added to the filters.
1183
1184 filter_obj can be a Q object or a dictionary of keyword lookup
1185 arguments.
1186
1187 This exists to support framework features such as 'limit_choices_to',
1188 and usually it will be more natural to use other methods.
1189 """
1190 if isinstance(filter_obj, Q):
1191 clone = self._chain()
1192 clone.sql_query.add_q(filter_obj)
1193 return clone
1194 else:
1195 return self._filter_or_exclude(False, args=(), kwargs=filter_obj)
1196
1197 def select_for_update(
1198 self,
1199 nowait: bool = False,
1200 skip_locked: bool = False,
1201 of: tuple[str, ...] = (),
1202 no_key: bool = False,
1203 ) -> QuerySet[T]:
1204 """
1205 Return a new QuerySet instance that will select objects with a
1206 FOR UPDATE lock.
1207 """
1208 if nowait and skip_locked:
1209 raise ValueError("The nowait option cannot be used with skip_locked.")
1210 obj = self._chain()
1211 obj._for_write = True
1212 obj.sql_query.select_for_update = True
1213 obj.sql_query.select_for_update_nowait = nowait
1214 obj.sql_query.select_for_update_skip_locked = skip_locked
1215 obj.sql_query.select_for_update_of = of
1216 obj.sql_query.select_for_no_key_update = no_key
1217 return obj
1218
1219 def select_related(self, *fields: str | None) -> Self:
1220 """
1221 Return a new QuerySet instance that will select related objects.
1222
1223 If fields are specified, they must be ForeignKeyField fields and only those
1224 related objects are included in the selection.
1225
1226 If select_related(None) is called, clear the list.
1227 """
1228 if self._fields is not None:
1229 raise TypeError(
1230 "Cannot call select_related() after .values() or .values_list()"
1231 )
1232
1233 obj = self._chain()
1234 if fields == (None,):
1235 obj.sql_query.select_related = False
1236 elif fields:
1237 obj.sql_query.add_select_related(list(fields)) # type: ignore[arg-type]
1238 else:
1239 obj.sql_query.select_related = True
1240 return obj
1241
1242 def prefetch_related(self, *lookups: str | Prefetch | None) -> Self:
1243 """
1244 Return a new QuerySet instance that will prefetch the specified
1245 Many-To-One and Many-To-Many related objects when the QuerySet is
1246 evaluated.
1247
1248 When prefetch_related() is called more than once, append to the list of
1249 prefetch lookups. If prefetch_related(None) is called, clear the list.
1250 """
1251 clone = self._chain()
1252 if lookups == (None,):
1253 clone._prefetch_related_lookups = ()
1254 else:
1255 for lookup in lookups:
1256 lookup_str: str
1257 if isinstance(lookup, Prefetch):
1258 lookup_str = lookup.prefetch_to
1259 else:
1260 assert isinstance(lookup, str)
1261 lookup_str = lookup
1262 lookup_str = lookup_str.split(LOOKUP_SEP, 1)[0]
1263 if lookup_str in self.sql_query._filtered_relations:
1264 raise ValueError(
1265 "prefetch_related() is not supported with FilteredRelation."
1266 )
1267 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1268 return clone
1269
1270 def annotate(self, *args: Any, **kwargs: Any) -> Self:
1271 """
1272 Return a query set in which the returned objects have been annotated
1273 with extra data or aggregations.
1274 """
1275 return self._annotate(args, kwargs, select=True)
1276
1277 def alias(self, *args: Any, **kwargs: Any) -> Self:
1278 """
1279 Return a query set with added aliases for extra data or aggregations.
1280 """
1281 return self._annotate(args, kwargs, select=False)
1282
1283 def _annotate(
1284 self, args: tuple[Any, ...], kwargs: dict[str, Any], select: bool = True
1285 ) -> Self:
1286 self._validate_values_are_expressions(
1287 args + tuple(kwargs.values()), method_name="annotate"
1288 )
1289 annotations = {}
1290 for arg in args:
1291 # The default_alias property may raise a TypeError.
1292 try:
1293 if arg.default_alias in kwargs:
1294 raise ValueError(
1295 f"The named annotation '{arg.default_alias}' conflicts with the "
1296 "default name for another annotation."
1297 )
1298 except TypeError:
1299 raise TypeError("Complex annotations require an alias")
1300 annotations[arg.default_alias] = arg
1301 annotations.update(kwargs)
1302
1303 clone = self._chain()
1304 names = self._fields
1305 if names is None:
1306 names = set(
1307 chain.from_iterable(
1308 (field.name, field.attname)
1309 if hasattr(field, "attname")
1310 else (field.name,)
1311 for field in self.model._model_meta.get_fields()
1312 )
1313 )
1314
1315 for alias, annotation in annotations.items():
1316 if alias in names:
1317 raise ValueError(
1318 f"The annotation '{alias}' conflicts with a field on the model."
1319 )
1320 if isinstance(annotation, FilteredRelation):
1321 clone.sql_query.add_filtered_relation(annotation, alias)
1322 else:
1323 clone.sql_query.add_annotation(
1324 annotation,
1325 alias,
1326 select=select,
1327 )
1328 for alias, annotation in clone.sql_query.annotations.items():
1329 if alias in annotations and annotation.contains_aggregate:
1330 if clone._fields is None:
1331 clone.sql_query.group_by = True
1332 else:
1333 clone.sql_query.set_group_by()
1334 break
1335
1336 return clone
1337
1338 def order_by(self, *field_names: str) -> Self:
1339 """Return a new QuerySet instance with the ordering changed."""
1340 if self.sql_query.is_sliced:
1341 raise TypeError("Cannot reorder a query once a slice has been taken.")
1342 obj = self._chain()
1343 obj.sql_query.clear_ordering(force=True, clear_default=False)
1344 obj.sql_query.add_ordering(*field_names)
1345 return obj
1346
1347 def distinct(self, *field_names: str) -> Self:
1348 """
1349 Return a new QuerySet instance that will select only distinct results.
1350 """
1351 if self.sql_query.is_sliced:
1352 raise TypeError(
1353 "Cannot create distinct fields once a slice has been taken."
1354 )
1355 obj = self._chain()
1356 obj.sql_query.add_distinct_fields(*field_names)
1357 return obj
1358
1359 def extra(
1360 self,
1361 select: dict[str, str] | None = None,
1362 where: list[str] | None = None,
1363 params: list[Any] | None = None,
1364 tables: list[str] | None = None,
1365 order_by: list[str] | None = None,
1366 select_params: list[Any] | None = None,
1367 ) -> QuerySet[T]:
1368 """Add extra SQL fragments to the query."""
1369 if self.sql_query.is_sliced:
1370 raise TypeError("Cannot change a query once a slice has been taken.")
1371 clone = self._chain()
1372 clone.sql_query.add_extra(
1373 select or {},
1374 select_params,
1375 where or [],
1376 params or [],
1377 tables or [],
1378 tuple(order_by) if order_by else (),
1379 )
1380 return clone
1381
1382 def reverse(self) -> QuerySet[T]:
1383 """Reverse the ordering of the QuerySet."""
1384 if self.sql_query.is_sliced:
1385 raise TypeError("Cannot reverse a query once a slice has been taken.")
1386 clone = self._chain()
1387 clone.sql_query.standard_ordering = not clone.sql_query.standard_ordering
1388 return clone
1389
1390 def defer(self, *fields: str | None) -> QuerySet[T]:
1391 """
1392 Defer the loading of data for certain fields until they are accessed.
1393 Add the set of deferred fields to any existing set of deferred fields.
1394 The only exception to this is if None is passed in as the only
1395 parameter, in which case removal all deferrals.
1396 """
1397 if self._fields is not None:
1398 raise TypeError("Cannot call defer() after .values() or .values_list()")
1399 clone = self._chain()
1400 if fields == (None,):
1401 clone.sql_query.clear_deferred_loading()
1402 else:
1403 clone.sql_query.add_deferred_loading(frozenset(fields))
1404 return clone
1405
1406 def only(self, *fields: str) -> QuerySet[T]:
1407 """
1408 Essentially, the opposite of defer(). Only the fields passed into this
1409 method and that are not already specified as deferred are loaded
1410 immediately when the queryset is evaluated.
1411 """
1412 if self._fields is not None:
1413 raise TypeError("Cannot call only() after .values() or .values_list()")
1414 if fields == (None,):
1415 # Can only pass None to defer(), not only(), as the rest option.
1416 # That won't stop people trying to do this, so let's be explicit.
1417 raise TypeError("Cannot pass None as an argument to only().")
1418 for field in fields:
1419 field = field.split(LOOKUP_SEP, 1)[0]
1420 if field in self.sql_query._filtered_relations:
1421 raise ValueError("only() is not supported with FilteredRelation.")
1422 clone = self._chain()
1423 clone.sql_query.add_immediate_loading(set(fields))
1424 return clone
1425
1426 ###################################
1427 # PUBLIC INTROSPECTION ATTRIBUTES #
1428 ###################################
1429
1430 @property
1431 def ordered(self) -> bool:
1432 """
1433 Return True if the QuerySet is ordered -- i.e. has an order_by()
1434 clause or a default ordering on the model (or is empty).
1435 """
1436 if isinstance(self, EmptyQuerySet):
1437 return True
1438 if self.sql_query.extra_order_by or self.sql_query.order_by:
1439 return True
1440 elif (
1441 self.sql_query.default_ordering
1442 and self.sql_query.model
1443 and self.sql_query.model._model_meta.ordering # type: ignore[arg-type]
1444 and
1445 # A default ordering doesn't affect GROUP BY queries.
1446 not self.sql_query.group_by
1447 ):
1448 return True
1449 else:
1450 return False
1451
1452 ###################
1453 # PRIVATE METHODS #
1454 ###################
1455
1456 def _insert(
1457 self,
1458 objs: list[T],
1459 fields: list[Field],
1460 returning_fields: list[Field] | None = None,
1461 raw: bool = False,
1462 on_conflict: OnConflict | None = None,
1463 update_fields: list[Field] | None = None,
1464 unique_fields: list[Field] | None = None,
1465 ) -> list[tuple[Any, ...]] | None:
1466 """
1467 Insert a new record for the given model. This provides an interface to
1468 the InsertQuery class and is how Model.save() is implemented.
1469 """
1470 self._for_write = True
1471 query = InsertQuery(
1472 self.model,
1473 on_conflict=on_conflict if on_conflict else None,
1474 update_fields=update_fields,
1475 unique_fields=unique_fields,
1476 )
1477 query.insert_values(fields, objs, raw=raw)
1478 # InsertQuery returns SQLInsertCompiler which has different execute_sql signature
1479 return query.get_compiler().execute_sql(returning_fields)
1480
1481 def _batched_insert(
1482 self,
1483 objs: list[T],
1484 fields: list[Field],
1485 batch_size: int | None,
1486 on_conflict: OnConflict | None = None,
1487 update_fields: list[Field] | None = None,
1488 unique_fields: list[Field] | None = None,
1489 ) -> list[tuple[Any, ...]]:
1490 """
1491 Helper method for bulk_create() to insert objs one batch at a time.
1492 """
1493 max_batch_size = max(len(objs), 1)
1494 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
1495 inserted_rows = []
1496 for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
1497 if on_conflict is None:
1498 inserted_rows.extend(
1499 self._insert( # type: ignore[arg-type]
1500 item,
1501 fields=fields,
1502 returning_fields=self.model._model_meta.db_returning_fields,
1503 )
1504 )
1505 else:
1506 self._insert(
1507 item,
1508 fields=fields,
1509 on_conflict=on_conflict,
1510 update_fields=update_fields,
1511 unique_fields=unique_fields,
1512 )
1513 return inserted_rows
1514
1515 def _chain(self) -> Self:
1516 """
1517 Return a copy of the current QuerySet that's ready for another
1518 operation.
1519 """
1520 obj = self._clone()
1521 if obj._sticky_filter:
1522 obj.sql_query.filter_is_sticky = True
1523 obj._sticky_filter = False
1524 return obj
1525
1526 def _clone(self) -> Self:
1527 """
1528 Return a copy of the current QuerySet. A lightweight alternative
1529 to deepcopy().
1530 """
1531 c = self.__class__.from_model(
1532 model=self.model,
1533 query=self.sql_query.chain(),
1534 )
1535 c._sticky_filter = self._sticky_filter
1536 c._for_write = self._for_write
1537 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1538 c._known_related_objects = self._known_related_objects
1539 c._iterable_class = self._iterable_class
1540 c._fields = self._fields
1541 return c
1542
1543 def _fetch_all(self) -> None:
1544 if self._result_cache is None:
1545 self._result_cache = list(self._iterable_class(self))
1546 if self._prefetch_related_lookups and not self._prefetch_done:
1547 self._prefetch_related_objects()
1548
1549 def _next_is_sticky(self) -> QuerySet[T]:
1550 """
1551 Indicate that the next filter call and the one following that should
1552 be treated as a single filter. This is only important when it comes to
1553 determining when to reuse tables for many-to-many filters. Required so
1554 that we can filter naturally on the results of related managers.
1555
1556 This doesn't return a clone of the current QuerySet (it returns
1557 "self"). The method is only used internally and should be immediately
1558 followed by a filter() that does create a clone.
1559 """
1560 self._sticky_filter = True
1561 return self
1562
1563 def _merge_sanity_check(self, other: QuerySet[T]) -> None:
1564 """Check that two QuerySet classes may be merged."""
1565 if self._fields is not None and (
1566 set(self.sql_query.values_select) != set(other.sql_query.values_select)
1567 or set(self.sql_query.extra_select) != set(other.sql_query.extra_select)
1568 or set(self.sql_query.annotation_select)
1569 != set(other.sql_query.annotation_select)
1570 ):
1571 raise TypeError(
1572 f"Merging '{self.__class__.__name__}' classes must involve the same values in each case."
1573 )
1574
1575 def _merge_known_related_objects(self, other: QuerySet[T]) -> None:
1576 """
1577 Keep track of all known related objects from either QuerySet instance.
1578 """
1579 for field, objects in other._known_related_objects.items():
1580 self._known_related_objects.setdefault(field, {}).update(objects)
1581
1582 def resolve_expression(self, *args: Any, **kwargs: Any) -> Query:
1583 if self._fields and len(self._fields) > 1:
1584 # values() queryset can only be used as nested queries
1585 # if they are set up to select only a single field.
1586 raise TypeError("Cannot use multi-field values as a filter value.")
1587 query = self.sql_query.resolve_expression(*args, **kwargs)
1588 return query
1589
1590 def _has_filters(self) -> bool:
1591 """
1592 Check if this QuerySet has any filtering going on. This isn't
1593 equivalent with checking if all objects are present in results, for
1594 example, qs[1:]._has_filters() -> False.
1595 """
1596 return self.sql_query.has_filters()
1597
1598 @staticmethod
1599 def _validate_values_are_expressions(
1600 values: tuple[Any, ...], method_name: str
1601 ) -> None:
1602 invalid_args = sorted(
1603 str(arg) for arg in values if not isinstance(arg, ResolvableExpression)
1604 )
1605 if invalid_args:
1606 raise TypeError(
1607 "QuerySet.{}() received non-expression(s): {}.".format(
1608 method_name,
1609 ", ".join(invalid_args),
1610 )
1611 )
1612
1613
1614class InstanceCheckMeta(type):
1615 def __instancecheck__(self, instance: object) -> bool:
1616 return isinstance(instance, QuerySet) and instance.sql_query.is_empty()
1617
1618
1619class EmptyQuerySet(metaclass=InstanceCheckMeta):
1620 """
1621 Marker class to checking if a queryset is empty by .none():
1622 isinstance(qs.none(), EmptyQuerySet) -> True
1623 """
1624
1625 def __init__(self, *args: Any, **kwargs: Any):
1626 raise TypeError("EmptyQuerySet can't be instantiated")
1627
1628
1629class RawQuerySet:
1630 """
1631 Provide an iterator which converts the results of raw SQL queries into
1632 annotated model instances.
1633 """
1634
1635 def __init__(
1636 self,
1637 raw_query: str,
1638 model: type[Model] | None = None,
1639 query: RawQuery | None = None,
1640 params: tuple[Any, ...] = (),
1641 translations: dict[str, str] | None = None,
1642 ):
1643 self.raw_query = raw_query
1644 self.model = model
1645 self.sql_query = query or RawQuery(sql=raw_query, params=params)
1646 self.params = params
1647 self.translations = translations or {}
1648 self._result_cache: list[Model] | None = None
1649 self._prefetch_related_lookups: tuple[Any, ...] = ()
1650 self._prefetch_done = False
1651
1652 def resolve_model_init_order(
1653 self,
1654 ) -> tuple[list[str], list[int], list[tuple[str, int]]]:
1655 """Resolve the init field names and value positions."""
1656 model = self.model
1657 assert model is not None
1658 model_init_fields = [
1659 f for f in model._model_meta.fields if f.column in self.columns
1660 ]
1661 annotation_fields = [
1662 (column, pos)
1663 for pos, column in enumerate(self.columns)
1664 if column not in self.model_fields
1665 ]
1666 model_init_order = [self.columns.index(f.column) for f in model_init_fields]
1667 model_init_names = [f.attname for f in model_init_fields]
1668 return model_init_names, model_init_order, annotation_fields
1669
1670 def prefetch_related(self, *lookups: str | Prefetch | None) -> RawQuerySet:
1671 """Same as QuerySet.prefetch_related()"""
1672 clone = self._clone()
1673 if lookups == (None,):
1674 clone._prefetch_related_lookups = ()
1675 else:
1676 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1677 return clone
1678
1679 def _prefetch_related_objects(self) -> None:
1680 assert self._result_cache is not None
1681 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1682 self._prefetch_done = True
1683
1684 def _clone(self) -> RawQuerySet:
1685 """Same as QuerySet._clone()"""
1686 c = self.__class__(
1687 self.raw_query,
1688 model=self.model,
1689 query=self.sql_query,
1690 params=self.params,
1691 translations=self.translations,
1692 )
1693 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1694 return c
1695
1696 def _fetch_all(self) -> None:
1697 if self._result_cache is None:
1698 self._result_cache = list(self.iterator())
1699 if self._prefetch_related_lookups and not self._prefetch_done:
1700 self._prefetch_related_objects()
1701
1702 def __len__(self) -> int:
1703 self._fetch_all()
1704 assert self._result_cache is not None
1705 return len(self._result_cache)
1706
1707 def __bool__(self) -> bool:
1708 self._fetch_all()
1709 return bool(self._result_cache)
1710
1711 def __iter__(self) -> Iterator[Model]:
1712 self._fetch_all()
1713 assert self._result_cache is not None
1714 return iter(self._result_cache)
1715
1716 def iterator(self) -> Iterator[Model]:
1717 yield from RawModelIterable(self) # type: ignore[arg-type]
1718
1719 def __repr__(self) -> str:
1720 return f"<{self.__class__.__name__}: {self.sql_query}>"
1721
1722 def __getitem__(self, k: int | slice) -> Model | list[Model]:
1723 return list(self)[k]
1724
1725 @cached_property
1726 def columns(self) -> list[str]:
1727 """
1728 A list of model field names in the order they'll appear in the
1729 query results.
1730 """
1731 columns = self.sql_query.get_columns()
1732 # Adjust any column names which don't match field names
1733 for query_name, model_name in self.translations.items():
1734 # Ignore translations for nonexistent column names
1735 try:
1736 index = columns.index(query_name)
1737 except ValueError:
1738 pass
1739 else:
1740 columns[index] = model_name
1741 return columns
1742
1743 @cached_property
1744 def model_fields(self) -> dict[str, Field]:
1745 """A dict mapping column names to model field names."""
1746 model_fields = {}
1747 model = self.model
1748 assert model is not None
1749 for field in model._model_meta.fields:
1750 model_fields[field.column] = field
1751 return model_fields
1752
1753
1754class Prefetch:
1755 def __init__(
1756 self,
1757 lookup: str,
1758 queryset: QuerySet[Any] | None = None,
1759 to_attr: str | None = None,
1760 ):
1761 # `prefetch_through` is the path we traverse to perform the prefetch.
1762 self.prefetch_through = lookup
1763 # `prefetch_to` is the path to the attribute that stores the result.
1764 self.prefetch_to = lookup
1765 if queryset is not None and (
1766 isinstance(queryset, RawQuerySet)
1767 or (
1768 hasattr(queryset, "_iterable_class")
1769 and not issubclass(queryset._iterable_class, ModelIterable)
1770 )
1771 ):
1772 raise ValueError(
1773 "Prefetch querysets cannot use raw(), values(), and values_list()."
1774 )
1775 if to_attr:
1776 self.prefetch_to = LOOKUP_SEP.join(
1777 lookup.split(LOOKUP_SEP)[:-1] + [to_attr]
1778 )
1779
1780 self.queryset = queryset
1781 self.to_attr = to_attr
1782
1783 def __getstate__(self) -> dict[str, Any]:
1784 obj_dict = self.__dict__.copy()
1785 if self.queryset is not None:
1786 queryset = self.queryset._chain()
1787 # Prevent the QuerySet from being evaluated
1788 queryset._result_cache = []
1789 queryset._prefetch_done = True
1790 obj_dict["queryset"] = queryset
1791 return obj_dict
1792
1793 def add_prefix(self, prefix: str) -> None:
1794 self.prefetch_through = prefix + LOOKUP_SEP + self.prefetch_through
1795 self.prefetch_to = prefix + LOOKUP_SEP + self.prefetch_to
1796
1797 def get_current_prefetch_to(self, level: int) -> str:
1798 return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[: level + 1])
1799
1800 def get_current_to_attr(self, level: int) -> tuple[str, bool]:
1801 parts = self.prefetch_to.split(LOOKUP_SEP)
1802 to_attr = parts[level]
1803 as_attr = bool(self.to_attr and level == len(parts) - 1)
1804 return to_attr, as_attr
1805
1806 def get_current_queryset(self, level: int) -> QuerySet[Any] | None:
1807 if self.get_current_prefetch_to(level) == self.prefetch_to:
1808 return self.queryset
1809 return None
1810
1811 def __eq__(self, other: object) -> bool:
1812 if not isinstance(other, Prefetch):
1813 return NotImplemented
1814 return self.prefetch_to == other.prefetch_to
1815
1816 def __hash__(self) -> int:
1817 return hash((self.__class__, self.prefetch_to))
1818
1819
1820def normalize_prefetch_lookups(
1821 lookups: tuple[str | Prefetch, ...] | list[str | Prefetch],
1822 prefix: str | None = None,
1823) -> list[Prefetch]:
1824 """Normalize lookups into Prefetch objects."""
1825 ret = []
1826 for lookup in lookups:
1827 if not isinstance(lookup, Prefetch):
1828 lookup = Prefetch(lookup)
1829 if prefix:
1830 lookup.add_prefix(prefix)
1831 ret.append(lookup)
1832 return ret
1833
1834
1835def prefetch_related_objects(
1836 model_instances: Sequence[Model], *related_lookups: str | Prefetch
1837) -> None:
1838 """
1839 Populate prefetched object caches for a list of model instances based on
1840 the lookups/Prefetch instances given.
1841 """
1842 if not model_instances:
1843 return # nothing to do
1844
1845 # We need to be able to dynamically add to the list of prefetch_related
1846 # lookups that we look up (see below). So we need some book keeping to
1847 # ensure we don't do duplicate work.
1848 done_queries = {} # dictionary of things like 'foo__bar': [results]
1849
1850 auto_lookups = set() # we add to this as we go through.
1851 followed_descriptors = set() # recursion protection
1852
1853 all_lookups = normalize_prefetch_lookups(reversed(related_lookups)) # type: ignore[arg-type]
1854 while all_lookups:
1855 lookup = all_lookups.pop()
1856 if lookup.prefetch_to in done_queries:
1857 if lookup.queryset is not None:
1858 raise ValueError(
1859 f"'{lookup.prefetch_to}' lookup was already seen with a different queryset. "
1860 "You may need to adjust the ordering of your lookups."
1861 )
1862
1863 continue
1864
1865 # Top level, the list of objects to decorate is the result cache
1866 # from the primary QuerySet. It won't be for deeper levels.
1867 obj_list = model_instances
1868
1869 through_attrs = lookup.prefetch_through.split(LOOKUP_SEP)
1870 for level, through_attr in enumerate(through_attrs):
1871 # Prepare main instances
1872 if not obj_list:
1873 break
1874
1875 prefetch_to = lookup.get_current_prefetch_to(level)
1876 if prefetch_to in done_queries:
1877 # Skip any prefetching, and any object preparation
1878 obj_list = done_queries[prefetch_to]
1879 continue
1880
1881 # Prepare objects:
1882 good_objects = True
1883 for obj in obj_list:
1884 # Since prefetching can re-use instances, it is possible to have
1885 # the same instance multiple times in obj_list, so obj might
1886 # already be prepared.
1887 if not hasattr(obj, "_prefetched_objects_cache"):
1888 try:
1889 obj._prefetched_objects_cache = {}
1890 except (AttributeError, TypeError):
1891 # Must be an immutable object from
1892 # values_list(flat=True), for example (TypeError) or
1893 # a QuerySet subclass that isn't returning Model
1894 # instances (AttributeError), either in Plain or a 3rd
1895 # party. prefetch_related() doesn't make sense, so quit.
1896 good_objects = False
1897 break
1898 if not good_objects:
1899 break
1900
1901 # Descend down tree
1902
1903 # We assume that objects retrieved are homogeneous (which is the premise
1904 # of prefetch_related), so what applies to first object applies to all.
1905 first_obj = obj_list[0]
1906 to_attr = lookup.get_current_to_attr(level)[0]
1907 prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(
1908 first_obj, through_attr, to_attr
1909 )
1910
1911 if not attr_found:
1912 raise AttributeError(
1913 f"Cannot find '{through_attr}' on {first_obj.__class__.__name__} object, '{lookup.prefetch_through}' is an invalid "
1914 "parameter to prefetch_related()"
1915 )
1916
1917 if level == len(through_attrs) - 1 and prefetcher is None:
1918 # Last one, this *must* resolve to something that supports
1919 # prefetching, otherwise there is no point adding it and the
1920 # developer asking for it has made a mistake.
1921 raise ValueError(
1922 f"'{lookup.prefetch_through}' does not resolve to an item that supports "
1923 "prefetching - this is an invalid parameter to "
1924 "prefetch_related()."
1925 )
1926
1927 obj_to_fetch = None
1928 if prefetcher is not None:
1929 obj_to_fetch = [obj for obj in obj_list if not is_fetched(obj)]
1930
1931 if obj_to_fetch:
1932 obj_list, additional_lookups = prefetch_one_level(
1933 obj_to_fetch,
1934 prefetcher,
1935 lookup,
1936 level,
1937 )
1938 # We need to ensure we don't keep adding lookups from the
1939 # same relationships to stop infinite recursion. So, if we
1940 # are already on an automatically added lookup, don't add
1941 # the new lookups from relationships we've seen already.
1942 if not (
1943 prefetch_to in done_queries
1944 and lookup in auto_lookups
1945 and descriptor in followed_descriptors
1946 ):
1947 done_queries[prefetch_to] = obj_list
1948 new_lookups = normalize_prefetch_lookups(
1949 reversed(additional_lookups), # type: ignore[arg-type]
1950 prefetch_to,
1951 )
1952 auto_lookups.update(new_lookups)
1953 all_lookups.extend(new_lookups)
1954 followed_descriptors.add(descriptor)
1955 else:
1956 # Either a singly related object that has already been fetched
1957 # (e.g. via select_related), or hopefully some other property
1958 # that doesn't support prefetching but needs to be traversed.
1959
1960 # We replace the current list of parent objects with the list
1961 # of related objects, filtering out empty or missing values so
1962 # that we can continue with nullable or reverse relations.
1963 new_obj_list = []
1964 for obj in obj_list:
1965 if through_attr in getattr(obj, "_prefetched_objects_cache", ()):
1966 # If related objects have been prefetched, use the
1967 # cache rather than the object's through_attr.
1968 new_obj = list(obj._prefetched_objects_cache.get(through_attr)) # type: ignore[arg-type]
1969 else:
1970 try:
1971 new_obj = getattr(obj, through_attr)
1972 except ObjectDoesNotExist:
1973 continue
1974 if new_obj is None:
1975 continue
1976 # We special-case `list` rather than something more generic
1977 # like `Iterable` because we don't want to accidentally match
1978 # user models that define __iter__.
1979 if isinstance(new_obj, list):
1980 new_obj_list.extend(new_obj)
1981 else:
1982 new_obj_list.append(new_obj)
1983 obj_list = new_obj_list
1984
1985
1986def get_prefetcher(
1987 instance: Model, through_attr: str, to_attr: str
1988) -> tuple[Any, Any, bool, Callable[[Model], bool]]:
1989 """
1990 For the attribute 'through_attr' on the given instance, find
1991 an object that has a get_prefetch_queryset().
1992 Return a 4 tuple containing:
1993 (the object with get_prefetch_queryset (or None),
1994 the descriptor object representing this relationship (or None),
1995 a boolean that is False if the attribute was not found at all,
1996 a function that takes an instance and returns a boolean that is True if
1997 the attribute has already been fetched for that instance)
1998 """
1999
2000 def has_to_attr_attribute(instance: Model) -> bool:
2001 return hasattr(instance, to_attr)
2002
2003 prefetcher = None
2004 is_fetched: Callable[[Model], bool] = has_to_attr_attribute
2005
2006 # For singly related objects, we have to avoid getting the attribute
2007 # from the object, as this will trigger the query. So we first try
2008 # on the class, in order to get the descriptor object.
2009 rel_obj_descriptor = getattr(instance.__class__, through_attr, None)
2010 if rel_obj_descriptor is None:
2011 attr_found = hasattr(instance, through_attr)
2012 else:
2013 attr_found = True
2014 if rel_obj_descriptor:
2015 # singly related object, descriptor object has the
2016 # get_prefetch_queryset() method.
2017 if hasattr(rel_obj_descriptor, "get_prefetch_queryset"):
2018 prefetcher = rel_obj_descriptor
2019 is_fetched = rel_obj_descriptor.is_cached
2020 else:
2021 # descriptor doesn't support prefetching, so we go ahead and get
2022 # the attribute on the instance rather than the class to
2023 # support many related managers
2024 rel_obj = getattr(instance, through_attr)
2025 if hasattr(rel_obj, "get_prefetch_queryset"):
2026 prefetcher = rel_obj
2027 if through_attr != to_attr:
2028 # Special case cached_property instances because hasattr
2029 # triggers attribute computation and assignment.
2030 if isinstance(
2031 getattr(instance.__class__, to_attr, None), cached_property
2032 ):
2033
2034 def has_cached_property(instance: Model) -> bool:
2035 return to_attr in instance.__dict__
2036
2037 is_fetched = has_cached_property
2038 else:
2039
2040 def in_prefetched_cache(instance: Model) -> bool:
2041 return through_attr in instance._prefetched_objects_cache
2042
2043 is_fetched = in_prefetched_cache
2044 return prefetcher, rel_obj_descriptor, attr_found, is_fetched
2045
2046
2047def prefetch_one_level(
2048 instances: list[Model], prefetcher: Any, lookup: Prefetch, level: int
2049) -> tuple[list[Model], list[Prefetch]]:
2050 """
2051 Helper function for prefetch_related_objects().
2052
2053 Run prefetches on all instances using the prefetcher object,
2054 assigning results to relevant caches in instance.
2055
2056 Return the prefetched objects along with any additional prefetches that
2057 must be done due to prefetch_related lookups found from default managers.
2058 """
2059 # prefetcher must have a method get_prefetch_queryset() which takes a list
2060 # of instances, and returns a tuple:
2061
2062 # (queryset of instances of self.model that are related to passed in instances,
2063 # callable that gets value to be matched for returned instances,
2064 # callable that gets value to be matched for passed in instances,
2065 # boolean that is True for singly related objects,
2066 # cache or field name to assign to,
2067 # boolean that is True when the previous argument is a cache name vs a field name).
2068
2069 # The 'values to be matched' must be hashable as they will be used
2070 # in a dictionary.
2071
2072 (
2073 rel_qs,
2074 rel_obj_attr,
2075 instance_attr,
2076 single,
2077 cache_name,
2078 is_descriptor,
2079 ) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))
2080 # We have to handle the possibility that the QuerySet we just got back
2081 # contains some prefetch_related lookups. We don't want to trigger the
2082 # prefetch_related functionality by evaluating the query. Rather, we need
2083 # to merge in the prefetch_related lookups.
2084 # Copy the lookups in case it is a Prefetch object which could be reused
2085 # later (happens in nested prefetch_related).
2086 additional_lookups = [
2087 copy.copy(additional_lookup)
2088 for additional_lookup in getattr(rel_qs, "_prefetch_related_lookups", ())
2089 ]
2090 if additional_lookups:
2091 # Don't need to clone because the queryset should have given us a fresh
2092 # instance, so we access an internal instead of using public interface
2093 # for performance reasons.
2094 rel_qs._prefetch_related_lookups = ()
2095
2096 all_related_objects = list(rel_qs)
2097
2098 rel_obj_cache = {}
2099 for rel_obj in all_related_objects:
2100 rel_attr_val = rel_obj_attr(rel_obj)
2101 rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
2102
2103 to_attr, as_attr = lookup.get_current_to_attr(level)
2104 # Make sure `to_attr` does not conflict with a field.
2105 if as_attr and instances:
2106 # We assume that objects retrieved are homogeneous (which is the premise
2107 # of prefetch_related), so what applies to first object applies to all.
2108 model = instances[0].__class__
2109 try:
2110 model._model_meta.get_field(to_attr)
2111 except FieldDoesNotExist:
2112 pass
2113 else:
2114 msg = "to_attr={} conflicts with a field on the {} model."
2115 raise ValueError(msg.format(to_attr, model.__name__))
2116
2117 # Whether or not we're prefetching the last part of the lookup.
2118 leaf = len(lookup.prefetch_through.split(LOOKUP_SEP)) - 1 == level
2119
2120 for obj in instances:
2121 instance_attr_val = instance_attr(obj)
2122 vals = rel_obj_cache.get(instance_attr_val, [])
2123
2124 if single:
2125 val = vals[0] if vals else None
2126 if as_attr:
2127 # A to_attr has been given for the prefetch.
2128 setattr(obj, to_attr, val)
2129 elif is_descriptor:
2130 # cache_name points to a field name in obj.
2131 # This field is a descriptor for a related object.
2132 setattr(obj, cache_name, val)
2133 else:
2134 # No to_attr has been given for this prefetch operation and the
2135 # cache_name does not point to a descriptor. Store the value of
2136 # the field in the object's field cache.
2137 obj._state.fields_cache[cache_name] = val # type: ignore[index]
2138 else:
2139 if as_attr:
2140 setattr(obj, to_attr, vals)
2141 else:
2142 queryset = getattr(obj, to_attr)
2143 if leaf and lookup.queryset is not None:
2144 qs = queryset._apply_rel_filters(lookup.queryset)
2145 else:
2146 # Check if queryset is a QuerySet or a related manager
2147 # We need a QuerySet instance to cache the prefetched values
2148 if isinstance(queryset, QuerySet):
2149 # It's already a QuerySet, create a new instance
2150 qs = queryset.__class__.from_model(queryset.model)
2151 else:
2152 # It's a related manager, get its QuerySet
2153 # The manager's query property returns a properly filtered QuerySet
2154 qs = queryset.query
2155 qs._result_cache = vals
2156 # We don't want the individual qs doing prefetch_related now,
2157 # since we have merged this into the current work.
2158 qs._prefetch_done = True
2159 obj._prefetched_objects_cache[cache_name] = qs
2160 return all_related_objects, additional_lookups
2161
2162
2163class RelatedPopulator:
2164 """
2165 RelatedPopulator is used for select_related() object instantiation.
2166
2167 The idea is that each select_related() model will be populated by a
2168 different RelatedPopulator instance. The RelatedPopulator instances get
2169 klass_info and select (computed in SQLCompiler) plus the used db as
2170 input for initialization. That data is used to compute which columns
2171 to use, how to instantiate the model, and how to populate the links
2172 between the objects.
2173
2174 The actual creation of the objects is done in populate() method. This
2175 method gets row and from_obj as input and populates the select_related()
2176 model instance.
2177 """
2178
2179 def __init__(self, klass_info: dict[str, Any], select: list[Any]):
2180 # Pre-compute needed attributes. The attributes are:
2181 # - model_cls: the possibly deferred model class to instantiate
2182 # - either:
2183 # - cols_start, cols_end: usually the columns in the row are
2184 # in the same order model_cls.__init__ expects them, so we
2185 # can instantiate by model_cls(*row[cols_start:cols_end])
2186 # - reorder_for_init: When select_related descends to a child
2187 # class, then we want to reuse the already selected parent
2188 # data. However, in this case the parent data isn't necessarily
2189 # in the same order that Model.__init__ expects it to be, so
2190 # we have to reorder the parent data. The reorder_for_init
2191 # attribute contains a function used to reorder the field data
2192 # in the order __init__ expects it.
2193 # - id_idx: the index of the primary key field in the reordered
2194 # model data. Used to check if a related object exists at all.
2195 # - init_list: the field attnames fetched from the database. For
2196 # deferred models this isn't the same as all attnames of the
2197 # model's fields.
2198 # - related_populators: a list of RelatedPopulator instances if
2199 # select_related() descends to related models from this model.
2200 # - local_setter, remote_setter: Methods to set cached values on
2201 # the object being populated and on the remote object. Usually
2202 # these are Field.set_cached_value() methods.
2203 select_fields = klass_info["select_fields"]
2204
2205 self.cols_start = select_fields[0]
2206 self.cols_end = select_fields[-1] + 1
2207 self.init_list = [
2208 f[0].target.attname for f in select[self.cols_start : self.cols_end]
2209 ]
2210 self.reorder_for_init = None
2211
2212 self.model_cls = klass_info["model"]
2213 self.id_idx = self.init_list.index("id")
2214 self.related_populators = get_related_populators(klass_info, select)
2215 self.local_setter = klass_info["local_setter"]
2216 self.remote_setter = klass_info["remote_setter"]
2217
2218 def populate(self, row: tuple[Any, ...], from_obj: Model) -> None:
2219 if self.reorder_for_init:
2220 obj_data = self.reorder_for_init(row)
2221 else:
2222 obj_data = row[self.cols_start : self.cols_end]
2223 if obj_data[self.id_idx] is None:
2224 obj = None
2225 else:
2226 obj = self.model_cls.from_db(self.init_list, obj_data)
2227 for rel_iter in self.related_populators:
2228 rel_iter.populate(row, obj)
2229 self.local_setter(from_obj, obj)
2230 if obj is not None:
2231 self.remote_setter(obj, from_obj)
2232
2233
2234def get_related_populators(
2235 klass_info: dict[str, Any], select: list[Any]
2236) -> list[RelatedPopulator]:
2237 iterators = []
2238 related_klass_infos = klass_info.get("related_klass_infos", [])
2239 for rel_klass_info in related_klass_infos:
2240 rel_cls = RelatedPopulator(rel_klass_info, select)
2241 iterators.append(rel_cls)
2242 return iterators