add fp16xf4 moe

This commit is contained in:
Feng Shijie
2025-08-18 17:28:11 +00:00
parent 599e1f5b32
commit be55c0f9cb
10 changed files with 1345 additions and 214 deletions

View File

@@ -80,7 +80,8 @@ enum class MoeFlatmmKind
template <typename TilePartitioner_,
typename FlatmmPipeline_,
typename EpiloguePipeline_,
MoeFlatmmKind kind>
MoeFlatmmKind kind,
typename FusedActivation = element_wise::Silu>
struct MoeFlatmmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
@@ -101,7 +102,8 @@ struct MoeFlatmmKernel
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using AccDataType = float;
using AccDataType = float;
using ActivationOp = FusedActivation;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -114,6 +116,7 @@ struct MoeFlatmmKernel
"The size of DsLayout and DsDataType should be the same");
static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
static constexpr index_t kBlockSize = EpiloguePipeline::kBlockSize;
static constexpr index_t kMPerBlock = EpiloguePipeline::kMPerBlock;
@@ -128,6 +131,17 @@ struct MoeFlatmmKernel
static constexpr index_t kNPerIteration = NPerXdl * NWave;
static constexpr index_t kNRepeat = kNPerBlock / kNPerIteration;
static constexpr int OutputNPerBlock =
IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
// MXF4_Pipeline only has the of scale B and granularityK is 32
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
static constexpr int MXFP4N_Pack = 2;
static constexpr int N_Pack = MXFP4_Pipeline ? MXFP4N_Pack : 1;
static constexpr int WeightPackedSize = numeric_traits<BDataType>::PackedSize;
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
struct MoeFlatmmKernelArgs
{
@@ -405,10 +419,10 @@ struct MoeFlatmmKernel
const BDataType* b_flat_ptr,
EDataType* e_ptr,
const AccDataType* exp_weight_ptr,
const int expert_id,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
// static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
@@ -432,9 +446,9 @@ struct MoeFlatmmKernel
}
}();
index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k /
BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
@@ -451,7 +465,7 @@ struct MoeFlatmmKernel
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens,
kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? kargs.N / 2 : kargs.N),
IsGateUp ? kargs.N / 2 : kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
@@ -461,14 +475,30 @@ struct MoeFlatmmKernel
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken,
kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? kargs.N / 2 : kargs.N),
IsGateUp ? kargs.N / 2 : kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view);
auto scale_n = kargs.scale_n;
constexpr int GranularityK = decltype(scale_n)::GranularityK;
index_t scale_k = GranularityK == 0 ? 1 : (kargs.K + GranularityK - 1) / GranularityK;
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
make_tuple(FlatScaleN, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view, scale_b_flat_view);
}
template <typename TensorView>
@@ -492,14 +522,9 @@ struct MoeFlatmmKernel
}
}();
const auto& b_flat_tensor_view = views.at(I1);
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I2);
constexpr int OutputNPerBlock = kind == MoeFlatmmKind::kFFN_gemm1_gate_up
? TilePartitioner::NPerBlock / 2
: TilePartitioner::NPerBlock;
const auto& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
@@ -516,12 +541,13 @@ struct MoeFlatmmKernel
}
}();
return make_tuple(a_pad_view, b_flat_tensor_view, c_pad_view);
return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3));
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, [[maybe_unused]] const index_t i_m, const index_t i_n)
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
[[maybe_unused]] const index_t coord_m,
const index_t coord_n)
{
const auto& a_pad_view = views.at(number<0>{});
const auto& b_flat_pad_view = views.at(number<1>{});
@@ -533,7 +559,7 @@ struct MoeFlatmmKernel
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0}); // NOTE!
{coord_m, 0}); // NOTE!
}
else
{
@@ -544,25 +570,33 @@ struct MoeFlatmmKernel
}
}();
const int problem_N_offset = kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? i_n / 2 : i_n;
constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline;
const auto& b_flat_block_window = make_tile_window(
b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(problem_N_offset / BlockGemmShape::WarpTile::at(I1)), 0});
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
(isNonInterleaveGateUp ? 1 : 2)),
0});
constexpr int OutputNPerBlock = kind == MoeFlatmmKind::kFFN_gemm1_gate_up
? TilePartitioner::NPerBlock / 2
: TilePartitioner::NPerBlock;
const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n;
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<OutputNPerBlock>{}),
{0, // offset_m is included when construct C-scatter-window offsets
problem_N_offset});
output_N_offset});
return make_tuple(a_block_window, b_flat_block_window, c_block_window);
constexpr int GranularityK = 32;
auto scale_block_window = make_tile_window(
views.at(I3),
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / GranularityK>{}),
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
}
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
@@ -614,16 +648,16 @@ struct MoeFlatmmKernel
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset +
expert_stride * expert_id;
const BDataType* b_flat_ptr =
static_cast<const BDataType*>(kargs.b_ptr) +
(splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / WeightPackedSize;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
const AccDataType* exp_weight_ptr =
static_cast<const AccDataType*>(kargs.p_sorted_expert_weights);
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(
a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, kargs, splitk_batch_offset);
a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, expert_id, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, coord_m, coord_n);
@@ -631,26 +665,43 @@ struct MoeFlatmmKernel
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(number<0>{});
const auto& b_block_window = gemm_tile_windows.at(number<1>{});
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& scale_block_window = gemm_tile_windows.at(I3);
auto a_gather_block_tile =
ck_tile::make_tile_scatter_gather(a_block_window.get_bottom_tensor_view(),
a_block_window.get_window_lengths(),
a_block_window.get_window_origin(),
FlatmmPipeline::GetADramTileDistribution(),
a_dram_dist,
a_offsets); // K DRAM tile window for
auto c_block_tile = FlatmmPipeline{}(a_gather_block_tile,
b_block_window,
number<kind == MoeFlatmmKind::kFFN_gemm1_gate_up>{},
num_loop,
smem_ptr_ping,
smem_ptr_pong);
using AccTile = decltype(c_block_tile);
auto c_block_tile = [&] {
if constexpr(MXFP4_Pipeline)
{
// MXFP4_Pipeline uses gate-up interleave 16 layout for weight
// so don't need extra processing
return FlatmmPipeline{}(a_gather_block_tile,
b_block_window,
scale_block_window, // weight scale with granularityK = 32
num_loop,
smem_ptr_ping,
smem_ptr_pong);
}
else
{
return FlatmmPipeline{}(a_gather_block_tile,
b_block_window,
number<IsGateUp>{},
num_loop,
smem_ptr_ping,
smem_ptr_pong);
}
}();
using AccTile = decltype(c_block_tile);
// Run EpiloguePipeline Pipeline
auto& c_block_window = gemm_tile_windows.at(number<2>{});
using ActivationOp = element_wise::Silu;
{
using EpiProblem = typename EpiloguePipeline::Problem;
@@ -666,26 +717,34 @@ struct MoeFlatmmKernel
constexpr index_t MRepeat = EpiloguePipeline::MRepeat;
constexpr index_t NRepeat = EpiloguePipeline::NRepeat;
constexpr auto lds_block_desc =
EpiloguePipeline::template MakeLdsBlockDescriptor<EpiProblem>();
static_assert(!IsGateUp || NumNXdlPerWavePerShuffle % 2 == 0);
constexpr index_t OutputNumNXdlPerWavePerShuffle =
IsGateUp ? NumNXdlPerWavePerShuffle / 2 : NumNXdlPerWavePerShuffle;
constexpr index_t LDS_NPerIterationShuffle =
IsGateUp ? NPerIterationShuffle / 2 : NPerIterationShuffle;
constexpr auto lds_block_desc = make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
make_tuple(number<LDS_NPerIterationShuffle>{}, number<1>{}));
// EpiloguePipeline::template MakeLdsBlockDescriptor<EpiProblem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<ODataType*>(smem_ptr_ping), lds_block_desc);
auto in_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
{0, 0});
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
{0, 0});
using SFC = space_filling_curve<
sequence<kMPerBlock,
kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? kNPerBlock / 2 : kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
@@ -696,7 +755,7 @@ struct MoeFlatmmKernel
using TileEncodingPattern = TileDistributionEncodingPattern2D<
kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
LDS_NPerIterationShuffle,
kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
EpiProblem::kNumWaveGroups>;
@@ -704,8 +763,24 @@ struct MoeFlatmmKernel
constexpr auto dram_tile_distribution =
TileEncodingPattern::Make2DStaticTileDistribution();
constexpr auto LdsTileDistr =
make_static_tile_distribution(EpiloguePipeline::MakeLdsDistributionEncode());
constexpr auto LdsTileDistr = [&] {
if constexpr(IsGateUp)
return make_static_tile_distribution(
detail::make_embed_tile_distribution_encoding(
tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
// merge two contiguous N
sequence<OutputNumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{},
typename CWarpDstr::DstrEncode{}));
else
return make_static_tile_distribution(
EpiloguePipeline::MakeLdsDistributionEncode());
}();
using LDSTileTensor =
decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
@@ -719,8 +794,8 @@ struct MoeFlatmmKernel
constexpr int kM1 = (64 / NPerXdl); // Thr
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
constexpr int ActVectorSize =
c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle * NumNXdlPerWavePerShuffle;
constexpr int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
OutputNumNXdlPerWavePerShuffle;
const index_t iMWarp = get_warp_id() / NWave;
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
@@ -737,32 +812,36 @@ struct MoeFlatmmKernel
//===----------------------------------------------------------------------===//
// Load scales and expert weights
//===----------------------------------------------------------------------===//
if constexpr(kind == MoeFlatmmKind::kFFN_gemm1_gate_up)
if constexpr(!MXFP4_Pipeline)
{
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
vec_scale_B[i] = kargs.scale_n[expert_id * kargs.N + coord_n / 2 +
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
vec_scale_B[i + NRepeat / 2] =
kargs.scale_n[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
});
if constexpr(IsGateUp)
{
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
vec_scale_B[i * 2] =
kargs.scale_n[expert_id * kargs.N + coord_n / 2 + i * NWave * NPerXdl +
iNWarp * NPerXdl + iNLane];
vec_scale_B[i * 2 + 1] =
kargs.scale_n[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto i) {
vec_scale_B[i] =
kargs.scale_n[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
iNWarp * NPerXdl + iNLane];
});
}
}
else
{
static_for<0, NRepeat, 1>{}([&](auto i) {
vec_scale_B[i] = kargs.scale_n[expert_id * kargs.N + coord_n +
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
});
}
static_for<0, MRepeat, 1>{}([&](auto i) {
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
index_t M2_offset = m2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
i * MPerXdl * MWave + coord_m;
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + m2] =
kargs.scale_m[row_to_token_idx(M2_offset)];
if constexpr(!MXFP4_Pipeline)
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + m2] =
kargs.scale_m[row_to_token_idx(M2_offset)];
if constexpr(!IsInputGemm)
vec_expert_weights[i * kM0 * kM2 + m0 * kM2 + m2] =
expert_weights[M2_offset];
@@ -770,46 +849,54 @@ struct MoeFlatmmKernel
});
});
constexpr int UpAccStride = NRepeat / 2;
//===----------------------------------------------------------------------===//
// Pingpong process start
//===----------------------------------------------------------------------===//
if constexpr(kind == MoeFlatmmKind::kFFN_gemm1_gate_up)
if constexpr(IsGateUp)
{
LDSTileTensor gate_tensor, up_tensor;
static_assert((NRepeat / NumNXdlPerWavePerShuffle) % 2 == 0);
// gate and up are interleaved along NRepeat dimension.
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
gate_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{}, c_warp_y_lengths),
c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 2 * n_xdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
c_warp_y_lengths)));
gate_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
up_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle,
0 * NumNXdlPerWavePerShuffle + UpAccStride>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
up_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{}, c_warp_y_lengths),
c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 2 * n_xdl + 1>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
c_warp_y_lengths)));
});
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
gate_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[n_xdl];
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[n_xdl + UpAccStride];
if constexpr(!MXFP4_Pipeline)
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
gate_tensor
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[2 * n_xdl];
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[2 * n_xdl + 1];
});
});
});
});
});
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
@@ -830,16 +917,18 @@ struct MoeFlatmmKernel
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
if constexpr(!IsInputGemm)
lds_tile[0]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_expert_weights[m_xdl * kM0 * kM2 + m0 * kM2 + m2];
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[n_xdl];
if constexpr(!MXFP4_Pipeline)
lds_tile[0]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[n_xdl];
});
});
});
@@ -875,51 +964,62 @@ struct MoeFlatmmKernel
if constexpr(iAccess < num_access - 1)
{
if constexpr(kind == MoeFlatmmKind::kFFN_gemm1_gate_up)
if constexpr(IsGateUp)
{
LDSTileTensor gate_tensor, up_tensor;
gate_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
nIter_next * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(
sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
up_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_next * NumMXdlPerWavePerShuffle,
nIter_next * NumNXdlPerWavePerShuffle + UpAccStride>{},
c_warp_y_index_zeros),
merge_sequences(
sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
gate_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
c_warp_y_lengths),
c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
nIter_next * NumNXdlPerWavePerShuffle +
2 * n_xdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
c_warp_y_lengths)));
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) *
c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
gate_tensor
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
n_xdl];
up_tensor
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
n_xdl + UpAccStride];
up_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
c_warp_y_lengths),
c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
nIter_next * NumNXdlPerWavePerShuffle +
2 * n_xdl + 1>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
c_warp_y_lengths)));
});
if constexpr(!MXFP4_Pipeline)
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
gate_tensor.get_thread_buffer()[acc_xdl_offset +
m0 * kM2 + m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
2 * n_xdl];
up_tensor.get_thread_buffer()[acc_xdl_offset +
m0 * kM2 + m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
2 * n_xdl + 1];
});
});
});
});
});
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
gate_tensor.get_thread_buffer().at(idx));
@@ -941,7 +1041,7 @@ struct MoeFlatmmKernel
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) *
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) *
c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
@@ -951,13 +1051,15 @@ struct MoeFlatmmKernel
m2] *= vec_expert_weights
[mIter_next * NumMXdlPerWavePerShuffle * kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2];
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
n_xdl];
if constexpr(!MXFP4_Pipeline)
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
n_xdl];
});
});
});