Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Form classes
  3"""
  4
  5from __future__ import annotations
  6
  7import copy
  8from functools import cached_property
  9from typing import TYPE_CHECKING, Any
 10
 11from plain.exceptions import NON_FIELD_ERRORS
 12from plain.utils.datastructures import MultiValueDict
 13
 14from .exceptions import ValidationError
 15from .fields import Field, FileField
 16
 17if TYPE_CHECKING:
 18    from plain.http import Request
 19
 20    from .boundfield import BoundField
 21
 22__all__ = ("BaseForm", "Form")
 23
 24
 25class DeclarativeFieldsMetaclass(type):
 26    """Collect Fields declared on the base classes."""
 27
 28    def __new__(
 29        mcs: type[DeclarativeFieldsMetaclass],
 30        name: str,
 31        bases: tuple[type, ...],
 32        attrs: dict[str, Any],
 33    ) -> type:
 34        # Collect fields from current class and remove them from attrs.
 35        attrs["declared_fields"] = {
 36            key: attrs.pop(key)
 37            for key, value in list(attrs.items())
 38            if isinstance(value, Field)
 39        }
 40
 41        new_class = super().__new__(mcs, name, bases, attrs)  # type: ignore[misc]
 42
 43        # Walk through the MRO.
 44        declared_fields: dict[str, Field] = {}
 45        for base in reversed(new_class.__mro__):
 46            # Collect fields from base class.
 47            if hasattr(base, "declared_fields"):
 48                declared_fields.update(getattr(base, "declared_fields"))
 49
 50            # Field shadowing.
 51            for attr, value in base.__dict__.items():
 52                if value is None and attr in declared_fields:
 53                    declared_fields.pop(attr)
 54
 55        setattr(new_class, "base_fields", declared_fields)
 56        setattr(new_class, "declared_fields", declared_fields)
 57
 58        return new_class
 59
 60
 61class BaseForm:
 62    """
 63    The main implementation of all the Form logic. Note that this class is
 64    different than Form. See the comments by the Form class for more info. Any
 65    improvements to the form API should be made to this class, not to the Form
 66    class.
 67    """
 68
 69    # Set by DeclarativeFieldsMetaclass
 70    base_fields: dict[str, Field]
 71
 72    prefix: str | None = None
 73
 74    def __init__(
 75        self,
 76        *,
 77        request: Request,
 78        auto_id: str | bool = "id_%s",
 79        prefix: str | None = None,
 80        initial: dict[str, Any] | None = None,
 81    ):
 82        # Forms can handle both JSON and form data
 83        self.is_json_request = request.headers.get("Content-Type", "").startswith(
 84            "application/json"
 85        )
 86        if self.is_json_request:
 87            self.data = request.json_data
 88            self.files = MultiValueDict()
 89        else:
 90            self.data = request.form_data
 91            self.files = request.files
 92
 93        self.is_bound = bool(self.data or self.files)
 94
 95        self._auto_id = auto_id
 96        if prefix is not None:
 97            self.prefix = prefix
 98        self.initial = initial or {}
 99        self._errors: dict[str, list[str]] | None = (
100            None  # Stores the errors after clean() has been called.
101        )
102
103        # The base_fields class attribute is the *class-wide* definition of
104        # fields. Because a particular *instance* of the class might want to
105        # alter self.fields, we create self.fields here by copying base_fields.
106        # Instances should always modify self.fields; they should not modify
107        # self.base_fields.
108        self.fields: dict[str, Field] = copy.deepcopy(self.base_fields)
109        self._bound_fields_cache: dict[str, BoundField] = {}
110
111    def __repr__(self) -> str:
112        if self._errors is None:
113            is_valid = "Unknown"
114        else:
115            is_valid = self.is_bound and not self._errors
116        return "<{cls} bound={bound}, valid={valid}, fields=({fields})>".format(
117            cls=self.__class__.__name__,
118            bound=self.is_bound,
119            valid=is_valid,
120            fields=";".join(self.fields),
121        )
122
123    def _bound_items(self) -> Any:
124        """Yield (name, bf) pairs, where bf is a BoundField object."""
125        for name in self.fields:
126            yield name, self[name]
127
128    def __iter__(self) -> Any:
129        """Yield the form's fields as BoundField objects."""
130        for name in self.fields:
131            yield self[name]
132
133    def __getitem__(self, name: str) -> BoundField:
134        """Return a BoundField with the given name."""
135        try:
136            field = self.fields[name]
137        except KeyError:
138            raise KeyError(
139                "Key '{}' not found in '{}'. Choices are: {}.".format(
140                    name,
141                    self.__class__.__name__,
142                    ", ".join(sorted(self.fields)),
143                )
144            )
145        if name not in self._bound_fields_cache:
146            self._bound_fields_cache[name] = field.get_bound_field(self, name)
147        return self._bound_fields_cache[name]
148
149    @property
150    def errors(self) -> dict[str, list[str]]:
151        """Return an error dict for the data provided for the form."""
152        if self._errors is None:
153            self.full_clean()
154        assert self._errors is not None, "full_clean should initialize _errors"
155        return self._errors
156
157    def is_valid(self) -> bool:
158        """Return True if the form has no errors, or False otherwise."""
159        return self.is_bound and not self.errors
160
161    def add_prefix(self, field_name: str) -> str:
162        """
163        Return the field name with a prefix appended, if this Form has a
164        prefix set.
165
166        Subclasses may wish to override.
167        """
168        return f"{self.prefix}-{field_name}" if self.prefix else field_name
169
170    @property
171    def non_field_errors(self) -> list[str]:
172        """
173        Return a list of errors that aren't associated with a particular
174        field -- i.e., from Form.clean(). Return an empty list if there
175        are none.
176        """
177        return self.errors.get(
178            NON_FIELD_ERRORS,
179            [],
180        )
181
182    def add_error(self, field: str | None, error: ValidationError) -> None:
183        """
184        Update the content of `self._errors`.
185
186        The `field` argument is the name of the field to which the errors
187        should be added. If it's None, treat the errors as NON_FIELD_ERRORS.
188
189        The `error` argument can be a single error, a list of errors, or a
190        dictionary that maps field names to lists of errors. An "error" can be
191        either a simple string or an instance of ValidationError with its
192        message attribute set and a "list or dictionary" can be an actual
193        `list` or `dict` or an instance of ValidationError with its
194        `error_list` or `error_dict` attribute set.
195
196        If `error` is a dictionary, the `field` argument *must* be None and
197        errors will be added to the fields that correspond to the keys of the
198        dictionary.
199        """
200        if not isinstance(error, ValidationError):
201            raise TypeError(
202                "The argument `error` must be an instance of "
203                f"`ValidationError`, not `{type(error).__name__}`."
204            )
205
206        error_dict: dict[str, Any]
207        if hasattr(error, "error_dict"):
208            if field is not None:
209                raise TypeError(
210                    "The argument `field` must be `None` when the `error` "
211                    "argument contains errors for multiple fields."
212                )
213            else:
214                error_dict = error.error_dict
215        else:
216            error_dict = {field or NON_FIELD_ERRORS: error.error_list}
217
218        class ValidationErrors(list):
219            def __iter__(self) -> Any:
220                for err in super().__iter__():
221                    # TODO make sure this works...
222                    yield next(iter(err))
223
224        for field_key, error_list in error_dict.items():
225            # Accessing self.errors ensures _errors is initialized
226            if field_key not in self.errors:
227                if field_key != NON_FIELD_ERRORS and field_key not in self.fields:
228                    raise ValueError(
229                        f"'{self.__class__.__name__}' has no field named '{field_key}'."
230                    )
231                assert self._errors is not None, "errors property initializes _errors"
232                self._errors[field_key] = ValidationErrors()
233
234            assert self._errors is not None, "errors property initializes _errors"
235            self._errors[field_key].extend(error_list)
236
237            # The field had an error, so removed it from the final data
238            # (we use getattr here so errors can be added to uncleaned forms)
239            if field_key in getattr(self, "cleaned_data", {}):
240                del self.cleaned_data[field_key]
241
242    def full_clean(self) -> None:
243        """
244        Clean all of self.data and populate self._errors and self.cleaned_data.
245        """
246        self._errors = {}
247        if not self.is_bound:  # Stop further processing.
248            return None
249        self.cleaned_data = {}
250
251        self._clean_fields()
252        self._clean_form()
253        self._post_clean()
254
255    def _field_data_value(self, field: Field, html_name: str) -> Any:
256        if hasattr(self, f"parse_{html_name}"):
257            # Allow custom parsing from form data/files at the form level
258            return getattr(self, f"parse_{html_name}")()
259
260        if self.is_json_request:
261            return field.value_from_json_data(self.data, self.files, html_name)
262        else:
263            return field.value_from_form_data(self.data, self.files, html_name)
264
265    def _clean_fields(self) -> None:
266        for name, bf in self._bound_items():
267            field = bf.field
268
269            value = self._field_data_value(bf.field, bf.html_name)
270
271            try:
272                if isinstance(field, FileField):
273                    value = field.clean(value, bf.initial)
274                else:
275                    value = field.clean(value)
276                self.cleaned_data[name] = value
277                if hasattr(self, f"clean_{name}"):
278                    value = getattr(self, f"clean_{name}")()
279                    self.cleaned_data[name] = value
280            except ValidationError as e:
281                self.add_error(name, e)
282
283    def _clean_form(self) -> None:
284        try:
285            cleaned_data = self.clean()
286        except ValidationError as e:
287            self.add_error(None, e)
288        else:
289            if cleaned_data is not None:
290                self.cleaned_data = cleaned_data
291
292    def _post_clean(self) -> None:
293        """
294        An internal hook for performing additional cleaning after form cleaning
295        is complete. Used for model validation in model forms.
296        """
297        pass
298
299    def clean(self) -> dict[str, Any]:
300        """
301        Hook for doing any extra form-wide cleaning after Field.clean() has been
302        called on every field. Any ValidationError raised by this method will
303        not be associated with a particular field; it will have a special-case
304        association with the field named '__all__'.
305        """
306        return self.cleaned_data
307
308    @cached_property
309    def changed_data(self) -> list[str]:
310        return [name for name, bf in self._bound_items() if bf._has_changed()]
311
312    def get_initial_for_field(self, field: Field, field_name: str) -> Any:
313        """
314        Return initial data for field on form. Use initial data from the form
315        or the field, in that order. Evaluate callable values.
316        """
317        value = self.initial.get(field_name, field.initial)
318        if callable(value):
319            value = value()
320        return value
321
322
323class Form(BaseForm, metaclass=DeclarativeFieldsMetaclass):
324    "A collection of Fields, plus their associated data."
325
326    # This is a separate class from BaseForm in order to abstract the way
327    # self.fields is specified. This class (Form) is the one that does the
328    # fancy metaclass stuff purely for the semantic sugar -- it allows one
329    # to define a form using declarative syntax.
330    # BaseForm itself has no way of designating self.fields.