Plain is headed towards 1.0! Subscribe for development updates →

 1from __future__ import annotations
 2
 3from typing import TYPE_CHECKING, Any
 4
 5from plain.models.exceptions import FieldError, FullResultSet
 6from plain.models.expressions import Col
 7from plain.models.sql import compiler
 8
 9if TYPE_CHECKING:
10    from plain.models.sql.compiler import SQLCompiler as BaseSQLCompiler
11
12
13class SQLCompiler(compiler.SQLCompiler):
14    def as_subquery_condition(
15        self, alias: str, columns: list[str], compiler: BaseSQLCompiler
16    ) -> tuple[str, tuple[Any, ...]]:
17        qn = compiler.quote_name_unless_alias
18        qn2 = self.connection.ops.quote_name
19        sql, params = self.as_sql()
20        return (
21            "({}) IN ({})".format(
22                ", ".join(f"{qn(alias)}.{qn2(column)}" for column in columns),
23                sql,
24            ),
25            params,
26        )
27
28
29class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
30    pass
31
32
33class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
34    def as_sql(
35        self, with_limits: bool = True, with_col_aliases: bool = False
36    ) -> tuple[str, tuple[Any, ...]]:
37        # Prefer the non-standard DELETE FROM syntax over the SQL generated by
38        # the SQLDeleteCompiler's default implementation when multiple tables
39        # are involved since MySQL/MariaDB will generate a more efficient query
40        # plan than when using a subquery.
41        where, having, qualify = self.query.where.split_having_qualify(
42            must_group_by=self.query.group_by is not None
43        )
44        if self.single_alias or having or qualify:
45            # DELETE FROM cannot be used when filtering against aggregates or
46            # window functions as it doesn't allow for GROUP BY/HAVING clauses
47            # and the subquery wrapping (necessary to emulate QUALIFY).
48            return super().as_sql()
49        initial_alias = self.query.get_initial_alias()
50        assert initial_alias is not None, "DELETE query must have an initial alias"
51        result = [f"DELETE {self.quote_name_unless_alias(initial_alias)} FROM"]
52        from_sql, params = self.get_from_clause()
53        result.extend(from_sql)
54        if where is not None:
55            try:
56                where_sql, where_params = self.compile(where)
57            except FullResultSet:
58                pass
59            else:
60                result.append(f"WHERE {where_sql}")
61                params.extend(where_params)
62        return " ".join(result), tuple(params)
63
64
65class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
66    def as_sql(
67        self, with_limits: bool = True, with_col_aliases: bool = False
68    ) -> tuple[str, tuple[Any, ...]]:
69        update_query, update_params = super().as_sql()
70        # MySQL and MariaDB support UPDATE ... ORDER BY syntax.
71        if self.query.order_by:
72            assert self.query.model is not None, "UPDATE requires a model"
73            order_by_sql = []
74            order_by_params = []
75            db_table = self.query.model.model_options.db_table
76            try:
77                for resolved, (sql, params, _) in self.get_order_by():
78                    if (
79                        isinstance(resolved.expression, Col)
80                        and resolved.expression.alias != db_table
81                    ):
82                        # Ignore ordering if it contains joined fields, because
83                        # they cannot be used in the ORDER BY clause.
84                        raise FieldError
85                    order_by_sql.append(sql)
86                    order_by_params.extend(params)
87                update_query += " ORDER BY " + ", ".join(order_by_sql)
88                update_params += tuple(order_by_params)
89            except FieldError:
90                # Ignore ordering if it contains annotations, because they're
91                # removed in .update() and cannot be resolved.
92                pass
93        return update_query, update_params
94
95
96class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
97    pass