[rocm-libraries] ROCm/rocm-libraries#5393 (commit d51b649)

[CK Tile] StreamK support for Bwd Weight grouped convolutions
 (#5393)

## Motivation

Add StreamK work distribution to the CK Tile grouped convolution
backward weight kernel. Split-K divides the K-dimension uniformly across
a fixed `k_batch`, which causes load imbalance when the number of output
tiles doesn't evenly fill the GPU. StreamK distributes total
K-iterations evenly across workgroups, improving utilization on these
shapes.

## Technical Details

StreamK is added as an `if constexpr` branch in the existing kernel,
selected by the `TilePartitioner_` template parameter. Two reduction
strategies are supported:
- **Linear**: tile-starter sequentially accumulates partials from
contributing CTAs
- **Tree**: pairwise binary tree reduction (O(log n) depth, faster for
many contributors)

Both persistent and non-persistent data-parallel (DP) sections are
supported.

Key changes:
- `grouped_convolution_backward_weight_kernel.hpp`: StreamK execution
path with `RunStreamK`/`RunStreamKLoop`, partial store/load via
workspace, flag-based cross-CTA synchronization,
`GridSize`/`MakeKernelArgs`/`GetWorkSpaceSize` extensions
- `streamk_common.hpp`: Shared `StreamKReductionOps` (reduction helpers)
and `StreamKDispatch` (persistent/non-persistent DP dispatch), used by
both GEMM and Conv StreamK kernels
- `streamk_gemm_kernel.hpp`: Refactored to use shared helpers
- Merged split-K and StreamK example invokers via `PartitionerPolicy`
template parameter
- StreamK example binary with `--streamk_reduction=linear|tree` and
`--streamk_persistent=0|1`
- CK Builder integration: `SpecifiesStreamK` concept,
`TilePartitionerType` factory helper, `InstanceTraits` with StreamK
fields
- 30 tests: host-side, GPU end-to-end (Linear + Tree + Persistent DP),
negative, builder regression

### Performance (MI355X, gfx950)

Speedup relative to best split-K (sweep over k_batch={1,2,4,8,16,32}):

| Shape | 16x64 tiles | | 128x128 tiles | |
|---|---|---|---|---|
| | Split-K | StreamK | Split-K | StreamK |
| 1x1 128x128 N=32 28x28 | 1.00x | 0.54x | 1.00x | 0.81x |
| 3x3 128x128 N=32 14x14 | 1.00x | 0.59x | 1.00x | 0.62x |
| 1x1 256x64 N=32 56x56 | 1.00x | 0.83x | 1.00x | 1.83x |
| 3x3 512x512 N=2 7x7 | 1.00x | 1.12x | 1.00x | 0.62x |
| 1x1 1024x1024 N=4 7x7 | 1.00x | 1.09x | 1.00x | 0.60x |
| 3x3 128x128 N=32 28x28 | 1.00x | 0.44x | 1.00x | 0.96x |
| 3x3 256x256 N=32 14x14 | 1.00x | 0.67x | 1.00x | 0.93x |
| 3x3 512x512 N=32 7x7 | 1.00x | 0.98x | 1.00x | 1.16x |

StreamK's value depends on tile config: with larger tiles (fewer output
tiles), StreamK delivers up to 1.83x speedup on bottleneck shapes and up
to 1.16x on typical large-channel convolutions. Tree reduction
consistently outperforms Linear when multiple CTAs contribute to the
same tile (up to 2.87x faster), due to O(log n) reduction depth vs O(n)
sequential accumulation. The table reports the best of Linear and Tree
for each shape.

## Test Plan

```bash
ninja -C build test_ck_tile_grouped_conv_bwd_weight_streamk
./build/bin/test_ck_tile_grouped_conv_bwd_weight_streamk

# Builder tests (requires CK_EXPERIMENTAL_BUILDER=ON)
ninja -C build check-builder
```

30 tests covering:
- Host-side: type traits, kernel args construction, grid size, workspace
size
- GPU end-to-end (Linear + Tree): small/medium shapes, multi-group,
stride>1, pure-DP degeneration, single-tile all-SK, large GemmK, higher
occupancy
- Persistent DP: Linear + Tree with persistent data-parallel dispatch
- Negative: `IsSupportedArgument` rejects unaligned K and C
- Builder: Create (instance string validation) + Execution (reference
comparison) + instance string regression

## Test Result

All 30 conv StreamK tests pass on MI355X (gfx950). 64/64 GEMM StreamK
tests pass. Full `check-builder` suite passes. Tolerances computed
dynamically using `calculate_rtol_atol` pattern (fp16 ULP-aware).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Johannes Graner
2026-03-27 09:18:14 +00:00
committed by assistant-librarian[bot]
parent 36f2ec23f5
commit 58475d3f45
21 changed files with 1860 additions and 348 deletions

View File

@@ -154,6 +154,7 @@ struct StreamKKernel
using KernelArgs = StreamKKernelArgs;
using Kernel = StreamKKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
using StreamKOps = StreamKReductionOps<TilePartitioner, GemmPipeline, StreamKKernelArgs>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -314,231 +315,6 @@ struct StreamKKernel
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
}
/**
*@brief Signals that the current thread block(CTA) has completed storing its partial
* results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the current thread block (CTA).
* @note This function utilizes a scalar store to write to the flags buffer.
*/
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
index_t cta_idx) const
{
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t offset = cta_idx * sizeof(index_t);
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the approproriate
// cache level(s) to ensure the write is visible to other workgroups. See the
// appropriate ISA for details about the GLC modifier.
"s_store_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
:
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
: "memory");
}
/**
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @note This function utilizes a scalar load to read from the flags
* buffer.
*/
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
{
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t result;
index_t offset = cta_idx * sizeof(index_t);
do
{
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the
// approproriate cache level(s) to avoid reading stale flags. See the
// appropriate ISA for details about the GLC modifier.
"s_load_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
: "=s"(result)
: "s"(sk_flags_ptr), "s"(offset)
: "memory");
} while(result != 1);
}
/**
* @brief Adds the values of a block tile to an output block tile.
* @param in_out_block_tile The output block tile to which values are added.
* @param in_block_tile The input block tile whose values are added.
* @note This function iterates over the distributed spans of the block tiles and updates
* the output block tile with accumulated values.
*/
template <typename OAccTile>
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
const OAccTile& in_block_tile) const
{
using BlockType = remove_cvref_t<decltype(in_out_block_tile)>;
constexpr auto o_spans = BlockType::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto idx = make_tuple(idx0, idx1);
in_out_block_tile(idx) = in_out_block_tile[idx] + in_block_tile[idx];
});
});
}
/**
* @brief Loads a partial block tile from the workspace buffer.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile_dist The tile distribution for the block.
* @return The loaded partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for
* loading the partial block tile.
*/
template <typename DataType, typename OAccTileDist>
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
index_t cta_idx,
const OAccTileDist& c_block_tile_dist) const
{
const auto c_block_tile_buffer_size =
TilePartitioner::MPerBlock * TilePartitioner::NPerBlock * sizeof(DataType);
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
static_cast<DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
number<GetVectorSizePartials()>{},
number<1>{});
auto partial_tile_window = make_tile_window(
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0},
MakePartialsDistribution());
auto partials_tile = load_tile(partial_tile_window);
// Since the partials distribution is not the same as the C block distribution, we must
// describe the contents in the partials tile with the C block distribution.
// Note: The data assigned to threads does not change between distributions.
auto partials_tile_with_c_distr = make_static_distributed_tensor<DataType>(
c_block_tile_dist, partials_tile.get_thread_buffer());
return partials_tile_with_c_distr;
}
/**
* @brief Returns the vector size to be used for reading from and writing to partials.
* @return The vector size
*/
CK_TILE_DEVICE static constexpr index_t GetVectorSizePartials()
{
// We use kCM1PerLane from the C register layout of the warp GEMM which corresponds to the
// maximum vector width
return WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane;
}
/**
* @brief Returns distribution used for reading from and writing to partials.
* @return The distribution.
* @note This will result in optimized reads from and writes to partials when C is row major.
* Additional functionality should be added to ensure optimized accesses to partials when C is
* column major. Since the C-Shuffle epilogue only supports C as row major, this is not a
* current limitation.
*/
CK_TILE_DEVICE static constexpr auto MakePartialsDistribution()
{
// Create the encoding to describe waves within a block
constexpr index_t m_warp = BlockGemmShape::BlockWarps::at(number<0>{});
constexpr index_t n_warp = BlockGemmShape::BlockWarps::at(number<1>{});
constexpr index_t m_iter_per_warp = TilePartitioner::MPerBlock / (m_warp * WarpGemm::kM);
constexpr index_t n_iter_per_warp = TilePartitioner::NPerBlock / (n_warp * WarpGemm::kN);
constexpr auto partials_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<m_iter_per_warp, m_warp>, sequence<n_iter_per_warp, n_warp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Create the encoding to describe threads within a wave
constexpr index_t vector_size = GetVectorSizePartials();
constexpr index_t m_warp_repeat = WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane;
constexpr index_t warp_tile_n_threads = WarpGemm::kN / vector_size;
constexpr index_t warp_tile_m_threads = get_warp_size() / warp_tile_n_threads;
// This inner encoding ensures that contiguous threads perform vectorized writes along the
// same row in C.
constexpr auto partials_inner_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<m_warp_repeat, warp_tile_m_threads>,
sequence<warp_tile_n_threads, vector_size>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{};
// Combine the outer and inner encoding
constexpr auto partials_dstr_encode = detail::make_embed_tile_distribution_encoding(
partials_outer_dstr_encoding, partials_inner_dstr_encoding);
return make_static_tile_distribution(partials_dstr_encode);
}
/**
* @brief Stores a partial block tile to the workspace buffer.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile The block tile to be stored.
* @note This function calculates the buffer pointer and uses the tile window for storing
* the partial block tile.
*/
template <typename OAccTile>
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
index_t cta_idx,
const OAccTile& c_block_tile) const
{
const auto c_block_tile_buffer_size = TilePartitioner::MPerBlock *
TilePartitioner::NPerBlock *
sizeof(typename OAccTile::DataType);
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<
address_space_enum::global,
memory_operation_enum::set,
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
number<GetVectorSizePartials()>{},
number<1>{});
auto partial_tile_window = make_tile_window(
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0},
MakePartialsDistribution());
// Since the C block distribution is not the same as the partials distribution, we must
// describe the contents in the c_block_tile with the partials distribution.
// Note: The data assigned to threads does not change between distributions.
auto c_with_partials_dist = make_static_distributed_tensor<typename OAccTile::DataType>(
MakePartialsDistribution(), c_block_tile.get_thread_buffer());
store_tile(partial_tile_window, c_with_partials_dist);
// Wait for all vector stores for this wavefront to complete
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
// Wait for all wavefronts in this workgroup to arrive here before continuing
__builtin_amdgcn_s_barrier();
}
/**
* @brief Runs the main Stream - K algorithm.
* @param kargs Stream - K kernel arguments.
@@ -551,6 +327,7 @@ struct StreamKKernel
CK_TILE_DEVICE
void StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
{
const StreamKOps sk_ops{};
index_t iter_start, iter_end;
kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
@@ -631,8 +408,8 @@ struct StreamKKernel
{
if(!tile_started)
{
StorePartial(kargs, cta_idx, c_block_tile);
SignalStorePartialDone(kargs, cta_idx);
sk_ops.StorePartial(kargs, cta_idx, c_block_tile);
sk_ops.SignalStorePartialDone(kargs, cta_idx);
}
else
{
@@ -649,12 +426,12 @@ struct StreamKKernel
while(accum_iters < iter_per_tile)
{
WaitStorePartialDone(kargs, next_cta);
sk_ops.WaitStorePartialDone(kargs, next_cta);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
sk_ops.AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
sk_ops.template LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
@@ -703,13 +480,14 @@ struct StreamKKernel
// partials and accumulate results.
if(partner_in_tile)
{
WaitStorePartialDone(kargs, partner_cta_idx);
sk_ops.WaitStorePartialDone(kargs, partner_cta_idx);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs,
partner_cta_idx,
c_block_tile.get_tile_distribution()));
sk_ops.AddBlockTile(
accum_block_tile,
sk_ops.template LoadPartial<typename BlockType::DataType>(
kargs,
partner_cta_idx,
c_block_tile.get_tile_distribution()));
}
}
// Otherwise, it's this workgroup's turn to write to partials. All
@@ -717,8 +495,8 @@ struct StreamKKernel
// partials.
else
{
StorePartial(kargs, cta_idx, accum_block_tile);
SignalStorePartialDone(kargs, cta_idx);
sk_ops.StorePartial(kargs, cta_idx, accum_block_tile);
sk_ops.SignalStorePartialDone(kargs, cta_idx);
// Once the workgroup writes to partials, it has no more work to do for
// this tile.
break;
@@ -739,66 +517,26 @@ struct StreamKKernel
}
/**
* @brief Entry point for the Stream-K Kernel with non-persistent DP.
* @brief Entry point for the Stream-K kernel.
*
* @par Overview
* For the Non-Persistent kernel, each data parallel workgroup will
* compute the results for their assigned macro-tile by calling `BaseGemm()`.
* The Stream-K workgroups will do their assigned work by calling
* `StreamKGemm()`, which calls `BaseGemm()` in the Stream-K loop.
* Uses StreamKDispatch to handle both persistent and non-persistent DP sections.
* Non-persistent: dedicated DP workgroups process full tiles, then dedicated SK
* workgroups share remaining K-iterations.
* Persistent: each workgroup loops over DP tiles (round-robin), then proceeds
* to SK work.
*/
template <bool U = PersistentDP>
CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const
{
// Allocate LDS
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
const index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
index_t block_idx = ck_tile::get_block_1d_id();
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();
// Check if at the data parallel section
if(is_dp_ctas)
{
BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
}
else
{
// Stream-K
StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0);
}
}
/**
* @brief Entry point for the Stream-K Kernel with persistent DP.
*
* @par Overview
* For the Persistent kernel, each workgroup will first compute their
* assigned data-parallel tiles. Each data parallel tile will be computed
* by calling `BaseGemm()`. Then the workgroups will proceed with the
* Stream-K portion by calling `StreamKGemm()`, which calls `BaseGemm()`
* in the Stream-K loop.
*/
template <bool U = PersistentDP>
CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
{
// Allocate LDS
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
index_t block_idx = ck_tile::get_block_1d_id();
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
// Data-parallel section
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
tile_idx += kargs.tile_partitioner.get_max_active_wgs())
{
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
block_sync_lds();
}
// Stream-K section
StreamKGemm(kargs, block_idx, smem_ptr_0);
StreamKDispatch(
kargs.tile_partitioner,
[&](index_t tile_idx) {
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
},
[&](index_t sk_cta_idx) { StreamKGemm(kargs, sk_cta_idx, smem_ptr_0); });
}
private: