1from __future__ import annotations
2
3from types import NoneType
4from typing import TYPE_CHECKING, Any, Self
5
6from plain.postgres.ddl import (
7 build_include_sql,
8 compile_expression_sql,
9 compile_index_expressions_sql,
10)
11from plain.postgres.dialect import quote_name
12from plain.postgres.expressions import Col, F, Func, OrderBy
13from plain.postgres.query_utils import Q
14from plain.utils.functional import partition
15
16if TYPE_CHECKING:
17 from plain.postgres.base import Model
18 from plain.postgres.expressions import Expression
19
20__all__ = ["Index"]
21
22
23class Index:
24 suffix = "idx"
25 # Postgres identifier limit: NAMEDATALEN - 1 = 63
26 max_name_length = 63
27
28 def __init__(
29 self,
30 *expressions: Any,
31 name: str,
32 fields: tuple[str, ...] | list[str] = (),
33 opclasses: tuple[str, ...] | list[str] = (),
34 condition: Q | None = None,
35 include: tuple[str, ...] | list[str] | None = None,
36 ) -> None:
37 if not isinstance(condition, NoneType | Q):
38 raise ValueError("Index.condition must be a Q instance.")
39 if not isinstance(fields, list | tuple):
40 raise ValueError("Index.fields must be a list or tuple.")
41 if not isinstance(opclasses, list | tuple):
42 raise ValueError("Index.opclasses must be a list or tuple.")
43 if not expressions and not fields:
44 raise ValueError(
45 "At least one field or expression is required to define an index."
46 )
47 if expressions and fields:
48 raise ValueError(
49 "Index.fields and expressions are mutually exclusive.",
50 )
51 if expressions and opclasses:
52 raise ValueError(
53 "Index.opclasses cannot be used with expressions. Use "
54 "a custom OpClass() instead."
55 )
56 if opclasses and len(fields) != len(opclasses):
57 raise ValueError(
58 "Index.fields and Index.opclasses must have the same number of "
59 "elements."
60 )
61 if fields and not all(isinstance(field, str) for field in fields):
62 raise ValueError("Index.fields must contain only strings with field names.")
63 if not isinstance(include, NoneType | list | tuple):
64 raise ValueError("Index.include must be a list or tuple.")
65 self.fields = list(fields)
66 # A list of 2-tuple with the field name and ordering ('' or 'DESC').
67 self.fields_orders = [
68 (field_name.removeprefix("-"), "DESC" if field_name.startswith("-") else "")
69 for field_name in self.fields
70 ]
71 if not name:
72 raise ValueError("Index.name is required.")
73 self.name = name
74 self.opclasses: tuple[str, ...] = tuple(opclasses)
75 self.condition = condition
76 self.include = tuple(include) if include else ()
77 self.expressions: tuple[Expression, ...] = tuple( # ty: ignore[invalid-assignment]
78 F(expression) if isinstance(expression, str) else expression
79 for expression in expressions
80 )
81
82 @property
83 def contains_expressions(self) -> bool:
84 return bool(self.expressions)
85
86 @property
87 def is_partial(self) -> bool:
88 return self.condition is not None
89
90 def to_sql(self, model: type[Model]) -> str:
91 """Generate CREATE INDEX CONCURRENTLY SQL as a plain string."""
92 table = model.model_options.db_table
93 condition = (
94 compile_expression_sql(model, self.condition)
95 if self.condition is not None
96 else None
97 )
98
99 if self.expressions:
100 columns_sql = compile_index_expressions_sql(model, self.expressions)
101 else:
102 col_parts = []
103 for i, (field_name, suffix) in enumerate(self.fields_orders):
104 field = model._model_meta.get_forward_field(field_name)
105 col = quote_name(field.column)
106 if self.opclasses:
107 col = f"{col} {self.opclasses[i]}"
108 if suffix:
109 col = f"{col} {suffix}"
110 col_parts.append(col)
111 columns_sql = ", ".join(col_parts)
112
113 include_sql = build_include_sql(model, self.include)
114 name = quote_name(self.name)
115 table = quote_name(table)
116 condition_sql = f" WHERE ({condition})" if condition else ""
117 return f"CREATE INDEX CONCURRENTLY {name} ON {table} ({columns_sql}){include_sql}{condition_sql}"
118
119 def deconstruct(self) -> tuple[str, tuple[Expression, ...], dict[str, Any]]:
120 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
121 path = path.replace("plain.postgres.indexes", "plain.postgres")
122 kwargs: dict[str, Any] = {"name": self.name}
123 if self.fields:
124 kwargs["fields"] = self.fields
125 if self.opclasses:
126 kwargs["opclasses"] = self.opclasses
127 if self.condition:
128 kwargs["condition"] = self.condition
129 if self.include:
130 kwargs["include"] = self.include
131 return (path, self.expressions, kwargs)
132
133 def clone(self) -> Index:
134 """Create a copy of this Index."""
135 _, args, kwargs = self.deconstruct()
136 return self.__class__(*args, **kwargs)
137
138 def __repr__(self) -> str:
139 return "<{}:{}{}{}{}{}{}>".format(
140 self.__class__.__qualname__,
141 "" if not self.fields else f" fields={repr(self.fields)}",
142 "" if not self.expressions else f" expressions={repr(self.expressions)}",
143 "" if not self.name else f" name={repr(self.name)}",
144 "" if self.condition is None else f" condition={self.condition}",
145 "" if not self.include else f" include={repr(self.include)}",
146 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
147 )
148
149 def __eq__(self, other: object) -> bool:
150 if isinstance(other, Index):
151 return self.deconstruct() == other.deconstruct()
152 return NotImplemented
153
154
155class IndexExpression(Func):
156 """Order and wrap expressions for CREATE INDEX statements."""
157
158 template = "%(expressions)s"
159 wrapper_classes = (OrderBy,)
160
161 def resolve_expression(
162 self,
163 query: Any = None,
164 allow_joins: bool = True,
165 reuse: Any = None,
166 summarize: bool = False,
167 for_save: bool = False,
168 ) -> Self:
169 expressions = list(self.flatten())
170 # Split expressions and wrappers.
171 index_expressions, wrappers = partition(
172 lambda e: isinstance(e, self.wrapper_classes),
173 expressions,
174 )
175 wrapper_types = [type(wrapper) for wrapper in wrappers]
176 if len(wrapper_types) != len(set(wrapper_types)):
177 raise ValueError(
178 "Multiple references to {} can't be used in an indexed "
179 "expression.".format(
180 ", ".join(
181 [
182 wrapper_cls.__qualname__
183 for wrapper_cls in self.wrapper_classes
184 ]
185 )
186 )
187 )
188 if expressions[1 : len(wrappers) + 1] != wrappers:
189 raise ValueError(
190 "{} must be topmost expressions in an indexed expression.".format(
191 ", ".join(
192 [
193 wrapper_cls.__qualname__
194 for wrapper_cls in self.wrapper_classes
195 ]
196 )
197 )
198 )
199 # Wrap expressions in parentheses if they are not column references.
200 root_expression = index_expressions[1]
201 resolve_root_expression = root_expression.resolve_expression(
202 query,
203 allow_joins,
204 reuse,
205 summarize,
206 for_save,
207 )
208 if not isinstance(resolve_root_expression, Col):
209 root_expression = Func(root_expression, template="(%(expressions)s)")
210
211 if wrappers:
212 # Order wrappers and set their expressions.
213 wrappers = sorted(
214 wrappers,
215 key=lambda w: self.wrapper_classes.index(type(w)),
216 )
217 wrappers = [wrapper.copy() for wrapper in wrappers]
218 for i, wrapper in enumerate(wrappers[:-1]):
219 wrapper.set_source_expressions([wrappers[i + 1]])
220 # Set the root expression on the deepest wrapper.
221 wrappers[-1].set_source_expressions([root_expression])
222 self.set_source_expressions([wrappers[0]])
223 else:
224 # Use the root expression, if there are no wrappers.
225 self.set_source_expressions([root_expression])
226 return super().resolve_expression(
227 query, allow_joins, reuse, summarize, for_save
228 )