1import datetime
2import os
3from pathlib import Path
4
5from plain.runtime import settings
6
7from .. import connections
8from .clients import PostgresBackupClient
9
10
11class DatabaseBackups:
12 def __init__(self):
13 self.path = settings.PLAIN_TEMP_PATH / "backups"
14
15 def find_backups(self):
16 if not self.path.exists():
17 return []
18
19 backups = []
20
21 for backup_dir in self.path.iterdir():
22 backup = DatabaseBackup(backup_dir.name, backups_path=self.path)
23 backups.append(backup)
24
25 # Sort backups by date
26 backups.sort(key=lambda x: x.updated_at(), reverse=True)
27
28 return backups
29
30 def create(self, name, **create_kwargs):
31 backup = DatabaseBackup(name, backups_path=self.path)
32 if backup.exists():
33 raise Exception(f"Backup {name} already exists")
34 backup_dir = backup.create(**create_kwargs)
35 return backup_dir
36
37 def restore(self, name, **restore_kwargs):
38 backup = DatabaseBackup(name, backups_path=self.path)
39 if not backup.exists():
40 raise Exception(f"Backup {name} not found")
41 backup.restore(**restore_kwargs)
42
43 def delete(self, name):
44 backup = DatabaseBackup(name, backups_path=self.path)
45 if not backup.exists():
46 raise Exception(f"Backup {name} not found")
47 backup.delete()
48
49
50class DatabaseBackup:
51 def __init__(self, name: str, *, backups_path: Path):
52 self.name = name
53 self.path = backups_path / name
54
55 if not self.name:
56 raise ValueError("Backup name is required")
57
58 def exists(self):
59 return self.path.exists()
60
61 def create(self, **create_kwargs):
62 self.path.mkdir(parents=True, exist_ok=True)
63
64 for connection_alias in connections:
65 connection = connections[connection_alias]
66 backup_path = self.path / f"{connection_alias}.backup"
67 if connection.vendor == "postgresql":
68 PostgresBackupClient(connection).create_backup(
69 backup_path, **create_kwargs
70 )
71 else:
72 raise Exception("Unsupported database vendor")
73
74 return self.path
75
76 def iter_files(self):
77 for backup_file in self.path.iterdir():
78 if not backup_file.is_file():
79 continue
80 if not backup_file.name.endswith(".backup"):
81 continue
82 yield backup_file
83
84 def restore(self, **restore_kwargs):
85 for backup_file in self.iter_files():
86 connection_alias = backup_file.stem
87 connection = connections[connection_alias]
88 if not connection:
89 raise Exception(f"Connection {connection_alias} not found")
90
91 if connection.vendor == "postgresql":
92 PostgresBackupClient(connection).restore_backup(
93 backup_file, **restore_kwargs
94 )
95 else:
96 raise Exception("Unsupported database vendor")
97
98 def delete(self):
99 for backup_file in self.iter_files():
100 backup_file.unlink()
101
102 self.path.rmdir()
103
104 def updated_at(self):
105 mtime = os.path.getmtime(self.path)
106 return datetime.datetime.fromtimestamp(mtime)