Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import os
  4import sys
  5from typing import TYPE_CHECKING, Any
  6
  7from plain.runtime import settings
  8
  9if TYPE_CHECKING:
 10    from plain.models.backends.base.base import BaseDatabaseWrapper
 11
 12# The prefix to put on the default database name when creating
 13# the test database.
 14TEST_DATABASE_PREFIX = "test_"
 15
 16
 17class BaseDatabaseCreation:
 18    """
 19    Encapsulate backend-specific differences pertaining to creation and
 20    destruction of the test database.
 21    """
 22
 23    def __init__(self, connection: BaseDatabaseWrapper):
 24        self.connection = connection
 25
 26    def _nodb_cursor(self) -> Any:
 27        return self.connection._nodb_cursor()
 28
 29    def log(self, msg: str) -> None:
 30        sys.stderr.write(msg + os.linesep)
 31
 32    def create_test_db(self, verbosity: int = 1, prefix: str = "") -> str:
 33        """
 34        Create a test database, prompting the user for confirmation if the
 35        database already exists. Return the name of the test database created.
 36
 37        If prefix is provided, it will be prepended to the database name
 38        to isolate it from other test databases.
 39        """
 40        from plain.models.cli import migrate
 41
 42        test_database_name = self._get_test_db_name(prefix)
 43
 44        if verbosity >= 1:
 45            self.log(f"Creating test database '{test_database_name}'...")
 46
 47        self._create_test_db(
 48            test_database_name=test_database_name, verbosity=verbosity, autoclobber=True
 49        )
 50
 51        self.connection.close()
 52        settings.DATABASE["NAME"] = test_database_name
 53        self.connection.settings_dict["NAME"] = test_database_name
 54
 55        migrate.callback(
 56            package_label=None,
 57            migration_name=None,
 58            fake=False,
 59            plan=False,
 60            check_unapplied=False,
 61            backup=False,
 62            no_input=True,
 63            atomic_batch=False,  # No need for atomic batch when creating test database
 64            quiet=verbosity < 2,  # Show migration output when verbosity is 2+
 65        )
 66
 67        # Ensure a connection for the side effect of initializing the test database.
 68        self.connection.ensure_connection()
 69
 70        return test_database_name
 71
 72    def set_as_test_mirror(self, primary_settings_dict: dict[str, Any]) -> None:
 73        """
 74        Set this database up to be used in testing as a mirror of a primary
 75        database whose settings are given.
 76        """
 77        self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"]
 78
 79    # def serialize_db_to_string(self):
 80    #     """
 81    #     Serialize all data in the database into a JSON string.
 82    #     Designed only for test runner usage; will not handle large
 83    #     amounts of data.
 84    #     """
 85
 86    #     # Iteratively return every object for all models to serialize.
 87    #     def get_objects():
 88    #         from plain.models.migrations.loader import MigrationLoader
 89
 90    #         loader = MigrationLoader(self.connection)
 91    #         for package_config in packages.get_package_configs():
 92    #             if (
 93    #                 package_config.models_module is not None
 94    #                 and package_config.package_label in loader.migrated_packages
 95    #             ):
 96    #                 for model in package_config.get_models():
 97    #                     if model.model_options.can_migrate(
 98    #                         self.connection
 99    #                     ) and router.allow_migrate_model(self.connection.alias, model):
100    #                         queryset = model._base_manager.using(
101    #                             self.connection.alias,
102    #                         ).order_by("id")
103    #                         yield from queryset.iterator()
104
105    #     # Serialize to a string
106    #     out = StringIO()
107    #     serializers.serialize("json", get_objects(), indent=None, stream=out)
108    #     return out.getvalue()
109
110    # def deserialize_db_from_string(self, data):
111    #     """
112    #     Reload the database with data from a string generated by
113    #     the serialize_db_to_string() method.
114    #     """
115    #     data = StringIO(data)
116    #     table_names = set()
117    #     # Load data in a transaction to handle forward references and cycles.
118    #     with atomic(using=self.connection.alias):
119    #         # Disable constraint checks, because some databases (MySQL) doesn't
120    #         # support deferred checks.
121    #         with self.connection.constraint_checks_disabled():
122    #             for obj in serializers.deserialize(
123    #                 "json", data, using=self.connection.alias
124    #             ):
125    #                 obj.save()
126    #                 table_names.add(obj.object.model_options.db_table)
127    #         # Manually check for any invalid keys that might have been added,
128    #         # because constraint checks were disabled.
129    #         self.connection.check_constraints(table_names=table_names)
130
131    def _get_test_db_name(self, prefix: str = "") -> str:
132        """
133        Internal implementation - return the name of the test DB that will be
134        created. Only useful when called from create_test_db() and
135        _create_test_db() and when no external munging is done with the 'NAME'
136        settings.
137
138        If prefix is provided, it will be prepended to the database name.
139        """
140        # Determine the base name: explicit TEST.NAME overrides base NAME.
141        base_name = (
142            self.connection.settings_dict["TEST"]["NAME"]
143            or self.connection.settings_dict["NAME"]
144        )
145        if prefix:
146            return f"{prefix}_{base_name}"
147        if self.connection.settings_dict["TEST"]["NAME"]:
148            return self.connection.settings_dict["TEST"]["NAME"]
149        return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"]
150
151    def _execute_create_test_db(self, cursor: Any, parameters: dict[str, str]) -> None:
152        cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
153
154    def _create_test_db(
155        self, *, test_database_name: str, verbosity: int, autoclobber: bool
156    ) -> str:
157        """
158        Internal implementation - create the test db tables.
159        """
160        test_db_params = {
161            "dbname": self.connection.ops.quote_name(test_database_name),
162            "suffix": self.sql_table_creation_suffix(),
163        }
164        # Create the test database and connect to it.
165        with self._nodb_cursor() as cursor:
166            try:
167                self._execute_create_test_db(cursor, test_db_params)
168            except Exception as e:
169                self.log(f"Got an error creating the test database: {e}")
170                if not autoclobber:
171                    confirm = input(
172                        "Type 'yes' if you would like to try deleting the test "
173                        f"database '{test_database_name}', or 'no' to cancel: "
174                    )
175                if autoclobber or confirm == "yes":
176                    try:
177                        if verbosity >= 1:
178                            self.log(
179                                f"Destroying old test database '{test_database_name}'..."
180                            )
181                        cursor.execute(
182                            "DROP DATABASE {dbname}".format(**test_db_params)
183                        )
184                        self._execute_create_test_db(cursor, test_db_params)
185                    except Exception as e:
186                        self.log(f"Got an error recreating the test database: {e}")
187                        sys.exit(2)
188                else:
189                    self.log("Tests cancelled.")
190                    sys.exit(1)
191
192        return test_database_name
193
194    def destroy_test_db(
195        self, old_database_name: str | None = None, verbosity: int = 1
196    ) -> None:
197        """
198        Destroy a test database, prompting the user for confirmation if the
199        database already exists.
200        """
201        self.connection.close()
202
203        test_database_name = self.connection.settings_dict["NAME"]
204
205        if verbosity >= 1:
206            self.log(f"Destroying test database '{test_database_name}'...")
207        self._destroy_test_db(test_database_name, verbosity)
208
209        # Restore the original database name
210        if old_database_name is not None:
211            settings.DATABASE["NAME"] = old_database_name
212            self.connection.settings_dict["NAME"] = old_database_name
213
214    def _destroy_test_db(self, test_database_name: str, verbosity: int) -> None:
215        """
216        Internal implementation - remove the test db tables.
217        """
218        # Remove the test database to clean up after
219        # ourselves. Connect to the previous database (not the test database)
220        # to do so, because it's not allowed to delete a database while being
221        # connected to it.
222        with self._nodb_cursor() as cursor:
223            cursor.execute(
224                f"DROP DATABASE {self.connection.ops.quote_name(test_database_name)}"
225            )
226
227    def sql_table_creation_suffix(self) -> str:
228        """
229        SQL to append to the end of the test table creation statements.
230        """
231        return ""
232
233    def test_db_signature(self, prefix: str = "") -> tuple[str, str, str, str]:
234        """
235        Return a tuple with elements of self.connection.settings_dict (a
236        DATABASE setting value) that uniquely identify a database
237        accordingly to the RDBMS particularities.
238        """
239        settings_dict = self.connection.settings_dict
240        return (
241            settings_dict["HOST"],
242            settings_dict["PORT"],
243            settings_dict["ENGINE"],
244            self._get_test_db_name(prefix),
245        )