1"""OAuth 2.1 authorization server.
2
3Implements the subset of OAuth 2.1 an MCP client (e.g. Claude's custom
4connector) needs to authenticate an end user against a Plain app:
5
6- Authorization server metadata (RFC 8414)
7- Dynamic client registration (RFC 7591) — public clients, PKCE
8- Authorization code grant with PKCE (RFC 7636), audience-bound (RFC 8707)
9- Token endpoint with refresh-token rotation (RFC 6749, OAuth 2.1)
10- Token revocation (RFC 7009)
11"""
12
13from __future__ import annotations
14
15from datetime import timedelta
16from typing import Any
17from urllib.parse import urlencode, urlparse, urlunparse
18
19from plain.auth.views import AuthView
20from plain.http import JsonResponse, RedirectResponse, Request, Response
21from plain.postgres import transaction
22from plain.runtime import settings
23from plain.templates import Template
24from plain.urls import reverse
25from plain.utils import timezone
26from plain.views import View
27
28from .models import (
29 _LOOPBACK_HOSTS,
30 AccessToken,
31 AuthorizationCode,
32 OAuthApplication,
33 RefreshToken,
34 _generate_token,
35 _hash_token,
36)
37
38_GRANT_TYPES = ["authorization_code", "refresh_token"]
39_RESPONSE_TYPES = ["code"]
40
41
42def _issuer(request: Request) -> str:
43 # request.scheme is proxy-aware (X-Forwarded-Proto), matching build_absolute_uri.
44 return f"{request.scheme}://{request.host}"
45
46
47def _is_allowed_redirect_uri(uri: str) -> bool:
48 """OAuth 2.1: redirect URIs must be HTTPS, or loopback for native clients."""
49 # Reject whitespace and fragments: redirect_uris are stored space-joined, so a
50 # value containing whitespace would smuggle in a second, unvalidated URI.
51 if any(c.isspace() for c in uri):
52 return False
53 parsed = urlparse(uri)
54 if parsed.fragment:
55 return False
56 if parsed.scheme == "https":
57 return True
58 return parsed.scheme == "http" and parsed.hostname in _LOOPBACK_HOSTS
59
60
61class AuthorizationServerMetadataView(View):
62 """RFC 8414 — served at /.well-known/oauth-authorization-server."""
63
64 def get(self) -> JsonResponse:
65 issuer = _issuer(self.request)
66 metadata: dict[str, Any] = {
67 "issuer": issuer,
68 "authorization_endpoint": issuer + reverse("oauthserver:authorize"),
69 "token_endpoint": issuer + reverse("oauthserver:token"),
70 "revocation_endpoint": issuer + reverse("oauthserver:revoke"),
71 "response_types_supported": _RESPONSE_TYPES,
72 "grant_types_supported": _GRANT_TYPES,
73 "code_challenge_methods_supported": ["S256"],
74 "token_endpoint_auth_methods_supported": ["none"],
75 "scopes_supported": list(settings.OAUTH_SERVER_SCOPES_SUPPORTED),
76 }
77 if settings.OAUTH_SERVER_ALLOW_DYNAMIC_REGISTRATION:
78 metadata["registration_endpoint"] = issuer + reverse("oauthserver:register")
79 return JsonResponse(metadata)
80
81
82class RegisterView(View):
83 """RFC 7591 — dynamic client registration.
84
85 Open registration: a freshly registered client can do nothing until a real
86 user completes the authorization + consent flow, so the risk is bounded.
87 """
88
89 def post(self) -> JsonResponse:
90 if not settings.OAUTH_SERVER_ALLOW_DYNAMIC_REGISTRATION:
91 return _oauth_error(
92 "invalid_request", "Dynamic registration is disabled", status_code=403
93 )
94
95 metadata = self.request.json_data
96 if not isinstance(metadata, dict):
97 return _oauth_error("invalid_client_metadata", "Body must be a JSON object")
98
99 redirect_uris = metadata.get("redirect_uris")
100 if not isinstance(redirect_uris, list) or not redirect_uris:
101 return _oauth_error("invalid_redirect_uri", "redirect_uris is required")
102 if not all(
103 isinstance(u, str) and _is_allowed_redirect_uri(u) for u in redirect_uris
104 ):
105 return _oauth_error(
106 "invalid_redirect_uri", "redirect_uris must be HTTPS or loopback"
107 )
108
109 # Every client is public (PKCE-proven); we never issue a secret, so any
110 # requested token_endpoint_auth_method is overridden to "none".
111 application = OAuthApplication(
112 name=metadata.get("client_name", ""),
113 redirect_uris=" ".join(redirect_uris),
114 )
115 application.create()
116
117 return JsonResponse(
118 {
119 "client_id": application.client_id,
120 "client_id_issued_at": int(application.created_at.timestamp()),
121 "redirect_uris": redirect_uris,
122 "token_endpoint_auth_method": "none",
123 "grant_types": _GRANT_TYPES,
124 "response_types": _RESPONSE_TYPES,
125 "client_name": application.name,
126 },
127 status_code=201,
128 headers={"Cache-Control": "no-store"},
129 )
130
131
132class AuthorizeView(AuthView):
133 """Authorization endpoint. GET shows consent; POST records the decision."""
134
135 login_required = True
136
137 def get(self) -> Response:
138 application, error = self._validate_request(self.request.query_params)
139 if error:
140 return self._render({"error": error})
141
142 params = self.request.query_params
143 return self._render(
144 {
145 "application": application,
146 "scope": params.get("scope", ""),
147 "params": {
148 "response_type": "code",
149 "client_id": params.get("client_id", ""),
150 "redirect_uri": params.get("redirect_uri", ""),
151 "scope": params.get("scope", ""),
152 "state": params.get("state", ""),
153 "resource": params.get("resource", ""),
154 "code_challenge": params.get("code_challenge", ""),
155 "code_challenge_method": params.get("code_challenge_method")
156 or "S256",
157 },
158 }
159 )
160
161 def post(self) -> Response:
162 form = self.request.form_data
163 application, error = self._validate_request(form)
164 if error:
165 return JsonResponse(
166 {"error": "invalid_request", "error_description": error},
167 status_code=400,
168 )
169 assert application is not None # a None error implies a valid client
170
171 # redirect_uri was validated against the client by _validate_request,
172 # so it's safe to redirect back to.
173 redirect_uri = form.get("redirect_uri", "")
174 state = form.get("state", "")
175
176 if form.get("action") != "approve":
177 return _redirect(redirect_uri, {"error": "access_denied", "state": state})
178
179 auth_code = AuthorizationCode(
180 application=application,
181 user=self.user,
182 redirect_uri=redirect_uri,
183 scope=form.get("scope", ""),
184 resource=form.get("resource", ""),
185 code_challenge=form.get("code_challenge", ""),
186 expires_at=timezone.now()
187 + timedelta(seconds=settings.OAUTH_SERVER_CODE_EXPIRY),
188 )
189 auth_code.create()
190
191 return _redirect(
192 redirect_uri,
193 {"code": auth_code.code, "state": state, "iss": _issuer(self.request)},
194 )
195
196 def _render(self, context: dict[str, Any]) -> Response:
197 # Always supply every variable the template reads (Jinja runs in strict
198 # mode), so the template needs no `is defined` guards.
199 full = {
200 "request": self.request,
201 "error": None,
202 "application": None,
203 "scope": "",
204 "params": {},
205 **context,
206 }
207 html = Template("oauthserver/authorize.html").render(full)
208 return Response(html, content_type="text/html")
209
210 def _validate_request(
211 self, params: Any
212 ) -> tuple[OAuthApplication | None, str | None]:
213 """Validate an authorization request. Returns (application, error).
214
215 Stops at the first problem. `application` is set once the client
216 resolves, so a redirect-back is only ever attempted against a URI
217 already proven to belong to that client.
218 """
219 if params.get("response_type", "") != "code":
220 return None, "response_type must be 'code'"
221
222 client_id = params.get("client_id", "")
223 if not client_id:
224 return None, "Missing client_id"
225 try:
226 application = OAuthApplication.query.get(client_id=client_id)
227 except OAuthApplication.DoesNotExist:
228 return None, f"Unknown client_id: {client_id}"
229
230 # Validate the redirect target before anything else can act on it —
231 # OAuth 2.1 §4.1.2.1 says to inform the user here, not redirect.
232 if not application.is_valid_redirect_uri(params.get("redirect_uri", "")):
233 return application, "Invalid redirect_uri"
234
235 if not params.get("code_challenge", ""):
236 return application, "Missing code_challenge (PKCE is required)"
237 method = params.get("code_challenge_method", "")
238 if method and method != "S256":
239 return application, "code_challenge_method must be 'S256'"
240
241 # Don't grant scopes the app never advertised — a consumer gating tools
242 # on scopes would otherwise treat an unconfigured scope as granted.
243 supported = set(settings.OAUTH_SERVER_SCOPES_SUPPORTED)
244 unsupported = [s for s in params.get("scope", "").split() if s not in supported]
245 if unsupported:
246 return application, f"Unsupported scope(s): {' '.join(unsupported)}"
247
248 return application, None
249
250
251class TokenView(View):
252 """Token endpoint — authorization_code and refresh_token grants."""
253
254 def post(self) -> JsonResponse:
255 grant_type = self.request.form_data.get("grant_type", "")
256 match grant_type:
257 case "authorization_code":
258 return self._authorization_code()
259 case "refresh_token":
260 return self._refresh_token()
261 case _:
262 return _oauth_error(
263 "unsupported_grant_type", f"Unsupported grant_type: {grant_type!r}"
264 )
265
266 def _authorization_code(self) -> JsonResponse:
267 application = _resolve_client(self.request)
268 if isinstance(application, JsonResponse):
269 return application
270
271 form = self.request.form_data
272 code_value = form.get("code", "")
273 code_verifier = form.get("code_verifier", "")
274 if not code_value:
275 return _oauth_error("invalid_request", "Missing code")
276 if not code_verifier:
277 return _oauth_error(
278 "invalid_request", "Missing code_verifier (PKCE required)"
279 )
280
281 # Lock the code row so two concurrent exchanges can't both spend it.
282 with transaction.atomic():
283 try:
284 auth_code = AuthorizationCode.query.select_for_update().get(
285 code=code_value, application=application
286 )
287 except AuthorizationCode.DoesNotExist:
288 return _oauth_error("invalid_grant", "Invalid authorization code")
289
290 if auth_code.used:
291 return _oauth_error("invalid_grant", "Authorization code already used")
292 if auth_code.is_expired():
293 return _oauth_error("invalid_grant", "Authorization code expired")
294 if auth_code.redirect_uri != form.get("redirect_uri", ""):
295 return _oauth_error("invalid_grant", "redirect_uri mismatch")
296 if not auth_code.verify_code_challenge(code_verifier):
297 return _oauth_error("invalid_grant", "PKCE verification failed")
298
299 auth_code.used = True
300 auth_code.update(fields=["used"])
301
302 return _issue_tokens(
303 application, auth_code.user, auth_code.scope, auth_code.resource
304 )
305
306 def _refresh_token(self) -> JsonResponse:
307 application = _resolve_client(self.request)
308 if isinstance(application, JsonResponse):
309 return application
310
311 token_value = self.request.form_data.get("refresh_token", "")
312 if not token_value:
313 return _oauth_error("invalid_request", "Missing refresh_token")
314
315 # Lock the row so concurrent reuse of one refresh token can't fork into
316 # two valid token pairs — the second waits, then sees it revoked.
317 with transaction.atomic():
318 try:
319 refresh = (
320 RefreshToken.query.select_for_update()
321 .select_related("access_token")
322 .get(token_hash=_hash_token(token_value), application=application)
323 )
324 except RefreshToken.DoesNotExist:
325 return _oauth_error("invalid_grant", "Invalid refresh token")
326
327 if not refresh.is_valid():
328 return _oauth_error("invalid_grant", "Refresh token is no longer valid")
329
330 # Carry the grant forward from the old access token before revoking it.
331 scope = refresh.access_token.scope
332 resource = refresh.access_token.resource
333
334 # Rotate: invalidate the old pair before issuing a new one.
335 refresh.revoked = True
336 refresh.update(fields=["revoked"])
337 refresh.access_token.revoked = True
338 refresh.access_token.update(fields=["revoked"])
339
340 return _issue_tokens(application, refresh.user, scope, resource)
341
342
343class RevocationView(View):
344 """RFC 7009 — always responds 200, even for unknown tokens."""
345
346 def post(self) -> Response:
347 application = _resolve_client(self.request)
348 if isinstance(application, JsonResponse):
349 return application
350
351 token_value = self.request.form_data.get("token", "")
352 if not token_value:
353 return Response(status_code=200)
354 token_hash = _hash_token(token_value)
355
356 revoked = AccessToken.query.filter(
357 token_hash=token_hash, application=application
358 ).update(revoked=True)
359 if revoked:
360 return Response(status_code=200)
361
362 try:
363 refresh = RefreshToken.query.get(
364 token_hash=token_hash, application=application
365 )
366 except RefreshToken.DoesNotExist:
367 return Response(status_code=200)
368
369 refresh.revoked = True
370 refresh.update(fields=["revoked"])
371 AccessToken.query.filter(id=refresh.access_token.id).update(revoked=True)
372 return Response(status_code=200)
373
374
375def _resolve_client(request: Request) -> OAuthApplication | JsonResponse:
376 """Look up the public client by client_id (PKCE / the refresh token is the proof)."""
377 client_id = request.form_data.get("client_id", "")
378 if not client_id:
379 return _oauth_error("invalid_client", "Missing client_id", status_code=401)
380
381 try:
382 return OAuthApplication.query.get(client_id=client_id)
383 except OAuthApplication.DoesNotExist:
384 return _oauth_error("invalid_client", "Unknown client", status_code=401)
385
386
387def _issue_tokens(
388 application: OAuthApplication, user: Any, scope: str, resource: str
389) -> JsonResponse:
390 # Both callers run inside the transaction that locked the code/refresh row,
391 # so these two inserts are already atomic with the rotation.
392 access_value = _generate_token()
393 refresh_value = _generate_token()
394 now = timezone.now()
395
396 access_token = AccessToken(
397 application=application,
398 user=user,
399 scope=scope,
400 resource=resource,
401 token_hash=_hash_token(access_value),
402 expires_at=now + timedelta(seconds=settings.OAUTH_SERVER_ACCESS_TOKEN_EXPIRY),
403 )
404 access_token.create()
405 refresh_token = RefreshToken(
406 application=application,
407 user=user,
408 access_token=access_token,
409 token_hash=_hash_token(refresh_value),
410 expires_at=now + timedelta(seconds=settings.OAUTH_SERVER_REFRESH_TOKEN_EXPIRY),
411 )
412 refresh_token.create()
413
414 return JsonResponse(
415 {
416 "access_token": access_value,
417 "token_type": "Bearer",
418 "expires_in": settings.OAUTH_SERVER_ACCESS_TOKEN_EXPIRY,
419 "refresh_token": refresh_value,
420 "scope": scope,
421 },
422 headers={"Cache-Control": "no-store"},
423 )
424
425
426def _redirect(redirect_uri: str, params: dict[str, str]) -> RedirectResponse:
427 params = {k: v for k, v in params.items() if v}
428 parsed = urlparse(redirect_uri)
429 separator = "&" if parsed.query else ""
430 new_query = parsed.query + separator + urlencode(params)
431 return RedirectResponse(
432 urlunparse(parsed._replace(query=new_query)), allow_external=True
433 )
434
435
436def _oauth_error(
437 error: str, description: str, *, status_code: int = 400
438) -> JsonResponse:
439 """An RFC 6749 error response (`{"error", "error_description"}`)."""
440 return JsonResponse(
441 {"error": error, "error_description": description}, status_code=status_code
442 )