1from __future__ import annotations
2
3from collections.abc import Sequence
4from types import NoneType
5from typing import TYPE_CHECKING, Any, Self, cast
6
7from plain.models.backends.utils import names_digest, split_identifier
8from plain.models.expressions import Col, ExpressionList, F, Func, OrderBy
9from plain.models.functions import Collate
10from plain.models.query_utils import Q
11from plain.models.sql import Query
12from plain.utils.functional import partition
13
14if TYPE_CHECKING:
15 from plain.models.backends.base.base import BaseDatabaseWrapper
16 from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
17 from plain.models.backends.ddl_references import Statement
18 from plain.models.base import Model
19 from plain.models.expressions import Expression
20 from plain.models.sql.compiler import SQLCompiler
21
22__all__ = ["Index"]
23
24
25class Index:
26 suffix = "idx"
27 # The max length of the name of the index (restricted to 30 for
28 # cross-database compatibility with Oracle)
29 max_name_length = 30
30
31 def __init__(
32 self,
33 *expressions: Any,
34 fields: tuple[str, ...] | list[str] = (),
35 name: str | None = None,
36 opclasses: tuple[str, ...] | list[str] = (),
37 condition: Q | None = None,
38 include: tuple[str, ...] | list[str] | None = None,
39 ) -> None:
40 if opclasses and not name:
41 raise ValueError("An index must be named to use opclasses.")
42 if not isinstance(condition, NoneType | Q):
43 raise ValueError("Index.condition must be a Q instance.")
44 if condition and not name:
45 raise ValueError("An index must be named to use condition.")
46 if not isinstance(fields, list | tuple):
47 raise ValueError("Index.fields must be a list or tuple.")
48 if not isinstance(opclasses, list | tuple):
49 raise ValueError("Index.opclasses must be a list or tuple.")
50 if not expressions and not fields:
51 raise ValueError(
52 "At least one field or expression is required to define an index."
53 )
54 if expressions and fields:
55 raise ValueError(
56 "Index.fields and expressions are mutually exclusive.",
57 )
58 if expressions and not name:
59 raise ValueError("An index must be named to use expressions.")
60 if expressions and opclasses:
61 raise ValueError(
62 "Index.opclasses cannot be used with expressions. Use "
63 "a custom OpClass() instead."
64 )
65 if opclasses and len(fields) != len(opclasses):
66 raise ValueError(
67 "Index.fields and Index.opclasses must have the same number of "
68 "elements."
69 )
70 if fields and not all(isinstance(field, str) for field in fields):
71 raise ValueError("Index.fields must contain only strings with field names.")
72 if include and not name:
73 raise ValueError("A covering index must be named.")
74 if not isinstance(include, NoneType | list | tuple):
75 raise ValueError("Index.include must be a list or tuple.")
76 self.fields = list(fields)
77 # A list of 2-tuple with the field name and ordering ('' or 'DESC').
78 self.fields_orders = [
79 (field_name.removeprefix("-"), "DESC" if field_name.startswith("-") else "")
80 for field_name in self.fields
81 ]
82 self.name = name or ""
83 self.opclasses: tuple[str, ...] = tuple(opclasses)
84 self.condition = condition
85 self.include = tuple(include) if include else ()
86 self.expressions: tuple[Expression, ...] = tuple(
87 F(expression) if isinstance(expression, str) else expression
88 for expression in expressions
89 )
90
91 @property
92 def contains_expressions(self) -> bool:
93 return bool(self.expressions)
94
95 def _get_condition_sql(
96 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
97 ) -> str | None:
98 if self.condition is None:
99 return None
100 query = Query(model=model, alias_cols=False)
101 where = query.build_where(self.condition)
102 compiler = query.get_compiler()
103 sql, params = where.as_sql(compiler, schema_editor.connection)
104 return sql % tuple(schema_editor.quote_value(p) for p in params)
105
106 def create_sql(
107 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor, **kwargs: Any
108 ) -> Statement:
109 include = [
110 model._model_meta.get_forward_field(field_name).column
111 for field_name in self.include
112 ]
113 condition = self._get_condition_sql(model, schema_editor)
114 if self.expressions:
115 index_expressions = []
116 for expression in self.expressions:
117 index_expression = IndexExpression(expression)
118 index_expression.set_wrapper_classes(schema_editor.connection)
119 index_expressions.append(index_expression)
120 expressions = ExpressionList(*index_expressions).resolve_expression(
121 Query(model, alias_cols=False),
122 )
123 fields = None
124 col_suffixes = ()
125 else:
126 fields = [
127 model._model_meta.get_forward_field(field_name)
128 for field_name, _ in self.fields_orders
129 ]
130 if schema_editor.connection.features.supports_index_column_ordering:
131 col_suffixes = tuple(order[1] for order in self.fields_orders)
132 else:
133 col_suffixes = ("",) * len(self.fields_orders)
134 expressions = None
135 return schema_editor._create_index_sql(
136 model,
137 fields=fields,
138 name=self.name,
139 col_suffixes=col_suffixes,
140 opclasses=self.opclasses,
141 condition=condition,
142 include=include,
143 expressions=expressions,
144 **kwargs,
145 )
146
147 def remove_sql(
148 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor, **kwargs: Any
149 ) -> Statement:
150 return schema_editor._delete_index_sql(model, self.name, **kwargs)
151
152 def deconstruct(self) -> tuple[str, tuple[Expression, ...], dict[str, Any]]:
153 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
154 path = path.replace("plain.models.indexes", "plain.models")
155 kwargs = {"name": self.name}
156 if self.fields:
157 kwargs["fields"] = self.fields
158 if self.opclasses:
159 kwargs["opclasses"] = self.opclasses
160 if self.condition:
161 kwargs["condition"] = self.condition
162 if self.include:
163 kwargs["include"] = self.include
164 return (path, self.expressions, kwargs)
165
166 def clone(self) -> Index:
167 """Create a copy of this Index."""
168 _, args, kwargs = self.deconstruct()
169 return self.__class__(*args, **kwargs)
170
171 def set_name_with_model(self, model: type[Model]) -> None:
172 """
173 Generate a unique name for the index.
174
175 The name is divided into 3 parts - table name (12 chars), field name
176 (8 chars) and unique hash + suffix (10 chars). Each part is made to
177 fit its size by truncating the excess length.
178 """
179 _, table_name = split_identifier(model.model_options.db_table)
180 column_names = [
181 model._model_meta.get_forward_field(field_name).column
182 for field_name, order in self.fields_orders
183 ]
184 column_names_with_order = [
185 (("-%s" if order else "%s") % column_name)
186 for column_name, (field_name, order) in zip(
187 column_names, self.fields_orders
188 )
189 ]
190 # The length of the parts of the name is based on the default max
191 # length of 30 characters.
192 hash_data = [table_name] + column_names_with_order + [self.suffix]
193 self.name = "{}_{}_{}".format(
194 table_name[:11],
195 column_names[0][:7],
196 f"{names_digest(*hash_data, length=6)}_{self.suffix}",
197 )
198 if len(self.name) > self.max_name_length:
199 raise ValueError(
200 "Index too long for multiple database support. Is self.suffix "
201 "longer than 3 characters?"
202 )
203 if self.name[0] == "_" or self.name[0].isdigit():
204 self.name = f"D{self.name[1:]}"
205
206 def __repr__(self) -> str:
207 return "<{}:{}{}{}{}{}{}>".format(
208 self.__class__.__qualname__,
209 "" if not self.fields else f" fields={repr(self.fields)}",
210 "" if not self.expressions else f" expressions={repr(self.expressions)}",
211 "" if not self.name else f" name={repr(self.name)}",
212 "" if self.condition is None else f" condition={self.condition}",
213 "" if not self.include else f" include={repr(self.include)}",
214 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
215 )
216
217 def __eq__(self, other: object) -> bool:
218 if isinstance(other, Index):
219 return self.deconstruct() == other.deconstruct()
220 return NotImplemented
221
222
223class IndexExpression(Func):
224 """Order and wrap expressions for CREATE INDEX statements."""
225
226 template = "%(expressions)s"
227 wrapper_classes = (OrderBy, Collate)
228
229 def set_wrapper_classes(
230 self, connection: BaseDatabaseWrapper | None = None
231 ) -> None:
232 # Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
233 if connection and connection.features.collate_as_index_expression:
234 self.wrapper_classes = tuple(
235 [
236 wrapper_cls
237 for wrapper_cls in self.wrapper_classes
238 if wrapper_cls is not Collate
239 ]
240 )
241
242 def resolve_expression(
243 self,
244 query: Any = None,
245 allow_joins: bool = True,
246 reuse: Any = None,
247 summarize: bool = False,
248 for_save: bool = False,
249 ) -> Self:
250 expressions = list(self.flatten())
251 # Split expressions and wrappers.
252 index_expressions, wrappers = partition(
253 lambda e: isinstance(e, self.wrapper_classes),
254 expressions,
255 )
256 wrapper_types = [type(wrapper) for wrapper in wrappers]
257 if len(wrapper_types) != len(set(wrapper_types)):
258 raise ValueError(
259 "Multiple references to {} can't be used in an indexed "
260 "expression.".format(
261 ", ".join(
262 [
263 wrapper_cls.__qualname__
264 for wrapper_cls in self.wrapper_classes
265 ]
266 )
267 )
268 )
269 if expressions[1 : len(wrappers) + 1] != wrappers:
270 raise ValueError(
271 "{} must be topmost expressions in an indexed expression.".format(
272 ", ".join(
273 [
274 wrapper_cls.__qualname__
275 for wrapper_cls in self.wrapper_classes
276 ]
277 )
278 )
279 )
280 # Wrap expressions in parentheses if they are not column references.
281 root_expression = index_expressions[1]
282 resolve_root_expression = root_expression.resolve_expression(
283 query,
284 allow_joins,
285 reuse,
286 summarize,
287 for_save,
288 )
289 if not isinstance(resolve_root_expression, Col):
290 root_expression = Func(root_expression, template="(%(expressions)s)")
291
292 if wrappers:
293 # Order wrappers and set their expressions.
294 wrappers = sorted(
295 wrappers,
296 key=lambda w: self.wrapper_classes.index(type(w)),
297 )
298 wrappers = [wrapper.copy() for wrapper in wrappers]
299 for i, wrapper in enumerate(wrappers[:-1]):
300 wrapper.set_source_expressions([wrappers[i + 1]])
301 # Set the root expression on the deepest wrapper.
302 wrappers[-1].set_source_expressions([root_expression])
303 self.set_source_expressions([wrappers[0]])
304 else:
305 # Use the root expression, if there are no wrappers.
306 self.set_source_expressions([root_expression])
307 # Cast needed because super() returns parent's Self type, not subclass's Self
308 return cast(
309 Self,
310 super().resolve_expression(query, allow_joins, reuse, summarize, for_save),
311 )
312
313 def as_sqlite(
314 self,
315 compiler: SQLCompiler,
316 connection: BaseDatabaseWrapper,
317 **extra_context: Any,
318 ) -> tuple[str, Sequence[Any]]:
319 # Casting to numeric is unnecessary.
320 return self.as_sql(compiler, connection, **extra_context)