mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-23 16:47:40 +00:00
[CK_TILE] Enable MXFP6 for MX GEMM op ## Motivation Add support for MXFP6 in the MX GEMM op in CK-Tile. Depends on https://github.com/ROCm/rocm-libraries/pull/4594 ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
791 lines
28 KiB
C++
791 lines
28 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <cstdlib>
|
|
#include <iostream>
|
|
#include <iomanip>
|
|
#include <iterator>
|
|
#include <limits>
|
|
#include <type_traits>
|
|
#include <vector>
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host/ranges.hpp"
|
|
|
|
namespace ck_tile {
|
|
|
|
/** @brief Maximum number of error values to display when checking errors */
|
|
constexpr int ERROR_DETAIL_LIMIT = 16;
|
|
|
|
/** @brief 8-bit floating point type */
|
|
using F8 = ck_tile::fp8_t;
|
|
/** @brief 8-bit brain floating point type */
|
|
using BF8 = ck_tile::bf8_t;
|
|
/** @brief 16-bit floating point (half precision) type */
|
|
using F16 = ck_tile::half_t;
|
|
/** @brief 16-bit brain floating point type */
|
|
using BF16 = ck_tile::bf16_t;
|
|
/** @brief 32-bit floating point (single precision) type */
|
|
using F32 = float;
|
|
/** @brief 8-bit signed integer type */
|
|
using I8 = int8_t;
|
|
/** @brief 32-bit signed integer type */
|
|
using I32 = int32_t;
|
|
|
|
/**
|
|
* @brief Calculate relative error threshold for numerical comparisons
|
|
*
|
|
* Calculates the relative error threshold based on the mantissa bits and characteristics
|
|
* of the data types involved in the computation.
|
|
*
|
|
* @tparam ComputeDataType Type used for computation
|
|
* @tparam OutDataType Type used for output
|
|
* @tparam AccDataType Type used for accumulation (defaults to ComputeDataType)
|
|
* @param number_of_accumulations Number of accumulation operations performed
|
|
* @return Relative error threshold based on data type characteristics
|
|
*/
|
|
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
|
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1)
|
|
{
|
|
|
|
static_assert(is_any_of<ComputeDataType,
|
|
F8,
|
|
BF8,
|
|
F16,
|
|
BF16,
|
|
F32,
|
|
tf32_t,
|
|
pk_fp4_t,
|
|
pk_fp4_raw_t,
|
|
pk_fp6x16_t,
|
|
pk_int4_t,
|
|
I8,
|
|
I32,
|
|
int>::value,
|
|
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
|
|
|
|
double compute_error = 0;
|
|
if constexpr(is_any_of<ComputeDataType, pk_int4_t, I8, I32, int>::value)
|
|
{
|
|
return 0;
|
|
}
|
|
else
|
|
{
|
|
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
|
|
}
|
|
|
|
static_assert(
|
|
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, tf32_t, pk_int4_t, I8, I32, int>::value,
|
|
"Warning: Unhandled OutDataType for setting up the relative threshold!");
|
|
|
|
double output_error = 0;
|
|
if constexpr(is_any_of<OutDataType, pk_int4_t, I8, I32, int>::value)
|
|
{
|
|
return 0;
|
|
}
|
|
else
|
|
{
|
|
output_error = std::pow(2, -numeric_traits<OutDataType>::mant) * 0.5;
|
|
}
|
|
double midway_error = std::max(compute_error, output_error);
|
|
|
|
static_assert(
|
|
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, tf32_t, pk_int4_t, I8, I32, int>::value,
|
|
"Warning: Unhandled AccDataType for setting up the relative threshold!");
|
|
|
|
double acc_error = 0;
|
|
if constexpr(is_any_of<AccDataType, pk_int4_t, I8, I32, int>::value)
|
|
{
|
|
return 0;
|
|
}
|
|
else
|
|
{
|
|
acc_error = std::pow(2, -numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
|
|
}
|
|
return std::max(acc_error, midway_error);
|
|
}
|
|
|
|
/**
|
|
* @brief Calculate absolute error threshold for numerical comparisons
|
|
*
|
|
* Calculates the absolute error threshold based on the maximum possible value and
|
|
* the characteristics of the data types involved in the computation.
|
|
*
|
|
* @tparam ComputeDataType Type used for computation
|
|
* @tparam OutDataType Type used for output
|
|
* @tparam AccDataType Type used for accumulation (defaults to ComputeDataType)
|
|
* @param max_possible_num Maximum possible value in the computation
|
|
* @param number_of_accumulations Number of accumulation operations performed
|
|
* @return Absolute error threshold based on data type characteristics and maximum value
|
|
*/
|
|
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
|
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
|
|
const int number_of_accumulations = 1)
|
|
{
|
|
|
|
static_assert(is_any_of<ComputeDataType,
|
|
F8,
|
|
BF8,
|
|
F16,
|
|
BF16,
|
|
F32,
|
|
tf32_t,
|
|
pk_fp4_t,
|
|
pk_fp4_raw_t,
|
|
pk_fp6x16_t,
|
|
pk_int4_t,
|
|
I8,
|
|
I32,
|
|
int>::value,
|
|
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
|
|
|
|
// Use discrete exponent (floor of log2) to match actual floating-point exponent levels
|
|
// This ensures ULP calculation matches the discrete precision levels of FP representation
|
|
int discrete_expo =
|
|
std::floor(static_cast<int>(std::floor(std::log2(std::abs(max_possible_num)))));
|
|
double compute_error = 0;
|
|
if constexpr(is_any_of<ComputeDataType, pk_int4_t, I8, I32, int>::value)
|
|
{
|
|
return 0;
|
|
}
|
|
else
|
|
{
|
|
compute_error = std::pow(2, discrete_expo - numeric_traits<ComputeDataType>::mant) * 0.5;
|
|
}
|
|
|
|
static_assert(
|
|
is_any_of<OutDataType, F8, BF8, F16, BF16, F32, tf32_t, pk_int4_t, I8, I32, int>::value,
|
|
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
|
|
|
|
double output_error = 0;
|
|
if constexpr(is_any_of<OutDataType, pk_int4_t, I8, I32, int>::value)
|
|
{
|
|
return 0;
|
|
}
|
|
else
|
|
{
|
|
// Use full ULP (1.0) instead of half ULP (0.5) for output_error to account for
|
|
// hardware vs software conversion differences (e.g., hardware __bf16 vs software
|
|
// float_to_bf16 can differ by up to 1 ULP at tie cases)
|
|
output_error = std::pow(2, discrete_expo - numeric_traits<OutDataType>::mant) * 1.0;
|
|
}
|
|
double midway_error = std::max(compute_error, output_error);
|
|
|
|
static_assert(
|
|
is_any_of<AccDataType, F8, BF8, F16, BF16, F32, tf32_t, pk_int4_t, I8, I32, int>::value,
|
|
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
|
|
|
|
double acc_error = 0;
|
|
if constexpr(is_any_of<AccDataType, pk_int4_t, I8, I32, int>::value)
|
|
{
|
|
return 0;
|
|
}
|
|
else
|
|
{
|
|
acc_error = std::pow(2, discrete_expo - numeric_traits<AccDataType>::mant) * 0.5 *
|
|
number_of_accumulations;
|
|
}
|
|
return std::max(acc_error, midway_error);
|
|
}
|
|
|
|
/**
|
|
* @brief Stream operator overload for vector output
|
|
*
|
|
* Provides a formatted string representation of a vector, useful for debugging and logging.
|
|
*
|
|
* @tparam T Type of vector elements
|
|
* @param os Output stream
|
|
* @param v Vector to output
|
|
* @return Reference to the output stream
|
|
*/
|
|
template <typename T>
|
|
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
|
{
|
|
using size_type = typename std::vector<T>::size_type;
|
|
|
|
os << "[";
|
|
for(size_type idx = 0; idx < v.size(); ++idx)
|
|
{
|
|
if(0 < idx)
|
|
{
|
|
os << ", ";
|
|
}
|
|
os << v[idx];
|
|
}
|
|
return os << "]";
|
|
}
|
|
|
|
/**
|
|
* @brief Check for size mismatch between output and reference ranges
|
|
*
|
|
* Verifies that the output and reference ranges are the same size.
|
|
*
|
|
* @tparam Range Type of output range
|
|
* @tparam RefRange Type of reference range
|
|
* @param out Output range to check
|
|
* @param ref Reference range to check against
|
|
* @param msg Error message to display if sizes mismatch
|
|
* @return True if sizes mismatch, false otherwise
|
|
*/
|
|
template <typename Range, typename RefRange>
|
|
CK_TILE_HOST bool check_size_mismatch(const Range& out,
|
|
const RefRange& ref,
|
|
const std::string& msg = "Error: Incorrect results!")
|
|
{
|
|
if(out.size() != ref.size())
|
|
{
|
|
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
|
<< std::endl;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/**
|
|
* @brief Report error statistics for numerical comparisons
|
|
*
|
|
* Outputs statistics about numerical comparison errors including count and maximum error.
|
|
*
|
|
* @param err_count Number of errors found
|
|
* @param max_err Maximum error value encountered
|
|
* @param total_size Total number of elements compared
|
|
*/
|
|
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
|
|
{
|
|
const float error_percent =
|
|
static_cast<float>(err_count) / static_cast<float>(total_size) * 100.f;
|
|
std::cerr << "max err: " << max_err;
|
|
std::cerr << ", number of errors: " << err_count;
|
|
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
|
}
|
|
|
|
/**
|
|
* @brief Check errors between floating point ranges using the specified tolerances.
|
|
*
|
|
* Compares two ranges of floating point values within specified relative and absolute tolerances.
|
|
* This overload handles standard floating point types except half precision floating point.
|
|
*
|
|
* @tparam Range Type of output range
|
|
* @tparam RefRange Type of reference range
|
|
* @param out Output range to check
|
|
* @param ref Reference range to check against
|
|
* @param msg Error message to display if check fails
|
|
* @param rtol Relative tolerance
|
|
* @param atol Absolute tolerance
|
|
* @param allow_infinity_ref Whether to allow infinity in reference values
|
|
* @return True if check passes, false otherwise
|
|
*/
|
|
template <typename Range, typename RefRange>
|
|
typename std::enable_if<
|
|
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
|
std::is_floating_point_v<ranges::range_value_t<Range>> &&
|
|
!std::is_same_v<ranges::range_value_t<Range>, half_t>,
|
|
bool>::type CK_TILE_HOST
|
|
check_err(const Range& out,
|
|
const RefRange& ref,
|
|
const std::string& msg = "Error: Incorrect results!",
|
|
double rtol = 1e-5,
|
|
double atol = 3e-6,
|
|
bool allow_infinity_ref = false)
|
|
{
|
|
|
|
if(check_size_mismatch(out, ref, msg))
|
|
return false;
|
|
|
|
const auto is_infinity_error = [=](auto o, auto r) {
|
|
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
|
const bool both_infinite_and_same =
|
|
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
|
|
|
|
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
|
|
};
|
|
|
|
bool res{true};
|
|
int err_count = 0;
|
|
double err = 0;
|
|
double max_err = std::numeric_limits<double>::min();
|
|
for(std::size_t i = 0; i < ref.size(); ++i)
|
|
{
|
|
const double o = *std::next(std::begin(out), i);
|
|
const double r = *std::next(std::begin(ref), i);
|
|
err = std::abs(o - r);
|
|
if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
|
|
{
|
|
max_err = err > max_err ? err : max_err;
|
|
err_count++;
|
|
if(err_count < ERROR_DETAIL_LIMIT)
|
|
{
|
|
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
|
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
|
}
|
|
res = false;
|
|
}
|
|
}
|
|
if(!res)
|
|
{
|
|
report_error_stats(err_count, max_err, ref.size());
|
|
}
|
|
return res;
|
|
}
|
|
|
|
/**
|
|
* @brief Check errors between floating point ranges using the specified tolerances
|
|
*
|
|
* Compares two ranges of brain floating point values within specified relative and absolute
|
|
* tolerances.
|
|
*
|
|
* @tparam Range Type of output range
|
|
* @tparam RefRange Type of reference range
|
|
* @param out Output range to check
|
|
* @param ref Reference range to check against
|
|
* @param msg Error message to display if check fails
|
|
* @param rtol Relative tolerance
|
|
* @param atol Absolute tolerance
|
|
* @param allow_infinity_ref Whether to allow infinity in reference values
|
|
* @return True if check passes, false otherwise
|
|
*/
|
|
template <typename Range, typename RefRange>
|
|
typename std::enable_if<
|
|
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
|
std::is_same_v<ranges::range_value_t<Range>, bf16_t>,
|
|
bool>::type CK_TILE_HOST
|
|
check_err(const Range& out,
|
|
const RefRange& ref,
|
|
const std::string& msg = "Error: Incorrect results!",
|
|
double rtol = 1e-3,
|
|
double atol = 1e-3,
|
|
bool allow_infinity_ref = false)
|
|
{
|
|
if(check_size_mismatch(out, ref, msg))
|
|
return false;
|
|
|
|
const auto is_infinity_error = [=](auto o, auto r) {
|
|
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
|
const bool both_infinite_and_same =
|
|
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
|
|
|
|
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
|
|
};
|
|
|
|
bool res{true};
|
|
int err_count = 0;
|
|
double err = 0;
|
|
// TODO: This is a hack. We should have proper specialization for bf16_t data type.
|
|
double max_err = std::numeric_limits<float>::min();
|
|
for(std::size_t i = 0; i < ref.size(); ++i)
|
|
{
|
|
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
|
const double r = type_convert<float>(*std::next(std::begin(ref), i));
|
|
err = std::abs(o - r);
|
|
if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
|
|
{
|
|
max_err = err > max_err ? err : max_err;
|
|
err_count++;
|
|
if(err_count < ERROR_DETAIL_LIMIT)
|
|
{
|
|
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
|
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
|
}
|
|
res = false;
|
|
}
|
|
}
|
|
if(!res)
|
|
{
|
|
report_error_stats(err_count, max_err, ref.size());
|
|
}
|
|
return res;
|
|
}
|
|
|
|
/**
|
|
* @brief Check errors between half precision floating point ranges
|
|
*
|
|
* Compares two ranges of half precision floating point values within specified tolerances.
|
|
* This specialization handles the specific requirements and characteristics of half precision
|
|
* floating point comparisons.
|
|
*
|
|
* @tparam Range Type of output range
|
|
* @tparam RefRange Type of reference range
|
|
* @param out Output range to check
|
|
* @param ref Reference range to check against
|
|
* @param msg Error message to display if check fails
|
|
* @param rtol Relative tolerance
|
|
* @param atol Absolute tolerance
|
|
* @param allow_infinity_ref Whether to allow infinity in reference values
|
|
* @return True if check passes, false otherwise
|
|
*/
|
|
template <typename Range, typename RefRange>
|
|
typename std::enable_if<
|
|
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
|
std::is_same_v<ranges::range_value_t<Range>, half_t>,
|
|
bool>::type CK_TILE_HOST
|
|
check_err(const Range& out,
|
|
const RefRange& ref,
|
|
const std::string& msg = "Error: Incorrect results!",
|
|
double rtol = 1e-3,
|
|
double atol = 1e-3,
|
|
bool allow_infinity_ref = false)
|
|
{
|
|
if(check_size_mismatch(out, ref, msg))
|
|
return false;
|
|
|
|
const auto is_infinity_error = [=](auto o, auto r) {
|
|
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
|
const bool both_infinite_and_same =
|
|
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
|
|
|
|
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
|
|
};
|
|
|
|
bool res{true};
|
|
int err_count = 0;
|
|
double err = 0;
|
|
double max_err = static_cast<double>(std::numeric_limits<ranges::range_value_t<Range>>::min());
|
|
for(std::size_t i = 0; i < ref.size(); ++i)
|
|
{
|
|
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
|
const double r = type_convert<float>(*std::next(std::begin(ref), i));
|
|
err = std::abs(o - r);
|
|
if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
|
|
{
|
|
max_err = err > max_err ? err : max_err;
|
|
err_count++;
|
|
if(err_count < ERROR_DETAIL_LIMIT)
|
|
{
|
|
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
|
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
|
}
|
|
res = false;
|
|
}
|
|
}
|
|
if(!res)
|
|
{
|
|
report_error_stats(err_count, max_err, ref.size());
|
|
}
|
|
return res;
|
|
}
|
|
|
|
/**
|
|
* @brief Check errors between integer ranges
|
|
*
|
|
* Compares two ranges of integer values with an absolute tolerance.
|
|
* This specialization handles integer types and optionally int4_t when the
|
|
* experimental bit int extension is enabled.
|
|
*
|
|
* @tparam Range Type of output range
|
|
* @tparam RefRange Type of reference range
|
|
* @param out Output range to check
|
|
* @param ref Reference range to check against
|
|
* @param msg Error message to display if check fails
|
|
* @param atol Absolute tolerance
|
|
* @return True if check passes, false otherwise
|
|
*/
|
|
template <typename Range, typename RefRange>
|
|
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
|
std::is_integral_v<ranges::range_value_t<Range>> &&
|
|
!std::is_same_v<ranges::range_value_t<Range>, bf16_t>)
|
|
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
|
|| std::is_same_v<ranges::range_value_t<Range>, int4_t>
|
|
#endif
|
|
,
|
|
bool>
|
|
CK_TILE_HOST check_err(const Range& out,
|
|
const RefRange& ref,
|
|
const std::string& msg = "Error: Incorrect results!",
|
|
double = 0,
|
|
double atol = 0)
|
|
{
|
|
if(check_size_mismatch(out, ref, msg))
|
|
return false;
|
|
|
|
bool res{true};
|
|
int err_count = 0;
|
|
int64_t err = 0;
|
|
int64_t max_err = std::numeric_limits<int64_t>::min();
|
|
for(std::size_t i = 0; i < ref.size(); ++i)
|
|
{
|
|
const int64_t o = *std::next(std::begin(out), i);
|
|
const int64_t r = *std::next(std::begin(ref), i);
|
|
err = std::abs(o - r);
|
|
|
|
if(err > atol)
|
|
{
|
|
max_err = err > max_err ? err : max_err;
|
|
err_count++;
|
|
if(err_count < ERROR_DETAIL_LIMIT)
|
|
{
|
|
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
|
|
<< std::endl;
|
|
}
|
|
res = false;
|
|
}
|
|
}
|
|
if(!res)
|
|
{
|
|
report_error_stats(err_count, static_cast<double>(max_err), ref.size());
|
|
}
|
|
return res;
|
|
}
|
|
|
|
/**
|
|
* @brief Check errors between FP8 ranges
|
|
*
|
|
* Specialized comparison for 8-bit floating point values that takes into account
|
|
* the unique characteristics and limitations of FP8 arithmetic, including
|
|
* rounding point distances and special handling of infinity values.
|
|
*
|
|
* @tparam Range Type of output range
|
|
* @tparam RefRange Type of reference range
|
|
* @param out Output range to check
|
|
* @param ref Reference range to check against
|
|
* @param msg Error message to display if check fails
|
|
* @param max_rounding_point_distance Maximum allowed distance between rounding points
|
|
* @param atol Absolute tolerance
|
|
* @param allow_infinity_ref Whether to allow infinity in reference values
|
|
* @return True if check passes, false otherwise
|
|
*/
|
|
template <typename Range, typename RefRange>
|
|
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
|
std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
|
|
bool>
|
|
CK_TILE_HOST check_err(const Range& out,
|
|
const RefRange& ref,
|
|
const std::string& msg = "Error: Incorrect results!",
|
|
unsigned max_rounding_point_distance = 1,
|
|
double atol = 1e-1,
|
|
bool allow_infinity_ref = false)
|
|
{
|
|
if(check_size_mismatch(out, ref, msg))
|
|
return false;
|
|
|
|
const auto is_infinity_error = [=](auto o, auto r) {
|
|
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
|
const bool both_infinite_and_same =
|
|
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
|
|
|
|
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
|
|
};
|
|
|
|
static const auto get_rounding_point_distance = [](fp8_t o, fp8_t r) -> unsigned {
|
|
static const auto get_sign_bit = [](fp8_t v) -> bool {
|
|
return 0x80 & bit_cast<uint8_t>(v);
|
|
};
|
|
|
|
if(get_sign_bit(o) ^ get_sign_bit(r))
|
|
{
|
|
return std::numeric_limits<unsigned>::max();
|
|
}
|
|
else
|
|
{
|
|
return std::abs(bit_cast<int8_t>(o) - bit_cast<int8_t>(r));
|
|
}
|
|
};
|
|
|
|
bool res{true};
|
|
int err_count = 0;
|
|
double err = 0;
|
|
double max_err = std::numeric_limits<float>::min();
|
|
for(std::size_t i = 0; i < ref.size(); ++i)
|
|
{
|
|
const fp8_t o_fp8 = *std::next(std::begin(out), i);
|
|
const fp8_t r_fp8 = *std::next(std::begin(ref), i);
|
|
const double o_fp64 = type_convert<float>(o_fp8);
|
|
const double r_fp64 = type_convert<float>(r_fp8);
|
|
err = std::abs(o_fp64 - r_fp64);
|
|
if(!(less_equal<double>{}(err, atol) ||
|
|
get_rounding_point_distance(o_fp8, r_fp8) <= max_rounding_point_distance) ||
|
|
is_infinity_error(o_fp64, r_fp64))
|
|
{
|
|
max_err = err > max_err ? err : max_err;
|
|
err_count++;
|
|
if(err_count < ERROR_DETAIL_LIMIT)
|
|
{
|
|
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
|
<< "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
|
|
}
|
|
res = false;
|
|
}
|
|
}
|
|
if(!res)
|
|
{
|
|
report_error_stats(err_count, max_err, ref.size());
|
|
}
|
|
return res;
|
|
}
|
|
|
|
/**
|
|
* @brief Check errors between BF8 ranges
|
|
*
|
|
* Specialized comparison for 8-bit brain floating point values that considers
|
|
* the specific numerical properties and error characteristics of the BF8 format.
|
|
*
|
|
* @tparam Range Type of output range
|
|
* @tparam RefRange Type of reference range
|
|
* @param out Output range to check
|
|
* @param ref Reference range to check against
|
|
* @param msg Error message to display if check fails
|
|
* @param rtol Relative tolerance
|
|
* @param atol Absolute tolerance
|
|
* @param allow_infinity_ref Whether to allow infinity in reference values
|
|
* @return True if check passes, false otherwise
|
|
*/
|
|
template <typename Range, typename RefRange>
|
|
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
|
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
|
|
bool>
|
|
CK_TILE_HOST check_err(const Range& out,
|
|
const RefRange& ref,
|
|
const std::string& msg = "Error: Incorrect results!",
|
|
double rtol = 1e-3,
|
|
double atol = 1e-3,
|
|
bool allow_infinity_ref = false)
|
|
{
|
|
if(check_size_mismatch(out, ref, msg))
|
|
return false;
|
|
|
|
const auto is_infinity_error = [=](auto o, auto r) {
|
|
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
|
const bool both_infinite_and_same =
|
|
std::isinf(o) && std::isinf(r) && (bit_cast<uint64_t>(o) == bit_cast<uint64_t>(r));
|
|
|
|
return either_not_finite && !(allow_infinity_ref && both_infinite_and_same);
|
|
};
|
|
|
|
bool res{true};
|
|
int err_count = 0;
|
|
double err = 0;
|
|
double max_err = std::numeric_limits<float>::min();
|
|
for(std::size_t i = 0; i < ref.size(); ++i)
|
|
{
|
|
const double o = type_convert<float>(*std::next(std::begin(out), i));
|
|
const double r = type_convert<float>(*std::next(std::begin(ref), i));
|
|
err = std::abs(o - r);
|
|
if(err > atol + rtol * std::abs(r) || is_infinity_error(o, r))
|
|
{
|
|
max_err = err > max_err ? err : max_err;
|
|
err_count++;
|
|
if(err_count < ERROR_DETAIL_LIMIT)
|
|
{
|
|
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
|
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
|
}
|
|
res = false;
|
|
}
|
|
}
|
|
if(!res)
|
|
{
|
|
report_error_stats(err_count, max_err, ref.size());
|
|
}
|
|
return res;
|
|
}
|
|
|
|
/**
|
|
* @brief Check errors between pk_fp4_t ranges
|
|
*
|
|
* Compares two ranges of pk_fp4_t without tolerance.
|
|
* This specialization handles ck_tile::pk_fp4_t type.
|
|
*
|
|
* @tparam Range Type of output range
|
|
* @tparam RefRange Type of reference range
|
|
* @param out Output range to check
|
|
* @param ref Reference range to check against
|
|
* @param msg Error message to display if check fails
|
|
* @return True if check passes, false otherwise
|
|
*/
|
|
template <typename Range, typename RefRange>
|
|
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
|
std::is_same_v<ranges::range_value_t<Range>, pk_fp4_t>),
|
|
bool>
|
|
CK_TILE_HOST check_err(const Range& out,
|
|
const RefRange& ref,
|
|
const std::string& msg = "Error: Incorrect results!",
|
|
double = 0,
|
|
double = 0)
|
|
{
|
|
if(check_size_mismatch(out, ref, msg))
|
|
return false;
|
|
|
|
int err_count = 0;
|
|
|
|
auto update_err = [&](pk_fp4_raw_t o, pk_fp4_raw_t r, std::size_t index) {
|
|
if(o != r)
|
|
{
|
|
std::cerr << msg << " out[" << index << "] != ref[" << index
|
|
<< "]: " << type_convert<float>(pk_fp4_t{o})
|
|
<< " != " << type_convert<float>(pk_fp4_t{r}) << std::endl;
|
|
++err_count;
|
|
}
|
|
};
|
|
|
|
for(std::size_t i = 0; i < ref.size(); ++i)
|
|
{
|
|
const pk_fp4_t o = *std::next(std::begin(out), i);
|
|
const pk_fp4_t r = *std::next(std::begin(ref), i);
|
|
update_err(o._unpack(number<0>{}), r._unpack(number<0>{}), i * 2);
|
|
update_err(o._unpack(number<1>{}), r._unpack(number<1>{}), i * 2 + 1);
|
|
}
|
|
if(err_count > 0)
|
|
{
|
|
report_error_stats(err_count, numeric<pk_fp4_t>::max(), ref.size());
|
|
}
|
|
return err_count == 0;
|
|
}
|
|
|
|
/**
|
|
* @brief Check errors between pk_fp6x16_t ranges
|
|
*
|
|
* Compares two ranges of pk_fp6x16_t without tolerance.
|
|
* This specialization handles ck_tile::pk_fp6x16_t type.
|
|
*
|
|
* @tparam Range Type of output range
|
|
* @tparam RefRange Type of reference range
|
|
* @param out Output range to check
|
|
* @param ref Reference range to check against
|
|
* @param msg Error message to display if check fails
|
|
* @return True if check passes, false otherwise
|
|
*/
|
|
template <typename Range, typename RefRange>
|
|
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
|
std::is_same_v<ranges::range_value_t<Range>, pk_fp6x16_t>),
|
|
bool>
|
|
CK_TILE_HOST check_err(const Range& out,
|
|
const RefRange& ref,
|
|
const std::string& msg = "Error: Incorrect results!",
|
|
double = 0,
|
|
double = 0)
|
|
{
|
|
if(check_size_mismatch(out, ref, msg))
|
|
return false;
|
|
|
|
int err_count = 0;
|
|
float max_err = 0.0f;
|
|
auto update_err = [&](float o, float r, std::size_t index) {
|
|
if(std::fabs(o - r) > 1e-8)
|
|
{
|
|
std::cerr << msg << " out[" << index << "] != ref[" << index << "]: " << o
|
|
<< " != " << r << std::endl;
|
|
++err_count;
|
|
max_err = max_err < std::fabs(o - r) ? o : max_err;
|
|
}
|
|
};
|
|
for(std::size_t i = 0; i < ref.size(); ++i)
|
|
{
|
|
const pk_fp6x16_t o = *std::next(std::begin(out), i);
|
|
const pk_fp6x16_t r = *std::next(std::begin(ref), i);
|
|
for(std::size_t j = 0; j < numeric_traits<pk_fp6x16_t>::PackedSize; j++)
|
|
{
|
|
update_err(o.unpack(j), r.unpack(j), i * numeric_traits<pk_fp6x16_t>::PackedSize + j);
|
|
}
|
|
}
|
|
if(err_count > 0)
|
|
{
|
|
report_error_stats(err_count, max_err, ref.size());
|
|
}
|
|
return err_count == 0;
|
|
}
|
|
|
|
} // namespace ck_tile
|