Mx fp6 flatmm (#3601)

* add fp6 data-type and support sync/async dwordx3 load/store

* clang-format

* pre-commit

* 1st commit

* default mnk pass ut

* fix a distrubution

* fix

* fix bdram distr

* update

* pass ut

* improve perf

* update

* clean code

* resolve copilot comment

* reslove comment

* clang-format

---------

Co-authored-by: ZheWang <zhewan@amd.com>
This commit is contained in:
ZheWang
2026-02-02 16:04:40 +08:00
committed by GitHub
parent 1ae83137eb
commit e6bcd192d4
21 changed files with 761 additions and 136 deletions

View File

@@ -625,6 +625,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
}
else if constexpr(std::is_same_v<ADataType, pk_fp6x16_t>)
{
if(k % pk_fp6x16_t::packed_size != 0)
continue;
auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++)
{
a_m_k_scaled(m, k + k_) =
pk_fp6x16_t::fp6_e2m3_to_float(a_m_k(m, k).unpack(k_)) * a_scale;
}
}
else
{
a_m_k_scaled(m, k) =
@@ -653,6 +664,17 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
b_k_n_scaled(k, n) = b_f4_lo * b_scale;
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
}
else if constexpr(std::is_same_v<ADataType, pk_fp6x16_t>)
{
if(k % pk_fp6x16_t::packed_size != 0)
continue;
auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
for(std::size_t k_ = 0; k_ < pk_fp6x16_t::packed_size; k_++)
{
b_k_n_scaled(k + k_, n) =
pk_fp6x16_t::fp6_e2m3_to_float(b_k_n(k, n).unpack(k_)) * b_scale;
}
}
else
{
b_k_n_scaled(k, n) =