1import datetime
2import secrets
3from typing import Any
4from urllib.parse import urlencode
5
6from plain.auth import login as auth_login
7from plain.http import HttpRequest, Response, ResponseRedirect
8from plain.runtime import settings
9from plain.urls import reverse
10from plain.utils.crypto import get_random_string
11from plain.utils.module_loading import import_string
12
13from .exceptions import OAuthError, OAuthStateMismatchError
14from .models import OAuthConnection
15
16SESSION_STATE_KEY = "plainoauth_state"
17SESSION_NEXT_KEY = "plainoauth_next"
18
19
20class OAuthToken:
21 def __init__(
22 self,
23 *,
24 access_token: str,
25 refresh_token: str = "",
26 access_token_expires_at: datetime.datetime = None,
27 refresh_token_expires_at: datetime.datetime = None,
28 ):
29 self.access_token = access_token
30 self.refresh_token = refresh_token
31 self.access_token_expires_at = access_token_expires_at
32 self.refresh_token_expires_at = refresh_token_expires_at
33
34
35class OAuthUser:
36 def __init__(self, *, id: str, **user_model_fields: dict):
37 self.id = id # ID on the provider's system
38 self.user_model_fields = user_model_fields
39
40 def __str__(self):
41 if "email" in self.user_model_fields:
42 return self.user_model_fields["email"]
43 if "username" in self.user_model_fields:
44 return self.user_model_fields["username"]
45 return str(self.id)
46
47
48class OAuthProvider:
49 authorization_url = ""
50
51 def __init__(
52 self,
53 *,
54 # Provided automatically
55 provider_key: str,
56 # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
57 client_id: str,
58 client_secret: str,
59 # Not necessarily required, but commonly used
60 scope: str = "",
61 ):
62 self.provider_key = provider_key
63 self.client_id = client_id
64 self.client_secret = client_secret
65 self.scope = scope
66
67 def get_authorization_url_params(self, *, request: HttpRequest) -> dict:
68 return {
69 "redirect_uri": self.get_callback_url(request=request),
70 "client_id": self.get_client_id(),
71 "scope": self.get_scope(),
72 "state": self.generate_state(),
73 "response_type": "code",
74 }
75
76 def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken:
77 raise NotImplementedError()
78
79 def get_oauth_token(self, *, code: str, request: HttpRequest) -> OAuthToken:
80 raise NotImplementedError()
81
82 def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser:
83 raise NotImplementedError()
84
85 def get_authorization_url(self, *, request: HttpRequest) -> str:
86 return self.authorization_url
87
88 def get_client_id(self) -> str:
89 return self.client_id
90
91 def get_client_secret(self) -> str:
92 return self.client_secret
93
94 def get_scope(self) -> str:
95 return self.scope
96
97 def get_callback_url(self, *, request: HttpRequest) -> str:
98 url = reverse("oauth:callback", provider=self.provider_key)
99 return request.build_absolute_uri(url)
100
101 def generate_state(self) -> str:
102 return get_random_string(length=32)
103
104 def check_request_state(self, *, request: HttpRequest) -> None:
105 if error := request.query_params.get("error"):
106 raise OAuthError(error)
107
108 state = request.query_params["state"]
109 expected_state = request.session.pop(SESSION_STATE_KEY)
110 request.session.save() # Make sure the pop is saved (won't save on an exception)
111 if not secrets.compare_digest(state, expected_state):
112 raise OAuthStateMismatchError()
113
114 def handle_login_request(
115 self, *, request: HttpRequest, redirect_to: str = ""
116 ) -> Response:
117 authorization_url = self.get_authorization_url(request=request)
118 authorization_params = self.get_authorization_url_params(request=request)
119
120 if "state" in authorization_params:
121 # Store the state in the session so we can check on callback
122 request.session[SESSION_STATE_KEY] = authorization_params["state"]
123
124 # Store next url in session so we can get it on the callback request
125 if redirect_to:
126 request.session[SESSION_NEXT_KEY] = redirect_to
127 elif "next" in request.data:
128 request.session[SESSION_NEXT_KEY] = request.data["next"]
129
130 # Sort authorization params for consistency
131 sorted_authorization_params = sorted(authorization_params.items())
132 redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
133 return ResponseRedirect(redirect_url)
134
135 def handle_connect_request(
136 self, *, request: HttpRequest, redirect_to: str = ""
137 ) -> Response:
138 return self.handle_login_request(request=request, redirect_to=redirect_to)
139
140 def handle_disconnect_request(self, *, request: HttpRequest) -> Response:
141 provider_user_id = request.data["provider_user_id"]
142 connection = OAuthConnection.objects.get(
143 provider_key=self.provider_key, provider_user_id=provider_user_id
144 )
145 connection.delete()
146 redirect_url = self.get_disconnect_redirect_url(request=request)
147 return ResponseRedirect(redirect_url)
148
149 def handle_callback_request(self, *, request: HttpRequest) -> Response:
150 self.check_request_state(request=request)
151
152 oauth_token = self.get_oauth_token(
153 code=request.query_params["code"], request=request
154 )
155 oauth_user = self.get_oauth_user(oauth_token=oauth_token)
156
157 if request.user:
158 connection = OAuthConnection.connect(
159 user=request.user,
160 provider_key=self.provider_key,
161 oauth_token=oauth_token,
162 oauth_user=oauth_user,
163 )
164 user = connection.user
165 else:
166 connection = OAuthConnection.get_or_create_user(
167 provider_key=self.provider_key,
168 oauth_token=oauth_token,
169 oauth_user=oauth_user,
170 )
171
172 user = connection.user
173
174 self.login(request=request, user=user)
175
176 redirect_url = self.get_login_redirect_url(request=request)
177 return ResponseRedirect(redirect_url)
178
179 def login(self, *, request: HttpRequest, user: Any) -> Response:
180 auth_login(request=request, user=user)
181
182 def get_login_redirect_url(self, *, request: HttpRequest) -> str:
183 return request.session.pop(SESSION_NEXT_KEY, "/")
184
185 def get_disconnect_redirect_url(self, *, request: HttpRequest) -> str:
186 return request.data.get("next", "/")
187
188
189def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
190 OAUTH_LOGIN_PROVIDERS = getattr(settings, "OAUTH_LOGIN_PROVIDERS", {})
191 provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
192 provider_class = import_string(provider_class_path)
193 provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
194 return provider_class(provider_key=provider_key, **provider_kwargs)
195
196
197def get_provider_keys() -> list[str]:
198 return list(getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}).keys())