Skip to content

Conversation

ayushsatyam146
Copy link
Contributor

@ayushsatyam146 ayushsatyam146 commented Aug 24, 2025

Purpose

Solves #23130. This change fixes a critical bug in vLLM's cascade attention optimization in the V1 arch. The bug is in get_num_common_prefix_blocks(), which determines how many KV cache blocks are shared among all currently running requests to enable cascade attention optimizations.

Changes made

  • Replace ref_cnt-based common prefix detection with running request tracking
  • Update get_num_common_prefix_blocks() to accept running_request_ids set
  • Fix FullAttentionManager to count actual references from running requests
  • Prevent incorrect cascade attention when async KV offloading delays cleanup

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a critical bug in cascade attention related to asynchronous KV transfer by replacing the unreliable ref_cnt-based logic with explicit tracking of running requests. The changes are well-contained and logically sound. My review includes one suggestion to optimize the performance of the new common prefix block calculation, which could be a bottleneck in scenarios with many concurrent requests.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even after block in self.req_to_blocks[req_id] is fixed, I'm still concern about the performance when all requests are sharing a very long prefix. The time complexity is num_requests x num_blocks_per_request. What about passing in the requests that are not running but are during kv transfer?

@ayushsatyam146 ayushsatyam146 force-pushed the kv-cache-fix branch 4 times, most recently from 542d108 to 8adac43 Compare August 29, 2025 04:48
@ayushsatyam146
Copy link
Contributor Author

Hi @heheda12345 @njhill The time complexity of the new code is O(RxB) now, which was O(RxB²) in the previous iteration. I have one caching based implementation as well in mind which will bring down the complexity to O(1) best case and O(RxB) worst case. But that makes the code a little complex for this module hence I did not want to push that version without someone's approval. PTAL if this is fine or if we need to improve this further? Thanks!

@heheda12345
Copy link
Collaborator

My example code is O((num_transfering_request+1) * num_common_blocks). It should be much faster than num_running_request * num_common_blocks for short requests.

@ayushsatyam146
Copy link
Contributor Author

Hi @heheda12345 I did the changes your way this time and have pushed it as well. Please take a look, Thanks!

@ayushsatyam146 ayushsatyam146 force-pushed the kv-cache-fix branch 4 times, most recently from 64ed09d to 0d66b57 Compare September 2, 2025 03:53
@ayushsatyam146
Copy link
Contributor Author

Hi @heheda12345 just a gentle reminder to please take a look and approve if everything is right. Thanks!

@heheda12345
Copy link
Collaborator

@ayushsatyam146 Hi, can you help to update this PR?

@ayushsatyam146
Copy link
Contributor Author

Hi @heheda12345 sorry I got sick this week and couldn't work on this. But I am good now and will update this soon, Thanks for the patience.

@ayushsatyam146
Copy link
Contributor Author

@heheda12345, I tried to address all your concerns. Can you please take a look now, Thanks!

* Replace ref_cnt-based common prefix detection with running request tracking
* Update get_num_common_prefix_blocks() to accept running_request_ids set
* Fix FullAttentionManager to count actual references from running requests
* Prevent incorrect cascade attention when async KV offloading delays cleanup

This resolves a bug where completed requests with pending async transfers
still contributed to ref_cnt, causing incorrect cascade attention decisions.

Signed-off-by: ayushsatyam146 <ayushsatyam146@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants