@@ -171,22 +171,10 @@ def free(self, request_id: str) -> None:
171
171
self .num_cached_block .pop (request_id , None )
172
172
173
173
@abstractmethod
174
- def get_num_common_prefix_blocks (self , request_id : str ,
175
- num_running_requests : int ) -> int :
176
- """
177
- Get the number of common prefix blocks for all requests in the RUNNING
178
- state.
179
-
180
- Args:
181
- request_id: The request ID.
182
- num_running_requests: The total number of requests in the RUNNING
183
- state.
184
-
185
- Returns:
186
- The number of common prefix blocks for all requests in the RUNNING
187
- state.
188
- """
189
-
174
+ def get_num_common_prefix_blocks (
175
+ self , running_request_id : str , running_request_ids : list [str ],
176
+ transfering_request_ids : list [str ]) -> int :
177
+ """Get the number of common prefix blocks for all running requests."""
190
178
raise NotImplementedError
191
179
192
180
@classmethod
@@ -289,15 +277,34 @@ def remove_skipped_blocks(self, request_id: str,
289
277
# No need to remove blocks for full attention.
290
278
pass
291
279
292
- def get_num_common_prefix_blocks (self , request_id : str ,
293
- num_running_requests : int ) -> int :
294
- blocks = self .req_to_blocks [request_id ]
280
+ def get_num_common_prefix_blocks (
281
+ self , running_request_id : str , running_request_ids : list [str ],
282
+ transfering_request_ids : list [str ]) -> int :
283
+ """Get common prefix blocks shared by all running and transferring
284
+ requests."""
285
+ if running_request_id not in self .req_to_blocks :
286
+ return 0
287
+
288
+ all_request_ids = running_request_ids + transfering_request_ids
289
+ request_blocks = [
290
+ self .req_to_blocks [req_id ] for req_id in all_request_ids
291
+ if req_id in self .req_to_blocks
292
+ ]
293
+
294
+ reference_blocks = self .req_to_blocks [running_request_id ]
295
+ total_requests = len (all_request_ids )
296
+
295
297
num_common_blocks = 0
296
- for block in blocks :
297
- if block .ref_cnt == num_running_requests :
298
+ for i , ref_block in enumerate (reference_blocks ):
299
+ requests_with_block = sum (
300
+ 1 for blocks in request_blocks if i < len (blocks )
301
+ and blocks [i ].block_id == ref_block .block_id )
302
+
303
+ if requests_with_block == total_requests :
298
304
num_common_blocks += 1
299
305
else :
300
306
break
307
+
301
308
return num_common_blocks
302
309
303
310
@@ -390,8 +397,12 @@ def remove_skipped_blocks(self, request_id: str,
390
397
blocks [i ] = self ._null_block
391
398
self .block_pool .free_blocks (removed_blocks )
392
399
393
- def get_num_common_prefix_blocks (self , request_id : str ,
394
- num_running_requests : int ) -> int :
400
+ def get_num_common_prefix_blocks (
401
+ self ,
402
+ running_request_id : str ,
403
+ running_request_ids : list [str ],
404
+ transfering_request_ids : list [str ],
405
+ ) -> int :
395
406
"""
396
407
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
397
408
So it's not correct to count ref_cnt like FullAttentionManager. Return
@@ -518,8 +529,12 @@ def remove_skipped_blocks(self, request_id: str,
518
529
blocks [i ] = self ._null_block
519
530
self .block_pool .free_blocks (removed_blocks )
520
531
521
- def get_num_common_prefix_blocks (self , request_id : str ,
522
- num_running_requests : int ) -> int :
532
+ def get_num_common_prefix_blocks (
533
+ self ,
534
+ running_request_id : str ,
535
+ running_request_ids : list [str ],
536
+ transfering_request_ids : list [str ],
537
+ ) -> int :
523
538
"""
524
539
cascade attention is not supported by chunked local attention.
525
540
"""
@@ -555,8 +570,12 @@ def remove_skipped_blocks(self, request_id: str,
555
570
# remove blocks.
556
571
pass
557
572
558
- def get_num_common_prefix_blocks (self , request_id : str ,
559
- num_running_requests : int ) -> int :
573
+ def get_num_common_prefix_blocks (
574
+ self ,
575
+ running_request_id : str ,
576
+ running_request_ids : list [str ],
577
+ transfering_request_ids : list [str ],
578
+ ) -> int :
560
579
return 0
561
580
562
581
def get_num_blocks_to_allocate (
@@ -618,8 +637,9 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None:
618
637
# requests, so this method is not relevant.
619
638
raise ValueError ("Should not be called as prefix caching is disabled." )
620
639
621
- def get_num_common_prefix_blocks (self , request_id : str ,
622
- num_running_requests : int ) -> int :
640
+ def get_num_common_prefix_blocks (
641
+ self , running_request_id : str , running_request_ids : list [str ],
642
+ transfering_request_ids : list [str ]) -> int :
623
643
# Cross-attention blocks contain request-specific encoder states
624
644
# and are not shared between different requests
625
645
return 0
0 commit comments