diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 5c6c7d923a..e58e040f19 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1153,7 +1153,7 @@ fwd_result fmha_fwd_run(mode_enum mode, } }; - const float appendkv_ave_time = [&] { + auto run_appendkv = [&](const ck_tile::stream_config& sc) { #if CK_TILE_FMHA_FWD_APPENDKV_API if(need_append_kvcache) { @@ -1163,18 +1163,19 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_appendkv_args fwd_appendkv_args; init_args(fwd_appendkv_args); - return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config); + return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc); } #endif return 0.0f; - }(); + }; + const float appendkv_ave_time = run_appendkv(stream_config); if(appendkv_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; return fwd_result::no_instance; } - const float fwd_ave_time = [&] { + auto run_fwd = [&](const ck_tile::stream_config& sc) { #if CK_TILE_FMHA_FWD_PAGEDKV_API if(1 == num_splits && use_kvcache) { @@ -1184,8 +1185,7 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_pagedkv_args fmha_pagedkv_args; init_args(fmha_pagedkv_args); - const float ave_time = - fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config); + const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc); #if CK_TILE_FMHA_FWD_SPLITKV_API // If there is no instance for these args, fallback to fmha_fwd_splitkv if(ave_time >= 0.0f) @@ -1204,7 +1204,7 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_splitkv_args fmha_splitkv_args; init_args(fmha_splitkv_args); - return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); + return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc); } #endif // CK_TILE_FMHA_FWD_SPLITKV_API fmha_fwd_traits fmha_traits; @@ -1213,8 +1213,9 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_args fmha_args; init_args(fmha_args); - return fmha_fwd(fmha_traits, fmha_args, stream_config); - }(); + return fmha_fwd(fmha_traits, fmha_args, sc); + }; + const float fwd_ave_time = run_fwd(stream_config); if(fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; @@ -1288,6 +1289,17 @@ fwd_result fmha_fwd_run(mode_enum mode, } else { +#if CK_TILE_FMHA_FWD_APPENDKV_API + // When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times + // when time_kernel_ is set). We need to reset the q buffer and rerun all kernels. + if(0 < rotary_dim && stream_config.time_kernel_) + { + const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0}; + q_buf.ToDevice(q_host.data()); + run_appendkv(stream_config2); + run_fwd(stream_config2); + } +#endif o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 7ac86e6d12..7b30f36fd8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -223,6 +223,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline }); } + // sync before rewriting lse_acc_lds + block_sync_lds(); // store the lse scales in shared memory. { constexpr auto spans = decltype(lse_accum)::get_distributed_spans();