1from __future__ import annotations
2
3from abc import ABC, abstractmethod
4from enum import Enum
5from types import NoneType
6from typing import TYPE_CHECKING, Any
7
8from plain.exceptions import ValidationError
9from plain.models.exceptions import FieldError
10from plain.models.expressions import Exists, ExpressionList, F, OrderBy
11from plain.models.indexes import IndexExpression
12from plain.models.lookups import Exact
13from plain.models.query_utils import Q
14from plain.models.sql.query import Query
15
16if TYPE_CHECKING:
17 from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
18 from plain.models.backends.ddl_references import Statement
19 from plain.models.base import Model
20 from plain.models.expressions import Combinable
21
22__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
23
24
25class BaseConstraint(ABC):
26 default_violation_error_message = 'Constraint "%(name)s" is violated.'
27 violation_error_code: str | None = None
28 violation_error_message: str | None = None
29
30 def __init__(
31 self,
32 *,
33 name: str,
34 violation_error_code: str | None = None,
35 violation_error_message: str | None = None,
36 ) -> None:
37 self.name = name
38 if violation_error_code is not None:
39 self.violation_error_code = violation_error_code
40 if violation_error_message is not None:
41 self.violation_error_message = violation_error_message
42 else:
43 self.violation_error_message = self.default_violation_error_message
44
45 @property
46 def contains_expressions(self) -> bool:
47 return False
48
49 @abstractmethod
50 def constraint_sql(
51 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
52 ) -> str | None: ...
53
54 @abstractmethod
55 def create_sql(
56 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
57 ) -> str | Statement | None: ...
58
59 @abstractmethod
60 def remove_sql(
61 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
62 ) -> str | Statement | None: ...
63
64 @abstractmethod
65 def validate(
66 self, model: type[Model], instance: Model, exclude: set[str] | None = None
67 ) -> None: ...
68
69 def get_violation_error_message(self) -> str:
70 return self.violation_error_message % {"name": self.name} # type: ignore[operator]
71
72 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
73 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
74 path = path.replace("plain.models.constraints", "plain.models")
75 kwargs: dict[str, Any] = {"name": self.name}
76 if (
77 self.violation_error_message is not None
78 and self.violation_error_message != self.default_violation_error_message
79 ):
80 kwargs["violation_error_message"] = self.violation_error_message
81 if self.violation_error_code is not None:
82 kwargs["violation_error_code"] = self.violation_error_code
83 return (path, (), kwargs)
84
85 def clone(self) -> BaseConstraint:
86 _, args, kwargs = self.deconstruct()
87 return self.__class__(*args, **kwargs)
88
89
90class CheckConstraint(BaseConstraint):
91 def __init__(
92 self,
93 *,
94 check: Q,
95 name: str,
96 violation_error_code: str | None = None,
97 violation_error_message: str | None = None,
98 ) -> None:
99 self.check = check
100 if not getattr(check, "conditional", False):
101 raise TypeError(
102 "CheckConstraint.check must be a Q instance or boolean expression."
103 )
104 super().__init__(
105 name=name,
106 violation_error_code=violation_error_code,
107 violation_error_message=violation_error_message,
108 )
109
110 def _get_check_sql(
111 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
112 ) -> str:
113 query = Query(model=model, alias_cols=False)
114 where = query.build_where(self.check)
115 compiler = query.get_compiler()
116 sql, params = where.as_sql(compiler, schema_editor.connection)
117 return sql % tuple(schema_editor.quote_value(p) for p in params)
118
119 def constraint_sql(
120 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
121 ) -> str:
122 check = self._get_check_sql(model, schema_editor)
123 return schema_editor._check_sql(self.name, check)
124
125 def create_sql(
126 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
127 ) -> Statement | None:
128 check = self._get_check_sql(model, schema_editor)
129 return schema_editor._create_check_sql(model, self.name, check)
130
131 def remove_sql(
132 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
133 ) -> Statement | None:
134 return schema_editor._delete_check_sql(model, self.name)
135
136 def validate(
137 self, model: type[Model], instance: Model, exclude: set[str] | None = None
138 ) -> None:
139 against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
140 try:
141 if not Q(self.check).check(against):
142 raise ValidationError(
143 self.get_violation_error_message(), code=self.violation_error_code
144 )
145 except FieldError:
146 pass
147
148 def __repr__(self) -> str:
149 return "<{}: check={} name={}{}{}>".format(
150 self.__class__.__qualname__,
151 self.check,
152 repr(self.name),
153 (
154 ""
155 if self.violation_error_code is None
156 else f" violation_error_code={self.violation_error_code!r}"
157 ),
158 (
159 ""
160 if self.violation_error_message is None
161 or self.violation_error_message == self.default_violation_error_message
162 else f" violation_error_message={self.violation_error_message!r}"
163 ),
164 )
165
166 def __eq__(self, other: object) -> bool:
167 if isinstance(other, CheckConstraint):
168 return (
169 self.name == other.name
170 and self.check == other.check
171 and self.violation_error_code == other.violation_error_code
172 and self.violation_error_message == other.violation_error_message
173 )
174 return super().__eq__(other)
175
176 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
177 path, args, kwargs = super().deconstruct()
178 kwargs["check"] = self.check
179 return path, args, kwargs
180
181
182class Deferrable(Enum):
183 DEFERRED = "deferred"
184 IMMEDIATE = "immediate"
185
186 # A similar format was proposed for Python 3.10.
187 def __repr__(self) -> str:
188 return f"{self.__class__.__qualname__}.{self._name_}"
189
190
191class UniqueConstraint(BaseConstraint):
192 def __init__(
193 self,
194 *expressions: str | Combinable,
195 fields: tuple[str, ...] | list[str] = (),
196 name: str | None = None,
197 condition: Q | None = None,
198 deferrable: Deferrable | None = None,
199 include: tuple[str, ...] | list[str] | None = None,
200 opclasses: tuple[str, ...] | list[str] = (),
201 violation_error_code: str | None = None,
202 violation_error_message: str | None = None,
203 ) -> None:
204 if not name:
205 raise ValueError("A unique constraint must be named.")
206 if not expressions and not fields:
207 raise ValueError(
208 "At least one field or expression is required to define a "
209 "unique constraint."
210 )
211 if expressions and fields:
212 raise ValueError(
213 "UniqueConstraint.fields and expressions are mutually exclusive."
214 )
215 if not isinstance(condition, NoneType | Q):
216 raise ValueError("UniqueConstraint.condition must be a Q instance.")
217 if condition and deferrable:
218 raise ValueError("UniqueConstraint with conditions cannot be deferred.")
219 if include and deferrable:
220 raise ValueError("UniqueConstraint with include fields cannot be deferred.")
221 if opclasses and deferrable:
222 raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
223 if expressions and deferrable:
224 raise ValueError("UniqueConstraint with expressions cannot be deferred.")
225 if expressions and opclasses:
226 raise ValueError(
227 "UniqueConstraint.opclasses cannot be used with expressions. "
228 "Use a custom OpClass() instead."
229 )
230 if not isinstance(deferrable, NoneType | Deferrable):
231 raise ValueError(
232 "UniqueConstraint.deferrable must be a Deferrable instance."
233 )
234 if not isinstance(include, NoneType | list | tuple):
235 raise ValueError("UniqueConstraint.include must be a list or tuple.")
236 if not isinstance(opclasses, list | tuple):
237 raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
238 if opclasses and len(fields) != len(opclasses):
239 raise ValueError(
240 "UniqueConstraint.fields and UniqueConstraint.opclasses must "
241 "have the same number of elements."
242 )
243 self.fields = tuple(fields)
244 self.condition = condition
245 self.deferrable = deferrable
246 self.include = tuple(include) if include else ()
247 self.opclasses = opclasses
248 self.expressions = tuple(
249 F(expression) if isinstance(expression, str) else expression
250 for expression in expressions
251 )
252 super().__init__(
253 name=name, # type: ignore[arg-type]
254 violation_error_code=violation_error_code,
255 violation_error_message=violation_error_message,
256 )
257
258 @property
259 def contains_expressions(self) -> bool:
260 return bool(self.expressions)
261
262 def _get_condition_sql(
263 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
264 ) -> str | None:
265 if self.condition is None:
266 return None
267 query = Query(model=model, alias_cols=False)
268 where = query.build_where(self.condition)
269 compiler = query.get_compiler()
270 sql, params = where.as_sql(compiler, schema_editor.connection)
271 return sql % tuple(schema_editor.quote_value(p) for p in params)
272
273 def _get_index_expressions(
274 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
275 ) -> Any:
276 if not self.expressions:
277 return None
278 index_expressions = []
279 for expression in self.expressions:
280 index_expression = IndexExpression(expression)
281 index_expression.set_wrapper_classes(schema_editor.connection)
282 index_expressions.append(index_expression)
283 return ExpressionList(*index_expressions).resolve_expression(
284 Query(model, alias_cols=False),
285 )
286
287 def constraint_sql(
288 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
289 ) -> str | None:
290 fields = [model._model_meta.get_field(field_name) for field_name in self.fields]
291 include = [
292 model._model_meta.get_field(field_name).column
293 for field_name in self.include
294 ]
295 condition = self._get_condition_sql(model, schema_editor)
296 expressions = self._get_index_expressions(model, schema_editor)
297 return schema_editor._unique_sql(
298 model,
299 fields,
300 self.name,
301 condition=condition,
302 deferrable=self.deferrable,
303 include=include,
304 opclasses=tuple(self.opclasses) if self.opclasses else None,
305 expressions=expressions,
306 )
307
308 def create_sql(
309 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
310 ) -> Statement | None:
311 fields = [model._model_meta.get_field(field_name) for field_name in self.fields]
312 include = [
313 model._model_meta.get_field(field_name).column
314 for field_name in self.include
315 ]
316 condition = self._get_condition_sql(model, schema_editor)
317 expressions = self._get_index_expressions(model, schema_editor)
318 return schema_editor._create_unique_sql(
319 model,
320 fields,
321 self.name,
322 condition=condition,
323 deferrable=self.deferrable,
324 include=include,
325 opclasses=tuple(self.opclasses) if self.opclasses else None,
326 expressions=expressions,
327 )
328
329 def remove_sql(
330 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
331 ) -> Statement | None:
332 condition = self._get_condition_sql(model, schema_editor)
333 include = [
334 model._model_meta.get_field(field_name).column
335 for field_name in self.include
336 ]
337 expressions = self._get_index_expressions(model, schema_editor)
338 return schema_editor._delete_unique_sql(
339 model,
340 self.name,
341 condition=condition,
342 deferrable=self.deferrable,
343 include=include,
344 opclasses=tuple(self.opclasses) if self.opclasses else None,
345 expressions=expressions,
346 )
347
348 def __repr__(self) -> str:
349 return "<{}:{}{}{}{}{}{}{}{}{}>".format(
350 self.__class__.__qualname__,
351 "" if not self.fields else f" fields={repr(self.fields)}",
352 "" if not self.expressions else f" expressions={repr(self.expressions)}",
353 f" name={repr(self.name)}",
354 "" if self.condition is None else f" condition={self.condition}",
355 "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
356 "" if not self.include else f" include={repr(self.include)}",
357 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
358 (
359 ""
360 if self.violation_error_code is None
361 else f" violation_error_code={self.violation_error_code!r}"
362 ),
363 (
364 ""
365 if self.violation_error_message is None
366 or self.violation_error_message == self.default_violation_error_message
367 else f" violation_error_message={self.violation_error_message!r}"
368 ),
369 )
370
371 def __eq__(self, other: object) -> bool:
372 if isinstance(other, UniqueConstraint):
373 return (
374 self.name == other.name
375 and self.fields == other.fields
376 and self.condition == other.condition
377 and self.deferrable == other.deferrable
378 and self.include == other.include
379 and self.opclasses == other.opclasses
380 and self.expressions == other.expressions
381 and self.violation_error_code == other.violation_error_code
382 and self.violation_error_message == other.violation_error_message
383 )
384 return super().__eq__(other)
385
386 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
387 path, args, kwargs = super().deconstruct()
388 if self.fields:
389 kwargs["fields"] = self.fields
390 if self.condition:
391 kwargs["condition"] = self.condition
392 if self.deferrable:
393 kwargs["deferrable"] = self.deferrable
394 if self.include:
395 kwargs["include"] = self.include
396 if self.opclasses:
397 kwargs["opclasses"] = self.opclasses
398 return path, self.expressions, kwargs
399
400 def validate(
401 self, model: type[Model], instance: Model, exclude: set[str] | None = None
402 ) -> None:
403 queryset = model.query
404 if self.fields:
405 lookup_kwargs = {}
406 for field_name in self.fields:
407 if exclude and field_name in exclude:
408 return
409 field = model._model_meta.get_field(field_name)
410 lookup_value = getattr(instance, field.attname)
411 if lookup_value is None:
412 # A composite constraint containing NULL value cannot cause
413 # a violation since NULL != NULL in SQL.
414 return
415 lookup_kwargs[field.name] = lookup_value
416 queryset = queryset.filter(**lookup_kwargs)
417 else:
418 # Ignore constraints with excluded fields.
419 if exclude:
420 for expression in self.expressions:
421 if hasattr(expression, "flatten"):
422 for expr in expression.flatten():
423 if isinstance(expr, F) and expr.name in exclude:
424 return
425 elif isinstance(expression, F) and expression.name in exclude:
426 return
427 replacements = {
428 F(field): value
429 for field, value in instance._get_field_value_map(
430 meta=model._model_meta, exclude=exclude
431 ).items()
432 }
433 expressions = []
434 for expr in self.expressions:
435 # Ignore ordering.
436 if isinstance(expr, OrderBy):
437 expr = expr.expression
438 expressions.append(Exact(expr, expr.replace_expressions(replacements)))
439 queryset = queryset.filter(*expressions)
440 model_class_id = instance.id
441 if not instance._state.adding and model_class_id is not None:
442 queryset = queryset.exclude(id=model_class_id)
443 if not self.condition:
444 if queryset.exists():
445 if self.expressions:
446 raise ValidationError(
447 self.get_violation_error_message(),
448 code=self.violation_error_code,
449 )
450 # When fields are defined, use the unique_error_message() for
451 # backward compatibility.
452 for constraint_model, constraints in instance.get_constraints():
453 for constraint in constraints:
454 if constraint is self:
455 raise ValidationError(
456 instance.unique_error_message(
457 constraint_model, # type: ignore[arg-type]
458 self.fields,
459 ),
460 )
461 else:
462 against = instance._get_field_value_map(
463 meta=model._model_meta, exclude=exclude
464 )
465 try:
466 if (self.condition & Exists(queryset.filter(self.condition))).check(
467 against
468 ):
469 raise ValidationError(
470 self.get_violation_error_message(),
471 code=self.violation_error_code,
472 )
473 except FieldError:
474 pass