mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
Addressing (Post Merge) code review comments for PR 1845 (#1883)
* Addressing code review comments.
* Addressing code review comments.
* Reorganized code for better readability.
* add ck_tile gemms for new types in CI
* fix jenkins syntax
* fix script syntax
* Add the test cases back
* Address the review comments
* Address review comments
* clang format
* Solve the merging issues
* Addressed the comments
* clang format
---------
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: 66c5f5b0b6]
This commit is contained in:
@@ -67,7 +67,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
CodegenGemmPipeline::BlockSize,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline test_gemm_pipeline.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp)
|
||||
endif()
|
||||
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline<T>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3, KernelTypesMem);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline<T>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4, KernelTypesMem);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
|
||||
@@ -21,27 +22,41 @@ using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType:
|
||||
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
|
||||
using KernelTypesMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Interwave, Mem>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
|
||||
>;
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipeline, KernelTypes);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_mem.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_mem.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline<T>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineMem, KernelTypesMem);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -3,7 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGemmPipeline, SmallM)
|
||||
#ifndef TEST_GEMM_PIPELINE_UT_CASES_INC
|
||||
#define TEST_GEMM_PIPELINE_UT_CASES_INC
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 1024;
|
||||
@@ -13,18 +16,25 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
|
||||
{
|
||||
if constexpr(std::is_same_v<typename TestFixture::ALayout,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
|
||||
}
|
||||
else
|
||||
{
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
|
||||
TYPED_TEST(TEST_SUITE_NAME, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 320;
|
||||
constexpr int VecLoadSize = 8;
|
||||
constexpr int VecLoadSize = (std::is_same_v<typename TestFixture::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TestFixture::ADataType, ck_tile::bf8_t>)
|
||||
? 16
|
||||
: 8;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
@@ -33,9 +43,13 @@ TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
|
||||
{
|
||||
// TODO: Can we anyhow deduce used vector load size?
|
||||
if(M % VecLoadSize == 0)
|
||||
{
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -44,7 +58,7 @@ TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmPipeline, PaddK)
|
||||
TYPED_TEST(TEST_SUITE_NAME, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{128};
|
||||
constexpr int N = 1024;
|
||||
@@ -54,7 +68,7 @@ TYPED_TEST(TestCkTileGemmPipeline, PaddK)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmPipeline, Regular)
|
||||
TYPED_TEST(TEST_SUITE_NAME, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 1024;
|
||||
@@ -64,7 +78,16 @@ TYPED_TEST(TestCkTileGemmPipeline, Regular)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmPipeline, NotSupportedArgument)
|
||||
TYPED_TEST(TEST_SUITE_NAME, LargeMatrix)
|
||||
{
|
||||
constexpr int M = 2048;
|
||||
constexpr int N = 2048;
|
||||
constexpr int K = 2048;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, NotSupportedArgument)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1025;
|
||||
@@ -76,3 +99,5 @@ TYPED_TEST(TestCkTileGemmPipeline, NotSupportedArgument)
|
||||
|
||||
EXPECT_THROW((this->template Run<PadM, PadN, PadK>(M, N, K)), std::runtime_error);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -11,6 +11,27 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
enum struct GemmPipelineType
|
||||
{
|
||||
Mem,
|
||||
@@ -63,7 +84,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
// TODO: This should be parameterized in tests
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
@@ -71,8 +92,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
// TODO: Restore to 8. At now after changes in block_universal_gemm_as_bs_cr it return wrong
|
||||
// values.
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
@@ -136,7 +155,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipeline::BlockSize,
|
||||
@@ -420,7 +441,18 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
|
||||
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -79,6 +79,8 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
|
||||
Reference in New Issue
Block a user