1from __future__ import annotations
2
3from functools import cached_property
4from typing import Any
5from urllib.parse import urlparse, urlunparse
6
7from plain.http import (
8 ForbiddenError403,
9 NotFoundError404,
10 QueryDict,
11 RedirectResponse,
12 ResponseBase,
13)
14from plain.runtime import settings
15from plain.sessions.views import SessionView
16from plain.urls import reverse
17from plain.utils.cache import patch_cache_control
18from plain.views import View
19
20from .sessions import logout
21from .utils import resolve_url
22
23try:
24 from plain.admin.impersonate import get_request_impersonator
25except ImportError:
26 get_request_impersonator: Any = None
27
28__all__ = [
29 "AuthView",
30 "LoginRequired",
31 "LogoutView",
32 "redirect_to_login",
33]
34
35
36class LoginRequired(Exception):
37 def __init__(self, login_url: str | None = None, redirect_field_name: str = "next"):
38 self.login_url = login_url or settings.AUTH_LOGIN_URL
39 self.redirect_field_name = redirect_field_name
40
41
42class AuthView(SessionView):
43 login_required = False
44 admin_required = False # Implies login_required
45 login_url = settings.AUTH_LOGIN_URL
46
47 @cached_property
48 def user(self) -> Any | None:
49 """Get the authenticated user for this request."""
50 from .requests import get_request_user
51
52 return get_request_user(self.request)
53
54 def get_template_context(self) -> dict:
55 """Add user and impersonator to template context."""
56 context = super().get_template_context()
57 context["user"] = self.user
58 return context
59
60 def check_auth(self) -> None:
61 """
62 Raises either LoginRequired or ForbiddenError403.
63 - LoginRequired can specify a login_url and redirect_field_name
64 - ForbiddenError403 can specify a message
65 """
66 if not self.login_required and not self.admin_required:
67 return None
68
69 if not self.user:
70 raise LoginRequired(login_url=self.login_url)
71
72 if self.admin_required:
73 # At this point, we know user is authenticated (from check above)
74 # Check if impersonation is active
75 if get_request_impersonator:
76 if impersonator := get_request_impersonator(self.request):
77 # Impersonators should be able to view admin pages while impersonating.
78 # There's probably never a case where an impersonator isn't admin, but it can be configured.
79 if not impersonator.is_admin:
80 raise ForbiddenError403(
81 "You do not have permission to access this page."
82 )
83 return
84
85 if not self.user.is_admin:
86 # Show a 404 so we don't expose admin urls to non-admin users
87 raise NotFoundError404()
88
89 def get_response(self) -> ResponseBase:
90 try:
91 self.check_auth()
92 except LoginRequired as e:
93 if self.login_url:
94 # Ideally this could be handled elsewhere... like PermissionDenied
95 # also seems like this code is used multiple places anyway...
96 # could be easier to get redirect query param
97 path = self.request.build_absolute_uri()
98 resolved_login_url = reverse(e.login_url)
99 # If the login url is the same scheme and net location then use the
100 # path as the "next" url.
101 login_scheme, login_netloc = urlparse(resolved_login_url)[:2]
102 current_scheme, current_netloc = urlparse(path)[:2]
103 if (not login_scheme or login_scheme == current_scheme) and (
104 not login_netloc or login_netloc == current_netloc
105 ):
106 path = self.request.get_full_path()
107 return redirect_to_login(
108 path,
109 resolved_login_url,
110 e.redirect_field_name,
111 )
112 else:
113 raise ForbiddenError403("Login required")
114
115 response = super().get_response()
116
117 if self.user:
118 # Make sure it at least has private as a default
119 patch_cache_control(response, private=True)
120
121 return response
122
123
124class LogoutView(View):
125 def post(self) -> RedirectResponse:
126 logout(self.request)
127 return RedirectResponse("/")
128
129
130def redirect_to_login(
131 next: str, login_url: str | None = None, redirect_field_name: str = "next"
132) -> RedirectResponse:
133 """
134 Redirect the user to the login page, passing the given 'next' page.
135 """
136 resolved_url = resolve_url(login_url or settings.AUTH_LOGIN_URL)
137
138 login_url_parts = list(urlparse(resolved_url))
139 if redirect_field_name:
140 querystring = QueryDict(login_url_parts[4], mutable=True)
141 querystring[redirect_field_name] = next
142 login_url_parts[4] = querystring.urlencode(safe="/")
143
144 return RedirectResponse(str(urlunparse(login_url_parts)))