1from __future__ import annotations
  2
  3import json
  4from collections import Counter
  5from collections.abc import Iterable, Mapping, Sequence
  6from datetime import UTC, datetime
  7from functools import cached_property
  8from typing import TYPE_CHECKING, Any, cast
  9
 10import sqlparse
 11from opentelemetry.sdk.trace import ReadableSpan
 12from opentelemetry.semconv._incubating.attributes import (
 13    exception_attributes,
 14    session_attributes,
 15    user_attributes,
 16)
 17from opentelemetry.semconv._incubating.attributes.code_attributes import (
 18    CODE_NAMESPACE,
 19)
 20from opentelemetry.semconv._incubating.attributes.db_attributes import (
 21    DB_QUERY_PARAMETER_TEMPLATE,
 22)
 23from opentelemetry.semconv.attributes import db_attributes, service_attributes
 24from opentelemetry.semconv.attributes.code_attributes import (
 25    CODE_COLUMN_NUMBER,
 26    CODE_FILE_PATH,
 27    CODE_FUNCTION_NAME,
 28    CODE_LINE_NUMBER,
 29    CODE_STACKTRACE,
 30)
 31from opentelemetry.trace import format_trace_id
 32
 33from plain import models
 34from plain.models import types
 35from plain.runtime import settings
 36from plain.urls import reverse
 37
 38
 39@models.register_model
 40class Trace(models.Model):
 41    trace_id: str = types.CharField(max_length=255)
 42    start_time: datetime = types.DateTimeField()
 43    end_time: datetime = types.DateTimeField()
 44
 45    root_span_name: str = types.TextField(default="", required=False)
 46    summary: str = types.CharField(max_length=255, default="", required=False)
 47
 48    # Plain fields
 49    request_id: str = types.CharField(max_length=255, default="", required=False)
 50    session_id: str = types.CharField(max_length=255, default="", required=False)
 51    user_id: str = types.CharField(max_length=255, default="", required=False)
 52    app_name: str = types.CharField(max_length=255, default="", required=False)
 53    app_version: str = types.CharField(max_length=255, default="", required=False)
 54
 55    # Explicit reverse relations
 56    spans: types.ReverseForeignKey[Span] = types.ReverseForeignKey(
 57        to="Span", field="trace"
 58    )
 59    logs: types.ReverseForeignKey[Log] = types.ReverseForeignKey(
 60        to="Log", field="trace"
 61    )
 62
 63    query: models.QuerySet[Trace] = models.QuerySet()
 64
 65    model_options = models.Options(
 66        ordering=["-start_time"],
 67        constraints=[
 68            models.UniqueConstraint(
 69                fields=["trace_id"],
 70                name="observer_unique_trace_id",
 71            )
 72        ],
 73        indexes=[
 74            models.Index(fields=["trace_id"]),
 75            models.Index(fields=["start_time"]),
 76            models.Index(fields=["request_id"]),
 77            models.Index(fields=["session_id"]),
 78        ],
 79    )
 80
 81    def __str__(self) -> str:
 82        return self.trace_id
 83
 84    def get_absolute_url(self) -> str:
 85        """Return the canonical URL for this trace."""
 86        return reverse("observer:trace_detail", trace_id=self.trace_id)
 87
 88    def duration_ms(self) -> float:
 89        return (self.end_time - self.start_time).total_seconds() * 1000
 90
 91    def get_trace_summary(self, spans: Iterable[Span]) -> str:
 92        # Count database queries with query text and track duplicates
 93        query_texts: list[str] = []
 94        for span in spans:
 95            if query_text := span.attributes.get(db_attributes.DB_QUERY_TEXT):
 96                query_texts.append(query_text)
 97
 98        query_counts = Counter(query_texts)
 99        query_total = len(query_texts)
100        duplicate_count = sum(query_counts.values()) - len(query_counts)
101
102        # Build summary: "n spans, n queries (n duplicates), Xms"
103        parts: list[str] = []
104
105        # Queries count with duplicates
106        if query_total > 0:
107            query_part = f"{query_total} quer{'y' if query_total == 1 else 'ies'}"
108            if duplicate_count > 0:
109                query_part += f" ({duplicate_count} duplicate{'' if duplicate_count == 1 else 's'})"
110            parts.append(query_part)
111
112        # Duration
113        if (duration_ms := self.duration_ms()) is not None:
114            parts.append(f"{round(duration_ms, 1)}ms")
115
116        return " โ€ข ".join(parts)
117
118    @classmethod
119    def from_opentelemetry_spans(cls, spans: Sequence[ReadableSpan]) -> Trace:
120        """Create a Trace instance from a list of OpenTelemetry spans."""
121        # Get trace information from the first span
122        first_span = spans[0]
123        trace_id = f"0x{format_trace_id(first_span.get_span_context().trace_id)}"
124
125        # Find trace boundaries and root span info
126        earliest_start = None
127        latest_end = None
128        root_span = None
129        request_id = ""
130        user_id = ""
131        session_id = ""
132        app_name = ""
133        app_version = ""
134
135        for span in spans:
136            if not span.parent:
137                root_span = span
138
139            if span.start_time and (
140                earliest_start is None or span.start_time < earliest_start
141            ):
142                earliest_start = span.start_time
143            # Only update latest_end if the span has actually ended
144            if span.end_time and (latest_end is None or span.end_time > latest_end):
145                latest_end = span.end_time
146
147            # For OpenTelemetry spans, access attributes directly
148            span_attrs = getattr(span, "attributes", {})
149            request_id = request_id or span_attrs.get("plain.request.id", "")
150            user_id = user_id or span_attrs.get(user_attributes.USER_ID, "")
151            session_id = session_id or span_attrs.get(session_attributes.SESSION_ID, "")
152
153            # Access Resource attributes if not found in span attributes
154            if resource := getattr(span, "resource", None):
155                app_name = app_name or resource.attributes.get(
156                    service_attributes.SERVICE_NAME, ""
157                )
158                app_version = app_version or resource.attributes.get(
159                    service_attributes.SERVICE_VERSION, ""
160                )
161
162        # Convert timestamps
163        start_time = (
164            datetime.fromtimestamp(earliest_start / 1_000_000_000, tz=UTC)
165            if earliest_start
166            else None
167        )
168        end_time = (
169            datetime.fromtimestamp(latest_end / 1_000_000_000, tz=UTC)
170            if latest_end
171            else None
172        )
173
174        return cls(
175            trace_id=trace_id,
176            start_time=start_time,
177            end_time=end_time
178            or start_time,  # Use start_time as fallback for active traces
179            request_id=request_id,
180            user_id=user_id,
181            session_id=session_id,
182            app_name=app_name or settings.NAME,
183            app_version=app_version or settings.VERSION,
184            root_span_name=root_span.name if root_span else "",
185        )
186
187    def as_dict(self) -> dict[str, Any]:
188        spans = [
189            span.span_data for span in self.spans.query.all().order_by("start_time")
190        ]
191        logs = [
192            {
193                "timestamp": log.timestamp.isoformat(),
194                "level": log.level,
195                "message": log.message,
196                "span_id": log.span_id,
197            }
198            for log in self.logs.query.all().order_by("timestamp")
199        ]
200
201        return {
202            "trace_id": self.trace_id,
203            "start_time": self.start_time.isoformat(),
204            "end_time": self.end_time.isoformat(),
205            "duration_ms": self.duration_ms(),
206            "summary": self.summary,
207            "root_span_name": self.root_span_name,
208            "request_id": self.request_id,
209            "user_id": self.user_id,
210            "session_id": self.session_id,
211            "app_name": self.app_name,
212            "app_version": self.app_version,
213            "spans": spans,
214            "logs": logs,
215        }
216
217    def get_timeline_events(self) -> list[dict[str, Any]]:
218        """Get chronological list of spans and logs for unified timeline display."""
219        events: list[dict[str, Any]] = []
220
221        for span in self.spans.query.all().annotate_spans():  # type: ignore[attr-defined]
222            events.append(
223                {
224                    "type": "span",
225                    "timestamp": span.start_time,
226                    "instance": span,
227                    "span_level": span.level,
228                }
229            )
230
231            # Add logs for this span
232            for log in self.logs.query.filter(span=span):
233                events.append(
234                    {
235                        "type": "log",
236                        "timestamp": log.timestamp,
237                        "instance": log,
238                        "span_level": span.level + 1,
239                    }
240                )
241
242        # Add unlinked logs (logs without span)
243        for log in self.logs.query.filter(span__isnull=True):
244            events.append(
245                {
246                    "type": "log",
247                    "timestamp": log.timestamp,
248                    "instance": log,
249                    "span_level": 0,
250                }
251            )
252
253        # Sort by timestamp
254        return sorted(events, key=lambda x: x["timestamp"])
255
256
257class SpanQuerySet(models.QuerySet["Span"]):
258    def annotate_spans(self) -> list[Span]:
259        """Annotate spans with nesting levels and duplicate query warnings."""
260        spans: list[Span] = list(self.order_by("start_time"))
261
262        # Build span dictionary for parent lookups
263        span_dict: dict[str, Span] = {span.span_id: span for span in spans}
264
265        # Calculate nesting levels
266        for span in spans:
267            if not span.parent_id:
268                span.level = 0
269            else:
270                # Find parent's level and add 1
271                parent = span_dict.get(span.parent_id)
272                parent_level = parent.level if parent else 0
273                span.level = parent_level + 1
274
275        query_counts: dict[str, int] = {}
276
277        # First pass: count queries
278        for span in spans:
279            if sql_query := span.sql_query:
280                query_counts[sql_query] = query_counts.get(sql_query, 0) + 1
281
282        # Second pass: add annotations
283        query_occurrences: dict[str, int] = {}
284        for span in spans:
285            span.annotations = []
286
287            # Check for duplicate queries
288            if sql_query := span.sql_query:
289                count = query_counts[sql_query]
290                if count > 1:
291                    occurrence = query_occurrences.get(sql_query, 0) + 1
292                    query_occurrences[sql_query] = occurrence
293
294                    span.annotations.append(
295                        {
296                            "message": f"Duplicate query ({occurrence} of {count})",
297                            "severity": "warning",
298                        }
299                    )
300
301        return spans
302
303
304@models.register_model
305class Span(models.Model):
306    trace: Trace = types.ForeignKeyField(Trace, on_delete=models.CASCADE)
307
308    span_id: str = types.CharField(max_length=255)
309
310    name: str = types.CharField(max_length=255)
311    kind: str = types.CharField(max_length=50)
312    parent_id: str = types.CharField(max_length=255, default="", required=False)
313    start_time: datetime = types.DateTimeField()
314    end_time: datetime = types.DateTimeField()
315    status: str = types.CharField(max_length=50, default="", required=False)
316    span_data: dict = types.JSONField(default=dict, required=False)
317
318    # Explicit reverse relation
319    logs: types.ReverseForeignKey[Log] = types.ReverseForeignKey(to="Log", field="span")
320
321    query: SpanQuerySet = SpanQuerySet()
322
323    model_options = models.Options(
324        ordering=["-start_time"],
325        constraints=[
326            models.UniqueConstraint(
327                fields=["trace", "span_id"],
328                name="observer_unique_span_id",
329            )
330        ],
331        indexes=[
332            models.Index(fields=["span_id"]),
333            models.Index(fields=["trace", "span_id"]),
334            models.Index(fields=["trace"]),
335            models.Index(fields=["start_time"]),
336        ],
337    )
338
339    if TYPE_CHECKING:
340        level: int
341        annotations: list[dict[str, Any]]
342
343    @classmethod
344    def from_opentelemetry_span(cls, otel_span: ReadableSpan, trace: Trace) -> Span:
345        """Create a Span instance from an OpenTelemetry span."""
346
347        span_data = json.loads(otel_span.to_json())
348
349        # Extract status code as string, default to empty string if unset
350        status = ""
351        if span_data.get("status") and span_data["status"].get("status_code"):
352            status = span_data["status"]["status_code"]
353
354        return cls(
355            trace=trace,
356            span_id=span_data["context"]["span_id"],
357            name=span_data["name"],
358            kind=span_data["kind"][len("SpanKind.") :],
359            parent_id=span_data["parent_id"] or "",
360            start_time=span_data["start_time"],
361            end_time=span_data["end_time"],
362            status=status,
363            span_data=span_data,
364        )
365
366    def __str__(self) -> str:
367        return self.span_id
368
369    @property
370    def attributes(self) -> Mapping[str, Any]:
371        """Get attributes from span_data."""
372        return cast(Mapping[str, Any], self.span_data.get("attributes", {}))
373
374    @property
375    def events(self) -> list[Mapping[str, Any]]:
376        """Get events from span_data."""
377        return cast(list[Mapping[str, Any]], self.span_data.get("events", []))
378
379    @property
380    def links(self) -> list[Mapping[str, Any]]:
381        """Get links from span_data."""
382        return cast(list[Mapping[str, Any]], self.span_data.get("links", []))
383
384    @property
385    def resource(self) -> Mapping[str, Any]:
386        """Get resource from span_data."""
387        return cast(Mapping[str, Any], self.span_data.get("resource", {}))
388
389    @property
390    def context(self) -> Mapping[str, Any]:
391        """Get context from span_data."""
392        return cast(Mapping[str, Any], self.span_data.get("context", {}))
393
394    def duration_ms(self) -> float:
395        if self.start_time and self.end_time:
396            return (self.end_time - self.start_time).total_seconds() * 1000
397        return 0
398
399    @cached_property
400    def sql_query(self) -> str | None:
401        """Get the SQL query if this span contains one."""
402        return self.attributes.get(db_attributes.DB_QUERY_TEXT)
403
404    @cached_property
405    def sql_query_params(self) -> dict[str, Any]:
406        """Get query parameters from attributes that start with 'db.query.parameter.'"""
407        if not self.attributes:
408            return {}
409
410        query_params: dict[str, Any] = {}
411        for key, value in self.attributes.items():
412            if key.startswith(DB_QUERY_PARAMETER_TEMPLATE + "."):
413                param_name = key.replace(DB_QUERY_PARAMETER_TEMPLATE + ".", "")
414                query_params[param_name] = value
415
416        return query_params
417
418    @cached_property
419    def source_code_location(self) -> dict[str, Any] | None:
420        """Get the source code location attributes from this span."""
421        if not self.attributes:
422            return None
423
424        # Look for common semantic convention code attributes
425        code_attrs = {}
426        code_attribute_mappings = {
427            CODE_FILE_PATH: "File",
428            CODE_LINE_NUMBER: "Line",
429            CODE_FUNCTION_NAME: "Function",
430            CODE_NAMESPACE: "Namespace",
431            CODE_COLUMN_NUMBER: "Column",
432            CODE_STACKTRACE: "Stacktrace",
433        }
434
435        for attr_key, display_name in code_attribute_mappings.items():
436            if attr_key in self.attributes:
437                code_attrs[display_name] = self.attributes[attr_key]
438
439        return code_attrs if code_attrs else None
440
441    def get_formatted_sql(self) -> str | None:
442        """Get the pretty-formatted SQL query if this span contains one."""
443        sql = self.sql_query
444        if not sql:
445            return None
446
447        return sqlparse.format(
448            sql,
449            reindent=True,
450            keyword_case="upper",
451            identifier_case="lower",
452            strip_comments=False,
453            strip_whitespace=True,
454            indent_width=2,
455            wrap_after=80,
456            comma_first=False,
457        )
458
459    def format_event_timestamp(
460        self, timestamp: float | int | datetime | str
461    ) -> datetime | str:
462        """Convert event timestamp to a readable datetime."""
463        if isinstance(timestamp, int | float):
464            ts_value = float(timestamp)
465            try:
466                # Try as seconds first
467                if ts_value > 1e10:  # Likely nanoseconds
468                    ts_value /= 1e9
469                elif ts_value > 1e7:  # Likely milliseconds
470                    ts_value /= 1e3
471
472                return datetime.fromtimestamp(ts_value, tz=UTC)
473            except (ValueError, OSError):
474                return str(ts_value)
475        if isinstance(timestamp, datetime):
476            return timestamp
477        return str(timestamp)
478
479    def get_exception_stacktrace(self) -> str | None:
480        """Get the exception stacktrace if this span has an exception event."""
481        if not self.events:
482            return None
483
484        for event in self.events:
485            if event.get("name") == "exception" and event.get("attributes"):
486                return event["attributes"].get(
487                    exception_attributes.EXCEPTION_STACKTRACE
488                )
489        return None
490
491
492@models.register_model
493class Log(models.Model):
494    trace: Trace = types.ForeignKeyField(Trace, on_delete=models.CASCADE)
495    trace_id: int
496    span: Span | None = types.ForeignKeyField(
497        Span,
498        on_delete=models.SET_NULL,
499        allow_null=True,
500        required=False,
501    )
502    span_id: int | None
503
504    timestamp: datetime = types.DateTimeField()
505    level: str = types.CharField(max_length=20)
506    message: str = types.TextField()
507
508    query: models.QuerySet[Log] = models.QuerySet()
509
510    model_options = models.Options(
511        ordering=["timestamp"],
512        indexes=[
513            models.Index(fields=["trace", "timestamp"]),
514            models.Index(fields=["trace", "span"]),
515            models.Index(fields=["timestamp"]),
516            models.Index(fields=["trace"]),
517        ],
518    )