update ck_tile moe

This commit is contained in:
zanzhang
2025-12-18 09:39:49 +08:00
parent 9ae1e18628
commit 2f3874b515

View File

@@ -672,31 +672,48 @@ struct MoeFlatmmKernel
}
}();
auto scale_m_desc = kargs.scale_m;
constexpr int AGranularityK = decltype(scale_m_desc)::GranularityK;
const auto& scale_a_tensor_view = [&]() {
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
// Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl));
const auto scale_a_desc = transform_tensor_descriptor(
scale_a_naive_desc,
make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)),
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_m_desc.ptr), scale_a_desc);
auto scale_m_desc = kargs.scale_m;
if constexpr(AQUANT_Pipeline)
{
constexpr int AGranularityK = decltype(scale_m_desc)::GranularityK == 0 ? 1 : decltype(scale_m_desc)::GranularityK;
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
// Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl));
const auto scale_a_desc = transform_tensor_descriptor(
scale_a_naive_desc,
make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)),
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_m_desc.ptr), scale_a_desc);
}
else
{
constexpr int AGranularityK = 32;
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
return make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_m_desc.ptr),
make_tuple(scale_m_packs * MThreadPerXdl, scale_k_packs * KThreadPerXdl),
make_tuple(scale_k_packs * KThreadPerXdl, 1),
number<8>{},
number<1>{});
}
}();
auto scale_n = kargs.scale_n;
constexpr int BGranularityK = decltype(scale_n)::GranularityK;
const auto scale_b_flat_view = [&]() {
auto scale_n = kargs.scale_n;
constexpr int BGranularityK = decltype(scale_n)::GranularityK == 0 ? 1 : decltype(scale_n)::GranularityK;
if constexpr(AQUANT_Pipeline)
{
index_t scale_k =
@@ -833,17 +850,12 @@ struct MoeFlatmmKernel
constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
auto a_scale_block_window =
// make_tile_window(views.at(I3),
// make_tuple(number<TilePartitioner::MPerBlock>{},
// number<TilePartitioner::KPerBlock / GranularityK>{}),
// {coord_m, 0});
make_tile_window(
views.at(I3),
make_tuple(number<TilePartitioner::MPerBlock / M_Pack>{},
number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
{coord_m / M_Pack, 0});
// constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
constexpr int XDLPerLoadScaleB =
BMXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
@@ -943,7 +955,7 @@ struct MoeFlatmmKernel
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_flat_ptr =
static_cast<const BDataType*>(kargs.b_ptr) +
(splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / 2;
(splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / WeightPackedSize;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
const AccDataType* exp_weight_ptr =