1import logging
2
3from plain.auth.requests import get_request_user
4from plain.auth.views import AuthView
5from plain.http import RedirectResponse, Response
6from plain.views import TemplateView, View
7
8from .exceptions import (
9 OAuthError,
10)
11from .providers import get_oauth_provider_instance
12
13logger = logging.getLogger(__name__)
14
15
16class OAuthLoginView(View):
17 def post(self) -> Response:
18 request = self.request
19 provider = self.url_kwargs["provider"]
20 if get_request_user(request):
21 return RedirectResponse("/")
22
23 provider_instance = get_oauth_provider_instance(provider_key=provider)
24 return provider_instance.handle_login_request(request=request)
25
26
27class OAuthCallbackView(TemplateView):
28 """
29 The callback view is used for signup, login, and connect.
30 """
31
32 template_name = "oauth/error.html"
33
34 def get(self) -> Response:
35 provider = self.url_kwargs["provider"]
36 provider_instance = get_oauth_provider_instance(provider_key=provider)
37 try:
38 return provider_instance.handle_callback_request(request=self.request)
39 except OAuthError as e:
40 logger.warning("OAuth error: %s", e.message)
41 self.oauth_error = e
42
43 response = super().get()
44 response.status_code = 400
45 return response
46
47 def get_template_names(self) -> list[str]:
48 names = []
49 if oauth_error := getattr(self, "oauth_error", None):
50 if oauth_error.template_name:
51 names.append(oauth_error.template_name)
52 names.append(self.template_name)
53 return names
54
55 def get_template_context(self) -> dict:
56 context = super().get_template_context()
57 context["oauth_error"] = getattr(self, "oauth_error", None)
58 return context
59
60
61class OAuthConnectView(AuthView):
62 def post(self) -> Response:
63 request = self.request
64 provider = self.url_kwargs["provider"]
65 provider_instance = get_oauth_provider_instance(provider_key=provider)
66 return provider_instance.handle_connect_request(request=request)
67
68
69class OAuthDisconnectView(AuthView):
70 def post(self) -> Response:
71 request = self.request
72 provider = self.url_kwargs["provider"]
73 provider_instance = get_oauth_provider_instance(provider_key=provider)
74 # try:
75 return provider_instance.handle_disconnect_request(request=request)
76 # except OAuthCannotDisconnectError:
77 # return render(
78 # request,
79 # "oauth/error.html",
80 # {
81 # "oauth_error": "This connection can't be removed. You must have a usable password or at least one active connection."
82 # },
83 # status_code=400,
84 # )