mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
sparse_attn: wire -mask and -attention_sink (block-map prune + attn mask)
This commit is contained in:
@@ -9,11 +9,11 @@ Implemented:
|
||||
- top-k / `cdfthreshd` block selection, BlockMap LUT
|
||||
- sparse FMHA (both `vsa` and `jenga` backends)
|
||||
- per-head `topk` / `simthreshd1` / `cdfthreshd`
|
||||
- **is_causal mask in pooled score** (top-left only at block-map grain) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338))
|
||||
- **attention_sink** — block-map column 0 force-on ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355))
|
||||
|
||||
Not yet ported (upstream pinned to commit [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a)):
|
||||
- **K smoothing** — pre-pool `k -= km`; required for diffusion / video checkpoints (CogVideoX, Mochi-1, Flux, OpenSora, SD 3.5) ([spas_sage_attn/core.py:L53](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L53))
|
||||
- **is_causal mask in pooled score** — required for causal-LM prefill (Llama, Qwen) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338))
|
||||
- **attention_sink** — column 0 forced ON; upstream is hard-wired to `True` at inference ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355))
|
||||
- **Sort-based top-k selection** — replaces our O(N_k^2) iterative argmax; matters at long seqlen (s ≥ 16k) ([spas_sage_attn/utils.py:L345](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L345))
|
||||
- **Q/K int8 quant fusion in pool kernel** — enables a downstream int8 GEMM0 in the attn kernel ([spas_sage_attn/utils.py:L371](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L371))
|
||||
|
||||
@@ -40,7 +40,11 @@ ninja tile_example_sparge
|
||||
|
||||
Select a PV-skip variant with `-pv_mode={none|warp|block}` (default `warp`); finite `-pv_threshold=20` lets the per-Q-tile skip predicate fire.
|
||||
|
||||
Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k.
|
||||
Mask + attention sink:
|
||||
- `-mask` accepts the `01_fmha` grammar (`0` / `t` / `b` / `t:l,r` / `xt:N` / `g:y,x`, default `0`). The block-map selection prunes past-diagonal blocks only under `mask_top_left` (`t`); `b` / SWA / generic are forwarded to the attention kernel and emit a stderr WARN that the block-map selection is unchanged.
|
||||
- `-attention_sink {0,1}` forces block-map column `kb=0` ON for every Q-block (default `0`). Under `-mask t` this is degenerate since `kb=0` is always causal-valid.
|
||||
|
||||
Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k. When `-mask != 0` or `-attention_sink == 1`, the `[block_map cross-check]` and `[VSA LUT self-consistency]` cells are SKIPPED (the CPU reference does not model causal mask or sink); the `[attention output]` cell still runs but the dense reference applies no mask, so it will report FAIL on the kernel-correct output. Treat `-v=1` correctness as **block-map level only** in those configurations.
|
||||
|
||||
## References
|
||||
|
||||
|
||||
@@ -56,6 +56,11 @@ struct sparge_blockmap_args
|
||||
const float* simthreshd1_per_head_ptr = nullptr;
|
||||
const float* cdfthreshd_per_head_ptr = nullptr;
|
||||
const float* topk_per_head_ptr = nullptr;
|
||||
|
||||
// R32 Items 2+3. Pipeline only honours mask_enum::mask_top_left; CLI warns
|
||||
// on other types. Defaults preserve back-compat for callers not yet setting.
|
||||
mask_enum mask_type = mask_enum::no_mask;
|
||||
bool attention_sink = false;
|
||||
};
|
||||
|
||||
struct sparge_blockmap_workspace_layout
|
||||
@@ -105,7 +110,9 @@ auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args,
|
||||
pooled_k_ws_ptr,
|
||||
sim_k_ws_ptr,
|
||||
args.topk_per_head_ptr,
|
||||
args.cdfthreshd_per_head_ptr);
|
||||
args.cdfthreshd_per_head_ptr,
|
||||
static_cast<ck_tile::index_t>(args.mask_type),
|
||||
args.attention_sink);
|
||||
|
||||
dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "01_fmha/mask.hpp" // R32: mask_info::decode, mask_enum
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "sparge_blockmap_trek.hpp"
|
||||
#include "sparge_tool.hpp"
|
||||
@@ -126,7 +127,23 @@ auto create_args(int argc, char* argv[])
|
||||
"none = no skip (kNone binary; matches VSA baseline). "
|
||||
"warp = per-wavefront butterfly vote (R25 A1; default). "
|
||||
"block = per-block AND vote via 1 LDS slot + block_sync_lds (R30). "
|
||||
"Overrides -pv_skip_compile when set explicitly.");
|
||||
"Overrides -pv_skip_compile when set explicitly.")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, "
|
||||
"window_size negative is causal, positive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, "
|
||||
"window_size negative is causal, positive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size "
|
||||
"(only debug purpose for now)")
|
||||
.insert("attention_sink",
|
||||
"0",
|
||||
"SpargeAttn: force block-map column 0 ON (kb=0 always selected). "
|
||||
"0=off, 1=on. Block-map level only; independent of -mask sink prefix.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -161,6 +178,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
int pv_skip_compile = arg_parser.get_int("pv_skip_compile");
|
||||
std::string pv_per_head_s = arg_parser.get_str("pv_threshold_per_head");
|
||||
std::string pv_mode_str = arg_parser.get_str("pv_mode");
|
||||
std::string mask_str = arg_parser.get_str("mask");
|
||||
bool attention_sink = arg_parser.get_bool("attention_sink");
|
||||
|
||||
// R30: --pv_mode maps to the int dispatched at host.
|
||||
// none -> 0 (kNone), warp -> 1 (kPerWave), block -> 2 (kPerBlock).
|
||||
@@ -192,6 +211,15 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
mask_info mask = mask_info::decode(mask_str, seqlen_q, seqlen_k);
|
||||
if(mask.type != mask_enum::no_mask && mask.type != mask_enum::mask_top_left)
|
||||
std::fprintf(stderr,
|
||||
"[test_sparge] WARN: -mask='%s' (type=%d) - block-map only "
|
||||
"filters mask_top_left; selection will not prune past-diagonal "
|
||||
"blocks. attention kernel still applies the mask.\n",
|
||||
mask_str.c_str(),
|
||||
static_cast<int>(mask.type));
|
||||
|
||||
// If cdfthreshd >= 0, use CDF mode; otherwise use topk mode
|
||||
if(cdfthreshd >= 0.0f)
|
||||
topk = -1.0f;
|
||||
@@ -281,6 +309,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer();
|
||||
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
|
||||
bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr;
|
||||
bmap_args.mask_type = mask.type; // R32 Item 2
|
||||
bmap_args.attention_sink = attention_sink; // R32 Item 3
|
||||
|
||||
// K-stats workspace: caller-owned, sized via host helper, allocated once outside any timing.
|
||||
const size_t ws_bytes = sparge_blockmap_get_workspace_size(bmap_traits, bmap_args);
|
||||
@@ -350,7 +380,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
attn_traits.hdim_v = hdim_v;
|
||||
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
|
||||
attn_traits.is_v_rowmajor = true;
|
||||
attn_traits.mask_type = mask_enum::no_mask;
|
||||
attn_traits.mask_type = mask.type;
|
||||
attn_traits.bm0 = BLKQ;
|
||||
|
||||
fmha_jenga_fwd_args attn_args;
|
||||
@@ -380,9 +410,9 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
attn_args.batch_stride_k = k_strides[0];
|
||||
attn_args.batch_stride_v = v_strides[0];
|
||||
attn_args.batch_stride_o = o_strides[0];
|
||||
attn_args.window_size_left = -1;
|
||||
attn_args.window_size_right = -1;
|
||||
attn_args.mask_type = 0;
|
||||
attn_args.window_size_left = mask.left;
|
||||
attn_args.window_size_right = mask.right;
|
||||
attn_args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
|
||||
avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
|
||||
}
|
||||
@@ -395,7 +425,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
attn_traits.hdim_v = hdim_v;
|
||||
attn_traits.data_type = std::is_same_v<T, ck_tile::half_t> ? "fp16" : "bf16";
|
||||
attn_traits.is_v_rowmajor = true;
|
||||
attn_traits.mask_type = mask_enum::no_mask;
|
||||
attn_traits.mask_type = mask.type;
|
||||
attn_traits.bm0 = BLKQ;
|
||||
|
||||
fmha_sparge_fwd_args attn_args;
|
||||
@@ -435,9 +465,9 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
attn_args.batch_stride_k = k_strides[0];
|
||||
attn_args.batch_stride_v = v_strides[0];
|
||||
attn_args.batch_stride_o = o_strides[0];
|
||||
attn_args.window_size_left = -1;
|
||||
attn_args.window_size_right = -1;
|
||||
attn_args.mask_type = 0;
|
||||
attn_args.window_size_left = mask.left;
|
||||
attn_args.window_size_right = mask.right;
|
||||
attn_args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
|
||||
avg_ms =
|
||||
sparge_sparge_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg);
|
||||
@@ -500,96 +530,111 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
auto k_ref = to_bhsd(k_host, i_perm);
|
||||
auto v_ref = to_bhsd(v_host, i_perm);
|
||||
|
||||
sparge::SpargeParams sp;
|
||||
sp.BLKQ = BLKQ;
|
||||
sp.BLKK = BLKK;
|
||||
sp.simthreshd1 = simthreshd1;
|
||||
sp.cdfthreshd = cdfthreshd;
|
||||
sp.topk = topk;
|
||||
sp.i_perm = i_perm;
|
||||
// R32: CPU reference lacks causal mask + attention_sink; skip block_map
|
||||
// cross-check + VSA LUT self-consistency when either is in effect. The
|
||||
// attention-output check below still runs (consumes GPU bmap).
|
||||
const bool skip_cpu_bm_check = (mask.type != mask_enum::no_mask) || attention_sink;
|
||||
|
||||
auto block_map_cpu = sparge::build_block_map_meansim<T>(q_host, k_host, sp);
|
||||
|
||||
size_t bm_total = block_map_host.mData.size();
|
||||
size_t bm_mismatch = 0;
|
||||
size_t shown = 0;
|
||||
constexpr size_t MAXSHOW = 10;
|
||||
std::cout << "\n [block_map cross-check] total=" << bm_total;
|
||||
for(size_t i = 0; i < bm_total; ++i)
|
||||
bool bm_pass = true;
|
||||
bool lut_pass = true;
|
||||
if(!skip_cpu_bm_check)
|
||||
{
|
||||
uint8_t g = block_map_host.mData[i];
|
||||
uint8_t c = block_map_cpu.mData[i];
|
||||
if(g != c)
|
||||
{
|
||||
if(shown < MAXSHOW)
|
||||
{
|
||||
size_t k_idx = i % num_k_blocks;
|
||||
size_t q_idx = (i / num_k_blocks) % num_q_blocks;
|
||||
size_t h_idx = (i / (num_k_blocks * num_q_blocks)) % nhead;
|
||||
size_t b_idx = i / (num_k_blocks * num_q_blocks * nhead);
|
||||
std::cout << "\n miss[" << shown << "] (b=" << b_idx << ",h=" << h_idx
|
||||
<< ",qb=" << q_idx << ",kb=" << k_idx << ") gpu=" << int(g)
|
||||
<< " cpu=" << int(c);
|
||||
++shown;
|
||||
}
|
||||
++bm_mismatch;
|
||||
}
|
||||
}
|
||||
bool bm_pass = (bm_mismatch == 0);
|
||||
float bm_ratio = bm_total ? 100.0f * float(bm_mismatch) / float(bm_total) : 0.0f;
|
||||
std::cout << "\n [block_map cross-check] mismatch=" << bm_mismatch << "/" << bm_total
|
||||
<< " (" << std::setprecision(4) << bm_ratio << "%) "
|
||||
<< (bm_pass ? "PASS" : "FAIL");
|
||||
|
||||
auto cpu_lut = sparge::block_map_to_vsa_lut_delta<uint8_t>(block_map_cpu);
|
||||
bool lut_pass = true;
|
||||
size_t lut_fails = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch && lut_fails < MAXSHOW; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead && lut_fails < MAXSHOW; ++h)
|
||||
sparge::SpargeParams sp;
|
||||
sp.BLKQ = BLKQ;
|
||||
sp.BLKK = BLKK;
|
||||
sp.simthreshd1 = simthreshd1;
|
||||
sp.cdfthreshd = cdfthreshd;
|
||||
sp.topk = topk;
|
||||
sp.i_perm = i_perm;
|
||||
|
||||
auto block_map_cpu = sparge::build_block_map_meansim<T>(q_host, k_host, sp);
|
||||
|
||||
size_t bm_total = block_map_host.mData.size();
|
||||
size_t bm_mismatch = 0;
|
||||
size_t shown = 0;
|
||||
constexpr size_t MAXSHOW = 10;
|
||||
std::cout << "\n [block_map cross-check] total=" << bm_total;
|
||||
for(size_t i = 0; i < bm_total; ++i)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks && lut_fails < MAXSHOW; ++qb)
|
||||
uint8_t g = block_map_host.mData[i];
|
||||
uint8_t c = block_map_cpu.mData[i];
|
||||
if(g != c)
|
||||
{
|
||||
int32_t valid = cpu_lut.valid_block_num(b, h, qb);
|
||||
int32_t active_count = 0;
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
if(block_map_cpu(b, h, qb, kb))
|
||||
++active_count;
|
||||
int32_t recon_kb = 0;
|
||||
bool delta_ok = true;
|
||||
for(int32_t i = 0; i < valid; ++i)
|
||||
if(shown < MAXSHOW)
|
||||
{
|
||||
int32_t d = cpu_lut.lut(b, h, qb, i);
|
||||
if(d < 0)
|
||||
{
|
||||
delta_ok = false;
|
||||
break;
|
||||
}
|
||||
recon_kb += d;
|
||||
if(recon_kb >= num_k_blocks)
|
||||
{
|
||||
delta_ok = false;
|
||||
break;
|
||||
}
|
||||
if(!block_map_cpu(b, h, qb, recon_kb))
|
||||
{
|
||||
delta_ok = false;
|
||||
break;
|
||||
}
|
||||
size_t k_idx = i % num_k_blocks;
|
||||
size_t q_idx = (i / num_k_blocks) % num_q_blocks;
|
||||
size_t h_idx = (i / (num_k_blocks * num_q_blocks)) % nhead;
|
||||
size_t b_idx = i / (num_k_blocks * num_q_blocks * nhead);
|
||||
std::cout << "\n miss[" << shown << "] (b=" << b_idx << ",h=" << h_idx
|
||||
<< ",qb=" << q_idx << ",kb=" << k_idx << ") gpu=" << int(g)
|
||||
<< " cpu=" << int(c);
|
||||
++shown;
|
||||
}
|
||||
if(valid != active_count || !delta_ok)
|
||||
++bm_mismatch;
|
||||
}
|
||||
}
|
||||
bm_pass = (bm_mismatch == 0);
|
||||
float bm_ratio = bm_total ? 100.0f * float(bm_mismatch) / float(bm_total) : 0.0f;
|
||||
std::cout << "\n [block_map cross-check] mismatch=" << bm_mismatch << "/" << bm_total
|
||||
<< " (" << std::setprecision(4) << bm_ratio << "%) "
|
||||
<< (bm_pass ? "PASS" : "FAIL");
|
||||
|
||||
auto cpu_lut = sparge::block_map_to_vsa_lut_delta<uint8_t>(block_map_cpu);
|
||||
size_t lut_fails = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch && lut_fails < MAXSHOW; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead && lut_fails < MAXSHOW; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks && lut_fails < MAXSHOW; ++qb)
|
||||
{
|
||||
lut_pass = false;
|
||||
if(lut_fails < MAXSHOW)
|
||||
std::cout << "\n lut_fail (b=" << b << ",h=" << h << ",qb=" << qb
|
||||
<< ") valid=" << valid << " active=" << active_count
|
||||
<< " delta_ok=" << delta_ok;
|
||||
++lut_fails;
|
||||
int32_t valid = cpu_lut.valid_block_num(b, h, qb);
|
||||
int32_t active_count = 0;
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
if(block_map_cpu(b, h, qb, kb))
|
||||
++active_count;
|
||||
int32_t recon_kb = 0;
|
||||
bool delta_ok = true;
|
||||
for(int32_t i = 0; i < valid; ++i)
|
||||
{
|
||||
int32_t d = cpu_lut.lut(b, h, qb, i);
|
||||
if(d < 0)
|
||||
{
|
||||
delta_ok = false;
|
||||
break;
|
||||
}
|
||||
recon_kb += d;
|
||||
if(recon_kb >= num_k_blocks)
|
||||
{
|
||||
delta_ok = false;
|
||||
break;
|
||||
}
|
||||
if(!block_map_cpu(b, h, qb, recon_kb))
|
||||
{
|
||||
delta_ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(valid != active_count || !delta_ok)
|
||||
{
|
||||
lut_pass = false;
|
||||
if(lut_fails < MAXSHOW)
|
||||
std::cout << "\n lut_fail (b=" << b << ",h=" << h << ",qb=" << qb
|
||||
<< ") valid=" << valid << " active=" << active_count
|
||||
<< " delta_ok=" << delta_ok;
|
||||
++lut_fails;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << "\n [VSA LUT self-consistency] " << (lut_pass ? "PASS" : "FAIL");
|
||||
} // end if(!skip_cpu_bm_check)
|
||||
else
|
||||
{
|
||||
std::cout << "\n [block_map cross-check] SKIPPED (mask/sink active; CPU ref lacks)";
|
||||
std::cout << "\n [VSA LUT self-consistency] SKIPPED";
|
||||
}
|
||||
std::cout << "\n [VSA LUT self-consistency] " << (lut_pass ? "PASS" : "FAIL");
|
||||
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
ck_tile::reference_blocked_attention<T, uint8_t>(
|
||||
|
||||
@@ -65,6 +65,12 @@ struct SpargeBlockMapKernel
|
||||
// Per-head cdfthreshd (size = nhead_q floats). Required (non-null);
|
||||
// only consulted on topk<=0 path.
|
||||
const float* cdfthreshd_per_head;
|
||||
|
||||
// R32 Items 2+3. mask_type stored as index_t (not mask_enum) to keep this
|
||||
// include-tree header independent of example/01_fmha/mask.hpp. Magic
|
||||
// constant 1 == mask_enum::mask_top_left (01_fmha/mask.hpp:13-19).
|
||||
index_t mask_type;
|
||||
bool attention_sink;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr,
|
||||
@@ -90,7 +96,9 @@ struct SpargeBlockMapKernel
|
||||
const void* pooled_k_ws_ptr,
|
||||
const void* sim_k_ws_ptr,
|
||||
const float* topk_per_head,
|
||||
const float* cdfthreshd_per_head)
|
||||
const float* cdfthreshd_per_head,
|
||||
index_t mask_type,
|
||||
bool attention_sink)
|
||||
{
|
||||
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
|
||||
return Kargs{q_ptr,
|
||||
@@ -117,7 +125,9 @@ struct SpargeBlockMapKernel
|
||||
sim_k_ws_ptr,
|
||||
N_k,
|
||||
topk_per_head,
|
||||
cdfthreshd_per_head};
|
||||
cdfthreshd_per_head,
|
||||
mask_type,
|
||||
attention_sink};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q)
|
||||
@@ -220,7 +230,9 @@ struct SpargeBlockMapKernel
|
||||
valid_out,
|
||||
pooled_k_ws,
|
||||
sim_k_ws,
|
||||
static_cast<void*>(smem));
|
||||
static_cast<void*>(smem),
|
||||
kargs.mask_type,
|
||||
kargs.attention_sink);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -263,10 +263,16 @@ struct SpargeBlockMapPipeline
|
||||
int32_t* valid_block_num_ptr,
|
||||
const KDataType* __restrict__ pooled_k_ws_ptr,
|
||||
const uint8_t* __restrict__ sim_k_ws_ptr,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
index_t mask_type,
|
||||
bool attention_sink) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
|
||||
// mask_enum::mask_top_left == 1 (01_fmha/mask.hpp:16). Multiplicative
|
||||
// form handles BLKQ=64,BLKK=128 (kM0<kN0) and the kM0>=kN0 case.
|
||||
const bool is_causal_tl = (mask_type == 1);
|
||||
|
||||
// K-loop no longer reduces; only Q-stats uses smem_float0.
|
||||
// smem_float1 slab is allocated for layout compat but unused.
|
||||
auto* smem_float0 = reinterpret_cast<float*>(smem_ptr);
|
||||
@@ -320,13 +326,23 @@ struct SpargeBlockMapPipeline
|
||||
// Not similar → force all K blocks ON, early exit
|
||||
if(!sim_q)
|
||||
{
|
||||
// R32 Item 2: only fill causal-valid prefix when active.
|
||||
const index_t causal_kb_end =
|
||||
is_causal_tl ? min(N_k, integer_divide_ceil((qb + 1) * kM0, kN0)) : N_k;
|
||||
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
block_map_ptr[i] = 1;
|
||||
block_map_ptr[i] = (i < causal_kb_end) ? 1 : 0;
|
||||
|
||||
// R32 Item 3: sink force. Under top-left causal, kb=0 always
|
||||
// causal-valid for qb>=0 -> no-op; meaningful for mask=no + sink=1.
|
||||
if(attention_sink && tid == 0)
|
||||
block_map_ptr[0] = 1;
|
||||
__syncthreads(); // sink visible to LUT-build below
|
||||
|
||||
if(lut_ptr != nullptr && tid == 0)
|
||||
{
|
||||
int32_t valid = 0, prev = 0;
|
||||
for(index_t kb = 0; kb < N_k; ++kb)
|
||||
for(index_t kb = 0; kb < causal_kb_end; ++kb)
|
||||
{
|
||||
lut_ptr[valid] = static_cast<int32_t>(kb) - prev;
|
||||
prev = static_cast<int32_t>(kb);
|
||||
@@ -354,6 +370,10 @@ struct SpargeBlockMapPipeline
|
||||
|
||||
for(index_t kb = 0; kb < N_k; ++kb)
|
||||
{
|
||||
// R32 Item 2: top-left causal at block grain.
|
||||
// (qb,kb) past-diagonal iff kb*kN0 >= (qb+1)*kM0.
|
||||
const bool causal_killed = is_causal_tl && (kb * kN0 >= (qb + 1) * kM0);
|
||||
|
||||
const KDataType* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
|
||||
float pooled_k_mean[KPerThread];
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
@@ -372,17 +392,17 @@ struct SpargeBlockMapPipeline
|
||||
// ~sim_k blocks are forced ON in the bitmap (final_map[~sim_k]=1)
|
||||
// AND have score = -inf so the selection step (topk / cdf) does NOT
|
||||
// pick them again (would double-count toward topk budget).
|
||||
// Both writes MUST stay together. Any selection rewrite
|
||||
// (e.g. iterative argmax -> bitonic sort) must keep the -inf write.
|
||||
if(!sim_k)
|
||||
// R32: causal_killed gates the force-on so past-diagonal blocks are
|
||||
// NOT forced ON; bmap stays 0, scores -inf so selection excludes them.
|
||||
if(causal_killed)
|
||||
smem_scores[kb] = -numeric<float>::infinity(); // bmap stays 0
|
||||
else if(!sim_k)
|
||||
{
|
||||
smem_bmap[kb] = 1;
|
||||
smem_scores[kb] = -numeric<float>::infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
smem_scores[kb] = dot * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads(); // guard selection's reads of smem_bmap / smem_scores
|
||||
@@ -517,6 +537,15 @@ struct SpargeBlockMapPipeline
|
||||
// ==================================================================
|
||||
// Write outputs to global memory
|
||||
// ==================================================================
|
||||
// R32 Item 3: force smem_bmap[0]=1 BEFORE LUT collation reads it.
|
||||
// Reuses existing LUT-build loop (R31 §4: don't manually insert into
|
||||
// delta stream). Causal post-multiply unnecessary: D.2 sets killed
|
||||
// scores to -inf; selection gate L490 `bv > 0` excludes them, so
|
||||
// smem_bmap[bi]=1 never fires for killed blocks.
|
||||
if(attention_sink && tid == 0)
|
||||
smem_bmap[0] = 1;
|
||||
__syncthreads();
|
||||
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
block_map_ptr[i] = smem_bmap[i];
|
||||
|
||||
|
||||
Reference in New Issue
Block a user