Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5from plain.models.expressions import Func, Value
  6from plain.models.fields import CharField, IntegerField, TextField
  7from plain.models.functions import Cast, Coalesce
  8from plain.models.lookups import Transform
  9
 10if TYPE_CHECKING:
 11    from plain.models.backends.base.base import BaseDatabaseWrapper
 12    from plain.models.sql.compiler import SQLCompiler
 13
 14
 15class MySQLSHA2Mixin:
 16    def as_mysql(
 17        self,
 18        compiler: SQLCompiler,
 19        connection: BaseDatabaseWrapper,
 20        **extra_context: Any,
 21    ) -> tuple[str, tuple[Any, ...]]:
 22        return super().as_sql(  # type: ignore[misc]
 23            compiler,
 24            connection,
 25            template=f"SHA2(%(expressions)s, {self.function[3:]})",
 26            **extra_context,
 27        )
 28
 29
 30class PostgreSQLSHAMixin:
 31    def as_postgresql(
 32        self,
 33        compiler: SQLCompiler,
 34        connection: BaseDatabaseWrapper,
 35        **extra_context: Any,
 36    ) -> tuple[str, tuple[Any, ...]]:
 37        return super().as_sql(  # type: ignore[misc]
 38            compiler,
 39            connection,
 40            template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
 41            function=self.function.lower(),
 42            **extra_context,
 43        )
 44
 45
 46class Chr(Transform):
 47    function = "CHR"
 48    lookup_name = "chr"
 49
 50    def as_mysql(
 51        self,
 52        compiler: SQLCompiler,
 53        connection: BaseDatabaseWrapper,
 54        **extra_context: Any,
 55    ) -> tuple[str, tuple[Any, ...]]:
 56        return super().as_sql(
 57            compiler,
 58            connection,
 59            function="CHAR",
 60            template="%(function)s(%(expressions)s USING utf16)",
 61            **extra_context,
 62        )
 63
 64    def as_sqlite(
 65        self,
 66        compiler: SQLCompiler,
 67        connection: BaseDatabaseWrapper,
 68        **extra_context: Any,
 69    ) -> tuple[str, tuple[Any, ...]]:
 70        return super().as_sql(compiler, connection, function="CHAR", **extra_context)
 71
 72
 73class ConcatPair(Func):
 74    """
 75    Concatenate two arguments together. This is used by `Concat` because not
 76    all backend databases support more than two arguments.
 77    """
 78
 79    function = "CONCAT"
 80
 81    def as_sqlite(
 82        self,
 83        compiler: SQLCompiler,
 84        connection: BaseDatabaseWrapper,
 85        **extra_context: Any,
 86    ) -> tuple[str, tuple[Any, ...]]:
 87        coalesced = self.coalesce()
 88        return super(ConcatPair, coalesced).as_sql(
 89            compiler,
 90            connection,
 91            template="%(expressions)s",
 92            arg_joiner=" || ",
 93            **extra_context,
 94        )
 95
 96    def as_postgresql(
 97        self,
 98        compiler: SQLCompiler,
 99        connection: BaseDatabaseWrapper,
100        **extra_context: Any,
101    ) -> tuple[str, tuple[Any, ...]]:
102        copy = self.copy()
103        copy.set_source_expressions(
104            [
105                Cast(expression, TextField())
106                for expression in copy.get_source_expressions()
107            ]
108        )
109        return super(ConcatPair, copy).as_sql(
110            compiler,
111            connection,
112            **extra_context,
113        )
114
115    def as_mysql(
116        self,
117        compiler: SQLCompiler,
118        connection: BaseDatabaseWrapper,
119        **extra_context: Any,
120    ) -> tuple[str, tuple[Any, ...]]:
121        # Use CONCAT_WS with an empty separator so that NULLs are ignored.
122        return super().as_sql(
123            compiler,
124            connection,
125            function="CONCAT_WS",
126            template="%(function)s('', %(expressions)s)",
127            **extra_context,
128        )
129
130    def coalesce(self) -> ConcatPair:
131        # null on either side results in null for expression, wrap with coalesce
132        c = self.copy()
133        c.set_source_expressions(
134            [
135                Coalesce(expression, Value(""))
136                for expression in c.get_source_expressions()
137            ]
138        )
139        return c
140
141
142class Concat(Func):
143    """
144    Concatenate text fields together. Backends that result in an entire
145    null expression when any arguments are null will wrap each argument in
146    coalesce functions to ensure a non-null result.
147    """
148
149    function = None
150    template = "%(expressions)s"
151
152    def __init__(self, *expressions: Any, **extra: Any) -> None:
153        if len(expressions) < 2:
154            raise ValueError("Concat must take at least two expressions")
155        paired = self._paired(expressions)
156        super().__init__(paired, **extra)
157
158    def _paired(self, expressions: tuple[Any, ...]) -> ConcatPair:
159        # wrap pairs of expressions in successive concat functions
160        # exp = [a, b, c, d]
161        # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
162        if len(expressions) == 2:
163            return ConcatPair(*expressions)
164        return ConcatPair(expressions[0], self._paired(expressions[1:]))
165
166
167class Left(Func):
168    function = "LEFT"
169    arity = 2
170    output_field = CharField()
171
172    def __init__(self, expression: Any, length: Any, **extra: Any) -> None:
173        """
174        expression: the name of a field, or an expression returning a string
175        length: the number of characters to return from the start of the string
176        """
177        if not hasattr(length, "resolve_expression"):
178            if length < 1:
179                raise ValueError("'length' must be greater than 0.")
180        super().__init__(expression, length, **extra)
181
182    def get_substr(self) -> Substr:
183        return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
184
185    def as_sqlite(
186        self,
187        compiler: SQLCompiler,
188        connection: BaseDatabaseWrapper,
189        **extra_context: Any,
190    ) -> tuple[str, tuple[Any, ...]]:
191        return self.get_substr().as_sqlite(compiler, connection, **extra_context)
192
193
194class Length(Transform):
195    """Return the number of characters in the expression."""
196
197    function = "LENGTH"
198    lookup_name = "length"
199    output_field = IntegerField()
200
201    def as_mysql(
202        self,
203        compiler: SQLCompiler,
204        connection: BaseDatabaseWrapper,
205        **extra_context: Any,
206    ) -> tuple[str, tuple[Any, ...]]:
207        return super().as_sql(
208            compiler, connection, function="CHAR_LENGTH", **extra_context
209        )
210
211
212class Lower(Transform):
213    function = "LOWER"
214    lookup_name = "lower"
215
216
217class LPad(Func):
218    function = "LPAD"
219    output_field = CharField()
220
221    def __init__(
222        self, expression: Any, length: Any, fill_text: Any = Value(" "), **extra: Any
223    ) -> None:
224        if (
225            not hasattr(length, "resolve_expression")
226            and length is not None
227            and length < 0
228        ):
229            raise ValueError("'length' must be greater or equal to 0.")
230        super().__init__(expression, length, fill_text, **extra)
231
232
233class LTrim(Transform):
234    function = "LTRIM"
235    lookup_name = "ltrim"
236
237
238class MD5(Transform):
239    function = "MD5"
240    lookup_name = "md5"
241
242
243class Ord(Transform):
244    function = "ASCII"
245    lookup_name = "ord"
246    output_field = IntegerField()
247
248    def as_mysql(
249        self,
250        compiler: SQLCompiler,
251        connection: BaseDatabaseWrapper,
252        **extra_context: Any,
253    ) -> tuple[str, tuple[Any, ...]]:
254        return super().as_sql(compiler, connection, function="ORD", **extra_context)
255
256    def as_sqlite(
257        self,
258        compiler: SQLCompiler,
259        connection: BaseDatabaseWrapper,
260        **extra_context: Any,
261    ) -> tuple[str, tuple[Any, ...]]:
262        return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
263
264
265class Repeat(Func):
266    function = "REPEAT"
267    output_field = CharField()
268
269    def __init__(self, expression: Any, number: Any, **extra: Any) -> None:
270        if (
271            not hasattr(number, "resolve_expression")
272            and number is not None
273            and number < 0
274        ):
275            raise ValueError("'number' must be greater or equal to 0.")
276        super().__init__(expression, number, **extra)
277
278
279class Replace(Func):
280    function = "REPLACE"
281
282    def __init__(
283        self, expression: Any, text: Any, replacement: Any = Value(""), **extra: Any
284    ) -> None:
285        super().__init__(expression, text, replacement, **extra)
286
287
288class Reverse(Transform):
289    function = "REVERSE"
290    lookup_name = "reverse"
291
292
293class Right(Left):
294    function = "RIGHT"
295
296    def get_substr(self) -> Substr:
297        return Substr(
298            self.source_expressions[0], self.source_expressions[1] * Value(-1)
299        )
300
301
302class RPad(LPad):
303    function = "RPAD"
304
305
306class RTrim(Transform):
307    function = "RTRIM"
308    lookup_name = "rtrim"
309
310
311class SHA1(PostgreSQLSHAMixin, Transform):
312    function = "SHA1"
313    lookup_name = "sha1"
314
315
316class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
317    function = "SHA224"
318    lookup_name = "sha224"
319
320
321class SHA256(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
322    function = "SHA256"
323    lookup_name = "sha256"
324
325
326class SHA384(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
327    function = "SHA384"
328    lookup_name = "sha384"
329
330
331class SHA512(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
332    function = "SHA512"
333    lookup_name = "sha512"
334
335
336class StrIndex(Func):
337    """
338    Return a positive integer corresponding to the 1-indexed position of the
339    first occurrence of a substring inside another string, or 0 if the
340    substring is not found.
341    """
342
343    function = "INSTR"
344    arity = 2
345    output_field = IntegerField()
346
347    def as_postgresql(
348        self,
349        compiler: SQLCompiler,
350        connection: BaseDatabaseWrapper,
351        **extra_context: Any,
352    ) -> tuple[str, tuple[Any, ...]]:
353        return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
354
355
356class Substr(Func):
357    function = "SUBSTRING"
358    output_field = CharField()
359
360    def __init__(
361        self, expression: Any, pos: Any, length: Any = None, **extra: Any
362    ) -> None:
363        """
364        expression: the name of a field, or an expression returning a string
365        pos: an integer > 0, or an expression returning an integer
366        length: an optional number of characters to return
367        """
368        if not hasattr(pos, "resolve_expression"):
369            if pos < 1:
370                raise ValueError("'pos' must be greater than 0")
371        expressions = [expression, pos]
372        if length is not None:
373            expressions.append(length)
374        super().__init__(*expressions, **extra)
375
376    def as_sqlite(
377        self,
378        compiler: SQLCompiler,
379        connection: BaseDatabaseWrapper,
380        **extra_context: Any,
381    ) -> tuple[str, tuple[Any, ...]]:
382        return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
383
384
385class Trim(Transform):
386    function = "TRIM"
387    lookup_name = "trim"
388
389
390class Upper(Transform):
391    function = "UPPER"
392    lookup_name = "upper"