1import logging
 2
 3from plain.auth.requests import get_request_user
 4from plain.auth.views import AuthView
 5from plain.http import RedirectResponse, Response
 6from plain.views import TemplateView, View
 7
 8from .exceptions import (
 9    OAuthError,
10)
11from .providers import get_oauth_provider_instance
12
13logger = logging.getLogger(__name__)
14
15
16class OAuthLoginView(View):
17    def post(self) -> Response:
18        request = self.request
19        provider = self.url_kwargs["provider"]
20        if get_request_user(request):
21            return RedirectResponse("/")
22
23        provider_instance = get_oauth_provider_instance(provider_key=provider)
24        return provider_instance.handle_login_request(request=request)
25
26
27class OAuthCallbackView(TemplateView):
28    """
29    The callback view is used for signup, login, and connect.
30    """
31
32    template_name = "oauth/error.html"
33
34    def get(self) -> Response:
35        provider = self.url_kwargs["provider"]
36        provider_instance = get_oauth_provider_instance(provider_key=provider)
37        try:
38            return provider_instance.handle_callback_request(request=self.request)
39        except OAuthError as e:
40            logger.warning("OAuth error: %s", e.message)
41            self.oauth_error = e
42
43            response = super().get()
44            response.status_code = 400
45            return response
46
47    def get_template_names(self) -> list[str]:
48        names = []
49        if oauth_error := getattr(self, "oauth_error", None):
50            if oauth_error.template_name:
51                names.append(oauth_error.template_name)
52        names.append(self.template_name)
53        return names
54
55    def get_template_context(self) -> dict:
56        context = super().get_template_context()
57        context["oauth_error"] = getattr(self, "oauth_error", None)
58        return context
59
60
61class OAuthConnectView(AuthView):
62    def post(self) -> Response:
63        request = self.request
64        provider = self.url_kwargs["provider"]
65        provider_instance = get_oauth_provider_instance(provider_key=provider)
66        return provider_instance.handle_connect_request(request=request)
67
68
69class OAuthDisconnectView(AuthView):
70    def post(self) -> Response:
71        request = self.request
72        provider = self.url_kwargs["provider"]
73        provider_instance = get_oauth_provider_instance(provider_key=provider)
74        # try:
75        return provider_instance.handle_disconnect_request(request=request)
76        # except OAuthCannotDisconnectError:
77        #     return render(
78        #         request,
79        #         "oauth/error.html",
80        #         {
81        #             "oauth_error": "This connection can't be removed. You must have a usable password or at least one active connection."
82        #         },
83        #         status_code=400,
84        #     )