1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5from plain.models.exceptions import FieldError, FullResultSet
6from plain.models.expressions import Col
7from plain.models.sql import compiler
8
9if TYPE_CHECKING:
10 from plain.models.sql.compiler import SQLCompiler as BaseSQLCompiler
11
12
13class SQLCompiler(compiler.SQLCompiler):
14 def as_subquery_condition(
15 self, alias: str, columns: list[str], compiler: BaseSQLCompiler
16 ) -> tuple[str, tuple[Any, ...]]:
17 qn = compiler.quote_name_unless_alias
18 qn2 = self.connection.ops.quote_name
19 sql, params = self.as_sql()
20 return (
21 "({}) IN ({})".format(
22 ", ".join(f"{qn(alias)}.{qn2(column)}" for column in columns),
23 sql,
24 ),
25 params,
26 )
27
28
29class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
30 pass
31
32
33class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
34 def as_sql(self) -> tuple[str, tuple[Any, ...]]:
35 # Prefer the non-standard DELETE FROM syntax over the SQL generated by
36 # the SQLDeleteCompiler's default implementation when multiple tables
37 # are involved since MySQL/MariaDB will generate a more efficient query
38 # plan than when using a subquery.
39 where, having, qualify = self.query.where.split_having_qualify(
40 must_group_by=self.query.group_by is not None
41 )
42 if self.single_alias or having or qualify:
43 # DELETE FROM cannot be used when filtering against aggregates or
44 # window functions as it doesn't allow for GROUP BY/HAVING clauses
45 # and the subquery wrapping (necessary to emulate QUALIFY).
46 return super().as_sql()
47 result = [
48 f"DELETE {self.quote_name_unless_alias(self.query.get_initial_alias())} FROM"
49 ]
50 from_sql, params = self.get_from_clause()
51 result.extend(from_sql)
52 try:
53 where_sql, where_params = self.compile(where)
54 except FullResultSet:
55 pass
56 else:
57 result.append(f"WHERE {where_sql}")
58 params.extend(where_params)
59 return " ".join(result), tuple(params)
60
61
62class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
63 def as_sql(self) -> tuple[str, tuple[Any, ...]]:
64 update_query, update_params = super().as_sql()
65 # MySQL and MariaDB support UPDATE ... ORDER BY syntax.
66 if self.query.order_by:
67 order_by_sql = []
68 order_by_params = []
69 db_table = self.query.get_model_meta().db_table
70 try:
71 for resolved, (sql, params, _) in self.get_order_by():
72 if (
73 isinstance(resolved.expression, Col)
74 and resolved.expression.alias != db_table
75 ):
76 # Ignore ordering if it contains joined fields, because
77 # they cannot be used in the ORDER BY clause.
78 raise FieldError
79 order_by_sql.append(sql)
80 order_by_params.extend(params)
81 update_query += " ORDER BY " + ", ".join(order_by_sql)
82 update_params += tuple(order_by_params)
83 except FieldError:
84 # Ignore ordering if it contains annotations, because they're
85 # removed in .update() and cannot be resolved.
86 pass
87 return update_query, update_params
88
89
90class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
91 pass