1from __future__ import annotations
2
3import datetime
4import os
5from collections.abc import Generator
6from pathlib import Path
7from typing import TYPE_CHECKING, Any, cast
8
9from plain.runtime import PLAIN_TEMP_PATH
10
11from .. import db_connection as _db_connection
12from .clients import PostgresBackupClient, SQLiteBackupClient
13
14if TYPE_CHECKING:
15 from plain.models.backends.base.base import BaseDatabaseWrapper
16
17 db_connection = cast("BaseDatabaseWrapper", _db_connection)
18else:
19 db_connection = _db_connection
20
21
22class DatabaseBackups:
23 def __init__(self) -> None:
24 self.path = PLAIN_TEMP_PATH / "backups"
25
26 def find_backups(self) -> list[DatabaseBackup]:
27 if not self.path.exists():
28 return []
29
30 backups = []
31
32 for backup_dir in self.path.iterdir():
33 backup = DatabaseBackup(backup_dir.name, backups_path=self.path)
34 backups.append(backup)
35
36 # Sort backups by date
37 backups.sort(key=lambda x: x.updated_at(), reverse=True)
38
39 return backups
40
41 def create(self, name: str, **create_kwargs: Any) -> Path:
42 backup = DatabaseBackup(name, backups_path=self.path)
43 if backup.exists():
44 raise Exception(f"Backup {name} already exists")
45 backup_dir = backup.create(**create_kwargs)
46 return backup_dir
47
48 def restore(self, name: str, **restore_kwargs: Any) -> None:
49 backup = DatabaseBackup(name, backups_path=self.path)
50 if not backup.exists():
51 raise Exception(f"Backup {name} not found")
52 backup.restore(**restore_kwargs)
53
54 def delete(self, name: str) -> None:
55 backup = DatabaseBackup(name, backups_path=self.path)
56 if not backup.exists():
57 raise Exception(f"Backup {name} not found")
58 backup.delete()
59
60
61class DatabaseBackup:
62 def __init__(self, name: str, *, backups_path: Path) -> None:
63 self.name = name
64 self.path = backups_path / name
65
66 if not self.name:
67 raise ValueError("Backup name is required")
68
69 def exists(self) -> bool:
70 return self.path.exists()
71
72 def create(self, **create_kwargs: Any) -> Path:
73 self.path.mkdir(parents=True, exist_ok=True)
74
75 backup_path = self.path / "default.backup"
76
77 if db_connection.vendor == "postgresql":
78 PostgresBackupClient(db_connection).create_backup(
79 backup_path,
80 pg_dump=create_kwargs.get("pg_dump", "pg_dump"),
81 )
82 elif db_connection.vendor == "sqlite":
83 SQLiteBackupClient(db_connection).create_backup(backup_path)
84 else:
85 raise Exception("Unsupported database vendor")
86
87 return self.path
88
89 def iter_files(self) -> Generator[Path, None, None]:
90 for backup_file in self.path.iterdir():
91 if not backup_file.is_file():
92 continue
93 if not backup_file.name.endswith(".backup"):
94 continue
95 yield backup_file
96
97 def restore(self, **restore_kwargs: Any) -> None:
98 for backup_file in self.iter_files():
99 if db_connection.vendor == "postgresql":
100 PostgresBackupClient(db_connection).restore_backup(
101 backup_file,
102 pg_restore=restore_kwargs.get("pg_restore", "pg_restore"),
103 )
104 elif db_connection.vendor == "sqlite":
105 SQLiteBackupClient(db_connection).restore_backup(backup_file)
106 else:
107 raise Exception("Unsupported database vendor")
108
109 def delete(self) -> None:
110 for backup_file in self.iter_files():
111 backup_file.unlink()
112
113 self.path.rmdir()
114
115 def updated_at(self) -> datetime.datetime:
116 mtime = os.path.getmtime(self.path)
117 return datetime.datetime.fromtimestamp(mtime)