Plain is headed towards 1.0! Subscribe for development updates →

 1from __future__ import annotations
 2
 3import signal
 4from typing import Any
 5
 6from plain.models.backends.base.client import BaseDatabaseClient
 7
 8
 9class DatabaseClient(BaseDatabaseClient):
10    executable_name = "psql"
11
12    @classmethod
13    def settings_to_cmd_args_env(
14        cls, settings_dict: dict[str, Any], parameters: list[str]
15    ) -> tuple[list[str], dict[str, str] | None]:
16        args = [cls.executable_name]
17        options = settings_dict.get("OPTIONS", {})
18
19        host = settings_dict.get("HOST")
20        port = settings_dict.get("PORT")
21        dbname = settings_dict.get("NAME")
22        user = settings_dict.get("USER")
23        passwd = settings_dict.get("PASSWORD")
24        passfile = options.get("passfile")
25        service = options.get("service")
26        sslmode = options.get("sslmode")
27        sslrootcert = options.get("sslrootcert")
28        sslcert = options.get("sslcert")
29        sslkey = options.get("sslkey")
30
31        if not dbname and not service:
32            # Connect to the default 'postgres' db.
33            dbname = "postgres"
34        if user:
35            args += ["-U", user]
36        if host:
37            args += ["-h", host]
38        if port:
39            args += ["-p", str(port)]
40        args.extend(parameters)
41        if dbname:
42            args += [dbname]
43
44        env = {}
45        if passwd:
46            env["PGPASSWORD"] = str(passwd)
47        if service:
48            env["PGSERVICE"] = str(service)
49        if sslmode:
50            env["PGSSLMODE"] = str(sslmode)
51        if sslrootcert:
52            env["PGSSLROOTCERT"] = str(sslrootcert)
53        if sslcert:
54            env["PGSSLCERT"] = str(sslcert)
55        if sslkey:
56            env["PGSSLKEY"] = str(sslkey)
57        if passfile:
58            env["PGPASSFILE"] = str(passfile)
59        return args, (env or None)
60
61    def runshell(self, parameters: list[str]) -> None:
62        sigint_handler = signal.getsignal(signal.SIGINT)
63        try:
64            # Allow SIGINT to pass to psql to abort queries.
65            signal.signal(signal.SIGINT, signal.SIG_IGN)
66            super().runshell(parameters)
67        finally:
68            # Restore the original SIGINT handler.
69            signal.signal(signal.SIGINT, sigint_handler)