1"""
  2PostgreSQL-specific SQL generation functions.
  3
  4All functions in this module are stateless - they don't depend on connection state.
  5"""
  6
  7from __future__ import annotations
  8
  9import datetime
 10import ipaddress
 11import json
 12from collections.abc import Callable, Iterable
 13from functools import lru_cache, partial
 14from typing import TYPE_CHECKING, Any
 15
 16from psycopg.types import numeric
 17from psycopg.types.json import Jsonb
 18
 19from plain.models.constants import OnConflict
 20from plain.models.db import NotSupportedError
 21from plain.models.postgres.utils import split_tzname_delta
 22from plain.utils import timezone
 23from plain.utils.regex_helper import _lazy_re_compile
 24
 25if TYPE_CHECKING:
 26    from plain.models.fields import Field
 27
 28
 29# Integer field safe ranges by internal_type.
 30INTEGER_FIELD_RANGES: dict[str, tuple[int, int]] = {
 31    "SmallIntegerField": (-32768, 32767),
 32    "IntegerField": (-2147483648, 2147483647),
 33    "BigIntegerField": (-9223372036854775808, 9223372036854775807),
 34    "PositiveBigIntegerField": (0, 9223372036854775807),
 35    "PositiveSmallIntegerField": (0, 32767),
 36    "PositiveIntegerField": (0, 2147483647),
 37    "PrimaryKeyField": (-9223372036854775808, 9223372036854775807),
 38}
 39
 40# Mapping of Field.get_internal_type() to the data type for Cast().
 41CAST_DATA_TYPES: dict[str, str] = {
 42    "PrimaryKeyField": "bigint",
 43}
 44
 45# CharField data type when max_length isn't provided.
 46CAST_CHAR_FIELD_WITHOUT_MAX_LENGTH: str | None = "varchar"
 47
 48# Start and end points for window expressions.
 49PRECEDING: str = "PRECEDING"
 50FOLLOWING: str = "FOLLOWING"
 51UNBOUNDED_PRECEDING: str = "UNBOUNDED " + PRECEDING
 52UNBOUNDED_FOLLOWING: str = "UNBOUNDED " + FOLLOWING
 53CURRENT_ROW: str = "CURRENT ROW"
 54
 55# Prefix for EXPLAIN queries.
 56EXPLAIN_PREFIX: str = "EXPLAIN"
 57EXPLAIN_OPTIONS = frozenset(
 58    [
 59        "ANALYZE",
 60        "BUFFERS",
 61        "COSTS",
 62        "SETTINGS",
 63        "SUMMARY",
 64        "TIMING",
 65        "VERBOSE",
 66        "WAL",
 67    ]
 68)
 69SUPPORTED_EXPLAIN_FORMATS: set[str] = {"JSON", "TEXT", "XML", "YAML"}
 70
 71# PostgreSQL integer type mapping for psycopg.
 72INTEGERFIELD_TYPE_MAP = {
 73    "SmallIntegerField": numeric.Int2,
 74    "IntegerField": numeric.Int4,
 75    "BigIntegerField": numeric.Int8,
 76    "PositiveSmallIntegerField": numeric.Int2,
 77    "PositiveIntegerField": numeric.Int4,
 78    "PositiveBigIntegerField": numeric.Int8,
 79}
 80
 81# Maximum length of an identifier (63 by default in PostgreSQL).
 82MAX_NAME_LENGTH: int = 63
 83
 84# Value to use during INSERT to specify that a field should use its default value.
 85PK_DEFAULT_VALUE: str = "DEFAULT"
 86
 87# SQL clause to make a constraint "initially deferred" during CREATE TABLE.
 88DEFERRABLE_SQL: str = " DEFERRABLE INITIALLY DEFERRED"
 89
 90# EXTRACT format validation pattern.
 91_EXTRACT_FORMAT_RE = _lazy_re_compile(r"[A-Z_]+")
 92
 93
 94# ##### Data type mappings (from constants.py) #####
 95
 96
 97def _get_varchar_column(data: dict[str, Any]) -> str:
 98    if data["max_length"] is None:
 99        return "varchar"
100    return "varchar({max_length})".format(**data)
101
102
103# Maps Field objects to their associated PostgreSQL column types.
104# Column-type strings can contain format strings interpolated against Field.__dict__.
105DATA_TYPES: dict[str, Any] = {
106    "PrimaryKeyField": "bigint",
107    "BinaryField": "bytea",
108    "BooleanField": "boolean",
109    "CharField": _get_varchar_column,
110    "DateField": "date",
111    "DateTimeField": "timestamp with time zone",
112    "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
113    "DurationField": "interval",
114    "FloatField": "double precision",
115    "IntegerField": "integer",
116    "BigIntegerField": "bigint",
117    "GenericIPAddressField": "inet",
118    "JSONField": "jsonb",
119    "PositiveBigIntegerField": "bigint",
120    "PositiveIntegerField": "integer",
121    "PositiveSmallIntegerField": "smallint",
122    "SmallIntegerField": "smallint",
123    "TextField": "text",
124    "TimeField": "time",
125    "UUIDField": "uuid",
126}
127
128# Check constraints for fields that need them.
129DATA_TYPE_CHECK_CONSTRAINTS: dict[str, str] = {
130    "PositiveBigIntegerField": '"%(column)s" >= 0',
131    "PositiveIntegerField": '"%(column)s" >= 0',
132    "PositiveSmallIntegerField": '"%(column)s" >= 0',
133}
134
135# Suffix applied to column definitions (e.g., for identity columns).
136DATA_TYPES_SUFFIX: dict[str, str] = {
137    "PrimaryKeyField": "GENERATED BY DEFAULT AS IDENTITY",
138}
139
140# SQL operators for lookups.
141OPERATORS: dict[str, str] = {
142    "exact": "= %s",
143    "iexact": "= UPPER(%s)",
144    "contains": "LIKE %s",
145    "icontains": "LIKE UPPER(%s)",
146    "regex": "~ %s",
147    "iregex": "~* %s",
148    "gt": "> %s",
149    "gte": ">= %s",
150    "lt": "< %s",
151    "lte": "<= %s",
152    "startswith": "LIKE %s",
153    "endswith": "LIKE %s",
154    "istartswith": "LIKE UPPER(%s)",
155    "iendswith": "LIKE UPPER(%s)",
156}
157
158# SQL pattern for escaping special characters in LIKE clauses.
159# Used when the right-hand side isn't a raw string (e.g., an expression).
160PATTERN_ESC = (
161    r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
162)
163
164# Pattern operators for non-literal LIKE lookups.
165PATTERN_OPS: dict[str, str] = {
166    "contains": "LIKE '%%' || {} || '%%'",
167    "icontains": "LIKE '%%' || UPPER({}) || '%%'",
168    "startswith": "LIKE {} || '%%'",
169    "istartswith": "LIKE UPPER({}) || '%%'",
170    "endswith": "LIKE '%%' || {}",
171    "iendswith": "LIKE '%%' || UPPER({})",
172}
173
174
175@lru_cache
176def get_json_dumps(
177    encoder: type[json.JSONEncoder] | None,
178) -> Callable[..., str]:
179    if encoder is None:
180        return json.dumps
181    return partial(json.dumps, cls=encoder)
182
183
184def quote_name(name: str) -> str:
185    """
186    Return a quoted version of the given table, index, or column name.
187    Does not quote the given name if it's already been quoted.
188    """
189    if name.startswith('"') and name.endswith('"'):
190        return name  # Quoting once is enough.
191    return f'"{name}"'
192
193
194def date_extract_sql(
195    lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
196) -> tuple[str, list[Any] | tuple[Any, ...]]:
197    """
198    Given a lookup_type of 'year', 'month', or 'day', return the SQL that
199    extracts a value from the given date field field_name.
200    """
201    # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
202    if lookup_type == "week_day":
203        # PostgreSQL DOW returns 0=Sunday, 6=Saturday; we return 1=Sunday, 7=Saturday.
204        return f"EXTRACT(DOW FROM {sql}) + 1", params
205    elif lookup_type == "iso_week_day":
206        return f"EXTRACT(ISODOW FROM {sql})", params
207    elif lookup_type == "iso_year":
208        return f"EXTRACT(ISOYEAR FROM {sql})", params
209
210    lookup_type = lookup_type.upper()
211    if not _EXTRACT_FORMAT_RE.fullmatch(lookup_type):
212        raise ValueError(f"Invalid lookup type: {lookup_type!r}")
213    return f"EXTRACT({lookup_type} FROM {sql})", params
214
215
216def _prepare_tzname_delta(tzname: str) -> str:
217    tzname, sign, offset = split_tzname_delta(tzname)
218    if offset:
219        sign = "-" if sign == "+" else "+"
220        return f"{tzname}{sign}{offset}"
221    return tzname
222
223
224def _convert_sql_to_tz(
225    sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
226) -> tuple[str, list[Any] | tuple[Any, ...]]:
227    if tzname:
228        tzname_param = _prepare_tzname_delta(tzname)
229        return f"{sql} AT TIME ZONE %s", (*params, tzname_param)
230    return sql, params
231
232
233def date_trunc_sql(
234    lookup_type: str,
235    sql: str,
236    params: list[Any] | tuple[Any, ...],
237    tzname: str | None = None,
238) -> tuple[str, tuple[Any, ...]]:
239    """
240    Given a lookup_type of 'year', 'month', or 'day', return the SQL that
241    truncates the given date or datetime field field_name to a date object
242    with only the given specificity.
243
244    If `tzname` is provided, the given value is truncated in a specific timezone.
245    """
246    sql, params = _convert_sql_to_tz(sql, params, tzname)
247    # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
248    return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
249
250
251def datetime_cast_date_sql(
252    sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
253) -> tuple[str, list[Any] | tuple[Any, ...]]:
254    """Return the SQL to cast a datetime value to date value."""
255    sql, params = _convert_sql_to_tz(sql, params, tzname)
256    return f"({sql})::date", params
257
258
259def datetime_cast_time_sql(
260    sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
261) -> tuple[str, list[Any] | tuple[Any, ...]]:
262    """Return the SQL to cast a datetime value to time value."""
263    sql, params = _convert_sql_to_tz(sql, params, tzname)
264    return f"({sql})::time", params
265
266
267def datetime_extract_sql(
268    lookup_type: str,
269    sql: str,
270    params: list[Any] | tuple[Any, ...],
271    tzname: str | None,
272) -> tuple[str, list[Any] | tuple[Any, ...]]:
273    """
274    Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
275    'second', return the SQL that extracts a value from the given
276    datetime field field_name.
277    """
278    sql, params = _convert_sql_to_tz(sql, params, tzname)
279    if lookup_type == "second":
280        # Truncate fractional seconds.
281        return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
282    return date_extract_sql(lookup_type, sql, params)
283
284
285def datetime_trunc_sql(
286    lookup_type: str,
287    sql: str,
288    params: list[Any] | tuple[Any, ...],
289    tzname: str | None,
290) -> tuple[str, tuple[Any, ...]]:
291    """
292    Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
293    'second', return the SQL that truncates the given datetime field
294    field_name to a datetime object with only the given specificity.
295    """
296    sql, params = _convert_sql_to_tz(sql, params, tzname)
297    # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
298    return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
299
300
301def time_extract_sql(
302    lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
303) -> tuple[str, list[Any] | tuple[Any, ...]]:
304    """
305    Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
306    that extracts a value from the given time field field_name.
307    """
308    if lookup_type == "second":
309        # Truncate fractional seconds.
310        return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
311    return date_extract_sql(lookup_type, sql, params)
312
313
314def time_trunc_sql(
315    lookup_type: str,
316    sql: str,
317    params: list[Any] | tuple[Any, ...],
318    tzname: str | None = None,
319) -> tuple[str, tuple[Any, ...]]:
320    """
321    Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
322    that truncates the given time or datetime field field_name to a time
323    object with only the given specificity.
324
325    If `tzname` is provided, the given value is truncated in a specific timezone.
326    """
327    sql, params = _convert_sql_to_tz(sql, params, tzname)
328    return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
329
330
331def distinct_sql(
332    fields: list[str], params: list[Any] | tuple[Any, ...]
333) -> tuple[list[str], list[Any]]:
334    """
335    Return an SQL DISTINCT clause which removes duplicate rows from the
336    result set. If any fields are given, only check the given fields for
337    duplicates.
338    """
339    if fields:
340        params = [param for param_list in params for param in param_list]
341        return (["DISTINCT ON ({})".format(", ".join(fields))], params)
342    else:
343        return ["DISTINCT"], []
344
345
346def for_update_sql(
347    nowait: bool = False,
348    skip_locked: bool = False,
349    of: tuple[str, ...] = (),
350    no_key: bool = False,
351) -> str:
352    """Return the FOR UPDATE SQL clause to lock rows for an update operation."""
353    return "FOR{} UPDATE{}{}{}".format(
354        " NO KEY" if no_key else "",
355        " OF {}".format(", ".join(of)) if of else "",
356        " NOWAIT" if nowait else "",
357        " SKIP LOCKED" if skip_locked else "",
358    )
359
360
361def limit_offset_sql(low_mark: int | None, high_mark: int | None) -> str:
362    """Return LIMIT/OFFSET SQL clause."""
363    offset = low_mark or 0
364    if high_mark is not None:
365        limit = high_mark - offset
366    else:
367        limit = None
368    return " ".join(
369        sql
370        for sql in (
371            ("LIMIT %d" % limit) if limit else None,  # noqa: UP031
372            ("OFFSET %d" % offset) if offset else None,  # noqa: UP031
373        )
374        if sql
375    )
376
377
378def lookup_cast(lookup_type: str, internal_type: str | None = None) -> str:
379    """
380    Return the string to use in a query when performing lookups
381    ("contains", "like", etc.). It should contain a '%s' placeholder for
382    the column being searched against.
383    """
384    lookup = "%s"
385
386    if lookup_type == "isnull" and internal_type in (
387        "CharField",
388        "EmailField",
389        "TextField",
390    ):
391        return "%s::text"
392
393    # Cast text lookups to text to allow things like filter(x__contains=4)
394    if lookup_type in (
395        "iexact",
396        "contains",
397        "icontains",
398        "startswith",
399        "istartswith",
400        "endswith",
401        "iendswith",
402        "regex",
403        "iregex",
404    ):
405        if internal_type == "GenericIPAddressField":
406            lookup = "HOST(%s)"
407        else:
408            lookup = "%s::text"
409
410    # Use UPPER(x) for case-insensitive lookups; it's faster.
411    if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
412        lookup = f"UPPER({lookup})"
413
414    return lookup
415
416
417def return_insert_columns(fields: list[Field]) -> tuple[str, tuple[Any, ...]]:
418    """Return the RETURNING clause SQL and params to append to an INSERT query."""
419    if not fields:
420        return "", ()
421    columns = [
422        f"{quote_name(field.model.model_options.db_table)}.{quote_name(field.column)}"
423        for field in fields
424    ]
425    return "RETURNING {}".format(", ".join(columns)), ()
426
427
428def bulk_insert_sql(fields: list[Field], placeholder_rows: list[list[str]]) -> str:
429    """Return the SQL for bulk inserting rows."""
430    placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
431    values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
432    return "VALUES " + values_sql
433
434
435def regex_lookup(lookup_type: str) -> str:
436    """
437    Return the string to use in a query when performing regular expression
438    lookups (using "regex" or "iregex").
439    """
440    # PostgreSQL uses ~ for regex and ~* for case-insensitive regex
441    if lookup_type == "regex":
442        return "%s ~ %s"
443    return "%s ~* %s"
444
445
446def prep_for_like_query(x: str) -> str:
447    """Prepare a value for use in a LIKE query."""
448    return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
449
450
451def adapt_integerfield_value(
452    value: int | Any | None, internal_type: str
453) -> int | Any | None:
454    from plain.models.expressions import ResolvableExpression
455
456    if value is None or isinstance(value, ResolvableExpression):
457        return value
458    return INTEGERFIELD_TYPE_MAP[internal_type](value)
459
460
461def adapt_ipaddressfield_value(
462    value: str | None,
463) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
464    """
465    Transform a string representation of an IP address into the expected
466    type for the backend driver.
467    """
468    if value:
469        return ipaddress.ip_address(value)
470    return None
471
472
473def adapt_json_value(value: Any, encoder: type[json.JSONEncoder] | None) -> Jsonb:
474    return Jsonb(value, dumps=get_json_dumps(encoder))
475
476
477def year_lookup_bounds_for_date_field(
478    value: int, iso_year: bool = False
479) -> list[datetime.date]:
480    """
481    Return a two-elements list with the lower and upper bound to be used
482    with a BETWEEN operator to query a DateField value using a year lookup.
483
484    `value` is an int, containing the looked-up year.
485    If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
486    """
487    if iso_year:
488        first = datetime.date.fromisocalendar(value, 1, 1)
489        second = datetime.date.fromisocalendar(value + 1, 1, 1) - datetime.timedelta(
490            days=1
491        )
492    else:
493        first = datetime.date(value, 1, 1)
494        second = datetime.date(value, 12, 31)
495    return [first, second]
496
497
498def year_lookup_bounds_for_datetime_field(
499    value: int, iso_year: bool = False
500) -> list[datetime.datetime]:
501    """
502    Return a two-elements list with the lower and upper bound to be used
503    with a BETWEEN operator to query a DateTimeField value using a year lookup.
504
505    `value` is an int, containing the looked-up year.
506    If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
507    """
508    if iso_year:
509        first = datetime.datetime.fromisocalendar(value, 1, 1)
510        second = datetime.datetime.fromisocalendar(
511            value + 1, 1, 1
512        ) - datetime.timedelta(microseconds=1)
513    else:
514        first = datetime.datetime(value, 1, 1)
515        second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
516
517    # Make sure that datetimes are aware in the current timezone
518    tz = timezone.get_current_timezone()
519    first = timezone.make_aware(first, tz)
520    second = timezone.make_aware(second, tz)
521    return [first, second]
522
523
524def combine_expression(connector: str, sub_expressions: list[str]) -> str:
525    """
526    Combine a list of subexpressions into a single expression, using
527    the provided connecting operator.
528    """
529    conn = f" {connector} "
530    return conn.join(sub_expressions)
531
532
533def subtract_temporals(
534    internal_type: str,
535    lhs: tuple[str, list[Any] | tuple[Any, ...]],
536    rhs: tuple[str, list[Any] | tuple[Any, ...]],
537) -> tuple[str, tuple[Any, ...]]:
538    lhs_sql, lhs_params = lhs
539    rhs_sql, rhs_params = rhs
540    params = (*lhs_params, *rhs_params)
541    if internal_type == "DateField":
542        return f"(interval '1 day' * ({lhs_sql} - {rhs_sql}))", params
543    # Use native temporal subtraction
544    return f"({lhs_sql} - {rhs_sql})", params
545
546
547def window_frame_start(start: int | None) -> str:
548    if isinstance(start, int):
549        if start < 0:
550            return "%d %s" % (abs(start), PRECEDING)  # noqa: UP031
551        elif start == 0:
552            return CURRENT_ROW
553    elif start is None:
554        return UNBOUNDED_PRECEDING
555    raise ValueError(
556        f"start argument must be a negative integer, zero, or None, but got '{start}'."
557    )
558
559
560def window_frame_end(end: int | None) -> str:
561    if isinstance(end, int):
562        if end == 0:
563            return CURRENT_ROW
564        elif end > 0:
565            return "%d %s" % (end, FOLLOWING)  # noqa: UP031
566    elif end is None:
567        return UNBOUNDED_FOLLOWING
568    raise ValueError(
569        f"end argument must be a positive integer, zero, or None, but got '{end}'."
570    )
571
572
573def window_frame_rows_start_end(
574    start: int | None = None, end: int | None = None
575) -> tuple[str, str]:
576    """Return SQL for start and end points in an OVER clause window frame."""
577    return window_frame_start(start), window_frame_end(end)
578
579
580def window_frame_range_start_end(
581    start: int | None = None, end: int | None = None
582) -> tuple[str, str]:
583    start_, end_ = window_frame_rows_start_end(start, end)
584    # PostgreSQL only supports UNBOUNDED with PRECEDING/FOLLOWING
585    if (start and start < 0) or (end and end > 0):
586        raise NotSupportedError(
587            "PostgreSQL only supports UNBOUNDED together with PRECEDING and FOLLOWING."
588        )
589    return start_, end_
590
591
592def explain_query_prefix(format: str | None = None, **options: Any) -> str:
593    extra = {}
594    # Normalize options.
595    if options:
596        options = {
597            name.upper(): "true" if value else "false"
598            for name, value in options.items()
599        }
600        for valid_option in EXPLAIN_OPTIONS:
601            value = options.pop(valid_option, None)
602            if value is not None:
603                extra[valid_option] = value
604    if format:
605        normalized_format = format.upper()
606        if normalized_format not in SUPPORTED_EXPLAIN_FORMATS:
607            msg = "{} is not a recognized format. Allowed formats: {}".format(
608                normalized_format, ", ".join(sorted(SUPPORTED_EXPLAIN_FORMATS))
609            )
610            raise ValueError(msg)
611        extra["FORMAT"] = format
612    if options:
613        raise ValueError(
614            "Unknown options: {}".format(", ".join(sorted(options.keys())))
615        )
616    prefix = EXPLAIN_PREFIX
617    if extra:
618        prefix += " ({})".format(", ".join("{} {}".format(*i) for i in extra.items()))
619    return prefix
620
621
622def on_conflict_suffix_sql(
623    fields: list[Field],
624    on_conflict: OnConflict | None,
625    update_fields: Iterable[str],
626    unique_fields: Iterable[str],
627) -> str:
628    if on_conflict == OnConflict.IGNORE:
629        return "ON CONFLICT DO NOTHING"
630    if on_conflict == OnConflict.UPDATE:
631        return "ON CONFLICT({}) DO UPDATE SET {}".format(
632            ", ".join(map(quote_name, unique_fields)),
633            ", ".join(
634                [
635                    f"{field} = EXCLUDED.{field}"
636                    for field in map(quote_name, update_fields)
637                ]
638            ),
639        )
640    return ""