1from __future__ import annotations
  2
  3import datetime
  4import json
  5import os
  6import subprocess
  7from pathlib import Path
  8from typing import Any
  9
 10from plain.runtime import PLAIN_TEMP_PATH
 11
 12from .clients import PostgresBackupClient
 13
 14
 15def get_git_branch() -> str | None:
 16    """Get current git branch, or None if not in a git repo."""
 17    try:
 18        result = subprocess.run(
 19            ["git", "rev-parse", "--abbrev-ref", "HEAD"],
 20            capture_output=True,
 21            text=True,
 22            check=True,
 23        )
 24        return result.stdout.strip()
 25    except (subprocess.CalledProcessError, FileNotFoundError):
 26        return None
 27
 28
 29def get_git_commit() -> str | None:
 30    """Get current git commit (short hash), or None if not in a git repo."""
 31    try:
 32        result = subprocess.run(
 33            ["git", "rev-parse", "--short", "HEAD"],
 34            capture_output=True,
 35            text=True,
 36            check=True,
 37        )
 38        return result.stdout.strip()
 39    except (subprocess.CalledProcessError, FileNotFoundError):
 40        return None
 41
 42
 43class DatabaseBackups:
 44    def __init__(self) -> None:
 45        self.path = PLAIN_TEMP_PATH / "backups"
 46
 47    def find_backups(self) -> list[DatabaseBackup]:
 48        if not self.path.exists():
 49            return []
 50
 51        backups = []
 52
 53        for backup_dir in self.path.iterdir():
 54            backup = DatabaseBackup(backup_dir.name, backups_path=self.path)
 55            backups.append(backup)
 56
 57        # Sort backups by date
 58        backups.sort(key=lambda x: x.updated_at(), reverse=True)
 59
 60        return backups
 61
 62    def create(
 63        self, name: str, *, source: str = "manual", pg_dump: str = "pg_dump"
 64    ) -> Path:
 65        backup = DatabaseBackup(name, backups_path=self.path)
 66        if backup.exists():
 67            raise Exception(f"Backup {name} already exists")
 68        backup_dir = backup.create(source=source, pg_dump=pg_dump)
 69        try:
 70            self.prune()
 71        except Exception:
 72            pass
 73        return backup_dir
 74
 75    def prune(self) -> list[str]:
 76        """Delete oldest backups on the current branch (or with no branch), keeping the most recent 20."""
 77        keep = 20
 78        current_branch = get_git_branch()
 79        backups = self.find_backups()  # sorted newest-first
 80
 81        # Only prune backups matching the current branch or with no branch metadata
 82        prunable = [
 83            b for b in backups if b.metadata.get("git_branch") in (current_branch, None)
 84        ]
 85
 86        deleted = []
 87        for backup in prunable[keep:]:
 88            backup.delete()
 89            deleted.append(backup.name)
 90        return deleted
 91
 92    def restore(self, name: str, *, pg_restore: str = "pg_restore") -> None:
 93        backup = DatabaseBackup(name, backups_path=self.path)
 94        if not backup.exists():
 95            raise Exception(f"Backup {name} not found")
 96        backup.restore(pg_restore=pg_restore)
 97
 98    def delete(self, name: str) -> None:
 99        backup = DatabaseBackup(name, backups_path=self.path)
100        if not backup.exists():
101            raise Exception(f"Backup {name} not found")
102        backup.delete()
103
104
105class DatabaseBackup:
106    def __init__(self, name: str, *, backups_path: Path) -> None:
107        self.name = name
108        self.path = backups_path / name
109
110        if not self.name:
111            raise ValueError("Backup name is required")
112
113    def exists(self) -> bool:
114        return self.path.exists()
115
116    def create(self, *, source: str = "manual", pg_dump: str = "pg_dump") -> Path:
117        from plain.postgres.db import get_connection
118
119        self.path.mkdir(parents=True, exist_ok=True)
120
121        backup_path = self.path / "default.backup"
122
123        PostgresBackupClient(get_connection()).create_backup(
124            backup_path,
125            pg_dump=pg_dump,
126        )
127
128        # Write metadata
129        metadata = {
130            "created_at": datetime.datetime.now(datetime.UTC).isoformat(),
131            "source": source,
132            "git_branch": get_git_branch(),
133            "git_commit": get_git_commit(),
134        }
135        metadata_path = self.path / "metadata.json"
136        with open(metadata_path, "w") as f:
137            json.dump(metadata, f, indent=2)
138
139        return self.path
140
141    def restore(self, *, pg_restore: str = "pg_restore") -> None:
142        from plain.postgres.db import get_connection
143
144        backup_file = self.path / "default.backup"
145
146        PostgresBackupClient(get_connection()).restore_backup(
147            backup_file,
148            pg_restore=pg_restore,
149        )
150
151    @property
152    def metadata(self) -> dict[str, Any]:
153        """Read metadata from metadata.json, with fallback for old backups."""
154        metadata_path = self.path / "metadata.json"
155        if metadata_path.exists():
156            with open(metadata_path) as f:
157                return json.load(f)
158
159        return {
160            "created_at": None,
161            "source": None,
162            "git_branch": None,
163            "git_commit": None,
164        }
165
166    def delete(self) -> None:
167        backup_file = self.path / "default.backup"
168        backup_file.unlink(missing_ok=True)
169        metadata_file = self.path / "metadata.json"
170        metadata_file.unlink(missing_ok=True)
171        self.path.rmdir()
172
173    def updated_at(self) -> datetime.datetime:
174        mtime = os.path.getmtime(self.path)
175        return datetime.datetime.fromtimestamp(mtime)