From 5256e754cc3c985d891c4cf9a90ea9d4330f851e Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Wed, 27 Aug 2025 21:17:24 -0400 Subject: [PATCH] feat(HostTensor): Extend support for HostTensor class' >> operator to print more data types (#2691) * feat(check_err): add a variable to adjust number of incorrect values to print * feat(host_tensor): add printing capability for fp8 bf8 int8 int4 * fix(gemm_utils): update acceptable data type * fix(host_tensor): print both 4 bit ints in pk_int4_t * refactor(HostTensor): define pk_int4_t_to_int8x2_t and fix typo in vector_type.hpp * feat(host_tensor): add print first n elements functions [ROCm/composable_kernel commit: f5f795c4d6cdfa86e282ba077839aad409ca3103] --- example/ck_tile/03_gemm/gemm_utils.hpp | 2 +- include/ck_tile/core/numeric/pk_int4.hpp | 21 +++++ include/ck_tile/core/numeric/vector_type.hpp | 82 ++++++++++---------- include/ck_tile/host/check_err.hpp | 15 ++-- include/ck_tile/host/host_tensor.hpp | 54 ++++++++++++- 5 files changed, 125 insertions(+), 49 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index ed2006d4b9..7f2af946e6 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -486,7 +486,7 @@ auto create_args(int argc, char* argv[]) .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index 0b0eb70beb..ad7956d32a 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -100,6 +100,7 @@ struct numeric_traits using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); +using int8x2_t = int8_t __attribute__((ext_vector_type(2))); CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x) { @@ -165,4 +166,24 @@ CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x) return res; } +CK_TILE_HOST_DEVICE int8x2_t pk_int4_t_to_int8x2_t(const pk_int4_t& x) +{ + uint8_t x_u8 = ck_tile::bit_cast(x); + + int8_t x_l = (x_u8 & 0x0F); + int8_t x_h = (x_u8 & 0xF0) >> 4; + + if(x_l & 0x08) + x_l |= 0xF0; + if(x_h & 0x08) + x_h |= 0xF0; + +#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE + int8x2_t res = {x_h, x_l}; +#else + int8x2_t res = {x_l, x_h}; +#endif + return res; +} + } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index bbd3d53827..5d8b109901 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -176,65 +176,65 @@ using uint16x64_t = uint16_t __attribute__((ext_vector_type(64))); // i8 // using int8_t -using int8x2_t = int8_t __attribute((ext_vector_type(2))); -using int8x4_t = int8_t __attribute((ext_vector_type(4))); -using int8x8_t = int8_t __attribute((ext_vector_type(8))); -using int8x16_t = int8_t __attribute((ext_vector_type(16))); -using int8x32_t = int8_t __attribute((ext_vector_type(32))); -using int8x64_t = int8_t __attribute((ext_vector_type(64))); +using int8x2_t = int8_t __attribute__((ext_vector_type(2))); +using int8x4_t = int8_t __attribute__((ext_vector_type(4))); +using int8x8_t = int8_t __attribute__((ext_vector_type(8))); +using int8x16_t = int8_t __attribute__((ext_vector_type(16))); +using int8x32_t = int8_t __attribute__((ext_vector_type(32))); +using int8x64_t = int8_t __attribute__((ext_vector_type(64))); // ui8 // using uint8_t -using uint8x2_t = uint8_t __attribute((ext_vector_type(2))); -using uint8x4_t = uint8_t __attribute((ext_vector_type(4))); -using uint8x8_t = uint8_t __attribute((ext_vector_type(8))); -using uint8x16_t = uint8_t __attribute((ext_vector_type(16))); -using uint8x32_t = uint8_t __attribute((ext_vector_type(32))); -using uint8x64_t = uint8_t __attribute((ext_vector_type(64))); +using uint8x2_t = uint8_t __attribute__((ext_vector_type(2))); +using uint8x4_t = uint8_t __attribute__((ext_vector_type(4))); +using uint8x8_t = uint8_t __attribute__((ext_vector_type(8))); +using uint8x16_t = uint8_t __attribute__((ext_vector_type(16))); +using uint8x32_t = uint8_t __attribute__((ext_vector_type(32))); +using uint8x64_t = uint8_t __attribute__((ext_vector_type(64))); #if CK_TILE_USE_CUSTOM_DATA_TYPE // f8 // using fp8_t -using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2))); -using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4))); -using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8))); -using fp8x16_t = fp8_raw_t __attribute((ext_vector_type(16))); -using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32))); -using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64))); +using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2))); +using fp8x4_t = fp8_raw_t __attribute__((ext_vector_type(4))); +using fp8x8_t = fp8_raw_t __attribute__((ext_vector_type(8))); +using fp8x16_t = fp8_raw_t __attribute__((ext_vector_type(16))); +using fp8x32_t = fp8_raw_t __attribute__((ext_vector_type(32))); +using fp8x64_t = fp8_raw_t __attribute__((ext_vector_type(64))); // bf8 // using bf8_t -using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2))); -using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4))); -using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8))); -using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16))); -using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32))); -using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64))); +using bf8x2_t = bf8_raw_t __attribute__((ext_vector_type(2))); +using bf8x4_t = bf8_raw_t __attribute__((ext_vector_type(4))); +using bf8x8_t = bf8_raw_t __attribute__((ext_vector_type(8))); +using bf8x16_t = bf8_raw_t __attribute__((ext_vector_type(16))); +using bf8x32_t = bf8_raw_t __attribute__((ext_vector_type(32))); +using bf8x64_t = bf8_raw_t __attribute__((ext_vector_type(64))); #else // f8 // using fp8_t -using fp8x2_t = fp8_t __attribute((ext_vector_type(2))); -using fp8x4_t = fp8_t __attribute((ext_vector_type(4))); -using fp8x8_t = fp8_t __attribute((ext_vector_type(8))); -using fp8x16_t = fp8_t __attribute((ext_vector_type(16))); -using fp8x32_t = fp8_t __attribute((ext_vector_type(32))); -using fp8x64_t = fp8_t __attribute((ext_vector_type(64))); +using fp8x2_t = fp8_t __attribute__((ext_vector_type(2))); +using fp8x4_t = fp8_t __attribute__((ext_vector_type(4))); +using fp8x8_t = fp8_t __attribute__((ext_vector_type(8))); +using fp8x16_t = fp8_t __attribute__((ext_vector_type(16))); +using fp8x32_t = fp8_t __attribute__((ext_vector_type(32))); +using fp8x64_t = fp8_t __attribute__((ext_vector_type(64))); // bf8 // using bf8_t -using bf8x2_t = bf8_t __attribute((ext_vector_type(2))); -using bf8x4_t = bf8_t __attribute((ext_vector_type(4))); -using bf8x8_t = bf8_t __attribute((ext_vector_type(8))); -using bf8x16_t = bf8_t __attribute((ext_vector_type(16))); -using bf8x32_t = bf8_t __attribute((ext_vector_type(32))); -using bf8x64_t = bf8_t __attribute((ext_vector_type(64))); +using bf8x2_t = bf8_t __attribute__((ext_vector_type(2))); +using bf8x4_t = bf8_t __attribute__((ext_vector_type(4))); +using bf8x8_t = bf8_t __attribute__((ext_vector_type(8))); +using bf8x16_t = bf8_t __attribute__((ext_vector_type(16))); +using bf8x32_t = bf8_t __attribute__((ext_vector_type(32))); +using bf8x64_t = bf8_t __attribute__((ext_vector_type(64))); #endif // pk_int4_t // using pk_int4_t -using pk_int4x2_t = int8_t __attribute((ext_vector_type(2))); -using pk_int4x4_t = int8_t __attribute((ext_vector_type(4))); -using pk_int4x8_t = int8_t __attribute((ext_vector_type(8))); -using pk_int4x16_t = int8_t __attribute((ext_vector_type(16))); -using pk_int4x32_t = int8_t __attribute((ext_vector_type(32))); +using pk_int4x2_t = int8_t __attribute__((ext_vector_type(2))); +using pk_int4x4_t = int8_t __attribute__((ext_vector_type(4))); +using pk_int4x8_t = int8_t __attribute__((ext_vector_type(8))); +using pk_int4x16_t = int8_t __attribute__((ext_vector_type(16))); +using pk_int4x32_t = int8_t __attribute__((ext_vector_type(32))); } // namespace ck_tile diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 171384be61..1a15271dc4 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -18,6 +18,9 @@ namespace ck_tile { +/** @brief Maximum number of error values to display when checking errors */ +constexpr int ERROR_DETAIL_LIMIT = 5; + /** @brief 8-bit floating point type */ using F8 = ck_tile::fp8_t; /** @brief 8-bit brain floating point type */ @@ -280,7 +283,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < ERROR_DETAIL_LIMIT) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -348,7 +351,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < ERROR_DETAIL_LIMIT) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -416,7 +419,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < ERROR_DETAIL_LIMIT) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -478,7 +481,7 @@ std::enable_if_t<(std::is_same_v, ranges::range_val { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < ERROR_DETAIL_LIMIT) { std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -564,7 +567,7 @@ std::enable_if_t<(std::is_same_v, ranges::range_val { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < ERROR_DETAIL_LIMIT) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl; @@ -630,7 +633,7 @@ std::enable_if_t<(std::is_same_v, ranges::range_val { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < ERROR_DETAIL_LIMIT) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index b7329fcac7..9b87518161 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -642,6 +642,51 @@ struct HostTensor size() * FromSize / ToSize}; } + /** + * @brief Print only the first N elements of the tensor + * + * @param os Output stream to write to + * @param n Number of elements to print (default: 5) + * @return std::ostream& Reference to the output stream + */ + std::ostream& print_first_n(std::ostream& os, std::size_t n = 5) const + { + os << mDesc; + os << "["; + for(typename Data::size_type idx = 0; idx < std::min(n, mData.size()); ++idx) + { + if(0 < idx) + { + os << ", "; + } + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + os << type_convert(mData[idx]) << " #### "; + } + else if constexpr(std::is_same_v) + { + auto unpacked = pk_int4_t_to_int8x2_t(mData[idx]); + os << "pk(" << static_cast(unpacked[0]) << ", " + << static_cast(unpacked[1]) << ") #### "; + } + else if constexpr(std::is_same_v) + { + os << static_cast(mData[idx]); + } + else + { + os << mData[idx]; + } + } + if(mData.size() > n) + { + os << ", ..."; + } + os << "]"; + return os; + } + friend std::ostream& operator<<(std::ostream& os, const HostTensor& t) { os << t.mDesc; @@ -652,10 +697,17 @@ struct HostTensor { os << ", "; } - if constexpr(std::is_same_v || std::is_same_v) + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { os << type_convert(t.mData[idx]) << " #### "; } + else if constexpr(std::is_same_v) + { + auto unpacked = pk_int4_t_to_int8x2_t(t.mData[idx]); + os << "pk(" << static_cast(unpacked[0]) << ", " + << static_cast(unpacked[1]) << ") #### "; + } else { os << t.mData[idx];