1import datetime
2import secrets
3from abc import ABC, abstractmethod
4from typing import Any
5from urllib.parse import urlencode
6
7from plain.auth import login as auth_login
8from plain.auth.requests import get_request_user
9from plain.http import RedirectResponse, Request, Response
10from plain.runtime import settings
11from plain.sessions import get_request_session
12from plain.urls import reverse
13from plain.utils.cache import add_never_cache_headers
14from plain.utils.crypto import get_random_string
15from plain.utils.module_loading import import_string
16
17from .exceptions import OAuthError, OAuthStateMismatchError, OAuthStateMissingError
18from .models import OAuthConnection
19
20__all__ = [
21 "OAuthProvider",
22 "OAuthToken",
23 "OAuthUser",
24 "get_oauth_provider_instance",
25 "get_provider_keys",
26]
27
28SESSION_STATE_KEY = "plainoauth_state"
29SESSION_NEXT_KEY = "plainoauth_next"
30
31
32class OAuthToken:
33 def __init__(
34 self,
35 *,
36 access_token: str,
37 refresh_token: str = "",
38 access_token_expires_at: datetime.datetime | None = None,
39 refresh_token_expires_at: datetime.datetime | None = None,
40 ):
41 self.access_token = access_token
42 self.refresh_token = refresh_token
43 self.access_token_expires_at = access_token_expires_at
44 self.refresh_token_expires_at = refresh_token_expires_at
45
46
47class OAuthUser:
48 def __init__(self, *, provider_id: str, user_model_fields: dict | None = None):
49 self.provider_id = provider_id # ID on the provider's system
50 self.user_model_fields = user_model_fields or {}
51
52 def __str__(self) -> str:
53 if "email" in self.user_model_fields:
54 return self.user_model_fields["email"]
55 if "username" in self.user_model_fields:
56 return self.user_model_fields["username"]
57 return str(self.provider_id)
58
59
60class OAuthProvider(ABC):
61 authorization_url = ""
62
63 def __init__(
64 self,
65 *,
66 # Provided automatically
67 provider_key: str,
68 # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
69 client_id: str,
70 client_secret: str,
71 # Not necessarily required, but commonly used
72 scope: str = "",
73 ):
74 self.provider_key = provider_key
75 self.client_id = client_id
76 self.client_secret = client_secret
77 self.scope = scope
78
79 def get_authorization_url_params(self, *, request: Request) -> dict:
80 return {
81 "redirect_uri": self.get_callback_url(request=request),
82 "client_id": self.get_client_id(),
83 "scope": self.get_scope(),
84 "state": self.generate_state(),
85 "response_type": "code",
86 }
87
88 @abstractmethod
89 def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken: ...
90
91 @abstractmethod
92 def get_oauth_token(self, *, code: str, request: Request) -> OAuthToken: ...
93
94 @abstractmethod
95 def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser: ...
96
97 def get_authorization_url(self, *, request: Request) -> str:
98 return self.authorization_url
99
100 def get_client_id(self) -> str:
101 return self.client_id
102
103 def get_client_secret(self) -> str:
104 return self.client_secret
105
106 def get_scope(self) -> str:
107 return self.scope
108
109 def get_callback_url(self, *, request: Request) -> str:
110 url = reverse("oauth:callback", provider=self.provider_key)
111 return request.build_absolute_uri(url)
112
113 def generate_state(self) -> str:
114 return get_random_string(length=32)
115
116 def check_request_state(self, *, request: Request) -> None:
117 if error := request.query_params.get("error"):
118 raise OAuthError(error)
119
120 try:
121 state = request.query_params["state"]
122 except KeyError as e:
123 raise OAuthStateMissingError() from e
124
125 session = get_request_session(request)
126 if SESSION_STATE_KEY not in session:
127 raise OAuthStateMissingError()
128 expected_state = session.pop(SESSION_STATE_KEY)
129 session.save() # Make sure the pop is saved (won't save on an exception)
130 if not secrets.compare_digest(state, expected_state):
131 raise OAuthStateMismatchError()
132
133 def handle_login_request(
134 self, *, request: Request, redirect_to: str = ""
135 ) -> Response:
136 authorization_url = self.get_authorization_url(request=request)
137 authorization_params = self.get_authorization_url_params(request=request)
138
139 session = get_request_session(request)
140
141 if "state" in authorization_params:
142 # Store the state in the session so we can check on callback
143 session[SESSION_STATE_KEY] = authorization_params["state"]
144
145 # Store next url in session so we can get it on the callback request
146 if redirect_to:
147 session[SESSION_NEXT_KEY] = redirect_to
148 elif "next" in request.form_data:
149 session[SESSION_NEXT_KEY] = request.form_data["next"]
150
151 # Sort authorization params for consistency
152 sorted_authorization_params = sorted(authorization_params.items())
153 redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
154 return self.get_redirect_response(redirect_url)
155
156 def handle_connect_request(
157 self, *, request: Request, redirect_to: str = ""
158 ) -> Response:
159 return self.handle_login_request(request=request, redirect_to=redirect_to)
160
161 def handle_disconnect_request(self, *, request: Request) -> Response:
162 provider_user_id = request.form_data["provider_user_id"]
163 connection = OAuthConnection.query.get(
164 provider_key=self.provider_key, provider_user_id=provider_user_id
165 )
166 connection.delete()
167 redirect_url = self.get_disconnect_redirect_url(request=request)
168 return self.get_redirect_response(redirect_url)
169
170 def handle_callback_request(self, *, request: Request) -> Response:
171 self.check_request_state(request=request)
172
173 oauth_token = self.get_oauth_token(
174 code=request.query_params["code"], request=request
175 )
176 oauth_user = self.get_oauth_user(oauth_token=oauth_token)
177
178 user = get_request_user(request)
179 if user:
180 connection = OAuthConnection.connect(
181 user=user,
182 provider_key=self.provider_key,
183 oauth_token=oauth_token,
184 oauth_user=oauth_user,
185 )
186 user = connection.user
187 else:
188 connection = OAuthConnection.get_or_create_user(
189 provider_key=self.provider_key,
190 oauth_token=oauth_token,
191 oauth_user=oauth_user,
192 )
193
194 user = connection.user
195
196 self.login(request=request, user=user)
197
198 redirect_url = self.get_login_redirect_url(request=request)
199 return self.get_redirect_response(redirect_url)
200
201 def login(self, *, request: Request, user: Any) -> None:
202 auth_login(request=request, user=user)
203
204 def get_login_redirect_url(self, *, request: Request) -> str:
205 session = get_request_session(request)
206 return session.pop(SESSION_NEXT_KEY, "/")
207
208 def get_disconnect_redirect_url(self, *, request: Request) -> str:
209 return request.form_data.get("next", "/")
210
211 def get_redirect_response(self, redirect_url: str) -> Response:
212 """
213 Returns a redirect response to the given URL.
214 This is a utility method to ensure consistent redirect handling.
215 """
216 response = RedirectResponse(redirect_url)
217 add_never_cache_headers(response)
218 return response
219
220
221def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
222 OAUTH_LOGIN_PROVIDERS = settings.OAUTH_LOGIN_PROVIDERS
223 provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
224 provider_class = import_string(provider_class_path)
225 provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
226 return provider_class(provider_key=provider_key, **provider_kwargs)
227
228
229def get_provider_keys() -> list[str]:
230 return list(settings.OAUTH_LOGIN_PROVIDERS.keys())