From 6a399ea47e07e975d9de6ac4af4cce51cdf79dbd Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 8 Aug 2024 19:22:53 +0000 Subject: [PATCH] Use generic lambda to init all the api traits/args --- example/ck_tile/01_fmha/fmha_fwd.cpp | 331 ++++++++++++--------------- example/ck_tile/01_fmha/fmha_fwd.hpp | 19 +- 2 files changed, 156 insertions(+), 194 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 2ed135d638..d7718a798b 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -758,131 +758,30 @@ bool run(const ck_tile::ArgParser& arg_parser) #endif std::cout << std::flush; - float appendkv_ave_time = 0; -#if CK_TILE_FMHA_FWD_APPENDKV_API - if(0 < seqlen_knew || 0 < rotary_dim) - { - auto appendkv_traits = fmha_fwd_appendkv_traits{ - hdim_q, - hdim_v, - data_type, - mode == mode_enum::group, - is_v_rowmajor, - (0 < rotary_dim - ? (is_rotary_interleaved ? rope_enum::interleaved : rope_enum::half_rotated) - : rope_enum::none)}; - - auto appendkv_args = [&, k_paddings_ = seqlen_kpads]() { - // setup stride_* arguments - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { - if(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return 0 < page_block_size - ? (i_perm ? page_block_size : nhead_k * page_block_size) - : (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); - }(); - const ck_tile::index_t stride_vnew = [&]() { - if(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return i_perm ? seqlen_knew : nhead_k * seqlen_knew; - }(); - // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = - (0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q) - : (i_perm ? shape_seqlen_k * hdim_q : hdim_q)); - const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_v = [&]() { - if(is_v_rowmajor) - return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v) - : (i_perm ? shape_seqlen_k * hdim_v : hdim_v); - else - return 0 < page_block_size - ? (i_perm ? hdim_v * page_block_size : page_block_size) - : (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k); - }(); - const ck_tile::index_t nhead_stride_vnew = [&]() { - if(is_v_rowmajor) - return i_perm ? seqlen_knew * hdim_v : hdim_v; - else - return i_perm ? hdim_v * seqlen_knew : seqlen_knew; - }(); - // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = - (0 < page_block_size ? (nhead_k * page_block_size * hdim_q) - : (nhead_k * shape_seqlen_k * hdim_q)); - const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q); - const ck_tile::index_t batch_stride_v = - (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) - : (nhead_k * hdim_v * shape_seqlen_k)); - const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); - const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); - - return fmha_fwd_appendkv_args{ - q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - knew_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - vnew_buf.GetDeviceBuffer(), - seqstart_q.GetDeviceBuffer(), - seqstart_k.GetDeviceBuffer(), - cache_seqlen_k_buf.GetDeviceBuffer(), - batch, - nhead, - nhead_k, - shape_seqlen_q, - max_seqlen_q, - shape_seqlen_k - seqlen_knew /* kvcache seqlen_k for batch mode */, - seqlen_knew, - hdim_q, - hdim_v, - rotary_cos_buf.GetDeviceBuffer(), - rotary_sin_buf.GetDeviceBuffer(), - rotary_dim, - 0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr, - batch_stride_block_table, // only used if 'block_table_ptr' is not nullptr - page_block_size, // only used if 'block_table_ptr' is not nullptr - stride_q, - stride_k, - stride_knew, - stride_v, - stride_vnew, - nhead_stride_q, - nhead_stride_k, - nhead_stride_knew, - nhead_stride_v, - nhead_stride_vnew, - batch_stride_q, - batch_stride_k, - batch_stride_knew, - batch_stride_v, - batch_stride_vnew}; - }(); - - appendkv_ave_time = fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config); - } -#endif - const auto init_traits = [&](auto& traits) { - traits.hdim_q = hdim_q; - traits.hdim_v = hdim_v; - traits.data_type = data_type; - traits.is_group_mode = (mode == mode_enum::group); - traits.is_v_rowmajor = is_v_rowmajor; - traits.mask_type = mask.type; - traits.bias_type = bias.type; - traits.has_lse = lse; - traits.do_fp8_static_quant = squant; + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_group_mode = (mode == mode_enum::group); + traits.is_v_rowmajor = is_v_rowmajor; - if constexpr(std::is_same_v>) + if constexpr(std::is_same_v>) { - traits.has_dropout = (p_drop > 0.0f); + traits.rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved + : rope_enum::half_rotated) + : rope_enum::none); + } + else + { + traits.mask_type = mask.type; + traits.bias_type = bias.type; + traits.has_lse = lse; + traits.do_fp8_static_quant = squant; + + if constexpr(std::is_same_v>) + { + traits.has_dropout = (p_drop > 0.0f); + } } }; @@ -892,14 +791,21 @@ bool run(const ck_tile::ArgParser& arg_parser) /// seqlen_k] in this example, hence both the 'batch_stride_bias' & /// 'nhead_stride_bias' are 0. // setup stride_* arguments - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { if(is_v_rowmajor) return i_perm ? hdim_v : nhead_k * hdim_v; else return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size) - : (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + : (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_vnew = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? seqlen_knew : nhead_k * seqlen_knew; }(); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); @@ -910,13 +816,20 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t nhead_stride_k = (0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q) : (i_perm ? shape_seqlen_k * hdim_q : hdim_q)); - const ck_tile::index_t nhead_stride_v = [&]() { + const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = [&]() { if(is_v_rowmajor) return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v) - : (i_perm ? shape_seqlen_k * hdim_v : hdim_v); + : (i_perm ? shape_seqlen_k * hdim_v : hdim_v); else return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size) - : (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k); + : (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k); + }(); + const ck_tile::index_t nhead_stride_vnew = [&]() { + if(is_v_rowmajor) + return i_perm ? seqlen_knew * hdim_v : hdim_v; + else + return i_perm ? hdim_v * seqlen_knew : seqlen_knew; }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); @@ -930,9 +843,11 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t batch_stride_k = (0 < page_block_size ? (nhead_k * page_block_size * hdim_q) : (nhead_k * shape_seqlen_k * hdim_q)); + const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q); const ck_tile::index_t batch_stride_v = (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) : (nhead_k * hdim_v * shape_seqlen_k)); + const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); @@ -944,17 +859,12 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q); const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.v_ptr = v_buf.GetDeviceBuffer(); - args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() - : bias_buf.GetDeviceBuffer(); - args.lse_ptr = lse_buf.GetDeviceBuffer(); - args.o_ptr = o_buf.GetDeviceBuffer(); + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer(); args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer(); - args.seqlen_k_ptr = k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(); args.seqlen_q = shape_seqlen_q; args.seqlen_k = shape_seqlen_k; @@ -965,67 +875,118 @@ bool run(const ck_tile::ArgParser& arg_parser) args.nhead_q = nhead; args.nhead_k = nhead_k; - args.scale_s = scale_s; - args.scale_p = scale_p; - args.scale_o = scale_o; + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; - args.stride_q = stride_q; - args.stride_k = stride_k; - args.stride_v = stride_v; - args.stride_bias = - bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias; - args.stride_o = stride_o; - args.nhead_stride_q = nhead_stride_q; - args.nhead_stride_k = nhead_stride_k; - args.nhead_stride_v = nhead_stride_v; - args.nhead_stride_bias = nhead_stride_bias; - args.nhead_stride_lse = nhead_stride_lse; - args.nhead_stride_o = nhead_stride_o; - args.batch_stride_q = batch_stride_q; - args.batch_stride_k = batch_stride_k; - args.batch_stride_v = batch_stride_v; - args.batch_stride_bias = batch_stride_bias; - args.batch_stride_lse = batch_stride_lse; - args.batch_stride_o = batch_stride_o; - - args.window_size_left = mask.left; - args.window_size_right = mask.right; - args.mask_type = static_cast(mask.type); - - if constexpr(std::is_same_v>) + if constexpr(std::is_same_v>) { - args.rand_val_ptr = randval_buf.GetDeviceBuffer(); + args.knew_ptr = knew_buf.GetDeviceBuffer(); + args.vnew_ptr = vnew_buf.GetDeviceBuffer(); + args.seqlen_knew = seqlen_knew; - args.stride_randval = stride_randval; - args.nhead_stride_randval = nhead_stride_randval; - args.batch_stride_randval = batch_stride_randval; + args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer(); - args.p_drop = p_drop; - args.s_randval = s_randval; - args.drop_seed_offset = std::tie(drop_seed, drop_offset); - } - else if constexpr(std::is_same_v>) - { - args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); - args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); + args.rotary_cos_ptr = rotary_cos_buf.GetDeviceBuffer(); + args.rotary_sin_ptr = rotary_sin_buf.GetDeviceBuffer(); + args.rotary_dim = rotary_dim; args.block_table_ptr = 0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr; args.batch_stride_block_table = batch_stride_block_table; args.page_block_size = page_block_size; - args.num_splits = num_splits; + args.stride_knew = stride_knew; + args.stride_vnew = stride_vnew; + args.nhead_stride_knew = nhead_stride_knew; + args.nhead_stride_vnew = nhead_stride_vnew; + args.batch_stride_knew = batch_stride_knew; + args.batch_stride_vnew = batch_stride_vnew; + } + else + { + args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(); + args.lse_ptr = lse_buf.GetDeviceBuffer(); + args.o_ptr = o_buf.GetDeviceBuffer(); - args.stride_o_acc = stride_o_acc; - args.nhead_stride_lse_acc = nhead_stride_lse_acc; - args.nhead_stride_o_acc = nhead_stride_o_acc; - args.batch_stride_lse_acc = batch_stride_lse_acc; - args.batch_stride_o_acc = batch_stride_o_acc; - args.split_stride_lse_acc = split_stride_lse_acc; - args.split_stride_o_acc = split_stride_o_acc; + args.seqlen_k_ptr = k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(); + + args.scale_s = scale_s; + args.scale_p = scale_p; + args.scale_o = scale_o; + + args.stride_bias = + bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias; + args.stride_o = stride_o; + args.nhead_stride_bias = nhead_stride_bias; + args.nhead_stride_lse = nhead_stride_lse; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_bias = batch_stride_bias; + args.batch_stride_lse = batch_stride_lse; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + + if constexpr(std::is_same_v>) + { + args.rand_val_ptr = randval_buf.GetDeviceBuffer(); + + args.stride_randval = stride_randval; + args.nhead_stride_randval = nhead_stride_randval; + args.batch_stride_randval = batch_stride_randval; + + args.p_drop = p_drop; + args.s_randval = s_randval; + args.drop_seed_offset = std::tie(drop_seed, drop_offset); + } + else if constexpr(std::is_same_v>) + { + args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); + args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); + + args.block_table_ptr = + 0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr; + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + + args.num_splits = num_splits; + + args.stride_o_acc = stride_o_acc; + args.nhead_stride_lse_acc = nhead_stride_lse_acc; + args.nhead_stride_o_acc = nhead_stride_o_acc; + args.batch_stride_lse_acc = batch_stride_lse_acc; + args.batch_stride_o_acc = batch_stride_o_acc; + args.split_stride_lse_acc = split_stride_lse_acc; + args.split_stride_o_acc = split_stride_o_acc; + } } }; + const float appendkv_ave_time = [&] { +#if CK_TILE_FMHA_FWD_APPENDKV_API + if(0 < seqlen_knew || 0 < rotary_dim) + { + fmha_fwd_appendkv_traits fwd_appendkv_traits; + init_traits(fwd_appendkv_traits); + + fmha_fwd_appendkv_args fwd_appendkv_args; + init_args(fwd_appendkv_args); + + return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config); + } +#endif + return 0.0f; + }(); + const float fwd_ave_time = [&] { #if CK_TILE_FMHA_FWD_SPLITKV_API if(1 < num_splits || 0 < page_block_size) @@ -1048,7 +1009,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return fmha_fwd(fmha_traits, fmha_args, stream_config); }(); - if(appendkv_ave_time < 0 || fwd_ave_time < 0) + if(appendkv_ave_time < 0.0f || fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; return false; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 46dfa1409c..ad40061355 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -221,25 +221,26 @@ struct fmha_fwd_appendkv_args const void* seqstart_q_ptr; const void* seqstart_k_ptr; - const void* seqlen_k_ptr; + const void* + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr - ck_tile::index_t batch; - ck_tile::index_t nhead_q; - ck_tile::index_t nhead_k; ck_tile::index_t seqlen_q; - ck_tile::index_t max_seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t seqlen_knew; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; ck_tile::index_t hdim_q; ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; - const void* rotary_cos_ptr; - const void* rotary_sin_ptr; + const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0 + const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0 ck_tile::index_t rotary_dim; void* block_table_ptr; - ck_tile::index_t batch_stride_block_table; - ck_tile::index_t page_block_size; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t stride_q; ck_tile::index_t stride_k;