From 8eb813de429378e5c0d83ed9c2384385e0003178 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> Date: Tue, 28 Oct 2025 15:53:36 +0100 Subject: [PATCH] [CK_TILE] Fixed multi-abd GEMM test, NaN problem (#2979) * Multi-ABD NaN problem * Rollback tests --------- Co-authored-by: root Co-authored-by: Thomas Ning Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> [ROCm/composable_kernel commit: da4247a6df4fb7b0941ce9b4a93b1b348e6cbc34] --- .../ops/gemm/kernel/gemm_multi_abd_kernel.hpp | 8 -- .../test_gemm_multi_abd_cshuffle.cpp | 17 ++-- .../test_gemm_multi_abd_default2d.cpp | 16 ++-- .../test_gemm_multi_abd_ut_cases_cshuffle.inc | 90 +++++++++++++++++++ .../test_gemm_multi_abd_util.hpp | 37 +------- 5 files changed, 108 insertions(+), 60 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp index b4ddc33e8d..7b73b89ede 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp @@ -185,14 +185,6 @@ struct GemmKernelMultiABD { return false; } - // Currently MultiABD kernel doesn't support F8 data type - if(ck_tile::get_device_name() == "gfx950" && - (std::is_same::value || - std::is_same::value || - std::is_same::value)) - { - return false; - } return UniversalGemmKernel::IsSupportedArgument(kargs); } diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp index 87d6a9101c..1eef00813d 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp @@ -20,20 +20,19 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using KernelTypes = ::testing::Types< // Has cshuffle epilogue enabled // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> - - // Currently MultiABD kernel doesn't support F8 data type - //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp index f2476e803f..5664ded81e 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp @@ -20,19 +20,17 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using KernelTypes = ::testing::Types< // Has cshuffle epilogue disabled // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> - - // Currently MultiABD kernel doesn't support F8 data type - //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc index 33eb404fbe..e2a7e0115f 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc @@ -1,5 +1,95 @@ #pragma once +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512) { constexpr int M = 512; diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp index 428bed4e25..8b5d845d6b 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -13,40 +13,9 @@ #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -struct AddScale -{ - template - CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const A0& a0, const A1& a1) const - { - a = scale * (ck_tile::type_convert(a0) + ck_tile::type_convert(a1)); - } - - float scale = 1.0; -}; - -struct MultiplyMultiply -{ - template - CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void - { - const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * - ck_tile::type_convert(d1); - - e = ck_tile::type_convert(x0_f); - } -}; - -struct ElementWiseAddAdd -{ - template - CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void - { - const float x0_f = ck_tile::type_convert(c) + ck_tile::type_convert(d0) + - ck_tile::type_convert(d1); - - e = ck_tile::type_convert(x0_f); - } -}; +using AddScale = ck_tile::element_wise::AddScale; +using ElementWiseAddAdd = ck_tile::element_wise::MultiDAdd; +using MultiplyMultiply = ck_tile::element_wise::MultiDMultiply; template static constexpr inline auto is_row_major(Layout layout_)