Plain is headed towards 1.0! Subscribe for development updates →

  1import json
  2from datetime import UTC, datetime
  3from functools import cached_property
  4
  5import sqlparse
  6from opentelemetry.semconv._incubating.attributes import (
  7    exception_attributes,
  8    session_attributes,
  9    user_attributes,
 10)
 11from opentelemetry.semconv._incubating.attributes.db_attributes import (
 12    DB_QUERY_PARAMETER_TEMPLATE,
 13)
 14from opentelemetry.semconv.attributes import db_attributes
 15from opentelemetry.trace import format_trace_id
 16
 17from plain import models
 18
 19
 20@models.register_model
 21class Trace(models.Model):
 22    trace_id = models.CharField(max_length=255)
 23    start_time = models.DateTimeField()
 24    end_time = models.DateTimeField()
 25
 26    root_span_name = models.TextField(default="", required=False)
 27
 28    # Plain fields
 29    request_id = models.CharField(max_length=255, default="", required=False)
 30    session_id = models.CharField(max_length=255, default="", required=False)
 31    user_id = models.CharField(max_length=255, default="", required=False)
 32
 33    class Meta:
 34        ordering = ["-start_time"]
 35        constraints = [
 36            models.UniqueConstraint(
 37                fields=["trace_id"],
 38                name="observer_unique_trace_id",
 39            )
 40        ]
 41
 42    def __str__(self):
 43        return self.trace_id
 44
 45    def duration_ms(self):
 46        return (self.end_time - self.start_time).total_seconds() * 1000
 47
 48    def get_trace_summary(self, spans=None):
 49        """Get a concise summary string for toolbar display.
 50
 51        Args:
 52            spans: Optional list of span objects. If not provided, will query from database.
 53        """
 54        # Get spans from database if not provided
 55        if spans is None:
 56            spans = list(self.spans.all())
 57
 58        if not spans:
 59            return ""
 60
 61        # Count database queries and track duplicates
 62        query_counts = {}
 63        db_queries = 0
 64
 65        for span in spans:
 66            if span.attributes.get(db_attributes.DB_SYSTEM_NAME):
 67                db_queries += 1
 68                if query_text := span.attributes.get(db_attributes.DB_QUERY_TEXT):
 69                    query_counts[query_text] = query_counts.get(query_text, 0) + 1
 70
 71        # Count duplicate queries (queries that appear more than once)
 72        duplicate_count = sum(count - 1 for count in query_counts.values() if count > 1)
 73
 74        # Build summary: "n spans, n queries (n duplicates), Xms"
 75        parts = []
 76
 77        # Queries count with duplicates
 78        if db_queries > 0:
 79            query_part = f"{db_queries} quer{'y' if db_queries == 1 else 'ies'}"
 80            if duplicate_count > 0:
 81                query_part += f" ({duplicate_count} duplicate{'' if duplicate_count == 1 else 's'})"
 82            parts.append(query_part)
 83
 84        # Duration
 85        if (duration_ms := self.duration_ms()) is not None:
 86            parts.append(f"{round(duration_ms, 1)}ms")
 87
 88        return " • ".join(parts)
 89
 90    @classmethod
 91    def from_opentelemetry_spans(cls, spans):
 92        """Create a Trace instance from a list of OpenTelemetry spans."""
 93        # Get trace information from the first span
 94        first_span = spans[0]
 95        trace_id = f"0x{format_trace_id(first_span.get_span_context().trace_id)}"
 96
 97        # Find trace boundaries and root span info
 98        earliest_start = None
 99        latest_end = None
100        root_span = None
101        request_id = ""
102        user_id = ""
103        session_id = ""
104
105        for span in spans:
106            if not span.parent:
107                root_span = span
108
109            if span.start_time and (
110                earliest_start is None or span.start_time < earliest_start
111            ):
112                earliest_start = span.start_time
113            # Only update latest_end if the span has actually ended
114            if span.end_time and (latest_end is None or span.end_time > latest_end):
115                latest_end = span.end_time
116
117            # For OpenTelemetry spans, access attributes directly
118            span_attrs = getattr(span, "attributes", {})
119            request_id = request_id or span_attrs.get("plain.request.id", "")
120            user_id = user_id or span_attrs.get(user_attributes.USER_ID, "")
121            session_id = session_id or span_attrs.get(session_attributes.SESSION_ID, "")
122
123        # Convert timestamps
124        start_time = (
125            datetime.fromtimestamp(earliest_start / 1_000_000_000, tz=UTC)
126            if earliest_start
127            else None
128        )
129        end_time = (
130            datetime.fromtimestamp(latest_end / 1_000_000_000, tz=UTC)
131            if latest_end
132            else None
133        )
134
135        # Create trace instance
136        # Note: end_time might be None if there are active spans
137        # This is OK since this trace is only used for summaries, not persistence
138        return cls(
139            trace_id=trace_id,
140            start_time=start_time,
141            end_time=end_time
142            or start_time,  # Use start_time as fallback for active traces
143            request_id=request_id,
144            user_id=user_id,
145            session_id=session_id,
146            root_span_name=root_span.name if root_span else "",
147        )
148
149    def get_annotated_spans(self):
150        """Return spans with annotations and nesting information."""
151        spans = list(self.spans.all().order_by("start_time"))
152
153        # Build span dictionary for parent lookups
154        span_dict = {span.span_id: span for span in spans}
155
156        # Calculate nesting levels
157        for span in spans:
158            if not span.parent_id:
159                span.level = 0
160            else:
161                # Find parent's level and add 1
162                parent = span_dict.get(span.parent_id)
163                parent_level = parent.level if parent else 0
164                span.level = parent_level + 1
165
166        query_counts = {}
167
168        # First pass: count queries
169        for span in spans:
170            if sql_query := span.sql_query:
171                query_counts[sql_query] = query_counts.get(sql_query, 0) + 1
172
173        # Second pass: add annotations
174        query_occurrences = {}
175        for span in spans:
176            span.annotations = []
177
178            # Check for duplicate queries
179            if sql_query := span.sql_query:
180                count = query_counts[sql_query]
181                if count > 1:
182                    occurrence = query_occurrences.get(sql_query, 0) + 1
183                    query_occurrences[sql_query] = occurrence
184
185                    span.annotations.append(
186                        {
187                            "message": f"Duplicate query ({occurrence} of {count})",
188                            "severity": "warning",
189                        }
190                    )
191
192        return spans
193
194    def as_dict(self):
195        spans = [span.span_data for span in self.spans.all().order_by("start_time")]
196
197        return {
198            "trace_id": self.trace_id,
199            "start_time": self.start_time.isoformat(),
200            "end_time": self.end_time.isoformat(),
201            "duration_ms": self.duration_ms(),
202            "request_id": self.request_id,
203            "user_id": self.user_id,
204            "session_id": self.session_id,
205            "spans": spans,
206        }
207
208
209@models.register_model
210class Span(models.Model):
211    trace = models.ForeignKey(Trace, on_delete=models.CASCADE, related_name="spans")
212
213    span_id = models.CharField(max_length=255)
214
215    name = models.CharField(max_length=255)
216    kind = models.CharField(max_length=50)
217    parent_id = models.CharField(max_length=255, default="", required=False)
218    start_time = models.DateTimeField()
219    end_time = models.DateTimeField()
220    status = models.CharField(max_length=50, default="", required=False)
221    span_data = models.JSONField(default=dict, required=False)
222
223    class Meta:
224        ordering = ["-start_time"]
225        constraints = [
226            models.UniqueConstraint(
227                fields=["trace", "span_id"],
228                name="observer_unique_span_id",
229            )
230        ]
231        indexes = [
232            models.Index(fields=["trace", "span_id"]),
233            models.Index(fields=["trace"]),
234            models.Index(fields=["start_time"]),
235        ]
236
237    @classmethod
238    def from_opentelemetry_span(cls, otel_span, trace):
239        """Create a Span instance from an OpenTelemetry span."""
240
241        span_data = json.loads(otel_span.to_json())
242
243        # Extract status code as string, default to empty string if unset
244        status = ""
245        if span_data.get("status") and span_data["status"].get("status_code"):
246            status = span_data["status"]["status_code"]
247
248        return cls(
249            trace=trace,
250            span_id=span_data["context"]["span_id"],
251            name=span_data["name"],
252            kind=span_data["kind"][len("SpanKind.") :],
253            parent_id=span_data["parent_id"] or "",
254            start_time=span_data["start_time"],
255            end_time=span_data["end_time"],
256            status=status,
257            span_data=span_data,
258        )
259
260    def __str__(self):
261        return self.span_id
262
263    @property
264    def attributes(self):
265        """Get attributes from span_data."""
266        return self.span_data.get("attributes", {})
267
268    @property
269    def events(self):
270        """Get events from span_data."""
271        return self.span_data.get("events", [])
272
273    @property
274    def links(self):
275        """Get links from span_data."""
276        return self.span_data.get("links", [])
277
278    @property
279    def resource(self):
280        """Get resource from span_data."""
281        return self.span_data.get("resource", {})
282
283    @property
284    def context(self):
285        """Get context from span_data."""
286        return self.span_data.get("context", {})
287
288    def duration_ms(self):
289        if self.start_time and self.end_time:
290            return (self.end_time - self.start_time).total_seconds() * 1000
291        return 0
292
293    @cached_property
294    def sql_query(self):
295        """Get the SQL query if this span contains one."""
296        return self.attributes.get(db_attributes.DB_QUERY_TEXT)
297
298    @cached_property
299    def sql_query_params(self):
300        """Get query parameters from attributes that start with 'db.query.parameter.'"""
301        if not self.attributes:
302            return {}
303
304        query_params = {}
305        for key, value in self.attributes.items():
306            if key.startswith(DB_QUERY_PARAMETER_TEMPLATE + "."):
307                param_name = key.replace(DB_QUERY_PARAMETER_TEMPLATE + ".", "")
308                query_params[param_name] = value
309
310        return query_params
311
312    def get_formatted_sql(self):
313        """Get the pretty-formatted SQL query if this span contains one."""
314        sql = self.sql_query
315        if not sql:
316            return None
317
318        return sqlparse.format(
319            sql,
320            reindent=True,
321            keyword_case="upper",
322            identifier_case="lower",
323            strip_comments=False,
324            strip_whitespace=True,
325            indent_width=2,
326            wrap_after=80,
327            comma_first=False,
328        )
329
330    def format_event_timestamp(self, timestamp):
331        """Convert event timestamp to a readable datetime."""
332        if isinstance(timestamp, int | float):
333            try:
334                # Try as seconds first
335                if timestamp > 1e10:  # Likely nanoseconds
336                    timestamp = timestamp / 1e9
337                elif timestamp > 1e7:  # Likely milliseconds
338                    timestamp = timestamp / 1e3
339
340                return datetime.fromtimestamp(timestamp, tz=UTC)
341            except (ValueError, OSError):
342                return str(timestamp)
343        return timestamp
344
345    def get_exception_stacktrace(self):
346        """Get the exception stacktrace if this span has an exception event."""
347        if not self.events:
348            return None
349
350        for event in self.events:
351            if event.get("name") == "exception" and event.get("attributes"):
352                return event["attributes"].get(
353                    exception_attributes.EXCEPTION_STACKTRACE
354                )
355        return None