Plain is headed towards 1.0! Subscribe for development updates →

  1from plain.models.expressions import Func, Value
  2from plain.models.fields import CharField, IntegerField, TextField
  3from plain.models.functions import Cast, Coalesce
  4from plain.models.lookups import Transform
  5
  6
  7class MySQLSHA2Mixin:
  8    def as_mysql(self, compiler, connection, **extra_context):
  9        return super().as_sql(
 10            compiler,
 11            connection,
 12            template="SHA2(%%(expressions)s, %s)" % self.function[3:],
 13            **extra_context,
 14        )
 15
 16
 17class PostgreSQLSHAMixin:
 18    def as_postgresql(self, compiler, connection, **extra_context):
 19        return super().as_sql(
 20            compiler,
 21            connection,
 22            template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
 23            function=self.function.lower(),
 24            **extra_context,
 25        )
 26
 27
 28class Chr(Transform):
 29    function = "CHR"
 30    lookup_name = "chr"
 31
 32    def as_mysql(self, compiler, connection, **extra_context):
 33        return super().as_sql(
 34            compiler,
 35            connection,
 36            function="CHAR",
 37            template="%(function)s(%(expressions)s USING utf16)",
 38            **extra_context,
 39        )
 40
 41    def as_sqlite(self, compiler, connection, **extra_context):
 42        return super().as_sql(compiler, connection, function="CHAR", **extra_context)
 43
 44
 45class ConcatPair(Func):
 46    """
 47    Concatenate two arguments together. This is used by `Concat` because not
 48    all backend databases support more than two arguments.
 49    """
 50
 51    function = "CONCAT"
 52
 53    def as_sqlite(self, compiler, connection, **extra_context):
 54        coalesced = self.coalesce()
 55        return super(ConcatPair, coalesced).as_sql(
 56            compiler,
 57            connection,
 58            template="%(expressions)s",
 59            arg_joiner=" || ",
 60            **extra_context,
 61        )
 62
 63    def as_postgresql(self, compiler, connection, **extra_context):
 64        copy = self.copy()
 65        copy.set_source_expressions(
 66            [
 67                Cast(expression, TextField())
 68                for expression in copy.get_source_expressions()
 69            ]
 70        )
 71        return super(ConcatPair, copy).as_sql(
 72            compiler,
 73            connection,
 74            **extra_context,
 75        )
 76
 77    def as_mysql(self, compiler, connection, **extra_context):
 78        # Use CONCAT_WS with an empty separator so that NULLs are ignored.
 79        return super().as_sql(
 80            compiler,
 81            connection,
 82            function="CONCAT_WS",
 83            template="%(function)s('', %(expressions)s)",
 84            **extra_context,
 85        )
 86
 87    def coalesce(self):
 88        # null on either side results in null for expression, wrap with coalesce
 89        c = self.copy()
 90        c.set_source_expressions(
 91            [
 92                Coalesce(expression, Value(""))
 93                for expression in c.get_source_expressions()
 94            ]
 95        )
 96        return c
 97
 98
 99class Concat(Func):
100    """
101    Concatenate text fields together. Backends that result in an entire
102    null expression when any arguments are null will wrap each argument in
103    coalesce functions to ensure a non-null result.
104    """
105
106    function = None
107    template = "%(expressions)s"
108
109    def __init__(self, *expressions, **extra):
110        if len(expressions) < 2:
111            raise ValueError("Concat must take at least two expressions")
112        paired = self._paired(expressions)
113        super().__init__(paired, **extra)
114
115    def _paired(self, expressions):
116        # wrap pairs of expressions in successive concat functions
117        # exp = [a, b, c, d]
118        # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
119        if len(expressions) == 2:
120            return ConcatPair(*expressions)
121        return ConcatPair(expressions[0], self._paired(expressions[1:]))
122
123
124class Left(Func):
125    function = "LEFT"
126    arity = 2
127    output_field = CharField()
128
129    def __init__(self, expression, length, **extra):
130        """
131        expression: the name of a field, or an expression returning a string
132        length: the number of characters to return from the start of the string
133        """
134        if not hasattr(length, "resolve_expression"):
135            if length < 1:
136                raise ValueError("'length' must be greater than 0.")
137        super().__init__(expression, length, **extra)
138
139    def get_substr(self):
140        return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
141
142    def as_sqlite(self, compiler, connection, **extra_context):
143        return self.get_substr().as_sqlite(compiler, connection, **extra_context)
144
145
146class Length(Transform):
147    """Return the number of characters in the expression."""
148
149    function = "LENGTH"
150    lookup_name = "length"
151    output_field = IntegerField()
152
153    def as_mysql(self, compiler, connection, **extra_context):
154        return super().as_sql(
155            compiler, connection, function="CHAR_LENGTH", **extra_context
156        )
157
158
159class Lower(Transform):
160    function = "LOWER"
161    lookup_name = "lower"
162
163
164class LPad(Func):
165    function = "LPAD"
166    output_field = CharField()
167
168    def __init__(self, expression, length, fill_text=Value(" "), **extra):
169        if (
170            not hasattr(length, "resolve_expression")
171            and length is not None
172            and length < 0
173        ):
174            raise ValueError("'length' must be greater or equal to 0.")
175        super().__init__(expression, length, fill_text, **extra)
176
177
178class LTrim(Transform):
179    function = "LTRIM"
180    lookup_name = "ltrim"
181
182
183class MD5(Transform):
184    function = "MD5"
185    lookup_name = "md5"
186
187
188class Ord(Transform):
189    function = "ASCII"
190    lookup_name = "ord"
191    output_field = IntegerField()
192
193    def as_mysql(self, compiler, connection, **extra_context):
194        return super().as_sql(compiler, connection, function="ORD", **extra_context)
195
196    def as_sqlite(self, compiler, connection, **extra_context):
197        return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
198
199
200class Repeat(Func):
201    function = "REPEAT"
202    output_field = CharField()
203
204    def __init__(self, expression, number, **extra):
205        if (
206            not hasattr(number, "resolve_expression")
207            and number is not None
208            and number < 0
209        ):
210            raise ValueError("'number' must be greater or equal to 0.")
211        super().__init__(expression, number, **extra)
212
213
214class Replace(Func):
215    function = "REPLACE"
216
217    def __init__(self, expression, text, replacement=Value(""), **extra):
218        super().__init__(expression, text, replacement, **extra)
219
220
221class Reverse(Transform):
222    function = "REVERSE"
223    lookup_name = "reverse"
224
225
226class Right(Left):
227    function = "RIGHT"
228
229    def get_substr(self):
230        return Substr(
231            self.source_expressions[0], self.source_expressions[1] * Value(-1)
232        )
233
234
235class RPad(LPad):
236    function = "RPAD"
237
238
239class RTrim(Transform):
240    function = "RTRIM"
241    lookup_name = "rtrim"
242
243
244class SHA1(PostgreSQLSHAMixin, Transform):
245    function = "SHA1"
246    lookup_name = "sha1"
247
248
249class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
250    function = "SHA224"
251    lookup_name = "sha224"
252
253
254class SHA256(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
255    function = "SHA256"
256    lookup_name = "sha256"
257
258
259class SHA384(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
260    function = "SHA384"
261    lookup_name = "sha384"
262
263
264class SHA512(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
265    function = "SHA512"
266    lookup_name = "sha512"
267
268
269class StrIndex(Func):
270    """
271    Return a positive integer corresponding to the 1-indexed position of the
272    first occurrence of a substring inside another string, or 0 if the
273    substring is not found.
274    """
275
276    function = "INSTR"
277    arity = 2
278    output_field = IntegerField()
279
280    def as_postgresql(self, compiler, connection, **extra_context):
281        return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
282
283
284class Substr(Func):
285    function = "SUBSTRING"
286    output_field = CharField()
287
288    def __init__(self, expression, pos, length=None, **extra):
289        """
290        expression: the name of a field, or an expression returning a string
291        pos: an integer > 0, or an expression returning an integer
292        length: an optional number of characters to return
293        """
294        if not hasattr(pos, "resolve_expression"):
295            if pos < 1:
296                raise ValueError("'pos' must be greater than 0")
297        expressions = [expression, pos]
298        if length is not None:
299            expressions.append(length)
300        super().__init__(*expressions, **extra)
301
302    def as_sqlite(self, compiler, connection, **extra_context):
303        return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
304
305
306class Trim(Transform):
307    function = "TRIM"
308    lookup_name = "trim"
309
310
311class Upper(Transform):
312    function = "UPPER"
313    lookup_name = "upper"