Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5from plain.models.expressions import Func, Value
  6from plain.models.fields import Field, FloatField, IntegerField
  7from plain.models.functions import Cast
  8from plain.models.functions.mixins import (
  9    FixDecimalInputMixin,
 10    NumericOutputFieldMixin,
 11)
 12from plain.models.lookups import Transform
 13
 14if TYPE_CHECKING:
 15    from plain.models.backends.base.base import BaseDatabaseWrapper
 16    from plain.models.sql.compiler import SQLCompiler
 17
 18
 19class Abs(Transform):
 20    function = "ABS"
 21    lookup_name = "abs"
 22
 23
 24class ACos(NumericOutputFieldMixin, Transform):
 25    function = "ACOS"
 26    lookup_name = "acos"
 27
 28
 29class ASin(NumericOutputFieldMixin, Transform):
 30    function = "ASIN"
 31    lookup_name = "asin"
 32
 33
 34class ATan(NumericOutputFieldMixin, Transform):
 35    function = "ATAN"
 36    lookup_name = "atan"
 37
 38
 39class ATan2(NumericOutputFieldMixin, Func):
 40    function = "ATAN2"
 41    arity = 2
 42
 43    def as_sqlite(
 44        self,
 45        compiler: SQLCompiler,
 46        connection: BaseDatabaseWrapper,
 47        **extra_context: Any,
 48    ) -> tuple[str, tuple[Any, ...]]:
 49        if not getattr(
 50            connection.ops, "spatialite", False
 51        ) or connection.ops.spatial_version >= (5, 0, 0):  # type: ignore[attr-defined]
 52            return self.as_sql(compiler, connection)
 53        # This function is usually ATan2(y, x), returning the inverse tangent
 54        # of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
 55        # Cast integers to float to avoid inconsistent/buggy behavior if the
 56        # arguments are mixed between integer and float or decimal.
 57        # https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
 58        clone = self.copy()
 59        clone.set_source_expressions(
 60            [
 61                Cast(expression, FloatField())
 62                if isinstance(expression.output_field, IntegerField)
 63                else expression
 64                for expression in self.get_source_expressions()[::-1]
 65            ]
 66        )
 67        return clone.as_sql(compiler, connection, **extra_context)
 68
 69
 70class Ceil(Transform):
 71    function = "CEILING"
 72    lookup_name = "ceil"
 73
 74
 75class Cos(NumericOutputFieldMixin, Transform):
 76    function = "COS"
 77    lookup_name = "cos"
 78
 79
 80class Cot(NumericOutputFieldMixin, Transform):
 81    function = "COT"
 82    lookup_name = "cot"
 83
 84
 85class Degrees(NumericOutputFieldMixin, Transform):
 86    function = "DEGREES"
 87    lookup_name = "degrees"
 88
 89
 90class Exp(NumericOutputFieldMixin, Transform):
 91    function = "EXP"
 92    lookup_name = "exp"
 93
 94
 95class Floor(Transform):
 96    function = "FLOOR"
 97    lookup_name = "floor"
 98
 99
100class Ln(NumericOutputFieldMixin, Transform):
101    function = "LN"
102    lookup_name = "ln"
103
104
105class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
106    function = "LOG"
107    arity = 2
108
109    def as_sqlite(
110        self,
111        compiler: SQLCompiler,
112        connection: BaseDatabaseWrapper,
113        **extra_context: Any,
114    ) -> tuple[str, tuple[Any, ...]]:
115        if not getattr(connection.ops, "spatialite", False):
116            return self.as_sql(compiler, connection)
117        # This function is usually Log(b, x) returning the logarithm of x to
118        # the base b, but on SpatiaLite it's Log(x, b).
119        clone = self.copy()
120        clone.set_source_expressions(self.get_source_expressions()[::-1])
121        return clone.as_sql(compiler, connection, **extra_context)
122
123
124class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
125    function = "MOD"
126    arity = 2
127
128
129class Pi(NumericOutputFieldMixin, Func):
130    function = "PI"
131    arity = 0
132
133
134class Power(NumericOutputFieldMixin, Func):
135    function = "POWER"
136    arity = 2
137
138
139class Radians(NumericOutputFieldMixin, Transform):
140    function = "RADIANS"
141    lookup_name = "radians"
142
143
144class Random(NumericOutputFieldMixin, Func):
145    function = "RANDOM"
146    arity = 0
147
148    def as_mysql(
149        self,
150        compiler: SQLCompiler,
151        connection: BaseDatabaseWrapper,
152        **extra_context: Any,
153    ) -> tuple[str, tuple[Any, ...]]:
154        return super().as_sql(compiler, connection, function="RAND", **extra_context)
155
156    def as_sqlite(
157        self,
158        compiler: SQLCompiler,
159        connection: BaseDatabaseWrapper,
160        **extra_context: Any,
161    ) -> tuple[str, tuple[Any, ...]]:
162        return super().as_sql(compiler, connection, function="RAND", **extra_context)
163
164    def get_group_by_cols(self) -> list[Any]:
165        return []
166
167
168class Round(FixDecimalInputMixin, Transform):
169    function = "ROUND"
170    lookup_name = "round"
171    arity = None  # Override Transform's arity=1 to enable passing precision.
172
173    def __init__(self, expression: Any, precision: int = 0, **extra: Any) -> None:
174        super().__init__(expression, precision, **extra)
175
176    def as_sqlite(
177        self,
178        compiler: SQLCompiler,
179        connection: BaseDatabaseWrapper,
180        **extra_context: Any,
181    ) -> tuple[str, tuple[Any, ...]]:
182        precision = self.get_source_expressions()[1]
183        if isinstance(precision, Value) and precision.value < 0:
184            raise ValueError("SQLite does not support negative precision.")
185        return super().as_sqlite(compiler, connection, **extra_context)
186
187    def _resolve_output_field(self) -> Field:
188        source = self.get_source_expressions()[0]
189        return source.output_field
190
191
192class Sign(Transform):
193    function = "SIGN"
194    lookup_name = "sign"
195
196
197class Sin(NumericOutputFieldMixin, Transform):
198    function = "SIN"
199    lookup_name = "sin"
200
201
202class Sqrt(NumericOutputFieldMixin, Transform):
203    function = "SQRT"
204    lookup_name = "sqrt"
205
206
207class Tan(NumericOutputFieldMixin, Transform):
208    function = "TAN"
209    lookup_name = "tan"