1from __future__ import annotations
 2
 3import logging
 4import threading
 5from datetime import UTC, datetime
 6from typing import TypedDict
 7
 8from opentelemetry import trace
 9from opentelemetry.trace import format_span_id, format_trace_id
10
11from .core import ObserverMode
12from .otel import get_observer_span_processor
13
14
15class ObserverLogEntry(TypedDict):
16    message: str
17    level: str
18    span_id: str
19    timestamp: datetime
20
21
22class ObserverLogHandler(logging.Handler):
23    """Custom logging handler that captures logs during active traces when observer is enabled."""
24
25    def __init__(self, level: int = logging.NOTSET) -> None:
26        super().__init__(level)
27        self._logs_lock = threading.Lock()
28        self._trace_logs: dict[str, list[ObserverLogEntry]] = {}
29
30    def emit(self, record: logging.LogRecord) -> None:
31        """Emit a log record if we're in an active observer trace."""
32        try:
33            # Get the current span to determine if we're in an active trace
34            current_span = trace.get_current_span()
35            if not current_span or not current_span.get_span_context().is_valid:
36                return
37
38            # Get trace and span IDs
39            trace_id = f"0x{format_trace_id(current_span.get_span_context().trace_id)}"
40            span_id = f"0x{format_span_id(current_span.get_span_context().span_id)}"
41
42            # Check if observer is recording this trace
43            processor = get_observer_span_processor()
44            if not processor:
45                return
46
47            # Check if we should record logs for this trace
48            with processor._traces_lock:
49                if trace_id not in processor._traces:
50                    return
51
52                trace_info = processor._traces[trace_id]
53                # Only capture logs in PERSIST mode
54                if trace_info["mode"] != ObserverMode.PERSIST.value:
55                    return
56
57            # Store the formatted message with span context
58            log_entry: ObserverLogEntry = {
59                "message": self.format(record),
60                "level": record.levelname,
61                "span_id": span_id,
62                "timestamp": datetime.fromtimestamp(record.created, tz=UTC),
63            }
64
65            with self._logs_lock:
66                if trace_id not in self._trace_logs:
67                    self._trace_logs[trace_id] = []
68                self._trace_logs[trace_id].append(log_entry)
69
70                # Limit logs per trace to prevent memory issues
71                if len(self._trace_logs[trace_id]) > 1000:
72                    self._trace_logs[trace_id] = self._trace_logs[trace_id][-500:]
73
74        except Exception:
75            # Don't let logging errors break the application
76            pass
77
78    def pop_logs_for_trace(self, trace_id: str) -> list[ObserverLogEntry]:
79        """Get and remove all logs for a specific trace in one operation."""
80        with self._logs_lock:
81            return self._trace_logs.pop(trace_id, []).copy()
82
83
84# Global instance of the log handler
85observer_log_handler = ObserverLogHandler()