1from __future__ import annotations
2
3import datetime
4import json
5import os
6import subprocess
7from pathlib import Path
8from typing import Any
9
10from plain.runtime import PLAIN_TEMP_PATH
11
12from .clients import PostgresBackupClient
13
14
15def get_git_branch() -> str | None:
16 """Get current git branch, or None if not in a git repo."""
17 try:
18 result = subprocess.run(
19 ["git", "rev-parse", "--abbrev-ref", "HEAD"],
20 capture_output=True,
21 text=True,
22 check=True,
23 )
24 return result.stdout.strip()
25 except (subprocess.CalledProcessError, FileNotFoundError):
26 return None
27
28
29def get_git_commit() -> str | None:
30 """Get current git commit (short hash), or None if not in a git repo."""
31 try:
32 result = subprocess.run(
33 ["git", "rev-parse", "--short", "HEAD"],
34 capture_output=True,
35 text=True,
36 check=True,
37 )
38 return result.stdout.strip()
39 except (subprocess.CalledProcessError, FileNotFoundError):
40 return None
41
42
43class DatabaseBackups:
44 def __init__(self) -> None:
45 self.path = PLAIN_TEMP_PATH / "backups"
46
47 def find_backups(self) -> list[DatabaseBackup]:
48 if not self.path.exists():
49 return []
50
51 backups = []
52
53 for backup_dir in self.path.iterdir():
54 backup = DatabaseBackup(backup_dir.name, backups_path=self.path)
55 backups.append(backup)
56
57 # Sort backups by date
58 backups.sort(key=lambda x: x.updated_at(), reverse=True)
59
60 return backups
61
62 def create(
63 self, name: str, *, source: str = "manual", pg_dump: str = "pg_dump"
64 ) -> Path:
65 backup = DatabaseBackup(name, backups_path=self.path)
66 if backup.exists():
67 raise Exception(f"Backup {name} already exists")
68 backup_dir = backup.create(source=source, pg_dump=pg_dump)
69 try:
70 self.prune()
71 except Exception:
72 pass
73 return backup_dir
74
75 def prune(self) -> list[str]:
76 """Delete oldest backups on the current branch (or with no branch), keeping the most recent 20."""
77 keep = 20
78 current_branch = get_git_branch()
79 backups = self.find_backups() # sorted newest-first
80
81 # Only prune backups matching the current branch or with no branch metadata
82 prunable = [
83 b for b in backups if b.metadata.get("git_branch") in (current_branch, None)
84 ]
85
86 deleted = []
87 for backup in prunable[keep:]:
88 backup.delete()
89 deleted.append(backup.name)
90 return deleted
91
92 def restore(self, name: str, *, pg_restore: str = "pg_restore") -> None:
93 backup = DatabaseBackup(name, backups_path=self.path)
94 if not backup.exists():
95 raise Exception(f"Backup {name} not found")
96 backup.restore(pg_restore=pg_restore)
97
98 def delete(self, name: str) -> None:
99 backup = DatabaseBackup(name, backups_path=self.path)
100 if not backup.exists():
101 raise Exception(f"Backup {name} not found")
102 backup.delete()
103
104
105class DatabaseBackup:
106 def __init__(self, name: str, *, backups_path: Path) -> None:
107 self.name = name
108 self.path = backups_path / name
109
110 if not self.name:
111 raise ValueError("Backup name is required")
112
113 def exists(self) -> bool:
114 return self.path.exists()
115
116 def create(self, *, source: str = "manual", pg_dump: str = "pg_dump") -> Path:
117 from plain.postgres.db import get_connection
118
119 self.path.mkdir(parents=True, exist_ok=True)
120
121 backup_path = self.path / "default.backup"
122
123 PostgresBackupClient(get_connection()).create_backup(
124 backup_path,
125 pg_dump=pg_dump,
126 )
127
128 # Write metadata
129 metadata = {
130 "created_at": datetime.datetime.now(datetime.UTC).isoformat(),
131 "source": source,
132 "git_branch": get_git_branch(),
133 "git_commit": get_git_commit(),
134 }
135 metadata_path = self.path / "metadata.json"
136 with open(metadata_path, "w") as f:
137 json.dump(metadata, f, indent=2)
138
139 return self.path
140
141 def restore(self, *, pg_restore: str = "pg_restore") -> None:
142 from plain.postgres.db import get_connection
143
144 backup_file = self.path / "default.backup"
145
146 PostgresBackupClient(get_connection()).restore_backup(
147 backup_file,
148 pg_restore=pg_restore,
149 )
150
151 @property
152 def metadata(self) -> dict[str, Any]:
153 """Read metadata from metadata.json, with fallback for old backups."""
154 metadata_path = self.path / "metadata.json"
155 if metadata_path.exists():
156 with open(metadata_path) as f:
157 return json.load(f)
158
159 return {
160 "created_at": None,
161 "source": None,
162 "git_branch": None,
163 "git_commit": None,
164 }
165
166 def delete(self) -> None:
167 backup_file = self.path / "default.backup"
168 backup_file.unlink(missing_ok=True)
169 metadata_file = self.path / "metadata.json"
170 metadata_file.unlink(missing_ok=True)
171 self.path.rmdir()
172
173 def updated_at(self) -> datetime.datetime:
174 mtime = os.path.getmtime(self.path)
175 return datetime.datetime.fromtimestamp(mtime)