1"""
  2Query subclasses which provide extra functionality beyond simple data retrieval.
  3"""
  4
  5from __future__ import annotations
  6
  7from typing import TYPE_CHECKING, Any
  8
  9from plain.models.exceptions import FieldError
 10from plain.models.expressions import ResolvableExpression
 11from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
 12from plain.models.sql.query import Query
 13
 14if TYPE_CHECKING:
 15    from plain.models.fields import Field
 16
 17
 18class DeleteQuery(Query):
 19    """A DELETE SQL query."""
 20
 21    def do_query(self, table: str, where: Any) -> int:
 22        self.alias_map = {table: self.alias_map[table]}
 23        self.where = where
 24        cursor = self.get_compiler().execute_sql(CURSOR)
 25        if cursor:
 26            with cursor:
 27                return cursor.rowcount
 28        return 0
 29
 30    def delete_batch(self, id_list: list[Any]) -> int:
 31        """
 32        Set up and execute delete queries for all the objects in id_list.
 33
 34        More than one physical query may be executed if there are a
 35        lot of values in id_list.
 36        """
 37        # number of objects deleted
 38        num_deleted = 0
 39        assert self.model is not None, "DELETE requires a model"
 40        meta = self.model._model_meta
 41        field = meta.get_forward_field("id")
 42        for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
 43            self.clear_where()
 44            self.add_filter(
 45                f"{field.attname}__in",
 46                id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
 47            )
 48            num_deleted += self.do_query(self.model.model_options.db_table, self.where)
 49        return num_deleted
 50
 51
 52class UpdateQuery(Query):
 53    """An UPDATE SQL query."""
 54
 55    def __init__(self, *args: Any, **kwargs: Any) -> None:
 56        super().__init__(*args, **kwargs)
 57        self._setup_query()
 58
 59    def _setup_query(self) -> None:
 60        """
 61        Run on initialization and at the end of chaining. Any attributes that
 62        would normally be set in __init__() should go here instead.
 63        """
 64        self.values: list[tuple[Any, Any, Any]] = []
 65        self.related_ids: dict[Any, list[Any]] | None = None
 66        self.related_updates: dict[Any, list[tuple[Any, Any, Any]]] = {}
 67
 68    def clone(self) -> UpdateQuery:
 69        obj = super().clone()
 70        obj.related_updates = self.related_updates.copy()
 71        return obj
 72
 73    def update_batch(self, id_list: list[Any], values: dict[str, Any]) -> None:
 74        self.add_update_values(values)
 75        for offset in range(0, len(id_list), GET_ITERATOR_CHUNK_SIZE):
 76            self.clear_where()
 77            self.add_filter(
 78                "id__in", id_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
 79            )
 80            self.get_compiler().execute_sql(NO_RESULTS)
 81
 82    def add_update_values(self, values: dict[str, Any]) -> None:
 83        """
 84        Convert a dictionary of field name to value mappings into an update
 85        query. This is the entry point for the public update() method on
 86        querysets.
 87        """
 88
 89        assert self.model is not None, "UPDATE requires model metadata"
 90        meta = self.model._model_meta
 91        values_seq = []
 92        for name, val in values.items():
 93            field = meta.get_field(name)
 94            direct = (
 95                not (field.auto_created and not field.concrete) or not field.concrete
 96            )
 97            model = field.model
 98            from plain.models.fields.related import ManyToManyField
 99
100            if not direct or isinstance(field, ManyToManyField):
101                raise FieldError(
102                    f"Cannot update model field {field!r} (only non-relations and "
103                    "foreign keys permitted)."
104                )
105            if model is not meta.model:
106                self.add_related_update(model, field, val)
107                continue
108            values_seq.append((field, model, val))
109        return self.add_update_fields(values_seq)
110
111    def add_update_fields(self, values_seq: list[tuple[Any, Any, Any]]) -> None:
112        """
113        Append a sequence of (field, model, value) triples to the internal list
114        that will be used to generate the UPDATE query. Might be more usefully
115        called add_update_targets() to hint at the extra information here.
116        """
117        for field, model, val in values_seq:
118            if isinstance(val, ResolvableExpression):
119                # Resolve expressions here so that annotations are no longer needed
120                val = val.resolve_expression(self, allow_joins=False, for_save=True)
121            self.values.append((field, model, val))
122
123    def add_related_update(self, model: Any, field: Any, value: Any) -> None:
124        """
125        Add (name, value) to an update query for an ancestor model.
126
127        Update are coalesced so that only one update query per ancestor is run.
128        """
129        self.related_updates.setdefault(model, []).append((field, None, value))
130
131    def get_related_updates(self) -> list[UpdateQuery]:
132        """
133        Return a list of query objects: one for each update required to an
134        ancestor model. Each query will have the same filtering conditions as
135        the current query but will only update a single table.
136        """
137        if not self.related_updates:
138            return []
139        result = []
140        for model, values in self.related_updates.items():
141            query = UpdateQuery(model)
142            query.values = values
143            if self.related_ids is not None:
144                query.add_filter("id__in", self.related_ids[model])
145            result.append(query)
146        return result
147
148
149class InsertQuery(Query):
150    def __str__(self) -> str:
151        raise NotImplementedError(
152            "InsertQuery does not support __str__(). "
153            "Use get_compiler().as_sql() which returns a list of SQL statements."
154        )
155
156    def sql_with_params(self) -> Any:
157        raise NotImplementedError(
158            "InsertQuery does not support sql_with_params(). "
159            "Use get_compiler().as_sql() which returns a list of SQL statements."
160        )
161
162    def __init__(
163        self,
164        *args: Any,
165        on_conflict: str | None = None,
166        update_fields: list[Field] | None = None,
167        unique_fields: list[Field] | None = None,
168        **kwargs: Any,
169    ) -> None:
170        super().__init__(*args, **kwargs)
171        self.fields: list[Field] = []
172        self.objs: list[Any] = []
173        self.on_conflict = on_conflict
174        self.update_fields: list[Field] = update_fields or []
175        self.unique_fields: list[Field] = unique_fields or []
176
177    def insert_values(
178        self, fields: list[Any], objs: list[Any], raw: bool = False
179    ) -> None:
180        self.fields = fields
181        self.objs = objs
182        self.raw = raw
183
184
185class AggregateQuery(Query):
186    """
187    Take another query as a parameter to the FROM clause and only select the
188    elements in the provided list.
189    """
190
191    def __init__(self, model: Any, inner_query: Any) -> None:
192        self.inner_query = inner_query
193        super().__init__(model)