Plain is headed towards 1.0! Subscribe for development updates →

  1import gzip
  2import os
  3import subprocess
  4
  5
  6class PostgresBackupClient:
  7    def __init__(self, connection):
  8        self.connection = connection
  9
 10    def get_env(self):
 11        settings_dict = self.connection.settings_dict
 12        options = settings_dict.get("OPTIONS", {})
 13        env = {}
 14        if options.get("passfile"):
 15            env["PGPASSFILE"] = str(options.get("passfile"))
 16        if settings_dict.get("PASSWORD"):
 17            env["PGPASSWORD"] = str(settings_dict.get("PASSWORD"))
 18        if options.get("service"):
 19            env["PGSERVICE"] = str(options.get("service"))
 20        if options.get("sslmode"):
 21            env["PGSSLMODE"] = str(options.get("sslmode"))
 22        if options.get("sslrootcert"):
 23            env["PGSSLROOTCERT"] = str(options.get("sslrootcert"))
 24        if options.get("sslcert"):
 25            env["PGSSLCERT"] = str(options.get("sslcert"))
 26        if options.get("sslkey"):
 27            env["PGSSLKEY"] = str(options.get("sslkey"))
 28        return env
 29
 30    def create_backup(self, backup_path, *, pg_dump="pg_dump"):
 31        settings_dict = self.connection.settings_dict
 32
 33        args = pg_dump.split()
 34        options = settings_dict.get("OPTIONS", {})
 35
 36        host = settings_dict.get("HOST")
 37        port = settings_dict.get("PORT")
 38        dbname = settings_dict.get("NAME")
 39        user = settings_dict.get("USER")
 40        service = options.get("service")
 41
 42        if not dbname and not service:
 43            # Connect to the default 'postgres' db.
 44            dbname = "postgres"
 45        if user:
 46            args += ["-U", user]
 47        if host:
 48            args += ["-h", host]
 49        if port:
 50            args += ["-p", str(port)]
 51
 52        args += ["-Fc"]
 53        # args += ["-f", backup_path]
 54
 55        if dbname:
 56            args += [dbname]
 57
 58        # Using stdin/stdout let's us use executables from within a docker container too
 59        args += ["|", "gzip", ">", str(backup_path)]
 60
 61        cmd = " ".join(args)
 62
 63        subprocess.run(
 64            cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
 65        )
 66
 67    def restore_backup(self, backup_path, *, pg_restore="pg_restore"):
 68        settings_dict = self.connection.settings_dict
 69
 70        args = pg_restore.split()
 71        options = settings_dict.get("OPTIONS", {})
 72
 73        host = settings_dict.get("HOST")
 74        port = settings_dict.get("PORT")
 75        dbname = settings_dict.get("NAME")
 76        user = settings_dict.get("USER")
 77        service = options.get("service")
 78
 79        if not dbname and not service:
 80            # Connect to the default 'postgres' db.
 81            dbname = "postgres"
 82        if user:
 83            args += ["-U", user]
 84        if host:
 85            args += ["-h", host]
 86        if port:
 87            args += ["-p", str(port)]
 88
 89        args += ["--clean"]  # Drop existing tables
 90        args += ["-d", dbname]
 91
 92        # Using stdin/stdout let's us use executables from within a docker container too
 93        args = ["gunzip", "<", str(backup_path), "|"] + args
 94
 95        cmd = " ".join(args)
 96
 97        subprocess.run(
 98            cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
 99        )
100
101
102class SQLiteBackupClient:
103    def __init__(self, connection):
104        self.connection = connection
105
106    def create_backup(self, backup_path):
107        self.connection.ensure_connection()
108        src_conn = self.connection.connection
109        dump = "\n".join(src_conn.iterdump())
110        with gzip.open(backup_path, "wt") as f:
111            f.write(dump)
112
113    def restore_backup(self, backup_path):
114        with gzip.open(backup_path, "rt") as f:
115            sql = f.read()
116
117        self.connection.close()
118        self.connection.connect()
119        dest_conn = self.connection.connection
120        cur = dest_conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
121        for (name,) in cur.fetchall():
122            if not name.startswith("sqlite_"):
123                dest_conn.execute(f'DROP TABLE IF EXISTS "{name}"')
124        dest_conn.executescript(sql)
125        dest_conn.commit()