1from __future__ import annotations
2
3from abc import ABC, abstractmethod
4from collections.abc import Generator, Sequence
5from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, runtime_checkable
6
7from plain.models.backends.utils import CursorWrapper
8
9if TYPE_CHECKING:
10 from plain.models.backends.base.base import BaseDatabaseWrapper
11
12
13# Protocols for structural typing - allows backends to extend with extra fields
14@runtime_checkable
15class TableInfoProtocol(Protocol):
16 """Protocol for TableInfo - backends can add extra fields."""
17
18 name: str
19 type: str
20
21
22@runtime_checkable
23class FieldInfoProtocol(Protocol):
24 """Protocol for FieldInfo - backends can add extra fields."""
25
26 name: str
27 type_code: Any
28 display_size: int | None
29 internal_size: int | None
30 precision: int | None
31 scale: int | None
32 null_ok: bool | None
33 default: Any
34 collation: str | None
35
36
37class TableInfo(NamedTuple):
38 """Structure returned by DatabaseIntrospection.get_table_list()."""
39
40 name: str
41 type: str
42
43
44class FieldInfo(NamedTuple):
45 """Structure returned by the DB-API cursor.description interface (PEP 249)."""
46
47 name: str
48 type_code: Any
49 display_size: int | None
50 internal_size: int | None
51 precision: int | None
52 scale: int | None
53 null_ok: bool | None
54 default: Any
55 collation: str | None
56
57
58class BaseDatabaseIntrospection(ABC):
59 """Encapsulate backend-specific introspection utilities."""
60
61 data_types_reverse: dict[Any, str] = {}
62
63 def __init__(self, connection: BaseDatabaseWrapper) -> None:
64 self.connection = connection
65
66 def get_field_type(self, data_type: Any, description: Any) -> str:
67 """
68 Hook for a database backend to use the cursor description to
69 match a Plain field type to a database column.
70
71 For Oracle, the column data_type on its own is insufficient to
72 distinguish between a FloatField and IntegerField, for example.
73 """
74 return self.data_types_reverse[data_type]
75
76 def identifier_converter(self, name: str) -> str:
77 """
78 Apply a conversion to the identifier for the purposes of comparison.
79
80 The default identifier converter is for case sensitive comparison.
81 """
82 return name
83
84 def table_names(
85 self, cursor: CursorWrapper | None = None, include_views: bool = False
86 ) -> list[str]:
87 """
88 Return a list of names of all tables that exist in the database.
89 Sort the returned table list by Python's default sorting. Do NOT use
90 the database's ORDER BY here to avoid subtle differences in sorting
91 order between databases.
92 """
93
94 def get_names(cursor: CursorWrapper) -> list[str]:
95 return sorted(
96 ti.name
97 for ti in self.get_table_list(cursor)
98 if include_views or ti.type == "t"
99 )
100
101 if cursor is None:
102 with self.connection.cursor() as cursor:
103 return get_names(cursor)
104 return get_names(cursor)
105
106 @abstractmethod
107 def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfoProtocol]:
108 """
109 Return an unsorted list of TableInfo named tuples of all tables and
110 views that exist in the database.
111 """
112 ...
113
114 @abstractmethod
115 def get_table_description(
116 self, cursor: CursorWrapper, table_name: str
117 ) -> Sequence[FieldInfoProtocol]:
118 """
119 Return a description of the table with the DB-API cursor.description
120 interface.
121 """
122 ...
123
124 def get_migratable_models(self) -> Generator[Any, None, None]:
125 from plain.models import models_registry
126 from plain.packages import packages_registry
127
128 return (
129 model
130 for package_config in packages_registry.get_package_configs()
131 for model in models_registry.get_models(
132 package_label=package_config.package_label
133 )
134 if model.model_options.can_migrate(self.connection)
135 )
136
137 def plain_table_names(
138 self, only_existing: bool = False, include_views: bool = True
139 ) -> list[str]:
140 """
141 Return a list of all table names that have associated Plain models and
142 are in INSTALLED_PACKAGES.
143
144 If only_existing is True, include only the tables in the database.
145 """
146 tables = set()
147 for model in self.get_migratable_models():
148 tables.add(model.model_options.db_table)
149 tables.update(
150 f.m2m_db_table() for f in model._model_meta.local_many_to_many
151 )
152 tables = list(tables)
153 if only_existing:
154 existing_tables = set(self.table_names(include_views=include_views))
155 tables = [
156 t for t in tables if self.identifier_converter(t) in existing_tables
157 ]
158 return tables
159
160 def sequence_list(self) -> list[dict[str, Any]]:
161 """
162 Return a list of information about all DB sequences for all models in
163 all packages.
164 """
165 sequence_list = []
166 with self.connection.cursor() as cursor:
167 for model in self.get_migratable_models():
168 sequence_list.extend(
169 self.get_sequences(
170 cursor,
171 model.model_options.db_table,
172 model._model_meta.local_fields,
173 )
174 )
175 return sequence_list
176
177 @abstractmethod
178 def get_sequences(
179 self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
180 ) -> list[dict[str, Any]]:
181 """
182 Return a list of introspected sequences for table_name. Each sequence
183 is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
184 'name' key can be added if the backend supports named sequences.
185 """
186 ...
187
188 @abstractmethod
189 def get_relations(
190 self, cursor: CursorWrapper, table_name: str
191 ) -> dict[str, tuple[str, str]]:
192 """
193 Return a dictionary of {field_name: (field_name_other_table, other_table)}
194 representing all foreign keys in the given table.
195 """
196 ...
197
198 def get_primary_key_column(
199 self, cursor: CursorWrapper, table_name: str
200 ) -> str | None:
201 """
202 Return the name of the primary key column for the given table.
203 """
204 columns = self.get_primary_key_columns(cursor, table_name)
205 return columns[0] if columns else None
206
207 def get_primary_key_columns(
208 self, cursor: CursorWrapper, table_name: str
209 ) -> list[str] | None:
210 """Return a list of primary key columns for the given table."""
211 for constraint in self.get_constraints(cursor, table_name).values():
212 if constraint["primary_key"]:
213 return constraint["columns"]
214 return None
215
216 @abstractmethod
217 def get_constraints(
218 self, cursor: CursorWrapper, table_name: str
219 ) -> dict[str, dict[str, Any]]:
220 """
221 Retrieve any constraints or keys (unique, pk, fk, check, index)
222 across one or more columns.
223
224 Return a dict mapping constraint names to their attributes,
225 where attributes is a dict with keys:
226 * columns: List of columns this covers
227 * primary_key: True if primary key, False otherwise
228 * unique: True if this is a unique constraint, False otherwise
229 * foreign_key: (table, column) of target, or None
230 * check: True if check constraint, False otherwise
231 * index: True if index, False otherwise.
232 * orders: The order (ASC/DESC) defined for the columns of indexes
233 * type: The type of the index (btree, hash, etc.)
234
235 Some backends may return special constraint names that don't exist
236 if they don't name constraints of a certain type (e.g. SQLite)
237 """
238 ...