mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
fix after merge ginolu/add_wgmfma_dispatcher
This commit is contained in:
@@ -282,7 +282,7 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
const std::size_t ScaleBlockSize = K / a_m_k_scale.get_length(1);
|
||||
const std::size_t ScaleBlockSize = K / scale_a.get_length(1);
|
||||
|
||||
HostTensor<AccDataType> a_m_k_scaled({M, K}, {K, 1});
|
||||
HostTensor<AccDataType> b_k_n_scaled({K, N}, {1, N});
|
||||
@@ -291,19 +291,19 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, f4x2_pk_t>)
|
||||
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
if(k % 2 == 1)
|
||||
continue; // skip odd k
|
||||
|
||||
auto a_f4x2 = a_m_k(m, k);
|
||||
auto a_scale = a_m_k_scale(m, k / ScaleBlockSize);
|
||||
auto a_scale = scale_a(m, k / ScaleBlockSize);
|
||||
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
|
||||
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
|
||||
aut a_f4_lo =
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(Number<0>{}));
|
||||
auto a_f4_lo =
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<0>{}));
|
||||
auto a_f4_hi =
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(Number<1>{}));
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<1>{}));
|
||||
|
||||
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
|
||||
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
|
||||
@@ -315,19 +315,19 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDatatype, f4x2_pk_t>)
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
if(k % 2 == 1)
|
||||
continue; // skip odd k
|
||||
|
||||
auto b_f4x2 = b_k_n(k, n);
|
||||
auto b_scale = b_k_n_scale(k / ScaleBlockSize, n);
|
||||
auto b_scale = scale_b(k / ScaleBlockSize, n);
|
||||
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
|
||||
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
|
||||
auto b_f4_lo =
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(Number<0>{}));
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<0>{}));
|
||||
auto b_f4_hi =
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(Number<1>{}));
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<1>{}));
|
||||
|
||||
b_k_n_scaled(k, n) = b_f4_lo * b_scale;
|
||||
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
|
||||
@@ -336,7 +336,7 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
|
||||
ck_tile::type_convert<AccDataType>(b_k_n_scale(k / ScaleBlockSize, n));
|
||||
ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user