update codes

This commit is contained in:
mtgu0705
2025-08-30 03:19:07 -05:00
parent 9c37e55d13
commit 16993acd1d
9 changed files with 2095 additions and 88 deletions

View File

@@ -71,6 +71,91 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename ScaleDataType,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const HostTensor<ScaleDataType>& scale_a,
const HostTensor<ScaleDataType>& scale_b,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
const std::size_t ScaleBlockSize = K / a_m_k_scale.get_length(1);
HostTensor<AccDataType> a_m_k_scaled({M, K}, {K, 1});
HostTensor<AccDataType> b_k_n_scaled({K, N}, {1, N});
for(int m = 0; m < M; ++m)
{
for(int k = 0; k < K; ++k)
{
if constexpr(std::is_same_v<ADataType, f4x2_pk_t>)
{
if(k % 2 == 1)
continue; // skip odd k
auto a_f4x2 = a_m_k(m, k);
auto a_scale = a_m_k_scale(m, k / ScaleBlockSize);
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
aut a_f4_lo =
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(Number<0>{}));
auto a_f4_hi =
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(Number<1>{}));
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
}
}
}
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
if constexpr(std::is_same_v<BDatatype, f4x2_pk_t>)
{
if(k % 2 == 1)
continue; // skip odd k
auto b_f4x2 = b_k_n(k, n);
auto b_scale = b_k_n_scale(k / ScaleBlockSize, n);
// auto f4_lo = ck_tile::type_convert<AccDataType>(f4x2)[0];
// auto f4_hi = ck_tile::type_convert<AccDataType>(f4x2)[1];
auto b_f4_lo =
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(Number<0>{}));
auto b_f4_hi =
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(Number<1>{}));
b_k_n_scaled(k, n) = b_f4_lo * b_scale;
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
}
else
{
b_k_n_scaled(k, n) =
ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
ck_tile::type_convert<AccDataType>(b_k_n_scale(k / ScaleBlockSize, n));
}
}
}
// call reference gemm
reference_gemm<AccDataType, AccDataType, AccDataType, CDataType>(
a_m_k_scaled, b_k_n_scaled, c_m_n);
}
template <typename ADataType,
typename BDataType,
typename DsDataType,