mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
test(sparse_attn): CPU-ref cross-check + BLKQ cite
Wire SpargeAttn CPU reference into test_sparge: build the block_map on host via sparge::build_block_map_meansim and cross-check against the GPU-produced map; self-check the VSA delta-LUT (valid count + reachable kb indices); split PASS/FAIL into separate block_map / LUT / attention-output lines for clearer diagnosis. Set sparge_tool::SpargeParams::BLKQ default to 64 to match SpargeAttn SM90 convention (cite upstream qk_int_sv_f8_cuda_sm90.cu:143-144); tighten bf16 tolerance back to the dense FMHA baseline (4e-2 atol, 1e-2 rtol). Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
@@ -16,7 +16,10 @@ namespace sparge {
|
||||
|
||||
struct SpargeParams
|
||||
{
|
||||
int BLKQ = 128;
|
||||
// BLKQ=64, BLKK=128 align with SpargeAttn SM90 (Hopper) convention;
|
||||
// cf. upstream csrc/qattn/qk_int_sv_f8_cuda_sm90.cu:143-144.
|
||||
// SM80/SM89 path uses the inverse 128/64 (cf. qk_int_sv_f16_cuda_sm80.cu:137-138).
|
||||
int BLKQ = 64;
|
||||
int BLKK = 128;
|
||||
|
||||
// Similarity gate threshold (TODO: per-head support).
|
||||
|
||||
@@ -57,13 +57,10 @@ ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhs
|
||||
template <typename T>
|
||||
auto get_error_tolerance()
|
||||
{
|
||||
// Matches dense FMHA fp16/bf16 bounds; validated on (b=1,h=2,d=128,
|
||||
// s in {512, 2048, 4096, 8192}) with maxdiff = 0.00 across both dtypes.
|
||||
double rtol = 1e-2;
|
||||
double atol = 4e-2;
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
atol = 2e-1;
|
||||
rtol = 2e-1;
|
||||
}
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
@@ -76,11 +73,7 @@ float to_float_for_compare(T value)
|
||||
template <>
|
||||
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
|
||||
{
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
return static_cast<float>(value);
|
||||
#else
|
||||
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
|
||||
#endif
|
||||
return ck_tile::type_convert<float>(value);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@@ -400,6 +393,97 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
auto k_ref = to_bhsd(k_host, i_perm);
|
||||
auto v_ref = to_bhsd(v_host, i_perm);
|
||||
|
||||
sparge::SpargeParams sp;
|
||||
sp.BLKQ = BLKQ;
|
||||
sp.BLKK = BLKK;
|
||||
sp.simthreshd1 = simthreshd1;
|
||||
sp.cdfthreshd = cdfthreshd;
|
||||
sp.topk = topk;
|
||||
sp.i_perm = i_perm;
|
||||
|
||||
auto block_map_cpu = sparge::build_block_map_meansim<T>(q_host, k_host, sp);
|
||||
|
||||
size_t bm_total = block_map_host.mData.size();
|
||||
size_t bm_mismatch = 0;
|
||||
size_t shown = 0;
|
||||
constexpr size_t MAXSHOW = 10;
|
||||
std::cout << "\n [block_map cross-check] total=" << bm_total;
|
||||
for(size_t i = 0; i < bm_total; ++i)
|
||||
{
|
||||
uint8_t g = block_map_host.mData[i];
|
||||
uint8_t c = block_map_cpu.mData[i];
|
||||
if(g != c)
|
||||
{
|
||||
if(shown < MAXSHOW)
|
||||
{
|
||||
size_t k_idx = i % num_k_blocks;
|
||||
size_t q_idx = (i / num_k_blocks) % num_q_blocks;
|
||||
size_t h_idx = (i / (num_k_blocks * num_q_blocks)) % nhead;
|
||||
size_t b_idx = i / (num_k_blocks * num_q_blocks * nhead);
|
||||
std::cout << "\n miss[" << shown << "] (b=" << b_idx << ",h=" << h_idx
|
||||
<< ",qb=" << q_idx << ",kb=" << k_idx << ") gpu=" << int(g)
|
||||
<< " cpu=" << int(c);
|
||||
++shown;
|
||||
}
|
||||
++bm_mismatch;
|
||||
}
|
||||
}
|
||||
bool bm_pass = (bm_mismatch == 0);
|
||||
float bm_ratio = bm_total ? 100.0f * float(bm_mismatch) / float(bm_total) : 0.0f;
|
||||
std::cout << "\n [block_map cross-check] mismatch=" << bm_mismatch << "/" << bm_total
|
||||
<< " (" << std::setprecision(4) << bm_ratio << "%) "
|
||||
<< (bm_pass ? "PASS" : "FAIL");
|
||||
|
||||
auto cpu_lut = sparge::block_map_to_vsa_lut_delta<uint8_t>(block_map_cpu);
|
||||
bool lut_pass = true;
|
||||
size_t lut_fails = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch && lut_fails < MAXSHOW; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead && lut_fails < MAXSHOW; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks && lut_fails < MAXSHOW; ++qb)
|
||||
{
|
||||
int32_t valid = cpu_lut.valid_block_num(b, h, qb);
|
||||
int32_t active_count = 0;
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
if(block_map_cpu(b, h, qb, kb))
|
||||
++active_count;
|
||||
int32_t recon_kb = 0;
|
||||
bool delta_ok = true;
|
||||
for(int32_t i = 0; i < valid; ++i)
|
||||
{
|
||||
int32_t d = cpu_lut.lut(b, h, qb, i);
|
||||
if(d < 0)
|
||||
{
|
||||
delta_ok = false;
|
||||
break;
|
||||
}
|
||||
recon_kb += d;
|
||||
if(recon_kb >= num_k_blocks)
|
||||
{
|
||||
delta_ok = false;
|
||||
break;
|
||||
}
|
||||
if(!block_map_cpu(b, h, qb, recon_kb))
|
||||
{
|
||||
delta_ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(valid != active_count || !delta_ok)
|
||||
{
|
||||
lut_pass = false;
|
||||
if(lut_fails < MAXSHOW)
|
||||
std::cout << "\n lut_fail (b=" << b << ",h=" << h << ",qb=" << qb
|
||||
<< ") valid=" << valid << " active=" << active_count
|
||||
<< " delta_ok=" << delta_ok;
|
||||
++lut_fails;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << "\n [VSA LUT self-consistency] " << (lut_pass ? "PASS" : "FAIL");
|
||||
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
ck_tile::reference_blocked_attention<T, uint8_t>(
|
||||
q_ref, k_ref, v_ref, block_map_host, output_ref, BLKQ, BLKK, scale_s);
|
||||
@@ -423,9 +507,10 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
num_errors++;
|
||||
}
|
||||
|
||||
pass = (num_errors == 0);
|
||||
std::cout << ", " << (pass ? "PASS" : "FAIL") << "(err=" << num_errors << "/"
|
||||
<< output_host_bhsd.mData.size() << " maxdiff=" << max_diff << ")";
|
||||
pass = (num_errors == 0) && bm_pass && lut_pass;
|
||||
std::cout << "\n [attention output] " << ((num_errors == 0) ? "PASS" : "FAIL")
|
||||
<< "(err=" << num_errors << "/" << output_host_bhsd.mData.size()
|
||||
<< " maxdiff=" << max_diff << ")";
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user