1from __future__ import annotations
2
3from collections.abc import Sequence
4from types import NoneType
5from typing import TYPE_CHECKING, Any, Self
6
7from plain.models.backends.utils import names_digest, split_identifier
8from plain.models.expressions import Col, ExpressionList, F, Func, OrderBy
9from plain.models.query_utils import Q
10from plain.models.sql import Query
11from plain.utils.functional import partition
12
13if TYPE_CHECKING:
14 from plain.models.backends.base.base import BaseDatabaseWrapper
15 from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
16 from plain.models.backends.ddl_references import Statement
17 from plain.models.base import Model
18 from plain.models.expressions import Expression
19 from plain.models.sql.compiler import SQLCompiler
20
21__all__ = ["Index"]
22
23
24class Index:
25 suffix = "idx"
26 # The max length of the name of the index (restricted to 30 for
27 # cross-database compatibility with Oracle)
28 max_name_length = 30
29
30 def __init__(
31 self,
32 *expressions: Any,
33 fields: tuple[str, ...] | list[str] = (),
34 name: str | None = None,
35 opclasses: tuple[str, ...] | list[str] = (),
36 condition: Q | None = None,
37 include: tuple[str, ...] | list[str] | None = None,
38 ) -> None:
39 if opclasses and not name:
40 raise ValueError("An index must be named to use opclasses.")
41 if not isinstance(condition, NoneType | Q):
42 raise ValueError("Index.condition must be a Q instance.")
43 if condition and not name:
44 raise ValueError("An index must be named to use condition.")
45 if not isinstance(fields, list | tuple):
46 raise ValueError("Index.fields must be a list or tuple.")
47 if not isinstance(opclasses, list | tuple):
48 raise ValueError("Index.opclasses must be a list or tuple.")
49 if not expressions and not fields:
50 raise ValueError(
51 "At least one field or expression is required to define an index."
52 )
53 if expressions and fields:
54 raise ValueError(
55 "Index.fields and expressions are mutually exclusive.",
56 )
57 if expressions and not name:
58 raise ValueError("An index must be named to use expressions.")
59 if expressions and opclasses:
60 raise ValueError(
61 "Index.opclasses cannot be used with expressions. Use "
62 "a custom OpClass() instead."
63 )
64 if opclasses and len(fields) != len(opclasses):
65 raise ValueError(
66 "Index.fields and Index.opclasses must have the same number of "
67 "elements."
68 )
69 if fields and not all(isinstance(field, str) for field in fields):
70 raise ValueError("Index.fields must contain only strings with field names.")
71 if include and not name:
72 raise ValueError("A covering index must be named.")
73 if not isinstance(include, NoneType | list | tuple):
74 raise ValueError("Index.include must be a list or tuple.")
75 self.fields = list(fields)
76 # A list of 2-tuple with the field name and ordering ('' or 'DESC').
77 self.fields_orders = [
78 (field_name.removeprefix("-"), "DESC" if field_name.startswith("-") else "")
79 for field_name in self.fields
80 ]
81 self.name = name or ""
82 self.opclasses: tuple[str, ...] = tuple(opclasses)
83 self.condition = condition
84 self.include = tuple(include) if include else ()
85 self.expressions: tuple[Expression, ...] = tuple(
86 F(expression) if isinstance(expression, str) else expression
87 for expression in expressions
88 )
89
90 @property
91 def contains_expressions(self) -> bool:
92 return bool(self.expressions)
93
94 def _get_condition_sql(
95 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
96 ) -> str | None:
97 if self.condition is None:
98 return None
99 query = Query(model=model, alias_cols=False)
100 where = query.build_where(self.condition)
101 compiler = query.get_compiler()
102 sql, params = where.as_sql(compiler, schema_editor.connection)
103 return sql % tuple(schema_editor.quote_value(p) for p in params)
104
105 def create_sql(
106 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor, **kwargs: Any
107 ) -> Statement:
108 include = [
109 model._model_meta.get_forward_field(field_name).column
110 for field_name in self.include
111 ]
112 condition = self._get_condition_sql(model, schema_editor)
113 if self.expressions:
114 index_expressions = []
115 for expression in self.expressions:
116 index_expression = IndexExpression(expression)
117 index_expressions.append(index_expression)
118 expressions = ExpressionList(*index_expressions).resolve_expression(
119 Query(model, alias_cols=False),
120 )
121 fields = None
122 col_suffixes = ()
123 else:
124 fields = [
125 model._model_meta.get_forward_field(field_name)
126 for field_name, _ in self.fields_orders
127 ]
128 if schema_editor.connection.features.supports_index_column_ordering:
129 col_suffixes = tuple(order[1] for order in self.fields_orders)
130 else:
131 col_suffixes = ("",) * len(self.fields_orders)
132 expressions = None
133 return schema_editor._create_index_sql(
134 model,
135 fields=fields,
136 name=self.name,
137 col_suffixes=col_suffixes,
138 opclasses=self.opclasses,
139 condition=condition,
140 include=include,
141 expressions=expressions,
142 **kwargs,
143 )
144
145 def remove_sql(
146 self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor, **kwargs: Any
147 ) -> Statement:
148 return schema_editor._delete_index_sql(model, self.name, **kwargs)
149
150 def deconstruct(self) -> tuple[str, tuple[Expression, ...], dict[str, Any]]:
151 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
152 path = path.replace("plain.models.indexes", "plain.models")
153 kwargs = {"name": self.name}
154 if self.fields:
155 kwargs["fields"] = self.fields
156 if self.opclasses:
157 kwargs["opclasses"] = self.opclasses
158 if self.condition:
159 kwargs["condition"] = self.condition
160 if self.include:
161 kwargs["include"] = self.include
162 return (path, self.expressions, kwargs)
163
164 def clone(self) -> Index:
165 """Create a copy of this Index."""
166 _, args, kwargs = self.deconstruct()
167 return self.__class__(*args, **kwargs)
168
169 def set_name_with_model(self, model: type[Model]) -> None:
170 """
171 Generate a unique name for the index.
172
173 The name is divided into 3 parts - table name (12 chars), field name
174 (8 chars) and unique hash + suffix (10 chars). Each part is made to
175 fit its size by truncating the excess length.
176 """
177 _, table_name = split_identifier(model.model_options.db_table)
178 column_names = [
179 model._model_meta.get_forward_field(field_name).column
180 for field_name, order in self.fields_orders
181 ]
182 column_names_with_order = [
183 (("-%s" if order else "%s") % column_name)
184 for column_name, (field_name, order) in zip(
185 column_names, self.fields_orders
186 )
187 ]
188 # The length of the parts of the name is based on the default max
189 # length of 30 characters.
190 hash_data = [table_name] + column_names_with_order + [self.suffix]
191 self.name = "{}_{}_{}".format(
192 table_name[:11],
193 column_names[0][:7],
194 f"{names_digest(*hash_data, length=6)}_{self.suffix}",
195 )
196 if len(self.name) > self.max_name_length:
197 raise ValueError(
198 "Index too long for multiple database support. Is self.suffix "
199 "longer than 3 characters?"
200 )
201 if self.name[0] == "_" or self.name[0].isdigit():
202 self.name = f"D{self.name[1:]}"
203
204 def __repr__(self) -> str:
205 return "<{}:{}{}{}{}{}{}>".format(
206 self.__class__.__qualname__,
207 "" if not self.fields else f" fields={repr(self.fields)}",
208 "" if not self.expressions else f" expressions={repr(self.expressions)}",
209 "" if not self.name else f" name={repr(self.name)}",
210 "" if self.condition is None else f" condition={self.condition}",
211 "" if not self.include else f" include={repr(self.include)}",
212 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
213 )
214
215 def __eq__(self, other: object) -> bool:
216 if isinstance(other, Index):
217 return self.deconstruct() == other.deconstruct()
218 return NotImplemented
219
220
221class IndexExpression(Func):
222 """Order and wrap expressions for CREATE INDEX statements."""
223
224 template = "%(expressions)s"
225 wrapper_classes = (OrderBy,)
226
227 def resolve_expression(
228 self,
229 query: Any = None,
230 allow_joins: bool = True,
231 reuse: Any = None,
232 summarize: bool = False,
233 for_save: bool = False,
234 ) -> Self:
235 expressions = list(self.flatten())
236 # Split expressions and wrappers.
237 index_expressions, wrappers = partition(
238 lambda e: isinstance(e, self.wrapper_classes),
239 expressions,
240 )
241 wrapper_types = [type(wrapper) for wrapper in wrappers]
242 if len(wrapper_types) != len(set(wrapper_types)):
243 raise ValueError(
244 "Multiple references to {} can't be used in an indexed "
245 "expression.".format(
246 ", ".join(
247 [
248 wrapper_cls.__qualname__
249 for wrapper_cls in self.wrapper_classes
250 ]
251 )
252 )
253 )
254 if expressions[1 : len(wrappers) + 1] != wrappers:
255 raise ValueError(
256 "{} must be topmost expressions in an indexed expression.".format(
257 ", ".join(
258 [
259 wrapper_cls.__qualname__
260 for wrapper_cls in self.wrapper_classes
261 ]
262 )
263 )
264 )
265 # Wrap expressions in parentheses if they are not column references.
266 root_expression = index_expressions[1]
267 resolve_root_expression = root_expression.resolve_expression(
268 query,
269 allow_joins,
270 reuse,
271 summarize,
272 for_save,
273 )
274 if not isinstance(resolve_root_expression, Col):
275 root_expression = Func(root_expression, template="(%(expressions)s)")
276
277 if wrappers:
278 # Order wrappers and set their expressions.
279 wrappers = sorted(
280 wrappers,
281 key=lambda w: self.wrapper_classes.index(type(w)),
282 )
283 wrappers = [wrapper.copy() for wrapper in wrappers]
284 for i, wrapper in enumerate(wrappers[:-1]):
285 wrapper.set_source_expressions([wrappers[i + 1]])
286 # Set the root expression on the deepest wrapper.
287 wrappers[-1].set_source_expressions([root_expression])
288 self.set_source_expressions([wrappers[0]])
289 else:
290 # Use the root expression, if there are no wrappers.
291 self.set_source_expressions([root_expression])
292 return super().resolve_expression(
293 query, allow_joins, reuse, summarize, for_save
294 )
295
296 def as_sqlite(
297 self,
298 compiler: SQLCompiler,
299 connection: BaseDatabaseWrapper,
300 **extra_context: Any,
301 ) -> tuple[str, Sequence[Any]]:
302 # Casting to numeric is unnecessary.
303 return self.as_sql(compiler, connection, **extra_context)