diff --git a/example/ck_tile/50_sparse_attn/sparge_tool.hpp b/example/ck_tile/50_sparse_attn/sparge_tool.hpp index 49c69cc6f7..94426c6fd8 100644 --- a/example/ck_tile/50_sparse_attn/sparge_tool.hpp +++ b/example/ck_tile/50_sparse_attn/sparge_tool.hpp @@ -16,7 +16,10 @@ namespace sparge { struct SpargeParams { - int BLKQ = 128; + // BLKQ=64, BLKK=128 align with SpargeAttn SM90 (Hopper) convention; + // cf. upstream csrc/qattn/qk_int_sv_f8_cuda_sm90.cu:143-144. + // SM80/SM89 path uses the inverse 128/64 (cf. qk_int_sv_f16_cuda_sm80.cu:137-138). + int BLKQ = 64; int BLKK = 128; // Similarity gate threshold (TODO: per-head support). diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp index a2cf101cf1..73e58f0c24 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -57,13 +57,10 @@ ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhs template auto get_error_tolerance() { + // Matches dense FMHA fp16/bf16 bounds; validated on (b=1,h=2,d=128, + // s in {512, 2048, 4096, 8192}) with maxdiff = 0.00 across both dtypes. double rtol = 1e-2; double atol = 4e-2; - if constexpr(std::is_same_v) - { - atol = 2e-1; - rtol = 2e-1; - } return ck_tile::make_tuple(rtol, atol); } @@ -76,11 +73,7 @@ float to_float_for_compare(T value) template <> float to_float_for_compare(ck_tile::bf16_t value) { -#if CK_TILE_USE_CUSTOM_DATA_TYPE - return static_cast(value); -#else - return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); -#endif + return ck_tile::type_convert(value); } // ============================================================================ @@ -400,6 +393,97 @@ bool run_test(const ck_tile::ArgParser& arg_parser) auto k_ref = to_bhsd(k_host, i_perm); auto v_ref = to_bhsd(v_host, i_perm); + sparge::SpargeParams sp; + sp.BLKQ = BLKQ; + sp.BLKK = BLKK; + sp.simthreshd1 = simthreshd1; + sp.cdfthreshd = cdfthreshd; + sp.topk = topk; + sp.i_perm = i_perm; + + auto block_map_cpu = sparge::build_block_map_meansim(q_host, k_host, sp); + + size_t bm_total = block_map_host.mData.size(); + size_t bm_mismatch = 0; + size_t shown = 0; + constexpr size_t MAXSHOW = 10; + std::cout << "\n [block_map cross-check] total=" << bm_total; + for(size_t i = 0; i < bm_total; ++i) + { + uint8_t g = block_map_host.mData[i]; + uint8_t c = block_map_cpu.mData[i]; + if(g != c) + { + if(shown < MAXSHOW) + { + size_t k_idx = i % num_k_blocks; + size_t q_idx = (i / num_k_blocks) % num_q_blocks; + size_t h_idx = (i / (num_k_blocks * num_q_blocks)) % nhead; + size_t b_idx = i / (num_k_blocks * num_q_blocks * nhead); + std::cout << "\n miss[" << shown << "] (b=" << b_idx << ",h=" << h_idx + << ",qb=" << q_idx << ",kb=" << k_idx << ") gpu=" << int(g) + << " cpu=" << int(c); + ++shown; + } + ++bm_mismatch; + } + } + bool bm_pass = (bm_mismatch == 0); + float bm_ratio = bm_total ? 100.0f * float(bm_mismatch) / float(bm_total) : 0.0f; + std::cout << "\n [block_map cross-check] mismatch=" << bm_mismatch << "/" << bm_total + << " (" << std::setprecision(4) << bm_ratio << "%) " + << (bm_pass ? "PASS" : "FAIL"); + + auto cpu_lut = sparge::block_map_to_vsa_lut_delta(block_map_cpu); + bool lut_pass = true; + size_t lut_fails = 0; + for(ck_tile::index_t b = 0; b < batch && lut_fails < MAXSHOW; ++b) + { + for(ck_tile::index_t h = 0; h < nhead && lut_fails < MAXSHOW; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks && lut_fails < MAXSHOW; ++qb) + { + int32_t valid = cpu_lut.valid_block_num(b, h, qb); + int32_t active_count = 0; + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + if(block_map_cpu(b, h, qb, kb)) + ++active_count; + int32_t recon_kb = 0; + bool delta_ok = true; + for(int32_t i = 0; i < valid; ++i) + { + int32_t d = cpu_lut.lut(b, h, qb, i); + if(d < 0) + { + delta_ok = false; + break; + } + recon_kb += d; + if(recon_kb >= num_k_blocks) + { + delta_ok = false; + break; + } + if(!block_map_cpu(b, h, qb, recon_kb)) + { + delta_ok = false; + break; + } + } + if(valid != active_count || !delta_ok) + { + lut_pass = false; + if(lut_fails < MAXSHOW) + std::cout << "\n lut_fail (b=" << b << ",h=" << h << ",qb=" << qb + << ") valid=" << valid << " active=" << active_count + << " delta_ok=" << delta_ok; + ++lut_fails; + } + } + } + } + std::cout << "\n [VSA LUT self-consistency] " << (lut_pass ? "PASS" : "FAIL"); + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); ck_tile::reference_blocked_attention( q_ref, k_ref, v_ref, block_map_host, output_ref, BLKQ, BLKK, scale_s); @@ -423,9 +507,10 @@ bool run_test(const ck_tile::ArgParser& arg_parser) num_errors++; } - pass = (num_errors == 0); - std::cout << ", " << (pass ? "PASS" : "FAIL") << "(err=" << num_errors << "/" - << output_host_bhsd.mData.size() << " maxdiff=" << max_diff << ")"; + pass = (num_errors == 0) && bm_pass && lut_pass; + std::cout << "\n [attention output] " << ((num_errors == 0) ? "PASS" : "FAIL") + << "(err=" << num_errors << "/" << output_host_bhsd.mData.size() + << " maxdiff=" << max_diff << ")"; } std::cout << std::endl;