fix a_wrap preload issue for large MPerBlock.

This commit is contained in:
mtgu0705
2025-09-18 01:19:03 -05:00
parent f2db44710f
commit 92ad6fcc0a

View File

@@ -876,13 +876,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
constexpr auto AkIter =
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
ua_ping.mxfp4 = load_tile(
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_ping.mxfp4 = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
}
@@ -1016,13 +1016,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
constexpr auto AkIter =
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
ua_pong.mxfp4 = load_tile(
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_pong.mxfp4 = load_tile(
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
}
@@ -1156,13 +1156,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
constexpr auto AkIter =
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
ua_ping.mxfp4 = load_tile(
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_ping.mxfp4 = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
}
@@ -1240,13 +1240,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
constexpr auto AkIter =
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
ua_pong.mxfp4 = load_tile(
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_pong.mxfp4 = load_tile(
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
}
@@ -1319,13 +1319,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
(mIter_pack * MXdlPack + imxdl)) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter =
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
constexpr auto AkIter =
(kIter_pack * KXdlPack + ikxdl +
(mIter_pack * MXdlPack + imxdl + m_preload) /
MIterPerWarp);
ua_ping.mxfp4 = load_tile(
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_ping.mxfp4 = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
}