Use generic lambda to init all the api traits/args

This commit is contained in:
PoYen, Chen
2024-08-08 19:22:53 +00:00
parent 9206808835
commit 6a399ea47e
2 changed files with 156 additions and 194 deletions

View File

@@ -758,131 +758,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
#endif
std::cout << std::flush;
float appendkv_ave_time = 0;
#if CK_TILE_FMHA_FWD_APPENDKV_API
if(0 < seqlen_knew || 0 < rotary_dim)
{
auto appendkv_traits = fmha_fwd_appendkv_traits{
hdim_q,
hdim_v,
data_type,
mode == mode_enum::group,
is_v_rowmajor,
(0 < rotary_dim
? (is_rotary_interleaved ? rope_enum::interleaved : rope_enum::half_rotated)
: rope_enum::none)};
auto appendkv_args = [&, k_paddings_ = seqlen_kpads]() {
// setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return 0 < page_block_size
? (i_perm ? page_block_size : nhead_k * page_block_size)
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
}();
const ck_tile::index_t stride_vnew = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
}();
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k =
(0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q)
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v)
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
else
return 0 < page_block_size
? (i_perm ? hdim_v * page_block_size : page_block_size)
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
}();
const ck_tile::index_t nhead_stride_vnew = [&]() {
if(is_v_rowmajor)
return i_perm ? seqlen_knew * hdim_v : hdim_v;
else
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
}();
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k =
(0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
: (nhead_k * shape_seqlen_k * hdim_q));
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
const ck_tile::index_t batch_stride_v =
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
: (nhead_k * hdim_v * shape_seqlen_k));
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
return fmha_fwd_appendkv_args{
q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
knew_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
vnew_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
cache_seqlen_k_buf.GetDeviceBuffer(),
batch,
nhead,
nhead_k,
shape_seqlen_q,
max_seqlen_q,
shape_seqlen_k - seqlen_knew /* kvcache seqlen_k for batch mode */,
seqlen_knew,
hdim_q,
hdim_v,
rotary_cos_buf.GetDeviceBuffer(),
rotary_sin_buf.GetDeviceBuffer(),
rotary_dim,
0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr,
batch_stride_block_table, // only used if 'block_table_ptr' is not nullptr
page_block_size, // only used if 'block_table_ptr' is not nullptr
stride_q,
stride_k,
stride_knew,
stride_v,
stride_vnew,
nhead_stride_q,
nhead_stride_k,
nhead_stride_knew,
nhead_stride_v,
nhead_stride_vnew,
batch_stride_q,
batch_stride_k,
batch_stride_knew,
batch_stride_v,
batch_stride_vnew};
}();
appendkv_ave_time = fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);
}
#endif
const auto init_traits = [&](auto& traits) {
traits.hdim_q = hdim_q;
traits.hdim_v = hdim_v;
traits.data_type = data_type;
traits.is_group_mode = (mode == mode_enum::group);
traits.is_v_rowmajor = is_v_rowmajor;
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.has_lse = lse;
traits.do_fp8_static_quant = squant;
traits.hdim_q = hdim_q;
traits.hdim_v = hdim_v;
traits.data_type = data_type;
traits.is_group_mode = (mode == mode_enum::group);
traits.is_v_rowmajor = is_v_rowmajor;
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
if constexpr(std::is_same_v<fmha_fwd_appendkv_traits, std::decay_t<decltype(traits)>>)
{
traits.has_dropout = (p_drop > 0.0f);
traits.rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved
: rope_enum::half_rotated)
: rope_enum::none);
}
else
{
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.has_lse = lse;
traits.do_fp8_static_quant = squant;
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
{
traits.has_dropout = (p_drop > 0.0f);
}
}
};
@@ -892,14 +791,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
/// 'nhead_stride_bias' are 0.
// setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size)
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
}();
const ck_tile::index_t stride_vnew = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
}();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k);
@@ -910,13 +816,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_k =
(0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q)
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
const ck_tile::index_t nhead_stride_v = [&]() {
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v)
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
else
return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size)
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
}();
const ck_tile::index_t nhead_stride_vnew = [&]() {
if(is_v_rowmajor)
return i_perm ? seqlen_knew * hdim_v : hdim_v;
else
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
}();
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
@@ -930,9 +843,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_k =
(0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
: (nhead_k * shape_seqlen_k * hdim_q));
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
const ck_tile::index_t batch_stride_v =
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
: (nhead_k * hdim_v * shape_seqlen_k));
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
@@ -944,17 +859,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer();
args.lse_ptr = lse_buf.GetDeviceBuffer();
args.o_ptr = o_buf.GetDeviceBuffer();
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer();
args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer();
args.seqlen_k_ptr = k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer();
args.seqlen_q = shape_seqlen_q;
args.seqlen_k = shape_seqlen_k;
@@ -965,67 +875,118 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.nhead_q = nhead;
args.nhead_k = nhead_k;
args.scale_s = scale_s;
args.scale_p = scale_p;
args.scale_o = scale_o;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.stride_v = stride_v;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.nhead_stride_v = nhead_stride_v;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.batch_stride_v = batch_stride_v;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.stride_v = stride_v;
args.stride_bias =
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias;
args.stride_o = stride_o;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.nhead_stride_v = nhead_stride_v;
args.nhead_stride_bias = nhead_stride_bias;
args.nhead_stride_lse = nhead_stride_lse;
args.nhead_stride_o = nhead_stride_o;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.batch_stride_v = batch_stride_v;
args.batch_stride_bias = batch_stride_bias;
args.batch_stride_lse = batch_stride_lse;
args.batch_stride_o = batch_stride_o;
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
if constexpr(std::is_same_v<fmha_fwd_appendkv_args, std::decay_t<decltype(args)>>)
{
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
args.knew_ptr = knew_buf.GetDeviceBuffer();
args.vnew_ptr = vnew_buf.GetDeviceBuffer();
args.seqlen_knew = seqlen_knew;
args.stride_randval = stride_randval;
args.nhead_stride_randval = nhead_stride_randval;
args.batch_stride_randval = batch_stride_randval;
args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer();
args.p_drop = p_drop;
args.s_randval = s_randval;
args.drop_seed_offset = std::tie(drop_seed, drop_offset);
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
args.rotary_cos_ptr = rotary_cos_buf.GetDeviceBuffer();
args.rotary_sin_ptr = rotary_sin_buf.GetDeviceBuffer();
args.rotary_dim = rotary_dim;
args.block_table_ptr =
0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr;
args.batch_stride_block_table = batch_stride_block_table;
args.page_block_size = page_block_size;
args.num_splits = num_splits;
args.stride_knew = stride_knew;
args.stride_vnew = stride_vnew;
args.nhead_stride_knew = nhead_stride_knew;
args.nhead_stride_vnew = nhead_stride_vnew;
args.batch_stride_knew = batch_stride_knew;
args.batch_stride_vnew = batch_stride_vnew;
}
else
{
args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer();
args.lse_ptr = lse_buf.GetDeviceBuffer();
args.o_ptr = o_buf.GetDeviceBuffer();
args.stride_o_acc = stride_o_acc;
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
args.nhead_stride_o_acc = nhead_stride_o_acc;
args.batch_stride_lse_acc = batch_stride_lse_acc;
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
args.seqlen_k_ptr = k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer();
args.scale_s = scale_s;
args.scale_p = scale_p;
args.scale_o = scale_o;
args.stride_bias =
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias;
args.stride_o = stride_o;
args.nhead_stride_bias = nhead_stride_bias;
args.nhead_stride_lse = nhead_stride_lse;
args.nhead_stride_o = nhead_stride_o;
args.batch_stride_bias = batch_stride_bias;
args.batch_stride_lse = batch_stride_lse;
args.batch_stride_o = batch_stride_o;
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
{
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
args.stride_randval = stride_randval;
args.nhead_stride_randval = nhead_stride_randval;
args.batch_stride_randval = batch_stride_randval;
args.p_drop = p_drop;
args.s_randval = s_randval;
args.drop_seed_offset = std::tie(drop_seed, drop_offset);
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
args.block_table_ptr =
0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr;
args.batch_stride_block_table = batch_stride_block_table;
args.page_block_size = page_block_size;
args.num_splits = num_splits;
args.stride_o_acc = stride_o_acc;
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
args.nhead_stride_o_acc = nhead_stride_o_acc;
args.batch_stride_lse_acc = batch_stride_lse_acc;
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
}
}
};
const float appendkv_ave_time = [&] {
#if CK_TILE_FMHA_FWD_APPENDKV_API
if(0 < seqlen_knew || 0 < rotary_dim)
{
fmha_fwd_appendkv_traits fwd_appendkv_traits;
init_traits(fwd_appendkv_traits);
fmha_fwd_appendkv_args fwd_appendkv_args;
init_args(fwd_appendkv_args);
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config);
}
#endif
return 0.0f;
}();
const float fwd_ave_time = [&] {
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < num_splits || 0 < page_block_size)
@@ -1048,7 +1009,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
return fmha_fwd(fmha_traits, fmha_args, stream_config);
}();
if(appendkv_ave_time < 0 || fwd_ave_time < 0)
if(appendkv_ave_time < 0.0f || fwd_ave_time < 0.0f)
{
std::cout << ", not supported yet" << std::flush << std::endl;
return false;

View File

@@ -221,25 +221,26 @@ struct fmha_fwd_appendkv_args
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
ck_tile::index_t batch;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
ck_tile::index_t seqlen_q;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t seqlen_knew;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
const void* rotary_cos_ptr;
const void* rotary_sin_ptr;
const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0
const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0
ck_tile::index_t rotary_dim;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;