Plain is headed towards 1.0! Subscribe for development updates โ†’

  1import ipaddress
  2import math
  3import re
  4from pathlib import Path
  5from urllib.parse import urlsplit, urlunsplit
  6
  7from plain.exceptions import ValidationError
  8from plain.utils.deconstruct import deconstructible
  9from plain.utils.encoding import punycode
 10from plain.utils.ipv6 import is_valid_ipv6_address
 11from plain.utils.regex_helper import _lazy_re_compile
 12from plain.utils.text import pluralize_lazy
 13
 14# These values, if given to validate(), will trigger the self.required check.
 15EMPTY_VALUES = (None, "", [], (), {})
 16
 17
 18@deconstructible
 19class RegexValidator:
 20    regex = ""
 21    message = "Enter a valid value."
 22    code = "invalid"
 23    inverse_match = False
 24    flags = 0
 25
 26    def __init__(
 27        self, regex=None, message=None, code=None, inverse_match=None, flags=None
 28    ):
 29        if regex is not None:
 30            self.regex = regex
 31        if message is not None:
 32            self.message = message
 33        if code is not None:
 34            self.code = code
 35        if inverse_match is not None:
 36            self.inverse_match = inverse_match
 37        if flags is not None:
 38            self.flags = flags
 39        if self.flags and not isinstance(self.regex, str):
 40            raise TypeError(
 41                "If the flags are set, regex must be a regular expression string."
 42            )
 43
 44        self.regex = _lazy_re_compile(self.regex, self.flags)
 45
 46    def __call__(self, value):
 47        """
 48        Validate that the input contains (or does *not* contain, if
 49        inverse_match is True) a match for the regular expression.
 50        """
 51        regex_matches = self.regex.search(str(value))
 52        invalid_input = regex_matches if self.inverse_match else not regex_matches
 53        if invalid_input:
 54            raise ValidationError(self.message, code=self.code, params={"value": value})
 55
 56    def __eq__(self, other):
 57        return (
 58            isinstance(other, RegexValidator)
 59            and self.regex.pattern == other.regex.pattern
 60            and self.regex.flags == other.regex.flags
 61            and (self.message == other.message)
 62            and (self.code == other.code)
 63            and (self.inverse_match == other.inverse_match)
 64        )
 65
 66
 67@deconstructible
 68class URLValidator(RegexValidator):
 69    ul = "\u00a1-\uffff"  # Unicode letters range (must not be a raw string).
 70
 71    # IP patterns
 72    ipv4_re = (
 73        r"(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)"
 74        r"(?:\.(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)){3}"
 75    )
 76    ipv6_re = r"\[[0-9a-f:.]+\]"  # (simple regex, validated later)
 77
 78    # Host patterns
 79    hostname_re = (
 80        r"[a-z" + ul + r"0-9](?:[a-z" + ul + r"0-9-]{0,61}[a-z" + ul + r"0-9])?"
 81    )
 82    # Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1
 83    domain_re = r"(?:\.(?!-)[a-z" + ul + r"0-9-]{1,63}(?<!-))*"
 84    tld_re = (
 85        r"\."  # dot
 86        r"(?!-)"  # can't start with a dash
 87        r"(?:[a-z" + ul + "-]{2,63}"  # domain label
 88        r"|xn--[a-z0-9]{1,59})"  # or punycode label
 89        r"(?<!-)"  # can't end with a dash
 90        r"\.?"  # may have a trailing dot
 91    )
 92    host_re = "(" + hostname_re + domain_re + tld_re + "|localhost)"
 93
 94    regex = _lazy_re_compile(
 95        r"^(?:[a-z0-9.+-]*)://"  # scheme is validated separately
 96        r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?"  # user:pass authentication
 97        r"(?:" + ipv4_re + "|" + ipv6_re + "|" + host_re + ")"
 98        r"(?::[0-9]{1,5})?"  # port
 99        r"(?:[/?#][^\s]*)?"  # resource path
100        r"\Z",
101        re.IGNORECASE,
102    )
103    message = "Enter a valid URL."
104    schemes = ["http", "https", "ftp", "ftps"]
105    unsafe_chars = frozenset("\t\r\n")
106
107    def __init__(self, schemes=None, **kwargs):
108        super().__init__(**kwargs)
109        if schemes is not None:
110            self.schemes = schemes
111
112    def __call__(self, value):
113        if not isinstance(value, str):
114            raise ValidationError(self.message, code=self.code, params={"value": value})
115        if self.unsafe_chars.intersection(value):
116            raise ValidationError(self.message, code=self.code, params={"value": value})
117        # Check if the scheme is valid.
118        scheme = value.split("://")[0].lower()
119        if scheme not in self.schemes:
120            raise ValidationError(self.message, code=self.code, params={"value": value})
121
122        # Then check full URL
123        try:
124            splitted_url = urlsplit(value)
125        except ValueError:
126            raise ValidationError(self.message, code=self.code, params={"value": value})
127        try:
128            super().__call__(value)
129        except ValidationError as e:
130            # Trivial case failed. Try for possible IDN domain
131            if value:
132                scheme, netloc, path, query, fragment = splitted_url
133                try:
134                    netloc = punycode(netloc)  # IDN -> ACE
135                except UnicodeError:  # invalid domain part
136                    raise e
137                url = urlunsplit((scheme, netloc, path, query, fragment))
138                super().__call__(url)
139            else:
140                raise
141        else:
142            # Now verify IPv6 in the netloc part
143            host_match = re.search(r"^\[(.+)\](?::[0-9]{1,5})?$", splitted_url.netloc)
144            if host_match:
145                potential_ip = host_match[1]
146                try:
147                    validate_ipv6_address(potential_ip)
148                except ValidationError:
149                    raise ValidationError(
150                        self.message, code=self.code, params={"value": value}
151                    )
152
153        # The maximum length of a full host name is 253 characters per RFC 1034
154        # section 3.1. It's defined to be 255 bytes or less, but this includes
155        # one byte for the length of the name and one byte for the trailing dot
156        # that's used to indicate absolute names in DNS.
157        if splitted_url.hostname is None or len(splitted_url.hostname) > 253:
158            raise ValidationError(self.message, code=self.code, params={"value": value})
159
160
161@deconstructible
162class EmailValidator:
163    message = "Enter a valid email address."
164    code = "invalid"
165    user_regex = _lazy_re_compile(
166        # dot-atom
167        r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z"
168        # quoted-string
169        r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])'
170        r'*"\Z)',
171        re.IGNORECASE,
172    )
173    domain_regex = _lazy_re_compile(
174        # max length for domain name labels is 63 characters per RFC 1034
175        r"((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z",
176        re.IGNORECASE,
177    )
178    literal_regex = _lazy_re_compile(
179        # literal form, ipv4 or ipv6 address (SMTP 4.1.3)
180        r"\[([A-F0-9:.]+)\]\Z",
181        re.IGNORECASE,
182    )
183    domain_allowlist = ["localhost"]
184
185    def __init__(self, message=None, code=None, allowlist=None):
186        if message is not None:
187            self.message = message
188        if code is not None:
189            self.code = code
190        if allowlist is not None:
191            self.domain_allowlist = allowlist
192
193    def __call__(self, value):
194        if not value or "@" not in value:
195            raise ValidationError(self.message, code=self.code, params={"value": value})
196
197        user_part, domain_part = value.rsplit("@", 1)
198
199        if not self.user_regex.match(user_part):
200            raise ValidationError(self.message, code=self.code, params={"value": value})
201
202        if domain_part not in self.domain_allowlist and not self.validate_domain_part(
203            domain_part
204        ):
205            # Try for possible IDN domain-part
206            try:
207                domain_part = punycode(domain_part)
208            except UnicodeError:
209                pass
210            else:
211                if self.validate_domain_part(domain_part):
212                    return
213            raise ValidationError(self.message, code=self.code, params={"value": value})
214
215    def validate_domain_part(self, domain_part):
216        if self.domain_regex.match(domain_part):
217            return True
218
219        literal_match = self.literal_regex.match(domain_part)
220        if literal_match:
221            ip_address = literal_match[1]
222            try:
223                validate_ipv46_address(ip_address)
224                return True
225            except ValidationError:
226                pass
227        return False
228
229    def __eq__(self, other):
230        return (
231            isinstance(other, EmailValidator)
232            and (self.domain_allowlist == other.domain_allowlist)
233            and (self.message == other.message)
234            and (self.code == other.code)
235        )
236
237
238validate_email = EmailValidator()
239
240
241def validate_ipv4_address(value):
242    try:
243        ipaddress.IPv4Address(value)
244    except ValueError:
245        raise ValidationError(
246            "Enter a valid IPv4 address.", code="invalid", params={"value": value}
247        )
248
249
250def validate_ipv6_address(value):
251    if not is_valid_ipv6_address(value):
252        raise ValidationError(
253            "Enter a valid IPv6 address.", code="invalid", params={"value": value}
254        )
255
256
257def validate_ipv46_address(value):
258    try:
259        validate_ipv4_address(value)
260    except ValidationError:
261        try:
262            validate_ipv6_address(value)
263        except ValidationError:
264            raise ValidationError(
265                "Enter a valid IPv4 or IPv6 address.",
266                code="invalid",
267                params={"value": value},
268            )
269
270
271ip_address_validator_map = {
272    "both": ([validate_ipv46_address], "Enter a valid IPv4 or IPv6 address."),
273    "ipv4": ([validate_ipv4_address], "Enter a valid IPv4 address."),
274    "ipv6": ([validate_ipv6_address], "Enter a valid IPv6 address."),
275}
276
277
278def ip_address_validators(protocol, unpack_ipv4):
279    """
280    Depending on the given parameters, return the appropriate validators for
281    the GenericIPAddressField.
282    """
283    if protocol != "both" and unpack_ipv4:
284        raise ValueError(
285            "You can only use `unpack_ipv4` if `protocol` is set to 'both'"
286        )
287    try:
288        return ip_address_validator_map[protocol.lower()]
289    except KeyError:
290        raise ValueError(
291            f"The protocol '{protocol}' is unknown. Supported: {list(ip_address_validator_map)}"
292        )
293
294
295def int_list_validator(sep=",", message=None, code="invalid", allow_negative=False):
296    regexp = _lazy_re_compile(
297        r"^{neg}\d+(?:{sep}{neg}\d+)*\Z".format(
298            neg="(-)?" if allow_negative else "",
299            sep=re.escape(sep),
300        )
301    )
302    return RegexValidator(regexp, message=message, code=code)
303
304
305validate_comma_separated_integer_list = int_list_validator(
306    message="Enter only digits separated by commas.",
307)
308
309
310@deconstructible
311class BaseValidator:
312    message = "Ensure this value is %(limit_value)s (it is %(show_value)s)."
313    code = "limit_value"
314
315    def __init__(self, limit_value, message=None):
316        self.limit_value = limit_value
317        if message:
318            self.message = message
319
320    def __call__(self, value):
321        cleaned = self.clean(value)
322        limit_value = (
323            self.limit_value() if callable(self.limit_value) else self.limit_value
324        )
325        params = {"limit_value": limit_value, "show_value": cleaned, "value": value}
326        if self.compare(cleaned, limit_value):
327            raise ValidationError(self.message, code=self.code, params=params)
328
329    def __eq__(self, other):
330        if not isinstance(other, self.__class__):
331            return NotImplemented
332        return (
333            self.limit_value == other.limit_value
334            and self.message == other.message
335            and self.code == other.code
336        )
337
338    def compare(self, a, b):
339        return a is not b
340
341    def clean(self, x):
342        return x
343
344
345@deconstructible
346class MaxValueValidator(BaseValidator):
347    message = "Ensure this value is less than or equal to %(limit_value)s."
348    code = "max_value"
349
350    def compare(self, a, b):
351        return a > b
352
353
354@deconstructible
355class MinValueValidator(BaseValidator):
356    message = "Ensure this value is greater than or equal to %(limit_value)s."
357    code = "min_value"
358
359    def compare(self, a, b):
360        return a < b
361
362
363@deconstructible
364class StepValueValidator(BaseValidator):
365    message = "Ensure this value is a multiple of step size %(limit_value)s."
366    code = "step_size"
367
368    def compare(self, a, b):
369        return not math.isclose(math.remainder(a, b), 0, abs_tol=1e-9)
370
371
372@deconstructible
373class MinLengthValidator(BaseValidator):
374    message = pluralize_lazy(
375        "Ensure this value has at least %(limit_value)d character (it has "
376        "%(show_value)d).",
377        "Ensure this value has at least %(limit_value)d characters (it has "
378        "%(show_value)d).",
379        "limit_value",
380    )
381    code = "min_length"
382
383    def compare(self, a, b):
384        return a < b
385
386    def clean(self, x):
387        return len(x)
388
389
390@deconstructible
391class MaxLengthValidator(BaseValidator):
392    message = pluralize_lazy(
393        "Ensure this value has at most %(limit_value)d character (it has "
394        "%(show_value)d).",
395        "Ensure this value has at most %(limit_value)d characters (it has "
396        "%(show_value)d).",
397        "limit_value",
398    )
399    code = "max_length"
400
401    def compare(self, a, b):
402        return a > b
403
404    def clean(self, x):
405        return len(x)
406
407
408@deconstructible
409class DecimalValidator:
410    """
411    Validate that the input does not exceed the maximum number of digits
412    expected, otherwise raise ValidationError.
413    """
414
415    messages = {
416        "invalid": "Enter a number.",
417        "max_digits": pluralize_lazy(
418            "Ensure that there are no more than %(max)s digit in total.",
419            "Ensure that there are no more than %(max)s digits in total.",
420            "max",
421        ),
422        "max_decimal_places": pluralize_lazy(
423            "Ensure that there are no more than %(max)s decimal place.",
424            "Ensure that there are no more than %(max)s decimal places.",
425            "max",
426        ),
427        "max_whole_digits": pluralize_lazy(
428            "Ensure that there are no more than %(max)s digit before the decimal "
429            "point.",
430            "Ensure that there are no more than %(max)s digits before the decimal "
431            "point.",
432            "max",
433        ),
434    }
435
436    def __init__(self, max_digits, decimal_places):
437        self.max_digits = max_digits
438        self.decimal_places = decimal_places
439
440    def __call__(self, value):
441        digit_tuple, exponent = value.as_tuple()[1:]
442        if exponent in {"F", "n", "N"}:
443            raise ValidationError(
444                self.messages["invalid"], code="invalid", params={"value": value}
445            )
446        if exponent >= 0:
447            digits = len(digit_tuple)
448            if digit_tuple != (0,):
449                # A positive exponent adds that many trailing zeros.
450                digits += exponent
451            decimals = 0
452        else:
453            # If the absolute value of the negative exponent is larger than the
454            # number of digits, then it's the same as the number of digits,
455            # because it'll consume all of the digits in digit_tuple and then
456            # add abs(exponent) - len(digit_tuple) leading zeros after the
457            # decimal point.
458            if abs(exponent) > len(digit_tuple):
459                digits = decimals = abs(exponent)
460            else:
461                digits = len(digit_tuple)
462                decimals = abs(exponent)
463        whole_digits = digits - decimals
464
465        if self.max_digits is not None and digits > self.max_digits:
466            raise ValidationError(
467                self.messages["max_digits"],
468                code="max_digits",
469                params={"max": self.max_digits, "value": value},
470            )
471        if self.decimal_places is not None and decimals > self.decimal_places:
472            raise ValidationError(
473                self.messages["max_decimal_places"],
474                code="max_decimal_places",
475                params={"max": self.decimal_places, "value": value},
476            )
477        if (
478            self.max_digits is not None
479            and self.decimal_places is not None
480            and whole_digits > (self.max_digits - self.decimal_places)
481        ):
482            raise ValidationError(
483                self.messages["max_whole_digits"],
484                code="max_whole_digits",
485                params={"max": (self.max_digits - self.decimal_places), "value": value},
486            )
487
488    def __eq__(self, other):
489        return (
490            isinstance(other, self.__class__)
491            and self.max_digits == other.max_digits
492            and self.decimal_places == other.decimal_places
493        )
494
495
496@deconstructible
497class FileExtensionValidator:
498    message = "File extension โ€œ%(extension)sโ€ is not allowed. Allowed extensions are: %(allowed_extensions)s."
499    code = "invalid_extension"
500
501    def __init__(self, allowed_extensions=None, message=None, code=None):
502        if allowed_extensions is not None:
503            allowed_extensions = [
504                allowed_extension.lower() for allowed_extension in allowed_extensions
505            ]
506        self.allowed_extensions = allowed_extensions
507        if message is not None:
508            self.message = message
509        if code is not None:
510            self.code = code
511
512    def __call__(self, value):
513        extension = Path(value.name).suffix[1:].lower()
514        if (
515            self.allowed_extensions is not None
516            and extension not in self.allowed_extensions
517        ):
518            raise ValidationError(
519                self.message,
520                code=self.code,
521                params={
522                    "extension": extension,
523                    "allowed_extensions": ", ".join(self.allowed_extensions),
524                    "value": value,
525                },
526            )
527
528    def __eq__(self, other):
529        return (
530            isinstance(other, self.__class__)
531            and self.allowed_extensions == other.allowed_extensions
532            and self.message == other.message
533            and self.code == other.code
534        )
535
536
537def get_available_image_extensions():
538    try:
539        from PIL import Image
540    except ImportError:
541        return []
542    else:
543        Image.init()
544        return [ext.lower()[1:] for ext in Image.EXTENSION]
545
546
547def validate_image_file_extension(value):
548    return FileExtensionValidator(allowed_extensions=get_available_image_extensions())(
549        value
550    )
551
552
553@deconstructible
554class ProhibitNullCharactersValidator:
555    """Validate that the string doesn't contain the null character."""
556
557    message = "Null characters are not allowed."
558    code = "null_characters_not_allowed"
559
560    def __init__(self, message=None, code=None):
561        if message is not None:
562            self.message = message
563        if code is not None:
564            self.code = code
565
566    def __call__(self, value):
567        if "\x00" in str(value):
568            raise ValidationError(self.message, code=self.code, params={"value": value})
569
570    def __eq__(self, other):
571        return (
572            isinstance(other, self.__class__)
573            and self.message == other.message
574            and self.code == other.code
575        )