Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import itertools
  4import math
  5from functools import cached_property
  6from typing import TYPE_CHECKING, Any
  7
  8from plain.models.exceptions import EmptyResultSet, FullResultSet
  9from plain.models.expressions import Expression, Func, Value
 10from plain.models.fields import (
 11    BooleanField,
 12    CharField,
 13    DateTimeField,
 14    Field,
 15    IntegerField,
 16    UUIDField,
 17)
 18from plain.models.query_utils import RegisterLookupMixin
 19from plain.utils.datastructures import OrderedSet
 20from plain.utils.hashable import make_hashable
 21
 22if TYPE_CHECKING:
 23    from plain.models.backends.base.base import BaseDatabaseWrapper
 24    from plain.models.sql.compiler import SQLCompiler
 25
 26
 27class Lookup(Expression):
 28    lookup_name: str | None = None
 29    prepare_rhs: bool = True
 30    can_use_none_as_rhs: bool = False
 31
 32    def __init__(self, lhs: Any, rhs: Any):
 33        self.lhs, self.rhs = lhs, rhs
 34        self.rhs = self.get_prep_lookup()
 35        self.lhs = self.get_prep_lhs()
 36        if hasattr(self.lhs, "get_bilateral_transforms"):
 37            bilateral_transforms = self.lhs.get_bilateral_transforms()  # type: ignore[attr-defined]
 38        else:
 39            bilateral_transforms = []
 40        if bilateral_transforms:
 41            # Warn the user as soon as possible if they are trying to apply
 42            # a bilateral transformation on a nested QuerySet: that won't work.
 43            from plain.models.sql.query import Query  # avoid circular import
 44
 45            if isinstance(rhs, Query):
 46                raise NotImplementedError(
 47                    "Bilateral transformations on nested querysets are not implemented."
 48                )
 49        self.bilateral_transforms = bilateral_transforms
 50
 51    def apply_bilateral_transforms(self, value: Any) -> Any:
 52        for transform in self.bilateral_transforms:
 53            value = transform(value)
 54        return value
 55
 56    def __repr__(self) -> str:
 57        return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
 58
 59    def batch_process_rhs(
 60        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, rhs: Any = None
 61    ) -> tuple[list[str], list[Any]]:
 62        if rhs is None:
 63            rhs = self.rhs
 64        if self.bilateral_transforms:
 65            sqls, sqls_params = [], []
 66            for p in rhs:
 67                value = Value(p, output_field=self.lhs.output_field)  # type: ignore[attr-defined]
 68                value = self.apply_bilateral_transforms(value)
 69                value = value.resolve_expression(compiler.query)  # type: ignore[attr-defined]
 70                sql, sql_params = compiler.compile(value)
 71                sqls.append(sql)
 72                sqls_params.extend(sql_params)
 73        else:
 74            _, params = self.get_db_prep_lookup(rhs, connection)
 75            sqls, sqls_params = ["%s"] * len(params), params
 76        return sqls, sqls_params
 77
 78    def get_source_expressions(self) -> list[Any]:
 79        if self.rhs_is_direct_value():
 80            return [self.lhs]
 81        return [self.lhs, self.rhs]
 82
 83    def set_source_expressions(self, new_exprs: list[Any]) -> None:
 84        if len(new_exprs) == 1:
 85            self.lhs = new_exprs[0]
 86        else:
 87            self.lhs, self.rhs = new_exprs
 88
 89    def get_prep_lookup(self) -> Any:
 90        if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"):
 91            return self.rhs
 92        if hasattr(self.lhs, "output_field"):
 93            if hasattr(self.lhs.output_field, "get_prep_value"):  # type: ignore[attr-defined]
 94                return self.lhs.output_field.get_prep_value(self.rhs)  # type: ignore[attr-defined]
 95        elif self.rhs_is_direct_value():
 96            return Value(self.rhs)
 97        return self.rhs
 98
 99    def get_prep_lhs(self) -> Any:
100        if hasattr(self.lhs, "resolve_expression"):
101            return self.lhs
102        return Value(self.lhs)
103
104    def get_db_prep_lookup(
105        self, value: Any, connection: BaseDatabaseWrapper
106    ) -> tuple[str, list[Any]]:
107        return ("%s", [value])
108
109    def process_lhs(
110        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Any = None
111    ) -> tuple[str, list[Any]]:
112        lhs = lhs or self.lhs
113        if hasattr(lhs, "resolve_expression"):
114            lhs = lhs.resolve_expression(compiler.query)  # type: ignore[attr-defined]
115        sql, params = compiler.compile(lhs)
116        if isinstance(lhs, Lookup):
117            # Wrapped in parentheses to respect operator precedence.
118            sql = f"({sql})"
119        return sql, list(params)
120
121    def process_rhs(
122        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
123    ) -> tuple[str, list[Any]]:
124        value = self.rhs
125        if self.bilateral_transforms:
126            if self.rhs_is_direct_value():
127                # Do not call get_db_prep_lookup here as the value will be
128                # transformed before being used for lookup
129                value = Value(value, output_field=self.lhs.output_field)  # type: ignore[attr-defined]
130            value = self.apply_bilateral_transforms(value)
131            value = value.resolve_expression(compiler.query)  # type: ignore[attr-defined]
132        if hasattr(value, "as_sql"):
133            sql, params = compiler.compile(value)
134            # Ensure expression is wrapped in parentheses to respect operator
135            # precedence but avoid double wrapping as it can be misinterpreted
136            # on some backends (e.g. subqueries on SQLite).
137            if sql and sql[0] != "(":
138                sql = f"({sql})"
139            return sql, list(params)
140        else:
141            return self.get_db_prep_lookup(value, connection)
142
143    def rhs_is_direct_value(self) -> bool:
144        return not hasattr(self.rhs, "as_sql")
145
146    def get_group_by_cols(self) -> list[Any]:
147        cols = []
148        for source in self.get_source_expressions():
149            cols.extend(source.get_group_by_cols())  # type: ignore[attr-defined]
150        return cols
151
152    @cached_property
153    def output_field(self) -> BooleanField:
154        return BooleanField()
155
156    @property
157    def identity(self) -> tuple[type[Lookup], Any, Any]:
158        return self.__class__, self.lhs, self.rhs
159
160    def __eq__(self, other: object) -> bool:
161        if not isinstance(other, Lookup):
162            return NotImplemented
163        return self.identity == other.identity
164
165    def __hash__(self) -> int:
166        return hash(make_hashable(self.identity))
167
168    def resolve_expression(
169        self,
170        query: Any = None,
171        allow_joins: bool = True,
172        reuse: Any = None,
173        summarize: bool = False,
174        for_save: bool = False,
175    ) -> Lookup:
176        c = self.copy()  # type: ignore[attr-defined]
177        c.is_summary = summarize
178        c.lhs = self.lhs.resolve_expression(  # type: ignore[attr-defined]
179            query, allow_joins, reuse, summarize, for_save
180        )
181        if hasattr(self.rhs, "resolve_expression"):
182            c.rhs = self.rhs.resolve_expression(  # type: ignore[attr-defined]
183                query, allow_joins, reuse, summarize, for_save
184            )
185        return c
186
187    def select_format(
188        self, compiler: SQLCompiler, sql: str, params: list[Any]
189    ) -> tuple[str, list[Any]]:
190        # Wrap filters with a CASE WHEN expression if a database backend
191        # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
192        # BY list.
193        if not compiler.connection.features.supports_boolean_expr_in_select_clause:
194            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
195        return sql, params
196
197
198class Transform(RegisterLookupMixin, Func):
199    """
200    RegisterLookupMixin() is first so that get_lookup() and get_transform()
201    first examine self and then check output_field.
202    """
203
204    bilateral: bool = False
205    arity: int = 1
206
207    @property
208    def lhs(self) -> Any:
209        return self.get_source_expressions()[0]
210
211    def get_bilateral_transforms(self) -> list[type[Transform]]:
212        if hasattr(self.lhs, "get_bilateral_transforms"):
213            bilateral_transforms = self.lhs.get_bilateral_transforms()  # type: ignore[attr-defined]
214        else:
215            bilateral_transforms = []
216        if self.bilateral:
217            bilateral_transforms.append(self.__class__)
218        return bilateral_transforms
219
220
221class BuiltinLookup(Lookup):
222    def process_lhs(
223        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Any = None
224    ) -> tuple[str, list[Any]]:
225        lhs_sql, params = super().process_lhs(compiler, connection, lhs)  # type: ignore[misc]
226        field_internal_type = self.lhs.output_field.get_internal_type()  # type: ignore[attr-defined]
227        db_type = self.lhs.output_field.db_type(connection=connection)  # type: ignore[attr-defined]
228        lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
229        lhs_sql = (
230            connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql  # type: ignore[arg-type]
231        )
232        return lhs_sql, list(params)
233
234    def as_sql(
235        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
236    ) -> tuple[str, list[Any]]:
237        lhs_sql, params = self.process_lhs(compiler, connection)
238        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
239        params.extend(rhs_params)
240        rhs_sql = self.get_rhs_op(connection, rhs_sql)
241        return f"{lhs_sql} {rhs_sql}", params
242
243    def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
244        return connection.operators[self.lookup_name] % rhs  # type: ignore[index]
245
246
247class FieldGetDbPrepValueMixin:
248    """
249    Some lookups require Field.get_db_prep_value() to be called on their
250    inputs.
251    """
252
253    get_db_prep_lookup_value_is_iterable: bool = False
254    lhs: Any
255    rhs: Any
256
257    def get_db_prep_lookup(
258        self, value: Any, connection: BaseDatabaseWrapper
259    ) -> tuple[str, list[Any]]:
260        # For relational fields, use the 'target_field' attribute of the
261        # output_field.
262        field = getattr(self.lhs.output_field, "target_field", None)  # type: ignore[attr-defined]
263        get_db_prep_value = (
264            getattr(field, "get_db_prep_value", None)
265            or self.lhs.output_field.get_db_prep_value  # type: ignore[attr-defined]
266        )
267        return (
268            "%s",
269            [get_db_prep_value(v, connection, prepared=True) for v in value]
270            if self.get_db_prep_lookup_value_is_iterable
271            else [get_db_prep_value(value, connection, prepared=True)],
272        )
273
274
275class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
276    """
277    Some lookups require Field.get_db_prep_value() to be called on each value
278    in an iterable.
279    """
280
281    get_db_prep_lookup_value_is_iterable: bool = True
282    prepare_rhs: bool
283
284    def get_prep_lookup(self) -> Any:
285        if hasattr(self.rhs, "resolve_expression"):
286            return self.rhs
287        prepared_values = []
288        for rhs_value in self.rhs:
289            if hasattr(rhs_value, "resolve_expression"):
290                # An expression will be handled by the database but can coexist
291                # alongside real values.
292                pass
293            elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"):  # type: ignore[attr-defined]
294                rhs_value = self.lhs.output_field.get_prep_value(rhs_value)  # type: ignore[attr-defined]
295            prepared_values.append(rhs_value)
296        return prepared_values
297
298    def process_rhs(
299        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
300    ) -> tuple[str, list[Any]]:
301        if self.rhs_is_direct_value():  # type: ignore[attr-defined]
302            # rhs should be an iterable of values. Use batch_process_rhs()
303            # to prepare/transform those values.
304            return self.batch_process_rhs(compiler, connection)  # type: ignore[attr-defined]
305        else:
306            return super().process_rhs(compiler, connection)  # type: ignore[misc]
307
308    def resolve_expression_parameter(
309        self,
310        compiler: SQLCompiler,
311        connection: BaseDatabaseWrapper,
312        sql: str,
313        param: Any,
314    ) -> tuple[str, list[Any]]:
315        params: list[Any] = [param]
316        if hasattr(param, "resolve_expression"):
317            param = param.resolve_expression(compiler.query)  # type: ignore[attr-defined]
318        if hasattr(param, "as_sql"):
319            sql, compiled_params = compiler.compile(param)
320            params = list(compiled_params)
321        return sql, params
322
323    def batch_process_rhs(
324        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, rhs: Any = None
325    ) -> tuple[tuple[str, ...], tuple[Any, ...]]:
326        pre_processed = super().batch_process_rhs(compiler, connection, rhs)  # type: ignore[misc]
327        # The params list may contain expressions which compile to a
328        # sql/param pair. Zip them to get sql and param pairs that refer to the
329        # same argument and attempt to replace them with the result of
330        # compiling the param step.
331        sql, params = zip(
332            *(
333                self.resolve_expression_parameter(compiler, connection, sql, param)
334                for sql, param in zip(*pre_processed)
335            )
336        )
337        params = itertools.chain.from_iterable(params)
338        return sql, tuple(params)
339
340
341class PostgresOperatorLookup(Lookup):
342    """Lookup defined by operators on PostgreSQL."""
343
344    postgres_operator: str | None = None
345
346    def as_postgresql(
347        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
348    ) -> tuple[str, tuple[Any, ...]]:
349        lhs, lhs_params = self.process_lhs(compiler, connection)
350        rhs, rhs_params = self.process_rhs(compiler, connection)
351        params = tuple(lhs_params) + tuple(rhs_params)
352        return f"{lhs} {self.postgres_operator} {rhs}", params
353
354
355@Field.register_lookup
356class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
357    lookup_name: str = "exact"
358
359    def get_prep_lookup(self) -> Any:
360        from plain.models.sql.query import Query  # avoid circular import
361
362        if isinstance(self.rhs, Query):
363            if self.rhs.has_limit_one():
364                if not self.rhs.has_select_fields:
365                    self.rhs.clear_select_clause()
366                    self.rhs.add_fields(["id"])
367            else:
368                raise ValueError(
369                    "The QuerySet value for an exact lookup must be limited to "
370                    "one result using slicing."
371                )
372        return super().get_prep_lookup()  # type: ignore[misc]
373
374    def as_sql(
375        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
376    ) -> tuple[str, list[Any]]:
377        # Avoid comparison against direct rhs if lhs is a boolean value. That
378        # turns "boolfield__exact=True" into "WHERE boolean_field" instead of
379        # "WHERE boolean_field = True" when allowed.
380        if (
381            isinstance(self.rhs, bool)
382            and getattr(self.lhs, "conditional", False)
383            and connection.ops.conditional_expression_supported_in_where_clause(
384                self.lhs
385            )
386        ):
387            lhs_sql, params = self.process_lhs(compiler, connection)
388            template = "%s" if self.rhs else "NOT %s"
389            return template % lhs_sql, params
390        return super().as_sql(compiler, connection)  # type: ignore[misc]
391
392
393@Field.register_lookup
394class IExact(BuiltinLookup):
395    lookup_name: str = "iexact"
396    prepare_rhs: bool = False
397
398    def process_rhs(
399        self, qn: SQLCompiler, connection: BaseDatabaseWrapper
400    ) -> tuple[str, list[Any]]:
401        rhs, params = super().process_rhs(qn, connection)  # type: ignore[misc]
402        if params:
403            params[0] = connection.ops.prep_for_iexact_query(params[0])
404        return rhs, params
405
406
407@Field.register_lookup
408class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
409    lookup_name: str = "gt"
410
411
412@Field.register_lookup
413class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
414    lookup_name: str = "gte"
415
416
417@Field.register_lookup
418class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
419    lookup_name: str = "lt"
420
421
422@Field.register_lookup
423class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
424    lookup_name: str = "lte"
425
426
427class IntegerFieldOverflow:
428    underflow_exception: type[Exception] = EmptyResultSet
429    overflow_exception: type[Exception] = EmptyResultSet
430    lhs: Any
431    rhs: Any
432
433    def process_rhs(
434        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
435    ) -> tuple[str, list[Any]]:
436        rhs = self.rhs
437        if isinstance(rhs, int):
438            field_internal_type = self.lhs.output_field.get_internal_type()  # type: ignore[attr-defined]
439            min_value, max_value = connection.ops.integer_field_range(
440                field_internal_type
441            )
442            if min_value is not None and rhs < min_value:
443                raise self.underflow_exception
444            if max_value is not None and rhs > max_value:
445                raise self.overflow_exception
446        return super().process_rhs(compiler, connection)  # type: ignore[misc]
447
448
449class IntegerFieldFloatRounding:
450    """
451    Allow floats to work as query values for IntegerField. Without this, the
452    decimal portion of the float would always be discarded.
453    """
454
455    rhs: Any
456
457    def get_prep_lookup(self) -> Any:
458        if isinstance(self.rhs, float):
459            self.rhs = math.ceil(self.rhs)
460        return super().get_prep_lookup()  # type: ignore[misc]
461
462
463@IntegerField.register_lookup
464class IntegerFieldExact(IntegerFieldOverflow, Exact):
465    pass
466
467
468@IntegerField.register_lookup
469class IntegerGreaterThan(IntegerFieldOverflow, GreaterThan):
470    underflow_exception = FullResultSet
471
472
473@IntegerField.register_lookup
474class IntegerGreaterThanOrEqual(
475    IntegerFieldOverflow, IntegerFieldFloatRounding, GreaterThanOrEqual
476):
477    underflow_exception = FullResultSet
478
479
480@IntegerField.register_lookup
481class IntegerLessThan(IntegerFieldOverflow, IntegerFieldFloatRounding, LessThan):
482    overflow_exception = FullResultSet
483
484
485@IntegerField.register_lookup
486class IntegerLessThanOrEqual(IntegerFieldOverflow, LessThanOrEqual):
487    overflow_exception = FullResultSet
488
489
490@Field.register_lookup
491class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
492    lookup_name: str = "in"
493
494    def get_prep_lookup(self) -> Any:
495        from plain.models.sql.query import Query  # avoid circular import
496
497        if isinstance(self.rhs, Query):
498            self.rhs.clear_ordering(clear_default=True)
499            if not self.rhs.has_select_fields:
500                self.rhs.clear_select_clause()
501                self.rhs.add_fields(["id"])
502        return super().get_prep_lookup()  # type: ignore[misc]
503
504    def process_rhs(
505        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
506    ) -> tuple[str, list[Any]] | tuple[str, tuple[Any, ...]]:
507        if self.rhs_is_direct_value():  # type: ignore[attr-defined]
508            # Remove None from the list as NULL is never equal to anything.
509            try:
510                rhs = OrderedSet(self.rhs)
511                rhs.discard(None)
512            except TypeError:  # Unhashable items in self.rhs
513                rhs = [r for r in self.rhs if r is not None]
514
515            if not rhs:
516                raise EmptyResultSet
517
518            # rhs should be an iterable; use batch_process_rhs() to
519            # prepare/transform those values.
520            sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)  # type: ignore[attr-defined]
521            placeholder = "(" + ", ".join(sqls) + ")"
522            return (placeholder, sqls_params)
523        return super().process_rhs(compiler, connection)  # type: ignore[misc]
524
525    def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
526        return f"IN {rhs}"
527
528    def as_sql(
529        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
530    ) -> tuple[str, list[Any]]:
531        max_in_list_size = connection.ops.max_in_list_size()
532        if (
533            self.rhs_is_direct_value()  # type: ignore[attr-defined]
534            and max_in_list_size
535            and len(self.rhs) > max_in_list_size
536        ):
537            return self.split_parameter_list_as_sql(compiler, connection)
538        return super().as_sql(compiler, connection)  # type: ignore[misc]
539
540    def split_parameter_list_as_sql(
541        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
542    ) -> tuple[str, list[Any]]:
543        # This is a special case for databases which limit the number of
544        # elements which can appear in an 'IN' clause.
545        max_in_list_size = connection.ops.max_in_list_size()
546        lhs, lhs_params = self.process_lhs(compiler, connection)
547        rhs, rhs_params = self.batch_process_rhs(compiler, connection)  # type: ignore[attr-defined]
548        in_clause_elements = ["("]
549        params = []
550        for offset in range(0, len(rhs_params), max_in_list_size):  # type: ignore[arg-type]
551            if offset > 0:
552                in_clause_elements.append(" OR ")
553            in_clause_elements.append(f"{lhs} IN (")
554            params.extend(lhs_params)
555            sqls = rhs[offset : offset + max_in_list_size]  # type: ignore[operator]
556            sqls_params = rhs_params[offset : offset + max_in_list_size]  # type: ignore[index,operator]
557            param_group = ", ".join(sqls)
558            in_clause_elements.append(param_group)
559            in_clause_elements.append(")")
560            params.extend(sqls_params)
561        in_clause_elements.append(")")
562        return "".join(in_clause_elements), params
563
564
565class PatternLookup(BuiltinLookup):
566    param_pattern: str = "%%%s%%"
567    prepare_rhs: bool = False
568    bilateral_transforms: list[Any]
569
570    def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
571        # Assume we are in startswith. We need to produce SQL like:
572        #     col LIKE %s, ['thevalue%']
573        # For python values we can (and should) do that directly in Python,
574        # but if the value is for example reference to other column, then
575        # we need to add the % pattern match to the lookup by something like
576        #     col LIKE othercol || '%%'
577        # So, for Python values we don't need any special pattern, but for
578        # SQL reference values or SQL transformations we need the correct
579        # pattern added.
580        if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
581            pattern = connection.pattern_ops[self.lookup_name].format(  # type: ignore[index]
582                connection.pattern_esc  # type: ignore[attr-defined]
583            )
584            return pattern.format(rhs)
585        else:
586            return super().get_rhs_op(connection, rhs)
587
588    def process_rhs(
589        self, qn: SQLCompiler, connection: BaseDatabaseWrapper
590    ) -> tuple[str, list[Any]]:
591        rhs, params = super().process_rhs(qn, connection)  # type: ignore[misc]
592        if self.rhs_is_direct_value() and params and not self.bilateral_transforms:  # type: ignore[attr-defined]
593            params[0] = self.param_pattern % connection.ops.prep_for_like_query(
594                params[0]
595            )
596        return rhs, params
597
598
599@Field.register_lookup
600class Contains(PatternLookup):
601    lookup_name: str = "contains"
602
603
604@Field.register_lookup
605class IContains(Contains):
606    lookup_name: str = "icontains"
607
608
609@Field.register_lookup
610class StartsWith(PatternLookup):
611    lookup_name: str = "startswith"
612    param_pattern: str = "%s%%"
613
614
615@Field.register_lookup
616class IStartsWith(StartsWith):
617    lookup_name: str = "istartswith"
618
619
620@Field.register_lookup
621class EndsWith(PatternLookup):
622    lookup_name: str = "endswith"
623    param_pattern: str = "%%%s"
624
625
626@Field.register_lookup
627class IEndsWith(EndsWith):
628    lookup_name: str = "iendswith"
629
630
631@Field.register_lookup
632class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
633    lookup_name: str = "range"
634
635    def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
636        return f"BETWEEN {rhs[0]} AND {rhs[1]}"
637
638
639@Field.register_lookup
640class IsNull(BuiltinLookup):
641    lookup_name: str = "isnull"
642    prepare_rhs: bool = False
643
644    def as_sql(
645        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
646    ) -> tuple[str, list[Any]]:
647        if not isinstance(self.rhs, bool):
648            raise ValueError(
649                "The QuerySet value for an isnull lookup must be True or False."
650            )
651        sql, params = self.process_lhs(compiler, connection)
652        if self.rhs:
653            return f"{sql} IS NULL", params
654        else:
655            return f"{sql} IS NOT NULL", params
656
657
658@Field.register_lookup
659class Regex(BuiltinLookup):
660    lookup_name: str = "regex"
661    prepare_rhs: bool = False
662
663    def as_sql(
664        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
665    ) -> tuple[str, list[Any]]:
666        if self.lookup_name in connection.operators:  # type: ignore[operator]
667            return super().as_sql(compiler, connection)  # type: ignore[misc]
668        else:
669            lhs, lhs_params = self.process_lhs(compiler, connection)
670            rhs, rhs_params = self.process_rhs(compiler, connection)
671            sql_template = connection.ops.regex_lookup(self.lookup_name)
672            return sql_template % (lhs, rhs), lhs_params + rhs_params
673
674
675@Field.register_lookup
676class IRegex(Regex):
677    lookup_name: str = "iregex"
678
679
680class YearLookup(Lookup):
681    def year_lookup_bounds(
682        self, connection: BaseDatabaseWrapper, year: int
683    ) -> list[str | Any | None]:
684        from plain.models.functions import ExtractIsoYear
685
686        iso_year = isinstance(self.lhs, ExtractIsoYear)
687        output_field = self.lhs.lhs.output_field  # type: ignore[attr-defined]
688        if isinstance(output_field, DateTimeField):
689            bounds = connection.ops.year_lookup_bounds_for_datetime_field(
690                year,
691                iso_year=iso_year,
692            )
693        else:
694            bounds = connection.ops.year_lookup_bounds_for_date_field(
695                year,
696                iso_year=iso_year,
697            )
698        return bounds
699
700    def as_sql(
701        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
702    ) -> tuple[str, list[Any]]:
703        # Avoid the extract operation if the rhs is a direct value to allow
704        # indexes to be used.
705        if self.rhs_is_direct_value():
706            # Skip the extract part by directly using the originating field,
707            # that is self.lhs.lhs.
708            lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)  # type: ignore[attr-defined]
709            rhs_sql, _ = self.process_rhs(compiler, connection)
710            rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
711            start, finish = self.year_lookup_bounds(connection, self.rhs)
712            params.extend(self.get_bound_params(start, finish))
713            return f"{lhs_sql} {rhs_sql}", params
714        return super().as_sql(compiler, connection)  # type: ignore[misc]
715
716    def get_direct_rhs_sql(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
717        return connection.operators[self.lookup_name] % rhs  # type: ignore[index]
718
719    def get_bound_params(self, start: Any, finish: Any) -> tuple[Any, ...]:
720        raise NotImplementedError(
721            "subclasses of YearLookup must provide a get_bound_params() method"
722        )
723
724
725class YearExact(YearLookup, Exact):
726    def get_direct_rhs_sql(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
727        return "BETWEEN %s AND %s"
728
729    def get_bound_params(self, start: Any, finish: Any) -> tuple[Any, Any]:
730        return (start, finish)
731
732
733class YearGt(YearLookup, GreaterThan):
734    def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
735        return (finish,)
736
737
738class YearGte(YearLookup, GreaterThanOrEqual):
739    def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
740        return (start,)
741
742
743class YearLt(YearLookup, LessThan):
744    def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
745        return (start,)
746
747
748class YearLte(YearLookup, LessThanOrEqual):
749    def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
750        return (finish,)
751
752
753class UUIDTextMixin:
754    """
755    Strip hyphens from a value when filtering a UUIDField on backends without
756    a native datatype for UUID.
757    """
758
759    rhs: Any
760
761    def process_rhs(
762        self, qn: SQLCompiler, connection: BaseDatabaseWrapper
763    ) -> tuple[str, list[Any]]:
764        if not connection.features.has_native_uuid_field:
765            from plain.models.functions import Replace
766
767            if self.rhs_is_direct_value():  # type: ignore[attr-defined]
768                self.rhs = Value(self.rhs)
769            self.rhs = Replace(
770                self.rhs, Value("-"), Value(""), output_field=CharField()
771            )
772        rhs, params = super().process_rhs(qn, connection)  # type: ignore[misc]
773        return rhs, params
774
775
776@UUIDField.register_lookup
777class UUIDIExact(UUIDTextMixin, IExact):
778    pass
779
780
781@UUIDField.register_lookup
782class UUIDContains(UUIDTextMixin, Contains):
783    pass
784
785
786@UUIDField.register_lookup
787class UUIDIContains(UUIDTextMixin, IContains):
788    pass
789
790
791@UUIDField.register_lookup
792class UUIDStartsWith(UUIDTextMixin, StartsWith):
793    pass
794
795
796@UUIDField.register_lookup
797class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
798    pass
799
800
801@UUIDField.register_lookup
802class UUIDEndsWith(UUIDTextMixin, EndsWith):
803    pass
804
805
806@UUIDField.register_lookup
807class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
808    pass