1from __future__ import annotations
2
3import os
4import subprocess
5from pathlib import Path
6from typing import TYPE_CHECKING
7
8from plain.exceptions import ImproperlyConfigured
9
10if TYPE_CHECKING:
11 from plain.postgres.connection import DatabaseConnection
12
13
14class PostgresBackupClient:
15 def __init__(self, connection: DatabaseConnection) -> None:
16 self.connection = connection
17
18 def get_env(self) -> dict[str, str]:
19 settings_dict = self.connection.settings_dict
20 options = settings_dict.get("OPTIONS", {})
21 env: dict[str, str] = {}
22
23 if password := settings_dict.get("PASSWORD"):
24 env["PGPASSWORD"] = str(password)
25
26 # Map OPTIONS keys to their corresponding environment variables.
27 option_env_vars = {
28 "passfile": "PGPASSFILE",
29 "sslmode": "PGSSLMODE",
30 "sslrootcert": "PGSSLROOTCERT",
31 "sslcert": "PGSSLCERT",
32 "sslkey": "PGSSLKEY",
33 }
34 for option_key, env_var in option_env_vars.items():
35 if value := options.get(option_key):
36 env[env_var] = str(value)
37
38 return env
39
40 def _get_conn_args(self) -> list[str]:
41 """Build common connection CLI args from settings."""
42 settings_dict = self.connection.settings_dict
43 args: list[str] = []
44 if user := settings_dict.get("USER"):
45 args += ["-U", user]
46 if host := settings_dict.get("HOST"):
47 args += ["-h", host]
48 if port := settings_dict.get("PORT"):
49 args += ["-p", str(port)]
50 return args
51
52 def _run(self, cmd: str | list[str], *, shell: bool = False) -> None:
53 subprocess.run(
54 cmd, env={**os.environ, **self.get_env()}, check=True, shell=shell
55 )
56
57 def create_backup(self, backup_path: Path, *, pg_dump: str = "pg_dump") -> None:
58 settings_dict = self.connection.settings_dict
59 dbname = settings_dict.get("DATABASE")
60 if not dbname:
61 raise ImproperlyConfigured("POSTGRES_DATABASE is required in settings")
62
63 args = pg_dump.split() + self._get_conn_args()
64 args += ["-Fc", dbname]
65
66 # Pipe through gzip for compression
67 args += ["|", "gzip", ">", str(backup_path)]
68 self._run(" ".join(args), shell=True)
69
70 def restore_backup(
71 self, backup_path: Path, *, pg_restore: str = "pg_restore", psql: str = "psql"
72 ) -> None:
73 settings_dict = self.connection.settings_dict
74 dbname = settings_dict.get("DATABASE")
75 if not dbname:
76 raise ImproperlyConfigured("POSTGRES_DATABASE is required in settings")
77
78 conn_args = self._get_conn_args()
79
80 # Drop and recreate the database via template1
81 drop_create_cmds = [
82 f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{dbname}' AND pid <> pg_backend_pid()",
83 f'DROP DATABASE IF EXISTS "{dbname}"',
84 f'CREATE DATABASE "{dbname}"',
85 ]
86 for sql in drop_create_cmds:
87 self._run(psql.split() + conn_args + ["-d", "template1", "-c", sql])
88
89 # Restore into the fresh database
90 args = pg_restore.split() + conn_args + ["-d", dbname]
91
92 # Pipe through gunzip for decompression
93 args = ["gunzip", "<", str(backup_path), "|"] + args
94 self._run(" ".join(args), shell=True)