Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import datetime
  4import os
  5from collections.abc import Generator
  6from pathlib import Path
  7from typing import TYPE_CHECKING, Any, cast
  8
  9from plain.runtime import PLAIN_TEMP_PATH
 10
 11from .. import db_connection as _db_connection
 12from .clients import PostgresBackupClient, SQLiteBackupClient
 13
 14if TYPE_CHECKING:
 15    from plain.models.backends.base.base import BaseDatabaseWrapper
 16
 17    db_connection = cast("BaseDatabaseWrapper", _db_connection)
 18else:
 19    db_connection = _db_connection
 20
 21
 22class DatabaseBackups:
 23    def __init__(self) -> None:
 24        self.path = PLAIN_TEMP_PATH / "backups"
 25
 26    def find_backups(self) -> list[DatabaseBackup]:
 27        if not self.path.exists():
 28            return []
 29
 30        backups = []
 31
 32        for backup_dir in self.path.iterdir():
 33            backup = DatabaseBackup(backup_dir.name, backups_path=self.path)
 34            backups.append(backup)
 35
 36        # Sort backups by date
 37        backups.sort(key=lambda x: x.updated_at(), reverse=True)
 38
 39        return backups
 40
 41    def create(self, name: str, **create_kwargs: Any) -> Path:
 42        backup = DatabaseBackup(name, backups_path=self.path)
 43        if backup.exists():
 44            raise Exception(f"Backup {name} already exists")
 45        backup_dir = backup.create(**create_kwargs)
 46        return backup_dir
 47
 48    def restore(self, name: str, **restore_kwargs: Any) -> None:
 49        backup = DatabaseBackup(name, backups_path=self.path)
 50        if not backup.exists():
 51            raise Exception(f"Backup {name} not found")
 52        backup.restore(**restore_kwargs)
 53
 54    def delete(self, name: str) -> None:
 55        backup = DatabaseBackup(name, backups_path=self.path)
 56        if not backup.exists():
 57            raise Exception(f"Backup {name} not found")
 58        backup.delete()
 59
 60
 61class DatabaseBackup:
 62    def __init__(self, name: str, *, backups_path: Path) -> None:
 63        self.name = name
 64        self.path = backups_path / name
 65
 66        if not self.name:
 67            raise ValueError("Backup name is required")
 68
 69    def exists(self) -> bool:
 70        return self.path.exists()
 71
 72    def create(self, **create_kwargs: Any) -> Path:
 73        self.path.mkdir(parents=True, exist_ok=True)
 74
 75        backup_path = self.path / "default.backup"
 76
 77        if db_connection.vendor == "postgresql":
 78            PostgresBackupClient(db_connection).create_backup(
 79                backup_path,
 80                pg_dump=create_kwargs.get("pg_dump", "pg_dump"),
 81            )
 82        elif db_connection.vendor == "sqlite":
 83            SQLiteBackupClient(db_connection).create_backup(backup_path)
 84        else:
 85            raise Exception("Unsupported database vendor")
 86
 87        return self.path
 88
 89    def iter_files(self) -> Generator[Path, None, None]:
 90        for backup_file in self.path.iterdir():
 91            if not backup_file.is_file():
 92                continue
 93            if not backup_file.name.endswith(".backup"):
 94                continue
 95            yield backup_file
 96
 97    def restore(self, **restore_kwargs: Any) -> None:
 98        for backup_file in self.iter_files():
 99            if db_connection.vendor == "postgresql":
100                PostgresBackupClient(db_connection).restore_backup(
101                    backup_file,
102                    pg_restore=restore_kwargs.get("pg_restore", "pg_restore"),
103                )
104            elif db_connection.vendor == "sqlite":
105                SQLiteBackupClient(db_connection).restore_backup(backup_file)
106            else:
107                raise Exception("Unsupported database vendor")
108
109    def delete(self) -> None:
110        for backup_file in self.iter_files():
111            backup_file.unlink()
112
113        self.path.rmdir()
114
115    def updated_at(self) -> datetime.datetime:
116        mtime = os.path.getmtime(self.path)
117        return datetime.datetime.fromtimestamp(mtime)