1from __future__ import annotations
2
3import datetime
4import json
5import os
6import subprocess
7from pathlib import Path
8from typing import TYPE_CHECKING, Any, cast
9
10from plain.runtime import PLAIN_TEMP_PATH
11
12from .. import db_connection as _db_connection
13from .clients import PostgresBackupClient, SQLiteBackupClient
14
15
16def get_git_branch() -> str | None:
17 """Get current git branch, or None if not in a git repo."""
18 try:
19 result = subprocess.run(
20 ["git", "rev-parse", "--abbrev-ref", "HEAD"],
21 capture_output=True,
22 text=True,
23 check=True,
24 )
25 return result.stdout.strip()
26 except (subprocess.CalledProcessError, FileNotFoundError):
27 return None
28
29
30def get_git_commit() -> str | None:
31 """Get current git commit (short hash), or None if not in a git repo."""
32 try:
33 result = subprocess.run(
34 ["git", "rev-parse", "--short", "HEAD"],
35 capture_output=True,
36 text=True,
37 check=True,
38 )
39 return result.stdout.strip()
40 except (subprocess.CalledProcessError, FileNotFoundError):
41 return None
42
43
44if TYPE_CHECKING:
45 from plain.models.backends.base.base import BaseDatabaseWrapper
46
47# Cast for type checkers; runtime value is _db_connection (DatabaseConnection)
48db_connection = cast("BaseDatabaseWrapper", _db_connection)
49
50
51class DatabaseBackups:
52 def __init__(self) -> None:
53 self.path = PLAIN_TEMP_PATH / "backups"
54
55 def find_backups(self) -> list[DatabaseBackup]:
56 if not self.path.exists():
57 return []
58
59 backups = []
60
61 for backup_dir in self.path.iterdir():
62 backup = DatabaseBackup(backup_dir.name, backups_path=self.path)
63 backups.append(backup)
64
65 # Sort backups by date
66 backups.sort(key=lambda x: x.updated_at(), reverse=True)
67
68 return backups
69
70 def create(self, name: str, **create_kwargs: Any) -> Path:
71 backup = DatabaseBackup(name, backups_path=self.path)
72 if backup.exists():
73 raise Exception(f"Backup {name} already exists")
74 backup_dir = backup.create(**create_kwargs)
75 return backup_dir
76
77 def restore(self, name: str, **restore_kwargs: Any) -> None:
78 backup = DatabaseBackup(name, backups_path=self.path)
79 if not backup.exists():
80 raise Exception(f"Backup {name} not found")
81 backup.restore(**restore_kwargs)
82
83 def delete(self, name: str) -> None:
84 backup = DatabaseBackup(name, backups_path=self.path)
85 if not backup.exists():
86 raise Exception(f"Backup {name} not found")
87 backup.delete()
88
89
90class DatabaseBackup:
91 def __init__(self, name: str, *, backups_path: Path) -> None:
92 self.name = name
93 self.path = backups_path / name
94
95 if not self.name:
96 raise ValueError("Backup name is required")
97
98 def exists(self) -> bool:
99 return self.path.exists()
100
101 def create(self, *, source: str = "manual", **create_kwargs: Any) -> Path:
102 self.path.mkdir(parents=True, exist_ok=True)
103
104 backup_path = self.path / "default.backup"
105
106 if db_connection.vendor == "postgresql":
107 PostgresBackupClient(db_connection).create_backup(
108 backup_path,
109 pg_dump=create_kwargs.get("pg_dump", "pg_dump"),
110 )
111 elif db_connection.vendor == "sqlite":
112 SQLiteBackupClient(db_connection).create_backup(backup_path)
113 else:
114 raise Exception("Unsupported database vendor")
115
116 # Write metadata
117 metadata = {
118 "created_at": datetime.datetime.now(datetime.UTC).isoformat(),
119 "source": source,
120 "git_branch": get_git_branch(),
121 "git_commit": get_git_commit(),
122 }
123 metadata_path = self.path / "metadata.json"
124 with open(metadata_path, "w") as f:
125 json.dump(metadata, f, indent=2)
126
127 return self.path
128
129 def restore(self, **restore_kwargs: Any) -> None:
130 backup_file = self.path / "default.backup"
131
132 if db_connection.vendor == "postgresql":
133 PostgresBackupClient(db_connection).restore_backup(
134 backup_file,
135 pg_restore=restore_kwargs.get("pg_restore", "pg_restore"),
136 )
137 elif db_connection.vendor == "sqlite":
138 SQLiteBackupClient(db_connection).restore_backup(backup_file)
139 else:
140 raise Exception("Unsupported database vendor")
141
142 @property
143 def metadata(self) -> dict[str, Any]:
144 """Read metadata from metadata.json, with fallback for old backups."""
145 metadata_path = self.path / "metadata.json"
146 if metadata_path.exists():
147 with open(metadata_path) as f:
148 return json.load(f)
149
150 return {
151 "created_at": None,
152 "source": None,
153 "git_branch": None,
154 "git_commit": None,
155 }
156
157 def delete(self) -> None:
158 backup_file = self.path / "default.backup"
159 backup_file.unlink()
160 metadata_file = self.path / "metadata.json"
161 if metadata_file.exists():
162 metadata_file.unlink()
163 self.path.rmdir()
164
165 def updated_at(self) -> datetime.datetime:
166 mtime = os.path.getmtime(self.path)
167 return datetime.datetime.fromtimestamp(mtime)