1from __future__ import annotations
2
3import itertools
4import math
5from abc import ABC, abstractmethod
6from collections.abc import Sequence
7from functools import cached_property
8from typing import TYPE_CHECKING, Any
9
10from plain.models.exceptions import EmptyResultSet, FullResultSet
11from plain.models.expressions import Expression, Func, ResolvableExpression, Value
12from plain.models.fields import (
13 BooleanField,
14 CharField,
15 DateTimeField,
16 Field,
17 IntegerField,
18 UUIDField,
19)
20from plain.models.query_utils import RegisterLookupMixin
21from plain.utils.datastructures import OrderedSet
22from plain.utils.hashable import make_hashable
23
24if TYPE_CHECKING:
25 from plain.models.backends.base.base import BaseDatabaseWrapper
26 from plain.models.sql.compiler import SQLCompiler
27
28
29class Lookup(Expression):
30 lookup_name: str | None = None
31 prepare_rhs: bool = True
32 can_use_none_as_rhs: bool = False
33 lhs: Any
34 rhs: Any
35
36 def __init__(self, lhs: Any, rhs: Any):
37 self.lhs, self.rhs = lhs, rhs
38 self.rhs = self.get_prep_lookup()
39 self.lhs = self.get_prep_lhs()
40 if hasattr(self.lhs, "get_bilateral_transforms"):
41 bilateral_transforms = self.lhs.get_bilateral_transforms()
42 else:
43 bilateral_transforms = []
44 if bilateral_transforms:
45 # Warn the user as soon as possible if they are trying to apply
46 # a bilateral transformation on a nested QuerySet: that won't work.
47 from plain.models.sql.query import Query # avoid circular import
48
49 if isinstance(rhs, Query):
50 raise NotImplementedError(
51 "Bilateral transformations on nested querysets are not implemented."
52 )
53 self.bilateral_transforms = bilateral_transforms
54
55 def apply_bilateral_transforms(self, value: Any) -> Any:
56 for transform in self.bilateral_transforms:
57 value = transform(value)
58 return value
59
60 def __repr__(self) -> str:
61 return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
62
63 def batch_process_rhs(
64 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, rhs: Any = None
65 ) -> tuple[list[str], list[Any]]:
66 if rhs is None:
67 rhs = self.rhs
68 if self.bilateral_transforms:
69 sqls: list[str] = []
70 sqls_params: list[Any] = []
71 for p in rhs:
72 value = Value(p, output_field=self.lhs.output_field)
73 value = self.apply_bilateral_transforms(value)
74 value = value.resolve_expression(compiler.query)
75 sql, sql_params = compiler.compile(value)
76 sqls.append(sql)
77 sqls_params.extend(sql_params)
78 else:
79 _, params = self.get_db_prep_lookup(rhs, connection)
80 sqls = ["%s"] * len(params)
81 sqls_params = list(params)
82 return sqls, sqls_params
83
84 def get_source_expressions(self) -> list[Any]:
85 if self.rhs_is_direct_value():
86 return [self.lhs]
87 return [self.lhs, self.rhs]
88
89 def set_source_expressions(self, exprs: Sequence[Any]) -> None:
90 exprs_list = list(exprs)
91 if len(exprs_list) == 1:
92 self.lhs = exprs_list[0]
93 else:
94 self.lhs, self.rhs = exprs_list
95
96 def get_prep_lookup(self) -> Any:
97 if not self.prepare_rhs or isinstance(self.rhs, ResolvableExpression):
98 return self.rhs
99 if output_field := getattr(self.lhs, "output_field", None):
100 if get_prep_value := getattr(output_field, "get_prep_value", None):
101 return get_prep_value(self.rhs)
102 elif self.rhs_is_direct_value():
103 return Value(self.rhs)
104 return self.rhs
105
106 def get_prep_lhs(self) -> Any:
107 if isinstance(self.lhs, ResolvableExpression):
108 return self.lhs
109 return Value(self.lhs)
110
111 def get_db_prep_lookup(
112 self, value: Any, connection: BaseDatabaseWrapper
113 ) -> tuple[str, list[Any]]:
114 return ("%s", [value])
115
116 def process_lhs(
117 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Any = None
118 ) -> tuple[str, list[Any]]:
119 lhs = lhs or self.lhs
120 if isinstance(lhs, ResolvableExpression):
121 lhs = lhs.resolve_expression(compiler.query)
122 sql, params = compiler.compile(lhs)
123 if isinstance(lhs, Lookup):
124 # Wrapped in parentheses to respect operator precedence.
125 sql = f"({sql})"
126 return sql, list(params)
127
128 def process_rhs(
129 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
130 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
131 value = self.rhs
132 if self.bilateral_transforms:
133 if self.rhs_is_direct_value():
134 # Do not call get_db_prep_lookup here as the value will be
135 # transformed before being used for lookup
136 value = Value(value, output_field=self.lhs.output_field)
137 value = self.apply_bilateral_transforms(value)
138 value = value.resolve_expression(compiler.query)
139 if hasattr(value, "as_sql"):
140 sql, params = compiler.compile(value)
141 # Ensure expression is wrapped in parentheses to respect operator
142 # precedence but avoid double wrapping as it can be misinterpreted
143 # on some backends (e.g. subqueries on SQLite).
144 if sql and sql[0] != "(":
145 sql = f"({sql})"
146 return sql, list(params)
147 else:
148 return self.get_db_prep_lookup(value, connection)
149
150 def rhs_is_direct_value(self) -> bool:
151 return not hasattr(self.rhs, "as_sql")
152
153 def get_group_by_cols(self) -> list[Any]:
154 cols = []
155 for source in self.get_source_expressions():
156 cols.extend(source.get_group_by_cols())
157 return cols
158
159 @cached_property
160 def output_field(self) -> BooleanField:
161 return BooleanField()
162
163 @property
164 def identity(self) -> tuple[type[Lookup], Any, Any]:
165 return self.__class__, self.lhs, self.rhs
166
167 def __eq__(self, other: object) -> bool:
168 if not isinstance(other, Lookup):
169 return NotImplemented
170 return self.identity == other.identity
171
172 def __hash__(self) -> int:
173 return hash(make_hashable(self.identity))
174
175 def resolve_expression(
176 self,
177 query: Any = None,
178 allow_joins: bool = True,
179 reuse: Any = None,
180 summarize: bool = False,
181 for_save: bool = False,
182 ) -> Lookup:
183 c = self.copy()
184 c.is_summary = summarize
185 c.lhs = self.lhs.resolve_expression(
186 query, allow_joins, reuse, summarize, for_save
187 )
188 if isinstance(self.rhs, ResolvableExpression):
189 c.rhs = self.rhs.resolve_expression(
190 query, allow_joins, reuse, summarize, for_save
191 )
192 return c
193
194 def select_format(
195 self, compiler: SQLCompiler, sql: str, params: Sequence[Any]
196 ) -> tuple[str, Sequence[Any]]:
197 # Wrap filters with a CASE WHEN expression if a database backend
198 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
199 # BY list.
200 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
201 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
202 return sql, params
203
204
205class Transform(RegisterLookupMixin, Func):
206 """
207 RegisterLookupMixin() is first so that get_lookup() and get_transform()
208 first examine self and then check output_field.
209 """
210
211 bilateral: bool = False
212 arity: int = 1
213
214 @property
215 def lhs(self) -> Any:
216 return self.get_source_expressions()[0]
217
218 def get_bilateral_transforms(self) -> list[type[Transform]]:
219 if hasattr(self.lhs, "get_bilateral_transforms"):
220 bilateral_transforms = self.lhs.get_bilateral_transforms()
221 else:
222 bilateral_transforms = []
223 if self.bilateral:
224 bilateral_transforms.append(self.__class__)
225 return bilateral_transforms
226
227
228class BuiltinLookup(Lookup):
229 def process_lhs(
230 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Any = None
231 ) -> tuple[str, list[Any]]:
232 assert self.lookup_name is not None, (
233 "lookup_name must be set on Lookup subclass"
234 )
235 lhs_sql, params = super().process_lhs(compiler, connection, lhs)
236 field_internal_type = self.lhs.output_field.get_internal_type()
237 db_type = self.lhs.output_field.db_type(connection=connection)
238 lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
239 lhs_sql = (
240 connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
241 )
242 return lhs_sql, list(params)
243
244 def as_sql(
245 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
246 ) -> tuple[str, list[Any]]:
247 lhs_sql, params = self.process_lhs(compiler, connection)
248 rhs_sql, rhs_params = self.process_rhs(compiler, connection)
249 params.extend(rhs_params)
250 rhs_sql = self.get_rhs_op(connection, rhs_sql)
251 return f"{lhs_sql} {rhs_sql}", params
252
253 def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str | list[str]) -> str:
254 assert self.lookup_name is not None, (
255 "lookup_name must be set on Lookup subclass"
256 )
257 return connection.operators[self.lookup_name] % rhs
258
259
260class FieldGetDbPrepValueMixin(Lookup):
261 """
262 Some lookups require Field.get_db_prep_value() to be called on their
263 inputs.
264 """
265
266 get_db_prep_lookup_value_is_iterable: bool = False
267 lhs: Any
268 rhs: Any
269
270 def get_db_prep_lookup(
271 self, value: Any, connection: BaseDatabaseWrapper
272 ) -> tuple[str, list[Any]]:
273 # For relational fields, use the 'target_field' attribute of the
274 # output_field.
275 field = getattr(self.lhs.output_field, "target_field", None)
276 get_db_prep_value = (
277 getattr(field, "get_db_prep_value", None)
278 or self.lhs.output_field.get_db_prep_value
279 )
280 return (
281 "%s",
282 [get_db_prep_value(v, connection, prepared=True) for v in value]
283 if self.get_db_prep_lookup_value_is_iterable
284 else [get_db_prep_value(value, connection, prepared=True)],
285 )
286
287
288class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
289 """
290 Some lookups require Field.get_db_prep_value() to be called on each value
291 in an iterable.
292 """
293
294 get_db_prep_lookup_value_is_iterable: bool = True
295 prepare_rhs: bool
296
297 def get_prep_lookup(self) -> Any:
298 if isinstance(self.rhs, ResolvableExpression):
299 return self.rhs
300 prepared_values = []
301 for rhs_value in self.rhs:
302 if isinstance(rhs_value, ResolvableExpression):
303 # An expression will be handled by the database but can coexist
304 # alongside real values.
305 pass
306 elif self.prepare_rhs:
307 if output_field := getattr(self.lhs, "output_field", None):
308 if get_prep_value := getattr(output_field, "get_prep_value", None):
309 rhs_value = get_prep_value(rhs_value)
310 prepared_values.append(rhs_value)
311 return prepared_values
312
313 def process_rhs(
314 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
315 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
316 if self.rhs_is_direct_value():
317 # rhs should be an iterable of values. Use batch_process_rhs()
318 # to prepare/transform those values.
319 return self.batch_process_rhs(compiler, connection)
320 else:
321 return super().process_rhs(compiler, connection)
322
323 def resolve_expression_parameter(
324 self,
325 compiler: SQLCompiler,
326 connection: BaseDatabaseWrapper,
327 sql: str,
328 param: Any,
329 ) -> tuple[str, list[Any]]:
330 params: list[Any] = [param]
331 if isinstance(param, ResolvableExpression):
332 param = param.resolve_expression(compiler.query)
333 if hasattr(param, "as_sql"):
334 sql, compiled_params = compiler.compile(param)
335 params = list(compiled_params)
336 return sql, params
337
338 def batch_process_rhs(
339 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, rhs: Any = None
340 ) -> tuple[list[str], list[Any]]:
341 pre_processed = super().batch_process_rhs(compiler, connection, rhs)
342 # The params list may contain expressions which compile to a
343 # sql/param pair. Zip them to get sql and param pairs that refer to the
344 # same argument and attempt to replace them with the result of
345 # compiling the param step.
346 sql, params = zip(
347 *(
348 self.resolve_expression_parameter(compiler, connection, sql, param)
349 for sql, param in zip(*pre_processed)
350 )
351 )
352 params_list = list(itertools.chain.from_iterable(params))
353 return list(sql), params_list
354
355
356class PostgresOperatorLookup(Lookup):
357 """Lookup defined by operators on PostgreSQL."""
358
359 postgres_operator: str | None = None
360
361 def as_postgresql(
362 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
363 ) -> tuple[str, tuple[Any, ...]]:
364 lhs, lhs_params = self.process_lhs(compiler, connection)
365 rhs, rhs_params = self.process_rhs(compiler, connection)
366 params = tuple(lhs_params) + tuple(rhs_params)
367 return f"{lhs} {self.postgres_operator} {rhs}", params
368
369
370@Field.register_lookup
371class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
372 lookup_name: str = "exact"
373
374 def get_prep_lookup(self) -> Any:
375 from plain.models.sql.query import Query # avoid circular import
376
377 if isinstance(self.rhs, Query):
378 if self.rhs.has_limit_one():
379 if not self.rhs.has_select_fields:
380 self.rhs.clear_select_clause()
381 self.rhs.add_fields(["id"])
382 else:
383 raise ValueError(
384 "The QuerySet value for an exact lookup must be limited to "
385 "one result using slicing."
386 )
387 return super().get_prep_lookup()
388
389 def as_sql(
390 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
391 ) -> tuple[str, list[Any]]:
392 # Avoid comparison against direct rhs if lhs is a boolean value. That
393 # turns "boolfield__exact=True" into "WHERE boolean_field" instead of
394 # "WHERE boolean_field = True" when allowed.
395 if (
396 isinstance(self.rhs, bool)
397 and getattr(self.lhs, "conditional", False)
398 and connection.ops.conditional_expression_supported_in_where_clause(
399 self.lhs
400 )
401 ):
402 lhs_sql, params = self.process_lhs(compiler, connection)
403 template = "%s" if self.rhs else "NOT %s"
404 return template % lhs_sql, params
405 return super().as_sql(compiler, connection)
406
407
408@Field.register_lookup
409class IExact(BuiltinLookup):
410 lookup_name: str = "iexact"
411 prepare_rhs: bool = False
412
413 def process_rhs(
414 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
415 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
416 rhs, params = super().process_rhs(compiler, connection)
417 if isinstance(rhs, str):
418 if params:
419 params[0] = connection.ops.prep_for_iexact_query(params[0])
420 return rhs, params
421 else:
422 return rhs, params
423
424
425@Field.register_lookup
426class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
427 lookup_name: str = "gt"
428
429
430@Field.register_lookup
431class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
432 lookup_name: str = "gte"
433
434
435@Field.register_lookup
436class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
437 lookup_name: str = "lt"
438
439
440@Field.register_lookup
441class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
442 lookup_name: str = "lte"
443
444
445class IntegerFieldOverflow:
446 underflow_exception: type[Exception] = EmptyResultSet
447 overflow_exception: type[Exception] = EmptyResultSet
448 lhs: Any
449 rhs: Any
450
451 def process_rhs(
452 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
453 ) -> tuple[str, list[Any]]:
454 rhs = self.rhs
455 if isinstance(rhs, int):
456 field_internal_type = self.lhs.output_field.get_internal_type()
457 min_value, max_value = connection.ops.integer_field_range(
458 field_internal_type
459 )
460 if min_value is not None and rhs < min_value:
461 raise self.underflow_exception
462 if max_value is not None and rhs > max_value:
463 raise self.overflow_exception
464 return super().process_rhs(compiler, connection) # type: ignore[misc]
465
466
467class IntegerFieldFloatRounding:
468 """
469 Allow floats to work as query values for IntegerField. Without this, the
470 decimal portion of the float would always be discarded.
471 """
472
473 rhs: Any
474
475 def get_prep_lookup(self) -> Any:
476 if isinstance(self.rhs, float):
477 self.rhs = math.ceil(self.rhs)
478 return super().get_prep_lookup() # type: ignore[misc]
479
480
481@IntegerField.register_lookup
482class IntegerFieldExact(IntegerFieldOverflow, Exact):
483 pass
484
485
486@IntegerField.register_lookup
487class IntegerGreaterThan(IntegerFieldOverflow, GreaterThan):
488 underflow_exception = FullResultSet
489
490
491@IntegerField.register_lookup
492class IntegerGreaterThanOrEqual(
493 IntegerFieldOverflow, IntegerFieldFloatRounding, GreaterThanOrEqual
494):
495 underflow_exception = FullResultSet
496
497
498@IntegerField.register_lookup
499class IntegerLessThan(IntegerFieldOverflow, IntegerFieldFloatRounding, LessThan):
500 overflow_exception = FullResultSet
501
502
503@IntegerField.register_lookup
504class IntegerLessThanOrEqual(IntegerFieldOverflow, LessThanOrEqual):
505 overflow_exception = FullResultSet
506
507
508@Field.register_lookup
509class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
510 lookup_name: str = "in"
511
512 def get_prep_lookup(self) -> Any:
513 from plain.models.sql.query import Query # avoid circular import
514
515 if isinstance(self.rhs, Query):
516 self.rhs.clear_ordering(clear_default=True)
517 if not self.rhs.has_select_fields:
518 self.rhs.clear_select_clause()
519 self.rhs.add_fields(["id"])
520 return super().get_prep_lookup()
521
522 def process_rhs(
523 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
524 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
525 if self.rhs_is_direct_value():
526 # Remove None from the list as NULL is never equal to anything.
527 try:
528 rhs = OrderedSet(self.rhs)
529 rhs.discard(None)
530 except TypeError: # Unhashable items in self.rhs
531 rhs = [r for r in self.rhs if r is not None]
532
533 if not rhs:
534 raise EmptyResultSet
535
536 # rhs should be an iterable; use batch_process_rhs() to
537 # prepare/transform those values.
538 sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
539 placeholder = "(" + ", ".join(sqls) + ")"
540 return (placeholder, sqls_params)
541 return super().process_rhs(compiler, connection)
542
543 def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str | list[str]) -> str:
544 return f"IN {rhs}"
545
546 def as_sql(
547 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
548 ) -> tuple[str, list[Any]]:
549 max_in_list_size = connection.ops.max_in_list_size()
550 if (
551 self.rhs_is_direct_value()
552 and max_in_list_size
553 and len(self.rhs) > max_in_list_size
554 ):
555 return self.split_parameter_list_as_sql(compiler, connection)
556 return super().as_sql(compiler, connection)
557
558 def split_parameter_list_as_sql(
559 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
560 ) -> tuple[str, list[Any]]:
561 # This is a special case for databases which limit the number of
562 # elements which can appear in an 'IN' clause.
563 max_in_list_size = connection.ops.max_in_list_size()
564 assert max_in_list_size is not None
565 lhs, lhs_params = self.process_lhs(compiler, connection)
566 rhs, rhs_params = self.batch_process_rhs(compiler, connection)
567 in_clause_elements = ["("]
568 params = []
569 for offset in range(0, len(rhs_params), max_in_list_size):
570 if offset > 0:
571 in_clause_elements.append(" OR ")
572 in_clause_elements.append(f"{lhs} IN (")
573 params.extend(lhs_params)
574 sqls = rhs[offset : offset + max_in_list_size]
575 sqls_params = rhs_params[offset : offset + max_in_list_size]
576 param_group = ", ".join(sqls)
577 in_clause_elements.append(param_group)
578 in_clause_elements.append(")")
579 params.extend(sqls_params)
580 in_clause_elements.append(")")
581 return "".join(in_clause_elements), params
582
583
584class PatternLookup(BuiltinLookup):
585 param_pattern: str = "%%%s%%"
586 prepare_rhs: bool = False
587 bilateral_transforms: list[Any]
588
589 def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str | list[str]) -> str:
590 # Assume we are in startswith. We need to produce SQL like:
591 # col LIKE %s, ['thevalue%']
592 # For python values we can (and should) do that directly in Python,
593 # but if the value is for example reference to other column, then
594 # we need to add the % pattern match to the lookup by something like
595 # col LIKE othercol || '%%'
596 # So, for Python values we don't need any special pattern, but for
597 # SQL reference values or SQL transformations we need the correct
598 # pattern added.
599 if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
600 assert self.lookup_name is not None, (
601 "lookup_name must be set on Lookup subclass"
602 )
603 pattern = connection.pattern_ops[self.lookup_name].format(
604 connection.pattern_esc
605 )
606 return pattern.format(rhs)
607 else:
608 return super().get_rhs_op(connection, rhs)
609
610 def process_rhs(
611 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
612 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
613 rhs, params = super().process_rhs(compiler, connection)
614 if isinstance(rhs, str):
615 if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
616 params[0] = self.param_pattern % connection.ops.prep_for_like_query(
617 params[0]
618 )
619 return rhs, params
620 else:
621 return rhs, params
622
623
624@Field.register_lookup
625class Contains(PatternLookup):
626 lookup_name: str = "contains"
627
628
629@Field.register_lookup
630class IContains(Contains):
631 lookup_name: str = "icontains"
632
633
634@Field.register_lookup
635class StartsWith(PatternLookup):
636 lookup_name: str = "startswith"
637 param_pattern: str = "%s%%"
638
639
640@Field.register_lookup
641class IStartsWith(StartsWith):
642 lookup_name: str = "istartswith"
643
644
645@Field.register_lookup
646class EndsWith(PatternLookup):
647 lookup_name: str = "endswith"
648 param_pattern: str = "%%%s"
649
650
651@Field.register_lookup
652class IEndsWith(EndsWith):
653 lookup_name: str = "iendswith"
654
655
656@Field.register_lookup
657class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
658 lookup_name: str = "range"
659
660 def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str | list[str]) -> str:
661 # Range lookup always receives a list of two elements from process_rhs
662 assert isinstance(rhs, list), f"Range lookup expects list, got {type(rhs)}"
663 return f"BETWEEN {rhs[0]} AND {rhs[1]}"
664
665
666@Field.register_lookup
667class IsNull(BuiltinLookup):
668 lookup_name: str = "isnull"
669 prepare_rhs: bool = False
670
671 def as_sql(
672 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
673 ) -> tuple[str, list[Any]]:
674 if not isinstance(self.rhs, bool):
675 raise ValueError(
676 "The QuerySet value for an isnull lookup must be True or False."
677 )
678 sql, params = self.process_lhs(compiler, connection)
679 if self.rhs:
680 return f"{sql} IS NULL", params
681 else:
682 return f"{sql} IS NOT NULL", params
683
684
685@Field.register_lookup
686class Regex(BuiltinLookup):
687 lookup_name: str = "regex"
688 prepare_rhs: bool = False
689
690 def as_sql(
691 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
692 ) -> tuple[str, list[Any]]:
693 if self.lookup_name in connection.operators:
694 return super().as_sql(compiler, connection)
695 else:
696 lhs, lhs_params = self.process_lhs(compiler, connection)
697 rhs, rhs_params = self.process_rhs(compiler, connection)
698 sql_template = connection.ops.regex_lookup(self.lookup_name)
699 return sql_template % (lhs, rhs), lhs_params + rhs_params
700
701
702@Field.register_lookup
703class IRegex(Regex):
704 lookup_name: str = "iregex"
705
706
707class YearLookup(Lookup, ABC):
708 def year_lookup_bounds(
709 self, connection: BaseDatabaseWrapper, year: int
710 ) -> list[str | Any | None]:
711 from plain.models.functions import ExtractIsoYear
712
713 iso_year = isinstance(self.lhs, ExtractIsoYear)
714 output_field = self.lhs.lhs.output_field
715 if isinstance(output_field, DateTimeField):
716 bounds = connection.ops.year_lookup_bounds_for_datetime_field(
717 year,
718 iso_year=iso_year,
719 )
720 else:
721 bounds = connection.ops.year_lookup_bounds_for_date_field(
722 year,
723 iso_year=iso_year,
724 )
725 return bounds
726
727 def as_sql(
728 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
729 ) -> tuple[str, Sequence[Any]]:
730 # Avoid the extract operation if the rhs is a direct value to allow
731 # indexes to be used.
732 if self.rhs_is_direct_value():
733 # Skip the extract part by directly using the originating field,
734 # that is self.lhs.lhs.
735 lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs)
736 rhs_sql, _ = self.process_rhs(compiler, connection)
737 # rhs_sql should be a string for year lookups
738 assert isinstance(rhs_sql, str), f"Expected str, got {type(rhs_sql)}"
739 rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
740 start, finish = self.year_lookup_bounds(connection, self.rhs)
741 params.extend(self.get_bound_params(start, finish))
742 return f"{lhs_sql} {rhs_sql}", params
743 return super().as_sql(compiler, connection)
744
745 def get_direct_rhs_sql(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
746 assert self.lookup_name is not None, (
747 "lookup_name must be set on Lookup subclass"
748 )
749 return connection.operators[self.lookup_name] % rhs
750
751 @abstractmethod
752 def get_bound_params(self, start: Any, finish: Any) -> tuple[Any, ...]: ...
753
754
755class YearExact(YearLookup, Exact):
756 def get_direct_rhs_sql(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
757 return "BETWEEN %s AND %s"
758
759 def get_bound_params(self, start: Any, finish: Any) -> tuple[Any, Any]:
760 return (start, finish)
761
762
763class YearGt(YearLookup, GreaterThan):
764 def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
765 return (finish,)
766
767
768class YearGte(YearLookup, GreaterThanOrEqual):
769 def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
770 return (start,)
771
772
773class YearLt(YearLookup, LessThan):
774 def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
775 return (start,)
776
777
778class YearLte(YearLookup, LessThanOrEqual):
779 def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
780 return (finish,)
781
782
783class UUIDTextMixin(Lookup):
784 """
785 Strip hyphens from a value when filtering a UUIDField on backends without
786 a native datatype for UUID.
787 """
788
789 rhs: Any
790
791 def process_rhs(
792 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
793 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
794 if not connection.features.has_native_uuid_field:
795 from plain.models.functions import Replace
796
797 if self.rhs_is_direct_value():
798 self.rhs = Value(self.rhs)
799 self.rhs = Replace(
800 self.rhs, Value("-"), Value(""), output_field=CharField()
801 )
802 return super().process_rhs(compiler, connection)
803
804
805@UUIDField.register_lookup
806class UUIDIExact(UUIDTextMixin, IExact):
807 pass
808
809
810@UUIDField.register_lookup
811class UUIDContains(UUIDTextMixin, Contains):
812 pass
813
814
815@UUIDField.register_lookup
816class UUIDIContains(UUIDTextMixin, IContains):
817 pass
818
819
820@UUIDField.register_lookup
821class UUIDStartsWith(UUIDTextMixin, StartsWith):
822 pass
823
824
825@UUIDField.register_lookup
826class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
827 pass
828
829
830@UUIDField.register_lookup
831class UUIDEndsWith(UUIDTextMixin, EndsWith):
832 pass
833
834
835@UUIDField.register_lookup
836class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
837 pass