mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
add fp16xf4 moe
This commit is contained in:
@@ -80,7 +80,8 @@ enum class MoeFlatmmKind
|
||||
template <typename TilePartitioner_,
|
||||
typename FlatmmPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
MoeFlatmmKind kind>
|
||||
MoeFlatmmKind kind,
|
||||
typename FusedActivation = element_wise::Silu>
|
||||
struct MoeFlatmmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
@@ -101,7 +102,8 @@ struct MoeFlatmmKernel
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using AccDataType = float;
|
||||
using AccDataType = float;
|
||||
using ActivationOp = FusedActivation;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -114,6 +116,7 @@ struct MoeFlatmmKernel
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
|
||||
static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
|
||||
static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
|
||||
|
||||
static constexpr index_t kBlockSize = EpiloguePipeline::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = EpiloguePipeline::kMPerBlock;
|
||||
@@ -128,6 +131,17 @@ struct MoeFlatmmKernel
|
||||
static constexpr index_t kNPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t kNRepeat = kNPerBlock / kNPerIteration;
|
||||
|
||||
static constexpr int OutputNPerBlock =
|
||||
IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
|
||||
|
||||
// MXF4_Pipeline only has the of scale B and granularityK is 32
|
||||
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
|
||||
static constexpr int MXFP4N_Pack = 2;
|
||||
|
||||
static constexpr int N_Pack = MXFP4_Pipeline ? MXFP4N_Pack : 1;
|
||||
|
||||
static constexpr int WeightPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
struct MoeFlatmmKernelArgs
|
||||
{
|
||||
@@ -405,10 +419,10 @@ struct MoeFlatmmKernel
|
||||
const BDataType* b_flat_ptr,
|
||||
EDataType* e_ptr,
|
||||
const AccDataType* exp_weight_ptr,
|
||||
const int expert_id,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
// static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -432,9 +446,9 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k /
|
||||
BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
@@ -451,7 +465,7 @@ struct MoeFlatmmKernel
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens,
|
||||
kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? kargs.N / 2 : kargs.N),
|
||||
IsGateUp ? kargs.N / 2 : kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
@@ -461,14 +475,30 @@ struct MoeFlatmmKernel
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken,
|
||||
kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? kargs.N / 2 : kargs.N),
|
||||
IsGateUp ? kargs.N / 2 : kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view);
|
||||
auto scale_n = kargs.scale_n;
|
||||
constexpr int GranularityK = decltype(scale_n)::GranularityK;
|
||||
|
||||
index_t scale_k = GranularityK == 0 ? 1 : (kargs.K + GranularityK - 1) / GranularityK;
|
||||
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
|
||||
|
||||
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
|
||||
make_tuple(FlatScaleN, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view, scale_b_flat_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -492,14 +522,9 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_tensor_view = views.at(I1);
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I2);
|
||||
constexpr int OutputNPerBlock = kind == MoeFlatmmKind::kFFN_gemm1_gate_up
|
||||
? TilePartitioner::NPerBlock / 2
|
||||
: TilePartitioner::NPerBlock;
|
||||
const auto& c_tensor_view = views.at(I2);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
@@ -516,12 +541,13 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, c_pad_view);
|
||||
return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, [[maybe_unused]] const index_t i_m, const index_t i_n)
|
||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
||||
[[maybe_unused]] const index_t coord_m,
|
||||
const index_t coord_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(number<0>{});
|
||||
const auto& b_flat_pad_view = views.at(number<1>{});
|
||||
@@ -533,7 +559,7 @@ struct MoeFlatmmKernel
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0}); // NOTE!
|
||||
{coord_m, 0}); // NOTE!
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -544,25 +570,33 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const int problem_N_offset = kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? i_n / 2 : i_n;
|
||||
constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline;
|
||||
|
||||
const auto& b_flat_block_window = make_tile_window(
|
||||
b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(problem_N_offset / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
|
||||
(isNonInterleaveGateUp ? 1 : 2)),
|
||||
0});
|
||||
|
||||
constexpr int OutputNPerBlock = kind == MoeFlatmmKind::kFFN_gemm1_gate_up
|
||||
? TilePartitioner::NPerBlock / 2
|
||||
: TilePartitioner::NPerBlock;
|
||||
const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n;
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<OutputNPerBlock>{}),
|
||||
{0, // offset_m is included when construct C-scatter-window offsets
|
||||
problem_N_offset});
|
||||
output_N_offset});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, c_block_window);
|
||||
constexpr int GranularityK = 32;
|
||||
|
||||
auto scale_block_window = make_tile_window(
|
||||
views.at(I3),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / GranularityK>{}),
|
||||
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
@@ -614,16 +648,16 @@ struct MoeFlatmmKernel
|
||||
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
splitk_batch_offset.b_k_split_offset +
|
||||
expert_stride * expert_id;
|
||||
const BDataType* b_flat_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
(splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / WeightPackedSize;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
const AccDataType* exp_weight_ptr =
|
||||
static_cast<const AccDataType*>(kargs.p_sorted_expert_weights);
|
||||
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(
|
||||
a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, kargs, splitk_batch_offset);
|
||||
a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, expert_id, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, coord_m, coord_n);
|
||||
@@ -631,26 +665,43 @@ struct MoeFlatmmKernel
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(number<0>{});
|
||||
const auto& b_block_window = gemm_tile_windows.at(number<1>{});
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
auto a_gather_block_tile =
|
||||
ck_tile::make_tile_scatter_gather(a_block_window.get_bottom_tensor_view(),
|
||||
a_block_window.get_window_lengths(),
|
||||
a_block_window.get_window_origin(),
|
||||
FlatmmPipeline::GetADramTileDistribution(),
|
||||
a_dram_dist,
|
||||
a_offsets); // K DRAM tile window for
|
||||
auto c_block_tile = FlatmmPipeline{}(a_gather_block_tile,
|
||||
b_block_window,
|
||||
number<kind == MoeFlatmmKind::kFFN_gemm1_gate_up>{},
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
using AccTile = decltype(c_block_tile);
|
||||
|
||||
auto c_block_tile = [&] {
|
||||
if constexpr(MXFP4_Pipeline)
|
||||
{
|
||||
// MXFP4_Pipeline uses gate-up interleave 16 layout for weight
|
||||
// so don't need extra processing
|
||||
return FlatmmPipeline{}(a_gather_block_tile,
|
||||
b_block_window,
|
||||
scale_block_window, // weight scale with granularityK = 32
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FlatmmPipeline{}(a_gather_block_tile,
|
||||
b_block_window,
|
||||
number<IsGateUp>{},
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
}
|
||||
}();
|
||||
using AccTile = decltype(c_block_tile);
|
||||
|
||||
// Run EpiloguePipeline Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(number<2>{});
|
||||
using ActivationOp = element_wise::Silu;
|
||||
|
||||
{
|
||||
using EpiProblem = typename EpiloguePipeline::Problem;
|
||||
@@ -666,26 +717,34 @@ struct MoeFlatmmKernel
|
||||
constexpr index_t MRepeat = EpiloguePipeline::MRepeat;
|
||||
constexpr index_t NRepeat = EpiloguePipeline::NRepeat;
|
||||
|
||||
constexpr auto lds_block_desc =
|
||||
EpiloguePipeline::template MakeLdsBlockDescriptor<EpiProblem>();
|
||||
static_assert(!IsGateUp || NumNXdlPerWavePerShuffle % 2 == 0);
|
||||
|
||||
constexpr index_t OutputNumNXdlPerWavePerShuffle =
|
||||
IsGateUp ? NumNXdlPerWavePerShuffle / 2 : NumNXdlPerWavePerShuffle;
|
||||
constexpr index_t LDS_NPerIterationShuffle =
|
||||
IsGateUp ? NPerIterationShuffle / 2 : NPerIterationShuffle;
|
||||
|
||||
constexpr auto lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
|
||||
make_tuple(number<LDS_NPerIterationShuffle>{}, number<1>{}));
|
||||
|
||||
// EpiloguePipeline::template MakeLdsBlockDescriptor<EpiProblem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<ODataType*>(smem_ptr_ping), lds_block_desc);
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<
|
||||
sequence<kMPerBlock,
|
||||
kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? kNPerBlock / 2 : kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
@@ -696,7 +755,7 @@ struct MoeFlatmmKernel
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<
|
||||
kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
LDS_NPerIterationShuffle,
|
||||
kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
EpiProblem::kNumWaveGroups>;
|
||||
@@ -704,8 +763,24 @@ struct MoeFlatmmKernel
|
||||
constexpr auto dram_tile_distribution =
|
||||
TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
constexpr auto LdsTileDistr =
|
||||
make_static_tile_distribution(EpiloguePipeline::MakeLdsDistributionEncode());
|
||||
constexpr auto LdsTileDistr = [&] {
|
||||
if constexpr(IsGateUp)
|
||||
return make_static_tile_distribution(
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
// merge two contiguous N
|
||||
sequence<OutputNumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{},
|
||||
typename CWarpDstr::DstrEncode{}));
|
||||
else
|
||||
return make_static_tile_distribution(
|
||||
EpiloguePipeline::MakeLdsDistributionEncode());
|
||||
}();
|
||||
|
||||
using LDSTileTensor =
|
||||
decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
|
||||
@@ -719,8 +794,8 @@ struct MoeFlatmmKernel
|
||||
constexpr int kM1 = (64 / NPerXdl); // Thr
|
||||
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
|
||||
|
||||
constexpr int ActVectorSize =
|
||||
c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle * NumNXdlPerWavePerShuffle;
|
||||
constexpr int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
|
||||
OutputNumNXdlPerWavePerShuffle;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
@@ -737,32 +812,36 @@ struct MoeFlatmmKernel
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Load scales and expert weights
|
||||
//===----------------------------------------------------------------------===//
|
||||
if constexpr(kind == MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
{
|
||||
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
|
||||
vec_scale_B[i] = kargs.scale_n[expert_id * kargs.N + coord_n / 2 +
|
||||
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
vec_scale_B[i + NRepeat / 2] =
|
||||
kargs.scale_n[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
|
||||
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
|
||||
vec_scale_B[i * 2] =
|
||||
kargs.scale_n[expert_id * kargs.N + coord_n / 2 + i * NWave * NPerXdl +
|
||||
iNWarp * NPerXdl + iNLane];
|
||||
vec_scale_B[i * 2 + 1] =
|
||||
kargs.scale_n[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
|
||||
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto i) {
|
||||
vec_scale_B[i] =
|
||||
kargs.scale_n[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
|
||||
iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto i) {
|
||||
vec_scale_B[i] = kargs.scale_n[expert_id * kargs.N + coord_n +
|
||||
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto i) {
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
index_t M2_offset = m2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave + coord_m;
|
||||
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + m2] =
|
||||
kargs.scale_m[row_to_token_idx(M2_offset)];
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + m2] =
|
||||
kargs.scale_m[row_to_token_idx(M2_offset)];
|
||||
if constexpr(!IsInputGemm)
|
||||
vec_expert_weights[i * kM0 * kM2 + m0 * kM2 + m2] =
|
||||
expert_weights[M2_offset];
|
||||
@@ -770,46 +849,54 @@ struct MoeFlatmmKernel
|
||||
});
|
||||
});
|
||||
|
||||
constexpr int UpAccStride = NRepeat / 2;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pingpong process start
|
||||
//===----------------------------------------------------------------------===//
|
||||
if constexpr(kind == MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
LDSTileTensor gate_tensor, up_tensor;
|
||||
|
||||
static_assert((NRepeat / NumNXdlPerWavePerShuffle) % 2 == 0);
|
||||
// gate and up are interleaved along NRepeat dimension.
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
gate_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{}, c_warp_y_lengths),
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 2 * n_xdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths)));
|
||||
|
||||
gate_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
up_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle,
|
||||
0 * NumNXdlPerWavePerShuffle + UpAccStride>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
up_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{}, c_warp_y_lengths),
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 2 * n_xdl + 1>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths)));
|
||||
});
|
||||
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
gate_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[n_xdl + UpAccStride];
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
gate_tensor
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[2 * n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[2 * n_xdl + 1];
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
@@ -830,16 +917,18 @@ struct MoeFlatmmKernel
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_expert_weights[m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[n_xdl];
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[n_xdl];
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -875,51 +964,62 @@ struct MoeFlatmmKernel
|
||||
|
||||
if constexpr(iAccess < num_access - 1)
|
||||
{
|
||||
if constexpr(kind == MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
LDSTileTensor gate_tensor, up_tensor;
|
||||
|
||||
gate_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
|
||||
nIter_next * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(
|
||||
sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
up_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_next * NumMXdlPerWavePerShuffle,
|
||||
nIter_next * NumNXdlPerWavePerShuffle + UpAccStride>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(
|
||||
sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
gate_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths),
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
|
||||
nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths)));
|
||||
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
gate_tensor
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
n_xdl];
|
||||
up_tensor
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
n_xdl + UpAccStride];
|
||||
up_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths),
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
|
||||
nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl + 1>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths)));
|
||||
});
|
||||
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
gate_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl + 1];
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
gate_tensor.get_thread_buffer().at(idx));
|
||||
@@ -941,7 +1041,7 @@ struct MoeFlatmmKernel
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) *
|
||||
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
@@ -951,13 +1051,15 @@ struct MoeFlatmmKernel
|
||||
m2] *= vec_expert_weights
|
||||
[mIter_next * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
n_xdl];
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
n_xdl];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user