1from __future__ import annotations
  2
  3import re
  4import traceback
  5from collections.abc import Generator
  6from contextlib import contextmanager
  7from typing import TYPE_CHECKING, Any
  8
  9from opentelemetry import context as otel_context
 10from opentelemetry import trace
 11
 12if TYPE_CHECKING:
 13    from opentelemetry.trace import Span
 14
 15    from plain.models.backends.base.base import BaseDatabaseWrapper
 16from opentelemetry.semconv._incubating.attributes.db_attributes import (
 17    DB_QUERY_PARAMETER_TEMPLATE,
 18    DB_USER,
 19)
 20from opentelemetry.semconv.attributes.code_attributes import (
 21    CODE_COLUMN_NUMBER,
 22    CODE_FILE_PATH,
 23    CODE_FUNCTION_NAME,
 24    CODE_LINE_NUMBER,
 25    CODE_STACKTRACE,
 26)
 27from opentelemetry.semconv.attributes.db_attributes import (
 28    DB_COLLECTION_NAME,
 29    DB_NAMESPACE,
 30    DB_OPERATION_NAME,
 31    DB_QUERY_SUMMARY,
 32    DB_QUERY_TEXT,
 33    DB_SYSTEM_NAME,
 34)
 35from opentelemetry.semconv.attributes.network_attributes import (
 36    NETWORK_PEER_ADDRESS,
 37    NETWORK_PEER_PORT,
 38)
 39from opentelemetry.semconv.trace import DbSystemValues
 40from opentelemetry.trace import SpanKind
 41
 42from plain.runtime import settings
 43
 44# Use a stable string key so OpenTelemetry context APIs receive the expected type.
 45_SUPPRESS_KEY = "plain.models.suppress_db_tracing"
 46
 47tracer = trace.get_tracer("plain.models")
 48
 49
 50def db_system_for(vendor: str) -> str:  # noqa: D401 – simple helper
 51    """Return the canonical ``db.system.name`` value for a backend vendor."""
 52
 53    return {
 54        "postgresql": DbSystemValues.POSTGRESQL.value,
 55        "mysql": DbSystemValues.MYSQL.value,
 56        "mariadb": DbSystemValues.MARIADB.value,
 57        "sqlite": DbSystemValues.SQLITE.value,
 58    }.get(vendor, vendor)
 59
 60
 61def extract_operation_and_target(sql: str) -> tuple[str, str | None, str | None]:
 62    """Extract operation, table name, and collection from SQL.
 63
 64    Returns: (operation, summary, collection_name)
 65    """
 66    sql_upper = sql.upper().strip()
 67
 68    # Strip leading parentheses (e.g. UNION queries: "(SELECT ... UNION ...)")
 69    operation = sql_upper.lstrip("(").split()[0] if sql_upper else "UNKNOWN"
 70
 71    # Pattern to match quoted and unquoted identifiers
 72    # Matches: "quoted", `quoted`, [quoted], unquoted.name
 73    identifier_pattern = r'("([^"]+)"|`([^`]+)`|\[([^\]]+)\]|([\w.]+))'
 74
 75    # Extract table/collection name based on operation
 76    collection_name = None
 77    summary = operation
 78
 79    if operation in ("SELECT", "DELETE"):
 80        match = re.search(rf"FROM\s+{identifier_pattern}", sql, re.IGNORECASE)
 81        if match:
 82            collection_name = _clean_identifier(match.group(1))
 83            summary = f"{operation} {collection_name}"
 84
 85    elif operation in ("INSERT", "REPLACE"):
 86        match = re.search(rf"INTO\s+{identifier_pattern}", sql, re.IGNORECASE)
 87        if match:
 88            collection_name = _clean_identifier(match.group(1))
 89            summary = f"{operation} {collection_name}"
 90
 91    elif operation == "UPDATE":
 92        match = re.search(rf"UPDATE\s+{identifier_pattern}", sql, re.IGNORECASE)
 93        if match:
 94            collection_name = _clean_identifier(match.group(1))
 95            summary = f"{operation} {collection_name}"
 96
 97    # Detect UNION queries
 98    if " UNION " in sql_upper and summary:
 99        summary = f"{summary} UNION"
100
101    return operation, summary, collection_name
102
103
104def _clean_identifier(identifier: str) -> str:
105    """Remove quotes from SQL identifiers."""
106    # Remove different types of SQL quotes
107    if identifier.startswith('"') and identifier.endswith('"'):
108        return identifier[1:-1]
109    elif identifier.startswith("`") and identifier.endswith("`"):
110        return identifier[1:-1]
111    elif identifier.startswith("[") and identifier.endswith("]"):
112        return identifier[1:-1]
113    return identifier
114
115
116@contextmanager
117def db_span(
118    db: BaseDatabaseWrapper, sql: Any, *, many: bool = False, params: Any = None
119) -> Generator[Span | None, None, None]:
120    """Open an OpenTelemetry CLIENT span for a database query.
121
122    All common attributes (`db.*`, `network.*`, etc.) are set automatically.
123    Follows OpenTelemetry semantic conventions for database instrumentation.
124    """
125
126    # Fast-exit if instrumentation suppression flag set in context.
127    if otel_context.get_value(_SUPPRESS_KEY):
128        yield None
129        return
130
131    sql = str(sql)  # Ensure SQL is a string for span attributes.
132
133    # Extract operation and target information
134    operation, summary, collection_name = extract_operation_and_target(sql)
135
136    if many:
137        summary = f"{summary} many"
138
139    # Span name follows semantic conventions: {target} or {db.operation.name} {target}
140    if summary:
141        span_name = summary[:255]
142    else:
143        span_name = operation
144
145    # Build attribute set following semantic conventions
146    attrs: dict[str, Any] = {
147        DB_SYSTEM_NAME: db_system_for(db.vendor),
148        DB_NAMESPACE: db.settings_dict.get("NAME"),
149        DB_QUERY_TEXT: sql,  # Already parameterized from Django/Plain
150        DB_QUERY_SUMMARY: summary,
151        DB_OPERATION_NAME: operation,
152    }
153
154    attrs.update(_get_code_attributes())
155
156    # Add collection name if detected
157    if collection_name:
158        attrs[DB_COLLECTION_NAME] = collection_name
159
160    # Add user attribute
161    if user := db.settings_dict.get("USER"):
162        attrs[DB_USER] = user
163
164    # Network attributes
165    if host := db.settings_dict.get("HOST"):
166        attrs[NETWORK_PEER_ADDRESS] = host
167
168    if port := db.settings_dict.get("PORT"):
169        try:
170            attrs[NETWORK_PEER_PORT] = int(port)
171        except (TypeError, ValueError):
172            pass
173
174    # Add query parameters as attributes when DEBUG is True
175    if settings.DEBUG and params is not None:
176        # Convert params to appropriate format based on type
177        if isinstance(params, dict):
178            # Dictionary params (e.g., for named placeholders)
179            for i, (key, value) in enumerate(params.items()):
180                attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{key}"] = str(value)
181        elif isinstance(params, list | tuple):
182            # Sequential params (e.g., for %s or ? placeholders)
183            for i, value in enumerate(params):
184                attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{i + 1}"] = str(value)
185        else:
186            # Single param (rare but possible)
187            attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.1"] = str(params)
188
189    with tracer.start_as_current_span(
190        span_name, kind=SpanKind.CLIENT, attributes=attrs
191    ) as span:
192        yield span
193        span.set_status(trace.StatusCode.OK)
194
195
196@contextmanager
197def suppress_db_tracing() -> Generator[None, None, None]:
198    token = otel_context.attach(otel_context.set_value(_SUPPRESS_KEY, True))
199    try:
200        yield
201    finally:
202        otel_context.detach(token)
203
204
205def _get_code_attributes() -> dict[str, Any]:
206    """Extract code context attributes for the current database query.
207
208    Returns a dict of OpenTelemetry code attributes.
209    """
210    stack = traceback.extract_stack()
211
212    # Find the user code frame
213    for frame in reversed(stack):
214        filepath = frame.filename
215        if not filepath:
216            continue
217
218        if "/plain/models/" in filepath:
219            continue
220
221        if filepath.endswith("contextlib.py"):
222            continue
223
224        # Found user code - build attributes dict
225        attrs = {}
226
227        if filepath:
228            attrs[CODE_FILE_PATH] = filepath
229        if frame.lineno:
230            attrs[CODE_LINE_NUMBER] = frame.lineno
231        if frame.name:
232            attrs[CODE_FUNCTION_NAME] = frame.name
233        if frame.colno:
234            attrs[CODE_COLUMN_NUMBER] = frame.colno
235
236        # Add full stack trace only in DEBUG mode (expensive)
237        if settings.DEBUG:
238            # Filter out internal frames from the stack trace
239            filtered_stack = []
240            for frame in stack:
241                filepath = frame.filename
242                if not filepath:
243                    continue
244                if "/plain/models/" in filepath:
245                    continue
246                if filepath.endswith("contextlib.py"):
247                    continue
248                filtered_stack.append(frame)
249
250            attrs[CODE_STACKTRACE] = "".join(traceback.format_list(filtered_stack))
251
252        return attrs
253
254    return {}