1from __future__ import annotations
2
3from typing import TYPE_CHECKING
4
5from ..results import AuditResult, CheckResult
6from .base import Audit
7
8if TYPE_CHECKING:
9 from ..scanner import Scanner
10
11
12class CORSAudit(Audit):
13 """CORS (Cross-Origin Resource Sharing) security checks."""
14
15 name = "Cross-Origin Resource Sharing (CORS)"
16 slug = "cors"
17 required = False # CORS is only needed for cross-origin API endpoints
18 description = "Checks for Cross-Origin Resource Sharing misconfigurations that could allow unauthorized access to resources. See: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS"
19
20 def check(self, scanner: Scanner) -> AuditResult:
21 """Check if CORS is configured securely."""
22 response = scanner.fetch()
23
24 # Check for CORS headers
25 allow_origin = response.headers.get("Access-Control-Allow-Origin")
26 allow_credentials = response.headers.get("Access-Control-Allow-Credentials")
27 vary_header = response.headers.get("Vary")
28
29 if not allow_origin:
30 # CORS not detected
31 return AuditResult(
32 name=self.name,
33 detected=False,
34 required=self.required,
35 checks=[],
36 description=self.description,
37 )
38
39 # CORS detected - run security checks
40 checks = [
41 self._check_wildcard_with_credentials(allow_origin, allow_credentials),
42 self._check_null_origin(allow_origin),
43 self._check_vary_header(allow_origin, vary_header),
44 ]
45
46 return AuditResult(
47 name=self.name,
48 detected=True,
49 required=self.required,
50 checks=checks,
51 description=self.description,
52 )
53
54 def _check_wildcard_with_credentials(
55 self, allow_origin: str, allow_credentials: str | None
56 ) -> CheckResult:
57 """Check for dangerous * origin with credentials."""
58 is_wildcard = allow_origin.strip() == "*"
59 allows_credentials = (
60 allow_credentials and allow_credentials.strip().lower() == "true"
61 )
62
63 if is_wildcard and allows_credentials:
64 return CheckResult(
65 name="wildcard-credentials",
66 passed=False,
67 message="CORS allows all origins (*) with credentials (major security risk)",
68 )
69
70 if is_wildcard:
71 return CheckResult(
72 name="wildcard-credentials",
73 passed=True,
74 message="CORS allows all origins (*) without credentials (acceptable for public resources)",
75 )
76
77 return CheckResult(
78 name="wildcard-credentials",
79 passed=True,
80 message="CORS origin is not set to wildcard",
81 )
82
83 def _check_null_origin(self, allow_origin: str) -> CheckResult:
84 """Check for dangerous null origin."""
85 if allow_origin.strip().lower() == "null":
86 return CheckResult(
87 name="null-origin",
88 passed=False,
89 message="CORS allows 'null' origin (can be exploited by sandboxed iframes)",
90 )
91
92 return CheckResult(
93 name="null-origin",
94 passed=True,
95 message="CORS does not allow null origin",
96 )
97
98 def _check_vary_header(
99 self, allow_origin: str, vary_header: str | None
100 ) -> CheckResult:
101 """Check for Vary: Origin header to prevent cache poisoning."""
102 # Wildcard doesn't need Vary header since it doesn't vary by origin
103 if allow_origin.strip() == "*":
104 return CheckResult(
105 name="vary-header",
106 passed=True,
107 message="Vary header not required for wildcard origin",
108 )
109
110 # For specific origins, Vary: Origin is needed to prevent cache poisoning
111 if not vary_header:
112 return CheckResult(
113 name="vary-header",
114 passed=False,
115 message="Missing 'Vary: Origin' header (required to prevent cache poisoning when using specific origins)",
116 )
117
118 # Check if Origin is in the Vary header
119 vary_values = [v.strip().lower() for v in vary_header.split(",")]
120 if "origin" not in vary_values:
121 return CheckResult(
122 name="vary-header",
123 passed=False,
124 message=f"'Vary' header present but missing 'Origin' value (found: {vary_header})",
125 )
126
127 return CheckResult(
128 name="vary-header",
129 passed=True,
130 message="Vary: Origin header correctly set",
131 )