Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import ipaddress
  4import math
  5import re
  6from collections.abc import Callable
  7from pathlib import Path
  8from typing import TYPE_CHECKING, Any, cast
  9from urllib.parse import urlsplit, urlunsplit
 10
 11from plain.exceptions import ValidationError
 12from plain.utils.deconstruct import deconstructible
 13from plain.utils.encoding import punycode
 14from plain.utils.ipv6 import is_valid_ipv6_address
 15from plain.utils.regex_helper import _lazy_re_compile
 16from plain.utils.text import pluralize_lazy
 17
 18if TYPE_CHECKING:
 19    from plain.utils.functional import SimpleLazyObject
 20
 21# These values, if given to validate(), will trigger the self.required check.
 22EMPTY_VALUES = (None, "", [], (), {})
 23
 24
 25@deconstructible
 26class RegexValidator:
 27    regex: str | re.Pattern[str] | SimpleLazyObject = ""
 28    message = "Enter a valid value."
 29    code = "invalid"
 30    inverse_match = False
 31    flags = 0
 32
 33    def __init__(
 34        self,
 35        regex: str | re.Pattern[str] | None = None,
 36        message: str | None = None,
 37        code: str | None = None,
 38        inverse_match: bool | None = None,
 39        flags: int | None = None,
 40    ) -> None:
 41        # Only compile regex if explicitly provided or if class default needs compilation
 42        if regex is not None:
 43            regex_to_compile: str | re.Pattern[str] = regex
 44        elif isinstance(self.regex, str | re.Pattern):
 45            # Class-level regex is a string or pattern that needs compilation
 46            regex_to_compile = self.regex
 47        else:
 48            # Class-level regex is already compiled (e.g., in URL Validator subclass)
 49            # Don't recompile it
 50            regex_to_compile = None  # type: ignore[assignment]
 51
 52        if message is not None:
 53            self.message = message
 54        if code is not None:
 55            self.code = code
 56        if inverse_match is not None:
 57            self.inverse_match = inverse_match
 58        if flags is not None:
 59            self.flags = flags
 60
 61        # Only compile if we have a regex to compile
 62        if regex_to_compile is not None:
 63            if self.flags and not isinstance(regex_to_compile, str):
 64                raise TypeError(
 65                    "If the flags are set, regex must be a regular expression string."
 66                )
 67            self.regex = _lazy_re_compile(regex_to_compile, self.flags)
 68
 69    def __call__(self, value: Any) -> None:
 70        """
 71        Validate that the input contains (or does *not* contain, if
 72        inverse_match is True) a match for the regular expression.
 73        """
 74        # self.regex is always a SimpleLazyObject with search() after __init__
 75        regex_matches = cast(re.Pattern[str], self.regex).search(str(value))
 76        invalid_input = regex_matches if self.inverse_match else not regex_matches
 77        if invalid_input:
 78            raise ValidationError(self.message, code=self.code, params={"value": value})
 79
 80    def __eq__(self, other: object) -> bool:
 81        return (
 82            isinstance(other, RegexValidator)
 83            and self.regex.pattern == other.regex.pattern  # type: ignore[attr-defined]
 84            and self.regex.flags == other.regex.flags  # type: ignore[attr-defined]
 85            and (self.message == other.message)
 86            and (self.code == other.code)
 87            and (self.inverse_match == other.inverse_match)
 88        )
 89
 90
 91@deconstructible
 92class URLValidator(RegexValidator):
 93    ul = "\u00a1-\uffff"  # Unicode letters range (must not be a raw string).
 94
 95    # IP patterns
 96    ipv4_re = (
 97        r"(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)"
 98        r"(?:\.(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)){3}"
 99    )
100    ipv6_re = r"\[[0-9a-f:.]+\]"  # (simple regex, validated later)
101
102    # Host patterns
103    hostname_re = (
104        r"[a-z" + ul + r"0-9](?:[a-z" + ul + r"0-9-]{0,61}[a-z" + ul + r"0-9])?"
105    )
106    # Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1
107    domain_re = r"(?:\.(?!-)[a-z" + ul + r"0-9-]{1,63}(?<!-))*"
108    tld_re = (
109        r"\."  # dot
110        r"(?!-)"  # can't start with a dash
111        r"(?:[a-z" + ul + "-]{2,63}"  # domain label
112        r"|xn--[a-z0-9]{1,59})"  # or punycode label
113        r"(?<!-)"  # can't end with a dash
114        r"\.?"  # may have a trailing dot
115    )
116    host_re = "(" + hostname_re + domain_re + tld_re + "|localhost)"
117
118    regex = _lazy_re_compile(
119        r"^(?:[a-z0-9.+-]*)://"  # scheme is validated separately
120        r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?"  # user:pass authentication
121        r"(?:" + ipv4_re + "|" + ipv6_re + "|" + host_re + ")"
122        r"(?::[0-9]{1,5})?"  # port
123        r"(?:[/?#][^\s]*)?"  # resource path
124        r"\Z",
125        re.IGNORECASE,
126    )
127    message = "Enter a valid URL."
128    schemes = ["http", "https", "ftp", "ftps"]
129    unsafe_chars = frozenset("\t\r\n")
130
131    def __init__(self, schemes: list[str] | None = None, **kwargs: Any) -> None:
132        super().__init__(**kwargs)
133        if schemes is not None:
134            self.schemes = schemes
135
136    def __call__(self, value: Any) -> None:
137        if not isinstance(value, str):
138            raise ValidationError(self.message, code=self.code, params={"value": value})
139        if self.unsafe_chars.intersection(value):
140            raise ValidationError(self.message, code=self.code, params={"value": value})
141        # Check if the scheme is valid.
142        scheme = value.split("://")[0].lower()
143        if scheme not in self.schemes:
144            raise ValidationError(self.message, code=self.code, params={"value": value})
145
146        # Then check full URL
147        try:
148            splitted_url = urlsplit(value)
149        except ValueError:
150            raise ValidationError(self.message, code=self.code, params={"value": value})
151        try:
152            super().__call__(value)
153        except ValidationError as e:
154            # Trivial case failed. Try for possible IDN domain
155            if value:
156                scheme, netloc, path, query, fragment = splitted_url
157                try:
158                    netloc = punycode(netloc)  # IDN -> ACE
159                except UnicodeError:  # invalid domain part
160                    raise e
161                url = urlunsplit((scheme, netloc, path, query, fragment))
162                super().__call__(url)
163            else:
164                raise
165        else:
166            # Now verify IPv6 in the netloc part
167            host_match = re.search(r"^\[(.+)\](?::[0-9]{1,5})?$", splitted_url.netloc)
168            if host_match:
169                potential_ip = host_match[1]
170                try:
171                    validate_ipv6_address(potential_ip)
172                except ValidationError:
173                    raise ValidationError(
174                        self.message, code=self.code, params={"value": value}
175                    )
176
177        # The maximum length of a full host name is 253 characters per RFC 1034
178        # section 3.1. It's defined to be 255 bytes or less, but this includes
179        # one byte for the length of the name and one byte for the trailing dot
180        # that's used to indicate absolute names in DNS.
181        if splitted_url.hostname is None or len(splitted_url.hostname) > 253:
182            raise ValidationError(self.message, code=self.code, params={"value": value})
183
184
185@deconstructible
186class EmailValidator:
187    message = "Enter a valid email address."
188    code = "invalid"
189    user_regex = _lazy_re_compile(
190        # dot-atom
191        r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z"
192        # quoted-string
193        r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])'
194        r'*"\Z)',
195        re.IGNORECASE,
196    )
197    domain_regex = _lazy_re_compile(
198        # max length for domain name labels is 63 characters per RFC 1034
199        r"((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z",
200        re.IGNORECASE,
201    )
202    literal_regex = _lazy_re_compile(
203        # literal form, ipv4 or ipv6 address (SMTP 4.1.3)
204        r"\[([A-F0-9:.]+)\]\Z",
205        re.IGNORECASE,
206    )
207    domain_allowlist = ["localhost"]
208
209    def __init__(
210        self,
211        message: str | None = None,
212        code: str | None = None,
213        allowlist: list[str] | None = None,
214    ) -> None:
215        if message is not None:
216            self.message = message
217        if code is not None:
218            self.code = code
219        if allowlist is not None:
220            self.domain_allowlist = allowlist
221
222    def __call__(self, value: Any) -> None:
223        if not value or "@" not in value:
224            raise ValidationError(self.message, code=self.code, params={"value": value})
225
226        user_part, domain_part = value.rsplit("@", 1)
227
228        if not self.user_regex.match(user_part):
229            raise ValidationError(self.message, code=self.code, params={"value": value})
230
231        if domain_part not in self.domain_allowlist and not self.validate_domain_part(
232            domain_part
233        ):
234            # Try for possible IDN domain-part
235            try:
236                domain_part = punycode(domain_part)
237            except UnicodeError:
238                pass
239            else:
240                if self.validate_domain_part(domain_part):
241                    return None
242            raise ValidationError(self.message, code=self.code, params={"value": value})
243        return None
244
245    def validate_domain_part(self, domain_part: str) -> bool:
246        if self.domain_regex.match(domain_part):
247            return True
248
249        literal_match = self.literal_regex.match(domain_part)
250        if literal_match:
251            ip_address = literal_match[1]
252            try:
253                validate_ipv46_address(ip_address)
254                return True
255            except ValidationError:
256                pass
257        return False
258
259    def __eq__(self, other: object) -> bool:
260        return (
261            isinstance(other, EmailValidator)
262            and (self.domain_allowlist == other.domain_allowlist)
263            and (self.message == other.message)
264            and (self.code == other.code)
265        )
266
267
268validate_email = EmailValidator()
269
270
271def validate_ipv4_address(value: str) -> None:
272    try:
273        ipaddress.IPv4Address(value)
274    except ValueError:
275        raise ValidationError(
276            "Enter a valid IPv4 address.", code="invalid", params={"value": value}
277        )
278
279
280def validate_ipv6_address(value: str) -> None:
281    if not is_valid_ipv6_address(value):
282        raise ValidationError(
283            "Enter a valid IPv6 address.", code="invalid", params={"value": value}
284        )
285
286
287def validate_ipv46_address(value: str) -> None:
288    try:
289        validate_ipv4_address(value)
290    except ValidationError:
291        try:
292            validate_ipv6_address(value)
293        except ValidationError:
294            raise ValidationError(
295                "Enter a valid IPv4 or IPv6 address.",
296                code="invalid",
297                params={"value": value},
298            )
299
300
301ip_address_validator_map = {
302    "both": ([validate_ipv46_address], "Enter a valid IPv4 or IPv6 address."),
303    "ipv4": ([validate_ipv4_address], "Enter a valid IPv4 address."),
304    "ipv6": ([validate_ipv6_address], "Enter a valid IPv6 address."),
305}
306
307
308def ip_address_validators(
309    protocol: str, unpack_ipv4: bool
310) -> tuple[list[Callable[[str], None]], str]:
311    """
312    Depending on the given parameters, return the appropriate validators for
313    the GenericIPAddressField.
314    """
315    if protocol != "both" and unpack_ipv4:
316        raise ValueError(
317            "You can only use `unpack_ipv4` if `protocol` is set to 'both'"
318        )
319    try:
320        return ip_address_validator_map[protocol.lower()]
321    except KeyError:
322        raise ValueError(
323            f"The protocol '{protocol}' is unknown. Supported: {list(ip_address_validator_map)}"
324        )
325
326
327def int_list_validator(
328    sep: str = ",",
329    message: str | None = None,
330    code: str = "invalid",
331    allow_negative: bool = False,
332) -> RegexValidator:
333    regexp = _lazy_re_compile(
334        r"^{neg}\d+(?:{sep}{neg}\d+)*\Z".format(
335            neg="(-)?" if allow_negative else "",
336            sep=re.escape(sep),
337        )
338    )
339    return RegexValidator(regexp, message=message, code=code)  # type: ignore[arg-type]
340
341
342validate_comma_separated_integer_list = int_list_validator(
343    message="Enter only digits separated by commas.",
344)
345
346
347@deconstructible
348class BaseValidator:
349    message = "Ensure this value is %(limit_value)s (it is %(show_value)s)."
350    code = "limit_value"
351
352    def __init__(self, limit_value: Any, message: str | None = None) -> None:
353        self.limit_value = limit_value
354        if message:
355            self.message = message
356
357    def __call__(self, value: Any) -> None:
358        cleaned = self.clean(value)
359        limit_value = (
360            self.limit_value() if callable(self.limit_value) else self.limit_value
361        )
362        params = {"limit_value": limit_value, "show_value": cleaned, "value": value}
363        if self.compare(cleaned, limit_value):
364            raise ValidationError(self.message, code=self.code, params=params)
365
366    def __eq__(self, other: object) -> bool:
367        if not isinstance(other, self.__class__):
368            return NotImplemented
369        return (
370            self.limit_value == other.limit_value
371            and self.message == other.message
372            and self.code == other.code
373        )
374
375    def compare(self, a: Any, b: Any) -> bool:
376        return a is not b
377
378    def clean(self, x: Any) -> Any:
379        return x
380
381
382@deconstructible
383class MaxValueValidator(BaseValidator):
384    message = "Ensure this value is less than or equal to %(limit_value)s."
385    code = "max_value"
386
387    def compare(self, a: Any, b: Any) -> bool:
388        return a > b
389
390
391@deconstructible
392class MinValueValidator(BaseValidator):
393    message = "Ensure this value is greater than or equal to %(limit_value)s."
394    code = "min_value"
395
396    def compare(self, a: Any, b: Any) -> bool:
397        return a < b
398
399
400@deconstructible
401class StepValueValidator(BaseValidator):
402    message = "Ensure this value is a multiple of step size %(limit_value)s."
403    code = "step_size"
404
405    def compare(self, a: Any, b: Any) -> bool:
406        return not math.isclose(math.remainder(a, b), 0, abs_tol=1e-9)
407
408
409@deconstructible
410class MinLengthValidator(BaseValidator):
411    message = pluralize_lazy(
412        "Ensure this value has at least %(limit_value)d character (it has "
413        "%(show_value)d).",
414        "Ensure this value has at least %(limit_value)d characters (it has "
415        "%(show_value)d).",
416        "limit_value",
417    )
418    code = "min_length"
419
420    def compare(self, a: Any, b: Any) -> bool:
421        return a < b
422
423    def clean(self, x: Any) -> int:
424        return len(x)
425
426
427@deconstructible
428class MaxLengthValidator(BaseValidator):
429    message = pluralize_lazy(
430        "Ensure this value has at most %(limit_value)d character (it has "
431        "%(show_value)d).",
432        "Ensure this value has at most %(limit_value)d characters (it has "
433        "%(show_value)d).",
434        "limit_value",
435    )
436    code = "max_length"
437
438    def compare(self, a: Any, b: Any) -> bool:
439        return a > b
440
441    def clean(self, x: Any) -> int:
442        return len(x)
443
444
445@deconstructible
446class DecimalValidator:
447    """
448    Validate that the input does not exceed the maximum number of digits
449    expected, otherwise raise ValidationError.
450    """
451
452    messages = {
453        "invalid": "Enter a number.",
454        "max_digits": pluralize_lazy(
455            "Ensure that there are no more than %(max)s digit in total.",
456            "Ensure that there are no more than %(max)s digits in total.",
457            "max",
458        ),
459        "max_decimal_places": pluralize_lazy(
460            "Ensure that there are no more than %(max)s decimal place.",
461            "Ensure that there are no more than %(max)s decimal places.",
462            "max",
463        ),
464        "max_whole_digits": pluralize_lazy(
465            "Ensure that there are no more than %(max)s digit before the decimal "
466            "point.",
467            "Ensure that there are no more than %(max)s digits before the decimal "
468            "point.",
469            "max",
470        ),
471    }
472
473    def __init__(self, max_digits: int | None, decimal_places: int | None) -> None:
474        self.max_digits = max_digits
475        self.decimal_places = decimal_places
476
477    def __call__(self, value: Any) -> None:
478        digit_tuple, exponent = value.as_tuple()[1:]
479        if exponent in {"F", "n", "N"}:
480            raise ValidationError(
481                self.messages["invalid"], code="invalid", params={"value": value}
482            )
483        if exponent >= 0:
484            digits = len(digit_tuple)
485            if digit_tuple != (0,):
486                # A positive exponent adds that many trailing zeros.
487                digits += exponent
488            decimals = 0
489        else:
490            # If the absolute value of the negative exponent is larger than the
491            # number of digits, then it's the same as the number of digits,
492            # because it'll consume all of the digits in digit_tuple and then
493            # add abs(exponent) - len(digit_tuple) leading zeros after the
494            # decimal point.
495            if abs(exponent) > len(digit_tuple):
496                digits = decimals = abs(exponent)
497            else:
498                digits = len(digit_tuple)
499                decimals = abs(exponent)
500        whole_digits = digits - decimals
501
502        if self.max_digits is not None and digits > self.max_digits:
503            raise ValidationError(
504                self.messages["max_digits"],
505                code="max_digits",
506                params={"max": self.max_digits, "value": value},
507            )
508        if self.decimal_places is not None and decimals > self.decimal_places:
509            raise ValidationError(
510                self.messages["max_decimal_places"],
511                code="max_decimal_places",
512                params={"max": self.decimal_places, "value": value},
513            )
514        if (
515            self.max_digits is not None
516            and self.decimal_places is not None
517            and whole_digits > (self.max_digits - self.decimal_places)
518        ):
519            raise ValidationError(
520                self.messages["max_whole_digits"],
521                code="max_whole_digits",
522                params={"max": (self.max_digits - self.decimal_places), "value": value},
523            )
524
525    def __eq__(self, other: object) -> bool:
526        return (
527            isinstance(other, self.__class__)
528            and self.max_digits == other.max_digits
529            and self.decimal_places == other.decimal_places
530        )
531
532
533@deconstructible
534class FileExtensionValidator:
535    message = 'File extension "%(extension)s" is not allowed. Allowed extensions are: %(allowed_extensions)s.'
536    code = "invalid_extension"
537
538    def __init__(
539        self,
540        allowed_extensions: list[str] | None = None,
541        message: str | None = None,
542        code: str | None = None,
543    ) -> None:
544        if allowed_extensions is not None:
545            allowed_extensions = [
546                allowed_extension.lower() for allowed_extension in allowed_extensions
547            ]
548        self.allowed_extensions = allowed_extensions
549        if message is not None:
550            self.message = message
551        if code is not None:
552            self.code = code
553
554    def __call__(self, value: Any) -> None:
555        extension = Path(value.name).suffix[1:].lower()
556        if (
557            self.allowed_extensions is not None
558            and extension not in self.allowed_extensions
559        ):
560            raise ValidationError(
561                self.message,
562                code=self.code,
563                params={
564                    "extension": extension,
565                    "allowed_extensions": ", ".join(self.allowed_extensions),
566                    "value": value,
567                },
568            )
569
570    def __eq__(self, other: object) -> bool:
571        return (
572            isinstance(other, self.__class__)
573            and self.allowed_extensions == other.allowed_extensions
574            and self.message == other.message
575            and self.code == other.code
576        )
577
578
579def get_available_image_extensions() -> list[str]:
580    try:
581        from PIL import Image  # type: ignore[import-untyped]
582    except ImportError:
583        return []
584    else:
585        Image.init()
586        return [ext.lower()[1:] for ext in Image.EXTENSION]
587
588
589def validate_image_file_extension(value: Any) -> None:
590    return FileExtensionValidator(allowed_extensions=get_available_image_extensions())(
591        value
592    )
593
594
595@deconstructible
596class ProhibitNullCharactersValidator:
597    """Validate that the string doesn't contain the null character."""
598
599    message = "Null characters are not allowed."
600    code = "null_characters_not_allowed"
601
602    def __init__(self, message: str | None = None, code: str | None = None) -> None:
603        if message is not None:
604            self.message = message
605        if code is not None:
606            self.code = code
607
608    def __call__(self, value: Any) -> None:
609        if "\x00" in str(value):
610            raise ValidationError(self.message, code=self.code, params={"value": value})
611
612    def __eq__(self, other: object) -> bool:
613        return (
614            isinstance(other, self.__class__)
615            and self.message == other.message
616            and self.code == other.code
617        )