1"""E2E encryption for portal sessions using SPAKE2 + NaCl.
 2
 3Both sides derive a shared secret from the human-readable portal code
 4via SPAKE2 (a password-authenticated key exchange). An eavesdropper
 5observing the relay traffic cannot brute-force the code offline.
 6
 7All messages after the key exchange are encrypted with NaCl SecretBox
 8(XSalsa20-Poly1305).
 9
10The portal code is never sent to the relay. Instead, a SHA-256 hash
11of the code is used as the channel ID for pairing. The raw code is
12only used locally for the SPAKE2 exchange.
13"""
14
15from __future__ import annotations
16
17import base64
18import hashlib
19import json
20
21import nacl.secret
22import spake2
23from websockets.asyncio.client import ClientConnection
24
25
26def channel_id(code: str) -> str:
27    """Derive a relay channel ID from the portal code.
28
29    Uses SHA-256 so the relay never learns the raw code and cannot
30    perform SPAKE2 to impersonate either side.
31    """
32    return hashlib.sha256(code.encode("utf-8")).hexdigest()
33
34
35async def perform_key_exchange(
36    ws: ClientConnection, code: str, *, side: str
37) -> PortalEncryptor:
38    """Run the SPAKE2 handshake over a WebSocket and return an encryptor.
39
40    `side` must be "start" (SPAKE2_A / initiator) or "connect" (SPAKE2_B / joiner).
41    """
42    spake_cls = spake2.SPAKE2_A if side == "start" else spake2.SPAKE2_B
43    spake_instance = spake_cls(code.encode("utf-8"))
44    spake_msg = spake_instance.start()
45
46    await ws.send(base64.b64encode(spake_msg).decode("ascii"))
47    peer_msg = base64.b64decode(await ws.recv())
48    key = spake_instance.finish(peer_msg)
49    return PortalEncryptor(key)
50
51
52class PortalEncryptor:
53    """Encrypts and decrypts portal messages using a shared key."""
54
55    def __init__(self, key: bytes) -> None:
56        # SPAKE2 produces a 32-byte key, which is exactly what SecretBox wants.
57        self._box = nacl.secret.SecretBox(key)
58
59    def encrypt(self, data: bytes) -> bytes:
60        """Encrypt data. Returns nonce + ciphertext."""
61        return self._box.encrypt(data)
62
63    def decrypt(self, data: bytes) -> bytes:
64        """Decrypt data. Expects nonce + ciphertext."""
65        return self._box.decrypt(data)
66
67    def encrypt_message(self, msg: dict) -> bytes:
68        """Encrypt a JSON-serializable message dict."""
69        plaintext = json.dumps(msg).encode("utf-8")
70        return self.encrypt(plaintext)
71
72    def decrypt_message(self, data: bytes) -> dict:
73        """Decrypt and parse a JSON message dict."""
74        plaintext = self.decrypt(data)
75        return json.loads(plaintext.decode("utf-8"))