1import re
2from contextlib import contextmanager
3from typing import Any
4
5from opentelemetry import context as otel_context
6from opentelemetry import trace
7from opentelemetry.semconv._incubating.attributes.db_attributes import (
8 DB_QUERY_PARAMETER_TEMPLATE,
9 DB_USER,
10)
11from opentelemetry.semconv.attributes.db_attributes import (
12 DB_COLLECTION_NAME,
13 DB_NAMESPACE,
14 DB_OPERATION_NAME,
15 DB_QUERY_SUMMARY,
16 DB_QUERY_TEXT,
17 DB_SYSTEM_NAME,
18)
19from opentelemetry.semconv.attributes.network_attributes import (
20 NETWORK_PEER_ADDRESS,
21 NETWORK_PEER_PORT,
22)
23from opentelemetry.semconv.trace import DbSystemValues
24from opentelemetry.trace import SpanKind
25
26from plain.runtime import settings
27
28_SUPPRESS_KEY = object()
29
30tracer = trace.get_tracer("plain.models")
31
32
33def db_system_for(vendor: str) -> str: # noqa: D401 – simple helper
34 """Return the canonical ``db.system.name`` value for a backend vendor."""
35
36 return {
37 "postgresql": DbSystemValues.POSTGRESQL.value,
38 "mysql": DbSystemValues.MYSQL.value,
39 "mariadb": DbSystemValues.MARIADB.value,
40 "sqlite": DbSystemValues.SQLITE.value,
41 }.get(vendor, vendor)
42
43
44def extract_operation_and_target(sql: str) -> tuple[str, str | None, str | None]:
45 """Extract operation, table name, and collection from SQL.
46
47 Returns: (operation, summary, collection_name)
48 """
49 sql_upper = sql.upper().strip()
50 operation = sql_upper.split()[0] if sql_upper else "UNKNOWN"
51
52 # Pattern to match quoted and unquoted identifiers
53 # Matches: "quoted", `quoted`, [quoted], unquoted.name
54 identifier_pattern = r'("([^"]+)"|`([^`]+)`|\[([^\]]+)\]|([\w.]+))'
55
56 # Extract table/collection name based on operation
57 collection_name = None
58 summary = operation
59
60 if operation in ("SELECT", "DELETE"):
61 match = re.search(rf"FROM\s+{identifier_pattern}", sql, re.IGNORECASE)
62 if match:
63 collection_name = _clean_identifier(match.group(1))
64 summary = f"{operation} {collection_name}"
65
66 elif operation in ("INSERT", "REPLACE"):
67 match = re.search(rf"INTO\s+{identifier_pattern}", sql, re.IGNORECASE)
68 if match:
69 collection_name = _clean_identifier(match.group(1))
70 summary = f"{operation} {collection_name}"
71
72 elif operation == "UPDATE":
73 match = re.search(rf"UPDATE\s+{identifier_pattern}", sql, re.IGNORECASE)
74 if match:
75 collection_name = _clean_identifier(match.group(1))
76 summary = f"{operation} {collection_name}"
77
78 return operation, summary, collection_name
79
80
81def _clean_identifier(identifier: str) -> str:
82 """Remove quotes from SQL identifiers."""
83 # Remove different types of SQL quotes
84 if identifier.startswith('"') and identifier.endswith('"'):
85 return identifier[1:-1]
86 elif identifier.startswith("`") and identifier.endswith("`"):
87 return identifier[1:-1]
88 elif identifier.startswith("[") and identifier.endswith("]"):
89 return identifier[1:-1]
90 return identifier
91
92
93@contextmanager
94def db_span(db, sql: Any, *, many: bool = False, params=None):
95 """Open an OpenTelemetry CLIENT span for a database query.
96
97 All common attributes (`db.*`, `network.*`, etc.) are set automatically.
98 Follows OpenTelemetry semantic conventions for database instrumentation.
99 """
100
101 # Fast-exit if instrumentation suppression flag set in context.
102 if otel_context.get_value(_SUPPRESS_KEY):
103 yield None
104 return
105
106 sql = str(sql) # Ensure SQL is a string for span attributes.
107
108 # Extract operation and target information
109 operation, summary, collection_name = extract_operation_and_target(sql)
110
111 if many:
112 summary = f"{summary} many"
113
114 # Span name follows semantic conventions: {target} or {db.operation.name} {target}
115 if summary:
116 span_name = summary[:255]
117 else:
118 span_name = operation
119
120 # Build attribute set following semantic conventions
121 attrs: dict[str, Any] = {
122 DB_SYSTEM_NAME: db_system_for(db.vendor),
123 DB_NAMESPACE: db.settings_dict.get("NAME"),
124 DB_QUERY_TEXT: sql, # Already parameterized from Django/Plain
125 DB_QUERY_SUMMARY: summary,
126 DB_OPERATION_NAME: operation,
127 }
128
129 # Add collection name if detected
130 if collection_name:
131 attrs[DB_COLLECTION_NAME] = collection_name
132
133 # Add user attribute
134 if user := db.settings_dict.get("USER"):
135 attrs[DB_USER] = user
136
137 # Network attributes
138 if host := db.settings_dict.get("HOST"):
139 attrs[NETWORK_PEER_ADDRESS] = host
140
141 if port := db.settings_dict.get("PORT"):
142 try:
143 attrs[NETWORK_PEER_PORT] = int(port)
144 except (TypeError, ValueError):
145 pass
146
147 # Add query parameters as attributes when DEBUG is True
148 if settings.DEBUG and params is not None:
149 # Convert params to appropriate format based on type
150 if isinstance(params, dict):
151 # Dictionary params (e.g., for named placeholders)
152 for i, (key, value) in enumerate(params.items()):
153 attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{key}"] = str(value)
154 elif isinstance(params, list | tuple):
155 # Sequential params (e.g., for %s or ? placeholders)
156 for i, value in enumerate(params):
157 attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{i + 1}"] = str(value)
158 else:
159 # Single param (rare but possible)
160 attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.1"] = str(params)
161
162 with tracer.start_as_current_span(
163 span_name, kind=SpanKind.CLIENT, attributes=attrs
164 ) as span:
165 yield span
166 span.set_status(trace.StatusCode.OK)
167
168
169@contextmanager
170def suppress_db_tracing():
171 token = otel_context.attach(otel_context.set_value(_SUPPRESS_KEY, True))
172 try:
173 yield
174 finally:
175 otel_context.detach(token)