1from types import NoneType
2
3from plain.models.backends.utils import names_digest, split_identifier
4from plain.models.expressions import Col, ExpressionList, F, Func, OrderBy
5from plain.models.functions import Collate
6from plain.models.query_utils import Q
7from plain.models.sql import Query
8from plain.utils.functional import partition
9
10__all__ = ["Index"]
11
12
13class Index:
14 suffix = "idx"
15 # The max length of the name of the index (restricted to 30 for
16 # cross-database compatibility with Oracle)
17 max_name_length = 30
18
19 def __init__(
20 self,
21 *expressions,
22 fields=(),
23 name=None,
24 opclasses=(),
25 condition=None,
26 include=None,
27 ):
28 if opclasses and not name:
29 raise ValueError("An index must be named to use opclasses.")
30 if not isinstance(condition, NoneType | Q):
31 raise ValueError("Index.condition must be a Q instance.")
32 if condition and not name:
33 raise ValueError("An index must be named to use condition.")
34 if not isinstance(fields, list | tuple):
35 raise ValueError("Index.fields must be a list or tuple.")
36 if not isinstance(opclasses, list | tuple):
37 raise ValueError("Index.opclasses must be a list or tuple.")
38 if not expressions and not fields:
39 raise ValueError(
40 "At least one field or expression is required to define an index."
41 )
42 if expressions and fields:
43 raise ValueError(
44 "Index.fields and expressions are mutually exclusive.",
45 )
46 if expressions and not name:
47 raise ValueError("An index must be named to use expressions.")
48 if expressions and opclasses:
49 raise ValueError(
50 "Index.opclasses cannot be used with expressions. Use "
51 "a custom OpClass() instead."
52 )
53 if opclasses and len(fields) != len(opclasses):
54 raise ValueError(
55 "Index.fields and Index.opclasses must have the same number of "
56 "elements."
57 )
58 if fields and not all(isinstance(field, str) for field in fields):
59 raise ValueError("Index.fields must contain only strings with field names.")
60 if include and not name:
61 raise ValueError("A covering index must be named.")
62 if not isinstance(include, NoneType | list | tuple):
63 raise ValueError("Index.include must be a list or tuple.")
64 self.fields = list(fields)
65 # A list of 2-tuple with the field name and ordering ('' or 'DESC').
66 self.fields_orders = [
67 (field_name.removeprefix("-"), "DESC" if field_name.startswith("-") else "")
68 for field_name in self.fields
69 ]
70 self.name = name or ""
71 self.opclasses = opclasses
72 self.condition = condition
73 self.include = tuple(include) if include else ()
74 self.expressions = tuple(
75 F(expression) if isinstance(expression, str) else expression
76 for expression in expressions
77 )
78
79 @property
80 def contains_expressions(self):
81 return bool(self.expressions)
82
83 def _get_condition_sql(self, model, schema_editor):
84 if self.condition is None:
85 return None
86 query = Query(model=model, alias_cols=False)
87 where = query.build_where(self.condition)
88 compiler = query.get_compiler()
89 sql, params = where.as_sql(compiler, schema_editor.connection)
90 return sql % tuple(schema_editor.quote_value(p) for p in params)
91
92 def create_sql(self, model, schema_editor, **kwargs):
93 include = [
94 model._meta.get_field(field_name).column for field_name in self.include
95 ]
96 condition = self._get_condition_sql(model, schema_editor)
97 if self.expressions:
98 index_expressions = []
99 for expression in self.expressions:
100 index_expression = IndexExpression(expression)
101 index_expression.set_wrapper_classes(schema_editor.connection)
102 index_expressions.append(index_expression)
103 expressions = ExpressionList(*index_expressions).resolve_expression(
104 Query(model, alias_cols=False),
105 )
106 fields = None
107 col_suffixes = None
108 else:
109 fields = [
110 model._meta.get_field(field_name)
111 for field_name, _ in self.fields_orders
112 ]
113 if schema_editor.connection.features.supports_index_column_ordering:
114 col_suffixes = [order[1] for order in self.fields_orders]
115 else:
116 col_suffixes = [""] * len(self.fields_orders)
117 expressions = None
118 return schema_editor._create_index_sql(
119 model,
120 fields=fields,
121 name=self.name,
122 col_suffixes=col_suffixes,
123 opclasses=self.opclasses,
124 condition=condition,
125 include=include,
126 expressions=expressions,
127 **kwargs,
128 )
129
130 def remove_sql(self, model, schema_editor, **kwargs):
131 return schema_editor._delete_index_sql(model, self.name, **kwargs)
132
133 def deconstruct(self):
134 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
135 path = path.replace("plain.models.indexes", "plain.models")
136 kwargs = {"name": self.name}
137 if self.fields:
138 kwargs["fields"] = self.fields
139 if self.opclasses:
140 kwargs["opclasses"] = self.opclasses
141 if self.condition:
142 kwargs["condition"] = self.condition
143 if self.include:
144 kwargs["include"] = self.include
145 return (path, self.expressions, kwargs)
146
147 def clone(self):
148 """Create a copy of this Index."""
149 _, args, kwargs = self.deconstruct()
150 return self.__class__(*args, **kwargs)
151
152 def set_name_with_model(self, model):
153 """
154 Generate a unique name for the index.
155
156 The name is divided into 3 parts - table name (12 chars), field name
157 (8 chars) and unique hash + suffix (10 chars). Each part is made to
158 fit its size by truncating the excess length.
159 """
160 _, table_name = split_identifier(model._meta.db_table)
161 column_names = [
162 model._meta.get_field(field_name).column
163 for field_name, order in self.fields_orders
164 ]
165 column_names_with_order = [
166 (("-%s" if order else "%s") % column_name)
167 for column_name, (field_name, order) in zip(
168 column_names, self.fields_orders
169 )
170 ]
171 # The length of the parts of the name is based on the default max
172 # length of 30 characters.
173 hash_data = [table_name] + column_names_with_order + [self.suffix]
174 self.name = "{}_{}_{}".format(
175 table_name[:11],
176 column_names[0][:7],
177 f"{names_digest(*hash_data, length=6)}_{self.suffix}",
178 )
179 if len(self.name) > self.max_name_length:
180 raise ValueError(
181 "Index too long for multiple database support. Is self.suffix "
182 "longer than 3 characters?"
183 )
184 if self.name[0] == "_" or self.name[0].isdigit():
185 self.name = f"D{self.name[1:]}"
186
187 def __repr__(self):
188 return "<{}:{}{}{}{}{}{}>".format(
189 self.__class__.__qualname__,
190 "" if not self.fields else f" fields={repr(self.fields)}",
191 "" if not self.expressions else f" expressions={repr(self.expressions)}",
192 "" if not self.name else f" name={repr(self.name)}",
193 "" if self.condition is None else f" condition={self.condition}",
194 "" if not self.include else f" include={repr(self.include)}",
195 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
196 )
197
198 def __eq__(self, other):
199 if self.__class__ == other.__class__:
200 return self.deconstruct() == other.deconstruct()
201 return NotImplemented
202
203
204class IndexExpression(Func):
205 """Order and wrap expressions for CREATE INDEX statements."""
206
207 template = "%(expressions)s"
208 wrapper_classes = (OrderBy, Collate)
209
210 def set_wrapper_classes(self, connection=None):
211 # Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
212 if connection and connection.features.collate_as_index_expression:
213 self.wrapper_classes = tuple(
214 [
215 wrapper_cls
216 for wrapper_cls in self.wrapper_classes
217 if wrapper_cls is not Collate
218 ]
219 )
220
221 def resolve_expression(
222 self,
223 query=None,
224 allow_joins=True,
225 reuse=None,
226 summarize=False,
227 for_save=False,
228 ):
229 expressions = list(self.flatten())
230 # Split expressions and wrappers.
231 index_expressions, wrappers = partition(
232 lambda e: isinstance(e, self.wrapper_classes),
233 expressions,
234 )
235 wrapper_types = [type(wrapper) for wrapper in wrappers]
236 if len(wrapper_types) != len(set(wrapper_types)):
237 raise ValueError(
238 "Multiple references to {} can't be used in an indexed "
239 "expression.".format(
240 ", ".join(
241 [
242 wrapper_cls.__qualname__
243 for wrapper_cls in self.wrapper_classes
244 ]
245 )
246 )
247 )
248 if expressions[1 : len(wrappers) + 1] != wrappers:
249 raise ValueError(
250 "{} must be topmost expressions in an indexed expression.".format(
251 ", ".join(
252 [
253 wrapper_cls.__qualname__
254 for wrapper_cls in self.wrapper_classes
255 ]
256 )
257 )
258 )
259 # Wrap expressions in parentheses if they are not column references.
260 root_expression = index_expressions[1]
261 resolve_root_expression = root_expression.resolve_expression(
262 query,
263 allow_joins,
264 reuse,
265 summarize,
266 for_save,
267 )
268 if not isinstance(resolve_root_expression, Col):
269 root_expression = Func(root_expression, template="(%(expressions)s)")
270
271 if wrappers:
272 # Order wrappers and set their expressions.
273 wrappers = sorted(
274 wrappers,
275 key=lambda w: self.wrapper_classes.index(type(w)),
276 )
277 wrappers = [wrapper.copy() for wrapper in wrappers]
278 for i, wrapper in enumerate(wrappers[:-1]):
279 wrapper.set_source_expressions([wrappers[i + 1]])
280 # Set the root expression on the deepest wrapper.
281 wrappers[-1].set_source_expressions([root_expression])
282 self.set_source_expressions([wrappers[0]])
283 else:
284 # Use the root expression, if there are no wrappers.
285 self.set_source_expressions([root_expression])
286 return super().resolve_expression(
287 query, allow_joins, reuse, summarize, for_save
288 )
289
290 def as_sqlite(self, compiler, connection, **extra_context):
291 # Casting to numeric is unnecessary.
292 return self.as_sql(compiler, connection, **extra_context)