From 29d96a90f0caf7bf6a2ce2c3f20cc10046d9f28d Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Tue, 16 Dec 2025 06:57:42 +0000 Subject: [PATCH] add jenga support bf16 --- .../50_sparse_attn/jenga_sparse_attention.cu | 53 ++++-- .../50_sparse_attn/jenga_sparse_attention.h | 23 ++- .../50_sparse_attn/test_jenga_sparse_attn.cpp | 162 ++++++++++-------- .../50_sparse_attn/vsa_sparse_attention.cu | 2 +- 4 files changed, 147 insertions(+), 93 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu index 02f48ee005..925960a0a8 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu @@ -6,17 +6,19 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/device_memory.hpp" +#include -ck_tile::HostTensor -jenga_sparse_attention(ck_tile::HostTensor& TQ, - ck_tile::HostTensor& TK, - ck_tile::HostTensor& TV, - ck_tile::HostTensor& Tblock_relation_onehot, - ck_tile::HostTensor& Y, - std::optional> bias, - std::optional> lse, - std::optional> seqstart_q, - std::optional> seqstart_k, +template +ck_tile::HostTensor +jenga_sparse_attention(ck_tile::HostTensor& TQ, + ck_tile::HostTensor& TK, + ck_tile::HostTensor& TV, + ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, int bias_type, int batch, int nhead, @@ -31,8 +33,12 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, int max_seqlen_q, int max_seqlen_k) { + // Determine data type string based on template parameter std::string data_type = "fp16"; - // DataType is determined at compile time via template + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } if(max_seqlen_q == 0) max_seqlen_q = seqlen_q; @@ -208,7 +214,30 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, fmha_jenga_fwd(fmha_traits, args, stream_config); // Copy output back to host - Y = o_buf.ToHost(); + Y = o_buf.ToHost(); return Y; } + +// Explicit template instantiations +template ck_tile::HostTensor +jenga_sparse_attention( + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + int, int, int, int, int, int, int, int, int, bool, bool, int, int); + +template ck_tile::HostTensor +jenga_sparse_attention( + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + int, int, int, int, int, int, int, int, int, bool, bool, int, int); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h index 5ebc3fb94e..8fad02ce04 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h @@ -7,18 +7,17 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" -using DataType = ck_tile::half_t; - -ck_tile::HostTensor -jenga_sparse_attention(ck_tile::HostTensor& TQ, - ck_tile::HostTensor& TK, - ck_tile::HostTensor& TV, - ck_tile::HostTensor& Tblock_relation_onehot, - ck_tile::HostTensor& Y, - std::optional> bias, - std::optional> lse, - std::optional> seqstart_q, - std::optional> seqstart_k, +template +ck_tile::HostTensor +jenga_sparse_attention(ck_tile::HostTensor& TQ, + ck_tile::HostTensor& TK, + ck_tile::HostTensor& TV, + ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, int bias_type, int batch, int nhead, diff --git a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp index 75e3aa0b7d..fa0eea5b4f 100644 --- a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp @@ -131,6 +131,15 @@ void reference_blocked_attention( } } +// Get error tolerance based on data type +template +auto get_error_tolerance() +{ + double rtol = 1e-2; + double atol = 4e-2; // Higher tolerance for bf16/fp16 + return ck_tile::make_tuple(rtol, atol); +} + // ============================================================================ // Command line argument parser // ============================================================================ @@ -148,13 +157,15 @@ auto create_args(int argc, char* argv[]) .insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)") .insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)") + .insert("prec", "fp16", "data type: fp16/bf16") .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") .insert("operm", "1", "permute output") .insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi") .insert("lse", "0", "0:not store lse, 1:store lse") .insert("seed", "42", "random seed") .insert("warmup", "5", "warmup iterations") - .insert("repeat", "20", "benchmark iterations"); + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -163,29 +174,29 @@ auto create_args(int argc, char* argv[]) // ============================================================================ // Main Test Function // ============================================================================ +template bool run_test(const ck_tile::ArgParser& arg_parser) { - using T = DataType; // Use DataType defined in header (half_t) - // Parse arguments - int do_validation = arg_parser.get_int("v"); - int mode = arg_parser.get_int("mode"); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - ck_tile::index_t block_size = arg_parser.get_int("block_size"); - float sparsity = arg_parser.get_float("sparsity"); - bool i_perm = arg_parser.get_bool("iperm"); - bool o_perm = arg_parser.get_bool("operm"); - int bias_type = arg_parser.get_int("bias"); - bool store_lse = arg_parser.get_bool("lse"); - uint32_t seed = arg_parser.get_uint32("seed"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); + int do_validation = arg_parser.get_int("v"); + int mode = arg_parser.get_int("mode"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + ck_tile::index_t block_size = arg_parser.get_int("block_size"); + float sparsity = arg_parser.get_float("sparsity"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + int bias_type = arg_parser.get_int("bias"); + [[maybe_unused]] bool store_lse = arg_parser.get_bool("lse"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + [[maybe_unused]] int kname = arg_parser.get_int("kname"); // Handle default values if(nhead_k < 0) @@ -301,28 +312,28 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Warmup for(int i = 0; i < warmup; ++i) { - jenga_sparse_attention(q_host, - k_host, - v_host, - block_relation_onehot, - output_host, - bias_opt, - lse_opt, - seqstart_q_opt, - seqstart_k_opt, - bias_type, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - mode, - i_perm, - o_perm, - seqlen_q, - seqlen_k); + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k); } // Benchmark @@ -331,28 +342,28 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < repeat; ++i) { - jenga_sparse_attention(q_host, - k_host, - v_host, - block_relation_onehot, - output_host, - bias_opt, - lse_opt, - seqstart_q_opt, - seqstart_k_opt, - bias_type, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - mode, - i_perm, - o_perm, - seqlen_q, - seqlen_k); + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k); } [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); @@ -389,8 +400,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) scale); // Compare results - double rtol = 1e-2; - double atol = 4e-2; + auto [rtol, atol] = get_error_tolerance(); float max_diff = 0.0f; float max_rel_diff = 0.0f; @@ -450,6 +460,22 @@ int main(int argc, char* argv[]) return -1; } - bool test_result = run_test(arg_parser); + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if(prec == "fp16") + { + test_result = run_test(arg_parser); + } + else if(prec == "bf16") + { + test_result = run_test(arg_parser); + } + else + { + std::cerr << "Unsupported precision: " << prec << std::endl; + return -1; + } + return test_result ? 0 : -1; } diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu index a824e9389f..e7a8fefa7a 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu @@ -220,7 +220,7 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, fmha_vsa_fwd(fmha_traits, args, stream_config); // Copy output back to host - Y = o_buf.ToHost(); + Y = o_buf.ToHost(); return Y; }