feat(HostTensor): Extend support for HostTensor class' >> operator to print more data types (#2691)

* feat(check_err): add a variable to adjust number of incorrect values to print

* feat(host_tensor): add printing capability for fp8 bf8 int8 int4

* fix(gemm_utils): update acceptable data type

* fix(host_tensor): print both 4 bit ints in pk_int4_t

* refactor(HostTensor): define pk_int4_t_to_int8x2_t and fix typo in vector_type.hpp

* feat(host_tensor): add print first n elements functions
This commit is contained in:
Aviral Goel
2025-08-27 21:17:24 -04:00
committed by GitHub
parent 9751583f95
commit f5f795c4d6
5 changed files with 125 additions and 49 deletions

View File

@@ -642,6 +642,51 @@ struct HostTensor
size() * FromSize / ToSize};
}
/**
* @brief Print only the first N elements of the tensor
*
* @param os Output stream to write to
* @param n Number of elements to print (default: 5)
* @return std::ostream& Reference to the output stream
*/
std::ostream& print_first_n(std::ostream& os, std::size_t n = 5) const
{
os << mDesc;
os << "[";
for(typename Data::size_type idx = 0; idx < std::min(n, mData.size()); ++idx)
{
if(0 < idx)
{
os << ", ";
}
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>)
{
os << type_convert<float>(mData[idx]) << " #### ";
}
else if constexpr(std::is_same_v<T, ck_tile::pk_int4_t>)
{
auto unpacked = pk_int4_t_to_int8x2_t(mData[idx]);
os << "pk(" << static_cast<int>(unpacked[0]) << ", "
<< static_cast<int>(unpacked[1]) << ") #### ";
}
else if constexpr(std::is_same_v<T, int8_t>)
{
os << static_cast<int>(mData[idx]);
}
else
{
os << mData[idx];
}
}
if(mData.size() > n)
{
os << ", ...";
}
os << "]";
return os;
}
friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
{
os << t.mDesc;
@@ -652,10 +697,17 @@ struct HostTensor
{
os << ", ";
}
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t> ||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>)
{
os << type_convert<float>(t.mData[idx]) << " #### ";
}
else if constexpr(std::is_same_v<T, ck_tile::pk_int4_t>)
{
auto unpacked = pk_int4_t_to_int8x2_t(t.mData[idx]);
os << "pk(" << static_cast<int>(unpacked[0]) << ", "
<< static_cast<int>(unpacked[1]) << ") #### ";
}
else
{
os << t.mData[idx];