1"""
2The main QuerySet implementation. This provides the public API for the ORM.
3"""
4
5import copy
6import operator
7import warnings
8from functools import cached_property
9from itertools import chain, islice
10
11import plain.runtime
12from plain import exceptions
13from plain.exceptions import ValidationError
14from plain.models import (
15 sql,
16 transaction,
17)
18from plain.models.constants import LOOKUP_SEP, OnConflict
19from plain.models.db import (
20 PLAIN_VERSION_PICKLE_KEY,
21 IntegrityError,
22 NotSupportedError,
23 db_connection,
24)
25from plain.models.expressions import Case, F, Value, When
26from plain.models.fields import (
27 AutoField,
28 DateField,
29 DateTimeField,
30 Field,
31)
32from plain.models.functions import Cast, Trunc
33from plain.models.query_utils import FilteredRelation, Q
34from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
35from plain.models.utils import (
36 create_namedtuple_class,
37 resolve_callables,
38)
39from plain.utils import timezone
40from plain.utils.functional import partition
41
42# The maximum number of results to fetch in a get() query.
43MAX_GET_RESULTS = 21
44
45# The maximum number of items to display in a QuerySet.__repr__
46REPR_OUTPUT_SIZE = 20
47
48
49class BaseIterable:
50 def __init__(
51 self, queryset, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE
52 ):
53 self.queryset = queryset
54 self.chunked_fetch = chunked_fetch
55 self.chunk_size = chunk_size
56
57
58class ModelIterable(BaseIterable):
59 """Iterable that yields a model instance for each row."""
60
61 def __iter__(self):
62 queryset = self.queryset
63 compiler = queryset.query.get_compiler()
64 # Execute the query. This will also fill compiler.select, klass_info,
65 # and annotations.
66 results = compiler.execute_sql(
67 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
68 )
69 select, klass_info, annotation_col_map = (
70 compiler.select,
71 compiler.klass_info,
72 compiler.annotation_col_map,
73 )
74 model_cls = klass_info["model"]
75 select_fields = klass_info["select_fields"]
76 model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1
77 init_list = [
78 f[0].target.attname for f in select[model_fields_start:model_fields_end]
79 ]
80 related_populators = get_related_populators(klass_info, select)
81 known_related_objects = [
82 (
83 field,
84 related_objs,
85 operator.attrgetter(
86 *[
87 field.attname
88 if from_field == "self"
89 else queryset.model._meta.get_field(from_field).attname
90 for from_field in field.from_fields
91 ]
92 ),
93 )
94 for field, related_objs in queryset._known_related_objects.items()
95 ]
96 for row in compiler.results_iter(results):
97 obj = model_cls.from_db(init_list, row[model_fields_start:model_fields_end])
98 for rel_populator in related_populators:
99 rel_populator.populate(row, obj)
100 if annotation_col_map:
101 for attr_name, col_pos in annotation_col_map.items():
102 setattr(obj, attr_name, row[col_pos])
103
104 # Add the known related objects to the model.
105 for field, rel_objs, rel_getter in known_related_objects:
106 # Avoid overwriting objects loaded by, e.g., select_related().
107 if field.is_cached(obj):
108 continue
109 rel_obj_id = rel_getter(obj)
110 try:
111 rel_obj = rel_objs[rel_obj_id]
112 except KeyError:
113 pass # May happen in qs1 | qs2 scenarios.
114 else:
115 setattr(obj, field.name, rel_obj)
116
117 yield obj
118
119
120class RawModelIterable(BaseIterable):
121 """
122 Iterable that yields a model instance for each row from a raw queryset.
123 """
124
125 def __iter__(self):
126 # Cache some things for performance reasons outside the loop.
127 query = self.queryset.query
128 compiler = db_connection.ops.compiler("SQLCompiler")(query, db_connection)
129 query_iterator = iter(query)
130
131 try:
132 (
133 model_init_names,
134 model_init_pos,
135 annotation_fields,
136 ) = self.queryset.resolve_model_init_order()
137 model_cls = self.queryset.model
138 if model_cls._meta.pk.attname not in model_init_names:
139 raise exceptions.FieldDoesNotExist(
140 "Raw query must include the primary key"
141 )
142 fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
143 converters = compiler.get_converters(
144 [f.get_col(f.model._meta.db_table) if f else None for f in fields]
145 )
146 if converters:
147 query_iterator = compiler.apply_converters(query_iterator, converters)
148 for values in query_iterator:
149 # Associate fields to values
150 model_init_values = [values[pos] for pos in model_init_pos]
151 instance = model_cls.from_db(model_init_names, model_init_values)
152 if annotation_fields:
153 for column, pos in annotation_fields:
154 setattr(instance, column, values[pos])
155 yield instance
156 finally:
157 # Done iterating the Query. If it has its own cursor, close it.
158 if hasattr(query, "cursor") and query.cursor:
159 query.cursor.close()
160
161
162class ValuesIterable(BaseIterable):
163 """
164 Iterable returned by QuerySet.values() that yields a dict for each row.
165 """
166
167 def __iter__(self):
168 queryset = self.queryset
169 query = queryset.query
170 compiler = query.get_compiler()
171
172 # extra(select=...) cols are always at the start of the row.
173 names = [
174 *query.extra_select,
175 *query.values_select,
176 *query.annotation_select,
177 ]
178 indexes = range(len(names))
179 for row in compiler.results_iter(
180 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
181 ):
182 yield {names[i]: row[i] for i in indexes}
183
184
185class ValuesListIterable(BaseIterable):
186 """
187 Iterable returned by QuerySet.values_list(flat=False) that yields a tuple
188 for each row.
189 """
190
191 def __iter__(self):
192 queryset = self.queryset
193 query = queryset.query
194 compiler = query.get_compiler()
195
196 if queryset._fields:
197 # extra(select=...) cols are always at the start of the row.
198 names = [
199 *query.extra_select,
200 *query.values_select,
201 *query.annotation_select,
202 ]
203 fields = [
204 *queryset._fields,
205 *(f for f in query.annotation_select if f not in queryset._fields),
206 ]
207 if fields != names:
208 # Reorder according to fields.
209 index_map = {name: idx for idx, name in enumerate(names)}
210 rowfactory = operator.itemgetter(*[index_map[f] for f in fields])
211 return map(
212 rowfactory,
213 compiler.results_iter(
214 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
215 ),
216 )
217 return compiler.results_iter(
218 tuple_expected=True,
219 chunked_fetch=self.chunked_fetch,
220 chunk_size=self.chunk_size,
221 )
222
223
224class NamedValuesListIterable(ValuesListIterable):
225 """
226 Iterable returned by QuerySet.values_list(named=True) that yields a
227 namedtuple for each row.
228 """
229
230 def __iter__(self):
231 queryset = self.queryset
232 if queryset._fields:
233 names = queryset._fields
234 else:
235 query = queryset.query
236 names = [
237 *query.extra_select,
238 *query.values_select,
239 *query.annotation_select,
240 ]
241 tuple_class = create_namedtuple_class(*names)
242 new = tuple.__new__
243 for row in super().__iter__():
244 yield new(tuple_class, row)
245
246
247class FlatValuesListIterable(BaseIterable):
248 """
249 Iterable returned by QuerySet.values_list(flat=True) that yields single
250 values.
251 """
252
253 def __iter__(self):
254 queryset = self.queryset
255 compiler = queryset.query.get_compiler()
256 for row in compiler.results_iter(
257 chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size
258 ):
259 yield row[0]
260
261
262class QuerySet:
263 """Represent a lazy database lookup for a set of objects."""
264
265 def __init__(self, model=None, query=None, hints=None):
266 self.model = model
267 self._hints = hints or {}
268 self._query = query or sql.Query(self.model)
269 self._result_cache = None
270 self._sticky_filter = False
271 self._for_write = False
272 self._prefetch_related_lookups = ()
273 self._prefetch_done = False
274 self._known_related_objects = {} # {rel_field: {pk: rel_obj}}
275 self._iterable_class = ModelIterable
276 self._fields = None
277 self._defer_next_filter = False
278 self._deferred_filter = None
279
280 @property
281 def query(self):
282 if self._deferred_filter:
283 negate, args, kwargs = self._deferred_filter
284 self._filter_or_exclude_inplace(negate, args, kwargs)
285 self._deferred_filter = None
286 return self._query
287
288 @query.setter
289 def query(self, value):
290 if value.values_select:
291 self._iterable_class = ValuesIterable
292 self._query = value
293
294 def as_manager(cls):
295 # Address the circular dependency between `Queryset` and `Manager`.
296 from plain.models.manager import Manager
297
298 manager = Manager.from_queryset(cls)()
299 manager._built_with_as_manager = True
300 return manager
301
302 as_manager.queryset_only = True
303 as_manager = classmethod(as_manager)
304
305 ########################
306 # PYTHON MAGIC METHODS #
307 ########################
308
309 def __deepcopy__(self, memo):
310 """Don't populate the QuerySet's cache."""
311 obj = self.__class__()
312 for k, v in self.__dict__.items():
313 if k == "_result_cache":
314 obj.__dict__[k] = None
315 else:
316 obj.__dict__[k] = copy.deepcopy(v, memo)
317 return obj
318
319 def __getstate__(self):
320 # Force the cache to be fully populated.
321 self._fetch_all()
322 return {**self.__dict__, PLAIN_VERSION_PICKLE_KEY: plain.runtime.__version__}
323
324 def __setstate__(self, state):
325 pickled_version = state.get(PLAIN_VERSION_PICKLE_KEY)
326 if pickled_version:
327 if pickled_version != plain.runtime.__version__:
328 warnings.warn(
329 f"Pickled queryset instance's Plain version {pickled_version} does not "
330 f"match the current version {plain.runtime.__version__}.",
331 RuntimeWarning,
332 stacklevel=2,
333 )
334 else:
335 warnings.warn(
336 "Pickled queryset instance's Plain version is not specified.",
337 RuntimeWarning,
338 stacklevel=2,
339 )
340 self.__dict__.update(state)
341
342 def __repr__(self):
343 data = list(self[: REPR_OUTPUT_SIZE + 1])
344 if len(data) > REPR_OUTPUT_SIZE:
345 data[-1] = "...(remaining elements truncated)..."
346 return f"<{self.__class__.__name__} {data!r}>"
347
348 def __len__(self):
349 self._fetch_all()
350 return len(self._result_cache)
351
352 def __iter__(self):
353 """
354 The queryset iterator protocol uses three nested iterators in the
355 default case:
356 1. sql.compiler.execute_sql()
357 - Returns 100 rows at time (constants.GET_ITERATOR_CHUNK_SIZE)
358 using cursor.fetchmany(). This part is responsible for
359 doing some column masking, and returning the rows in chunks.
360 2. sql.compiler.results_iter()
361 - Returns one row at time. At this point the rows are still just
362 tuples. In some cases the return values are converted to
363 Python values at this location.
364 3. self.iterator()
365 - Responsible for turning the rows into model objects.
366 """
367 self._fetch_all()
368 return iter(self._result_cache)
369
370 def __bool__(self):
371 self._fetch_all()
372 return bool(self._result_cache)
373
374 def __getitem__(self, k):
375 """Retrieve an item or slice from the set of results."""
376 if not isinstance(k, int | slice):
377 raise TypeError(
378 f"QuerySet indices must be integers or slices, not {type(k).__name__}."
379 )
380 if (isinstance(k, int) and k < 0) or (
381 isinstance(k, slice)
382 and (
383 (k.start is not None and k.start < 0)
384 or (k.stop is not None and k.stop < 0)
385 )
386 ):
387 raise ValueError("Negative indexing is not supported.")
388
389 if self._result_cache is not None:
390 return self._result_cache[k]
391
392 if isinstance(k, slice):
393 qs = self._chain()
394 if k.start is not None:
395 start = int(k.start)
396 else:
397 start = None
398 if k.stop is not None:
399 stop = int(k.stop)
400 else:
401 stop = None
402 qs.query.set_limits(start, stop)
403 return list(qs)[:: k.step] if k.step else qs
404
405 qs = self._chain()
406 qs.query.set_limits(k, k + 1)
407 qs._fetch_all()
408 return qs._result_cache[0]
409
410 def __class_getitem__(cls, *args, **kwargs):
411 return cls
412
413 def __and__(self, other):
414 self._check_operator_queryset(other, "&")
415 self._merge_sanity_check(other)
416 if isinstance(other, EmptyQuerySet):
417 return other
418 if isinstance(self, EmptyQuerySet):
419 return self
420 combined = self._chain()
421 combined._merge_known_related_objects(other)
422 combined.query.combine(other.query, sql.AND)
423 return combined
424
425 def __or__(self, other):
426 self._check_operator_queryset(other, "|")
427 self._merge_sanity_check(other)
428 if isinstance(self, EmptyQuerySet):
429 return other
430 if isinstance(other, EmptyQuerySet):
431 return self
432 query = (
433 self
434 if self.query.can_filter()
435 else self.model._base_manager.filter(pk__in=self.values("pk"))
436 )
437 combined = query._chain()
438 combined._merge_known_related_objects(other)
439 if not other.query.can_filter():
440 other = other.model._base_manager.filter(pk__in=other.values("pk"))
441 combined.query.combine(other.query, sql.OR)
442 return combined
443
444 def __xor__(self, other):
445 self._check_operator_queryset(other, "^")
446 self._merge_sanity_check(other)
447 if isinstance(self, EmptyQuerySet):
448 return other
449 if isinstance(other, EmptyQuerySet):
450 return self
451 query = (
452 self
453 if self.query.can_filter()
454 else self.model._base_manager.filter(pk__in=self.values("pk"))
455 )
456 combined = query._chain()
457 combined._merge_known_related_objects(other)
458 if not other.query.can_filter():
459 other = other.model._base_manager.filter(pk__in=other.values("pk"))
460 combined.query.combine(other.query, sql.XOR)
461 return combined
462
463 ####################################
464 # METHODS THAT DO DATABASE QUERIES #
465 ####################################
466
467 def _iterator(self, use_chunked_fetch, chunk_size):
468 iterable = self._iterable_class(
469 self,
470 chunked_fetch=use_chunked_fetch,
471 chunk_size=chunk_size or 2000,
472 )
473 if not self._prefetch_related_lookups or chunk_size is None:
474 yield from iterable
475 return
476
477 iterator = iter(iterable)
478 while results := list(islice(iterator, chunk_size)):
479 prefetch_related_objects(results, *self._prefetch_related_lookups)
480 yield from results
481
482 def iterator(self, chunk_size=None):
483 """
484 An iterator over the results from applying this QuerySet to the
485 database. chunk_size must be provided for QuerySets that prefetch
486 related objects. Otherwise, a default chunk_size of 2000 is supplied.
487 """
488 if chunk_size is None:
489 if self._prefetch_related_lookups:
490 raise ValueError(
491 "chunk_size must be provided when using QuerySet.iterator() after "
492 "prefetch_related()."
493 )
494 elif chunk_size <= 0:
495 raise ValueError("Chunk size must be strictly positive.")
496 use_chunked_fetch = not db_connection.settings_dict.get(
497 "DISABLE_SERVER_SIDE_CURSORS"
498 )
499 return self._iterator(use_chunked_fetch, chunk_size)
500
501 def aggregate(self, *args, **kwargs):
502 """
503 Return a dictionary containing the calculations (aggregation)
504 over the current queryset.
505
506 If args is present the expression is passed as a kwarg using
507 the Aggregate object's default alias.
508 """
509 if self.query.distinct_fields:
510 raise NotImplementedError("aggregate() + distinct(fields) not implemented.")
511 self._validate_values_are_expressions(
512 (*args, *kwargs.values()), method_name="aggregate"
513 )
514 for arg in args:
515 # The default_alias property raises TypeError if default_alias
516 # can't be set automatically or AttributeError if it isn't an
517 # attribute.
518 try:
519 arg.default_alias
520 except (AttributeError, TypeError):
521 raise TypeError("Complex aggregates require an alias")
522 kwargs[arg.default_alias] = arg
523
524 return self.query.chain().get_aggregation(kwargs)
525
526 def count(self):
527 """
528 Perform a SELECT COUNT() and return the number of records as an
529 integer.
530
531 If the QuerySet is already fully cached, return the length of the
532 cached results set to avoid multiple SELECT COUNT(*) calls.
533 """
534 if self._result_cache is not None:
535 return len(self._result_cache)
536
537 return self.query.get_count()
538
539 def get(self, *args, **kwargs):
540 """
541 Perform the query and return a single object matching the given
542 keyword arguments.
543 """
544 if self.query.combinator and (args or kwargs):
545 raise NotSupportedError(
546 f"Calling QuerySet.get(...) with filters after {self.query.combinator}() is not "
547 "supported."
548 )
549 clone = self._chain() if self.query.combinator else self.filter(*args, **kwargs)
550 if self.query.can_filter() and not self.query.distinct_fields:
551 clone = clone.order_by()
552 limit = None
553 if (
554 not clone.query.select_for_update
555 or db_connection.features.supports_select_for_update_with_limit
556 ):
557 limit = MAX_GET_RESULTS
558 clone.query.set_limits(high=limit)
559 num = len(clone)
560 if num == 1:
561 return clone._result_cache[0]
562 if not num:
563 raise self.model.DoesNotExist(
564 f"{self.model._meta.object_name} matching query does not exist."
565 )
566 raise self.model.MultipleObjectsReturned(
567 "get() returned more than one {} -- it returned {}!".format(
568 self.model._meta.object_name,
569 num if not limit or num < limit else "more than %s" % (limit - 1),
570 )
571 )
572
573 def create(self, **kwargs):
574 """
575 Create a new object with the given kwargs, saving it to the database
576 and returning the created object.
577 """
578 obj = self.model(**kwargs)
579 self._for_write = True
580 obj.save(force_insert=True)
581 return obj
582
583 def _prepare_for_bulk_create(self, objs):
584 for obj in objs:
585 if obj.pk is None:
586 # Populate new PK values.
587 obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
588 obj._prepare_related_fields_for_save(operation_name="bulk_create")
589
590 def _check_bulk_create_options(
591 self, update_conflicts, update_fields, unique_fields
592 ):
593 db_features = db_connection.features
594 if update_conflicts:
595 if not db_features.supports_update_conflicts:
596 raise NotSupportedError(
597 "This database backend does not support updating conflicts."
598 )
599 if not update_fields:
600 raise ValueError(
601 "Fields that will be updated when a row insertion fails "
602 "on conflicts must be provided."
603 )
604 if unique_fields and not db_features.supports_update_conflicts_with_target:
605 raise NotSupportedError(
606 "This database backend does not support updating "
607 "conflicts with specifying unique fields that can trigger "
608 "the upsert."
609 )
610 if not unique_fields and db_features.supports_update_conflicts_with_target:
611 raise ValueError(
612 "Unique fields that can trigger the upsert must be provided."
613 )
614 # Updating primary keys and non-concrete fields is forbidden.
615 if any(not f.concrete or f.many_to_many for f in update_fields):
616 raise ValueError(
617 "bulk_create() can only be used with concrete fields in "
618 "update_fields."
619 )
620 if any(f.primary_key for f in update_fields):
621 raise ValueError(
622 "bulk_create() cannot be used with primary keys in update_fields."
623 )
624 if unique_fields:
625 if any(not f.concrete or f.many_to_many for f in unique_fields):
626 raise ValueError(
627 "bulk_create() can only be used with concrete fields "
628 "in unique_fields."
629 )
630 return OnConflict.UPDATE
631 return None
632
633 def bulk_create(
634 self,
635 objs,
636 batch_size=None,
637 update_conflicts=False,
638 update_fields=None,
639 unique_fields=None,
640 ):
641 """
642 Insert each of the instances into the database. Do *not* call
643 save() on each of the instances, and do not set the primary key attribute if it is an
644 autoincrement field (except if features.can_return_rows_from_bulk_insert=True).
645 Multi-table models are not supported.
646 """
647 # When you bulk insert you don't get the primary keys back (if it's an
648 # autoincrement, except if can_return_rows_from_bulk_insert=True), so
649 # you can't insert into the child tables which references this. There
650 # are two workarounds:
651 # 1) This could be implemented if you didn't have an autoincrement pk
652 # 2) You could do it by doing O(n) normal inserts into the parent
653 # tables to get the primary keys back and then doing a single bulk
654 # insert into the childmost table.
655 # We currently set the primary keys on the objects when using
656 # PostgreSQL via the RETURNING ID clause. It should be possible for
657 # Oracle as well, but the semantics for extracting the primary keys is
658 # trickier so it's not done yet.
659 if batch_size is not None and batch_size <= 0:
660 raise ValueError("Batch size must be a positive integer.")
661
662 if not objs:
663 return objs
664 opts = self.model._meta
665 if unique_fields:
666 # Primary key is allowed in unique_fields.
667 unique_fields = [
668 self.model._meta.get_field(opts.pk.name if name == "pk" else name)
669 for name in unique_fields
670 ]
671 if update_fields:
672 update_fields = [self.model._meta.get_field(name) for name in update_fields]
673 on_conflict = self._check_bulk_create_options(
674 update_conflicts,
675 update_fields,
676 unique_fields,
677 )
678 self._for_write = True
679 fields = opts.concrete_fields
680 objs = list(objs)
681 self._prepare_for_bulk_create(objs)
682 with transaction.atomic(savepoint=False):
683 objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
684 if objs_with_pk:
685 returned_columns = self._batched_insert(
686 objs_with_pk,
687 fields,
688 batch_size,
689 on_conflict=on_conflict,
690 update_fields=update_fields,
691 unique_fields=unique_fields,
692 )
693 for obj_with_pk, results in zip(objs_with_pk, returned_columns):
694 for result, field in zip(results, opts.db_returning_fields):
695 if field != opts.pk:
696 setattr(obj_with_pk, field.attname, result)
697 for obj_with_pk in objs_with_pk:
698 obj_with_pk._state.adding = False
699 if objs_without_pk:
700 fields = [f for f in fields if not isinstance(f, AutoField)]
701 returned_columns = self._batched_insert(
702 objs_without_pk,
703 fields,
704 batch_size,
705 on_conflict=on_conflict,
706 update_fields=update_fields,
707 unique_fields=unique_fields,
708 )
709 if (
710 db_connection.features.can_return_rows_from_bulk_insert
711 and on_conflict is None
712 ):
713 assert len(returned_columns) == len(objs_without_pk)
714 for obj_without_pk, results in zip(objs_without_pk, returned_columns):
715 for result, field in zip(results, opts.db_returning_fields):
716 setattr(obj_without_pk, field.attname, result)
717 obj_without_pk._state.adding = False
718
719 return objs
720
721 def bulk_update(self, objs, fields, batch_size=None):
722 """
723 Update the given fields in each of the given objects in the database.
724 """
725 if batch_size is not None and batch_size <= 0:
726 raise ValueError("Batch size must be a positive integer.")
727 if not fields:
728 raise ValueError("Field names must be given to bulk_update().")
729 objs = tuple(objs)
730 if any(obj.pk is None for obj in objs):
731 raise ValueError("All bulk_update() objects must have a primary key set.")
732 fields = [self.model._meta.get_field(name) for name in fields]
733 if any(not f.concrete or f.many_to_many for f in fields):
734 raise ValueError("bulk_update() can only be used with concrete fields.")
735 if any(f.primary_key for f in fields):
736 raise ValueError("bulk_update() cannot be used with primary key fields.")
737 if not objs:
738 return 0
739 for obj in objs:
740 obj._prepare_related_fields_for_save(
741 operation_name="bulk_update", fields=fields
742 )
743 # PK is used twice in the resulting update query, once in the filter
744 # and once in the WHEN. Each field will also have one CAST.
745 self._for_write = True
746 max_batch_size = db_connection.ops.bulk_batch_size(["pk", "pk"] + fields, objs)
747 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
748 requires_casting = db_connection.features.requires_casted_case_in_updates
749 batches = (objs[i : i + batch_size] for i in range(0, len(objs), batch_size))
750 updates = []
751 for batch_objs in batches:
752 update_kwargs = {}
753 for field in fields:
754 when_statements = []
755 for obj in batch_objs:
756 attr = getattr(obj, field.attname)
757 if not hasattr(attr, "resolve_expression"):
758 attr = Value(attr, output_field=field)
759 when_statements.append(When(pk=obj.pk, then=attr))
760 case_statement = Case(*when_statements, output_field=field)
761 if requires_casting:
762 case_statement = Cast(case_statement, output_field=field)
763 update_kwargs[field.attname] = case_statement
764 updates.append(([obj.pk for obj in batch_objs], update_kwargs))
765 rows_updated = 0
766 queryset = self._chain()
767 with transaction.atomic(savepoint=False):
768 for pks, update_kwargs in updates:
769 rows_updated += queryset.filter(pk__in=pks).update(**update_kwargs)
770 return rows_updated
771
772 def get_or_create(self, defaults=None, **kwargs):
773 """
774 Look up an object with the given kwargs, creating one if necessary.
775 Return a tuple of (object, created), where created is a boolean
776 specifying whether an object was created.
777 """
778 # The get() needs to be targeted at the write database in order
779 # to avoid potential transaction consistency problems.
780 self._for_write = True
781 try:
782 return self.get(**kwargs), False
783 except self.model.DoesNotExist:
784 params = self._extract_model_params(defaults, **kwargs)
785 # Try to create an object using passed params.
786 try:
787 with transaction.atomic():
788 params = dict(resolve_callables(params))
789 return self.create(**params), True
790 except (IntegrityError, ValidationError):
791 # Since create() also validates by default,
792 # we can get any kind of ValidationError here,
793 # or it can flow through and get an IntegrityError from the database.
794 # The main thing we're concerned about is uniqueness failures,
795 # but ValidationError could include other things too.
796 # In all cases though it should be fine to try the get() again
797 # and return an existing object.
798 try:
799 return self.get(**kwargs), False
800 except self.model.DoesNotExist:
801 pass
802 raise
803
804 def update_or_create(self, defaults=None, create_defaults=None, **kwargs):
805 """
806 Look up an object with the given kwargs, updating one with defaults
807 if it exists, otherwise create a new one. Optionally, an object can
808 be created with different values than defaults by using
809 create_defaults.
810 Return a tuple (object, created), where created is a boolean
811 specifying whether an object was created.
812 """
813 if create_defaults is None:
814 update_defaults = create_defaults = defaults or {}
815 else:
816 update_defaults = defaults or {}
817 self._for_write = True
818 with transaction.atomic():
819 # Lock the row so that a concurrent update is blocked until
820 # update_or_create() has performed its save.
821 obj, created = self.select_for_update().get_or_create(
822 create_defaults, **kwargs
823 )
824 if created:
825 return obj, created
826 for k, v in resolve_callables(update_defaults):
827 setattr(obj, k, v)
828
829 update_fields = set(update_defaults)
830 concrete_field_names = self.model._meta._non_pk_concrete_field_names
831 # update_fields does not support non-concrete fields.
832 if concrete_field_names.issuperset(update_fields):
833 # Add fields which are set on pre_save(), e.g. auto_now fields.
834 # This is to maintain backward compatibility as these fields
835 # are not updated unless explicitly specified in the
836 # update_fields list.
837 for field in self.model._meta.local_concrete_fields:
838 if not (
839 field.primary_key or field.__class__.pre_save is Field.pre_save
840 ):
841 update_fields.add(field.name)
842 if field.name != field.attname:
843 update_fields.add(field.attname)
844 obj.save(update_fields=update_fields)
845 else:
846 obj.save()
847 return obj, False
848
849 def _extract_model_params(self, defaults, **kwargs):
850 """
851 Prepare `params` for creating a model instance based on the given
852 kwargs; for use by get_or_create().
853 """
854 defaults = defaults or {}
855 params = {k: v for k, v in kwargs.items() if LOOKUP_SEP not in k}
856 params.update(defaults)
857 property_names = self.model._meta._property_names
858 invalid_params = []
859 for param in params:
860 try:
861 self.model._meta.get_field(param)
862 except exceptions.FieldDoesNotExist:
863 # It's okay to use a model's property if it has a setter.
864 if not (param in property_names and getattr(self.model, param).fset):
865 invalid_params.append(param)
866 if invalid_params:
867 raise exceptions.FieldError(
868 "Invalid field name(s) for model {}: '{}'.".format(
869 self.model._meta.object_name,
870 "', '".join(sorted(invalid_params)),
871 )
872 )
873 return params
874
875 def _earliest(self, *fields):
876 """
877 Return the earliest object according to fields (if given) or by the
878 model's Meta.get_latest_by.
879 """
880 if fields:
881 order_by = fields
882 else:
883 order_by = getattr(self.model._meta, "get_latest_by")
884 if order_by and not isinstance(order_by, tuple | list):
885 order_by = (order_by,)
886 if order_by is None:
887 raise ValueError(
888 "earliest() and latest() require either fields as positional "
889 "arguments or 'get_latest_by' in the model's Meta."
890 )
891 obj = self._chain()
892 obj.query.set_limits(high=1)
893 obj.query.clear_ordering(force=True)
894 obj.query.add_ordering(*order_by)
895 return obj.get()
896
897 def earliest(self, *fields):
898 if self.query.is_sliced:
899 raise TypeError("Cannot change a query once a slice has been taken.")
900 return self._earliest(*fields)
901
902 def latest(self, *fields):
903 """
904 Return the latest object according to fields (if given) or by the
905 model's Meta.get_latest_by.
906 """
907 if self.query.is_sliced:
908 raise TypeError("Cannot change a query once a slice has been taken.")
909 return self.reverse()._earliest(*fields)
910
911 def first(self):
912 """Return the first object of a query or None if no match is found."""
913 if self.ordered:
914 queryset = self
915 else:
916 self._check_ordering_first_last_queryset_aggregation(method="first")
917 queryset = self.order_by("pk")
918 for obj in queryset[:1]:
919 return obj
920
921 def last(self):
922 """Return the last object of a query or None if no match is found."""
923 if self.ordered:
924 queryset = self.reverse()
925 else:
926 self._check_ordering_first_last_queryset_aggregation(method="last")
927 queryset = self.order_by("-pk")
928 for obj in queryset[:1]:
929 return obj
930
931 def in_bulk(self, id_list=None, *, field_name="pk"):
932 """
933 Return a dictionary mapping each of the given IDs to the object with
934 that ID. If `id_list` isn't provided, evaluate the entire QuerySet.
935 """
936 if self.query.is_sliced:
937 raise TypeError("Cannot use 'limit' or 'offset' with in_bulk().")
938 opts = self.model._meta
939 unique_fields = [
940 constraint.fields[0]
941 for constraint in opts.total_unique_constraints
942 if len(constraint.fields) == 1
943 ]
944 if (
945 field_name != "pk"
946 and not opts.get_field(field_name).primary_key
947 and field_name not in unique_fields
948 and self.query.distinct_fields != (field_name,)
949 ):
950 raise ValueError(
951 f"in_bulk()'s field_name must be a unique field but {field_name!r} isn't."
952 )
953 if id_list is not None:
954 if not id_list:
955 return {}
956 filter_key = f"{field_name}__in"
957 batch_size = db_connection.features.max_query_params
958 id_list = tuple(id_list)
959 # If the database has a limit on the number of query parameters
960 # (e.g. SQLite), retrieve objects in batches if necessary.
961 if batch_size and batch_size < len(id_list):
962 qs = ()
963 for offset in range(0, len(id_list), batch_size):
964 batch = id_list[offset : offset + batch_size]
965 qs += tuple(self.filter(**{filter_key: batch}))
966 else:
967 qs = self.filter(**{filter_key: id_list})
968 else:
969 qs = self._chain()
970 return {getattr(obj, field_name): obj for obj in qs}
971
972 def delete(self):
973 """Delete the records in the current QuerySet."""
974 self._not_support_combined_queries("delete")
975 if self.query.is_sliced:
976 raise TypeError("Cannot use 'limit' or 'offset' with delete().")
977 if self.query.distinct or self.query.distinct_fields:
978 raise TypeError("Cannot call delete() after .distinct().")
979 if self._fields is not None:
980 raise TypeError("Cannot call delete() after .values() or .values_list()")
981
982 del_query = self._chain()
983
984 # The delete is actually 2 queries - one to find related objects,
985 # and one to delete. Make sure that the discovery of related
986 # objects is performed on the same database as the deletion.
987 del_query._for_write = True
988
989 # Disable non-supported fields.
990 del_query.query.select_for_update = False
991 del_query.query.select_related = False
992 del_query.query.clear_ordering(force=True)
993
994 from plain.models.deletion import Collector
995
996 collector = Collector(origin=self)
997 collector.collect(del_query)
998 deleted, _rows_count = collector.delete()
999
1000 # Clear the result cache, in case this QuerySet gets reused.
1001 self._result_cache = None
1002 return deleted, _rows_count
1003
1004 delete.queryset_only = True
1005
1006 def _raw_delete(self):
1007 """
1008 Delete objects found from the given queryset in single direct SQL
1009 query. No signals are sent and there is no protection for cascades.
1010 """
1011 query = self.query.clone()
1012 query.__class__ = sql.DeleteQuery
1013 cursor = query.get_compiler().execute_sql(CURSOR)
1014 if cursor:
1015 with cursor:
1016 return cursor.rowcount
1017 return 0
1018
1019 def update(self, **kwargs):
1020 """
1021 Update all elements in the current QuerySet, setting all the given
1022 fields to the appropriate values.
1023 """
1024 self._not_support_combined_queries("update")
1025 if self.query.is_sliced:
1026 raise TypeError("Cannot update a query once a slice has been taken.")
1027 self._for_write = True
1028 query = self.query.chain(sql.UpdateQuery)
1029 query.add_update_values(kwargs)
1030
1031 # Inline annotations in order_by(), if possible.
1032 new_order_by = []
1033 for col in query.order_by:
1034 alias = col
1035 descending = False
1036 if isinstance(alias, str) and alias.startswith("-"):
1037 alias = alias.removeprefix("-")
1038 descending = True
1039 if annotation := query.annotations.get(alias):
1040 if getattr(annotation, "contains_aggregate", False):
1041 raise exceptions.FieldError(
1042 f"Cannot update when ordering by an aggregate: {annotation}"
1043 )
1044 if descending:
1045 annotation = annotation.desc()
1046 new_order_by.append(annotation)
1047 else:
1048 new_order_by.append(col)
1049 query.order_by = tuple(new_order_by)
1050
1051 # Clear any annotations so that they won't be present in subqueries.
1052 query.annotations = {}
1053 with transaction.mark_for_rollback_on_error():
1054 rows = query.get_compiler().execute_sql(CURSOR)
1055 self._result_cache = None
1056 return rows
1057
1058 def _update(self, values):
1059 """
1060 A version of update() that accepts field objects instead of field names.
1061 Used primarily for model saving and not intended for use by general
1062 code (it requires too much poking around at model internals to be
1063 useful at that level).
1064 """
1065 if self.query.is_sliced:
1066 raise TypeError("Cannot update a query once a slice has been taken.")
1067 query = self.query.chain(sql.UpdateQuery)
1068 query.add_update_fields(values)
1069 # Clear any annotations so that they won't be present in subqueries.
1070 query.annotations = {}
1071 self._result_cache = None
1072 return query.get_compiler().execute_sql(CURSOR)
1073
1074 _update.queryset_only = False
1075
1076 def exists(self):
1077 """
1078 Return True if the QuerySet would have any results, False otherwise.
1079 """
1080 if self._result_cache is None:
1081 return self.query.has_results()
1082 return bool(self._result_cache)
1083
1084 def contains(self, obj):
1085 """
1086 Return True if the QuerySet contains the provided obj,
1087 False otherwise.
1088 """
1089 self._not_support_combined_queries("contains")
1090 if self._fields is not None:
1091 raise TypeError(
1092 "Cannot call QuerySet.contains() after .values() or .values_list()."
1093 )
1094 try:
1095 if obj._meta.concrete_model != self.model._meta.concrete_model:
1096 return False
1097 except AttributeError:
1098 raise TypeError("'obj' must be a model instance.")
1099 if obj.pk is None:
1100 raise ValueError("QuerySet.contains() cannot be used on unsaved objects.")
1101 if self._result_cache is not None:
1102 return obj in self._result_cache
1103 return self.filter(pk=obj.pk).exists()
1104
1105 def _prefetch_related_objects(self):
1106 # This method can only be called once the result cache has been filled.
1107 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1108 self._prefetch_done = True
1109
1110 def explain(self, *, format=None, **options):
1111 """
1112 Runs an EXPLAIN on the SQL query this QuerySet would perform, and
1113 returns the results.
1114 """
1115 return self.query.explain(format=format, **options)
1116
1117 ##################################################
1118 # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
1119 ##################################################
1120
1121 def raw(self, raw_query, params=(), translations=None):
1122 qs = RawQuerySet(
1123 raw_query,
1124 model=self.model,
1125 params=params,
1126 translations=translations,
1127 )
1128 qs._prefetch_related_lookups = self._prefetch_related_lookups[:]
1129 return qs
1130
1131 def _values(self, *fields, **expressions):
1132 clone = self._chain()
1133 if expressions:
1134 clone = clone.annotate(**expressions)
1135 clone._fields = fields
1136 clone.query.set_values(fields)
1137 return clone
1138
1139 def values(self, *fields, **expressions):
1140 fields += tuple(expressions)
1141 clone = self._values(*fields, **expressions)
1142 clone._iterable_class = ValuesIterable
1143 return clone
1144
1145 def values_list(self, *fields, flat=False, named=False):
1146 if flat and named:
1147 raise TypeError("'flat' and 'named' can't be used together.")
1148 if flat and len(fields) > 1:
1149 raise TypeError(
1150 "'flat' is not valid when values_list is called with more than one "
1151 "field."
1152 )
1153
1154 field_names = {f for f in fields if not hasattr(f, "resolve_expression")}
1155 _fields = []
1156 expressions = {}
1157 counter = 1
1158 for field in fields:
1159 if hasattr(field, "resolve_expression"):
1160 field_id_prefix = getattr(
1161 field, "default_alias", field.__class__.__name__.lower()
1162 )
1163 while True:
1164 field_id = field_id_prefix + str(counter)
1165 counter += 1
1166 if field_id not in field_names:
1167 break
1168 expressions[field_id] = field
1169 _fields.append(field_id)
1170 else:
1171 _fields.append(field)
1172
1173 clone = self._values(*_fields, **expressions)
1174 clone._iterable_class = (
1175 NamedValuesListIterable
1176 if named
1177 else FlatValuesListIterable
1178 if flat
1179 else ValuesListIterable
1180 )
1181 return clone
1182
1183 def dates(self, field_name, kind, order="ASC"):
1184 """
1185 Return a list of date objects representing all available dates for
1186 the given field_name, scoped to 'kind'.
1187 """
1188 if kind not in ("year", "month", "week", "day"):
1189 raise ValueError("'kind' must be one of 'year', 'month', 'week', or 'day'.")
1190 if order not in ("ASC", "DESC"):
1191 raise ValueError("'order' must be either 'ASC' or 'DESC'.")
1192 return (
1193 self.annotate(
1194 datefield=Trunc(field_name, kind, output_field=DateField()),
1195 plain_field=F(field_name),
1196 )
1197 .values_list("datefield", flat=True)
1198 .distinct()
1199 .filter(plain_field__isnull=False)
1200 .order_by(("-" if order == "DESC" else "") + "datefield")
1201 )
1202
1203 def datetimes(self, field_name, kind, order="ASC", tzinfo=None):
1204 """
1205 Return a list of datetime objects representing all available
1206 datetimes for the given field_name, scoped to 'kind'.
1207 """
1208 if kind not in ("year", "month", "week", "day", "hour", "minute", "second"):
1209 raise ValueError(
1210 "'kind' must be one of 'year', 'month', 'week', 'day', "
1211 "'hour', 'minute', or 'second'."
1212 )
1213 if order not in ("ASC", "DESC"):
1214 raise ValueError("'order' must be either 'ASC' or 'DESC'.")
1215
1216 if tzinfo is None:
1217 tzinfo = timezone.get_current_timezone()
1218
1219 return (
1220 self.annotate(
1221 datetimefield=Trunc(
1222 field_name,
1223 kind,
1224 output_field=DateTimeField(),
1225 tzinfo=tzinfo,
1226 ),
1227 plain_field=F(field_name),
1228 )
1229 .values_list("datetimefield", flat=True)
1230 .distinct()
1231 .filter(plain_field__isnull=False)
1232 .order_by(("-" if order == "DESC" else "") + "datetimefield")
1233 )
1234
1235 def none(self):
1236 """Return an empty QuerySet."""
1237 clone = self._chain()
1238 clone.query.set_empty()
1239 return clone
1240
1241 ##################################################################
1242 # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
1243 ##################################################################
1244
1245 def all(self):
1246 """
1247 Return a new QuerySet that is a copy of the current one. This allows a
1248 QuerySet to proxy for a model manager in some cases.
1249 """
1250 return self._chain()
1251
1252 def filter(self, *args, **kwargs):
1253 """
1254 Return a new QuerySet instance with the args ANDed to the existing
1255 set.
1256 """
1257 self._not_support_combined_queries("filter")
1258 return self._filter_or_exclude(False, args, kwargs)
1259
1260 def exclude(self, *args, **kwargs):
1261 """
1262 Return a new QuerySet instance with NOT (args) ANDed to the existing
1263 set.
1264 """
1265 self._not_support_combined_queries("exclude")
1266 return self._filter_or_exclude(True, args, kwargs)
1267
1268 def _filter_or_exclude(self, negate, args, kwargs):
1269 if (args or kwargs) and self.query.is_sliced:
1270 raise TypeError("Cannot filter a query once a slice has been taken.")
1271 clone = self._chain()
1272 if self._defer_next_filter:
1273 self._defer_next_filter = False
1274 clone._deferred_filter = negate, args, kwargs
1275 else:
1276 clone._filter_or_exclude_inplace(negate, args, kwargs)
1277 return clone
1278
1279 def _filter_or_exclude_inplace(self, negate, args, kwargs):
1280 if negate:
1281 self._query.add_q(~Q(*args, **kwargs))
1282 else:
1283 self._query.add_q(Q(*args, **kwargs))
1284
1285 def complex_filter(self, filter_obj):
1286 """
1287 Return a new QuerySet instance with filter_obj added to the filters.
1288
1289 filter_obj can be a Q object or a dictionary of keyword lookup
1290 arguments.
1291
1292 This exists to support framework features such as 'limit_choices_to',
1293 and usually it will be more natural to use other methods.
1294 """
1295 if isinstance(filter_obj, Q):
1296 clone = self._chain()
1297 clone.query.add_q(filter_obj)
1298 return clone
1299 else:
1300 return self._filter_or_exclude(False, args=(), kwargs=filter_obj)
1301
1302 def _combinator_query(self, combinator, *other_qs, all=False):
1303 # Clone the query to inherit the select list and everything
1304 clone = self._chain()
1305 # Clear limits and ordering so they can be reapplied
1306 clone.query.clear_ordering(force=True)
1307 clone.query.clear_limits()
1308 clone.query.combined_queries = (self.query,) + tuple(
1309 qs.query for qs in other_qs
1310 )
1311 clone.query.combinator = combinator
1312 clone.query.combinator_all = all
1313 return clone
1314
1315 def union(self, *other_qs, all=False):
1316 # If the query is an EmptyQuerySet, combine all nonempty querysets.
1317 if isinstance(self, EmptyQuerySet):
1318 qs = [q for q in other_qs if not isinstance(q, EmptyQuerySet)]
1319 if not qs:
1320 return self
1321 if len(qs) == 1:
1322 return qs[0]
1323 return qs[0]._combinator_query("union", *qs[1:], all=all)
1324 return self._combinator_query("union", *other_qs, all=all)
1325
1326 def intersection(self, *other_qs):
1327 # If any query is an EmptyQuerySet, return it.
1328 if isinstance(self, EmptyQuerySet):
1329 return self
1330 for other in other_qs:
1331 if isinstance(other, EmptyQuerySet):
1332 return other
1333 return self._combinator_query("intersection", *other_qs)
1334
1335 def difference(self, *other_qs):
1336 # If the query is an EmptyQuerySet, return it.
1337 if isinstance(self, EmptyQuerySet):
1338 return self
1339 return self._combinator_query("difference", *other_qs)
1340
1341 def select_for_update(self, nowait=False, skip_locked=False, of=(), no_key=False):
1342 """
1343 Return a new QuerySet instance that will select objects with a
1344 FOR UPDATE lock.
1345 """
1346 if nowait and skip_locked:
1347 raise ValueError("The nowait option cannot be used with skip_locked.")
1348 obj = self._chain()
1349 obj._for_write = True
1350 obj.query.select_for_update = True
1351 obj.query.select_for_update_nowait = nowait
1352 obj.query.select_for_update_skip_locked = skip_locked
1353 obj.query.select_for_update_of = of
1354 obj.query.select_for_no_key_update = no_key
1355 return obj
1356
1357 def select_related(self, *fields):
1358 """
1359 Return a new QuerySet instance that will select related objects.
1360
1361 If fields are specified, they must be ForeignKey fields and only those
1362 related objects are included in the selection.
1363
1364 If select_related(None) is called, clear the list.
1365 """
1366 self._not_support_combined_queries("select_related")
1367 if self._fields is not None:
1368 raise TypeError(
1369 "Cannot call select_related() after .values() or .values_list()"
1370 )
1371
1372 obj = self._chain()
1373 if fields == (None,):
1374 obj.query.select_related = False
1375 elif fields:
1376 obj.query.add_select_related(fields)
1377 else:
1378 obj.query.select_related = True
1379 return obj
1380
1381 def prefetch_related(self, *lookups):
1382 """
1383 Return a new QuerySet instance that will prefetch the specified
1384 Many-To-One and Many-To-Many related objects when the QuerySet is
1385 evaluated.
1386
1387 When prefetch_related() is called more than once, append to the list of
1388 prefetch lookups. If prefetch_related(None) is called, clear the list.
1389 """
1390 self._not_support_combined_queries("prefetch_related")
1391 clone = self._chain()
1392 if lookups == (None,):
1393 clone._prefetch_related_lookups = ()
1394 else:
1395 for lookup in lookups:
1396 if isinstance(lookup, Prefetch):
1397 lookup = lookup.prefetch_to
1398 lookup = lookup.split(LOOKUP_SEP, 1)[0]
1399 if lookup in self.query._filtered_relations:
1400 raise ValueError(
1401 "prefetch_related() is not supported with FilteredRelation."
1402 )
1403 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1404 return clone
1405
1406 def annotate(self, *args, **kwargs):
1407 """
1408 Return a query set in which the returned objects have been annotated
1409 with extra data or aggregations.
1410 """
1411 self._not_support_combined_queries("annotate")
1412 return self._annotate(args, kwargs, select=True)
1413
1414 def alias(self, *args, **kwargs):
1415 """
1416 Return a query set with added aliases for extra data or aggregations.
1417 """
1418 self._not_support_combined_queries("alias")
1419 return self._annotate(args, kwargs, select=False)
1420
1421 def _annotate(self, args, kwargs, select=True):
1422 self._validate_values_are_expressions(
1423 args + tuple(kwargs.values()), method_name="annotate"
1424 )
1425 annotations = {}
1426 for arg in args:
1427 # The default_alias property may raise a TypeError.
1428 try:
1429 if arg.default_alias in kwargs:
1430 raise ValueError(
1431 f"The named annotation '{arg.default_alias}' conflicts with the "
1432 "default name for another annotation."
1433 )
1434 except TypeError:
1435 raise TypeError("Complex annotations require an alias")
1436 annotations[arg.default_alias] = arg
1437 annotations.update(kwargs)
1438
1439 clone = self._chain()
1440 names = self._fields
1441 if names is None:
1442 names = set(
1443 chain.from_iterable(
1444 (field.name, field.attname)
1445 if hasattr(field, "attname")
1446 else (field.name,)
1447 for field in self.model._meta.get_fields()
1448 )
1449 )
1450
1451 for alias, annotation in annotations.items():
1452 if alias in names:
1453 raise ValueError(
1454 f"The annotation '{alias}' conflicts with a field on the model."
1455 )
1456 if isinstance(annotation, FilteredRelation):
1457 clone.query.add_filtered_relation(annotation, alias)
1458 else:
1459 clone.query.add_annotation(
1460 annotation,
1461 alias,
1462 select=select,
1463 )
1464 for alias, annotation in clone.query.annotations.items():
1465 if alias in annotations and annotation.contains_aggregate:
1466 if clone._fields is None:
1467 clone.query.group_by = True
1468 else:
1469 clone.query.set_group_by()
1470 break
1471
1472 return clone
1473
1474 def order_by(self, *field_names):
1475 """Return a new QuerySet instance with the ordering changed."""
1476 if self.query.is_sliced:
1477 raise TypeError("Cannot reorder a query once a slice has been taken.")
1478 obj = self._chain()
1479 obj.query.clear_ordering(force=True, clear_default=False)
1480 obj.query.add_ordering(*field_names)
1481 return obj
1482
1483 def distinct(self, *field_names):
1484 """
1485 Return a new QuerySet instance that will select only distinct results.
1486 """
1487 self._not_support_combined_queries("distinct")
1488 if self.query.is_sliced:
1489 raise TypeError(
1490 "Cannot create distinct fields once a slice has been taken."
1491 )
1492 obj = self._chain()
1493 obj.query.add_distinct_fields(*field_names)
1494 return obj
1495
1496 def extra(
1497 self,
1498 select=None,
1499 where=None,
1500 params=None,
1501 tables=None,
1502 order_by=None,
1503 select_params=None,
1504 ):
1505 """Add extra SQL fragments to the query."""
1506 self._not_support_combined_queries("extra")
1507 if self.query.is_sliced:
1508 raise TypeError("Cannot change a query once a slice has been taken.")
1509 clone = self._chain()
1510 clone.query.add_extra(select, select_params, where, params, tables, order_by)
1511 return clone
1512
1513 def reverse(self):
1514 """Reverse the ordering of the QuerySet."""
1515 if self.query.is_sliced:
1516 raise TypeError("Cannot reverse a query once a slice has been taken.")
1517 clone = self._chain()
1518 clone.query.standard_ordering = not clone.query.standard_ordering
1519 return clone
1520
1521 def defer(self, *fields):
1522 """
1523 Defer the loading of data for certain fields until they are accessed.
1524 Add the set of deferred fields to any existing set of deferred fields.
1525 The only exception to this is if None is passed in as the only
1526 parameter, in which case removal all deferrals.
1527 """
1528 self._not_support_combined_queries("defer")
1529 if self._fields is not None:
1530 raise TypeError("Cannot call defer() after .values() or .values_list()")
1531 clone = self._chain()
1532 if fields == (None,):
1533 clone.query.clear_deferred_loading()
1534 else:
1535 clone.query.add_deferred_loading(fields)
1536 return clone
1537
1538 def only(self, *fields):
1539 """
1540 Essentially, the opposite of defer(). Only the fields passed into this
1541 method and that are not already specified as deferred are loaded
1542 immediately when the queryset is evaluated.
1543 """
1544 self._not_support_combined_queries("only")
1545 if self._fields is not None:
1546 raise TypeError("Cannot call only() after .values() or .values_list()")
1547 if fields == (None,):
1548 # Can only pass None to defer(), not only(), as the rest option.
1549 # That won't stop people trying to do this, so let's be explicit.
1550 raise TypeError("Cannot pass None as an argument to only().")
1551 for field in fields:
1552 field = field.split(LOOKUP_SEP, 1)[0]
1553 if field in self.query._filtered_relations:
1554 raise ValueError("only() is not supported with FilteredRelation.")
1555 clone = self._chain()
1556 clone.query.add_immediate_loading(fields)
1557 return clone
1558
1559 ###################################
1560 # PUBLIC INTROSPECTION ATTRIBUTES #
1561 ###################################
1562
1563 @property
1564 def ordered(self):
1565 """
1566 Return True if the QuerySet is ordered -- i.e. has an order_by()
1567 clause or a default ordering on the model (or is empty).
1568 """
1569 if isinstance(self, EmptyQuerySet):
1570 return True
1571 if self.query.extra_order_by or self.query.order_by:
1572 return True
1573 elif (
1574 self.query.default_ordering
1575 and self.query.get_meta().ordering
1576 and
1577 # A default ordering doesn't affect GROUP BY queries.
1578 not self.query.group_by
1579 ):
1580 return True
1581 else:
1582 return False
1583
1584 ###################
1585 # PRIVATE METHODS #
1586 ###################
1587
1588 def _insert(
1589 self,
1590 objs,
1591 fields,
1592 returning_fields=None,
1593 raw=False,
1594 on_conflict=None,
1595 update_fields=None,
1596 unique_fields=None,
1597 ):
1598 """
1599 Insert a new record for the given model. This provides an interface to
1600 the InsertQuery class and is how Model.save() is implemented.
1601 """
1602 self._for_write = True
1603 query = sql.InsertQuery(
1604 self.model,
1605 on_conflict=on_conflict,
1606 update_fields=update_fields,
1607 unique_fields=unique_fields,
1608 )
1609 query.insert_values(fields, objs, raw=raw)
1610 return query.get_compiler().execute_sql(returning_fields)
1611
1612 _insert.queryset_only = False
1613
1614 def _batched_insert(
1615 self,
1616 objs,
1617 fields,
1618 batch_size,
1619 on_conflict=None,
1620 update_fields=None,
1621 unique_fields=None,
1622 ):
1623 """
1624 Helper method for bulk_create() to insert objs one batch at a time.
1625 """
1626 ops = db_connection.ops
1627 max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
1628 batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
1629 inserted_rows = []
1630 bulk_return = db_connection.features.can_return_rows_from_bulk_insert
1631 for item in [objs[i : i + batch_size] for i in range(0, len(objs), batch_size)]:
1632 if bulk_return and on_conflict is None:
1633 inserted_rows.extend(
1634 self._insert(
1635 item,
1636 fields=fields,
1637 returning_fields=self.model._meta.db_returning_fields,
1638 )
1639 )
1640 else:
1641 self._insert(
1642 item,
1643 fields=fields,
1644 on_conflict=on_conflict,
1645 update_fields=update_fields,
1646 unique_fields=unique_fields,
1647 )
1648 return inserted_rows
1649
1650 def _chain(self):
1651 """
1652 Return a copy of the current QuerySet that's ready for another
1653 operation.
1654 """
1655 obj = self._clone()
1656 if obj._sticky_filter:
1657 obj.query.filter_is_sticky = True
1658 obj._sticky_filter = False
1659 return obj
1660
1661 def _clone(self):
1662 """
1663 Return a copy of the current QuerySet. A lightweight alternative
1664 to deepcopy().
1665 """
1666 c = self.__class__(
1667 model=self.model,
1668 query=self.query.chain(),
1669 hints=self._hints,
1670 )
1671 c._sticky_filter = self._sticky_filter
1672 c._for_write = self._for_write
1673 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1674 c._known_related_objects = self._known_related_objects
1675 c._iterable_class = self._iterable_class
1676 c._fields = self._fields
1677 return c
1678
1679 def _fetch_all(self):
1680 if self._result_cache is None:
1681 self._result_cache = list(self._iterable_class(self))
1682 if self._prefetch_related_lookups and not self._prefetch_done:
1683 self._prefetch_related_objects()
1684
1685 def _next_is_sticky(self):
1686 """
1687 Indicate that the next filter call and the one following that should
1688 be treated as a single filter. This is only important when it comes to
1689 determining when to reuse tables for many-to-many filters. Required so
1690 that we can filter naturally on the results of related managers.
1691
1692 This doesn't return a clone of the current QuerySet (it returns
1693 "self"). The method is only used internally and should be immediately
1694 followed by a filter() that does create a clone.
1695 """
1696 self._sticky_filter = True
1697 return self
1698
1699 def _merge_sanity_check(self, other):
1700 """Check that two QuerySet classes may be merged."""
1701 if self._fields is not None and (
1702 set(self.query.values_select) != set(other.query.values_select)
1703 or set(self.query.extra_select) != set(other.query.extra_select)
1704 or set(self.query.annotation_select) != set(other.query.annotation_select)
1705 ):
1706 raise TypeError(
1707 f"Merging '{self.__class__.__name__}' classes must involve the same values in each case."
1708 )
1709
1710 def _merge_known_related_objects(self, other):
1711 """
1712 Keep track of all known related objects from either QuerySet instance.
1713 """
1714 for field, objects in other._known_related_objects.items():
1715 self._known_related_objects.setdefault(field, {}).update(objects)
1716
1717 def resolve_expression(self, *args, **kwargs):
1718 if self._fields and len(self._fields) > 1:
1719 # values() queryset can only be used as nested queries
1720 # if they are set up to select only a single field.
1721 raise TypeError("Cannot use multi-field values as a filter value.")
1722 query = self.query.resolve_expression(*args, **kwargs)
1723 return query
1724
1725 resolve_expression.queryset_only = True
1726
1727 def _add_hints(self, **hints):
1728 """
1729 Update hinting information for use by routers. Add new key/values or
1730 overwrite existing key/values.
1731 """
1732 self._hints.update(hints)
1733
1734 def _has_filters(self):
1735 """
1736 Check if this QuerySet has any filtering going on. This isn't
1737 equivalent with checking if all objects are present in results, for
1738 example, qs[1:]._has_filters() -> False.
1739 """
1740 return self.query.has_filters()
1741
1742 @staticmethod
1743 def _validate_values_are_expressions(values, method_name):
1744 invalid_args = sorted(
1745 str(arg) for arg in values if not hasattr(arg, "resolve_expression")
1746 )
1747 if invalid_args:
1748 raise TypeError(
1749 "QuerySet.{}() received non-expression(s): {}.".format(
1750 method_name,
1751 ", ".join(invalid_args),
1752 )
1753 )
1754
1755 def _not_support_combined_queries(self, operation_name):
1756 if self.query.combinator:
1757 raise NotSupportedError(
1758 f"Calling QuerySet.{operation_name}() after {self.query.combinator}() is not supported."
1759 )
1760
1761 def _check_operator_queryset(self, other, operator_):
1762 if self.query.combinator or other.query.combinator:
1763 raise TypeError(f"Cannot use {operator_} operator with combined queryset.")
1764
1765 def _check_ordering_first_last_queryset_aggregation(self, method):
1766 if isinstance(self.query.group_by, tuple) and not any(
1767 col.output_field is self.model._meta.pk for col in self.query.group_by
1768 ):
1769 raise TypeError(
1770 f"Cannot use QuerySet.{method}() on an unordered queryset performing "
1771 f"aggregation. Add an ordering with order_by()."
1772 )
1773
1774
1775class InstanceCheckMeta(type):
1776 def __instancecheck__(self, instance):
1777 return isinstance(instance, QuerySet) and instance.query.is_empty()
1778
1779
1780class EmptyQuerySet(metaclass=InstanceCheckMeta):
1781 """
1782 Marker class to checking if a queryset is empty by .none():
1783 isinstance(qs.none(), EmptyQuerySet) -> True
1784 """
1785
1786 def __init__(self, *args, **kwargs):
1787 raise TypeError("EmptyQuerySet can't be instantiated")
1788
1789
1790class RawQuerySet:
1791 """
1792 Provide an iterator which converts the results of raw SQL queries into
1793 annotated model instances.
1794 """
1795
1796 def __init__(
1797 self,
1798 raw_query,
1799 model=None,
1800 query=None,
1801 params=(),
1802 translations=None,
1803 hints=None,
1804 ):
1805 self.raw_query = raw_query
1806 self.model = model
1807 self._hints = hints or {}
1808 self.query = query or sql.RawQuery(sql=raw_query, params=params)
1809 self.params = params
1810 self.translations = translations or {}
1811 self._result_cache = None
1812 self._prefetch_related_lookups = ()
1813 self._prefetch_done = False
1814
1815 def resolve_model_init_order(self):
1816 """Resolve the init field names and value positions."""
1817 converter = db_connection.introspection.identifier_converter
1818 model_init_fields = [
1819 f for f in self.model._meta.fields if converter(f.column) in self.columns
1820 ]
1821 annotation_fields = [
1822 (column, pos)
1823 for pos, column in enumerate(self.columns)
1824 if column not in self.model_fields
1825 ]
1826 model_init_order = [
1827 self.columns.index(converter(f.column)) for f in model_init_fields
1828 ]
1829 model_init_names = [f.attname for f in model_init_fields]
1830 return model_init_names, model_init_order, annotation_fields
1831
1832 def prefetch_related(self, *lookups):
1833 """Same as QuerySet.prefetch_related()"""
1834 clone = self._clone()
1835 if lookups == (None,):
1836 clone._prefetch_related_lookups = ()
1837 else:
1838 clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups
1839 return clone
1840
1841 def _prefetch_related_objects(self):
1842 prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups)
1843 self._prefetch_done = True
1844
1845 def _clone(self):
1846 """Same as QuerySet._clone()"""
1847 c = self.__class__(
1848 self.raw_query,
1849 model=self.model,
1850 query=self.query,
1851 params=self.params,
1852 translations=self.translations,
1853 hints=self._hints,
1854 )
1855 c._prefetch_related_lookups = self._prefetch_related_lookups[:]
1856 return c
1857
1858 def _fetch_all(self):
1859 if self._result_cache is None:
1860 self._result_cache = list(self.iterator())
1861 if self._prefetch_related_lookups and not self._prefetch_done:
1862 self._prefetch_related_objects()
1863
1864 def __len__(self):
1865 self._fetch_all()
1866 return len(self._result_cache)
1867
1868 def __bool__(self):
1869 self._fetch_all()
1870 return bool(self._result_cache)
1871
1872 def __iter__(self):
1873 self._fetch_all()
1874 return iter(self._result_cache)
1875
1876 def iterator(self):
1877 yield from RawModelIterable(self)
1878
1879 def __repr__(self):
1880 return f"<{self.__class__.__name__}: {self.query}>"
1881
1882 def __getitem__(self, k):
1883 return list(self)[k]
1884
1885 @cached_property
1886 def columns(self):
1887 """
1888 A list of model field names in the order they'll appear in the
1889 query results.
1890 """
1891 columns = self.query.get_columns()
1892 # Adjust any column names which don't match field names
1893 for query_name, model_name in self.translations.items():
1894 # Ignore translations for nonexistent column names
1895 try:
1896 index = columns.index(query_name)
1897 except ValueError:
1898 pass
1899 else:
1900 columns[index] = model_name
1901 return columns
1902
1903 @cached_property
1904 def model_fields(self):
1905 """A dict mapping column names to model field names."""
1906 converter = db_connection.introspection.identifier_converter
1907 model_fields = {}
1908 for field in self.model._meta.fields:
1909 name, column = field.get_attname_column()
1910 model_fields[converter(column)] = field
1911 return model_fields
1912
1913
1914class Prefetch:
1915 def __init__(self, lookup, queryset=None, to_attr=None):
1916 # `prefetch_through` is the path we traverse to perform the prefetch.
1917 self.prefetch_through = lookup
1918 # `prefetch_to` is the path to the attribute that stores the result.
1919 self.prefetch_to = lookup
1920 if queryset is not None and (
1921 isinstance(queryset, RawQuerySet)
1922 or (
1923 hasattr(queryset, "_iterable_class")
1924 and not issubclass(queryset._iterable_class, ModelIterable)
1925 )
1926 ):
1927 raise ValueError(
1928 "Prefetch querysets cannot use raw(), values(), and values_list()."
1929 )
1930 if to_attr:
1931 self.prefetch_to = LOOKUP_SEP.join(
1932 lookup.split(LOOKUP_SEP)[:-1] + [to_attr]
1933 )
1934
1935 self.queryset = queryset
1936 self.to_attr = to_attr
1937
1938 def __getstate__(self):
1939 obj_dict = self.__dict__.copy()
1940 if self.queryset is not None:
1941 queryset = self.queryset._chain()
1942 # Prevent the QuerySet from being evaluated
1943 queryset._result_cache = []
1944 queryset._prefetch_done = True
1945 obj_dict["queryset"] = queryset
1946 return obj_dict
1947
1948 def add_prefix(self, prefix):
1949 self.prefetch_through = prefix + LOOKUP_SEP + self.prefetch_through
1950 self.prefetch_to = prefix + LOOKUP_SEP + self.prefetch_to
1951
1952 def get_current_prefetch_to(self, level):
1953 return LOOKUP_SEP.join(self.prefetch_to.split(LOOKUP_SEP)[: level + 1])
1954
1955 def get_current_to_attr(self, level):
1956 parts = self.prefetch_to.split(LOOKUP_SEP)
1957 to_attr = parts[level]
1958 as_attr = self.to_attr and level == len(parts) - 1
1959 return to_attr, as_attr
1960
1961 def get_current_queryset(self, level):
1962 if self.get_current_prefetch_to(level) == self.prefetch_to:
1963 return self.queryset
1964 return None
1965
1966 def __eq__(self, other):
1967 if not isinstance(other, Prefetch):
1968 return NotImplemented
1969 return self.prefetch_to == other.prefetch_to
1970
1971 def __hash__(self):
1972 return hash((self.__class__, self.prefetch_to))
1973
1974
1975def normalize_prefetch_lookups(lookups, prefix=None):
1976 """Normalize lookups into Prefetch objects."""
1977 ret = []
1978 for lookup in lookups:
1979 if not isinstance(lookup, Prefetch):
1980 lookup = Prefetch(lookup)
1981 if prefix:
1982 lookup.add_prefix(prefix)
1983 ret.append(lookup)
1984 return ret
1985
1986
1987def prefetch_related_objects(model_instances, *related_lookups):
1988 """
1989 Populate prefetched object caches for a list of model instances based on
1990 the lookups/Prefetch instances given.
1991 """
1992 if not model_instances:
1993 return # nothing to do
1994
1995 # We need to be able to dynamically add to the list of prefetch_related
1996 # lookups that we look up (see below). So we need some book keeping to
1997 # ensure we don't do duplicate work.
1998 done_queries = {} # dictionary of things like 'foo__bar': [results]
1999
2000 auto_lookups = set() # we add to this as we go through.
2001 followed_descriptors = set() # recursion protection
2002
2003 all_lookups = normalize_prefetch_lookups(reversed(related_lookups))
2004 while all_lookups:
2005 lookup = all_lookups.pop()
2006 if lookup.prefetch_to in done_queries:
2007 if lookup.queryset is not None:
2008 raise ValueError(
2009 f"'{lookup.prefetch_to}' lookup was already seen with a different queryset. "
2010 "You may need to adjust the ordering of your lookups."
2011 )
2012
2013 continue
2014
2015 # Top level, the list of objects to decorate is the result cache
2016 # from the primary QuerySet. It won't be for deeper levels.
2017 obj_list = model_instances
2018
2019 through_attrs = lookup.prefetch_through.split(LOOKUP_SEP)
2020 for level, through_attr in enumerate(through_attrs):
2021 # Prepare main instances
2022 if not obj_list:
2023 break
2024
2025 prefetch_to = lookup.get_current_prefetch_to(level)
2026 if prefetch_to in done_queries:
2027 # Skip any prefetching, and any object preparation
2028 obj_list = done_queries[prefetch_to]
2029 continue
2030
2031 # Prepare objects:
2032 good_objects = True
2033 for obj in obj_list:
2034 # Since prefetching can re-use instances, it is possible to have
2035 # the same instance multiple times in obj_list, so obj might
2036 # already be prepared.
2037 if not hasattr(obj, "_prefetched_objects_cache"):
2038 try:
2039 obj._prefetched_objects_cache = {}
2040 except (AttributeError, TypeError):
2041 # Must be an immutable object from
2042 # values_list(flat=True), for example (TypeError) or
2043 # a QuerySet subclass that isn't returning Model
2044 # instances (AttributeError), either in Plain or a 3rd
2045 # party. prefetch_related() doesn't make sense, so quit.
2046 good_objects = False
2047 break
2048 if not good_objects:
2049 break
2050
2051 # Descend down tree
2052
2053 # We assume that objects retrieved are homogeneous (which is the premise
2054 # of prefetch_related), so what applies to first object applies to all.
2055 first_obj = obj_list[0]
2056 to_attr = lookup.get_current_to_attr(level)[0]
2057 prefetcher, descriptor, attr_found, is_fetched = get_prefetcher(
2058 first_obj, through_attr, to_attr
2059 )
2060
2061 if not attr_found:
2062 raise AttributeError(
2063 f"Cannot find '{through_attr}' on {first_obj.__class__.__name__} object, '{lookup.prefetch_through}' is an invalid "
2064 "parameter to prefetch_related()"
2065 )
2066
2067 if level == len(through_attrs) - 1 and prefetcher is None:
2068 # Last one, this *must* resolve to something that supports
2069 # prefetching, otherwise there is no point adding it and the
2070 # developer asking for it has made a mistake.
2071 raise ValueError(
2072 f"'{lookup.prefetch_through}' does not resolve to an item that supports "
2073 "prefetching - this is an invalid parameter to "
2074 "prefetch_related()."
2075 )
2076
2077 obj_to_fetch = None
2078 if prefetcher is not None:
2079 obj_to_fetch = [obj for obj in obj_list if not is_fetched(obj)]
2080
2081 if obj_to_fetch:
2082 obj_list, additional_lookups = prefetch_one_level(
2083 obj_to_fetch,
2084 prefetcher,
2085 lookup,
2086 level,
2087 )
2088 # We need to ensure we don't keep adding lookups from the
2089 # same relationships to stop infinite recursion. So, if we
2090 # are already on an automatically added lookup, don't add
2091 # the new lookups from relationships we've seen already.
2092 if not (
2093 prefetch_to in done_queries
2094 and lookup in auto_lookups
2095 and descriptor in followed_descriptors
2096 ):
2097 done_queries[prefetch_to] = obj_list
2098 new_lookups = normalize_prefetch_lookups(
2099 reversed(additional_lookups), prefetch_to
2100 )
2101 auto_lookups.update(new_lookups)
2102 all_lookups.extend(new_lookups)
2103 followed_descriptors.add(descriptor)
2104 else:
2105 # Either a singly related object that has already been fetched
2106 # (e.g. via select_related), or hopefully some other property
2107 # that doesn't support prefetching but needs to be traversed.
2108
2109 # We replace the current list of parent objects with the list
2110 # of related objects, filtering out empty or missing values so
2111 # that we can continue with nullable or reverse relations.
2112 new_obj_list = []
2113 for obj in obj_list:
2114 if through_attr in getattr(obj, "_prefetched_objects_cache", ()):
2115 # If related objects have been prefetched, use the
2116 # cache rather than the object's through_attr.
2117 new_obj = list(obj._prefetched_objects_cache.get(through_attr))
2118 else:
2119 try:
2120 new_obj = getattr(obj, through_attr)
2121 except exceptions.ObjectDoesNotExist:
2122 continue
2123 if new_obj is None:
2124 continue
2125 # We special-case `list` rather than something more generic
2126 # like `Iterable` because we don't want to accidentally match
2127 # user models that define __iter__.
2128 if isinstance(new_obj, list):
2129 new_obj_list.extend(new_obj)
2130 else:
2131 new_obj_list.append(new_obj)
2132 obj_list = new_obj_list
2133
2134
2135def get_prefetcher(instance, through_attr, to_attr):
2136 """
2137 For the attribute 'through_attr' on the given instance, find
2138 an object that has a get_prefetch_queryset().
2139 Return a 4 tuple containing:
2140 (the object with get_prefetch_queryset (or None),
2141 the descriptor object representing this relationship (or None),
2142 a boolean that is False if the attribute was not found at all,
2143 a function that takes an instance and returns a boolean that is True if
2144 the attribute has already been fetched for that instance)
2145 """
2146
2147 def has_to_attr_attribute(instance):
2148 return hasattr(instance, to_attr)
2149
2150 prefetcher = None
2151 is_fetched = has_to_attr_attribute
2152
2153 # For singly related objects, we have to avoid getting the attribute
2154 # from the object, as this will trigger the query. So we first try
2155 # on the class, in order to get the descriptor object.
2156 rel_obj_descriptor = getattr(instance.__class__, through_attr, None)
2157 if rel_obj_descriptor is None:
2158 attr_found = hasattr(instance, through_attr)
2159 else:
2160 attr_found = True
2161 if rel_obj_descriptor:
2162 # singly related object, descriptor object has the
2163 # get_prefetch_queryset() method.
2164 if hasattr(rel_obj_descriptor, "get_prefetch_queryset"):
2165 prefetcher = rel_obj_descriptor
2166 is_fetched = rel_obj_descriptor.is_cached
2167 else:
2168 # descriptor doesn't support prefetching, so we go ahead and get
2169 # the attribute on the instance rather than the class to
2170 # support many related managers
2171 rel_obj = getattr(instance, through_attr)
2172 if hasattr(rel_obj, "get_prefetch_queryset"):
2173 prefetcher = rel_obj
2174 if through_attr != to_attr:
2175 # Special case cached_property instances because hasattr
2176 # triggers attribute computation and assignment.
2177 if isinstance(
2178 getattr(instance.__class__, to_attr, None), cached_property
2179 ):
2180
2181 def has_cached_property(instance):
2182 return to_attr in instance.__dict__
2183
2184 is_fetched = has_cached_property
2185 else:
2186
2187 def in_prefetched_cache(instance):
2188 return through_attr in instance._prefetched_objects_cache
2189
2190 is_fetched = in_prefetched_cache
2191 return prefetcher, rel_obj_descriptor, attr_found, is_fetched
2192
2193
2194def prefetch_one_level(instances, prefetcher, lookup, level):
2195 """
2196 Helper function for prefetch_related_objects().
2197
2198 Run prefetches on all instances using the prefetcher object,
2199 assigning results to relevant caches in instance.
2200
2201 Return the prefetched objects along with any additional prefetches that
2202 must be done due to prefetch_related lookups found from default managers.
2203 """
2204 # prefetcher must have a method get_prefetch_queryset() which takes a list
2205 # of instances, and returns a tuple:
2206
2207 # (queryset of instances of self.model that are related to passed in instances,
2208 # callable that gets value to be matched for returned instances,
2209 # callable that gets value to be matched for passed in instances,
2210 # boolean that is True for singly related objects,
2211 # cache or field name to assign to,
2212 # boolean that is True when the previous argument is a cache name vs a field name).
2213
2214 # The 'values to be matched' must be hashable as they will be used
2215 # in a dictionary.
2216
2217 (
2218 rel_qs,
2219 rel_obj_attr,
2220 instance_attr,
2221 single,
2222 cache_name,
2223 is_descriptor,
2224 ) = prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))
2225 # We have to handle the possibility that the QuerySet we just got back
2226 # contains some prefetch_related lookups. We don't want to trigger the
2227 # prefetch_related functionality by evaluating the query. Rather, we need
2228 # to merge in the prefetch_related lookups.
2229 # Copy the lookups in case it is a Prefetch object which could be reused
2230 # later (happens in nested prefetch_related).
2231 additional_lookups = [
2232 copy.copy(additional_lookup)
2233 for additional_lookup in getattr(rel_qs, "_prefetch_related_lookups", ())
2234 ]
2235 if additional_lookups:
2236 # Don't need to clone because the manager should have given us a fresh
2237 # instance, so we access an internal instead of using public interface
2238 # for performance reasons.
2239 rel_qs._prefetch_related_lookups = ()
2240
2241 all_related_objects = list(rel_qs)
2242
2243 rel_obj_cache = {}
2244 for rel_obj in all_related_objects:
2245 rel_attr_val = rel_obj_attr(rel_obj)
2246 rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
2247
2248 to_attr, as_attr = lookup.get_current_to_attr(level)
2249 # Make sure `to_attr` does not conflict with a field.
2250 if as_attr and instances:
2251 # We assume that objects retrieved are homogeneous (which is the premise
2252 # of prefetch_related), so what applies to first object applies to all.
2253 model = instances[0].__class__
2254 try:
2255 model._meta.get_field(to_attr)
2256 except exceptions.FieldDoesNotExist:
2257 pass
2258 else:
2259 msg = "to_attr={} conflicts with a field on the {} model."
2260 raise ValueError(msg.format(to_attr, model.__name__))
2261
2262 # Whether or not we're prefetching the last part of the lookup.
2263 leaf = len(lookup.prefetch_through.split(LOOKUP_SEP)) - 1 == level
2264
2265 for obj in instances:
2266 instance_attr_val = instance_attr(obj)
2267 vals = rel_obj_cache.get(instance_attr_val, [])
2268
2269 if single:
2270 val = vals[0] if vals else None
2271 if as_attr:
2272 # A to_attr has been given for the prefetch.
2273 setattr(obj, to_attr, val)
2274 elif is_descriptor:
2275 # cache_name points to a field name in obj.
2276 # This field is a descriptor for a related object.
2277 setattr(obj, cache_name, val)
2278 else:
2279 # No to_attr has been given for this prefetch operation and the
2280 # cache_name does not point to a descriptor. Store the value of
2281 # the field in the object's field cache.
2282 obj._state.fields_cache[cache_name] = val
2283 else:
2284 if as_attr:
2285 setattr(obj, to_attr, vals)
2286 else:
2287 manager = getattr(obj, to_attr)
2288 if leaf and lookup.queryset is not None:
2289 qs = manager._apply_rel_filters(lookup.queryset)
2290 else:
2291 qs = manager.get_queryset()
2292 qs._result_cache = vals
2293 # We don't want the individual qs doing prefetch_related now,
2294 # since we have merged this into the current work.
2295 qs._prefetch_done = True
2296 obj._prefetched_objects_cache[cache_name] = qs
2297 return all_related_objects, additional_lookups
2298
2299
2300class RelatedPopulator:
2301 """
2302 RelatedPopulator is used for select_related() object instantiation.
2303
2304 The idea is that each select_related() model will be populated by a
2305 different RelatedPopulator instance. The RelatedPopulator instances get
2306 klass_info and select (computed in SQLCompiler) plus the used db as
2307 input for initialization. That data is used to compute which columns
2308 to use, how to instantiate the model, and how to populate the links
2309 between the objects.
2310
2311 The actual creation of the objects is done in populate() method. This
2312 method gets row and from_obj as input and populates the select_related()
2313 model instance.
2314 """
2315
2316 def __init__(self, klass_info, select):
2317 # Pre-compute needed attributes. The attributes are:
2318 # - model_cls: the possibly deferred model class to instantiate
2319 # - either:
2320 # - cols_start, cols_end: usually the columns in the row are
2321 # in the same order model_cls.__init__ expects them, so we
2322 # can instantiate by model_cls(*row[cols_start:cols_end])
2323 # - reorder_for_init: When select_related descends to a child
2324 # class, then we want to reuse the already selected parent
2325 # data. However, in this case the parent data isn't necessarily
2326 # in the same order that Model.__init__ expects it to be, so
2327 # we have to reorder the parent data. The reorder_for_init
2328 # attribute contains a function used to reorder the field data
2329 # in the order __init__ expects it.
2330 # - pk_idx: the index of the primary key field in the reordered
2331 # model data. Used to check if a related object exists at all.
2332 # - init_list: the field attnames fetched from the database. For
2333 # deferred models this isn't the same as all attnames of the
2334 # model's fields.
2335 # - related_populators: a list of RelatedPopulator instances if
2336 # select_related() descends to related models from this model.
2337 # - local_setter, remote_setter: Methods to set cached values on
2338 # the object being populated and on the remote object. Usually
2339 # these are Field.set_cached_value() methods.
2340 select_fields = klass_info["select_fields"]
2341
2342 self.cols_start = select_fields[0]
2343 self.cols_end = select_fields[-1] + 1
2344 self.init_list = [
2345 f[0].target.attname for f in select[self.cols_start : self.cols_end]
2346 ]
2347 self.reorder_for_init = None
2348
2349 self.model_cls = klass_info["model"]
2350 self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname)
2351 self.related_populators = get_related_populators(klass_info, select)
2352 self.local_setter = klass_info["local_setter"]
2353 self.remote_setter = klass_info["remote_setter"]
2354
2355 def populate(self, row, from_obj):
2356 if self.reorder_for_init:
2357 obj_data = self.reorder_for_init(row)
2358 else:
2359 obj_data = row[self.cols_start : self.cols_end]
2360 if obj_data[self.pk_idx] is None:
2361 obj = None
2362 else:
2363 obj = self.model_cls.from_db(self.init_list, obj_data)
2364 for rel_iter in self.related_populators:
2365 rel_iter.populate(row, obj)
2366 self.local_setter(from_obj, obj)
2367 if obj is not None:
2368 self.remote_setter(obj, from_obj)
2369
2370
2371def get_related_populators(klass_info, select):
2372 iterators = []
2373 related_klass_infos = klass_info.get("related_klass_infos", [])
2374 for rel_klass_info in related_klass_infos:
2375 rel_cls = RelatedPopulator(rel_klass_info, select)
2376 iterators.append(rel_cls)
2377 return iterators