mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
fix a bug, set the A DS_read preload size to 4 for MXFP4
This commit is contained in:
@@ -37,8 +37,7 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
// return run_func(bool_constant<true>{}, integral_constant<TailNumber,
|
||||
// TailNumber::Empty>{});
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -999,7 +998,7 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType & a) { return a; },
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
|
||||
@@ -72,7 +72,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
static constexpr index_t DsReadPreload = 4; // default 4 for MXFP4 (MXdlPack * KXdlPack)
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
@@ -698,7 +698,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, KIterPerWarp * KFlatPerBlockPerIter});
|
||||
|
||||
// prefetch Scale A and Scale B
|
||||
// prefetch Scale A
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
@@ -712,6 +712,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// move Scale A window to next K
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// prefetch Scale B
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
@@ -763,8 +764,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
statically_indexed_array<V4UInt_A_Buffer, m_preload> a_warp_tensor;
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
|
||||
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_ping.u;
|
||||
@@ -830,10 +831,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
mIter_pack * MXdlPack + imxdl) %
|
||||
m_preload;
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -872,14 +870,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
@@ -905,8 +902,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
|
||||
});
|
||||
@@ -969,10 +966,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
mIter_pack * MXdlPack + imxdl) %
|
||||
m_preload;
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -1012,14 +1006,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_pong.mxfp4 = load_tile(
|
||||
@@ -1045,8 +1038,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_ping.u; // reload a_warp_tensor with ping buffer
|
||||
});
|
||||
@@ -1109,10 +1102,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
mIter_pack * MXdlPack + imxdl) %
|
||||
m_preload;
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -1152,14 +1142,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
@@ -1180,8 +1169,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
|
||||
});
|
||||
@@ -1194,10 +1183,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
mIter_pack * MXdlPack + imxdl) %
|
||||
m_preload;
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -1236,14 +1222,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_pong.mxfp4 = load_tile(
|
||||
@@ -1272,10 +1257,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter =
|
||||
((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
mIter_pack * MXdlPack + imxdl) %
|
||||
m_preload;
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -1315,14 +1297,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr(((kIter_pack * KXdlPack + ikxdl) * MIterPerWarp +
|
||||
(mIter_pack * MXdlPack + imxdl)) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
|
||||
Reference in New Issue
Block a user