1from enum import Enum
2from types import NoneType
3
4from plain.exceptions import FieldError, ValidationError
5from plain.models.expressions import Exists, ExpressionList, F, OrderBy
6from plain.models.indexes import IndexExpression
7from plain.models.lookups import Exact
8from plain.models.query_utils import Q
9from plain.models.sql.query import Query
10
11__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
12
13
14class BaseConstraint:
15 default_violation_error_message = "Constraint “%(name)s” is violated."
16 violation_error_code = None
17 violation_error_message = None
18
19 def __init__(
20 self, *, name, violation_error_code=None, violation_error_message=None
21 ):
22 self.name = name
23 if violation_error_code is not None:
24 self.violation_error_code = violation_error_code
25 if violation_error_message is not None:
26 self.violation_error_message = violation_error_message
27 else:
28 self.violation_error_message = self.default_violation_error_message
29
30 @property
31 def contains_expressions(self):
32 return False
33
34 def constraint_sql(self, model, schema_editor):
35 raise NotImplementedError("This method must be implemented by a subclass.")
36
37 def create_sql(self, model, schema_editor):
38 raise NotImplementedError("This method must be implemented by a subclass.")
39
40 def remove_sql(self, model, schema_editor):
41 raise NotImplementedError("This method must be implemented by a subclass.")
42
43 def validate(self, model, instance, exclude=None):
44 raise NotImplementedError("This method must be implemented by a subclass.")
45
46 def get_violation_error_message(self):
47 return self.violation_error_message % {"name": self.name}
48
49 def deconstruct(self):
50 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
51 path = path.replace("plain.models.constraints", "plain.models")
52 kwargs = {"name": self.name}
53 if (
54 self.violation_error_message is not None
55 and self.violation_error_message != self.default_violation_error_message
56 ):
57 kwargs["violation_error_message"] = self.violation_error_message
58 if self.violation_error_code is not None:
59 kwargs["violation_error_code"] = self.violation_error_code
60 return (path, (), kwargs)
61
62 def clone(self):
63 _, args, kwargs = self.deconstruct()
64 return self.__class__(*args, **kwargs)
65
66
67class CheckConstraint(BaseConstraint):
68 def __init__(
69 self, *, check, name, violation_error_code=None, violation_error_message=None
70 ):
71 self.check = check
72 if not getattr(check, "conditional", False):
73 raise TypeError(
74 "CheckConstraint.check must be a Q instance or boolean expression."
75 )
76 super().__init__(
77 name=name,
78 violation_error_code=violation_error_code,
79 violation_error_message=violation_error_message,
80 )
81
82 def _get_check_sql(self, model, schema_editor):
83 query = Query(model=model, alias_cols=False)
84 where = query.build_where(self.check)
85 compiler = query.get_compiler()
86 sql, params = where.as_sql(compiler, schema_editor.connection)
87 return sql % tuple(schema_editor.quote_value(p) for p in params)
88
89 def constraint_sql(self, model, schema_editor):
90 check = self._get_check_sql(model, schema_editor)
91 return schema_editor._check_sql(self.name, check)
92
93 def create_sql(self, model, schema_editor):
94 check = self._get_check_sql(model, schema_editor)
95 return schema_editor._create_check_sql(model, self.name, check)
96
97 def remove_sql(self, model, schema_editor):
98 return schema_editor._delete_check_sql(model, self.name)
99
100 def validate(self, model, instance, exclude=None):
101 against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
102 try:
103 if not Q(self.check).check(against):
104 raise ValidationError(
105 self.get_violation_error_message(), code=self.violation_error_code
106 )
107 except FieldError:
108 pass
109
110 def __repr__(self):
111 return "<{}: check={} name={}{}{}>".format(
112 self.__class__.__qualname__,
113 self.check,
114 repr(self.name),
115 (
116 ""
117 if self.violation_error_code is None
118 else f" violation_error_code={self.violation_error_code!r}"
119 ),
120 (
121 ""
122 if self.violation_error_message is None
123 or self.violation_error_message == self.default_violation_error_message
124 else f" violation_error_message={self.violation_error_message!r}"
125 ),
126 )
127
128 def __eq__(self, other):
129 if isinstance(other, CheckConstraint):
130 return (
131 self.name == other.name
132 and self.check == other.check
133 and self.violation_error_code == other.violation_error_code
134 and self.violation_error_message == other.violation_error_message
135 )
136 return super().__eq__(other)
137
138 def deconstruct(self):
139 path, args, kwargs = super().deconstruct()
140 kwargs["check"] = self.check
141 return path, args, kwargs
142
143
144class Deferrable(Enum):
145 DEFERRED = "deferred"
146 IMMEDIATE = "immediate"
147
148 # A similar format was proposed for Python 3.10.
149 def __repr__(self):
150 return f"{self.__class__.__qualname__}.{self._name_}"
151
152
153class UniqueConstraint(BaseConstraint):
154 def __init__(
155 self,
156 *expressions,
157 fields=(),
158 name=None,
159 condition=None,
160 deferrable=None,
161 include=None,
162 opclasses=(),
163 violation_error_code=None,
164 violation_error_message=None,
165 ):
166 if not name:
167 raise ValueError("A unique constraint must be named.")
168 if not expressions and not fields:
169 raise ValueError(
170 "At least one field or expression is required to define a "
171 "unique constraint."
172 )
173 if expressions and fields:
174 raise ValueError(
175 "UniqueConstraint.fields and expressions are mutually exclusive."
176 )
177 if not isinstance(condition, NoneType | Q):
178 raise ValueError("UniqueConstraint.condition must be a Q instance.")
179 if condition and deferrable:
180 raise ValueError("UniqueConstraint with conditions cannot be deferred.")
181 if include and deferrable:
182 raise ValueError("UniqueConstraint with include fields cannot be deferred.")
183 if opclasses and deferrable:
184 raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
185 if expressions and deferrable:
186 raise ValueError("UniqueConstraint with expressions cannot be deferred.")
187 if expressions and opclasses:
188 raise ValueError(
189 "UniqueConstraint.opclasses cannot be used with expressions. "
190 "Use a custom OpClass() instead."
191 )
192 if not isinstance(deferrable, NoneType | Deferrable):
193 raise ValueError(
194 "UniqueConstraint.deferrable must be a Deferrable instance."
195 )
196 if not isinstance(include, NoneType | list | tuple):
197 raise ValueError("UniqueConstraint.include must be a list or tuple.")
198 if not isinstance(opclasses, list | tuple):
199 raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
200 if opclasses and len(fields) != len(opclasses):
201 raise ValueError(
202 "UniqueConstraint.fields and UniqueConstraint.opclasses must "
203 "have the same number of elements."
204 )
205 self.fields = tuple(fields)
206 self.condition = condition
207 self.deferrable = deferrable
208 self.include = tuple(include) if include else ()
209 self.opclasses = opclasses
210 self.expressions = tuple(
211 F(expression) if isinstance(expression, str) else expression
212 for expression in expressions
213 )
214 super().__init__(
215 name=name,
216 violation_error_code=violation_error_code,
217 violation_error_message=violation_error_message,
218 )
219
220 @property
221 def contains_expressions(self):
222 return bool(self.expressions)
223
224 def _get_condition_sql(self, model, schema_editor):
225 if self.condition is None:
226 return None
227 query = Query(model=model, alias_cols=False)
228 where = query.build_where(self.condition)
229 compiler = query.get_compiler()
230 sql, params = where.as_sql(compiler, schema_editor.connection)
231 return sql % tuple(schema_editor.quote_value(p) for p in params)
232
233 def _get_index_expressions(self, model, schema_editor):
234 if not self.expressions:
235 return None
236 index_expressions = []
237 for expression in self.expressions:
238 index_expression = IndexExpression(expression)
239 index_expression.set_wrapper_classes(schema_editor.connection)
240 index_expressions.append(index_expression)
241 return ExpressionList(*index_expressions).resolve_expression(
242 Query(model, alias_cols=False),
243 )
244
245 def constraint_sql(self, model, schema_editor):
246 fields = [model._meta.get_field(field_name) for field_name in self.fields]
247 include = [
248 model._meta.get_field(field_name).column for field_name in self.include
249 ]
250 condition = self._get_condition_sql(model, schema_editor)
251 expressions = self._get_index_expressions(model, schema_editor)
252 return schema_editor._unique_sql(
253 model,
254 fields,
255 self.name,
256 condition=condition,
257 deferrable=self.deferrable,
258 include=include,
259 opclasses=self.opclasses,
260 expressions=expressions,
261 )
262
263 def create_sql(self, model, schema_editor):
264 fields = [model._meta.get_field(field_name) for field_name in self.fields]
265 include = [
266 model._meta.get_field(field_name).column for field_name in self.include
267 ]
268 condition = self._get_condition_sql(model, schema_editor)
269 expressions = self._get_index_expressions(model, schema_editor)
270 return schema_editor._create_unique_sql(
271 model,
272 fields,
273 self.name,
274 condition=condition,
275 deferrable=self.deferrable,
276 include=include,
277 opclasses=self.opclasses,
278 expressions=expressions,
279 )
280
281 def remove_sql(self, model, schema_editor):
282 condition = self._get_condition_sql(model, schema_editor)
283 include = [
284 model._meta.get_field(field_name).column for field_name in self.include
285 ]
286 expressions = self._get_index_expressions(model, schema_editor)
287 return schema_editor._delete_unique_sql(
288 model,
289 self.name,
290 condition=condition,
291 deferrable=self.deferrable,
292 include=include,
293 opclasses=self.opclasses,
294 expressions=expressions,
295 )
296
297 def __repr__(self):
298 return "<{}:{}{}{}{}{}{}{}{}{}>".format(
299 self.__class__.__qualname__,
300 "" if not self.fields else f" fields={repr(self.fields)}",
301 "" if not self.expressions else f" expressions={repr(self.expressions)}",
302 f" name={repr(self.name)}",
303 "" if self.condition is None else f" condition={self.condition}",
304 "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
305 "" if not self.include else f" include={repr(self.include)}",
306 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
307 (
308 ""
309 if self.violation_error_code is None
310 else f" violation_error_code={self.violation_error_code!r}"
311 ),
312 (
313 ""
314 if self.violation_error_message is None
315 or self.violation_error_message == self.default_violation_error_message
316 else f" violation_error_message={self.violation_error_message!r}"
317 ),
318 )
319
320 def __eq__(self, other):
321 if isinstance(other, UniqueConstraint):
322 return (
323 self.name == other.name
324 and self.fields == other.fields
325 and self.condition == other.condition
326 and self.deferrable == other.deferrable
327 and self.include == other.include
328 and self.opclasses == other.opclasses
329 and self.expressions == other.expressions
330 and self.violation_error_code == other.violation_error_code
331 and self.violation_error_message == other.violation_error_message
332 )
333 return super().__eq__(other)
334
335 def deconstruct(self):
336 path, args, kwargs = super().deconstruct()
337 if self.fields:
338 kwargs["fields"] = self.fields
339 if self.condition:
340 kwargs["condition"] = self.condition
341 if self.deferrable:
342 kwargs["deferrable"] = self.deferrable
343 if self.include:
344 kwargs["include"] = self.include
345 if self.opclasses:
346 kwargs["opclasses"] = self.opclasses
347 return path, self.expressions, kwargs
348
349 def validate(self, model, instance, exclude=None):
350 queryset = model._default_manager
351 if self.fields:
352 lookup_kwargs = {}
353 for field_name in self.fields:
354 if exclude and field_name in exclude:
355 return
356 field = model._meta.get_field(field_name)
357 lookup_value = getattr(instance, field.attname)
358 if lookup_value is None:
359 # A composite constraint containing NULL value cannot cause
360 # a violation since NULL != NULL in SQL.
361 return
362 lookup_kwargs[field.name] = lookup_value
363 queryset = queryset.filter(**lookup_kwargs)
364 else:
365 # Ignore constraints with excluded fields.
366 if exclude:
367 for expression in self.expressions:
368 if hasattr(expression, "flatten"):
369 for expr in expression.flatten():
370 if isinstance(expr, F) and expr.name in exclude:
371 return
372 elif isinstance(expression, F) and expression.name in exclude:
373 return
374 replacements = {
375 F(field): value
376 for field, value in instance._get_field_value_map(
377 meta=model._meta, exclude=exclude
378 ).items()
379 }
380 expressions = []
381 for expr in self.expressions:
382 # Ignore ordering.
383 if isinstance(expr, OrderBy):
384 expr = expr.expression
385 expressions.append(Exact(expr, expr.replace_expressions(replacements)))
386 queryset = queryset.filter(*expressions)
387 model_class_pk = instance._get_pk_val(model._meta)
388 if not instance._state.adding and model_class_pk is not None:
389 queryset = queryset.exclude(pk=model_class_pk)
390 if not self.condition:
391 if queryset.exists():
392 if self.expressions:
393 raise ValidationError(
394 self.get_violation_error_message(),
395 code=self.violation_error_code,
396 )
397 # When fields are defined, use the unique_error_message() for
398 # backward compatibility.
399 for model, constraints in instance.get_constraints():
400 for constraint in constraints:
401 if constraint is self:
402 raise ValidationError(
403 instance.unique_error_message(model, self.fields),
404 )
405 else:
406 against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
407 try:
408 if (self.condition & Exists(queryset.filter(self.condition))).check(
409 against
410 ):
411 raise ValidationError(
412 self.get_violation_error_message(),
413 code=self.violation_error_code,
414 )
415 except FieldError:
416 pass