Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import re
  4from collections.abc import Generator
  5from pathlib import Path
  6
  7import requests
  8import tomlkit
  9
 10from plain.assets.finders import APP_ASSETS_DIR
 11
 12from .exceptions import (
 13    UnknownContentTypeError,
 14    VersionMismatchError,
 15)
 16
 17VENDOR_DIR = APP_ASSETS_DIR / "vendor"
 18
 19
 20def iter_next_version(version: str) -> Generator[str, None, None]:
 21    if len(version.split(".")) == 2:
 22        major, minor = version.split(".")
 23        yield f"{int(major) + 1}.0"
 24        yield f"{major}.{int(minor) + 1}"
 25    elif len(version.split(".")) == 3:
 26        major, minor, patch = version.split(".")
 27        yield f"{int(major) + 1}.0.0"
 28        yield f"{major}.{int(minor) + 1}.0"
 29        yield f"{major}.{minor}.{int(patch) + 1}"
 30    else:
 31        raise ValueError(f"Unable to iterate next version for {version}")
 32
 33
 34class Dependency:
 35    def __init__(self, name: str, **config: str | bool):
 36        self.name = name
 37        self.url = config.get("url", "")
 38        self.installed = config.get("installed", "")
 39        self.filename = config.get("filename", "")
 40        self.sourcemap = config.get("sourcemap", "")
 41
 42    @staticmethod
 43    def parse_version_from_url(url: str) -> str:
 44        if match := re.search(r"\d+\.\d+\.\d+", url):
 45            return match.group(0)
 46
 47        if match := re.search(r"\d+\.\d+", url):
 48            return match.group(0)
 49
 50        return ""
 51
 52    def __str__(self) -> str:
 53        return f"{self.name} -> {self.url}"
 54
 55    def download(self, version: str) -> tuple[str, requests.Response]:
 56        # If the string contains a {version} placeholder, replace it
 57        download_url = self.url.replace("{version}", version)
 58
 59        response = requests.get(download_url)
 60        response.raise_for_status()
 61
 62        content_type = response.headers.get("content-type")
 63        if content_type.lower() not in (
 64            "application/javascript; charset=utf-8",
 65            "text/javascript; charset=utf-8",
 66            "application/json; charset=utf-8",
 67            "text/css; charset=utf-8",
 68        ):
 69            raise UnknownContentTypeError(
 70                f"Unknown content type for {self.name}: {content_type}"
 71            )
 72
 73        # Good chance it will redirect to a more final URL (which we hope is versioned)
 74        url = response.url
 75        version = self.parse_version_from_url(url)
 76
 77        return version, response
 78
 79    def install(self) -> Path:
 80        if self.installed:
 81            version, response = self.download(self.installed)
 82            if version != self.installed:
 83                raise VersionMismatchError(
 84                    f"Version mismatch for {self.name}: {self.installed} != {version}"
 85                )
 86            return self.vendor(response)
 87        else:
 88            return self.update()
 89
 90    def update(self) -> Path:
 91        def try_version(v: str) -> tuple[str, requests.Response | None]:
 92            try:
 93                version, response = self.download(v)
 94                return version, response
 95            except requests.RequestException:
 96                return "", None
 97
 98        if not self.installed:
 99            # If we don't know the installed version yet,
100            # just use the url as given
101            version, response = self.download("")
102        else:
103            version, response = try_version("latest")  # A lot of CDNs support this
104            if not version:
105                # Try the next few versions
106                for v in iter_next_version(self.installed):
107                    version, response = try_version(v)
108                    if version:
109                        break
110
111                    # TODO ideally this would keep going -- if we move to 2.0, and no 3.0, try 2.1, 2.2, etc.
112
113        if not version:
114            # Use the currently installed version if we found nothing else
115            version, response = self.download(self.installed)
116
117        vendored_path = self.vendor(response)
118        self.installed = version
119
120        if self.installed:
121            # If the exact version was in the string, replace it with {version} placeholder
122            self.url = self.url.replace(self.installed, "{version}")
123
124        self.save_config()
125        return vendored_path
126
127    def save_config(self) -> None:
128        with open("pyproject.toml") as f:
129            pyproject = tomlkit.load(f)
130
131        # Force [tool.plain.vendor.dependencies] to be a table
132        dependencies = tomlkit.table()
133        dependencies.update(
134            pyproject.get("tool", {})
135            .get("plain", {})
136            .get("vendor", {})
137            .get("dependencies", {})
138        )
139
140        # Force [tool.plain.vendor.dependencies.{name}] to be an inline table
141        # name = { url = "https://example.com", installed = "1.0.0" }
142        dependencies[self.name] = tomlkit.inline_table()  # type: ignore
143        dependencies[self.name]["url"] = self.url  # type: ignore
144        dependencies[self.name]["installed"] = self.installed  # type: ignore
145        if self.filename:
146            dependencies[self.name]["filename"] = self.filename  # type: ignore
147        if self.sourcemap:
148            dependencies[self.name]["sourcemap"] = self.sourcemap  # type: ignore
149
150        # Have to give it the right structure in case they don't exist
151        if "tool" not in pyproject:  # type: ignore
152            pyproject["tool"] = tomlkit.table()  # type: ignore
153        if "plain" not in pyproject["tool"]:  # type: ignore
154            pyproject["tool"]["plain"] = tomlkit.table()  # type: ignore
155        if "vendor" not in pyproject["tool"]["plain"]:  # type: ignore
156            pyproject["tool"]["plain"]["vendor"] = tomlkit.table()  # type: ignore
157
158        pyproject["tool"]["plain"]["vendor"]["dependencies"] = dependencies  # type: ignore
159
160        with open("pyproject.toml", "w") as f:
161            f.write(tomlkit.dumps(pyproject))
162
163    def vendor(self, response: requests.Response) -> Path:
164        if not VENDOR_DIR.exists():
165            VENDOR_DIR.mkdir(parents=True)
166
167        if self.filename:
168            # Use a specific filename from config
169            filename = self.filename
170        else:
171            # Otherwise, use the filename from the URL
172            filename = response.url.split("/")[-1]  # type: ignore
173            # Remove any query string or fragment
174            filename = filename.split("?")[0].split("#")[0]
175
176        vendored_path = VENDOR_DIR / filename
177
178        with open(vendored_path, "wb") as f:
179            f.write(response.content)
180
181        # If a sourcemap is requested, download it as well
182        if self.sourcemap:
183            if isinstance(self.sourcemap, str):
184                # Use a specific filename from config
185                sourcemap_filename = self.sourcemap
186            else:
187                # Otherwise, append .map to the URL
188                sourcemap_filename = f"{filename}.map"
189
190            sourcemap_url = "/".join(
191                response.url.split("/")[:-1] + [sourcemap_filename]  # type: ignore
192            )
193            sourcemap_response = requests.get(sourcemap_url)
194            sourcemap_response.raise_for_status()
195
196            sourcemap_path = VENDOR_DIR / sourcemap_filename
197
198            with open(sourcemap_path, "wb") as f:
199                f.write(sourcemap_response.content)
200
201        return vendored_path
202
203
204def get_deps() -> list[Dependency]:
205    with open("pyproject.toml") as f:
206        pyproject = tomlkit.load(f)
207
208    config = (
209        pyproject.get("tool", {})
210        .get("plain", {})
211        .get("vendor", {})
212        .get("dependencies", {})
213    )
214
215    deps = []
216
217    for name, data in config.items():
218        deps.append(Dependency(name, **data))
219
220    return deps