mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Use generic lambda to init all the api traits/args
This commit is contained in:
@@ -758,131 +758,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
#endif
|
||||
std::cout << std::flush;
|
||||
|
||||
float appendkv_ave_time = 0;
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
if(0 < seqlen_knew || 0 < rotary_dim)
|
||||
{
|
||||
auto appendkv_traits = fmha_fwd_appendkv_traits{
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
is_v_rowmajor,
|
||||
(0 < rotary_dim
|
||||
? (is_rotary_interleaved ? rope_enum::interleaved : rope_enum::half_rotated)
|
||||
: rope_enum::none)};
|
||||
|
||||
auto appendkv_args = [&, k_paddings_ = seqlen_kpads]() {
|
||||
// setup stride_* arguments
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return 0 < page_block_size
|
||||
? (i_perm ? page_block_size : nhead_k * page_block_size)
|
||||
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t stride_vnew = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
|
||||
}();
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k =
|
||||
(0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q)
|
||||
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
|
||||
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v)
|
||||
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
|
||||
else
|
||||
return 0 < page_block_size
|
||||
? (i_perm ? hdim_v * page_block_size : page_block_size)
|
||||
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_vnew = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? seqlen_knew * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
|
||||
}();
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k =
|
||||
(0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
|
||||
: (nhead_k * shape_seqlen_k * hdim_q));
|
||||
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
|
||||
const ck_tile::index_t batch_stride_v =
|
||||
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
|
||||
: (nhead_k * hdim_v * shape_seqlen_k));
|
||||
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
|
||||
return fmha_fwd_appendkv_args{
|
||||
q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
knew_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
vnew_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
cache_seqlen_k_buf.GetDeviceBuffer(),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
shape_seqlen_q,
|
||||
max_seqlen_q,
|
||||
shape_seqlen_k - seqlen_knew /* kvcache seqlen_k for batch mode */,
|
||||
seqlen_knew,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
rotary_cos_buf.GetDeviceBuffer(),
|
||||
rotary_sin_buf.GetDeviceBuffer(),
|
||||
rotary_dim,
|
||||
0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr,
|
||||
batch_stride_block_table, // only used if 'block_table_ptr' is not nullptr
|
||||
page_block_size, // only used if 'block_table_ptr' is not nullptr
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_knew,
|
||||
stride_v,
|
||||
stride_vnew,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_knew,
|
||||
nhead_stride_v,
|
||||
nhead_stride_vnew,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_knew,
|
||||
batch_stride_v,
|
||||
batch_stride_vnew};
|
||||
}();
|
||||
|
||||
appendkv_ave_time = fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);
|
||||
}
|
||||
#endif
|
||||
|
||||
const auto init_traits = [&](auto& traits) {
|
||||
traits.hdim_q = hdim_q;
|
||||
traits.hdim_v = hdim_v;
|
||||
traits.data_type = data_type;
|
||||
traits.is_group_mode = (mode == mode_enum::group);
|
||||
traits.is_v_rowmajor = is_v_rowmajor;
|
||||
traits.mask_type = mask.type;
|
||||
traits.bias_type = bias.type;
|
||||
traits.has_lse = lse;
|
||||
traits.do_fp8_static_quant = squant;
|
||||
traits.hdim_q = hdim_q;
|
||||
traits.hdim_v = hdim_v;
|
||||
traits.data_type = data_type;
|
||||
traits.is_group_mode = (mode == mode_enum::group);
|
||||
traits.is_v_rowmajor = is_v_rowmajor;
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
|
||||
if constexpr(std::is_same_v<fmha_fwd_appendkv_traits, std::decay_t<decltype(traits)>>)
|
||||
{
|
||||
traits.has_dropout = (p_drop > 0.0f);
|
||||
traits.rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved
|
||||
: rope_enum::half_rotated)
|
||||
: rope_enum::none);
|
||||
}
|
||||
else
|
||||
{
|
||||
traits.mask_type = mask.type;
|
||||
traits.bias_type = bias.type;
|
||||
traits.has_lse = lse;
|
||||
traits.do_fp8_static_quant = squant;
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
|
||||
{
|
||||
traits.has_dropout = (p_drop > 0.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -892,14 +791,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
|
||||
/// 'nhead_stride_bias' are 0.
|
||||
// setup stride_* arguments
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size)
|
||||
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t stride_vnew = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
|
||||
}();
|
||||
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
||||
@@ -910,13 +816,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t nhead_stride_k =
|
||||
(0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q)
|
||||
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v)
|
||||
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
|
||||
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
|
||||
else
|
||||
return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size)
|
||||
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
|
||||
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_vnew = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? seqlen_knew * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_bias =
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
@@ -930,9 +843,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_k =
|
||||
(0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
|
||||
: (nhead_k * shape_seqlen_k * hdim_q));
|
||||
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
|
||||
const ck_tile::index_t batch_stride_v =
|
||||
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
|
||||
: (nhead_k * hdim_v * shape_seqlen_k));
|
||||
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
|
||||
@@ -944,17 +859,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
|
||||
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
||||
: bias_buf.GetDeviceBuffer();
|
||||
args.lse_ptr = lse_buf.GetDeviceBuffer();
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer();
|
||||
args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer();
|
||||
args.seqlen_k_ptr = k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_q = shape_seqlen_q;
|
||||
args.seqlen_k = shape_seqlen_k;
|
||||
@@ -965,67 +875,118 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
args.scale_p = scale_p;
|
||||
args.scale_o = scale_o;
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.stride_v = stride_v;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.nhead_stride_v = nhead_stride_v;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.batch_stride_v = batch_stride_v;
|
||||
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.stride_v = stride_v;
|
||||
args.stride_bias =
|
||||
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias;
|
||||
args.stride_o = stride_o;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.nhead_stride_v = nhead_stride_v;
|
||||
args.nhead_stride_bias = nhead_stride_bias;
|
||||
args.nhead_stride_lse = nhead_stride_lse;
|
||||
args.nhead_stride_o = nhead_stride_o;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.batch_stride_v = batch_stride_v;
|
||||
args.batch_stride_bias = batch_stride_bias;
|
||||
args.batch_stride_lse = batch_stride_lse;
|
||||
args.batch_stride_o = batch_stride_o;
|
||||
|
||||
args.window_size_left = mask.left;
|
||||
args.window_size_right = mask.right;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
||||
if constexpr(std::is_same_v<fmha_fwd_appendkv_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
||||
args.knew_ptr = knew_buf.GetDeviceBuffer();
|
||||
args.vnew_ptr = vnew_buf.GetDeviceBuffer();
|
||||
args.seqlen_knew = seqlen_knew;
|
||||
|
||||
args.stride_randval = stride_randval;
|
||||
args.nhead_stride_randval = nhead_stride_randval;
|
||||
args.batch_stride_randval = batch_stride_randval;
|
||||
args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer();
|
||||
|
||||
args.p_drop = p_drop;
|
||||
args.s_randval = s_randval;
|
||||
args.drop_seed_offset = std::tie(drop_seed, drop_offset);
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
|
||||
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
|
||||
args.rotary_cos_ptr = rotary_cos_buf.GetDeviceBuffer();
|
||||
args.rotary_sin_ptr = rotary_sin_buf.GetDeviceBuffer();
|
||||
args.rotary_dim = rotary_dim;
|
||||
|
||||
args.block_table_ptr =
|
||||
0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr;
|
||||
args.batch_stride_block_table = batch_stride_block_table;
|
||||
args.page_block_size = page_block_size;
|
||||
|
||||
args.num_splits = num_splits;
|
||||
args.stride_knew = stride_knew;
|
||||
args.stride_vnew = stride_vnew;
|
||||
args.nhead_stride_knew = nhead_stride_knew;
|
||||
args.nhead_stride_vnew = nhead_stride_vnew;
|
||||
args.batch_stride_knew = batch_stride_knew;
|
||||
args.batch_stride_vnew = batch_stride_vnew;
|
||||
}
|
||||
else
|
||||
{
|
||||
args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
||||
: bias_buf.GetDeviceBuffer();
|
||||
args.lse_ptr = lse_buf.GetDeviceBuffer();
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.stride_o_acc = stride_o_acc;
|
||||
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
|
||||
args.nhead_stride_o_acc = nhead_stride_o_acc;
|
||||
args.batch_stride_lse_acc = batch_stride_lse_acc;
|
||||
args.batch_stride_o_acc = batch_stride_o_acc;
|
||||
args.split_stride_lse_acc = split_stride_lse_acc;
|
||||
args.split_stride_o_acc = split_stride_o_acc;
|
||||
args.seqlen_k_ptr = k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer();
|
||||
|
||||
args.scale_s = scale_s;
|
||||
args.scale_p = scale_p;
|
||||
args.scale_o = scale_o;
|
||||
|
||||
args.stride_bias =
|
||||
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias;
|
||||
args.stride_o = stride_o;
|
||||
args.nhead_stride_bias = nhead_stride_bias;
|
||||
args.nhead_stride_lse = nhead_stride_lse;
|
||||
args.nhead_stride_o = nhead_stride_o;
|
||||
args.batch_stride_bias = batch_stride_bias;
|
||||
args.batch_stride_lse = batch_stride_lse;
|
||||
args.batch_stride_o = batch_stride_o;
|
||||
|
||||
args.window_size_left = mask.left;
|
||||
args.window_size_right = mask.right;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
||||
|
||||
args.stride_randval = stride_randval;
|
||||
args.nhead_stride_randval = nhead_stride_randval;
|
||||
args.batch_stride_randval = batch_stride_randval;
|
||||
|
||||
args.p_drop = p_drop;
|
||||
args.s_randval = s_randval;
|
||||
args.drop_seed_offset = std::tie(drop_seed, drop_offset);
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
|
||||
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
|
||||
|
||||
args.block_table_ptr =
|
||||
0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr;
|
||||
args.batch_stride_block_table = batch_stride_block_table;
|
||||
args.page_block_size = page_block_size;
|
||||
|
||||
args.num_splits = num_splits;
|
||||
|
||||
args.stride_o_acc = stride_o_acc;
|
||||
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
|
||||
args.nhead_stride_o_acc = nhead_stride_o_acc;
|
||||
args.batch_stride_lse_acc = batch_stride_lse_acc;
|
||||
args.batch_stride_o_acc = batch_stride_o_acc;
|
||||
args.split_stride_lse_acc = split_stride_lse_acc;
|
||||
args.split_stride_o_acc = split_stride_o_acc;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const float appendkv_ave_time = [&] {
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
if(0 < seqlen_knew || 0 < rotary_dim)
|
||||
{
|
||||
fmha_fwd_appendkv_traits fwd_appendkv_traits;
|
||||
init_traits(fwd_appendkv_traits);
|
||||
|
||||
fmha_fwd_appendkv_args fwd_appendkv_args;
|
||||
init_args(fwd_appendkv_args);
|
||||
|
||||
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config);
|
||||
}
|
||||
#endif
|
||||
return 0.0f;
|
||||
}();
|
||||
|
||||
const float fwd_ave_time = [&] {
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(1 < num_splits || 0 < page_block_size)
|
||||
@@ -1048,7 +1009,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return fmha_fwd(fmha_traits, fmha_args, stream_config);
|
||||
}();
|
||||
|
||||
if(appendkv_ave_time < 0 || fwd_ave_time < 0)
|
||||
if(appendkv_ave_time < 0.0f || fwd_ave_time < 0.0f)
|
||||
{
|
||||
std::cout << ", not supported yet" << std::flush << std::endl;
|
||||
return false;
|
||||
|
||||
@@ -221,25 +221,26 @@ struct fmha_fwd_appendkv_args
|
||||
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
const void*
|
||||
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
|
||||
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t seqlen_knew;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
const void* rotary_cos_ptr;
|
||||
const void* rotary_sin_ptr;
|
||||
const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0
|
||||
const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0
|
||||
ck_tile::index_t rotary_dim;
|
||||
|
||||
void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table;
|
||||
ck_tile::index_t page_block_size;
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
|
||||
Reference in New Issue
Block a user