From 477a605961b1d1bd1735d6f55e112e43c24cb8cd Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Sat, 27 Sep 2025 03:19:57 +0000 Subject: [PATCH] Merge commit 'c6bfd97c2d186fd03866c3f5d460bb680ce667a1' into develop --- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 30 +++++-- ...lock_fmha_fwd_splitkv_combine_pipeline.hpp | 4 +- .../profiler/profile_gemm_multi_abd_impl.hpp | 2 +- .../test_gemm_multi_abd_wmma.cpp | 85 +++++++++---------- .../test_gemm_multi_abd_xdl.cpp | 85 +++++++++---------- 5 files changed, 107 insertions(+), 99 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 5c6c7d923a..e58e040f19 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1153,7 +1153,7 @@ fwd_result fmha_fwd_run(mode_enum mode, } }; - const float appendkv_ave_time = [&] { + auto run_appendkv = [&](const ck_tile::stream_config& sc) { #if CK_TILE_FMHA_FWD_APPENDKV_API if(need_append_kvcache) { @@ -1163,18 +1163,19 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_appendkv_args fwd_appendkv_args; init_args(fwd_appendkv_args); - return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config); + return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc); } #endif return 0.0f; - }(); + }; + const float appendkv_ave_time = run_appendkv(stream_config); if(appendkv_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; return fwd_result::no_instance; } - const float fwd_ave_time = [&] { + auto run_fwd = [&](const ck_tile::stream_config& sc) { #if CK_TILE_FMHA_FWD_PAGEDKV_API if(1 == num_splits && use_kvcache) { @@ -1184,8 +1185,7 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_pagedkv_args fmha_pagedkv_args; init_args(fmha_pagedkv_args); - const float ave_time = - fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config); + const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc); #if CK_TILE_FMHA_FWD_SPLITKV_API // If there is no instance for these args, fallback to fmha_fwd_splitkv if(ave_time >= 0.0f) @@ -1204,7 +1204,7 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_splitkv_args fmha_splitkv_args; init_args(fmha_splitkv_args); - return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); + return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc); } #endif // CK_TILE_FMHA_FWD_SPLITKV_API fmha_fwd_traits fmha_traits; @@ -1213,8 +1213,9 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_args fmha_args; init_args(fmha_args); - return fmha_fwd(fmha_traits, fmha_args, stream_config); - }(); + return fmha_fwd(fmha_traits, fmha_args, sc); + }; + const float fwd_ave_time = run_fwd(stream_config); if(fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; @@ -1288,6 +1289,17 @@ fwd_result fmha_fwd_run(mode_enum mode, } else { +#if CK_TILE_FMHA_FWD_APPENDKV_API + // When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times + // when time_kernel_ is set). We need to reset the q buffer and rerun all kernels. + if(0 < rotary_dim && stream_config.time_kernel_) + { + const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0}; + q_buf.ToDevice(q_host.data()); + run_appendkv(stream_config2); + run_fwd(stream_config2); + } +#endif o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 7ac86e6d12..7b30f36fd8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -223,6 +223,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline }); } + // sync before rewriting lse_acc_lds + block_sync_lds(); // store the lse scales in shared memory. { constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); diff --git a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp index a3c5c6a3ac..46745fd02b 100644 --- a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp @@ -224,7 +224,7 @@ bool profile_gemm_multi_abd_impl(int do_verification, auto get_b_matrix = [&]() -> auto { // in case of pass through we avoid allocating a new // tensor and copying values - if constexpr(is_same_v) + if constexpr(is_same_v) { return bs_k_n(Number<0>{}); } diff --git a/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp b/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp index a15f95bbf8..42584ecc02 100644 --- a/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp +++ b/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp @@ -26,9 +26,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; -using KernelTypesABD = ::testing::Types< -#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation - std::tuple, +using KernelTypesABD = ::testing::Types, ck::Tuple, ck::Tuple, ck::Tuple, @@ -108,47 +106,46 @@ using KernelTypesABD = ::testing::Types< PassThrough, Multiply, PassThrough>, -#endif - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAddFastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAdd>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyFastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - Multiply>>; + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAdd>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>; TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD); TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); } diff --git a/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp b/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp index a15f95bbf8..42584ecc02 100644 --- a/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp +++ b/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp @@ -26,9 +26,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; -using KernelTypesABD = ::testing::Types< -#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation - std::tuple, +using KernelTypesABD = ::testing::Types, ck::Tuple, ck::Tuple, ck::Tuple, @@ -108,47 +106,46 @@ using KernelTypesABD = ::testing::Types< PassThrough, Multiply, PassThrough>, -#endif - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAddFastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAdd>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyFastGelu>, - std::tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - Multiply>>; + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAdd>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>; TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD); TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }