diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 9de5d31022..8995339cd9 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -494,7 +494,68 @@ bool run(const ck_tile::ArgParser& arg_parser) auto appendkv_traits = fmha_fwd_appendkv_traits{ hdim_q, hdim_v, data_type, mode == mode_enum::group, is_v_rowmajor}; - auto appendkv_args = []() { return fmha_fwd_appendkv_args{}; }(); + auto appendkv_args = [&, k_paddings_ = seqlen_kpads]() { + // setup stride_* arguments + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; + }(); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + }(); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); + + return fmha_fwd_appendkv_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + nullptr, + v_buf.GetDeviceBuffer(), + nullptr, + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + k_paddings_[0] < 0 ? nullptr + : seqlen_k_buf.GetDeviceBuffer(), + batch, + nhead, + nhead_k, + shape_seqlen_q, + max_seqlen_q, + shape_seqlen_k, + seqlen_knew, + hdim_q, + hdim_v, + nullptr, + nullptr, + 0, + false, + stride_q, + stride_k, + 0, + stride_v, + 0, + nhead_stride_q, + nhead_stride_k, + 0, + nhead_stride_v, + 0, + batch_stride_q, + batch_stride_k, + 0, + batch_stride_v, + 0}; + }(); ave_time += fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config); }