1from __future__ import annotations
2
3import re
4import traceback
5from collections.abc import Generator
6from contextlib import contextmanager
7from typing import TYPE_CHECKING, Any
8
9from opentelemetry import context as otel_context
10from opentelemetry import trace
11
12if TYPE_CHECKING:
13 from opentelemetry.trace import Span
14
15 from plain.postgres.connection import DatabaseConnection
16
17from opentelemetry.semconv._incubating.attributes.db_attributes import (
18 DB_QUERY_PARAMETER_TEMPLATE,
19 DB_USER,
20)
21from opentelemetry.semconv.attributes.code_attributes import (
22 CODE_COLUMN_NUMBER,
23 CODE_FILE_PATH,
24 CODE_FUNCTION_NAME,
25 CODE_LINE_NUMBER,
26 CODE_STACKTRACE,
27)
28from opentelemetry.semconv.attributes.db_attributes import (
29 DB_COLLECTION_NAME,
30 DB_NAMESPACE,
31 DB_OPERATION_NAME,
32 DB_QUERY_SUMMARY,
33 DB_QUERY_TEXT,
34 DB_SYSTEM_NAME,
35)
36from opentelemetry.semconv.attributes.network_attributes import (
37 NETWORK_PEER_ADDRESS,
38 NETWORK_PEER_PORT,
39)
40from opentelemetry.semconv.trace import DbSystemValues
41from opentelemetry.trace import SpanKind
42
43from plain.runtime import settings
44
45# Use a stable string key so OpenTelemetry context APIs receive the expected type.
46_SUPPRESS_KEY = "plain.postgres.suppress_db_tracing"
47
48tracer = trace.get_tracer("plain.postgres")
49
50
51DB_SYSTEM = DbSystemValues.POSTGRESQL.value
52
53
54def extract_operation_and_target(sql: str) -> tuple[str, str | None, str | None]:
55 """Extract operation, table name, and collection from SQL.
56
57 Returns: (operation, summary, collection_name)
58 """
59 sql_upper = sql.upper().strip()
60
61 # Strip leading parentheses (e.g. UNION queries: "(SELECT ... UNION ...)")
62 operation = sql_upper.lstrip("(").split()[0] if sql_upper else "UNKNOWN"
63
64 # Pattern to match quoted and unquoted identifiers
65 # Matches: "quoted" (PostgreSQL), unquoted.name
66 identifier_pattern = r'("([^"]+)"|([\w.]+))'
67
68 # Map operations to the SQL keyword that precedes the table name.
69 keyword_by_operation = {
70 "SELECT": "FROM",
71 "DELETE": "FROM",
72 "INSERT": "INTO",
73 "UPDATE": "UPDATE",
74 }
75
76 # Extract table/collection name based on operation
77 collection_name = None
78 summary = operation
79
80 keyword = keyword_by_operation.get(operation)
81 if keyword:
82 match = re.search(rf"{keyword}\s+{identifier_pattern}", sql, re.IGNORECASE)
83 if match:
84 collection_name = _clean_identifier(match.group(1))
85 summary = f"{operation} {collection_name}"
86
87 # Detect UNION queries
88 if " UNION " in sql_upper and summary:
89 summary = f"{summary} UNION"
90
91 return operation, summary, collection_name
92
93
94def _clean_identifier(identifier: str) -> str:
95 """Remove quotes from SQL identifiers."""
96 if identifier.startswith('"') and identifier.endswith('"'):
97 return identifier[1:-1]
98 return identifier
99
100
101@contextmanager
102def db_span(
103 db: DatabaseConnection, sql: Any, *, many: bool = False, params: Any = None
104) -> Generator[Span | None]:
105 """Open an OpenTelemetry CLIENT span for a database query.
106
107 All common attributes (`db.*`, `network.*`, etc.) are set automatically.
108 Follows OpenTelemetry semantic conventions for database instrumentation.
109 """
110
111 # Fast-exit if instrumentation suppression flag set in context.
112 if otel_context.get_value(_SUPPRESS_KEY):
113 yield None
114 return
115
116 sql = str(sql) # Ensure SQL is a string for span attributes.
117
118 # Extract operation and target information
119 operation, summary, collection_name = extract_operation_and_target(sql)
120
121 if many:
122 summary = f"{summary} many"
123
124 # Span name follows semantic conventions: {target} or {db.operation.name} {target}
125 if summary:
126 span_name = summary[:255]
127 else:
128 span_name = operation
129
130 # Build attribute set following semantic conventions
131 attrs: dict[str, Any] = {
132 DB_SYSTEM_NAME: DB_SYSTEM,
133 DB_NAMESPACE: db.settings_dict.get("DATABASE"),
134 DB_QUERY_TEXT: sql, # Already parameterized from Django/Plain
135 DB_QUERY_SUMMARY: summary,
136 DB_OPERATION_NAME: operation,
137 }
138
139 attrs.update(_get_code_attributes())
140
141 # Add collection name if detected
142 if collection_name:
143 attrs[DB_COLLECTION_NAME] = collection_name
144
145 # Add user attribute
146 if user := db.settings_dict.get("USER"):
147 attrs[DB_USER] = user
148
149 # Network attributes
150 if host := db.settings_dict.get("HOST"):
151 attrs[NETWORK_PEER_ADDRESS] = host
152
153 if port := db.settings_dict.get("PORT"):
154 try:
155 attrs[NETWORK_PEER_PORT] = int(port)
156 except (TypeError, ValueError):
157 pass
158
159 # Add query parameters as attributes when DEBUG is True
160 if settings.DEBUG and params is not None:
161 # Convert params to appropriate format based on type
162 if isinstance(params, dict):
163 # Dictionary params (e.g., for named placeholders)
164 for key, value in params.items():
165 attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{key}"] = str(value)
166 elif isinstance(params, list | tuple):
167 # Sequential params (e.g., for %s or ? placeholders)
168 for i, value in enumerate(params):
169 attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{i + 1}"] = str(value)
170 else:
171 # Single param (rare but possible)
172 attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.1"] = str(params)
173
174 with tracer.start_as_current_span(
175 span_name, kind=SpanKind.CLIENT, attributes=attrs
176 ) as span:
177 yield span
178 span.set_status(trace.StatusCode.OK)
179
180
181@contextmanager
182def suppress_db_tracing() -> Generator[None]:
183 token = otel_context.attach(otel_context.set_value(_SUPPRESS_KEY, True))
184 try:
185 yield
186 finally:
187 otel_context.detach(token)
188
189
190def _is_internal_frame(frame: traceback.FrameSummary) -> bool:
191 """Return True if the frame is internal to plain.postgres or contextlib."""
192 filepath = frame.filename
193 if not filepath:
194 return True
195 if "/plain/postgres/" in filepath:
196 return True
197 if filepath.endswith("contextlib.py"):
198 return True
199 return False
200
201
202def _get_code_attributes() -> dict[str, Any]:
203 """Extract code context attributes for the current database query.
204
205 Returns a dict of OpenTelemetry code attributes.
206 """
207 stack = traceback.extract_stack()
208
209 # Find the first user code frame (outermost non-internal frame from the top of the call stack)
210 for frame in reversed(stack):
211 if _is_internal_frame(frame):
212 continue
213
214 attrs: dict[str, Any] = {
215 CODE_FILE_PATH: frame.filename,
216 }
217 if frame.lineno:
218 attrs[CODE_LINE_NUMBER] = frame.lineno
219 if frame.name:
220 attrs[CODE_FUNCTION_NAME] = frame.name
221 if frame.colno:
222 attrs[CODE_COLUMN_NUMBER] = frame.colno
223
224 # Add full stack trace only in DEBUG mode (expensive)
225 if settings.DEBUG:
226 filtered_stack = [f for f in stack if not _is_internal_frame(f)]
227 attrs[CODE_STACKTRACE] = "".join(traceback.format_list(filtered_stack))
228
229 return attrs
230
231 return {}