Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from abc import ABC, abstractmethod
  4from enum import Enum
  5from types import NoneType
  6from typing import TYPE_CHECKING, Any
  7
  8from plain.exceptions import ValidationError
  9from plain.models.exceptions import FieldError
 10from plain.models.expressions import (
 11    Exists,
 12    ExpressionList,
 13    F,
 14    OrderBy,
 15    ReplaceableExpression,
 16)
 17from plain.models.indexes import IndexExpression
 18from plain.models.lookups import Exact
 19from plain.models.query_utils import Q
 20from plain.models.sql.query import Query
 21
 22if TYPE_CHECKING:
 23    from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
 24    from plain.models.backends.ddl_references import Statement
 25    from plain.models.base import Model
 26
 27__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
 28
 29
 30class BaseConstraint(ABC):
 31    default_violation_error_message = 'Constraint "%(name)s" is violated.'
 32    violation_error_code: str | None = None
 33    violation_error_message: str | None = None
 34
 35    def __init__(
 36        self,
 37        *,
 38        name: str,
 39        violation_error_code: str | None = None,
 40        violation_error_message: str | None = None,
 41    ) -> None:
 42        self.name = name
 43        if violation_error_code is not None:
 44            self.violation_error_code = violation_error_code
 45        if violation_error_message is not None:
 46            self.violation_error_message = violation_error_message
 47        else:
 48            self.violation_error_message = self.default_violation_error_message
 49
 50    @property
 51    def contains_expressions(self) -> bool:
 52        return False
 53
 54    @abstractmethod
 55    def constraint_sql(
 56        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
 57    ) -> str | None: ...
 58
 59    @abstractmethod
 60    def create_sql(
 61        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
 62    ) -> str | Statement | None: ...
 63
 64    @abstractmethod
 65    def remove_sql(
 66        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
 67    ) -> str | Statement | None: ...
 68
 69    @abstractmethod
 70    def validate(
 71        self, model: type[Model], instance: Model, exclude: set[str] | None = None
 72    ) -> None: ...
 73
 74    def get_violation_error_message(self) -> str:
 75        assert self.violation_error_message is not None
 76        return self.violation_error_message % {"name": self.name}
 77
 78    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 79        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
 80        path = path.replace("plain.models.constraints", "plain.models")
 81        kwargs: dict[str, Any] = {"name": self.name}
 82        if (
 83            self.violation_error_message is not None
 84            and self.violation_error_message != self.default_violation_error_message
 85        ):
 86            kwargs["violation_error_message"] = self.violation_error_message
 87        if self.violation_error_code is not None:
 88            kwargs["violation_error_code"] = self.violation_error_code
 89        return (path, (), kwargs)
 90
 91    def clone(self) -> BaseConstraint:
 92        _, args, kwargs = self.deconstruct()
 93        return self.__class__(*args, **kwargs)
 94
 95
 96class CheckConstraint(BaseConstraint):
 97    def __init__(
 98        self,
 99        *,
100        check: Q,
101        name: str,
102        violation_error_code: str | None = None,
103        violation_error_message: str | None = None,
104    ) -> None:
105        self.check = check
106        if not getattr(check, "conditional", False):
107            raise TypeError(
108                "CheckConstraint.check must be a Q instance or boolean expression."
109            )
110        super().__init__(
111            name=name,
112            violation_error_code=violation_error_code,
113            violation_error_message=violation_error_message,
114        )
115
116    def _get_check_sql(
117        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
118    ) -> str:
119        query = Query(model=model, alias_cols=False)
120        where = query.build_where(self.check)
121        compiler = query.get_compiler()
122        sql, params = where.as_sql(compiler, schema_editor.connection)
123        return sql % tuple(schema_editor.quote_value(p) for p in params)
124
125    def constraint_sql(
126        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
127    ) -> str:
128        check = self._get_check_sql(model, schema_editor)
129        return schema_editor._check_sql(self.name, check)
130
131    def create_sql(
132        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
133    ) -> Statement | None:
134        check = self._get_check_sql(model, schema_editor)
135        return schema_editor._create_check_sql(model, self.name, check)
136
137    def remove_sql(
138        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
139    ) -> Statement | None:
140        return schema_editor._delete_check_sql(model, self.name)
141
142    def validate(
143        self, model: type[Model], instance: Model, exclude: set[str] | None = None
144    ) -> None:
145        against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
146        try:
147            if not Q(self.check).check(against):
148                raise ValidationError(
149                    self.get_violation_error_message(), code=self.violation_error_code
150                )
151        except FieldError:
152            pass
153
154    def __repr__(self) -> str:
155        return "<{}: check={} name={}{}{}>".format(
156            self.__class__.__qualname__,
157            self.check,
158            repr(self.name),
159            (
160                ""
161                if self.violation_error_code is None
162                else f" violation_error_code={self.violation_error_code!r}"
163            ),
164            (
165                ""
166                if self.violation_error_message is None
167                or self.violation_error_message == self.default_violation_error_message
168                else f" violation_error_message={self.violation_error_message!r}"
169            ),
170        )
171
172    def __eq__(self, other: object) -> bool:
173        if isinstance(other, CheckConstraint):
174            return (
175                self.name == other.name
176                and self.check == other.check
177                and self.violation_error_code == other.violation_error_code
178                and self.violation_error_message == other.violation_error_message
179            )
180        return super().__eq__(other)
181
182    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
183        path, args, kwargs = super().deconstruct()
184        kwargs["check"] = self.check
185        return path, args, kwargs
186
187
188class Deferrable(Enum):
189    DEFERRED = "deferred"
190    IMMEDIATE = "immediate"
191
192    # A similar format was proposed for Python 3.10.
193    def __repr__(self) -> str:
194        return f"{self.__class__.__qualname__}.{self._name_}"
195
196
197class UniqueConstraint(BaseConstraint):
198    expressions: tuple[ReplaceableExpression, ...]
199
200    def __init__(
201        self,
202        *expressions: str | ReplaceableExpression,
203        fields: tuple[str, ...] | list[str] = (),
204        name: str | None = None,
205        condition: Q | None = None,
206        deferrable: Deferrable | None = None,
207        include: tuple[str, ...] | list[str] | None = None,
208        opclasses: tuple[str, ...] | list[str] = (),
209        violation_error_code: str | None = None,
210        violation_error_message: str | None = None,
211    ) -> None:
212        if not name:
213            raise ValueError("A unique constraint must be named.")
214        if not expressions and not fields:
215            raise ValueError(
216                "At least one field or expression is required to define a "
217                "unique constraint."
218            )
219        if expressions and fields:
220            raise ValueError(
221                "UniqueConstraint.fields and expressions are mutually exclusive."
222            )
223        if not isinstance(condition, NoneType | Q):
224            raise ValueError("UniqueConstraint.condition must be a Q instance.")
225        if condition and deferrable:
226            raise ValueError("UniqueConstraint with conditions cannot be deferred.")
227        if include and deferrable:
228            raise ValueError("UniqueConstraint with include fields cannot be deferred.")
229        if opclasses and deferrable:
230            raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
231        if expressions and deferrable:
232            raise ValueError("UniqueConstraint with expressions cannot be deferred.")
233        if expressions and opclasses:
234            raise ValueError(
235                "UniqueConstraint.opclasses cannot be used with expressions. "
236                "Use a custom OpClass() instead."
237            )
238        if not isinstance(deferrable, NoneType | Deferrable):
239            raise ValueError(
240                "UniqueConstraint.deferrable must be a Deferrable instance."
241            )
242        if not isinstance(include, NoneType | list | tuple):
243            raise ValueError("UniqueConstraint.include must be a list or tuple.")
244        if not isinstance(opclasses, list | tuple):
245            raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
246        if opclasses and len(fields) != len(opclasses):
247            raise ValueError(
248                "UniqueConstraint.fields and UniqueConstraint.opclasses must "
249                "have the same number of elements."
250            )
251        self.fields = tuple(fields)
252        self.condition = condition
253        self.deferrable = deferrable
254        self.include = tuple(include) if include else ()
255        self.opclasses = opclasses
256        self.expressions = tuple(
257            F(expression) if isinstance(expression, str) else expression
258            for expression in expressions
259        )
260        super().__init__(
261            name=name,
262            violation_error_code=violation_error_code,
263            violation_error_message=violation_error_message,
264        )
265
266    @property
267    def contains_expressions(self) -> bool:
268        return bool(self.expressions)
269
270    def _get_condition_sql(
271        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
272    ) -> str | None:
273        if self.condition is None:
274            return None
275        query = Query(model=model, alias_cols=False)
276        where = query.build_where(self.condition)
277        compiler = query.get_compiler()
278        sql, params = where.as_sql(compiler, schema_editor.connection)
279        return sql % tuple(schema_editor.quote_value(p) for p in params)
280
281    def _get_index_expressions(
282        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
283    ) -> Any:
284        if not self.expressions:
285            return None
286        index_expressions = []
287        for expression in self.expressions:
288            index_expression = IndexExpression(expression)
289            index_expression.set_wrapper_classes(schema_editor.connection)
290            index_expressions.append(index_expression)
291        return ExpressionList(*index_expressions).resolve_expression(
292            Query(model, alias_cols=False),
293        )
294
295    def constraint_sql(
296        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
297    ) -> str | None:
298        fields = [
299            model._model_meta.get_forward_field(field_name)
300            for field_name in self.fields
301        ]
302        include = [
303            model._model_meta.get_forward_field(field_name).column
304            for field_name in self.include
305        ]
306        condition = self._get_condition_sql(model, schema_editor)
307        expressions = self._get_index_expressions(model, schema_editor)
308        return schema_editor._unique_sql(
309            model,
310            fields,
311            self.name,
312            condition=condition,
313            deferrable=self.deferrable,
314            include=include,
315            opclasses=tuple(self.opclasses) if self.opclasses else None,
316            expressions=expressions,
317        )
318
319    def create_sql(
320        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
321    ) -> Statement | None:
322        fields = [
323            model._model_meta.get_forward_field(field_name)
324            for field_name in self.fields
325        ]
326        include = [
327            model._model_meta.get_forward_field(field_name).column
328            for field_name in self.include
329        ]
330        condition = self._get_condition_sql(model, schema_editor)
331        expressions = self._get_index_expressions(model, schema_editor)
332        return schema_editor._create_unique_sql(
333            model,
334            fields,
335            self.name,
336            condition=condition,
337            deferrable=self.deferrable,
338            include=include,
339            opclasses=tuple(self.opclasses) if self.opclasses else None,
340            expressions=expressions,
341        )
342
343    def remove_sql(
344        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
345    ) -> Statement | None:
346        condition = self._get_condition_sql(model, schema_editor)
347        include = [
348            model._model_meta.get_forward_field(field_name).column
349            for field_name in self.include
350        ]
351        expressions = self._get_index_expressions(model, schema_editor)
352        return schema_editor._delete_unique_sql(
353            model,
354            self.name,
355            condition=condition,
356            deferrable=self.deferrable,
357            include=include,
358            opclasses=tuple(self.opclasses) if self.opclasses else None,
359            expressions=expressions,
360        )
361
362    def __repr__(self) -> str:
363        return "<{}:{}{}{}{}{}{}{}{}{}>".format(
364            self.__class__.__qualname__,
365            "" if not self.fields else f" fields={repr(self.fields)}",
366            "" if not self.expressions else f" expressions={repr(self.expressions)}",
367            f" name={repr(self.name)}",
368            "" if self.condition is None else f" condition={self.condition}",
369            "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
370            "" if not self.include else f" include={repr(self.include)}",
371            "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
372            (
373                ""
374                if self.violation_error_code is None
375                else f" violation_error_code={self.violation_error_code!r}"
376            ),
377            (
378                ""
379                if self.violation_error_message is None
380                or self.violation_error_message == self.default_violation_error_message
381                else f" violation_error_message={self.violation_error_message!r}"
382            ),
383        )
384
385    def __eq__(self, other: object) -> bool:
386        if isinstance(other, UniqueConstraint):
387            return (
388                self.name == other.name
389                and self.fields == other.fields
390                and self.condition == other.condition
391                and self.deferrable == other.deferrable
392                and self.include == other.include
393                and self.opclasses == other.opclasses
394                and self.expressions == other.expressions
395                and self.violation_error_code == other.violation_error_code
396                and self.violation_error_message == other.violation_error_message
397            )
398        return super().__eq__(other)
399
400    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
401        path, args, kwargs = super().deconstruct()
402        if self.fields:
403            kwargs["fields"] = self.fields
404        if self.condition:
405            kwargs["condition"] = self.condition
406        if self.deferrable:
407            kwargs["deferrable"] = self.deferrable
408        if self.include:
409            kwargs["include"] = self.include
410        if self.opclasses:
411            kwargs["opclasses"] = self.opclasses
412        return path, self.expressions, kwargs
413
414    def validate(
415        self, model: type[Model], instance: Model, exclude: set[str] | None = None
416    ) -> None:
417        queryset = model.query
418        if self.fields:
419            lookup_kwargs = {}
420            for field_name in self.fields:
421                if exclude and field_name in exclude:
422                    return
423                field = model._model_meta.get_forward_field(field_name)
424                lookup_value = getattr(instance, field.attname)
425                if lookup_value is None:
426                    # A composite constraint containing NULL value cannot cause
427                    # a violation since NULL != NULL in SQL.
428                    return
429                lookup_kwargs[field.name] = lookup_value
430            queryset = queryset.filter(**lookup_kwargs)
431        else:
432            # Ignore constraints with excluded fields.
433            if exclude:
434                for expression in self.expressions:
435                    if hasattr(expression, "flatten"):
436                        for expr in expression.flatten():  # type: ignore[call-non-callable]
437                            if isinstance(expr, F) and expr.name in exclude:
438                                return
439                    elif isinstance(expression, F) and expression.name in exclude:
440                        return
441            replacements: dict[Any, Any] = {
442                F(field): value
443                for field, value in instance._get_field_value_map(
444                    meta=model._model_meta, exclude=exclude
445                ).items()
446            }
447            expressions = []
448            for expr in self.expressions:
449                # Ignore ordering.
450                if isinstance(expr, OrderBy):
451                    expr = expr.expression
452                expressions.append(Exact(expr, expr.replace_expressions(replacements)))
453            queryset = queryset.filter(*expressions)
454        model_class_id = instance.id
455        if not instance._state.adding and model_class_id is not None:
456            queryset = queryset.exclude(id=model_class_id)
457        if not self.condition:
458            if queryset.exists():
459                if self.expressions:
460                    raise ValidationError(
461                        self.get_violation_error_message(),
462                        code=self.violation_error_code,
463                    )
464                # When fields are defined, use the unique_error_message() for
465                # backward compatibility.
466                for constraint_model, constraints in instance.get_constraints():
467                    for constraint in constraints:
468                        if constraint is self:
469                            raise ValidationError(
470                                instance.unique_error_message(
471                                    constraint_model,
472                                    self.fields,
473                                ),
474                            )
475        else:
476            against = instance._get_field_value_map(
477                meta=model._model_meta, exclude=exclude
478            )
479            try:
480                if (self.condition & Exists(queryset.filter(self.condition))).check(
481                    against
482                ):
483                    raise ValidationError(
484                        self.get_violation_error_message(),
485                        code=self.violation_error_code,
486                    )
487            except FieldError:
488                pass