Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Query subclasses which provide extra functionality beyond simple data retrieval.
  3"""
  4
  5from __future__ import annotations
  6
  7from typing import Any
  8
  9from plain.models.exceptions import FieldError
 10from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
 11from plain.models.sql.query import Query
 12
 13__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
 14
 15
 16class DeleteQuery(Query):
 17    """A DELETE SQL query."""
 18
 19    compiler = "SQLDeleteCompiler"
 20
 21    def do_query(self, table: str, where: Any) -> int:
 22        self.alias_map = {table: self.alias_map[table]}
 23        self.where = where
 24        cursor = self.get_compiler().execute_sql(CURSOR)
 25        if cursor:
 26            with cursor:
 27                return cursor.rowcount
 28        return 0
 29
 30    def delete_batch(self, id_list: list[Any]) -> int:
 31        """
 32        Set up and execute delete queries for all the objects in id_list.
 33
 34        More than one physical query may be executed if there are a
 35        lot of values in id_list.
 36        """
 37        # number of objects deleted
 38        num_deleted = 0
 39        field = self.get_model_meta().get_field("id")
 40        for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
 41            self.clear_where()
 42            self.add_filter(
 43                f"{field.attname}__in",
 44                id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
 45            )
 46            num_deleted += self.do_query(self.model.model_options.db_table, self.where)
 47        return num_deleted
 48
 49
 50class UpdateQuery(Query):
 51    """An UPDATE SQL query."""
 52
 53    compiler = "SQLUpdateCompiler"
 54
 55    def __init__(self, *args: Any, **kwargs: Any) -> None:
 56        super().__init__(*args, **kwargs)
 57        self._setup_query()
 58
 59    def _setup_query(self) -> None:
 60        """
 61        Run on initialization and at the end of chaining. Any attributes that
 62        would normally be set in __init__() should go here instead.
 63        """
 64        self.values: list[tuple[Any, Any, Any]] = []
 65        self.related_ids: dict[Any, list[Any]] | None = None
 66        self.related_updates: dict[Any, list[tuple[Any, Any, Any]]] = {}
 67
 68    def clone(self) -> UpdateQuery:
 69        obj = super().clone()
 70        obj.related_updates = self.related_updates.copy()
 71        return obj  # type: ignore[return-value]
 72
 73    def update_batch(self, id_list: list[Any], values: dict[str, Any]) -> None:
 74        self.add_update_values(values)
 75        for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
 76            self.clear_where()
 77            self.add_filter(
 78                "id__in", id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
 79            )
 80            self.get_compiler().execute_sql(NO_RESULTS)
 81
 82    def add_update_values(self, values: dict[str, Any]) -> list[tuple[Any, Any, Any]]:
 83        """
 84        Convert a dictionary of field name to value mappings into an update
 85        query. This is the entry point for the public update() method on
 86        querysets.
 87        """
 88        values_seq = []
 89        for name, val in values.items():
 90            field = self.get_model_meta().get_field(name)
 91            direct = (
 92                not (field.auto_created and not field.concrete) or not field.concrete
 93            )
 94            model = field.model
 95            if not direct or (field.is_relation and field.many_to_many):
 96                raise FieldError(
 97                    f"Cannot update model field {field!r} (only non-relations and "
 98                    "foreign keys permitted)."
 99                )
100            if model is not self.get_model_meta().model:
101                self.add_related_update(model, field, val)
102                continue
103            values_seq.append((field, model, val))
104        return self.add_update_fields(values_seq)
105
106    def add_update_fields(self, values_seq: list[tuple[Any, Any, Any]]) -> None:
107        """
108        Append a sequence of (field, model, value) triples to the internal list
109        that will be used to generate the UPDATE query. Might be more usefully
110        called add_update_targets() to hint at the extra information here.
111        """
112        for field, model, val in values_seq:
113            if hasattr(val, "resolve_expression"):
114                # Resolve expressions here so that annotations are no longer needed
115                val = val.resolve_expression(self, allow_joins=False, for_save=True)
116            self.values.append((field, model, val))
117
118    def add_related_update(self, model: Any, field: Any, value: Any) -> None:
119        """
120        Add (name, value) to an update query for an ancestor model.
121
122        Update are coalesced so that only one update query per ancestor is run.
123        """
124        self.related_updates.setdefault(model, []).append((field, None, value))
125
126    def get_related_updates(self) -> list[UpdateQuery]:
127        """
128        Return a list of query objects: one for each update required to an
129        ancestor model. Each query will have the same filtering conditions as
130        the current query but will only update a single table.
131        """
132        if not self.related_updates:
133            return []
134        result = []
135        for model, values in self.related_updates.items():
136            query = UpdateQuery(model)
137            query.values = values
138            if self.related_ids is not None:
139                query.add_filter("id__in", self.related_ids[model])
140            result.append(query)
141        return result
142
143
144class InsertQuery(Query):
145    compiler = "SQLInsertCompiler"
146
147    def __init__(
148        self,
149        *args: Any,
150        on_conflict: str | None = None,
151        update_fields: list[Any] | None = None,
152        unique_fields: list[Any] | None = None,
153        **kwargs: Any,
154    ) -> None:
155        super().__init__(*args, **kwargs)
156        self.fields: list[Any] = []
157        self.objs: list[Any] = []
158        self.on_conflict = on_conflict
159        self.update_fields = update_fields or []
160        self.unique_fields = unique_fields or []
161
162    def insert_values(
163        self, fields: list[Any], objs: list[Any], raw: bool = False
164    ) -> None:
165        self.fields = fields
166        self.objs = objs
167        self.raw = raw  # type: ignore[attr-defined]
168
169
170class AggregateQuery(Query):
171    """
172    Take another query as a parameter to the FROM clause and only select the
173    elements in the provided list.
174    """
175
176    compiler = "SQLAggregateCompiler"
177
178    def __init__(self, model: Any, inner_query: Any) -> None:
179        self.inner_query = inner_query
180        super().__init__(model)