1"""Database functions that do comparisons or type conversions."""
2
3from __future__ import annotations
4
5from typing import TYPE_CHECKING, Any
6
7from plain.postgres.expressions import Func, Value
8from plain.postgres.fields import Field, TextField
9from plain.postgres.fields.json import JSONField
10
11if TYPE_CHECKING:
12 from plain.postgres.connection import DatabaseConnection
13 from plain.postgres.sql.compiler import SQLCompiler
14
15
16class Cast(Func):
17 """Coerce an expression to a new field type."""
18
19 function = "CAST"
20 # PostgreSQL :: shortcut syntax is more readable than standard CAST().
21 template = "(%(expressions)s)::%(db_type)s"
22
23 def __init__(self, expression: Any, output_field: Field) -> None:
24 super().__init__(expression, output_field=output_field)
25
26 def as_sql(
27 self,
28 compiler: SQLCompiler,
29 connection: DatabaseConnection,
30 function: str | None = None,
31 template: str | None = None,
32 arg_joiner: str | None = None,
33 **extra_context: Any,
34 ) -> tuple[str, list[Any]]:
35 extra_context["db_type"] = self.output_field.cast_db_type()
36 return super().as_sql(
37 compiler, connection, function, template, arg_joiner, **extra_context
38 )
39
40
41class Coalesce(Func):
42 """Return, from left to right, the first non-null expression."""
43
44 function = "COALESCE"
45
46 def __init__(self, *expressions: Any, **extra: Any) -> None:
47 if len(expressions) < 2:
48 raise ValueError("Coalesce must take at least two expressions")
49 super().__init__(*expressions, **extra)
50
51 @property
52 def empty_result_set_value(self) -> Any:
53 for expression in self.get_source_expressions():
54 result = expression.empty_result_set_value
55 if result is NotImplemented or result is not None:
56 return result
57 return None
58
59
60class Greatest(Func):
61 """
62 Return the maximum expression.
63
64 If any expression is null the return value is database-specific:
65 On PostgreSQL, the maximum not-null expression is returned.
66 """
67
68 function = "GREATEST"
69
70 def __init__(self, *expressions: Any, **extra: Any) -> None:
71 if len(expressions) < 2:
72 raise ValueError("Greatest must take at least two expressions")
73 super().__init__(*expressions, **extra)
74
75
76class JSONObject(Func):
77 # PostgreSQL uses JSONB_BUILD_OBJECT for JSON object construction.
78 function = "JSONB_BUILD_OBJECT"
79 output_field = JSONField()
80
81 def __init__(self, **fields: Any) -> None:
82 expressions = []
83 for key, value in fields.items():
84 expressions.extend((Value(key), value))
85 super().__init__(*expressions)
86
87 def as_sql(
88 self,
89 compiler: SQLCompiler,
90 connection: DatabaseConnection,
91 function: str | None = None,
92 template: str | None = None,
93 arg_joiner: str | None = None,
94 **extra_context: Any,
95 ) -> tuple[str, list[Any]]:
96 # PostgreSQL requires keys to be cast to text.
97 copy = self.copy()
98 copy.set_source_expressions(
99 [
100 Cast(expression, TextField()) if index % 2 == 0 else expression
101 for index, expression in enumerate(copy.get_source_expressions())
102 ]
103 )
104 return super(JSONObject, copy).as_sql(
105 compiler, connection, function, template, arg_joiner, **extra_context
106 )
107
108
109class Least(Func):
110 """
111 Return the minimum expression.
112
113 If any expression is null the return value is database-specific:
114 On PostgreSQL, return the minimum not-null expression.
115 """
116
117 function = "LEAST"
118
119 def __init__(self, *expressions: Any, **extra: Any) -> None:
120 if len(expressions) < 2:
121 raise ValueError("Least must take at least two expressions")
122 super().__init__(*expressions, **extra)
123
124
125class NullIf(Func):
126 function = "NULLIF"
127 arity = 2