1from __future__ import annotations
  2
  3import re
  4import traceback
  5from collections.abc import Generator
  6from contextlib import contextmanager
  7from typing import TYPE_CHECKING, Any
  8
  9from opentelemetry import context as otel_context
 10from opentelemetry import trace
 11
 12if TYPE_CHECKING:
 13    from opentelemetry.trace import Span
 14
 15    from plain.postgres.connection import DatabaseConnection
 16
 17from opentelemetry.semconv._incubating.attributes.db_attributes import (
 18    DB_QUERY_PARAMETER_TEMPLATE,
 19    DB_USER,
 20)
 21from opentelemetry.semconv.attributes.code_attributes import (
 22    CODE_COLUMN_NUMBER,
 23    CODE_FILE_PATH,
 24    CODE_FUNCTION_NAME,
 25    CODE_LINE_NUMBER,
 26    CODE_STACKTRACE,
 27)
 28from opentelemetry.semconv.attributes.db_attributes import (
 29    DB_COLLECTION_NAME,
 30    DB_NAMESPACE,
 31    DB_OPERATION_NAME,
 32    DB_QUERY_SUMMARY,
 33    DB_QUERY_TEXT,
 34    DB_SYSTEM_NAME,
 35)
 36from opentelemetry.semconv.attributes.network_attributes import (
 37    NETWORK_PEER_ADDRESS,
 38    NETWORK_PEER_PORT,
 39)
 40from opentelemetry.semconv.trace import DbSystemValues
 41from opentelemetry.trace import SpanKind
 42
 43from plain.runtime import settings
 44
 45# Use a stable string key so OpenTelemetry context APIs receive the expected type.
 46_SUPPRESS_KEY = "plain.postgres.suppress_db_tracing"
 47
 48tracer = trace.get_tracer("plain.postgres")
 49
 50
 51DB_SYSTEM = DbSystemValues.POSTGRESQL.value
 52
 53
 54def extract_operation_and_target(sql: str) -> tuple[str, str | None, str | None]:
 55    """Extract operation, table name, and collection from SQL.
 56
 57    Returns: (operation, summary, collection_name)
 58    """
 59    sql_upper = sql.upper().strip()
 60
 61    # Strip leading parentheses (e.g. UNION queries: "(SELECT ... UNION ...)")
 62    operation = sql_upper.lstrip("(").split()[0] if sql_upper else "UNKNOWN"
 63
 64    # Pattern to match quoted and unquoted identifiers
 65    # Matches: "quoted" (PostgreSQL), unquoted.name
 66    identifier_pattern = r'("([^"]+)"|([\w.]+))'
 67
 68    # Map operations to the SQL keyword that precedes the table name.
 69    keyword_by_operation = {
 70        "SELECT": "FROM",
 71        "DELETE": "FROM",
 72        "INSERT": "INTO",
 73        "UPDATE": "UPDATE",
 74    }
 75
 76    # Extract table/collection name based on operation
 77    collection_name = None
 78    summary = operation
 79
 80    keyword = keyword_by_operation.get(operation)
 81    if keyword:
 82        match = re.search(rf"{keyword}\s+{identifier_pattern}", sql, re.IGNORECASE)
 83        if match:
 84            collection_name = _clean_identifier(match.group(1))
 85            summary = f"{operation} {collection_name}"
 86
 87    # Detect UNION queries
 88    if " UNION " in sql_upper and summary:
 89        summary = f"{summary} UNION"
 90
 91    return operation, summary, collection_name
 92
 93
 94def _clean_identifier(identifier: str) -> str:
 95    """Remove quotes from SQL identifiers."""
 96    if identifier.startswith('"') and identifier.endswith('"'):
 97        return identifier[1:-1]
 98    return identifier
 99
100
101@contextmanager
102def db_span(
103    db: DatabaseConnection, sql: Any, *, many: bool = False, params: Any = None
104) -> Generator[Span | None]:
105    """Open an OpenTelemetry CLIENT span for a database query.
106
107    All common attributes (`db.*`, `network.*`, etc.) are set automatically.
108    Follows OpenTelemetry semantic conventions for database instrumentation.
109    """
110
111    # Fast-exit if instrumentation suppression flag set in context.
112    if otel_context.get_value(_SUPPRESS_KEY):
113        yield None
114        return
115
116    sql = str(sql)  # Ensure SQL is a string for span attributes.
117
118    # Extract operation and target information
119    operation, summary, collection_name = extract_operation_and_target(sql)
120
121    if many:
122        summary = f"{summary} many"
123
124    # Span name follows semantic conventions: {target} or {db.operation.name} {target}
125    if summary:
126        span_name = summary[:255]
127    else:
128        span_name = operation
129
130    # Build attribute set following semantic conventions
131    attrs: dict[str, Any] = {
132        DB_SYSTEM_NAME: DB_SYSTEM,
133        DB_NAMESPACE: db.settings_dict.get("DATABASE"),
134        DB_QUERY_TEXT: sql,  # Already parameterized from Django/Plain
135        DB_QUERY_SUMMARY: summary,
136        DB_OPERATION_NAME: operation,
137    }
138
139    attrs.update(_get_code_attributes())
140
141    # Add collection name if detected
142    if collection_name:
143        attrs[DB_COLLECTION_NAME] = collection_name
144
145    # Add user attribute
146    if user := db.settings_dict.get("USER"):
147        attrs[DB_USER] = user
148
149    # Network attributes
150    if host := db.settings_dict.get("HOST"):
151        attrs[NETWORK_PEER_ADDRESS] = host
152
153    if port := db.settings_dict.get("PORT"):
154        try:
155            attrs[NETWORK_PEER_PORT] = int(port)
156        except (TypeError, ValueError):
157            pass
158
159    # Add query parameters as attributes when DEBUG is True
160    if settings.DEBUG and params is not None:
161        # Convert params to appropriate format based on type
162        if isinstance(params, dict):
163            # Dictionary params (e.g., for named placeholders)
164            for key, value in params.items():
165                attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{key}"] = str(value)
166        elif isinstance(params, list | tuple):
167            # Sequential params (e.g., for %s or ? placeholders)
168            for i, value in enumerate(params):
169                attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.{i + 1}"] = str(value)
170        else:
171            # Single param (rare but possible)
172            attrs[f"{DB_QUERY_PARAMETER_TEMPLATE}.1"] = str(params)
173
174    with tracer.start_as_current_span(
175        span_name, kind=SpanKind.CLIENT, attributes=attrs
176    ) as span:
177        yield span
178        span.set_status(trace.StatusCode.OK)
179
180
181@contextmanager
182def suppress_db_tracing() -> Generator[None]:
183    token = otel_context.attach(otel_context.set_value(_SUPPRESS_KEY, True))
184    try:
185        yield
186    finally:
187        otel_context.detach(token)
188
189
190def _is_internal_frame(frame: traceback.FrameSummary) -> bool:
191    """Return True if the frame is internal to plain.postgres or contextlib."""
192    filepath = frame.filename
193    if not filepath:
194        return True
195    if "/plain/postgres/" in filepath:
196        return True
197    if filepath.endswith("contextlib.py"):
198        return True
199    return False
200
201
202def _get_code_attributes() -> dict[str, Any]:
203    """Extract code context attributes for the current database query.
204
205    Returns a dict of OpenTelemetry code attributes.
206    """
207    stack = traceback.extract_stack()
208
209    # Find the first user code frame (outermost non-internal frame from the top of the call stack)
210    for frame in reversed(stack):
211        if _is_internal_frame(frame):
212            continue
213
214        attrs: dict[str, Any] = {
215            CODE_FILE_PATH: frame.filename,
216        }
217        if frame.lineno:
218            attrs[CODE_LINE_NUMBER] = frame.lineno
219        if frame.name:
220            attrs[CODE_FUNCTION_NAME] = frame.name
221        if frame.colno:
222            attrs[CODE_COLUMN_NUMBER] = frame.colno
223
224        # Add full stack trace only in DEBUG mode (expensive)
225        if settings.DEBUG:
226            filtered_stack = [f for f in stack if not _is_internal_frame(f)]
227            attrs[CODE_STACKTRACE] = "".join(traceback.format_list(filtered_stack))
228
229        return attrs
230
231    return {}