1from __future__ import annotations
 2
 3import os
 4import subprocess
 5from pathlib import Path
 6from typing import TYPE_CHECKING
 7
 8from plain.exceptions import ImproperlyConfigured
 9
10if TYPE_CHECKING:
11    from plain.postgres.connection import DatabaseConnection
12
13
14class PostgresBackupClient:
15    def __init__(self, connection: DatabaseConnection) -> None:
16        self.connection = connection
17
18    def get_env(self) -> dict[str, str]:
19        settings_dict = self.connection.settings_dict
20        options = settings_dict.get("OPTIONS", {})
21        env: dict[str, str] = {}
22
23        if password := settings_dict.get("PASSWORD"):
24            env["PGPASSWORD"] = str(password)
25
26        # Map OPTIONS keys to their corresponding environment variables.
27        option_env_vars = {
28            "passfile": "PGPASSFILE",
29            "sslmode": "PGSSLMODE",
30            "sslrootcert": "PGSSLROOTCERT",
31            "sslcert": "PGSSLCERT",
32            "sslkey": "PGSSLKEY",
33        }
34        for option_key, env_var in option_env_vars.items():
35            if value := options.get(option_key):
36                env[env_var] = str(value)
37
38        return env
39
40    def _get_conn_args(self) -> list[str]:
41        """Build common connection CLI args from settings."""
42        settings_dict = self.connection.settings_dict
43        args: list[str] = []
44        if user := settings_dict.get("USER"):
45            args += ["-U", user]
46        if host := settings_dict.get("HOST"):
47            args += ["-h", host]
48        if port := settings_dict.get("PORT"):
49            args += ["-p", str(port)]
50        return args
51
52    def _run(self, cmd: str | list[str], *, shell: bool = False) -> None:
53        subprocess.run(
54            cmd, env={**os.environ, **self.get_env()}, check=True, shell=shell
55        )
56
57    def create_backup(self, backup_path: Path, *, pg_dump: str = "pg_dump") -> None:
58        settings_dict = self.connection.settings_dict
59        dbname = settings_dict.get("DATABASE")
60        if not dbname:
61            raise ImproperlyConfigured("POSTGRES_DATABASE is required in settings")
62
63        args = pg_dump.split() + self._get_conn_args()
64        args += ["-Fc", dbname]
65
66        # Pipe through gzip for compression
67        args += ["|", "gzip", ">", str(backup_path)]
68        self._run(" ".join(args), shell=True)
69
70    def restore_backup(
71        self, backup_path: Path, *, pg_restore: str = "pg_restore", psql: str = "psql"
72    ) -> None:
73        settings_dict = self.connection.settings_dict
74        dbname = settings_dict.get("DATABASE")
75        if not dbname:
76            raise ImproperlyConfigured("POSTGRES_DATABASE is required in settings")
77
78        conn_args = self._get_conn_args()
79
80        # Drop and recreate the database via template1
81        drop_create_cmds = [
82            f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{dbname}' AND pid <> pg_backend_pid()",
83            f'DROP DATABASE IF EXISTS "{dbname}"',
84            f'CREATE DATABASE "{dbname}"',
85        ]
86        for sql in drop_create_cmds:
87            self._run(psql.split() + conn_args + ["-d", "template1", "-c", sql])
88
89        # Restore into the fresh database
90        args = pg_restore.split() + conn_args + ["-d", dbname]
91
92        # Pipe through gunzip for decompression
93        args = ["gunzip", "<", str(backup_path), "|"] + args
94        self._run(" ".join(args), shell=True)