Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import re
  4from collections.abc import Generator
  5from pathlib import Path
  6
  7import requests
  8import tomlkit
  9
 10from plain.assets.finders import APP_ASSETS_DIR
 11
 12from .exceptions import (
 13    UnknownContentTypeError,
 14    VersionMismatchError,
 15)
 16
 17VENDOR_DIR = APP_ASSETS_DIR / "vendor"
 18
 19
 20def iter_next_version(version: str) -> Generator[str, None, None]:
 21    if len(version.split(".")) == 2:
 22        major, minor = version.split(".")
 23        yield f"{int(major) + 1}.0"
 24        yield f"{major}.{int(minor) + 1}"
 25    elif len(version.split(".")) == 3:
 26        major, minor, patch = version.split(".")
 27        yield f"{int(major) + 1}.0.0"
 28        yield f"{major}.{int(minor) + 1}.0"
 29        yield f"{major}.{minor}.{int(patch) + 1}"
 30    else:
 31        raise ValueError(f"Unable to iterate next version for {version}")
 32
 33
 34class Dependency:
 35    def __init__(self, name: str, **config: str | bool):
 36        self.name = name
 37        # Config values for these keys are always strings
 38        self.url: str = str(config.get("url", ""))
 39        self.installed: str = str(config.get("installed", ""))
 40        self.filename: str = str(config.get("filename", ""))
 41        # sourcemap can be bool (True/False) or str (custom filename)
 42        self.sourcemap: bool | str = config.get("sourcemap", False)
 43
 44    @staticmethod
 45    def parse_version_from_url(url: str) -> str:
 46        if match := re.search(r"\d+\.\d+\.\d+", url):
 47            return match.group(0)
 48
 49        if match := re.search(r"\d+\.\d+", url):
 50            return match.group(0)
 51
 52        return ""
 53
 54    def __str__(self) -> str:
 55        return f"{self.name} -> {self.url}"
 56
 57    def download(self, version: str) -> tuple[str, requests.Response]:
 58        # If the string contains a {version} placeholder, replace it
 59        download_url = self.url.replace("{version}", version)
 60
 61        response = requests.get(download_url)
 62        response.raise_for_status()
 63
 64        content_type = response.headers.get("content-type", "").lower()
 65        allowed_types = (
 66            "application/javascript",
 67            "text/javascript",
 68            "application/json",
 69            "text/css",
 70        )
 71        if not any(content_type.startswith(allowed) for allowed in allowed_types):
 72            raise UnknownContentTypeError(
 73                f"Unknown content type for {self.name}: {content_type}"
 74            )
 75
 76        # Good chance it will redirect to a more final URL (which we hope is versioned)
 77        url = response.url
 78        version = self.parse_version_from_url(url)
 79
 80        return version, response
 81
 82    def install(self) -> Path:
 83        if self.installed:
 84            version, response = self.download(self.installed)
 85            if version != self.installed:
 86                raise VersionMismatchError(
 87                    f"Version mismatch for {self.name}: {self.installed} != {version}"
 88                )
 89            return self.vendor(response)
 90        else:
 91            return self.update()
 92
 93    def update(self) -> Path:
 94        def try_version(v: str) -> tuple[str, requests.Response | None]:
 95            try:
 96                version, response = self.download(v)
 97                return version, response
 98            except requests.RequestException:
 99                return "", None
100
101        if not self.installed:
102            # If we don't know the installed version yet,
103            # just use the url as given
104            version, response = self.download("")
105        else:
106            version, response = try_version("latest")  # A lot of CDNs support this
107            if not version:
108                # Try the next few versions
109                for v in iter_next_version(self.installed):
110                    version, response = try_version(v)
111                    if version:
112                        break
113
114                    # TODO ideally this would keep going -- if we move to 2.0, and no 3.0, try 2.1, 2.2, etc.
115
116        if not version:
117            # Use the currently installed version if we found nothing else
118            version, response = self.download(self.installed)
119
120        if not response:
121            raise requests.RequestException("Unable to download dependency")
122
123        vendored_path = self.vendor(response)
124        self.installed = version
125
126        if self.installed:
127            # If the exact version was in the string, replace it with {version} placeholder
128            self.url = self.url.replace(self.installed, "{version}")
129
130        self.save_config()
131        return vendored_path
132
133    def save_config(self) -> None:
134        with open("pyproject.toml") as f:
135            pyproject = tomlkit.load(f)
136
137        # Force [tool.plain.vendor.dependencies] to be a table
138        dependencies = tomlkit.table()
139        dependencies.update(
140            pyproject.get("tool", {})
141            .get("plain", {})
142            .get("vendor", {})
143            .get("dependencies", {})
144        )
145
146        # Force [tool.plain.vendor.dependencies.{name}] to be an inline table
147        # name = { url = "https://example.com", installed = "1.0.0" }
148        dependencies[self.name] = tomlkit.inline_table()
149        dependencies[self.name]["url"] = self.url  # type: ignore[index]
150        dependencies[self.name]["installed"] = self.installed  # type: ignore[index]
151        if self.filename:
152            dependencies[self.name]["filename"] = self.filename  # type: ignore[index]
153        if self.sourcemap:
154            dependencies[self.name]["sourcemap"] = self.sourcemap  # type: ignore[index]
155
156        # Have to give it the right structure in case they don't exist
157        if "tool" not in pyproject:
158            pyproject["tool"] = tomlkit.table()
159        if "plain" not in pyproject["tool"]:  # type: ignore[operator]
160            pyproject["tool"]["plain"] = tomlkit.table()  # type: ignore[index]
161        if "vendor" not in pyproject["tool"]["plain"]:  # type: ignore[operator, index]
162            pyproject["tool"]["plain"]["vendor"] = tomlkit.table()  # type: ignore[index]
163
164        pyproject["tool"]["plain"]["vendor"]["dependencies"] = dependencies  # type: ignore[index]
165
166        with open("pyproject.toml", "w") as f:
167            f.write(tomlkit.dumps(pyproject))
168
169    def vendor(self, response: requests.Response) -> Path:
170        if not VENDOR_DIR.exists():
171            VENDOR_DIR.mkdir(parents=True)
172
173        if self.filename:
174            # Use a specific filename from config
175            filename = self.filename
176        else:
177            # Otherwise, use the filename from the URL
178            assert response.url is not None
179            filename = response.url.split("/")[-1]
180            # Remove any query string or fragment
181            filename = filename.split("?")[0].split("#")[0]
182
183        vendored_path = VENDOR_DIR / filename
184
185        with open(vendored_path, "wb") as f:
186            f.write(response.content)
187
188        # If a sourcemap is requested, download it as well
189        if self.sourcemap:
190            if self.sourcemap is True:
191                # Default: append .map to the filename
192                sourcemap_filename = f"{filename}.map"
193            else:
194                # Use a specific filename from config
195                sourcemap_filename = str(self.sourcemap)
196
197            assert response.url is not None
198            sourcemap_url = "/".join(
199                response.url.split("/")[:-1] + [sourcemap_filename]
200            )
201            sourcemap_response = requests.get(sourcemap_url)
202            sourcemap_response.raise_for_status()
203
204            sourcemap_path = VENDOR_DIR / sourcemap_filename
205
206            with open(sourcemap_path, "wb") as f:
207                f.write(sourcemap_response.content)
208
209        return vendored_path
210
211
212def get_deps() -> list[Dependency]:
213    with open("pyproject.toml") as f:
214        pyproject = tomlkit.load(f)
215
216    config = (
217        pyproject.get("tool", {})
218        .get("plain", {})
219        .get("vendor", {})
220        .get("dependencies", {})
221    )
222
223    deps = []
224
225    for name, data in config.items():
226        deps.append(Dependency(name, **data))
227
228    return deps