1from __future__ import annotations
  2
  3from abc import ABC, abstractmethod
  4from enum import Enum
  5from types import NoneType
  6from typing import TYPE_CHECKING, Any
  7
  8from plain.exceptions import ValidationError
  9from plain.models.exceptions import FieldError
 10from plain.models.expressions import (
 11    Exists,
 12    ExpressionList,
 13    F,
 14    OrderBy,
 15    ReplaceableExpression,
 16)
 17from plain.models.indexes import IndexExpression
 18from plain.models.lookups import Exact
 19from plain.models.query_utils import Q
 20from plain.models.sql.query import Query
 21
 22if TYPE_CHECKING:
 23    from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
 24    from plain.models.backends.ddl_references import Statement
 25    from plain.models.base import Model
 26
 27__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
 28
 29
 30class BaseConstraint(ABC):
 31    default_violation_error_message = 'Constraint "%(name)s" is violated.'
 32    violation_error_code: str | None = None
 33    violation_error_message: str | None = None
 34
 35    def __init__(
 36        self,
 37        *,
 38        name: str,
 39        violation_error_code: str | None = None,
 40        violation_error_message: str | None = None,
 41    ) -> None:
 42        self.name = name
 43        if violation_error_code is not None:
 44            self.violation_error_code = violation_error_code
 45        if violation_error_message is not None:
 46            self.violation_error_message = violation_error_message
 47        else:
 48            self.violation_error_message = self.default_violation_error_message
 49
 50    @property
 51    def contains_expressions(self) -> bool:
 52        return False
 53
 54    @abstractmethod
 55    def constraint_sql(
 56        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
 57    ) -> str | None: ...
 58
 59    @abstractmethod
 60    def create_sql(
 61        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
 62    ) -> str | Statement | None: ...
 63
 64    @abstractmethod
 65    def remove_sql(
 66        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
 67    ) -> str | Statement | None: ...
 68
 69    @abstractmethod
 70    def validate(
 71        self, model: type[Model], instance: Model, exclude: set[str] | None = None
 72    ) -> None: ...
 73
 74    def get_violation_error_message(self) -> str:
 75        assert self.violation_error_message is not None
 76        return self.violation_error_message % {"name": self.name}
 77
 78    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 79        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
 80        path = path.replace("plain.models.constraints", "plain.models")
 81        kwargs: dict[str, Any] = {"name": self.name}
 82        if (
 83            self.violation_error_message is not None
 84            and self.violation_error_message != self.default_violation_error_message
 85        ):
 86            kwargs["violation_error_message"] = self.violation_error_message
 87        if self.violation_error_code is not None:
 88            kwargs["violation_error_code"] = self.violation_error_code
 89        return (path, (), kwargs)
 90
 91    def clone(self) -> BaseConstraint:
 92        _, args, kwargs = self.deconstruct()
 93        return self.__class__(*args, **kwargs)
 94
 95
 96class CheckConstraint(BaseConstraint):
 97    def __init__(
 98        self,
 99        *,
100        check: Q,
101        name: str,
102        violation_error_code: str | None = None,
103        violation_error_message: str | None = None,
104    ) -> None:
105        self.check = check
106        if not getattr(check, "conditional", False):
107            raise TypeError(
108                "CheckConstraint.check must be a Q instance or boolean expression."
109            )
110        super().__init__(
111            name=name,
112            violation_error_code=violation_error_code,
113            violation_error_message=violation_error_message,
114        )
115
116    def _get_check_sql(
117        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
118    ) -> str:
119        query = Query(model=model, alias_cols=False)
120        where = query.build_where(self.check)
121        compiler = query.get_compiler()
122        sql, params = where.as_sql(compiler, schema_editor.connection)
123        return sql % tuple(schema_editor.quote_value(p) for p in params)
124
125    def constraint_sql(
126        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
127    ) -> str:
128        check = self._get_check_sql(model, schema_editor)
129        return schema_editor._check_sql(self.name, check)
130
131    def create_sql(
132        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
133    ) -> Statement | None:
134        check = self._get_check_sql(model, schema_editor)
135        return schema_editor._create_check_sql(model, self.name, check)
136
137    def remove_sql(
138        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
139    ) -> Statement | None:
140        return schema_editor._delete_check_sql(model, self.name)
141
142    def validate(
143        self, model: type[Model], instance: Model, exclude: set[str] | None = None
144    ) -> None:
145        against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
146        try:
147            if not Q(self.check).check(against):
148                raise ValidationError(
149                    self.get_violation_error_message(), code=self.violation_error_code
150                )
151        except FieldError:
152            pass
153
154    def __repr__(self) -> str:
155        return "<{}: check={} name={}{}{}>".format(
156            self.__class__.__qualname__,
157            self.check,
158            repr(self.name),
159            (
160                ""
161                if self.violation_error_code is None
162                else f" violation_error_code={self.violation_error_code!r}"
163            ),
164            (
165                ""
166                if self.violation_error_message is None
167                or self.violation_error_message == self.default_violation_error_message
168                else f" violation_error_message={self.violation_error_message!r}"
169            ),
170        )
171
172    def __eq__(self, other: object) -> bool:
173        if isinstance(other, CheckConstraint):
174            return (
175                self.name == other.name
176                and self.check == other.check
177                and self.violation_error_code == other.violation_error_code
178                and self.violation_error_message == other.violation_error_message
179            )
180        return super().__eq__(other)
181
182    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
183        path, args, kwargs = super().deconstruct()
184        kwargs["check"] = self.check
185        return path, args, kwargs
186
187
188class Deferrable(Enum):
189    DEFERRED = "deferred"
190    IMMEDIATE = "immediate"
191
192    # A similar format was proposed for Python 3.10.
193    def __repr__(self) -> str:
194        return f"{self.__class__.__qualname__}.{self._name_}"
195
196
197class UniqueConstraint(BaseConstraint):
198    expressions: tuple[ReplaceableExpression, ...]
199
200    def __init__(
201        self,
202        *expressions: str | ReplaceableExpression,
203        fields: tuple[str, ...] | list[str] = (),
204        name: str | None = None,
205        condition: Q | None = None,
206        deferrable: Deferrable | None = None,
207        include: tuple[str, ...] | list[str] | None = None,
208        opclasses: tuple[str, ...] | list[str] = (),
209        violation_error_code: str | None = None,
210        violation_error_message: str | None = None,
211    ) -> None:
212        if not name:
213            raise ValueError("A unique constraint must be named.")
214        if not expressions and not fields:
215            raise ValueError(
216                "At least one field or expression is required to define a "
217                "unique constraint."
218            )
219        if expressions and fields:
220            raise ValueError(
221                "UniqueConstraint.fields and expressions are mutually exclusive."
222            )
223        if not isinstance(condition, NoneType | Q):
224            raise ValueError("UniqueConstraint.condition must be a Q instance.")
225        if condition and deferrable:
226            raise ValueError("UniqueConstraint with conditions cannot be deferred.")
227        if include and deferrable:
228            raise ValueError("UniqueConstraint with include fields cannot be deferred.")
229        if opclasses and deferrable:
230            raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
231        if expressions and deferrable:
232            raise ValueError("UniqueConstraint with expressions cannot be deferred.")
233        if expressions and opclasses:
234            raise ValueError(
235                "UniqueConstraint.opclasses cannot be used with expressions. "
236                "Use a custom OpClass() instead."
237            )
238        if not isinstance(deferrable, NoneType | Deferrable):
239            raise ValueError(
240                "UniqueConstraint.deferrable must be a Deferrable instance."
241            )
242        if not isinstance(include, NoneType | list | tuple):
243            raise ValueError("UniqueConstraint.include must be a list or tuple.")
244        if not isinstance(opclasses, list | tuple):
245            raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
246        if opclasses and len(fields) != len(opclasses):
247            raise ValueError(
248                "UniqueConstraint.fields and UniqueConstraint.opclasses must "
249                "have the same number of elements."
250            )
251        self.fields = tuple(fields)
252        self.condition = condition
253        self.deferrable = deferrable
254        self.include = tuple(include) if include else ()
255        self.opclasses = opclasses
256        self.expressions = tuple(
257            F(expression) if isinstance(expression, str) else expression
258            for expression in expressions
259        )
260        super().__init__(
261            name=name,
262            violation_error_code=violation_error_code,
263            violation_error_message=violation_error_message,
264        )
265
266    @property
267    def contains_expressions(self) -> bool:
268        return bool(self.expressions)
269
270    def _get_condition_sql(
271        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
272    ) -> str | None:
273        if self.condition is None:
274            return None
275        query = Query(model=model, alias_cols=False)
276        where = query.build_where(self.condition)
277        compiler = query.get_compiler()
278        sql, params = where.as_sql(compiler, schema_editor.connection)
279        return sql % tuple(schema_editor.quote_value(p) for p in params)
280
281    def _get_index_expressions(
282        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
283    ) -> Any:
284        if not self.expressions:
285            return None
286        index_expressions = []
287        for expression in self.expressions:
288            index_expression = IndexExpression(expression)
289            index_expressions.append(index_expression)
290        return ExpressionList(*index_expressions).resolve_expression(
291            Query(model, alias_cols=False),
292        )
293
294    def constraint_sql(
295        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
296    ) -> str | None:
297        fields = [
298            model._model_meta.get_forward_field(field_name)
299            for field_name in self.fields
300        ]
301        include = [
302            model._model_meta.get_forward_field(field_name).column
303            for field_name in self.include
304        ]
305        condition = self._get_condition_sql(model, schema_editor)
306        expressions = self._get_index_expressions(model, schema_editor)
307        return schema_editor._unique_sql(
308            model,
309            fields,
310            self.name,
311            condition=condition,
312            deferrable=self.deferrable,
313            include=include,
314            opclasses=tuple(self.opclasses) if self.opclasses else None,
315            expressions=expressions,
316        )
317
318    def create_sql(
319        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
320    ) -> Statement | None:
321        fields = [
322            model._model_meta.get_forward_field(field_name)
323            for field_name in self.fields
324        ]
325        include = [
326            model._model_meta.get_forward_field(field_name).column
327            for field_name in self.include
328        ]
329        condition = self._get_condition_sql(model, schema_editor)
330        expressions = self._get_index_expressions(model, schema_editor)
331        return schema_editor._create_unique_sql(
332            model,
333            fields,
334            self.name,
335            condition=condition,
336            deferrable=self.deferrable,
337            include=include,
338            opclasses=tuple(self.opclasses) if self.opclasses else None,
339            expressions=expressions,
340        )
341
342    def remove_sql(
343        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
344    ) -> Statement | None:
345        condition = self._get_condition_sql(model, schema_editor)
346        include = [
347            model._model_meta.get_forward_field(field_name).column
348            for field_name in self.include
349        ]
350        expressions = self._get_index_expressions(model, schema_editor)
351        return schema_editor._delete_unique_sql(
352            model,
353            self.name,
354            condition=condition,
355            deferrable=self.deferrable,
356            include=include,
357            opclasses=tuple(self.opclasses) if self.opclasses else None,
358            expressions=expressions,
359        )
360
361    def __repr__(self) -> str:
362        return "<{}:{}{}{}{}{}{}{}{}{}>".format(
363            self.__class__.__qualname__,
364            "" if not self.fields else f" fields={repr(self.fields)}",
365            "" if not self.expressions else f" expressions={repr(self.expressions)}",
366            f" name={repr(self.name)}",
367            "" if self.condition is None else f" condition={self.condition}",
368            "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
369            "" if not self.include else f" include={repr(self.include)}",
370            "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
371            (
372                ""
373                if self.violation_error_code is None
374                else f" violation_error_code={self.violation_error_code!r}"
375            ),
376            (
377                ""
378                if self.violation_error_message is None
379                or self.violation_error_message == self.default_violation_error_message
380                else f" violation_error_message={self.violation_error_message!r}"
381            ),
382        )
383
384    def __eq__(self, other: object) -> bool:
385        if isinstance(other, UniqueConstraint):
386            return (
387                self.name == other.name
388                and self.fields == other.fields
389                and self.condition == other.condition
390                and self.deferrable == other.deferrable
391                and self.include == other.include
392                and self.opclasses == other.opclasses
393                and self.expressions == other.expressions
394                and self.violation_error_code == other.violation_error_code
395                and self.violation_error_message == other.violation_error_message
396            )
397        return super().__eq__(other)
398
399    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
400        path, args, kwargs = super().deconstruct()
401        if self.fields:
402            kwargs["fields"] = self.fields
403        if self.condition:
404            kwargs["condition"] = self.condition
405        if self.deferrable:
406            kwargs["deferrable"] = self.deferrable
407        if self.include:
408            kwargs["include"] = self.include
409        if self.opclasses:
410            kwargs["opclasses"] = self.opclasses
411        return path, self.expressions, kwargs
412
413    def validate(
414        self, model: type[Model], instance: Model, exclude: set[str] | None = None
415    ) -> None:
416        queryset = model.query
417        if self.fields:
418            lookup_kwargs = {}
419            for field_name in self.fields:
420                if exclude and field_name in exclude:
421                    return
422                field = model._model_meta.get_forward_field(field_name)
423                lookup_value = getattr(instance, field.attname)
424                if lookup_value is None:
425                    # A composite constraint containing NULL value cannot cause
426                    # a violation since NULL != NULL in SQL.
427                    return
428                lookup_kwargs[field.name] = lookup_value
429            queryset = queryset.filter(**lookup_kwargs)
430        else:
431            # Ignore constraints with excluded fields.
432            if exclude:
433                for expression in self.expressions:
434                    if hasattr(expression, "flatten"):
435                        for expr in expression.flatten():  # type: ignore[operator]
436                            if isinstance(expr, F) and expr.name in exclude:
437                                return
438                    elif isinstance(expression, F) and expression.name in exclude:
439                        return
440            replacements: dict[Any, Any] = {
441                F(field): value
442                for field, value in instance._get_field_value_map(
443                    meta=model._model_meta, exclude=exclude
444                ).items()
445            }
446            expressions = []
447            for expr in self.expressions:
448                # Ignore ordering.
449                if isinstance(expr, OrderBy):
450                    expr = expr.expression
451                expressions.append(Exact(expr, expr.replace_expressions(replacements)))
452            queryset = queryset.filter(*expressions)
453        model_class_id = instance.id
454        if not instance._state.adding and model_class_id is not None:
455            queryset = queryset.exclude(id=model_class_id)
456        if not self.condition:
457            if queryset.exists():
458                if self.expressions:
459                    raise ValidationError(
460                        self.get_violation_error_message(),
461                        code=self.violation_error_code,
462                    )
463                # When fields are defined, use the unique_error_message() for
464                # backward compatibility.
465                for constraint_model, constraints in instance.get_constraints():
466                    for constraint in constraints:
467                        if constraint is self:
468                            raise ValidationError(
469                                instance.unique_error_message(
470                                    constraint_model,
471                                    self.fields,
472                                ),
473                            )
474        else:
475            against = instance._get_field_value_map(
476                meta=model._model_meta, exclude=exclude
477            )
478            try:
479                if (self.condition & Exists(queryset.filter(self.condition))).check(
480                    against
481                ):
482                    raise ValidationError(
483                        self.get_violation_error_message(),
484                        code=self.violation_error_code,
485                    )
486            except FieldError:
487                pass