1from __future__ import annotations
2
3import json
4from collections import Counter
5from collections.abc import Iterable, Mapping, Sequence
6from datetime import UTC, datetime
7from functools import cached_property
8from typing import TYPE_CHECKING, Any, cast
9
10import sqlparse
11from opentelemetry.sdk.trace import ReadableSpan
12from opentelemetry.semconv._incubating.attributes import (
13 exception_attributes,
14 session_attributes,
15 user_attributes,
16)
17from opentelemetry.semconv._incubating.attributes.code_attributes import (
18 CODE_NAMESPACE,
19)
20from opentelemetry.semconv._incubating.attributes.db_attributes import (
21 DB_QUERY_PARAMETER_TEMPLATE,
22)
23from opentelemetry.semconv._incubating.attributes.http_attributes import (
24 HTTP_RESPONSE_BODY_SIZE,
25)
26from opentelemetry.semconv.attributes import db_attributes, service_attributes
27from opentelemetry.semconv.attributes.code_attributes import (
28 CODE_COLUMN_NUMBER,
29 CODE_FILE_PATH,
30 CODE_FUNCTION_NAME,
31 CODE_LINE_NUMBER,
32 CODE_STACKTRACE,
33)
34from opentelemetry.trace import format_trace_id
35
36from plain import postgres
37from plain.postgres import types
38from plain.runtime import settings
39from plain.urls import reverse
40
41__all__ = ["Log", "Span", "Trace"]
42
43
44from .formatting import format_bytes
45
46
47@postgres.register_model
48class Trace(postgres.Model):
49 trace_id: str = types.TextField(max_length=255)
50 start_time: datetime = types.DateTimeField()
51 end_time: datetime = types.DateTimeField()
52
53 root_span_name: str = types.TextField(default="", required=False)
54 summary: str = types.TextField(max_length=255, default="", required=False)
55
56 # Plain fields
57 request_id: str = types.TextField(max_length=255, default="", required=False)
58 session_id: str = types.TextField(max_length=255, default="", required=False)
59 user_id: str = types.TextField(max_length=255, default="", required=False)
60 app_name: str = types.TextField(max_length=255, default="", required=False)
61 app_version: str = types.TextField(max_length=255, default="", required=False)
62
63 # Explicit reverse relations
64 spans: types.ReverseForeignKey[Span] = types.ReverseForeignKey(
65 to="Span", field="trace"
66 )
67 logs: types.ReverseForeignKey[Log] = types.ReverseForeignKey(
68 to="Log", field="trace"
69 )
70
71 query: postgres.QuerySet[Trace] = postgres.QuerySet()
72
73 model_options = postgres.Options(
74 ordering=["-start_time"],
75 constraints=[
76 postgres.UniqueConstraint(
77 fields=["trace_id"],
78 name="observer_unique_trace_id",
79 )
80 ],
81 indexes=[
82 postgres.Index(
83 name="plainobserver_trace_start_time_idx", fields=["start_time"]
84 ),
85 postgres.Index(
86 name="plainobserver_trace_request_id_idx", fields=["request_id"]
87 ),
88 postgres.Index(
89 name="plainobserver_trace_session_id_idx", fields=["session_id"]
90 ),
91 ],
92 )
93
94 def __str__(self) -> str:
95 return self.trace_id
96
97 def get_absolute_url(self) -> str:
98 """Return the canonical URL for this trace."""
99 return reverse("observer:trace_detail", trace_id=self.trace_id)
100
101 def duration_ms(self) -> float:
102 return (self.end_time - self.start_time).total_seconds() * 1000
103
104 def get_trace_stats(self, spans: Iterable[Span]) -> dict[str, Any]:
105 """Extract structured performance stats from trace spans."""
106 query_texts: list[str] = []
107 response_body_size: int | None = None
108 for span in spans:
109 if query_text := span.attributes.get(db_attributes.DB_QUERY_TEXT):
110 query_texts.append(query_text)
111 if (body_size := span.attributes.get(HTTP_RESPONSE_BODY_SIZE)) is not None:
112 response_body_size = int(body_size)
113
114 query_counts = Counter(query_texts)
115 return {
116 "query_count": len(query_texts),
117 "duplicate_count": sum(query_counts.values()) - len(query_counts),
118 "duration_ms": self.duration_ms(),
119 "response_body_size": response_body_size,
120 }
121
122 def get_trace_summary(self, spans: Iterable[Span]) -> str:
123 """Format trace stats as a human-readable summary string."""
124 stats = self.get_trace_stats(spans)
125 parts: list[str] = []
126
127 query_total = stats["query_count"]
128 if query_total > 0:
129 query_part = f"{query_total} quer{'y' if query_total == 1 else 'ies'}"
130 duplicate_count = stats["duplicate_count"]
131 if duplicate_count > 0:
132 query_part += f" ({duplicate_count} duplicate{'' if duplicate_count == 1 else 's'})"
133 parts.append(query_part)
134
135 if stats["duration_ms"] is not None:
136 parts.append(f"{round(stats['duration_ms'], 1)}ms")
137
138 if stats["response_body_size"] is not None:
139 parts.append(format_bytes(stats["response_body_size"]))
140
141 return " • ".join(parts)
142
143 @classmethod
144 def from_opentelemetry_spans(cls, spans: Sequence[ReadableSpan]) -> Trace:
145 """Create a Trace instance from a list of OpenTelemetry spans."""
146 # Get trace information from the first span
147 first_span = spans[0]
148 trace_id = f"0x{format_trace_id(first_span.get_span_context().trace_id)}"
149
150 # Find trace boundaries and root span info
151 earliest_start = None
152 latest_end = None
153 root_span = None
154 request_id = ""
155 user_id = ""
156 session_id = ""
157 app_name = ""
158 app_version = ""
159
160 for span in spans:
161 if not span.parent:
162 root_span = span
163
164 if span.start_time and (
165 earliest_start is None or span.start_time < earliest_start
166 ):
167 earliest_start = span.start_time
168 # Only update latest_end if the span has actually ended
169 if span.end_time and (latest_end is None or span.end_time > latest_end):
170 latest_end = span.end_time
171
172 # For OpenTelemetry spans, access attributes directly
173 span_attrs = getattr(span, "attributes", {})
174 request_id = request_id or span_attrs.get("plain.request.id", "")
175 user_id = user_id or span_attrs.get(user_attributes.USER_ID, "")
176 session_id = session_id or span_attrs.get(session_attributes.SESSION_ID, "")
177
178 # Access Resource attributes if not found in span attributes
179 if resource := getattr(span, "resource", None):
180 app_name = app_name or resource.attributes.get(
181 service_attributes.SERVICE_NAME, ""
182 )
183 app_version = app_version or resource.attributes.get(
184 service_attributes.SERVICE_VERSION, ""
185 )
186
187 # Convert timestamps
188 start_time = (
189 datetime.fromtimestamp(earliest_start / 1_000_000_000, tz=UTC)
190 if earliest_start
191 else None
192 )
193 end_time = (
194 datetime.fromtimestamp(latest_end / 1_000_000_000, tz=UTC)
195 if latest_end
196 else None
197 )
198
199 return cls(
200 trace_id=trace_id,
201 start_time=start_time,
202 end_time=end_time
203 or start_time, # Use start_time as fallback for active traces
204 request_id=request_id,
205 user_id=user_id,
206 session_id=session_id,
207 app_name=app_name or settings.NAME,
208 app_version=app_version or settings.VERSION,
209 root_span_name=root_span.name if root_span else "",
210 )
211
212 def as_dict(self) -> dict[str, Any]:
213 spans = [
214 span.span_data for span in self.spans.query.all().order_by("start_time")
215 ]
216 logs = [
217 {
218 "timestamp": log.timestamp.isoformat(),
219 "level": log.level,
220 "message": log.message,
221 "span_id": log.span_id,
222 }
223 for log in self.logs.query.all().order_by("timestamp")
224 ]
225
226 return {
227 "trace_id": self.trace_id,
228 "start_time": self.start_time.isoformat(),
229 "end_time": self.end_time.isoformat(),
230 "duration_ms": self.duration_ms(),
231 "summary": self.summary,
232 "root_span_name": self.root_span_name,
233 "request_id": self.request_id,
234 "user_id": self.user_id,
235 "session_id": self.session_id,
236 "app_name": self.app_name,
237 "app_version": self.app_version,
238 "spans": spans,
239 "logs": logs,
240 }
241
242 def get_timeline_events(self) -> list[dict[str, Any]]:
243 """Get chronological list of spans and logs for unified timeline display."""
244 events: list[dict[str, Any]] = []
245
246 for span in self.spans.query.all().annotate_spans(): # ty: ignore[unresolved-attribute]
247 events.append(
248 {
249 "type": "span",
250 "timestamp": span.start_time,
251 "instance": span,
252 "span_level": span.level,
253 }
254 )
255
256 # Add logs for this span
257 for log in self.logs.query.filter(span=span):
258 events.append(
259 {
260 "type": "log",
261 "timestamp": log.timestamp,
262 "instance": log,
263 "span_level": span.level + 1,
264 }
265 )
266
267 # Add unlinked logs (logs without span)
268 for log in self.logs.query.filter(span__isnull=True):
269 events.append(
270 {
271 "type": "log",
272 "timestamp": log.timestamp,
273 "instance": log,
274 "span_level": 0,
275 }
276 )
277
278 # Sort by timestamp
279 return sorted(events, key=lambda x: x["timestamp"])
280
281
282class SpanQuerySet(postgres.QuerySet["Span"]):
283 def annotate_spans(self) -> list[Span]:
284 """Annotate spans with nesting levels and duplicate query warnings."""
285 spans: list[Span] = list(self.order_by("start_time"))
286
287 # Build span dictionary for parent lookups
288 span_dict: dict[str, Span] = {span.span_id: span for span in spans}
289
290 # Calculate nesting levels
291 for span in spans:
292 if not span.parent_id:
293 span.level = 0
294 else:
295 # Find parent's level and add 1
296 parent = span_dict.get(span.parent_id)
297 parent_level = parent.level if parent else 0
298 span.level = parent_level + 1
299
300 query_counts: dict[str, int] = {}
301
302 # First pass: count queries
303 for span in spans:
304 if sql_query := span.sql_query:
305 query_counts[sql_query] = query_counts.get(sql_query, 0) + 1
306
307 # Second pass: add annotations
308 query_occurrences: dict[str, int] = {}
309 for span in spans:
310 span.annotations = []
311
312 # Check for duplicate queries
313 if sql_query := span.sql_query:
314 count = query_counts[sql_query]
315 if count > 1:
316 occurrence = query_occurrences.get(sql_query, 0) + 1
317 query_occurrences[sql_query] = occurrence
318
319 span.annotations.append(
320 {
321 "message": f"Duplicate query ({occurrence} of {count})",
322 "severity": "warning",
323 }
324 )
325
326 return spans
327
328
329@postgres.register_model
330class Span(postgres.Model):
331 trace: Trace = types.ForeignKeyField(Trace, on_delete=postgres.CASCADE)
332
333 span_id: str = types.TextField(max_length=255)
334
335 name: str = types.TextField(max_length=255)
336 kind: str = types.TextField(max_length=50)
337 parent_id: str = types.TextField(max_length=255, default="", required=False)
338 start_time: datetime = types.DateTimeField()
339 end_time: datetime = types.DateTimeField()
340 status: str = types.TextField(max_length=50, default="", required=False)
341 span_data: dict = types.JSONField(default={}, required=False)
342
343 # Explicit reverse relation
344 logs: types.ReverseForeignKey[Log] = types.ReverseForeignKey(to="Log", field="span")
345
346 query: SpanQuerySet = SpanQuerySet()
347
348 model_options = postgres.Options(
349 ordering=["-start_time"],
350 constraints=[
351 postgres.UniqueConstraint(
352 fields=["trace", "span_id"],
353 name="observer_unique_span_id",
354 )
355 ],
356 indexes=[
357 postgres.Index(name="plainobserver_span_span_id_idx", fields=["span_id"]),
358 postgres.Index(
359 name="plainobserver_span_start_time_idx", fields=["start_time"]
360 ),
361 ],
362 )
363
364 if TYPE_CHECKING:
365 level: int
366 annotations: list[dict[str, Any]]
367
368 @classmethod
369 def from_opentelemetry_span(cls, otel_span: ReadableSpan, trace: Trace) -> Span:
370 """Create a Span instance from an OpenTelemetry span."""
371
372 span_data = json.loads(otel_span.to_json())
373
374 # Extract status code as string, default to empty string if unset
375 status = ""
376 if span_data.get("status") and span_data["status"].get("status_code"):
377 status = span_data["status"]["status_code"]
378
379 return cls(
380 trace=trace,
381 span_id=span_data["context"]["span_id"],
382 name=span_data["name"],
383 kind=span_data["kind"][len("SpanKind.") :],
384 parent_id=span_data["parent_id"] or "",
385 start_time=span_data["start_time"],
386 end_time=span_data["end_time"],
387 status=status,
388 span_data=span_data,
389 )
390
391 def __str__(self) -> str:
392 return self.span_id
393
394 @property
395 def attributes(self) -> Mapping[str, Any]:
396 """Get attributes from span_data."""
397 return cast(Mapping[str, Any], self.span_data.get("attributes", {}))
398
399 @property
400 def events(self) -> list[Mapping[str, Any]]:
401 """Get events from span_data."""
402 return cast(list[Mapping[str, Any]], self.span_data.get("events", []))
403
404 @property
405 def links(self) -> list[Mapping[str, Any]]:
406 """Get links from span_data."""
407 return cast(list[Mapping[str, Any]], self.span_data.get("links", []))
408
409 @property
410 def resource(self) -> Mapping[str, Any]:
411 """Get resource from span_data."""
412 return cast(Mapping[str, Any], self.span_data.get("resource", {}))
413
414 @property
415 def context(self) -> Mapping[str, Any]:
416 """Get context from span_data."""
417 return cast(Mapping[str, Any], self.span_data.get("context", {}))
418
419 def duration_ms(self) -> float:
420 if self.start_time and self.end_time:
421 return (self.end_time - self.start_time).total_seconds() * 1000
422 return 0
423
424 @cached_property
425 def sql_query(self) -> str | None:
426 """Get the SQL query if this span contains one."""
427 return self.attributes.get(db_attributes.DB_QUERY_TEXT)
428
429 @cached_property
430 def sql_query_params(self) -> dict[str, Any]:
431 """Get query parameters from attributes that start with 'db.query.parameter.'"""
432 if not self.attributes:
433 return {}
434
435 query_params: dict[str, Any] = {}
436 for key, value in self.attributes.items():
437 if key.startswith(DB_QUERY_PARAMETER_TEMPLATE + "."):
438 param_name = key.replace(DB_QUERY_PARAMETER_TEMPLATE + ".", "")
439 query_params[param_name] = value
440
441 return query_params
442
443 @cached_property
444 def source_code_location(self) -> dict[str, Any] | None:
445 """Get the source code location attributes from this span."""
446 if not self.attributes:
447 return None
448
449 # Look for common semantic convention code attributes
450 code_attrs = {}
451 code_attribute_mappings = {
452 CODE_FILE_PATH: "File",
453 CODE_LINE_NUMBER: "Line",
454 CODE_FUNCTION_NAME: "Function",
455 CODE_NAMESPACE: "Namespace",
456 CODE_COLUMN_NUMBER: "Column",
457 CODE_STACKTRACE: "Stacktrace",
458 }
459
460 for attr_key, display_name in code_attribute_mappings.items():
461 if attr_key in self.attributes:
462 code_attrs[display_name] = self.attributes[attr_key]
463
464 return code_attrs if code_attrs else None
465
466 def get_formatted_sql(self) -> str | None:
467 """Get the pretty-formatted SQL query if this span contains one."""
468 sql = self.sql_query
469 if not sql:
470 return None
471
472 return sqlparse.format(
473 sql,
474 reindent=True,
475 keyword_case="upper",
476 identifier_case="lower",
477 strip_comments=False,
478 strip_whitespace=True,
479 indent_width=2,
480 wrap_after=80,
481 comma_first=False,
482 )
483
484 def format_event_timestamp(
485 self, timestamp: float | int | datetime | str
486 ) -> datetime | str:
487 """Convert event timestamp to a readable datetime."""
488 if isinstance(timestamp, int | float):
489 ts_value = float(timestamp)
490 try:
491 # Try as seconds first
492 if ts_value > 1e10: # Likely nanoseconds
493 ts_value /= 1e9
494 elif ts_value > 1e7: # Likely milliseconds
495 ts_value /= 1e3
496
497 return datetime.fromtimestamp(ts_value, tz=UTC)
498 except (ValueError, OSError):
499 return str(ts_value)
500 if isinstance(timestamp, datetime):
501 return timestamp
502 return str(timestamp)
503
504 def get_exception_stacktrace(self) -> str | None:
505 """Get the exception stacktrace if this span has an exception event."""
506 if not self.events:
507 return None
508
509 for event in self.events:
510 if event.get("name") == "exception" and event.get("attributes"):
511 return event["attributes"].get(
512 exception_attributes.EXCEPTION_STACKTRACE
513 )
514 return None
515
516
517@postgres.register_model
518class Log(postgres.Model):
519 trace: Trace = types.ForeignKeyField(Trace, on_delete=postgres.CASCADE)
520 trace_id: int
521 span: Span | None = types.ForeignKeyField(
522 Span,
523 on_delete=postgres.SET_NULL,
524 allow_null=True,
525 required=False,
526 )
527 span_id: int | None
528
529 timestamp: datetime = types.DateTimeField()
530 level: str = types.TextField(max_length=20)
531 message: str = types.TextField()
532
533 query: postgres.QuerySet[Log] = postgres.QuerySet()
534
535 model_options = postgres.Options(
536 ordering=["timestamp"],
537 indexes=[
538 postgres.Index(
539 name="plainobserver_log_trace_id_timestamp_idx",
540 fields=["trace", "timestamp"],
541 ),
542 postgres.Index(
543 name="plainobserver_log_trace_id_span_id_idx", fields=["trace", "span"]
544 ),
545 postgres.Index(
546 name="plainobserver_log_timestamp_idx", fields=["timestamp"]
547 ),
548 postgres.Index(name="plainobserver_log_span_id_idx", fields=["span"]),
549 ],
550 )