1import datetime
2import decimal
3import json
4from importlib import import_module
5
6import sqlparse
7
8from plain.models.backends import utils
9from plain.models.db import NotSupportedError
10from plain.utils import timezone
11from plain.utils.encoding import force_str
12
13
14class BaseDatabaseOperations:
15 """
16 Encapsulate backend-specific differences, such as the way a backend
17 performs ordering or calculates the ID of a recently-inserted row.
18 """
19
20 compiler_module = "plain.models.sql.compiler"
21
22 # Integer field safe ranges by `internal_type` as documented
23 # in docs/ref/models/fields.txt.
24 integer_field_ranges = {
25 "SmallIntegerField": (-32768, 32767),
26 "IntegerField": (-2147483648, 2147483647),
27 "BigIntegerField": (-9223372036854775808, 9223372036854775807),
28 "PositiveBigIntegerField": (0, 9223372036854775807),
29 "PositiveSmallIntegerField": (0, 32767),
30 "PositiveIntegerField": (0, 2147483647),
31 "PrimaryKeyField": (-9223372036854775808, 9223372036854775807),
32 }
33 set_operators = {
34 "union": "UNION",
35 "intersection": "INTERSECT",
36 "difference": "EXCEPT",
37 }
38 # Mapping of Field.get_internal_type() (typically the model field's class
39 # name) to the data type to use for the Cast() function, if different from
40 # DatabaseWrapper.data_types.
41 cast_data_types = {}
42 # CharField data type if the max_length argument isn't provided.
43 cast_char_field_without_max_length = None
44
45 # Start and end points for window expressions.
46 PRECEDING = "PRECEDING"
47 FOLLOWING = "FOLLOWING"
48 UNBOUNDED_PRECEDING = "UNBOUNDED " + PRECEDING
49 UNBOUNDED_FOLLOWING = "UNBOUNDED " + FOLLOWING
50 CURRENT_ROW = "CURRENT ROW"
51
52 # Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
53 explain_prefix = None
54
55 def __init__(self, connection):
56 self.connection = connection
57 self._cache = None
58
59 def autoinc_sql(self, table, column):
60 """
61 Return any SQL needed to support auto-incrementing primary keys, or
62 None if no SQL is necessary.
63
64 This SQL is executed when a table is created.
65 """
66 return None
67
68 def bulk_batch_size(self, fields, objs):
69 """
70 Return the maximum allowed batch size for the backend. The fields
71 are the fields going to be inserted in the batch, the objs contains
72 all the objects to be inserted.
73 """
74 return len(objs)
75
76 def format_for_duration_arithmetic(self, sql):
77 raise NotImplementedError(
78 "subclasses of BaseDatabaseOperations may require a "
79 "format_for_duration_arithmetic() method."
80 )
81
82 def unification_cast_sql(self, output_field):
83 """
84 Given a field instance, return the SQL that casts the result of a union
85 to that type. The resulting string should contain a '%s' placeholder
86 for the expression being cast.
87 """
88 return "%s"
89
90 def date_extract_sql(self, lookup_type, sql, params):
91 """
92 Given a lookup_type of 'year', 'month', or 'day', return the SQL that
93 extracts a value from the given date field field_name.
94 """
95 raise NotImplementedError(
96 "subclasses of BaseDatabaseOperations may require a date_extract_sql() "
97 "method"
98 )
99
100 def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
101 """
102 Given a lookup_type of 'year', 'month', or 'day', return the SQL that
103 truncates the given date or datetime field field_name to a date object
104 with only the given specificity.
105
106 If `tzname` is provided, the given value is truncated in a specific
107 timezone.
108 """
109 raise NotImplementedError(
110 "subclasses of BaseDatabaseOperations may require a date_trunc_sql() "
111 "method."
112 )
113
114 def datetime_cast_date_sql(self, sql, params, tzname):
115 """
116 Return the SQL to cast a datetime value to date value.
117 """
118 raise NotImplementedError(
119 "subclasses of BaseDatabaseOperations may require a "
120 "datetime_cast_date_sql() method."
121 )
122
123 def datetime_cast_time_sql(self, sql, params, tzname):
124 """
125 Return the SQL to cast a datetime value to time value.
126 """
127 raise NotImplementedError(
128 "subclasses of BaseDatabaseOperations may require a "
129 "datetime_cast_time_sql() method"
130 )
131
132 def datetime_extract_sql(self, lookup_type, sql, params, tzname):
133 """
134 Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
135 'second', return the SQL that extracts a value from the given
136 datetime field field_name.
137 """
138 raise NotImplementedError(
139 "subclasses of BaseDatabaseOperations may require a datetime_extract_sql() "
140 "method"
141 )
142
143 def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
144 """
145 Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
146 'second', return the SQL that truncates the given datetime field
147 field_name to a datetime object with only the given specificity.
148 """
149 raise NotImplementedError(
150 "subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() "
151 "method"
152 )
153
154 def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
155 """
156 Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
157 that truncates the given time or datetime field field_name to a time
158 object with only the given specificity.
159
160 If `tzname` is provided, the given value is truncated in a specific
161 timezone.
162 """
163 raise NotImplementedError(
164 "subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
165 )
166
167 def time_extract_sql(self, lookup_type, sql, params):
168 """
169 Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
170 that extracts a value from the given time field field_name.
171 """
172 return self.date_extract_sql(lookup_type, sql, params)
173
174 def deferrable_sql(self):
175 """
176 Return the SQL to make a constraint "initially deferred" during a
177 CREATE TABLE statement.
178 """
179 return ""
180
181 def distinct_sql(self, fields, params):
182 """
183 Return an SQL DISTINCT clause which removes duplicate rows from the
184 result set. If any fields are given, only check the given fields for
185 duplicates.
186 """
187 if fields:
188 raise NotSupportedError(
189 "DISTINCT ON fields is not supported by this database backend"
190 )
191 else:
192 return ["DISTINCT"], []
193
194 def fetch_returned_insert_columns(self, cursor, returning_params):
195 """
196 Given a cursor object that has just performed an INSERT...RETURNING
197 statement into a table, return the newly created data.
198 """
199 return cursor.fetchone()
200
201 def field_cast_sql(self, db_type, internal_type):
202 """
203 Given a column type (e.g. 'BLOB', 'VARCHAR') and an internal type
204 (e.g. 'GenericIPAddressField'), return the SQL to cast it before using
205 it in a WHERE statement. The resulting string should contain a '%s'
206 placeholder for the column being searched against.
207 """
208 return "%s"
209
210 def force_no_ordering(self):
211 """
212 Return a list used in the "ORDER BY" clause to force no ordering at
213 all. Return an empty list to include nothing in the ordering.
214 """
215 return []
216
217 def for_update_sql(self, nowait=False, skip_locked=False, of=(), no_key=False):
218 """
219 Return the FOR UPDATE SQL clause to lock rows for an update operation.
220 """
221 return "FOR{} UPDATE{}{}{}".format(
222 " NO KEY" if no_key else "",
223 " OF {}".format(", ".join(of)) if of else "",
224 " NOWAIT" if nowait else "",
225 " SKIP LOCKED" if skip_locked else "",
226 )
227
228 def _get_limit_offset_params(self, low_mark, high_mark):
229 offset = low_mark or 0
230 if high_mark is not None:
231 return (high_mark - offset), offset
232 elif offset:
233 return self.connection.ops.no_limit_value(), offset
234 return None, offset
235
236 def limit_offset_sql(self, low_mark, high_mark):
237 """Return LIMIT/OFFSET SQL clause."""
238 limit, offset = self._get_limit_offset_params(low_mark, high_mark)
239 return " ".join(
240 sql
241 for sql in (
242 ("LIMIT %d" % limit) if limit else None, # noqa: UP031
243 ("OFFSET %d" % offset) if offset else None, # noqa: UP031
244 )
245 if sql
246 )
247
248 def last_executed_query(self, cursor, sql, params):
249 """
250 Return a string of the query last executed by the given cursor, with
251 placeholders replaced with actual values.
252
253 `sql` is the raw query containing placeholders and `params` is the
254 sequence of parameters. These are used by default, but this method
255 exists for database backends to provide a better implementation
256 according to their own quoting schemes.
257 """
258
259 # Convert params to contain string values.
260 def to_string(s):
261 return force_str(s, strings_only=True, errors="replace")
262
263 if isinstance(params, list | tuple):
264 u_params = tuple(to_string(val) for val in params)
265 elif params is None:
266 u_params = ()
267 else:
268 u_params = {to_string(k): to_string(v) for k, v in params.items()}
269
270 return f"QUERY = {sql!r} - PARAMS = {u_params!r}"
271
272 def last_insert_id(self, cursor, table_name, pk_name):
273 """
274 Given a cursor object that has just performed an INSERT statement into
275 a table that has an auto-incrementing ID, return the newly created ID.
276
277 `pk_name` is the name of the primary-key column.
278 """
279 return cursor.lastrowid
280
281 def lookup_cast(self, lookup_type, internal_type=None):
282 """
283 Return the string to use in a query when performing lookups
284 ("contains", "like", etc.). It should contain a '%s' placeholder for
285 the column being searched against.
286 """
287 return "%s"
288
289 def max_in_list_size(self):
290 """
291 Return the maximum number of items that can be passed in a single 'IN'
292 list condition, or None if the backend does not impose a limit.
293 """
294 return None
295
296 def max_name_length(self):
297 """
298 Return the maximum length of table and column names, or None if there
299 is no limit.
300 """
301 return None
302
303 def no_limit_value(self):
304 """
305 Return the value to use for the LIMIT when we are wanting "LIMIT
306 infinity". Return None if the limit clause can be omitted in this case.
307 """
308 raise NotImplementedError(
309 "subclasses of BaseDatabaseOperations may require a no_limit_value() method"
310 )
311
312 def pk_default_value(self):
313 """
314 Return the value to use during an INSERT statement to specify that
315 the field should use its default value.
316 """
317 return "DEFAULT"
318
319 def prepare_sql_script(self, sql):
320 """
321 Take an SQL script that may contain multiple lines and return a list
322 of statements to feed to successive cursor.execute() calls.
323
324 Since few databases are able to process raw SQL scripts in a single
325 cursor.execute() call and PEP 249 doesn't talk about this use case,
326 the default implementation is conservative.
327 """
328 return [
329 sqlparse.format(statement, strip_comments=True)
330 for statement in sqlparse.split(sql)
331 if statement
332 ]
333
334 def return_insert_columns(self, fields):
335 """
336 For backends that support returning columns as part of an insert query,
337 return the SQL and params to append to the INSERT query. The returned
338 fragment should contain a format string to hold the appropriate column.
339 """
340 pass
341
342 def compiler(self, compiler_name):
343 """
344 Return the SQLCompiler class corresponding to the given name,
345 in the namespace corresponding to the `compiler_module` attribute
346 on this backend.
347 """
348 if self._cache is None:
349 self._cache = import_module(self.compiler_module)
350 return getattr(self._cache, compiler_name)
351
352 def quote_name(self, name):
353 """
354 Return a quoted version of the given table, index, or column name. Do
355 not quote the given name if it's already been quoted.
356 """
357 raise NotImplementedError(
358 "subclasses of BaseDatabaseOperations may require a quote_name() method"
359 )
360
361 def regex_lookup(self, lookup_type):
362 """
363 Return the string to use in a query when performing regular expression
364 lookups (using "regex" or "iregex"). It should contain a '%s'
365 placeholder for the column being searched against.
366
367 If the feature is not supported (or part of it is not supported), raise
368 NotImplementedError.
369 """
370 raise NotImplementedError(
371 "subclasses of BaseDatabaseOperations may require a regex_lookup() method"
372 )
373
374 def savepoint_create_sql(self, sid):
375 """
376 Return the SQL for starting a new savepoint. Only required if the
377 "uses_savepoints" feature is True. The "sid" parameter is a string
378 for the savepoint id.
379 """
380 return f"SAVEPOINT {self.quote_name(sid)}"
381
382 def savepoint_commit_sql(self, sid):
383 """
384 Return the SQL for committing the given savepoint.
385 """
386 return f"RELEASE SAVEPOINT {self.quote_name(sid)}"
387
388 def savepoint_rollback_sql(self, sid):
389 """
390 Return the SQL for rolling back the given savepoint.
391 """
392 return f"ROLLBACK TO SAVEPOINT {self.quote_name(sid)}"
393
394 def set_time_zone_sql(self):
395 """
396 Return the SQL that will set the connection's time zone.
397
398 Return '' if the backend doesn't support time zones.
399 """
400 return ""
401
402 def prep_for_like_query(self, x):
403 """Prepare a value for use in a LIKE query."""
404 return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
405
406 # Same as prep_for_like_query(), but called for "iexact" matches, which
407 # need not necessarily be implemented using "LIKE" in the backend.
408 prep_for_iexact_query = prep_for_like_query
409
410 def validate_autopk_value(self, value):
411 """
412 Certain backends do not accept some values for "serial" fields
413 (for example zero in MySQL). Raise a ValueError if the value is
414 invalid, otherwise return the validated value.
415 """
416 return value
417
418 def adapt_unknown_value(self, value):
419 """
420 Transform a value to something compatible with the backend driver.
421
422 This method only depends on the type of the value. It's designed for
423 cases where the target type isn't known, such as .raw() SQL queries.
424 As a consequence it may not work perfectly in all circumstances.
425 """
426 if isinstance(value, datetime.datetime): # must be before date
427 return self.adapt_datetimefield_value(value)
428 elif isinstance(value, datetime.date):
429 return self.adapt_datefield_value(value)
430 elif isinstance(value, datetime.time):
431 return self.adapt_timefield_value(value)
432 elif isinstance(value, decimal.Decimal):
433 return self.adapt_decimalfield_value(value)
434 else:
435 return value
436
437 def adapt_integerfield_value(self, value, internal_type):
438 return value
439
440 def adapt_datefield_value(self, value):
441 """
442 Transform a date value to an object compatible with what is expected
443 by the backend driver for date columns.
444 """
445 if value is None:
446 return None
447 return str(value)
448
449 def adapt_datetimefield_value(self, value):
450 """
451 Transform a datetime value to an object compatible with what is expected
452 by the backend driver for datetime columns.
453 """
454 if value is None:
455 return None
456 # Expression values are adapted by the database.
457 if hasattr(value, "resolve_expression"):
458 return value
459
460 return str(value)
461
462 def adapt_timefield_value(self, value):
463 """
464 Transform a time value to an object compatible with what is expected
465 by the backend driver for time columns.
466 """
467 if value is None:
468 return None
469 # Expression values are adapted by the database.
470 if hasattr(value, "resolve_expression"):
471 return value
472
473 if timezone.is_aware(value):
474 raise ValueError("Plain does not support timezone-aware times.")
475 return str(value)
476
477 def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
478 """
479 Transform a decimal.Decimal value to an object compatible with what is
480 expected by the backend driver for decimal (numeric) columns.
481 """
482 return utils.format_number(value, max_digits, decimal_places)
483
484 def adapt_ipaddressfield_value(self, value):
485 """
486 Transform a string representation of an IP address into the expected
487 type for the backend driver.
488 """
489 return value or None
490
491 def adapt_json_value(self, value, encoder):
492 return json.dumps(value, cls=encoder)
493
494 def year_lookup_bounds_for_date_field(self, value, iso_year=False):
495 """
496 Return a two-elements list with the lower and upper bound to be used
497 with a BETWEEN operator to query a DateField value using a year
498 lookup.
499
500 `value` is an int, containing the looked-up year.
501 If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
502 """
503 if iso_year:
504 first = datetime.date.fromisocalendar(value, 1, 1)
505 second = datetime.date.fromisocalendar(
506 value + 1, 1, 1
507 ) - datetime.timedelta(days=1)
508 else:
509 first = datetime.date(value, 1, 1)
510 second = datetime.date(value, 12, 31)
511 first = self.adapt_datefield_value(first)
512 second = self.adapt_datefield_value(second)
513 return [first, second]
514
515 def year_lookup_bounds_for_datetime_field(self, value, iso_year=False):
516 """
517 Return a two-elements list with the lower and upper bound to be used
518 with a BETWEEN operator to query a DateTimeField value using a year
519 lookup.
520
521 `value` is an int, containing the looked-up year.
522 If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
523 """
524 if iso_year:
525 first = datetime.datetime.fromisocalendar(value, 1, 1)
526 second = datetime.datetime.fromisocalendar(
527 value + 1, 1, 1
528 ) - datetime.timedelta(microseconds=1)
529 else:
530 first = datetime.datetime(value, 1, 1)
531 second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
532
533 # Make sure that datetimes are aware in the current timezone
534 tz = timezone.get_current_timezone()
535 first = timezone.make_aware(first, tz)
536 second = timezone.make_aware(second, tz)
537
538 first = self.adapt_datetimefield_value(first)
539 second = self.adapt_datetimefield_value(second)
540 return [first, second]
541
542 def get_db_converters(self, expression):
543 """
544 Return a list of functions needed to convert field data.
545
546 Some field types on some backends do not provide data in the correct
547 format, this is the hook for converter functions.
548 """
549 return []
550
551 def convert_durationfield_value(self, value, expression, connection):
552 if value is not None:
553 return datetime.timedelta(0, 0, value)
554
555 def check_expression_support(self, expression):
556 """
557 Check that the backend supports the provided expression.
558
559 This is used on specific backends to rule out known expressions
560 that have problematic or nonexistent implementations. If the
561 expression has a known problem, the backend should raise
562 NotSupportedError.
563 """
564 pass
565
566 def conditional_expression_supported_in_where_clause(self, expression):
567 """
568 Return True, if the conditional expression is supported in the WHERE
569 clause.
570 """
571 return True
572
573 def combine_expression(self, connector, sub_expressions):
574 """
575 Combine a list of subexpressions into a single expression, using
576 the provided connecting operator. This is required because operators
577 can vary between backends (e.g., Oracle with %% and &) and between
578 subexpression types (e.g., date expressions).
579 """
580 conn = f" {connector} "
581 return conn.join(sub_expressions)
582
583 def combine_duration_expression(self, connector, sub_expressions):
584 return self.combine_expression(connector, sub_expressions)
585
586 def binary_placeholder_sql(self, value):
587 """
588 Some backends require special syntax to insert binary content (MySQL
589 for example uses '_binary %s').
590 """
591 return "%s"
592
593 def modify_insert_params(self, placeholder, params):
594 """
595 Allow modification of insert parameters. Needed for Oracle Spatial
596 backend due to #10888.
597 """
598 return params
599
600 def integer_field_range(self, internal_type):
601 """
602 Given an integer field internal type (e.g. 'PositiveIntegerField'),
603 return a tuple of the (min_value, max_value) form representing the
604 range of the column type bound to the field.
605 """
606 return self.integer_field_ranges[internal_type]
607
608 def subtract_temporals(self, internal_type, lhs, rhs):
609 if self.connection.features.supports_temporal_subtraction:
610 lhs_sql, lhs_params = lhs
611 rhs_sql, rhs_params = rhs
612 return f"({lhs_sql} - {rhs_sql})", (*lhs_params, *rhs_params)
613 raise NotSupportedError(
614 f"This backend does not support {internal_type} subtraction."
615 )
616
617 def window_frame_start(self, start):
618 if isinstance(start, int):
619 if start < 0:
620 return "%d %s" % (abs(start), self.PRECEDING) # noqa: UP031
621 elif start == 0:
622 return self.CURRENT_ROW
623 elif start is None:
624 return self.UNBOUNDED_PRECEDING
625 raise ValueError(
626 f"start argument must be a negative integer, zero, or None, but got '{start}'."
627 )
628
629 def window_frame_end(self, end):
630 if isinstance(end, int):
631 if end == 0:
632 return self.CURRENT_ROW
633 elif end > 0:
634 return "%d %s" % (end, self.FOLLOWING) # noqa: UP031
635 elif end is None:
636 return self.UNBOUNDED_FOLLOWING
637 raise ValueError(
638 f"end argument must be a positive integer, zero, or None, but got '{end}'."
639 )
640
641 def window_frame_rows_start_end(self, start=None, end=None):
642 """
643 Return SQL for start and end points in an OVER clause window frame.
644 """
645 if not self.connection.features.supports_over_clause:
646 raise NotSupportedError("This backend does not support window expressions.")
647 return self.window_frame_start(start), self.window_frame_end(end)
648
649 def window_frame_range_start_end(self, start=None, end=None):
650 start_, end_ = self.window_frame_rows_start_end(start, end)
651 features = self.connection.features
652 if features.only_supports_unbounded_with_preceding_and_following and (
653 (start and start < 0) or (end and end > 0)
654 ):
655 raise NotSupportedError(
656 f"{self.connection.display_name} only supports UNBOUNDED together with PRECEDING and "
657 "FOLLOWING."
658 )
659 return start_, end_
660
661 def explain_query_prefix(self, format=None, **options):
662 if not self.connection.features.supports_explaining_query_execution:
663 raise NotSupportedError(
664 "This backend does not support explaining query execution."
665 )
666 if format:
667 supported_formats = self.connection.features.supported_explain_formats
668 normalized_format = format.upper()
669 if normalized_format not in supported_formats:
670 msg = f"{normalized_format} is not a recognized format."
671 if supported_formats:
672 msg += " Allowed formats: {}".format(
673 ", ".join(sorted(supported_formats))
674 )
675 else:
676 msg += (
677 f" {self.connection.display_name} does not support any formats."
678 )
679 raise ValueError(msg)
680 if options:
681 raise ValueError(
682 "Unknown options: {}".format(", ".join(sorted(options.keys())))
683 )
684 return self.explain_prefix
685
686 def insert_statement(self, on_conflict=None):
687 return "INSERT INTO"
688
689 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
690 return ""