1import datetime
2import os
3import pathlib
4import socket
5import subprocess
6import sys
7import tempfile
8import time
9from typing import TYPE_CHECKING
10from urllib.parse import urlparse
11
12from cryptography import x509
13from cryptography.hazmat.primitives import hashes, serialization
14from cryptography.hazmat.primitives.asymmetric import rsa
15from cryptography.x509.oid import NameOID
16
17from plain.test import Client
18
19if TYPE_CHECKING:
20 from playwright.sync_api import Browser
21
22
23class TestBrowser:
24 def __init__(self, browser: "Browser", database_url: str):
25 self.browser = browser
26
27 self.database_url = database_url
28 self.wsgi = "plain.wsgi:app"
29 self.protocol = "https"
30 self.host = "localhost"
31 self.port = _get_available_port()
32 self.base_url = f"{self.protocol}://{self.host}:{self.port}"
33 self.server_process = None
34 self.tmpdir = tempfile.TemporaryDirectory()
35
36 # Set the initial browser context
37 self.reset_context()
38
39 def force_login(self, user):
40 # Make sure existing session cookies are cleared
41 self.context.clear_cookies()
42
43 client = Client()
44 client.force_login(user)
45
46 cookies = []
47
48 for morsel in client.cookies.values():
49 cookie = {
50 "name": morsel.key,
51 "value": morsel.value,
52 # Set this by default because playwright needs url or domain/path pair
53 # (Plain does this in response, but this isn't going through a response)
54 "domain": self.host,
55 }
56 # These fields are all optional
57 if url := morsel.get("url"):
58 cookie["url"] = url
59 if domain := morsel.get("domain"):
60 cookie["domain"] = domain
61 if path := morsel.get("path"):
62 cookie["path"] = path
63 if expires := morsel.get("expires"):
64 cookie["expires"] = expires
65 if httponly := morsel.get("httponly"):
66 cookie["httpOnly"] = httponly
67 if secure := morsel.get("secure"):
68 cookie["secure"] = secure
69 if samesite := morsel.get("samesite"):
70 cookie["sameSite"] = samesite
71
72 cookies.append(cookie)
73
74 self.context.add_cookies(cookies)
75
76 def logout(self):
77 self.context.clear_cookies()
78
79 def reset_context(self):
80 """Create a new browser context with the base URL and ignore HTTPS errors."""
81 self.context = self.browser.new_context(
82 base_url=self.base_url,
83 ignore_https_errors=True,
84 )
85
86 def new_page(self):
87 """Create a new page in the current context."""
88 return self.context.new_page()
89
90 def discover_urls(self, *urls) -> list[str]:
91 """Recursively discover all URLs on the page and related pages until we don't see anything new"""
92
93 def relative_url(url: str) -> str:
94 """Convert a URL to a relative URL based on the base URL."""
95 if url.startswith(self.base_url):
96 return url[len(self.base_url) :]
97 return url
98
99 # Start with the initial URLs
100 to_visit = {relative_url(url) for url in urls}
101 visited = set()
102
103 # Create a new page to use for all crawling
104 page = self.context.new_page()
105
106 while to_visit:
107 # Move the url from to_visit to visited
108 url = to_visit.pop()
109
110 response = page.goto(url)
111
112 visited.add(url)
113
114 # Don't process links that aren't on our site
115 if not response.url.startswith(self.base_url):
116 continue
117
118 # Find all <a> links on the page
119 for link in page.query_selector_all("a"):
120 if href := link.get_attribute("href"):
121 # Remove fragments
122 href = href.split("#")[0]
123 if not href:
124 # Empty URL, skip it
125 continue
126
127 parsed = urlparse(href)
128 # Skip non-http(s) links (mailto:, tel:, javascript:, etc.)
129 if parsed.scheme and parsed.scheme not in ("http", "https"):
130 continue
131
132 # Skip external HTTP links
133 if parsed.scheme in ("http", "https") and not href.startswith(
134 self.base_url
135 ):
136 continue
137
138 visit_url = relative_url(href)
139 if visit_url not in visited:
140 to_visit.add(visit_url)
141
142 page.close()
143
144 return list(visited)
145
146 def generate_certificates(self):
147 """Generate self-signed certificates for HTTPS."""
148
149 # Generate private key
150 private_key = rsa.generate_private_key(
151 public_exponent=65537,
152 key_size=2048,
153 )
154
155 # Create certificate
156 subject = issuer = x509.Name(
157 [
158 x509.NameAttribute(NameOID.COMMON_NAME, self.host),
159 ]
160 )
161
162 cert = (
163 x509.CertificateBuilder()
164 .subject_name(subject)
165 .issuer_name(issuer)
166 .public_key(private_key.public_key())
167 .serial_number(x509.random_serial_number())
168 .not_valid_before(datetime.datetime.now(datetime.UTC))
169 .not_valid_after(
170 datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=365)
171 )
172 .add_extension(
173 x509.SubjectAlternativeName(
174 [
175 x509.DNSName(self.host),
176 ]
177 ),
178 critical=False,
179 )
180 .sign(private_key, hashes.SHA256())
181 )
182
183 # Write certificate and key to files
184 cert_file = pathlib.Path(self.tmpdir.name) / "cert.pem"
185 key_file = pathlib.Path(self.tmpdir.name) / "key.pem"
186
187 with open(cert_file, "wb") as f:
188 f.write(cert.public_bytes(serialization.Encoding.PEM))
189
190 with open(key_file, "wb") as f:
191 f.write(
192 private_key.private_bytes(
193 encoding=serialization.Encoding.PEM,
194 format=serialization.PrivateFormat.PKCS8,
195 encryption_algorithm=serialization.NoEncryption(),
196 )
197 )
198
199 return str(cert_file), str(key_file)
200
201 def run_server(self):
202 cert_file, key_file = self.generate_certificates()
203
204 env = os.environ.copy()
205
206 if self.database_url:
207 env["DATABASE_URL"] = self.database_url
208
209 gunicorn = pathlib.Path(sys.executable).with_name("gunicorn")
210
211 self.server_process = subprocess.Popen(
212 [
213 str(gunicorn),
214 self.wsgi,
215 "--bind",
216 f"{self.host}:{self.port}",
217 "--certfile",
218 cert_file,
219 "--keyfile",
220 key_file,
221 "--workers",
222 "2",
223 "--timeout",
224 "10",
225 "--log-level",
226 "warning",
227 ],
228 env=env,
229 )
230
231 time.sleep(0.7) # quick grace period
232
233 def cleanup_server(self):
234 if self.server_process:
235 self.server_process.terminate()
236 self.server_process.wait()
237 self.server_process = None
238
239 self.tmpdir.cleanup()
240
241
242def _get_available_port() -> int:
243 """Get a randomly available port."""
244 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
245 s.bind(("", 0))
246 s.listen(1)
247 port = s.getsockname()[1]
248 return port