1"""
2Helper functions for creating Form classes from Plain models
3and database field objects.
4"""
5
6from __future__ import annotations
7
8from itertools import chain
9from typing import TYPE_CHECKING, Any
10
11from plain.exceptions import (
12 NON_FIELD_ERRORS,
13 ImproperlyConfigured,
14 ValidationError,
15)
16from plain.forms import fields
17from plain.forms.fields import ChoiceField, Field
18from plain.forms.forms import BaseForm, DeclarativeFieldsMetaclass
19from plain.models.exceptions import FieldError
20
21if TYPE_CHECKING:
22 from plain.models.fields import Field as ModelField
23
24__all__ = (
25 "ModelForm",
26 "BaseModelForm",
27 "model_to_dict",
28 "fields_for_model",
29 "ModelChoiceField",
30 "ModelMultipleChoiceField",
31)
32
33
34def construct_instance(
35 form: BaseModelForm,
36 instance: Any,
37 fields: list[str] | tuple[str, ...] | None = None,
38) -> Any:
39 """
40 Construct and return a model instance from the bound ``form``'s
41 ``cleaned_data``, but do not save the returned instance to the database.
42 """
43 from plain import models
44
45 meta = instance._model_meta
46
47 cleaned_data = form.cleaned_data
48 file_field_list = []
49 for f in meta.fields:
50 if isinstance(f, models.PrimaryKeyField) or f.name not in cleaned_data:
51 continue
52 if fields is not None and f.name not in fields:
53 continue
54 # Leave defaults for fields that aren't in POST data, except for
55 # checkbox inputs because they don't appear in POST data if not checked.
56 if (
57 f.has_default()
58 and form.add_prefix(f.name) not in form.data
59 and form.add_prefix(f.name) not in form.files
60 # and form[f.name].field.widget.value_omitted_from_data(
61 # form.data, form.files, form.add_prefix(f.name)
62 # )
63 and cleaned_data.get(f.name) in form[f.name].field.empty_values
64 ):
65 continue
66
67 f.save_form_data(instance, cleaned_data[f.name])
68
69 for f in file_field_list:
70 f.save_form_data(instance, cleaned_data[f.name])
71
72 return instance
73
74
75# ModelForms #################################################################
76
77
78def model_to_dict(
79 instance: Any, fields: list[str] | tuple[str, ...] | None = None
80) -> dict[str, Any]:
81 """
82 Return a dict containing the data in ``instance`` suitable for passing as
83 a Form's ``initial`` keyword argument.
84
85 ``fields`` is an optional list of field names. If provided, return only the
86 named.
87 """
88 meta = instance._model_meta
89 data = {}
90 for f in chain(meta.concrete_fields, meta.many_to_many):
91 if fields is not None and f.name not in fields:
92 continue
93 data[f.name] = f.value_from_object(instance)
94 return data
95
96
97def fields_for_model(
98 model: type[Any],
99 fields: list[str] | tuple[str, ...] | None = None,
100 formfield_callback: Any = None,
101 error_messages: dict[str, Any] | None = None,
102 field_classes: dict[str, type[Field]] | None = None,
103) -> dict[str, Field | None]:
104 """
105 Return a dictionary containing form fields for the given model.
106
107 ``fields`` is an optional list of field names. If provided, return only the
108 named fields.
109
110 ``formfield_callback`` is a callable that takes a model field and returns
111 a form field.
112
113 ``error_messages`` is a dictionary of model field names mapped to a
114 dictionary of error messages.
115
116 ``field_classes`` is a dictionary of model field names mapped to a form
117 field class.
118 """
119 field_dict = {}
120 ignored = []
121 meta = model._model_meta
122
123 for f in sorted(chain(meta.concrete_fields, meta.many_to_many)):
124 if fields is not None and f.name not in fields:
125 continue
126
127 kwargs = {}
128 if error_messages and f.name in error_messages:
129 kwargs["error_messages"] = error_messages[f.name]
130 if field_classes and f.name in field_classes:
131 kwargs["form_class"] = field_classes[f.name]
132
133 if formfield_callback is None:
134 formfield = modelfield_to_formfield(f, **kwargs)
135 elif not callable(formfield_callback):
136 raise TypeError("formfield_callback must be a function or callable")
137 else:
138 formfield = formfield_callback(f, **kwargs)
139
140 if formfield:
141 field_dict[f.name] = formfield
142 else:
143 ignored.append(f.name)
144 if fields:
145 field_dict = {f: field_dict.get(f) for f in fields if f not in ignored}
146 return field_dict
147
148
149class ModelFormOptions:
150 def __init__(self, options: Any = None) -> None:
151 self.model: type[Any] | None = getattr(options, "model", None)
152 self.fields: list[str] | tuple[str, ...] | None = getattr(
153 options, "fields", None
154 )
155 self.error_messages: dict[str, Any] | None = getattr(
156 options, "error_messages", None
157 )
158 self.field_classes: dict[str, type[Field]] | None = getattr(
159 options, "field_classes", None
160 )
161 self.formfield_callback: Any = getattr(options, "formfield_callback", None)
162
163
164class ModelFormMetaclass(DeclarativeFieldsMetaclass):
165 def __new__(
166 mcs: type[ModelFormMetaclass],
167 name: str,
168 bases: tuple[type, ...],
169 attrs: dict[str, Any],
170 ) -> type[BaseModelForm]:
171 new_class = super().__new__(mcs, name, bases, attrs) # type: ignore[invalid-super-argument]
172
173 if bases == (BaseModelForm,):
174 return new_class
175
176 opts = new_class._meta = ModelFormOptions(getattr(new_class, "Meta", None))
177
178 # We check if a string was passed to `fields`,
179 # which is likely to be a mistake where the user typed ('foo') instead
180 # of ('foo',)
181 for opt in ["fields"]:
182 value = getattr(opts, opt)
183 if isinstance(value, str):
184 msg = (
185 f"{new_class.__name__}.Meta.{opt} cannot be a string. "
186 f"Did you mean to type: ('{value}',)?"
187 )
188 raise TypeError(msg)
189
190 if opts.model:
191 # If a model is defined, extract form fields from it.
192 if opts.fields is None:
193 raise ImproperlyConfigured(
194 "Creating a ModelForm without the 'fields' attribute "
195 f"is prohibited; form {name} "
196 "needs updating."
197 )
198
199 fields = fields_for_model(
200 opts.model,
201 opts.fields,
202 opts.formfield_callback,
203 opts.error_messages,
204 opts.field_classes,
205 )
206
207 # make sure opts.fields doesn't specify an invalid field
208 none_model_fields = {k for k, v in fields.items() if not v}
209 missing_fields = none_model_fields.difference(new_class.declared_fields)
210 if missing_fields:
211 message = "Unknown field(s) (%s) specified for %s"
212 message %= (", ".join(missing_fields), opts.model.__name__)
213 raise FieldError(message)
214 # Override default model fields with any custom declared ones
215 # (plus, include all the other declared fields).
216 fields.update(new_class.declared_fields)
217 else:
218 fields = new_class.declared_fields
219
220 new_class.base_fields = fields
221
222 return new_class
223
224
225class BaseModelForm(BaseForm):
226 def __init__(
227 self,
228 *,
229 request: Any,
230 auto_id: str = "id_%s",
231 prefix: str | None = None,
232 initial: dict[str, Any] | None = None,
233 instance: Any = None,
234 ) -> None:
235 opts = self._meta
236 if opts.model is None:
237 raise ValueError("ModelForm has no model class specified.")
238 if instance is None:
239 # if we didn't get an instance, instantiate a new one
240 self.instance = opts.model()
241 object_data = {}
242 else:
243 self.instance = instance
244 object_data = model_to_dict(instance, opts.fields)
245 # if initial was provided, it should override the values from instance
246 if initial is not None:
247 object_data.update(initial)
248 # self._validate_unique will be set to True by BaseModelForm.clean().
249 # It is False by default so overriding self.clean() and failing to call
250 # super will stop validate_unique from being called.
251 self._validate_unique = False
252 super().__init__(
253 request=request,
254 auto_id=auto_id,
255 prefix=prefix,
256 initial=object_data,
257 )
258
259 def _get_validation_exclusions(self) -> set[str]:
260 """
261 For backwards-compatibility, exclude several types of fields from model
262 validation. See tickets #12507, #12521, #12553.
263 """
264 exclude = set()
265 # Build up a list of fields that should be excluded from model field
266 # validation and unique checks.
267 for f in self.instance._model_meta.fields:
268 field = f.name
269 # Exclude fields that aren't on the form. The developer may be
270 # adding these values to the model after form validation.
271 if field not in self.fields:
272 exclude.add(f.name)
273
274 # Don't perform model validation on fields that were defined
275 # manually on the form and excluded via the ModelForm's Meta
276 # class. See #12901.
277 elif self._meta.fields and field not in self._meta.fields:
278 exclude.add(f.name)
279
280 # Exclude fields that failed form validation. There's no need for
281 # the model fields to validate them as well.
282 elif field in self._errors:
283 exclude.add(f.name)
284
285 # Exclude empty fields that are not required by the form, if the
286 # underlying model field is required. This keeps the model field
287 # from raising a required error. Note: don't exclude the field from
288 # validation if the model field allows blanks. If it does, the blank
289 # value may be included in a unique check, so cannot be excluded
290 # from validation.
291 else:
292 form_field = self.fields[field]
293 field_value = self.cleaned_data.get(field)
294 if (
295 f.required
296 and not form_field.required
297 and field_value in form_field.empty_values
298 ):
299 exclude.add(f.name)
300 return exclude
301
302 def clean(self) -> dict[str, Any]:
303 self._validate_unique = True
304 return self.cleaned_data
305
306 def _update_errors(self, errors: ValidationError) -> None:
307 # Override any validation error messages defined at the model level
308 # with those defined at the form level.
309 opts = self._meta
310
311 # Allow the model generated by construct_instance() to raise
312 # ValidationError and have them handled in the same way as others.
313 if hasattr(errors, "error_dict"):
314 error_dict = errors.error_dict
315 else:
316 error_dict = {NON_FIELD_ERRORS: errors}
317
318 for field, messages in error_dict.items():
319 if (
320 field == NON_FIELD_ERRORS
321 and opts.error_messages
322 and NON_FIELD_ERRORS in opts.error_messages
323 ):
324 error_messages = opts.error_messages[NON_FIELD_ERRORS]
325 elif field in self.fields:
326 error_messages = self.fields[field].error_messages
327 else:
328 continue
329
330 for message in messages:
331 if (
332 isinstance(message, ValidationError)
333 and message.code in error_messages
334 ):
335 message.message = error_messages[message.code]
336
337 self.add_error(None, errors)
338
339 def _post_clean(self) -> None:
340 opts = self._meta
341
342 exclude = self._get_validation_exclusions()
343
344 try:
345 self.instance = construct_instance(self, self.instance, opts.fields)
346 except ValidationError as e:
347 self._update_errors(e)
348
349 try:
350 self.instance.full_clean(exclude=exclude, validate_unique=False)
351 except ValidationError as e:
352 self._update_errors(e)
353
354 # Validate uniqueness if needed.
355 if self._validate_unique:
356 self.validate_unique()
357
358 def validate_unique(self) -> None:
359 """
360 Call the instance's validate_unique() method and update the form's
361 validation errors if any were raised.
362 """
363 exclude = self._get_validation_exclusions()
364 try:
365 self.instance.validate_unique(exclude=exclude)
366 except ValidationError as e:
367 self._update_errors(e)
368
369 def _save_m2m(self) -> None:
370 """
371 Save the many-to-many fields and generic relations for this form.
372 """
373 cleaned_data = self.cleaned_data
374 fields = self._meta.fields
375 meta = self.instance._model_meta
376
377 for f in meta.many_to_many:
378 if not hasattr(f, "save_form_data"):
379 continue
380 if fields and f.name not in fields:
381 continue
382 if f.name in cleaned_data:
383 f.save_form_data(self.instance, cleaned_data[f.name])
384
385 def save(self, commit: bool = True) -> Any:
386 """
387 Save this form's self.instance object if commit=True. Otherwise, add
388 a save_m2m() method to the form which can be called after the instance
389 is saved manually at a later time. Return the model instance.
390 """
391 if self.errors:
392 raise ValueError(
393 "The {} could not be {} because the data didn't validate.".format(
394 self.instance.model_options.object_name,
395 "created" if self.instance._state.adding else "changed",
396 )
397 )
398 if commit:
399 # If committing, save the instance and the m2m data immediately.
400 self.instance.save(clean_and_validate=False)
401 self._save_m2m()
402 else:
403 # If not committing, add a method to the form to allow deferred
404 # saving of m2m data.
405 self.save_m2m = self._save_m2m
406 return self.instance
407
408
409class ModelForm(BaseModelForm, metaclass=ModelFormMetaclass):
410 pass
411
412
413# Fields #####################################################################
414
415
416class ModelChoiceIteratorValue:
417 def __init__(self, value: Any, instance: Any) -> None:
418 self.value = value
419 self.instance = instance
420
421 def __str__(self) -> str:
422 return str(self.value)
423
424 def __hash__(self) -> int:
425 return hash(self.value)
426
427 def __eq__(self, other: object) -> bool:
428 if isinstance(other, ModelChoiceIteratorValue):
429 other = other.value
430 return self.value == other
431
432
433class ModelChoiceIterator:
434 def __init__(self, field: ModelChoiceField) -> None:
435 self.field = field
436 self.queryset = field.queryset
437
438 def __iter__(self) -> Any:
439 if self.field.empty_label is not None:
440 yield ("", self.field.empty_label)
441 queryset = self.queryset
442 # Can't use iterator() when queryset uses prefetch_related()
443 if not queryset._prefetch_related_lookups:
444 queryset = queryset.iterator()
445 for obj in queryset:
446 yield self.choice(obj)
447
448 def __len__(self) -> int:
449 # count() adds a query but uses less memory since the QuerySet results
450 # won't be cached. In most cases, the choices will only be iterated on,
451 # and __len__() won't be called.
452 return self.queryset.count() + (1 if self.field.empty_label is not None else 0)
453
454 def __bool__(self) -> bool:
455 return self.field.empty_label is not None or self.queryset.exists()
456
457 def choice(self, obj: Any) -> tuple[ModelChoiceIteratorValue, str]:
458 return (
459 ModelChoiceIteratorValue(self.field.prepare_value(obj), obj),
460 str(obj),
461 )
462
463
464class ModelChoiceField(ChoiceField):
465 """A ChoiceField whose choices are a model QuerySet."""
466
467 # This class is a subclass of ChoiceField for purity, but it doesn't
468 # actually use any of ChoiceField's implementation.
469 default_error_messages = {
470 "invalid_choice": "Select a valid choice. That choice is not one of the available choices.",
471 }
472 iterator = ModelChoiceIterator
473
474 def __init__(
475 self,
476 queryset: Any,
477 *,
478 empty_label: str | None = "---------",
479 required: bool = True,
480 initial: Any = None,
481 **kwargs: Any,
482 ) -> None:
483 # Call Field instead of ChoiceField __init__() because we don't need
484 # ChoiceField.__init__().
485 Field.__init__(
486 self,
487 required=required,
488 initial=initial,
489 **kwargs,
490 )
491 if required and initial is not None:
492 self.empty_label = None
493 else:
494 self.empty_label = empty_label
495 self.queryset = queryset
496
497 def __deepcopy__(self, memo: dict[int, Any]) -> ModelChoiceField:
498 result = super(ChoiceField, self).__deepcopy__(memo)
499 # Need to force a new ModelChoiceIterator to be created, bug #11183
500 if self.queryset is not None:
501 result.queryset = self.queryset.all()
502 return result
503
504 def _get_queryset(self) -> Any:
505 return self._queryset
506
507 def _set_queryset(self, queryset: Any) -> None:
508 self._queryset = None if queryset is None else queryset.all()
509
510 queryset = property(_get_queryset, _set_queryset)
511
512 def _get_choices(self) -> ModelChoiceIterator:
513 # If self._choices is set, then somebody must have manually set
514 # the property self.choices. In this case, just return self._choices.
515 if hasattr(self, "_choices"):
516 return self._choices
517
518 # Otherwise, execute the QuerySet in self.queryset to determine the
519 # choices dynamically. Return a fresh ModelChoiceIterator that has not been
520 # consumed. Note that we're instantiating a new ModelChoiceIterator *each*
521 # time _get_choices() is called (and, thus, each time self.choices is
522 # accessed) so that we can ensure the QuerySet has not been consumed. This
523 # construct might look complicated but it allows for lazy evaluation of
524 # the queryset.
525 return self.iterator(self)
526
527 choices = property(_get_choices, ChoiceField._set_choices)
528
529 def prepare_value(self, value: Any) -> Any:
530 if hasattr(value, "_model_meta"):
531 return value.id
532 return super().prepare_value(value)
533
534 def to_python(self, value: Any) -> Any:
535 if value in self.empty_values:
536 return None
537 try:
538 key = "id"
539 if isinstance(value, self.queryset.model):
540 value = getattr(value, key)
541 value = self.queryset.get(**{key: value})
542 except (ValueError, TypeError, self.queryset.model.DoesNotExist):
543 raise ValidationError(
544 self.error_messages["invalid_choice"],
545 code="invalid_choice",
546 params={"value": value},
547 )
548 return value
549
550 def validate(self, value: Any) -> None:
551 return Field.validate(self, value)
552
553 def has_changed(self, initial: Any, data: Any) -> bool:
554 initial_value = initial if initial is not None else ""
555 data_value = data if data is not None else ""
556 return str(self.prepare_value(initial_value)) != str(data_value)
557
558
559class ModelMultipleChoiceField(ModelChoiceField):
560 """A MultipleChoiceField whose choices are a model QuerySet."""
561
562 default_error_messages = {
563 "invalid_list": "Enter a list of values.",
564 "invalid_choice": "Select a valid choice. %(value)s is not one of the available choices.",
565 "invalid_id_value": "'%(id)s' is not a valid value.",
566 }
567
568 def __init__(self, queryset: Any, **kwargs: Any) -> None:
569 super().__init__(queryset, empty_label=None, **kwargs)
570
571 def to_python(self, value: Any) -> list[Any]:
572 if not value:
573 return []
574 return list(self._check_values(value))
575
576 def clean(self, value: Any) -> Any:
577 value = self.prepare_value(value)
578 if self.required and not value:
579 raise ValidationError(self.error_messages["required"], code="required")
580 elif not self.required and not value:
581 return self.queryset.none()
582 if not isinstance(value, list | tuple):
583 raise ValidationError(
584 self.error_messages["invalid_list"],
585 code="invalid_list",
586 )
587 qs = self._check_values(value)
588 # Since this overrides the inherited ModelChoiceField.clean
589 # we run custom validators here
590 self.run_validators(value)
591 return qs
592
593 def _check_values(self, value: Any) -> Any:
594 """
595 Given a list of possible PK values, return a QuerySet of the
596 corresponding objects. Raise a ValidationError if a given value is
597 invalid (not a valid PK, not in the queryset, etc.)
598 """
599 # deduplicate given values to avoid creating many querysets or
600 # requiring the database backend deduplicate efficiently.
601 try:
602 value = frozenset(value)
603 except TypeError:
604 # list of lists isn't hashable, for example
605 raise ValidationError(
606 self.error_messages["invalid_list"],
607 code="invalid_list",
608 )
609 for id_val in value:
610 try:
611 self.queryset.filter(id=id_val)
612 except (ValueError, TypeError):
613 raise ValidationError(
614 self.error_messages["invalid_id_value"],
615 code="invalid_id_value",
616 params={"id": id_val},
617 )
618 qs = self.queryset.filter(id__in=value)
619 ids = {str(o.id) for o in qs}
620 for val in value:
621 if str(val) not in ids:
622 raise ValidationError(
623 self.error_messages["invalid_choice"],
624 code="invalid_choice",
625 params={"value": val},
626 )
627 return qs
628
629 def prepare_value(self, value: Any) -> Any:
630 if (
631 hasattr(value, "__iter__")
632 and not isinstance(value, str)
633 and not hasattr(value, "_model_meta")
634 ):
635 prepare_value = super().prepare_value
636 return [prepare_value(v) for v in value]
637 return super().prepare_value(value)
638
639 def has_changed(self, initial: Any, data: Any) -> bool:
640 if initial is None:
641 initial = []
642 if data is None:
643 data = []
644 if len(initial) != len(data):
645 return True
646 initial_set = {str(value) for value in self.prepare_value(initial)}
647 data_set = {str(value) for value in data}
648 return data_set != initial_set
649
650 def value_from_form_data(self, data: Any, files: Any, html_name: str) -> Any:
651 return data.getlist(html_name)
652
653
654def modelfield_to_formfield(
655 modelfield: ModelField,
656 form_class: type[Field] | None = None,
657 choices_form_class: type[Field] | None = None,
658 **kwargs: Any,
659) -> Field | None:
660 defaults = {
661 "required": modelfield.required,
662 }
663
664 if modelfield.has_default():
665 defaults["initial"] = modelfield.get_default()
666
667 if modelfield.choices is not None:
668 # Fields with choices get special treatment.
669 include_blank = not modelfield.required or not (
670 modelfield.has_default() or "initial" in kwargs
671 )
672 defaults["choices"] = modelfield.get_choices(include_blank=include_blank)
673 defaults["coerce"] = modelfield.to_python
674 if modelfield.allow_null:
675 defaults["empty_value"] = None
676 if choices_form_class is not None:
677 form_class = choices_form_class
678 else:
679 form_class = fields.TypedChoiceField
680 # Many of the subclass-specific formfield arguments (min_value,
681 # max_value) don't apply for choice fields, so be sure to only pass
682 # the values that TypedChoiceField will understand.
683 for k in list(kwargs):
684 if k not in (
685 "coerce",
686 "empty_value",
687 "choices",
688 "required",
689 "initial",
690 "error_messages",
691 ):
692 del kwargs[k]
693
694 defaults.update(kwargs)
695
696 if form_class is not None:
697 return form_class(**defaults)
698
699 # Avoid a circular import
700 from plain import models
701
702 # Primary key fields aren't rendered by default
703 if isinstance(modelfield, models.PrimaryKeyField):
704 return None
705
706 if isinstance(modelfield, models.BooleanField):
707 form_class = (
708 fields.NullBooleanField if modelfield.allow_null else fields.BooleanField
709 )
710 # In HTML checkboxes, 'required' means "must be checked" which is
711 # different from the choices case ("must select some value").
712 # required=False allows unchecked checkboxes.
713 defaults["required"] = False
714 return form_class(**defaults)
715
716 if isinstance(modelfield, models.DecimalField):
717 return fields.DecimalField(
718 max_digits=modelfield.max_digits,
719 decimal_places=modelfield.decimal_places,
720 **defaults,
721 )
722
723 if issubclass(modelfield.__class__, models.fields.PositiveIntegerRelDbTypeMixin): # type: ignore[attr-defined]
724 return fields.IntegerField(min_value=0, **defaults)
725
726 if isinstance(modelfield, models.TextField):
727 # Passing max_length to fields.CharField means that the value's length
728 # will be validated twice. This is considered acceptable since we want
729 # the value in the form field (to pass into widget for example).
730 return fields.CharField(max_length=modelfield.max_length, **defaults)
731
732 if isinstance(modelfield, models.CharField):
733 # Passing max_length to forms.CharField means that the value's length
734 # will be validated twice. This is considered acceptable since we want
735 # the value in the form field (to pass into widget for example).
736 if modelfield.allow_null:
737 defaults["empty_value"] = None
738 return fields.CharField(
739 max_length=modelfield.max_length,
740 **defaults,
741 )
742
743 if isinstance(modelfield, models.JSONField):
744 return fields.JSONField(
745 encoder=modelfield.encoder, decoder=modelfield.decoder, **defaults
746 )
747
748 if isinstance(modelfield, models.ForeignKey):
749 return ModelChoiceField(
750 queryset=modelfield.remote_field.model.query, # type: ignore[attr-defined]
751 **defaults,
752 )
753
754 # TODO related (OneToOne, m2m)
755
756 # If there's a form field of the exact same name, use it
757 # (models.URLField -> forms.URLField)
758 if hasattr(fields, modelfield.__class__.__name__):
759 form_class = getattr(fields, modelfield.__class__.__name__)
760 return form_class(**defaults)
761
762 # Default to CharField if we didn't find anything else
763 return fields.CharField(**defaults)