Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections.abc import Sequence
  4from typing import TYPE_CHECKING, Any
  5
  6from plain.models.expressions import Func, Value
  7from plain.models.fields import Field, FloatField, IntegerField
  8from plain.models.functions import Cast
  9from plain.models.functions.mixins import (
 10    FixDecimalInputMixin,
 11    NumericOutputFieldMixin,
 12)
 13from plain.models.lookups import Transform
 14
 15if TYPE_CHECKING:
 16    from plain.models.backends.base.base import BaseDatabaseWrapper
 17    from plain.models.sql.compiler import SQLCompiler
 18
 19
 20class Abs(Transform):
 21    function = "ABS"
 22    lookup_name = "abs"
 23
 24
 25class ACos(NumericOutputFieldMixin, Transform):
 26    function = "ACOS"
 27    lookup_name = "acos"
 28
 29
 30class ASin(NumericOutputFieldMixin, Transform):
 31    function = "ASIN"
 32    lookup_name = "asin"
 33
 34
 35class ATan(NumericOutputFieldMixin, Transform):
 36    function = "ATAN"
 37    lookup_name = "atan"
 38
 39
 40class ATan2(NumericOutputFieldMixin, Func):
 41    function = "ATAN2"
 42    arity = 2
 43
 44    def as_sqlite(
 45        self,
 46        compiler: SQLCompiler,
 47        connection: BaseDatabaseWrapper,
 48        **extra_context: Any,
 49    ) -> tuple[str, list[Any]]:
 50        if not getattr(connection.ops, "spatialite", False) or getattr(
 51            connection.ops, "spatial_version", (0, 0, 0)
 52        ) >= (5, 0, 0):
 53            return self.as_sql(compiler, connection)
 54        # This function is usually ATan2(y, x), returning the inverse tangent
 55        # of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
 56        # Cast integers to float to avoid inconsistent/buggy behavior if the
 57        # arguments are mixed between integer and float or decimal.
 58        # https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
 59        clone = self.copy()
 60        clone.set_source_expressions(
 61            [
 62                Cast(expression, FloatField())
 63                if isinstance(expression.output_field, IntegerField)
 64                else expression
 65                for expression in self.get_source_expressions()[::-1]
 66            ]
 67        )
 68        return clone.as_sql(compiler, connection, **extra_context)
 69
 70
 71class Ceil(Transform):
 72    function = "CEILING"
 73    lookup_name = "ceil"
 74
 75
 76class Cos(NumericOutputFieldMixin, Transform):
 77    function = "COS"
 78    lookup_name = "cos"
 79
 80
 81class Cot(NumericOutputFieldMixin, Transform):
 82    function = "COT"
 83    lookup_name = "cot"
 84
 85
 86class Degrees(NumericOutputFieldMixin, Transform):
 87    function = "DEGREES"
 88    lookup_name = "degrees"
 89
 90
 91class Exp(NumericOutputFieldMixin, Transform):
 92    function = "EXP"
 93    lookup_name = "exp"
 94
 95
 96class Floor(Transform):
 97    function = "FLOOR"
 98    lookup_name = "floor"
 99
100
101class Ln(NumericOutputFieldMixin, Transform):
102    function = "LN"
103    lookup_name = "ln"
104
105
106class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
107    function = "LOG"
108    arity = 2
109
110    def as_sqlite(
111        self,
112        compiler: SQLCompiler,
113        connection: BaseDatabaseWrapper,
114        **extra_context: Any,
115    ) -> tuple[str, list[Any]]:
116        if not getattr(connection.ops, "spatialite", False):
117            return self.as_sql(compiler, connection)
118        # This function is usually Log(b, x) returning the logarithm of x to
119        # the base b, but on SpatiaLite it's Log(x, b).
120        clone = self.copy()
121        clone.set_source_expressions(self.get_source_expressions()[::-1])
122        return clone.as_sql(compiler, connection, **extra_context)
123
124
125class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
126    function = "MOD"
127    arity = 2
128
129
130class Pi(NumericOutputFieldMixin, Func):
131    function = "PI"
132    arity = 0
133
134
135class Power(NumericOutputFieldMixin, Func):
136    function = "POWER"
137    arity = 2
138
139
140class Radians(NumericOutputFieldMixin, Transform):
141    function = "RADIANS"
142    lookup_name = "radians"
143
144
145class Random(NumericOutputFieldMixin, Func):
146    function = "RANDOM"
147    arity = 0
148
149    def as_mysql(
150        self,
151        compiler: SQLCompiler,
152        connection: BaseDatabaseWrapper,
153        **extra_context: Any,
154    ) -> tuple[str, list[Any]]:
155        return super().as_sql(compiler, connection, function="RAND", **extra_context)
156
157    def as_sqlite(
158        self,
159        compiler: SQLCompiler,
160        connection: BaseDatabaseWrapper,
161        **extra_context: Any,
162    ) -> tuple[str, list[Any]]:
163        return super().as_sql(compiler, connection, function="RAND", **extra_context)
164
165    def get_group_by_cols(self) -> list[Any]:
166        return []
167
168
169class Round(FixDecimalInputMixin, Transform):
170    function = "ROUND"
171    lookup_name = "round"
172    arity = None  # Override Transform's arity=1 to enable passing precision.
173
174    def __init__(self, expression: Any, precision: int = 0, **extra: Any) -> None:
175        super().__init__(expression, precision, **extra)
176
177    def as_sqlite(
178        self,
179        compiler: SQLCompiler,
180        connection: BaseDatabaseWrapper,
181        **extra_context: Any,
182    ) -> tuple[str, Sequence[Any]]:
183        precision = self.get_source_expressions()[1]
184        if isinstance(precision, Value) and precision.value < 0:
185            raise ValueError("SQLite does not support negative precision.")
186        return super().as_sqlite(compiler, connection, **extra_context)
187
188    def _resolve_output_field(self) -> Field:
189        source = self.get_source_expressions()[0]
190        return source.output_field
191
192
193class Sign(Transform):
194    function = "SIGN"
195    lookup_name = "sign"
196
197
198class Sin(NumericOutputFieldMixin, Transform):
199    function = "SIN"
200    lookup_name = "sin"
201
202
203class Sqrt(NumericOutputFieldMixin, Transform):
204    function = "SQRT"
205    lookup_name = "sqrt"
206
207
208class Tan(NumericOutputFieldMixin, Transform):
209    function = "TAN"
210    lookup_name = "tan"