support NRepeat=1 in A16W4_MoE_gemm2 to improve performance in the small tokens case

This commit is contained in:
Feng Shijie
2025-12-01 12:24:18 +00:00
parent abd6a4b3fc
commit bf6447cd54
4 changed files with 110 additions and 35 deletions

View File

@@ -87,11 +87,11 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp / FlatmmConfig::N_Warp_Tile;
static_assert(NRepeat == 1 || NRepeat % 2 == 0);
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
{
static_assert(
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
static_assert(NRepeat % 2 == 0, "requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
}
using ComputeDataType = ADataType;
@@ -139,8 +139,8 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
scheduler,
has_hot_loop_v,
tail_number_v>>;
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
constexpr int BlockedXDLN_PerWarp =
NRepeat == 1 ? 1 : 2; // determined by scale shuffle pattern
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
@@ -370,7 +370,6 @@ auto shuffle_mxfp4_scale(const ck_tile::HostTensor<T>& scale, int experts_cnt)
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)

View File

@@ -92,7 +92,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
// TODO: replace the magic declaration
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk;
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * topk;
ck_tile::index_t valid_tile_num = sorted_tile_num;
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;

View File

@@ -873,6 +873,17 @@ struct MoeFlatmmKernel
auto& c_block_window = gemm_tile_windows.at(number<2>{});
if constexpr(MXFP4_Pipeline && kNRepeat == 1)
{
// Continuous warps compute the values spaced N_Pack XDLs
auto c_window_coord = c_block_window.get_window_origin();
int group_128_idx = c_window_coord[I1] / (N_Pack * NWave * NPerXdl);
int raked_idx = (c_window_coord[I1] / (NWave * NPerXdl)) % N_Pack;
c_window_coord[I1] = group_128_idx * (N_Pack * NWave * NPerXdl) + raked_idx * NPerXdl;
c_block_window.set_window_origin(c_window_coord);
}
// Run EpiloguePipeline
{
using EpiProblem = typename EpiloguePipeline::Problem;
@@ -1002,12 +1013,12 @@ struct MoeFlatmmKernel
kargs.scale_n.ptr + expert_id * kargs.N,
make_tuple(1, kargs.N),
make_tuple(0, scale_stride_n),
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
number<ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1>{},
number<1>{}), // MXF4_Pipeline does't use scale_n, so there is no need to
// permute as n_pack
make_tuple(number<TilePartitioner::MPerBlock>{},
number < IsGateUp ? TilePartitioner::NPerBlock / 2
: TilePartitioner::NPerBlock > {}),
number<IsGateUp ? TilePartitioner::NPerBlock / 2
: TilePartitioner::NPerBlock>{}),
{0, IsGateUp ? coord_n / 2 : coord_n},
output_acc_tile_distr);
@@ -1016,7 +1027,7 @@ struct MoeFlatmmKernel
kargs.scale_n.ptr + expert_id * kargs.N + kargs.N / 2,
make_tuple(1, kargs.N),
make_tuple(0, scale_stride_n),
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
number<ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1>{},
number<1>{}),
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock / 2>{}),
@@ -1033,8 +1044,8 @@ struct MoeFlatmmKernel
auto exp_bias_window = make_tile_window(
permute_tensor_view(exp_bias_view, number<(MXFP4_Pipeline && !IsInputGemm)>{}),
make_tuple(number<TilePartitioner::MPerBlock>{},
number < IsGateUp ? TilePartitioner::NPerBlock / 2
: TilePartitioner::NPerBlock > {}),
number<IsGateUp ? TilePartitioner::NPerBlock / 2
: TilePartitioner::NPerBlock>{}),
{0, IsGateUp ? coord_n / 2 : coord_n},
output_acc_tile_distr);
@@ -1110,17 +1121,47 @@ struct MoeFlatmmKernel
"Currently, the CShuffle EpiloguePipeline only supports the Row Major "
"Output layout");
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<
kBlockSize,
MPerIterationShuffle,
LDS_NPerIterationShuffle,
kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
EpiProblem::kNumWaveGroups>;
constexpr int OutputVectorSize =
kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC();
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<kBlockSize,
MPerIterationShuffle,
LDS_NPerIterationShuffle,
OutputVectorSize,
tile_distribution_pattern::thread_raked,
EpiProblem::kNumWaveGroups>;
constexpr auto dram_tile_distribution =
TileEncodingPattern::make_2d_static_tile_distribution();
constexpr auto c_dram_distribution = [&] {
if constexpr(MXFP4_Pipeline && kNRepeat == 1)
{
// Continuous warps compute the values spaced N_Pack XDLs, the remaining part
// keeps the same as dram_tile_distribution.
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>,
tuple<sequence<NWave,
get_warp_size() / (kNPerBlock / OutputVectorSize),
// M2 = M / M0 / M1
kMPerBlock / (NWave * get_warp_size() /
(kNPerBlock / OutputVectorSize))>,
sequence<kNPerBlock / NPerXdl,
N_Pack,
NPerXdl / OutputVectorSize,
OutputVectorSize>>,
tuple<sequence<1>, sequence<1, 2, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<2, 3>>{});
}
else
{
return dram_tile_distribution;
}
}();
constexpr auto LdsTileDistr = [&] {
if constexpr(IsGateUp)
return make_static_tile_distribution(
@@ -1304,22 +1345,24 @@ struct MoeFlatmmKernel
make_tile_scatter_gather(c_block_window.get_bottom_tensor_view(),
c_block_window.get_window_lengths(),
c_block_window.get_window_origin(),
dram_tile_distribution,
c_dram_distribution,
c_scatter_offsets[mIter],
c_scatter_valids[mIter]);
auto redistributed_c_out_tensor = make_static_distributed_tensor<EDataType>(
c_dram_distribution, c_out_tensor.get_thread_buffer());
if constexpr(!IsInputGemm ||
EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)
c_scatter_tile_window.update(c_out_tensor);
c_scatter_tile_window.update(redistributed_c_out_tensor);
else
c_scatter_tile_window.store(c_out_tensor);
c_scatter_tile_window.store(redistributed_c_out_tensor);
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
// row_offset of out windows has been included in scatter offset
move_tile_window(c_block_window,
{0, step.at(number<1>{}) / number < IsGateUp ? 2 : 1 > {}});
{0, step.at(number<1>{}) / number<IsGateUp ? 2 : 1>{}});
}
});
}

View File

@@ -140,11 +140,10 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
static constexpr int XDL_PerScaleN = ContinuousScaleNPerThread; // 2
static_assert(XDL_PerScaleK % XDL_PerWeightK == 0);
static_assert(KIterPerWarp % XDL_PerScaleK == 0);
static_assert(NIterPerWarp % XDL_PerScaleN == 0);
static constexpr int MXFP4KPerWarp = KIterPerWarp / XDL_PerWeightK;
static constexpr int ScaleKPerWarp = KIterPerWarp / XDL_PerScaleK;
static constexpr int ScaleNPerWarp = NIterPerWarp / XDL_PerScaleN;
static constexpr int ScaleNPerWarp = max(1, NIterPerWarp / XDL_PerScaleN);
static constexpr int MXFP4K_PerScaleK = MXFP4KPerWarp / ScaleKPerWarp;
@@ -168,8 +167,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
static constexpr index_t ScaleBload_K1 = ContinuousScaleNPerThread * ContinuousScaleKPerThread;
static constexpr index_t ScaleBload_num =
kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 /
WaveSize; // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
max(1,
kNPerBlock* kKPerBlock / NWarp / 32 / ScaleBload_K1 /
WaveSize); // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
static constexpr index_t Bload_total_num =
Bload_num_perK * KIterPerWarp + ScaleBload_num + 0X3f0;
static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
@@ -559,16 +559,30 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
auto scale_b_flat_distribution =
PipelinePolicy::template MakeFp4ScaleBFlatDramTileDistribution<Problem>();
auto b_flat_dram_coord = b_flat_dram_block_window_tmp.get_window_origin();
auto scale_b_dram_coord = scale_b_flat_window.get_window_origin();
int global_packed_n_idx = 0;
if constexpr(NIterPerWarp == 1)
{
// Continuous warps compute the values spaced ContinuousScaleNPerThread XDLs
int group_scale_idx = b_flat_dram_coord[I0] / (ContinuousScaleNPerThread * NWarp);
global_packed_n_idx = (b_flat_dram_coord[I0] / NWarp) % ContinuousScaleNPerThread;
b_flat_dram_coord[I0] =
group_scale_idx * (ContinuousScaleNPerThread * NWarp) + global_packed_n_idx;
scale_b_dram_coord[I0] = group_scale_idx * NWarp;
}
auto b_flat_dram_window = make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_dram_coord,
b_flat_distribution);
auto scale_b_flat_dram_window = make_tile_window(
scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<ScaleKFlatPerWarp>{}),
scale_b_flat_window.get_window_origin(),
scale_b_dram_coord,
scale_b_flat_distribution);
using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
@@ -718,11 +732,30 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
auto xdl_kIter) {
auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
auto scale = scale_tensor.get_thread_buffer()[scale_offset];
auto scale = [&] {
if constexpr(NIterPerWarp == 1)
{
if(global_packed_n_idx == 0)
{
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
return scale_tensor.get_thread_buffer()[scale_offset];
}
else
{
auto scale_offset =
scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{} + I1;
return scale_tensor.get_thread_buffer()[scale_offset];
}
}
else
{
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
return scale_tensor.get_thread_buffer()[scale_offset];
}
}();
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;