1import datetime
2import secrets
3from abc import ABC, abstractmethod
4from typing import Any
5from urllib.parse import urlencode
6
7from plain.auth import login as auth_login
8from plain.auth.requests import get_request_user
9from plain.http import Request, Response, ResponseRedirect
10from plain.runtime import settings
11from plain.sessions import get_request_session
12from plain.urls import reverse
13from plain.utils.cache import add_never_cache_headers
14from plain.utils.crypto import get_random_string
15from plain.utils.module_loading import import_string
16
17from .exceptions import OAuthError, OAuthStateMismatchError, OAuthStateMissingError
18from .models import OAuthConnection
19
20SESSION_STATE_KEY = "plainoauth_state"
21SESSION_NEXT_KEY = "plainoauth_next"
22
23
24class OAuthToken:
25 def __init__(
26 self,
27 *,
28 access_token: str,
29 refresh_token: str = "",
30 access_token_expires_at: datetime.datetime | None = None,
31 refresh_token_expires_at: datetime.datetime | None = None,
32 ):
33 self.access_token = access_token
34 self.refresh_token = refresh_token
35 self.access_token_expires_at = access_token_expires_at
36 self.refresh_token_expires_at = refresh_token_expires_at
37
38
39class OAuthUser:
40 def __init__(self, *, provider_id: str, user_model_fields: dict | None = None):
41 self.provider_id = provider_id # ID on the provider's system
42 self.user_model_fields = user_model_fields or {}
43
44 def __str__(self) -> str:
45 if "email" in self.user_model_fields:
46 return self.user_model_fields["email"]
47 if "username" in self.user_model_fields:
48 return self.user_model_fields["username"]
49 return str(self.provider_id)
50
51
52class OAuthProvider(ABC):
53 authorization_url = ""
54
55 def __init__(
56 self,
57 *,
58 # Provided automatically
59 provider_key: str,
60 # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
61 client_id: str,
62 client_secret: str,
63 # Not necessarily required, but commonly used
64 scope: str = "",
65 ):
66 self.provider_key = provider_key
67 self.client_id = client_id
68 self.client_secret = client_secret
69 self.scope = scope
70
71 def get_authorization_url_params(self, *, request: Request) -> dict:
72 return {
73 "redirect_uri": self.get_callback_url(request=request),
74 "client_id": self.get_client_id(),
75 "scope": self.get_scope(),
76 "state": self.generate_state(),
77 "response_type": "code",
78 }
79
80 @abstractmethod
81 def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken: ...
82
83 @abstractmethod
84 def get_oauth_token(self, *, code: str, request: Request) -> OAuthToken: ...
85
86 @abstractmethod
87 def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser: ...
88
89 def get_authorization_url(self, *, request: Request) -> str:
90 return self.authorization_url
91
92 def get_client_id(self) -> str:
93 return self.client_id
94
95 def get_client_secret(self) -> str:
96 return self.client_secret
97
98 def get_scope(self) -> str:
99 return self.scope
100
101 def get_callback_url(self, *, request: Request) -> str:
102 url = reverse("oauth:callback", provider=self.provider_key)
103 return request.build_absolute_uri(url)
104
105 def generate_state(self) -> str:
106 return get_random_string(length=32)
107
108 def check_request_state(self, *, request: Request) -> None:
109 if error := request.query_params.get("error"):
110 raise OAuthError(error)
111
112 try:
113 state = request.query_params["state"]
114 except KeyError as e:
115 raise OAuthStateMissingError() from e
116
117 session = get_request_session(request)
118 if SESSION_STATE_KEY not in session:
119 raise OAuthStateMissingError()
120 expected_state = session.pop(SESSION_STATE_KEY)
121 session.save() # Make sure the pop is saved (won't save on an exception)
122 if not secrets.compare_digest(state, expected_state):
123 raise OAuthStateMismatchError()
124
125 def handle_login_request(
126 self, *, request: Request, redirect_to: str = ""
127 ) -> Response:
128 authorization_url = self.get_authorization_url(request=request)
129 authorization_params = self.get_authorization_url_params(request=request)
130
131 session = get_request_session(request)
132
133 if "state" in authorization_params:
134 # Store the state in the session so we can check on callback
135 session[SESSION_STATE_KEY] = authorization_params["state"]
136
137 # Store next url in session so we can get it on the callback request
138 if redirect_to:
139 session[SESSION_NEXT_KEY] = redirect_to
140 elif "next" in request.form_data:
141 session[SESSION_NEXT_KEY] = request.form_data["next"]
142
143 # Sort authorization params for consistency
144 sorted_authorization_params = sorted(authorization_params.items())
145 redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
146 return self.get_redirect_response(redirect_url)
147
148 def handle_connect_request(
149 self, *, request: Request, redirect_to: str = ""
150 ) -> Response:
151 return self.handle_login_request(request=request, redirect_to=redirect_to)
152
153 def handle_disconnect_request(self, *, request: Request) -> Response:
154 provider_user_id = request.form_data["provider_user_id"]
155 connection = OAuthConnection.query.get(
156 provider_key=self.provider_key, provider_user_id=provider_user_id
157 )
158 connection.delete()
159 redirect_url = self.get_disconnect_redirect_url(request=request)
160 return self.get_redirect_response(redirect_url)
161
162 def handle_callback_request(self, *, request: Request) -> Response:
163 self.check_request_state(request=request)
164
165 oauth_token = self.get_oauth_token(
166 code=request.query_params["code"], request=request
167 )
168 oauth_user = self.get_oauth_user(oauth_token=oauth_token)
169
170 user = get_request_user(request)
171 if user:
172 connection = OAuthConnection.connect(
173 user=user,
174 provider_key=self.provider_key,
175 oauth_token=oauth_token,
176 oauth_user=oauth_user,
177 )
178 user = connection.user
179 else:
180 connection = OAuthConnection.get_or_create_user(
181 provider_key=self.provider_key,
182 oauth_token=oauth_token,
183 oauth_user=oauth_user,
184 )
185
186 user = connection.user
187
188 self.login(request=request, user=user)
189
190 redirect_url = self.get_login_redirect_url(request=request)
191 return self.get_redirect_response(redirect_url)
192
193 def login(self, *, request: Request, user: Any) -> None:
194 auth_login(request=request, user=user)
195
196 def get_login_redirect_url(self, *, request: Request) -> str:
197 session = get_request_session(request)
198 return session.pop(SESSION_NEXT_KEY, "/")
199
200 def get_disconnect_redirect_url(self, *, request: Request) -> str:
201 return request.form_data.get("next", "/")
202
203 def get_redirect_response(self, redirect_url: str) -> Response:
204 """
205 Returns a redirect response to the given URL.
206 This is a utility method to ensure consistent redirect handling.
207 """
208 response = ResponseRedirect(redirect_url)
209 add_never_cache_headers(response)
210 return response
211
212
213def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
214 OAUTH_LOGIN_PROVIDERS = settings.OAUTH_LOGIN_PROVIDERS
215 provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
216 provider_class = import_string(provider_class_path)
217 provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
218 return provider_class(provider_key=provider_key, **provider_kwargs)
219
220
221def get_provider_keys() -> list[str]:
222 return list(settings.OAUTH_LOGIN_PROVIDERS.keys())