Plain is headed towards 1.0! Subscribe for development updates →

  1# Copyright (c) Kenneth Reitz & individual contributors
  2# All rights reserved.
  3
  4# Redistribution and use in source and binary forms, with or without modification,
  5# are permitted provided that the following conditions are met:
  6
  7#     1. Redistributions of source code must retain the above copyright notice,
  8#        this list of conditions and the following disclaimer.
  9
 10#     2. Redistributions in binary form must reproduce the above copyright
 11#        notice, this list of conditions and the following disclaimer in the
 12#        documentation and/or other materials provided with the distribution.
 13
 14#     3. Neither the name of Plain nor the names of its contributors may be used
 15#        to endorse or promote products derived from this software without
 16#        specific prior written permission.
 17
 18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 19# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 20# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 21# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 22# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 23# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 24# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 25# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 26# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 27# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 28import urllib.parse as urlparse
 29from typing import Any
 30
 31from .connections import DatabaseConfig
 32
 33SCHEMES = {
 34    "postgres": "plain.models.backends.postgresql",
 35    "postgresql": "plain.models.backends.postgresql",
 36    "pgsql": "plain.models.backends.postgresql",
 37    "mysql": "plain.models.backends.mysql",
 38    "mysql2": "plain.models.backends.mysql",
 39    "sqlite": "plain.models.backends.sqlite3",
 40}
 41
 42# Register database schemes in URLs.
 43for key in SCHEMES.keys():
 44    urlparse.uses_netloc.append(key)
 45
 46
 47def parse_database_url(
 48    url: str,
 49    engine: str | None = None,
 50    conn_max_age: int | None = 0,
 51    conn_health_checks: bool = False,
 52) -> DatabaseConfig:
 53    """Parses a database URL."""
 54    if url == "sqlite://:memory:":
 55        # this is a special case, because if we pass this URL into
 56        # urlparse, urlparse will choke trying to interpret "memory"
 57        # as a port number
 58        return {"ENGINE": SCHEMES["sqlite"], "NAME": ":memory:"}
 59        # note: no other settings are required for sqlite
 60
 61    # otherwise parse the url as normal
 62    parsed_config: DatabaseConfig = {}
 63
 64    spliturl = urlparse.urlsplit(url)
 65
 66    # Split query strings from path.
 67    path = spliturl.path[1:]
 68    query = urlparse.parse_qs(spliturl.query)
 69
 70    # If we are using sqlite and we have no path, then assume we
 71    # want an in-memory database (this is the behaviour of sqlalchemy)
 72    if spliturl.scheme == "sqlite" and path == "":
 73        path = ":memory:"
 74
 75    # Handle postgres percent-encoded paths.
 76    hostname = spliturl.hostname or ""
 77    if "%" in hostname:
 78        # Switch to url.netloc to avoid lower cased paths
 79        hostname = spliturl.netloc
 80        if "@" in hostname:
 81            hostname = hostname.rsplit("@", 1)[1]
 82        # Use URL Parse library to decode % encodes
 83        hostname = urlparse.unquote(hostname)
 84
 85    # Lookup specified engine.
 86    if engine is None:
 87        engine = SCHEMES.get(spliturl.scheme)
 88        if engine is None:
 89            raise ValueError(
 90                "No support for '{}'. We support: {}".format(
 91                    spliturl.scheme, ", ".join(sorted(SCHEMES.keys()))
 92                )
 93            )
 94
 95    port = spliturl.port
 96
 97    # Update with environment configuration.
 98    parsed_config.update(
 99        {
100            "NAME": urlparse.unquote(path or ""),
101            "USER": urlparse.unquote(spliturl.username or ""),
102            "PASSWORD": urlparse.unquote(spliturl.password or ""),
103            "HOST": hostname,
104            "PORT": port or "",
105            "CONN_MAX_AGE": conn_max_age,
106            "CONN_HEALTH_CHECKS": conn_health_checks,
107            "ENGINE": engine,
108        }
109    )
110
111    # Pass the query string into OPTIONS.
112    options: dict[str, Any] = {}
113    for key, values in query.items():
114        if spliturl.scheme == "mysql" and key == "ssl-ca":
115            options["ssl"] = {"ca": values[-1]}
116            continue
117
118        options[key] = values[-1]
119
120    if options:
121        parsed_config["OPTIONS"] = options
122
123    return parsed_config
124
125
126def build_database_url(config: dict) -> str:
127    """Build a database URL from a configuration dictionary."""
128    engine = config.get("ENGINE")
129    if not engine:
130        raise ValueError("ENGINE is required to build a database URL")
131
132    reverse_schemes: dict[str, str] = {}
133    for scheme, eng in SCHEMES.items():
134        reverse_schemes.setdefault(eng, scheme)
135
136    scheme = reverse_schemes.get(engine)
137    if scheme is None:
138        raise ValueError(
139            f"No scheme known for engine '{engine}'. We support: {', '.join(sorted(SCHEMES.values()))}"
140        )
141
142    options = config.get("OPTIONS") or {}
143    query_parts: list[tuple[str, Any]] = []
144    for key, value in options.items():
145        if scheme == "mysql" and key == "ssl" and isinstance(value, dict):
146            ca = value.get("ca")
147            if ca:
148                query_parts.append(("ssl-ca", ca))
149            continue
150
151        query_parts.append((key, value))
152
153    query = urlparse.urlencode(query_parts)
154
155    if scheme == "sqlite":
156        name = config.get("NAME", "")
157        if name == ":memory:":
158            url = "sqlite://:memory:"
159        else:
160            url = f"sqlite:///{urlparse.quote(name, safe='/')}"
161
162        if query:
163            url += f"?{query}"
164
165        return url
166
167    user = urlparse.quote(str(config.get("USER", "")))
168    password = urlparse.quote(str(config.get("PASSWORD", "")))
169    host = config.get("HOST", "")
170    port = config.get("PORT", "")
171    name = urlparse.quote(str(config.get("NAME", "")))
172
173    netloc = ""
174    if user or password:
175        netloc += user
176        if password:
177            netloc += f":{password}"
178        netloc += "@"
179    netloc += host
180    if port:
181        netloc += f":{port}"
182
183    path = f"/{name}"
184    url = urlparse.urlunsplit((scheme, netloc, path, query, ""))
185    return url