From 7e74e0071e103a875d08e615b92c8feed700a140 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Sat, 18 Jan 2025 01:01:52 +0100 Subject: [PATCH] [CK_TILE] Add error threshold calculation for gemm examples (#1821) [ROCm/composable_kernel commit: bdddf1eacec17c648c13ba921a8933f8e4d0174e] --- example/ck_tile/03_gemm/run_gemm_example.inc | 47 ++++++- .../run_batched_gemm_example.inc | 47 ++++++- .../run_grouped_gemm_example.inc | 35 ++++- include/ck_tile/core/numeric/bfloat16.hpp | 12 +- include/ck_tile/host/check_err.hpp | 126 +++++++++++++++++- 5 files changed, 256 insertions(+), 11 deletions(-) diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 56d0348bd6..e8fa102643 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -1,7 +1,27 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + template float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, @@ -148,9 +168,18 @@ int run_gemm_example_with_layouts(int argc, ck_tile::reference_gemm( a_m_k, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); - + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; } else if(arg_parser.get_int("v") == 2) @@ -196,8 +225,18 @@ int run_gemm_example_with_layouts(int argc, ck_tile::hip_check_error(hipFree(d_C)); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; } diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index c14bb5668c..2fe81e87c4 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -1,8 +1,28 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + template float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, @@ -179,8 +199,18 @@ int run_batched_gemm_example_with_layouts(int argc, ck_tile::reference_batched_gemm( a_m_k, b_n_k, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; } @@ -240,7 +270,18 @@ int run_batched_gemm_example_with_layouts(int argc, ck_tile::hip_check_error(hipFree(d_C)); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); - pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; } diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 11faa6642c..e889a85bf4 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -1,8 +1,28 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + template float invoke_gemm(int n_warmup, int n_repeat, @@ -162,7 +182,18 @@ int run_grouped_gemm_example_with_layouts(int argc, c_m_n_host_ref.SetZero(); ck_tile::reference_gemm( a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); - pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(Ks[i], 1 /*kbatch*/, max_accumulated_value); + pass &= ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; } std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; } diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 499ba80a88..6ad38b1f7c 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/config.hpp" #include "ck_tile/core/utility/bit_cast.hpp" @@ -376,6 +376,16 @@ struct numeric } }; +template +struct numeric_traits; + +template <> +struct numeric_traits +{ + static constexpr int exp = 8; + static constexpr int mant = 7; +}; + #if CK_TILE_USE_CUSTOM_DATA_TYPE CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t) #endif diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 529bfdff25..c4ad345d8e 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,6 +18,130 @@ namespace ck_tile { +template +double get_relative_threshold(const int number_of_accumulations = 1) +{ + using F8 = ck_tile::fp8_t; + using F16 = ck_tile::half_t; + using BF16 = ck_tile::bf16_t; + using F32 = float; + using I8 = int8_t; + using I32 = int32_t; + + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v, + "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); + double compute_error = 0; + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + return 0; + } + else + { + compute_error = std::pow(2, -numeric_traits::mant) * 0.5; + } + + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Warning: Unhandled OutDataType for setting up the relative threshold!"); + double output_error = 0; + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + return 0; + } + else + { + output_error = std::pow(2, -numeric_traits::mant) * 0.5; + } + double midway_error = std::max(compute_error, output_error); + + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Warning: Unhandled AccDataType for setting up the relative threshold!"); + double acc_error = 0; + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + return 0; + } + else + { + acc_error = std::pow(2, -numeric_traits::mant) * 0.5 * number_of_accumulations; + } + return std::max(acc_error, midway_error); +} + +template +double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) +{ + using F8 = ck_tile::fp8_t; + using F16 = ck_tile::half_t; + using BF16 = ck_tile::bf16_t; + using F32 = float; + using I8 = int8_t; + using I32 = int32_t; + + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v, + "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); + auto expo = std::log2(std::abs(max_possible_num)); + double compute_error = 0; + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + return 0; + } + else + { + compute_error = std::pow(2, expo - numeric_traits::mant) * 0.5; + } + + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Warning: Unhandled OutDataType for setting up the absolute threshold!"); + double output_error = 0; + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + return 0; + } + else + { + output_error = std::pow(2, expo - numeric_traits::mant) * 0.5; + } + double midway_error = std::max(compute_error, output_error); + + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + "Warning: Unhandled AccDataType for setting up the absolute threshold!"); + double acc_error = 0; + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + return 0; + } + else + { + acc_error = + std::pow(2, expo - numeric_traits::mant) * 0.5 * number_of_accumulations; + } + return std::max(acc_error, midway_error); +} + template std::ostream& operator<<(std::ostream& os, const std::vector& v) {