1import json
2from functools import lru_cache, partial
3
4from plain.models.backends.base.operations import BaseDatabaseOperations
5from plain.models.backends.postgresql.psycopg_any import (
6 Inet,
7 Jsonb,
8 errors,
9 is_psycopg3,
10 mogrify,
11)
12from plain.models.backends.utils import split_tzname_delta
13from plain.models.constants import OnConflict
14from plain.runtime import settings
15from plain.utils.regex_helper import _lazy_re_compile
16
17
18@lru_cache
19def get_json_dumps(encoder):
20 if encoder is None:
21 return json.dumps
22 return partial(json.dumps, cls=encoder)
23
24
25class DatabaseOperations(BaseDatabaseOperations):
26 cast_char_field_without_max_length = "varchar"
27 explain_prefix = "EXPLAIN"
28 explain_options = frozenset(
29 [
30 "ANALYZE",
31 "BUFFERS",
32 "COSTS",
33 "SETTINGS",
34 "SUMMARY",
35 "TIMING",
36 "VERBOSE",
37 "WAL",
38 ]
39 )
40 cast_data_types = {
41 "AutoField": "integer",
42 "BigAutoField": "bigint",
43 "SmallAutoField": "smallint",
44 }
45
46 if is_psycopg3:
47 from psycopg.types import numeric
48
49 integerfield_type_map = {
50 "SmallIntegerField": numeric.Int2,
51 "IntegerField": numeric.Int4,
52 "BigIntegerField": numeric.Int8,
53 "PositiveSmallIntegerField": numeric.Int2,
54 "PositiveIntegerField": numeric.Int4,
55 "PositiveBigIntegerField": numeric.Int8,
56 }
57
58 def unification_cast_sql(self, output_field):
59 internal_type = output_field.get_internal_type()
60 if internal_type in (
61 "GenericIPAddressField",
62 "IPAddressField",
63 "TimeField",
64 "UUIDField",
65 ):
66 # PostgreSQL will resolve a union as type 'text' if input types are
67 # 'unknown'.
68 # https://www.postgresql.org/docs/current/typeconv-union-case.html
69 # These fields cannot be implicitly cast back in the default
70 # PostgreSQL configuration so we need to explicitly cast them.
71 # We must also remove components of the type within brackets:
72 # varchar(255) -> varchar.
73 return (
74 "CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0]
75 )
76 return "%s"
77
78 # EXTRACT format cannot be passed in parameters.
79 _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
80
81 def date_extract_sql(self, lookup_type, sql, params):
82 # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
83 if lookup_type == "week_day":
84 # For consistency across backends, we return Sunday=1, Saturday=7.
85 return f"EXTRACT(DOW FROM {sql}) + 1", params
86 elif lookup_type == "iso_week_day":
87 return f"EXTRACT(ISODOW FROM {sql})", params
88 elif lookup_type == "iso_year":
89 return f"EXTRACT(ISOYEAR FROM {sql})", params
90
91 lookup_type = lookup_type.upper()
92 if not self._extract_format_re.fullmatch(lookup_type):
93 raise ValueError(f"Invalid lookup type: {lookup_type!r}")
94 return f"EXTRACT({lookup_type} FROM {sql})", params
95
96 def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
97 sql, params = self._convert_sql_to_tz(sql, params, tzname)
98 # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
99 return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
100
101 def _prepare_tzname_delta(self, tzname):
102 tzname, sign, offset = split_tzname_delta(tzname)
103 if offset:
104 sign = "-" if sign == "+" else "+"
105 return f"{tzname}{sign}{offset}"
106 return tzname
107
108 def _convert_sql_to_tz(self, sql, params, tzname):
109 if tzname and settings.USE_TZ:
110 tzname_param = self._prepare_tzname_delta(tzname)
111 return f"{sql} AT TIME ZONE %s", (*params, tzname_param)
112 return sql, params
113
114 def datetime_cast_date_sql(self, sql, params, tzname):
115 sql, params = self._convert_sql_to_tz(sql, params, tzname)
116 return f"({sql})::date", params
117
118 def datetime_cast_time_sql(self, sql, params, tzname):
119 sql, params = self._convert_sql_to_tz(sql, params, tzname)
120 return f"({sql})::time", params
121
122 def datetime_extract_sql(self, lookup_type, sql, params, tzname):
123 sql, params = self._convert_sql_to_tz(sql, params, tzname)
124 if lookup_type == "second":
125 # Truncate fractional seconds.
126 return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
127 return self.date_extract_sql(lookup_type, sql, params)
128
129 def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
130 sql, params = self._convert_sql_to_tz(sql, params, tzname)
131 # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
132 return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
133
134 def time_extract_sql(self, lookup_type, sql, params):
135 if lookup_type == "second":
136 # Truncate fractional seconds.
137 return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
138 return self.date_extract_sql(lookup_type, sql, params)
139
140 def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
141 sql, params = self._convert_sql_to_tz(sql, params, tzname)
142 return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
143
144 def deferrable_sql(self):
145 return " DEFERRABLE INITIALLY DEFERRED"
146
147 def fetch_returned_insert_rows(self, cursor):
148 """
149 Given a cursor object that has just performed an INSERT...RETURNING
150 statement into a table, return the tuple of returned data.
151 """
152 return cursor.fetchall()
153
154 def lookup_cast(self, lookup_type, internal_type=None):
155 lookup = "%s"
156
157 if lookup_type == "isnull" and internal_type in (
158 "CharField",
159 "EmailField",
160 "TextField",
161 "CICharField",
162 "CIEmailField",
163 "CITextField",
164 ):
165 return "%s::text"
166
167 # Cast text lookups to text to allow things like filter(x__contains=4)
168 if lookup_type in (
169 "iexact",
170 "contains",
171 "icontains",
172 "startswith",
173 "istartswith",
174 "endswith",
175 "iendswith",
176 "regex",
177 "iregex",
178 ):
179 if internal_type in ("IPAddressField", "GenericIPAddressField"):
180 lookup = "HOST(%s)"
181 # RemovedInDjango51Warning.
182 elif internal_type in ("CICharField", "CIEmailField", "CITextField"):
183 lookup = "%s::citext"
184 else:
185 lookup = "%s::text"
186
187 # Use UPPER(x) for case-insensitive lookups; it's faster.
188 if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
189 lookup = "UPPER(%s)" % lookup
190
191 return lookup
192
193 def no_limit_value(self):
194 return None
195
196 def prepare_sql_script(self, sql):
197 return [sql]
198
199 def quote_name(self, name):
200 if name.startswith('"') and name.endswith('"'):
201 return name # Quoting once is enough.
202 return '"%s"' % name
203
204 def compose_sql(self, sql, params):
205 return mogrify(sql, params, self.connection)
206
207 def set_time_zone_sql(self):
208 return "SELECT set_config('TimeZone', %s, false)"
209
210 def sequence_reset_by_name_sql(self, style, sequences):
211 # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
212 # to reset sequence indices
213 sql = []
214 for sequence_info in sequences:
215 table_name = sequence_info["table"]
216 # 'id' will be the case if it's an m2m using an autogenerated
217 # intermediate table (see BaseDatabaseIntrospection.sequence_list).
218 column_name = sequence_info["column"] or "id"
219 sql.append(
220 "{} setval(pg_get_serial_sequence('{}','{}'), 1, false);".format(
221 style.SQL_KEYWORD("SELECT"),
222 style.SQL_TABLE(self.quote_name(table_name)),
223 style.SQL_FIELD(column_name),
224 )
225 )
226 return sql
227
228 def tablespace_sql(self, tablespace, inline=False):
229 if inline:
230 return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
231 else:
232 return "TABLESPACE %s" % self.quote_name(tablespace)
233
234 def sequence_reset_sql(self, style, model_list):
235 from plain import models
236
237 output = []
238 qn = self.quote_name
239 for model in model_list:
240 # Use `coalesce` to set the sequence for each model to the max pk
241 # value if there are records, or 1 if there are none. Set the
242 # `is_called` property (the third argument to `setval`) to true if
243 # there are records (as the max pk value is already in use),
244 # otherwise set it to false. Use pg_get_serial_sequence to get the
245 # underlying sequence name from the table name and column name.
246
247 for f in model._meta.local_fields:
248 if isinstance(f, models.AutoField):
249 output.append(
250 "{} setval(pg_get_serial_sequence('{}','{}'), "
251 "coalesce(max({}), 1), max({}) {} null) {} {};".format(
252 style.SQL_KEYWORD("SELECT"),
253 style.SQL_TABLE(qn(model._meta.db_table)),
254 style.SQL_FIELD(f.column),
255 style.SQL_FIELD(qn(f.column)),
256 style.SQL_FIELD(qn(f.column)),
257 style.SQL_KEYWORD("IS NOT"),
258 style.SQL_KEYWORD("FROM"),
259 style.SQL_TABLE(qn(model._meta.db_table)),
260 )
261 )
262 # Only one AutoField is allowed per model, so don't bother
263 # continuing.
264 break
265 return output
266
267 def prep_for_iexact_query(self, x):
268 return x
269
270 def max_name_length(self):
271 """
272 Return the maximum length of an identifier.
273
274 The maximum length of an identifier is 63 by default, but can be
275 changed by recompiling PostgreSQL after editing the NAMEDATALEN
276 macro in src/include/pg_config_manual.h.
277
278 This implementation returns 63, but can be overridden by a custom
279 database backend that inherits most of its behavior from this one.
280 """
281 return 63
282
283 def distinct_sql(self, fields, params):
284 if fields:
285 params = [param for param_list in params for param in param_list]
286 return (["DISTINCT ON (%s)" % ", ".join(fields)], params)
287 else:
288 return ["DISTINCT"], []
289
290 if is_psycopg3:
291
292 def last_executed_query(self, cursor, sql, params):
293 try:
294 return self.compose_sql(sql, params)
295 except errors.DataError:
296 return None
297
298 else:
299
300 def last_executed_query(self, cursor, sql, params):
301 # https://www.psycopg.org/docs/cursor.html#cursor.query
302 # The query attribute is a Psycopg extension to the DB API 2.0.
303 if cursor.query is not None:
304 return cursor.query.decode()
305 return None
306
307 def return_insert_columns(self, fields):
308 if not fields:
309 return "", ()
310 columns = [
311 "{}.{}".format(
312 self.quote_name(field.model._meta.db_table),
313 self.quote_name(field.column),
314 )
315 for field in fields
316 ]
317 return "RETURNING %s" % ", ".join(columns), ()
318
319 def bulk_insert_sql(self, fields, placeholder_rows):
320 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
321 values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
322 return "VALUES " + values_sql
323
324 if is_psycopg3:
325
326 def adapt_integerfield_value(self, value, internal_type):
327 if value is None or hasattr(value, "resolve_expression"):
328 return value
329 return self.integerfield_type_map[internal_type](value)
330
331 def adapt_datefield_value(self, value):
332 return value
333
334 def adapt_datetimefield_value(self, value):
335 return value
336
337 def adapt_timefield_value(self, value):
338 return value
339
340 def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
341 return value
342
343 def adapt_ipaddressfield_value(self, value):
344 if value:
345 return Inet(value)
346 return None
347
348 def adapt_json_value(self, value, encoder):
349 return Jsonb(value, dumps=get_json_dumps(encoder))
350
351 def subtract_temporals(self, internal_type, lhs, rhs):
352 if internal_type == "DateField":
353 lhs_sql, lhs_params = lhs
354 rhs_sql, rhs_params = rhs
355 params = (*lhs_params, *rhs_params)
356 return f"(interval '1 day' * ({lhs_sql} - {rhs_sql}))", params
357 return super().subtract_temporals(internal_type, lhs, rhs)
358
359 def explain_query_prefix(self, format=None, **options):
360 extra = {}
361 # Normalize options.
362 if options:
363 options = {
364 name.upper(): "true" if value else "false"
365 for name, value in options.items()
366 }
367 for valid_option in self.explain_options:
368 value = options.pop(valid_option, None)
369 if value is not None:
370 extra[valid_option] = value
371 prefix = super().explain_query_prefix(format, **options)
372 if format:
373 extra["FORMAT"] = format
374 if extra:
375 prefix += " (%s)" % ", ".join("{} {}".format(*i) for i in extra.items())
376 return prefix
377
378 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
379 if on_conflict == OnConflict.IGNORE:
380 return "ON CONFLICT DO NOTHING"
381 if on_conflict == OnConflict.UPDATE:
382 return "ON CONFLICT({}) DO UPDATE SET {}".format(
383 ", ".join(map(self.quote_name, unique_fields)),
384 ", ".join(
385 [
386 f"{field} = EXCLUDED.{field}"
387 for field in map(self.quote_name, update_fields)
388 ]
389 ),
390 )
391 return super().on_conflict_suffix_sql(
392 fields,
393 on_conflict,
394 update_fields,
395 unique_fields,
396 )