1from __future__ import annotations
2
3import gzip
4import os
5import subprocess
6from pathlib import Path
7from typing import TYPE_CHECKING
8
9if TYPE_CHECKING:
10 from plain.models.backends.base.base import BaseDatabaseWrapper
11
12
13class PostgresBackupClient:
14 def __init__(self, connection: BaseDatabaseWrapper) -> None:
15 self.connection = connection
16
17 def get_env(self) -> dict[str, str]:
18 settings_dict = self.connection.settings_dict
19 options = settings_dict.get("OPTIONS", {})
20 env = {}
21 if options.get("passfile"):
22 env["PGPASSFILE"] = str(options.get("passfile"))
23 if settings_dict.get("PASSWORD"):
24 env["PGPASSWORD"] = str(settings_dict.get("PASSWORD"))
25 if options.get("service"):
26 env["PGSERVICE"] = str(options.get("service"))
27 if options.get("sslmode"):
28 env["PGSSLMODE"] = str(options.get("sslmode"))
29 if options.get("sslrootcert"):
30 env["PGSSLROOTCERT"] = str(options.get("sslrootcert"))
31 if options.get("sslcert"):
32 env["PGSSLCERT"] = str(options.get("sslcert"))
33 if options.get("sslkey"):
34 env["PGSSLKEY"] = str(options.get("sslkey"))
35 return env
36
37 def create_backup(self, backup_path: Path, *, pg_dump: str = "pg_dump") -> None:
38 settings_dict = self.connection.settings_dict
39
40 args = pg_dump.split()
41 options = settings_dict.get("OPTIONS", {})
42
43 host = settings_dict.get("HOST")
44 port = settings_dict.get("PORT")
45 dbname = settings_dict.get("NAME")
46 user = settings_dict.get("USER")
47 service = options.get("service")
48
49 if not dbname and not service:
50 # Connect to the default 'postgres' db.
51 dbname = "postgres"
52 if user:
53 args += ["-U", user]
54 if host:
55 args += ["-h", host]
56 if port:
57 args += ["-p", str(port)]
58
59 args += ["-Fc"]
60 # args += ["-f", backup_path]
61
62 if dbname:
63 args += [dbname]
64
65 # Using stdin/stdout let's us use executables from within a docker container too
66 args += ["|", "gzip", ">", str(backup_path)]
67
68 cmd = " ".join(args)
69
70 subprocess.run(
71 cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
72 )
73
74 def restore_backup(
75 self, backup_path: Path, *, pg_restore: str = "pg_restore", psql: str = "psql"
76 ) -> None:
77 settings_dict = self.connection.settings_dict
78
79 host = settings_dict.get("HOST")
80 port = settings_dict.get("PORT")
81 dbname = settings_dict.get("NAME")
82 user = settings_dict.get("USER")
83
84 # Build common connection args
85 conn_args = []
86 if user:
87 conn_args += ["-U", user]
88 if host:
89 conn_args += ["-h", host]
90 if port:
91 conn_args += ["-p", str(port)]
92
93 # First, drop and recreate the database
94 # Connect to 'template1' database to do this (works for all databases including 'postgres')
95 drop_create_cmds = [
96 f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{dbname}' AND pid <> pg_backend_pid()",
97 f'DROP DATABASE IF EXISTS "{dbname}"',
98 f'CREATE DATABASE "{dbname}"',
99 ]
100
101 for cmd in drop_create_cmds:
102 psql_args = (
103 psql.split()
104 + conn_args
105 + [
106 "-d",
107 "template1", # Always use template1
108 "-c",
109 cmd,
110 ]
111 )
112 subprocess.run(psql_args, env={**os.environ, **self.get_env()}, check=True)
113
114 # Now restore into the fresh database
115 args = pg_restore.split()
116 args += conn_args
117 args += ["-d", dbname]
118
119 # Using stdin/stdout let's us use executables from within a docker container too
120 args = ["gunzip", "<", str(backup_path), "|"] + args
121
122 cmd = " ".join(args)
123
124 subprocess.run(
125 cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
126 )
127
128
129class SQLiteBackupClient:
130 def __init__(self, connection: BaseDatabaseWrapper) -> None:
131 self.connection = connection
132
133 def create_backup(self, backup_path: Path) -> None:
134 self.connection.ensure_connection()
135 src_conn = self.connection.connection
136 dump = "\n".join(src_conn.iterdump())
137 with gzip.open(backup_path, "wt") as f:
138 f.write(dump)
139
140 def restore_backup(self, backup_path: Path) -> None:
141 with gzip.open(backup_path, "rt") as f:
142 sql = f.read()
143
144 self.connection.close()
145 self.connection.connect()
146 dest_conn = self.connection.connection
147 cur = dest_conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
148 for (name,) in cur.fetchall():
149 if not name.startswith("sqlite_"):
150 dest_conn.execute(f'DROP TABLE IF EXISTS "{name}"')
151 dest_conn.executescript(sql)
152 dest_conn.commit()