Skip to content

Commit b11852c

Browse files
ix(v1/kv_cache): resolve async KV transfer bug in cascade attention
* Replace ref_cnt-based common prefix detection with running request tracking * Update get_num_common_prefix_blocks() to accept running_request_ids set * Fix FullAttentionManager to count actual references from running requests * Prevent incorrect cascade attention when async KV offloading delays cleanup This resolves a bug where completed requests with pending async transfers still contributed to ref_cnt, causing incorrect cascade attention decisions. Signed-off-by: ayushsatyam146 <ayushsatyam146@gmail.com>
1 parent 0eecb31 commit b11852c

File tree

4 files changed

+90
-75
lines changed

4 files changed

+90
-75
lines changed

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,27 +138,29 @@ def free(self, request_id: str) -> None:
138138
for manager in self.single_type_managers:
139139
manager.free(request_id)
140140

141-
def get_num_common_prefix_blocks(self, request_id: str,
142-
num_running_requests: int) -> list[int]:
141+
def get_num_common_prefix_blocks(
142+
self, running_request_id: str, running_request_ids: list[str],
143+
transfering_request_ids: list[str]) -> list[int]:
143144
"""
144145
Get the number of common prefix blocks for all requests in the RUNNING
145-
state for each kv cache group.
146+
and TRANSFERING state for each kv cache group.
146147
147148
Args:
148-
request_id: The request ID.
149-
num_running_requests: The total number of requests in the RUNNING
150-
state.
149+
running_request_id: The request ID of the running request.
150+
running_request_ids: List of all request IDs in the RUNNING state.
151+
transfering_request_ids: List of request IDs in
152+
WAITING_FOR_REMOTE_KVS state.
151153
152154
Returns:
153155
list[int]: The number of common prefix blocks for all requests in
154156
the RUNNING state for each kv cache group.
155157
"""
156-
num_blocks_per_group = [
157-
manager.get_num_common_prefix_blocks(request_id,
158-
num_running_requests)
158+
return [
159+
manager.get_num_common_prefix_blocks(running_request_id,
160+
running_request_ids,
161+
transfering_request_ids)
159162
for manager in self.single_type_managers
160163
]
161-
return num_blocks_per_group
162164

163165
def remove_skipped_blocks(self, request_id: str,
164166
num_computed_tokens: int) -> None:
@@ -209,8 +211,9 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
209211
dcp_world_size=dcp_world_size)
210212
self.num_single_type_manager = len(self.single_type_managers)
211213

212-
def get_num_common_prefix_blocks(self, request_id: str,
213-
num_running_requests: int) -> list[int]:
214+
def get_num_common_prefix_blocks(
215+
self, running_request_id: str, running_request_ids: list[str],
216+
transfering_request_ids: list[str]) -> list[int]:
214217
return [0] * self.num_single_type_manager
215218

216219
def find_longest_cache_hit(

vllm/v1/core/kv_cache_manager.py

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vllm.v1.core.kv_cache_utils import KVCacheBlock
1111
from vllm.v1.kv_cache_interface import KVCacheConfig
1212
from vllm.v1.metrics.stats import PrefixCacheStats
13-
from vllm.v1.request import Request, RequestStatus
13+
from vllm.v1.request import Request
1414

1515
logger = init_logger(__name__)
1616

@@ -331,46 +331,29 @@ def reset_prefix_cache(self) -> bool:
331331

332332
def get_num_common_prefix_blocks(
333333
self,
334-
request: Request,
335-
num_running_requests: int,
334+
running_request_id: str,
335+
running_request_ids: list[str],
336+
transfering_request_ids: list[str],
336337
) -> list[int]:
337338
"""Calculate the number of common prefix blocks shared by all requests
338-
in the RUNNING state for each kv cache group.
339-
340-
The function determines this by selecting any request and iterating
341-
through its blocks. A block is considered a common prefix block if its
342-
`ref_cnt` equals the total number of requests in the RUNNING state.
343-
344-
NOTE(woosuk): The number of requests in the RUNNING state is **greater
345-
than or equal to** the number of requests scheduled in the current step.
346-
This is because the RUNNING state only indicates that:
347-
1. The request has not yet finished, and
348-
2. The request holds its blocks unfreed.
349-
350-
While all scheduled requests must be in the RUNNING state, the inverse
351-
is not necessarily true. There may be RUNNING requests that are not
352-
scheduled in the current step.
339+
in the RUNNING state for each kv cache group. A block is considered a
340+
common prefix block if it is referenced by ALL currently running
341+
requests.
353342
354-
This can result in an edge case where the number of common prefix blocks
355-
is 0, even though all scheduled requests share a common prefix. This
356-
occurs because there may be unscheduled RUNNING requests that do not
357-
share the common prefix. Currently, this case cannot be easily detected,
358-
so the function returns 0 in such cases.
343+
This approach correctly handles async KV offloading scenarios where
344+
completed requests may still hold block references while no longer
345+
being in the RUNNING state.
359346
360347
Args:
361-
request: Any request in the RUNNING state, used to identify the
362-
common prefix blocks.
363-
num_running_requests: The total number of requests in the RUNNING
364-
state. This can be different from the number of scheduled
365-
requests in the current step.
348+
running_request_id: The request ID of the running request.
349+
running_request_ids: List of all request IDs in the RUNNING state.
350+
transfering_request_ids: List of request IDs in transfer state.
366351
367352
Returns:
368-
list[int]: The number of common prefix blocks for each kv cache
369-
group.
353+
list[int]: Number of common prefix blocks for each kv cache group.
370354
"""
371-
assert request.status == RequestStatus.RUNNING
372355
return self.coordinator.get_num_common_prefix_blocks(
373-
request.request_id, num_running_requests)
356+
running_request_id, running_request_ids, transfering_request_ids)
374357

375358
def take_events(self) -> list[KVCacheEvent]:
376359
"""Take the KV cache events from the block pool.
@@ -396,4 +379,4 @@ def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
396379
def create_empty_block_list(self) -> KVCacheBlocks:
397380
"""Creates a new KVCacheBlocks instance with no blocks."""
398381
return KVCacheBlocks(tuple([]
399-
for _ in range(self.num_kv_cache_groups)))
382+
for _ in range(self.num_kv_cache_groups)))

vllm/v1/core/sched/scheduler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,9 +561,18 @@ def schedule(self) -> SchedulerOutput:
561561
self.kv_cache_config.kv_cache_groups)
562562
if self.running:
563563
any_request = self.running[0]
564+
running_request_ids = {req.request_id for req in self.running}
565+
566+
# Include requests in KV transfer state for common prefix calc
567+
transferring_request_ids = [
568+
req_id for req_id, request in self.requests.items()
569+
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS and
570+
any(self.kv_cache_manager.get_blocks(req_id).get_block_ids())
571+
]
564572
num_common_prefix_blocks = (
565573
self.kv_cache_manager.get_num_common_prefix_blocks(
566-
any_request, len(self.running)))
574+
any_request.request_id, list(running_request_ids),
575+
transferring_request_ids))
567576

568577
# Construct the scheduler output.
569578
new_reqs_data = [

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -171,22 +171,10 @@ def free(self, request_id: str) -> None:
171171
self.num_cached_block.pop(request_id, None)
172172

173173
@abstractmethod
174-
def get_num_common_prefix_blocks(self, request_id: str,
175-
num_running_requests: int) -> int:
176-
"""
177-
Get the number of common prefix blocks for all requests in the RUNNING
178-
state.
179-
180-
Args:
181-
request_id: The request ID.
182-
num_running_requests: The total number of requests in the RUNNING
183-
state.
184-
185-
Returns:
186-
The number of common prefix blocks for all requests in the RUNNING
187-
state.
188-
"""
189-
174+
def get_num_common_prefix_blocks(
175+
self, running_request_id: str, running_request_ids: list[str],
176+
transfering_request_ids: list[str]) -> int:
177+
"""Get the number of common prefix blocks for all running requests."""
190178
raise NotImplementedError
191179

192180
@classmethod
@@ -289,15 +277,34 @@ def remove_skipped_blocks(self, request_id: str,
289277
# No need to remove blocks for full attention.
290278
pass
291279

292-
def get_num_common_prefix_blocks(self, request_id: str,
293-
num_running_requests: int) -> int:
294-
blocks = self.req_to_blocks[request_id]
280+
def get_num_common_prefix_blocks(
281+
self, running_request_id: str, running_request_ids: list[str],
282+
transfering_request_ids: list[str]) -> int:
283+
"""Get common prefix blocks shared by all running and transferring
284+
requests."""
285+
if running_request_id not in self.req_to_blocks:
286+
return 0
287+
288+
all_request_ids = running_request_ids + transfering_request_ids
289+
request_blocks = [
290+
self.req_to_blocks[req_id] for req_id in all_request_ids
291+
if req_id in self.req_to_blocks
292+
]
293+
294+
reference_blocks = self.req_to_blocks[running_request_id]
295+
total_requests = len(all_request_ids)
296+
295297
num_common_blocks = 0
296-
for block in blocks:
297-
if block.ref_cnt == num_running_requests:
298+
for i, ref_block in enumerate(reference_blocks):
299+
requests_with_block = sum(
300+
1 for blocks in request_blocks if i < len(blocks)
301+
and blocks[i].block_id == ref_block.block_id)
302+
303+
if requests_with_block == total_requests:
298304
num_common_blocks += 1
299305
else:
300306
break
307+
301308
return num_common_blocks
302309

303310

@@ -390,8 +397,12 @@ def remove_skipped_blocks(self, request_id: str,
390397
blocks[i] = self._null_block
391398
self.block_pool.free_blocks(removed_blocks)
392399

393-
def get_num_common_prefix_blocks(self, request_id: str,
394-
num_running_requests: int) -> int:
400+
def get_num_common_prefix_blocks(
401+
self,
402+
running_request_id: str,
403+
running_request_ids: list[str],
404+
transfering_request_ids: list[str],
405+
) -> int:
395406
"""
396407
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
397408
So it's not correct to count ref_cnt like FullAttentionManager. Return
@@ -518,8 +529,12 @@ def remove_skipped_blocks(self, request_id: str,
518529
blocks[i] = self._null_block
519530
self.block_pool.free_blocks(removed_blocks)
520531

521-
def get_num_common_prefix_blocks(self, request_id: str,
522-
num_running_requests: int) -> int:
532+
def get_num_common_prefix_blocks(
533+
self,
534+
running_request_id: str,
535+
running_request_ids: list[str],
536+
transfering_request_ids: list[str],
537+
) -> int:
523538
"""
524539
cascade attention is not supported by chunked local attention.
525540
"""
@@ -555,8 +570,12 @@ def remove_skipped_blocks(self, request_id: str,
555570
# remove blocks.
556571
pass
557572

558-
def get_num_common_prefix_blocks(self, request_id: str,
559-
num_running_requests: int) -> int:
573+
def get_num_common_prefix_blocks(
574+
self,
575+
running_request_id: str,
576+
running_request_ids: list[str],
577+
transfering_request_ids: list[str],
578+
) -> int:
560579
return 0
561580

562581
def get_num_blocks_to_allocate(
@@ -618,8 +637,9 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None:
618637
# requests, so this method is not relevant.
619638
raise ValueError("Should not be called as prefix caching is disabled.")
620639

621-
def get_num_common_prefix_blocks(self, request_id: str,
622-
num_running_requests: int) -> int:
640+
def get_num_common_prefix_blocks(
641+
self, running_request_id: str, running_request_ids: list[str],
642+
transfering_request_ids: list[str]) -> int:
623643
# Cross-attention blocks contain request-specific encoder states
624644
# and are not shared between different requests
625645
return 0

0 commit comments

Comments
 (0)