fix flatmm with scaling when WarpTileM == 32

This commit is contained in:
Feng Shijie
2025-08-04 07:16:36 +00:00
parent aa5e008fa5
commit 90e910f3a7

View File

@@ -463,6 +463,7 @@ struct CShuffleEpilogue
constexpr int kM0 = MWave; constexpr int kM0 = MWave;
constexpr int kM2 = 4; constexpr int kM2 = 4;
constexpr int kM1 = MPerXdl / kM2; constexpr int kM1 = MPerXdl / kM2;
static_assert(MPerXdl == 16, "TiledMMAPermuteN only supports MPerXdl = 16 now");
constexpr int kN0 = NWave; constexpr int kN0 = NWave;
constexpr int kN1 = NPerXdl; constexpr int kN1 = NPerXdl;
@@ -503,13 +504,11 @@ struct CShuffleEpilogue
float vec_scale_A[kM2 * MRepeat]; float vec_scale_A[kM2 * MRepeat];
float vec_scale_B[NRepeat]; float vec_scale_B[NRepeat];
#pragma unroll _Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
for(int i = 0; i < NRepeat; ++i)
{ {
vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl]; vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl];
} }
#pragma unroll _Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
for(int i = 0; i < MRepeat; ++i)
{ {
vec_scale_A[i * kM2 + 0] = vec_scale_A[i * kM2 + 0] =
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
@@ -631,7 +630,6 @@ struct CShuffleEpilogue
constexpr int kM2 = 4; // Val constexpr int kM2 = 4; // Val
constexpr int kM1 = (64 / NPerXdl); // Thr constexpr int kM1 = (64 / NPerXdl); // Thr
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
static_assert(kM0 == 1, "only support when XDL_M == 16");
const index_t iMWarp = get_warp_id() / NWave; const index_t iMWarp = get_warp_id() / NWave;
const index_t iNWarp = get_warp_id() - iMWarp * NWave; const index_t iNWarp = get_warp_id() - iMWarp * NWave;
@@ -641,22 +639,27 @@ struct CShuffleEpilogue
float vec_scale_A[kM0 * kM2 * MRepeat]; float vec_scale_A[kM0 * kM2 * MRepeat];
float vec_scale_B[NRepeat]; float vec_scale_B[NRepeat];
#pragma unroll _Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
for(int i = 0; i < NRepeat; ++i)
{ {
vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane]; vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
} }
#pragma unroll _Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
for(int i = 0; i < MRepeat; ++i)
{ {
vec_scale_A[i * kM2 + 0] = _Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; {
vec_scale_A[i * kM2 + 1] = vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 0] =
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; scale_m[0 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
vec_scale_A[i * kM2 + 2] = i * MPerXdl * MWave];
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 1] =
vec_scale_A[i * kM2 + 3] = scale_m[1 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave]; i * MPerXdl * MWave];
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 2] =
scale_m[2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
i * MPerXdl * MWave];
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 3] =
scale_m[3 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
i * MPerXdl * MWave];
}
} }
lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
@@ -668,14 +671,17 @@ struct CShuffleEpilogue
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) { static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset = constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product(); (m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 0] *= _Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
vec_scale_A[m_xdl * kM2 + 0] * vec_scale_B[n_xdl]; {
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 1] *= lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
vec_scale_A[m_xdl * kM2 + 1] * vec_scale_B[n_xdl]; vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 0] * vec_scale_B[n_xdl];
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 2] *= lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
vec_scale_A[m_xdl * kM2 + 2] * vec_scale_B[n_xdl]; vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 1] * vec_scale_B[n_xdl];
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 3] *= lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
vec_scale_A[m_xdl * kM2 + 3] * vec_scale_B[n_xdl]; vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 2] * vec_scale_B[n_xdl];
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 3] * vec_scale_B[n_xdl];
}
}); });
}); });
@@ -705,18 +711,29 @@ struct CShuffleEpilogue
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) { static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset = constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product(); (m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 0] *= _Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +0] * {
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl]; lds_tile[write_stage]
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 1] *= .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +1] * vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl]; m_xdl * kM0 * kM2 + m0 * kM2 + 0] *
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 2] *= vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +2] * lds_tile[write_stage]
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl]; .get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 3] *= vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +3] * m_xdl * kM0 * kM2 + m0 * kM2 + 1] *
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl]; vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + 2] *
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + 3] *
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
}
}); });
}); });
} }