diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index cb32587d3e..a70ee6f053 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -337,6 +337,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage elementwise_d_grid_descs_m_n_.reserve(group_count_); ds_grid_pointer_.reserve(group_count_); group_grid_size_.reserve(group_count_); + e_ptrs_.reserve(group_count_); for(std::size_t i = 0; i < gemm_descs.size(); ++i) { @@ -380,7 +381,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage const index_t block_end = grid_size_ + grid_size_grp; grid_size_ += grid_size_grp; - group_grid_size_[i] = grid_size_grp; + group_grid_size_.push_back(grid_size_grp); // block-to-e-tile map auto grouped_block_2_ctile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); @@ -421,9 +422,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n); elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n); ds_grid_pointer_.push_back(p_ds_grid); + // Store a copy of E pointers for elementwise kernel destination + e_ptrs_.push_back(p_Es[i]); } - // Store a copy of E pointers for elementwise kernel destination - e_ptrs_ = p_Es; } /** @@ -774,13 +775,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage dim3(BlockSize), 0, cast_pointer_to_constant_address_space(dev_gemm_args), - arg.group_count_, + arg.gemm_kernel_args_.size(), arg.a_element_op_, arg.b_element_op_, PassThrough{}); // Elementwise kernels - for(int i = 0; i < arg.group_count_; ++i) + for(size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) { time += launch_and_time_kernel( stream_config, diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index f47685cf91..55cb209772 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -6,6 +6,12 @@ if(result EQUAL 0) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) endif() +add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk) +endif() + add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) diff --git a/test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp b/test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp new file mode 100644 index 0000000000..67ecbaea30 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +#include "gtest/gtest.h" +#include "test_grouped_gemm_util.hpp" + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using RRR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage>; +using RCR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage>; +using RRR_F16_F16_F16_LargeK = + ck::test::TestGroupedGemmTwoStage>; +using RCR_F16_F16_F16_LargeK = + ck::test::TestGroupedGemmTwoStage>; +using RRR_BF16_BF16_BF16 = + ck::test::TestGroupedGemmTwoStage>; +using RCR_BF16_BF16_BF16 = + ck::test::TestGroupedGemmTwoStage>; +using RRR_BF16_I8_BF16 = + ck::test::TestGroupedGemmTwoStage>; +using RCR_BF16_I8_BF16 = + ck::test::TestGroupedGemmTwoStage>; + +const std::vector KBATCH{1, 2, 3, 5, 8}; + +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN, + RRR_F16_F16_F16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK, + RCR_F16_F16_F16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16, + RRR_BF16_BF16_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16, + RCR_BF16_BF16_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16_INT8, + RRR_BF16_I8_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16_INT8, + RCR_BF16_I8_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_KN, + RRR_F16_F16_F16_LargeK, + testing::Values(32, 64)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_NK, + RCR_F16_F16_F16_LargeK, + testing::Values(32, 64)); + +#include "test_grouped_gemm_ut_cases.inc" +#include "test_grouped_gemm_two_stage_ut_cases.inc" diff --git a/test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc b/test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc new file mode 100644 index 0000000000..40d48f4ec0 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc @@ -0,0 +1,61 @@ +#pragma once + +TEST_P(RRR_BF16_BF16_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_BF16_BF16_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_BF16_I8_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_BF16_I8_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index 50f423ada3..9e1395b9f8 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,6 +22,7 @@ #include "ck/utility/tuple.hpp" #include "ck/utility/number.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" +#include "profiler/profile_grouped_gemm_two_stage_impl.hpp" namespace ck { namespace test { @@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam } }; +template +class TestGroupedGemmTwoStage : public testing::TestWithParam +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using ELayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using EDataType = std::tuple_element_t<5, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + + void SetUp() override {} + + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + template