Plain is headed towards 1.0! Subscribe for development updates →

  1import gzip
  2import os
  3import subprocess
  4
  5
  6class PostgresBackupClient:
  7    def __init__(self, connection):
  8        self.connection = connection
  9
 10    def get_env(self):
 11        settings_dict = self.connection.settings_dict
 12        options = settings_dict.get("OPTIONS", {})
 13        env = {}
 14        if options.get("passfile"):
 15            env["PGPASSFILE"] = str(options.get("passfile"))
 16        if settings_dict.get("PASSWORD"):
 17            env["PGPASSWORD"] = str(settings_dict.get("PASSWORD"))
 18        if options.get("service"):
 19            env["PGSERVICE"] = str(options.get("service"))
 20        if options.get("sslmode"):
 21            env["PGSSLMODE"] = str(options.get("sslmode"))
 22        if options.get("sslrootcert"):
 23            env["PGSSLROOTCERT"] = str(options.get("sslrootcert"))
 24        if options.get("sslcert"):
 25            env["PGSSLCERT"] = str(options.get("sslcert"))
 26        if options.get("sslkey"):
 27            env["PGSSLKEY"] = str(options.get("sslkey"))
 28        return env
 29
 30    def create_backup(self, backup_path, *, pg_dump="pg_dump"):
 31        settings_dict = self.connection.settings_dict
 32
 33        args = pg_dump.split()
 34        options = settings_dict.get("OPTIONS", {})
 35
 36        host = settings_dict.get("HOST")
 37        port = settings_dict.get("PORT")
 38        dbname = settings_dict.get("NAME")
 39        user = settings_dict.get("USER")
 40        service = options.get("service")
 41
 42        if not dbname and not service:
 43            # Connect to the default 'postgres' db.
 44            dbname = "postgres"
 45        if user:
 46            args += ["-U", user]
 47        if host:
 48            args += ["-h", host]
 49        if port:
 50            args += ["-p", str(port)]
 51
 52        args += ["-Fc"]
 53        # args += ["-f", backup_path]
 54
 55        if dbname:
 56            args += [dbname]
 57
 58        # Using stdin/stdout let's us use executables from within a docker container too
 59        args += ["|", "gzip", ">", str(backup_path)]
 60
 61        cmd = " ".join(args)
 62
 63        subprocess.run(
 64            cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
 65        )
 66
 67    def restore_backup(self, backup_path, *, pg_restore="pg_restore", psql="psql"):
 68        settings_dict = self.connection.settings_dict
 69
 70        host = settings_dict.get("HOST")
 71        port = settings_dict.get("PORT")
 72        dbname = settings_dict.get("NAME")
 73        user = settings_dict.get("USER")
 74
 75        # Build common connection args
 76        conn_args = []
 77        if user:
 78            conn_args += ["-U", user]
 79        if host:
 80            conn_args += ["-h", host]
 81        if port:
 82            conn_args += ["-p", str(port)]
 83
 84        # First, drop and recreate the database
 85        # Connect to 'template1' database to do this (works for all databases including 'postgres')
 86        drop_create_cmds = [
 87            f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{dbname}' AND pid <> pg_backend_pid()",
 88            f'DROP DATABASE IF EXISTS "{dbname}"',
 89            f'CREATE DATABASE "{dbname}"',
 90        ]
 91
 92        for cmd in drop_create_cmds:
 93            psql_args = (
 94                psql.split()
 95                + conn_args
 96                + [
 97                    "-d",
 98                    "template1",  # Always use template1
 99                    "-c",
100                    cmd,
101                ]
102            )
103            subprocess.run(psql_args, env={**os.environ, **self.get_env()}, check=True)
104
105        # Now restore into the fresh database
106        args = pg_restore.split()
107        args += conn_args
108        args += ["-d", dbname]
109
110        # Using stdin/stdout let's us use executables from within a docker container too
111        args = ["gunzip", "<", str(backup_path), "|"] + args
112
113        cmd = " ".join(args)
114
115        subprocess.run(
116            cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
117        )
118
119
120class SQLiteBackupClient:
121    def __init__(self, connection):
122        self.connection = connection
123
124    def create_backup(self, backup_path):
125        self.connection.ensure_connection()
126        src_conn = self.connection.connection
127        dump = "\n".join(src_conn.iterdump())
128        with gzip.open(backup_path, "wt") as f:
129            f.write(dump)
130
131    def restore_backup(self, backup_path):
132        with gzip.open(backup_path, "rt") as f:
133            sql = f.read()
134
135        self.connection.close()
136        self.connection.connect()
137        dest_conn = self.connection.connection
138        cur = dest_conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
139        for (name,) in cur.fetchall():
140            if not name.startswith("sqlite_"):
141                dest_conn.execute(f'DROP TABLE IF EXISTS "{name}"')
142        dest_conn.executescript(sql)
143        dest_conn.commit()