1from __future__ import annotations
  2
  3from enum import Enum
  4from types import NoneType
  5from typing import TYPE_CHECKING, Any
  6
  7from plain.exceptions import ValidationError
  8from plain.postgres.exceptions import FieldError
  9from plain.postgres.expressions import (
 10    Exists,
 11    ExpressionList,
 12    F,
 13    OrderBy,
 14    ReplaceableExpression,
 15)
 16from plain.postgres.indexes import IndexExpression
 17from plain.postgres.lookups import Exact
 18from plain.postgres.query_utils import Q
 19from plain.postgres.sql.query import Query
 20
 21if TYPE_CHECKING:
 22    from plain.postgres.base import Model
 23    from plain.postgres.schema import DatabaseSchemaEditor, Statement
 24
 25__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
 26
 27
 28class BaseConstraint:
 29    default_violation_error_message = 'Constraint "%(name)s" is violated.'
 30    violation_error_code: str | None = None
 31    violation_error_message: str | None = None
 32
 33    def __init__(
 34        self,
 35        *,
 36        name: str,
 37        violation_error_code: str | None = None,
 38        violation_error_message: str | None = None,
 39    ) -> None:
 40        self.name = name
 41        if violation_error_code is not None:
 42            self.violation_error_code = violation_error_code
 43        if violation_error_message is not None:
 44            self.violation_error_message = violation_error_message
 45        else:
 46            self.violation_error_message = self.default_violation_error_message
 47
 48    @property
 49    def contains_expressions(self) -> bool:
 50        return False
 51
 52    def constraint_sql(
 53        self, model: type[Model], schema_editor: DatabaseSchemaEditor
 54    ) -> str | None:
 55        raise NotImplementedError(
 56            "subclasses of BaseConstraint must provide a constraint_sql() method"
 57        )
 58
 59    def create_sql(
 60        self, model: type[Model], schema_editor: DatabaseSchemaEditor
 61    ) -> str | Statement | None:
 62        raise NotImplementedError(
 63            "subclasses of BaseConstraint must provide a create_sql() method"
 64        )
 65
 66    def remove_sql(
 67        self, model: type[Model], schema_editor: DatabaseSchemaEditor
 68    ) -> str | Statement | None:
 69        raise NotImplementedError(
 70            "subclasses of BaseConstraint must provide a remove_sql() method"
 71        )
 72
 73    def validate(
 74        self, model: type[Model], instance: Model, exclude: set[str] | None = None
 75    ) -> None:
 76        raise NotImplementedError(
 77            "subclasses of BaseConstraint must provide a validate() method"
 78        )
 79
 80    def get_violation_error_message(self) -> str:
 81        assert self.violation_error_message is not None
 82        return self.violation_error_message % {"name": self.name}
 83
 84    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 85        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
 86        path = path.replace("plain.postgres.constraints", "plain.postgres")
 87        kwargs: dict[str, Any] = {"name": self.name}
 88        if (
 89            self.violation_error_message is not None
 90            and self.violation_error_message != self.default_violation_error_message
 91        ):
 92            kwargs["violation_error_message"] = self.violation_error_message
 93        if self.violation_error_code is not None:
 94            kwargs["violation_error_code"] = self.violation_error_code
 95        return (path, (), kwargs)
 96
 97    def clone(self) -> BaseConstraint:
 98        _, args, kwargs = self.deconstruct()
 99        return self.__class__(*args, **kwargs)
100
101
102class CheckConstraint(BaseConstraint):
103    def __init__(
104        self,
105        *,
106        check: Q,
107        name: str,
108        violation_error_code: str | None = None,
109        violation_error_message: str | None = None,
110    ) -> None:
111        self.check = check
112        if not getattr(check, "conditional", False):
113            raise TypeError(
114                "CheckConstraint.check must be a Q instance or boolean expression."
115            )
116        super().__init__(
117            name=name,
118            violation_error_code=violation_error_code,
119            violation_error_message=violation_error_message,
120        )
121
122    def _get_check_sql(
123        self, model: type[Model], schema_editor: DatabaseSchemaEditor
124    ) -> str:
125        query = Query(model=model, alias_cols=False)
126        where = query.build_where(self.check)
127        compiler = query.get_compiler()
128        sql, params = where.as_sql(compiler, schema_editor.connection)
129        return sql % tuple(schema_editor.quote_value(p) for p in params)
130
131    def constraint_sql(
132        self, model: type[Model], schema_editor: DatabaseSchemaEditor
133    ) -> str:
134        check = self._get_check_sql(model, schema_editor)
135        return schema_editor._check_sql(self.name, check)
136
137    def create_sql(
138        self, model: type[Model], schema_editor: DatabaseSchemaEditor
139    ) -> Statement | None:
140        check = self._get_check_sql(model, schema_editor)
141        return schema_editor._create_check_sql(model, self.name, check)
142
143    def remove_sql(
144        self, model: type[Model], schema_editor: DatabaseSchemaEditor
145    ) -> Statement | None:
146        return schema_editor._delete_constraint_sql(
147            schema_editor.sql_delete_check, model, self.name
148        )
149
150    def validate(
151        self, model: type[Model], instance: Model, exclude: set[str] | None = None
152    ) -> None:
153        against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
154        try:
155            if not Q(self.check).check(against):
156                raise ValidationError(
157                    self.get_violation_error_message(), code=self.violation_error_code
158                )
159        except FieldError:
160            pass
161
162    def __repr__(self) -> str:
163        return "<{}: check={} name={}{}{}>".format(
164            self.__class__.__qualname__,
165            self.check,
166            repr(self.name),
167            (
168                ""
169                if self.violation_error_code is None
170                else f" violation_error_code={self.violation_error_code!r}"
171            ),
172            (
173                ""
174                if self.violation_error_message is None
175                or self.violation_error_message == self.default_violation_error_message
176                else f" violation_error_message={self.violation_error_message!r}"
177            ),
178        )
179
180    def __eq__(self, other: object) -> bool:
181        if isinstance(other, CheckConstraint):
182            return (
183                self.name == other.name
184                and self.check == other.check
185                and self.violation_error_code == other.violation_error_code
186                and self.violation_error_message == other.violation_error_message
187            )
188        return super().__eq__(other)
189
190    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
191        path, args, kwargs = super().deconstruct()
192        kwargs["check"] = self.check
193        return path, args, kwargs
194
195
196class Deferrable(Enum):
197    DEFERRED = "deferred"
198    IMMEDIATE = "immediate"
199
200    # A similar format was proposed for Python 3.10.
201    def __repr__(self) -> str:
202        return f"{self.__class__.__qualname__}.{self._name_}"
203
204
205class UniqueConstraint(BaseConstraint):
206    expressions: tuple[ReplaceableExpression, ...]
207
208    def __init__(
209        self,
210        *expressions: str | ReplaceableExpression,
211        fields: tuple[str, ...] | list[str] = (),
212        name: str | None = None,
213        condition: Q | None = None,
214        deferrable: Deferrable | None = None,
215        include: tuple[str, ...] | list[str] | None = None,
216        opclasses: tuple[str, ...] | list[str] = (),
217        violation_error_code: str | None = None,
218        violation_error_message: str | None = None,
219    ) -> None:
220        if not name:
221            raise ValueError("A unique constraint must be named.")
222        if not expressions and not fields:
223            raise ValueError(
224                "At least one field or expression is required to define a "
225                "unique constraint."
226            )
227        if expressions and fields:
228            raise ValueError(
229                "UniqueConstraint.fields and expressions are mutually exclusive."
230            )
231        if not isinstance(condition, NoneType | Q):
232            raise ValueError("UniqueConstraint.condition must be a Q instance.")
233        if condition and deferrable:
234            raise ValueError("UniqueConstraint with conditions cannot be deferred.")
235        if include and deferrable:
236            raise ValueError("UniqueConstraint with include fields cannot be deferred.")
237        if opclasses and deferrable:
238            raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
239        if expressions and deferrable:
240            raise ValueError("UniqueConstraint with expressions cannot be deferred.")
241        if expressions and opclasses:
242            raise ValueError(
243                "UniqueConstraint.opclasses cannot be used with expressions. "
244                "Use a custom OpClass() instead."
245            )
246        if not isinstance(deferrable, NoneType | Deferrable):
247            raise ValueError(
248                "UniqueConstraint.deferrable must be a Deferrable instance."
249            )
250        if not isinstance(include, NoneType | list | tuple):
251            raise ValueError("UniqueConstraint.include must be a list or tuple.")
252        if not isinstance(opclasses, list | tuple):
253            raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
254        if opclasses and len(fields) != len(opclasses):
255            raise ValueError(
256                "UniqueConstraint.fields and UniqueConstraint.opclasses must "
257                "have the same number of elements."
258            )
259        self.fields = tuple(fields)
260        self.condition = condition
261        self.deferrable = deferrable
262        self.include = tuple(include) if include else ()
263        self.opclasses = opclasses
264        self.expressions = tuple(
265            F(expression) if isinstance(expression, str) else expression
266            for expression in expressions
267        )
268        super().__init__(
269            name=name,
270            violation_error_code=violation_error_code,
271            violation_error_message=violation_error_message,
272        )
273
274    @property
275    def contains_expressions(self) -> bool:
276        return bool(self.expressions)
277
278    def _get_condition_sql(
279        self, model: type[Model], schema_editor: DatabaseSchemaEditor
280    ) -> str | None:
281        if self.condition is None:
282            return None
283        query = Query(model=model, alias_cols=False)
284        where = query.build_where(self.condition)
285        compiler = query.get_compiler()
286        sql, params = where.as_sql(compiler, schema_editor.connection)
287        return sql % tuple(schema_editor.quote_value(p) for p in params)
288
289    def _get_index_expressions(
290        self, model: type[Model], schema_editor: DatabaseSchemaEditor
291    ) -> Any:
292        if not self.expressions:
293            return None
294        index_expressions = []
295        for expression in self.expressions:
296            index_expression = IndexExpression(expression)
297            index_expressions.append(index_expression)
298        return ExpressionList(*index_expressions).resolve_expression(
299            Query(model, alias_cols=False),
300        )
301
302    def constraint_sql(
303        self, model: type[Model], schema_editor: DatabaseSchemaEditor
304    ) -> str | None:
305        fields = [
306            model._model_meta.get_forward_field(field_name)
307            for field_name in self.fields
308        ]
309        include = [
310            model._model_meta.get_forward_field(field_name).column
311            for field_name in self.include
312        ]
313        condition = self._get_condition_sql(model, schema_editor)
314        expressions = self._get_index_expressions(model, schema_editor)
315        return schema_editor._unique_sql(
316            model,
317            fields,
318            self.name,
319            condition=condition,
320            deferrable=self.deferrable,
321            include=include,
322            opclasses=tuple(self.opclasses) if self.opclasses else None,
323            expressions=expressions,
324        )
325
326    def create_sql(
327        self, model: type[Model], schema_editor: DatabaseSchemaEditor
328    ) -> Statement | None:
329        fields = [
330            model._model_meta.get_forward_field(field_name)
331            for field_name in self.fields
332        ]
333        include = [
334            model._model_meta.get_forward_field(field_name).column
335            for field_name in self.include
336        ]
337        condition = self._get_condition_sql(model, schema_editor)
338        expressions = self._get_index_expressions(model, schema_editor)
339        return schema_editor._create_unique_sql(
340            model,
341            fields,
342            self.name,
343            condition=condition,
344            deferrable=self.deferrable,
345            include=include,
346            opclasses=tuple(self.opclasses) if self.opclasses else None,
347            expressions=expressions,
348        )
349
350    def remove_sql(
351        self, model: type[Model], schema_editor: DatabaseSchemaEditor
352    ) -> Statement | None:
353        condition = self._get_condition_sql(model, schema_editor)
354        include = [
355            model._model_meta.get_forward_field(field_name).column
356            for field_name in self.include
357        ]
358        expressions = self._get_index_expressions(model, schema_editor)
359        return schema_editor._delete_unique_sql(
360            model,
361            self.name,
362            condition=condition,
363            deferrable=self.deferrable,
364            include=include,
365            opclasses=tuple(self.opclasses) if self.opclasses else None,
366            expressions=expressions,
367        )
368
369    def __repr__(self) -> str:
370        return "<{}:{}{}{}{}{}{}{}{}{}>".format(
371            self.__class__.__qualname__,
372            "" if not self.fields else f" fields={repr(self.fields)}",
373            "" if not self.expressions else f" expressions={repr(self.expressions)}",
374            f" name={repr(self.name)}",
375            "" if self.condition is None else f" condition={self.condition}",
376            "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
377            "" if not self.include else f" include={repr(self.include)}",
378            "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
379            (
380                ""
381                if self.violation_error_code is None
382                else f" violation_error_code={self.violation_error_code!r}"
383            ),
384            (
385                ""
386                if self.violation_error_message is None
387                or self.violation_error_message == self.default_violation_error_message
388                else f" violation_error_message={self.violation_error_message!r}"
389            ),
390        )
391
392    def __eq__(self, other: object) -> bool:
393        if isinstance(other, UniqueConstraint):
394            return (
395                self.name == other.name
396                and self.fields == other.fields
397                and self.condition == other.condition
398                and self.deferrable == other.deferrable
399                and self.include == other.include
400                and self.opclasses == other.opclasses
401                and self.expressions == other.expressions
402                and self.violation_error_code == other.violation_error_code
403                and self.violation_error_message == other.violation_error_message
404            )
405        return super().__eq__(other)
406
407    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
408        path, args, kwargs = super().deconstruct()
409        if self.fields:
410            kwargs["fields"] = self.fields
411        if self.condition:
412            kwargs["condition"] = self.condition
413        if self.deferrable:
414            kwargs["deferrable"] = self.deferrable
415        if self.include:
416            kwargs["include"] = self.include
417        if self.opclasses:
418            kwargs["opclasses"] = self.opclasses
419        return path, self.expressions, kwargs
420
421    def validate(
422        self, model: type[Model], instance: Model, exclude: set[str] | None = None
423    ) -> None:
424        queryset = model.query
425        if self.fields:
426            lookup_kwargs = {}
427            for field_name in self.fields:
428                if exclude and field_name in exclude:
429                    return
430                field = model._model_meta.get_forward_field(field_name)
431                lookup_value = getattr(instance, field.attname)
432                if lookup_value is None:
433                    # A composite constraint containing NULL value cannot cause
434                    # a violation since NULL != NULL in SQL.
435                    return
436                lookup_kwargs[field.name] = lookup_value
437            queryset = queryset.filter(**lookup_kwargs)
438        else:
439            # Ignore constraints with excluded fields.
440            if exclude:
441                for expression in self.expressions:
442                    if hasattr(expression, "flatten"):
443                        for expr in expression.flatten():  # type: ignore[operator]
444                            if isinstance(expr, F) and expr.name in exclude:
445                                return
446                    elif isinstance(expression, F) and expression.name in exclude:
447                        return
448            replacements: dict[Any, Any] = {
449                F(field): value
450                for field, value in instance._get_field_value_map(
451                    meta=model._model_meta, exclude=exclude
452                ).items()
453            }
454            expressions = []
455            for expr in self.expressions:
456                # Ignore ordering.
457                if isinstance(expr, OrderBy):
458                    expr = expr.expression
459                expressions.append(Exact(expr, expr.replace_expressions(replacements)))
460            queryset = queryset.filter(*expressions)
461        model_class_id = instance.id
462        if not instance._state.adding and model_class_id is not None:
463            queryset = queryset.exclude(id=model_class_id)
464        if not self.condition:
465            if queryset.exists():
466                if self.expressions:
467                    raise ValidationError(
468                        self.get_violation_error_message(),
469                        code=self.violation_error_code,
470                    )
471                # When fields are defined, use the unique_error_message() for
472                # backward compatibility.
473                for constraint_model, constraints in instance.get_constraints():
474                    for constraint in constraints:
475                        if constraint is self:
476                            raise ValidationError(
477                                instance.unique_error_message(
478                                    constraint_model,
479                                    self.fields,
480                                ),
481                            )
482        else:
483            against = instance._get_field_value_map(
484                meta=model._model_meta, exclude=exclude
485            )
486            try:
487                if (self.condition & Exists(queryset.filter(self.condition))).check(
488                    against
489                ):
490                    raise ValidationError(
491                        self.get_violation_error_message(),
492                        code=self.violation_error_code,
493                    )
494            except FieldError:
495                pass