1from __future__ import annotations
2
3import gzip
4import os
5import subprocess
6from pathlib import Path
7from typing import TYPE_CHECKING
8
9if TYPE_CHECKING:
10 from plain.models.backends.base.base import BaseDatabaseWrapper
11
12
13class PostgresBackupClient:
14 def __init__(self, connection: BaseDatabaseWrapper) -> None:
15 self.connection = connection
16
17 def get_env(self) -> dict[str, str]:
18 settings_dict = self.connection.settings_dict
19 options = settings_dict.get("OPTIONS", {})
20 env = {}
21 if options.get("passfile"):
22 env["PGPASSFILE"] = str(options.get("passfile"))
23 if settings_dict.get("PASSWORD"):
24 env["PGPASSWORD"] = str(settings_dict.get("PASSWORD"))
25 if options.get("service"):
26 env["PGSERVICE"] = str(options.get("service"))
27 if options.get("sslmode"):
28 env["PGSSLMODE"] = str(options.get("sslmode"))
29 if options.get("sslrootcert"):
30 env["PGSSLROOTCERT"] = str(options.get("sslrootcert"))
31 if options.get("sslcert"):
32 env["PGSSLCERT"] = str(options.get("sslcert"))
33 if options.get("sslkey"):
34 env["PGSSLKEY"] = str(options.get("sslkey"))
35 return env
36
37 def create_backup(self, backup_path: Path, *, pg_dump: str = "pg_dump") -> None:
38 settings_dict = self.connection.settings_dict
39
40 args = pg_dump.split()
41 options = settings_dict.get("OPTIONS", {})
42
43 host = settings_dict.get("HOST")
44 port = settings_dict.get("PORT")
45 dbname = settings_dict.get("NAME")
46 user = settings_dict.get("USER")
47 service = options.get("service")
48
49 if not dbname and not service:
50 # Connect to the default 'postgres' db.
51 dbname = "postgres"
52 if user:
53 args += ["-U", user]
54 if host:
55 args += ["-h", host]
56 if port:
57 args += ["-p", str(port)]
58
59 args += ["-Fc"]
60 # args += ["-f", backup_path]
61
62 if dbname:
63 args += [dbname]
64
65 # Using stdin/stdout let's us use executables from within a docker container too
66 args += ["|", "gzip", ">", str(backup_path)]
67
68 cmd = " ".join(args)
69
70 subprocess.run(
71 cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
72 )
73
74 def restore_backup(
75 self, backup_path: Path, *, pg_restore: str = "pg_restore", psql: str = "psql"
76 ) -> None:
77 settings_dict = self.connection.settings_dict
78
79 host = settings_dict.get("HOST")
80 port = settings_dict.get("PORT")
81 dbname = settings_dict.get("NAME")
82 assert dbname is not None, "Database NAME is required in settings"
83 user = settings_dict.get("USER")
84
85 # Build common connection args
86 conn_args: list[str] = []
87 if user:
88 conn_args += ["-U", user]
89 if host:
90 conn_args += ["-h", host]
91 if port:
92 conn_args += ["-p", str(port)]
93
94 # First, drop and recreate the database
95 # Connect to 'template1' database to do this (works for all databases including 'postgres')
96 drop_create_cmds = [
97 f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{dbname}' AND pid <> pg_backend_pid()",
98 f'DROP DATABASE IF EXISTS "{dbname}"',
99 f'CREATE DATABASE "{dbname}"',
100 ]
101
102 for cmd in drop_create_cmds:
103 psql_args = (
104 psql.split()
105 + conn_args
106 + [
107 "-d",
108 "template1", # Always use template1
109 "-c",
110 cmd,
111 ]
112 )
113 subprocess.run(psql_args, env={**os.environ, **self.get_env()}, check=True)
114
115 # Now restore into the fresh database
116 args = pg_restore.split()
117 args += conn_args
118 args += ["-d", dbname]
119
120 # Using stdin/stdout let's us use executables from within a docker container too
121 args = ["gunzip", "<", str(backup_path), "|"] + args
122
123 cmd = " ".join(args)
124
125 subprocess.run(
126 cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
127 )
128
129
130class SQLiteBackupClient:
131 def __init__(self, connection: BaseDatabaseWrapper) -> None:
132 self.connection = connection
133
134 def create_backup(self, backup_path: Path) -> None:
135 self.connection.ensure_connection()
136 src_conn = self.connection.connection
137 dump = "\n".join(src_conn.iterdump())
138 with gzip.open(backup_path, "wt") as f:
139 f.write(dump)
140
141 def restore_backup(self, backup_path: Path) -> None:
142 with gzip.open(backup_path, "rt") as f:
143 sql = f.read()
144
145 self.connection.close()
146 self.connection.connect()
147 dest_conn = self.connection.connection
148 cur = dest_conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
149 for (name,) in cur.fetchall():
150 if not name.startswith("sqlite_"):
151 dest_conn.execute(f'DROP TABLE IF EXISTS "{name}"')
152 dest_conn.executescript(sql)
153 dest_conn.commit()