1import datetime
2import decimal
3import json
4from importlib import import_module
5
6import sqlparse
7
8from plain.models.backends import utils
9from plain.models.db import NotSupportedError
10from plain.utils import timezone
11from plain.utils.encoding import force_str
12
13
14class BaseDatabaseOperations:
15 """
16 Encapsulate backend-specific differences, such as the way a backend
17 performs ordering or calculates the ID of a recently-inserted row.
18 """
19
20 compiler_module = "plain.models.sql.compiler"
21
22 # Integer field safe ranges by `internal_type` as documented
23 # in docs/ref/models/fields.txt.
24 integer_field_ranges = {
25 "SmallIntegerField": (-32768, 32767),
26 "IntegerField": (-2147483648, 2147483647),
27 "BigIntegerField": (-9223372036854775808, 9223372036854775807),
28 "PositiveBigIntegerField": (0, 9223372036854775807),
29 "PositiveSmallIntegerField": (0, 32767),
30 "PositiveIntegerField": (0, 2147483647),
31 "SmallAutoField": (-32768, 32767),
32 "AutoField": (-2147483648, 2147483647),
33 "BigAutoField": (-9223372036854775808, 9223372036854775807),
34 }
35 set_operators = {
36 "union": "UNION",
37 "intersection": "INTERSECT",
38 "difference": "EXCEPT",
39 }
40 # Mapping of Field.get_internal_type() (typically the model field's class
41 # name) to the data type to use for the Cast() function, if different from
42 # DatabaseWrapper.data_types.
43 cast_data_types = {}
44 # CharField data type if the max_length argument isn't provided.
45 cast_char_field_without_max_length = None
46
47 # Start and end points for window expressions.
48 PRECEDING = "PRECEDING"
49 FOLLOWING = "FOLLOWING"
50 UNBOUNDED_PRECEDING = "UNBOUNDED " + PRECEDING
51 UNBOUNDED_FOLLOWING = "UNBOUNDED " + FOLLOWING
52 CURRENT_ROW = "CURRENT ROW"
53
54 # Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
55 explain_prefix = None
56
57 def __init__(self, connection):
58 self.connection = connection
59 self._cache = None
60
61 def autoinc_sql(self, table, column):
62 """
63 Return any SQL needed to support auto-incrementing primary keys, or
64 None if no SQL is necessary.
65
66 This SQL is executed when a table is created.
67 """
68 return None
69
70 def bulk_batch_size(self, fields, objs):
71 """
72 Return the maximum allowed batch size for the backend. The fields
73 are the fields going to be inserted in the batch, the objs contains
74 all the objects to be inserted.
75 """
76 return len(objs)
77
78 def format_for_duration_arithmetic(self, sql):
79 raise NotImplementedError(
80 "subclasses of BaseDatabaseOperations may require a "
81 "format_for_duration_arithmetic() method."
82 )
83
84 def unification_cast_sql(self, output_field):
85 """
86 Given a field instance, return the SQL that casts the result of a union
87 to that type. The resulting string should contain a '%s' placeholder
88 for the expression being cast.
89 """
90 return "%s"
91
92 def date_extract_sql(self, lookup_type, sql, params):
93 """
94 Given a lookup_type of 'year', 'month', or 'day', return the SQL that
95 extracts a value from the given date field field_name.
96 """
97 raise NotImplementedError(
98 "subclasses of BaseDatabaseOperations may require a date_extract_sql() "
99 "method"
100 )
101
102 def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
103 """
104 Given a lookup_type of 'year', 'month', or 'day', return the SQL that
105 truncates the given date or datetime field field_name to a date object
106 with only the given specificity.
107
108 If `tzname` is provided, the given value is truncated in a specific
109 timezone.
110 """
111 raise NotImplementedError(
112 "subclasses of BaseDatabaseOperations may require a date_trunc_sql() "
113 "method."
114 )
115
116 def datetime_cast_date_sql(self, sql, params, tzname):
117 """
118 Return the SQL to cast a datetime value to date value.
119 """
120 raise NotImplementedError(
121 "subclasses of BaseDatabaseOperations may require a "
122 "datetime_cast_date_sql() method."
123 )
124
125 def datetime_cast_time_sql(self, sql, params, tzname):
126 """
127 Return the SQL to cast a datetime value to time value.
128 """
129 raise NotImplementedError(
130 "subclasses of BaseDatabaseOperations may require a "
131 "datetime_cast_time_sql() method"
132 )
133
134 def datetime_extract_sql(self, lookup_type, sql, params, tzname):
135 """
136 Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
137 'second', return the SQL that extracts a value from the given
138 datetime field field_name.
139 """
140 raise NotImplementedError(
141 "subclasses of BaseDatabaseOperations may require a datetime_extract_sql() "
142 "method"
143 )
144
145 def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
146 """
147 Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
148 'second', return the SQL that truncates the given datetime field
149 field_name to a datetime object with only the given specificity.
150 """
151 raise NotImplementedError(
152 "subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() "
153 "method"
154 )
155
156 def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
157 """
158 Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
159 that truncates the given time or datetime field field_name to a time
160 object with only the given specificity.
161
162 If `tzname` is provided, the given value is truncated in a specific
163 timezone.
164 """
165 raise NotImplementedError(
166 "subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
167 )
168
169 def time_extract_sql(self, lookup_type, sql, params):
170 """
171 Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
172 that extracts a value from the given time field field_name.
173 """
174 return self.date_extract_sql(lookup_type, sql, params)
175
176 def deferrable_sql(self):
177 """
178 Return the SQL to make a constraint "initially deferred" during a
179 CREATE TABLE statement.
180 """
181 return ""
182
183 def distinct_sql(self, fields, params):
184 """
185 Return an SQL DISTINCT clause which removes duplicate rows from the
186 result set. If any fields are given, only check the given fields for
187 duplicates.
188 """
189 if fields:
190 raise NotSupportedError(
191 "DISTINCT ON fields is not supported by this database backend"
192 )
193 else:
194 return ["DISTINCT"], []
195
196 def fetch_returned_insert_columns(self, cursor, returning_params):
197 """
198 Given a cursor object that has just performed an INSERT...RETURNING
199 statement into a table, return the newly created data.
200 """
201 return cursor.fetchone()
202
203 def field_cast_sql(self, db_type, internal_type):
204 """
205 Given a column type (e.g. 'BLOB', 'VARCHAR') and an internal type
206 (e.g. 'GenericIPAddressField'), return the SQL to cast it before using
207 it in a WHERE statement. The resulting string should contain a '%s'
208 placeholder for the column being searched against.
209 """
210 return "%s"
211
212 def force_no_ordering(self):
213 """
214 Return a list used in the "ORDER BY" clause to force no ordering at
215 all. Return an empty list to include nothing in the ordering.
216 """
217 return []
218
219 def for_update_sql(self, nowait=False, skip_locked=False, of=(), no_key=False):
220 """
221 Return the FOR UPDATE SQL clause to lock rows for an update operation.
222 """
223 return "FOR{} UPDATE{}{}{}".format(
224 " NO KEY" if no_key else "",
225 " OF {}".format(", ".join(of)) if of else "",
226 " NOWAIT" if nowait else "",
227 " SKIP LOCKED" if skip_locked else "",
228 )
229
230 def _get_limit_offset_params(self, low_mark, high_mark):
231 offset = low_mark or 0
232 if high_mark is not None:
233 return (high_mark - offset), offset
234 elif offset:
235 return self.connection.ops.no_limit_value(), offset
236 return None, offset
237
238 def limit_offset_sql(self, low_mark, high_mark):
239 """Return LIMIT/OFFSET SQL clause."""
240 limit, offset = self._get_limit_offset_params(low_mark, high_mark)
241 return " ".join(
242 sql
243 for sql in (
244 ("LIMIT %d" % limit) if limit else None, # noqa: UP031
245 ("OFFSET %d" % offset) if offset else None, # noqa: UP031
246 )
247 if sql
248 )
249
250 def last_executed_query(self, cursor, sql, params):
251 """
252 Return a string of the query last executed by the given cursor, with
253 placeholders replaced with actual values.
254
255 `sql` is the raw query containing placeholders and `params` is the
256 sequence of parameters. These are used by default, but this method
257 exists for database backends to provide a better implementation
258 according to their own quoting schemes.
259 """
260
261 # Convert params to contain string values.
262 def to_string(s):
263 return force_str(s, strings_only=True, errors="replace")
264
265 if isinstance(params, list | tuple):
266 u_params = tuple(to_string(val) for val in params)
267 elif params is None:
268 u_params = ()
269 else:
270 u_params = {to_string(k): to_string(v) for k, v in params.items()}
271
272 return f"QUERY = {sql!r} - PARAMS = {u_params!r}"
273
274 def last_insert_id(self, cursor, table_name, pk_name):
275 """
276 Given a cursor object that has just performed an INSERT statement into
277 a table that has an auto-incrementing ID, return the newly created ID.
278
279 `pk_name` is the name of the primary-key column.
280 """
281 return cursor.lastrowid
282
283 def lookup_cast(self, lookup_type, internal_type=None):
284 """
285 Return the string to use in a query when performing lookups
286 ("contains", "like", etc.). It should contain a '%s' placeholder for
287 the column being searched against.
288 """
289 return "%s"
290
291 def max_in_list_size(self):
292 """
293 Return the maximum number of items that can be passed in a single 'IN'
294 list condition, or None if the backend does not impose a limit.
295 """
296 return None
297
298 def max_name_length(self):
299 """
300 Return the maximum length of table and column names, or None if there
301 is no limit.
302 """
303 return None
304
305 def no_limit_value(self):
306 """
307 Return the value to use for the LIMIT when we are wanting "LIMIT
308 infinity". Return None if the limit clause can be omitted in this case.
309 """
310 raise NotImplementedError(
311 "subclasses of BaseDatabaseOperations may require a no_limit_value() method"
312 )
313
314 def pk_default_value(self):
315 """
316 Return the value to use during an INSERT statement to specify that
317 the field should use its default value.
318 """
319 return "DEFAULT"
320
321 def prepare_sql_script(self, sql):
322 """
323 Take an SQL script that may contain multiple lines and return a list
324 of statements to feed to successive cursor.execute() calls.
325
326 Since few databases are able to process raw SQL scripts in a single
327 cursor.execute() call and PEP 249 doesn't talk about this use case,
328 the default implementation is conservative.
329 """
330 return [
331 sqlparse.format(statement, strip_comments=True)
332 for statement in sqlparse.split(sql)
333 if statement
334 ]
335
336 def return_insert_columns(self, fields):
337 """
338 For backends that support returning columns as part of an insert query,
339 return the SQL and params to append to the INSERT query. The returned
340 fragment should contain a format string to hold the appropriate column.
341 """
342 pass
343
344 def compiler(self, compiler_name):
345 """
346 Return the SQLCompiler class corresponding to the given name,
347 in the namespace corresponding to the `compiler_module` attribute
348 on this backend.
349 """
350 if self._cache is None:
351 self._cache = import_module(self.compiler_module)
352 return getattr(self._cache, compiler_name)
353
354 def quote_name(self, name):
355 """
356 Return a quoted version of the given table, index, or column name. Do
357 not quote the given name if it's already been quoted.
358 """
359 raise NotImplementedError(
360 "subclasses of BaseDatabaseOperations may require a quote_name() method"
361 )
362
363 def regex_lookup(self, lookup_type):
364 """
365 Return the string to use in a query when performing regular expression
366 lookups (using "regex" or "iregex"). It should contain a '%s'
367 placeholder for the column being searched against.
368
369 If the feature is not supported (or part of it is not supported), raise
370 NotImplementedError.
371 """
372 raise NotImplementedError(
373 "subclasses of BaseDatabaseOperations may require a regex_lookup() method"
374 )
375
376 def savepoint_create_sql(self, sid):
377 """
378 Return the SQL for starting a new savepoint. Only required if the
379 "uses_savepoints" feature is True. The "sid" parameter is a string
380 for the savepoint id.
381 """
382 return f"SAVEPOINT {self.quote_name(sid)}"
383
384 def savepoint_commit_sql(self, sid):
385 """
386 Return the SQL for committing the given savepoint.
387 """
388 return f"RELEASE SAVEPOINT {self.quote_name(sid)}"
389
390 def savepoint_rollback_sql(self, sid):
391 """
392 Return the SQL for rolling back the given savepoint.
393 """
394 return f"ROLLBACK TO SAVEPOINT {self.quote_name(sid)}"
395
396 def set_time_zone_sql(self):
397 """
398 Return the SQL that will set the connection's time zone.
399
400 Return '' if the backend doesn't support time zones.
401 """
402 return ""
403
404 def prep_for_like_query(self, x):
405 """Prepare a value for use in a LIKE query."""
406 return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
407
408 # Same as prep_for_like_query(), but called for "iexact" matches, which
409 # need not necessarily be implemented using "LIKE" in the backend.
410 prep_for_iexact_query = prep_for_like_query
411
412 def validate_autopk_value(self, value):
413 """
414 Certain backends do not accept some values for "serial" fields
415 (for example zero in MySQL). Raise a ValueError if the value is
416 invalid, otherwise return the validated value.
417 """
418 return value
419
420 def adapt_unknown_value(self, value):
421 """
422 Transform a value to something compatible with the backend driver.
423
424 This method only depends on the type of the value. It's designed for
425 cases where the target type isn't known, such as .raw() SQL queries.
426 As a consequence it may not work perfectly in all circumstances.
427 """
428 if isinstance(value, datetime.datetime): # must be before date
429 return self.adapt_datetimefield_value(value)
430 elif isinstance(value, datetime.date):
431 return self.adapt_datefield_value(value)
432 elif isinstance(value, datetime.time):
433 return self.adapt_timefield_value(value)
434 elif isinstance(value, decimal.Decimal):
435 return self.adapt_decimalfield_value(value)
436 else:
437 return value
438
439 def adapt_integerfield_value(self, value, internal_type):
440 return value
441
442 def adapt_datefield_value(self, value):
443 """
444 Transform a date value to an object compatible with what is expected
445 by the backend driver for date columns.
446 """
447 if value is None:
448 return None
449 return str(value)
450
451 def adapt_datetimefield_value(self, value):
452 """
453 Transform a datetime value to an object compatible with what is expected
454 by the backend driver for datetime columns.
455 """
456 if value is None:
457 return None
458 # Expression values are adapted by the database.
459 if hasattr(value, "resolve_expression"):
460 return value
461
462 return str(value)
463
464 def adapt_timefield_value(self, value):
465 """
466 Transform a time value to an object compatible with what is expected
467 by the backend driver for time columns.
468 """
469 if value is None:
470 return None
471 # Expression values are adapted by the database.
472 if hasattr(value, "resolve_expression"):
473 return value
474
475 if timezone.is_aware(value):
476 raise ValueError("Plain does not support timezone-aware times.")
477 return str(value)
478
479 def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
480 """
481 Transform a decimal.Decimal value to an object compatible with what is
482 expected by the backend driver for decimal (numeric) columns.
483 """
484 return utils.format_number(value, max_digits, decimal_places)
485
486 def adapt_ipaddressfield_value(self, value):
487 """
488 Transform a string representation of an IP address into the expected
489 type for the backend driver.
490 """
491 return value or None
492
493 def adapt_json_value(self, value, encoder):
494 return json.dumps(value, cls=encoder)
495
496 def year_lookup_bounds_for_date_field(self, value, iso_year=False):
497 """
498 Return a two-elements list with the lower and upper bound to be used
499 with a BETWEEN operator to query a DateField value using a year
500 lookup.
501
502 `value` is an int, containing the looked-up year.
503 If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
504 """
505 if iso_year:
506 first = datetime.date.fromisocalendar(value, 1, 1)
507 second = datetime.date.fromisocalendar(
508 value + 1, 1, 1
509 ) - datetime.timedelta(days=1)
510 else:
511 first = datetime.date(value, 1, 1)
512 second = datetime.date(value, 12, 31)
513 first = self.adapt_datefield_value(first)
514 second = self.adapt_datefield_value(second)
515 return [first, second]
516
517 def year_lookup_bounds_for_datetime_field(self, value, iso_year=False):
518 """
519 Return a two-elements list with the lower and upper bound to be used
520 with a BETWEEN operator to query a DateTimeField value using a year
521 lookup.
522
523 `value` is an int, containing the looked-up year.
524 If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
525 """
526 if iso_year:
527 first = datetime.datetime.fromisocalendar(value, 1, 1)
528 second = datetime.datetime.fromisocalendar(
529 value + 1, 1, 1
530 ) - datetime.timedelta(microseconds=1)
531 else:
532 first = datetime.datetime(value, 1, 1)
533 second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
534
535 # Make sure that datetimes are aware in the current timezone
536 tz = timezone.get_current_timezone()
537 first = timezone.make_aware(first, tz)
538 second = timezone.make_aware(second, tz)
539
540 first = self.adapt_datetimefield_value(first)
541 second = self.adapt_datetimefield_value(second)
542 return [first, second]
543
544 def get_db_converters(self, expression):
545 """
546 Return a list of functions needed to convert field data.
547
548 Some field types on some backends do not provide data in the correct
549 format, this is the hook for converter functions.
550 """
551 return []
552
553 def convert_durationfield_value(self, value, expression, connection):
554 if value is not None:
555 return datetime.timedelta(0, 0, value)
556
557 def check_expression_support(self, expression):
558 """
559 Check that the backend supports the provided expression.
560
561 This is used on specific backends to rule out known expressions
562 that have problematic or nonexistent implementations. If the
563 expression has a known problem, the backend should raise
564 NotSupportedError.
565 """
566 pass
567
568 def conditional_expression_supported_in_where_clause(self, expression):
569 """
570 Return True, if the conditional expression is supported in the WHERE
571 clause.
572 """
573 return True
574
575 def combine_expression(self, connector, sub_expressions):
576 """
577 Combine a list of subexpressions into a single expression, using
578 the provided connecting operator. This is required because operators
579 can vary between backends (e.g., Oracle with %% and &) and between
580 subexpression types (e.g., date expressions).
581 """
582 conn = f" {connector} "
583 return conn.join(sub_expressions)
584
585 def combine_duration_expression(self, connector, sub_expressions):
586 return self.combine_expression(connector, sub_expressions)
587
588 def binary_placeholder_sql(self, value):
589 """
590 Some backends require special syntax to insert binary content (MySQL
591 for example uses '_binary %s').
592 """
593 return "%s"
594
595 def modify_insert_params(self, placeholder, params):
596 """
597 Allow modification of insert parameters. Needed for Oracle Spatial
598 backend due to #10888.
599 """
600 return params
601
602 def integer_field_range(self, internal_type):
603 """
604 Given an integer field internal type (e.g. 'PositiveIntegerField'),
605 return a tuple of the (min_value, max_value) form representing the
606 range of the column type bound to the field.
607 """
608 return self.integer_field_ranges[internal_type]
609
610 def subtract_temporals(self, internal_type, lhs, rhs):
611 if self.connection.features.supports_temporal_subtraction:
612 lhs_sql, lhs_params = lhs
613 rhs_sql, rhs_params = rhs
614 return f"({lhs_sql} - {rhs_sql})", (*lhs_params, *rhs_params)
615 raise NotSupportedError(
616 f"This backend does not support {internal_type} subtraction."
617 )
618
619 def window_frame_start(self, start):
620 if isinstance(start, int):
621 if start < 0:
622 return "%d %s" % (abs(start), self.PRECEDING) # noqa: UP031
623 elif start == 0:
624 return self.CURRENT_ROW
625 elif start is None:
626 return self.UNBOUNDED_PRECEDING
627 raise ValueError(
628 f"start argument must be a negative integer, zero, or None, but got '{start}'."
629 )
630
631 def window_frame_end(self, end):
632 if isinstance(end, int):
633 if end == 0:
634 return self.CURRENT_ROW
635 elif end > 0:
636 return "%d %s" % (end, self.FOLLOWING) # noqa: UP031
637 elif end is None:
638 return self.UNBOUNDED_FOLLOWING
639 raise ValueError(
640 f"end argument must be a positive integer, zero, or None, but got '{end}'."
641 )
642
643 def window_frame_rows_start_end(self, start=None, end=None):
644 """
645 Return SQL for start and end points in an OVER clause window frame.
646 """
647 if not self.connection.features.supports_over_clause:
648 raise NotSupportedError("This backend does not support window expressions.")
649 return self.window_frame_start(start), self.window_frame_end(end)
650
651 def window_frame_range_start_end(self, start=None, end=None):
652 start_, end_ = self.window_frame_rows_start_end(start, end)
653 features = self.connection.features
654 if features.only_supports_unbounded_with_preceding_and_following and (
655 (start and start < 0) or (end and end > 0)
656 ):
657 raise NotSupportedError(
658 f"{self.connection.display_name} only supports UNBOUNDED together with PRECEDING and "
659 "FOLLOWING."
660 )
661 return start_, end_
662
663 def explain_query_prefix(self, format=None, **options):
664 if not self.connection.features.supports_explaining_query_execution:
665 raise NotSupportedError(
666 "This backend does not support explaining query execution."
667 )
668 if format:
669 supported_formats = self.connection.features.supported_explain_formats
670 normalized_format = format.upper()
671 if normalized_format not in supported_formats:
672 msg = f"{normalized_format} is not a recognized format."
673 if supported_formats:
674 msg += " Allowed formats: {}".format(
675 ", ".join(sorted(supported_formats))
676 )
677 else:
678 msg += (
679 f" {self.connection.display_name} does not support any formats."
680 )
681 raise ValueError(msg)
682 if options:
683 raise ValueError(
684 "Unknown options: {}".format(", ".join(sorted(options.keys())))
685 )
686 return self.explain_prefix
687
688 def insert_statement(self, on_conflict=None):
689 return "INSERT INTO"
690
691 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
692 return ""