Plain is headed towards 1.0! Subscribe for development updates →

 1from __future__ import annotations
 2
 3import sys
 4from typing import TYPE_CHECKING, Any
 5
 6from plain.models.fields import DecimalField, FloatField, IntegerField
 7from plain.models.functions import Cast
 8
 9if TYPE_CHECKING:
10    from plain.models.backends.base.base import BaseDatabaseWrapper
11    from plain.models.sql.compiler import SQLCompiler
12
13
14class FixDecimalInputMixin:
15    def as_postgresql(
16        self,
17        compiler: SQLCompiler,
18        connection: BaseDatabaseWrapper,
19        **extra_context: Any,
20    ) -> tuple[str, tuple[Any, ...]]:
21        # Cast FloatField to DecimalField as PostgreSQL doesn't support the
22        # following function signatures:
23        # - LOG(double, double)
24        # - MOD(double, double)
25        output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
26        clone = self.copy()
27        clone.set_source_expressions(
28            [
29                Cast(expression, output_field)
30                if isinstance(expression.output_field, FloatField)
31                else expression
32                for expression in self.get_source_expressions()
33            ]
34        )
35        return clone.as_sql(compiler, connection, **extra_context)
36
37
38class FixDurationInputMixin:
39    def as_mysql(
40        self,
41        compiler: SQLCompiler,
42        connection: BaseDatabaseWrapper,
43        **extra_context: Any,
44    ) -> tuple[str, tuple[Any, ...]]:
45        sql, params = super().as_sql(compiler, connection, **extra_context)  # type: ignore[misc]
46        if self.output_field.get_internal_type() == "DurationField":  # type: ignore[attr-defined]
47            sql = f"CAST({sql} AS SIGNED)"
48        return sql, params
49
50
51class NumericOutputFieldMixin:
52    def _resolve_output_field(self) -> DecimalField | FloatField:
53        source_fields = self.get_source_fields()  # type: ignore[attr-defined]
54        if any(isinstance(s, DecimalField) for s in source_fields):
55            return DecimalField()
56        if any(isinstance(s, IntegerField) for s in source_fields):
57            return FloatField()
58        return super()._resolve_output_field() if source_fields else FloatField()  # type: ignore[misc]