1from __future__ import annotations
  2
  3import datetime
  4import os
  5import pathlib
  6import socket
  7import subprocess
  8import sys
  9import tempfile
 10import time
 11from typing import TYPE_CHECKING
 12from urllib.parse import urlparse
 13
 14from cryptography import x509
 15from cryptography.hazmat.primitives import hashes, serialization
 16from cryptography.hazmat.primitives.asymmetric import rsa
 17from cryptography.x509.oid import NameOID
 18
 19from plain.test import Client
 20
 21if TYPE_CHECKING:
 22    from playwright.sync_api import Browser, Page  # ty: ignore[unresolved-import]
 23
 24
 25class TestBrowser:
 26    def __init__(self, browser: Browser, database_url: str):
 27        self.browser = browser
 28
 29        self.database_url = database_url
 30        self.protocol = "https"
 31        self.host = "localhost"
 32        self.port = _get_available_port()
 33        self.base_url = f"{self.protocol}://{self.host}:{self.port}"
 34        self.server_process: subprocess.Popen | None = None
 35        self.tmpdir = tempfile.TemporaryDirectory()
 36
 37        # Set the initial browser context
 38        self.reset_context()
 39
 40    def force_login(self, user: object) -> None:
 41        # Make sure existing session cookies are cleared
 42        self.context.clear_cookies()
 43
 44        client = Client()
 45        client.force_login(user)
 46
 47        cookies = []
 48
 49        for morsel in client.cookies.values():
 50            cookie = {
 51                "name": morsel.key,
 52                "value": morsel.value,
 53                # Set this by default because playwright needs url or domain/path pair
 54                # (Plain does this in response, but this isn't going through a response)
 55                "domain": self.host,
 56            }
 57            # These fields are all optional
 58            if url := morsel.get("url"):
 59                cookie["url"] = url
 60            if domain := morsel.get("domain"):
 61                cookie["domain"] = domain
 62            if path := morsel.get("path"):
 63                cookie["path"] = path
 64            if expires := morsel.get("expires"):
 65                cookie["expires"] = expires
 66            if httponly := morsel.get("httponly"):
 67                cookie["httpOnly"] = httponly
 68            if secure := morsel.get("secure"):
 69                cookie["secure"] = secure
 70            if samesite := morsel.get("samesite"):
 71                cookie["sameSite"] = samesite
 72
 73            cookies.append(cookie)
 74
 75        self.context.add_cookies(cookies)
 76
 77    def logout(self) -> None:
 78        self.context.clear_cookies()
 79
 80    def reset_context(self) -> None:
 81        """Create a new browser context with the base URL and ignore HTTPS errors."""
 82        self.context = self.browser.new_context(
 83            base_url=self.base_url,
 84            ignore_https_errors=True,
 85        )
 86
 87    def new_page(self) -> Page:
 88        """Create a new page in the current context."""
 89        return self.context.new_page()
 90
 91    def discover_urls(self, urls: list[str]) -> list[str]:
 92        """Recursively discover all URLs on the page and related pages until we don't see anything new"""
 93
 94        def relative_url(url: str) -> str:
 95            """Convert a URL to a relative URL based on the base URL."""
 96            return url.removeprefix(self.base_url)
 97
 98        # Start with the initial URLs
 99        to_visit = {relative_url(url) for url in urls}
100        visited = set()
101
102        # Create a new page to use for all crawling
103        page = self.context.new_page()
104
105        while to_visit:
106            # Move the url from to_visit to visited
107            url = to_visit.pop()
108
109            response = page.goto(url)
110
111            visited.add(url)
112
113            # Don't process links that aren't on our site
114            if not response.url.startswith(self.base_url):
115                continue
116
117            # Get the current page's path for resolving relative URLs
118            current_page_path = response.url.removeprefix(self.base_url)
119
120            # Find all <a> links on the page
121            for link in page.query_selector_all("a"):
122                if href := link.get_attribute("href"):
123                    # Remove fragments
124                    href = href.split("#")[0]
125                    if not href:
126                        # Empty URL, skip it
127                        continue
128
129                    parsed = urlparse(href)
130                    # Skip non-http(s) links (mailto:, tel:, javascript:, etc.)
131                    if parsed.scheme and parsed.scheme not in ("http", "https"):
132                        continue
133
134                    # Skip external HTTP links
135                    if parsed.scheme in ("http", "https") and not href.startswith(
136                        self.base_url
137                    ):
138                        continue
139
140                    # Handle query-only URLs (e.g., "?stage=approved")
141                    if href.startswith("?"):
142                        href = current_page_path.split("?")[0] + href
143
144                    visit_url = relative_url(href)
145                    if visit_url not in visited:
146                        to_visit.add(visit_url)
147
148        page.close()
149
150        return list(visited)
151
152    def generate_certificates(self) -> tuple[str, str]:
153        """Generate self-signed certificates for HTTPS."""
154
155        # Generate private key
156        private_key = rsa.generate_private_key(
157            public_exponent=65537,
158            key_size=2048,
159        )
160
161        # Create certificate
162        subject = issuer = x509.Name(
163            [
164                x509.NameAttribute(NameOID.COMMON_NAME, self.host),
165            ]
166        )
167
168        cert = (
169            x509.CertificateBuilder()
170            .subject_name(subject)
171            .issuer_name(issuer)
172            .public_key(private_key.public_key())
173            .serial_number(x509.random_serial_number())
174            .not_valid_before(datetime.datetime.now(datetime.UTC))
175            .not_valid_after(
176                datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=365)
177            )
178            .add_extension(
179                x509.SubjectAlternativeName(
180                    [
181                        x509.DNSName(self.host),
182                    ]
183                ),
184                critical=False,
185            )
186            .sign(private_key, hashes.SHA256())
187        )
188
189        # Write certificate and key to files
190        cert_file = pathlib.Path(self.tmpdir.name) / "cert.pem"
191        key_file = pathlib.Path(self.tmpdir.name) / "key.pem"
192
193        with open(cert_file, "wb") as f:
194            f.write(cert.public_bytes(serialization.Encoding.PEM))
195
196        with open(key_file, "wb") as f:
197            f.write(
198                private_key.private_bytes(
199                    encoding=serialization.Encoding.PEM,
200                    format=serialization.PrivateFormat.PKCS8,
201                    encryption_algorithm=serialization.NoEncryption(),
202                )
203            )
204
205        return str(cert_file), str(key_file)
206
207    def run_server(self) -> None:
208        cert_file, key_file = self.generate_certificates()
209
210        env = os.environ.copy()
211
212        if self.database_url:
213            env["DATABASE_URL"] = self.database_url
214
215        self.server_process = subprocess.Popen(
216            [
217                sys.executable,
218                "-m",
219                "plain",
220                "server",
221                "--bind",
222                f"{self.host}:{self.port}",
223                "--certfile",
224                cert_file,
225                "--keyfile",
226                key_file,
227                "--workers",
228                "2",
229                "--timeout",
230                "10",
231                "--log-level",
232                "warning",
233            ],
234            env=env,
235        )
236
237        self._wait_for_server()
238
239    def _wait_for_server(self, timeout: float = 10.0, interval: float = 0.1) -> None:
240        """Wait until the server is accepting connections."""
241        deadline = time.monotonic() + timeout
242        while time.monotonic() < deadline:
243            # Check that the server process hasn't crashed
244            if self.server_process and self.server_process.poll() is not None:
245                raise RuntimeError(
246                    f"Server process exited with code {self.server_process.returncode}"
247                )
248            try:
249                with socket.create_connection((self.host, self.port), timeout=interval):
250                    return
251            except OSError:
252                time.sleep(interval)
253        raise RuntimeError(
254            f"Server did not start within {timeout}s at {self.host}:{self.port}"
255        )
256
257    def cleanup_server(self) -> None:
258        if self.server_process:
259            self.server_process.terminate()
260            self.server_process.wait()
261            self.server_process = None
262
263        self.tmpdir.cleanup()
264
265
266def _get_available_port() -> int:
267    """Get a randomly available port."""
268    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
269        s.bind(("", 0))
270        s.listen(1)
271        return s.getsockname()[1]