Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import secrets
  3from typing import Any
  4from urllib.parse import urlencode
  5
  6from plain.auth import login as auth_login
  7from plain.auth.requests import get_request_user
  8from plain.http import Request, Response, ResponseRedirect
  9from plain.runtime import settings
 10from plain.sessions import get_request_session
 11from plain.urls import reverse
 12from plain.utils.cache import add_never_cache_headers
 13from plain.utils.crypto import get_random_string
 14from plain.utils.module_loading import import_string
 15
 16from .exceptions import OAuthError, OAuthStateMismatchError, OAuthStateMissingError
 17from .models import OAuthConnection
 18
 19SESSION_STATE_KEY = "plainoauth_state"
 20SESSION_NEXT_KEY = "plainoauth_next"
 21
 22
 23class OAuthToken:
 24    def __init__(
 25        self,
 26        *,
 27        access_token: str,
 28        refresh_token: str = "",
 29        access_token_expires_at: datetime.datetime | None = None,
 30        refresh_token_expires_at: datetime.datetime | None = None,
 31    ):
 32        self.access_token = access_token
 33        self.refresh_token = refresh_token
 34        self.access_token_expires_at = access_token_expires_at
 35        self.refresh_token_expires_at = refresh_token_expires_at
 36
 37
 38class OAuthUser:
 39    def __init__(self, *, provider_id: str, user_model_fields: dict | None = None):
 40        self.provider_id = provider_id  # ID on the provider's system
 41        self.user_model_fields = user_model_fields or {}
 42
 43    def __str__(self) -> str:
 44        if "email" in self.user_model_fields:
 45            return self.user_model_fields["email"]
 46        if "username" in self.user_model_fields:
 47            return self.user_model_fields["username"]
 48        return str(self.provider_id)
 49
 50
 51class OAuthProvider:
 52    authorization_url = ""
 53
 54    def __init__(
 55        self,
 56        *,
 57        # Provided automatically
 58        provider_key: str,
 59        # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
 60        client_id: str,
 61        client_secret: str,
 62        # Not necessarily required, but commonly used
 63        scope: str = "",
 64    ):
 65        self.provider_key = provider_key
 66        self.client_id = client_id
 67        self.client_secret = client_secret
 68        self.scope = scope
 69
 70    def get_authorization_url_params(self, *, request: Request) -> dict:
 71        return {
 72            "redirect_uri": self.get_callback_url(request=request),
 73            "client_id": self.get_client_id(),
 74            "scope": self.get_scope(),
 75            "state": self.generate_state(),
 76            "response_type": "code",
 77        }
 78
 79    def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken:
 80        raise NotImplementedError()
 81
 82    def get_oauth_token(self, *, code: str, request: Request) -> OAuthToken:
 83        raise NotImplementedError()
 84
 85    def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser:
 86        raise NotImplementedError()
 87
 88    def get_authorization_url(self, *, request: Request) -> str:
 89        return self.authorization_url
 90
 91    def get_client_id(self) -> str:
 92        return self.client_id
 93
 94    def get_client_secret(self) -> str:
 95        return self.client_secret
 96
 97    def get_scope(self) -> str:
 98        return self.scope
 99
100    def get_callback_url(self, *, request: Request) -> str:
101        url = reverse("oauth:callback", provider=self.provider_key)
102        return request.build_absolute_uri(url)
103
104    def generate_state(self) -> str:
105        return get_random_string(length=32)
106
107    def check_request_state(self, *, request: Request) -> None:
108        if error := request.query_params.get("error"):
109            raise OAuthError(error)
110
111        try:
112            state = request.query_params["state"]
113        except KeyError as e:
114            raise OAuthStateMissingError() from e
115
116        session = get_request_session(request)
117        expected_state = session.pop(SESSION_STATE_KEY)
118        session.save()  # Make sure the pop is saved (won't save on an exception)
119        if not secrets.compare_digest(state, expected_state):
120            raise OAuthStateMismatchError()
121
122    def handle_login_request(
123        self, *, request: Request, redirect_to: str = ""
124    ) -> Response:
125        authorization_url = self.get_authorization_url(request=request)
126        authorization_params = self.get_authorization_url_params(request=request)
127
128        session = get_request_session(request)
129
130        if "state" in authorization_params:
131            # Store the state in the session so we can check on callback
132            session[SESSION_STATE_KEY] = authorization_params["state"]
133
134        # Store next url in session so we can get it on the callback request
135        if redirect_to:
136            session[SESSION_NEXT_KEY] = redirect_to
137        elif "next" in request.data:
138            session[SESSION_NEXT_KEY] = request.data["next"]
139
140        # Sort authorization params for consistency
141        sorted_authorization_params = sorted(authorization_params.items())
142        redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
143        return self.get_redirect_response(redirect_url)
144
145    def handle_connect_request(
146        self, *, request: Request, redirect_to: str = ""
147    ) -> Response:
148        return self.handle_login_request(request=request, redirect_to=redirect_to)
149
150    def handle_disconnect_request(self, *, request: Request) -> Response:
151        provider_user_id = request.data["provider_user_id"]
152        connection = OAuthConnection.query.get(
153            provider_key=self.provider_key, provider_user_id=provider_user_id
154        )
155        connection.delete()
156        redirect_url = self.get_disconnect_redirect_url(request=request)
157        return self.get_redirect_response(redirect_url)
158
159    def handle_callback_request(self, *, request: Request) -> Response:
160        self.check_request_state(request=request)
161
162        oauth_token = self.get_oauth_token(
163            code=request.query_params["code"], request=request
164        )
165        oauth_user = self.get_oauth_user(oauth_token=oauth_token)
166
167        user = get_request_user(request)
168        if user:
169            connection = OAuthConnection.connect(
170                user=user,
171                provider_key=self.provider_key,
172                oauth_token=oauth_token,
173                oauth_user=oauth_user,
174            )
175            user = connection.user
176        else:
177            connection = OAuthConnection.get_or_create_user(
178                provider_key=self.provider_key,
179                oauth_token=oauth_token,
180                oauth_user=oauth_user,
181            )
182
183            user = connection.user
184
185            self.login(request=request, user=user)
186
187        redirect_url = self.get_login_redirect_url(request=request)
188        return self.get_redirect_response(redirect_url)
189
190    def login(self, *, request: Request, user: Any) -> None:
191        auth_login(request=request, user=user)
192
193    def get_login_redirect_url(self, *, request: Request) -> str:
194        session = get_request_session(request)
195        return session.pop(SESSION_NEXT_KEY, "/")
196
197    def get_disconnect_redirect_url(self, *, request: Request) -> str:
198        return request.data.get("next", "/")
199
200    def get_redirect_response(self, redirect_url: str) -> Response:
201        """
202        Returns a redirect response to the given URL.
203        This is a utility method to ensure consistent redirect handling.
204        """
205        response = ResponseRedirect(redirect_url)
206        add_never_cache_headers(response)
207        return response
208
209
210def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
211    OAUTH_LOGIN_PROVIDERS = getattr(settings, "OAUTH_LOGIN_PROVIDERS", {})
212    provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
213    provider_class = import_string(provider_class_path)
214    provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
215    return provider_class(provider_key=provider_key, **provider_kwargs)
216
217
218def get_provider_keys() -> list[str]:
219    return list(getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}).keys())