1from __future__ import annotations
  2
  3from abc import ABC, abstractmethod
  4from collections.abc import Generator, Sequence
  5from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, runtime_checkable
  6
  7from plain.models.backends.utils import CursorWrapper
  8
  9if TYPE_CHECKING:
 10    from plain.models.backends.base.base import BaseDatabaseWrapper
 11
 12
 13# Protocols for structural typing - allows backends to extend with extra fields
 14@runtime_checkable
 15class TableInfoProtocol(Protocol):
 16    """Protocol for TableInfo - backends can add extra fields."""
 17
 18    name: str
 19    type: str
 20
 21
 22@runtime_checkable
 23class FieldInfoProtocol(Protocol):
 24    """Protocol for FieldInfo - backends can add extra fields."""
 25
 26    name: str
 27    type_code: Any
 28    display_size: int | None
 29    internal_size: int | None
 30    precision: int | None
 31    scale: int | None
 32    null_ok: bool | None
 33    default: Any
 34
 35
 36class TableInfo(NamedTuple):
 37    """Structure returned by DatabaseIntrospection.get_table_list()."""
 38
 39    name: str
 40    type: str
 41
 42
 43class FieldInfo(NamedTuple):
 44    """Structure returned by the DB-API cursor.description interface (PEP 249)."""
 45
 46    name: str
 47    type_code: Any
 48    display_size: int | None
 49    internal_size: int | None
 50    precision: int | None
 51    scale: int | None
 52    null_ok: bool | None
 53    default: Any
 54
 55
 56class BaseDatabaseIntrospection(ABC):
 57    """Encapsulate backend-specific introspection utilities."""
 58
 59    data_types_reverse: dict[Any, str] = {}
 60
 61    def __init__(self, connection: BaseDatabaseWrapper) -> None:
 62        self.connection = connection
 63
 64    def get_field_type(self, data_type: Any, description: Any) -> str:
 65        """
 66        Hook for a database backend to use the cursor description to
 67        match a Plain field type to a database column.
 68
 69        For Oracle, the column data_type on its own is insufficient to
 70        distinguish between a FloatField and IntegerField, for example.
 71        """
 72        return self.data_types_reverse[data_type]
 73
 74    def identifier_converter(self, name: str) -> str:
 75        """
 76        Apply a conversion to the identifier for the purposes of comparison.
 77
 78        The default identifier converter is for case sensitive comparison.
 79        """
 80        return name
 81
 82    def table_names(
 83        self, cursor: CursorWrapper | None = None, include_views: bool = False
 84    ) -> list[str]:
 85        """
 86        Return a list of names of all tables that exist in the database.
 87        Sort the returned table list by Python's default sorting. Do NOT use
 88        the database's ORDER BY here to avoid subtle differences in sorting
 89        order between databases.
 90        """
 91
 92        def get_names(cursor: CursorWrapper) -> list[str]:
 93            return sorted(
 94                ti.name
 95                for ti in self.get_table_list(cursor)
 96                if include_views or ti.type == "t"
 97            )
 98
 99        if cursor is None:
100            with self.connection.cursor() as cursor:
101                return get_names(cursor)
102        return get_names(cursor)
103
104    @abstractmethod
105    def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfoProtocol]:
106        """
107        Return an unsorted list of TableInfo named tuples of all tables and
108        views that exist in the database.
109        """
110        ...
111
112    @abstractmethod
113    def get_table_description(
114        self, cursor: CursorWrapper, table_name: str
115    ) -> Sequence[FieldInfoProtocol]:
116        """
117        Return a description of the table with the DB-API cursor.description
118        interface.
119        """
120        ...
121
122    def get_migratable_models(self) -> Generator[Any, None, None]:
123        from plain.models import models_registry
124        from plain.packages import packages_registry
125
126        return (
127            model
128            for package_config in packages_registry.get_package_configs()
129            for model in models_registry.get_models(
130                package_label=package_config.package_label
131            )
132            if model.model_options.can_migrate(self.connection)
133        )
134
135    def plain_table_names(
136        self, only_existing: bool = False, include_views: bool = True
137    ) -> list[str]:
138        """
139        Return a list of all table names that have associated Plain models and
140        are in INSTALLED_PACKAGES.
141
142        If only_existing is True, include only the tables in the database.
143        """
144        tables = set()
145        for model in self.get_migratable_models():
146            tables.add(model.model_options.db_table)
147            tables.update(
148                f.m2m_db_table() for f in model._model_meta.local_many_to_many
149            )
150        tables = list(tables)
151        if only_existing:
152            existing_tables = set(self.table_names(include_views=include_views))
153            tables = [
154                t for t in tables if self.identifier_converter(t) in existing_tables
155            ]
156        return tables
157
158    def sequence_list(self) -> list[dict[str, Any]]:
159        """
160        Return a list of information about all DB sequences for all models in
161        all packages.
162        """
163        sequence_list = []
164        with self.connection.cursor() as cursor:
165            for model in self.get_migratable_models():
166                sequence_list.extend(
167                    self.get_sequences(
168                        cursor,
169                        model.model_options.db_table,
170                        model._model_meta.local_fields,
171                    )
172                )
173        return sequence_list
174
175    @abstractmethod
176    def get_sequences(
177        self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
178    ) -> list[dict[str, Any]]:
179        """
180        Return a list of introspected sequences for table_name. Each sequence
181        is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
182        'name' key can be added if the backend supports named sequences.
183        """
184        ...
185
186    @abstractmethod
187    def get_relations(
188        self, cursor: CursorWrapper, table_name: str
189    ) -> dict[str, tuple[str, str]]:
190        """
191        Return a dictionary of {field_name: (field_name_other_table, other_table)}
192        representing all foreign keys in the given table.
193        """
194        ...
195
196    def get_primary_key_column(
197        self, cursor: CursorWrapper, table_name: str
198    ) -> str | None:
199        """
200        Return the name of the primary key column for the given table.
201        """
202        columns = self.get_primary_key_columns(cursor, table_name)
203        return columns[0] if columns else None
204
205    def get_primary_key_columns(
206        self, cursor: CursorWrapper, table_name: str
207    ) -> list[str] | None:
208        """Return a list of primary key columns for the given table."""
209        for constraint in self.get_constraints(cursor, table_name).values():
210            if constraint["primary_key"]:
211                return constraint["columns"]
212        return None
213
214    @abstractmethod
215    def get_constraints(
216        self, cursor: CursorWrapper, table_name: str
217    ) -> dict[str, dict[str, Any]]:
218        """
219        Retrieve any constraints or keys (unique, pk, fk, check, index)
220        across one or more columns.
221
222        Return a dict mapping constraint names to their attributes,
223        where attributes is a dict with keys:
224         * columns: List of columns this covers
225         * primary_key: True if primary key, False otherwise
226         * unique: True if this is a unique constraint, False otherwise
227         * foreign_key: (table, column) of target, or None
228         * check: True if check constraint, False otherwise
229         * index: True if index, False otherwise.
230         * orders: The order (ASC/DESC) defined for the columns of indexes
231         * type: The type of the index (btree, hash, etc.)
232
233        Some backends may return special constraint names that don't exist
234        if they don't name constraints of a certain type (e.g. SQLite)
235        """
236        ...