Addressing (Post Merge) code review comments for PR 1845 (#1883)

* Addressing code review comments.

* Addressing code review comments.

* Reorganized code for better readability.

* add ck_tile gemms for new types in CI

* fix jenkins syntax

* fix script syntax

* Add the test cases back

* Address the review comments

* Address review comments

* clang format

* Solve the merging issues

* Addressed the comments

* clang format

---------

Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
kylasa
2025-03-06 11:40:30 -08:00
committed by GitHub
parent c12fb0a624
commit 66c5f5b0b6
32 changed files with 511 additions and 245 deletions

View File

@@ -11,6 +11,27 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
enum struct GemmPipelineType
{
Mem,
@@ -63,7 +84,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
// TODO: This should be parameterized in tests
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
@@ -71,8 +92,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
// TODO: Restore to 8. At now after changes in block_universal_gemm_as_bs_cr it return wrong
// values.
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool kPadM = PadM;
@@ -136,7 +155,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipeline::BlockSize,
@@ -420,7 +441,18 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
EXPECT_TRUE(pass);
}
};