1import datetime
2import decimal
3import uuid
4from functools import cached_property, lru_cache
5
6from plain import models
7from plain.exceptions import FieldError
8from plain.models.backends.base.operations import BaseDatabaseOperations
9from plain.models.constants import OnConflict
10from plain.models.db import DatabaseError, NotSupportedError
11from plain.models.expressions import Col
12from plain.utils import timezone
13from plain.utils.dateparse import parse_date, parse_datetime, parse_time
14
15
16class DatabaseOperations(BaseDatabaseOperations):
17 cast_char_field_without_max_length = "text"
18 cast_data_types = {
19 "DateField": "TEXT",
20 "DateTimeField": "TEXT",
21 }
22 explain_prefix = "EXPLAIN QUERY PLAN"
23 # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
24 # SQLite. Use JSON_TYPE() instead.
25 jsonfield_datatype_values = frozenset(["null", "false", "true"])
26
27 def bulk_batch_size(self, fields, objs):
28 """
29 SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
30 999 variables per query.
31
32 If there's only a single field to insert, the limit is 500
33 (SQLITE_MAX_COMPOUND_SELECT).
34 """
35 if len(fields) == 1:
36 return 500
37 elif len(fields) > 1:
38 return self.connection.features.max_query_params // len(fields)
39 else:
40 return len(objs)
41
42 def check_expression_support(self, expression):
43 bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
44 bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
45 if isinstance(expression, bad_aggregates):
46 for expr in expression.get_source_expressions():
47 try:
48 output_field = expr.output_field
49 except (AttributeError, FieldError):
50 # Not every subexpression has an output_field which is fine
51 # to ignore.
52 pass
53 else:
54 if isinstance(output_field, bad_fields):
55 raise NotSupportedError(
56 "You cannot use Sum, Avg, StdDev, and Variance "
57 "aggregations on date/time fields in sqlite3 "
58 "since date/time is saved as text."
59 )
60 if (
61 isinstance(expression, models.Aggregate)
62 and expression.distinct
63 and len(expression.source_expressions) > 1
64 ):
65 raise NotSupportedError(
66 "SQLite doesn't support DISTINCT on aggregate functions "
67 "accepting multiple arguments."
68 )
69
70 def date_extract_sql(self, lookup_type, sql, params):
71 """
72 Support EXTRACT with a user-defined function plain_date_extract()
73 that's registered in connect(). Use single quotes because this is a
74 string and could otherwise cause a collision with a field name.
75 """
76 return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params)
77
78 def fetch_returned_insert_rows(self, cursor):
79 """
80 Given a cursor object that has just performed an INSERT...RETURNING
81 statement into a table, return the list of returned data.
82 """
83 return cursor.fetchall()
84
85 def format_for_duration_arithmetic(self, sql):
86 """Do nothing since formatting is handled in the custom function."""
87 return sql
88
89 def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
90 return f"plain_date_trunc(%s, {sql}, %s, %s)", (
91 lookup_type.lower(),
92 *params,
93 *self._convert_tznames_to_sql(tzname),
94 )
95
96 def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
97 return f"plain_time_trunc(%s, {sql}, %s, %s)", (
98 lookup_type.lower(),
99 *params,
100 *self._convert_tznames_to_sql(tzname),
101 )
102
103 def _convert_tznames_to_sql(self, tzname):
104 if tzname:
105 return tzname, self.connection.timezone_name
106 return None, None
107
108 def datetime_cast_date_sql(self, sql, params, tzname):
109 return f"plain_datetime_cast_date({sql}, %s, %s)", (
110 *params,
111 *self._convert_tznames_to_sql(tzname),
112 )
113
114 def datetime_cast_time_sql(self, sql, params, tzname):
115 return f"plain_datetime_cast_time({sql}, %s, %s)", (
116 *params,
117 *self._convert_tznames_to_sql(tzname),
118 )
119
120 def datetime_extract_sql(self, lookup_type, sql, params, tzname):
121 return f"plain_datetime_extract(%s, {sql}, %s, %s)", (
122 lookup_type.lower(),
123 *params,
124 *self._convert_tznames_to_sql(tzname),
125 )
126
127 def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
128 return f"plain_datetime_trunc(%s, {sql}, %s, %s)", (
129 lookup_type.lower(),
130 *params,
131 *self._convert_tznames_to_sql(tzname),
132 )
133
134 def time_extract_sql(self, lookup_type, sql, params):
135 return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params)
136
137 def pk_default_value(self):
138 return "NULL"
139
140 def _quote_params_for_last_executed_query(self, params):
141 """
142 Only for last_executed_query! Don't use this to execute SQL queries!
143 """
144 # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
145 # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
146 # number of return values, default = 2000). Since Python's sqlite3
147 # module doesn't expose the get_limit() C API, assume the default
148 # limits are in effect and split the work in batches if needed.
149 BATCH_SIZE = 999
150 if len(params) > BATCH_SIZE:
151 results = ()
152 for index in range(0, len(params), BATCH_SIZE):
153 chunk = params[index : index + BATCH_SIZE]
154 results += self._quote_params_for_last_executed_query(chunk)
155 return results
156
157 sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
158 # Bypass Plain's wrappers and use the underlying sqlite3 connection
159 # to avoid logging this query - it would trigger infinite recursion.
160 cursor = self.connection.connection.cursor()
161 # Native sqlite3 cursors cannot be used as context managers.
162 try:
163 return cursor.execute(sql, params).fetchone()
164 finally:
165 cursor.close()
166
167 def last_executed_query(self, cursor, sql, params):
168 # Python substitutes parameters in Modules/_sqlite/cursor.c with:
169 # bind_parameters(state, self->statement, parameters);
170 # Unfortunately there is no way to reach self->statement from Python,
171 # so we quote and substitute parameters manually.
172 if params:
173 if isinstance(params, list | tuple):
174 params = self._quote_params_for_last_executed_query(params)
175 else:
176 values = tuple(params.values())
177 values = self._quote_params_for_last_executed_query(values)
178 params = dict(zip(params, values))
179 return sql % params
180 # For consistency with SQLiteCursorWrapper.execute(), just return sql
181 # when there are no parameters. See #13648 and #17158.
182 else:
183 return sql
184
185 def quote_name(self, name):
186 if name.startswith('"') and name.endswith('"'):
187 return name # Quoting once is enough.
188 return f'"{name}"'
189
190 def no_limit_value(self):
191 return -1
192
193 def __references_graph(self, table_name):
194 query = """
195 WITH tables AS (
196 SELECT %s name
197 UNION
198 SELECT sqlite_master.name
199 FROM sqlite_master
200 JOIN tables ON (sql REGEXP %s || tables.name || %s)
201 ) SELECT name FROM tables;
202 """
203 params = (
204 table_name,
205 r'(?i)\s+references\s+("|\')?',
206 r'("|\')?\s*\(',
207 )
208 with self.connection.cursor() as cursor:
209 results = cursor.execute(query, params)
210 return [row[0] for row in results.fetchall()]
211
212 @cached_property
213 def _references_graph(self):
214 # 512 is large enough to fit the ~330 tables (as of this writing) in
215 # Plain's test suite.
216 return lru_cache(maxsize=512)(self.__references_graph)
217
218 def adapt_datetimefield_value(self, value):
219 if value is None:
220 return None
221
222 # Expression values are adapted by the database.
223 if hasattr(value, "resolve_expression"):
224 return value
225
226 # SQLite doesn't support tz-aware datetimes
227 if timezone.is_aware(value):
228 value = timezone.make_naive(value, self.connection.timezone)
229
230 return str(value)
231
232 def adapt_timefield_value(self, value):
233 if value is None:
234 return None
235
236 # Expression values are adapted by the database.
237 if hasattr(value, "resolve_expression"):
238 return value
239
240 # SQLite doesn't support tz-aware datetimes
241 if timezone.is_aware(value):
242 raise ValueError("SQLite backend does not support timezone-aware times.")
243
244 return str(value)
245
246 def get_db_converters(self, expression):
247 converters = super().get_db_converters(expression)
248 internal_type = expression.output_field.get_internal_type()
249 if internal_type == "DateTimeField":
250 converters.append(self.convert_datetimefield_value)
251 elif internal_type == "DateField":
252 converters.append(self.convert_datefield_value)
253 elif internal_type == "TimeField":
254 converters.append(self.convert_timefield_value)
255 elif internal_type == "DecimalField":
256 converters.append(self.get_decimalfield_converter(expression))
257 elif internal_type == "UUIDField":
258 converters.append(self.convert_uuidfield_value)
259 elif internal_type == "BooleanField":
260 converters.append(self.convert_booleanfield_value)
261 return converters
262
263 def convert_datetimefield_value(self, value, expression, connection):
264 if value is not None:
265 if not isinstance(value, datetime.datetime):
266 value = parse_datetime(value)
267 if not timezone.is_aware(value):
268 value = timezone.make_aware(value, self.connection.timezone)
269 return value
270
271 def convert_datefield_value(self, value, expression, connection):
272 if value is not None:
273 if not isinstance(value, datetime.date):
274 value = parse_date(value)
275 return value
276
277 def convert_timefield_value(self, value, expression, connection):
278 if value is not None:
279 if not isinstance(value, datetime.time):
280 value = parse_time(value)
281 return value
282
283 def get_decimalfield_converter(self, expression):
284 # SQLite stores only 15 significant digits. Digits coming from
285 # float inaccuracy must be removed.
286 create_decimal = decimal.Context(prec=15).create_decimal_from_float
287 if isinstance(expression, Col):
288 quantize_value = decimal.Decimal(1).scaleb(
289 -expression.output_field.decimal_places
290 )
291
292 def converter(value, expression, connection):
293 if value is not None:
294 return create_decimal(value).quantize(
295 quantize_value, context=expression.output_field.context
296 )
297
298 else:
299
300 def converter(value, expression, connection):
301 if value is not None:
302 return create_decimal(value)
303
304 return converter
305
306 def convert_uuidfield_value(self, value, expression, connection):
307 if value is not None:
308 value = uuid.UUID(value)
309 return value
310
311 def convert_booleanfield_value(self, value, expression, connection):
312 return bool(value) if value in (1, 0) else value
313
314 def bulk_insert_sql(self, fields, placeholder_rows):
315 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
316 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
317 return f"VALUES {values_sql}"
318
319 def combine_expression(self, connector, sub_expressions):
320 # SQLite doesn't have a ^ operator, so use the user-defined POWER
321 # function that's registered in connect().
322 if connector == "^":
323 return "POWER({})".format(",".join(sub_expressions))
324 elif connector == "#":
325 return "BITXOR({})".format(",".join(sub_expressions))
326 return super().combine_expression(connector, sub_expressions)
327
328 def combine_duration_expression(self, connector, sub_expressions):
329 if connector not in ["+", "-", "*", "/"]:
330 raise DatabaseError(f"Invalid connector for timedelta: {connector}.")
331 fn_params = [f"'{connector}'"] + sub_expressions
332 if len(fn_params) > 3:
333 raise ValueError("Too many params for timedelta operations.")
334 return "plain_format_dtdelta({})".format(", ".join(fn_params))
335
336 def integer_field_range(self, internal_type):
337 # SQLite doesn't enforce any integer constraints, but sqlite3 supports
338 # integers up to 64 bits.
339 if internal_type in [
340 "PositiveBigIntegerField",
341 "PositiveIntegerField",
342 "PositiveSmallIntegerField",
343 ]:
344 return (0, 9223372036854775807)
345 return (-9223372036854775808, 9223372036854775807)
346
347 def subtract_temporals(self, internal_type, lhs, rhs):
348 lhs_sql, lhs_params = lhs
349 rhs_sql, rhs_params = rhs
350 params = (*lhs_params, *rhs_params)
351 if internal_type == "TimeField":
352 return f"plain_time_diff({lhs_sql}, {rhs_sql})", params
353 return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params
354
355 def insert_statement(self, on_conflict=None):
356 if on_conflict == OnConflict.IGNORE:
357 return "INSERT OR IGNORE INTO"
358 return super().insert_statement(on_conflict=on_conflict)
359
360 def return_insert_columns(self, fields):
361 # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
362 if not fields:
363 return "", ()
364 columns = [
365 f"{self.quote_name(field.model._meta.db_table)}.{self.quote_name(field.column)}"
366 for field in fields
367 ]
368 return "RETURNING {}".format(", ".join(columns)), ()
369
370 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
371 if (
372 on_conflict == OnConflict.UPDATE
373 and self.connection.features.supports_update_conflicts_with_target
374 ):
375 return "ON CONFLICT({}) DO UPDATE SET {}".format(
376 ", ".join(map(self.quote_name, unique_fields)),
377 ", ".join(
378 [
379 f"{field} = EXCLUDED.{field}"
380 for field in map(self.quote_name, update_fields)
381 ]
382 ),
383 )
384 return super().on_conflict_suffix_sql(
385 fields,
386 on_conflict,
387 update_fields,
388 unique_fields,
389 )