1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5from plain.models.expressions import Func, Value
6from plain.models.fields import Field, FloatField, IntegerField
7from plain.models.functions import Cast
8from plain.models.functions.mixins import (
9 FixDecimalInputMixin,
10 NumericOutputFieldMixin,
11)
12from plain.models.lookups import Transform
13
14if TYPE_CHECKING:
15 from plain.models.backends.base.base import BaseDatabaseWrapper
16 from plain.models.sql.compiler import SQLCompiler
17
18
19class Abs(Transform):
20 function = "ABS"
21 lookup_name = "abs"
22
23
24class ACos(NumericOutputFieldMixin, Transform):
25 function = "ACOS"
26 lookup_name = "acos"
27
28
29class ASin(NumericOutputFieldMixin, Transform):
30 function = "ASIN"
31 lookup_name = "asin"
32
33
34class ATan(NumericOutputFieldMixin, Transform):
35 function = "ATAN"
36 lookup_name = "atan"
37
38
39class ATan2(NumericOutputFieldMixin, Func):
40 function = "ATAN2"
41 arity = 2
42
43 def as_sqlite(
44 self,
45 compiler: SQLCompiler,
46 connection: BaseDatabaseWrapper,
47 **extra_context: Any,
48 ) -> tuple[str, tuple[Any, ...]]:
49 if not getattr(
50 connection.ops, "spatialite", False
51 ) or connection.ops.spatial_version >= (5, 0, 0): # type: ignore[attr-defined]
52 return self.as_sql(compiler, connection)
53 # This function is usually ATan2(y, x), returning the inverse tangent
54 # of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
55 # Cast integers to float to avoid inconsistent/buggy behavior if the
56 # arguments are mixed between integer and float or decimal.
57 # https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
58 clone = self.copy()
59 clone.set_source_expressions(
60 [
61 Cast(expression, FloatField())
62 if isinstance(expression.output_field, IntegerField)
63 else expression
64 for expression in self.get_source_expressions()[::-1]
65 ]
66 )
67 return clone.as_sql(compiler, connection, **extra_context)
68
69
70class Ceil(Transform):
71 function = "CEILING"
72 lookup_name = "ceil"
73
74
75class Cos(NumericOutputFieldMixin, Transform):
76 function = "COS"
77 lookup_name = "cos"
78
79
80class Cot(NumericOutputFieldMixin, Transform):
81 function = "COT"
82 lookup_name = "cot"
83
84
85class Degrees(NumericOutputFieldMixin, Transform):
86 function = "DEGREES"
87 lookup_name = "degrees"
88
89
90class Exp(NumericOutputFieldMixin, Transform):
91 function = "EXP"
92 lookup_name = "exp"
93
94
95class Floor(Transform):
96 function = "FLOOR"
97 lookup_name = "floor"
98
99
100class Ln(NumericOutputFieldMixin, Transform):
101 function = "LN"
102 lookup_name = "ln"
103
104
105class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
106 function = "LOG"
107 arity = 2
108
109 def as_sqlite(
110 self,
111 compiler: SQLCompiler,
112 connection: BaseDatabaseWrapper,
113 **extra_context: Any,
114 ) -> tuple[str, tuple[Any, ...]]:
115 if not getattr(connection.ops, "spatialite", False):
116 return self.as_sql(compiler, connection)
117 # This function is usually Log(b, x) returning the logarithm of x to
118 # the base b, but on SpatiaLite it's Log(x, b).
119 clone = self.copy()
120 clone.set_source_expressions(self.get_source_expressions()[::-1])
121 return clone.as_sql(compiler, connection, **extra_context)
122
123
124class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
125 function = "MOD"
126 arity = 2
127
128
129class Pi(NumericOutputFieldMixin, Func):
130 function = "PI"
131 arity = 0
132
133
134class Power(NumericOutputFieldMixin, Func):
135 function = "POWER"
136 arity = 2
137
138
139class Radians(NumericOutputFieldMixin, Transform):
140 function = "RADIANS"
141 lookup_name = "radians"
142
143
144class Random(NumericOutputFieldMixin, Func):
145 function = "RANDOM"
146 arity = 0
147
148 def as_mysql(
149 self,
150 compiler: SQLCompiler,
151 connection: BaseDatabaseWrapper,
152 **extra_context: Any,
153 ) -> tuple[str, tuple[Any, ...]]:
154 return super().as_sql(compiler, connection, function="RAND", **extra_context)
155
156 def as_sqlite(
157 self,
158 compiler: SQLCompiler,
159 connection: BaseDatabaseWrapper,
160 **extra_context: Any,
161 ) -> tuple[str, tuple[Any, ...]]:
162 return super().as_sql(compiler, connection, function="RAND", **extra_context)
163
164 def get_group_by_cols(self) -> list[Any]:
165 return []
166
167
168class Round(FixDecimalInputMixin, Transform):
169 function = "ROUND"
170 lookup_name = "round"
171 arity = None # Override Transform's arity=1 to enable passing precision.
172
173 def __init__(self, expression: Any, precision: int = 0, **extra: Any) -> None:
174 super().__init__(expression, precision, **extra)
175
176 def as_sqlite(
177 self,
178 compiler: SQLCompiler,
179 connection: BaseDatabaseWrapper,
180 **extra_context: Any,
181 ) -> tuple[str, tuple[Any, ...]]:
182 precision = self.get_source_expressions()[1]
183 if isinstance(precision, Value) and precision.value < 0:
184 raise ValueError("SQLite does not support negative precision.")
185 return super().as_sqlite(compiler, connection, **extra_context)
186
187 def _resolve_output_field(self) -> Field:
188 source = self.get_source_expressions()[0]
189 return source.output_field
190
191
192class Sign(Transform):
193 function = "SIGN"
194 lookup_name = "sign"
195
196
197class Sin(NumericOutputFieldMixin, Transform):
198 function = "SIN"
199 lookup_name = "sin"
200
201
202class Sqrt(NumericOutputFieldMixin, Transform):
203 function = "SQRT"
204 lookup_name = "sqrt"
205
206
207class Tan(NumericOutputFieldMixin, Transform):
208 function = "TAN"
209 lookup_name = "tan"