1from __future__ import annotations
  2
  3from types import NoneType
  4from typing import TYPE_CHECKING, Any
  5
  6from plain.models.backends.utils import names_digest, split_identifier
  7from plain.models.expressions import Col, ExpressionList, F, Func, OrderBy
  8from plain.models.functions import Collate
  9from plain.models.query_utils import Q
 10from plain.models.sql import Query
 11from plain.utils.functional import partition
 12
 13if TYPE_CHECKING:
 14    from plain.models.backends.base.base import BaseDatabaseWrapper
 15    from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
 16    from plain.models.backends.ddl_references import Statement
 17    from plain.models.base import Model
 18    from plain.models.expressions import Expression
 19    from plain.models.sql.compiler import SQLCompiler
 20
 21__all__ = ["Index"]
 22
 23
 24class Index:
 25    suffix = "idx"
 26    # The max length of the name of the index (restricted to 30 for
 27    # cross-database compatibility with Oracle)
 28    max_name_length = 30
 29
 30    def __init__(
 31        self,
 32        *expressions: Any,
 33        fields: tuple[str, ...] | list[str] = (),
 34        name: str | None = None,
 35        opclasses: tuple[str, ...] | list[str] = (),
 36        condition: Q | None = None,
 37        include: tuple[str, ...] | list[str] | None = None,
 38    ) -> None:
 39        if opclasses and not name:
 40            raise ValueError("An index must be named to use opclasses.")
 41        if not isinstance(condition, NoneType | Q):
 42            raise ValueError("Index.condition must be a Q instance.")
 43        if condition and not name:
 44            raise ValueError("An index must be named to use condition.")
 45        if not isinstance(fields, list | tuple):
 46            raise ValueError("Index.fields must be a list or tuple.")
 47        if not isinstance(opclasses, list | tuple):
 48            raise ValueError("Index.opclasses must be a list or tuple.")
 49        if not expressions and not fields:
 50            raise ValueError(
 51                "At least one field or expression is required to define an index."
 52            )
 53        if expressions and fields:
 54            raise ValueError(
 55                "Index.fields and expressions are mutually exclusive.",
 56            )
 57        if expressions and not name:
 58            raise ValueError("An index must be named to use expressions.")
 59        if expressions and opclasses:
 60            raise ValueError(
 61                "Index.opclasses cannot be used with expressions. Use "
 62                "a custom OpClass() instead."
 63            )
 64        if opclasses and len(fields) != len(opclasses):
 65            raise ValueError(
 66                "Index.fields and Index.opclasses must have the same number of "
 67                "elements."
 68            )
 69        if fields and not all(isinstance(field, str) for field in fields):
 70            raise ValueError("Index.fields must contain only strings with field names.")
 71        if include and not name:
 72            raise ValueError("A covering index must be named.")
 73        if not isinstance(include, NoneType | list | tuple):
 74            raise ValueError("Index.include must be a list or tuple.")
 75        self.fields = list(fields)
 76        # A list of 2-tuple with the field name and ordering ('' or 'DESC').
 77        self.fields_orders = [
 78            (field_name.removeprefix("-"), "DESC" if field_name.startswith("-") else "")
 79            for field_name in self.fields
 80        ]
 81        self.name = name or ""
 82        self.opclasses = opclasses
 83        self.condition = condition
 84        self.include = tuple(include) if include else ()
 85        self.expressions = tuple(
 86            F(expression) if isinstance(expression, str) else expression
 87            for expression in expressions
 88        )
 89
 90    @property
 91    def contains_expressions(self) -> bool:
 92        return bool(self.expressions)
 93
 94    def _get_condition_sql(
 95        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor
 96    ) -> str | None:
 97        if self.condition is None:
 98            return None
 99        query = Query(model=model, alias_cols=False)
100        where = query.build_where(self.condition)
101        compiler = query.get_compiler()
102        sql, params = where.as_sql(compiler, schema_editor.connection)
103        return sql % tuple(schema_editor.quote_value(p) for p in params)
104
105    def create_sql(
106        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor, **kwargs: Any
107    ) -> Statement:
108        include = [
109            model._model_meta.get_field(field_name).column
110            for field_name in self.include
111        ]
112        condition = self._get_condition_sql(model, schema_editor)
113        if self.expressions:
114            index_expressions = []
115            for expression in self.expressions:
116                index_expression = IndexExpression(expression)
117                index_expression.set_wrapper_classes(schema_editor.connection)
118                index_expressions.append(index_expression)
119            expressions = ExpressionList(*index_expressions).resolve_expression(
120                Query(model, alias_cols=False),
121            )
122            fields = None
123            col_suffixes = None
124        else:
125            fields = [
126                model._model_meta.get_field(field_name)
127                for field_name, _ in self.fields_orders
128            ]
129            if schema_editor.connection.features.supports_index_column_ordering:
130                col_suffixes = tuple(order[1] for order in self.fields_orders)
131            else:
132                col_suffixes = ("",) * len(self.fields_orders)
133            expressions = None
134        return schema_editor._create_index_sql(
135            model,
136            fields=fields,
137            name=self.name,
138            col_suffixes=col_suffixes,  # type: ignore[arg-type]
139            opclasses=self.opclasses,
140            condition=condition,
141            include=include,
142            expressions=expressions,
143            **kwargs,
144        )
145
146    def remove_sql(
147        self, model: type[Model], schema_editor: BaseDatabaseSchemaEditor, **kwargs: Any
148    ) -> Statement:
149        return schema_editor._delete_index_sql(model, self.name, **kwargs)
150
151    def deconstruct(self) -> tuple[str, tuple[Expression, ...], dict[str, Any]]:
152        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
153        path = path.replace("plain.models.indexes", "plain.models")
154        kwargs = {"name": self.name}
155        if self.fields:
156            kwargs["fields"] = self.fields
157        if self.opclasses:
158            kwargs["opclasses"] = self.opclasses
159        if self.condition:
160            kwargs["condition"] = self.condition
161        if self.include:
162            kwargs["include"] = self.include
163        return (path, self.expressions, kwargs)
164
165    def clone(self) -> Index:
166        """Create a copy of this Index."""
167        _, args, kwargs = self.deconstruct()
168        return self.__class__(*args, **kwargs)
169
170    def set_name_with_model(self, model: type[Model]) -> None:
171        """
172        Generate a unique name for the index.
173
174        The name is divided into 3 parts - table name (12 chars), field name
175        (8 chars) and unique hash + suffix (10 chars). Each part is made to
176        fit its size by truncating the excess length.
177        """
178        _, table_name = split_identifier(model.model_options.db_table)
179        column_names = [
180            model._model_meta.get_field(field_name).column
181            for field_name, order in self.fields_orders
182        ]
183        column_names_with_order = [
184            (("-%s" if order else "%s") % column_name)
185            for column_name, (field_name, order) in zip(
186                column_names, self.fields_orders
187            )
188        ]
189        # The length of the parts of the name is based on the default max
190        # length of 30 characters.
191        hash_data = [table_name] + column_names_with_order + [self.suffix]
192        self.name = "{}_{}_{}".format(
193            table_name[:11],
194            column_names[0][:7],
195            f"{names_digest(*hash_data, length=6)}_{self.suffix}",
196        )
197        if len(self.name) > self.max_name_length:
198            raise ValueError(
199                "Index too long for multiple database support. Is self.suffix "
200                "longer than 3 characters?"
201            )
202        if self.name[0] == "_" or self.name[0].isdigit():
203            self.name = f"D{self.name[1:]}"
204
205    def __repr__(self) -> str:
206        return "<{}:{}{}{}{}{}{}>".format(
207            self.__class__.__qualname__,
208            "" if not self.fields else f" fields={repr(self.fields)}",
209            "" if not self.expressions else f" expressions={repr(self.expressions)}",
210            "" if not self.name else f" name={repr(self.name)}",
211            "" if self.condition is None else f" condition={self.condition}",
212            "" if not self.include else f" include={repr(self.include)}",
213            "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
214        )
215
216    def __eq__(self, other: object) -> bool:
217        if self.__class__ == other.__class__:
218            return self.deconstruct() == other.deconstruct()  # type: ignore[attr-defined]
219        return NotImplemented
220
221
222class IndexExpression(Func):
223    """Order and wrap expressions for CREATE INDEX statements."""
224
225    template = "%(expressions)s"
226    wrapper_classes = (OrderBy, Collate)
227
228    def set_wrapper_classes(
229        self, connection: BaseDatabaseWrapper | None = None
230    ) -> None:
231        # Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
232        if connection and connection.features.collate_as_index_expression:
233            self.wrapper_classes = tuple(
234                [
235                    wrapper_cls
236                    for wrapper_cls in self.wrapper_classes
237                    if wrapper_cls is not Collate
238                ]
239            )
240
241    def resolve_expression(
242        self,
243        query: Any = None,
244        allow_joins: bool = True,
245        reuse: Any = None,
246        summarize: bool = False,
247        for_save: bool = False,
248    ) -> Expression:
249        expressions = list(self.flatten())
250        # Split expressions and wrappers.
251        index_expressions, wrappers = partition(
252            lambda e: isinstance(e, self.wrapper_classes),
253            expressions,
254        )
255        wrapper_types = [type(wrapper) for wrapper in wrappers]
256        if len(wrapper_types) != len(set(wrapper_types)):
257            raise ValueError(
258                "Multiple references to {} can't be used in an indexed "
259                "expression.".format(
260                    ", ".join(
261                        [
262                            wrapper_cls.__qualname__
263                            for wrapper_cls in self.wrapper_classes
264                        ]
265                    )
266                )
267            )
268        if expressions[1 : len(wrappers) + 1] != wrappers:
269            raise ValueError(
270                "{} must be topmost expressions in an indexed expression.".format(
271                    ", ".join(
272                        [
273                            wrapper_cls.__qualname__
274                            for wrapper_cls in self.wrapper_classes
275                        ]
276                    )
277                )
278            )
279        # Wrap expressions in parentheses if they are not column references.
280        root_expression = index_expressions[1]
281        resolve_root_expression = root_expression.resolve_expression(
282            query,
283            allow_joins,
284            reuse,
285            summarize,
286            for_save,
287        )
288        if not isinstance(resolve_root_expression, Col):
289            root_expression = Func(root_expression, template="(%(expressions)s)")
290
291        if wrappers:
292            # Order wrappers and set their expressions.
293            wrappers = sorted(
294                wrappers,
295                key=lambda w: self.wrapper_classes.index(type(w)),
296            )
297            wrappers = [wrapper.copy() for wrapper in wrappers]
298            for i, wrapper in enumerate(wrappers[:-1]):
299                wrapper.set_source_expressions([wrappers[i + 1]])
300            # Set the root expression on the deepest wrapper.
301            wrappers[-1].set_source_expressions([root_expression])
302            self.set_source_expressions([wrappers[0]])
303        else:
304            # Use the root expression, if there are no wrappers.
305            self.set_source_expressions([root_expression])
306        return super().resolve_expression(
307            query, allow_joins, reuse, summarize, for_save
308        )
309
310    def as_sqlite(
311        self,
312        compiler: SQLCompiler,
313        connection: BaseDatabaseWrapper,
314        **extra_context: Any,
315    ) -> tuple[str, tuple[Any, ...]]:
316        # Casting to numeric is unnecessary.
317        return self.as_sql(compiler, connection, **extra_context)