v0.146.0
  1from __future__ import annotations
  2
  3from types import NoneType
  4from typing import TYPE_CHECKING, Any, Self
  5
  6from plain.postgres.ddl import (
  7    build_include_sql,
  8    compile_expression_sql,
  9    compile_index_expressions_sql,
 10)
 11from plain.postgres.dialect import quote_name
 12from plain.postgres.expressions import Col, F, Func, OrderBy
 13from plain.postgres.query_utils import Q
 14from plain.utils.functional import partition
 15
 16if TYPE_CHECKING:
 17    from plain.postgres.base import Model
 18    from plain.postgres.expressions import Expression
 19
 20__all__ = ["Index"]
 21
 22
 23class Index:
 24    suffix = "idx"
 25    # Postgres identifier limit: NAMEDATALEN - 1 = 63
 26    max_name_length = 63
 27
 28    def __init__(
 29        self,
 30        *expressions: Any,
 31        name: str,
 32        fields: tuple[str, ...] | list[str] = (),
 33        opclasses: tuple[str, ...] | list[str] = (),
 34        condition: Q | None = None,
 35        include: tuple[str, ...] | list[str] | None = None,
 36    ) -> None:
 37        if not isinstance(condition, NoneType | Q):
 38            raise ValueError("Index.condition must be a Q instance.")
 39        if not isinstance(fields, list | tuple):
 40            raise ValueError("Index.fields must be a list or tuple.")
 41        if not isinstance(opclasses, list | tuple):
 42            raise ValueError("Index.opclasses must be a list or tuple.")
 43        if not expressions and not fields:
 44            raise ValueError(
 45                "At least one field or expression is required to define an index."
 46            )
 47        if expressions and fields:
 48            raise ValueError(
 49                "Index.fields and expressions are mutually exclusive.",
 50            )
 51        if expressions and opclasses:
 52            raise ValueError(
 53                "Index.opclasses cannot be used with expressions. Use "
 54                "a custom OpClass() instead."
 55            )
 56        if opclasses and len(fields) != len(opclasses):
 57            raise ValueError(
 58                "Index.fields and Index.opclasses must have the same number of "
 59                "elements."
 60            )
 61        if fields and not all(isinstance(field, str) for field in fields):
 62            raise ValueError("Index.fields must contain only strings with field names.")
 63        if not isinstance(include, NoneType | list | tuple):
 64            raise ValueError("Index.include must be a list or tuple.")
 65        self.fields = list(fields)
 66        # A list of 2-tuple with the field name and ordering ('' or 'DESC').
 67        self.fields_orders = [
 68            (field_name.removeprefix("-"), "DESC" if field_name.startswith("-") else "")
 69            for field_name in self.fields
 70        ]
 71        if not name:
 72            raise ValueError("Index.name is required.")
 73        self.name = name
 74        self.opclasses: tuple[str, ...] = tuple(opclasses)
 75        self.condition = condition
 76        self.include = tuple(include) if include else ()
 77        self.expressions: tuple[Expression, ...] = tuple(  # ty: ignore[invalid-assignment]
 78            F(expression) if isinstance(expression, str) else expression
 79            for expression in expressions
 80        )
 81
 82    @property
 83    def contains_expressions(self) -> bool:
 84        return bool(self.expressions)
 85
 86    @property
 87    def is_partial(self) -> bool:
 88        return self.condition is not None
 89
 90    def to_sql(self, model: type[Model]) -> str:
 91        """Generate CREATE INDEX CONCURRENTLY SQL as a plain string."""
 92        table = model.model_options.db_table
 93        condition = (
 94            compile_expression_sql(model, self.condition)
 95            if self.condition is not None
 96            else None
 97        )
 98
 99        if self.expressions:
100            columns_sql = compile_index_expressions_sql(model, self.expressions)
101        else:
102            col_parts = []
103            for i, (field_name, suffix) in enumerate(self.fields_orders):
104                field = model._model_meta.get_forward_field(field_name)
105                col = quote_name(field.column)
106                if self.opclasses:
107                    col = f"{col} {self.opclasses[i]}"
108                if suffix:
109                    col = f"{col} {suffix}"
110                col_parts.append(col)
111            columns_sql = ", ".join(col_parts)
112
113        include_sql = build_include_sql(model, self.include)
114        name = quote_name(self.name)
115        table = quote_name(table)
116        condition_sql = f" WHERE ({condition})" if condition else ""
117        return f"CREATE INDEX CONCURRENTLY {name} ON {table} ({columns_sql}){include_sql}{condition_sql}"
118
119    def deconstruct(self) -> tuple[str, tuple[Expression, ...], dict[str, Any]]:
120        path = f"{self.__class__.__module__}.{self.__class__.__name__}"
121        path = path.replace("plain.postgres.indexes", "plain.postgres")
122        kwargs: dict[str, Any] = {"name": self.name}
123        if self.fields:
124            kwargs["fields"] = self.fields
125        if self.opclasses:
126            kwargs["opclasses"] = self.opclasses
127        if self.condition:
128            kwargs["condition"] = self.condition
129        if self.include:
130            kwargs["include"] = self.include
131        return (path, self.expressions, kwargs)
132
133    def clone(self) -> Index:
134        """Create a copy of this Index."""
135        _, args, kwargs = self.deconstruct()
136        return self.__class__(*args, **kwargs)
137
138    def __repr__(self) -> str:
139        return "<{}:{}{}{}{}{}{}>".format(
140            self.__class__.__qualname__,
141            "" if not self.fields else f" fields={repr(self.fields)}",
142            "" if not self.expressions else f" expressions={repr(self.expressions)}",
143            "" if not self.name else f" name={repr(self.name)}",
144            "" if self.condition is None else f" condition={self.condition}",
145            "" if not self.include else f" include={repr(self.include)}",
146            "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
147        )
148
149    def __eq__(self, other: object) -> bool:
150        if isinstance(other, Index):
151            return self.deconstruct() == other.deconstruct()
152        return NotImplemented
153
154
155class IndexExpression(Func):
156    """Order and wrap expressions for CREATE INDEX statements."""
157
158    template = "%(expressions)s"
159    wrapper_classes = (OrderBy,)
160
161    def resolve_expression(
162        self,
163        query: Any = None,
164        allow_joins: bool = True,
165        reuse: Any = None,
166        summarize: bool = False,
167        for_save: bool = False,
168    ) -> Self:
169        expressions = list(self.flatten())
170        # Split expressions and wrappers.
171        index_expressions, wrappers = partition(
172            lambda e: isinstance(e, self.wrapper_classes),
173            expressions,
174        )
175        wrapper_types = [type(wrapper) for wrapper in wrappers]
176        if len(wrapper_types) != len(set(wrapper_types)):
177            raise ValueError(
178                "Multiple references to {} can't be used in an indexed "
179                "expression.".format(
180                    ", ".join(
181                        [
182                            wrapper_cls.__qualname__
183                            for wrapper_cls in self.wrapper_classes
184                        ]
185                    )
186                )
187            )
188        if expressions[1 : len(wrappers) + 1] != wrappers:
189            raise ValueError(
190                "{} must be topmost expressions in an indexed expression.".format(
191                    ", ".join(
192                        [
193                            wrapper_cls.__qualname__
194                            for wrapper_cls in self.wrapper_classes
195                        ]
196                    )
197                )
198            )
199        # Wrap expressions in parentheses if they are not column references.
200        root_expression = index_expressions[1]
201        resolve_root_expression = root_expression.resolve_expression(
202            query,
203            allow_joins,
204            reuse,
205            summarize,
206            for_save,
207        )
208        if not isinstance(resolve_root_expression, Col):
209            root_expression = Func(root_expression, template="(%(expressions)s)")
210
211        if wrappers:
212            # Order wrappers and set their expressions.
213            wrappers = sorted(
214                wrappers,
215                key=lambda w: self.wrapper_classes.index(type(w)),
216            )
217            wrappers = [wrapper.copy() for wrapper in wrappers]
218            for i, wrapper in enumerate(wrappers[:-1]):
219                wrapper.set_source_expressions([wrappers[i + 1]])
220            # Set the root expression on the deepest wrapper.
221            wrappers[-1].set_source_expressions([root_expression])
222            self.set_source_expressions([wrappers[0]])
223        else:
224            # Use the root expression, if there are no wrappers.
225            self.set_source_expressions([root_expression])
226        return super().resolve_expression(
227            query, allow_joins, reuse, summarize, for_save
228        )