mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
fix loop-dim mismatch and improve c_shuffle alu parallelism
This commit is contained in:
@@ -375,7 +375,8 @@ struct CShuffleEpilogue
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
|
||||
using LDSTileTensor = decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
|
||||
LDSTileTensor lds_tile[2];
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
@@ -419,53 +420,121 @@ struct CShuffleEpilogue
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr int kM2 = 4; // Val
|
||||
constexpr int kM1 = (64 / NPerXdl); // Thr
|
||||
constexpr int kM0 = MPerXdl / kM1; // Val
|
||||
constexpr int kM2 = 4; // Val
|
||||
constexpr int kM1 = (64 / NPerXdl); // Thr
|
||||
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
const index_t iMLane = get_lane_id() / NPerXdl;
|
||||
const index_t iNLane = get_lane_id() % NPerXdl;
|
||||
|
||||
lds_tile[0].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
float scale_B =
|
||||
scale_n[0 * NPerIterationShuffle + iNWarp * NumNXdlPerWavePerShuffle * NPerXdl +
|
||||
n_xdl * NPerXdl + iNLane];
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
fp32x4_t vec_scale_A;
|
||||
vec_scale_A.x = scale_m[0 * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 0];
|
||||
vec_scale_A.y = scale_m[0 * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 1];
|
||||
vec_scale_A.z = scale_m[0 * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 2];
|
||||
vec_scale_A.w = scale_m[0 * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 3];
|
||||
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||
vec_scale_A.x * scale_B;
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||
vec_scale_A.y * scale_B;
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||
vec_scale_A.z * scale_B;
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||
vec_scale_A.w * scale_B;
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
constexpr int read_stage = iAccess % 2;
|
||||
constexpr int write_stage = read_stage ^ 1;
|
||||
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
constexpr auto idx_y_start = SFC::get_index(number<iAccess.value + 1>{});
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile[read_stage]);
|
||||
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
float scale_B =
|
||||
scale_n[nIter * NPerIterationShuffle +
|
||||
iNWarp * NumNXdlPerWavePerShuffle * NPerXdl + n_xdl * NPerXdl + iNLane];
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * NumMXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product();
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
float scale_A =
|
||||
if constexpr(iAccess < num_access - 1)
|
||||
{
|
||||
lds_tile[write_stage].get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter * NumMXdlPerWavePerShuffle,
|
||||
nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
float scale_B = scale_n[nIter * NPerIterationShuffle +
|
||||
iNWarp * NumNXdlPerWavePerShuffle * NPerXdl +
|
||||
n_xdl * NPerXdl + iNLane];
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
fp32x4_t vec_scale_A;
|
||||
vec_scale_A.x =
|
||||
scale_m[mIter * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + m2];
|
||||
lds_tile.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
scale_A * scale_B;
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 0];
|
||||
vec_scale_A.y =
|
||||
scale_m[mIter * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 1];
|
||||
vec_scale_A.z =
|
||||
scale_m[mIter * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 2];
|
||||
vec_scale_A.w =
|
||||
scale_m[mIter * MPerIterationShuffle +
|
||||
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
|
||||
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + 3];
|
||||
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 0] *=
|
||||
vec_scale_A.x * scale_B;
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 1] *=
|
||||
vec_scale_A.y * scale_B;
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 2] *=
|
||||
vec_scale_A.z * scale_B;
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + 3] *=
|
||||
vec_scale_A.w * scale_B;
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
block_sync_lds();
|
||||
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
Reference in New Issue
Block a user