1from __future__ import annotations
2
3import importlib.metadata
4from collections.abc import Callable, Iterable
5from functools import wraps
6from typing import TYPE_CHECKING, Any, ClassVar
7
8from opentelemetry import metrics, trace
9from opentelemetry.metrics import CallbackOptions, Observation
10from opentelemetry.semconv._incubating.attributes.code_attributes import (
11 CODE_FUNCTION_NAME,
12)
13from opentelemetry.semconv._incubating.attributes.messaging_attributes import (
14 MESSAGING_DESTINATION_NAME,
15 MESSAGING_OPERATION_TYPE,
16 MESSAGING_SYSTEM,
17 MessagingOperationTypeValues,
18)
19from opentelemetry.semconv._incubating.metrics.messaging_metrics import (
20 create_messaging_client_consumed_messages,
21 create_messaging_client_operation_duration,
22 create_messaging_client_sent_messages,
23)
24from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
25
26from plain.postgres import Q
27from plain.postgres.aggregates import Count, Min
28from plain.postgres.db import return_database_connection
29from plain.utils import timezone
30from plain.utils.otel import format_exception_type
31
32if TYPE_CHECKING:
33 from .models import JobResult
34 from .workers import Worker
35
36# Attribute key for the terminal-status dimension on the consumed counter.
37PLAIN_JOBS_OUTCOME = "plain.jobs.outcome"
38
39# Attribute key for the worker-liveness dimension on plain.jobs.workers.
40PLAIN_JOBS_WORKER_STATE = "plain.jobs.worker.state"
41
42try:
43 _package_version = importlib.metadata.version("plain.jobs")
44except importlib.metadata.PackageNotFoundError:
45 _package_version = "dev"
46
47tracer = trace.get_tracer("plain.jobs", _package_version)
48meter = metrics.get_meter("plain.jobs", version=_package_version)
49
50# Per-event instruments (semconv messaging metrics + plain.jobs queue.wait.duration).
51sent_messages_counter = create_messaging_client_sent_messages(meter)
52consumed_messages_counter = create_messaging_client_consumed_messages(meter)
53operation_duration_histogram = create_messaging_client_operation_duration(meter)
54queue_wait_duration_histogram = meter.create_histogram(
55 name="plain.jobs.queue.wait.duration",
56 unit="s",
57 description="Time a job spent waiting in the queue before a worker picked it up.",
58)
59
60
61def record_span_error(
62 span: trace.Span,
63 exc: BaseException,
64 metric_attributes: dict[str, Any],
65) -> str:
66 """Mark the span as failed, stamp error.type on it and on the per-call
67 metric attribute dict, and return the error.type string so the caller
68 can forward it to other instruments."""
69 error_type = format_exception_type(exc)
70 span.record_exception(exc)
71 span.set_status(trace.StatusCode.ERROR)
72 span.set_attribute(ERROR_TYPE, error_type)
73 metric_attributes[ERROR_TYPE] = error_type
74 return error_type
75
76
77def process_metric_attributes(queue: str, job_class: str) -> dict[str, Any]:
78 """Base attribute dict for messaging.client.* process-side metrics.
79
80 Shared by JobProcess.run() (which adds error.type for failed jobs) and
81 record_consumed (which adds the outcome dimension). One builder so
82 keys/values stay in lockstep across the two call sites.
83 """
84 return {
85 MESSAGING_SYSTEM: "plain.jobs",
86 MESSAGING_OPERATION_TYPE: MessagingOperationTypeValues.PROCESS.value,
87 MESSAGING_DESTINATION_NAME: queue,
88 CODE_FUNCTION_NAME: f"{job_class}.run",
89 }
90
91
92def record_consumed(result: JobResult, *, error_type: str | None = None) -> None:
93 """Record one consumed-message metric point per terminal JobResult.
94
95 `plain.jobs.outcome` carries the terminal status (successful/errored/
96 lost/cancelled/deferred). `error.type` is included when known — i.e.,
97 when the live path caught an exception and forwarded it through. The
98 rescue path (LOST) and direct cancellations don't carry an error type
99 because there is no exception object to derive it from."""
100 attrs = process_metric_attributes(result.queue, result.job_class)
101 attrs[PLAIN_JOBS_OUTCOME] = result.status.lower()
102 if error_type is not None:
103 attrs[ERROR_TYPE] = error_type
104 consumed_messages_counter.add(1, attrs)
105
106
107def _release_db_connection(
108 callback: Callable[..., Iterable[Observation]],
109) -> Callable[..., Iterable[Observation]]:
110 """Return a gauge callback's database connection to the pool once its
111 observation has been collected.
112
113 OTel runs observable-gauge callbacks on the PeriodicExportingMetricReader
114 thread, which has no request or job lifecycle to recycle connections.
115 Left unreturned, that thread's connection wrapper holds a single pooled
116 connection idle between export intervals — long enough for the server (or
117 a pooler) to close it. The next interval then reuses the dead connection
118 and raises `OperationalError: the connection is closed`. Returning it each
119 interval means every callback starts from a freshly checked-out connection.
120 """
121
122 @wraps(callback)
123 def wrapper(
124 cls: type[WorkerMetrics], options: CallbackOptions
125 ) -> Iterable[Observation]:
126 try:
127 return callback(cls, options)
128 finally:
129 return_database_connection()
130
131 return wrapper
132
133
134class WorkerMetrics:
135 """Per-Worker observable gauges (queue depth/age/scheduled, running count,
136 worker process count).
137
138 The OTel SDK keeps the *first* callback registered for a given instrument
139 name, so instruments are registered once per process. The Worker they
140 observe may change across reload paths, so each Worker owns a
141 WorkerMetrics; constructing one swaps it in as the active target for the
142 (process-singleton) callbacks. The new instance simply replaces the old
143 one in the class-level `_current` slot — no explicit teardown is needed
144 because either a successor swaps in (reload) or the process exits
145 (signal shutdown).
146
147 Each callback emits one observation per queue this Worker handles, every
148 export interval, including zero for empty queues so `last_value`
149 dashboards don't show stale readings after a drain. When two Workers
150 handle the same queue they emit identical values; aggregate with
151 `last_value`/`max`, never `sum`.
152 """
153
154 _current: ClassVar[WorkerMetrics | None] = None
155 _registered: ClassVar[bool] = False
156
157 def __init__(self, worker: Worker) -> None:
158 self.worker = worker
159 type(self)._register_instruments()
160 type(self)._current = self
161
162 @classmethod
163 def _register_instruments(cls) -> None:
164 if cls._registered:
165 return
166 cls._registered = True
167 meter.create_observable_gauge(
168 name="plain.jobs.worker.processes",
169 callbacks=[cls._gauge_worker_processes],
170 unit="{process}",
171 description="OS processes spawned by this worker.",
172 )
173 meter.create_observable_gauge(
174 name="plain.jobs.queue.depth",
175 callbacks=[cls._gauge_queue_depth],
176 unit="{job}",
177 description="Pending JobRequests ready to run, per queue.",
178 )
179 meter.create_observable_gauge(
180 name="plain.jobs.queue.oldest.age",
181 callbacks=[cls._gauge_queue_oldest_age],
182 unit="s",
183 description="Age of the oldest ready-to-run JobRequest, per queue.",
184 )
185 meter.create_observable_gauge(
186 name="plain.jobs.queue.scheduled",
187 callbacks=[cls._gauge_queue_scheduled],
188 unit="{job}",
189 description="JobRequests with start_at in the future, per queue.",
190 )
191 meter.create_observable_gauge(
192 name="plain.jobs.running",
193 callbacks=[cls._gauge_running],
194 unit="{job}",
195 description="JobProcess rows currently running, per queue.",
196 )
197 meter.create_observable_gauge(
198 name="plain.jobs.workers",
199 callbacks=[cls._gauge_workers],
200 unit="{worker}",
201 description=(
202 "WorkerHeartbeat row count, split by liveness state "
203 "(active=within JOBS_HEARTBEAT_TIMEOUT, stale=past it)."
204 ),
205 )
206
207 # --- Callbacks ----------------------------------------------------------
208
209 # Each callback snapshots `cls._current` to a local — `deactivate()` can
210 # null the class var on another thread mid-callback (PeriodicExporting
211 # MetricReader runs callbacks off the main thread).
212
213 @classmethod
214 def _gauge_worker_processes(cls, options: CallbackOptions) -> Iterable[Observation]:
215 active = cls._current
216 if active is None:
217 return []
218 try:
219 n = len(active.worker.executor._processes)
220 except (AttributeError, TypeError):
221 # Pool may be mid-shutdown; report 0 rather than crashing the export.
222 n = 0
223 return [Observation(n)]
224
225 @classmethod
226 @_release_db_connection
227 def _gauge_queue_depth(cls, options: CallbackOptions) -> Iterable[Observation]:
228 active = cls._current
229 if active is None:
230 return []
231 # Lazy import - see Worker._worker_process_initializer() comment for why.
232 from .models import JobRequest
233
234 return _count_per_queue(JobRequest.query.ready_to_run(), active.worker.queues)
235
236 @classmethod
237 @_release_db_connection
238 def _gauge_queue_oldest_age(cls, options: CallbackOptions) -> Iterable[Observation]:
239 active = cls._current
240 if active is None:
241 return []
242 from .models import JobRequest
243
244 queues = active.worker.queues
245 rows = (
246 JobRequest.query.ready_to_run()
247 .filter(queue__in=queues)
248 .values("queue")
249 .annotate(oldest=Min("created_at"))
250 )
251 now = timezone.now()
252 # `max(0, ...)` defends against Python/Postgres clock skew producing
253 # a negative age. Empty queues fall through to 0.0 below.
254 ages = {
255 row["queue"]: max(0.0, (now - row["oldest"]).total_seconds())
256 for row in rows
257 if row["oldest"] is not None
258 }
259 return [
260 Observation(ages.get(q, 0.0), {MESSAGING_DESTINATION_NAME: q})
261 for q in queues
262 ]
263
264 @classmethod
265 @_release_db_connection
266 def _gauge_queue_scheduled(cls, options: CallbackOptions) -> Iterable[Observation]:
267 active = cls._current
268 if active is None:
269 return []
270 from .models import JobRequest
271
272 return _count_per_queue(JobRequest.query.scheduled(), active.worker.queues)
273
274 @classmethod
275 @_release_db_connection
276 def _gauge_running(cls, options: CallbackOptions) -> Iterable[Observation]:
277 active = cls._current
278 if active is None:
279 return []
280 from .models import JobProcess
281
282 return _count_per_queue(JobProcess.query.running(), active.worker.queues)
283
284 # The worker-liveness gauge observes the global WorkerHeartbeat table and
285 # doesn't need a calling Worker — emit unconditionally so dashboards keep
286 # reporting even during a full worker drain. One snapshot of the cutoff
287 # is shared across both observations so a row landing exactly at the
288 # boundary can't be counted in both states (or neither).
289 @classmethod
290 @_release_db_connection
291 def _gauge_workers(cls, options: CallbackOptions) -> Iterable[Observation]:
292 from .models import WorkerHeartbeat, heartbeat_cutoff
293
294 cutoff = heartbeat_cutoff()
295 counts = WorkerHeartbeat.query.aggregate(
296 active=Count("id", filter=Q(last_heartbeat_at__gte=cutoff)),
297 stale=Count("id", filter=Q(last_heartbeat_at__lt=cutoff)),
298 )
299 return [
300 Observation(counts["active"], {PLAIN_JOBS_WORKER_STATE: "active"}),
301 Observation(counts["stale"], {PLAIN_JOBS_WORKER_STATE: "stale"}),
302 ]
303
304
305def _count_per_queue(queryset: Any, queues: list[str]) -> list[Observation]:
306 rows = queryset.filter(queue__in=queues).values("queue").annotate(c=Count("*"))
307 counts = {row["queue"]: row["c"] for row in rows}
308 return [
309 Observation(counts.get(q, 0), {MESSAGING_DESTINATION_NAME: q}) for q in queues
310 ]