v0.151.1
  1from __future__ import annotations
  2
  3import base64
  4import hashlib
  5import secrets
  6from urllib.parse import urlparse, urlunparse
  7
  8from plain import postgres
  9from plain.postgres import types
 10from plain.utils import timezone
 11
 12__all__ = [
 13    "OAuthApplication",
 14    "AuthorizationCode",
 15    "AccessToken",
 16    "RefreshToken",
 17]
 18
 19_LOOPBACK_HOSTS = {"localhost", "127.0.0.1", "::1"}
 20
 21
 22def _generate_token() -> str:
 23    """A new opaque bearer token, handed to the client once and never stored."""
 24    return secrets.token_urlsafe(32)
 25
 26
 27def _hash_token(token: str) -> str:
 28    """What we persist for access/refresh tokens — a DB leak can't be replayed."""
 29    return hashlib.sha256(token.encode("ascii")).hexdigest()
 30
 31
 32def _normalize_redirect_uri(uri: str) -> str:
 33    """Drop the port from loopback URIs so they match regardless of it (RFC 8252)."""
 34    parsed = urlparse(uri)
 35    if parsed.scheme == "http" and parsed.hostname in _LOOPBACK_HOSTS:
 36        return urlunparse(parsed._replace(netloc=parsed.hostname or ""))
 37    return uri
 38
 39
 40@postgres.register_model
 41class OAuthApplication(postgres.Model):
 42    """A registered OAuth client (Claude's connector, a CLI).
 43
 44    Always a public client — proven by PKCE on the code exchange and by the
 45    refresh token on refresh, never a client secret.
 46    """
 47
 48    client_id = types.RandomStringField(length=32)
 49    name = types.TextField(max_length=255, default="", required=False)
 50    redirect_uris = types.TextField(max_length=2000)
 51    created_at = types.DateTimeField(create_now=True)
 52
 53    query: postgres.QuerySet[OAuthApplication] = postgres.QuerySet()
 54
 55    model_options = postgres.Options(
 56        constraints=[
 57            postgres.UniqueConstraint(
 58                fields=["client_id"],
 59                name="plainoauthserver_oauthapplication_unique_client_id",
 60            ),
 61        ],
 62    )
 63
 64    def __str__(self) -> str:
 65        return self.name
 66
 67    def get_redirect_uris(self) -> list[str]:
 68        return self.redirect_uris.split()
 69
 70    def is_valid_redirect_uri(self, uri: str) -> bool:
 71        normalized = _normalize_redirect_uri(uri)
 72        return any(
 73            _normalize_redirect_uri(registered) == normalized
 74            for registered in self.get_redirect_uris()
 75        )
 76
 77
 78@postgres.register_model
 79class AuthorizationCode(postgres.Model):
 80    """Short-lived, single-use code from the authorization endpoint.
 81
 82    Stored in plaintext: it's ephemeral, single-use, and bound by PKCE.
 83    """
 84
 85    code = types.RandomStringField(length=48)
 86    application = types.ForeignKeyField(OAuthApplication, on_delete=postgres.CASCADE)
 87    user = types.ForeignKeyField("users.User", on_delete=postgres.CASCADE)
 88    redirect_uri = types.TextField(max_length=2000)
 89    scope = types.TextField(max_length=500, default="", required=False)
 90    resource = types.TextField(max_length=2000, default="", required=False)
 91    code_challenge = types.TextField(max_length=128)
 92    created_at = types.DateTimeField(create_now=True)
 93    expires_at = types.DateTimeField()
 94    used = types.BooleanField(default=False)
 95
 96    query: postgres.QuerySet[AuthorizationCode] = postgres.QuerySet()
 97
 98    model_options = postgres.Options(
 99        constraints=[
100            # The code is the credential redeemed at the token endpoint, so
101            # uniqueness is enforced by the DB rather than left to chance.
102            postgres.UniqueConstraint(
103                fields=["code"],
104                name="plainoauthserver_authorizationcode_unique_code",
105            ),
106        ],
107        indexes=[
108            postgres.Index(
109                name="plainoauthserver_authorizationcode_application_id_idx",
110                fields=["application"],
111            ),
112            postgres.Index(
113                name="plainoauthserver_authorizationcode_user_id_idx",
114                fields=["user"],
115            ),
116        ],
117    )
118
119    def is_expired(self) -> bool:
120        return timezone.now() >= self.expires_at
121
122    def verify_code_challenge(self, code_verifier: str) -> bool:
123        """Verify a PKCE code_verifier against the stored S256 challenge."""
124        digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
125        computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
126        return secrets.compare_digest(computed, self.code_challenge)
127
128
129@postgres.register_model
130class AccessToken(postgres.Model):
131    """A bearer access token. Only its hash is stored."""
132
133    token_hash = types.TextField(max_length=64)
134    application = types.ForeignKeyField(OAuthApplication, on_delete=postgres.CASCADE)
135    user = types.ForeignKeyField("users.User", on_delete=postgres.CASCADE)
136    scope = types.TextField(max_length=500, default="", required=False)
137    resource = types.TextField(max_length=2000, default="", required=False)
138    created_at = types.DateTimeField(create_now=True)
139    expires_at = types.DateTimeField()
140    revoked = types.BooleanField(default=False)
141
142    query: postgres.QuerySet[AccessToken] = postgres.QuerySet()
143
144    model_options = postgres.Options(
145        constraints=[
146            postgres.UniqueConstraint(
147                fields=["token_hash"],
148                name="plainoauthserver_accesstoken_unique_token_hash",
149            ),
150        ],
151        indexes=[
152            postgres.Index(
153                name="plainoauthserver_accesstoken_user_id_idx", fields=["user"]
154            ),
155            postgres.Index(
156                name="plainoauthserver_accesstoken_application_id_idx",
157                fields=["application"],
158            ),
159        ],
160    )
161
162    def is_valid(self) -> bool:
163        return not self.revoked and timezone.now() < self.expires_at
164
165    @property
166    def scopes(self) -> frozenset[str]:
167        """The granted scope as a set — the shape a resource server checks against."""
168        return frozenset(self.scope.split())
169
170
171@postgres.register_model
172class RefreshToken(postgres.Model):
173    """A refresh token. Only its hash is stored; rotated on every use.
174
175    Scope and resource live on the linked `access_token` — a refresh always
176    has one (non-null CASCADE FK), so there's nothing to duplicate here.
177    """
178
179    token_hash = types.TextField(max_length=64)
180    application = types.ForeignKeyField(OAuthApplication, on_delete=postgres.CASCADE)
181    user = types.ForeignKeyField("users.User", on_delete=postgres.CASCADE)
182    # CASCADE is load-bearing for the cleanup chore: it deletes access tokens
183    # only when no live refresh token still points at them (see chores.py).
184    access_token = types.ForeignKeyField(AccessToken, on_delete=postgres.CASCADE)
185    created_at = types.DateTimeField(create_now=True)
186    expires_at = types.DateTimeField()
187    revoked = types.BooleanField(default=False)
188
189    query: postgres.QuerySet[RefreshToken] = postgres.QuerySet()
190
191    model_options = postgres.Options(
192        constraints=[
193            postgres.UniqueConstraint(
194                fields=["token_hash"],
195                name="plainoauthserver_refreshtoken_unique_token_hash",
196            ),
197        ],
198        indexes=[
199            postgres.Index(
200                name="plainoauthserver_refreshtoken_application_id_idx",
201                fields=["application"],
202            ),
203            postgres.Index(
204                name="plainoauthserver_refreshtoken_user_id_idx", fields=["user"]
205            ),
206            postgres.Index(
207                name="plainoauthserver_refreshtoken_access_token_id_idx",
208                fields=["access_token"],
209            ),
210        ],
211    )
212
213    def is_valid(self) -> bool:
214        return not self.revoked and timezone.now() < self.expires_at