1from __future__ import annotations
  2
  3from functools import cached_property
  4from typing import TYPE_CHECKING, Any
  5from urllib.parse import urlparse, urlunparse
  6
  7from plain.http import (
  8    ForbiddenError403,
  9    NotFoundError404,
 10    QueryDict,
 11    RedirectResponse,
 12    ResponseBase,
 13)
 14from plain.runtime import settings
 15from plain.sessions.views import SessionView
 16from plain.urls import reverse
 17from plain.utils.cache import patch_cache_control
 18from plain.views import View
 19
 20from .sessions import logout
 21from .utils import resolve_url
 22
 23if TYPE_CHECKING:
 24    from plain.postgres import Model
 25
 26try:
 27    from plain.admin.impersonate import get_request_impersonator
 28except ImportError:
 29    get_request_impersonator: Any = None
 30
 31__all__ = [
 32    "AuthView",
 33    "LoginRequired",
 34    "LogoutView",
 35    "redirect_to_login",
 36]
 37
 38
 39class LoginRequired(Exception):
 40    def __init__(self, login_url: str | None = None, redirect_field_name: str = "next"):
 41        self.login_url = login_url or settings.AUTH_LOGIN_URL
 42        self.redirect_field_name = redirect_field_name
 43
 44
 45class AuthView(SessionView):
 46    login_required = False
 47    admin_required = False  # Implies login_required
 48    login_url = settings.AUTH_LOGIN_URL
 49
 50    @cached_property
 51    def user(self) -> Model | None:
 52        """Get the authenticated user for this request."""
 53        from .requests import get_request_user
 54
 55        return get_request_user(self.request)
 56
 57    def get_template_context(self) -> dict:
 58        """Add user and impersonator to template context."""
 59        context = super().get_template_context()
 60        context["user"] = self.user
 61        return context
 62
 63    def check_auth(self) -> None:
 64        """
 65        Raises either LoginRequired or ForbiddenError403.
 66        - LoginRequired can specify a login_url and redirect_field_name
 67        - ForbiddenError403 can specify a message
 68        """
 69        if not self.login_required and not self.admin_required:
 70            return None
 71
 72        if not self.user:
 73            raise LoginRequired(login_url=self.login_url)
 74
 75        if self.admin_required:
 76            # At this point, we know user is authenticated (from check above)
 77            # Check if impersonation is active
 78            if get_request_impersonator:
 79                if impersonator := get_request_impersonator(self.request):
 80                    # Impersonators should be able to view admin pages while impersonating.
 81                    # There's probably never a case where an impersonator isn't admin, but it can be configured.
 82                    if not impersonator.is_admin:
 83                        raise ForbiddenError403(
 84                            "You do not have permission to access this page."
 85                        )
 86                    return
 87
 88            if not self.user.is_admin:  # type: ignore[union-attr]
 89                # Show a 404 so we don't expose admin urls to non-admin users
 90                raise NotFoundError404()
 91
 92    def get_response(self) -> ResponseBase:
 93        try:
 94            self.check_auth()
 95        except LoginRequired as e:
 96            if self.login_url:
 97                # Ideally this could be handled elsewhere... like PermissionDenied
 98                # also seems like this code is used multiple places anyway...
 99                # could be easier to get redirect query param
100                path = self.request.build_absolute_uri()
101                resolved_login_url = reverse(e.login_url)
102                # If the login url is the same scheme and net location then use the
103                # path as the "next" url.
104                login_scheme, login_netloc = urlparse(resolved_login_url)[:2]
105                current_scheme, current_netloc = urlparse(path)[:2]
106                if (not login_scheme or login_scheme == current_scheme) and (
107                    not login_netloc or login_netloc == current_netloc
108                ):
109                    path = self.request.get_full_path()
110                return redirect_to_login(
111                    path,
112                    resolved_login_url,
113                    e.redirect_field_name,
114                )
115            else:
116                raise ForbiddenError403("Login required")
117
118        response = super().get_response()
119
120        if self.user:
121            # Make sure it at least has private as a default
122            patch_cache_control(response, private=True)
123
124        return response
125
126
127class LogoutView(View):
128    def post(self) -> RedirectResponse:
129        logout(self.request)
130        return RedirectResponse("/")
131
132
133def redirect_to_login(
134    next: str, login_url: str | None = None, redirect_field_name: str = "next"
135) -> RedirectResponse:
136    """
137    Redirect the user to the login page, passing the given 'next' page.
138    """
139    resolved_url = resolve_url(login_url or settings.AUTH_LOGIN_URL)
140
141    login_url_parts = list(urlparse(resolved_url))
142    if redirect_field_name:
143        querystring = QueryDict(login_url_parts[4], mutable=True)
144        querystring[redirect_field_name] = next
145        login_url_parts[4] = querystring.urlencode(safe="/")
146
147    return RedirectResponse(str(urlunparse(login_url_parts)), allow_external=True)