1from __future__ import annotations
  2
  3from collections.abc import Callable, Sequence
  4from typing import TYPE_CHECKING, Any
  5
  6from plain import exceptions
  7from plain.postgres.dialect import adapt_ipaddressfield_value
  8from plain.preflight import PreflightResult
  9from plain.utils.ipv6 import clean_ipv6_address
 10from plain.validators import ip_address_validators
 11
 12from .base import NOT_PROVIDED, DefaultableField
 13
 14if TYPE_CHECKING:
 15    from plain.postgres.connection import DatabaseConnection
 16
 17
 18class GenericIPAddressField(DefaultableField[str]):
 19    db_type_sql = "inet"
 20    empty_strings_allowed = False
 21
 22    def __init__(
 23        self,
 24        *,
 25        protocol: str = "both",
 26        unpack_ipv4: bool = False,
 27        required: bool = True,
 28        allow_null: bool = False,
 29        default: Any = NOT_PROVIDED,
 30        validators: Sequence[Callable[..., Any]] = (),
 31    ):
 32        self.unpack_ipv4 = unpack_ipv4
 33        self.protocol = protocol
 34        (
 35            self.default_validators,
 36            self.invalid_error_message,
 37        ) = ip_address_validators(protocol, unpack_ipv4)
 38        super().__init__(
 39            required=required,
 40            allow_null=allow_null,
 41            default=default,
 42            validators=validators,
 43        )
 44
 45    def preflight(self, **kwargs: Any) -> list[PreflightResult]:
 46        return [
 47            *super().preflight(**kwargs),
 48            *self._check_required_and_null_values(),
 49        ]
 50
 51    def _check_required_and_null_values(self) -> list[PreflightResult]:
 52        if not getattr(self, "allow_null", False) and not getattr(
 53            self, "required", True
 54        ):
 55            return [
 56                PreflightResult(
 57                    fix="GenericIPAddressFields cannot have required=False if allow_null=False, "
 58                    "as blank values are stored as nulls.",
 59                    obj=self,
 60                    id="fields.generic_ip_field_null_blank_config",
 61                )
 62            ]
 63        return []
 64
 65    def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
 66        name, path, args, kwargs = super().deconstruct()
 67        if self.unpack_ipv4 is not False:
 68            kwargs["unpack_ipv4"] = self.unpack_ipv4
 69        if self.protocol != "both":
 70            kwargs["protocol"] = self.protocol
 71        return name, path, args, kwargs
 72
 73    def to_python(self, value: Any) -> str | None:
 74        if value is None:
 75            return None
 76        if not isinstance(value, str):
 77            value = str(value)
 78        value = value.strip()
 79        if ":" in value:
 80            return clean_ipv6_address(
 81                value, self.unpack_ipv4, self.invalid_error_message
 82            )
 83        return value
 84
 85    def get_db_prep_value(
 86        self, value: Any, connection: DatabaseConnection, prepared: bool = False
 87    ) -> Any:
 88        if not prepared:
 89            value = self.get_prep_value(value)
 90        return adapt_ipaddressfield_value(value)
 91
 92    def get_prep_value(self, value: Any) -> Any:
 93        value = super().get_prep_value(value)
 94        if value is None:
 95            return None
 96        if value and ":" in value:
 97            try:
 98                return clean_ipv6_address(value, self.unpack_ipv4)
 99            except exceptions.ValidationError:
100                pass
101        return str(value)