1from __future__ import annotations
2
3import copy
4import warnings
5from collections.abc import Iterable, Iterator, Sequence
6from itertools import chain
7from typing import TYPE_CHECKING, Any, cast
8
9if TYPE_CHECKING:
10 from plain.postgres.meta import Meta
11 from plain.postgres.options import Options
12
13import plain.runtime
14from plain.exceptions import NON_FIELD_ERRORS, ValidationError
15from plain.postgres import models_registry, transaction, types
16from plain.postgres.constants import LOOKUP_SEP
17from plain.postgres.constraints import CheckConstraint, UniqueConstraint
18from plain.postgres.db import (
19 PLAIN_VERSION_PICKLE_KEY,
20 DatabaseError,
21)
22from plain.postgres.deletion import Collector
23from plain.postgres.dialect import MAX_NAME_LENGTH
24from plain.postgres.exceptions import (
25 DoesNotExistDescriptor,
26 FieldDoesNotExist,
27 MultipleObjectsReturnedDescriptor,
28)
29from plain.postgres.expressions import RawSQL, Value
30from plain.postgres.fields import NOT_PROVIDED, Field
31from plain.postgres.fields.related import RelatedField
32from plain.postgres.fields.reverse_related import ForeignObjectRel
33from plain.postgres.meta import Meta
34from plain.postgres.options import Options
35from plain.postgres.query import F, Q, QuerySet
36from plain.preflight import PreflightResult
37from plain.utils.encoding import force_str
38from plain.utils.hashable import make_hashable
39
40
41class Deferred:
42 def __repr__(self) -> str:
43 return "<Deferred field>"
44
45 def __str__(self) -> str:
46 return "<Deferred field>"
47
48
49DEFERRED = Deferred()
50
51
52class ModelBase(type):
53 """Metaclass for all models."""
54
55 def __new__(
56 cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any], **kwargs: Any
57 ) -> type:
58 # Don't do any of this for the root models.Model class.
59 if not bases:
60 return super().__new__(cls, name, bases, attrs)
61
62 for base in bases:
63 # Models are required to directly inherit from model.Model, not a subclass of it.
64 if issubclass(base, Model) and base is not Model:
65 raise TypeError(
66 f"A model can't extend another model: {name} extends {base}"
67 )
68
69 return super().__new__(cls, name, bases, attrs, **kwargs)
70
71
72class ModelStateFieldsCacheDescriptor:
73 def __get__(
74 self, instance: ModelState | None, cls: type | None = None
75 ) -> ModelStateFieldsCacheDescriptor | dict[str, Any]:
76 if instance is None:
77 return self
78 res = instance.fields_cache = {}
79 return res
80
81
82class ModelState:
83 """Store model instance state."""
84
85 # If true, uniqueness validation checks will consider this a new, unsaved
86 # object. Necessary for correct validation of new instances of objects with
87 # explicit (non-auto) PKs. This impacts validation only; it has no effect
88 # on the actual save.
89 adding = True
90 fields_cache = ModelStateFieldsCacheDescriptor()
91
92
93class Model(metaclass=ModelBase):
94 # Every model gets an automatic id field
95 id: int = types.PrimaryKeyField()
96
97 # Descriptors for other model behavior
98 query: QuerySet[Model] = QuerySet()
99 model_options: Options = Options()
100 _model_meta: Meta = Meta()
101 DoesNotExist = DoesNotExistDescriptor()
102 MultipleObjectsReturned = MultipleObjectsReturnedDescriptor()
103
104 def __init__(self, **kwargs: Any):
105 # Alias some things as locals to avoid repeat global lookups
106 cls = self.__class__
107 meta = cls._model_meta
108 _setattr = setattr
109 _DEFERRED = DEFERRED
110
111 # Set up the storage for instance state
112 self._state = ModelState()
113
114 # Process all fields from kwargs or use defaults
115 for field in meta.fields:
116 from plain.postgres.fields.related import RelatedField
117
118 is_related_object = False
119 # Virtual field
120 if field.attname not in kwargs and field.column is None:
121 continue
122 if isinstance(field, RelatedField) and isinstance(
123 field.remote_field, ForeignObjectRel
124 ):
125 try:
126 # Assume object instance was passed in.
127 rel_obj = kwargs.pop(field.name)
128 is_related_object = True
129 except KeyError:
130 try:
131 # Object instance wasn't passed in -- must be an ID.
132 val = kwargs.pop(field.attname)
133 except KeyError:
134 val = field.get_default()
135 else:
136 try:
137 val = kwargs.pop(field.attname)
138 except KeyError:
139 # This is done with an exception rather than the
140 # default argument on pop because we don't want
141 # get_default() to be evaluated, and then not used.
142 # Refs #12057.
143 val = field.get_default()
144
145 if is_related_object:
146 # If we are passed a related instance, set it using the
147 # field.name instead of field.attname (e.g. "user" instead of
148 # "user_id") so that the object gets properly cached (and type
149 # checked) by the RelatedObjectDescriptor.
150 if rel_obj is not _DEFERRED:
151 _setattr(self, field.name, rel_obj)
152 else:
153 if val is not _DEFERRED:
154 _setattr(self, field.attname, val)
155
156 # Handle any remaining kwargs (properties or virtual fields)
157 property_names = meta._property_names
158 unexpected = ()
159 for prop, value in kwargs.items():
160 # Any remaining kwargs must correspond to properties or virtual
161 # fields.
162 if prop in property_names:
163 if value is not _DEFERRED:
164 _setattr(self, prop, value)
165 else:
166 try:
167 meta.get_field(prop)
168 except FieldDoesNotExist:
169 unexpected += (prop,)
170 else:
171 if value is not _DEFERRED:
172 _setattr(self, prop, value)
173 if unexpected:
174 unexpected_names = ", ".join(repr(n) for n in unexpected)
175 raise TypeError(
176 f"{cls.__name__}() got unexpected keyword arguments: {unexpected_names}"
177 )
178
179 super().__init__()
180
181 @classmethod
182 def from_db(cls, field_names: Iterable[str], values: Sequence[Any]) -> Model:
183 if len(values) != len(cls._model_meta.concrete_fields):
184 values_iter = iter(values)
185 values = [
186 next(values_iter) if f.attname in field_names else DEFERRED
187 for f in cls._model_meta.concrete_fields
188 ]
189 # Build kwargs dict from field names and values
190 field_dict = dict(
191 zip((f.attname for f in cls._model_meta.concrete_fields), values)
192 )
193 new = cls(**field_dict)
194 new._state.adding = False
195 return new
196
197 def __repr__(self) -> str:
198 return f"<{self.__class__.__name__}: {self.id}>"
199
200 def __str__(self) -> str:
201 return f"{self.__class__.__name__} object ({self.id})"
202
203 def __eq__(self, other: object) -> bool:
204 if not isinstance(other, Model):
205 return NotImplemented
206 if self.__class__ != other.__class__:
207 return False
208 my_id = self.id
209 if my_id is None:
210 return self is other
211 return my_id == other.id
212
213 def __hash__(self) -> int:
214 if self.id is None:
215 raise TypeError("Model instances without primary key value are unhashable")
216 return hash(self.id)
217
218 def __reduce__(self) -> tuple[Any, tuple[Any, ...], dict[str, Any]]:
219 data = self.__getstate__()
220 data[PLAIN_VERSION_PICKLE_KEY] = plain.runtime.__version__
221 class_id = (
222 self.model_options.package_label,
223 self.model_options.object_name,
224 )
225 return model_unpickle, (class_id,), data
226
227 def __getstate__(self) -> dict[str, Any]:
228 """Hook to allow choosing the attributes to pickle."""
229 state = self.__dict__.copy()
230 state["_state"] = copy.copy(state["_state"])
231 state["_state"].fields_cache = state["_state"].fields_cache.copy()
232 # memoryview cannot be pickled, so cast it to bytes and store
233 # separately.
234 _memoryview_attrs = []
235 for attr, value in state.items():
236 if isinstance(value, memoryview):
237 _memoryview_attrs.append((attr, bytes(value)))
238 if _memoryview_attrs:
239 state["_memoryview_attrs"] = _memoryview_attrs
240 for attr, value in _memoryview_attrs:
241 state.pop(attr)
242 return state
243
244 def __setstate__(self, state: dict[str, Any]) -> None:
245 pickled_version = state.get(PLAIN_VERSION_PICKLE_KEY)
246 if pickled_version:
247 if pickled_version != plain.runtime.__version__:
248 warnings.warn(
249 f"Pickled model instance's Plain version {pickled_version} does not "
250 f"match the current version {plain.runtime.__version__}.",
251 RuntimeWarning,
252 stacklevel=2,
253 )
254 else:
255 warnings.warn(
256 "Pickled model instance's Plain version is not specified.",
257 RuntimeWarning,
258 stacklevel=2,
259 )
260 if "_memoryview_attrs" in state:
261 for attr, value in state.pop("_memoryview_attrs"):
262 state[attr] = memoryview(value)
263 self.__dict__.update(state)
264
265 def get_deferred_fields(self) -> set[str]:
266 """
267 Return a set containing names of deferred fields for this instance.
268 """
269 return {
270 f.attname
271 for f in self._model_meta.concrete_fields
272 if f.attname not in self.__dict__
273 }
274
275 def refresh_from_db(self, fields: list[str] | None = None) -> None:
276 """
277 Reload field values from the database.
278
279 Fields can be used to specify which fields to reload. If fields is
280 None, then all non-deferred fields are reloaded.
281
282 When accessing deferred fields of an instance, the deferred loading
283 of the field will call this method.
284 """
285 if fields is None:
286 self._prefetched_objects_cache = {}
287 else:
288 prefetched_objects_cache = getattr(self, "_prefetched_objects_cache", {})
289 for field in fields:
290 if field in prefetched_objects_cache:
291 del prefetched_objects_cache[field]
292 fields.remove(field)
293 if not fields:
294 return
295 if any(LOOKUP_SEP in f for f in fields):
296 raise ValueError(
297 f'Found "{LOOKUP_SEP}" in fields argument. Relations and transforms '
298 "are not allowed in fields."
299 )
300
301 db_instance_qs = self._model_meta.base_queryset.filter(id=self.id)
302
303 # Use provided fields, if not set then reload all non-deferred fields.
304 deferred_fields = self.get_deferred_fields()
305 if fields is not None:
306 fields = list(fields)
307 db_instance_qs = db_instance_qs.only(*fields)
308 elif deferred_fields:
309 fields = [
310 f.attname
311 for f in self._model_meta.concrete_fields
312 if f.attname not in deferred_fields
313 ]
314 db_instance_qs = db_instance_qs.only(*fields)
315
316 db_instance = db_instance_qs.get()
317 non_loaded_fields = db_instance.get_deferred_fields()
318 for field in self._model_meta.concrete_fields:
319 if field.attname in non_loaded_fields:
320 # This field wasn't refreshed - skip ahead.
321 continue
322 setattr(self, field.attname, getattr(db_instance, field.attname))
323 # Clear cached foreign keys.
324 if isinstance(field, RelatedField) and field.is_cached(self):
325 field.delete_cached_value(self)
326
327 # Clear cached relations.
328 for field in self._model_meta.related_objects:
329 if field.is_cached(self):
330 field.delete_cached_value(self)
331
332 def serializable_value(self, field_name: str) -> Any:
333 """
334 Return the value of the field name for this instance. If the field is
335 a foreign key, return the id value instead of the object. If there's
336 no Field object with this name on the model, return the model
337 attribute's value.
338
339 Used to serialize a field's value (in the serializer, or form output,
340 for example). Normally, you would just access the attribute directly
341 and not use this method.
342 """
343 try:
344 field = self._model_meta.get_forward_field(field_name)
345 except FieldDoesNotExist:
346 return getattr(self, field_name)
347 return getattr(self, field.attname)
348
349 def save(
350 self,
351 *,
352 clean_and_validate: bool = True,
353 force_insert: bool = False,
354 force_update: bool = False,
355 update_fields: Iterable[str] | None = None,
356 ) -> None:
357 """
358 Save the current instance. Override this in a subclass if you want to
359 control the saving process.
360
361 The 'force_insert' and 'force_update' parameters can be used to insist
362 that the "save" must be an SQL INSERT or UPDATE, respectively.
363 Normally, they should not be set.
364 """
365 self._prepare_related_fields_for_save(operation_name="save")
366
367 if force_insert and (force_update or update_fields):
368 raise ValueError("Cannot force both insert and updating in model saving.")
369
370 deferred_fields = self.get_deferred_fields()
371 if update_fields is not None:
372 # If update_fields is empty, skip the save. We do also check for
373 # no-op saves later on for inheritance cases. This bailout is
374 # still needed for skipping signal sending.
375 if not update_fields:
376 return
377
378 update_fields = frozenset(update_fields)
379 field_names = self._model_meta._non_pk_concrete_field_names
380 non_model_fields = update_fields.difference(field_names)
381
382 if non_model_fields:
383 raise ValueError(
384 "The following fields do not exist in this model, are m2m "
385 "fields, or are non-concrete fields: {}".format(
386 ", ".join(non_model_fields)
387 )
388 )
389
390 # If this model is deferred, automatically do an "update_fields" save
391 # on the loaded fields.
392 elif not force_insert and deferred_fields:
393 field_names = set()
394 for field in self._model_meta.concrete_fields:
395 if not field.primary_key and not hasattr(field, "through"):
396 field_names.add(field.attname)
397 loaded_fields = field_names.difference(deferred_fields)
398 if loaded_fields:
399 update_fields = frozenset(loaded_fields)
400
401 if clean_and_validate:
402 self.full_clean(exclude=deferred_fields)
403
404 self.save_base(
405 force_insert=force_insert,
406 force_update=force_update,
407 update_fields=update_fields,
408 )
409
410 def save_base(
411 self,
412 *,
413 raw: bool = False,
414 force_insert: bool = False,
415 force_update: bool = False,
416 update_fields: Iterable[str] | None = None,
417 ) -> None:
418 """
419 Handle the parts of saving which should be done only once per save,
420 yet need to be done in raw saves, too. This includes some sanity
421 checks and signal sending.
422
423 The 'raw' argument is telling save_base not to save any parent
424 models and not to do any changes to the values before save. This
425 is used by fixture loading.
426 """
427 assert not (force_insert and (force_update or update_fields))
428 assert update_fields is None or update_fields
429 cls = self.__class__
430
431 with transaction.mark_for_rollback_on_error():
432 self._save_table(
433 raw=raw,
434 cls=cls,
435 force_insert=force_insert,
436 force_update=force_update,
437 update_fields=update_fields,
438 )
439 # Once saved, this is no longer a to-be-added instance.
440 self._state.adding = False
441
442 def _save_table(
443 self,
444 *,
445 raw: bool,
446 cls: type[Model],
447 force_insert: bool = False,
448 force_update: bool = False,
449 update_fields: Iterable[str] | None = None,
450 ) -> bool:
451 """
452 Do the heavy-lifting involved in saving. Update or insert the data
453 for a single table.
454 """
455 meta = cls._model_meta
456 non_pks = [f for f in meta.local_concrete_fields if not f.primary_key]
457
458 if update_fields:
459 non_pks = [
460 f
461 for f in non_pks
462 if f.name in update_fields or f.attname in update_fields
463 ]
464
465 id_val = self.id
466 if id_val is None:
467 id_field = meta.get_forward_field("id")
468 id_val = id_field.get_id_value_on_save(self)
469 setattr(self, id_field.attname, id_val)
470 id_set = id_val is not None
471 if not id_set and (force_update or update_fields):
472 raise ValueError("Cannot force an update in save() with no primary key.")
473 updated = False
474 # Skip an UPDATE when adding an instance and primary key has a default.
475 if (
476 not raw
477 and not force_insert
478 and self._state.adding
479 and meta.get_forward_field("id").default
480 and meta.get_forward_field("id").default is not NOT_PROVIDED
481 ):
482 force_insert = True
483 # If possible, try an UPDATE. If that doesn't update anything, do an INSERT.
484 if id_set and not force_insert:
485 base_qs = meta.base_queryset
486 values = [
487 (
488 f,
489 None,
490 (getattr(self, f.attname) if raw else f.pre_save(self, False)),
491 )
492 for f in non_pks
493 ]
494 forced_update = bool(update_fields or force_update)
495 updated = self._do_update(
496 base_qs, id_val, values, update_fields, forced_update
497 )
498 if force_update and not updated:
499 raise DatabaseError("Forced update did not affect any rows.")
500 if update_fields and not updated:
501 raise DatabaseError("Save with update_fields did not affect any rows.")
502 if not updated:
503 fields = meta.local_concrete_fields
504 if not id_set:
505 id_field = meta.get_forward_field("id")
506 fields = [f for f in fields if f is not id_field]
507
508 returning_fields = meta.db_returning_fields
509 results = self._do_insert(meta.base_queryset, fields, returning_fields, raw)
510 if results:
511 for value, field in zip(results[0], returning_fields):
512 setattr(self, field.attname, value)
513 return updated
514
515 def _do_update(
516 self,
517 base_qs: QuerySet,
518 id_val: Any,
519 values: list[tuple[Any, Any, Any]],
520 update_fields: Iterable[str] | None,
521 forced_update: bool,
522 ) -> bool:
523 """
524 Try to update the model. Return True if the model was updated (if an
525 update query was done and a matching row was found in the DB).
526 """
527 filtered = base_qs.filter(id=id_val)
528 if not values:
529 # We can end up here when saving a model in inheritance chain where
530 # update_fields doesn't target any field in current model. In that
531 # case we just say the update succeeded. Another case ending up here
532 # is a model with just PK - in that case check that the PK still
533 # exists.
534 return update_fields is not None or filtered.exists()
535 return filtered._update(values) > 0
536
537 def _do_insert(
538 self,
539 manager: QuerySet,
540 fields: Sequence[Any],
541 returning_fields: Sequence[Any],
542 raw: bool,
543 ) -> list[tuple[Any, ...]] | None:
544 """
545 Do an INSERT. If returning_fields is defined then this method should
546 return the newly created data for the model.
547 """
548 return manager._insert(
549 [self],
550 fields=list(fields),
551 returning_fields=list(returning_fields) if returning_fields else None,
552 raw=raw,
553 )
554
555 def _prepare_related_fields_for_save(
556 self, operation_name: str, fields: Sequence[Any] | None = None
557 ) -> None:
558 # Ensure that a model instance without a PK hasn't been assigned to
559 # a ForeignKeyField on this model. If the field is nullable, allowing the save would result in silent data loss.
560 for field in self._model_meta.concrete_fields:
561 if fields and field not in fields:
562 continue
563 # If the related field isn't cached, then an instance hasn't been
564 # assigned and there's no need to worry about this check.
565 if isinstance(field, RelatedField) and field.is_cached(self):
566 obj = getattr(self, field.name, None)
567 if not obj:
568 continue
569 # A pk may have been assigned manually to a model instance not
570 # saved to the database (or auto-generated in a case like
571 # UUIDField), but we allow the save to proceed and rely on the
572 # database to raise an IntegrityError if applicable. If
573 # constraints aren't supported by the database, there's the
574 # unavoidable risk of data corruption.
575 if obj.id is None:
576 # Remove the object from a related instance cache.
577 if not field.remote_field.multiple:
578 field.remote_field.delete_cached_value(obj)
579 raise ValueError(
580 f"{operation_name}() prohibited to prevent data loss due to unsaved "
581 f"related object '{field.name}'."
582 )
583 elif getattr(self, field.attname) in field.empty_values:
584 # Set related object if it has been saved after an
585 # assignment.
586 setattr(self, field.name, obj)
587 # If the relationship's pk/to_field was changed, clear the
588 # cached relationship.
589 if getattr(obj, field.target_field.attname) != getattr(
590 self, field.attname
591 ):
592 field.delete_cached_value(self)
593
594 def delete(self) -> tuple[int, dict[str, int]]:
595 if self.id is None:
596 raise ValueError(
597 f"{self.model_options.object_name} object can't be deleted because its id attribute is set "
598 "to None."
599 )
600 collector = Collector(origin=self)
601 collector.collect([self])
602 return collector.delete()
603
604 def get_field_display(self, field_name: str) -> str:
605 """Get the display value for a field, especially useful for fields with choices."""
606 # Get the field object from the field name
607 field = self._model_meta.get_forward_field(field_name)
608 value = getattr(self, field.attname)
609
610 # If field has no choices, just return the value as string
611 if not hasattr(field, "flatchoices") or not field.flatchoices:
612 return force_str(value, strings_only=True)
613
614 # For fields with choices, look up the display value
615 choices_dict = dict(make_hashable(field.flatchoices))
616 return force_str(
617 choices_dict.get(make_hashable(value), value), strings_only=True
618 )
619
620 def _get_field_value_map(
621 self, meta: Meta | None, exclude: set[str] | None = None
622 ) -> dict[str, Value]:
623 if exclude is None:
624 exclude = set()
625 meta = meta or self._model_meta
626 return {
627 field.name: Value(getattr(self, field.attname), field)
628 for field in meta.local_concrete_fields
629 if field.name not in exclude
630 }
631
632 def prepare_database_save(self, field: Any) -> Any:
633 if self.id is None:
634 raise ValueError(
635 f"Unsaved model instance {self!r} cannot be used in an ORM query."
636 )
637 return getattr(self, field.remote_field.get_related_field().attname)
638
639 def clean(self) -> None:
640 """
641 Hook for doing any extra model-wide validation after clean() has been
642 called on every field by self.clean_fields. Any ValidationError raised
643 by this method will not be associated with a particular field; it will
644 have a special-case association with the field defined by NON_FIELD_ERRORS.
645 """
646 pass
647
648 def validate_unique(self, exclude: set[str] | None = None) -> None:
649 """
650 Check unique constraints on the model and raise ValidationError if any
651 failed.
652 """
653 unique_checks = self._get_unique_checks(exclude=exclude)
654
655 if errors := self._perform_unique_checks(unique_checks):
656 raise ValidationError(errors)
657
658 def _get_unique_checks(
659 self, exclude: set[str] | None = None
660 ) -> list[tuple[type[Model], tuple[str, ...]]]:
661 """
662 Return a list of checks to perform. Since validate_unique() could be
663 called from a ModelForm, some fields may have been excluded; we can't
664 perform a unique check on a model that is missing fields involved
665 in that check. Fields that did not validate should also be excluded,
666 but they need to be passed in via the exclude argument.
667 """
668 if exclude is None:
669 exclude = set()
670 unique_checks = []
671
672 # Gather a list of checks for fields declared as unique and add them to
673 # the list of checks.
674
675 fields_with_class = [(self.__class__, self._model_meta.local_fields)]
676
677 for model_class, fields in fields_with_class:
678 for f in fields:
679 name = f.name
680 if name in exclude:
681 continue
682 if f.primary_key:
683 unique_checks.append((model_class, (name,)))
684
685 return unique_checks
686
687 def _perform_unique_checks(
688 self, unique_checks: list[tuple[type[Model], tuple[str, ...]]]
689 ) -> dict[str, list[ValidationError]]:
690 errors = {}
691
692 for model_class, unique_check in unique_checks:
693 # Try to look up an existing object with the same values as this
694 # object's values for all the unique field.
695
696 lookup_kwargs = {}
697 for field_name in unique_check:
698 f = self._model_meta.get_forward_field(field_name)
699 lookup_value = getattr(self, f.attname)
700 if lookup_value is None:
701 # no value, skip the lookup
702 continue
703 if f.primary_key and not self._state.adding:
704 # no need to check for unique primary key when editing
705 continue
706 lookup_kwargs[str(field_name)] = lookup_value
707
708 # some fields were skipped, no reason to do the check
709 if len(unique_check) != len(lookup_kwargs):
710 continue
711
712 qs = model_class.query.filter(**lookup_kwargs)
713
714 # Exclude the current object from the query if we are editing an
715 # instance (as opposed to creating a new one)
716 # Use the primary key defined by model_class. In previous versions
717 # this could differ from `self.id` due to model inheritance.
718 model_class_id = getattr(self, "id")
719 if not self._state.adding and model_class_id is not None:
720 qs = qs.exclude(id=model_class_id)
721 if qs.exists():
722 if len(unique_check) == 1:
723 key = unique_check[0]
724 else:
725 key = NON_FIELD_ERRORS
726 errors.setdefault(key, []).append(
727 self.unique_error_message(model_class, unique_check)
728 )
729
730 return errors
731
732 def unique_error_message(
733 self, model_class: type[Model], unique_check: tuple[str, ...]
734 ) -> ValidationError:
735 meta = model_class._model_meta
736
737 params = {
738 "model": self,
739 "model_class": model_class,
740 "model_name": model_class.model_options.model_name,
741 "unique_check": unique_check,
742 }
743
744 if len(unique_check) == 1:
745 field = meta.get_forward_field(unique_check[0])
746 params["field_label"] = field.name # type: ignore[assignment]
747 return ValidationError(
748 message=field.error_messages["unique"],
749 code="unique",
750 params=params,
751 )
752 else:
753 field_names = [meta.get_forward_field(f).name for f in unique_check]
754
755 # Put an "and" before the last one
756 field_names[-1] = f"and {field_names[-1]}"
757
758 if len(field_names) > 2:
759 # Comma join if more than 2
760 params["field_label"] = ", ".join(cast(list[str], field_names))
761 else:
762 # Just a space if there are only 2
763 params["field_label"] = " ".join(cast(list[str], field_names))
764
765 # Use the first field as the message format...
766 message = meta.get_forward_field(unique_check[0]).error_messages["unique"]
767
768 return ValidationError(
769 message=message,
770 code="unique",
771 params=params,
772 )
773
774 def get_constraints(self) -> list[tuple[type[Model], list[Any]]]:
775 constraints: list[tuple[type[Model], list[Any]]] = [
776 (self.__class__, list(self.model_options.constraints))
777 ]
778 return constraints
779
780 def validate_constraints(self, exclude: set[str] | None = None) -> None:
781 constraints = self.get_constraints()
782
783 errors = {}
784 for model_class, model_constraints in constraints:
785 for constraint in model_constraints:
786 try:
787 constraint.validate(model_class, self, exclude=exclude)
788 except ValidationError as e:
789 if (
790 getattr(e, "code", None) == "unique"
791 and len(constraint.fields) == 1
792 ):
793 errors.setdefault(constraint.fields[0], []).append(e)
794 else:
795 errors = e.update_error_dict(errors)
796 if errors:
797 raise ValidationError(errors)
798
799 def full_clean(
800 self,
801 *,
802 exclude: set[str] | Iterable[str] | None = None,
803 validate_unique: bool = True,
804 validate_constraints: bool = True,
805 ) -> None:
806 """
807 Call clean_fields(), clean(), validate_unique(), and
808 validate_constraints() on the model. Raise a ValidationError for any
809 errors that occur.
810 """
811 errors = {}
812 if exclude is None:
813 exclude = set()
814 else:
815 exclude = set(exclude)
816
817 try:
818 self.clean_fields(exclude=exclude)
819 except ValidationError as e:
820 errors = e.update_error_dict(errors)
821
822 # Form.clean() is run even if other validation fails, so do the
823 # same with Model.clean() for consistency.
824 try:
825 self.clean()
826 except ValidationError as e:
827 errors = e.update_error_dict(errors)
828
829 # Run unique checks, but only for fields that passed validation.
830 if validate_unique:
831 for name in errors:
832 if name != NON_FIELD_ERRORS and name not in exclude:
833 exclude.add(name)
834 try:
835 self.validate_unique(exclude=exclude)
836 except ValidationError as e:
837 errors = e.update_error_dict(errors)
838
839 # Run constraints checks, but only for fields that passed validation.
840 if validate_constraints:
841 for name in errors:
842 if name != NON_FIELD_ERRORS and name not in exclude:
843 exclude.add(name)
844 try:
845 self.validate_constraints(exclude=exclude)
846 except ValidationError as e:
847 errors = e.update_error_dict(errors)
848
849 if errors:
850 raise ValidationError(errors)
851
852 def clean_fields(self, exclude: set[str] | None = None) -> None:
853 """
854 Clean all fields and raise a ValidationError containing a dict
855 of all validation errors if any occur.
856 """
857 if exclude is None:
858 exclude = set()
859
860 errors = {}
861 for f in self._model_meta.fields:
862 if f.name in exclude:
863 continue
864 # Skip validation for empty fields with required=False. The developer
865 # is responsible for making sure they have a valid value.
866 raw_value = getattr(self, f.attname)
867 if not f.required and raw_value in f.empty_values:
868 continue
869 try:
870 setattr(self, f.attname, f.clean(raw_value, self))
871 except ValidationError as e:
872 errors[f.name] = e.error_list
873
874 if errors:
875 raise ValidationError(errors)
876
877 @classmethod
878 def preflight(cls) -> list[PreflightResult]:
879 errors: list[PreflightResult] = []
880
881 errors += [
882 *cls._check_fields(),
883 *cls._check_m2m_through_same_relationship(),
884 *cls._check_long_column_names(),
885 ]
886 clash_errors = (
887 *cls._check_id_field(),
888 *cls._check_field_name_clashes(),
889 *cls._check_model_name_db_lookup_clashes(),
890 *cls._check_property_name_related_field_accessor_clashes(),
891 *cls._check_single_primary_key(),
892 )
893 errors.extend(clash_errors)
894 # If there are field name clashes, hide consequent column name
895 # clashes.
896 if not clash_errors:
897 errors.extend(cls._check_column_name_clashes())
898 errors += [
899 *cls._check_indexes(),
900 *cls._check_ordering(),
901 *cls._check_constraints(),
902 ]
903
904 return errors
905
906 @classmethod
907 def _check_fields(cls) -> list[PreflightResult]:
908 """Perform all field checks."""
909 errors: list[PreflightResult] = []
910 for field in cls._model_meta.local_fields:
911 errors.extend(field.preflight(from_model=cls))
912 for field in cls._model_meta.local_many_to_many:
913 errors.extend(field.preflight(from_model=cls))
914 return errors
915
916 @classmethod
917 def _check_m2m_through_same_relationship(cls) -> list[PreflightResult]:
918 """Check if no relationship model is used by more than one m2m field."""
919
920 errors: list[PreflightResult] = []
921 seen_intermediary_signatures = []
922
923 fields = cls._model_meta.local_many_to_many
924
925 # Skip when the target model wasn't found.
926 fields = (f for f in fields if isinstance(f.remote_field.model, ModelBase))
927
928 # Skip when the relationship model wasn't found.
929 fields = (f for f in fields if isinstance(f.remote_field.through, ModelBase))
930
931 for f in fields:
932 signature = (
933 f.remote_field.model,
934 cls,
935 f.remote_field.through,
936 f.remote_field.through_fields,
937 )
938 if signature in seen_intermediary_signatures:
939 errors.append(
940 PreflightResult(
941 fix="The model has two identical many-to-many relations "
942 f"through the intermediate model '{f.remote_field.through.model_options.label}'.",
943 obj=cls,
944 id="postgres.duplicate_many_to_many_relations",
945 )
946 )
947 else:
948 seen_intermediary_signatures.append(signature)
949 return errors
950
951 @classmethod
952 def _check_id_field(cls) -> list[PreflightResult]:
953 """Disallow user-defined fields named ``id``."""
954 if any(
955 f
956 for f in cls._model_meta.local_fields
957 if f.name == "id" and not f.auto_created
958 ):
959 return [
960 PreflightResult(
961 fix="'id' is a reserved word that cannot be used as a field name.",
962 obj=cls,
963 id="postgres.reserved_field_name_id",
964 )
965 ]
966 return []
967
968 @classmethod
969 def _check_field_name_clashes(cls) -> list[PreflightResult]:
970 """Forbid field shadowing in multi-table inheritance."""
971 errors: list[PreflightResult] = []
972 used_fields = {} # name or attname -> field
973
974 for f in cls._model_meta.local_fields:
975 clash = used_fields.get(f.name) or used_fields.get(f.attname) or None
976 # Note that we may detect clash between user-defined non-unique
977 # field "id" and automatically added unique field "id", both
978 # defined at the same model. This special case is considered in
979 # _check_id_field and here we ignore it.
980 id_conflict = (
981 f.name == "id" and clash and clash.name == "id" and clash.model == cls
982 )
983 if clash and not id_conflict:
984 errors.append(
985 PreflightResult(
986 fix=f"The field '{f.name}' clashes with the field '{clash.name}' "
987 f"from model '{clash.model.model_options}'.",
988 obj=f,
989 id="postgres.field_name_clash",
990 )
991 )
992 used_fields[f.name] = f
993 used_fields[f.attname] = f
994
995 return errors
996
997 @classmethod
998 def _check_column_name_clashes(cls) -> list[PreflightResult]:
999 # Store a list of column names which have already been used by other fields.
1000 used_column_names: list[str] = []
1001 errors: list[PreflightResult] = []
1002
1003 for f in cls._model_meta.local_fields:
1004 column_name = f.column
1005
1006 # Ensure the column name is not already in use.
1007 if column_name and column_name in used_column_names:
1008 errors.append(
1009 PreflightResult(
1010 fix=f"Field '{f.name}' has column name '{column_name}' that is used by "
1011 "another field.",
1012 obj=cls,
1013 id="postgres.db_column_clash",
1014 )
1015 )
1016 else:
1017 used_column_names.append(column_name)
1018
1019 return errors
1020
1021 @classmethod
1022 def _check_model_name_db_lookup_clashes(cls) -> list[PreflightResult]:
1023 errors: list[PreflightResult] = []
1024 model_name = cls.__name__
1025 if model_name.startswith("_") or model_name.endswith("_"):
1026 errors.append(
1027 PreflightResult(
1028 fix=f"The model name '{model_name}' cannot start or end with an underscore "
1029 "as it collides with the query lookup syntax.",
1030 obj=cls,
1031 id="postgres.model_name_underscore_bounds",
1032 )
1033 )
1034 elif LOOKUP_SEP in model_name:
1035 errors.append(
1036 PreflightResult(
1037 fix=f"The model name '{model_name}' cannot contain double underscores as "
1038 "it collides with the query lookup syntax.",
1039 obj=cls,
1040 id="postgres.model_name_double_underscore",
1041 )
1042 )
1043 return errors
1044
1045 @classmethod
1046 def _check_property_name_related_field_accessor_clashes(
1047 cls,
1048 ) -> list[PreflightResult]:
1049 errors: list[PreflightResult] = []
1050 property_names = cls._model_meta._property_names
1051 related_field_accessors = (
1052 f.get_attname()
1053 for f in cls._model_meta._get_fields(reverse=False)
1054 if isinstance(f, RelatedField)
1055 )
1056 for accessor in related_field_accessors:
1057 if accessor in property_names:
1058 errors.append(
1059 PreflightResult(
1060 fix=f"The property '{accessor}' clashes with a related field "
1061 "accessor.",
1062 obj=cls,
1063 id="postgres.property_related_field_clash",
1064 )
1065 )
1066 return errors
1067
1068 @classmethod
1069 def _check_single_primary_key(cls) -> list[PreflightResult]:
1070 errors: list[PreflightResult] = []
1071 if sum(1 for f in cls._model_meta.local_fields if f.primary_key) > 1:
1072 errors.append(
1073 PreflightResult(
1074 fix="The model cannot have more than one field with "
1075 "'primary_key=True'.",
1076 obj=cls,
1077 id="postgres.multiple_primary_keys",
1078 )
1079 )
1080 return errors
1081
1082 @classmethod
1083 def _check_indexes(cls) -> list[PreflightResult]:
1084 """Check fields, names, and conditions of indexes."""
1085 errors: list[PreflightResult] = []
1086 references: set[str] = set()
1087 for index in cls.model_options.indexes:
1088 # Index name can't start with an underscore or a number
1089 if index.name[0] == "_" or index.name[0].isdigit():
1090 errors.append(
1091 PreflightResult(
1092 fix=f"The index name '{index.name}' cannot start with an underscore "
1093 "or a number.",
1094 obj=cls,
1095 id="postgres.index_name_invalid_start",
1096 ),
1097 )
1098 if len(index.name) > index.max_name_length:
1099 errors.append(
1100 PreflightResult(
1101 fix="The index name '%s' cannot be longer than %d " # noqa: UP031
1102 "characters." % (index.name, index.max_name_length),
1103 obj=cls,
1104 id="postgres.index_name_too_long",
1105 ),
1106 )
1107 if index.contains_expressions:
1108 for expression in index.expressions:
1109 references.update(
1110 ref[0] for ref in cls._get_expr_references(expression)
1111 )
1112 # Check fields referenced in indexes
1113 fields = [
1114 field
1115 for index in cls.model_options.indexes
1116 for field, _ in index.fields_orders
1117 ]
1118 fields += [
1119 include for index in cls.model_options.indexes for include in index.include
1120 ]
1121 fields += references
1122 errors.extend(cls._check_local_fields(fields, "indexes"))
1123 return errors
1124
1125 @classmethod
1126 def _check_local_fields(
1127 cls, fields: Iterable[str], option: str
1128 ) -> list[PreflightResult]:
1129 # In order to avoid hitting the relation tree prematurely, we use our
1130 # own fields_map instead of using get_field()
1131 forward_fields_map: dict[str, Field] = {}
1132 for field in cls._model_meta._get_fields(reverse=False):
1133 forward_fields_map[field.name] = field
1134 if hasattr(field, "attname"):
1135 forward_fields_map[field.attname] = field
1136
1137 errors: list[PreflightResult] = []
1138 for field_name in fields:
1139 try:
1140 field = forward_fields_map[field_name]
1141 except KeyError:
1142 errors.append(
1143 PreflightResult(
1144 fix=f"'{option}' refers to the nonexistent field '{field_name}'.",
1145 obj=cls,
1146 id="postgres.nonexistent_field_reference",
1147 )
1148 )
1149 else:
1150 from plain.postgres.fields.related import ManyToManyField
1151
1152 if isinstance(field, ManyToManyField):
1153 errors.append(
1154 PreflightResult(
1155 fix=f"'{option}' refers to a ManyToManyField '{field_name}', but "
1156 f"ManyToManyFields are not permitted in '{option}'.",
1157 obj=cls,
1158 id="postgres.m2m_field_in_meta_option",
1159 )
1160 )
1161 elif field not in cls._model_meta.local_fields:
1162 errors.append(
1163 PreflightResult(
1164 fix=f"'{option}' refers to field '{field_name}' which is not local to model "
1165 f"'{cls.model_options.object_name}'. This issue may be caused by multi-table inheritance.",
1166 obj=cls,
1167 id="postgres.non_local_field_reference",
1168 )
1169 )
1170 return errors
1171
1172 @classmethod
1173 def _check_ordering(cls) -> list[PreflightResult]:
1174 """
1175 Check "ordering" option -- is it a list of strings and do all fields
1176 exist?
1177 """
1178
1179 if not cls.model_options.ordering:
1180 return []
1181
1182 if not isinstance(cls.model_options.ordering, list | tuple):
1183 return [
1184 PreflightResult(
1185 fix="'ordering' must be a tuple or list (even if you want to order by "
1186 "only one field).",
1187 obj=cls,
1188 id="postgres.ordering_not_tuple_or_list",
1189 )
1190 ]
1191
1192 errors: list[PreflightResult] = []
1193 fields = cls.model_options.ordering
1194
1195 # Skip expressions and '?' fields.
1196 fields = (f for f in fields if isinstance(f, str) and f != "?")
1197
1198 # Convert "-field" to "field".
1199 fields = (f.removeprefix("-") for f in fields)
1200
1201 # Separate related fields and non-related fields.
1202 _fields = []
1203 related_fields = []
1204 for f in fields:
1205 if LOOKUP_SEP in f:
1206 related_fields.append(f)
1207 else:
1208 _fields.append(f)
1209 fields = _fields
1210
1211 # Check related fields.
1212 for field in related_fields:
1213 _cls = cls
1214 fld = None
1215 for part in field.split(LOOKUP_SEP):
1216 try:
1217 fld = _cls._model_meta.get_field(part) # type: ignore[unresolved-attribute]
1218 if isinstance(fld, RelatedField):
1219 _cls = fld.path_infos[-1].to_meta.model
1220 else:
1221 _cls = None
1222 except (FieldDoesNotExist, AttributeError):
1223 if fld is None or (
1224 not isinstance(fld, Field)
1225 or (
1226 fld.get_transform(part) is None
1227 and fld.get_lookup(part) is None
1228 )
1229 ):
1230 errors.append(
1231 PreflightResult(
1232 fix="'ordering' refers to the nonexistent field, "
1233 f"related field, or lookup '{field}'.",
1234 obj=cls,
1235 id="postgres.ordering_nonexistent_field",
1236 )
1237 )
1238
1239 # Check for invalid or nonexistent fields in ordering.
1240 invalid_fields = []
1241
1242 # Any field name that is not present in field_names does not exist.
1243 # Also, ordering by m2m fields is not allowed.
1244 meta = cls._model_meta
1245 valid_fields = set(
1246 chain.from_iterable(
1247 (f.name, f.attname)
1248 if not (f.auto_created and not f.concrete)
1249 else (f.field.related_query_name(),)
1250 for f in chain(meta.fields, meta.related_objects)
1251 )
1252 )
1253
1254 invalid_fields.extend(set(fields) - valid_fields)
1255
1256 for invalid_field in invalid_fields:
1257 errors.append(
1258 PreflightResult(
1259 fix="'ordering' refers to the nonexistent field, related "
1260 f"field, or lookup '{invalid_field}'.",
1261 obj=cls,
1262 id="postgres.ordering_nonexistent_field",
1263 )
1264 )
1265 return errors
1266
1267 @classmethod
1268 def _check_long_column_names(cls) -> list[PreflightResult]:
1269 """
1270 Check that any auto-generated column names are shorter than the limits
1271 for each database in which the model will be created.
1272 """
1273 errors: list[PreflightResult] = []
1274
1275 # PostgreSQL has a 63-character limit on identifier names and doesn't
1276 # silently truncate, so we check for names that are too long
1277 allowed_len = MAX_NAME_LENGTH
1278
1279 for f in cls._model_meta.local_fields:
1280 column_name = f.column
1281
1282 # Check if column name is too long for the database.
1283 if column_name is not None and len(column_name) > allowed_len:
1284 errors.append(
1285 PreflightResult(
1286 fix=f'Column name too long for field "{column_name}". '
1287 f'Maximum length is "{allowed_len}" for the database.',
1288 obj=cls,
1289 id="postgres.column_name_too_long",
1290 )
1291 )
1292
1293 for f in cls._model_meta.local_many_to_many:
1294 # Skip nonexistent models.
1295 if isinstance(f.remote_field.through, str):
1296 continue
1297
1298 # Check if column name for the M2M field is too long for the database.
1299 for m2m in f.remote_field.through._model_meta.local_fields:
1300 rel_name = m2m.column
1301 if rel_name is not None and len(rel_name) > allowed_len:
1302 errors.append(
1303 PreflightResult(
1304 fix="Column name too long for M2M field "
1305 f'"{rel_name}". Maximum length is "{allowed_len}" for the database.',
1306 obj=cls,
1307 id="postgres.m2m_column_name_too_long",
1308 )
1309 )
1310
1311 return errors
1312
1313 @classmethod
1314 def _get_expr_references(cls, expr: Any) -> Iterator[tuple[str, ...]]:
1315 if isinstance(expr, Q):
1316 for child in expr.children:
1317 if isinstance(child, tuple):
1318 lookup, value = child
1319 yield tuple(lookup.split(LOOKUP_SEP))
1320 yield from cls._get_expr_references(value)
1321 else:
1322 yield from cls._get_expr_references(child)
1323 elif isinstance(expr, F):
1324 yield tuple(expr.name.split(LOOKUP_SEP))
1325 elif hasattr(expr, "get_source_expressions"):
1326 for src_expr in expr.get_source_expressions():
1327 yield from cls._get_expr_references(src_expr)
1328
1329 @classmethod
1330 def _check_constraints(cls) -> list[PreflightResult]:
1331 errors: list[PreflightResult] = []
1332 fields = set(
1333 chain.from_iterable(
1334 (*constraint.fields, *constraint.include)
1335 for constraint in cls.model_options.constraints
1336 if isinstance(constraint, UniqueConstraint)
1337 )
1338 )
1339 references = set()
1340 for constraint in cls.model_options.constraints:
1341 if isinstance(constraint, UniqueConstraint):
1342 if isinstance(constraint.condition, Q):
1343 references.update(cls._get_expr_references(constraint.condition))
1344 if constraint.contains_expressions:
1345 for expression in constraint.expressions:
1346 references.update(cls._get_expr_references(expression))
1347 elif isinstance(constraint, CheckConstraint):
1348 if isinstance(constraint.check, Q):
1349 references.update(cls._get_expr_references(constraint.check))
1350 if any(isinstance(expr, RawSQL) for expr in constraint.check.flatten()):
1351 errors.append(
1352 PreflightResult(
1353 fix=f"Check constraint {constraint.name!r} contains "
1354 f"RawSQL() expression and won't be validated "
1355 f"during the model full_clean(). "
1356 "Silence this warning if you don't care about it.",
1357 warning=True,
1358 obj=cls,
1359 id="postgres.constraint_name_collision_autogenerated",
1360 ),
1361 )
1362 for field_name, *lookups in references:
1363 fields.add(field_name)
1364 if not lookups:
1365 # If it has no lookups it cannot result in a JOIN.
1366 continue
1367 try:
1368 field = cls._model_meta.get_field(field_name)
1369 from plain.postgres.fields.related import ManyToManyField
1370 from plain.postgres.fields.reverse_related import ForeignKeyRel
1371
1372 if (
1373 not isinstance(field, RelatedField)
1374 or isinstance(field, ManyToManyField)
1375 or isinstance(field, ForeignKeyRel)
1376 ):
1377 continue
1378 except FieldDoesNotExist:
1379 continue
1380 # JOIN must happen at the first lookup.
1381 first_lookup = lookups[0]
1382 if (
1383 hasattr(field, "get_transform")
1384 and hasattr(field, "get_lookup")
1385 and field.get_transform(first_lookup) is None
1386 and field.get_lookup(first_lookup) is None
1387 ):
1388 errors.append(
1389 PreflightResult(
1390 fix=f"'constraints' refers to the joined field '{LOOKUP_SEP.join([field_name] + lookups)}'.",
1391 obj=cls,
1392 id="postgres.constraint_refers_to_joined_field",
1393 )
1394 )
1395 errors.extend(cls._check_local_fields(fields, "constraints"))
1396 return errors
1397
1398
1399########
1400# MISC #
1401########
1402
1403
1404def model_unpickle(model_id: tuple[str, str] | type[Model]) -> Model:
1405 """Used to unpickle Model subclasses with deferred fields."""
1406 if isinstance(model_id, tuple):
1407 model = models_registry.get_model(*model_id)
1408 else:
1409 # Backwards compat - the model was cached directly in earlier versions.
1410 model = model_id
1411 return model.__new__(model)
1412
1413
1414# Pickle protocol marker - functions don't normally have this attribute
1415model_unpickle.__safe_for_unpickle__ = True # type: ignore[attr-defined]