1from __future__ import annotations
2
3from abc import ABC, abstractmethod
4from enum import Enum
5from types import NoneType
6from typing import TYPE_CHECKING, Any
7
8from plain.exceptions import ValidationError
9from plain.models.exceptions import FieldError
10from plain.models.expressions import (
11 Exists,
12 ExpressionList,
13 F,
14 OrderBy,
15 ReplaceableExpression,
16)
17from plain.models.indexes import IndexExpression
18from plain.models.lookups import Exact
19from plain.models.query_utils import Q
20from plain.models.sql.query import Query
21
22if TYPE_CHECKING:
23 from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
24 from plain.models.backends.ddl_references import Statement
25 from plain.models.base import Model
26
27__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
28
29
30class BaseConstraint(ABC):
31 default_violation_error_message = 'Constraint "%(name)s" is violated.'
32 violation_error_code: str | None = None
33 violation_error_message: str | None = None
34
35 def __init__(
36 self,
37 *,
38 name: str,
39 violation_error_code: str | None = None,
40 violation_error_message: str | None = None,
41 ) -> None:
42 self.name = name
43 if violation_error_code is not None:
44 self.violation_error_code = violation_error_code
45 if violation_error_message is not None:
46 self.violation_error_message = violation_error_message
47 else:
48 self.violation_error_message = self.default_violation_error_message
49
50 @property
51 def contains_expressions(self) -> bool:
52 return False
53
54 @abstractmethod
55 def constraint_sql(
56 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
57 ) -> str | None: ...
58
59 @abstractmethod
60 def create_sql(
61 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
62 ) -> str | Statement | None: ...
63
64 @abstractmethod
65 def remove_sql(
66 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
67 ) -> str | Statement | None: ...
68
69 @abstractmethod
70 def validate(
71 self, model: type[Model], instance: Model, exclude: set[str] | None = None
72 ) -> None: ...
73
74 def get_violation_error_message(self) -> str:
75 assert self.violation_error_message is not None
76 return self.violation_error_message % {"name": self.name}
77
78 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
79 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
80 path = path.replace("plain.models.constraints", "plain.models")
81 kwargs: dict[str, Any] = {"name": self.name}
82 if (
83 self.violation_error_message is not None
84 and self.violation_error_message != self.default_violation_error_message
85 ):
86 kwargs["violation_error_message"] = self.violation_error_message
87 if self.violation_error_code is not None:
88 kwargs["violation_error_code"] = self.violation_error_code
89 return (path, (), kwargs)
90
91 def clone(self) -> BaseConstraint:
92 _, args, kwargs = self.deconstruct()
93 return self.__class__(*args, **kwargs)
94
95
96class CheckConstraint(BaseConstraint):
97 def __init__(
98 self,
99 *,
100 check: Q,
101 name: str,
102 violation_error_code: str | None = None,
103 violation_error_message: str | None = None,
104 ) -> None:
105 self.check = check
106 if not getattr(check, "conditional", False):
107 raise TypeError(
108 "CheckConstraint.check must be a Q instance or boolean expression."
109 )
110 super().__init__(
111 name=name,
112 violation_error_code=violation_error_code,
113 violation_error_message=violation_error_message,
114 )
115
116 def _get_check_sql(
117 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
118 ) -> str:
119 query = Query(model=model, alias_cols=False)
120 where = query.build_where(self.check)
121 compiler = query.get_compiler()
122 sql, params = where.as_sql(compiler, schema_editor.connection)
123 return sql % tuple(schema_editor.quote_value(p) for p in params)
124
125 def constraint_sql(
126 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
127 ) -> str:
128 check = self._get_check_sql(model, schema_editor)
129 return schema_editor._check_sql(self.name, check)
130
131 def create_sql(
132 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
133 ) -> Statement | None:
134 check = self._get_check_sql(model, schema_editor)
135 return schema_editor._create_check_sql(model, self.name, check)
136
137 def remove_sql(
138 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
139 ) -> Statement | None:
140 return schema_editor._delete_check_sql(model, self.name)
141
142 def validate(
143 self, model: type[Model], instance: Model, exclude: set[str] | None = None
144 ) -> None:
145 against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
146 try:
147 if not Q(self.check).check(against):
148 raise ValidationError(
149 self.get_violation_error_message(), code=self.violation_error_code
150 )
151 except FieldError:
152 pass
153
154 def __repr__(self) -> str:
155 return "<{}: check={} name={}{}{}>".format(
156 self.__class__.__qualname__,
157 self.check,
158 repr(self.name),
159 (
160 ""
161 if self.violation_error_code is None
162 else f" violation_error_code={self.violation_error_code!r}"
163 ),
164 (
165 ""
166 if self.violation_error_message is None
167 or self.violation_error_message == self.default_violation_error_message
168 else f" violation_error_message={self.violation_error_message!r}"
169 ),
170 )
171
172 def __eq__(self, other: object) -> bool:
173 if isinstance(other, CheckConstraint):
174 return (
175 self.name == other.name
176 and self.check == other.check
177 and self.violation_error_code == other.violation_error_code
178 and self.violation_error_message == other.violation_error_message
179 )
180 return super().__eq__(other)
181
182 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
183 path, args, kwargs = super().deconstruct()
184 kwargs["check"] = self.check
185 return path, args, kwargs
186
187
188class Deferrable(Enum):
189 DEFERRED = "deferred"
190 IMMEDIATE = "immediate"
191
192 # A similar format was proposed for Python 3.10.
193 def __repr__(self) -> str:
194 return f"{self.__class__.__qualname__}.{self._name_}"
195
196
197class UniqueConstraint(BaseConstraint):
198 expressions: tuple[ReplaceableExpression, ...]
199
200 def __init__(
201 self,
202 *expressions: str | ReplaceableExpression,
203 fields: tuple[str, ...] | list[str] = (),
204 name: str | None = None,
205 condition: Q | None = None,
206 deferrable: Deferrable | None = None,
207 include: tuple[str, ...] | list[str] | None = None,
208 opclasses: tuple[str, ...] | list[str] = (),
209 violation_error_code: str | None = None,
210 violation_error_message: str | None = None,
211 ) -> None:
212 if not name:
213 raise ValueError("A unique constraint must be named.")
214 if not expressions and not fields:
215 raise ValueError(
216 "At least one field or expression is required to define a "
217 "unique constraint."
218 )
219 if expressions and fields:
220 raise ValueError(
221 "UniqueConstraint.fields and expressions are mutually exclusive."
222 )
223 if not isinstance(condition, NoneType | Q):
224 raise ValueError("UniqueConstraint.condition must be a Q instance.")
225 if condition and deferrable:
226 raise ValueError("UniqueConstraint with conditions cannot be deferred.")
227 if include and deferrable:
228 raise ValueError("UniqueConstraint with include fields cannot be deferred.")
229 if opclasses and deferrable:
230 raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
231 if expressions and deferrable:
232 raise ValueError("UniqueConstraint with expressions cannot be deferred.")
233 if expressions and opclasses:
234 raise ValueError(
235 "UniqueConstraint.opclasses cannot be used with expressions. "
236 "Use a custom OpClass() instead."
237 )
238 if not isinstance(deferrable, NoneType | Deferrable):
239 raise ValueError(
240 "UniqueConstraint.deferrable must be a Deferrable instance."
241 )
242 if not isinstance(include, NoneType | list | tuple):
243 raise ValueError("UniqueConstraint.include must be a list or tuple.")
244 if not isinstance(opclasses, list | tuple):
245 raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
246 if opclasses and len(fields) != len(opclasses):
247 raise ValueError(
248 "UniqueConstraint.fields and UniqueConstraint.opclasses must "
249 "have the same number of elements."
250 )
251 self.fields = tuple(fields)
252 self.condition = condition
253 self.deferrable = deferrable
254 self.include = tuple(include) if include else ()
255 self.opclasses = opclasses
256 self.expressions = tuple(
257 F(expression) if isinstance(expression, str) else expression
258 for expression in expressions
259 )
260 super().__init__(
261 name=name,
262 violation_error_code=violation_error_code,
263 violation_error_message=violation_error_message,
264 )
265
266 @property
267 def contains_expressions(self) -> bool:
268 return bool(self.expressions)
269
270 def _get_condition_sql(
271 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
272 ) -> str | None:
273 if self.condition is None:
274 return None
275 query = Query(model=model, alias_cols=False)
276 where = query.build_where(self.condition)
277 compiler = query.get_compiler()
278 sql, params = where.as_sql(compiler, schema_editor.connection)
279 return sql % tuple(schema_editor.quote_value(p) for p in params)
280
281 def _get_index_expressions(
282 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
283 ) -> Any:
284 if not self.expressions:
285 return None
286 index_expressions = []
287 for expression in self.expressions:
288 index_expression = IndexExpression(expression)
289 index_expression.set_wrapper_classes(schema_editor.connection)
290 index_expressions.append(index_expression)
291 return ExpressionList(*index_expressions).resolve_expression(
292 Query(model, alias_cols=False),
293 )
294
295 def constraint_sql(
296 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
297 ) -> str | None:
298 fields = [
299 model._model_meta.get_forward_field(field_name)
300 for field_name in self.fields
301 ]
302 include = [
303 model._model_meta.get_forward_field(field_name).column
304 for field_name in self.include
305 ]
306 condition = self._get_condition_sql(model, schema_editor)
307 expressions = self._get_index_expressions(model, schema_editor)
308 return schema_editor._unique_sql(
309 model,
310 fields,
311 self.name,
312 condition=condition,
313 deferrable=self.deferrable,
314 include=include,
315 opclasses=tuple(self.opclasses) if self.opclasses else None,
316 expressions=expressions,
317 )
318
319 def create_sql(
320 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
321 ) -> Statement | None:
322 fields = [
323 model._model_meta.get_forward_field(field_name)
324 for field_name in self.fields
325 ]
326 include = [
327 model._model_meta.get_forward_field(field_name).column
328 for field_name in self.include
329 ]
330 condition = self._get_condition_sql(model, schema_editor)
331 expressions = self._get_index_expressions(model, schema_editor)
332 return schema_editor._create_unique_sql(
333 model,
334 fields,
335 self.name,
336 condition=condition,
337 deferrable=self.deferrable,
338 include=include,
339 opclasses=tuple(self.opclasses) if self.opclasses else None,
340 expressions=expressions,
341 )
342
343 def remove_sql(
344 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
345 ) -> Statement | None:
346 condition = self._get_condition_sql(model, schema_editor)
347 include = [
348 model._model_meta.get_forward_field(field_name).column
349 for field_name in self.include
350 ]
351 expressions = self._get_index_expressions(model, schema_editor)
352 return schema_editor._delete_unique_sql(
353 model,
354 self.name,
355 condition=condition,
356 deferrable=self.deferrable,
357 include=include,
358 opclasses=tuple(self.opclasses) if self.opclasses else None,
359 expressions=expressions,
360 )
361
362 def __repr__(self) -> str:
363 return "<{}:{}{}{}{}{}{}{}{}{}>".format(
364 self.__class__.__qualname__,
365 "" if not self.fields else f" fields={repr(self.fields)}",
366 "" if not self.expressions else f" expressions={repr(self.expressions)}",
367 f" name={repr(self.name)}",
368 "" if self.condition is None else f" condition={self.condition}",
369 "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
370 "" if not self.include else f" include={repr(self.include)}",
371 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
372 (
373 ""
374 if self.violation_error_code is None
375 else f" violation_error_code={self.violation_error_code!r}"
376 ),
377 (
378 ""
379 if self.violation_error_message is None
380 or self.violation_error_message == self.default_violation_error_message
381 else f" violation_error_message={self.violation_error_message!r}"
382 ),
383 )
384
385 def __eq__(self, other: object) -> bool:
386 if isinstance(other, UniqueConstraint):
387 return (
388 self.name == other.name
389 and self.fields == other.fields
390 and self.condition == other.condition
391 and self.deferrable == other.deferrable
392 and self.include == other.include
393 and self.opclasses == other.opclasses
394 and self.expressions == other.expressions
395 and self.violation_error_code == other.violation_error_code
396 and self.violation_error_message == other.violation_error_message
397 )
398 return super().__eq__(other)
399
400 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
401 path, args, kwargs = super().deconstruct()
402 if self.fields:
403 kwargs["fields"] = self.fields
404 if self.condition:
405 kwargs["condition"] = self.condition
406 if self.deferrable:
407 kwargs["deferrable"] = self.deferrable
408 if self.include:
409 kwargs["include"] = self.include
410 if self.opclasses:
411 kwargs["opclasses"] = self.opclasses
412 return path, self.expressions, kwargs
413
414 def validate(
415 self, model: type[Model], instance: Model, exclude: set[str] | None = None
416 ) -> None:
417 queryset = model.query
418 if self.fields:
419 lookup_kwargs = {}
420 for field_name in self.fields:
421 if exclude and field_name in exclude:
422 return
423 field = model._model_meta.get_forward_field(field_name)
424 lookup_value = getattr(instance, field.attname)
425 if lookup_value is None:
426 # A composite constraint containing NULL value cannot cause
427 # a violation since NULL != NULL in SQL.
428 return
429 lookup_kwargs[field.name] = lookup_value
430 queryset = queryset.filter(**lookup_kwargs)
431 else:
432 # Ignore constraints with excluded fields.
433 if exclude:
434 for expression in self.expressions:
435 if hasattr(expression, "flatten"):
436 for expr in expression.flatten(): # type: ignore[call-non-callable]
437 if isinstance(expr, F) and expr.name in exclude:
438 return
439 elif isinstance(expression, F) and expression.name in exclude:
440 return
441 replacements: dict[Any, Any] = {
442 F(field): value
443 for field, value in instance._get_field_value_map(
444 meta=model._model_meta, exclude=exclude
445 ).items()
446 }
447 expressions = []
448 for expr in self.expressions:
449 # Ignore ordering.
450 if isinstance(expr, OrderBy):
451 expr = expr.expression
452 expressions.append(Exact(expr, expr.replace_expressions(replacements)))
453 queryset = queryset.filter(*expressions)
454 model_class_id = instance.id
455 if not instance._state.adding and model_class_id is not None:
456 queryset = queryset.exclude(id=model_class_id)
457 if not self.condition:
458 if queryset.exists():
459 if self.expressions:
460 raise ValidationError(
461 self.get_violation_error_message(),
462 code=self.violation_error_code,
463 )
464 # When fields are defined, use the unique_error_message() for
465 # backward compatibility.
466 for constraint_model, constraints in instance.get_constraints():
467 for constraint in constraints:
468 if constraint is self:
469 raise ValidationError(
470 instance.unique_error_message(
471 constraint_model,
472 self.fields,
473 ),
474 )
475 else:
476 against = instance._get_field_value_map(
477 meta=model._model_meta, exclude=exclude
478 )
479 try:
480 if (self.condition & Exists(queryset.filter(self.condition))).check(
481 against
482 ):
483 raise ValidationError(
484 self.get_violation_error_message(),
485 code=self.violation_error_code,
486 )
487 except FieldError:
488 pass