Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import gzip
  4import os
  5import subprocess
  6from pathlib import Path
  7from typing import TYPE_CHECKING
  8
  9if TYPE_CHECKING:
 10    from plain.models.backends.base.base import BaseDatabaseWrapper
 11
 12
 13class PostgresBackupClient:
 14    def __init__(self, connection: BaseDatabaseWrapper) -> None:
 15        self.connection = connection
 16
 17    def get_env(self) -> dict[str, str]:
 18        settings_dict = self.connection.settings_dict
 19        options = settings_dict.get("OPTIONS", {})
 20        env = {}
 21        if options.get("passfile"):
 22            env["PGPASSFILE"] = str(options.get("passfile"))
 23        if settings_dict.get("PASSWORD"):
 24            env["PGPASSWORD"] = str(settings_dict.get("PASSWORD"))
 25        if options.get("service"):
 26            env["PGSERVICE"] = str(options.get("service"))
 27        if options.get("sslmode"):
 28            env["PGSSLMODE"] = str(options.get("sslmode"))
 29        if options.get("sslrootcert"):
 30            env["PGSSLROOTCERT"] = str(options.get("sslrootcert"))
 31        if options.get("sslcert"):
 32            env["PGSSLCERT"] = str(options.get("sslcert"))
 33        if options.get("sslkey"):
 34            env["PGSSLKEY"] = str(options.get("sslkey"))
 35        return env
 36
 37    def create_backup(self, backup_path: Path, *, pg_dump: str = "pg_dump") -> None:
 38        settings_dict = self.connection.settings_dict
 39
 40        args = pg_dump.split()
 41        options = settings_dict.get("OPTIONS", {})
 42
 43        host = settings_dict.get("HOST")
 44        port = settings_dict.get("PORT")
 45        dbname = settings_dict.get("NAME")
 46        user = settings_dict.get("USER")
 47        service = options.get("service")
 48
 49        if not dbname and not service:
 50            # Connect to the default 'postgres' db.
 51            dbname = "postgres"
 52        if user:
 53            args += ["-U", user]
 54        if host:
 55            args += ["-h", host]
 56        if port:
 57            args += ["-p", str(port)]
 58
 59        args += ["-Fc"]
 60        # args += ["-f", backup_path]
 61
 62        if dbname:
 63            args += [dbname]
 64
 65        # Using stdin/stdout let's us use executables from within a docker container too
 66        args += ["|", "gzip", ">", str(backup_path)]
 67
 68        cmd = " ".join(args)
 69
 70        subprocess.run(
 71            cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
 72        )
 73
 74    def restore_backup(
 75        self, backup_path: Path, *, pg_restore: str = "pg_restore", psql: str = "psql"
 76    ) -> None:
 77        settings_dict = self.connection.settings_dict
 78
 79        host = settings_dict.get("HOST")
 80        port = settings_dict.get("PORT")
 81        dbname = settings_dict.get("NAME")
 82        user = settings_dict.get("USER")
 83
 84        # Build common connection args
 85        conn_args = []
 86        if user:
 87            conn_args += ["-U", user]
 88        if host:
 89            conn_args += ["-h", host]
 90        if port:
 91            conn_args += ["-p", str(port)]
 92
 93        # First, drop and recreate the database
 94        # Connect to 'template1' database to do this (works for all databases including 'postgres')
 95        drop_create_cmds = [
 96            f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{dbname}' AND pid <> pg_backend_pid()",
 97            f'DROP DATABASE IF EXISTS "{dbname}"',
 98            f'CREATE DATABASE "{dbname}"',
 99        ]
100
101        for cmd in drop_create_cmds:
102            psql_args = (
103                psql.split()
104                + conn_args
105                + [
106                    "-d",
107                    "template1",  # Always use template1
108                    "-c",
109                    cmd,
110                ]
111            )
112            subprocess.run(psql_args, env={**os.environ, **self.get_env()}, check=True)
113
114        # Now restore into the fresh database
115        args = pg_restore.split()
116        args += conn_args
117        args += ["-d", dbname]
118
119        # Using stdin/stdout let's us use executables from within a docker container too
120        args = ["gunzip", "<", str(backup_path), "|"] + args
121
122        cmd = " ".join(args)
123
124        subprocess.run(
125            cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
126        )
127
128
129class SQLiteBackupClient:
130    def __init__(self, connection: BaseDatabaseWrapper) -> None:
131        self.connection = connection
132
133    def create_backup(self, backup_path: Path) -> None:
134        self.connection.ensure_connection()
135        src_conn = self.connection.connection
136        dump = "\n".join(src_conn.iterdump())
137        with gzip.open(backup_path, "wt") as f:
138            f.write(dump)
139
140    def restore_backup(self, backup_path: Path) -> None:
141        with gzip.open(backup_path, "rt") as f:
142            sql = f.read()
143
144        self.connection.close()
145        self.connection.connect()
146        dest_conn = self.connection.connection
147        cur = dest_conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
148        for (name,) in cur.fetchall():
149            if not name.startswith("sqlite_"):
150                dest_conn.execute(f'DROP TABLE IF EXISTS "{name}"')
151        dest_conn.executescript(sql)
152        dest_conn.commit()