mirror of
https://github.com/kvcache-ai/custom_flashinfer.git
synced 2026-06-29 10:47:12 +00:00
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> This PR implements a block sparse attention wrapper that calls the underlying FA2 and FA3 kernel implementation, which supports: 1. Variable block size for the block sparse attention: Recent research progress in sparse attention algorithms in (e.g., https://arxiv.org/abs/2505.18875) necessitates sparse patterns with variable block sizes instead of fixed ones. E.g., the following figure illustrates one use case. <img width="627" alt="flashinfer-variable-block-sparse" src="https://github.com/user-attachments/assets/21a88440-a6e5-4829-abe1-dabc0b8b6310" /> This PR provides a straightforward API: `block_mask_map [num_kv_heads, MB, NB]` is boolean tensor showcasing which block is important and activated; `block_row_sz [num_kv_heads, MB]` provides the size of each row block; `block_col_sz [num_kv_heads, NB]` provides the size of column block. 2. Variable sparse patterns for each attention head. Users can specify different block sparse maps for different `kv_head_idx` by providing `[num_kv_heads, ...]` metadata. Note that GQA is supported by using the same attention maps for the same `gqa_group_size` qo_heads. Perf Benchmarks: All benchmarks are conducted with `num_blocks_row=20` and `num_blocks_col=50` (without considering `plan` overhead). Scripts are available at `benchmarks/bench_block_sparse_attention.py`. Note that sparse FA3 achieves near-theoretical speedup with higher sparsity, while falling behind with low sparsity due to sparse loading overhead. <img width="682" alt="image" src="https://github.com/user-attachments/assets/1433d65c-1381-4742-bc09-22ed989997ef" /> ## 🔍 Related Issues <!-- Link any related issues here --> This PR may relate to https://github.com/flashinfer-ai/flashinfer/issues/1091 and https://github.com/flashinfer-ai/flashinfer/issues/886. ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). New unit tests are added to `tests/test_block_sparse.py`. All passed. <img width="740" alt="image" src="https://github.com/user-attachments/assets/3d3f3274-818b-418f-b299-0471fbda1e57" /> ## Reviewer Notes cc @yzh119 @andy-yang-1 <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn> Co-authored-by: yzh119 <zihaoy@nvidia.com>
27 lines
708 B
ReStructuredText
27 lines
708 B
ReStructuredText
.. _apisparse:
|
|
|
|
flashinfer.sparse
|
|
=================
|
|
|
|
Kernels for block sparse flashattention.
|
|
|
|
.. currentmodule:: flashinfer.sparse
|
|
|
|
.. autoclass:: BlockSparseAttentionWrapper
|
|
:members:
|
|
:exclude-members: begin_forward, end_forward, forward, forward_return_lse
|
|
|
|
.. automethod:: __init__
|
|
|
|
|
|
.. autoclass:: VariableBlockSparseAttentionWrapper
|
|
:members:
|
|
:exclude-members: begin_forward, end_forward, forward, forward_return_lse
|
|
|
|
.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/examples/flashinfer-variable-block-sparse.png
|
|
:width: 600
|
|
:alt: variable block sparse attention plan function diagram
|
|
:align: center
|
|
|
|
.. automethod:: __init__
|