Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import os
  3import pathlib
  4import socket
  5import subprocess
  6import sys
  7import tempfile
  8import time
  9from typing import TYPE_CHECKING
 10from urllib.parse import urlparse
 11
 12from cryptography import x509
 13from cryptography.hazmat.primitives import hashes, serialization
 14from cryptography.hazmat.primitives.asymmetric import rsa
 15from cryptography.x509.oid import NameOID
 16
 17from plain.test import Client
 18
 19if TYPE_CHECKING:
 20    from playwright.sync_api import Browser
 21
 22
 23class TestBrowser:
 24    def __init__(self, browser: "Browser", database_url: str):
 25        self.browser = browser
 26
 27        self.database_url = database_url
 28        self.wsgi = "plain.wsgi:app"
 29        self.protocol = "https"
 30        self.host = "localhost"
 31        self.port = _get_available_port()
 32        self.base_url = f"{self.protocol}://{self.host}:{self.port}"
 33        self.server_process = None
 34        self.tmpdir = tempfile.TemporaryDirectory()
 35
 36        # Set the initial browser context
 37        self.reset_context()
 38
 39    def force_login(self, user):
 40        # Make sure existing session cookies are cleared
 41        self.context.clear_cookies()
 42
 43        client = Client()
 44        client.force_login(user)
 45
 46        cookies = []
 47
 48        for morsel in client.cookies.values():
 49            cookie = {
 50                "name": morsel.key,
 51                "value": morsel.value,
 52                # Set this by default because playwright needs url or domain/path pair
 53                # (Plain does this in response, but this isn't going through a response)
 54                "domain": self.host,
 55            }
 56            # These fields are all optional
 57            if url := morsel.get("url"):
 58                cookie["url"] = url
 59            if domain := morsel.get("domain"):
 60                cookie["domain"] = domain
 61            if path := morsel.get("path"):
 62                cookie["path"] = path
 63            if expires := morsel.get("expires"):
 64                cookie["expires"] = expires
 65            if httponly := morsel.get("httponly"):
 66                cookie["httpOnly"] = httponly
 67            if secure := morsel.get("secure"):
 68                cookie["secure"] = secure
 69            if samesite := morsel.get("samesite"):
 70                cookie["sameSite"] = samesite
 71
 72            cookies.append(cookie)
 73
 74        self.context.add_cookies(cookies)
 75
 76    def logout(self):
 77        self.context.clear_cookies()
 78
 79    def reset_context(self):
 80        """Create a new browser context with the base URL and ignore HTTPS errors."""
 81        self.context = self.browser.new_context(
 82            base_url=self.base_url,
 83            ignore_https_errors=True,
 84        )
 85
 86    def new_page(self):
 87        """Create a new page in the current context."""
 88        return self.context.new_page()
 89
 90    def discover_urls(self, *urls) -> list[str]:
 91        """Recursively discover all URLs on the page and related pages until we don't see anything new"""
 92
 93        def relative_url(url: str) -> str:
 94            """Convert a URL to a relative URL based on the base URL."""
 95            if url.startswith(self.base_url):
 96                return url[len(self.base_url) :]
 97            return url
 98
 99        # Start with the initial URLs
100        to_visit = {relative_url(url) for url in urls}
101        visited = set()
102
103        # Create a new page to use for all crawling
104        page = self.context.new_page()
105
106        while to_visit:
107            # Move the url from to_visit to visited
108            url = to_visit.pop()
109
110            response = page.goto(url)
111
112            visited.add(url)
113
114            # Don't process links that aren't on our site
115            if not response.url.startswith(self.base_url):
116                continue
117
118            # Find all <a> links on the page
119            for link in page.query_selector_all("a"):
120                if href := link.get_attribute("href"):
121                    # Remove fragments
122                    href = href.split("#")[0]
123                    if not href:
124                        # Empty URL, skip it
125                        continue
126
127                    parsed = urlparse(href)
128                    # Skip non-http(s) links (mailto:, tel:, javascript:, etc.)
129                    if parsed.scheme and parsed.scheme not in ("http", "https"):
130                        continue
131
132                    # Skip external HTTP links
133                    if parsed.scheme in ("http", "https") and not href.startswith(
134                        self.base_url
135                    ):
136                        continue
137
138                    visit_url = relative_url(href)
139                    if visit_url not in visited:
140                        to_visit.add(visit_url)
141
142        page.close()
143
144        return list(visited)
145
146    def generate_certificates(self):
147        """Generate self-signed certificates for HTTPS."""
148
149        # Generate private key
150        private_key = rsa.generate_private_key(
151            public_exponent=65537,
152            key_size=2048,
153        )
154
155        # Create certificate
156        subject = issuer = x509.Name(
157            [
158                x509.NameAttribute(NameOID.COMMON_NAME, self.host),
159            ]
160        )
161
162        cert = (
163            x509.CertificateBuilder()
164            .subject_name(subject)
165            .issuer_name(issuer)
166            .public_key(private_key.public_key())
167            .serial_number(x509.random_serial_number())
168            .not_valid_before(datetime.datetime.now(datetime.UTC))
169            .not_valid_after(
170                datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=365)
171            )
172            .add_extension(
173                x509.SubjectAlternativeName(
174                    [
175                        x509.DNSName(self.host),
176                    ]
177                ),
178                critical=False,
179            )
180            .sign(private_key, hashes.SHA256())
181        )
182
183        # Write certificate and key to files
184        cert_file = pathlib.Path(self.tmpdir.name) / "cert.pem"
185        key_file = pathlib.Path(self.tmpdir.name) / "key.pem"
186
187        with open(cert_file, "wb") as f:
188            f.write(cert.public_bytes(serialization.Encoding.PEM))
189
190        with open(key_file, "wb") as f:
191            f.write(
192                private_key.private_bytes(
193                    encoding=serialization.Encoding.PEM,
194                    format=serialization.PrivateFormat.PKCS8,
195                    encryption_algorithm=serialization.NoEncryption(),
196                )
197            )
198
199        return str(cert_file), str(key_file)
200
201    def run_server(self):
202        cert_file, key_file = self.generate_certificates()
203
204        env = os.environ.copy()
205
206        if self.database_url:
207            env["DATABASE_URL"] = self.database_url
208
209        gunicorn = pathlib.Path(sys.executable).with_name("gunicorn")
210
211        self.server_process = subprocess.Popen(
212            [
213                str(gunicorn),
214                self.wsgi,
215                "--bind",
216                f"{self.host}:{self.port}",
217                "--certfile",
218                cert_file,
219                "--keyfile",
220                key_file,
221                "--workers",
222                "2",
223                "--timeout",
224                "10",
225                "--log-level",
226                "warning",
227            ],
228            env=env,
229        )
230
231        time.sleep(0.7)  # quick grace period
232
233    def cleanup_server(self):
234        if self.server_process:
235            self.server_process.terminate()
236            self.server_process.wait()
237            self.server_process = None
238
239        self.tmpdir.cleanup()
240
241
242def _get_available_port() -> int:
243    """Get a randomly available port."""
244    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
245        s.bind(("", 0))
246        s.listen(1)
247        port = s.getsockname()[1]
248    return port