1"""
2The main QuerySet implementation. This provides the public API for the ORM.
3"""
4
5from __future__ import annotations
6
7import copy
8import operator
9import warnings
10from collections.abc import Callable, Iterator
11from functools import cached_property
12from itertools import chain, islice
13from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar
14
15import plain.runtime
16from plain.exceptions import ValidationError
17from plain.models import (
18 sql,
19 transaction,
20)
21from plain.models.constants import LOOKUP_SEP, OnConflict
22from plain.models.db import (
23 PLAIN_VERSION_PICKLE_KEY,
24 IntegrityError,
25 NotSupportedError,
26 db_connection,
27)
28from plain.models.exceptions import (
29 FieldDoesNotExist,
30 FieldError,
31 ObjectDoesNotExist,
32)
33from plain.models.expressions import Case, F, Value, When
34from plain.models.fields import (
35 DateField,
36 DateTimeField,
37 Field,
38 PrimaryKeyField,
39)
40from plain.models.functions import Cast, Trunc
41from plain.models.query_utils import FilteredRelation, Q
42from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
43from plain.models.utils import (
44 create_namedtuple_class,
45 resolve_callables,
46)
47from plain.utils import timezone
48from plain.utils.functional import partition
49
50if TYPE_CHECKING:
51 from datetime import tzinfo
52
53 from plain.models import Model
54
55# Type variable for QuerySet generic
56T = TypeVar("T", bound="Model")
57
58# The maximum number of results to fetch in a get() query.
59MAX_GET_RESULTS = 21
60
61# The maximum number of items to display in a QuerySet.__repr__
62REPR_OUTPUT_SIZE = 20
63
64
65class BaseIterable:
66 def __init__(
67 self,
68 queryset: QuerySet[Any],
69 chunked_fetch: bool = False,
70 chunk_size: int = GET_ITERATOR_CHUNK_SIZE,
71 ):
72 self.queryset = queryset
73 self.chunked_fetch = chunked_fetch
74 self.chunk_size = chunk_size
75
76
77class ModelIterable(BaseIterable):
78 """Iterable that yields a model instance for each row."""
79
80 def __iter__(self) -> Iterator[Model]: # type: ignore[misc]
81 queryset = self.queryset
82 compiler = queryset.sql_query.get_compiler()
83 # Execute the query. This will also fill compiler.select, klass_info,
84 # and annotations.
85 results = compiler.execute_sql(
86 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
87 )
88 select, klass_info, annotation_col_map = (
89 compiler.select,
90 compiler.klass_info,
91 compiler.annotation_col_map,
92 )
93 model_cls = klass_info["model"]
94 select_fields = klass_info["select_fields"]
95 model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
96 init_list = [
97 f[0].target.attname for f in select[model_fields_start:model_fields_end]
98 ]
99 related_populators = get_related_populators(klass_info, select)
100 known_related_objects = [
101 (
102 field,
103 related_objs,
104 operator.attrgetter(field.attname),
105 )
106 for field, related_objs in queryset._known_related_objects.items()
107 ]
108 for row in compiler.results_iter(results):
109 obj = model_cls.from_db(init_list, row[model_fields_start:model_fields_end])
110 for rel_populator in related_populators:
111 rel_populator.populate(row, obj)
112 if annotation_col_map:
113 for attr_name, col_pos in annotation_col_map.items():
114 setattr(obj, attr_name, row[col_pos])
115
116 # Add the known related objects to the model.
117 for field, rel_objs, rel_getter in known_related_objects:
118 # Avoid overwriting objects loaded by, e.g., select_related().
119 if field.is_cached(obj):
120 continue
121 rel_obj_id = rel_getter(obj)
122 try:
123 rel_obj = rel_objs[rel_obj_id]
124 except KeyError:
125 pass # May happen in qs1 | qs2 scenarios.
126 else:
127 setattr(obj, field.name, rel_obj)
128
129 yield obj
130
131
132class RawModelIterable(BaseIterable):
133 """
134 Iterable that yields a model instance for each row from a raw queryset.
135 """
136
137 def __iter__(self) -> Iterator[Model]: # type: ignore[misc]
138 # Cache some things for performance reasons outside the loop.
139 query = self.queryset.sql_query
140 compiler = db_connection.ops.compiler("SQLCompiler")(query, db_connection)
141 query_iterator = iter(query)
142
143 try:
144 (
145 model_init_names,
146 model_init_pos,
147 annotation_fields,
148 ) = self.queryset.resolve_model_init_order()
149 model_cls = self.queryset.model
150 if "id" not in model_init_names:
151 raise FieldDoesNotExist("Raw query must include the primary key")
152 fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
153 converters = compiler.get_converters(
154 [
155 f.get_col(f.model.model_options.db_table) if f else None
156 for f in fields
157 ]
158 )
159 if converters:
160 query_iterator = compiler.apply_converters(query_iterator, converters)
161 for values in query_iterator:
162 # Associate fields to values
163 model_init_values = [values[pos] for pos in model_init_pos]
164 instance = model_cls.from_db(model_init_names, model_init_values)
165 if annotation_fields:
166 for column, pos in annotation_fields:
167 setattr(instance, column, values[pos])
168 yield instance
169 finally:
170 # Done iterating the Query. If it has its own cursor, close it.
171 if hasattr(query, "cursor") and query.cursor:
172 query.cursor.close()
173
174
175class ValuesIterable(BaseIterable):
176 """
177 Iterable returned by QuerySet.values() that yields a dict for each row.
178 """
179
180 def __iter__(self) -> Iterator[dict[str, Any]]: # type: ignore[misc]
181 queryset = self.queryset
182 query = queryset.sql_query
183 compiler = query.get_compiler()
184
185 # extra(select=...) cols are always at the start of the row.
186 names = [
187 *query.extra_select,
188 *query.values_select,
189 *query.annotation_select,
190 ]
191 indexes = range(len(names))
192 for row in compiler.results_iter(
193 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
194 ):
195 yield {names[i]: row[i] for i in indexes}
196
197
198class ValuesListIterable(BaseIterable):
199 """
200 Iterable returned by QuerySet.values_list(flat=False) that yields a tuple
201 for each row.
202 """
203
204 def __iter__(self) -> Iterator[tuple[Any, ...]]: # type: ignore[misc]
205 queryset = self.queryset
206 query = queryset.sql_query
207 compiler = query.get_compiler()
208
209 if queryset._fields:
210 # extra(select=...) cols are always at the start of the row.
211 names = [
212 *query.extra_select,
213 *query.values_select,
214 *query.annotation_select,
215 ]
216 fields = [
217 *queryset._fields,
218 *(f for f in query.annotation_select if f not in queryset._fields),
219 ]
220 if fields != names:
221 # Reorder according to fields.
222 index_map = {name: idx for idx, name in enumerate(names)}
223 rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
224 return map(
225 rowfactory,
226 compiler.results_iter(
227 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
228 ),
229 )
230 return compiler.results_iter(
231 tuple_expected=True,
232 chunked_fetch=self.chunked_fetch,
233 chunk_size=self.chunk_size,
234 )
235
236
237class NamedValuesListIterable(ValuesListIterable):
238 """
239 Iterable returned by QuerySet.values_list(named=True) that yields a
240 namedtuple for each row.
241 """
242
243 def __iter__(self) -> Iterator[tuple[Any, ...]]: # type: ignore[misc]
244 queryset = self.queryset
245 if queryset._fields:
246 names = queryset._fields
247 else:
248 query = queryset.sql_query
249 names = [
250 *query.extra_select,
251 *query.values_select,
252 *query.annotation_select,
253 ]
254 tuple_class = create_namedtuple_class(*names)
255 new = tuple.__new__
256 for row in super().__iter__():
257 yield new(tuple_class, row)
258
259
260class FlatValuesListIterable(BaseIterable):
261 """
262 Iterable returned by QuerySet.values_list(flat=True) that yields single
263 values.
264 """
265
266 def __iter__(self) -> Iterator[Any]: # type: ignore[misc]
267 queryset = self.queryset
268 compiler = queryset.sql_query.get_compiler()
269 for row in compiler.results_iter(
270 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
271 ):
272 yield row[0]
273
274
275class QuerySet(Generic[T]):
276 """
277 Represent a lazy database lookup for a set of objects.
278
279 Usage:
280 MyModel.query.filter(name="test").all()
281
282 Custom QuerySets:
283 from typing import Self
284
285 class TaskQuerySet(QuerySet["Task"]):
286 def active(self) -> Self:
287 return self.filter(is_active=True)
288
289 class Task(Model):
290 is_active = BooleanField(default=True)
291 query = TaskQuerySet()
292
293 Task.query.active().filter(name="test") # Full type inference
294
295 Custom methods should return `Self` to preserve type through method chaining.
296 """
297
298 # Instance attributes (set in from_model())
299 model: type[T]
300 _query: sql.Query
301 _result_cache: list[T] | None
302 _sticky_filter: bool
303 _for_write: bool
304 _prefetch_related_lookups: tuple[Any, ...]
305 _prefetch_done: bool
306 _known_related_objects: dict[Any, dict[Any, Any]]
307 _iterable_class: type[BaseIterable]
308 _fields: tuple[str, ...] | None
309 _defer_next_filter: bool
310 _deferred_filter: tuple[bool, tuple[Any, ...], dict[str, Any]] | None
311
312 def __init__(self):
313 """Minimal init for descriptor mode. Use from_model() to create instances."""
314 pass
315
316 @classmethod
317 def from_model(cls, model: type[T], query: sql.Query | None = None) -> Self:
318 """Create a QuerySet instance bound to a model."""
319 instance = cls()
320 instance.model = model
321 instance._query = query or sql.Query(model)
322 instance._result_cache = None
323 instance._sticky_filter = False
324 instance._for_write = False
325 instance._prefetch_related_lookups = ()
326 instance._prefetch_done = False
327 instance._known_related_objects = {}
328 instance._iterable_class = ModelIterable
329 instance._fields = None
330 instance._defer_next_filter = False
331 instance._deferred_filter = None
332 return instance
333
334 def __get__(self, instance: Any, owner: type[T]) -> Self:
335 """Descriptor protocol - return a new QuerySet bound to the model."""
336 if instance is not None:
337 raise AttributeError(
338 f"QuerySet is only accessible from the model class, not instances. "
339 f"Use {owner.__name__}.query instead."
340 )
341 return self.from_model(owner)
342
343 @property
344 def sql_query(self) -> sql.Query:
345 if self._deferred_filter:
346 negate, args, kwargs = self._deferred_filter
347 self._filter_or_exclude_inplace(negate, args, kwargs)
348 self._deferred_filter = None
349 return self._query
350
351 @sql_query.setter
352 def sql_query(self, value: sql.Query) -> None:
353 if value.values_select:
354 self._iterable_class = ValuesIterable
355 self._query = value
356
357 ########################
358 # PYTHON MAGIC METHODS #
359 ########################
360
361 def __deepcopy__(self, memo: dict[int, Any]) -> QuerySet[T]:
362 """Don't populate the QuerySet's cache."""
363 obj = self.__class__.from_model(self.model)
364 for k, v in self.__dict__.items():
365 if k == "_result_cache":
366 obj.__dict__[k] = None
367 else:
368 obj.__dict__[k] = copy.deepcopy(v, memo)
369 return obj
370
371 def __getstate__(self) -> dict[str, Any]:
372 # Force the cache to be fully populated.
373 self._fetch_all()
374 return {**self.__dict__, PLAIN_VERSION_PICKLE_KEY: plain.runtime.__version__}
375
376 def __setstate__(self, state: dict[str, Any]) -> None:
377 pickled_version = state.get(PLAIN_VERSION_PICKLE_KEY)
378 if pickled_version:
379 if pickled_version != plain.runtime.__version__:
380 warnings.warn(
381 f"Pickled queryset instance's Plain version {pickled_version} does not "
382 f"match the current version {plain.runtime.__version__}.",
383 RuntimeWarning,
384 stacklevel=2,
385 )
386 else:
387 warnings.warn(
388 "Pickled queryset instance's Plain version is not specified.",
389 RuntimeWarning,
390 stacklevel=2,
391 )
392 self.__dict__.update(state)
393
394 def __repr__(self) -> str:
395 data = list(self[: REPR_OUTPUT_SIZE + 1])
396 if len(data) > REPR_OUTPUT_SIZE:
397 data[-1] = "...(remaining elements truncated)..."
398 return f"<{self.__class__.__name__} {data!r}>"
399
400 def __len__(self) -> int:
401 self._fetch_all()
402 return len(self._result_cache) # type: ignore[arg-type]
403
404 def __iter__(self) -> Iterator[T]:
405 """
406 The queryset iterator protocol uses three nested iterators in the
407 default case:
408 1. sql.compiler.execute_sql()
409 - Returns 100 rows at time (constants.GET_ITERATOR_CHUNK_SIZE)
410 using cursor.fetchmany(). This part is responsible for
411 doing some column masking, and returning the rows in chunks.
412 2. sql.compiler.results_iter()
413 - Returns one row at time. At this point the rows are still just
414 tuples. In some cases the return values are converted to
415 Python values at this location.
416 3. self.iterator()
417 - Responsible for turning the rows into model objects.
418 """
419 self._fetch_all()
420 return iter(self._result_cache) # type: ignore[arg-type]
421
422 def __bool__(self) -> bool:
423 self._fetch_all()
424 return bool(self._result_cache)
425
426 def __getitem__(self, k: int | slice) -> T | QuerySet[T]:
427 """Retrieve an item or slice from the set of results."""
428 if not isinstance(k, int | slice):
429 raise TypeError(
430 f"QuerySet indices must be integers or slices, not {type(k).__name__}."
431 )
432 if (isinstance(k, int) and k < 0) or (
433 isinstance(k, slice)
434 and (
435 (k.start is not None and k.start < 0)
436 or (k.stop is not None and k.stop < 0)
437 )
438 ):
439 raise ValueError("Negative indexing is not supported.")
440
441 if self._result_cache is not None:
442 return self._result_cache[k]
443
444 if isinstance(k, slice):
445 qs = self._chain()
446 if k.start is not None:
447 start = int(k.start)
448 else:
449 start = None
450 if k.stop is not None:
451 stop = int(k.stop)
452 else:
453 stop = None
454 qs.sql_query.set_limits(start, stop)
455 return list(qs)[:: k.step] if k.step else qs # type: ignore[return-value]
456
457 qs = self._chain()
458 qs.sql_query.set_limits(k, k + 1) # type: ignore[unsupported-operator]
459 qs._fetch_all()
460 return qs._result_cache[0]
461
462 def __class_getitem__(cls, *args: Any, **kwargs: Any) -> type[QuerySet[Any]]:
463 return cls
464
465 def __and__(self, other: QuerySet[T]) -> QuerySet[T]:
466 self._check_operator_queryset(other, "&")
467 self._merge_sanity_check(other)
468 if isinstance(other, EmptyQuerySet):
469 return other
470 if isinstance(self, EmptyQuerySet):
471 return self
472 combined = self._chain()
473 combined._merge_known_related_objects(other)
474 combined.sql_query.combine(other.sql_query, sql.AND)
475 return combined
476
477 def __or__(self, other: QuerySet[T]) -> QuerySet[T]:
478 self._check_operator_queryset(other, "|")
479 self._merge_sanity_check(other)
480 if isinstance(self, EmptyQuerySet):
481 return other
482 if isinstance(other, EmptyQuerySet):
483 return self
484 query = (
485 self
486 if self.sql_query.can_filter()
487 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
488 )
489 combined = query._chain()
490 combined._merge_known_related_objects(other)
491 if not other.sql_query.can_filter():
492 other = other.model._model_meta.base_queryset.filter(
493 id__in=other.values("id")
494 )
495 combined.sql_query.combine(other.sql_query, sql.OR)
496 return combined
497
498 def __xor__(self, other: QuerySet[T]) -> QuerySet[T]:
499 self._check_operator_queryset(other, "^")
500 self._merge_sanity_check(other)
501 if isinstance(self, EmptyQuerySet):
502 return other
503 if isinstance(other, EmptyQuerySet):
504 return self
505 query = (
506 self
507 if self.sql_query.can_filter()
508 else self.model._model_meta.base_queryset.filter(id__in=self.values("id"))
509 )
510 combined = query._chain()
511 combined._merge_known_related_objects(other)
512 if not other.sql_query.can_filter():
513 other = other.model._model_meta.base_queryset.filter(
514 id__in=other.values("id")
515 )
516 combined.sql_query.combine(other.sql_query, sql.XOR)
517 return combined
518
519 ####################################
520 # METHODS THAT DO DATABASE QUERIES #
521 ####################################
522
523 def _iterator(self, use_chunked_fetch: bool, chunk_size: int | None) -> Iterator[T]:
524 iterable = self._iterable_class(
525 self,
526 chunked_fetch=use_chunked_fetch,
527 chunk_size=chunk_size or 2000,
528 )
529 if not self._prefetch_related_lookups or chunk_size is None:
530 yield from iterable
531 return
532
533 iterator = iter(iterable)
534 while results := list(islice(iterator, chunk_size)):
535 prefetch_related_objects(results, *self._prefetch_related_lookups)
536 yield from results
537
538 def iterator(self, chunk_size: int | None = None) -> Iterator[T]:
539 """
540 An iterator over the results from applying this QuerySet to the
541 database. chunk_size must be provided for QuerySets that prefetch
542 related objects. Otherwise, a default chunk_size of 2000 is supplied.
543 """
544 if chunk_size is None:
545 if self._prefetch_related_lookups:
546 raise ValueError(
547 "chunk_size must be provided when using QuerySet.iterator() after "
548 "prefetch_related()."
549 )
550 elif chunk_size <= 0:
551 raise ValueError("Chunk size must be strictly positive.")
552 use_chunked_fetch = not db_connection.settings_dict.get(
553 "DISABLE_SERVER_SIDE_CURSORS"
554 )
555 return self._iterator(use_chunked_fetch, chunk_size)
556
557 def aggregate(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
558 """
559 Return a dictionary containing the calculations (aggregation)
560 over the current queryset.
561
562 If args is present the expression is passed as a kwarg using
563 the Aggregate object's default alias.
564 """
565 if self.sql_query.distinct_fields:
566 raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
567 self._validate_values_are_expressions(
568 (*args, *kwargs.values()), method_name="aggregate"
569 )
570 for arg in args:
571 # The default_alias property raises TypeError if default_alias
572 # can't be set automatically or AttributeError if it isn't an
573 # attribute.
574 try:
575 arg.default_alias
576 except (AttributeError, TypeError):
577 raise TypeError("Complex aggregates require an alias")
578 kwargs[arg.default_alias] = arg
579
580 return self.sql_query.chain().get_aggregation(kwargs)
581
582 def count(self) -> int:
583 """
584 Perform a SELECT COUNT() and return the number of records as an
585 integer.
586
587 If the QuerySet is already fully cached, return the length of the
588 cached results set to avoid multiple SELECT COUNT(*) calls.
589 """
590 if self._result_cache is not None:
591 return len(self._result_cache)
592
593 return self.sql_query.get_count()
594
595 def get(self, *args: Any, **kwargs: Any) -> T:
596 """
597 Perform the query and return a single object matching the given
598 keyword arguments.
599 """
600 if self.sql_query.combinator and (args or kwargs):
601 raise NotSupportedError(
602 f"Calling QuerySet.get(...) with filters after {self.sql_query.combinator}() is not "
603 "supported."
604 )
605 clone = (
606 self._chain() if self.sql_query.combinator else self.filter(*args, **kwargs)
607 )
608 if self.sql_query.can_filter() and not self.sql_query.distinct_fields:
609 clone = clone.order_by()
610 limit = None
611 if (
612 not clone.sql_query.select_for_update
613 or db_connection.features.supports_select_for_update_with_limit
614 ):
615 limit = MAX_GET_RESULTS
616 clone.sql_query.set_limits(high=limit)
617 num = len(clone)
618 if num == 1:
619 return clone._result_cache[0]
620 if not num:
621 raise self.model.DoesNotExist(
622 f"{self.model.model_options.object_name} matching query does not exist."
623 )
624 raise self.model.MultipleObjectsReturned(
625 "get() returned more than one {} -- it returned {}!".format(
626 self.model.model_options.object_name,
627 num if not limit or num < limit else "more than %s" % (limit - 1),
628 )
629 )
630
631 def get_or_none(self, *args: Any, **kwargs: Any) -> T | None:
632 """
633 Perform the query and return a single object matching the given
634 keyword arguments, or None if no object is found.
635 """
636 try:
637 return self.get(*args, **kwargs)
638 except self.model.DoesNotExist: # type: ignore[attr-defined]
639 return None
640
641 def create(self, **kwargs: Any) -> T:
642 """
643 Create a new object with the given kwargs, saving it to the database
644 and returning the created object.
645 """
646 obj = self.model(**kwargs) # type: ignore[misc]
647 self._for_write = True
648 obj.save(force_insert=True) # type: ignore[attr-defined]
649 return obj
650
651 def _prepare_for_bulk_create(self, objs: list[T]) -> None:
652 id_field = self.model._model_meta.get_field("id")
653 for obj in objs:
654 if obj.id is None: # type: ignore[attr-defined]
655 # Populate new primary key values.
656 obj.id = id_field.get_id_value_on_save(obj) # type: ignore[attr-defined]
657 obj._prepare_related_fields_for_save(operation_name="bulk_create") # type: ignore[attr-defined]
658
659 def _check_bulk_create_options(
660 self,
661 update_conflicts: bool,
662 update_fields: list[Field] | None,
663 unique_fields: list[Field] | None,
664 ) -> OnConflict | None:
665 db_features = db_connection.features
666 if update_conflicts:
667 if not db_features.supports_update_conflicts:
668 raise NotSupportedError(
669 "This database backend does not support updating conflicts."
670 )
671 if not update_fields:
672 raise ValueError(
673 "Fields that will be updated when a row insertion fails "
674 "on conflicts must be provided."
675 )
676 if unique_fields and not db_features.supports_update_conflicts_with_target:
677 raise NotSupportedError(
678 "This database backend does not support updating "
679 "conflicts with specifying unique fields that can trigger "
680 "the upsert."
681 )
682 if not unique_fields and db_features.supports_update_conflicts_with_target:
683 raise ValueError(
684 "Unique fields that can trigger the upsert must be provided."
685 )
686 # Updating primary keys and non-concrete fields is forbidden.
687 if any(not f.concrete or f.many_to_many for f in update_fields):
688 raise ValueError(
689 "bulk_create() can only be used with concrete fields in "
690 "update_fields."
691 )
692 if any(f.primary_key for f in update_fields):
693 raise ValueError(
694 "bulk_create() cannot be used with primary keys in update_fields."
695 )
696 if unique_fields:
697 if any(not f.concrete or f.many_to_many for f in unique_fields):
698 raise ValueError(
699 "bulk_create() can only be used with concrete fields "
700 "in unique_fields."
701 )
702 return OnConflict.UPDATE
703 return None
704
705 def bulk_create(
706 self,
707 objs: list[T],
708 batch_size: int | None = None,
709 update_conflicts: bool = False,
710 update_fields: list[str] | None = None,
711 unique_fields: list[str] | None = None,
712 ) -> list[T]:
713 """
714 Insert each of the instances into the database. Do *not* call
715 save() on each of the instances, and do not set the primary key attribute if it is an
716 autoincrement field (except if features.can_return_rows_from_bulk_insert=True).
717 Multi-table models are not supported.
718 """
719 # When you bulk insert you don't get the primary keys back (if it's an
720 # autoincrement, except if can_return_rows_from_bulk_insert=True), so
721 # you can't insert into the child tables which references this. There
722 # are two workarounds:
723 # 1) This could be implemented if you didn't have an autoincrement pk
724 # 2) You could do it by doing O(n) normal inserts into the parent
725 # tables to get the primary keys back and then doing a single bulk
726 # insert into the childmost table.
727 # We currently set the primary keys on the objects when using
728 # PostgreSQL via the RETURNING ID clause. It should be possible for
729 # Oracle as well, but the semantics for extracting the primary keys is
730 # trickier so it's not done yet.
731 if batch_size is not None and batch_size <= 0:
732 raise ValueError("Batch size must be a positive integer.")
733
734 if not objs:
735 return objs
736 meta = self.model._model_meta
737 if unique_fields:
738 unique_fields = [meta.get_field(name) for name in unique_fields]
739 if update_fields:
740 update_fields = [meta.get_field(name) for name in update_fields]
741 on_conflict = self._check_bulk_create_options(
742 update_conflicts,
743 update_fields,
744 unique_fields,
745 )
746 self._for_write = True
747 fields = meta.concrete_fields
748 objs = list(objs)
749 self._prepare_for_bulk_create(objs)
750 with transaction.atomic(savepoint=False):
751 objs_with_id, objs_without_id = partition(lambda o: o.id is None, objs)
752 if objs_with_id:
753 returned_columns = self._batched_insert(
754 objs_with_id,
755 fields,
756 batch_size,
757 on_conflict=on_conflict,
758 update_fields=update_fields,
759 unique_fields=unique_fields,
760 )
761 id_field = meta.get_field("id")
762 for obj_with_id, results in zip(objs_with_id, returned_columns):
763 for result, field in zip(results, meta.db_returning_fields):
764 if field != id_field:
765 setattr(obj_with_id, field.attname, result)
766 for obj_with_id in objs_with_id:
767 obj_with_id._state.adding = False
768 if objs_without_id:
769 fields = [f for f in fields if not isinstance(f, PrimaryKeyField)]
770 returned_columns = self._batched_insert(
771 objs_without_id,
772 fields,
773 batch_size,
774 on_conflict=on_conflict,
775 update_fields=update_fields,
776 unique_fields=unique_fields,
777 )
778 if (
779 db_connection.features.can_return_rows_from_bulk_insert
780 and on_conflict is None
781 ):
782 assert len(returned_columns) == len(objs_without_id)
783 for obj_without_id, results in zip(objs_without_id, returned_columns):
784 for result, field in zip(results, meta.db_returning_fields):
785 setattr(obj_without_id, field.attname, result)
786 obj_without_id._state.adding = False
787
788 return objs
789
790 def bulk_update(
791 self, objs: list[T], fields: list[str], batch_size: int | None = None
792 ) -> int:
793 """
794 Update the given fields in each of the given objects in the database.
795 """
796 if batch_size is not None and batch_size <= 0:
797 raise ValueError("Batch size must be a positive integer.")
798 if not fields:
799 raise ValueError("Field names must be given to bulk_update().")
800 objs_tuple = tuple(objs)
801 if any(obj.id is None for obj in objs_tuple): # type: ignore[attr-defined]
802 raise ValueError("All bulk_update() objects must have a primary key set.")
803 fields_list = [self.model._model_meta.get_field(name) for name in fields]
804 if any(not f.concrete or f.many_to_many for f in fields_list):
805 raise ValueError("bulk_update() can only be used with concrete fields.")
806 if any(f.primary_key for f in fields_list):
807 raise ValueError("bulk_update() cannot be used with primary key fields.")
808 if not objs_tuple:
809 return 0
810 for obj in objs_tuple:
811 obj._prepare_related_fields_for_save( # type: ignore[attr-defined]
812 operation_name="bulk_update", fields=fields_list
813 )
814 # PK is used twice in the resulting update query, once in the filter
815 # and once in the WHEN. Each field will also have one CAST.
816 self._for_write = True
817 max_batch_size = db_connection.ops.bulk_batch_size(
818 ["id", "id"] + fields_list, objs_tuple
819 )
820 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
821 requires_casting = db_connection.features.requires_casted_case_in_updates
822 batches = (
823 objs_tuple[i : i + batch_size]
824 for i in range(0, len(objs_tuple), batch_size)
825 )
826 updates = []
827 for batch_objs in batches:
828 update_kwargs = {}
829 for field in fields_list:
830 when_statements = []
831 for obj in batch_objs:
832 attr = getattr(obj, field.attname)
833 if not hasattr(attr, "resolve_expression"):
834 attr = Value(attr, output_field=field)
835 when_statements.append(When(id=obj.id, then=attr)) # type: ignore[attr-defined]
836 case_statement = Case(*when_statements, output_field=field)
837 if requires_casting:
838 case_statement = Cast(case_statement, output_field=field)
839 update_kwargs[field.attname] = case_statement
840 updates.append(([obj.id for obj in batch_objs], update_kwargs)) # type: ignore[attr-defined,misc]
841 rows_updated = 0
842 queryset = self._chain()
843 with transaction.atomic(savepoint=False):
844 for ids, update_kwargs in updates:
845 rows_updated += queryset.filter(id__in=ids).update(**update_kwargs)
846 return rows_updated
847
848 def get_or_create(
849 self, defaults: dict[str, Any] | None = None, **kwargs: Any
850 ) -> tuple[T, bool]:
851 """
852 Look up an object with the given kwargs, creating one if necessary.
853 Return a tuple of (object, created), where created is a boolean
854 specifying whether an object was created.
855 """
856 # The get() needs to be targeted at the write database in order
857 # to avoid potential transaction consistency problems.
858 self._for_write = True
859 try:
860 return self.get(**kwargs), False
861 except self.model.DoesNotExist: # type: ignore[attr-defined]
862 params = self._extract_model_params(defaults, **kwargs)
863 # Try to create an object using passed params.
864 try:
865 with transaction.atomic():
866 params = dict(resolve_callables(params))
867 return self.create(**params), True
868 except (IntegrityError, ValidationError):
869 # Since create() also validates by default,
870 # we can get any kind of ValidationError here,
871 # or it can flow through and get an IntegrityError from the database.
872 # The main thing we're concerned about is uniqueness failures,
873 # but ValidationError could include other things too.
874 # In all cases though it should be fine to try the get() again
875 # and return an existing object.
876 try:
877 return self.get(**kwargs), False
878 except self.model.DoesNotExist: # type: ignore[attr-defined]
879 pass
880 raise
881
882 def update_or_create(
883 self,
884 defaults: dict[str, Any] | None = None,
885 create_defaults: dict[str, Any] | None = None,
886 **kwargs: Any,
887 ) -> tuple[T, bool]:
888 """
889 Look up an object with the given kwargs, updating one with defaults
890 if it exists, otherwise create a new one. Optionally, an object can
891 be created with different values than defaults by using
892 create_defaults.
893 Return a tuple (object, created), where created is a boolean
894 specifying whether an object was created.
895 """
896 if create_defaults is None:
897 update_defaults = create_defaults = defaults or {}
898 else:
899 update_defaults = defaults or {}
900 self._for_write = True
901 with transaction.atomic():
902 # Lock the row so that a concurrent update is blocked until
903 # update_or_create() has performed its save.
904 obj, created = self.select_for_update().get_or_create(
905 create_defaults, **kwargs
906 )
907 if created:
908 return obj, created
909 for k, v in resolve_callables(update_defaults):
910 setattr(obj, k, v)
911
912 update_fields = set(update_defaults)
913 concrete_field_names = self.model._model_meta._non_pk_concrete_field_names
914 # update_fields does not support non-concrete fields.
915 if concrete_field_names.issuperset(update_fields):
916 # Add fields which are set on pre_save(), e.g. auto_now fields.
917 # This is to maintain backward compatibility as these fields
918 # are not updated unless explicitly specified in the
919 # update_fields list.
920 for field in self.model._model_meta.local_concrete_fields:
921 if not (
922 field.primary_key or field.__class__.pre_save is Field.pre_save
923 ):
924 update_fields.add(field.name)
925 if field.name != field.attname:
926 update_fields.add(field.attname)
927 obj.save(update_fields=update_fields) # type: ignore[attr-defined]
928 else:
929 obj.save() # type: ignore[attr-defined]
930 return obj, False
931
932 def _extract_model_params(
933 self, defaults: dict[str, Any] | None, **kwargs: Any
934 ) -> dict[str, Any]:
935 """
936 Prepare `params` for creating a model instance based on the given
937 kwargs; for use by get_or_create().
938 """
939 defaults = defaults or {}
940 params = {k: v for k, v in kwargs.items() if LOOKUP_SEP not in k}
941 params.update(defaults)
942 property_names = self.model._model_meta._property_names
943 invalid_params = []
944 for param in params:
945 try:
946 self.model._model_meta.get_field(param)
947 except FieldDoesNotExist:
948 # It's okay to use a model's property if it has a setter.
949 if not (param in property_names and getattr(self.model, param).fset):
950 invalid_params.append(param)
951 if invalid_params:
952 raise FieldError(
953 "Invalid field name(s) for model {}: '{}'.".format(
954 self.model.model_options.object_name,
955 "', '".join(sorted(invalid_params)),
956 )
957 )
958 return params
959
960 def first(self) -> T | None:
961 """Return the first object of a query or None if no match is found."""
962 for obj in self[:1]:
963 return obj
964 return None
965
966 def last(self) -> T | None:
967 """Return the last object of a query or None if no match is found."""
968 queryset = self.reverse()
969 for obj in queryset[:1]:
970 return obj
971 return None
972
973 def in_bulk(
974 self, id_list: list[Any] | None = None, *, field_name: str = "id"
975 ) -> dict[Any, T]:
976 """
977 Return a dictionary mapping each of the given IDs to the object with
978 that ID. If `id_list` isn't provided, evaluate the entire QuerySet.
979 """
980 if self.sql_query.is_sliced:
981 raise TypeError("Cannot use 'limit' or 'offset' with in_bulk().")
982 meta = self.model._model_meta
983 unique_fields = [
984 constraint.fields[0]
985 for constraint in self.model.model_options.total_unique_constraints
986 if len(constraint.fields) == 1
987 ]
988 if (
989 field_name != "id"
990 and not meta.get_field(field_name).primary_key
991 and field_name not in unique_fields
992 and self.sql_query.distinct_fields != (field_name,)
993 ):
994 raise ValueError(
995 f"in_bulk()'s field_name must be a unique field but {field_name!r} isn't."
996 )
997 if id_list is not None:
998 if not id_list:
999 return {}
1000 filter_key = f"{field_name}__in"
1001 batch_size = db_connection.features.max_query_params
1002 id_list_tuple = tuple(id_list)
1003 # If the database has a limit on the number of query parameters
1004 # (e.g. SQLite), retrieve objects in batches if necessary.
1005 if batch_size and batch_size < len(id_list_tuple):
1006 qs: tuple[T, ...] = ()
1007 for offset in range(0, len(id_list_tuple), batch_size):
1008 batch = id_list_tuple[offset : offset + batch_size]
1009 qs += tuple(self.filter(**{filter_key: batch}))
1010 else:
1011 qs = self.filter(**{filter_key: id_list_tuple})
1012 else:
1013 qs = self._chain()
1014 return {getattr(obj, field_name): obj for obj in qs}
1015
1016 def delete(self) -> tuple[int, dict[str, int]]:
1017 """Delete the records in the current QuerySet."""
1018 self._not_support_combined_queries("delete")
1019 if self.sql_query.is_sliced:
1020 raise TypeError("Cannot use 'limit' or 'offset' with delete().")
1021 if self.sql_query.distinct or self.sql_query.distinct_fields:
1022 raise TypeError("Cannot call delete() after .distinct().")
1023 if self._fields is not None:
1024 raise TypeError("Cannot call delete() after .values() or .values_list()")
1025
1026 del_query = self._chain()
1027
1028 # The delete is actually 2 queries - one to find related objects,
1029 # and one to delete. Make sure that the discovery of related
1030 # objects is performed on the same database as the deletion.
1031 del_query._for_write = True
1032
1033 # Disable non-supported fields.
1034 del_query.sql_query.select_for_update = False
1035 del_query.sql_query.select_related = False
1036 del_query.sql_query.clear_ordering(force=True)
1037
1038 from plain.models.deletion import Collector
1039
1040 collector = Collector(origin=self)
1041 collector.collect(del_query)
1042 deleted, _rows_count = collector.delete()
1043
1044 # Clear the result cache, in case this QuerySet gets reused.
1045 self._result_cache = None
1046 return deleted, _rows_count
1047
1048 def _raw_delete(self) -> int:
1049 """
1050 Delete objects found from the given queryset in single direct SQL
1051 query. No signals are sent and there is no protection for cascades.
1052 """
1053 query = self.sql_query.clone()
1054 query.__class__ = sql.DeleteQuery
1055 cursor = query.get_compiler().execute_sql(CURSOR)
1056 if cursor:
1057 with cursor:
1058 return cursor.rowcount
1059 return 0
1060
1061 def update(self, **kwargs: Any) -> int:
1062 """
1063 Update all elements in the current QuerySet, setting all the given
1064 fields to the appropriate values.
1065 """
1066 self._not_support_combined_queries("update")
1067 if self.sql_query.is_sliced:
1068 raise TypeError("Cannot update a query once a slice has been taken.")
1069 self._for_write = True
1070 query = self.sql_query.chain(sql.UpdateQuery)
1071 query.add_update_values(kwargs)
1072
1073 # Inline annotations in order_by(), if possible.
1074 new_order_by = []
1075 for col in query.order_by:
1076 alias = col
1077 descending = False
1078 if isinstance(alias, str) and alias.startswith("-"):
1079 alias = alias.removeprefix("-")
1080 descending = True
1081 if annotation := query.annotations.get(alias):
1082 if getattr(annotation, "contains_aggregate", False):
1083 raise FieldError(
1084 f"Cannot update when ordering by an aggregate: {annotation}"
1085 )
1086 if descending:
1087 annotation = annotation.desc()
1088 new_order_by.append(annotation)
1089 else:
1090 new_order_by.append(col)
1091 query.order_by = tuple(new_order_by)
1092
1093 # Clear any annotations so that they won't be present in subqueries.
1094 query.annotations = {}
1095 with transaction.mark_for_rollback_on_error():
1096 rows = query.get_compiler().execute_sql(CURSOR)
1097 self._result_cache = None
1098 return rows
1099
1100 def _update(self, values: list[tuple[Field, Any, Any]]) -> int:
1101 """
1102 A version of update() that accepts field objects instead of field names.
1103 Used primarily for model saving and not intended for use by general
1104 code (it requires too much poking around at model internals to be
1105 useful at that level).
1106 """
1107 if self.sql_query.is_sliced:
1108 raise TypeError("Cannot update a query once a slice has been taken.")
1109 query = self.sql_query.chain(sql.UpdateQuery)
1110 query.add_update_fields(values)
1111 # Clear any annotations so that they won't be present in subqueries.
1112 query.annotations = {}
1113 self._result_cache = None
1114 return query.get_compiler().execute_sql(CURSOR)
1115
1116 def exists(self) -> bool:
1117 """
1118 Return True if the QuerySet would have any results, False otherwise.
1119 """
1120 if self._result_cache is None:
1121 return self.sql_query.has_results()
1122 return bool(self._result_cache)
1123
1124 def contains(self, obj: T) -> bool:
1125 """
1126 Return True if the QuerySet contains the provided obj,
1127 False otherwise.
1128 """
1129 self._not_support_combined_queries("contains")
1130 if self._fields is not None:
1131 raise TypeError(
1132 "Cannot call QuerySet.contains() after .values() or .values_list()."
1133 )
1134 try:
1135 if obj.__class__ != self.model:
1136 return False
1137 except AttributeError:
1138 raise TypeError("'obj' must be a model instance.")
1139 if obj.id is None: # type: ignore[attr-defined]
1140 raise ValueError("QuerySet.contains() cannot be used on unsaved objects.")
1141 if self._result_cache is not None:
1142 return obj in self._result_cache
1143 return self.filter(id=obj.id).exists() # type: ignore[attr-defined]
1144
1145 def _prefetch_related_objects(self) -> None:
1146 # This method can only be called once the result cache has been filled.
1147 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1148 self._prefetch_done = True
1149
1150 def explain(self, *, format: str | None = None, **options: Any) -> str:
1151 """
1152 Runs an EXPLAIN on the SQL query this QuerySet would perform, and
1153 returns the results.
1154 """
1155 return self.sql_query.explain(format=format, **options)
1156
1157 ##################################################
1158 # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
1159 ##################################################
1160
1161 def raw(
1162 self,
1163 raw_query: str,
1164 params: tuple[Any, ...] = (),
1165 translations: dict[str, str] | None = None,
1166 ) -> RawQuerySet:
1167 qs = RawQuerySet(
1168 raw_query,
1169 model=self.model,
1170 params=params,
1171 translations=translations,
1172 )
1173 qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
1174 return qs
1175
1176 def _values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1177 clone = self._chain()
1178 if expressions:
1179 clone = clone.annotate(**expressions)
1180 clone._fields = fields # type: ignore[assignment]
1181 clone.sql_query.set_values(fields)
1182 return clone
1183
1184 def values(self, *fields: str, **expressions: Any) -> QuerySet[Any]:
1185 fields += tuple(expressions)
1186 clone = self._values(*fields, **expressions)
1187 clone._iterable_class = ValuesIterable
1188 return clone
1189
1190 def values_list(
1191 self, *fields: str, flat: bool = False, named: bool = False
1192 ) -> QuerySet[Any]:
1193 if flat and named:
1194 raise TypeError("'flat' and 'named' can't be used together.")
1195 if flat and len(fields) > 1:
1196 raise TypeError(
1197 "'flat' is not valid when values_list is called with more than one "
1198 "field."
1199 )
1200
1201 field_names = {f for f in fields if not hasattr(f, "resolve_expression")}
1202 _fields = []
1203 expressions = {}
1204 counter = 1
1205 for field in fields:
1206 if hasattr(field, "resolve_expression"):
1207 field_id_prefix = getattr(
1208 field, "default_alias", field.__class__.__name__.lower()
1209 )
1210 while True:
1211 field_id = field_id_prefix + str(counter)
1212 counter += 1
1213 if field_id not in field_names:
1214 break
1215 expressions[field_id] = field
1216 _fields.append(field_id)
1217 else:
1218 _fields.append(field)
1219
1220 clone = self._values(*_fields, **expressions)
1221 clone._iterable_class = (
1222 NamedValuesListIterable
1223 if named
1224 else FlatValuesListIterable
1225 if flat
1226 else ValuesListIterable
1227 )
1228 return clone
1229
1230 def dates(self, field_name: str, kind: str, order: str = "ASC") -> QuerySet[Any]:
1231 """
1232 Return a list of date objects representing all available dates for
1233 the given field_name, scoped to 'kind'.
1234 """
1235 if kind not in ("year", "month", "week", "day"):
1236 raise ValueError("'kind' must be one of 'year', 'month', 'week', or 'day'.")
1237 if order not in ("ASC", "DESC"):
1238 raise ValueError("'order' must be either 'ASC' or 'DESC'.")
1239 return (
1240 self.annotate(
1241 datefield=Trunc(field_name, kind, output_field=DateField()),
1242 plain_field=F(field_name),
1243 )
1244 .values_list("datefield", flat=True)
1245 .distinct()
1246 .filter(plain_field__isnull=False)
1247 .order_by(("-" if order == "DESC" else "") + "datefield")
1248 )
1249
1250 def datetimes(
1251 self,
1252 field_name: str,
1253 kind: str,
1254 order: str = "ASC",
1255 tzinfo: tzinfo | None = None,
1256 ) -> QuerySet[Any]:
1257 """
1258 Return a list of datetime objects representing all available
1259 datetimes for the given field_name, scoped to 'kind'.
1260 """
1261 if kind not in ("year", "month", "week", "day", "hour", "minute", "second"):
1262 raise ValueError(
1263 "'kind' must be one of 'year', 'month', 'week', 'day', "
1264 "'hour', 'minute', or 'second'."
1265 )
1266 if order not in ("ASC", "DESC"):
1267 raise ValueError("'order' must be either 'ASC' or 'DESC'.")
1268
1269 if tzinfo is None:
1270 tzinfo = timezone.get_current_timezone()
1271
1272 return (
1273 self.annotate(
1274 datetimefield=Trunc(
1275 field_name,
1276 kind,
1277 output_field=DateTimeField(),
1278 tzinfo=tzinfo,
1279 ),
1280 plain_field=F(field_name),
1281 )
1282 .values_list("datetimefield", flat=True)
1283 .distinct()
1284 .filter(plain_field__isnull=False)
1285 .order_by(("-" if order == "DESC" else "") + "datetimefield")
1286 )
1287
1288 def none(self) -> QuerySet[T]:
1289 """Return an empty QuerySet."""
1290 clone = self._chain()
1291 clone.sql_query.set_empty()
1292 return clone
1293
1294 ##################################################################
1295 # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
1296 ##################################################################
1297
1298 def all(self) -> Self:
1299 """
1300 Return a new QuerySet that is a copy of the current one. This allows a
1301 QuerySet to proxy for a model queryset in some cases.
1302 """
1303 return self._chain()
1304
1305 def filter(self, *args: Any, **kwargs: Any) -> Self:
1306 """
1307 Return a new QuerySet instance with the args ANDed to the existing
1308 set.
1309 """
1310 self._not_support_combined_queries("filter")
1311 return self._filter_or_exclude(False, args, kwargs)
1312
1313 def exclude(self, *args: Any, **kwargs: Any) -> Self:
1314 """
1315 Return a new QuerySet instance with NOT (args) ANDed to the existing
1316 set.
1317 """
1318 self._not_support_combined_queries("exclude")
1319 return self._filter_or_exclude(True, args, kwargs)
1320
1321 def _filter_or_exclude(
1322 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1323 ) -> Self:
1324 if (args or kwargs) and self.sql_query.is_sliced:
1325 raise TypeError("Cannot filter a query once a slice has been taken.")
1326 clone = self._chain()
1327 if self._defer_next_filter:
1328 self._defer_next_filter = False
1329 clone._deferred_filter = negate, args, kwargs
1330 else:
1331 clone._filter_or_exclude_inplace(negate, args, kwargs)
1332 return clone
1333
1334 def _filter_or_exclude_inplace(
1335 self, negate: bool, args: tuple[Any, ...], kwargs: dict[str, Any]
1336 ) -> None:
1337 if negate:
1338 self._query.add_q(~Q(*args, **kwargs)) # type: ignore[unsupported-operator]
1339 else:
1340 self._query.add_q(Q(*args, **kwargs))
1341
1342 def complex_filter(self, filter_obj: Q | dict[str, Any]) -> QuerySet[T]:
1343 """
1344 Return a new QuerySet instance with filter_obj added to the filters.
1345
1346 filter_obj can be a Q object or a dictionary of keyword lookup
1347 arguments.
1348
1349 This exists to support framework features such as 'limit_choices_to',
1350 and usually it will be more natural to use other methods.
1351 """
1352 if isinstance(filter_obj, Q):
1353 clone = self._chain()
1354 clone.sql_query.add_q(filter_obj)
1355 return clone
1356 else:
1357 return self._filter_or_exclude(False, args=(), kwargs=filter_obj)
1358
1359 def _combinator_query(
1360 self, combinator: str, *other_qs: QuerySet[T], all: bool = False
1361 ) -> QuerySet[T]:
1362 # Clone the query to inherit the select list and everything
1363 clone = self._chain()
1364 # Clear limits and ordering so they can be reapplied
1365 clone.sql_query.clear_ordering(force=True)
1366 clone.sql_query.clear_limits()
1367 clone.sql_query.combined_queries = (self.sql_query,) + tuple(
1368 qs.sql_query for qs in other_qs
1369 )
1370 clone.sql_query.combinator = combinator
1371 clone.sql_query.combinator_all = all
1372 return clone
1373
1374 def union(self, *other_qs: QuerySet[T], all: bool = False) -> QuerySet[T]:
1375 # If the query is an EmptyQuerySet, combine all nonempty querysets.
1376 if isinstance(self, EmptyQuerySet):
1377 qs = [q for q in other_qs if not isinstance(q, EmptyQuerySet)]
1378 if not qs:
1379 return self
1380 if len(qs) == 1:
1381 return qs[0]
1382 return qs[0]._combinator_query("union", *qs[1:], all=all)
1383 return self._combinator_query("union", *other_qs, all=all)
1384
1385 def intersection(self, *other_qs: QuerySet[T]) -> QuerySet[T]:
1386 # If any query is an EmptyQuerySet, return it.
1387 if isinstance(self, EmptyQuerySet):
1388 return self
1389 for other in other_qs:
1390 if isinstance(other, EmptyQuerySet):
1391 return other
1392 return self._combinator_query("intersection", *other_qs)
1393
1394 def difference(self, *other_qs: QuerySet[T]) -> QuerySet[T]:
1395 # If the query is an EmptyQuerySet, return it.
1396 if isinstance(self, EmptyQuerySet):
1397 return self
1398 return self._combinator_query("difference", *other_qs)
1399
1400 def select_for_update(
1401 self,
1402 nowait: bool = False,
1403 skip_locked: bool = False,
1404 of: tuple[str, ...] = (),
1405 no_key: bool = False,
1406 ) -> QuerySet[T]:
1407 """
1408 Return a new QuerySet instance that will select objects with a
1409 FOR UPDATE lock.
1410 """
1411 if nowait and skip_locked:
1412 raise ValueError("The nowait option cannot be used with skip_locked.")
1413 obj = self._chain()
1414 obj._for_write = True
1415 obj.sql_query.select_for_update = True
1416 obj.sql_query.select_for_update_nowait = nowait
1417 obj.sql_query.select_for_update_skip_locked = skip_locked
1418 obj.sql_query.select_for_update_of = of
1419 obj.sql_query.select_for_no_key_update = no_key
1420 return obj
1421
1422 def select_related(self, *fields: str | None) -> Self:
1423 """
1424 Return a new QuerySet instance that will select related objects.
1425
1426 If fields are specified, they must be ForeignKey fields and only those
1427 related objects are included in the selection.
1428
1429 If select_related(None) is called, clear the list.
1430 """
1431 self._not_support_combined_queries("select_related")
1432 if self._fields is not None:
1433 raise TypeError(
1434 "Cannot call select_related() after .values() or .values_list()"
1435 )
1436
1437 obj = self._chain()
1438 if fields == (None,):
1439 obj.sql_query.select_related = False
1440 elif fields:
1441 obj.sql_query.add_select_related(fields)
1442 else:
1443 obj.sql_query.select_related = True
1444 return obj
1445
1446 def prefetch_related(self, *lookups: str | Prefetch | None) -> Self:
1447 """
1448 Return a new QuerySet instance that will prefetch the specified
1449 Many-To-One and Many-To-Many related objects when the QuerySet is
1450 evaluated.
1451
1452 When prefetch_related() is called more than once, append to the list of
1453 prefetch lookups. If prefetch_related(None) is called, clear the list.
1454 """
1455 self._not_support_combined_queries("prefetch_related")
1456 clone = self._chain()
1457 if lookups == (None,):
1458 clone._prefetch_related_lookups = ()
1459 else:
1460 for lookup in lookups:
1461 lookup_str: str
1462 if isinstance(lookup, Prefetch):
1463 lookup_str = lookup.prefetch_to
1464 else:
1465 lookup_str = lookup # type: ignore[assignment]
1466 lookup_str = lookup_str.split(LOOKUP_SEP, 1)[0]
1467 if lookup_str in self.sql_query._filtered_relations:
1468 raise ValueError(
1469 "prefetch_related() is not supported with FilteredRelation."
1470 )
1471 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1472 return clone
1473
1474 def annotate(self, *args: Any, **kwargs: Any) -> Self:
1475 """
1476 Return a query set in which the returned objects have been annotated
1477 with extra data or aggregations.
1478 """
1479 self._not_support_combined_queries("annotate")
1480 return self._annotate(args, kwargs, select=True)
1481
1482 def alias(self, *args: Any, **kwargs: Any) -> Self:
1483 """
1484 Return a query set with added aliases for extra data or aggregations.
1485 """
1486 self._not_support_combined_queries("alias")
1487 return self._annotate(args, kwargs, select=False)
1488
1489 def _annotate(
1490 self, args: tuple[Any, ...], kwargs: dict[str, Any], select: bool = True
1491 ) -> Self:
1492 self._validate_values_are_expressions(
1493 args + tuple(kwargs.values()), method_name="annotate"
1494 )
1495 annotations = {}
1496 for arg in args:
1497 # The default_alias property may raise a TypeError.
1498 try:
1499 if arg.default_alias in kwargs:
1500 raise ValueError(
1501 f"The named annotation '{arg.default_alias}' conflicts with the "
1502 "default name for another annotation."
1503 )
1504 except TypeError:
1505 raise TypeError("Complex annotations require an alias")
1506 annotations[arg.default_alias] = arg
1507 annotations.update(kwargs)
1508
1509 clone = self._chain()
1510 names = self._fields
1511 if names is None:
1512 names = set(
1513 chain.from_iterable(
1514 (field.name, field.attname)
1515 if hasattr(field, "attname")
1516 else (field.name,)
1517 for field in self.model._model_meta.get_fields()
1518 )
1519 )
1520
1521 for alias, annotation in annotations.items():
1522 if alias in names:
1523 raise ValueError(
1524 f"The annotation '{alias}' conflicts with a field on the model."
1525 )
1526 if isinstance(annotation, FilteredRelation):
1527 clone.sql_query.add_filtered_relation(annotation, alias)
1528 else:
1529 clone.sql_query.add_annotation(
1530 annotation,
1531 alias,
1532 select=select,
1533 )
1534 for alias, annotation in clone.sql_query.annotations.items():
1535 if alias in annotations and annotation.contains_aggregate:
1536 if clone._fields is None:
1537 clone.sql_query.group_by = True
1538 else:
1539 clone.sql_query.set_group_by()
1540 break
1541
1542 return clone
1543
1544 def order_by(self, *field_names: str) -> Self:
1545 """Return a new QuerySet instance with the ordering changed."""
1546 if self.sql_query.is_sliced:
1547 raise TypeError("Cannot reorder a query once a slice has been taken.")
1548 obj = self._chain()
1549 obj.sql_query.clear_ordering(force=True, clear_default=False)
1550 obj.sql_query.add_ordering(*field_names)
1551 return obj
1552
1553 def distinct(self, *field_names: str) -> Self:
1554 """
1555 Return a new QuerySet instance that will select only distinct results.
1556 """
1557 self._not_support_combined_queries("distinct")
1558 if self.sql_query.is_sliced:
1559 raise TypeError(
1560 "Cannot create distinct fields once a slice has been taken."
1561 )
1562 obj = self._chain()
1563 obj.sql_query.add_distinct_fields(*field_names)
1564 return obj
1565
1566 def extra(
1567 self,
1568 select: dict[str, str] | None = None,
1569 where: list[str] | None = None,
1570 params: list[Any] | None = None,
1571 tables: list[str] | None = None,
1572 order_by: list[str] | None = None,
1573 select_params: list[Any] | None = None,
1574 ) -> QuerySet[T]:
1575 """Add extra SQL fragments to the query."""
1576 self._not_support_combined_queries("extra")
1577 if self.sql_query.is_sliced:
1578 raise TypeError("Cannot change a query once a slice has been taken.")
1579 clone = self._chain()
1580 clone.sql_query.add_extra(
1581 select, select_params, where, params, tables, order_by
1582 )
1583 return clone
1584
1585 def reverse(self) -> QuerySet[T]:
1586 """Reverse the ordering of the QuerySet."""
1587 if self.sql_query.is_sliced:
1588 raise TypeError("Cannot reverse a query once a slice has been taken.")
1589 clone = self._chain()
1590 clone.sql_query.standard_ordering = not clone.sql_query.standard_ordering
1591 return clone
1592
1593 def defer(self, *fields: str | None) -> QuerySet[T]:
1594 """
1595 Defer the loading of data for certain fields until they are accessed.
1596 Add the set of deferred fields to any existing set of deferred fields.
1597 The only exception to this is if None is passed in as the only
1598 parameter, in which case removal all deferrals.
1599 """
1600 self._not_support_combined_queries("defer")
1601 if self._fields is not None:
1602 raise TypeError("Cannot call defer() after .values() or .values_list()")
1603 clone = self._chain()
1604 if fields == (None,):
1605 clone.sql_query.clear_deferred_loading()
1606 else:
1607 clone.sql_query.add_deferred_loading(fields)
1608 return clone
1609
1610 def only(self, *fields: str) -> QuerySet[T]:
1611 """
1612 Essentially, the opposite of defer(). Only the fields passed into this
1613 method and that are not already specified as deferred are loaded
1614 immediately when the queryset is evaluated.
1615 """
1616 self._not_support_combined_queries("only")
1617 if self._fields is not None:
1618 raise TypeError("Cannot call only() after .values() or .values_list()")
1619 if fields == (None,):
1620 # Can only pass None to defer(), not only(), as the rest option.
1621 # That won't stop people trying to do this, so let's be explicit.
1622 raise TypeError("Cannot pass None as an argument to only().")
1623 for field in fields:
1624 field = field.split(LOOKUP_SEP, 1)[0]
1625 if field in self.sql_query._filtered_relations:
1626 raise ValueError("only() is not supported with FilteredRelation.")
1627 clone = self._chain()
1628 clone.sql_query.add_immediate_loading(fields)
1629 return clone
1630
1631 ###################################
1632 # PUBLIC INTROSPECTION ATTRIBUTES #
1633 ###################################
1634
1635 @property
1636 def ordered(self) -> bool:
1637 """
1638 Return True if the QuerySet is ordered -- i.e. has an order_by()
1639 clause or a default ordering on the model (or is empty).
1640 """
1641 if isinstance(self, EmptyQuerySet):
1642 return True
1643 if self.sql_query.extra_order_by or self.sql_query.order_by:
1644 return True
1645 elif (
1646 self.sql_query.default_ordering
1647 and self.sql_query.get_model_meta().ordering
1648 and
1649 # A default ordering doesn't affect GROUP BY queries.
1650 not self.sql_query.group_by
1651 ):
1652 return True
1653 else:
1654 return False
1655
1656 ###################
1657 # PRIVATE METHODS #
1658 ###################
1659
1660 def _insert(
1661 self,
1662 objs: list[T],
1663 fields: list[Field],
1664 returning_fields: list[Field] | None = None,
1665 raw: bool = False,
1666 on_conflict: OnConflict | None = None,
1667 update_fields: list[Field] | None = None,
1668 unique_fields: list[Field] | None = None,
1669 ) -> list[tuple[Any, ...]] | None:
1670 """
1671 Insert a new record for the given model. This provides an interface to
1672 the InsertQuery class and is how Model.save() is implemented.
1673 """
1674 self._for_write = True
1675 query = sql.InsertQuery(
1676 self.model,
1677 on_conflict=on_conflict.value if on_conflict else None, # type: ignore[attr-defined]
1678 update_fields=update_fields,
1679 unique_fields=unique_fields,
1680 )
1681 query.insert_values(fields, objs, raw=raw)
1682 return query.get_compiler().execute_sql(returning_fields)
1683
1684 def _batched_insert(
1685 self,
1686 objs: list[T],
1687 fields: list[Field],
1688 batch_size: int,
1689 on_conflict: OnConflict | None = None,
1690 update_fields: list[Field] | None = None,
1691 unique_fields: list[Field] | None = None,
1692 ) -> list[tuple[Any, ...]]:
1693 """
1694 Helper method for bulk_create() to insert objs one batch at a time.
1695 """
1696 ops = db_connection.ops
1697 max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
1698 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
1699 inserted_rows = []
1700 bulk_return = db_connection.features.can_return_rows_from_bulk_insert
1701 for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
1702 if bulk_return and on_conflict is None:
1703 inserted_rows.extend(
1704 self._insert(
1705 item,
1706 fields=fields,
1707 returning_fields=self.model._model_meta.db_returning_fields,
1708 )
1709 )
1710 else:
1711 self._insert(
1712 item,
1713 fields=fields,
1714 on_conflict=on_conflict,
1715 update_fields=update_fields,
1716 unique_fields=unique_fields,
1717 )
1718 return inserted_rows
1719
1720 def _chain(self) -> Self:
1721 """
1722 Return a copy of the current QuerySet that's ready for another
1723 operation.
1724 """
1725 obj = self._clone()
1726 if obj._sticky_filter:
1727 obj.sql_query.filter_is_sticky = True
1728 obj._sticky_filter = False
1729 return obj
1730
1731 def _clone(self) -> Self:
1732 """
1733 Return a copy of the current QuerySet. A lightweight alternative
1734 to deepcopy().
1735 """
1736 c = self.__class__.from_model(
1737 model=self.model,
1738 query=self.sql_query.chain(),
1739 )
1740 c._sticky_filter = self._sticky_filter
1741 c._for_write = self._for_write
1742 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1743 c._known_related_objects = self._known_related_objects
1744 c._iterable_class = self._iterable_class
1745 c._fields = self._fields
1746 return c
1747
1748 def _fetch_all(self) -> None:
1749 if self._result_cache is None:
1750 self._result_cache = list(self._iterable_class(self))
1751 if self._prefetch_related_lookups and not self._prefetch_done:
1752 self._prefetch_related_objects()
1753
1754 def _next_is_sticky(self) -> QuerySet[T]:
1755 """
1756 Indicate that the next filter call and the one following that should
1757 be treated as a single filter. This is only important when it comes to
1758 determining when to reuse tables for many-to-many filters. Required so
1759 that we can filter naturally on the results of related managers.
1760
1761 This doesn't return a clone of the current QuerySet (it returns
1762 "self"). The method is only used internally and should be immediately
1763 followed by a filter() that does create a clone.
1764 """
1765 self._sticky_filter = True
1766 return self
1767
1768 def _merge_sanity_check(self, other: QuerySet[T]) -> None:
1769 """Check that two QuerySet classes may be merged."""
1770 if self._fields is not None and (
1771 set(self.sql_query.values_select) != set(other.sql_query.values_select)
1772 or set(self.sql_query.extra_select) != set(other.sql_query.extra_select)
1773 or set(self.sql_query.annotation_select)
1774 != set(other.sql_query.annotation_select)
1775 ):
1776 raise TypeError(
1777 f"Merging '{self.__class__.__name__}' classes must involve the same values in each case."
1778 )
1779
1780 def _merge_known_related_objects(self, other: QuerySet[T]) -> None:
1781 """
1782 Keep track of all known related objects from either QuerySet instance.
1783 """
1784 for field, objects in other._known_related_objects.items():
1785 self._known_related_objects.setdefault(field, {}).update(objects)
1786
1787 def resolve_expression(self, *args: Any, **kwargs: Any) -> sql.Query:
1788 if self._fields and len(self._fields) > 1:
1789 # values() queryset can only be used as nested queries
1790 # if they are set up to select only a single field.
1791 raise TypeError("Cannot use multi-field values as a filter value.")
1792 query = self.sql_query.resolve_expression(*args, **kwargs)
1793 return query
1794
1795 def _has_filters(self) -> bool:
1796 """
1797 Check if this QuerySet has any filtering going on. This isn't
1798 equivalent with checking if all objects are present in results, for
1799 example, qs[1:]._has_filters() -> False.
1800 """
1801 return self.sql_query.has_filters()
1802
1803 @staticmethod
1804 def _validate_values_are_expressions(
1805 values: tuple[Any, ...], method_name: str
1806 ) -> None:
1807 invalid_args = sorted(
1808 str(arg) for arg in values if not hasattr(arg, "resolve_expression")
1809 )
1810 if invalid_args:
1811 raise TypeError(
1812 "QuerySet.{}() received non-expression(s): {}.".format(
1813 method_name,
1814 ", ".join(invalid_args),
1815 )
1816 )
1817
1818 def _not_support_combined_queries(self, operation_name: str) -> None:
1819 if self.sql_query.combinator:
1820 raise NotSupportedError(
1821 f"Calling QuerySet.{operation_name}() after {self.sql_query.combinator}() is not supported."
1822 )
1823
1824 def _check_operator_queryset(self, other: QuerySet[T], operator_: str) -> None:
1825 if self.sql_query.combinator or other.sql_query.combinator:
1826 raise TypeError(f"Cannot use {operator_} operator with combined queryset.")
1827
1828
1829class InstanceCheckMeta(type):
1830 def __instancecheck__(self, instance: object) -> bool:
1831 return isinstance(instance, QuerySet) and instance.sql_query.is_empty()
1832
1833
1834class EmptyQuerySet(metaclass=InstanceCheckMeta):
1835 """
1836 Marker class to checking if a queryset is empty by .none():
1837 isinstance(qs.none(), EmptyQuerySet) -> True
1838 """
1839
1840 def __init__(self, *args: Any, **kwargs: Any):
1841 raise TypeError("EmptyQuerySet can't be instantiated")
1842
1843
1844class RawQuerySet:
1845 """
1846 Provide an iterator which converts the results of raw SQL queries into
1847 annotated model instances.
1848 """
1849
1850 def __init__(
1851 self,
1852 raw_query: str,
1853 model: type[Model] | None = None,
1854 query: sql.RawQuery | None = None,
1855 params: tuple[Any, ...] = (),
1856 translations: dict[str, str] | None = None,
1857 ):
1858 self.raw_query = raw_query
1859 self.model = model
1860 self.sql_query = query or sql.RawQuery(sql=raw_query, params=params)
1861 self.params = params
1862 self.translations = translations or {}
1863 self._result_cache: list[Model] | None = None
1864 self._prefetch_related_lookups: tuple[Any, ...] = ()
1865 self._prefetch_done = False
1866
1867 def resolve_model_init_order(
1868 self,
1869 ) -> tuple[list[str], list[int], list[tuple[str, int]]]:
1870 """Resolve the init field names and value positions."""
1871 converter = db_connection.introspection.identifier_converter
1872 model_init_fields = [
1873 f
1874 for f in self.model._model_meta.fields
1875 if converter(f.column) in self.columns
1876 ]
1877 annotation_fields = [
1878 (column, pos)
1879 for pos, column in enumerate(self.columns)
1880 if column not in self.model_fields
1881 ]
1882 model_init_order = [
1883 self.columns.index(converter(f.column)) for f in model_init_fields
1884 ]
1885 model_init_names = [f.attname for f in model_init_fields]
1886 return model_init_names, model_init_order, annotation_fields
1887
1888 def prefetch_related(self, *lookups: str | Prefetch | None) -> RawQuerySet:
1889 """Same as QuerySet.prefetch_related()"""
1890 clone = self._clone()
1891 if lookups == (None,):
1892 clone._prefetch_related_lookups = ()
1893 else:
1894 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1895 return clone
1896
1897 def _prefetch_related_objects(self) -> None:
1898 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1899 self._prefetch_done = True
1900
1901 def _clone(self) -> RawQuerySet:
1902 """Same as QuerySet._clone()"""
1903 c = self.__class__(
1904 self.raw_query,
1905 model=self.model,
1906 query=self.sql_query,
1907 params=self.params,
1908 translations=self.translations,
1909 )
1910 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1911 return c
1912
1913 def _fetch_all(self) -> None:
1914 if self._result_cache is None:
1915 self._result_cache = list(self.iterator())
1916 if self._prefetch_related_lookups and not self._prefetch_done:
1917 self._prefetch_related_objects()
1918
1919 def __len__(self) -> int:
1920 self._fetch_all()
1921 return len(self._result_cache) # type: ignore[arg-type]
1922
1923 def __bool__(self) -> bool:
1924 self._fetch_all()
1925 return bool(self._result_cache)
1926
1927 def __iter__(self) -> Iterator[Model]:
1928 self._fetch_all()
1929 return iter(self._result_cache) # type: ignore[arg-type]
1930
1931 def iterator(self) -> Iterator[Model]:
1932 yield from RawModelIterable(self)
1933
1934 def __repr__(self) -> str:
1935 return f"<{self.__class__.__name__}: {self.sql_query}>"
1936
1937 def __getitem__(self, k: int | slice) -> Model | list[Model]:
1938 return list(self)[k]
1939
1940 @cached_property
1941 def columns(self) -> list[str]:
1942 """
1943 A list of model field names in the order they'll appear in the
1944 query results.
1945 """
1946 columns = self.sql_query.get_columns()
1947 # Adjust any column names which don't match field names
1948 for query_name, model_name in self.translations.items():
1949 # Ignore translations for nonexistent column names
1950 try:
1951 index = columns.index(query_name)
1952 except ValueError:
1953 pass
1954 else:
1955 columns[index] = model_name
1956 return columns
1957
1958 @cached_property
1959 def model_fields(self) -> dict[str, Field]:
1960 """A dict mapping column names to model field names."""
1961 converter = db_connection.introspection.identifier_converter
1962 model_fields = {}
1963 for field in self.model._model_meta.fields:
1964 name, column = field.get_attname_column()
1965 model_fields[converter(column)] = field
1966 return model_fields
1967
1968
1969class Prefetch:
1970 def __init__(
1971 self,
1972 lookup: str,
1973 queryset: QuerySet[Any] | None = None,
1974 to_attr: str | None = None,
1975 ):
1976 # `prefetch_through` is the path we traverse to perform the prefetch.
1977 self.prefetch_through = lookup
1978 # `prefetch_to` is the path to the attribute that stores the result.
1979 self.prefetch_to = lookup
1980 if queryset is not None and (
1981 isinstance(queryset, RawQuerySet)
1982 or (
1983 hasattr(queryset, "_iterable_class")
1984 and not issubclass(queryset._iterable_class, ModelIterable)
1985 )
1986 ):
1987 raise ValueError(
1988 "Prefetch querysets cannot use raw(), values(), and values_list()."
1989 )
1990 if to_attr:
1991 self.prefetch_to = LOOKUP_SEP.join(
1992 lookup.split(LOOKUP_SEP)[:-1] + [to_attr]
1993 )
1994
1995 self.queryset = queryset
1996 self.to_attr = to_attr
1997
1998 def __getstate__(self) -> dict[str, Any]:
1999 obj_dict = self.__dict__.copy()
2000 if self.queryset is not None:
2001 queryset = self.queryset._chain()
2002 # Prevent the QuerySet from being evaluated
2003 queryset._result_cache = []
2004 queryset._prefetch_done = True
2005 obj_dict["queryset"] = queryset
2006 return obj_dict
2007
2008 def add_prefix(self, prefix: str) -> None:
2009 self.prefetch_through = prefix + LOOKUP_SEP + self.prefetch_through
2010 self.prefetch_to = prefix + LOOKUP_SEP + self.prefetch_to
2011
2012 def get_current_prefetch_to(self, level: int) -> str:
2013 return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[: level + 1])
2014
2015 def get_current_to_attr(self, level: int) -> tuple[str, bool]:
2016 parts = self.prefetch_to.split(LOOKUP_SEP)
2017 to_attr = parts[level]
2018 as_attr = self.to_attr and level == len(parts) - 1
2019 return to_attr, as_attr
2020
2021 def get_current_queryset(self, level: int) -> QuerySet[Any] | None:
2022 if self.get_current_prefetch_to(level) == self.prefetch_to:
2023 return self.queryset
2024 return None
2025
2026 def __eq__(self, other: object) -> bool:
2027 if not isinstance(other, Prefetch):
2028 return NotImplemented
2029 return self.prefetch_to == other.prefetch_to
2030
2031 def __hash__(self) -> int:
2032 return hash((self.__class__, self.prefetch_to))
2033
2034
2035def normalize_prefetch_lookups(
2036 lookups: tuple[str | Prefetch, ...] | list[str | Prefetch],
2037 prefix: str | None = None,
2038) -> list[Prefetch]:
2039 """Normalize lookups into Prefetch objects."""
2040 ret = []
2041 for lookup in lookups:
2042 if not isinstance(lookup, Prefetch):
2043 lookup = Prefetch(lookup)
2044 if prefix:
2045 lookup.add_prefix(prefix)
2046 ret.append(lookup)
2047 return ret
2048
2049
2050def prefetch_related_objects(
2051 model_instances: list[Model], *related_lookups: str | Prefetch
2052) -> None:
2053 """
2054 Populate prefetched object caches for a list of model instances based on
2055 the lookups/Prefetch instances given.
2056 """
2057 if not model_instances:
2058 return # nothing to do
2059
2060 # We need to be able to dynamically add to the list of prefetch_related
2061 # lookups that we look up (see below). So we need some book keeping to
2062 # ensure we don't do duplicate work.
2063 done_queries = {} # dictionary of things like 'foo__bar': [results]
2064
2065 auto_lookups = set() # we add to this as we go through.
2066 followed_descriptors = set() # recursion protection
2067
2068 all_lookups = normalize_prefetch_lookups(reversed(related_lookups)) # type: ignore[arg-type]
2069 while all_lookups:
2070 lookup = all_lookups.pop()
2071 if lookup.prefetch_to in done_queries:
2072 if lookup.queryset is not None:
2073 raise ValueError(
2074 f"'{lookup.prefetch_to}' lookup was already seen with a different queryset. "
2075 "You may need to adjust the ordering of your lookups."
2076 )
2077
2078 continue
2079
2080 # Top level, the list of objects to decorate is the result cache
2081 # from the primary QuerySet. It won't be for deeper levels.
2082 obj_list = model_instances
2083
2084 through_attrs = lookup.prefetch_through.split(LOOKUP_SEP)
2085 for level, through_attr in enumerate(through_attrs):
2086 # Prepare main instances
2087 if not obj_list:
2088 break
2089
2090 prefetch_to = lookup.get_current_prefetch_to(level)
2091 if prefetch_to in done_queries:
2092 # Skip any prefetching, and any object preparation
2093 obj_list = done_queries[prefetch_to]
2094 continue
2095
2096 # Prepare objects:
2097 good_objects = True
2098 for obj in obj_list:
2099 # Since prefetching can re-use instances, it is possible to have
2100 # the same instance multiple times in obj_list, so obj might
2101 # already be prepared.
2102 if not hasattr(obj, "_prefetched_objects_cache"):
2103 try:
2104 obj._prefetched_objects_cache = {}
2105 except (AttributeError, TypeError):
2106 # Must be an immutable object from
2107 # values_list(flat=True), for example (TypeError) or
2108 # a QuerySet subclass that isn't returning Model
2109 # instances (AttributeError), either in Plain or a 3rd
2110 # party. prefetch_related() doesn't make sense, so quit.
2111 good_objects = False
2112 break
2113 if not good_objects:
2114 break
2115
2116 # Descend down tree
2117
2118 # We assume that objects retrieved are homogeneous (which is the premise
2119 # of prefetch_related), so what applies to first object applies to all.
2120 first_obj = obj_list[0]
2121 to_attr = lookup.get_current_to_attr(level)[0]
2122 prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(
2123 first_obj, through_attr, to_attr
2124 )
2125
2126 if not attr_found:
2127 raise AttributeError(
2128 f"Cannot find '{through_attr}' on {first_obj.__class__.__name__} object, '{lookup.prefetch_through}' is an invalid "
2129 "parameter to prefetch_related()"
2130 )
2131
2132 if level == len(through_attrs) - 1 and prefetcher is None:
2133 # Last one, this *must* resolve to something that supports
2134 # prefetching, otherwise there is no point adding it and the
2135 # developer asking for it has made a mistake.
2136 raise ValueError(
2137 f"'{lookup.prefetch_through}' does not resolve to an item that supports "
2138 "prefetching - this is an invalid parameter to "
2139 "prefetch_related()."
2140 )
2141
2142 obj_to_fetch = None
2143 if prefetcher is not None:
2144 obj_to_fetch = [obj for obj in obj_list if not is_fetched(obj)]
2145
2146 if obj_to_fetch:
2147 obj_list, additional_lookups = prefetch_one_level(
2148 obj_to_fetch,
2149 prefetcher,
2150 lookup,
2151 level,
2152 )
2153 # We need to ensure we don't keep adding lookups from the
2154 # same relationships to stop infinite recursion. So, if we
2155 # are already on an automatically added lookup, don't add
2156 # the new lookups from relationships we've seen already.
2157 if not (
2158 prefetch_to in done_queries
2159 and lookup in auto_lookups
2160 and descriptor in followed_descriptors
2161 ):
2162 done_queries[prefetch_to] = obj_list
2163 new_lookups = normalize_prefetch_lookups(
2164 reversed(additional_lookups), # type: ignore[arg-type]
2165 prefetch_to,
2166 )
2167 auto_lookups.update(new_lookups)
2168 all_lookups.extend(new_lookups)
2169 followed_descriptors.add(descriptor)
2170 else:
2171 # Either a singly related object that has already been fetched
2172 # (e.g. via select_related), or hopefully some other property
2173 # that doesn't support prefetching but needs to be traversed.
2174
2175 # We replace the current list of parent objects with the list
2176 # of related objects, filtering out empty or missing values so
2177 # that we can continue with nullable or reverse relations.
2178 new_obj_list = []
2179 for obj in obj_list:
2180 if through_attr in getattr(obj, "_prefetched_objects_cache", ()):
2181 # If related objects have been prefetched, use the
2182 # cache rather than the object's through_attr.
2183 new_obj = list(obj._prefetched_objects_cache.get(through_attr)) # type: ignore[arg-type]
2184 else:
2185 try:
2186 new_obj = getattr(obj, through_attr)
2187 except ObjectDoesNotExist:
2188 continue
2189 if new_obj is None:
2190 continue
2191 # We special-case `list` rather than something more generic
2192 # like `Iterable` because we don't want to accidentally match
2193 # user models that define __iter__.
2194 if isinstance(new_obj, list):
2195 new_obj_list.extend(new_obj)
2196 else:
2197 new_obj_list.append(new_obj)
2198 obj_list = new_obj_list
2199
2200
2201def get_prefetcher(
2202 instance: Model, through_attr: str, to_attr: str
2203) -> tuple[Any, Any, bool, Callable[[Model], bool]]:
2204 """
2205 For the attribute 'through_attr' on the given instance, find
2206 an object that has a get_prefetch_queryset().
2207 Return a 4 tuple containing:
2208 (the object with get_prefetch_queryset (or None),
2209 the descriptor object representing this relationship (or None),
2210 a boolean that is False if the attribute was not found at all,
2211 a function that takes an instance and returns a boolean that is True if
2212 the attribute has already been fetched for that instance)
2213 """
2214
2215 def has_to_attr_attribute(instance: Model) -> bool:
2216 return hasattr(instance, to_attr)
2217
2218 prefetcher = None
2219 is_fetched: Callable[[Model], bool] = has_to_attr_attribute
2220
2221 # For singly related objects, we have to avoid getting the attribute
2222 # from the object, as this will trigger the query. So we first try
2223 # on the class, in order to get the descriptor object.
2224 rel_obj_descriptor = getattr(instance.__class__, through_attr, None)
2225 if rel_obj_descriptor is None:
2226 attr_found = hasattr(instance, through_attr)
2227 else:
2228 attr_found = True
2229 if rel_obj_descriptor:
2230 # singly related object, descriptor object has the
2231 # get_prefetch_queryset() method.
2232 if hasattr(rel_obj_descriptor, "get_prefetch_queryset"):
2233 prefetcher = rel_obj_descriptor
2234 is_fetched = rel_obj_descriptor.is_cached
2235 else:
2236 # descriptor doesn't support prefetching, so we go ahead and get
2237 # the attribute on the instance rather than the class to
2238 # support many related managers
2239 rel_obj = getattr(instance, through_attr)
2240 if hasattr(rel_obj, "get_prefetch_queryset"):
2241 prefetcher = rel_obj
2242 if through_attr != to_attr:
2243 # Special case cached_property instances because hasattr
2244 # triggers attribute computation and assignment.
2245 if isinstance(
2246 getattr(instance.__class__, to_attr, None), cached_property
2247 ):
2248
2249 def has_cached_property(instance: Model) -> bool:
2250 return to_attr in instance.__dict__
2251
2252 is_fetched = has_cached_property
2253 else:
2254
2255 def in_prefetched_cache(instance: Model) -> bool:
2256 return through_attr in instance._prefetched_objects_cache # type: ignore[attr-defined]
2257
2258 is_fetched = in_prefetched_cache
2259 return prefetcher, rel_obj_descriptor, attr_found, is_fetched
2260
2261
2262def prefetch_one_level(
2263 instances: list[Model], prefetcher: Any, lookup: Prefetch, level: int
2264) -> tuple[list[Model], list[Prefetch]]:
2265 """
2266 Helper function for prefetch_related_objects().
2267
2268 Run prefetches on all instances using the prefetcher object,
2269 assigning results to relevant caches in instance.
2270
2271 Return the prefetched objects along with any additional prefetches that
2272 must be done due to prefetch_related lookups found from default managers.
2273 """
2274 # prefetcher must have a method get_prefetch_queryset() which takes a list
2275 # of instances, and returns a tuple:
2276
2277 # (queryset of instances of self.model that are related to passed in instances,
2278 # callable that gets value to be matched for returned instances,
2279 # callable that gets value to be matched for passed in instances,
2280 # boolean that is True for singly related objects,
2281 # cache or field name to assign to,
2282 # boolean that is True when the previous argument is a cache name vs a field name).
2283
2284 # The 'values to be matched' must be hashable as they will be used
2285 # in a dictionary.
2286
2287 (
2288 rel_qs,
2289 rel_obj_attr,
2290 instance_attr,
2291 single,
2292 cache_name,
2293 is_descriptor,
2294 ) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))
2295 # We have to handle the possibility that the QuerySet we just got back
2296 # contains some prefetch_related lookups. We don't want to trigger the
2297 # prefetch_related functionality by evaluating the query. Rather, we need
2298 # to merge in the prefetch_related lookups.
2299 # Copy the lookups in case it is a Prefetch object which could be reused
2300 # later (happens in nested prefetch_related).
2301 additional_lookups = [
2302 copy.copy(additional_lookup)
2303 for additional_lookup in getattr(rel_qs, "_prefetch_related_lookups", ())
2304 ]
2305 if additional_lookups:
2306 # Don't need to clone because the queryset should have given us a fresh
2307 # instance, so we access an internal instead of using public interface
2308 # for performance reasons.
2309 rel_qs._prefetch_related_lookups = ()
2310
2311 all_related_objects = list(rel_qs)
2312
2313 rel_obj_cache = {}
2314 for rel_obj in all_related_objects:
2315 rel_attr_val = rel_obj_attr(rel_obj)
2316 rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
2317
2318 to_attr, as_attr = lookup.get_current_to_attr(level)
2319 # Make sure `to_attr` does not conflict with a field.
2320 if as_attr and instances:
2321 # We assume that objects retrieved are homogeneous (which is the premise
2322 # of prefetch_related), so what applies to first object applies to all.
2323 model = instances[0].__class__
2324 try:
2325 model._model_meta.get_field(to_attr)
2326 except FieldDoesNotExist:
2327 pass
2328 else:
2329 msg = "to_attr={} conflicts with a field on the {} model."
2330 raise ValueError(msg.format(to_attr, model.__name__))
2331
2332 # Whether or not we're prefetching the last part of the lookup.
2333 leaf = len(lookup.prefetch_through.split(LOOKUP_SEP)) - 1 == level
2334
2335 for obj in instances:
2336 instance_attr_val = instance_attr(obj)
2337 vals = rel_obj_cache.get(instance_attr_val, [])
2338
2339 if single:
2340 val = vals[0] if vals else None
2341 if as_attr:
2342 # A to_attr has been given for the prefetch.
2343 setattr(obj, to_attr, val)
2344 elif is_descriptor:
2345 # cache_name points to a field name in obj.
2346 # This field is a descriptor for a related object.
2347 setattr(obj, cache_name, val)
2348 else:
2349 # No to_attr has been given for this prefetch operation and the
2350 # cache_name does not point to a descriptor. Store the value of
2351 # the field in the object's field cache.
2352 obj._state.fields_cache[cache_name] = val # type: ignore[index]
2353 else:
2354 if as_attr:
2355 setattr(obj, to_attr, vals)
2356 else:
2357 queryset = getattr(obj, to_attr)
2358 if leaf and lookup.queryset is not None:
2359 qs = queryset._apply_rel_filters(lookup.queryset)
2360 else:
2361 # Check if queryset is a QuerySet or a related manager
2362 # We need a QuerySet instance to cache the prefetched values
2363 if isinstance(queryset, QuerySet):
2364 # It's already a QuerySet, create a new instance
2365 qs = queryset.__class__.from_model(queryset.model)
2366 else:
2367 # It's a related manager, get its QuerySet
2368 # The manager's query property returns a properly filtered QuerySet
2369 qs = queryset.query
2370 qs._result_cache = vals
2371 # We don't want the individual qs doing prefetch_related now,
2372 # since we have merged this into the current work.
2373 qs._prefetch_done = True
2374 obj._prefetched_objects_cache[cache_name] = qs
2375 return all_related_objects, additional_lookups
2376
2377
2378class RelatedPopulator:
2379 """
2380 RelatedPopulator is used for select_related() object instantiation.
2381
2382 The idea is that each select_related() model will be populated by a
2383 different RelatedPopulator instance. The RelatedPopulator instances get
2384 klass_info and select (computed in SQLCompiler) plus the used db as
2385 input for initialization. That data is used to compute which columns
2386 to use, how to instantiate the model, and how to populate the links
2387 between the objects.
2388
2389 The actual creation of the objects is done in populate() method. This
2390 method gets row and from_obj as input and populates the select_related()
2391 model instance.
2392 """
2393
2394 def __init__(self, klass_info: dict[str, Any], select: list[Any]):
2395 # Pre-compute needed attributes. The attributes are:
2396 # - model_cls: the possibly deferred model class to instantiate
2397 # - either:
2398 # - cols_start, cols_end: usually the columns in the row are
2399 # in the same order model_cls.__init__ expects them, so we
2400 # can instantiate by model_cls(*row[cols_start:cols_end])
2401 # - reorder_for_init: When select_related descends to a child
2402 # class, then we want to reuse the already selected parent
2403 # data. However, in this case the parent data isn't necessarily
2404 # in the same order that Model.__init__ expects it to be, so
2405 # we have to reorder the parent data. The reorder_for_init
2406 # attribute contains a function used to reorder the field data
2407 # in the order __init__ expects it.
2408 # - id_idx: the index of the primary key field in the reordered
2409 # model data. Used to check if a related object exists at all.
2410 # - init_list: the field attnames fetched from the database. For
2411 # deferred models this isn't the same as all attnames of the
2412 # model's fields.
2413 # - related_populators: a list of RelatedPopulator instances if
2414 # select_related() descends to related models from this model.
2415 # - local_setter, remote_setter: Methods to set cached values on
2416 # the object being populated and on the remote object. Usually
2417 # these are Field.set_cached_value() methods.
2418 select_fields = klass_info["select_fields"]
2419
2420 self.cols_start = select_fields[0]
2421 self.cols_end = select_fields[-1] + 1
2422 self.init_list = [
2423 f[0].target.attname for f in select[self.cols_start : self.cols_end]
2424 ]
2425 self.reorder_for_init = None
2426
2427 self.model_cls = klass_info["model"]
2428 self.id_idx = self.init_list.index("id")
2429 self.related_populators = get_related_populators(klass_info, select)
2430 self.local_setter = klass_info["local_setter"]
2431 self.remote_setter = klass_info["remote_setter"]
2432
2433 def populate(self, row: tuple[Any, ...], from_obj: Model) -> None:
2434 if self.reorder_for_init:
2435 obj_data = self.reorder_for_init(row)
2436 else:
2437 obj_data = row[self.cols_start : self.cols_end]
2438 if obj_data[self.id_idx] is None:
2439 obj = None
2440 else:
2441 obj = self.model_cls.from_db(self.init_list, obj_data)
2442 for rel_iter in self.related_populators:
2443 rel_iter.populate(row, obj)
2444 self.local_setter(from_obj, obj)
2445 if obj is not None:
2446 self.remote_setter(obj, from_obj)
2447
2448
2449def get_related_populators(
2450 klass_info: dict[str, Any], select: list[Any]
2451) -> list[RelatedPopulator]:
2452 iterators = []
2453 related_klass_infos = klass_info.get("related_klass_infos", [])
2454 for rel_klass_info in related_klass_infos:
2455 rel_cls = RelatedPopulator(rel_klass_info, select)
2456 iterators.append(rel_cls)
2457 return iterators