1from __future__ import annotations
2
3import sys
4from typing import TYPE_CHECKING, Any
5
6from plain.models.fields import DecimalField, FloatField, IntegerField
7from plain.models.functions import Cast
8
9if TYPE_CHECKING:
10 from plain.models.backends.base.base import BaseDatabaseWrapper
11 from plain.models.sql.compiler import SQLCompiler
12
13
14class FixDecimalInputMixin:
15 def as_postgresql(
16 self,
17 compiler: SQLCompiler,
18 connection: BaseDatabaseWrapper,
19 **extra_context: Any,
20 ) -> tuple[str, tuple[Any, ...]]:
21 # Cast FloatField to DecimalField as PostgreSQL doesn't support the
22 # following function signatures:
23 # - LOG(double, double)
24 # - MOD(double, double)
25 output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
26 clone = self.copy()
27 clone.set_source_expressions(
28 [
29 Cast(expression, output_field)
30 if isinstance(expression.output_field, FloatField)
31 else expression
32 for expression in self.get_source_expressions()
33 ]
34 )
35 return clone.as_sql(compiler, connection, **extra_context)
36
37
38class FixDurationInputMixin:
39 def as_mysql(
40 self,
41 compiler: SQLCompiler,
42 connection: BaseDatabaseWrapper,
43 **extra_context: Any,
44 ) -> tuple[str, tuple[Any, ...]]:
45 sql, params = super().as_sql(compiler, connection, **extra_context) # type: ignore[misc]
46 if self.output_field.get_internal_type() == "DurationField": # type: ignore[attr-defined]
47 sql = f"CAST({sql} AS SIGNED)"
48 return sql, params
49
50
51class NumericOutputFieldMixin:
52 def _resolve_output_field(self) -> DecimalField | FloatField:
53 source_fields = self.get_source_fields() # type: ignore[attr-defined]
54 if any(isinstance(s, DecimalField) for s in source_fields):
55 return DecimalField()
56 if any(isinstance(s, IntegerField) for s in source_fields):
57 return FloatField()
58 return super()._resolve_output_field() if source_fields else FloatField() # type: ignore[misc]