Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import os
  4import sys
  5from typing import TYPE_CHECKING, Any
  6
  7from plain.runtime import settings
  8
  9if TYPE_CHECKING:
 10    from plain.models.backends.base.base import BaseDatabaseWrapper
 11
 12# The prefix to put on the default database name when creating
 13# the test database.
 14TEST_DATABASE_PREFIX = "test_"
 15
 16
 17class BaseDatabaseCreation:
 18    """
 19    Encapsulate backend-specific differences pertaining to creation and
 20    destruction of the test database.
 21    """
 22
 23    def __init__(self, connection: BaseDatabaseWrapper):
 24        self.connection = connection
 25
 26    def _nodb_cursor(self) -> Any:
 27        return self.connection._nodb_cursor()
 28
 29    def log(self, msg: str) -> None:
 30        sys.stderr.write(msg + os.linesep)
 31
 32    def create_test_db(self, verbosity: int = 1, prefix: str = "") -> str:
 33        """
 34        Create a test database, prompting the user for confirmation if the
 35        database already exists. Return the name of the test database created.
 36
 37        If prefix is provided, it will be prepended to the database name
 38        to isolate it from other test databases.
 39        """
 40        from plain.models.cli.migrations import apply
 41
 42        test_database_name = self._get_test_db_name(prefix)
 43
 44        if verbosity >= 1:
 45            self.log(f"Creating test database '{test_database_name}'...")
 46
 47        self._create_test_db(
 48            test_database_name=test_database_name, verbosity=verbosity, autoclobber=True
 49        )
 50
 51        self.connection.close()
 52        settings.DATABASE["NAME"] = test_database_name
 53        self.connection.settings_dict["NAME"] = test_database_name
 54
 55        apply.callback(
 56            package_label=None,
 57            migration_name=None,
 58            fake=False,
 59            plan=False,
 60            check_unapplied=False,
 61            backup=False,
 62            no_input=True,
 63            atomic_batch=False,  # No need for atomic batch when creating test database
 64            quiet=verbosity < 2,  # Show migration output when verbosity is 2+
 65        )
 66
 67        # Ensure a connection for the side effect of initializing the test database.
 68        self.connection.ensure_connection()
 69
 70        return test_database_name
 71
 72    def set_as_test_mirror(self, primary_settings_dict: dict[str, Any]) -> None:
 73        """
 74        Set this database up to be used in testing as a mirror of a primary
 75        database whose settings are given.
 76        """
 77        self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"]
 78
 79    # def serialize_db_to_string(self):
 80    #     """
 81    #     Serialize all data in the database into a JSON string.
 82    #     Designed only for test runner usage; will not handle large
 83    #     amounts of data.
 84    #     """
 85
 86    #     # Iteratively return every object for all models to serialize.
 87    #     def get_objects():
 88    #         from plain.models.migrations.loader import MigrationLoader
 89
 90    #         loader = MigrationLoader(self.connection)
 91    #         for package_config in packages.get_package_configs():
 92    #             if (
 93    #                 package_config.models_module is not None
 94    #                 and package_config.package_label in loader.migrated_packages
 95    #             ):
 96    #                 for model in package_config.get_models():
 97    #                     if model.model_options.can_migrate(
 98    #                         self.connection
 99    #                     ) and router.allow_migrate_model(self.connection.alias, model):
100    #                         queryset = model._base_manager.using(
101    #                             self.connection.alias,
102    #                         ).order_by("id")
103    #                         yield from queryset.iterator()
104
105    #     # Serialize to a string
106    #     out = StringIO()
107    #     serializers.serialize("json", get_objects(), indent=None, stream=out)
108    #     return out.getvalue()
109
110    # def deserialize_db_from_string(self, data):
111    #     """
112    #     Reload the database with data from a string generated by
113    #     the serialize_db_to_string() method.
114    #     """
115    #     data = StringIO(data)
116    #     table_names = set()
117    #     # Load data in a transaction to handle forward references and cycles.
118    #     with atomic(using=self.connection.alias):
119    #         # Disable constraint checks, because some databases (MySQL) doesn't
120    #         # support deferred checks.
121    #         with self.connection.constraint_checks_disabled():
122    #             for obj in serializers.deserialize(
123    #                 "json", data, using=self.connection.alias
124    #             ):
125    #                 obj.save()
126    #                 table_names.add(obj.object.model_options.db_table)
127    #         # Manually check for any invalid keys that might have been added,
128    #         # because constraint checks were disabled.
129    #         self.connection.check_constraints(table_names=table_names)
130
131    def _get_test_db_name(self, prefix: str = "") -> str:
132        """
133        Internal implementation - return the name of the test DB that will be
134        created. Only useful when called from create_test_db() and
135        _create_test_db() and when no external munging is done with the 'NAME'
136        settings.
137
138        If prefix is provided, it will be prepended to the database name.
139        """
140        # Determine the base name: explicit TEST.NAME overrides base NAME.
141        base_name = (
142            self.connection.settings_dict["TEST"]["NAME"]
143            or self.connection.settings_dict["NAME"]
144        )
145        if prefix:
146            return f"{prefix}_{base_name}"
147        if self.connection.settings_dict["TEST"]["NAME"]:
148            return self.connection.settings_dict["TEST"]["NAME"]
149        name = self.connection.settings_dict["NAME"]
150        assert name is not None, "DATABASE NAME must be set"
151        return TEST_DATABASE_PREFIX + name
152
153    def _execute_create_test_db(self, cursor: Any, parameters: dict[str, str]) -> None:
154        cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
155
156    def _create_test_db(
157        self, *, test_database_name: str, verbosity: int, autoclobber: bool
158    ) -> str:
159        """
160        Internal implementation - create the test db tables.
161        """
162        test_db_params = {
163            "dbname": self.connection.ops.quote_name(test_database_name),
164            "suffix": self.sql_table_creation_suffix(),
165        }
166        # Create the test database and connect to it.
167        with self._nodb_cursor() as cursor:
168            try:
169                self._execute_create_test_db(cursor, test_db_params)
170            except Exception as e:
171                self.log(f"Got an error creating the test database: {e}")
172                if not autoclobber:
173                    confirm = input(
174                        "Type 'yes' if you would like to try deleting the test "
175                        f"database '{test_database_name}', or 'no' to cancel: "
176                    )
177                if autoclobber or confirm == "yes":
178                    try:
179                        if verbosity >= 1:
180                            self.log(
181                                f"Destroying old test database '{test_database_name}'..."
182                            )
183                        cursor.execute(
184                            "DROP DATABASE {dbname}".format(**test_db_params)
185                        )
186                        self._execute_create_test_db(cursor, test_db_params)
187                    except Exception as e:
188                        self.log(f"Got an error recreating the test database: {e}")
189                        sys.exit(2)
190                else:
191                    self.log("Tests cancelled.")
192                    sys.exit(1)
193
194        return test_database_name
195
196    def destroy_test_db(
197        self, old_database_name: str | None = None, verbosity: int = 1
198    ) -> None:
199        """
200        Destroy a test database, prompting the user for confirmation if the
201        database already exists.
202        """
203        self.connection.close()
204
205        test_database_name = self.connection.settings_dict["NAME"]
206        assert test_database_name is not None, "Test database NAME must be set"
207
208        if verbosity >= 1:
209            self.log(f"Destroying test database '{test_database_name}'...")
210        self._destroy_test_db(test_database_name, verbosity)
211
212        # Restore the original database name
213        if old_database_name is not None:
214            settings.DATABASE["NAME"] = old_database_name
215            self.connection.settings_dict["NAME"] = old_database_name
216
217    def _destroy_test_db(self, test_database_name: str, verbosity: int) -> None:
218        """
219        Internal implementation - remove the test db tables.
220        """
221        # Remove the test database to clean up after
222        # ourselves. Connect to the previous database (not the test database)
223        # to do so, because it's not allowed to delete a database while being
224        # connected to it.
225        with self._nodb_cursor() as cursor:
226            cursor.execute(
227                f"DROP DATABASE {self.connection.ops.quote_name(test_database_name)}"
228            )
229
230    def sql_table_creation_suffix(self) -> str:
231        """
232        SQL to append to the end of the test table creation statements.
233        """
234        return ""
235
236    def test_db_signature(self, prefix: str = "") -> tuple[str | int, ...]:
237        """
238        Return a tuple with elements of self.connection.settings_dict (a
239        DATABASE setting value) that uniquely identify a database
240        accordingly to the RDBMS particularities.
241        """
242        settings_dict = self.connection.settings_dict
243        return (
244            settings_dict.get("HOST") or "",
245            settings_dict.get("PORT") or "",
246            settings_dict.get("ENGINE") or "",
247            self._get_test_db_name(prefix),
248        )