feat(HostTensor): Extend support for HostTensor class' >> operator to print more data types (#2691)

* feat(check_err): add a variable to adjust number of incorrect values to print

* feat(host_tensor): add printing capability for fp8 bf8 int8 int4

* fix(gemm_utils): update acceptable data type

* fix(host_tensor): print both 4 bit ints in pk_int4_t

* refactor(HostTensor): define pk_int4_t_to_int8x2_t and fix typo in vector_type.hpp

* feat(host_tensor): add print first n elements functions
This commit is contained in:
Aviral Goel
2025-08-27 21:17:24 -04:00
committed by GitHub
parent 9751583f95
commit f5f795c4d6
5 changed files with 125 additions and 49 deletions

View File

@@ -18,6 +18,9 @@
namespace ck_tile {
/** @brief Maximum number of error values to display when checking errors */
constexpr int ERROR_DETAIL_LIMIT = 5;
/** @brief 8-bit floating point type */
using F8 = ck_tile::fp8_t;
/** @brief 8-bit brain floating point type */
@@ -280,7 +283,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < ERROR_DETAIL_LIMIT)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
@@ -348,7 +351,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < ERROR_DETAIL_LIMIT)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
@@ -416,7 +419,7 @@ check_err(const Range& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < ERROR_DETAIL_LIMIT)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
@@ -478,7 +481,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < ERROR_DETAIL_LIMIT)
{
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
<< std::endl;
@@ -564,7 +567,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < ERROR_DETAIL_LIMIT)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
@@ -630,7 +633,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < ERROR_DETAIL_LIMIT)
{
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;