Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import itertools
  4import math
  5from abc import ABC, abstractmethod
  6from collections.abc import Sequence
  7from functools import cached_property
  8from typing import TYPE_CHECKING, Any
  9
 10from plain.models.exceptions import EmptyResultSet, FullResultSet
 11from plain.models.expressions import Expression, Func, ResolvableExpression, Value
 12from plain.models.fields import (
 13    BooleanField,
 14    CharField,
 15    DateTimeField,
 16    Field,
 17    IntegerField,
 18    UUIDField,
 19)
 20from plain.models.query_utils import RegisterLookupMixin
 21from plain.utils.datastructures import OrderedSet
 22from plain.utils.hashable import make_hashable
 23
 24if TYPE_CHECKING:
 25    from plain.models.backends.base.base import BaseDatabaseWrapper
 26    from plain.models.sql.compiler import SQLCompiler
 27
 28
 29class Lookup(Expression):
 30    lookup_name: str | None = None
 31    prepare_rhs: bool = True
 32    can_use_none_as_rhs: bool = False
 33    lhs: Any
 34    rhs: Any
 35
 36    def __init__(self, lhs: Any, rhs: Any):
 37        self.lhs, self.rhs = lhs, rhs
 38        self.rhs = self.get_prep_lookup()
 39        self.lhs = self.get_prep_lhs()
 40        if hasattr(self.lhs, "get_bilateral_transforms"):
 41            bilateral_transforms = self.lhs.get_bilateral_transforms()
 42        else:
 43            bilateral_transforms = []
 44        if bilateral_transforms:
 45            # Warn the user as soon as possible if they are trying to apply
 46            # a bilateral transformation on a nested QuerySet: that won't work.
 47            from plain.models.sql.query import Query  # avoid circular import
 48
 49            if isinstance(rhs, Query):
 50                raise NotImplementedError(
 51                    "Bilateral transformations on nested querysets are not implemented."
 52                )
 53        self.bilateral_transforms = bilateral_transforms
 54
 55    def apply_bilateral_transforms(self, value: Any) -> Any:
 56        for transform in self.bilateral_transforms:
 57            value = transform(value)
 58        return value
 59
 60    def __repr__(self) -> str:
 61        return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
 62
 63    def batch_process_rhs(
 64        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, rhs: Any = None
 65    ) -> tuple[list[str], list[Any]]:
 66        if rhs is None:
 67            rhs = self.rhs
 68        if self.bilateral_transforms:
 69            sqls: list[str] = []
 70            sqls_params: list[Any] = []
 71            for p in rhs:
 72                value = Value(p, output_field=self.lhs.output_field)
 73                value = self.apply_bilateral_transforms(value)
 74                value = value.resolve_expression(compiler.query)
 75                sql, sql_params = compiler.compile(value)
 76                sqls.append(sql)
 77                sqls_params.extend(sql_params)
 78        else:
 79            _, params = self.get_db_prep_lookup(rhs, connection)
 80            sqls = ["%s"] * len(params)
 81            sqls_params = list(params)
 82        return sqls, sqls_params
 83
 84    def get_source_expressions(self) -> list[Any]:
 85        if self.rhs_is_direct_value():
 86            return [self.lhs]
 87        return [self.lhs, self.rhs]
 88
 89    def set_source_expressions(self, exprs: Sequence[Any]) -> None:
 90        exprs_list = list(exprs)
 91        if len(exprs_list) == 1:
 92            self.lhs = exprs_list[0]
 93        else:
 94            self.lhs, self.rhs = exprs_list
 95
 96    def get_prep_lookup(self) -> Any:
 97        if not self.prepare_rhs or isinstance(self.rhs, ResolvableExpression):
 98            return self.rhs
 99        if output_field := getattr(self.lhs, "output_field", None):
100            if get_prep_value := getattr(output_field, "get_prep_value", None):
101                return get_prep_value(self.rhs)
102        elif self.rhs_is_direct_value():
103            return Value(self.rhs)
104        return self.rhs
105
106    def get_prep_lhs(self) -> Any:
107        if isinstance(self.lhs, ResolvableExpression):
108            return self.lhs
109        return Value(self.lhs)
110
111    def get_db_prep_lookup(
112        self, value: Any, connection: BaseDatabaseWrapper
113    ) -> tuple[str, list[Any]]:
114        return ("%s", [value])
115
116    def process_lhs(
117        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Any = None
118    ) -> tuple[str, list[Any]]:
119        lhs = lhs or self.lhs
120        if isinstance(lhs, ResolvableExpression):
121            lhs = lhs.resolve_expression(compiler.query)
122        sql, params = compiler.compile(lhs)
123        if isinstance(lhs, Lookup):
124            # Wrapped in parentheses to respect operator precedence.
125            sql = f"({sql})"
126        return sql, list(params)
127
128    def process_rhs(
129        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
130    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
131        value = self.rhs
132        if self.bilateral_transforms:
133            if self.rhs_is_direct_value():
134                # Do not call get_db_prep_lookup here as the value will be
135                # transformed before being used for lookup
136                value = Value(value, output_field=self.lhs.output_field)
137            value = self.apply_bilateral_transforms(value)
138            value = value.resolve_expression(compiler.query)
139        if hasattr(value, "as_sql"):
140            sql, params = compiler.compile(value)
141            # Ensure expression is wrapped in parentheses to respect operator
142            # precedence but avoid double wrapping as it can be misinterpreted
143            # on some backends (e.g. subqueries on SQLite).
144            if sql and sql[0] != "(":
145                sql = f"({sql})"
146            return sql, list(params)
147        else:
148            return self.get_db_prep_lookup(value, connection)
149
150    def rhs_is_direct_value(self) -> bool:
151        return not hasattr(self.rhs, "as_sql")
152
153    def get_group_by_cols(self) -> list[Any]:
154        cols = []
155        for source in self.get_source_expressions():
156            cols.extend(source.get_group_by_cols())
157        return cols
158
159    @cached_property
160    def output_field(self) -> BooleanField:
161        return BooleanField()
162
163    @property
164    def identity(self) -> tuple[type[Lookup], Any, Any]:
165        return self.__class__, self.lhs, self.rhs
166
167    def __eq__(self, other: object) -> bool:
168        if not isinstance(other, Lookup):
169            return NotImplemented
170        return self.identity == other.identity
171
172    def __hash__(self) -> int:
173        return hash(make_hashable(self.identity))
174
175    def resolve_expression(
176        self,
177        query: Any = None,
178        allow_joins: bool = True,
179        reuse: Any = None,
180        summarize: bool = False,
181        for_save: bool = False,
182    ) -> Lookup:
183        c = self.copy()
184        c.is_summary = summarize
185        c.lhs = self.lhs.resolve_expression(
186            query, allow_joins, reuse, summarize, for_save
187        )
188        if isinstance(self.rhs, ResolvableExpression):
189            c.rhs = self.rhs.resolve_expression(
190                query, allow_joins, reuse, summarize, for_save
191            )
192        return c
193
194    def select_format(
195        self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
196    ) -> tuple[str, Sequence[Any]]:
197        # Wrap filters with a CASE WHEN expression if a database backend
198        # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
199        # BY list.
200        if not compiler.connection.features.supports_boolean_expr_in_select_clause:
201            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
202        return sql, params
203
204
205class Transform(RegisterLookupMixin, Func):
206    """
207    RegisterLookupMixin() is first so that get_lookup() and get_transform()
208    first examine self and then check output_field.
209    """
210
211    bilateral: bool = False
212    arity: int = 1
213
214    @property
215    def lhs(self) -> Any:
216        return self.get_source_expressions()[0]
217
218    def get_bilateral_transforms(self) -> list[type[Transform]]:
219        if hasattr(self.lhs, "get_bilateral_transforms"):
220            bilateral_transforms = self.lhs.get_bilateral_transforms()
221        else:
222            bilateral_transforms = []
223        if self.bilateral:
224            bilateral_transforms.append(self.__class__)
225        return bilateral_transforms
226
227
228class BuiltinLookup(Lookup):
229    def process_lhs(
230        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Any = None
231    ) -> tuple[str, list[Any]]:
232        assert self.lookup_name is not None, (
233            "lookup_name must be set on Lookup subclass"
234        )
235        lhs_sql, params = super().process_lhs(compiler, connection, lhs)
236        field_internal_type = self.lhs.output_field.get_internal_type()
237        db_type = self.lhs.output_field.db_type(connection=connection)
238        lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
239        lhs_sql = (
240            connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
241        )
242        return lhs_sql, list(params)
243
244    def as_sql(
245        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
246    ) -> tuple[str, list[Any]]:
247        lhs_sql, params = self.process_lhs(compiler, connection)
248        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
249        params.extend(rhs_params)
250        rhs_sql = self.get_rhs_op(connection, rhs_sql)
251        return f"{lhs_sql} {rhs_sql}", params
252
253    def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str | list[str]) -> str:
254        assert self.lookup_name is not None, (
255            "lookup_name must be set on Lookup subclass"
256        )
257        return connection.operators[self.lookup_name] % rhs
258
259
260class FieldGetDbPrepValueMixin(Lookup):
261    """
262    Some lookups require Field.get_db_prep_value() to be called on their
263    inputs.
264    """
265
266    get_db_prep_lookup_value_is_iterable: bool = False
267    lhs: Any
268    rhs: Any
269
270    def get_db_prep_lookup(
271        self, value: Any, connection: BaseDatabaseWrapper
272    ) -> tuple[str, list[Any]]:
273        # For relational fields, use the 'target_field' attribute of the
274        # output_field.
275        field = getattr(self.lhs.output_field, "target_field", None)
276        get_db_prep_value = (
277            getattr(field, "get_db_prep_value", None)
278            or self.lhs.output_field.get_db_prep_value
279        )
280        return (
281            "%s",
282            [get_db_prep_value(v, connection, prepared=True) for v in value]
283            if self.get_db_prep_lookup_value_is_iterable
284            else [get_db_prep_value(value, connection, prepared=True)],
285        )
286
287
288class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
289    """
290    Some lookups require Field.get_db_prep_value() to be called on each value
291    in an iterable.
292    """
293
294    get_db_prep_lookup_value_is_iterable: bool = True
295    prepare_rhs: bool
296
297    def get_prep_lookup(self) -> Any:
298        if isinstance(self.rhs, ResolvableExpression):
299            return self.rhs
300        prepared_values = []
301        for rhs_value in self.rhs:
302            if isinstance(rhs_value, ResolvableExpression):
303                # An expression will be handled by the database but can coexist
304                # alongside real values.
305                pass
306            elif self.prepare_rhs:
307                if output_field := getattr(self.lhs, "output_field", None):
308                    if get_prep_value := getattr(output_field, "get_prep_value", None):
309                        rhs_value = get_prep_value(rhs_value)
310            prepared_values.append(rhs_value)
311        return prepared_values
312
313    def process_rhs(
314        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
315    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
316        if self.rhs_is_direct_value():
317            # rhs should be an iterable of values. Use batch_process_rhs()
318            # to prepare/transform those values.
319            return self.batch_process_rhs(compiler, connection)
320        else:
321            return super().process_rhs(compiler, connection)
322
323    def resolve_expression_parameter(
324        self,
325        compiler: SQLCompiler,
326        connection: BaseDatabaseWrapper,
327        sql: str,
328        param: Any,
329    ) -> tuple[str, list[Any]]:
330        params: list[Any] = [param]
331        if isinstance(param, ResolvableExpression):
332            param = param.resolve_expression(compiler.query)
333        if hasattr(param, "as_sql"):
334            sql, compiled_params = compiler.compile(param)
335            params = list(compiled_params)
336        return sql, params
337
338    def batch_process_rhs(
339        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, rhs: Any = None
340    ) -> tuple[list[str], list[Any]]:
341        pre_processed = super().batch_process_rhs(compiler, connection, rhs)
342        # The params list may contain expressions which compile to a
343        # sql/param pair. Zip them to get sql and param pairs that refer to the
344        # same argument and attempt to replace them with the result of
345        # compiling the param step.
346        sql, params = zip(
347            *(
348                self.resolve_expression_parameter(compiler, connection, sql, param)
349                for sql, param in zip(*pre_processed)
350            )
351        )
352        params_list = list(itertools.chain.from_iterable(params))
353        return list(sql), params_list
354
355
356class PostgresOperatorLookup(Lookup):
357    """Lookup defined by operators on PostgreSQL."""
358
359    postgres_operator: str | None = None
360
361    def as_postgresql(
362        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
363    ) -> tuple[str, tuple[Any, ...]]:
364        lhs, lhs_params = self.process_lhs(compiler, connection)
365        rhs, rhs_params = self.process_rhs(compiler, connection)
366        params = tuple(lhs_params) + tuple(rhs_params)
367        return f"{lhs} {self.postgres_operator} {rhs}", params
368
369
370@Field.register_lookup
371class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
372    lookup_name: str = "exact"
373
374    def get_prep_lookup(self) -> Any:
375        from plain.models.sql.query import Query  # avoid circular import
376
377        if isinstance(self.rhs, Query):
378            if self.rhs.has_limit_one():
379                if not self.rhs.has_select_fields:
380                    self.rhs.clear_select_clause()
381                    self.rhs.add_fields(["id"])
382            else:
383                raise ValueError(
384                    "The QuerySet value for an exact lookup must be limited to "
385                    "one result using slicing."
386                )
387        return super().get_prep_lookup()
388
389    def as_sql(
390        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
391    ) -> tuple[str, list[Any]]:
392        # Avoid comparison against direct rhs if lhs is a boolean value. That
393        # turns "boolfield__exact=True" into "WHERE boolean_field" instead of
394        # "WHERE boolean_field = True" when allowed.
395        if (
396            isinstance(self.rhs, bool)
397            and getattr(self.lhs, "conditional", False)
398            and connection.ops.conditional_expression_supported_in_where_clause(
399                self.lhs
400            )
401        ):
402            lhs_sql, params = self.process_lhs(compiler, connection)
403            template = "%s" if self.rhs else "NOT %s"
404            return template % lhs_sql, params
405        return super().as_sql(compiler, connection)
406
407
408@Field.register_lookup
409class IExact(BuiltinLookup):
410    lookup_name: str = "iexact"
411    prepare_rhs: bool = False
412
413    def process_rhs(
414        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
415    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
416        rhs, params = super().process_rhs(compiler, connection)
417        if isinstance(rhs, str):
418            if params:
419                params[0] = connection.ops.prep_for_iexact_query(params[0])
420            return rhs, params
421        else:
422            return rhs, params
423
424
425@Field.register_lookup
426class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
427    lookup_name: str = "gt"
428
429
430@Field.register_lookup
431class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
432    lookup_name: str = "gte"
433
434
435@Field.register_lookup
436class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
437    lookup_name: str = "lt"
438
439
440@Field.register_lookup
441class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
442    lookup_name: str = "lte"
443
444
445class IntegerFieldOverflow:
446    underflow_exception: type[Exception] = EmptyResultSet
447    overflow_exception: type[Exception] = EmptyResultSet
448    lhs: Any
449    rhs: Any
450
451    def process_rhs(
452        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
453    ) -> tuple[str, list[Any]]:
454        rhs = self.rhs
455        if isinstance(rhs, int):
456            field_internal_type = self.lhs.output_field.get_internal_type()
457            min_value, max_value = connection.ops.integer_field_range(
458                field_internal_type
459            )
460            if min_value is not None and rhs < min_value:
461                raise self.underflow_exception
462            if max_value is not None and rhs > max_value:
463                raise self.overflow_exception
464        return super().process_rhs(compiler, connection)  # type: ignore[misc]
465
466
467class IntegerFieldFloatRounding:
468    """
469    Allow floats to work as query values for IntegerField. Without this, the
470    decimal portion of the float would always be discarded.
471    """
472
473    rhs: Any
474
475    def get_prep_lookup(self) -> Any:
476        if isinstance(self.rhs, float):
477            self.rhs = math.ceil(self.rhs)
478        return super().get_prep_lookup()  # type: ignore[misc]
479
480
481@IntegerField.register_lookup
482class IntegerFieldExact(IntegerFieldOverflow, Exact):
483    pass
484
485
486@IntegerField.register_lookup
487class IntegerGreaterThan(IntegerFieldOverflow, GreaterThan):
488    underflow_exception = FullResultSet
489
490
491@IntegerField.register_lookup
492class IntegerGreaterThanOrEqual(
493    IntegerFieldOverflow, IntegerFieldFloatRounding, GreaterThanOrEqual
494):
495    underflow_exception = FullResultSet
496
497
498@IntegerField.register_lookup
499class IntegerLessThan(IntegerFieldOverflow, IntegerFieldFloatRounding, LessThan):
500    overflow_exception = FullResultSet
501
502
503@IntegerField.register_lookup
504class IntegerLessThanOrEqual(IntegerFieldOverflow, LessThanOrEqual):
505    overflow_exception = FullResultSet
506
507
508@Field.register_lookup
509class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
510    lookup_name: str = "in"
511
512    def get_prep_lookup(self) -> Any:
513        from plain.models.sql.query import Query  # avoid circular import
514
515        if isinstance(self.rhs, Query):
516            self.rhs.clear_ordering(clear_default=True)
517            if not self.rhs.has_select_fields:
518                self.rhs.clear_select_clause()
519                self.rhs.add_fields(["id"])
520        return super().get_prep_lookup()
521
522    def process_rhs(
523        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
524    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
525        if self.rhs_is_direct_value():
526            # Remove None from the list as NULL is never equal to anything.
527            try:
528                rhs = OrderedSet(self.rhs)
529                rhs.discard(None)
530            except TypeError:  # Unhashable items in self.rhs
531                rhs = [r for r in self.rhs if r is not None]
532
533            if not rhs:
534                raise EmptyResultSet
535
536            # rhs should be an iterable; use batch_process_rhs() to
537            # prepare/transform those values.
538            sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
539            placeholder = "(" + ", ".join(sqls) + ")"
540            return (placeholder, sqls_params)
541        return super().process_rhs(compiler, connection)
542
543    def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str | list[str]) -> str:
544        return f"IN {rhs}"
545
546    def as_sql(
547        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
548    ) -> tuple[str, list[Any]]:
549        max_in_list_size = connection.ops.max_in_list_size()
550        if (
551            self.rhs_is_direct_value()
552            and max_in_list_size
553            and len(self.rhs) > max_in_list_size
554        ):
555            return self.split_parameter_list_as_sql(compiler, connection)
556        return super().as_sql(compiler, connection)
557
558    def split_parameter_list_as_sql(
559        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
560    ) -> tuple[str, list[Any]]:
561        # This is a special case for databases which limit the number of
562        # elements which can appear in an 'IN' clause.
563        max_in_list_size = connection.ops.max_in_list_size()
564        assert max_in_list_size is not None
565        lhs, lhs_params = self.process_lhs(compiler, connection)
566        rhs, rhs_params = self.batch_process_rhs(compiler, connection)
567        in_clause_elements = ["("]
568        params = []
569        for offset in range(0, len(rhs_params), max_in_list_size):
570            if offset > 0:
571                in_clause_elements.append(" OR ")
572            in_clause_elements.append(f"{lhs} IN (")
573            params.extend(lhs_params)
574            sqls = rhs[offset : offset + max_in_list_size]
575            sqls_params = rhs_params[offset : offset + max_in_list_size]
576            param_group = ", ".join(sqls)
577            in_clause_elements.append(param_group)
578            in_clause_elements.append(")")
579            params.extend(sqls_params)
580        in_clause_elements.append(")")
581        return "".join(in_clause_elements), params
582
583
584class PatternLookup(BuiltinLookup):
585    param_pattern: str = "%%%s%%"
586    prepare_rhs: bool = False
587    bilateral_transforms: list[Any]
588
589    def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str | list[str]) -> str:
590        # Assume we are in startswith. We need to produce SQL like:
591        #     col LIKE %s, ['thevalue%']
592        # For python values we can (and should) do that directly in Python,
593        # but if the value is for example reference to other column, then
594        # we need to add the % pattern match to the lookup by something like
595        #     col LIKE othercol || '%%'
596        # So, for Python values we don't need any special pattern, but for
597        # SQL reference values or SQL transformations we need the correct
598        # pattern added.
599        if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
600            assert self.lookup_name is not None, (
601                "lookup_name must be set on Lookup subclass"
602            )
603            pattern = connection.pattern_ops[self.lookup_name].format(
604                connection.pattern_esc
605            )
606            return pattern.format(rhs)
607        else:
608            return super().get_rhs_op(connection, rhs)
609
610    def process_rhs(
611        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
612    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
613        rhs, params = super().process_rhs(compiler, connection)
614        if isinstance(rhs, str):
615            if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
616                params[0] = self.param_pattern % connection.ops.prep_for_like_query(
617                    params[0]
618                )
619            return rhs, params
620        else:
621            return rhs, params
622
623
624@Field.register_lookup
625class Contains(PatternLookup):
626    lookup_name: str = "contains"
627
628
629@Field.register_lookup
630class IContains(Contains):
631    lookup_name: str = "icontains"
632
633
634@Field.register_lookup
635class StartsWith(PatternLookup):
636    lookup_name: str = "startswith"
637    param_pattern: str = "%s%%"
638
639
640@Field.register_lookup
641class IStartsWith(StartsWith):
642    lookup_name: str = "istartswith"
643
644
645@Field.register_lookup
646class EndsWith(PatternLookup):
647    lookup_name: str = "endswith"
648    param_pattern: str = "%%%s"
649
650
651@Field.register_lookup
652class IEndsWith(EndsWith):
653    lookup_name: str = "iendswith"
654
655
656@Field.register_lookup
657class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
658    lookup_name: str = "range"
659
660    def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str | list[str]) -> str:
661        # Range lookup always receives a list of two elements from process_rhs
662        assert isinstance(rhs, list), f"Range lookup expects list, got {type(rhs)}"
663        return f"BETWEEN {rhs[0]} AND {rhs[1]}"
664
665
666@Field.register_lookup
667class IsNull(BuiltinLookup):
668    lookup_name: str = "isnull"
669    prepare_rhs: bool = False
670
671    def as_sql(
672        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
673    ) -> tuple[str, list[Any]]:
674        if not isinstance(self.rhs, bool):
675            raise ValueError(
676                "The QuerySet value for an isnull lookup must be True or False."
677            )
678        sql, params = self.process_lhs(compiler, connection)
679        if self.rhs:
680            return f"{sql} IS NULL", params
681        else:
682            return f"{sql} IS NOT NULL", params
683
684
685@Field.register_lookup
686class Regex(BuiltinLookup):
687    lookup_name: str = "regex"
688    prepare_rhs: bool = False
689
690    def as_sql(
691        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
692    ) -> tuple[str, list[Any]]:
693        if self.lookup_name in connection.operators:
694            return super().as_sql(compiler, connection)
695        else:
696            lhs, lhs_params = self.process_lhs(compiler, connection)
697            rhs, rhs_params = self.process_rhs(compiler, connection)
698            sql_template = connection.ops.regex_lookup(self.lookup_name)
699            return sql_template % (lhs, rhs), lhs_params + rhs_params
700
701
702@Field.register_lookup
703class IRegex(Regex):
704    lookup_name: str = "iregex"
705
706
707class YearLookup(Lookup, ABC):
708    def year_lookup_bounds(
709        self, connection: BaseDatabaseWrapper, year: int
710    ) -> list[str | Any | None]:
711        from plain.models.functions import ExtractIsoYear
712
713        iso_year = isinstance(self.lhs, ExtractIsoYear)
714        output_field = self.lhs.lhs.output_field
715        if isinstance(output_field, DateTimeField):
716            bounds = connection.ops.year_lookup_bounds_for_datetime_field(
717                year,
718                iso_year=iso_year,
719            )
720        else:
721            bounds = connection.ops.year_lookup_bounds_for_date_field(
722                year,
723                iso_year=iso_year,
724            )
725        return bounds
726
727    def as_sql(
728        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
729    ) -> tuple[str, Sequence[Any]]:
730        # Avoid the extract operation if the rhs is a direct value to allow
731        # indexes to be used.
732        if self.rhs_is_direct_value():
733            # Skip the extract part by directly using the originating field,
734            # that is self.lhs.lhs.
735            lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
736            rhs_sql, _ = self.process_rhs(compiler, connection)
737            # rhs_sql should be a string for year lookups
738            assert isinstance(rhs_sql, str), f"Expected str, got {type(rhs_sql)}"
739            rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
740            start, finish = self.year_lookup_bounds(connection, self.rhs)
741            params.extend(self.get_bound_params(start, finish))
742            return f"{lhs_sql} {rhs_sql}", params
743        return super().as_sql(compiler, connection)
744
745    def get_direct_rhs_sql(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
746        assert self.lookup_name is not None, (
747            "lookup_name must be set on Lookup subclass"
748        )
749        return connection.operators[self.lookup_name] % rhs
750
751    @abstractmethod
752    def get_bound_params(self, start: Any, finish: Any) -> tuple[Any, ...]: ...
753
754
755class YearExact(YearLookup, Exact):
756    def get_direct_rhs_sql(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
757        return "BETWEEN %s AND %s"
758
759    def get_bound_params(self, start: Any, finish: Any) -> tuple[Any, Any]:
760        return (start, finish)
761
762
763class YearGt(YearLookup, GreaterThan):
764    def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
765        return (finish,)
766
767
768class YearGte(YearLookup, GreaterThanOrEqual):
769    def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
770        return (start,)
771
772
773class YearLt(YearLookup, LessThan):
774    def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
775        return (start,)
776
777
778class YearLte(YearLookup, LessThanOrEqual):
779    def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
780        return (finish,)
781
782
783class UUIDTextMixin(Lookup):
784    """
785    Strip hyphens from a value when filtering a UUIDField on backends without
786    a native datatype for UUID.
787    """
788
789    rhs: Any
790
791    def process_rhs(
792        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
793    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
794        if not connection.features.has_native_uuid_field:
795            from plain.models.functions import Replace
796
797            if self.rhs_is_direct_value():
798                self.rhs = Value(self.rhs)
799            self.rhs = Replace(
800                self.rhs, Value("-"), Value(""), output_field=CharField()
801            )
802        return super().process_rhs(compiler, connection)
803
804
805@UUIDField.register_lookup
806class UUIDIExact(UUIDTextMixin, IExact):
807    pass
808
809
810@UUIDField.register_lookup
811class UUIDContains(UUIDTextMixin, Contains):
812    pass
813
814
815@UUIDField.register_lookup
816class UUIDIContains(UUIDTextMixin, IContains):
817    pass
818
819
820@UUIDField.register_lookup
821class UUIDStartsWith(UUIDTextMixin, StartsWith):
822    pass
823
824
825@UUIDField.register_lookup
826class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
827    pass
828
829
830@UUIDField.register_lookup
831class UUIDEndsWith(UUIDTextMixin, EndsWith):
832    pass
833
834
835@UUIDField.register_lookup
836class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
837    pass