mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
update ck_tile moe
This commit is contained in:
@@ -672,31 +672,48 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto scale_m_desc = kargs.scale_m;
|
||||
constexpr int AGranularityK = decltype(scale_m_desc)::GranularityK;
|
||||
|
||||
const auto& scale_a_tensor_view = [&]() {
|
||||
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
|
||||
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
|
||||
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
|
||||
// Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
|
||||
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl));
|
||||
const auto scale_a_desc = transform_tensor_descriptor(
|
||||
scale_a_naive_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_m_desc.ptr), scale_a_desc);
|
||||
auto scale_m_desc = kargs.scale_m;
|
||||
if constexpr(AQUANT_Pipeline)
|
||||
{
|
||||
constexpr int AGranularityK = decltype(scale_m_desc)::GranularityK == 0 ? 1 : decltype(scale_m_desc)::GranularityK;
|
||||
|
||||
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
|
||||
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
|
||||
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
|
||||
// Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
|
||||
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl));
|
||||
const auto scale_a_desc = transform_tensor_descriptor(
|
||||
scale_a_naive_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_m_desc.ptr), scale_a_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int AGranularityK = 32;
|
||||
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
|
||||
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
|
||||
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_m_desc.ptr),
|
||||
make_tuple(scale_m_packs * MThreadPerXdl, scale_k_packs * KThreadPerXdl),
|
||||
make_tuple(scale_k_packs * KThreadPerXdl, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto scale_n = kargs.scale_n;
|
||||
|
||||
constexpr int BGranularityK = decltype(scale_n)::GranularityK;
|
||||
const auto scale_b_flat_view = [&]() {
|
||||
auto scale_n = kargs.scale_n;
|
||||
constexpr int BGranularityK = decltype(scale_n)::GranularityK == 0 ? 1 : decltype(scale_n)::GranularityK;
|
||||
if constexpr(AQUANT_Pipeline)
|
||||
{
|
||||
index_t scale_k =
|
||||
@@ -833,17 +850,12 @@ struct MoeFlatmmKernel
|
||||
|
||||
constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
|
||||
auto a_scale_block_window =
|
||||
// make_tile_window(views.at(I3),
|
||||
// make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
// number<TilePartitioner::KPerBlock / GranularityK>{}),
|
||||
// {coord_m, 0});
|
||||
make_tile_window(
|
||||
views.at(I3),
|
||||
make_tuple(number<TilePartitioner::MPerBlock / M_Pack>{},
|
||||
number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
|
||||
{coord_m / M_Pack, 0});
|
||||
|
||||
// constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
|
||||
constexpr int XDLPerLoadScaleB =
|
||||
BMXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
|
||||
|
||||
@@ -943,7 +955,7 @@ struct MoeFlatmmKernel
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
(splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / 2;
|
||||
(splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / WeightPackedSize;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
const AccDataType* exp_weight_ptr =
|
||||
|
||||
Reference in New Issue
Block a user