mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] FMHA Fix synchronization issue in FWD splitkv combine pipeline (#2934)
* Fix validation of rotary embedding with time_kernel_ When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times when time_kernel_ is set). We need to reset the q buffer and rerun all kernels. * Fix synchronization issue in splitkv combine pipeline Different warps can read and then rewrite the same values of lse_acc_lds. Sometimes warps progress at different speeds, one warp can rewrite values that are still being read by another warp. Running the tests multiple times and, preferably, with multiple processes on the same GPU helps to trigger this issue: bin/test_ck_tile_fmha_fwd_fp16 --gtest_repeat=-1 --gtest_shuffle --gtest_throw_on_failure --gtest_filter="TestCkTileFmhaFwd/*KV*"
This commit is contained in:
@@ -1153,7 +1153,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
};
|
||||
|
||||
const float appendkv_ave_time = [&] {
|
||||
auto run_appendkv = [&](const ck_tile::stream_config& sc) {
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
if(need_append_kvcache)
|
||||
{
|
||||
@@ -1163,18 +1163,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
fmha_fwd_appendkv_args fwd_appendkv_args;
|
||||
init_args(fwd_appendkv_args);
|
||||
|
||||
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config);
|
||||
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc);
|
||||
}
|
||||
#endif
|
||||
return 0.0f;
|
||||
}();
|
||||
};
|
||||
const float appendkv_ave_time = run_appendkv(stream_config);
|
||||
if(appendkv_ave_time < 0.0f)
|
||||
{
|
||||
std::cout << ", not supported yet" << std::flush << std::endl;
|
||||
return fwd_result::no_instance;
|
||||
}
|
||||
|
||||
const float fwd_ave_time = [&] {
|
||||
auto run_fwd = [&](const ck_tile::stream_config& sc) {
|
||||
#if CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
if(1 == num_splits && use_kvcache)
|
||||
{
|
||||
@@ -1184,8 +1185,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
fmha_fwd_pagedkv_args fmha_pagedkv_args;
|
||||
init_args(fmha_pagedkv_args);
|
||||
|
||||
const float ave_time =
|
||||
fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config);
|
||||
const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc);
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
// If there is no instance for these args, fallback to fmha_fwd_splitkv
|
||||
if(ave_time >= 0.0f)
|
||||
@@ -1204,7 +1204,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
fmha_fwd_splitkv_args fmha_splitkv_args;
|
||||
init_args(fmha_splitkv_args);
|
||||
|
||||
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config);
|
||||
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc);
|
||||
}
|
||||
#endif // CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
fmha_fwd_traits fmha_traits;
|
||||
@@ -1213,8 +1213,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
fmha_fwd_args fmha_args;
|
||||
init_args(fmha_args);
|
||||
|
||||
return fmha_fwd(fmha_traits, fmha_args, stream_config);
|
||||
}();
|
||||
return fmha_fwd(fmha_traits, fmha_args, sc);
|
||||
};
|
||||
const float fwd_ave_time = run_fwd(stream_config);
|
||||
if(fwd_ave_time < 0.0f)
|
||||
{
|
||||
std::cout << ", not supported yet" << std::flush << std::endl;
|
||||
@@ -1288,6 +1289,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
else
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
// When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times
|
||||
// when time_kernel_ is set). We need to reset the q buffer and rerun all kernels.
|
||||
if(0 < rotary_dim && stream_config.time_kernel_)
|
||||
{
|
||||
const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0};
|
||||
q_buf.ToDevice(q_host.data());
|
||||
run_appendkv(stream_config2);
|
||||
run_fwd(stream_config2);
|
||||
}
|
||||
#endif
|
||||
o_buf.FromDevice(o_host.data());
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
Reference in New Issue
Block a user