1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5from plain.models.expressions import Func, ResolvableExpression, Value
6from plain.models.fields import CharField, IntegerField, TextField
7from plain.models.functions import Cast, Coalesce
8from plain.models.lookups import Transform
9
10if TYPE_CHECKING:
11 from plain.models.postgres.wrapper import DatabaseWrapper
12 from plain.models.sql.compiler import SQLCompiler
13
14
15class SHAMixin(Transform):
16 """Base class for SHA hashing using PostgreSQL's pgcrypto extension."""
17
18 def as_sql(
19 self,
20 compiler: SQLCompiler,
21 connection: DatabaseWrapper,
22 function: str | None = None,
23 template: str | None = None,
24 arg_joiner: str | None = None,
25 **extra_context: Any,
26 ) -> tuple[str, list[Any]]:
27 assert self.function is not None
28 return super().as_sql(
29 compiler,
30 connection,
31 template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
32 function=self.function.lower(),
33 **extra_context,
34 )
35
36
37class Chr(Transform):
38 function = "CHR"
39 lookup_name = "chr"
40
41
42class ConcatPair(Func):
43 """Concatenate two arguments together."""
44
45 function = "CONCAT"
46
47 def as_sql(
48 self,
49 compiler: SQLCompiler,
50 connection: DatabaseWrapper,
51 function: str | None = None,
52 template: str | None = None,
53 arg_joiner: str | None = None,
54 **extra_context: Any,
55 ) -> tuple[str, list[Any]]:
56 # PostgreSQL requires explicit cast to text for CONCAT.
57 copy = self.copy()
58 copy.set_source_expressions(
59 [
60 Cast(expression, TextField())
61 for expression in copy.get_source_expressions()
62 ]
63 )
64 return super(ConcatPair, copy).as_sql(
65 compiler,
66 connection,
67 **extra_context,
68 )
69
70 def coalesce(self) -> ConcatPair:
71 # null on either side results in null for expression, wrap with coalesce
72 c = self.copy()
73 c.set_source_expressions(
74 [
75 Coalesce(expression, Value(""))
76 for expression in c.get_source_expressions()
77 ]
78 )
79 return c
80
81
82class Concat(Func):
83 """
84 Concatenate text fields together. Wraps each argument in coalesce
85 functions to ensure a non-null result.
86 """
87
88 function = None
89 template = "%(expressions)s"
90
91 def __init__(self, *expressions: Any, **extra: Any) -> None:
92 if len(expressions) < 2:
93 raise ValueError("Concat must take at least two expressions")
94 paired = self._paired(expressions)
95 super().__init__(paired, **extra)
96
97 def _paired(self, expressions: tuple[Any, ...]) -> ConcatPair:
98 # wrap pairs of expressions in successive concat functions
99 # exp = [a, b, c, d]
100 # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
101 if len(expressions) == 2:
102 return ConcatPair(*expressions)
103 return ConcatPair(expressions[0], self._paired(expressions[1:]))
104
105
106class Left(Func):
107 function = "LEFT"
108 arity = 2
109 output_field = CharField()
110
111 def __init__(self, expression: Any, length: Any, **extra: Any) -> None:
112 """
113 expression: the name of a field, or an expression returning a string
114 length: the number of characters to return from the start of the string
115 """
116 if not isinstance(length, ResolvableExpression):
117 if length < 1:
118 raise ValueError("'length' must be greater than 0.")
119 super().__init__(expression, length, **extra)
120
121 def get_substr(self) -> Substr:
122 return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
123
124
125class Length(Transform):
126 """Return the number of characters in the expression."""
127
128 function = "LENGTH"
129 lookup_name = "length"
130 output_field = IntegerField()
131
132
133class Lower(Transform):
134 function = "LOWER"
135 lookup_name = "lower"
136
137
138class LPad(Func):
139 function = "LPAD"
140 output_field = CharField()
141
142 def __init__(
143 self, expression: Any, length: Any, fill_text: Any = Value(" "), **extra: Any
144 ) -> None:
145 if (
146 not isinstance(length, ResolvableExpression)
147 and length is not None
148 and length < 0
149 ):
150 raise ValueError("'length' must be greater or equal to 0.")
151 super().__init__(expression, length, fill_text, **extra)
152
153
154class LTrim(Transform):
155 function = "LTRIM"
156 lookup_name = "ltrim"
157
158
159class MD5(Transform):
160 function = "MD5"
161 lookup_name = "md5"
162
163
164class Ord(Transform):
165 function = "ASCII"
166 lookup_name = "ord"
167 output_field = IntegerField()
168
169
170class Repeat(Func):
171 function = "REPEAT"
172 output_field = CharField()
173
174 def __init__(self, expression: Any, number: Any, **extra: Any) -> None:
175 if (
176 not isinstance(number, ResolvableExpression)
177 and number is not None
178 and number < 0
179 ):
180 raise ValueError("'number' must be greater or equal to 0.")
181 super().__init__(expression, number, **extra)
182
183
184class Replace(Func):
185 function = "REPLACE"
186
187 def __init__(
188 self, expression: Any, text: Any, replacement: Any = Value(""), **extra: Any
189 ) -> None:
190 super().__init__(expression, text, replacement, **extra)
191
192
193class Reverse(Transform):
194 function = "REVERSE"
195 lookup_name = "reverse"
196
197
198class Right(Left):
199 function = "RIGHT"
200
201 def get_substr(self) -> Substr:
202 return Substr(
203 self.source_expressions[0], self.source_expressions[1] * Value(-1)
204 )
205
206
207class RPad(LPad):
208 function = "RPAD"
209
210
211class RTrim(Transform):
212 function = "RTRIM"
213 lookup_name = "rtrim"
214
215
216class SHA1(SHAMixin, Transform):
217 function = "SHA1"
218 lookup_name = "sha1"
219
220
221class SHA224(SHAMixin, Transform):
222 function = "SHA224"
223 lookup_name = "sha224"
224
225
226class SHA256(SHAMixin, Transform):
227 function = "SHA256"
228 lookup_name = "sha256"
229
230
231class SHA384(SHAMixin, Transform):
232 function = "SHA384"
233 lookup_name = "sha384"
234
235
236class SHA512(SHAMixin, Transform):
237 function = "SHA512"
238 lookup_name = "sha512"
239
240
241class StrIndex(Func):
242 """
243 Return a positive integer corresponding to the 1-indexed position of the
244 first occurrence of a substring inside another string, or 0 if the
245 substring is not found.
246 """
247
248 # PostgreSQL uses STRPOS instead of INSTR.
249 function = "STRPOS"
250 arity = 2
251 output_field = IntegerField()
252
253
254class Substr(Func):
255 function = "SUBSTRING"
256 output_field = CharField()
257
258 def __init__(
259 self, expression: Any, pos: Any, length: Any = None, **extra: Any
260 ) -> None:
261 """
262 expression: the name of a field, or an expression returning a string
263 pos: an integer > 0, or an expression returning an integer
264 length: an optional number of characters to return
265 """
266 if not isinstance(pos, ResolvableExpression):
267 if pos < 1:
268 raise ValueError("'pos' must be greater than 0")
269 expressions = [expression, pos]
270 if length is not None:
271 expressions.append(length)
272 super().__init__(*expressions, **extra)
273
274
275class Trim(Transform):
276 function = "TRIM"
277 lookup_name = "trim"
278
279
280class Upper(Transform):
281 function = "UPPER"
282 lookup_name = "upper"