1from __future__ import annotations
  2
  3import datetime
  4import json
  5import os
  6import subprocess
  7from pathlib import Path
  8from typing import Any
  9
 10from plain.runtime import PLAIN_TEMP_PATH
 11
 12from ..db import get_connection
 13from .clients import PostgresBackupClient
 14
 15
 16def get_git_branch() -> str | None:
 17    """Get current git branch, or None if not in a git repo."""
 18    try:
 19        result = subprocess.run(
 20            ["git", "rev-parse", "--abbrev-ref", "HEAD"],
 21            capture_output=True,
 22            text=True,
 23            check=True,
 24        )
 25        return result.stdout.strip()
 26    except (subprocess.CalledProcessError, FileNotFoundError):
 27        return None
 28
 29
 30def get_git_commit() -> str | None:
 31    """Get current git commit (short hash), or None if not in a git repo."""
 32    try:
 33        result = subprocess.run(
 34            ["git", "rev-parse", "--short", "HEAD"],
 35            capture_output=True,
 36            text=True,
 37            check=True,
 38        )
 39        return result.stdout.strip()
 40    except (subprocess.CalledProcessError, FileNotFoundError):
 41        return None
 42
 43
 44class DatabaseBackups:
 45    def __init__(self) -> None:
 46        self.path = PLAIN_TEMP_PATH / "backups"
 47
 48    def find_backups(self) -> list[DatabaseBackup]:
 49        if not self.path.exists():
 50            return []
 51
 52        backups = []
 53
 54        for backup_dir in self.path.iterdir():
 55            backup = DatabaseBackup(backup_dir.name, backups_path=self.path)
 56            backups.append(backup)
 57
 58        # Sort backups by date
 59        backups.sort(key=lambda x: x.updated_at(), reverse=True)
 60
 61        return backups
 62
 63    def create(
 64        self, name: str, *, source: str = "manual", pg_dump: str = "pg_dump"
 65    ) -> Path:
 66        backup = DatabaseBackup(name, backups_path=self.path)
 67        if backup.exists():
 68            raise Exception(f"Backup {name} already exists")
 69        backup_dir = backup.create(source=source, pg_dump=pg_dump)
 70        try:
 71            self.prune()
 72        except Exception:
 73            pass
 74        return backup_dir
 75
 76    def prune(self) -> list[str]:
 77        """Delete oldest backups on the current branch (or with no branch), keeping the most recent 20."""
 78        keep = 20
 79        current_branch = get_git_branch()
 80        backups = self.find_backups()  # sorted newest-first
 81
 82        # Only prune backups matching the current branch or with no branch metadata
 83        prunable = [
 84            b for b in backups if b.metadata.get("git_branch") in (current_branch, None)
 85        ]
 86
 87        deleted = []
 88        for backup in prunable[keep:]:
 89            backup.delete()
 90            deleted.append(backup.name)
 91        return deleted
 92
 93    def restore(self, name: str, *, pg_restore: str = "pg_restore") -> None:
 94        backup = DatabaseBackup(name, backups_path=self.path)
 95        if not backup.exists():
 96            raise Exception(f"Backup {name} not found")
 97        backup.restore(pg_restore=pg_restore)
 98
 99    def delete(self, name: str) -> None:
100        backup = DatabaseBackup(name, backups_path=self.path)
101        if not backup.exists():
102            raise Exception(f"Backup {name} not found")
103        backup.delete()
104
105
106class DatabaseBackup:
107    def __init__(self, name: str, *, backups_path: Path) -> None:
108        self.name = name
109        self.path = backups_path / name
110
111        if not self.name:
112            raise ValueError("Backup name is required")
113
114    def exists(self) -> bool:
115        return self.path.exists()
116
117    def create(self, *, source: str = "manual", pg_dump: str = "pg_dump") -> Path:
118        self.path.mkdir(parents=True, exist_ok=True)
119
120        backup_path = self.path / "default.backup"
121
122        PostgresBackupClient(get_connection()).create_backup(
123            backup_path,
124            pg_dump=pg_dump,
125        )
126
127        # Write metadata
128        metadata = {
129            "created_at": datetime.datetime.now(datetime.UTC).isoformat(),
130            "source": source,
131            "git_branch": get_git_branch(),
132            "git_commit": get_git_commit(),
133        }
134        metadata_path = self.path / "metadata.json"
135        with open(metadata_path, "w") as f:
136            json.dump(metadata, f, indent=2)
137
138        return self.path
139
140    def restore(self, *, pg_restore: str = "pg_restore") -> None:
141        backup_file = self.path / "default.backup"
142
143        PostgresBackupClient(get_connection()).restore_backup(
144            backup_file,
145            pg_restore=pg_restore,
146        )
147
148    @property
149    def metadata(self) -> dict[str, Any]:
150        """Read metadata from metadata.json, with fallback for old backups."""
151        metadata_path = self.path / "metadata.json"
152        if metadata_path.exists():
153            with open(metadata_path) as f:
154                return json.load(f)
155
156        return {
157            "created_at": None,
158            "source": None,
159            "git_branch": None,
160            "git_commit": None,
161        }
162
163    def delete(self) -> None:
164        backup_file = self.path / "default.backup"
165        backup_file.unlink(missing_ok=True)
166        metadata_file = self.path / "metadata.json"
167        metadata_file.unlink(missing_ok=True)
168        self.path.rmdir()
169
170    def updated_at(self) -> datetime.datetime:
171        mtime = os.path.getmtime(self.path)
172        return datetime.datetime.fromtimestamp(mtime)