Plain is headed towards 1.0! Subscribe for development updates →

  1"""Database functions that do comparisons or type conversions."""
  2from plain.models.db import NotSupportedError
  3from plain.models.expressions import Func, Value
  4from plain.models.fields import TextField
  5from plain.models.fields.json import JSONField
  6from plain.utils.regex_helper import _lazy_re_compile
  7
  8
  9class Cast(Func):
 10    """Coerce an expression to a new field type."""
 11
 12    function = "CAST"
 13    template = "%(function)s(%(expressions)s AS %(db_type)s)"
 14
 15    def __init__(self, expression, output_field):
 16        super().__init__(expression, output_field=output_field)
 17
 18    def as_sql(self, compiler, connection, **extra_context):
 19        extra_context["db_type"] = self.output_field.cast_db_type(connection)
 20        return super().as_sql(compiler, connection, **extra_context)
 21
 22    def as_sqlite(self, compiler, connection, **extra_context):
 23        db_type = self.output_field.db_type(connection)
 24        if db_type in {"datetime", "time"}:
 25            # Use strftime as datetime/time don't keep fractional seconds.
 26            template = "strftime(%%s, %(expressions)s)"
 27            sql, params = super().as_sql(
 28                compiler, connection, template=template, **extra_context
 29            )
 30            format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
 31            params.insert(0, format_string)
 32            return sql, params
 33        elif db_type == "date":
 34            template = "date(%(expressions)s)"
 35            return super().as_sql(
 36                compiler, connection, template=template, **extra_context
 37            )
 38        return self.as_sql(compiler, connection, **extra_context)
 39
 40    def as_mysql(self, compiler, connection, **extra_context):
 41        template = None
 42        output_type = self.output_field.get_internal_type()
 43        # MySQL doesn't support explicit cast to float.
 44        if output_type == "FloatField":
 45            template = "(%(expressions)s + 0.0)"
 46        # MariaDB doesn't support explicit cast to JSON.
 47        elif output_type == "JSONField" and connection.mysql_is_mariadb:
 48            template = "JSON_EXTRACT(%(expressions)s, '$')"
 49        return self.as_sql(compiler, connection, template=template, **extra_context)
 50
 51    def as_postgresql(self, compiler, connection, **extra_context):
 52        # CAST would be valid too, but the :: shortcut syntax is more readable.
 53        # 'expressions' is wrapped in parentheses in case it's a complex
 54        # expression.
 55        return self.as_sql(
 56            compiler,
 57            connection,
 58            template="(%(expressions)s)::%(db_type)s",
 59            **extra_context,
 60        )
 61
 62
 63class Coalesce(Func):
 64    """Return, from left to right, the first non-null expression."""
 65
 66    function = "COALESCE"
 67
 68    def __init__(self, *expressions, **extra):
 69        if len(expressions) < 2:
 70            raise ValueError("Coalesce must take at least two expressions")
 71        super().__init__(*expressions, **extra)
 72
 73    @property
 74    def empty_result_set_value(self):
 75        for expression in self.get_source_expressions():
 76            result = expression.empty_result_set_value
 77            if result is NotImplemented or result is not None:
 78                return result
 79        return None
 80
 81
 82class Collate(Func):
 83    function = "COLLATE"
 84    template = "%(expressions)s %(function)s %(collation)s"
 85    # Inspired from
 86    # https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
 87    collation_re = _lazy_re_compile(r"^[\w\-]+$")
 88
 89    def __init__(self, expression, collation):
 90        if not (collation and self.collation_re.match(collation)):
 91            raise ValueError("Invalid collation name: %r." % collation)
 92        self.collation = collation
 93        super().__init__(expression)
 94
 95    def as_sql(self, compiler, connection, **extra_context):
 96        extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
 97        return super().as_sql(compiler, connection, **extra_context)
 98
 99
100class Greatest(Func):
101    """
102    Return the maximum expression.
103
104    If any expression is null the return value is database-specific:
105    On PostgreSQL, the maximum not-null expression is returned.
106    On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
107    """
108
109    function = "GREATEST"
110
111    def __init__(self, *expressions, **extra):
112        if len(expressions) < 2:
113            raise ValueError("Greatest must take at least two expressions")
114        super().__init__(*expressions, **extra)
115
116    def as_sqlite(self, compiler, connection, **extra_context):
117        """Use the MAX function on SQLite."""
118        return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
119
120
121class JSONObject(Func):
122    function = "JSON_OBJECT"
123    output_field = JSONField()
124
125    def __init__(self, **fields):
126        expressions = []
127        for key, value in fields.items():
128            expressions.extend((Value(key), value))
129        super().__init__(*expressions)
130
131    def as_sql(self, compiler, connection, **extra_context):
132        if not connection.features.has_json_object_function:
133            raise NotSupportedError(
134                "JSONObject() is not supported on this database backend."
135            )
136        return super().as_sql(compiler, connection, **extra_context)
137
138    def as_postgresql(self, compiler, connection, **extra_context):
139        copy = self.copy()
140        copy.set_source_expressions(
141            [
142                Cast(expression, TextField()) if index % 2 == 0 else expression
143                for index, expression in enumerate(copy.get_source_expressions())
144            ]
145        )
146        return super(JSONObject, copy).as_sql(
147            compiler,
148            connection,
149            function="JSONB_BUILD_OBJECT",
150            **extra_context,
151        )
152
153
154class Least(Func):
155    """
156    Return the minimum expression.
157
158    If any expression is null the return value is database-specific:
159    On PostgreSQL, return the minimum not-null expression.
160    On MySQL, Oracle, and SQLite, if any expression is null, return null.
161    """
162
163    function = "LEAST"
164
165    def __init__(self, *expressions, **extra):
166        if len(expressions) < 2:
167            raise ValueError("Least must take at least two expressions")
168        super().__init__(*expressions, **extra)
169
170    def as_sqlite(self, compiler, connection, **extra_context):
171        """Use the MIN function on SQLite."""
172        return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
173
174
175class NullIf(Func):
176    function = "NULLIF"
177    arity = 2