1"""
2The main QuerySet implementation. This provides the public API for the ORM.
3"""
4
5from __future__ import annotations
6
7import copy
8import operator
9import warnings
10from abc import ABC, abstractmethod
11from collections.abc import Callable, Iterator, Sequence
12from functools import cached_property
13from itertools import chain, islice
14from typing import TYPE_CHECKING, Any, Generic, Never, Self, TypeVar, overload
15
16import plain.runtime
17from plain.exceptions import ValidationError
18from plain.models import (
19 sql,
20 transaction,
21)
22from plain.models.constants import LOOKUP_SEP, OnConflict
23from plain.models.db import (
24 PLAIN_VERSION_PICKLE_KEY,
25 IntegrityError,
26 NotSupportedError,
27 db_connection,
28)
29from plain.models.exceptions import (
30 FieldDoesNotExist,
31 FieldError,
32 ObjectDoesNotExist,
33)
34from plain.models.expressions import Case, F, ResolvableExpression, Value, When
35from plain.models.fields import (
36 DateField,
37 DateTimeField,
38 Field,
39 PrimaryKeyField,
40)
41from plain.models.functions import Cast, Trunc
42from plain.models.query_utils import FilteredRelation, Q
43from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
44from plain.models.utils import resolve_callables
45from plain.utils import timezone
46from plain.utils.functional import partition
47
48if TYPE_CHECKING:
49 from datetime import tzinfo
50
51 from plain.models import Model
52
53# Type variable for QuerySet generic
54T = TypeVar("T", bound="Model")
55
56# The maximum number of results to fetch in a get() query.
57MAX_GET_RESULTS = 21
58
59# The maximum number of items to display in a QuerySet.__repr__
60REPR_OUTPUT_SIZE = 20
61
62
63class BaseIterable(ABC):
64 def __init__(
65 self,
66 queryset: QuerySet[Any],
67 chunked_fetch: bool = False,
68 chunk_size: int = GET_ITERATOR_CHUNK_SIZE,
69 ):
70 self.queryset = queryset
71 self.chunked_fetch = chunked_fetch
72 self.chunk_size = chunk_size
73
74 @abstractmethod
75 def __iter__(self) -> Iterator[Any]: ...
76
77
78class ModelIterable(BaseIterable):
79 """Iterable that yields a model instance for each row."""
80
81 def __iter__(self) -> Iterator[Model]:
82 queryset = self.queryset
83 compiler = queryset.sql_query.get_compiler()
84 # Execute the query. This will also fill compiler.select, klass_info,
85 # and annotations.
86 results = compiler.execute_sql(
87 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
88 )
89 select, klass_info, annotation_col_map = (
90 compiler.select,
91 compiler.klass_info,
92 compiler.annotation_col_map,
93 )
94 model_cls = klass_info["model"]
95 select_fields = klass_info["select_fields"]
96 model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
97 init_list = [
98 f[0].target.attname for f in select[model_fields_start:model_fields_end]
99 ]
100 related_populators = get_related_populators(klass_info, select)
101 known_related_objects = [
102 (
103 field,
104 related_objs,
105 operator.attrgetter(field.attname),
106 )
107 for field, related_objs in queryset._known_related_objects.items()
108 ]
109 for row in compiler.results_iter(results):
110 obj = model_cls.from_db(init_list, row[model_fields_start:model_fields_end])
111 for rel_populator in related_populators:
112 rel_populator.populate(row, obj)
113 if annotation_col_map:
114 for attr_name, col_pos in annotation_col_map.items():
115 setattr(obj, attr_name, row[col_pos])
116
117 # Add the known related objects to the model.
118 for field, rel_objs, rel_getter in known_related_objects:
119 # Avoid overwriting objects loaded by, e.g., select_related().
120 if field.is_cached(obj):
121 continue
122 rel_obj_id = rel_getter(obj)
123 try:
124 rel_obj = rel_objs[rel_obj_id]
125 except KeyError:
126 pass # May happen in qs1 | qs2 scenarios.
127 else:
128 setattr(obj, field.name, rel_obj)
129
130 yield obj
131
132
133class RawModelIterable(BaseIterable):
134 """
135 Iterable that yields a model instance for each row from a raw queryset.
136 """
137
138 queryset: RawQuerySet
139
140 def __iter__(self) -> Iterator[Model]:
141 # Cache some things for performance reasons outside the loop.
142 query = self.queryset.sql_query
143 compiler = db_connection.ops.compiler("SQLCompiler")(query, db_connection)
144 query_iterator = iter(query)
145
146 try:
147 (
148 model_init_names,
149 model_init_pos,
150 annotation_fields,
151 ) = self.queryset.resolve_model_init_order()
152 model_cls = self.queryset.model
153 assert model_cls is not None
154 if "id" not in model_init_names:
155 raise FieldDoesNotExist("Raw query must include the primary key")
156 fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
157 converters = compiler.get_converters(
158 [
159 f.get_col(f.model.model_options.db_table) if f else None
160 for f in fields
161 ]
162 )
163 if converters:
164 query_iterator = compiler.apply_converters(query_iterator, converters)
165 for values in query_iterator:
166 # Associate fields to values
167 model_init_values = [values[pos] for pos in model_init_pos]
168 instance = model_cls.from_db(model_init_names, model_init_values)
169 if annotation_fields:
170 for column, pos in annotation_fields:
171 setattr(instance, column, values[pos])
172 yield instance
173 finally:
174 # Done iterating the Query. If it has its own cursor, close it.
175 if hasattr(query, "cursor") and query.cursor:
176 query.cursor.close()
177
178
179class ValuesIterable(BaseIterable):
180 """
181 Iterable returned by QuerySet.values() that yields a dict for each row.
182 """
183
184 def __iter__(self) -> Iterator[dict[str, Any]]:
185 queryset = self.queryset
186 query = queryset.sql_query
187 compiler = query.get_compiler()
188
189 # extra(select=...) cols are always at the start of the row.
190 names = [
191 *query.extra_select,
192 *query.values_select,
193 *query.annotation_select,
194 ]
195 indexes = range(len(names))
196 for row in compiler.results_iter(
197 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
198 ):
199 yield {names[i]: row[i] for i in indexes}
200
201
202class ValuesListIterable(BaseIterable):
203 """
204 Iterable returned by QuerySet.values_list(flat=False) that yields a tuple
205 for each row.
206 """
207
208 def __iter__(self) -> Iterator[tuple[Any, ...]]:
209 queryset = self.queryset
210 query = queryset.sql_query
211 compiler = query.get_compiler()
212
213 if queryset._fields:
214 # extra(select=...) cols are always at the start of the row.
215 names = [
216 *query.extra_select,
217 *query.values_select,
218 *query.annotation_select,
219 ]
220 fields = [
221 *queryset._fields,
222 *(f for f in query.annotation_select if f not in queryset._fields),
223 ]
224 if fields != names:
225 # Reorder according to fields.
226 index_map = {name: idx for idx, name in enumerate(names)}
227 rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
228 return map(
229 rowfactory,
230 compiler.results_iter(
231 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
232 ),
233 )
234 return compiler.results_iter(
235 tuple_expected=True,
236 chunked_fetch=self.chunked_fetch,
237 chunk_size=self.chunk_size,
238 )
239
240
241class FlatValuesListIterable(BaseIterable):
242 """
243 Iterable returned by QuerySet.values_list(flat=True) that yields single
244 values.
245 """
246
247 def __iter__(self) -> Iterator[Any]:
248 queryset = self.queryset
249 compiler = queryset.sql_query.get_compiler()
250 for row in compiler.results_iter(
251 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
252 ):
253 yield row[0]
254
255
256class QuerySet(Generic[T]):
257 """
258 Represent a lazy database lookup for a set of objects.
259
260 Usage:
261 MyModel.query.filter(name="test").all()
262
263 Custom QuerySets:
264 from typing import Self
265
266 class TaskQuerySet(QuerySet["Task"]):
267 def active(self) -> Self:
268 return self.filter(is_active=True)
269
270 class Task(Model):
271 is_active = BooleanField(default=True)
272 query = TaskQuerySet()
273
274 Task.query.active().filter(name="test") # Full type inference
275
276 Custom methods should return `Self` to preserve type through method chaining.
277 """
278
279 # Instance attributes (set in from_model())
280 model: type[T]
281 _query: sql.Query
282 _result_cache: list[T] | None
283 _sticky_filter: bool
284 _for_write: bool
285 _prefetch_related_lookups: tuple[Any, ...]
286 _prefetch_done: bool
287 _known_related_objects: dict[Any, dict[Any, Any]]
288 _iterable_class: type[BaseIterable]
289 _fields: tuple[str, ...] | None
290 _defer_next_filter: bool
291 _deferred_filter: tuple[bool, tuple[Any, ...], dict[str, Any]] | None
292
293 def __init__(self):
294 """Minimal init for descriptor mode. Use from_model() to create instances."""
295 pass
296
297 @classmethod
298 def from_model(cls, model: type[T], query: sql.Query | None = None) -> Self:
299 """Create a QuerySet instance bound to a model."""
300 instance = cls()
301 instance.model = model
302 instance._query = query or sql.Query(model)
303 instance._result_cache = None
304 instance._sticky_filter = False
305 instance._for_write = False
306 instance._prefetch_related_lookups = ()
307 instance._prefetch_done = False
308 instance._known_related_objects = {}
309 instance._iterable_class = ModelIterable
310 instance._fields = None
311 instance._defer_next_filter = False
312 instance._deferred_filter = None
313 return instance
314
315 @overload
316 def __get__(self, instance: None, owner: type[T]) -> Self: ...
317
318 @overload
319 def __get__(self, instance: Model, owner: type[T]) -> Never: ...
320
321 def __get__(self, instance: Any, owner: type[T]) -> Self:
322 """Descriptor protocol - return a new QuerySet bound to the model."""
323 if instance is not None:
324 raise AttributeError(
325 f"QuerySet is only accessible from the model class, not instances. "
326 f"Use {owner.__name__}.query instead."
327 )
328 return self.from_model(owner)
329
330 @property
331 def sql_query(self) -> sql.Query:
332 if self._deferred_filter:
333 negate, args, kwargs = self._deferred_filter
334 self._filter_or_exclude_inplace(negate, args, kwargs)
335 self._deferred_filter = None
336 return self._query
337
338 @sql_query.setter
339 def sql_query(self, value: sql.Query) -> None:
340 if value.values_select:
341 self._iterable_class = ValuesIterable
342 self._query = value
343
344 ########################
345 # PYTHON MAGIC METHODS #
346 ########################
347
348 def __deepcopy__(self, memo: dict[int, Any]) -> QuerySet[T]:
349 """Don't populate the QuerySet's cache."""
350 obj = self.__class__.from_model(self.model)
351 for k, v in self.__dict__.items():
352 if k == "_result_cache":
353 obj.__dict__[k] = None
354 else:
355 obj.__dict__[k] = copy.deepcopy(v, memo)
356 return obj
357
358 def __getstate__(self) -> dict[str, Any]:
359 # Force the cache to be fully populated.
360 self._fetch_all()
361 return {**self.__dict__, PLAIN_VERSION_PICKLE_KEY: plain.runtime.__version__}
362
363 def __setstate__(self, state: dict[str, Any]) -> None:
364 pickled_version = state.get(PLAIN_VERSION_PICKLE_KEY)
365 if pickled_version:
366 if pickled_version != plain.runtime.__version__:
367 warnings.warn(
368 f"Pickled queryset instance's Plain version {pickled_version} does not "
369 f"match the current version {plain.runtime.__version__}.",
370 RuntimeWarning,
371 stacklevel=2,
372 )
373 else:
374 warnings.warn(
375 "Pickled queryset instance's Plain version is not specified.",
376 RuntimeWarning,
377 stacklevel=2,
378 )
379 self.__dict__.update(state)
380
381 def __repr__(self) -> str:
382 data = list(self[: REPR_OUTPUT_SIZE + 1])
383 if len(data) > REPR_OUTPUT_SIZE:
384 data[-1] = "...(remaining elements truncated)..."
385 return f"<{self.__class__.__name__} {data!r}>"
386
387 def __len__(self) -> int:
388 self._fetch_all()
389 assert self._result_cache is not None
390 return len(self._result_cache)
391
392 def __iter__(self) -> Iterator[T]:
393 """
394 The queryset iterator protocol uses three nested iterators in the
395 default case:
396 1. sql.compiler.execute_sql()
397 - Returns 100 rows at time (constants.GET_ITERATOR_CHUNK_SIZE)
398 using cursor.fetchmany(). This part is responsible for
399 doing some column masking, and returning the rows in chunks.
400 2. sql.compiler.results_iter()
401 - Returns one row at time. At this point the rows are still just
402 tuples. In some cases the return values are converted to
403 Python values at this location.
404 3. self.iterator()
405 - Responsible for turning the rows into model objects.
406 """
407 self._fetch_all()
408 assert self._result_cache is not None
409 return iter(self._result_cache)
410
411 def __bool__(self) -> bool:
412 self._fetch_all()
413 return bool(self._result_cache)
414
415 @overload
416 def __getitem__(self, k: int) -> T: ...
417
418 @overload
419 def __getitem__(self, k: slice) -> QuerySet[T] | list[T]: ...
420
421 def __getitem__(self, k: int | slice) -> T | QuerySet[T] | list[T]:
422 """Retrieve an item or slice from the set of results."""
423 if not isinstance(k, int | slice):
424 raise TypeError(
425 f"QuerySet indices must be integers or slices, not {type(k).__name__}."
426 )
427 if (isinstance(k, int) and k < 0) or (
428 isinstance(k, slice)
429 and (
430 (k.start is not None and k.start < 0)
431 or (k.stop is not None and k.stop < 0)
432 )
433 ):
434 raise ValueError("Negative indexing is not supported.")
435
436 if self._result_cache is not None:
437 return self._result_cache[k]
438
439 if isinstance(k, slice):
440 qs = self._chain()
441 if k.start is not None:
442 start = int(k.start)
443 else:
444 start = None
445 if k.stop is not None:
446 stop = int(k.stop)
447 else:
448 stop = None
449 qs.sql_query.set_limits(start, stop)
450 return list(qs)[:: k.step] if k.step else qs
451
452 qs = self._chain()
453 qs.sql_query.set_limits(k, k + 1)
454 qs._fetch_all()
455 assert qs._result_cache is not None # _fetch_all guarantees this
456 return qs._result_cache[0]
457
458 def __class_getitem__(cls, *args: Any, **kwargs: Any) -> type[QuerySet[Any]]:
459 return cls
460
461 def __and__(self, other: QuerySet[T]) -> QuerySet[T]:
462 self._check_operator_queryset(other, "&")
463 self._merge_sanity_check(other)
464 if isinstance(other, EmptyQuerySet):
465 return other
466 if isinstance(self, EmptyQuerySet):
467 return self
468 combined = self._chain()
469 combined._merge_known_related_objects(other)
470 combined.sql_query.combine(other.sql_query, sql.AND)
471 return combined
472
473 def __or__(self, other: QuerySet[T]) -> QuerySet[T]:
474 self._check_operator_queryset(other, "|")
475 self._merge_sanity_check(other)
476 if isinstance(self, EmptyQuerySet):
477 return other
478 if isinstance(other, EmptyQuerySet):
479 return self
480 query = (
481 self
482 if self.sql_query.can_filter()
483 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
484 )
485 combined = query._chain()
486 combined._merge_known_related_objects(other)
487 if not other.sql_query.can_filter():
488 other = other.model._model_meta.base_queryset.filter(
489 id__in=other.values("id")
490 )
491 combined.sql_query.combine(other.sql_query, sql.OR)
492 return combined
493
494 def __xor__(self, other: QuerySet[T]) -> QuerySet[T]:
495 self._check_operator_queryset(other, "^")
496 self._merge_sanity_check(other)
497 if isinstance(self, EmptyQuerySet):
498 return other
499 if isinstance(other, EmptyQuerySet):
500 return self
501 query = (
502 self
503 if self.sql_query.can_filter()
504 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
505 )
506 combined = query._chain()
507 combined._merge_known_related_objects(other)
508 if not other.sql_query.can_filter():
509 other = other.model._model_meta.base_queryset.filter(
510 id__in=other.values("id")
511 )
512 combined.sql_query.combine(other.sql_query, sql.XOR)
513 return combined
514
515 ####################################
516 # METHODS THAT DO DATABASE QUERIES #
517 ####################################
518
519 def _iterator(self, use_chunked_fetch: bool, chunk_size: int | None) -> Iterator[T]:
520 iterable = self._iterable_class(
521 self,
522 chunked_fetch=use_chunked_fetch,
523 chunk_size=chunk_size or 2000,
524 )
525 if not self._prefetch_related_lookups or chunk_size is None:
526 yield from iterable
527 return
528
529 iterator = iter(iterable)
530 while results := list(islice(iterator, chunk_size)):
531 prefetch_related_objects(results, *self._prefetch_related_lookups)
532 yield from results
533
534 def iterator(self, chunk_size: int | None = None) -> Iterator[T]:
535 """
536 An iterator over the results from applying this QuerySet to the
537 database. chunk_size must be provided for QuerySets that prefetch
538 related objects. Otherwise, a default chunk_size of 2000 is supplied.
539 """
540 if chunk_size is None:
541 if self._prefetch_related_lookups:
542 raise ValueError(
543 "chunk_size must be provided when using QuerySet.iterator() after "
544 "prefetch_related()."
545 )
546 elif chunk_size <= 0:
547 raise ValueError("Chunk size must be strictly positive.")
548 use_chunked_fetch = not db_connection.settings_dict.get(
549 "DISABLE_SERVER_SIDE_CURSORS"
550 )
551 return self._iterator(use_chunked_fetch, chunk_size)
552
553 def aggregate(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
554 """
555 Return a dictionary containing the calculations (aggregation)
556 over the current queryset.
557
558 If args is present the expression is passed as a kwarg using
559 the Aggregate object's default alias.
560 """
561 if self.sql_query.distinct_fields:
562 raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
563 self._validate_values_are_expressions(
564 (*args, *kwargs.values()), method_name="aggregate"
565 )
566 for arg in args:
567 # The default_alias property raises TypeError if default_alias
568 # can't be set automatically or AttributeError if it isn't an
569 # attribute.
570 try:
571 arg.default_alias
572 except (AttributeError, TypeError):
573 raise TypeError("Complex aggregates require an alias")
574 kwargs[arg.default_alias] = arg
575
576 return self.sql_query.chain().get_aggregation(kwargs)
577
578 def count(self) -> int:
579 """
580 Perform a SELECT COUNT() and return the number of records as an
581 integer.
582
583 If the QuerySet is already fully cached, return the length of the
584 cached results set to avoid multiple SELECT COUNT(*) calls.
585 """
586 if self._result_cache is not None:
587 return len(self._result_cache)
588
589 return self.sql_query.get_count()
590
591 def get(self, *args: Any, **kwargs: Any) -> T:
592 """
593 Perform the query and return a single object matching the given
594 keyword arguments.
595 """
596 if self.sql_query.combinator and (args or kwargs):
597 raise NotSupportedError(
598 f"Calling QuerySet.get(...) with filters after {self.sql_query.combinator}() is not "
599 "supported."
600 )
601 clone = (
602 self._chain() if self.sql_query.combinator else self.filter(*args, **kwargs)
603 )
604 if self.sql_query.can_filter() and not self.sql_query.distinct_fields:
605 clone = clone.order_by()
606 limit = None
607 if (
608 not clone.sql_query.select_for_update
609 or db_connection.features.supports_select_for_update_with_limit
610 ):
611 limit = MAX_GET_RESULTS
612 clone.sql_query.set_limits(high=limit)
613 num = len(clone)
614 if num == 1:
615 assert clone._result_cache is not None # len() fetches results
616 return clone._result_cache[0]
617 if not num:
618 raise self.model.DoesNotExist(
619 f"{self.model.model_options.object_name} matching query does not exist."
620 )
621 raise self.model.MultipleObjectsReturned(
622 "get() returned more than one {} -- it returned {}!".format(
623 self.model.model_options.object_name,
624 num if not limit or num < limit else "more than %s" % (limit - 1),
625 )
626 )
627
628 def get_or_none(self, *args: Any, **kwargs: Any) -> T | None:
629 """
630 Perform the query and return a single object matching the given
631 keyword arguments, or None if no object is found.
632 """
633 try:
634 return self.get(*args, **kwargs)
635 except self.model.DoesNotExist:
636 return None
637
638 def create(self, **kwargs: Any) -> T:
639 """
640 Create a new object with the given kwargs, saving it to the database
641 and returning the created object.
642 """
643 obj = self.model(**kwargs)
644 self._for_write = True
645 obj.save(force_insert=True)
646 return obj
647
648 def _prepare_for_bulk_create(self, objs: list[T]) -> None:
649 id_field = self.model._model_meta.get_forward_field("id")
650 for obj in objs:
651 if obj.id is None:
652 # Populate new primary key values.
653 obj.id = id_field.get_id_value_on_save(obj)
654 obj._prepare_related_fields_for_save(operation_name="bulk_create")
655
656 def _check_bulk_create_options(
657 self,
658 update_conflicts: bool,
659 update_fields: list[Field] | None,
660 unique_fields: list[Field] | None,
661 ) -> OnConflict | None:
662 db_features = db_connection.features
663 if update_conflicts:
664 if not db_features.supports_update_conflicts:
665 raise NotSupportedError(
666 "This database backend does not support updating conflicts."
667 )
668 if not update_fields:
669 raise ValueError(
670 "Fields that will be updated when a row insertion fails "
671 "on conflicts must be provided."
672 )
673 if unique_fields and not db_features.supports_update_conflicts_with_target:
674 raise NotSupportedError(
675 "This database backend does not support updating "
676 "conflicts with specifying unique fields that can trigger "
677 "the upsert."
678 )
679 if not unique_fields and db_features.supports_update_conflicts_with_target:
680 raise ValueError(
681 "Unique fields that can trigger the upsert must be provided."
682 )
683 # Updating primary keys and non-concrete fields is forbidden.
684 from plain.models.fields.related import ManyToManyField
685
686 if any(
687 not f.concrete or isinstance(f, ManyToManyField) for f in update_fields
688 ):
689 raise ValueError(
690 "bulk_create() can only be used with concrete fields in "
691 "update_fields."
692 )
693 if any(f.primary_key for f in update_fields):
694 raise ValueError(
695 "bulk_create() cannot be used with primary keys in update_fields."
696 )
697 if unique_fields:
698 from plain.models.fields.related import ManyToManyField
699
700 if any(
701 not f.concrete or isinstance(f, ManyToManyField)
702 for f in unique_fields
703 ):
704 raise ValueError(
705 "bulk_create() can only be used with concrete fields "
706 "in unique_fields."
707 )
708 return OnConflict.UPDATE
709 return None
710
711 def bulk_create(
712 self,
713 objs: Sequence[T],
714 batch_size: int | None = None,
715 update_conflicts: bool = False,
716 update_fields: list[str] | None = None,
717 unique_fields: list[str] | None = None,
718 ) -> list[T]:
719 """
720 Insert each of the instances into the database. Do *not* call
721 save() on each of the instances, and do not set the primary key attribute if it is an
722 autoincrement field (except if features.can_return_rows_from_bulk_insert=True).
723 Multi-table models are not supported.
724 """
725 # When you bulk insert you don't get the primary keys back (if it's an
726 # autoincrement, except if can_return_rows_from_bulk_insert=True), so
727 # you can't insert into the child tables which references this. There
728 # are two workarounds:
729 # 1) This could be implemented if you didn't have an autoincrement pk
730 # 2) You could do it by doing O(n) normal inserts into the parent
731 # tables to get the primary keys back and then doing a single bulk
732 # insert into the childmost table.
733 # We currently set the primary keys on the objects when using
734 # PostgreSQL via the RETURNING ID clause. It should be possible for
735 # Oracle as well, but the semantics for extracting the primary keys is
736 # trickier so it's not done yet.
737 if batch_size is not None and batch_size <= 0:
738 raise ValueError("Batch size must be a positive integer.")
739
740 objs = list(objs)
741 if not objs:
742 return objs
743 meta = self.model._model_meta
744 unique_fields_objs: list[Field] | None = None
745 update_fields_objs: list[Field] | None = None
746 if unique_fields:
747 unique_fields_objs = [
748 meta.get_forward_field(name) for name in unique_fields
749 ]
750 if update_fields:
751 update_fields_objs = [
752 meta.get_forward_field(name) for name in update_fields
753 ]
754 on_conflict = self._check_bulk_create_options(
755 update_conflicts,
756 update_fields_objs,
757 unique_fields_objs,
758 )
759 self._for_write = True
760 fields = meta.concrete_fields
761 self._prepare_for_bulk_create(objs)
762 with transaction.atomic(savepoint=False):
763 objs_with_id, objs_without_id = partition(lambda o: o.id is None, objs)
764 if objs_with_id:
765 returned_columns = self._batched_insert(
766 objs_with_id,
767 fields,
768 batch_size,
769 on_conflict=on_conflict,
770 update_fields=update_fields_objs,
771 unique_fields=unique_fields_objs,
772 )
773 id_field = meta.get_forward_field("id")
774 for obj_with_id, results in zip(objs_with_id, returned_columns):
775 for result, field in zip(results, meta.db_returning_fields):
776 if field != id_field:
777 setattr(obj_with_id, field.attname, result)
778 for obj_with_id in objs_with_id:
779 obj_with_id._state.adding = False
780 if objs_without_id:
781 fields = [f for f in fields if not isinstance(f, PrimaryKeyField)]
782 returned_columns = self._batched_insert(
783 objs_without_id,
784 fields,
785 batch_size,
786 on_conflict=on_conflict,
787 update_fields=update_fields_objs,
788 unique_fields=unique_fields_objs,
789 )
790 if (
791 db_connection.features.can_return_rows_from_bulk_insert
792 and on_conflict is None
793 ):
794 assert len(returned_columns) == len(objs_without_id)
795 for obj_without_id, results in zip(objs_without_id, returned_columns):
796 for result, field in zip(results, meta.db_returning_fields):
797 setattr(obj_without_id, field.attname, result)
798 obj_without_id._state.adding = False
799
800 return objs
801
802 def bulk_update(
803 self, objs: Sequence[T], fields: list[str], batch_size: int | None = None
804 ) -> int:
805 """
806 Update the given fields in each of the given objects in the database.
807 """
808 if batch_size is not None and batch_size <= 0:
809 raise ValueError("Batch size must be a positive integer.")
810 if not fields:
811 raise ValueError("Field names must be given to bulk_update().")
812 objs_tuple = tuple(objs)
813 if any(obj.id is None for obj in objs_tuple):
814 raise ValueError("All bulk_update() objects must have a primary key set.")
815 fields_list = [
816 self.model._model_meta.get_forward_field(name) for name in fields
817 ]
818 from plain.models.fields.related import ManyToManyField
819
820 if any(not f.concrete or isinstance(f, ManyToManyField) for f in fields_list):
821 raise ValueError("bulk_update() can only be used with concrete fields.")
822 if any(f.primary_key for f in fields_list):
823 raise ValueError("bulk_update() cannot be used with primary key fields.")
824 if not objs_tuple:
825 return 0
826 for obj in objs_tuple:
827 obj._prepare_related_fields_for_save(
828 operation_name="bulk_update", fields=fields_list
829 )
830 # PK is used twice in the resulting update query, once in the filter
831 # and once in the WHEN. Each field will also have one CAST.
832 self._for_write = True
833 max_batch_size = db_connection.ops.bulk_batch_size(
834 ["id", "id"] + fields_list, objs_tuple
835 )
836 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
837 requires_casting = db_connection.features.requires_casted_case_in_updates
838 batches = (
839 objs_tuple[i : i + batch_size]
840 for i in range(0, len(objs_tuple), batch_size)
841 )
842 updates = []
843 for batch_objs in batches:
844 update_kwargs = {}
845 for field in fields_list:
846 when_statements = []
847 for obj in batch_objs:
848 attr = getattr(obj, field.attname)
849 if not isinstance(attr, ResolvableExpression):
850 attr = Value(attr, output_field=field)
851 when_statements.append(When(id=obj.id, then=attr))
852 case_statement = Case(*when_statements, output_field=field)
853 if requires_casting:
854 case_statement = Cast(case_statement, output_field=field)
855 update_kwargs[field.attname] = case_statement
856 updates.append(([obj.id for obj in batch_objs], update_kwargs))
857 rows_updated = 0
858 queryset = self._chain()
859 with transaction.atomic(savepoint=False):
860 for ids, update_kwargs in updates:
861 rows_updated += queryset.filter(id__in=ids).update(**update_kwargs)
862 return rows_updated
863
864 def get_or_create(
865 self, defaults: dict[str, Any] | None = None, **kwargs: Any
866 ) -> tuple[T, bool]:
867 """
868 Look up an object with the given kwargs, creating one if necessary.
869 Return a tuple of (object, created), where created is a boolean
870 specifying whether an object was created.
871 """
872 # The get() needs to be targeted at the write database in order
873 # to avoid potential transaction consistency problems.
874 self._for_write = True
875 try:
876 return self.get(**kwargs), False
877 except self.model.DoesNotExist:
878 params = self._extract_model_params(defaults, **kwargs)
879 # Try to create an object using passed params.
880 try:
881 with transaction.atomic():
882 params = dict(resolve_callables(params))
883 return self.create(**params), True
884 except (IntegrityError, ValidationError):
885 # Since create() also validates by default,
886 # we can get any kind of ValidationError here,
887 # or it can flow through and get an IntegrityError from the database.
888 # The main thing we're concerned about is uniqueness failures,
889 # but ValidationError could include other things too.
890 # In all cases though it should be fine to try the get() again
891 # and return an existing object.
892 try:
893 return self.get(**kwargs), False
894 except self.model.DoesNotExist:
895 pass
896 raise
897
898 def update_or_create(
899 self,
900 defaults: dict[str, Any] | None = None,
901 create_defaults: dict[str, Any] | None = None,
902 **kwargs: Any,
903 ) -> tuple[T, bool]:
904 """
905 Look up an object with the given kwargs, updating one with defaults
906 if it exists, otherwise create a new one. Optionally, an object can
907 be created with different values than defaults by using
908 create_defaults.
909 Return a tuple (object, created), where created is a boolean
910 specifying whether an object was created.
911 """
912 if create_defaults is None:
913 update_defaults = create_defaults = defaults or {}
914 else:
915 update_defaults = defaults or {}
916 self._for_write = True
917 with transaction.atomic():
918 # Lock the row so that a concurrent update is blocked until
919 # update_or_create() has performed its save.
920 obj, created = self.select_for_update().get_or_create(
921 create_defaults, **kwargs
922 )
923 if created:
924 return obj, created
925 for k, v in resolve_callables(update_defaults):
926 setattr(obj, k, v)
927
928 update_fields = set(update_defaults)
929 concrete_field_names = self.model._model_meta._non_pk_concrete_field_names
930 # update_fields does not support non-concrete fields.
931 if concrete_field_names.issuperset(update_fields):
932 # Add fields which are set on pre_save(), e.g. auto_now fields.
933 # This is to maintain backward compatibility as these fields
934 # are not updated unless explicitly specified in the
935 # update_fields list.
936 for field in self.model._model_meta.local_concrete_fields:
937 if not (
938 field.primary_key or field.__class__.pre_save is Field.pre_save
939 ):
940 update_fields.add(field.name)
941 if field.name != field.attname:
942 update_fields.add(field.attname)
943 obj.save(update_fields=update_fields)
944 else:
945 obj.save()
946 return obj, False
947
948 def _extract_model_params(
949 self, defaults: dict[str, Any] | None, **kwargs: Any
950 ) -> dict[str, Any]:
951 """
952 Prepare `params` for creating a model instance based on the given
953 kwargs; for use by get_or_create().
954 """
955 defaults = defaults or {}
956 params = {k: v for k, v in kwargs.items() if LOOKUP_SEP not in k}
957 params.update(defaults)
958 property_names = self.model._model_meta._property_names
959 invalid_params = []
960 for param in params:
961 try:
962 self.model._model_meta.get_field(param)
963 except FieldDoesNotExist:
964 # It's okay to use a model's property if it has a setter.
965 if not (param in property_names and getattr(self.model, param).fset):
966 invalid_params.append(param)
967 if invalid_params:
968 raise FieldError(
969 "Invalid field name(s) for model {}: '{}'.".format(
970 self.model.model_options.object_name,
971 "', '".join(sorted(invalid_params)),
972 )
973 )
974 return params
975
976 def first(self) -> T | None:
977 """Return the first object of a query or None if no match is found."""
978 for obj in self[:1]:
979 return obj
980 return None
981
982 def last(self) -> T | None:
983 """Return the last object of a query or None if no match is found."""
984 queryset = self.reverse()
985 for obj in queryset[:1]:
986 return obj
987 return None
988
989 def in_bulk(
990 self, id_list: list[Any] | None = None, *, field_name: str = "id"
991 ) -> dict[Any, T]:
992 """
993 Return a dictionary mapping each of the given IDs to the object with
994 that ID. If `id_list` isn't provided, evaluate the entire QuerySet.
995 """
996 if self.sql_query.is_sliced:
997 raise TypeError("Cannot use 'limit' or 'offset' with in_bulk().")
998 meta = self.model._model_meta
999 unique_fields = [
1000 constraint.fields[0]
1001 for constraint in self.model.model_options.total_unique_constraints
1002 if len(constraint.fields) == 1
1003 ]
1004 if (
1005 field_name != "id"
1006 and not meta.get_forward_field(field_name).primary_key
1007 and field_name not in unique_fields
1008 and self.sql_query.distinct_fields != (field_name,)
1009 ):
1010 raise ValueError(
1011 f"in_bulk()'s field_name must be a unique field but {field_name!r} isn't."
1012 )
1013 if id_list is not None:
1014 if not id_list:
1015 return {}
1016 filter_key = f"{field_name}__in"
1017 batch_size = db_connection.features.max_query_params
1018 id_list_tuple = tuple(id_list)
1019 # If the database has a limit on the number of query parameters
1020 # (e.g. SQLite), retrieve objects in batches if necessary.
1021 if batch_size and batch_size < len(id_list_tuple):
1022 qs: tuple[T, ...] = ()
1023 for offset in range(0, len(id_list_tuple), batch_size):
1024 batch = id_list_tuple[offset : offset + batch_size]
1025 qs += tuple(self.filter(**{filter_key: batch}))
1026 else:
1027 qs = self.filter(**{filter_key: id_list_tuple})
1028 else:
1029 qs = self._chain()
1030 return {getattr(obj, field_name): obj for obj in qs}
1031
1032 def delete(self) -> tuple[int, dict[str, int]]:
1033 """Delete the records in the current QuerySet."""
1034 self._not_support_combined_queries("delete")
1035 if self.sql_query.is_sliced:
1036 raise TypeError("Cannot use 'limit' or 'offset' with delete().")
1037 if self.sql_query.distinct or self.sql_query.distinct_fields:
1038 raise TypeError("Cannot call delete() after .distinct().")
1039 if self._fields is not None:
1040 raise TypeError("Cannot call delete() after .values() or .values_list()")
1041
1042 del_query = self._chain()
1043
1044 # The delete is actually 2 queries - one to find related objects,
1045 # and one to delete. Make sure that the discovery of related
1046 # objects is performed on the same database as the deletion.
1047 del_query._for_write = True
1048
1049 # Disable non-supported fields.
1050 del_query.sql_query.select_for_update = False
1051 del_query.sql_query.select_related = False
1052 del_query.sql_query.clear_ordering(force=True)
1053
1054 from plain.models.deletion import Collector
1055
1056 collector = Collector(origin=self)
1057 collector.collect(del_query)
1058 deleted, _rows_count = collector.delete()
1059
1060 # Clear the result cache, in case this QuerySet gets reused.
1061 self._result_cache = None
1062 return deleted, _rows_count
1063
1064 def _raw_delete(self) -> int:
1065 """
1066 Delete objects found from the given queryset in single direct SQL
1067 query. No signals are sent and there is no protection for cascades.
1068 """
1069 query = self.sql_query.clone()
1070 query.__class__ = sql.DeleteQuery
1071 cursor = query.get_compiler().execute_sql(CURSOR)
1072 if cursor:
1073 with cursor:
1074 return cursor.rowcount
1075 return 0
1076
1077 def update(self, **kwargs: Any) -> int:
1078 """
1079 Update all elements in the current QuerySet, setting all the given
1080 fields to the appropriate values.
1081 """
1082 self._not_support_combined_queries("update")
1083 if self.sql_query.is_sliced:
1084 raise TypeError("Cannot update a query once a slice has been taken.")
1085 self._for_write = True
1086 query = self.sql_query.chain(sql.UpdateQuery)
1087 query.add_update_values(kwargs)
1088
1089 # Inline annotations in order_by(), if possible.
1090 new_order_by = []
1091 for col in query.order_by:
1092 alias = col
1093 descending = False
1094 if isinstance(alias, str) and alias.startswith("-"):
1095 alias = alias.removeprefix("-")
1096 descending = True
1097 if annotation := query.annotations.get(alias):
1098 if getattr(annotation, "contains_aggregate", False):
1099 raise FieldError(
1100 f"Cannot update when ordering by an aggregate: {annotation}"
1101 )
1102 if descending:
1103 annotation = annotation.desc()
1104 new_order_by.append(annotation)
1105 else:
1106 new_order_by.append(col)
1107 query.order_by = tuple(new_order_by)
1108
1109 # Clear any annotations so that they won't be present in subqueries.
1110 query.annotations = {}
1111 with transaction.mark_for_rollback_on_error():
1112 rows = query.get_compiler().execute_sql(CURSOR)
1113 self._result_cache = None
1114 return rows
1115
1116 def _update(self, values: list[tuple[Field, Any, Any]]) -> int:
1117 """
1118 A version of update() that accepts field objects instead of field names.
1119 Used primarily for model saving and not intended for use by general
1120 code (it requires too much poking around at model internals to be
1121 useful at that level).
1122 """
1123 if self.sql_query.is_sliced:
1124 raise TypeError("Cannot update a query once a slice has been taken.")
1125 query = self.sql_query.chain(sql.UpdateQuery)
1126 query.add_update_fields(values)
1127 # Clear any annotations so that they won't be present in subqueries.
1128 query.annotations = {}
1129 self._result_cache = None
1130 return query.get_compiler().execute_sql(CURSOR)
1131
1132 def exists(self) -> bool:
1133 """
1134 Return True if the QuerySet would have any results, False otherwise.
1135 """
1136 if self._result_cache is None:
1137 return self.sql_query.has_results()
1138 return bool(self._result_cache)
1139
1140 def contains(self, obj: T) -> bool:
1141 """
1142 Return True if the QuerySet contains the provided obj,
1143 False otherwise.
1144 """
1145 self._not_support_combined_queries("contains")
1146 if self._fields is not None:
1147 raise TypeError(
1148 "Cannot call QuerySet.contains() after .values() or .values_list()."
1149 )
1150 try:
1151 if obj.__class__ != self.model:
1152 return False
1153 except AttributeError:
1154 raise TypeError("'obj' must be a model instance.")
1155 if obj.id is None:
1156 raise ValueError("QuerySet.contains() cannot be used on unsaved objects.")
1157 if self._result_cache is not None:
1158 return obj in self._result_cache
1159 return self.filter(id=obj.id).exists()
1160
1161 def _prefetch_related_objects(self) -> None:
1162 # This method can only be called once the result cache has been filled.
1163 assert self._result_cache is not None
1164 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1165 self._prefetch_done = True
1166
1167 def explain(self, *, format: str | None = None, **options: Any) -> str:
1168 """
1169 Runs an EXPLAIN on the SQL query this QuerySet would perform, and
1170 returns the results.
1171 """
1172 return self.sql_query.explain(format=format, **options)
1173
1174 ##################################################
1175 # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
1176 ##################################################
1177
1178 def raw(
1179 self,
1180 raw_query: str,
1181 params: tuple[Any, ...] = (),
1182 translations: dict[str, str] | None = None,
1183 ) -> RawQuerySet:
1184 qs = RawQuerySet(
1185 raw_query,
1186 model=self.model,
1187 params=params,
1188 translations=translations,
1189 )
1190 qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
1191 return qs
1192
1193 def _values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1194 clone = self._chain()
1195 if expressions:
1196 clone = clone.annotate(**expressions)
1197 clone._fields = fields
1198 clone.sql_query.set_values(list(fields))
1199 return clone
1200
1201 def values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1202 fields += tuple(expressions)
1203 clone = self._values(*fields, **expressions)
1204 clone._iterable_class = ValuesIterable
1205 return clone
1206
1207 def values_list(self, *fields: str, flat: bool = False) -> QuerySet[Any]:
1208 if flat and len(fields) > 1:
1209 raise TypeError(
1210 "'flat' is not valid when values_list is called with more than one "
1211 "field."
1212 )
1213
1214 field_names = {f for f in fields if not isinstance(f, ResolvableExpression)}
1215 _fields = []
1216 expressions = {}
1217 counter = 1
1218 for field in fields:
1219 if isinstance(field, ResolvableExpression):
1220 field_id_prefix = getattr(
1221 field, "default_alias", field.__class__.__name__.lower()
1222 )
1223 while True:
1224 field_id = field_id_prefix + str(counter)
1225 counter += 1
1226 if field_id not in field_names:
1227 break
1228 expressions[field_id] = field
1229 _fields.append(field_id)
1230 else:
1231 _fields.append(field)
1232
1233 clone = self._values(*_fields, **expressions)
1234 clone._iterable_class = FlatValuesListIterable if flat else ValuesListIterable
1235 return clone
1236
1237 def dates(self, field_name: str, kind: str, order: str = "ASC") -> QuerySet[Any]:
1238 """
1239 Return a list of date objects representing all available dates for
1240 the given field_name, scoped to 'kind'.
1241 """
1242 if kind not in ("year", "month", "week", "day"):
1243 raise ValueError("'kind' must be one of 'year', 'month', 'week', or 'day'.")
1244 if order not in ("ASC", "DESC"):
1245 raise ValueError("'order' must be either 'ASC' or 'DESC'.")
1246 return (
1247 self.annotate(
1248 datefield=Trunc(field_name, kind, output_field=DateField()),
1249 plain_field=F(field_name),
1250 )
1251 .values_list("datefield", flat=True)
1252 .distinct()
1253 .filter(plain_field__isnull=False)
1254 .order_by(("-" if order == "DESC" else "") + "datefield")
1255 )
1256
1257 def datetimes(
1258 self,
1259 field_name: str,
1260 kind: str,
1261 order: str = "ASC",
1262 tzinfo: tzinfo | None = None,
1263 ) -> QuerySet[Any]:
1264 """
1265 Return a list of datetime objects representing all available
1266 datetimes for the given field_name, scoped to 'kind'.
1267 """
1268 if kind not in ("year", "month", "week", "day", "hour", "minute", "second"):
1269 raise ValueError(
1270 "'kind' must be one of 'year', 'month', 'week', 'day', "
1271 "'hour', 'minute', or 'second'."
1272 )
1273 if order not in ("ASC", "DESC"):
1274 raise ValueError("'order' must be either 'ASC' or 'DESC'.")
1275
1276 if tzinfo is None:
1277 tzinfo = timezone.get_current_timezone()
1278
1279 return (
1280 self.annotate(
1281 datetimefield=Trunc(
1282 field_name,
1283 kind,
1284 output_field=DateTimeField(),
1285 tzinfo=tzinfo,
1286 ),
1287 plain_field=F(field_name),
1288 )
1289 .values_list("datetimefield", flat=True)
1290 .distinct()
1291 .filter(plain_field__isnull=False)
1292 .order_by(("-" if order == "DESC" else "") + "datetimefield")
1293 )
1294
1295 def none(self) -> QuerySet[T]:
1296 """Return an empty QuerySet."""
1297 clone = self._chain()
1298 clone.sql_query.set_empty()
1299 return clone
1300
1301 ##################################################################
1302 # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
1303 ##################################################################
1304
1305 def all(self) -> Self:
1306 """
1307 Return a new QuerySet that is a copy of the current one. This allows a
1308 QuerySet to proxy for a model queryset in some cases.
1309 """
1310 return self._chain()
1311
1312 def filter(self, *args: Any, **kwargs: Any) -> Self:
1313 """
1314 Return a new QuerySet instance with the args ANDed to the existing
1315 set.
1316 """
1317 self._not_support_combined_queries("filter")
1318 return self._filter_or_exclude(False, args, kwargs)
1319
1320 def exclude(self, *args: Any, **kwargs: Any) -> Self:
1321 """
1322 Return a new QuerySet instance with NOT (args) ANDed to the existing
1323 set.
1324 """
1325 self._not_support_combined_queries("exclude")
1326 return self._filter_or_exclude(True, args, kwargs)
1327
1328 def _filter_or_exclude(
1329 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1330 ) -> Self:
1331 if (args or kwargs) and self.sql_query.is_sliced:
1332 raise TypeError("Cannot filter a query once a slice has been taken.")
1333 clone = self._chain()
1334 if self._defer_next_filter:
1335 self._defer_next_filter = False
1336 clone._deferred_filter = negate, args, kwargs
1337 else:
1338 clone._filter_or_exclude_inplace(negate, args, kwargs)
1339 return clone
1340
1341 def _filter_or_exclude_inplace(
1342 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1343 ) -> None:
1344 if negate:
1345 self._query.add_q(~Q(*args, **kwargs))
1346 else:
1347 self._query.add_q(Q(*args, **kwargs))
1348
1349 def complex_filter(self, filter_obj: Q | dict[str, Any]) -> QuerySet[T]:
1350 """
1351 Return a new QuerySet instance with filter_obj added to the filters.
1352
1353 filter_obj can be a Q object or a dictionary of keyword lookup
1354 arguments.
1355
1356 This exists to support framework features such as 'limit_choices_to',
1357 and usually it will be more natural to use other methods.
1358 """
1359 if isinstance(filter_obj, Q):
1360 clone = self._chain()
1361 clone.sql_query.add_q(filter_obj)
1362 return clone
1363 else:
1364 return self._filter_or_exclude(False, args=(), kwargs=filter_obj)
1365
1366 def _combinator_query(
1367 self, combinator: str, *other_qs: QuerySet[T], all: bool = False
1368 ) -> QuerySet[T]:
1369 # Clone the query to inherit the select list and everything
1370 clone = self._chain()
1371 # Clear limits and ordering so they can be reapplied
1372 clone.sql_query.clear_ordering(force=True)
1373 clone.sql_query.clear_limits()
1374 clone.sql_query.combined_queries = (self.sql_query,) + tuple(
1375 qs.sql_query for qs in other_qs
1376 )
1377 clone.sql_query.combinator = combinator
1378 clone.sql_query.combinator_all = all
1379 return clone
1380
1381 def union(self, *other_qs: QuerySet[T], all: bool = False) -> QuerySet[T]:
1382 # If the query is an EmptyQuerySet, combine all nonempty querysets.
1383 if isinstance(self, EmptyQuerySet):
1384 qs = [q for q in other_qs if not isinstance(q, EmptyQuerySet)]
1385 if not qs:
1386 return self
1387 if len(qs) == 1:
1388 return qs[0]
1389 return qs[0]._combinator_query("union", *qs[1:], all=all)
1390 return self._combinator_query("union", *other_qs, all=all)
1391
1392 def intersection(self, *other_qs: QuerySet[T]) -> QuerySet[T]:
1393 # If any query is an EmptyQuerySet, return it.
1394 if isinstance(self, EmptyQuerySet):
1395 return self
1396 for other in other_qs:
1397 if isinstance(other, EmptyQuerySet):
1398 return other
1399 return self._combinator_query("intersection", *other_qs)
1400
1401 def difference(self, *other_qs: QuerySet[T]) -> QuerySet[T]:
1402 # If the query is an EmptyQuerySet, return it.
1403 if isinstance(self, EmptyQuerySet):
1404 return self
1405 return self._combinator_query("difference", *other_qs)
1406
1407 def select_for_update(
1408 self,
1409 nowait: bool = False,
1410 skip_locked: bool = False,
1411 of: tuple[str, ...] = (),
1412 no_key: bool = False,
1413 ) -> QuerySet[T]:
1414 """
1415 Return a new QuerySet instance that will select objects with a
1416 FOR UPDATE lock.
1417 """
1418 if nowait and skip_locked:
1419 raise ValueError("The nowait option cannot be used with skip_locked.")
1420 obj = self._chain()
1421 obj._for_write = True
1422 obj.sql_query.select_for_update = True
1423 obj.sql_query.select_for_update_nowait = nowait
1424 obj.sql_query.select_for_update_skip_locked = skip_locked
1425 obj.sql_query.select_for_update_of = of
1426 obj.sql_query.select_for_no_key_update = no_key
1427 return obj
1428
1429 def select_related(self, *fields: str | None) -> Self:
1430 """
1431 Return a new QuerySet instance that will select related objects.
1432
1433 If fields are specified, they must be ForeignKeyField fields and only those
1434 related objects are included in the selection.
1435
1436 If select_related(None) is called, clear the list.
1437 """
1438 self._not_support_combined_queries("select_related")
1439 if self._fields is not None:
1440 raise TypeError(
1441 "Cannot call select_related() after .values() or .values_list()"
1442 )
1443
1444 obj = self._chain()
1445 if fields == (None,):
1446 obj.sql_query.select_related = False
1447 elif fields:
1448 obj.sql_query.add_select_related(list(fields)) # type: ignore[arg-type]
1449 else:
1450 obj.sql_query.select_related = True
1451 return obj
1452
1453 def prefetch_related(self, *lookups: str | Prefetch | None) -> Self:
1454 """
1455 Return a new QuerySet instance that will prefetch the specified
1456 Many-To-One and Many-To-Many related objects when the QuerySet is
1457 evaluated.
1458
1459 When prefetch_related() is called more than once, append to the list of
1460 prefetch lookups. If prefetch_related(None) is called, clear the list.
1461 """
1462 self._not_support_combined_queries("prefetch_related")
1463 clone = self._chain()
1464 if lookups == (None,):
1465 clone._prefetch_related_lookups = ()
1466 else:
1467 for lookup in lookups:
1468 lookup_str: str
1469 if isinstance(lookup, Prefetch):
1470 lookup_str = lookup.prefetch_to
1471 else:
1472 assert isinstance(lookup, str)
1473 lookup_str = lookup
1474 lookup_str = lookup_str.split(LOOKUP_SEP, 1)[0]
1475 if lookup_str in self.sql_query._filtered_relations:
1476 raise ValueError(
1477 "prefetch_related() is not supported with FilteredRelation."
1478 )
1479 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1480 return clone
1481
1482 def annotate(self, *args: Any, **kwargs: Any) -> Self:
1483 """
1484 Return a query set in which the returned objects have been annotated
1485 with extra data or aggregations.
1486 """
1487 self._not_support_combined_queries("annotate")
1488 return self._annotate(args, kwargs, select=True)
1489
1490 def alias(self, *args: Any, **kwargs: Any) -> Self:
1491 """
1492 Return a query set with added aliases for extra data or aggregations.
1493 """
1494 self._not_support_combined_queries("alias")
1495 return self._annotate(args, kwargs, select=False)
1496
1497 def _annotate(
1498 self, args: tuple[Any, ...], kwargs: dict[str, Any], select: bool = True
1499 ) -> Self:
1500 self._validate_values_are_expressions(
1501 args + tuple(kwargs.values()), method_name="annotate"
1502 )
1503 annotations = {}
1504 for arg in args:
1505 # The default_alias property may raise a TypeError.
1506 try:
1507 if arg.default_alias in kwargs:
1508 raise ValueError(
1509 f"The named annotation '{arg.default_alias}' conflicts with the "
1510 "default name for another annotation."
1511 )
1512 except TypeError:
1513 raise TypeError("Complex annotations require an alias")
1514 annotations[arg.default_alias] = arg
1515 annotations.update(kwargs)
1516
1517 clone = self._chain()
1518 names = self._fields
1519 if names is None:
1520 names = set(
1521 chain.from_iterable(
1522 (field.name, field.attname)
1523 if hasattr(field, "attname")
1524 else (field.name,)
1525 for field in self.model._model_meta.get_fields()
1526 )
1527 )
1528
1529 for alias, annotation in annotations.items():
1530 if alias in names:
1531 raise ValueError(
1532 f"The annotation '{alias}' conflicts with a field on the model."
1533 )
1534 if isinstance(annotation, FilteredRelation):
1535 clone.sql_query.add_filtered_relation(annotation, alias)
1536 else:
1537 clone.sql_query.add_annotation(
1538 annotation,
1539 alias,
1540 select=select,
1541 )
1542 for alias, annotation in clone.sql_query.annotations.items():
1543 if alias in annotations and annotation.contains_aggregate:
1544 if clone._fields is None:
1545 clone.sql_query.group_by = True
1546 else:
1547 clone.sql_query.set_group_by()
1548 break
1549
1550 return clone
1551
1552 def order_by(self, *field_names: str) -> Self:
1553 """Return a new QuerySet instance with the ordering changed."""
1554 if self.sql_query.is_sliced:
1555 raise TypeError("Cannot reorder a query once a slice has been taken.")
1556 obj = self._chain()
1557 obj.sql_query.clear_ordering(force=True, clear_default=False)
1558 obj.sql_query.add_ordering(*field_names)
1559 return obj
1560
1561 def distinct(self, *field_names: str) -> Self:
1562 """
1563 Return a new QuerySet instance that will select only distinct results.
1564 """
1565 self._not_support_combined_queries("distinct")
1566 if self.sql_query.is_sliced:
1567 raise TypeError(
1568 "Cannot create distinct fields once a slice has been taken."
1569 )
1570 obj = self._chain()
1571 obj.sql_query.add_distinct_fields(*field_names)
1572 return obj
1573
1574 def extra(
1575 self,
1576 select: dict[str, str] | None = None,
1577 where: list[str] | None = None,
1578 params: list[Any] | None = None,
1579 tables: list[str] | None = None,
1580 order_by: list[str] | None = None,
1581 select_params: list[Any] | None = None,
1582 ) -> QuerySet[T]:
1583 """Add extra SQL fragments to the query."""
1584 self._not_support_combined_queries("extra")
1585 if self.sql_query.is_sliced:
1586 raise TypeError("Cannot change a query once a slice has been taken.")
1587 clone = self._chain()
1588 clone.sql_query.add_extra(
1589 select or {},
1590 select_params,
1591 where or [],
1592 params or [],
1593 tables or [],
1594 tuple(order_by) if order_by else (),
1595 )
1596 return clone
1597
1598 def reverse(self) -> QuerySet[T]:
1599 """Reverse the ordering of the QuerySet."""
1600 if self.sql_query.is_sliced:
1601 raise TypeError("Cannot reverse a query once a slice has been taken.")
1602 clone = self._chain()
1603 clone.sql_query.standard_ordering = not clone.sql_query.standard_ordering
1604 return clone
1605
1606 def defer(self, *fields: str | None) -> QuerySet[T]:
1607 """
1608 Defer the loading of data for certain fields until they are accessed.
1609 Add the set of deferred fields to any existing set of deferred fields.
1610 The only exception to this is if None is passed in as the only
1611 parameter, in which case removal all deferrals.
1612 """
1613 self._not_support_combined_queries("defer")
1614 if self._fields is not None:
1615 raise TypeError("Cannot call defer() after .values() or .values_list()")
1616 clone = self._chain()
1617 if fields == (None,):
1618 clone.sql_query.clear_deferred_loading()
1619 else:
1620 clone.sql_query.add_deferred_loading(frozenset(fields)) # type: ignore[arg-type]
1621 return clone
1622
1623 def only(self, *fields: str) -> QuerySet[T]:
1624 """
1625 Essentially, the opposite of defer(). Only the fields passed into this
1626 method and that are not already specified as deferred are loaded
1627 immediately when the queryset is evaluated.
1628 """
1629 self._not_support_combined_queries("only")
1630 if self._fields is not None:
1631 raise TypeError("Cannot call only() after .values() or .values_list()")
1632 if fields == (None,):
1633 # Can only pass None to defer(), not only(), as the rest option.
1634 # That won't stop people trying to do this, so let's be explicit.
1635 raise TypeError("Cannot pass None as an argument to only().")
1636 for field in fields:
1637 field = field.split(LOOKUP_SEP, 1)[0]
1638 if field in self.sql_query._filtered_relations:
1639 raise ValueError("only() is not supported with FilteredRelation.")
1640 clone = self._chain()
1641 clone.sql_query.add_immediate_loading(set(fields))
1642 return clone
1643
1644 ###################################
1645 # PUBLIC INTROSPECTION ATTRIBUTES #
1646 ###################################
1647
1648 @property
1649 def ordered(self) -> bool:
1650 """
1651 Return True if the QuerySet is ordered -- i.e. has an order_by()
1652 clause or a default ordering on the model (or is empty).
1653 """
1654 if isinstance(self, EmptyQuerySet):
1655 return True
1656 if self.sql_query.extra_order_by or self.sql_query.order_by:
1657 return True
1658 elif (
1659 self.sql_query.default_ordering
1660 and self.sql_query.model
1661 and self.sql_query.model._model_meta.ordering # type: ignore[possibly-missing-attribute]
1662 and
1663 # A default ordering doesn't affect GROUP BY queries.
1664 not self.sql_query.group_by
1665 ):
1666 return True
1667 else:
1668 return False
1669
1670 ###################
1671 # PRIVATE METHODS #
1672 ###################
1673
1674 def _insert(
1675 self,
1676 objs: list[T],
1677 fields: list[Field],
1678 returning_fields: list[Field] | None = None,
1679 raw: bool = False,
1680 on_conflict: OnConflict | None = None,
1681 update_fields: list[Field] | None = None,
1682 unique_fields: list[Field] | None = None,
1683 ) -> list[tuple[Any, ...]] | None:
1684 """
1685 Insert a new record for the given model. This provides an interface to
1686 the InsertQuery class and is how Model.save() is implemented.
1687 """
1688 self._for_write = True
1689 query = sql.InsertQuery(
1690 self.model,
1691 on_conflict=on_conflict.value if on_conflict else None,
1692 update_fields=update_fields,
1693 unique_fields=unique_fields,
1694 )
1695 query.insert_values(fields, objs, raw=raw)
1696 return query.get_compiler().execute_sql(returning_fields)
1697
1698 def _batched_insert(
1699 self,
1700 objs: list[T],
1701 fields: list[Field],
1702 batch_size: int | None,
1703 on_conflict: OnConflict | None = None,
1704 update_fields: list[Field] | None = None,
1705 unique_fields: list[Field] | None = None,
1706 ) -> list[tuple[Any, ...]]:
1707 """
1708 Helper method for bulk_create() to insert objs one batch at a time.
1709 """
1710 ops = db_connection.ops
1711 max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
1712 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
1713 inserted_rows = []
1714 bulk_return = db_connection.features.can_return_rows_from_bulk_insert
1715 for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
1716 if bulk_return and on_conflict is None:
1717 inserted_rows.extend(
1718 self._insert( # type: ignore[arg-type]
1719 item,
1720 fields=fields,
1721 returning_fields=self.model._model_meta.db_returning_fields,
1722 )
1723 )
1724 else:
1725 self._insert(
1726 item,
1727 fields=fields,
1728 on_conflict=on_conflict,
1729 update_fields=update_fields,
1730 unique_fields=unique_fields,
1731 )
1732 return inserted_rows
1733
1734 def _chain(self) -> Self:
1735 """
1736 Return a copy of the current QuerySet that's ready for another
1737 operation.
1738 """
1739 obj = self._clone()
1740 if obj._sticky_filter:
1741 obj.sql_query.filter_is_sticky = True
1742 obj._sticky_filter = False
1743 return obj
1744
1745 def _clone(self) -> Self:
1746 """
1747 Return a copy of the current QuerySet. A lightweight alternative
1748 to deepcopy().
1749 """
1750 c = self.__class__.from_model(
1751 model=self.model,
1752 query=self.sql_query.chain(),
1753 )
1754 c._sticky_filter = self._sticky_filter
1755 c._for_write = self._for_write
1756 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1757 c._known_related_objects = self._known_related_objects
1758 c._iterable_class = self._iterable_class
1759 c._fields = self._fields
1760 return c
1761
1762 def _fetch_all(self) -> None:
1763 if self._result_cache is None:
1764 self._result_cache = list(self._iterable_class(self))
1765 if self._prefetch_related_lookups and not self._prefetch_done:
1766 self._prefetch_related_objects()
1767
1768 def _next_is_sticky(self) -> QuerySet[T]:
1769 """
1770 Indicate that the next filter call and the one following that should
1771 be treated as a single filter. This is only important when it comes to
1772 determining when to reuse tables for many-to-many filters. Required so
1773 that we can filter naturally on the results of related managers.
1774
1775 This doesn't return a clone of the current QuerySet (it returns
1776 "self"). The method is only used internally and should be immediately
1777 followed by a filter() that does create a clone.
1778 """
1779 self._sticky_filter = True
1780 return self
1781
1782 def _merge_sanity_check(self, other: QuerySet[T]) -> None:
1783 """Check that two QuerySet classes may be merged."""
1784 if self._fields is not None and (
1785 set(self.sql_query.values_select) != set(other.sql_query.values_select)
1786 or set(self.sql_query.extra_select) != set(other.sql_query.extra_select)
1787 or set(self.sql_query.annotation_select)
1788 != set(other.sql_query.annotation_select)
1789 ):
1790 raise TypeError(
1791 f"Merging '{self.__class__.__name__}' classes must involve the same values in each case."
1792 )
1793
1794 def _merge_known_related_objects(self, other: QuerySet[T]) -> None:
1795 """
1796 Keep track of all known related objects from either QuerySet instance.
1797 """
1798 for field, objects in other._known_related_objects.items():
1799 self._known_related_objects.setdefault(field, {}).update(objects)
1800
1801 def resolve_expression(self, *args: Any, **kwargs: Any) -> sql.Query:
1802 if self._fields and len(self._fields) > 1:
1803 # values() queryset can only be used as nested queries
1804 # if they are set up to select only a single field.
1805 raise TypeError("Cannot use multi-field values as a filter value.")
1806 query = self.sql_query.resolve_expression(*args, **kwargs)
1807 return query
1808
1809 def _has_filters(self) -> bool:
1810 """
1811 Check if this QuerySet has any filtering going on. This isn't
1812 equivalent with checking if all objects are present in results, for
1813 example, qs[1:]._has_filters() -> False.
1814 """
1815 return self.sql_query.has_filters()
1816
1817 @staticmethod
1818 def _validate_values_are_expressions(
1819 values: tuple[Any, ...], method_name: str
1820 ) -> None:
1821 invalid_args = sorted(
1822 str(arg) for arg in values if not isinstance(arg, ResolvableExpression)
1823 )
1824 if invalid_args:
1825 raise TypeError(
1826 "QuerySet.{}() received non-expression(s): {}.".format(
1827 method_name,
1828 ", ".join(invalid_args),
1829 )
1830 )
1831
1832 def _not_support_combined_queries(self, operation_name: str) -> None:
1833 if self.sql_query.combinator:
1834 raise NotSupportedError(
1835 f"Calling QuerySet.{operation_name}() after {self.sql_query.combinator}() is not supported."
1836 )
1837
1838 def _check_operator_queryset(self, other: QuerySet[T], operator_: str) -> None:
1839 if self.sql_query.combinator or other.sql_query.combinator:
1840 raise TypeError(f"Cannot use {operator_} operator with combined queryset.")
1841
1842
1843class InstanceCheckMeta(type):
1844 def __instancecheck__(self, instance: object) -> bool:
1845 return isinstance(instance, QuerySet) and instance.sql_query.is_empty()
1846
1847
1848class EmptyQuerySet(metaclass=InstanceCheckMeta):
1849 """
1850 Marker class to checking if a queryset is empty by .none():
1851 isinstance(qs.none(), EmptyQuerySet) -> True
1852 """
1853
1854 def __init__(self, *args: Any, **kwargs: Any):
1855 raise TypeError("EmptyQuerySet can't be instantiated")
1856
1857
1858class RawQuerySet:
1859 """
1860 Provide an iterator which converts the results of raw SQL queries into
1861 annotated model instances.
1862 """
1863
1864 def __init__(
1865 self,
1866 raw_query: str,
1867 model: type[Model] | None = None,
1868 query: sql.RawQuery | None = None,
1869 params: tuple[Any, ...] = (),
1870 translations: dict[str, str] | None = None,
1871 ):
1872 self.raw_query = raw_query
1873 self.model = model
1874 self.sql_query = query or sql.RawQuery(sql=raw_query, params=params)
1875 self.params = params
1876 self.translations = translations or {}
1877 self._result_cache: list[Model] | None = None
1878 self._prefetch_related_lookups: tuple[Any, ...] = ()
1879 self._prefetch_done = False
1880
1881 def resolve_model_init_order(
1882 self,
1883 ) -> tuple[list[str], list[int], list[tuple[str, int]]]:
1884 """Resolve the init field names and value positions."""
1885 converter = db_connection.introspection.identifier_converter
1886 model = self.model
1887 assert model is not None
1888 model_init_fields = [
1889 f for f in model._model_meta.fields if converter(f.column) in self.columns
1890 ]
1891 annotation_fields = [
1892 (column, pos)
1893 for pos, column in enumerate(self.columns)
1894 if column not in self.model_fields
1895 ]
1896 model_init_order = [
1897 self.columns.index(converter(f.column)) for f in model_init_fields
1898 ]
1899 model_init_names = [f.attname for f in model_init_fields]
1900 return model_init_names, model_init_order, annotation_fields
1901
1902 def prefetch_related(self, *lookups: str | Prefetch | None) -> RawQuerySet:
1903 """Same as QuerySet.prefetch_related()"""
1904 clone = self._clone()
1905 if lookups == (None,):
1906 clone._prefetch_related_lookups = ()
1907 else:
1908 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1909 return clone
1910
1911 def _prefetch_related_objects(self) -> None:
1912 assert self._result_cache is not None
1913 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1914 self._prefetch_done = True
1915
1916 def _clone(self) -> RawQuerySet:
1917 """Same as QuerySet._clone()"""
1918 c = self.__class__(
1919 self.raw_query,
1920 model=self.model,
1921 query=self.sql_query,
1922 params=self.params,
1923 translations=self.translations,
1924 )
1925 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1926 return c
1927
1928 def _fetch_all(self) -> None:
1929 if self._result_cache is None:
1930 self._result_cache = list(self.iterator())
1931 if self._prefetch_related_lookups and not self._prefetch_done:
1932 self._prefetch_related_objects()
1933
1934 def __len__(self) -> int:
1935 self._fetch_all()
1936 assert self._result_cache is not None
1937 return len(self._result_cache)
1938
1939 def __bool__(self) -> bool:
1940 self._fetch_all()
1941 return bool(self._result_cache)
1942
1943 def __iter__(self) -> Iterator[Model]:
1944 self._fetch_all()
1945 assert self._result_cache is not None
1946 return iter(self._result_cache)
1947
1948 def iterator(self) -> Iterator[Model]:
1949 yield from RawModelIterable(self) # type: ignore[arg-type]
1950
1951 def __repr__(self) -> str:
1952 return f"<{self.__class__.__name__}: {self.sql_query}>"
1953
1954 def __getitem__(self, k: int | slice) -> Model | list[Model]:
1955 return list(self)[k]
1956
1957 @cached_property
1958 def columns(self) -> list[str]:
1959 """
1960 A list of model field names in the order they'll appear in the
1961 query results.
1962 """
1963 columns = self.sql_query.get_columns()
1964 # Adjust any column names which don't match field names
1965 for query_name, model_name in self.translations.items():
1966 # Ignore translations for nonexistent column names
1967 try:
1968 index = columns.index(query_name)
1969 except ValueError:
1970 pass
1971 else:
1972 columns[index] = model_name
1973 return columns
1974
1975 @cached_property
1976 def model_fields(self) -> dict[str, Field]:
1977 """A dict mapping column names to model field names."""
1978 converter = db_connection.introspection.identifier_converter
1979 model_fields = {}
1980 model = self.model
1981 assert model is not None
1982 for field in model._model_meta.fields:
1983 name, column = field.get_attname_column()
1984 model_fields[converter(column)] = field
1985 return model_fields
1986
1987
1988class Prefetch:
1989 def __init__(
1990 self,
1991 lookup: str,
1992 queryset: QuerySet[Any] | None = None,
1993 to_attr: str | None = None,
1994 ):
1995 # `prefetch_through` is the path we traverse to perform the prefetch.
1996 self.prefetch_through = lookup
1997 # `prefetch_to` is the path to the attribute that stores the result.
1998 self.prefetch_to = lookup
1999 if queryset is not None and (
2000 isinstance(queryset, RawQuerySet)
2001 or (
2002 hasattr(queryset, "_iterable_class")
2003 and not issubclass(queryset._iterable_class, ModelIterable)
2004 )
2005 ):
2006 raise ValueError(
2007 "Prefetch querysets cannot use raw(), values(), and values_list()."
2008 )
2009 if to_attr:
2010 self.prefetch_to = LOOKUP_SEP.join(
2011 lookup.split(LOOKUP_SEP)[:-1] + [to_attr]
2012 )
2013
2014 self.queryset = queryset
2015 self.to_attr = to_attr
2016
2017 def __getstate__(self) -> dict[str, Any]:
2018 obj_dict = self.__dict__.copy()
2019 if self.queryset is not None:
2020 queryset = self.queryset._chain()
2021 # Prevent the QuerySet from being evaluated
2022 queryset._result_cache = []
2023 queryset._prefetch_done = True
2024 obj_dict["queryset"] = queryset
2025 return obj_dict
2026
2027 def add_prefix(self, prefix: str) -> None:
2028 self.prefetch_through = prefix + LOOKUP_SEP + self.prefetch_through
2029 self.prefetch_to = prefix + LOOKUP_SEP + self.prefetch_to
2030
2031 def get_current_prefetch_to(self, level: int) -> str:
2032 return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[: level + 1])
2033
2034 def get_current_to_attr(self, level: int) -> tuple[str, bool]:
2035 parts = self.prefetch_to.split(LOOKUP_SEP)
2036 to_attr = parts[level]
2037 as_attr = bool(self.to_attr and level == len(parts) - 1)
2038 return to_attr, as_attr
2039
2040 def get_current_queryset(self, level: int) -> QuerySet[Any] | None:
2041 if self.get_current_prefetch_to(level) == self.prefetch_to:
2042 return self.queryset
2043 return None
2044
2045 def __eq__(self, other: object) -> bool:
2046 if not isinstance(other, Prefetch):
2047 return NotImplemented
2048 return self.prefetch_to == other.prefetch_to
2049
2050 def __hash__(self) -> int:
2051 return hash((self.__class__, self.prefetch_to))
2052
2053
2054def normalize_prefetch_lookups(
2055 lookups: tuple[str | Prefetch, ...] | list[str | Prefetch],
2056 prefix: str | None = None,
2057) -> list[Prefetch]:
2058 """Normalize lookups into Prefetch objects."""
2059 ret = []
2060 for lookup in lookups:
2061 if not isinstance(lookup, Prefetch):
2062 lookup = Prefetch(lookup)
2063 if prefix:
2064 lookup.add_prefix(prefix)
2065 ret.append(lookup)
2066 return ret
2067
2068
2069def prefetch_related_objects(
2070 model_instances: Sequence[Model], *related_lookups: str | Prefetch
2071) -> None:
2072 """
2073 Populate prefetched object caches for a list of model instances based on
2074 the lookups/Prefetch instances given.
2075 """
2076 if not model_instances:
2077 return # nothing to do
2078
2079 # We need to be able to dynamically add to the list of prefetch_related
2080 # lookups that we look up (see below). So we need some book keeping to
2081 # ensure we don't do duplicate work.
2082 done_queries = {} # dictionary of things like 'foo__bar': [results]
2083
2084 auto_lookups = set() # we add to this as we go through.
2085 followed_descriptors = set() # recursion protection
2086
2087 all_lookups = normalize_prefetch_lookups(reversed(related_lookups)) # type: ignore[arg-type]
2088 while all_lookups:
2089 lookup = all_lookups.pop()
2090 if lookup.prefetch_to in done_queries:
2091 if lookup.queryset is not None:
2092 raise ValueError(
2093 f"'{lookup.prefetch_to}' lookup was already seen with a different queryset. "
2094 "You may need to adjust the ordering of your lookups."
2095 )
2096
2097 continue
2098
2099 # Top level, the list of objects to decorate is the result cache
2100 # from the primary QuerySet. It won't be for deeper levels.
2101 obj_list = model_instances
2102
2103 through_attrs = lookup.prefetch_through.split(LOOKUP_SEP)
2104 for level, through_attr in enumerate(through_attrs):
2105 # Prepare main instances
2106 if not obj_list:
2107 break
2108
2109 prefetch_to = lookup.get_current_prefetch_to(level)
2110 if prefetch_to in done_queries:
2111 # Skip any prefetching, and any object preparation
2112 obj_list = done_queries[prefetch_to]
2113 continue
2114
2115 # Prepare objects:
2116 good_objects = True
2117 for obj in obj_list:
2118 # Since prefetching can re-use instances, it is possible to have
2119 # the same instance multiple times in obj_list, so obj might
2120 # already be prepared.
2121 if not hasattr(obj, "_prefetched_objects_cache"):
2122 try:
2123 obj._prefetched_objects_cache = {}
2124 except (AttributeError, TypeError):
2125 # Must be an immutable object from
2126 # values_list(flat=True), for example (TypeError) or
2127 # a QuerySet subclass that isn't returning Model
2128 # instances (AttributeError), either in Plain or a 3rd
2129 # party. prefetch_related() doesn't make sense, so quit.
2130 good_objects = False
2131 break
2132 if not good_objects:
2133 break
2134
2135 # Descend down tree
2136
2137 # We assume that objects retrieved are homogeneous (which is the premise
2138 # of prefetch_related), so what applies to first object applies to all.
2139 first_obj = obj_list[0]
2140 to_attr = lookup.get_current_to_attr(level)[0]
2141 prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(
2142 first_obj, through_attr, to_attr
2143 )
2144
2145 if not attr_found:
2146 raise AttributeError(
2147 f"Cannot find '{through_attr}' on {first_obj.__class__.__name__} object, '{lookup.prefetch_through}' is an invalid "
2148 "parameter to prefetch_related()"
2149 )
2150
2151 if level == len(through_attrs) - 1 and prefetcher is None:
2152 # Last one, this *must* resolve to something that supports
2153 # prefetching, otherwise there is no point adding it and the
2154 # developer asking for it has made a mistake.
2155 raise ValueError(
2156 f"'{lookup.prefetch_through}' does not resolve to an item that supports "
2157 "prefetching - this is an invalid parameter to "
2158 "prefetch_related()."
2159 )
2160
2161 obj_to_fetch = None
2162 if prefetcher is not None:
2163 obj_to_fetch = [obj for obj in obj_list if not is_fetched(obj)]
2164
2165 if obj_to_fetch:
2166 obj_list, additional_lookups = prefetch_one_level(
2167 obj_to_fetch,
2168 prefetcher,
2169 lookup,
2170 level,
2171 )
2172 # We need to ensure we don't keep adding lookups from the
2173 # same relationships to stop infinite recursion. So, if we
2174 # are already on an automatically added lookup, don't add
2175 # the new lookups from relationships we've seen already.
2176 if not (
2177 prefetch_to in done_queries
2178 and lookup in auto_lookups
2179 and descriptor in followed_descriptors
2180 ):
2181 done_queries[prefetch_to] = obj_list
2182 new_lookups = normalize_prefetch_lookups(
2183 reversed(additional_lookups), # type: ignore[arg-type]
2184 prefetch_to,
2185 )
2186 auto_lookups.update(new_lookups)
2187 all_lookups.extend(new_lookups)
2188 followed_descriptors.add(descriptor)
2189 else:
2190 # Either a singly related object that has already been fetched
2191 # (e.g. via select_related), or hopefully some other property
2192 # that doesn't support prefetching but needs to be traversed.
2193
2194 # We replace the current list of parent objects with the list
2195 # of related objects, filtering out empty or missing values so
2196 # that we can continue with nullable or reverse relations.
2197 new_obj_list = []
2198 for obj in obj_list:
2199 if through_attr in getattr(obj, "_prefetched_objects_cache", ()):
2200 # If related objects have been prefetched, use the
2201 # cache rather than the object's through_attr.
2202 new_obj = list(obj._prefetched_objects_cache.get(through_attr)) # type: ignore[arg-type]
2203 else:
2204 try:
2205 new_obj = getattr(obj, through_attr)
2206 except ObjectDoesNotExist:
2207 continue
2208 if new_obj is None:
2209 continue
2210 # We special-case `list` rather than something more generic
2211 # like `Iterable` because we don't want to accidentally match
2212 # user models that define __iter__.
2213 if isinstance(new_obj, list):
2214 new_obj_list.extend(new_obj)
2215 else:
2216 new_obj_list.append(new_obj)
2217 obj_list = new_obj_list
2218
2219
2220def get_prefetcher(
2221 instance: Model, through_attr: str, to_attr: str
2222) -> tuple[Any, Any, bool, Callable[[Model], bool]]:
2223 """
2224 For the attribute 'through_attr' on the given instance, find
2225 an object that has a get_prefetch_queryset().
2226 Return a 4 tuple containing:
2227 (the object with get_prefetch_queryset (or None),
2228 the descriptor object representing this relationship (or None),
2229 a boolean that is False if the attribute was not found at all,
2230 a function that takes an instance and returns a boolean that is True if
2231 the attribute has already been fetched for that instance)
2232 """
2233
2234 def has_to_attr_attribute(instance: Model) -> bool:
2235 return hasattr(instance, to_attr)
2236
2237 prefetcher = None
2238 is_fetched: Callable[[Model], bool] = has_to_attr_attribute
2239
2240 # For singly related objects, we have to avoid getting the attribute
2241 # from the object, as this will trigger the query. So we first try
2242 # on the class, in order to get the descriptor object.
2243 rel_obj_descriptor = getattr(instance.__class__, through_attr, None)
2244 if rel_obj_descriptor is None:
2245 attr_found = hasattr(instance, through_attr)
2246 else:
2247 attr_found = True
2248 if rel_obj_descriptor:
2249 # singly related object, descriptor object has the
2250 # get_prefetch_queryset() method.
2251 if hasattr(rel_obj_descriptor, "get_prefetch_queryset"):
2252 prefetcher = rel_obj_descriptor
2253 is_fetched = rel_obj_descriptor.is_cached
2254 else:
2255 # descriptor doesn't support prefetching, so we go ahead and get
2256 # the attribute on the instance rather than the class to
2257 # support many related managers
2258 rel_obj = getattr(instance, through_attr)
2259 if hasattr(rel_obj, "get_prefetch_queryset"):
2260 prefetcher = rel_obj
2261 if through_attr != to_attr:
2262 # Special case cached_property instances because hasattr
2263 # triggers attribute computation and assignment.
2264 if isinstance(
2265 getattr(instance.__class__, to_attr, None), cached_property
2266 ):
2267
2268 def has_cached_property(instance: Model) -> bool:
2269 return to_attr in instance.__dict__
2270
2271 is_fetched = has_cached_property
2272 else:
2273
2274 def in_prefetched_cache(instance: Model) -> bool:
2275 return through_attr in instance._prefetched_objects_cache
2276
2277 is_fetched = in_prefetched_cache
2278 return prefetcher, rel_obj_descriptor, attr_found, is_fetched
2279
2280
2281def prefetch_one_level(
2282 instances: list[Model], prefetcher: Any, lookup: Prefetch, level: int
2283) -> tuple[list[Model], list[Prefetch]]:
2284 """
2285 Helper function for prefetch_related_objects().
2286
2287 Run prefetches on all instances using the prefetcher object,
2288 assigning results to relevant caches in instance.
2289
2290 Return the prefetched objects along with any additional prefetches that
2291 must be done due to prefetch_related lookups found from default managers.
2292 """
2293 # prefetcher must have a method get_prefetch_queryset() which takes a list
2294 # of instances, and returns a tuple:
2295
2296 # (queryset of instances of self.model that are related to passed in instances,
2297 # callable that gets value to be matched for returned instances,
2298 # callable that gets value to be matched for passed in instances,
2299 # boolean that is True for singly related objects,
2300 # cache or field name to assign to,
2301 # boolean that is True when the previous argument is a cache name vs a field name).
2302
2303 # The 'values to be matched' must be hashable as they will be used
2304 # in a dictionary.
2305
2306 (
2307 rel_qs,
2308 rel_obj_attr,
2309 instance_attr,
2310 single,
2311 cache_name,
2312 is_descriptor,
2313 ) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))
2314 # We have to handle the possibility that the QuerySet we just got back
2315 # contains some prefetch_related lookups. We don't want to trigger the
2316 # prefetch_related functionality by evaluating the query. Rather, we need
2317 # to merge in the prefetch_related lookups.
2318 # Copy the lookups in case it is a Prefetch object which could be reused
2319 # later (happens in nested prefetch_related).
2320 additional_lookups = [
2321 copy.copy(additional_lookup)
2322 for additional_lookup in getattr(rel_qs, "_prefetch_related_lookups", ())
2323 ]
2324 if additional_lookups:
2325 # Don't need to clone because the queryset should have given us a fresh
2326 # instance, so we access an internal instead of using public interface
2327 # for performance reasons.
2328 rel_qs._prefetch_related_lookups = ()
2329
2330 all_related_objects = list(rel_qs)
2331
2332 rel_obj_cache = {}
2333 for rel_obj in all_related_objects:
2334 rel_attr_val = rel_obj_attr(rel_obj)
2335 rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
2336
2337 to_attr, as_attr = lookup.get_current_to_attr(level)
2338 # Make sure `to_attr` does not conflict with a field.
2339 if as_attr and instances:
2340 # We assume that objects retrieved are homogeneous (which is the premise
2341 # of prefetch_related), so what applies to first object applies to all.
2342 model = instances[0].__class__
2343 try:
2344 model._model_meta.get_field(to_attr)
2345 except FieldDoesNotExist:
2346 pass
2347 else:
2348 msg = "to_attr={} conflicts with a field on the {} model."
2349 raise ValueError(msg.format(to_attr, model.__name__))
2350
2351 # Whether or not we're prefetching the last part of the lookup.
2352 leaf = len(lookup.prefetch_through.split(LOOKUP_SEP)) - 1 == level
2353
2354 for obj in instances:
2355 instance_attr_val = instance_attr(obj)
2356 vals = rel_obj_cache.get(instance_attr_val, [])
2357
2358 if single:
2359 val = vals[0] if vals else None
2360 if as_attr:
2361 # A to_attr has been given for the prefetch.
2362 setattr(obj, to_attr, val)
2363 elif is_descriptor:
2364 # cache_name points to a field name in obj.
2365 # This field is a descriptor for a related object.
2366 setattr(obj, cache_name, val)
2367 else:
2368 # No to_attr has been given for this prefetch operation and the
2369 # cache_name does not point to a descriptor. Store the value of
2370 # the field in the object's field cache.
2371 obj._state.fields_cache[cache_name] = val # type: ignore[assignment]
2372 else:
2373 if as_attr:
2374 setattr(obj, to_attr, vals)
2375 else:
2376 queryset = getattr(obj, to_attr)
2377 if leaf and lookup.queryset is not None:
2378 qs = queryset._apply_rel_filters(lookup.queryset)
2379 else:
2380 # Check if queryset is a QuerySet or a related manager
2381 # We need a QuerySet instance to cache the prefetched values
2382 if isinstance(queryset, QuerySet):
2383 # It's already a QuerySet, create a new instance
2384 qs = queryset.__class__.from_model(queryset.model)
2385 else:
2386 # It's a related manager, get its QuerySet
2387 # The manager's query property returns a properly filtered QuerySet
2388 qs = queryset.query
2389 qs._result_cache = vals
2390 # We don't want the individual qs doing prefetch_related now,
2391 # since we have merged this into the current work.
2392 qs._prefetch_done = True
2393 obj._prefetched_objects_cache[cache_name] = qs
2394 return all_related_objects, additional_lookups
2395
2396
2397class RelatedPopulator:
2398 """
2399 RelatedPopulator is used for select_related() object instantiation.
2400
2401 The idea is that each select_related() model will be populated by a
2402 different RelatedPopulator instance. The RelatedPopulator instances get
2403 klass_info and select (computed in SQLCompiler) plus the used db as
2404 input for initialization. That data is used to compute which columns
2405 to use, how to instantiate the model, and how to populate the links
2406 between the objects.
2407
2408 The actual creation of the objects is done in populate() method. This
2409 method gets row and from_obj as input and populates the select_related()
2410 model instance.
2411 """
2412
2413 def __init__(self, klass_info: dict[str, Any], select: list[Any]):
2414 # Pre-compute needed attributes. The attributes are:
2415 # - model_cls: the possibly deferred model class to instantiate
2416 # - either:
2417 # - cols_start, cols_end: usually the columns in the row are
2418 # in the same order model_cls.__init__ expects them, so we
2419 # can instantiate by model_cls(*row[cols_start:cols_end])
2420 # - reorder_for_init: When select_related descends to a child
2421 # class, then we want to reuse the already selected parent
2422 # data. However, in this case the parent data isn't necessarily
2423 # in the same order that Model.__init__ expects it to be, so
2424 # we have to reorder the parent data. The reorder_for_init
2425 # attribute contains a function used to reorder the field data
2426 # in the order __init__ expects it.
2427 # - id_idx: the index of the primary key field in the reordered
2428 # model data. Used to check if a related object exists at all.
2429 # - init_list: the field attnames fetched from the database. For
2430 # deferred models this isn't the same as all attnames of the
2431 # model's fields.
2432 # - related_populators: a list of RelatedPopulator instances if
2433 # select_related() descends to related models from this model.
2434 # - local_setter, remote_setter: Methods to set cached values on
2435 # the object being populated and on the remote object. Usually
2436 # these are Field.set_cached_value() methods.
2437 select_fields = klass_info["select_fields"]
2438
2439 self.cols_start = select_fields[0]
2440 self.cols_end = select_fields[-1] + 1
2441 self.init_list = [
2442 f[0].target.attname for f in select[self.cols_start : self.cols_end]
2443 ]
2444 self.reorder_for_init = None
2445
2446 self.model_cls = klass_info["model"]
2447 self.id_idx = self.init_list.index("id")
2448 self.related_populators = get_related_populators(klass_info, select)
2449 self.local_setter = klass_info["local_setter"]
2450 self.remote_setter = klass_info["remote_setter"]
2451
2452 def populate(self, row: tuple[Any, ...], from_obj: Model) -> None:
2453 if self.reorder_for_init:
2454 obj_data = self.reorder_for_init(row)
2455 else:
2456 obj_data = row[self.cols_start : self.cols_end]
2457 if obj_data[self.id_idx] is None:
2458 obj = None
2459 else:
2460 obj = self.model_cls.from_db(self.init_list, obj_data)
2461 for rel_iter in self.related_populators:
2462 rel_iter.populate(row, obj)
2463 self.local_setter(from_obj, obj)
2464 if obj is not None:
2465 self.remote_setter(obj, from_obj)
2466
2467
2468def get_related_populators(
2469 klass_info: dict[str, Any], select: list[Any]
2470) -> list[RelatedPopulator]:
2471 iterators = []
2472 related_klass_infos = klass_info.get("related_klass_infos", [])
2473 for rel_klass_info in related_klass_infos:
2474 rel_cls = RelatedPopulator(rel_klass_info, select)
2475 iterators.append(rel_cls)
2476 return iterators