Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections.abc import Sequence
  4from types import NoneType
  5from typing import TYPE_CHECKING, Any, Self, cast
  6
  7from plain.models.backends.utils import names_digest, split_identifier
  8from plain.models.expressions import Col, ExpressionList, F, Func, OrderBy
  9from plain.models.functions import Collate
 10from plain.models.query_utils import Q
 11from plain.models.sql import Query
 12from plain.utils.functional import partition
 13
 14if TYPE_CHECKING:
 15    from plain.models.backends.base.base import BaseDatabaseWrapper
 16    from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
 17    from plain.models.backends.ddl_references import Statement
 18    from plain.models.base import Model
 19    from plain.models.expressions import Expression
 20    from plain.models.sql.compiler import SQLCompiler
 21
 22__all__ = ["Index"]
 23
 24
 25class Index:
 26    suffix = "idx"
 27    # The max length of the name of the index (restricted to 30 for
 28    # cross-database compatibility with Oracle)
 29    max_name_length = 30
 30
 31    def __init__(
 32        self,
 33        *expressions: Any,
 34        fields: tuple[str, ...] | list[str] = (),
 35        name: str | None = None,
 36        opclasses: tuple[str, ...] | list[str] = (),
 37        condition: Q | None = None,
 38        include: tuple[str, ...] | list[str] | None = None,
 39    ) -> None:
 40        if opclasses and not name:
 41            raise ValueError("An index must be named to use opclasses.")
 42        if not isinstance(condition, NoneType | Q):
 43            raise ValueError("Index.condition must be a Q instance.")
 44        if condition and not name:
 45            raise ValueError("An index must be named to use condition.")
 46        if not isinstance(fields, list | tuple):
 47            raise ValueError("Index.fields must be a list or tuple.")
 48        if not isinstance(opclasses, list | tuple):
 49            raise ValueError("Index.opclasses must be a list or tuple.")
 50        if not expressions and not fields:
 51            raise ValueError(
 52                "At least one field or expression is required to define an index."
 53            )
 54        if expressions and fields:
 55            raise ValueError(
 56                "Index.fields and expressions are mutually exclusive.",
 57            )
 58        if expressions and not name:
 59            raise ValueError("An index must be named to use expressions.")
 60        if expressions and opclasses:
 61            raise ValueError(
 62                "Index.opclasses cannot be used with expressions. Use "
 63                "a custom OpClass() instead."
 64            )
 65        if opclasses and len(fields) != len(opclasses):
 66            raise ValueError(
 67                "Index.fields and Index.opclasses must have the same number of "
 68                "elements."
 69            )
 70        if fields and not all(isinstance(field, str) for field in fields):
 71            raise ValueError("Index.fields must contain only strings with field names.")
 72        if include and not name:
 73            raise ValueError("A covering index must be named.")
 74        if not isinstance(include, NoneType | list | tuple):
 75            raise ValueError("Index.include must be a list or tuple.")
 76        self.fields = list(fields)
 77        # A list of 2-tuple with the field name and ordering ('' or 'DESC').
 78        self.fields_orders = [
 79            (field_name.removeprefix("-"), "DESC" if field_name.startswith("-") else "")
 80            for field_name in self.fields
 81        ]
 82        self.name = name or ""
 83        self.opclasses: tuple[str, ...] = tuple(opclasses)
 84        self.condition = condition
 85        self.include = tuple(include) if include else ()
 86        self.expressions: tuple[Expression, ...] = tuple(
 87            F(expression) if isinstance(expression, str) else expression
 88            for expression in expressions
 89        )
 90
 91    @property
 92    def contains_expressions(self) -> bool:
 93        return bool(self.expressions)
 94
 95    def _get_condition_sql(
 96        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
 97    ) -> str | None:
 98        if self.condition is None:
 99            return None
100        query = Query(model=model, alias_cols=False)
101        where = query.build_where(self.condition)
102        compiler = query.get_compiler()
103        sql, params = where.as_sql(compiler, schema_editor.connection)
104        return sql % tuple(schema_editor.quote_value(p) for p in params)
105
106    def create_sql(
107        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor, **kwargs: Any
108    ) -> Statement:
109        include = [
110            model._model_meta.get_forward_field(field_name).column
111            for field_name in self.include
112        ]
113        condition = self._get_condition_sql(model, schema_editor)
114        if self.expressions:
115            index_expressions = []
116            for expression in self.expressions:
117                index_expression = IndexExpression(expression)
118                index_expression.set_wrapper_classes(schema_editor.connection)
119                index_expressions.append(index_expression)
120            expressions = ExpressionList(*index_expressions).resolve_expression(
121                Query(model, alias_cols=False),
122            )
123            fields = None
124            col_suffixes = ()
125        else:
126            fields = [
127                model._model_meta.get_forward_field(field_name)
128                for field_name, _ in self.fields_orders
129            ]
130            if schema_editor.connection.features.supports_index_column_ordering:
131                col_suffixes = tuple(order[1] for order in self.fields_orders)
132            else:
133                col_suffixes = ("",) * len(self.fields_orders)
134            expressions = None
135        return schema_editor._create_index_sql(
136            model,
137            fields=fields,
138            name=self.name,
139            col_suffixes=col_suffixes,
140            opclasses=self.opclasses,
141            condition=condition,
142            include=include,
143            expressions=expressions,
144            **kwargs,
145        )
146
147    def remove_sql(
148        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor, **kwargs: Any
149    ) -> Statement:
150        return schema_editor._delete_index_sql(model, self.name, **kwargs)
151
152    def deconstruct(self) -> tuple[str, tuple[Expression, ...], dict[str, Any]]:
153        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
154        path = path.replace("plain.models.indexes", "plain.models")
155        kwargs = {"name": self.name}
156        if self.fields:
157            kwargs["fields"] = self.fields
158        if self.opclasses:
159            kwargs["opclasses"] = self.opclasses
160        if self.condition:
161            kwargs["condition"] = self.condition
162        if self.include:
163            kwargs["include"] = self.include
164        return (path, self.expressions, kwargs)
165
166    def clone(self) -> Index:
167        """Create a copy of this Index."""
168        _, args, kwargs = self.deconstruct()
169        return self.__class__(*args, **kwargs)
170
171    def set_name_with_model(self, model: type[Model]) -> None:
172        """
173        Generate a unique name for the index.
174
175        The name is divided into 3 parts - table name (12 chars), field name
176        (8 chars) and unique hash + suffix (10 chars). Each part is made to
177        fit its size by truncating the excess length.
178        """
179        _, table_name = split_identifier(model.model_options.db_table)
180        column_names = [
181            model._model_meta.get_forward_field(field_name).column
182            for field_name, order in self.fields_orders
183        ]
184        column_names_with_order = [
185            (("-%s" if order else "%s") % column_name)
186            for column_name, (field_name, order) in zip(
187                column_names, self.fields_orders
188            )
189        ]
190        # The length of the parts of the name is based on the default max
191        # length of 30 characters.
192        hash_data = [table_name] + column_names_with_order + [self.suffix]
193        self.name = "{}_{}_{}".format(
194            table_name[:11],
195            column_names[0][:7],
196            f"{names_digest(*hash_data, length=6)}_{self.suffix}",
197        )
198        if len(self.name) > self.max_name_length:
199            raise ValueError(
200                "Index too long for multiple database support. Is self.suffix "
201                "longer than 3 characters?"
202            )
203        if self.name[0] == "_" or self.name[0].isdigit():
204            self.name = f"D{self.name[1:]}"
205
206    def __repr__(self) -> str:
207        return "<{}:{}{}{}{}{}{}>".format(
208            self.__class__.__qualname__,
209            "" if not self.fields else f" fields={repr(self.fields)}",
210            "" if not self.expressions else f" expressions={repr(self.expressions)}",
211            "" if not self.name else f" name={repr(self.name)}",
212            "" if self.condition is None else f" condition={self.condition}",
213            "" if not self.include else f" include={repr(self.include)}",
214            "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
215        )
216
217    def __eq__(self, other: object) -> bool:
218        if isinstance(other, Index):
219            return self.deconstruct() == other.deconstruct()
220        return NotImplemented
221
222
223class IndexExpression(Func):
224    """Order and wrap expressions for CREATE INDEX statements."""
225
226    template = "%(expressions)s"
227    wrapper_classes = (OrderBy, Collate)
228
229    def set_wrapper_classes(
230        self, connection: BaseDatabaseWrapper | None = None
231    ) -> None:
232        # Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
233        if connection and connection.features.collate_as_index_expression:
234            self.wrapper_classes = tuple(
235                [
236                    wrapper_cls
237                    for wrapper_cls in self.wrapper_classes
238                    if wrapper_cls is not Collate
239                ]
240            )
241
242    def resolve_expression(
243        self,
244        query: Any = None,
245        allow_joins: bool = True,
246        reuse: Any = None,
247        summarize: bool = False,
248        for_save: bool = False,
249    ) -> Self:
250        expressions = list(self.flatten())
251        # Split expressions and wrappers.
252        index_expressions, wrappers = partition(
253            lambda e: isinstance(e, self.wrapper_classes),
254            expressions,
255        )
256        wrapper_types = [type(wrapper) for wrapper in wrappers]
257        if len(wrapper_types) != len(set(wrapper_types)):
258            raise ValueError(
259                "Multiple references to {} can't be used in an indexed "
260                "expression.".format(
261                    ", ".join(
262                        [
263                            wrapper_cls.__qualname__
264                            for wrapper_cls in self.wrapper_classes
265                        ]
266                    )
267                )
268            )
269        if expressions[1 : len(wrappers) + 1] != wrappers:
270            raise ValueError(
271                "{} must be topmost expressions in an indexed expression.".format(
272                    ", ".join(
273                        [
274                            wrapper_cls.__qualname__
275                            for wrapper_cls in self.wrapper_classes
276                        ]
277                    )
278                )
279            )
280        # Wrap expressions in parentheses if they are not column references.
281        root_expression = index_expressions[1]
282        resolve_root_expression = root_expression.resolve_expression(
283            query,
284            allow_joins,
285            reuse,
286            summarize,
287            for_save,
288        )
289        if not isinstance(resolve_root_expression, Col):
290            root_expression = Func(root_expression, template="(%(expressions)s)")
291
292        if wrappers:
293            # Order wrappers and set their expressions.
294            wrappers = sorted(
295                wrappers,
296                key=lambda w: self.wrapper_classes.index(type(w)),
297            )
298            wrappers = [wrapper.copy() for wrapper in wrappers]
299            for i, wrapper in enumerate(wrappers[:-1]):
300                wrapper.set_source_expressions([wrappers[i + 1]])
301            # Set the root expression on the deepest wrapper.
302            wrappers[-1].set_source_expressions([root_expression])
303            self.set_source_expressions([wrappers[0]])
304        else:
305            # Use the root expression, if there are no wrappers.
306            self.set_source_expressions([root_expression])
307        # Cast needed because super() returns parent's Self type, not subclass's Self
308        return cast(
309            Self,
310            super().resolve_expression(query, allow_joins, reuse, summarize, for_save),
311        )
312
313    def as_sqlite(
314        self,
315        compiler: SQLCompiler,
316        connection: BaseDatabaseWrapper,
317        **extra_context: Any,
318    ) -> tuple[str, Sequence[Any]]:
319        # Casting to numeric is unnecessary.
320        return self.as_sql(compiler, connection, **extra_context)