1from __future__ import annotations
2
3from enum import Enum
4from types import NoneType
5from typing import Any
6
7from plain.exceptions import ValidationError
8from plain.models.exceptions import FieldError
9from plain.models.expressions import Exists, ExpressionList, F, OrderBy
10from plain.models.indexes import IndexExpression
11from plain.models.lookups import Exact
12from plain.models.query_utils import Q
13from plain.models.sql.query import Query
14
15__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
16
17
18class BaseConstraint:
19 default_violation_error_message = 'Constraint "%(name)s" is violated.'
20 violation_error_code: str | None = None
21 violation_error_message: str | None = None
22
23 def __init__(
24 self,
25 *,
26 name: str,
27 violation_error_code: str | None = None,
28 violation_error_message: str | None = None,
29 ) -> None:
30 self.name = name
31 if violation_error_code is not None:
32 self.violation_error_code = violation_error_code
33 if violation_error_message is not None:
34 self.violation_error_message = violation_error_message
35 else:
36 self.violation_error_message = self.default_violation_error_message
37
38 @property
39 def contains_expressions(self) -> bool:
40 return False
41
42 def constraint_sql(self, model: Any, schema_editor: Any) -> str:
43 raise NotImplementedError("This method must be implemented by a subclass.")
44
45 def create_sql(self, model: Any, schema_editor: Any) -> str:
46 raise NotImplementedError("This method must be implemented by a subclass.")
47
48 def remove_sql(self, model: Any, schema_editor: Any) -> str:
49 raise NotImplementedError("This method must be implemented by a subclass.")
50
51 def validate(
52 self, model: Any, instance: Any, exclude: set[str] | None = None
53 ) -> None:
54 raise NotImplementedError("This method must be implemented by a subclass.")
55
56 def get_violation_error_message(self) -> str:
57 return self.violation_error_message % {"name": self.name} # type: ignore[operator]
58
59 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
60 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
61 path = path.replace("plain.models.constraints", "plain.models")
62 kwargs: dict[str, Any] = {"name": self.name}
63 if (
64 self.violation_error_message is not None
65 and self.violation_error_message != self.default_violation_error_message
66 ):
67 kwargs["violation_error_message"] = self.violation_error_message
68 if self.violation_error_code is not None:
69 kwargs["violation_error_code"] = self.violation_error_code
70 return (path, (), kwargs)
71
72 def clone(self) -> BaseConstraint:
73 _, args, kwargs = self.deconstruct()
74 return self.__class__(*args, **kwargs)
75
76
77class CheckConstraint(BaseConstraint):
78 def __init__(
79 self,
80 *,
81 check: Q,
82 name: str,
83 violation_error_code: str | None = None,
84 violation_error_message: str | None = None,
85 ) -> None:
86 self.check = check
87 if not getattr(check, "conditional", False):
88 raise TypeError(
89 "CheckConstraint.check must be a Q instance or boolean expression."
90 )
91 super().__init__(
92 name=name,
93 violation_error_code=violation_error_code,
94 violation_error_message=violation_error_message,
95 )
96
97 def _get_check_sql(self, model: Any, schema_editor: Any) -> str:
98 query = Query(model=model, alias_cols=False)
99 where = query.build_where(self.check)
100 compiler = query.get_compiler()
101 sql, params = where.as_sql(compiler, schema_editor.connection)
102 return sql % tuple(schema_editor.quote_value(p) for p in params)
103
104 def constraint_sql(self, model: Any, schema_editor: Any) -> str:
105 check = self._get_check_sql(model, schema_editor)
106 return schema_editor._check_sql(self.name, check)
107
108 def create_sql(self, model: Any, schema_editor: Any) -> str:
109 check = self._get_check_sql(model, schema_editor)
110 return schema_editor._create_check_sql(model, self.name, check)
111
112 def remove_sql(self, model: Any, schema_editor: Any) -> str:
113 return schema_editor._delete_check_sql(model, self.name)
114
115 def validate(
116 self, model: Any, instance: Any, exclude: set[str] | None = None
117 ) -> None:
118 against = instance._get_field_value_map(meta=model._model_meta, exclude=exclude)
119 try:
120 if not Q(self.check).check(against):
121 raise ValidationError(
122 self.get_violation_error_message(), code=self.violation_error_code
123 )
124 except FieldError:
125 pass
126
127 def __repr__(self) -> str:
128 return "<{}: check={} name={}{}{}>".format(
129 self.__class__.__qualname__,
130 self.check,
131 repr(self.name),
132 (
133 ""
134 if self.violation_error_code is None
135 else f" violation_error_code={self.violation_error_code!r}"
136 ),
137 (
138 ""
139 if self.violation_error_message is None
140 or self.violation_error_message == self.default_violation_error_message
141 else f" violation_error_message={self.violation_error_message!r}"
142 ),
143 )
144
145 def __eq__(self, other: object) -> bool:
146 if isinstance(other, CheckConstraint):
147 return (
148 self.name == other.name
149 and self.check == other.check
150 and self.violation_error_code == other.violation_error_code
151 and self.violation_error_message == other.violation_error_message
152 )
153 return super().__eq__(other)
154
155 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
156 path, args, kwargs = super().deconstruct()
157 kwargs["check"] = self.check
158 return path, args, kwargs
159
160
161class Deferrable(Enum):
162 DEFERRED = "deferred"
163 IMMEDIATE = "immediate"
164
165 # A similar format was proposed for Python 3.10.
166 def __repr__(self) -> str:
167 return f"{self.__class__.__qualname__}.{self._name_}"
168
169
170class UniqueConstraint(BaseConstraint):
171 def __init__(
172 self,
173 *expressions: Any,
174 fields: tuple[str, ...] | list[str] = (),
175 name: str | None = None,
176 condition: Q | None = None,
177 deferrable: Deferrable | None = None,
178 include: tuple[str, ...] | list[str] | None = None,
179 opclasses: tuple[str, ...] | list[str] = (),
180 violation_error_code: str | None = None,
181 violation_error_message: str | None = None,
182 ) -> None:
183 if not name:
184 raise ValueError("A unique constraint must be named.")
185 if not expressions and not fields:
186 raise ValueError(
187 "At least one field or expression is required to define a "
188 "unique constraint."
189 )
190 if expressions and fields:
191 raise ValueError(
192 "UniqueConstraint.fields and expressions are mutually exclusive."
193 )
194 if not isinstance(condition, NoneType | Q):
195 raise ValueError("UniqueConstraint.condition must be a Q instance.")
196 if condition and deferrable:
197 raise ValueError("UniqueConstraint with conditions cannot be deferred.")
198 if include and deferrable:
199 raise ValueError("UniqueConstraint with include fields cannot be deferred.")
200 if opclasses and deferrable:
201 raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
202 if expressions and deferrable:
203 raise ValueError("UniqueConstraint with expressions cannot be deferred.")
204 if expressions and opclasses:
205 raise ValueError(
206 "UniqueConstraint.opclasses cannot be used with expressions. "
207 "Use a custom OpClass() instead."
208 )
209 if not isinstance(deferrable, NoneType | Deferrable):
210 raise ValueError(
211 "UniqueConstraint.deferrable must be a Deferrable instance."
212 )
213 if not isinstance(include, NoneType | list | tuple):
214 raise ValueError("UniqueConstraint.include must be a list or tuple.")
215 if not isinstance(opclasses, list | tuple):
216 raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
217 if opclasses and len(fields) != len(opclasses):
218 raise ValueError(
219 "UniqueConstraint.fields and UniqueConstraint.opclasses must "
220 "have the same number of elements."
221 )
222 self.fields = tuple(fields)
223 self.condition = condition
224 self.deferrable = deferrable
225 self.include = tuple(include) if include else ()
226 self.opclasses = opclasses
227 self.expressions = tuple(
228 F(expression) if isinstance(expression, str) else expression
229 for expression in expressions
230 )
231 super().__init__(
232 name=name, # type: ignore[arg-type]
233 violation_error_code=violation_error_code,
234 violation_error_message=violation_error_message,
235 )
236
237 @property
238 def contains_expressions(self) -> bool:
239 return bool(self.expressions)
240
241 def _get_condition_sql(self, model: Any, schema_editor: Any) -> str | None:
242 if self.condition is None:
243 return None
244 query = Query(model=model, alias_cols=False)
245 where = query.build_where(self.condition)
246 compiler = query.get_compiler()
247 sql, params = where.as_sql(compiler, schema_editor.connection)
248 return sql % tuple(schema_editor.quote_value(p) for p in params)
249
250 def _get_index_expressions(self, model: Any, schema_editor: Any) -> Any:
251 if not self.expressions:
252 return None
253 index_expressions = []
254 for expression in self.expressions:
255 index_expression = IndexExpression(expression)
256 index_expression.set_wrapper_classes(schema_editor.connection)
257 index_expressions.append(index_expression)
258 return ExpressionList(*index_expressions).resolve_expression(
259 Query(model, alias_cols=False),
260 )
261
262 def constraint_sql(self, model: Any, schema_editor: Any) -> str:
263 fields = [model._model_meta.get_field(field_name) for field_name in self.fields]
264 include = [
265 model._model_meta.get_field(field_name).column
266 for field_name in self.include
267 ]
268 condition = self._get_condition_sql(model, schema_editor)
269 expressions = self._get_index_expressions(model, schema_editor)
270 return schema_editor._unique_sql(
271 model,
272 fields,
273 self.name,
274 condition=condition,
275 deferrable=self.deferrable,
276 include=include,
277 opclasses=self.opclasses,
278 expressions=expressions,
279 )
280
281 def create_sql(self, model: Any, schema_editor: Any) -> str:
282 fields = [model._model_meta.get_field(field_name) for field_name in self.fields]
283 include = [
284 model._model_meta.get_field(field_name).column
285 for field_name in self.include
286 ]
287 condition = self._get_condition_sql(model, schema_editor)
288 expressions = self._get_index_expressions(model, schema_editor)
289 return schema_editor._create_unique_sql(
290 model,
291 fields,
292 self.name,
293 condition=condition,
294 deferrable=self.deferrable,
295 include=include,
296 opclasses=self.opclasses,
297 expressions=expressions,
298 )
299
300 def remove_sql(self, model: Any, schema_editor: Any) -> str:
301 condition = self._get_condition_sql(model, schema_editor)
302 include = [
303 model._model_meta.get_field(field_name).column
304 for field_name in self.include
305 ]
306 expressions = self._get_index_expressions(model, schema_editor)
307 return schema_editor._delete_unique_sql(
308 model,
309 self.name,
310 condition=condition,
311 deferrable=self.deferrable,
312 include=include,
313 opclasses=self.opclasses,
314 expressions=expressions,
315 )
316
317 def __repr__(self) -> str:
318 return "<{}:{}{}{}{}{}{}{}{}{}>".format(
319 self.__class__.__qualname__,
320 "" if not self.fields else f" fields={repr(self.fields)}",
321 "" if not self.expressions else f" expressions={repr(self.expressions)}",
322 f" name={repr(self.name)}",
323 "" if self.condition is None else f" condition={self.condition}",
324 "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
325 "" if not self.include else f" include={repr(self.include)}",
326 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
327 (
328 ""
329 if self.violation_error_code is None
330 else f" violation_error_code={self.violation_error_code!r}"
331 ),
332 (
333 ""
334 if self.violation_error_message is None
335 or self.violation_error_message == self.default_violation_error_message
336 else f" violation_error_message={self.violation_error_message!r}"
337 ),
338 )
339
340 def __eq__(self, other: object) -> bool:
341 if isinstance(other, UniqueConstraint):
342 return (
343 self.name == other.name
344 and self.fields == other.fields
345 and self.condition == other.condition
346 and self.deferrable == other.deferrable
347 and self.include == other.include
348 and self.opclasses == other.opclasses
349 and self.expressions == other.expressions
350 and self.violation_error_code == other.violation_error_code
351 and self.violation_error_message == other.violation_error_message
352 )
353 return super().__eq__(other)
354
355 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
356 path, args, kwargs = super().deconstruct()
357 if self.fields:
358 kwargs["fields"] = self.fields
359 if self.condition:
360 kwargs["condition"] = self.condition
361 if self.deferrable:
362 kwargs["deferrable"] = self.deferrable
363 if self.include:
364 kwargs["include"] = self.include
365 if self.opclasses:
366 kwargs["opclasses"] = self.opclasses
367 return path, self.expressions, kwargs
368
369 def validate(
370 self, model: Any, instance: Any, exclude: set[str] | None = None
371 ) -> None:
372 queryset = model.query
373 if self.fields:
374 lookup_kwargs = {}
375 for field_name in self.fields:
376 if exclude and field_name in exclude:
377 return
378 field = model._model_meta.get_field(field_name)
379 lookup_value = getattr(instance, field.attname)
380 if lookup_value is None:
381 # A composite constraint containing NULL value cannot cause
382 # a violation since NULL != NULL in SQL.
383 return
384 lookup_kwargs[field.name] = lookup_value
385 queryset = queryset.filter(**lookup_kwargs)
386 else:
387 # Ignore constraints with excluded fields.
388 if exclude:
389 for expression in self.expressions:
390 if hasattr(expression, "flatten"):
391 for expr in expression.flatten():
392 if isinstance(expr, F) and expr.name in exclude:
393 return
394 elif isinstance(expression, F) and expression.name in exclude:
395 return
396 replacements = {
397 F(field): value
398 for field, value in instance._get_field_value_map(
399 meta=model._model_meta, exclude=exclude
400 ).items()
401 }
402 expressions = []
403 for expr in self.expressions:
404 # Ignore ordering.
405 if isinstance(expr, OrderBy):
406 expr = expr.expression
407 expressions.append(Exact(expr, expr.replace_expressions(replacements)))
408 queryset = queryset.filter(*expressions)
409 model_class_id = instance.id
410 if not instance._state.adding and model_class_id is not None:
411 queryset = queryset.exclude(id=model_class_id)
412 if not self.condition:
413 if queryset.exists():
414 if self.expressions:
415 raise ValidationError(
416 self.get_violation_error_message(),
417 code=self.violation_error_code,
418 )
419 # When fields are defined, use the unique_error_message() for
420 # backward compatibility.
421 for model, constraints in instance.get_constraints():
422 for constraint in constraints:
423 if constraint is self:
424 raise ValidationError(
425 instance.unique_error_message(model, self.fields),
426 )
427 else:
428 against = instance._get_field_value_map(
429 meta=model._model_meta, exclude=exclude
430 )
431 try:
432 if (self.condition & Exists(queryset.filter(self.condition))).check(
433 against
434 ):
435 raise ValidationError(
436 self.get_violation_error_message(),
437 code=self.violation_error_code,
438 )
439 except FieldError:
440 pass