Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import secrets
  3from typing import Any
  4from urllib.parse import urlencode
  5
  6from plain.auth import login as auth_login
  7from plain.http import HttpRequest, Response, ResponseRedirect
  8from plain.runtime import settings
  9from plain.urls import reverse
 10from plain.utils.crypto import get_random_string
 11from plain.utils.module_loading import import_string
 12
 13from .exceptions import OAuthError, OAuthStateMismatchError
 14from .models import OAuthConnection
 15
 16SESSION_STATE_KEY = "plainoauth_state"
 17SESSION_NEXT_KEY = "plainoauth_next"
 18
 19
 20class OAuthToken:
 21    def __init__(
 22        self,
 23        *,
 24        access_token: str,
 25        refresh_token: str = "",
 26        access_token_expires_at: datetime.datetime = None,
 27        refresh_token_expires_at: datetime.datetime = None,
 28    ):
 29        self.access_token = access_token
 30        self.refresh_token = refresh_token
 31        self.access_token_expires_at = access_token_expires_at
 32        self.refresh_token_expires_at = refresh_token_expires_at
 33
 34
 35class OAuthUser:
 36    def __init__(self, *, id: str, **user_model_fields: dict):
 37        self.id = id  # ID on the provider's system
 38        self.user_model_fields = user_model_fields
 39
 40    def __str__(self):
 41        if "email" in self.user_model_fields:
 42            return self.user_model_fields["email"]
 43        if "username" in self.user_model_fields:
 44            return self.user_model_fields["username"]
 45        return str(self.id)
 46
 47
 48class OAuthProvider:
 49    authorization_url = ""
 50
 51    def __init__(
 52        self,
 53        *,
 54        # Provided automatically
 55        provider_key: str,
 56        # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
 57        client_id: str,
 58        client_secret: str,
 59        # Not necessarily required, but commonly used
 60        scope: str = "",
 61    ):
 62        self.provider_key = provider_key
 63        self.client_id = client_id
 64        self.client_secret = client_secret
 65        self.scope = scope
 66
 67    def get_authorization_url_params(self, *, request: HttpRequest) -> dict:
 68        return {
 69            "redirect_uri": self.get_callback_url(request=request),
 70            "client_id": self.get_client_id(),
 71            "scope": self.get_scope(),
 72            "state": self.generate_state(),
 73            "response_type": "code",
 74        }
 75
 76    def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken:
 77        raise NotImplementedError()
 78
 79    def get_oauth_token(self, *, code: str, request: HttpRequest) -> OAuthToken:
 80        raise NotImplementedError()
 81
 82    def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser:
 83        raise NotImplementedError()
 84
 85    def get_authorization_url(self, *, request: HttpRequest) -> str:
 86        return self.authorization_url
 87
 88    def get_client_id(self) -> str:
 89        return self.client_id
 90
 91    def get_client_secret(self) -> str:
 92        return self.client_secret
 93
 94    def get_scope(self) -> str:
 95        return self.scope
 96
 97    def get_callback_url(self, *, request: HttpRequest) -> str:
 98        url = reverse("oauth:callback", provider=self.provider_key)
 99        return request.build_absolute_uri(url)
100
101    def generate_state(self) -> str:
102        return get_random_string(length=32)
103
104    def check_request_state(self, *, request: HttpRequest) -> None:
105        if error := request.query_params.get("error"):
106            raise OAuthError(error)
107
108        state = request.query_params["state"]
109        expected_state = request.session.pop(SESSION_STATE_KEY)
110        request.session.save()  # Make sure the pop is saved (won't save on an exception)
111        if not secrets.compare_digest(state, expected_state):
112            raise OAuthStateMismatchError()
113
114    def handle_login_request(
115        self, *, request: HttpRequest, redirect_to: str = ""
116    ) -> Response:
117        authorization_url = self.get_authorization_url(request=request)
118        authorization_params = self.get_authorization_url_params(request=request)
119
120        if "state" in authorization_params:
121            # Store the state in the session so we can check on callback
122            request.session[SESSION_STATE_KEY] = authorization_params["state"]
123
124        # Store next url in session so we can get it on the callback request
125        if redirect_to:
126            request.session[SESSION_NEXT_KEY] = redirect_to
127        elif "next" in request.data:
128            request.session[SESSION_NEXT_KEY] = request.data["next"]
129
130        # Sort authorization params for consistency
131        sorted_authorization_params = sorted(authorization_params.items())
132        redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
133        return ResponseRedirect(redirect_url)
134
135    def handle_connect_request(
136        self, *, request: HttpRequest, redirect_to: str = ""
137    ) -> Response:
138        return self.handle_login_request(request=request, redirect_to=redirect_to)
139
140    def handle_disconnect_request(self, *, request: HttpRequest) -> Response:
141        provider_user_id = request.data["provider_user_id"]
142        connection = OAuthConnection.objects.get(
143            provider_key=self.provider_key, provider_user_id=provider_user_id
144        )
145        connection.delete()
146        redirect_url = self.get_disconnect_redirect_url(request=request)
147        return ResponseRedirect(redirect_url)
148
149    def handle_callback_request(self, *, request: HttpRequest) -> Response:
150        self.check_request_state(request=request)
151
152        oauth_token = self.get_oauth_token(
153            code=request.query_params["code"], request=request
154        )
155        oauth_user = self.get_oauth_user(oauth_token=oauth_token)
156
157        if request.user:
158            connection = OAuthConnection.connect(
159                user=request.user,
160                provider_key=self.provider_key,
161                oauth_token=oauth_token,
162                oauth_user=oauth_user,
163            )
164            user = connection.user
165        else:
166            connection = OAuthConnection.get_or_create_user(
167                provider_key=self.provider_key,
168                oauth_token=oauth_token,
169                oauth_user=oauth_user,
170            )
171
172            user = connection.user
173
174            self.login(request=request, user=user)
175
176        redirect_url = self.get_login_redirect_url(request=request)
177        return ResponseRedirect(redirect_url)
178
179    def login(self, *, request: HttpRequest, user: Any) -> Response:
180        auth_login(request=request, user=user)
181
182    def get_login_redirect_url(self, *, request: HttpRequest) -> str:
183        return request.session.pop(SESSION_NEXT_KEY, "/")
184
185    def get_disconnect_redirect_url(self, *, request: HttpRequest) -> str:
186        return request.data.get("next", "/")
187
188
189def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
190    OAUTH_LOGIN_PROVIDERS = getattr(settings, "OAUTH_LOGIN_PROVIDERS", {})
191    provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
192    provider_class = import_string(provider_class_path)
193    provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
194    return provider_class(provider_key=provider_key, **provider_kwargs)
195
196
197def get_provider_keys() -> list[str]:
198    return list(getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}).keys())