1from __future__ import annotations
2
3from collections import namedtuple
4from collections.abc import Generator
5from typing import TYPE_CHECKING, Any
6
7if TYPE_CHECKING:
8 from plain.models.backends.base.base import BaseDatabaseWrapper
9
10# Structure returned by DatabaseIntrospection.get_table_list()
11TableInfo = namedtuple("TableInfo", ["name", "type"])
12
13# Structure returned by the DB-API cursor.description interface (PEP 249)
14FieldInfo = namedtuple(
15 "FieldInfo",
16 "name type_code display_size internal_size precision scale null_ok "
17 "default collation",
18)
19
20
21class BaseDatabaseIntrospection:
22 """Encapsulate backend-specific introspection utilities."""
23
24 data_types_reverse: dict[Any, str] = {}
25
26 def __init__(self, connection: BaseDatabaseWrapper) -> None:
27 self.connection = connection
28
29 def get_field_type(self, data_type: Any, description: Any) -> str:
30 """
31 Hook for a database backend to use the cursor description to
32 match a Plain field type to a database column.
33
34 For Oracle, the column data_type on its own is insufficient to
35 distinguish between a FloatField and IntegerField, for example.
36 """
37 return self.data_types_reverse[data_type]
38
39 def identifier_converter(self, name: str) -> str:
40 """
41 Apply a conversion to the identifier for the purposes of comparison.
42
43 The default identifier converter is for case sensitive comparison.
44 """
45 return name
46
47 def table_names(self, cursor: Any = None, include_views: bool = False) -> list[str]:
48 """
49 Return a list of names of all tables that exist in the database.
50 Sort the returned table list by Python's default sorting. Do NOT use
51 the database's ORDER BY here to avoid subtle differences in sorting
52 order between databases.
53 """
54
55 def get_names(cursor: Any) -> list[str]:
56 return sorted(
57 ti.name
58 for ti in self.get_table_list(cursor)
59 if include_views or ti.type == "t"
60 )
61
62 if cursor is None:
63 with self.connection.cursor() as cursor:
64 return get_names(cursor)
65 return get_names(cursor)
66
67 def get_table_list(self, cursor: Any) -> list[TableInfo]:
68 """
69 Return an unsorted list of TableInfo named tuples of all tables and
70 views that exist in the database.
71 """
72 raise NotImplementedError(
73 "subclasses of BaseDatabaseIntrospection may require a get_table_list() "
74 "method"
75 )
76
77 def get_table_description(self, cursor: Any, table_name: str) -> list[FieldInfo]:
78 """
79 Return a description of the table with the DB-API cursor.description
80 interface.
81 """
82 raise NotImplementedError(
83 "subclasses of BaseDatabaseIntrospection may require a "
84 "get_table_description() method."
85 )
86
87 def get_migratable_models(self) -> Generator[Any, None, None]:
88 from plain.models import models_registry
89 from plain.packages import packages_registry
90
91 return (
92 model
93 for package_config in packages_registry.get_package_configs()
94 for model in models_registry.get_models(
95 package_label=package_config.package_label
96 )
97 if model.model_options.can_migrate(self.connection)
98 )
99
100 def plain_table_names(
101 self, only_existing: bool = False, include_views: bool = True
102 ) -> list[str]:
103 """
104 Return a list of all table names that have associated Plain models and
105 are in INSTALLED_PACKAGES.
106
107 If only_existing is True, include only the tables in the database.
108 """
109 tables = set()
110 for model in self.get_migratable_models():
111 tables.add(model.model_options.db_table)
112 tables.update(
113 f.m2m_db_table() for f in model._model_meta.local_many_to_many
114 )
115 tables = list(tables)
116 if only_existing:
117 existing_tables = set(self.table_names(include_views=include_views))
118 tables = [
119 t for t in tables if self.identifier_converter(t) in existing_tables
120 ]
121 return tables
122
123 def sequence_list(self) -> list[dict[str, Any]]:
124 """
125 Return a list of information about all DB sequences for all models in
126 all packages.
127 """
128 sequence_list = []
129 with self.connection.cursor() as cursor:
130 for model in self.get_migratable_models():
131 sequence_list.extend(
132 self.get_sequences(
133 cursor,
134 model.model_options.db_table,
135 model._model_meta.local_fields,
136 )
137 )
138 return sequence_list
139
140 def get_sequences(
141 self, cursor: Any, table_name: str, table_fields: tuple[Any, ...] = ()
142 ) -> list[dict[str, Any]]:
143 """
144 Return a list of introspected sequences for table_name. Each sequence
145 is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
146 'name' key can be added if the backend supports named sequences.
147 """
148 raise NotImplementedError(
149 "subclasses of BaseDatabaseIntrospection may require a get_sequences() "
150 "method"
151 )
152
153 def get_relations(self, cursor: Any, table_name: str) -> dict[str, tuple[str, str]]:
154 """
155 Return a dictionary of {field_name: (field_name_other_table, other_table)}
156 representing all foreign keys in the given table.
157 """
158 raise NotImplementedError(
159 "subclasses of BaseDatabaseIntrospection may require a "
160 "get_relations() method."
161 )
162
163 def get_primary_key_column(self, cursor: Any, table_name: str) -> str | None:
164 """
165 Return the name of the primary key column for the given table.
166 """
167 columns = self.get_primary_key_columns(cursor, table_name)
168 return columns[0] if columns else None
169
170 def get_primary_key_columns(self, cursor: Any, table_name: str) -> list[str] | None:
171 """Return a list of primary key columns for the given table."""
172 for constraint in self.get_constraints(cursor, table_name).values():
173 if constraint["primary_key"]:
174 return constraint["columns"]
175 return None
176
177 def get_constraints(
178 self, cursor: Any, table_name: str
179 ) -> dict[str, dict[str, Any]]:
180 """
181 Retrieve any constraints or keys (unique, pk, fk, check, index)
182 across one or more columns.
183
184 Return a dict mapping constraint names to their attributes,
185 where attributes is a dict with keys:
186 * columns: List of columns this covers
187 * primary_key: True if primary key, False otherwise
188 * unique: True if this is a unique constraint, False otherwise
189 * foreign_key: (table, column) of target, or None
190 * check: True if check constraint, False otherwise
191 * index: True if index, False otherwise.
192 * orders: The order (ASC/DESC) defined for the columns of indexes
193 * type: The type of the index (btree, hash, etc.)
194
195 Some backends may return special constraint names that don't exist
196 if they don't name constraints of a certain type (e.g. SQLite)
197 """
198 raise NotImplementedError(
199 "subclasses of BaseDatabaseIntrospection may require a get_constraints() "
200 "method"
201 )