Plain is headed towards 1.0! Subscribe for development updates →

  1from plain.models.lookups import (
  2    Exact,
  3    GreaterThan,
  4    GreaterThanOrEqual,
  5    In,
  6    IsNull,
  7    LessThan,
  8    LessThanOrEqual,
  9)
 10
 11
 12class MultiColSource:
 13    contains_aggregate = False
 14    contains_over_clause = False
 15
 16    def __init__(self, alias, targets, sources, field):
 17        self.targets, self.sources, self.field, self.alias = (
 18            targets,
 19            sources,
 20            field,
 21            alias,
 22        )
 23        self.output_field = self.field
 24
 25    def __repr__(self):
 26        return f"{self.__class__.__name__}({self.alias}, {self.field})"
 27
 28    def relabeled_clone(self, relabels):
 29        return self.__class__(
 30            relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
 31        )
 32
 33    def get_lookup(self, lookup):
 34        return self.output_field.get_lookup(lookup)
 35
 36    def resolve_expression(self, *args, **kwargs):
 37        return self
 38
 39
 40def get_normalized_value(value, lhs):
 41    from plain.models import Model
 42
 43    if isinstance(value, Model):
 44        if value.pk is None:
 45            raise ValueError("Model instances passed to related filters must be saved.")
 46        value_list = []
 47        sources = lhs.output_field.path_infos[-1].target_fields
 48        for source in sources:
 49            while not isinstance(value, source.model) and source.remote_field:
 50                source = source.remote_field.model._meta.get_field(
 51                    source.remote_field.field_name
 52                )
 53            try:
 54                value_list.append(getattr(value, source.attname))
 55            except AttributeError:
 56                # A case like Restaurant.objects.filter(place=restaurant_instance),
 57                # where place is a OneToOneField and the primary key of Restaurant.
 58                return (value.pk,)
 59        return tuple(value_list)
 60    if not isinstance(value, tuple):
 61        return (value,)
 62    return value
 63
 64
 65class RelatedIn(In):
 66    def get_prep_lookup(self):
 67        if not isinstance(self.lhs, MultiColSource):
 68            if self.rhs_is_direct_value():
 69                # If we get here, we are dealing with single-column relations.
 70                self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
 71                # We need to run the related field's get_prep_value(). Consider
 72                # case ForeignKey to IntegerField given value 'abc'. The
 73                # ForeignKey itself doesn't have validation for non-integers,
 74                # so we must run validation using the target field.
 75                if hasattr(self.lhs.output_field, "path_infos"):
 76                    # Run the target field's get_prep_value. We can safely
 77                    # assume there is only one as we don't get to the direct
 78                    # value branch otherwise.
 79                    target_field = self.lhs.output_field.path_infos[-1].target_fields[
 80                        -1
 81                    ]
 82                    self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
 83            elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
 84                self.lhs.field.target_field, "primary_key", False
 85            ):
 86                if (
 87                    getattr(self.lhs.output_field, "primary_key", False)
 88                    and self.lhs.output_field.model == self.rhs.model
 89                ):
 90                    # A case like
 91                    # Restaurant.objects.filter(place__in=restaurant_qs), where
 92                    # place is a OneToOneField and the primary key of
 93                    # Restaurant.
 94                    target_field = self.lhs.field.name
 95                else:
 96                    target_field = self.lhs.field.target_field.name
 97                self.rhs.set_values([target_field])
 98        return super().get_prep_lookup()
 99
100    def as_sql(self, compiler, connection):
101        if isinstance(self.lhs, MultiColSource):
102            # For multicolumn lookups we need to build a multicolumn where clause.
103            # This clause is either a SubqueryConstraint (for values that need
104            # to be compiled to SQL) or an OR-combined list of
105            # (col1 = val1 AND col2 = val2 AND ...) clauses.
106            from plain.models.sql.where import (
107                AND,
108                OR,
109                SubqueryConstraint,
110                WhereNode,
111            )
112
113            root_constraint = WhereNode(connector=OR)
114            if self.rhs_is_direct_value():
115                values = [get_normalized_value(value, self.lhs) for value in self.rhs]
116                for value in values:
117                    value_constraint = WhereNode()
118                    for source, target, val in zip(
119                        self.lhs.sources, self.lhs.targets, value
120                    ):
121                        lookup_class = target.get_lookup("exact")
122                        lookup = lookup_class(
123                            target.get_col(self.lhs.alias, source), val
124                        )
125                        value_constraint.add(lookup, AND)
126                    root_constraint.add(value_constraint, OR)
127            else:
128                root_constraint.add(
129                    SubqueryConstraint(
130                        self.lhs.alias,
131                        [target.column for target in self.lhs.targets],
132                        [source.name for source in self.lhs.sources],
133                        self.rhs,
134                    ),
135                    AND,
136                )
137            return root_constraint.as_sql(compiler, connection)
138        return super().as_sql(compiler, connection)
139
140
141class RelatedLookupMixin:
142    def get_prep_lookup(self):
143        if not isinstance(self.lhs, MultiColSource) and not hasattr(
144            self.rhs, "resolve_expression"
145        ):
146            # If we get here, we are dealing with single-column relations.
147            self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
148            # We need to run the related field's get_prep_value(). Consider case
149            # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
150            # doesn't have validation for non-integers, so we must run validation
151            # using the target field.
152            if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
153                # Get the target field. We can safely assume there is only one
154                # as we don't get to the direct value branch otherwise.
155                target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
156                self.rhs = target_field.get_prep_value(self.rhs)
157
158        return super().get_prep_lookup()
159
160    def as_sql(self, compiler, connection):
161        if isinstance(self.lhs, MultiColSource):
162            assert self.rhs_is_direct_value()
163            self.rhs = get_normalized_value(self.rhs, self.lhs)
164            from plain.models.sql.where import AND, WhereNode
165
166            root_constraint = WhereNode()
167            for target, source, val in zip(
168                self.lhs.targets, self.lhs.sources, self.rhs
169            ):
170                lookup_class = target.get_lookup(self.lookup_name)
171                root_constraint.add(
172                    lookup_class(target.get_col(self.lhs.alias, source), val), AND
173                )
174            return root_constraint.as_sql(compiler, connection)
175        return super().as_sql(compiler, connection)
176
177
178class RelatedExact(RelatedLookupMixin, Exact):
179    pass
180
181
182class RelatedLessThan(RelatedLookupMixin, LessThan):
183    pass
184
185
186class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
187    pass
188
189
190class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
191    pass
192
193
194class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
195    pass
196
197
198class RelatedIsNull(RelatedLookupMixin, IsNull):
199    pass