mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
fix a_wrap preload issue for large MPerBlock.
This commit is contained in:
@@ -876,13 +876,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter =
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
|
||||
}
|
||||
@@ -1016,13 +1016,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter =
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
ua_pong.mxfp4 = load_tile(
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_pong.mxfp4 = load_tile(
|
||||
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
|
||||
}
|
||||
@@ -1156,13 +1156,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter =
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
|
||||
}
|
||||
@@ -1240,13 +1240,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter =
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
ua_pong.mxfp4 = load_tile(
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_pong.mxfp4 = load_tile(
|
||||
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
|
||||
}
|
||||
@@ -1319,13 +1319,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter =
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter =
|
||||
(kIter_pack * KXdlPack + ikxdl +
|
||||
(mIter_pack * MXdlPack + imxdl + m_preload) /
|
||||
MIterPerWarp);
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user