mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-20 15:17:41 +00:00
update mx flatmm tail pipeline
This commit is contained in:
@@ -1034,24 +1034,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
|
||||
{
|
||||
auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
|
||||
auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
|
||||
|
||||
scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
|
||||
scale_b_flat_dram_window;
|
||||
|
||||
move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
|
||||
{scale_n_iter * NFlatPerBlockPerIter,
|
||||
scale_k_iter * ScaleKFlatPerWarp});
|
||||
|
||||
scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) =
|
||||
load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
|
||||
}
|
||||
|
||||
auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
|
||||
auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
|
||||
|
||||
@@ -1068,54 +1052,95 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter)(kIter) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter)(kIter),
|
||||
{mIter * MWarp * WG::kM, kIter * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_pong(mIter)(kIter) =
|
||||
load_tile(scale_a_dram_windows(mIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter)(kIter) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter)(kIter),
|
||||
{nIter * NWarp * WG::kN, kIter * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_pong(nIter)(kIter) =
|
||||
load_tile(scale_b_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPacke, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
mIter_pack * MXdlPack + imxdl) %
|
||||
m_preload;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter =
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1129,47 +1154,67 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
Last2ndHotLoopScheduler();
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
mIter_pack * MXdlPack + imxdl) %
|
||||
m_preload;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_pong(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_pong(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter =
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(
|
||||
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
@@ -1177,48 +1222,67 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPacke, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
mIter_pack * MXdlPack + imxdl) %
|
||||
m_preload;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * NXdlPack +
|
||||
inxdl)(kIter_pack * KXdlPack + ikxdl),
|
||||
scale_a_tensor_ping(mIter_pack)(kIter_pack), // scale B
|
||||
scale_b_tensor_ping(nIter_pack)(kIter_pack), // scale A
|
||||
ikxd * MXdlPack + imxdl, // A opsel
|
||||
ikxdl * NXdlPack + inxdl); // B opsel
|
||||
|
||||
if constexpr(mIter == 0)
|
||||
dequant_mxfp4(
|
||||
b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
|
||||
scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
|
||||
kIter / number<XDL_PerScaleK>{}),
|
||||
nIter,
|
||||
kIter);
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter =
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
|
||||
Reference in New Issue
Block a user