1"""Database functions that do comparisons or type conversions."""
2from plain.models.db import NotSupportedError
3from plain.models.expressions import Func, Value
4from plain.models.fields import TextField
5from plain.models.fields.json import JSONField
6from plain.utils.regex_helper import _lazy_re_compile
7
8
9class Cast(Func):
10 """Coerce an expression to a new field type."""
11
12 function = "CAST"
13 template = "%(function)s(%(expressions)s AS %(db_type)s)"
14
15 def __init__(self, expression, output_field):
16 super().__init__(expression, output_field=output_field)
17
18 def as_sql(self, compiler, connection, **extra_context):
19 extra_context["db_type"] = self.output_field.cast_db_type(connection)
20 return super().as_sql(compiler, connection, **extra_context)
21
22 def as_sqlite(self, compiler, connection, **extra_context):
23 db_type = self.output_field.db_type(connection)
24 if db_type in {"datetime", "time"}:
25 # Use strftime as datetime/time don't keep fractional seconds.
26 template = "strftime(%%s, %(expressions)s)"
27 sql, params = super().as_sql(
28 compiler, connection, template=template, **extra_context
29 )
30 format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
31 params.insert(0, format_string)
32 return sql, params
33 elif db_type == "date":
34 template = "date(%(expressions)s)"
35 return super().as_sql(
36 compiler, connection, template=template, **extra_context
37 )
38 return self.as_sql(compiler, connection, **extra_context)
39
40 def as_mysql(self, compiler, connection, **extra_context):
41 template = None
42 output_type = self.output_field.get_internal_type()
43 # MySQL doesn't support explicit cast to float.
44 if output_type == "FloatField":
45 template = "(%(expressions)s + 0.0)"
46 # MariaDB doesn't support explicit cast to JSON.
47 elif output_type == "JSONField" and connection.mysql_is_mariadb:
48 template = "JSON_EXTRACT(%(expressions)s, '$')"
49 return self.as_sql(compiler, connection, template=template, **extra_context)
50
51 def as_postgresql(self, compiler, connection, **extra_context):
52 # CAST would be valid too, but the :: shortcut syntax is more readable.
53 # 'expressions' is wrapped in parentheses in case it's a complex
54 # expression.
55 return self.as_sql(
56 compiler,
57 connection,
58 template="(%(expressions)s)::%(db_type)s",
59 **extra_context,
60 )
61
62
63class Coalesce(Func):
64 """Return, from left to right, the first non-null expression."""
65
66 function = "COALESCE"
67
68 def __init__(self, *expressions, **extra):
69 if len(expressions) < 2:
70 raise ValueError("Coalesce must take at least two expressions")
71 super().__init__(*expressions, **extra)
72
73 @property
74 def empty_result_set_value(self):
75 for expression in self.get_source_expressions():
76 result = expression.empty_result_set_value
77 if result is NotImplemented or result is not None:
78 return result
79 return None
80
81
82class Collate(Func):
83 function = "COLLATE"
84 template = "%(expressions)s %(function)s %(collation)s"
85 # Inspired from
86 # https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
87 collation_re = _lazy_re_compile(r"^[\w\-]+$")
88
89 def __init__(self, expression, collation):
90 if not (collation and self.collation_re.match(collation)):
91 raise ValueError("Invalid collation name: %r." % collation)
92 self.collation = collation
93 super().__init__(expression)
94
95 def as_sql(self, compiler, connection, **extra_context):
96 extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
97 return super().as_sql(compiler, connection, **extra_context)
98
99
100class Greatest(Func):
101 """
102 Return the maximum expression.
103
104 If any expression is null the return value is database-specific:
105 On PostgreSQL, the maximum not-null expression is returned.
106 On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
107 """
108
109 function = "GREATEST"
110
111 def __init__(self, *expressions, **extra):
112 if len(expressions) < 2:
113 raise ValueError("Greatest must take at least two expressions")
114 super().__init__(*expressions, **extra)
115
116 def as_sqlite(self, compiler, connection, **extra_context):
117 """Use the MAX function on SQLite."""
118 return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
119
120
121class JSONObject(Func):
122 function = "JSON_OBJECT"
123 output_field = JSONField()
124
125 def __init__(self, **fields):
126 expressions = []
127 for key, value in fields.items():
128 expressions.extend((Value(key), value))
129 super().__init__(*expressions)
130
131 def as_sql(self, compiler, connection, **extra_context):
132 if not connection.features.has_json_object_function:
133 raise NotSupportedError(
134 "JSONObject() is not supported on this database backend."
135 )
136 return super().as_sql(compiler, connection, **extra_context)
137
138 def as_postgresql(self, compiler, connection, **extra_context):
139 copy = self.copy()
140 copy.set_source_expressions(
141 [
142 Cast(expression, TextField()) if index % 2 == 0 else expression
143 for index, expression in enumerate(copy.get_source_expressions())
144 ]
145 )
146 return super(JSONObject, copy).as_sql(
147 compiler,
148 connection,
149 function="JSONB_BUILD_OBJECT",
150 **extra_context,
151 )
152
153
154class Least(Func):
155 """
156 Return the minimum expression.
157
158 If any expression is null the return value is database-specific:
159 On PostgreSQL, return the minimum not-null expression.
160 On MySQL, Oracle, and SQLite, if any expression is null, return null.
161 """
162
163 function = "LEAST"
164
165 def __init__(self, *expressions, **extra):
166 if len(expressions) < 2:
167 raise ValueError("Least must take at least two expressions")
168 super().__init__(*expressions, **extra)
169
170 def as_sqlite(self, compiler, connection, **extra_context):
171 """Use the MIN function on SQLite."""
172 return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
173
174
175class NullIf(Func):
176 function = "NULLIF"
177 arity = 2