fix a bug, set the A DS_read preload size to 4 for MXFP4

This commit is contained in:
mtgu0705
2025-09-19 00:39:46 -05:00
parent 92ad6fcc0a
commit 62607de56c
2 changed files with 48 additions and 68 deletions

View File

@@ -37,8 +37,7 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
// return run_func(bool_constant<true>{}, integral_constant<TailNumber,
// TailNumber::Empty>{});
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
@@ -999,7 +998,7 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType & a) { return a; },
[](const ADataType& a) { return a; },
b_flat_dram_block_window_tmp,
num_loop,
p_smem_ping,

View File

@@ -72,7 +72,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
static constexpr index_t DsReadPreload = 4; // default 4 for MXFP4 (MXdlPack * KXdlPack)
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t WaveSize = get_warp_size();
@@ -698,7 +698,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, KIterPerWarp * KFlatPerBlockPerIter});
// prefetch Scale A and Scale B
// prefetch Scale A
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
@@ -712,6 +712,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
// move Scale A window to next K
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
// prefetch Scale B
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
@@ -763,8 +764,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
statically_indexed_array<V4UInt_A_Buffer, m_preload> a_warp_tensor;
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_ping.u;
@@ -830,10 +831,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter =
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
mIter_pack * MXdlPack + imxdl) %
m_preload;
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -872,14 +870,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_ping.mxfp4 = load_tile(
@@ -905,8 +902,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
});
@@ -969,10 +966,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter =
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
mIter_pack * MXdlPack + imxdl) %
m_preload;
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -1012,14 +1006,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_pong.mxfp4 = load_tile(
@@ -1045,8 +1038,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_ping.u; // reload a_warp_tensor with ping buffer
});
@@ -1109,10 +1102,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter =
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
mIter_pack * MXdlPack + imxdl) %
m_preload;
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -1152,14 +1142,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_ping.mxfp4 = load_tile(
@@ -1180,8 +1169,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
});
@@ -1194,10 +1183,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter =
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
mIter_pack * MXdlPack + imxdl) %
m_preload;
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -1236,14 +1222,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_pong.mxfp4 = load_tile(
@@ -1272,10 +1257,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter =
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
mIter_pack * MXdlPack + imxdl) %
m_preload;
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -1315,14 +1297,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_ping.mxfp4 = load_tile(