1"""
2PostgreSQL-specific SQL generation functions.
3
4All functions in this module are stateless - they don't depend on connection state.
5"""
6
7from __future__ import annotations
8
9import datetime
10import ipaddress
11import json
12from collections.abc import Callable, Iterable
13from functools import lru_cache, partial
14from typing import TYPE_CHECKING, Any
15
16from psycopg.types import numeric
17from psycopg.types.json import Jsonb
18
19from plain.models.constants import OnConflict
20from plain.models.db import NotSupportedError
21from plain.models.postgres.utils import split_tzname_delta
22from plain.utils import timezone
23from plain.utils.regex_helper import _lazy_re_compile
24
25if TYPE_CHECKING:
26 from plain.models.fields import Field
27
28
29# Integer field safe ranges by internal_type.
30INTEGER_FIELD_RANGES: dict[str, tuple[int, int]] = {
31 "SmallIntegerField": (-32768, 32767),
32 "IntegerField": (-2147483648, 2147483647),
33 "BigIntegerField": (-9223372036854775808, 9223372036854775807),
34 "PositiveBigIntegerField": (0, 9223372036854775807),
35 "PositiveSmallIntegerField": (0, 32767),
36 "PositiveIntegerField": (0, 2147483647),
37 "PrimaryKeyField": (-9223372036854775808, 9223372036854775807),
38}
39
40# Mapping of Field.get_internal_type() to the data type for Cast().
41CAST_DATA_TYPES: dict[str, str] = {
42 "PrimaryKeyField": "bigint",
43}
44
45# CharField data type when max_length isn't provided.
46CAST_CHAR_FIELD_WITHOUT_MAX_LENGTH: str | None = "varchar"
47
48# Start and end points for window expressions.
49PRECEDING: str = "PRECEDING"
50FOLLOWING: str = "FOLLOWING"
51UNBOUNDED_PRECEDING: str = "UNBOUNDED " + PRECEDING
52UNBOUNDED_FOLLOWING: str = "UNBOUNDED " + FOLLOWING
53CURRENT_ROW: str = "CURRENT ROW"
54
55# Prefix for EXPLAIN queries.
56EXPLAIN_PREFIX: str = "EXPLAIN"
57EXPLAIN_OPTIONS = frozenset(
58 [
59 "ANALYZE",
60 "BUFFERS",
61 "COSTS",
62 "SETTINGS",
63 "SUMMARY",
64 "TIMING",
65 "VERBOSE",
66 "WAL",
67 ]
68)
69SUPPORTED_EXPLAIN_FORMATS: set[str] = {"JSON", "TEXT", "XML", "YAML"}
70
71# PostgreSQL integer type mapping for psycopg.
72INTEGERFIELD_TYPE_MAP = {
73 "SmallIntegerField": numeric.Int2,
74 "IntegerField": numeric.Int4,
75 "BigIntegerField": numeric.Int8,
76 "PositiveSmallIntegerField": numeric.Int2,
77 "PositiveIntegerField": numeric.Int4,
78 "PositiveBigIntegerField": numeric.Int8,
79}
80
81# Maximum length of an identifier (63 by default in PostgreSQL).
82MAX_NAME_LENGTH: int = 63
83
84# Value to use during INSERT to specify that a field should use its default value.
85PK_DEFAULT_VALUE: str = "DEFAULT"
86
87# SQL clause to make a constraint "initially deferred" during CREATE TABLE.
88DEFERRABLE_SQL: str = " DEFERRABLE INITIALLY DEFERRED"
89
90# EXTRACT format validation pattern.
91_EXTRACT_FORMAT_RE = _lazy_re_compile(r"[A-Z_]+")
92
93
94# ##### Data type mappings (from constants.py) #####
95
96
97def _get_varchar_column(data: dict[str, Any]) -> str:
98 if data["max_length"] is None:
99 return "varchar"
100 return "varchar({max_length})".format(**data)
101
102
103# Maps Field objects to their associated PostgreSQL column types.
104# Column-type strings can contain format strings interpolated against Field.__dict__.
105DATA_TYPES: dict[str, Any] = {
106 "PrimaryKeyField": "bigint",
107 "BinaryField": "bytea",
108 "BooleanField": "boolean",
109 "CharField": _get_varchar_column,
110 "DateField": "date",
111 "DateTimeField": "timestamp with time zone",
112 "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
113 "DurationField": "interval",
114 "FloatField": "double precision",
115 "IntegerField": "integer",
116 "BigIntegerField": "bigint",
117 "GenericIPAddressField": "inet",
118 "JSONField": "jsonb",
119 "PositiveBigIntegerField": "bigint",
120 "PositiveIntegerField": "integer",
121 "PositiveSmallIntegerField": "smallint",
122 "SmallIntegerField": "smallint",
123 "TextField": "text",
124 "TimeField": "time",
125 "UUIDField": "uuid",
126}
127
128# Check constraints for fields that need them.
129DATA_TYPE_CHECK_CONSTRAINTS: dict[str, str] = {
130 "PositiveBigIntegerField": '"%(column)s" >= 0',
131 "PositiveIntegerField": '"%(column)s" >= 0',
132 "PositiveSmallIntegerField": '"%(column)s" >= 0',
133}
134
135# Suffix applied to column definitions (e.g., for identity columns).
136DATA_TYPES_SUFFIX: dict[str, str] = {
137 "PrimaryKeyField": "GENERATED BY DEFAULT AS IDENTITY",
138}
139
140# SQL operators for lookups.
141OPERATORS: dict[str, str] = {
142 "exact": "= %s",
143 "iexact": "= UPPER(%s)",
144 "contains": "LIKE %s",
145 "icontains": "LIKE UPPER(%s)",
146 "regex": "~ %s",
147 "iregex": "~* %s",
148 "gt": "> %s",
149 "gte": ">= %s",
150 "lt": "< %s",
151 "lte": "<= %s",
152 "startswith": "LIKE %s",
153 "endswith": "LIKE %s",
154 "istartswith": "LIKE UPPER(%s)",
155 "iendswith": "LIKE UPPER(%s)",
156}
157
158# SQL pattern for escaping special characters in LIKE clauses.
159# Used when the right-hand side isn't a raw string (e.g., an expression).
160PATTERN_ESC = (
161 r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
162)
163
164# Pattern operators for non-literal LIKE lookups.
165PATTERN_OPS: dict[str, str] = {
166 "contains": "LIKE '%%' || {} || '%%'",
167 "icontains": "LIKE '%%' || UPPER({}) || '%%'",
168 "startswith": "LIKE {} || '%%'",
169 "istartswith": "LIKE UPPER({}) || '%%'",
170 "endswith": "LIKE '%%' || {}",
171 "iendswith": "LIKE '%%' || UPPER({})",
172}
173
174
175@lru_cache
176def get_json_dumps(
177 encoder: type[json.JSONEncoder] | None,
178) -> Callable[..., str]:
179 if encoder is None:
180 return json.dumps
181 return partial(json.dumps, cls=encoder)
182
183
184def quote_name(name: str) -> str:
185 """
186 Return a quoted version of the given table, index, or column name.
187 Does not quote the given name if it's already been quoted.
188 """
189 if name.startswith('"') and name.endswith('"'):
190 return name # Quoting once is enough.
191 return f'"{name}"'
192
193
194def date_extract_sql(
195 lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
196) -> tuple[str, list[Any] | tuple[Any, ...]]:
197 """
198 Given a lookup_type of 'year', 'month', or 'day', return the SQL that
199 extracts a value from the given date field field_name.
200 """
201 # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
202 if lookup_type == "week_day":
203 # PostgreSQL DOW returns 0=Sunday, 6=Saturday; we return 1=Sunday, 7=Saturday.
204 return f"EXTRACT(DOW FROM {sql}) + 1", params
205 elif lookup_type == "iso_week_day":
206 return f"EXTRACT(ISODOW FROM {sql})", params
207 elif lookup_type == "iso_year":
208 return f"EXTRACT(ISOYEAR FROM {sql})", params
209
210 lookup_type = lookup_type.upper()
211 if not _EXTRACT_FORMAT_RE.fullmatch(lookup_type):
212 raise ValueError(f"Invalid lookup type: {lookup_type!r}")
213 return f"EXTRACT({lookup_type} FROM {sql})", params
214
215
216def _prepare_tzname_delta(tzname: str) -> str:
217 tzname, sign, offset = split_tzname_delta(tzname)
218 if offset:
219 sign = "-" if sign == "+" else "+"
220 return f"{tzname}{sign}{offset}"
221 return tzname
222
223
224def _convert_sql_to_tz(
225 sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
226) -> tuple[str, list[Any] | tuple[Any, ...]]:
227 if tzname:
228 tzname_param = _prepare_tzname_delta(tzname)
229 return f"{sql} AT TIME ZONE %s", (*params, tzname_param)
230 return sql, params
231
232
233def date_trunc_sql(
234 lookup_type: str,
235 sql: str,
236 params: list[Any] | tuple[Any, ...],
237 tzname: str | None = None,
238) -> tuple[str, tuple[Any, ...]]:
239 """
240 Given a lookup_type of 'year', 'month', or 'day', return the SQL that
241 truncates the given date or datetime field field_name to a date object
242 with only the given specificity.
243
244 If `tzname` is provided, the given value is truncated in a specific timezone.
245 """
246 sql, params = _convert_sql_to_tz(sql, params, tzname)
247 # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
248 return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
249
250
251def datetime_cast_date_sql(
252 sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
253) -> tuple[str, list[Any] | tuple[Any, ...]]:
254 """Return the SQL to cast a datetime value to date value."""
255 sql, params = _convert_sql_to_tz(sql, params, tzname)
256 return f"({sql})::date", params
257
258
259def datetime_cast_time_sql(
260 sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
261) -> tuple[str, list[Any] | tuple[Any, ...]]:
262 """Return the SQL to cast a datetime value to time value."""
263 sql, params = _convert_sql_to_tz(sql, params, tzname)
264 return f"({sql})::time", params
265
266
267def datetime_extract_sql(
268 lookup_type: str,
269 sql: str,
270 params: list[Any] | tuple[Any, ...],
271 tzname: str | None,
272) -> tuple[str, list[Any] | tuple[Any, ...]]:
273 """
274 Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
275 'second', return the SQL that extracts a value from the given
276 datetime field field_name.
277 """
278 sql, params = _convert_sql_to_tz(sql, params, tzname)
279 if lookup_type == "second":
280 # Truncate fractional seconds.
281 return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
282 return date_extract_sql(lookup_type, sql, params)
283
284
285def datetime_trunc_sql(
286 lookup_type: str,
287 sql: str,
288 params: list[Any] | tuple[Any, ...],
289 tzname: str | None,
290) -> tuple[str, tuple[Any, ...]]:
291 """
292 Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
293 'second', return the SQL that truncates the given datetime field
294 field_name to a datetime object with only the given specificity.
295 """
296 sql, params = _convert_sql_to_tz(sql, params, tzname)
297 # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
298 return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
299
300
301def time_extract_sql(
302 lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
303) -> tuple[str, list[Any] | tuple[Any, ...]]:
304 """
305 Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
306 that extracts a value from the given time field field_name.
307 """
308 if lookup_type == "second":
309 # Truncate fractional seconds.
310 return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
311 return date_extract_sql(lookup_type, sql, params)
312
313
314def time_trunc_sql(
315 lookup_type: str,
316 sql: str,
317 params: list[Any] | tuple[Any, ...],
318 tzname: str | None = None,
319) -> tuple[str, tuple[Any, ...]]:
320 """
321 Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
322 that truncates the given time or datetime field field_name to a time
323 object with only the given specificity.
324
325 If `tzname` is provided, the given value is truncated in a specific timezone.
326 """
327 sql, params = _convert_sql_to_tz(sql, params, tzname)
328 return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
329
330
331def distinct_sql(
332 fields: list[str], params: list[Any] | tuple[Any, ...]
333) -> tuple[list[str], list[Any]]:
334 """
335 Return an SQL DISTINCT clause which removes duplicate rows from the
336 result set. If any fields are given, only check the given fields for
337 duplicates.
338 """
339 if fields:
340 params = [param for param_list in params for param in param_list]
341 return (["DISTINCT ON ({})".format(", ".join(fields))], params)
342 else:
343 return ["DISTINCT"], []
344
345
346def for_update_sql(
347 nowait: bool = False,
348 skip_locked: bool = False,
349 of: tuple[str, ...] = (),
350 no_key: bool = False,
351) -> str:
352 """Return the FOR UPDATE SQL clause to lock rows for an update operation."""
353 return "FOR{} UPDATE{}{}{}".format(
354 " NO KEY" if no_key else "",
355 " OF {}".format(", ".join(of)) if of else "",
356 " NOWAIT" if nowait else "",
357 " SKIP LOCKED" if skip_locked else "",
358 )
359
360
361def limit_offset_sql(low_mark: int | None, high_mark: int | None) -> str:
362 """Return LIMIT/OFFSET SQL clause."""
363 offset = low_mark or 0
364 if high_mark is not None:
365 limit = high_mark - offset
366 else:
367 limit = None
368 return " ".join(
369 sql
370 for sql in (
371 ("LIMIT %d" % limit) if limit else None, # noqa: UP031
372 ("OFFSET %d" % offset) if offset else None, # noqa: UP031
373 )
374 if sql
375 )
376
377
378def lookup_cast(lookup_type: str, internal_type: str | None = None) -> str:
379 """
380 Return the string to use in a query when performing lookups
381 ("contains", "like", etc.). It should contain a '%s' placeholder for
382 the column being searched against.
383 """
384 lookup = "%s"
385
386 if lookup_type == "isnull" and internal_type in (
387 "CharField",
388 "EmailField",
389 "TextField",
390 ):
391 return "%s::text"
392
393 # Cast text lookups to text to allow things like filter(x__contains=4)
394 if lookup_type in (
395 "iexact",
396 "contains",
397 "icontains",
398 "startswith",
399 "istartswith",
400 "endswith",
401 "iendswith",
402 "regex",
403 "iregex",
404 ):
405 if internal_type == "GenericIPAddressField":
406 lookup = "HOST(%s)"
407 else:
408 lookup = "%s::text"
409
410 # Use UPPER(x) for case-insensitive lookups; it's faster.
411 if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
412 lookup = f"UPPER({lookup})"
413
414 return lookup
415
416
417def return_insert_columns(fields: list[Field]) -> tuple[str, tuple[Any, ...]]:
418 """Return the RETURNING clause SQL and params to append to an INSERT query."""
419 if not fields:
420 return "", ()
421 columns = [
422 f"{quote_name(field.model.model_options.db_table)}.{quote_name(field.column)}"
423 for field in fields
424 ]
425 return "RETURNING {}".format(", ".join(columns)), ()
426
427
428def bulk_insert_sql(fields: list[Field], placeholder_rows: list[list[str]]) -> str:
429 """Return the SQL for bulk inserting rows."""
430 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
431 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
432 return "VALUES " + values_sql
433
434
435def regex_lookup(lookup_type: str) -> str:
436 """
437 Return the string to use in a query when performing regular expression
438 lookups (using "regex" or "iregex").
439 """
440 # PostgreSQL uses ~ for regex and ~* for case-insensitive regex
441 if lookup_type == "regex":
442 return "%s ~ %s"
443 return "%s ~* %s"
444
445
446def prep_for_like_query(x: str) -> str:
447 """Prepare a value for use in a LIKE query."""
448 return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
449
450
451def adapt_integerfield_value(
452 value: int | Any | None, internal_type: str
453) -> int | Any | None:
454 from plain.models.expressions import ResolvableExpression
455
456 if value is None or isinstance(value, ResolvableExpression):
457 return value
458 return INTEGERFIELD_TYPE_MAP[internal_type](value)
459
460
461def adapt_ipaddressfield_value(
462 value: str | None,
463) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
464 """
465 Transform a string representation of an IP address into the expected
466 type for the backend driver.
467 """
468 if value:
469 return ipaddress.ip_address(value)
470 return None
471
472
473def adapt_json_value(value: Any, encoder: type[json.JSONEncoder] | None) -> Jsonb:
474 return Jsonb(value, dumps=get_json_dumps(encoder))
475
476
477def year_lookup_bounds_for_date_field(
478 value: int, iso_year: bool = False
479) -> list[datetime.date]:
480 """
481 Return a two-elements list with the lower and upper bound to be used
482 with a BETWEEN operator to query a DateField value using a year lookup.
483
484 `value` is an int, containing the looked-up year.
485 If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
486 """
487 if iso_year:
488 first = datetime.date.fromisocalendar(value, 1, 1)
489 second = datetime.date.fromisocalendar(value + 1, 1, 1) - datetime.timedelta(
490 days=1
491 )
492 else:
493 first = datetime.date(value, 1, 1)
494 second = datetime.date(value, 12, 31)
495 return [first, second]
496
497
498def year_lookup_bounds_for_datetime_field(
499 value: int, iso_year: bool = False
500) -> list[datetime.datetime]:
501 """
502 Return a two-elements list with the lower and upper bound to be used
503 with a BETWEEN operator to query a DateTimeField value using a year lookup.
504
505 `value` is an int, containing the looked-up year.
506 If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
507 """
508 if iso_year:
509 first = datetime.datetime.fromisocalendar(value, 1, 1)
510 second = datetime.datetime.fromisocalendar(
511 value + 1, 1, 1
512 ) - datetime.timedelta(microseconds=1)
513 else:
514 first = datetime.datetime(value, 1, 1)
515 second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
516
517 # Make sure that datetimes are aware in the current timezone
518 tz = timezone.get_current_timezone()
519 first = timezone.make_aware(first, tz)
520 second = timezone.make_aware(second, tz)
521 return [first, second]
522
523
524def combine_expression(connector: str, sub_expressions: list[str]) -> str:
525 """
526 Combine a list of subexpressions into a single expression, using
527 the provided connecting operator.
528 """
529 conn = f" {connector} "
530 return conn.join(sub_expressions)
531
532
533def subtract_temporals(
534 internal_type: str,
535 lhs: tuple[str, list[Any] | tuple[Any, ...]],
536 rhs: tuple[str, list[Any] | tuple[Any, ...]],
537) -> tuple[str, tuple[Any, ...]]:
538 lhs_sql, lhs_params = lhs
539 rhs_sql, rhs_params = rhs
540 params = (*lhs_params, *rhs_params)
541 if internal_type == "DateField":
542 return f"(interval '1 day' * ({lhs_sql} - {rhs_sql}))", params
543 # Use native temporal subtraction
544 return f"({lhs_sql} - {rhs_sql})", params
545
546
547def window_frame_start(start: int | None) -> str:
548 if isinstance(start, int):
549 if start < 0:
550 return "%d %s" % (abs(start), PRECEDING) # noqa: UP031
551 elif start == 0:
552 return CURRENT_ROW
553 elif start is None:
554 return UNBOUNDED_PRECEDING
555 raise ValueError(
556 f"start argument must be a negative integer, zero, or None, but got '{start}'."
557 )
558
559
560def window_frame_end(end: int | None) -> str:
561 if isinstance(end, int):
562 if end == 0:
563 return CURRENT_ROW
564 elif end > 0:
565 return "%d %s" % (end, FOLLOWING) # noqa: UP031
566 elif end is None:
567 return UNBOUNDED_FOLLOWING
568 raise ValueError(
569 f"end argument must be a positive integer, zero, or None, but got '{end}'."
570 )
571
572
573def window_frame_rows_start_end(
574 start: int | None = None, end: int | None = None
575) -> tuple[str, str]:
576 """Return SQL for start and end points in an OVER clause window frame."""
577 return window_frame_start(start), window_frame_end(end)
578
579
580def window_frame_range_start_end(
581 start: int | None = None, end: int | None = None
582) -> tuple[str, str]:
583 start_, end_ = window_frame_rows_start_end(start, end)
584 # PostgreSQL only supports UNBOUNDED with PRECEDING/FOLLOWING
585 if (start and start < 0) or (end and end > 0):
586 raise NotSupportedError(
587 "PostgreSQL only supports UNBOUNDED together with PRECEDING and FOLLOWING."
588 )
589 return start_, end_
590
591
592def explain_query_prefix(format: str | None = None, **options: Any) -> str:
593 extra = {}
594 # Normalize options.
595 if options:
596 options = {
597 name.upper(): "true" if value else "false"
598 for name, value in options.items()
599 }
600 for valid_option in EXPLAIN_OPTIONS:
601 value = options.pop(valid_option, None)
602 if value is not None:
603 extra[valid_option] = value
604 if format:
605 normalized_format = format.upper()
606 if normalized_format not in SUPPORTED_EXPLAIN_FORMATS:
607 msg = "{} is not a recognized format. Allowed formats: {}".format(
608 normalized_format, ", ".join(sorted(SUPPORTED_EXPLAIN_FORMATS))
609 )
610 raise ValueError(msg)
611 extra["FORMAT"] = format
612 if options:
613 raise ValueError(
614 "Unknown options: {}".format(", ".join(sorted(options.keys())))
615 )
616 prefix = EXPLAIN_PREFIX
617 if extra:
618 prefix += " ({})".format(", ".join("{} {}".format(*i) for i in extra.items()))
619 return prefix
620
621
622def on_conflict_suffix_sql(
623 fields: list[Field],
624 on_conflict: OnConflict | None,
625 update_fields: Iterable[str],
626 unique_fields: Iterable[str],
627) -> str:
628 if on_conflict == OnConflict.IGNORE:
629 return "ON CONFLICT DO NOTHING"
630 if on_conflict == OnConflict.UPDATE:
631 return "ON CONFLICT({}) DO UPDATE SET {}".format(
632 ", ".join(map(quote_name, unique_fields)),
633 ", ".join(
634 [
635 f"{field} = EXCLUDED.{field}"
636 for field in map(quote_name, update_fields)
637 ]
638 ),
639 )
640 return ""