1"""E2E encryption for portal sessions using SPAKE2 + NaCl.
2
3Both sides derive a shared secret from the human-readable portal code
4via SPAKE2 (a password-authenticated key exchange). An eavesdropper
5observing the relay traffic cannot brute-force the code offline.
6
7All messages after the key exchange are encrypted with NaCl SecretBox
8(XSalsa20-Poly1305).
9
10The portal code is never sent to the relay. Instead, a SHA-256 hash
11of the code is used as the channel ID for pairing. The raw code is
12only used locally for the SPAKE2 exchange.
13"""
14
15from __future__ import annotations
16
17import base64
18import hashlib
19import json
20
21import nacl.secret
22import spake2
23from websockets.asyncio.client import ClientConnection
24
25
26def channel_id(code: str) -> str:
27 """Derive a relay channel ID from the portal code.
28
29 Uses SHA-256 so the relay never learns the raw code and cannot
30 perform SPAKE2 to impersonate either side.
31 """
32 return hashlib.sha256(code.encode("utf-8")).hexdigest()
33
34
35async def perform_key_exchange(
36 ws: ClientConnection, code: str, *, side: str
37) -> PortalEncryptor:
38 """Run the SPAKE2 handshake over a WebSocket and return an encryptor.
39
40 `side` must be "start" (SPAKE2_A / initiator) or "connect" (SPAKE2_B / joiner).
41 """
42 spake_cls = spake2.SPAKE2_A if side == "start" else spake2.SPAKE2_B
43 spake_instance = spake_cls(code.encode("utf-8"))
44 spake_msg = spake_instance.start()
45
46 await ws.send(base64.b64encode(spake_msg).decode("ascii"))
47 peer_msg = base64.b64decode(await ws.recv())
48 key = spake_instance.finish(peer_msg)
49 return PortalEncryptor(key)
50
51
52class PortalEncryptor:
53 """Encrypts and decrypts portal messages using a shared key."""
54
55 def __init__(self, key: bytes) -> None:
56 # SPAKE2 produces a 32-byte key, which is exactly what SecretBox wants.
57 self._box = nacl.secret.SecretBox(key)
58
59 def encrypt(self, data: bytes) -> bytes:
60 """Encrypt data. Returns nonce + ciphertext."""
61 return self._box.encrypt(data)
62
63 def decrypt(self, data: bytes) -> bytes:
64 """Decrypt data. Expects nonce + ciphertext."""
65 return self._box.decrypt(data)
66
67 def encrypt_message(self, msg: dict) -> bytes:
68 """Encrypt a JSON-serializable message dict."""
69 plaintext = json.dumps(msg).encode("utf-8")
70 return self.encrypt(plaintext)
71
72 def decrypt_message(self, data: bytes) -> dict:
73 """Decrypt and parse a JSON message dict."""
74 plaintext = self.decrypt(data)
75 return json.loads(plaintext.decode("utf-8"))