1from __future__ import annotations
  2
  3import socket
  4import ssl
  5from datetime import UTC, datetime
  6from typing import TYPE_CHECKING
  7from urllib.parse import urlparse
  8
  9from ..results import AuditResult, CheckResult
 10from .base import Audit
 11
 12if TYPE_CHECKING:
 13    from ..scanner import Scanner
 14
 15
 16class TLSAudit(Audit):
 17    """TLS/SSL security checks."""
 18
 19    name = "TLS/SSL"
 20    slug = "tls"
 21    description = "Basic TLS/SSL validation including certificate expiry and protocol version. For comprehensive TLS testing, use SSL Labs (https://www.ssllabs.com/ssltest/)."
 22
 23    def check(self, scanner: Scanner) -> AuditResult:
 24        """Check TLS certificate and configuration."""
 25
 26        # Check if there was a TLS/SSL error during the initial fetch
 27        # This allows us to report certificate issues even when the connection fails
 28        if scanner.fetch_exception is not None:
 29            # Report the TLS/SSL error
 30            error_msg = str(scanner.fetch_exception)
 31
 32            # Try to make the error message more user-friendly
 33            if "CERTIFICATE_VERIFY_FAILED" in error_msg:
 34                error_msg = "Certificate verification failed (certificate may be expired, self-signed, or for wrong hostname)"
 35            elif "certificate verify failed" in error_msg.lower():
 36                error_msg = "Certificate verification failed"
 37
 38            return AuditResult(
 39                name=self.name,
 40                detected=True,
 41                required=self.required,
 42                checks=[
 43                    CheckResult(
 44                        name="connection",
 45                        passed=False,
 46                        message=f"Failed to establish secure TLS connection: {error_msg}",
 47                    )
 48                ],
 49                description=self.description,
 50            )
 51
 52        response = scanner.fetch()
 53
 54        initial_parsed = urlparse(scanner.url)
 55        final_parsed = urlparse(response.url)
 56
 57        # Prefer the final HTTPS endpoint if we followed redirects
 58        target_parsed = (
 59            final_parsed if final_parsed.scheme == "https" else initial_parsed
 60        )
 61
 62        if target_parsed.scheme != "https":
 63            return AuditResult(
 64                name=self.name,
 65                detected=False,
 66                required=self.required,
 67                checks=[],
 68                description=self.description,
 69            )
 70
 71        hostname = target_parsed.hostname
 72        port = target_parsed.port or 443
 73
 74        if not hostname:
 75            return AuditResult(
 76                name=self.name,
 77                detected=False,
 78                required=self.required,
 79                checks=[],
 80                description=self.description,
 81            )
 82
 83        # Try to get certificate info
 84        try:
 85            cert_info = self._get_certificate_info(hostname, port)
 86        except Exception as e:
 87            # TLS connection failed
 88            return AuditResult(
 89                name=self.name,
 90                detected=True,
 91                required=self.required,
 92                checks=[
 93                    CheckResult(
 94                        name="connection",
 95                        passed=False,
 96                        message=f"Failed to connect via TLS: {str(e)}",
 97                    )
 98                ],
 99                description=self.description,
100            )
101
102        # Run checks on the certificate
103        checks = [
104            self._check_certificate_expiry(cert_info),
105            self._check_tls_version(cert_info),
106            self._check_legacy_tls(cert_info),
107            self._check_certificate_hostname(cert_info, hostname),
108        ]
109
110        return AuditResult(
111            name=self.name,
112            detected=True,
113            required=self.required,
114            checks=checks,
115            description=self.description,
116        )
117
118    def _get_certificate_info(self, hostname: str, port: int) -> dict:
119        """Get certificate information from the server."""
120        context = ssl.create_default_context()
121
122        with socket.create_connection((hostname, port), timeout=10) as sock:
123            with context.wrap_socket(sock, server_hostname=hostname) as ssock:
124                cert = ssock.getpeercert()
125                tls_version = ssock.version()
126
127                return {
128                    "cert": cert,
129                    "tls_version": tls_version,
130                }
131
132    def _check_certificate_expiry(self, cert_info: dict) -> CheckResult:
133        """Check if certificate is expired or expiring soon."""
134        cert = cert_info["cert"]
135
136        # Parse notAfter date
137        not_after_str = cert.get("notAfter")
138        if not not_after_str:
139            return CheckResult(
140                name="certificate-expiry",
141                passed=False,
142                message="Certificate has no expiration date",
143            )
144
145        # Parse the date string (format: 'Jul 15 12:00:00 2025 GMT')
146        not_after = datetime.strptime(not_after_str, "%b %d %H:%M:%S %Y %Z")
147        not_after = not_after.replace(tzinfo=UTC)
148
149        now = datetime.now(UTC)
150        days_until_expiry = (not_after - now).days
151
152        if days_until_expiry < 0:
153            return CheckResult(
154                name="certificate-expiry",
155                passed=False,
156                message=f"Certificate expired {abs(days_until_expiry)} days ago",
157            )
158
159        if days_until_expiry < 30:
160            return CheckResult(
161                name="certificate-expiry",
162                passed=False,
163                message=f"Certificate expires in {days_until_expiry} days (renew soon)",
164            )
165
166        return CheckResult(
167            name="certificate-expiry",
168            passed=True,
169            message=f"Certificate valid for {days_until_expiry} days",
170        )
171
172    def _check_tls_version(self, cert_info: dict) -> CheckResult:
173        """Check if using a secure TLS version."""
174        tls_version = cert_info["tls_version"]
175
176        # TLS 1.2 and 1.3 are secure
177        secure_versions = ["TLSv1.2", "TLSv1.3"]
178
179        if tls_version in secure_versions:
180            return CheckResult(
181                name="tls-version",
182                passed=True,
183                message=f"Using {tls_version}",
184            )
185
186        # Older versions are insecure
187        return CheckResult(
188            name="tls-version",
189            passed=False,
190            message=f"Using outdated {tls_version} (upgrade to TLS 1.2 or 1.3)",
191        )
192
193    def _check_legacy_tls(self, cert_info: dict) -> CheckResult:
194        """Check if server supports legacy TLS 1.0/1.1 (should be disabled)."""
195        tls_version = cert_info["tls_version"]
196
197        # If we connected with TLS 1.0 or 1.1, that's already bad
198        if tls_version in ["TLSv1", "TLSv1.0", "TLSv1.1"]:
199            return CheckResult(
200                name="legacy-tls",
201                passed=False,
202                message=f"Server is using legacy {tls_version} (should be disabled)",
203            )
204
205        # If we're using TLS 1.2+, we can't easily test if 1.0/1.1 are also enabled
206        # without making additional connections with specific protocols
207        # For now, just report that we're using a modern version
208        return CheckResult(
209            name="legacy-tls",
210            passed=True,
211            message=f"Connection uses modern TLS ({tls_version})",
212        )
213
214    def _check_certificate_hostname(
215        self, cert_info: dict, hostname: str
216    ) -> CheckResult:
217        """Check if certificate hostname matches the requested hostname."""
218        cert = cert_info["cert"]
219
220        # Get subject common name
221        subject = dict(x[0] for x in cert.get("subject", ()))
222        common_name = subject.get("commonName", "")
223
224        # Get subject alternative names
225        san_list = []
226        for item in cert.get("subjectAltName", ()):
227            if item[0] == "DNS":
228                san_list.append(item[1])
229
230        # Check if hostname matches CN or any SAN
231        if hostname == common_name or hostname in san_list:
232            return CheckResult(
233                name="certificate-hostname",
234                passed=True,
235                message=f"Certificate hostname matches: {common_name}",
236            )
237
238        # Check for wildcard matches
239        for san in san_list:
240            if san.startswith("*.") and hostname.endswith(san[1:]):
241                return CheckResult(
242                    name="certificate-hostname",
243                    passed=True,
244                    message=f"Certificate hostname matches wildcard: {san}",
245                )
246
247        # Hostname mismatch
248        all_names = [common_name] + san_list
249        return CheckResult(
250            name="certificate-hostname",
251            passed=False,
252            message=f"Certificate hostname mismatch (expected: {hostname}, found: {', '.join(all_names)})",
253        )