1from __future__ import annotations
  2
  3from types import NoneType
  4from typing import TYPE_CHECKING, Any, Self
  5
  6from plain.models.expressions import Col, ExpressionList, F, Func, OrderBy
  7from plain.models.postgres.utils import names_digest, split_identifier
  8from plain.models.query_utils import Q
  9from plain.models.sql.query import Query
 10from plain.utils.functional import partition
 11
 12if TYPE_CHECKING:
 13    from plain.models.base import Model
 14    from plain.models.expressions import Expression
 15    from plain.models.postgres.schema import DatabaseSchemaEditor, Statement
 16
 17__all__ = ["Index"]
 18
 19
 20class Index:
 21    suffix = "idx"
 22    # The max length of the name of the index
 23    max_name_length = 30
 24
 25    def __init__(
 26        self,
 27        *expressions: Any,
 28        fields: tuple[str, ...] | list[str] = (),
 29        name: str | None = None,
 30        opclasses: tuple[str, ...] | list[str] = (),
 31        condition: Q | None = None,
 32        include: tuple[str, ...] | list[str] | None = None,
 33    ) -> None:
 34        if opclasses and not name:
 35            raise ValueError("An index must be named to use opclasses.")
 36        if not isinstance(condition, NoneType | Q):
 37            raise ValueError("Index.condition must be a Q instance.")
 38        if condition and not name:
 39            raise ValueError("An index must be named to use condition.")
 40        if not isinstance(fields, list | tuple):
 41            raise ValueError("Index.fields must be a list or tuple.")
 42        if not isinstance(opclasses, list | tuple):
 43            raise ValueError("Index.opclasses must be a list or tuple.")
 44        if not expressions and not fields:
 45            raise ValueError(
 46                "At least one field or expression is required to define an index."
 47            )
 48        if expressions and fields:
 49            raise ValueError(
 50                "Index.fields and expressions are mutually exclusive.",
 51            )
 52        if expressions and not name:
 53            raise ValueError("An index must be named to use expressions.")
 54        if expressions and opclasses:
 55            raise ValueError(
 56                "Index.opclasses cannot be used with expressions. Use "
 57                "a custom OpClass() instead."
 58            )
 59        if opclasses and len(fields) != len(opclasses):
 60            raise ValueError(
 61                "Index.fields and Index.opclasses must have the same number of "
 62                "elements."
 63            )
 64        if fields and not all(isinstance(field, str) for field in fields):
 65            raise ValueError("Index.fields must contain only strings with field names.")
 66        if include and not name:
 67            raise ValueError("A covering index must be named.")
 68        if not isinstance(include, NoneType | list | tuple):
 69            raise ValueError("Index.include must be a list or tuple.")
 70        self.fields = list(fields)
 71        # A list of 2-tuple with the field name and ordering ('' or 'DESC').
 72        self.fields_orders = [
 73            (field_name.removeprefix("-"), "DESC" if field_name.startswith("-") else "")
 74            for field_name in self.fields
 75        ]
 76        self.name = name or ""
 77        self.opclasses: tuple[str, ...] = tuple(opclasses)
 78        self.condition = condition
 79        self.include = tuple(include) if include else ()
 80        self.expressions: tuple[Expression, ...] = tuple(
 81            F(expression) if isinstance(expression, str) else expression
 82            for expression in expressions
 83        )
 84
 85    @property
 86    def contains_expressions(self) -> bool:
 87        return bool(self.expressions)
 88
 89    def _get_condition_sql(
 90        self, model: type[Model], schema_editor: DatabaseSchemaEditor
 91    ) -> str | None:
 92        if self.condition is None:
 93            return None
 94        query = Query(model=model, alias_cols=False)
 95        where = query.build_where(self.condition)
 96        compiler = query.get_compiler()
 97        sql, params = where.as_sql(compiler, schema_editor.connection)
 98        return sql % tuple(schema_editor.quote_value(p) for p in params)
 99
100    def create_sql(
101        self, model: type[Model], schema_editor: DatabaseSchemaEditor, **kwargs: Any
102    ) -> Statement:
103        include = [
104            model._model_meta.get_forward_field(field_name).column
105            for field_name in self.include
106        ]
107        condition = self._get_condition_sql(model, schema_editor)
108        if self.expressions:
109            index_expressions = []
110            for expression in self.expressions:
111                index_expression = IndexExpression(expression)
112                index_expressions.append(index_expression)
113            expressions = ExpressionList(*index_expressions).resolve_expression(
114                Query(model, alias_cols=False),
115            )
116            fields = None
117            col_suffixes = ()
118        else:
119            fields = [
120                model._model_meta.get_forward_field(field_name)
121                for field_name, _ in self.fields_orders
122            ]
123            # Support index column ordering (ASC/DESC)
124            col_suffixes = tuple(order[1] for order in self.fields_orders)
125            expressions = None
126        return schema_editor._create_index_sql(
127            model,
128            fields=fields,
129            name=self.name,
130            col_suffixes=col_suffixes,
131            opclasses=self.opclasses,
132            condition=condition,
133            include=include,
134            expressions=expressions,
135            **kwargs,
136        )
137
138    def remove_sql(
139        self, model: type[Model], schema_editor: DatabaseSchemaEditor, **kwargs: Any
140    ) -> Statement:
141        return schema_editor._delete_index_sql(model, self.name, **kwargs)
142
143    def deconstruct(self) -> tuple[str, tuple[Expression, ...], dict[str, Any]]:
144        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
145        path = path.replace("plain.models.indexes", "plain.models")
146        kwargs = {"name": self.name}
147        if self.fields:
148            kwargs["fields"] = self.fields
149        if self.opclasses:
150            kwargs["opclasses"] = self.opclasses
151        if self.condition:
152            kwargs["condition"] = self.condition
153        if self.include:
154            kwargs["include"] = self.include
155        return (path, self.expressions, kwargs)
156
157    def clone(self) -> Index:
158        """Create a copy of this Index."""
159        _, args, kwargs = self.deconstruct()
160        return self.__class__(*args, **kwargs)
161
162    def set_name_with_model(self, model: type[Model]) -> None:
163        """
164        Generate a unique name for the index.
165
166        The name is divided into 3 parts - table name (12 chars), field name
167        (8 chars) and unique hash + suffix (10 chars). Each part is made to
168        fit its size by truncating the excess length.
169        """
170        _, table_name = split_identifier(model.model_options.db_table)
171        column_names = [
172            model._model_meta.get_forward_field(field_name).column
173            for field_name, order in self.fields_orders
174        ]
175        column_names_with_order = [
176            (("-%s" if order else "%s") % column_name)
177            for column_name, (field_name, order) in zip(
178                column_names, self.fields_orders
179            )
180        ]
181        # The length of the parts of the name is based on the default max
182        # length of 30 characters.
183        hash_data = [table_name] + column_names_with_order + [self.suffix]
184        self.name = "{}_{}_{}".format(
185            table_name[:11],
186            column_names[0][:7],
187            f"{names_digest(*hash_data, length=6)}_{self.suffix}",
188        )
189        if len(self.name) > self.max_name_length:
190            raise ValueError(
191                "Index name too long. Is self.suffix longer than 3 characters?"
192            )
193        if self.name[0] == "_" or self.name[0].isdigit():
194            self.name = f"D{self.name[1:]}"
195
196    def __repr__(self) -> str:
197        return "<{}:{}{}{}{}{}{}>".format(
198            self.__class__.__qualname__,
199            "" if not self.fields else f" fields={repr(self.fields)}",
200            "" if not self.expressions else f" expressions={repr(self.expressions)}",
201            "" if not self.name else f" name={repr(self.name)}",
202            "" if self.condition is None else f" condition={self.condition}",
203            "" if not self.include else f" include={repr(self.include)}",
204            "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
205        )
206
207    def __eq__(self, other: object) -> bool:
208        if isinstance(other, Index):
209            return self.deconstruct() == other.deconstruct()
210        return NotImplemented
211
212
213class IndexExpression(Func):
214    """Order and wrap expressions for CREATE INDEX statements."""
215
216    template = "%(expressions)s"
217    wrapper_classes = (OrderBy,)
218
219    def resolve_expression(
220        self,
221        query: Any = None,
222        allow_joins: bool = True,
223        reuse: Any = None,
224        summarize: bool = False,
225        for_save: bool = False,
226    ) -> Self:
227        expressions = list(self.flatten())
228        # Split expressions and wrappers.
229        index_expressions, wrappers = partition(
230            lambda e: isinstance(e, self.wrapper_classes),
231            expressions,
232        )
233        wrapper_types = [type(wrapper) for wrapper in wrappers]
234        if len(wrapper_types) != len(set(wrapper_types)):
235            raise ValueError(
236                "Multiple references to {} can't be used in an indexed "
237                "expression.".format(
238                    ", ".join(
239                        [
240                            wrapper_cls.__qualname__
241                            for wrapper_cls in self.wrapper_classes
242                        ]
243                    )
244                )
245            )
246        if expressions[1 : len(wrappers) + 1] != wrappers:
247            raise ValueError(
248                "{} must be topmost expressions in an indexed expression.".format(
249                    ", ".join(
250                        [
251                            wrapper_cls.__qualname__
252                            for wrapper_cls in self.wrapper_classes
253                        ]
254                    )
255                )
256            )
257        # Wrap expressions in parentheses if they are not column references.
258        root_expression = index_expressions[1]
259        resolve_root_expression = root_expression.resolve_expression(
260            query,
261            allow_joins,
262            reuse,
263            summarize,
264            for_save,
265        )
266        if not isinstance(resolve_root_expression, Col):
267            root_expression = Func(root_expression, template="(%(expressions)s)")
268
269        if wrappers:
270            # Order wrappers and set their expressions.
271            wrappers = sorted(
272                wrappers,
273                key=lambda w: self.wrapper_classes.index(type(w)),
274            )
275            wrappers = [wrapper.copy() for wrapper in wrappers]
276            for i, wrapper in enumerate(wrappers[:-1]):
277                wrapper.set_source_expressions([wrappers[i + 1]])
278            # Set the root expression on the deepest wrapper.
279            wrappers[-1].set_source_expressions([root_expression])
280            self.set_source_expressions([wrappers[0]])
281        else:
282            # Use the root expression, if there are no wrappers.
283            self.set_source_expressions([root_expression])
284        return super().resolve_expression(
285            query, allow_joins, reuse, summarize, for_save
286        )