mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-21 07:37:38 +00:00
fix flatmm with scaling when WarpTileM == 32
This commit is contained in:
@@ -463,6 +463,7 @@ struct CShuffleEpilogue
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = 4;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
static_assert(MPerXdl == 16, "TiledMMAPermuteN only supports MPerXdl = 16 now");
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
constexpr int kN1 = NPerXdl;
|
||||
@@ -503,13 +504,11 @@ struct CShuffleEpilogue
|
||||
float vec_scale_A[kM2 * MRepeat];
|
||||
float vec_scale_B[NRepeat];
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < NRepeat; ++i)
|
||||
_Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
|
||||
{
|
||||
vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl];
|
||||
}
|
||||
#pragma unroll
|
||||
for(int i = 0; i < MRepeat; ++i)
|
||||
_Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
|
||||
{
|
||||
vec_scale_A[i * kM2 + 0] =
|
||||
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
@@ -631,7 +630,6 @@ struct CShuffleEpilogue
|
||||
constexpr int kM2 = 4; // Val
|
||||
constexpr int kM1 = (64 / NPerXdl); // Thr
|
||||
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
|
||||
static_assert(kM0 == 1, "only support when XDL_M == 16");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
@@ -641,22 +639,27 @@ struct CShuffleEpilogue
|
||||
float vec_scale_A[kM0 * kM2 * MRepeat];
|
||||
float vec_scale_B[NRepeat];
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < NRepeat; ++i)
|
||||
_Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
|
||||
{
|
||||
vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
}
|
||||
#pragma unroll
|
||||
for(int i = 0; i < MRepeat; ++i)
|
||||
_Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
|
||||
{
|
||||
vec_scale_A[i * kM2 + 0] =
|
||||
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 1] =
|
||||
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 2] =
|
||||
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 3] =
|
||||
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||
{
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 0] =
|
||||
scale_m[0 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 1] =
|
||||
scale_m[1 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 2] =
|
||||
scale_m[2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 3] =
|
||||
scale_m[3 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave];
|
||||
}
|
||||
}
|
||||
|
||||
lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
@@ -668,14 +671,17 @@ struct CShuffleEpilogue
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 0] *=
|
||||
vec_scale_A[m_xdl * kM2 + 0] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 1] *=
|
||||
vec_scale_A[m_xdl * kM2 + 1] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 2] *=
|
||||
vec_scale_A[m_xdl * kM2 + 2] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 3] *=
|
||||
vec_scale_A[m_xdl * kM2 + 3] * vec_scale_B[n_xdl];
|
||||
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||
{
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 0] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 1] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 2] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 3] * vec_scale_B[n_xdl];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -705,18 +711,29 @@ struct CShuffleEpilogue
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 0] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +0] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 1] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +1] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 2] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +2] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 3] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +3] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||
{
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 0] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 1] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 2] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + 3] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user