1from __future__ import annotations
2
3from collections.abc import Callable, Sequence
4from typing import TYPE_CHECKING, Any
5
6from plain import exceptions
7from plain.postgres.dialect import adapt_ipaddressfield_value
8from plain.preflight import PreflightResult
9from plain.utils.ipv6 import clean_ipv6_address
10from plain.validators import ip_address_validators
11
12from .base import NOT_PROVIDED, DefaultableField
13
14if TYPE_CHECKING:
15 from plain.postgres.connection import DatabaseConnection
16
17
18class GenericIPAddressField(DefaultableField[str]):
19 db_type_sql = "inet"
20 empty_strings_allowed = False
21
22 def __init__(
23 self,
24 *,
25 protocol: str = "both",
26 unpack_ipv4: bool = False,
27 required: bool = True,
28 allow_null: bool = False,
29 default: Any = NOT_PROVIDED,
30 validators: Sequence[Callable[..., Any]] = (),
31 ):
32 self.unpack_ipv4 = unpack_ipv4
33 self.protocol = protocol
34 (
35 self.default_validators,
36 self.invalid_error_message,
37 ) = ip_address_validators(protocol, unpack_ipv4)
38 super().__init__(
39 required=required,
40 allow_null=allow_null,
41 default=default,
42 validators=validators,
43 )
44
45 def preflight(self, **kwargs: Any) -> list[PreflightResult]:
46 return [
47 *super().preflight(**kwargs),
48 *self._check_required_and_null_values(),
49 ]
50
51 def _check_required_and_null_values(self) -> list[PreflightResult]:
52 if not getattr(self, "allow_null", False) and not getattr(
53 self, "required", True
54 ):
55 return [
56 PreflightResult(
57 fix="GenericIPAddressFields cannot have required=False if allow_null=False, "
58 "as blank values are stored as nulls.",
59 obj=self,
60 id="fields.generic_ip_field_null_blank_config",
61 )
62 ]
63 return []
64
65 def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
66 name, path, args, kwargs = super().deconstruct()
67 if self.unpack_ipv4 is not False:
68 kwargs["unpack_ipv4"] = self.unpack_ipv4
69 if self.protocol != "both":
70 kwargs["protocol"] = self.protocol
71 return name, path, args, kwargs
72
73 def to_python(self, value: Any) -> str | None:
74 if value is None:
75 return None
76 if not isinstance(value, str):
77 value = str(value)
78 value = value.strip()
79 if ":" in value:
80 return clean_ipv6_address(
81 value, self.unpack_ipv4, self.invalid_error_message
82 )
83 return value
84
85 def get_db_prep_value(
86 self, value: Any, connection: DatabaseConnection, prepared: bool = False
87 ) -> Any:
88 if not prepared:
89 value = self.get_prep_value(value)
90 return adapt_ipaddressfield_value(value)
91
92 def get_prep_value(self, value: Any) -> Any:
93 value = super().get_prep_value(value)
94 if value is None:
95 return None
96 if value and ":" in value:
97 try:
98 return clean_ipv6_address(value, self.unpack_ipv4)
99 except exceptions.ValidationError:
100 pass
101 return str(value)