Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections import namedtuple
  4from collections.abc import Generator
  5from typing import TYPE_CHECKING, Any
  6
  7if TYPE_CHECKING:
  8    from plain.models.backends.base.base import BaseDatabaseWrapper
  9
 10# Structure returned by DatabaseIntrospection.get_table_list()
 11TableInfo = namedtuple("TableInfo", ["name", "type"])
 12
 13# Structure returned by the DB-API cursor.description interface (PEP 249)
 14FieldInfo = namedtuple(
 15    "FieldInfo",
 16    "name type_code display_size internal_size precision scale null_ok "
 17    "default collation",
 18)
 19
 20
 21class BaseDatabaseIntrospection:
 22    """Encapsulate backend-specific introspection utilities."""
 23
 24    data_types_reverse: dict[Any, str] = {}
 25
 26    def __init__(self, connection: BaseDatabaseWrapper) -> None:
 27        self.connection = connection
 28
 29    def get_field_type(self, data_type: Any, description: Any) -> str:
 30        """
 31        Hook for a database backend to use the cursor description to
 32        match a Plain field type to a database column.
 33
 34        For Oracle, the column data_type on its own is insufficient to
 35        distinguish between a FloatField and IntegerField, for example.
 36        """
 37        return self.data_types_reverse[data_type]
 38
 39    def identifier_converter(self, name: str) -> str:
 40        """
 41        Apply a conversion to the identifier for the purposes of comparison.
 42
 43        The default identifier converter is for case sensitive comparison.
 44        """
 45        return name
 46
 47    def table_names(self, cursor: Any = None, include_views: bool = False) -> list[str]:
 48        """
 49        Return a list of names of all tables that exist in the database.
 50        Sort the returned table list by Python's default sorting. Do NOT use
 51        the database's ORDER BY here to avoid subtle differences in sorting
 52        order between databases.
 53        """
 54
 55        def get_names(cursor: Any) -> list[str]:
 56            return sorted(
 57                ti.name
 58                for ti in self.get_table_list(cursor)
 59                if include_views or ti.type == "t"
 60            )
 61
 62        if cursor is None:
 63            with self.connection.cursor() as cursor:
 64                return get_names(cursor)
 65        return get_names(cursor)
 66
 67    def get_table_list(self, cursor: Any) -> list[TableInfo]:
 68        """
 69        Return an unsorted list of TableInfo named tuples of all tables and
 70        views that exist in the database.
 71        """
 72        raise NotImplementedError(
 73            "subclasses of BaseDatabaseIntrospection may require a get_table_list() "
 74            "method"
 75        )
 76
 77    def get_table_description(self, cursor: Any, table_name: str) -> list[FieldInfo]:
 78        """
 79        Return a description of the table with the DB-API cursor.description
 80        interface.
 81        """
 82        raise NotImplementedError(
 83            "subclasses of BaseDatabaseIntrospection may require a "
 84            "get_table_description() method."
 85        )
 86
 87    def get_migratable_models(self) -> Generator[Any, None, None]:
 88        from plain.models import models_registry
 89        from plain.packages import packages_registry
 90
 91        return (
 92            model
 93            for package_config in packages_registry.get_package_configs()
 94            for model in models_registry.get_models(
 95                package_label=package_config.package_label
 96            )
 97            if model.model_options.can_migrate(self.connection)
 98        )
 99
100    def plain_table_names(
101        self, only_existing: bool = False, include_views: bool = True
102    ) -> list[str]:
103        """
104        Return a list of all table names that have associated Plain models and
105        are in INSTALLED_PACKAGES.
106
107        If only_existing is True, include only the tables in the database.
108        """
109        tables = set()
110        for model in self.get_migratable_models():
111            tables.add(model.model_options.db_table)
112            tables.update(
113                f.m2m_db_table() for f in model._model_meta.local_many_to_many
114            )
115        tables = list(tables)
116        if only_existing:
117            existing_tables = set(self.table_names(include_views=include_views))
118            tables = [
119                t for t in tables if self.identifier_converter(t) in existing_tables
120            ]
121        return tables
122
123    def sequence_list(self) -> list[dict[str, Any]]:
124        """
125        Return a list of information about all DB sequences for all models in
126        all packages.
127        """
128        sequence_list = []
129        with self.connection.cursor() as cursor:
130            for model in self.get_migratable_models():
131                sequence_list.extend(
132                    self.get_sequences(
133                        cursor,
134                        model.model_options.db_table,
135                        model._model_meta.local_fields,
136                    )
137                )
138        return sequence_list
139
140    def get_sequences(
141        self, cursor: Any, table_name: str, table_fields: tuple[Any, ...] = ()
142    ) -> list[dict[str, Any]]:
143        """
144        Return a list of introspected sequences for table_name. Each sequence
145        is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
146        'name' key can be added if the backend supports named sequences.
147        """
148        raise NotImplementedError(
149            "subclasses of BaseDatabaseIntrospection may require a get_sequences() "
150            "method"
151        )
152
153    def get_relations(self, cursor: Any, table_name: str) -> dict[str, tuple[str, str]]:
154        """
155        Return a dictionary of {field_name: (field_name_other_table, other_table)}
156        representing all foreign keys in the given table.
157        """
158        raise NotImplementedError(
159            "subclasses of BaseDatabaseIntrospection may require a "
160            "get_relations() method."
161        )
162
163    def get_primary_key_column(self, cursor: Any, table_name: str) -> str | None:
164        """
165        Return the name of the primary key column for the given table.
166        """
167        columns = self.get_primary_key_columns(cursor, table_name)
168        return columns[0] if columns else None
169
170    def get_primary_key_columns(self, cursor: Any, table_name: str) -> list[str] | None:
171        """Return a list of primary key columns for the given table."""
172        for constraint in self.get_constraints(cursor, table_name).values():
173            if constraint["primary_key"]:
174                return constraint["columns"]
175        return None
176
177    def get_constraints(
178        self, cursor: Any, table_name: str
179    ) -> dict[str, dict[str, Any]]:
180        """
181        Retrieve any constraints or keys (unique, pk, fk, check, index)
182        across one or more columns.
183
184        Return a dict mapping constraint names to their attributes,
185        where attributes is a dict with keys:
186         * columns: List of columns this covers
187         * primary_key: True if primary key, False otherwise
188         * unique: True if this is a unique constraint, False otherwise
189         * foreign_key: (table, column) of target, or None
190         * check: True if check constraint, False otherwise
191         * index: True if index, False otherwise.
192         * orders: The order (ASC/DESC) defined for the columns of indexes
193         * type: The type of the index (btree, hash, etc.)
194
195        Some backends may return special constraint names that don't exist
196        if they don't name constraints of a certain type (e.g. SQLite)
197        """
198        raise NotImplementedError(
199            "subclasses of BaseDatabaseIntrospection may require a get_constraints() "
200            "method"
201        )