1from __future__ import annotations
2
3import datetime
4import json
5import os
6import subprocess
7from pathlib import Path
8from typing import Any
9
10from plain.runtime import PLAIN_TEMP_PATH
11
12from ..db import get_connection
13from .clients import PostgresBackupClient
14
15
16def get_git_branch() -> str | None:
17 """Get current git branch, or None if not in a git repo."""
18 try:
19 result = subprocess.run(
20 ["git", "rev-parse", "--abbrev-ref", "HEAD"],
21 capture_output=True,
22 text=True,
23 check=True,
24 )
25 return result.stdout.strip()
26 except (subprocess.CalledProcessError, FileNotFoundError):
27 return None
28
29
30def get_git_commit() -> str | None:
31 """Get current git commit (short hash), or None if not in a git repo."""
32 try:
33 result = subprocess.run(
34 ["git", "rev-parse", "--short", "HEAD"],
35 capture_output=True,
36 text=True,
37 check=True,
38 )
39 return result.stdout.strip()
40 except (subprocess.CalledProcessError, FileNotFoundError):
41 return None
42
43
44class DatabaseBackups:
45 def __init__(self) -> None:
46 self.path = PLAIN_TEMP_PATH / "backups"
47
48 def find_backups(self) -> list[DatabaseBackup]:
49 if not self.path.exists():
50 return []
51
52 backups = []
53
54 for backup_dir in self.path.iterdir():
55 backup = DatabaseBackup(backup_dir.name, backups_path=self.path)
56 backups.append(backup)
57
58 # Sort backups by date
59 backups.sort(key=lambda x: x.updated_at(), reverse=True)
60
61 return backups
62
63 def create(
64 self, name: str, *, source: str = "manual", pg_dump: str = "pg_dump"
65 ) -> Path:
66 backup = DatabaseBackup(name, backups_path=self.path)
67 if backup.exists():
68 raise Exception(f"Backup {name} already exists")
69 backup_dir = backup.create(source=source, pg_dump=pg_dump)
70 try:
71 self.prune()
72 except Exception:
73 pass
74 return backup_dir
75
76 def prune(self) -> list[str]:
77 """Delete oldest backups on the current branch (or with no branch), keeping the most recent 20."""
78 keep = 20
79 current_branch = get_git_branch()
80 backups = self.find_backups() # sorted newest-first
81
82 # Only prune backups matching the current branch or with no branch metadata
83 prunable = [
84 b for b in backups if b.metadata.get("git_branch") in (current_branch, None)
85 ]
86
87 deleted = []
88 for backup in prunable[keep:]:
89 backup.delete()
90 deleted.append(backup.name)
91 return deleted
92
93 def restore(self, name: str, *, pg_restore: str = "pg_restore") -> None:
94 backup = DatabaseBackup(name, backups_path=self.path)
95 if not backup.exists():
96 raise Exception(f"Backup {name} not found")
97 backup.restore(pg_restore=pg_restore)
98
99 def delete(self, name: str) -> None:
100 backup = DatabaseBackup(name, backups_path=self.path)
101 if not backup.exists():
102 raise Exception(f"Backup {name} not found")
103 backup.delete()
104
105
106class DatabaseBackup:
107 def __init__(self, name: str, *, backups_path: Path) -> None:
108 self.name = name
109 self.path = backups_path / name
110
111 if not self.name:
112 raise ValueError("Backup name is required")
113
114 def exists(self) -> bool:
115 return self.path.exists()
116
117 def create(self, *, source: str = "manual", pg_dump: str = "pg_dump") -> Path:
118 self.path.mkdir(parents=True, exist_ok=True)
119
120 backup_path = self.path / "default.backup"
121
122 PostgresBackupClient(get_connection()).create_backup(
123 backup_path,
124 pg_dump=pg_dump,
125 )
126
127 # Write metadata
128 metadata = {
129 "created_at": datetime.datetime.now(datetime.UTC).isoformat(),
130 "source": source,
131 "git_branch": get_git_branch(),
132 "git_commit": get_git_commit(),
133 }
134 metadata_path = self.path / "metadata.json"
135 with open(metadata_path, "w") as f:
136 json.dump(metadata, f, indent=2)
137
138 return self.path
139
140 def restore(self, *, pg_restore: str = "pg_restore") -> None:
141 backup_file = self.path / "default.backup"
142
143 PostgresBackupClient(get_connection()).restore_backup(
144 backup_file,
145 pg_restore=pg_restore,
146 )
147
148 @property
149 def metadata(self) -> dict[str, Any]:
150 """Read metadata from metadata.json, with fallback for old backups."""
151 metadata_path = self.path / "metadata.json"
152 if metadata_path.exists():
153 with open(metadata_path) as f:
154 return json.load(f)
155
156 return {
157 "created_at": None,
158 "source": None,
159 "git_branch": None,
160 "git_commit": None,
161 }
162
163 def delete(self) -> None:
164 backup_file = self.path / "default.backup"
165 backup_file.unlink(missing_ok=True)
166 metadata_file = self.path / "metadata.json"
167 metadata_file.unlink(missing_ok=True)
168 self.path.rmdir()
169
170 def updated_at(self) -> datetime.datetime:
171 mtime = os.path.getmtime(self.path)
172 return datetime.datetime.fromtimestamp(mtime)