1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5from plain.models.expressions import Func, Value
6from plain.models.fields import CharField, IntegerField, TextField
7from plain.models.functions import Cast, Coalesce
8from plain.models.lookups import Transform
9
10if TYPE_CHECKING:
11 from plain.models.backends.base.base import BaseDatabaseWrapper
12 from plain.models.sql.compiler import SQLCompiler
13
14
15class MySQLSHA2Mixin:
16 def as_mysql(
17 self,
18 compiler: SQLCompiler,
19 connection: BaseDatabaseWrapper,
20 **extra_context: Any,
21 ) -> tuple[str, tuple[Any, ...]]:
22 return super().as_sql( # type: ignore[misc]
23 compiler,
24 connection,
25 template=f"SHA2(%(expressions)s, {self.function[3:]})",
26 **extra_context,
27 )
28
29
30class PostgreSQLSHAMixin:
31 def as_postgresql(
32 self,
33 compiler: SQLCompiler,
34 connection: BaseDatabaseWrapper,
35 **extra_context: Any,
36 ) -> tuple[str, tuple[Any, ...]]:
37 return super().as_sql( # type: ignore[misc]
38 compiler,
39 connection,
40 template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
41 function=self.function.lower(),
42 **extra_context,
43 )
44
45
46class Chr(Transform):
47 function = "CHR"
48 lookup_name = "chr"
49
50 def as_mysql(
51 self,
52 compiler: SQLCompiler,
53 connection: BaseDatabaseWrapper,
54 **extra_context: Any,
55 ) -> tuple[str, tuple[Any, ...]]:
56 return super().as_sql(
57 compiler,
58 connection,
59 function="CHAR",
60 template="%(function)s(%(expressions)s USING utf16)",
61 **extra_context,
62 )
63
64 def as_sqlite(
65 self,
66 compiler: SQLCompiler,
67 connection: BaseDatabaseWrapper,
68 **extra_context: Any,
69 ) -> tuple[str, tuple[Any, ...]]:
70 return super().as_sql(compiler, connection, function="CHAR", **extra_context)
71
72
73class ConcatPair(Func):
74 """
75 Concatenate two arguments together. This is used by `Concat` because not
76 all backend databases support more than two arguments.
77 """
78
79 function = "CONCAT"
80
81 def as_sqlite(
82 self,
83 compiler: SQLCompiler,
84 connection: BaseDatabaseWrapper,
85 **extra_context: Any,
86 ) -> tuple[str, tuple[Any, ...]]:
87 coalesced = self.coalesce()
88 return super(ConcatPair, coalesced).as_sql(
89 compiler,
90 connection,
91 template="%(expressions)s",
92 arg_joiner=" || ",
93 **extra_context,
94 )
95
96 def as_postgresql(
97 self,
98 compiler: SQLCompiler,
99 connection: BaseDatabaseWrapper,
100 **extra_context: Any,
101 ) -> tuple[str, tuple[Any, ...]]:
102 copy = self.copy()
103 copy.set_source_expressions(
104 [
105 Cast(expression, TextField())
106 for expression in copy.get_source_expressions()
107 ]
108 )
109 return super(ConcatPair, copy).as_sql(
110 compiler,
111 connection,
112 **extra_context,
113 )
114
115 def as_mysql(
116 self,
117 compiler: SQLCompiler,
118 connection: BaseDatabaseWrapper,
119 **extra_context: Any,
120 ) -> tuple[str, tuple[Any, ...]]:
121 # Use CONCAT_WS with an empty separator so that NULLs are ignored.
122 return super().as_sql(
123 compiler,
124 connection,
125 function="CONCAT_WS",
126 template="%(function)s('', %(expressions)s)",
127 **extra_context,
128 )
129
130 def coalesce(self) -> ConcatPair:
131 # null on either side results in null for expression, wrap with coalesce
132 c = self.copy()
133 c.set_source_expressions(
134 [
135 Coalesce(expression, Value(""))
136 for expression in c.get_source_expressions()
137 ]
138 )
139 return c
140
141
142class Concat(Func):
143 """
144 Concatenate text fields together. Backends that result in an entire
145 null expression when any arguments are null will wrap each argument in
146 coalesce functions to ensure a non-null result.
147 """
148
149 function = None
150 template = "%(expressions)s"
151
152 def __init__(self, *expressions: Any, **extra: Any) -> None:
153 if len(expressions) < 2:
154 raise ValueError("Concat must take at least two expressions")
155 paired = self._paired(expressions)
156 super().__init__(paired, **extra)
157
158 def _paired(self, expressions: tuple[Any, ...]) -> ConcatPair:
159 # wrap pairs of expressions in successive concat functions
160 # exp = [a, b, c, d]
161 # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
162 if len(expressions) == 2:
163 return ConcatPair(*expressions)
164 return ConcatPair(expressions[0], self._paired(expressions[1:]))
165
166
167class Left(Func):
168 function = "LEFT"
169 arity = 2
170 output_field = CharField()
171
172 def __init__(self, expression: Any, length: Any, **extra: Any) -> None:
173 """
174 expression: the name of a field, or an expression returning a string
175 length: the number of characters to return from the start of the string
176 """
177 if not hasattr(length, "resolve_expression"):
178 if length < 1:
179 raise ValueError("'length' must be greater than 0.")
180 super().__init__(expression, length, **extra)
181
182 def get_substr(self) -> Substr:
183 return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
184
185 def as_sqlite(
186 self,
187 compiler: SQLCompiler,
188 connection: BaseDatabaseWrapper,
189 **extra_context: Any,
190 ) -> tuple[str, tuple[Any, ...]]:
191 return self.get_substr().as_sqlite(compiler, connection, **extra_context)
192
193
194class Length(Transform):
195 """Return the number of characters in the expression."""
196
197 function = "LENGTH"
198 lookup_name = "length"
199 output_field = IntegerField()
200
201 def as_mysql(
202 self,
203 compiler: SQLCompiler,
204 connection: BaseDatabaseWrapper,
205 **extra_context: Any,
206 ) -> tuple[str, tuple[Any, ...]]:
207 return super().as_sql(
208 compiler, connection, function="CHAR_LENGTH", **extra_context
209 )
210
211
212class Lower(Transform):
213 function = "LOWER"
214 lookup_name = "lower"
215
216
217class LPad(Func):
218 function = "LPAD"
219 output_field = CharField()
220
221 def __init__(
222 self, expression: Any, length: Any, fill_text: Any = Value(" "), **extra: Any
223 ) -> None:
224 if (
225 not hasattr(length, "resolve_expression")
226 and length is not None
227 and length < 0
228 ):
229 raise ValueError("'length' must be greater or equal to 0.")
230 super().__init__(expression, length, fill_text, **extra)
231
232
233class LTrim(Transform):
234 function = "LTRIM"
235 lookup_name = "ltrim"
236
237
238class MD5(Transform):
239 function = "MD5"
240 lookup_name = "md5"
241
242
243class Ord(Transform):
244 function = "ASCII"
245 lookup_name = "ord"
246 output_field = IntegerField()
247
248 def as_mysql(
249 self,
250 compiler: SQLCompiler,
251 connection: BaseDatabaseWrapper,
252 **extra_context: Any,
253 ) -> tuple[str, tuple[Any, ...]]:
254 return super().as_sql(compiler, connection, function="ORD", **extra_context)
255
256 def as_sqlite(
257 self,
258 compiler: SQLCompiler,
259 connection: BaseDatabaseWrapper,
260 **extra_context: Any,
261 ) -> tuple[str, tuple[Any, ...]]:
262 return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
263
264
265class Repeat(Func):
266 function = "REPEAT"
267 output_field = CharField()
268
269 def __init__(self, expression: Any, number: Any, **extra: Any) -> None:
270 if (
271 not hasattr(number, "resolve_expression")
272 and number is not None
273 and number < 0
274 ):
275 raise ValueError("'number' must be greater or equal to 0.")
276 super().__init__(expression, number, **extra)
277
278
279class Replace(Func):
280 function = "REPLACE"
281
282 def __init__(
283 self, expression: Any, text: Any, replacement: Any = Value(""), **extra: Any
284 ) -> None:
285 super().__init__(expression, text, replacement, **extra)
286
287
288class Reverse(Transform):
289 function = "REVERSE"
290 lookup_name = "reverse"
291
292
293class Right(Left):
294 function = "RIGHT"
295
296 def get_substr(self) -> Substr:
297 return Substr(
298 self.source_expressions[0], self.source_expressions[1] * Value(-1)
299 )
300
301
302class RPad(LPad):
303 function = "RPAD"
304
305
306class RTrim(Transform):
307 function = "RTRIM"
308 lookup_name = "rtrim"
309
310
311class SHA1(PostgreSQLSHAMixin, Transform):
312 function = "SHA1"
313 lookup_name = "sha1"
314
315
316class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
317 function = "SHA224"
318 lookup_name = "sha224"
319
320
321class SHA256(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
322 function = "SHA256"
323 lookup_name = "sha256"
324
325
326class SHA384(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
327 function = "SHA384"
328 lookup_name = "sha384"
329
330
331class SHA512(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
332 function = "SHA512"
333 lookup_name = "sha512"
334
335
336class StrIndex(Func):
337 """
338 Return a positive integer corresponding to the 1-indexed position of the
339 first occurrence of a substring inside another string, or 0 if the
340 substring is not found.
341 """
342
343 function = "INSTR"
344 arity = 2
345 output_field = IntegerField()
346
347 def as_postgresql(
348 self,
349 compiler: SQLCompiler,
350 connection: BaseDatabaseWrapper,
351 **extra_context: Any,
352 ) -> tuple[str, tuple[Any, ...]]:
353 return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
354
355
356class Substr(Func):
357 function = "SUBSTRING"
358 output_field = CharField()
359
360 def __init__(
361 self, expression: Any, pos: Any, length: Any = None, **extra: Any
362 ) -> None:
363 """
364 expression: the name of a field, or an expression returning a string
365 pos: an integer > 0, or an expression returning an integer
366 length: an optional number of characters to return
367 """
368 if not hasattr(pos, "resolve_expression"):
369 if pos < 1:
370 raise ValueError("'pos' must be greater than 0")
371 expressions = [expression, pos]
372 if length is not None:
373 expressions.append(length)
374 super().__init__(*expressions, **extra)
375
376 def as_sqlite(
377 self,
378 compiler: SQLCompiler,
379 connection: BaseDatabaseWrapper,
380 **extra_context: Any,
381 ) -> tuple[str, tuple[Any, ...]]:
382 return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
383
384
385class Trim(Transform):
386 function = "TRIM"
387 lookup_name = "trim"
388
389
390class Upper(Transform):
391 function = "UPPER"
392 lookup_name = "upper"