Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Query subclasses which provide extra functionality beyond simple data retrieval.
  3"""
  4
  5from plain.exceptions import FieldError
  6from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
  7from plain.models.sql.query import Query
  8
  9__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
 10
 11
 12class DeleteQuery(Query):
 13    """A DELETE SQL query."""
 14
 15    compiler = "SQLDeleteCompiler"
 16
 17    def do_query(self, table, where):
 18        self.alias_map = {table: self.alias_map[table]}
 19        self.where = where
 20        cursor = self.get_compiler().execute_sql(CURSOR)
 21        if cursor:
 22            with cursor:
 23                return cursor.rowcount
 24        return 0
 25
 26    def delete_batch(self, pk_list):
 27        """
 28        Set up and execute delete queries for all the objects in pk_list.
 29
 30        More than one physical query may be executed if there are a
 31        lot of values in pk_list.
 32        """
 33        # number of objects deleted
 34        num_deleted = 0
 35        field = self.get_meta().pk
 36        for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
 37            self.clear_where()
 38            self.add_filter(
 39                f"{field.attname}__in",
 40                pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
 41            )
 42            num_deleted += self.do_query(self.get_meta().db_table, self.where)
 43        return num_deleted
 44
 45
 46class UpdateQuery(Query):
 47    """An UPDATE SQL query."""
 48
 49    compiler = "SQLUpdateCompiler"
 50
 51    def __init__(self, *args, **kwargs):
 52        super().__init__(*args, **kwargs)
 53        self._setup_query()
 54
 55    def _setup_query(self):
 56        """
 57        Run on initialization and at the end of chaining. Any attributes that
 58        would normally be set in __init__() should go here instead.
 59        """
 60        self.values = []
 61        self.related_ids = None
 62        self.related_updates = {}
 63
 64    def clone(self):
 65        obj = super().clone()
 66        obj.related_updates = self.related_updates.copy()
 67        return obj
 68
 69    def update_batch(self, pk_list, values):
 70        self.add_update_values(values)
 71        for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
 72            self.clear_where()
 73            self.add_filter(
 74                "pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
 75            )
 76            self.get_compiler().execute_sql(NO_RESULTS)
 77
 78    def add_update_values(self, values):
 79        """
 80        Convert a dictionary of field name to value mappings into an update
 81        query. This is the entry point for the public update() method on
 82        querysets.
 83        """
 84        values_seq = []
 85        for name, val in values.items():
 86            field = self.get_meta().get_field(name)
 87            direct = (
 88                not (field.auto_created and not field.concrete) or not field.concrete
 89            )
 90            model = field.model._meta.concrete_model
 91            if not direct or (field.is_relation and field.many_to_many):
 92                raise FieldError(
 93                    f"Cannot update model field {field!r} (only non-relations and "
 94                    "foreign keys permitted)."
 95                )
 96            if model is not self.get_meta().concrete_model:
 97                self.add_related_update(model, field, val)
 98                continue
 99            values_seq.append((field, model, val))
100        return self.add_update_fields(values_seq)
101
102    def add_update_fields(self, values_seq):
103        """
104        Append a sequence of (field, model, value) triples to the internal list
105        that will be used to generate the UPDATE query. Might be more usefully
106        called add_update_targets() to hint at the extra information here.
107        """
108        for field, model, val in values_seq:
109            if hasattr(val, "resolve_expression"):
110                # Resolve expressions here so that annotations are no longer needed
111                val = val.resolve_expression(self, allow_joins=False, for_save=True)
112            self.values.append((field, model, val))
113
114    def add_related_update(self, model, field, value):
115        """
116        Add (name, value) to an update query for an ancestor model.
117
118        Update are coalesced so that only one update query per ancestor is run.
119        """
120        self.related_updates.setdefault(model, []).append((field, None, value))
121
122    def get_related_updates(self):
123        """
124        Return a list of query objects: one for each update required to an
125        ancestor model. Each query will have the same filtering conditions as
126        the current query but will only update a single table.
127        """
128        if not self.related_updates:
129            return []
130        result = []
131        for model, values in self.related_updates.items():
132            query = UpdateQuery(model)
133            query.values = values
134            if self.related_ids is not None:
135                query.add_filter("pk__in", self.related_ids[model])
136            result.append(query)
137        return result
138
139
140class InsertQuery(Query):
141    compiler = "SQLInsertCompiler"
142
143    def __init__(
144        self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs
145    ):
146        super().__init__(*args, **kwargs)
147        self.fields = []
148        self.objs = []
149        self.on_conflict = on_conflict
150        self.update_fields = update_fields or []
151        self.unique_fields = unique_fields or []
152
153    def insert_values(self, fields, objs, raw=False):
154        self.fields = fields
155        self.objs = objs
156        self.raw = raw
157
158
159class AggregateQuery(Query):
160    """
161    Take another query as a parameter to the FROM clause and only select the
162    elements in the provided list.
163    """
164
165    compiler = "SQLAggregateCompiler"
166
167    def __init__(self, model, inner_query):
168        self.inner_query = inner_query
169        super().__init__(model)