1import datetime
2import decimal
3import uuid
4from functools import lru_cache
5
6from plain import models
7from plain.exceptions import FieldError
8from plain.models.backends.base.operations import BaseDatabaseOperations
9from plain.models.constants import OnConflict
10from plain.models.db import DatabaseError, NotSupportedError
11from plain.models.expressions import Col
12from plain.runtime import settings
13from plain.utils import timezone
14from plain.utils.dateparse import parse_date, parse_datetime, parse_time
15from plain.utils.functional import cached_property
16
17
18class DatabaseOperations(BaseDatabaseOperations):
19 cast_char_field_without_max_length = "text"
20 cast_data_types = {
21 "DateField": "TEXT",
22 "DateTimeField": "TEXT",
23 }
24 explain_prefix = "EXPLAIN QUERY PLAN"
25 # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
26 # SQLite. Use JSON_TYPE() instead.
27 jsonfield_datatype_values = frozenset(["null", "false", "true"])
28
29 def bulk_batch_size(self, fields, objs):
30 """
31 SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
32 999 variables per query.
33
34 If there's only a single field to insert, the limit is 500
35 (SQLITE_MAX_COMPOUND_SELECT).
36 """
37 if len(fields) == 1:
38 return 500
39 elif len(fields) > 1:
40 return self.connection.features.max_query_params // len(fields)
41 else:
42 return len(objs)
43
44 def check_expression_support(self, expression):
45 bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
46 bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
47 if isinstance(expression, bad_aggregates):
48 for expr in expression.get_source_expressions():
49 try:
50 output_field = expr.output_field
51 except (AttributeError, FieldError):
52 # Not every subexpression has an output_field which is fine
53 # to ignore.
54 pass
55 else:
56 if isinstance(output_field, bad_fields):
57 raise NotSupportedError(
58 "You cannot use Sum, Avg, StdDev, and Variance "
59 "aggregations on date/time fields in sqlite3 "
60 "since date/time is saved as text."
61 )
62 if (
63 isinstance(expression, models.Aggregate)
64 and expression.distinct
65 and len(expression.source_expressions) > 1
66 ):
67 raise NotSupportedError(
68 "SQLite doesn't support DISTINCT on aggregate functions "
69 "accepting multiple arguments."
70 )
71
72 def date_extract_sql(self, lookup_type, sql, params):
73 """
74 Support EXTRACT with a user-defined function plain_date_extract()
75 that's registered in connect(). Use single quotes because this is a
76 string and could otherwise cause a collision with a field name.
77 """
78 return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params)
79
80 def fetch_returned_insert_rows(self, cursor):
81 """
82 Given a cursor object that has just performed an INSERT...RETURNING
83 statement into a table, return the list of returned data.
84 """
85 return cursor.fetchall()
86
87 def format_for_duration_arithmetic(self, sql):
88 """Do nothing since formatting is handled in the custom function."""
89 return sql
90
91 def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
92 return f"plain_date_trunc(%s, {sql}, %s, %s)", (
93 lookup_type.lower(),
94 *params,
95 *self._convert_tznames_to_sql(tzname),
96 )
97
98 def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
99 return f"plain_time_trunc(%s, {sql}, %s, %s)", (
100 lookup_type.lower(),
101 *params,
102 *self._convert_tznames_to_sql(tzname),
103 )
104
105 def _convert_tznames_to_sql(self, tzname):
106 if tzname and settings.USE_TZ:
107 return tzname, self.connection.timezone_name
108 return None, None
109
110 def datetime_cast_date_sql(self, sql, params, tzname):
111 return f"plain_datetime_cast_date({sql}, %s, %s)", (
112 *params,
113 *self._convert_tznames_to_sql(tzname),
114 )
115
116 def datetime_cast_time_sql(self, sql, params, tzname):
117 return f"plain_datetime_cast_time({sql}, %s, %s)", (
118 *params,
119 *self._convert_tznames_to_sql(tzname),
120 )
121
122 def datetime_extract_sql(self, lookup_type, sql, params, tzname):
123 return f"plain_datetime_extract(%s, {sql}, %s, %s)", (
124 lookup_type.lower(),
125 *params,
126 *self._convert_tznames_to_sql(tzname),
127 )
128
129 def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
130 return f"plain_datetime_trunc(%s, {sql}, %s, %s)", (
131 lookup_type.lower(),
132 *params,
133 *self._convert_tznames_to_sql(tzname),
134 )
135
136 def time_extract_sql(self, lookup_type, sql, params):
137 return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params)
138
139 def pk_default_value(self):
140 return "NULL"
141
142 def _quote_params_for_last_executed_query(self, params):
143 """
144 Only for last_executed_query! Don't use this to execute SQL queries!
145 """
146 # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
147 # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
148 # number of return values, default = 2000). Since Python's sqlite3
149 # module doesn't expose the get_limit() C API, assume the default
150 # limits are in effect and split the work in batches if needed.
151 BATCH_SIZE = 999
152 if len(params) > BATCH_SIZE:
153 results = ()
154 for index in range(0, len(params), BATCH_SIZE):
155 chunk = params[index : index + BATCH_SIZE]
156 results += self._quote_params_for_last_executed_query(chunk)
157 return results
158
159 sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
160 # Bypass Plain's wrappers and use the underlying sqlite3 connection
161 # to avoid logging this query - it would trigger infinite recursion.
162 cursor = self.connection.connection.cursor()
163 # Native sqlite3 cursors cannot be used as context managers.
164 try:
165 return cursor.execute(sql, params).fetchone()
166 finally:
167 cursor.close()
168
169 def last_executed_query(self, cursor, sql, params):
170 # Python substitutes parameters in Modules/_sqlite/cursor.c with:
171 # bind_parameters(state, self->statement, parameters);
172 # Unfortunately there is no way to reach self->statement from Python,
173 # so we quote and substitute parameters manually.
174 if params:
175 if isinstance(params, list | tuple):
176 params = self._quote_params_for_last_executed_query(params)
177 else:
178 values = tuple(params.values())
179 values = self._quote_params_for_last_executed_query(values)
180 params = dict(zip(params, values))
181 return sql % params
182 # For consistency with SQLiteCursorWrapper.execute(), just return sql
183 # when there are no parameters. See #13648 and #17158.
184 else:
185 return sql
186
187 def quote_name(self, name):
188 if name.startswith('"') and name.endswith('"'):
189 return name # Quoting once is enough.
190 return '"%s"' % name
191
192 def no_limit_value(self):
193 return -1
194
195 def __references_graph(self, table_name):
196 query = """
197 WITH tables AS (
198 SELECT %s name
199 UNION
200 SELECT sqlite_master.name
201 FROM sqlite_master
202 JOIN tables ON (sql REGEXP %s || tables.name || %s)
203 ) SELECT name FROM tables;
204 """
205 params = (
206 table_name,
207 r'(?i)\s+references\s+("|\')?',
208 r'("|\')?\s*\(',
209 )
210 with self.connection.cursor() as cursor:
211 results = cursor.execute(query, params)
212 return [row[0] for row in results.fetchall()]
213
214 @cached_property
215 def _references_graph(self):
216 # 512 is large enough to fit the ~330 tables (as of this writing) in
217 # Plain's test suite.
218 return lru_cache(maxsize=512)(self.__references_graph)
219
220 def sequence_reset_by_name_sql(self, style, sequences):
221 if not sequences:
222 return []
223 return [
224 "{} {} {} {} = 0 {} {} {} ({});".format(
225 style.SQL_KEYWORD("UPDATE"),
226 style.SQL_TABLE(self.quote_name("sqlite_sequence")),
227 style.SQL_KEYWORD("SET"),
228 style.SQL_FIELD(self.quote_name("seq")),
229 style.SQL_KEYWORD("WHERE"),
230 style.SQL_FIELD(self.quote_name("name")),
231 style.SQL_KEYWORD("IN"),
232 ", ".join(
233 ["'%s'" % sequence_info["table"] for sequence_info in sequences]
234 ),
235 ),
236 ]
237
238 def adapt_datetimefield_value(self, value):
239 if value is None:
240 return None
241
242 # Expression values are adapted by the database.
243 if hasattr(value, "resolve_expression"):
244 return value
245
246 # SQLite doesn't support tz-aware datetimes
247 if timezone.is_aware(value):
248 if settings.USE_TZ:
249 value = timezone.make_naive(value, self.connection.timezone)
250 else:
251 raise ValueError(
252 "SQLite backend does not support timezone-aware datetimes when "
253 "USE_TZ is False."
254 )
255
256 return str(value)
257
258 def adapt_timefield_value(self, value):
259 if value is None:
260 return None
261
262 # Expression values are adapted by the database.
263 if hasattr(value, "resolve_expression"):
264 return value
265
266 # SQLite doesn't support tz-aware datetimes
267 if timezone.is_aware(value):
268 raise ValueError("SQLite backend does not support timezone-aware times.")
269
270 return str(value)
271
272 def get_db_converters(self, expression):
273 converters = super().get_db_converters(expression)
274 internal_type = expression.output_field.get_internal_type()
275 if internal_type == "DateTimeField":
276 converters.append(self.convert_datetimefield_value)
277 elif internal_type == "DateField":
278 converters.append(self.convert_datefield_value)
279 elif internal_type == "TimeField":
280 converters.append(self.convert_timefield_value)
281 elif internal_type == "DecimalField":
282 converters.append(self.get_decimalfield_converter(expression))
283 elif internal_type == "UUIDField":
284 converters.append(self.convert_uuidfield_value)
285 elif internal_type == "BooleanField":
286 converters.append(self.convert_booleanfield_value)
287 return converters
288
289 def convert_datetimefield_value(self, value, expression, connection):
290 if value is not None:
291 if not isinstance(value, datetime.datetime):
292 value = parse_datetime(value)
293 if settings.USE_TZ and not timezone.is_aware(value):
294 value = timezone.make_aware(value, self.connection.timezone)
295 return value
296
297 def convert_datefield_value(self, value, expression, connection):
298 if value is not None:
299 if not isinstance(value, datetime.date):
300 value = parse_date(value)
301 return value
302
303 def convert_timefield_value(self, value, expression, connection):
304 if value is not None:
305 if not isinstance(value, datetime.time):
306 value = parse_time(value)
307 return value
308
309 def get_decimalfield_converter(self, expression):
310 # SQLite stores only 15 significant digits. Digits coming from
311 # float inaccuracy must be removed.
312 create_decimal = decimal.Context(prec=15).create_decimal_from_float
313 if isinstance(expression, Col):
314 quantize_value = decimal.Decimal(1).scaleb(
315 -expression.output_field.decimal_places
316 )
317
318 def converter(value, expression, connection):
319 if value is not None:
320 return create_decimal(value).quantize(
321 quantize_value, context=expression.output_field.context
322 )
323
324 else:
325
326 def converter(value, expression, connection):
327 if value is not None:
328 return create_decimal(value)
329
330 return converter
331
332 def convert_uuidfield_value(self, value, expression, connection):
333 if value is not None:
334 value = uuid.UUID(value)
335 return value
336
337 def convert_booleanfield_value(self, value, expression, connection):
338 return bool(value) if value in (1, 0) else value
339
340 def bulk_insert_sql(self, fields, placeholder_rows):
341 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
342 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
343 return f"VALUES {values_sql}"
344
345 def combine_expression(self, connector, sub_expressions):
346 # SQLite doesn't have a ^ operator, so use the user-defined POWER
347 # function that's registered in connect().
348 if connector == "^":
349 return "POWER(%s)" % ",".join(sub_expressions)
350 elif connector == "#":
351 return "BITXOR(%s)" % ",".join(sub_expressions)
352 return super().combine_expression(connector, sub_expressions)
353
354 def combine_duration_expression(self, connector, sub_expressions):
355 if connector not in ["+", "-", "*", "/"]:
356 raise DatabaseError("Invalid connector for timedelta: %s." % connector)
357 fn_params = ["'%s'" % connector] + sub_expressions
358 if len(fn_params) > 3:
359 raise ValueError("Too many params for timedelta operations.")
360 return "plain_format_dtdelta(%s)" % ", ".join(fn_params)
361
362 def integer_field_range(self, internal_type):
363 # SQLite doesn't enforce any integer constraints, but sqlite3 supports
364 # integers up to 64 bits.
365 if internal_type in [
366 "PositiveBigIntegerField",
367 "PositiveIntegerField",
368 "PositiveSmallIntegerField",
369 ]:
370 return (0, 9223372036854775807)
371 return (-9223372036854775808, 9223372036854775807)
372
373 def subtract_temporals(self, internal_type, lhs, rhs):
374 lhs_sql, lhs_params = lhs
375 rhs_sql, rhs_params = rhs
376 params = (*lhs_params, *rhs_params)
377 if internal_type == "TimeField":
378 return f"plain_time_diff({lhs_sql}, {rhs_sql})", params
379 return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params
380
381 def insert_statement(self, on_conflict=None):
382 if on_conflict == OnConflict.IGNORE:
383 return "INSERT OR IGNORE INTO"
384 return super().insert_statement(on_conflict=on_conflict)
385
386 def return_insert_columns(self, fields):
387 # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
388 if not fields:
389 return "", ()
390 columns = [
391 "{}.{}".format(
392 self.quote_name(field.model._meta.db_table),
393 self.quote_name(field.column),
394 )
395 for field in fields
396 ]
397 return "RETURNING %s" % ", ".join(columns), ()
398
399 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
400 if (
401 on_conflict == OnConflict.UPDATE
402 and self.connection.features.supports_update_conflicts_with_target
403 ):
404 return "ON CONFLICT({}) DO UPDATE SET {}".format(
405 ", ".join(map(self.quote_name, unique_fields)),
406 ", ".join(
407 [
408 f"{field} = EXCLUDED.{field}"
409 for field in map(self.quote_name, update_fields)
410 ]
411 ),
412 )
413 return super().on_conflict_suffix_sql(
414 fields,
415 on_conflict,
416 update_fields,
417 unique_fields,
418 )