[CK_TILE] Refine pk_fp4's fill, pack, and unpack (#2845)

* fix bug

* let pack/unpack return pk_fp4_t

* fix clang-format
This commit is contained in:
Gino Lu
2025-09-17 10:54:06 +08:00
committed by GitHub
parent db79fad16f
commit c2997f2b7f
3 changed files with 70 additions and 36 deletions

View File

@@ -67,7 +67,10 @@ struct FillUniformDistribution
: std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
if constexpr(numeric_traits<T>::PackedSize == 2)
return ck_tile::type_convert<T>(fp32x2_t{dis(gen), dis(gen)});
else
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
@@ -77,8 +80,12 @@ struct FillUniformDistribution
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
std::generate(first, last, [&dis, &gen]() {
if constexpr(numeric_traits<T>::PackedSize == 2)
return ck_tile::type_convert<T>(fp32x2_t{dis(gen), dis(gen)});
else
return ck_tile::type_convert<T>(dis(gen));
});
}
}