1from __future__ import annotations
2
3import os
4import sys
5from typing import TYPE_CHECKING, Any
6
7from plain.runtime import settings
8
9if TYPE_CHECKING:
10 from plain.models.backends.base.base import BaseDatabaseWrapper
11
12# The prefix to put on the default database name when creating
13# the test database.
14TEST_DATABASE_PREFIX = "test_"
15
16
17class BaseDatabaseCreation:
18 """
19 Encapsulate backend-specific differences pertaining to creation and
20 destruction of the test database.
21 """
22
23 def __init__(self, connection: BaseDatabaseWrapper):
24 self.connection = connection
25
26 def _nodb_cursor(self) -> Any:
27 return self.connection._nodb_cursor()
28
29 def log(self, msg: str) -> None:
30 sys.stderr.write(msg + os.linesep)
31
32 def create_test_db(self, verbosity: int = 1, prefix: str = "") -> str:
33 """
34 Create a test database, prompting the user for confirmation if the
35 database already exists. Return the name of the test database created.
36
37 If prefix is provided, it will be prepended to the database name
38 to isolate it from other test databases.
39 """
40 from plain.models.cli.migrations import apply
41
42 test_database_name = self._get_test_db_name(prefix)
43
44 if verbosity >= 1:
45 self.log(f"Creating test database '{test_database_name}'...")
46
47 self._create_test_db(
48 test_database_name=test_database_name, verbosity=verbosity, autoclobber=True
49 )
50
51 self.connection.close()
52 settings.DATABASE["NAME"] = test_database_name
53 self.connection.settings_dict["NAME"] = test_database_name
54
55 apply.callback(
56 package_label=None,
57 migration_name=None,
58 fake=False,
59 plan=False,
60 check_unapplied=False,
61 backup=False,
62 no_input=True,
63 atomic_batch=False, # No need for atomic batch when creating test database
64 quiet=verbosity < 2, # Show migration output when verbosity is 2+
65 )
66
67 # Ensure a connection for the side effect of initializing the test database.
68 self.connection.ensure_connection()
69
70 return test_database_name
71
72 def set_as_test_mirror(self, primary_settings_dict: dict[str, Any]) -> None:
73 """
74 Set this database up to be used in testing as a mirror of a primary
75 database whose settings are given.
76 """
77 self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"]
78
79 # def serialize_db_to_string(self):
80 # """
81 # Serialize all data in the database into a JSON string.
82 # Designed only for test runner usage; will not handle large
83 # amounts of data.
84 # """
85
86 # # Iteratively return every object for all models to serialize.
87 # def get_objects():
88 # from plain.models.migrations.loader import MigrationLoader
89
90 # loader = MigrationLoader(self.connection)
91 # for package_config in packages.get_package_configs():
92 # if (
93 # package_config.models_module is not None
94 # and package_config.package_label in loader.migrated_packages
95 # ):
96 # for model in package_config.get_models():
97 # if model.model_options.can_migrate(
98 # self.connection
99 # ) and router.allow_migrate_model(self.connection.alias, model):
100 # queryset = model._base_manager.using(
101 # self.connection.alias,
102 # ).order_by("id")
103 # yield from queryset.iterator()
104
105 # # Serialize to a string
106 # out = StringIO()
107 # serializers.serialize("json", get_objects(), indent=None, stream=out)
108 # return out.getvalue()
109
110 # def deserialize_db_from_string(self, data):
111 # """
112 # Reload the database with data from a string generated by
113 # the serialize_db_to_string() method.
114 # """
115 # data = StringIO(data)
116 # table_names = set()
117 # # Load data in a transaction to handle forward references and cycles.
118 # with atomic(using=self.connection.alias):
119 # # Disable constraint checks, because some databases (MySQL) doesn't
120 # # support deferred checks.
121 # with self.connection.constraint_checks_disabled():
122 # for obj in serializers.deserialize(
123 # "json", data, using=self.connection.alias
124 # ):
125 # obj.save()
126 # table_names.add(obj.object.model_options.db_table)
127 # # Manually check for any invalid keys that might have been added,
128 # # because constraint checks were disabled.
129 # self.connection.check_constraints(table_names=table_names)
130
131 def _get_test_db_name(self, prefix: str = "") -> str:
132 """
133 Internal implementation - return the name of the test DB that will be
134 created. Only useful when called from create_test_db() and
135 _create_test_db() and when no external munging is done with the 'NAME'
136 settings.
137
138 If prefix is provided, it will be prepended to the database name.
139 """
140 # Determine the base name: explicit TEST.NAME overrides base NAME.
141 base_name = (
142 self.connection.settings_dict["TEST"]["NAME"]
143 or self.connection.settings_dict["NAME"]
144 )
145 if prefix:
146 return f"{prefix}_{base_name}"
147 if self.connection.settings_dict["TEST"]["NAME"]:
148 return self.connection.settings_dict["TEST"]["NAME"]
149 name = self.connection.settings_dict["NAME"]
150 assert name is not None, "DATABASE NAME must be set"
151 return TEST_DATABASE_PREFIX + name
152
153 def _execute_create_test_db(self, cursor: Any, parameters: dict[str, str]) -> None:
154 cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
155
156 def _create_test_db(
157 self, *, test_database_name: str, verbosity: int, autoclobber: bool
158 ) -> str:
159 """
160 Internal implementation - create the test db tables.
161 """
162 test_db_params = {
163 "dbname": self.connection.ops.quote_name(test_database_name),
164 "suffix": self.sql_table_creation_suffix(),
165 }
166 # Create the test database and connect to it.
167 with self._nodb_cursor() as cursor:
168 try:
169 self._execute_create_test_db(cursor, test_db_params)
170 except Exception as e:
171 self.log(f"Got an error creating the test database: {e}")
172 if not autoclobber:
173 confirm = input(
174 "Type 'yes' if you would like to try deleting the test "
175 f"database '{test_database_name}', or 'no' to cancel: "
176 )
177 if autoclobber or confirm == "yes":
178 try:
179 if verbosity >= 1:
180 self.log(
181 f"Destroying old test database '{test_database_name}'..."
182 )
183 cursor.execute(
184 "DROP DATABASE {dbname}".format(**test_db_params)
185 )
186 self._execute_create_test_db(cursor, test_db_params)
187 except Exception as e:
188 self.log(f"Got an error recreating the test database: {e}")
189 sys.exit(2)
190 else:
191 self.log("Tests cancelled.")
192 sys.exit(1)
193
194 return test_database_name
195
196 def destroy_test_db(
197 self, old_database_name: str | None = None, verbosity: int = 1
198 ) -> None:
199 """
200 Destroy a test database, prompting the user for confirmation if the
201 database already exists.
202 """
203 self.connection.close()
204
205 test_database_name = self.connection.settings_dict["NAME"]
206 assert test_database_name is not None, "Test database NAME must be set"
207
208 if verbosity >= 1:
209 self.log(f"Destroying test database '{test_database_name}'...")
210 self._destroy_test_db(test_database_name, verbosity)
211
212 # Restore the original database name
213 if old_database_name is not None:
214 settings.DATABASE["NAME"] = old_database_name
215 self.connection.settings_dict["NAME"] = old_database_name
216
217 def _destroy_test_db(self, test_database_name: str, verbosity: int) -> None:
218 """
219 Internal implementation - remove the test db tables.
220 """
221 # Remove the test database to clean up after
222 # ourselves. Connect to the previous database (not the test database)
223 # to do so, because it's not allowed to delete a database while being
224 # connected to it.
225 with self._nodb_cursor() as cursor:
226 cursor.execute(
227 f"DROP DATABASE {self.connection.ops.quote_name(test_database_name)}"
228 )
229
230 def sql_table_creation_suffix(self) -> str:
231 """
232 SQL to append to the end of the test table creation statements.
233 """
234 return ""
235
236 def test_db_signature(self, prefix: str = "") -> tuple[str | int, ...]:
237 """
238 Return a tuple with elements of self.connection.settings_dict (a
239 DATABASE setting value) that uniquely identify a database
240 accordingly to the RDBMS particularities.
241 """
242 settings_dict = self.connection.settings_dict
243 return (
244 settings_dict.get("HOST") or "",
245 settings_dict.get("PORT") or "",
246 settings_dict.get("ENGINE") or "",
247 self._get_test_db_name(prefix),
248 )