1from __future__ import annotations
 2
 3import sys
 4from typing import TYPE_CHECKING, Any
 5
 6from plain.models.expressions import Func
 7from plain.models.fields import DecimalField, Field, FloatField, IntegerField
 8from plain.models.functions import Cast
 9
10if TYPE_CHECKING:
11    from plain.models.backends.base.base import BaseDatabaseWrapper
12    from plain.models.sql.compiler import SQLCompiler
13
14
15class FixDecimalInputMixin(Func):
16    """Mixin for Func subclasses that need to convert FloatField to DecimalField on PostgreSQL."""
17
18    def as_postgresql(
19        self,
20        compiler: SQLCompiler,
21        connection: BaseDatabaseWrapper,
22        **extra_context: Any,
23    ) -> tuple[str, list[Any]]:
24        # Cast FloatField to DecimalField as PostgreSQL doesn't support the
25        # following function signatures:
26        # - LOG(double, double)
27        # - MOD(double, double)
28        output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
29
30        clone = self.copy()
31        clone.set_source_expressions(
32            [
33                Cast(expression, output_field)
34                if isinstance(expression.output_field, FloatField)
35                else expression
36                for expression in self.get_source_expressions()
37            ]
38        )
39        return clone.as_sql(compiler, connection, **extra_context)
40
41
42class FixDurationInputMixin(Func):
43    def as_mysql(
44        self,
45        compiler: SQLCompiler,
46        connection: BaseDatabaseWrapper,
47        **extra_context: Any,
48    ) -> tuple[str, list[Any]]:
49        sql, params = super().as_sql(compiler, connection, **extra_context)
50        if self.output_field.get_internal_type() == "DurationField":
51            sql = f"CAST({sql} AS SIGNED)"
52        return sql, params
53
54
55class NumericOutputFieldMixin(Func):
56    def _resolve_output_field(self) -> DecimalField | FloatField | Field:
57        source_fields = self.get_source_fields()
58        if any(isinstance(s, DecimalField) for s in source_fields):
59            return DecimalField()
60        if any(isinstance(s, IntegerField) for s in source_fields):
61            return FloatField()
62        if source_fields:
63            if result := super()._resolve_output_field():
64                return result
65        return FloatField()