1from __future__ import annotations
2
3from enum import Enum
4from types import NoneType
5from typing import TYPE_CHECKING, Any
6
7from plain.exceptions import ValidationError
8from plain.postgres.exceptions import FieldError
9from plain.postgres.expressions import (
10 Exists,
11 ExpressionList,
12 F,
13 OrderBy,
14 ReplaceableExpression,
15)
16from plain.postgres.indexes import IndexExpression
17from plain.postgres.lookups import Exact
18from plain.postgres.query_utils import Q
19from plain.postgres.sql.query import Query
20
21if TYPE_CHECKING:
22 from plain.postgres.base import Model
23 from plain.postgres.schema import DatabaseSchemaEditor, Statement
24
25__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
26
27
28class BaseConstraint:
29 default_violation_error_message = 'Constraint "%(name)s" is violated.'
30 violation_error_code: str | None = None
31 violation_error_message: str | None = None
32
33 def __init__(
34 self,
35 *,
36 name: str,
37 violation_error_code: str | None = None,
38 violation_error_message: str | None = None,
39 ) -> None:
40 self.name = name
41 if violation_error_code is not None:
42 self.violation_error_code = violation_error_code
43 if violation_error_message is not None:
44 self.violation_error_message = violation_error_message
45 else:
46 self.violation_error_message = self.default_violation_error_message
47
48 @property
49 def contains_expressions(self) -> bool:
50 return False
51
52 def constraint_sql(
53 self, model: type[Model], schema_editor: DatabaseSchemaEditor
54 ) -> str | None:
55 raise NotImplementedError(
56 "subclasses of BaseConstraint must provide a constraint_sql() method"
57 )
58
59 def create_sql(
60 self, model: type[Model], schema_editor: DatabaseSchemaEditor
61 ) -> str | Statement | None:
62 raise NotImplementedError(
63 "subclasses of BaseConstraint must provide a create_sql() method"
64 )
65
66 def remove_sql(
67 self, model: type[Model], schema_editor: DatabaseSchemaEditor
68 ) -> str | Statement | None:
69 raise NotImplementedError(
70 "subclasses of BaseConstraint must provide a remove_sql() method"
71 )
72
73 def validate(
74 self, model: type[Model], instance: Model, exclude: set[str] | None = None
75 ) -> None:
76 raise NotImplementedError(
77 "subclasses of BaseConstraint must provide a validate() method"
78 )
79
80 def get_violation_error_message(self) -> str:
81 assert self.violation_error_message is not None
82 return self.violation_error_message % {"name": self.name}
83
84 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
85 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
86 path = path.replace("plain.postgres.constraints", "plain.postgres")
87 kwargs: dict[str, Any] = {"name": self.name}
88 if (
89 self.violation_error_message is not None
90 and self.violation_error_message != self.default_violation_error_message
91 ):
92 kwargs["violation_error_message"] = self.violation_error_message
93 if self.violation_error_code is not None:
94 kwargs["violation_error_code"] = self.violation_error_code
95 return (path, (), kwargs)
96
97 def clone(self) -> BaseConstraint:
98 _, args, kwargs = self.deconstruct()
99 return self.__class__(*args, **kwargs)
100
101
102class CheckConstraint(BaseConstraint):
103 def __init__(
104 self,
105 *,
106 check: Q,
107 name: str,
108 violation_error_code: str | None = None,
109 violation_error_message: str | None = None,
110 ) -> None:
111 self.check = check
112 if not getattr(check, "conditional", False):
113 raise TypeError(
114 "CheckConstraint.check must be a Q instance or boolean expression."
115 )
116 super().__init__(
117 name=name,
118 violation_error_code=violation_error_code,
119 violation_error_message=violation_error_message,
120 )
121
122 def _get_check_sql(
123 self, model: type[Model], schema_editor: DatabaseSchemaEditor
124 ) -> str:
125 query = Query(model=model, alias_cols=False)
126 where = query.build_where(self.check)
127 compiler = query.get_compiler()
128 sql, params = where.as_sql(compiler, schema_editor.connection)
129 return sql % tuple(schema_editor.quote_value(p) for p in params)
130
131 def constraint_sql(
132 self, model: type[Model], schema_editor: DatabaseSchemaEditor
133 ) -> str:
134 check = self._get_check_sql(model, schema_editor)
135 return schema_editor._check_sql(self.name, check)
136
137 def create_sql(
138 self, model: type[Model], schema_editor: DatabaseSchemaEditor
139 ) -> Statement | None:
140 check = self._get_check_sql(model, schema_editor)
141 return schema_editor._create_check_sql(model, self.name, check)
142
143 def remove_sql(
144 self, model: type[Model], schema_editor: DatabaseSchemaEditor
145 ) -> Statement | None:
146 return schema_editor._delete_constraint_sql(
147 schema_editor.sql_delete_check, model, self.name
148 )
149
150 def validate(
151 self, model: type[Model], instance: Model, exclude: set[str] | None = None
152 ) -> None:
153 against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
154 try:
155 if not Q(self.check).check(against):
156 raise ValidationError(
157 self.get_violation_error_message(), code=self.violation_error_code
158 )
159 except FieldError:
160 pass
161
162 def __repr__(self) -> str:
163 return "<{}: check={} name={}{}{}>".format(
164 self.__class__.__qualname__,
165 self.check,
166 repr(self.name),
167 (
168 ""
169 if self.violation_error_code is None
170 else f" violation_error_code={self.violation_error_code!r}"
171 ),
172 (
173 ""
174 if self.violation_error_message is None
175 or self.violation_error_message == self.default_violation_error_message
176 else f" violation_error_message={self.violation_error_message!r}"
177 ),
178 )
179
180 def __eq__(self, other: object) -> bool:
181 if isinstance(other, CheckConstraint):
182 return (
183 self.name == other.name
184 and self.check == other.check
185 and self.violation_error_code == other.violation_error_code
186 and self.violation_error_message == other.violation_error_message
187 )
188 return super().__eq__(other)
189
190 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
191 path, args, kwargs = super().deconstruct()
192 kwargs["check"] = self.check
193 return path, args, kwargs
194
195
196class Deferrable(Enum):
197 DEFERRED = "deferred"
198 IMMEDIATE = "immediate"
199
200 # A similar format was proposed for Python 3.10.
201 def __repr__(self) -> str:
202 return f"{self.__class__.__qualname__}.{self._name_}"
203
204
205class UniqueConstraint(BaseConstraint):
206 expressions: tuple[ReplaceableExpression, ...]
207
208 def __init__(
209 self,
210 *expressions: str | ReplaceableExpression,
211 fields: tuple[str, ...] | list[str] = (),
212 name: str | None = None,
213 condition: Q | None = None,
214 deferrable: Deferrable | None = None,
215 include: tuple[str, ...] | list[str] | None = None,
216 opclasses: tuple[str, ...] | list[str] = (),
217 violation_error_code: str | None = None,
218 violation_error_message: str | None = None,
219 ) -> None:
220 if not name:
221 raise ValueError("A unique constraint must be named.")
222 if not expressions and not fields:
223 raise ValueError(
224 "At least one field or expression is required to define a "
225 "unique constraint."
226 )
227 if expressions and fields:
228 raise ValueError(
229 "UniqueConstraint.fields and expressions are mutually exclusive."
230 )
231 if not isinstance(condition, NoneType | Q):
232 raise ValueError("UniqueConstraint.condition must be a Q instance.")
233 if condition and deferrable:
234 raise ValueError("UniqueConstraint with conditions cannot be deferred.")
235 if include and deferrable:
236 raise ValueError("UniqueConstraint with include fields cannot be deferred.")
237 if opclasses and deferrable:
238 raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
239 if expressions and deferrable:
240 raise ValueError("UniqueConstraint with expressions cannot be deferred.")
241 if expressions and opclasses:
242 raise ValueError(
243 "UniqueConstraint.opclasses cannot be used with expressions. "
244 "Use a custom OpClass() instead."
245 )
246 if not isinstance(deferrable, NoneType | Deferrable):
247 raise ValueError(
248 "UniqueConstraint.deferrable must be a Deferrable instance."
249 )
250 if not isinstance(include, NoneType | list | tuple):
251 raise ValueError("UniqueConstraint.include must be a list or tuple.")
252 if not isinstance(opclasses, list | tuple):
253 raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
254 if opclasses and len(fields) != len(opclasses):
255 raise ValueError(
256 "UniqueConstraint.fields and UniqueConstraint.opclasses must "
257 "have the same number of elements."
258 )
259 self.fields = tuple(fields)
260 self.condition = condition
261 self.deferrable = deferrable
262 self.include = tuple(include) if include else ()
263 self.opclasses = opclasses
264 self.expressions = tuple(
265 F(expression) if isinstance(expression, str) else expression
266 for expression in expressions
267 )
268 super().__init__(
269 name=name,
270 violation_error_code=violation_error_code,
271 violation_error_message=violation_error_message,
272 )
273
274 @property
275 def contains_expressions(self) -> bool:
276 return bool(self.expressions)
277
278 def _get_condition_sql(
279 self, model: type[Model], schema_editor: DatabaseSchemaEditor
280 ) -> str | None:
281 if self.condition is None:
282 return None
283 query = Query(model=model, alias_cols=False)
284 where = query.build_where(self.condition)
285 compiler = query.get_compiler()
286 sql, params = where.as_sql(compiler, schema_editor.connection)
287 return sql % tuple(schema_editor.quote_value(p) for p in params)
288
289 def _get_index_expressions(
290 self, model: type[Model], schema_editor: DatabaseSchemaEditor
291 ) -> Any:
292 if not self.expressions:
293 return None
294 index_expressions = []
295 for expression in self.expressions:
296 index_expression = IndexExpression(expression)
297 index_expressions.append(index_expression)
298 return ExpressionList(*index_expressions).resolve_expression(
299 Query(model, alias_cols=False),
300 )
301
302 def constraint_sql(
303 self, model: type[Model], schema_editor: DatabaseSchemaEditor
304 ) -> str | None:
305 fields = [
306 model._model_meta.get_forward_field(field_name)
307 for field_name in self.fields
308 ]
309 include = [
310 model._model_meta.get_forward_field(field_name).column
311 for field_name in self.include
312 ]
313 condition = self._get_condition_sql(model, schema_editor)
314 expressions = self._get_index_expressions(model, schema_editor)
315 return schema_editor._unique_sql(
316 model,
317 fields,
318 self.name,
319 condition=condition,
320 deferrable=self.deferrable,
321 include=include,
322 opclasses=tuple(self.opclasses) if self.opclasses else None,
323 expressions=expressions,
324 )
325
326 def create_sql(
327 self, model: type[Model], schema_editor: DatabaseSchemaEditor
328 ) -> Statement | None:
329 fields = [
330 model._model_meta.get_forward_field(field_name)
331 for field_name in self.fields
332 ]
333 include = [
334 model._model_meta.get_forward_field(field_name).column
335 for field_name in self.include
336 ]
337 condition = self._get_condition_sql(model, schema_editor)
338 expressions = self._get_index_expressions(model, schema_editor)
339 return schema_editor._create_unique_sql(
340 model,
341 fields,
342 self.name,
343 condition=condition,
344 deferrable=self.deferrable,
345 include=include,
346 opclasses=tuple(self.opclasses) if self.opclasses else None,
347 expressions=expressions,
348 )
349
350 def remove_sql(
351 self, model: type[Model], schema_editor: DatabaseSchemaEditor
352 ) -> Statement | None:
353 condition = self._get_condition_sql(model, schema_editor)
354 include = [
355 model._model_meta.get_forward_field(field_name).column
356 for field_name in self.include
357 ]
358 expressions = self._get_index_expressions(model, schema_editor)
359 return schema_editor._delete_unique_sql(
360 model,
361 self.name,
362 condition=condition,
363 deferrable=self.deferrable,
364 include=include,
365 opclasses=tuple(self.opclasses) if self.opclasses else None,
366 expressions=expressions,
367 )
368
369 def __repr__(self) -> str:
370 return "<{}:{}{}{}{}{}{}{}{}{}>".format(
371 self.__class__.__qualname__,
372 "" if not self.fields else f" fields={repr(self.fields)}",
373 "" if not self.expressions else f" expressions={repr(self.expressions)}",
374 f" name={repr(self.name)}",
375 "" if self.condition is None else f" condition={self.condition}",
376 "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
377 "" if not self.include else f" include={repr(self.include)}",
378 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
379 (
380 ""
381 if self.violation_error_code is None
382 else f" violation_error_code={self.violation_error_code!r}"
383 ),
384 (
385 ""
386 if self.violation_error_message is None
387 or self.violation_error_message == self.default_violation_error_message
388 else f" violation_error_message={self.violation_error_message!r}"
389 ),
390 )
391
392 def __eq__(self, other: object) -> bool:
393 if isinstance(other, UniqueConstraint):
394 return (
395 self.name == other.name
396 and self.fields == other.fields
397 and self.condition == other.condition
398 and self.deferrable == other.deferrable
399 and self.include == other.include
400 and self.opclasses == other.opclasses
401 and self.expressions == other.expressions
402 and self.violation_error_code == other.violation_error_code
403 and self.violation_error_message == other.violation_error_message
404 )
405 return super().__eq__(other)
406
407 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
408 path, args, kwargs = super().deconstruct()
409 if self.fields:
410 kwargs["fields"] = self.fields
411 if self.condition:
412 kwargs["condition"] = self.condition
413 if self.deferrable:
414 kwargs["deferrable"] = self.deferrable
415 if self.include:
416 kwargs["include"] = self.include
417 if self.opclasses:
418 kwargs["opclasses"] = self.opclasses
419 return path, self.expressions, kwargs
420
421 def validate(
422 self, model: type[Model], instance: Model, exclude: set[str] | None = None
423 ) -> None:
424 queryset = model.query
425 if self.fields:
426 lookup_kwargs = {}
427 for field_name in self.fields:
428 if exclude and field_name in exclude:
429 return
430 field = model._model_meta.get_forward_field(field_name)
431 lookup_value = getattr(instance, field.attname)
432 if lookup_value is None:
433 # A composite constraint containing NULL value cannot cause
434 # a violation since NULL != NULL in SQL.
435 return
436 lookup_kwargs[field.name] = lookup_value
437 queryset = queryset.filter(**lookup_kwargs)
438 else:
439 # Ignore constraints with excluded fields.
440 if exclude:
441 for expression in self.expressions:
442 if hasattr(expression, "flatten"):
443 for expr in expression.flatten(): # type: ignore[operator]
444 if isinstance(expr, F) and expr.name in exclude:
445 return
446 elif isinstance(expression, F) and expression.name in exclude:
447 return
448 replacements: dict[Any, Any] = {
449 F(field): value
450 for field, value in instance._get_field_value_map(
451 meta=model._model_meta, exclude=exclude
452 ).items()
453 }
454 expressions = []
455 for expr in self.expressions:
456 # Ignore ordering.
457 if isinstance(expr, OrderBy):
458 expr = expr.expression
459 expressions.append(Exact(expr, expr.replace_expressions(replacements)))
460 queryset = queryset.filter(*expressions)
461 model_class_id = instance.id
462 if not instance._state.adding and model_class_id is not None:
463 queryset = queryset.exclude(id=model_class_id)
464 if not self.condition:
465 if queryset.exists():
466 if self.expressions:
467 raise ValidationError(
468 self.get_violation_error_message(),
469 code=self.violation_error_code,
470 )
471 # When fields are defined, use the unique_error_message() for
472 # backward compatibility.
473 for constraint_model, constraints in instance.get_constraints():
474 for constraint in constraints:
475 if constraint is self:
476 raise ValidationError(
477 instance.unique_error_message(
478 constraint_model,
479 self.fields,
480 ),
481 )
482 else:
483 against = instance._get_field_value_map(
484 meta=model._model_meta, exclude=exclude
485 )
486 try:
487 if (self.condition & Exists(queryset.filter(self.condition))).check(
488 against
489 ):
490 raise ValidationError(
491 self.get_violation_error_message(),
492 code=self.violation_error_code,
493 )
494 except FieldError:
495 pass