1from __future__ import annotations
  2
  3import logging
  4import re
  5import threading
  6from collections import defaultdict
  7from collections.abc import Sequence
  8from typing import TYPE_CHECKING, Any, cast
  9
 10import opentelemetry.context as context_api
 11from opentelemetry import trace
 12from opentelemetry.context import Context
 13from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor, sampling
 14from opentelemetry.semconv.attributes import url_attributes
 15from opentelemetry.trace import (
 16    Link,
 17    SpanKind,
 18    TraceState,
 19    format_span_id,
 20    format_trace_id,
 21)
 22from opentelemetry.util.types import Attributes
 23
 24from plain.logs import app_logger
 25from plain.models.otel import suppress_db_tracing
 26from plain.runtime import settings
 27
 28from .core import Observer, ObserverMode
 29
 30if TYPE_CHECKING:
 31    from plain.observer.models import Span as ObserverSpanModel
 32    from plain.observer.models import Trace as TraceModel
 33
 34    from .logging import ObserverLogEntry
 35
 36logger = logging.getLogger(__name__)
 37
 38
 39def get_observer_span_processor() -> ObserverSpanProcessor | None:
 40    """Get the span collector instance from the tracer provider."""
 41    if not (current_provider := trace.get_tracer_provider()):
 42        return None
 43
 44    # Look for ObserverSpanProcessor in the span processors
 45    # Check if the provider has a _active_span_processor attribute
 46    if hasattr(current_provider, "_active_span_processor"):
 47        # It's a composite processor, check its _span_processors
 48        if composite_processor := current_provider._active_span_processor:
 49            if hasattr(composite_processor, "_span_processors"):
 50                processors = cast(
 51                    Sequence[SpanProcessor],
 52                    getattr(composite_processor, "_span_processors", ()),
 53                )
 54                for processor in processors:
 55                    if isinstance(processor, ObserverSpanProcessor):
 56                        return processor
 57
 58    return None
 59
 60
 61def get_current_trace_summary() -> str | None:
 62    """Get performance summary for the currently active trace."""
 63    if not (current_span := trace.get_current_span()):
 64        return None
 65
 66    if not (processor := get_observer_span_processor()):
 67        return None
 68
 69    trace_id = f"0x{format_trace_id(current_span.get_span_context().trace_id)}"
 70
 71    # Technically we're still in the trace... so the duration and stuff could shift slightly
 72    # (though we should be at the end of the template, hopefully)
 73    return processor.get_trace_summary(trace_id)
 74
 75
 76class ObserverSampler(sampling.Sampler):
 77    """Samples traces based on request path and cookies."""
 78
 79    def __init__(self) -> None:
 80        # Custom parent-based sampler
 81        self._delegate = sampling.ParentBased(sampling.ALWAYS_OFF)
 82
 83        # TODO ignore url namespace instead? admin, observer, assets
 84        self._ignore_url_paths: list[re.Pattern[str]] = [
 85            re.compile(p) for p in settings.OBSERVER_IGNORE_URL_PATTERNS
 86        ]
 87
 88    def should_sample(
 89        self,
 90        parent_context: Context | None,
 91        trace_id: int,
 92        name: str,
 93        kind: SpanKind | None = None,
 94        attributes: Attributes = None,
 95        links: Sequence[Link] | None = None,
 96        trace_state: TraceState | None = None,
 97    ) -> sampling.SamplingResult:
 98        # First, drop if the URL should be ignored.
 99        if attributes:
100            url_path = attributes.get(url_attributes.URL_PATH, "")
101            if isinstance(url_path, str) and url_path:
102                for pattern in self._ignore_url_paths:
103                    if pattern.match(url_path):
104                        return sampling.SamplingResult(
105                            sampling.Decision.DROP,
106                            attributes=attributes,
107                        )
108
109        # If no processor decision, check headers and cookies directly for root spans
110        decision: sampling.Decision | None = None
111        if parent_context:
112            # Check Observer header (DEBUG only) and cookies
113            mode = Observer.from_otel_context(parent_context).mode()
114
115            # Set decision based on mode
116            if mode in (ObserverMode.PERSIST.value, ObserverMode.SUMMARY.value):
117                # Always use RECORD_AND_SAMPLE so ParentBased works correctly
118                # The processor will check the mode to decide whether to export
119                decision = sampling.Decision.RECORD_AND_SAMPLE
120            elif mode == ObserverMode.DISABLED.value:
121                # Explicitly disabled - never sample even with remote parent
122                decision = sampling.Decision.DROP
123
124        # If there are links, assume it is to another trace/span that we are keeping
125        if links:
126            decision = sampling.Decision.RECORD_AND_SAMPLE
127
128        # If no decision from cookies, use default
129        if decision is None:
130            result = self._delegate.should_sample(
131                parent_context,
132                trace_id,
133                name,
134                kind=kind,
135                attributes=attributes,
136                links=links,
137                trace_state=trace_state,
138            )
139            decision = result.decision
140
141        return sampling.SamplingResult(
142            decision,
143            attributes=attributes,
144        )
145
146    def get_description(self) -> str:
147        return "ObserverSampler"
148
149
150class ObserverCombinedSampler(sampling.Sampler):
151    """Combine another sampler with ``ObserverSampler``."""
152
153    def __init__(self, primary: sampling.Sampler, secondary: sampling.Sampler) -> None:
154        self.primary = primary
155        self.secondary = secondary
156
157    def should_sample(
158        self,
159        parent_context: Context | None,
160        trace_id: int,
161        name: str,
162        kind: SpanKind | None = None,
163        attributes: Attributes = None,
164        links: Sequence[Link] | None = None,
165        trace_state: TraceState | None = None,
166    ) -> sampling.SamplingResult:
167        result = self.primary.should_sample(
168            parent_context,
169            trace_id,
170            name,
171            kind=kind,
172            attributes=attributes,
173            links=links,
174            trace_state=trace_state,
175        )
176
177        if result.decision is sampling.Decision.DROP:
178            return self.secondary.should_sample(
179                parent_context,
180                trace_id,
181                name,
182                kind=kind,
183                attributes=attributes,
184                links=links,
185                trace_state=trace_state,
186            )
187
188        return result
189
190    def get_description(self) -> str:
191        return f"ObserverCombinedSampler({self.primary.get_description()}, {self.secondary.get_description()})"
192
193
194class ObserverSpanProcessor(SpanProcessor):
195    """Collects spans in real-time for current trace performance monitoring.
196
197    This processor keeps spans in memory for traces that have the 'summary' or 'persist'
198    cookie set. These spans can be accessed via get_current_trace_summary() for
199    real-time debugging. Spans with 'persist' cookie will also be persisted to the
200    database.
201    """
202
203    def __init__(self) -> None:
204        # Span storage
205        self._traces: defaultdict[str, dict[str, Any]] = defaultdict(
206            lambda: {
207                "trace": None,  # Trace model instance
208                "active_otel_spans": {},  # span_id -> opentelemetry span
209                "completed_otel_spans": [],  # list of opentelemetry spans
210                "span_models": [],  # list of Span model instances
211                "root_span_id": None,
212                "mode": None,  # None, ObserverMode.SUMMARY.value, or ObserverMode.PERSIST.value
213            }
214        )
215        self._traces_lock = threading.Lock()
216        self._ignore_url_paths: list[re.Pattern[str]] = [
217            re.compile(p) for p in settings.OBSERVER_IGNORE_URL_PATTERNS
218        ]
219
220    def on_start(self, span: Any, parent_context: Context | None = None) -> None:
221        """Called when a span starts."""
222        trace_id = f"0x{format_trace_id(span.get_span_context().trace_id)}"
223
224        with self._traces_lock:
225            # Check if we already have this trace
226            if trace_id in self._traces:
227                trace_info = self._traces[trace_id]
228            else:
229                # First span in trace - determine if we should record it
230                mode = self._get_recording_mode(span, parent_context)
231                if not mode:
232                    # Don't create trace entry for traces we won't record
233                    return
234
235                # Create trace entry only for traces we'll record
236                trace_info = self._traces[trace_id]
237                trace_info["mode"] = mode
238
239            span_id = f"0x{format_span_id(span.get_span_context().span_id)}"
240
241            # Enable DEBUG logging only for PERSIST mode (when logs are captured)
242            if trace_info["mode"] == ObserverMode.PERSIST.value:
243                app_logger.debug_mode.start()
244
245            # Store span (we know mode is truthy if we get here)
246            trace_info["active_otel_spans"][span_id] = span
247
248            # Track root span
249            if not span.parent:
250                trace_info["root_span_id"] = span_id
251
252    def on_end(self, span: ReadableSpan) -> None:
253        """Called when a span ends."""
254        trace_id = f"0x{format_trace_id(span.get_span_context().trace_id)}"
255        span_id = f"0x{format_span_id(span.get_span_context().span_id)}"
256
257        with self._traces_lock:
258            # Skip if we don't have this trace (mode was None on start)
259            if trace_id not in self._traces:
260                return
261
262            trace_info = self._traces[trace_id]
263
264            # Disable DEBUG logging only for PERSIST mode spans
265            if trace_info["mode"] == ObserverMode.PERSIST.value:
266                app_logger.debug_mode.end()
267
268            # Move span from active to completed
269            if trace_info["active_otel_spans"].pop(span_id, None):
270                trace_info["completed_otel_spans"].append(span)
271
272            # Check if trace is complete (root span ended)
273            if span_id == trace_info["root_span_id"]:
274                all_spans = trace_info["completed_otel_spans"]
275
276                from .models import Span, Trace
277
278                trace_info["trace"] = Trace.from_opentelemetry_spans(all_spans)
279                trace_info["span_models"] = [
280                    Span.from_opentelemetry_span(s, trace_info["trace"])
281                    for s in all_spans
282                ]
283
284                # Export if in persist mode
285                if trace_info["mode"] == ObserverMode.PERSIST.value:
286                    # Get and remove logs for this trace
287                    from .logging import observer_log_handler
288
289                    if observer_log_handler:
290                        logs = observer_log_handler.pop_logs_for_trace(trace_id)
291                    else:
292                        logs = []
293
294                    logger.debug(
295                        "Exporting %d spans and %d logs for trace %s",
296                        len(trace_info["span_models"]),
297                        len(logs),
298                        trace_id,
299                    )
300                    # The trace is done now, so we can get a more accurate summary
301                    trace_info["trace"].summary = trace_info["trace"].get_trace_summary(
302                        trace_info["span_models"]
303                    )
304                    self._export_trace(
305                        trace=trace_info["trace"],
306                        spans=trace_info["span_models"],
307                        logs=logs,
308                    )
309
310                # Clean up trace
311                del self._traces[trace_id]
312
313    def get_trace_summary(self, trace_id: str) -> str | None:
314        """Get performance summary for a specific trace."""
315        from .models import Span, Trace
316
317        with self._traces_lock:
318            # Return None if trace doesn't exist (mode was None)
319            if trace_id not in self._traces:
320                return None
321
322            trace_info = self._traces[trace_id]
323
324            # Combine active and completed spans
325            all_otel_spans = (
326                list(trace_info["active_otel_spans"].values())
327                + trace_info["completed_otel_spans"]
328            )
329
330            if not all_otel_spans:
331                return None
332
333            # Create or update trace model instance
334            if not trace_info["trace"]:
335                trace_info["trace"] = Trace.from_opentelemetry_spans(all_otel_spans)
336
337            if not trace_info["trace"]:
338                return None
339
340            # Create span model instances if needed
341            span_models = trace_info.get("span_models", [])
342            if not span_models:
343                span_models = [
344                    Span.from_opentelemetry_span(s, trace_info["trace"])
345                    for s in all_otel_spans
346                ]
347
348            return trace_info["trace"].get_trace_summary(span_models)
349
350    def _export_trace(
351        self,
352        *,
353        trace: TraceModel,
354        spans: Sequence[ObserverSpanModel],
355        logs: Sequence[ObserverLogEntry],
356    ) -> None:
357        """Export trace, spans, and logs to the database."""
358        from .models import Log, Span, Trace
359
360        with suppress_db_tracing():
361            try:
362                trace.save()
363
364                for span in spans:
365                    span.trace = trace
366
367                # Bulk create spans
368                Span.query.bulk_create(spans)
369
370                # Create log models if we have logs
371                if logs:
372                    # Create a mapping of span_id to span_model
373                    span_id_to_model = {
374                        span_model.span_id: span_model for span_model in spans
375                    }
376
377                    log_models = []
378                    for log_entry in logs:
379                        log_model = Log(
380                            trace=trace,
381                            timestamp=log_entry["timestamp"],
382                            level=log_entry["level"],
383                            message=log_entry["message"],
384                            span=span_id_to_model.get(log_entry["span_id"]),
385                        )
386                        log_models.append(log_model)
387
388                    Log.query.bulk_create(log_models)
389
390            except Exception as e:
391                logger.warning(
392                    "Failed to export trace to database: %s",
393                    e,
394                    exc_info=True,
395                )
396
397            # Delete oldest traces if we exceed the limit
398            if settings.OBSERVER_TRACE_LIMIT > 0:
399                try:
400                    if Trace.query.count() > settings.OBSERVER_TRACE_LIMIT:
401                        excess_count = (
402                            Trace.query.count() - settings.OBSERVER_TRACE_LIMIT
403                        )
404                        delete_ids = Trace.query.order_by("start_time")[
405                            :excess_count
406                        ].values_list(  # type: ignore[union-attr]
407                            "id", flat=True
408                        )
409                        Trace.query.filter(id__in=delete_ids).delete()
410                except Exception as e:
411                    logger.warning(
412                        "Failed to clean up old observer traces: %s", e, exc_info=True
413                    )
414
415    def _get_recording_mode(
416        self, span: Any, parent_context: Context | None
417    ) -> str | None:
418        # Again check the span attributes, in case we relied on another sampler
419        if span.attributes:
420            if url_path := span.attributes.get(url_attributes.URL_PATH, ""):
421                for pattern in self._ignore_url_paths:
422                    if pattern.match(url_path):
423                        return None
424
425        # If the span has links, then we are going to export if the linked span is also exported
426        for link in span.links:
427            if link.context.is_valid and link.context.span_id:
428                from .models import Span
429
430                with suppress_db_tracing():
431                    if Span.query.filter(
432                        span_id=f"0x{format_span_id(link.context.span_id)}"
433                    ).exists():
434                        return ObserverMode.PERSIST.value
435
436        if not (context := parent_context or context_api.get_current()):
437            return None
438
439        # Check Observer header (DEBUG only) and cookies
440        mode = Observer.from_otel_context(context).mode()
441
442        # Only return valid recording modes (summary/persist), not disabled
443        if mode in (ObserverMode.SUMMARY.value, ObserverMode.PERSIST.value):
444            return mode
445
446        return None
447
448    def shutdown(self) -> None:
449        """Cleanup when shutting down."""
450        with self._traces_lock:
451            self._traces.clear()
452
453    def force_flush(self, timeout_millis: int | None = None) -> bool:
454        """Required by SpanProcessor interface."""
455        return True