1from __future__ import annotations
2
3import decimal
4from collections.abc import Callable, Sequence
5from functools import cached_property
6from typing import TYPE_CHECKING, Any
7
8from psycopg.types import numeric
9
10from plain import exceptions, validators
11from plain.preflight import PreflightResult
12
13from .base import NOT_PROVIDED, DefaultableField
14
15if TYPE_CHECKING:
16 from plain.postgres.connection import DatabaseConnection
17
18
19class FloatField(DefaultableField[float]):
20 db_type_sql = "double precision"
21 empty_strings_allowed = False
22
23 def get_prep_value(self, value: Any) -> Any:
24 value = super().get_prep_value(value)
25 if value is None:
26 return None
27 try:
28 return float(value)
29 except (TypeError, ValueError) as e:
30 raise e.__class__(
31 f"Field '{self.name}' expected a number but got {value!r}.",
32 ) from e
33
34 def to_python(self, value: Any) -> float | None:
35 if value is None:
36 return value
37 try:
38 return float(value)
39 except (TypeError, ValueError):
40 raise exceptions.ValidationError(
41 '"%(value)s" value must be a float.',
42 code="invalid",
43 params={"value": value},
44 )
45
46
47class IntegerField(DefaultableField[int]):
48 db_type_sql = "integer"
49 integer_range: tuple[int, int] = (-2147483648, 2147483647)
50 psycopg_type: type = numeric.Int4
51 empty_strings_allowed = False
52
53 @cached_property
54 def validators(self) -> list[Callable[..., Any]]:
55 # These validators can't be added at field initialization time since
56 # they're based on values retrieved from the database connection.
57 validators_ = super().validators
58 min_value, max_value = self.integer_range
59 if min_value is not None and not any(
60 (
61 isinstance(validator, validators.MinValueValidator)
62 and (
63 validator.limit_value()
64 if callable(validator.limit_value)
65 else validator.limit_value
66 )
67 >= min_value
68 )
69 for validator in validators_
70 ):
71 validators_.append(validators.MinValueValidator(min_value))
72 if max_value is not None and not any(
73 (
74 isinstance(validator, validators.MaxValueValidator)
75 and (
76 validator.limit_value()
77 if callable(validator.limit_value)
78 else validator.limit_value
79 )
80 <= max_value
81 )
82 for validator in validators_
83 ):
84 validators_.append(validators.MaxValueValidator(max_value))
85 return validators_
86
87 def get_prep_value(self, value: Any) -> Any:
88 value = super().get_prep_value(value)
89 if value is None:
90 return None
91 try:
92 return int(value)
93 except (TypeError, ValueError) as e:
94 raise e.__class__(
95 f"Field '{self.name}' expected a number but got {value!r}.",
96 ) from e
97
98 def get_db_prep_value(
99 self, value: Any, connection: DatabaseConnection, prepared: bool = False
100 ) -> Any:
101 from plain.postgres.expressions import ResolvableExpression
102
103 value = super().get_db_prep_value(value, connection, prepared)
104 if value is None or isinstance(value, ResolvableExpression):
105 return value
106 return self.psycopg_type(value)
107
108 def to_python(self, value: Any) -> int | None:
109 if value is None:
110 return value
111 try:
112 return int(value)
113 except (TypeError, ValueError):
114 raise exceptions.ValidationError(
115 '"%(value)s" value must be an integer.',
116 code="invalid",
117 params={"value": value},
118 )
119
120
121class BigIntegerField(IntegerField):
122 db_type_sql = "bigint"
123 integer_range = (-9223372036854775808, 9223372036854775807)
124 psycopg_type = numeric.Int8
125
126
127class SmallIntegerField(IntegerField):
128 db_type_sql = "smallint"
129 integer_range = (-32768, 32767)
130 psycopg_type = numeric.Int2
131
132
133class DecimalField(DefaultableField[decimal.Decimal]):
134 db_type_sql = "numeric(%(max_digits)s,%(decimal_places)s)"
135 empty_strings_allowed = False
136
137 def __init__(
138 self,
139 *,
140 max_digits: int | None = None,
141 decimal_places: int | None = None,
142 required: bool = True,
143 allow_null: bool = False,
144 default: Any = NOT_PROVIDED,
145 validators: Sequence[Callable[..., Any]] = (),
146 ):
147 self.max_digits, self.decimal_places = max_digits, decimal_places
148 super().__init__(
149 required=required,
150 allow_null=allow_null,
151 default=default,
152 validators=validators,
153 )
154
155 def preflight(self, **kwargs: Any) -> list[PreflightResult]:
156 errors = super().preflight(**kwargs)
157
158 digits_errors = [
159 *self._check_decimal_places(),
160 *self._check_max_digits(),
161 ]
162 if not digits_errors:
163 errors.extend(self._check_decimal_places_and_max_digits())
164 else:
165 errors.extend(digits_errors)
166 return errors
167
168 def _check_decimal_places(self) -> list[PreflightResult]:
169 if self.decimal_places is None:
170 return [
171 PreflightResult(
172 fix="DecimalFields must define a 'decimal_places' attribute.",
173 obj=self,
174 id="fields.decimalfield_missing_decimal_places",
175 )
176 ]
177 try:
178 decimal_places = int(self.decimal_places)
179 if decimal_places < 0:
180 raise ValueError()
181 except ValueError:
182 return [
183 PreflightResult(
184 fix="'decimal_places' must be a non-negative integer.",
185 obj=self,
186 id="fields.decimalfield_invalid_decimal_places",
187 )
188 ]
189 else:
190 return []
191
192 def _check_max_digits(self) -> list[PreflightResult]:
193 if self.max_digits is None:
194 return [
195 PreflightResult(
196 fix="DecimalFields must define a 'max_digits' attribute.",
197 obj=self,
198 id="fields.decimalfield_missing_max_digits",
199 )
200 ]
201 try:
202 max_digits = int(self.max_digits)
203 if max_digits <= 0:
204 raise ValueError()
205 except ValueError:
206 return [
207 PreflightResult(
208 fix="'max_digits' must be a positive integer.",
209 obj=self,
210 id="fields.decimalfield_invalid_max_digits",
211 )
212 ]
213 else:
214 return []
215
216 def _check_decimal_places_and_max_digits(self) -> list[PreflightResult]:
217 if self.decimal_places is None or self.max_digits is None:
218 return []
219 if self.decimal_places > self.max_digits:
220 return [
221 PreflightResult(
222 fix="'max_digits' must be greater or equal to 'decimal_places'.",
223 obj=self,
224 id="fields.decimalfield_decimal_places_exceeds_max_digits",
225 )
226 ]
227 return []
228
229 @cached_property
230 def validators(self) -> list[Callable[..., Any]]:
231 return super().validators + [
232 validators.DecimalValidator(self.max_digits, self.decimal_places)
233 ]
234
235 @cached_property
236 def context(self) -> decimal.Context:
237 return decimal.Context(prec=self.max_digits)
238
239 def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
240 name, path, args, kwargs = super().deconstruct()
241 if self.max_digits is not None:
242 kwargs["max_digits"] = self.max_digits
243 if self.decimal_places is not None:
244 kwargs["decimal_places"] = self.decimal_places
245 return name, path, args, kwargs
246
247 def to_python(self, value: Any) -> decimal.Decimal | None:
248 if value is None:
249 return value
250 try:
251 if isinstance(value, float):
252 decimal_value = self.context.create_decimal_from_float(value)
253 else:
254 decimal_value = decimal.Decimal(value)
255 except (decimal.InvalidOperation, TypeError, ValueError):
256 raise exceptions.ValidationError(
257 '"%(value)s" value must be a decimal number.',
258 code="invalid",
259 params={"value": value},
260 )
261 if not decimal_value.is_finite():
262 raise exceptions.ValidationError(
263 '"%(value)s" value must be a decimal number.',
264 code="invalid",
265 params={"value": value},
266 )
267 return decimal_value
268
269 def get_db_prep_value(
270 self, value: Any, connection: DatabaseConnection, prepared: bool = False
271 ) -> Any:
272 if not prepared:
273 value = self.get_prep_value(value)
274 return value
275
276 def get_prep_value(self, value: Any) -> Any:
277 value = super().get_prep_value(value)
278 return self.to_python(value)