mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
fix flatmm with scaling when WarpTileM == 32
This commit is contained in:
@@ -463,6 +463,7 @@ struct CShuffleEpilogue
|
|||||||
constexpr int kM0 = MWave;
|
constexpr int kM0 = MWave;
|
||||||
constexpr int kM2 = 4;
|
constexpr int kM2 = 4;
|
||||||
constexpr int kM1 = MPerXdl / kM2;
|
constexpr int kM1 = MPerXdl / kM2;
|
||||||
|
static_assert(MPerXdl == 16, "TiledMMAPermuteN only supports MPerXdl = 16 now");
|
||||||
|
|
||||||
constexpr int kN0 = NWave;
|
constexpr int kN0 = NWave;
|
||||||
constexpr int kN1 = NPerXdl;
|
constexpr int kN1 = NPerXdl;
|
||||||
@@ -503,13 +504,11 @@ struct CShuffleEpilogue
|
|||||||
float vec_scale_A[kM2 * MRepeat];
|
float vec_scale_A[kM2 * MRepeat];
|
||||||
float vec_scale_B[NRepeat];
|
float vec_scale_B[NRepeat];
|
||||||
|
|
||||||
#pragma unroll
|
_Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
|
||||||
for(int i = 0; i < NRepeat; ++i)
|
|
||||||
{
|
{
|
||||||
vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl];
|
vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl];
|
||||||
}
|
}
|
||||||
#pragma unroll
|
_Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
|
||||||
for(int i = 0; i < MRepeat; ++i)
|
|
||||||
{
|
{
|
||||||
vec_scale_A[i * kM2 + 0] =
|
vec_scale_A[i * kM2 + 0] =
|
||||||
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||||
@@ -631,7 +630,6 @@ struct CShuffleEpilogue
|
|||||||
constexpr int kM2 = 4; // Val
|
constexpr int kM2 = 4; // Val
|
||||||
constexpr int kM1 = (64 / NPerXdl); // Thr
|
constexpr int kM1 = (64 / NPerXdl); // Thr
|
||||||
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
|
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
|
||||||
static_assert(kM0 == 1, "only support when XDL_M == 16");
|
|
||||||
|
|
||||||
const index_t iMWarp = get_warp_id() / NWave;
|
const index_t iMWarp = get_warp_id() / NWave;
|
||||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||||
@@ -641,22 +639,27 @@ struct CShuffleEpilogue
|
|||||||
float vec_scale_A[kM0 * kM2 * MRepeat];
|
float vec_scale_A[kM0 * kM2 * MRepeat];
|
||||||
float vec_scale_B[NRepeat];
|
float vec_scale_B[NRepeat];
|
||||||
|
|
||||||
#pragma unroll
|
_Pragma("unroll") for(int i = 0; i < NRepeat; ++i)
|
||||||
for(int i = 0; i < NRepeat; ++i)
|
|
||||||
{
|
{
|
||||||
vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||||
}
|
}
|
||||||
#pragma unroll
|
_Pragma("unroll") for(int i = 0; i < MRepeat; ++i)
|
||||||
for(int i = 0; i < MRepeat; ++i)
|
|
||||||
{
|
{
|
||||||
vec_scale_A[i * kM2 + 0] =
|
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||||
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
{
|
||||||
vec_scale_A[i * kM2 + 1] =
|
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 0] =
|
||||||
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
scale_m[0 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||||
vec_scale_A[i * kM2 + 2] =
|
i * MPerXdl * MWave];
|
||||||
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 1] =
|
||||||
vec_scale_A[i * kM2 + 3] =
|
scale_m[1 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||||
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
i * MPerXdl * MWave];
|
||||||
|
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 2] =
|
||||||
|
scale_m[2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||||
|
i * MPerXdl * MWave];
|
||||||
|
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + 3] =
|
||||||
|
scale_m[3 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||||
|
i * MPerXdl * MWave];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||||
@@ -668,14 +671,17 @@ struct CShuffleEpilogue
|
|||||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||||
constexpr int acc_xdl_offset =
|
constexpr int acc_xdl_offset =
|
||||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 0] *=
|
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||||
vec_scale_A[m_xdl * kM2 + 0] * vec_scale_B[n_xdl];
|
{
|
||||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 1] *=
|
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||||
vec_scale_A[m_xdl * kM2 + 1] * vec_scale_B[n_xdl];
|
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 0] * vec_scale_B[n_xdl];
|
||||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 2] *=
|
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||||
vec_scale_A[m_xdl * kM2 + 2] * vec_scale_B[n_xdl];
|
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 1] * vec_scale_B[n_xdl];
|
||||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 3] *=
|
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||||
vec_scale_A[m_xdl * kM2 + 3] * vec_scale_B[n_xdl];
|
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 2] * vec_scale_B[n_xdl];
|
||||||
|
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||||
|
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + 3] * vec_scale_B[n_xdl];
|
||||||
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -705,18 +711,29 @@ struct CShuffleEpilogue
|
|||||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||||
constexpr int acc_xdl_offset =
|
constexpr int acc_xdl_offset =
|
||||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||||
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 0] *=
|
_Pragma("unroll") for(int m0 = 0; m0 < kM0; ++m0)
|
||||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +0] *
|
{
|
||||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
lds_tile[write_stage]
|
||||||
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 1] *=
|
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +1] *
|
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
m_xdl * kM0 * kM2 + m0 * kM2 + 0] *
|
||||||
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 2] *=
|
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +2] *
|
lds_tile[write_stage]
|
||||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||||
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 3] *=
|
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +3] *
|
m_xdl * kM0 * kM2 + m0 * kM2 + 1] *
|
||||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||||
|
lds_tile[write_stage]
|
||||||
|
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||||
|
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||||
|
m_xdl * kM0 * kM2 + m0 * kM2 + 2] *
|
||||||
|
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||||
|
lds_tile[write_stage]
|
||||||
|
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||||
|
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||||
|
m_xdl * kM0 * kM2 + m0 * kM2 + 3] *
|
||||||
|
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||||
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user