Setup meaningfull arguments

This commit is contained in:
PoYen, Chen
2024-06-24 14:34:31 +00:00
parent 342c8cf01d
commit eee035ade5

View File

@@ -494,7 +494,68 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto appendkv_traits = fmha_fwd_appendkv_traits{
hdim_q, hdim_v, data_type, mode == mode_enum::group, is_v_rowmajor};
auto appendkv_args = []() { return fmha_fwd_appendkv_args{}; }();
auto appendkv_args = [&, k_paddings_ = seqlen_kpads]() {
// setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k;
}();
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}();
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
return fmha_fwd_appendkv_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
nullptr,
v_buf.GetDeviceBuffer(),
nullptr,
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
k_paddings_[0] < 0 ? nullptr
: seqlen_k_buf.GetDeviceBuffer(),
batch,
nhead,
nhead_k,
shape_seqlen_q,
max_seqlen_q,
shape_seqlen_k,
seqlen_knew,
hdim_q,
hdim_v,
nullptr,
nullptr,
0,
false,
stride_q,
stride_k,
0,
stride_v,
0,
nhead_stride_q,
nhead_stride_k,
0,
nhead_stride_v,
0,
batch_stride_q,
batch_stride_k,
0,
batch_stride_v,
0};
}();
ave_time += fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);
}