1"""
2The main QuerySet implementation. This provides the public API for the ORM.
3"""
4
5from __future__ import annotations
6
7import copy
8import operator
9import warnings
10from collections.abc import Callable, Iterator, Sequence
11from functools import cached_property
12from itertools import chain, islice
13from typing import TYPE_CHECKING, Any, Generic, Never, Self, TypeVar, overload
14
15import plain.runtime
16from plain.exceptions import ValidationError
17from plain.models import transaction
18from plain.models.constants import LOOKUP_SEP, OnConflict
19from plain.models.db import (
20 PLAIN_VERSION_PICKLE_KEY,
21 IntegrityError,
22 db_connection,
23)
24from plain.models.exceptions import (
25 FieldDoesNotExist,
26 FieldError,
27 ObjectDoesNotExist,
28)
29from plain.models.expressions import Case, F, ResolvableExpression, Value, When
30from plain.models.fields import (
31 Field,
32 PrimaryKeyField,
33)
34from plain.models.functions import Cast
35from plain.models.query_utils import FilteredRelation, Q
36from plain.models.sql import (
37 AND,
38 CURSOR,
39 GET_ITERATOR_CHUNK_SIZE,
40 OR,
41 XOR,
42 DeleteQuery,
43 InsertQuery,
44 Query,
45 RawQuery,
46 UpdateQuery,
47)
48from plain.models.utils import resolve_callables
49from plain.utils.functional import partition
50
51# Re-exports for public API
52__all__ = ["F", "Q", "QuerySet", "RawQuerySet", "Prefetch", "FilteredRelation"]
53
54if TYPE_CHECKING:
55 from plain.models import Model
56
57# Type variable for QuerySet generic
58T = TypeVar("T", bound="Model")
59
60# The maximum number of results to fetch in a get() query.
61MAX_GET_RESULTS = 21
62
63# The maximum number of items to display in a QuerySet.__repr__
64REPR_OUTPUT_SIZE = 20
65
66
67class BaseIterable:
68 def __init__(
69 self,
70 queryset: QuerySet[Any],
71 chunked_fetch: bool = False,
72 chunk_size: int = GET_ITERATOR_CHUNK_SIZE,
73 ):
74 self.queryset = queryset
75 self.chunked_fetch = chunked_fetch
76 self.chunk_size = chunk_size
77
78 def __iter__(self) -> Iterator[Any]:
79 raise NotImplementedError(
80 "subclasses of BaseIterable must provide an __iter__() method"
81 )
82
83
84class ModelIterable(BaseIterable):
85 """Iterable that yields a model instance for each row."""
86
87 def __iter__(self) -> Iterator[Model]:
88 queryset = self.queryset
89 compiler = queryset.sql_query.get_compiler()
90 # Execute the query. This will also fill compiler.select, klass_info,
91 # and annotations.
92 results = compiler.execute_sql(
93 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
94 )
95 select, klass_info, annotation_col_map = (
96 compiler.select,
97 compiler.klass_info,
98 compiler.annotation_col_map,
99 )
100 # These are set by execute_sql() above
101 assert select is not None
102 assert klass_info is not None
103 model_cls = klass_info["model"]
104 select_fields = klass_info["select_fields"]
105 model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
106 init_list = [
107 f[0].target.attname for f in select[model_fields_start:model_fields_end]
108 ]
109 related_populators = get_related_populators(klass_info, select)
110 known_related_objects = [
111 (
112 field,
113 related_objs,
114 operator.attrgetter(field.attname),
115 )
116 for field, related_objs in queryset._known_related_objects.items()
117 ]
118 for row in compiler.results_iter(results):
119 obj = model_cls.from_db(init_list, row[model_fields_start:model_fields_end])
120 for rel_populator in related_populators:
121 rel_populator.populate(row, obj)
122 if annotation_col_map:
123 for attr_name, col_pos in annotation_col_map.items():
124 setattr(obj, attr_name, row[col_pos])
125
126 # Add the known related objects to the model.
127 for field, rel_objs, rel_getter in known_related_objects:
128 # Avoid overwriting objects loaded by, e.g., select_related().
129 if field.is_cached(obj):
130 continue
131 rel_obj_id = rel_getter(obj)
132 try:
133 rel_obj = rel_objs[rel_obj_id]
134 except KeyError:
135 pass # May happen in qs1 | qs2 scenarios.
136 else:
137 setattr(obj, field.name, rel_obj)
138
139 yield obj
140
141
142class RawModelIterable(BaseIterable):
143 """
144 Iterable that yields a model instance for each row from a raw queryset.
145 """
146
147 queryset: RawQuerySet
148
149 def __iter__(self) -> Iterator[Model]:
150 # Cache some things for performance reasons outside the loop.
151 # RawQuery is not a Query subclass, so we directly get SQLCompiler
152 from plain.models.sql.compiler import SQLCompiler
153
154 query = self.queryset.sql_query
155 compiler = SQLCompiler(query, db_connection, True) # type: ignore[arg-type]
156 query_iterator = iter(query)
157
158 try:
159 (
160 model_init_names,
161 model_init_pos,
162 annotation_fields,
163 ) = self.queryset.resolve_model_init_order()
164 model_cls = self.queryset.model
165 assert model_cls is not None
166 if "id" not in model_init_names:
167 raise FieldDoesNotExist("Raw query must include the primary key")
168 fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
169 converters = compiler.get_converters(
170 [
171 f.get_col(f.model.model_options.db_table) if f else None
172 for f in fields
173 ]
174 )
175 if converters:
176 query_iterator = compiler.apply_converters(query_iterator, converters)
177 for values in query_iterator:
178 # Associate fields to values
179 model_init_values = [values[pos] for pos in model_init_pos]
180 instance = model_cls.from_db(model_init_names, model_init_values)
181 if annotation_fields:
182 for column, pos in annotation_fields:
183 setattr(instance, column, values[pos])
184 yield instance
185 finally:
186 # Done iterating the Query. If it has its own cursor, close it.
187 if hasattr(query, "cursor") and query.cursor:
188 query.cursor.close()
189
190
191class ValuesIterable(BaseIterable):
192 """
193 Iterable returned by QuerySet.values() that yields a dict for each row.
194 """
195
196 def __iter__(self) -> Iterator[dict[str, Any]]:
197 queryset = self.queryset
198 query = queryset.sql_query
199 compiler = query.get_compiler()
200
201 # extra(select=...) cols are always at the start of the row.
202 names = [
203 *query.extra_select,
204 *query.values_select,
205 *query.annotation_select,
206 ]
207 indexes = range(len(names))
208 for row in compiler.results_iter(
209 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
210 ):
211 yield {names[i]: row[i] for i in indexes}
212
213
214class ValuesListIterable(BaseIterable):
215 """
216 Iterable returned by QuerySet.values_list(flat=False) that yields a tuple
217 for each row.
218 """
219
220 def __iter__(self) -> Iterator[tuple[Any, ...]]:
221 queryset = self.queryset
222 query = queryset.sql_query
223 compiler = query.get_compiler()
224
225 if queryset._fields:
226 # extra(select=...) cols are always at the start of the row.
227 names = [
228 *query.extra_select,
229 *query.values_select,
230 *query.annotation_select,
231 ]
232 fields = [
233 *queryset._fields,
234 *(f for f in query.annotation_select if f not in queryset._fields),
235 ]
236 if fields != names:
237 # Reorder according to fields.
238 index_map = {name: idx for idx, name in enumerate(names)}
239 rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
240 return map(
241 rowfactory,
242 compiler.results_iter(
243 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
244 ),
245 )
246 return iter(
247 compiler.results_iter(
248 tuple_expected=True,
249 chunked_fetch=self.chunked_fetch,
250 chunk_size=self.chunk_size,
251 )
252 )
253
254
255class FlatValuesListIterable(BaseIterable):
256 """
257 Iterable returned by QuerySet.values_list(flat=True) that yields single
258 values.
259 """
260
261 def __iter__(self) -> Iterator[Any]:
262 queryset = self.queryset
263 compiler = queryset.sql_query.get_compiler()
264 for row in compiler.results_iter(
265 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
266 ):
267 yield row[0]
268
269
270class QuerySet(Generic[T]):
271 """
272 Represent a lazy database lookup for a set of objects.
273
274 Usage:
275 MyModel.query.filter(name="test").all()
276
277 Custom QuerySets:
278 from typing import Self
279
280 class TaskQuerySet(QuerySet["Task"]):
281 def active(self) -> Self:
282 return self.filter(is_active=True)
283
284 class Task(Model):
285 is_active = BooleanField(default=True)
286 query = TaskQuerySet()
287
288 Task.query.active().filter(name="test") # Full type inference
289
290 Custom methods should return `Self` to preserve type through method chaining.
291 """
292
293 # Instance attributes (set in from_model())
294 model: type[T]
295 _query: Query
296 _result_cache: list[T] | None
297 _sticky_filter: bool
298 _for_write: bool
299 _prefetch_related_lookups: tuple[Any, ...]
300 _prefetch_done: bool
301 _known_related_objects: dict[Any, dict[Any, Any]]
302 _iterable_class: type[BaseIterable]
303 _fields: tuple[str, ...] | None
304 _defer_next_filter: bool
305 _deferred_filter: tuple[bool, tuple[Any, ...], dict[str, Any]] | None
306
307 def __init__(self):
308 """Minimal init for descriptor mode. Use from_model() to create instances."""
309 pass
310
311 @classmethod
312 def from_model(cls, model: type[T], query: Query | None = None) -> Self:
313 """Create a QuerySet instance bound to a model."""
314 instance = cls()
315 instance.model = model
316 instance._query = query or Query(model)
317 instance._result_cache = None
318 instance._sticky_filter = False
319 instance._for_write = False
320 instance._prefetch_related_lookups = ()
321 instance._prefetch_done = False
322 instance._known_related_objects = {}
323 instance._iterable_class = ModelIterable
324 instance._fields = None
325 instance._defer_next_filter = False
326 instance._deferred_filter = None
327 return instance
328
329 @overload
330 def __get__(self, instance: None, owner: type[T]) -> Self: ...
331
332 @overload
333 def __get__(self, instance: Model, owner: type[T]) -> Never: ...
334
335 def __get__(self, instance: Any, owner: type[T]) -> Self:
336 """Descriptor protocol - return a new QuerySet bound to the model."""
337 if instance is not None:
338 raise AttributeError(
339 f"QuerySet is only accessible from the model class, not instances. "
340 f"Use {owner.__name__}.query instead."
341 )
342 return self.from_model(owner)
343
344 @property
345 def sql_query(self) -> Query:
346 if self._deferred_filter:
347 negate, args, kwargs = self._deferred_filter
348 self._filter_or_exclude_inplace(negate, args, kwargs)
349 self._deferred_filter = None
350 return self._query
351
352 @sql_query.setter
353 def sql_query(self, value: Query) -> None:
354 if value.values_select:
355 self._iterable_class = ValuesIterable
356 self._query = value
357
358 ########################
359 # PYTHON MAGIC METHODS #
360 ########################
361
362 def __deepcopy__(self, memo: dict[int, Any]) -> QuerySet[T]:
363 """Don't populate the QuerySet's cache."""
364 obj = self.__class__.from_model(self.model)
365 for k, v in self.__dict__.items():
366 if k == "_result_cache":
367 obj.__dict__[k] = None
368 else:
369 obj.__dict__[k] = copy.deepcopy(v, memo)
370 return obj
371
372 def __getstate__(self) -> dict[str, Any]:
373 # Force the cache to be fully populated.
374 self._fetch_all()
375 return {**self.__dict__, PLAIN_VERSION_PICKLE_KEY: plain.runtime.__version__}
376
377 def __setstate__(self, state: dict[str, Any]) -> None:
378 pickled_version = state.get(PLAIN_VERSION_PICKLE_KEY)
379 if pickled_version:
380 if pickled_version != plain.runtime.__version__:
381 warnings.warn(
382 f"Pickled queryset instance's Plain version {pickled_version} does not "
383 f"match the current version {plain.runtime.__version__}.",
384 RuntimeWarning,
385 stacklevel=2,
386 )
387 else:
388 warnings.warn(
389 "Pickled queryset instance's Plain version is not specified.",
390 RuntimeWarning,
391 stacklevel=2,
392 )
393 self.__dict__.update(state)
394
395 def __repr__(self) -> str:
396 data = list(self[: REPR_OUTPUT_SIZE + 1])
397 if len(data) > REPR_OUTPUT_SIZE:
398 data[-1] = "...(remaining elements truncated)..."
399 return f"<{self.__class__.__name__} {data!r}>"
400
401 def __len__(self) -> int:
402 self._fetch_all()
403 assert self._result_cache is not None
404 return len(self._result_cache)
405
406 def __iter__(self) -> Iterator[T]:
407 """
408 The queryset iterator protocol uses three nested iterators in the
409 default case:
410 1. sql.compiler.execute_sql()
411 - Returns 100 rows at time (constants.GET_ITERATOR_CHUNK_SIZE)
412 using cursor.fetchmany(). This part is responsible for
413 doing some column masking, and returning the rows in chunks.
414 2. sql.compiler.results_iter()
415 - Returns one row at time. At this point the rows are still just
416 tuples. In some cases the return values are converted to
417 Python values at this location.
418 3. self.iterator()
419 - Responsible for turning the rows into model objects.
420 """
421 self._fetch_all()
422 assert self._result_cache is not None
423 return iter(self._result_cache)
424
425 def __bool__(self) -> bool:
426 self._fetch_all()
427 return bool(self._result_cache)
428
429 @overload
430 def __getitem__(self, k: int) -> T: ...
431
432 @overload
433 def __getitem__(self, k: slice) -> QuerySet[T] | list[T]: ...
434
435 def __getitem__(self, k: int | slice) -> T | QuerySet[T] | list[T]:
436 """Retrieve an item or slice from the set of results."""
437 if not isinstance(k, int | slice):
438 raise TypeError(
439 f"QuerySet indices must be integers or slices, not {type(k).__name__}."
440 )
441 if (isinstance(k, int) and k < 0) or (
442 isinstance(k, slice)
443 and (
444 (k.start is not None and k.start < 0)
445 or (k.stop is not None and k.stop < 0)
446 )
447 ):
448 raise ValueError("Negative indexing is not supported.")
449
450 if self._result_cache is not None:
451 return self._result_cache[k]
452
453 if isinstance(k, slice):
454 qs = self._chain()
455 if k.start is not None:
456 start = int(k.start)
457 else:
458 start = None
459 if k.stop is not None:
460 stop = int(k.stop)
461 else:
462 stop = None
463 qs.sql_query.set_limits(start, stop)
464 return list(qs)[:: k.step] if k.step else qs
465
466 qs = self._chain()
467 qs.sql_query.set_limits(k, k + 1)
468 qs._fetch_all()
469 assert qs._result_cache is not None # _fetch_all guarantees this
470 return qs._result_cache[0]
471
472 def __class_getitem__(cls, *args: Any, **kwargs: Any) -> type[QuerySet[Any]]:
473 return cls
474
475 def __and__(self, other: QuerySet[T]) -> QuerySet[T]:
476 self._merge_sanity_check(other)
477 if isinstance(other, EmptyQuerySet):
478 return other
479 if isinstance(self, EmptyQuerySet):
480 return self
481 combined = self._chain()
482 combined._merge_known_related_objects(other)
483 combined.sql_query.combine(other.sql_query, AND)
484 return combined
485
486 def __or__(self, other: QuerySet[T]) -> QuerySet[T]:
487 self._merge_sanity_check(other)
488 if isinstance(self, EmptyQuerySet):
489 return other
490 if isinstance(other, EmptyQuerySet):
491 return self
492 query = (
493 self
494 if self.sql_query.can_filter()
495 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
496 )
497 combined = query._chain()
498 combined._merge_known_related_objects(other)
499 if not other.sql_query.can_filter():
500 other = other.model._model_meta.base_queryset.filter(
501 id__in=other.values("id")
502 )
503 combined.sql_query.combine(other.sql_query, OR)
504 return combined
505
506 def __xor__(self, other: QuerySet[T]) -> QuerySet[T]:
507 self._merge_sanity_check(other)
508 if isinstance(self, EmptyQuerySet):
509 return other
510 if isinstance(other, EmptyQuerySet):
511 return self
512 query = (
513 self
514 if self.sql_query.can_filter()
515 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
516 )
517 combined = query._chain()
518 combined._merge_known_related_objects(other)
519 if not other.sql_query.can_filter():
520 other = other.model._model_meta.base_queryset.filter(
521 id__in=other.values("id")
522 )
523 combined.sql_query.combine(other.sql_query, XOR)
524 return combined
525
526 ####################################
527 # METHODS THAT DO DATABASE QUERIES #
528 ####################################
529
530 def _iterator(self, use_chunked_fetch: bool, chunk_size: int | None) -> Iterator[T]:
531 iterable = self._iterable_class(
532 self,
533 chunked_fetch=use_chunked_fetch,
534 chunk_size=chunk_size or 2000,
535 )
536 if not self._prefetch_related_lookups or chunk_size is None:
537 yield from iterable
538 return
539
540 iterator = iter(iterable)
541 while results := list(islice(iterator, chunk_size)):
542 prefetch_related_objects(results, *self._prefetch_related_lookups)
543 yield from results
544
545 def iterator(self, chunk_size: int | None = None) -> Iterator[T]:
546 """
547 An iterator over the results from applying this QuerySet to the
548 database. chunk_size must be provided for QuerySets that prefetch
549 related objects. Otherwise, a default chunk_size of 2000 is supplied.
550 """
551 if chunk_size is None:
552 if self._prefetch_related_lookups:
553 raise ValueError(
554 "chunk_size must be provided when using QuerySet.iterator() after "
555 "prefetch_related()."
556 )
557 elif chunk_size <= 0:
558 raise ValueError("Chunk size must be strictly positive.")
559 # PostgreSQL always supports server-side cursors for chunked fetches
560 return self._iterator(use_chunked_fetch=True, chunk_size=chunk_size)
561
562 def aggregate(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
563 """
564 Return a dictionary containing the calculations (aggregation)
565 over the current queryset.
566
567 If args is present the expression is passed as a kwarg using
568 the Aggregate object's default alias.
569 """
570 if self.sql_query.distinct_fields:
571 raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
572 self._validate_values_are_expressions(
573 (*args, *kwargs.values()), method_name="aggregate"
574 )
575 for arg in args:
576 # The default_alias property raises TypeError if default_alias
577 # can't be set automatically or AttributeError if it isn't an
578 # attribute.
579 try:
580 arg.default_alias
581 except (AttributeError, TypeError):
582 raise TypeError("Complex aggregates require an alias")
583 kwargs[arg.default_alias] = arg
584
585 return self.sql_query.chain().get_aggregation(kwargs)
586
587 def count(self) -> int:
588 """
589 Perform a SELECT COUNT() and return the number of records as an
590 integer.
591
592 If the QuerySet is already fully cached, return the length of the
593 cached results set to avoid multiple SELECT COUNT(*) calls.
594 """
595 if self._result_cache is not None:
596 return len(self._result_cache)
597
598 return self.sql_query.get_count()
599
600 def get(self, *args: Any, **kwargs: Any) -> T:
601 """
602 Perform the query and return a single object matching the given
603 keyword arguments.
604 """
605 clone = self.filter(*args, **kwargs)
606 if self.sql_query.can_filter() and not self.sql_query.distinct_fields:
607 clone = clone.order_by()
608 limit = MAX_GET_RESULTS
609 clone.sql_query.set_limits(high=limit)
610 num = len(clone)
611 if num == 1:
612 assert clone._result_cache is not None # len() fetches results
613 return clone._result_cache[0]
614 if not num:
615 raise self.model.DoesNotExist(
616 f"{self.model.model_options.object_name} matching query does not exist."
617 )
618 raise self.model.MultipleObjectsReturned(
619 "get() returned more than one {} -- it returned {}!".format(
620 self.model.model_options.object_name,
621 num if not limit or num < limit else "more than %s" % (limit - 1),
622 )
623 )
624
625 def get_or_none(self, *args: Any, **kwargs: Any) -> T | None:
626 """
627 Perform the query and return a single object matching the given
628 keyword arguments, or None if no object is found.
629 """
630 try:
631 return self.get(*args, **kwargs)
632 except self.model.DoesNotExist:
633 return None
634
635 def create(self, **kwargs: Any) -> T:
636 """
637 Create a new object with the given kwargs, saving it to the database
638 and returning the created object.
639 """
640 obj = self.model(**kwargs)
641 self._for_write = True
642 obj.save(force_insert=True)
643 return obj
644
645 def _prepare_for_bulk_create(self, objs: list[T]) -> None:
646 id_field = self.model._model_meta.get_forward_field("id")
647 for obj in objs:
648 if obj.id is None:
649 # Populate new primary key values.
650 obj.id = id_field.get_id_value_on_save(obj)
651 obj._prepare_related_fields_for_save(operation_name="bulk_create")
652
653 def _check_bulk_create_options(
654 self,
655 update_conflicts: bool,
656 update_fields: list[Field] | None,
657 unique_fields: list[Field] | None,
658 ) -> OnConflict | None:
659 if update_conflicts:
660 if not update_fields:
661 raise ValueError(
662 "Fields that will be updated when a row insertion fails "
663 "on conflicts must be provided."
664 )
665 if not unique_fields:
666 raise ValueError(
667 "Unique fields that can trigger the upsert must be provided."
668 )
669 # Updating primary keys and non-concrete fields is forbidden.
670 from plain.models.fields.related import ManyToManyField
671
672 if any(
673 not f.concrete or isinstance(f, ManyToManyField) for f in update_fields
674 ):
675 raise ValueError(
676 "bulk_create() can only be used with concrete fields in "
677 "update_fields."
678 )
679 if any(f.primary_key for f in update_fields):
680 raise ValueError(
681 "bulk_create() cannot be used with primary keys in update_fields."
682 )
683 if unique_fields:
684 from plain.models.fields.related import ManyToManyField
685
686 if any(
687 not f.concrete or isinstance(f, ManyToManyField)
688 for f in unique_fields
689 ):
690 raise ValueError(
691 "bulk_create() can only be used with concrete fields "
692 "in unique_fields."
693 )
694 return OnConflict.UPDATE
695 return None
696
697 def bulk_create(
698 self,
699 objs: Sequence[T],
700 batch_size: int | None = None,
701 update_conflicts: bool = False,
702 update_fields: list[str] | None = None,
703 unique_fields: list[str] | None = None,
704 ) -> list[T]:
705 """
706 Insert each of the instances into the database. Do *not* call
707 save() on each of the instances. Primary keys are set on the objects
708 via the PostgreSQL RETURNING clause. Multi-table models are not supported.
709 """
710 if batch_size is not None and batch_size <= 0:
711 raise ValueError("Batch size must be a positive integer.")
712
713 objs = list(objs)
714 if not objs:
715 return objs
716 meta = self.model._model_meta
717 unique_fields_objs: list[Field] | None = None
718 update_fields_objs: list[Field] | None = None
719 if unique_fields:
720 unique_fields_objs = [
721 meta.get_forward_field(name) for name in unique_fields
722 ]
723 if update_fields:
724 update_fields_objs = [
725 meta.get_forward_field(name) for name in update_fields
726 ]
727 on_conflict = self._check_bulk_create_options(
728 update_conflicts,
729 update_fields_objs,
730 unique_fields_objs,
731 )
732 self._for_write = True
733 fields = meta.concrete_fields
734 self._prepare_for_bulk_create(objs)
735 with transaction.atomic(savepoint=False):
736 objs_with_id, objs_without_id = partition(lambda o: o.id is None, objs)
737 if objs_with_id:
738 returned_columns = self._batched_insert(
739 objs_with_id,
740 fields,
741 batch_size,
742 on_conflict=on_conflict,
743 update_fields=update_fields_objs,
744 unique_fields=unique_fields_objs,
745 )
746 id_field = meta.get_forward_field("id")
747 for obj_with_id, results in zip(objs_with_id, returned_columns):
748 for result, field in zip(results, meta.db_returning_fields):
749 if field != id_field:
750 setattr(obj_with_id, field.attname, result)
751 for obj_with_id in objs_with_id:
752 obj_with_id._state.adding = False
753 if objs_without_id:
754 fields = [f for f in fields if not isinstance(f, PrimaryKeyField)]
755 returned_columns = self._batched_insert(
756 objs_without_id,
757 fields,
758 batch_size,
759 on_conflict=on_conflict,
760 update_fields=update_fields_objs,
761 unique_fields=unique_fields_objs,
762 )
763 if on_conflict is None:
764 assert len(returned_columns) == len(objs_without_id)
765 for obj_without_id, results in zip(objs_without_id, returned_columns):
766 for result, field in zip(results, meta.db_returning_fields):
767 setattr(obj_without_id, field.attname, result)
768 obj_without_id._state.adding = False
769
770 return objs
771
772 def bulk_update(
773 self, objs: Sequence[T], fields: list[str], batch_size: int | None = None
774 ) -> int:
775 """
776 Update the given fields in each of the given objects in the database.
777 """
778 if batch_size is not None and batch_size <= 0:
779 raise ValueError("Batch size must be a positive integer.")
780 if not fields:
781 raise ValueError("Field names must be given to bulk_update().")
782 objs_tuple = tuple(objs)
783 if any(obj.id is None for obj in objs_tuple):
784 raise ValueError("All bulk_update() objects must have a primary key set.")
785 fields_list = [
786 self.model._model_meta.get_forward_field(name) for name in fields
787 ]
788 from plain.models.fields.related import ManyToManyField
789
790 if any(not f.concrete or isinstance(f, ManyToManyField) for f in fields_list):
791 raise ValueError("bulk_update() can only be used with concrete fields.")
792 if any(f.primary_key for f in fields_list):
793 raise ValueError("bulk_update() cannot be used with primary key fields.")
794 if not objs_tuple:
795 return 0
796 for obj in objs_tuple:
797 obj._prepare_related_fields_for_save(
798 operation_name="bulk_update", fields=fields_list
799 )
800 # PK is used twice in the resulting update query, once in the filter
801 # and once in the WHEN. Each field will also have one CAST.
802 self._for_write = True
803 max_batch_size = len(objs_tuple)
804 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
805 batches = (
806 objs_tuple[i : i + batch_size]
807 for i in range(0, len(objs_tuple), batch_size)
808 )
809 updates = []
810 for batch_objs in batches:
811 update_kwargs = {}
812 for field in fields_list:
813 when_statements = []
814 for obj in batch_objs:
815 attr = getattr(obj, field.attname)
816 if not isinstance(attr, ResolvableExpression):
817 attr = Value(attr, output_field=field)
818 when_statements.append(When(id=obj.id, then=attr))
819 case_statement = Case(*when_statements, output_field=field)
820 # PostgreSQL requires casted CASE in updates
821 case_statement = Cast(case_statement, output_field=field)
822 update_kwargs[field.attname] = case_statement
823 updates.append(([obj.id for obj in batch_objs], update_kwargs))
824 rows_updated = 0
825 queryset = self._chain()
826 with transaction.atomic(savepoint=False):
827 for ids, update_kwargs in updates:
828 rows_updated += queryset.filter(id__in=ids).update(**update_kwargs)
829 return rows_updated
830
831 def get_or_create(
832 self, defaults: dict[str, Any] | None = None, **kwargs: Any
833 ) -> tuple[T, bool]:
834 """
835 Look up an object with the given kwargs, creating one if necessary.
836 Return a tuple of (object, created), where created is a boolean
837 specifying whether an object was created.
838 """
839 # The get() needs to be targeted at the write database in order
840 # to avoid potential transaction consistency problems.
841 self._for_write = True
842 try:
843 return self.get(**kwargs), False
844 except self.model.DoesNotExist:
845 params = self._extract_model_params(defaults, **kwargs)
846 # Try to create an object using passed params.
847 try:
848 with transaction.atomic():
849 params = dict(resolve_callables(params))
850 return self.create(**params), True
851 except (IntegrityError, ValidationError):
852 # Since create() also validates by default,
853 # we can get any kind of ValidationError here,
854 # or it can flow through and get an IntegrityError from the database.
855 # The main thing we're concerned about is uniqueness failures,
856 # but ValidationError could include other things too.
857 # In all cases though it should be fine to try the get() again
858 # and return an existing object.
859 try:
860 return self.get(**kwargs), False
861 except self.model.DoesNotExist:
862 pass
863 raise
864
865 def update_or_create(
866 self,
867 defaults: dict[str, Any] | None = None,
868 create_defaults: dict[str, Any] | None = None,
869 **kwargs: Any,
870 ) -> tuple[T, bool]:
871 """
872 Look up an object with the given kwargs, updating one with defaults
873 if it exists, otherwise create a new one. Optionally, an object can
874 be created with different values than defaults by using
875 create_defaults.
876 Return a tuple (object, created), where created is a boolean
877 specifying whether an object was created.
878 """
879 if create_defaults is None:
880 update_defaults = create_defaults = defaults or {}
881 else:
882 update_defaults = defaults or {}
883 self._for_write = True
884 with transaction.atomic():
885 # Lock the row so that a concurrent update is blocked until
886 # update_or_create() has performed its save.
887 obj, created = self.select_for_update().get_or_create(
888 create_defaults, **kwargs
889 )
890 if created:
891 return obj, created
892 for k, v in resolve_callables(update_defaults):
893 setattr(obj, k, v)
894
895 update_fields = set(update_defaults)
896 concrete_field_names = self.model._model_meta._non_pk_concrete_field_names
897 # update_fields does not support non-concrete fields.
898 if concrete_field_names.issuperset(update_fields):
899 # Add fields which are set on pre_save(), e.g. auto_now fields.
900 # This is to maintain backward compatibility as these fields
901 # are not updated unless explicitly specified in the
902 # update_fields list.
903 for field in self.model._model_meta.local_concrete_fields:
904 if not (
905 field.primary_key or field.__class__.pre_save is Field.pre_save
906 ):
907 update_fields.add(field.name)
908 if field.name != field.attname:
909 update_fields.add(field.attname)
910 obj.save(update_fields=update_fields)
911 else:
912 obj.save()
913 return obj, False
914
915 def _extract_model_params(
916 self, defaults: dict[str, Any] | None, **kwargs: Any
917 ) -> dict[str, Any]:
918 """
919 Prepare `params` for creating a model instance based on the given
920 kwargs; for use by get_or_create().
921 """
922 defaults = defaults or {}
923 params = {k: v for k, v in kwargs.items() if LOOKUP_SEP not in k}
924 params.update(defaults)
925 property_names = self.model._model_meta._property_names
926 invalid_params = []
927 for param in params:
928 try:
929 self.model._model_meta.get_field(param)
930 except FieldDoesNotExist:
931 # It's okay to use a model's property if it has a setter.
932 if not (param in property_names and getattr(self.model, param).fset):
933 invalid_params.append(param)
934 if invalid_params:
935 raise FieldError(
936 "Invalid field name(s) for model {}: '{}'.".format(
937 self.model.model_options.object_name,
938 "', '".join(sorted(invalid_params)),
939 )
940 )
941 return params
942
943 def first(self) -> T | None:
944 """Return the first object of a query or None if no match is found."""
945 for obj in self[:1]:
946 return obj
947 return None
948
949 def last(self) -> T | None:
950 """Return the last object of a query or None if no match is found."""
951 queryset = self.reverse()
952 for obj in queryset[:1]:
953 return obj
954 return None
955
956 def delete(self) -> tuple[int, dict[str, int]]:
957 """Delete the records in the current QuerySet."""
958 if self.sql_query.is_sliced:
959 raise TypeError("Cannot use 'limit' or 'offset' with delete().")
960 if self.sql_query.distinct or self.sql_query.distinct_fields:
961 raise TypeError("Cannot call delete() after .distinct().")
962 if self._fields is not None:
963 raise TypeError("Cannot call delete() after .values() or .values_list()")
964
965 del_query = self._chain()
966
967 # The delete is actually 2 queries - one to find related objects,
968 # and one to delete. Make sure that the discovery of related
969 # objects is performed on the same database as the deletion.
970 del_query._for_write = True
971
972 # Disable non-supported fields.
973 del_query.sql_query.select_for_update = False
974 del_query.sql_query.select_related = False
975 del_query.sql_query.clear_ordering(force=True)
976
977 from plain.models.deletion import Collector
978
979 collector = Collector(origin=self)
980 collector.collect(del_query)
981 deleted, _rows_count = collector.delete()
982
983 # Clear the result cache, in case this QuerySet gets reused.
984 self._result_cache = None
985 return deleted, _rows_count
986
987 def _raw_delete(self) -> int:
988 """
989 Delete objects found from the given queryset in single direct SQL
990 query. No signals are sent and there is no protection for cascades.
991 """
992 query = self.sql_query.clone()
993 query.__class__ = DeleteQuery
994 cursor = query.get_compiler().execute_sql(CURSOR)
995 if cursor:
996 with cursor:
997 return cursor.rowcount
998 return 0
999
1000 def update(self, **kwargs: Any) -> int:
1001 """
1002 Update all elements in the current QuerySet, setting all the given
1003 fields to the appropriate values.
1004 """
1005 if self.sql_query.is_sliced:
1006 raise TypeError("Cannot update a query once a slice has been taken.")
1007 self._for_write = True
1008 query = self.sql_query.chain(UpdateQuery)
1009 query.add_update_values(kwargs)
1010
1011 # Inline annotations in order_by(), if possible.
1012 new_order_by = []
1013 for col in query.order_by:
1014 alias = col
1015 descending = False
1016 if isinstance(alias, str) and alias.startswith("-"):
1017 alias = alias.removeprefix("-")
1018 descending = True
1019 if annotation := query.annotations.get(alias):
1020 if getattr(annotation, "contains_aggregate", False):
1021 raise FieldError(
1022 f"Cannot update when ordering by an aggregate: {annotation}"
1023 )
1024 if descending:
1025 annotation = annotation.desc()
1026 new_order_by.append(annotation)
1027 else:
1028 new_order_by.append(col)
1029 query.order_by = tuple(new_order_by)
1030
1031 # Clear any annotations so that they won't be present in subqueries.
1032 query.annotations = {}
1033 with transaction.mark_for_rollback_on_error():
1034 rows = query.get_compiler().execute_sql(CURSOR)
1035 self._result_cache = None
1036 return rows
1037
1038 def _update(self, values: list[tuple[Field, Any, Any]]) -> int:
1039 """
1040 A version of update() that accepts field objects instead of field names.
1041 Used primarily for model saving and not intended for use by general
1042 code (it requires too much poking around at model internals to be
1043 useful at that level).
1044 """
1045 if self.sql_query.is_sliced:
1046 raise TypeError("Cannot update a query once a slice has been taken.")
1047 query = self.sql_query.chain(UpdateQuery)
1048 query.add_update_fields(values)
1049 # Clear any annotations so that they won't be present in subqueries.
1050 query.annotations = {}
1051 self._result_cache = None
1052 return query.get_compiler().execute_sql(CURSOR)
1053
1054 def exists(self) -> bool:
1055 """
1056 Return True if the QuerySet would have any results, False otherwise.
1057 """
1058 if self._result_cache is None:
1059 return self.sql_query.has_results()
1060 return bool(self._result_cache)
1061
1062 def _prefetch_related_objects(self) -> None:
1063 # This method can only be called once the result cache has been filled.
1064 assert self._result_cache is not None
1065 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1066 self._prefetch_done = True
1067
1068 def explain(self, *, format: str | None = None, **options: Any) -> str:
1069 """
1070 Runs an EXPLAIN on the SQL query this QuerySet would perform, and
1071 returns the results.
1072 """
1073 return self.sql_query.explain(format=format, **options)
1074
1075 ##################################################
1076 # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
1077 ##################################################
1078
1079 def raw(
1080 self,
1081 raw_query: str,
1082 params: Sequence[Any] = (),
1083 translations: dict[str, str] | None = None,
1084 ) -> RawQuerySet:
1085 qs = RawQuerySet(
1086 raw_query,
1087 model=self.model,
1088 params=tuple(params),
1089 translations=translations,
1090 )
1091 qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
1092 return qs
1093
1094 def _values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1095 clone = self._chain()
1096 if expressions:
1097 clone = clone.annotate(**expressions)
1098 clone._fields = fields
1099 clone.sql_query.set_values(list(fields))
1100 return clone
1101
1102 def values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1103 fields += tuple(expressions)
1104 clone = self._values(*fields, **expressions)
1105 clone._iterable_class = ValuesIterable
1106 return clone
1107
1108 def values_list(self, *fields: str, flat: bool = False) -> QuerySet[Any]:
1109 if flat and len(fields) > 1:
1110 raise TypeError(
1111 "'flat' is not valid when values_list is called with more than one "
1112 "field."
1113 )
1114
1115 field_names = {f for f in fields if not isinstance(f, ResolvableExpression)}
1116 _fields = []
1117 expressions = {}
1118 counter = 1
1119 for field in fields:
1120 if isinstance(field, ResolvableExpression):
1121 field_id_prefix = getattr(
1122 field, "default_alias", field.__class__.__name__.lower()
1123 )
1124 while True:
1125 field_id = field_id_prefix + str(counter)
1126 counter += 1
1127 if field_id not in field_names:
1128 break
1129 expressions[field_id] = field
1130 _fields.append(field_id)
1131 else:
1132 _fields.append(field)
1133
1134 clone = self._values(*_fields, **expressions)
1135 clone._iterable_class = FlatValuesListIterable if flat else ValuesListIterable
1136 return clone
1137
1138 def none(self) -> QuerySet[T]:
1139 """Return an empty QuerySet."""
1140 clone = self._chain()
1141 clone.sql_query.set_empty()
1142 return clone
1143
1144 ##################################################################
1145 # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
1146 ##################################################################
1147
1148 def all(self) -> Self:
1149 """
1150 Return a new QuerySet that is a copy of the current one. This allows a
1151 QuerySet to proxy for a model queryset in some cases.
1152 """
1153 obj = self._chain()
1154 # Preserve cache since all() doesn't modify the query.
1155 # This is important for prefetch_related() to work correctly.
1156 obj._result_cache = self._result_cache
1157 obj._prefetch_done = self._prefetch_done
1158 return obj
1159
1160 def filter(self, *args: Any, **kwargs: Any) -> Self:
1161 """
1162 Return a new QuerySet instance with the args ANDed to the existing
1163 set.
1164 """
1165 return self._filter_or_exclude(False, args, kwargs)
1166
1167 def exclude(self, *args: Any, **kwargs: Any) -> Self:
1168 """
1169 Return a new QuerySet instance with NOT (args) ANDed to the existing
1170 set.
1171 """
1172 return self._filter_or_exclude(True, args, kwargs)
1173
1174 def _filter_or_exclude(
1175 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1176 ) -> Self:
1177 if (args or kwargs) and self.sql_query.is_sliced:
1178 raise TypeError("Cannot filter a query once a slice has been taken.")
1179 clone = self._chain()
1180 if self._defer_next_filter:
1181 self._defer_next_filter = False
1182 clone._deferred_filter = negate, args, kwargs
1183 else:
1184 clone._filter_or_exclude_inplace(negate, args, kwargs)
1185 return clone
1186
1187 def _filter_or_exclude_inplace(
1188 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1189 ) -> None:
1190 if negate:
1191 self._query.add_q(~Q(*args, **kwargs))
1192 else:
1193 self._query.add_q(Q(*args, **kwargs))
1194
1195 def complex_filter(self, filter_obj: Q | dict[str, Any]) -> QuerySet[T]:
1196 """
1197 Return a new QuerySet instance with filter_obj added to the filters.
1198
1199 filter_obj can be a Q object or a dictionary of keyword lookup
1200 arguments.
1201
1202 This exists to support framework features such as 'limit_choices_to',
1203 and usually it will be more natural to use other methods.
1204 """
1205 if isinstance(filter_obj, Q):
1206 clone = self._chain()
1207 clone.sql_query.add_q(filter_obj)
1208 return clone
1209 else:
1210 return self._filter_or_exclude(False, args=(), kwargs=filter_obj)
1211
1212 def select_for_update(
1213 self,
1214 nowait: bool = False,
1215 skip_locked: bool = False,
1216 of: tuple[str, ...] = (),
1217 no_key: bool = False,
1218 ) -> QuerySet[T]:
1219 """
1220 Return a new QuerySet instance that will select objects with a
1221 FOR UPDATE lock.
1222 """
1223 if nowait and skip_locked:
1224 raise ValueError("The nowait option cannot be used with skip_locked.")
1225 obj = self._chain()
1226 obj._for_write = True
1227 obj.sql_query.select_for_update = True
1228 obj.sql_query.select_for_update_nowait = nowait
1229 obj.sql_query.select_for_update_skip_locked = skip_locked
1230 obj.sql_query.select_for_update_of = of
1231 obj.sql_query.select_for_no_key_update = no_key
1232 return obj
1233
1234 def select_related(self, *fields: str | None) -> Self:
1235 """
1236 Return a new QuerySet instance that will select related objects.
1237
1238 If fields are specified, they must be ForeignKeyField fields and only those
1239 related objects are included in the selection.
1240
1241 If select_related(None) is called, clear the list.
1242 """
1243 if self._fields is not None:
1244 raise TypeError(
1245 "Cannot call select_related() after .values() or .values_list()"
1246 )
1247
1248 obj = self._chain()
1249 if fields == (None,):
1250 obj.sql_query.select_related = False
1251 elif fields:
1252 obj.sql_query.add_select_related(list(fields)) # type: ignore[arg-type]
1253 else:
1254 obj.sql_query.select_related = True
1255 return obj
1256
1257 def prefetch_related(self, *lookups: str | Prefetch | None) -> Self:
1258 """
1259 Return a new QuerySet instance that will prefetch the specified
1260 Many-To-One and Many-To-Many related objects when the QuerySet is
1261 evaluated.
1262
1263 When prefetch_related() is called more than once, append to the list of
1264 prefetch lookups. If prefetch_related(None) is called, clear the list.
1265 """
1266 clone = self._chain()
1267 if lookups == (None,):
1268 clone._prefetch_related_lookups = ()
1269 else:
1270 for lookup in lookups:
1271 lookup_str: str
1272 if isinstance(lookup, Prefetch):
1273 lookup_str = lookup.prefetch_to
1274 else:
1275 assert isinstance(lookup, str)
1276 lookup_str = lookup
1277 lookup_str = lookup_str.split(LOOKUP_SEP, 1)[0]
1278 if lookup_str in self.sql_query._filtered_relations:
1279 raise ValueError(
1280 "prefetch_related() is not supported with FilteredRelation."
1281 )
1282 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1283 return clone
1284
1285 def annotate(self, *args: Any, **kwargs: Any) -> Self:
1286 """
1287 Return a query set in which the returned objects have been annotated
1288 with extra data or aggregations.
1289 """
1290 return self._annotate(args, kwargs, select=True)
1291
1292 def alias(self, *args: Any, **kwargs: Any) -> Self:
1293 """
1294 Return a query set with added aliases for extra data or aggregations.
1295 """
1296 return self._annotate(args, kwargs, select=False)
1297
1298 def _annotate(
1299 self, args: tuple[Any, ...], kwargs: dict[str, Any], select: bool = True
1300 ) -> Self:
1301 self._validate_values_are_expressions(
1302 args + tuple(kwargs.values()), method_name="annotate"
1303 )
1304 annotations = {}
1305 for arg in args:
1306 # The default_alias property may raise a TypeError.
1307 try:
1308 if arg.default_alias in kwargs:
1309 raise ValueError(
1310 f"The named annotation '{arg.default_alias}' conflicts with the "
1311 "default name for another annotation."
1312 )
1313 except TypeError:
1314 raise TypeError("Complex annotations require an alias")
1315 annotations[arg.default_alias] = arg
1316 annotations.update(kwargs)
1317
1318 clone = self._chain()
1319 names = self._fields
1320 if names is None:
1321 names = set(
1322 chain.from_iterable(
1323 (field.name, field.attname)
1324 if hasattr(field, "attname")
1325 else (field.name,)
1326 for field in self.model._model_meta.get_fields()
1327 )
1328 )
1329
1330 for alias, annotation in annotations.items():
1331 if alias in names:
1332 raise ValueError(
1333 f"The annotation '{alias}' conflicts with a field on the model."
1334 )
1335 if isinstance(annotation, FilteredRelation):
1336 clone.sql_query.add_filtered_relation(annotation, alias)
1337 else:
1338 clone.sql_query.add_annotation(
1339 annotation,
1340 alias,
1341 select=select,
1342 )
1343 for alias, annotation in clone.sql_query.annotations.items():
1344 if alias in annotations and annotation.contains_aggregate:
1345 if clone._fields is None:
1346 clone.sql_query.group_by = True
1347 else:
1348 clone.sql_query.set_group_by()
1349 break
1350
1351 return clone
1352
1353 def order_by(self, *field_names: str) -> Self:
1354 """Return a new QuerySet instance with the ordering changed."""
1355 if self.sql_query.is_sliced:
1356 raise TypeError("Cannot reorder a query once a slice has been taken.")
1357 obj = self._chain()
1358 obj.sql_query.clear_ordering(force=True, clear_default=False)
1359 obj.sql_query.add_ordering(*field_names)
1360 return obj
1361
1362 def distinct(self, *field_names: str) -> Self:
1363 """
1364 Return a new QuerySet instance that will select only distinct results.
1365 """
1366 if self.sql_query.is_sliced:
1367 raise TypeError(
1368 "Cannot create distinct fields once a slice has been taken."
1369 )
1370 obj = self._chain()
1371 obj.sql_query.add_distinct_fields(*field_names)
1372 return obj
1373
1374 def extra(
1375 self,
1376 select: dict[str, str] | None = None,
1377 where: list[str] | None = None,
1378 params: list[Any] | None = None,
1379 tables: list[str] | None = None,
1380 order_by: list[str] | None = None,
1381 select_params: list[Any] | None = None,
1382 ) -> QuerySet[T]:
1383 """Add extra SQL fragments to the query."""
1384 if self.sql_query.is_sliced:
1385 raise TypeError("Cannot change a query once a slice has been taken.")
1386 clone = self._chain()
1387 clone.sql_query.add_extra(
1388 select or {},
1389 select_params,
1390 where or [],
1391 params or [],
1392 tables or [],
1393 tuple(order_by) if order_by else (),
1394 )
1395 return clone
1396
1397 def reverse(self) -> QuerySet[T]:
1398 """Reverse the ordering of the QuerySet."""
1399 if self.sql_query.is_sliced:
1400 raise TypeError("Cannot reverse a query once a slice has been taken.")
1401 clone = self._chain()
1402 clone.sql_query.standard_ordering = not clone.sql_query.standard_ordering
1403 return clone
1404
1405 def defer(self, *fields: str | None) -> QuerySet[T]:
1406 """
1407 Defer the loading of data for certain fields until they are accessed.
1408 Add the set of deferred fields to any existing set of deferred fields.
1409 The only exception to this is if None is passed in as the only
1410 parameter, in which case removal all deferrals.
1411 """
1412 if self._fields is not None:
1413 raise TypeError("Cannot call defer() after .values() or .values_list()")
1414 clone = self._chain()
1415 if fields == (None,):
1416 clone.sql_query.clear_deferred_loading()
1417 else:
1418 clone.sql_query.add_deferred_loading(frozenset(fields))
1419 return clone
1420
1421 def only(self, *fields: str) -> QuerySet[T]:
1422 """
1423 Essentially, the opposite of defer(). Only the fields passed into this
1424 method and that are not already specified as deferred are loaded
1425 immediately when the queryset is evaluated.
1426 """
1427 if self._fields is not None:
1428 raise TypeError("Cannot call only() after .values() or .values_list()")
1429 if fields == (None,):
1430 # Can only pass None to defer(), not only(), as the rest option.
1431 # That won't stop people trying to do this, so let's be explicit.
1432 raise TypeError("Cannot pass None as an argument to only().")
1433 for field in fields:
1434 field = field.split(LOOKUP_SEP, 1)[0]
1435 if field in self.sql_query._filtered_relations:
1436 raise ValueError("only() is not supported with FilteredRelation.")
1437 clone = self._chain()
1438 clone.sql_query.add_immediate_loading(set(fields))
1439 return clone
1440
1441 ###################################
1442 # PUBLIC INTROSPECTION ATTRIBUTES #
1443 ###################################
1444
1445 @property
1446 def ordered(self) -> bool:
1447 """
1448 Return True if the QuerySet is ordered -- i.e. has an order_by()
1449 clause or a default ordering on the model (or is empty).
1450 """
1451 if isinstance(self, EmptyQuerySet):
1452 return True
1453 if self.sql_query.extra_order_by or self.sql_query.order_by:
1454 return True
1455 elif (
1456 self.sql_query.default_ordering
1457 and self.sql_query.model
1458 and self.sql_query.model._model_meta.ordering # type: ignore[arg-type]
1459 and
1460 # A default ordering doesn't affect GROUP BY queries.
1461 not self.sql_query.group_by
1462 ):
1463 return True
1464 else:
1465 return False
1466
1467 ###################
1468 # PRIVATE METHODS #
1469 ###################
1470
1471 def _insert(
1472 self,
1473 objs: list[T],
1474 fields: list[Field],
1475 returning_fields: list[Field] | None = None,
1476 raw: bool = False,
1477 on_conflict: OnConflict | None = None,
1478 update_fields: list[Field] | None = None,
1479 unique_fields: list[Field] | None = None,
1480 ) -> list[tuple[Any, ...]] | None:
1481 """
1482 Insert a new record for the given model. This provides an interface to
1483 the InsertQuery class and is how Model.save() is implemented.
1484 """
1485 self._for_write = True
1486 query = InsertQuery(
1487 self.model,
1488 on_conflict=on_conflict.value if on_conflict else None,
1489 update_fields=update_fields,
1490 unique_fields=unique_fields,
1491 )
1492 query.insert_values(fields, objs, raw=raw)
1493 # InsertQuery returns SQLInsertCompiler which has different execute_sql signature
1494 return query.get_compiler().execute_sql(returning_fields)
1495
1496 def _batched_insert(
1497 self,
1498 objs: list[T],
1499 fields: list[Field],
1500 batch_size: int | None,
1501 on_conflict: OnConflict | None = None,
1502 update_fields: list[Field] | None = None,
1503 unique_fields: list[Field] | None = None,
1504 ) -> list[tuple[Any, ...]]:
1505 """
1506 Helper method for bulk_create() to insert objs one batch at a time.
1507 """
1508 max_batch_size = max(len(objs), 1)
1509 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
1510 inserted_rows = []
1511 for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
1512 if on_conflict is None:
1513 inserted_rows.extend(
1514 self._insert( # type: ignore[arg-type]
1515 item,
1516 fields=fields,
1517 returning_fields=self.model._model_meta.db_returning_fields,
1518 )
1519 )
1520 else:
1521 self._insert(
1522 item,
1523 fields=fields,
1524 on_conflict=on_conflict,
1525 update_fields=update_fields,
1526 unique_fields=unique_fields,
1527 )
1528 return inserted_rows
1529
1530 def _chain(self) -> Self:
1531 """
1532 Return a copy of the current QuerySet that's ready for another
1533 operation.
1534 """
1535 obj = self._clone()
1536 if obj._sticky_filter:
1537 obj.sql_query.filter_is_sticky = True
1538 obj._sticky_filter = False
1539 return obj
1540
1541 def _clone(self) -> Self:
1542 """
1543 Return a copy of the current QuerySet. A lightweight alternative
1544 to deepcopy().
1545 """
1546 c = self.__class__.from_model(
1547 model=self.model,
1548 query=self.sql_query.chain(),
1549 )
1550 c._sticky_filter = self._sticky_filter
1551 c._for_write = self._for_write
1552 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1553 c._known_related_objects = self._known_related_objects
1554 c._iterable_class = self._iterable_class
1555 c._fields = self._fields
1556 return c
1557
1558 def _fetch_all(self) -> None:
1559 if self._result_cache is None:
1560 self._result_cache = list(self._iterable_class(self))
1561 if self._prefetch_related_lookups and not self._prefetch_done:
1562 self._prefetch_related_objects()
1563
1564 def _next_is_sticky(self) -> QuerySet[T]:
1565 """
1566 Indicate that the next filter call and the one following that should
1567 be treated as a single filter. This is only important when it comes to
1568 determining when to reuse tables for many-to-many filters. Required so
1569 that we can filter naturally on the results of related managers.
1570
1571 This doesn't return a clone of the current QuerySet (it returns
1572 "self"). The method is only used internally and should be immediately
1573 followed by a filter() that does create a clone.
1574 """
1575 self._sticky_filter = True
1576 return self
1577
1578 def _merge_sanity_check(self, other: QuerySet[T]) -> None:
1579 """Check that two QuerySet classes may be merged."""
1580 if self._fields is not None and (
1581 set(self.sql_query.values_select) != set(other.sql_query.values_select)
1582 or set(self.sql_query.extra_select) != set(other.sql_query.extra_select)
1583 or set(self.sql_query.annotation_select)
1584 != set(other.sql_query.annotation_select)
1585 ):
1586 raise TypeError(
1587 f"Merging '{self.__class__.__name__}' classes must involve the same values in each case."
1588 )
1589
1590 def _merge_known_related_objects(self, other: QuerySet[T]) -> None:
1591 """
1592 Keep track of all known related objects from either QuerySet instance.
1593 """
1594 for field, objects in other._known_related_objects.items():
1595 self._known_related_objects.setdefault(field, {}).update(objects)
1596
1597 def resolve_expression(self, *args: Any, **kwargs: Any) -> Query:
1598 if self._fields and len(self._fields) > 1:
1599 # values() queryset can only be used as nested queries
1600 # if they are set up to select only a single field.
1601 raise TypeError("Cannot use multi-field values as a filter value.")
1602 query = self.sql_query.resolve_expression(*args, **kwargs)
1603 return query
1604
1605 def _has_filters(self) -> bool:
1606 """
1607 Check if this QuerySet has any filtering going on. This isn't
1608 equivalent with checking if all objects are present in results, for
1609 example, qs[1:]._has_filters() -> False.
1610 """
1611 return self.sql_query.has_filters()
1612
1613 @staticmethod
1614 def _validate_values_are_expressions(
1615 values: tuple[Any, ...], method_name: str
1616 ) -> None:
1617 invalid_args = sorted(
1618 str(arg) for arg in values if not isinstance(arg, ResolvableExpression)
1619 )
1620 if invalid_args:
1621 raise TypeError(
1622 "QuerySet.{}() received non-expression(s): {}.".format(
1623 method_name,
1624 ", ".join(invalid_args),
1625 )
1626 )
1627
1628
1629class InstanceCheckMeta(type):
1630 def __instancecheck__(self, instance: object) -> bool:
1631 return isinstance(instance, QuerySet) and instance.sql_query.is_empty()
1632
1633
1634class EmptyQuerySet(metaclass=InstanceCheckMeta):
1635 """
1636 Marker class to checking if a queryset is empty by .none():
1637 isinstance(qs.none(), EmptyQuerySet) -> True
1638 """
1639
1640 def __init__(self, *args: Any, **kwargs: Any):
1641 raise TypeError("EmptyQuerySet can't be instantiated")
1642
1643
1644class RawQuerySet:
1645 """
1646 Provide an iterator which converts the results of raw SQL queries into
1647 annotated model instances.
1648 """
1649
1650 def __init__(
1651 self,
1652 raw_query: str,
1653 model: type[Model] | None = None,
1654 query: RawQuery | None = None,
1655 params: tuple[Any, ...] = (),
1656 translations: dict[str, str] | None = None,
1657 ):
1658 self.raw_query = raw_query
1659 self.model = model
1660 self.sql_query = query or RawQuery(sql=raw_query, params=params)
1661 self.params = params
1662 self.translations = translations or {}
1663 self._result_cache: list[Model] | None = None
1664 self._prefetch_related_lookups: tuple[Any, ...] = ()
1665 self._prefetch_done = False
1666
1667 def resolve_model_init_order(
1668 self,
1669 ) -> tuple[list[str], list[int], list[tuple[str, int]]]:
1670 """Resolve the init field names and value positions."""
1671 model = self.model
1672 assert model is not None
1673 model_init_fields = [
1674 f for f in model._model_meta.fields if f.column in self.columns
1675 ]
1676 annotation_fields = [
1677 (column, pos)
1678 for pos, column in enumerate(self.columns)
1679 if column not in self.model_fields
1680 ]
1681 model_init_order = [self.columns.index(f.column) for f in model_init_fields]
1682 model_init_names = [f.attname for f in model_init_fields]
1683 return model_init_names, model_init_order, annotation_fields
1684
1685 def prefetch_related(self, *lookups: str | Prefetch | None) -> RawQuerySet:
1686 """Same as QuerySet.prefetch_related()"""
1687 clone = self._clone()
1688 if lookups == (None,):
1689 clone._prefetch_related_lookups = ()
1690 else:
1691 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1692 return clone
1693
1694 def _prefetch_related_objects(self) -> None:
1695 assert self._result_cache is not None
1696 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1697 self._prefetch_done = True
1698
1699 def _clone(self) -> RawQuerySet:
1700 """Same as QuerySet._clone()"""
1701 c = self.__class__(
1702 self.raw_query,
1703 model=self.model,
1704 query=self.sql_query,
1705 params=self.params,
1706 translations=self.translations,
1707 )
1708 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1709 return c
1710
1711 def _fetch_all(self) -> None:
1712 if self._result_cache is None:
1713 self._result_cache = list(self.iterator())
1714 if self._prefetch_related_lookups and not self._prefetch_done:
1715 self._prefetch_related_objects()
1716
1717 def __len__(self) -> int:
1718 self._fetch_all()
1719 assert self._result_cache is not None
1720 return len(self._result_cache)
1721
1722 def __bool__(self) -> bool:
1723 self._fetch_all()
1724 return bool(self._result_cache)
1725
1726 def __iter__(self) -> Iterator[Model]:
1727 self._fetch_all()
1728 assert self._result_cache is not None
1729 return iter(self._result_cache)
1730
1731 def iterator(self) -> Iterator[Model]:
1732 yield from RawModelIterable(self) # type: ignore[arg-type]
1733
1734 def __repr__(self) -> str:
1735 return f"<{self.__class__.__name__}: {self.sql_query}>"
1736
1737 def __getitem__(self, k: int | slice) -> Model | list[Model]:
1738 return list(self)[k]
1739
1740 @cached_property
1741 def columns(self) -> list[str]:
1742 """
1743 A list of model field names in the order they'll appear in the
1744 query results.
1745 """
1746 columns = self.sql_query.get_columns()
1747 # Adjust any column names which don't match field names
1748 for query_name, model_name in self.translations.items():
1749 # Ignore translations for nonexistent column names
1750 try:
1751 index = columns.index(query_name)
1752 except ValueError:
1753 pass
1754 else:
1755 columns[index] = model_name
1756 return columns
1757
1758 @cached_property
1759 def model_fields(self) -> dict[str, Field]:
1760 """A dict mapping column names to model field names."""
1761 model_fields = {}
1762 model = self.model
1763 assert model is not None
1764 for field in model._model_meta.fields:
1765 model_fields[field.column] = field
1766 return model_fields
1767
1768
1769class Prefetch:
1770 def __init__(
1771 self,
1772 lookup: str,
1773 queryset: QuerySet[Any] | None = None,
1774 to_attr: str | None = None,
1775 ):
1776 # `prefetch_through` is the path we traverse to perform the prefetch.
1777 self.prefetch_through = lookup
1778 # `prefetch_to` is the path to the attribute that stores the result.
1779 self.prefetch_to = lookup
1780 if queryset is not None and (
1781 isinstance(queryset, RawQuerySet)
1782 or (
1783 hasattr(queryset, "_iterable_class")
1784 and not issubclass(queryset._iterable_class, ModelIterable)
1785 )
1786 ):
1787 raise ValueError(
1788 "Prefetch querysets cannot use raw(), values(), and values_list()."
1789 )
1790 if to_attr:
1791 self.prefetch_to = LOOKUP_SEP.join(
1792 lookup.split(LOOKUP_SEP)[:-1] + [to_attr]
1793 )
1794
1795 self.queryset = queryset
1796 self.to_attr = to_attr
1797
1798 def __getstate__(self) -> dict[str, Any]:
1799 obj_dict = self.__dict__.copy()
1800 if self.queryset is not None:
1801 queryset = self.queryset._chain()
1802 # Prevent the QuerySet from being evaluated
1803 queryset._result_cache = []
1804 queryset._prefetch_done = True
1805 obj_dict["queryset"] = queryset
1806 return obj_dict
1807
1808 def add_prefix(self, prefix: str) -> None:
1809 self.prefetch_through = prefix + LOOKUP_SEP + self.prefetch_through
1810 self.prefetch_to = prefix + LOOKUP_SEP + self.prefetch_to
1811
1812 def get_current_prefetch_to(self, level: int) -> str:
1813 return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[: level + 1])
1814
1815 def get_current_to_attr(self, level: int) -> tuple[str, bool]:
1816 parts = self.prefetch_to.split(LOOKUP_SEP)
1817 to_attr = parts[level]
1818 as_attr = bool(self.to_attr and level == len(parts) - 1)
1819 return to_attr, as_attr
1820
1821 def get_current_queryset(self, level: int) -> QuerySet[Any] | None:
1822 if self.get_current_prefetch_to(level) == self.prefetch_to:
1823 return self.queryset
1824 return None
1825
1826 def __eq__(self, other: object) -> bool:
1827 if not isinstance(other, Prefetch):
1828 return NotImplemented
1829 return self.prefetch_to == other.prefetch_to
1830
1831 def __hash__(self) -> int:
1832 return hash((self.__class__, self.prefetch_to))
1833
1834
1835def normalize_prefetch_lookups(
1836 lookups: tuple[str | Prefetch, ...] | list[str | Prefetch],
1837 prefix: str | None = None,
1838) -> list[Prefetch]:
1839 """Normalize lookups into Prefetch objects."""
1840 ret = []
1841 for lookup in lookups:
1842 if not isinstance(lookup, Prefetch):
1843 lookup = Prefetch(lookup)
1844 if prefix:
1845 lookup.add_prefix(prefix)
1846 ret.append(lookup)
1847 return ret
1848
1849
1850def prefetch_related_objects(
1851 model_instances: Sequence[Model], *related_lookups: str | Prefetch
1852) -> None:
1853 """
1854 Populate prefetched object caches for a list of model instances based on
1855 the lookups/Prefetch instances given.
1856 """
1857 if not model_instances:
1858 return # nothing to do
1859
1860 # We need to be able to dynamically add to the list of prefetch_related
1861 # lookups that we look up (see below). So we need some book keeping to
1862 # ensure we don't do duplicate work.
1863 done_queries = {} # dictionary of things like 'foo__bar': [results]
1864
1865 auto_lookups = set() # we add to this as we go through.
1866 followed_descriptors = set() # recursion protection
1867
1868 all_lookups = normalize_prefetch_lookups(reversed(related_lookups)) # type: ignore[arg-type]
1869 while all_lookups:
1870 lookup = all_lookups.pop()
1871 if lookup.prefetch_to in done_queries:
1872 if lookup.queryset is not None:
1873 raise ValueError(
1874 f"'{lookup.prefetch_to}' lookup was already seen with a different queryset. "
1875 "You may need to adjust the ordering of your lookups."
1876 )
1877
1878 continue
1879
1880 # Top level, the list of objects to decorate is the result cache
1881 # from the primary QuerySet. It won't be for deeper levels.
1882 obj_list = model_instances
1883
1884 through_attrs = lookup.prefetch_through.split(LOOKUP_SEP)
1885 for level, through_attr in enumerate(through_attrs):
1886 # Prepare main instances
1887 if not obj_list:
1888 break
1889
1890 prefetch_to = lookup.get_current_prefetch_to(level)
1891 if prefetch_to in done_queries:
1892 # Skip any prefetching, and any object preparation
1893 obj_list = done_queries[prefetch_to]
1894 continue
1895
1896 # Prepare objects:
1897 good_objects = True
1898 for obj in obj_list:
1899 # Since prefetching can re-use instances, it is possible to have
1900 # the same instance multiple times in obj_list, so obj might
1901 # already be prepared.
1902 if not hasattr(obj, "_prefetched_objects_cache"):
1903 try:
1904 obj._prefetched_objects_cache = {}
1905 except (AttributeError, TypeError):
1906 # Must be an immutable object from
1907 # values_list(flat=True), for example (TypeError) or
1908 # a QuerySet subclass that isn't returning Model
1909 # instances (AttributeError), either in Plain or a 3rd
1910 # party. prefetch_related() doesn't make sense, so quit.
1911 good_objects = False
1912 break
1913 if not good_objects:
1914 break
1915
1916 # Descend down tree
1917
1918 # We assume that objects retrieved are homogeneous (which is the premise
1919 # of prefetch_related), so what applies to first object applies to all.
1920 first_obj = obj_list[0]
1921 to_attr = lookup.get_current_to_attr(level)[0]
1922 prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(
1923 first_obj, through_attr, to_attr
1924 )
1925
1926 if not attr_found:
1927 raise AttributeError(
1928 f"Cannot find '{through_attr}' on {first_obj.__class__.__name__} object, '{lookup.prefetch_through}' is an invalid "
1929 "parameter to prefetch_related()"
1930 )
1931
1932 if level == len(through_attrs) - 1 and prefetcher is None:
1933 # Last one, this *must* resolve to something that supports
1934 # prefetching, otherwise there is no point adding it and the
1935 # developer asking for it has made a mistake.
1936 raise ValueError(
1937 f"'{lookup.prefetch_through}' does not resolve to an item that supports "
1938 "prefetching - this is an invalid parameter to "
1939 "prefetch_related()."
1940 )
1941
1942 obj_to_fetch = None
1943 if prefetcher is not None:
1944 obj_to_fetch = [obj for obj in obj_list if not is_fetched(obj)]
1945
1946 if obj_to_fetch:
1947 obj_list, additional_lookups = prefetch_one_level(
1948 obj_to_fetch,
1949 prefetcher,
1950 lookup,
1951 level,
1952 )
1953 # We need to ensure we don't keep adding lookups from the
1954 # same relationships to stop infinite recursion. So, if we
1955 # are already on an automatically added lookup, don't add
1956 # the new lookups from relationships we've seen already.
1957 if not (
1958 prefetch_to in done_queries
1959 and lookup in auto_lookups
1960 and descriptor in followed_descriptors
1961 ):
1962 done_queries[prefetch_to] = obj_list
1963 new_lookups = normalize_prefetch_lookups(
1964 reversed(additional_lookups), # type: ignore[arg-type]
1965 prefetch_to,
1966 )
1967 auto_lookups.update(new_lookups)
1968 all_lookups.extend(new_lookups)
1969 followed_descriptors.add(descriptor)
1970 else:
1971 # Either a singly related object that has already been fetched
1972 # (e.g. via select_related), or hopefully some other property
1973 # that doesn't support prefetching but needs to be traversed.
1974
1975 # We replace the current list of parent objects with the list
1976 # of related objects, filtering out empty or missing values so
1977 # that we can continue with nullable or reverse relations.
1978 new_obj_list = []
1979 for obj in obj_list:
1980 if through_attr in getattr(obj, "_prefetched_objects_cache", ()):
1981 # If related objects have been prefetched, use the
1982 # cache rather than the object's through_attr.
1983 new_obj = list(obj._prefetched_objects_cache.get(through_attr)) # type: ignore[arg-type]
1984 else:
1985 try:
1986 new_obj = getattr(obj, through_attr)
1987 except ObjectDoesNotExist:
1988 continue
1989 if new_obj is None:
1990 continue
1991 # We special-case `list` rather than something more generic
1992 # like `Iterable` because we don't want to accidentally match
1993 # user models that define __iter__.
1994 if isinstance(new_obj, list):
1995 new_obj_list.extend(new_obj)
1996 else:
1997 new_obj_list.append(new_obj)
1998 obj_list = new_obj_list
1999
2000
2001def get_prefetcher(
2002 instance: Model, through_attr: str, to_attr: str
2003) -> tuple[Any, Any, bool, Callable[[Model], bool]]:
2004 """
2005 For the attribute 'through_attr' on the given instance, find
2006 an object that has a get_prefetch_queryset().
2007 Return a 4 tuple containing:
2008 (the object with get_prefetch_queryset (or None),
2009 the descriptor object representing this relationship (or None),
2010 a boolean that is False if the attribute was not found at all,
2011 a function that takes an instance and returns a boolean that is True if
2012 the attribute has already been fetched for that instance)
2013 """
2014
2015 def has_to_attr_attribute(instance: Model) -> bool:
2016 return hasattr(instance, to_attr)
2017
2018 prefetcher = None
2019 is_fetched: Callable[[Model], bool] = has_to_attr_attribute
2020
2021 # For singly related objects, we have to avoid getting the attribute
2022 # from the object, as this will trigger the query. So we first try
2023 # on the class, in order to get the descriptor object.
2024 rel_obj_descriptor = getattr(instance.__class__, through_attr, None)
2025 if rel_obj_descriptor is None:
2026 attr_found = hasattr(instance, through_attr)
2027 else:
2028 attr_found = True
2029 if rel_obj_descriptor:
2030 # singly related object, descriptor object has the
2031 # get_prefetch_queryset() method.
2032 if hasattr(rel_obj_descriptor, "get_prefetch_queryset"):
2033 prefetcher = rel_obj_descriptor
2034 is_fetched = rel_obj_descriptor.is_cached
2035 else:
2036 # descriptor doesn't support prefetching, so we go ahead and get
2037 # the attribute on the instance rather than the class to
2038 # support many related managers
2039 rel_obj = getattr(instance, through_attr)
2040 if hasattr(rel_obj, "get_prefetch_queryset"):
2041 prefetcher = rel_obj
2042 if through_attr != to_attr:
2043 # Special case cached_property instances because hasattr
2044 # triggers attribute computation and assignment.
2045 if isinstance(
2046 getattr(instance.__class__, to_attr, None), cached_property
2047 ):
2048
2049 def has_cached_property(instance: Model) -> bool:
2050 return to_attr in instance.__dict__
2051
2052 is_fetched = has_cached_property
2053 else:
2054
2055 def in_prefetched_cache(instance: Model) -> bool:
2056 return through_attr in instance._prefetched_objects_cache
2057
2058 is_fetched = in_prefetched_cache
2059 return prefetcher, rel_obj_descriptor, attr_found, is_fetched
2060
2061
2062def prefetch_one_level(
2063 instances: list[Model], prefetcher: Any, lookup: Prefetch, level: int
2064) -> tuple[list[Model], list[Prefetch]]:
2065 """
2066 Helper function for prefetch_related_objects().
2067
2068 Run prefetches on all instances using the prefetcher object,
2069 assigning results to relevant caches in instance.
2070
2071 Return the prefetched objects along with any additional prefetches that
2072 must be done due to prefetch_related lookups found from default managers.
2073 """
2074 # prefetcher must have a method get_prefetch_queryset() which takes a list
2075 # of instances, and returns a tuple:
2076
2077 # (queryset of instances of self.model that are related to passed in instances,
2078 # callable that gets value to be matched for returned instances,
2079 # callable that gets value to be matched for passed in instances,
2080 # boolean that is True for singly related objects,
2081 # cache or field name to assign to,
2082 # boolean that is True when the previous argument is a cache name vs a field name).
2083
2084 # The 'values to be matched' must be hashable as they will be used
2085 # in a dictionary.
2086
2087 (
2088 rel_qs,
2089 rel_obj_attr,
2090 instance_attr,
2091 single,
2092 cache_name,
2093 is_descriptor,
2094 ) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))
2095 # We have to handle the possibility that the QuerySet we just got back
2096 # contains some prefetch_related lookups. We don't want to trigger the
2097 # prefetch_related functionality by evaluating the query. Rather, we need
2098 # to merge in the prefetch_related lookups.
2099 # Copy the lookups in case it is a Prefetch object which could be reused
2100 # later (happens in nested prefetch_related).
2101 additional_lookups = [
2102 copy.copy(additional_lookup)
2103 for additional_lookup in getattr(rel_qs, "_prefetch_related_lookups", ())
2104 ]
2105 if additional_lookups:
2106 # Don't need to clone because the queryset should have given us a fresh
2107 # instance, so we access an internal instead of using public interface
2108 # for performance reasons.
2109 rel_qs._prefetch_related_lookups = ()
2110
2111 all_related_objects = list(rel_qs)
2112
2113 rel_obj_cache = {}
2114 for rel_obj in all_related_objects:
2115 rel_attr_val = rel_obj_attr(rel_obj)
2116 rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
2117
2118 to_attr, as_attr = lookup.get_current_to_attr(level)
2119 # Make sure `to_attr` does not conflict with a field.
2120 if as_attr and instances:
2121 # We assume that objects retrieved are homogeneous (which is the premise
2122 # of prefetch_related), so what applies to first object applies to all.
2123 model = instances[0].__class__
2124 try:
2125 model._model_meta.get_field(to_attr)
2126 except FieldDoesNotExist:
2127 pass
2128 else:
2129 msg = "to_attr={} conflicts with a field on the {} model."
2130 raise ValueError(msg.format(to_attr, model.__name__))
2131
2132 # Whether or not we're prefetching the last part of the lookup.
2133 leaf = len(lookup.prefetch_through.split(LOOKUP_SEP)) - 1 == level
2134
2135 for obj in instances:
2136 instance_attr_val = instance_attr(obj)
2137 vals = rel_obj_cache.get(instance_attr_val, [])
2138
2139 if single:
2140 val = vals[0] if vals else None
2141 if as_attr:
2142 # A to_attr has been given for the prefetch.
2143 setattr(obj, to_attr, val)
2144 elif is_descriptor:
2145 # cache_name points to a field name in obj.
2146 # This field is a descriptor for a related object.
2147 setattr(obj, cache_name, val)
2148 else:
2149 # No to_attr has been given for this prefetch operation and the
2150 # cache_name does not point to a descriptor. Store the value of
2151 # the field in the object's field cache.
2152 obj._state.fields_cache[cache_name] = val # type: ignore[index]
2153 else:
2154 if as_attr:
2155 setattr(obj, to_attr, vals)
2156 else:
2157 queryset = getattr(obj, to_attr)
2158 if leaf and lookup.queryset is not None:
2159 qs = queryset._apply_rel_filters(lookup.queryset)
2160 else:
2161 # Check if queryset is a QuerySet or a related manager
2162 # We need a QuerySet instance to cache the prefetched values
2163 if isinstance(queryset, QuerySet):
2164 # It's already a QuerySet, create a new instance
2165 qs = queryset.__class__.from_model(queryset.model)
2166 else:
2167 # It's a related manager, get its QuerySet
2168 # The manager's query property returns a properly filtered QuerySet
2169 qs = queryset.query
2170 qs._result_cache = vals
2171 # We don't want the individual qs doing prefetch_related now,
2172 # since we have merged this into the current work.
2173 qs._prefetch_done = True
2174 obj._prefetched_objects_cache[cache_name] = qs
2175 return all_related_objects, additional_lookups
2176
2177
2178class RelatedPopulator:
2179 """
2180 RelatedPopulator is used for select_related() object instantiation.
2181
2182 The idea is that each select_related() model will be populated by a
2183 different RelatedPopulator instance. The RelatedPopulator instances get
2184 klass_info and select (computed in SQLCompiler) plus the used db as
2185 input for initialization. That data is used to compute which columns
2186 to use, how to instantiate the model, and how to populate the links
2187 between the objects.
2188
2189 The actual creation of the objects is done in populate() method. This
2190 method gets row and from_obj as input and populates the select_related()
2191 model instance.
2192 """
2193
2194 def __init__(self, klass_info: dict[str, Any], select: list[Any]):
2195 # Pre-compute needed attributes. The attributes are:
2196 # - model_cls: the possibly deferred model class to instantiate
2197 # - either:
2198 # - cols_start, cols_end: usually the columns in the row are
2199 # in the same order model_cls.__init__ expects them, so we
2200 # can instantiate by model_cls(*row[cols_start:cols_end])
2201 # - reorder_for_init: When select_related descends to a child
2202 # class, then we want to reuse the already selected parent
2203 # data. However, in this case the parent data isn't necessarily
2204 # in the same order that Model.__init__ expects it to be, so
2205 # we have to reorder the parent data. The reorder_for_init
2206 # attribute contains a function used to reorder the field data
2207 # in the order __init__ expects it.
2208 # - id_idx: the index of the primary key field in the reordered
2209 # model data. Used to check if a related object exists at all.
2210 # - init_list: the field attnames fetched from the database. For
2211 # deferred models this isn't the same as all attnames of the
2212 # model's fields.
2213 # - related_populators: a list of RelatedPopulator instances if
2214 # select_related() descends to related models from this model.
2215 # - local_setter, remote_setter: Methods to set cached values on
2216 # the object being populated and on the remote object. Usually
2217 # these are Field.set_cached_value() methods.
2218 select_fields = klass_info["select_fields"]
2219
2220 self.cols_start = select_fields[0]
2221 self.cols_end = select_fields[-1] + 1
2222 self.init_list = [
2223 f[0].target.attname for f in select[self.cols_start : self.cols_end]
2224 ]
2225 self.reorder_for_init = None
2226
2227 self.model_cls = klass_info["model"]
2228 self.id_idx = self.init_list.index("id")
2229 self.related_populators = get_related_populators(klass_info, select)
2230 self.local_setter = klass_info["local_setter"]
2231 self.remote_setter = klass_info["remote_setter"]
2232
2233 def populate(self, row: tuple[Any, ...], from_obj: Model) -> None:
2234 if self.reorder_for_init:
2235 obj_data = self.reorder_for_init(row)
2236 else:
2237 obj_data = row[self.cols_start : self.cols_end]
2238 if obj_data[self.id_idx] is None:
2239 obj = None
2240 else:
2241 obj = self.model_cls.from_db(self.init_list, obj_data)
2242 for rel_iter in self.related_populators:
2243 rel_iter.populate(row, obj)
2244 self.local_setter(from_obj, obj)
2245 if obj is not None:
2246 self.remote_setter(obj, from_obj)
2247
2248
2249def get_related_populators(
2250 klass_info: dict[str, Any], select: list[Any]
2251) -> list[RelatedPopulator]:
2252 iterators = []
2253 related_klass_infos = klass_info.get("related_klass_infos", [])
2254 for rel_klass_info in related_klass_infos:
2255 rel_cls = RelatedPopulator(rel_klass_info, select)
2256 iterators.append(rel_cls)
2257 return iterators