[CK_TILE] FMHA Fix synchronization issue in FWD splitkv combine pipeline (#2934)

* Fix validation of rotary embedding with time_kernel_

When rotary embedding is used, the appendkv kernel modifies the q tensor
(multiple times when time_kernel_ is set). We need to reset the q buffer
and rerun all kernels.

* Fix synchronization issue in splitkv combine pipeline

Different warps can read and then rewrite the same values of lse_acc_lds.
Sometimes warps progress at different speeds, one warp can rewrite
values that are still being read by another warp.

Running the tests multiple times and, preferably, with multiple
processes on the same GPU helps to trigger this issue:

bin/test_ck_tile_fmha_fwd_fp16 --gtest_repeat=-1 --gtest_shuffle --gtest_throw_on_failure --gtest_filter="TestCkTileFmhaFwd/*KV*"
This commit is contained in:
Anton Gorenko
2025-09-27 09:16:10 +06:00
committed by GitHub
parent 2aa06fbd45
commit c6bfd97c2d
2 changed files with 24 additions and 10 deletions

View File

@@ -1153,7 +1153,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
};
const float appendkv_ave_time = [&] {
auto run_appendkv = [&](const ck_tile::stream_config& sc) {
#if CK_TILE_FMHA_FWD_APPENDKV_API
if(need_append_kvcache)
{
@@ -1163,18 +1163,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
fmha_fwd_appendkv_args fwd_appendkv_args;
init_args(fwd_appendkv_args);
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config);
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc);
}
#endif
return 0.0f;
}();
};
const float appendkv_ave_time = run_appendkv(stream_config);
if(appendkv_ave_time < 0.0f)
{
std::cout << ", not supported yet" << std::flush << std::endl;
return fwd_result::no_instance;
}
const float fwd_ave_time = [&] {
auto run_fwd = [&](const ck_tile::stream_config& sc) {
#if CK_TILE_FMHA_FWD_PAGEDKV_API
if(1 == num_splits && use_kvcache)
{
@@ -1184,8 +1185,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
fmha_fwd_pagedkv_args fmha_pagedkv_args;
init_args(fmha_pagedkv_args);
const float ave_time =
fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config);
const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc);
#if CK_TILE_FMHA_FWD_SPLITKV_API
// If there is no instance for these args, fallback to fmha_fwd_splitkv
if(ave_time >= 0.0f)
@@ -1204,7 +1204,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
fmha_fwd_splitkv_args fmha_splitkv_args;
init_args(fmha_splitkv_args);
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config);
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc);
}
#endif // CK_TILE_FMHA_FWD_SPLITKV_API
fmha_fwd_traits fmha_traits;
@@ -1213,8 +1213,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
fmha_fwd_args fmha_args;
init_args(fmha_args);
return fmha_fwd(fmha_traits, fmha_args, stream_config);
}();
return fmha_fwd(fmha_traits, fmha_args, sc);
};
const float fwd_ave_time = run_fwd(stream_config);
if(fwd_ave_time < 0.0f)
{
std::cout << ", not supported yet" << std::flush << std::endl;
@@ -1288,6 +1289,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
else
{
#if CK_TILE_FMHA_FWD_APPENDKV_API
// When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times
// when time_kernel_ is set). We need to reset the q buffer and rerun all kernels.
if(0 < rotary_dim && stream_config.time_kernel_)
{
const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0};
q_buf.ToDevice(q_host.data());
run_appendkv(stream_config2);
run_fwd(stream_config2);
}
#endif
o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data());
randval_buf.FromDevice(randval_host.data());