Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import datetime
  4import json
  5import os
  6import subprocess
  7from pathlib import Path
  8from typing import TYPE_CHECKING, Any, cast
  9
 10from plain.runtime import PLAIN_TEMP_PATH
 11
 12from .. import db_connection as _db_connection
 13from .clients import PostgresBackupClient, SQLiteBackupClient
 14
 15
 16def get_git_branch() -> str | None:
 17    """Get current git branch, or None if not in a git repo."""
 18    try:
 19        result = subprocess.run(
 20            ["git", "rev-parse", "--abbrev-ref", "HEAD"],
 21            capture_output=True,
 22            text=True,
 23            check=True,
 24        )
 25        return result.stdout.strip()
 26    except (subprocess.CalledProcessError, FileNotFoundError):
 27        return None
 28
 29
 30def get_git_commit() -> str | None:
 31    """Get current git commit (short hash), or None if not in a git repo."""
 32    try:
 33        result = subprocess.run(
 34            ["git", "rev-parse", "--short", "HEAD"],
 35            capture_output=True,
 36            text=True,
 37            check=True,
 38        )
 39        return result.stdout.strip()
 40    except (subprocess.CalledProcessError, FileNotFoundError):
 41        return None
 42
 43
 44if TYPE_CHECKING:
 45    from plain.models.backends.base.base import BaseDatabaseWrapper
 46
 47# Cast for type checkers; runtime value is _db_connection (DatabaseConnection)
 48db_connection = cast("BaseDatabaseWrapper", _db_connection)
 49
 50
 51class DatabaseBackups:
 52    def __init__(self) -> None:
 53        self.path = PLAIN_TEMP_PATH / "backups"
 54
 55    def find_backups(self) -> list[DatabaseBackup]:
 56        if not self.path.exists():
 57            return []
 58
 59        backups = []
 60
 61        for backup_dir in self.path.iterdir():
 62            backup = DatabaseBackup(backup_dir.name, backups_path=self.path)
 63            backups.append(backup)
 64
 65        # Sort backups by date
 66        backups.sort(key=lambda x: x.updated_at(), reverse=True)
 67
 68        return backups
 69
 70    def create(self, name: str, **create_kwargs: Any) -> Path:
 71        backup = DatabaseBackup(name, backups_path=self.path)
 72        if backup.exists():
 73            raise Exception(f"Backup {name} already exists")
 74        backup_dir = backup.create(**create_kwargs)
 75        return backup_dir
 76
 77    def restore(self, name: str, **restore_kwargs: Any) -> None:
 78        backup = DatabaseBackup(name, backups_path=self.path)
 79        if not backup.exists():
 80            raise Exception(f"Backup {name} not found")
 81        backup.restore(**restore_kwargs)
 82
 83    def delete(self, name: str) -> None:
 84        backup = DatabaseBackup(name, backups_path=self.path)
 85        if not backup.exists():
 86            raise Exception(f"Backup {name} not found")
 87        backup.delete()
 88
 89
 90class DatabaseBackup:
 91    def __init__(self, name: str, *, backups_path: Path) -> None:
 92        self.name = name
 93        self.path = backups_path / name
 94
 95        if not self.name:
 96            raise ValueError("Backup name is required")
 97
 98    def exists(self) -> bool:
 99        return self.path.exists()
100
101    def create(self, *, source: str = "manual", **create_kwargs: Any) -> Path:
102        self.path.mkdir(parents=True, exist_ok=True)
103
104        backup_path = self.path / "default.backup"
105
106        if db_connection.vendor == "postgresql":
107            PostgresBackupClient(db_connection).create_backup(
108                backup_path,
109                pg_dump=create_kwargs.get("pg_dump", "pg_dump"),
110            )
111        elif db_connection.vendor == "sqlite":
112            SQLiteBackupClient(db_connection).create_backup(backup_path)
113        else:
114            raise Exception("Unsupported database vendor")
115
116        # Write metadata
117        metadata = {
118            "created_at": datetime.datetime.now(datetime.UTC).isoformat(),
119            "source": source,
120            "git_branch": get_git_branch(),
121            "git_commit": get_git_commit(),
122        }
123        metadata_path = self.path / "metadata.json"
124        with open(metadata_path, "w") as f:
125            json.dump(metadata, f, indent=2)
126
127        return self.path
128
129    def restore(self, **restore_kwargs: Any) -> None:
130        backup_file = self.path / "default.backup"
131
132        if db_connection.vendor == "postgresql":
133            PostgresBackupClient(db_connection).restore_backup(
134                backup_file,
135                pg_restore=restore_kwargs.get("pg_restore", "pg_restore"),
136            )
137        elif db_connection.vendor == "sqlite":
138            SQLiteBackupClient(db_connection).restore_backup(backup_file)
139        else:
140            raise Exception("Unsupported database vendor")
141
142    @property
143    def metadata(self) -> dict[str, Any]:
144        """Read metadata from metadata.json, with fallback for old backups."""
145        metadata_path = self.path / "metadata.json"
146        if metadata_path.exists():
147            with open(metadata_path) as f:
148                return json.load(f)
149
150        return {
151            "created_at": None,
152            "source": None,
153            "git_branch": None,
154            "git_commit": None,
155        }
156
157    def delete(self) -> None:
158        backup_file = self.path / "default.backup"
159        backup_file.unlink()
160        metadata_file = self.path / "metadata.json"
161        if metadata_file.exists():
162            metadata_file.unlink()
163        self.path.rmdir()
164
165    def updated_at(self) -> datetime.datetime:
166        mtime = os.path.getmtime(self.path)
167        return datetime.datetime.fromtimestamp(mtime)