add fp16xf4 moe

This commit is contained in:
Feng Shijie
2025-08-18 17:28:11 +00:00
parent 599e1f5b32
commit be55c0f9cb
10 changed files with 1345 additions and 214 deletions

View File

@@ -370,6 +370,14 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& scale_block_window = gemm_tile_windows.at(I4);
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
"ScaleM and ScaleN should have the same GranularityK");
constexpr bool DoEpiScale =
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
auto a_block_window_with_distr =
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
a_block_window.get_window_lengths(),
@@ -383,26 +391,21 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
smem_ptr_pong);
// Run Epilogue Pipeline
if constexpr(false && (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) ||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0))
if constexpr(DoEpiScale)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
EpiloguePipeline{}(c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
}
}