fix copy-paste bug in get_matrix_b; re-enable all tests in multi_abd (#2939)

[ROCm/composable_kernel commit: 2aa06fbd45]
This commit is contained in:
emezh
2025-09-26 22:55:18 -04:00
committed by GitHub
parent 857566c8aa
commit daabe29bff
3 changed files with 83 additions and 89 deletions

View File

@@ -224,7 +224,7 @@ bool profile_gemm_multi_abd_impl(int do_verification,
auto get_b_matrix = [&]() -> auto {
// in case of pass through we avoid allocating a new
// tensor and copying values
if constexpr(is_same_v<AElementOp, PassThrough>)
if constexpr(is_same_v<BElementOp, PassThrough>)
{
return bs_k_n(Number<0>{});
}

View File

@@ -26,9 +26,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using KernelTypesABD = ::testing::Types<
#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation
std::tuple<ck::Tuple<Row>,
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
@@ -108,47 +106,46 @@ using KernelTypesABD = ::testing::Types<
PassThrough,
Multiply,
PassThrough>,
#endif
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }

View File

@@ -26,9 +26,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using KernelTypesABD = ::testing::Types<
#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation
std::tuple<ck::Tuple<Row>,
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
@@ -108,47 +106,46 @@ using KernelTypesABD = ::testing::Types<
PassThrough,
Multiply,
PassThrough>,
#endif
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }