1from __future__ import annotations
2
3import datetime
4import decimal
5import uuid
6from collections.abc import Callable
7from functools import cached_property, lru_cache
8from typing import TYPE_CHECKING, Any
9
10from plain import models
11from plain.models.backends.base.operations import BaseDatabaseOperations
12from plain.models.constants import OnConflict
13from plain.models.db import DatabaseError, NotSupportedError
14from plain.models.exceptions import FieldError
15from plain.models.expressions import Col
16from plain.utils import timezone
17from plain.utils.dateparse import parse_date, parse_datetime, parse_time
18
19if TYPE_CHECKING:
20 from plain.models.backends.base.base import BaseDatabaseWrapper
21
22
23class DatabaseOperations(BaseDatabaseOperations):
24 cast_char_field_without_max_length = "text"
25 cast_data_types = {
26 "DateField": "TEXT",
27 "DateTimeField": "TEXT",
28 }
29 explain_prefix = "EXPLAIN QUERY PLAN"
30 # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
31 # SQLite. Use JSON_TYPE() instead.
32 jsonfield_datatype_values = frozenset(["null", "false", "true"])
33
34 def bulk_batch_size(self, fields: list[Any], objs: list[Any]) -> int:
35 """
36 SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
37 999 variables per query.
38
39 If there's only a single field to insert, the limit is 500
40 (SQLITE_MAX_COMPOUND_SELECT).
41 """
42 if len(fields) == 1:
43 return 500
44 elif len(fields) > 1:
45 return self.connection.features.max_query_params // len(fields)
46 else:
47 return len(objs)
48
49 def check_expression_support(self, expression: Any) -> None:
50 bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
51 bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
52 if isinstance(expression, bad_aggregates):
53 for expr in expression.get_source_expressions():
54 try:
55 output_field = expr.output_field
56 except (AttributeError, FieldError):
57 # Not every subexpression has an output_field which is fine
58 # to ignore.
59 pass
60 else:
61 if isinstance(output_field, bad_fields):
62 raise NotSupportedError(
63 "You cannot use Sum, Avg, StdDev, and Variance "
64 "aggregations on date/time fields in sqlite3 "
65 "since date/time is saved as text."
66 )
67 if (
68 isinstance(expression, models.Aggregate)
69 and expression.distinct
70 and len(expression.source_expressions) > 1
71 ):
72 raise NotSupportedError(
73 "SQLite doesn't support DISTINCT on aggregate functions "
74 "accepting multiple arguments."
75 )
76
77 def date_extract_sql(
78 self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
79 ) -> tuple[str, tuple[Any, ...]]:
80 """
81 Support EXTRACT with a user-defined function plain_date_extract()
82 that's registered in connect(). Use single quotes because this is a
83 string and could otherwise cause a collision with a field name.
84 """
85 return f"plain_date_extract(%s, {sql})", (lookup_type.lower(), *params)
86
87 def fetch_returned_insert_rows(self, cursor: Any) -> list[Any]:
88 """
89 Given a cursor object that has just performed an INSERT...RETURNING
90 statement into a table, return the list of returned data.
91 """
92 return cursor.fetchall()
93
94 def format_for_duration_arithmetic(self, sql: str) -> str:
95 """Do nothing since formatting is handled in the custom function."""
96 return sql
97
98 def date_trunc_sql(
99 self,
100 lookup_type: str,
101 sql: str,
102 params: list[Any] | tuple[Any, ...],
103 tzname: str | None = None,
104 ) -> tuple[str, tuple[Any, ...]]:
105 return f"plain_date_trunc(%s, {sql}, %s, %s)", (
106 lookup_type.lower(),
107 *params,
108 *self._convert_tznames_to_sql(tzname),
109 )
110
111 def time_trunc_sql(
112 self,
113 lookup_type: str,
114 sql: str,
115 params: list[Any] | tuple[Any, ...],
116 tzname: str | None = None,
117 ) -> tuple[str, tuple[Any, ...]]:
118 return f"plain_time_trunc(%s, {sql}, %s, %s)", (
119 lookup_type.lower(),
120 *params,
121 *self._convert_tznames_to_sql(tzname),
122 )
123
124 def _convert_tznames_to_sql(
125 self, tzname: str | None
126 ) -> tuple[str | None, str | None]:
127 if tzname:
128 return tzname, self.connection.timezone_name
129 return None, None
130
131 def datetime_cast_date_sql(
132 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
133 ) -> tuple[str, tuple[Any, ...]]:
134 return f"plain_datetime_cast_date({sql}, %s, %s)", (
135 *params,
136 *self._convert_tznames_to_sql(tzname),
137 )
138
139 def datetime_cast_time_sql(
140 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
141 ) -> tuple[str, tuple[Any, ...]]:
142 return f"plain_datetime_cast_time({sql}, %s, %s)", (
143 *params,
144 *self._convert_tznames_to_sql(tzname),
145 )
146
147 def datetime_extract_sql(
148 self,
149 lookup_type: str,
150 sql: str,
151 params: list[Any] | tuple[Any, ...],
152 tzname: str | None,
153 ) -> tuple[str, tuple[Any, ...]]:
154 return f"plain_datetime_extract(%s, {sql}, %s, %s)", (
155 lookup_type.lower(),
156 *params,
157 *self._convert_tznames_to_sql(tzname),
158 )
159
160 def datetime_trunc_sql(
161 self,
162 lookup_type: str,
163 sql: str,
164 params: list[Any] | tuple[Any, ...],
165 tzname: str | None,
166 ) -> tuple[str, tuple[Any, ...]]:
167 return f"plain_datetime_trunc(%s, {sql}, %s, %s)", (
168 lookup_type.lower(),
169 *params,
170 *self._convert_tznames_to_sql(tzname),
171 )
172
173 def time_extract_sql(
174 self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
175 ) -> tuple[str, tuple[Any, ...]]:
176 return f"plain_time_extract(%s, {sql})", (lookup_type.lower(), *params)
177
178 def pk_default_value(self) -> str:
179 return "NULL"
180
181 def _quote_params_for_last_executed_query(
182 self, params: list[Any] | tuple[Any, ...]
183 ) -> tuple[Any, ...]:
184 """
185 Only for last_executed_query! Don't use this to execute SQL queries!
186 """
187 # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
188 # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
189 # number of return values, default = 2000). Since Python's sqlite3
190 # module doesn't expose the get_limit() C API, assume the default
191 # limits are in effect and split the work in batches if needed.
192 BATCH_SIZE = 999
193 if len(params) > BATCH_SIZE:
194 results = ()
195 for index in range(0, len(params), BATCH_SIZE):
196 chunk = params[index : index + BATCH_SIZE]
197 results += self._quote_params_for_last_executed_query(chunk)
198 return results
199
200 sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
201 # Bypass Plain's wrappers and use the underlying sqlite3 connection
202 # to avoid logging this query - it would trigger infinite recursion.
203 cursor = self.connection.connection.cursor()
204 # Native sqlite3 cursors cannot be used as context managers.
205 try:
206 return cursor.execute(sql, params).fetchone()
207 finally:
208 cursor.close()
209
210 def last_executed_query(
211 self,
212 cursor: Any,
213 sql: str,
214 params: list[Any] | tuple[Any, ...] | dict[str, Any] | None,
215 ) -> str:
216 # Python substitutes parameters in Modules/_sqlite/cursor.c with:
217 # bind_parameters(state, self->statement, parameters);
218 # Unfortunately there is no way to reach self->statement from Python,
219 # so we quote and substitute parameters manually.
220 if params:
221 if isinstance(params, list | tuple):
222 params = self._quote_params_for_last_executed_query(params)
223 else:
224 values = tuple(params.values()) # type: ignore[union-attr]
225 values = self._quote_params_for_last_executed_query(values)
226 params = dict(zip(params, values))
227 return sql % params
228 # For consistency with SQLiteCursorWrapper.execute(), just return sql
229 # when there are no parameters. See #13648 and #17158.
230 else:
231 return sql
232
233 def quote_name(self, name: str) -> str:
234 if name.startswith('"') and name.endswith('"'):
235 return name # Quoting once is enough.
236 return f'"{name}"'
237
238 def no_limit_value(self) -> int:
239 return -1
240
241 def __references_graph(self, table_name: str) -> list[str]:
242 query = """
243 WITH tables AS (
244 SELECT %s name
245 UNION
246 SELECT sqlite_master.name
247 FROM sqlite_master
248 JOIN tables ON (sql REGEXP %s || tables.name || %s)
249 ) SELECT name FROM tables;
250 """
251 params = (
252 table_name,
253 r'(?i)\s+references\s+("|\')?',
254 r'("|\')?\s*\(',
255 )
256 with self.connection.cursor() as cursor:
257 results = cursor.execute(query, params)
258 return [row[0] for row in results.fetchall()]
259
260 @cached_property
261 def _references_graph(self) -> Callable[[str], list[str]]:
262 # 512 is large enough to fit the ~330 tables (as of this writing) in
263 # Plain's test suite.
264 return lru_cache(maxsize=512)(self.__references_graph)
265
266 def adapt_datetimefield_value(
267 self, value: datetime.datetime | Any | None
268 ) -> str | Any | None:
269 if value is None:
270 return None
271
272 # Expression values are adapted by the database.
273 if hasattr(value, "resolve_expression"):
274 return value
275
276 # SQLite doesn't support tz-aware datetimes
277 if timezone.is_aware(value):
278 value = timezone.make_naive(value, self.connection.timezone)
279
280 return str(value)
281
282 def adapt_timefield_value(
283 self, value: datetime.time | Any | None
284 ) -> str | Any | None:
285 if value is None:
286 return None
287
288 # Expression values are adapted by the database.
289 if hasattr(value, "resolve_expression"):
290 return value
291
292 # SQLite doesn't support tz-aware datetimes
293 if timezone.is_aware(value): # type: ignore[arg-type]
294 raise ValueError("SQLite backend does not support timezone-aware times.")
295
296 return str(value)
297
298 def get_db_converters(self, expression: Any) -> list[Any]:
299 converters = super().get_db_converters(expression)
300 internal_type = expression.output_field.get_internal_type()
301 if internal_type == "DateTimeField":
302 converters.append(self.convert_datetimefield_value)
303 elif internal_type == "DateField":
304 converters.append(self.convert_datefield_value)
305 elif internal_type == "TimeField":
306 converters.append(self.convert_timefield_value)
307 elif internal_type == "DecimalField":
308 converters.append(self.get_decimalfield_converter(expression))
309 elif internal_type == "UUIDField":
310 converters.append(self.convert_uuidfield_value)
311 elif internal_type == "BooleanField":
312 converters.append(self.convert_booleanfield_value)
313 return converters
314
315 def convert_datetimefield_value(
316 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
317 ) -> datetime.datetime | None:
318 if value is not None:
319 if not isinstance(value, datetime.datetime):
320 value = parse_datetime(value)
321 if value is not None and not timezone.is_aware(value):
322 value = timezone.make_aware(value, self.connection.timezone)
323 return value
324
325 def convert_datefield_value(
326 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
327 ) -> datetime.date | None:
328 if value is not None:
329 if not isinstance(value, datetime.date):
330 value = parse_date(value)
331 return value
332
333 def convert_timefield_value(
334 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
335 ) -> datetime.time | None:
336 if value is not None:
337 if not isinstance(value, datetime.time):
338 value = parse_time(value)
339 return value
340
341 def get_decimalfield_converter(self, expression: Any) -> Callable[..., Any]:
342 # SQLite stores only 15 significant digits. Digits coming from
343 # float inaccuracy must be removed.
344 create_decimal = decimal.Context(prec=15).create_decimal_from_float
345 if isinstance(expression, Col):
346 quantize_value = decimal.Decimal(1).scaleb(
347 -expression.output_field.decimal_places
348 )
349
350 def converter(
351 value: Any, expression: Any, connection: BaseDatabaseWrapper
352 ) -> decimal.Decimal | None:
353 if value is not None:
354 return create_decimal(value).quantize(
355 quantize_value, context=expression.output_field.context
356 )
357 return None
358
359 else:
360
361 def converter(
362 value: Any, expression: Any, connection: BaseDatabaseWrapper
363 ) -> decimal.Decimal | None:
364 if value is not None:
365 return create_decimal(value)
366 return None
367
368 return converter
369
370 def convert_uuidfield_value(
371 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
372 ) -> uuid.UUID | None:
373 if value is not None:
374 value = uuid.UUID(value)
375 return value
376
377 def convert_booleanfield_value(
378 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
379 ) -> bool | Any:
380 return bool(value) if value in (1, 0) else value
381
382 def bulk_insert_sql(
383 self, fields: list[Any], placeholder_rows: list[list[str]]
384 ) -> str:
385 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
386 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
387 return f"VALUES {values_sql}"
388
389 def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
390 # SQLite doesn't have a ^ operator, so use the user-defined POWER
391 # function that's registered in connect().
392 if connector == "^":
393 return "POWER({})".format(",".join(sub_expressions))
394 elif connector == "#":
395 return "BITXOR({})".format(",".join(sub_expressions))
396 return super().combine_expression(connector, sub_expressions)
397
398 def combine_duration_expression(
399 self, connector: str, sub_expressions: list[str]
400 ) -> str:
401 if connector not in ["+", "-", "*", "/"]:
402 raise DatabaseError(f"Invalid connector for timedelta: {connector}.")
403 fn_params = [f"'{connector}'"] + sub_expressions
404 if len(fn_params) > 3:
405 raise ValueError("Too many params for timedelta operations.")
406 return "plain_format_dtdelta({})".format(", ".join(fn_params))
407
408 def integer_field_range(self, internal_type: str) -> tuple[int, int]:
409 # SQLite doesn't enforce any integer constraints, but sqlite3 supports
410 # integers up to 64 bits.
411 if internal_type in [
412 "PositiveBigIntegerField",
413 "PositiveIntegerField",
414 "PositiveSmallIntegerField",
415 ]:
416 return (0, 9223372036854775807)
417 return (-9223372036854775808, 9223372036854775807)
418
419 def subtract_temporals(
420 self,
421 internal_type: str,
422 lhs: tuple[str, list[Any] | tuple[Any, ...]],
423 rhs: tuple[str, list[Any] | tuple[Any, ...]],
424 ) -> tuple[str, tuple[Any, ...]]:
425 lhs_sql, lhs_params = lhs
426 rhs_sql, rhs_params = rhs
427 params = (*lhs_params, *rhs_params)
428 if internal_type == "TimeField":
429 return f"plain_time_diff({lhs_sql}, {rhs_sql})", params
430 return f"plain_timestamp_diff({lhs_sql}, {rhs_sql})", params
431
432 def insert_statement(self, on_conflict: Any = None) -> str:
433 if on_conflict == OnConflict.IGNORE:
434 return "INSERT OR IGNORE INTO"
435 return super().insert_statement(on_conflict=on_conflict)
436
437 def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
438 # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
439 if not fields:
440 return "", ()
441 columns = [
442 f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
443 for field in fields
444 ]
445 return "RETURNING {}".format(", ".join(columns)), ()
446
447 def on_conflict_suffix_sql(
448 self,
449 fields: list[Any],
450 on_conflict: Any,
451 update_fields: list[Any],
452 unique_fields: list[Any],
453 ) -> str:
454 if (
455 on_conflict == OnConflict.UPDATE
456 and self.connection.features.supports_update_conflicts_with_target
457 ):
458 return "ON CONFLICT({}) DO UPDATE SET {}".format(
459 ", ".join(map(self.quote_name, unique_fields)),
460 ", ".join(
461 [
462 f"{field} = EXCLUDED.{field}"
463 for field in map(self.quote_name, update_fields)
464 ]
465 ),
466 )
467 return super().on_conflict_suffix_sql(
468 fields,
469 on_conflict,
470 update_fields,
471 unique_fields,
472 )