1from __future__ import annotations
2
3import signal
4from typing import Any
5
6from plain.models.backends.base.client import BaseDatabaseClient
7
8
9class DatabaseClient(BaseDatabaseClient):
10 executable_name = "psql"
11
12 @classmethod
13 def settings_to_cmd_args_env(
14 cls, settings_dict: dict[str, Any], parameters: list[str]
15 ) -> tuple[list[str], dict[str, str] | None]:
16 args = [cls.executable_name]
17 options = settings_dict.get("OPTIONS", {})
18
19 host = settings_dict.get("HOST")
20 port = settings_dict.get("PORT")
21 dbname = settings_dict.get("NAME")
22 user = settings_dict.get("USER")
23 passwd = settings_dict.get("PASSWORD")
24 passfile = options.get("passfile")
25 service = options.get("service")
26 sslmode = options.get("sslmode")
27 sslrootcert = options.get("sslrootcert")
28 sslcert = options.get("sslcert")
29 sslkey = options.get("sslkey")
30
31 if not dbname and not service:
32 # Connect to the default 'postgres' db.
33 dbname = "postgres"
34 if user:
35 args += ["-U", user]
36 if host:
37 args += ["-h", host]
38 if port:
39 args += ["-p", str(port)]
40 args.extend(parameters)
41 if dbname:
42 args += [dbname]
43
44 env = {}
45 if passwd:
46 env["PGPASSWORD"] = str(passwd)
47 if service:
48 env["PGSERVICE"] = str(service)
49 if sslmode:
50 env["PGSSLMODE"] = str(sslmode)
51 if sslrootcert:
52 env["PGSSLROOTCERT"] = str(sslrootcert)
53 if sslcert:
54 env["PGSSLCERT"] = str(sslcert)
55 if sslkey:
56 env["PGSSLKEY"] = str(sslkey)
57 if passfile:
58 env["PGPASSFILE"] = str(passfile)
59 return args, (env or None)
60
61 def runshell(self, parameters: list[str]) -> None:
62 sigint_handler = signal.getsignal(signal.SIGINT)
63 try:
64 # Allow SIGINT to pass to psql to abort queries.
65 signal.signal(signal.SIGINT, signal.SIG_IGN)
66 super().runshell(parameters)
67 finally:
68 # Restore the original SIGINT handler.
69 signal.signal(signal.SIGINT, sigint_handler)