v0.146.0
  1from __future__ import annotations
  2
  3from enum import Enum
  4from types import NoneType
  5from typing import TYPE_CHECKING, Any
  6
  7from plain.exceptions import ValidationError
  8from plain.postgres.constants import LOOKUP_SEP
  9from plain.postgres.ddl import (
 10    build_include_sql,
 11    compile_expression_sql,
 12    compile_index_expressions_sql,
 13    deferrable_sql,
 14)
 15from plain.postgres.dialect import quote_name
 16from plain.postgres.exceptions import FieldError
 17from plain.postgres.expressions import (
 18    Exists,
 19    F,
 20    OrderBy,
 21    ReplaceableExpression,
 22)
 23from plain.postgres.lookups import Exact
 24from plain.postgres.query_utils import Q
 25
 26if TYPE_CHECKING:
 27    from plain.postgres.base import Model
 28
 29__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
 30
 31
 32ViolationError = str | dict[str, Any] | list[Any] | ValidationError
 33
 34
 35class BaseConstraint:
 36    violation_error: ViolationError | None = None
 37
 38    def __init__(
 39        self,
 40        *,
 41        name: str,
 42        violation_error: ViolationError | None = None,
 43    ) -> None:
 44        self.name = name
 45        self.violation_error = violation_error
 46
 47    @property
 48    def contains_expressions(self) -> bool:
 49        return False
 50
 51    def to_sql(self, model: type[Model]) -> str:
 52        raise NotImplementedError(
 53            "subclasses of BaseConstraint must provide a to_sql() method"
 54        )
 55
 56    def validate(
 57        self, model: type[Model], instance: Model, exclude: set[str] | None = None
 58    ) -> None:
 59        raise NotImplementedError(
 60            "subclasses of BaseConstraint must provide a validate() method"
 61        )
 62
 63    def _build_violation_error(self) -> ValidationError:
 64        if self.violation_error is None:
 65            return ValidationError(f'Constraint "{self.name}" is violated.')
 66        if isinstance(self.violation_error, ValidationError):
 67            return self.violation_error
 68        return ValidationError(self.violation_error)
 69
 70    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 71        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
 72        path = path.replace("plain.postgres.constraints", "plain.postgres")
 73        kwargs: dict[str, Any] = {"name": self.name}
 74        if self.violation_error is not None:
 75            kwargs["violation_error"] = self.violation_error
 76        return (path, (), kwargs)
 77
 78    def clone(self) -> BaseConstraint:
 79        _, args, kwargs = self.deconstruct()
 80        return self.__class__(*args, **kwargs)
 81
 82
 83class CheckConstraint(BaseConstraint):
 84    def __init__(
 85        self,
 86        *,
 87        check: Q,
 88        name: str,
 89        violation_error: ViolationError | None = None,
 90    ) -> None:
 91        self.check = check
 92        if not getattr(check, "conditional", False):
 93            raise TypeError(
 94                "CheckConstraint.check must be a Q instance or boolean expression."
 95            )
 96        super().__init__(name=name, violation_error=violation_error)
 97
 98    def to_sql(self, model: type[Model], *, not_valid: bool = False) -> str:
 99        """Generate ALTER TABLE ADD CONSTRAINT CHECK SQL as a plain string."""
100        check = compile_expression_sql(model, self.check)
101        table = quote_name(model.model_options.db_table)
102        name = quote_name(self.name)
103        sql = f"ALTER TABLE {table} ADD CONSTRAINT {name} CHECK ({check})"
104        if not_valid:
105            sql += " NOT VALID"
106        return sql
107
108    def referenced_fields(self) -> set[str]:
109        """Top-level model field names referenced by `self.check`.
110
111        Walks lookup keys (`field__regex` → `field`), nested Q nodes, and
112        F-expressions in values or other source expressions.
113        """
114        fields: set[str] = set()
115
116        def visit(node: Any) -> None:
117            if isinstance(node, Q):
118                for child in node.children:
119                    visit(child)
120            elif isinstance(node, tuple) and len(node) == 2:
121                lookup, value = node
122                fields.add(lookup.split(LOOKUP_SEP, 1)[0])
123                visit(value)
124            elif isinstance(node, F):
125                fields.add(node.name.split(LOOKUP_SEP, 1)[0])
126            elif hasattr(node, "get_source_expressions"):
127                for sub in node.get_source_expressions():
128                    visit(sub)
129
130        visit(self.check)
131        return fields
132
133    def validate(
134        self, model: type[Model], instance: Model, exclude: set[str] | None = None
135    ) -> None:
136        against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
137        # Skip the check entirely when any field referenced by `self.check` was
138        # excluded — the in-Python pipeline can't resolve a missing field's
139        # annotation, and surfacing a constraint violation here would just
140        # duplicate the field-level error that caused the exclusion.
141        if not self.referenced_fields().issubset(against):
142            return
143        try:
144            if not Q(self.check).check(against):
145                raise self._build_violation_error()
146        except FieldError:
147            pass
148
149    def __repr__(self) -> str:
150        return "<{}: check={} name={}{}>".format(
151            self.__class__.__qualname__,
152            self.check,
153            repr(self.name),
154            (
155                ""
156                if self.violation_error is None
157                else f" violation_error={self.violation_error!r}"
158            ),
159        )
160
161    def __eq__(self, other: object) -> bool:
162        if isinstance(other, CheckConstraint):
163            return (
164                self.name == other.name
165                and self.check == other.check
166                and self.violation_error == other.violation_error
167            )
168        return super().__eq__(other)
169
170    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
171        path, args, kwargs = super().deconstruct()
172        kwargs["check"] = self.check
173        return path, args, kwargs
174
175
176class Deferrable(Enum):
177    DEFERRED = "deferred"
178    IMMEDIATE = "immediate"
179
180    # A similar format was proposed for Python 3.10.
181    def __repr__(self) -> str:
182        return f"{self.__class__.__qualname__}.{self._name_}"
183
184
185class UniqueConstraint(BaseConstraint):
186    expressions: tuple[ReplaceableExpression, ...]
187
188    def __init__(
189        self,
190        *expressions: str | ReplaceableExpression,
191        fields: tuple[str, ...] | list[str] = (),
192        name: str | None = None,
193        condition: Q | None = None,
194        deferrable: Deferrable | None = None,
195        include: tuple[str, ...] | list[str] | None = None,
196        opclasses: tuple[str, ...] | list[str] = (),
197        violation_error: ViolationError | None = None,
198    ) -> None:
199        if not name:
200            raise ValueError("A unique constraint must be named.")
201        if not expressions and not fields:
202            raise ValueError(
203                "At least one field or expression is required to define a "
204                "unique constraint."
205            )
206        if expressions and fields:
207            raise ValueError(
208                "UniqueConstraint.fields and expressions are mutually exclusive."
209            )
210        if not isinstance(condition, NoneType | Q):
211            raise ValueError("UniqueConstraint.condition must be a Q instance.")
212        if condition and deferrable:
213            raise ValueError("UniqueConstraint with conditions cannot be deferred.")
214        if include and deferrable:
215            raise ValueError("UniqueConstraint with include fields cannot be deferred.")
216        if opclasses and deferrable:
217            raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
218        if expressions and deferrable:
219            raise ValueError("UniqueConstraint with expressions cannot be deferred.")
220        if expressions and opclasses:
221            raise ValueError(
222                "UniqueConstraint.opclasses cannot be used with expressions. "
223                "Use a custom OpClass() instead."
224            )
225        if not isinstance(deferrable, NoneType | Deferrable):
226            raise ValueError(
227                "UniqueConstraint.deferrable must be a Deferrable instance."
228            )
229        if not isinstance(include, NoneType | list | tuple):
230            raise ValueError("UniqueConstraint.include must be a list or tuple.")
231        if not isinstance(opclasses, list | tuple):
232            raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
233        if opclasses and len(fields) != len(opclasses):
234            raise ValueError(
235                "UniqueConstraint.fields and UniqueConstraint.opclasses must "
236                "have the same number of elements."
237            )
238        self.fields = tuple(fields)
239        self.condition = condition
240        self.deferrable = deferrable
241        self.include = tuple(include) if include else ()
242        self.opclasses = opclasses
243        self.expressions = tuple(
244            F(expression) if isinstance(expression, str) else expression
245            for expression in expressions
246        )
247        super().__init__(name=name, violation_error=violation_error)
248
249    @property
250    def contains_expressions(self) -> bool:
251        return bool(self.expressions)
252
253    @property
254    def is_partial(self) -> bool:
255        return self.condition is not None
256
257    @property
258    def index_only(self) -> bool:
259        """Whether PostgreSQL can only store this as a unique index, not a constraint.
260
261        PostgreSQL rejects ALTER TABLE ADD CONSTRAINT UNIQUE USING INDEX for
262        partial indexes, expression indexes, and indexes with non-default
263        operator classes.
264        """
265        return bool(self.condition or self.expressions or self.opclasses)
266
267    def to_sql(self, model: type[Model], *, concurrently: bool = False) -> str:
268        """Generate CREATE UNIQUE INDEX or ALTER TABLE ADD CONSTRAINT UNIQUE SQL."""
269        table = quote_name(model.model_options.db_table)
270        name = quote_name(self.name)
271        condition = (
272            compile_expression_sql(model, self.condition)
273            if self.condition is not None
274            else None
275        )
276
277        if self.expressions:
278            columns_sql = compile_index_expressions_sql(model, self.expressions)
279        else:
280            col_parts = []
281            for i, field_name in enumerate(self.fields):
282                field = model._model_meta.get_forward_field(field_name)
283                col = quote_name(field.column)
284                if self.opclasses:
285                    col = f"{col} {self.opclasses[i]}"
286                col_parts.append(col)
287            columns_sql = ", ".join(col_parts)
288
289        include_sql = build_include_sql(model, self.include)
290        condition_sql = f" WHERE ({condition})" if condition else ""
291
292        if concurrently:
293            return f"CREATE UNIQUE INDEX CONCURRENTLY {name} ON {table} ({columns_sql}){include_sql}{condition_sql}"
294        elif condition or self.include or self.opclasses or self.expressions:
295            return f"CREATE UNIQUE INDEX {name} ON {table} ({columns_sql}){include_sql}{condition_sql}"
296        else:
297            return f"ALTER TABLE {table} ADD CONSTRAINT {name} UNIQUE ({columns_sql}){deferrable_sql(self.deferrable)}"
298
299    def to_attach_sql(self, model: type[Model]) -> str:
300        """Generate ALTER TABLE ADD CONSTRAINT UNIQUE USING INDEX SQL.
301
302        Used after creating the unique index concurrently to attach it
303        as a named constraint.
304        """
305        table = quote_name(model.model_options.db_table)
306        name = quote_name(self.name)
307        sql = f"ALTER TABLE {table} ADD CONSTRAINT {name} UNIQUE USING INDEX {name}"
308        sql += deferrable_sql(self.deferrable)
309        return sql
310
311    def __repr__(self) -> str:
312        return "<{}:{}{}{}{}{}{}{}{}>".format(
313            self.__class__.__qualname__,
314            "" if not self.fields else f" fields={repr(self.fields)}",
315            "" if not self.expressions else f" expressions={repr(self.expressions)}",
316            f" name={repr(self.name)}",
317            "" if self.condition is None else f" condition={self.condition}",
318            "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
319            "" if not self.include else f" include={repr(self.include)}",
320            "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
321            (
322                ""
323                if self.violation_error is None
324                else f" violation_error={self.violation_error!r}"
325            ),
326        )
327
328    def __eq__(self, other: object) -> bool:
329        if isinstance(other, UniqueConstraint):
330            return (
331                self.name == other.name
332                and self.fields == other.fields
333                and self.condition == other.condition
334                and self.deferrable == other.deferrable
335                and self.include == other.include
336                and self.opclasses == other.opclasses
337                and self.expressions == other.expressions
338                and self.violation_error == other.violation_error
339            )
340        return super().__eq__(other)
341
342    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
343        path, args, kwargs = super().deconstruct()
344        if self.fields:
345            kwargs["fields"] = self.fields
346        if self.condition:
347            kwargs["condition"] = self.condition
348        if self.deferrable:
349            kwargs["deferrable"] = self.deferrable
350        if self.include:
351            kwargs["include"] = self.include
352        if self.opclasses:
353            kwargs["opclasses"] = self.opclasses
354        return path, self.expressions, kwargs
355
356    def validate(
357        self, model: type[Model], instance: Model, exclude: set[str] | None = None
358    ) -> None:
359        queryset = model.query
360        if self.fields:
361            lookup_kwargs = {}
362            for field_name in self.fields:
363                if exclude and field_name in exclude:
364                    return
365                field = model._model_meta.get_forward_field(field_name)
366                lookup_value = getattr(instance, field.attname)
367                if lookup_value is None:
368                    # A composite constraint containing NULL value cannot cause
369                    # a violation since NULL != NULL in SQL.
370                    return
371                lookup_kwargs[field.name] = lookup_value
372            queryset = queryset.filter(**lookup_kwargs)
373        else:
374            # Ignore constraints with excluded fields.
375            if exclude:
376                for expression in self.expressions:
377                    if hasattr(expression, "flatten"):
378                        for expr in expression.flatten():  # ty: ignore[call-non-callable]
379                            if isinstance(expr, F) and expr.name in exclude:
380                                return
381                    elif isinstance(expression, F) and expression.name in exclude:
382                        return
383            replacements: dict[Any, Any] = {
384                F(field): value
385                for field, value in instance._get_field_value_map(
386                    meta=model._model_meta, exclude=exclude
387                ).items()
388            }
389            expressions = []
390            for expr in self.expressions:
391                # Ignore ordering.
392                if isinstance(expr, OrderBy):
393                    expr = expr.expression
394                expressions.append(Exact(expr, expr.replace_expressions(replacements)))
395            queryset = queryset.filter(*expressions)
396        model_class_id = instance.id
397        if not instance._state.adding and model_class_id is not None:
398            queryset = queryset.exclude(id=model_class_id)
399        if not self.condition:
400            if queryset.exists():
401                raise self._build_unique_violation(instance, model)
402        else:
403            against = instance._get_field_value_map(
404                meta=model._model_meta, exclude=exclude
405            )
406            try:
407                if (self.condition & Exists(queryset.filter(self.condition))).check(
408                    against
409                ):
410                    raise self._build_unique_violation(instance, model)
411            except FieldError:
412                pass
413
414    def _build_unique_violation(
415        self, instance: Model, model: type[Model]
416    ) -> ValidationError:
417        """Build the ValidationError for a unique violation.
418
419        Single-field unique constraints route the error to that field via the
420        dict form so it surfaces under the field rather than NON_FIELD_ERRORS.
421        """
422        single_field = self.fields[0] if len(self.fields) == 1 else None
423
424        if self.violation_error is not None:
425            err = self._build_violation_error()
426            # Only auto-route flat errors. A ValidationError that already has
427            # an error_dict (from dict-form input or a caller-built instance)
428            # already declares its own field routing — don't override it.
429            if single_field and not hasattr(err, "error_dict"):
430                return ValidationError({single_field: [err]})
431            return err
432
433        if self.fields:
434            err = instance.unique_error_message(model, self.fields)
435            if single_field:
436                return ValidationError({single_field: [err]})
437            return err
438        return ValidationError(f'Constraint "{self.name}" is violated.')