diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index c7f1a6d15e..631eff10db 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -507,7 +507,7 @@ bool run(const ck_tile::ArgParser& arg_parser) #if CK_TILE_FMHA_FWD_SPLITKV_API if(0 < p_drop && (1 < num_splits || 0 < page_block_size)) { - std::cerr << "dropout is not supoprted in split-kv kernels. ignoring the option 'p_drop'" + std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the option 'p_drop'" << std::endl; p_drop = 0.0f; } @@ -879,7 +879,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } }; - auto fmha_args = [&, k_paddings_ = seqlen_kpads]() { + const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) { assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// seqlen_k] in this example, hence both the 'batch_stride_bias' & @@ -937,69 +937,87 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q); const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); - return fmha_fwd_args{ - q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() - : bias_buf.GetDeviceBuffer(), - randval_buf.GetDeviceBuffer(), - 1 < num_splits || 0 < page_block_size ? lse_acc_buf.GetDeviceBuffer() : nullptr, - 1 < num_splits || 0 < page_block_size ? o_acc_buf.GetDeviceBuffer() : nullptr, - lse_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - seqstart_q.GetDeviceBuffer(), - seqstart_k.GetDeviceBuffer(), - k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(), - 0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr, - batch_stride_block_table, // only used if 'block_table_ptr' is not nullptr - page_block_size, // only used if 'block_table_ptr' is not nullptr - shape_seqlen_q, - shape_seqlen_k, - batch, - max_seqlen_q, - hdim_q, - hdim_v, - nhead, - nhead_k, - num_splits, // only used in split-kv kernel - scale_s, - scale_p, - scale_o, - stride_q, - stride_k, - stride_v, - bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias, - stride_randval, - stride_o_acc, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_bias, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_lse_acc, - nhead_stride_o_acc, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_bias, - batch_stride_randval, - batch_stride_lse, - batch_stride_lse_acc, - batch_stride_o_acc, - batch_stride_o, - split_stride_lse_acc, - split_stride_o_acc, - mask.left, - mask.right, - static_cast(mask.type), - p_drop, - s_randval, - {drop_seed, drop_offset}}; - }(); + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(); + args.lse_ptr = lse_buf.GetDeviceBuffer(); + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqstart_q_ptr = seqstart_q.GetDeviceBuffer(); + args.seqstart_k_ptr = seqstart_k.GetDeviceBuffer(); + args.seqlen_k_ptr = k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(); + + args.seqlen_q = shape_seqlen_q; + args.seqlen_k = shape_seqlen_k; + args.batch = batch; + args.max_seqlen_q = max_seqlen_q; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.scale_s = scale_s; + args.scale_p = scale_p; + args.scale_o = scale_o; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.stride_bias = + bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias; + args.stride_o = stride_o; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.nhead_stride_bias = nhead_stride_bias; + args.nhead_stride_lse = nhead_stride_lse; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + args.batch_stride_bias = batch_stride_bias; + args.batch_stride_lse = batch_stride_lse; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + + if constexpr(std::is_same_v>) + { + args.rand_val_ptr = randval_buf.GetDeviceBuffer(); + + args.stride_randval = stride_randval; + args.nhead_stride_randval = nhead_stride_randval; + args.batch_stride_randval = batch_stride_randval; + + args.p_drop = p_drop; + args.s_randval = s_randval; + args.drop_seed_offset = std::tie(drop_seed, drop_offset); + } + else if constexpr(std::is_same_v>) + { + args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); + args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); + + args.block_table_ptr = + 0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr; + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + + args.num_splits = num_splits; + + args.stride_o_acc = stride_o_acc; + args.nhead_stride_lse_acc = nhead_stride_lse_acc; + args.nhead_stride_o_acc = nhead_stride_o_acc; + args.batch_stride_lse_acc = batch_stride_lse_acc; + args.batch_stride_o_acc = batch_stride_o_acc; + args.split_stride_lse_acc = split_stride_lse_acc; + args.split_stride_o_acc = split_stride_o_acc; + } + }; const float fwd_ave_time = [&] { #if CK_TILE_FMHA_FWD_SPLITKV_API @@ -1008,12 +1026,18 @@ bool run(const ck_tile::ArgParser& arg_parser) fmha_fwd_splitkv_traits fmha_splitkv_traits; init_traits(fmha_splitkv_traits); - return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_args, stream_config); + fmha_fwd_splitkv_args fmha_splitkv_args; + init_args(fmha_splitkv_args); + + return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); } #endif fmha_fwd_traits fmha_traits; init_traits(fmha_traits); + fmha_fwd_args fmha_args; + init_args(fmha_args); + return fmha_fwd(fmha_traits, fmha_args, stream_config); }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index dcec1c68ce..46dfa1409c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -96,18 +96,76 @@ struct fmha_fwd_args const void* v_ptr; const void* bias_ptr; // bias or alibi_slope pointer void* rand_val_ptr; - void* lse_acc_ptr; - void* o_acc_ptr; void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; const void* seqstart_k_ptr; - const void* seqlen_k_ptr; + const void* + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + float p_drop; + bool s_randval; + std::tuple drop_seed_offset; +}; + +struct fmha_fwd_splitkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* lse_acc_ptr; + void* o_acc_ptr; + void* lse_ptr; + void* o_ptr; void* block_table_ptr; - ck_tile::index_t batch_stride_block_table; - ck_tile::index_t page_block_size; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; @@ -118,21 +176,21 @@ struct fmha_fwd_args ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; ck_tile::index_t num_splits; + float scale_s; float scale_p; float scale_o; + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 - ck_tile::index_t stride_randval; ck_tile::index_t stride_o_acc; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_bias; - ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_o_acc; @@ -141,23 +199,18 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_bias; - ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_lse_acc; ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o; ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_o_acc; + ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; - float p_drop; - bool s_randval; - std::tuple drop_seed_offset; }; -using fmha_fwd_splitkv_args = fmha_fwd_args; - struct fmha_fwd_appendkv_args { void* q_ptr;