1from __future__ import annotations
2
3from collections.abc import Sequence
4from typing import TYPE_CHECKING, Any
5
6from plain.models.expressions import Func, Value
7from plain.models.fields import Field, FloatField, IntegerField
8from plain.models.functions import Cast
9from plain.models.functions.mixins import (
10 FixDecimalInputMixin,
11 NumericOutputFieldMixin,
12)
13from plain.models.lookups import Transform
14
15if TYPE_CHECKING:
16 from plain.models.backends.base.base import BaseDatabaseWrapper
17 from plain.models.sql.compiler import SQLCompiler
18
19
20class Abs(Transform):
21 function = "ABS"
22 lookup_name = "abs"
23
24
25class ACos(NumericOutputFieldMixin, Transform):
26 function = "ACOS"
27 lookup_name = "acos"
28
29
30class ASin(NumericOutputFieldMixin, Transform):
31 function = "ASIN"
32 lookup_name = "asin"
33
34
35class ATan(NumericOutputFieldMixin, Transform):
36 function = "ATAN"
37 lookup_name = "atan"
38
39
40class ATan2(NumericOutputFieldMixin, Func):
41 function = "ATAN2"
42 arity = 2
43
44 def as_sqlite(
45 self,
46 compiler: SQLCompiler,
47 connection: BaseDatabaseWrapper,
48 **extra_context: Any,
49 ) -> tuple[str, list[Any]]:
50 if not getattr(connection.ops, "spatialite", False) or getattr(
51 connection.ops, "spatial_version", (0, 0, 0)
52 ) >= (5, 0, 0):
53 return self.as_sql(compiler, connection)
54 # This function is usually ATan2(y, x), returning the inverse tangent
55 # of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
56 # Cast integers to float to avoid inconsistent/buggy behavior if the
57 # arguments are mixed between integer and float or decimal.
58 # https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
59 clone = self.copy()
60 clone.set_source_expressions(
61 [
62 Cast(expression, FloatField())
63 if isinstance(expression.output_field, IntegerField)
64 else expression
65 for expression in self.get_source_expressions()[::-1]
66 ]
67 )
68 return clone.as_sql(compiler, connection, **extra_context)
69
70
71class Ceil(Transform):
72 function = "CEILING"
73 lookup_name = "ceil"
74
75
76class Cos(NumericOutputFieldMixin, Transform):
77 function = "COS"
78 lookup_name = "cos"
79
80
81class Cot(NumericOutputFieldMixin, Transform):
82 function = "COT"
83 lookup_name = "cot"
84
85
86class Degrees(NumericOutputFieldMixin, Transform):
87 function = "DEGREES"
88 lookup_name = "degrees"
89
90
91class Exp(NumericOutputFieldMixin, Transform):
92 function = "EXP"
93 lookup_name = "exp"
94
95
96class Floor(Transform):
97 function = "FLOOR"
98 lookup_name = "floor"
99
100
101class Ln(NumericOutputFieldMixin, Transform):
102 function = "LN"
103 lookup_name = "ln"
104
105
106class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
107 function = "LOG"
108 arity = 2
109
110 def as_sqlite(
111 self,
112 compiler: SQLCompiler,
113 connection: BaseDatabaseWrapper,
114 **extra_context: Any,
115 ) -> tuple[str, list[Any]]:
116 if not getattr(connection.ops, "spatialite", False):
117 return self.as_sql(compiler, connection)
118 # This function is usually Log(b, x) returning the logarithm of x to
119 # the base b, but on SpatiaLite it's Log(x, b).
120 clone = self.copy()
121 clone.set_source_expressions(self.get_source_expressions()[::-1])
122 return clone.as_sql(compiler, connection, **extra_context)
123
124
125class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
126 function = "MOD"
127 arity = 2
128
129
130class Pi(NumericOutputFieldMixin, Func):
131 function = "PI"
132 arity = 0
133
134
135class Power(NumericOutputFieldMixin, Func):
136 function = "POWER"
137 arity = 2
138
139
140class Radians(NumericOutputFieldMixin, Transform):
141 function = "RADIANS"
142 lookup_name = "radians"
143
144
145class Random(NumericOutputFieldMixin, Func):
146 function = "RANDOM"
147 arity = 0
148
149 def as_mysql(
150 self,
151 compiler: SQLCompiler,
152 connection: BaseDatabaseWrapper,
153 **extra_context: Any,
154 ) -> tuple[str, list[Any]]:
155 return super().as_sql(compiler, connection, function="RAND", **extra_context)
156
157 def as_sqlite(
158 self,
159 compiler: SQLCompiler,
160 connection: BaseDatabaseWrapper,
161 **extra_context: Any,
162 ) -> tuple[str, list[Any]]:
163 return super().as_sql(compiler, connection, function="RAND", **extra_context)
164
165 def get_group_by_cols(self) -> list[Any]:
166 return []
167
168
169class Round(FixDecimalInputMixin, Transform):
170 function = "ROUND"
171 lookup_name = "round"
172 arity = None # Override Transform's arity=1 to enable passing precision.
173
174 def __init__(self, expression: Any, precision: int = 0, **extra: Any) -> None:
175 super().__init__(expression, precision, **extra)
176
177 def as_sqlite(
178 self,
179 compiler: SQLCompiler,
180 connection: BaseDatabaseWrapper,
181 **extra_context: Any,
182 ) -> tuple[str, Sequence[Any]]:
183 precision = self.get_source_expressions()[1]
184 if isinstance(precision, Value) and precision.value < 0:
185 raise ValueError("SQLite does not support negative precision.")
186 return super().as_sqlite(compiler, connection, **extra_context)
187
188 def _resolve_output_field(self) -> Field:
189 source = self.get_source_expressions()[0]
190 return source.output_field
191
192
193class Sign(Transform):
194 function = "SIGN"
195 lookup_name = "sign"
196
197
198class Sin(NumericOutputFieldMixin, Transform):
199 function = "SIN"
200 lookup_name = "sin"
201
202
203class Sqrt(NumericOutputFieldMixin, Transform):
204 function = "SQRT"
205 lookup_name = "sqrt"
206
207
208class Tan(NumericOutputFieldMixin, Transform):
209 function = "TAN"
210 lookup_name = "tan"