[CK_TILE] Add error threshold calculation for gemm examples (#1821)

This commit is contained in:
Bartłomiej Kocot
2025-01-18 01:01:52 +01:00
committed by GitHub
parent 0fcbb25f70
commit bdddf1eace
5 changed files with 256 additions and 11 deletions

View File

@@ -1,7 +1,27 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
@@ -148,9 +168,18 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
@@ -196,8 +225,18 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}