mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
support NRepeat=1 in A16W4_MoE_gemm2 to improve performance in the small tokens case
This commit is contained in:
@@ -87,11 +87,11 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
|
||||
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
|
||||
|
||||
if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp / FlatmmConfig::N_Warp_Tile;
|
||||
static_assert(NRepeat == 1 || NRepeat % 2 == 0);
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
static_assert(
|
||||
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
|
||||
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
|
||||
static_assert(NRepeat % 2 == 0, "requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
|
||||
}
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
@@ -139,8 +139,8 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
constexpr int BlockedXDLN_PerWarp =
|
||||
NRepeat == 1 ? 1 : 2; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
@@ -370,7 +370,6 @@ auto shuffle_mxfp4_scale(const ck_tile::HostTensor<T>& scale, int experts_cnt)
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
|
||||
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
|
||||
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
|
||||
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
|
||||
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
|
||||
@@ -92,7 +92,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
|
||||
|
||||
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk;
|
||||
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * topk;
|
||||
ck_tile::index_t valid_tile_num = sorted_tile_num;
|
||||
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
|
||||
|
||||
@@ -873,6 +873,17 @@ struct MoeFlatmmKernel
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(number<2>{});
|
||||
|
||||
if constexpr(MXFP4_Pipeline && kNRepeat == 1)
|
||||
{
|
||||
// Continuous warps compute the values spaced N_Pack XDLs
|
||||
auto c_window_coord = c_block_window.get_window_origin();
|
||||
int group_128_idx = c_window_coord[I1] / (N_Pack * NWave * NPerXdl);
|
||||
int raked_idx = (c_window_coord[I1] / (NWave * NPerXdl)) % N_Pack;
|
||||
c_window_coord[I1] = group_128_idx * (N_Pack * NWave * NPerXdl) + raked_idx * NPerXdl;
|
||||
|
||||
c_block_window.set_window_origin(c_window_coord);
|
||||
}
|
||||
|
||||
// Run EpiloguePipeline
|
||||
{
|
||||
using EpiProblem = typename EpiloguePipeline::Problem;
|
||||
@@ -1002,12 +1013,12 @@ struct MoeFlatmmKernel
|
||||
kargs.scale_n.ptr + expert_id * kargs.N,
|
||||
make_tuple(1, kargs.N),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
|
||||
number<ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1>{},
|
||||
number<1>{}), // MXF4_Pipeline does't use scale_n, so there is no need to
|
||||
// permute as n_pack
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number < IsGateUp ? TilePartitioner::NPerBlock / 2
|
||||
: TilePartitioner::NPerBlock > {}),
|
||||
number<IsGateUp ? TilePartitioner::NPerBlock / 2
|
||||
: TilePartitioner::NPerBlock>{}),
|
||||
{0, IsGateUp ? coord_n / 2 : coord_n},
|
||||
output_acc_tile_distr);
|
||||
|
||||
@@ -1016,7 +1027,7 @@ struct MoeFlatmmKernel
|
||||
kargs.scale_n.ptr + expert_id * kargs.N + kargs.N / 2,
|
||||
make_tuple(1, kargs.N),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
|
||||
number<ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1>{},
|
||||
number<1>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock / 2>{}),
|
||||
@@ -1033,8 +1044,8 @@ struct MoeFlatmmKernel
|
||||
auto exp_bias_window = make_tile_window(
|
||||
permute_tensor_view(exp_bias_view, number<(MXFP4_Pipeline && !IsInputGemm)>{}),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number < IsGateUp ? TilePartitioner::NPerBlock / 2
|
||||
: TilePartitioner::NPerBlock > {}),
|
||||
number<IsGateUp ? TilePartitioner::NPerBlock / 2
|
||||
: TilePartitioner::NPerBlock>{}),
|
||||
{0, IsGateUp ? coord_n / 2 : coord_n},
|
||||
output_acc_tile_distr);
|
||||
|
||||
@@ -1110,17 +1121,47 @@ struct MoeFlatmmKernel
|
||||
"Currently, the CShuffle EpiloguePipeline only supports the Row Major "
|
||||
"Output layout");
|
||||
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<
|
||||
kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
LDS_NPerIterationShuffle,
|
||||
kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
EpiProblem::kNumWaveGroups>;
|
||||
constexpr int OutputVectorSize =
|
||||
kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC();
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
LDS_NPerIterationShuffle,
|
||||
OutputVectorSize,
|
||||
tile_distribution_pattern::thread_raked,
|
||||
EpiProblem::kNumWaveGroups>;
|
||||
|
||||
constexpr auto dram_tile_distribution =
|
||||
TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
|
||||
constexpr auto c_dram_distribution = [&] {
|
||||
if constexpr(MXFP4_Pipeline && kNRepeat == 1)
|
||||
{
|
||||
// Continuous warps compute the values spaced N_Pack XDLs, the remaining part
|
||||
// keeps the same as dram_tile_distribution.
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NWave,
|
||||
get_warp_size() / (kNPerBlock / OutputVectorSize),
|
||||
// M2 = M / M0 / M1
|
||||
kMPerBlock / (NWave * get_warp_size() /
|
||||
(kNPerBlock / OutputVectorSize))>,
|
||||
sequence<kNPerBlock / NPerXdl,
|
||||
N_Pack,
|
||||
NPerXdl / OutputVectorSize,
|
||||
OutputVectorSize>>,
|
||||
tuple<sequence<1>, sequence<1, 2, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return dram_tile_distribution;
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto LdsTileDistr = [&] {
|
||||
if constexpr(IsGateUp)
|
||||
return make_static_tile_distribution(
|
||||
@@ -1304,22 +1345,24 @@ struct MoeFlatmmKernel
|
||||
make_tile_scatter_gather(c_block_window.get_bottom_tensor_view(),
|
||||
c_block_window.get_window_lengths(),
|
||||
c_block_window.get_window_origin(),
|
||||
dram_tile_distribution,
|
||||
c_dram_distribution,
|
||||
c_scatter_offsets[mIter],
|
||||
c_scatter_valids[mIter]);
|
||||
|
||||
auto redistributed_c_out_tensor = make_static_distributed_tensor<EDataType>(
|
||||
c_dram_distribution, c_out_tensor.get_thread_buffer());
|
||||
if constexpr(!IsInputGemm ||
|
||||
EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)
|
||||
c_scatter_tile_window.update(c_out_tensor);
|
||||
c_scatter_tile_window.update(redistributed_c_out_tensor);
|
||||
else
|
||||
c_scatter_tile_window.store(c_out_tensor);
|
||||
c_scatter_tile_window.store(redistributed_c_out_tensor);
|
||||
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
// row_offset of out windows has been included in scatter offset
|
||||
move_tile_window(c_block_window,
|
||||
{0, step.at(number<1>{}) / number < IsGateUp ? 2 : 1 > {}});
|
||||
{0, step.at(number<1>{}) / number<IsGateUp ? 2 : 1>{}});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -140,11 +140,10 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
static constexpr int XDL_PerScaleN = ContinuousScaleNPerThread; // 2
|
||||
static_assert(XDL_PerScaleK % XDL_PerWeightK == 0);
|
||||
static_assert(KIterPerWarp % XDL_PerScaleK == 0);
|
||||
static_assert(NIterPerWarp % XDL_PerScaleN == 0);
|
||||
|
||||
static constexpr int MXFP4KPerWarp = KIterPerWarp / XDL_PerWeightK;
|
||||
static constexpr int ScaleKPerWarp = KIterPerWarp / XDL_PerScaleK;
|
||||
static constexpr int ScaleNPerWarp = NIterPerWarp / XDL_PerScaleN;
|
||||
static constexpr int ScaleNPerWarp = max(1, NIterPerWarp / XDL_PerScaleN);
|
||||
|
||||
static constexpr int MXFP4K_PerScaleK = MXFP4KPerWarp / ScaleKPerWarp;
|
||||
|
||||
@@ -168,8 +167,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
|
||||
static constexpr index_t ScaleBload_K1 = ContinuousScaleNPerThread * ContinuousScaleKPerThread;
|
||||
static constexpr index_t ScaleBload_num =
|
||||
kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 /
|
||||
WaveSize; // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
|
||||
max(1,
|
||||
kNPerBlock* kKPerBlock / NWarp / 32 / ScaleBload_K1 /
|
||||
WaveSize); // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
|
||||
static constexpr index_t Bload_total_num =
|
||||
Bload_num_perK * KIterPerWarp + ScaleBload_num + 0X3f0;
|
||||
static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
|
||||
@@ -559,16 +559,30 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
auto scale_b_flat_distribution =
|
||||
PipelinePolicy::template MakeFp4ScaleBFlatDramTileDistribution<Problem>();
|
||||
|
||||
auto b_flat_dram_coord = b_flat_dram_block_window_tmp.get_window_origin();
|
||||
auto scale_b_dram_coord = scale_b_flat_window.get_window_origin();
|
||||
int global_packed_n_idx = 0;
|
||||
|
||||
if constexpr(NIterPerWarp == 1)
|
||||
{
|
||||
// Continuous warps compute the values spaced ContinuousScaleNPerThread XDLs
|
||||
int group_scale_idx = b_flat_dram_coord[I0] / (ContinuousScaleNPerThread * NWarp);
|
||||
global_packed_n_idx = (b_flat_dram_coord[I0] / NWarp) % ContinuousScaleNPerThread;
|
||||
b_flat_dram_coord[I0] =
|
||||
group_scale_idx * (ContinuousScaleNPerThread * NWarp) + global_packed_n_idx;
|
||||
scale_b_dram_coord[I0] = group_scale_idx * NWarp;
|
||||
}
|
||||
|
||||
auto b_flat_dram_window = make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_dram_coord,
|
||||
b_flat_distribution);
|
||||
|
||||
auto scale_b_flat_dram_window = make_tile_window(
|
||||
scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<ScaleKFlatPerWarp>{}),
|
||||
scale_b_flat_window.get_window_origin(),
|
||||
scale_b_dram_coord,
|
||||
scale_b_flat_distribution);
|
||||
|
||||
using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
|
||||
@@ -718,11 +732,30 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1
|
||||
auto xdl_kIter) {
|
||||
auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
|
||||
|
||||
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
|
||||
auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
|
||||
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
|
||||
auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
|
||||
auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
|
||||
|
||||
auto scale = scale_tensor.get_thread_buffer()[scale_offset];
|
||||
auto scale = [&] {
|
||||
if constexpr(NIterPerWarp == 1)
|
||||
{
|
||||
if(global_packed_n_idx == 0)
|
||||
{
|
||||
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
|
||||
return scale_tensor.get_thread_buffer()[scale_offset];
|
||||
}
|
||||
else
|
||||
{
|
||||
auto scale_offset =
|
||||
scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{} + I1;
|
||||
return scale_tensor.get_thread_buffer()[scale_offset];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
|
||||
return scale_tensor.get_thread_buffer()[scale_offset];
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
|
||||
|
||||
Reference in New Issue
Block a user