Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import gzip
  4import os
  5import subprocess
  6from pathlib import Path
  7from typing import TYPE_CHECKING
  8
  9if TYPE_CHECKING:
 10    from plain.models.backends.base.base import BaseDatabaseWrapper
 11
 12
 13class PostgresBackupClient:
 14    def __init__(self, connection: BaseDatabaseWrapper) -> None:
 15        self.connection = connection
 16
 17    def get_env(self) -> dict[str, str]:
 18        settings_dict = self.connection.settings_dict
 19        options = settings_dict.get("OPTIONS", {})
 20        env = {}
 21        if options.get("passfile"):
 22            env["PGPASSFILE"] = str(options.get("passfile"))
 23        if settings_dict.get("PASSWORD"):
 24            env["PGPASSWORD"] = str(settings_dict.get("PASSWORD"))
 25        if options.get("service"):
 26            env["PGSERVICE"] = str(options.get("service"))
 27        if options.get("sslmode"):
 28            env["PGSSLMODE"] = str(options.get("sslmode"))
 29        if options.get("sslrootcert"):
 30            env["PGSSLROOTCERT"] = str(options.get("sslrootcert"))
 31        if options.get("sslcert"):
 32            env["PGSSLCERT"] = str(options.get("sslcert"))
 33        if options.get("sslkey"):
 34            env["PGSSLKEY"] = str(options.get("sslkey"))
 35        return env
 36
 37    def create_backup(self, backup_path: Path, *, pg_dump: str = "pg_dump") -> None:
 38        settings_dict = self.connection.settings_dict
 39
 40        args = pg_dump.split()
 41        options = settings_dict.get("OPTIONS", {})
 42
 43        host = settings_dict.get("HOST")
 44        port = settings_dict.get("PORT")
 45        dbname = settings_dict.get("NAME")
 46        user = settings_dict.get("USER")
 47        service = options.get("service")
 48
 49        if not dbname and not service:
 50            # Connect to the default 'postgres' db.
 51            dbname = "postgres"
 52        if user:
 53            args += ["-U", user]
 54        if host:
 55            args += ["-h", host]
 56        if port:
 57            args += ["-p", str(port)]
 58
 59        args += ["-Fc"]
 60        # args += ["-f", backup_path]
 61
 62        if dbname:
 63            args += [dbname]
 64
 65        # Using stdin/stdout let's us use executables from within a docker container too
 66        args += ["|", "gzip", ">", str(backup_path)]
 67
 68        cmd = " ".join(args)
 69
 70        subprocess.run(
 71            cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
 72        )
 73
 74    def restore_backup(
 75        self, backup_path: Path, *, pg_restore: str = "pg_restore", psql: str = "psql"
 76    ) -> None:
 77        settings_dict = self.connection.settings_dict
 78
 79        host = settings_dict.get("HOST")
 80        port = settings_dict.get("PORT")
 81        dbname = settings_dict.get("NAME")
 82        assert dbname is not None, "Database NAME is required in settings"
 83        user = settings_dict.get("USER")
 84
 85        # Build common connection args
 86        conn_args: list[str] = []
 87        if user:
 88            conn_args += ["-U", user]
 89        if host:
 90            conn_args += ["-h", host]
 91        if port:
 92            conn_args += ["-p", str(port)]
 93
 94        # First, drop and recreate the database
 95        # Connect to 'template1' database to do this (works for all databases including 'postgres')
 96        drop_create_cmds = [
 97            f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{dbname}' AND pid <> pg_backend_pid()",
 98            f'DROP DATABASE IF EXISTS "{dbname}"',
 99            f'CREATE DATABASE "{dbname}"',
100        ]
101
102        for cmd in drop_create_cmds:
103            psql_args = (
104                psql.split()
105                + conn_args
106                + [
107                    "-d",
108                    "template1",  # Always use template1
109                    "-c",
110                    cmd,
111                ]
112            )
113            subprocess.run(psql_args, env={**os.environ, **self.get_env()}, check=True)
114
115        # Now restore into the fresh database
116        args = pg_restore.split()
117        args += conn_args
118        args += ["-d", dbname]
119
120        # Using stdin/stdout let's us use executables from within a docker container too
121        args = ["gunzip", "<", str(backup_path), "|"] + args
122
123        cmd = " ".join(args)
124
125        subprocess.run(
126            cmd, env={**os.environ, **self.get_env()}, check=True, shell=True
127        )
128
129
130class SQLiteBackupClient:
131    def __init__(self, connection: BaseDatabaseWrapper) -> None:
132        self.connection = connection
133
134    def create_backup(self, backup_path: Path) -> None:
135        self.connection.ensure_connection()
136        src_conn = self.connection.connection
137        dump = "\n".join(src_conn.iterdump())
138        with gzip.open(backup_path, "wt") as f:
139            f.write(dump)
140
141    def restore_backup(self, backup_path: Path) -> None:
142        with gzip.open(backup_path, "rt") as f:
143            sql = f.read()
144
145        self.connection.close()
146        self.connection.connect()
147        dest_conn = self.connection.connection
148        cur = dest_conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
149        for (name,) in cur.fetchall():
150            if not name.startswith("sqlite_"):
151                dest_conn.execute(f'DROP TABLE IF EXISTS "{name}"')
152        dest_conn.executescript(sql)
153        dest_conn.commit()