1from __future__ import annotations
  2
  3import ipaddress
  4import math
  5import re
  6from collections.abc import Callable
  7from pathlib import Path
  8from typing import TYPE_CHECKING, Any, cast
  9from urllib.parse import urlsplit, urlunsplit
 10
 11from plain.exceptions import ValidationError
 12from plain.utils.deconstruct import deconstructible
 13from plain.utils.encoding import punycode
 14from plain.utils.ipv6 import is_valid_ipv6_address
 15from plain.utils.regex_helper import _lazy_re_compile
 16from plain.utils.text import pluralize_lazy
 17
 18if TYPE_CHECKING:
 19    from plain.utils.functional import SimpleLazyObject
 20
 21# These values, if given to validate(), will trigger the self.required check.
 22EMPTY_VALUES = (None, "", [], (), {})
 23
 24
 25@deconstructible
 26class RegexValidator:
 27    regex: str | re.Pattern[str] | SimpleLazyObject = ""
 28    message = "Enter a valid value."
 29    code = "invalid"
 30    inverse_match = False
 31    flags = 0
 32
 33    def __init__(
 34        self,
 35        regex: str | re.Pattern[str] | None = None,
 36        message: str | None = None,
 37        code: str | None = None,
 38        inverse_match: bool | None = None,
 39        flags: int | None = None,
 40    ) -> None:
 41        # Only compile regex if explicitly provided or if class default needs compilation
 42        if regex is not None:
 43            regex_to_compile: str | re.Pattern[str] = regex
 44        elif isinstance(self.regex, str | re.Pattern):
 45            # Class-level regex is a string or pattern that needs compilation
 46            regex_to_compile = self.regex
 47        else:
 48            # Class-level regex is already compiled (e.g., in URL Validator subclass)
 49            # Don't recompile it
 50            regex_to_compile = None
 51
 52        if message is not None:
 53            self.message = message
 54        if code is not None:
 55            self.code = code
 56        if inverse_match is not None:
 57            self.inverse_match = inverse_match
 58        if flags is not None:
 59            self.flags = flags
 60
 61        # Only compile if we have a regex to compile
 62        if regex_to_compile is not None:
 63            if self.flags and not isinstance(regex_to_compile, str):
 64                raise TypeError(
 65                    "If the flags are set, regex must be a regular expression string."
 66                )
 67            self.regex = _lazy_re_compile(regex_to_compile, self.flags)
 68
 69    def __call__(self, value: Any) -> None:
 70        """
 71        Validate that the input contains (or does *not* contain, if
 72        inverse_match is True) a match for the regular expression.
 73        """
 74        # self.regex is always a SimpleLazyObject with search() after __init__
 75        regex_matches = cast(re.Pattern[str], self.regex).search(str(value))
 76        invalid_input = regex_matches if self.inverse_match else not regex_matches
 77        if invalid_input:
 78            raise ValidationError(self.message, code=self.code, params={"value": value})
 79
 80    def __eq__(self, other: object) -> bool:
 81        if not isinstance(other, RegexValidator):
 82            return NotImplemented
 83        self_regex = cast(re.Pattern[str], self.regex)
 84        other_regex = cast(re.Pattern[str], other.regex)
 85        return (
 86            self_regex.pattern == other_regex.pattern
 87            and self_regex.flags == other_regex.flags
 88            and (self.message == other.message)
 89            and (self.code == other.code)
 90            and (self.inverse_match == other.inverse_match)
 91        )
 92
 93
 94@deconstructible
 95class URLValidator(RegexValidator):
 96    ul = "\u00a1-\uffff"  # Unicode letters range (must not be a raw string).
 97
 98    # IP patterns
 99    ipv4_re = (
100        r"(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)"
101        r"(?:\.(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)){3}"
102    )
103    ipv6_re = r"\[[0-9a-f:.]+\]"  # (simple regex, validated later)
104
105    # Host patterns
106    hostname_re = (
107        r"[a-z" + ul + r"0-9](?:[a-z" + ul + r"0-9-]{0,61}[a-z" + ul + r"0-9])?"
108    )
109    # Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1
110    domain_re = r"(?:\.(?!-)[a-z" + ul + r"0-9-]{1,63}(?<!-))*"
111    tld_re = (
112        r"\."  # dot
113        r"(?!-)"  # can't start with a dash
114        r"(?:[a-z" + ul + "-]{2,63}"  # domain label
115        r"|xn--[a-z0-9]{1,59})"  # or punycode label
116        r"(?<!-)"  # can't end with a dash
117        r"\.?"  # may have a trailing dot
118    )
119    host_re = "(" + hostname_re + domain_re + tld_re + "|localhost)"
120
121    regex = _lazy_re_compile(
122        r"^(?:[a-z0-9.+-]*)://"  # scheme is validated separately
123        r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?"  # user:pass authentication
124        r"(?:" + ipv4_re + "|" + ipv6_re + "|" + host_re + ")"
125        r"(?::[0-9]{1,5})?"  # port
126        r"(?:[/?#][^\s]*)?"  # resource path
127        r"\Z",
128        re.IGNORECASE,
129    )
130    message = "Enter a valid URL."
131    schemes = ["http", "https", "ftp", "ftps"]
132    unsafe_chars = frozenset("\t\r\n")
133
134    def __init__(self, schemes: list[str] | None = None, **kwargs: Any) -> None:
135        super().__init__(**kwargs)
136        if schemes is not None:
137            self.schemes = schemes
138
139    def __call__(self, value: Any) -> None:
140        if not isinstance(value, str):
141            raise ValidationError(self.message, code=self.code, params={"value": value})
142        if self.unsafe_chars.intersection(value):
143            raise ValidationError(self.message, code=self.code, params={"value": value})
144        # Check if the scheme is valid.
145        scheme = value.split("://")[0].lower()
146        if scheme not in self.schemes:
147            raise ValidationError(self.message, code=self.code, params={"value": value})
148
149        # Then check full URL
150        try:
151            splitted_url = urlsplit(value)
152        except ValueError:
153            raise ValidationError(self.message, code=self.code, params={"value": value})
154        try:
155            super().__call__(value)
156        except ValidationError as e:
157            # Trivial case failed. Try for possible IDN domain
158            if value:
159                scheme, netloc, path, query, fragment = splitted_url
160                try:
161                    netloc = punycode(netloc)  # IDN -> ACE
162                except UnicodeError:  # invalid domain part
163                    raise e
164                url = urlunsplit((scheme, netloc, path, query, fragment))
165                super().__call__(url)
166            else:
167                raise
168        else:
169            # Now verify IPv6 in the netloc part
170            host_match = re.search(r"^\[(.+)\](?::[0-9]{1,5})?$", splitted_url.netloc)
171            if host_match:
172                potential_ip = host_match[1]
173                try:
174                    validate_ipv6_address(potential_ip)
175                except ValidationError:
176                    raise ValidationError(
177                        self.message, code=self.code, params={"value": value}
178                    )
179
180        # The maximum length of a full host name is 253 characters per RFC 1034
181        # section 3.1. It's defined to be 255 bytes or less, but this includes
182        # one byte for the length of the name and one byte for the trailing dot
183        # that's used to indicate absolute names in DNS.
184        if splitted_url.hostname is None or len(splitted_url.hostname) > 253:
185            raise ValidationError(self.message, code=self.code, params={"value": value})
186
187
188@deconstructible
189class EmailValidator:
190    message = "Enter a valid email address."
191    code = "invalid"
192    user_regex = _lazy_re_compile(
193        # dot-atom
194        r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z"
195        # quoted-string
196        r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])'
197        r'*"\Z)',
198        re.IGNORECASE,
199    )
200    domain_regex = _lazy_re_compile(
201        # max length for domain name labels is 63 characters per RFC 1034
202        r"((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z",
203        re.IGNORECASE,
204    )
205    literal_regex = _lazy_re_compile(
206        # literal form, ipv4 or ipv6 address (SMTP 4.1.3)
207        r"\[([A-F0-9:.]+)\]\Z",
208        re.IGNORECASE,
209    )
210    domain_allowlist = ["localhost"]
211
212    def __init__(
213        self,
214        message: str | None = None,
215        code: str | None = None,
216        allowlist: list[str] | None = None,
217    ) -> None:
218        if message is not None:
219            self.message = message
220        if code is not None:
221            self.code = code
222        if allowlist is not None:
223            self.domain_allowlist = allowlist
224
225    def __call__(self, value: Any) -> None:
226        if not value or "@" not in value:
227            raise ValidationError(self.message, code=self.code, params={"value": value})
228
229        user_part, domain_part = value.rsplit("@", 1)
230
231        if not self.user_regex.match(user_part):
232            raise ValidationError(self.message, code=self.code, params={"value": value})
233
234        if domain_part not in self.domain_allowlist and not self.validate_domain_part(
235            domain_part
236        ):
237            # Try for possible IDN domain-part
238            try:
239                domain_part = punycode(domain_part)
240            except UnicodeError:
241                pass
242            else:
243                if self.validate_domain_part(domain_part):
244                    return None
245            raise ValidationError(self.message, code=self.code, params={"value": value})
246        return None
247
248    def validate_domain_part(self, domain_part: str) -> bool:
249        if self.domain_regex.match(domain_part):
250            return True
251
252        literal_match = self.literal_regex.match(domain_part)
253        if literal_match:
254            ip_address = literal_match[1]
255            try:
256                validate_ipv46_address(ip_address)
257                return True
258            except ValidationError:
259                pass
260        return False
261
262    def __eq__(self, other: object) -> bool:
263        return (
264            isinstance(other, EmailValidator)
265            and (self.domain_allowlist == other.domain_allowlist)
266            and (self.message == other.message)
267            and (self.code == other.code)
268        )
269
270
271validate_email = EmailValidator()
272
273
274def validate_ipv4_address(value: str, /) -> None:
275    try:
276        ipaddress.IPv4Address(value)
277    except ValueError:
278        raise ValidationError(
279            "Enter a valid IPv4 address.", code="invalid", params={"value": value}
280        )
281
282
283def validate_ipv6_address(value: str, /) -> None:
284    if not is_valid_ipv6_address(value):
285        raise ValidationError(
286            "Enter a valid IPv6 address.", code="invalid", params={"value": value}
287        )
288
289
290def validate_ipv46_address(value: str, /) -> None:
291    try:
292        validate_ipv4_address(value)
293    except ValidationError:
294        try:
295            validate_ipv6_address(value)
296        except ValidationError:
297            raise ValidationError(
298                "Enter a valid IPv4 or IPv6 address.",
299                code="invalid",
300                params={"value": value},
301            )
302
303
304ip_address_validator_map: dict[str, tuple[list[Callable[[str], None]], str]] = {
305    "both": ([validate_ipv46_address], "Enter a valid IPv4 or IPv6 address."),
306    "ipv4": ([validate_ipv4_address], "Enter a valid IPv4 address."),
307    "ipv6": ([validate_ipv6_address], "Enter a valid IPv6 address."),
308}
309
310
311def ip_address_validators(
312    protocol: str, unpack_ipv4: bool
313) -> tuple[list[Callable[[str], None]], str]:
314    """
315    Depending on the given parameters, return the appropriate validators for
316    the GenericIPAddressField.
317    """
318    if protocol != "both" and unpack_ipv4:
319        raise ValueError(
320            "You can only use `unpack_ipv4` if `protocol` is set to 'both'"
321        )
322    try:
323        return ip_address_validator_map[protocol.lower()]
324    except KeyError:
325        raise ValueError(
326            f"The protocol '{protocol}' is unknown. Supported: {list(ip_address_validator_map)}"
327        )
328
329
330def int_list_validator(
331    sep: str = ",",
332    message: str | None = None,
333    code: str = "invalid",
334    allow_negative: bool = False,
335) -> RegexValidator:
336    regexp = _lazy_re_compile(
337        r"^{neg}\d+(?:{sep}{neg}\d+)*\Z".format(
338            neg="(-)?" if allow_negative else "",
339            sep=re.escape(sep),
340        )
341    )
342    return RegexValidator(regexp, message=message, code=code)  # ty: ignore[invalid-argument-type]
343
344
345validate_comma_separated_integer_list = int_list_validator(
346    message="Enter only digits separated by commas.",
347)
348
349
350@deconstructible
351class BaseValidator:
352    message = "Ensure this value is %(limit_value)s (it is %(show_value)s)."
353    code = "limit_value"
354
355    def __init__(self, limit_value: Any, message: str | None = None) -> None:
356        self.limit_value = limit_value
357        if message:
358            self.message = message
359
360    def __call__(self, value: Any) -> None:
361        cleaned = self.clean(value)
362        limit_value = (
363            self.limit_value() if callable(self.limit_value) else self.limit_value
364        )
365        params = {"limit_value": limit_value, "show_value": cleaned, "value": value}
366        if self.compare(cleaned, limit_value):
367            raise ValidationError(self.message, code=self.code, params=params)
368
369    def __eq__(self, other: object) -> bool:
370        if not isinstance(other, self.__class__):
371            return NotImplemented
372        return (
373            self.limit_value == other.limit_value
374            and self.message == other.message
375            and self.code == other.code
376        )
377
378    def compare(self, a: Any, b: Any) -> bool:
379        return a is not b
380
381    def clean(self, x: Any) -> Any:
382        return x
383
384
385@deconstructible
386class MaxValueValidator(BaseValidator):
387    message = "Ensure this value is less than or equal to %(limit_value)s."
388    code = "max_value"
389
390    def compare(self, a: Any, b: Any) -> bool:
391        return a > b
392
393
394@deconstructible
395class MinValueValidator(BaseValidator):
396    message = "Ensure this value is greater than or equal to %(limit_value)s."
397    code = "min_value"
398
399    def compare(self, a: Any, b: Any) -> bool:
400        return a < b
401
402
403@deconstructible
404class StepValueValidator(BaseValidator):
405    message = "Ensure this value is a multiple of step size %(limit_value)s."
406    code = "step_size"
407
408    def compare(self, a: Any, b: Any) -> bool:
409        return not math.isclose(math.remainder(a, b), 0, abs_tol=1e-9)
410
411
412@deconstructible
413class MinLengthValidator(BaseValidator):
414    message = pluralize_lazy(
415        "Ensure this value has at least %(limit_value)d character (it has "
416        "%(show_value)d).",
417        "Ensure this value has at least %(limit_value)d characters (it has "
418        "%(show_value)d).",
419        "limit_value",
420    )
421    code = "min_length"
422
423    def compare(self, a: Any, b: Any) -> bool:
424        return a < b
425
426    def clean(self, x: Any) -> int:
427        return len(x)
428
429
430@deconstructible
431class MaxLengthValidator(BaseValidator):
432    message = pluralize_lazy(
433        "Ensure this value has at most %(limit_value)d character (it has "
434        "%(show_value)d).",
435        "Ensure this value has at most %(limit_value)d characters (it has "
436        "%(show_value)d).",
437        "limit_value",
438    )
439    code = "max_length"
440
441    def compare(self, a: Any, b: Any) -> bool:
442        return a > b
443
444    def clean(self, x: Any) -> int:
445        return len(x)
446
447
448@deconstructible
449class DecimalValidator:
450    """
451    Validate that the input does not exceed the maximum number of digits
452    expected, otherwise raise ValidationError.
453    """
454
455    messages = {
456        "invalid": "Enter a number.",
457        "max_digits": pluralize_lazy(
458            "Ensure that there are no more than %(max)s digit in total.",
459            "Ensure that there are no more than %(max)s digits in total.",
460            "max",
461        ),
462        "max_decimal_places": pluralize_lazy(
463            "Ensure that there are no more than %(max)s decimal place.",
464            "Ensure that there are no more than %(max)s decimal places.",
465            "max",
466        ),
467        "max_whole_digits": pluralize_lazy(
468            "Ensure that there are no more than %(max)s digit before the decimal "
469            "point.",
470            "Ensure that there are no more than %(max)s digits before the decimal "
471            "point.",
472            "max",
473        ),
474    }
475
476    def __init__(self, max_digits: int | None, decimal_places: int | None) -> None:
477        self.max_digits = max_digits
478        self.decimal_places = decimal_places
479
480    def __call__(self, value: Any) -> None:
481        digit_tuple, exponent = value.as_tuple()[1:]
482        if exponent in {"F", "n", "N"}:
483            raise ValidationError(
484                self.messages["invalid"], code="invalid", params={"value": value}
485            )
486        if exponent >= 0:
487            digits = len(digit_tuple)
488            if digit_tuple != (0,):
489                # A positive exponent adds that many trailing zeros.
490                digits += exponent
491            decimals = 0
492        else:
493            # If the absolute value of the negative exponent is larger than the
494            # number of digits, then it's the same as the number of digits,
495            # because it'll consume all of the digits in digit_tuple and then
496            # add abs(exponent) - len(digit_tuple) leading zeros after the
497            # decimal point.
498            if abs(exponent) > len(digit_tuple):
499                digits = decimals = abs(exponent)
500            else:
501                digits = len(digit_tuple)
502                decimals = abs(exponent)
503        whole_digits = digits - decimals
504
505        if self.max_digits is not None and digits > self.max_digits:
506            raise ValidationError(
507                self.messages["max_digits"],
508                code="max_digits",
509                params={"max": self.max_digits, "value": value},
510            )
511        if self.decimal_places is not None and decimals > self.decimal_places:
512            raise ValidationError(
513                self.messages["max_decimal_places"],
514                code="max_decimal_places",
515                params={"max": self.decimal_places, "value": value},
516            )
517        if (
518            self.max_digits is not None
519            and self.decimal_places is not None
520            and whole_digits > (self.max_digits - self.decimal_places)
521        ):
522            raise ValidationError(
523                self.messages["max_whole_digits"],
524                code="max_whole_digits",
525                params={"max": (self.max_digits - self.decimal_places), "value": value},
526            )
527
528    def __eq__(self, other: object) -> bool:
529        return (
530            isinstance(other, self.__class__)
531            and self.max_digits == other.max_digits
532            and self.decimal_places == other.decimal_places
533        )
534
535
536@deconstructible
537class FileExtensionValidator:
538    message = 'File extension "%(extension)s" is not allowed. Allowed extensions are: %(allowed_extensions)s.'
539    code = "invalid_extension"
540
541    def __init__(
542        self,
543        allowed_extensions: list[str] | None = None,
544        message: str | None = None,
545        code: str | None = None,
546    ) -> None:
547        if allowed_extensions is not None:
548            allowed_extensions = [
549                allowed_extension.lower() for allowed_extension in allowed_extensions
550            ]
551        self.allowed_extensions = allowed_extensions
552        if message is not None:
553            self.message = message
554        if code is not None:
555            self.code = code
556
557    def __call__(self, value: Any) -> None:
558        extension = Path(value.name).suffix[1:].lower()
559        if (
560            self.allowed_extensions is not None
561            and extension not in self.allowed_extensions
562        ):
563            raise ValidationError(
564                self.message,
565                code=self.code,
566                params={
567                    "extension": extension,
568                    "allowed_extensions": ", ".join(self.allowed_extensions),
569                    "value": value,
570                },
571            )
572
573    def __eq__(self, other: object) -> bool:
574        return (
575            isinstance(other, self.__class__)
576            and self.allowed_extensions == other.allowed_extensions
577            and self.message == other.message
578            and self.code == other.code
579        )
580
581
582def get_available_image_extensions() -> list[str]:
583    try:
584        from PIL import Image  # ty: ignore[unresolved-import]
585    except ImportError:
586        return []
587    else:
588        Image.init()
589        return [ext.lower()[1:] for ext in Image.EXTENSION]
590
591
592def validate_image_file_extension(value: Any) -> None:
593    return FileExtensionValidator(allowed_extensions=get_available_image_extensions())(
594        value
595    )
596
597
598@deconstructible
599class ProhibitNullCharactersValidator:
600    """Validate that the string doesn't contain the null character."""
601
602    message = "Null characters are not allowed."
603    code = "null_characters_not_allowed"
604
605    def __init__(self, message: str | None = None, code: str | None = None) -> None:
606        if message is not None:
607            self.message = message
608        if code is not None:
609            self.code = code
610
611    def __call__(self, value: Any) -> None:
612        if "\x00" in str(value):
613            raise ValidationError(self.message, code=self.code, params={"value": value})
614
615    def __eq__(self, other: object) -> bool:
616        return (
617            isinstance(other, self.__class__)
618            and self.message == other.message
619            and self.code == other.code
620        )