From 299c63d198f05942620cfdcaf9b577d2b3fa9802 Mon Sep 17 00:00:00 2001 From: zanzhang Date: Wed, 21 May 2025 15:38:34 +0800 Subject: [PATCH] fix moe gemm for not gate only --- .../block_flatmm_asmem_bsmem_creg_v1.hpp | 62 ++++ .../ops/moe_gemm/kernel/moe_gemm_kernel.hpp | 89 +---- ..._gemm_pipeline_agmem_bgmem_creg_flatmm.hpp | 305 ++++++++---------- 3 files changed, 191 insertions(+), 265 deletions(-) diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp index 18b2fe6483..ff9a793076 100644 --- a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -117,6 +117,68 @@ struct BlockFlatmmASmemBSmemCRegV1 }); }); } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + ABlockTensor& a_block_tensor, + BFlatBlockTensor& b_warp_tensor) const + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = + BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/moe_gemm/kernel/moe_gemm_kernel.hpp b/include/ck_tile/ops/moe_gemm/kernel/moe_gemm_kernel.hpp index 9652b04348..ebf655ba5f 100644 --- a/include/ck_tile/ops/moe_gemm/kernel/moe_gemm_kernel.hpp +++ b/include/ck_tile/ops/moe_gemm/kernel/moe_gemm_kernel.hpp @@ -66,6 +66,7 @@ struct MoeGemmKernel remove_cvref_t; // TileFlatmmShape static constexpr bool IsInputGemm = FlatmmPipeline::IsInputGemm; + static constexpr bool IsGateOnly = FlatmmPipeline::IsGateOnly; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -134,7 +135,6 @@ struct MoeGemmKernel CK_TILE_HOST static constexpr MoeGemmKernelArgs MakeKernelArgs(const MoeGemmHostArgs& hostArgs) { - printf("in moe gemm kernel args! \n"); return MoeGemmKernelArgs{hostArgs.p_sorted_token_ids, hostArgs.p_sorted_expert_ids, hostArgs.p_max_token_id, @@ -263,70 +263,6 @@ struct MoeGemmKernel number<1>{}); }(); - - // const auto& b_tensor_view = [&]() { - // if constexpr(std::is_same_v) - // { - // if constexpr(TilePartitioner::BlockGemmShape::PermuteB) - // { - // constexpr index_t K1 = FlatmmPipeline::GetSmemPackB(); - // const index_t K0 = splitk_batch_offset.splitted_k / K1; - // constexpr index_t VectorSizeB = std::min(K1, FlatmmPipeline::GetVectorSizeB()); - // const auto b_k0_n_k1_desc = - // make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - // make_tuple(kargs.N * K1, K1, I1), - // number{}, - // number<1>{}); - // const auto b_n_k_desc = transform_tensor_descriptor( - // b_k0_n_k1_desc, - // make_tuple(make_merge_transform(make_tuple(K0, K1)), - // make_pass_through_transform(kargs.N)), - // make_tuple(sequence<0, 2>{}, sequence<1>{}), - // make_tuple(sequence<0>{}, sequence<1>{})); - // return make_tensor_view(b_ptr, b_n_k_desc); - // } - // else - // { - // return make_naive_tensor_view( - // b_ptr, - // make_tuple(splitk_batch_offset.splitted_k, kargs.N), - // make_tuple(kargs.stride_B, 1), - // number{}, - // number<1>{}); - // } - // } - // else - // { - // if constexpr(TilePartitioner::BlockGemmShape::PermuteB) - // { - // constexpr index_t K1 = FlatmmPipeline::GetSmemPackB(); - // const index_t K0 = splitk_batch_offset.splitted_k / K1; - // constexpr index_t VectorSizeB = std::min(K1, FlatmmPipeline::GetVectorSizeB()); - // const auto b_k0_n_k1_desc = - // make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - // make_tuple(kargs.N * K1, K1, I1), - // number{}, - // number<1>{}); - // const auto b_n_k_desc = transform_tensor_descriptor( - // b_k0_n_k1_desc, - // make_tuple(make_merge_transform(make_tuple(K0, K1)), - // make_pass_through_transform(kargs.N)), - // make_tuple(sequence<0, 2>{}, sequence<1>{}), - // make_tuple(sequence<1>{}, sequence<0>{})); - // return make_tensor_view(b_ptr, b_n_k_desc); - // } - // else - // { - // return make_naive_tensor_view( - // b_ptr, - // make_tuple(kargs.N, splitk_batch_offset.splitted_k), - // make_tuple(kargs.stride_B, 1), - // number{}, - // number<1>{}); - // } - // } - // }(); - // TODO: enable vector write for C in ColMajor const auto& c_tensor_view = [&]() { if constexpr(std::is_same_v) @@ -422,29 +358,6 @@ struct MoeGemmKernel make_tuple(sequence<0>{}, sequence<1>{})); } - // template - // CK_TILE_DEVICE static auto GetCTransformGemmView(const CView& view, const index_t token_id) - // { - // if constexpr(std::is_same_v) - // return transform_tensor_view( - // view, - // make_tuple(make_indexing_transform( - // view.get_tensor_descriptor().get_length(number<0>()), token_id), - // make_pass_through_transform( - // view.get_tensor_descriptor().get_length(number<1>()))), - // make_tuple(sequence<0>{}, sequence<1>{}), - // make_tuple(sequence<0>{}, sequence<1>{})); - // else - // return transform_tensor_view( - // view, - // make_tuple(make_pass_through_transform( - // view.get_tensor_descriptor().get_length(number<0>())), - // make_indexing_transform( - // view.get_tensor_descriptor().get_length(number<1>()), token_id)), - // make_tuple(sequence<0>{}, sequence<1>{}), - // make_tuple(sequence<0>{}, sequence<1>{})); - // } - template CK_TILE_DEVICE static auto TransformGemmPadViews(const PadView& views, const index_t token_id) { diff --git a/include/ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm.hpp b/include/ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm.hpp index 80b8bd8074..a6abf6b953 100644 --- a/include/ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm.hpp +++ b/include/ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm.hpp @@ -11,7 +11,7 @@ namespace ck_tile { template -struct MoeGemmPipelineAgBgCrImpl +struct MoeGemmPipelineAgBgCrImpl : public FlatmmPipelineAGmemBGmemCRegV1 { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -95,140 +95,6 @@ struct MoeGemmPipelineAgBgCrImpl return MRepeat; } - template - CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window, - const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem) const - { - static_assert( - std::is_same_v>, - "wrong!"); - - static_assert(kMPerBlock == ADramBlockWindow{}.get_window_lengths()[number<0>{}], - "wrong!"); - static_assert(kKPerBlock == ADramBlockWindow{}.get_window_lengths()[number<1>{}], - "wrong!"); - - // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); - - constexpr auto a_lds_block_desc = - PipelinePolicy::template MakeALdsBlockDescriptor(); - - auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - - // A LDS tile window for store - auto a_copy_lds_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // A LDS tile for block GEMM - auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // Block GEMM - auto block_flatmm = BlockFlatmm(); - - // B flat DRAM window for load - auto b_flat_distribution = - PipelinePolicy::template MakeBFlatDramTileDistribution(); - auto b_flat_dram_window = // tile_window_with_static_distribution - make_tile_window( - b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views - make_tuple(number{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - b_flat_distribution); - - // Acc register tile - auto c_block_tile = decltype(block_flatmm(a_lds_gemm_window, b_flat_dram_window)){}; - - // prefetch - // global read 0 - auto a_block_tile = a_dram_block_window.load(); - - { - // move to 1 - move_tile_window(a_dram_block_window, {0, kKPerBlock}); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - // LDS write 0 - if constexpr(std::is_same_v) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - PipelinePolicy::template MakeShuffledARegBlockDistribution()); - shuffle_tile(a_shuffle_tmp, a_block_tile); - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); - store_tile(a_copy_lds_window, a_block_tile_tmp); - } - else - { - store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); - } - } - - index_t iCounter = num_loop - 1; - while(iCounter > 0) - { - // global read i + 1 - a_dram_block_window.load(a_block_tile); - - block_sync_lds(); - - // GEMM i - block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window); - - block_sync_lds(); - - // move to i + 2 - move_tile_window(a_dram_block_window, {0, kKPerBlock}); - - // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); - - // move to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - iCounter--; - } - - // tail - { - block_sync_lds(); - - // GEMM num_loop - 1 - block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window); - } - - sweep_tile(c_block_tile, - [&](auto idx0, auto idx1) { - fp32x2_t v_{c_block_tile(idx0), c_block_tile(idx1)}; - GateActivation{}(v_, v_); - c_block_tile(idx0) = v_.x; - c_block_tile(idx1) = v_.y; - }, - sequence<1, 2>{}); - - return c_block_tile; - } - - template - CK_TILE_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window_tmp, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem) const - { - return operator()( - a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, - b_flat_dram_block_window_tmp, - num_loop, - p_smem); - } - template CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window, const AElementFunction& a_element_func, @@ -246,6 +112,26 @@ struct MoeGemmPipelineAgBgCrImpl static_assert(kKPerBlock == ADramBlockWindow{}.get_window_lengths()[number<1>{}], "wrong!"); + constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; + constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; + + constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; + constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + + // A tile in LDS ADataType* p_a_lds = static_cast(p_smem); @@ -268,106 +154,171 @@ struct MoeGemmPipelineAgBgCrImpl // B flat DRAM window for load auto b_flat_distribution = PipelinePolicy::template MakeBFlatDramTileDistribution(); - - auto b_gate_flat_dram_window = + auto b_gate_flat_dram_window = // tile_window_with_static_distribution make_tile_window( b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views make_tuple(number{}, number{}), b_flat_dram_block_window_tmp.get_window_origin(), b_flat_distribution); - b_flat_dram_block_window_tmp.move({N, 0}) - auto b_up_flat_dram_window = + move_tile(b_flat_dram_block_window_tmp, {N, 0}); + auto b_up_flat_dram_window = // tile_window_with_static_distribution make_tile_window( b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views make_tuple(number{}, number{}), b_flat_dram_block_window_tmp.get_window_origin(), b_flat_distribution); + // Acc register tile using c_block_tile_type = decltype(block_flatmm(a_lds_gemm_window, b_gate_flat_dram_window)); - auto c_block_tiles[2] = {c_block_tile_type{}, c_block_tile_type{}}; + auto c_gate_block_tile = c_block_tile_type{}; + auto c_up_block_tile = c_block_tile_type{} // prefetch // global read 0 - auto a_block_tile = a_dram_block_window.load(); + a_block_tile = load_tile(a_dram_block_window); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_2; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_gate_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); { // move to 1 move_tile_window(a_dram_block_window, {0, kKPerBlock}); + // move to next flat K + move_tile_window(b_gate_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tiles[0]); - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tiles[1]); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - if constexpr(std::is_same_v) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - PipelinePolicy::template MakeShuffledARegBlockDistribution()); - shuffle_tile(a_shuffle_tmp, a_block_tile); - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); - store_tile(a_copy_lds_window, a_block_tile_tmp); - } - else - { - store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); - } + store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + block_sync_lds(); } index_t iCounter = num_loop - 1; while(iCounter > 0) { // global read i + 1 - a_dram_block_window.load(a_block_tile); - - block_sync_lds(); + a_block_tile = load_tile(a_dram_block_window); // GEMM i - block_flatmm(c_block_tiles[0], a_lds_gemm_window, b_gate_flat_dram_window); - - //TODO: simply add b_gate flatmm - block_flatmm(c_block_tiles[1], a_lds_gemm_window, b_up_flat_dram_window); + block_flatmm(c_gate_block_tile, a_warp_windows, b_warp_tensor); block_sync_lds(); - // move to i + 2 - move_tile_window(a_dram_block_window, {0, kKPerBlock}); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_up_flat_dram_window; - // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // move to next flat K + move_tile_window(b_up_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // GEMM i + block_flatmm(c_up_block_tile, a_warp_windows, b_warp_tensor_2); + + block_sync_lds(); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_gate_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // move to next flat K move_tile_window(b_gate_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - move_tile_window(b_up_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // LDS write i + 1 + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + // HotLoopScheduler(); + block_sync_lds(); iCounter--; } // tail { + // GEMM i + block_flatmm(c_gate_block_tile, a_warp_windows, b_warp_tensor); + + block_sync_lds(); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_up_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // HotLoopScheduler(); block_sync_lds(); // GEMM num_loop - 1 - block_flatmm(c_block_tiles[0], a_lds_gemm_window, b_gate_flat_dram_window); - block_flatmm(c_block_tiles[1], a_lds_gemm_window, b_up_flat_dram_window); + block_flatmm(c_up_block_tile, a_warp_windows, b_warp_tensor_2); } - sweep_tile(c_block_tiles[0], + sweep_tile(c_gate_block_tile, [&](auto idx0, auto idx1) { - fp32x2_t v_{c_block_tiles[0].at(number<0>{})(idx0), c_block_tiles[0].at(number<0>{})(idx1)}; + fp32x2_t v_{c_gate_block_tile.at(number<0>{})(idx0), c_gate_block_tile.at(number<0>{})(idx1)}; typename Problem::GateActivation{}(v_, v_); - c_block_tiles[0].at(number<0>{})(idx0) = v_.x; - c_block_tiles[0].at(number<0>{})(idx1) = v_.y; + c_gate_block_tile.at(number<0>{})(idx0) = v_.x; + c_gate_block_tile.at(number<0>{})(idx1) = v_.y; }, sequence<1, 2>{}); auto c_block_tile = tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; }, - c_block_tiles[0], - c_block_tiles[1]); + c_gate_block_tile, + c_up_block_tile); - return c_block_tiles[0]; + return c_block_tile; } template