1from __future__ import annotations
2
3from abc import ABC, abstractmethod
4from collections.abc import Generator, Sequence
5from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, runtime_checkable
6
7from plain.models.backends.utils import CursorWrapper
8
9if TYPE_CHECKING:
10 from plain.models.backends.base.base import BaseDatabaseWrapper
11
12
13# Protocols for structural typing - allows backends to extend with extra fields
14@runtime_checkable
15class TableInfoProtocol(Protocol):
16 """Protocol for TableInfo - backends can add extra fields."""
17
18 name: str
19 type: str
20
21
22@runtime_checkable
23class FieldInfoProtocol(Protocol):
24 """Protocol for FieldInfo - backends can add extra fields."""
25
26 name: str
27 type_code: Any
28 display_size: int | None
29 internal_size: int | None
30 precision: int | None
31 scale: int | None
32 null_ok: bool | None
33 default: Any
34
35
36class TableInfo(NamedTuple):
37 """Structure returned by DatabaseIntrospection.get_table_list()."""
38
39 name: str
40 type: str
41
42
43class FieldInfo(NamedTuple):
44 """Structure returned by the DB-API cursor.description interface (PEP 249)."""
45
46 name: str
47 type_code: Any
48 display_size: int | None
49 internal_size: int | None
50 precision: int | None
51 scale: int | None
52 null_ok: bool | None
53 default: Any
54
55
56class BaseDatabaseIntrospection(ABC):
57 """Encapsulate backend-specific introspection utilities."""
58
59 data_types_reverse: dict[Any, str] = {}
60
61 def __init__(self, connection: BaseDatabaseWrapper) -> None:
62 self.connection = connection
63
64 def get_field_type(self, data_type: Any, description: Any) -> str:
65 """
66 Hook for a database backend to use the cursor description to
67 match a Plain field type to a database column.
68
69 For Oracle, the column data_type on its own is insufficient to
70 distinguish between a FloatField and IntegerField, for example.
71 """
72 return self.data_types_reverse[data_type]
73
74 def identifier_converter(self, name: str) -> str:
75 """
76 Apply a conversion to the identifier for the purposes of comparison.
77
78 The default identifier converter is for case sensitive comparison.
79 """
80 return name
81
82 def table_names(
83 self, cursor: CursorWrapper | None = None, include_views: bool = False
84 ) -> list[str]:
85 """
86 Return a list of names of all tables that exist in the database.
87 Sort the returned table list by Python's default sorting. Do NOT use
88 the database's ORDER BY here to avoid subtle differences in sorting
89 order between databases.
90 """
91
92 def get_names(cursor: CursorWrapper) -> list[str]:
93 return sorted(
94 ti.name
95 for ti in self.get_table_list(cursor)
96 if include_views or ti.type == "t"
97 )
98
99 if cursor is None:
100 with self.connection.cursor() as cursor:
101 return get_names(cursor)
102 return get_names(cursor)
103
104 @abstractmethod
105 def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfoProtocol]:
106 """
107 Return an unsorted list of TableInfo named tuples of all tables and
108 views that exist in the database.
109 """
110 ...
111
112 @abstractmethod
113 def get_table_description(
114 self, cursor: CursorWrapper, table_name: str
115 ) -> Sequence[FieldInfoProtocol]:
116 """
117 Return a description of the table with the DB-API cursor.description
118 interface.
119 """
120 ...
121
122 def get_migratable_models(self) -> Generator[Any, None, None]:
123 from plain.models import models_registry
124 from plain.packages import packages_registry
125
126 return (
127 model
128 for package_config in packages_registry.get_package_configs()
129 for model in models_registry.get_models(
130 package_label=package_config.package_label
131 )
132 if model.model_options.can_migrate(self.connection)
133 )
134
135 def plain_table_names(
136 self, only_existing: bool = False, include_views: bool = True
137 ) -> list[str]:
138 """
139 Return a list of all table names that have associated Plain models and
140 are in INSTALLED_PACKAGES.
141
142 If only_existing is True, include only the tables in the database.
143 """
144 tables = set()
145 for model in self.get_migratable_models():
146 tables.add(model.model_options.db_table)
147 tables.update(
148 f.m2m_db_table() for f in model._model_meta.local_many_to_many
149 )
150 tables = list(tables)
151 if only_existing:
152 existing_tables = set(self.table_names(include_views=include_views))
153 tables = [
154 t for t in tables if self.identifier_converter(t) in existing_tables
155 ]
156 return tables
157
158 def sequence_list(self) -> list[dict[str, Any]]:
159 """
160 Return a list of information about all DB sequences for all models in
161 all packages.
162 """
163 sequence_list = []
164 with self.connection.cursor() as cursor:
165 for model in self.get_migratable_models():
166 sequence_list.extend(
167 self.get_sequences(
168 cursor,
169 model.model_options.db_table,
170 model._model_meta.local_fields,
171 )
172 )
173 return sequence_list
174
175 @abstractmethod
176 def get_sequences(
177 self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
178 ) -> list[dict[str, Any]]:
179 """
180 Return a list of introspected sequences for table_name. Each sequence
181 is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
182 'name' key can be added if the backend supports named sequences.
183 """
184 ...
185
186 @abstractmethod
187 def get_relations(
188 self, cursor: CursorWrapper, table_name: str
189 ) -> dict[str, tuple[str, str]]:
190 """
191 Return a dictionary of {field_name: (field_name_other_table, other_table)}
192 representing all foreign keys in the given table.
193 """
194 ...
195
196 def get_primary_key_column(
197 self, cursor: CursorWrapper, table_name: str
198 ) -> str | None:
199 """
200 Return the name of the primary key column for the given table.
201 """
202 columns = self.get_primary_key_columns(cursor, table_name)
203 return columns[0] if columns else None
204
205 def get_primary_key_columns(
206 self, cursor: CursorWrapper, table_name: str
207 ) -> list[str] | None:
208 """Return a list of primary key columns for the given table."""
209 for constraint in self.get_constraints(cursor, table_name).values():
210 if constraint["primary_key"]:
211 return constraint["columns"]
212 return None
213
214 @abstractmethod
215 def get_constraints(
216 self, cursor: CursorWrapper, table_name: str
217 ) -> dict[str, dict[str, Any]]:
218 """
219 Retrieve any constraints or keys (unique, pk, fk, check, index)
220 across one or more columns.
221
222 Return a dict mapping constraint names to their attributes,
223 where attributes is a dict with keys:
224 * columns: List of columns this covers
225 * primary_key: True if primary key, False otherwise
226 * unique: True if this is a unique constraint, False otherwise
227 * foreign_key: (table, column) of target, or None
228 * check: True if check constraint, False otherwise
229 * index: True if index, False otherwise.
230 * orders: The order (ASC/DESC) defined for the columns of indexes
231 * type: The type of the index (btree, hash, etc.)
232
233 Some backends may return special constraint names that don't exist
234 if they don't name constraints of a certain type (e.g. SQLite)
235 """
236 ...