Plain is headed towards 1.0! Subscribe for development updates →

  1from collections import namedtuple
  2
  3# Structure returned by DatabaseIntrospection.get_table_list()
  4TableInfo = namedtuple("TableInfo", ["name", "type"])
  5
  6# Structure returned by the DB-API cursor.description interface (PEP 249)
  7FieldInfo = namedtuple(
  8    "FieldInfo",
  9    "name type_code display_size internal_size precision scale null_ok "
 10    "default collation",
 11)
 12
 13
 14class BaseDatabaseIntrospection:
 15    """Encapsulate backend-specific introspection utilities."""
 16
 17    data_types_reverse = {}
 18
 19    def __init__(self, connection):
 20        self.connection = connection
 21
 22    def get_field_type(self, data_type, description):
 23        """
 24        Hook for a database backend to use the cursor description to
 25        match a Plain field type to a database column.
 26
 27        For Oracle, the column data_type on its own is insufficient to
 28        distinguish between a FloatField and IntegerField, for example.
 29        """
 30        return self.data_types_reverse[data_type]
 31
 32    def identifier_converter(self, name):
 33        """
 34        Apply a conversion to the identifier for the purposes of comparison.
 35
 36        The default identifier converter is for case sensitive comparison.
 37        """
 38        return name
 39
 40    def table_names(self, cursor=None, include_views=False):
 41        """
 42        Return a list of names of all tables that exist in the database.
 43        Sort the returned table list by Python's default sorting. Do NOT use
 44        the database's ORDER BY here to avoid subtle differences in sorting
 45        order between databases.
 46        """
 47
 48        def get_names(cursor):
 49            return sorted(
 50                ti.name
 51                for ti in self.get_table_list(cursor)
 52                if include_views or ti.type == "t"
 53            )
 54
 55        if cursor is None:
 56            with self.connection.cursor() as cursor:
 57                return get_names(cursor)
 58        return get_names(cursor)
 59
 60    def get_table_list(self, cursor):
 61        """
 62        Return an unsorted list of TableInfo named tuples of all tables and
 63        views that exist in the database.
 64        """
 65        raise NotImplementedError(
 66            "subclasses of BaseDatabaseIntrospection may require a get_table_list() "
 67            "method"
 68        )
 69
 70    def get_table_description(self, cursor, table_name):
 71        """
 72        Return a description of the table with the DB-API cursor.description
 73        interface.
 74        """
 75        raise NotImplementedError(
 76            "subclasses of BaseDatabaseIntrospection may require a "
 77            "get_table_description() method."
 78        )
 79
 80    def get_migratable_models(self):
 81        from plain.models.db import router
 82        from plain.packages import packages
 83
 84        return (
 85            model
 86            for package_config in packages.get_package_configs()
 87            for model in router.get_migratable_models(
 88                package_config, self.connection.alias
 89            )
 90            if model._meta.can_migrate(self.connection)
 91        )
 92
 93    def plain_table_names(self, only_existing=False, include_views=True):
 94        """
 95        Return a list of all table names that have associated Plain models and
 96        are in INSTALLED_PACKAGES.
 97
 98        If only_existing is True, include only the tables in the database.
 99        """
100        tables = set()
101        for model in self.get_migratable_models():
102            if not model._meta.managed:
103                continue
104            tables.add(model._meta.db_table)
105            tables.update(
106                f.m2m_db_table()
107                for f in model._meta.local_many_to_many
108                if f.remote_field.through._meta.managed
109            )
110        tables = list(tables)
111        if only_existing:
112            existing_tables = set(self.table_names(include_views=include_views))
113            tables = [
114                t for t in tables if self.identifier_converter(t) in existing_tables
115            ]
116        return tables
117
118    def installed_models(self, tables):
119        """
120        Return a set of all models represented by the provided list of table
121        names.
122        """
123        tables = set(map(self.identifier_converter, tables))
124        return {
125            m
126            for m in self.get_migratable_models()
127            if self.identifier_converter(m._meta.db_table) in tables
128        }
129
130    def sequence_list(self):
131        """
132        Return a list of information about all DB sequences for all models in
133        all packages.
134        """
135        sequence_list = []
136        with self.connection.cursor() as cursor:
137            for model in self.get_migratable_models():
138                if not model._meta.managed:
139                    continue
140                if model._meta.swapped:
141                    continue
142                sequence_list.extend(
143                    self.get_sequences(
144                        cursor, model._meta.db_table, model._meta.local_fields
145                    )
146                )
147                for f in model._meta.local_many_to_many:
148                    # If this is an m2m using an intermediate table,
149                    # we don't need to reset the sequence.
150                    if f.remote_field.through._meta.auto_created:
151                        sequence = self.get_sequences(cursor, f.m2m_db_table())
152                        sequence_list.extend(
153                            sequence or [{"table": f.m2m_db_table(), "column": None}]
154                        )
155        return sequence_list
156
157    def get_sequences(self, cursor, table_name, table_fields=()):
158        """
159        Return a list of introspected sequences for table_name. Each sequence
160        is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
161        'name' key can be added if the backend supports named sequences.
162        """
163        raise NotImplementedError(
164            "subclasses of BaseDatabaseIntrospection may require a get_sequences() "
165            "method"
166        )
167
168    def get_relations(self, cursor, table_name):
169        """
170        Return a dictionary of {field_name: (field_name_other_table, other_table)}
171        representing all foreign keys in the given table.
172        """
173        raise NotImplementedError(
174            "subclasses of BaseDatabaseIntrospection may require a "
175            "get_relations() method."
176        )
177
178    def get_primary_key_column(self, cursor, table_name):
179        """
180        Return the name of the primary key column for the given table.
181        """
182        columns = self.get_primary_key_columns(cursor, table_name)
183        return columns[0] if columns else None
184
185    def get_primary_key_columns(self, cursor, table_name):
186        """Return a list of primary key columns for the given table."""
187        for constraint in self.get_constraints(cursor, table_name).values():
188            if constraint["primary_key"]:
189                return constraint["columns"]
190        return None
191
192    def get_constraints(self, cursor, table_name):
193        """
194        Retrieve any constraints or keys (unique, pk, fk, check, index)
195        across one or more columns.
196
197        Return a dict mapping constraint names to their attributes,
198        where attributes is a dict with keys:
199         * columns: List of columns this covers
200         * primary_key: True if primary key, False otherwise
201         * unique: True if this is a unique constraint, False otherwise
202         * foreign_key: (table, column) of target, or None
203         * check: True if check constraint, False otherwise
204         * index: True if index, False otherwise.
205         * orders: The order (ASC/DESC) defined for the columns of indexes
206         * type: The type of the index (btree, hash, etc.)
207
208        Some backends may return special constraint names that don't exist
209        if they don't name constraints of a certain type (e.g. SQLite)
210        """
211        raise NotImplementedError(
212            "subclasses of BaseDatabaseIntrospection may require a get_constraints() "
213            "method"
214        )