optimize scaling epilogue

This commit is contained in:
Feng Shijie
2025-08-01 11:01:23 +00:00
parent ac5908c0bb
commit aa5e008fa5

View File

@@ -488,7 +488,9 @@ struct CShuffleEpilogue
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
using ShuffleAcc =
decltype(make_static_distributed_tensor<AccDataType>(dram_tile_distribution));
ShuffleAcc shuffle_acc[MRepeat];
auto c_out_tensor_fp32 =
make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
@@ -498,30 +500,53 @@ struct CShuffleEpilogue
const index_t iMLane = get_lane_id() / NPerXdl;
const index_t iNLane = get_lane_id() % NPerXdl;
float vec_scale_A[kM2 * MRepeat];
float vec_scale_B[NRepeat];
#pragma unroll
for(int i = 0; i < NRepeat; ++i)
{
vec_scale_B[i] = scale_n[i + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl];
}
#pragma unroll
for(int i = 0; i < MRepeat; ++i)
{
vec_scale_A[i * kM2 + 0] =
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
vec_scale_A[i * kM2 + 1] =
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
vec_scale_A[i * kM2 + 2] =
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
vec_scale_A[i * kM2 + 3] =
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
}
static_for<0, MRepeat, 1>{}([&](auto mIter) {
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
shuffle_acc[mIter].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
float scale_B = scale_n[n_idx + iNLane * NRepeat + iNWarp * NRepeat * NPerXdl];
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 0] *= vec_scale_B[n_idx];
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 1] *= vec_scale_B[n_idx];
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 2] *= vec_scale_B[n_idx];
shuffle_acc[mIter].get_thread_buffer()[n_idx * kM2 + 3] *= vec_scale_B[n_idx];
});
});
static_for<0, MRepeat, 1>{}([&](auto mIter) {
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
c_out_tensor_fp32.get_thread_buffer()[n_idx + 0 * NRepeat] =
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] *
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + mIter * MPerXdl * MWave] *
scale_B;
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 0] *
vec_scale_A[mIter * kM2 + 0];
c_out_tensor_fp32.get_thread_buffer()[n_idx + 1 * NRepeat] =
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] *
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + mIter * MPerXdl * MWave] *
scale_B;
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 1] *
vec_scale_A[mIter * kM2 + 1];
c_out_tensor_fp32.get_thread_buffer()[n_idx + 2 * NRepeat] =
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] *
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + mIter * MPerXdl * MWave] *
scale_B;
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 2] *
vec_scale_A[mIter * kM2 + 2];
c_out_tensor_fp32.get_thread_buffer()[n_idx + 3 * NRepeat] =
shuffle_acc.get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] *
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + mIter * MPerXdl * MWave] *
scale_B;
shuffle_acc[mIter].get_thread_buffer()[n_idx * c_warp_y_lengths.product() + 3] *
vec_scale_A[mIter * kM2 + 3];
});
c_out_tensor = cast_tile<ODataType>(c_out_tensor_fp32);
@@ -606,49 +631,51 @@ struct CShuffleEpilogue
constexpr int kM2 = 4; // Val
constexpr int kM1 = (64 / NPerXdl); // Thr
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
static_assert(kM0 == 1, "only support when XDL_M == 16");
const index_t iMWarp = get_warp_id() / NWave;
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
const index_t iMLane = get_lane_id() / NPerXdl;
const index_t iNLane = get_lane_id() % NPerXdl;
float vec_scale_A[kM0 * kM2 * MRepeat];
float vec_scale_B[NRepeat];
#pragma unroll
for(int i = 0; i < NRepeat; ++i)
{
vec_scale_B[i] = scale_n[i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
}
#pragma unroll
for(int i = 0; i < MRepeat; ++i)
{
vec_scale_A[i * kM2 + 0] =
scale_m[0 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
vec_scale_A[i * kM2 + 1] =
scale_m[1 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
vec_scale_A[i * kM2 + 2] =
scale_m[2 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
vec_scale_A[i * kM2 + 3] =
scale_m[3 + iMLane * kM2 + iMWarp * MPerXdl + i * MPerXdl * MWave];
}
lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
float scale_B =
scale_n[0 * NPerIterationShuffle + iNWarp * NumNXdlPerWavePerShuffle * NPerXdl +
n_xdl * NPerXdl + iNLane];
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
fp32x4_t vec_scale_A;
vec_scale_A.x = scale_m[0 * MPerIterationShuffle +
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 0];
vec_scale_A.y = scale_m[0 * MPerIterationShuffle +
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 1];
vec_scale_A.z = scale_m[0 * MPerIterationShuffle +
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 2];
vec_scale_A.w = scale_m[0 * MPerIterationShuffle +
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 3];
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
vec_scale_A.x * scale_B;
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
vec_scale_A.y * scale_B;
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
vec_scale_A.z * scale_B;
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
vec_scale_A.w * scale_B;
});
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 0] *=
vec_scale_A[m_xdl * kM2 + 0] * vec_scale_B[n_xdl];
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 1] *=
vec_scale_A[m_xdl * kM2 + 1] * vec_scale_B[n_xdl];
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 2] *=
vec_scale_A[m_xdl * kM2 + 2] * vec_scale_B[n_xdl];
lds_tile[0].get_thread_buffer()[acc_xdl_offset + 3] *=
vec_scale_A[m_xdl * kM2 + 3] * vec_scale_B[n_xdl];
});
});
@@ -675,45 +702,21 @@ struct CShuffleEpilogue
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
float scale_B = scale_n[nIter * NPerIterationShuffle +
iNWarp * NumNXdlPerWavePerShuffle * NPerXdl +
n_xdl * NPerXdl + iNLane];
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
fp32x4_t vec_scale_A;
vec_scale_A.x =
scale_m[mIter * MPerIterationShuffle +
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 0];
vec_scale_A.y =
scale_m[mIter * MPerIterationShuffle +
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 1];
vec_scale_A.z =
scale_m[mIter * MPerIterationShuffle +
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 2];
vec_scale_A.w =
scale_m[mIter * MPerIterationShuffle +
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 3];
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
vec_scale_A.x * scale_B;
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
vec_scale_A.y * scale_B;
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
vec_scale_A.z * scale_B;
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
vec_scale_A.w * scale_B;
});
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 0] *=
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +0] *
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 1] *=
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +1] *
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 2] *=
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +2] *
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
lds_tile[write_stage].get_thread_buffer()[acc_xdl_offset + 3] *=
vec_scale_A[mIter * NumMXdlPerWavePerShuffle * kM2 + m_xdl * kM2 + +3] *
vec_scale_B[nIter * NumNXdlPerWavePerShuffle + n_xdl];
});
});
}