1from __future__ import annotations
2
3import signal
4from typing import TYPE_CHECKING
5
6from plain.models.backends.base.client import BaseDatabaseClient
7
8if TYPE_CHECKING:
9 from plain.models.connections import DatabaseConfig
10
11
12class DatabaseClient(BaseDatabaseClient):
13 executable_name = "psql"
14
15 @classmethod
16 def settings_to_cmd_args_env(
17 cls, settings_dict: DatabaseConfig, parameters: list[str]
18 ) -> tuple[list[str], dict[str, str] | None]:
19 args = [cls.executable_name]
20 options = settings_dict.get("OPTIONS", {})
21
22 host = settings_dict.get("HOST")
23 port = settings_dict.get("PORT")
24 dbname = settings_dict.get("NAME")
25 user = settings_dict.get("USER")
26 passwd = settings_dict.get("PASSWORD")
27 passfile = options.get("passfile")
28 service = options.get("service")
29 sslmode = options.get("sslmode")
30 sslrootcert = options.get("sslrootcert")
31 sslcert = options.get("sslcert")
32 sslkey = options.get("sslkey")
33
34 if not dbname and not service:
35 # Connect to the default 'postgres' db.
36 dbname = "postgres"
37 if user:
38 args += ["-U", user]
39 if host:
40 args += ["-h", host]
41 if port:
42 args += ["-p", str(port)]
43 args.extend(parameters)
44 if dbname:
45 args += [dbname]
46
47 env = {}
48 if passwd:
49 env["PGPASSWORD"] = str(passwd)
50 if service:
51 env["PGSERVICE"] = str(service)
52 if sslmode:
53 env["PGSSLMODE"] = str(sslmode)
54 if sslrootcert:
55 env["PGSSLROOTCERT"] = str(sslrootcert)
56 if sslcert:
57 env["PGSSLCERT"] = str(sslcert)
58 if sslkey:
59 env["PGSSLKEY"] = str(sslkey)
60 if passfile:
61 env["PGPASSFILE"] = str(passfile)
62 return args, (env or None)
63
64 def runshell(self, parameters: list[str]) -> None:
65 sigint_handler = signal.getsignal(signal.SIGINT)
66 try:
67 # Allow SIGINT to pass to psql to abort queries.
68 signal.signal(signal.SIGINT, signal.SIG_IGN)
69 super().runshell(parameters)
70 finally:
71 # Restore the original SIGINT handler.
72 signal.signal(signal.SIGINT, sigint_handler)