1import datetime
2import secrets
3from typing import Any
4from urllib.parse import urlencode
5
6from plain.auth import login as auth_login
7from plain.http import HttpRequest, Response, ResponseRedirect
8from plain.runtime import settings
9from plain.urls import reverse
10from plain.utils.cache import add_never_cache_headers
11from plain.utils.crypto import get_random_string
12from plain.utils.module_loading import import_string
13
14from .exceptions import OAuthError, OAuthStateMismatchError, OAuthStateMissingError
15from .models import OAuthConnection
16
17SESSION_STATE_KEY = "plainoauth_state"
18SESSION_NEXT_KEY = "plainoauth_next"
19
20
21class OAuthToken:
22 def __init__(
23 self,
24 *,
25 access_token: str,
26 refresh_token: str = "",
27 access_token_expires_at: datetime.datetime | None = None,
28 refresh_token_expires_at: datetime.datetime | None = None,
29 ):
30 self.access_token = access_token
31 self.refresh_token = refresh_token
32 self.access_token_expires_at = access_token_expires_at
33 self.refresh_token_expires_at = refresh_token_expires_at
34
35
36class OAuthUser:
37 def __init__(self, *, provider_id: str, user_model_fields: dict | None = None):
38 self.provider_id = provider_id # ID on the provider's system
39 self.user_model_fields = user_model_fields or {}
40
41 def __str__(self):
42 if "email" in self.user_model_fields:
43 return self.user_model_fields["email"]
44 if "username" in self.user_model_fields:
45 return self.user_model_fields["username"]
46 return str(self.provider_id)
47
48
49class OAuthProvider:
50 authorization_url = ""
51
52 def __init__(
53 self,
54 *,
55 # Provided automatically
56 provider_key: str,
57 # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
58 client_id: str,
59 client_secret: str,
60 # Not necessarily required, but commonly used
61 scope: str = "",
62 ):
63 self.provider_key = provider_key
64 self.client_id = client_id
65 self.client_secret = client_secret
66 self.scope = scope
67
68 def get_authorization_url_params(self, *, request: HttpRequest) -> dict:
69 return {
70 "redirect_uri": self.get_callback_url(request=request),
71 "client_id": self.get_client_id(),
72 "scope": self.get_scope(),
73 "state": self.generate_state(),
74 "response_type": "code",
75 }
76
77 def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken:
78 raise NotImplementedError()
79
80 def get_oauth_token(self, *, code: str, request: HttpRequest) -> OAuthToken:
81 raise NotImplementedError()
82
83 def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser:
84 raise NotImplementedError()
85
86 def get_authorization_url(self, *, request: HttpRequest) -> str:
87 return self.authorization_url
88
89 def get_client_id(self) -> str:
90 return self.client_id
91
92 def get_client_secret(self) -> str:
93 return self.client_secret
94
95 def get_scope(self) -> str:
96 return self.scope
97
98 def get_callback_url(self, *, request: HttpRequest) -> str:
99 url = reverse("oauth:callback", provider=self.provider_key)
100 return request.build_absolute_uri(url)
101
102 def generate_state(self) -> str:
103 return get_random_string(length=32)
104
105 def check_request_state(self, *, request: HttpRequest) -> None:
106 if error := request.query_params.get("error"):
107 raise OAuthError(error)
108
109 try:
110 state = request.query_params["state"]
111 except KeyError as e:
112 raise OAuthStateMissingError() from e
113
114 expected_state = request.session.pop(SESSION_STATE_KEY)
115 request.session.save() # Make sure the pop is saved (won't save on an exception)
116 if not secrets.compare_digest(state, expected_state):
117 raise OAuthStateMismatchError()
118
119 def handle_login_request(
120 self, *, request: HttpRequest, redirect_to: str = ""
121 ) -> Response:
122 authorization_url = self.get_authorization_url(request=request)
123 authorization_params = self.get_authorization_url_params(request=request)
124
125 if "state" in authorization_params:
126 # Store the state in the session so we can check on callback
127 request.session[SESSION_STATE_KEY] = authorization_params["state"]
128
129 # Store next url in session so we can get it on the callback request
130 if redirect_to:
131 request.session[SESSION_NEXT_KEY] = redirect_to
132 elif "next" in request.data:
133 request.session[SESSION_NEXT_KEY] = request.data["next"]
134
135 # Sort authorization params for consistency
136 sorted_authorization_params = sorted(authorization_params.items())
137 redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
138 return self.get_redirect_response(redirect_url)
139
140 def handle_connect_request(
141 self, *, request: HttpRequest, redirect_to: str = ""
142 ) -> Response:
143 return self.handle_login_request(request=request, redirect_to=redirect_to)
144
145 def handle_disconnect_request(self, *, request: HttpRequest) -> Response:
146 provider_user_id = request.data["provider_user_id"]
147 connection = OAuthConnection.objects.get(
148 provider_key=self.provider_key, provider_user_id=provider_user_id
149 )
150 connection.delete()
151 redirect_url = self.get_disconnect_redirect_url(request=request)
152 return self.get_redirect_response(redirect_url)
153
154 def handle_callback_request(self, *, request: HttpRequest) -> Response:
155 self.check_request_state(request=request)
156
157 oauth_token = self.get_oauth_token(
158 code=request.query_params["code"], request=request
159 )
160 oauth_user = self.get_oauth_user(oauth_token=oauth_token)
161
162 if request.user:
163 connection = OAuthConnection.connect(
164 user=request.user,
165 provider_key=self.provider_key,
166 oauth_token=oauth_token,
167 oauth_user=oauth_user,
168 )
169 user = connection.user
170 else:
171 connection = OAuthConnection.get_or_create_user(
172 provider_key=self.provider_key,
173 oauth_token=oauth_token,
174 oauth_user=oauth_user,
175 )
176
177 user = connection.user
178
179 self.login(request=request, user=user)
180
181 redirect_url = self.get_login_redirect_url(request=request)
182 return self.get_redirect_response(redirect_url)
183
184 def login(self, *, request: HttpRequest, user: Any) -> None:
185 auth_login(request=request, user=user)
186
187 def get_login_redirect_url(self, *, request: HttpRequest) -> str:
188 return request.session.pop(SESSION_NEXT_KEY, "/")
189
190 def get_disconnect_redirect_url(self, *, request: HttpRequest) -> str:
191 return request.data.get("next", "/")
192
193 def get_redirect_response(self, redirect_url: str) -> Response:
194 """
195 Returns a redirect response to the given URL.
196 This is a utility method to ensure consistent redirect handling.
197 """
198 response = ResponseRedirect(redirect_url)
199 add_never_cache_headers(response)
200 return response
201
202
203def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
204 OAUTH_LOGIN_PROVIDERS = getattr(settings, "OAUTH_LOGIN_PROVIDERS", {})
205 provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
206 provider_class = import_string(provider_class_path)
207 provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
208 return provider_class(provider_key=provider_key, **provider_kwargs)
209
210
211def get_provider_keys() -> list[str]:
212 return list(getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}).keys())