Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from abc import ABC, abstractmethod
  4from collections.abc import Generator, Sequence
  5from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, runtime_checkable
  6
  7from plain.models.backends.utils import CursorWrapper
  8
  9if TYPE_CHECKING:
 10    from plain.models.backends.base.base import BaseDatabaseWrapper
 11
 12
 13# Protocols for structural typing - allows backends to extend with extra fields
 14@runtime_checkable
 15class TableInfoProtocol(Protocol):
 16    """Protocol for TableInfo - backends can add extra fields."""
 17
 18    name: str
 19    type: str
 20
 21
 22@runtime_checkable
 23class FieldInfoProtocol(Protocol):
 24    """Protocol for FieldInfo - backends can add extra fields."""
 25
 26    name: str
 27    type_code: Any
 28    display_size: int | None
 29    internal_size: int | None
 30    precision: int | None
 31    scale: int | None
 32    null_ok: bool | None
 33    default: Any
 34    collation: str | None
 35
 36
 37class TableInfo(NamedTuple):
 38    """Structure returned by DatabaseIntrospection.get_table_list()."""
 39
 40    name: str
 41    type: str
 42
 43
 44class FieldInfo(NamedTuple):
 45    """Structure returned by the DB-API cursor.description interface (PEP 249)."""
 46
 47    name: str
 48    type_code: Any
 49    display_size: int | None
 50    internal_size: int | None
 51    precision: int | None
 52    scale: int | None
 53    null_ok: bool | None
 54    default: Any
 55    collation: str | None
 56
 57
 58class BaseDatabaseIntrospection(ABC):
 59    """Encapsulate backend-specific introspection utilities."""
 60
 61    data_types_reverse: dict[Any, str] = {}
 62
 63    def __init__(self, connection: BaseDatabaseWrapper) -> None:
 64        self.connection = connection
 65
 66    def get_field_type(self, data_type: Any, description: Any) -> str:
 67        """
 68        Hook for a database backend to use the cursor description to
 69        match a Plain field type to a database column.
 70
 71        For Oracle, the column data_type on its own is insufficient to
 72        distinguish between a FloatField and IntegerField, for example.
 73        """
 74        return self.data_types_reverse[data_type]
 75
 76    def identifier_converter(self, name: str) -> str:
 77        """
 78        Apply a conversion to the identifier for the purposes of comparison.
 79
 80        The default identifier converter is for case sensitive comparison.
 81        """
 82        return name
 83
 84    def table_names(
 85        self, cursor: CursorWrapper | None = None, include_views: bool = False
 86    ) -> list[str]:
 87        """
 88        Return a list of names of all tables that exist in the database.
 89        Sort the returned table list by Python's default sorting. Do NOT use
 90        the database's ORDER BY here to avoid subtle differences in sorting
 91        order between databases.
 92        """
 93
 94        def get_names(cursor: CursorWrapper) -> list[str]:
 95            return sorted(
 96                ti.name
 97                for ti in self.get_table_list(cursor)
 98                if include_views or ti.type == "t"
 99            )
100
101        if cursor is None:
102            with self.connection.cursor() as cursor:
103                return get_names(cursor)
104        return get_names(cursor)
105
106    @abstractmethod
107    def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfoProtocol]:
108        """
109        Return an unsorted list of TableInfo named tuples of all tables and
110        views that exist in the database.
111        """
112        ...
113
114    @abstractmethod
115    def get_table_description(
116        self, cursor: CursorWrapper, table_name: str
117    ) -> Sequence[FieldInfoProtocol]:
118        """
119        Return a description of the table with the DB-API cursor.description
120        interface.
121        """
122        ...
123
124    def get_migratable_models(self) -> Generator[Any, None, None]:
125        from plain.models import models_registry
126        from plain.packages import packages_registry
127
128        return (
129            model
130            for package_config in packages_registry.get_package_configs()
131            for model in models_registry.get_models(
132                package_label=package_config.package_label
133            )
134            if model.model_options.can_migrate(self.connection)
135        )
136
137    def plain_table_names(
138        self, only_existing: bool = False, include_views: bool = True
139    ) -> list[str]:
140        """
141        Return a list of all table names that have associated Plain models and
142        are in INSTALLED_PACKAGES.
143
144        If only_existing is True, include only the tables in the database.
145        """
146        tables = set()
147        for model in self.get_migratable_models():
148            tables.add(model.model_options.db_table)
149            tables.update(
150                f.m2m_db_table() for f in model._model_meta.local_many_to_many
151            )
152        tables = list(tables)
153        if only_existing:
154            existing_tables = set(self.table_names(include_views=include_views))
155            tables = [
156                t for t in tables if self.identifier_converter(t) in existing_tables
157            ]
158        return tables
159
160    def sequence_list(self) -> list[dict[str, Any]]:
161        """
162        Return a list of information about all DB sequences for all models in
163        all packages.
164        """
165        sequence_list = []
166        with self.connection.cursor() as cursor:
167            for model in self.get_migratable_models():
168                sequence_list.extend(
169                    self.get_sequences(
170                        cursor,
171                        model.model_options.db_table,
172                        model._model_meta.local_fields,
173                    )
174                )
175        return sequence_list
176
177    @abstractmethod
178    def get_sequences(
179        self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
180    ) -> list[dict[str, Any]]:
181        """
182        Return a list of introspected sequences for table_name. Each sequence
183        is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
184        'name' key can be added if the backend supports named sequences.
185        """
186        ...
187
188    @abstractmethod
189    def get_relations(
190        self, cursor: CursorWrapper, table_name: str
191    ) -> dict[str, tuple[str, str]]:
192        """
193        Return a dictionary of {field_name: (field_name_other_table, other_table)}
194        representing all foreign keys in the given table.
195        """
196        ...
197
198    def get_primary_key_column(
199        self, cursor: CursorWrapper, table_name: str
200    ) -> str | None:
201        """
202        Return the name of the primary key column for the given table.
203        """
204        columns = self.get_primary_key_columns(cursor, table_name)
205        return columns[0] if columns else None
206
207    def get_primary_key_columns(
208        self, cursor: CursorWrapper, table_name: str
209    ) -> list[str] | None:
210        """Return a list of primary key columns for the given table."""
211        for constraint in self.get_constraints(cursor, table_name).values():
212            if constraint["primary_key"]:
213                return constraint["columns"]
214        return None
215
216    @abstractmethod
217    def get_constraints(
218        self, cursor: CursorWrapper, table_name: str
219    ) -> dict[str, dict[str, Any]]:
220        """
221        Retrieve any constraints or keys (unique, pk, fk, check, index)
222        across one or more columns.
223
224        Return a dict mapping constraint names to their attributes,
225        where attributes is a dict with keys:
226         * columns: List of columns this covers
227         * primary_key: True if primary key, False otherwise
228         * unique: True if this is a unique constraint, False otherwise
229         * foreign_key: (table, column) of target, or None
230         * check: True if check constraint, False otherwise
231         * index: True if index, False otherwise.
232         * orders: The order (ASC/DESC) defined for the columns of indexes
233         * type: The type of the index (btree, hash, etc.)
234
235        Some backends may return special constraint names that don't exist
236        if they don't name constraints of a certain type (e.g. SQLite)
237        """
238        ...