1import ipaddress
2import math
3import re
4from pathlib import Path
5from urllib.parse import urlsplit, urlunsplit
6
7from plain.exceptions import ValidationError
8from plain.utils.deconstruct import deconstructible
9from plain.utils.encoding import punycode
10from plain.utils.ipv6 import is_valid_ipv6_address
11from plain.utils.regex_helper import _lazy_re_compile
12from plain.utils.text import pluralize_lazy
13
14# These values, if given to validate(), will trigger the self.required check.
15EMPTY_VALUES = (None, "", [], (), {})
16
17
18@deconstructible
19class RegexValidator:
20 regex = ""
21 message = "Enter a valid value."
22 code = "invalid"
23 inverse_match = False
24 flags = 0
25
26 def __init__(
27 self, regex=None, message=None, code=None, inverse_match=None, flags=None
28 ):
29 if regex is not None:
30 self.regex = regex
31 if message is not None:
32 self.message = message
33 if code is not None:
34 self.code = code
35 if inverse_match is not None:
36 self.inverse_match = inverse_match
37 if flags is not None:
38 self.flags = flags
39 if self.flags and not isinstance(self.regex, str):
40 raise TypeError(
41 "If the flags are set, regex must be a regular expression string."
42 )
43
44 self.regex = _lazy_re_compile(self.regex, self.flags)
45
46 def __call__(self, value):
47 """
48 Validate that the input contains (or does *not* contain, if
49 inverse_match is True) a match for the regular expression.
50 """
51 regex_matches = self.regex.search(str(value))
52 invalid_input = regex_matches if self.inverse_match else not regex_matches
53 if invalid_input:
54 raise ValidationError(self.message, code=self.code, params={"value": value})
55
56 def __eq__(self, other):
57 return (
58 isinstance(other, RegexValidator)
59 and self.regex.pattern == other.regex.pattern
60 and self.regex.flags == other.regex.flags
61 and (self.message == other.message)
62 and (self.code == other.code)
63 and (self.inverse_match == other.inverse_match)
64 )
65
66
67@deconstructible
68class URLValidator(RegexValidator):
69 ul = "\u00a1-\uffff" # Unicode letters range (must not be a raw string).
70
71 # IP patterns
72 ipv4_re = (
73 r"(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)"
74 r"(?:\.(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)){3}"
75 )
76 ipv6_re = r"\[[0-9a-f:.]+\]" # (simple regex, validated later)
77
78 # Host patterns
79 hostname_re = (
80 r"[a-z" + ul + r"0-9](?:[a-z" + ul + r"0-9-]{0,61}[a-z" + ul + r"0-9])?"
81 )
82 # Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1
83 domain_re = r"(?:\.(?!-)[a-z" + ul + r"0-9-]{1,63}(?<!-))*"
84 tld_re = (
85 r"\." # dot
86 r"(?!-)" # can't start with a dash
87 r"(?:[a-z" + ul + "-]{2,63}" # domain label
88 r"|xn--[a-z0-9]{1,59})" # or punycode label
89 r"(?<!-)" # can't end with a dash
90 r"\.?" # may have a trailing dot
91 )
92 host_re = "(" + hostname_re + domain_re + tld_re + "|localhost)"
93
94 regex = _lazy_re_compile(
95 r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately
96 r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication
97 r"(?:" + ipv4_re + "|" + ipv6_re + "|" + host_re + ")"
98 r"(?::[0-9]{1,5})?" # port
99 r"(?:[/?#][^\s]*)?" # resource path
100 r"\Z",
101 re.IGNORECASE,
102 )
103 message = "Enter a valid URL."
104 schemes = ["http", "https", "ftp", "ftps"]
105 unsafe_chars = frozenset("\t\r\n")
106
107 def __init__(self, schemes=None, **kwargs):
108 super().__init__(**kwargs)
109 if schemes is not None:
110 self.schemes = schemes
111
112 def __call__(self, value):
113 if not isinstance(value, str):
114 raise ValidationError(self.message, code=self.code, params={"value": value})
115 if self.unsafe_chars.intersection(value):
116 raise ValidationError(self.message, code=self.code, params={"value": value})
117 # Check if the scheme is valid.
118 scheme = value.split("://")[0].lower()
119 if scheme not in self.schemes:
120 raise ValidationError(self.message, code=self.code, params={"value": value})
121
122 # Then check full URL
123 try:
124 splitted_url = urlsplit(value)
125 except ValueError:
126 raise ValidationError(self.message, code=self.code, params={"value": value})
127 try:
128 super().__call__(value)
129 except ValidationError as e:
130 # Trivial case failed. Try for possible IDN domain
131 if value:
132 scheme, netloc, path, query, fragment = splitted_url
133 try:
134 netloc = punycode(netloc) # IDN -> ACE
135 except UnicodeError: # invalid domain part
136 raise e
137 url = urlunsplit((scheme, netloc, path, query, fragment))
138 super().__call__(url)
139 else:
140 raise
141 else:
142 # Now verify IPv6 in the netloc part
143 host_match = re.search(r"^\[(.+)\](?::[0-9]{1,5})?$", splitted_url.netloc)
144 if host_match:
145 potential_ip = host_match[1]
146 try:
147 validate_ipv6_address(potential_ip)
148 except ValidationError:
149 raise ValidationError(
150 self.message, code=self.code, params={"value": value}
151 )
152
153 # The maximum length of a full host name is 253 characters per RFC 1034
154 # section 3.1. It's defined to be 255 bytes or less, but this includes
155 # one byte for the length of the name and one byte for the trailing dot
156 # that's used to indicate absolute names in DNS.
157 if splitted_url.hostname is None or len(splitted_url.hostname) > 253:
158 raise ValidationError(self.message, code=self.code, params={"value": value})
159
160
161@deconstructible
162class EmailValidator:
163 message = "Enter a valid email address."
164 code = "invalid"
165 user_regex = _lazy_re_compile(
166 # dot-atom
167 r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z"
168 # quoted-string
169 r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])'
170 r'*"\Z)',
171 re.IGNORECASE,
172 )
173 domain_regex = _lazy_re_compile(
174 # max length for domain name labels is 63 characters per RFC 1034
175 r"((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z",
176 re.IGNORECASE,
177 )
178 literal_regex = _lazy_re_compile(
179 # literal form, ipv4 or ipv6 address (SMTP 4.1.3)
180 r"\[([A-F0-9:.]+)\]\Z",
181 re.IGNORECASE,
182 )
183 domain_allowlist = ["localhost"]
184
185 def __init__(self, message=None, code=None, allowlist=None):
186 if message is not None:
187 self.message = message
188 if code is not None:
189 self.code = code
190 if allowlist is not None:
191 self.domain_allowlist = allowlist
192
193 def __call__(self, value):
194 if not value or "@" not in value:
195 raise ValidationError(self.message, code=self.code, params={"value": value})
196
197 user_part, domain_part = value.rsplit("@", 1)
198
199 if not self.user_regex.match(user_part):
200 raise ValidationError(self.message, code=self.code, params={"value": value})
201
202 if domain_part not in self.domain_allowlist and not self.validate_domain_part(
203 domain_part
204 ):
205 # Try for possible IDN domain-part
206 try:
207 domain_part = punycode(domain_part)
208 except UnicodeError:
209 pass
210 else:
211 if self.validate_domain_part(domain_part):
212 return
213 raise ValidationError(self.message, code=self.code, params={"value": value})
214
215 def validate_domain_part(self, domain_part):
216 if self.domain_regex.match(domain_part):
217 return True
218
219 literal_match = self.literal_regex.match(domain_part)
220 if literal_match:
221 ip_address = literal_match[1]
222 try:
223 validate_ipv46_address(ip_address)
224 return True
225 except ValidationError:
226 pass
227 return False
228
229 def __eq__(self, other):
230 return (
231 isinstance(other, EmailValidator)
232 and (self.domain_allowlist == other.domain_allowlist)
233 and (self.message == other.message)
234 and (self.code == other.code)
235 )
236
237
238validate_email = EmailValidator()
239
240
241def validate_ipv4_address(value):
242 try:
243 ipaddress.IPv4Address(value)
244 except ValueError:
245 raise ValidationError(
246 "Enter a valid IPv4 address.", code="invalid", params={"value": value}
247 )
248
249
250def validate_ipv6_address(value):
251 if not is_valid_ipv6_address(value):
252 raise ValidationError(
253 "Enter a valid IPv6 address.", code="invalid", params={"value": value}
254 )
255
256
257def validate_ipv46_address(value):
258 try:
259 validate_ipv4_address(value)
260 except ValidationError:
261 try:
262 validate_ipv6_address(value)
263 except ValidationError:
264 raise ValidationError(
265 "Enter a valid IPv4 or IPv6 address.",
266 code="invalid",
267 params={"value": value},
268 )
269
270
271ip_address_validator_map = {
272 "both": ([validate_ipv46_address], "Enter a valid IPv4 or IPv6 address."),
273 "ipv4": ([validate_ipv4_address], "Enter a valid IPv4 address."),
274 "ipv6": ([validate_ipv6_address], "Enter a valid IPv6 address."),
275}
276
277
278def ip_address_validators(protocol, unpack_ipv4):
279 """
280 Depending on the given parameters, return the appropriate validators for
281 the GenericIPAddressField.
282 """
283 if protocol != "both" and unpack_ipv4:
284 raise ValueError(
285 "You can only use `unpack_ipv4` if `protocol` is set to 'both'"
286 )
287 try:
288 return ip_address_validator_map[protocol.lower()]
289 except KeyError:
290 raise ValueError(
291 f"The protocol '{protocol}' is unknown. Supported: {list(ip_address_validator_map)}"
292 )
293
294
295def int_list_validator(sep=",", message=None, code="invalid", allow_negative=False):
296 regexp = _lazy_re_compile(
297 r"^{neg}\d+(?:{sep}{neg}\d+)*\Z".format(
298 neg="(-)?" if allow_negative else "",
299 sep=re.escape(sep),
300 )
301 )
302 return RegexValidator(regexp, message=message, code=code)
303
304
305validate_comma_separated_integer_list = int_list_validator(
306 message="Enter only digits separated by commas.",
307)
308
309
310@deconstructible
311class BaseValidator:
312 message = "Ensure this value is %(limit_value)s (it is %(show_value)s)."
313 code = "limit_value"
314
315 def __init__(self, limit_value, message=None):
316 self.limit_value = limit_value
317 if message:
318 self.message = message
319
320 def __call__(self, value):
321 cleaned = self.clean(value)
322 limit_value = (
323 self.limit_value() if callable(self.limit_value) else self.limit_value
324 )
325 params = {"limit_value": limit_value, "show_value": cleaned, "value": value}
326 if self.compare(cleaned, limit_value):
327 raise ValidationError(self.message, code=self.code, params=params)
328
329 def __eq__(self, other):
330 if not isinstance(other, self.__class__):
331 return NotImplemented
332 return (
333 self.limit_value == other.limit_value
334 and self.message == other.message
335 and self.code == other.code
336 )
337
338 def compare(self, a, b):
339 return a is not b
340
341 def clean(self, x):
342 return x
343
344
345@deconstructible
346class MaxValueValidator(BaseValidator):
347 message = "Ensure this value is less than or equal to %(limit_value)s."
348 code = "max_value"
349
350 def compare(self, a, b):
351 return a > b
352
353
354@deconstructible
355class MinValueValidator(BaseValidator):
356 message = "Ensure this value is greater than or equal to %(limit_value)s."
357 code = "min_value"
358
359 def compare(self, a, b):
360 return a < b
361
362
363@deconstructible
364class StepValueValidator(BaseValidator):
365 message = "Ensure this value is a multiple of step size %(limit_value)s."
366 code = "step_size"
367
368 def compare(self, a, b):
369 return not math.isclose(math.remainder(a, b), 0, abs_tol=1e-9)
370
371
372@deconstructible
373class MinLengthValidator(BaseValidator):
374 message = pluralize_lazy(
375 "Ensure this value has at least %(limit_value)d character (it has "
376 "%(show_value)d).",
377 "Ensure this value has at least %(limit_value)d characters (it has "
378 "%(show_value)d).",
379 "limit_value",
380 )
381 code = "min_length"
382
383 def compare(self, a, b):
384 return a < b
385
386 def clean(self, x):
387 return len(x)
388
389
390@deconstructible
391class MaxLengthValidator(BaseValidator):
392 message = pluralize_lazy(
393 "Ensure this value has at most %(limit_value)d character (it has "
394 "%(show_value)d).",
395 "Ensure this value has at most %(limit_value)d characters (it has "
396 "%(show_value)d).",
397 "limit_value",
398 )
399 code = "max_length"
400
401 def compare(self, a, b):
402 return a > b
403
404 def clean(self, x):
405 return len(x)
406
407
408@deconstructible
409class DecimalValidator:
410 """
411 Validate that the input does not exceed the maximum number of digits
412 expected, otherwise raise ValidationError.
413 """
414
415 messages = {
416 "invalid": "Enter a number.",
417 "max_digits": pluralize_lazy(
418 "Ensure that there are no more than %(max)s digit in total.",
419 "Ensure that there are no more than %(max)s digits in total.",
420 "max",
421 ),
422 "max_decimal_places": pluralize_lazy(
423 "Ensure that there are no more than %(max)s decimal place.",
424 "Ensure that there are no more than %(max)s decimal places.",
425 "max",
426 ),
427 "max_whole_digits": pluralize_lazy(
428 "Ensure that there are no more than %(max)s digit before the decimal "
429 "point.",
430 "Ensure that there are no more than %(max)s digits before the decimal "
431 "point.",
432 "max",
433 ),
434 }
435
436 def __init__(self, max_digits, decimal_places):
437 self.max_digits = max_digits
438 self.decimal_places = decimal_places
439
440 def __call__(self, value):
441 digit_tuple, exponent = value.as_tuple()[1:]
442 if exponent in {"F", "n", "N"}:
443 raise ValidationError(
444 self.messages["invalid"], code="invalid", params={"value": value}
445 )
446 if exponent >= 0:
447 digits = len(digit_tuple)
448 if digit_tuple != (0,):
449 # A positive exponent adds that many trailing zeros.
450 digits += exponent
451 decimals = 0
452 else:
453 # If the absolute value of the negative exponent is larger than the
454 # number of digits, then it's the same as the number of digits,
455 # because it'll consume all of the digits in digit_tuple and then
456 # add abs(exponent) - len(digit_tuple) leading zeros after the
457 # decimal point.
458 if abs(exponent) > len(digit_tuple):
459 digits = decimals = abs(exponent)
460 else:
461 digits = len(digit_tuple)
462 decimals = abs(exponent)
463 whole_digits = digits - decimals
464
465 if self.max_digits is not None and digits > self.max_digits:
466 raise ValidationError(
467 self.messages["max_digits"],
468 code="max_digits",
469 params={"max": self.max_digits, "value": value},
470 )
471 if self.decimal_places is not None and decimals > self.decimal_places:
472 raise ValidationError(
473 self.messages["max_decimal_places"],
474 code="max_decimal_places",
475 params={"max": self.decimal_places, "value": value},
476 )
477 if (
478 self.max_digits is not None
479 and self.decimal_places is not None
480 and whole_digits > (self.max_digits - self.decimal_places)
481 ):
482 raise ValidationError(
483 self.messages["max_whole_digits"],
484 code="max_whole_digits",
485 params={"max": (self.max_digits - self.decimal_places), "value": value},
486 )
487
488 def __eq__(self, other):
489 return (
490 isinstance(other, self.__class__)
491 and self.max_digits == other.max_digits
492 and self.decimal_places == other.decimal_places
493 )
494
495
496@deconstructible
497class FileExtensionValidator:
498 message = "File extension โ%(extension)sโ is not allowed. Allowed extensions are: %(allowed_extensions)s."
499 code = "invalid_extension"
500
501 def __init__(self, allowed_extensions=None, message=None, code=None):
502 if allowed_extensions is not None:
503 allowed_extensions = [
504 allowed_extension.lower() for allowed_extension in allowed_extensions
505 ]
506 self.allowed_extensions = allowed_extensions
507 if message is not None:
508 self.message = message
509 if code is not None:
510 self.code = code
511
512 def __call__(self, value):
513 extension = Path(value.name).suffix[1:].lower()
514 if (
515 self.allowed_extensions is not None
516 and extension not in self.allowed_extensions
517 ):
518 raise ValidationError(
519 self.message,
520 code=self.code,
521 params={
522 "extension": extension,
523 "allowed_extensions": ", ".join(self.allowed_extensions),
524 "value": value,
525 },
526 )
527
528 def __eq__(self, other):
529 return (
530 isinstance(other, self.__class__)
531 and self.allowed_extensions == other.allowed_extensions
532 and self.message == other.message
533 and self.code == other.code
534 )
535
536
537def get_available_image_extensions():
538 try:
539 from PIL import Image
540 except ImportError:
541 return []
542 else:
543 Image.init()
544 return [ext.lower()[1:] for ext in Image.EXTENSION]
545
546
547def validate_image_file_extension(value):
548 return FileExtensionValidator(allowed_extensions=get_available_image_extensions())(
549 value
550 )
551
552
553@deconstructible
554class ProhibitNullCharactersValidator:
555 """Validate that the string doesn't contain the null character."""
556
557 message = "Null characters are not allowed."
558 code = "null_characters_not_allowed"
559
560 def __init__(self, message=None, code=None):
561 if message is not None:
562 self.message = message
563 if code is not None:
564 self.code = code
565
566 def __call__(self, value):
567 if "\x00" in str(value):
568 raise ValidationError(self.message, code=self.code, params={"value": value})
569
570 def __eq__(self, other):
571 return (
572 isinstance(other, self.__class__)
573 and self.message == other.message
574 and self.code == other.code
575 )