1"""
2Query subclasses which provide extra functionality beyond simple data retrieval.
3"""
4
5from plain.exceptions import FieldError
6from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
7from plain.models.sql.query import Query
8
9__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
10
11
12class DeleteQuery(Query):
13 """A DELETE SQL query."""
14
15 compiler = "SQLDeleteCompiler"
16
17 def do_query(self, table, where):
18 self.alias_map = {table: self.alias_map[table]}
19 self.where = where
20 cursor = self.get_compiler().execute_sql(CURSOR)
21 if cursor:
22 with cursor:
23 return cursor.rowcount
24 return 0
25
26 def delete_batch(self, pk_list):
27 """
28 Set up and execute delete queries for all the objects in pk_list.
29
30 More than one physical query may be executed if there are a
31 lot of values in pk_list.
32 """
33 # number of objects deleted
34 num_deleted = 0
35 field = self.get_meta().pk
36 for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
37 self.clear_where()
38 self.add_filter(
39 f"{field.attname}__in",
40 pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
41 )
42 num_deleted += self.do_query(self.get_meta().db_table, self.where)
43 return num_deleted
44
45
46class UpdateQuery(Query):
47 """An UPDATE SQL query."""
48
49 compiler = "SQLUpdateCompiler"
50
51 def __init__(self, *args, **kwargs):
52 super().__init__(*args, **kwargs)
53 self._setup_query()
54
55 def _setup_query(self):
56 """
57 Run on initialization and at the end of chaining. Any attributes that
58 would normally be set in __init__() should go here instead.
59 """
60 self.values = []
61 self.related_ids = None
62 self.related_updates = {}
63
64 def clone(self):
65 obj = super().clone()
66 obj.related_updates = self.related_updates.copy()
67 return obj
68
69 def update_batch(self, pk_list, values):
70 self.add_update_values(values)
71 for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
72 self.clear_where()
73 self.add_filter(
74 "pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
75 )
76 self.get_compiler().execute_sql(NO_RESULTS)
77
78 def add_update_values(self, values):
79 """
80 Convert a dictionary of field name to value mappings into an update
81 query. This is the entry point for the public update() method on
82 querysets.
83 """
84 values_seq = []
85 for name, val in values.items():
86 field = self.get_meta().get_field(name)
87 direct = (
88 not (field.auto_created and not field.concrete) or not field.concrete
89 )
90 model = field.model._meta.concrete_model
91 if not direct or (field.is_relation and field.many_to_many):
92 raise FieldError(
93 f"Cannot update model field {field!r} (only non-relations and "
94 "foreign keys permitted)."
95 )
96 if model is not self.get_meta().concrete_model:
97 self.add_related_update(model, field, val)
98 continue
99 values_seq.append((field, model, val))
100 return self.add_update_fields(values_seq)
101
102 def add_update_fields(self, values_seq):
103 """
104 Append a sequence of (field, model, value) triples to the internal list
105 that will be used to generate the UPDATE query. Might be more usefully
106 called add_update_targets() to hint at the extra information here.
107 """
108 for field, model, val in values_seq:
109 if hasattr(val, "resolve_expression"):
110 # Resolve expressions here so that annotations are no longer needed
111 val = val.resolve_expression(self, allow_joins=False, for_save=True)
112 self.values.append((field, model, val))
113
114 def add_related_update(self, model, field, value):
115 """
116 Add (name, value) to an update query for an ancestor model.
117
118 Update are coalesced so that only one update query per ancestor is run.
119 """
120 self.related_updates.setdefault(model, []).append((field, None, value))
121
122 def get_related_updates(self):
123 """
124 Return a list of query objects: one for each update required to an
125 ancestor model. Each query will have the same filtering conditions as
126 the current query but will only update a single table.
127 """
128 if not self.related_updates:
129 return []
130 result = []
131 for model, values in self.related_updates.items():
132 query = UpdateQuery(model)
133 query.values = values
134 if self.related_ids is not None:
135 query.add_filter("pk__in", self.related_ids[model])
136 result.append(query)
137 return result
138
139
140class InsertQuery(Query):
141 compiler = "SQLInsertCompiler"
142
143 def __init__(
144 self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs
145 ):
146 super().__init__(*args, **kwargs)
147 self.fields = []
148 self.objs = []
149 self.on_conflict = on_conflict
150 self.update_fields = update_fields or []
151 self.unique_fields = unique_fields or []
152
153 def insert_values(self, fields, objs, raw=False):
154 self.fields = fields
155 self.objs = objs
156 self.raw = raw
157
158
159class AggregateQuery(Query):
160 """
161 Take another query as a parameter to the FROM clause and only select the
162 elements in the provided list.
163 """
164
165 compiler = "SQLAggregateCompiler"
166
167 def __init__(self, model, inner_query):
168 self.inner_query = inner_query
169 super().__init__(model)