v0.151.1
  1"""OAuth 2.1 authorization server.
  2
  3Implements the subset of OAuth 2.1 an MCP client (e.g. Claude's custom
  4connector) needs to authenticate an end user against a Plain app:
  5
  6- Authorization server metadata (RFC 8414)
  7- Dynamic client registration (RFC 7591) — public clients, PKCE
  8- Authorization code grant with PKCE (RFC 7636), audience-bound (RFC 8707)
  9- Token endpoint with refresh-token rotation (RFC 6749, OAuth 2.1)
 10- Token revocation (RFC 7009)
 11"""
 12
 13from __future__ import annotations
 14
 15from datetime import timedelta
 16from typing import Any
 17from urllib.parse import urlencode, urlparse, urlunparse
 18
 19from plain.auth.views import AuthView
 20from plain.http import JsonResponse, RedirectResponse, Request, Response
 21from plain.postgres import transaction
 22from plain.runtime import settings
 23from plain.templates import Template
 24from plain.urls import reverse
 25from plain.utils import timezone
 26from plain.views import View
 27
 28from .models import (
 29    _LOOPBACK_HOSTS,
 30    AccessToken,
 31    AuthorizationCode,
 32    OAuthApplication,
 33    RefreshToken,
 34    _generate_token,
 35    _hash_token,
 36)
 37
 38_GRANT_TYPES = ["authorization_code", "refresh_token"]
 39_RESPONSE_TYPES = ["code"]
 40
 41
 42def _issuer(request: Request) -> str:
 43    # request.scheme is proxy-aware (X-Forwarded-Proto), matching build_absolute_uri.
 44    return f"{request.scheme}://{request.host}"
 45
 46
 47def _is_allowed_redirect_uri(uri: str) -> bool:
 48    """OAuth 2.1: redirect URIs must be HTTPS, or loopback for native clients."""
 49    # Reject whitespace and fragments: redirect_uris are stored space-joined, so a
 50    # value containing whitespace would smuggle in a second, unvalidated URI.
 51    if any(c.isspace() for c in uri):
 52        return False
 53    parsed = urlparse(uri)
 54    if parsed.fragment:
 55        return False
 56    if parsed.scheme == "https":
 57        return True
 58    return parsed.scheme == "http" and parsed.hostname in _LOOPBACK_HOSTS
 59
 60
 61class AuthorizationServerMetadataView(View):
 62    """RFC 8414 — served at /.well-known/oauth-authorization-server."""
 63
 64    def get(self) -> JsonResponse:
 65        issuer = _issuer(self.request)
 66        metadata: dict[str, Any] = {
 67            "issuer": issuer,
 68            "authorization_endpoint": issuer + reverse("oauthserver:authorize"),
 69            "token_endpoint": issuer + reverse("oauthserver:token"),
 70            "revocation_endpoint": issuer + reverse("oauthserver:revoke"),
 71            "response_types_supported": _RESPONSE_TYPES,
 72            "grant_types_supported": _GRANT_TYPES,
 73            "code_challenge_methods_supported": ["S256"],
 74            "token_endpoint_auth_methods_supported": ["none"],
 75            "scopes_supported": list(settings.OAUTH_SERVER_SCOPES_SUPPORTED),
 76        }
 77        if settings.OAUTH_SERVER_ALLOW_DYNAMIC_REGISTRATION:
 78            metadata["registration_endpoint"] = issuer + reverse("oauthserver:register")
 79        return JsonResponse(metadata)
 80
 81
 82class RegisterView(View):
 83    """RFC 7591 — dynamic client registration.
 84
 85    Open registration: a freshly registered client can do nothing until a real
 86    user completes the authorization + consent flow, so the risk is bounded.
 87    """
 88
 89    def post(self) -> JsonResponse:
 90        if not settings.OAUTH_SERVER_ALLOW_DYNAMIC_REGISTRATION:
 91            return _oauth_error(
 92                "invalid_request", "Dynamic registration is disabled", status_code=403
 93            )
 94
 95        metadata = self.request.json_data
 96        if not isinstance(metadata, dict):
 97            return _oauth_error("invalid_client_metadata", "Body must be a JSON object")
 98
 99        redirect_uris = metadata.get("redirect_uris")
100        if not isinstance(redirect_uris, list) or not redirect_uris:
101            return _oauth_error("invalid_redirect_uri", "redirect_uris is required")
102        if not all(
103            isinstance(u, str) and _is_allowed_redirect_uri(u) for u in redirect_uris
104        ):
105            return _oauth_error(
106                "invalid_redirect_uri", "redirect_uris must be HTTPS or loopback"
107            )
108
109        # Every client is public (PKCE-proven); we never issue a secret, so any
110        # requested token_endpoint_auth_method is overridden to "none".
111        application = OAuthApplication(
112            name=metadata.get("client_name", ""),
113            redirect_uris=" ".join(redirect_uris),
114        )
115        application.create()
116
117        return JsonResponse(
118            {
119                "client_id": application.client_id,
120                "client_id_issued_at": int(application.created_at.timestamp()),
121                "redirect_uris": redirect_uris,
122                "token_endpoint_auth_method": "none",
123                "grant_types": _GRANT_TYPES,
124                "response_types": _RESPONSE_TYPES,
125                "client_name": application.name,
126            },
127            status_code=201,
128            headers={"Cache-Control": "no-store"},
129        )
130
131
132class AuthorizeView(AuthView):
133    """Authorization endpoint. GET shows consent; POST records the decision."""
134
135    login_required = True
136
137    def get(self) -> Response:
138        application, error = self._validate_request(self.request.query_params)
139        if error:
140            return self._render({"error": error})
141
142        params = self.request.query_params
143        return self._render(
144            {
145                "application": application,
146                "scope": params.get("scope", ""),
147                "params": {
148                    "response_type": "code",
149                    "client_id": params.get("client_id", ""),
150                    "redirect_uri": params.get("redirect_uri", ""),
151                    "scope": params.get("scope", ""),
152                    "state": params.get("state", ""),
153                    "resource": params.get("resource", ""),
154                    "code_challenge": params.get("code_challenge", ""),
155                    "code_challenge_method": params.get("code_challenge_method")
156                    or "S256",
157                },
158            }
159        )
160
161    def post(self) -> Response:
162        form = self.request.form_data
163        application, error = self._validate_request(form)
164        if error:
165            return JsonResponse(
166                {"error": "invalid_request", "error_description": error},
167                status_code=400,
168            )
169        assert application is not None  # a None error implies a valid client
170
171        # redirect_uri was validated against the client by _validate_request,
172        # so it's safe to redirect back to.
173        redirect_uri = form.get("redirect_uri", "")
174        state = form.get("state", "")
175
176        if form.get("action") != "approve":
177            return _redirect(redirect_uri, {"error": "access_denied", "state": state})
178
179        auth_code = AuthorizationCode(
180            application=application,
181            user=self.user,
182            redirect_uri=redirect_uri,
183            scope=form.get("scope", ""),
184            resource=form.get("resource", ""),
185            code_challenge=form.get("code_challenge", ""),
186            expires_at=timezone.now()
187            + timedelta(seconds=settings.OAUTH_SERVER_CODE_EXPIRY),
188        )
189        auth_code.create()
190
191        return _redirect(
192            redirect_uri,
193            {"code": auth_code.code, "state": state, "iss": _issuer(self.request)},
194        )
195
196    def _render(self, context: dict[str, Any]) -> Response:
197        # Always supply every variable the template reads (Jinja runs in strict
198        # mode), so the template needs no `is defined` guards.
199        full = {
200            "request": self.request,
201            "error": None,
202            "application": None,
203            "scope": "",
204            "params": {},
205            **context,
206        }
207        html = Template("oauthserver/authorize.html").render(full)
208        return Response(html, content_type="text/html")
209
210    def _validate_request(
211        self, params: Any
212    ) -> tuple[OAuthApplication | None, str | None]:
213        """Validate an authorization request. Returns (application, error).
214
215        Stops at the first problem. `application` is set once the client
216        resolves, so a redirect-back is only ever attempted against a URI
217        already proven to belong to that client.
218        """
219        if params.get("response_type", "") != "code":
220            return None, "response_type must be 'code'"
221
222        client_id = params.get("client_id", "")
223        if not client_id:
224            return None, "Missing client_id"
225        try:
226            application = OAuthApplication.query.get(client_id=client_id)
227        except OAuthApplication.DoesNotExist:
228            return None, f"Unknown client_id: {client_id}"
229
230        # Validate the redirect target before anything else can act on it —
231        # OAuth 2.1 §4.1.2.1 says to inform the user here, not redirect.
232        if not application.is_valid_redirect_uri(params.get("redirect_uri", "")):
233            return application, "Invalid redirect_uri"
234
235        if not params.get("code_challenge", ""):
236            return application, "Missing code_challenge (PKCE is required)"
237        method = params.get("code_challenge_method", "")
238        if method and method != "S256":
239            return application, "code_challenge_method must be 'S256'"
240
241        # Don't grant scopes the app never advertised — a consumer gating tools
242        # on scopes would otherwise treat an unconfigured scope as granted.
243        supported = set(settings.OAUTH_SERVER_SCOPES_SUPPORTED)
244        unsupported = [s for s in params.get("scope", "").split() if s not in supported]
245        if unsupported:
246            return application, f"Unsupported scope(s): {' '.join(unsupported)}"
247
248        return application, None
249
250
251class TokenView(View):
252    """Token endpoint — authorization_code and refresh_token grants."""
253
254    def post(self) -> JsonResponse:
255        grant_type = self.request.form_data.get("grant_type", "")
256        match grant_type:
257            case "authorization_code":
258                return self._authorization_code()
259            case "refresh_token":
260                return self._refresh_token()
261            case _:
262                return _oauth_error(
263                    "unsupported_grant_type", f"Unsupported grant_type: {grant_type!r}"
264                )
265
266    def _authorization_code(self) -> JsonResponse:
267        application = _resolve_client(self.request)
268        if isinstance(application, JsonResponse):
269            return application
270
271        form = self.request.form_data
272        code_value = form.get("code", "")
273        code_verifier = form.get("code_verifier", "")
274        if not code_value:
275            return _oauth_error("invalid_request", "Missing code")
276        if not code_verifier:
277            return _oauth_error(
278                "invalid_request", "Missing code_verifier (PKCE required)"
279            )
280
281        # Lock the code row so two concurrent exchanges can't both spend it.
282        with transaction.atomic():
283            try:
284                auth_code = AuthorizationCode.query.select_for_update().get(
285                    code=code_value, application=application
286                )
287            except AuthorizationCode.DoesNotExist:
288                return _oauth_error("invalid_grant", "Invalid authorization code")
289
290            if auth_code.used:
291                return _oauth_error("invalid_grant", "Authorization code already used")
292            if auth_code.is_expired():
293                return _oauth_error("invalid_grant", "Authorization code expired")
294            if auth_code.redirect_uri != form.get("redirect_uri", ""):
295                return _oauth_error("invalid_grant", "redirect_uri mismatch")
296            if not auth_code.verify_code_challenge(code_verifier):
297                return _oauth_error("invalid_grant", "PKCE verification failed")
298
299            auth_code.used = True
300            auth_code.update(fields=["used"])
301
302            return _issue_tokens(
303                application, auth_code.user, auth_code.scope, auth_code.resource
304            )
305
306    def _refresh_token(self) -> JsonResponse:
307        application = _resolve_client(self.request)
308        if isinstance(application, JsonResponse):
309            return application
310
311        token_value = self.request.form_data.get("refresh_token", "")
312        if not token_value:
313            return _oauth_error("invalid_request", "Missing refresh_token")
314
315        # Lock the row so concurrent reuse of one refresh token can't fork into
316        # two valid token pairs — the second waits, then sees it revoked.
317        with transaction.atomic():
318            try:
319                refresh = (
320                    RefreshToken.query.select_for_update()
321                    .select_related("access_token")
322                    .get(token_hash=_hash_token(token_value), application=application)
323                )
324            except RefreshToken.DoesNotExist:
325                return _oauth_error("invalid_grant", "Invalid refresh token")
326
327            if not refresh.is_valid():
328                return _oauth_error("invalid_grant", "Refresh token is no longer valid")
329
330            # Carry the grant forward from the old access token before revoking it.
331            scope = refresh.access_token.scope
332            resource = refresh.access_token.resource
333
334            # Rotate: invalidate the old pair before issuing a new one.
335            refresh.revoked = True
336            refresh.update(fields=["revoked"])
337            refresh.access_token.revoked = True
338            refresh.access_token.update(fields=["revoked"])
339
340            return _issue_tokens(application, refresh.user, scope, resource)
341
342
343class RevocationView(View):
344    """RFC 7009 — always responds 200, even for unknown tokens."""
345
346    def post(self) -> Response:
347        application = _resolve_client(self.request)
348        if isinstance(application, JsonResponse):
349            return application
350
351        token_value = self.request.form_data.get("token", "")
352        if not token_value:
353            return Response(status_code=200)
354        token_hash = _hash_token(token_value)
355
356        revoked = AccessToken.query.filter(
357            token_hash=token_hash, application=application
358        ).update(revoked=True)
359        if revoked:
360            return Response(status_code=200)
361
362        try:
363            refresh = RefreshToken.query.get(
364                token_hash=token_hash, application=application
365            )
366        except RefreshToken.DoesNotExist:
367            return Response(status_code=200)
368
369        refresh.revoked = True
370        refresh.update(fields=["revoked"])
371        AccessToken.query.filter(id=refresh.access_token.id).update(revoked=True)
372        return Response(status_code=200)
373
374
375def _resolve_client(request: Request) -> OAuthApplication | JsonResponse:
376    """Look up the public client by client_id (PKCE / the refresh token is the proof)."""
377    client_id = request.form_data.get("client_id", "")
378    if not client_id:
379        return _oauth_error("invalid_client", "Missing client_id", status_code=401)
380
381    try:
382        return OAuthApplication.query.get(client_id=client_id)
383    except OAuthApplication.DoesNotExist:
384        return _oauth_error("invalid_client", "Unknown client", status_code=401)
385
386
387def _issue_tokens(
388    application: OAuthApplication, user: Any, scope: str, resource: str
389) -> JsonResponse:
390    # Both callers run inside the transaction that locked the code/refresh row,
391    # so these two inserts are already atomic with the rotation.
392    access_value = _generate_token()
393    refresh_value = _generate_token()
394    now = timezone.now()
395
396    access_token = AccessToken(
397        application=application,
398        user=user,
399        scope=scope,
400        resource=resource,
401        token_hash=_hash_token(access_value),
402        expires_at=now + timedelta(seconds=settings.OAUTH_SERVER_ACCESS_TOKEN_EXPIRY),
403    )
404    access_token.create()
405    refresh_token = RefreshToken(
406        application=application,
407        user=user,
408        access_token=access_token,
409        token_hash=_hash_token(refresh_value),
410        expires_at=now + timedelta(seconds=settings.OAUTH_SERVER_REFRESH_TOKEN_EXPIRY),
411    )
412    refresh_token.create()
413
414    return JsonResponse(
415        {
416            "access_token": access_value,
417            "token_type": "Bearer",
418            "expires_in": settings.OAUTH_SERVER_ACCESS_TOKEN_EXPIRY,
419            "refresh_token": refresh_value,
420            "scope": scope,
421        },
422        headers={"Cache-Control": "no-store"},
423    )
424
425
426def _redirect(redirect_uri: str, params: dict[str, str]) -> RedirectResponse:
427    params = {k: v for k, v in params.items() if v}
428    parsed = urlparse(redirect_uri)
429    separator = "&" if parsed.query else ""
430    new_query = parsed.query + separator + urlencode(params)
431    return RedirectResponse(
432        urlunparse(parsed._replace(query=new_query)), allow_external=True
433    )
434
435
436def _oauth_error(
437    error: str, description: str, *, status_code: int = 400
438) -> JsonResponse:
439    """An RFC 6749 error response (`{"error", "error_description"}`)."""
440    return JsonResponse(
441        {"error": error, "error_description": description}, status_code=status_code
442    )