mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Addressing (Post Merge) code review comments for PR 1845 (#1883)
* Addressing code review comments. * Addressing code review comments. * Reorganized code for better readability. * add ck_tile gemms for new types in CI * fix jenkins syntax * fix script syntax * Add the test cases back * Address the review comments * Address review comments * clang format * Solve the merging issues * Addressed the comments * clang format --------- Co-authored-by: illsilin <Illia.Silin@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com> Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -11,6 +11,27 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
enum struct GemmPipelineType
|
||||
{
|
||||
Mem,
|
||||
@@ -63,7 +84,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
// TODO: This should be parameterized in tests
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
@@ -71,8 +92,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
// TODO: Restore to 8. At now after changes in block_universal_gemm_as_bs_cr it return wrong
|
||||
// values.
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
@@ -136,7 +155,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipeline::BlockSize,
|
||||
@@ -420,7 +441,18 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
|
||||
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user