Plain is headed towards 1.0! Subscribe for development updates →

 1from __future__ import annotations
 2
 3import signal
 4from typing import TYPE_CHECKING
 5
 6from plain.models.backends.base.client import BaseDatabaseClient
 7
 8if TYPE_CHECKING:
 9    from plain.models.connections import DatabaseConfig
10
11
12class DatabaseClient(BaseDatabaseClient):
13    executable_name = "psql"
14
15    @classmethod
16    def settings_to_cmd_args_env(
17        cls, settings_dict: DatabaseConfig, parameters: list[str]
18    ) -> tuple[list[str], dict[str, str] | None]:
19        args = [cls.executable_name]
20        options = settings_dict.get("OPTIONS", {})
21
22        host = settings_dict.get("HOST")
23        port = settings_dict.get("PORT")
24        dbname = settings_dict.get("NAME")
25        user = settings_dict.get("USER")
26        passwd = settings_dict.get("PASSWORD")
27        passfile = options.get("passfile")
28        service = options.get("service")
29        sslmode = options.get("sslmode")
30        sslrootcert = options.get("sslrootcert")
31        sslcert = options.get("sslcert")
32        sslkey = options.get("sslkey")
33
34        if not dbname and not service:
35            # Connect to the default 'postgres' db.
36            dbname = "postgres"
37        if user:
38            args += ["-U", user]
39        if host:
40            args += ["-h", host]
41        if port:
42            args += ["-p", str(port)]
43        args.extend(parameters)
44        if dbname:
45            args += [dbname]
46
47        env = {}
48        if passwd:
49            env["PGPASSWORD"] = str(passwd)
50        if service:
51            env["PGSERVICE"] = str(service)
52        if sslmode:
53            env["PGSSLMODE"] = str(sslmode)
54        if sslrootcert:
55            env["PGSSLROOTCERT"] = str(sslrootcert)
56        if sslcert:
57            env["PGSSLCERT"] = str(sslcert)
58        if sslkey:
59            env["PGSSLKEY"] = str(sslkey)
60        if passfile:
61            env["PGPASSFILE"] = str(passfile)
62        return args, (env or None)
63
64    def runshell(self, parameters: list[str]) -> None:
65        sigint_handler = signal.getsignal(signal.SIGINT)
66        try:
67            # Allow SIGINT to pass to psql to abort queries.
68            signal.signal(signal.SIGINT, signal.SIG_IGN)
69            super().runshell(parameters)
70        finally:
71            # Restore the original SIGINT handler.
72            signal.signal(signal.SIGINT, sigint_handler)