1from __future__ import annotations
2
3import sys
4from typing import TYPE_CHECKING, Any
5
6from plain.models.expressions import Func
7from plain.models.fields import DecimalField, Field, FloatField, IntegerField
8from plain.models.functions import Cast
9
10if TYPE_CHECKING:
11 from plain.models.postgres.wrapper import DatabaseWrapper
12 from plain.models.sql.compiler import SQLCompiler
13
14
15class FixDecimalInputMixin(Func):
16 """
17 Mixin for Func subclasses that need to convert FloatField to DecimalField.
18
19 PostgreSQL doesn't support the following function signatures:
20 - LOG(double, double)
21 - MOD(double, double)
22 """
23
24 def as_sql(
25 self,
26 compiler: SQLCompiler,
27 connection: DatabaseWrapper,
28 function: str | None = None,
29 template: str | None = None,
30 arg_joiner: str | None = None,
31 **extra_context: Any,
32 ) -> tuple[str, list[Any]]:
33 output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
34
35 clone = self.copy()
36 clone.set_source_expressions(
37 [
38 Cast(expression, output_field)
39 if isinstance(expression.output_field, FloatField)
40 else expression
41 for expression in self.get_source_expressions()
42 ]
43 )
44 return super(FixDecimalInputMixin, clone).as_sql(
45 compiler, connection, **extra_context
46 )
47
48
49class NumericOutputFieldMixin(Func):
50 def _resolve_output_field(self) -> DecimalField | FloatField | Field:
51 source_fields = self.get_source_fields()
52 if any(isinstance(s, DecimalField) for s in source_fields):
53 return DecimalField()
54 if any(isinstance(s, IntegerField) for s in source_fields):
55 return FloatField()
56 if source_fields:
57 if result := super()._resolve_output_field():
58 return result
59 return FloatField()