mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
optimize scaling epilogue
This commit is contained in:
@@ -488,7 +488,9 @@ struct CShuffleEpilogue
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
using ShuffleAcc =
|
||||
decltype(make_static_distributed_tensor<AccDataType>(dram_tile_distribution));
|
||||
ShuffleAcc shuffle_acc[MRepeat];
|
||||
auto c_out_tensor_fp32 =
|
||||
make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
@@ -498,30 +500,53 @@ struct CShuffleEpilogue
|
||||
const index_t iMLane = get_lane_id() / NPerXdl;
|
||||
const index_t iNLane = get_lane_id() % NPerXdl;
|
||||
|
||||
float vec_scale_A[kM2 * MRepeat];
|
||||
float vec_scale_B[NRepeat];
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < NRepeat; ++i)
|
||||
{
|
||||
vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl];
|
||||
}
|
||||
#pragma unroll
|
||||
for(int i = 0; i < MRepeat; ++i)
|
||||
{
|
||||
vec_scale_A[i * kM2 + 0] =
|
||||
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 1] =
|
||||
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 2] =
|
||||
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 3] =
|
||||
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
}
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
shuffle_acc[mIter].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
float scale_B = scale_n[n_idx + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 0] *= vec_scale_B[n_idx];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 1] *= vec_scale_B[n_idx];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 2] *= vec_scale_B[n_idx];
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 3] *= vec_scale_B[n_idx];
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 0 * NRepeat] =
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] *
|
||||
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + mIter * MPerXdl * MWave] *
|
||||
scale_B;
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] *
|
||||
vec_scale_A[mIter * kM2 + 0];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 1 * NRepeat] =
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] *
|
||||
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + mIter * MPerXdl * MWave] *
|
||||
scale_B;
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] *
|
||||
vec_scale_A[mIter * kM2 + 1];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 2 * NRepeat] =
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] *
|
||||
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + mIter * MPerXdl * MWave] *
|
||||
scale_B;
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] *
|
||||
vec_scale_A[mIter * kM2 + 2];
|
||||
c_out_tensor_fp32.get_thread_buffer()[n_idx + 3 * NRepeat] =
|
||||
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] *
|
||||
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + mIter * MPerXdl * MWave] *
|
||||
scale_B;
|
||||
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] *
|
||||
vec_scale_A[mIter * kM2 + 3];
|
||||
});
|
||||
|
||||
c_out_tensor = cast_tile<ODataType>(c_out_tensor_fp32);
|
||||
@@ -606,49 +631,51 @@ struct CShuffleEpilogue
|
||||
constexpr int kM2 = 4; // Val
|
||||
constexpr int kM1 = (64 / NPerXdl); // Thr
|
||||
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
|
||||
static_assert(kM0 == 1, "only support when XDL_M == 16");
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
const index_t iMLane = get_lane_id() / NPerXdl;
|
||||
const index_t iNLane = get_lane_id() % NPerXdl;
|
||||
|
||||
float vec_scale_A[kM0 * kM2 * MRepeat];
|
||||
float vec_scale_B[NRepeat];
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < NRepeat; ++i)
|
||||
{
|
||||
vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
}
|
||||
#pragma unroll
|
||||
for(int i = 0; i < MRepeat; ++i)
|
||||
{
|
||||
vec_scale_A[i * kM2 + 0] =
|
||||
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 1] =
|
||||
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 2] =
|
||||
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
vec_scale_A[i * kM2 + 3] =
|
||||
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
|
||||
}
|
||||
|
||||
lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
float scale_B =
|
||||
scale_n[0 * NPerIterationShuffle + iNWarp * NumNXdlPerWavePerShuffle * NPerXdl +
|
||||
n_xdl * NPerXdl + iNLane];
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
fp32x4_t vec_scale_A;
|
||||
vec_scale_A.x = scale_m[0 * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 0];
|
||||
vec_scale_A.y = scale_m[0 * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 1];
|
||||
vec_scale_A.z = scale_m[0 * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 2];
|
||||
vec_scale_A.w = scale_m[0 * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 3];
|
||||
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||
vec_scale_A.x * scale_B;
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||
vec_scale_A.y * scale_B;
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||
vec_scale_A.z * scale_B;
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||
vec_scale_A.w * scale_B;
|
||||
});
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 0] *=
|
||||
vec_scale_A[m_xdl * kM2 + 0] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 1] *=
|
||||
vec_scale_A[m_xdl * kM2 + 1] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 2] *=
|
||||
vec_scale_A[m_xdl * kM2 + 2] * vec_scale_B[n_xdl];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 3] *=
|
||||
vec_scale_A[m_xdl * kM2 + 3] * vec_scale_B[n_xdl];
|
||||
});
|
||||
});
|
||||
|
||||
@@ -675,45 +702,21 @@ struct CShuffleEpilogue
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
float scale_B = scale_n[nIter * NPerIterationShuffle +
|
||||
iNWarp * NumNXdlPerWavePerShuffle * NPerXdl +
|
||||
n_xdl * NPerXdl + iNLane];
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
fp32x4_t vec_scale_A;
|
||||
vec_scale_A.x =
|
||||
scale_m[mIter * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 0];
|
||||
vec_scale_A.y =
|
||||
scale_m[mIter * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 1];
|
||||
vec_scale_A.z =
|
||||
scale_m[mIter * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 2];
|
||||
vec_scale_A.w =
|
||||
scale_m[mIter * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 3];
|
||||
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||
vec_scale_A.x * scale_B;
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||
vec_scale_A.y * scale_B;
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||
vec_scale_A.z * scale_B;
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||
vec_scale_A.w * scale_B;
|
||||
});
|
||||
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 0] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +0] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 1] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +1] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 2] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +2] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 3] *=
|
||||
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +3] *
|
||||
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user