v0.146.0
  1from __future__ import annotations
  2
  3import json
  4from collections import Counter
  5from collections.abc import Iterable, Mapping, Sequence
  6from datetime import UTC, datetime
  7from functools import cached_property
  8from typing import TYPE_CHECKING, Any, cast
  9
 10import sqlparse
 11from opentelemetry.sdk.trace import ReadableSpan
 12from opentelemetry.semconv._incubating.attributes import (
 13    exception_attributes,
 14    session_attributes,
 15    user_attributes,
 16)
 17from opentelemetry.semconv._incubating.attributes.code_attributes import (
 18    CODE_NAMESPACE,
 19)
 20from opentelemetry.semconv._incubating.attributes.db_attributes import (
 21    DB_QUERY_PARAMETER_TEMPLATE,
 22)
 23from opentelemetry.semconv._incubating.attributes.http_attributes import (
 24    HTTP_RESPONSE_BODY_SIZE,
 25)
 26from opentelemetry.semconv.attributes import db_attributes, service_attributes
 27from opentelemetry.semconv.attributes.code_attributes import (
 28    CODE_COLUMN_NUMBER,
 29    CODE_FILE_PATH,
 30    CODE_FUNCTION_NAME,
 31    CODE_LINE_NUMBER,
 32    CODE_STACKTRACE,
 33)
 34from opentelemetry.trace import format_trace_id
 35
 36from plain import postgres
 37from plain.postgres import types
 38from plain.runtime import settings
 39from plain.urls import reverse
 40
 41__all__ = ["Log", "Span", "Trace"]
 42
 43
 44from .formatting import format_bytes
 45
 46
 47@postgres.register_model
 48class Trace(postgres.Model):
 49    trace_id: str = types.TextField(max_length=255)
 50    start_time: datetime = types.DateTimeField()
 51    end_time: datetime = types.DateTimeField()
 52
 53    root_span_name: str = types.TextField(default="", required=False)
 54    summary: str = types.TextField(max_length=255, default="", required=False)
 55
 56    # Plain fields
 57    request_id: str = types.TextField(max_length=255, default="", required=False)
 58    session_id: str = types.TextField(max_length=255, default="", required=False)
 59    user_id: str = types.TextField(max_length=255, default="", required=False)
 60    app_name: str = types.TextField(max_length=255, default="", required=False)
 61    app_version: str = types.TextField(max_length=255, default="", required=False)
 62
 63    # Explicit reverse relations
 64    spans: types.ReverseForeignKey[Span] = types.ReverseForeignKey(
 65        to="Span", field="trace"
 66    )
 67    logs: types.ReverseForeignKey[Log] = types.ReverseForeignKey(
 68        to="Log", field="trace"
 69    )
 70
 71    query: postgres.QuerySet[Trace] = postgres.QuerySet()
 72
 73    model_options = postgres.Options(
 74        ordering=["-start_time"],
 75        constraints=[
 76            postgres.UniqueConstraint(
 77                fields=["trace_id"],
 78                name="observer_unique_trace_id",
 79            )
 80        ],
 81        indexes=[
 82            postgres.Index(
 83                name="plainobserver_trace_start_time_idx", fields=["start_time"]
 84            ),
 85            postgres.Index(
 86                name="plainobserver_trace_request_id_idx", fields=["request_id"]
 87            ),
 88            postgres.Index(
 89                name="plainobserver_trace_session_id_idx", fields=["session_id"]
 90            ),
 91        ],
 92    )
 93
 94    def __str__(self) -> str:
 95        return self.trace_id
 96
 97    def get_absolute_url(self) -> str:
 98        """Return the canonical URL for this trace."""
 99        return reverse("observer:trace_detail", trace_id=self.trace_id)
100
101    def duration_ms(self) -> float:
102        return (self.end_time - self.start_time).total_seconds() * 1000
103
104    def get_trace_stats(self, spans: Iterable[Span]) -> dict[str, Any]:
105        """Extract structured performance stats from trace spans."""
106        query_texts: list[str] = []
107        response_body_size: int | None = None
108        for span in spans:
109            if query_text := span.attributes.get(db_attributes.DB_QUERY_TEXT):
110                query_texts.append(query_text)
111            if (body_size := span.attributes.get(HTTP_RESPONSE_BODY_SIZE)) is not None:
112                response_body_size = int(body_size)
113
114        query_counts = Counter(query_texts)
115        return {
116            "query_count": len(query_texts),
117            "duplicate_count": sum(query_counts.values()) - len(query_counts),
118            "duration_ms": self.duration_ms(),
119            "response_body_size": response_body_size,
120        }
121
122    def get_trace_summary(self, spans: Iterable[Span]) -> str:
123        """Format trace stats as a human-readable summary string."""
124        stats = self.get_trace_stats(spans)
125        parts: list[str] = []
126
127        query_total = stats["query_count"]
128        if query_total > 0:
129            query_part = f"{query_total} quer{'y' if query_total == 1 else 'ies'}"
130            duplicate_count = stats["duplicate_count"]
131            if duplicate_count > 0:
132                query_part += f" ({duplicate_count} duplicate{'' if duplicate_count == 1 else 's'})"
133            parts.append(query_part)
134
135        if stats["duration_ms"] is not None:
136            parts.append(f"{round(stats['duration_ms'], 1)}ms")
137
138        if stats["response_body_size"] is not None:
139            parts.append(format_bytes(stats["response_body_size"]))
140
141        return " • ".join(parts)
142
143    @classmethod
144    def from_opentelemetry_spans(cls, spans: Sequence[ReadableSpan]) -> Trace:
145        """Create a Trace instance from a list of OpenTelemetry spans."""
146        # Get trace information from the first span
147        first_span = spans[0]
148        trace_id = f"0x{format_trace_id(first_span.get_span_context().trace_id)}"
149
150        # Find trace boundaries and root span info
151        earliest_start = None
152        latest_end = None
153        root_span = None
154        request_id = ""
155        user_id = ""
156        session_id = ""
157        app_name = ""
158        app_version = ""
159
160        for span in spans:
161            if not span.parent:
162                root_span = span
163
164            if span.start_time and (
165                earliest_start is None or span.start_time < earliest_start
166            ):
167                earliest_start = span.start_time
168            # Only update latest_end if the span has actually ended
169            if span.end_time and (latest_end is None or span.end_time > latest_end):
170                latest_end = span.end_time
171
172            # For OpenTelemetry spans, access attributes directly
173            span_attrs = getattr(span, "attributes", {})
174            request_id = request_id or span_attrs.get("plain.request.id", "")
175            user_id = user_id or span_attrs.get(user_attributes.USER_ID, "")
176            session_id = session_id or span_attrs.get(session_attributes.SESSION_ID, "")
177
178            # Access Resource attributes if not found in span attributes
179            if resource := getattr(span, "resource", None):
180                app_name = app_name or resource.attributes.get(
181                    service_attributes.SERVICE_NAME, ""
182                )
183                app_version = app_version or resource.attributes.get(
184                    service_attributes.SERVICE_VERSION, ""
185                )
186
187        # Convert timestamps
188        start_time = (
189            datetime.fromtimestamp(earliest_start / 1_000_000_000, tz=UTC)
190            if earliest_start
191            else None
192        )
193        end_time = (
194            datetime.fromtimestamp(latest_end / 1_000_000_000, tz=UTC)
195            if latest_end
196            else None
197        )
198
199        return cls(
200            trace_id=trace_id,
201            start_time=start_time,
202            end_time=end_time
203            or start_time,  # Use start_time as fallback for active traces
204            request_id=request_id,
205            user_id=user_id,
206            session_id=session_id,
207            app_name=app_name or settings.NAME,
208            app_version=app_version or settings.VERSION,
209            root_span_name=root_span.name if root_span else "",
210        )
211
212    def as_dict(self) -> dict[str, Any]:
213        spans = [
214            span.span_data for span in self.spans.query.all().order_by("start_time")
215        ]
216        logs = [
217            {
218                "timestamp": log.timestamp.isoformat(),
219                "level": log.level,
220                "message": log.message,
221                "span_id": log.span_id,
222            }
223            for log in self.logs.query.all().order_by("timestamp")
224        ]
225
226        return {
227            "trace_id": self.trace_id,
228            "start_time": self.start_time.isoformat(),
229            "end_time": self.end_time.isoformat(),
230            "duration_ms": self.duration_ms(),
231            "summary": self.summary,
232            "root_span_name": self.root_span_name,
233            "request_id": self.request_id,
234            "user_id": self.user_id,
235            "session_id": self.session_id,
236            "app_name": self.app_name,
237            "app_version": self.app_version,
238            "spans": spans,
239            "logs": logs,
240        }
241
242    def get_timeline_events(self) -> list[dict[str, Any]]:
243        """Get chronological list of spans and logs for unified timeline display."""
244        events: list[dict[str, Any]] = []
245
246        for span in self.spans.query.all().annotate_spans():  # ty: ignore[unresolved-attribute]
247            events.append(
248                {
249                    "type": "span",
250                    "timestamp": span.start_time,
251                    "instance": span,
252                    "span_level": span.level,
253                }
254            )
255
256            # Add logs for this span
257            for log in self.logs.query.filter(span=span):
258                events.append(
259                    {
260                        "type": "log",
261                        "timestamp": log.timestamp,
262                        "instance": log,
263                        "span_level": span.level + 1,
264                    }
265                )
266
267        # Add unlinked logs (logs without span)
268        for log in self.logs.query.filter(span__isnull=True):
269            events.append(
270                {
271                    "type": "log",
272                    "timestamp": log.timestamp,
273                    "instance": log,
274                    "span_level": 0,
275                }
276            )
277
278        # Sort by timestamp
279        return sorted(events, key=lambda x: x["timestamp"])
280
281
282class SpanQuerySet(postgres.QuerySet["Span"]):
283    def annotate_spans(self) -> list[Span]:
284        """Annotate spans with nesting levels and duplicate query warnings."""
285        spans: list[Span] = list(self.order_by("start_time"))
286
287        # Build span dictionary for parent lookups
288        span_dict: dict[str, Span] = {span.span_id: span for span in spans}
289
290        # Calculate nesting levels
291        for span in spans:
292            if not span.parent_id:
293                span.level = 0
294            else:
295                # Find parent's level and add 1
296                parent = span_dict.get(span.parent_id)
297                parent_level = parent.level if parent else 0
298                span.level = parent_level + 1
299
300        query_counts: dict[str, int] = {}
301
302        # First pass: count queries
303        for span in spans:
304            if sql_query := span.sql_query:
305                query_counts[sql_query] = query_counts.get(sql_query, 0) + 1
306
307        # Second pass: add annotations
308        query_occurrences: dict[str, int] = {}
309        for span in spans:
310            span.annotations = []
311
312            # Check for duplicate queries
313            if sql_query := span.sql_query:
314                count = query_counts[sql_query]
315                if count > 1:
316                    occurrence = query_occurrences.get(sql_query, 0) + 1
317                    query_occurrences[sql_query] = occurrence
318
319                    span.annotations.append(
320                        {
321                            "message": f"Duplicate query ({occurrence} of {count})",
322                            "severity": "warning",
323                        }
324                    )
325
326        return spans
327
328
329@postgres.register_model
330class Span(postgres.Model):
331    trace: Trace = types.ForeignKeyField(Trace, on_delete=postgres.CASCADE)
332
333    span_id: str = types.TextField(max_length=255)
334
335    name: str = types.TextField(max_length=255)
336    kind: str = types.TextField(max_length=50)
337    parent_id: str = types.TextField(max_length=255, default="", required=False)
338    start_time: datetime = types.DateTimeField()
339    end_time: datetime = types.DateTimeField()
340    status: str = types.TextField(max_length=50, default="", required=False)
341    span_data: dict = types.JSONField(default={}, required=False)
342
343    # Explicit reverse relation
344    logs: types.ReverseForeignKey[Log] = types.ReverseForeignKey(to="Log", field="span")
345
346    query: SpanQuerySet = SpanQuerySet()
347
348    model_options = postgres.Options(
349        ordering=["-start_time"],
350        constraints=[
351            postgres.UniqueConstraint(
352                fields=["trace", "span_id"],
353                name="observer_unique_span_id",
354            )
355        ],
356        indexes=[
357            postgres.Index(name="plainobserver_span_span_id_idx", fields=["span_id"]),
358            postgres.Index(
359                name="plainobserver_span_start_time_idx", fields=["start_time"]
360            ),
361        ],
362    )
363
364    if TYPE_CHECKING:
365        level: int
366        annotations: list[dict[str, Any]]
367
368    @classmethod
369    def from_opentelemetry_span(cls, otel_span: ReadableSpan, trace: Trace) -> Span:
370        """Create a Span instance from an OpenTelemetry span."""
371
372        span_data = json.loads(otel_span.to_json())
373
374        # Extract status code as string, default to empty string if unset
375        status = ""
376        if span_data.get("status") and span_data["status"].get("status_code"):
377            status = span_data["status"]["status_code"]
378
379        return cls(
380            trace=trace,
381            span_id=span_data["context"]["span_id"],
382            name=span_data["name"],
383            kind=span_data["kind"][len("SpanKind.") :],
384            parent_id=span_data["parent_id"] or "",
385            start_time=span_data["start_time"],
386            end_time=span_data["end_time"],
387            status=status,
388            span_data=span_data,
389        )
390
391    def __str__(self) -> str:
392        return self.span_id
393
394    @property
395    def attributes(self) -> Mapping[str, Any]:
396        """Get attributes from span_data."""
397        return cast(Mapping[str, Any], self.span_data.get("attributes", {}))
398
399    @property
400    def events(self) -> list[Mapping[str, Any]]:
401        """Get events from span_data."""
402        return cast(list[Mapping[str, Any]], self.span_data.get("events", []))
403
404    @property
405    def links(self) -> list[Mapping[str, Any]]:
406        """Get links from span_data."""
407        return cast(list[Mapping[str, Any]], self.span_data.get("links", []))
408
409    @property
410    def resource(self) -> Mapping[str, Any]:
411        """Get resource from span_data."""
412        return cast(Mapping[str, Any], self.span_data.get("resource", {}))
413
414    @property
415    def context(self) -> Mapping[str, Any]:
416        """Get context from span_data."""
417        return cast(Mapping[str, Any], self.span_data.get("context", {}))
418
419    def duration_ms(self) -> float:
420        if self.start_time and self.end_time:
421            return (self.end_time - self.start_time).total_seconds() * 1000
422        return 0
423
424    @cached_property
425    def sql_query(self) -> str | None:
426        """Get the SQL query if this span contains one."""
427        return self.attributes.get(db_attributes.DB_QUERY_TEXT)
428
429    @cached_property
430    def sql_query_params(self) -> dict[str, Any]:
431        """Get query parameters from attributes that start with 'db.query.parameter.'"""
432        if not self.attributes:
433            return {}
434
435        query_params: dict[str, Any] = {}
436        for key, value in self.attributes.items():
437            if key.startswith(DB_QUERY_PARAMETER_TEMPLATE + "."):
438                param_name = key.replace(DB_QUERY_PARAMETER_TEMPLATE + ".", "")
439                query_params[param_name] = value
440
441        return query_params
442
443    @cached_property
444    def source_code_location(self) -> dict[str, Any] | None:
445        """Get the source code location attributes from this span."""
446        if not self.attributes:
447            return None
448
449        # Look for common semantic convention code attributes
450        code_attrs = {}
451        code_attribute_mappings = {
452            CODE_FILE_PATH: "File",
453            CODE_LINE_NUMBER: "Line",
454            CODE_FUNCTION_NAME: "Function",
455            CODE_NAMESPACE: "Namespace",
456            CODE_COLUMN_NUMBER: "Column",
457            CODE_STACKTRACE: "Stacktrace",
458        }
459
460        for attr_key, display_name in code_attribute_mappings.items():
461            if attr_key in self.attributes:
462                code_attrs[display_name] = self.attributes[attr_key]
463
464        return code_attrs if code_attrs else None
465
466    def get_formatted_sql(self) -> str | None:
467        """Get the pretty-formatted SQL query if this span contains one."""
468        sql = self.sql_query
469        if not sql:
470            return None
471
472        return sqlparse.format(
473            sql,
474            reindent=True,
475            keyword_case="upper",
476            identifier_case="lower",
477            strip_comments=False,
478            strip_whitespace=True,
479            indent_width=2,
480            wrap_after=80,
481            comma_first=False,
482        )
483
484    def format_event_timestamp(
485        self, timestamp: float | int | datetime | str
486    ) -> datetime | str:
487        """Convert event timestamp to a readable datetime."""
488        if isinstance(timestamp, int | float):
489            ts_value = float(timestamp)
490            try:
491                # Try as seconds first
492                if ts_value > 1e10:  # Likely nanoseconds
493                    ts_value /= 1e9
494                elif ts_value > 1e7:  # Likely milliseconds
495                    ts_value /= 1e3
496
497                return datetime.fromtimestamp(ts_value, tz=UTC)
498            except (ValueError, OSError):
499                return str(ts_value)
500        if isinstance(timestamp, datetime):
501            return timestamp
502        return str(timestamp)
503
504    def get_exception_stacktrace(self) -> str | None:
505        """Get the exception stacktrace if this span has an exception event."""
506        if not self.events:
507            return None
508
509        for event in self.events:
510            if event.get("name") == "exception" and event.get("attributes"):
511                return event["attributes"].get(
512                    exception_attributes.EXCEPTION_STACKTRACE
513                )
514        return None
515
516
517@postgres.register_model
518class Log(postgres.Model):
519    trace: Trace = types.ForeignKeyField(Trace, on_delete=postgres.CASCADE)
520    trace_id: int
521    span: Span | None = types.ForeignKeyField(
522        Span,
523        on_delete=postgres.SET_NULL,
524        allow_null=True,
525        required=False,
526    )
527    span_id: int | None
528
529    timestamp: datetime = types.DateTimeField()
530    level: str = types.TextField(max_length=20)
531    message: str = types.TextField()
532
533    query: postgres.QuerySet[Log] = postgres.QuerySet()
534
535    model_options = postgres.Options(
536        ordering=["timestamp"],
537        indexes=[
538            postgres.Index(
539                name="plainobserver_log_trace_id_timestamp_idx",
540                fields=["trace", "timestamp"],
541            ),
542            postgres.Index(
543                name="plainobserver_log_trace_id_span_id_idx", fields=["trace", "span"]
544            ),
545            postgres.Index(
546                name="plainobserver_log_timestamp_idx", fields=["timestamp"]
547            ),
548            postgres.Index(name="plainobserver_log_span_id_idx", fields=["span"]),
549        ],
550    )