1from plain.models.expressions import Func, Value
2from plain.models.fields import CharField, IntegerField, TextField
3from plain.models.functions import Cast, Coalesce
4from plain.models.lookups import Transform
5
6
7class MySQLSHA2Mixin:
8 def as_mysql(self, compiler, connection, **extra_context):
9 return super().as_sql(
10 compiler,
11 connection,
12 template="SHA2(%%(expressions)s, %s)" % self.function[3:],
13 **extra_context,
14 )
15
16
17class PostgreSQLSHAMixin:
18 def as_postgresql(self, compiler, connection, **extra_context):
19 return super().as_sql(
20 compiler,
21 connection,
22 template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
23 function=self.function.lower(),
24 **extra_context,
25 )
26
27
28class Chr(Transform):
29 function = "CHR"
30 lookup_name = "chr"
31
32 def as_mysql(self, compiler, connection, **extra_context):
33 return super().as_sql(
34 compiler,
35 connection,
36 function="CHAR",
37 template="%(function)s(%(expressions)s USING utf16)",
38 **extra_context,
39 )
40
41 def as_sqlite(self, compiler, connection, **extra_context):
42 return super().as_sql(compiler, connection, function="CHAR", **extra_context)
43
44
45class ConcatPair(Func):
46 """
47 Concatenate two arguments together. This is used by `Concat` because not
48 all backend databases support more than two arguments.
49 """
50
51 function = "CONCAT"
52
53 def as_sqlite(self, compiler, connection, **extra_context):
54 coalesced = self.coalesce()
55 return super(ConcatPair, coalesced).as_sql(
56 compiler,
57 connection,
58 template="%(expressions)s",
59 arg_joiner=" || ",
60 **extra_context,
61 )
62
63 def as_postgresql(self, compiler, connection, **extra_context):
64 copy = self.copy()
65 copy.set_source_expressions(
66 [
67 Cast(expression, TextField())
68 for expression in copy.get_source_expressions()
69 ]
70 )
71 return super(ConcatPair, copy).as_sql(
72 compiler,
73 connection,
74 **extra_context,
75 )
76
77 def as_mysql(self, compiler, connection, **extra_context):
78 # Use CONCAT_WS with an empty separator so that NULLs are ignored.
79 return super().as_sql(
80 compiler,
81 connection,
82 function="CONCAT_WS",
83 template="%(function)s('', %(expressions)s)",
84 **extra_context,
85 )
86
87 def coalesce(self):
88 # null on either side results in null for expression, wrap with coalesce
89 c = self.copy()
90 c.set_source_expressions(
91 [
92 Coalesce(expression, Value(""))
93 for expression in c.get_source_expressions()
94 ]
95 )
96 return c
97
98
99class Concat(Func):
100 """
101 Concatenate text fields together. Backends that result in an entire
102 null expression when any arguments are null will wrap each argument in
103 coalesce functions to ensure a non-null result.
104 """
105
106 function = None
107 template = "%(expressions)s"
108
109 def __init__(self, *expressions, **extra):
110 if len(expressions) < 2:
111 raise ValueError("Concat must take at least two expressions")
112 paired = self._paired(expressions)
113 super().__init__(paired, **extra)
114
115 def _paired(self, expressions):
116 # wrap pairs of expressions in successive concat functions
117 # exp = [a, b, c, d]
118 # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
119 if len(expressions) == 2:
120 return ConcatPair(*expressions)
121 return ConcatPair(expressions[0], self._paired(expressions[1:]))
122
123
124class Left(Func):
125 function = "LEFT"
126 arity = 2
127 output_field = CharField()
128
129 def __init__(self, expression, length, **extra):
130 """
131 expression: the name of a field, or an expression returning a string
132 length: the number of characters to return from the start of the string
133 """
134 if not hasattr(length, "resolve_expression"):
135 if length < 1:
136 raise ValueError("'length' must be greater than 0.")
137 super().__init__(expression, length, **extra)
138
139 def get_substr(self):
140 return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
141
142 def as_sqlite(self, compiler, connection, **extra_context):
143 return self.get_substr().as_sqlite(compiler, connection, **extra_context)
144
145
146class Length(Transform):
147 """Return the number of characters in the expression."""
148
149 function = "LENGTH"
150 lookup_name = "length"
151 output_field = IntegerField()
152
153 def as_mysql(self, compiler, connection, **extra_context):
154 return super().as_sql(
155 compiler, connection, function="CHAR_LENGTH", **extra_context
156 )
157
158
159class Lower(Transform):
160 function = "LOWER"
161 lookup_name = "lower"
162
163
164class LPad(Func):
165 function = "LPAD"
166 output_field = CharField()
167
168 def __init__(self, expression, length, fill_text=Value(" "), **extra):
169 if (
170 not hasattr(length, "resolve_expression")
171 and length is not None
172 and length < 0
173 ):
174 raise ValueError("'length' must be greater or equal to 0.")
175 super().__init__(expression, length, fill_text, **extra)
176
177
178class LTrim(Transform):
179 function = "LTRIM"
180 lookup_name = "ltrim"
181
182
183class MD5(Transform):
184 function = "MD5"
185 lookup_name = "md5"
186
187
188class Ord(Transform):
189 function = "ASCII"
190 lookup_name = "ord"
191 output_field = IntegerField()
192
193 def as_mysql(self, compiler, connection, **extra_context):
194 return super().as_sql(compiler, connection, function="ORD", **extra_context)
195
196 def as_sqlite(self, compiler, connection, **extra_context):
197 return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
198
199
200class Repeat(Func):
201 function = "REPEAT"
202 output_field = CharField()
203
204 def __init__(self, expression, number, **extra):
205 if (
206 not hasattr(number, "resolve_expression")
207 and number is not None
208 and number < 0
209 ):
210 raise ValueError("'number' must be greater or equal to 0.")
211 super().__init__(expression, number, **extra)
212
213
214class Replace(Func):
215 function = "REPLACE"
216
217 def __init__(self, expression, text, replacement=Value(""), **extra):
218 super().__init__(expression, text, replacement, **extra)
219
220
221class Reverse(Transform):
222 function = "REVERSE"
223 lookup_name = "reverse"
224
225
226class Right(Left):
227 function = "RIGHT"
228
229 def get_substr(self):
230 return Substr(
231 self.source_expressions[0], self.source_expressions[1] * Value(-1)
232 )
233
234
235class RPad(LPad):
236 function = "RPAD"
237
238
239class RTrim(Transform):
240 function = "RTRIM"
241 lookup_name = "rtrim"
242
243
244class SHA1(PostgreSQLSHAMixin, Transform):
245 function = "SHA1"
246 lookup_name = "sha1"
247
248
249class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
250 function = "SHA224"
251 lookup_name = "sha224"
252
253
254class SHA256(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
255 function = "SHA256"
256 lookup_name = "sha256"
257
258
259class SHA384(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
260 function = "SHA384"
261 lookup_name = "sha384"
262
263
264class SHA512(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
265 function = "SHA512"
266 lookup_name = "sha512"
267
268
269class StrIndex(Func):
270 """
271 Return a positive integer corresponding to the 1-indexed position of the
272 first occurrence of a substring inside another string, or 0 if the
273 substring is not found.
274 """
275
276 function = "INSTR"
277 arity = 2
278 output_field = IntegerField()
279
280 def as_postgresql(self, compiler, connection, **extra_context):
281 return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
282
283
284class Substr(Func):
285 function = "SUBSTRING"
286 output_field = CharField()
287
288 def __init__(self, expression, pos, length=None, **extra):
289 """
290 expression: the name of a field, or an expression returning a string
291 pos: an integer > 0, or an expression returning an integer
292 length: an optional number of characters to return
293 """
294 if not hasattr(pos, "resolve_expression"):
295 if pos < 1:
296 raise ValueError("'pos' must be greater than 0")
297 expressions = [expression, pos]
298 if length is not None:
299 expressions.append(length)
300 super().__init__(*expressions, **extra)
301
302 def as_sqlite(self, compiler, connection, **extra_context):
303 return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
304
305
306class Trim(Transform):
307 function = "TRIM"
308 lookup_name = "trim"
309
310
311class Upper(Transform):
312 function = "UPPER"
313 lookup_name = "upper"