1from __future__ import annotations
2
3import os
4import sys
5from typing import TYPE_CHECKING, Any
6
7from plain.runtime import settings
8
9if TYPE_CHECKING:
10 from plain.models.backends.base.base import BaseDatabaseWrapper
11
12# The prefix to put on the default database name when creating
13# the test database.
14TEST_DATABASE_PREFIX = "test_"
15
16
17class BaseDatabaseCreation:
18 """
19 Encapsulate backend-specific differences pertaining to creation and
20 destruction of the test database.
21 """
22
23 def __init__(self, connection: BaseDatabaseWrapper):
24 self.connection = connection
25
26 def _nodb_cursor(self) -> Any:
27 return self.connection._nodb_cursor()
28
29 def log(self, msg: str) -> None:
30 sys.stderr.write(msg + os.linesep)
31
32 def create_test_db(self, verbosity: int = 1, prefix: str = "") -> str:
33 """
34 Create a test database, prompting the user for confirmation if the
35 database already exists. Return the name of the test database created.
36
37 If prefix is provided, it will be prepended to the database name
38 to isolate it from other test databases.
39 """
40 from plain.models.cli import migrate
41
42 test_database_name = self._get_test_db_name(prefix)
43
44 if verbosity >= 1:
45 self.log(f"Creating test database '{test_database_name}'...")
46
47 self._create_test_db(
48 test_database_name=test_database_name, verbosity=verbosity, autoclobber=True
49 )
50
51 self.connection.close()
52 settings.DATABASE["NAME"] = test_database_name
53 self.connection.settings_dict["NAME"] = test_database_name
54
55 migrate.callback(
56 package_label=None,
57 migration_name=None,
58 fake=False,
59 plan=False,
60 check_unapplied=False,
61 backup=False,
62 no_input=True,
63 atomic_batch=False, # No need for atomic batch when creating test database
64 quiet=verbosity < 2, # Show migration output when verbosity is 2+
65 )
66
67 # Ensure a connection for the side effect of initializing the test database.
68 self.connection.ensure_connection()
69
70 return test_database_name
71
72 def set_as_test_mirror(self, primary_settings_dict: dict[str, Any]) -> None:
73 """
74 Set this database up to be used in testing as a mirror of a primary
75 database whose settings are given.
76 """
77 self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"]
78
79 # def serialize_db_to_string(self):
80 # """
81 # Serialize all data in the database into a JSON string.
82 # Designed only for test runner usage; will not handle large
83 # amounts of data.
84 # """
85
86 # # Iteratively return every object for all models to serialize.
87 # def get_objects():
88 # from plain.models.migrations.loader import MigrationLoader
89
90 # loader = MigrationLoader(self.connection)
91 # for package_config in packages.get_package_configs():
92 # if (
93 # package_config.models_module is not None
94 # and package_config.package_label in loader.migrated_packages
95 # ):
96 # for model in package_config.get_models():
97 # if model.model_options.can_migrate(
98 # self.connection
99 # ) and router.allow_migrate_model(self.connection.alias, model):
100 # queryset = model._base_manager.using(
101 # self.connection.alias,
102 # ).order_by("id")
103 # yield from queryset.iterator()
104
105 # # Serialize to a string
106 # out = StringIO()
107 # serializers.serialize("json", get_objects(), indent=None, stream=out)
108 # return out.getvalue()
109
110 # def deserialize_db_from_string(self, data):
111 # """
112 # Reload the database with data from a string generated by
113 # the serialize_db_to_string() method.
114 # """
115 # data = StringIO(data)
116 # table_names = set()
117 # # Load data in a transaction to handle forward references and cycles.
118 # with atomic(using=self.connection.alias):
119 # # Disable constraint checks, because some databases (MySQL) doesn't
120 # # support deferred checks.
121 # with self.connection.constraint_checks_disabled():
122 # for obj in serializers.deserialize(
123 # "json", data, using=self.connection.alias
124 # ):
125 # obj.save()
126 # table_names.add(obj.object.model_options.db_table)
127 # # Manually check for any invalid keys that might have been added,
128 # # because constraint checks were disabled.
129 # self.connection.check_constraints(table_names=table_names)
130
131 def _get_test_db_name(self, prefix: str = "") -> str:
132 """
133 Internal implementation - return the name of the test DB that will be
134 created. Only useful when called from create_test_db() and
135 _create_test_db() and when no external munging is done with the 'NAME'
136 settings.
137
138 If prefix is provided, it will be prepended to the database name.
139 """
140 # Determine the base name: explicit TEST.NAME overrides base NAME.
141 base_name = (
142 self.connection.settings_dict["TEST"]["NAME"]
143 or self.connection.settings_dict["NAME"]
144 )
145 if prefix:
146 return f"{prefix}_{base_name}"
147 if self.connection.settings_dict["TEST"]["NAME"]:
148 return self.connection.settings_dict["TEST"]["NAME"]
149 return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"]
150
151 def _execute_create_test_db(self, cursor: Any, parameters: dict[str, str]) -> None:
152 cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
153
154 def _create_test_db(
155 self, *, test_database_name: str, verbosity: int, autoclobber: bool
156 ) -> str:
157 """
158 Internal implementation - create the test db tables.
159 """
160 test_db_params = {
161 "dbname": self.connection.ops.quote_name(test_database_name),
162 "suffix": self.sql_table_creation_suffix(),
163 }
164 # Create the test database and connect to it.
165 with self._nodb_cursor() as cursor:
166 try:
167 self._execute_create_test_db(cursor, test_db_params)
168 except Exception as e:
169 self.log(f"Got an error creating the test database: {e}")
170 if not autoclobber:
171 confirm = input(
172 "Type 'yes' if you would like to try deleting the test "
173 f"database '{test_database_name}', or 'no' to cancel: "
174 )
175 if autoclobber or confirm == "yes":
176 try:
177 if verbosity >= 1:
178 self.log(
179 f"Destroying old test database '{test_database_name}'..."
180 )
181 cursor.execute(
182 "DROP DATABASE {dbname}".format(**test_db_params)
183 )
184 self._execute_create_test_db(cursor, test_db_params)
185 except Exception as e:
186 self.log(f"Got an error recreating the test database: {e}")
187 sys.exit(2)
188 else:
189 self.log("Tests cancelled.")
190 sys.exit(1)
191
192 return test_database_name
193
194 def destroy_test_db(
195 self, old_database_name: str | None = None, verbosity: int = 1
196 ) -> None:
197 """
198 Destroy a test database, prompting the user for confirmation if the
199 database already exists.
200 """
201 self.connection.close()
202
203 test_database_name = self.connection.settings_dict["NAME"]
204
205 if verbosity >= 1:
206 self.log(f"Destroying test database '{test_database_name}'...")
207 self._destroy_test_db(test_database_name, verbosity)
208
209 # Restore the original database name
210 if old_database_name is not None:
211 settings.DATABASE["NAME"] = old_database_name
212 self.connection.settings_dict["NAME"] = old_database_name
213
214 def _destroy_test_db(self, test_database_name: str, verbosity: int) -> None:
215 """
216 Internal implementation - remove the test db tables.
217 """
218 # Remove the test database to clean up after
219 # ourselves. Connect to the previous database (not the test database)
220 # to do so, because it's not allowed to delete a database while being
221 # connected to it.
222 with self._nodb_cursor() as cursor:
223 cursor.execute(
224 f"DROP DATABASE {self.connection.ops.quote_name(test_database_name)}"
225 )
226
227 def sql_table_creation_suffix(self) -> str:
228 """
229 SQL to append to the end of the test table creation statements.
230 """
231 return ""
232
233 def test_db_signature(self, prefix: str = "") -> tuple[str, str, str, str]:
234 """
235 Return a tuple with elements of self.connection.settings_dict (a
236 DATABASE setting value) that uniquely identify a database
237 accordingly to the RDBMS particularities.
238 """
239 settings_dict = self.connection.settings_dict
240 return (
241 settings_dict["HOST"],
242 settings_dict["PORT"],
243 settings_dict["ENGINE"],
244 self._get_test_db_name(prefix),
245 )