1from __future__ import annotations
  2
  3import re
  4from typing import TYPE_CHECKING
  5from urllib.parse import urlparse
  6
  7from plain.http import HttpMiddleware, SuspiciousOperationError400
  8from plain.runtime import settings
  9
 10if TYPE_CHECKING:
 11    from plain.http import Request, Response
 12
 13
 14class CsrfViewMiddleware(HttpMiddleware):
 15    """
 16    Modern CSRF protection middleware using Sec-Fetch-Site headers and origin validation.
 17    Based on Filippo Valsorda's 2025 research (https://words.filippo.io/csrf/).
 18
 19    Note: This provides same-origin (not same-site) protection. Same-site origins
 20    like subdomains can have different trust levels and are rejected.
 21    """
 22
 23    def __init__(self):
 24        # Compile CSRF exempt patterns once for performance
 25        self.csrf_exempt_patterns: list[re.Pattern[str]] = [
 26            re.compile(r) for r in settings.CSRF_EXEMPT_PATHS
 27        ]
 28
 29    def before_request(self, request: Request) -> Response | None:
 30        allowed, reason = self.should_allow_request(request)
 31
 32        if not allowed:
 33            raise SuspiciousOperationError400(reason)
 34
 35        return None
 36
 37    def should_allow_request(self, request: Request) -> tuple[bool, str]:
 38        # 1. Allow safe methods (GET, HEAD, OPTIONS)
 39        if request.method in ("GET", "HEAD", "OPTIONS"):
 40            return True, f"CSRF allowed: Safe HTTP method: {request.method}"
 41
 42        # 2. Path-based exemption (regex patterns)
 43        for pattern in self.csrf_exempt_patterns:
 44            if pattern.search(request.path_info):
 45                return (
 46                    True,
 47                    f"CSRF allowed: Path {request.path_info} matches exempt pattern {pattern.pattern}",
 48                )
 49
 50        origin = request.headers.get("Origin")
 51        sec_fetch_site = request.headers.get("Sec-Fetch-Site", "").lower()
 52
 53        # 3. Check trusted origins allow-list
 54
 55        if origin and origin in settings.CSRF_TRUSTED_ORIGINS:
 56            return True, f"CSRF allowed: Trusted origin: {origin}"
 57
 58        # 4. Primary protection: Check Sec-Fetch-Site header
 59        if sec_fetch_site in ("same-origin", "none"):
 60            return (
 61                True,
 62                f"CSRF allowed: Same-origin request from Sec-Fetch-Site: {sec_fetch_site}",
 63            )
 64        elif sec_fetch_site in ("cross-site", "same-site"):
 65            return (
 66                False,
 67                f"CSRF rejected: Cross-origin request from Sec-Fetch-Site: {sec_fetch_site}",
 68            )
 69
 70        # 5. No fetch metadata or Origin headers - allow (non-browser requests)
 71        if not origin and not sec_fetch_site:
 72            return (
 73                True,
 74                "CSRF allowed: No Origin or Sec-Fetch-Site header - likely non-browser or old browser",
 75            )
 76
 77        # 6. Fallback: Origin vs Host comparison for older browsers
 78        # Note: On pre-2023 browsers, HTTP→HTTPS transitions may cause mismatches
 79        # (Origin shows :443, request sees :80 if TLS terminated upstream).
 80        # HSTS helps here; otherwise add external origins to CSRF_TRUSTED_ORIGINS.
 81        if origin == "null":
 82            return False, "CSRF rejected: Null Origin header"
 83
 84        if (parsed_origin := urlparse(origin)) and (host := request.host):
 85            try:
 86                # Scheme-agnostic host:port comparison
 87                origin_host = parsed_origin.hostname
 88                origin_port = parsed_origin.port or (
 89                    80
 90                    if parsed_origin.scheme == "http"
 91                    else 443
 92                    if parsed_origin.scheme == "https"
 93                    else None
 94                )
 95
 96                # Extract hostname from request host (similar to how we parse origin)
 97                # Use a fake scheme since we only care about host parsing
 98                parsed_host = urlparse(f"http://{host}")
 99                request_host = parsed_host.hostname or host
100                request_port = request.port
101
102                # Compare hostname and port (scheme-agnostic)
103                # Both origin_host and request_host are normalized by urlparse (IPv6 brackets stripped)
104                if origin_host and origin_port and request_host and request_port:
105                    host_match = origin_host.lower() == request_host.lower()
106                    port_match = origin_port == int(request_port)
107
108                    if host_match and port_match:
109                        return (
110                            True,
111                            f"CSRF allowed: Same-origin request - Origin {origin} matches Host {host}",
112                        )
113
114                    # Build detailed error message based on what mismatched
115                    if host_match:
116                        # Port mismatch only - show ports since they're relevant
117                        return (
118                            False,
119                            f"CSRF rejected: Origin {origin_host}:{origin_port} does not match Host {request_host}:{request_port} (port mismatch)",
120                        )
121                    elif port_match:
122                        # Host mismatch only - no need to show ports
123                        return (
124                            False,
125                            f"CSRF rejected: Origin {origin_host} does not match Host {request_host}",
126                        )
127                    else:
128                        # Both mismatch - show full details
129                        return (
130                            False,
131                            f"CSRF rejected: Origin {origin_host}:{origin_port} does not match Host {request_host}:{request_port}",
132                        )
133            except ValueError:
134                pass
135
136        # Origin present but couldn't parse/compare properly
137        return (
138            False,
139            f"CSRF rejected: Origin {origin} could not be validated against Host {request.host}",
140        )