1from __future__ import annotations
2
3import re
4import traceback
5from collections.abc import Generator
6from contextlib import contextmanager
7from typing import TYPE_CHECKING, Any
8
9from opentelemetry import context as otel_context
10from opentelemetry import trace
11
12if TYPE_CHECKING:
13 from opentelemetry.trace import Span
14
15 from plain.models.backends.base.base import BaseDatabaseWrapper
16from opentelemetry.semconv._incubating.attributes.db_attributes import (
17 DB_QUERY_PARAMETER_TEMPLATE,
18 DB_USER,
19)
20from opentelemetry.semconv.attributes.code_attributes import (
21 CODE_COLUMN_NUMBER,
22 CODE_FILE_PATH,
23 CODE_FUNCTION_NAME,
24 CODE_LINE_NUMBER,
25 CODE_STACKTRACE,
26)
27from opentelemetry.semconv.attributes.db_attributes import (
28 DB_COLLECTION_NAME,
29 DB_NAMESPACE,
30 DB_OPERATION_NAME,
31 DB_QUERY_SUMMARY,
32 DB_QUERY_TEXT,
33 DB_SYSTEM_NAME,
34)
35from opentelemetry.semconv.attributes.network_attributes import (
36 NETWORK_PEER_ADDRESS,
37 NETWORK_PEER_PORT,
38)
39from opentelemetry.semconv.trace import DbSystemValues
40from opentelemetry.trace import SpanKind
41
42from plain.runtime import settings
43
44# Use a stable string key so OpenTelemetry context APIs receive the expected type.
45_SUPPRESS_KEY = "plain.models.suppress_db_tracing"
46
47tracer = trace.get_tracer("plain.models")
48
49
50def db_system_for(vendor: str) -> str: # noqa: D401 – simple helper
51 """Return the canonical ``db.system.name`` value for a backend vendor."""
52
53 return {
54 "postgresql": DbSystemValues.POSTGRESQL.value,
55 "mysql": DbSystemValues.MYSQL.value,
56 "mariadb": DbSystemValues.MARIADB.value,
57 "sqlite": DbSystemValues.SQLITE.value,
58 }.get(vendor, vendor)
59
60
61def extract_operation_and_target(sql: str) -> tuple[str, str | None, str | None]:
62 """Extract operation, table name, and collection from SQL.
63
64 Returns: (operation, summary, collection_name)
65 """
66 sql_upper = sql.upper().strip()
67
68 # Strip leading parentheses (e.g. UNION queries: "(SELECT ... UNION ...)")
69 operation = sql_upper.lstrip("(").split()[0] if sql_upper else "UNKNOWN"
70
71 # Pattern to match quoted and unquoted identifiers
72 # Matches: "quoted", `quoted`, [quoted], unquoted.name
73 identifier_pattern = r'("([^"]+)"|`([^`]+)`|\[([^\]]+)\]|([\w.]+))'
74
75 # Extract table/collection name based on operation
76 collection_name = None
77 summary = operation
78
79 if operation in ("SELECT", "DELETE"):
80 match = re.search(rf"FROM\s+{identifier_pattern}", sql, re.IGNORECASE)
81 if match:
82 collection_name = _clean_identifier(match.group(1))
83 summary = f"{operation} {collection_name}"
84
85 elif operation in ("INSERT", "REPLACE"):
86 match = re.search(rf"INTO\s+{identifier_pattern}", sql, re.IGNORECASE)
87 if match:
88 collection_name = _clean_identifier(match.group(1))
89 summary = f"{operation} {collection_name}"
90
91 elif operation == "UPDATE":
92 match = re.search(rf"UPDATE\s+{identifier_pattern}", sql, re.IGNORECASE)
93 if match:
94 collection_name = _clean_identifier(match.group(1))
95 summary = f"{operation} {collection_name}"
96
97 # Detect UNION queries
98 if " UNION " in sql_upper and summary:
99 summary = f"{summary} UNION"
100
101 return operation, summary, collection_name
102
103
104def _clean_identifier(identifier: str) -> str:
105 """Remove quotes from SQL identifiers."""
106 # Remove different types of SQL quotes
107 if identifier.startswith('"') and identifier.endswith('"'):
108 return identifier[1:-1]
109 elif identifier.startswith("`") and identifier.endswith("`"):
110 return identifier[1:-1]
111 elif identifier.startswith("[") and identifier.endswith("]"):
112 return identifier[1:-1]
113 return identifier
114
115
116@contextmanager
117def db_span(
118 db: BaseDatabaseWrapper, sql: Any, *, many: bool = False, params: Any = None
119) -> Generator[Span | None, None, None]:
120 """Open an OpenTelemetry CLIENT span for a database query.
121
122 All common attributes (`db.*`, `network.*`, etc.) are set automatically.
123 Follows OpenTelemetry semantic conventions for database instrumentation.
124 """
125
126 # Fast-exit if instrumentation suppression flag set in context.
127 if otel_context.get_value(_SUPPRESS_KEY):
128 yield None
129 return
130
131 sql = str(sql) # Ensure SQL is a string for span attributes.
132
133 # Extract operation and target information
134 operation, summary, collection_name = extract_operation_and_target(sql)
135
136 if many:
137 summary = f"{summary} many"
138
139 # Span name follows semantic conventions: {target} or {db.operation.name} {target}
140 if summary:
141 span_name = summary[:255]
142 else:
143 span_name = operation
144
145 # Build attribute set following semantic conventions
146 attrs: dict[str, Any] = {
147 DB_SYSTEM_NAME: db_system_for(db.vendor),
148 DB_NAMESPACE: db.settings_dict.get("NAME"),
149 DB_QUERY_TEXT: sql, # Already parameterized from Django/Plain
150 DB_QUERY_SUMMARY: summary,
151 DB_OPERATION_NAME: operation,
152 }
153
154 attrs.update(_get_code_attributes())
155
156 # Add collection name if detected
157 if collection_name:
158 attrs[DB_COLLECTION_NAME] = collection_name
159
160 # Add user attribute
161 if user := db.settings_dict.get("USER"):
162 attrs[DB_USER] = user
163
164 # Network attributes
165 if host := db.settings_dict.get("HOST"):
166 attrs[NETWORK_PEER_ADDRESS] = host
167
168 if port := db.settings_dict.get("PORT"):
169 try:
170 attrs[NETWORK_PEER_PORT] = int(port)
171 except (TypeError, ValueError):
172 pass
173
174 # Add query parameters as attributes when DEBUG is True
175 if settings.DEBUG and params is not None:
176 # Convert params to appropriate format based on type
177 if isinstance(params, dict):
178 # Dictionary params (e.g., for named placeholders)
179 for i, (key, value) in enumerate(params.items()):
180 attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{key}"] = str(value)
181 elif isinstance(params, list | tuple):
182 # Sequential params (e.g., for %s or ? placeholders)
183 for i, value in enumerate(params):
184 attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{i + 1}"] = str(value)
185 else:
186 # Single param (rare but possible)
187 attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.1"] = str(params)
188
189 with tracer.start_as_current_span(
190 span_name, kind=SpanKind.CLIENT, attributes=attrs
191 ) as span:
192 yield span
193 span.set_status(trace.StatusCode.OK)
194
195
196@contextmanager
197def suppress_db_tracing() -> Generator[None, None, None]:
198 token = otel_context.attach(otel_context.set_value(_SUPPRESS_KEY, True))
199 try:
200 yield
201 finally:
202 otel_context.detach(token)
203
204
205def _get_code_attributes() -> dict[str, Any]:
206 """Extract code context attributes for the current database query.
207
208 Returns a dict of OpenTelemetry code attributes.
209 """
210 stack = traceback.extract_stack()
211
212 # Find the user code frame
213 for frame in reversed(stack):
214 filepath = frame.filename
215 if not filepath:
216 continue
217
218 if "/plain/models/" in filepath:
219 continue
220
221 if filepath.endswith("contextlib.py"):
222 continue
223
224 # Found user code - build attributes dict
225 attrs = {}
226
227 if filepath:
228 attrs[CODE_FILE_PATH] = filepath
229 if frame.lineno:
230 attrs[CODE_LINE_NUMBER] = frame.lineno
231 if frame.name:
232 attrs[CODE_FUNCTION_NAME] = frame.name
233 if frame.colno:
234 attrs[CODE_COLUMN_NUMBER] = frame.colno
235
236 # Add full stack trace only in DEBUG mode (expensive)
237 if settings.DEBUG:
238 # Filter out internal frames from the stack trace
239 filtered_stack = []
240 for frame in stack:
241 filepath = frame.filename
242 if not filepath:
243 continue
244 if "/plain/models/" in filepath:
245 continue
246 if filepath.endswith("contextlib.py"):
247 continue
248 filtered_stack.append(frame)
249
250 attrs[CODE_STACKTRACE] = "".join(traceback.format_list(filtered_stack))
251
252 return attrs
253
254 return {}