mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#5095 (commit 7e55766)
[CK_TILE] Enable MXFP6 for MX GEMM op ## Motivation Add support for MXFP6 in the MX GEMM op in CK-Tile. Depends on https://github.com/ROCm/rocm-libraries/pull/4594 ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
a5d0200ccf
commit
d7c761e060
@@ -61,6 +61,7 @@ CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1
|
||||
tf32_t,
|
||||
pk_fp4_t,
|
||||
pk_fp4_raw_t,
|
||||
pk_fp6x16_t,
|
||||
pk_int4_t,
|
||||
I8,
|
||||
I32,
|
||||
@@ -135,6 +136,7 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
|
||||
tf32_t,
|
||||
pk_fp4_t,
|
||||
pk_fp4_raw_t,
|
||||
pk_fp6x16_t,
|
||||
pk_int4_t,
|
||||
I8,
|
||||
I32,
|
||||
|
||||
@@ -169,6 +169,41 @@ struct FillUniformDistribution<ck_tile::pk_int4_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FillUniformDistribution<ck_tile::pk_fp6x16_t>
|
||||
{
|
||||
float a_{-2.f};
|
||||
float b_{2.f};
|
||||
std::optional<uint32_t> seed_{11939};
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
while(first != last)
|
||||
{
|
||||
ck_tile::pk_fp6x16_t pk{};
|
||||
for(ck_tile::index_t i = 0; i < ck_tile::pk_fp6x16_t::packed_size; ++i)
|
||||
{
|
||||
pk.pack(ck_tile::pk_fp6x16_t::float_to_fp6_e2m3(dis(gen)), i);
|
||||
}
|
||||
*first = pk;
|
||||
++first;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const
|
||||
-> std::void_t<decltype(std::declval<const FillUniformDistribution&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
|
||||
// clang-format off
|
||||
|
||||
Reference in New Issue
Block a user