mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit 'c6bfd97c2d186fd03866c3f5d460bb680ce667a1' into develop
This commit is contained in:
@@ -1153,7 +1153,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
};
|
||||
|
||||
const float appendkv_ave_time = [&] {
|
||||
auto run_appendkv = [&](const ck_tile::stream_config& sc) {
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
if(need_append_kvcache)
|
||||
{
|
||||
@@ -1163,18 +1163,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
fmha_fwd_appendkv_args fwd_appendkv_args;
|
||||
init_args(fwd_appendkv_args);
|
||||
|
||||
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config);
|
||||
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc);
|
||||
}
|
||||
#endif
|
||||
return 0.0f;
|
||||
}();
|
||||
};
|
||||
const float appendkv_ave_time = run_appendkv(stream_config);
|
||||
if(appendkv_ave_time < 0.0f)
|
||||
{
|
||||
std::cout << ", not supported yet" << std::flush << std::endl;
|
||||
return fwd_result::no_instance;
|
||||
}
|
||||
|
||||
const float fwd_ave_time = [&] {
|
||||
auto run_fwd = [&](const ck_tile::stream_config& sc) {
|
||||
#if CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
if(1 == num_splits && use_kvcache)
|
||||
{
|
||||
@@ -1184,8 +1185,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
fmha_fwd_pagedkv_args fmha_pagedkv_args;
|
||||
init_args(fmha_pagedkv_args);
|
||||
|
||||
const float ave_time =
|
||||
fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config);
|
||||
const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc);
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
// If there is no instance for these args, fallback to fmha_fwd_splitkv
|
||||
if(ave_time >= 0.0f)
|
||||
@@ -1204,7 +1204,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
fmha_fwd_splitkv_args fmha_splitkv_args;
|
||||
init_args(fmha_splitkv_args);
|
||||
|
||||
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config);
|
||||
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc);
|
||||
}
|
||||
#endif // CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
fmha_fwd_traits fmha_traits;
|
||||
@@ -1213,8 +1213,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
fmha_fwd_args fmha_args;
|
||||
init_args(fmha_args);
|
||||
|
||||
return fmha_fwd(fmha_traits, fmha_args, stream_config);
|
||||
}();
|
||||
return fmha_fwd(fmha_traits, fmha_args, sc);
|
||||
};
|
||||
const float fwd_ave_time = run_fwd(stream_config);
|
||||
if(fwd_ave_time < 0.0f)
|
||||
{
|
||||
std::cout << ", not supported yet" << std::flush << std::endl;
|
||||
@@ -1288,6 +1289,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
else
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
// When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times
|
||||
// when time_kernel_ is set). We need to reset the q buffer and rerun all kernels.
|
||||
if(0 < rotary_dim && stream_config.time_kernel_)
|
||||
{
|
||||
const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0};
|
||||
q_buf.ToDevice(q_host.data());
|
||||
run_appendkv(stream_config2);
|
||||
run_fwd(stream_config2);
|
||||
}
|
||||
#endif
|
||||
o_buf.FromDevice(o_host.data());
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -223,6 +223,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
});
|
||||
}
|
||||
|
||||
// sync before rewriting lse_acc_lds
|
||||
block_sync_lds();
|
||||
// store the lse scales in shared memory.
|
||||
{
|
||||
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
|
||||
|
||||
@@ -224,7 +224,7 @@ bool profile_gemm_multi_abd_impl(int do_verification,
|
||||
auto get_b_matrix = [&]() -> auto {
|
||||
// in case of pass through we avoid allocating a new
|
||||
// tensor and copying values
|
||||
if constexpr(is_same_v<AElementOp, PassThrough>)
|
||||
if constexpr(is_same_v<BElementOp, PassThrough>)
|
||||
{
|
||||
return bs_k_n(Number<0>{});
|
||||
}
|
||||
|
||||
@@ -26,9 +26,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
||||
|
||||
using KernelTypesABD = ::testing::Types<
|
||||
#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
@@ -108,47 +106,46 @@ using KernelTypesABD = ::testing::Types<
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>,
|
||||
#endif
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAddFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>;
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAddFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
|
||||
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }
|
||||
|
||||
@@ -26,9 +26,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
||||
|
||||
using KernelTypesABD = ::testing::Types<
|
||||
#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
@@ -108,47 +106,46 @@ using KernelTypesABD = ::testing::Types<
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>,
|
||||
#endif
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAddFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>;
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAddFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16, BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyFastGelu>,
|
||||
std::tuple<ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<Row>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
|
||||
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }
|
||||
|
||||
Reference in New Issue
Block a user