v0.146.0
  1from __future__ import annotations
  2
  3import datetime
  4import sys
  5import time
  6from abc import ABCMeta, abstractmethod
  7from contextlib import AbstractContextManager, nullcontext
  8from typing import TYPE_CHECKING, Any
  9
 10from opentelemetry import trace
 11from opentelemetry.semconv._incubating.attributes.messaging_attributes import (
 12    MESSAGING_DESTINATION_NAME,
 13    MESSAGING_MESSAGE_ID,
 14    MESSAGING_OPERATION_NAME,
 15    MESSAGING_OPERATION_TYPE,
 16    MESSAGING_SYSTEM,
 17    MessagingOperationTypeValues,
 18)
 19from opentelemetry.semconv.attributes.code_attributes import (
 20    CODE_FILE_PATH,
 21    CODE_FUNCTION_NAME,
 22    CODE_LINE_NUMBER,
 23)
 24from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
 25from opentelemetry.trace import SpanKind, format_span_id, format_trace_id
 26
 27from plain import postgres
 28from plain.postgres import transaction
 29from plain.utils import timezone
 30
 31from .locks import postgres_advisory_lock
 32from .otel import (
 33    operation_duration_histogram,
 34    record_span_error,
 35    sent_messages_counter,
 36    tracer,
 37)
 38from .registry import JobParameters, jobs_registry
 39
 40if TYPE_CHECKING:
 41    from .models import JobProcess, JobRequest, JobResult
 42
 43
 44class JobType(ABCMeta):
 45    """
 46    Metaclass allows us to capture the original args/kwargs
 47    used to instantiate the job, so we can store them in the database
 48    when we schedule the job.
 49    """
 50
 51    def __call__(self, *args: Any, **kwargs: Any) -> Job:
 52        instance = super().__call__(*args, **kwargs)
 53        instance._init_args = args
 54        instance._init_kwargs = kwargs
 55        return instance
 56
 57
 58class Job(metaclass=JobType):
 59    # Set by JobType metaclass when the job is instantiated
 60    _init_args: tuple[Any, ...]
 61    _init_kwargs: dict[str, Any]
 62
 63    # Set by JobProcess when the job is executed
 64    # Useful for jobs that need to query and exclude themselves
 65    job_process: JobProcess | None = None
 66
 67    @abstractmethod
 68    def run(self) -> None:
 69        pass
 70
 71    def run_in_worker(
 72        self,
 73        *,
 74        queue: str | None = None,
 75        delay: int | datetime.timedelta | datetime.datetime | None = None,
 76        priority: int | None = None,
 77        retries: int | None = None,
 78        retry_attempt: int = 0,
 79        concurrency_key: str | None = None,
 80    ) -> JobRequest | None:
 81        from .models import JobRequest
 82
 83        job_class_name = jobs_registry.get_job_class_name(self.__class__)
 84
 85        if queue is None:
 86            queue = self.default_queue()
 87
 88        metric_attributes: dict[str, Any] = {
 89            MESSAGING_SYSTEM: "plain.jobs",
 90            MESSAGING_OPERATION_TYPE: MessagingOperationTypeValues.SEND.value,
 91            MESSAGING_DESTINATION_NAME: queue,
 92            CODE_FUNCTION_NAME: f"{job_class_name}.run_in_worker",
 93        }
 94        start_time = time.perf_counter()
 95        skipped = False
 96        with tracer.start_as_current_span(
 97            f"send {queue}",
 98            kind=SpanKind.PRODUCER,
 99            attributes={**metric_attributes, MESSAGING_OPERATION_NAME: "send"},
100            # We record manually via record_span_error (escaped=True) at the
101            # workflow boundary; suppress the SDK's escaped=False auto-record
102            # so failed sends carry a single, correctly-marked event.
103            record_exception=False,
104        ) as span:
105            try:
106                try:
107                    frame = sys._getframe(1)
108                    filename = frame.f_code.co_filename
109                    lineno = frame.f_lineno
110                    source = f"{filename}:{lineno}"
111                    span.set_attributes(
112                        {
113                            CODE_FILE_PATH: filename,
114                            CODE_LINE_NUMBER: lineno,
115                        }
116                    )
117                except (ValueError, AttributeError):
118                    source = ""
119
120                parameters = JobParameters.to_json(self._init_args, self._init_kwargs)
121
122                if priority is None:
123                    priority = self.default_priority()
124
125                if retries is None:
126                    retries = self.default_retries()
127
128                if delay is None:
129                    start_at = None
130                elif isinstance(delay, int):
131                    start_at = timezone.now() + datetime.timedelta(seconds=delay)
132                elif isinstance(delay, datetime.timedelta):
133                    start_at = timezone.now() + delay
134                elif isinstance(delay, datetime.datetime):
135                    start_at = delay
136                else:
137                    raise ValueError(f"Invalid delay: {delay}")
138
139                if concurrency_key is None:
140                    concurrency_key = self.default_concurrency_key()
141
142                # Capture current trace context
143                current_span = trace.get_current_span()
144                span_context = current_span.get_span_context()
145
146                # Only include trace context if the span is being recorded (sampled)
147                # This ensures jobs are only linked to traces that are actually being collected
148                if current_span.is_recording() and span_context.is_valid:
149                    trace_id = f"0x{format_trace_id(span_context.trace_id)}"
150                    span_id = f"0x{format_span_id(span_context.span_id)}"
151                else:
152                    trace_id = None
153                    span_id = None
154
155                # Use transaction with optional locking for race-free enqueue
156                with transaction.atomic():
157                    # Acquire lock via context manager (or nullcontext if None)
158                    with self.get_enqueue_lock(concurrency_key) or nullcontext():
159                        # Check with lock held (if using locks)
160                        if not self.should_enqueue(concurrency_key):
161                            span.set_attribute("job.enqueue.skipped", True)
162                            skipped = True
163                            return None
164
165                        # Create job with lock held
166                        job_request = JobRequest(
167                            job_class=job_class_name,
168                            parameters=parameters,
169                            start_at=start_at,
170                            source=source,
171                            queue=queue,
172                            priority=priority,
173                            retries=retries,
174                            retry_attempt=retry_attempt,
175                            concurrency_key=concurrency_key,
176                            trace_id=trace_id,
177                            span_id=span_id,
178                        )
179                        job_request.save()
180
181                        span.set_attribute(
182                            MESSAGING_MESSAGE_ID,
183                            str(job_request.uuid),
184                        )
185
186                        return job_request
187            except Exception as e:
188                # Stamps escaped=True on the span event, ERROR_TYPE on both
189                # the span and `metric_attributes` (so the finally below
190                # picks up the failed-send branch).
191                record_span_error(span, e, metric_attributes)
192                raise
193            finally:
194                # Skipped enqueues are visible on the span (`job.enqueue.skipped`)
195                # but do not fire the messaging counter — no message was sent, so
196                # there's nothing for `messaging.client.sent.messages` to count.
197                if not skipped:
198                    duration = time.perf_counter() - start_time
199                    if ERROR_TYPE in metric_attributes:
200                        # No commit is coming — record now so failed sends are visible.
201                        sent_messages_counter.add(1, metric_attributes)
202                        operation_duration_histogram.record(duration, metric_attributes)
203                    else:
204                        # Defer to the outer commit so a caller-level rollback
205                        # doesn't leave a phantom send. Runs immediately if not
206                        # inside a transaction.
207                        attrs = metric_attributes
208
209                        def _emit() -> None:
210                            sent_messages_counter.add(1, attrs)
211                            operation_duration_histogram.record(duration, attrs)
212
213                        transaction.on_commit(_emit)
214
215    def get_requested_jobs(
216        self, *, concurrency_key: str | None = None, include_retries: bool = False
217    ) -> postgres.QuerySet:
218        """
219        Get pending jobs (JobRequest) for this job class.
220
221        Args:
222            concurrency_key: Optional concurrency_key to filter by. If None, uses self.job_process.concurrency_key (if available) or self.default_concurrency_key()
223            include_retries: If False (default), exclude retry attempts from results
224        """
225        from .models import JobRequest
226
227        job_class_name = jobs_registry.get_job_class_name(self.__class__)
228
229        if concurrency_key is None:
230            if self.job_process:
231                concurrency_key = self.job_process.concurrency_key
232            else:
233                concurrency_key = self.default_concurrency_key()
234
235        filters = {"job_class": job_class_name}
236        if concurrency_key:
237            filters["concurrency_key"] = concurrency_key
238
239        qs = JobRequest.query.filter(**filters)
240
241        if not include_retries:
242            qs = qs.filter(retry_attempt=0)
243
244        return qs
245
246    def get_processing_jobs(
247        self,
248        *,
249        concurrency_key: str | None = None,
250        include_retries: bool = False,
251        include_self: bool = False,
252    ) -> postgres.QuerySet:
253        """
254        Get currently processing jobs (JobProcess) for this job class.
255
256        Args:
257            concurrency_key: Optional concurrency_key to filter by. If None, uses self.job_process.concurrency_key (if available) or self.default_concurrency_key()
258            include_retries: If False (default), exclude retry attempts from results
259        """
260        from .models import JobProcess
261
262        job_class_name = jobs_registry.get_job_class_name(self.__class__)
263
264        if concurrency_key is None:
265            if self.job_process:
266                concurrency_key = self.job_process.concurrency_key
267            else:
268                concurrency_key = self.default_concurrency_key()
269
270        filters = {"job_class": job_class_name}
271        if concurrency_key:
272            filters["concurrency_key"] = concurrency_key
273
274        qs = JobProcess.query.filter(**filters)
275
276        if not include_retries:
277            qs = qs.filter(retry_attempt=0)
278
279        if not include_self and self.job_process:
280            qs = qs.exclude(id=self.job_process.id)
281
282        return qs
283
284    def should_enqueue(self, concurrency_key: str) -> bool:
285        """
286        Called before enqueueing job. Return False to skip.
287
288        Args:
289            concurrency_key: The resolved concurrency_key (from default_concurrency_key() or override)
290
291        Default behavior:
292        - If concurrency_key is empty: no restrictions (always enqueue)
293        - If concurrency_key is set: enforce uniqueness (only one job with this key can be pending or processing)
294
295        Override to implement custom concurrency control:
296        - Concurrency limits
297        - Rate limits
298        - Custom business logic
299
300        Example:
301            def should_enqueue(self, concurrency_key):
302                # Max 3 processing, 1 pending per concurrency_key
303                processing = self.get_processing_jobs(concurrency_key).count()
304                pending = self.get_requested_jobs(concurrency_key).count()
305                return processing < 3 and pending < 1
306        """
307        if not concurrency_key:
308            # No key = no uniqueness check
309            return True
310
311        # Key set = enforce uniqueness (include retries for strong guarantee)
312        return (
313            self.get_processing_jobs(
314                concurrency_key=concurrency_key, include_retries=True
315            ).count()
316            == 0
317            and self.get_requested_jobs(
318                concurrency_key=concurrency_key, include_retries=True
319            ).count()
320            == 0
321        )
322
323    def default_concurrency_key(self) -> str:
324        """
325        Default identifier for this job.
326
327        Use for:
328        - Deduplication
329        - Grouping related jobs
330        - Concurrency control
331
332        Return empty string (default) for no grouping.
333        Can be overridden per-call via concurrency_key parameter in run_in_worker().
334        """
335        return ""
336
337    def default_queue(self) -> str:
338        """Default queue for this job. Can be overridden in run_in_worker()."""
339        return "default"
340
341    def default_priority(self) -> int:
342        """
343        Default priority for this job. Can be overridden in run_in_worker().
344
345        Higher numbers run first: 10 > 5 > 0 > -5 > -10
346        - Use positive numbers for high priority jobs
347        - Use negative numbers for low priority jobs
348        - Default is 0
349        """
350        return 0
351
352    def default_retries(self) -> int:
353        """Default number of retry attempts. Can be overridden in run_in_worker()."""
354        return 0
355
356    def calculate_retry_delay(self, attempt: int) -> int:
357        """
358        Calculate a delay in seconds before the next retry attempt.
359
360        On the first retry, attempt will be 1.
361        """
362        return 0
363
364    def on_aborted(self, result: JobResult) -> None:
365        """
366        Called when this job's process was terminated externally before run()
367        could complete (status LOST or CANCELLED). Default no-op.
368
369        See README "Worker resilience" for the full contract.
370        """
371
372    def get_enqueue_lock(
373        self, concurrency_key: str
374    ) -> AbstractContextManager[None] | None:
375        """
376        Return a context manager for the enqueue lock, or None for no locking.
377
378        Default: PostgreSQL advisory lock (None if empty concurrency_key).
379        Override to provide custom locking (Redis, etcd, etc.).
380
381        The returned context manager is used to wrap the should_enqueue() check
382        and job creation, ensuring atomicity.
383
384        Example with Redis:
385            def get_enqueue_lock(self, concurrency_key):
386                import redis
387                return redis_client.lock(f"job:{concurrency_key}", timeout=5)
388
389        Example with custom implementation:
390            from contextlib import contextmanager
391
392            @contextmanager
393            def get_enqueue_lock(self, concurrency_key):
394                my_lock.acquire(concurrency_key)
395                try:
396                    yield
397                finally:
398                    my_lock.release(concurrency_key)
399
400        To disable locking:
401            def get_enqueue_lock(self, concurrency_key):
402                return None
403        """
404        # No locking if no concurrency_key
405        if not concurrency_key:
406            return None
407
408        return postgres_advisory_lock(self, concurrency_key)