1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5from plain.models.expressions import Func, ResolvableExpression, Value
6from plain.models.fields import CharField, IntegerField, TextField
7from plain.models.functions import Cast, Coalesce
8from plain.models.lookups import Transform
9
10if TYPE_CHECKING:
11 from plain.models.backends.base.base import BaseDatabaseWrapper
12 from plain.models.sql.compiler import SQLCompiler
13
14
15class MySQLSHA2Mixin(Transform):
16 """Mixin for Transform subclasses that implement SHA2 hashing on MySQL."""
17
18 def as_mysql(
19 self,
20 compiler: SQLCompiler,
21 connection: BaseDatabaseWrapper,
22 **extra_context: Any,
23 ) -> tuple[str, list[Any]]:
24 assert self.function is not None
25 return super().as_sql(
26 compiler,
27 connection,
28 template=f"SHA2(%(expressions)s, {self.function[3:]})",
29 **extra_context,
30 )
31
32
33class PostgreSQLSHAMixin(Transform):
34 """Mixin for Transform subclasses that implement SHA hashing on PostgreSQL."""
35
36 def as_postgresql(
37 self,
38 compiler: SQLCompiler,
39 connection: BaseDatabaseWrapper,
40 **extra_context: Any,
41 ) -> tuple[str, list[Any]]:
42 assert self.function is not None
43 return super().as_sql(
44 compiler,
45 connection,
46 template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
47 function=self.function.lower(),
48 **extra_context,
49 )
50
51
52class Chr(Transform):
53 function = "CHR"
54 lookup_name = "chr"
55
56 def as_mysql(
57 self,
58 compiler: SQLCompiler,
59 connection: BaseDatabaseWrapper,
60 **extra_context: Any,
61 ) -> tuple[str, list[Any]]:
62 return super().as_sql(
63 compiler,
64 connection,
65 function="CHAR",
66 template="%(function)s(%(expressions)s USING utf16)",
67 **extra_context,
68 )
69
70 def as_sqlite(
71 self,
72 compiler: SQLCompiler,
73 connection: BaseDatabaseWrapper,
74 **extra_context: Any,
75 ) -> tuple[str, list[Any]]:
76 return super().as_sql(compiler, connection, function="CHAR", **extra_context)
77
78
79class ConcatPair(Func):
80 """
81 Concatenate two arguments together. This is used by `Concat` because not
82 all backend databases support more than two arguments.
83 """
84
85 function = "CONCAT"
86
87 def as_sqlite(
88 self,
89 compiler: SQLCompiler,
90 connection: BaseDatabaseWrapper,
91 **extra_context: Any,
92 ) -> tuple[str, list[Any]]:
93 coalesced = self.coalesce()
94 return super(ConcatPair, coalesced).as_sql(
95 compiler,
96 connection,
97 template="%(expressions)s",
98 arg_joiner=" || ",
99 **extra_context,
100 )
101
102 def as_postgresql(
103 self,
104 compiler: SQLCompiler,
105 connection: BaseDatabaseWrapper,
106 **extra_context: Any,
107 ) -> tuple[str, list[Any]]:
108 copy = self.copy()
109 copy.set_source_expressions(
110 [
111 Cast(expression, TextField())
112 for expression in copy.get_source_expressions()
113 ]
114 )
115 return super(ConcatPair, copy).as_sql(
116 compiler,
117 connection,
118 **extra_context,
119 )
120
121 def as_mysql(
122 self,
123 compiler: SQLCompiler,
124 connection: BaseDatabaseWrapper,
125 **extra_context: Any,
126 ) -> tuple[str, list[Any]]:
127 # Use CONCAT_WS with an empty separator so that NULLs are ignored.
128 return super().as_sql(
129 compiler,
130 connection,
131 function="CONCAT_WS",
132 template="%(function)s('', %(expressions)s)",
133 **extra_context,
134 )
135
136 def coalesce(self) -> ConcatPair:
137 # null on either side results in null for expression, wrap with coalesce
138 c = self.copy()
139 c.set_source_expressions(
140 [
141 Coalesce(expression, Value(""))
142 for expression in c.get_source_expressions()
143 ]
144 )
145 return c
146
147
148class Concat(Func):
149 """
150 Concatenate text fields together. Backends that result in an entire
151 null expression when any arguments are null will wrap each argument in
152 coalesce functions to ensure a non-null result.
153 """
154
155 function = None
156 template = "%(expressions)s"
157
158 def __init__(self, *expressions: Any, **extra: Any) -> None:
159 if len(expressions) < 2:
160 raise ValueError("Concat must take at least two expressions")
161 paired = self._paired(expressions)
162 super().__init__(paired, **extra)
163
164 def _paired(self, expressions: tuple[Any, ...]) -> ConcatPair:
165 # wrap pairs of expressions in successive concat functions
166 # exp = [a, b, c, d]
167 # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
168 if len(expressions) == 2:
169 return ConcatPair(*expressions)
170 return ConcatPair(expressions[0], self._paired(expressions[1:]))
171
172
173class Left(Func):
174 function = "LEFT"
175 arity = 2
176 output_field = CharField()
177
178 def __init__(self, expression: Any, length: Any, **extra: Any) -> None:
179 """
180 expression: the name of a field, or an expression returning a string
181 length: the number of characters to return from the start of the string
182 """
183 if not isinstance(length, ResolvableExpression):
184 if length < 1:
185 raise ValueError("'length' must be greater than 0.")
186 super().__init__(expression, length, **extra)
187
188 def get_substr(self) -> Substr:
189 return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
190
191 def as_sqlite(
192 self,
193 compiler: SQLCompiler,
194 connection: BaseDatabaseWrapper,
195 **extra_context: Any,
196 ) -> tuple[str, list[Any]]:
197 return self.get_substr().as_sqlite(compiler, connection, **extra_context)
198
199
200class Length(Transform):
201 """Return the number of characters in the expression."""
202
203 function = "LENGTH"
204 lookup_name = "length"
205 output_field = IntegerField()
206
207 def as_mysql(
208 self,
209 compiler: SQLCompiler,
210 connection: BaseDatabaseWrapper,
211 **extra_context: Any,
212 ) -> tuple[str, list[Any]]:
213 return super().as_sql(
214 compiler, connection, function="CHAR_LENGTH", **extra_context
215 )
216
217
218class Lower(Transform):
219 function = "LOWER"
220 lookup_name = "lower"
221
222
223class LPad(Func):
224 function = "LPAD"
225 output_field = CharField()
226
227 def __init__(
228 self, expression: Any, length: Any, fill_text: Any = Value(" "), **extra: Any
229 ) -> None:
230 if (
231 not isinstance(length, ResolvableExpression)
232 and length is not None
233 and length < 0
234 ):
235 raise ValueError("'length' must be greater or equal to 0.")
236 super().__init__(expression, length, fill_text, **extra)
237
238
239class LTrim(Transform):
240 function = "LTRIM"
241 lookup_name = "ltrim"
242
243
244class MD5(Transform):
245 function = "MD5"
246 lookup_name = "md5"
247
248
249class Ord(Transform):
250 function = "ASCII"
251 lookup_name = "ord"
252 output_field = IntegerField()
253
254 def as_mysql(
255 self,
256 compiler: SQLCompiler,
257 connection: BaseDatabaseWrapper,
258 **extra_context: Any,
259 ) -> tuple[str, list[Any]]:
260 return super().as_sql(compiler, connection, function="ORD", **extra_context)
261
262 def as_sqlite(
263 self,
264 compiler: SQLCompiler,
265 connection: BaseDatabaseWrapper,
266 **extra_context: Any,
267 ) -> tuple[str, list[Any]]:
268 return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
269
270
271class Repeat(Func):
272 function = "REPEAT"
273 output_field = CharField()
274
275 def __init__(self, expression: Any, number: Any, **extra: Any) -> None:
276 if (
277 not isinstance(number, ResolvableExpression)
278 and number is not None
279 and number < 0
280 ):
281 raise ValueError("'number' must be greater or equal to 0.")
282 super().__init__(expression, number, **extra)
283
284
285class Replace(Func):
286 function = "REPLACE"
287
288 def __init__(
289 self, expression: Any, text: Any, replacement: Any = Value(""), **extra: Any
290 ) -> None:
291 super().__init__(expression, text, replacement, **extra)
292
293
294class Reverse(Transform):
295 function = "REVERSE"
296 lookup_name = "reverse"
297
298
299class Right(Left):
300 function = "RIGHT"
301
302 def get_substr(self) -> Substr:
303 return Substr(
304 self.source_expressions[0], self.source_expressions[1] * Value(-1)
305 )
306
307
308class RPad(LPad):
309 function = "RPAD"
310
311
312class RTrim(Transform):
313 function = "RTRIM"
314 lookup_name = "rtrim"
315
316
317class SHA1(PostgreSQLSHAMixin, Transform):
318 function = "SHA1"
319 lookup_name = "sha1"
320
321
322class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
323 function = "SHA224"
324 lookup_name = "sha224"
325
326
327class SHA256(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
328 function = "SHA256"
329 lookup_name = "sha256"
330
331
332class SHA384(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
333 function = "SHA384"
334 lookup_name = "sha384"
335
336
337class SHA512(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
338 function = "SHA512"
339 lookup_name = "sha512"
340
341
342class StrIndex(Func):
343 """
344 Return a positive integer corresponding to the 1-indexed position of the
345 first occurrence of a substring inside another string, or 0 if the
346 substring is not found.
347 """
348
349 function = "INSTR"
350 arity = 2
351 output_field = IntegerField()
352
353 def as_postgresql(
354 self,
355 compiler: SQLCompiler,
356 connection: BaseDatabaseWrapper,
357 **extra_context: Any,
358 ) -> tuple[str, list[Any]]:
359 return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
360
361
362class Substr(Func):
363 function = "SUBSTRING"
364 output_field = CharField()
365
366 def __init__(
367 self, expression: Any, pos: Any, length: Any = None, **extra: Any
368 ) -> None:
369 """
370 expression: the name of a field, or an expression returning a string
371 pos: an integer > 0, or an expression returning an integer
372 length: an optional number of characters to return
373 """
374 if not isinstance(pos, ResolvableExpression):
375 if pos < 1:
376 raise ValueError("'pos' must be greater than 0")
377 expressions = [expression, pos]
378 if length is not None:
379 expressions.append(length)
380 super().__init__(*expressions, **extra)
381
382 def as_sqlite(
383 self,
384 compiler: SQLCompiler,
385 connection: BaseDatabaseWrapper,
386 **extra_context: Any,
387 ) -> tuple[str, list[Any]]:
388 return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
389
390
391class Trim(Transform):
392 function = "TRIM"
393 lookup_name = "trim"
394
395
396class Upper(Transform):
397 function = "UPPER"
398 lookup_name = "upper"