[rocm-libraries] ROCm/rocm-libraries#5393 (commit d51b649)

[CK Tile] StreamK support for Bwd Weight grouped convolutions
 (#5393)

## Motivation

Add StreamK work distribution to the CK Tile grouped convolution
backward weight kernel. Split-K divides the K-dimension uniformly across
a fixed `k_batch`, which causes load imbalance when the number of output
tiles doesn't evenly fill the GPU. StreamK distributes total
K-iterations evenly across workgroups, improving utilization on these
shapes.

## Technical Details

StreamK is added as an `if constexpr` branch in the existing kernel,
selected by the `TilePartitioner_` template parameter. Two reduction
strategies are supported:
- **Linear**: tile-starter sequentially accumulates partials from
contributing CTAs
- **Tree**: pairwise binary tree reduction (O(log n) depth, faster for
many contributors)

Both persistent and non-persistent data-parallel (DP) sections are
supported.

Key changes:
- `grouped_convolution_backward_weight_kernel.hpp`: StreamK execution
path with `RunStreamK`/`RunStreamKLoop`, partial store/load via
workspace, flag-based cross-CTA synchronization,
`GridSize`/`MakeKernelArgs`/`GetWorkSpaceSize` extensions
- `streamk_common.hpp`: Shared `StreamKReductionOps` (reduction helpers)
and `StreamKDispatch` (persistent/non-persistent DP dispatch), used by
both GEMM and Conv StreamK kernels
- `streamk_gemm_kernel.hpp`: Refactored to use shared helpers
- Merged split-K and StreamK example invokers via `PartitionerPolicy`
template parameter
- StreamK example binary with `--streamk_reduction=linear|tree` and
`--streamk_persistent=0|1`
- CK Builder integration: `SpecifiesStreamK` concept,
`TilePartitionerType` factory helper, `InstanceTraits` with StreamK
fields
- 30 tests: host-side, GPU end-to-end (Linear + Tree + Persistent DP),
negative, builder regression

### Performance (MI355X, gfx950)

Speedup relative to best split-K (sweep over k_batch={1,2,4,8,16,32}):

| Shape | 16x64 tiles | | 128x128 tiles | |
|---|---|---|---|---|
| | Split-K | StreamK | Split-K | StreamK |
| 1x1 128x128 N=32 28x28 | 1.00x | 0.54x | 1.00x | 0.81x |
| 3x3 128x128 N=32 14x14 | 1.00x | 0.59x | 1.00x | 0.62x |
| 1x1 256x64 N=32 56x56 | 1.00x | 0.83x | 1.00x | 1.83x |
| 3x3 512x512 N=2 7x7 | 1.00x | 1.12x | 1.00x | 0.62x |
| 1x1 1024x1024 N=4 7x7 | 1.00x | 1.09x | 1.00x | 0.60x |
| 3x3 128x128 N=32 28x28 | 1.00x | 0.44x | 1.00x | 0.96x |
| 3x3 256x256 N=32 14x14 | 1.00x | 0.67x | 1.00x | 0.93x |
| 3x3 512x512 N=32 7x7 | 1.00x | 0.98x | 1.00x | 1.16x |

StreamK's value depends on tile config: with larger tiles (fewer output
tiles), StreamK delivers up to 1.83x speedup on bottleneck shapes and up
to 1.16x on typical large-channel convolutions. Tree reduction
consistently outperforms Linear when multiple CTAs contribute to the
same tile (up to 2.87x faster), due to O(log n) reduction depth vs O(n)
sequential accumulation. The table reports the best of Linear and Tree
for each shape.

## Test Plan

```bash
ninja -C build test_ck_tile_grouped_conv_bwd_weight_streamk
./build/bin/test_ck_tile_grouped_conv_bwd_weight_streamk

# Builder tests (requires CK_EXPERIMENTAL_BUILDER=ON)
ninja -C build check-builder
```

30 tests covering:
- Host-side: type traits, kernel args construction, grid size, workspace
size
- GPU end-to-end (Linear + Tree): small/medium shapes, multi-group,
stride>1, pure-DP degeneration, single-tile all-SK, large GemmK, higher
occupancy
- Persistent DP: Linear + Tree with persistent data-parallel dispatch
- Negative: `IsSupportedArgument` rejects unaligned K and C
- Builder: Create (instance string validation) + Execution (reference
comparison) + instance string regression

## Test Result

All 30 conv StreamK tests pass on MI355X (gfx950). 64/64 GEMM StreamK
tests pass. Full `check-builder` suite passes. Tolerances computed
dynamically using `calculate_rtol_atol` pattern (fp16 ULP-aware).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Johannes Graner
2026-03-27 09:18:14 +00:00
committed by assistant-librarian[bot]
parent 36f2ec23f5
commit 58475d3f45
21 changed files with 1860 additions and 348 deletions

View File

@@ -17,6 +17,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
add_executable(tile_example_grouped_conv_bwd_weight grouped_convolution_backward_weight.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
# StreamK requires cross-CU coherence (StreamKCoherency), CDNA only.
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
add_executable(tile_example_grouped_conv_bwd_weight_streamk grouped_convolution_backward_weight_streamk.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight_streamk PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
endif()
add_executable(tile_example_grouped_conv_bwd_weight_two_stage grouped_convolution_backward_weight_two_stage.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})

View File

@@ -17,7 +17,7 @@
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker;
using Invoker = GroupedConvolutionBackwardWeightInvoker<>;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");

View File

@@ -2,7 +2,28 @@
// SPDX-License-Identifier: MIT
#pragma once
#include "grouped_convolution_utils.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
/// @brief Partitioner policies for the conv bwd weight invoker.
/// SplitKPartitionerPolicy is the default (data-parallel + split-K).
/// StreamKPartitionerPolicy selects StreamK work distribution.
struct SplitKPartitionerPolicy
{
template <typename GemmShape, typename GroupedConvTraitsType>
using type = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
};
template <ck_tile::StreamKReductionStrategy ReductionStrategy, bool Persistent = false>
struct StreamKPartitionerPolicy
{
template <typename GemmShape, typename>
using type = ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy, Persistent>;
};
template <typename PartitionerPolicy = SplitKPartitionerPolicy>
struct GroupedConvolutionBackwardWeightInvoker
{
template <ck_tile::index_t NDimSpatial,
@@ -40,10 +61,8 @@ struct GroupedConvolutionBackwardWeightInvoker
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using TilePartitioner =
typename PartitionerPolicy::template type<GemmShape, GroupedConvTraitsType>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
@@ -103,7 +122,7 @@ struct GroupedConvolutionBackwardWeightInvoker
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
@@ -111,6 +130,12 @@ struct GroupedConvolutionBackwardWeightInvoker
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
// Workspace: may be non-zero for StreamK (depends on SK/DP tile split),
// always zero for Split-K.
auto ws_size = Kernel::GetWorkSpaceSize(kargs);
ck_tile::DeviceMem workspace_dev(ws_size);
Kernel::SetWorkSpacePointer(kargs, workspace_dev.GetDeviceBuffer());
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
@@ -120,14 +145,25 @@ struct GroupedConvolutionBackwardWeightInvoker
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "workspace: " << ws_size << " bytes" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
if(kargs.k_batch > 1)
if constexpr(Kernel::IsStreamK)
{
// StreamK: zero workspace flags before each kernel launch
if(ws_size > 0)
{
ck_tile::hip_check_error(
hipMemsetAsync(workspace_dev.GetDeviceBuffer(), 0, ws_size, s.stream_id_));
}
}
else if(kargs.k_batch > 1)
{
// Split-K: zero weight buffer for atomic accumulation
ck_tile::hip_check_error(hipMemsetAsync(
kargs.wei_ptr, 0, args.template GetWeightByte<WeiDataType>(), s.stream_id_));
}

View File

@@ -0,0 +1,99 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_backward_weight_invoker.hpp"
#include "run_grouped_convolution_bwd_weight_example.inc"
template <template <typename PrecType> typename ConvConfig, typename Invoker>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported data type for this operation!");
}
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] =
create_args(argc,
argv,
{
{"streamk_reduction", "tree", "StreamK reduction strategy: linear or tree"},
{"streamk_persistent", "0", "Use persistent DP (1) or non-persistent (0)"},
});
if(!result)
return -1;
try
{
const std::string reduction = arg_parser.get_str("streamk_reduction");
const bool persistent = arg_parser.get_int("streamk_persistent") != 0;
// Dispatch on reduction strategy × persistent DP
if(reduction == "linear" && !persistent)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker<
StreamKPartitionerPolicy<ck_tile::StreamKReductionStrategy::Linear, false>>;
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3, Invoker>(arg_parser);
}
else if(reduction == "linear" && persistent)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker<
StreamKPartitionerPolicy<ck_tile::StreamKReductionStrategy::Linear, true>>;
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3, Invoker>(arg_parser);
}
else if(reduction == "tree" && !persistent)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker<
StreamKPartitionerPolicy<ck_tile::StreamKReductionStrategy::Tree, false>>;
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3, Invoker>(arg_parser);
}
else if(reduction == "tree" && persistent)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker<
StreamKPartitionerPolicy<ck_tile::StreamKReductionStrategy::Tree, true>>;
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3, Invoker>(arg_parser);
}
else
{
std::cerr << "Unknown streamk_reduction: " << reduction
<< ". Use 'linear' or 'tree'.\n";
return EXIT_FAILURE;
}
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -4,6 +4,8 @@
#pragma once
#include <string>
#include <tuple>
#include <vector>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
@@ -80,7 +82,10 @@ ck_tile::index_t fill_spatial_dimensions(std::vector<ck_tile::index_t>& filter_s
return n_dim_sp;
}
auto create_args(int argc, char* argv[])
auto create_args(
int argc,
char* argv[],
const std::vector<std::tuple<std::string, std::string, std::string>>& extra_args = {})
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("g", "2", "group dimension")
@@ -124,6 +129,12 @@ auto create_args(int argc, char* argv[])
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format");
// Allow per-binary CLI customization (e.g., StreamK adds --streamk_reduction).
for(const auto& [key, default_val, help] : extra_args)
{
arg_parser.insert(key, default_val, help);
}
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}