1import datetime
2import decimal
3import json
4from importlib import import_module
5
6import sqlparse
7
8from plain.models.backends import utils
9from plain.models.db import NotSupportedError
10from plain.runtime import settings
11from plain.utils import timezone
12from plain.utils.encoding import force_str
13
14
15class BaseDatabaseOperations:
16 """
17 Encapsulate backend-specific differences, such as the way a backend
18 performs ordering or calculates the ID of a recently-inserted row.
19 """
20
21 compiler_module = "plain.models.sql.compiler"
22
23 # Integer field safe ranges by `internal_type` as documented
24 # in docs/ref/models/fields.txt.
25 integer_field_ranges = {
26 "SmallIntegerField": (-32768, 32767),
27 "IntegerField": (-2147483648, 2147483647),
28 "BigIntegerField": (-9223372036854775808, 9223372036854775807),
29 "PositiveBigIntegerField": (0, 9223372036854775807),
30 "PositiveSmallIntegerField": (0, 32767),
31 "PositiveIntegerField": (0, 2147483647),
32 "SmallAutoField": (-32768, 32767),
33 "AutoField": (-2147483648, 2147483647),
34 "BigAutoField": (-9223372036854775808, 9223372036854775807),
35 }
36 set_operators = {
37 "union": "UNION",
38 "intersection": "INTERSECT",
39 "difference": "EXCEPT",
40 }
41 # Mapping of Field.get_internal_type() (typically the model field's class
42 # name) to the data type to use for the Cast() function, if different from
43 # DatabaseWrapper.data_types.
44 cast_data_types = {}
45 # CharField data type if the max_length argument isn't provided.
46 cast_char_field_without_max_length = None
47
48 # Start and end points for window expressions.
49 PRECEDING = "PRECEDING"
50 FOLLOWING = "FOLLOWING"
51 UNBOUNDED_PRECEDING = "UNBOUNDED " + PRECEDING
52 UNBOUNDED_FOLLOWING = "UNBOUNDED " + FOLLOWING
53 CURRENT_ROW = "CURRENT ROW"
54
55 # Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
56 explain_prefix = None
57
58 def __init__(self, connection):
59 self.connection = connection
60 self._cache = None
61
62 def autoinc_sql(self, table, column):
63 """
64 Return any SQL needed to support auto-incrementing primary keys, or
65 None if no SQL is necessary.
66
67 This SQL is executed when a table is created.
68 """
69 return None
70
71 def bulk_batch_size(self, fields, objs):
72 """
73 Return the maximum allowed batch size for the backend. The fields
74 are the fields going to be inserted in the batch, the objs contains
75 all the objects to be inserted.
76 """
77 return len(objs)
78
79 def format_for_duration_arithmetic(self, sql):
80 raise NotImplementedError(
81 "subclasses of BaseDatabaseOperations may require a "
82 "format_for_duration_arithmetic() method."
83 )
84
85 def cache_key_culling_sql(self):
86 """
87 Return an SQL query that retrieves the first cache key greater than the
88 n smallest.
89
90 This is used by the 'db' cache backend to determine where to start
91 culling.
92 """
93 cache_key = self.quote_name("cache_key")
94 return f"SELECT {cache_key} FROM %s ORDER BY {cache_key} LIMIT 1 OFFSET %%s"
95
96 def unification_cast_sql(self, output_field):
97 """
98 Given a field instance, return the SQL that casts the result of a union
99 to that type. The resulting string should contain a '%s' placeholder
100 for the expression being cast.
101 """
102 return "%s"
103
104 def date_extract_sql(self, lookup_type, sql, params):
105 """
106 Given a lookup_type of 'year', 'month', or 'day', return the SQL that
107 extracts a value from the given date field field_name.
108 """
109 raise NotImplementedError(
110 "subclasses of BaseDatabaseOperations may require a date_extract_sql() "
111 "method"
112 )
113
114 def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
115 """
116 Given a lookup_type of 'year', 'month', or 'day', return the SQL that
117 truncates the given date or datetime field field_name to a date object
118 with only the given specificity.
119
120 If `tzname` is provided, the given value is truncated in a specific
121 timezone.
122 """
123 raise NotImplementedError(
124 "subclasses of BaseDatabaseOperations may require a date_trunc_sql() "
125 "method."
126 )
127
128 def datetime_cast_date_sql(self, sql, params, tzname):
129 """
130 Return the SQL to cast a datetime value to date value.
131 """
132 raise NotImplementedError(
133 "subclasses of BaseDatabaseOperations may require a "
134 "datetime_cast_date_sql() method."
135 )
136
137 def datetime_cast_time_sql(self, sql, params, tzname):
138 """
139 Return the SQL to cast a datetime value to time value.
140 """
141 raise NotImplementedError(
142 "subclasses of BaseDatabaseOperations may require a "
143 "datetime_cast_time_sql() method"
144 )
145
146 def datetime_extract_sql(self, lookup_type, sql, params, tzname):
147 """
148 Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
149 'second', return the SQL that extracts a value from the given
150 datetime field field_name.
151 """
152 raise NotImplementedError(
153 "subclasses of BaseDatabaseOperations may require a datetime_extract_sql() "
154 "method"
155 )
156
157 def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
158 """
159 Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
160 'second', return the SQL that truncates the given datetime field
161 field_name to a datetime object with only the given specificity.
162 """
163 raise NotImplementedError(
164 "subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() "
165 "method"
166 )
167
168 def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
169 """
170 Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
171 that truncates the given time or datetime field field_name to a time
172 object with only the given specificity.
173
174 If `tzname` is provided, the given value is truncated in a specific
175 timezone.
176 """
177 raise NotImplementedError(
178 "subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
179 )
180
181 def time_extract_sql(self, lookup_type, sql, params):
182 """
183 Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
184 that extracts a value from the given time field field_name.
185 """
186 return self.date_extract_sql(lookup_type, sql, params)
187
188 def deferrable_sql(self):
189 """
190 Return the SQL to make a constraint "initially deferred" during a
191 CREATE TABLE statement.
192 """
193 return ""
194
195 def distinct_sql(self, fields, params):
196 """
197 Return an SQL DISTINCT clause which removes duplicate rows from the
198 result set. If any fields are given, only check the given fields for
199 duplicates.
200 """
201 if fields:
202 raise NotSupportedError(
203 "DISTINCT ON fields is not supported by this database backend"
204 )
205 else:
206 return ["DISTINCT"], []
207
208 def fetch_returned_insert_columns(self, cursor, returning_params):
209 """
210 Given a cursor object that has just performed an INSERT...RETURNING
211 statement into a table, return the newly created data.
212 """
213 return cursor.fetchone()
214
215 def field_cast_sql(self, db_type, internal_type):
216 """
217 Given a column type (e.g. 'BLOB', 'VARCHAR') and an internal type
218 (e.g. 'GenericIPAddressField'), return the SQL to cast it before using
219 it in a WHERE statement. The resulting string should contain a '%s'
220 placeholder for the column being searched against.
221 """
222 return "%s"
223
224 def force_no_ordering(self):
225 """
226 Return a list used in the "ORDER BY" clause to force no ordering at
227 all. Return an empty list to include nothing in the ordering.
228 """
229 return []
230
231 def for_update_sql(self, nowait=False, skip_locked=False, of=(), no_key=False):
232 """
233 Return the FOR UPDATE SQL clause to lock rows for an update operation.
234 """
235 return "FOR{} UPDATE{}{}{}".format(
236 " NO KEY" if no_key else "",
237 " OF %s" % ", ".join(of) if of else "",
238 " NOWAIT" if nowait else "",
239 " SKIP LOCKED" if skip_locked else "",
240 )
241
242 def _get_limit_offset_params(self, low_mark, high_mark):
243 offset = low_mark or 0
244 if high_mark is not None:
245 return (high_mark - offset), offset
246 elif offset:
247 return self.connection.ops.no_limit_value(), offset
248 return None, offset
249
250 def limit_offset_sql(self, low_mark, high_mark):
251 """Return LIMIT/OFFSET SQL clause."""
252 limit, offset = self._get_limit_offset_params(low_mark, high_mark)
253 return " ".join(
254 sql
255 for sql in (
256 ("LIMIT %d" % limit) if limit else None,
257 ("OFFSET %d" % offset) if offset else None,
258 )
259 if sql
260 )
261
262 def last_executed_query(self, cursor, sql, params):
263 """
264 Return a string of the query last executed by the given cursor, with
265 placeholders replaced with actual values.
266
267 `sql` is the raw query containing placeholders and `params` is the
268 sequence of parameters. These are used by default, but this method
269 exists for database backends to provide a better implementation
270 according to their own quoting schemes.
271 """
272
273 # Convert params to contain string values.
274 def to_string(s):
275 return force_str(s, strings_only=True, errors="replace")
276
277 if isinstance(params, list | tuple):
278 u_params = tuple(to_string(val) for val in params)
279 elif params is None:
280 u_params = ()
281 else:
282 u_params = {to_string(k): to_string(v) for k, v in params.items()}
283
284 return f"QUERY = {sql!r} - PARAMS = {u_params!r}"
285
286 def last_insert_id(self, cursor, table_name, pk_name):
287 """
288 Given a cursor object that has just performed an INSERT statement into
289 a table that has an auto-incrementing ID, return the newly created ID.
290
291 `pk_name` is the name of the primary-key column.
292 """
293 return cursor.lastrowid
294
295 def lookup_cast(self, lookup_type, internal_type=None):
296 """
297 Return the string to use in a query when performing lookups
298 ("contains", "like", etc.). It should contain a '%s' placeholder for
299 the column being searched against.
300 """
301 return "%s"
302
303 def max_in_list_size(self):
304 """
305 Return the maximum number of items that can be passed in a single 'IN'
306 list condition, or None if the backend does not impose a limit.
307 """
308 return None
309
310 def max_name_length(self):
311 """
312 Return the maximum length of table and column names, or None if there
313 is no limit.
314 """
315 return None
316
317 def no_limit_value(self):
318 """
319 Return the value to use for the LIMIT when we are wanting "LIMIT
320 infinity". Return None if the limit clause can be omitted in this case.
321 """
322 raise NotImplementedError(
323 "subclasses of BaseDatabaseOperations may require a no_limit_value() method"
324 )
325
326 def pk_default_value(self):
327 """
328 Return the value to use during an INSERT statement to specify that
329 the field should use its default value.
330 """
331 return "DEFAULT"
332
333 def prepare_sql_script(self, sql):
334 """
335 Take an SQL script that may contain multiple lines and return a list
336 of statements to feed to successive cursor.execute() calls.
337
338 Since few databases are able to process raw SQL scripts in a single
339 cursor.execute() call and PEP 249 doesn't talk about this use case,
340 the default implementation is conservative.
341 """
342 return [
343 sqlparse.format(statement, strip_comments=True)
344 for statement in sqlparse.split(sql)
345 if statement
346 ]
347
348 def process_clob(self, value):
349 """
350 Return the value of a CLOB column, for backends that return a locator
351 object that requires additional processing.
352 """
353 return value
354
355 def return_insert_columns(self, fields):
356 """
357 For backends that support returning columns as part of an insert query,
358 return the SQL and params to append to the INSERT query. The returned
359 fragment should contain a format string to hold the appropriate column.
360 """
361 pass
362
363 def compiler(self, compiler_name):
364 """
365 Return the SQLCompiler class corresponding to the given name,
366 in the namespace corresponding to the `compiler_module` attribute
367 on this backend.
368 """
369 if self._cache is None:
370 self._cache = import_module(self.compiler_module)
371 return getattr(self._cache, compiler_name)
372
373 def quote_name(self, name):
374 """
375 Return a quoted version of the given table, index, or column name. Do
376 not quote the given name if it's already been quoted.
377 """
378 raise NotImplementedError(
379 "subclasses of BaseDatabaseOperations may require a quote_name() method"
380 )
381
382 def regex_lookup(self, lookup_type):
383 """
384 Return the string to use in a query when performing regular expression
385 lookups (using "regex" or "iregex"). It should contain a '%s'
386 placeholder for the column being searched against.
387
388 If the feature is not supported (or part of it is not supported), raise
389 NotImplementedError.
390 """
391 raise NotImplementedError(
392 "subclasses of BaseDatabaseOperations may require a regex_lookup() method"
393 )
394
395 def savepoint_create_sql(self, sid):
396 """
397 Return the SQL for starting a new savepoint. Only required if the
398 "uses_savepoints" feature is True. The "sid" parameter is a string
399 for the savepoint id.
400 """
401 return "SAVEPOINT %s" % self.quote_name(sid)
402
403 def savepoint_commit_sql(self, sid):
404 """
405 Return the SQL for committing the given savepoint.
406 """
407 return "RELEASE SAVEPOINT %s" % self.quote_name(sid)
408
409 def savepoint_rollback_sql(self, sid):
410 """
411 Return the SQL for rolling back the given savepoint.
412 """
413 return "ROLLBACK TO SAVEPOINT %s" % self.quote_name(sid)
414
415 def set_time_zone_sql(self):
416 """
417 Return the SQL that will set the connection's time zone.
418
419 Return '' if the backend doesn't support time zones.
420 """
421 return ""
422
423 def sequence_reset_by_name_sql(self, style, sequences):
424 """
425 Return a list of the SQL statements required to reset sequences
426 passed in `sequences`.
427
428 The `style` argument is a Style object as returned by either
429 color_style() or no_style() in plain.internal.legacy.management.color.
430 """
431 return []
432
433 def sequence_reset_sql(self, style, model_list):
434 """
435 Return a list of the SQL statements required to reset sequences for
436 the given models.
437
438 The `style` argument is a Style object as returned by either
439 color_style() or no_style() in plain.internal.legacy.management.color.
440 """
441 return [] # No sequence reset required by default.
442
443 def start_transaction_sql(self):
444 """Return the SQL statement required to start a transaction."""
445 return "BEGIN;"
446
447 def end_transaction_sql(self, success=True):
448 """Return the SQL statement required to end a transaction."""
449 if not success:
450 return "ROLLBACK;"
451 return "COMMIT;"
452
453 def tablespace_sql(self, tablespace, inline=False):
454 """
455 Return the SQL that will be used in a query to define the tablespace.
456
457 Return '' if the backend doesn't support tablespaces.
458
459 If `inline` is True, append the SQL to a row; otherwise append it to
460 the entire CREATE TABLE or CREATE INDEX statement.
461 """
462 return ""
463
464 def prep_for_like_query(self, x):
465 """Prepare a value for use in a LIKE query."""
466 return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
467
468 # Same as prep_for_like_query(), but called for "iexact" matches, which
469 # need not necessarily be implemented using "LIKE" in the backend.
470 prep_for_iexact_query = prep_for_like_query
471
472 def validate_autopk_value(self, value):
473 """
474 Certain backends do not accept some values for "serial" fields
475 (for example zero in MySQL). Raise a ValueError if the value is
476 invalid, otherwise return the validated value.
477 """
478 return value
479
480 def adapt_unknown_value(self, value):
481 """
482 Transform a value to something compatible with the backend driver.
483
484 This method only depends on the type of the value. It's designed for
485 cases where the target type isn't known, such as .raw() SQL queries.
486 As a consequence it may not work perfectly in all circumstances.
487 """
488 if isinstance(value, datetime.datetime): # must be before date
489 return self.adapt_datetimefield_value(value)
490 elif isinstance(value, datetime.date):
491 return self.adapt_datefield_value(value)
492 elif isinstance(value, datetime.time):
493 return self.adapt_timefield_value(value)
494 elif isinstance(value, decimal.Decimal):
495 return self.adapt_decimalfield_value(value)
496 else:
497 return value
498
499 def adapt_integerfield_value(self, value, internal_type):
500 return value
501
502 def adapt_datefield_value(self, value):
503 """
504 Transform a date value to an object compatible with what is expected
505 by the backend driver for date columns.
506 """
507 if value is None:
508 return None
509 return str(value)
510
511 def adapt_datetimefield_value(self, value):
512 """
513 Transform a datetime value to an object compatible with what is expected
514 by the backend driver for datetime columns.
515 """
516 if value is None:
517 return None
518 # Expression values are adapted by the database.
519 if hasattr(value, "resolve_expression"):
520 return value
521
522 return str(value)
523
524 def adapt_timefield_value(self, value):
525 """
526 Transform a time value to an object compatible with what is expected
527 by the backend driver for time columns.
528 """
529 if value is None:
530 return None
531 # Expression values are adapted by the database.
532 if hasattr(value, "resolve_expression"):
533 return value
534
535 if timezone.is_aware(value):
536 raise ValueError("Plain does not support timezone-aware times.")
537 return str(value)
538
539 def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
540 """
541 Transform a decimal.Decimal value to an object compatible with what is
542 expected by the backend driver for decimal (numeric) columns.
543 """
544 return utils.format_number(value, max_digits, decimal_places)
545
546 def adapt_ipaddressfield_value(self, value):
547 """
548 Transform a string representation of an IP address into the expected
549 type for the backend driver.
550 """
551 return value or None
552
553 def adapt_json_value(self, value, encoder):
554 return json.dumps(value, cls=encoder)
555
556 def year_lookup_bounds_for_date_field(self, value, iso_year=False):
557 """
558 Return a two-elements list with the lower and upper bound to be used
559 with a BETWEEN operator to query a DateField value using a year
560 lookup.
561
562 `value` is an int, containing the looked-up year.
563 If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
564 """
565 if iso_year:
566 first = datetime.date.fromisocalendar(value, 1, 1)
567 second = datetime.date.fromisocalendar(
568 value + 1, 1, 1
569 ) - datetime.timedelta(days=1)
570 else:
571 first = datetime.date(value, 1, 1)
572 second = datetime.date(value, 12, 31)
573 first = self.adapt_datefield_value(first)
574 second = self.adapt_datefield_value(second)
575 return [first, second]
576
577 def year_lookup_bounds_for_datetime_field(self, value, iso_year=False):
578 """
579 Return a two-elements list with the lower and upper bound to be used
580 with a BETWEEN operator to query a DateTimeField value using a year
581 lookup.
582
583 `value` is an int, containing the looked-up year.
584 If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
585 """
586 if iso_year:
587 first = datetime.datetime.fromisocalendar(value, 1, 1)
588 second = datetime.datetime.fromisocalendar(
589 value + 1, 1, 1
590 ) - datetime.timedelta(microseconds=1)
591 else:
592 first = datetime.datetime(value, 1, 1)
593 second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
594 if settings.USE_TZ:
595 tz = timezone.get_current_timezone()
596 first = timezone.make_aware(first, tz)
597 second = timezone.make_aware(second, tz)
598 first = self.adapt_datetimefield_value(first)
599 second = self.adapt_datetimefield_value(second)
600 return [first, second]
601
602 def get_db_converters(self, expression):
603 """
604 Return a list of functions needed to convert field data.
605
606 Some field types on some backends do not provide data in the correct
607 format, this is the hook for converter functions.
608 """
609 return []
610
611 def convert_durationfield_value(self, value, expression, connection):
612 if value is not None:
613 return datetime.timedelta(0, 0, value)
614
615 def check_expression_support(self, expression):
616 """
617 Check that the backend supports the provided expression.
618
619 This is used on specific backends to rule out known expressions
620 that have problematic or nonexistent implementations. If the
621 expression has a known problem, the backend should raise
622 NotSupportedError.
623 """
624 pass
625
626 def conditional_expression_supported_in_where_clause(self, expression):
627 """
628 Return True, if the conditional expression is supported in the WHERE
629 clause.
630 """
631 return True
632
633 def combine_expression(self, connector, sub_expressions):
634 """
635 Combine a list of subexpressions into a single expression, using
636 the provided connecting operator. This is required because operators
637 can vary between backends (e.g., Oracle with %% and &) and between
638 subexpression types (e.g., date expressions).
639 """
640 conn = " %s " % connector
641 return conn.join(sub_expressions)
642
643 def combine_duration_expression(self, connector, sub_expressions):
644 return self.combine_expression(connector, sub_expressions)
645
646 def binary_placeholder_sql(self, value):
647 """
648 Some backends require special syntax to insert binary content (MySQL
649 for example uses '_binary %s').
650 """
651 return "%s"
652
653 def modify_insert_params(self, placeholder, params):
654 """
655 Allow modification of insert parameters. Needed for Oracle Spatial
656 backend due to #10888.
657 """
658 return params
659
660 def integer_field_range(self, internal_type):
661 """
662 Given an integer field internal type (e.g. 'PositiveIntegerField'),
663 return a tuple of the (min_value, max_value) form representing the
664 range of the column type bound to the field.
665 """
666 return self.integer_field_ranges[internal_type]
667
668 def subtract_temporals(self, internal_type, lhs, rhs):
669 if self.connection.features.supports_temporal_subtraction:
670 lhs_sql, lhs_params = lhs
671 rhs_sql, rhs_params = rhs
672 return f"({lhs_sql} - {rhs_sql})", (*lhs_params, *rhs_params)
673 raise NotSupportedError(
674 "This backend does not support %s subtraction." % internal_type
675 )
676
677 def window_frame_start(self, start):
678 if isinstance(start, int):
679 if start < 0:
680 return "%d %s" % (abs(start), self.PRECEDING)
681 elif start == 0:
682 return self.CURRENT_ROW
683 elif start is None:
684 return self.UNBOUNDED_PRECEDING
685 raise ValueError(
686 "start argument must be a negative integer, zero, or None, but got '%s'."
687 % start
688 )
689
690 def window_frame_end(self, end):
691 if isinstance(end, int):
692 if end == 0:
693 return self.CURRENT_ROW
694 elif end > 0:
695 return "%d %s" % (end, self.FOLLOWING)
696 elif end is None:
697 return self.UNBOUNDED_FOLLOWING
698 raise ValueError(
699 "end argument must be a positive integer, zero, or None, but got '%s'."
700 % end
701 )
702
703 def window_frame_rows_start_end(self, start=None, end=None):
704 """
705 Return SQL for start and end points in an OVER clause window frame.
706 """
707 if not self.connection.features.supports_over_clause:
708 raise NotSupportedError("This backend does not support window expressions.")
709 return self.window_frame_start(start), self.window_frame_end(end)
710
711 def window_frame_range_start_end(self, start=None, end=None):
712 start_, end_ = self.window_frame_rows_start_end(start, end)
713 features = self.connection.features
714 if features.only_supports_unbounded_with_preceding_and_following and (
715 (start and start < 0) or (end and end > 0)
716 ):
717 raise NotSupportedError(
718 "%s only supports UNBOUNDED together with PRECEDING and "
719 "FOLLOWING." % self.connection.display_name
720 )
721 return start_, end_
722
723 def explain_query_prefix(self, format=None, **options):
724 if not self.connection.features.supports_explaining_query_execution:
725 raise NotSupportedError(
726 "This backend does not support explaining query execution."
727 )
728 if format:
729 supported_formats = self.connection.features.supported_explain_formats
730 normalized_format = format.upper()
731 if normalized_format not in supported_formats:
732 msg = "%s is not a recognized format." % normalized_format
733 if supported_formats:
734 msg += " Allowed formats: %s" % ", ".join(sorted(supported_formats))
735 else:
736 msg += (
737 f" {self.connection.display_name} does not support any formats."
738 )
739 raise ValueError(msg)
740 if options:
741 raise ValueError("Unknown options: %s" % ", ".join(sorted(options.keys())))
742 return self.explain_prefix
743
744 def insert_statement(self, on_conflict=None):
745 return "INSERT INTO"
746
747 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
748 return ""