1from __future__ import annotations
2
3from enum import Enum
4from types import NoneType
5from typing import TYPE_CHECKING, Any
6
7from plain.exceptions import ValidationError
8from plain.postgres.constants import LOOKUP_SEP
9from plain.postgres.ddl import (
10 build_include_sql,
11 compile_expression_sql,
12 compile_index_expressions_sql,
13 deferrable_sql,
14)
15from plain.postgres.dialect import quote_name
16from plain.postgres.exceptions import FieldError
17from plain.postgres.expressions import (
18 Exists,
19 F,
20 OrderBy,
21 ReplaceableExpression,
22)
23from plain.postgres.lookups import Exact
24from plain.postgres.query_utils import Q
25
26if TYPE_CHECKING:
27 from plain.postgres.base import Model
28
29__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
30
31
32ViolationError = str | dict[str, Any] | list[Any] | ValidationError
33
34
35class BaseConstraint:
36 violation_error: ViolationError | None = None
37
38 def __init__(
39 self,
40 *,
41 name: str,
42 violation_error: ViolationError | None = None,
43 ) -> None:
44 self.name = name
45 self.violation_error = violation_error
46
47 @property
48 def contains_expressions(self) -> bool:
49 return False
50
51 def to_sql(self, model: type[Model]) -> str:
52 raise NotImplementedError(
53 "subclasses of BaseConstraint must provide a to_sql() method"
54 )
55
56 def validate(
57 self, model: type[Model], instance: Model, exclude: set[str] | None = None
58 ) -> None:
59 raise NotImplementedError(
60 "subclasses of BaseConstraint must provide a validate() method"
61 )
62
63 def _build_violation_error(self) -> ValidationError:
64 if self.violation_error is None:
65 return ValidationError(f'Constraint "{self.name}" is violated.')
66 if isinstance(self.violation_error, ValidationError):
67 return self.violation_error
68 return ValidationError(self.violation_error)
69
70 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
71 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
72 path = path.replace("plain.postgres.constraints", "plain.postgres")
73 kwargs: dict[str, Any] = {"name": self.name}
74 if self.violation_error is not None:
75 kwargs["violation_error"] = self.violation_error
76 return (path, (), kwargs)
77
78 def clone(self) -> BaseConstraint:
79 _, args, kwargs = self.deconstruct()
80 return self.__class__(*args, **kwargs)
81
82
83class CheckConstraint(BaseConstraint):
84 def __init__(
85 self,
86 *,
87 check: Q,
88 name: str,
89 violation_error: ViolationError | None = None,
90 ) -> None:
91 self.check = check
92 if not getattr(check, "conditional", False):
93 raise TypeError(
94 "CheckConstraint.check must be a Q instance or boolean expression."
95 )
96 super().__init__(name=name, violation_error=violation_error)
97
98 def to_sql(self, model: type[Model], *, not_valid: bool = False) -> str:
99 """Generate ALTER TABLE ADD CONSTRAINT CHECK SQL as a plain string."""
100 check = compile_expression_sql(model, self.check)
101 table = quote_name(model.model_options.db_table)
102 name = quote_name(self.name)
103 sql = f"ALTER TABLE {table} ADD CONSTRAINT {name} CHECK ({check})"
104 if not_valid:
105 sql += " NOT VALID"
106 return sql
107
108 def referenced_fields(self) -> set[str]:
109 """Top-level model field names referenced by `self.check`.
110
111 Walks lookup keys (`field__regex` → `field`), nested Q nodes, and
112 F-expressions in values or other source expressions.
113 """
114 fields: set[str] = set()
115
116 def visit(node: Any) -> None:
117 if isinstance(node, Q):
118 for child in node.children:
119 visit(child)
120 elif isinstance(node, tuple) and len(node) == 2:
121 lookup, value = node
122 fields.add(lookup.split(LOOKUP_SEP, 1)[0])
123 visit(value)
124 elif isinstance(node, F):
125 fields.add(node.name.split(LOOKUP_SEP, 1)[0])
126 elif hasattr(node, "get_source_expressions"):
127 for sub in node.get_source_expressions():
128 visit(sub)
129
130 visit(self.check)
131 return fields
132
133 def validate(
134 self, model: type[Model], instance: Model, exclude: set[str] | None = None
135 ) -> None:
136 against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
137 # Skip the check entirely when any field referenced by `self.check` was
138 # excluded — the in-Python pipeline can't resolve a missing field's
139 # annotation, and surfacing a constraint violation here would just
140 # duplicate the field-level error that caused the exclusion.
141 if not self.referenced_fields().issubset(against):
142 return
143 try:
144 if not Q(self.check).check(against):
145 raise self._build_violation_error()
146 except FieldError:
147 pass
148
149 def __repr__(self) -> str:
150 return "<{}: check={} name={}{}>".format(
151 self.__class__.__qualname__,
152 self.check,
153 repr(self.name),
154 (
155 ""
156 if self.violation_error is None
157 else f" violation_error={self.violation_error!r}"
158 ),
159 )
160
161 def __eq__(self, other: object) -> bool:
162 if isinstance(other, CheckConstraint):
163 return (
164 self.name == other.name
165 and self.check == other.check
166 and self.violation_error == other.violation_error
167 )
168 return super().__eq__(other)
169
170 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
171 path, args, kwargs = super().deconstruct()
172 kwargs["check"] = self.check
173 return path, args, kwargs
174
175
176class Deferrable(Enum):
177 DEFERRED = "deferred"
178 IMMEDIATE = "immediate"
179
180 # A similar format was proposed for Python 3.10.
181 def __repr__(self) -> str:
182 return f"{self.__class__.__qualname__}.{self._name_}"
183
184
185class UniqueConstraint(BaseConstraint):
186 expressions: tuple[ReplaceableExpression, ...]
187
188 def __init__(
189 self,
190 *expressions: str | ReplaceableExpression,
191 fields: tuple[str, ...] | list[str] = (),
192 name: str | None = None,
193 condition: Q | None = None,
194 deferrable: Deferrable | None = None,
195 include: tuple[str, ...] | list[str] | None = None,
196 opclasses: tuple[str, ...] | list[str] = (),
197 violation_error: ViolationError | None = None,
198 ) -> None:
199 if not name:
200 raise ValueError("A unique constraint must be named.")
201 if not expressions and not fields:
202 raise ValueError(
203 "At least one field or expression is required to define a "
204 "unique constraint."
205 )
206 if expressions and fields:
207 raise ValueError(
208 "UniqueConstraint.fields and expressions are mutually exclusive."
209 )
210 if not isinstance(condition, NoneType | Q):
211 raise ValueError("UniqueConstraint.condition must be a Q instance.")
212 if condition and deferrable:
213 raise ValueError("UniqueConstraint with conditions cannot be deferred.")
214 if include and deferrable:
215 raise ValueError("UniqueConstraint with include fields cannot be deferred.")
216 if opclasses and deferrable:
217 raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
218 if expressions and deferrable:
219 raise ValueError("UniqueConstraint with expressions cannot be deferred.")
220 if expressions and opclasses:
221 raise ValueError(
222 "UniqueConstraint.opclasses cannot be used with expressions. "
223 "Use a custom OpClass() instead."
224 )
225 if not isinstance(deferrable, NoneType | Deferrable):
226 raise ValueError(
227 "UniqueConstraint.deferrable must be a Deferrable instance."
228 )
229 if not isinstance(include, NoneType | list | tuple):
230 raise ValueError("UniqueConstraint.include must be a list or tuple.")
231 if not isinstance(opclasses, list | tuple):
232 raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
233 if opclasses and len(fields) != len(opclasses):
234 raise ValueError(
235 "UniqueConstraint.fields and UniqueConstraint.opclasses must "
236 "have the same number of elements."
237 )
238 self.fields = tuple(fields)
239 self.condition = condition
240 self.deferrable = deferrable
241 self.include = tuple(include) if include else ()
242 self.opclasses = opclasses
243 self.expressions = tuple(
244 F(expression) if isinstance(expression, str) else expression
245 for expression in expressions
246 )
247 super().__init__(name=name, violation_error=violation_error)
248
249 @property
250 def contains_expressions(self) -> bool:
251 return bool(self.expressions)
252
253 @property
254 def is_partial(self) -> bool:
255 return self.condition is not None
256
257 @property
258 def index_only(self) -> bool:
259 """Whether PostgreSQL can only store this as a unique index, not a constraint.
260
261 PostgreSQL rejects ALTER TABLE ADD CONSTRAINT UNIQUE USING INDEX for
262 partial indexes, expression indexes, and indexes with non-default
263 operator classes.
264 """
265 return bool(self.condition or self.expressions or self.opclasses)
266
267 def to_sql(self, model: type[Model], *, concurrently: bool = False) -> str:
268 """Generate CREATE UNIQUE INDEX or ALTER TABLE ADD CONSTRAINT UNIQUE SQL."""
269 table = quote_name(model.model_options.db_table)
270 name = quote_name(self.name)
271 condition = (
272 compile_expression_sql(model, self.condition)
273 if self.condition is not None
274 else None
275 )
276
277 if self.expressions:
278 columns_sql = compile_index_expressions_sql(model, self.expressions)
279 else:
280 col_parts = []
281 for i, field_name in enumerate(self.fields):
282 field = model._model_meta.get_forward_field(field_name)
283 col = quote_name(field.column)
284 if self.opclasses:
285 col = f"{col} {self.opclasses[i]}"
286 col_parts.append(col)
287 columns_sql = ", ".join(col_parts)
288
289 include_sql = build_include_sql(model, self.include)
290 condition_sql = f" WHERE ({condition})" if condition else ""
291
292 if concurrently:
293 return f"CREATE UNIQUE INDEX CONCURRENTLY {name} ON {table} ({columns_sql}){include_sql}{condition_sql}"
294 elif condition or self.include or self.opclasses or self.expressions:
295 return f"CREATE UNIQUE INDEX {name} ON {table} ({columns_sql}){include_sql}{condition_sql}"
296 else:
297 return f"ALTER TABLE {table} ADD CONSTRAINT {name} UNIQUE ({columns_sql}){deferrable_sql(self.deferrable)}"
298
299 def to_attach_sql(self, model: type[Model]) -> str:
300 """Generate ALTER TABLE ADD CONSTRAINT UNIQUE USING INDEX SQL.
301
302 Used after creating the unique index concurrently to attach it
303 as a named constraint.
304 """
305 table = quote_name(model.model_options.db_table)
306 name = quote_name(self.name)
307 sql = f"ALTER TABLE {table} ADD CONSTRAINT {name} UNIQUE USING INDEX {name}"
308 sql += deferrable_sql(self.deferrable)
309 return sql
310
311 def __repr__(self) -> str:
312 return "<{}:{}{}{}{}{}{}{}{}>".format(
313 self.__class__.__qualname__,
314 "" if not self.fields else f" fields={repr(self.fields)}",
315 "" if not self.expressions else f" expressions={repr(self.expressions)}",
316 f" name={repr(self.name)}",
317 "" if self.condition is None else f" condition={self.condition}",
318 "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
319 "" if not self.include else f" include={repr(self.include)}",
320 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
321 (
322 ""
323 if self.violation_error is None
324 else f" violation_error={self.violation_error!r}"
325 ),
326 )
327
328 def __eq__(self, other: object) -> bool:
329 if isinstance(other, UniqueConstraint):
330 return (
331 self.name == other.name
332 and self.fields == other.fields
333 and self.condition == other.condition
334 and self.deferrable == other.deferrable
335 and self.include == other.include
336 and self.opclasses == other.opclasses
337 and self.expressions == other.expressions
338 and self.violation_error == other.violation_error
339 )
340 return super().__eq__(other)
341
342 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
343 path, args, kwargs = super().deconstruct()
344 if self.fields:
345 kwargs["fields"] = self.fields
346 if self.condition:
347 kwargs["condition"] = self.condition
348 if self.deferrable:
349 kwargs["deferrable"] = self.deferrable
350 if self.include:
351 kwargs["include"] = self.include
352 if self.opclasses:
353 kwargs["opclasses"] = self.opclasses
354 return path, self.expressions, kwargs
355
356 def validate(
357 self, model: type[Model], instance: Model, exclude: set[str] | None = None
358 ) -> None:
359 queryset = model.query
360 if self.fields:
361 lookup_kwargs = {}
362 for field_name in self.fields:
363 if exclude and field_name in exclude:
364 return
365 field = model._model_meta.get_forward_field(field_name)
366 lookup_value = getattr(instance, field.attname)
367 if lookup_value is None:
368 # A composite constraint containing NULL value cannot cause
369 # a violation since NULL != NULL in SQL.
370 return
371 lookup_kwargs[field.name] = lookup_value
372 queryset = queryset.filter(**lookup_kwargs)
373 else:
374 # Ignore constraints with excluded fields.
375 if exclude:
376 for expression in self.expressions:
377 if hasattr(expression, "flatten"):
378 for expr in expression.flatten(): # ty: ignore[call-non-callable]
379 if isinstance(expr, F) and expr.name in exclude:
380 return
381 elif isinstance(expression, F) and expression.name in exclude:
382 return
383 replacements: dict[Any, Any] = {
384 F(field): value
385 for field, value in instance._get_field_value_map(
386 meta=model._model_meta, exclude=exclude
387 ).items()
388 }
389 expressions = []
390 for expr in self.expressions:
391 # Ignore ordering.
392 if isinstance(expr, OrderBy):
393 expr = expr.expression
394 expressions.append(Exact(expr, expr.replace_expressions(replacements)))
395 queryset = queryset.filter(*expressions)
396 model_class_id = instance.id
397 if not instance._state.adding and model_class_id is not None:
398 queryset = queryset.exclude(id=model_class_id)
399 if not self.condition:
400 if queryset.exists():
401 raise self._build_unique_violation(instance, model)
402 else:
403 against = instance._get_field_value_map(
404 meta=model._model_meta, exclude=exclude
405 )
406 try:
407 if (self.condition & Exists(queryset.filter(self.condition))).check(
408 against
409 ):
410 raise self._build_unique_violation(instance, model)
411 except FieldError:
412 pass
413
414 def _build_unique_violation(
415 self, instance: Model, model: type[Model]
416 ) -> ValidationError:
417 """Build the ValidationError for a unique violation.
418
419 Single-field unique constraints route the error to that field via the
420 dict form so it surfaces under the field rather than NON_FIELD_ERRORS.
421 """
422 single_field = self.fields[0] if len(self.fields) == 1 else None
423
424 if self.violation_error is not None:
425 err = self._build_violation_error()
426 # Only auto-route flat errors. A ValidationError that already has
427 # an error_dict (from dict-form input or a caller-built instance)
428 # already declares its own field routing — don't override it.
429 if single_field and not hasattr(err, "error_dict"):
430 return ValidationError({single_field: [err]})
431 return err
432
433 if self.fields:
434 err = instance.unique_error_message(model, self.fields)
435 if single_field:
436 return ValidationError({single_field: [err]})
437 return err
438 return ValidationError(f'Constraint "{self.name}" is violated.')