From eed42a9dfa2d81358344689f04489517ee8c0510 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Thu, 19 Mar 2026 23:28:36 -0400 Subject: [PATCH] Add host-side Sparge block-map pipeline for sparse attention examples - Add sparge_tool.hpp: host-side Sparge block-map builder (mean-sim scoring, CDF/topk selection) and VSA delta-LUT converter. - Add test_sparge_jenga_sparse_attn.cpp and test_sparge_vsa_sparse_attn.cpp as end-to-end demos. - Update CMakeLists.txt to register both new executables. Note: block size is currently fixed at 128; flexible block size support is not yet addressed. --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 22 + .../ck_tile/50_sparse_attn/sparge_tool.hpp | 408 +++++++++++++++++ .../test_sparge_jenga_sparse_attn.cpp | 422 +++++++++++++++++ .../test_sparge_vsa_sparse_attn.cpp | 429 ++++++++++++++++++ 4 files changed, 1281 insertions(+) create mode 100644 example/ck_tile/50_sparse_attn/sparge_tool.hpp create mode 100644 example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp create mode 100644 example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 65bb207764..c916f642eb 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -88,6 +88,17 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# Sparge + Jenga Example executable +set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn") +message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}") +add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp) +target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES}) +target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal +) + # ============================================================================ # VSA Sparse Attention # ============================================================================ @@ -153,4 +164,15 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# Sparge + VSA Example executable +set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn") +message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}") +add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp) +target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES}) +target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal +) + set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/50_sparse_attn/sparge_tool.hpp b/example/ck_tile/50_sparse_attn/sparge_tool.hpp new file mode 100644 index 0000000000..49c69cc6f7 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_tool.hpp @@ -0,0 +1,408 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace sparge { + +struct SpargeParams +{ + int BLKQ = 128; + int BLKK = 128; + + // Similarity gate threshold (TODO: per-head support). + float simthreshd1 = 0.6f; + + // Exactly one of the following should be used: + // - Use CDF threshold if topk < 0 + // - Both should be in [0, 1] <-- NEED TO CHECK THIS + float cdfthreshd = 0.98f; + float topk = -1.0f; + + // If true, treat Q/K as BHSD; otherwise BSHD (same convention as CK examples). + bool i_perm = true; +}; + +// Output format CK VSA expects. +struct VSALut +{ + ck_tile::HostTensor lut; // [B, Hq, Q_blk, K_blk] delta-encoded + ck_tile::HostTensor valid_block_num; // [B, Hq, Q_blk] +}; + +namespace detail { + +template +inline float to_f32(const T& x) +{ + return ck_tile::type_convert(x); +} + +// Read element from HostTensor with either BHSD or BSHD layout. +// Q: [B, Hq, Sq, D] if i_perm else [B, Sq, Hq, D] +// K: [B, Hk, Sk, D] if i_perm else [B, Sk, Hk, D] +template +inline float load(const ck_tile::HostTensor& X, bool i_perm, int b, int h, int s, int d) +{ + return i_perm ? to_f32(X(b, h, s, d)) : to_f32(X(b, s, h, d)); +} + +// Compute pooled mean vector of one block: mean over tokens in [s0, s1). +template +std::vector +pooled_mean_block(const ck_tile::HostTensor& X, bool i_perm, int b, int h, int s0, int s1, int d) +{ + std::vector mean(d, 0.0f); + const int bs = std::max(0, s1 - s0); + if(bs == 0) + return mean; + + for(int s = s0; s < s1; ++s) + { + for(int d_ = 0; d_ < d; ++d_) + { + mean[d_] += load(X, i_perm, b, h, s, d_); + } + } + const float inv = 1.0f / static_cast(bs); + for(int d_ = 0; d_ < d; ++d_) + mean[d_] *= inv; + return mean; +} + +// Compute "sim" flag of one block following SpargeAttn's intent: +// mean_sim = sum(Gram(x_hat)) / (BS_*BS_), where x_hat are token vectors normalized along D. +// +// Important: sum(Gram) = ||sum_i x_hat_i||^2, so we can compute it in O(BS_*D) exactly +// instead of O(BS_^2 * D). +template +bool sim_block_flag(const ck_tile::HostTensor& X, + bool i_perm, + int b, + int h, + int s0, + int s1, + int d, + float simthreshd1) +{ + const int bs = std::max(0, s1 - s0); + if(bs == 0) + return false; + + std::vector sum_hat(d, 0.0f); + + for(int s = s0; s < s1; ++s) + { + // Compute L2 norm over D. + float norm2 = 0.0f; + for(int d_ = 0; d_ < d; ++d_) + { + const float v = load(X, i_perm, b, h, s, d_); + norm2 += v * v; + } + float inv_norm = 1.0f; + // spargeAttn use eps to prevent division by zero + if(norm2 > 0.0f) + inv_norm = 1.0f / std::sqrt(norm2); + + // Accumulate normalized vector. + for(int d_ = 0; d_ < d; ++d_) + { + sum_hat[d_] += load(X, i_perm, b, h, s, d_) * inv_norm; + } + } + + float sum_gram = 0.0f; + for(int d_ = 0; d_ < d; ++d_) + sum_gram += sum_hat[d_] * sum_hat[d_]; + + const float denom = static_cast(bs) * static_cast(bs); + const float mean_sim = sum_gram / denom; + + return mean_sim > simthreshd1; +} + +inline int select_count_from_cdf(const std::vector& sorted_probs, float cdfthreshd) +{ + // Choose the smallest n such that cdf[n-1] >= cdfthreshd. + // Ensure at least 1. + if(sorted_probs.empty()) + return 0; + if(cdfthreshd <= 0.0f) + return 1; + + float c = 0.0f; + for(int i = 0; i < static_cast(sorted_probs.size()); ++i) + { + c += sorted_probs[i]; + if(c >= cdfthreshd) + return i + 1; + } + return static_cast(sorted_probs.size()); +} + +inline int select_count_from_topk(int K_blk, float topk) +{ + if(K_blk <= 0) + return 0; + int n = static_cast(std::floor(topk * static_cast(K_blk))); + n = std::max(1, n); + return n; +} + +} // namespace detail + +// Build one-hot block_map[b,hq,qb,kb] in {0,1}. +// - No causal mask +// - No attention sink +// - Logic matches SpargeAttn's structure: +// - score softmax is only over sim_kblocks; ~sim_kblocks are forced ON later +// - if a Q-block is not "similar", force the whole row ON +template +ck_tile::HostTensor build_block_map_meansim(const ck_tile::HostTensor& Q, + const ck_tile::HostTensor& K, + const SpargeParams& p) +{ + const auto qlens = Q.get_lengths(); + const auto klens = K.get_lengths(); + + const int B = static_cast(qlens[0]); + const int Hq = p.i_perm ? static_cast(qlens[1]) : static_cast(qlens[2]); + const int Sq = p.i_perm ? static_cast(qlens[2]) : static_cast(qlens[1]); + const int D = static_cast(qlens[3]); + + [[maybe_unused]] const int Bk = static_cast(klens[0]); + const int Hk = p.i_perm ? static_cast(klens[1]) : static_cast(klens[2]); + const int Sk = p.i_perm ? static_cast(klens[2]) : static_cast(klens[1]); + [[maybe_unused]] const int Dk = static_cast(klens[3]); + + assert(B == Bk && D == Dk && Hq % Hk == 0); + assert(p.BLKQ > 0 && p.BLKK > 0); + + const int nhead_ratio_qk = Hq / Hk; + const int Q_blk = ck_tile::integer_divide_ceil(Sq, p.BLKQ); + const int K_blk = ck_tile::integer_divide_ceil(Sk, p.BLKK); + + ck_tile::HostTensor block_map({B, Hq, Q_blk, K_blk}); + + // pooled_q: [B,Hq,Q_blk,D], pooled_k: [B,Hk,K_blk,D] + // sim_q: [B,Hq,Q_blk], sim_k: [B,Hk,K_blk] + std::vector pooled_q(static_cast(B) * Hq * Q_blk * D, 0.0f); + std::vector pooled_k(static_cast(B) * Hk * K_blk * D, 0.0f); + std::vector sim_q(static_cast(B) * Hq * Q_blk, 0); + std::vector sim_k(static_cast(B) * Hk * K_blk, 0); + + auto idx_pq = [&](int b, int hq, int qb, int d) { + return (((b * Hq + hq) * Q_blk + qb) * D + d); + }; + auto idx_pk = [&](int b, int hk, int kb, int d) { + return (((b * Hk + hk) * K_blk + kb) * D + d); + }; + auto idx_sq = [&](int b, int hq, int qb) { return ((b * Hq + hq) * Q_blk + qb); }; + auto idx_sk = [&](int b, int hk, int kb) { return ((b * Hk + hk) * K_blk + kb); }; + + for(int b = 0; b < B; ++b) + { + for(int hq = 0; hq < Hq; ++hq) + { + // Q blocks + for(int qb = 0; qb < Q_blk; ++qb) + { + const int s0 = qb * p.BLKQ; + const int s1 = std::min(Sq, (qb + 1) * p.BLKQ); + + // pooled mean + auto mean = detail::pooled_mean_block(Q, p.i_perm, b, hq, s0, s1, D); + for(int d = 0; d < D; ++d) + pooled_q[idx_pq(b, hq, qb, d)] = mean[d]; + + // sim flag + sim_q[idx_sq(b, hq, qb)] = + detail::sim_block_flag(Q, p.i_perm, b, hq, s0, s1, D, p.simthreshd1) ? 1 : 0; + } + } + + for(int hk = 0; hk < Hk; ++hk) + { + // K blocks + for(int kb = 0; kb < K_blk; ++kb) + { + const int s0 = kb * p.BLKK; + const int s1 = std::min(Sk, (kb + 1) * p.BLKK); + + auto mean = detail::pooled_mean_block(K, p.i_perm, b, hk, s0, s1, D); + for(int d = 0; d < D; ++d) + pooled_k[idx_pk(b, hk, kb, d)] = mean[d]; + + sim_k[idx_sk(b, hk, kb)] = + detail::sim_block_flag(K, p.i_perm, b, hk, s0, s1, D, p.simthreshd1) ? 1 : 0; + } + } + } + + const float scale = 1.0f / std::sqrt(static_cast(D)); + + // Main loop + for(int b = 0; b < B; ++b) + { + for(int hq = 0; hq < Hq; ++hq) + { + const int hk = hq / nhead_ratio_qk; + + for(int qb = 0; qb < Q_blk; ++qb) + { + const bool q_is_sim = (sim_q[idx_sq(b, hq, qb)] != 0); + + // If Q-block is not "similar", force dense row. + if(!q_is_sim) + { + for(int kb = 0; kb < K_blk; ++kb) + block_map(b, hq, qb, kb) = 1; + continue; + } + + // Compute scores over K blocks (only sim_kblocks participate in softmax; others set + // to -inf). + std::vector score(K_blk, -std::numeric_limits::infinity()); + for(int kb = 0; kb < K_blk; ++kb) + { + const bool k_is_sim = (sim_k[idx_sk(b, hk, kb)] != 0); + if(!k_is_sim) + { + block_map(b, hq, qb, kb) = 1; + continue; + } + + float dot = 0.0f; + for(int d = 0; d < D; ++d) + { + dot += pooled_q[idx_pq(b, hq, qb, d)] * pooled_k[idx_pk(b, hk, kb, d)]; + } + score[kb] = dot * scale; + } + + // Softmax over K_blk (numerically stable). If all -inf, probs become all zeros. + float maxv = -std::numeric_limits::infinity(); + for(int kb = 0; kb < K_blk; ++kb) + maxv = std::max(maxv, score[kb]); + + std::vector prob(K_blk, 0.0f); + if(std::isfinite(maxv)) + { + float sumexp = 0.0f; + for(int kb = 0; kb < K_blk; ++kb) + { + if(!std::isfinite(score[kb])) + continue; + const float e = std::exp(score[kb] - maxv); + prob[kb] = e; + sumexp += e; + } + if(sumexp > 0.0f) + { + const float inv = 1.0f / sumexp; + for(int kb = 0; kb < K_blk; ++kb) + prob[kb] *= inv; + } + else + { + // All exponentials underflowed: keep zeros. + std::fill(prob.begin(), prob.end(), 0.0f); + } + } + + // Sort indices by prob descending. + std::vector order(K_blk); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&](int a, int c) { + if(prob[a] != prob[c]) + return prob[a] > prob[c]; + return a < c; // tie-breaker for determinism + }); + + // Determine how many to select. + int num_to_select = 0; + if(p.topk > 0.0f) + { + num_to_select = detail::select_count_from_topk(K_blk, p.topk); + } + else + { + // Use CDF threshold selection (smallest n s.t. cumulative prob >= cdfthreshd). + std::vector sorted_probs(K_blk); + for(int i = 0; i < K_blk; ++i) + sorted_probs[i] = prob[order[i]]; + num_to_select = detail::select_count_from_cdf(sorted_probs, p.cdfthreshd); + num_to_select = std::max(1, num_to_select); + } + + // Select top-kb blocks by order[0..num_to_select-1]. + for(int i = 0; i < num_to_select; ++i) + { + const int kb = order[i]; + block_map(b, hq, qb, kb) = 1; + } + } + } + } + + return block_map; +} + +// Convert one-hot block_map -> delta-encoded LUT + valid_block_num (CK VSA format). +template +VSALut block_map_to_vsa_lut_delta(const ck_tile::HostTensor& block_map) +{ + const auto lens = block_map.get_lengths(); + const int B = static_cast(lens[0]); + const int H = static_cast(lens[1]); + const int Q = static_cast(lens[2]); + const int K = static_cast(lens[3]); + + VSALut out{ + ck_tile::HostTensor({B, H, Q, K}), + ck_tile::HostTensor({B, H, Q}), + }; + + for(int b = 0; b < B; ++b) + { + for(int h = 0; h < H; ++h) + { + for(int q = 0; q < Q; ++q) + { + int32_t valid = 0; + int32_t prev = 0; + + for(int k = 0; k < K; ++k) + { + const bool on = static_cast(block_map(b, h, q, k)) != 0; + if(on) + { + out.lut(b, h, q, valid) = static_cast(k - prev); + prev = static_cast(k); + ++valid; + } + } + + out.valid_block_num(b, h, q) = valid; + + // Optional: zero-fill the unused tail for determinism. + for(int i = valid; i < K; ++i) + out.lut(b, h, q, i) = 0; + } + } + } + + return out; +} + +} // namespace sparge diff --git a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp new file mode 100644 index 0000000000..0bd664adf6 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp @@ -0,0 +1,422 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Demo: Sparge block-map -> Jenga sparse attention + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +#include "jenga_sparse_attention.h" +#include "sparge_tool.hpp" + +// ============================================================================ +// Helper Functions +// ============================================================================ + +template +ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t seqlen, + ck_tile::index_t hdim, + bool i_perm) +{ + if(i_perm) + { + return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); + } + return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); +} + +template +ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) +{ + auto lens = tensor.get_lengths(); + ck_tile::index_t batch = lens[0]; + ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; + ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; + ck_tile::index_t hdim = lens[3]; + + ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t s = 0; s < seqlen; ++s) + { + for(ck_tile::index_t d = 0; d < hdim; ++d) + { + out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); + } + } + } + } + return out; +} + +template +auto get_error_tolerance() +{ + double rtol = 1e-2; + double atol = 4e-2; + if constexpr(std::is_same_v) + { + atol = 2e-1; + rtol = 2e-1; + } + return ck_tile::make_tuple(rtol, atol); +} + +template +float to_float_for_compare(T value) +{ + return static_cast(value); +} + +template <> +float to_float_for_compare(ck_tile::bf16_t value) +{ +#if CK_TILE_USE_CUSTOM_DATA_TYPE + return static_cast(value); +#else + return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); +#endif +} + +// ============================================================================ +// Command line argument parser +// ============================================================================ + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") + .insert("b", "1", "batch size") + .insert("h", "4", "num of head for q") + .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") + .insert("s", "4096", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("prec", "fp16", "data type: fp16/bf16") + .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") + .insert("operm", "1", "permute output") + .insert("seed", "42", "random seed") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name") + // Sparge-specific + .insert("blkq", "128", "Sparge BLKQ") + .insert("blkk", "128", "Sparge BLKK") + .insert("simthreshd1", "0.6", "Sparge sim threshold") + .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") + .insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// ============================================================================ +// Main Test Function +// ============================================================================ + +template +bool run_test(const ck_tile::ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int kname = arg_parser.get_int("kname"); + + // Sparge params + ck_tile::index_t blkq = arg_parser.get_int("blkq"); + ck_tile::index_t blkk = arg_parser.get_int("blkk"); + float simthreshd1 = arg_parser.get_float("simthreshd1"); + float cdfthreshd = arg_parser.get_float("cdfthreshd"); + float topk = arg_parser.get_float("topk"); + + if(nhead_k < 0) + nhead_k = nhead; + if(seqlen_k < 0) + seqlen_k = seqlen_q; + if(hdim_v < 0) + hdim_v = hdim_q; + + if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128) + { + std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; + std::cout << "Jenga/VSA kernel instances are generated for BLKQ=BLKK=128, " + "hdim_q=128, hdim_v=128 only." + << std::endl; + std::cout << "TEST SKIPPED" << std::endl; + return true; + } + + ck_tile::index_t BLKQ = blkq; + ck_tile::index_t BLKK = blkk; + + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; + + std::cout << "============================================================" << std::endl; + std::cout << "[Sparge -> Jenga Sparse Attention Demo]" << std::endl; + std::cout << "============================================================" << std::endl; + std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k + << std::endl; + std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; + std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; + std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl; + std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks + << std::endl; + std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd + << ", topk=" << topk << ")" << std::endl; + std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; + + // Create host tensors + ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + ck_tile::HostTensor output_host = + o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) + : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); + + std::cout << "\nInitializing tensors..." << std::endl; + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); + + // Build block map using Sparge tool + std::cout << "Building Sparge block map..." << std::endl; + sparge::SpargeParams p; + p.BLKQ = static_cast(BLKQ); + p.BLKK = static_cast(BLKK); + p.simthreshd1 = simthreshd1; + p.cdfthreshd = cdfthreshd; + p.topk = topk; + p.i_perm = i_perm; + + ck_tile::HostTensor block_relation_onehot = + sparge::build_block_map_meansim(q_host, k_host, p); + + // Print actual sparsity + std::size_t total_blocks = 0; + std::size_t active_blocks = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + total_blocks++; + if(block_relation_onehot(b, h, qb, kb) != 0) + active_blocks++; + } + } + } + } + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" + << total_blocks << " blocks active)" << std::endl; + + std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl; + + try + { + if(kname) + { + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 1); + } + + for(int i = 0; i < warmup; ++i) + { + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 0); + } + + [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); + auto start = std::chrono::high_resolution_clock::now(); + + for(int i = 0; i < repeat; ++i) + { + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 0); + } + + [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); + auto end = std::chrono::high_resolution_clock::now(); + double avg_time_ms = + std::chrono::duration(end - start).count() / repeat; + + std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<" + << std::endl; + } + catch(const std::exception& e) + { + std::cerr << "Error during kernel execution: " << e.what() << std::endl; + return false; + } + + bool pass = true; + if(do_validation) + { + std::cout << "\n--- Performing CPU validation ---" << std::endl; + float scale = 1.0f / std::sqrt(static_cast(hdim_q)); + + std::cout << "Computing reference output..." << std::endl; + auto q_ref = to_bhsd(q_host, i_perm); + auto k_ref = to_bhsd(k_host, i_perm); + auto v_ref = to_bhsd(v_host, i_perm); + + ck_tile::reference_blocked_attention( + q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); + + auto [rtol, atol] = get_error_tolerance(); + + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + std::size_t num_errors = 0; + + auto output_host_bhsd = to_bhsd(output_host, o_perm); + for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i) + { + float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); + float ref_val = to_float_for_compare(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); + float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; + + max_diff = std::max(max_diff, diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + + if(diff > atol && rel_diff > rtol) + { + num_errors++; + if(num_errors <= 5) + { + std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val + << ", Ref=" << ref_val << ", Diff=" << diff << std::endl; + } + } + } + + std::cout << "\nValidation results:" << std::endl; + std::cout << " Max absolute difference: " << max_diff << std::endl; + std::cout << " Max relative difference: " << max_rel_diff << std::endl; + std::cout << " Number of mismatches: " << num_errors << " / " + << output_host_bhsd.mData.size() << std::endl; + + if(num_errors == 0) + { + std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; + } + else + { + std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; + pass = false; + } + } + + std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl; + return pass; +} + +// ============================================================================ +// Main +// ============================================================================ + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cerr << "Failed to parse arguments" << std::endl; + return -1; + } + + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if(prec == "fp16") + { + test_result = run_test(arg_parser); + } + else if(prec == "bf16") + { + test_result = run_test(arg_parser); + } + else + { + std::cerr << "Unsupported precision: " << prec << std::endl; + return -1; + } + + return test_result ? 0 : -1; +} diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp new file mode 100644 index 0000000000..dd1d3e60be --- /dev/null +++ b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp @@ -0,0 +1,429 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +#include "jenga_sparse_attention.h" +#include "sparge_tool.hpp" + +// ============================================================================ +// Helper Functions +// ============================================================================ + +template +ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t seqlen, + ck_tile::index_t hdim, + bool i_perm) +{ + if(i_perm) + { + return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); + } + return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); +} + +template +ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) +{ + auto lens = tensor.get_lengths(); + ck_tile::index_t batch = lens[0]; + ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; + ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; + ck_tile::index_t hdim = lens[3]; + + ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t s = 0; s < seqlen; ++s) + { + for(ck_tile::index_t d = 0; d < hdim; ++d) + { + out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); + } + } + } + } + return out; +} + +template +auto get_error_tolerance() +{ + double rtol = 1e-2; + double atol = 4e-2; + if constexpr(std::is_same_v) + { + atol = 2e-1; + rtol = 2e-1; + } + return ck_tile::make_tuple(rtol, atol); +} + +template +float to_float_for_compare(T value) +{ + return static_cast(value); +} + +template <> +float to_float_for_compare(ck_tile::bf16_t value) +{ +#if CK_TILE_USE_CUSTOM_DATA_TYPE + return static_cast(value); +#else + return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); +#endif +} + +// ============================================================================ +// Command line argument parser +// ============================================================================ + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") + .insert("b", "1", "batch size") + .insert("h", "4", "num of head for q") + .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") + .insert("s", "4096", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("prec", "fp16", "data type: fp16/bf16") + .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") + .insert("operm", "1", "permute output") + .insert("seed", "42", "random seed") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name") + // Sparge-specific + .insert("blkq", "128", "Sparge BLKQ") + .insert("blkk", "128", "Sparge BLKK") + .insert("simthreshd1", "0.6", "Sparge sim threshold") + .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") + .insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// ============================================================================ +// Main Test Function +// ============================================================================ + +template +bool run_test(const ck_tile::ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int kname = arg_parser.get_int("kname"); + + // Sparge params + ck_tile::index_t blkq = arg_parser.get_int("blkq"); + ck_tile::index_t blkk = arg_parser.get_int("blkk"); + float simthreshd1 = arg_parser.get_float("simthreshd1"); + float cdfthreshd = arg_parser.get_float("cdfthreshd"); + float topk = arg_parser.get_float("topk"); + + if(nhead_k < 0) + nhead_k = nhead; + if(seqlen_k < 0) + seqlen_k = seqlen_q; + if(hdim_v < 0) + hdim_v = hdim_q; + + if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128) + { + std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; + std::cout << "VSA kernel instances are generated for BLKQ=BLKK=128, " + "hdim_q=128, hdim_v=128 only." + << std::endl; + std::cout << "TEST SKIPPED" << std::endl; + return true; + } + + ck_tile::index_t BLKQ = blkq; + ck_tile::index_t BLKK = blkk; + + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; + + std::cout << "============================================================" << std::endl; + std::cout << "[Sparge -> VSA Sparse Attention Demo]" << std::endl; + std::cout << "============================================================" << std::endl; + std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k + << std::endl; + std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; + std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; + std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl; + std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks + << std::endl; + std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd + << ", topk=" << topk << ")" << std::endl; + std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; + + // Create host tensors + ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + ck_tile::HostTensor output_host = + o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) + : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); + + std::cout << "\nInitializing tensors..." << std::endl; + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); + + // Build block map using Sparge tool + std::cout << "Building Sparge block map..." << std::endl; + sparge::SpargeParams p; + p.BLKQ = static_cast(BLKQ); + p.BLKK = static_cast(BLKK); + p.simthreshd1 = simthreshd1; + p.cdfthreshd = cdfthreshd; + p.topk = topk; + p.i_perm = i_perm; + + ck_tile::HostTensor block_relation_onehot = + sparge::build_block_map_meansim(q_host, k_host, p); + + // Convert to VSA LUT (delta-encoded) + valid_block_num + std::cout << "Converting block map to VSA LUT (delta)..." << std::endl; + auto vsa_lut = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); + + // Print actual sparsity (based on one-hot) + std::size_t total_blocks = 0; + std::size_t active_blocks = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + total_blocks++; + if(block_relation_onehot(b, h, qb, kb) != 0) + active_blocks++; + } + } + } + } + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" + << total_blocks << " blocks active)" << std::endl; + + std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; + + try + { + if(kname) + { + vsa_sparse_attention(q_host, + k_host, + v_host, + vsa_lut.lut, + vsa_lut.valid_block_num, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 1); + } + + for(int i = 0; i < warmup; ++i) + { + vsa_sparse_attention(q_host, + k_host, + v_host, + vsa_lut.lut, + vsa_lut.valid_block_num, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 0); + } + + [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); + auto start = std::chrono::high_resolution_clock::now(); + + for(int i = 0; i < repeat; ++i) + { + vsa_sparse_attention(q_host, + k_host, + v_host, + vsa_lut.lut, + vsa_lut.valid_block_num, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 0); + } + + [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); + auto end = std::chrono::high_resolution_clock::now(); + double avg_time_ms = + std::chrono::duration(end - start).count() / repeat; + + std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" + << std::endl; + } + catch(const std::exception& e) + { + std::cerr << "Error during kernel execution: " << e.what() << std::endl; + return false; + } + + bool pass = true; + if(do_validation) + { + std::cout << "\n--- Performing CPU validation ---" << std::endl; + float scale = 1.0f / std::sqrt(static_cast(hdim_q)); + + std::cout << "Computing reference output..." << std::endl; + auto q_ref = to_bhsd(q_host, i_perm); + auto k_ref = to_bhsd(k_host, i_perm); + auto v_ref = to_bhsd(v_host, i_perm); + + ck_tile::reference_blocked_attention( + q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); + + auto [rtol, atol] = get_error_tolerance(); + + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + std::size_t num_errors = 0; + + auto output_host_bhsd = to_bhsd(output_host, o_perm); + for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i) + { + float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); + float ref_val = to_float_for_compare(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); + float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; + + max_diff = std::max(max_diff, diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + + if(diff > atol && rel_diff > rtol) + { + num_errors++; + if(num_errors <= 5) + { + std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val + << ", Ref=" << ref_val << ", Diff=" << diff << std::endl; + } + } + } + + std::cout << "\nValidation results:" << std::endl; + std::cout << " Max absolute difference: " << max_diff << std::endl; + std::cout << " Max relative difference: " << max_rel_diff << std::endl; + std::cout << " Number of mismatches: " << num_errors << " / " + << output_host_bhsd.mData.size() << std::endl; + + if(num_errors == 0) + { + std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; + } + else + { + std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; + pass = false; + } + } + + std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl; + return pass; +} + +// ============================================================================ +// Main +// ============================================================================ + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cerr << "Failed to parse arguments" << std::endl; + return -1; + } + + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if(prec == "fp16") + { + test_result = run_test(arg_parser); + } + else if(prec == "bf16") + { + test_result = run_test(arg_parser); + } + else + { + std::cerr << "Unsupported precision: " << prec << std::endl; + return -1; + } + + return test_result ? 0 : -1; +}