Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5from plain.models.expressions import Func, ResolvableExpression, Value
  6from plain.models.fields import CharField, IntegerField, TextField
  7from plain.models.functions import Cast, Coalesce
  8from plain.models.lookups import Transform
  9
 10if TYPE_CHECKING:
 11    from plain.models.backends.base.base import BaseDatabaseWrapper
 12    from plain.models.sql.compiler import SQLCompiler
 13
 14
 15class MySQLSHA2Mixin(Transform):
 16    """Mixin for Transform subclasses that implement SHA2 hashing on MySQL."""
 17
 18    def as_mysql(
 19        self,
 20        compiler: SQLCompiler,
 21        connection: BaseDatabaseWrapper,
 22        **extra_context: Any,
 23    ) -> tuple[str, list[Any]]:
 24        assert self.function is not None
 25        return super().as_sql(
 26            compiler,
 27            connection,
 28            template=f"SHA2(%(expressions)s, {self.function[3:]})",
 29            **extra_context,
 30        )
 31
 32
 33class PostgreSQLSHAMixin(Transform):
 34    """Mixin for Transform subclasses that implement SHA hashing on PostgreSQL."""
 35
 36    def as_postgresql(
 37        self,
 38        compiler: SQLCompiler,
 39        connection: BaseDatabaseWrapper,
 40        **extra_context: Any,
 41    ) -> tuple[str, list[Any]]:
 42        assert self.function is not None
 43        return super().as_sql(
 44            compiler,
 45            connection,
 46            template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
 47            function=self.function.lower(),
 48            **extra_context,
 49        )
 50
 51
 52class Chr(Transform):
 53    function = "CHR"
 54    lookup_name = "chr"
 55
 56    def as_mysql(
 57        self,
 58        compiler: SQLCompiler,
 59        connection: BaseDatabaseWrapper,
 60        **extra_context: Any,
 61    ) -> tuple[str, list[Any]]:
 62        return super().as_sql(
 63            compiler,
 64            connection,
 65            function="CHAR",
 66            template="%(function)s(%(expressions)s USING utf16)",
 67            **extra_context,
 68        )
 69
 70    def as_sqlite(
 71        self,
 72        compiler: SQLCompiler,
 73        connection: BaseDatabaseWrapper,
 74        **extra_context: Any,
 75    ) -> tuple[str, list[Any]]:
 76        return super().as_sql(compiler, connection, function="CHAR", **extra_context)
 77
 78
 79class ConcatPair(Func):
 80    """
 81    Concatenate two arguments together. This is used by `Concat` because not
 82    all backend databases support more than two arguments.
 83    """
 84
 85    function = "CONCAT"
 86
 87    def as_sqlite(
 88        self,
 89        compiler: SQLCompiler,
 90        connection: BaseDatabaseWrapper,
 91        **extra_context: Any,
 92    ) -> tuple[str, list[Any]]:
 93        coalesced = self.coalesce()
 94        return super(ConcatPair, coalesced).as_sql(
 95            compiler,
 96            connection,
 97            template="%(expressions)s",
 98            arg_joiner=" || ",
 99            **extra_context,
100        )
101
102    def as_postgresql(
103        self,
104        compiler: SQLCompiler,
105        connection: BaseDatabaseWrapper,
106        **extra_context: Any,
107    ) -> tuple[str, list[Any]]:
108        copy = self.copy()
109        copy.set_source_expressions(
110            [
111                Cast(expression, TextField())
112                for expression in copy.get_source_expressions()
113            ]
114        )
115        return super(ConcatPair, copy).as_sql(
116            compiler,
117            connection,
118            **extra_context,
119        )
120
121    def as_mysql(
122        self,
123        compiler: SQLCompiler,
124        connection: BaseDatabaseWrapper,
125        **extra_context: Any,
126    ) -> tuple[str, list[Any]]:
127        # Use CONCAT_WS with an empty separator so that NULLs are ignored.
128        return super().as_sql(
129            compiler,
130            connection,
131            function="CONCAT_WS",
132            template="%(function)s('', %(expressions)s)",
133            **extra_context,
134        )
135
136    def coalesce(self) -> ConcatPair:
137        # null on either side results in null for expression, wrap with coalesce
138        c = self.copy()
139        c.set_source_expressions(
140            [
141                Coalesce(expression, Value(""))
142                for expression in c.get_source_expressions()
143            ]
144        )
145        return c
146
147
148class Concat(Func):
149    """
150    Concatenate text fields together. Backends that result in an entire
151    null expression when any arguments are null will wrap each argument in
152    coalesce functions to ensure a non-null result.
153    """
154
155    function = None
156    template = "%(expressions)s"
157
158    def __init__(self, *expressions: Any, **extra: Any) -> None:
159        if len(expressions) < 2:
160            raise ValueError("Concat must take at least two expressions")
161        paired = self._paired(expressions)
162        super().__init__(paired, **extra)
163
164    def _paired(self, expressions: tuple[Any, ...]) -> ConcatPair:
165        # wrap pairs of expressions in successive concat functions
166        # exp = [a, b, c, d]
167        # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
168        if len(expressions) == 2:
169            return ConcatPair(*expressions)
170        return ConcatPair(expressions[0], self._paired(expressions[1:]))
171
172
173class Left(Func):
174    function = "LEFT"
175    arity = 2
176    output_field = CharField()
177
178    def __init__(self, expression: Any, length: Any, **extra: Any) -> None:
179        """
180        expression: the name of a field, or an expression returning a string
181        length: the number of characters to return from the start of the string
182        """
183        if not isinstance(length, ResolvableExpression):
184            if length < 1:
185                raise ValueError("'length' must be greater than 0.")
186        super().__init__(expression, length, **extra)
187
188    def get_substr(self) -> Substr:
189        return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
190
191    def as_sqlite(
192        self,
193        compiler: SQLCompiler,
194        connection: BaseDatabaseWrapper,
195        **extra_context: Any,
196    ) -> tuple[str, list[Any]]:
197        return self.get_substr().as_sqlite(compiler, connection, **extra_context)
198
199
200class Length(Transform):
201    """Return the number of characters in the expression."""
202
203    function = "LENGTH"
204    lookup_name = "length"
205    output_field = IntegerField()
206
207    def as_mysql(
208        self,
209        compiler: SQLCompiler,
210        connection: BaseDatabaseWrapper,
211        **extra_context: Any,
212    ) -> tuple[str, list[Any]]:
213        return super().as_sql(
214            compiler, connection, function="CHAR_LENGTH", **extra_context
215        )
216
217
218class Lower(Transform):
219    function = "LOWER"
220    lookup_name = "lower"
221
222
223class LPad(Func):
224    function = "LPAD"
225    output_field = CharField()
226
227    def __init__(
228        self, expression: Any, length: Any, fill_text: Any = Value(" "), **extra: Any
229    ) -> None:
230        if (
231            not isinstance(length, ResolvableExpression)
232            and length is not None
233            and length < 0
234        ):
235            raise ValueError("'length' must be greater or equal to 0.")
236        super().__init__(expression, length, fill_text, **extra)
237
238
239class LTrim(Transform):
240    function = "LTRIM"
241    lookup_name = "ltrim"
242
243
244class MD5(Transform):
245    function = "MD5"
246    lookup_name = "md5"
247
248
249class Ord(Transform):
250    function = "ASCII"
251    lookup_name = "ord"
252    output_field = IntegerField()
253
254    def as_mysql(
255        self,
256        compiler: SQLCompiler,
257        connection: BaseDatabaseWrapper,
258        **extra_context: Any,
259    ) -> tuple[str, list[Any]]:
260        return super().as_sql(compiler, connection, function="ORD", **extra_context)
261
262    def as_sqlite(
263        self,
264        compiler: SQLCompiler,
265        connection: BaseDatabaseWrapper,
266        **extra_context: Any,
267    ) -> tuple[str, list[Any]]:
268        return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
269
270
271class Repeat(Func):
272    function = "REPEAT"
273    output_field = CharField()
274
275    def __init__(self, expression: Any, number: Any, **extra: Any) -> None:
276        if (
277            not isinstance(number, ResolvableExpression)
278            and number is not None
279            and number < 0
280        ):
281            raise ValueError("'number' must be greater or equal to 0.")
282        super().__init__(expression, number, **extra)
283
284
285class Replace(Func):
286    function = "REPLACE"
287
288    def __init__(
289        self, expression: Any, text: Any, replacement: Any = Value(""), **extra: Any
290    ) -> None:
291        super().__init__(expression, text, replacement, **extra)
292
293
294class Reverse(Transform):
295    function = "REVERSE"
296    lookup_name = "reverse"
297
298
299class Right(Left):
300    function = "RIGHT"
301
302    def get_substr(self) -> Substr:
303        return Substr(
304            self.source_expressions[0], self.source_expressions[1] * Value(-1)
305        )
306
307
308class RPad(LPad):
309    function = "RPAD"
310
311
312class RTrim(Transform):
313    function = "RTRIM"
314    lookup_name = "rtrim"
315
316
317class SHA1(PostgreSQLSHAMixin, Transform):
318    function = "SHA1"
319    lookup_name = "sha1"
320
321
322class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
323    function = "SHA224"
324    lookup_name = "sha224"
325
326
327class SHA256(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
328    function = "SHA256"
329    lookup_name = "sha256"
330
331
332class SHA384(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
333    function = "SHA384"
334    lookup_name = "sha384"
335
336
337class SHA512(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
338    function = "SHA512"
339    lookup_name = "sha512"
340
341
342class StrIndex(Func):
343    """
344    Return a positive integer corresponding to the 1-indexed position of the
345    first occurrence of a substring inside another string, or 0 if the
346    substring is not found.
347    """
348
349    function = "INSTR"
350    arity = 2
351    output_field = IntegerField()
352
353    def as_postgresql(
354        self,
355        compiler: SQLCompiler,
356        connection: BaseDatabaseWrapper,
357        **extra_context: Any,
358    ) -> tuple[str, list[Any]]:
359        return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
360
361
362class Substr(Func):
363    function = "SUBSTRING"
364    output_field = CharField()
365
366    def __init__(
367        self, expression: Any, pos: Any, length: Any = None, **extra: Any
368    ) -> None:
369        """
370        expression: the name of a field, or an expression returning a string
371        pos: an integer > 0, or an expression returning an integer
372        length: an optional number of characters to return
373        """
374        if not isinstance(pos, ResolvableExpression):
375            if pos < 1:
376                raise ValueError("'pos' must be greater than 0")
377        expressions = [expression, pos]
378        if length is not None:
379            expressions.append(length)
380        super().__init__(*expressions, **extra)
381
382    def as_sqlite(
383        self,
384        compiler: SQLCompiler,
385        connection: BaseDatabaseWrapper,
386        **extra_context: Any,
387    ) -> tuple[str, list[Any]]:
388        return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
389
390
391class Trim(Transform):
392    function = "TRIM"
393    lookup_name = "trim"
394
395
396class Upper(Transform):
397    function = "UPPER"
398    lookup_name = "upper"