mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4984 (commit 962b047)
[CK_TILE] Reduce Register Spills in Stream-K Reductions (#4984) ## Motivation In CK Tile Stream-K, kernels using one of two non-atomic reduction strategies (i.e., linear, tree) have high register spill count, with the tree reduction generally being worse. These changes act a first step to help decrease the register spill count. ## Technical Details ### Problem 1: Unvectorized access to partials In both the linear and tree reductions, workgroups write partials results to a global buffer; another workgroup will later read this data. When the initial logic to support reading and writing to the partials buffer was added (see https://github.com/ROCm/composable_kernel/pull/3107), the tile distribution encoding used to read from and write to partials matches the register layout for the accumulator of the mfma instruction used for the kernel. Since we do not currently use the transposed register layout for the accumulator, we end with an encoding that is not optimized for writing to HBM. For example: Consider the register layout of the `v_mfma_f32_16x16x32_fp8_fp8` instruction. ```bash ./matrix_calculator.py --architecture gfx942 --instruction v_mfma_f32_16x16x32_fp8_fp8 --register-layout --C-matrix ``` <img width="1113" height="537" alt="image" src="https://github.com/user-attachments/assets/afc8f556-08cc-4224-a6e5-b5edabc5fc02" /> The above shows that threads are responsible for consecutive elements down a column of the C tile. If we use this distribution to read and write to partials with C in row major, then threads are unable to perform vectorized reads and writes. Note: thread 0 is shown in red and thread 1 is shown in green. Since the C-shuffle Epilogue only supports C in row major, reading and writing to partials is highly unoptimized. ### Problem 2: Missed opportunity for SPGR use in tree reduction loop Since the reduction occurs between workgroups, all threads in the workgroup follow the same execution paths in the tree reduction logic, hence various variables should be using SGPRs, but they are not. ### Implemented Solutions 1. Add a new tile distribution encoding that is optimized for accessing partials in HBM. This encoding does not change the data assignment to threads, it merely changes the addresses to which they write/read in the partials buffer. For example, continuing with the `v_mfma_f32_16x16x32_fp8_fp8` instruction, the new encoding would result in threads writing in the following layout: <img width="517" height="342" alt="image" src="https://github.com/user-attachments/assets/93b5e0ea-bafc-47b8-89bb-c40ba75cb202" /> This layout ensures that each thread writes along a row, enabling `buffer_{store|load}_dwordx4` instructions (i.e., vectorized accesses). This helps reduce register usage due to requiring fewer offset calculations. 2. To force SGPR usage in the tree reduction loop, I make use of CK Tile's `amd_wave_read_first_lane` which is a wrapper around `__builtin_amdgcn_readfirstlane`. This helps reduce VGPR spills in the tree reduction. _These changes do not fully eliminate register spills. Future work will aim to further reduce spills. But these changes make good progress._ ## Test Plan Added tests for different warp tile sizes to validate that the new encoding works with different `WarpGemm` variants. ## Test Result All tests pass locally on all gfx9 architectures. Some results for decreases in register spills on gfx942: (BL = baseline) | Kernel | SGPR Spill (BL) | SGPR Spill (new) | SGPR Delta | SGPR % | VGPR Spill (BL) | VGPR Spill (new) | VGPR Delta | VGPR % | |--------|------------------:|------------------:|-----------:|-------:|-------------------:|------------------:|-----------:|-------:| | fp16 linear F/F/F/T 256x256x32 2x2x1 32x32x16 | 223 | 0 | -223 | -100.0% | 21 | 20 | -1 | -4.8% | | fp16 tree F/F/F/T 256x256x32 2x2x1 32x32x16 | 233 | 11 | -222 | -95.3% | 443 | 23 | -420 | -94.8% | | fp8 linear F/F/F/F 256x256x32 2x2x1 32x32x32 | 221 | 3 | -218 | -98.6% | 12 | 6 | -6 | -50.0% | | fp8 tree F/F/F/F 256x256x32 2x2x1 32x32x32 | 230 | 14 | -216 | -93.9% | 396 | 12 | -384 | -97.0% | ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b042e1805a
commit
f1746955fd
@@ -75,6 +75,8 @@ struct StreamKKernel
|
||||
using TilePartitioner = TilePartitioner_;
|
||||
using GemmPipeline = GemmPipeline_;
|
||||
using EpiloguePipeline = EpiloguePipeline_;
|
||||
using WarpGemm = typename GemmPipeline::BlockGemm::WarpGemm;
|
||||
using BlockGemmShape = typename GemmPipeline::BlockGemmShape;
|
||||
|
||||
static_assert(
|
||||
TilePartitioner::PERSISTENT == PersistentDP,
|
||||
@@ -156,7 +158,7 @@ struct StreamKKernel
|
||||
{
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
using WarpTile = typename P_::BlockGemmShape::WarpTile;
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
|
||||
return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
@@ -407,16 +409,84 @@ struct StreamKKernel
|
||||
static_cast<DataType*>(partial_buffer_ptr),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(TilePartitioner::NPerBlock, 1),
|
||||
number<GemmPipeline::GetVectorSizeC()>{},
|
||||
number<GetVectorSizePartials()>{},
|
||||
number<1>{});
|
||||
|
||||
auto partial_tile_window = make_tile_window(
|
||||
partial_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{0, 0},
|
||||
c_block_tile_dist);
|
||||
MakePartialsDistribution());
|
||||
|
||||
return load_tile(partial_tile_window);
|
||||
auto partials_tile = load_tile(partial_tile_window);
|
||||
|
||||
// Since the partials distribution is not the same as the C block distribution, we must
|
||||
// describe the contents in the partials tile with the C block distribution.
|
||||
// Note: The data assigned to threads does not change between distributions.
|
||||
auto partials_tile_with_c_distr = make_static_distributed_tensor<DataType>(
|
||||
c_block_tile_dist, partials_tile.get_thread_buffer());
|
||||
|
||||
return partials_tile_with_c_distr;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns the vector size to be used for reading from and writing to partials.
|
||||
* @return The vector size
|
||||
*/
|
||||
CK_TILE_DEVICE static constexpr index_t GetVectorSizePartials()
|
||||
{
|
||||
// We use kCM1PerLane from the C register layout of the warp GEMM which corresponds to the
|
||||
// maximum vector width
|
||||
return WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns distribution used for reading from and writing to partials.
|
||||
* @return The distribution.
|
||||
* @note This will result in optimized reads from and writes to partials when C is row major.
|
||||
* Additional functionality should be added to ensure optimized accesses to partials when C is
|
||||
* column major. Since the C-Shuffle epilogue only supports C as row major, this is not a
|
||||
* current limitation.
|
||||
*/
|
||||
CK_TILE_DEVICE static constexpr auto MakePartialsDistribution()
|
||||
{
|
||||
// Create the encoding to describe waves within a block
|
||||
constexpr index_t m_warp = BlockGemmShape::BlockWarps::at(number<0>{});
|
||||
constexpr index_t n_warp = BlockGemmShape::BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t m_iter_per_warp = TilePartitioner::MPerBlock / (m_warp * WarpGemm::kM);
|
||||
constexpr index_t n_iter_per_warp = TilePartitioner::NPerBlock / (n_warp * WarpGemm::kN);
|
||||
|
||||
constexpr auto partials_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<m_iter_per_warp, m_warp>, sequence<n_iter_per_warp, n_warp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
// Create the encoding to describe threads within a wave
|
||||
constexpr index_t vector_size = GetVectorSizePartials();
|
||||
constexpr index_t m_warp_repeat = WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane;
|
||||
constexpr index_t warp_tile_n_threads = WarpGemm::kN / vector_size;
|
||||
constexpr index_t warp_tile_m_threads = get_warp_size() / warp_tile_n_threads;
|
||||
|
||||
// This inner encoding ensures that contiguous threads perform vectorized writes along the
|
||||
// same row in C.
|
||||
constexpr auto partials_inner_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<m_warp_repeat, warp_tile_m_threads>,
|
||||
sequence<warp_tile_n_threads, vector_size>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{};
|
||||
|
||||
// Combine the outer and inner encoding
|
||||
constexpr auto partials_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
partials_outer_dstr_encoding, partials_inner_dstr_encoding);
|
||||
|
||||
return make_static_tile_distribution(partials_dstr_encode);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -446,14 +516,22 @@ struct StreamKKernel
|
||||
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(TilePartitioner::NPerBlock, 1),
|
||||
number<GemmPipeline::GetVectorSizeC()>{},
|
||||
number<GetVectorSizePartials()>{},
|
||||
number<1>{});
|
||||
|
||||
auto partial_tile_window = make_tile_window(
|
||||
partial_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{0, 0});
|
||||
store_tile(partial_tile_window, c_block_tile);
|
||||
{0, 0},
|
||||
MakePartialsDistribution());
|
||||
|
||||
// Since the C block distribution is not the same as the partials distribution, we must
|
||||
// describe the contents in the c_block_tile with the partials distribution.
|
||||
// Note: The data assigned to threads does not change between distributions.
|
||||
auto c_with_partials_dist = make_static_distributed_tensor<typename OAccTile::DataType>(
|
||||
MakePartialsDistribution(), c_block_tile.get_thread_buffer());
|
||||
|
||||
store_tile(partial_tile_window, c_with_partials_dist);
|
||||
// Wait for all vector stores for this wavefront to complete
|
||||
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
|
||||
// Wait for all wavefronts in this workgroup to arrive here before continuing
|
||||
@@ -591,16 +669,19 @@ struct StreamKKernel
|
||||
}
|
||||
else // Tree Reduction
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
index_t tile_local_cta_idx =
|
||||
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
|
||||
auto accum_block_tile = c_block_tile;
|
||||
index_t tile_local_cta_idx = amd_wave_read_first_lane(
|
||||
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx));
|
||||
|
||||
for(index_t stride = 1;; stride <<= 1)
|
||||
index_t stride = amd_wave_read_first_lane(1);
|
||||
|
||||
for(;; stride <<= 1)
|
||||
{
|
||||
const index_t partner_cta_idx = cta_idx + stride;
|
||||
const index_t partner_start_iter =
|
||||
kargs.tile_partitioner.get_start_iter(partner_cta_idx);
|
||||
bool partner_in_tile = partner_start_iter < tile_iter_end;
|
||||
const index_t partner_cta_idx = amd_wave_read_first_lane(cta_idx + stride);
|
||||
const index_t partner_start_iter = amd_wave_read_first_lane(
|
||||
kargs.tile_partitioner.get_start_iter(partner_cta_idx));
|
||||
bool partner_in_tile =
|
||||
amd_wave_read_first_lane(partner_start_iter < tile_iter_end);
|
||||
|
||||
// If the partner of the workgroup who started the tile is not in this tile,
|
||||
// then the work for this tile is done and results can be stored in the C
|
||||
|
||||
Reference in New Issue
Block a user