1from __future__ import annotations
2
3from types import NoneType
4from typing import TYPE_CHECKING, Any, Self
5
6from plain.models.expressions import Col, ExpressionList, F, Func, OrderBy
7from plain.models.postgres.utils import names_digest, split_identifier
8from plain.models.query_utils import Q
9from plain.models.sql.query import Query
10from plain.utils.functional import partition
11
12if TYPE_CHECKING:
13 from plain.models.base import Model
14 from plain.models.expressions import Expression
15 from plain.models.postgres.schema import DatabaseSchemaEditor, Statement
16
17__all__ = ["Index"]
18
19
20class Index:
21 suffix = "idx"
22 # The max length of the name of the index
23 max_name_length = 30
24
25 def __init__(
26 self,
27 *expressions: Any,
28 fields: tuple[str, ...] | list[str] = (),
29 name: str | None = None,
30 opclasses: tuple[str, ...] | list[str] = (),
31 condition: Q | None = None,
32 include: tuple[str, ...] | list[str] | None = None,
33 ) -> None:
34 if opclasses and not name:
35 raise ValueError("An index must be named to use opclasses.")
36 if not isinstance(condition, NoneType | Q):
37 raise ValueError("Index.condition must be a Q instance.")
38 if condition and not name:
39 raise ValueError("An index must be named to use condition.")
40 if not isinstance(fields, list | tuple):
41 raise ValueError("Index.fields must be a list or tuple.")
42 if not isinstance(opclasses, list | tuple):
43 raise ValueError("Index.opclasses must be a list or tuple.")
44 if not expressions and not fields:
45 raise ValueError(
46 "At least one field or expression is required to define an index."
47 )
48 if expressions and fields:
49 raise ValueError(
50 "Index.fields and expressions are mutually exclusive.",
51 )
52 if expressions and not name:
53 raise ValueError("An index must be named to use expressions.")
54 if expressions and opclasses:
55 raise ValueError(
56 "Index.opclasses cannot be used with expressions. Use "
57 "a custom OpClass() instead."
58 )
59 if opclasses and len(fields) != len(opclasses):
60 raise ValueError(
61 "Index.fields and Index.opclasses must have the same number of "
62 "elements."
63 )
64 if fields and not all(isinstance(field, str) for field in fields):
65 raise ValueError("Index.fields must contain only strings with field names.")
66 if include and not name:
67 raise ValueError("A covering index must be named.")
68 if not isinstance(include, NoneType | list | tuple):
69 raise ValueError("Index.include must be a list or tuple.")
70 self.fields = list(fields)
71 # A list of 2-tuple with the field name and ordering ('' or 'DESC').
72 self.fields_orders = [
73 (field_name.removeprefix("-"), "DESC" if field_name.startswith("-") else "")
74 for field_name in self.fields
75 ]
76 self.name = name or ""
77 self.opclasses: tuple[str, ...] = tuple(opclasses)
78 self.condition = condition
79 self.include = tuple(include) if include else ()
80 self.expressions: tuple[Expression, ...] = tuple(
81 F(expression) if isinstance(expression, str) else expression
82 for expression in expressions
83 )
84
85 @property
86 def contains_expressions(self) -> bool:
87 return bool(self.expressions)
88
89 def _get_condition_sql(
90 self, model: type[Model], schema_editor: DatabaseSchemaEditor
91 ) -> str | None:
92 if self.condition is None:
93 return None
94 query = Query(model=model, alias_cols=False)
95 where = query.build_where(self.condition)
96 compiler = query.get_compiler()
97 sql, params = where.as_sql(compiler, schema_editor.connection)
98 return sql % tuple(schema_editor.quote_value(p) for p in params)
99
100 def create_sql(
101 self, model: type[Model], schema_editor: DatabaseSchemaEditor, **kwargs: Any
102 ) -> Statement:
103 include = [
104 model._model_meta.get_forward_field(field_name).column
105 for field_name in self.include
106 ]
107 condition = self._get_condition_sql(model, schema_editor)
108 if self.expressions:
109 index_expressions = []
110 for expression in self.expressions:
111 index_expression = IndexExpression(expression)
112 index_expressions.append(index_expression)
113 expressions = ExpressionList(*index_expressions).resolve_expression(
114 Query(model, alias_cols=False),
115 )
116 fields = None
117 col_suffixes = ()
118 else:
119 fields = [
120 model._model_meta.get_forward_field(field_name)
121 for field_name, _ in self.fields_orders
122 ]
123 # Support index column ordering (ASC/DESC)
124 col_suffixes = tuple(order[1] for order in self.fields_orders)
125 expressions = None
126 return schema_editor._create_index_sql(
127 model,
128 fields=fields,
129 name=self.name,
130 col_suffixes=col_suffixes,
131 opclasses=self.opclasses,
132 condition=condition,
133 include=include,
134 expressions=expressions,
135 **kwargs,
136 )
137
138 def remove_sql(
139 self, model: type[Model], schema_editor: DatabaseSchemaEditor, **kwargs: Any
140 ) -> Statement:
141 return schema_editor._delete_index_sql(model, self.name, **kwargs)
142
143 def deconstruct(self) -> tuple[str, tuple[Expression, ...], dict[str, Any]]:
144 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
145 path = path.replace("plain.models.indexes", "plain.models")
146 kwargs = {"name": self.name}
147 if self.fields:
148 kwargs["fields"] = self.fields
149 if self.opclasses:
150 kwargs["opclasses"] = self.opclasses
151 if self.condition:
152 kwargs["condition"] = self.condition
153 if self.include:
154 kwargs["include"] = self.include
155 return (path, self.expressions, kwargs)
156
157 def clone(self) -> Index:
158 """Create a copy of this Index."""
159 _, args, kwargs = self.deconstruct()
160 return self.__class__(*args, **kwargs)
161
162 def set_name_with_model(self, model: type[Model]) -> None:
163 """
164 Generate a unique name for the index.
165
166 The name is divided into 3 parts - table name (12 chars), field name
167 (8 chars) and unique hash + suffix (10 chars). Each part is made to
168 fit its size by truncating the excess length.
169 """
170 _, table_name = split_identifier(model.model_options.db_table)
171 column_names = [
172 model._model_meta.get_forward_field(field_name).column
173 for field_name, order in self.fields_orders
174 ]
175 column_names_with_order = [
176 (("-%s" if order else "%s") % column_name)
177 for column_name, (field_name, order) in zip(
178 column_names, self.fields_orders
179 )
180 ]
181 # The length of the parts of the name is based on the default max
182 # length of 30 characters.
183 hash_data = [table_name] + column_names_with_order + [self.suffix]
184 self.name = "{}_{}_{}".format(
185 table_name[:11],
186 column_names[0][:7],
187 f"{names_digest(*hash_data, length=6)}_{self.suffix}",
188 )
189 if len(self.name) > self.max_name_length:
190 raise ValueError(
191 "Index name too long. Is self.suffix longer than 3 characters?"
192 )
193 if self.name[0] == "_" or self.name[0].isdigit():
194 self.name = f"D{self.name[1:]}"
195
196 def __repr__(self) -> str:
197 return "<{}:{}{}{}{}{}{}>".format(
198 self.__class__.__qualname__,
199 "" if not self.fields else f" fields={repr(self.fields)}",
200 "" if not self.expressions else f" expressions={repr(self.expressions)}",
201 "" if not self.name else f" name={repr(self.name)}",
202 "" if self.condition is None else f" condition={self.condition}",
203 "" if not self.include else f" include={repr(self.include)}",
204 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
205 )
206
207 def __eq__(self, other: object) -> bool:
208 if isinstance(other, Index):
209 return self.deconstruct() == other.deconstruct()
210 return NotImplemented
211
212
213class IndexExpression(Func):
214 """Order and wrap expressions for CREATE INDEX statements."""
215
216 template = "%(expressions)s"
217 wrapper_classes = (OrderBy,)
218
219 def resolve_expression(
220 self,
221 query: Any = None,
222 allow_joins: bool = True,
223 reuse: Any = None,
224 summarize: bool = False,
225 for_save: bool = False,
226 ) -> Self:
227 expressions = list(self.flatten())
228 # Split expressions and wrappers.
229 index_expressions, wrappers = partition(
230 lambda e: isinstance(e, self.wrapper_classes),
231 expressions,
232 )
233 wrapper_types = [type(wrapper) for wrapper in wrappers]
234 if len(wrapper_types) != len(set(wrapper_types)):
235 raise ValueError(
236 "Multiple references to {} can't be used in an indexed "
237 "expression.".format(
238 ", ".join(
239 [
240 wrapper_cls.__qualname__
241 for wrapper_cls in self.wrapper_classes
242 ]
243 )
244 )
245 )
246 if expressions[1 : len(wrappers) + 1] != wrappers:
247 raise ValueError(
248 "{} must be topmost expressions in an indexed expression.".format(
249 ", ".join(
250 [
251 wrapper_cls.__qualname__
252 for wrapper_cls in self.wrapper_classes
253 ]
254 )
255 )
256 )
257 # Wrap expressions in parentheses if they are not column references.
258 root_expression = index_expressions[1]
259 resolve_root_expression = root_expression.resolve_expression(
260 query,
261 allow_joins,
262 reuse,
263 summarize,
264 for_save,
265 )
266 if not isinstance(resolve_root_expression, Col):
267 root_expression = Func(root_expression, template="(%(expressions)s)")
268
269 if wrappers:
270 # Order wrappers and set their expressions.
271 wrappers = sorted(
272 wrappers,
273 key=lambda w: self.wrapper_classes.index(type(w)),
274 )
275 wrappers = [wrapper.copy() for wrapper in wrappers]
276 for i, wrapper in enumerate(wrappers[:-1]):
277 wrapper.set_source_expressions([wrappers[i + 1]])
278 # Set the root expression on the deepest wrapper.
279 wrappers[-1].set_source_expressions([root_expression])
280 self.set_source_expressions([wrappers[0]])
281 else:
282 # Use the root expression, if there are no wrappers.
283 self.set_source_expressions([root_expression])
284 return super().resolve_expression(
285 query, allow_joins, reuse, summarize, for_save
286 )