Files
custom_flashinfer/docs/api/sparse.rst
Yilong Zhao 4ec2116e58 [feat] support block sparse attention w/ variable block sizes and head-wise sparse patterns (#1177)
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
This PR implements a block sparse attention wrapper that calls the
underlying FA2 and FA3 kernel implementation, which supports:
1. Variable block size for the block sparse attention: Recent research
progress in sparse attention algorithms in (e.g.,
https://arxiv.org/abs/2505.18875) necessitates sparse patterns with
variable block sizes instead of fixed ones. E.g., the following figure
illustrates one use case.
<img width="627" alt="flashinfer-variable-block-sparse"
src="https://github.com/user-attachments/assets/21a88440-a6e5-4829-abe1-dabc0b8b6310"
/>


This PR provides a straightforward API: `block_mask_map [num_kv_heads,
MB, NB]` is boolean tensor showcasing which block is important and
activated; `block_row_sz [num_kv_heads, MB]` provides the size of each
row block; `block_col_sz [num_kv_heads, NB]` provides the size of column
block.
2. Variable sparse patterns for each attention head. Users can specify
different block sparse maps for different `kv_head_idx` by providing
`[num_kv_heads, ...]` metadata. Note that GQA is supported by using the
same attention maps for the same `gqa_group_size` qo_heads.

Perf Benchmarks:
All benchmarks are conducted with `num_blocks_row=20` and
`num_blocks_col=50` (without considering `plan` overhead). Scripts are
available at `benchmarks/bench_block_sparse_attention.py`. Note that
sparse FA3 achieves near-theoretical speedup with higher sparsity, while
falling behind with low sparsity due to sparse loading overhead.

<img width="682" alt="image"
src="https://github.com/user-attachments/assets/1433d65c-1381-4742-bc09-22ed989997ef"
/>


## 🔍 Related Issues

<!-- Link any related issues here -->
This PR may relate to
https://github.com/flashinfer-ai/flashinfer/issues/1091 and
https://github.com/flashinfer-ai/flashinfer/issues/886.
## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

###  Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
New unit tests are added to `tests/test_block_sparse.py`. All passed.
<img width="740" alt="image"
src="https://github.com/user-attachments/assets/3d3f3274-818b-418f-b299-0471fbda1e57"
/>

## Reviewer Notes
cc @yzh119 @andy-yang-1
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

---------

Co-authored-by: happierpig <zhaoyilong217@sjtu.edn.cn>
Co-authored-by: yzh119 <zihaoy@nvidia.com>
2025-06-26 17:37:57 -07:00

27 lines
708 B
ReStructuredText

.. _apisparse:
flashinfer.sparse
=================
Kernels for block sparse flashattention.
.. currentmodule:: flashinfer.sparse
.. autoclass:: BlockSparseAttentionWrapper
:members:
:exclude-members: begin_forward, end_forward, forward, forward_return_lse
.. automethod:: __init__
.. autoclass:: VariableBlockSparseAttentionWrapper
:members:
:exclude-members: begin_forward, end_forward, forward, forward_return_lse
.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/examples/flashinfer-variable-block-sparse.png
:width: 600
:alt: variable block sparse attention plan function diagram
:align: center
.. automethod:: __init__