1import datetime
2import inspect
3import logging
4
5from opentelemetry import trace
6from opentelemetry.semconv._incubating.attributes.code_attributes import (
7 CODE_FILEPATH,
8 CODE_LINENO,
9)
10from opentelemetry.semconv._incubating.attributes.messaging_attributes import (
11 MESSAGING_DESTINATION_NAME,
12 MESSAGING_MESSAGE_ID,
13 MESSAGING_OPERATION_NAME,
14 MESSAGING_OPERATION_TYPE,
15 MESSAGING_SYSTEM,
16 MessagingOperationTypeValues,
17)
18from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
19from opentelemetry.trace import SpanKind, format_span_id, format_trace_id
20
21from plain.models import IntegrityError
22from plain.utils import timezone
23
24from .registry import JobParameters, jobs_registry
25
26logger = logging.getLogger(__name__)
27tracer = trace.get_tracer("plain.worker")
28
29
30class JobType(type):
31 """
32 Metaclass allows us to capture the original args/kwargs
33 used to instantiate the job, so we can store them in the database
34 when we schedule the job.
35 """
36
37 def __call__(self, *args, **kwargs):
38 instance = super().__call__(*args, **kwargs)
39 instance._init_args = args
40 instance._init_kwargs = kwargs
41 return instance
42
43
44class Job(metaclass=JobType):
45 def run(self):
46 raise NotImplementedError
47
48 def run_in_worker(
49 self,
50 *,
51 queue: str | None = None,
52 delay: int | datetime.timedelta | datetime.datetime | None = None,
53 priority: int | None = None,
54 retries: int | None = None,
55 retry_attempt: int = 0,
56 unique_key: str | None = None,
57 ):
58 from .models import JobRequest
59
60 job_class_name = jobs_registry.get_job_class_name(self.__class__)
61
62 if queue is None:
63 queue = self.get_queue()
64
65 with tracer.start_as_current_span(
66 f"run_in_worker {job_class_name}",
67 kind=SpanKind.PRODUCER,
68 attributes={
69 MESSAGING_SYSTEM: "plain.worker",
70 MESSAGING_OPERATION_TYPE: MessagingOperationTypeValues.SEND.value,
71 MESSAGING_OPERATION_NAME: "run_in_worker",
72 MESSAGING_DESTINATION_NAME: queue,
73 },
74 ) as span:
75 try:
76 # Try to automatically annotate the source of the job
77 caller = inspect.stack()[1]
78 source = f"{caller.filename}:{caller.lineno}"
79 span.set_attributes(
80 {
81 CODE_FILEPATH: caller.filename,
82 CODE_LINENO: caller.lineno,
83 }
84 )
85 except (IndexError, AttributeError):
86 source = ""
87
88 parameters = JobParameters.to_json(self._init_args, self._init_kwargs)
89
90 if priority is None:
91 priority = self.get_priority()
92
93 if retries is None:
94 retries = self.get_retries()
95
96 if delay is None:
97 start_at = None
98 elif isinstance(delay, int):
99 start_at = timezone.now() + datetime.timedelta(seconds=delay)
100 elif isinstance(delay, datetime.timedelta):
101 start_at = timezone.now() + delay
102 elif isinstance(delay, datetime.datetime):
103 start_at = delay
104 else:
105 raise ValueError(f"Invalid delay: {delay}")
106
107 if unique_key is None:
108 unique_key = self.get_unique_key()
109
110 if unique_key:
111 # Only need to look at in progress jobs
112 # if we also have a unique key.
113 # Otherwise it's up to the user to use _in_progress()
114 if running := self._in_progress(unique_key):
115 span.set_attribute(ERROR_TYPE, "DuplicateJob")
116 return running
117
118 # Is recording is not enough here... because we also record for summaries!
119
120 # Capture current trace context
121 current_span = trace.get_current_span()
122 span_context = current_span.get_span_context()
123
124 # Only include trace context if the span is being recorded (sampled)
125 # This ensures jobs are only linked to traces that are actually being collected
126 if current_span.is_recording() and span_context.is_valid:
127 trace_id = f"0x{format_trace_id(span_context.trace_id)}"
128 span_id = f"0x{format_span_id(span_context.span_id)}"
129 else:
130 trace_id = None
131 span_id = None
132
133 try:
134 job_request = JobRequest(
135 job_class=job_class_name,
136 parameters=parameters,
137 start_at=start_at,
138 source=source,
139 queue=queue,
140 priority=priority,
141 retries=retries,
142 retry_attempt=retry_attempt,
143 unique_key=unique_key,
144 trace_id=trace_id,
145 span_id=span_id,
146 )
147 job_request.save(
148 clean_and_validate=False
149 ) # So IntegrityError is raised on unique instead of potentially confusing ValidationError...
150
151 span.set_attribute(
152 MESSAGING_MESSAGE_ID,
153 str(job_request.uuid),
154 )
155
156 # Add job UUID to current span for bidirectional linking
157 span.set_attribute("job.uuid", str(job_request.uuid))
158 span.set_status(trace.StatusCode.OK)
159
160 return job_request
161 except IntegrityError as e:
162 span.set_attribute(ERROR_TYPE, "IntegrityError")
163 span.set_status(trace.Status(trace.StatusCode.ERROR, "Duplicate job"))
164 logger.warning("Job already in progress: %s", e)
165 # Try to return the _in_progress list again
166 return self._in_progress(unique_key)
167
168 def _in_progress(self, unique_key):
169 """Get all JobRequests and JobProcess that are currently in progress, regardless of queue."""
170 from .models import JobProcess, JobRequest
171
172 job_class_name = jobs_registry.get_job_class_name(self.__class__)
173
174 job_requests = JobRequest.query.filter(
175 job_class=job_class_name,
176 unique_key=unique_key,
177 )
178
179 jobs = JobProcess.query.filter(
180 job_class=job_class_name,
181 unique_key=unique_key,
182 )
183
184 return list(job_requests) + list(jobs)
185
186 def get_unique_key(self) -> str:
187 """
188 A unique key to prevent duplicate jobs from being queued.
189 Enabled by returning a non-empty string.
190
191 Note that this is not a "once and only once" guarantee, but rather
192 an "at least once" guarantee. Jobs should still be idempotent in case
193 multiple instances are queued in a race condition.
194 """
195 return ""
196
197 def get_queue(self) -> str:
198 return "default"
199
200 def get_priority(self) -> int:
201 """
202 Return the default priority for this job.
203
204 Higher numbers run first: 10 > 5 > 0 > -5 > -10
205 - Use positive numbers for high priority jobs
206 - Use negative numbers for low priority jobs
207 - Default is 0
208 """
209 return 0
210
211 def get_retries(self) -> int:
212 return 0
213
214 def get_retry_delay(self, attempt: int) -> int:
215 """
216 Calculate a delay in seconds before the next retry attempt.
217
218 On the first retry, attempt will be 1.
219 """
220 return 0