1from __future__ import annotations
2
3import ipaddress
4import math
5import re
6from collections.abc import Callable
7from pathlib import Path
8from typing import TYPE_CHECKING, Any, cast
9from urllib.parse import urlsplit, urlunsplit
10
11from plain.exceptions import ValidationError
12from plain.utils.deconstruct import deconstructible
13from plain.utils.encoding import punycode
14from plain.utils.ipv6 import is_valid_ipv6_address
15from plain.utils.regex_helper import _lazy_re_compile
16from plain.utils.text import pluralize_lazy
17
18if TYPE_CHECKING:
19 from plain.utils.functional import SimpleLazyObject
20
21# These values, if given to validate(), will trigger the self.required check.
22EMPTY_VALUES = (None, "", [], (), {})
23
24
25@deconstructible
26class RegexValidator:
27 regex: str | re.Pattern[str] | SimpleLazyObject = ""
28 message = "Enter a valid value."
29 code = "invalid"
30 inverse_match = False
31 flags = 0
32
33 def __init__(
34 self,
35 regex: str | re.Pattern[str] | None = None,
36 message: str | None = None,
37 code: str | None = None,
38 inverse_match: bool | None = None,
39 flags: int | None = None,
40 ) -> None:
41 # Only compile regex if explicitly provided or if class default needs compilation
42 if regex is not None:
43 regex_to_compile: str | re.Pattern[str] = regex
44 elif isinstance(self.regex, str | re.Pattern):
45 # Class-level regex is a string or pattern that needs compilation
46 regex_to_compile = self.regex
47 else:
48 # Class-level regex is already compiled (e.g., in URL Validator subclass)
49 # Don't recompile it
50 regex_to_compile = None
51
52 if message is not None:
53 self.message = message
54 if code is not None:
55 self.code = code
56 if inverse_match is not None:
57 self.inverse_match = inverse_match
58 if flags is not None:
59 self.flags = flags
60
61 # Only compile if we have a regex to compile
62 if regex_to_compile is not None:
63 if self.flags and not isinstance(regex_to_compile, str):
64 raise TypeError(
65 "If the flags are set, regex must be a regular expression string."
66 )
67 self.regex = _lazy_re_compile(regex_to_compile, self.flags)
68
69 def __call__(self, value: Any) -> None:
70 """
71 Validate that the input contains (or does *not* contain, if
72 inverse_match is True) a match for the regular expression.
73 """
74 # self.regex is always a SimpleLazyObject with search() after __init__
75 regex_matches = cast(re.Pattern[str], self.regex).search(str(value))
76 invalid_input = regex_matches if self.inverse_match else not regex_matches
77 if invalid_input:
78 raise ValidationError(self.message, code=self.code, params={"value": value})
79
80 def __eq__(self, other: object) -> bool:
81 if not isinstance(other, RegexValidator):
82 return NotImplemented
83 self_regex = cast(re.Pattern[str], self.regex)
84 other_regex = cast(re.Pattern[str], other.regex)
85 return (
86 self_regex.pattern == other_regex.pattern
87 and self_regex.flags == other_regex.flags
88 and (self.message == other.message)
89 and (self.code == other.code)
90 and (self.inverse_match == other.inverse_match)
91 )
92
93
94@deconstructible
95class URLValidator(RegexValidator):
96 ul = "\u00a1-\uffff" # Unicode letters range (must not be a raw string).
97
98 # IP patterns
99 ipv4_re = (
100 r"(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)"
101 r"(?:\.(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)){3}"
102 )
103 ipv6_re = r"\[[0-9a-f:.]+\]" # (simple regex, validated later)
104
105 # Host patterns
106 hostname_re = (
107 r"[a-z" + ul + r"0-9](?:[a-z" + ul + r"0-9-]{0,61}[a-z" + ul + r"0-9])?"
108 )
109 # Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1
110 domain_re = r"(?:\.(?!-)[a-z" + ul + r"0-9-]{1,63}(?<!-))*"
111 tld_re = (
112 r"\." # dot
113 r"(?!-)" # can't start with a dash
114 r"(?:[a-z" + ul + "-]{2,63}" # domain label
115 r"|xn--[a-z0-9]{1,59})" # or punycode label
116 r"(?<!-)" # can't end with a dash
117 r"\.?" # may have a trailing dot
118 )
119 host_re = "(" + hostname_re + domain_re + tld_re + "|localhost)"
120
121 regex = _lazy_re_compile(
122 r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately
123 r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication
124 r"(?:" + ipv4_re + "|" + ipv6_re + "|" + host_re + ")"
125 r"(?::[0-9]{1,5})?" # port
126 r"(?:[/?#][^\s]*)?" # resource path
127 r"\Z",
128 re.IGNORECASE,
129 )
130 message = "Enter a valid URL."
131 schemes = ["http", "https", "ftp", "ftps"]
132 unsafe_chars = frozenset("\t\r\n")
133
134 def __init__(self, schemes: list[str] | None = None, **kwargs: Any) -> None:
135 super().__init__(**kwargs)
136 if schemes is not None:
137 self.schemes = schemes
138
139 def __call__(self, value: Any) -> None:
140 if not isinstance(value, str):
141 raise ValidationError(self.message, code=self.code, params={"value": value})
142 if self.unsafe_chars.intersection(value):
143 raise ValidationError(self.message, code=self.code, params={"value": value})
144 # Check if the scheme is valid.
145 scheme = value.split("://")[0].lower()
146 if scheme not in self.schemes:
147 raise ValidationError(self.message, code=self.code, params={"value": value})
148
149 # Then check full URL
150 try:
151 splitted_url = urlsplit(value)
152 except ValueError:
153 raise ValidationError(self.message, code=self.code, params={"value": value})
154 try:
155 super().__call__(value)
156 except ValidationError as e:
157 # Trivial case failed. Try for possible IDN domain
158 if value:
159 scheme, netloc, path, query, fragment = splitted_url
160 try:
161 netloc = punycode(netloc) # IDN -> ACE
162 except UnicodeError: # invalid domain part
163 raise e
164 url = urlunsplit((scheme, netloc, path, query, fragment))
165 super().__call__(url)
166 else:
167 raise
168 else:
169 # Now verify IPv6 in the netloc part
170 host_match = re.search(r"^\[(.+)\](?::[0-9]{1,5})?$", splitted_url.netloc)
171 if host_match:
172 potential_ip = host_match[1]
173 try:
174 validate_ipv6_address(potential_ip)
175 except ValidationError:
176 raise ValidationError(
177 self.message, code=self.code, params={"value": value}
178 )
179
180 # The maximum length of a full host name is 253 characters per RFC 1034
181 # section 3.1. It's defined to be 255 bytes or less, but this includes
182 # one byte for the length of the name and one byte for the trailing dot
183 # that's used to indicate absolute names in DNS.
184 if splitted_url.hostname is None or len(splitted_url.hostname) > 253:
185 raise ValidationError(self.message, code=self.code, params={"value": value})
186
187
188@deconstructible
189class EmailValidator:
190 message = "Enter a valid email address."
191 code = "invalid"
192 user_regex = _lazy_re_compile(
193 # dot-atom
194 r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z"
195 # quoted-string
196 r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])'
197 r'*"\Z)',
198 re.IGNORECASE,
199 )
200 domain_regex = _lazy_re_compile(
201 # max length for domain name labels is 63 characters per RFC 1034
202 r"((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z",
203 re.IGNORECASE,
204 )
205 literal_regex = _lazy_re_compile(
206 # literal form, ipv4 or ipv6 address (SMTP 4.1.3)
207 r"\[([A-F0-9:.]+)\]\Z",
208 re.IGNORECASE,
209 )
210 domain_allowlist = ["localhost"]
211
212 def __init__(
213 self,
214 message: str | None = None,
215 code: str | None = None,
216 allowlist: list[str] | None = None,
217 ) -> None:
218 if message is not None:
219 self.message = message
220 if code is not None:
221 self.code = code
222 if allowlist is not None:
223 self.domain_allowlist = allowlist
224
225 def __call__(self, value: Any) -> None:
226 if not value or "@" not in value:
227 raise ValidationError(self.message, code=self.code, params={"value": value})
228
229 user_part, domain_part = value.rsplit("@", 1)
230
231 if not self.user_regex.match(user_part):
232 raise ValidationError(self.message, code=self.code, params={"value": value})
233
234 if domain_part not in self.domain_allowlist and not self.validate_domain_part(
235 domain_part
236 ):
237 # Try for possible IDN domain-part
238 try:
239 domain_part = punycode(domain_part)
240 except UnicodeError:
241 pass
242 else:
243 if self.validate_domain_part(domain_part):
244 return None
245 raise ValidationError(self.message, code=self.code, params={"value": value})
246 return None
247
248 def validate_domain_part(self, domain_part: str) -> bool:
249 if self.domain_regex.match(domain_part):
250 return True
251
252 literal_match = self.literal_regex.match(domain_part)
253 if literal_match:
254 ip_address = literal_match[1]
255 try:
256 validate_ipv46_address(ip_address)
257 return True
258 except ValidationError:
259 pass
260 return False
261
262 def __eq__(self, other: object) -> bool:
263 return (
264 isinstance(other, EmailValidator)
265 and (self.domain_allowlist == other.domain_allowlist)
266 and (self.message == other.message)
267 and (self.code == other.code)
268 )
269
270
271validate_email = EmailValidator()
272
273
274def validate_ipv4_address(value: str, /) -> None:
275 try:
276 ipaddress.IPv4Address(value)
277 except ValueError:
278 raise ValidationError(
279 "Enter a valid IPv4 address.", code="invalid", params={"value": value}
280 )
281
282
283def validate_ipv6_address(value: str, /) -> None:
284 if not is_valid_ipv6_address(value):
285 raise ValidationError(
286 "Enter a valid IPv6 address.", code="invalid", params={"value": value}
287 )
288
289
290def validate_ipv46_address(value: str, /) -> None:
291 try:
292 validate_ipv4_address(value)
293 except ValidationError:
294 try:
295 validate_ipv6_address(value)
296 except ValidationError:
297 raise ValidationError(
298 "Enter a valid IPv4 or IPv6 address.",
299 code="invalid",
300 params={"value": value},
301 )
302
303
304ip_address_validator_map: dict[str, tuple[list[Callable[[str], None]], str]] = {
305 "both": ([validate_ipv46_address], "Enter a valid IPv4 or IPv6 address."),
306 "ipv4": ([validate_ipv4_address], "Enter a valid IPv4 address."),
307 "ipv6": ([validate_ipv6_address], "Enter a valid IPv6 address."),
308}
309
310
311def ip_address_validators(
312 protocol: str, unpack_ipv4: bool
313) -> tuple[list[Callable[[str], None]], str]:
314 """
315 Depending on the given parameters, return the appropriate validators for
316 the GenericIPAddressField.
317 """
318 if protocol != "both" and unpack_ipv4:
319 raise ValueError(
320 "You can only use `unpack_ipv4` if `protocol` is set to 'both'"
321 )
322 try:
323 return ip_address_validator_map[protocol.lower()]
324 except KeyError:
325 raise ValueError(
326 f"The protocol '{protocol}' is unknown. Supported: {list(ip_address_validator_map)}"
327 )
328
329
330def int_list_validator(
331 sep: str = ",",
332 message: str | None = None,
333 code: str = "invalid",
334 allow_negative: bool = False,
335) -> RegexValidator:
336 regexp = _lazy_re_compile(
337 r"^{neg}\d+(?:{sep}{neg}\d+)*\Z".format(
338 neg="(-)?" if allow_negative else "",
339 sep=re.escape(sep),
340 )
341 )
342 return RegexValidator(regexp, message=message, code=code) # ty: ignore[invalid-argument-type]
343
344
345validate_comma_separated_integer_list = int_list_validator(
346 message="Enter only digits separated by commas.",
347)
348
349
350@deconstructible
351class BaseValidator:
352 message = "Ensure this value is %(limit_value)s (it is %(show_value)s)."
353 code = "limit_value"
354
355 def __init__(self, limit_value: Any, message: str | None = None) -> None:
356 self.limit_value = limit_value
357 if message:
358 self.message = message
359
360 def __call__(self, value: Any) -> None:
361 cleaned = self.clean(value)
362 limit_value = (
363 self.limit_value() if callable(self.limit_value) else self.limit_value
364 )
365 params = {"limit_value": limit_value, "show_value": cleaned, "value": value}
366 if self.compare(cleaned, limit_value):
367 raise ValidationError(self.message, code=self.code, params=params)
368
369 def __eq__(self, other: object) -> bool:
370 if not isinstance(other, self.__class__):
371 return NotImplemented
372 return (
373 self.limit_value == other.limit_value
374 and self.message == other.message
375 and self.code == other.code
376 )
377
378 def compare(self, a: Any, b: Any) -> bool:
379 return a is not b
380
381 def clean(self, x: Any) -> Any:
382 return x
383
384
385@deconstructible
386class MaxValueValidator(BaseValidator):
387 message = "Ensure this value is less than or equal to %(limit_value)s."
388 code = "max_value"
389
390 def compare(self, a: Any, b: Any) -> bool:
391 return a > b
392
393
394@deconstructible
395class MinValueValidator(BaseValidator):
396 message = "Ensure this value is greater than or equal to %(limit_value)s."
397 code = "min_value"
398
399 def compare(self, a: Any, b: Any) -> bool:
400 return a < b
401
402
403@deconstructible
404class StepValueValidator(BaseValidator):
405 message = "Ensure this value is a multiple of step size %(limit_value)s."
406 code = "step_size"
407
408 def compare(self, a: Any, b: Any) -> bool:
409 return not math.isclose(math.remainder(a, b), 0, abs_tol=1e-9)
410
411
412@deconstructible
413class MinLengthValidator(BaseValidator):
414 message = pluralize_lazy(
415 "Ensure this value has at least %(limit_value)d character (it has "
416 "%(show_value)d).",
417 "Ensure this value has at least %(limit_value)d characters (it has "
418 "%(show_value)d).",
419 "limit_value",
420 )
421 code = "min_length"
422
423 def compare(self, a: Any, b: Any) -> bool:
424 return a < b
425
426 def clean(self, x: Any) -> int:
427 return len(x)
428
429
430@deconstructible
431class MaxLengthValidator(BaseValidator):
432 message = pluralize_lazy(
433 "Ensure this value has at most %(limit_value)d character (it has "
434 "%(show_value)d).",
435 "Ensure this value has at most %(limit_value)d characters (it has "
436 "%(show_value)d).",
437 "limit_value",
438 )
439 code = "max_length"
440
441 def compare(self, a: Any, b: Any) -> bool:
442 return a > b
443
444 def clean(self, x: Any) -> int:
445 return len(x)
446
447
448@deconstructible
449class DecimalValidator:
450 """
451 Validate that the input does not exceed the maximum number of digits
452 expected, otherwise raise ValidationError.
453 """
454
455 messages = {
456 "invalid": "Enter a number.",
457 "max_digits": pluralize_lazy(
458 "Ensure that there are no more than %(max)s digit in total.",
459 "Ensure that there are no more than %(max)s digits in total.",
460 "max",
461 ),
462 "max_decimal_places": pluralize_lazy(
463 "Ensure that there are no more than %(max)s decimal place.",
464 "Ensure that there are no more than %(max)s decimal places.",
465 "max",
466 ),
467 "max_whole_digits": pluralize_lazy(
468 "Ensure that there are no more than %(max)s digit before the decimal "
469 "point.",
470 "Ensure that there are no more than %(max)s digits before the decimal "
471 "point.",
472 "max",
473 ),
474 }
475
476 def __init__(self, max_digits: int | None, decimal_places: int | None) -> None:
477 self.max_digits = max_digits
478 self.decimal_places = decimal_places
479
480 def __call__(self, value: Any) -> None:
481 digit_tuple, exponent = value.as_tuple()[1:]
482 if exponent in {"F", "n", "N"}:
483 raise ValidationError(
484 self.messages["invalid"], code="invalid", params={"value": value}
485 )
486 if exponent >= 0:
487 digits = len(digit_tuple)
488 if digit_tuple != (0,):
489 # A positive exponent adds that many trailing zeros.
490 digits += exponent
491 decimals = 0
492 else:
493 # If the absolute value of the negative exponent is larger than the
494 # number of digits, then it's the same as the number of digits,
495 # because it'll consume all of the digits in digit_tuple and then
496 # add abs(exponent) - len(digit_tuple) leading zeros after the
497 # decimal point.
498 if abs(exponent) > len(digit_tuple):
499 digits = decimals = abs(exponent)
500 else:
501 digits = len(digit_tuple)
502 decimals = abs(exponent)
503 whole_digits = digits - decimals
504
505 if self.max_digits is not None and digits > self.max_digits:
506 raise ValidationError(
507 self.messages["max_digits"],
508 code="max_digits",
509 params={"max": self.max_digits, "value": value},
510 )
511 if self.decimal_places is not None and decimals > self.decimal_places:
512 raise ValidationError(
513 self.messages["max_decimal_places"],
514 code="max_decimal_places",
515 params={"max": self.decimal_places, "value": value},
516 )
517 if (
518 self.max_digits is not None
519 and self.decimal_places is not None
520 and whole_digits > (self.max_digits - self.decimal_places)
521 ):
522 raise ValidationError(
523 self.messages["max_whole_digits"],
524 code="max_whole_digits",
525 params={"max": (self.max_digits - self.decimal_places), "value": value},
526 )
527
528 def __eq__(self, other: object) -> bool:
529 return (
530 isinstance(other, self.__class__)
531 and self.max_digits == other.max_digits
532 and self.decimal_places == other.decimal_places
533 )
534
535
536@deconstructible
537class FileExtensionValidator:
538 message = 'File extension "%(extension)s" is not allowed. Allowed extensions are: %(allowed_extensions)s.'
539 code = "invalid_extension"
540
541 def __init__(
542 self,
543 allowed_extensions: list[str] | None = None,
544 message: str | None = None,
545 code: str | None = None,
546 ) -> None:
547 if allowed_extensions is not None:
548 allowed_extensions = [
549 allowed_extension.lower() for allowed_extension in allowed_extensions
550 ]
551 self.allowed_extensions = allowed_extensions
552 if message is not None:
553 self.message = message
554 if code is not None:
555 self.code = code
556
557 def __call__(self, value: Any) -> None:
558 extension = Path(value.name).suffix[1:].lower()
559 if (
560 self.allowed_extensions is not None
561 and extension not in self.allowed_extensions
562 ):
563 raise ValidationError(
564 self.message,
565 code=self.code,
566 params={
567 "extension": extension,
568 "allowed_extensions": ", ".join(self.allowed_extensions),
569 "value": value,
570 },
571 )
572
573 def __eq__(self, other: object) -> bool:
574 return (
575 isinstance(other, self.__class__)
576 and self.allowed_extensions == other.allowed_extensions
577 and self.message == other.message
578 and self.code == other.code
579 )
580
581
582def get_available_image_extensions() -> list[str]:
583 try:
584 from PIL import Image # ty: ignore[unresolved-import]
585 except ImportError:
586 return []
587 else:
588 Image.init()
589 return [ext.lower()[1:] for ext in Image.EXTENSION]
590
591
592def validate_image_file_extension(value: Any) -> None:
593 return FileExtensionValidator(allowed_extensions=get_available_image_extensions())(
594 value
595 )
596
597
598@deconstructible
599class ProhibitNullCharactersValidator:
600 """Validate that the string doesn't contain the null character."""
601
602 message = "Null characters are not allowed."
603 code = "null_characters_not_allowed"
604
605 def __init__(self, message: str | None = None, code: str | None = None) -> None:
606 if message is not None:
607 self.message = message
608 if code is not None:
609 self.code = code
610
611 def __call__(self, value: Any) -> None:
612 if "\x00" in str(value):
613 raise ValidationError(self.message, code=self.code, params={"value": value})
614
615 def __eq__(self, other: object) -> bool:
616 return (
617 isinstance(other, self.__class__)
618 and self.message == other.message
619 and self.code == other.code
620 )