1from __future__ import annotations
2
3import re
4from typing import TYPE_CHECKING
5from urllib.parse import urlparse
6
7from plain.http import HttpMiddleware, SuspiciousOperationError400
8from plain.runtime import settings
9
10if TYPE_CHECKING:
11 from plain.http import Request, Response
12
13
14class CsrfViewMiddleware(HttpMiddleware):
15 """
16 Modern CSRF protection middleware using Sec-Fetch-Site headers and origin validation.
17 Based on Filippo Valsorda's 2025 research (https://words.filippo.io/csrf/).
18
19 Note: This provides same-origin (not same-site) protection. Same-site origins
20 like subdomains can have different trust levels and are rejected.
21 """
22
23 def __init__(self):
24 # Compile CSRF exempt patterns once for performance
25 self.csrf_exempt_patterns: list[re.Pattern[str]] = [
26 re.compile(r) for r in settings.CSRF_EXEMPT_PATHS
27 ]
28
29 def before_request(self, request: Request) -> Response | None:
30 allowed, reason = self.should_allow_request(request)
31
32 if not allowed:
33 raise SuspiciousOperationError400(reason)
34
35 return None
36
37 def should_allow_request(self, request: Request) -> tuple[bool, str]:
38 # 1. Allow safe methods (GET, HEAD, OPTIONS)
39 if request.method in ("GET", "HEAD", "OPTIONS"):
40 return True, f"CSRF allowed: Safe HTTP method: {request.method}"
41
42 # 2. Path-based exemption (regex patterns)
43 for pattern in self.csrf_exempt_patterns:
44 if pattern.search(request.path_info):
45 return (
46 True,
47 f"CSRF allowed: Path {request.path_info} matches exempt pattern {pattern.pattern}",
48 )
49
50 origin = request.headers.get("Origin")
51 sec_fetch_site = request.headers.get("Sec-Fetch-Site", "").lower()
52
53 # 3. Check trusted origins allow-list
54
55 if origin and origin in settings.CSRF_TRUSTED_ORIGINS:
56 return True, f"CSRF allowed: Trusted origin: {origin}"
57
58 # 4. Primary protection: Check Sec-Fetch-Site header
59 if sec_fetch_site in ("same-origin", "none"):
60 return (
61 True,
62 f"CSRF allowed: Same-origin request from Sec-Fetch-Site: {sec_fetch_site}",
63 )
64 elif sec_fetch_site in ("cross-site", "same-site"):
65 return (
66 False,
67 f"CSRF rejected: Cross-origin request from Sec-Fetch-Site: {sec_fetch_site}",
68 )
69
70 # 5. No fetch metadata or Origin headers - allow (non-browser requests)
71 if not origin and not sec_fetch_site:
72 return (
73 True,
74 "CSRF allowed: No Origin or Sec-Fetch-Site header - likely non-browser or old browser",
75 )
76
77 # 6. Fallback: Origin vs Host comparison for older browsers
78 # Note: On pre-2023 browsers, HTTP→HTTPS transitions may cause mismatches
79 # (Origin shows :443, request sees :80 if TLS terminated upstream).
80 # HSTS helps here; otherwise add external origins to CSRF_TRUSTED_ORIGINS.
81 if origin == "null":
82 return False, "CSRF rejected: Null Origin header"
83
84 if (parsed_origin := urlparse(origin)) and (host := request.host):
85 try:
86 # Scheme-agnostic host:port comparison
87 origin_host = parsed_origin.hostname
88 origin_port = parsed_origin.port or (
89 80
90 if parsed_origin.scheme == "http"
91 else 443
92 if parsed_origin.scheme == "https"
93 else None
94 )
95
96 # Extract hostname from request host (similar to how we parse origin)
97 # Use a fake scheme since we only care about host parsing
98 parsed_host = urlparse(f"http://{host}")
99 request_host = parsed_host.hostname or host
100 request_port = request.port
101
102 # Compare hostname and port (scheme-agnostic)
103 # Both origin_host and request_host are normalized by urlparse (IPv6 brackets stripped)
104 if origin_host and origin_port and request_host and request_port:
105 host_match = origin_host.lower() == request_host.lower()
106 port_match = origin_port == int(request_port)
107
108 if host_match and port_match:
109 return (
110 True,
111 f"CSRF allowed: Same-origin request - Origin {origin} matches Host {host}",
112 )
113
114 # Build detailed error message based on what mismatched
115 if host_match:
116 # Port mismatch only - show ports since they're relevant
117 return (
118 False,
119 f"CSRF rejected: Origin {origin_host}:{origin_port} does not match Host {request_host}:{request_port} (port mismatch)",
120 )
121 elif port_match:
122 # Host mismatch only - no need to show ports
123 return (
124 False,
125 f"CSRF rejected: Origin {origin_host} does not match Host {request_host}",
126 )
127 else:
128 # Both mismatch - show full details
129 return (
130 False,
131 f"CSRF rejected: Origin {origin_host}:{origin_port} does not match Host {request_host}:{request_port}",
132 )
133 except ValueError:
134 pass
135
136 # Origin present but couldn't parse/compare properly
137 return (
138 False,
139 f"CSRF rejected: Origin {origin} could not be validated against Host {request.host}",
140 )