This commit is contained in:
yadaish
2025-12-05 10:12:13 +00:00
parent 12d764e999
commit 971ed7da51

View File

@@ -723,6 +723,7 @@ struct MoeFlatmmKernel
constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline;
/*
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
@@ -730,6 +731,63 @@ struct MoeFlatmmKernel
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
(isNonInterleaveGateUp ? 1 : 2)),
0});
*/
const auto& b_flat_block_window = [&]() {
// GateUp needs to shuffle weight
if constexpr(IsGateUp)
{
// 1. Get Dimensions
const auto N = b_flat_pad_view.get_tensor_descriptor().get_length(I0);
const auto K = b_flat_pad_view.get_tensor_descriptor().get_length(I1);
// 2. View Linear N as (2, N/2) -> effectively separating Gate (0) and Up (1) blocks
// Layout becomes: (BlockIdx, RowInBlock, K)
auto v_split = transform_tensor_view(
b_flat_pad_view,
make_tuple(make_unmerge_transform(make_tuple(number<2>{}, N / 2)),
make_pass_through_transform(K)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2>{}));
// 3. Permute to (N/2, 2, K) -> (RowInBlock, BlockIdx, K)
// This puts Gate(i) and Up(i) adjacent in the view
auto v_permute = transform_tensor_view(
v_split,
make_tuple(make_pass_through_transform(N / 2),
make_pass_through_transform(number<2>{}),
make_pass_through_transform(K)),
make_tuple(sequence<1>{}, sequence<0>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
// 4. Merge back to (N, K) -> effectively Interleaved View
auto b_interleaved_view = transform_tensor_view(
v_permute,
make_tuple(make_merge_transform(make_tuple(N / 2, number<2>{})),
make_pass_through_transform(K)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// 5. Create Window on the transformed view
return make_tile_window(
b_interleaved_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
(isNonInterleaveGateUp ? 1 : 2)),
0});
}
else
{
// Default behavior for Interleaved or non-GateUp
return make_tile_window(
b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
(isNonInterleaveGateUp ? 1 : 2)),
0});
}
}();
const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n;