Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import re
  4from collections.abc import Generator
  5from pathlib import Path
  6
  7import requests
  8import tomlkit
  9
 10from plain.assets.finders import APP_ASSETS_DIR
 11
 12from .exceptions import (
 13    UnknownContentTypeError,
 14    VersionMismatchError,
 15)
 16
 17VENDOR_DIR = APP_ASSETS_DIR / "vendor"
 18
 19
 20def iter_next_version(version: str) -> Generator[str, None, None]:
 21    if len(version.split(".")) == 2:
 22        major, minor = version.split(".")
 23        yield f"{int(major) + 1}.0"
 24        yield f"{major}.{int(minor) + 1}"
 25    elif len(version.split(".")) == 3:
 26        major, minor, patch = version.split(".")
 27        yield f"{int(major) + 1}.0.0"
 28        yield f"{major}.{int(minor) + 1}.0"
 29        yield f"{major}.{minor}.{int(patch) + 1}"
 30    else:
 31        raise ValueError(f"Unable to iterate next version for {version}")
 32
 33
 34class Dependency:
 35    def __init__(self, name: str, **config: str | bool):
 36        self.name = name
 37        # Config values for these keys are always strings
 38        self.url: str = str(config.get("url", ""))
 39        self.installed: str = str(config.get("installed", ""))
 40        self.filename: str = str(config.get("filename", ""))
 41        self.sourcemap: str = str(config.get("sourcemap", ""))
 42
 43    @staticmethod
 44    def parse_version_from_url(url: str) -> str:
 45        if match := re.search(r"\d+\.\d+\.\d+", url):
 46            return match.group(0)
 47
 48        if match := re.search(r"\d+\.\d+", url):
 49            return match.group(0)
 50
 51        return ""
 52
 53    def __str__(self) -> str:
 54        return f"{self.name} -> {self.url}"
 55
 56    def download(self, version: str) -> tuple[str, requests.Response]:
 57        # If the string contains a {version} placeholder, replace it
 58        download_url = self.url.replace("{version}", version)
 59
 60        response = requests.get(download_url)
 61        response.raise_for_status()
 62
 63        content_type = response.headers.get("content-type", "").lower()
 64        allowed_types = (
 65            "application/javascript",
 66            "text/javascript",
 67            "application/json",
 68            "text/css",
 69        )
 70        if not any(content_type.startswith(allowed) for allowed in allowed_types):
 71            raise UnknownContentTypeError(
 72                f"Unknown content type for {self.name}: {content_type}"
 73            )
 74
 75        # Good chance it will redirect to a more final URL (which we hope is versioned)
 76        url = response.url
 77        version = self.parse_version_from_url(url)
 78
 79        return version, response
 80
 81    def install(self) -> Path:
 82        if self.installed:
 83            version, response = self.download(self.installed)
 84            if version != self.installed:
 85                raise VersionMismatchError(
 86                    f"Version mismatch for {self.name}: {self.installed} != {version}"
 87                )
 88            return self.vendor(response)
 89        else:
 90            return self.update()
 91
 92    def update(self) -> Path:
 93        def try_version(v: str) -> tuple[str, requests.Response | None]:
 94            try:
 95                version, response = self.download(v)
 96                return version, response
 97            except requests.RequestException:
 98                return "", None
 99
100        if not self.installed:
101            # If we don't know the installed version yet,
102            # just use the url as given
103            version, response = self.download("")
104        else:
105            version, response = try_version("latest")  # A lot of CDNs support this
106            if not version:
107                # Try the next few versions
108                for v in iter_next_version(self.installed):
109                    version, response = try_version(v)
110                    if version:
111                        break
112
113                    # TODO ideally this would keep going -- if we move to 2.0, and no 3.0, try 2.1, 2.2, etc.
114
115        if not version:
116            # Use the currently installed version if we found nothing else
117            version, response = self.download(self.installed)
118
119        if not response:
120            raise requests.RequestException("Unable to download dependency")
121
122        vendored_path = self.vendor(response)
123        self.installed = version
124
125        if self.installed:
126            # If the exact version was in the string, replace it with {version} placeholder
127            self.url = self.url.replace(self.installed, "{version}")
128
129        self.save_config()
130        return vendored_path
131
132    def save_config(self) -> None:
133        with open("pyproject.toml") as f:
134            pyproject = tomlkit.load(f)
135
136        # Force [tool.plain.vendor.dependencies] to be a table
137        dependencies = tomlkit.table()
138        dependencies.update(
139            pyproject.get("tool", {})
140            .get("plain", {})
141            .get("vendor", {})
142            .get("dependencies", {})
143        )
144
145        # Force [tool.plain.vendor.dependencies.{name}] to be an inline table
146        # name = { url = "https://example.com", installed = "1.0.0" }
147        dependencies[self.name] = tomlkit.inline_table()  # type: ignore
148        dependencies[self.name]["url"] = self.url  # type: ignore
149        dependencies[self.name]["installed"] = self.installed  # type: ignore
150        if self.filename:
151            dependencies[self.name]["filename"] = self.filename  # type: ignore
152        if self.sourcemap:
153            dependencies[self.name]["sourcemap"] = self.sourcemap  # type: ignore
154
155        # Have to give it the right structure in case they don't exist
156        if "tool" not in pyproject:  # type: ignore
157            pyproject["tool"] = tomlkit.table()  # type: ignore
158        if "plain" not in pyproject["tool"]:  # type: ignore
159            pyproject["tool"]["plain"] = tomlkit.table()  # type: ignore
160        if "vendor" not in pyproject["tool"]["plain"]:  # type: ignore
161            pyproject["tool"]["plain"]["vendor"] = tomlkit.table()  # type: ignore
162
163        pyproject["tool"]["plain"]["vendor"]["dependencies"] = dependencies  # type: ignore
164
165        with open("pyproject.toml", "w") as f:
166            f.write(tomlkit.dumps(pyproject))
167
168    def vendor(self, response: requests.Response) -> Path:
169        if not VENDOR_DIR.exists():
170            VENDOR_DIR.mkdir(parents=True)
171
172        if self.filename:
173            # Use a specific filename from config
174            filename = self.filename
175        else:
176            # Otherwise, use the filename from the URL
177            filename = response.url.split("/")[-1]  # type: ignore
178            # Remove any query string or fragment
179            filename = filename.split("?")[0].split("#")[0]
180
181        vendored_path = VENDOR_DIR / filename
182
183        with open(vendored_path, "wb") as f:
184            f.write(response.content)
185
186        # If a sourcemap is requested, download it as well
187        if self.sourcemap:
188            if isinstance(self.sourcemap, str):
189                # Use a specific filename from config
190                sourcemap_filename = self.sourcemap
191            else:
192                # Otherwise, append .map to the URL
193                sourcemap_filename = f"{filename}.map"
194
195            sourcemap_url = "/".join(
196                response.url.split("/")[:-1] + [sourcemap_filename]  # type: ignore
197            )
198            sourcemap_response = requests.get(sourcemap_url)
199            sourcemap_response.raise_for_status()
200
201            sourcemap_path = VENDOR_DIR / sourcemap_filename
202
203            with open(sourcemap_path, "wb") as f:
204                f.write(sourcemap_response.content)
205
206        return vendored_path
207
208
209def get_deps() -> list[Dependency]:
210    with open("pyproject.toml") as f:
211        pyproject = tomlkit.load(f)
212
213    config = (
214        pyproject.get("tool", {})
215        .get("plain", {})
216        .get("vendor", {})
217        .get("dependencies", {})
218    )
219
220    deps = []
221
222    for name, data in config.items():
223        deps.append(Dependency(name, **data))
224
225    return deps