Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Various data structures used in query construction.
  3
  4Factored out from plain.models.query to avoid making the main module very
  5large and/or so that they can be used by other modules without getting into
  6circular import difficulties.
  7"""
  8
  9import functools
 10import inspect
 11import logging
 12from collections import namedtuple
 13
 14from plain.exceptions import FieldError
 15from plain.models.constants import LOOKUP_SEP
 16from plain.models.db import DatabaseError, db_connection
 17from plain.utils import tree
 18
 19logger = logging.getLogger("plain.models")
 20
 21# PathInfo is used when converting lookups (fk__somecol). The contents
 22# describe the relation in Model terms (model Options and Fields for both
 23# sides of the relation. The join_field is the field backing the relation.
 24PathInfo = namedtuple(
 25    "PathInfo",
 26    "from_opts to_opts target_fields join_field m2m direct filtered_relation",
 27)
 28
 29
 30def subclasses(cls):
 31    yield cls
 32    for subclass in cls.__subclasses__():
 33        yield from subclasses(subclass)
 34
 35
 36class Q(tree.Node):
 37    """
 38    Encapsulate filters as objects that can then be combined logically (using
 39    `&` and `|`).
 40    """
 41
 42    # Connection types
 43    AND = "AND"
 44    OR = "OR"
 45    XOR = "XOR"
 46    default = AND
 47    conditional = True
 48
 49    def __init__(self, *args, _connector=None, _negated=False, **kwargs):
 50        super().__init__(
 51            children=[*args, *sorted(kwargs.items())],
 52            connector=_connector,
 53            negated=_negated,
 54        )
 55
 56    def _combine(self, other, conn):
 57        if getattr(other, "conditional", False) is False:
 58            raise TypeError(other)
 59        if not self:
 60            return other.copy()
 61        if not other and isinstance(other, Q):
 62            return self.copy()
 63
 64        obj = self.create(connector=conn)
 65        obj.add(self, conn)
 66        obj.add(other, conn)
 67        return obj
 68
 69    def __or__(self, other):
 70        return self._combine(other, self.OR)
 71
 72    def __and__(self, other):
 73        return self._combine(other, self.AND)
 74
 75    def __xor__(self, other):
 76        return self._combine(other, self.XOR)
 77
 78    def __invert__(self):
 79        obj = self.copy()
 80        obj.negate()
 81        return obj
 82
 83    def resolve_expression(
 84        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
 85    ):
 86        # We must promote any new joins to left outer joins so that when Q is
 87        # used as an expression, rows aren't filtered due to joins.
 88        clause, joins = query._add_q(
 89            self,
 90            reuse,
 91            allow_joins=allow_joins,
 92            split_subq=False,
 93            check_filterable=False,
 94            summarize=summarize,
 95        )
 96        query.promote_joins(joins)
 97        return clause
 98
 99    def flatten(self):
100        """
101        Recursively yield this Q object and all subexpressions, in depth-first
102        order.
103        """
104        yield self
105        for child in self.children:
106            if isinstance(child, tuple):
107                # Use the lookup.
108                child = child[1]
109            if hasattr(child, "flatten"):
110                yield from child.flatten()
111            else:
112                yield child
113
114    def check(self, against):
115        """
116        Do a database query to check if the expressions of the Q instance
117        matches against the expressions.
118        """
119        # Avoid circular imports.
120        from plain.models.expressions import Value
121        from plain.models.fields import BooleanField
122        from plain.models.functions import Coalesce
123        from plain.models.sql import Query
124        from plain.models.sql.constants import SINGLE
125
126        query = Query(None)
127        for name, value in against.items():
128            if not hasattr(value, "resolve_expression"):
129                value = Value(value)
130            query.add_annotation(value, name, select=False)
131        query.add_annotation(Value(1), "_check")
132        # This will raise a FieldError if a field is missing in "against".
133        if db_connection.features.supports_comparing_boolean_expr:
134            query.add_q(Q(Coalesce(self, True, output_field=BooleanField())))
135        else:
136            query.add_q(self)
137        compiler = query.get_compiler()
138        try:
139            return compiler.execute_sql(SINGLE) is not None
140        except DatabaseError as e:
141            logger.warning("Got a database error calling check() on %r: %s", self, e)
142            return True
143
144    def deconstruct(self):
145        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
146        if path.startswith("plain.models.query_utils"):
147            path = path.replace("plain.models.query_utils", "plain.models")
148        args = tuple(self.children)
149        kwargs = {}
150        if self.connector != self.default:
151            kwargs["_connector"] = self.connector
152        if self.negated:
153            kwargs["_negated"] = True
154        return path, args, kwargs
155
156
157class DeferredAttribute:
158    """
159    A wrapper for a deferred-loading field. When the value is read from this
160    object the first time, the query is executed.
161    """
162
163    def __init__(self, field):
164        self.field = field
165
166    def __get__(self, instance, cls=None):
167        """
168        Retrieve and caches the value from the datastore on the first lookup.
169        Return the cached value.
170        """
171        if instance is None:
172            return self
173        data = instance.__dict__
174        field_name = self.field.attname
175        if field_name not in data:
176            instance.refresh_from_db(fields=[field_name])
177        return data[field_name]
178
179
180class class_or_instance_method:
181    """
182    Hook used in RegisterLookupMixin to return partial functions depending on
183    the caller type (instance or class of models.Field).
184    """
185
186    def __init__(self, class_method, instance_method):
187        self.class_method = class_method
188        self.instance_method = instance_method
189
190    def __get__(self, instance, owner):
191        if instance is None:
192            return functools.partial(self.class_method, owner)
193        return functools.partial(self.instance_method, instance)
194
195
196class RegisterLookupMixin:
197    def _get_lookup(self, lookup_name):
198        return self.get_lookups().get(lookup_name, None)
199
200    @functools.cache
201    def get_class_lookups(cls):
202        class_lookups = [
203            parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
204        ]
205        return cls.merge_dicts(class_lookups)
206
207    def get_instance_lookups(self):
208        class_lookups = self.get_class_lookups()
209        if instance_lookups := getattr(self, "instance_lookups", None):
210            return {**class_lookups, **instance_lookups}
211        return class_lookups
212
213    get_lookups = class_or_instance_method(get_class_lookups, get_instance_lookups)
214    get_class_lookups = classmethod(get_class_lookups)
215
216    def get_lookup(self, lookup_name):
217        from plain.models.lookups import Lookup
218
219        found = self._get_lookup(lookup_name)
220        if found is None and hasattr(self, "output_field"):
221            return self.output_field.get_lookup(lookup_name)
222        if found is not None and not issubclass(found, Lookup):
223            return None
224        return found
225
226    def get_transform(self, lookup_name):
227        from plain.models.lookups import Transform
228
229        found = self._get_lookup(lookup_name)
230        if found is None and hasattr(self, "output_field"):
231            return self.output_field.get_transform(lookup_name)
232        if found is not None and not issubclass(found, Transform):
233            return None
234        return found
235
236    @staticmethod
237    def merge_dicts(dicts):
238        """
239        Merge dicts in reverse to preference the order of the original list. e.g.,
240        merge_dicts([a, b]) will preference the keys in 'a' over those in 'b'.
241        """
242        merged = {}
243        for d in reversed(dicts):
244            merged.update(d)
245        return merged
246
247    @classmethod
248    def _clear_cached_class_lookups(cls):
249        for subclass in subclasses(cls):
250            subclass.get_class_lookups.cache_clear()
251
252    def register_class_lookup(cls, lookup, lookup_name=None):
253        if lookup_name is None:
254            lookup_name = lookup.lookup_name
255        if "class_lookups" not in cls.__dict__:
256            cls.class_lookups = {}
257        cls.class_lookups[lookup_name] = lookup
258        cls._clear_cached_class_lookups()
259        return lookup
260
261    def register_instance_lookup(self, lookup, lookup_name=None):
262        if lookup_name is None:
263            lookup_name = lookup.lookup_name
264        if "instance_lookups" not in self.__dict__:
265            self.instance_lookups = {}
266        self.instance_lookups[lookup_name] = lookup
267        return lookup
268
269    register_lookup = class_or_instance_method(
270        register_class_lookup, register_instance_lookup
271    )
272    register_class_lookup = classmethod(register_class_lookup)
273
274    def _unregister_class_lookup(cls, lookup, lookup_name=None):
275        """
276        Remove given lookup from cls lookups. For use in tests only as it's
277        not thread-safe.
278        """
279        if lookup_name is None:
280            lookup_name = lookup.lookup_name
281        del cls.class_lookups[lookup_name]
282        cls._clear_cached_class_lookups()
283
284    def _unregister_instance_lookup(self, lookup, lookup_name=None):
285        """
286        Remove given lookup from instance lookups. For use in tests only as
287        it's not thread-safe.
288        """
289        if lookup_name is None:
290            lookup_name = lookup.lookup_name
291        del self.instance_lookups[lookup_name]
292
293    _unregister_lookup = class_or_instance_method(
294        _unregister_class_lookup, _unregister_instance_lookup
295    )
296    _unregister_class_lookup = classmethod(_unregister_class_lookup)
297
298
299def select_related_descend(field, restricted, requested, select_mask, reverse=False):
300    """
301    Return True if this field should be used to descend deeper for
302    select_related() purposes. Used by both the query construction code
303    (compiler.get_related_selections()) and the model instance creation code
304    (compiler.klass_info).
305
306    Arguments:
307     * field - the field to be checked
308     * restricted - a boolean field, indicating if the field list has been
309       manually restricted using a requested clause)
310     * requested - The select_related() dictionary.
311     * select_mask - the dictionary of selected fields.
312     * reverse - boolean, True if we are checking a reverse select related
313    """
314    if not field.remote_field:
315        return False
316    if restricted:
317        if reverse and field.related_query_name() not in requested:
318            return False
319        if not reverse and field.name not in requested:
320            return False
321    if not restricted and field.allow_null:
322        return False
323    if (
324        restricted
325        and select_mask
326        and field.name in requested
327        and field not in select_mask
328    ):
329        raise FieldError(
330            f"Field {field.model._meta.object_name}.{field.name} cannot be both "
331            "deferred and traversed using select_related at the same time."
332        )
333    return True
334
335
336def refs_expression(lookup_parts, annotations):
337    """
338    Check if the lookup_parts contains references to the given annotations set.
339    Because the LOOKUP_SEP is contained in the default annotation names, check
340    each prefix of the lookup_parts for a match.
341    """
342    for n in range(1, len(lookup_parts) + 1):
343        level_n_lookup = LOOKUP_SEP.join(lookup_parts[0:n])
344        if annotations.get(level_n_lookup):
345            return level_n_lookup, lookup_parts[n:]
346    return None, ()
347
348
349def check_rel_lookup_compatibility(model, target_opts, field):
350    """
351    Check that self.model is compatible with target_opts. Compatibility
352    is OK if:
353      1) model and opts match (where proxy inheritance is removed)
354      2) model is parent of opts' model or the other way around
355    """
356
357    def check(opts):
358        return model._meta.concrete_model == opts.concrete_model
359
360    # If the field is a primary key, then doing a query against the field's
361    # model is ok, too. Consider the case:
362    # class Restaurant(models.Model):
363    #     place = OneToOneField(Place, primary_key=True):
364    # Restaurant.query.filter(id__in=Restaurant.query.all()).
365    # If we didn't have the primary key check, then id__in (== place__in) would
366    # give Place's opts as the target opts, but Restaurant isn't compatible
367    # with that. This logic applies only to primary keys, as when doing __in=qs,
368    # we are going to turn this into __in=qs.values('id') later on.
369    return check(target_opts) or (
370        getattr(field, "primary_key", False) and check(field.model._meta)
371    )
372
373
374class FilteredRelation:
375    """Specify custom filtering in the ON clause of SQL joins."""
376
377    def __init__(self, relation_name, *, condition=Q()):
378        if not relation_name:
379            raise ValueError("relation_name cannot be empty.")
380        self.relation_name = relation_name
381        self.alias = None
382        if not isinstance(condition, Q):
383            raise ValueError("condition argument must be a Q() instance.")
384        self.condition = condition
385        self.path = []
386
387    def __eq__(self, other):
388        if not isinstance(other, self.__class__):
389            return NotImplemented
390        return (
391            self.relation_name == other.relation_name
392            and self.alias == other.alias
393            and self.condition == other.condition
394        )
395
396    def clone(self):
397        clone = FilteredRelation(self.relation_name, condition=self.condition)
398        clone.alias = self.alias
399        clone.path = self.path[:]
400        return clone
401
402    def resolve_expression(self, *args, **kwargs):
403        """
404        QuerySet.annotate() only accepts expression-like arguments
405        (with a resolve_expression() method).
406        """
407        raise NotImplementedError("FilteredRelation.resolve_expression() is unused.")
408
409    def as_sql(self, compiler, connection):
410        # Resolve the condition in Join.filtered_relation.
411        query = compiler.query
412        where = query.build_filtered_relation_q(self.condition, reuse=set(self.path))
413        return compiler.compile(where)