From 5e8a010fc612e4ecea9a3ecb94f7f2c7484f41d9 Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Tue, 16 Dec 2025 07:40:45 +0000 Subject: [PATCH] remove lse arg --- .../50_sparse_attn/jenga_sparse_attention.cu | 10 +--- .../50_sparse_attn/jenga_sparse_attention.h | 2 - .../50_sparse_attn/test_jenga_sparse_attn.cpp | 48 +++++++------------ .../50_sparse_attn/test_vsa_sparse_attn.cpp | 41 +++++++--------- .../50_sparse_attn/vsa_sparse_attention.cu | 10 +--- 5 files changed, 40 insertions(+), 71 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu index 925960a0a8..3be8cae198 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu @@ -16,7 +16,6 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& Tblock_relation_onehot, ck_tile::HostTensor& Y, std::optional> bias, - std::optional> lse, std::optional> seqstart_q, std::optional> seqstart_k, int bias_type, @@ -77,14 +76,11 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, // Optional buffers ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0); - ck_tile::DeviceMem lse_buf(lse ? lse->get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem seqstart_q_buf(seqstart_q ? seqstart_q->get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem seqstart_k_buf(seqstart_k ? seqstart_k->get_element_space_size_in_bytes() : 0); if(bias) bias_buf.ToDevice(bias->data()); - if(lse) - lse_buf.ToDevice(lse->data()); if(seqstart_q) seqstart_q_buf.ToDevice(seqstart_q->data()); if(seqstart_k) @@ -150,7 +146,7 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, args.batch_stride_v = batch_stride_v; args.bias_ptr = bias ? bias_buf.GetDeviceBuffer() : nullptr; - args.lse_ptr = lse ? lse_buf.GetDeviceBuffer() : nullptr; + args.lse_ptr = nullptr; args.o_ptr = o_buf.GetDeviceBuffer(); args.seqstart_q_ptr = (mode == 1 ? seqstart_q_buf.GetDeviceBuffer() : nullptr); @@ -199,7 +195,7 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = static_cast(bias_type); - traits.has_lse = lse ? true : false; + traits.has_lse = false; traits.do_fp8_static_quant = false; traits.has_dropout = false; @@ -228,7 +224,6 @@ jenga_sparse_attention( std::optional>, std::optional>, std::optional>, - std::optional>, int, int, int, int, int, int, int, int, int, bool, bool, int, int); template ck_tile::HostTensor @@ -239,5 +234,4 @@ jenga_sparse_attention( std::optional>, std::optional>, std::optional>, - std::optional>, int, int, int, int, int, int, int, int, int, bool, bool, int, int); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h index 8fad02ce04..2f0be76bf5 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h @@ -15,7 +15,6 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& Tblock_relation_onehot, ck_tile::HostTensor& Y, std::optional> bias, - std::optional> lse, std::optional> seqstart_q, std::optional> seqstart_k, int bias_type, @@ -41,7 +40,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& TKV_blocks, // valid_block_num must be int32_t ck_tile::HostTensor& Y, std::optional> bias, - std::optional> lse, std::optional> seqstart_q, std::optional> seqstart_k, int bias_type, diff --git a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp index fa0eea5b4f..0acde4bed3 100644 --- a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp @@ -161,7 +161,6 @@ auto create_args(int argc, char* argv[]) .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") .insert("operm", "1", "permute output") .insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi") - .insert("lse", "0", "0:not store lse, 1:store lse") .insert("seed", "42", "random seed") .insert("warmup", "5", "warmup iterations") .insert("repeat", "20", "benchmark iterations") @@ -178,25 +177,24 @@ template bool run_test(const ck_tile::ArgParser& arg_parser) { // Parse arguments - int do_validation = arg_parser.get_int("v"); - int mode = arg_parser.get_int("mode"); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - ck_tile::index_t block_size = arg_parser.get_int("block_size"); - float sparsity = arg_parser.get_float("sparsity"); - bool i_perm = arg_parser.get_bool("iperm"); - bool o_perm = arg_parser.get_bool("operm"); - int bias_type = arg_parser.get_int("bias"); - [[maybe_unused]] bool store_lse = arg_parser.get_bool("lse"); - uint32_t seed = arg_parser.get_uint32("seed"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - [[maybe_unused]] int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int mode = arg_parser.get_int("mode"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + ck_tile::index_t block_size = arg_parser.get_int("block_size"); + float sparsity = arg_parser.get_float("sparsity"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + int bias_type = arg_parser.get_int("bias"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + [[maybe_unused]] int kname = arg_parser.get_int("kname"); // Handle default values if(nhead_k < 0) @@ -240,9 +238,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Block relation onehot: [B, H, Q_blocks, K_blocks] ck_tile::HostTensor block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks}); - // LSE tensor (optional) - ck_tile::HostTensor lse_host({batch, nhead, seqlen_q}); - // Initialize tensors with random values std::cout << "\nInitializing tensors..." << std::endl; ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); @@ -291,7 +286,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Optional tensors std::optional> bias_opt = std::nullopt; - std::optional> lse_opt = std::nullopt; std::optional> seqstart_q_opt = std::nullopt; std::optional> seqstart_k_opt = std::nullopt; @@ -299,10 +293,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { bias_opt = bias_host; } - if(store_lse) - { - lse_opt = lse_host; - } // Run kernel std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl; @@ -318,7 +308,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) block_relation_onehot, output_host, bias_opt, - lse_opt, seqstart_q_opt, seqstart_k_opt, bias_type, @@ -348,7 +337,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) block_relation_onehot, output_host, bias_opt, - lse_opt, seqstart_q_opt, seqstart_k_opt, bias_type, diff --git a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp index a5c3f9f2b3..be4653c994 100644 --- a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp @@ -201,7 +201,6 @@ auto create_args(int argc, char* argv[]) .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") .insert("operm", "1", "permute output") .insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi") - .insert("lse", "0", "0:not store lse, 1:store lse") .insert("seed", "42", "random seed") .insert("warmup", "5", "warmup iterations") .insert("repeat", "20", "benchmark iterations") @@ -218,25 +217,24 @@ template bool run_test(const ck_tile::ArgParser& arg_parser) { // Parse arguments - int do_validation = arg_parser.get_int("v"); - int mode = arg_parser.get_int("mode"); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - ck_tile::index_t block_size = arg_parser.get_int("block_size"); - float sparsity = arg_parser.get_float("sparsity"); - bool i_perm = arg_parser.get_bool("iperm"); - bool o_perm = arg_parser.get_bool("operm"); - int bias_type = arg_parser.get_int("bias"); - [[maybe_unused]] bool store_lse = arg_parser.get_bool("lse"); - uint32_t seed = arg_parser.get_uint32("seed"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - [[maybe_unused]] int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int mode = arg_parser.get_int("mode"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + ck_tile::index_t block_size = arg_parser.get_int("block_size"); + float sparsity = arg_parser.get_float("sparsity"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + int bias_type = arg_parser.get_int("bias"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + [[maybe_unused]] int kname = arg_parser.get_int("kname"); // Handle default values if(nhead_k < 0) @@ -343,7 +341,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Optional tensors std::optional> bias_opt = std::nullopt; - std::optional> lse_opt = std::nullopt; std::optional> seqstart_q_opt = std::nullopt; std::optional> seqstart_k_opt = std::nullopt; @@ -367,7 +364,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) valid_block_num_host, output_host, bias_opt, - lse_opt, seqstart_q_opt, seqstart_k_opt, bias_type, @@ -398,7 +394,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) valid_block_num_host, output_host, bias_opt, - lse_opt, seqstart_q_opt, seqstart_k_opt, bias_type, diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu index e7a8fefa7a..e3199b444f 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu @@ -17,7 +17,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& TKV_blocks, ck_tile::HostTensor& Y, std::optional> bias, - std::optional> lse, std::optional> seqstart_q, std::optional> seqstart_k, int bias_type, @@ -80,7 +79,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, // Optional buffers ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0); - ck_tile::DeviceMem lse_buf(lse ? lse->get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem seqstart_q_buf(seqstart_q ? seqstart_q->get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem seqstart_k_buf(seqstart_k ? seqstart_k->get_element_space_size_in_bytes() @@ -88,8 +86,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, if(bias) bias_buf.ToDevice(bias->data()); - if(lse) - lse_buf.ToDevice(lse->data()); if(seqstart_q) seqstart_q_buf.ToDevice(seqstart_q->data()); if(seqstart_k) @@ -156,7 +152,7 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, args.batch_stride_v = batch_stride_v; args.bias_ptr = bias ? bias_buf.GetDeviceBuffer() : nullptr; - args.lse_ptr = lse ? lse_buf.GetDeviceBuffer() : nullptr; + args.lse_ptr = nullptr; args.o_ptr = o_buf.GetDeviceBuffer(); args.seqstart_q_ptr = (mode == 1 ? seqstart_q_buf.GetDeviceBuffer() : nullptr); @@ -205,7 +201,7 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = static_cast(bias_type); - traits.has_lse = lse ? true : false; + traits.has_lse = false; traits.do_fp8_static_quant = false; traits.has_dropout = false; @@ -234,7 +230,6 @@ vsa_sparse_attention( std::optional>, std::optional>, std::optional>, - std::optional>, int, int, int, int, int, int, int, int, int, bool, bool, int, int); template ck_tile::HostTensor @@ -245,5 +240,4 @@ vsa_sparse_attention( std::optional>, std::optional>, std::optional>, - std::optional>, int, int, int, int, int, int, int, int, int, bool, bool, int, int);