Plain is headed towards 1.0! Subscribe for development updates →

 1import sys
 2
 3from plain.exceptions import ImproperlyConfigured
 4from plain.models.backends.base.creation import BaseDatabaseCreation
 5from plain.models.backends.postgresql.psycopg_any import errors
 6from plain.models.backends.utils import strip_quotes
 7
 8
 9class DatabaseCreation(BaseDatabaseCreation):
10    def _quote_name(self, name):
11        return self.connection.ops.quote_name(name)
12
13    def _get_database_create_suffix(self, encoding=None, template=None):
14        suffix = ""
15        if encoding:
16            suffix += f" ENCODING '{encoding}'"
17        if template:
18            suffix += f" TEMPLATE {self._quote_name(template)}"
19        return suffix and "WITH" + suffix
20
21    def sql_table_creation_suffix(self):
22        test_settings = self.connection.settings_dict["TEST"]
23        if test_settings.get("COLLATION") is not None:
24            raise ImproperlyConfigured(
25                "PostgreSQL does not support collation setting at database "
26                "creation time."
27            )
28        return self._get_database_create_suffix(
29            encoding=test_settings["CHARSET"],
30            template=test_settings.get("TEMPLATE"),
31        )
32
33    def _database_exists(self, cursor, database_name):
34        cursor.execute(
35            "SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s",
36            [strip_quotes(database_name)],
37        )
38        return cursor.fetchone() is not None
39
40    def _execute_create_test_db(self, cursor, parameters, keepdb=False):
41        try:
42            if keepdb and self._database_exists(cursor, parameters["dbname"]):
43                # If the database should be kept and it already exists, don't
44                # try to create a new one.
45                return
46            super()._execute_create_test_db(cursor, parameters, keepdb)
47        except Exception as e:
48            cause = e.__cause__
49            if cause and not isinstance(cause, errors.DuplicateDatabase):
50                # All errors except "database already exists" cancel tests.
51                self.log("Got an error creating the test database: %s" % e)
52                sys.exit(2)
53            elif not keepdb:
54                # If the database should be kept, ignore "database already
55                # exists".
56                raise
57
58    def _clone_test_db(self, suffix, verbosity, keepdb=False):
59        # CREATE DATABASE ... WITH TEMPLATE ... requires closing connections
60        # to the template database.
61        self.connection.close()
62
63        source_database_name = self.connection.settings_dict["NAME"]
64        target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
65        test_db_params = {
66            "dbname": self._quote_name(target_database_name),
67            "suffix": self._get_database_create_suffix(template=source_database_name),
68        }
69        with self._nodb_cursor() as cursor:
70            try:
71                self._execute_create_test_db(cursor, test_db_params, keepdb)
72            except Exception:
73                try:
74                    if verbosity >= 1:
75                        self.log(
76                            "Destroying old test database for alias {}...".format(
77                                self._get_database_display_str(
78                                    verbosity, target_database_name
79                                ),
80                            )
81                        )
82                    cursor.execute("DROP DATABASE {dbname}".format(**test_db_params))
83                    self._execute_create_test_db(cursor, test_db_params, keepdb)
84                except Exception as e:
85                    self.log("Got an error cloning the test database: %s" % e)
86                    sys.exit(2)