Merge commit 'c6bfd97c2d186fd03866c3f5d460bb680ce667a1' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-27 03:19:57 +00:00
parent 088b4670ae
commit 477a605961
5 changed files with 107 additions and 99 deletions

View File

@@ -26,9 +26,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using KernelTypesABD = ::testing::Types<
#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation
std::tuple<ck::Tuple<Row>,
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
@@ -108,47 +106,46 @@ using KernelTypesABD = ::testing::Types<
PassThrough,
Multiply,
PassThrough>,
#endif
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }

View File

@@ -26,9 +26,7 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using KernelTypesABD = ::testing::Types<
#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation
std::tuple<ck::Tuple<Row>,
using KernelTypesABD = ::testing::Types<std::tuple<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
@@ -108,47 +106,46 @@ using KernelTypesABD = ::testing::Types<
PassThrough,
Multiply,
PassThrough>,
#endif
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>,
std::tuple<ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<Row>,
ck::Tuple<BF16>,
ck::Tuple<I8>,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>;
TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD);
TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); }