optimized the VGPR repack issue for MXFP4

This commit is contained in:
mtgu0705
2025-09-17 21:34:03 -05:00
parent 346a400027
commit f2db44710f

View File

@@ -612,27 +612,25 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
// using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
// // use v4i32 as the data type between basicblock to avoid unpack and repack operation.
// using V4UInt_Buffer = thread_buffer<uint32_t, 4>;
// union UnionB
// {
// V4UInt_Buffer u = 0;
// MXFP4_Buffer mxfp4;
// } ub;
using MXFP4_B_Buffer = decltype(load_tile(b_flat_dram_window));
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
using V4UInt_B_Buffer = thread_buffer<uint32_t, 4>;
union UnionBuf
{
V4UInt_B_Buffer u = 0;
MXFP4_B_Buffer mxfp4;
} ub;
// pingpong buffer for B
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
statically_indexed_array<statically_indexed_array<V4UInt_B_Buffer, KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
statically_indexed_array<statically_indexed_array<V4UInt_B_Buffer, KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_pong;
// pingpong buffer for Scale A and Scale B
@@ -693,7 +691,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_ping(nIter)(kIter) = ub.u;
});
});
// move B window to next flat K
@@ -742,16 +741,33 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
block_sync_lds();
using MXFP4_A_Buffer_ping =
decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{})));
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
using V4UInt_A_Buffer = thread_buffer<uint32_t, 4>;
union UnionBuf_A_ping
{
V4UInt_A_Buffer u = 0;
MXFP4_A_Buffer_ping mxfp4;
} ua_ping;
using MXFP4_A_Buffer_pong =
decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{})));
union UnionBuf_A_pong
{
V4UInt_A_Buffer u = 0;
MXFP4_A_Buffer_pong mxfp4;
} ua_pong;
// preload A00,A10... from lds
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
m_preload>
a_warp_tensor;
statically_indexed_array<V4UInt_A_Buffer, m_preload> a_warp_tensor;
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_ping.u;
});
__builtin_amdgcn_sched_barrier(0);
@@ -771,7 +787,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = ub.u;
});
});
@@ -828,13 +845,19 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
UnionBuf_A_ping ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
ua_compute.mxfp4,
ub_compute.mxfp4,
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0],
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -849,7 +872,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
@@ -858,8 +882,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(
ua_ping.mxfp4 = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
}
// barrier
@@ -882,8 +907,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
});
HotLoopScheduler();
@@ -901,7 +926,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_ping(nIter)(kIter) = ub.u;
});
});
@@ -958,13 +984,20 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_pong ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
ua_compute.mxfp4,
ub_compute.mxfp4,
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
@@ -979,7 +1012,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
@@ -988,8 +1022,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(
ua_pong.mxfp4 = load_tile(
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
}
// barrier
@@ -1012,8 +1047,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_ping.u; // reload a_warp_tensor with ping buffer
});
HotLoopScheduler();
@@ -1036,7 +1071,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = ub.u;
});
});
@@ -1088,13 +1124,20 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_ping ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
ua_compute.mxfp4,
ub_compute.mxfp4,
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -1109,7 +1152,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
@@ -1118,8 +1162,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(
ua_ping.mxfp4 = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
}
// barrier
@@ -1137,8 +1182,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
});
Last2ndHotLoopScheduler();
@@ -1164,13 +1209,19 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_pong ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
ua_compute.mxfp4,
ub_compute.mxfp4,
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
@@ -1185,7 +1236,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
@@ -1194,8 +1246,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(
ua_pong.mxfp4 = load_tile(
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
}
// barrier
@@ -1234,13 +1287,20 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_ping ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
ua_compute.mxfp4,
ub_compute.mxfp4,
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -1255,7 +1315,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp) <
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
@@ -1264,8 +1325,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
a_warp_tensor(number<AwarpIter>{}) = load_tile(
ua_ping.mxfp4 = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
}
// barrier