1from __future__ import annotations
2
3import json
4from collections import Counter
5from collections.abc import Iterable, Mapping, Sequence
6from datetime import UTC, datetime
7from functools import cached_property
8from typing import TYPE_CHECKING, Any, cast
9
10import sqlparse
11from opentelemetry.sdk.trace import ReadableSpan
12from opentelemetry.semconv._incubating.attributes import (
13 exception_attributes,
14 session_attributes,
15 user_attributes,
16)
17from opentelemetry.semconv._incubating.attributes.code_attributes import (
18 CODE_NAMESPACE,
19)
20from opentelemetry.semconv._incubating.attributes.db_attributes import (
21 DB_QUERY_PARAMETER_TEMPLATE,
22)
23from opentelemetry.semconv.attributes import db_attributes, service_attributes
24from opentelemetry.semconv.attributes.code_attributes import (
25 CODE_COLUMN_NUMBER,
26 CODE_FILE_PATH,
27 CODE_FUNCTION_NAME,
28 CODE_LINE_NUMBER,
29 CODE_STACKTRACE,
30)
31from opentelemetry.trace import format_trace_id
32
33from plain import models
34from plain.models import types
35from plain.runtime import settings
36from plain.urls import reverse
37
38
39@models.register_model
40class Trace(models.Model):
41 trace_id: str = types.CharField(max_length=255)
42 start_time: datetime = types.DateTimeField()
43 end_time: datetime = types.DateTimeField()
44
45 root_span_name: str = types.TextField(default="", required=False)
46 summary: str = types.CharField(max_length=255, default="", required=False)
47
48 # Plain fields
49 request_id: str = types.CharField(max_length=255, default="", required=False)
50 session_id: str = types.CharField(max_length=255, default="", required=False)
51 user_id: str = types.CharField(max_length=255, default="", required=False)
52 app_name: str = types.CharField(max_length=255, default="", required=False)
53 app_version: str = types.CharField(max_length=255, default="", required=False)
54
55 # Explicit reverse relations
56 spans: types.ReverseForeignKey[Span] = types.ReverseForeignKey(
57 to="Span", field="trace"
58 )
59 logs: types.ReverseForeignKey[Log] = types.ReverseForeignKey(
60 to="Log", field="trace"
61 )
62
63 query: models.QuerySet[Trace] = models.QuerySet()
64
65 model_options = models.Options(
66 ordering=["-start_time"],
67 constraints=[
68 models.UniqueConstraint(
69 fields=["trace_id"],
70 name="observer_unique_trace_id",
71 )
72 ],
73 indexes=[
74 models.Index(fields=["trace_id"]),
75 models.Index(fields=["start_time"]),
76 models.Index(fields=["request_id"]),
77 models.Index(fields=["session_id"]),
78 ],
79 )
80
81 def __str__(self) -> str:
82 return self.trace_id
83
84 def get_absolute_url(self) -> str:
85 """Return the canonical URL for this trace."""
86 return reverse("observer:trace_detail", trace_id=self.trace_id)
87
88 def duration_ms(self) -> float:
89 return (self.end_time - self.start_time).total_seconds() * 1000
90
91 def get_trace_summary(self, spans: Iterable[Span]) -> str:
92 # Count database queries with query text and track duplicates
93 query_texts: list[str] = []
94 for span in spans:
95 if query_text := span.attributes.get(db_attributes.DB_QUERY_TEXT):
96 query_texts.append(query_text)
97
98 query_counts = Counter(query_texts)
99 query_total = len(query_texts)
100 duplicate_count = sum(query_counts.values()) - len(query_counts)
101
102 # Build summary: "n spans, n queries (n duplicates), Xms"
103 parts: list[str] = []
104
105 # Queries count with duplicates
106 if query_total > 0:
107 query_part = f"{query_total} quer{'y' if query_total == 1 else 'ies'}"
108 if duplicate_count > 0:
109 query_part += f" ({duplicate_count} duplicate{'' if duplicate_count == 1 else 's'})"
110 parts.append(query_part)
111
112 # Duration
113 if (duration_ms := self.duration_ms()) is not None:
114 parts.append(f"{round(duration_ms, 1)}ms")
115
116 return " โข ".join(parts)
117
118 @classmethod
119 def from_opentelemetry_spans(cls, spans: Sequence[ReadableSpan]) -> Trace:
120 """Create a Trace instance from a list of OpenTelemetry spans."""
121 # Get trace information from the first span
122 first_span = spans[0]
123 trace_id = f"0x{format_trace_id(first_span.get_span_context().trace_id)}"
124
125 # Find trace boundaries and root span info
126 earliest_start = None
127 latest_end = None
128 root_span = None
129 request_id = ""
130 user_id = ""
131 session_id = ""
132 app_name = ""
133 app_version = ""
134
135 for span in spans:
136 if not span.parent:
137 root_span = span
138
139 if span.start_time and (
140 earliest_start is None or span.start_time < earliest_start
141 ):
142 earliest_start = span.start_time
143 # Only update latest_end if the span has actually ended
144 if span.end_time and (latest_end is None or span.end_time > latest_end):
145 latest_end = span.end_time
146
147 # For OpenTelemetry spans, access attributes directly
148 span_attrs = getattr(span, "attributes", {})
149 request_id = request_id or span_attrs.get("plain.request.id", "")
150 user_id = user_id or span_attrs.get(user_attributes.USER_ID, "")
151 session_id = session_id or span_attrs.get(session_attributes.SESSION_ID, "")
152
153 # Access Resource attributes if not found in span attributes
154 if resource := getattr(span, "resource", None):
155 app_name = app_name or resource.attributes.get(
156 service_attributes.SERVICE_NAME, ""
157 )
158 app_version = app_version or resource.attributes.get(
159 service_attributes.SERVICE_VERSION, ""
160 )
161
162 # Convert timestamps
163 start_time = (
164 datetime.fromtimestamp(earliest_start / 1_000_000_000, tz=UTC)
165 if earliest_start
166 else None
167 )
168 end_time = (
169 datetime.fromtimestamp(latest_end / 1_000_000_000, tz=UTC)
170 if latest_end
171 else None
172 )
173
174 return cls(
175 trace_id=trace_id,
176 start_time=start_time,
177 end_time=end_time
178 or start_time, # Use start_time as fallback for active traces
179 request_id=request_id,
180 user_id=user_id,
181 session_id=session_id,
182 app_name=app_name or settings.NAME,
183 app_version=app_version or settings.VERSION,
184 root_span_name=root_span.name if root_span else "",
185 )
186
187 def as_dict(self) -> dict[str, Any]:
188 spans = [
189 span.span_data for span in self.spans.query.all().order_by("start_time")
190 ]
191 logs = [
192 {
193 "timestamp": log.timestamp.isoformat(),
194 "level": log.level,
195 "message": log.message,
196 "span_id": log.span_id,
197 }
198 for log in self.logs.query.all().order_by("timestamp")
199 ]
200
201 return {
202 "trace_id": self.trace_id,
203 "start_time": self.start_time.isoformat(),
204 "end_time": self.end_time.isoformat(),
205 "duration_ms": self.duration_ms(),
206 "summary": self.summary,
207 "root_span_name": self.root_span_name,
208 "request_id": self.request_id,
209 "user_id": self.user_id,
210 "session_id": self.session_id,
211 "app_name": self.app_name,
212 "app_version": self.app_version,
213 "spans": spans,
214 "logs": logs,
215 }
216
217 def get_timeline_events(self) -> list[dict[str, Any]]:
218 """Get chronological list of spans and logs for unified timeline display."""
219 events: list[dict[str, Any]] = []
220
221 for span in self.spans.query.all().annotate_spans(): # type: ignore[attr-defined]
222 events.append(
223 {
224 "type": "span",
225 "timestamp": span.start_time,
226 "instance": span,
227 "span_level": span.level,
228 }
229 )
230
231 # Add logs for this span
232 for log in self.logs.query.filter(span=span):
233 events.append(
234 {
235 "type": "log",
236 "timestamp": log.timestamp,
237 "instance": log,
238 "span_level": span.level + 1,
239 }
240 )
241
242 # Add unlinked logs (logs without span)
243 for log in self.logs.query.filter(span__isnull=True):
244 events.append(
245 {
246 "type": "log",
247 "timestamp": log.timestamp,
248 "instance": log,
249 "span_level": 0,
250 }
251 )
252
253 # Sort by timestamp
254 return sorted(events, key=lambda x: x["timestamp"])
255
256
257class SpanQuerySet(models.QuerySet["Span"]):
258 def annotate_spans(self) -> list[Span]:
259 """Annotate spans with nesting levels and duplicate query warnings."""
260 spans: list[Span] = list(self.order_by("start_time"))
261
262 # Build span dictionary for parent lookups
263 span_dict: dict[str, Span] = {span.span_id: span for span in spans}
264
265 # Calculate nesting levels
266 for span in spans:
267 if not span.parent_id:
268 span.level = 0
269 else:
270 # Find parent's level and add 1
271 parent = span_dict.get(span.parent_id)
272 parent_level = parent.level if parent else 0
273 span.level = parent_level + 1
274
275 query_counts: dict[str, int] = {}
276
277 # First pass: count queries
278 for span in spans:
279 if sql_query := span.sql_query:
280 query_counts[sql_query] = query_counts.get(sql_query, 0) + 1
281
282 # Second pass: add annotations
283 query_occurrences: dict[str, int] = {}
284 for span in spans:
285 span.annotations = []
286
287 # Check for duplicate queries
288 if sql_query := span.sql_query:
289 count = query_counts[sql_query]
290 if count > 1:
291 occurrence = query_occurrences.get(sql_query, 0) + 1
292 query_occurrences[sql_query] = occurrence
293
294 span.annotations.append(
295 {
296 "message": f"Duplicate query ({occurrence} of {count})",
297 "severity": "warning",
298 }
299 )
300
301 return spans
302
303
304@models.register_model
305class Span(models.Model):
306 trace: Trace = types.ForeignKeyField(Trace, on_delete=models.CASCADE)
307
308 span_id: str = types.CharField(max_length=255)
309
310 name: str = types.CharField(max_length=255)
311 kind: str = types.CharField(max_length=50)
312 parent_id: str = types.CharField(max_length=255, default="", required=False)
313 start_time: datetime = types.DateTimeField()
314 end_time: datetime = types.DateTimeField()
315 status: str = types.CharField(max_length=50, default="", required=False)
316 span_data: dict = types.JSONField(default=dict, required=False)
317
318 # Explicit reverse relation
319 logs: types.ReverseForeignKey[Log] = types.ReverseForeignKey(to="Log", field="span")
320
321 query: SpanQuerySet = SpanQuerySet()
322
323 model_options = models.Options(
324 ordering=["-start_time"],
325 constraints=[
326 models.UniqueConstraint(
327 fields=["trace", "span_id"],
328 name="observer_unique_span_id",
329 )
330 ],
331 indexes=[
332 models.Index(fields=["span_id"]),
333 models.Index(fields=["trace", "span_id"]),
334 models.Index(fields=["trace"]),
335 models.Index(fields=["start_time"]),
336 ],
337 )
338
339 if TYPE_CHECKING:
340 level: int
341 annotations: list[dict[str, Any]]
342
343 @classmethod
344 def from_opentelemetry_span(cls, otel_span: ReadableSpan, trace: Trace) -> Span:
345 """Create a Span instance from an OpenTelemetry span."""
346
347 span_data = json.loads(otel_span.to_json())
348
349 # Extract status code as string, default to empty string if unset
350 status = ""
351 if span_data.get("status") and span_data["status"].get("status_code"):
352 status = span_data["status"]["status_code"]
353
354 return cls(
355 trace=trace,
356 span_id=span_data["context"]["span_id"],
357 name=span_data["name"],
358 kind=span_data["kind"][len("SpanKind.") :],
359 parent_id=span_data["parent_id"] or "",
360 start_time=span_data["start_time"],
361 end_time=span_data["end_time"],
362 status=status,
363 span_data=span_data,
364 )
365
366 def __str__(self) -> str:
367 return self.span_id
368
369 @property
370 def attributes(self) -> Mapping[str, Any]:
371 """Get attributes from span_data."""
372 return cast(Mapping[str, Any], self.span_data.get("attributes", {}))
373
374 @property
375 def events(self) -> list[Mapping[str, Any]]:
376 """Get events from span_data."""
377 return cast(list[Mapping[str, Any]], self.span_data.get("events", []))
378
379 @property
380 def links(self) -> list[Mapping[str, Any]]:
381 """Get links from span_data."""
382 return cast(list[Mapping[str, Any]], self.span_data.get("links", []))
383
384 @property
385 def resource(self) -> Mapping[str, Any]:
386 """Get resource from span_data."""
387 return cast(Mapping[str, Any], self.span_data.get("resource", {}))
388
389 @property
390 def context(self) -> Mapping[str, Any]:
391 """Get context from span_data."""
392 return cast(Mapping[str, Any], self.span_data.get("context", {}))
393
394 def duration_ms(self) -> float:
395 if self.start_time and self.end_time:
396 return (self.end_time - self.start_time).total_seconds() * 1000
397 return 0
398
399 @cached_property
400 def sql_query(self) -> str | None:
401 """Get the SQL query if this span contains one."""
402 return self.attributes.get(db_attributes.DB_QUERY_TEXT)
403
404 @cached_property
405 def sql_query_params(self) -> dict[str, Any]:
406 """Get query parameters from attributes that start with 'db.query.parameter.'"""
407 if not self.attributes:
408 return {}
409
410 query_params: dict[str, Any] = {}
411 for key, value in self.attributes.items():
412 if key.startswith(DB_QUERY_PARAMETER_TEMPLATE + "."):
413 param_name = key.replace(DB_QUERY_PARAMETER_TEMPLATE + ".", "")
414 query_params[param_name] = value
415
416 return query_params
417
418 @cached_property
419 def source_code_location(self) -> dict[str, Any] | None:
420 """Get the source code location attributes from this span."""
421 if not self.attributes:
422 return None
423
424 # Look for common semantic convention code attributes
425 code_attrs = {}
426 code_attribute_mappings = {
427 CODE_FILE_PATH: "File",
428 CODE_LINE_NUMBER: "Line",
429 CODE_FUNCTION_NAME: "Function",
430 CODE_NAMESPACE: "Namespace",
431 CODE_COLUMN_NUMBER: "Column",
432 CODE_STACKTRACE: "Stacktrace",
433 }
434
435 for attr_key, display_name in code_attribute_mappings.items():
436 if attr_key in self.attributes:
437 code_attrs[display_name] = self.attributes[attr_key]
438
439 return code_attrs if code_attrs else None
440
441 def get_formatted_sql(self) -> str | None:
442 """Get the pretty-formatted SQL query if this span contains one."""
443 sql = self.sql_query
444 if not sql:
445 return None
446
447 return sqlparse.format(
448 sql,
449 reindent=True,
450 keyword_case="upper",
451 identifier_case="lower",
452 strip_comments=False,
453 strip_whitespace=True,
454 indent_width=2,
455 wrap_after=80,
456 comma_first=False,
457 )
458
459 def format_event_timestamp(
460 self, timestamp: float | int | datetime | str
461 ) -> datetime | str:
462 """Convert event timestamp to a readable datetime."""
463 if isinstance(timestamp, int | float):
464 ts_value = float(timestamp)
465 try:
466 # Try as seconds first
467 if ts_value > 1e10: # Likely nanoseconds
468 ts_value /= 1e9
469 elif ts_value > 1e7: # Likely milliseconds
470 ts_value /= 1e3
471
472 return datetime.fromtimestamp(ts_value, tz=UTC)
473 except (ValueError, OSError):
474 return str(ts_value)
475 if isinstance(timestamp, datetime):
476 return timestamp
477 return str(timestamp)
478
479 def get_exception_stacktrace(self) -> str | None:
480 """Get the exception stacktrace if this span has an exception event."""
481 if not self.events:
482 return None
483
484 for event in self.events:
485 if event.get("name") == "exception" and event.get("attributes"):
486 return event["attributes"].get(
487 exception_attributes.EXCEPTION_STACKTRACE
488 )
489 return None
490
491
492@models.register_model
493class Log(models.Model):
494 trace: Trace = types.ForeignKeyField(Trace, on_delete=models.CASCADE)
495 trace_id: int
496 span: Span | None = types.ForeignKeyField(
497 Span,
498 on_delete=models.SET_NULL,
499 allow_null=True,
500 required=False,
501 )
502 span_id: int | None
503
504 timestamp: datetime = types.DateTimeField()
505 level: str = types.CharField(max_length=20)
506 message: str = types.TextField()
507
508 query: models.QuerySet[Log] = models.QuerySet()
509
510 model_options = models.Options(
511 ordering=["timestamp"],
512 indexes=[
513 models.Index(fields=["trace", "timestamp"]),
514 models.Index(fields=["trace", "span"]),
515 models.Index(fields=["timestamp"]),
516 models.Index(fields=["trace"]),
517 ],
518 )