1from __future__ import annotations
2
3from abc import ABC, abstractmethod
4from enum import Enum
5from types import NoneType
6from typing import TYPE_CHECKING, Any
7
8from plain.exceptions import ValidationError
9from plain.models.exceptions import FieldError
10from plain.models.expressions import (
11 Exists,
12 ExpressionList,
13 F,
14 OrderBy,
15 ReplaceableExpression,
16)
17from plain.models.indexes import IndexExpression
18from plain.models.lookups import Exact
19from plain.models.query_utils import Q
20from plain.models.sql.query import Query
21
22if TYPE_CHECKING:
23 from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
24 from plain.models.backends.ddl_references import Statement
25 from plain.models.base import Model
26
27__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
28
29
30class BaseConstraint(ABC):
31 default_violation_error_message = 'Constraint "%(name)s" is violated.'
32 violation_error_code: str | None = None
33 violation_error_message: str | None = None
34
35 def __init__(
36 self,
37 *,
38 name: str,
39 violation_error_code: str | None = None,
40 violation_error_message: str | None = None,
41 ) -> None:
42 self.name = name
43 if violation_error_code is not None:
44 self.violation_error_code = violation_error_code
45 if violation_error_message is not None:
46 self.violation_error_message = violation_error_message
47 else:
48 self.violation_error_message = self.default_violation_error_message
49
50 @property
51 def contains_expressions(self) -> bool:
52 return False
53
54 @abstractmethod
55 def constraint_sql(
56 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
57 ) -> str | None: ...
58
59 @abstractmethod
60 def create_sql(
61 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
62 ) -> str | Statement | None: ...
63
64 @abstractmethod
65 def remove_sql(
66 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
67 ) -> str | Statement | None: ...
68
69 @abstractmethod
70 def validate(
71 self, model: type[Model], instance: Model, exclude: set[str] | None = None
72 ) -> None: ...
73
74 def get_violation_error_message(self) -> str:
75 assert self.violation_error_message is not None
76 return self.violation_error_message % {"name": self.name}
77
78 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
79 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
80 path = path.replace("plain.models.constraints", "plain.models")
81 kwargs: dict[str, Any] = {"name": self.name}
82 if (
83 self.violation_error_message is not None
84 and self.violation_error_message != self.default_violation_error_message
85 ):
86 kwargs["violation_error_message"] = self.violation_error_message
87 if self.violation_error_code is not None:
88 kwargs["violation_error_code"] = self.violation_error_code
89 return (path, (), kwargs)
90
91 def clone(self) -> BaseConstraint:
92 _, args, kwargs = self.deconstruct()
93 return self.__class__(*args, **kwargs)
94
95
96class CheckConstraint(BaseConstraint):
97 def __init__(
98 self,
99 *,
100 check: Q,
101 name: str,
102 violation_error_code: str | None = None,
103 violation_error_message: str | None = None,
104 ) -> None:
105 self.check = check
106 if not getattr(check, "conditional", False):
107 raise TypeError(
108 "CheckConstraint.check must be a Q instance or boolean expression."
109 )
110 super().__init__(
111 name=name,
112 violation_error_code=violation_error_code,
113 violation_error_message=violation_error_message,
114 )
115
116 def _get_check_sql(
117 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
118 ) -> str:
119 query = Query(model=model, alias_cols=False)
120 where = query.build_where(self.check)
121 compiler = query.get_compiler()
122 sql, params = where.as_sql(compiler, schema_editor.connection)
123 return sql % tuple(schema_editor.quote_value(p) for p in params)
124
125 def constraint_sql(
126 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
127 ) -> str:
128 check = self._get_check_sql(model, schema_editor)
129 return schema_editor._check_sql(self.name, check)
130
131 def create_sql(
132 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
133 ) -> Statement | None:
134 check = self._get_check_sql(model, schema_editor)
135 return schema_editor._create_check_sql(model, self.name, check)
136
137 def remove_sql(
138 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
139 ) -> Statement | None:
140 return schema_editor._delete_check_sql(model, self.name)
141
142 def validate(
143 self, model: type[Model], instance: Model, exclude: set[str] | None = None
144 ) -> None:
145 against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
146 try:
147 if not Q(self.check).check(against):
148 raise ValidationError(
149 self.get_violation_error_message(), code=self.violation_error_code
150 )
151 except FieldError:
152 pass
153
154 def __repr__(self) -> str:
155 return "<{}: check={} name={}{}{}>".format(
156 self.__class__.__qualname__,
157 self.check,
158 repr(self.name),
159 (
160 ""
161 if self.violation_error_code is None
162 else f" violation_error_code={self.violation_error_code!r}"
163 ),
164 (
165 ""
166 if self.violation_error_message is None
167 or self.violation_error_message == self.default_violation_error_message
168 else f" violation_error_message={self.violation_error_message!r}"
169 ),
170 )
171
172 def __eq__(self, other: object) -> bool:
173 if isinstance(other, CheckConstraint):
174 return (
175 self.name == other.name
176 and self.check == other.check
177 and self.violation_error_code == other.violation_error_code
178 and self.violation_error_message == other.violation_error_message
179 )
180 return super().__eq__(other)
181
182 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
183 path, args, kwargs = super().deconstruct()
184 kwargs["check"] = self.check
185 return path, args, kwargs
186
187
188class Deferrable(Enum):
189 DEFERRED = "deferred"
190 IMMEDIATE = "immediate"
191
192 # A similar format was proposed for Python 3.10.
193 def __repr__(self) -> str:
194 return f"{self.__class__.__qualname__}.{self._name_}"
195
196
197class UniqueConstraint(BaseConstraint):
198 expressions: tuple[ReplaceableExpression, ...]
199
200 def __init__(
201 self,
202 *expressions: str | ReplaceableExpression,
203 fields: tuple[str, ...] | list[str] = (),
204 name: str | None = None,
205 condition: Q | None = None,
206 deferrable: Deferrable | None = None,
207 include: tuple[str, ...] | list[str] | None = None,
208 opclasses: tuple[str, ...] | list[str] = (),
209 violation_error_code: str | None = None,
210 violation_error_message: str | None = None,
211 ) -> None:
212 if not name:
213 raise ValueError("A unique constraint must be named.")
214 if not expressions and not fields:
215 raise ValueError(
216 "At least one field or expression is required to define a "
217 "unique constraint."
218 )
219 if expressions and fields:
220 raise ValueError(
221 "UniqueConstraint.fields and expressions are mutually exclusive."
222 )
223 if not isinstance(condition, NoneType | Q):
224 raise ValueError("UniqueConstraint.condition must be a Q instance.")
225 if condition and deferrable:
226 raise ValueError("UniqueConstraint with conditions cannot be deferred.")
227 if include and deferrable:
228 raise ValueError("UniqueConstraint with include fields cannot be deferred.")
229 if opclasses and deferrable:
230 raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
231 if expressions and deferrable:
232 raise ValueError("UniqueConstraint with expressions cannot be deferred.")
233 if expressions and opclasses:
234 raise ValueError(
235 "UniqueConstraint.opclasses cannot be used with expressions. "
236 "Use a custom OpClass() instead."
237 )
238 if not isinstance(deferrable, NoneType | Deferrable):
239 raise ValueError(
240 "UniqueConstraint.deferrable must be a Deferrable instance."
241 )
242 if not isinstance(include, NoneType | list | tuple):
243 raise ValueError("UniqueConstraint.include must be a list or tuple.")
244 if not isinstance(opclasses, list | tuple):
245 raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
246 if opclasses and len(fields) != len(opclasses):
247 raise ValueError(
248 "UniqueConstraint.fields and UniqueConstraint.opclasses must "
249 "have the same number of elements."
250 )
251 self.fields = tuple(fields)
252 self.condition = condition
253 self.deferrable = deferrable
254 self.include = tuple(include) if include else ()
255 self.opclasses = opclasses
256 self.expressions = tuple(
257 F(expression) if isinstance(expression, str) else expression
258 for expression in expressions
259 )
260 super().__init__(
261 name=name,
262 violation_error_code=violation_error_code,
263 violation_error_message=violation_error_message,
264 )
265
266 @property
267 def contains_expressions(self) -> bool:
268 return bool(self.expressions)
269
270 def _get_condition_sql(
271 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
272 ) -> str | None:
273 if self.condition is None:
274 return None
275 query = Query(model=model, alias_cols=False)
276 where = query.build_where(self.condition)
277 compiler = query.get_compiler()
278 sql, params = where.as_sql(compiler, schema_editor.connection)
279 return sql % tuple(schema_editor.quote_value(p) for p in params)
280
281 def _get_index_expressions(
282 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
283 ) -> Any:
284 if not self.expressions:
285 return None
286 index_expressions = []
287 for expression in self.expressions:
288 index_expression = IndexExpression(expression)
289 index_expressions.append(index_expression)
290 return ExpressionList(*index_expressions).resolve_expression(
291 Query(model, alias_cols=False),
292 )
293
294 def constraint_sql(
295 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
296 ) -> str | None:
297 fields = [
298 model._model_meta.get_forward_field(field_name)
299 for field_name in self.fields
300 ]
301 include = [
302 model._model_meta.get_forward_field(field_name).column
303 for field_name in self.include
304 ]
305 condition = self._get_condition_sql(model, schema_editor)
306 expressions = self._get_index_expressions(model, schema_editor)
307 return schema_editor._unique_sql(
308 model,
309 fields,
310 self.name,
311 condition=condition,
312 deferrable=self.deferrable,
313 include=include,
314 opclasses=tuple(self.opclasses) if self.opclasses else None,
315 expressions=expressions,
316 )
317
318 def create_sql(
319 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
320 ) -> Statement | None:
321 fields = [
322 model._model_meta.get_forward_field(field_name)
323 for field_name in self.fields
324 ]
325 include = [
326 model._model_meta.get_forward_field(field_name).column
327 for field_name in self.include
328 ]
329 condition = self._get_condition_sql(model, schema_editor)
330 expressions = self._get_index_expressions(model, schema_editor)
331 return schema_editor._create_unique_sql(
332 model,
333 fields,
334 self.name,
335 condition=condition,
336 deferrable=self.deferrable,
337 include=include,
338 opclasses=tuple(self.opclasses) if self.opclasses else None,
339 expressions=expressions,
340 )
341
342 def remove_sql(
343 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
344 ) -> Statement | None:
345 condition = self._get_condition_sql(model, schema_editor)
346 include = [
347 model._model_meta.get_forward_field(field_name).column
348 for field_name in self.include
349 ]
350 expressions = self._get_index_expressions(model, schema_editor)
351 return schema_editor._delete_unique_sql(
352 model,
353 self.name,
354 condition=condition,
355 deferrable=self.deferrable,
356 include=include,
357 opclasses=tuple(self.opclasses) if self.opclasses else None,
358 expressions=expressions,
359 )
360
361 def __repr__(self) -> str:
362 return "<{}:{}{}{}{}{}{}{}{}{}>".format(
363 self.__class__.__qualname__,
364 "" if not self.fields else f" fields={repr(self.fields)}",
365 "" if not self.expressions else f" expressions={repr(self.expressions)}",
366 f" name={repr(self.name)}",
367 "" if self.condition is None else f" condition={self.condition}",
368 "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
369 "" if not self.include else f" include={repr(self.include)}",
370 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
371 (
372 ""
373 if self.violation_error_code is None
374 else f" violation_error_code={self.violation_error_code!r}"
375 ),
376 (
377 ""
378 if self.violation_error_message is None
379 or self.violation_error_message == self.default_violation_error_message
380 else f" violation_error_message={self.violation_error_message!r}"
381 ),
382 )
383
384 def __eq__(self, other: object) -> bool:
385 if isinstance(other, UniqueConstraint):
386 return (
387 self.name == other.name
388 and self.fields == other.fields
389 and self.condition == other.condition
390 and self.deferrable == other.deferrable
391 and self.include == other.include
392 and self.opclasses == other.opclasses
393 and self.expressions == other.expressions
394 and self.violation_error_code == other.violation_error_code
395 and self.violation_error_message == other.violation_error_message
396 )
397 return super().__eq__(other)
398
399 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
400 path, args, kwargs = super().deconstruct()
401 if self.fields:
402 kwargs["fields"] = self.fields
403 if self.condition:
404 kwargs["condition"] = self.condition
405 if self.deferrable:
406 kwargs["deferrable"] = self.deferrable
407 if self.include:
408 kwargs["include"] = self.include
409 if self.opclasses:
410 kwargs["opclasses"] = self.opclasses
411 return path, self.expressions, kwargs
412
413 def validate(
414 self, model: type[Model], instance: Model, exclude: set[str] | None = None
415 ) -> None:
416 queryset = model.query
417 if self.fields:
418 lookup_kwargs = {}
419 for field_name in self.fields:
420 if exclude and field_name in exclude:
421 return
422 field = model._model_meta.get_forward_field(field_name)
423 lookup_value = getattr(instance, field.attname)
424 if lookup_value is None:
425 # A composite constraint containing NULL value cannot cause
426 # a violation since NULL != NULL in SQL.
427 return
428 lookup_kwargs[field.name] = lookup_value
429 queryset = queryset.filter(**lookup_kwargs)
430 else:
431 # Ignore constraints with excluded fields.
432 if exclude:
433 for expression in self.expressions:
434 if hasattr(expression, "flatten"):
435 for expr in expression.flatten(): # type: ignore[operator]
436 if isinstance(expr, F) and expr.name in exclude:
437 return
438 elif isinstance(expression, F) and expression.name in exclude:
439 return
440 replacements: dict[Any, Any] = {
441 F(field): value
442 for field, value in instance._get_field_value_map(
443 meta=model._model_meta, exclude=exclude
444 ).items()
445 }
446 expressions = []
447 for expr in self.expressions:
448 # Ignore ordering.
449 if isinstance(expr, OrderBy):
450 expr = expr.expression
451 expressions.append(Exact(expr, expr.replace_expressions(replacements)))
452 queryset = queryset.filter(*expressions)
453 model_class_id = instance.id
454 if not instance._state.adding and model_class_id is not None:
455 queryset = queryset.exclude(id=model_class_id)
456 if not self.condition:
457 if queryset.exists():
458 if self.expressions:
459 raise ValidationError(
460 self.get_violation_error_message(),
461 code=self.violation_error_code,
462 )
463 # When fields are defined, use the unique_error_message() for
464 # backward compatibility.
465 for constraint_model, constraints in instance.get_constraints():
466 for constraint in constraints:
467 if constraint is self:
468 raise ValidationError(
469 instance.unique_error_message(
470 constraint_model,
471 self.fields,
472 ),
473 )
474 else:
475 against = instance._get_field_value_map(
476 meta=model._model_meta, exclude=exclude
477 )
478 try:
479 if (self.condition & Exists(queryset.filter(self.condition))).check(
480 against
481 ):
482 raise ValidationError(
483 self.get_violation_error_message(),
484 code=self.violation_error_code,
485 )
486 except FieldError:
487 pass