1from __future__ import annotations
2
3import datetime
4import decimal
5import uuid
6from collections.abc import Callable, Iterable
7from functools import cached_property, lru_cache
8from typing import TYPE_CHECKING, Any
9
10from plain import models
11from plain.models.aggregates import Aggregate, Avg, StdDev, Sum, Variance
12from plain.models.backends.base.operations import BaseDatabaseOperations
13from plain.models.backends.utils import CursorWrapper
14from plain.models.constants import OnConflict
15from plain.models.db import DatabaseError, NotSupportedError
16from plain.models.exceptions import FieldError
17from plain.models.expressions import Col, ResolvableExpression
18from plain.utils import timezone
19from plain.utils.dateparse import parse_date, parse_datetime, parse_time
20
21if TYPE_CHECKING:
22 from plain.models.backends.base.base import BaseDatabaseWrapper
23 from plain.models.backends.sqlite3.base import SQLiteDatabaseWrapper
24 from plain.models.fields import Field
25
26
27class DatabaseOperations(BaseDatabaseOperations):
28 # Type checker hint: connection is always SQLiteDatabaseWrapper in this class
29 connection: SQLiteDatabaseWrapper
30
31 cast_char_field_without_max_length = "text"
32 cast_data_types = {
33 "DateField": "TEXT",
34 "DateTimeField": "TEXT",
35 }
36 explain_prefix = "EXPLAIN QUERY PLAN"
37 # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
38 # SQLite. Use JSON_TYPE() instead.
39 jsonfield_datatype_values = frozenset(["null", "false", "true"])
40
41 def bulk_batch_size(self, fields: list[Any], objs: list[Any]) -> int:
42 """
43 SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
44 999 variables per query.
45
46 If there's only a single field to insert, the limit is 500
47 (SQLITE_MAX_COMPOUND_SELECT).
48 """
49 if len(fields) == 1:
50 return 500
51 elif len(fields) > 1:
52 return self.connection.features.max_query_params // len(fields)
53 else:
54 return len(objs)
55
56 def check_expression_support(self, expression: Any) -> None:
57 bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
58 bad_aggregates = (Sum, Avg, Variance, StdDev)
59 if isinstance(expression, bad_aggregates):
60 for expr in expression.get_source_expressions():
61 try:
62 output_field = expr.output_field
63 except (AttributeError, FieldError):
64 # Not every subexpression has an output_field which is fine
65 # to ignore.
66 pass
67 else:
68 if isinstance(output_field, bad_fields):
69 raise NotSupportedError(
70 "You cannot use Sum, Avg, StdDev, and Variance "
71 "aggregations on date/time fields in sqlite3 "
72 "since date/time is saved as text."
73 )
74 if (
75 isinstance(expression, Aggregate)
76 and expression.distinct
77 and len(expression.source_expressions) > 1
78 ):
79 raise NotSupportedError(
80 "SQLite doesn't support DISTINCT on aggregate functions "
81 "accepting multiple arguments."
82 )
83
84 def date_extract_sql(
85 self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
86 ) -> tuple[str, tuple[Any, ...]]:
87 """
88 Support EXTRACT with a user-defined function plain_date_extract()
89 that's registered in connect(). Use single quotes because this is a
90 string and could otherwise cause a collision with a field name.
91 """
92 return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params)
93
94 def fetch_returned_insert_rows(self, cursor: CursorWrapper) -> list[Any]:
95 """
96 Given a cursor object that has just performed an INSERT...RETURNING
97 statement into a table, return the list of returned data.
98 """
99 return cursor.fetchall()
100
101 def format_for_duration_arithmetic(self, sql: str) -> str:
102 """Do nothing since formatting is handled in the custom function."""
103 return sql
104
105 def date_trunc_sql(
106 self,
107 lookup_type: str,
108 sql: str,
109 params: list[Any] | tuple[Any, ...],
110 tzname: str | None = None,
111 ) -> tuple[str, tuple[Any, ...]]:
112 return f"plain_date_trunc(%s, {sql}, %s, %s)", (
113 lookup_type.lower(),
114 *params,
115 *self._convert_tznames_to_sql(tzname),
116 )
117
118 def time_trunc_sql(
119 self,
120 lookup_type: str,
121 sql: str,
122 params: list[Any] | tuple[Any, ...],
123 tzname: str | None = None,
124 ) -> tuple[str, tuple[Any, ...]]:
125 return f"plain_time_trunc(%s, {sql}, %s, %s)", (
126 lookup_type.lower(),
127 *params,
128 *self._convert_tznames_to_sql(tzname),
129 )
130
131 def _convert_tznames_to_sql(
132 self, tzname: str | None
133 ) -> tuple[str | None, str | None]:
134 if tzname:
135 return tzname, self.connection.timezone_name
136 return None, None
137
138 def datetime_cast_date_sql(
139 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
140 ) -> tuple[str, tuple[Any, ...]]:
141 return f"plain_datetime_cast_date({sql}, %s, %s)", (
142 *params,
143 *self._convert_tznames_to_sql(tzname),
144 )
145
146 def datetime_cast_time_sql(
147 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
148 ) -> tuple[str, tuple[Any, ...]]:
149 return f"plain_datetime_cast_time({sql}, %s, %s)", (
150 *params,
151 *self._convert_tznames_to_sql(tzname),
152 )
153
154 def datetime_extract_sql(
155 self,
156 lookup_type: str,
157 sql: str,
158 params: list[Any] | tuple[Any, ...],
159 tzname: str | None,
160 ) -> tuple[str, tuple[Any, ...]]:
161 return f"plain_datetime_extract(%s, {sql}, %s, %s)", (
162 lookup_type.lower(),
163 *params,
164 *self._convert_tznames_to_sql(tzname),
165 )
166
167 def datetime_trunc_sql(
168 self,
169 lookup_type: str,
170 sql: str,
171 params: list[Any] | tuple[Any, ...],
172 tzname: str | None,
173 ) -> tuple[str, tuple[Any, ...]]:
174 return f"plain_datetime_trunc(%s, {sql}, %s, %s)", (
175 lookup_type.lower(),
176 *params,
177 *self._convert_tznames_to_sql(tzname),
178 )
179
180 def time_extract_sql(
181 self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
182 ) -> tuple[str, tuple[Any, ...]]:
183 return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params)
184
185 def pk_default_value(self) -> str:
186 return "NULL"
187
188 def _quote_params_for_last_executed_query(
189 self, params: list[Any] | tuple[Any, ...]
190 ) -> tuple[Any, ...]:
191 """
192 Only for last_executed_query! Don't use this to execute SQL queries!
193 """
194 # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
195 # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
196 # number of return values, default = 2000). Since Python's sqlite3
197 # module doesn't expose the get_limit() C API, assume the default
198 # limits are in effect and split the work in batches if needed.
199 BATCH_SIZE = 999
200 if len(params) > BATCH_SIZE:
201 results = ()
202 for index in range(0, len(params), BATCH_SIZE):
203 chunk = params[index : index + BATCH_SIZE]
204 results += self._quote_params_for_last_executed_query(chunk)
205 return results
206
207 sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
208 # Bypass Plain's wrappers and use the underlying sqlite3 connection
209 # to avoid logging this query - it would trigger infinite recursion.
210 cursor = self.connection.connection.cursor()
211 # Native sqlite3 cursors cannot be used as context managers.
212 try:
213 return cursor.execute(sql, params).fetchone()
214 finally:
215 cursor.close()
216
217 def last_executed_query(
218 self,
219 cursor: CursorWrapper,
220 sql: str,
221 params: list[Any] | tuple[Any, ...] | dict[str, Any] | None,
222 ) -> str:
223 # Python substitutes parameters in Modules/_sqlite/cursor.c with:
224 # bind_parameters(state, self->statement, parameters);
225 # Unfortunately there is no way to reach self->statement from Python,
226 # so we quote and substitute parameters manually.
227 if params:
228 if isinstance(params, list | tuple):
229 params = self._quote_params_for_last_executed_query(params)
230 else:
231 values = tuple(params.values())
232 values = self._quote_params_for_last_executed_query(values)
233 params = dict(zip(params, values))
234 return sql % params
235 # For consistency with SQLiteCursorWrapper.execute(), just return sql
236 # when there are no parameters. See #13648 and #17158.
237 else:
238 return sql
239
240 def quote_name(self, name: str) -> str:
241 if name.startswith('"') and name.endswith('"'):
242 return name # Quoting once is enough.
243 return f'"{name}"'
244
245 def no_limit_value(self) -> int:
246 return -1
247
248 def __references_graph(self, table_name: str) -> list[str]:
249 query = """
250 WITH tables AS (
251 SELECT %s name
252 UNION
253 SELECT sqlite_master.name
254 FROM sqlite_master
255 JOIN tables ON (sql REGEXP %s || tables.name || %s)
256 ) SELECT name FROM tables;
257 """
258 params = (
259 table_name,
260 r'(?i)\s+references\s+("|\')?',
261 r'("|\')?\s*\(',
262 )
263 with self.connection.cursor() as cursor:
264 results = cursor.execute(query, params)
265 return [row[0] for row in results.fetchall()]
266
267 @cached_property
268 def _references_graph(self) -> Callable[[str], list[str]]:
269 # 512 is large enough to fit the ~330 tables (as of this writing) in
270 # Plain's test suite.
271 return lru_cache(maxsize=512)(self.__references_graph)
272
273 def adapt_datetimefield_value(
274 self, value: datetime.datetime | Any | None
275 ) -> str | Any | None:
276 if value is None:
277 return None
278
279 # Expression values are adapted by the database.
280 if isinstance(value, ResolvableExpression):
281 return value
282
283 # SQLite doesn't support tz-aware datetimes
284 if timezone.is_aware(value):
285 value = timezone.make_naive(value, self.connection.timezone)
286
287 return str(value)
288
289 def adapt_timefield_value(
290 self, value: datetime.time | Any | None
291 ) -> str | Any | None:
292 if value is None:
293 return None
294
295 # Expression values are adapted by the database.
296 if isinstance(value, ResolvableExpression):
297 return value
298
299 # SQLite doesn't support tz-aware datetimes
300 if isinstance(value, datetime.time) and value.utcoffset() is not None:
301 raise ValueError("SQLite backend does not support timezone-aware times.")
302
303 return str(value)
304
305 def get_db_converters(self, expression: Any) -> list[Any]:
306 converters = super().get_db_converters(expression)
307 internal_type = expression.output_field.get_internal_type()
308 if internal_type == "DateTimeField":
309 converters.append(self.convert_datetimefield_value)
310 elif internal_type == "DateField":
311 converters.append(self.convert_datefield_value)
312 elif internal_type == "TimeField":
313 converters.append(self.convert_timefield_value)
314 elif internal_type == "DecimalField":
315 converters.append(self.get_decimalfield_converter(expression))
316 elif internal_type == "UUIDField":
317 converters.append(self.convert_uuidfield_value)
318 elif internal_type == "BooleanField":
319 converters.append(self.convert_booleanfield_value)
320 return converters
321
322 def convert_datetimefield_value(
323 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
324 ) -> datetime.datetime | None:
325 if value is not None:
326 if not isinstance(value, datetime.datetime):
327 value = parse_datetime(value)
328 if value is not None and not timezone.is_aware(value):
329 value = timezone.make_aware(value, self.connection.timezone)
330 return value
331
332 def convert_datefield_value(
333 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
334 ) -> datetime.date | None:
335 if value is not None:
336 if not isinstance(value, datetime.date):
337 value = parse_date(value)
338 return value
339
340 def convert_timefield_value(
341 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
342 ) -> datetime.time | None:
343 if value is not None:
344 if not isinstance(value, datetime.time):
345 value = parse_time(value)
346 return value
347
348 def get_decimalfield_converter(self, expression: Any) -> Callable[..., Any]:
349 # SQLite stores only 15 significant digits. Digits coming from
350 # float inaccuracy must be removed.
351 create_decimal = decimal.Context(prec=15).create_decimal_from_float
352 if isinstance(expression, Col):
353 quantize_value = decimal.Decimal(1).scaleb(
354 -expression.output_field.decimal_places
355 )
356
357 def converter(
358 value: Any, expression: Any, connection: BaseDatabaseWrapper
359 ) -> decimal.Decimal | None:
360 if value is not None:
361 return create_decimal(value).quantize(
362 quantize_value, context=expression.output_field.context
363 )
364 return None
365
366 else:
367
368 def converter(
369 value: Any, expression: Any, connection: BaseDatabaseWrapper
370 ) -> decimal.Decimal | None:
371 if value is not None:
372 return create_decimal(value)
373 return None
374
375 return converter
376
377 def convert_uuidfield_value(
378 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
379 ) -> uuid.UUID | None:
380 if value is not None:
381 value = uuid.UUID(value)
382 return value
383
384 def convert_booleanfield_value(
385 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
386 ) -> bool | Any:
387 return bool(value) if value in (1, 0) else value
388
389 def bulk_insert_sql(
390 self, fields: list[Any], placeholder_rows: list[list[str]]
391 ) -> str:
392 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
393 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
394 return f"VALUES {values_sql}"
395
396 def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
397 # SQLite doesn't have a ^ operator, so use the user-defined POWER
398 # function that's registered in connect().
399 if connector == "^":
400 return "POWER({})".format(",".join(sub_expressions))
401 elif connector == "#":
402 return "BITXOR({})".format(",".join(sub_expressions))
403 return super().combine_expression(connector, sub_expressions)
404
405 def combine_duration_expression(
406 self, connector: str, sub_expressions: list[str]
407 ) -> str:
408 if connector not in ["+", "-", "*", "/"]:
409 raise DatabaseError(f"Invalid connector for timedelta: {connector}.")
410 fn_params = [f"'{connector}'"] + sub_expressions
411 if len(fn_params) > 3:
412 raise ValueError("Too many params for timedelta operations.")
413 return "plain_format_dtdelta({})".format(", ".join(fn_params))
414
415 def integer_field_range(self, internal_type: str) -> tuple[int, int]:
416 # SQLite doesn't enforce any integer constraints, but sqlite3 supports
417 # integers up to 64 bits.
418 if internal_type in [
419 "PositiveBigIntegerField",
420 "PositiveIntegerField",
421 "PositiveSmallIntegerField",
422 ]:
423 return (0, 9223372036854775807)
424 return (-9223372036854775808, 9223372036854775807)
425
426 def subtract_temporals(
427 self,
428 internal_type: str,
429 lhs: tuple[str, list[Any] | tuple[Any, ...]],
430 rhs: tuple[str, list[Any] | tuple[Any, ...]],
431 ) -> tuple[str, tuple[Any, ...]]:
432 lhs_sql, lhs_params = lhs
433 rhs_sql, rhs_params = rhs
434 params = (*lhs_params, *rhs_params)
435 if internal_type == "TimeField":
436 return f"plain_time_diff({lhs_sql}, {rhs_sql})", params
437 return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params
438
439 def insert_statement(self, on_conflict: Any = None) -> str:
440 if on_conflict == OnConflict.IGNORE:
441 return "INSERT OR IGNORE INTO"
442 return super().insert_statement(on_conflict=on_conflict)
443
444 def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
445 # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
446 if not fields:
447 return "", ()
448 columns = [
449 f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
450 for field in fields
451 ]
452 return "RETURNING {}".format(", ".join(columns)), ()
453
454 def on_conflict_suffix_sql(
455 self,
456 fields: list[Field],
457 on_conflict: Any,
458 update_fields: Iterable[str],
459 unique_fields: Iterable[str],
460 ) -> str:
461 if (
462 on_conflict == OnConflict.UPDATE
463 and self.connection.features.supports_update_conflicts_with_target
464 ):
465 return "ON CONFLICT({}) DO UPDATE SET {}".format(
466 ", ".join(map(self.quote_name, unique_fields)),
467 ", ".join(
468 [
469 f"{field} = EXCLUDED.{field}"
470 for field in map(self.quote_name, update_fields)
471 ]
472 ),
473 )
474 return super().on_conflict_suffix_sql(
475 fields,
476 on_conflict,
477 update_fields,
478 unique_fields,
479 )