1from __future__ import annotations
2
3import datetime
4import os
5import pathlib
6import socket
7import subprocess
8import sys
9import tempfile
10import time
11from typing import TYPE_CHECKING
12from urllib.parse import urlparse
13
14from cryptography import x509
15from cryptography.hazmat.primitives import hashes, serialization
16from cryptography.hazmat.primitives.asymmetric import rsa
17from cryptography.x509.oid import NameOID
18
19from plain.test import Client
20
21if TYPE_CHECKING:
22 from playwright.sync_api import Browser, Page # ty: ignore[unresolved-import]
23
24
25class TestBrowser:
26 def __init__(self, browser: Browser, database_url: str):
27 self.browser = browser
28
29 self.database_url = database_url
30 self.protocol = "https"
31 self.host = "localhost"
32 self.port = _get_available_port()
33 self.base_url = f"{self.protocol}://{self.host}:{self.port}"
34 self.server_process: subprocess.Popen | None = None
35 self.tmpdir = tempfile.TemporaryDirectory()
36
37 # Set the initial browser context
38 self.reset_context()
39
40 def force_login(self, user: object) -> None:
41 # Make sure existing session cookies are cleared
42 self.context.clear_cookies()
43
44 client = Client()
45 client.force_login(user)
46
47 cookies = []
48
49 for morsel in client.cookies.values():
50 cookie = {
51 "name": morsel.key,
52 "value": morsel.value,
53 # Set this by default because playwright needs url or domain/path pair
54 # (Plain does this in response, but this isn't going through a response)
55 "domain": self.host,
56 }
57 # These fields are all optional
58 if url := morsel.get("url"):
59 cookie["url"] = url
60 if domain := morsel.get("domain"):
61 cookie["domain"] = domain
62 if path := morsel.get("path"):
63 cookie["path"] = path
64 if expires := morsel.get("expires"):
65 cookie["expires"] = expires
66 if httponly := morsel.get("httponly"):
67 cookie["httpOnly"] = httponly
68 if secure := morsel.get("secure"):
69 cookie["secure"] = secure
70 if samesite := morsel.get("samesite"):
71 cookie["sameSite"] = samesite
72
73 cookies.append(cookie)
74
75 self.context.add_cookies(cookies)
76
77 def logout(self) -> None:
78 self.context.clear_cookies()
79
80 def reset_context(self) -> None:
81 """Create a new browser context with the base URL and ignore HTTPS errors."""
82 self.context = self.browser.new_context(
83 base_url=self.base_url,
84 ignore_https_errors=True,
85 )
86
87 def new_page(self) -> Page:
88 """Create a new page in the current context."""
89 return self.context.new_page()
90
91 def discover_urls(self, urls: list[str]) -> list[str]:
92 """Recursively discover all URLs on the page and related pages until we don't see anything new"""
93
94 def relative_url(url: str) -> str:
95 """Convert a URL to a relative URL based on the base URL."""
96 return url.removeprefix(self.base_url)
97
98 # Start with the initial URLs
99 to_visit = {relative_url(url) for url in urls}
100 visited = set()
101
102 # Create a new page to use for all crawling
103 page = self.context.new_page()
104
105 while to_visit:
106 # Move the url from to_visit to visited
107 url = to_visit.pop()
108
109 response = page.goto(url)
110
111 visited.add(url)
112
113 # Don't process links that aren't on our site
114 if not response.url.startswith(self.base_url):
115 continue
116
117 # Get the current page's path for resolving relative URLs
118 current_page_path = response.url.removeprefix(self.base_url)
119
120 # Find all <a> links on the page
121 for link in page.query_selector_all("a"):
122 if href := link.get_attribute("href"):
123 # Remove fragments
124 href = href.split("#")[0]
125 if not href:
126 # Empty URL, skip it
127 continue
128
129 parsed = urlparse(href)
130 # Skip non-http(s) links (mailto:, tel:, javascript:, etc.)
131 if parsed.scheme and parsed.scheme not in ("http", "https"):
132 continue
133
134 # Skip external HTTP links
135 if parsed.scheme in ("http", "https") and not href.startswith(
136 self.base_url
137 ):
138 continue
139
140 # Handle query-only URLs (e.g., "?stage=approved")
141 if href.startswith("?"):
142 href = current_page_path.split("?")[0] + href
143
144 visit_url = relative_url(href)
145 if visit_url not in visited:
146 to_visit.add(visit_url)
147
148 page.close()
149
150 return list(visited)
151
152 def generate_certificates(self) -> tuple[str, str]:
153 """Generate self-signed certificates for HTTPS."""
154
155 # Generate private key
156 private_key = rsa.generate_private_key(
157 public_exponent=65537,
158 key_size=2048,
159 )
160
161 # Create certificate
162 subject = issuer = x509.Name(
163 [
164 x509.NameAttribute(NameOID.COMMON_NAME, self.host),
165 ]
166 )
167
168 cert = (
169 x509.CertificateBuilder()
170 .subject_name(subject)
171 .issuer_name(issuer)
172 .public_key(private_key.public_key())
173 .serial_number(x509.random_serial_number())
174 .not_valid_before(datetime.datetime.now(datetime.UTC))
175 .not_valid_after(
176 datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=365)
177 )
178 .add_extension(
179 x509.SubjectAlternativeName(
180 [
181 x509.DNSName(self.host),
182 ]
183 ),
184 critical=False,
185 )
186 .sign(private_key, hashes.SHA256())
187 )
188
189 # Write certificate and key to files
190 cert_file = pathlib.Path(self.tmpdir.name) / "cert.pem"
191 key_file = pathlib.Path(self.tmpdir.name) / "key.pem"
192
193 with open(cert_file, "wb") as f:
194 f.write(cert.public_bytes(serialization.Encoding.PEM))
195
196 with open(key_file, "wb") as f:
197 f.write(
198 private_key.private_bytes(
199 encoding=serialization.Encoding.PEM,
200 format=serialization.PrivateFormat.PKCS8,
201 encryption_algorithm=serialization.NoEncryption(),
202 )
203 )
204
205 return str(cert_file), str(key_file)
206
207 def run_server(self) -> None:
208 cert_file, key_file = self.generate_certificates()
209
210 env = os.environ.copy()
211
212 if self.database_url:
213 env["DATABASE_URL"] = self.database_url
214
215 self.server_process = subprocess.Popen(
216 [
217 sys.executable,
218 "-m",
219 "plain",
220 "server",
221 "--bind",
222 f"{self.host}:{self.port}",
223 "--certfile",
224 cert_file,
225 "--keyfile",
226 key_file,
227 "--workers",
228 "2",
229 "--timeout",
230 "10",
231 "--log-level",
232 "warning",
233 ],
234 env=env,
235 )
236
237 self._wait_for_server()
238
239 def _wait_for_server(self, timeout: float = 10.0, interval: float = 0.1) -> None:
240 """Wait until the server is accepting connections."""
241 deadline = time.monotonic() + timeout
242 while time.monotonic() < deadline:
243 # Check that the server process hasn't crashed
244 if self.server_process and self.server_process.poll() is not None:
245 raise RuntimeError(
246 f"Server process exited with code {self.server_process.returncode}"
247 )
248 try:
249 with socket.create_connection((self.host, self.port), timeout=interval):
250 return
251 except OSError:
252 time.sleep(interval)
253 raise RuntimeError(
254 f"Server did not start within {timeout}s at {self.host}:{self.port}"
255 )
256
257 def cleanup_server(self) -> None:
258 if self.server_process:
259 self.server_process.terminate()
260 self.server_process.wait()
261 self.server_process = None
262
263 self.tmpdir.cleanup()
264
265
266def _get_available_port() -> int:
267 """Get a randomly available port."""
268 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
269 s.bind(("", 0))
270 s.listen(1)
271 return s.getsockname()[1]