update mx flatmm tail pipeline

This commit is contained in:
mtgu0705
2025-09-08 21:42:47 -05:00
parent 0509597f55
commit c5030e602e

View File

@@ -1034,24 +1034,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
if constexpr(TailNum == TailNumber::Even)
{
// prefetch B(loopK)
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
{
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
scale_b_flat_dram_window;
move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
{scale_n_iter * NFlatPerBlockPerIter,
scale_k_iter * ScaleKFlatPerWarp});
scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) =
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
}
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
@@ -1068,54 +1052,95 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
});
// prefetch Scale A and Scale B (2i+1)
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
move_tile_window(scale_a_dram_windows(mIter)(kIter),
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
scale_a_tile_tensor_pong(mIter)(kIter) =
load_tile(scale_a_dram_windows(mIter)(kIter));
});
});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
move_tile_window(scale_b_dram_windows(nIter)(kIter),
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
scale_b_tile_tensor_pong(nIter)(kIter) =
load_tile(scale_b_dram_windows(nIter)(kIter));
});
});
// Prefill A(loopK)
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
// GEMM loopK-1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPacke, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter =
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
mIter_pack * MXdlPack + imxdl) %
m_preload;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
ikxd * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
if constexpr(mIter == 0)
dequant_mxfp4(
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
constexpr auto AkIter =
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
}
});
});
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
@@ -1129,47 +1154,67 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
Last2ndHotLoopScheduler();
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter =
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
mIter_pack * MXdlPack + imxdl) %
m_preload;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B
scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A
ikxd * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
if constexpr(mIter == 0)
dequant_mxfp4(
b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
constexpr auto AkIter =
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
}
});
});
});
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
LastHotLoopScheduler();
@@ -1177,48 +1222,67 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
else if constexpr(TailNum == TailNumber::Odd)
{
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPacke, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter =
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
mIter_pack * MXdlPack + imxdl) %
m_preload;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * NXdlPack +
inxdl)(kIter_pack * KXdlPack + ikxdl),
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
ikxd * MXdlPack + imxdl, // A opsel
ikxdl * NXdlPack + inxdl); // B opsel
if constexpr(mIter == 0)
dequant_mxfp4(
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
kIter / number<XDL_PerScaleK>{}),
nIter,
kIter);
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
constexpr auto AkIter =
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
}
});
});
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
});
});
LastHotLoopScheduler();