From 971ed7da51e171230460460642704c8a2ed7151e Mon Sep 17 00:00:00 2001 From: yadaish Date: Fri, 5 Dec 2025 10:12:13 +0000 Subject: [PATCH] update --- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 88b1ab8c4f..62a69a1667 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -723,6 +723,7 @@ struct MoeFlatmmKernel constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline; + /* const auto& b_flat_block_window = make_tile_window(b_flat_pad_view, make_tuple(number{}, @@ -730,6 +731,63 @@ struct MoeFlatmmKernel {static_cast(coord_n / BlockGemmShape::WarpTile::at(I1) / (isNonInterleaveGateUp ? 1 : 2)), 0}); + */ + const auto& b_flat_block_window = [&]() { + // GateUp needs to shuffle weight + if constexpr(IsGateUp) + { + // 1. Get Dimensions + const auto N = b_flat_pad_view.get_tensor_descriptor().get_length(I0); + const auto K = b_flat_pad_view.get_tensor_descriptor().get_length(I1); + + // 2. View Linear N as (2, N/2) -> effectively separating Gate (0) and Up (1) blocks + // Layout becomes: (BlockIdx, RowInBlock, K) + auto v_split = transform_tensor_view( + b_flat_pad_view, + make_tuple(make_unmerge_transform(make_tuple(number<2>{}, N / 2)), + make_pass_through_transform(K)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 1>{}, sequence<2>{})); + + // 3. Permute to (N/2, 2, K) -> (RowInBlock, BlockIdx, K) + // This puts Gate(i) and Up(i) adjacent in the view + auto v_permute = transform_tensor_view( + v_split, + make_tuple(make_pass_through_transform(N / 2), + make_pass_through_transform(number<2>{}), + make_pass_through_transform(K)), + make_tuple(sequence<1>{}, sequence<0>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + // 4. Merge back to (N, K) -> effectively Interleaved View + auto b_interleaved_view = transform_tensor_view( + v_permute, + make_tuple(make_merge_transform(make_tuple(N / 2, number<2>{})), + make_pass_through_transform(K)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // 5. Create Window on the transformed view + return make_tile_window( + b_interleaved_view, + make_tuple(number{}, + number{}), + {static_cast(coord_n / BlockGemmShape::WarpTile::at(I1) / + (isNonInterleaveGateUp ? 1 : 2)), + 0}); + } + else + { + // Default behavior for Interleaved or non-GateUp + return make_tile_window( + b_flat_pad_view, + make_tuple(number{}, + number{}), + {static_cast(coord_n / BlockGemmShape::WarpTile::at(I1) / + (isNonInterleaveGateUp ? 1 : 2)), + 0}); + } + }(); const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n;