1"""Database functions that do comparisons or type conversions."""
  2
  3from __future__ import annotations
  4
  5from typing import TYPE_CHECKING, Any
  6
  7from plain.postgres.expressions import Func, Value
  8from plain.postgres.fields import Field, TextField
  9from plain.postgres.fields.json import JSONField
 10
 11if TYPE_CHECKING:
 12    from plain.postgres.connection import DatabaseConnection
 13    from plain.postgres.sql.compiler import SQLCompiler
 14
 15
 16class Cast(Func):
 17    """Coerce an expression to a new field type."""
 18
 19    function = "CAST"
 20    # PostgreSQL :: shortcut syntax is more readable than standard CAST().
 21    template = "(%(expressions)s)::%(db_type)s"
 22
 23    def __init__(self, expression: Any, output_field: Field) -> None:
 24        super().__init__(expression, output_field=output_field)
 25
 26    def as_sql(
 27        self,
 28        compiler: SQLCompiler,
 29        connection: DatabaseConnection,
 30        function: str | None = None,
 31        template: str | None = None,
 32        arg_joiner: str | None = None,
 33        **extra_context: Any,
 34    ) -> tuple[str, list[Any]]:
 35        extra_context["db_type"] = self.output_field.cast_db_type()
 36        return super().as_sql(
 37            compiler, connection, function, template, arg_joiner, **extra_context
 38        )
 39
 40
 41class Coalesce(Func):
 42    """Return, from left to right, the first non-null expression."""
 43
 44    function = "COALESCE"
 45
 46    def __init__(self, *expressions: Any, **extra: Any) -> None:
 47        if len(expressions) < 2:
 48            raise ValueError("Coalesce must take at least two expressions")
 49        super().__init__(*expressions, **extra)
 50
 51    @property
 52    def empty_result_set_value(self) -> Any:
 53        for expression in self.get_source_expressions():
 54            result = expression.empty_result_set_value
 55            if result is NotImplemented or result is not None:
 56                return result
 57        return None
 58
 59
 60class Greatest(Func):
 61    """
 62    Return the maximum expression.
 63
 64    If any expression is null the return value is database-specific:
 65    On PostgreSQL, the maximum not-null expression is returned.
 66    """
 67
 68    function = "GREATEST"
 69
 70    def __init__(self, *expressions: Any, **extra: Any) -> None:
 71        if len(expressions) < 2:
 72            raise ValueError("Greatest must take at least two expressions")
 73        super().__init__(*expressions, **extra)
 74
 75
 76class JSONObject(Func):
 77    # PostgreSQL uses JSONB_BUILD_OBJECT for JSON object construction.
 78    function = "JSONB_BUILD_OBJECT"
 79    output_field = JSONField()
 80
 81    def __init__(self, **fields: Any) -> None:
 82        expressions = []
 83        for key, value in fields.items():
 84            expressions.extend((Value(key), value))
 85        super().__init__(*expressions)
 86
 87    def as_sql(
 88        self,
 89        compiler: SQLCompiler,
 90        connection: DatabaseConnection,
 91        function: str | None = None,
 92        template: str | None = None,
 93        arg_joiner: str | None = None,
 94        **extra_context: Any,
 95    ) -> tuple[str, list[Any]]:
 96        # PostgreSQL requires keys to be cast to text.
 97        copy = self.copy()
 98        copy.set_source_expressions(
 99            [
100                Cast(expression, TextField()) if index % 2 == 0 else expression
101                for index, expression in enumerate(copy.get_source_expressions())
102            ]
103        )
104        return super(JSONObject, copy).as_sql(
105            compiler, connection, function, template, arg_joiner, **extra_context
106        )
107
108
109class Least(Func):
110    """
111    Return the minimum expression.
112
113    If any expression is null the return value is database-specific:
114    On PostgreSQL, return the minimum not-null expression.
115    """
116
117    function = "LEAST"
118
119    def __init__(self, *expressions: Any, **extra: Any) -> None:
120        if len(expressions) < 2:
121            raise ValueError("Least must take at least two expressions")
122        super().__init__(*expressions, **extra)
123
124
125class NullIf(Func):
126    function = "NULLIF"
127    arity = 2