1"""
2Query subclasses which provide extra functionality beyond simple data retrieval.
3"""
4
5from __future__ import annotations
6
7from typing import TYPE_CHECKING, Any
8
9from plain.models.exceptions import FieldError
10from plain.models.expressions import ResolvableExpression
11from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
12from plain.models.sql.query import Query
13
14if TYPE_CHECKING:
15 from plain.models.fields import Field
16
17
18class DeleteQuery(Query):
19 """A DELETE SQL query."""
20
21 def do_query(self, table: str, where: Any) -> int:
22 self.alias_map = {table: self.alias_map[table]}
23 self.where = where
24 cursor = self.get_compiler().execute_sql(CURSOR)
25 if cursor:
26 with cursor:
27 return cursor.rowcount
28 return 0
29
30 def delete_batch(self, id_list: list[Any]) -> int:
31 """
32 Set up and execute delete queries for all the objects in id_list.
33
34 More than one physical query may be executed if there are a
35 lot of values in id_list.
36 """
37 # number of objects deleted
38 num_deleted = 0
39 assert self.model is not None, "DELETE requires a model"
40 meta = self.model._model_meta
41 field = meta.get_forward_field("id")
42 for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
43 self.clear_where()
44 self.add_filter(
45 f"{field.attname}__in",
46 id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
47 )
48 num_deleted += self.do_query(self.model.model_options.db_table, self.where)
49 return num_deleted
50
51
52class UpdateQuery(Query):
53 """An UPDATE SQL query."""
54
55 def __init__(self, *args: Any, **kwargs: Any) -> None:
56 super().__init__(*args, **kwargs)
57 self._setup_query()
58
59 def _setup_query(self) -> None:
60 """
61 Run on initialization and at the end of chaining. Any attributes that
62 would normally be set in __init__() should go here instead.
63 """
64 self.values: list[tuple[Any, Any, Any]] = []
65 self.related_ids: dict[Any, list[Any]] | None = None
66 self.related_updates: dict[Any, list[tuple[Any, Any, Any]]] = {}
67
68 def clone(self) -> UpdateQuery:
69 obj = super().clone()
70 obj.related_updates = self.related_updates.copy()
71 return obj
72
73 def update_batch(self, id_list: list[Any], values: dict[str, Any]) -> None:
74 self.add_update_values(values)
75 for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
76 self.clear_where()
77 self.add_filter(
78 "id__in", id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
79 )
80 self.get_compiler().execute_sql(NO_RESULTS)
81
82 def add_update_values(self, values: dict[str, Any]) -> None:
83 """
84 Convert a dictionary of field name to value mappings into an update
85 query. This is the entry point for the public update() method on
86 querysets.
87 """
88
89 assert self.model is not None, "UPDATE requires model metadata"
90 meta = self.model._model_meta
91 values_seq = []
92 for name, val in values.items():
93 field = meta.get_field(name)
94 direct = (
95 not (field.auto_created and not field.concrete) or not field.concrete
96 )
97 model = field.model
98 from plain.models.fields.related import ManyToManyField
99
100 if not direct or isinstance(field, ManyToManyField):
101 raise FieldError(
102 f"Cannot update model field {field!r} (only non-relations and "
103 "foreign keys permitted)."
104 )
105 if model is not meta.model:
106 self.add_related_update(model, field, val)
107 continue
108 values_seq.append((field, model, val))
109 return self.add_update_fields(values_seq)
110
111 def add_update_fields(self, values_seq: list[tuple[Any, Any, Any]]) -> None:
112 """
113 Append a sequence of (field, model, value) triples to the internal list
114 that will be used to generate the UPDATE query. Might be more usefully
115 called add_update_targets() to hint at the extra information here.
116 """
117 for field, model, val in values_seq:
118 if isinstance(val, ResolvableExpression):
119 # Resolve expressions here so that annotations are no longer needed
120 val = val.resolve_expression(self, allow_joins=False, for_save=True)
121 self.values.append((field, model, val))
122
123 def add_related_update(self, model: Any, field: Any, value: Any) -> None:
124 """
125 Add (name, value) to an update query for an ancestor model.
126
127 Update are coalesced so that only one update query per ancestor is run.
128 """
129 self.related_updates.setdefault(model, []).append((field, None, value))
130
131 def get_related_updates(self) -> list[UpdateQuery]:
132 """
133 Return a list of query objects: one for each update required to an
134 ancestor model. Each query will have the same filtering conditions as
135 the current query but will only update a single table.
136 """
137 if not self.related_updates:
138 return []
139 result = []
140 for model, values in self.related_updates.items():
141 query = UpdateQuery(model)
142 query.values = values
143 if self.related_ids is not None:
144 query.add_filter("id__in", self.related_ids[model])
145 result.append(query)
146 return result
147
148
149class InsertQuery(Query):
150 def __str__(self) -> str:
151 raise NotImplementedError(
152 "InsertQuery does not support __str__(). "
153 "Use get_compiler().as_sql() which returns a list of SQL statements."
154 )
155
156 def sql_with_params(self) -> Any:
157 raise NotImplementedError(
158 "InsertQuery does not support sql_with_params(). "
159 "Use get_compiler().as_sql() which returns a list of SQL statements."
160 )
161
162 def __init__(
163 self,
164 *args: Any,
165 on_conflict: str | None = None,
166 update_fields: list[Field] | None = None,
167 unique_fields: list[Field] | None = None,
168 **kwargs: Any,
169 ) -> None:
170 super().__init__(*args, **kwargs)
171 self.fields: list[Field] = []
172 self.objs: list[Any] = []
173 self.on_conflict = on_conflict
174 self.update_fields: list[Field] = update_fields or []
175 self.unique_fields: list[Field] = unique_fields or []
176
177 def insert_values(
178 self, fields: list[Any], objs: list[Any], raw: bool = False
179 ) -> None:
180 self.fields = fields
181 self.objs = objs
182 self.raw = raw
183
184
185class AggregateQuery(Query):
186 """
187 Take another query as a parameter to the FROM clause and only select the
188 elements in the provided list.
189 """
190
191 def __init__(self, model: Any, inner_query: Any) -> None:
192 self.inner_query = inner_query
193 super().__init__(model)