1from __future__ import annotations
2
3import ipaddress
4import math
5import re
6from collections.abc import Callable
7from pathlib import Path
8from typing import TYPE_CHECKING, Any, cast
9from urllib.parse import urlsplit, urlunsplit
10
11from plain.exceptions import ValidationError
12from plain.utils.deconstruct import deconstructible
13from plain.utils.encoding import punycode
14from plain.utils.ipv6 import is_valid_ipv6_address
15from plain.utils.regex_helper import _lazy_re_compile
16from plain.utils.text import pluralize_lazy
17
18if TYPE_CHECKING:
19 from plain.utils.functional import SimpleLazyObject
20
21# These values, if given to validate(), will trigger the self.required check.
22EMPTY_VALUES = (None, "", [], (), {})
23
24
25@deconstructible
26class RegexValidator:
27 regex: str | re.Pattern[str] | SimpleLazyObject = ""
28 message = "Enter a valid value."
29 code = "invalid"
30 inverse_match = False
31 flags = 0
32
33 def __init__(
34 self,
35 regex: str | re.Pattern[str] | None = None,
36 message: str | None = None,
37 code: str | None = None,
38 inverse_match: bool | None = None,
39 flags: int | None = None,
40 ) -> None:
41 # Only compile regex if explicitly provided or if class default needs compilation
42 if regex is not None:
43 regex_to_compile: str | re.Pattern[str] = regex
44 elif isinstance(self.regex, str | re.Pattern):
45 # Class-level regex is a string or pattern that needs compilation
46 regex_to_compile = self.regex
47 else:
48 # Class-level regex is already compiled (e.g., in URL Validator subclass)
49 # Don't recompile it
50 regex_to_compile = None # type: ignore[assignment]
51
52 if message is not None:
53 self.message = message
54 if code is not None:
55 self.code = code
56 if inverse_match is not None:
57 self.inverse_match = inverse_match
58 if flags is not None:
59 self.flags = flags
60
61 # Only compile if we have a regex to compile
62 if regex_to_compile is not None:
63 if self.flags and not isinstance(regex_to_compile, str):
64 raise TypeError(
65 "If the flags are set, regex must be a regular expression string."
66 )
67 self.regex = _lazy_re_compile(regex_to_compile, self.flags)
68
69 def __call__(self, value: Any) -> None:
70 """
71 Validate that the input contains (or does *not* contain, if
72 inverse_match is True) a match for the regular expression.
73 """
74 # self.regex is always a SimpleLazyObject with search() after __init__
75 regex_matches = cast(re.Pattern[str], self.regex).search(str(value))
76 invalid_input = regex_matches if self.inverse_match else not regex_matches
77 if invalid_input:
78 raise ValidationError(self.message, code=self.code, params={"value": value})
79
80 def __eq__(self, other: object) -> bool:
81 return (
82 isinstance(other, RegexValidator)
83 and self.regex.pattern == other.regex.pattern # type: ignore[attr-defined]
84 and self.regex.flags == other.regex.flags # type: ignore[attr-defined]
85 and (self.message == other.message)
86 and (self.code == other.code)
87 and (self.inverse_match == other.inverse_match)
88 )
89
90
91@deconstructible
92class URLValidator(RegexValidator):
93 ul = "\u00a1-\uffff" # Unicode letters range (must not be a raw string).
94
95 # IP patterns
96 ipv4_re = (
97 r"(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)"
98 r"(?:\.(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)){3}"
99 )
100 ipv6_re = r"\[[0-9a-f:.]+\]" # (simple regex, validated later)
101
102 # Host patterns
103 hostname_re = (
104 r"[a-z" + ul + r"0-9](?:[a-z" + ul + r"0-9-]{0,61}[a-z" + ul + r"0-9])?"
105 )
106 # Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1
107 domain_re = r"(?:\.(?!-)[a-z" + ul + r"0-9-]{1,63}(?<!-))*"
108 tld_re = (
109 r"\." # dot
110 r"(?!-)" # can't start with a dash
111 r"(?:[a-z" + ul + "-]{2,63}" # domain label
112 r"|xn--[a-z0-9]{1,59})" # or punycode label
113 r"(?<!-)" # can't end with a dash
114 r"\.?" # may have a trailing dot
115 )
116 host_re = "(" + hostname_re + domain_re + tld_re + "|localhost)"
117
118 regex = _lazy_re_compile(
119 r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately
120 r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication
121 r"(?:" + ipv4_re + "|" + ipv6_re + "|" + host_re + ")"
122 r"(?::[0-9]{1,5})?" # port
123 r"(?:[/?#][^\s]*)?" # resource path
124 r"\Z",
125 re.IGNORECASE,
126 )
127 message = "Enter a valid URL."
128 schemes = ["http", "https", "ftp", "ftps"]
129 unsafe_chars = frozenset("\t\r\n")
130
131 def __init__(self, schemes: list[str] | None = None, **kwargs: Any) -> None:
132 super().__init__(**kwargs)
133 if schemes is not None:
134 self.schemes = schemes
135
136 def __call__(self, value: Any) -> None:
137 if not isinstance(value, str):
138 raise ValidationError(self.message, code=self.code, params={"value": value})
139 if self.unsafe_chars.intersection(value):
140 raise ValidationError(self.message, code=self.code, params={"value": value})
141 # Check if the scheme is valid.
142 scheme = value.split("://")[0].lower()
143 if scheme not in self.schemes:
144 raise ValidationError(self.message, code=self.code, params={"value": value})
145
146 # Then check full URL
147 try:
148 splitted_url = urlsplit(value)
149 except ValueError:
150 raise ValidationError(self.message, code=self.code, params={"value": value})
151 try:
152 super().__call__(value)
153 except ValidationError as e:
154 # Trivial case failed. Try for possible IDN domain
155 if value:
156 scheme, netloc, path, query, fragment = splitted_url
157 try:
158 netloc = punycode(netloc) # IDN -> ACE
159 except UnicodeError: # invalid domain part
160 raise e
161 url = urlunsplit((scheme, netloc, path, query, fragment))
162 super().__call__(url)
163 else:
164 raise
165 else:
166 # Now verify IPv6 in the netloc part
167 host_match = re.search(r"^\[(.+)\](?::[0-9]{1,5})?$", splitted_url.netloc)
168 if host_match:
169 potential_ip = host_match[1]
170 try:
171 validate_ipv6_address(potential_ip)
172 except ValidationError:
173 raise ValidationError(
174 self.message, code=self.code, params={"value": value}
175 )
176
177 # The maximum length of a full host name is 253 characters per RFC 1034
178 # section 3.1. It's defined to be 255 bytes or less, but this includes
179 # one byte for the length of the name and one byte for the trailing dot
180 # that's used to indicate absolute names in DNS.
181 if splitted_url.hostname is None or len(splitted_url.hostname) > 253:
182 raise ValidationError(self.message, code=self.code, params={"value": value})
183
184
185@deconstructible
186class EmailValidator:
187 message = "Enter a valid email address."
188 code = "invalid"
189 user_regex = _lazy_re_compile(
190 # dot-atom
191 r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z"
192 # quoted-string
193 r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])'
194 r'*"\Z)',
195 re.IGNORECASE,
196 )
197 domain_regex = _lazy_re_compile(
198 # max length for domain name labels is 63 characters per RFC 1034
199 r"((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z",
200 re.IGNORECASE,
201 )
202 literal_regex = _lazy_re_compile(
203 # literal form, ipv4 or ipv6 address (SMTP 4.1.3)
204 r"\[([A-F0-9:.]+)\]\Z",
205 re.IGNORECASE,
206 )
207 domain_allowlist = ["localhost"]
208
209 def __init__(
210 self,
211 message: str | None = None,
212 code: str | None = None,
213 allowlist: list[str] | None = None,
214 ) -> None:
215 if message is not None:
216 self.message = message
217 if code is not None:
218 self.code = code
219 if allowlist is not None:
220 self.domain_allowlist = allowlist
221
222 def __call__(self, value: Any) -> None:
223 if not value or "@" not in value:
224 raise ValidationError(self.message, code=self.code, params={"value": value})
225
226 user_part, domain_part = value.rsplit("@", 1)
227
228 if not self.user_regex.match(user_part):
229 raise ValidationError(self.message, code=self.code, params={"value": value})
230
231 if domain_part not in self.domain_allowlist and not self.validate_domain_part(
232 domain_part
233 ):
234 # Try for possible IDN domain-part
235 try:
236 domain_part = punycode(domain_part)
237 except UnicodeError:
238 pass
239 else:
240 if self.validate_domain_part(domain_part):
241 return None
242 raise ValidationError(self.message, code=self.code, params={"value": value})
243 return None
244
245 def validate_domain_part(self, domain_part: str) -> bool:
246 if self.domain_regex.match(domain_part):
247 return True
248
249 literal_match = self.literal_regex.match(domain_part)
250 if literal_match:
251 ip_address = literal_match[1]
252 try:
253 validate_ipv46_address(ip_address)
254 return True
255 except ValidationError:
256 pass
257 return False
258
259 def __eq__(self, other: object) -> bool:
260 return (
261 isinstance(other, EmailValidator)
262 and (self.domain_allowlist == other.domain_allowlist)
263 and (self.message == other.message)
264 and (self.code == other.code)
265 )
266
267
268validate_email = EmailValidator()
269
270
271def validate_ipv4_address(value: str) -> None:
272 try:
273 ipaddress.IPv4Address(value)
274 except ValueError:
275 raise ValidationError(
276 "Enter a valid IPv4 address.", code="invalid", params={"value": value}
277 )
278
279
280def validate_ipv6_address(value: str) -> None:
281 if not is_valid_ipv6_address(value):
282 raise ValidationError(
283 "Enter a valid IPv6 address.", code="invalid", params={"value": value}
284 )
285
286
287def validate_ipv46_address(value: str) -> None:
288 try:
289 validate_ipv4_address(value)
290 except ValidationError:
291 try:
292 validate_ipv6_address(value)
293 except ValidationError:
294 raise ValidationError(
295 "Enter a valid IPv4 or IPv6 address.",
296 code="invalid",
297 params={"value": value},
298 )
299
300
301ip_address_validator_map = {
302 "both": ([validate_ipv46_address], "Enter a valid IPv4 or IPv6 address."),
303 "ipv4": ([validate_ipv4_address], "Enter a valid IPv4 address."),
304 "ipv6": ([validate_ipv6_address], "Enter a valid IPv6 address."),
305}
306
307
308def ip_address_validators(
309 protocol: str, unpack_ipv4: bool
310) -> tuple[list[Callable[[str], None]], str]:
311 """
312 Depending on the given parameters, return the appropriate validators for
313 the GenericIPAddressField.
314 """
315 if protocol != "both" and unpack_ipv4:
316 raise ValueError(
317 "You can only use `unpack_ipv4` if `protocol` is set to 'both'"
318 )
319 try:
320 return ip_address_validator_map[protocol.lower()]
321 except KeyError:
322 raise ValueError(
323 f"The protocol '{protocol}' is unknown. Supported: {list(ip_address_validator_map)}"
324 )
325
326
327def int_list_validator(
328 sep: str = ",",
329 message: str | None = None,
330 code: str = "invalid",
331 allow_negative: bool = False,
332) -> RegexValidator:
333 regexp = _lazy_re_compile(
334 r"^{neg}\d+(?:{sep}{neg}\d+)*\Z".format(
335 neg="(-)?" if allow_negative else "",
336 sep=re.escape(sep),
337 )
338 )
339 return RegexValidator(regexp, message=message, code=code) # type: ignore[arg-type]
340
341
342validate_comma_separated_integer_list = int_list_validator(
343 message="Enter only digits separated by commas.",
344)
345
346
347@deconstructible
348class BaseValidator:
349 message = "Ensure this value is %(limit_value)s (it is %(show_value)s)."
350 code = "limit_value"
351
352 def __init__(self, limit_value: Any, message: str | None = None) -> None:
353 self.limit_value = limit_value
354 if message:
355 self.message = message
356
357 def __call__(self, value: Any) -> None:
358 cleaned = self.clean(value)
359 limit_value = (
360 self.limit_value() if callable(self.limit_value) else self.limit_value
361 )
362 params = {"limit_value": limit_value, "show_value": cleaned, "value": value}
363 if self.compare(cleaned, limit_value):
364 raise ValidationError(self.message, code=self.code, params=params)
365
366 def __eq__(self, other: object) -> bool:
367 if not isinstance(other, self.__class__):
368 return NotImplemented
369 return (
370 self.limit_value == other.limit_value
371 and self.message == other.message
372 and self.code == other.code
373 )
374
375 def compare(self, a: Any, b: Any) -> bool:
376 return a is not b
377
378 def clean(self, x: Any) -> Any:
379 return x
380
381
382@deconstructible
383class MaxValueValidator(BaseValidator):
384 message = "Ensure this value is less than or equal to %(limit_value)s."
385 code = "max_value"
386
387 def compare(self, a: Any, b: Any) -> bool:
388 return a > b
389
390
391@deconstructible
392class MinValueValidator(BaseValidator):
393 message = "Ensure this value is greater than or equal to %(limit_value)s."
394 code = "min_value"
395
396 def compare(self, a: Any, b: Any) -> bool:
397 return a < b
398
399
400@deconstructible
401class StepValueValidator(BaseValidator):
402 message = "Ensure this value is a multiple of step size %(limit_value)s."
403 code = "step_size"
404
405 def compare(self, a: Any, b: Any) -> bool:
406 return not math.isclose(math.remainder(a, b), 0, abs_tol=1e-9)
407
408
409@deconstructible
410class MinLengthValidator(BaseValidator):
411 message = pluralize_lazy(
412 "Ensure this value has at least %(limit_value)d character (it has "
413 "%(show_value)d).",
414 "Ensure this value has at least %(limit_value)d characters (it has "
415 "%(show_value)d).",
416 "limit_value",
417 )
418 code = "min_length"
419
420 def compare(self, a: Any, b: Any) -> bool:
421 return a < b
422
423 def clean(self, x: Any) -> int:
424 return len(x)
425
426
427@deconstructible
428class MaxLengthValidator(BaseValidator):
429 message = pluralize_lazy(
430 "Ensure this value has at most %(limit_value)d character (it has "
431 "%(show_value)d).",
432 "Ensure this value has at most %(limit_value)d characters (it has "
433 "%(show_value)d).",
434 "limit_value",
435 )
436 code = "max_length"
437
438 def compare(self, a: Any, b: Any) -> bool:
439 return a > b
440
441 def clean(self, x: Any) -> int:
442 return len(x)
443
444
445@deconstructible
446class DecimalValidator:
447 """
448 Validate that the input does not exceed the maximum number of digits
449 expected, otherwise raise ValidationError.
450 """
451
452 messages = {
453 "invalid": "Enter a number.",
454 "max_digits": pluralize_lazy(
455 "Ensure that there are no more than %(max)s digit in total.",
456 "Ensure that there are no more than %(max)s digits in total.",
457 "max",
458 ),
459 "max_decimal_places": pluralize_lazy(
460 "Ensure that there are no more than %(max)s decimal place.",
461 "Ensure that there are no more than %(max)s decimal places.",
462 "max",
463 ),
464 "max_whole_digits": pluralize_lazy(
465 "Ensure that there are no more than %(max)s digit before the decimal "
466 "point.",
467 "Ensure that there are no more than %(max)s digits before the decimal "
468 "point.",
469 "max",
470 ),
471 }
472
473 def __init__(self, max_digits: int | None, decimal_places: int | None) -> None:
474 self.max_digits = max_digits
475 self.decimal_places = decimal_places
476
477 def __call__(self, value: Any) -> None:
478 digit_tuple, exponent = value.as_tuple()[1:]
479 if exponent in {"F", "n", "N"}:
480 raise ValidationError(
481 self.messages["invalid"], code="invalid", params={"value": value}
482 )
483 if exponent >= 0:
484 digits = len(digit_tuple)
485 if digit_tuple != (0,):
486 # A positive exponent adds that many trailing zeros.
487 digits += exponent
488 decimals = 0
489 else:
490 # If the absolute value of the negative exponent is larger than the
491 # number of digits, then it's the same as the number of digits,
492 # because it'll consume all of the digits in digit_tuple and then
493 # add abs(exponent) - len(digit_tuple) leading zeros after the
494 # decimal point.
495 if abs(exponent) > len(digit_tuple):
496 digits = decimals = abs(exponent)
497 else:
498 digits = len(digit_tuple)
499 decimals = abs(exponent)
500 whole_digits = digits - decimals
501
502 if self.max_digits is not None and digits > self.max_digits:
503 raise ValidationError(
504 self.messages["max_digits"],
505 code="max_digits",
506 params={"max": self.max_digits, "value": value},
507 )
508 if self.decimal_places is not None and decimals > self.decimal_places:
509 raise ValidationError(
510 self.messages["max_decimal_places"],
511 code="max_decimal_places",
512 params={"max": self.decimal_places, "value": value},
513 )
514 if (
515 self.max_digits is not None
516 and self.decimal_places is not None
517 and whole_digits > (self.max_digits - self.decimal_places)
518 ):
519 raise ValidationError(
520 self.messages["max_whole_digits"],
521 code="max_whole_digits",
522 params={"max": (self.max_digits - self.decimal_places), "value": value},
523 )
524
525 def __eq__(self, other: object) -> bool:
526 return (
527 isinstance(other, self.__class__)
528 and self.max_digits == other.max_digits
529 and self.decimal_places == other.decimal_places
530 )
531
532
533@deconstructible
534class FileExtensionValidator:
535 message = 'File extension "%(extension)s" is not allowed. Allowed extensions are: %(allowed_extensions)s.'
536 code = "invalid_extension"
537
538 def __init__(
539 self,
540 allowed_extensions: list[str] | None = None,
541 message: str | None = None,
542 code: str | None = None,
543 ) -> None:
544 if allowed_extensions is not None:
545 allowed_extensions = [
546 allowed_extension.lower() for allowed_extension in allowed_extensions
547 ]
548 self.allowed_extensions = allowed_extensions
549 if message is not None:
550 self.message = message
551 if code is not None:
552 self.code = code
553
554 def __call__(self, value: Any) -> None:
555 extension = Path(value.name).suffix[1:].lower()
556 if (
557 self.allowed_extensions is not None
558 and extension not in self.allowed_extensions
559 ):
560 raise ValidationError(
561 self.message,
562 code=self.code,
563 params={
564 "extension": extension,
565 "allowed_extensions": ", ".join(self.allowed_extensions),
566 "value": value,
567 },
568 )
569
570 def __eq__(self, other: object) -> bool:
571 return (
572 isinstance(other, self.__class__)
573 and self.allowed_extensions == other.allowed_extensions
574 and self.message == other.message
575 and self.code == other.code
576 )
577
578
579def get_available_image_extensions() -> list[str]:
580 try:
581 from PIL import Image # type: ignore[import-untyped]
582 except ImportError:
583 return []
584 else:
585 Image.init()
586 return [ext.lower()[1:] for ext in Image.EXTENSION]
587
588
589def validate_image_file_extension(value: Any) -> None:
590 return FileExtensionValidator(allowed_extensions=get_available_image_extensions())(
591 value
592 )
593
594
595@deconstructible
596class ProhibitNullCharactersValidator:
597 """Validate that the string doesn't contain the null character."""
598
599 message = "Null characters are not allowed."
600 code = "null_characters_not_allowed"
601
602 def __init__(self, message: str | None = None, code: str | None = None) -> None:
603 if message is not None:
604 self.message = message
605 if code is not None:
606 self.code = code
607
608 def __call__(self, value: Any) -> None:
609 if "\x00" in str(value):
610 raise ValidationError(self.message, code=self.code, params={"value": value})
611
612 def __eq__(self, other: object) -> bool:
613 return (
614 isinstance(other, self.__class__)
615 and self.message == other.message
616 and self.code == other.code
617 )