Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from enum import Enum
  4from types import NoneType
  5from typing import Any
  6
  7from plain.exceptions import ValidationError
  8from plain.models.exceptions import FieldError
  9from plain.models.expressions import Exists, ExpressionList, F, OrderBy
 10from plain.models.indexes import IndexExpression
 11from plain.models.lookups import Exact
 12from plain.models.query_utils import Q
 13from plain.models.sql.query import Query
 14
 15__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
 16
 17
 18class BaseConstraint:
 19    default_violation_error_message = 'Constraint "%(name)s" is violated.'
 20    violation_error_code: str | None = None
 21    violation_error_message: str | None = None
 22
 23    def __init__(
 24        self,
 25        *,
 26        name: str,
 27        violation_error_code: str | None = None,
 28        violation_error_message: str | None = None,
 29    ) -> None:
 30        self.name = name
 31        if violation_error_code is not None:
 32            self.violation_error_code = violation_error_code
 33        if violation_error_message is not None:
 34            self.violation_error_message = violation_error_message
 35        else:
 36            self.violation_error_message = self.default_violation_error_message
 37
 38    @property
 39    def contains_expressions(self) -> bool:
 40        return False
 41
 42    def constraint_sql(self, model: Any, schema_editor: Any) -> str:
 43        raise NotImplementedError("This method must be implemented by a subclass.")
 44
 45    def create_sql(self, model: Any, schema_editor: Any) -> str:
 46        raise NotImplementedError("This method must be implemented by a subclass.")
 47
 48    def remove_sql(self, model: Any, schema_editor: Any) -> str:
 49        raise NotImplementedError("This method must be implemented by a subclass.")
 50
 51    def validate(
 52        self, model: Any, instance: Any, exclude: set[str] | None = None
 53    ) -> None:
 54        raise NotImplementedError("This method must be implemented by a subclass.")
 55
 56    def get_violation_error_message(self) -> str:
 57        return self.violation_error_message % {"name": self.name}  # type: ignore[operator]
 58
 59    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 60        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
 61        path = path.replace("plain.models.constraints", "plain.models")
 62        kwargs: dict[str, Any] = {"name": self.name}
 63        if (
 64            self.violation_error_message is not None
 65            and self.violation_error_message != self.default_violation_error_message
 66        ):
 67            kwargs["violation_error_message"] = self.violation_error_message
 68        if self.violation_error_code is not None:
 69            kwargs["violation_error_code"] = self.violation_error_code
 70        return (path, (), kwargs)
 71
 72    def clone(self) -> BaseConstraint:
 73        _, args, kwargs = self.deconstruct()
 74        return self.__class__(*args, **kwargs)
 75
 76
 77class CheckConstraint(BaseConstraint):
 78    def __init__(
 79        self,
 80        *,
 81        check: Q,
 82        name: str,
 83        violation_error_code: str | None = None,
 84        violation_error_message: str | None = None,
 85    ) -> None:
 86        self.check = check
 87        if not getattr(check, "conditional", False):
 88            raise TypeError(
 89                "CheckConstraint.check must be a Q instance or boolean expression."
 90            )
 91        super().__init__(
 92            name=name,
 93            violation_error_code=violation_error_code,
 94            violation_error_message=violation_error_message,
 95        )
 96
 97    def _get_check_sql(self, model: Any, schema_editor: Any) -> str:
 98        query = Query(model=model, alias_cols=False)
 99        where = query.build_where(self.check)
100        compiler = query.get_compiler()
101        sql, params = where.as_sql(compiler, schema_editor.connection)
102        return sql % tuple(schema_editor.quote_value(p) for p in params)
103
104    def constraint_sql(self, model: Any, schema_editor: Any) -> str:
105        check = self._get_check_sql(model, schema_editor)
106        return schema_editor._check_sql(self.name, check)
107
108    def create_sql(self, model: Any, schema_editor: Any) -> str:
109        check = self._get_check_sql(model, schema_editor)
110        return schema_editor._create_check_sql(model, self.name, check)
111
112    def remove_sql(self, model: Any, schema_editor: Any) -> str:
113        return schema_editor._delete_check_sql(model, self.name)
114
115    def validate(
116        self, model: Any, instance: Any, exclude: set[str] | None = None
117    ) -> None:
118        against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
119        try:
120            if not Q(self.check).check(against):
121                raise ValidationError(
122                    self.get_violation_error_message(), code=self.violation_error_code
123                )
124        except FieldError:
125            pass
126
127    def __repr__(self) -> str:
128        return "<{}: check={} name={}{}{}>".format(
129            self.__class__.__qualname__,
130            self.check,
131            repr(self.name),
132            (
133                ""
134                if self.violation_error_code is None
135                else f" violation_error_code={self.violation_error_code!r}"
136            ),
137            (
138                ""
139                if self.violation_error_message is None
140                or self.violation_error_message == self.default_violation_error_message
141                else f" violation_error_message={self.violation_error_message!r}"
142            ),
143        )
144
145    def __eq__(self, other: object) -> bool:
146        if isinstance(other, CheckConstraint):
147            return (
148                self.name == other.name
149                and self.check == other.check
150                and self.violation_error_code == other.violation_error_code
151                and self.violation_error_message == other.violation_error_message
152            )
153        return super().__eq__(other)
154
155    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
156        path, args, kwargs = super().deconstruct()
157        kwargs["check"] = self.check
158        return path, args, kwargs
159
160
161class Deferrable(Enum):
162    DEFERRED = "deferred"
163    IMMEDIATE = "immediate"
164
165    # A similar format was proposed for Python 3.10.
166    def __repr__(self) -> str:
167        return f"{self.__class__.__qualname__}.{self._name_}"
168
169
170class UniqueConstraint(BaseConstraint):
171    def __init__(
172        self,
173        *expressions: Any,
174        fields: tuple[str, ...] | list[str] = (),
175        name: str | None = None,
176        condition: Q | None = None,
177        deferrable: Deferrable | None = None,
178        include: tuple[str, ...] | list[str] | None = None,
179        opclasses: tuple[str, ...] | list[str] = (),
180        violation_error_code: str | None = None,
181        violation_error_message: str | None = None,
182    ) -> None:
183        if not name:
184            raise ValueError("A unique constraint must be named.")
185        if not expressions and not fields:
186            raise ValueError(
187                "At least one field or expression is required to define a "
188                "unique constraint."
189            )
190        if expressions and fields:
191            raise ValueError(
192                "UniqueConstraint.fields and expressions are mutually exclusive."
193            )
194        if not isinstance(condition, NoneType | Q):
195            raise ValueError("UniqueConstraint.condition must be a Q instance.")
196        if condition and deferrable:
197            raise ValueError("UniqueConstraint with conditions cannot be deferred.")
198        if include and deferrable:
199            raise ValueError("UniqueConstraint with include fields cannot be deferred.")
200        if opclasses and deferrable:
201            raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
202        if expressions and deferrable:
203            raise ValueError("UniqueConstraint with expressions cannot be deferred.")
204        if expressions and opclasses:
205            raise ValueError(
206                "UniqueConstraint.opclasses cannot be used with expressions. "
207                "Use a custom OpClass() instead."
208            )
209        if not isinstance(deferrable, NoneType | Deferrable):
210            raise ValueError(
211                "UniqueConstraint.deferrable must be a Deferrable instance."
212            )
213        if not isinstance(include, NoneType | list | tuple):
214            raise ValueError("UniqueConstraint.include must be a list or tuple.")
215        if not isinstance(opclasses, list | tuple):
216            raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
217        if opclasses and len(fields) != len(opclasses):
218            raise ValueError(
219                "UniqueConstraint.fields and UniqueConstraint.opclasses must "
220                "have the same number of elements."
221            )
222        self.fields = tuple(fields)
223        self.condition = condition
224        self.deferrable = deferrable
225        self.include = tuple(include) if include else ()
226        self.opclasses = opclasses
227        self.expressions = tuple(
228            F(expression) if isinstance(expression, str) else expression
229            for expression in expressions
230        )
231        super().__init__(
232            name=name,  # type: ignore[arg-type]
233            violation_error_code=violation_error_code,
234            violation_error_message=violation_error_message,
235        )
236
237    @property
238    def contains_expressions(self) -> bool:
239        return bool(self.expressions)
240
241    def _get_condition_sql(self, model: Any, schema_editor: Any) -> str | None:
242        if self.condition is None:
243            return None
244        query = Query(model=model, alias_cols=False)
245        where = query.build_where(self.condition)
246        compiler = query.get_compiler()
247        sql, params = where.as_sql(compiler, schema_editor.connection)
248        return sql % tuple(schema_editor.quote_value(p) for p in params)
249
250    def _get_index_expressions(self, model: Any, schema_editor: Any) -> Any:
251        if not self.expressions:
252            return None
253        index_expressions = []
254        for expression in self.expressions:
255            index_expression = IndexExpression(expression)
256            index_expression.set_wrapper_classes(schema_editor.connection)
257            index_expressions.append(index_expression)
258        return ExpressionList(*index_expressions).resolve_expression(
259            Query(model, alias_cols=False),
260        )
261
262    def constraint_sql(self, model: Any, schema_editor: Any) -> str:
263        fields = [model._model_meta.get_field(field_name) for field_name in self.fields]
264        include = [
265            model._model_meta.get_field(field_name).column
266            for field_name in self.include
267        ]
268        condition = self._get_condition_sql(model, schema_editor)
269        expressions = self._get_index_expressions(model, schema_editor)
270        return schema_editor._unique_sql(
271            model,
272            fields,
273            self.name,
274            condition=condition,
275            deferrable=self.deferrable,
276            include=include,
277            opclasses=self.opclasses,
278            expressions=expressions,
279        )
280
281    def create_sql(self, model: Any, schema_editor: Any) -> str:
282        fields = [model._model_meta.get_field(field_name) for field_name in self.fields]
283        include = [
284            model._model_meta.get_field(field_name).column
285            for field_name in self.include
286        ]
287        condition = self._get_condition_sql(model, schema_editor)
288        expressions = self._get_index_expressions(model, schema_editor)
289        return schema_editor._create_unique_sql(
290            model,
291            fields,
292            self.name,
293            condition=condition,
294            deferrable=self.deferrable,
295            include=include,
296            opclasses=self.opclasses,
297            expressions=expressions,
298        )
299
300    def remove_sql(self, model: Any, schema_editor: Any) -> str:
301        condition = self._get_condition_sql(model, schema_editor)
302        include = [
303            model._model_meta.get_field(field_name).column
304            for field_name in self.include
305        ]
306        expressions = self._get_index_expressions(model, schema_editor)
307        return schema_editor._delete_unique_sql(
308            model,
309            self.name,
310            condition=condition,
311            deferrable=self.deferrable,
312            include=include,
313            opclasses=self.opclasses,
314            expressions=expressions,
315        )
316
317    def __repr__(self) -> str:
318        return "<{}:{}{}{}{}{}{}{}{}{}>".format(
319            self.__class__.__qualname__,
320            "" if not self.fields else f" fields={repr(self.fields)}",
321            "" if not self.expressions else f" expressions={repr(self.expressions)}",
322            f" name={repr(self.name)}",
323            "" if self.condition is None else f" condition={self.condition}",
324            "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
325            "" if not self.include else f" include={repr(self.include)}",
326            "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
327            (
328                ""
329                if self.violation_error_code is None
330                else f" violation_error_code={self.violation_error_code!r}"
331            ),
332            (
333                ""
334                if self.violation_error_message is None
335                or self.violation_error_message == self.default_violation_error_message
336                else f" violation_error_message={self.violation_error_message!r}"
337            ),
338        )
339
340    def __eq__(self, other: object) -> bool:
341        if isinstance(other, UniqueConstraint):
342            return (
343                self.name == other.name
344                and self.fields == other.fields
345                and self.condition == other.condition
346                and self.deferrable == other.deferrable
347                and self.include == other.include
348                and self.opclasses == other.opclasses
349                and self.expressions == other.expressions
350                and self.violation_error_code == other.violation_error_code
351                and self.violation_error_message == other.violation_error_message
352            )
353        return super().__eq__(other)
354
355    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
356        path, args, kwargs = super().deconstruct()
357        if self.fields:
358            kwargs["fields"] = self.fields
359        if self.condition:
360            kwargs["condition"] = self.condition
361        if self.deferrable:
362            kwargs["deferrable"] = self.deferrable
363        if self.include:
364            kwargs["include"] = self.include
365        if self.opclasses:
366            kwargs["opclasses"] = self.opclasses
367        return path, self.expressions, kwargs
368
369    def validate(
370        self, model: Any, instance: Any, exclude: set[str] | None = None
371    ) -> None:
372        queryset = model.query
373        if self.fields:
374            lookup_kwargs = {}
375            for field_name in self.fields:
376                if exclude and field_name in exclude:
377                    return
378                field = model._model_meta.get_field(field_name)
379                lookup_value = getattr(instance, field.attname)
380                if lookup_value is None:
381                    # A composite constraint containing NULL value cannot cause
382                    # a violation since NULL != NULL in SQL.
383                    return
384                lookup_kwargs[field.name] = lookup_value
385            queryset = queryset.filter(**lookup_kwargs)
386        else:
387            # Ignore constraints with excluded fields.
388            if exclude:
389                for expression in self.expressions:
390                    if hasattr(expression, "flatten"):
391                        for expr in expression.flatten():
392                            if isinstance(expr, F) and expr.name in exclude:
393                                return
394                    elif isinstance(expression, F) and expression.name in exclude:
395                        return
396            replacements = {
397                F(field): value
398                for field, value in instance._get_field_value_map(
399                    meta=model._model_meta, exclude=exclude
400                ).items()
401            }
402            expressions = []
403            for expr in self.expressions:
404                # Ignore ordering.
405                if isinstance(expr, OrderBy):
406                    expr = expr.expression
407                expressions.append(Exact(expr, expr.replace_expressions(replacements)))
408            queryset = queryset.filter(*expressions)
409        model_class_id = instance.id
410        if not instance._state.adding and model_class_id is not None:
411            queryset = queryset.exclude(id=model_class_id)
412        if not self.condition:
413            if queryset.exists():
414                if self.expressions:
415                    raise ValidationError(
416                        self.get_violation_error_message(),
417                        code=self.violation_error_code,
418                    )
419                # When fields are defined, use the unique_error_message() for
420                # backward compatibility.
421                for model, constraints in instance.get_constraints():
422                    for constraint in constraints:
423                        if constraint is self:
424                            raise ValidationError(
425                                instance.unique_error_message(model, self.fields),
426                            )
427        else:
428            against = instance._get_field_value_map(
429                meta=model._model_meta, exclude=exclude
430            )
431            try:
432                if (self.condition & Exists(queryset.filter(self.condition))).check(
433                    against
434                ):
435                    raise ValidationError(
436                        self.get_violation_error_message(),
437                        code=self.violation_error_code,
438                    )
439            except FieldError:
440                pass