mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Merge commit 'e135dd518d19a36466ce7c61bb9d3203ec18c8af' into develop
This commit is contained in:
@@ -382,6 +382,93 @@ reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename ScaleDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const HostTensor<ScaleDataType>& scale_a,
|
||||
const HostTensor<ScaleDataType>& scale_b,
|
||||
const AElementOp& = {},
|
||||
const BElementOp& = {},
|
||||
const ACCElementOp& = {})
|
||||
{
|
||||
static_assert(std::is_same_v<AElementOp, ck_tile::identity>);
|
||||
static_assert(std::is_same_v<BElementOp, ck_tile::identity>);
|
||||
static_assert(std::is_same_v<ACCElementOp, ck_tile::identity>);
|
||||
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
const std::size_t ScaleBlockSize = K / scale_a.get_length(1);
|
||||
|
||||
HostTensor<AccDataType> a_m_k_scaled({std::size_t(M), std::size_t(K)},
|
||||
{std::size_t(K), std::size_t(1)});
|
||||
HostTensor<AccDataType> b_k_n_scaled({std::size_t(K), std::size_t(N)},
|
||||
{std::size_t(1), std::size_t(K)});
|
||||
|
||||
for(std::size_t m = 0; m < M; ++m)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
if(k % 2 == 1)
|
||||
continue; // skip odd k
|
||||
|
||||
auto a_f4x2 = a_m_k(m, k);
|
||||
auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
|
||||
auto a_f4_lo =
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<0>{}));
|
||||
auto a_f4_hi =
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<1>{}));
|
||||
|
||||
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
|
||||
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(std::size_t n = 0; n < N; n++)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; k++)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
if(k % 2 == 1)
|
||||
continue; // skip odd k
|
||||
|
||||
auto b_f4x2 = b_k_n(k, n);
|
||||
auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
|
||||
auto b_f4_lo =
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<0>{}));
|
||||
auto b_f4_hi =
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<1>{}));
|
||||
|
||||
b_k_n_scaled(k, n) = b_f4_lo * b_scale;
|
||||
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
|
||||
ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// call reference gemm
|
||||
reference_gemm<AccDataType, AccDataType, AccDataType, CDataType>(
|
||||
a_m_k_scaled, b_k_n_scaled, c_m_n);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
|
||||
Reference in New Issue
Block a user