1from __future__ import annotations
  2
  3import decimal
  4from collections.abc import Callable, Sequence
  5from functools import cached_property
  6from typing import TYPE_CHECKING, Any
  7
  8from psycopg.types import numeric
  9
 10from plain import exceptions, validators
 11from plain.preflight import PreflightResult
 12
 13from .base import NOT_PROVIDED, DefaultableField
 14
 15if TYPE_CHECKING:
 16    from plain.postgres.connection import DatabaseConnection
 17
 18
 19class FloatField(DefaultableField[float]):
 20    db_type_sql = "double precision"
 21    empty_strings_allowed = False
 22
 23    def get_prep_value(self, value: Any) -> Any:
 24        value = super().get_prep_value(value)
 25        if value is None:
 26            return None
 27        try:
 28            return float(value)
 29        except (TypeError, ValueError) as e:
 30            raise e.__class__(
 31                f"Field '{self.name}' expected a number but got {value!r}.",
 32            ) from e
 33
 34    def to_python(self, value: Any) -> float | None:
 35        if value is None:
 36            return value
 37        try:
 38            return float(value)
 39        except (TypeError, ValueError):
 40            raise exceptions.ValidationError(
 41                '"%(value)s" value must be a float.',
 42                code="invalid",
 43                params={"value": value},
 44            )
 45
 46
 47class IntegerField(DefaultableField[int]):
 48    db_type_sql = "integer"
 49    integer_range: tuple[int, int] = (-2147483648, 2147483647)
 50    psycopg_type: type = numeric.Int4
 51    empty_strings_allowed = False
 52
 53    @cached_property
 54    def validators(self) -> list[Callable[..., Any]]:
 55        # These validators can't be added at field initialization time since
 56        # they're based on values retrieved from the database connection.
 57        validators_ = super().validators
 58        min_value, max_value = self.integer_range
 59        if min_value is not None and not any(
 60            (
 61                isinstance(validator, validators.MinValueValidator)
 62                and (
 63                    validator.limit_value()
 64                    if callable(validator.limit_value)
 65                    else validator.limit_value
 66                )
 67                >= min_value
 68            )
 69            for validator in validators_
 70        ):
 71            validators_.append(validators.MinValueValidator(min_value))
 72        if max_value is not None and not any(
 73            (
 74                isinstance(validator, validators.MaxValueValidator)
 75                and (
 76                    validator.limit_value()
 77                    if callable(validator.limit_value)
 78                    else validator.limit_value
 79                )
 80                <= max_value
 81            )
 82            for validator in validators_
 83        ):
 84            validators_.append(validators.MaxValueValidator(max_value))
 85        return validators_
 86
 87    def get_prep_value(self, value: Any) -> Any:
 88        value = super().get_prep_value(value)
 89        if value is None:
 90            return None
 91        try:
 92            return int(value)
 93        except (TypeError, ValueError) as e:
 94            raise e.__class__(
 95                f"Field '{self.name}' expected a number but got {value!r}.",
 96            ) from e
 97
 98    def get_db_prep_value(
 99        self, value: Any, connection: DatabaseConnection, prepared: bool = False
100    ) -> Any:
101        from plain.postgres.expressions import ResolvableExpression
102
103        value = super().get_db_prep_value(value, connection, prepared)
104        if value is None or isinstance(value, ResolvableExpression):
105            return value
106        return self.psycopg_type(value)
107
108    def to_python(self, value: Any) -> int | None:
109        if value is None:
110            return value
111        try:
112            return int(value)
113        except (TypeError, ValueError):
114            raise exceptions.ValidationError(
115                '"%(value)s" value must be an integer.',
116                code="invalid",
117                params={"value": value},
118            )
119
120
121class BigIntegerField(IntegerField):
122    db_type_sql = "bigint"
123    integer_range = (-9223372036854775808, 9223372036854775807)
124    psycopg_type = numeric.Int8
125
126
127class SmallIntegerField(IntegerField):
128    db_type_sql = "smallint"
129    integer_range = (-32768, 32767)
130    psycopg_type = numeric.Int2
131
132
133class DecimalField(DefaultableField[decimal.Decimal]):
134    db_type_sql = "numeric(%(max_digits)s,%(decimal_places)s)"
135    empty_strings_allowed = False
136
137    def __init__(
138        self,
139        *,
140        max_digits: int | None = None,
141        decimal_places: int | None = None,
142        required: bool = True,
143        allow_null: bool = False,
144        default: Any = NOT_PROVIDED,
145        validators: Sequence[Callable[..., Any]] = (),
146    ):
147        self.max_digits, self.decimal_places = max_digits, decimal_places
148        super().__init__(
149            required=required,
150            allow_null=allow_null,
151            default=default,
152            validators=validators,
153        )
154
155    def preflight(self, **kwargs: Any) -> list[PreflightResult]:
156        errors = super().preflight(**kwargs)
157
158        digits_errors = [
159            *self._check_decimal_places(),
160            *self._check_max_digits(),
161        ]
162        if not digits_errors:
163            errors.extend(self._check_decimal_places_and_max_digits())
164        else:
165            errors.extend(digits_errors)
166        return errors
167
168    def _check_decimal_places(self) -> list[PreflightResult]:
169        if self.decimal_places is None:
170            return [
171                PreflightResult(
172                    fix="DecimalFields must define a 'decimal_places' attribute.",
173                    obj=self,
174                    id="fields.decimalfield_missing_decimal_places",
175                )
176            ]
177        try:
178            decimal_places = int(self.decimal_places)
179            if decimal_places < 0:
180                raise ValueError()
181        except ValueError:
182            return [
183                PreflightResult(
184                    fix="'decimal_places' must be a non-negative integer.",
185                    obj=self,
186                    id="fields.decimalfield_invalid_decimal_places",
187                )
188            ]
189        else:
190            return []
191
192    def _check_max_digits(self) -> list[PreflightResult]:
193        if self.max_digits is None:
194            return [
195                PreflightResult(
196                    fix="DecimalFields must define a 'max_digits' attribute.",
197                    obj=self,
198                    id="fields.decimalfield_missing_max_digits",
199                )
200            ]
201        try:
202            max_digits = int(self.max_digits)
203            if max_digits <= 0:
204                raise ValueError()
205        except ValueError:
206            return [
207                PreflightResult(
208                    fix="'max_digits' must be a positive integer.",
209                    obj=self,
210                    id="fields.decimalfield_invalid_max_digits",
211                )
212            ]
213        else:
214            return []
215
216    def _check_decimal_places_and_max_digits(self) -> list[PreflightResult]:
217        if self.decimal_places is None or self.max_digits is None:
218            return []
219        if self.decimal_places > self.max_digits:
220            return [
221                PreflightResult(
222                    fix="'max_digits' must be greater or equal to 'decimal_places'.",
223                    obj=self,
224                    id="fields.decimalfield_decimal_places_exceeds_max_digits",
225                )
226            ]
227        return []
228
229    @cached_property
230    def validators(self) -> list[Callable[..., Any]]:
231        return super().validators + [
232            validators.DecimalValidator(self.max_digits, self.decimal_places)
233        ]
234
235    @cached_property
236    def context(self) -> decimal.Context:
237        return decimal.Context(prec=self.max_digits)
238
239    def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
240        name, path, args, kwargs = super().deconstruct()
241        if self.max_digits is not None:
242            kwargs["max_digits"] = self.max_digits
243        if self.decimal_places is not None:
244            kwargs["decimal_places"] = self.decimal_places
245        return name, path, args, kwargs
246
247    def to_python(self, value: Any) -> decimal.Decimal | None:
248        if value is None:
249            return value
250        try:
251            if isinstance(value, float):
252                decimal_value = self.context.create_decimal_from_float(value)
253            else:
254                decimal_value = decimal.Decimal(value)
255        except (decimal.InvalidOperation, TypeError, ValueError):
256            raise exceptions.ValidationError(
257                '"%(value)s" value must be a decimal number.',
258                code="invalid",
259                params={"value": value},
260            )
261        if not decimal_value.is_finite():
262            raise exceptions.ValidationError(
263                '"%(value)s" value must be a decimal number.',
264                code="invalid",
265                params={"value": value},
266            )
267        return decimal_value
268
269    def get_db_prep_value(
270        self, value: Any, connection: DatabaseConnection, prepared: bool = False
271    ) -> Any:
272        if not prepared:
273            value = self.get_prep_value(value)
274        return value
275
276    def get_prep_value(self, value: Any) -> Any:
277        value = super().get_prep_value(value)
278        return self.to_python(value)