1import datetime
  2import secrets
  3from abc import ABC, abstractmethod
  4from typing import Any
  5from urllib.parse import urlencode
  6
  7from plain.auth import login as auth_login
  8from plain.auth.requests import get_request_user
  9from plain.http import RedirectResponse, Request, Response
 10from plain.runtime import settings
 11from plain.sessions import get_request_session
 12from plain.urls import reverse
 13from plain.utils.cache import add_never_cache_headers
 14from plain.utils.crypto import get_random_string
 15from plain.utils.module_loading import import_string
 16
 17from .exceptions import OAuthError, OAuthStateMismatchError, OAuthStateMissingError
 18from .models import OAuthConnection
 19
 20__all__ = [
 21    "OAuthProvider",
 22    "OAuthToken",
 23    "OAuthUser",
 24    "get_oauth_provider_instance",
 25    "get_provider_keys",
 26]
 27
 28SESSION_STATE_KEY = "plainoauth_state"
 29SESSION_NEXT_KEY = "plainoauth_next"
 30
 31
 32class OAuthToken:
 33    def __init__(
 34        self,
 35        *,
 36        access_token: str,
 37        refresh_token: str = "",
 38        access_token_expires_at: datetime.datetime | None = None,
 39        refresh_token_expires_at: datetime.datetime | None = None,
 40    ):
 41        self.access_token = access_token
 42        self.refresh_token = refresh_token
 43        self.access_token_expires_at = access_token_expires_at
 44        self.refresh_token_expires_at = refresh_token_expires_at
 45
 46
 47class OAuthUser:
 48    def __init__(self, *, provider_id: str, user_model_fields: dict | None = None):
 49        self.provider_id = provider_id  # ID on the provider's system
 50        self.user_model_fields = user_model_fields or {}
 51
 52    def __str__(self) -> str:
 53        if "email" in self.user_model_fields:
 54            return self.user_model_fields["email"]
 55        if "username" in self.user_model_fields:
 56            return self.user_model_fields["username"]
 57        return str(self.provider_id)
 58
 59
 60class OAuthProvider(ABC):
 61    authorization_url = ""
 62
 63    def __init__(
 64        self,
 65        *,
 66        # Provided automatically
 67        provider_key: str,
 68        # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
 69        client_id: str,
 70        client_secret: str,
 71        # Not necessarily required, but commonly used
 72        scope: str = "",
 73    ):
 74        self.provider_key = provider_key
 75        self.client_id = client_id
 76        self.client_secret = client_secret
 77        self.scope = scope
 78
 79    def get_authorization_url_params(self, *, request: Request) -> dict:
 80        return {
 81            "redirect_uri": self.get_callback_url(request=request),
 82            "client_id": self.get_client_id(),
 83            "scope": self.get_scope(),
 84            "state": self.generate_state(),
 85            "response_type": "code",
 86        }
 87
 88    @abstractmethod
 89    def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken: ...
 90
 91    @abstractmethod
 92    def get_oauth_token(self, *, code: str, request: Request) -> OAuthToken: ...
 93
 94    @abstractmethod
 95    def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser: ...
 96
 97    def get_authorization_url(self, *, request: Request) -> str:
 98        return self.authorization_url
 99
100    def get_client_id(self) -> str:
101        return self.client_id
102
103    def get_client_secret(self) -> str:
104        return self.client_secret
105
106    def get_scope(self) -> str:
107        return self.scope
108
109    def get_callback_url(self, *, request: Request) -> str:
110        url = reverse("oauth:callback", provider=self.provider_key)
111        return request.build_absolute_uri(url)
112
113    def generate_state(self) -> str:
114        return get_random_string(length=32)
115
116    def check_request_state(self, *, request: Request) -> None:
117        if error := request.query_params.get("error"):
118            raise OAuthError(error)
119
120        try:
121            state = request.query_params["state"]
122        except KeyError as e:
123            raise OAuthStateMissingError() from e
124
125        session = get_request_session(request)
126        if SESSION_STATE_KEY not in session:
127            raise OAuthStateMissingError()
128        expected_state = session.pop(SESSION_STATE_KEY)
129        session.save()  # Make sure the pop is saved (won't save on an exception)
130        if not secrets.compare_digest(state, expected_state):
131            raise OAuthStateMismatchError()
132
133    def handle_login_request(
134        self, *, request: Request, redirect_to: str = ""
135    ) -> Response:
136        authorization_url = self.get_authorization_url(request=request)
137        authorization_params = self.get_authorization_url_params(request=request)
138
139        session = get_request_session(request)
140
141        if "state" in authorization_params:
142            # Store the state in the session so we can check on callback
143            session[SESSION_STATE_KEY] = authorization_params["state"]
144
145        # Store next url in session so we can get it on the callback request
146        if redirect_to:
147            session[SESSION_NEXT_KEY] = redirect_to
148        elif "next" in request.form_data:
149            session[SESSION_NEXT_KEY] = request.form_data["next"]
150
151        # Sort authorization params for consistency
152        sorted_authorization_params = sorted(authorization_params.items())
153        redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
154        return self.get_redirect_response(redirect_url)
155
156    def handle_connect_request(
157        self, *, request: Request, redirect_to: str = ""
158    ) -> Response:
159        return self.handle_login_request(request=request, redirect_to=redirect_to)
160
161    def handle_disconnect_request(self, *, request: Request) -> Response:
162        provider_user_id = request.form_data["provider_user_id"]
163        connection = OAuthConnection.query.get(
164            provider_key=self.provider_key, provider_user_id=provider_user_id
165        )
166        connection.delete()
167        redirect_url = self.get_disconnect_redirect_url(request=request)
168        return self.get_redirect_response(redirect_url)
169
170    def handle_callback_request(self, *, request: Request) -> Response:
171        self.check_request_state(request=request)
172
173        oauth_token = self.get_oauth_token(
174            code=request.query_params["code"], request=request
175        )
176        oauth_user = self.get_oauth_user(oauth_token=oauth_token)
177
178        user = get_request_user(request)
179        if user:
180            connection = OAuthConnection.connect(
181                user=user,
182                provider_key=self.provider_key,
183                oauth_token=oauth_token,
184                oauth_user=oauth_user,
185            )
186            user = connection.user
187        else:
188            connection = OAuthConnection.get_or_create_user(
189                provider_key=self.provider_key,
190                oauth_token=oauth_token,
191                oauth_user=oauth_user,
192            )
193
194            user = connection.user
195
196            self.login(request=request, user=user)
197
198        redirect_url = self.get_login_redirect_url(request=request)
199        return self.get_redirect_response(redirect_url)
200
201    def login(self, *, request: Request, user: Any) -> None:
202        auth_login(request=request, user=user)
203
204    def get_login_redirect_url(self, *, request: Request) -> str:
205        session = get_request_session(request)
206        return session.pop(SESSION_NEXT_KEY, "/")
207
208    def get_disconnect_redirect_url(self, *, request: Request) -> str:
209        return request.form_data.get("next", "/")
210
211    def get_redirect_response(self, redirect_url: str) -> Response:
212        """
213        Returns a redirect response to the given URL.
214        This is a utility method to ensure consistent redirect handling.
215        """
216        response = RedirectResponse(redirect_url)
217        add_never_cache_headers(response)
218        return response
219
220
221def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
222    OAUTH_LOGIN_PROVIDERS = settings.OAUTH_LOGIN_PROVIDERS
223    provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
224    provider_class = import_string(provider_class_path)
225    provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
226    return provider_class(provider_key=provider_key, **provider_kwargs)
227
228
229def get_provider_keys() -> list[str]:
230    return list(settings.OAUTH_LOGIN_PROVIDERS.keys())