Plain is headed towards 1.0! Subscribe for development updates →

 1import sys
 2
 3from plain.models.fields import DecimalField, FloatField, IntegerField
 4from plain.models.functions import Cast
 5
 6
 7class FixDecimalInputMixin:
 8    def as_postgresql(self, compiler, connection, **extra_context):
 9        # Cast FloatField to DecimalField as PostgreSQL doesn't support the
10        # following function signatures:
11        # - LOG(double, double)
12        # - MOD(double, double)
13        output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
14        clone = self.copy()
15        clone.set_source_expressions(
16            [
17                Cast(expression, output_field)
18                if isinstance(expression.output_field, FloatField)
19                else expression
20                for expression in self.get_source_expressions()
21            ]
22        )
23        return clone.as_sql(compiler, connection, **extra_context)
24
25
26class FixDurationInputMixin:
27    def as_mysql(self, compiler, connection, **extra_context):
28        sql, params = super().as_sql(compiler, connection, **extra_context)
29        if self.output_field.get_internal_type() == "DurationField":
30            sql = "CAST(%s AS SIGNED)" % sql
31        return sql, params
32
33
34class NumericOutputFieldMixin:
35    def _resolve_output_field(self):
36        source_fields = self.get_source_fields()
37        if any(isinstance(s, DecimalField) for s in source_fields):
38            return DecimalField()
39        if any(isinstance(s, IntegerField) for s in source_fields):
40            return FloatField()
41        return super()._resolve_output_field() if source_fields else FloatField()