Plain is headed towards 1.0! Subscribe for development updates →

  1from types import NoneType
  2
  3from plain.models.backends.utils import names_digest, split_identifier
  4from plain.models.expressions import Col, ExpressionList, F, Func, OrderBy
  5from plain.models.functions import Collate
  6from plain.models.query_utils import Q
  7from plain.models.sql import Query
  8from plain.utils.functional import partition
  9
 10__all__ = ["Index"]
 11
 12
 13class Index:
 14    suffix = "idx"
 15    # The max length of the name of the index (restricted to 30 for
 16    # cross-database compatibility with Oracle)
 17    max_name_length = 30
 18
 19    def __init__(
 20        self,
 21        *expressions,
 22        fields=(),
 23        name=None,
 24        opclasses=(),
 25        condition=None,
 26        include=None,
 27    ):
 28        if opclasses and not name:
 29            raise ValueError("An index must be named to use opclasses.")
 30        if not isinstance(condition, NoneType | Q):
 31            raise ValueError("Index.condition must be a Q instance.")
 32        if condition and not name:
 33            raise ValueError("An index must be named to use condition.")
 34        if not isinstance(fields, list | tuple):
 35            raise ValueError("Index.fields must be a list or tuple.")
 36        if not isinstance(opclasses, list | tuple):
 37            raise ValueError("Index.opclasses must be a list or tuple.")
 38        if not expressions and not fields:
 39            raise ValueError(
 40                "At least one field or expression is required to define an index."
 41            )
 42        if expressions and fields:
 43            raise ValueError(
 44                "Index.fields and expressions are mutually exclusive.",
 45            )
 46        if expressions and not name:
 47            raise ValueError("An index must be named to use expressions.")
 48        if expressions and opclasses:
 49            raise ValueError(
 50                "Index.opclasses cannot be used with expressions. Use "
 51                "a custom OpClass() instead."
 52            )
 53        if opclasses and len(fields) != len(opclasses):
 54            raise ValueError(
 55                "Index.fields and Index.opclasses must have the same number of "
 56                "elements."
 57            )
 58        if fields and not all(isinstance(field, str) for field in fields):
 59            raise ValueError("Index.fields must contain only strings with field names.")
 60        if include and not name:
 61            raise ValueError("A covering index must be named.")
 62        if not isinstance(include, NoneType | list | tuple):
 63            raise ValueError("Index.include must be a list or tuple.")
 64        self.fields = list(fields)
 65        # A list of 2-tuple with the field name and ordering ('' or 'DESC').
 66        self.fields_orders = [
 67            (field_name.removeprefix("-"), "DESC" if field_name.startswith("-") else "")
 68            for field_name in self.fields
 69        ]
 70        self.name = name or ""
 71        self.opclasses = opclasses
 72        self.condition = condition
 73        self.include = tuple(include) if include else ()
 74        self.expressions = tuple(
 75            F(expression) if isinstance(expression, str) else expression
 76            for expression in expressions
 77        )
 78
 79    @property
 80    def contains_expressions(self):
 81        return bool(self.expressions)
 82
 83    def _get_condition_sql(self, model, schema_editor):
 84        if self.condition is None:
 85            return None
 86        query = Query(model=model, alias_cols=False)
 87        where = query.build_where(self.condition)
 88        compiler = query.get_compiler()
 89        sql, params = where.as_sql(compiler, schema_editor.connection)
 90        return sql % tuple(schema_editor.quote_value(p) for p in params)
 91
 92    def create_sql(self, model, schema_editor, **kwargs):
 93        include = [
 94            model._meta.get_field(field_name).column for field_name in self.include
 95        ]
 96        condition = self._get_condition_sql(model, schema_editor)
 97        if self.expressions:
 98            index_expressions = []
 99            for expression in self.expressions:
100                index_expression = IndexExpression(expression)
101                index_expression.set_wrapper_classes(schema_editor.connection)
102                index_expressions.append(index_expression)
103            expressions = ExpressionList(*index_expressions).resolve_expression(
104                Query(model, alias_cols=False),
105            )
106            fields = None
107            col_suffixes = None
108        else:
109            fields = [
110                model._meta.get_field(field_name)
111                for field_name, _ in self.fields_orders
112            ]
113            if schema_editor.connection.features.supports_index_column_ordering:
114                col_suffixes = [order[1] for order in self.fields_orders]
115            else:
116                col_suffixes = [""] * len(self.fields_orders)
117            expressions = None
118        return schema_editor._create_index_sql(
119            model,
120            fields=fields,
121            name=self.name,
122            col_suffixes=col_suffixes,
123            opclasses=self.opclasses,
124            condition=condition,
125            include=include,
126            expressions=expressions,
127            **kwargs,
128        )
129
130    def remove_sql(self, model, schema_editor, **kwargs):
131        return schema_editor._delete_index_sql(model, self.name, **kwargs)
132
133    def deconstruct(self):
134        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
135        path = path.replace("plain.models.indexes", "plain.models")
136        kwargs = {"name": self.name}
137        if self.fields:
138            kwargs["fields"] = self.fields
139        if self.opclasses:
140            kwargs["opclasses"] = self.opclasses
141        if self.condition:
142            kwargs["condition"] = self.condition
143        if self.include:
144            kwargs["include"] = self.include
145        return (path, self.expressions, kwargs)
146
147    def clone(self):
148        """Create a copy of this Index."""
149        _, args, kwargs = self.deconstruct()
150        return self.__class__(*args, **kwargs)
151
152    def set_name_with_model(self, model):
153        """
154        Generate a unique name for the index.
155
156        The name is divided into 3 parts - table name (12 chars), field name
157        (8 chars) and unique hash + suffix (10 chars). Each part is made to
158        fit its size by truncating the excess length.
159        """
160        _, table_name = split_identifier(model._meta.db_table)
161        column_names = [
162            model._meta.get_field(field_name).column
163            for field_name, order in self.fields_orders
164        ]
165        column_names_with_order = [
166            (("-%s" if order else "%s") % column_name)
167            for column_name, (field_name, order) in zip(
168                column_names, self.fields_orders
169            )
170        ]
171        # The length of the parts of the name is based on the default max
172        # length of 30 characters.
173        hash_data = [table_name] + column_names_with_order + [self.suffix]
174        self.name = "{}_{}_{}".format(
175            table_name[:11],
176            column_names[0][:7],
177            f"{names_digest(*hash_data, length=6)}_{self.suffix}",
178        )
179        if len(self.name) > self.max_name_length:
180            raise ValueError(
181                "Index too long for multiple database support. Is self.suffix "
182                "longer than 3 characters?"
183            )
184        if self.name[0] == "_" or self.name[0].isdigit():
185            self.name = f"D{self.name[1:]}"
186
187    def __repr__(self):
188        return "<{}:{}{}{}{}{}{}>".format(
189            self.__class__.__qualname__,
190            "" if not self.fields else f" fields={repr(self.fields)}",
191            "" if not self.expressions else f" expressions={repr(self.expressions)}",
192            "" if not self.name else f" name={repr(self.name)}",
193            "" if self.condition is None else f" condition={self.condition}",
194            "" if not self.include else f" include={repr(self.include)}",
195            "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
196        )
197
198    def __eq__(self, other):
199        if self.__class__ == other.__class__:
200            return self.deconstruct() == other.deconstruct()
201        return NotImplemented
202
203
204class IndexExpression(Func):
205    """Order and wrap expressions for CREATE INDEX statements."""
206
207    template = "%(expressions)s"
208    wrapper_classes = (OrderBy, Collate)
209
210    def set_wrapper_classes(self, connection=None):
211        # Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
212        if connection and connection.features.collate_as_index_expression:
213            self.wrapper_classes = tuple(
214                [
215                    wrapper_cls
216                    for wrapper_cls in self.wrapper_classes
217                    if wrapper_cls is not Collate
218                ]
219            )
220
221    def resolve_expression(
222        self,
223        query=None,
224        allow_joins=True,
225        reuse=None,
226        summarize=False,
227        for_save=False,
228    ):
229        expressions = list(self.flatten())
230        # Split expressions and wrappers.
231        index_expressions, wrappers = partition(
232            lambda e: isinstance(e, self.wrapper_classes),
233            expressions,
234        )
235        wrapper_types = [type(wrapper) for wrapper in wrappers]
236        if len(wrapper_types) != len(set(wrapper_types)):
237            raise ValueError(
238                "Multiple references to {} can't be used in an indexed "
239                "expression.".format(
240                    ", ".join(
241                        [
242                            wrapper_cls.__qualname__
243                            for wrapper_cls in self.wrapper_classes
244                        ]
245                    )
246                )
247            )
248        if expressions[1 : len(wrappers) + 1] != wrappers:
249            raise ValueError(
250                "{} must be topmost expressions in an indexed expression.".format(
251                    ", ".join(
252                        [
253                            wrapper_cls.__qualname__
254                            for wrapper_cls in self.wrapper_classes
255                        ]
256                    )
257                )
258            )
259        # Wrap expressions in parentheses if they are not column references.
260        root_expression = index_expressions[1]
261        resolve_root_expression = root_expression.resolve_expression(
262            query,
263            allow_joins,
264            reuse,
265            summarize,
266            for_save,
267        )
268        if not isinstance(resolve_root_expression, Col):
269            root_expression = Func(root_expression, template="(%(expressions)s)")
270
271        if wrappers:
272            # Order wrappers and set their expressions.
273            wrappers = sorted(
274                wrappers,
275                key=lambda w: self.wrapper_classes.index(type(w)),
276            )
277            wrappers = [wrapper.copy() for wrapper in wrappers]
278            for i, wrapper in enumerate(wrappers[:-1]):
279                wrapper.set_source_expressions([wrappers[i + 1]])
280            # Set the root expression on the deepest wrapper.
281            wrappers[-1].set_source_expressions([root_expression])
282            self.set_source_expressions([wrappers[0]])
283        else:
284            # Use the root expression, if there are no wrappers.
285            self.set_source_expressions([root_expression])
286        return super().resolve_expression(
287            query, allow_joins, reuse, summarize, for_save
288        )
289
290    def as_sqlite(self, compiler, connection, **extra_context):
291        # Casting to numeric is unnecessary.
292        return self.as_sql(compiler, connection, **extra_context)