mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Mx fp6 flatmm (#3601)
* add fp6 data-type and support sync/async dwordx3 load/store * clang-format * pre-commit * 1st commit * default mnk pass ut * fix a distrubution * fix * fix bdram distr * update * pass ut * improve perf * update * clean code * resolve copilot comment * reslove comment * clang-format --------- Co-authored-by: ZheWang <zhewan@amd.com>
This commit is contained in:
@@ -625,6 +625,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
|
||||
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp6x16_t>)
|
||||
{
|
||||
if(k % pk_fp6x16_t::packed_size != 0)
|
||||
continue;
|
||||
auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
|
||||
for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++)
|
||||
{
|
||||
a_m_k_scaled(m, k + k_) =
|
||||
pk_fp6x16_t::fp6_e2m3_to_float(a_m_k(m, k).unpack(k_)) * a_scale;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_scaled(m, k) =
|
||||
@@ -653,6 +664,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
b_k_n_scaled(k, n) = b_f4_lo * b_scale;
|
||||
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp6x16_t>)
|
||||
{
|
||||
if(k % pk_fp6x16_t::packed_size != 0)
|
||||
continue;
|
||||
auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
|
||||
for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++)
|
||||
{
|
||||
b_k_n_scaled(k + k_, n) =
|
||||
pk_fp6x16_t::fp6_e2m3_to_float(b_k_n(k, n).unpack(k_)) * b_scale;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
|
||||
Reference in New Issue
Block a user