add fp16xf4 moe

This commit is contained in:
Feng Shijie
2025-08-18 17:28:11 +00:00
parent 599e1f5b32
commit be55c0f9cb
10 changed files with 1345 additions and 214 deletions

View File

@@ -582,19 +582,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
{
if constexpr(!IsGateUpMode)
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
}
else
{
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
@@ -637,18 +638,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
{
if constexpr(!IsGateUpMode)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
}
else
{
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
@@ -723,18 +726,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
{
if constexpr(!IsGateUpMode)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
}
else
{
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
@@ -812,18 +817,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
{
if constexpr(!IsGateUpMode)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
}
else
{
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));