mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add unit tests for grouped gemm two stage (#1256)
* add unit tests for grouped gemm two stage
* add reviewers suggestions
---------
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: 3e3471d5d2]
This commit is contained in:
@@ -337,6 +337,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
elementwise_d_grid_descs_m_n_.reserve(group_count_);
|
||||
ds_grid_pointer_.reserve(group_count_);
|
||||
group_grid_size_.reserve(group_count_);
|
||||
e_ptrs_.reserve(group_count_);
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
|
||||
{
|
||||
@@ -380,7 +381,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
const index_t block_end = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
group_grid_size_[i] = grid_size_grp;
|
||||
group_grid_size_.push_back(grid_size_grp);
|
||||
// block-to-e-tile map
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
@@ -421,9 +422,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n);
|
||||
elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n);
|
||||
ds_grid_pointer_.push_back(p_ds_grid);
|
||||
// Store a copy of E pointers for elementwise kernel destination
|
||||
e_ptrs_.push_back(p_Es[i]);
|
||||
}
|
||||
// Store a copy of E pointers for elementwise kernel destination
|
||||
e_ptrs_ = p_Es;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -774,13 +775,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(dev_gemm_args),
|
||||
arg.group_count_,
|
||||
arg.gemm_kernel_args_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
PassThrough{});
|
||||
|
||||
// Elementwise kernels
|
||||
for(int i = 0; i < arg.group_count_; ++i)
|
||||
for(size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
|
||||
@@ -6,6 +6,12 @@ if(result EQUAL 0)
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance)
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using RRR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, F16, F16, F16>>;
|
||||
using RCR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, F16, F16, F16>>;
|
||||
using RRR_F16_F16_F16_LargeK =
|
||||
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, F16, F16, F16>>;
|
||||
using RCR_F16_F16_F16_LargeK =
|
||||
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, F16, F16, F16>>;
|
||||
using RRR_BF16_BF16_BF16 =
|
||||
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, BF16, BF16, BF16>>;
|
||||
using RCR_BF16_BF16_BF16 =
|
||||
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, BF16, BF16, BF16>>;
|
||||
using RRR_BF16_I8_BF16 =
|
||||
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, BF16, I8, BF16>>;
|
||||
using RCR_BF16_I8_BF16 =
|
||||
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, BF16, I8, BF16>>;
|
||||
|
||||
const std::vector<int> KBATCH{1, 2, 3, 5, 8};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN,
|
||||
RRR_F16_F16_F16,
|
||||
testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK,
|
||||
RCR_F16_F16_F16,
|
||||
testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16,
|
||||
RRR_BF16_BF16_BF16,
|
||||
testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16,
|
||||
RCR_BF16_BF16_BF16,
|
||||
testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16_INT8,
|
||||
RRR_BF16_I8_BF16,
|
||||
testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16_INT8,
|
||||
RCR_BF16_I8_BF16,
|
||||
testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_KN,
|
||||
RRR_F16_F16_F16_LargeK,
|
||||
testing::Values(32, 64));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_NK,
|
||||
RCR_F16_F16_F16_LargeK,
|
||||
testing::Values(32, 64));
|
||||
|
||||
#include "test_grouped_gemm_ut_cases.inc"
|
||||
#include "test_grouped_gemm_two_stage_ut_cases.inc"
|
||||
61
test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc
Normal file
61
test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc
Normal file
@@ -0,0 +1,61 @@
|
||||
#pragma once
|
||||
|
||||
TEST_P(RRR_BF16_BF16_BF16, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_BF16_BF16_BF16, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RRR_BF16_I8_BF16, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_BF16_I8_BF16, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "profiler/profile_grouped_gemm_impl.hpp"
|
||||
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
@@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam<int>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedGemmTwoStage : public testing::TestWithParam<int>
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using ELayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using EDataType = std::tuple_element_t<5, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // decimal value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
|
||||
void SetUp() override {}
|
||||
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1,
|
||||
int n_warmup = 1,
|
||||
int n_iter = 10)
|
||||
{
|
||||
bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl<ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
float,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
|
||||
Reference in New Issue
Block a user