@@ -75,7 +75,7 @@ def result_ready(self):
75
75
async def run(self, config=None):
76
76
print("Worker started with config:", config)
77
77
# Wait until stop is requested
78
- await self._tsignal_stopping.wait ()
78
+ await self.wait_for_stop ()
79
79
print("Worker finishing...")
80
80
81
81
async def do_work(self, data):
@@ -108,9 +108,7 @@ def __init__(self):
108
108
All operations that access or modify worker's lifecycle state must be
109
109
performed while holding this lock.
110
110
"""
111
- self ._tsignal_lifecycle_lock = (
112
- threading .RLock ()
113
- ) # Renamed lock for loop and thread
111
+ self ._tsignal_lifecycle_lock = threading .RLock ()
114
112
self ._tsignal_stopping = asyncio .Event ()
115
113
self ._tsignal_affinity = object ()
116
114
self ._tsignal_process_queue_task = None
@@ -120,8 +118,10 @@ def __init__(self):
120
118
@property
121
119
def event_loop (self ) -> asyncio .AbstractEventLoop :
122
120
"""Returns the worker's event loop"""
121
+
123
122
if not self ._tsignal_loop :
124
123
raise RuntimeError ("Worker not started" )
124
+
125
125
return self ._tsignal_loop
126
126
127
127
@t_signal
@@ -134,6 +134,7 @@ def stopped(self):
134
134
135
135
async def run (self , * args , ** kwargs ):
136
136
"""Run the worker."""
137
+
137
138
logger .debug ("[WorkerClass][run] calling super" )
138
139
139
140
super_run = getattr (super (), _WorkerConstants .RUN , None )
@@ -161,8 +162,10 @@ async def run(self, *args, **kwargs):
161
162
162
163
async def _process_queue (self ):
163
164
"""Process the task queue."""
165
+
164
166
while not self ._tsignal_stopping .is_set ():
165
167
coro = await self ._tsignal_task_queue .get ()
168
+
166
169
try :
167
170
await coro
168
171
except Exception as e :
@@ -176,12 +179,14 @@ async def _process_queue(self):
176
179
177
180
async def start_queue (self ):
178
181
"""Start the task queue processing. Returns the queue task."""
182
+
179
183
self ._tsignal_process_queue_task = asyncio .create_task (
180
184
self ._process_queue ()
181
185
)
182
186
183
187
def queue_task (self , coro ):
184
188
"""Method to add a task to the queue"""
189
+
185
190
if not asyncio .iscoroutine (coro ):
186
191
logger .error (
187
192
"[WorkerClass][queue_task] Task must be a coroutine object: %s" ,
@@ -196,6 +201,7 @@ def queue_task(self, coro):
196
201
197
202
def start (self , * args , ** kwargs ):
198
203
"""Start the worker thread."""
204
+
199
205
run_coro = kwargs .pop (_WorkerConstants .RUN_CORO , None )
200
206
201
207
if run_coro is not None and not asyncio .iscoroutine (run_coro ):
@@ -207,6 +213,7 @@ def start(self, *args, **kwargs):
207
213
208
214
def thread_main ():
209
215
"""Thread main function."""
216
+
210
217
self ._tsignal_task_queue = asyncio .Queue ()
211
218
212
219
with self ._tsignal_lifecycle_lock :
@@ -215,6 +222,7 @@ def thread_main():
215
222
216
223
async def runner ():
217
224
"""Runner function."""
225
+
218
226
self .started .emit ()
219
227
220
228
if run_coro is not None :
@@ -324,4 +332,9 @@ def move_to_thread(self, target):
324
332
self ._tsignal_affinity ,
325
333
)
326
334
335
+ async def wait_for_stop (self ):
336
+ """Wait for the worker to stop."""
337
+
338
+ await self ._tsignal_stopping .wait ()
339
+
327
340
return WorkerClass
0 commit comments