1"""
2Query subclasses which provide extra functionality beyond simple data retrieval.
3"""
4
5from __future__ import annotations
6
7from typing import Any
8
9from plain.models.exceptions import FieldError
10from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
11from plain.models.sql.query import Query
12
13__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
14
15
16class DeleteQuery(Query):
17 """A DELETE SQL query."""
18
19 compiler = "SQLDeleteCompiler"
20
21 def do_query(self, table: str, where: Any) -> int:
22 self.alias_map = {table: self.alias_map[table]}
23 self.where = where
24 cursor = self.get_compiler().execute_sql(CURSOR)
25 if cursor:
26 with cursor:
27 return cursor.rowcount
28 return 0
29
30 def delete_batch(self, id_list: list[Any]) -> int:
31 """
32 Set up and execute delete queries for all the objects in id_list.
33
34 More than one physical query may be executed if there are a
35 lot of values in id_list.
36 """
37 # number of objects deleted
38 num_deleted = 0
39 field = self.get_model_meta().get_field("id")
40 for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
41 self.clear_where()
42 self.add_filter(
43 f"{field.attname}__in",
44 id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
45 )
46 num_deleted += self.do_query(self.model.model_options.db_table, self.where)
47 return num_deleted
48
49
50class UpdateQuery(Query):
51 """An UPDATE SQL query."""
52
53 compiler = "SQLUpdateCompiler"
54
55 def __init__(self, *args: Any, **kwargs: Any) -> None:
56 super().__init__(*args, **kwargs)
57 self._setup_query()
58
59 def _setup_query(self) -> None:
60 """
61 Run on initialization and at the end of chaining. Any attributes that
62 would normally be set in __init__() should go here instead.
63 """
64 self.values: list[tuple[Any, Any, Any]] = []
65 self.related_ids: dict[Any, list[Any]] | None = None
66 self.related_updates: dict[Any, list[tuple[Any, Any, Any]]] = {}
67
68 def clone(self) -> UpdateQuery:
69 obj = super().clone()
70 obj.related_updates = self.related_updates.copy()
71 return obj # type: ignore[return-value]
72
73 def update_batch(self, id_list: list[Any], values: dict[str, Any]) -> None:
74 self.add_update_values(values)
75 for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
76 self.clear_where()
77 self.add_filter(
78 "id__in", id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
79 )
80 self.get_compiler().execute_sql(NO_RESULTS)
81
82 def add_update_values(self, values: dict[str, Any]) -> list[tuple[Any, Any, Any]]:
83 """
84 Convert a dictionary of field name to value mappings into an update
85 query. This is the entry point for the public update() method on
86 querysets.
87 """
88 values_seq = []
89 for name, val in values.items():
90 field = self.get_model_meta().get_field(name)
91 direct = (
92 not (field.auto_created and not field.concrete) or not field.concrete
93 )
94 model = field.model
95 if not direct or (field.is_relation and field.many_to_many):
96 raise FieldError(
97 f"Cannot update model field {field!r} (only non-relations and "
98 "foreign keys permitted)."
99 )
100 if model is not self.get_model_meta().model:
101 self.add_related_update(model, field, val)
102 continue
103 values_seq.append((field, model, val))
104 return self.add_update_fields(values_seq)
105
106 def add_update_fields(self, values_seq: list[tuple[Any, Any, Any]]) -> None:
107 """
108 Append a sequence of (field, model, value) triples to the internal list
109 that will be used to generate the UPDATE query. Might be more usefully
110 called add_update_targets() to hint at the extra information here.
111 """
112 for field, model, val in values_seq:
113 if hasattr(val, "resolve_expression"):
114 # Resolve expressions here so that annotations are no longer needed
115 val = val.resolve_expression(self, allow_joins=False, for_save=True)
116 self.values.append((field, model, val))
117
118 def add_related_update(self, model: Any, field: Any, value: Any) -> None:
119 """
120 Add (name, value) to an update query for an ancestor model.
121
122 Update are coalesced so that only one update query per ancestor is run.
123 """
124 self.related_updates.setdefault(model, []).append((field, None, value))
125
126 def get_related_updates(self) -> list[UpdateQuery]:
127 """
128 Return a list of query objects: one for each update required to an
129 ancestor model. Each query will have the same filtering conditions as
130 the current query but will only update a single table.
131 """
132 if not self.related_updates:
133 return []
134 result = []
135 for model, values in self.related_updates.items():
136 query = UpdateQuery(model)
137 query.values = values
138 if self.related_ids is not None:
139 query.add_filter("id__in", self.related_ids[model])
140 result.append(query)
141 return result
142
143
144class InsertQuery(Query):
145 compiler = "SQLInsertCompiler"
146
147 def __init__(
148 self,
149 *args: Any,
150 on_conflict: str | None = None,
151 update_fields: list[Any] | None = None,
152 unique_fields: list[Any] | None = None,
153 **kwargs: Any,
154 ) -> None:
155 super().__init__(*args, **kwargs)
156 self.fields: list[Any] = []
157 self.objs: list[Any] = []
158 self.on_conflict = on_conflict
159 self.update_fields = update_fields or []
160 self.unique_fields = unique_fields or []
161
162 def insert_values(
163 self, fields: list[Any], objs: list[Any], raw: bool = False
164 ) -> None:
165 self.fields = fields
166 self.objs = objs
167 self.raw = raw # type: ignore[attr-defined]
168
169
170class AggregateQuery(Query):
171 """
172 Take another query as a parameter to the FROM clause and only select the
173 elements in the provided list.
174 """
175
176 compiler = "SQLAggregateCompiler"
177
178 def __init__(self, model: Any, inner_query: Any) -> None:
179 self.inner_query = inner_query
180 super().__init__(model)