1"""
2The main QuerySet implementation. This provides the public API for the ORM.
3"""
4
5from __future__ import annotations
6
7import copy
8import operator
9import warnings
10from abc import ABC, abstractmethod
11from collections.abc import Callable, Iterator, Sequence
12from functools import cached_property
13from itertools import chain, islice
14from typing import TYPE_CHECKING, Any, Generic, Never, Self, TypeVar, overload
15
16import plain.runtime
17from plain.exceptions import ValidationError
18from plain.models import transaction
19from plain.models.constants import LOOKUP_SEP, OnConflict
20from plain.models.db import (
21 PLAIN_VERSION_PICKLE_KEY,
22 IntegrityError,
23 NotSupportedError,
24 db_connection,
25)
26from plain.models.exceptions import (
27 FieldDoesNotExist,
28 FieldError,
29 ObjectDoesNotExist,
30)
31from plain.models.expressions import Case, F, ResolvableExpression, Value, When
32from plain.models.fields import (
33 Field,
34 PrimaryKeyField,
35)
36from plain.models.functions import Cast
37from plain.models.query_utils import FilteredRelation, Q
38from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
39from plain.models.sql.query import Query, RawQuery
40from plain.models.sql.subqueries import DeleteQuery, InsertQuery, UpdateQuery
41from plain.models.sql.where import AND, OR, XOR
42from plain.models.utils import resolve_callables
43from plain.utils.functional import partition
44
45# Re-exports for public API
46__all__ = ["F", "Q", "QuerySet", "RawQuerySet", "Prefetch", "FilteredRelation"]
47
48if TYPE_CHECKING:
49 from plain.models import Model
50
51# Type variable for QuerySet generic
52T = TypeVar("T", bound="Model")
53
54# The maximum number of results to fetch in a get() query.
55MAX_GET_RESULTS = 21
56
57# The maximum number of items to display in a QuerySet.__repr__
58REPR_OUTPUT_SIZE = 20
59
60
61class BaseIterable(ABC):
62 def __init__(
63 self,
64 queryset: QuerySet[Any],
65 chunked_fetch: bool = False,
66 chunk_size: int = GET_ITERATOR_CHUNK_SIZE,
67 ):
68 self.queryset = queryset
69 self.chunked_fetch = chunked_fetch
70 self.chunk_size = chunk_size
71
72 @abstractmethod
73 def __iter__(self) -> Iterator[Any]: ...
74
75
76class ModelIterable(BaseIterable):
77 """Iterable that yields a model instance for each row."""
78
79 def __iter__(self) -> Iterator[Model]:
80 queryset = self.queryset
81 compiler = queryset.sql_query.get_compiler()
82 # Execute the query. This will also fill compiler.select, klass_info,
83 # and annotations.
84 results = compiler.execute_sql(
85 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
86 )
87 select, klass_info, annotation_col_map = (
88 compiler.select,
89 compiler.klass_info,
90 compiler.annotation_col_map,
91 )
92 # These are set by execute_sql() above
93 assert select is not None
94 assert klass_info is not None
95 model_cls = klass_info["model"]
96 select_fields = klass_info["select_fields"]
97 model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
98 init_list = [
99 f[0].target.attname for f in select[model_fields_start:model_fields_end]
100 ]
101 related_populators = get_related_populators(klass_info, select)
102 known_related_objects = [
103 (
104 field,
105 related_objs,
106 operator.attrgetter(field.attname),
107 )
108 for field, related_objs in queryset._known_related_objects.items()
109 ]
110 for row in compiler.results_iter(results):
111 obj = model_cls.from_db(init_list, row[model_fields_start:model_fields_end])
112 for rel_populator in related_populators:
113 rel_populator.populate(row, obj)
114 if annotation_col_map:
115 for attr_name, col_pos in annotation_col_map.items():
116 setattr(obj, attr_name, row[col_pos])
117
118 # Add the known related objects to the model.
119 for field, rel_objs, rel_getter in known_related_objects:
120 # Avoid overwriting objects loaded by, e.g., select_related().
121 if field.is_cached(obj):
122 continue
123 rel_obj_id = rel_getter(obj)
124 try:
125 rel_obj = rel_objs[rel_obj_id]
126 except KeyError:
127 pass # May happen in qs1 | qs2 scenarios.
128 else:
129 setattr(obj, field.name, rel_obj)
130
131 yield obj
132
133
134class RawModelIterable(BaseIterable):
135 """
136 Iterable that yields a model instance for each row from a raw queryset.
137 """
138
139 queryset: RawQuerySet
140
141 def __iter__(self) -> Iterator[Model]:
142 # Cache some things for performance reasons outside the loop.
143 # RawQuery is not a Query subclass, so we directly get SQLCompiler
144 query = self.queryset.sql_query
145 compiler = db_connection.ops.compilers[Query](query, db_connection, True)
146 query_iterator = iter(query)
147
148 try:
149 (
150 model_init_names,
151 model_init_pos,
152 annotation_fields,
153 ) = self.queryset.resolve_model_init_order()
154 model_cls = self.queryset.model
155 assert model_cls is not None
156 if "id" not in model_init_names:
157 raise FieldDoesNotExist("Raw query must include the primary key")
158 fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
159 converters = compiler.get_converters(
160 [
161 f.get_col(f.model.model_options.db_table) if f else None
162 for f in fields
163 ]
164 )
165 if converters:
166 query_iterator = compiler.apply_converters(query_iterator, converters)
167 for values in query_iterator:
168 # Associate fields to values
169 model_init_values = [values[pos] for pos in model_init_pos]
170 instance = model_cls.from_db(model_init_names, model_init_values)
171 if annotation_fields:
172 for column, pos in annotation_fields:
173 setattr(instance, column, values[pos])
174 yield instance
175 finally:
176 # Done iterating the Query. If it has its own cursor, close it.
177 if hasattr(query, "cursor") and query.cursor:
178 query.cursor.close()
179
180
181class ValuesIterable(BaseIterable):
182 """
183 Iterable returned by QuerySet.values() that yields a dict for each row.
184 """
185
186 def __iter__(self) -> Iterator[dict[str, Any]]:
187 queryset = self.queryset
188 query = queryset.sql_query
189 compiler = query.get_compiler()
190
191 # extra(select=...) cols are always at the start of the row.
192 names = [
193 *query.extra_select,
194 *query.values_select,
195 *query.annotation_select,
196 ]
197 indexes = range(len(names))
198 for row in compiler.results_iter(
199 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
200 ):
201 yield {names[i]: row[i] for i in indexes}
202
203
204class ValuesListIterable(BaseIterable):
205 """
206 Iterable returned by QuerySet.values_list(flat=False) that yields a tuple
207 for each row.
208 """
209
210 def __iter__(self) -> Iterator[tuple[Any, ...]]:
211 queryset = self.queryset
212 query = queryset.sql_query
213 compiler = query.get_compiler()
214
215 if queryset._fields:
216 # extra(select=...) cols are always at the start of the row.
217 names = [
218 *query.extra_select,
219 *query.values_select,
220 *query.annotation_select,
221 ]
222 fields = [
223 *queryset._fields,
224 *(f for f in query.annotation_select if f not in queryset._fields),
225 ]
226 if fields != names:
227 # Reorder according to fields.
228 index_map = {name: idx for idx, name in enumerate(names)}
229 rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
230 return map(
231 rowfactory,
232 compiler.results_iter(
233 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
234 ),
235 )
236 return iter(
237 compiler.results_iter(
238 tuple_expected=True,
239 chunked_fetch=self.chunked_fetch,
240 chunk_size=self.chunk_size,
241 )
242 )
243
244
245class FlatValuesListIterable(BaseIterable):
246 """
247 Iterable returned by QuerySet.values_list(flat=True) that yields single
248 values.
249 """
250
251 def __iter__(self) -> Iterator[Any]:
252 queryset = self.queryset
253 compiler = queryset.sql_query.get_compiler()
254 for row in compiler.results_iter(
255 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
256 ):
257 yield row[0]
258
259
260class QuerySet(Generic[T]):
261 """
262 Represent a lazy database lookup for a set of objects.
263
264 Usage:
265 MyModel.query.filter(name="test").all()
266
267 Custom QuerySets:
268 from typing import Self
269
270 class TaskQuerySet(QuerySet["Task"]):
271 def active(self) -> Self:
272 return self.filter(is_active=True)
273
274 class Task(Model):
275 is_active = BooleanField(default=True)
276 query = TaskQuerySet()
277
278 Task.query.active().filter(name="test") # Full type inference
279
280 Custom methods should return `Self` to preserve type through method chaining.
281 """
282
283 # Instance attributes (set in from_model())
284 model: type[T]
285 _query: Query
286 _result_cache: list[T] | None
287 _sticky_filter: bool
288 _for_write: bool
289 _prefetch_related_lookups: tuple[Any, ...]
290 _prefetch_done: bool
291 _known_related_objects: dict[Any, dict[Any, Any]]
292 _iterable_class: type[BaseIterable]
293 _fields: tuple[str, ...] | None
294 _defer_next_filter: bool
295 _deferred_filter: tuple[bool, tuple[Any, ...], dict[str, Any]] | None
296
297 def __init__(self):
298 """Minimal init for descriptor mode. Use from_model() to create instances."""
299 pass
300
301 @classmethod
302 def from_model(cls, model: type[T], query: Query | None = None) -> Self:
303 """Create a QuerySet instance bound to a model."""
304 instance = cls()
305 instance.model = model
306 instance._query = query or Query(model)
307 instance._result_cache = None
308 instance._sticky_filter = False
309 instance._for_write = False
310 instance._prefetch_related_lookups = ()
311 instance._prefetch_done = False
312 instance._known_related_objects = {}
313 instance._iterable_class = ModelIterable
314 instance._fields = None
315 instance._defer_next_filter = False
316 instance._deferred_filter = None
317 return instance
318
319 @overload
320 def __get__(self, instance: None, owner: type[T]) -> Self: ...
321
322 @overload
323 def __get__(self, instance: Model, owner: type[T]) -> Never: ...
324
325 def __get__(self, instance: Any, owner: type[T]) -> Self:
326 """Descriptor protocol - return a new QuerySet bound to the model."""
327 if instance is not None:
328 raise AttributeError(
329 f"QuerySet is only accessible from the model class, not instances. "
330 f"Use {owner.__name__}.query instead."
331 )
332 return self.from_model(owner)
333
334 @property
335 def sql_query(self) -> Query:
336 if self._deferred_filter:
337 negate, args, kwargs = self._deferred_filter
338 self._filter_or_exclude_inplace(negate, args, kwargs)
339 self._deferred_filter = None
340 return self._query
341
342 @sql_query.setter
343 def sql_query(self, value: Query) -> None:
344 if value.values_select:
345 self._iterable_class = ValuesIterable
346 self._query = value
347
348 ########################
349 # PYTHON MAGIC METHODS #
350 ########################
351
352 def __deepcopy__(self, memo: dict[int, Any]) -> QuerySet[T]:
353 """Don't populate the QuerySet's cache."""
354 obj = self.__class__.from_model(self.model)
355 for k, v in self.__dict__.items():
356 if k == "_result_cache":
357 obj.__dict__[k] = None
358 else:
359 obj.__dict__[k] = copy.deepcopy(v, memo)
360 return obj
361
362 def __getstate__(self) -> dict[str, Any]:
363 # Force the cache to be fully populated.
364 self._fetch_all()
365 return {**self.__dict__, PLAIN_VERSION_PICKLE_KEY: plain.runtime.__version__}
366
367 def __setstate__(self, state: dict[str, Any]) -> None:
368 pickled_version = state.get(PLAIN_VERSION_PICKLE_KEY)
369 if pickled_version:
370 if pickled_version != plain.runtime.__version__:
371 warnings.warn(
372 f"Pickled queryset instance's Plain version {pickled_version} does not "
373 f"match the current version {plain.runtime.__version__}.",
374 RuntimeWarning,
375 stacklevel=2,
376 )
377 else:
378 warnings.warn(
379 "Pickled queryset instance's Plain version is not specified.",
380 RuntimeWarning,
381 stacklevel=2,
382 )
383 self.__dict__.update(state)
384
385 def __repr__(self) -> str:
386 data = list(self[: REPR_OUTPUT_SIZE + 1])
387 if len(data) > REPR_OUTPUT_SIZE:
388 data[-1] = "...(remaining elements truncated)..."
389 return f"<{self.__class__.__name__} {data!r}>"
390
391 def __len__(self) -> int:
392 self._fetch_all()
393 assert self._result_cache is not None
394 return len(self._result_cache)
395
396 def __iter__(self) -> Iterator[T]:
397 """
398 The queryset iterator protocol uses three nested iterators in the
399 default case:
400 1. sql.compiler.execute_sql()
401 - Returns 100 rows at time (constants.GET_ITERATOR_CHUNK_SIZE)
402 using cursor.fetchmany(). This part is responsible for
403 doing some column masking, and returning the rows in chunks.
404 2. sql.compiler.results_iter()
405 - Returns one row at time. At this point the rows are still just
406 tuples. In some cases the return values are converted to
407 Python values at this location.
408 3. self.iterator()
409 - Responsible for turning the rows into model objects.
410 """
411 self._fetch_all()
412 assert self._result_cache is not None
413 return iter(self._result_cache)
414
415 def __bool__(self) -> bool:
416 self._fetch_all()
417 return bool(self._result_cache)
418
419 @overload
420 def __getitem__(self, k: int) -> T: ...
421
422 @overload
423 def __getitem__(self, k: slice) -> QuerySet[T] | list[T]: ...
424
425 def __getitem__(self, k: int | slice) -> T | QuerySet[T] | list[T]:
426 """Retrieve an item or slice from the set of results."""
427 if not isinstance(k, int | slice):
428 raise TypeError(
429 f"QuerySet indices must be integers or slices, not {type(k).__name__}."
430 )
431 if (isinstance(k, int) and k < 0) or (
432 isinstance(k, slice)
433 and (
434 (k.start is not None and k.start < 0)
435 or (k.stop is not None and k.stop < 0)
436 )
437 ):
438 raise ValueError("Negative indexing is not supported.")
439
440 if self._result_cache is not None:
441 return self._result_cache[k]
442
443 if isinstance(k, slice):
444 qs = self._chain()
445 if k.start is not None:
446 start = int(k.start)
447 else:
448 start = None
449 if k.stop is not None:
450 stop = int(k.stop)
451 else:
452 stop = None
453 qs.sql_query.set_limits(start, stop)
454 return list(qs)[:: k.step] if k.step else qs
455
456 qs = self._chain()
457 qs.sql_query.set_limits(k, k + 1)
458 qs._fetch_all()
459 assert qs._result_cache is not None # _fetch_all guarantees this
460 return qs._result_cache[0]
461
462 def __class_getitem__(cls, *args: Any, **kwargs: Any) -> type[QuerySet[Any]]:
463 return cls
464
465 def __and__(self, other: QuerySet[T]) -> QuerySet[T]:
466 self._merge_sanity_check(other)
467 if isinstance(other, EmptyQuerySet):
468 return other
469 if isinstance(self, EmptyQuerySet):
470 return self
471 combined = self._chain()
472 combined._merge_known_related_objects(other)
473 combined.sql_query.combine(other.sql_query, AND)
474 return combined
475
476 def __or__(self, other: QuerySet[T]) -> QuerySet[T]:
477 self._merge_sanity_check(other)
478 if isinstance(self, EmptyQuerySet):
479 return other
480 if isinstance(other, EmptyQuerySet):
481 return self
482 query = (
483 self
484 if self.sql_query.can_filter()
485 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
486 )
487 combined = query._chain()
488 combined._merge_known_related_objects(other)
489 if not other.sql_query.can_filter():
490 other = other.model._model_meta.base_queryset.filter(
491 id__in=other.values("id")
492 )
493 combined.sql_query.combine(other.sql_query, OR)
494 return combined
495
496 def __xor__(self, other: QuerySet[T]) -> QuerySet[T]:
497 self._merge_sanity_check(other)
498 if isinstance(self, EmptyQuerySet):
499 return other
500 if isinstance(other, EmptyQuerySet):
501 return self
502 query = (
503 self
504 if self.sql_query.can_filter()
505 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
506 )
507 combined = query._chain()
508 combined._merge_known_related_objects(other)
509 if not other.sql_query.can_filter():
510 other = other.model._model_meta.base_queryset.filter(
511 id__in=other.values("id")
512 )
513 combined.sql_query.combine(other.sql_query, XOR)
514 return combined
515
516 ####################################
517 # METHODS THAT DO DATABASE QUERIES #
518 ####################################
519
520 def _iterator(self, use_chunked_fetch: bool, chunk_size: int | None) -> Iterator[T]:
521 iterable = self._iterable_class(
522 self,
523 chunked_fetch=use_chunked_fetch,
524 chunk_size=chunk_size or 2000,
525 )
526 if not self._prefetch_related_lookups or chunk_size is None:
527 yield from iterable
528 return
529
530 iterator = iter(iterable)
531 while results := list(islice(iterator, chunk_size)):
532 prefetch_related_objects(results, *self._prefetch_related_lookups)
533 yield from results
534
535 def iterator(self, chunk_size: int | None = None) -> Iterator[T]:
536 """
537 An iterator over the results from applying this QuerySet to the
538 database. chunk_size must be provided for QuerySets that prefetch
539 related objects. Otherwise, a default chunk_size of 2000 is supplied.
540 """
541 if chunk_size is None:
542 if self._prefetch_related_lookups:
543 raise ValueError(
544 "chunk_size must be provided when using QuerySet.iterator() after "
545 "prefetch_related()."
546 )
547 elif chunk_size <= 0:
548 raise ValueError("Chunk size must be strictly positive.")
549 use_chunked_fetch = not db_connection.settings_dict.get(
550 "DISABLE_SERVER_SIDE_CURSORS"
551 )
552 return self._iterator(use_chunked_fetch, chunk_size)
553
554 def aggregate(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
555 """
556 Return a dictionary containing the calculations (aggregation)
557 over the current queryset.
558
559 If args is present the expression is passed as a kwarg using
560 the Aggregate object's default alias.
561 """
562 if self.sql_query.distinct_fields:
563 raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
564 self._validate_values_are_expressions(
565 (*args, *kwargs.values()), method_name="aggregate"
566 )
567 for arg in args:
568 # The default_alias property raises TypeError if default_alias
569 # can't be set automatically or AttributeError if it isn't an
570 # attribute.
571 try:
572 arg.default_alias
573 except (AttributeError, TypeError):
574 raise TypeError("Complex aggregates require an alias")
575 kwargs[arg.default_alias] = arg
576
577 return self.sql_query.chain().get_aggregation(kwargs)
578
579 def count(self) -> int:
580 """
581 Perform a SELECT COUNT() and return the number of records as an
582 integer.
583
584 If the QuerySet is already fully cached, return the length of the
585 cached results set to avoid multiple SELECT COUNT(*) calls.
586 """
587 if self._result_cache is not None:
588 return len(self._result_cache)
589
590 return self.sql_query.get_count()
591
592 def get(self, *args: Any, **kwargs: Any) -> T:
593 """
594 Perform the query and return a single object matching the given
595 keyword arguments.
596 """
597 clone = self.filter(*args, **kwargs)
598 if self.sql_query.can_filter() and not self.sql_query.distinct_fields:
599 clone = clone.order_by()
600 limit = None
601 if (
602 not clone.sql_query.select_for_update
603 or db_connection.features.supports_select_for_update_with_limit
604 ):
605 limit = MAX_GET_RESULTS
606 clone.sql_query.set_limits(high=limit)
607 num = len(clone)
608 if num == 1:
609 assert clone._result_cache is not None # len() fetches results
610 return clone._result_cache[0]
611 if not num:
612 raise self.model.DoesNotExist(
613 f"{self.model.model_options.object_name} matching query does not exist."
614 )
615 raise self.model.MultipleObjectsReturned(
616 "get() returned more than one {} -- it returned {}!".format(
617 self.model.model_options.object_name,
618 num if not limit or num < limit else "more than %s" % (limit - 1),
619 )
620 )
621
622 def get_or_none(self, *args: Any, **kwargs: Any) -> T | None:
623 """
624 Perform the query and return a single object matching the given
625 keyword arguments, or None if no object is found.
626 """
627 try:
628 return self.get(*args, **kwargs)
629 except self.model.DoesNotExist:
630 return None
631
632 def create(self, **kwargs: Any) -> T:
633 """
634 Create a new object with the given kwargs, saving it to the database
635 and returning the created object.
636 """
637 obj = self.model(**kwargs)
638 self._for_write = True
639 obj.save(force_insert=True)
640 return obj
641
642 def _prepare_for_bulk_create(self, objs: list[T]) -> None:
643 id_field = self.model._model_meta.get_forward_field("id")
644 for obj in objs:
645 if obj.id is None:
646 # Populate new primary key values.
647 obj.id = id_field.get_id_value_on_save(obj)
648 obj._prepare_related_fields_for_save(operation_name="bulk_create")
649
650 def _check_bulk_create_options(
651 self,
652 update_conflicts: bool,
653 update_fields: list[Field] | None,
654 unique_fields: list[Field] | None,
655 ) -> OnConflict | None:
656 db_features = db_connection.features
657 if update_conflicts:
658 if not db_features.supports_update_conflicts:
659 raise NotSupportedError(
660 "This database backend does not support updating conflicts."
661 )
662 if not update_fields:
663 raise ValueError(
664 "Fields that will be updated when a row insertion fails "
665 "on conflicts must be provided."
666 )
667 if unique_fields and not db_features.supports_update_conflicts_with_target:
668 raise NotSupportedError(
669 "This database backend does not support updating "
670 "conflicts with specifying unique fields that can trigger "
671 "the upsert."
672 )
673 if not unique_fields and db_features.supports_update_conflicts_with_target:
674 raise ValueError(
675 "Unique fields that can trigger the upsert must be provided."
676 )
677 # Updating primary keys and non-concrete fields is forbidden.
678 from plain.models.fields.related import ManyToManyField
679
680 if any(
681 not f.concrete or isinstance(f, ManyToManyField) for f in update_fields
682 ):
683 raise ValueError(
684 "bulk_create() can only be used with concrete fields in "
685 "update_fields."
686 )
687 if any(f.primary_key for f in update_fields):
688 raise ValueError(
689 "bulk_create() cannot be used with primary keys in update_fields."
690 )
691 if unique_fields:
692 from plain.models.fields.related import ManyToManyField
693
694 if any(
695 not f.concrete or isinstance(f, ManyToManyField)
696 for f in unique_fields
697 ):
698 raise ValueError(
699 "bulk_create() can only be used with concrete fields "
700 "in unique_fields."
701 )
702 return OnConflict.UPDATE
703 return None
704
705 def bulk_create(
706 self,
707 objs: Sequence[T],
708 batch_size: int | None = None,
709 update_conflicts: bool = False,
710 update_fields: list[str] | None = None,
711 unique_fields: list[str] | None = None,
712 ) -> list[T]:
713 """
714 Insert each of the instances into the database. Do *not* call
715 save() on each of the instances, and do not set the primary key attribute if it is an
716 autoincrement field (except if features.can_return_rows_from_bulk_insert=True).
717 Multi-table models are not supported.
718 """
719 # When you bulk insert you don't get the primary keys back (if it's an
720 # autoincrement, except if can_return_rows_from_bulk_insert=True), so
721 # you can't insert into the child tables which references this. There
722 # are two workarounds:
723 # 1) This could be implemented if you didn't have an autoincrement pk
724 # 2) You could do it by doing O(n) normal inserts into the parent
725 # tables to get the primary keys back and then doing a single bulk
726 # insert into the childmost table.
727 # We currently set the primary keys on the objects when using
728 # PostgreSQL via the RETURNING ID clause. It should be possible for
729 # Oracle as well, but the semantics for extracting the primary keys is
730 # trickier so it's not done yet.
731 if batch_size is not None and batch_size <= 0:
732 raise ValueError("Batch size must be a positive integer.")
733
734 objs = list(objs)
735 if not objs:
736 return objs
737 meta = self.model._model_meta
738 unique_fields_objs: list[Field] | None = None
739 update_fields_objs: list[Field] | None = None
740 if unique_fields:
741 unique_fields_objs = [
742 meta.get_forward_field(name) for name in unique_fields
743 ]
744 if update_fields:
745 update_fields_objs = [
746 meta.get_forward_field(name) for name in update_fields
747 ]
748 on_conflict = self._check_bulk_create_options(
749 update_conflicts,
750 update_fields_objs,
751 unique_fields_objs,
752 )
753 self._for_write = True
754 fields = meta.concrete_fields
755 self._prepare_for_bulk_create(objs)
756 with transaction.atomic(savepoint=False):
757 objs_with_id, objs_without_id = partition(lambda o: o.id is None, objs)
758 if objs_with_id:
759 returned_columns = self._batched_insert(
760 objs_with_id,
761 fields,
762 batch_size,
763 on_conflict=on_conflict,
764 update_fields=update_fields_objs,
765 unique_fields=unique_fields_objs,
766 )
767 id_field = meta.get_forward_field("id")
768 for obj_with_id, results in zip(objs_with_id, returned_columns):
769 for result, field in zip(results, meta.db_returning_fields):
770 if field != id_field:
771 setattr(obj_with_id, field.attname, result)
772 for obj_with_id in objs_with_id:
773 obj_with_id._state.adding = False
774 if objs_without_id:
775 fields = [f for f in fields if not isinstance(f, PrimaryKeyField)]
776 returned_columns = self._batched_insert(
777 objs_without_id,
778 fields,
779 batch_size,
780 on_conflict=on_conflict,
781 update_fields=update_fields_objs,
782 unique_fields=unique_fields_objs,
783 )
784 if (
785 db_connection.features.can_return_rows_from_bulk_insert
786 and on_conflict is None
787 ):
788 assert len(returned_columns) == len(objs_without_id)
789 for obj_without_id, results in zip(objs_without_id, returned_columns):
790 for result, field in zip(results, meta.db_returning_fields):
791 setattr(obj_without_id, field.attname, result)
792 obj_without_id._state.adding = False
793
794 return objs
795
796 def bulk_update(
797 self, objs: Sequence[T], fields: list[str], batch_size: int | None = None
798 ) -> int:
799 """
800 Update the given fields in each of the given objects in the database.
801 """
802 if batch_size is not None and batch_size <= 0:
803 raise ValueError("Batch size must be a positive integer.")
804 if not fields:
805 raise ValueError("Field names must be given to bulk_update().")
806 objs_tuple = tuple(objs)
807 if any(obj.id is None for obj in objs_tuple):
808 raise ValueError("All bulk_update() objects must have a primary key set.")
809 fields_list = [
810 self.model._model_meta.get_forward_field(name) for name in fields
811 ]
812 from plain.models.fields.related import ManyToManyField
813
814 if any(not f.concrete or isinstance(f, ManyToManyField) for f in fields_list):
815 raise ValueError("bulk_update() can only be used with concrete fields.")
816 if any(f.primary_key for f in fields_list):
817 raise ValueError("bulk_update() cannot be used with primary key fields.")
818 if not objs_tuple:
819 return 0
820 for obj in objs_tuple:
821 obj._prepare_related_fields_for_save(
822 operation_name="bulk_update", fields=fields_list
823 )
824 # PK is used twice in the resulting update query, once in the filter
825 # and once in the WHEN. Each field will also have one CAST.
826 self._for_write = True
827 max_batch_size = db_connection.ops.bulk_batch_size(
828 ["id", "id"] + fields_list, objs_tuple
829 )
830 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
831 requires_casting = db_connection.features.requires_casted_case_in_updates
832 batches = (
833 objs_tuple[i : i + batch_size]
834 for i in range(0, len(objs_tuple), batch_size)
835 )
836 updates = []
837 for batch_objs in batches:
838 update_kwargs = {}
839 for field in fields_list:
840 when_statements = []
841 for obj in batch_objs:
842 attr = getattr(obj, field.attname)
843 if not isinstance(attr, ResolvableExpression):
844 attr = Value(attr, output_field=field)
845 when_statements.append(When(id=obj.id, then=attr))
846 case_statement = Case(*when_statements, output_field=field)
847 if requires_casting:
848 case_statement = Cast(case_statement, output_field=field)
849 update_kwargs[field.attname] = case_statement
850 updates.append(([obj.id for obj in batch_objs], update_kwargs))
851 rows_updated = 0
852 queryset = self._chain()
853 with transaction.atomic(savepoint=False):
854 for ids, update_kwargs in updates:
855 rows_updated += queryset.filter(id__in=ids).update(**update_kwargs)
856 return rows_updated
857
858 def get_or_create(
859 self, defaults: dict[str, Any] | None = None, **kwargs: Any
860 ) -> tuple[T, bool]:
861 """
862 Look up an object with the given kwargs, creating one if necessary.
863 Return a tuple of (object, created), where created is a boolean
864 specifying whether an object was created.
865 """
866 # The get() needs to be targeted at the write database in order
867 # to avoid potential transaction consistency problems.
868 self._for_write = True
869 try:
870 return self.get(**kwargs), False
871 except self.model.DoesNotExist:
872 params = self._extract_model_params(defaults, **kwargs)
873 # Try to create an object using passed params.
874 try:
875 with transaction.atomic():
876 params = dict(resolve_callables(params))
877 return self.create(**params), True
878 except (IntegrityError, ValidationError):
879 # Since create() also validates by default,
880 # we can get any kind of ValidationError here,
881 # or it can flow through and get an IntegrityError from the database.
882 # The main thing we're concerned about is uniqueness failures,
883 # but ValidationError could include other things too.
884 # In all cases though it should be fine to try the get() again
885 # and return an existing object.
886 try:
887 return self.get(**kwargs), False
888 except self.model.DoesNotExist:
889 pass
890 raise
891
892 def update_or_create(
893 self,
894 defaults: dict[str, Any] | None = None,
895 create_defaults: dict[str, Any] | None = None,
896 **kwargs: Any,
897 ) -> tuple[T, bool]:
898 """
899 Look up an object with the given kwargs, updating one with defaults
900 if it exists, otherwise create a new one. Optionally, an object can
901 be created with different values than defaults by using
902 create_defaults.
903 Return a tuple (object, created), where created is a boolean
904 specifying whether an object was created.
905 """
906 if create_defaults is None:
907 update_defaults = create_defaults = defaults or {}
908 else:
909 update_defaults = defaults or {}
910 self._for_write = True
911 with transaction.atomic():
912 # Lock the row so that a concurrent update is blocked until
913 # update_or_create() has performed its save.
914 obj, created = self.select_for_update().get_or_create(
915 create_defaults, **kwargs
916 )
917 if created:
918 return obj, created
919 for k, v in resolve_callables(update_defaults):
920 setattr(obj, k, v)
921
922 update_fields = set(update_defaults)
923 concrete_field_names = self.model._model_meta._non_pk_concrete_field_names
924 # update_fields does not support non-concrete fields.
925 if concrete_field_names.issuperset(update_fields):
926 # Add fields which are set on pre_save(), e.g. auto_now fields.
927 # This is to maintain backward compatibility as these fields
928 # are not updated unless explicitly specified in the
929 # update_fields list.
930 for field in self.model._model_meta.local_concrete_fields:
931 if not (
932 field.primary_key or field.__class__.pre_save is Field.pre_save
933 ):
934 update_fields.add(field.name)
935 if field.name != field.attname:
936 update_fields.add(field.attname)
937 obj.save(update_fields=update_fields)
938 else:
939 obj.save()
940 return obj, False
941
942 def _extract_model_params(
943 self, defaults: dict[str, Any] | None, **kwargs: Any
944 ) -> dict[str, Any]:
945 """
946 Prepare `params` for creating a model instance based on the given
947 kwargs; for use by get_or_create().
948 """
949 defaults = defaults or {}
950 params = {k: v for k, v in kwargs.items() if LOOKUP_SEP not in k}
951 params.update(defaults)
952 property_names = self.model._model_meta._property_names
953 invalid_params = []
954 for param in params:
955 try:
956 self.model._model_meta.get_field(param)
957 except FieldDoesNotExist:
958 # It's okay to use a model's property if it has a setter.
959 if not (param in property_names and getattr(self.model, param).fset):
960 invalid_params.append(param)
961 if invalid_params:
962 raise FieldError(
963 "Invalid field name(s) for model {}: '{}'.".format(
964 self.model.model_options.object_name,
965 "', '".join(sorted(invalid_params)),
966 )
967 )
968 return params
969
970 def first(self) -> T | None:
971 """Return the first object of a query or None if no match is found."""
972 for obj in self[:1]:
973 return obj
974 return None
975
976 def last(self) -> T | None:
977 """Return the last object of a query or None if no match is found."""
978 queryset = self.reverse()
979 for obj in queryset[:1]:
980 return obj
981 return None
982
983 def delete(self) -> tuple[int, dict[str, int]]:
984 """Delete the records in the current QuerySet."""
985 if self.sql_query.is_sliced:
986 raise TypeError("Cannot use 'limit' or 'offset' with delete().")
987 if self.sql_query.distinct or self.sql_query.distinct_fields:
988 raise TypeError("Cannot call delete() after .distinct().")
989 if self._fields is not None:
990 raise TypeError("Cannot call delete() after .values() or .values_list()")
991
992 del_query = self._chain()
993
994 # The delete is actually 2 queries - one to find related objects,
995 # and one to delete. Make sure that the discovery of related
996 # objects is performed on the same database as the deletion.
997 del_query._for_write = True
998
999 # Disable non-supported fields.
1000 del_query.sql_query.select_for_update = False
1001 del_query.sql_query.select_related = False
1002 del_query.sql_query.clear_ordering(force=True)
1003
1004 from plain.models.deletion import Collector
1005
1006 collector = Collector(origin=self)
1007 collector.collect(del_query)
1008 deleted, _rows_count = collector.delete()
1009
1010 # Clear the result cache, in case this QuerySet gets reused.
1011 self._result_cache = None
1012 return deleted, _rows_count
1013
1014 def _raw_delete(self) -> int:
1015 """
1016 Delete objects found from the given queryset in single direct SQL
1017 query. No signals are sent and there is no protection for cascades.
1018 """
1019 query = self.sql_query.clone()
1020 query.__class__ = DeleteQuery
1021 cursor = query.get_compiler().execute_sql(CURSOR)
1022 if cursor:
1023 with cursor:
1024 return cursor.rowcount
1025 return 0
1026
1027 def update(self, **kwargs: Any) -> int:
1028 """
1029 Update all elements in the current QuerySet, setting all the given
1030 fields to the appropriate values.
1031 """
1032 if self.sql_query.is_sliced:
1033 raise TypeError("Cannot update a query once a slice has been taken.")
1034 self._for_write = True
1035 query = self.sql_query.chain(UpdateQuery)
1036 query.add_update_values(kwargs)
1037
1038 # Inline annotations in order_by(), if possible.
1039 new_order_by = []
1040 for col in query.order_by:
1041 alias = col
1042 descending = False
1043 if isinstance(alias, str) and alias.startswith("-"):
1044 alias = alias.removeprefix("-")
1045 descending = True
1046 if annotation := query.annotations.get(alias):
1047 if getattr(annotation, "contains_aggregate", False):
1048 raise FieldError(
1049 f"Cannot update when ordering by an aggregate: {annotation}"
1050 )
1051 if descending:
1052 annotation = annotation.desc()
1053 new_order_by.append(annotation)
1054 else:
1055 new_order_by.append(col)
1056 query.order_by = tuple(new_order_by)
1057
1058 # Clear any annotations so that they won't be present in subqueries.
1059 query.annotations = {}
1060 with transaction.mark_for_rollback_on_error():
1061 rows = query.get_compiler().execute_sql(CURSOR)
1062 self._result_cache = None
1063 return rows
1064
1065 def _update(self, values: list[tuple[Field, Any, Any]]) -> int:
1066 """
1067 A version of update() that accepts field objects instead of field names.
1068 Used primarily for model saving and not intended for use by general
1069 code (it requires too much poking around at model internals to be
1070 useful at that level).
1071 """
1072 if self.sql_query.is_sliced:
1073 raise TypeError("Cannot update a query once a slice has been taken.")
1074 query = self.sql_query.chain(UpdateQuery)
1075 query.add_update_fields(values)
1076 # Clear any annotations so that they won't be present in subqueries.
1077 query.annotations = {}
1078 self._result_cache = None
1079 return query.get_compiler().execute_sql(CURSOR)
1080
1081 def exists(self) -> bool:
1082 """
1083 Return True if the QuerySet would have any results, False otherwise.
1084 """
1085 if self._result_cache is None:
1086 return self.sql_query.has_results()
1087 return bool(self._result_cache)
1088
1089 def _prefetch_related_objects(self) -> None:
1090 # This method can only be called once the result cache has been filled.
1091 assert self._result_cache is not None
1092 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1093 self._prefetch_done = True
1094
1095 def explain(self, *, format: str | None = None, **options: Any) -> str:
1096 """
1097 Runs an EXPLAIN on the SQL query this QuerySet would perform, and
1098 returns the results.
1099 """
1100 return self.sql_query.explain(format=format, **options)
1101
1102 ##################################################
1103 # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
1104 ##################################################
1105
1106 def raw(
1107 self,
1108 raw_query: str,
1109 params: Sequence[Any] = (),
1110 translations: dict[str, str] | None = None,
1111 ) -> RawQuerySet:
1112 qs = RawQuerySet(
1113 raw_query,
1114 model=self.model,
1115 params=tuple(params),
1116 translations=translations,
1117 )
1118 qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
1119 return qs
1120
1121 def _values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1122 clone = self._chain()
1123 if expressions:
1124 clone = clone.annotate(**expressions)
1125 clone._fields = fields
1126 clone.sql_query.set_values(list(fields))
1127 return clone
1128
1129 def values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1130 fields += tuple(expressions)
1131 clone = self._values(*fields, **expressions)
1132 clone._iterable_class = ValuesIterable
1133 return clone
1134
1135 def values_list(self, *fields: str, flat: bool = False) -> QuerySet[Any]:
1136 if flat and len(fields) > 1:
1137 raise TypeError(
1138 "'flat' is not valid when values_list is called with more than one "
1139 "field."
1140 )
1141
1142 field_names = {f for f in fields if not isinstance(f, ResolvableExpression)}
1143 _fields = []
1144 expressions = {}
1145 counter = 1
1146 for field in fields:
1147 if isinstance(field, ResolvableExpression):
1148 field_id_prefix = getattr(
1149 field, "default_alias", field.__class__.__name__.lower()
1150 )
1151 while True:
1152 field_id = field_id_prefix + str(counter)
1153 counter += 1
1154 if field_id not in field_names:
1155 break
1156 expressions[field_id] = field
1157 _fields.append(field_id)
1158 else:
1159 _fields.append(field)
1160
1161 clone = self._values(*_fields, **expressions)
1162 clone._iterable_class = FlatValuesListIterable if flat else ValuesListIterable
1163 return clone
1164
1165 def none(self) -> QuerySet[T]:
1166 """Return an empty QuerySet."""
1167 clone = self._chain()
1168 clone.sql_query.set_empty()
1169 return clone
1170
1171 ##################################################################
1172 # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
1173 ##################################################################
1174
1175 def all(self) -> Self:
1176 """
1177 Return a new QuerySet that is a copy of the current one. This allows a
1178 QuerySet to proxy for a model queryset in some cases.
1179 """
1180 obj = self._chain()
1181 # Preserve cache since all() doesn't modify the query.
1182 # This is important for prefetch_related() to work correctly.
1183 obj._result_cache = self._result_cache
1184 obj._prefetch_done = self._prefetch_done
1185 return obj
1186
1187 def filter(self, *args: Any, **kwargs: Any) -> Self:
1188 """
1189 Return a new QuerySet instance with the args ANDed to the existing
1190 set.
1191 """
1192 return self._filter_or_exclude(False, args, kwargs)
1193
1194 def exclude(self, *args: Any, **kwargs: Any) -> Self:
1195 """
1196 Return a new QuerySet instance with NOT (args) ANDed to the existing
1197 set.
1198 """
1199 return self._filter_or_exclude(True, args, kwargs)
1200
1201 def _filter_or_exclude(
1202 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1203 ) -> Self:
1204 if (args or kwargs) and self.sql_query.is_sliced:
1205 raise TypeError("Cannot filter a query once a slice has been taken.")
1206 clone = self._chain()
1207 if self._defer_next_filter:
1208 self._defer_next_filter = False
1209 clone._deferred_filter = negate, args, kwargs
1210 else:
1211 clone._filter_or_exclude_inplace(negate, args, kwargs)
1212 return clone
1213
1214 def _filter_or_exclude_inplace(
1215 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1216 ) -> None:
1217 if negate:
1218 self._query.add_q(~Q(*args, **kwargs))
1219 else:
1220 self._query.add_q(Q(*args, **kwargs))
1221
1222 def complex_filter(self, filter_obj: Q | dict[str, Any]) -> QuerySet[T]:
1223 """
1224 Return a new QuerySet instance with filter_obj added to the filters.
1225
1226 filter_obj can be a Q object or a dictionary of keyword lookup
1227 arguments.
1228
1229 This exists to support framework features such as 'limit_choices_to',
1230 and usually it will be more natural to use other methods.
1231 """
1232 if isinstance(filter_obj, Q):
1233 clone = self._chain()
1234 clone.sql_query.add_q(filter_obj)
1235 return clone
1236 else:
1237 return self._filter_or_exclude(False, args=(), kwargs=filter_obj)
1238
1239 def select_for_update(
1240 self,
1241 nowait: bool = False,
1242 skip_locked: bool = False,
1243 of: tuple[str, ...] = (),
1244 no_key: bool = False,
1245 ) -> QuerySet[T]:
1246 """
1247 Return a new QuerySet instance that will select objects with a
1248 FOR UPDATE lock.
1249 """
1250 if nowait and skip_locked:
1251 raise ValueError("The nowait option cannot be used with skip_locked.")
1252 obj = self._chain()
1253 obj._for_write = True
1254 obj.sql_query.select_for_update = True
1255 obj.sql_query.select_for_update_nowait = nowait
1256 obj.sql_query.select_for_update_skip_locked = skip_locked
1257 obj.sql_query.select_for_update_of = of
1258 obj.sql_query.select_for_no_key_update = no_key
1259 return obj
1260
1261 def select_related(self, *fields: str | None) -> Self:
1262 """
1263 Return a new QuerySet instance that will select related objects.
1264
1265 If fields are specified, they must be ForeignKeyField fields and only those
1266 related objects are included in the selection.
1267
1268 If select_related(None) is called, clear the list.
1269 """
1270 if self._fields is not None:
1271 raise TypeError(
1272 "Cannot call select_related() after .values() or .values_list()"
1273 )
1274
1275 obj = self._chain()
1276 if fields == (None,):
1277 obj.sql_query.select_related = False
1278 elif fields:
1279 obj.sql_query.add_select_related(list(fields)) # type: ignore[arg-type]
1280 else:
1281 obj.sql_query.select_related = True
1282 return obj
1283
1284 def prefetch_related(self, *lookups: str | Prefetch | None) -> Self:
1285 """
1286 Return a new QuerySet instance that will prefetch the specified
1287 Many-To-One and Many-To-Many related objects when the QuerySet is
1288 evaluated.
1289
1290 When prefetch_related() is called more than once, append to the list of
1291 prefetch lookups. If prefetch_related(None) is called, clear the list.
1292 """
1293 clone = self._chain()
1294 if lookups == (None,):
1295 clone._prefetch_related_lookups = ()
1296 else:
1297 for lookup in lookups:
1298 lookup_str: str
1299 if isinstance(lookup, Prefetch):
1300 lookup_str = lookup.prefetch_to
1301 else:
1302 assert isinstance(lookup, str)
1303 lookup_str = lookup
1304 lookup_str = lookup_str.split(LOOKUP_SEP, 1)[0]
1305 if lookup_str in self.sql_query._filtered_relations:
1306 raise ValueError(
1307 "prefetch_related() is not supported with FilteredRelation."
1308 )
1309 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1310 return clone
1311
1312 def annotate(self, *args: Any, **kwargs: Any) -> Self:
1313 """
1314 Return a query set in which the returned objects have been annotated
1315 with extra data or aggregations.
1316 """
1317 return self._annotate(args, kwargs, select=True)
1318
1319 def alias(self, *args: Any, **kwargs: Any) -> Self:
1320 """
1321 Return a query set with added aliases for extra data or aggregations.
1322 """
1323 return self._annotate(args, kwargs, select=False)
1324
1325 def _annotate(
1326 self, args: tuple[Any, ...], kwargs: dict[str, Any], select: bool = True
1327 ) -> Self:
1328 self._validate_values_are_expressions(
1329 args + tuple(kwargs.values()), method_name="annotate"
1330 )
1331 annotations = {}
1332 for arg in args:
1333 # The default_alias property may raise a TypeError.
1334 try:
1335 if arg.default_alias in kwargs:
1336 raise ValueError(
1337 f"The named annotation '{arg.default_alias}' conflicts with the "
1338 "default name for another annotation."
1339 )
1340 except TypeError:
1341 raise TypeError("Complex annotations require an alias")
1342 annotations[arg.default_alias] = arg
1343 annotations.update(kwargs)
1344
1345 clone = self._chain()
1346 names = self._fields
1347 if names is None:
1348 names = set(
1349 chain.from_iterable(
1350 (field.name, field.attname)
1351 if hasattr(field, "attname")
1352 else (field.name,)
1353 for field in self.model._model_meta.get_fields()
1354 )
1355 )
1356
1357 for alias, annotation in annotations.items():
1358 if alias in names:
1359 raise ValueError(
1360 f"The annotation '{alias}' conflicts with a field on the model."
1361 )
1362 if isinstance(annotation, FilteredRelation):
1363 clone.sql_query.add_filtered_relation(annotation, alias)
1364 else:
1365 clone.sql_query.add_annotation(
1366 annotation,
1367 alias,
1368 select=select,
1369 )
1370 for alias, annotation in clone.sql_query.annotations.items():
1371 if alias in annotations and annotation.contains_aggregate:
1372 if clone._fields is None:
1373 clone.sql_query.group_by = True
1374 else:
1375 clone.sql_query.set_group_by()
1376 break
1377
1378 return clone
1379
1380 def order_by(self, *field_names: str) -> Self:
1381 """Return a new QuerySet instance with the ordering changed."""
1382 if self.sql_query.is_sliced:
1383 raise TypeError("Cannot reorder a query once a slice has been taken.")
1384 obj = self._chain()
1385 obj.sql_query.clear_ordering(force=True, clear_default=False)
1386 obj.sql_query.add_ordering(*field_names)
1387 return obj
1388
1389 def distinct(self, *field_names: str) -> Self:
1390 """
1391 Return a new QuerySet instance that will select only distinct results.
1392 """
1393 if self.sql_query.is_sliced:
1394 raise TypeError(
1395 "Cannot create distinct fields once a slice has been taken."
1396 )
1397 obj = self._chain()
1398 obj.sql_query.add_distinct_fields(*field_names)
1399 return obj
1400
1401 def extra(
1402 self,
1403 select: dict[str, str] | None = None,
1404 where: list[str] | None = None,
1405 params: list[Any] | None = None,
1406 tables: list[str] | None = None,
1407 order_by: list[str] | None = None,
1408 select_params: list[Any] | None = None,
1409 ) -> QuerySet[T]:
1410 """Add extra SQL fragments to the query."""
1411 if self.sql_query.is_sliced:
1412 raise TypeError("Cannot change a query once a slice has been taken.")
1413 clone = self._chain()
1414 clone.sql_query.add_extra(
1415 select or {},
1416 select_params,
1417 where or [],
1418 params or [],
1419 tables or [],
1420 tuple(order_by) if order_by else (),
1421 )
1422 return clone
1423
1424 def reverse(self) -> QuerySet[T]:
1425 """Reverse the ordering of the QuerySet."""
1426 if self.sql_query.is_sliced:
1427 raise TypeError("Cannot reverse a query once a slice has been taken.")
1428 clone = self._chain()
1429 clone.sql_query.standard_ordering = not clone.sql_query.standard_ordering
1430 return clone
1431
1432 def defer(self, *fields: str | None) -> QuerySet[T]:
1433 """
1434 Defer the loading of data for certain fields until they are accessed.
1435 Add the set of deferred fields to any existing set of deferred fields.
1436 The only exception to this is if None is passed in as the only
1437 parameter, in which case removal all deferrals.
1438 """
1439 if self._fields is not None:
1440 raise TypeError("Cannot call defer() after .values() or .values_list()")
1441 clone = self._chain()
1442 if fields == (None,):
1443 clone.sql_query.clear_deferred_loading()
1444 else:
1445 clone.sql_query.add_deferred_loading(frozenset(fields))
1446 return clone
1447
1448 def only(self, *fields: str) -> QuerySet[T]:
1449 """
1450 Essentially, the opposite of defer(). Only the fields passed into this
1451 method and that are not already specified as deferred are loaded
1452 immediately when the queryset is evaluated.
1453 """
1454 if self._fields is not None:
1455 raise TypeError("Cannot call only() after .values() or .values_list()")
1456 if fields == (None,):
1457 # Can only pass None to defer(), not only(), as the rest option.
1458 # That won't stop people trying to do this, so let's be explicit.
1459 raise TypeError("Cannot pass None as an argument to only().")
1460 for field in fields:
1461 field = field.split(LOOKUP_SEP, 1)[0]
1462 if field in self.sql_query._filtered_relations:
1463 raise ValueError("only() is not supported with FilteredRelation.")
1464 clone = self._chain()
1465 clone.sql_query.add_immediate_loading(set(fields))
1466 return clone
1467
1468 ###################################
1469 # PUBLIC INTROSPECTION ATTRIBUTES #
1470 ###################################
1471
1472 @property
1473 def ordered(self) -> bool:
1474 """
1475 Return True if the QuerySet is ordered -- i.e. has an order_by()
1476 clause or a default ordering on the model (or is empty).
1477 """
1478 if isinstance(self, EmptyQuerySet):
1479 return True
1480 if self.sql_query.extra_order_by or self.sql_query.order_by:
1481 return True
1482 elif (
1483 self.sql_query.default_ordering
1484 and self.sql_query.model
1485 and self.sql_query.model._model_meta.ordering # type: ignore[arg-type]
1486 and
1487 # A default ordering doesn't affect GROUP BY queries.
1488 not self.sql_query.group_by
1489 ):
1490 return True
1491 else:
1492 return False
1493
1494 ###################
1495 # PRIVATE METHODS #
1496 ###################
1497
1498 def _insert(
1499 self,
1500 objs: list[T],
1501 fields: list[Field],
1502 returning_fields: list[Field] | None = None,
1503 raw: bool = False,
1504 on_conflict: OnConflict | None = None,
1505 update_fields: list[Field] | None = None,
1506 unique_fields: list[Field] | None = None,
1507 ) -> list[tuple[Any, ...]] | None:
1508 """
1509 Insert a new record for the given model. This provides an interface to
1510 the InsertQuery class and is how Model.save() is implemented.
1511 """
1512 self._for_write = True
1513 query = InsertQuery(
1514 self.model,
1515 on_conflict=on_conflict.value if on_conflict else None,
1516 update_fields=update_fields,
1517 unique_fields=unique_fields,
1518 )
1519 query.insert_values(fields, objs, raw=raw)
1520 # InsertQuery returns SQLInsertCompiler which has different execute_sql signature
1521 return query.get_compiler().execute_sql(returning_fields) # type: ignore[arg-type]
1522
1523 def _batched_insert(
1524 self,
1525 objs: list[T],
1526 fields: list[Field],
1527 batch_size: int | None,
1528 on_conflict: OnConflict | None = None,
1529 update_fields: list[Field] | None = None,
1530 unique_fields: list[Field] | None = None,
1531 ) -> list[tuple[Any, ...]]:
1532 """
1533 Helper method for bulk_create() to insert objs one batch at a time.
1534 """
1535 ops = db_connection.ops
1536 max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
1537 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
1538 inserted_rows = []
1539 bulk_return = db_connection.features.can_return_rows_from_bulk_insert
1540 for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
1541 if bulk_return and on_conflict is None:
1542 inserted_rows.extend(
1543 self._insert( # type: ignore[arg-type]
1544 item,
1545 fields=fields,
1546 returning_fields=self.model._model_meta.db_returning_fields,
1547 )
1548 )
1549 else:
1550 self._insert(
1551 item,
1552 fields=fields,
1553 on_conflict=on_conflict,
1554 update_fields=update_fields,
1555 unique_fields=unique_fields,
1556 )
1557 return inserted_rows
1558
1559 def _chain(self) -> Self:
1560 """
1561 Return a copy of the current QuerySet that's ready for another
1562 operation.
1563 """
1564 obj = self._clone()
1565 if obj._sticky_filter:
1566 obj.sql_query.filter_is_sticky = True
1567 obj._sticky_filter = False
1568 return obj
1569
1570 def _clone(self) -> Self:
1571 """
1572 Return a copy of the current QuerySet. A lightweight alternative
1573 to deepcopy().
1574 """
1575 c = self.__class__.from_model(
1576 model=self.model,
1577 query=self.sql_query.chain(),
1578 )
1579 c._sticky_filter = self._sticky_filter
1580 c._for_write = self._for_write
1581 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1582 c._known_related_objects = self._known_related_objects
1583 c._iterable_class = self._iterable_class
1584 c._fields = self._fields
1585 return c
1586
1587 def _fetch_all(self) -> None:
1588 if self._result_cache is None:
1589 self._result_cache = list(self._iterable_class(self))
1590 if self._prefetch_related_lookups and not self._prefetch_done:
1591 self._prefetch_related_objects()
1592
1593 def _next_is_sticky(self) -> QuerySet[T]:
1594 """
1595 Indicate that the next filter call and the one following that should
1596 be treated as a single filter. This is only important when it comes to
1597 determining when to reuse tables for many-to-many filters. Required so
1598 that we can filter naturally on the results of related managers.
1599
1600 This doesn't return a clone of the current QuerySet (it returns
1601 "self"). The method is only used internally and should be immediately
1602 followed by a filter() that does create a clone.
1603 """
1604 self._sticky_filter = True
1605 return self
1606
1607 def _merge_sanity_check(self, other: QuerySet[T]) -> None:
1608 """Check that two QuerySet classes may be merged."""
1609 if self._fields is not None and (
1610 set(self.sql_query.values_select) != set(other.sql_query.values_select)
1611 or set(self.sql_query.extra_select) != set(other.sql_query.extra_select)
1612 or set(self.sql_query.annotation_select)
1613 != set(other.sql_query.annotation_select)
1614 ):
1615 raise TypeError(
1616 f"Merging '{self.__class__.__name__}' classes must involve the same values in each case."
1617 )
1618
1619 def _merge_known_related_objects(self, other: QuerySet[T]) -> None:
1620 """
1621 Keep track of all known related objects from either QuerySet instance.
1622 """
1623 for field, objects in other._known_related_objects.items():
1624 self._known_related_objects.setdefault(field, {}).update(objects)
1625
1626 def resolve_expression(self, *args: Any, **kwargs: Any) -> Query:
1627 if self._fields and len(self._fields) > 1:
1628 # values() queryset can only be used as nested queries
1629 # if they are set up to select only a single field.
1630 raise TypeError("Cannot use multi-field values as a filter value.")
1631 query = self.sql_query.resolve_expression(*args, **kwargs)
1632 return query
1633
1634 def _has_filters(self) -> bool:
1635 """
1636 Check if this QuerySet has any filtering going on. This isn't
1637 equivalent with checking if all objects are present in results, for
1638 example, qs[1:]._has_filters() -> False.
1639 """
1640 return self.sql_query.has_filters()
1641
1642 @staticmethod
1643 def _validate_values_are_expressions(
1644 values: tuple[Any, ...], method_name: str
1645 ) -> None:
1646 invalid_args = sorted(
1647 str(arg) for arg in values if not isinstance(arg, ResolvableExpression)
1648 )
1649 if invalid_args:
1650 raise TypeError(
1651 "QuerySet.{}() received non-expression(s): {}.".format(
1652 method_name,
1653 ", ".join(invalid_args),
1654 )
1655 )
1656
1657
1658class InstanceCheckMeta(type):
1659 def __instancecheck__(self, instance: object) -> bool:
1660 return isinstance(instance, QuerySet) and instance.sql_query.is_empty()
1661
1662
1663class EmptyQuerySet(metaclass=InstanceCheckMeta):
1664 """
1665 Marker class to checking if a queryset is empty by .none():
1666 isinstance(qs.none(), EmptyQuerySet) -> True
1667 """
1668
1669 def __init__(self, *args: Any, **kwargs: Any):
1670 raise TypeError("EmptyQuerySet can't be instantiated")
1671
1672
1673class RawQuerySet:
1674 """
1675 Provide an iterator which converts the results of raw SQL queries into
1676 annotated model instances.
1677 """
1678
1679 def __init__(
1680 self,
1681 raw_query: str,
1682 model: type[Model] | None = None,
1683 query: RawQuery | None = None,
1684 params: tuple[Any, ...] = (),
1685 translations: dict[str, str] | None = None,
1686 ):
1687 self.raw_query = raw_query
1688 self.model = model
1689 self.sql_query = query or RawQuery(sql=raw_query, params=params)
1690 self.params = params
1691 self.translations = translations or {}
1692 self._result_cache: list[Model] | None = None
1693 self._prefetch_related_lookups: tuple[Any, ...] = ()
1694 self._prefetch_done = False
1695
1696 def resolve_model_init_order(
1697 self,
1698 ) -> tuple[list[str], list[int], list[tuple[str, int]]]:
1699 """Resolve the init field names and value positions."""
1700 converter = db_connection.introspection.identifier_converter
1701 model = self.model
1702 assert model is not None
1703 model_init_fields = [
1704 f for f in model._model_meta.fields if converter(f.column) in self.columns
1705 ]
1706 annotation_fields = [
1707 (column, pos)
1708 for pos, column in enumerate(self.columns)
1709 if column not in self.model_fields
1710 ]
1711 model_init_order = [
1712 self.columns.index(converter(f.column)) for f in model_init_fields
1713 ]
1714 model_init_names = [f.attname for f in model_init_fields]
1715 return model_init_names, model_init_order, annotation_fields
1716
1717 def prefetch_related(self, *lookups: str | Prefetch | None) -> RawQuerySet:
1718 """Same as QuerySet.prefetch_related()"""
1719 clone = self._clone()
1720 if lookups == (None,):
1721 clone._prefetch_related_lookups = ()
1722 else:
1723 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1724 return clone
1725
1726 def _prefetch_related_objects(self) -> None:
1727 assert self._result_cache is not None
1728 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1729 self._prefetch_done = True
1730
1731 def _clone(self) -> RawQuerySet:
1732 """Same as QuerySet._clone()"""
1733 c = self.__class__(
1734 self.raw_query,
1735 model=self.model,
1736 query=self.sql_query,
1737 params=self.params,
1738 translations=self.translations,
1739 )
1740 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1741 return c
1742
1743 def _fetch_all(self) -> None:
1744 if self._result_cache is None:
1745 self._result_cache = list(self.iterator())
1746 if self._prefetch_related_lookups and not self._prefetch_done:
1747 self._prefetch_related_objects()
1748
1749 def __len__(self) -> int:
1750 self._fetch_all()
1751 assert self._result_cache is not None
1752 return len(self._result_cache)
1753
1754 def __bool__(self) -> bool:
1755 self._fetch_all()
1756 return bool(self._result_cache)
1757
1758 def __iter__(self) -> Iterator[Model]:
1759 self._fetch_all()
1760 assert self._result_cache is not None
1761 return iter(self._result_cache)
1762
1763 def iterator(self) -> Iterator[Model]:
1764 yield from RawModelIterable(self) # type: ignore[arg-type]
1765
1766 def __repr__(self) -> str:
1767 return f"<{self.__class__.__name__}: {self.sql_query}>"
1768
1769 def __getitem__(self, k: int | slice) -> Model | list[Model]:
1770 return list(self)[k]
1771
1772 @cached_property
1773 def columns(self) -> list[str]:
1774 """
1775 A list of model field names in the order they'll appear in the
1776 query results.
1777 """
1778 columns = self.sql_query.get_columns()
1779 # Adjust any column names which don't match field names
1780 for query_name, model_name in self.translations.items():
1781 # Ignore translations for nonexistent column names
1782 try:
1783 index = columns.index(query_name)
1784 except ValueError:
1785 pass
1786 else:
1787 columns[index] = model_name
1788 return columns
1789
1790 @cached_property
1791 def model_fields(self) -> dict[str, Field]:
1792 """A dict mapping column names to model field names."""
1793 converter = db_connection.introspection.identifier_converter
1794 model_fields = {}
1795 model = self.model
1796 assert model is not None
1797 for field in model._model_meta.fields:
1798 model_fields[converter(field.column)] = field
1799 return model_fields
1800
1801
1802class Prefetch:
1803 def __init__(
1804 self,
1805 lookup: str,
1806 queryset: QuerySet[Any] | None = None,
1807 to_attr: str | None = None,
1808 ):
1809 # `prefetch_through` is the path we traverse to perform the prefetch.
1810 self.prefetch_through = lookup
1811 # `prefetch_to` is the path to the attribute that stores the result.
1812 self.prefetch_to = lookup
1813 if queryset is not None and (
1814 isinstance(queryset, RawQuerySet)
1815 or (
1816 hasattr(queryset, "_iterable_class")
1817 and not issubclass(queryset._iterable_class, ModelIterable)
1818 )
1819 ):
1820 raise ValueError(
1821 "Prefetch querysets cannot use raw(), values(), and values_list()."
1822 )
1823 if to_attr:
1824 self.prefetch_to = LOOKUP_SEP.join(
1825 lookup.split(LOOKUP_SEP)[:-1] + [to_attr]
1826 )
1827
1828 self.queryset = queryset
1829 self.to_attr = to_attr
1830
1831 def __getstate__(self) -> dict[str, Any]:
1832 obj_dict = self.__dict__.copy()
1833 if self.queryset is not None:
1834 queryset = self.queryset._chain()
1835 # Prevent the QuerySet from being evaluated
1836 queryset._result_cache = []
1837 queryset._prefetch_done = True
1838 obj_dict["queryset"] = queryset
1839 return obj_dict
1840
1841 def add_prefix(self, prefix: str) -> None:
1842 self.prefetch_through = prefix + LOOKUP_SEP + self.prefetch_through
1843 self.prefetch_to = prefix + LOOKUP_SEP + self.prefetch_to
1844
1845 def get_current_prefetch_to(self, level: int) -> str:
1846 return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[: level + 1])
1847
1848 def get_current_to_attr(self, level: int) -> tuple[str, bool]:
1849 parts = self.prefetch_to.split(LOOKUP_SEP)
1850 to_attr = parts[level]
1851 as_attr = bool(self.to_attr and level == len(parts) - 1)
1852 return to_attr, as_attr
1853
1854 def get_current_queryset(self, level: int) -> QuerySet[Any] | None:
1855 if self.get_current_prefetch_to(level) == self.prefetch_to:
1856 return self.queryset
1857 return None
1858
1859 def __eq__(self, other: object) -> bool:
1860 if not isinstance(other, Prefetch):
1861 return NotImplemented
1862 return self.prefetch_to == other.prefetch_to
1863
1864 def __hash__(self) -> int:
1865 return hash((self.__class__, self.prefetch_to))
1866
1867
1868def normalize_prefetch_lookups(
1869 lookups: tuple[str | Prefetch, ...] | list[str | Prefetch],
1870 prefix: str | None = None,
1871) -> list[Prefetch]:
1872 """Normalize lookups into Prefetch objects."""
1873 ret = []
1874 for lookup in lookups:
1875 if not isinstance(lookup, Prefetch):
1876 lookup = Prefetch(lookup)
1877 if prefix:
1878 lookup.add_prefix(prefix)
1879 ret.append(lookup)
1880 return ret
1881
1882
1883def prefetch_related_objects(
1884 model_instances: Sequence[Model], *related_lookups: str | Prefetch
1885) -> None:
1886 """
1887 Populate prefetched object caches for a list of model instances based on
1888 the lookups/Prefetch instances given.
1889 """
1890 if not model_instances:
1891 return # nothing to do
1892
1893 # We need to be able to dynamically add to the list of prefetch_related
1894 # lookups that we look up (see below). So we need some book keeping to
1895 # ensure we don't do duplicate work.
1896 done_queries = {} # dictionary of things like 'foo__bar': [results]
1897
1898 auto_lookups = set() # we add to this as we go through.
1899 followed_descriptors = set() # recursion protection
1900
1901 all_lookups = normalize_prefetch_lookups(reversed(related_lookups)) # type: ignore[arg-type]
1902 while all_lookups:
1903 lookup = all_lookups.pop()
1904 if lookup.prefetch_to in done_queries:
1905 if lookup.queryset is not None:
1906 raise ValueError(
1907 f"'{lookup.prefetch_to}' lookup was already seen with a different queryset. "
1908 "You may need to adjust the ordering of your lookups."
1909 )
1910
1911 continue
1912
1913 # Top level, the list of objects to decorate is the result cache
1914 # from the primary QuerySet. It won't be for deeper levels.
1915 obj_list = model_instances
1916
1917 through_attrs = lookup.prefetch_through.split(LOOKUP_SEP)
1918 for level, through_attr in enumerate(through_attrs):
1919 # Prepare main instances
1920 if not obj_list:
1921 break
1922
1923 prefetch_to = lookup.get_current_prefetch_to(level)
1924 if prefetch_to in done_queries:
1925 # Skip any prefetching, and any object preparation
1926 obj_list = done_queries[prefetch_to]
1927 continue
1928
1929 # Prepare objects:
1930 good_objects = True
1931 for obj in obj_list:
1932 # Since prefetching can re-use instances, it is possible to have
1933 # the same instance multiple times in obj_list, so obj might
1934 # already be prepared.
1935 if not hasattr(obj, "_prefetched_objects_cache"):
1936 try:
1937 obj._prefetched_objects_cache = {}
1938 except (AttributeError, TypeError):
1939 # Must be an immutable object from
1940 # values_list(flat=True), for example (TypeError) or
1941 # a QuerySet subclass that isn't returning Model
1942 # instances (AttributeError), either in Plain or a 3rd
1943 # party. prefetch_related() doesn't make sense, so quit.
1944 good_objects = False
1945 break
1946 if not good_objects:
1947 break
1948
1949 # Descend down tree
1950
1951 # We assume that objects retrieved are homogeneous (which is the premise
1952 # of prefetch_related), so what applies to first object applies to all.
1953 first_obj = obj_list[0]
1954 to_attr = lookup.get_current_to_attr(level)[0]
1955 prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(
1956 first_obj, through_attr, to_attr
1957 )
1958
1959 if not attr_found:
1960 raise AttributeError(
1961 f"Cannot find '{through_attr}' on {first_obj.__class__.__name__} object, '{lookup.prefetch_through}' is an invalid "
1962 "parameter to prefetch_related()"
1963 )
1964
1965 if level == len(through_attrs) - 1 and prefetcher is None:
1966 # Last one, this *must* resolve to something that supports
1967 # prefetching, otherwise there is no point adding it and the
1968 # developer asking for it has made a mistake.
1969 raise ValueError(
1970 f"'{lookup.prefetch_through}' does not resolve to an item that supports "
1971 "prefetching - this is an invalid parameter to "
1972 "prefetch_related()."
1973 )
1974
1975 obj_to_fetch = None
1976 if prefetcher is not None:
1977 obj_to_fetch = [obj for obj in obj_list if not is_fetched(obj)]
1978
1979 if obj_to_fetch:
1980 obj_list, additional_lookups = prefetch_one_level(
1981 obj_to_fetch,
1982 prefetcher,
1983 lookup,
1984 level,
1985 )
1986 # We need to ensure we don't keep adding lookups from the
1987 # same relationships to stop infinite recursion. So, if we
1988 # are already on an automatically added lookup, don't add
1989 # the new lookups from relationships we've seen already.
1990 if not (
1991 prefetch_to in done_queries
1992 and lookup in auto_lookups
1993 and descriptor in followed_descriptors
1994 ):
1995 done_queries[prefetch_to] = obj_list
1996 new_lookups = normalize_prefetch_lookups(
1997 reversed(additional_lookups), # type: ignore[arg-type]
1998 prefetch_to,
1999 )
2000 auto_lookups.update(new_lookups)
2001 all_lookups.extend(new_lookups)
2002 followed_descriptors.add(descriptor)
2003 else:
2004 # Either a singly related object that has already been fetched
2005 # (e.g. via select_related), or hopefully some other property
2006 # that doesn't support prefetching but needs to be traversed.
2007
2008 # We replace the current list of parent objects with the list
2009 # of related objects, filtering out empty or missing values so
2010 # that we can continue with nullable or reverse relations.
2011 new_obj_list = []
2012 for obj in obj_list:
2013 if through_attr in getattr(obj, "_prefetched_objects_cache", ()):
2014 # If related objects have been prefetched, use the
2015 # cache rather than the object's through_attr.
2016 new_obj = list(obj._prefetched_objects_cache.get(through_attr)) # type: ignore[arg-type]
2017 else:
2018 try:
2019 new_obj = getattr(obj, through_attr)
2020 except ObjectDoesNotExist:
2021 continue
2022 if new_obj is None:
2023 continue
2024 # We special-case `list` rather than something more generic
2025 # like `Iterable` because we don't want to accidentally match
2026 # user models that define __iter__.
2027 if isinstance(new_obj, list):
2028 new_obj_list.extend(new_obj)
2029 else:
2030 new_obj_list.append(new_obj)
2031 obj_list = new_obj_list
2032
2033
2034def get_prefetcher(
2035 instance: Model, through_attr: str, to_attr: str
2036) -> tuple[Any, Any, bool, Callable[[Model], bool]]:
2037 """
2038 For the attribute 'through_attr' on the given instance, find
2039 an object that has a get_prefetch_queryset().
2040 Return a 4 tuple containing:
2041 (the object with get_prefetch_queryset (or None),
2042 the descriptor object representing this relationship (or None),
2043 a boolean that is False if the attribute was not found at all,
2044 a function that takes an instance and returns a boolean that is True if
2045 the attribute has already been fetched for that instance)
2046 """
2047
2048 def has_to_attr_attribute(instance: Model) -> bool:
2049 return hasattr(instance, to_attr)
2050
2051 prefetcher = None
2052 is_fetched: Callable[[Model], bool] = has_to_attr_attribute
2053
2054 # For singly related objects, we have to avoid getting the attribute
2055 # from the object, as this will trigger the query. So we first try
2056 # on the class, in order to get the descriptor object.
2057 rel_obj_descriptor = getattr(instance.__class__, through_attr, None)
2058 if rel_obj_descriptor is None:
2059 attr_found = hasattr(instance, through_attr)
2060 else:
2061 attr_found = True
2062 if rel_obj_descriptor:
2063 # singly related object, descriptor object has the
2064 # get_prefetch_queryset() method.
2065 if hasattr(rel_obj_descriptor, "get_prefetch_queryset"):
2066 prefetcher = rel_obj_descriptor
2067 is_fetched = rel_obj_descriptor.is_cached
2068 else:
2069 # descriptor doesn't support prefetching, so we go ahead and get
2070 # the attribute on the instance rather than the class to
2071 # support many related managers
2072 rel_obj = getattr(instance, through_attr)
2073 if hasattr(rel_obj, "get_prefetch_queryset"):
2074 prefetcher = rel_obj
2075 if through_attr != to_attr:
2076 # Special case cached_property instances because hasattr
2077 # triggers attribute computation and assignment.
2078 if isinstance(
2079 getattr(instance.__class__, to_attr, None), cached_property
2080 ):
2081
2082 def has_cached_property(instance: Model) -> bool:
2083 return to_attr in instance.__dict__
2084
2085 is_fetched = has_cached_property
2086 else:
2087
2088 def in_prefetched_cache(instance: Model) -> bool:
2089 return through_attr in instance._prefetched_objects_cache
2090
2091 is_fetched = in_prefetched_cache
2092 return prefetcher, rel_obj_descriptor, attr_found, is_fetched
2093
2094
2095def prefetch_one_level(
2096 instances: list[Model], prefetcher: Any, lookup: Prefetch, level: int
2097) -> tuple[list[Model], list[Prefetch]]:
2098 """
2099 Helper function for prefetch_related_objects().
2100
2101 Run prefetches on all instances using the prefetcher object,
2102 assigning results to relevant caches in instance.
2103
2104 Return the prefetched objects along with any additional prefetches that
2105 must be done due to prefetch_related lookups found from default managers.
2106 """
2107 # prefetcher must have a method get_prefetch_queryset() which takes a list
2108 # of instances, and returns a tuple:
2109
2110 # (queryset of instances of self.model that are related to passed in instances,
2111 # callable that gets value to be matched for returned instances,
2112 # callable that gets value to be matched for passed in instances,
2113 # boolean that is True for singly related objects,
2114 # cache or field name to assign to,
2115 # boolean that is True when the previous argument is a cache name vs a field name).
2116
2117 # The 'values to be matched' must be hashable as they will be used
2118 # in a dictionary.
2119
2120 (
2121 rel_qs,
2122 rel_obj_attr,
2123 instance_attr,
2124 single,
2125 cache_name,
2126 is_descriptor,
2127 ) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))
2128 # We have to handle the possibility that the QuerySet we just got back
2129 # contains some prefetch_related lookups. We don't want to trigger the
2130 # prefetch_related functionality by evaluating the query. Rather, we need
2131 # to merge in the prefetch_related lookups.
2132 # Copy the lookups in case it is a Prefetch object which could be reused
2133 # later (happens in nested prefetch_related).
2134 additional_lookups = [
2135 copy.copy(additional_lookup)
2136 for additional_lookup in getattr(rel_qs, "_prefetch_related_lookups", ())
2137 ]
2138 if additional_lookups:
2139 # Don't need to clone because the queryset should have given us a fresh
2140 # instance, so we access an internal instead of using public interface
2141 # for performance reasons.
2142 rel_qs._prefetch_related_lookups = ()
2143
2144 all_related_objects = list(rel_qs)
2145
2146 rel_obj_cache = {}
2147 for rel_obj in all_related_objects:
2148 rel_attr_val = rel_obj_attr(rel_obj)
2149 rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
2150
2151 to_attr, as_attr = lookup.get_current_to_attr(level)
2152 # Make sure `to_attr` does not conflict with a field.
2153 if as_attr and instances:
2154 # We assume that objects retrieved are homogeneous (which is the premise
2155 # of prefetch_related), so what applies to first object applies to all.
2156 model = instances[0].__class__
2157 try:
2158 model._model_meta.get_field(to_attr)
2159 except FieldDoesNotExist:
2160 pass
2161 else:
2162 msg = "to_attr={} conflicts with a field on the {} model."
2163 raise ValueError(msg.format(to_attr, model.__name__))
2164
2165 # Whether or not we're prefetching the last part of the lookup.
2166 leaf = len(lookup.prefetch_through.split(LOOKUP_SEP)) - 1 == level
2167
2168 for obj in instances:
2169 instance_attr_val = instance_attr(obj)
2170 vals = rel_obj_cache.get(instance_attr_val, [])
2171
2172 if single:
2173 val = vals[0] if vals else None
2174 if as_attr:
2175 # A to_attr has been given for the prefetch.
2176 setattr(obj, to_attr, val)
2177 elif is_descriptor:
2178 # cache_name points to a field name in obj.
2179 # This field is a descriptor for a related object.
2180 setattr(obj, cache_name, val)
2181 else:
2182 # No to_attr has been given for this prefetch operation and the
2183 # cache_name does not point to a descriptor. Store the value of
2184 # the field in the object's field cache.
2185 obj._state.fields_cache[cache_name] = val # type: ignore[index]
2186 else:
2187 if as_attr:
2188 setattr(obj, to_attr, vals)
2189 else:
2190 queryset = getattr(obj, to_attr)
2191 if leaf and lookup.queryset is not None:
2192 qs = queryset._apply_rel_filters(lookup.queryset)
2193 else:
2194 # Check if queryset is a QuerySet or a related manager
2195 # We need a QuerySet instance to cache the prefetched values
2196 if isinstance(queryset, QuerySet):
2197 # It's already a QuerySet, create a new instance
2198 qs = queryset.__class__.from_model(queryset.model)
2199 else:
2200 # It's a related manager, get its QuerySet
2201 # The manager's query property returns a properly filtered QuerySet
2202 qs = queryset.query
2203 qs._result_cache = vals
2204 # We don't want the individual qs doing prefetch_related now,
2205 # since we have merged this into the current work.
2206 qs._prefetch_done = True
2207 obj._prefetched_objects_cache[cache_name] = qs
2208 return all_related_objects, additional_lookups
2209
2210
2211class RelatedPopulator:
2212 """
2213 RelatedPopulator is used for select_related() object instantiation.
2214
2215 The idea is that each select_related() model will be populated by a
2216 different RelatedPopulator instance. The RelatedPopulator instances get
2217 klass_info and select (computed in SQLCompiler) plus the used db as
2218 input for initialization. That data is used to compute which columns
2219 to use, how to instantiate the model, and how to populate the links
2220 between the objects.
2221
2222 The actual creation of the objects is done in populate() method. This
2223 method gets row and from_obj as input and populates the select_related()
2224 model instance.
2225 """
2226
2227 def __init__(self, klass_info: dict[str, Any], select: list[Any]):
2228 # Pre-compute needed attributes. The attributes are:
2229 # - model_cls: the possibly deferred model class to instantiate
2230 # - either:
2231 # - cols_start, cols_end: usually the columns in the row are
2232 # in the same order model_cls.__init__ expects them, so we
2233 # can instantiate by model_cls(*row[cols_start:cols_end])
2234 # - reorder_for_init: When select_related descends to a child
2235 # class, then we want to reuse the already selected parent
2236 # data. However, in this case the parent data isn't necessarily
2237 # in the same order that Model.__init__ expects it to be, so
2238 # we have to reorder the parent data. The reorder_for_init
2239 # attribute contains a function used to reorder the field data
2240 # in the order __init__ expects it.
2241 # - id_idx: the index of the primary key field in the reordered
2242 # model data. Used to check if a related object exists at all.
2243 # - init_list: the field attnames fetched from the database. For
2244 # deferred models this isn't the same as all attnames of the
2245 # model's fields.
2246 # - related_populators: a list of RelatedPopulator instances if
2247 # select_related() descends to related models from this model.
2248 # - local_setter, remote_setter: Methods to set cached values on
2249 # the object being populated and on the remote object. Usually
2250 # these are Field.set_cached_value() methods.
2251 select_fields = klass_info["select_fields"]
2252
2253 self.cols_start = select_fields[0]
2254 self.cols_end = select_fields[-1] + 1
2255 self.init_list = [
2256 f[0].target.attname for f in select[self.cols_start : self.cols_end]
2257 ]
2258 self.reorder_for_init = None
2259
2260 self.model_cls = klass_info["model"]
2261 self.id_idx = self.init_list.index("id")
2262 self.related_populators = get_related_populators(klass_info, select)
2263 self.local_setter = klass_info["local_setter"]
2264 self.remote_setter = klass_info["remote_setter"]
2265
2266 def populate(self, row: tuple[Any, ...], from_obj: Model) -> None:
2267 if self.reorder_for_init:
2268 obj_data = self.reorder_for_init(row)
2269 else:
2270 obj_data = row[self.cols_start : self.cols_end]
2271 if obj_data[self.id_idx] is None:
2272 obj = None
2273 else:
2274 obj = self.model_cls.from_db(self.init_list, obj_data)
2275 for rel_iter in self.related_populators:
2276 rel_iter.populate(row, obj)
2277 self.local_setter(from_obj, obj)
2278 if obj is not None:
2279 self.remote_setter(obj, from_obj)
2280
2281
2282def get_related_populators(
2283 klass_info: dict[str, Any], select: list[Any]
2284) -> list[RelatedPopulator]:
2285 iterators = []
2286 related_klass_infos = klass_info.get("related_klass_infos", [])
2287 for rel_klass_info in related_klass_infos:
2288 rel_cls = RelatedPopulator(rel_klass_info, select)
2289 iterators.append(rel_cls)
2290 return iterators