diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp index 0678e87e47..b71415819a 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp @@ -87,11 +87,11 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config constexpr bool MXFP4_Pipeline = std::is_same_v; - if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up) + constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp / FlatmmConfig::N_Warp_Tile; + static_assert(NRepeat == 1 || NRepeat % 2 == 0); + if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up) { - static_assert( - FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0, - "requires NRepeat is multiple of 2 for FFN_gemm1_gate_up"); + static_assert(NRepeat % 2 == 0, "requires NRepeat is multiple of 2 for FFN_gemm1_gate_up"); } using ComputeDataType = ADataType; @@ -139,8 +139,8 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config scheduler, has_hot_loop_v, tail_number_v>>; - - constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern + constexpr int BlockedXDLN_PerWarp = + NRepeat == 1 ? 1 : 2; // determined by scale shuffle pattern using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem& scale, int experts_cnt) constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4 static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16"); - static_assert(FlatmmConfig::N_Repeat % N_Pack == 0); static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0); if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up) diff --git a/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc index 45df126540..f00de5e29e 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc @@ -92,7 +92,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc, // TODO: replace the magic declaration const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile; - ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk; + ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * topk; ck_tile::index_t valid_tile_num = sorted_tile_num; ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock; diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index b3b34a6da0..f8ee8a65a5 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -873,6 +873,17 @@ struct MoeFlatmmKernel auto& c_block_window = gemm_tile_windows.at(number<2>{}); + if constexpr(MXFP4_Pipeline && kNRepeat == 1) + { + // Continuous warps compute the values spaced N_Pack XDLs + auto c_window_coord = c_block_window.get_window_origin(); + int group_128_idx = c_window_coord[I1] / (N_Pack * NWave * NPerXdl); + int raked_idx = (c_window_coord[I1] / (NWave * NPerXdl)) % N_Pack; + c_window_coord[I1] = group_128_idx * (N_Pack * NWave * NPerXdl) + raked_idx * NPerXdl; + + c_block_window.set_window_origin(c_window_coord); + } + // Run EpiloguePipeline { using EpiProblem = typename EpiloguePipeline::Problem; @@ -1002,12 +1013,12 @@ struct MoeFlatmmKernel kargs.scale_n.ptr + expert_id * kargs.N, make_tuple(1, kargs.N), make_tuple(0, scale_stride_n), - number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, + number{}, number<1>{}), // MXF4_Pipeline does't use scale_n, so there is no need to // permute as n_pack make_tuple(number{}, - number < IsGateUp ? TilePartitioner::NPerBlock / 2 - : TilePartitioner::NPerBlock > {}), + number{}), {0, IsGateUp ? coord_n / 2 : coord_n}, output_acc_tile_distr); @@ -1016,7 +1027,7 @@ struct MoeFlatmmKernel kargs.scale_n.ptr + expert_id * kargs.N + kargs.N / 2, make_tuple(1, kargs.N), make_tuple(0, scale_stride_n), - number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, + number{}, number<1>{}), make_tuple(number{}, number{}), @@ -1033,8 +1044,8 @@ struct MoeFlatmmKernel auto exp_bias_window = make_tile_window( permute_tensor_view(exp_bias_view, number<(MXFP4_Pipeline && !IsInputGemm)>{}), make_tuple(number{}, - number < IsGateUp ? TilePartitioner::NPerBlock / 2 - : TilePartitioner::NPerBlock > {}), + number{}), {0, IsGateUp ? coord_n / 2 : coord_n}, output_acc_tile_distr); @@ -1110,17 +1121,47 @@ struct MoeFlatmmKernel "Currently, the CShuffle EpiloguePipeline only supports the Row Major " "Output layout"); - using TileEncodingPattern = tile_distribution_encoding_pattern_2d< - kBlockSize, - MPerIterationShuffle, - LDS_NPerIterationShuffle, - kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC(), - tile_distribution_pattern::thread_raked, - EpiProblem::kNumWaveGroups>; + constexpr int OutputVectorSize = + kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC(); + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; constexpr auto dram_tile_distribution = TileEncodingPattern::make_2d_static_tile_distribution(); + constexpr auto c_dram_distribution = [&] { + if constexpr(MXFP4_Pipeline && kNRepeat == 1) + { + // Continuous warps compute the values spaced N_Pack XDLs, the remaining part + // keeps the same as dram_tile_distribution. + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<2, 3>>{}); + } + else + { + return dram_tile_distribution; + } + }(); + constexpr auto LdsTileDistr = [&] { if constexpr(IsGateUp) return make_static_tile_distribution( @@ -1304,22 +1345,24 @@ struct MoeFlatmmKernel make_tile_scatter_gather(c_block_window.get_bottom_tensor_view(), c_block_window.get_window_lengths(), c_block_window.get_window_origin(), - dram_tile_distribution, + c_dram_distribution, c_scatter_offsets[mIter], c_scatter_valids[mIter]); + auto redistributed_c_out_tensor = make_static_distributed_tensor( + c_dram_distribution, c_out_tensor.get_thread_buffer()); if constexpr(!IsInputGemm || EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add) - c_scatter_tile_window.update(c_out_tensor); + c_scatter_tile_window.update(redistributed_c_out_tensor); else - c_scatter_tile_window.store(c_out_tensor); + c_scatter_tile_window.store(redistributed_c_out_tensor); if constexpr(iAccess != num_access - 1) { constexpr auto step = SFC::get_forward_step(iAccess); // row_offset of out windows has been included in scatter offset move_tile_window(c_block_window, - {0, step.at(number<1>{}) / number < IsGateUp ? 2 : 1 > {}}); + {0, step.at(number<1>{}) / number{}}); } }); } diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 8ec23b7570..da671e22ca 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -140,11 +140,10 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 static constexpr int XDL_PerScaleN = ContinuousScaleNPerThread; // 2 static_assert(XDL_PerScaleK % XDL_PerWeightK == 0); static_assert(KIterPerWarp % XDL_PerScaleK == 0); - static_assert(NIterPerWarp % XDL_PerScaleN == 0); static constexpr int MXFP4KPerWarp = KIterPerWarp / XDL_PerWeightK; static constexpr int ScaleKPerWarp = KIterPerWarp / XDL_PerScaleK; - static constexpr int ScaleNPerWarp = NIterPerWarp / XDL_PerScaleN; + static constexpr int ScaleNPerWarp = max(1, NIterPerWarp / XDL_PerScaleN); static constexpr int MXFP4K_PerScaleK = MXFP4KPerWarp / ScaleKPerWarp; @@ -168,8 +167,9 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize; static constexpr index_t ScaleBload_K1 = ContinuousScaleNPerThread * ContinuousScaleKPerThread; static constexpr index_t ScaleBload_num = - kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 / - WaveSize; // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize + max(1, + kNPerBlock* kKPerBlock / NWarp / 32 / ScaleBload_K1 / + WaveSize); // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize static constexpr index_t Bload_total_num = Bload_num_perK * KIterPerWarp + ScaleBload_num + 0X3f0; static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num; @@ -559,16 +559,30 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 auto scale_b_flat_distribution = PipelinePolicy::template MakeFp4ScaleBFlatDramTileDistribution(); + auto b_flat_dram_coord = b_flat_dram_block_window_tmp.get_window_origin(); + auto scale_b_dram_coord = scale_b_flat_window.get_window_origin(); + int global_packed_n_idx = 0; + + if constexpr(NIterPerWarp == 1) + { + // Continuous warps compute the values spaced ContinuousScaleNPerThread XDLs + int group_scale_idx = b_flat_dram_coord[I0] / (ContinuousScaleNPerThread * NWarp); + global_packed_n_idx = (b_flat_dram_coord[I0] / NWarp) % ContinuousScaleNPerThread; + b_flat_dram_coord[I0] = + group_scale_idx * (ContinuousScaleNPerThread * NWarp) + global_packed_n_idx; + scale_b_dram_coord[I0] = group_scale_idx * NWarp; + } + auto b_flat_dram_window = make_tile_window( b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views make_tuple(number{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_dram_coord, b_flat_distribution); auto scale_b_flat_dram_window = make_tile_window( scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views make_tuple(number{}, number{}), - scale_b_flat_window.get_window_origin(), + scale_b_dram_coord, scale_b_flat_distribution); using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window)); @@ -718,11 +732,30 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 auto xdl_kIter) { auto quant_idx_k = xdl_kIter % number{}; - auto scale_idx_n = xdl_nIter % number{}; - auto scale_idx_k = (xdl_kIter % number{}) / number{}; - auto scale_offset = scale_idx_n + scale_idx_k * number{}; + auto scale_idx_n = xdl_nIter % number{}; + auto scale_idx_k = (xdl_kIter % number{}) / number{}; - auto scale = scale_tensor.get_thread_buffer()[scale_offset]; + auto scale = [&] { + if constexpr(NIterPerWarp == 1) + { + if(global_packed_n_idx == 0) + { + auto scale_offset = scale_idx_n + scale_idx_k * number{}; + return scale_tensor.get_thread_buffer()[scale_offset]; + } + else + { + auto scale_offset = + scale_idx_n + scale_idx_k * number{} + I1; + return scale_tensor.get_thread_buffer()[scale_offset]; + } + } + else + { + auto scale_offset = scale_idx_n + scale_idx_k * number{}; + return scale_tensor.get_thread_buffer()[scale_offset]; + } + }(); constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size(); constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;