Addressing (Post Merge) code review comments for PR 1845 (#1883)

* Addressing code review comments.

* Addressing code review comments.

* Reorganized code for better readability.

* add ck_tile gemms for new types in CI

* fix jenkins syntax

* fix script syntax

* Add the test cases back

* Address the review comments

* Address review comments

* clang format

* Solve the merging issues

* Addressed the comments

* clang format

---------

Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

[ROCm/composable_kernel commit: 66c5f5b0b6]
This commit is contained in:
kylasa
2025-03-06 11:40:30 -08:00
committed by GitHub
parent 710aa99819
commit 7fbcd06a62
32 changed files with 511 additions and 245 deletions

View File

@@ -67,7 +67,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
CodegenGemmPipeline::BlockSize,

View File

@@ -1,4 +1,6 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_gemm_pipeline test_gemm_pipeline.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp)
endif()

View File

@@ -0,0 +1,16 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_util.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline<T>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3, KernelTypesMem);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,16 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_util.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline<T>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4, KernelTypesMem);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -10,6 +10,7 @@
using F16 = ck_tile::half_t;
using F32 = float;
using F8 = ck_tile::fp8_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
@@ -21,27 +22,41 @@ using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType:
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
using KernelTypesMem = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, Interwave, Mem>
>;
using KernelTypesCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>;
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>
>;
using KernelTypesCompV4 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGemmPipeline, KernelTypes);
#include "test_gemm_pipeline_ut_cases.inc"

View File

@@ -0,0 +1,16 @@
#include "test_gemm_pipeline_kernel_types.hpp"
#include "test_gemm_pipeline_util.hpp"
#include "gtest/gtest.h"
template <typename T>
class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline<T>
{
};
#define TEST_SUITE_NAME TestCkTileGemmPipelineMem
TYPED_TEST_SUITE(TestCkTileGemmPipelineMem, KernelTypesMem);
#include "test_gemm_pipeline_ut_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -3,7 +3,10 @@
#pragma once
TYPED_TEST(TestCkTileGemmPipeline, SmallM)
#ifndef TEST_GEMM_PIPELINE_UT_CASES_INC
#define TEST_GEMM_PIPELINE_UT_CASES_INC
TYPED_TEST(TEST_SUITE_NAME, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 1024;
@@ -13,18 +16,25 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
{
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
{
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
else
{
this->Run(M, N, K);
}
}
}
TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
TYPED_TEST(TEST_SUITE_NAME, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024;
constexpr int K = 320;
constexpr int VecLoadSize = 8;
constexpr int VecLoadSize = (std::is_same_v<typename TestFixture::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TestFixture::ADataType, ck_tile::bf8_t>)
? 16
: 8;
for(int M : Ms)
{
@@ -33,9 +43,13 @@ TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
{
// TODO: Can we anyhow deduce used vector load size?
if(M % VecLoadSize == 0)
{
this->Run(M, N, K);
}
else
{
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
}
else
{
@@ -44,7 +58,7 @@ TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
}
}
TYPED_TEST(TestCkTileGemmPipeline, PaddK)
TYPED_TEST(TEST_SUITE_NAME, PaddK)
{
std::vector<int> Ms{128};
constexpr int N = 1024;
@@ -54,7 +68,7 @@ TYPED_TEST(TestCkTileGemmPipeline, PaddK)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmPipeline, Regular)
TYPED_TEST(TEST_SUITE_NAME, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 1024;
@@ -64,7 +78,16 @@ TYPED_TEST(TestCkTileGemmPipeline, Regular)
this->Run(M, N, K);
}
TYPED_TEST(TestCkTileGemmPipeline, NotSupportedArgument)
TYPED_TEST(TEST_SUITE_NAME, LargeMatrix)
{
constexpr int M = 2048;
constexpr int N = 2048;
constexpr int K = 2048;
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, NotSupportedArgument)
{
constexpr int M = 512;
constexpr int N = 1025;
@@ -76,3 +99,5 @@ TYPED_TEST(TestCkTileGemmPipeline, NotSupportedArgument)
EXPECT_THROW((this->template Run<PadM, PadN, PadK>(M, N, K)), std::runtime_error);
}
#endif

View File

@@ -11,6 +11,27 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
enum struct GemmPipelineType
{
Mem,
@@ -63,7 +84,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
// TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
@@ -71,8 +92,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
// TODO: Restore to 8. At now after changes in block_universal_gemm_as_bs_cr it return wrong
// values.
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool kPadM = PadM;
@@ -136,7 +155,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipeline::BlockSize,
@@ -420,7 +441,18 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
EXPECT_TRUE(pass);
}
};

View File

@@ -79,6 +79,8 @@ class TestCkTileGroupedGemm : public ::testing::Test
template <typename ALayout, typename BLayout, typename CLayout>
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,