[rocm-libraries] ROCm/rocm-libraries#4984 (commit 962b047)

[CK_TILE] Reduce Register Spills in Stream-K Reductions
 (#4984)

## Motivation

In CK Tile Stream-K, kernels using one of two non-atomic reduction
strategies (i.e., linear, tree) have high register spill count, with the
tree reduction generally being worse. These changes act a first step to
help decrease the register spill count.

## Technical Details
### Problem 1: Unvectorized access to partials
In both the linear and tree reductions, workgroups write partials
results to a global buffer; another workgroup will later read this data.
When the initial logic to support reading and writing to the partials
buffer was added (see
https://github.com/ROCm/composable_kernel/pull/3107), the tile
distribution encoding used to read from and write to partials matches
the register layout for the accumulator of the mfma instruction used for
the kernel. Since we do not currently use the transposed register layout
for the accumulator, we end with an encoding that is not optimized for
writing to HBM.

For example: Consider the register layout of the
`v_mfma_f32_16x16x32_fp8_fp8` instruction.
```bash
./matrix_calculator.py --architecture gfx942 --instruction  v_mfma_f32_16x16x32_fp8_fp8 --register-layout --C-matrix
```
<img width="1113" height="537" alt="image"
src="https://github.com/user-attachments/assets/afc8f556-08cc-4224-a6e5-b5edabc5fc02"
/>

The above shows that threads are responsible for consecutive elements
down a column of the C tile. If we use this distribution to read and
write to partials with C in row major, then threads are unable to
perform vectorized reads and writes. Note: thread 0 is shown in red and
thread 1 is shown in green.

Since the C-shuffle Epilogue only supports C in row major, reading and
writing to partials is highly unoptimized.
### Problem 2: Missed opportunity for SPGR use in tree reduction loop
Since the reduction occurs between workgroups, all threads in the
workgroup follow the same execution paths in the tree reduction logic,
hence various variables should be using SGPRs, but they are not.

### Implemented Solutions
1. Add a new tile distribution encoding that is optimized for accessing
partials in HBM. This encoding does not change the data assignment to
threads, it merely changes the addresses to which they write/read in the
partials buffer. For example, continuing with the
`v_mfma_f32_16x16x32_fp8_fp8` instruction, the new encoding would result
in threads writing in the following layout:
<img width="517" height="342" alt="image"
src="https://github.com/user-attachments/assets/93b5e0ea-bafc-47b8-89bb-c40ba75cb202"
/>

This layout ensures that each thread writes along a row, enabling
`buffer_{store|load}_dwordx4` instructions (i.e., vectorized accesses).
This helps reduce register usage due to requiring fewer offset
calculations.

2. To force SGPR usage in the tree reduction loop, I make use of CK
Tile's `amd_wave_read_first_lane` which is a wrapper around
`__builtin_amdgcn_readfirstlane`. This helps reduce VGPR spills in the
tree reduction.

_These changes do not fully eliminate register spills. Future work will
aim to further reduce spills. But these changes make good progress._

## Test Plan

Added tests for different warp tile sizes to validate that the new
encoding works with different `WarpGemm` variants.

## Test Result

All tests pass locally on all gfx9 architectures.

Some results for decreases in register spills on gfx942: (BL = baseline)
| Kernel | SGPR Spill (BL) | SGPR Spill (new) | SGPR Delta | SGPR % |
VGPR Spill (BL) | VGPR Spill (new) | VGPR Delta | VGPR % |

|--------|------------------:|------------------:|-----------:|-------:|-------------------:|------------------:|-----------:|-------:|
| fp16 linear F/F/F/T 256x256x32 2x2x1 32x32x16 | 223 | 0 | -223 |
-100.0% | 21 | 20 | -1 | -4.8% |
| fp16 tree F/F/F/T 256x256x32 2x2x1 32x32x16 | 233 | 11 | -222 | -95.3%
| 443 | 23 | -420 | -94.8% |
| fp8 linear F/F/F/F 256x256x32 2x2x1 32x32x32 | 221 | 3 | -218 | -98.6%
| 12 | 6 | -6 | -50.0% |
| fp8 tree F/F/F/F 256x256x32 2x2x1 32x32x32 | 230 | 14 | -216 | -93.9%
| 396 | 12 | -384 | -97.0% |

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Emily Martins
2026-03-02 17:40:34 +00:00
committed by assistant-librarian[bot]
parent b042e1805a
commit f1746955fd
2 changed files with 99 additions and 18 deletions

View File

@@ -75,6 +75,8 @@ struct StreamKKernel
using TilePartitioner = TilePartitioner_;
using GemmPipeline = GemmPipeline_;
using EpiloguePipeline = EpiloguePipeline_;
using WarpGemm = typename GemmPipeline::BlockGemm::WarpGemm;
using BlockGemmShape = typename GemmPipeline::BlockGemmShape;
static_assert(
TilePartitioner::PERSISTENT == PersistentDP,
@@ -156,7 +158,7 @@ struct StreamKKernel
{
// clang-format off
using P_ = GemmPipeline;
using WarpTile = typename P_::BlockGemmShape::WarpTile;
using WarpTile = typename BlockGemmShape::WarpTile;
return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
@@ -407,16 +409,84 @@ struct StreamKKernel
static_cast<DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
number<GemmPipeline::GetVectorSizeC()>{},
number<GetVectorSizePartials()>{},
number<1>{});
auto partial_tile_window = make_tile_window(
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0},
c_block_tile_dist);
MakePartialsDistribution());
return load_tile(partial_tile_window);
auto partials_tile = load_tile(partial_tile_window);
// Since the partials distribution is not the same as the C block distribution, we must
// describe the contents in the partials tile with the C block distribution.
// Note: The data assigned to threads does not change between distributions.
auto partials_tile_with_c_distr = make_static_distributed_tensor<DataType>(
c_block_tile_dist, partials_tile.get_thread_buffer());
return partials_tile_with_c_distr;
}
/**
* @brief Returns the vector size to be used for reading from and writing to partials.
* @return The vector size
*/
CK_TILE_DEVICE static constexpr index_t GetVectorSizePartials()
{
// We use kCM1PerLane from the C register layout of the warp GEMM which corresponds to the
// maximum vector width
return WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane;
}
/**
* @brief Returns distribution used for reading from and writing to partials.
* @return The distribution.
* @note This will result in optimized reads from and writes to partials when C is row major.
* Additional functionality should be added to ensure optimized accesses to partials when C is
* column major. Since the C-Shuffle epilogue only supports C as row major, this is not a
* current limitation.
*/
CK_TILE_DEVICE static constexpr auto MakePartialsDistribution()
{
// Create the encoding to describe waves within a block
constexpr index_t m_warp = BlockGemmShape::BlockWarps::at(number<0>{});
constexpr index_t n_warp = BlockGemmShape::BlockWarps::at(number<1>{});
constexpr index_t m_iter_per_warp = TilePartitioner::MPerBlock / (m_warp * WarpGemm::kM);
constexpr index_t n_iter_per_warp = TilePartitioner::NPerBlock / (n_warp * WarpGemm::kN);
constexpr auto partials_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<m_iter_per_warp, m_warp>, sequence<n_iter_per_warp, n_warp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Create the encoding to describe threads within a wave
constexpr index_t vector_size = GetVectorSizePartials();
constexpr index_t m_warp_repeat = WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane;
constexpr index_t warp_tile_n_threads = WarpGemm::kN / vector_size;
constexpr index_t warp_tile_m_threads = get_warp_size() / warp_tile_n_threads;
// This inner encoding ensures that contiguous threads perform vectorized writes along the
// same row in C.
constexpr auto partials_inner_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<m_warp_repeat, warp_tile_m_threads>,
sequence<warp_tile_n_threads, vector_size>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{};
// Combine the outer and inner encoding
constexpr auto partials_dstr_encode = detail::make_embed_tile_distribution_encoding(
partials_outer_dstr_encoding, partials_inner_dstr_encoding);
return make_static_tile_distribution(partials_dstr_encode);
}
/**
@@ -446,14 +516,22 @@ struct StreamKKernel
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
number<GemmPipeline::GetVectorSizeC()>{},
number<GetVectorSizePartials()>{},
number<1>{});
auto partial_tile_window = make_tile_window(
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0});
store_tile(partial_tile_window, c_block_tile);
{0, 0},
MakePartialsDistribution());
// Since the C block distribution is not the same as the partials distribution, we must
// describe the contents in the c_block_tile with the partials distribution.
// Note: The data assigned to threads does not change between distributions.
auto c_with_partials_dist = make_static_distributed_tensor<typename OAccTile::DataType>(
MakePartialsDistribution(), c_block_tile.get_thread_buffer());
store_tile(partial_tile_window, c_with_partials_dist);
// Wait for all vector stores for this wavefront to complete
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
// Wait for all wavefronts in this workgroup to arrive here before continuing
@@ -591,16 +669,19 @@ struct StreamKKernel
}
else // Tree Reduction
{
auto accum_block_tile = c_block_tile;
index_t tile_local_cta_idx =
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
auto accum_block_tile = c_block_tile;
index_t tile_local_cta_idx = amd_wave_read_first_lane(
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx));
for(index_t stride = 1;; stride <<= 1)
index_t stride = amd_wave_read_first_lane(1);
for(;; stride <<= 1)
{
const index_t partner_cta_idx = cta_idx + stride;
const index_t partner_start_iter =
kargs.tile_partitioner.get_start_iter(partner_cta_idx);
bool partner_in_tile = partner_start_iter < tile_iter_end;
const index_t partner_cta_idx = amd_wave_read_first_lane(cta_idx + stride);
const index_t partner_start_iter = amd_wave_read_first_lane(
kargs.tile_partitioner.get_start_iter(partner_cta_idx));
bool partner_in_tile =
amd_wave_read_first_lane(partner_start_iter < tile_iter_end);
// If the partner of the workgroup who started the tile is not in this tile,
// then the work for this tile is done and results can be stored in the C