1import datetime
2import secrets
3from typing import Any
4from urllib.parse import urlencode
5
6from plain.auth import login as auth_login
7from plain.auth.requests import get_request_user
8from plain.http import Request, Response, ResponseRedirect
9from plain.runtime import settings
10from plain.sessions import get_request_session
11from plain.urls import reverse
12from plain.utils.cache import add_never_cache_headers
13from plain.utils.crypto import get_random_string
14from plain.utils.module_loading import import_string
15
16from .exceptions import OAuthError, OAuthStateMismatchError, OAuthStateMissingError
17from .models import OAuthConnection
18
19SESSION_STATE_KEY = "plainoauth_state"
20SESSION_NEXT_KEY = "plainoauth_next"
21
22
23class OAuthToken:
24 def __init__(
25 self,
26 *,
27 access_token: str,
28 refresh_token: str = "",
29 access_token_expires_at: datetime.datetime | None = None,
30 refresh_token_expires_at: datetime.datetime | None = None,
31 ):
32 self.access_token = access_token
33 self.refresh_token = refresh_token
34 self.access_token_expires_at = access_token_expires_at
35 self.refresh_token_expires_at = refresh_token_expires_at
36
37
38class OAuthUser:
39 def __init__(self, *, provider_id: str, user_model_fields: dict | None = None):
40 self.provider_id = provider_id # ID on the provider's system
41 self.user_model_fields = user_model_fields or {}
42
43 def __str__(self) -> str:
44 if "email" in self.user_model_fields:
45 return self.user_model_fields["email"]
46 if "username" in self.user_model_fields:
47 return self.user_model_fields["username"]
48 return str(self.provider_id)
49
50
51class OAuthProvider:
52 authorization_url = ""
53
54 def __init__(
55 self,
56 *,
57 # Provided automatically
58 provider_key: str,
59 # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
60 client_id: str,
61 client_secret: str,
62 # Not necessarily required, but commonly used
63 scope: str = "",
64 ):
65 self.provider_key = provider_key
66 self.client_id = client_id
67 self.client_secret = client_secret
68 self.scope = scope
69
70 def get_authorization_url_params(self, *, request: Request) -> dict:
71 return {
72 "redirect_uri": self.get_callback_url(request=request),
73 "client_id": self.get_client_id(),
74 "scope": self.get_scope(),
75 "state": self.generate_state(),
76 "response_type": "code",
77 }
78
79 def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken:
80 raise NotImplementedError()
81
82 def get_oauth_token(self, *, code: str, request: Request) -> OAuthToken:
83 raise NotImplementedError()
84
85 def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser:
86 raise NotImplementedError()
87
88 def get_authorization_url(self, *, request: Request) -> str:
89 return self.authorization_url
90
91 def get_client_id(self) -> str:
92 return self.client_id
93
94 def get_client_secret(self) -> str:
95 return self.client_secret
96
97 def get_scope(self) -> str:
98 return self.scope
99
100 def get_callback_url(self, *, request: Request) -> str:
101 url = reverse("oauth:callback", provider=self.provider_key)
102 return request.build_absolute_uri(url)
103
104 def generate_state(self) -> str:
105 return get_random_string(length=32)
106
107 def check_request_state(self, *, request: Request) -> None:
108 if error := request.query_params.get("error"):
109 raise OAuthError(error)
110
111 try:
112 state = request.query_params["state"]
113 except KeyError as e:
114 raise OAuthStateMissingError() from e
115
116 session = get_request_session(request)
117 expected_state = session.pop(SESSION_STATE_KEY)
118 session.save() # Make sure the pop is saved (won't save on an exception)
119 if not secrets.compare_digest(state, expected_state):
120 raise OAuthStateMismatchError()
121
122 def handle_login_request(
123 self, *, request: Request, redirect_to: str = ""
124 ) -> Response:
125 authorization_url = self.get_authorization_url(request=request)
126 authorization_params = self.get_authorization_url_params(request=request)
127
128 session = get_request_session(request)
129
130 if "state" in authorization_params:
131 # Store the state in the session so we can check on callback
132 session[SESSION_STATE_KEY] = authorization_params["state"]
133
134 # Store next url in session so we can get it on the callback request
135 if redirect_to:
136 session[SESSION_NEXT_KEY] = redirect_to
137 elif "next" in request.data:
138 session[SESSION_NEXT_KEY] = request.data["next"]
139
140 # Sort authorization params for consistency
141 sorted_authorization_params = sorted(authorization_params.items())
142 redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
143 return self.get_redirect_response(redirect_url)
144
145 def handle_connect_request(
146 self, *, request: Request, redirect_to: str = ""
147 ) -> Response:
148 return self.handle_login_request(request=request, redirect_to=redirect_to)
149
150 def handle_disconnect_request(self, *, request: Request) -> Response:
151 provider_user_id = request.data["provider_user_id"]
152 connection = OAuthConnection.query.get(
153 provider_key=self.provider_key, provider_user_id=provider_user_id
154 )
155 connection.delete()
156 redirect_url = self.get_disconnect_redirect_url(request=request)
157 return self.get_redirect_response(redirect_url)
158
159 def handle_callback_request(self, *, request: Request) -> Response:
160 self.check_request_state(request=request)
161
162 oauth_token = self.get_oauth_token(
163 code=request.query_params["code"], request=request
164 )
165 oauth_user = self.get_oauth_user(oauth_token=oauth_token)
166
167 user = get_request_user(request)
168 if user:
169 connection = OAuthConnection.connect(
170 user=user,
171 provider_key=self.provider_key,
172 oauth_token=oauth_token,
173 oauth_user=oauth_user,
174 )
175 user = connection.user
176 else:
177 connection = OAuthConnection.get_or_create_user(
178 provider_key=self.provider_key,
179 oauth_token=oauth_token,
180 oauth_user=oauth_user,
181 )
182
183 user = connection.user
184
185 self.login(request=request, user=user)
186
187 redirect_url = self.get_login_redirect_url(request=request)
188 return self.get_redirect_response(redirect_url)
189
190 def login(self, *, request: Request, user: Any) -> None:
191 auth_login(request=request, user=user)
192
193 def get_login_redirect_url(self, *, request: Request) -> str:
194 session = get_request_session(request)
195 return session.pop(SESSION_NEXT_KEY, "/")
196
197 def get_disconnect_redirect_url(self, *, request: Request) -> str:
198 return request.data.get("next", "/")
199
200 def get_redirect_response(self, redirect_url: str) -> Response:
201 """
202 Returns a redirect response to the given URL.
203 This is a utility method to ensure consistent redirect handling.
204 """
205 response = ResponseRedirect(redirect_url)
206 add_never_cache_headers(response)
207 return response
208
209
210def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
211 OAUTH_LOGIN_PROVIDERS = getattr(settings, "OAUTH_LOGIN_PROVIDERS", {})
212 provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
213 provider_class = import_string(provider_class_path)
214 provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
215 return provider_class(provider_key=provider_key, **provider_kwargs)
216
217
218def get_provider_keys() -> list[str]:
219 return list(getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}).keys())