1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5from plain.models.expressions import Func, ResolvableExpression, Value
  6from plain.models.fields import CharField, IntegerField, TextField
  7from plain.models.functions import Cast, Coalesce
  8from plain.models.lookups import Transform
  9
 10if TYPE_CHECKING:
 11    from plain.models.postgres.wrapper import DatabaseWrapper
 12    from plain.models.sql.compiler import SQLCompiler
 13
 14
 15class SHAMixin(Transform):
 16    """Base class for SHA hashing using PostgreSQL's pgcrypto extension."""
 17
 18    def as_sql(
 19        self,
 20        compiler: SQLCompiler,
 21        connection: DatabaseWrapper,
 22        function: str | None = None,
 23        template: str | None = None,
 24        arg_joiner: str | None = None,
 25        **extra_context: Any,
 26    ) -> tuple[str, list[Any]]:
 27        assert self.function is not None
 28        return super().as_sql(
 29            compiler,
 30            connection,
 31            template="ENCODE(DIGEST(%(expressions)s, '%(function)s'), 'hex')",
 32            function=self.function.lower(),
 33            **extra_context,
 34        )
 35
 36
 37class Chr(Transform):
 38    function = "CHR"
 39    lookup_name = "chr"
 40
 41
 42class ConcatPair(Func):
 43    """Concatenate two arguments together."""
 44
 45    function = "CONCAT"
 46
 47    def as_sql(
 48        self,
 49        compiler: SQLCompiler,
 50        connection: DatabaseWrapper,
 51        function: str | None = None,
 52        template: str | None = None,
 53        arg_joiner: str | None = None,
 54        **extra_context: Any,
 55    ) -> tuple[str, list[Any]]:
 56        # PostgreSQL requires explicit cast to text for CONCAT.
 57        copy = self.copy()
 58        copy.set_source_expressions(
 59            [
 60                Cast(expression, TextField())
 61                for expression in copy.get_source_expressions()
 62            ]
 63        )
 64        return super(ConcatPair, copy).as_sql(
 65            compiler,
 66            connection,
 67            **extra_context,
 68        )
 69
 70    def coalesce(self) -> ConcatPair:
 71        # null on either side results in null for expression, wrap with coalesce
 72        c = self.copy()
 73        c.set_source_expressions(
 74            [
 75                Coalesce(expression, Value(""))
 76                for expression in c.get_source_expressions()
 77            ]
 78        )
 79        return c
 80
 81
 82class Concat(Func):
 83    """
 84    Concatenate text fields together. Wraps each argument in coalesce
 85    functions to ensure a non-null result.
 86    """
 87
 88    function = None
 89    template = "%(expressions)s"
 90
 91    def __init__(self, *expressions: Any, **extra: Any) -> None:
 92        if len(expressions) < 2:
 93            raise ValueError("Concat must take at least two expressions")
 94        paired = self._paired(expressions)
 95        super().__init__(paired, **extra)
 96
 97    def _paired(self, expressions: tuple[Any, ...]) -> ConcatPair:
 98        # wrap pairs of expressions in successive concat functions
 99        # exp = [a, b, c, d]
100        # -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
101        if len(expressions) == 2:
102            return ConcatPair(*expressions)
103        return ConcatPair(expressions[0], self._paired(expressions[1:]))
104
105
106class Left(Func):
107    function = "LEFT"
108    arity = 2
109    output_field = CharField()
110
111    def __init__(self, expression: Any, length: Any, **extra: Any) -> None:
112        """
113        expression: the name of a field, or an expression returning a string
114        length: the number of characters to return from the start of the string
115        """
116        if not isinstance(length, ResolvableExpression):
117            if length < 1:
118                raise ValueError("'length' must be greater than 0.")
119        super().__init__(expression, length, **extra)
120
121    def get_substr(self) -> Substr:
122        return Substr(self.source_expressions[0], Value(1), self.source_expressions[1])
123
124
125class Length(Transform):
126    """Return the number of characters in the expression."""
127
128    function = "LENGTH"
129    lookup_name = "length"
130    output_field = IntegerField()
131
132
133class Lower(Transform):
134    function = "LOWER"
135    lookup_name = "lower"
136
137
138class LPad(Func):
139    function = "LPAD"
140    output_field = CharField()
141
142    def __init__(
143        self, expression: Any, length: Any, fill_text: Any = Value(" "), **extra: Any
144    ) -> None:
145        if (
146            not isinstance(length, ResolvableExpression)
147            and length is not None
148            and length < 0
149        ):
150            raise ValueError("'length' must be greater or equal to 0.")
151        super().__init__(expression, length, fill_text, **extra)
152
153
154class LTrim(Transform):
155    function = "LTRIM"
156    lookup_name = "ltrim"
157
158
159class MD5(Transform):
160    function = "MD5"
161    lookup_name = "md5"
162
163
164class Ord(Transform):
165    function = "ASCII"
166    lookup_name = "ord"
167    output_field = IntegerField()
168
169
170class Repeat(Func):
171    function = "REPEAT"
172    output_field = CharField()
173
174    def __init__(self, expression: Any, number: Any, **extra: Any) -> None:
175        if (
176            not isinstance(number, ResolvableExpression)
177            and number is not None
178            and number < 0
179        ):
180            raise ValueError("'number' must be greater or equal to 0.")
181        super().__init__(expression, number, **extra)
182
183
184class Replace(Func):
185    function = "REPLACE"
186
187    def __init__(
188        self, expression: Any, text: Any, replacement: Any = Value(""), **extra: Any
189    ) -> None:
190        super().__init__(expression, text, replacement, **extra)
191
192
193class Reverse(Transform):
194    function = "REVERSE"
195    lookup_name = "reverse"
196
197
198class Right(Left):
199    function = "RIGHT"
200
201    def get_substr(self) -> Substr:
202        return Substr(
203            self.source_expressions[0], self.source_expressions[1] * Value(-1)
204        )
205
206
207class RPad(LPad):
208    function = "RPAD"
209
210
211class RTrim(Transform):
212    function = "RTRIM"
213    lookup_name = "rtrim"
214
215
216class SHA1(SHAMixin, Transform):
217    function = "SHA1"
218    lookup_name = "sha1"
219
220
221class SHA224(SHAMixin, Transform):
222    function = "SHA224"
223    lookup_name = "sha224"
224
225
226class SHA256(SHAMixin, Transform):
227    function = "SHA256"
228    lookup_name = "sha256"
229
230
231class SHA384(SHAMixin, Transform):
232    function = "SHA384"
233    lookup_name = "sha384"
234
235
236class SHA512(SHAMixin, Transform):
237    function = "SHA512"
238    lookup_name = "sha512"
239
240
241class StrIndex(Func):
242    """
243    Return a positive integer corresponding to the 1-indexed position of the
244    first occurrence of a substring inside another string, or 0 if the
245    substring is not found.
246    """
247
248    # PostgreSQL uses STRPOS instead of INSTR.
249    function = "STRPOS"
250    arity = 2
251    output_field = IntegerField()
252
253
254class Substr(Func):
255    function = "SUBSTRING"
256    output_field = CharField()
257
258    def __init__(
259        self, expression: Any, pos: Any, length: Any = None, **extra: Any
260    ) -> None:
261        """
262        expression: the name of a field, or an expression returning a string
263        pos: an integer > 0, or an expression returning an integer
264        length: an optional number of characters to return
265        """
266        if not isinstance(pos, ResolvableExpression):
267            if pos < 1:
268                raise ValueError("'pos' must be greater than 0")
269        expressions = [expression, pos]
270        if length is not None:
271            expressions.append(length)
272        super().__init__(*expressions, **extra)
273
274
275class Trim(Transform):
276    function = "TRIM"
277    lookup_name = "trim"
278
279
280class Upper(Transform):
281    function = "UPPER"
282    lookup_name = "upper"