mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
feat(HostTensor): Extend support for HostTensor class' >> operator to print more data types (#2691)
* feat(check_err): add a variable to adjust number of incorrect values to print * feat(host_tensor): add printing capability for fp8 bf8 int8 int4 * fix(gemm_utils): update acceptable data type * fix(host_tensor): print both 4 bit ints in pk_int4_t * refactor(HostTensor): define pk_int4_t_to_int8x2_t and fix typo in vector_type.hpp * feat(host_tensor): add print first n elements functions
This commit is contained in:
@@ -18,6 +18,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/** @brief Maximum number of error values to display when checking errors */
|
||||
constexpr int ERROR_DETAIL_LIMIT = 5;
|
||||
|
||||
/** @brief 8-bit floating point type */
|
||||
using F8 = ck_tile::fp8_t;
|
||||
/** @brief 8-bit brain floating point type */
|
||||
@@ -280,7 +283,7 @@ check_err(const Range& out,
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < ERROR_DETAIL_LIMIT)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
@@ -348,7 +351,7 @@ check_err(const Range& out,
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < ERROR_DETAIL_LIMIT)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
@@ -416,7 +419,7 @@ check_err(const Range& out,
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < ERROR_DETAIL_LIMIT)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
@@ -478,7 +481,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < ERROR_DETAIL_LIMIT)
|
||||
{
|
||||
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
|
||||
<< std::endl;
|
||||
@@ -564,7 +567,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < ERROR_DETAIL_LIMIT)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl;
|
||||
@@ -630,7 +633,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
if(err_count < ERROR_DETAIL_LIMIT)
|
||||
{
|
||||
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
|
||||
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
|
||||
|
||||
@@ -642,6 +642,51 @@ struct HostTensor
|
||||
size() * FromSize / ToSize};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Print only the first N elements of the tensor
|
||||
*
|
||||
* @param os Output stream to write to
|
||||
* @param n Number of elements to print (default: 5)
|
||||
* @return std::ostream& Reference to the output stream
|
||||
*/
|
||||
std::ostream& print_first_n(std::ostream& os, std::size_t n = 5) const
|
||||
{
|
||||
os << mDesc;
|
||||
os << "[";
|
||||
for(typename Data::size_type idx = 0; idx < std::min(n, mData.size()); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>)
|
||||
{
|
||||
os << type_convert<float>(mData[idx]) << " #### ";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::pk_int4_t>)
|
||||
{
|
||||
auto unpacked = pk_int4_t_to_int8x2_t(mData[idx]);
|
||||
os << "pk(" << static_cast<int>(unpacked[0]) << ", "
|
||||
<< static_cast<int>(unpacked[1]) << ") #### ";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, int8_t>)
|
||||
{
|
||||
os << static_cast<int>(mData[idx]);
|
||||
}
|
||||
else
|
||||
{
|
||||
os << mData[idx];
|
||||
}
|
||||
}
|
||||
if(mData.size() > n)
|
||||
{
|
||||
os << ", ...";
|
||||
}
|
||||
os << "]";
|
||||
return os;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
|
||||
{
|
||||
os << t.mDesc;
|
||||
@@ -652,10 +697,17 @@ struct HostTensor
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
|
||||
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>)
|
||||
{
|
||||
os << type_convert<float>(t.mData[idx]) << " #### ";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::pk_int4_t>)
|
||||
{
|
||||
auto unpacked = pk_int4_t_to_int8x2_t(t.mData[idx]);
|
||||
os << "pk(" << static_cast<int>(unpacked[0]) << ", "
|
||||
<< static_cast<int>(unpacked[1]) << ") #### ";
|
||||
}
|
||||
else
|
||||
{
|
||||
os << t.mData[idx];
|
||||
|
||||
Reference in New Issue
Block a user