1from __future__ import annotations
  2
  3from functools import cached_property
  4from typing import Any
  5from urllib.parse import urlparse, urlunparse
  6
  7from plain.http import (
  8    ForbiddenError403,
  9    NotFoundError404,
 10    QueryDict,
 11    RedirectResponse,
 12    ResponseBase,
 13)
 14from plain.runtime import settings
 15from plain.sessions.views import SessionView
 16from plain.urls import reverse
 17from plain.utils.cache import patch_cache_control
 18from plain.views import View
 19
 20from .sessions import logout
 21from .utils import resolve_url
 22
 23try:
 24    from plain.admin.impersonate import get_request_impersonator
 25except ImportError:
 26    get_request_impersonator: Any = None
 27
 28__all__ = [
 29    "AuthView",
 30    "LoginRequired",
 31    "LogoutView",
 32    "redirect_to_login",
 33]
 34
 35
 36class LoginRequired(Exception):
 37    def __init__(self, login_url: str | None = None, redirect_field_name: str = "next"):
 38        self.login_url = login_url or settings.AUTH_LOGIN_URL
 39        self.redirect_field_name = redirect_field_name
 40
 41
 42class AuthView(SessionView):
 43    login_required = False
 44    admin_required = False  # Implies login_required
 45    login_url = settings.AUTH_LOGIN_URL
 46
 47    @cached_property
 48    def user(self) -> Any | None:
 49        """Get the authenticated user for this request."""
 50        from .requests import get_request_user
 51
 52        return get_request_user(self.request)
 53
 54    def get_template_context(self) -> dict:
 55        """Add user and impersonator to template context."""
 56        context = super().get_template_context()
 57        context["user"] = self.user
 58        return context
 59
 60    def check_auth(self) -> None:
 61        """
 62        Raises either LoginRequired or ForbiddenError403.
 63        - LoginRequired can specify a login_url and redirect_field_name
 64        - ForbiddenError403 can specify a message
 65        """
 66        if not self.login_required and not self.admin_required:
 67            return None
 68
 69        if not self.user:
 70            raise LoginRequired(login_url=self.login_url)
 71
 72        if self.admin_required:
 73            # At this point, we know user is authenticated (from check above)
 74            # Check if impersonation is active
 75            if get_request_impersonator:
 76                if impersonator := get_request_impersonator(self.request):
 77                    # Impersonators should be able to view admin pages while impersonating.
 78                    # There's probably never a case where an impersonator isn't admin, but it can be configured.
 79                    if not impersonator.is_admin:
 80                        raise ForbiddenError403(
 81                            "You do not have permission to access this page."
 82                        )
 83                    return
 84
 85            if not self.user.is_admin:
 86                # Show a 404 so we don't expose admin urls to non-admin users
 87                raise NotFoundError404()
 88
 89    def get_response(self) -> ResponseBase:
 90        try:
 91            self.check_auth()
 92        except LoginRequired as e:
 93            if self.login_url:
 94                # Ideally this could be handled elsewhere... like PermissionDenied
 95                # also seems like this code is used multiple places anyway...
 96                # could be easier to get redirect query param
 97                path = self.request.build_absolute_uri()
 98                resolved_login_url = reverse(e.login_url)
 99                # If the login url is the same scheme and net location then use the
100                # path as the "next" url.
101                login_scheme, login_netloc = urlparse(resolved_login_url)[:2]
102                current_scheme, current_netloc = urlparse(path)[:2]
103                if (not login_scheme or login_scheme == current_scheme) and (
104                    not login_netloc or login_netloc == current_netloc
105                ):
106                    path = self.request.get_full_path()
107                return redirect_to_login(
108                    path,
109                    resolved_login_url,
110                    e.redirect_field_name,
111                )
112            else:
113                raise ForbiddenError403("Login required")
114
115        response = super().get_response()
116
117        if self.user:
118            # Make sure it at least has private as a default
119            patch_cache_control(response, private=True)
120
121        return response
122
123
124class LogoutView(View):
125    def post(self) -> RedirectResponse:
126        logout(self.request)
127        return RedirectResponse("/")
128
129
130def redirect_to_login(
131    next: str, login_url: str | None = None, redirect_field_name: str = "next"
132) -> RedirectResponse:
133    """
134    Redirect the user to the login page, passing the given 'next' page.
135    """
136    resolved_url = resolve_url(login_url or settings.AUTH_LOGIN_URL)
137
138    login_url_parts = list(urlparse(resolved_url))
139    if redirect_field_name:
140        querystring = QueryDict(login_url_parts[4], mutable=True)
141        querystring[redirect_field_name] = next
142        login_url_parts[4] = querystring.urlencode(safe="/")
143
144    return RedirectResponse(str(urlunparse(login_url_parts)))