1"""
2The main QuerySet implementation. This provides the public API for the ORM.
3"""
4
5from __future__ import annotations
6
7import copy
8import operator
9import warnings
10from collections.abc import Callable, Iterator, Sequence
11from functools import cached_property
12from itertools import chain, islice
13from typing import TYPE_CHECKING, Any, Never, Self, overload
14
15import psycopg
16
17import plain.runtime
18from plain.exceptions import ValidationError
19from plain.postgres import transaction
20from plain.postgres.constants import LOOKUP_SEP, OnConflict
21from plain.postgres.db import (
22 PLAIN_VERSION_PICKLE_KEY,
23 get_connection,
24)
25from plain.postgres.exceptions import (
26 FieldDoesNotExist,
27 FieldError,
28 ObjectDoesNotExist,
29)
30from plain.postgres.expressions import Case, F, ResolvableExpression, Value, When
31from plain.postgres.fields import (
32 Field,
33 PrimaryKeyField,
34)
35from plain.postgres.functions import Cast
36from plain.postgres.query_utils import FilteredRelation, Q
37from plain.postgres.sql import (
38 AND,
39 CURSOR,
40 OR,
41 XOR,
42 DeleteQuery,
43 InsertQuery,
44 Query,
45 RawQuery,
46 UpdateQuery,
47)
48from plain.postgres.utils import resolve_callables
49from plain.utils.functional import partition
50
51# Re-exports for public API
52__all__ = ["F", "Q", "QuerySet", "RawQuerySet", "Prefetch", "FilteredRelation"]
53
54if TYPE_CHECKING:
55 from plain.postgres import Model
56
57# The maximum number of results to fetch in a get() query.
58MAX_GET_RESULTS = 21
59
60# The maximum number of items to display in a QuerySet.__repr__
61REPR_OUTPUT_SIZE = 20
62
63
64class BaseIterable:
65 def __init__(
66 self,
67 queryset: QuerySet[Any],
68 chunked_fetch: bool = False,
69 ):
70 self.queryset = queryset
71 self.chunked_fetch = chunked_fetch
72
73 def __iter__(self) -> Iterator[Any]:
74 raise NotImplementedError(
75 "subclasses of BaseIterable must provide an __iter__() method"
76 )
77
78
79class ModelIterable(BaseIterable):
80 """Iterable that yields a model instance for each row."""
81
82 def __iter__(self) -> Iterator[Model]:
83 queryset = self.queryset
84 compiler = queryset.sql_query.get_compiler()
85 # Execute the query. This will also fill compiler.select, klass_info,
86 # and annotations.
87 results = compiler.execute_sql(chunked_fetch=self.chunked_fetch)
88 select, klass_info, annotation_col_map = (
89 compiler.select,
90 compiler.klass_info,
91 compiler.annotation_col_map,
92 )
93 # These are set by execute_sql() above
94 assert select is not None
95 assert klass_info is not None
96 model_cls = klass_info["model"]
97 select_fields = klass_info["select_fields"]
98 model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
99 init_list = [
100 f[0].target.attname for f in select[model_fields_start:model_fields_end]
101 ]
102 related_populators = get_related_populators(klass_info, select)
103 known_related_objects = [
104 (
105 field,
106 related_objs,
107 operator.attrgetter(field.attname),
108 )
109 for field, related_objs in queryset._known_related_objects.items()
110 ]
111 for row in compiler.results_iter(results):
112 obj = model_cls.from_db(init_list, row[model_fields_start:model_fields_end])
113 for rel_populator in related_populators:
114 rel_populator.populate(row, obj)
115 if annotation_col_map:
116 for attr_name, col_pos in annotation_col_map.items():
117 setattr(obj, attr_name, row[col_pos])
118
119 # Add the known related objects to the model.
120 for field, rel_objs, rel_getter in known_related_objects:
121 # Avoid overwriting objects loaded by, e.g., select_related().
122 if field.is_cached(obj):
123 continue
124 rel_obj_id = rel_getter(obj)
125 try:
126 rel_obj = rel_objs[rel_obj_id]
127 except KeyError:
128 pass # May happen in qs1 | qs2 scenarios.
129 else:
130 setattr(obj, field.name, rel_obj)
131
132 yield obj
133
134
135class RawModelIterable(BaseIterable):
136 """
137 Iterable that yields a model instance for each row from a raw queryset.
138 """
139
140 queryset: RawQuerySet
141
142 def __iter__(self) -> Iterator[Model]:
143 from plain.postgres.sql.compiler import apply_converters, get_converters
144
145 query = self.queryset.sql_query
146 connection = get_connection()
147 query_iterator: Iterator[Any] = iter(query)
148
149 try:
150 (
151 model_init_names,
152 model_init_pos,
153 annotation_fields,
154 ) = self.queryset.resolve_model_init_order()
155 model_cls = self.queryset.model
156 assert model_cls is not None
157 if "id" not in model_init_names:
158 raise FieldDoesNotExist("Raw query must include the primary key")
159 fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
160 converters = get_converters(
161 [
162 f.get_col(f.model.model_options.db_table) if f else None
163 for f in fields
164 ],
165 connection,
166 )
167 if converters:
168 query_iterator = apply_converters(
169 query_iterator, converters, connection
170 )
171 for values in query_iterator:
172 # Associate fields to values
173 model_init_values = [values[pos] for pos in model_init_pos]
174 instance = model_cls.from_db(model_init_names, model_init_values)
175 if annotation_fields:
176 for column, pos in annotation_fields:
177 setattr(instance, column, values[pos])
178 yield instance
179 finally:
180 # Done iterating the Query. If it has its own cursor, close it.
181 if hasattr(query, "cursor") and query.cursor:
182 query.cursor.close()
183
184
185class ValuesIterable(BaseIterable):
186 """
187 Iterable returned by QuerySet.values() that yields a dict for each row.
188 """
189
190 def __iter__(self) -> Iterator[dict[str, Any]]:
191 queryset = self.queryset
192 query = queryset.sql_query
193 compiler = query.get_compiler()
194
195 # extra(select=...) cols are always at the start of the row.
196 names = [
197 *query.extra_select,
198 *query.values_select,
199 *query.annotation_select,
200 ]
201 indexes = range(len(names))
202 for row in compiler.results_iter(chunked_fetch=self.chunked_fetch):
203 yield {names[i]: row[i] for i in indexes}
204
205
206class ValuesListIterable(BaseIterable):
207 """
208 Iterable returned by QuerySet.values_list(flat=False) that yields a tuple
209 for each row.
210 """
211
212 def __iter__(self) -> Iterator[tuple[Any, ...]]:
213 queryset = self.queryset
214 query = queryset.sql_query
215 compiler = query.get_compiler()
216
217 if queryset._fields:
218 # extra(select=...) cols are always at the start of the row.
219 names = [
220 *query.extra_select,
221 *query.values_select,
222 *query.annotation_select,
223 ]
224 fields = [
225 *queryset._fields,
226 *(f for f in query.annotation_select if f not in queryset._fields),
227 ]
228 if fields != names:
229 # Reorder according to fields.
230 index_map = {name: idx for idx, name in enumerate(names)}
231 rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
232 return map(
233 rowfactory,
234 compiler.results_iter(chunked_fetch=self.chunked_fetch),
235 )
236 return iter(
237 compiler.results_iter(
238 tuple_expected=True,
239 chunked_fetch=self.chunked_fetch,
240 )
241 )
242
243
244class FlatValuesListIterable(BaseIterable):
245 """
246 Iterable returned by QuerySet.values_list(flat=True) that yields single
247 values.
248 """
249
250 def __iter__(self) -> Iterator[Any]:
251 queryset = self.queryset
252 compiler = queryset.sql_query.get_compiler()
253 for row in compiler.results_iter(chunked_fetch=self.chunked_fetch):
254 yield row[0]
255
256
257class QuerySet[T: "Model"]:
258 """
259 Represent a lazy database lookup for a set of objects.
260
261 Usage:
262 MyModel.query.filter(name="test").all()
263
264 Custom QuerySets:
265 from typing import Self
266
267 class TaskQuerySet(QuerySet["Task"]):
268 def active(self) -> Self:
269 return self.filter(is_active=True)
270
271 class Task(Model):
272 is_active = BooleanField(default=True)
273 query = TaskQuerySet()
274
275 Task.query.active().filter(name="test") # Full type inference
276
277 Custom methods should return `Self` to preserve type through method chaining.
278 """
279
280 # Instance attributes (set in from_model())
281 model: type[T]
282 _query: Query
283 _result_cache: list[T] | None
284 _sticky_filter: bool
285 _for_write: bool
286 _prefetch_related_lookups: tuple[Any, ...]
287 _prefetch_done: bool
288 _known_related_objects: dict[Any, dict[Any, Any]]
289 _iterable_class: type[BaseIterable]
290 _fields: tuple[str, ...] | None
291 _defer_next_filter: bool
292 _deferred_filter: tuple[bool, tuple[Any, ...], dict[str, Any]] | None
293
294 def __init__(self):
295 """Minimal init for descriptor mode. Use from_model() to create instances."""
296 pass
297
298 @classmethod
299 def from_model(cls, model: type[T], query: Query | None = None) -> Self:
300 """Create a QuerySet instance bound to a model."""
301 instance = cls()
302 instance.model = model
303 instance._query = query or Query(model)
304 instance._result_cache = None
305 instance._sticky_filter = False
306 instance._for_write = False
307 instance._prefetch_related_lookups = ()
308 instance._prefetch_done = False
309 instance._known_related_objects = {}
310 instance._iterable_class = ModelIterable
311 instance._fields = None
312 instance._defer_next_filter = False
313 instance._deferred_filter = None
314 return instance
315
316 @overload
317 def __get__(self, instance: None, owner: type[T]) -> Self: ...
318
319 @overload
320 def __get__(self, instance: Model, owner: type[T]) -> Never: ...
321
322 def __get__(self, instance: Any, owner: type[T]) -> Self:
323 """Descriptor protocol - return a new QuerySet bound to the model."""
324 if instance is not None:
325 raise AttributeError(
326 f"QuerySet is only accessible from the model class, not instances. "
327 f"Use {owner.__name__}.query instead."
328 )
329 return self.from_model(owner)
330
331 @property
332 def sql_query(self) -> Query:
333 if self._deferred_filter:
334 negate, args, kwargs = self._deferred_filter
335 self._filter_or_exclude_inplace(negate, args, kwargs)
336 self._deferred_filter = None
337 return self._query
338
339 @sql_query.setter
340 def sql_query(self, value: Query) -> None:
341 if value.values_select:
342 self._iterable_class = ValuesIterable
343 self._query = value
344
345 ########################
346 # PYTHON MAGIC METHODS #
347 ########################
348
349 def __deepcopy__(self, memo: dict[int, Any]) -> QuerySet[T]:
350 """Don't populate the QuerySet's cache."""
351 obj = self.__class__.from_model(self.model)
352 for k, v in self.__dict__.items():
353 if k == "_result_cache":
354 obj.__dict__[k] = None
355 else:
356 obj.__dict__[k] = copy.deepcopy(v, memo)
357 return obj
358
359 def __getstate__(self) -> dict[str, Any]:
360 # Force the cache to be fully populated.
361 self._fetch_all()
362 return {**self.__dict__, PLAIN_VERSION_PICKLE_KEY: plain.runtime.__version__}
363
364 def __setstate__(self, state: dict[str, Any]) -> None:
365 pickled_version = state.get(PLAIN_VERSION_PICKLE_KEY)
366 if pickled_version:
367 if pickled_version != plain.runtime.__version__:
368 warnings.warn(
369 f"Pickled queryset instance's Plain version {pickled_version} does not "
370 f"match the current version {plain.runtime.__version__}.",
371 RuntimeWarning,
372 stacklevel=2,
373 )
374 else:
375 warnings.warn(
376 "Pickled queryset instance's Plain version is not specified.",
377 RuntimeWarning,
378 stacklevel=2,
379 )
380 self.__dict__.update(state)
381
382 def __repr__(self) -> str:
383 # Don't run SQL from __repr__ — error reporters (Sentry, pdb,
384 # exception templates) call repr() on stack-frame locals to
385 # build error events. If the queryset hasn't been evaluated, a
386 # surprise SELECT inside an exception path is a known footgun.
387 if self._result_cache is None:
388 return f"<{self.__class__.__name__} [unevaluated]>"
389 data: list[Any] = list(self._result_cache[: REPR_OUTPUT_SIZE + 1])
390 if len(data) > REPR_OUTPUT_SIZE:
391 data[-1] = "...(remaining elements truncated)..."
392 return f"<{self.__class__.__name__} {data!r}>"
393
394 def __len__(self) -> int:
395 self._fetch_all()
396 assert self._result_cache is not None
397 return len(self._result_cache)
398
399 def __iter__(self) -> Iterator[T]:
400 """
401 The queryset iterator protocol uses three nested iterators in the
402 default case:
403 1. sql.compiler.execute_sql()
404 - Returns a flat iterable of rows: a list from fetchall()
405 for regular queries, or a streaming generator from
406 cursor.stream() when using .iterator().
407 2. sql.compiler.results_iter()
408 - Returns one row at a time. At this point the rows are still
409 just tuples. In some cases the return values are converted
410 to Python values at this location.
411 3. self.iterator()
412 - Responsible for turning the rows into model objects.
413 """
414 self._fetch_all()
415 assert self._result_cache is not None
416 return iter(self._result_cache)
417
418 def __bool__(self) -> bool:
419 self._fetch_all()
420 return bool(self._result_cache)
421
422 @overload
423 def __getitem__(self, k: int) -> T: ...
424
425 @overload
426 def __getitem__(self, k: slice) -> QuerySet[T] | list[T]: ...
427
428 def __getitem__(self, k: int | slice) -> T | QuerySet[T] | list[T]:
429 """Retrieve an item or slice from the set of results."""
430 if not isinstance(k, int | slice):
431 raise TypeError(
432 f"QuerySet indices must be integers or slices, not {type(k).__name__}."
433 )
434 if (isinstance(k, int) and k < 0) or (
435 isinstance(k, slice)
436 and (
437 (k.start is not None and k.start < 0)
438 or (k.stop is not None and k.stop < 0)
439 )
440 ):
441 raise ValueError("Negative indexing is not supported.")
442
443 if self._result_cache is not None:
444 return self._result_cache[k]
445
446 if isinstance(k, slice):
447 qs = self._chain()
448 if k.start is not None:
449 start = int(k.start)
450 else:
451 start = None
452 if k.stop is not None:
453 stop = int(k.stop)
454 else:
455 stop = None
456 qs.sql_query.set_limits(start, stop)
457 return list(qs)[:: k.step] if k.step else qs
458
459 qs = self._chain()
460 qs.sql_query.set_limits(k, k + 1)
461 qs._fetch_all()
462 assert qs._result_cache is not None # _fetch_all guarantees this
463 return qs._result_cache[0]
464
465 def __class_getitem__(cls, *args: Any, **kwargs: Any) -> type[QuerySet[Any]]:
466 return cls
467
468 def __and__(self, other: QuerySet[T]) -> QuerySet[T]:
469 self._merge_sanity_check(other)
470 if isinstance(other, EmptyQuerySet):
471 return other
472 if isinstance(self, EmptyQuerySet):
473 return self
474 combined = self._chain()
475 combined._merge_known_related_objects(other)
476 combined.sql_query.combine(other.sql_query, AND)
477 return combined
478
479 def __or__(self, other: QuerySet[T]) -> QuerySet[T]:
480 self._merge_sanity_check(other)
481 if isinstance(self, EmptyQuerySet):
482 return other
483 if isinstance(other, EmptyQuerySet):
484 return self
485 query = (
486 self
487 if self.sql_query.can_filter()
488 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
489 )
490 combined = query._chain()
491 combined._merge_known_related_objects(other)
492 if not other.sql_query.can_filter():
493 other = other.model._model_meta.base_queryset.filter(
494 id__in=other.values("id")
495 )
496 combined.sql_query.combine(other.sql_query, OR)
497 return combined
498
499 def __xor__(self, other: QuerySet[T]) -> QuerySet[T]:
500 self._merge_sanity_check(other)
501 if isinstance(self, EmptyQuerySet):
502 return other
503 if isinstance(other, EmptyQuerySet):
504 return self
505 query = (
506 self
507 if self.sql_query.can_filter()
508 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
509 )
510 combined = query._chain()
511 combined._merge_known_related_objects(other)
512 if not other.sql_query.can_filter():
513 other = other.model._model_meta.base_queryset.filter(
514 id__in=other.values("id")
515 )
516 combined.sql_query.combine(other.sql_query, XOR)
517 return combined
518
519 ####################################
520 # METHODS THAT DO DATABASE QUERIES #
521 ####################################
522
523 def _iterator(self, use_chunked_fetch: bool, chunk_size: int | None) -> Iterator[T]:
524 iterable = self._iterable_class(
525 self,
526 chunked_fetch=use_chunked_fetch,
527 )
528 if not self._prefetch_related_lookups or chunk_size is None:
529 yield from iterable
530 return
531
532 iterator = iter(iterable)
533 while results := list(islice(iterator, chunk_size)):
534 prefetch_related_objects(results, *self._prefetch_related_lookups)
535 yield from results
536
537 def iterator(self, chunk_size: int | None = None) -> Iterator[T]:
538 """
539 An iterator over the results from applying this QuerySet to the
540 database. chunk_size must be provided for QuerySets that prefetch
541 related objects. Otherwise, a default chunk_size of 2000 is supplied.
542 """
543 if chunk_size is None:
544 if self._prefetch_related_lookups:
545 raise ValueError(
546 "chunk_size must be provided when using QuerySet.iterator() after "
547 "prefetch_related()."
548 )
549 elif chunk_size <= 0:
550 raise ValueError("Chunk size must be strictly positive.")
551 # PostgreSQL always supports server-side cursors for chunked fetches
552 return self._iterator(use_chunked_fetch=True, chunk_size=chunk_size)
553
554 def aggregate(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
555 """
556 Return a dictionary containing the calculations (aggregation)
557 over the current queryset.
558
559 If args is present the expression is passed as a kwarg using
560 the Aggregate object's default alias.
561 """
562 if self.sql_query.distinct_fields:
563 raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
564 self._validate_values_are_expressions(
565 (*args, *kwargs.values()), method_name="aggregate"
566 )
567 for arg in args:
568 # The default_alias property raises TypeError if default_alias
569 # can't be set automatically or AttributeError if it isn't an
570 # attribute.
571 try:
572 arg.default_alias
573 except (AttributeError, TypeError):
574 raise TypeError("Complex aggregates require an alias")
575 kwargs[arg.default_alias] = arg
576
577 return self.sql_query.chain().get_aggregation(kwargs)
578
579 def count(self) -> int:
580 """
581 Perform a SELECT COUNT() and return the number of records as an
582 integer.
583
584 If the QuerySet is already fully cached, return the length of the
585 cached results set to avoid multiple SELECT COUNT(*) calls.
586 """
587 if self._result_cache is not None:
588 return len(self._result_cache)
589
590 return self.sql_query.get_count()
591
592 def get(self, *args: Any, **kwargs: Any) -> T:
593 """
594 Perform the query and return a single object matching the given
595 keyword arguments.
596 """
597 clone = self.filter(*args, **kwargs)
598 if self.sql_query.can_filter() and not self.sql_query.distinct_fields:
599 clone = clone.order_by()
600 limit = MAX_GET_RESULTS
601 clone.sql_query.set_limits(high=limit)
602 num = len(clone)
603 if num == 1:
604 assert clone._result_cache is not None # len() fetches results
605 return clone._result_cache[0]
606 if not num:
607 raise self.model.DoesNotExist(
608 f"{self.model.model_options.object_name} matching query does not exist."
609 )
610 raise self.model.MultipleObjectsReturned(
611 "get() returned more than one {} -- it returned {}!".format(
612 self.model.model_options.object_name,
613 num if not limit or num < limit else "more than %s" % (limit - 1),
614 )
615 )
616
617 def get_or_none(self, *args: Any, **kwargs: Any) -> T | None:
618 """
619 Perform the query and return a single object matching the given
620 keyword arguments, or None if no object is found.
621 """
622 try:
623 return self.get(*args, **kwargs)
624 except self.model.DoesNotExist:
625 return None
626
627 def create(self, **kwargs: Any) -> T:
628 """
629 Create a new object with the given kwargs, saving it to the database
630 and returning the created object.
631 """
632 obj = self.model(**kwargs)
633 self._for_write = True
634 obj.save(force_insert=True)
635 return obj
636
637 def _prepare_for_bulk_create(self, objs: list[T]) -> None:
638 from plain.postgres.fields.base import DefaultableField
639
640 id_field = self.model._model_meta.get_forward_field("id")
641 has_id_default = (
642 isinstance(id_field, DefaultableField) and id_field.has_default()
643 )
644 for obj in objs:
645 if obj.id is None and has_id_default:
646 # User-declared literal default on the PK — materialize it
647 # so the INSERT carries a Python value. Identity PKs take the
648 # DB's DEFAULT path and leave obj.id as None.
649 obj.id = id_field.get_default()
650 obj._prepare_related_fields_for_save(operation_name="bulk_create")
651
652 def _check_bulk_create_options(
653 self,
654 update_conflicts: bool,
655 update_fields: list[Field] | None,
656 unique_fields: list[Field] | None,
657 ) -> OnConflict | None:
658 if update_conflicts:
659 if not update_fields:
660 raise ValueError(
661 "Fields that will be updated when a row insertion fails "
662 "on conflicts must be provided."
663 )
664 if not unique_fields:
665 raise ValueError(
666 "Unique fields that can trigger the upsert must be provided."
667 )
668 # Updating primary keys and non-concrete fields is forbidden.
669 from plain.postgres.fields.related import ManyToManyField
670
671 if any(
672 not f.concrete or isinstance(f, ManyToManyField) for f in update_fields
673 ):
674 raise ValueError(
675 "bulk_create() can only be used with concrete fields in "
676 "update_fields."
677 )
678 if any(f.primary_key for f in update_fields):
679 raise ValueError(
680 "bulk_create() cannot be used with primary keys in update_fields."
681 )
682 if unique_fields:
683 from plain.postgres.fields.related import ManyToManyField
684
685 if any(
686 not f.concrete or isinstance(f, ManyToManyField)
687 for f in unique_fields
688 ):
689 raise ValueError(
690 "bulk_create() can only be used with concrete fields "
691 "in unique_fields."
692 )
693 return OnConflict.UPDATE
694 return None
695
696 def bulk_create(
697 self,
698 objs: Sequence[T],
699 batch_size: int | None = None,
700 update_conflicts: bool = False,
701 update_fields: list[str] | None = None,
702 unique_fields: list[str] | None = None,
703 ) -> list[T]:
704 """
705 Insert each of the instances into the database. Do *not* call
706 save() on each of the instances. Primary keys are set on the objects
707 via the PostgreSQL RETURNING clause. Multi-table models are not supported.
708 """
709 if batch_size is not None and batch_size <= 0:
710 raise ValueError("Batch size must be a positive integer.")
711
712 objs = list(objs)
713 if not objs:
714 return objs
715 meta = self.model._model_meta
716 unique_fields_objs: list[Field] | None = None
717 update_fields_objs: list[Field] | None = None
718 if unique_fields:
719 unique_fields_objs = [
720 meta.get_forward_field(name) for name in unique_fields
721 ]
722 if update_fields:
723 update_fields_objs = [
724 meta.get_forward_field(name) for name in update_fields
725 ]
726 on_conflict = self._check_bulk_create_options(
727 update_conflicts,
728 update_fields_objs,
729 unique_fields_objs,
730 )
731 self._for_write = True
732 fields = meta.concrete_fields
733 self._prepare_for_bulk_create(objs)
734 with transaction.atomic(savepoint=False):
735 objs_with_id, objs_without_id = partition(lambda o: o.id is None, objs)
736 if objs_with_id:
737 returned_columns = self._batched_insert(
738 objs_with_id,
739 fields,
740 batch_size,
741 on_conflict=on_conflict,
742 update_fields=update_fields_objs,
743 unique_fields=unique_fields_objs,
744 )
745 id_field = meta.get_forward_field("id")
746 for obj_with_id, results in zip(objs_with_id, returned_columns):
747 for result, field in zip(results, meta.db_returning_fields):
748 if field != id_field:
749 setattr(obj_with_id, field.attname, result)
750 for obj_with_id in objs_with_id:
751 obj_with_id._state.adding = False
752 if objs_without_id:
753 fields = [f for f in fields if not isinstance(f, PrimaryKeyField)]
754 returned_columns = self._batched_insert(
755 objs_without_id,
756 fields,
757 batch_size,
758 on_conflict=on_conflict,
759 update_fields=update_fields_objs,
760 unique_fields=unique_fields_objs,
761 )
762 if on_conflict is None:
763 assert len(returned_columns) == len(objs_without_id)
764 for obj_without_id, results in zip(objs_without_id, returned_columns):
765 for result, field in zip(results, meta.db_returning_fields):
766 setattr(obj_without_id, field.attname, result)
767 obj_without_id._state.adding = False
768
769 return objs
770
771 def bulk_update(
772 self, objs: Sequence[T], fields: list[str], batch_size: int | None = None
773 ) -> int:
774 """
775 Update the given fields in each of the given objects in the database.
776 """
777 if batch_size is not None and batch_size <= 0:
778 raise ValueError("Batch size must be a positive integer.")
779 if not fields:
780 raise ValueError("Field names must be given to bulk_update().")
781 objs_tuple = tuple(objs)
782 if any(obj.id is None for obj in objs_tuple):
783 raise ValueError("All bulk_update() objects must have a primary key set.")
784 fields_list = [
785 self.model._model_meta.get_forward_field(name) for name in fields
786 ]
787 from plain.postgres.fields.related import ManyToManyField
788
789 if any(not f.concrete or isinstance(f, ManyToManyField) for f in fields_list):
790 raise ValueError("bulk_update() can only be used with concrete fields.")
791 if any(f.primary_key for f in fields_list):
792 raise ValueError("bulk_update() cannot be used with primary key fields.")
793 if not objs_tuple:
794 return 0
795 for obj in objs_tuple:
796 obj._prepare_related_fields_for_save(
797 operation_name="bulk_update", fields=fields_list
798 )
799 # PK is used twice in the resulting update query, once in the filter
800 # and once in the WHEN. Each field will also have one CAST.
801 self._for_write = True
802 max_batch_size = len(objs_tuple)
803 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
804 batches = (
805 objs_tuple[i : i + batch_size]
806 for i in range(0, len(objs_tuple), batch_size)
807 )
808 updates = []
809 for batch_objs in batches:
810 update_kwargs = {}
811 for field in fields_list:
812 when_statements = []
813 for obj in batch_objs:
814 attr = getattr(obj, field.attname)
815 if not isinstance(attr, ResolvableExpression):
816 attr = Value(attr, output_field=field)
817 when_statements.append(When(id=obj.id, then=attr))
818 case_statement = Case(*when_statements, output_field=field)
819 # PostgreSQL requires casted CASE in updates
820 case_statement = Cast(case_statement, output_field=field)
821 update_kwargs[field.attname] = case_statement
822 updates.append(([obj.id for obj in batch_objs], update_kwargs))
823 rows_updated = 0
824 queryset = self._chain()
825 with transaction.atomic(savepoint=False):
826 for ids, update_kwargs in updates:
827 rows_updated += queryset.filter(id__in=ids).update(**update_kwargs)
828 return rows_updated
829
830 def get_or_create(
831 self, defaults: dict[str, Any] | None = None, **kwargs: Any
832 ) -> tuple[T, bool]:
833 """
834 Look up an object with the given kwargs, creating one if necessary.
835 Return a tuple of (object, created), where created is a boolean
836 specifying whether an object was created.
837 """
838 # The get() needs to be targeted at the write database in order
839 # to avoid potential transaction consistency problems.
840 self._for_write = True
841 try:
842 return self.get(**kwargs), False
843 except self.model.DoesNotExist:
844 params = self._extract_model_params(defaults, **kwargs)
845 # Try to create an object using passed params.
846 try:
847 with transaction.atomic():
848 params = dict(resolve_callables(params))
849 return self.create(**params), True
850 except (psycopg.IntegrityError, ValidationError):
851 # Since create() also validates by default,
852 # we can get any kind of ValidationError here,
853 # or it can flow through and get an IntegrityError from the database.
854 # The main thing we're concerned about is uniqueness failures,
855 # but ValidationError could include other things too.
856 # In all cases though it should be fine to try the get() again
857 # and return an existing object.
858 try:
859 return self.get(**kwargs), False
860 except self.model.DoesNotExist:
861 pass
862 raise
863
864 def update_or_create(
865 self,
866 defaults: dict[str, Any] | None = None,
867 create_defaults: dict[str, Any] | None = None,
868 **kwargs: Any,
869 ) -> tuple[T, bool]:
870 """
871 Look up an object with the given kwargs, updating one with defaults
872 if it exists, otherwise create a new one. Optionally, an object can
873 be created with different values than defaults by using
874 create_defaults.
875 Return a tuple (object, created), where created is a boolean
876 specifying whether an object was created.
877 """
878 if create_defaults is None:
879 update_defaults = create_defaults = defaults or {}
880 else:
881 update_defaults = defaults or {}
882 self._for_write = True
883 with transaction.atomic():
884 # Lock the row so that a concurrent update is blocked until
885 # update_or_create() has performed its save.
886 obj, created = self.select_for_update().get_or_create(
887 create_defaults, **kwargs
888 )
889 if created:
890 return obj, created
891 for k, v in resolve_callables(update_defaults):
892 setattr(obj, k, v)
893
894 update_fields = set(update_defaults)
895 concrete_field_names = self.model._model_meta._non_pk_concrete_field_names
896 # update_fields does not support non-concrete fields.
897 if concrete_field_names.issuperset(update_fields):
898 # Add fields which are set on pre_save(), e.g. update_now fields.
899 # This is to maintain backward compatibility as these fields
900 # are not updated unless explicitly specified in the
901 # update_fields list.
902 for field in self.model._model_meta.local_concrete_fields:
903 if not (
904 field.primary_key or field.__class__.pre_save is Field.pre_save
905 ):
906 update_fields.add(field.name)
907 if field.name != field.attname:
908 update_fields.add(field.attname)
909 obj.save(update_fields=update_fields)
910 else:
911 obj.save()
912 return obj, False
913
914 def _extract_model_params(
915 self, defaults: dict[str, Any] | None, **kwargs: Any
916 ) -> dict[str, Any]:
917 """
918 Prepare `params` for creating a model instance based on the given
919 kwargs; for use by get_or_create().
920 """
921 defaults = defaults or {}
922 params = {k: v for k, v in kwargs.items() if LOOKUP_SEP not in k}
923 params.update(defaults)
924 property_names = self.model._model_meta._property_names
925 invalid_params = []
926 for param in params:
927 try:
928 self.model._model_meta.get_field(param)
929 except FieldDoesNotExist:
930 # It's okay to use a model's property if it has a setter.
931 if not (param in property_names and getattr(self.model, param).fset):
932 invalid_params.append(param)
933 if invalid_params:
934 raise FieldError(
935 "Invalid field name(s) for model {}: '{}'.".format(
936 self.model.model_options.object_name,
937 "', '".join(sorted(invalid_params)),
938 )
939 )
940 return params
941
942 def first(self) -> T | None:
943 """Return the first object of a query or None if no match is found."""
944 for obj in self[:1]:
945 return obj
946 return None
947
948 def last(self) -> T | None:
949 """Return the last object of a query or None if no match is found."""
950 queryset = self.reverse()
951 for obj in queryset[:1]:
952 return obj
953 return None
954
955 def delete(self) -> int:
956 """Delete the records in the current QuerySet.
957
958 Returns the number of parent rows deleted. Cascaded child rows are
959 handled by Postgres via the declared `on_delete` clauses and are not
960 included in the count.
961 """
962 if self.sql_query.is_sliced:
963 raise TypeError("Cannot use 'limit' or 'offset' with delete().")
964 if self.sql_query.distinct or self.sql_query.distinct_fields:
965 raise TypeError("Cannot call delete() after .distinct().")
966 if self._fields is not None:
967 raise TypeError("Cannot call delete() after .values() or .values_list()")
968
969 del_query = self._chain()
970 del_query._for_write = True
971 del_query.sql_query.select_for_update = False
972 del_query.sql_query.select_related = False
973 del_query.sql_query.clear_ordering(force=True)
974
975 # FK errors (RESTRICT / NO_ACTION) leave the DB transaction aborted.
976 # Mark the connection so outer atomic() blocks see the abort state
977 # even if the caller catches IntegrityError themselves.
978 with transaction.mark_for_rollback_on_error():
979 count = del_query._raw_delete()
980
981 # Clear the result cache, in case this QuerySet gets reused.
982 self._result_cache = None
983 return count
984
985 def _raw_delete(self) -> int:
986 """
987 Delete objects found from the given queryset in single direct SQL
988 query. No signals are sent and there is no protection for cascades.
989 """
990 query = self.sql_query.clone()
991 query.__class__ = DeleteQuery
992 cursor = query.get_compiler().execute_sql(CURSOR)
993 if cursor:
994 with cursor:
995 return cursor.rowcount
996 return 0
997
998 def update(self, **kwargs: Any) -> int:
999 """
1000 Update all elements in the current QuerySet, setting all the given
1001 fields to the appropriate values.
1002 """
1003 if self.sql_query.is_sliced:
1004 raise TypeError("Cannot update a query once a slice has been taken.")
1005 self._for_write = True
1006 query = self.sql_query.chain(UpdateQuery)
1007 query.add_update_values(kwargs)
1008
1009 # Inline annotations in order_by(), if possible.
1010 new_order_by = []
1011 for col in query.order_by:
1012 alias = col
1013 descending = False
1014 if isinstance(alias, str) and alias.startswith("-"):
1015 alias = alias.removeprefix("-")
1016 descending = True
1017 if annotation := query.annotations.get(alias):
1018 if getattr(annotation, "contains_aggregate", False):
1019 raise FieldError(
1020 f"Cannot update when ordering by an aggregate: {annotation}"
1021 )
1022 if descending:
1023 annotation = annotation.desc()
1024 new_order_by.append(annotation)
1025 else:
1026 new_order_by.append(col)
1027 query.order_by = tuple(new_order_by)
1028
1029 # Clear any annotations so that they won't be present in subqueries.
1030 query.annotations = {}
1031 with transaction.mark_for_rollback_on_error():
1032 rows = query.get_compiler().execute_sql(CURSOR)
1033 self._result_cache = None
1034 return rows
1035
1036 def _update(self, values: list[tuple[Field, Any]]) -> int:
1037 """
1038 A version of update() that accepts field objects instead of field names.
1039 Used primarily for model saving and not intended for use by general
1040 code (it requires too much poking around at model internals to be
1041 useful at that level).
1042 """
1043 if self.sql_query.is_sliced:
1044 raise TypeError("Cannot update a query once a slice has been taken.")
1045 query = self.sql_query.chain(UpdateQuery)
1046 query.add_update_fields(values)
1047 # Clear any annotations so that they won't be present in subqueries.
1048 query.annotations = {}
1049 self._result_cache = None
1050 return query.get_compiler().execute_sql(CURSOR)
1051
1052 def exists(self) -> bool:
1053 """
1054 Return True if the QuerySet would have any results, False otherwise.
1055 """
1056 if self._result_cache is None:
1057 return self.sql_query.has_results()
1058 return bool(self._result_cache)
1059
1060 def _prefetch_related_objects(self) -> None:
1061 # This method can only be called once the result cache has been filled.
1062 assert self._result_cache is not None
1063 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1064 self._prefetch_done = True
1065
1066 def explain(self, *, format: str | None = None, **options: Any) -> str:
1067 """
1068 Runs an EXPLAIN on the SQL query this QuerySet would perform, and
1069 returns the results.
1070 """
1071 return self.sql_query.explain(format=format, **options)
1072
1073 ##################################################
1074 # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
1075 ##################################################
1076
1077 def raw(
1078 self,
1079 raw_query: str,
1080 params: Sequence[Any] = (),
1081 translations: dict[str, str] | None = None,
1082 ) -> RawQuerySet:
1083 qs = RawQuerySet(
1084 raw_query,
1085 model=self.model,
1086 params=tuple(params),
1087 translations=translations,
1088 )
1089 qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
1090 return qs
1091
1092 def _values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1093 clone = self._chain()
1094 if expressions:
1095 clone = clone.annotate(**expressions)
1096 clone._fields = fields
1097 clone.sql_query.set_values(list(fields))
1098 return clone
1099
1100 def values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1101 fields += tuple(expressions)
1102 clone = self._values(*fields, **expressions)
1103 clone._iterable_class = ValuesIterable
1104 return clone
1105
1106 def values_list(self, *fields: str, flat: bool = False) -> QuerySet[Any]:
1107 if flat and len(fields) > 1:
1108 raise TypeError(
1109 "'flat' is not valid when values_list is called with more than one "
1110 "field."
1111 )
1112
1113 field_names = {f for f in fields if not isinstance(f, ResolvableExpression)}
1114 _fields = []
1115 expressions = {}
1116 counter = 1
1117 for field in fields:
1118 if isinstance(field, ResolvableExpression):
1119 field_id_prefix = getattr(
1120 field, "default_alias", field.__class__.__name__.lower()
1121 )
1122 while True:
1123 field_id = field_id_prefix + str(counter)
1124 counter += 1
1125 if field_id not in field_names:
1126 break
1127 expressions[field_id] = field
1128 _fields.append(field_id)
1129 else:
1130 _fields.append(field)
1131
1132 clone = self._values(*_fields, **expressions)
1133 clone._iterable_class = FlatValuesListIterable if flat else ValuesListIterable
1134 return clone
1135
1136 def none(self) -> QuerySet[T]:
1137 """Return an empty QuerySet."""
1138 clone = self._chain()
1139 clone.sql_query.set_empty()
1140 return clone
1141
1142 ##################################################################
1143 # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
1144 ##################################################################
1145
1146 def all(self) -> Self:
1147 """
1148 Return a new QuerySet that is a copy of the current one. This allows a
1149 QuerySet to proxy for a model queryset in some cases.
1150 """
1151 obj = self._chain()
1152 # Preserve cache since all() doesn't modify the query.
1153 # This is important for prefetch_related() to work correctly.
1154 obj._result_cache = self._result_cache
1155 obj._prefetch_done = self._prefetch_done
1156 return obj
1157
1158 def filter(self, *args: Any, **kwargs: Any) -> Self:
1159 """
1160 Return a new QuerySet instance with the args ANDed to the existing
1161 set.
1162 """
1163 return self._filter_or_exclude(False, args, kwargs)
1164
1165 def exclude(self, *args: Any, **kwargs: Any) -> Self:
1166 """
1167 Return a new QuerySet instance with NOT (args) ANDed to the existing
1168 set.
1169 """
1170 return self._filter_or_exclude(True, args, kwargs)
1171
1172 def _filter_or_exclude(
1173 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1174 ) -> Self:
1175 if (args or kwargs) and self.sql_query.is_sliced:
1176 raise TypeError("Cannot filter a query once a slice has been taken.")
1177 clone = self._chain()
1178 if self._defer_next_filter:
1179 self._defer_next_filter = False
1180 clone._deferred_filter = negate, args, kwargs
1181 else:
1182 clone._filter_or_exclude_inplace(negate, args, kwargs)
1183 return clone
1184
1185 def _filter_or_exclude_inplace(
1186 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1187 ) -> None:
1188 if negate:
1189 self._query.add_q(~Q(*args, **kwargs))
1190 else:
1191 self._query.add_q(Q(*args, **kwargs))
1192
1193 def complex_filter(self, filter_obj: Q | dict[str, Any]) -> QuerySet[T]:
1194 """
1195 Return a new QuerySet instance with filter_obj added to the filters.
1196
1197 filter_obj can be a Q object or a dictionary of keyword lookup
1198 arguments.
1199
1200 This exists to support framework features such as 'limit_choices_to',
1201 and usually it will be more natural to use other methods.
1202 """
1203 if isinstance(filter_obj, Q):
1204 clone = self._chain()
1205 clone.sql_query.add_q(filter_obj)
1206 return clone
1207 else:
1208 return self._filter_or_exclude(False, args=(), kwargs=filter_obj)
1209
1210 def select_for_update(
1211 self,
1212 nowait: bool = False,
1213 skip_locked: bool = False,
1214 of: tuple[str, ...] = (),
1215 no_key: bool = False,
1216 ) -> QuerySet[T]:
1217 """
1218 Return a new QuerySet instance that will select objects with a
1219 FOR UPDATE lock.
1220 """
1221 if nowait and skip_locked:
1222 raise ValueError("The nowait option cannot be used with skip_locked.")
1223 obj = self._chain()
1224 obj._for_write = True
1225 obj.sql_query.select_for_update = True
1226 obj.sql_query.select_for_update_nowait = nowait
1227 obj.sql_query.select_for_update_skip_locked = skip_locked
1228 obj.sql_query.select_for_update_of = of
1229 obj.sql_query.select_for_no_key_update = no_key
1230 return obj
1231
1232 def select_related(self, *fields: str | None) -> Self:
1233 """
1234 Return a new QuerySet instance that will select related objects.
1235
1236 If fields are specified, they must be ForeignKeyField fields and only those
1237 related objects are included in the selection.
1238
1239 If select_related(None) is called, clear the list.
1240 """
1241 if self._fields is not None:
1242 raise TypeError(
1243 "Cannot call select_related() after .values() or .values_list()"
1244 )
1245
1246 obj = self._chain()
1247 if fields == (None,):
1248 obj.sql_query.select_related = False
1249 elif fields:
1250 obj.sql_query.add_select_related(list(fields)) # ty: ignore[invalid-argument-type]
1251 else:
1252 obj.sql_query.select_related = True
1253 return obj
1254
1255 def prefetch_related(self, *lookups: str | Prefetch | None) -> Self:
1256 """
1257 Return a new QuerySet instance that will prefetch the specified
1258 Many-To-One and Many-To-Many related objects when the QuerySet is
1259 evaluated.
1260
1261 When prefetch_related() is called more than once, append to the list of
1262 prefetch lookups. If prefetch_related(None) is called, clear the list.
1263 """
1264 clone = self._chain()
1265 if lookups == (None,):
1266 clone._prefetch_related_lookups = ()
1267 else:
1268 for lookup in lookups:
1269 lookup_str: str
1270 if isinstance(lookup, Prefetch):
1271 lookup_str = lookup.prefetch_to
1272 else:
1273 assert isinstance(lookup, str)
1274 lookup_str = lookup
1275 lookup_str = lookup_str.split(LOOKUP_SEP, 1)[0]
1276 if lookup_str in self.sql_query._filtered_relations:
1277 raise ValueError(
1278 "prefetch_related() is not supported with FilteredRelation."
1279 )
1280 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1281 return clone
1282
1283 def annotate(self, *args: Any, **kwargs: Any) -> Self:
1284 """
1285 Return a query set in which the returned objects have been annotated
1286 with extra data or aggregations.
1287 """
1288 return self._annotate(args, kwargs, select=True)
1289
1290 def alias(self, *args: Any, **kwargs: Any) -> Self:
1291 """
1292 Return a query set with added aliases for extra data or aggregations.
1293 """
1294 return self._annotate(args, kwargs, select=False)
1295
1296 def _annotate(
1297 self, args: tuple[Any, ...], kwargs: dict[str, Any], select: bool = True
1298 ) -> Self:
1299 self._validate_values_are_expressions(
1300 args + tuple(kwargs.values()), method_name="annotate"
1301 )
1302 annotations = {}
1303 for arg in args:
1304 # The default_alias property may raise a TypeError.
1305 try:
1306 if arg.default_alias in kwargs:
1307 raise ValueError(
1308 f"The named annotation '{arg.default_alias}' conflicts with the "
1309 "default name for another annotation."
1310 )
1311 except TypeError:
1312 raise TypeError("Complex annotations require an alias")
1313 annotations[arg.default_alias] = arg
1314 annotations.update(kwargs)
1315
1316 clone = self._chain()
1317 names = self._fields
1318 if names is None:
1319 names = set(
1320 chain.from_iterable(
1321 (field.name, field.attname)
1322 if hasattr(field, "attname")
1323 else (field.name,)
1324 for field in self.model._model_meta.get_fields()
1325 )
1326 )
1327
1328 for alias, annotation in annotations.items():
1329 if alias in names:
1330 raise ValueError(
1331 f"The annotation '{alias}' conflicts with a field on the model."
1332 )
1333 if isinstance(annotation, FilteredRelation):
1334 clone.sql_query.add_filtered_relation(annotation, alias)
1335 else:
1336 clone.sql_query.add_annotation(
1337 annotation,
1338 alias,
1339 select=select,
1340 )
1341 for alias, annotation in clone.sql_query.annotations.items():
1342 if alias in annotations and annotation.contains_aggregate:
1343 if clone._fields is None:
1344 clone.sql_query.group_by = True
1345 else:
1346 clone.sql_query.set_group_by()
1347 break
1348
1349 return clone
1350
1351 def order_by(self, *field_names: str) -> Self:
1352 """Return a new QuerySet instance with the ordering changed."""
1353 if self.sql_query.is_sliced:
1354 raise TypeError("Cannot reorder a query once a slice has been taken.")
1355 obj = self._chain()
1356 obj.sql_query.clear_ordering(force=True, clear_default=False)
1357 obj.sql_query.add_ordering(*field_names)
1358 return obj
1359
1360 def distinct(self, *field_names: str) -> Self:
1361 """
1362 Return a new QuerySet instance that will select only distinct results.
1363 """
1364 if self.sql_query.is_sliced:
1365 raise TypeError(
1366 "Cannot create distinct fields once a slice has been taken."
1367 )
1368 obj = self._chain()
1369 obj.sql_query.add_distinct_fields(*field_names)
1370 return obj
1371
1372 def extra(
1373 self,
1374 select: dict[str, str] | None = None,
1375 where: list[str] | None = None,
1376 params: list[Any] | None = None,
1377 tables: list[str] | None = None,
1378 order_by: list[str] | None = None,
1379 select_params: list[Any] | None = None,
1380 ) -> QuerySet[T]:
1381 """Add extra SQL fragments to the query."""
1382 if self.sql_query.is_sliced:
1383 raise TypeError("Cannot change a query once a slice has been taken.")
1384 clone = self._chain()
1385 clone.sql_query.add_extra(
1386 select or {},
1387 select_params,
1388 where or [],
1389 params or [],
1390 tables or [],
1391 tuple(order_by) if order_by else (),
1392 )
1393 return clone
1394
1395 def reverse(self) -> QuerySet[T]:
1396 """Reverse the ordering of the QuerySet."""
1397 if self.sql_query.is_sliced:
1398 raise TypeError("Cannot reverse a query once a slice has been taken.")
1399 clone = self._chain()
1400 clone.sql_query.standard_ordering = not clone.sql_query.standard_ordering
1401 return clone
1402
1403 def defer(self, *fields: str | None) -> QuerySet[T]:
1404 """
1405 Defer the loading of data for certain fields until they are accessed.
1406 Add the set of deferred fields to any existing set of deferred fields.
1407 The only exception to this is if None is passed in as the only
1408 parameter, in which case removal all deferrals.
1409 """
1410 if self._fields is not None:
1411 raise TypeError("Cannot call defer() after .values() or .values_list()")
1412 clone = self._chain()
1413 if fields == (None,):
1414 clone.sql_query.clear_deferred_loading()
1415 else:
1416 clone.sql_query.add_deferred_loading(frozenset(fields)) # ty: ignore[invalid-argument-type]
1417 return clone
1418
1419 def only(self, *fields: str) -> QuerySet[T]:
1420 """
1421 Essentially, the opposite of defer(). Only the fields passed into this
1422 method and that are not already specified as deferred are loaded
1423 immediately when the queryset is evaluated.
1424 """
1425 if self._fields is not None:
1426 raise TypeError("Cannot call only() after .values() or .values_list()")
1427 if fields == (None,):
1428 # Can only pass None to defer(), not only(), as the rest option.
1429 # That won't stop people trying to do this, so let's be explicit.
1430 raise TypeError("Cannot pass None as an argument to only().")
1431 for field in fields:
1432 field = field.split(LOOKUP_SEP, 1)[0]
1433 if field in self.sql_query._filtered_relations:
1434 raise ValueError("only() is not supported with FilteredRelation.")
1435 clone = self._chain()
1436 clone.sql_query.add_immediate_loading(set(fields))
1437 return clone
1438
1439 ###################################
1440 # PUBLIC INTROSPECTION ATTRIBUTES #
1441 ###################################
1442
1443 @property
1444 def ordered(self) -> bool:
1445 """
1446 Return True if the QuerySet is ordered -- i.e. has an order_by()
1447 clause or a default ordering on the model (or is empty).
1448 """
1449 if isinstance(self, EmptyQuerySet):
1450 return True
1451 if self.sql_query.extra_order_by or self.sql_query.order_by:
1452 return True
1453 elif (
1454 self.sql_query.default_ordering
1455 and self.sql_query.model
1456 and self.sql_query.model._model_meta.ordering # ty: ignore[unresolved-attribute]
1457 and
1458 # A default ordering doesn't affect GROUP BY queries.
1459 not self.sql_query.group_by
1460 ):
1461 return True
1462 else:
1463 return False
1464
1465 ###################
1466 # PRIVATE METHODS #
1467 ###################
1468
1469 def _insert(
1470 self,
1471 objs: list[T],
1472 fields: list[Field],
1473 returning_fields: list[Field] | None = None,
1474 raw: bool = False,
1475 on_conflict: OnConflict | None = None,
1476 update_fields: list[Field] | None = None,
1477 unique_fields: list[Field] | None = None,
1478 ) -> list[tuple[Any, ...]] | None:
1479 """
1480 Insert a new record for the given model. This provides an interface to
1481 the InsertQuery class and is how Model.save() is implemented.
1482 """
1483 self._for_write = True
1484 query = InsertQuery(
1485 self.model,
1486 on_conflict=on_conflict if on_conflict else None,
1487 update_fields=update_fields,
1488 unique_fields=unique_fields,
1489 )
1490 query.insert_values(fields, objs, raw=raw)
1491 # InsertQuery returns SQLInsertCompiler which has different execute_sql signature
1492 return query.get_compiler().execute_sql(returning_fields)
1493
1494 def _batched_insert(
1495 self,
1496 objs: list[T],
1497 fields: list[Field],
1498 batch_size: int | None,
1499 on_conflict: OnConflict | None = None,
1500 update_fields: list[Field] | None = None,
1501 unique_fields: list[Field] | None = None,
1502 ) -> list[tuple[Any, ...]]:
1503 """
1504 Helper method for bulk_create() to insert objs one batch at a time.
1505 """
1506 max_batch_size = max(len(objs), 1)
1507 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
1508 inserted_rows = []
1509 for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
1510 if on_conflict is None:
1511 inserted_rows.extend(
1512 self._insert( # ty: ignore[invalid-argument-type]
1513 item,
1514 fields=fields,
1515 returning_fields=self.model._model_meta.db_returning_fields,
1516 )
1517 )
1518 else:
1519 self._insert(
1520 item,
1521 fields=fields,
1522 on_conflict=on_conflict,
1523 update_fields=update_fields,
1524 unique_fields=unique_fields,
1525 )
1526 return inserted_rows
1527
1528 def _chain(self) -> Self:
1529 """
1530 Return a copy of the current QuerySet that's ready for another
1531 operation.
1532 """
1533 obj = self._clone()
1534 if obj._sticky_filter:
1535 obj.sql_query.filter_is_sticky = True
1536 obj._sticky_filter = False
1537 return obj
1538
1539 def _clone(self) -> Self:
1540 """
1541 Return a copy of the current QuerySet. A lightweight alternative
1542 to deepcopy().
1543 """
1544 c = self.__class__.from_model(
1545 model=self.model,
1546 query=self.sql_query.chain(),
1547 )
1548 c._sticky_filter = self._sticky_filter
1549 c._for_write = self._for_write
1550 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1551 c._known_related_objects = self._known_related_objects
1552 c._iterable_class = self._iterable_class
1553 c._fields = self._fields
1554 return c
1555
1556 def _fetch_all(self) -> None:
1557 if self._result_cache is None:
1558 self._result_cache = list(self._iterable_class(self))
1559 if self._prefetch_related_lookups and not self._prefetch_done:
1560 self._prefetch_related_objects()
1561
1562 def _next_is_sticky(self) -> QuerySet[T]:
1563 """
1564 Indicate that the next filter call and the one following that should
1565 be treated as a single filter. This is only important when it comes to
1566 determining when to reuse tables for many-to-many filters. Required so
1567 that we can filter naturally on the results of related managers.
1568
1569 This doesn't return a clone of the current QuerySet (it returns
1570 "self"). The method is only used internally and should be immediately
1571 followed by a filter() that does create a clone.
1572 """
1573 self._sticky_filter = True
1574 return self
1575
1576 def _merge_sanity_check(self, other: QuerySet[T]) -> None:
1577 """Check that two QuerySet classes may be merged."""
1578 if self._fields is not None and (
1579 set(self.sql_query.values_select) != set(other.sql_query.values_select)
1580 or set(self.sql_query.extra_select) != set(other.sql_query.extra_select)
1581 or set(self.sql_query.annotation_select)
1582 != set(other.sql_query.annotation_select)
1583 ):
1584 raise TypeError(
1585 f"Merging '{self.__class__.__name__}' classes must involve the same values in each case."
1586 )
1587
1588 def _merge_known_related_objects(self, other: QuerySet[T]) -> None:
1589 """
1590 Keep track of all known related objects from either QuerySet instance.
1591 """
1592 for field, objects in other._known_related_objects.items():
1593 self._known_related_objects.setdefault(field, {}).update(objects)
1594
1595 def resolve_expression(self, *args: Any, **kwargs: Any) -> Query:
1596 if self._fields and len(self._fields) > 1:
1597 # values() queryset can only be used as nested queries
1598 # if they are set up to select only a single field.
1599 raise TypeError("Cannot use multi-field values as a filter value.")
1600 query = self.sql_query.resolve_expression(*args, **kwargs)
1601 return query
1602
1603 def _has_filters(self) -> bool:
1604 """
1605 Check if this QuerySet has any filtering going on. This isn't
1606 equivalent with checking if all objects are present in results, for
1607 example, qs[1:]._has_filters() -> False.
1608 """
1609 return self.sql_query.has_filters()
1610
1611 @staticmethod
1612 def _validate_values_are_expressions(
1613 values: tuple[Any, ...], method_name: str
1614 ) -> None:
1615 invalid_args = sorted(
1616 str(arg) for arg in values if not isinstance(arg, ResolvableExpression)
1617 )
1618 if invalid_args:
1619 raise TypeError(
1620 "QuerySet.{}() received non-expression(s): {}.".format(
1621 method_name,
1622 ", ".join(invalid_args),
1623 )
1624 )
1625
1626
1627class InstanceCheckMeta(type):
1628 def __instancecheck__(self, instance: object) -> bool:
1629 return isinstance(instance, QuerySet) and instance.sql_query.is_empty()
1630
1631
1632class EmptyQuerySet(metaclass=InstanceCheckMeta):
1633 """
1634 Marker class to checking if a queryset is empty by .none():
1635 isinstance(qs.none(), EmptyQuerySet) -> True
1636 """
1637
1638 def __init__(self, *args: Any, **kwargs: Any):
1639 raise TypeError("EmptyQuerySet can't be instantiated")
1640
1641
1642class RawQuerySet:
1643 """
1644 Provide an iterator which converts the results of raw SQL queries into
1645 annotated model instances.
1646 """
1647
1648 def __init__(
1649 self,
1650 raw_query: str,
1651 model: type[Model] | None = None,
1652 query: RawQuery | None = None,
1653 params: tuple[Any, ...] = (),
1654 translations: dict[str, str] | None = None,
1655 ):
1656 self.raw_query = raw_query
1657 self.model = model
1658 self.sql_query = query or RawQuery(sql=raw_query, params=params)
1659 self.params = params
1660 self.translations = translations or {}
1661 self._result_cache: list[Model] | None = None
1662 self._prefetch_related_lookups: tuple[Any, ...] = ()
1663 self._prefetch_done = False
1664
1665 def resolve_model_init_order(
1666 self,
1667 ) -> tuple[list[str], list[int], list[tuple[str, int]]]:
1668 """Resolve the init field names and value positions."""
1669 model = self.model
1670 assert model is not None
1671 model_init_fields = [
1672 f for f in model._model_meta.fields if f.column in self.columns
1673 ]
1674 annotation_fields = [
1675 (column, pos)
1676 for pos, column in enumerate(self.columns)
1677 if column not in self.model_fields
1678 ]
1679 model_init_order = [self.columns.index(f.column) for f in model_init_fields]
1680 model_init_names = [f.attname for f in model_init_fields]
1681 return model_init_names, model_init_order, annotation_fields
1682
1683 def prefetch_related(self, *lookups: str | Prefetch | None) -> RawQuerySet:
1684 """Same as QuerySet.prefetch_related()"""
1685 clone = self._clone()
1686 if lookups == (None,):
1687 clone._prefetch_related_lookups = ()
1688 else:
1689 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1690 return clone
1691
1692 def _prefetch_related_objects(self) -> None:
1693 assert self._result_cache is not None
1694 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1695 self._prefetch_done = True
1696
1697 def _clone(self) -> RawQuerySet:
1698 """Same as QuerySet._clone()"""
1699 c = self.__class__(
1700 self.raw_query,
1701 model=self.model,
1702 query=self.sql_query,
1703 params=self.params,
1704 translations=self.translations,
1705 )
1706 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1707 return c
1708
1709 def _fetch_all(self) -> None:
1710 if self._result_cache is None:
1711 self._result_cache = list(self.iterator())
1712 if self._prefetch_related_lookups and not self._prefetch_done:
1713 self._prefetch_related_objects()
1714
1715 def __len__(self) -> int:
1716 self._fetch_all()
1717 assert self._result_cache is not None
1718 return len(self._result_cache)
1719
1720 def __bool__(self) -> bool:
1721 self._fetch_all()
1722 return bool(self._result_cache)
1723
1724 def __iter__(self) -> Iterator[Model]:
1725 self._fetch_all()
1726 assert self._result_cache is not None
1727 return iter(self._result_cache)
1728
1729 def iterator(self) -> Iterator[Model]:
1730 yield from RawModelIterable(self) # ty: ignore[invalid-argument-type]
1731
1732 def __repr__(self) -> str:
1733 return f"<{self.__class__.__name__}: {self.sql_query}>"
1734
1735 def __getitem__(self, k: int | slice) -> Model | list[Model]:
1736 return list(self)[k]
1737
1738 @cached_property
1739 def columns(self) -> list[str]:
1740 """
1741 A list of model field names in the order they'll appear in the
1742 query results.
1743 """
1744 columns = self.sql_query.get_columns()
1745 # Adjust any column names which don't match field names
1746 for query_name, model_name in self.translations.items():
1747 # Ignore translations for nonexistent column names
1748 try:
1749 index = columns.index(query_name)
1750 except ValueError:
1751 pass
1752 else:
1753 columns[index] = model_name
1754 return columns
1755
1756 @cached_property
1757 def model_fields(self) -> dict[str, Field]:
1758 """A dict mapping column names to model field names."""
1759 model_fields = {}
1760 model = self.model
1761 assert model is not None
1762 for field in model._model_meta.fields:
1763 model_fields[field.column] = field
1764 return model_fields
1765
1766
1767class Prefetch:
1768 def __init__(
1769 self,
1770 lookup: str,
1771 queryset: QuerySet[Any] | None = None,
1772 to_attr: str | None = None,
1773 ):
1774 # `prefetch_through` is the path we traverse to perform the prefetch.
1775 self.prefetch_through = lookup
1776 # `prefetch_to` is the path to the attribute that stores the result.
1777 self.prefetch_to = lookup
1778 if queryset is not None and (
1779 isinstance(queryset, RawQuerySet)
1780 or (
1781 hasattr(queryset, "_iterable_class")
1782 and not issubclass(queryset._iterable_class, ModelIterable)
1783 )
1784 ):
1785 raise ValueError(
1786 "Prefetch querysets cannot use raw(), values(), and values_list()."
1787 )
1788 if to_attr:
1789 self.prefetch_to = LOOKUP_SEP.join(
1790 lookup.split(LOOKUP_SEP)[:-1] + [to_attr]
1791 )
1792
1793 self.queryset = queryset
1794 self.to_attr = to_attr
1795
1796 def __getstate__(self) -> dict[str, Any]:
1797 obj_dict = self.__dict__.copy()
1798 if self.queryset is not None:
1799 queryset = self.queryset._chain()
1800 # Prevent the QuerySet from being evaluated
1801 queryset._result_cache = []
1802 queryset._prefetch_done = True
1803 obj_dict["queryset"] = queryset
1804 return obj_dict
1805
1806 def add_prefix(self, prefix: str) -> None:
1807 self.prefetch_through = prefix + LOOKUP_SEP + self.prefetch_through
1808 self.prefetch_to = prefix + LOOKUP_SEP + self.prefetch_to
1809
1810 def get_current_prefetch_to(self, level: int) -> str:
1811 return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[: level + 1])
1812
1813 def get_current_to_attr(self, level: int) -> tuple[str, bool]:
1814 parts = self.prefetch_to.split(LOOKUP_SEP)
1815 to_attr = parts[level]
1816 as_attr = bool(self.to_attr and level == len(parts) - 1)
1817 return to_attr, as_attr
1818
1819 def get_current_queryset(self, level: int) -> QuerySet[Any] | None:
1820 if self.get_current_prefetch_to(level) == self.prefetch_to:
1821 return self.queryset
1822 return None
1823
1824 def __eq__(self, other: object) -> bool:
1825 if not isinstance(other, Prefetch):
1826 return NotImplemented
1827 return self.prefetch_to == other.prefetch_to
1828
1829 def __hash__(self) -> int:
1830 return hash((self.__class__, self.prefetch_to))
1831
1832
1833def normalize_prefetch_lookups(
1834 lookups: tuple[str | Prefetch, ...] | list[str | Prefetch],
1835 prefix: str | None = None,
1836) -> list[Prefetch]:
1837 """Normalize lookups into Prefetch objects."""
1838 ret = []
1839 for lookup in lookups:
1840 if not isinstance(lookup, Prefetch):
1841 lookup = Prefetch(lookup)
1842 if prefix:
1843 lookup.add_prefix(prefix)
1844 ret.append(lookup)
1845 return ret
1846
1847
1848def prefetch_related_objects(
1849 model_instances: Sequence[Model], *related_lookups: str | Prefetch
1850) -> None:
1851 """
1852 Populate prefetched object caches for a list of model instances based on
1853 the lookups/Prefetch instances given.
1854 """
1855 if not model_instances:
1856 return # nothing to do
1857
1858 # We need to be able to dynamically add to the list of prefetch_related
1859 # lookups that we look up (see below). So we need some book keeping to
1860 # ensure we don't do duplicate work.
1861 done_queries = {} # dictionary of things like 'foo__bar': [results]
1862
1863 auto_lookups = set() # we add to this as we go through.
1864 followed_descriptors = set() # recursion protection
1865
1866 all_lookups = normalize_prefetch_lookups(list(reversed(related_lookups)))
1867 while all_lookups:
1868 lookup = all_lookups.pop()
1869 if lookup.prefetch_to in done_queries:
1870 if lookup.queryset is not None:
1871 raise ValueError(
1872 f"'{lookup.prefetch_to}' lookup was already seen with a different queryset. "
1873 "You may need to adjust the ordering of your lookups."
1874 )
1875
1876 continue
1877
1878 # Top level, the list of objects to decorate is the result cache
1879 # from the primary QuerySet. It won't be for deeper levels.
1880 obj_list = model_instances
1881
1882 through_attrs = lookup.prefetch_through.split(LOOKUP_SEP)
1883 for level, through_attr in enumerate(through_attrs):
1884 # Prepare main instances
1885 if not obj_list:
1886 break
1887
1888 prefetch_to = lookup.get_current_prefetch_to(level)
1889 if prefetch_to in done_queries:
1890 # Skip any prefetching, and any object preparation
1891 obj_list = done_queries[prefetch_to]
1892 continue
1893
1894 # Prepare objects:
1895 good_objects = True
1896 for obj in obj_list:
1897 # Since prefetching can re-use instances, it is possible to have
1898 # the same instance multiple times in obj_list, so obj might
1899 # already be prepared.
1900 if not hasattr(obj, "_prefetched_objects_cache"):
1901 try:
1902 obj._prefetched_objects_cache = {}
1903 except (AttributeError, TypeError):
1904 # Must be an immutable object from
1905 # values_list(flat=True), for example (TypeError) or
1906 # a QuerySet subclass that isn't returning Model
1907 # instances (AttributeError), either in Plain or a 3rd
1908 # party. prefetch_related() doesn't make sense, so quit.
1909 good_objects = False
1910 break
1911 if not good_objects:
1912 break
1913
1914 # Descend down tree
1915
1916 # We assume that objects retrieved are homogeneous (which is the premise
1917 # of prefetch_related), so what applies to first object applies to all.
1918 first_obj = obj_list[0]
1919 to_attr = lookup.get_current_to_attr(level)[0]
1920 prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(
1921 first_obj, through_attr, to_attr
1922 )
1923
1924 if not attr_found:
1925 raise AttributeError(
1926 f"Cannot find '{through_attr}' on {first_obj.__class__.__name__} object, '{lookup.prefetch_through}' is an invalid "
1927 "parameter to prefetch_related()"
1928 )
1929
1930 if level == len(through_attrs) - 1 and prefetcher is None:
1931 # Last one, this *must* resolve to something that supports
1932 # prefetching, otherwise there is no point adding it and the
1933 # developer asking for it has made a mistake.
1934 raise ValueError(
1935 f"'{lookup.prefetch_through}' does not resolve to an item that supports "
1936 "prefetching - this is an invalid parameter to "
1937 "prefetch_related()."
1938 )
1939
1940 obj_to_fetch = None
1941 if prefetcher is not None:
1942 obj_to_fetch = [obj for obj in obj_list if not is_fetched(obj)]
1943
1944 if obj_to_fetch:
1945 obj_list, additional_lookups = prefetch_one_level(
1946 obj_to_fetch,
1947 prefetcher,
1948 lookup,
1949 level,
1950 )
1951 # We need to ensure we don't keep adding lookups from the
1952 # same relationships to stop infinite recursion. So, if we
1953 # are already on an automatically added lookup, don't add
1954 # the new lookups from relationships we've seen already.
1955 if not (
1956 prefetch_to in done_queries
1957 and lookup in auto_lookups
1958 and descriptor in followed_descriptors
1959 ):
1960 done_queries[prefetch_to] = obj_list
1961 new_lookups = normalize_prefetch_lookups(
1962 list(reversed(additional_lookups)),
1963 prefetch_to,
1964 )
1965 auto_lookups.update(new_lookups)
1966 all_lookups.extend(new_lookups)
1967 followed_descriptors.add(descriptor)
1968 else:
1969 # Either a singly related object that has already been fetched
1970 # (e.g. via select_related), or hopefully some other property
1971 # that doesn't support prefetching but needs to be traversed.
1972
1973 # We replace the current list of parent objects with the list
1974 # of related objects, filtering out empty or missing values so
1975 # that we can continue with nullable or reverse relations.
1976 new_obj_list = []
1977 for obj in obj_list:
1978 if through_attr in getattr(obj, "_prefetched_objects_cache", ()):
1979 # If related objects have been prefetched, use the
1980 # cache rather than the object's through_attr.
1981 new_obj = list(obj._prefetched_objects_cache.get(through_attr)) # ty: ignore[invalid-argument-type]
1982 else:
1983 try:
1984 new_obj = getattr(obj, through_attr)
1985 except ObjectDoesNotExist:
1986 continue
1987 if new_obj is None:
1988 continue
1989 # We special-case `list` rather than something more generic
1990 # like `Iterable` because we don't want to accidentally match
1991 # user models that define __iter__.
1992 if isinstance(new_obj, list):
1993 new_obj_list.extend(new_obj)
1994 else:
1995 new_obj_list.append(new_obj)
1996 obj_list = new_obj_list
1997
1998
1999def get_prefetcher(
2000 instance: Model, through_attr: str, to_attr: str
2001) -> tuple[Any, Any, bool, Callable[[Model], bool]]:
2002 """
2003 For the attribute 'through_attr' on the given instance, find
2004 an object that has a get_prefetch_queryset().
2005 Return a 4 tuple containing:
2006 (the object with get_prefetch_queryset (or None),
2007 the descriptor object representing this relationship (or None),
2008 a boolean that is False if the attribute was not found at all,
2009 a function that takes an instance and returns a boolean that is True if
2010 the attribute has already been fetched for that instance)
2011 """
2012
2013 def has_to_attr_attribute(instance: Model) -> bool:
2014 return hasattr(instance, to_attr)
2015
2016 prefetcher = None
2017 is_fetched: Callable[[Model], bool] = has_to_attr_attribute
2018
2019 # For singly related objects, we have to avoid getting the attribute
2020 # from the object, as this will trigger the query. So we first try
2021 # on the class, in order to get the descriptor object.
2022 rel_obj_descriptor = getattr(instance.__class__, through_attr, None)
2023 if rel_obj_descriptor is None:
2024 attr_found = hasattr(instance, through_attr)
2025 else:
2026 attr_found = True
2027 if rel_obj_descriptor:
2028 # singly related object, descriptor object has the
2029 # get_prefetch_queryset() method.
2030 if hasattr(rel_obj_descriptor, "get_prefetch_queryset"):
2031 prefetcher = rel_obj_descriptor
2032 is_fetched = rel_obj_descriptor.is_cached
2033 else:
2034 # descriptor doesn't support prefetching, so we go ahead and get
2035 # the attribute on the instance rather than the class to
2036 # support many related managers
2037 rel_obj = getattr(instance, through_attr)
2038 if hasattr(rel_obj, "get_prefetch_queryset"):
2039 prefetcher = rel_obj
2040 if through_attr != to_attr:
2041 # Special case cached_property instances because hasattr
2042 # triggers attribute computation and assignment.
2043 if isinstance(
2044 getattr(instance.__class__, to_attr, None), cached_property
2045 ):
2046
2047 def has_cached_property(instance: Model) -> bool:
2048 return to_attr in instance.__dict__
2049
2050 is_fetched = has_cached_property
2051 else:
2052
2053 def in_prefetched_cache(instance: Model) -> bool:
2054 return through_attr in instance._prefetched_objects_cache
2055
2056 is_fetched = in_prefetched_cache
2057 return prefetcher, rel_obj_descriptor, attr_found, is_fetched
2058
2059
2060def prefetch_one_level(
2061 instances: list[Model], prefetcher: Any, lookup: Prefetch, level: int
2062) -> tuple[list[Model], list[Prefetch]]:
2063 """
2064 Helper function for prefetch_related_objects().
2065
2066 Run prefetches on all instances using the prefetcher object,
2067 assigning results to relevant caches in instance.
2068
2069 Return the prefetched objects along with any additional prefetches that
2070 must be done due to prefetch_related lookups found from default managers.
2071 """
2072 # prefetcher must have a method get_prefetch_queryset() which takes a list
2073 # of instances, and returns a tuple:
2074
2075 # (queryset of instances of self.model that are related to passed in instances,
2076 # callable that gets value to be matched for returned instances,
2077 # callable that gets value to be matched for passed in instances,
2078 # boolean that is True for singly related objects,
2079 # cache or field name to assign to,
2080 # boolean that is True when the previous argument is a cache name vs a field name).
2081
2082 # The 'values to be matched' must be hashable as they will be used
2083 # in a dictionary.
2084
2085 (
2086 rel_qs,
2087 rel_obj_attr,
2088 instance_attr,
2089 single,
2090 cache_name,
2091 is_descriptor,
2092 ) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))
2093 # We have to handle the possibility that the QuerySet we just got back
2094 # contains some prefetch_related lookups. We don't want to trigger the
2095 # prefetch_related functionality by evaluating the query. Rather, we need
2096 # to merge in the prefetch_related lookups.
2097 # Copy the lookups in case it is a Prefetch object which could be reused
2098 # later (happens in nested prefetch_related).
2099 additional_lookups = [
2100 copy.copy(additional_lookup)
2101 for additional_lookup in getattr(rel_qs, "_prefetch_related_lookups", ())
2102 ]
2103 if additional_lookups:
2104 # Don't need to clone because the queryset should have given us a fresh
2105 # instance, so we access an internal instead of using public interface
2106 # for performance reasons.
2107 rel_qs._prefetch_related_lookups = ()
2108
2109 all_related_objects = list(rel_qs)
2110
2111 rel_obj_cache = {}
2112 for rel_obj in all_related_objects:
2113 rel_attr_val = rel_obj_attr(rel_obj)
2114 rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
2115
2116 to_attr, as_attr = lookup.get_current_to_attr(level)
2117 # Make sure `to_attr` does not conflict with a field.
2118 if as_attr and instances:
2119 # We assume that objects retrieved are homogeneous (which is the premise
2120 # of prefetch_related), so what applies to first object applies to all.
2121 model = instances[0].__class__
2122 try:
2123 model._model_meta.get_field(to_attr)
2124 except FieldDoesNotExist:
2125 pass
2126 else:
2127 msg = "to_attr={} conflicts with a field on the {} model."
2128 raise ValueError(msg.format(to_attr, model.__name__))
2129
2130 # Whether or not we're prefetching the last part of the lookup.
2131 leaf = len(lookup.prefetch_through.split(LOOKUP_SEP)) - 1 == level
2132
2133 for obj in instances:
2134 instance_attr_val = instance_attr(obj)
2135 vals = rel_obj_cache.get(instance_attr_val, [])
2136
2137 if single:
2138 val = vals[0] if vals else None
2139 if as_attr:
2140 # A to_attr has been given for the prefetch.
2141 setattr(obj, to_attr, val)
2142 elif is_descriptor:
2143 # cache_name points to a field name in obj.
2144 # This field is a descriptor for a related object.
2145 setattr(obj, cache_name, val)
2146 else:
2147 # No to_attr has been given for this prefetch operation and the
2148 # cache_name does not point to a descriptor. Store the value of
2149 # the field in the object's field cache.
2150 obj._state.fields_cache[cache_name] = val
2151 else:
2152 if as_attr:
2153 setattr(obj, to_attr, vals)
2154 else:
2155 queryset = getattr(obj, to_attr)
2156 if leaf and lookup.queryset is not None:
2157 qs = queryset._apply_rel_filters(lookup.queryset)
2158 else:
2159 # Check if queryset is a QuerySet or a related manager
2160 # We need a QuerySet instance to cache the prefetched values
2161 if isinstance(queryset, QuerySet):
2162 # It's already a QuerySet, create a new instance
2163 qs = queryset.__class__.from_model(queryset.model)
2164 else:
2165 # It's a related manager, get its QuerySet
2166 # The manager's query property returns a properly filtered QuerySet
2167 qs = queryset.query
2168 qs._result_cache = vals
2169 # We don't want the individual qs doing prefetch_related now,
2170 # since we have merged this into the current work.
2171 qs._prefetch_done = True
2172 obj._prefetched_objects_cache[cache_name] = qs
2173 return all_related_objects, additional_lookups
2174
2175
2176class RelatedPopulator:
2177 """
2178 RelatedPopulator is used for select_related() object instantiation.
2179
2180 The idea is that each select_related() model will be populated by a
2181 different RelatedPopulator instance. The RelatedPopulator instances get
2182 klass_info and select (computed in SQLCompiler) plus the used db as
2183 input for initialization. That data is used to compute which columns
2184 to use, how to instantiate the model, and how to populate the links
2185 between the objects.
2186
2187 The actual creation of the objects is done in populate() method. This
2188 method gets row and from_obj as input and populates the select_related()
2189 model instance.
2190 """
2191
2192 def __init__(self, klass_info: dict[str, Any], select: list[Any]):
2193 # Pre-compute needed attributes. The attributes are:
2194 # - model_cls: the possibly deferred model class to instantiate
2195 # - either:
2196 # - cols_start, cols_end: usually the columns in the row are
2197 # in the same order model_cls.__init__ expects them, so we
2198 # can instantiate by model_cls(*row[cols_start:cols_end])
2199 # - reorder_for_init: When select_related descends to a child
2200 # class, then we want to reuse the already selected parent
2201 # data. However, in this case the parent data isn't necessarily
2202 # in the same order that Model.__init__ expects it to be, so
2203 # we have to reorder the parent data. The reorder_for_init
2204 # attribute contains a function used to reorder the field data
2205 # in the order __init__ expects it.
2206 # - id_idx: the index of the primary key field in the reordered
2207 # model data. Used to check if a related object exists at all.
2208 # - init_list: the field attnames fetched from the database. For
2209 # deferred models this isn't the same as all attnames of the
2210 # model's fields.
2211 # - related_populators: a list of RelatedPopulator instances if
2212 # select_related() descends to related models from this model.
2213 # - local_setter, remote_setter: Methods to set cached values on
2214 # the object being populated and on the remote object. Usually
2215 # these are Field.set_cached_value() methods.
2216 select_fields = klass_info["select_fields"]
2217
2218 self.cols_start = select_fields[0]
2219 self.cols_end = select_fields[-1] + 1
2220 self.init_list = [
2221 f[0].target.attname for f in select[self.cols_start : self.cols_end]
2222 ]
2223 self.reorder_for_init = None
2224
2225 self.model_cls = klass_info["model"]
2226 self.id_idx = self.init_list.index("id")
2227 self.related_populators = get_related_populators(klass_info, select)
2228 self.local_setter = klass_info["local_setter"]
2229 self.remote_setter = klass_info["remote_setter"]
2230
2231 def populate(self, row: tuple[Any, ...], from_obj: Model) -> None:
2232 if self.reorder_for_init:
2233 obj_data = self.reorder_for_init(row)
2234 else:
2235 obj_data = row[self.cols_start : self.cols_end]
2236 if obj_data[self.id_idx] is None:
2237 obj = None
2238 else:
2239 obj = self.model_cls.from_db(self.init_list, obj_data)
2240 for rel_iter in self.related_populators:
2241 rel_iter.populate(row, obj)
2242 self.local_setter(from_obj, obj)
2243 if obj is not None:
2244 self.remote_setter(obj, from_obj)
2245
2246
2247def get_related_populators(
2248 klass_info: dict[str, Any], select: list[Any]
2249) -> list[RelatedPopulator]:
2250 iterators = []
2251 related_klass_infos = klass_info.get("related_klass_infos", [])
2252 for rel_klass_info in related_klass_infos:
2253 rel_cls = RelatedPopulator(rel_klass_info, select)
2254 iterators.append(rel_cls)
2255 return iterators