1import os
2import sys
3
4from plain.packages import packages
5from plain.runtime import settings
6
7# The prefix to put on the default database name when creating
8# the test database.
9TEST_DATABASE_PREFIX = "test_"
10
11
12class BaseDatabaseCreation:
13 """
14 Encapsulate backend-specific differences pertaining to creation and
15 destruction of the test database.
16 """
17
18 def __init__(self, connection):
19 self.connection = connection
20
21 def _nodb_cursor(self):
22 return self.connection._nodb_cursor()
23
24 def log(self, msg):
25 sys.stderr.write(msg + os.linesep)
26
27 def create_test_db(
28 self, verbosity=1, autoclobber=False, serialize=True, keepdb=False
29 ):
30 """
31 Create a test database, prompting the user for confirmation if the
32 database already exists. Return the name of the test database created.
33 """
34 from plain.models.cli import migrate
35
36 test_database_name = self._get_test_db_name()
37
38 if verbosity >= 1:
39 action = "Creating"
40 if keepdb:
41 action = "Using existing"
42
43 self.log(
44 "{} test database for alias {}...".format(
45 action,
46 self._get_database_display_str(verbosity, test_database_name),
47 )
48 )
49
50 # We could skip this call if keepdb is True, but we instead
51 # give it the keepdb param. This is to handle the case
52 # where the test DB doesn't exist, in which case we need to
53 # create it, then just not destroy it. If we instead skip
54 # this, we will get an exception.
55 self._create_test_db(verbosity, autoclobber, keepdb)
56
57 self.connection.close()
58 settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
59 self.connection.settings_dict["NAME"] = test_database_name
60
61 try:
62 if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
63 # Disable migrations for all packages.
64 for app in packages.get_package_configs():
65 app._old_migrations_module = app.migrations_module
66 app.migrations_module = None
67 # We report migrate messages at one level lower than that
68 # requested. This ensures we don't get flooded with messages during
69 # testing (unless you really ask to be flooded).
70 migrate.callback(
71 package_label=None,
72 migration_name=None,
73 no_input=True,
74 database=self.connection.alias,
75 fake=False,
76 fake_initial=False,
77 plan=False,
78 check_unapplied=False,
79 run_syncdb=True,
80 prune=False,
81 verbosity=max(verbosity - 1, 0),
82 )
83 finally:
84 if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
85 for app in packages.get_package_configs():
86 app.migrations_module = app._old_migrations_module
87 del app._old_migrations_module
88
89 # We then serialize the current state of the database into a string
90 # and store it on the connection. This slightly horrific process is so people
91 # who are testing on databases without transactions or who are using
92 # a TransactionTestCase still get a clean database on every test run.
93 # if serialize:
94 # self.connection._test_serialized_contents = self.serialize_db_to_string()
95
96 # Ensure a connection for the side effect of initializing the test database.
97 self.connection.ensure_connection()
98
99 return test_database_name
100
101 def set_as_test_mirror(self, primary_settings_dict):
102 """
103 Set this database up to be used in testing as a mirror of a primary
104 database whose settings are given.
105 """
106 self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"]
107
108 # def serialize_db_to_string(self):
109 # """
110 # Serialize all data in the database into a JSON string.
111 # Designed only for test runner usage; will not handle large
112 # amounts of data.
113 # """
114
115 # # Iteratively return every object for all models to serialize.
116 # def get_objects():
117 # from plain.models.migrations.loader import MigrationLoader
118
119 # loader = MigrationLoader(self.connection)
120 # for package_config in packages.get_package_configs():
121 # if (
122 # package_config.models_module is not None
123 # and package_config.label in loader.migrated_packages
124 # ):
125 # for model in package_config.get_models():
126 # if model._meta.can_migrate(
127 # self.connection
128 # ) and router.allow_migrate_model(self.connection.alias, model):
129 # queryset = model._base_manager.using(
130 # self.connection.alias,
131 # ).order_by(model._meta.pk.name)
132 # yield from queryset.iterator()
133
134 # # Serialize to a string
135 # out = StringIO()
136 # serializers.serialize("json", get_objects(), indent=None, stream=out)
137 # return out.getvalue()
138
139 # def deserialize_db_from_string(self, data):
140 # """
141 # Reload the database with data from a string generated by
142 # the serialize_db_to_string() method.
143 # """
144 # data = StringIO(data)
145 # table_names = set()
146 # # Load data in a transaction to handle forward references and cycles.
147 # with atomic(using=self.connection.alias):
148 # # Disable constraint checks, because some databases (MySQL) doesn't
149 # # support deferred checks.
150 # with self.connection.constraint_checks_disabled():
151 # for obj in serializers.deserialize(
152 # "json", data, using=self.connection.alias
153 # ):
154 # obj.save()
155 # table_names.add(obj.object.__class__._meta.db_table)
156 # # Manually check for any invalid keys that might have been added,
157 # # because constraint checks were disabled.
158 # self.connection.check_constraints(table_names=table_names)
159
160 def _get_database_display_str(self, verbosity, database_name):
161 """
162 Return display string for a database for use in various actions.
163 """
164 return "'{}'{}".format(
165 self.connection.alias,
166 (" ('%s')" % database_name) if verbosity >= 2 else "",
167 )
168
169 def _get_test_db_name(self):
170 """
171 Internal implementation - return the name of the test DB that will be
172 created. Only useful when called from create_test_db() and
173 _create_test_db() and when no external munging is done with the 'NAME'
174 settings.
175 """
176 if self.connection.settings_dict["TEST"]["NAME"]:
177 return self.connection.settings_dict["TEST"]["NAME"]
178 return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"]
179
180 def _execute_create_test_db(self, cursor, parameters, keepdb=False):
181 cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
182
183 def _create_test_db(self, verbosity, autoclobber, keepdb=False):
184 """
185 Internal implementation - create the test db tables.
186 """
187 test_database_name = self._get_test_db_name()
188 test_db_params = {
189 "dbname": self.connection.ops.quote_name(test_database_name),
190 "suffix": self.sql_table_creation_suffix(),
191 }
192 # Create the test database and connect to it.
193 with self._nodb_cursor() as cursor:
194 try:
195 self._execute_create_test_db(cursor, test_db_params, keepdb)
196 except Exception as e:
197 # if we want to keep the db, then no need to do any of the below,
198 # just return and skip it all.
199 if keepdb:
200 return test_database_name
201
202 self.log("Got an error creating the test database: %s" % e)
203 if not autoclobber:
204 confirm = input(
205 "Type 'yes' if you would like to try deleting the test "
206 "database '%s', or 'no' to cancel: " % test_database_name
207 )
208 if autoclobber or confirm == "yes":
209 try:
210 if verbosity >= 1:
211 self.log(
212 "Destroying old test database for alias {}...".format(
213 self._get_database_display_str(
214 verbosity, test_database_name
215 ),
216 )
217 )
218 cursor.execute(
219 "DROP DATABASE {dbname}".format(**test_db_params)
220 )
221 self._execute_create_test_db(cursor, test_db_params, keepdb)
222 except Exception as e:
223 self.log("Got an error recreating the test database: %s" % e)
224 sys.exit(2)
225 else:
226 self.log("Tests cancelled.")
227 sys.exit(1)
228
229 return test_database_name
230
231 def clone_test_db(self, suffix, verbosity=1, autoclobber=False, keepdb=False):
232 """
233 Clone a test database.
234 """
235 source_database_name = self.connection.settings_dict["NAME"]
236
237 if verbosity >= 1:
238 action = "Cloning test database"
239 if keepdb:
240 action = "Using existing clone"
241 self.log(
242 "{} for alias {}...".format(
243 action,
244 self._get_database_display_str(verbosity, source_database_name),
245 )
246 )
247
248 # We could skip this call if keepdb is True, but we instead
249 # give it the keepdb param. See create_test_db for details.
250 self._clone_test_db(suffix, verbosity, keepdb)
251
252 def get_test_db_clone_settings(self, suffix):
253 """
254 Return a modified connection settings dict for the n-th clone of a DB.
255 """
256 # When this function is called, the test database has been created
257 # already and its name has been copied to settings_dict['NAME'] so
258 # we don't need to call _get_test_db_name.
259 orig_settings_dict = self.connection.settings_dict
260 return {
261 **orig_settings_dict,
262 "NAME": "{}_{}".format(orig_settings_dict["NAME"], suffix),
263 }
264
265 def _clone_test_db(self, suffix, verbosity, keepdb=False):
266 """
267 Internal implementation - duplicate the test db tables.
268 """
269 raise NotImplementedError(
270 "The database backend doesn't support cloning databases. "
271 "Disable the option to run tests in parallel processes."
272 )
273
274 def destroy_test_db(
275 self, old_database_name=None, verbosity=1, keepdb=False, suffix=None
276 ):
277 """
278 Destroy a test database, prompting the user for confirmation if the
279 database already exists.
280 """
281 self.connection.close()
282 if suffix is None:
283 test_database_name = self.connection.settings_dict["NAME"]
284 else:
285 test_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
286
287 if verbosity >= 1:
288 action = "Destroying"
289 if keepdb:
290 action = "Preserving"
291 self.log(
292 "{} test database for alias {}...".format(
293 action,
294 self._get_database_display_str(verbosity, test_database_name),
295 )
296 )
297
298 # if we want to preserve the database
299 # skip the actual destroying piece.
300 if not keepdb:
301 self._destroy_test_db(test_database_name, verbosity)
302
303 # Restore the original database name
304 if old_database_name is not None:
305 settings.DATABASES[self.connection.alias]["NAME"] = old_database_name
306 self.connection.settings_dict["NAME"] = old_database_name
307
308 def _destroy_test_db(self, test_database_name, verbosity):
309 """
310 Internal implementation - remove the test db tables.
311 """
312 # Remove the test database to clean up after
313 # ourselves. Connect to the previous database (not the test database)
314 # to do so, because it's not allowed to delete a database while being
315 # connected to it.
316 with self._nodb_cursor() as cursor:
317 cursor.execute(
318 "DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name)
319 )
320
321 def sql_table_creation_suffix(self):
322 """
323 SQL to append to the end of the test table creation statements.
324 """
325 return ""
326
327 def test_db_signature(self):
328 """
329 Return a tuple with elements of self.connection.settings_dict (a
330 DATABASES setting value) that uniquely identify a database
331 accordingly to the RDBMS particularities.
332 """
333 settings_dict = self.connection.settings_dict
334 return (
335 settings_dict["HOST"],
336 settings_dict["PORT"],
337 settings_dict["ENGINE"],
338 self._get_test_db_name(),
339 )
340
341 def setup_worker_connection(self, _worker_id):
342 settings_dict = self.get_test_db_clone_settings(str(_worker_id))
343 # connection.settings_dict must be updated in place for changes to be
344 # reflected in plain.models.connections. If the following line assigned
345 # connection.settings_dict = settings_dict, new threads would connect
346 # to the default database instead of the appropriate clone.
347 self.connection.settings_dict.update(settings_dict)
348 self.connection.close()