Plain is headed towards 1.0! Subscribe for development updates →

  1from enum import Enum
  2from types import NoneType
  3
  4from plain.exceptions import FieldError, ValidationError
  5from plain.models.expressions import Exists, ExpressionList, F, OrderBy
  6from plain.models.indexes import IndexExpression
  7from plain.models.lookups import Exact
  8from plain.models.query_utils import Q
  9from plain.models.sql.query import Query
 10
 11__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
 12
 13
 14class BaseConstraint:
 15    default_violation_error_message = "Constraint “%(name)s” is violated."
 16    violation_error_code = None
 17    violation_error_message = None
 18
 19    def __init__(
 20        self, *, name, violation_error_code=None, violation_error_message=None
 21    ):
 22        self.name = name
 23        if violation_error_code is not None:
 24            self.violation_error_code = violation_error_code
 25        if violation_error_message is not None:
 26            self.violation_error_message = violation_error_message
 27        else:
 28            self.violation_error_message = self.default_violation_error_message
 29
 30    @property
 31    def contains_expressions(self):
 32        return False
 33
 34    def constraint_sql(self, model, schema_editor):
 35        raise NotImplementedError("This method must be implemented by a subclass.")
 36
 37    def create_sql(self, model, schema_editor):
 38        raise NotImplementedError("This method must be implemented by a subclass.")
 39
 40    def remove_sql(self, model, schema_editor):
 41        raise NotImplementedError("This method must be implemented by a subclass.")
 42
 43    def validate(self, model, instance, exclude=None):
 44        raise NotImplementedError("This method must be implemented by a subclass.")
 45
 46    def get_violation_error_message(self):
 47        return self.violation_error_message % {"name": self.name}
 48
 49    def deconstruct(self):
 50        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
 51        path = path.replace("plain.models.constraints", "plain.models")
 52        kwargs = {"name": self.name}
 53        if (
 54            self.violation_error_message is not None
 55            and self.violation_error_message != self.default_violation_error_message
 56        ):
 57            kwargs["violation_error_message"] = self.violation_error_message
 58        if self.violation_error_code is not None:
 59            kwargs["violation_error_code"] = self.violation_error_code
 60        return (path, (), kwargs)
 61
 62    def clone(self):
 63        _, args, kwargs = self.deconstruct()
 64        return self.__class__(*args, **kwargs)
 65
 66
 67class CheckConstraint(BaseConstraint):
 68    def __init__(
 69        self, *, check, name, violation_error_code=None, violation_error_message=None
 70    ):
 71        self.check = check
 72        if not getattr(check, "conditional", False):
 73            raise TypeError(
 74                "CheckConstraint.check must be a Q instance or boolean expression."
 75            )
 76        super().__init__(
 77            name=name,
 78            violation_error_code=violation_error_code,
 79            violation_error_message=violation_error_message,
 80        )
 81
 82    def _get_check_sql(self, model, schema_editor):
 83        query = Query(model=model, alias_cols=False)
 84        where = query.build_where(self.check)
 85        compiler = query.get_compiler()
 86        sql, params = where.as_sql(compiler, schema_editor.connection)
 87        return sql % tuple(schema_editor.quote_value(p) for p in params)
 88
 89    def constraint_sql(self, model, schema_editor):
 90        check = self._get_check_sql(model, schema_editor)
 91        return schema_editor._check_sql(self.name, check)
 92
 93    def create_sql(self, model, schema_editor):
 94        check = self._get_check_sql(model, schema_editor)
 95        return schema_editor._create_check_sql(model, self.name, check)
 96
 97    def remove_sql(self, model, schema_editor):
 98        return schema_editor._delete_check_sql(model, self.name)
 99
100    def validate(self, model, instance, exclude=None):
101        against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
102        try:
103            if not Q(self.check).check(against):
104                raise ValidationError(
105                    self.get_violation_error_message(), code=self.violation_error_code
106                )
107        except FieldError:
108            pass
109
110    def __repr__(self):
111        return "<{}: check={} name={}{}{}>".format(
112            self.__class__.__qualname__,
113            self.check,
114            repr(self.name),
115            (
116                ""
117                if self.violation_error_code is None
118                else f" violation_error_code={self.violation_error_code!r}"
119            ),
120            (
121                ""
122                if self.violation_error_message is None
123                or self.violation_error_message == self.default_violation_error_message
124                else f" violation_error_message={self.violation_error_message!r}"
125            ),
126        )
127
128    def __eq__(self, other):
129        if isinstance(other, CheckConstraint):
130            return (
131                self.name == other.name
132                and self.check == other.check
133                and self.violation_error_code == other.violation_error_code
134                and self.violation_error_message == other.violation_error_message
135            )
136        return super().__eq__(other)
137
138    def deconstruct(self):
139        path, args, kwargs = super().deconstruct()
140        kwargs["check"] = self.check
141        return path, args, kwargs
142
143
144class Deferrable(Enum):
145    DEFERRED = "deferred"
146    IMMEDIATE = "immediate"
147
148    # A similar format was proposed for Python 3.10.
149    def __repr__(self):
150        return f"{self.__class__.__qualname__}.{self._name_}"
151
152
153class UniqueConstraint(BaseConstraint):
154    def __init__(
155        self,
156        *expressions,
157        fields=(),
158        name=None,
159        condition=None,
160        deferrable=None,
161        include=None,
162        opclasses=(),
163        violation_error_code=None,
164        violation_error_message=None,
165    ):
166        if not name:
167            raise ValueError("A unique constraint must be named.")
168        if not expressions and not fields:
169            raise ValueError(
170                "At least one field or expression is required to define a "
171                "unique constraint."
172            )
173        if expressions and fields:
174            raise ValueError(
175                "UniqueConstraint.fields and expressions are mutually exclusive."
176            )
177        if not isinstance(condition, NoneType | Q):
178            raise ValueError("UniqueConstraint.condition must be a Q instance.")
179        if condition and deferrable:
180            raise ValueError("UniqueConstraint with conditions cannot be deferred.")
181        if include and deferrable:
182            raise ValueError("UniqueConstraint with include fields cannot be deferred.")
183        if opclasses and deferrable:
184            raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
185        if expressions and deferrable:
186            raise ValueError("UniqueConstraint with expressions cannot be deferred.")
187        if expressions and opclasses:
188            raise ValueError(
189                "UniqueConstraint.opclasses cannot be used with expressions. "
190                "Use a custom OpClass() instead."
191            )
192        if not isinstance(deferrable, NoneType | Deferrable):
193            raise ValueError(
194                "UniqueConstraint.deferrable must be a Deferrable instance."
195            )
196        if not isinstance(include, NoneType | list | tuple):
197            raise ValueError("UniqueConstraint.include must be a list or tuple.")
198        if not isinstance(opclasses, list | tuple):
199            raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
200        if opclasses and len(fields) != len(opclasses):
201            raise ValueError(
202                "UniqueConstraint.fields and UniqueConstraint.opclasses must "
203                "have the same number of elements."
204            )
205        self.fields = tuple(fields)
206        self.condition = condition
207        self.deferrable = deferrable
208        self.include = tuple(include) if include else ()
209        self.opclasses = opclasses
210        self.expressions = tuple(
211            F(expression) if isinstance(expression, str) else expression
212            for expression in expressions
213        )
214        super().__init__(
215            name=name,
216            violation_error_code=violation_error_code,
217            violation_error_message=violation_error_message,
218        )
219
220    @property
221    def contains_expressions(self):
222        return bool(self.expressions)
223
224    def _get_condition_sql(self, model, schema_editor):
225        if self.condition is None:
226            return None
227        query = Query(model=model, alias_cols=False)
228        where = query.build_where(self.condition)
229        compiler = query.get_compiler()
230        sql, params = where.as_sql(compiler, schema_editor.connection)
231        return sql % tuple(schema_editor.quote_value(p) for p in params)
232
233    def _get_index_expressions(self, model, schema_editor):
234        if not self.expressions:
235            return None
236        index_expressions = []
237        for expression in self.expressions:
238            index_expression = IndexExpression(expression)
239            index_expression.set_wrapper_classes(schema_editor.connection)
240            index_expressions.append(index_expression)
241        return ExpressionList(*index_expressions).resolve_expression(
242            Query(model, alias_cols=False),
243        )
244
245    def constraint_sql(self, model, schema_editor):
246        fields = [model._meta.get_field(field_name) for field_name in self.fields]
247        include = [
248            model._meta.get_field(field_name).column for field_name in self.include
249        ]
250        condition = self._get_condition_sql(model, schema_editor)
251        expressions = self._get_index_expressions(model, schema_editor)
252        return schema_editor._unique_sql(
253            model,
254            fields,
255            self.name,
256            condition=condition,
257            deferrable=self.deferrable,
258            include=include,
259            opclasses=self.opclasses,
260            expressions=expressions,
261        )
262
263    def create_sql(self, model, schema_editor):
264        fields = [model._meta.get_field(field_name) for field_name in self.fields]
265        include = [
266            model._meta.get_field(field_name).column for field_name in self.include
267        ]
268        condition = self._get_condition_sql(model, schema_editor)
269        expressions = self._get_index_expressions(model, schema_editor)
270        return schema_editor._create_unique_sql(
271            model,
272            fields,
273            self.name,
274            condition=condition,
275            deferrable=self.deferrable,
276            include=include,
277            opclasses=self.opclasses,
278            expressions=expressions,
279        )
280
281    def remove_sql(self, model, schema_editor):
282        condition = self._get_condition_sql(model, schema_editor)
283        include = [
284            model._meta.get_field(field_name).column for field_name in self.include
285        ]
286        expressions = self._get_index_expressions(model, schema_editor)
287        return schema_editor._delete_unique_sql(
288            model,
289            self.name,
290            condition=condition,
291            deferrable=self.deferrable,
292            include=include,
293            opclasses=self.opclasses,
294            expressions=expressions,
295        )
296
297    def __repr__(self):
298        return "<{}:{}{}{}{}{}{}{}{}{}>".format(
299            self.__class__.__qualname__,
300            "" if not self.fields else f" fields={repr(self.fields)}",
301            "" if not self.expressions else f" expressions={repr(self.expressions)}",
302            f" name={repr(self.name)}",
303            "" if self.condition is None else f" condition={self.condition}",
304            "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
305            "" if not self.include else f" include={repr(self.include)}",
306            "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
307            (
308                ""
309                if self.violation_error_code is None
310                else f" violation_error_code={self.violation_error_code!r}"
311            ),
312            (
313                ""
314                if self.violation_error_message is None
315                or self.violation_error_message == self.default_violation_error_message
316                else f" violation_error_message={self.violation_error_message!r}"
317            ),
318        )
319
320    def __eq__(self, other):
321        if isinstance(other, UniqueConstraint):
322            return (
323                self.name == other.name
324                and self.fields == other.fields
325                and self.condition == other.condition
326                and self.deferrable == other.deferrable
327                and self.include == other.include
328                and self.opclasses == other.opclasses
329                and self.expressions == other.expressions
330                and self.violation_error_code == other.violation_error_code
331                and self.violation_error_message == other.violation_error_message
332            )
333        return super().__eq__(other)
334
335    def deconstruct(self):
336        path, args, kwargs = super().deconstruct()
337        if self.fields:
338            kwargs["fields"] = self.fields
339        if self.condition:
340            kwargs["condition"] = self.condition
341        if self.deferrable:
342            kwargs["deferrable"] = self.deferrable
343        if self.include:
344            kwargs["include"] = self.include
345        if self.opclasses:
346            kwargs["opclasses"] = self.opclasses
347        return path, self.expressions, kwargs
348
349    def validate(self, model, instance, exclude=None):
350        queryset = model._default_manager
351        if self.fields:
352            lookup_kwargs = {}
353            for field_name in self.fields:
354                if exclude and field_name in exclude:
355                    return
356                field = model._meta.get_field(field_name)
357                lookup_value = getattr(instance, field.attname)
358                if lookup_value is None:
359                    # A composite constraint containing NULL value cannot cause
360                    # a violation since NULL != NULL in SQL.
361                    return
362                lookup_kwargs[field.name] = lookup_value
363            queryset = queryset.filter(**lookup_kwargs)
364        else:
365            # Ignore constraints with excluded fields.
366            if exclude:
367                for expression in self.expressions:
368                    if hasattr(expression, "flatten"):
369                        for expr in expression.flatten():
370                            if isinstance(expr, F) and expr.name in exclude:
371                                return
372                    elif isinstance(expression, F) and expression.name in exclude:
373                        return
374            replacements = {
375                F(field): value
376                for field, value in instance._get_field_value_map(
377                    meta=model._meta, exclude=exclude
378                ).items()
379            }
380            expressions = []
381            for expr in self.expressions:
382                # Ignore ordering.
383                if isinstance(expr, OrderBy):
384                    expr = expr.expression
385                expressions.append(Exact(expr, expr.replace_expressions(replacements)))
386            queryset = queryset.filter(*expressions)
387        model_class_pk = instance._get_pk_val(model._meta)
388        if not instance._state.adding and model_class_pk is not None:
389            queryset = queryset.exclude(pk=model_class_pk)
390        if not self.condition:
391            if queryset.exists():
392                if self.expressions:
393                    raise ValidationError(
394                        self.get_violation_error_message(),
395                        code=self.violation_error_code,
396                    )
397                # When fields are defined, use the unique_error_message() for
398                # backward compatibility.
399                for model, constraints in instance.get_constraints():
400                    for constraint in constraints:
401                        if constraint is self:
402                            raise ValidationError(
403                                instance.unique_error_message(model, self.fields),
404                            )
405        else:
406            against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
407            try:
408                if (self.condition & Exists(queryset.filter(self.condition))).check(
409                    against
410                ):
411                    raise ValidationError(
412                        self.get_violation_error_message(),
413                        code=self.violation_error_code,
414                    )
415            except FieldError:
416                pass