mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Add error threshold calculation for gemm examples (#1821)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -18,6 +18,130 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(std::is_same_v<ComputeDataType, F8> || std::is_same_v<ComputeDataType, F16> ||
|
||||
std::is_same_v<ComputeDataType, BF16> ||
|
||||
std::is_same_v<ComputeDataType, F32> || std::is_same_v<ComputeDataType, I8> ||
|
||||
std::is_same_v<ComputeDataType, I32> || std::is_same_v<ComputeDataType, int>,
|
||||
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
|
||||
double compute_error = 0;
|
||||
if constexpr(std::is_same_v<ComputeDataType, I8> || std::is_same_v<ComputeDataType, I32> ||
|
||||
std::is_same_v<ComputeDataType, int>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(std::is_same_v<OutDataType, F8> || std::is_same_v<OutDataType, F16> ||
|
||||
std::is_same_v<OutDataType, BF16> || std::is_same_v<OutDataType, F32> ||
|
||||
std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
|
||||
std::is_same_v<OutDataType, int>,
|
||||
"Warning: Unhandled OutDataType for setting up the relative threshold!");
|
||||
double output_error = 0;
|
||||
if constexpr(std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
|
||||
std::is_same_v<OutDataType, int>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
output_error = std::pow(2, -numeric_traits<OutDataType>::mant) * 0.5;
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(std::is_same_v<AccDataType, F8> || std::is_same_v<AccDataType, F16> ||
|
||||
std::is_same_v<AccDataType, BF16> || std::is_same_v<AccDataType, F32> ||
|
||||
std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
|
||||
std::is_same_v<AccDataType, int>,
|
||||
"Warning: Unhandled AccDataType for setting up the relative threshold!");
|
||||
double acc_error = 0;
|
||||
if constexpr(std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
|
||||
std::is_same_v<AccDataType, int>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_error = std::pow(2, -numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
|
||||
}
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(std::is_same_v<ComputeDataType, F8> || std::is_same_v<ComputeDataType, F16> ||
|
||||
std::is_same_v<ComputeDataType, BF16> ||
|
||||
std::is_same_v<ComputeDataType, F32> || std::is_same_v<ComputeDataType, I8> ||
|
||||
std::is_same_v<ComputeDataType, I32> || std::is_same_v<ComputeDataType, int>,
|
||||
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
|
||||
auto expo = std::log2(std::abs(max_possible_num));
|
||||
double compute_error = 0;
|
||||
if constexpr(std::is_same_v<ComputeDataType, I8> || std::is_same_v<ComputeDataType, I32> ||
|
||||
std::is_same_v<ComputeDataType, int>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(std::is_same_v<OutDataType, F8> || std::is_same_v<OutDataType, F16> ||
|
||||
std::is_same_v<OutDataType, BF16> || std::is_same_v<OutDataType, F32> ||
|
||||
std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
|
||||
std::is_same_v<OutDataType, int>,
|
||||
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
|
||||
double output_error = 0;
|
||||
if constexpr(std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
|
||||
std::is_same_v<OutDataType, int>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
output_error = std::pow(2, expo - numeric_traits<OutDataType>::mant) * 0.5;
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(std::is_same_v<AccDataType, F8> || std::is_same_v<AccDataType, F16> ||
|
||||
std::is_same_v<AccDataType, BF16> || std::is_same_v<AccDataType, F32> ||
|
||||
std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
|
||||
std::is_same_v<AccDataType, int>,
|
||||
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
|
||||
double acc_error = 0;
|
||||
if constexpr(std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
|
||||
std::is_same_v<AccDataType, int>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_error =
|
||||
std::pow(2, expo - numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
|
||||
}
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user