1from __future__ import annotations
 2
 3import sys
 4from typing import TYPE_CHECKING, Any
 5
 6from plain.models.expressions import Func
 7from plain.models.fields import DecimalField, Field, FloatField, IntegerField
 8from plain.models.functions import Cast
 9
10if TYPE_CHECKING:
11    from plain.models.postgres.wrapper import DatabaseWrapper
12    from plain.models.sql.compiler import SQLCompiler
13
14
15class FixDecimalInputMixin(Func):
16    """
17    Mixin for Func subclasses that need to convert FloatField to DecimalField.
18
19    PostgreSQL doesn't support the following function signatures:
20    - LOG(double, double)
21    - MOD(double, double)
22    """
23
24    def as_sql(
25        self,
26        compiler: SQLCompiler,
27        connection: DatabaseWrapper,
28        function: str | None = None,
29        template: str | None = None,
30        arg_joiner: str | None = None,
31        **extra_context: Any,
32    ) -> tuple[str, list[Any]]:
33        output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
34
35        clone = self.copy()
36        clone.set_source_expressions(
37            [
38                Cast(expression, output_field)
39                if isinstance(expression.output_field, FloatField)
40                else expression
41                for expression in self.get_source_expressions()
42            ]
43        )
44        return super(FixDecimalInputMixin, clone).as_sql(
45            compiler, connection, **extra_context
46        )
47
48
49class NumericOutputFieldMixin(Func):
50    def _resolve_output_field(self) -> DecimalField | FloatField | Field:
51        source_fields = self.get_source_fields()
52        if any(isinstance(s, DecimalField) for s in source_fields):
53            return DecimalField()
54        if any(isinstance(s, IntegerField) for s in source_fields):
55            return FloatField()
56        if source_fields:
57            if result := super()._resolve_output_field():
58                return result
59        return FloatField()