Merge commit 'c6bfd97c2d186fd03866c3f5d460bb680ce667a1' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-27 03:19:57 +00:00
parent 088b4670ae
commit 477a605961
5 changed files with 107 additions and 99 deletions

View File

@@ -1153,7 +1153,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
};
const float appendkv_ave_time = [&] {
auto run_appendkv = [&](const ck_tile::stream_config& sc) {
#if CK_TILE_FMHA_FWD_APPENDKV_API
if(need_append_kvcache)
{
@@ -1163,18 +1163,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
fmha_fwd_appendkv_args fwd_appendkv_args;
init_args(fwd_appendkv_args);
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config);
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc);
}
#endif
return 0.0f;
}();
};
const float appendkv_ave_time = run_appendkv(stream_config);
if(appendkv_ave_time < 0.0f)
{
std::cout << ", not supported yet" << std::flush << std::endl;
return fwd_result::no_instance;
}
const float fwd_ave_time = [&] {
auto run_fwd = [&](const ck_tile::stream_config& sc) {
#if CK_TILE_FMHA_FWD_PAGEDKV_API
if(1 == num_splits && use_kvcache)
{
@@ -1184,8 +1185,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
fmha_fwd_pagedkv_args fmha_pagedkv_args;
init_args(fmha_pagedkv_args);
const float ave_time =
fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config);
const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc);
#if CK_TILE_FMHA_FWD_SPLITKV_API
// If there is no instance for these args, fallback to fmha_fwd_splitkv
if(ave_time >= 0.0f)
@@ -1204,7 +1204,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
fmha_fwd_splitkv_args fmha_splitkv_args;
init_args(fmha_splitkv_args);
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config);
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc);
}
#endif // CK_TILE_FMHA_FWD_SPLITKV_API
fmha_fwd_traits fmha_traits;
@@ -1213,8 +1213,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
fmha_fwd_args fmha_args;
init_args(fmha_args);
return fmha_fwd(fmha_traits, fmha_args, stream_config);
}();
return fmha_fwd(fmha_traits, fmha_args, sc);
};
const float fwd_ave_time = run_fwd(stream_config);
if(fwd_ave_time < 0.0f)
{
std::cout << ", not supported yet" << std::flush << std::endl;
@@ -1288,6 +1289,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
else
{
#if CK_TILE_FMHA_FWD_APPENDKV_API
// When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times
// when time_kernel_ is set). We need to reset the q buffer and rerun all kernels.
if(0 < rotary_dim && stream_config.time_kernel_)
{
const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0};
q_buf.ToDevice(q_host.data());
run_appendkv(stream_config2);
run_fwd(stream_config2);
}
#endif
o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data());
randval_buf.FromDevice(randval_host.data());

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -223,6 +223,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
});
}
// sync before rewriting lse_acc_lds
block_sync_lds();
// store the lse scales in shared memory.
{
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();

View File

@@ -224,7 +224,7 @@ bool profile_gemm_multi_abd_impl(int do_verification,
auto get_b_matrix = [&]() -> auto {
// in case of pass through we avoid allocating a new
// tensor and copying values
if constexpr(is_same_v<AElementOp, PassThrough>)
if constexpr(is_same_v<BElementOp, PassThrough>)
{
return bs_k_n(Number<0>{});
}

View File

@@ -26,9 +26,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using KernelTypesABD = ::testing::Types<
#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation
std::tuple<ck::Tuple<Row>,
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
@@ -108,47 +106,46 @@ using KernelTypesABD = ::testing::Types<
PassThrough,
Multiply,
PassThrough>,
#endif
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }

View File

@@ -26,9 +26,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using KernelTypesABD = ::testing::Types<
#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation
std::tuple<ck::Tuple<Row>,
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
@@ -108,47 +106,46 @@ using KernelTypesABD = ::testing::Types<
PassThrough,
Multiply,
PassThrough>,
#endif
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }