mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
update
This commit is contained in:
@@ -723,6 +723,7 @@ struct MoeFlatmmKernel
|
||||
|
||||
constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline;
|
||||
|
||||
/*
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
@@ -730,6 +731,63 @@ struct MoeFlatmmKernel
|
||||
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
|
||||
(isNonInterleaveGateUp ? 1 : 2)),
|
||||
0});
|
||||
*/
|
||||
const auto& b_flat_block_window = [&]() {
|
||||
// GateUp needs to shuffle weight
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
// 1. Get Dimensions
|
||||
const auto N = b_flat_pad_view.get_tensor_descriptor().get_length(I0);
|
||||
const auto K = b_flat_pad_view.get_tensor_descriptor().get_length(I1);
|
||||
|
||||
// 2. View Linear N as (2, N/2) -> effectively separating Gate (0) and Up (1) blocks
|
||||
// Layout becomes: (BlockIdx, RowInBlock, K)
|
||||
auto v_split = transform_tensor_view(
|
||||
b_flat_pad_view,
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<2>{}, N / 2)),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
// 3. Permute to (N/2, 2, K) -> (RowInBlock, BlockIdx, K)
|
||||
// This puts Gate(i) and Up(i) adjacent in the view
|
||||
auto v_permute = transform_tensor_view(
|
||||
v_split,
|
||||
make_tuple(make_pass_through_transform(N / 2),
|
||||
make_pass_through_transform(number<2>{}),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
// 4. Merge back to (N, K) -> effectively Interleaved View
|
||||
auto b_interleaved_view = transform_tensor_view(
|
||||
v_permute,
|
||||
make_tuple(make_merge_transform(make_tuple(N / 2, number<2>{})),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// 5. Create Window on the transformed view
|
||||
return make_tile_window(
|
||||
b_interleaved_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
|
||||
(isNonInterleaveGateUp ? 1 : 2)),
|
||||
0});
|
||||
}
|
||||
else
|
||||
{
|
||||
// Default behavior for Interleaved or non-GateUp
|
||||
return make_tile_window(
|
||||
b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
|
||||
(isNonInterleaveGateUp ? 1 : 2)),
|
||||
0});
|
||||
}
|
||||
}();
|
||||
|
||||
const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user