mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Merge commit 'c6bfd97c2d186fd03866c3f5d460bb680ce667a1' into develop
This commit is contained in:
@@ -1153,7 +1153,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
};
|
||||
|
||||
const float appendkv_ave_time = [&] {
|
||||
auto run_appendkv = [&](const ck_tile::stream_config& sc) {
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
if(need_append_kvcache)
|
||||
{
|
||||
@@ -1163,18 +1163,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
fmha_fwd_appendkv_args fwd_appendkv_args;
|
||||
init_args(fwd_appendkv_args);
|
||||
|
||||
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config);
|
||||
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc);
|
||||
}
|
||||
#endif
|
||||
return 0.0f;
|
||||
}();
|
||||
};
|
||||
const float appendkv_ave_time = run_appendkv(stream_config);
|
||||
if(appendkv_ave_time < 0.0f)
|
||||
{
|
||||
std::cout << ", not supported yet" << std::flush << std::endl;
|
||||
return fwd_result::no_instance;
|
||||
}
|
||||
|
||||
const float fwd_ave_time = [&] {
|
||||
auto run_fwd = [&](const ck_tile::stream_config& sc) {
|
||||
#if CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
if(1 == num_splits && use_kvcache)
|
||||
{
|
||||
@@ -1184,8 +1185,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
fmha_fwd_pagedkv_args fmha_pagedkv_args;
|
||||
init_args(fmha_pagedkv_args);
|
||||
|
||||
const float ave_time =
|
||||
fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config);
|
||||
const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc);
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
// If there is no instance for these args, fallback to fmha_fwd_splitkv
|
||||
if(ave_time >= 0.0f)
|
||||
@@ -1204,7 +1204,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
fmha_fwd_splitkv_args fmha_splitkv_args;
|
||||
init_args(fmha_splitkv_args);
|
||||
|
||||
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config);
|
||||
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc);
|
||||
}
|
||||
#endif // CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
fmha_fwd_traits fmha_traits;
|
||||
@@ -1213,8 +1213,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
fmha_fwd_args fmha_args;
|
||||
init_args(fmha_args);
|
||||
|
||||
return fmha_fwd(fmha_traits, fmha_args, stream_config);
|
||||
}();
|
||||
return fmha_fwd(fmha_traits, fmha_args, sc);
|
||||
};
|
||||
const float fwd_ave_time = run_fwd(stream_config);
|
||||
if(fwd_ave_time < 0.0f)
|
||||
{
|
||||
std::cout << ", not supported yet" << std::flush << std::endl;
|
||||
@@ -1288,6 +1289,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
else
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
// When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times
|
||||
// when time_kernel_ is set). We need to reset the q buffer and rerun all kernels.
|
||||
if(0 < rotary_dim && stream_config.time_kernel_)
|
||||
{
|
||||
const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0};
|
||||
q_buf.ToDevice(q_host.data());
|
||||
run_appendkv(stream_config2);
|
||||
run_fwd(stream_config2);
|
||||
}
|
||||
#endif
|
||||
o_buf.FromDevice(o_host.data());
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
Reference in New Issue
Block a user