[CK_TILE] Patch for pk_fp4 ref check and buffer load. (#3044)

* Patch for pk_fp4_raw_t buffer load and ref check
This commit is contained in:
Gino Lu
2025-10-20 14:47:04 +08:00
committed by GitHub
parent af3786fe08
commit fb1d090f3c
2 changed files with 54 additions and 0 deletions

View File

@@ -648,4 +648,56 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
return res;
}
/**
* @brief Check errors between pk_fp4_t ranges
*
* Compares two ranges of pk_fp4_t without tolerance.
* This specialization handles ck_tile::pk_fp4_t type.
*
* @tparam Range Type of output range
* @tparam RefRange Type of reference range
* @param out Output range to check
* @param ref Reference range to check against
* @param msg Error message to display if check fails
* @return True if check passes, false otherwise
*/
template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_same_v<ranges::range_value_t<Range>, pk_fp4_t>),
bool>
CK_TILE_HOST check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double = 0,
double = 0)
{
if(check_size_mismatch(out, ref, msg))
return false;
int err_count = 0;
auto update_err = [&](pk_fp4_raw_t o, pk_fp4_raw_t r, std::size_t index) {
if(o != r)
{
std::cerr << msg << " out[" << index << "] != ref[" << index
<< "]: " << type_convert<float>(pk_fp4_t{o})
<< " != " << type_convert<float>(pk_fp4_t{r}) << std::endl;
++err_count;
}
};
for(std::size_t i = 0; i < ref.size(); ++i)
{
const pk_fp4_t o = *std::next(std::begin(out), i);
const pk_fp4_t r = *std::next(std::begin(ref), i);
update_err(o._unpack(number<0>{}), r._unpack(number<0>{}), i * 2);
update_err(o._unpack(number<1>{}), r._unpack(number<1>{}), i * 2 + 1);
}
if(err_count > 0)
{
report_error_stats(err_count, numeric<pk_fp4_t>::max(), ref.size());
}
return err_count == 0;
}
} // namespace ck_tile