mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Setup meaningfull arguments
This commit is contained in:
@@ -494,7 +494,68 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto appendkv_traits = fmha_fwd_appendkv_traits{
|
||||
hdim_q, hdim_v, data_type, mode == mode_enum::group, is_v_rowmajor};
|
||||
|
||||
auto appendkv_args = []() { return fmha_fwd_appendkv_args{}; }();
|
||||
auto appendkv_args = [&, k_paddings_ = seqlen_kpads]() {
|
||||
// setup stride_* arguments
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k;
|
||||
}();
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
|
||||
}();
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
|
||||
|
||||
return fmha_fwd_appendkv_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
v_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
k_paddings_[0] < 0 ? nullptr
|
||||
: seqlen_k_buf.GetDeviceBuffer(),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
shape_seqlen_q,
|
||||
max_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
seqlen_knew,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nullptr,
|
||||
nullptr,
|
||||
0,
|
||||
false,
|
||||
stride_q,
|
||||
stride_k,
|
||||
0,
|
||||
stride_v,
|
||||
0,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
0,
|
||||
nhead_stride_v,
|
||||
0,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
0,
|
||||
batch_stride_v,
|
||||
0};
|
||||
}();
|
||||
|
||||
ave_time += fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user