Plain is headed towards 1.0! Subscribe for development updates →

  1import os
  2import sys
  3
  4from plain.packages import packages
  5from plain.runtime import settings
  6
  7# The prefix to put on the default database name when creating
  8# the test database.
  9TEST_DATABASE_PREFIX = "test_"
 10
 11
 12class BaseDatabaseCreation:
 13    """
 14    Encapsulate backend-specific differences pertaining to creation and
 15    destruction of the test database.
 16    """
 17
 18    def __init__(self, connection):
 19        self.connection = connection
 20
 21    def _nodb_cursor(self):
 22        return self.connection._nodb_cursor()
 23
 24    def log(self, msg):
 25        sys.stderr.write(msg + os.linesep)
 26
 27    def create_test_db(
 28        self, verbosity=1, autoclobber=False, serialize=True, keepdb=False
 29    ):
 30        """
 31        Create a test database, prompting the user for confirmation if the
 32        database already exists. Return the name of the test database created.
 33        """
 34        from plain.models.cli import migrate
 35
 36        test_database_name = self._get_test_db_name()
 37
 38        if verbosity >= 1:
 39            action = "Creating"
 40            if keepdb:
 41                action = "Using existing"
 42
 43            self.log(
 44                "{} test database for alias {}...".format(
 45                    action,
 46                    self._get_database_display_str(verbosity, test_database_name),
 47                )
 48            )
 49
 50        # We could skip this call if keepdb is True, but we instead
 51        # give it the keepdb param. This is to handle the case
 52        # where the test DB doesn't exist, in which case we need to
 53        # create it, then just not destroy it. If we instead skip
 54        # this, we will get an exception.
 55        self._create_test_db(verbosity, autoclobber, keepdb)
 56
 57        self.connection.close()
 58        settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
 59        self.connection.settings_dict["NAME"] = test_database_name
 60
 61        try:
 62            if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
 63                # Disable migrations for all packages.
 64                for app in packages.get_package_configs():
 65                    app._old_migrations_module = app.migrations_module
 66                    app.migrations_module = None
 67            # We report migrate messages at one level lower than that
 68            # requested. This ensures we don't get flooded with messages during
 69            # testing (unless you really ask to be flooded).
 70            migrate.callback(
 71                package_label=None,
 72                migration_name=None,
 73                no_input=True,
 74                database=self.connection.alias,
 75                fake=False,
 76                fake_initial=False,
 77                plan=False,
 78                check_unapplied=False,
 79                run_syncdb=True,
 80                prune=False,
 81                verbosity=max(verbosity - 1, 0),
 82            )
 83        finally:
 84            if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
 85                for app in packages.get_package_configs():
 86                    app.migrations_module = app._old_migrations_module
 87                    del app._old_migrations_module
 88
 89        # We then serialize the current state of the database into a string
 90        # and store it on the connection. This slightly horrific process is so people
 91        # who are testing on databases without transactions or who are using
 92        # a TransactionTestCase still get a clean database on every test run.
 93        # if serialize:
 94        #     self.connection._test_serialized_contents = self.serialize_db_to_string()
 95
 96        # Ensure a connection for the side effect of initializing the test database.
 97        self.connection.ensure_connection()
 98
 99        return test_database_name
100
101    def set_as_test_mirror(self, primary_settings_dict):
102        """
103        Set this database up to be used in testing as a mirror of a primary
104        database whose settings are given.
105        """
106        self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"]
107
108    # def serialize_db_to_string(self):
109    #     """
110    #     Serialize all data in the database into a JSON string.
111    #     Designed only for test runner usage; will not handle large
112    #     amounts of data.
113    #     """
114
115    #     # Iteratively return every object for all models to serialize.
116    #     def get_objects():
117    #         from plain.models.migrations.loader import MigrationLoader
118
119    #         loader = MigrationLoader(self.connection)
120    #         for package_config in packages.get_package_configs():
121    #             if (
122    #                 package_config.models_module is not None
123    #                 and package_config.label in loader.migrated_packages
124    #             ):
125    #                 for model in package_config.get_models():
126    #                     if model._meta.can_migrate(
127    #                         self.connection
128    #                     ) and router.allow_migrate_model(self.connection.alias, model):
129    #                         queryset = model._base_manager.using(
130    #                             self.connection.alias,
131    #                         ).order_by(model._meta.pk.name)
132    #                         yield from queryset.iterator()
133
134    #     # Serialize to a string
135    #     out = StringIO()
136    #     serializers.serialize("json", get_objects(), indent=None, stream=out)
137    #     return out.getvalue()
138
139    # def deserialize_db_from_string(self, data):
140    #     """
141    #     Reload the database with data from a string generated by
142    #     the serialize_db_to_string() method.
143    #     """
144    #     data = StringIO(data)
145    #     table_names = set()
146    #     # Load data in a transaction to handle forward references and cycles.
147    #     with atomic(using=self.connection.alias):
148    #         # Disable constraint checks, because some databases (MySQL) doesn't
149    #         # support deferred checks.
150    #         with self.connection.constraint_checks_disabled():
151    #             for obj in serializers.deserialize(
152    #                 "json", data, using=self.connection.alias
153    #             ):
154    #                 obj.save()
155    #                 table_names.add(obj.object.__class__._meta.db_table)
156    #         # Manually check for any invalid keys that might have been added,
157    #         # because constraint checks were disabled.
158    #         self.connection.check_constraints(table_names=table_names)
159
160    def _get_database_display_str(self, verbosity, database_name):
161        """
162        Return display string for a database for use in various actions.
163        """
164        return "'{}'{}".format(
165            self.connection.alias,
166            (" ('%s')" % database_name) if verbosity >= 2 else "",
167        )
168
169    def _get_test_db_name(self):
170        """
171        Internal implementation - return the name of the test DB that will be
172        created. Only useful when called from create_test_db() and
173        _create_test_db() and when no external munging is done with the 'NAME'
174        settings.
175        """
176        if self.connection.settings_dict["TEST"]["NAME"]:
177            return self.connection.settings_dict["TEST"]["NAME"]
178        return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"]
179
180    def _execute_create_test_db(self, cursor, parameters, keepdb=False):
181        cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
182
183    def _create_test_db(self, verbosity, autoclobber, keepdb=False):
184        """
185        Internal implementation - create the test db tables.
186        """
187        test_database_name = self._get_test_db_name()
188        test_db_params = {
189            "dbname": self.connection.ops.quote_name(test_database_name),
190            "suffix": self.sql_table_creation_suffix(),
191        }
192        # Create the test database and connect to it.
193        with self._nodb_cursor() as cursor:
194            try:
195                self._execute_create_test_db(cursor, test_db_params, keepdb)
196            except Exception as e:
197                # if we want to keep the db, then no need to do any of the below,
198                # just return and skip it all.
199                if keepdb:
200                    return test_database_name
201
202                self.log("Got an error creating the test database: %s" % e)
203                if not autoclobber:
204                    confirm = input(
205                        "Type 'yes' if you would like to try deleting the test "
206                        "database '%s', or 'no' to cancel: " % test_database_name
207                    )
208                if autoclobber or confirm == "yes":
209                    try:
210                        if verbosity >= 1:
211                            self.log(
212                                "Destroying old test database for alias {}...".format(
213                                    self._get_database_display_str(
214                                        verbosity, test_database_name
215                                    ),
216                                )
217                            )
218                        cursor.execute(
219                            "DROP DATABASE {dbname}".format(**test_db_params)
220                        )
221                        self._execute_create_test_db(cursor, test_db_params, keepdb)
222                    except Exception as e:
223                        self.log("Got an error recreating the test database: %s" % e)
224                        sys.exit(2)
225                else:
226                    self.log("Tests cancelled.")
227                    sys.exit(1)
228
229        return test_database_name
230
231    def clone_test_db(self, suffix, verbosity=1, autoclobber=False, keepdb=False):
232        """
233        Clone a test database.
234        """
235        source_database_name = self.connection.settings_dict["NAME"]
236
237        if verbosity >= 1:
238            action = "Cloning test database"
239            if keepdb:
240                action = "Using existing clone"
241            self.log(
242                "{} for alias {}...".format(
243                    action,
244                    self._get_database_display_str(verbosity, source_database_name),
245                )
246            )
247
248        # We could skip this call if keepdb is True, but we instead
249        # give it the keepdb param. See create_test_db for details.
250        self._clone_test_db(suffix, verbosity, keepdb)
251
252    def get_test_db_clone_settings(self, suffix):
253        """
254        Return a modified connection settings dict for the n-th clone of a DB.
255        """
256        # When this function is called, the test database has been created
257        # already and its name has been copied to settings_dict['NAME'] so
258        # we don't need to call _get_test_db_name.
259        orig_settings_dict = self.connection.settings_dict
260        return {
261            **orig_settings_dict,
262            "NAME": "{}_{}".format(orig_settings_dict["NAME"], suffix),
263        }
264
265    def _clone_test_db(self, suffix, verbosity, keepdb=False):
266        """
267        Internal implementation - duplicate the test db tables.
268        """
269        raise NotImplementedError(
270            "The database backend doesn't support cloning databases. "
271            "Disable the option to run tests in parallel processes."
272        )
273
274    def destroy_test_db(
275        self, old_database_name=None, verbosity=1, keepdb=False, suffix=None
276    ):
277        """
278        Destroy a test database, prompting the user for confirmation if the
279        database already exists.
280        """
281        self.connection.close()
282        if suffix is None:
283            test_database_name = self.connection.settings_dict["NAME"]
284        else:
285            test_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
286
287        if verbosity >= 1:
288            action = "Destroying"
289            if keepdb:
290                action = "Preserving"
291            self.log(
292                "{} test database for alias {}...".format(
293                    action,
294                    self._get_database_display_str(verbosity, test_database_name),
295                )
296            )
297
298        # if we want to preserve the database
299        # skip the actual destroying piece.
300        if not keepdb:
301            self._destroy_test_db(test_database_name, verbosity)
302
303        # Restore the original database name
304        if old_database_name is not None:
305            settings.DATABASES[self.connection.alias]["NAME"] = old_database_name
306            self.connection.settings_dict["NAME"] = old_database_name
307
308    def _destroy_test_db(self, test_database_name, verbosity):
309        """
310        Internal implementation - remove the test db tables.
311        """
312        # Remove the test database to clean up after
313        # ourselves. Connect to the previous database (not the test database)
314        # to do so, because it's not allowed to delete a database while being
315        # connected to it.
316        with self._nodb_cursor() as cursor:
317            cursor.execute(
318                "DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name)
319            )
320
321    def sql_table_creation_suffix(self):
322        """
323        SQL to append to the end of the test table creation statements.
324        """
325        return ""
326
327    def test_db_signature(self):
328        """
329        Return a tuple with elements of self.connection.settings_dict (a
330        DATABASES setting value) that uniquely identify a database
331        accordingly to the RDBMS particularities.
332        """
333        settings_dict = self.connection.settings_dict
334        return (
335            settings_dict["HOST"],
336            settings_dict["PORT"],
337            settings_dict["ENGINE"],
338            self._get_test_db_name(),
339        )
340
341    def setup_worker_connection(self, _worker_id):
342        settings_dict = self.get_test_db_clone_settings(str(_worker_id))
343        # connection.settings_dict must be updated in place for changes to be
344        # reflected in plain.models.connections. If the following line assigned
345        # connection.settings_dict = settings_dict, new threads would connect
346        # to the default database instead of the appropriate clone.
347        self.connection.settings_dict.update(settings_dict)
348        self.connection.close()