mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
add fp16xf4 moe
This commit is contained in:
@@ -582,19 +582,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
|
||||
{
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
@@ -637,18 +638,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
|
||||
{
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
@@ -723,18 +726,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
|
||||
{
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
@@ -812,18 +817,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
|
||||
{
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
|
||||
Reference in New Issue
Block a user