From b7e5da5e8301905f955b3bd724c9ba3ff2e32a90 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Mon, 20 Oct 2025 14:47:04 +0800 Subject: [PATCH] [CK_TILE] Patch for pk_fp4 ref check and buffer load. (#3044) * Patch for pk_fp4_raw_t buffer load and ref check [ROCm/composable_kernel commit: fb1d090f3c475907fbcbdaf9dcfd2829f92d3c26] --- .../arch/amd_buffer_addressing_builtins.hpp | 2 + include/ck_tile/host/check_err.hpp | 52 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 38e033cd92..4a86ca785d 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1405,6 +1405,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)), "wrong! not implemented"); diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 1a15271dc4..91d387796f 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -648,4 +648,56 @@ std::enable_if_t<(std::is_same_v, ranges::range_val return res; } +/** + * @brief Check errors between pk_fp4_t ranges + * + * Compares two ranges of pk_fp4_t without tolerance. + * This specialization handles ck_tile::pk_fp4_t type. + * + * @tparam Range Type of output range + * @tparam RefRange Type of reference range + * @param out Output range to check + * @param ref Reference range to check against + * @param msg Error message to display if check fails + * @return True if check passes, false otherwise + */ +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, pk_fp4_t>), + bool> + CK_TILE_HOST check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double = 0, + double = 0) +{ + if(check_size_mismatch(out, ref, msg)) + return false; + + int err_count = 0; + + auto update_err = [&](pk_fp4_raw_t o, pk_fp4_raw_t r, std::size_t index) { + if(o != r) + { + std::cerr << msg << " out[" << index << "] != ref[" << index + << "]: " << type_convert(pk_fp4_t{o}) + << " != " << type_convert(pk_fp4_t{r}) << std::endl; + ++err_count; + } + }; + + for(std::size_t i = 0; i < ref.size(); ++i) + { + const pk_fp4_t o = *std::next(std::begin(out), i); + const pk_fp4_t r = *std::next(std::begin(ref), i); + update_err(o._unpack(number<0>{}), r._unpack(number<0>{}), i * 2); + update_err(o._unpack(number<1>{}), r._unpack(number<1>{}), i * 2 + 1); + } + if(err_count > 0) + { + report_error_stats(err_count, numeric::max(), ref.size()); + } + return err_count == 0; +} + } // namespace ck_tile