1from __future__ import annotations
2
3import itertools
4import math
5from functools import cached_property
6from typing import TYPE_CHECKING, Any
7
8from plain.models.exceptions import EmptyResultSet, FullResultSet
9from plain.models.expressions import Expression, Func, Value
10from plain.models.fields import (
11 BooleanField,
12 CharField,
13 DateTimeField,
14 Field,
15 IntegerField,
16 UUIDField,
17)
18from plain.models.query_utils import RegisterLookupMixin
19from plain.utils.datastructures import OrderedSet
20from plain.utils.hashable import make_hashable
21
22if TYPE_CHECKING:
23 from plain.models.backends.base.base import BaseDatabaseWrapper
24 from plain.models.sql.compiler import SQLCompiler
25
26
27class Lookup(Expression):
28 lookup_name: str | None = None
29 prepare_rhs: bool = True
30 can_use_none_as_rhs: bool = False
31
32 def __init__(self, lhs: Any, rhs: Any):
33 self.lhs, self.rhs = lhs, rhs
34 self.rhs = self.get_prep_lookup()
35 self.lhs = self.get_prep_lhs()
36 if hasattr(self.lhs, "get_bilateral_transforms"):
37 bilateral_transforms = self.lhs.get_bilateral_transforms() # type: ignore[attr-defined]
38 else:
39 bilateral_transforms = []
40 if bilateral_transforms:
41 # Warn the user as soon as possible if they are trying to apply
42 # a bilateral transformation on a nested QuerySet: that won't work.
43 from plain.models.sql.query import Query # avoid circular import
44
45 if isinstance(rhs, Query):
46 raise NotImplementedError(
47 "Bilateral transformations on nested querysets are not implemented."
48 )
49 self.bilateral_transforms = bilateral_transforms
50
51 def apply_bilateral_transforms(self, value: Any) -> Any:
52 for transform in self.bilateral_transforms:
53 value = transform(value)
54 return value
55
56 def __repr__(self) -> str:
57 return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
58
59 def batch_process_rhs(
60 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, rhs: Any = None
61 ) -> tuple[list[str], list[Any]]:
62 if rhs is None:
63 rhs = self.rhs
64 if self.bilateral_transforms:
65 sqls, sqls_params = [], []
66 for p in rhs:
67 value = Value(p, output_field=self.lhs.output_field) # type: ignore[attr-defined]
68 value = self.apply_bilateral_transforms(value)
69 value = value.resolve_expression(compiler.query) # type: ignore[attr-defined]
70 sql, sql_params = compiler.compile(value)
71 sqls.append(sql)
72 sqls_params.extend(sql_params)
73 else:
74 _, params = self.get_db_prep_lookup(rhs, connection)
75 sqls, sqls_params = ["%s"] * len(params), params
76 return sqls, sqls_params
77
78 def get_source_expressions(self) -> list[Any]:
79 if self.rhs_is_direct_value():
80 return [self.lhs]
81 return [self.lhs, self.rhs]
82
83 def set_source_expressions(self, new_exprs: list[Any]) -> None:
84 if len(new_exprs) == 1:
85 self.lhs = new_exprs[0]
86 else:
87 self.lhs, self.rhs = new_exprs
88
89 def get_prep_lookup(self) -> Any:
90 if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"):
91 return self.rhs
92 if hasattr(self.lhs, "output_field"):
93 if hasattr(self.lhs.output_field, "get_prep_value"): # type: ignore[attr-defined]
94 return self.lhs.output_field.get_prep_value(self.rhs) # type: ignore[attr-defined]
95 elif self.rhs_is_direct_value():
96 return Value(self.rhs)
97 return self.rhs
98
99 def get_prep_lhs(self) -> Any:
100 if hasattr(self.lhs, "resolve_expression"):
101 return self.lhs
102 return Value(self.lhs)
103
104 def get_db_prep_lookup(
105 self, value: Any, connection: BaseDatabaseWrapper
106 ) -> tuple[str, list[Any]]:
107 return ("%s", [value])
108
109 def process_lhs(
110 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Any = None
111 ) -> tuple[str, list[Any]]:
112 lhs = lhs or self.lhs
113 if hasattr(lhs, "resolve_expression"):
114 lhs = lhs.resolve_expression(compiler.query) # type: ignore[attr-defined]
115 sql, params = compiler.compile(lhs)
116 if isinstance(lhs, Lookup):
117 # Wrapped in parentheses to respect operator precedence.
118 sql = f"({sql})"
119 return sql, list(params)
120
121 def process_rhs(
122 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
123 ) -> tuple[str, list[Any]]:
124 value = self.rhs
125 if self.bilateral_transforms:
126 if self.rhs_is_direct_value():
127 # Do not call get_db_prep_lookup here as the value will be
128 # transformed before being used for lookup
129 value = Value(value, output_field=self.lhs.output_field) # type: ignore[attr-defined]
130 value = self.apply_bilateral_transforms(value)
131 value = value.resolve_expression(compiler.query) # type: ignore[attr-defined]
132 if hasattr(value, "as_sql"):
133 sql, params = compiler.compile(value)
134 # Ensure expression is wrapped in parentheses to respect operator
135 # precedence but avoid double wrapping as it can be misinterpreted
136 # on some backends (e.g. subqueries on SQLite).
137 if sql and sql[0] != "(":
138 sql = f"({sql})"
139 return sql, list(params)
140 else:
141 return self.get_db_prep_lookup(value, connection)
142
143 def rhs_is_direct_value(self) -> bool:
144 return not hasattr(self.rhs, "as_sql")
145
146 def get_group_by_cols(self) -> list[Any]:
147 cols = []
148 for source in self.get_source_expressions():
149 cols.extend(source.get_group_by_cols()) # type: ignore[attr-defined]
150 return cols
151
152 @cached_property
153 def output_field(self) -> BooleanField:
154 return BooleanField()
155
156 @property
157 def identity(self) -> tuple[type[Lookup], Any, Any]:
158 return self.__class__, self.lhs, self.rhs
159
160 def __eq__(self, other: object) -> bool:
161 if not isinstance(other, Lookup):
162 return NotImplemented
163 return self.identity == other.identity
164
165 def __hash__(self) -> int:
166 return hash(make_hashable(self.identity))
167
168 def resolve_expression(
169 self,
170 query: Any = None,
171 allow_joins: bool = True,
172 reuse: Any = None,
173 summarize: bool = False,
174 for_save: bool = False,
175 ) -> Lookup:
176 c = self.copy() # type: ignore[attr-defined]
177 c.is_summary = summarize
178 c.lhs = self.lhs.resolve_expression( # type: ignore[attr-defined]
179 query, allow_joins, reuse, summarize, for_save
180 )
181 if hasattr(self.rhs, "resolve_expression"):
182 c.rhs = self.rhs.resolve_expression( # type: ignore[attr-defined]
183 query, allow_joins, reuse, summarize, for_save
184 )
185 return c
186
187 def select_format(
188 self, compiler: SQLCompiler, sql: str, params: list[Any]
189 ) -> tuple[str, list[Any]]:
190 # Wrap filters with a CASE WHEN expression if a database backend
191 # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
192 # BY list.
193 if not compiler.connection.features.supports_boolean_expr_in_select_clause:
194 sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
195 return sql, params
196
197
198class Transform(RegisterLookupMixin, Func):
199 """
200 RegisterLookupMixin() is first so that get_lookup() and get_transform()
201 first examine self and then check output_field.
202 """
203
204 bilateral: bool = False
205 arity: int = 1
206
207 @property
208 def lhs(self) -> Any:
209 return self.get_source_expressions()[0]
210
211 def get_bilateral_transforms(self) -> list[type[Transform]]:
212 if hasattr(self.lhs, "get_bilateral_transforms"):
213 bilateral_transforms = self.lhs.get_bilateral_transforms() # type: ignore[attr-defined]
214 else:
215 bilateral_transforms = []
216 if self.bilateral:
217 bilateral_transforms.append(self.__class__)
218 return bilateral_transforms
219
220
221class BuiltinLookup(Lookup):
222 def process_lhs(
223 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Any = None
224 ) -> tuple[str, list[Any]]:
225 lhs_sql, params = super().process_lhs(compiler, connection, lhs) # type: ignore[misc]
226 field_internal_type = self.lhs.output_field.get_internal_type() # type: ignore[attr-defined]
227 db_type = self.lhs.output_field.db_type(connection=connection) # type: ignore[attr-defined]
228 lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
229 lhs_sql = (
230 connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql # type: ignore[arg-type]
231 )
232 return lhs_sql, list(params)
233
234 def as_sql(
235 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
236 ) -> tuple[str, list[Any]]:
237 lhs_sql, params = self.process_lhs(compiler, connection)
238 rhs_sql, rhs_params = self.process_rhs(compiler, connection)
239 params.extend(rhs_params)
240 rhs_sql = self.get_rhs_op(connection, rhs_sql)
241 return f"{lhs_sql} {rhs_sql}", params
242
243 def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
244 return connection.operators[self.lookup_name] % rhs # type: ignore[index]
245
246
247class FieldGetDbPrepValueMixin:
248 """
249 Some lookups require Field.get_db_prep_value() to be called on their
250 inputs.
251 """
252
253 get_db_prep_lookup_value_is_iterable: bool = False
254 lhs: Any
255 rhs: Any
256
257 def get_db_prep_lookup(
258 self, value: Any, connection: BaseDatabaseWrapper
259 ) -> tuple[str, list[Any]]:
260 # For relational fields, use the 'target_field' attribute of the
261 # output_field.
262 field = getattr(self.lhs.output_field, "target_field", None) # type: ignore[attr-defined]
263 get_db_prep_value = (
264 getattr(field, "get_db_prep_value", None)
265 or self.lhs.output_field.get_db_prep_value # type: ignore[attr-defined]
266 )
267 return (
268 "%s",
269 [get_db_prep_value(v, connection, prepared=True) for v in value]
270 if self.get_db_prep_lookup_value_is_iterable
271 else [get_db_prep_value(value, connection, prepared=True)],
272 )
273
274
275class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
276 """
277 Some lookups require Field.get_db_prep_value() to be called on each value
278 in an iterable.
279 """
280
281 get_db_prep_lookup_value_is_iterable: bool = True
282 prepare_rhs: bool
283
284 def get_prep_lookup(self) -> Any:
285 if hasattr(self.rhs, "resolve_expression"):
286 return self.rhs
287 prepared_values = []
288 for rhs_value in self.rhs:
289 if hasattr(rhs_value, "resolve_expression"):
290 # An expression will be handled by the database but can coexist
291 # alongside real values.
292 pass
293 elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"): # type: ignore[attr-defined]
294 rhs_value = self.lhs.output_field.get_prep_value(rhs_value) # type: ignore[attr-defined]
295 prepared_values.append(rhs_value)
296 return prepared_values
297
298 def process_rhs(
299 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
300 ) -> tuple[str, list[Any]]:
301 if self.rhs_is_direct_value(): # type: ignore[attr-defined]
302 # rhs should be an iterable of values. Use batch_process_rhs()
303 # to prepare/transform those values.
304 return self.batch_process_rhs(compiler, connection) # type: ignore[attr-defined]
305 else:
306 return super().process_rhs(compiler, connection) # type: ignore[misc]
307
308 def resolve_expression_parameter(
309 self,
310 compiler: SQLCompiler,
311 connection: BaseDatabaseWrapper,
312 sql: str,
313 param: Any,
314 ) -> tuple[str, list[Any]]:
315 params: list[Any] = [param]
316 if hasattr(param, "resolve_expression"):
317 param = param.resolve_expression(compiler.query) # type: ignore[attr-defined]
318 if hasattr(param, "as_sql"):
319 sql, compiled_params = compiler.compile(param)
320 params = list(compiled_params)
321 return sql, params
322
323 def batch_process_rhs(
324 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, rhs: Any = None
325 ) -> tuple[tuple[str, ...], tuple[Any, ...]]:
326 pre_processed = super().batch_process_rhs(compiler, connection, rhs) # type: ignore[misc]
327 # The params list may contain expressions which compile to a
328 # sql/param pair. Zip them to get sql and param pairs that refer to the
329 # same argument and attempt to replace them with the result of
330 # compiling the param step.
331 sql, params = zip(
332 *(
333 self.resolve_expression_parameter(compiler, connection, sql, param)
334 for sql, param in zip(*pre_processed)
335 )
336 )
337 params = itertools.chain.from_iterable(params)
338 return sql, tuple(params)
339
340
341class PostgresOperatorLookup(Lookup):
342 """Lookup defined by operators on PostgreSQL."""
343
344 postgres_operator: str | None = None
345
346 def as_postgresql(
347 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
348 ) -> tuple[str, tuple[Any, ...]]:
349 lhs, lhs_params = self.process_lhs(compiler, connection)
350 rhs, rhs_params = self.process_rhs(compiler, connection)
351 params = tuple(lhs_params) + tuple(rhs_params)
352 return f"{lhs} {self.postgres_operator} {rhs}", params
353
354
355@Field.register_lookup
356class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
357 lookup_name: str = "exact"
358
359 def get_prep_lookup(self) -> Any:
360 from plain.models.sql.query import Query # avoid circular import
361
362 if isinstance(self.rhs, Query):
363 if self.rhs.has_limit_one():
364 if not self.rhs.has_select_fields:
365 self.rhs.clear_select_clause()
366 self.rhs.add_fields(["id"])
367 else:
368 raise ValueError(
369 "The QuerySet value for an exact lookup must be limited to "
370 "one result using slicing."
371 )
372 return super().get_prep_lookup() # type: ignore[misc]
373
374 def as_sql(
375 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
376 ) -> tuple[str, list[Any]]:
377 # Avoid comparison against direct rhs if lhs is a boolean value. That
378 # turns "boolfield__exact=True" into "WHERE boolean_field" instead of
379 # "WHERE boolean_field = True" when allowed.
380 if (
381 isinstance(self.rhs, bool)
382 and getattr(self.lhs, "conditional", False)
383 and connection.ops.conditional_expression_supported_in_where_clause(
384 self.lhs
385 )
386 ):
387 lhs_sql, params = self.process_lhs(compiler, connection)
388 template = "%s" if self.rhs else "NOT %s"
389 return template % lhs_sql, params
390 return super().as_sql(compiler, connection) # type: ignore[misc]
391
392
393@Field.register_lookup
394class IExact(BuiltinLookup):
395 lookup_name: str = "iexact"
396 prepare_rhs: bool = False
397
398 def process_rhs(
399 self, qn: SQLCompiler, connection: BaseDatabaseWrapper
400 ) -> tuple[str, list[Any]]:
401 rhs, params = super().process_rhs(qn, connection) # type: ignore[misc]
402 if params:
403 params[0] = connection.ops.prep_for_iexact_query(params[0])
404 return rhs, params
405
406
407@Field.register_lookup
408class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
409 lookup_name: str = "gt"
410
411
412@Field.register_lookup
413class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
414 lookup_name: str = "gte"
415
416
417@Field.register_lookup
418class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
419 lookup_name: str = "lt"
420
421
422@Field.register_lookup
423class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
424 lookup_name: str = "lte"
425
426
427class IntegerFieldOverflow:
428 underflow_exception: type[Exception] = EmptyResultSet
429 overflow_exception: type[Exception] = EmptyResultSet
430 lhs: Any
431 rhs: Any
432
433 def process_rhs(
434 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
435 ) -> tuple[str, list[Any]]:
436 rhs = self.rhs
437 if isinstance(rhs, int):
438 field_internal_type = self.lhs.output_field.get_internal_type() # type: ignore[attr-defined]
439 min_value, max_value = connection.ops.integer_field_range(
440 field_internal_type
441 )
442 if min_value is not None and rhs < min_value:
443 raise self.underflow_exception
444 if max_value is not None and rhs > max_value:
445 raise self.overflow_exception
446 return super().process_rhs(compiler, connection) # type: ignore[misc]
447
448
449class IntegerFieldFloatRounding:
450 """
451 Allow floats to work as query values for IntegerField. Without this, the
452 decimal portion of the float would always be discarded.
453 """
454
455 rhs: Any
456
457 def get_prep_lookup(self) -> Any:
458 if isinstance(self.rhs, float):
459 self.rhs = math.ceil(self.rhs)
460 return super().get_prep_lookup() # type: ignore[misc]
461
462
463@IntegerField.register_lookup
464class IntegerFieldExact(IntegerFieldOverflow, Exact):
465 pass
466
467
468@IntegerField.register_lookup
469class IntegerGreaterThan(IntegerFieldOverflow, GreaterThan):
470 underflow_exception = FullResultSet
471
472
473@IntegerField.register_lookup
474class IntegerGreaterThanOrEqual(
475 IntegerFieldOverflow, IntegerFieldFloatRounding, GreaterThanOrEqual
476):
477 underflow_exception = FullResultSet
478
479
480@IntegerField.register_lookup
481class IntegerLessThan(IntegerFieldOverflow, IntegerFieldFloatRounding, LessThan):
482 overflow_exception = FullResultSet
483
484
485@IntegerField.register_lookup
486class IntegerLessThanOrEqual(IntegerFieldOverflow, LessThanOrEqual):
487 overflow_exception = FullResultSet
488
489
490@Field.register_lookup
491class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
492 lookup_name: str = "in"
493
494 def get_prep_lookup(self) -> Any:
495 from plain.models.sql.query import Query # avoid circular import
496
497 if isinstance(self.rhs, Query):
498 self.rhs.clear_ordering(clear_default=True)
499 if not self.rhs.has_select_fields:
500 self.rhs.clear_select_clause()
501 self.rhs.add_fields(["id"])
502 return super().get_prep_lookup() # type: ignore[misc]
503
504 def process_rhs(
505 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
506 ) -> tuple[str, list[Any]] | tuple[str, tuple[Any, ...]]:
507 if self.rhs_is_direct_value(): # type: ignore[attr-defined]
508 # Remove None from the list as NULL is never equal to anything.
509 try:
510 rhs = OrderedSet(self.rhs)
511 rhs.discard(None)
512 except TypeError: # Unhashable items in self.rhs
513 rhs = [r for r in self.rhs if r is not None]
514
515 if not rhs:
516 raise EmptyResultSet
517
518 # rhs should be an iterable; use batch_process_rhs() to
519 # prepare/transform those values.
520 sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs) # type: ignore[attr-defined]
521 placeholder = "(" + ", ".join(sqls) + ")"
522 return (placeholder, sqls_params)
523 return super().process_rhs(compiler, connection) # type: ignore[misc]
524
525 def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
526 return f"IN {rhs}"
527
528 def as_sql(
529 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
530 ) -> tuple[str, list[Any]]:
531 max_in_list_size = connection.ops.max_in_list_size()
532 if (
533 self.rhs_is_direct_value() # type: ignore[attr-defined]
534 and max_in_list_size
535 and len(self.rhs) > max_in_list_size
536 ):
537 return self.split_parameter_list_as_sql(compiler, connection)
538 return super().as_sql(compiler, connection) # type: ignore[misc]
539
540 def split_parameter_list_as_sql(
541 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
542 ) -> tuple[str, list[Any]]:
543 # This is a special case for databases which limit the number of
544 # elements which can appear in an 'IN' clause.
545 max_in_list_size = connection.ops.max_in_list_size()
546 lhs, lhs_params = self.process_lhs(compiler, connection)
547 rhs, rhs_params = self.batch_process_rhs(compiler, connection) # type: ignore[attr-defined]
548 in_clause_elements = ["("]
549 params = []
550 for offset in range(0, len(rhs_params), max_in_list_size): # type: ignore[arg-type]
551 if offset > 0:
552 in_clause_elements.append(" OR ")
553 in_clause_elements.append(f"{lhs} IN (")
554 params.extend(lhs_params)
555 sqls = rhs[offset : offset + max_in_list_size] # type: ignore[operator]
556 sqls_params = rhs_params[offset : offset + max_in_list_size] # type: ignore[index,operator]
557 param_group = ", ".join(sqls)
558 in_clause_elements.append(param_group)
559 in_clause_elements.append(")")
560 params.extend(sqls_params)
561 in_clause_elements.append(")")
562 return "".join(in_clause_elements), params
563
564
565class PatternLookup(BuiltinLookup):
566 param_pattern: str = "%%%s%%"
567 prepare_rhs: bool = False
568 bilateral_transforms: list[Any]
569
570 def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
571 # Assume we are in startswith. We need to produce SQL like:
572 # col LIKE %s, ['thevalue%']
573 # For python values we can (and should) do that directly in Python,
574 # but if the value is for example reference to other column, then
575 # we need to add the % pattern match to the lookup by something like
576 # col LIKE othercol || '%%'
577 # So, for Python values we don't need any special pattern, but for
578 # SQL reference values or SQL transformations we need the correct
579 # pattern added.
580 if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
581 pattern = connection.pattern_ops[self.lookup_name].format( # type: ignore[index]
582 connection.pattern_esc # type: ignore[attr-defined]
583 )
584 return pattern.format(rhs)
585 else:
586 return super().get_rhs_op(connection, rhs)
587
588 def process_rhs(
589 self, qn: SQLCompiler, connection: BaseDatabaseWrapper
590 ) -> tuple[str, list[Any]]:
591 rhs, params = super().process_rhs(qn, connection) # type: ignore[misc]
592 if self.rhs_is_direct_value() and params and not self.bilateral_transforms: # type: ignore[attr-defined]
593 params[0] = self.param_pattern % connection.ops.prep_for_like_query(
594 params[0]
595 )
596 return rhs, params
597
598
599@Field.register_lookup
600class Contains(PatternLookup):
601 lookup_name: str = "contains"
602
603
604@Field.register_lookup
605class IContains(Contains):
606 lookup_name: str = "icontains"
607
608
609@Field.register_lookup
610class StartsWith(PatternLookup):
611 lookup_name: str = "startswith"
612 param_pattern: str = "%s%%"
613
614
615@Field.register_lookup
616class IStartsWith(StartsWith):
617 lookup_name: str = "istartswith"
618
619
620@Field.register_lookup
621class EndsWith(PatternLookup):
622 lookup_name: str = "endswith"
623 param_pattern: str = "%%%s"
624
625
626@Field.register_lookup
627class IEndsWith(EndsWith):
628 lookup_name: str = "iendswith"
629
630
631@Field.register_lookup
632class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
633 lookup_name: str = "range"
634
635 def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
636 return f"BETWEEN {rhs[0]} AND {rhs[1]}"
637
638
639@Field.register_lookup
640class IsNull(BuiltinLookup):
641 lookup_name: str = "isnull"
642 prepare_rhs: bool = False
643
644 def as_sql(
645 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
646 ) -> tuple[str, list[Any]]:
647 if not isinstance(self.rhs, bool):
648 raise ValueError(
649 "The QuerySet value for an isnull lookup must be True or False."
650 )
651 sql, params = self.process_lhs(compiler, connection)
652 if self.rhs:
653 return f"{sql} IS NULL", params
654 else:
655 return f"{sql} IS NOT NULL", params
656
657
658@Field.register_lookup
659class Regex(BuiltinLookup):
660 lookup_name: str = "regex"
661 prepare_rhs: bool = False
662
663 def as_sql(
664 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
665 ) -> tuple[str, list[Any]]:
666 if self.lookup_name in connection.operators: # type: ignore[operator]
667 return super().as_sql(compiler, connection) # type: ignore[misc]
668 else:
669 lhs, lhs_params = self.process_lhs(compiler, connection)
670 rhs, rhs_params = self.process_rhs(compiler, connection)
671 sql_template = connection.ops.regex_lookup(self.lookup_name)
672 return sql_template % (lhs, rhs), lhs_params + rhs_params
673
674
675@Field.register_lookup
676class IRegex(Regex):
677 lookup_name: str = "iregex"
678
679
680class YearLookup(Lookup):
681 def year_lookup_bounds(
682 self, connection: BaseDatabaseWrapper, year: int
683 ) -> list[str | Any | None]:
684 from plain.models.functions import ExtractIsoYear
685
686 iso_year = isinstance(self.lhs, ExtractIsoYear)
687 output_field = self.lhs.lhs.output_field # type: ignore[attr-defined]
688 if isinstance(output_field, DateTimeField):
689 bounds = connection.ops.year_lookup_bounds_for_datetime_field(
690 year,
691 iso_year=iso_year,
692 )
693 else:
694 bounds = connection.ops.year_lookup_bounds_for_date_field(
695 year,
696 iso_year=iso_year,
697 )
698 return bounds
699
700 def as_sql(
701 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
702 ) -> tuple[str, list[Any]]:
703 # Avoid the extract operation if the rhs is a direct value to allow
704 # indexes to be used.
705 if self.rhs_is_direct_value():
706 # Skip the extract part by directly using the originating field,
707 # that is self.lhs.lhs.
708 lhs_sql, params = self.process_lhs(compiler, connection, self.lhs.lhs) # type: ignore[attr-defined]
709 rhs_sql, _ = self.process_rhs(compiler, connection)
710 rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
711 start, finish = self.year_lookup_bounds(connection, self.rhs)
712 params.extend(self.get_bound_params(start, finish))
713 return f"{lhs_sql} {rhs_sql}", params
714 return super().as_sql(compiler, connection) # type: ignore[misc]
715
716 def get_direct_rhs_sql(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
717 return connection.operators[self.lookup_name] % rhs # type: ignore[index]
718
719 def get_bound_params(self, start: Any, finish: Any) -> tuple[Any, ...]:
720 raise NotImplementedError(
721 "subclasses of YearLookup must provide a get_bound_params() method"
722 )
723
724
725class YearExact(YearLookup, Exact):
726 def get_direct_rhs_sql(self, connection: BaseDatabaseWrapper, rhs: str) -> str:
727 return "BETWEEN %s AND %s"
728
729 def get_bound_params(self, start: Any, finish: Any) -> tuple[Any, Any]:
730 return (start, finish)
731
732
733class YearGt(YearLookup, GreaterThan):
734 def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
735 return (finish,)
736
737
738class YearGte(YearLookup, GreaterThanOrEqual):
739 def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
740 return (start,)
741
742
743class YearLt(YearLookup, LessThan):
744 def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
745 return (start,)
746
747
748class YearLte(YearLookup, LessThanOrEqual):
749 def get_bound_params(self, start: Any, finish: Any) -> tuple[Any]:
750 return (finish,)
751
752
753class UUIDTextMixin:
754 """
755 Strip hyphens from a value when filtering a UUIDField on backends without
756 a native datatype for UUID.
757 """
758
759 rhs: Any
760
761 def process_rhs(
762 self, qn: SQLCompiler, connection: BaseDatabaseWrapper
763 ) -> tuple[str, list[Any]]:
764 if not connection.features.has_native_uuid_field:
765 from plain.models.functions import Replace
766
767 if self.rhs_is_direct_value(): # type: ignore[attr-defined]
768 self.rhs = Value(self.rhs)
769 self.rhs = Replace(
770 self.rhs, Value("-"), Value(""), output_field=CharField()
771 )
772 rhs, params = super().process_rhs(qn, connection) # type: ignore[misc]
773 return rhs, params
774
775
776@UUIDField.register_lookup
777class UUIDIExact(UUIDTextMixin, IExact):
778 pass
779
780
781@UUIDField.register_lookup
782class UUIDContains(UUIDTextMixin, Contains):
783 pass
784
785
786@UUIDField.register_lookup
787class UUIDIContains(UUIDTextMixin, IContains):
788 pass
789
790
791@UUIDField.register_lookup
792class UUIDStartsWith(UUIDTextMixin, StartsWith):
793 pass
794
795
796@UUIDField.register_lookup
797class UUIDIStartsWith(UUIDTextMixin, IStartsWith):
798 pass
799
800
801@UUIDField.register_lookup
802class UUIDEndsWith(UUIDTextMixin, EndsWith):
803 pass
804
805
806@UUIDField.register_lookup
807class UUIDIEndsWith(UUIDTextMixin, IEndsWith):
808 pass