1from __future__ import annotations
2
3import datetime
4import decimal
5import uuid
6from collections.abc import Callable, Iterable
7from functools import cached_property, lru_cache
8from typing import TYPE_CHECKING, Any
9
10from plain import models
11from plain.models.aggregates import Aggregate, Avg, StdDev, Sum, Variance
12from plain.models.backends.base.operations import BaseDatabaseOperations
13from plain.models.constants import OnConflict
14from plain.models.db import DatabaseError, NotSupportedError
15from plain.models.exceptions import FieldError
16from plain.models.expressions import Col
17from plain.utils import timezone
18from plain.utils.dateparse import parse_date, parse_datetime, parse_time
19
20if TYPE_CHECKING:
21 from plain.models.backends.base.base import BaseDatabaseWrapper
22 from plain.models.backends.sqlite3.base import SQLiteDatabaseWrapper
23 from plain.models.fields import Field
24
25
26class DatabaseOperations(BaseDatabaseOperations):
27 # Type checker hint: connection is always SQLiteDatabaseWrapper in this class
28 connection: SQLiteDatabaseWrapper
29
30 cast_char_field_without_max_length = "text"
31 cast_data_types = {
32 "DateField": "TEXT",
33 "DateTimeField": "TEXT",
34 }
35 explain_prefix = "EXPLAIN QUERY PLAN"
36 # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
37 # SQLite. Use JSON_TYPE() instead.
38 jsonfield_datatype_values = frozenset(["null", "false", "true"])
39
40 def bulk_batch_size(self, fields: list[Any], objs: list[Any]) -> int:
41 """
42 SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
43 999 variables per query.
44
45 If there's only a single field to insert, the limit is 500
46 (SQLITE_MAX_COMPOUND_SELECT).
47 """
48 if len(fields) == 1:
49 return 500
50 elif len(fields) > 1:
51 return self.connection.features.max_query_params // len(fields)
52 else:
53 return len(objs)
54
55 def check_expression_support(self, expression: Any) -> None:
56 bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
57 bad_aggregates = (Sum, Avg, Variance, StdDev)
58 if isinstance(expression, bad_aggregates):
59 for expr in expression.get_source_expressions():
60 try:
61 output_field = expr.output_field
62 except (AttributeError, FieldError):
63 # Not every subexpression has an output_field which is fine
64 # to ignore.
65 pass
66 else:
67 if isinstance(output_field, bad_fields):
68 raise NotSupportedError(
69 "You cannot use Sum, Avg, StdDev, and Variance "
70 "aggregations on date/time fields in sqlite3 "
71 "since date/time is saved as text."
72 )
73 if (
74 isinstance(expression, Aggregate)
75 and expression.distinct
76 and len(expression.source_expressions) > 1
77 ):
78 raise NotSupportedError(
79 "SQLite doesn't support DISTINCT on aggregate functions "
80 "accepting multiple arguments."
81 )
82
83 def date_extract_sql(
84 self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
85 ) -> tuple[str, tuple[Any, ...]]:
86 """
87 Support EXTRACT with a user-defined function plain_date_extract()
88 that's registered in connect(). Use single quotes because this is a
89 string and could otherwise cause a collision with a field name.
90 """
91 return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params)
92
93 def fetch_returned_insert_rows(self, cursor: Any) -> list[Any]:
94 """
95 Given a cursor object that has just performed an INSERT...RETURNING
96 statement into a table, return the list of returned data.
97 """
98 return cursor.fetchall()
99
100 def format_for_duration_arithmetic(self, sql: str) -> str:
101 """Do nothing since formatting is handled in the custom function."""
102 return sql
103
104 def date_trunc_sql(
105 self,
106 lookup_type: str,
107 sql: str,
108 params: list[Any] | tuple[Any, ...],
109 tzname: str | None = None,
110 ) -> tuple[str, tuple[Any, ...]]:
111 return f"plain_date_trunc(%s, {sql}, %s, %s)", (
112 lookup_type.lower(),
113 *params,
114 *self._convert_tznames_to_sql(tzname),
115 )
116
117 def time_trunc_sql(
118 self,
119 lookup_type: str,
120 sql: str,
121 params: list[Any] | tuple[Any, ...],
122 tzname: str | None = None,
123 ) -> tuple[str, tuple[Any, ...]]:
124 return f"plain_time_trunc(%s, {sql}, %s, %s)", (
125 lookup_type.lower(),
126 *params,
127 *self._convert_tznames_to_sql(tzname),
128 )
129
130 def _convert_tznames_to_sql(
131 self, tzname: str | None
132 ) -> tuple[str | None, str | None]:
133 if tzname:
134 return tzname, self.connection.timezone_name
135 return None, None
136
137 def datetime_cast_date_sql(
138 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
139 ) -> tuple[str, tuple[Any, ...]]:
140 return f"plain_datetime_cast_date({sql}, %s, %s)", (
141 *params,
142 *self._convert_tznames_to_sql(tzname),
143 )
144
145 def datetime_cast_time_sql(
146 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
147 ) -> tuple[str, tuple[Any, ...]]:
148 return f"plain_datetime_cast_time({sql}, %s, %s)", (
149 *params,
150 *self._convert_tznames_to_sql(tzname),
151 )
152
153 def datetime_extract_sql(
154 self,
155 lookup_type: str,
156 sql: str,
157 params: list[Any] | tuple[Any, ...],
158 tzname: str | None,
159 ) -> tuple[str, tuple[Any, ...]]:
160 return f"plain_datetime_extract(%s, {sql}, %s, %s)", (
161 lookup_type.lower(),
162 *params,
163 *self._convert_tznames_to_sql(tzname),
164 )
165
166 def datetime_trunc_sql(
167 self,
168 lookup_type: str,
169 sql: str,
170 params: list[Any] | tuple[Any, ...],
171 tzname: str | None,
172 ) -> tuple[str, tuple[Any, ...]]:
173 return f"plain_datetime_trunc(%s, {sql}, %s, %s)", (
174 lookup_type.lower(),
175 *params,
176 *self._convert_tznames_to_sql(tzname),
177 )
178
179 def time_extract_sql(
180 self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
181 ) -> tuple[str, tuple[Any, ...]]:
182 return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params)
183
184 def pk_default_value(self) -> str:
185 return "NULL"
186
187 def _quote_params_for_last_executed_query(
188 self, params: list[Any] | tuple[Any, ...]
189 ) -> tuple[Any, ...]:
190 """
191 Only for last_executed_query! Don't use this to execute SQL queries!
192 """
193 # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
194 # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
195 # number of return values, default = 2000). Since Python's sqlite3
196 # module doesn't expose the get_limit() C API, assume the default
197 # limits are in effect and split the work in batches if needed.
198 BATCH_SIZE = 999
199 if len(params) > BATCH_SIZE:
200 results = ()
201 for index in range(0, len(params), BATCH_SIZE):
202 chunk = params[index : index + BATCH_SIZE]
203 results += self._quote_params_for_last_executed_query(chunk)
204 return results
205
206 sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
207 # Bypass Plain's wrappers and use the underlying sqlite3 connection
208 # to avoid logging this query - it would trigger infinite recursion.
209 cursor = self.connection.connection.cursor()
210 # Native sqlite3 cursors cannot be used as context managers.
211 try:
212 return cursor.execute(sql, params).fetchone()
213 finally:
214 cursor.close()
215
216 def last_executed_query(
217 self,
218 cursor: Any,
219 sql: str,
220 params: list[Any] | tuple[Any, ...] | dict[str, Any] | None,
221 ) -> str:
222 # Python substitutes parameters in Modules/_sqlite/cursor.c with:
223 # bind_parameters(state, self->statement, parameters);
224 # Unfortunately there is no way to reach self->statement from Python,
225 # so we quote and substitute parameters manually.
226 if params:
227 if isinstance(params, list | tuple):
228 params = self._quote_params_for_last_executed_query(params)
229 else:
230 values = tuple(params.values()) # type: ignore[union-attr]
231 values = self._quote_params_for_last_executed_query(values)
232 params = dict(zip(params, values))
233 return sql % params
234 # For consistency with SQLiteCursorWrapper.execute(), just return sql
235 # when there are no parameters. See #13648 and #17158.
236 else:
237 return sql
238
239 def quote_name(self, name: str) -> str:
240 if name.startswith('"') and name.endswith('"'):
241 return name # Quoting once is enough.
242 return f'"{name}"'
243
244 def no_limit_value(self) -> int:
245 return -1
246
247 def __references_graph(self, table_name: str) -> list[str]:
248 query = """
249 WITH tables AS (
250 SELECT %s name
251 UNION
252 SELECT sqlite_master.name
253 FROM sqlite_master
254 JOIN tables ON (sql REGEXP %s || tables.name || %s)
255 ) SELECT name FROM tables;
256 """
257 params = (
258 table_name,
259 r'(?i)\s+references\s+("|\')?',
260 r'("|\')?\s*\(',
261 )
262 with self.connection.cursor() as cursor:
263 results = cursor.execute(query, params)
264 return [row[0] for row in results.fetchall()]
265
266 @cached_property
267 def _references_graph(self) -> Callable[[str], list[str]]:
268 # 512 is large enough to fit the ~330 tables (as of this writing) in
269 # Plain's test suite.
270 return lru_cache(maxsize=512)(self.__references_graph)
271
272 def adapt_datetimefield_value(
273 self, value: datetime.datetime | Any | None
274 ) -> str | Any | None:
275 if value is None:
276 return None
277
278 # Expression values are adapted by the database.
279 if hasattr(value, "resolve_expression"):
280 return value
281
282 # SQLite doesn't support tz-aware datetimes
283 if timezone.is_aware(value):
284 value = timezone.make_naive(value, self.connection.timezone)
285
286 return str(value)
287
288 def adapt_timefield_value(
289 self, value: datetime.time | Any | None
290 ) -> str | Any | None:
291 if value is None:
292 return None
293
294 # Expression values are adapted by the database.
295 if hasattr(value, "resolve_expression"):
296 return value
297
298 # SQLite doesn't support tz-aware datetimes
299 if timezone.is_aware(value): # type: ignore[arg-type]
300 raise ValueError("SQLite backend does not support timezone-aware times.")
301
302 return str(value)
303
304 def get_db_converters(self, expression: Any) -> list[Any]:
305 converters = super().get_db_converters(expression)
306 internal_type = expression.output_field.get_internal_type()
307 if internal_type == "DateTimeField":
308 converters.append(self.convert_datetimefield_value)
309 elif internal_type == "DateField":
310 converters.append(self.convert_datefield_value)
311 elif internal_type == "TimeField":
312 converters.append(self.convert_timefield_value)
313 elif internal_type == "DecimalField":
314 converters.append(self.get_decimalfield_converter(expression))
315 elif internal_type == "UUIDField":
316 converters.append(self.convert_uuidfield_value)
317 elif internal_type == "BooleanField":
318 converters.append(self.convert_booleanfield_value)
319 return converters
320
321 def convert_datetimefield_value(
322 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
323 ) -> datetime.datetime | None:
324 if value is not None:
325 if not isinstance(value, datetime.datetime):
326 value = parse_datetime(value)
327 if value is not None and not timezone.is_aware(value):
328 value = timezone.make_aware(value, self.connection.timezone)
329 return value
330
331 def convert_datefield_value(
332 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
333 ) -> datetime.date | None:
334 if value is not None:
335 if not isinstance(value, datetime.date):
336 value = parse_date(value)
337 return value
338
339 def convert_timefield_value(
340 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
341 ) -> datetime.time | None:
342 if value is not None:
343 if not isinstance(value, datetime.time):
344 value = parse_time(value)
345 return value
346
347 def get_decimalfield_converter(self, expression: Any) -> Callable[..., Any]:
348 # SQLite stores only 15 significant digits. Digits coming from
349 # float inaccuracy must be removed.
350 create_decimal = decimal.Context(prec=15).create_decimal_from_float
351 if isinstance(expression, Col):
352 quantize_value = decimal.Decimal(1).scaleb(
353 -expression.output_field.decimal_places
354 )
355
356 def converter(
357 value: Any, expression: Any, connection: BaseDatabaseWrapper
358 ) -> decimal.Decimal | None:
359 if value is not None:
360 return create_decimal(value).quantize(
361 quantize_value, context=expression.output_field.context
362 )
363 return None
364
365 else:
366
367 def converter(
368 value: Any, expression: Any, connection: BaseDatabaseWrapper
369 ) -> decimal.Decimal | None:
370 if value is not None:
371 return create_decimal(value)
372 return None
373
374 return converter
375
376 def convert_uuidfield_value(
377 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
378 ) -> uuid.UUID | None:
379 if value is not None:
380 value = uuid.UUID(value)
381 return value
382
383 def convert_booleanfield_value(
384 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
385 ) -> bool | Any:
386 return bool(value) if value in (1, 0) else value
387
388 def bulk_insert_sql(
389 self, fields: list[Any], placeholder_rows: list[list[str]]
390 ) -> str:
391 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
392 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
393 return f"VALUES {values_sql}"
394
395 def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
396 # SQLite doesn't have a ^ operator, so use the user-defined POWER
397 # function that's registered in connect().
398 if connector == "^":
399 return "POWER({})".format(",".join(sub_expressions))
400 elif connector == "#":
401 return "BITXOR({})".format(",".join(sub_expressions))
402 return super().combine_expression(connector, sub_expressions)
403
404 def combine_duration_expression(
405 self, connector: str, sub_expressions: list[str]
406 ) -> str:
407 if connector not in ["+", "-", "*", "/"]:
408 raise DatabaseError(f"Invalid connector for timedelta: {connector}.")
409 fn_params = [f"'{connector}'"] + sub_expressions
410 if len(fn_params) > 3:
411 raise ValueError("Too many params for timedelta operations.")
412 return "plain_format_dtdelta({})".format(", ".join(fn_params))
413
414 def integer_field_range(self, internal_type: str) -> tuple[int, int]:
415 # SQLite doesn't enforce any integer constraints, but sqlite3 supports
416 # integers up to 64 bits.
417 if internal_type in [
418 "PositiveBigIntegerField",
419 "PositiveIntegerField",
420 "PositiveSmallIntegerField",
421 ]:
422 return (0, 9223372036854775807)
423 return (-9223372036854775808, 9223372036854775807)
424
425 def subtract_temporals(
426 self,
427 internal_type: str,
428 lhs: tuple[str, list[Any] | tuple[Any, ...]],
429 rhs: tuple[str, list[Any] | tuple[Any, ...]],
430 ) -> tuple[str, tuple[Any, ...]]:
431 lhs_sql, lhs_params = lhs
432 rhs_sql, rhs_params = rhs
433 params = (*lhs_params, *rhs_params)
434 if internal_type == "TimeField":
435 return f"plain_time_diff({lhs_sql}, {rhs_sql})", params
436 return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params
437
438 def insert_statement(self, on_conflict: Any = None) -> str:
439 if on_conflict == OnConflict.IGNORE:
440 return "INSERT OR IGNORE INTO"
441 return super().insert_statement(on_conflict=on_conflict)
442
443 def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
444 # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
445 if not fields:
446 return "", ()
447 columns = [
448 f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
449 for field in fields
450 ]
451 return "RETURNING {}".format(", ".join(columns)), ()
452
453 def on_conflict_suffix_sql(
454 self,
455 fields: list[Field],
456 on_conflict: Any,
457 update_fields: Iterable[str],
458 unique_fields: Iterable[str],
459 ) -> str:
460 if (
461 on_conflict == OnConflict.UPDATE
462 and self.connection.features.supports_update_conflicts_with_target
463 ):
464 return "ON CONFLICT({}) DO UPDATE SET {}".format(
465 ", ".join(map(self.quote_name, unique_fields)),
466 ", ".join(
467 [
468 f"{field} = EXCLUDED.{field}"
469 for field in map(self.quote_name, update_fields)
470 ]
471 ),
472 )
473 return super().on_conflict_suffix_sql(
474 fields,
475 on_conflict,
476 update_fields,
477 unique_fields,
478 )