1import gzip
2import os
3import subprocess
4
5
6class PostgresBackupClient:
7 def __init__(self, connection):
8 self.connection = connection
9
10 def get_env(self):
11 settings_dict = self.connection.settings_dict
12 options = settings_dict.get("OPTIONS", {})
13 env = {}
14 if options.get("passfile"):
15 env["PGPASSFILE"] = str(options.get("passfile"))
16 if settings_dict.get("PASSWORD"):
17 env["PGPASSWORD"] = str(settings_dict.get("PASSWORD"))
18 if options.get("service"):
19 env["PGSERVICE"] = str(options.get("service"))
20 if options.get("sslmode"):
21 env["PGSSLMODE"] = str(options.get("sslmode"))
22 if options.get("sslrootcert"):
23 env["PGSSLROOTCERT"] = str(options.get("sslrootcert"))
24 if options.get("sslcert"):
25 env["PGSSLCERT"] = str(options.get("sslcert"))
26 if options.get("sslkey"):
27 env["PGSSLKEY"] = str(options.get("sslkey"))
28 return env
29
30 def create_backup(self, backup_path, *, pg_dump="pg_dump"):
31 settings_dict = self.connection.settings_dict
32
33 args = pg_dump.split()
34 options = settings_dict.get("OPTIONS", {})
35
36 host = settings_dict.get("HOST")
37 port = settings_dict.get("PORT")
38 dbname = settings_dict.get("NAME")
39 user = settings_dict.get("USER")
40 service = options.get("service")
41
42 if not dbname and not service:
43 # Connect to the default 'postgres' db.
44 dbname = "postgres"
45 if user:
46 args += ["-U", user]
47 if host:
48 args += ["-h", host]
49 if port:
50 args += ["-p", str(port)]
51
52 args += ["-Fc"]
53 # args += ["-f", backup_path]
54
55 if dbname:
56 args += [dbname]
57
58 # Using stdin/stdout let's us use executables from within a docker container too
59 args += ["|", "gzip", ">", str(backup_path)]
60
61 cmd = " ".join(args)
62
63 subprocess.run(
64 cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
65 )
66
67 def restore_backup(self, backup_path, *, pg_restore="pg_restore", psql="psql"):
68 settings_dict = self.connection.settings_dict
69
70 host = settings_dict.get("HOST")
71 port = settings_dict.get("PORT")
72 dbname = settings_dict.get("NAME")
73 user = settings_dict.get("USER")
74
75 # Build common connection args
76 conn_args = []
77 if user:
78 conn_args += ["-U", user]
79 if host:
80 conn_args += ["-h", host]
81 if port:
82 conn_args += ["-p", str(port)]
83
84 # First, drop and recreate the database
85 # Connect to 'template1' database to do this (works for all databases including 'postgres')
86 drop_create_cmds = [
87 f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{dbname}' AND pid <> pg_backend_pid()",
88 f'DROP DATABASE IF EXISTS "{dbname}"',
89 f'CREATE DATABASE "{dbname}"',
90 ]
91
92 for cmd in drop_create_cmds:
93 psql_args = (
94 psql.split()
95 + conn_args
96 + [
97 "-d",
98 "template1", # Always use template1
99 "-c",
100 cmd,
101 ]
102 )
103 subprocess.run(psql_args, env={**os.environ, **self.get_env()}, check=True)
104
105 # Now restore into the fresh database
106 args = pg_restore.split()
107 args += conn_args
108 args += ["-d", dbname]
109
110 # Using stdin/stdout let's us use executables from within a docker container too
111 args = ["gunzip", "<", str(backup_path), "|"] + args
112
113 cmd = " ".join(args)
114
115 subprocess.run(
116 cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
117 )
118
119
120class SQLiteBackupClient:
121 def __init__(self, connection):
122 self.connection = connection
123
124 def create_backup(self, backup_path):
125 self.connection.ensure_connection()
126 src_conn = self.connection.connection
127 dump = "\n".join(src_conn.iterdump())
128 with gzip.open(backup_path, "wt") as f:
129 f.write(dump)
130
131 def restore_backup(self, backup_path):
132 with gzip.open(backup_path, "rt") as f:
133 sql = f.read()
134
135 self.connection.close()
136 self.connection.connect()
137 dest_conn = self.connection.connection
138 cur = dest_conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
139 for (name,) in cur.fetchall():
140 if not name.startswith("sqlite_"):
141 dest_conn.execute(f'DROP TABLE IF EXISTS "{name}"')
142 dest_conn.executescript(sql)
143 dest_conn.commit()