1"""
2The main QuerySet implementation. This provides the public API for the ORM.
3"""
4
5from __future__ import annotations
6
7import copy
8import operator
9import warnings
10from abc import ABC, abstractmethod
11from collections.abc import Callable, Iterator, Sequence
12from functools import cached_property
13from itertools import chain, islice
14from typing import TYPE_CHECKING, Any, Generic, Never, Self, TypeVar, overload
15
16import plain.runtime
17from plain.exceptions import ValidationError
18from plain.models import (
19 sql,
20 transaction,
21)
22from plain.models.constants import LOOKUP_SEP, OnConflict
23from plain.models.db import (
24 PLAIN_VERSION_PICKLE_KEY,
25 IntegrityError,
26 NotSupportedError,
27 db_connection,
28)
29from plain.models.exceptions import (
30 FieldDoesNotExist,
31 FieldError,
32 ObjectDoesNotExist,
33)
34from plain.models.expressions import Case, F, ResolvableExpression, Value, When
35from plain.models.fields import (
36 Field,
37 PrimaryKeyField,
38)
39from plain.models.functions import Cast
40from plain.models.query_utils import FilteredRelation, Q
41from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
42from plain.models.utils import resolve_callables
43from plain.utils.functional import partition
44
45# Re-exports for public API
46__all__ = ["F", "Q", "QuerySet", "RawQuerySet", "Prefetch", "FilteredRelation"]
47
48if TYPE_CHECKING:
49 from plain.models import Model
50
51# Type variable for QuerySet generic
52T = TypeVar("T", bound="Model")
53
54# The maximum number of results to fetch in a get() query.
55MAX_GET_RESULTS = 21
56
57# The maximum number of items to display in a QuerySet.__repr__
58REPR_OUTPUT_SIZE = 20
59
60
61class BaseIterable(ABC):
62 def __init__(
63 self,
64 queryset: QuerySet[Any],
65 chunked_fetch: bool = False,
66 chunk_size: int = GET_ITERATOR_CHUNK_SIZE,
67 ):
68 self.queryset = queryset
69 self.chunked_fetch = chunked_fetch
70 self.chunk_size = chunk_size
71
72 @abstractmethod
73 def __iter__(self) -> Iterator[Any]: ...
74
75
76class ModelIterable(BaseIterable):
77 """Iterable that yields a model instance for each row."""
78
79 def __iter__(self) -> Iterator[Model]:
80 queryset = self.queryset
81 compiler = queryset.sql_query.get_compiler()
82 # Execute the query. This will also fill compiler.select, klass_info,
83 # and annotations.
84 results = compiler.execute_sql(
85 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
86 )
87 select, klass_info, annotation_col_map = (
88 compiler.select,
89 compiler.klass_info,
90 compiler.annotation_col_map,
91 )
92 # These are set by execute_sql() above
93 assert select is not None
94 assert klass_info is not None
95 model_cls = klass_info["model"]
96 select_fields = klass_info["select_fields"]
97 model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
98 init_list = [
99 f[0].target.attname for f in select[model_fields_start:model_fields_end]
100 ]
101 related_populators = get_related_populators(klass_info, select)
102 known_related_objects = [
103 (
104 field,
105 related_objs,
106 operator.attrgetter(field.attname),
107 )
108 for field, related_objs in queryset._known_related_objects.items()
109 ]
110 for row in compiler.results_iter(results):
111 obj = model_cls.from_db(init_list, row[model_fields_start:model_fields_end])
112 for rel_populator in related_populators:
113 rel_populator.populate(row, obj)
114 if annotation_col_map:
115 for attr_name, col_pos in annotation_col_map.items():
116 setattr(obj, attr_name, row[col_pos])
117
118 # Add the known related objects to the model.
119 for field, rel_objs, rel_getter in known_related_objects:
120 # Avoid overwriting objects loaded by, e.g., select_related().
121 if field.is_cached(obj):
122 continue
123 rel_obj_id = rel_getter(obj)
124 try:
125 rel_obj = rel_objs[rel_obj_id]
126 except KeyError:
127 pass # May happen in qs1 | qs2 scenarios.
128 else:
129 setattr(obj, field.name, rel_obj)
130
131 yield obj
132
133
134class RawModelIterable(BaseIterable):
135 """
136 Iterable that yields a model instance for each row from a raw queryset.
137 """
138
139 queryset: RawQuerySet
140
141 def __iter__(self) -> Iterator[Model]:
142 # Cache some things for performance reasons outside the loop.
143 # RawQuery is not a Query subclass, so we directly get SQLCompiler
144 from plain.models.sql.query import Query as SqlQuery
145
146 query = self.queryset.sql_query
147 compiler = db_connection.ops.compilers[SqlQuery](query, db_connection, True)
148 query_iterator = iter(query)
149
150 try:
151 (
152 model_init_names,
153 model_init_pos,
154 annotation_fields,
155 ) = self.queryset.resolve_model_init_order()
156 model_cls = self.queryset.model
157 assert model_cls is not None
158 if "id" not in model_init_names:
159 raise FieldDoesNotExist("Raw query must include the primary key")
160 fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
161 converters = compiler.get_converters(
162 [
163 f.get_col(f.model.model_options.db_table) if f else None
164 for f in fields
165 ]
166 )
167 if converters:
168 query_iterator = compiler.apply_converters(query_iterator, converters)
169 for values in query_iterator:
170 # Associate fields to values
171 model_init_values = [values[pos] for pos in model_init_pos]
172 instance = model_cls.from_db(model_init_names, model_init_values)
173 if annotation_fields:
174 for column, pos in annotation_fields:
175 setattr(instance, column, values[pos])
176 yield instance
177 finally:
178 # Done iterating the Query. If it has its own cursor, close it.
179 if hasattr(query, "cursor") and query.cursor:
180 query.cursor.close()
181
182
183class ValuesIterable(BaseIterable):
184 """
185 Iterable returned by QuerySet.values() that yields a dict for each row.
186 """
187
188 def __iter__(self) -> Iterator[dict[str, Any]]:
189 queryset = self.queryset
190 query = queryset.sql_query
191 compiler = query.get_compiler()
192
193 # extra(select=...) cols are always at the start of the row.
194 names = [
195 *query.extra_select,
196 *query.values_select,
197 *query.annotation_select,
198 ]
199 indexes = range(len(names))
200 for row in compiler.results_iter(
201 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
202 ):
203 yield {names[i]: row[i] for i in indexes}
204
205
206class ValuesListIterable(BaseIterable):
207 """
208 Iterable returned by QuerySet.values_list(flat=False) that yields a tuple
209 for each row.
210 """
211
212 def __iter__(self) -> Iterator[tuple[Any, ...]]:
213 queryset = self.queryset
214 query = queryset.sql_query
215 compiler = query.get_compiler()
216
217 if queryset._fields:
218 # extra(select=...) cols are always at the start of the row.
219 names = [
220 *query.extra_select,
221 *query.values_select,
222 *query.annotation_select,
223 ]
224 fields = [
225 *queryset._fields,
226 *(f for f in query.annotation_select if f not in queryset._fields),
227 ]
228 if fields != names:
229 # Reorder according to fields.
230 index_map = {name: idx for idx, name in enumerate(names)}
231 rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
232 return map(
233 rowfactory,
234 compiler.results_iter(
235 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
236 ),
237 )
238 return iter(
239 compiler.results_iter(
240 tuple_expected=True,
241 chunked_fetch=self.chunked_fetch,
242 chunk_size=self.chunk_size,
243 )
244 )
245
246
247class FlatValuesListIterable(BaseIterable):
248 """
249 Iterable returned by QuerySet.values_list(flat=True) that yields single
250 values.
251 """
252
253 def __iter__(self) -> Iterator[Any]:
254 queryset = self.queryset
255 compiler = queryset.sql_query.get_compiler()
256 for row in compiler.results_iter(
257 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
258 ):
259 yield row[0]
260
261
262class QuerySet(Generic[T]):
263 """
264 Represent a lazy database lookup for a set of objects.
265
266 Usage:
267 MyModel.query.filter(name="test").all()
268
269 Custom QuerySets:
270 from typing import Self
271
272 class TaskQuerySet(QuerySet["Task"]):
273 def active(self) -> Self:
274 return self.filter(is_active=True)
275
276 class Task(Model):
277 is_active = BooleanField(default=True)
278 query = TaskQuerySet()
279
280 Task.query.active().filter(name="test") # Full type inference
281
282 Custom methods should return `Self` to preserve type through method chaining.
283 """
284
285 # Instance attributes (set in from_model())
286 model: type[T]
287 _query: sql.Query
288 _result_cache: list[T] | None
289 _sticky_filter: bool
290 _for_write: bool
291 _prefetch_related_lookups: tuple[Any, ...]
292 _prefetch_done: bool
293 _known_related_objects: dict[Any, dict[Any, Any]]
294 _iterable_class: type[BaseIterable]
295 _fields: tuple[str, ...] | None
296 _defer_next_filter: bool
297 _deferred_filter: tuple[bool, tuple[Any, ...], dict[str, Any]] | None
298
299 def __init__(self):
300 """Minimal init for descriptor mode. Use from_model() to create instances."""
301 pass
302
303 @classmethod
304 def from_model(cls, model: type[T], query: sql.Query | None = None) -> Self:
305 """Create a QuerySet instance bound to a model."""
306 instance = cls()
307 instance.model = model
308 instance._query = query or sql.Query(model)
309 instance._result_cache = None
310 instance._sticky_filter = False
311 instance._for_write = False
312 instance._prefetch_related_lookups = ()
313 instance._prefetch_done = False
314 instance._known_related_objects = {}
315 instance._iterable_class = ModelIterable
316 instance._fields = None
317 instance._defer_next_filter = False
318 instance._deferred_filter = None
319 return instance
320
321 @overload
322 def __get__(self, instance: None, owner: type[T]) -> Self: ...
323
324 @overload
325 def __get__(self, instance: Model, owner: type[T]) -> Never: ...
326
327 def __get__(self, instance: Any, owner: type[T]) -> Self:
328 """Descriptor protocol - return a new QuerySet bound to the model."""
329 if instance is not None:
330 raise AttributeError(
331 f"QuerySet is only accessible from the model class, not instances. "
332 f"Use {owner.__name__}.query instead."
333 )
334 return self.from_model(owner)
335
336 @property
337 def sql_query(self) -> sql.Query:
338 if self._deferred_filter:
339 negate, args, kwargs = self._deferred_filter
340 self._filter_or_exclude_inplace(negate, args, kwargs)
341 self._deferred_filter = None
342 return self._query
343
344 @sql_query.setter
345 def sql_query(self, value: sql.Query) -> None:
346 if value.values_select:
347 self._iterable_class = ValuesIterable
348 self._query = value
349
350 ########################
351 # PYTHON MAGIC METHODS #
352 ########################
353
354 def __deepcopy__(self, memo: dict[int, Any]) -> QuerySet[T]:
355 """Don't populate the QuerySet's cache."""
356 obj = self.__class__.from_model(self.model)
357 for k, v in self.__dict__.items():
358 if k == "_result_cache":
359 obj.__dict__[k] = None
360 else:
361 obj.__dict__[k] = copy.deepcopy(v, memo)
362 return obj
363
364 def __getstate__(self) -> dict[str, Any]:
365 # Force the cache to be fully populated.
366 self._fetch_all()
367 return {**self.__dict__, PLAIN_VERSION_PICKLE_KEY: plain.runtime.__version__}
368
369 def __setstate__(self, state: dict[str, Any]) -> None:
370 pickled_version = state.get(PLAIN_VERSION_PICKLE_KEY)
371 if pickled_version:
372 if pickled_version != plain.runtime.__version__:
373 warnings.warn(
374 f"Pickled queryset instance's Plain version {pickled_version} does not "
375 f"match the current version {plain.runtime.__version__}.",
376 RuntimeWarning,
377 stacklevel=2,
378 )
379 else:
380 warnings.warn(
381 "Pickled queryset instance's Plain version is not specified.",
382 RuntimeWarning,
383 stacklevel=2,
384 )
385 self.__dict__.update(state)
386
387 def __repr__(self) -> str:
388 data = list(self[: REPR_OUTPUT_SIZE + 1])
389 if len(data) > REPR_OUTPUT_SIZE:
390 data[-1] = "...(remaining elements truncated)..."
391 return f"<{self.__class__.__name__} {data!r}>"
392
393 def __len__(self) -> int:
394 self._fetch_all()
395 assert self._result_cache is not None
396 return len(self._result_cache)
397
398 def __iter__(self) -> Iterator[T]:
399 """
400 The queryset iterator protocol uses three nested iterators in the
401 default case:
402 1. sql.compiler.execute_sql()
403 - Returns 100 rows at time (constants.GET_ITERATOR_CHUNK_SIZE)
404 using cursor.fetchmany(). This part is responsible for
405 doing some column masking, and returning the rows in chunks.
406 2. sql.compiler.results_iter()
407 - Returns one row at time. At this point the rows are still just
408 tuples. In some cases the return values are converted to
409 Python values at this location.
410 3. self.iterator()
411 - Responsible for turning the rows into model objects.
412 """
413 self._fetch_all()
414 assert self._result_cache is not None
415 return iter(self._result_cache)
416
417 def __bool__(self) -> bool:
418 self._fetch_all()
419 return bool(self._result_cache)
420
421 @overload
422 def __getitem__(self, k: int) -> T: ...
423
424 @overload
425 def __getitem__(self, k: slice) -> QuerySet[T] | list[T]: ...
426
427 def __getitem__(self, k: int | slice) -> T | QuerySet[T] | list[T]:
428 """Retrieve an item or slice from the set of results."""
429 if not isinstance(k, int | slice):
430 raise TypeError(
431 f"QuerySet indices must be integers or slices, not {type(k).__name__}."
432 )
433 if (isinstance(k, int) and k < 0) or (
434 isinstance(k, slice)
435 and (
436 (k.start is not None and k.start < 0)
437 or (k.stop is not None and k.stop < 0)
438 )
439 ):
440 raise ValueError("Negative indexing is not supported.")
441
442 if self._result_cache is not None:
443 return self._result_cache[k]
444
445 if isinstance(k, slice):
446 qs = self._chain()
447 if k.start is not None:
448 start = int(k.start)
449 else:
450 start = None
451 if k.stop is not None:
452 stop = int(k.stop)
453 else:
454 stop = None
455 qs.sql_query.set_limits(start, stop)
456 return list(qs)[:: k.step] if k.step else qs
457
458 qs = self._chain()
459 qs.sql_query.set_limits(k, k + 1)
460 qs._fetch_all()
461 assert qs._result_cache is not None # _fetch_all guarantees this
462 return qs._result_cache[0]
463
464 def __class_getitem__(cls, *args: Any, **kwargs: Any) -> type[QuerySet[Any]]:
465 return cls
466
467 def __and__(self, other: QuerySet[T]) -> QuerySet[T]:
468 self._merge_sanity_check(other)
469 if isinstance(other, EmptyQuerySet):
470 return other
471 if isinstance(self, EmptyQuerySet):
472 return self
473 combined = self._chain()
474 combined._merge_known_related_objects(other)
475 combined.sql_query.combine(other.sql_query, sql.AND)
476 return combined
477
478 def __or__(self, other: QuerySet[T]) -> QuerySet[T]:
479 self._merge_sanity_check(other)
480 if isinstance(self, EmptyQuerySet):
481 return other
482 if isinstance(other, EmptyQuerySet):
483 return self
484 query = (
485 self
486 if self.sql_query.can_filter()
487 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
488 )
489 combined = query._chain()
490 combined._merge_known_related_objects(other)
491 if not other.sql_query.can_filter():
492 other = other.model._model_meta.base_queryset.filter(
493 id__in=other.values("id")
494 )
495 combined.sql_query.combine(other.sql_query, sql.OR)
496 return combined
497
498 def __xor__(self, other: QuerySet[T]) -> QuerySet[T]:
499 self._merge_sanity_check(other)
500 if isinstance(self, EmptyQuerySet):
501 return other
502 if isinstance(other, EmptyQuerySet):
503 return self
504 query = (
505 self
506 if self.sql_query.can_filter()
507 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
508 )
509 combined = query._chain()
510 combined._merge_known_related_objects(other)
511 if not other.sql_query.can_filter():
512 other = other.model._model_meta.base_queryset.filter(
513 id__in=other.values("id")
514 )
515 combined.sql_query.combine(other.sql_query, sql.XOR)
516 return combined
517
518 ####################################
519 # METHODS THAT DO DATABASE QUERIES #
520 ####################################
521
522 def _iterator(self, use_chunked_fetch: bool, chunk_size: int | None) -> Iterator[T]:
523 iterable = self._iterable_class(
524 self,
525 chunked_fetch=use_chunked_fetch,
526 chunk_size=chunk_size or 2000,
527 )
528 if not self._prefetch_related_lookups or chunk_size is None:
529 yield from iterable
530 return
531
532 iterator = iter(iterable)
533 while results := list(islice(iterator, chunk_size)):
534 prefetch_related_objects(results, *self._prefetch_related_lookups)
535 yield from results
536
537 def iterator(self, chunk_size: int | None = None) -> Iterator[T]:
538 """
539 An iterator over the results from applying this QuerySet to the
540 database. chunk_size must be provided for QuerySets that prefetch
541 related objects. Otherwise, a default chunk_size of 2000 is supplied.
542 """
543 if chunk_size is None:
544 if self._prefetch_related_lookups:
545 raise ValueError(
546 "chunk_size must be provided when using QuerySet.iterator() after "
547 "prefetch_related()."
548 )
549 elif chunk_size <= 0:
550 raise ValueError("Chunk size must be strictly positive.")
551 use_chunked_fetch = not db_connection.settings_dict.get(
552 "DISABLE_SERVER_SIDE_CURSORS"
553 )
554 return self._iterator(use_chunked_fetch, chunk_size)
555
556 def aggregate(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
557 """
558 Return a dictionary containing the calculations (aggregation)
559 over the current queryset.
560
561 If args is present the expression is passed as a kwarg using
562 the Aggregate object's default alias.
563 """
564 if self.sql_query.distinct_fields:
565 raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
566 self._validate_values_are_expressions(
567 (*args, *kwargs.values()), method_name="aggregate"
568 )
569 for arg in args:
570 # The default_alias property raises TypeError if default_alias
571 # can't be set automatically or AttributeError if it isn't an
572 # attribute.
573 try:
574 arg.default_alias
575 except (AttributeError, TypeError):
576 raise TypeError("Complex aggregates require an alias")
577 kwargs[arg.default_alias] = arg
578
579 return self.sql_query.chain().get_aggregation(kwargs)
580
581 def count(self) -> int:
582 """
583 Perform a SELECT COUNT() and return the number of records as an
584 integer.
585
586 If the QuerySet is already fully cached, return the length of the
587 cached results set to avoid multiple SELECT COUNT(*) calls.
588 """
589 if self._result_cache is not None:
590 return len(self._result_cache)
591
592 return self.sql_query.get_count()
593
594 def get(self, *args: Any, **kwargs: Any) -> T:
595 """
596 Perform the query and return a single object matching the given
597 keyword arguments.
598 """
599 clone = self.filter(*args, **kwargs)
600 if self.sql_query.can_filter() and not self.sql_query.distinct_fields:
601 clone = clone.order_by()
602 limit = None
603 if (
604 not clone.sql_query.select_for_update
605 or db_connection.features.supports_select_for_update_with_limit
606 ):
607 limit = MAX_GET_RESULTS
608 clone.sql_query.set_limits(high=limit)
609 num = len(clone)
610 if num == 1:
611 assert clone._result_cache is not None # len() fetches results
612 return clone._result_cache[0]
613 if not num:
614 raise self.model.DoesNotExist(
615 f"{self.model.model_options.object_name} matching query does not exist."
616 )
617 raise self.model.MultipleObjectsReturned(
618 "get() returned more than one {} -- it returned {}!".format(
619 self.model.model_options.object_name,
620 num if not limit or num < limit else "more than %s" % (limit - 1),
621 )
622 )
623
624 def get_or_none(self, *args: Any, **kwargs: Any) -> T | None:
625 """
626 Perform the query and return a single object matching the given
627 keyword arguments, or None if no object is found.
628 """
629 try:
630 return self.get(*args, **kwargs)
631 except self.model.DoesNotExist:
632 return None
633
634 def create(self, **kwargs: Any) -> T:
635 """
636 Create a new object with the given kwargs, saving it to the database
637 and returning the created object.
638 """
639 obj = self.model(**kwargs)
640 self._for_write = True
641 obj.save(force_insert=True)
642 return obj
643
644 def _prepare_for_bulk_create(self, objs: list[T]) -> None:
645 id_field = self.model._model_meta.get_forward_field("id")
646 for obj in objs:
647 if obj.id is None:
648 # Populate new primary key values.
649 obj.id = id_field.get_id_value_on_save(obj)
650 obj._prepare_related_fields_for_save(operation_name="bulk_create")
651
652 def _check_bulk_create_options(
653 self,
654 update_conflicts: bool,
655 update_fields: list[Field] | None,
656 unique_fields: list[Field] | None,
657 ) -> OnConflict | None:
658 db_features = db_connection.features
659 if update_conflicts:
660 if not db_features.supports_update_conflicts:
661 raise NotSupportedError(
662 "This database backend does not support updating conflicts."
663 )
664 if not update_fields:
665 raise ValueError(
666 "Fields that will be updated when a row insertion fails "
667 "on conflicts must be provided."
668 )
669 if unique_fields and not db_features.supports_update_conflicts_with_target:
670 raise NotSupportedError(
671 "This database backend does not support updating "
672 "conflicts with specifying unique fields that can trigger "
673 "the upsert."
674 )
675 if not unique_fields and db_features.supports_update_conflicts_with_target:
676 raise ValueError(
677 "Unique fields that can trigger the upsert must be provided."
678 )
679 # Updating primary keys and non-concrete fields is forbidden.
680 from plain.models.fields.related import ManyToManyField
681
682 if any(
683 not f.concrete or isinstance(f, ManyToManyField) for f in update_fields
684 ):
685 raise ValueError(
686 "bulk_create() can only be used with concrete fields in "
687 "update_fields."
688 )
689 if any(f.primary_key for f in update_fields):
690 raise ValueError(
691 "bulk_create() cannot be used with primary keys in update_fields."
692 )
693 if unique_fields:
694 from plain.models.fields.related import ManyToManyField
695
696 if any(
697 not f.concrete or isinstance(f, ManyToManyField)
698 for f in unique_fields
699 ):
700 raise ValueError(
701 "bulk_create() can only be used with concrete fields "
702 "in unique_fields."
703 )
704 return OnConflict.UPDATE
705 return None
706
707 def bulk_create(
708 self,
709 objs: Sequence[T],
710 batch_size: int | None = None,
711 update_conflicts: bool = False,
712 update_fields: list[str] | None = None,
713 unique_fields: list[str] | None = None,
714 ) -> list[T]:
715 """
716 Insert each of the instances into the database. Do *not* call
717 save() on each of the instances, and do not set the primary key attribute if it is an
718 autoincrement field (except if features.can_return_rows_from_bulk_insert=True).
719 Multi-table models are not supported.
720 """
721 # When you bulk insert you don't get the primary keys back (if it's an
722 # autoincrement, except if can_return_rows_from_bulk_insert=True), so
723 # you can't insert into the child tables which references this. There
724 # are two workarounds:
725 # 1) This could be implemented if you didn't have an autoincrement pk
726 # 2) You could do it by doing O(n) normal inserts into the parent
727 # tables to get the primary keys back and then doing a single bulk
728 # insert into the childmost table.
729 # We currently set the primary keys on the objects when using
730 # PostgreSQL via the RETURNING ID clause. It should be possible for
731 # Oracle as well, but the semantics for extracting the primary keys is
732 # trickier so it's not done yet.
733 if batch_size is not None and batch_size <= 0:
734 raise ValueError("Batch size must be a positive integer.")
735
736 objs = list(objs)
737 if not objs:
738 return objs
739 meta = self.model._model_meta
740 unique_fields_objs: list[Field] | None = None
741 update_fields_objs: list[Field] | None = None
742 if unique_fields:
743 unique_fields_objs = [
744 meta.get_forward_field(name) for name in unique_fields
745 ]
746 if update_fields:
747 update_fields_objs = [
748 meta.get_forward_field(name) for name in update_fields
749 ]
750 on_conflict = self._check_bulk_create_options(
751 update_conflicts,
752 update_fields_objs,
753 unique_fields_objs,
754 )
755 self._for_write = True
756 fields = meta.concrete_fields
757 self._prepare_for_bulk_create(objs)
758 with transaction.atomic(savepoint=False):
759 objs_with_id, objs_without_id = partition(lambda o: o.id is None, objs)
760 if objs_with_id:
761 returned_columns = self._batched_insert(
762 objs_with_id,
763 fields,
764 batch_size,
765 on_conflict=on_conflict,
766 update_fields=update_fields_objs,
767 unique_fields=unique_fields_objs,
768 )
769 id_field = meta.get_forward_field("id")
770 for obj_with_id, results in zip(objs_with_id, returned_columns):
771 for result, field in zip(results, meta.db_returning_fields):
772 if field != id_field:
773 setattr(obj_with_id, field.attname, result)
774 for obj_with_id in objs_with_id:
775 obj_with_id._state.adding = False
776 if objs_without_id:
777 fields = [f for f in fields if not isinstance(f, PrimaryKeyField)]
778 returned_columns = self._batched_insert(
779 objs_without_id,
780 fields,
781 batch_size,
782 on_conflict=on_conflict,
783 update_fields=update_fields_objs,
784 unique_fields=unique_fields_objs,
785 )
786 if (
787 db_connection.features.can_return_rows_from_bulk_insert
788 and on_conflict is None
789 ):
790 assert len(returned_columns) == len(objs_without_id)
791 for obj_without_id, results in zip(objs_without_id, returned_columns):
792 for result, field in zip(results, meta.db_returning_fields):
793 setattr(obj_without_id, field.attname, result)
794 obj_without_id._state.adding = False
795
796 return objs
797
798 def bulk_update(
799 self, objs: Sequence[T], fields: list[str], batch_size: int | None = None
800 ) -> int:
801 """
802 Update the given fields in each of the given objects in the database.
803 """
804 if batch_size is not None and batch_size <= 0:
805 raise ValueError("Batch size must be a positive integer.")
806 if not fields:
807 raise ValueError("Field names must be given to bulk_update().")
808 objs_tuple = tuple(objs)
809 if any(obj.id is None for obj in objs_tuple):
810 raise ValueError("All bulk_update() objects must have a primary key set.")
811 fields_list = [
812 self.model._model_meta.get_forward_field(name) for name in fields
813 ]
814 from plain.models.fields.related import ManyToManyField
815
816 if any(not f.concrete or isinstance(f, ManyToManyField) for f in fields_list):
817 raise ValueError("bulk_update() can only be used with concrete fields.")
818 if any(f.primary_key for f in fields_list):
819 raise ValueError("bulk_update() cannot be used with primary key fields.")
820 if not objs_tuple:
821 return 0
822 for obj in objs_tuple:
823 obj._prepare_related_fields_for_save(
824 operation_name="bulk_update", fields=fields_list
825 )
826 # PK is used twice in the resulting update query, once in the filter
827 # and once in the WHEN. Each field will also have one CAST.
828 self._for_write = True
829 max_batch_size = db_connection.ops.bulk_batch_size(
830 ["id", "id"] + fields_list, objs_tuple
831 )
832 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
833 requires_casting = db_connection.features.requires_casted_case_in_updates
834 batches = (
835 objs_tuple[i : i + batch_size]
836 for i in range(0, len(objs_tuple), batch_size)
837 )
838 updates = []
839 for batch_objs in batches:
840 update_kwargs = {}
841 for field in fields_list:
842 when_statements = []
843 for obj in batch_objs:
844 attr = getattr(obj, field.attname)
845 if not isinstance(attr, ResolvableExpression):
846 attr = Value(attr, output_field=field)
847 when_statements.append(When(id=obj.id, then=attr))
848 case_statement = Case(*when_statements, output_field=field)
849 if requires_casting:
850 case_statement = Cast(case_statement, output_field=field)
851 update_kwargs[field.attname] = case_statement
852 updates.append(([obj.id for obj in batch_objs], update_kwargs))
853 rows_updated = 0
854 queryset = self._chain()
855 with transaction.atomic(savepoint=False):
856 for ids, update_kwargs in updates:
857 rows_updated += queryset.filter(id__in=ids).update(**update_kwargs)
858 return rows_updated
859
860 def get_or_create(
861 self, defaults: dict[str, Any] | None = None, **kwargs: Any
862 ) -> tuple[T, bool]:
863 """
864 Look up an object with the given kwargs, creating one if necessary.
865 Return a tuple of (object, created), where created is a boolean
866 specifying whether an object was created.
867 """
868 # The get() needs to be targeted at the write database in order
869 # to avoid potential transaction consistency problems.
870 self._for_write = True
871 try:
872 return self.get(**kwargs), False
873 except self.model.DoesNotExist:
874 params = self._extract_model_params(defaults, **kwargs)
875 # Try to create an object using passed params.
876 try:
877 with transaction.atomic():
878 params = dict(resolve_callables(params))
879 return self.create(**params), True
880 except (IntegrityError, ValidationError):
881 # Since create() also validates by default,
882 # we can get any kind of ValidationError here,
883 # or it can flow through and get an IntegrityError from the database.
884 # The main thing we're concerned about is uniqueness failures,
885 # but ValidationError could include other things too.
886 # In all cases though it should be fine to try the get() again
887 # and return an existing object.
888 try:
889 return self.get(**kwargs), False
890 except self.model.DoesNotExist:
891 pass
892 raise
893
894 def update_or_create(
895 self,
896 defaults: dict[str, Any] | None = None,
897 create_defaults: dict[str, Any] | None = None,
898 **kwargs: Any,
899 ) -> tuple[T, bool]:
900 """
901 Look up an object with the given kwargs, updating one with defaults
902 if it exists, otherwise create a new one. Optionally, an object can
903 be created with different values than defaults by using
904 create_defaults.
905 Return a tuple (object, created), where created is a boolean
906 specifying whether an object was created.
907 """
908 if create_defaults is None:
909 update_defaults = create_defaults = defaults or {}
910 else:
911 update_defaults = defaults or {}
912 self._for_write = True
913 with transaction.atomic():
914 # Lock the row so that a concurrent update is blocked until
915 # update_or_create() has performed its save.
916 obj, created = self.select_for_update().get_or_create(
917 create_defaults, **kwargs
918 )
919 if created:
920 return obj, created
921 for k, v in resolve_callables(update_defaults):
922 setattr(obj, k, v)
923
924 update_fields = set(update_defaults)
925 concrete_field_names = self.model._model_meta._non_pk_concrete_field_names
926 # update_fields does not support non-concrete fields.
927 if concrete_field_names.issuperset(update_fields):
928 # Add fields which are set on pre_save(), e.g. auto_now fields.
929 # This is to maintain backward compatibility as these fields
930 # are not updated unless explicitly specified in the
931 # update_fields list.
932 for field in self.model._model_meta.local_concrete_fields:
933 if not (
934 field.primary_key or field.__class__.pre_save is Field.pre_save
935 ):
936 update_fields.add(field.name)
937 if field.name != field.attname:
938 update_fields.add(field.attname)
939 obj.save(update_fields=update_fields)
940 else:
941 obj.save()
942 return obj, False
943
944 def _extract_model_params(
945 self, defaults: dict[str, Any] | None, **kwargs: Any
946 ) -> dict[str, Any]:
947 """
948 Prepare `params` for creating a model instance based on the given
949 kwargs; for use by get_or_create().
950 """
951 defaults = defaults or {}
952 params = {k: v for k, v in kwargs.items() if LOOKUP_SEP not in k}
953 params.update(defaults)
954 property_names = self.model._model_meta._property_names
955 invalid_params = []
956 for param in params:
957 try:
958 self.model._model_meta.get_field(param)
959 except FieldDoesNotExist:
960 # It's okay to use a model's property if it has a setter.
961 if not (param in property_names and getattr(self.model, param).fset):
962 invalid_params.append(param)
963 if invalid_params:
964 raise FieldError(
965 "Invalid field name(s) for model {}: '{}'.".format(
966 self.model.model_options.object_name,
967 "', '".join(sorted(invalid_params)),
968 )
969 )
970 return params
971
972 def first(self) -> T | None:
973 """Return the first object of a query or None if no match is found."""
974 for obj in self[:1]:
975 return obj
976 return None
977
978 def last(self) -> T | None:
979 """Return the last object of a query or None if no match is found."""
980 queryset = self.reverse()
981 for obj in queryset[:1]:
982 return obj
983 return None
984
985 def delete(self) -> tuple[int, dict[str, int]]:
986 """Delete the records in the current QuerySet."""
987 if self.sql_query.is_sliced:
988 raise TypeError("Cannot use 'limit' or 'offset' with delete().")
989 if self.sql_query.distinct or self.sql_query.distinct_fields:
990 raise TypeError("Cannot call delete() after .distinct().")
991 if self._fields is not None:
992 raise TypeError("Cannot call delete() after .values() or .values_list()")
993
994 del_query = self._chain()
995
996 # The delete is actually 2 queries - one to find related objects,
997 # and one to delete. Make sure that the discovery of related
998 # objects is performed on the same database as the deletion.
999 del_query._for_write = True
1000
1001 # Disable non-supported fields.
1002 del_query.sql_query.select_for_update = False
1003 del_query.sql_query.select_related = False
1004 del_query.sql_query.clear_ordering(force=True)
1005
1006 from plain.models.deletion import Collector
1007
1008 collector = Collector(origin=self)
1009 collector.collect(del_query)
1010 deleted, _rows_count = collector.delete()
1011
1012 # Clear the result cache, in case this QuerySet gets reused.
1013 self._result_cache = None
1014 return deleted, _rows_count
1015
1016 def _raw_delete(self) -> int:
1017 """
1018 Delete objects found from the given queryset in single direct SQL
1019 query. No signals are sent and there is no protection for cascades.
1020 """
1021 query = self.sql_query.clone()
1022 query.__class__ = sql.DeleteQuery
1023 cursor = query.get_compiler().execute_sql(CURSOR)
1024 if cursor:
1025 with cursor:
1026 return cursor.rowcount
1027 return 0
1028
1029 def update(self, **kwargs: Any) -> int:
1030 """
1031 Update all elements in the current QuerySet, setting all the given
1032 fields to the appropriate values.
1033 """
1034 if self.sql_query.is_sliced:
1035 raise TypeError("Cannot update a query once a slice has been taken.")
1036 self._for_write = True
1037 query = self.sql_query.chain(sql.UpdateQuery)
1038 query.add_update_values(kwargs)
1039
1040 # Inline annotations in order_by(), if possible.
1041 new_order_by = []
1042 for col in query.order_by:
1043 alias = col
1044 descending = False
1045 if isinstance(alias, str) and alias.startswith("-"):
1046 alias = alias.removeprefix("-")
1047 descending = True
1048 if annotation := query.annotations.get(alias):
1049 if getattr(annotation, "contains_aggregate", False):
1050 raise FieldError(
1051 f"Cannot update when ordering by an aggregate: {annotation}"
1052 )
1053 if descending:
1054 annotation = annotation.desc()
1055 new_order_by.append(annotation)
1056 else:
1057 new_order_by.append(col)
1058 query.order_by = tuple(new_order_by)
1059
1060 # Clear any annotations so that they won't be present in subqueries.
1061 query.annotations = {}
1062 with transaction.mark_for_rollback_on_error():
1063 rows = query.get_compiler().execute_sql(CURSOR)
1064 self._result_cache = None
1065 return rows
1066
1067 def _update(self, values: list[tuple[Field, Any, Any]]) -> int:
1068 """
1069 A version of update() that accepts field objects instead of field names.
1070 Used primarily for model saving and not intended for use by general
1071 code (it requires too much poking around at model internals to be
1072 useful at that level).
1073 """
1074 if self.sql_query.is_sliced:
1075 raise TypeError("Cannot update a query once a slice has been taken.")
1076 query = self.sql_query.chain(sql.UpdateQuery)
1077 query.add_update_fields(values)
1078 # Clear any annotations so that they won't be present in subqueries.
1079 query.annotations = {}
1080 self._result_cache = None
1081 return query.get_compiler().execute_sql(CURSOR)
1082
1083 def exists(self) -> bool:
1084 """
1085 Return True if the QuerySet would have any results, False otherwise.
1086 """
1087 if self._result_cache is None:
1088 return self.sql_query.has_results()
1089 return bool(self._result_cache)
1090
1091 def _prefetch_related_objects(self) -> None:
1092 # This method can only be called once the result cache has been filled.
1093 assert self._result_cache is not None
1094 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1095 self._prefetch_done = True
1096
1097 def explain(self, *, format: str | None = None, **options: Any) -> str:
1098 """
1099 Runs an EXPLAIN on the SQL query this QuerySet would perform, and
1100 returns the results.
1101 """
1102 return self.sql_query.explain(format=format, **options)
1103
1104 ##################################################
1105 # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
1106 ##################################################
1107
1108 def raw(
1109 self,
1110 raw_query: str,
1111 params: Sequence[Any] = (),
1112 translations: dict[str, str] | None = None,
1113 ) -> RawQuerySet:
1114 qs = RawQuerySet(
1115 raw_query,
1116 model=self.model,
1117 params=tuple(params),
1118 translations=translations,
1119 )
1120 qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
1121 return qs
1122
1123 def _values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1124 clone = self._chain()
1125 if expressions:
1126 clone = clone.annotate(**expressions)
1127 clone._fields = fields
1128 clone.sql_query.set_values(list(fields))
1129 return clone
1130
1131 def values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1132 fields += tuple(expressions)
1133 clone = self._values(*fields, **expressions)
1134 clone._iterable_class = ValuesIterable
1135 return clone
1136
1137 def values_list(self, *fields: str, flat: bool = False) -> QuerySet[Any]:
1138 if flat and len(fields) > 1:
1139 raise TypeError(
1140 "'flat' is not valid when values_list is called with more than one "
1141 "field."
1142 )
1143
1144 field_names = {f for f in fields if not isinstance(f, ResolvableExpression)}
1145 _fields = []
1146 expressions = {}
1147 counter = 1
1148 for field in fields:
1149 if isinstance(field, ResolvableExpression):
1150 field_id_prefix = getattr(
1151 field, "default_alias", field.__class__.__name__.lower()
1152 )
1153 while True:
1154 field_id = field_id_prefix + str(counter)
1155 counter += 1
1156 if field_id not in field_names:
1157 break
1158 expressions[field_id] = field
1159 _fields.append(field_id)
1160 else:
1161 _fields.append(field)
1162
1163 clone = self._values(*_fields, **expressions)
1164 clone._iterable_class = FlatValuesListIterable if flat else ValuesListIterable
1165 return clone
1166
1167 def none(self) -> QuerySet[T]:
1168 """Return an empty QuerySet."""
1169 clone = self._chain()
1170 clone.sql_query.set_empty()
1171 return clone
1172
1173 ##################################################################
1174 # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
1175 ##################################################################
1176
1177 def all(self) -> Self:
1178 """
1179 Return a new QuerySet that is a copy of the current one. This allows a
1180 QuerySet to proxy for a model queryset in some cases.
1181 """
1182 obj = self._chain()
1183 # Preserve cache since all() doesn't modify the query.
1184 # This is important for prefetch_related() to work correctly.
1185 obj._result_cache = self._result_cache
1186 obj._prefetch_done = self._prefetch_done
1187 return obj
1188
1189 def filter(self, *args: Any, **kwargs: Any) -> Self:
1190 """
1191 Return a new QuerySet instance with the args ANDed to the existing
1192 set.
1193 """
1194 return self._filter_or_exclude(False, args, kwargs)
1195
1196 def exclude(self, *args: Any, **kwargs: Any) -> Self:
1197 """
1198 Return a new QuerySet instance with NOT (args) ANDed to the existing
1199 set.
1200 """
1201 return self._filter_or_exclude(True, args, kwargs)
1202
1203 def _filter_or_exclude(
1204 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1205 ) -> Self:
1206 if (args or kwargs) and self.sql_query.is_sliced:
1207 raise TypeError("Cannot filter a query once a slice has been taken.")
1208 clone = self._chain()
1209 if self._defer_next_filter:
1210 self._defer_next_filter = False
1211 clone._deferred_filter = negate, args, kwargs
1212 else:
1213 clone._filter_or_exclude_inplace(negate, args, kwargs)
1214 return clone
1215
1216 def _filter_or_exclude_inplace(
1217 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1218 ) -> None:
1219 if negate:
1220 self._query.add_q(~Q(*args, **kwargs))
1221 else:
1222 self._query.add_q(Q(*args, **kwargs))
1223
1224 def complex_filter(self, filter_obj: Q | dict[str, Any]) -> QuerySet[T]:
1225 """
1226 Return a new QuerySet instance with filter_obj added to the filters.
1227
1228 filter_obj can be a Q object or a dictionary of keyword lookup
1229 arguments.
1230
1231 This exists to support framework features such as 'limit_choices_to',
1232 and usually it will be more natural to use other methods.
1233 """
1234 if isinstance(filter_obj, Q):
1235 clone = self._chain()
1236 clone.sql_query.add_q(filter_obj)
1237 return clone
1238 else:
1239 return self._filter_or_exclude(False, args=(), kwargs=filter_obj)
1240
1241 def select_for_update(
1242 self,
1243 nowait: bool = False,
1244 skip_locked: bool = False,
1245 of: tuple[str, ...] = (),
1246 no_key: bool = False,
1247 ) -> QuerySet[T]:
1248 """
1249 Return a new QuerySet instance that will select objects with a
1250 FOR UPDATE lock.
1251 """
1252 if nowait and skip_locked:
1253 raise ValueError("The nowait option cannot be used with skip_locked.")
1254 obj = self._chain()
1255 obj._for_write = True
1256 obj.sql_query.select_for_update = True
1257 obj.sql_query.select_for_update_nowait = nowait
1258 obj.sql_query.select_for_update_skip_locked = skip_locked
1259 obj.sql_query.select_for_update_of = of
1260 obj.sql_query.select_for_no_key_update = no_key
1261 return obj
1262
1263 def select_related(self, *fields: str | None) -> Self:
1264 """
1265 Return a new QuerySet instance that will select related objects.
1266
1267 If fields are specified, they must be ForeignKeyField fields and only those
1268 related objects are included in the selection.
1269
1270 If select_related(None) is called, clear the list.
1271 """
1272 if self._fields is not None:
1273 raise TypeError(
1274 "Cannot call select_related() after .values() or .values_list()"
1275 )
1276
1277 obj = self._chain()
1278 if fields == (None,):
1279 obj.sql_query.select_related = False
1280 elif fields:
1281 obj.sql_query.add_select_related(list(fields)) # type: ignore[arg-type]
1282 else:
1283 obj.sql_query.select_related = True
1284 return obj
1285
1286 def prefetch_related(self, *lookups: str | Prefetch | None) -> Self:
1287 """
1288 Return a new QuerySet instance that will prefetch the specified
1289 Many-To-One and Many-To-Many related objects when the QuerySet is
1290 evaluated.
1291
1292 When prefetch_related() is called more than once, append to the list of
1293 prefetch lookups. If prefetch_related(None) is called, clear the list.
1294 """
1295 clone = self._chain()
1296 if lookups == (None,):
1297 clone._prefetch_related_lookups = ()
1298 else:
1299 for lookup in lookups:
1300 lookup_str: str
1301 if isinstance(lookup, Prefetch):
1302 lookup_str = lookup.prefetch_to
1303 else:
1304 assert isinstance(lookup, str)
1305 lookup_str = lookup
1306 lookup_str = lookup_str.split(LOOKUP_SEP, 1)[0]
1307 if lookup_str in self.sql_query._filtered_relations:
1308 raise ValueError(
1309 "prefetch_related() is not supported with FilteredRelation."
1310 )
1311 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1312 return clone
1313
1314 def annotate(self, *args: Any, **kwargs: Any) -> Self:
1315 """
1316 Return a query set in which the returned objects have been annotated
1317 with extra data or aggregations.
1318 """
1319 return self._annotate(args, kwargs, select=True)
1320
1321 def alias(self, *args: Any, **kwargs: Any) -> Self:
1322 """
1323 Return a query set with added aliases for extra data or aggregations.
1324 """
1325 return self._annotate(args, kwargs, select=False)
1326
1327 def _annotate(
1328 self, args: tuple[Any, ...], kwargs: dict[str, Any], select: bool = True
1329 ) -> Self:
1330 self._validate_values_are_expressions(
1331 args + tuple(kwargs.values()), method_name="annotate"
1332 )
1333 annotations = {}
1334 for arg in args:
1335 # The default_alias property may raise a TypeError.
1336 try:
1337 if arg.default_alias in kwargs:
1338 raise ValueError(
1339 f"The named annotation '{arg.default_alias}' conflicts with the "
1340 "default name for another annotation."
1341 )
1342 except TypeError:
1343 raise TypeError("Complex annotations require an alias")
1344 annotations[arg.default_alias] = arg
1345 annotations.update(kwargs)
1346
1347 clone = self._chain()
1348 names = self._fields
1349 if names is None:
1350 names = set(
1351 chain.from_iterable(
1352 (field.name, field.attname)
1353 if hasattr(field, "attname")
1354 else (field.name,)
1355 for field in self.model._model_meta.get_fields()
1356 )
1357 )
1358
1359 for alias, annotation in annotations.items():
1360 if alias in names:
1361 raise ValueError(
1362 f"The annotation '{alias}' conflicts with a field on the model."
1363 )
1364 if isinstance(annotation, FilteredRelation):
1365 clone.sql_query.add_filtered_relation(annotation, alias)
1366 else:
1367 clone.sql_query.add_annotation(
1368 annotation,
1369 alias,
1370 select=select,
1371 )
1372 for alias, annotation in clone.sql_query.annotations.items():
1373 if alias in annotations and annotation.contains_aggregate:
1374 if clone._fields is None:
1375 clone.sql_query.group_by = True
1376 else:
1377 clone.sql_query.set_group_by()
1378 break
1379
1380 return clone
1381
1382 def order_by(self, *field_names: str) -> Self:
1383 """Return a new QuerySet instance with the ordering changed."""
1384 if self.sql_query.is_sliced:
1385 raise TypeError("Cannot reorder a query once a slice has been taken.")
1386 obj = self._chain()
1387 obj.sql_query.clear_ordering(force=True, clear_default=False)
1388 obj.sql_query.add_ordering(*field_names)
1389 return obj
1390
1391 def distinct(self, *field_names: str) -> Self:
1392 """
1393 Return a new QuerySet instance that will select only distinct results.
1394 """
1395 if self.sql_query.is_sliced:
1396 raise TypeError(
1397 "Cannot create distinct fields once a slice has been taken."
1398 )
1399 obj = self._chain()
1400 obj.sql_query.add_distinct_fields(*field_names)
1401 return obj
1402
1403 def extra(
1404 self,
1405 select: dict[str, str] | None = None,
1406 where: list[str] | None = None,
1407 params: list[Any] | None = None,
1408 tables: list[str] | None = None,
1409 order_by: list[str] | None = None,
1410 select_params: list[Any] | None = None,
1411 ) -> QuerySet[T]:
1412 """Add extra SQL fragments to the query."""
1413 if self.sql_query.is_sliced:
1414 raise TypeError("Cannot change a query once a slice has been taken.")
1415 clone = self._chain()
1416 clone.sql_query.add_extra(
1417 select or {},
1418 select_params,
1419 where or [],
1420 params or [],
1421 tables or [],
1422 tuple(order_by) if order_by else (),
1423 )
1424 return clone
1425
1426 def reverse(self) -> QuerySet[T]:
1427 """Reverse the ordering of the QuerySet."""
1428 if self.sql_query.is_sliced:
1429 raise TypeError("Cannot reverse a query once a slice has been taken.")
1430 clone = self._chain()
1431 clone.sql_query.standard_ordering = not clone.sql_query.standard_ordering
1432 return clone
1433
1434 def defer(self, *fields: str | None) -> QuerySet[T]:
1435 """
1436 Defer the loading of data for certain fields until they are accessed.
1437 Add the set of deferred fields to any existing set of deferred fields.
1438 The only exception to this is if None is passed in as the only
1439 parameter, in which case removal all deferrals.
1440 """
1441 if self._fields is not None:
1442 raise TypeError("Cannot call defer() after .values() or .values_list()")
1443 clone = self._chain()
1444 if fields == (None,):
1445 clone.sql_query.clear_deferred_loading()
1446 else:
1447 clone.sql_query.add_deferred_loading(frozenset(fields)) # type: ignore[arg-type]
1448 return clone
1449
1450 def only(self, *fields: str) -> QuerySet[T]:
1451 """
1452 Essentially, the opposite of defer(). Only the fields passed into this
1453 method and that are not already specified as deferred are loaded
1454 immediately when the queryset is evaluated.
1455 """
1456 if self._fields is not None:
1457 raise TypeError("Cannot call only() after .values() or .values_list()")
1458 if fields == (None,):
1459 # Can only pass None to defer(), not only(), as the rest option.
1460 # That won't stop people trying to do this, so let's be explicit.
1461 raise TypeError("Cannot pass None as an argument to only().")
1462 for field in fields:
1463 field = field.split(LOOKUP_SEP, 1)[0]
1464 if field in self.sql_query._filtered_relations:
1465 raise ValueError("only() is not supported with FilteredRelation.")
1466 clone = self._chain()
1467 clone.sql_query.add_immediate_loading(set(fields))
1468 return clone
1469
1470 ###################################
1471 # PUBLIC INTROSPECTION ATTRIBUTES #
1472 ###################################
1473
1474 @property
1475 def ordered(self) -> bool:
1476 """
1477 Return True if the QuerySet is ordered -- i.e. has an order_by()
1478 clause or a default ordering on the model (or is empty).
1479 """
1480 if isinstance(self, EmptyQuerySet):
1481 return True
1482 if self.sql_query.extra_order_by or self.sql_query.order_by:
1483 return True
1484 elif (
1485 self.sql_query.default_ordering
1486 and self.sql_query.model
1487 and self.sql_query.model._model_meta.ordering # type: ignore[possibly-missing-attribute]
1488 and
1489 # A default ordering doesn't affect GROUP BY queries.
1490 not self.sql_query.group_by
1491 ):
1492 return True
1493 else:
1494 return False
1495
1496 ###################
1497 # PRIVATE METHODS #
1498 ###################
1499
1500 def _insert(
1501 self,
1502 objs: list[T],
1503 fields: list[Field],
1504 returning_fields: list[Field] | None = None,
1505 raw: bool = False,
1506 on_conflict: OnConflict | None = None,
1507 update_fields: list[Field] | None = None,
1508 unique_fields: list[Field] | None = None,
1509 ) -> list[tuple[Any, ...]] | None:
1510 """
1511 Insert a new record for the given model. This provides an interface to
1512 the InsertQuery class and is how Model.save() is implemented.
1513 """
1514 self._for_write = True
1515 query = sql.InsertQuery(
1516 self.model,
1517 on_conflict=on_conflict.value if on_conflict else None,
1518 update_fields=update_fields,
1519 unique_fields=unique_fields,
1520 )
1521 query.insert_values(fields, objs, raw=raw)
1522 # InsertQuery returns SQLInsertCompiler which has different execute_sql signature
1523 return query.get_compiler().execute_sql(returning_fields) # type: ignore[arg-type]
1524
1525 def _batched_insert(
1526 self,
1527 objs: list[T],
1528 fields: list[Field],
1529 batch_size: int | None,
1530 on_conflict: OnConflict | None = None,
1531 update_fields: list[Field] | None = None,
1532 unique_fields: list[Field] | None = None,
1533 ) -> list[tuple[Any, ...]]:
1534 """
1535 Helper method for bulk_create() to insert objs one batch at a time.
1536 """
1537 ops = db_connection.ops
1538 max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
1539 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
1540 inserted_rows = []
1541 bulk_return = db_connection.features.can_return_rows_from_bulk_insert
1542 for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
1543 if bulk_return and on_conflict is None:
1544 inserted_rows.extend(
1545 self._insert( # type: ignore[arg-type]
1546 item,
1547 fields=fields,
1548 returning_fields=self.model._model_meta.db_returning_fields,
1549 )
1550 )
1551 else:
1552 self._insert(
1553 item,
1554 fields=fields,
1555 on_conflict=on_conflict,
1556 update_fields=update_fields,
1557 unique_fields=unique_fields,
1558 )
1559 return inserted_rows
1560
1561 def _chain(self) -> Self:
1562 """
1563 Return a copy of the current QuerySet that's ready for another
1564 operation.
1565 """
1566 obj = self._clone()
1567 if obj._sticky_filter:
1568 obj.sql_query.filter_is_sticky = True
1569 obj._sticky_filter = False
1570 return obj
1571
1572 def _clone(self) -> Self:
1573 """
1574 Return a copy of the current QuerySet. A lightweight alternative
1575 to deepcopy().
1576 """
1577 c = self.__class__.from_model(
1578 model=self.model,
1579 query=self.sql_query.chain(),
1580 )
1581 c._sticky_filter = self._sticky_filter
1582 c._for_write = self._for_write
1583 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1584 c._known_related_objects = self._known_related_objects
1585 c._iterable_class = self._iterable_class
1586 c._fields = self._fields
1587 return c
1588
1589 def _fetch_all(self) -> None:
1590 if self._result_cache is None:
1591 self._result_cache = list(self._iterable_class(self))
1592 if self._prefetch_related_lookups and not self._prefetch_done:
1593 self._prefetch_related_objects()
1594
1595 def _next_is_sticky(self) -> QuerySet[T]:
1596 """
1597 Indicate that the next filter call and the one following that should
1598 be treated as a single filter. This is only important when it comes to
1599 determining when to reuse tables for many-to-many filters. Required so
1600 that we can filter naturally on the results of related managers.
1601
1602 This doesn't return a clone of the current QuerySet (it returns
1603 "self"). The method is only used internally and should be immediately
1604 followed by a filter() that does create a clone.
1605 """
1606 self._sticky_filter = True
1607 return self
1608
1609 def _merge_sanity_check(self, other: QuerySet[T]) -> None:
1610 """Check that two QuerySet classes may be merged."""
1611 if self._fields is not None and (
1612 set(self.sql_query.values_select) != set(other.sql_query.values_select)
1613 or set(self.sql_query.extra_select) != set(other.sql_query.extra_select)
1614 or set(self.sql_query.annotation_select)
1615 != set(other.sql_query.annotation_select)
1616 ):
1617 raise TypeError(
1618 f"Merging '{self.__class__.__name__}' classes must involve the same values in each case."
1619 )
1620
1621 def _merge_known_related_objects(self, other: QuerySet[T]) -> None:
1622 """
1623 Keep track of all known related objects from either QuerySet instance.
1624 """
1625 for field, objects in other._known_related_objects.items():
1626 self._known_related_objects.setdefault(field, {}).update(objects)
1627
1628 def resolve_expression(self, *args: Any, **kwargs: Any) -> sql.Query:
1629 if self._fields and len(self._fields) > 1:
1630 # values() queryset can only be used as nested queries
1631 # if they are set up to select only a single field.
1632 raise TypeError("Cannot use multi-field values as a filter value.")
1633 query = self.sql_query.resolve_expression(*args, **kwargs)
1634 return query
1635
1636 def _has_filters(self) -> bool:
1637 """
1638 Check if this QuerySet has any filtering going on. This isn't
1639 equivalent with checking if all objects are present in results, for
1640 example, qs[1:]._has_filters() -> False.
1641 """
1642 return self.sql_query.has_filters()
1643
1644 @staticmethod
1645 def _validate_values_are_expressions(
1646 values: tuple[Any, ...], method_name: str
1647 ) -> None:
1648 invalid_args = sorted(
1649 str(arg) for arg in values if not isinstance(arg, ResolvableExpression)
1650 )
1651 if invalid_args:
1652 raise TypeError(
1653 "QuerySet.{}() received non-expression(s): {}.".format(
1654 method_name,
1655 ", ".join(invalid_args),
1656 )
1657 )
1658
1659
1660class InstanceCheckMeta(type):
1661 def __instancecheck__(self, instance: object) -> bool:
1662 return isinstance(instance, QuerySet) and instance.sql_query.is_empty()
1663
1664
1665class EmptyQuerySet(metaclass=InstanceCheckMeta):
1666 """
1667 Marker class to checking if a queryset is empty by .none():
1668 isinstance(qs.none(), EmptyQuerySet) -> True
1669 """
1670
1671 def __init__(self, *args: Any, **kwargs: Any):
1672 raise TypeError("EmptyQuerySet can't be instantiated")
1673
1674
1675class RawQuerySet:
1676 """
1677 Provide an iterator which converts the results of raw SQL queries into
1678 annotated model instances.
1679 """
1680
1681 def __init__(
1682 self,
1683 raw_query: str,
1684 model: type[Model] | None = None,
1685 query: sql.RawQuery | None = None,
1686 params: tuple[Any, ...] = (),
1687 translations: dict[str, str] | None = None,
1688 ):
1689 self.raw_query = raw_query
1690 self.model = model
1691 self.sql_query = query or sql.RawQuery(sql=raw_query, params=params)
1692 self.params = params
1693 self.translations = translations or {}
1694 self._result_cache: list[Model] | None = None
1695 self._prefetch_related_lookups: tuple[Any, ...] = ()
1696 self._prefetch_done = False
1697
1698 def resolve_model_init_order(
1699 self,
1700 ) -> tuple[list[str], list[int], list[tuple[str, int]]]:
1701 """Resolve the init field names and value positions."""
1702 converter = db_connection.introspection.identifier_converter
1703 model = self.model
1704 assert model is not None
1705 model_init_fields = [
1706 f for f in model._model_meta.fields if converter(f.column) in self.columns
1707 ]
1708 annotation_fields = [
1709 (column, pos)
1710 for pos, column in enumerate(self.columns)
1711 if column not in self.model_fields
1712 ]
1713 model_init_order = [
1714 self.columns.index(converter(f.column)) for f in model_init_fields
1715 ]
1716 model_init_names = [f.attname for f in model_init_fields]
1717 return model_init_names, model_init_order, annotation_fields
1718
1719 def prefetch_related(self, *lookups: str | Prefetch | None) -> RawQuerySet:
1720 """Same as QuerySet.prefetch_related()"""
1721 clone = self._clone()
1722 if lookups == (None,):
1723 clone._prefetch_related_lookups = ()
1724 else:
1725 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1726 return clone
1727
1728 def _prefetch_related_objects(self) -> None:
1729 assert self._result_cache is not None
1730 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1731 self._prefetch_done = True
1732
1733 def _clone(self) -> RawQuerySet:
1734 """Same as QuerySet._clone()"""
1735 c = self.__class__(
1736 self.raw_query,
1737 model=self.model,
1738 query=self.sql_query,
1739 params=self.params,
1740 translations=self.translations,
1741 )
1742 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1743 return c
1744
1745 def _fetch_all(self) -> None:
1746 if self._result_cache is None:
1747 self._result_cache = list(self.iterator())
1748 if self._prefetch_related_lookups and not self._prefetch_done:
1749 self._prefetch_related_objects()
1750
1751 def __len__(self) -> int:
1752 self._fetch_all()
1753 assert self._result_cache is not None
1754 return len(self._result_cache)
1755
1756 def __bool__(self) -> bool:
1757 self._fetch_all()
1758 return bool(self._result_cache)
1759
1760 def __iter__(self) -> Iterator[Model]:
1761 self._fetch_all()
1762 assert self._result_cache is not None
1763 return iter(self._result_cache)
1764
1765 def iterator(self) -> Iterator[Model]:
1766 yield from RawModelIterable(self) # type: ignore[arg-type]
1767
1768 def __repr__(self) -> str:
1769 return f"<{self.__class__.__name__}: {self.sql_query}>"
1770
1771 def __getitem__(self, k: int | slice) -> Model | list[Model]:
1772 return list(self)[k]
1773
1774 @cached_property
1775 def columns(self) -> list[str]:
1776 """
1777 A list of model field names in the order they'll appear in the
1778 query results.
1779 """
1780 columns = self.sql_query.get_columns()
1781 # Adjust any column names which don't match field names
1782 for query_name, model_name in self.translations.items():
1783 # Ignore translations for nonexistent column names
1784 try:
1785 index = columns.index(query_name)
1786 except ValueError:
1787 pass
1788 else:
1789 columns[index] = model_name
1790 return columns
1791
1792 @cached_property
1793 def model_fields(self) -> dict[str, Field]:
1794 """A dict mapping column names to model field names."""
1795 converter = db_connection.introspection.identifier_converter
1796 model_fields = {}
1797 model = self.model
1798 assert model is not None
1799 for field in model._model_meta.fields:
1800 name, column = field.get_attname_column()
1801 model_fields[converter(column)] = field
1802 return model_fields
1803
1804
1805class Prefetch:
1806 def __init__(
1807 self,
1808 lookup: str,
1809 queryset: QuerySet[Any] | None = None,
1810 to_attr: str | None = None,
1811 ):
1812 # `prefetch_through` is the path we traverse to perform the prefetch.
1813 self.prefetch_through = lookup
1814 # `prefetch_to` is the path to the attribute that stores the result.
1815 self.prefetch_to = lookup
1816 if queryset is not None and (
1817 isinstance(queryset, RawQuerySet)
1818 or (
1819 hasattr(queryset, "_iterable_class")
1820 and not issubclass(queryset._iterable_class, ModelIterable)
1821 )
1822 ):
1823 raise ValueError(
1824 "Prefetch querysets cannot use raw(), values(), and values_list()."
1825 )
1826 if to_attr:
1827 self.prefetch_to = LOOKUP_SEP.join(
1828 lookup.split(LOOKUP_SEP)[:-1] + [to_attr]
1829 )
1830
1831 self.queryset = queryset
1832 self.to_attr = to_attr
1833
1834 def __getstate__(self) -> dict[str, Any]:
1835 obj_dict = self.__dict__.copy()
1836 if self.queryset is not None:
1837 queryset = self.queryset._chain()
1838 # Prevent the QuerySet from being evaluated
1839 queryset._result_cache = []
1840 queryset._prefetch_done = True
1841 obj_dict["queryset"] = queryset
1842 return obj_dict
1843
1844 def add_prefix(self, prefix: str) -> None:
1845 self.prefetch_through = prefix + LOOKUP_SEP + self.prefetch_through
1846 self.prefetch_to = prefix + LOOKUP_SEP + self.prefetch_to
1847
1848 def get_current_prefetch_to(self, level: int) -> str:
1849 return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[: level + 1])
1850
1851 def get_current_to_attr(self, level: int) -> tuple[str, bool]:
1852 parts = self.prefetch_to.split(LOOKUP_SEP)
1853 to_attr = parts[level]
1854 as_attr = bool(self.to_attr and level == len(parts) - 1)
1855 return to_attr, as_attr
1856
1857 def get_current_queryset(self, level: int) -> QuerySet[Any] | None:
1858 if self.get_current_prefetch_to(level) == self.prefetch_to:
1859 return self.queryset
1860 return None
1861
1862 def __eq__(self, other: object) -> bool:
1863 if not isinstance(other, Prefetch):
1864 return NotImplemented
1865 return self.prefetch_to == other.prefetch_to
1866
1867 def __hash__(self) -> int:
1868 return hash((self.__class__, self.prefetch_to))
1869
1870
1871def normalize_prefetch_lookups(
1872 lookups: tuple[str | Prefetch, ...] | list[str | Prefetch],
1873 prefix: str | None = None,
1874) -> list[Prefetch]:
1875 """Normalize lookups into Prefetch objects."""
1876 ret = []
1877 for lookup in lookups:
1878 if not isinstance(lookup, Prefetch):
1879 lookup = Prefetch(lookup)
1880 if prefix:
1881 lookup.add_prefix(prefix)
1882 ret.append(lookup)
1883 return ret
1884
1885
1886def prefetch_related_objects(
1887 model_instances: Sequence[Model], *related_lookups: str | Prefetch
1888) -> None:
1889 """
1890 Populate prefetched object caches for a list of model instances based on
1891 the lookups/Prefetch instances given.
1892 """
1893 if not model_instances:
1894 return # nothing to do
1895
1896 # We need to be able to dynamically add to the list of prefetch_related
1897 # lookups that we look up (see below). So we need some book keeping to
1898 # ensure we don't do duplicate work.
1899 done_queries = {} # dictionary of things like 'foo__bar': [results]
1900
1901 auto_lookups = set() # we add to this as we go through.
1902 followed_descriptors = set() # recursion protection
1903
1904 all_lookups = normalize_prefetch_lookups(reversed(related_lookups)) # type: ignore[arg-type]
1905 while all_lookups:
1906 lookup = all_lookups.pop()
1907 if lookup.prefetch_to in done_queries:
1908 if lookup.queryset is not None:
1909 raise ValueError(
1910 f"'{lookup.prefetch_to}' lookup was already seen with a different queryset. "
1911 "You may need to adjust the ordering of your lookups."
1912 )
1913
1914 continue
1915
1916 # Top level, the list of objects to decorate is the result cache
1917 # from the primary QuerySet. It won't be for deeper levels.
1918 obj_list = model_instances
1919
1920 through_attrs = lookup.prefetch_through.split(LOOKUP_SEP)
1921 for level, through_attr in enumerate(through_attrs):
1922 # Prepare main instances
1923 if not obj_list:
1924 break
1925
1926 prefetch_to = lookup.get_current_prefetch_to(level)
1927 if prefetch_to in done_queries:
1928 # Skip any prefetching, and any object preparation
1929 obj_list = done_queries[prefetch_to]
1930 continue
1931
1932 # Prepare objects:
1933 good_objects = True
1934 for obj in obj_list:
1935 # Since prefetching can re-use instances, it is possible to have
1936 # the same instance multiple times in obj_list, so obj might
1937 # already be prepared.
1938 if not hasattr(obj, "_prefetched_objects_cache"):
1939 try:
1940 obj._prefetched_objects_cache = {}
1941 except (AttributeError, TypeError):
1942 # Must be an immutable object from
1943 # values_list(flat=True), for example (TypeError) or
1944 # a QuerySet subclass that isn't returning Model
1945 # instances (AttributeError), either in Plain or a 3rd
1946 # party. prefetch_related() doesn't make sense, so quit.
1947 good_objects = False
1948 break
1949 if not good_objects:
1950 break
1951
1952 # Descend down tree
1953
1954 # We assume that objects retrieved are homogeneous (which is the premise
1955 # of prefetch_related), so what applies to first object applies to all.
1956 first_obj = obj_list[0]
1957 to_attr = lookup.get_current_to_attr(level)[0]
1958 prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(
1959 first_obj, through_attr, to_attr
1960 )
1961
1962 if not attr_found:
1963 raise AttributeError(
1964 f"Cannot find '{through_attr}' on {first_obj.__class__.__name__} object, '{lookup.prefetch_through}' is an invalid "
1965 "parameter to prefetch_related()"
1966 )
1967
1968 if level == len(through_attrs) - 1 and prefetcher is None:
1969 # Last one, this *must* resolve to something that supports
1970 # prefetching, otherwise there is no point adding it and the
1971 # developer asking for it has made a mistake.
1972 raise ValueError(
1973 f"'{lookup.prefetch_through}' does not resolve to an item that supports "
1974 "prefetching - this is an invalid parameter to "
1975 "prefetch_related()."
1976 )
1977
1978 obj_to_fetch = None
1979 if prefetcher is not None:
1980 obj_to_fetch = [obj for obj in obj_list if not is_fetched(obj)]
1981
1982 if obj_to_fetch:
1983 obj_list, additional_lookups = prefetch_one_level(
1984 obj_to_fetch,
1985 prefetcher,
1986 lookup,
1987 level,
1988 )
1989 # We need to ensure we don't keep adding lookups from the
1990 # same relationships to stop infinite recursion. So, if we
1991 # are already on an automatically added lookup, don't add
1992 # the new lookups from relationships we've seen already.
1993 if not (
1994 prefetch_to in done_queries
1995 and lookup in auto_lookups
1996 and descriptor in followed_descriptors
1997 ):
1998 done_queries[prefetch_to] = obj_list
1999 new_lookups = normalize_prefetch_lookups(
2000 reversed(additional_lookups), # type: ignore[arg-type]
2001 prefetch_to,
2002 )
2003 auto_lookups.update(new_lookups)
2004 all_lookups.extend(new_lookups)
2005 followed_descriptors.add(descriptor)
2006 else:
2007 # Either a singly related object that has already been fetched
2008 # (e.g. via select_related), or hopefully some other property
2009 # that doesn't support prefetching but needs to be traversed.
2010
2011 # We replace the current list of parent objects with the list
2012 # of related objects, filtering out empty or missing values so
2013 # that we can continue with nullable or reverse relations.
2014 new_obj_list = []
2015 for obj in obj_list:
2016 if through_attr in getattr(obj, "_prefetched_objects_cache", ()):
2017 # If related objects have been prefetched, use the
2018 # cache rather than the object's through_attr.
2019 new_obj = list(obj._prefetched_objects_cache.get(through_attr)) # type: ignore[arg-type]
2020 else:
2021 try:
2022 new_obj = getattr(obj, through_attr)
2023 except ObjectDoesNotExist:
2024 continue
2025 if new_obj is None:
2026 continue
2027 # We special-case `list` rather than something more generic
2028 # like `Iterable` because we don't want to accidentally match
2029 # user models that define __iter__.
2030 if isinstance(new_obj, list):
2031 new_obj_list.extend(new_obj)
2032 else:
2033 new_obj_list.append(new_obj)
2034 obj_list = new_obj_list
2035
2036
2037def get_prefetcher(
2038 instance: Model, through_attr: str, to_attr: str
2039) -> tuple[Any, Any, bool, Callable[[Model], bool]]:
2040 """
2041 For the attribute 'through_attr' on the given instance, find
2042 an object that has a get_prefetch_queryset().
2043 Return a 4 tuple containing:
2044 (the object with get_prefetch_queryset (or None),
2045 the descriptor object representing this relationship (or None),
2046 a boolean that is False if the attribute was not found at all,
2047 a function that takes an instance and returns a boolean that is True if
2048 the attribute has already been fetched for that instance)
2049 """
2050
2051 def has_to_attr_attribute(instance: Model) -> bool:
2052 return hasattr(instance, to_attr)
2053
2054 prefetcher = None
2055 is_fetched: Callable[[Model], bool] = has_to_attr_attribute
2056
2057 # For singly related objects, we have to avoid getting the attribute
2058 # from the object, as this will trigger the query. So we first try
2059 # on the class, in order to get the descriptor object.
2060 rel_obj_descriptor = getattr(instance.__class__, through_attr, None)
2061 if rel_obj_descriptor is None:
2062 attr_found = hasattr(instance, through_attr)
2063 else:
2064 attr_found = True
2065 if rel_obj_descriptor:
2066 # singly related object, descriptor object has the
2067 # get_prefetch_queryset() method.
2068 if hasattr(rel_obj_descriptor, "get_prefetch_queryset"):
2069 prefetcher = rel_obj_descriptor
2070 is_fetched = rel_obj_descriptor.is_cached
2071 else:
2072 # descriptor doesn't support prefetching, so we go ahead and get
2073 # the attribute on the instance rather than the class to
2074 # support many related managers
2075 rel_obj = getattr(instance, through_attr)
2076 if hasattr(rel_obj, "get_prefetch_queryset"):
2077 prefetcher = rel_obj
2078 if through_attr != to_attr:
2079 # Special case cached_property instances because hasattr
2080 # triggers attribute computation and assignment.
2081 if isinstance(
2082 getattr(instance.__class__, to_attr, None), cached_property
2083 ):
2084
2085 def has_cached_property(instance: Model) -> bool:
2086 return to_attr in instance.__dict__
2087
2088 is_fetched = has_cached_property
2089 else:
2090
2091 def in_prefetched_cache(instance: Model) -> bool:
2092 return through_attr in instance._prefetched_objects_cache
2093
2094 is_fetched = in_prefetched_cache
2095 return prefetcher, rel_obj_descriptor, attr_found, is_fetched
2096
2097
2098def prefetch_one_level(
2099 instances: list[Model], prefetcher: Any, lookup: Prefetch, level: int
2100) -> tuple[list[Model], list[Prefetch]]:
2101 """
2102 Helper function for prefetch_related_objects().
2103
2104 Run prefetches on all instances using the prefetcher object,
2105 assigning results to relevant caches in instance.
2106
2107 Return the prefetched objects along with any additional prefetches that
2108 must be done due to prefetch_related lookups found from default managers.
2109 """
2110 # prefetcher must have a method get_prefetch_queryset() which takes a list
2111 # of instances, and returns a tuple:
2112
2113 # (queryset of instances of self.model that are related to passed in instances,
2114 # callable that gets value to be matched for returned instances,
2115 # callable that gets value to be matched for passed in instances,
2116 # boolean that is True for singly related objects,
2117 # cache or field name to assign to,
2118 # boolean that is True when the previous argument is a cache name vs a field name).
2119
2120 # The 'values to be matched' must be hashable as they will be used
2121 # in a dictionary.
2122
2123 (
2124 rel_qs,
2125 rel_obj_attr,
2126 instance_attr,
2127 single,
2128 cache_name,
2129 is_descriptor,
2130 ) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))
2131 # We have to handle the possibility that the QuerySet we just got back
2132 # contains some prefetch_related lookups. We don't want to trigger the
2133 # prefetch_related functionality by evaluating the query. Rather, we need
2134 # to merge in the prefetch_related lookups.
2135 # Copy the lookups in case it is a Prefetch object which could be reused
2136 # later (happens in nested prefetch_related).
2137 additional_lookups = [
2138 copy.copy(additional_lookup)
2139 for additional_lookup in getattr(rel_qs, "_prefetch_related_lookups", ())
2140 ]
2141 if additional_lookups:
2142 # Don't need to clone because the queryset should have given us a fresh
2143 # instance, so we access an internal instead of using public interface
2144 # for performance reasons.
2145 rel_qs._prefetch_related_lookups = ()
2146
2147 all_related_objects = list(rel_qs)
2148
2149 rel_obj_cache = {}
2150 for rel_obj in all_related_objects:
2151 rel_attr_val = rel_obj_attr(rel_obj)
2152 rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
2153
2154 to_attr, as_attr = lookup.get_current_to_attr(level)
2155 # Make sure `to_attr` does not conflict with a field.
2156 if as_attr and instances:
2157 # We assume that objects retrieved are homogeneous (which is the premise
2158 # of prefetch_related), so what applies to first object applies to all.
2159 model = instances[0].__class__
2160 try:
2161 model._model_meta.get_field(to_attr)
2162 except FieldDoesNotExist:
2163 pass
2164 else:
2165 msg = "to_attr={} conflicts with a field on the {} model."
2166 raise ValueError(msg.format(to_attr, model.__name__))
2167
2168 # Whether or not we're prefetching the last part of the lookup.
2169 leaf = len(lookup.prefetch_through.split(LOOKUP_SEP)) - 1 == level
2170
2171 for obj in instances:
2172 instance_attr_val = instance_attr(obj)
2173 vals = rel_obj_cache.get(instance_attr_val, [])
2174
2175 if single:
2176 val = vals[0] if vals else None
2177 if as_attr:
2178 # A to_attr has been given for the prefetch.
2179 setattr(obj, to_attr, val)
2180 elif is_descriptor:
2181 # cache_name points to a field name in obj.
2182 # This field is a descriptor for a related object.
2183 setattr(obj, cache_name, val)
2184 else:
2185 # No to_attr has been given for this prefetch operation and the
2186 # cache_name does not point to a descriptor. Store the value of
2187 # the field in the object's field cache.
2188 obj._state.fields_cache[cache_name] = val # type: ignore[assignment]
2189 else:
2190 if as_attr:
2191 setattr(obj, to_attr, vals)
2192 else:
2193 queryset = getattr(obj, to_attr)
2194 if leaf and lookup.queryset is not None:
2195 qs = queryset._apply_rel_filters(lookup.queryset)
2196 else:
2197 # Check if queryset is a QuerySet or a related manager
2198 # We need a QuerySet instance to cache the prefetched values
2199 if isinstance(queryset, QuerySet):
2200 # It's already a QuerySet, create a new instance
2201 qs = queryset.__class__.from_model(queryset.model)
2202 else:
2203 # It's a related manager, get its QuerySet
2204 # The manager's query property returns a properly filtered QuerySet
2205 qs = queryset.query
2206 qs._result_cache = vals
2207 # We don't want the individual qs doing prefetch_related now,
2208 # since we have merged this into the current work.
2209 qs._prefetch_done = True
2210 obj._prefetched_objects_cache[cache_name] = qs
2211 return all_related_objects, additional_lookups
2212
2213
2214class RelatedPopulator:
2215 """
2216 RelatedPopulator is used for select_related() object instantiation.
2217
2218 The idea is that each select_related() model will be populated by a
2219 different RelatedPopulator instance. The RelatedPopulator instances get
2220 klass_info and select (computed in SQLCompiler) plus the used db as
2221 input for initialization. That data is used to compute which columns
2222 to use, how to instantiate the model, and how to populate the links
2223 between the objects.
2224
2225 The actual creation of the objects is done in populate() method. This
2226 method gets row and from_obj as input and populates the select_related()
2227 model instance.
2228 """
2229
2230 def __init__(self, klass_info: dict[str, Any], select: list[Any]):
2231 # Pre-compute needed attributes. The attributes are:
2232 # - model_cls: the possibly deferred model class to instantiate
2233 # - either:
2234 # - cols_start, cols_end: usually the columns in the row are
2235 # in the same order model_cls.__init__ expects them, so we
2236 # can instantiate by model_cls(*row[cols_start:cols_end])
2237 # - reorder_for_init: When select_related descends to a child
2238 # class, then we want to reuse the already selected parent
2239 # data. However, in this case the parent data isn't necessarily
2240 # in the same order that Model.__init__ expects it to be, so
2241 # we have to reorder the parent data. The reorder_for_init
2242 # attribute contains a function used to reorder the field data
2243 # in the order __init__ expects it.
2244 # - id_idx: the index of the primary key field in the reordered
2245 # model data. Used to check if a related object exists at all.
2246 # - init_list: the field attnames fetched from the database. For
2247 # deferred models this isn't the same as all attnames of the
2248 # model's fields.
2249 # - related_populators: a list of RelatedPopulator instances if
2250 # select_related() descends to related models from this model.
2251 # - local_setter, remote_setter: Methods to set cached values on
2252 # the object being populated and on the remote object. Usually
2253 # these are Field.set_cached_value() methods.
2254 select_fields = klass_info["select_fields"]
2255
2256 self.cols_start = select_fields[0]
2257 self.cols_end = select_fields[-1] + 1
2258 self.init_list = [
2259 f[0].target.attname for f in select[self.cols_start : self.cols_end]
2260 ]
2261 self.reorder_for_init = None
2262
2263 self.model_cls = klass_info["model"]
2264 self.id_idx = self.init_list.index("id")
2265 self.related_populators = get_related_populators(klass_info, select)
2266 self.local_setter = klass_info["local_setter"]
2267 self.remote_setter = klass_info["remote_setter"]
2268
2269 def populate(self, row: tuple[Any, ...], from_obj: Model) -> None:
2270 if self.reorder_for_init:
2271 obj_data = self.reorder_for_init(row)
2272 else:
2273 obj_data = row[self.cols_start : self.cols_end]
2274 if obj_data[self.id_idx] is None:
2275 obj = None
2276 else:
2277 obj = self.model_cls.from_db(self.init_list, obj_data)
2278 for rel_iter in self.related_populators:
2279 rel_iter.populate(row, obj)
2280 self.local_setter(from_obj, obj)
2281 if obj is not None:
2282 self.remote_setter(obj, from_obj)
2283
2284
2285def get_related_populators(
2286 klass_info: dict[str, Any], select: list[Any]
2287) -> list[RelatedPopulator]:
2288 iterators = []
2289 related_klass_infos = klass_info.get("related_klass_infos", [])
2290 for rel_klass_info in related_klass_infos:
2291 rel_cls = RelatedPopulator(rel_klass_info, select)
2292 iterators.append(rel_cls)
2293 return iterators