Add host-side Sparge block-map pipeline for sparse attention examples

- Add sparge_tool.hpp: host-side Sparge block-map builder (mean-sim
  scoring, CDF/topk selection) and VSA delta-LUT converter.
- Add test_sparge_jenga_sparse_attn.cpp and
  test_sparge_vsa_sparse_attn.cpp as end-to-end demos.
- Update CMakeLists.txt to register both new executables.

Note: block size is currently fixed at 128; flexible block size
support is not yet addressed.
This commit is contained in:
Gino Lu
2026-03-19 23:28:36 -04:00
parent d7c761e060
commit eed42a9dfa
4 changed files with 1281 additions and 0 deletions

View File

@@ -88,6 +88,17 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
-Wno-float-equal
)
# Sparge + Jenga Example executable
set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn")
message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}")
add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES})
target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
)
# ============================================================================
# VSA Sparse Attention
# ============================================================================
@@ -153,4 +164,15 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
-Wno-float-equal
)
# Sparge + VSA Example executable
set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn")
message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}")
add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES})
target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
)
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,408 @@
#pragma once
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <limits>
#include <numeric>
#include <utility>
#include <vector>
#include <cassert>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace sparge {
struct SpargeParams
{
int BLKQ = 128;
int BLKK = 128;
// Similarity gate threshold (TODO: per-head support).
float simthreshd1 = 0.6f;
// Exactly one of the following should be used:
// - Use CDF threshold if topk < 0
// - Both should be in [0, 1] <-- NEED TO CHECK THIS
float cdfthreshd = 0.98f;
float topk = -1.0f;
// If true, treat Q/K as BHSD; otherwise BSHD (same convention as CK examples).
bool i_perm = true;
};
// Output format CK VSA expects.
struct VSALut
{
ck_tile::HostTensor<int32_t> lut; // [B, Hq, Q_blk, K_blk] delta-encoded
ck_tile::HostTensor<int32_t> valid_block_num; // [B, Hq, Q_blk]
};
namespace detail {
template <typename T>
inline float to_f32(const T& x)
{
return ck_tile::type_convert<float>(x);
}
// Read element from HostTensor with either BHSD or BSHD layout.
// Q: [B, Hq, Sq, D] if i_perm else [B, Sq, Hq, D]
// K: [B, Hk, Sk, D] if i_perm else [B, Sk, Hk, D]
template <typename T>
inline float load(const ck_tile::HostTensor<T>& X, bool i_perm, int b, int h, int s, int d)
{
return i_perm ? to_f32(X(b, h, s, d)) : to_f32(X(b, s, h, d));
}
// Compute pooled mean vector of one block: mean over tokens in [s0, s1).
template <typename T>
std::vector<float>
pooled_mean_block(const ck_tile::HostTensor<T>& X, bool i_perm, int b, int h, int s0, int s1, int d)
{
std::vector<float> mean(d, 0.0f);
const int bs = std::max(0, s1 - s0);
if(bs == 0)
return mean;
for(int s = s0; s < s1; ++s)
{
for(int d_ = 0; d_ < d; ++d_)
{
mean[d_] += load(X, i_perm, b, h, s, d_);
}
}
const float inv = 1.0f / static_cast<float>(bs);
for(int d_ = 0; d_ < d; ++d_)
mean[d_] *= inv;
return mean;
}
// Compute "sim" flag of one block following SpargeAttn's intent:
// mean_sim = sum(Gram(x_hat)) / (BS_*BS_), where x_hat are token vectors normalized along D.
//
// Important: sum(Gram) = ||sum_i x_hat_i||^2, so we can compute it in O(BS_*D) exactly
// instead of O(BS_^2 * D).
template <typename T>
bool sim_block_flag(const ck_tile::HostTensor<T>& X,
bool i_perm,
int b,
int h,
int s0,
int s1,
int d,
float simthreshd1)
{
const int bs = std::max(0, s1 - s0);
if(bs == 0)
return false;
std::vector<float> sum_hat(d, 0.0f);
for(int s = s0; s < s1; ++s)
{
// Compute L2 norm over D.
float norm2 = 0.0f;
for(int d_ = 0; d_ < d; ++d_)
{
const float v = load(X, i_perm, b, h, s, d_);
norm2 += v * v;
}
float inv_norm = 1.0f;
// spargeAttn use eps to prevent division by zero
if(norm2 > 0.0f)
inv_norm = 1.0f / std::sqrt(norm2);
// Accumulate normalized vector.
for(int d_ = 0; d_ < d; ++d_)
{
sum_hat[d_] += load(X, i_perm, b, h, s, d_) * inv_norm;
}
}
float sum_gram = 0.0f;
for(int d_ = 0; d_ < d; ++d_)
sum_gram += sum_hat[d_] * sum_hat[d_];
const float denom = static_cast<float>(bs) * static_cast<float>(bs);
const float mean_sim = sum_gram / denom;
return mean_sim > simthreshd1;
}
inline int select_count_from_cdf(const std::vector<float>& sorted_probs, float cdfthreshd)
{
// Choose the smallest n such that cdf[n-1] >= cdfthreshd.
// Ensure at least 1.
if(sorted_probs.empty())
return 0;
if(cdfthreshd <= 0.0f)
return 1;
float c = 0.0f;
for(int i = 0; i < static_cast<int>(sorted_probs.size()); ++i)
{
c += sorted_probs[i];
if(c >= cdfthreshd)
return i + 1;
}
return static_cast<int>(sorted_probs.size());
}
inline int select_count_from_topk(int K_blk, float topk)
{
if(K_blk <= 0)
return 0;
int n = static_cast<int>(std::floor(topk * static_cast<float>(K_blk)));
n = std::max(1, n);
return n;
}
} // namespace detail
// Build one-hot block_map[b,hq,qb,kb] in {0,1}.
// - No causal mask
// - No attention sink
// - Logic matches SpargeAttn's structure:
// - score softmax is only over sim_kblocks; ~sim_kblocks are forced ON later
// - if a Q-block is not "similar", force the whole row ON
template <typename T>
ck_tile::HostTensor<uint8_t> build_block_map_meansim(const ck_tile::HostTensor<T>& Q,
const ck_tile::HostTensor<T>& K,
const SpargeParams& p)
{
const auto qlens = Q.get_lengths();
const auto klens = K.get_lengths();
const int B = static_cast<int>(qlens[0]);
const int Hq = p.i_perm ? static_cast<int>(qlens[1]) : static_cast<int>(qlens[2]);
const int Sq = p.i_perm ? static_cast<int>(qlens[2]) : static_cast<int>(qlens[1]);
const int D = static_cast<int>(qlens[3]);
[[maybe_unused]] const int Bk = static_cast<int>(klens[0]);
const int Hk = p.i_perm ? static_cast<int>(klens[1]) : static_cast<int>(klens[2]);
const int Sk = p.i_perm ? static_cast<int>(klens[2]) : static_cast<int>(klens[1]);
[[maybe_unused]] const int Dk = static_cast<int>(klens[3]);
assert(B == Bk && D == Dk && Hq % Hk == 0);
assert(p.BLKQ > 0 && p.BLKK > 0);
const int nhead_ratio_qk = Hq / Hk;
const int Q_blk = ck_tile::integer_divide_ceil(Sq, p.BLKQ);
const int K_blk = ck_tile::integer_divide_ceil(Sk, p.BLKK);
ck_tile::HostTensor<uint8_t> block_map({B, Hq, Q_blk, K_blk});
// pooled_q: [B,Hq,Q_blk,D], pooled_k: [B,Hk,K_blk,D]
// sim_q: [B,Hq,Q_blk], sim_k: [B,Hk,K_blk]
std::vector<float> pooled_q(static_cast<size_t>(B) * Hq * Q_blk * D, 0.0f);
std::vector<float> pooled_k(static_cast<size_t>(B) * Hk * K_blk * D, 0.0f);
std::vector<uint8_t> sim_q(static_cast<size_t>(B) * Hq * Q_blk, 0);
std::vector<uint8_t> sim_k(static_cast<size_t>(B) * Hk * K_blk, 0);
auto idx_pq = [&](int b, int hq, int qb, int d) {
return (((b * Hq + hq) * Q_blk + qb) * D + d);
};
auto idx_pk = [&](int b, int hk, int kb, int d) {
return (((b * Hk + hk) * K_blk + kb) * D + d);
};
auto idx_sq = [&](int b, int hq, int qb) { return ((b * Hq + hq) * Q_blk + qb); };
auto idx_sk = [&](int b, int hk, int kb) { return ((b * Hk + hk) * K_blk + kb); };
for(int b = 0; b < B; ++b)
{
for(int hq = 0; hq < Hq; ++hq)
{
// Q blocks
for(int qb = 0; qb < Q_blk; ++qb)
{
const int s0 = qb * p.BLKQ;
const int s1 = std::min(Sq, (qb + 1) * p.BLKQ);
// pooled mean
auto mean = detail::pooled_mean_block(Q, p.i_perm, b, hq, s0, s1, D);
for(int d = 0; d < D; ++d)
pooled_q[idx_pq(b, hq, qb, d)] = mean[d];
// sim flag
sim_q[idx_sq(b, hq, qb)] =
detail::sim_block_flag(Q, p.i_perm, b, hq, s0, s1, D, p.simthreshd1) ? 1 : 0;
}
}
for(int hk = 0; hk < Hk; ++hk)
{
// K blocks
for(int kb = 0; kb < K_blk; ++kb)
{
const int s0 = kb * p.BLKK;
const int s1 = std::min(Sk, (kb + 1) * p.BLKK);
auto mean = detail::pooled_mean_block(K, p.i_perm, b, hk, s0, s1, D);
for(int d = 0; d < D; ++d)
pooled_k[idx_pk(b, hk, kb, d)] = mean[d];
sim_k[idx_sk(b, hk, kb)] =
detail::sim_block_flag(K, p.i_perm, b, hk, s0, s1, D, p.simthreshd1) ? 1 : 0;
}
}
}
const float scale = 1.0f / std::sqrt(static_cast<float>(D));
// Main loop
for(int b = 0; b < B; ++b)
{
for(int hq = 0; hq < Hq; ++hq)
{
const int hk = hq / nhead_ratio_qk;
for(int qb = 0; qb < Q_blk; ++qb)
{
const bool q_is_sim = (sim_q[idx_sq(b, hq, qb)] != 0);
// If Q-block is not "similar", force dense row.
if(!q_is_sim)
{
for(int kb = 0; kb < K_blk; ++kb)
block_map(b, hq, qb, kb) = 1;
continue;
}
// Compute scores over K blocks (only sim_kblocks participate in softmax; others set
// to -inf).
std::vector<float> score(K_blk, -std::numeric_limits<float>::infinity());
for(int kb = 0; kb < K_blk; ++kb)
{
const bool k_is_sim = (sim_k[idx_sk(b, hk, kb)] != 0);
if(!k_is_sim)
{
block_map(b, hq, qb, kb) = 1;
continue;
}
float dot = 0.0f;
for(int d = 0; d < D; ++d)
{
dot += pooled_q[idx_pq(b, hq, qb, d)] * pooled_k[idx_pk(b, hk, kb, d)];
}
score[kb] = dot * scale;
}
// Softmax over K_blk (numerically stable). If all -inf, probs become all zeros.
float maxv = -std::numeric_limits<float>::infinity();
for(int kb = 0; kb < K_blk; ++kb)
maxv = std::max(maxv, score[kb]);
std::vector<float> prob(K_blk, 0.0f);
if(std::isfinite(maxv))
{
float sumexp = 0.0f;
for(int kb = 0; kb < K_blk; ++kb)
{
if(!std::isfinite(score[kb]))
continue;
const float e = std::exp(score[kb] - maxv);
prob[kb] = e;
sumexp += e;
}
if(sumexp > 0.0f)
{
const float inv = 1.0f / sumexp;
for(int kb = 0; kb < K_blk; ++kb)
prob[kb] *= inv;
}
else
{
// All exponentials underflowed: keep zeros.
std::fill(prob.begin(), prob.end(), 0.0f);
}
}
// Sort indices by prob descending.
std::vector<int> order(K_blk);
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(), [&](int a, int c) {
if(prob[a] != prob[c])
return prob[a] > prob[c];
return a < c; // tie-breaker for determinism
});
// Determine how many to select.
int num_to_select = 0;
if(p.topk > 0.0f)
{
num_to_select = detail::select_count_from_topk(K_blk, p.topk);
}
else
{
// Use CDF threshold selection (smallest n s.t. cumulative prob >= cdfthreshd).
std::vector<float> sorted_probs(K_blk);
for(int i = 0; i < K_blk; ++i)
sorted_probs[i] = prob[order[i]];
num_to_select = detail::select_count_from_cdf(sorted_probs, p.cdfthreshd);
num_to_select = std::max(1, num_to_select);
}
// Select top-kb blocks by order[0..num_to_select-1].
for(int i = 0; i < num_to_select; ++i)
{
const int kb = order[i];
block_map(b, hq, qb, kb) = 1;
}
}
}
}
return block_map;
}
// Convert one-hot block_map -> delta-encoded LUT + valid_block_num (CK VSA format).
template <typename MapT>
VSALut block_map_to_vsa_lut_delta(const ck_tile::HostTensor<MapT>& block_map)
{
const auto lens = block_map.get_lengths();
const int B = static_cast<int>(lens[0]);
const int H = static_cast<int>(lens[1]);
const int Q = static_cast<int>(lens[2]);
const int K = static_cast<int>(lens[3]);
VSALut out{
ck_tile::HostTensor<int32_t>({B, H, Q, K}),
ck_tile::HostTensor<int32_t>({B, H, Q}),
};
for(int b = 0; b < B; ++b)
{
for(int h = 0; h < H; ++h)
{
for(int q = 0; q < Q; ++q)
{
int32_t valid = 0;
int32_t prev = 0;
for(int k = 0; k < K; ++k)
{
const bool on = static_cast<int>(block_map(b, h, q, k)) != 0;
if(on)
{
out.lut(b, h, q, valid) = static_cast<int32_t>(k - prev);
prev = static_cast<int32_t>(k);
++valid;
}
}
out.valid_block_num(b, h, q) = valid;
// Optional: zero-fill the unused tail for determinism.
for(int i = valid; i < K; ++i)
out.lut(b, h, q, i) = 0;
}
}
}
return out;
}
} // namespace sparge

View File

@@ -0,0 +1,422 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Demo: Sparge block-map -> Jenga sparse attention
#include <iostream>
#include <vector>
#include <cmath>
#include <random>
#include <string>
#include <algorithm>
#include <numeric>
#include <chrono>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "jenga_sparse_attention.h"
#include "sparge_tool.hpp"
// ============================================================================
// Helper Functions
// ============================================================================
template <typename T>
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t seqlen,
ck_tile::index_t hdim,
bool i_perm)
{
if(i_perm)
{
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
}
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
}
template <typename T>
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
{
auto lens = tensor.get_lengths();
ck_tile::index_t batch = lens[0];
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
ck_tile::index_t hdim = lens[3];
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t s = 0; s < seqlen; ++s)
{
for(ck_tile::index_t d = 0; d < hdim; ++d)
{
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
}
}
}
}
return out;
}
template <typename T>
auto get_error_tolerance()
{
double rtol = 1e-2;
double atol = 4e-2;
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
atol = 2e-1;
rtol = 2e-1;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename T>
float to_float_for_compare(T value)
{
return static_cast<float>(value);
}
template <>
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
return static_cast<float>(value);
#else
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
#endif
}
// ============================================================================
// Command line argument parser
// ============================================================================
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("b", "1", "batch size")
.insert("h", "4", "num of head for q")
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
.insert("s", "4096", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("prec", "fp16", "data type: fp16/bf16")
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
.insert("operm", "1", "permute output")
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name")
// Sparge-specific
.insert("blkq", "128", "Sparge BLKQ")
.insert("blkk", "128", "Sparge BLKK")
.insert("simthreshd1", "0.6", "Sparge sim threshold")
.insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)")
.insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// ============================================================================
// Main Test Function
// ============================================================================
template <typename T>
bool run_test(const ck_tile::ArgParser& arg_parser)
{
int do_validation = arg_parser.get_int("v");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
// Sparge params
ck_tile::index_t blkq = arg_parser.get_int("blkq");
ck_tile::index_t blkk = arg_parser.get_int("blkk");
float simthreshd1 = arg_parser.get_float("simthreshd1");
float cdfthreshd = arg_parser.get_float("cdfthreshd");
float topk = arg_parser.get_float("topk");
if(nhead_k < 0)
nhead_k = nhead;
if(seqlen_k < 0)
seqlen_k = seqlen_q;
if(hdim_v < 0)
hdim_v = hdim_q;
if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128)
{
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
std::cout << "Jenga/VSA kernel instances are generated for BLKQ=BLKK=128, "
"hdim_q=128, hdim_v=128 only."
<< std::endl;
std::cout << "TEST SKIPPED" << std::endl;
return true;
}
ck_tile::index_t BLKQ = blkq;
ck_tile::index_t BLKK = blkk;
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
std::cout << "============================================================" << std::endl;
std::cout << "[Sparge -> Jenga Sparse Attention Demo]" << std::endl;
std::cout << "============================================================" << std::endl;
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
<< std::endl;
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl;
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
<< std::endl;
std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd
<< ", topk=" << topk << ")" << std::endl;
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
// Create host tensors
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
ck_tile::HostTensor<T> output_host =
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
std::cout << "\nInitializing tensors..." << std::endl;
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
// Build block map using Sparge tool
std::cout << "Building Sparge block map..." << std::endl;
sparge::SpargeParams p;
p.BLKQ = static_cast<int>(BLKQ);
p.BLKK = static_cast<int>(BLKK);
p.simthreshd1 = simthreshd1;
p.cdfthreshd = cdfthreshd;
p.topk = topk;
p.i_perm = i_perm;
ck_tile::HostTensor<uint8_t> block_relation_onehot =
sparge::build_block_map_meansim(q_host, k_host, p);
// Print actual sparsity
std::size_t total_blocks = 0;
std::size_t active_blocks = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
total_blocks++;
if(block_relation_onehot(b, h, qb, kb) != 0)
active_blocks++;
}
}
}
}
float actual_sparsity =
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
<< total_blocks << " blocks active)" << std::endl;
std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl;
try
{
if(kname)
{
jenga_sparse_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
1);
}
for(int i = 0; i < warmup; ++i)
{
jenga_sparse_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
auto start = std::chrono::high_resolution_clock::now();
for(int i = 0; i < repeat; ++i)
{
jenga_sparse_attention<T>(q_host,
k_host,
v_host,
block_relation_onehot,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
auto end = std::chrono::high_resolution_clock::now();
double avg_time_ms =
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<"
<< std::endl;
}
catch(const std::exception& e)
{
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
return false;
}
bool pass = true;
if(do_validation)
{
std::cout << "\n--- Performing CPU validation ---" << std::endl;
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
std::cout << "Computing reference output..." << std::endl;
auto q_ref = to_bhsd(q_host, i_perm);
auto k_ref = to_bhsd(k_host, i_perm);
auto v_ref = to_bhsd(v_host, i_perm);
ck_tile::reference_blocked_attention<T, uint8_t>(
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
auto [rtol, atol] = get_error_tolerance<T>();
float max_diff = 0.0f;
float max_rel_diff = 0.0f;
std::size_t num_errors = 0;
auto output_host_bhsd = to_bhsd(output_host, o_perm);
for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
{
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
float ref_val = to_float_for_compare(output_ref.mData[i]);
float diff = std::abs(gpu_val - ref_val);
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
max_diff = std::max(max_diff, diff);
max_rel_diff = std::max(max_rel_diff, rel_diff);
if(diff > atol && rel_diff > rtol)
{
num_errors++;
if(num_errors <= 5)
{
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
}
}
}
std::cout << "\nValidation results:" << std::endl;
std::cout << " Max absolute difference: " << max_diff << std::endl;
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
std::cout << " Number of mismatches: " << num_errors << " / "
<< output_host_bhsd.mData.size() << std::endl;
if(num_errors == 0)
{
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
}
else
{
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
pass = false;
}
}
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
return pass;
}
// ============================================================================
// Main
// ============================================================================
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
std::cerr << "Failed to parse arguments" << std::endl;
return -1;
}
std::string prec = arg_parser.get_str("prec");
bool test_result = false;
if(prec == "fp16")
{
test_result = run_test<ck_tile::half_t>(arg_parser);
}
else if(prec == "bf16")
{
test_result = run_test<ck_tile::bf16_t>(arg_parser);
}
else
{
std::cerr << "Unsupported precision: " << prec << std::endl;
return -1;
}
return test_result ? 0 : -1;
}

View File

@@ -0,0 +1,429 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention
#include <iostream>
#include <vector>
#include <cmath>
#include <random>
#include <string>
#include <algorithm>
#include <numeric>
#include <chrono>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "jenga_sparse_attention.h"
#include "sparge_tool.hpp"
// ============================================================================
// Helper Functions
// ============================================================================
template <typename T>
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t seqlen,
ck_tile::index_t hdim,
bool i_perm)
{
if(i_perm)
{
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
}
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
}
template <typename T>
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
{
auto lens = tensor.get_lengths();
ck_tile::index_t batch = lens[0];
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
ck_tile::index_t hdim = lens[3];
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t s = 0; s < seqlen; ++s)
{
for(ck_tile::index_t d = 0; d < hdim; ++d)
{
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
}
}
}
}
return out;
}
template <typename T>
auto get_error_tolerance()
{
double rtol = 1e-2;
double atol = 4e-2;
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
atol = 2e-1;
rtol = 2e-1;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename T>
float to_float_for_compare(T value)
{
return static_cast<float>(value);
}
template <>
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
return static_cast<float>(value);
#else
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
#endif
}
// ============================================================================
// Command line argument parser
// ============================================================================
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
.insert("b", "1", "batch size")
.insert("h", "4", "num of head for q")
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
.insert("s", "4096", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("prec", "fp16", "data type: fp16/bf16")
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
.insert("operm", "1", "permute output")
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name")
// Sparge-specific
.insert("blkq", "128", "Sparge BLKQ")
.insert("blkk", "128", "Sparge BLKK")
.insert("simthreshd1", "0.6", "Sparge sim threshold")
.insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)")
.insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// ============================================================================
// Main Test Function
// ============================================================================
template <typename T>
bool run_test(const ck_tile::ArgParser& arg_parser)
{
int do_validation = arg_parser.get_int("v");
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
uint32_t seed = arg_parser.get_uint32("seed");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
// Sparge params
ck_tile::index_t blkq = arg_parser.get_int("blkq");
ck_tile::index_t blkk = arg_parser.get_int("blkk");
float simthreshd1 = arg_parser.get_float("simthreshd1");
float cdfthreshd = arg_parser.get_float("cdfthreshd");
float topk = arg_parser.get_float("topk");
if(nhead_k < 0)
nhead_k = nhead;
if(seqlen_k < 0)
seqlen_k = seqlen_q;
if(hdim_v < 0)
hdim_v = hdim_q;
if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128)
{
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
std::cout << "VSA kernel instances are generated for BLKQ=BLKK=128, "
"hdim_q=128, hdim_v=128 only."
<< std::endl;
std::cout << "TEST SKIPPED" << std::endl;
return true;
}
ck_tile::index_t BLKQ = blkq;
ck_tile::index_t BLKK = blkk;
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
std::cout << "============================================================" << std::endl;
std::cout << "[Sparge -> VSA Sparse Attention Demo]" << std::endl;
std::cout << "============================================================" << std::endl;
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
<< std::endl;
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl;
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
<< std::endl;
std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd
<< ", topk=" << topk << ")" << std::endl;
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
// Create host tensors
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
ck_tile::HostTensor<T> output_host =
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
std::cout << "\nInitializing tensors..." << std::endl;
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
// Build block map using Sparge tool
std::cout << "Building Sparge block map..." << std::endl;
sparge::SpargeParams p;
p.BLKQ = static_cast<int>(BLKQ);
p.BLKK = static_cast<int>(BLKK);
p.simthreshd1 = simthreshd1;
p.cdfthreshd = cdfthreshd;
p.topk = topk;
p.i_perm = i_perm;
ck_tile::HostTensor<uint8_t> block_relation_onehot =
sparge::build_block_map_meansim(q_host, k_host, p);
// Convert to VSA LUT (delta-encoded) + valid_block_num
std::cout << "Converting block map to VSA LUT (delta)..." << std::endl;
auto vsa_lut = sparge::block_map_to_vsa_lut_delta(block_relation_onehot);
// Print actual sparsity (based on one-hot)
std::size_t total_blocks = 0;
std::size_t active_blocks = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
total_blocks++;
if(block_relation_onehot(b, h, qb, kb) != 0)
active_blocks++;
}
}
}
}
float actual_sparsity =
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
<< total_blocks << " blocks active)" << std::endl;
std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl;
try
{
if(kname)
{
vsa_sparse_attention<T>(q_host,
k_host,
v_host,
vsa_lut.lut,
vsa_lut.valid_block_num,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
1);
}
for(int i = 0; i < warmup; ++i)
{
vsa_sparse_attention<T>(q_host,
k_host,
v_host,
vsa_lut.lut,
vsa_lut.valid_block_num,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
auto start = std::chrono::high_resolution_clock::now();
for(int i = 0; i < repeat; ++i)
{
vsa_sparse_attention<T>(q_host,
k_host,
v_host,
vsa_lut.lut,
vsa_lut.valid_block_num,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
auto end = std::chrono::high_resolution_clock::now();
double avg_time_ms =
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<"
<< std::endl;
}
catch(const std::exception& e)
{
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
return false;
}
bool pass = true;
if(do_validation)
{
std::cout << "\n--- Performing CPU validation ---" << std::endl;
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
std::cout << "Computing reference output..." << std::endl;
auto q_ref = to_bhsd(q_host, i_perm);
auto k_ref = to_bhsd(k_host, i_perm);
auto v_ref = to_bhsd(v_host, i_perm);
ck_tile::reference_blocked_attention<T, uint8_t>(
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
auto [rtol, atol] = get_error_tolerance<T>();
float max_diff = 0.0f;
float max_rel_diff = 0.0f;
std::size_t num_errors = 0;
auto output_host_bhsd = to_bhsd(output_host, o_perm);
for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
{
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
float ref_val = to_float_for_compare(output_ref.mData[i]);
float diff = std::abs(gpu_val - ref_val);
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
max_diff = std::max(max_diff, diff);
max_rel_diff = std::max(max_rel_diff, rel_diff);
if(diff > atol && rel_diff > rtol)
{
num_errors++;
if(num_errors <= 5)
{
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
}
}
}
std::cout << "\nValidation results:" << std::endl;
std::cout << " Max absolute difference: " << max_diff << std::endl;
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
std::cout << " Number of mismatches: " << num_errors << " / "
<< output_host_bhsd.mData.size() << std::endl;
if(num_errors == 0)
{
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
}
else
{
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
pass = false;
}
}
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
return pass;
}
// ============================================================================
// Main
// ============================================================================
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
std::cerr << "Failed to parse arguments" << std::endl;
return -1;
}
std::string prec = arg_parser.get_str("prec");
bool test_result = false;
if(prec == "fp16")
{
test_result = run_test<ck_tile::half_t>(arg_parser);
}
else if(prec == "bf16")
{
test_result = run_test<ck_tile::bf16_t>(arg_parser);
}
else
{
std::cerr << "Unsupported precision: " << prec << std::endl;
return -1;
}
return test_result ? 0 : -1;
}