Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import secrets
  3from typing import Any
  4from urllib.parse import urlencode
  5
  6from plain.auth import login as auth_login
  7from plain.http import HttpRequest, Response, ResponseRedirect
  8from plain.runtime import settings
  9from plain.urls import reverse
 10from plain.utils.cache import add_never_cache_headers
 11from plain.utils.crypto import get_random_string
 12from plain.utils.module_loading import import_string
 13
 14from .exceptions import OAuthError, OAuthStateMismatchError, OAuthStateMissingError
 15from .models import OAuthConnection
 16
 17SESSION_STATE_KEY = "plainoauth_state"
 18SESSION_NEXT_KEY = "plainoauth_next"
 19
 20
 21class OAuthToken:
 22    def __init__(
 23        self,
 24        *,
 25        access_token: str,
 26        refresh_token: str = "",
 27        access_token_expires_at: datetime.datetime | None = None,
 28        refresh_token_expires_at: datetime.datetime | None = None,
 29    ):
 30        self.access_token = access_token
 31        self.refresh_token = refresh_token
 32        self.access_token_expires_at = access_token_expires_at
 33        self.refresh_token_expires_at = refresh_token_expires_at
 34
 35
 36class OAuthUser:
 37    def __init__(self, *, provider_id: str, user_model_fields: dict | None = None):
 38        self.provider_id = provider_id  # ID on the provider's system
 39        self.user_model_fields = user_model_fields or {}
 40
 41    def __str__(self):
 42        if "email" in self.user_model_fields:
 43            return self.user_model_fields["email"]
 44        if "username" in self.user_model_fields:
 45            return self.user_model_fields["username"]
 46        return str(self.provider_id)
 47
 48
 49class OAuthProvider:
 50    authorization_url = ""
 51
 52    def __init__(
 53        self,
 54        *,
 55        # Provided automatically
 56        provider_key: str,
 57        # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
 58        client_id: str,
 59        client_secret: str,
 60        # Not necessarily required, but commonly used
 61        scope: str = "",
 62    ):
 63        self.provider_key = provider_key
 64        self.client_id = client_id
 65        self.client_secret = client_secret
 66        self.scope = scope
 67
 68    def get_authorization_url_params(self, *, request: HttpRequest) -> dict:
 69        return {
 70            "redirect_uri": self.get_callback_url(request=request),
 71            "client_id": self.get_client_id(),
 72            "scope": self.get_scope(),
 73            "state": self.generate_state(),
 74            "response_type": "code",
 75        }
 76
 77    def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken:
 78        raise NotImplementedError()
 79
 80    def get_oauth_token(self, *, code: str, request: HttpRequest) -> OAuthToken:
 81        raise NotImplementedError()
 82
 83    def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser:
 84        raise NotImplementedError()
 85
 86    def get_authorization_url(self, *, request: HttpRequest) -> str:
 87        return self.authorization_url
 88
 89    def get_client_id(self) -> str:
 90        return self.client_id
 91
 92    def get_client_secret(self) -> str:
 93        return self.client_secret
 94
 95    def get_scope(self) -> str:
 96        return self.scope
 97
 98    def get_callback_url(self, *, request: HttpRequest) -> str:
 99        url = reverse("oauth:callback", provider=self.provider_key)
100        return request.build_absolute_uri(url)
101
102    def generate_state(self) -> str:
103        return get_random_string(length=32)
104
105    def check_request_state(self, *, request: HttpRequest) -> None:
106        if error := request.query_params.get("error"):
107            raise OAuthError(error)
108
109        try:
110            state = request.query_params["state"]
111        except KeyError as e:
112            raise OAuthStateMissingError() from e
113
114        expected_state = request.session.pop(SESSION_STATE_KEY)
115        request.session.save()  # Make sure the pop is saved (won't save on an exception)
116        if not secrets.compare_digest(state, expected_state):
117            raise OAuthStateMismatchError()
118
119    def handle_login_request(
120        self, *, request: HttpRequest, redirect_to: str = ""
121    ) -> Response:
122        authorization_url = self.get_authorization_url(request=request)
123        authorization_params = self.get_authorization_url_params(request=request)
124
125        if "state" in authorization_params:
126            # Store the state in the session so we can check on callback
127            request.session[SESSION_STATE_KEY] = authorization_params["state"]
128
129        # Store next url in session so we can get it on the callback request
130        if redirect_to:
131            request.session[SESSION_NEXT_KEY] = redirect_to
132        elif "next" in request.data:
133            request.session[SESSION_NEXT_KEY] = request.data["next"]
134
135        # Sort authorization params for consistency
136        sorted_authorization_params = sorted(authorization_params.items())
137        redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
138        return self.get_redirect_response(redirect_url)
139
140    def handle_connect_request(
141        self, *, request: HttpRequest, redirect_to: str = ""
142    ) -> Response:
143        return self.handle_login_request(request=request, redirect_to=redirect_to)
144
145    def handle_disconnect_request(self, *, request: HttpRequest) -> Response:
146        provider_user_id = request.data["provider_user_id"]
147        connection = OAuthConnection.objects.get(
148            provider_key=self.provider_key, provider_user_id=provider_user_id
149        )
150        connection.delete()
151        redirect_url = self.get_disconnect_redirect_url(request=request)
152        return self.get_redirect_response(redirect_url)
153
154    def handle_callback_request(self, *, request: HttpRequest) -> Response:
155        self.check_request_state(request=request)
156
157        oauth_token = self.get_oauth_token(
158            code=request.query_params["code"], request=request
159        )
160        oauth_user = self.get_oauth_user(oauth_token=oauth_token)
161
162        if request.user:
163            connection = OAuthConnection.connect(
164                user=request.user,
165                provider_key=self.provider_key,
166                oauth_token=oauth_token,
167                oauth_user=oauth_user,
168            )
169            user = connection.user
170        else:
171            connection = OAuthConnection.get_or_create_user(
172                provider_key=self.provider_key,
173                oauth_token=oauth_token,
174                oauth_user=oauth_user,
175            )
176
177            user = connection.user
178
179            self.login(request=request, user=user)
180
181        redirect_url = self.get_login_redirect_url(request=request)
182        return self.get_redirect_response(redirect_url)
183
184    def login(self, *, request: HttpRequest, user: Any) -> None:
185        auth_login(request=request, user=user)
186
187    def get_login_redirect_url(self, *, request: HttpRequest) -> str:
188        return request.session.pop(SESSION_NEXT_KEY, "/")
189
190    def get_disconnect_redirect_url(self, *, request: HttpRequest) -> str:
191        return request.data.get("next", "/")
192
193    def get_redirect_response(self, redirect_url: str) -> Response:
194        """
195        Returns a redirect response to the given URL.
196        This is a utility method to ensure consistent redirect handling.
197        """
198        response = ResponseRedirect(redirect_url)
199        add_never_cache_headers(response)
200        return response
201
202
203def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
204    OAUTH_LOGIN_PROVIDERS = getattr(settings, "OAUTH_LOGIN_PROVIDERS", {})
205    provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
206    provider_class = import_string(provider_class_path)
207    provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
208    return provider_class(provider_key=provider_key, **provider_kwargs)
209
210
211def get_provider_keys() -> list[str]:
212    return list(getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}).keys())