1from __future__ import annotations
2
3import datetime
4import sys
5import time
6from abc import ABCMeta, abstractmethod
7from contextlib import AbstractContextManager, nullcontext
8from typing import TYPE_CHECKING, Any
9
10from opentelemetry import trace
11from opentelemetry.semconv._incubating.attributes.messaging_attributes import (
12 MESSAGING_DESTINATION_NAME,
13 MESSAGING_MESSAGE_ID,
14 MESSAGING_OPERATION_NAME,
15 MESSAGING_OPERATION_TYPE,
16 MESSAGING_SYSTEM,
17 MessagingOperationTypeValues,
18)
19from opentelemetry.semconv.attributes.code_attributes import (
20 CODE_FILE_PATH,
21 CODE_FUNCTION_NAME,
22 CODE_LINE_NUMBER,
23)
24from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
25from opentelemetry.trace import SpanKind, format_span_id, format_trace_id
26
27from plain import postgres
28from plain.postgres import transaction
29from plain.utils import timezone
30
31from .locks import postgres_advisory_lock
32from .otel import (
33 operation_duration_histogram,
34 record_span_error,
35 sent_messages_counter,
36 tracer,
37)
38from .registry import JobParameters, jobs_registry
39
40if TYPE_CHECKING:
41 from .models import JobProcess, JobRequest, JobResult
42
43
44class JobType(ABCMeta):
45 """
46 Metaclass allows us to capture the original args/kwargs
47 used to instantiate the job, so we can store them in the database
48 when we schedule the job.
49 """
50
51 def __call__(self, *args: Any, **kwargs: Any) -> Job:
52 instance = super().__call__(*args, **kwargs)
53 instance._init_args = args
54 instance._init_kwargs = kwargs
55 return instance
56
57
58class Job(metaclass=JobType):
59 # Set by JobType metaclass when the job is instantiated
60 _init_args: tuple[Any, ...]
61 _init_kwargs: dict[str, Any]
62
63 # Set by JobProcess when the job is executed
64 # Useful for jobs that need to query and exclude themselves
65 job_process: JobProcess | None = None
66
67 @abstractmethod
68 def run(self) -> None:
69 pass
70
71 def run_in_worker(
72 self,
73 *,
74 queue: str | None = None,
75 delay: int | datetime.timedelta | datetime.datetime | None = None,
76 priority: int | None = None,
77 retries: int | None = None,
78 retry_attempt: int = 0,
79 concurrency_key: str | None = None,
80 ) -> JobRequest | None:
81 from .models import JobRequest
82
83 job_class_name = jobs_registry.get_job_class_name(self.__class__)
84
85 if queue is None:
86 queue = self.default_queue()
87
88 metric_attributes: dict[str, Any] = {
89 MESSAGING_SYSTEM: "plain.jobs",
90 MESSAGING_OPERATION_TYPE: MessagingOperationTypeValues.SEND.value,
91 MESSAGING_DESTINATION_NAME: queue,
92 CODE_FUNCTION_NAME: f"{job_class_name}.run_in_worker",
93 }
94 start_time = time.perf_counter()
95 skipped = False
96 with tracer.start_as_current_span(
97 f"send {queue}",
98 kind=SpanKind.PRODUCER,
99 attributes={**metric_attributes, MESSAGING_OPERATION_NAME: "send"},
100 # We record manually via record_span_error (escaped=True) at the
101 # workflow boundary; suppress the SDK's escaped=False auto-record
102 # so failed sends carry a single, correctly-marked event.
103 record_exception=False,
104 ) as span:
105 try:
106 try:
107 frame = sys._getframe(1)
108 filename = frame.f_code.co_filename
109 lineno = frame.f_lineno
110 source = f"{filename}:{lineno}"
111 span.set_attributes(
112 {
113 CODE_FILE_PATH: filename,
114 CODE_LINE_NUMBER: lineno,
115 }
116 )
117 except (ValueError, AttributeError):
118 source = ""
119
120 parameters = JobParameters.to_json(self._init_args, self._init_kwargs)
121
122 if priority is None:
123 priority = self.default_priority()
124
125 if retries is None:
126 retries = self.default_retries()
127
128 if delay is None:
129 start_at = None
130 elif isinstance(delay, int):
131 start_at = timezone.now() + datetime.timedelta(seconds=delay)
132 elif isinstance(delay, datetime.timedelta):
133 start_at = timezone.now() + delay
134 elif isinstance(delay, datetime.datetime):
135 start_at = delay
136 else:
137 raise ValueError(f"Invalid delay: {delay}")
138
139 if concurrency_key is None:
140 concurrency_key = self.default_concurrency_key()
141
142 # Capture current trace context
143 current_span = trace.get_current_span()
144 span_context = current_span.get_span_context()
145
146 # Only include trace context if the span is being recorded (sampled)
147 # This ensures jobs are only linked to traces that are actually being collected
148 if current_span.is_recording() and span_context.is_valid:
149 trace_id = f"0x{format_trace_id(span_context.trace_id)}"
150 span_id = f"0x{format_span_id(span_context.span_id)}"
151 else:
152 trace_id = None
153 span_id = None
154
155 # Use transaction with optional locking for race-free enqueue
156 with transaction.atomic():
157 # Acquire lock via context manager (or nullcontext if None)
158 with self.get_enqueue_lock(concurrency_key) or nullcontext():
159 # Check with lock held (if using locks)
160 if not self.should_enqueue(concurrency_key):
161 span.set_attribute("job.enqueue.skipped", True)
162 skipped = True
163 return None
164
165 # Create job with lock held
166 job_request = JobRequest(
167 job_class=job_class_name,
168 parameters=parameters,
169 start_at=start_at,
170 source=source,
171 queue=queue,
172 priority=priority,
173 retries=retries,
174 retry_attempt=retry_attempt,
175 concurrency_key=concurrency_key,
176 trace_id=trace_id,
177 span_id=span_id,
178 )
179 job_request.save()
180
181 span.set_attribute(
182 MESSAGING_MESSAGE_ID,
183 str(job_request.uuid),
184 )
185
186 return job_request
187 except Exception as e:
188 # Stamps escaped=True on the span event, ERROR_TYPE on both
189 # the span and `metric_attributes` (so the finally below
190 # picks up the failed-send branch).
191 record_span_error(span, e, metric_attributes)
192 raise
193 finally:
194 # Skipped enqueues are visible on the span (`job.enqueue.skipped`)
195 # but do not fire the messaging counter — no message was sent, so
196 # there's nothing for `messaging.client.sent.messages` to count.
197 if not skipped:
198 duration = time.perf_counter() - start_time
199 if ERROR_TYPE in metric_attributes:
200 # No commit is coming — record now so failed sends are visible.
201 sent_messages_counter.add(1, metric_attributes)
202 operation_duration_histogram.record(duration, metric_attributes)
203 else:
204 # Defer to the outer commit so a caller-level rollback
205 # doesn't leave a phantom send. Runs immediately if not
206 # inside a transaction.
207 attrs = metric_attributes
208
209 def _emit() -> None:
210 sent_messages_counter.add(1, attrs)
211 operation_duration_histogram.record(duration, attrs)
212
213 transaction.on_commit(_emit)
214
215 def get_requested_jobs(
216 self, *, concurrency_key: str | None = None, include_retries: bool = False
217 ) -> postgres.QuerySet:
218 """
219 Get pending jobs (JobRequest) for this job class.
220
221 Args:
222 concurrency_key: Optional concurrency_key to filter by. If None, uses self.job_process.concurrency_key (if available) or self.default_concurrency_key()
223 include_retries: If False (default), exclude retry attempts from results
224 """
225 from .models import JobRequest
226
227 job_class_name = jobs_registry.get_job_class_name(self.__class__)
228
229 if concurrency_key is None:
230 if self.job_process:
231 concurrency_key = self.job_process.concurrency_key
232 else:
233 concurrency_key = self.default_concurrency_key()
234
235 filters = {"job_class": job_class_name}
236 if concurrency_key:
237 filters["concurrency_key"] = concurrency_key
238
239 qs = JobRequest.query.filter(**filters)
240
241 if not include_retries:
242 qs = qs.filter(retry_attempt=0)
243
244 return qs
245
246 def get_processing_jobs(
247 self,
248 *,
249 concurrency_key: str | None = None,
250 include_retries: bool = False,
251 include_self: bool = False,
252 ) -> postgres.QuerySet:
253 """
254 Get currently processing jobs (JobProcess) for this job class.
255
256 Args:
257 concurrency_key: Optional concurrency_key to filter by. If None, uses self.job_process.concurrency_key (if available) or self.default_concurrency_key()
258 include_retries: If False (default), exclude retry attempts from results
259 """
260 from .models import JobProcess
261
262 job_class_name = jobs_registry.get_job_class_name(self.__class__)
263
264 if concurrency_key is None:
265 if self.job_process:
266 concurrency_key = self.job_process.concurrency_key
267 else:
268 concurrency_key = self.default_concurrency_key()
269
270 filters = {"job_class": job_class_name}
271 if concurrency_key:
272 filters["concurrency_key"] = concurrency_key
273
274 qs = JobProcess.query.filter(**filters)
275
276 if not include_retries:
277 qs = qs.filter(retry_attempt=0)
278
279 if not include_self and self.job_process:
280 qs = qs.exclude(id=self.job_process.id)
281
282 return qs
283
284 def should_enqueue(self, concurrency_key: str) -> bool:
285 """
286 Called before enqueueing job. Return False to skip.
287
288 Args:
289 concurrency_key: The resolved concurrency_key (from default_concurrency_key() or override)
290
291 Default behavior:
292 - If concurrency_key is empty: no restrictions (always enqueue)
293 - If concurrency_key is set: enforce uniqueness (only one job with this key can be pending or processing)
294
295 Override to implement custom concurrency control:
296 - Concurrency limits
297 - Rate limits
298 - Custom business logic
299
300 Example:
301 def should_enqueue(self, concurrency_key):
302 # Max 3 processing, 1 pending per concurrency_key
303 processing = self.get_processing_jobs(concurrency_key).count()
304 pending = self.get_requested_jobs(concurrency_key).count()
305 return processing < 3 and pending < 1
306 """
307 if not concurrency_key:
308 # No key = no uniqueness check
309 return True
310
311 # Key set = enforce uniqueness (include retries for strong guarantee)
312 return (
313 self.get_processing_jobs(
314 concurrency_key=concurrency_key, include_retries=True
315 ).count()
316 == 0
317 and self.get_requested_jobs(
318 concurrency_key=concurrency_key, include_retries=True
319 ).count()
320 == 0
321 )
322
323 def default_concurrency_key(self) -> str:
324 """
325 Default identifier for this job.
326
327 Use for:
328 - Deduplication
329 - Grouping related jobs
330 - Concurrency control
331
332 Return empty string (default) for no grouping.
333 Can be overridden per-call via concurrency_key parameter in run_in_worker().
334 """
335 return ""
336
337 def default_queue(self) -> str:
338 """Default queue for this job. Can be overridden in run_in_worker()."""
339 return "default"
340
341 def default_priority(self) -> int:
342 """
343 Default priority for this job. Can be overridden in run_in_worker().
344
345 Higher numbers run first: 10 > 5 > 0 > -5 > -10
346 - Use positive numbers for high priority jobs
347 - Use negative numbers for low priority jobs
348 - Default is 0
349 """
350 return 0
351
352 def default_retries(self) -> int:
353 """Default number of retry attempts. Can be overridden in run_in_worker()."""
354 return 0
355
356 def calculate_retry_delay(self, attempt: int) -> int:
357 """
358 Calculate a delay in seconds before the next retry attempt.
359
360 On the first retry, attempt will be 1.
361 """
362 return 0
363
364 def on_aborted(self, result: JobResult) -> None:
365 """
366 Called when this job's process was terminated externally before run()
367 could complete (status LOST or CANCELLED). Default no-op.
368
369 See README "Worker resilience" for the full contract.
370 """
371
372 def get_enqueue_lock(
373 self, concurrency_key: str
374 ) -> AbstractContextManager[None] | None:
375 """
376 Return a context manager for the enqueue lock, or None for no locking.
377
378 Default: PostgreSQL advisory lock (None if empty concurrency_key).
379 Override to provide custom locking (Redis, etcd, etc.).
380
381 The returned context manager is used to wrap the should_enqueue() check
382 and job creation, ensuring atomicity.
383
384 Example with Redis:
385 def get_enqueue_lock(self, concurrency_key):
386 import redis
387 return redis_client.lock(f"job:{concurrency_key}", timeout=5)
388
389 Example with custom implementation:
390 from contextlib import contextmanager
391
392 @contextmanager
393 def get_enqueue_lock(self, concurrency_key):
394 my_lock.acquire(concurrency_key)
395 try:
396 yield
397 finally:
398 my_lock.release(concurrency_key)
399
400 To disable locking:
401 def get_enqueue_lock(self, concurrency_key):
402 return None
403 """
404 # No locking if no concurrency_key
405 if not concurrency_key:
406 return None
407
408 return postgres_advisory_lock(self, concurrency_key)