1from __future__ import annotations
2
3import sys
4from typing import TYPE_CHECKING, Any
5
6from plain.models.expressions import Func
7from plain.models.fields import DecimalField, Field, FloatField, IntegerField
8from plain.models.functions import Cast
9
10if TYPE_CHECKING:
11 from plain.models.backends.base.base import BaseDatabaseWrapper
12 from plain.models.sql.compiler import SQLCompiler
13
14
15class FixDecimalInputMixin(Func):
16 """Mixin for Func subclasses that need to convert FloatField to DecimalField on PostgreSQL."""
17
18 def as_postgresql(
19 self,
20 compiler: SQLCompiler,
21 connection: BaseDatabaseWrapper,
22 **extra_context: Any,
23 ) -> tuple[str, list[Any]]:
24 # Cast FloatField to DecimalField as PostgreSQL doesn't support the
25 # following function signatures:
26 # - LOG(double, double)
27 # - MOD(double, double)
28 output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
29
30 clone = self.copy()
31 clone.set_source_expressions(
32 [
33 Cast(expression, output_field)
34 if isinstance(expression.output_field, FloatField)
35 else expression
36 for expression in self.get_source_expressions()
37 ]
38 )
39 return clone.as_sql(compiler, connection, **extra_context)
40
41
42class FixDurationInputMixin(Func):
43 def as_mysql(
44 self,
45 compiler: SQLCompiler,
46 connection: BaseDatabaseWrapper,
47 **extra_context: Any,
48 ) -> tuple[str, list[Any]]:
49 sql, params = super().as_sql(compiler, connection, **extra_context)
50 if self.output_field.get_internal_type() == "DurationField":
51 sql = f"CAST({sql} AS SIGNED)"
52 return sql, params
53
54
55class NumericOutputFieldMixin(Func):
56 def _resolve_output_field(self) -> DecimalField | FloatField | Field:
57 source_fields = self.get_source_fields()
58 if any(isinstance(s, DecimalField) for s in source_fields):
59 return DecimalField()
60 if any(isinstance(s, IntegerField) for s in source_fields):
61 return FloatField()
62 if source_fields:
63 if result := super()._resolve_output_field():
64 return result
65 return FloatField()