Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Query subclasses which provide extra functionality beyond simple data retrieval.
  3"""
  4
  5from plain.exceptions import FieldError
  6from plain.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
  7from plain.models.sql.query import Query
  8
  9__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
 10
 11
 12class DeleteQuery(Query):
 13    """A DELETE SQL query."""
 14
 15    compiler = "SQLDeleteCompiler"
 16
 17    def do_query(self, table, where, using):
 18        self.alias_map = {table: self.alias_map[table]}
 19        self.where = where
 20        cursor = self.get_compiler(using).execute_sql(CURSOR)
 21        if cursor:
 22            with cursor:
 23                return cursor.rowcount
 24        return 0
 25
 26    def delete_batch(self, pk_list, using):
 27        """
 28        Set up and execute delete queries for all the objects in pk_list.
 29
 30        More than one physical query may be executed if there are a
 31        lot of values in pk_list.
 32        """
 33        # number of objects deleted
 34        num_deleted = 0
 35        field = self.get_meta().pk
 36        for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
 37            self.clear_where()
 38            self.add_filter(
 39                f"{field.attname}__in",
 40                pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
 41            )
 42            num_deleted += self.do_query(
 43                self.get_meta().db_table, self.where, using=using
 44            )
 45        return num_deleted
 46
 47
 48class UpdateQuery(Query):
 49    """An UPDATE SQL query."""
 50
 51    compiler = "SQLUpdateCompiler"
 52
 53    def __init__(self, *args, **kwargs):
 54        super().__init__(*args, **kwargs)
 55        self._setup_query()
 56
 57    def _setup_query(self):
 58        """
 59        Run on initialization and at the end of chaining. Any attributes that
 60        would normally be set in __init__() should go here instead.
 61        """
 62        self.values = []
 63        self.related_ids = None
 64        self.related_updates = {}
 65
 66    def clone(self):
 67        obj = super().clone()
 68        obj.related_updates = self.related_updates.copy()
 69        return obj
 70
 71    def update_batch(self, pk_list, values, using):
 72        self.add_update_values(values)
 73        for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
 74            self.clear_where()
 75            self.add_filter(
 76                "pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
 77            )
 78            self.get_compiler(using).execute_sql(NO_RESULTS)
 79
 80    def add_update_values(self, values):
 81        """
 82        Convert a dictionary of field name to value mappings into an update
 83        query. This is the entry point for the public update() method on
 84        querysets.
 85        """
 86        values_seq = []
 87        for name, val in values.items():
 88            field = self.get_meta().get_field(name)
 89            direct = (
 90                not (field.auto_created and not field.concrete) or not field.concrete
 91            )
 92            model = field.model._meta.concrete_model
 93            if not direct or (field.is_relation and field.many_to_many):
 94                raise FieldError(
 95                    "Cannot update model field %r (only non-relations and "
 96                    "foreign keys permitted)." % field
 97                )
 98            if model is not self.get_meta().concrete_model:
 99                self.add_related_update(model, field, val)
100                continue
101            values_seq.append((field, model, val))
102        return self.add_update_fields(values_seq)
103
104    def add_update_fields(self, values_seq):
105        """
106        Append a sequence of (field, model, value) triples to the internal list
107        that will be used to generate the UPDATE query. Might be more usefully
108        called add_update_targets() to hint at the extra information here.
109        """
110        for field, model, val in values_seq:
111            if hasattr(val, "resolve_expression"):
112                # Resolve expressions here so that annotations are no longer needed
113                val = val.resolve_expression(self, allow_joins=False, for_save=True)
114            self.values.append((field, model, val))
115
116    def add_related_update(self, model, field, value):
117        """
118        Add (name, value) to an update query for an ancestor model.
119
120        Update are coalesced so that only one update query per ancestor is run.
121        """
122        self.related_updates.setdefault(model, []).append((field, None, value))
123
124    def get_related_updates(self):
125        """
126        Return a list of query objects: one for each update required to an
127        ancestor model. Each query will have the same filtering conditions as
128        the current query but will only update a single table.
129        """
130        if not self.related_updates:
131            return []
132        result = []
133        for model, values in self.related_updates.items():
134            query = UpdateQuery(model)
135            query.values = values
136            if self.related_ids is not None:
137                query.add_filter("pk__in", self.related_ids[model])
138            result.append(query)
139        return result
140
141
142class InsertQuery(Query):
143    compiler = "SQLInsertCompiler"
144
145    def __init__(
146        self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs
147    ):
148        super().__init__(*args, **kwargs)
149        self.fields = []
150        self.objs = []
151        self.on_conflict = on_conflict
152        self.update_fields = update_fields or []
153        self.unique_fields = unique_fields or []
154
155    def insert_values(self, fields, objs, raw=False):
156        self.fields = fields
157        self.objs = objs
158        self.raw = raw
159
160
161class AggregateQuery(Query):
162    """
163    Take another query as a parameter to the FROM clause and only select the
164    elements in the provided list.
165    """
166
167    compiler = "SQLAggregateCompiler"
168
169    def __init__(self, model, inner_query):
170        self.inner_query = inner_query
171        super().__init__(model)