Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import secrets
  3from abc import ABC, abstractmethod
  4from typing import Any
  5from urllib.parse import urlencode
  6
  7from plain.auth import login as auth_login
  8from plain.auth.requests import get_request_user
  9from plain.http import Request, Response, ResponseRedirect
 10from plain.runtime import settings
 11from plain.sessions import get_request_session
 12from plain.urls import reverse
 13from plain.utils.cache import add_never_cache_headers
 14from plain.utils.crypto import get_random_string
 15from plain.utils.module_loading import import_string
 16
 17from .exceptions import OAuthError, OAuthStateMismatchError, OAuthStateMissingError
 18from .models import OAuthConnection
 19
 20SESSION_STATE_KEY = "plainoauth_state"
 21SESSION_NEXT_KEY = "plainoauth_next"
 22
 23
 24class OAuthToken:
 25    def __init__(
 26        self,
 27        *,
 28        access_token: str,
 29        refresh_token: str = "",
 30        access_token_expires_at: datetime.datetime | None = None,
 31        refresh_token_expires_at: datetime.datetime | None = None,
 32    ):
 33        self.access_token = access_token
 34        self.refresh_token = refresh_token
 35        self.access_token_expires_at = access_token_expires_at
 36        self.refresh_token_expires_at = refresh_token_expires_at
 37
 38
 39class OAuthUser:
 40    def __init__(self, *, provider_id: str, user_model_fields: dict | None = None):
 41        self.provider_id = provider_id  # ID on the provider's system
 42        self.user_model_fields = user_model_fields or {}
 43
 44    def __str__(self) -> str:
 45        if "email" in self.user_model_fields:
 46            return self.user_model_fields["email"]
 47        if "username" in self.user_model_fields:
 48            return self.user_model_fields["username"]
 49        return str(self.provider_id)
 50
 51
 52class OAuthProvider(ABC):
 53    authorization_url = ""
 54
 55    def __init__(
 56        self,
 57        *,
 58        # Provided automatically
 59        provider_key: str,
 60        # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
 61        client_id: str,
 62        client_secret: str,
 63        # Not necessarily required, but commonly used
 64        scope: str = "",
 65    ):
 66        self.provider_key = provider_key
 67        self.client_id = client_id
 68        self.client_secret = client_secret
 69        self.scope = scope
 70
 71    def get_authorization_url_params(self, *, request: Request) -> dict:
 72        return {
 73            "redirect_uri": self.get_callback_url(request=request),
 74            "client_id": self.get_client_id(),
 75            "scope": self.get_scope(),
 76            "state": self.generate_state(),
 77            "response_type": "code",
 78        }
 79
 80    @abstractmethod
 81    def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken: ...
 82
 83    @abstractmethod
 84    def get_oauth_token(self, *, code: str, request: Request) -> OAuthToken: ...
 85
 86    @abstractmethod
 87    def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser: ...
 88
 89    def get_authorization_url(self, *, request: Request) -> str:
 90        return self.authorization_url
 91
 92    def get_client_id(self) -> str:
 93        return self.client_id
 94
 95    def get_client_secret(self) -> str:
 96        return self.client_secret
 97
 98    def get_scope(self) -> str:
 99        return self.scope
100
101    def get_callback_url(self, *, request: Request) -> str:
102        url = reverse("oauth:callback", provider=self.provider_key)
103        return request.build_absolute_uri(url)
104
105    def generate_state(self) -> str:
106        return get_random_string(length=32)
107
108    def check_request_state(self, *, request: Request) -> None:
109        if error := request.query_params.get("error"):
110            raise OAuthError(error)
111
112        try:
113            state = request.query_params["state"]
114        except KeyError as e:
115            raise OAuthStateMissingError() from e
116
117        session = get_request_session(request)
118        if SESSION_STATE_KEY not in session:
119            raise OAuthStateMissingError()
120        expected_state = session.pop(SESSION_STATE_KEY)
121        session.save()  # Make sure the pop is saved (won't save on an exception)
122        if not secrets.compare_digest(state, expected_state):
123            raise OAuthStateMismatchError()
124
125    def handle_login_request(
126        self, *, request: Request, redirect_to: str = ""
127    ) -> Response:
128        authorization_url = self.get_authorization_url(request=request)
129        authorization_params = self.get_authorization_url_params(request=request)
130
131        session = get_request_session(request)
132
133        if "state" in authorization_params:
134            # Store the state in the session so we can check on callback
135            session[SESSION_STATE_KEY] = authorization_params["state"]
136
137        # Store next url in session so we can get it on the callback request
138        if redirect_to:
139            session[SESSION_NEXT_KEY] = redirect_to
140        elif "next" in request.form_data:
141            session[SESSION_NEXT_KEY] = request.form_data["next"]
142
143        # Sort authorization params for consistency
144        sorted_authorization_params = sorted(authorization_params.items())
145        redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
146        return self.get_redirect_response(redirect_url)
147
148    def handle_connect_request(
149        self, *, request: Request, redirect_to: str = ""
150    ) -> Response:
151        return self.handle_login_request(request=request, redirect_to=redirect_to)
152
153    def handle_disconnect_request(self, *, request: Request) -> Response:
154        provider_user_id = request.form_data["provider_user_id"]
155        connection = OAuthConnection.query.get(
156            provider_key=self.provider_key, provider_user_id=provider_user_id
157        )
158        connection.delete()
159        redirect_url = self.get_disconnect_redirect_url(request=request)
160        return self.get_redirect_response(redirect_url)
161
162    def handle_callback_request(self, *, request: Request) -> Response:
163        self.check_request_state(request=request)
164
165        oauth_token = self.get_oauth_token(
166            code=request.query_params["code"], request=request
167        )
168        oauth_user = self.get_oauth_user(oauth_token=oauth_token)
169
170        user = get_request_user(request)
171        if user:
172            connection = OAuthConnection.connect(
173                user=user,
174                provider_key=self.provider_key,
175                oauth_token=oauth_token,
176                oauth_user=oauth_user,
177            )
178            user = connection.user
179        else:
180            connection = OAuthConnection.get_or_create_user(
181                provider_key=self.provider_key,
182                oauth_token=oauth_token,
183                oauth_user=oauth_user,
184            )
185
186            user = connection.user
187
188            self.login(request=request, user=user)
189
190        redirect_url = self.get_login_redirect_url(request=request)
191        return self.get_redirect_response(redirect_url)
192
193    def login(self, *, request: Request, user: Any) -> None:
194        auth_login(request=request, user=user)
195
196    def get_login_redirect_url(self, *, request: Request) -> str:
197        session = get_request_session(request)
198        return session.pop(SESSION_NEXT_KEY, "/")
199
200    def get_disconnect_redirect_url(self, *, request: Request) -> str:
201        return request.form_data.get("next", "/")
202
203    def get_redirect_response(self, redirect_url: str) -> Response:
204        """
205        Returns a redirect response to the given URL.
206        This is a utility method to ensure consistent redirect handling.
207        """
208        response = ResponseRedirect(redirect_url)
209        add_never_cache_headers(response)
210        return response
211
212
213def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
214    OAUTH_LOGIN_PROVIDERS = settings.OAUTH_LOGIN_PROVIDERS
215    provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
216    provider_class = import_string(provider_class_path)
217    provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
218    return provider_class(provider_key=provider_key, **provider_kwargs)
219
220
221def get_provider_keys() -> list[str]:
222    return list(settings.OAUTH_LOGIN_PROVIDERS.keys())