1from __future__ import annotations
2
3import base64
4import hashlib
5import secrets
6from urllib.parse import urlparse, urlunparse
7
8from plain import postgres
9from plain.postgres import types
10from plain.utils import timezone
11
12__all__ = [
13 "OAuthApplication",
14 "AuthorizationCode",
15 "AccessToken",
16 "RefreshToken",
17]
18
19_LOOPBACK_HOSTS = {"localhost", "127.0.0.1", "::1"}
20
21
22def _generate_token() -> str:
23 """A new opaque bearer token, handed to the client once and never stored."""
24 return secrets.token_urlsafe(32)
25
26
27def _hash_token(token: str) -> str:
28 """What we persist for access/refresh tokens — a DB leak can't be replayed."""
29 return hashlib.sha256(token.encode("ascii")).hexdigest()
30
31
32def _normalize_redirect_uri(uri: str) -> str:
33 """Drop the port from loopback URIs so they match regardless of it (RFC 8252)."""
34 parsed = urlparse(uri)
35 if parsed.scheme == "http" and parsed.hostname in _LOOPBACK_HOSTS:
36 return urlunparse(parsed._replace(netloc=parsed.hostname or ""))
37 return uri
38
39
40@postgres.register_model
41class OAuthApplication(postgres.Model):
42 """A registered OAuth client (Claude's connector, a CLI).
43
44 Always a public client — proven by PKCE on the code exchange and by the
45 refresh token on refresh, never a client secret.
46 """
47
48 client_id = types.RandomStringField(length=32)
49 name = types.TextField(max_length=255, default="", required=False)
50 redirect_uris = types.TextField(max_length=2000)
51 created_at = types.DateTimeField(create_now=True)
52
53 query: postgres.QuerySet[OAuthApplication] = postgres.QuerySet()
54
55 model_options = postgres.Options(
56 constraints=[
57 postgres.UniqueConstraint(
58 fields=["client_id"],
59 name="plainoauthserver_oauthapplication_unique_client_id",
60 ),
61 ],
62 )
63
64 def __str__(self) -> str:
65 return self.name
66
67 def get_redirect_uris(self) -> list[str]:
68 return self.redirect_uris.split()
69
70 def is_valid_redirect_uri(self, uri: str) -> bool:
71 normalized = _normalize_redirect_uri(uri)
72 return any(
73 _normalize_redirect_uri(registered) == normalized
74 for registered in self.get_redirect_uris()
75 )
76
77
78@postgres.register_model
79class AuthorizationCode(postgres.Model):
80 """Short-lived, single-use code from the authorization endpoint.
81
82 Stored in plaintext: it's ephemeral, single-use, and bound by PKCE.
83 """
84
85 code = types.RandomStringField(length=48)
86 application = types.ForeignKeyField(OAuthApplication, on_delete=postgres.CASCADE)
87 user = types.ForeignKeyField("users.User", on_delete=postgres.CASCADE)
88 redirect_uri = types.TextField(max_length=2000)
89 scope = types.TextField(max_length=500, default="", required=False)
90 resource = types.TextField(max_length=2000, default="", required=False)
91 code_challenge = types.TextField(max_length=128)
92 created_at = types.DateTimeField(create_now=True)
93 expires_at = types.DateTimeField()
94 used = types.BooleanField(default=False)
95
96 query: postgres.QuerySet[AuthorizationCode] = postgres.QuerySet()
97
98 model_options = postgres.Options(
99 constraints=[
100 # The code is the credential redeemed at the token endpoint, so
101 # uniqueness is enforced by the DB rather than left to chance.
102 postgres.UniqueConstraint(
103 fields=["code"],
104 name="plainoauthserver_authorizationcode_unique_code",
105 ),
106 ],
107 indexes=[
108 postgres.Index(
109 name="plainoauthserver_authorizationcode_application_id_idx",
110 fields=["application"],
111 ),
112 postgres.Index(
113 name="plainoauthserver_authorizationcode_user_id_idx",
114 fields=["user"],
115 ),
116 ],
117 )
118
119 def is_expired(self) -> bool:
120 return timezone.now() >= self.expires_at
121
122 def verify_code_challenge(self, code_verifier: str) -> bool:
123 """Verify a PKCE code_verifier against the stored S256 challenge."""
124 digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
125 computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
126 return secrets.compare_digest(computed, self.code_challenge)
127
128
129@postgres.register_model
130class AccessToken(postgres.Model):
131 """A bearer access token. Only its hash is stored."""
132
133 token_hash = types.TextField(max_length=64)
134 application = types.ForeignKeyField(OAuthApplication, on_delete=postgres.CASCADE)
135 user = types.ForeignKeyField("users.User", on_delete=postgres.CASCADE)
136 scope = types.TextField(max_length=500, default="", required=False)
137 resource = types.TextField(max_length=2000, default="", required=False)
138 created_at = types.DateTimeField(create_now=True)
139 expires_at = types.DateTimeField()
140 revoked = types.BooleanField(default=False)
141
142 query: postgres.QuerySet[AccessToken] = postgres.QuerySet()
143
144 model_options = postgres.Options(
145 constraints=[
146 postgres.UniqueConstraint(
147 fields=["token_hash"],
148 name="plainoauthserver_accesstoken_unique_token_hash",
149 ),
150 ],
151 indexes=[
152 postgres.Index(
153 name="plainoauthserver_accesstoken_user_id_idx", fields=["user"]
154 ),
155 postgres.Index(
156 name="plainoauthserver_accesstoken_application_id_idx",
157 fields=["application"],
158 ),
159 ],
160 )
161
162 def is_valid(self) -> bool:
163 return not self.revoked and timezone.now() < self.expires_at
164
165 @property
166 def scopes(self) -> frozenset[str]:
167 """The granted scope as a set — the shape a resource server checks against."""
168 return frozenset(self.scope.split())
169
170
171@postgres.register_model
172class RefreshToken(postgres.Model):
173 """A refresh token. Only its hash is stored; rotated on every use.
174
175 Scope and resource live on the linked `access_token` — a refresh always
176 has one (non-null CASCADE FK), so there's nothing to duplicate here.
177 """
178
179 token_hash = types.TextField(max_length=64)
180 application = types.ForeignKeyField(OAuthApplication, on_delete=postgres.CASCADE)
181 user = types.ForeignKeyField("users.User", on_delete=postgres.CASCADE)
182 # CASCADE is load-bearing for the cleanup chore: it deletes access tokens
183 # only when no live refresh token still points at them (see chores.py).
184 access_token = types.ForeignKeyField(AccessToken, on_delete=postgres.CASCADE)
185 created_at = types.DateTimeField(create_now=True)
186 expires_at = types.DateTimeField()
187 revoked = types.BooleanField(default=False)
188
189 query: postgres.QuerySet[RefreshToken] = postgres.QuerySet()
190
191 model_options = postgres.Options(
192 constraints=[
193 postgres.UniqueConstraint(
194 fields=["token_hash"],
195 name="plainoauthserver_refreshtoken_unique_token_hash",
196 ),
197 ],
198 indexes=[
199 postgres.Index(
200 name="plainoauthserver_refreshtoken_application_id_idx",
201 fields=["application"],
202 ),
203 postgres.Index(
204 name="plainoauthserver_refreshtoken_user_id_idx", fields=["user"]
205 ),
206 postgres.Index(
207 name="plainoauthserver_refreshtoken_access_token_id_idx",
208 fields=["access_token"],
209 ),
210 ],
211 )
212
213 def is_valid(self) -> bool:
214 return not self.revoked and timezone.now() < self.expires_at