1import datetime
2import os
3from pathlib import Path
4
5from plain.runtime import PLAIN_TEMP_PATH
6
7from .. import db_connection
8from .clients import PostgresBackupClient, SQLiteBackupClient
9
10
11class DatabaseBackups:
12 def __init__(self):
13 self.path = PLAIN_TEMP_PATH / "backups"
14
15 def find_backups(self):
16 if not self.path.exists():
17 return []
18
19 backups = []
20
21 for backup_dir in self.path.iterdir():
22 backup = DatabaseBackup(backup_dir.name, backups_path=self.path)
23 backups.append(backup)
24
25 # Sort backups by date
26 backups.sort(key=lambda x: x.updated_at(), reverse=True)
27
28 return backups
29
30 def create(self, name, **create_kwargs):
31 backup = DatabaseBackup(name, backups_path=self.path)
32 if backup.exists():
33 raise Exception(f"Backup {name} already exists")
34 backup_dir = backup.create(**create_kwargs)
35 return backup_dir
36
37 def restore(self, name, **restore_kwargs):
38 backup = DatabaseBackup(name, backups_path=self.path)
39 if not backup.exists():
40 raise Exception(f"Backup {name} not found")
41 backup.restore(**restore_kwargs)
42
43 def delete(self, name):
44 backup = DatabaseBackup(name, backups_path=self.path)
45 if not backup.exists():
46 raise Exception(f"Backup {name} not found")
47 backup.delete()
48
49
50class DatabaseBackup:
51 def __init__(self, name: str, *, backups_path: Path):
52 self.name = name
53 self.path = backups_path / name
54
55 if not self.name:
56 raise ValueError("Backup name is required")
57
58 def exists(self):
59 return self.path.exists()
60
61 def create(self, **create_kwargs):
62 self.path.mkdir(parents=True, exist_ok=True)
63
64 backup_path = self.path / "default.backup"
65
66 if db_connection.vendor == "postgresql":
67 PostgresBackupClient(db_connection).create_backup(
68 backup_path,
69 pg_dump=create_kwargs.get("pg_dump", "pg_dump"),
70 )
71 elif db_connection.vendor == "sqlite":
72 SQLiteBackupClient(db_connection).create_backup(backup_path)
73 else:
74 raise Exception("Unsupported database vendor")
75
76 return self.path
77
78 def iter_files(self):
79 for backup_file in self.path.iterdir():
80 if not backup_file.is_file():
81 continue
82 if not backup_file.name.endswith(".backup"):
83 continue
84 yield backup_file
85
86 def restore(self, **restore_kwargs):
87 for backup_file in self.iter_files():
88 if db_connection.vendor == "postgresql":
89 PostgresBackupClient(db_connection).restore_backup(
90 backup_file,
91 pg_restore=restore_kwargs.get("pg_restore", "pg_restore"),
92 )
93 elif db_connection.vendor == "sqlite":
94 SQLiteBackupClient(db_connection).restore_backup(backup_file)
95 else:
96 raise Exception("Unsupported database vendor")
97
98 def delete(self):
99 for backup_file in self.iter_files():
100 backup_file.unlink()
101
102 self.path.rmdir()
103
104 def updated_at(self):
105 mtime = os.path.getmtime(self.path)
106 return datetime.datetime.fromtimestamp(mtime)