mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
remove lse arg
This commit is contained in:
@@ -16,7 +16,6 @@ jenga_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
ck_tile::HostTensor<DataType_>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> bias,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> lse,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_q,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_k,
|
||||
int bias_type,
|
||||
@@ -77,14 +76,11 @@ jenga_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
|
||||
// Optional buffers
|
||||
ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0);
|
||||
ck_tile::DeviceMem lse_buf(lse ? lse->get_element_space_size_in_bytes() : 0);
|
||||
ck_tile::DeviceMem seqstart_q_buf(seqstart_q ? seqstart_q->get_element_space_size_in_bytes() : 0);
|
||||
ck_tile::DeviceMem seqstart_k_buf(seqstart_k ? seqstart_k->get_element_space_size_in_bytes() : 0);
|
||||
|
||||
if(bias)
|
||||
bias_buf.ToDevice(bias->data());
|
||||
if(lse)
|
||||
lse_buf.ToDevice(lse->data());
|
||||
if(seqstart_q)
|
||||
seqstart_q_buf.ToDevice(seqstart_q->data());
|
||||
if(seqstart_k)
|
||||
@@ -150,7 +146,7 @@ jenga_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
args.batch_stride_v = batch_stride_v;
|
||||
|
||||
args.bias_ptr = bias ? bias_buf.GetDeviceBuffer() : nullptr;
|
||||
args.lse_ptr = lse ? lse_buf.GetDeviceBuffer() : nullptr;
|
||||
args.lse_ptr = nullptr;
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqstart_q_ptr = (mode == 1 ? seqstart_q_buf.GetDeviceBuffer() : nullptr);
|
||||
@@ -199,7 +195,7 @@ jenga_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
|
||||
traits.mask_type = mask.type;
|
||||
traits.bias_type = static_cast<bias_enum>(bias_type);
|
||||
traits.has_lse = lse ? true : false;
|
||||
traits.has_lse = false;
|
||||
traits.do_fp8_static_quant = false;
|
||||
|
||||
traits.has_dropout = false;
|
||||
@@ -228,7 +224,6 @@ jenga_sparse_attention<ck_tile::half_t>(
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
int, int, int, int, int, int, int, int, int, bool, bool, int, int);
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::bf16_t>
|
||||
@@ -239,5 +234,4 @@ jenga_sparse_attention<ck_tile::bf16_t>(
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
int, int, int, int, int, int, int, int, int, bool, bool, int, int);
|
||||
|
||||
@@ -15,7 +15,6 @@ jenga_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
ck_tile::HostTensor<DataType_>& Tblock_relation_onehot,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> bias,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> lse,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_q,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_k,
|
||||
int bias_type,
|
||||
@@ -41,7 +40,6 @@ vsa_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
ck_tile::HostTensor<int32_t>& TKV_blocks, // valid_block_num must be int32_t
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> bias,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> lse,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_q,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_k,
|
||||
int bias_type,
|
||||
|
||||
@@ -161,7 +161,6 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi")
|
||||
.insert("lse", "0", "0:not store lse, 1:store lse")
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
@@ -178,25 +177,24 @@ template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Parse arguments
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int mode = arg_parser.get_int("mode");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t block_size = arg_parser.get_int("block_size");
|
||||
float sparsity = arg_parser.get_float("sparsity");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
int bias_type = arg_parser.get_int("bias");
|
||||
[[maybe_unused]] bool store_lse = arg_parser.get_bool("lse");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
[[maybe_unused]] int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int mode = arg_parser.get_int("mode");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t block_size = arg_parser.get_int("block_size");
|
||||
float sparsity = arg_parser.get_float("sparsity");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
int bias_type = arg_parser.get_int("bias");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
[[maybe_unused]] int kname = arg_parser.get_int("kname");
|
||||
|
||||
// Handle default values
|
||||
if(nhead_k < 0)
|
||||
@@ -240,9 +238,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
// Block relation onehot: [B, H, Q_blocks, K_blocks]
|
||||
ck_tile::HostTensor<T> block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
|
||||
// LSE tensor (optional)
|
||||
ck_tile::HostTensor<T> lse_host({batch, nhead, seqlen_q});
|
||||
|
||||
// Initialize tensors with random values
|
||||
std::cout << "\nInitializing tensors..." << std::endl;
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
|
||||
@@ -291,7 +286,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// Optional tensors
|
||||
std::optional<ck_tile::HostTensor<T>> bias_opt = std::nullopt;
|
||||
std::optional<ck_tile::HostTensor<T>> lse_opt = std::nullopt;
|
||||
std::optional<ck_tile::HostTensor<T>> seqstart_q_opt = std::nullopt;
|
||||
std::optional<ck_tile::HostTensor<T>> seqstart_k_opt = std::nullopt;
|
||||
|
||||
@@ -299,10 +293,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
bias_opt = bias_host;
|
||||
}
|
||||
if(store_lse)
|
||||
{
|
||||
lse_opt = lse_host;
|
||||
}
|
||||
|
||||
// Run kernel
|
||||
std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl;
|
||||
@@ -318,7 +308,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
bias_opt,
|
||||
lse_opt,
|
||||
seqstart_q_opt,
|
||||
seqstart_k_opt,
|
||||
bias_type,
|
||||
@@ -348,7 +337,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
block_relation_onehot,
|
||||
output_host,
|
||||
bias_opt,
|
||||
lse_opt,
|
||||
seqstart_q_opt,
|
||||
seqstart_k_opt,
|
||||
bias_type,
|
||||
|
||||
@@ -201,7 +201,6 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi")
|
||||
.insert("lse", "0", "0:not store lse, 1:store lse")
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
@@ -218,25 +217,24 @@ template <typename T>
|
||||
bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Parse arguments
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int mode = arg_parser.get_int("mode");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t block_size = arg_parser.get_int("block_size");
|
||||
float sparsity = arg_parser.get_float("sparsity");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
int bias_type = arg_parser.get_int("bias");
|
||||
[[maybe_unused]] bool store_lse = arg_parser.get_bool("lse");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
[[maybe_unused]] int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int mode = arg_parser.get_int("mode");
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
ck_tile::index_t block_size = arg_parser.get_int("block_size");
|
||||
float sparsity = arg_parser.get_float("sparsity");
|
||||
bool i_perm = arg_parser.get_bool("iperm");
|
||||
bool o_perm = arg_parser.get_bool("operm");
|
||||
int bias_type = arg_parser.get_int("bias");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
[[maybe_unused]] int kname = arg_parser.get_int("kname");
|
||||
|
||||
// Handle default values
|
||||
if(nhead_k < 0)
|
||||
@@ -343,7 +341,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// Optional tensors
|
||||
std::optional<ck_tile::HostTensor<T>> bias_opt = std::nullopt;
|
||||
std::optional<ck_tile::HostTensor<T>> lse_opt = std::nullopt;
|
||||
std::optional<ck_tile::HostTensor<T>> seqstart_q_opt = std::nullopt;
|
||||
std::optional<ck_tile::HostTensor<T>> seqstart_k_opt = std::nullopt;
|
||||
|
||||
@@ -367,7 +364,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
valid_block_num_host,
|
||||
output_host,
|
||||
bias_opt,
|
||||
lse_opt,
|
||||
seqstart_q_opt,
|
||||
seqstart_k_opt,
|
||||
bias_type,
|
||||
@@ -398,7 +394,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
valid_block_num_host,
|
||||
output_host,
|
||||
bias_opt,
|
||||
lse_opt,
|
||||
seqstart_q_opt,
|
||||
seqstart_k_opt,
|
||||
bias_type,
|
||||
|
||||
@@ -17,7 +17,6 @@ vsa_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
ck_tile::HostTensor<int32_t>& TKV_blocks,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> bias,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> lse,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_q,
|
||||
std::optional<ck_tile::HostTensor<DataType_>> seqstart_k,
|
||||
int bias_type,
|
||||
@@ -80,7 +79,6 @@ vsa_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
|
||||
// Optional buffers
|
||||
ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0);
|
||||
ck_tile::DeviceMem lse_buf(lse ? lse->get_element_space_size_in_bytes() : 0);
|
||||
ck_tile::DeviceMem seqstart_q_buf(seqstart_q ? seqstart_q->get_element_space_size_in_bytes()
|
||||
: 0);
|
||||
ck_tile::DeviceMem seqstart_k_buf(seqstart_k ? seqstart_k->get_element_space_size_in_bytes()
|
||||
@@ -88,8 +86,6 @@ vsa_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
|
||||
if(bias)
|
||||
bias_buf.ToDevice(bias->data());
|
||||
if(lse)
|
||||
lse_buf.ToDevice(lse->data());
|
||||
if(seqstart_q)
|
||||
seqstart_q_buf.ToDevice(seqstart_q->data());
|
||||
if(seqstart_k)
|
||||
@@ -156,7 +152,7 @@ vsa_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
args.batch_stride_v = batch_stride_v;
|
||||
|
||||
args.bias_ptr = bias ? bias_buf.GetDeviceBuffer() : nullptr;
|
||||
args.lse_ptr = lse ? lse_buf.GetDeviceBuffer() : nullptr;
|
||||
args.lse_ptr = nullptr;
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqstart_q_ptr = (mode == 1 ? seqstart_q_buf.GetDeviceBuffer() : nullptr);
|
||||
@@ -205,7 +201,7 @@ vsa_sparse_attention(ck_tile::HostTensor<DataType_>& TQ,
|
||||
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
|
||||
traits.mask_type = mask.type;
|
||||
traits.bias_type = static_cast<bias_enum>(bias_type);
|
||||
traits.has_lse = lse ? true : false;
|
||||
traits.has_lse = false;
|
||||
traits.do_fp8_static_quant = false;
|
||||
|
||||
traits.has_dropout = false;
|
||||
@@ -234,7 +230,6 @@ vsa_sparse_attention<ck_tile::half_t>(
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::half_t>>,
|
||||
int, int, int, int, int, int, int, int, int, bool, bool, int, int);
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::bf16_t>
|
||||
@@ -245,5 +240,4 @@ vsa_sparse_attention<ck_tile::bf16_t>(
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
std::optional<ck_tile::HostTensor<ck_tile::bf16_t>>,
|
||||
int, int, int, int, int, int, int, int, int, bool, bool, int, int);
|
||||
|
||||
Reference in New Issue
Block a user