mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
Add host-side Sparge block-map pipeline for sparse attention examples
- Add sparge_tool.hpp: host-side Sparge block-map builder (mean-sim scoring, CDF/topk selection) and VSA delta-LUT converter. - Add test_sparge_jenga_sparse_attn.cpp and test_sparge_vsa_sparse_attn.cpp as end-to-end demos. - Update CMakeLists.txt to register both new executables. Note: block size is currently fixed at 128; flexible block size support is not yet addressed.
This commit is contained in:
@@ -88,6 +88,17 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# Sparge + Jenga Example executable
|
||||
set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn")
|
||||
message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# VSA Sparse Attention
|
||||
# ============================================================================
|
||||
@@ -153,4 +164,15 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# Sparge + VSA Example executable
|
||||
set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn")
|
||||
message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES})
|
||||
target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
|
||||
408
example/ck_tile/50_sparse_attn/sparge_tool.hpp
Normal file
408
example/ck_tile/50_sparse_attn/sparge_tool.hpp
Normal file
@@ -0,0 +1,408 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace sparge {
|
||||
|
||||
struct SpargeParams
|
||||
{
|
||||
int BLKQ = 128;
|
||||
int BLKK = 128;
|
||||
|
||||
// Similarity gate threshold (TODO: per-head support).
|
||||
float simthreshd1 = 0.6f;
|
||||
|
||||
// Exactly one of the following should be used:
|
||||
// - Use CDF threshold if topk < 0
|
||||
// - Both should be in [0, 1] <-- NEED TO CHECK THIS
|
||||
float cdfthreshd = 0.98f;
|
||||
float topk = -1.0f;
|
||||
|
||||
// If true, treat Q/K as BHSD; otherwise BSHD (same convention as CK examples).
|
||||
bool i_perm = true;
|
||||
};
|
||||
|
||||
// Output format CK VSA expects.
|
||||
struct VSALut
|
||||
{
|
||||
ck_tile::HostTensor<int32_t> lut; // [B, Hq, Q_blk, K_blk] delta-encoded
|
||||
ck_tile::HostTensor<int32_t> valid_block_num; // [B, Hq, Q_blk]
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T>
|
||||
inline float to_f32(const T& x)
|
||||
{
|
||||
return ck_tile::type_convert<float>(x);
|
||||
}
|
||||
|
||||
// Read element from HostTensor with either BHSD or BSHD layout.
|
||||
// Q: [B, Hq, Sq, D] if i_perm else [B, Sq, Hq, D]
|
||||
// K: [B, Hk, Sk, D] if i_perm else [B, Sk, Hk, D]
|
||||
template <typename T>
|
||||
inline float load(const ck_tile::HostTensor<T>& X, bool i_perm, int b, int h, int s, int d)
|
||||
{
|
||||
return i_perm ? to_f32(X(b, h, s, d)) : to_f32(X(b, s, h, d));
|
||||
}
|
||||
|
||||
// Compute pooled mean vector of one block: mean over tokens in [s0, s1).
|
||||
template <typename T>
|
||||
std::vector<float>
|
||||
pooled_mean_block(const ck_tile::HostTensor<T>& X, bool i_perm, int b, int h, int s0, int s1, int d)
|
||||
{
|
||||
std::vector<float> mean(d, 0.0f);
|
||||
const int bs = std::max(0, s1 - s0);
|
||||
if(bs == 0)
|
||||
return mean;
|
||||
|
||||
for(int s = s0; s < s1; ++s)
|
||||
{
|
||||
for(int d_ = 0; d_ < d; ++d_)
|
||||
{
|
||||
mean[d_] += load(X, i_perm, b, h, s, d_);
|
||||
}
|
||||
}
|
||||
const float inv = 1.0f / static_cast<float>(bs);
|
||||
for(int d_ = 0; d_ < d; ++d_)
|
||||
mean[d_] *= inv;
|
||||
return mean;
|
||||
}
|
||||
|
||||
// Compute "sim" flag of one block following SpargeAttn's intent:
|
||||
// mean_sim = sum(Gram(x_hat)) / (BS_*BS_), where x_hat are token vectors normalized along D.
|
||||
//
|
||||
// Important: sum(Gram) = ||sum_i x_hat_i||^2, so we can compute it in O(BS_*D) exactly
|
||||
// instead of O(BS_^2 * D).
|
||||
template <typename T>
|
||||
bool sim_block_flag(const ck_tile::HostTensor<T>& X,
|
||||
bool i_perm,
|
||||
int b,
|
||||
int h,
|
||||
int s0,
|
||||
int s1,
|
||||
int d,
|
||||
float simthreshd1)
|
||||
{
|
||||
const int bs = std::max(0, s1 - s0);
|
||||
if(bs == 0)
|
||||
return false;
|
||||
|
||||
std::vector<float> sum_hat(d, 0.0f);
|
||||
|
||||
for(int s = s0; s < s1; ++s)
|
||||
{
|
||||
// Compute L2 norm over D.
|
||||
float norm2 = 0.0f;
|
||||
for(int d_ = 0; d_ < d; ++d_)
|
||||
{
|
||||
const float v = load(X, i_perm, b, h, s, d_);
|
||||
norm2 += v * v;
|
||||
}
|
||||
float inv_norm = 1.0f;
|
||||
// spargeAttn use eps to prevent division by zero
|
||||
if(norm2 > 0.0f)
|
||||
inv_norm = 1.0f / std::sqrt(norm2);
|
||||
|
||||
// Accumulate normalized vector.
|
||||
for(int d_ = 0; d_ < d; ++d_)
|
||||
{
|
||||
sum_hat[d_] += load(X, i_perm, b, h, s, d_) * inv_norm;
|
||||
}
|
||||
}
|
||||
|
||||
float sum_gram = 0.0f;
|
||||
for(int d_ = 0; d_ < d; ++d_)
|
||||
sum_gram += sum_hat[d_] * sum_hat[d_];
|
||||
|
||||
const float denom = static_cast<float>(bs) * static_cast<float>(bs);
|
||||
const float mean_sim = sum_gram / denom;
|
||||
|
||||
return mean_sim > simthreshd1;
|
||||
}
|
||||
|
||||
inline int select_count_from_cdf(const std::vector<float>& sorted_probs, float cdfthreshd)
|
||||
{
|
||||
// Choose the smallest n such that cdf[n-1] >= cdfthreshd.
|
||||
// Ensure at least 1.
|
||||
if(sorted_probs.empty())
|
||||
return 0;
|
||||
if(cdfthreshd <= 0.0f)
|
||||
return 1;
|
||||
|
||||
float c = 0.0f;
|
||||
for(int i = 0; i < static_cast<int>(sorted_probs.size()); ++i)
|
||||
{
|
||||
c += sorted_probs[i];
|
||||
if(c >= cdfthreshd)
|
||||
return i + 1;
|
||||
}
|
||||
return static_cast<int>(sorted_probs.size());
|
||||
}
|
||||
|
||||
inline int select_count_from_topk(int K_blk, float topk)
|
||||
{
|
||||
if(K_blk <= 0)
|
||||
return 0;
|
||||
int n = static_cast<int>(std::floor(topk * static_cast<float>(K_blk)));
|
||||
n = std::max(1, n);
|
||||
return n;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Build one-hot block_map[b,hq,qb,kb] in {0,1}.
|
||||
// - No causal mask
|
||||
// - No attention sink
|
||||
// - Logic matches SpargeAttn's structure:
|
||||
// - score softmax is only over sim_kblocks; ~sim_kblocks are forced ON later
|
||||
// - if a Q-block is not "similar", force the whole row ON
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<uint8_t> build_block_map_meansim(const ck_tile::HostTensor<T>& Q,
|
||||
const ck_tile::HostTensor<T>& K,
|
||||
const SpargeParams& p)
|
||||
{
|
||||
const auto qlens = Q.get_lengths();
|
||||
const auto klens = K.get_lengths();
|
||||
|
||||
const int B = static_cast<int>(qlens[0]);
|
||||
const int Hq = p.i_perm ? static_cast<int>(qlens[1]) : static_cast<int>(qlens[2]);
|
||||
const int Sq = p.i_perm ? static_cast<int>(qlens[2]) : static_cast<int>(qlens[1]);
|
||||
const int D = static_cast<int>(qlens[3]);
|
||||
|
||||
[[maybe_unused]] const int Bk = static_cast<int>(klens[0]);
|
||||
const int Hk = p.i_perm ? static_cast<int>(klens[1]) : static_cast<int>(klens[2]);
|
||||
const int Sk = p.i_perm ? static_cast<int>(klens[2]) : static_cast<int>(klens[1]);
|
||||
[[maybe_unused]] const int Dk = static_cast<int>(klens[3]);
|
||||
|
||||
assert(B == Bk && D == Dk && Hq % Hk == 0);
|
||||
assert(p.BLKQ > 0 && p.BLKK > 0);
|
||||
|
||||
const int nhead_ratio_qk = Hq / Hk;
|
||||
const int Q_blk = ck_tile::integer_divide_ceil(Sq, p.BLKQ);
|
||||
const int K_blk = ck_tile::integer_divide_ceil(Sk, p.BLKK);
|
||||
|
||||
ck_tile::HostTensor<uint8_t> block_map({B, Hq, Q_blk, K_blk});
|
||||
|
||||
// pooled_q: [B,Hq,Q_blk,D], pooled_k: [B,Hk,K_blk,D]
|
||||
// sim_q: [B,Hq,Q_blk], sim_k: [B,Hk,K_blk]
|
||||
std::vector<float> pooled_q(static_cast<size_t>(B) * Hq * Q_blk * D, 0.0f);
|
||||
std::vector<float> pooled_k(static_cast<size_t>(B) * Hk * K_blk * D, 0.0f);
|
||||
std::vector<uint8_t> sim_q(static_cast<size_t>(B) * Hq * Q_blk, 0);
|
||||
std::vector<uint8_t> sim_k(static_cast<size_t>(B) * Hk * K_blk, 0);
|
||||
|
||||
auto idx_pq = [&](int b, int hq, int qb, int d) {
|
||||
return (((b * Hq + hq) * Q_blk + qb) * D + d);
|
||||
};
|
||||
auto idx_pk = [&](int b, int hk, int kb, int d) {
|
||||
return (((b * Hk + hk) * K_blk + kb) * D + d);
|
||||
};
|
||||
auto idx_sq = [&](int b, int hq, int qb) { return ((b * Hq + hq) * Q_blk + qb); };
|
||||
auto idx_sk = [&](int b, int hk, int kb) { return ((b * Hk + hk) * K_blk + kb); };
|
||||
|
||||
for(int b = 0; b < B; ++b)
|
||||
{
|
||||
for(int hq = 0; hq < Hq; ++hq)
|
||||
{
|
||||
// Q blocks
|
||||
for(int qb = 0; qb < Q_blk; ++qb)
|
||||
{
|
||||
const int s0 = qb * p.BLKQ;
|
||||
const int s1 = std::min(Sq, (qb + 1) * p.BLKQ);
|
||||
|
||||
// pooled mean
|
||||
auto mean = detail::pooled_mean_block(Q, p.i_perm, b, hq, s0, s1, D);
|
||||
for(int d = 0; d < D; ++d)
|
||||
pooled_q[idx_pq(b, hq, qb, d)] = mean[d];
|
||||
|
||||
// sim flag
|
||||
sim_q[idx_sq(b, hq, qb)] =
|
||||
detail::sim_block_flag(Q, p.i_perm, b, hq, s0, s1, D, p.simthreshd1) ? 1 : 0;
|
||||
}
|
||||
}
|
||||
|
||||
for(int hk = 0; hk < Hk; ++hk)
|
||||
{
|
||||
// K blocks
|
||||
for(int kb = 0; kb < K_blk; ++kb)
|
||||
{
|
||||
const int s0 = kb * p.BLKK;
|
||||
const int s1 = std::min(Sk, (kb + 1) * p.BLKK);
|
||||
|
||||
auto mean = detail::pooled_mean_block(K, p.i_perm, b, hk, s0, s1, D);
|
||||
for(int d = 0; d < D; ++d)
|
||||
pooled_k[idx_pk(b, hk, kb, d)] = mean[d];
|
||||
|
||||
sim_k[idx_sk(b, hk, kb)] =
|
||||
detail::sim_block_flag(K, p.i_perm, b, hk, s0, s1, D, p.simthreshd1) ? 1 : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(D));
|
||||
|
||||
// Main loop
|
||||
for(int b = 0; b < B; ++b)
|
||||
{
|
||||
for(int hq = 0; hq < Hq; ++hq)
|
||||
{
|
||||
const int hk = hq / nhead_ratio_qk;
|
||||
|
||||
for(int qb = 0; qb < Q_blk; ++qb)
|
||||
{
|
||||
const bool q_is_sim = (sim_q[idx_sq(b, hq, qb)] != 0);
|
||||
|
||||
// If Q-block is not "similar", force dense row.
|
||||
if(!q_is_sim)
|
||||
{
|
||||
for(int kb = 0; kb < K_blk; ++kb)
|
||||
block_map(b, hq, qb, kb) = 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute scores over K blocks (only sim_kblocks participate in softmax; others set
|
||||
// to -inf).
|
||||
std::vector<float> score(K_blk, -std::numeric_limits<float>::infinity());
|
||||
for(int kb = 0; kb < K_blk; ++kb)
|
||||
{
|
||||
const bool k_is_sim = (sim_k[idx_sk(b, hk, kb)] != 0);
|
||||
if(!k_is_sim)
|
||||
{
|
||||
block_map(b, hq, qb, kb) = 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
float dot = 0.0f;
|
||||
for(int d = 0; d < D; ++d)
|
||||
{
|
||||
dot += pooled_q[idx_pq(b, hq, qb, d)] * pooled_k[idx_pk(b, hk, kb, d)];
|
||||
}
|
||||
score[kb] = dot * scale;
|
||||
}
|
||||
|
||||
// Softmax over K_blk (numerically stable). If all -inf, probs become all zeros.
|
||||
float maxv = -std::numeric_limits<float>::infinity();
|
||||
for(int kb = 0; kb < K_blk; ++kb)
|
||||
maxv = std::max(maxv, score[kb]);
|
||||
|
||||
std::vector<float> prob(K_blk, 0.0f);
|
||||
if(std::isfinite(maxv))
|
||||
{
|
||||
float sumexp = 0.0f;
|
||||
for(int kb = 0; kb < K_blk; ++kb)
|
||||
{
|
||||
if(!std::isfinite(score[kb]))
|
||||
continue;
|
||||
const float e = std::exp(score[kb] - maxv);
|
||||
prob[kb] = e;
|
||||
sumexp += e;
|
||||
}
|
||||
if(sumexp > 0.0f)
|
||||
{
|
||||
const float inv = 1.0f / sumexp;
|
||||
for(int kb = 0; kb < K_blk; ++kb)
|
||||
prob[kb] *= inv;
|
||||
}
|
||||
else
|
||||
{
|
||||
// All exponentials underflowed: keep zeros.
|
||||
std::fill(prob.begin(), prob.end(), 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort indices by prob descending.
|
||||
std::vector<int> order(K_blk);
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
std::sort(order.begin(), order.end(), [&](int a, int c) {
|
||||
if(prob[a] != prob[c])
|
||||
return prob[a] > prob[c];
|
||||
return a < c; // tie-breaker for determinism
|
||||
});
|
||||
|
||||
// Determine how many to select.
|
||||
int num_to_select = 0;
|
||||
if(p.topk > 0.0f)
|
||||
{
|
||||
num_to_select = detail::select_count_from_topk(K_blk, p.topk);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use CDF threshold selection (smallest n s.t. cumulative prob >= cdfthreshd).
|
||||
std::vector<float> sorted_probs(K_blk);
|
||||
for(int i = 0; i < K_blk; ++i)
|
||||
sorted_probs[i] = prob[order[i]];
|
||||
num_to_select = detail::select_count_from_cdf(sorted_probs, p.cdfthreshd);
|
||||
num_to_select = std::max(1, num_to_select);
|
||||
}
|
||||
|
||||
// Select top-kb blocks by order[0..num_to_select-1].
|
||||
for(int i = 0; i < num_to_select; ++i)
|
||||
{
|
||||
const int kb = order[i];
|
||||
block_map(b, hq, qb, kb) = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return block_map;
|
||||
}
|
||||
|
||||
// Convert one-hot block_map -> delta-encoded LUT + valid_block_num (CK VSA format).
|
||||
template <typename MapT>
|
||||
VSALut block_map_to_vsa_lut_delta(const ck_tile::HostTensor<MapT>& block_map)
|
||||
{
|
||||
const auto lens = block_map.get_lengths();
|
||||
const int B = static_cast<int>(lens[0]);
|
||||
const int H = static_cast<int>(lens[1]);
|
||||
const int Q = static_cast<int>(lens[2]);
|
||||
const int K = static_cast<int>(lens[3]);
|
||||
|
||||
VSALut out{
|
||||
ck_tile::HostTensor<int32_t>({B, H, Q, K}),
|
||||
ck_tile::HostTensor<int32_t>({B, H, Q}),
|
||||
};
|
||||
|
||||
for(int b = 0; b < B; ++b)
|
||||
{
|
||||
for(int h = 0; h < H; ++h)
|
||||
{
|
||||
for(int q = 0; q < Q; ++q)
|
||||
{
|
||||
int32_t valid = 0;
|
||||
int32_t prev = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
const bool on = static_cast<int>(block_map(b, h, q, k)) != 0;
|
||||
if(on)
|
||||
{
|
||||
out.lut(b, h, q, valid) = static_cast<int32_t>(k - prev);
|
||||
prev = static_cast<int32_t>(k);
|
||||
++valid;
|
||||
}
|
||||
}
|
||||
|
||||
out.valid_block_num(b, h, q) = valid;
|
||||
|
||||
// Optional: zero-fill the unused tail for determinism.
|
||||
for(int i = valid; i < K; ++i)
|
||||
out.lut(b, h, q, i) = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace sparge
|
||||
422
example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp
Normal file
422
example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp
Normal file
@@ -0,0 +1,422 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Demo: Sparge block-map -> Jenga sparse attention
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <chrono>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "jenga_sparse_attention.h"
|
||||
#include "sparge_tool.hpp"
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim,
|
||||
bool i_perm)
|
||||
{
|
||||
if(i_perm)
|
||||
{
|
||||
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
|
||||
}
|
||||
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
|
||||
{
|
||||
auto lens = tensor.get_lengths();
|
||||
ck_tile::index_t batch = lens[0];
|
||||
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
|
||||
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
|
||||
ck_tile::index_t hdim = lens[3];
|
||||
|
||||
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t s = 0; s < seqlen; ++s)
|
||||
{
|
||||
for(ck_tile::index_t d = 0; d < hdim; ++d)
|
||||
{
|
||||
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto get_error_tolerance()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 4e-2;
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
atol = 2e-1;
|
||||
rtol = 2e-1;
|
||||
}
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float to_float_for_compare(T value)
|
||||
{
|
||||
return static_cast<float>(value);
|
||||
}
|
||||
|
||||
template <>
|
||||
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
|
||||
{
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
return static_cast<float>(value);
|
||||
#else
|
||||
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
|
||||
#endif
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Command line argument parser
|
||||
// ============================================================================
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
.insert("b", "1", "batch size")
|
||||
.insert("h", "4", "num of head for q")
|
||||
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
|
||||
.insert("s", "4096", "seqlen_q")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("prec", "fp16", "data type: fp16/bf16")
|
||||
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name")
|
||||
// Sparge-specific
|
||||
.insert("blkq", "128", "Sparge BLKQ")
|
||||
.insert("blkk", "128", "Sparge BLKK")
|
||||
.insert("simthreshd1", "0.6", "Sparge sim threshold")
|
||||
.insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)")
|
||||
.insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Test Function
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
|
||||
// Sparge params
|
||||
ck_tile::index_t blkq = arg_parser.get_int("blkq");
|
||||
ck_tile::index_t blkk = arg_parser.get_int("blkk");
|
||||
float simthreshd1 = arg_parser.get_float("simthreshd1");
|
||||
float cdfthreshd = arg_parser.get_float("cdfthreshd");
|
||||
float topk = arg_parser.get_float("topk");
|
||||
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128)
|
||||
{
|
||||
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
|
||||
std::cout << "Jenga/VSA kernel instances are generated for BLKQ=BLKK=128, "
|
||||
"hdim_q=128, hdim_v=128 only."
|
||||
<< std::endl;
|
||||
std::cout << "TEST SKIPPED" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
ck_tile::index_t BLKQ = blkq;
|
||||
ck_tile::index_t BLKK = blkk;
|
||||
|
||||
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
|
||||
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
|
||||
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << "[Sparge -> Jenga Sparse Attention Demo]" << std::endl;
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
|
||||
<< std::endl;
|
||||
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
|
||||
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
|
||||
std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl;
|
||||
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
|
||||
<< std::endl;
|
||||
std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd
|
||||
<< ", topk=" << topk << ")" << std::endl;
|
||||
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
|
||||
|
||||
// Create host tensors
|
||||
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
|
||||
ck_tile::HostTensor<T> output_host =
|
||||
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
|
||||
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
|
||||
std::cout << "\nInitializing tensors..." << std::endl;
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
|
||||
|
||||
// Build block map using Sparge tool
|
||||
std::cout << "Building Sparge block map..." << std::endl;
|
||||
sparge::SpargeParams p;
|
||||
p.BLKQ = static_cast<int>(BLKQ);
|
||||
p.BLKK = static_cast<int>(BLKK);
|
||||
p.simthreshd1 = simthreshd1;
|
||||
p.cdfthreshd = cdfthreshd;
|
||||
p.topk = topk;
|
||||
p.i_perm = i_perm;
|
||||
|
||||
ck_tile::HostTensor<uint8_t> block_relation_onehot =
|
||||
sparge::build_block_map_meansim(q_host, k_host, p);
|
||||
|
||||
// Print actual sparsity
|
||||
std::size_t total_blocks = 0;
|
||||
std::size_t active_blocks = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
total_blocks++;
|
||||
if(block_relation_onehot(b, h, qb, kb) != 0)
|
||||
active_blocks++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
float actual_sparsity =
|
||||
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
|
||||
<< total_blocks << " blocks active)" << std::endl;
|
||||
|
||||
std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl;
|
||||
|
||||
try
|
||||
{
|
||||
if(kname)
|
||||
{
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
1);
|
||||
}
|
||||
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
jenga_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
double avg_time_ms =
|
||||
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
|
||||
|
||||
std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<"
|
||||
<< std::endl;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
std::cout << "\n--- Performing CPU validation ---" << std::endl;
|
||||
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
std::cout << "Computing reference output..." << std::endl;
|
||||
auto q_ref = to_bhsd(q_host, i_perm);
|
||||
auto k_ref = to_bhsd(k_host, i_perm);
|
||||
auto v_ref = to_bhsd(v_host, i_perm);
|
||||
|
||||
ck_tile::reference_blocked_attention<T, uint8_t>(
|
||||
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
|
||||
|
||||
auto [rtol, atol] = get_error_tolerance<T>();
|
||||
|
||||
float max_diff = 0.0f;
|
||||
float max_rel_diff = 0.0f;
|
||||
std::size_t num_errors = 0;
|
||||
|
||||
auto output_host_bhsd = to_bhsd(output_host, o_perm);
|
||||
for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
|
||||
{
|
||||
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
|
||||
float ref_val = to_float_for_compare(output_ref.mData[i]);
|
||||
float diff = std::abs(gpu_val - ref_val);
|
||||
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
|
||||
|
||||
max_diff = std::max(max_diff, diff);
|
||||
max_rel_diff = std::max(max_rel_diff, rel_diff);
|
||||
|
||||
if(diff > atol && rel_diff > rtol)
|
||||
{
|
||||
num_errors++;
|
||||
if(num_errors <= 5)
|
||||
{
|
||||
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
|
||||
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\nValidation results:" << std::endl;
|
||||
std::cout << " Max absolute difference: " << max_diff << std::endl;
|
||||
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
|
||||
std::cout << " Number of mismatches: " << num_errors << " / "
|
||||
<< output_host_bhsd.mData.size() << std::endl;
|
||||
|
||||
if(num_errors == 0)
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main
|
||||
// ============================================================================
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "Failed to parse arguments" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
bool test_result = false;
|
||||
if(prec == "fp16")
|
||||
{
|
||||
test_result = run_test<ck_tile::half_t>(arg_parser);
|
||||
}
|
||||
else if(prec == "bf16")
|
||||
{
|
||||
test_result = run_test<ck_tile::bf16_t>(arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported precision: " << prec << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return test_result ? 0 : -1;
|
||||
}
|
||||
429
example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp
Normal file
429
example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp
Normal file
@@ -0,0 +1,429 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <chrono>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "jenga_sparse_attention.h"
|
||||
#include "sparge_tool.hpp"
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> make_qkv_tensor(ck_tile::index_t batch,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim,
|
||||
bool i_perm)
|
||||
{
|
||||
if(i_perm)
|
||||
{
|
||||
return ck_tile::HostTensor<T>({batch, nhead, seqlen, hdim});
|
||||
}
|
||||
return ck_tile::HostTensor<T>({batch, seqlen, nhead, hdim});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ck_tile::HostTensor<T> to_bhsd(const ck_tile::HostTensor<T>& tensor, bool is_bhsd)
|
||||
{
|
||||
auto lens = tensor.get_lengths();
|
||||
ck_tile::index_t batch = lens[0];
|
||||
ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1];
|
||||
ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2];
|
||||
ck_tile::index_t hdim = lens[3];
|
||||
|
||||
ck_tile::HostTensor<T> out({batch, nhead, seqlen, hdim});
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t s = 0; s < seqlen; ++s)
|
||||
{
|
||||
for(ck_tile::index_t d = 0; d < hdim; ++d)
|
||||
{
|
||||
out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto get_error_tolerance()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 4e-2;
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
atol = 2e-1;
|
||||
rtol = 2e-1;
|
||||
}
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float to_float_for_compare(T value)
|
||||
{
|
||||
return static_cast<float>(value);
|
||||
}
|
||||
|
||||
template <>
|
||||
float to_float_for_compare<ck_tile::bf16_t>(ck_tile::bf16_t value)
|
||||
{
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
return static_cast<float>(value);
|
||||
#else
|
||||
return ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value));
|
||||
#endif
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Command line argument parser
|
||||
// ============================================================================
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "0:no validation, 1:cpu validation")
|
||||
.insert("b", "1", "batch size")
|
||||
.insert("h", "4", "num of head for q")
|
||||
.insert("h_k", "-1", "num of head for k/v, -1 means equal to h")
|
||||
.insert("s", "4096", "seqlen_q")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("prec", "fp16", "data type: fp16/bf16")
|
||||
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name")
|
||||
// Sparge-specific
|
||||
.insert("blkq", "128", "Sparge BLKQ")
|
||||
.insert("blkk", "128", "Sparge BLKK")
|
||||
.insert("simthreshd1", "0.6", "Sparge sim threshold")
|
||||
.insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)")
|
||||
.insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Test Function
|
||||
// ============================================================================
|
||||
|
||||
template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
|
||||
// Sparge params
|
||||
ck_tile::index_t blkq = arg_parser.get_int("blkq");
|
||||
ck_tile::index_t blkk = arg_parser.get_int("blkk");
|
||||
float simthreshd1 = arg_parser.get_float("simthreshd1");
|
||||
float cdfthreshd = arg_parser.get_float("cdfthreshd");
|
||||
float topk = arg_parser.get_float("topk");
|
||||
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128)
|
||||
{
|
||||
std::cout << "\n>>> TEST SKIPPED <<<" << std::endl;
|
||||
std::cout << "VSA kernel instances are generated for BLKQ=BLKK=128, "
|
||||
"hdim_q=128, hdim_v=128 only."
|
||||
<< std::endl;
|
||||
std::cout << "TEST SKIPPED" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
ck_tile::index_t BLKQ = blkq;
|
||||
ck_tile::index_t BLKK = blkk;
|
||||
|
||||
ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
|
||||
ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
|
||||
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << "[Sparge -> VSA Sparse Attention Demo]" << std::endl;
|
||||
std::cout << "============================================================" << std::endl;
|
||||
std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k
|
||||
<< std::endl;
|
||||
std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl;
|
||||
std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl;
|
||||
std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl;
|
||||
std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks
|
||||
<< std::endl;
|
||||
std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd
|
||||
<< ", topk=" << topk << ")" << std::endl;
|
||||
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
|
||||
|
||||
// Create host tensors
|
||||
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
|
||||
ck_tile::HostTensor<T> output_host =
|
||||
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
|
||||
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
|
||||
std::cout << "\nInitializing tensors..." << std::endl;
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
|
||||
|
||||
// Build block map using Sparge tool
|
||||
std::cout << "Building Sparge block map..." << std::endl;
|
||||
sparge::SpargeParams p;
|
||||
p.BLKQ = static_cast<int>(BLKQ);
|
||||
p.BLKK = static_cast<int>(BLKK);
|
||||
p.simthreshd1 = simthreshd1;
|
||||
p.cdfthreshd = cdfthreshd;
|
||||
p.topk = topk;
|
||||
p.i_perm = i_perm;
|
||||
|
||||
ck_tile::HostTensor<uint8_t> block_relation_onehot =
|
||||
sparge::build_block_map_meansim(q_host, k_host, p);
|
||||
|
||||
// Convert to VSA LUT (delta-encoded) + valid_block_num
|
||||
std::cout << "Converting block map to VSA LUT (delta)..." << std::endl;
|
||||
auto vsa_lut = sparge::block_map_to_vsa_lut_delta(block_relation_onehot);
|
||||
|
||||
// Print actual sparsity (based on one-hot)
|
||||
std::size_t total_blocks = 0;
|
||||
std::size_t active_blocks = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
total_blocks++;
|
||||
if(block_relation_onehot(b, h, qb, kb) != 0)
|
||||
active_blocks++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
float actual_sparsity =
|
||||
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
|
||||
<< total_blocks << " blocks active)" << std::endl;
|
||||
|
||||
std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl;
|
||||
|
||||
try
|
||||
{
|
||||
if(kname)
|
||||
{
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
vsa_lut.lut,
|
||||
vsa_lut.valid_block_num,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
1);
|
||||
}
|
||||
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
vsa_lut.lut,
|
||||
vsa_lut.valid_block_num,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
vsa_sparse_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
vsa_lut.lut,
|
||||
vsa_lut.valid_block_num,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
|
||||
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
double avg_time_ms =
|
||||
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
|
||||
|
||||
std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<"
|
||||
<< std::endl;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
std::cout << "\n--- Performing CPU validation ---" << std::endl;
|
||||
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
std::cout << "Computing reference output..." << std::endl;
|
||||
auto q_ref = to_bhsd(q_host, i_perm);
|
||||
auto k_ref = to_bhsd(k_host, i_perm);
|
||||
auto v_ref = to_bhsd(v_host, i_perm);
|
||||
|
||||
ck_tile::reference_blocked_attention<T, uint8_t>(
|
||||
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
|
||||
|
||||
auto [rtol, atol] = get_error_tolerance<T>();
|
||||
|
||||
float max_diff = 0.0f;
|
||||
float max_rel_diff = 0.0f;
|
||||
std::size_t num_errors = 0;
|
||||
|
||||
auto output_host_bhsd = to_bhsd(output_host, o_perm);
|
||||
for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i)
|
||||
{
|
||||
float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]);
|
||||
float ref_val = to_float_for_compare(output_ref.mData[i]);
|
||||
float diff = std::abs(gpu_val - ref_val);
|
||||
float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff;
|
||||
|
||||
max_diff = std::max(max_diff, diff);
|
||||
max_rel_diff = std::max(max_rel_diff, rel_diff);
|
||||
|
||||
if(diff > atol && rel_diff > rtol)
|
||||
{
|
||||
num_errors++;
|
||||
if(num_errors <= 5)
|
||||
{
|
||||
std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val
|
||||
<< ", Ref=" << ref_val << ", Diff=" << diff << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\nValidation results:" << std::endl;
|
||||
std::cout << " Max absolute difference: " << max_diff << std::endl;
|
||||
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
|
||||
std::cout << " Number of mismatches: " << num_errors << " / "
|
||||
<< output_host_bhsd.mData.size() << std::endl;
|
||||
|
||||
if(num_errors == 0)
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl;
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main
|
||||
// ============================================================================
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "Failed to parse arguments" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
bool test_result = false;
|
||||
if(prec == "fp16")
|
||||
{
|
||||
test_result = run_test<ck_tile::half_t>(arg_parser);
|
||||
}
|
||||
else if(prec == "bf16")
|
||||
{
|
||||
test_result = run_test<ck_tile::bf16_t>(arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported precision: " << prec << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return test_result ? 0 : -1;
|
||||
}
|
||||
Reference in New Issue
Block a user