Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from abc import ABC, abstractmethod
  4from enum import Enum
  5from types import NoneType
  6from typing import TYPE_CHECKING, Any
  7
  8from plain.exceptions import ValidationError
  9from plain.models.exceptions import FieldError
 10from plain.models.expressions import Exists, ExpressionList, F, OrderBy
 11from plain.models.indexes import IndexExpression
 12from plain.models.lookups import Exact
 13from plain.models.query_utils import Q
 14from plain.models.sql.query import Query
 15
 16if TYPE_CHECKING:
 17    from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
 18    from plain.models.backends.ddl_references import Statement
 19    from plain.models.base import Model
 20    from plain.models.expressions import Combinable
 21
 22__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
 23
 24
 25class BaseConstraint(ABC):
 26    default_violation_error_message = 'Constraint "%(name)s" is violated.'
 27    violation_error_code: str | None = None
 28    violation_error_message: str | None = None
 29
 30    def __init__(
 31        self,
 32        *,
 33        name: str,
 34        violation_error_code: str | None = None,
 35        violation_error_message: str | None = None,
 36    ) -> None:
 37        self.name = name
 38        if violation_error_code is not None:
 39            self.violation_error_code = violation_error_code
 40        if violation_error_message is not None:
 41            self.violation_error_message = violation_error_message
 42        else:
 43            self.violation_error_message = self.default_violation_error_message
 44
 45    @property
 46    def contains_expressions(self) -> bool:
 47        return False
 48
 49    @abstractmethod
 50    def constraint_sql(
 51        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
 52    ) -> str | None: ...
 53
 54    @abstractmethod
 55    def create_sql(
 56        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
 57    ) -> str | Statement | None: ...
 58
 59    @abstractmethod
 60    def remove_sql(
 61        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
 62    ) -> str | Statement | None: ...
 63
 64    @abstractmethod
 65    def validate(
 66        self, model: type[Model], instance: Model, exclude: set[str] | None = None
 67    ) -> None: ...
 68
 69    def get_violation_error_message(self) -> str:
 70        return self.violation_error_message % {"name": self.name}  # type: ignore[operator]
 71
 72    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 73        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
 74        path = path.replace("plain.models.constraints", "plain.models")
 75        kwargs: dict[str, Any] = {"name": self.name}
 76        if (
 77            self.violation_error_message is not None
 78            and self.violation_error_message != self.default_violation_error_message
 79        ):
 80            kwargs["violation_error_message"] = self.violation_error_message
 81        if self.violation_error_code is not None:
 82            kwargs["violation_error_code"] = self.violation_error_code
 83        return (path, (), kwargs)
 84
 85    def clone(self) -> BaseConstraint:
 86        _, args, kwargs = self.deconstruct()
 87        return self.__class__(*args, **kwargs)
 88
 89
 90class CheckConstraint(BaseConstraint):
 91    def __init__(
 92        self,
 93        *,
 94        check: Q,
 95        name: str,
 96        violation_error_code: str | None = None,
 97        violation_error_message: str | None = None,
 98    ) -> None:
 99        self.check = check
100        if not getattr(check, "conditional", False):
101            raise TypeError(
102                "CheckConstraint.check must be a Q instance or boolean expression."
103            )
104        super().__init__(
105            name=name,
106            violation_error_code=violation_error_code,
107            violation_error_message=violation_error_message,
108        )
109
110    def _get_check_sql(
111        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
112    ) -> str:
113        query = Query(model=model, alias_cols=False)
114        where = query.build_where(self.check)
115        compiler = query.get_compiler()
116        sql, params = where.as_sql(compiler, schema_editor.connection)
117        return sql % tuple(schema_editor.quote_value(p) for p in params)
118
119    def constraint_sql(
120        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
121    ) -> str:
122        check = self._get_check_sql(model, schema_editor)
123        return schema_editor._check_sql(self.name, check)
124
125    def create_sql(
126        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
127    ) -> Statement | None:
128        check = self._get_check_sql(model, schema_editor)
129        return schema_editor._create_check_sql(model, self.name, check)
130
131    def remove_sql(
132        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
133    ) -> Statement | None:
134        return schema_editor._delete_check_sql(model, self.name)
135
136    def validate(
137        self, model: type[Model], instance: Model, exclude: set[str] | None = None
138    ) -> None:
139        against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
140        try:
141            if not Q(self.check).check(against):
142                raise ValidationError(
143                    self.get_violation_error_message(), code=self.violation_error_code
144                )
145        except FieldError:
146            pass
147
148    def __repr__(self) -> str:
149        return "<{}: check={} name={}{}{}>".format(
150            self.__class__.__qualname__,
151            self.check,
152            repr(self.name),
153            (
154                ""
155                if self.violation_error_code is None
156                else f" violation_error_code={self.violation_error_code!r}"
157            ),
158            (
159                ""
160                if self.violation_error_message is None
161                or self.violation_error_message == self.default_violation_error_message
162                else f" violation_error_message={self.violation_error_message!r}"
163            ),
164        )
165
166    def __eq__(self, other: object) -> bool:
167        if isinstance(other, CheckConstraint):
168            return (
169                self.name == other.name
170                and self.check == other.check
171                and self.violation_error_code == other.violation_error_code
172                and self.violation_error_message == other.violation_error_message
173            )
174        return super().__eq__(other)
175
176    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
177        path, args, kwargs = super().deconstruct()
178        kwargs["check"] = self.check
179        return path, args, kwargs
180
181
182class Deferrable(Enum):
183    DEFERRED = "deferred"
184    IMMEDIATE = "immediate"
185
186    # A similar format was proposed for Python 3.10.
187    def __repr__(self) -> str:
188        return f"{self.__class__.__qualname__}.{self._name_}"
189
190
191class UniqueConstraint(BaseConstraint):
192    def __init__(
193        self,
194        *expressions: str | Combinable,
195        fields: tuple[str, ...] | list[str] = (),
196        name: str | None = None,
197        condition: Q | None = None,
198        deferrable: Deferrable | None = None,
199        include: tuple[str, ...] | list[str] | None = None,
200        opclasses: tuple[str, ...] | list[str] = (),
201        violation_error_code: str | None = None,
202        violation_error_message: str | None = None,
203    ) -> None:
204        if not name:
205            raise ValueError("A unique constraint must be named.")
206        if not expressions and not fields:
207            raise ValueError(
208                "At least one field or expression is required to define a "
209                "unique constraint."
210            )
211        if expressions and fields:
212            raise ValueError(
213                "UniqueConstraint.fields and expressions are mutually exclusive."
214            )
215        if not isinstance(condition, NoneType | Q):
216            raise ValueError("UniqueConstraint.condition must be a Q instance.")
217        if condition and deferrable:
218            raise ValueError("UniqueConstraint with conditions cannot be deferred.")
219        if include and deferrable:
220            raise ValueError("UniqueConstraint with include fields cannot be deferred.")
221        if opclasses and deferrable:
222            raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
223        if expressions and deferrable:
224            raise ValueError("UniqueConstraint with expressions cannot be deferred.")
225        if expressions and opclasses:
226            raise ValueError(
227                "UniqueConstraint.opclasses cannot be used with expressions. "
228                "Use a custom OpClass() instead."
229            )
230        if not isinstance(deferrable, NoneType | Deferrable):
231            raise ValueError(
232                "UniqueConstraint.deferrable must be a Deferrable instance."
233            )
234        if not isinstance(include, NoneType | list | tuple):
235            raise ValueError("UniqueConstraint.include must be a list or tuple.")
236        if not isinstance(opclasses, list | tuple):
237            raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
238        if opclasses and len(fields) != len(opclasses):
239            raise ValueError(
240                "UniqueConstraint.fields and UniqueConstraint.opclasses must "
241                "have the same number of elements."
242            )
243        self.fields = tuple(fields)
244        self.condition = condition
245        self.deferrable = deferrable
246        self.include = tuple(include) if include else ()
247        self.opclasses = opclasses
248        self.expressions = tuple(
249            F(expression) if isinstance(expression, str) else expression
250            for expression in expressions
251        )
252        super().__init__(
253            name=name,  # type: ignore[arg-type]
254            violation_error_code=violation_error_code,
255            violation_error_message=violation_error_message,
256        )
257
258    @property
259    def contains_expressions(self) -> bool:
260        return bool(self.expressions)
261
262    def _get_condition_sql(
263        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
264    ) -> str | None:
265        if self.condition is None:
266            return None
267        query = Query(model=model, alias_cols=False)
268        where = query.build_where(self.condition)
269        compiler = query.get_compiler()
270        sql, params = where.as_sql(compiler, schema_editor.connection)
271        return sql % tuple(schema_editor.quote_value(p) for p in params)
272
273    def _get_index_expressions(
274        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
275    ) -> Any:
276        if not self.expressions:
277            return None
278        index_expressions = []
279        for expression in self.expressions:
280            index_expression = IndexExpression(expression)
281            index_expression.set_wrapper_classes(schema_editor.connection)
282            index_expressions.append(index_expression)
283        return ExpressionList(*index_expressions).resolve_expression(
284            Query(model, alias_cols=False),
285        )
286
287    def constraint_sql(
288        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
289    ) -> str | None:
290        fields = [model._model_meta.get_field(field_name) for field_name in self.fields]
291        include = [
292            model._model_meta.get_field(field_name).column
293            for field_name in self.include
294        ]
295        condition = self._get_condition_sql(model, schema_editor)
296        expressions = self._get_index_expressions(model, schema_editor)
297        return schema_editor._unique_sql(
298            model,
299            fields,
300            self.name,
301            condition=condition,
302            deferrable=self.deferrable,
303            include=include,
304            opclasses=tuple(self.opclasses) if self.opclasses else None,
305            expressions=expressions,
306        )
307
308    def create_sql(
309        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
310    ) -> Statement | None:
311        fields = [model._model_meta.get_field(field_name) for field_name in self.fields]
312        include = [
313            model._model_meta.get_field(field_name).column
314            for field_name in self.include
315        ]
316        condition = self._get_condition_sql(model, schema_editor)
317        expressions = self._get_index_expressions(model, schema_editor)
318        return schema_editor._create_unique_sql(
319            model,
320            fields,
321            self.name,
322            condition=condition,
323            deferrable=self.deferrable,
324            include=include,
325            opclasses=tuple(self.opclasses) if self.opclasses else None,
326            expressions=expressions,
327        )
328
329    def remove_sql(
330        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
331    ) -> Statement | None:
332        condition = self._get_condition_sql(model, schema_editor)
333        include = [
334            model._model_meta.get_field(field_name).column
335            for field_name in self.include
336        ]
337        expressions = self._get_index_expressions(model, schema_editor)
338        return schema_editor._delete_unique_sql(
339            model,
340            self.name,
341            condition=condition,
342            deferrable=self.deferrable,
343            include=include,
344            opclasses=tuple(self.opclasses) if self.opclasses else None,
345            expressions=expressions,
346        )
347
348    def __repr__(self) -> str:
349        return "<{}:{}{}{}{}{}{}{}{}{}>".format(
350            self.__class__.__qualname__,
351            "" if not self.fields else f" fields={repr(self.fields)}",
352            "" if not self.expressions else f" expressions={repr(self.expressions)}",
353            f" name={repr(self.name)}",
354            "" if self.condition is None else f" condition={self.condition}",
355            "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
356            "" if not self.include else f" include={repr(self.include)}",
357            "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
358            (
359                ""
360                if self.violation_error_code is None
361                else f" violation_error_code={self.violation_error_code!r}"
362            ),
363            (
364                ""
365                if self.violation_error_message is None
366                or self.violation_error_message == self.default_violation_error_message
367                else f" violation_error_message={self.violation_error_message!r}"
368            ),
369        )
370
371    def __eq__(self, other: object) -> bool:
372        if isinstance(other, UniqueConstraint):
373            return (
374                self.name == other.name
375                and self.fields == other.fields
376                and self.condition == other.condition
377                and self.deferrable == other.deferrable
378                and self.include == other.include
379                and self.opclasses == other.opclasses
380                and self.expressions == other.expressions
381                and self.violation_error_code == other.violation_error_code
382                and self.violation_error_message == other.violation_error_message
383            )
384        return super().__eq__(other)
385
386    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
387        path, args, kwargs = super().deconstruct()
388        if self.fields:
389            kwargs["fields"] = self.fields
390        if self.condition:
391            kwargs["condition"] = self.condition
392        if self.deferrable:
393            kwargs["deferrable"] = self.deferrable
394        if self.include:
395            kwargs["include"] = self.include
396        if self.opclasses:
397            kwargs["opclasses"] = self.opclasses
398        return path, self.expressions, kwargs
399
400    def validate(
401        self, model: type[Model], instance: Model, exclude: set[str] | None = None
402    ) -> None:
403        queryset = model.query
404        if self.fields:
405            lookup_kwargs = {}
406            for field_name in self.fields:
407                if exclude and field_name in exclude:
408                    return
409                field = model._model_meta.get_field(field_name)
410                lookup_value = getattr(instance, field.attname)
411                if lookup_value is None:
412                    # A composite constraint containing NULL value cannot cause
413                    # a violation since NULL != NULL in SQL.
414                    return
415                lookup_kwargs[field.name] = lookup_value
416            queryset = queryset.filter(**lookup_kwargs)
417        else:
418            # Ignore constraints with excluded fields.
419            if exclude:
420                for expression in self.expressions:
421                    if hasattr(expression, "flatten"):
422                        for expr in expression.flatten():
423                            if isinstance(expr, F) and expr.name in exclude:
424                                return
425                    elif isinstance(expression, F) and expression.name in exclude:
426                        return
427            replacements = {
428                F(field): value
429                for field, value in instance._get_field_value_map(
430                    meta=model._model_meta, exclude=exclude
431                ).items()
432            }
433            expressions = []
434            for expr in self.expressions:
435                # Ignore ordering.
436                if isinstance(expr, OrderBy):
437                    expr = expr.expression
438                expressions.append(Exact(expr, expr.replace_expressions(replacements)))
439            queryset = queryset.filter(*expressions)
440        model_class_id = instance.id
441        if not instance._state.adding and model_class_id is not None:
442            queryset = queryset.exclude(id=model_class_id)
443        if not self.condition:
444            if queryset.exists():
445                if self.expressions:
446                    raise ValidationError(
447                        self.get_violation_error_message(),
448                        code=self.violation_error_code,
449                    )
450                # When fields are defined, use the unique_error_message() for
451                # backward compatibility.
452                for constraint_model, constraints in instance.get_constraints():
453                    for constraint in constraints:
454                        if constraint is self:
455                            raise ValidationError(
456                                instance.unique_error_message(
457                                    constraint_model,  # type: ignore[arg-type]
458                                    self.fields,
459                                ),
460                            )
461        else:
462            against = instance._get_field_value_map(
463                meta=model._model_meta, exclude=exclude
464            )
465            try:
466                if (self.condition & Exists(queryset.filter(self.condition))).check(
467                    against
468                ):
469                    raise ValidationError(
470                        self.get_violation_error_message(),
471                        code=self.violation_error_code,
472                    )
473            except FieldError:
474                pass