[CK Tile] StreamK support for Bwd Weight grouped convolutions (#5393)

## Motivation

Add StreamK work distribution to the CK Tile grouped convolution
backward weight kernel. Split-K divides the K-dimension uniformly across
a fixed `k_batch`, which causes load imbalance when the number of output
tiles doesn't evenly fill the GPU. StreamK distributes total
K-iterations evenly across workgroups, improving utilization on these
shapes.

## Technical Details

StreamK is added as an `if constexpr` branch in the existing kernel,
selected by the `TilePartitioner_` template parameter. Two reduction
strategies are supported:
- **Linear**: tile-starter sequentially accumulates partials from
contributing CTAs
- **Tree**: pairwise binary tree reduction (O(log n) depth, faster for
many contributors)

Both persistent and non-persistent data-parallel (DP) sections are
supported.

Key changes:
- `grouped_convolution_backward_weight_kernel.hpp`: StreamK execution
path with `RunStreamK`/`RunStreamKLoop`, partial store/load via
workspace, flag-based cross-CTA synchronization,
`GridSize`/`MakeKernelArgs`/`GetWorkSpaceSize` extensions
- `streamk_common.hpp`: Shared `StreamKReductionOps` (reduction helpers)
and `StreamKDispatch` (persistent/non-persistent DP dispatch), used by
both GEMM and Conv StreamK kernels
- `streamk_gemm_kernel.hpp`: Refactored to use shared helpers
- Merged split-K and StreamK example invokers via `PartitionerPolicy`
template parameter
- StreamK example binary with `--streamk_reduction=linear|tree` and
`--streamk_persistent=0|1`
- CK Builder integration: `SpecifiesStreamK` concept,
`TilePartitionerType` factory helper, `InstanceTraits` with StreamK
fields
- 30 tests: host-side, GPU end-to-end (Linear + Tree + Persistent DP),
negative, builder regression

### Performance (MI355X, gfx950)

Speedup relative to best split-K (sweep over k_batch={1,2,4,8,16,32}):

| Shape | 16x64 tiles | | 128x128 tiles | |
|---|---|---|---|---|
| | Split-K | StreamK | Split-K | StreamK |
| 1x1 128x128 N=32 28x28 | 1.00x | 0.54x | 1.00x | 0.81x |
| 3x3 128x128 N=32 14x14 | 1.00x | 0.59x | 1.00x | 0.62x |
| 1x1 256x64 N=32 56x56 | 1.00x | 0.83x | 1.00x | 1.83x |
| 3x3 512x512 N=2 7x7 | 1.00x | 1.12x | 1.00x | 0.62x |
| 1x1 1024x1024 N=4 7x7 | 1.00x | 1.09x | 1.00x | 0.60x |
| 3x3 128x128 N=32 28x28 | 1.00x | 0.44x | 1.00x | 0.96x |
| 3x3 256x256 N=32 14x14 | 1.00x | 0.67x | 1.00x | 0.93x |
| 3x3 512x512 N=32 7x7 | 1.00x | 0.98x | 1.00x | 1.16x |

StreamK's value depends on tile config: with larger tiles (fewer output
tiles), StreamK delivers up to 1.83x speedup on bottleneck shapes and up
to 1.16x on typical large-channel convolutions. Tree reduction
consistently outperforms Linear when multiple CTAs contribute to the
same tile (up to 2.87x faster), due to O(log n) reduction depth vs O(n)
sequential accumulation. The table reports the best of Linear and Tree
for each shape.

## Test Plan

```bash
ninja -C build test_ck_tile_grouped_conv_bwd_weight_streamk
./build/bin/test_ck_tile_grouped_conv_bwd_weight_streamk

# Builder tests (requires CK_EXPERIMENTAL_BUILDER=ON)
ninja -C build check-builder
```

30 tests covering:
- Host-side: type traits, kernel args construction, grid size, workspace
size
- GPU end-to-end (Linear + Tree): small/medium shapes, multi-group,
stride>1, pure-DP degeneration, single-tile all-SK, large GemmK, higher
occupancy
- Persistent DP: Linear + Tree with persistent data-parallel dispatch
- Negative: `IsSupportedArgument` rejects unaligned K and C
- Builder: Create (instance string validation) + Execution (reference
comparison) + instance string regression

## Test Result

All 30 conv StreamK tests pass on MI355X (gfx950). 64/64 GEMM StreamK
tests pass. Full `check-builder` suite passes. Tolerances computed
dynamically using `calculate_rtol_atol` pattern (fp16 ULP-aware).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Johannes Graner
2026-03-27 10:17:10 +01:00
committed by GitHub
parent 54272c6fa6
commit c60514f371
21 changed files with 1860 additions and 348 deletions

View File

@@ -17,6 +17,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12")
add_executable(tile_example_grouped_conv_bwd_weight grouped_convolution_backward_weight.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
# StreamK requires cross-CU coherence (StreamKCoherency), CDNA only.
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
add_executable(tile_example_grouped_conv_bwd_weight_streamk grouped_convolution_backward_weight_streamk.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight_streamk PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
endif()
add_executable(tile_example_grouped_conv_bwd_weight_two_stage grouped_convolution_backward_weight_two_stage.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})

View File

@@ -17,7 +17,7 @@
template <template <typename PrecType> typename ConvConfig>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker;
using Invoker = GroupedConvolutionBackwardWeightInvoker<>;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");

View File

@@ -2,7 +2,28 @@
// SPDX-License-Identifier: MIT
#pragma once
#include "grouped_convolution_utils.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
/// @brief Partitioner policies for the conv bwd weight invoker.
/// SplitKPartitionerPolicy is the default (data-parallel + split-K).
/// StreamKPartitionerPolicy selects StreamK work distribution.
struct SplitKPartitionerPolicy
{
template <typename GemmShape, typename GroupedConvTraitsType>
using type = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
};
template <ck_tile::StreamKReductionStrategy ReductionStrategy, bool Persistent = false>
struct StreamKPartitionerPolicy
{
template <typename GemmShape, typename>
using type = ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy, Persistent>;
};
template <typename PartitionerPolicy = SplitKPartitionerPolicy>
struct GroupedConvolutionBackwardWeightInvoker
{
template <ck_tile::index_t NDimSpatial,
@@ -40,10 +61,8 @@ struct GroupedConvolutionBackwardWeightInvoker
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using TilePartitioner =
typename PartitionerPolicy::template type<GemmShape, GroupedConvTraitsType>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
@@ -103,7 +122,7 @@ struct GroupedConvolutionBackwardWeightInvoker
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
@@ -111,6 +130,12 @@ struct GroupedConvolutionBackwardWeightInvoker
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
// Workspace: may be non-zero for StreamK (depends on SK/DP tile split),
// always zero for Split-K.
auto ws_size = Kernel::GetWorkSpaceSize(kargs);
ck_tile::DeviceMem workspace_dev(ws_size);
Kernel::SetWorkSpacePointer(kargs, workspace_dev.GetDeviceBuffer());
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
@@ -120,14 +145,25 @@ struct GroupedConvolutionBackwardWeightInvoker
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "workspace: " << ws_size << " bytes" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
auto preprocess = [&]() {
if(kargs.k_batch > 1)
if constexpr(Kernel::IsStreamK)
{
// StreamK: zero workspace flags before each kernel launch
if(ws_size > 0)
{
ck_tile::hip_check_error(
hipMemsetAsync(workspace_dev.GetDeviceBuffer(), 0, ws_size, s.stream_id_));
}
}
else if(kargs.k_batch > 1)
{
// Split-K: zero weight buffer for atomic accumulation
ck_tile::hip_check_error(hipMemsetAsync(
kargs.wei_ptr, 0, args.template GetWeightByte<WeiDataType>(), s.stream_id_));
}

View File

@@ -0,0 +1,99 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_backward_weight_invoker.hpp"
#include "run_grouped_convolution_bwd_weight_example.inc"
template <template <typename PrecType> typename ConvConfig, typename Invoker>
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
ConvConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
ConvConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported data type for this operation!");
}
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] =
create_args(argc,
argv,
{
{"streamk_reduction", "tree", "StreamK reduction strategy: linear or tree"},
{"streamk_persistent", "0", "Use persistent DP (1) or non-persistent (0)"},
});
if(!result)
return -1;
try
{
const std::string reduction = arg_parser.get_str("streamk_reduction");
const bool persistent = arg_parser.get_int("streamk_persistent") != 0;
// Dispatch on reduction strategy × persistent DP
if(reduction == "linear" && !persistent)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker<
StreamKPartitionerPolicy<ck_tile::StreamKReductionStrategy::Linear, false>>;
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3, Invoker>(arg_parser);
}
else if(reduction == "linear" && persistent)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker<
StreamKPartitionerPolicy<ck_tile::StreamKReductionStrategy::Linear, true>>;
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3, Invoker>(arg_parser);
}
else if(reduction == "tree" && !persistent)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker<
StreamKPartitionerPolicy<ck_tile::StreamKReductionStrategy::Tree, false>>;
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3, Invoker>(arg_parser);
}
else if(reduction == "tree" && persistent)
{
using Invoker = GroupedConvolutionBackwardWeightInvoker<
StreamKPartitionerPolicy<ck_tile::StreamKReductionStrategy::Tree, true>>;
return !run_grouped_conv_bwd_weight_example<ConvConfigComputeV3, Invoker>(arg_parser);
}
else
{
std::cerr << "Unknown streamk_reduction: " << reduction
<< ". Use 'linear' or 'tree'.\n";
return EXIT_FAILURE;
}
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -4,6 +4,8 @@
#pragma once
#include <string>
#include <tuple>
#include <vector>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
@@ -80,7 +82,10 @@ ck_tile::index_t fill_spatial_dimensions(std::vector<ck_tile::index_t>& filter_s
return n_dim_sp;
}
auto create_args(int argc, char* argv[])
auto create_args(
int argc,
char* argv[],
const std::vector<std::tuple<std::string, std::string, std::string>>& extra_args = {})
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("g", "2", "group dimension")
@@ -124,6 +129,12 @@ auto create_args(int argc, char* argv[])
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format");
// Allow per-binary CLI customization (e.g., StreamK adds --streamk_reduction).
for(const auto& [key, default_val, help] : extra_args)
{
arg_parser.insert(key, default_val, help);
}
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

View File

@@ -470,4 +470,17 @@ concept SpecifiesDlEpilogue = requires {
{ T::transfer.c } -> DlEpilogueDescriptor;
};
// Concept to detect StreamK configuration in a tile algorithm descriptor.
template <typename T>
concept StreamKDescriptor = requires(T t) {
{ t.reduction_strategy } -> std::convertible_to<StreamKReductionStrategy>;
{ t.persistent } -> std::convertible_to<bool>;
};
// Concept to check if a tile algorithm specifies StreamK work distribution.
template <typename T>
concept SpecifiesStreamK = requires {
{ T::streamk } -> StreamKDescriptor;
};
} // namespace ck_tile::builder

View File

@@ -65,10 +65,8 @@ struct ConvTileFactory
ck_tile::sequence<BLOCK_GEMM.warps.m, BLOCK_GEMM.warps.n, BLOCK_GEMM.warps.k>,
ck_tile::sequence<BLOCK_GEMM.warp_tile.m, BLOCK_GEMM.warp_tile.n, BLOCK_GEMM.warp_tile.k>>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using TilePartitioner =
typename internal::TilePartitionerType<ALGORITHM, GemmShape, GroupedConvTraitsType>::type;
using ConvOutDataType = std::conditional_t<OPTIMIZATIONS.two_stage,
typename Types::AccDataType,

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/types.hpp"
@@ -186,4 +187,38 @@ consteval TileOptimizations SetTileOptimizations()
.two_stage = OPT.two_stage};
}
// Maps builder StreamKReductionStrategy to ck_tile::StreamKReductionStrategy.
consteval ck_tile::StreamKReductionStrategy
MapStreamKReductionStrategy(StreamKReductionStrategy strategy)
{
switch(strategy)
{
case StreamKReductionStrategy::LINEAR: return ck_tile::StreamKReductionStrategy::Linear;
case StreamKReductionStrategy::TREE: return ck_tile::StreamKReductionStrategy::Tree;
default: throw "Unknown StreamKReductionStrategy";
}
}
// Selects the tile partitioner type based on whether the algorithm specifies StreamK.
// Usage: typename TilePartitionerType<ALGORITHM, GemmShape, ConvTraitsType>::type
template <ConvAlgorithmDescriptor auto ALGORITHM, typename GemmShape_, typename ConvTraitsType_>
struct TilePartitionerType
{
using type = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape_,
ConvTraitsType_::FixedGemmParams::TilePartitionerGroupNum,
ConvTraitsType_::FixedGemmParams::TilePartitionerM01>;
};
template <ConvAlgorithmDescriptor auto ALGORITHM, typename GemmShape_, typename ConvTraitsType_>
requires SpecifiesStreamK<decltype(ALGORITHM)>
struct TilePartitionerType<ALGORITHM, GemmShape_, ConvTraitsType_>
{
static constexpr auto CK_STRATEGY =
MapStreamKReductionStrategy(ALGORITHM.streamk.reduction_strategy);
static constexpr bool PERSISTENT = ALGORITHM.streamk.persistent;
using type = ck_tile::StreamKTilePartitioner<GemmShape_, CK_STRATEGY, PERSISTENT>;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -66,14 +66,18 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedCon
static constexpr int kMPerBlock = TilePartitioner_::MPerBlock;
static constexpr int kNPerBlock = TilePartitioner_::NPerBlock;
static constexpr int kKPerBlock = TilePartitioner_::KPerBlock;
// Partitioner — detect StreamK by checking for PERSISTENT member
static constexpr bool kIsStreamK = requires { TilePartitioner_::PERSISTENT; };
static constexpr int kMWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<0>{});
static constexpr int kNWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<1>{});
static constexpr int kKWarp = TilePartitioner_::BlockGemmShape::BlockWarps::at(number<2>{});
static constexpr int kMWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<0>{});
static constexpr int kNWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<1>{});
static constexpr int kKWarpTile = TilePartitioner_::BlockGemmShape::WarpTile::at(number<2>{});
// Warp configuration — sourced from pipeline's BlockGemmShape (works for both
// GemmSpatiallyLocalTilePartitioner and StreamKTilePartitioner).
using BlockGemmShape_ = typename GemmPipeline_::BlockGemmShape;
static constexpr int kMWarp = BlockGemmShape_::BlockWarps::at(number<0>{});
static constexpr int kNWarp = BlockGemmShape_::BlockWarps::at(number<1>{});
static constexpr int kKWarp = BlockGemmShape_::BlockWarps::at(number<2>{});
static constexpr int kMWarpTile = BlockGemmShape_::WarpTile::at(number<0>{});
static constexpr int kNWarpTile = BlockGemmShape_::WarpTile::at(number<1>{});
static constexpr int kKWarpTile = BlockGemmShape_::WarpTile::at(number<2>{});
// Data types
using ADataType = typename GemmPipeline_::ADataType;
@@ -133,6 +137,13 @@ struct InstanceTraits<ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedCon
oss << ","
<< detail::elementwise_op_name<CDEElementwiseOperation>(); // 31.
// CDEElementwiseOperation
oss << "," << kIsStreamK; // 32. IsStreamK
if constexpr(kIsStreamK)
{
oss << ","
<< static_cast<int>(TilePartitioner_::ReductionStrategy); // 33. ReductionStrategy
oss << "," << TilePartitioner_::PERSISTENT; // 34. PersistentDP
}
oss << ">";
return oss.str();

View File

@@ -75,12 +75,20 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
auto kargs = Conv::MakeKernelArgs(host_args);
const dim3 grids = Conv::GridSize(kargs);
const dim3 blocks = Conv::BlockSize();
if(!Conv::IsSupportedArgument(kargs))
return RunResult::not_supported("unsupported ck_tile arguments");
// Workspace allocation (bwd weight only): may be non-zero for StreamK.
[[maybe_unused]] std::size_t ws_size = 0;
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
ws_size = Conv::GetWorkSpaceSize(kargs);
ck_tile::DeviceMem workspace_dev(ws_size);
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
Conv::SetWorkSpacePointer(kargs, workspace_dev.GetDeviceBuffer());
const dim3 grids = Conv::GridSize(kargs);
const dim3 blocks = Conv::BlockSize();
using Types = ck_tile::builder::factory::internal::TileConvTensorTypes<SIGNATURE.data_type>;
const std::size_t zeroing_size = gemm_split_k_output_size<SIGNATURE>(kargs);
@@ -88,8 +96,18 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
auto preprocess = [&]() {
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
if(kargs.k_batch > 1)
if constexpr(Conv::IsStreamK)
{
// StreamK: zero workspace flags before each kernel launch
if(ws_size > 0)
{
ck_tile::hip_check_error(hipMemsetAsync(
workspace_dev.GetDeviceBuffer(), 0, ws_size, s_conf.stream_id_));
}
}
else if(kargs.k_batch > 1)
{
// Split-K: zero weight buffer for atomic accumulation
ck_tile::hip_check_error(
hipMemsetAsync(kargs.wei_ptr,
0,

View File

@@ -241,6 +241,13 @@ enum class ConvAlgorithmSpecialization
MULTIPLE_D
};
// StreamK work distribution strategy for the tile partitioner.
enum class StreamKReductionStrategy
{
LINEAR,
TREE
};
// to_string methods for enum classes
inline std::string_view to_string(DataType dt)
{
@@ -470,6 +477,17 @@ inline std::string_view to_string(TensorLayout layout)
}
}
inline std::string_view to_string(StreamKReductionStrategy s)
{
using enum StreamKReductionStrategy;
switch(s)
{
case LINEAR: return "LINEAR";
case TREE: return "TREE";
default: return "Unknown";
}
}
// ostream operator overloads for enum classes
inline std::ostream& operator<<(std::ostream& os, DataType dt) { return os << to_string(dt); }
@@ -513,4 +531,9 @@ inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
return os << to_string(layout);
}
inline std::ostream& operator<<(std::ostream& os, StreamKReductionStrategy s)
{
return os << to_string(s);
}
} // namespace ck_tile::builder

View File

@@ -189,6 +189,7 @@ set(BWD_WEIGHT_TESTS
conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp
conv/ck/test_ckb_conv_bwd_weight_dl.cpp
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_streamk.cpp
)
if (CK_USE_WMMA)

View File

@@ -0,0 +1,102 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
#include "ck_tile/builder/testing/conv/reference.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "utils/ckb_conv_tile_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "testing_utils.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using enum ck_tile::builder::TensorLayout;
using ck_tile::test::MatchesReference;
using ck_tile::test::SuccessfulRun;
constexpr auto SIGNATURE = cku::ConvSignature{.spatial_dim = 2,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::FP16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = NHWGC}},
.weight = {.config = {.layout = GKYXC}},
.output = {.config = {.layout = NHWGK}}};
constexpr auto ALGORITHM =
cku::ConvAlgorithm_Tile_GroupedConvolutionKernel_StreamK{}
.with_tile_specializations(ckb::TileConvSpecialization::DEFAULT)
.with_tile_thread_block(cku::TileThreadBlock_128x128x32)
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(cku::TileTransfer_4x4x4)
.with_tile_optimizations(ckt::TileOptimizations{.num_groups_to_merge = 1,
.split_image = false,
.explicit_gemm = false,
.two_stage = false})
.with_streamk(ckt::TileStreamKConfig{
.reduction_strategy = ckb::StreamKReductionStrategy::TREE, .persistent = false});
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
TEST(BwdWeight_2D_FP16_NHWGC_StreamK, Create)
{
cku::run_ck_tile_test<Builder>({
"grouped_convolution_backward_weight",
"fp16",
"NHWGC_GKYXC_NHWGK",
"128x128x32",
"2x2",
"16x16x16",
"Default",
"Intrawave",
"CShuffleEpilogue",
"pipeline_AgBgCrCompV3",
"DoubleSmemBuffer_0",
"NumWaveGroups_1",
"MergedGroups_1",
"SplitImage_0",
"ExplicitGemm_0",
"StreamK",
});
}
TEST(BwdWeight_2D_FP16_NHWGC_StreamK, Execution)
{
ckt::Args<SIGNATURE> args = {
.lengths =
{
.batch_size = 2,
.groups = 4,
.input_channels = 32,
.output_channels = 48,
.image = {.width = 32, .height = 56},
.filter = {.width = 3, .height = 3},
},
.filter_strides = {.width = 1, .height = 1},
.filter_dilation = {.width = 1, .height = 1},
.input_left_pad = {.width = 0, .height = 0},
.input_right_pad = {.width = 0, .height = 0},
.a_elementwise_op = {},
.b_elementwise_op = {},
.cde_elementwise_op = {},
};
auto inputs = ckt::alloc_inputs(args);
auto outputs = ckt::alloc_outputs(args);
auto reference = ckt::alloc_outputs(args);
ckt::init_inputs(args, inputs.get());
auto conv = Instance{};
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
auto ref_conv = Reference{};
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
}

View File

@@ -382,6 +382,15 @@ struct TileOptimizations
};
static_assert(ckb::TileOptimizationsDescriptor<TileOptimizations>);
struct TileStreamKConfig
{
// StreamK reduction strategy (Linear or Tree).
StreamKReductionStrategy reduction_strategy;
// Use persistent DP (true) or non-persistent DP (false).
bool persistent;
};
static_assert(ckb::StreamKDescriptor<TileStreamKConfig>);
struct TileConvSpecialization_
{
TileConvSpecialization specialization;
@@ -407,6 +416,11 @@ struct TileOptimizations_
TileOptimizations optimizations;
};
struct TileStreamK_
{
TileStreamKConfig streamk;
};
// Factory
template <typename... Components>
@@ -614,6 +628,15 @@ struct ConvAlgorithmTemplate : Components...
result.optimizations = o;
return result;
}
template <typename SK>
constexpr auto with_streamk(const SK& sk) const
{
static_assert(std::is_base_of_v<TileStreamK_, ConvAlgorithmTemplate>);
auto result = *this;
result.streamk = sk;
return result;
}
};
// Fwd algorithm types
@@ -674,6 +697,15 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileTh
TileConvSpecialization_,
TileOptimizations_>;
// CK Tile algorithm with StreamK work distribution
using ConvAlgorithm_Tile_GroupedConvolutionKernel_StreamK =
ConvAlgorithmTemplate<TileThreadBlock_,
TileBlockGemm_,
TileTransfer_,
TileConvSpecialization_,
TileOptimizations_,
TileStreamK_>;
// Reference algorithm descriptor - for GPU reference validation
// This is a simple algorithm that requires no complex configuration,
// just a specialization marker to identify it as a reference implementation.

View File

@@ -8,6 +8,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
namespace {
@@ -228,6 +229,130 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat)
",bf16" // EDataType
",EmptyTuple" // DsDataType
",PassThrough" // CDEElementwiseOperation
",0" // IsStreamK
">";
EXPECT_EQ(instance_str, expected_str);
}
TEST(InstanceTraits, TileStreamKInstanceStringReturnsCorrectFormat)
{
using GroupedConvTraitsType =
ck_tile::GroupedConvTraits<2 /*NDimSpatial*/,
ck_tile::ConvolutionSpecialization::Default /*ConvSpec*/,
ck_tile::tensor_layout::convolution::NHWGC /*InLayout*/,
ck_tile::tensor_layout::convolution::GKYXC /*WeiLayout*/,
ck_tile::tuple<> /*DsLayout*/,
ck_tile::tensor_layout::convolution::NHWGK /*OutLayout*/,
4 /*VectorSizeA*/,
4 /*VectorSizeB*/,
4 /*VectorSizeC*/,
1 /*NumGroupsToMerge*/,
false /*EnableSplitImage*/,
false /*ExplicitGemm*/>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>,
ck_tile::sequence<4 /*M_Warp*/, 1 /*N_Warp*/, 1 /*K_Warp*/>,
ck_tile::sequence<16 /*M_Warp_Tile*/, 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/>>;
using TilePartitioner = ck_tile::StreamKTilePartitioner<GemmShape,
ck_tile::StreamKReductionStrategy::Tree,
false /*Persistent*/>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
false /*DoubleSmemBuffer*/,
typename GroupedConvTraitsType::AsLayoutBwdWeight,
typename GroupedConvTraitsType::BsLayoutBwdWeight,
typename GroupedConvTraitsType::CLayoutBwdWeight,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
1 /*NumWaveGroups*/>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
ck_tile::bf16_t /*OutDataType*/,
ck_tile::bf16_t /*InDataType*/,
float /*AccDataType*/,
GemmShape,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Intrawave /*scheduler*/,
ck_tile::element_wise::PassThrough /*AElementwiseOperation*/,
ck_tile::element_wise::PassThrough /*BElementwiseOperation*/,
ck_tile::bf16_t /*WeiDataType*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ck_tile::bf16_t /*OutDataType*/,
ck_tile::bf16_t /*InDataType*/,
ck_tile::tuple<> /*DsDataType*/,
float /*AccDataType*/,
ck_tile::bf16_t /*WeiDataType*/,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
ck_tile::element_wise::PassThrough /*CDElementWise*/,
128 /*MPerBlock*/,
128 /*NPerBlock*/,
4 /*M_Warp*/,
1 /*N_Warp*/,
16 /*M_Warp_Tile*/,
16 /*N_Warp_Tile*/,
16 /*K_Warp_Tile*/,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
1 /*kNumWaveGroups*/,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeC>>;
using GroupedConvBwdWeiKernel =
ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
std::string instance_str = ck_tile::reflect::instance_string<GroupedConvBwdWeiKernel>();
std::string expected_str = "GroupedConvolutionBackwardWeightKernel"
"<2" // NDimSpatial
",Default" // ConvSpecialization
",NHWGC" // InLayout
",GKYXC" // WeiLayout
",EmptyTuple" // DsLayout
",NHWGK" // OutLayout
",4" // VectorSizeA
",4" // VectorSizeB
",4" // VectorSizeC
",1" // NumGroupsToMerge
",0" // EnableSplitImage
",0" // ExplicitGemm
",128" // MPerBlock
",128" // NPerBlock
",32" // KPerBlock
",4" // MWarp
",1" // NWarp
",1" // KWarp
",16" // MWarpTile
",16" // NWarpTile
",16" // KWarpTile
",bf16" // ADataType
",bf16" // BDataType
",COMPUTE_V3" // BlkGemmPipelineVer
",Intrawave" // BlkGemmPipeSched
",0" // DoubleSmemBuffer
",1" // NumWaveGroups
",fp32" // AccDataType
",bf16" // EDataType
",EmptyTuple" // DsDataType
",PassThrough" // CDEElementwiseOperation
",1" // IsStreamK
",2" // ReductionStrategy (Tree=2)
",0" // PersistentDP
">";
EXPECT_EQ(instance_str, expected_str);

View File

@@ -55,6 +55,7 @@
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/e8m0.hpp"
#include "ck_tile/core/numeric/ext_vector_base.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
namespace ck_tile {
enum StreamKReductionStrategy : uint32_t
@@ -12,4 +13,277 @@ enum StreamKReductionStrategy : uint32_t
Linear = 1u,
Tree = 2u
};
/// @brief StreamK reduction helpers: partial store/load, flag signaling, and tile accumulation.
/// Shared by StreamK GEMM and StreamK conv bwd weight kernels.
template <typename TilePartitioner_, typename GemmPipeline_, typename KernelArgs_>
struct StreamKReductionOps
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using BlockGemm = typename GemmPipeline_::BlockGemm;
using WarpGemm = typename BlockGemm::WarpGemm;
using BlockGemmShape = typename GemmPipeline_::BlockGemmShape;
/**
*@brief Signals that the current thread block(CTA) has completed storing its partial
* results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the current thread block (CTA).
* @note This function utilizes a scalar store to write to the flags buffer.
*/
CK_TILE_DEVICE void SignalStorePartialDone(const KernelArgs_& kargs, index_t cta_idx) const
{
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t offset = cta_idx * sizeof(index_t);
// Depending on the architecture, the GLC flag will bypass the appropriate
// cache level(s) to ensure the write is visible to other workgroups. See the
// appropriate ISA for details about the GLC modifier.
asm volatile("s_store_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
:
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
: "memory");
}
/**
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @note This function utilizes a scalar load to read from the flags
* buffer.
*/
CK_TILE_DEVICE void WaitStorePartialDone(const KernelArgs_& kargs, index_t cta_idx) const
{
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t result;
index_t offset = cta_idx * sizeof(index_t);
do
{
// Depending on the architecture, the GLC flag will bypass the
// appropriate cache level(s) to avoid reading stale flags. See the
// appropriate ISA for details about the GLC modifier.
asm volatile("s_load_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
: "=s"(result)
: "s"(sk_flags_ptr), "s"(offset)
: "memory");
} while(result != 1);
}
/**
* @brief Adds the values of a block tile to an output block tile.
* @param in_out_block_tile The output block tile to which values are added.
* @param in_block_tile The input block tile whose values are added.
* @note This function iterates over the distributed spans of the block tiles and updates
* the output block tile with accumulated values.
*/
template <typename OAccTile>
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
const OAccTile& in_block_tile) const
{
using BlockType = remove_cvref_t<decltype(in_out_block_tile)>;
constexpr auto o_spans = BlockType::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto idx = make_tuple(idx0, idx1);
in_out_block_tile(idx) = in_out_block_tile[idx] + in_block_tile[idx];
});
});
}
/**
* @brief Loads a partial block tile from the workspace buffer.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile_dist The tile distribution for the block.
* @return The loaded partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for
* loading the partial block tile.
*/
template <typename DataType, typename OAccTileDist>
CK_TILE_DEVICE auto LoadPartial(const KernelArgs_& kargs,
index_t cta_idx,
const OAccTileDist& c_block_tile_dist) const
{
const auto c_block_tile_buffer_size =
TilePartitioner::MPerBlock * TilePartitioner::NPerBlock * sizeof(DataType);
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
static_cast<DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
number<GetVectorSizePartials()>{},
number<1>{});
auto partial_tile_window = make_tile_window(
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0},
MakePartialsDistribution());
auto partials_tile = load_tile(partial_tile_window);
// Since the partials distribution is not the same as the C block distribution, we must
// describe the contents in the partials tile with the C block distribution.
// Note: The data assigned to threads does not change between distributions.
auto partials_tile_with_c_distr = make_static_distributed_tensor<DataType>(
c_block_tile_dist, partials_tile.get_thread_buffer());
return partials_tile_with_c_distr;
}
/**
* @brief Returns the vector size to be used for reading from and writing to partials.
* @return The vector size
*/
CK_TILE_DEVICE static constexpr index_t GetVectorSizePartials()
{
// We use kCM1PerLane from the C register layout of the warp GEMM which corresponds to the
// maximum vector width
return WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane;
}
/**
* @brief Returns distribution used for reading from and writing to partials.
* @return The distribution.
* @note This will result in optimized reads from and writes to partials when C is row major.
* Additional functionality should be added to ensure optimized accesses to partials when C is
* column major. Since the C-Shuffle epilogue only supports C as row major, this is not a
* current limitation.
*/
CK_TILE_DEVICE static constexpr auto MakePartialsDistribution()
{
// Create the encoding to describe waves within a block
constexpr index_t m_warp = BlockGemmShape::BlockWarps::at(number<0>{});
constexpr index_t n_warp = BlockGemmShape::BlockWarps::at(number<1>{});
constexpr index_t m_iter_per_warp = TilePartitioner::MPerBlock / (m_warp * WarpGemm::kM);
constexpr index_t n_iter_per_warp = TilePartitioner::NPerBlock / (n_warp * WarpGemm::kN);
constexpr auto partials_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<m_iter_per_warp, m_warp>, sequence<n_iter_per_warp, n_warp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Create the encoding to describe threads within a wave
constexpr index_t vector_size = GetVectorSizePartials();
constexpr index_t m_warp_repeat = WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane;
constexpr index_t warp_tile_n_threads = WarpGemm::kN / vector_size;
constexpr index_t warp_tile_m_threads = get_warp_size() / warp_tile_n_threads;
// This inner encoding ensures that contiguous threads perform vectorized writes along the
// same row in C.
constexpr auto partials_inner_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<m_warp_repeat, warp_tile_m_threads>,
sequence<warp_tile_n_threads, vector_size>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{};
// Combine the outer and inner encoding
constexpr auto partials_dstr_encode = detail::make_embed_tile_distribution_encoding(
partials_outer_dstr_encoding, partials_inner_dstr_encoding);
return make_static_tile_distribution(partials_dstr_encode);
}
/**
* @brief Stores a partial block tile to the workspace buffer.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile The block tile to be stored.
* @note This function calculates the buffer pointer and uses the tile window for storing
* the partial block tile.
*/
template <typename OAccTile>
CK_TILE_DEVICE void
StorePartial(const KernelArgs_& kargs, index_t cta_idx, const OAccTile& c_block_tile) const
{
const auto c_block_tile_buffer_size = TilePartitioner::MPerBlock *
TilePartitioner::NPerBlock *
sizeof(typename OAccTile::DataType);
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<
address_space_enum::global,
memory_operation_enum::set,
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
number<GetVectorSizePartials()>{},
number<1>{});
auto partial_tile_window = make_tile_window(
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0},
MakePartialsDistribution());
// Since the C block distribution is not the same as the partials distribution, we must
// describe the contents in the c_block_tile with the partials distribution.
// Note: The data assigned to threads does not change between distributions.
auto c_with_partials_dist = make_static_distributed_tensor<typename OAccTile::DataType>(
MakePartialsDistribution(), c_block_tile.get_thread_buffer());
store_tile(partial_tile_window, c_with_partials_dist);
// Wait for all vector stores for this wavefront to complete
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
// Wait for all wavefronts in this workgroup to arrive here before continuing
__builtin_amdgcn_s_barrier();
}
};
/// @brief StreamK data-parallel (DP) dispatch: handles persistent vs non-persistent DP,
/// then delegates to the Stream-K loop. Shared by GEMM and Conv StreamK kernels.
///
/// Non-persistent: launches dp_ctas + sk_ctas workgroups. DP workgroups each process
/// one full tile; SK workgroups share the remaining tiles' K-iterations.
/// Persistent: launches num_cu * occupancy workgroups. Each loops over DP tiles
/// (round-robin), then proceeds to SK work.
///
/// @tparam TilePartitioner_ Partitioner type (persistent or non-persistent specialization).
/// @param tile_partitioner The partitioner instance from kernel args.
/// @param dp_tile_func Callable(index_t tile_idx) — processes one full DP tile.
/// @param sk_func Callable(index_t sk_cta_idx) — runs the StreamK loop for this CTA.
template <typename TilePartitioner_, typename DPTileFunc, typename SKFunc>
CK_TILE_DEVICE void
StreamKDispatch(const TilePartitioner_& tile_partitioner, DPTileFunc dp_tile_func, SKFunc sk_func)
{
const index_t block_idx = get_block_1d_id();
if constexpr(TilePartitioner_::PERSISTENT)
{
// Persistent: each workgroup loops over multiple DP tiles, then does SK work
for(index_t tile_idx = block_idx; tile_idx < tile_partitioner.get_dp_tiles();
tile_idx += tile_partitioner.get_max_active_wgs())
{
dp_tile_func(tile_idx);
block_sync_lds();
}
sk_func(block_idx);
}
else
{
// Non-persistent: dedicated DP workgroups, then dedicated SK workgroups
const index_t dp_ctas = tile_partitioner.get_dp_ctas();
if(block_idx < dp_ctas)
dp_tile_func(block_idx);
else
sk_func(block_idx - dp_ctas);
}
}
} // namespace ck_tile

View File

@@ -154,6 +154,7 @@ struct StreamKKernel
using KernelArgs = StreamKKernelArgs;
using Kernel = StreamKKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
using StreamKOps = StreamKReductionOps<TilePartitioner, GemmPipeline, StreamKKernelArgs>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -314,231 +315,6 @@ struct StreamKKernel
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
}
/**
*@brief Signals that the current thread block(CTA) has completed storing its partial
* results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the current thread block (CTA).
* @note This function utilizes a scalar store to write to the flags buffer.
*/
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
index_t cta_idx) const
{
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t offset = cta_idx * sizeof(index_t);
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the approproriate
// cache level(s) to ensure the write is visible to other workgroups. See the
// appropriate ISA for details about the GLC modifier.
"s_store_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
:
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
: "memory");
}
/**
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @note This function utilizes a scalar load to read from the flags
* buffer.
*/
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
{
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t result;
index_t offset = cta_idx * sizeof(index_t);
do
{
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the
// approproriate cache level(s) to avoid reading stale flags. See the
// appropriate ISA for details about the GLC modifier.
"s_load_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
: "=s"(result)
: "s"(sk_flags_ptr), "s"(offset)
: "memory");
} while(result != 1);
}
/**
* @brief Adds the values of a block tile to an output block tile.
* @param in_out_block_tile The output block tile to which values are added.
* @param in_block_tile The input block tile whose values are added.
* @note This function iterates over the distributed spans of the block tiles and updates
* the output block tile with accumulated values.
*/
template <typename OAccTile>
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
const OAccTile& in_block_tile) const
{
using BlockType = remove_cvref_t<decltype(in_out_block_tile)>;
constexpr auto o_spans = BlockType::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto idx = make_tuple(idx0, idx1);
in_out_block_tile(idx) = in_out_block_tile[idx] + in_block_tile[idx];
});
});
}
/**
* @brief Loads a partial block tile from the workspace buffer.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile_dist The tile distribution for the block.
* @return The loaded partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for
* loading the partial block tile.
*/
template <typename DataType, typename OAccTileDist>
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
index_t cta_idx,
const OAccTileDist& c_block_tile_dist) const
{
const auto c_block_tile_buffer_size =
TilePartitioner::MPerBlock * TilePartitioner::NPerBlock * sizeof(DataType);
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
static_cast<DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
number<GetVectorSizePartials()>{},
number<1>{});
auto partial_tile_window = make_tile_window(
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0},
MakePartialsDistribution());
auto partials_tile = load_tile(partial_tile_window);
// Since the partials distribution is not the same as the C block distribution, we must
// describe the contents in the partials tile with the C block distribution.
// Note: The data assigned to threads does not change between distributions.
auto partials_tile_with_c_distr = make_static_distributed_tensor<DataType>(
c_block_tile_dist, partials_tile.get_thread_buffer());
return partials_tile_with_c_distr;
}
/**
* @brief Returns the vector size to be used for reading from and writing to partials.
* @return The vector size
*/
CK_TILE_DEVICE static constexpr index_t GetVectorSizePartials()
{
// We use kCM1PerLane from the C register layout of the warp GEMM which corresponds to the
// maximum vector width
return WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane;
}
/**
* @brief Returns distribution used for reading from and writing to partials.
* @return The distribution.
* @note This will result in optimized reads from and writes to partials when C is row major.
* Additional functionality should be added to ensure optimized accesses to partials when C is
* column major. Since the C-Shuffle epilogue only supports C as row major, this is not a
* current limitation.
*/
CK_TILE_DEVICE static constexpr auto MakePartialsDistribution()
{
// Create the encoding to describe waves within a block
constexpr index_t m_warp = BlockGemmShape::BlockWarps::at(number<0>{});
constexpr index_t n_warp = BlockGemmShape::BlockWarps::at(number<1>{});
constexpr index_t m_iter_per_warp = TilePartitioner::MPerBlock / (m_warp * WarpGemm::kM);
constexpr index_t n_iter_per_warp = TilePartitioner::NPerBlock / (n_warp * WarpGemm::kN);
constexpr auto partials_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<m_iter_per_warp, m_warp>, sequence<n_iter_per_warp, n_warp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Create the encoding to describe threads within a wave
constexpr index_t vector_size = GetVectorSizePartials();
constexpr index_t m_warp_repeat = WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane;
constexpr index_t warp_tile_n_threads = WarpGemm::kN / vector_size;
constexpr index_t warp_tile_m_threads = get_warp_size() / warp_tile_n_threads;
// This inner encoding ensures that contiguous threads perform vectorized writes along the
// same row in C.
constexpr auto partials_inner_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<m_warp_repeat, warp_tile_m_threads>,
sequence<warp_tile_n_threads, vector_size>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{};
// Combine the outer and inner encoding
constexpr auto partials_dstr_encode = detail::make_embed_tile_distribution_encoding(
partials_outer_dstr_encoding, partials_inner_dstr_encoding);
return make_static_tile_distribution(partials_dstr_encode);
}
/**
* @brief Stores a partial block tile to the workspace buffer.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile The block tile to be stored.
* @note This function calculates the buffer pointer and uses the tile window for storing
* the partial block tile.
*/
template <typename OAccTile>
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
index_t cta_idx,
const OAccTile& c_block_tile) const
{
const auto c_block_tile_buffer_size = TilePartitioner::MPerBlock *
TilePartitioner::NPerBlock *
sizeof(typename OAccTile::DataType);
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<
address_space_enum::global,
memory_operation_enum::set,
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
number<GetVectorSizePartials()>{},
number<1>{});
auto partial_tile_window = make_tile_window(
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0},
MakePartialsDistribution());
// Since the C block distribution is not the same as the partials distribution, we must
// describe the contents in the c_block_tile with the partials distribution.
// Note: The data assigned to threads does not change between distributions.
auto c_with_partials_dist = make_static_distributed_tensor<typename OAccTile::DataType>(
MakePartialsDistribution(), c_block_tile.get_thread_buffer());
store_tile(partial_tile_window, c_with_partials_dist);
// Wait for all vector stores for this wavefront to complete
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
// Wait for all wavefronts in this workgroup to arrive here before continuing
__builtin_amdgcn_s_barrier();
}
/**
* @brief Runs the main Stream - K algorithm.
* @param kargs Stream - K kernel arguments.
@@ -551,6 +327,7 @@ struct StreamKKernel
CK_TILE_DEVICE
void StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
{
const StreamKOps sk_ops{};
index_t iter_start, iter_end;
kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
@@ -631,8 +408,8 @@ struct StreamKKernel
{
if(!tile_started)
{
StorePartial(kargs, cta_idx, c_block_tile);
SignalStorePartialDone(kargs, cta_idx);
sk_ops.StorePartial(kargs, cta_idx, c_block_tile);
sk_ops.SignalStorePartialDone(kargs, cta_idx);
}
else
{
@@ -649,12 +426,12 @@ struct StreamKKernel
while(accum_iters < iter_per_tile)
{
WaitStorePartialDone(kargs, next_cta);
sk_ops.WaitStorePartialDone(kargs, next_cta);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
sk_ops.AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
sk_ops.template LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
@@ -703,13 +480,14 @@ struct StreamKKernel
// partials and accumulate results.
if(partner_in_tile)
{
WaitStorePartialDone(kargs, partner_cta_idx);
sk_ops.WaitStorePartialDone(kargs, partner_cta_idx);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs,
partner_cta_idx,
c_block_tile.get_tile_distribution()));
sk_ops.AddBlockTile(
accum_block_tile,
sk_ops.template LoadPartial<typename BlockType::DataType>(
kargs,
partner_cta_idx,
c_block_tile.get_tile_distribution()));
}
}
// Otherwise, it's this workgroup's turn to write to partials. All
@@ -717,8 +495,8 @@ struct StreamKKernel
// partials.
else
{
StorePartial(kargs, cta_idx, accum_block_tile);
SignalStorePartialDone(kargs, cta_idx);
sk_ops.StorePartial(kargs, cta_idx, accum_block_tile);
sk_ops.SignalStorePartialDone(kargs, cta_idx);
// Once the workgroup writes to partials, it has no more work to do for
// this tile.
break;
@@ -739,66 +517,26 @@ struct StreamKKernel
}
/**
* @brief Entry point for the Stream-K Kernel with non-persistent DP.
* @brief Entry point for the Stream-K kernel.
*
* @par Overview
* For the Non-Persistent kernel, each data parallel workgroup will
* compute the results for their assigned macro-tile by calling `BaseGemm()`.
* The Stream-K workgroups will do their assigned work by calling
* `StreamKGemm()`, which calls `BaseGemm()` in the Stream-K loop.
* Uses StreamKDispatch to handle both persistent and non-persistent DP sections.
* Non-persistent: dedicated DP workgroups process full tiles, then dedicated SK
* workgroups share remaining K-iterations.
* Persistent: each workgroup loops over DP tiles (round-robin), then proceeds
* to SK work.
*/
template <bool U = PersistentDP>
CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const
{
// Allocate LDS
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
const index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
index_t block_idx = ck_tile::get_block_1d_id();
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();
// Check if at the data parallel section
if(is_dp_ctas)
{
BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
}
else
{
// Stream-K
StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0);
}
}
/**
* @brief Entry point for the Stream-K Kernel with persistent DP.
*
* @par Overview
* For the Persistent kernel, each workgroup will first compute their
* assigned data-parallel tiles. Each data parallel tile will be computed
* by calling `BaseGemm()`. Then the workgroups will proceed with the
* Stream-K portion by calling `StreamKGemm()`, which calls `BaseGemm()`
* in the Stream-K loop.
*/
template <bool U = PersistentDP>
CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
{
// Allocate LDS
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
index_t block_idx = ck_tile::get_block_1d_id();
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
// Data-parallel section
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
tile_idx += kargs.tile_partitioner.get_max_active_wgs())
{
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
block_sync_lds();
}
// Stream-K section
StreamKGemm(kargs, block_idx, smem_ptr_0);
StreamKDispatch(
kargs.tile_partitioner,
[&](index_t tile_idx) {
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
},
[&](index_t sk_cta_idx) { StreamKGemm(kargs, sk_cta_idx, smem_ptr_0); });
}
private:

View File

@@ -16,12 +16,25 @@
#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
#endif
namespace ck_tile {
template <typename T>
struct is_streamk_partitioner : std::false_type
{
};
template <typename Shape, StreamKReductionStrategy S, bool P>
struct is_streamk_partitioner<StreamKTilePartitioner<Shape, S, P>> : std::true_type
{
};
template <typename... Args>
CK_TILE_HOST void LogInfo(Args&&... args) noexcept
{
@@ -32,7 +45,7 @@ CK_TILE_HOST void LogInfo(Args&&... args) noexcept
}
/// @brief The Grouped Convolution kernel device arguments.
template <typename GroupedConvTraitsType_>
template <typename GroupedConvTraitsType_, typename TilePartitioner_ = void>
struct GroupedConvBwdWeightKernelArgs
{
@@ -354,6 +367,23 @@ struct GroupedConvBwdWeightKernelArgs
long_index_t group_stride_a;
long_index_t group_stride_b;
long_index_t group_stride_c;
void* workspace_ptr = nullptr;
// StreamK tile partitioner — stored directly when TilePartitioner_ is a real type,
// empty struct when void (Split-K path). Constructed with dummy values here;
// properly initialized in MakeKernelArgs before device-side use.
struct EmptyPartitioner
{
};
using PartitionerType =
std::conditional_t<std::is_void_v<TilePartitioner_>, EmptyPartitioner, TilePartitioner_>;
PartitionerType tile_partitioner = []() {
if constexpr(std::is_void_v<TilePartitioner_>)
return EmptyPartitioner{};
else
return TilePartitioner_(1, 1, 1, 1);
}();
};
/// @brief The Grouped Convolution Backward Weight kernel template.
@@ -424,10 +454,15 @@ struct GroupedConvolutionBackwardWeightKernel
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
using WeiDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using GroupedConvBwdWeightKernelArgsSpecialized =
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>;
static constexpr bool IsSplitKSupported = true;
static constexpr bool IsStreamK = is_streamk_partitioner<TilePartitioner>::value;
using GroupedConvBwdWeightKernelArgsSpecialized =
std::conditional_t<IsStreamK,
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_, TilePartitioner>,
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>>;
using AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -442,6 +477,32 @@ struct GroupedConvolutionBackwardWeightKernel
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
"Not supported!");
static_assert(!IsStreamK || NumDTensor == 0,
"D tensor per-group offsets not implemented for StreamK path");
// StreamK reduction helpers (partial store/load, flag signaling, tile accumulation).
// Shared with the StreamK GEMM kernel via StreamKReductionOps in streamk_common.hpp.
using StreamKOps = StreamKReductionOps<TilePartitioner,
GemmPipeline,
GroupedConvBwdWeightKernelArgsSpecialized>;
CK_TILE_HOST static index_t
GetWorkSpaceSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
if constexpr(IsStreamK)
return kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType)) * kargs.GemmBatch;
else
return 0;
}
// Post-construction setter: workspace is allocated by the caller after
// GetWorkSpaceSize() and must outlive the kernel launch. Can't be moved into
// the constructor because kargs is a POD value type copied to GPU constant memory.
CK_TILE_HOST static void SetWorkSpacePointer(GroupedConvBwdWeightKernelArgsSpecialized& kargs,
void* workspace_ptr)
{
kargs.workspace_ptr = workspace_ptr;
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -463,7 +524,8 @@ struct GroupedConvolutionBackwardWeightKernel
"SplitImage",
EnableSplitImage,
"ExplicitGemm",
GroupedConvTraitsType_::ExplicitGemm
GroupedConvTraitsType_::ExplicitGemm,
IsStreamK ? "StreamK" : "SplitK"
);
// clang-format on
}
@@ -483,11 +545,17 @@ struct GroupedConvolutionBackwardWeightKernel
}
#endif
CK_TILE_HOST static constexpr auto
GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
CK_TILE_HOST static auto GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
return dim3(
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
if constexpr(IsStreamK)
{
auto sk_grid = kargs.tile_partitioner.grid_size();
return dim3(sk_grid.x, kargs.GemmBatch, 1);
}
else
return dim3(TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN),
kargs.GemmBatch,
kargs.k_batch);
}
CK_TILE_HOST static constexpr auto BlockSize()
@@ -495,8 +563,10 @@ struct GroupedConvolutionBackwardWeightKernel
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
}
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
CK_TILE_HOST static GroupedConvBwdWeightKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs,
[[maybe_unused]] int num_cu = 0,
[[maybe_unused]] int occupancy = 0)
{
LogInfo("MPerBlock: ",
number<TilePartitioner::MPerBlock>{},
@@ -507,18 +577,42 @@ struct GroupedConvolutionBackwardWeightKernel
auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>;
// Negative k_batch value: split-K autodeduction.
if(kernel_args.k_batch < 0)
if constexpr(IsStreamK)
{
const auto optimal_split_k =
calculate_optimal_k_batch<GemmPipeline_::BlockSize, KernelImpl, TilePartitioner_>(
kernel_args);
kernel_args.k_batch = optimal_split_k;
// StreamK: construct tile partitioner and embed it in the args.
// Use provided num_cu/occupancy, or query HW.
if(num_cu == 0)
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
ck_tile::hip_check_error(hipGetDevice(&dev));
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
}
if(occupancy == 0)
occupancy = 1; // conservative default; caller may use hipOccupancy API
const index_t grid = num_cu * occupancy;
kernel_args.tile_partitioner =
TilePartitioner(kernel_args.GemmM, kernel_args.GemmN, kernel_args.GemmK, grid);
kernel_args.k_batch = 1; // StreamK does its own K distribution
}
else
{
using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>;
// Negative k_batch value: split-K autodeduction.
if(kernel_args.k_batch < 0)
{
const auto optimal_split_k =
calculate_optimal_k_batch<GemmPipeline_::BlockSize,
KernelImpl,
TilePartitioner_>(kernel_args);
kernel_args.k_batch = optimal_split_k;
}
}
return kernel_args;
@@ -539,6 +633,21 @@ struct GroupedConvolutionBackwardWeightKernel
return false;
}
}
// Runtime arch check — complements the static_assert in operator().
// Both are needed: this check runs on the host (where get_compiler_target()
// isn't available since HIP's host pass doesn't define __gfx*__ macros),
// while the static_assert in operator() catches misuse at device compile time.
if constexpr(IsStreamK)
{
const auto name = get_device_name();
if(name != "gfx90a" && name != "gfx942" && name != "gfx950")
{
LogInfo("StreamK requires cross-CU buffer coherence. "
"Supported: gfx90a, gfx942, gfx950. Got: ",
name);
return false;
}
}
if(kargs.k_batch < 1)
{
LogInfo("k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
@@ -574,17 +683,20 @@ struct GroupedConvolutionBackwardWeightKernel
}
}
if(integer_divide_ceil(kargs.GemmK,
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})) <
kargs.k_batch)
if constexpr(!IsStreamK)
{
LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ",
kargs.GemmK,
", BlockGemmShape K: ",
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}),
", k_batch: ",
kargs.k_batch);
return false;
if(integer_divide_ceil(kargs.GemmK,
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})) <
kargs.k_batch)
{
LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ",
kargs.GemmK,
", BlockGemmShape K: ",
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}),
", k_batch: ",
kargs.k_batch);
return false;
}
}
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
@@ -929,9 +1041,239 @@ struct GroupedConvolutionBackwardWeightKernel
ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
}
CK_TILE_DEVICE void RunStreamK(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
{
// Device-side compile-time arch check — complements the runtime check in
// IsSupportedArgument(). Both are needed: the runtime check runs on the host
// (where get_compiler_target() isn't available since HIP's host pass doesn't
// define __gfx*__ macros), while this catches misuse at device compile time.
static_assert(
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE !=
amd_buffer_coherence_enum::coherence_default,
"StreamK requires cross-CU buffer coherence (StreamKCoherency specialization). "
"Currently supported: gfx90a, gfx942, gfx950.");
__shared__ char smem_ptr[GetSmemSize()];
// Group offset (blockIdx.y = group batch index)
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
// Offset workspace per group so groups don't interfere.
// Safe to mutate kargs: on GPU each workgroup operates on its own
// register-local copy of the kernel arguments.
const auto per_group_ws_size =
kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType));
kargs.workspace_ptr =
static_cast<char*>(kargs.workspace_ptr) + blockIdY * per_group_ws_size;
const index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
StreamKDispatch(
kargs.tile_partitioner,
[&](index_t tile_idx) {
// Data-parallel workgroup: process one full tile
const auto tile_mn = kargs.tile_partitioner.get_output_tile_index(tile_idx);
const index_t i_m =
amd_wave_read_first_lane(tile_mn[I0] * TilePartitioner::MPerBlock);
const index_t i_n =
amd_wave_read_first_lane(tile_mn[I1] * TilePartitioner::NPerBlock);
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr,
kargs,
dp_num_loop,
i_m,
i_n,
/*block_idx_k=*/0);
},
[&](index_t sk_cta_idx) {
RunStreamKLoop(kargs, sk_cta_idx, a_ptr, b_ptr, c_ptr, smem_ptr);
});
}
/// @brief Stream-K loop: iterate over assigned K-iterations, run GEMM pipeline,
/// and perform Linear or Tree reduction to accumulate partial results.
CK_TILE_DEVICE void RunStreamKLoop(GroupedConvBwdWeightKernelArgsSpecialized& kargs,
index_t sk_cta_idx,
const OutDataType* a_ptr,
const InDataType* b_ptr,
WeiDataType* c_ptr,
char* smem_ptr) const
{
const StreamKOps sk_ops{};
index_t iter_start, iter_end;
kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, sk_cta_idx);
while(iter_start < iter_end)
{
index_t tile_idx =
amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
index_t tile_iter_start, tile_iter_end;
kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
index_t local_iter_start = amd_wave_read_first_lane(
kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
index_t local_iter_end =
amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
tile_iter_start, iter_end, tile_iter_end));
index_t num_loop_sk = local_iter_end - local_iter_start;
// Compute M/N tile indices from 1D tile index
const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
const index_t i_m =
amd_wave_read_first_lane(c_macro_tile_idx[I0] * TilePartitioner::MPerBlock);
const index_t i_n =
amd_wave_read_first_lane(c_macro_tile_idx[I1] * TilePartitioner::NPerBlock);
// K offset = local_iter_start * KPerBlock
const index_t i_k =
amd_wave_read_first_lane(local_iter_start * TilePartitioner::KPerBlock);
// Create block windows and run pipeline
const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, i_m, i_k);
const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, i_n, i_k);
const auto& d_block_window = MakeDBlockWindows(kargs.ds_ptr, kargs, i_m, i_n);
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop_sk, has_hot_loop, tail_num, smem_ptr);
auto tile_started = iter_start == tile_iter_start;
auto tile_ended = iter_end >= tile_iter_end;
if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Linear)
{
// Linear Reduction: tile-starter sequentially accumulates all
// partials from subsequent CTAs in order.
if(!tile_started)
{
sk_ops.StorePartial(kargs, sk_cta_idx, c_block_tile);
sk_ops.SignalStorePartialDone(kargs, sk_cta_idx);
}
else
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
{
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = sk_cta_idx + 1;
while(accum_iters < iter_per_tile)
{
sk_ops.WaitStorePartialDone(kargs, next_cta);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
sk_ops.AddBlockTile(
accum_block_tile,
sk_ops.template LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
}
}
auto c_block_window_out =
MakeCBlockWindow<memory_operation_enum::set>(c_ptr, kargs, i_m, i_n);
EpiloguePipeline{}(
c_block_window_out, accum_block_tile, d_block_window, smem_ptr);
}
}
else if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Tree)
{
// Tree Reduction: pairwise reduction with stride doubling.
// At each round, half the CTAs store their accumulated partial
// and exit; the other half load and accumulate from their partner.
// The tile-starter writes the final result.
auto accum_block_tile = c_block_tile;
index_t tile_local_cta_idx = amd_wave_read_first_lane(
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, sk_cta_idx));
index_t stride = amd_wave_read_first_lane(1);
for(;; stride <<= 1)
{
// Partner index is a *global* SK CTA index. This works because
// CTAs contributing to the same tile always have contiguous global
// SK CTA indices (guaranteed by the partitioner's iteration assignment).
const index_t partner_cta_idx = amd_wave_read_first_lane(sk_cta_idx + stride);
const index_t partner_start_iter = amd_wave_read_first_lane(
kargs.tile_partitioner.get_start_iter(partner_cta_idx));
bool partner_in_tile =
amd_wave_read_first_lane(partner_start_iter < tile_iter_end);
// If the partner of the tile-starter is not in this tile,
// then all partials are accumulated — write final result.
if(tile_started && !partner_in_tile)
{
auto c_block_window_out =
MakeCBlockWindow<memory_operation_enum::set>(c_ptr, kargs, i_m, i_n);
EpiloguePipeline{}(
c_block_window_out, accum_block_tile, d_block_window, smem_ptr);
break;
}
// This CTA's turn to read from its partner and accumulate.
if(tile_local_cta_idx % (stride << 1) == 0)
{
if(partner_in_tile)
{
sk_ops.WaitStorePartialDone(kargs, partner_cta_idx);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
sk_ops.AddBlockTile(
accum_block_tile,
sk_ops.template LoadPartial<typename BlockType::DataType>(
kargs, partner_cta_idx, c_block_tile.get_tile_distribution()));
}
}
// This CTA's turn to write its partial and exit.
else
{
sk_ops.StorePartial(kargs, sk_cta_idx, accum_block_tile);
sk_ops.SignalStorePartialDone(kargs, sk_cta_idx);
break;
}
}
}
else
{
static_assert(
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Linear ||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Tree,
"Unsupported StreamK reduction strategy for conv bwd weight.");
}
// Advance to next tile
iter_start = tile_iter_end;
block_sync_lds();
}
}
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
{
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
if constexpr(IsStreamK)
{
RunStreamK(kargs);
}
else if constexpr(GroupedConvTraitsType_::ExplicitGemm)
{
CallExplicitGemm(kargs);
}

View File

@@ -5,3 +5,9 @@
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_grouped_conv_bwd_weight test_ck_tile_grouped_conv_bwd_weight.cpp)
endif()
# StreamK requires cross-CU coherence via StreamKCoherency, which only has
# specializations for CDNA architectures (gfx90a/gfx942/gfx950).
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
add_gtest_executable(test_ck_tile_grouped_conv_bwd_weight_streamk test_ck_tile_grouped_conv_bwd_weight_streamk.cpp)
endif()

View File

@@ -0,0 +1,641 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gtest/gtest.h"
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/fill.hpp"
using namespace ck_tile;
struct StreamKTestConvConfig
{
static constexpr index_t VectorSizeA = 4;
static constexpr index_t VectorSizeB = 8;
static constexpr index_t VectorSizeC = 8;
static constexpr index_t M_Tile = 128;
static constexpr index_t N_Tile = 128;
static constexpr index_t K_Tile = 32;
static constexpr index_t M_Warp = 2;
static constexpr index_t N_Warp = 2;
static constexpr index_t K_Warp = 1;
static constexpr index_t M_Warp_Tile = 16;
static constexpr index_t N_Warp_Tile = 16;
static constexpr index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr GemmPipeline Pipeline = GemmPipeline::COMPUTE_V3;
static constexpr index_t NumWaveGroups = 1;
static constexpr index_t NumGroupsToMerge = 1;
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
};
// Build a conv bwd weight kernel type from a tile partitioner.
// Works for both StreamK and Split-K partitioners.
template <typename PrecType,
typename ConvConfig,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TilePartitioner_,
index_t NDimSpatial = 2>
struct BuildKernel
{
using GemmShape = TileGemmShape<
sequence<ConvConfig::M_Tile, ConvConfig::N_Tile, ConvConfig::K_Tile>,
sequence<ConvConfig::M_Warp, ConvConfig::N_Warp, ConvConfig::K_Warp>,
sequence<ConvConfig::M_Warp_Tile, ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile>>;
using ConvTraits = GroupedConvTraits<NDimSpatial,
ConvolutionSpecialization::Default,
InLayout,
WeiLayout,
tuple<>,
OutLayout,
ConvConfig::VectorSizeA,
ConvConfig::VectorSizeB,
ConvConfig::VectorSizeC,
ConvConfig::NumGroupsToMerge>;
using GemmUniversalTraits =
TileGemmUniversalTraits<ConvTraits::FixedGemmParams::kPadM,
ConvTraits::FixedGemmParams::kPadN,
ConvTraits::FixedGemmParams::kPadK,
ConvConfig::DoubleSmemBuffer,
typename ConvTraits::AsLayoutBwdWeight,
typename ConvTraits::BsLayoutBwdWeight,
typename ConvTraits::CLayoutBwdWeight,
ConvTraits::FixedGemmParams::TransposeC,
ConvTraits::FixedGemmParams::UseStructuredSparsity,
ConvTraits::FixedGemmParams::Persistent,
ConvConfig::NumWaveGroups>;
using UniversalGemmProblem =
UniversalGemmPipelineProblem<PrecType,
PrecType,
float,
GemmShape,
GemmUniversalTraits,
ConvConfig::Scheduler,
element_wise::PassThrough,
element_wise::PassThrough,
PrecType,
ConvTraits::FixedGemmParams::FixedVectorSize,
ConvTraits::VectorSizeA,
ConvTraits::VectorSizeB>;
using GemmPipeline_ = GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using EpilogueProblem = CShuffleEpilogueProblem<PrecType,
PrecType,
tuple<>,
float,
PrecType,
typename ConvTraits::ImplicitGemmDsLayout,
typename ConvTraits::FixedGemmParams::ELayout,
element_wise::PassThrough,
TilePartitioner_::MPerBlock,
TilePartitioner_::NPerBlock,
ConvConfig::M_Warp,
ConvConfig::N_Warp,
ConvConfig::M_Warp_Tile,
ConvConfig::N_Warp_Tile,
ConvConfig::K_Warp_Tile,
ConvTraits::FixedGemmParams::TransposeC,
ConvConfig::NumWaveGroups,
ConvTraits::FixedGemmParams::FixedVectorSize,
ConvTraits::VectorSizeC>;
using Epilogue = CShuffleEpilogue<EpilogueProblem>;
using type = GroupedConvolutionBackwardWeightKernel<ConvTraits,
TilePartitioner_,
GemmPipeline_,
Epilogue>;
};
// Helper to create 2D host args
static GroupedConvBwdWeightHostArgs create_host_args(index_t G,
index_t N,
index_t K,
index_t C,
index_t Y,
index_t X,
index_t Hi,
index_t Wi,
index_t stride_y,
index_t stride_x,
index_t dilation_y,
index_t dilation_x,
index_t left_pad_y,
index_t left_pad_x,
index_t right_pad_y,
index_t right_pad_x,
index_t k_batch = 1)
{
auto conv_param = conv::ConvParam{2,
G,
N,
K,
C,
{Y, X},
{Hi, Wi},
{stride_y, stride_x},
{dilation_y, dilation_x},
{left_pad_y, left_pad_x},
{right_pad_y, right_pad_x}};
return GroupedConvBwdWeightHostArgs{conv_param, nullptr, nullptr, {}, nullptr, k_batch};
}
// Common type aliases
using InLayout = tensor_layout::convolution::NHWGC;
using WeiLayout = tensor_layout::convolution::GKYXC;
using OutLayout = tensor_layout::convolution::NHWGK;
using PrecType = half_t;
using TestGemmShape =
TileGemmShape<sequence<128, 128, 32>, sequence<2, 2, 1>, sequence<16, 16, 16>>;
using SplitKPartitioner = GemmSpatiallyLocalTilePartitioner<TestGemmShape, 8, 4>;
using LinearPartitioner =
StreamKTilePartitioner<TestGemmShape, StreamKReductionStrategy::Linear, false>;
using TreePartitioner =
StreamKTilePartitioner<TestGemmShape, StreamKReductionStrategy::Tree, false>;
using LinearPersistentPartitioner =
StreamKTilePartitioner<TestGemmShape, StreamKReductionStrategy::Linear, true>;
using TreePersistentPartitioner =
StreamKTilePartitioner<TestGemmShape, StreamKReductionStrategy::Tree, true>;
template <typename Partitioner>
using TestKernel = typename BuildKernel<PrecType,
StreamKTestConvConfig,
InLayout,
WeiLayout,
OutLayout,
Partitioner>::type;
// ============================================================================
// Host-side unit tests
// ============================================================================
TEST(StreamKConvBwdWeight, TypeTraitDetection)
{
EXPECT_FALSE(is_streamk_partitioner<SplitKPartitioner>::value);
EXPECT_TRUE(is_streamk_partitioner<LinearPartitioner>::value);
EXPECT_TRUE(is_streamk_partitioner<TreePartitioner>::value);
}
TEST(StreamKConvBwdWeight, KernelArgsConstruction_LinearPartitioner)
{
using Kernel = TestKernel<LinearPartitioner>;
EXPECT_TRUE(Kernel::IsStreamK);
auto host_args = create_host_args(1, 4, 128, 128, 3, 3, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1);
auto kargs = Kernel::MakeKernelArgs(host_args, /*num_cu=*/4, /*occupancy=*/1);
EXPECT_EQ(kargs.k_batch, 1);
EXPECT_GT(kargs.GemmM, 0);
EXPECT_GT(kargs.GemmN, 0);
EXPECT_GT(kargs.GemmK, 0);
EXPECT_GT(kargs.tile_partitioner.get_max_active_wgs(), 0);
}
TEST(StreamKConvBwdWeight, KernelArgsConstruction_TreePartitioner)
{
using Kernel = TestKernel<TreePartitioner>;
EXPECT_TRUE(Kernel::IsStreamK);
auto host_args = create_host_args(1, 4, 128, 128, 3, 3, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1);
auto kargs = Kernel::MakeKernelArgs(host_args, /*num_cu=*/4, /*occupancy=*/1);
EXPECT_EQ(kargs.k_batch, 1);
EXPECT_GT(kargs.GemmM, 0);
EXPECT_GT(kargs.GemmN, 0);
EXPECT_GT(kargs.GemmK, 0);
EXPECT_GT(kargs.tile_partitioner.get_max_active_wgs(), 0);
}
TEST(StreamKConvBwdWeight, GridSize)
{
using Kernel = TestKernel<LinearPartitioner>;
auto host_args = create_host_args(1, 4, 128, 128, 3, 3, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1);
auto kargs = Kernel::MakeKernelArgs(host_args, /*num_cu=*/4, /*occupancy=*/1);
auto grid = Kernel::GridSize(kargs);
auto sk_grid = kargs.tile_partitioner.grid_size();
EXPECT_EQ(grid.x, sk_grid.x);
EXPECT_EQ(grid.y, static_cast<unsigned int>(kargs.GemmBatch));
EXPECT_EQ(grid.z, 1u);
}
TEST(StreamKConvBwdWeight, WorkSpaceSize)
{
using Kernel = TestKernel<LinearPartitioner>;
auto host_args = create_host_args(1, 4, 128, 128, 3, 3, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1);
auto kargs = Kernel::MakeKernelArgs(host_args, /*num_cu=*/4, /*occupancy=*/1);
EXPECT_GT(Kernel::GetWorkSpaceSize(kargs), 0);
}
TEST(StreamKConvBwdWeight, SplitKNoWorkspace)
{
using Kernel = TestKernel<SplitKPartitioner>;
EXPECT_FALSE(Kernel::IsStreamK);
auto host_args = create_host_args(1, 4, 128, 128, 3, 3, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1);
auto kargs = Kernel::MakeKernelArgs(host_args);
EXPECT_EQ(Kernel::GetWorkSpaceSize(kargs), 0);
}
// ============================================================================
// GPU end-to-end tests: StreamK vs Split-K=1 reference
// ============================================================================
template <typename StreamKKernelType>
static bool run_streamk_vs_splitk_test(index_t G,
index_t N,
index_t K,
index_t C,
index_t Y,
index_t X,
index_t Hi,
index_t Wi,
index_t num_cu,
index_t occupancy,
index_t stride_h = 1,
index_t stride_w = 1,
index_t dilation_h = 1,
index_t dilation_w = 1,
index_t lpad_h = 1,
index_t lpad_w = 1,
index_t rpad_h = 1,
index_t rpad_w = 1)
{
using RefKernel = TestKernel<SplitKPartitioner>;
constexpr index_t NDimSpatial = 2;
auto conv_param = conv::ConvParam{NDimSpatial,
G,
N,
K,
C,
{Y, X},
{Hi, Wi},
{stride_h, stride_w},
{dilation_h, dilation_w},
{lpad_h, lpad_w},
{rpad_h, rpad_w}};
const auto in_desc =
conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_desc =
conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_desc =
conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
HostTensor<PrecType> input(in_desc);
HostTensor<PrecType> output(out_desc);
HostTensor<PrecType> weight_ref(wei_desc);
HostTensor<PrecType> weight_streamk(wei_desc);
FillUniformDistribution<PrecType>{-1.f, 1.f}(input);
FillUniformDistribution<PrecType>{-1.f, 1.f}(output);
DeviceMem input_dev(input.get_element_space_size_in_bytes());
DeviceMem output_dev(output.get_element_space_size_in_bytes());
DeviceMem weight_ref_dev(weight_ref.get_element_space_size_in_bytes());
DeviceMem weight_streamk_dev(weight_streamk.get_element_space_size_in_bytes());
input_dev.ToDevice(input.data());
output_dev.ToDevice(output.data());
// Reference: Split-K=1
{
weight_ref_dev.SetZero();
GroupedConvBwdWeightHostArgs host_args(conv_param,
input_dev.GetDeviceBuffer(),
weight_ref_dev.GetDeviceBuffer(),
{},
output_dev.GetDeviceBuffer(),
/*k_batch=*/1);
auto kargs = RefKernel::MakeKernelArgs(host_args);
if(!RefKernel::IsSupportedArgument(kargs))
{
std::cout << "Split-K kernel does not support this shape, skipping\n";
return true;
}
auto kernel_func = make_kernel<1>(
RefKernel{}, RefKernel::GridSize(kargs), RefKernel::BlockSize(), 0, kargs);
launch_kernel(stream_config{nullptr, false}, kernel_func);
hip_check_error(hipDeviceSynchronize());
}
// StreamK under test
{
weight_streamk_dev.SetZero();
GroupedConvBwdWeightHostArgs host_args(conv_param,
input_dev.GetDeviceBuffer(),
weight_streamk_dev.GetDeviceBuffer(),
{},
output_dev.GetDeviceBuffer(),
/*k_batch=*/1);
auto kargs = StreamKKernelType::MakeKernelArgs(host_args, num_cu, occupancy);
auto ws_size = StreamKKernelType::GetWorkSpaceSize(kargs);
DeviceMem workspace_dev(ws_size);
workspace_dev.SetZero();
StreamKKernelType::SetWorkSpacePointer(kargs, workspace_dev.GetDeviceBuffer());
auto kernel_func = make_kernel<1>(StreamKKernelType{},
StreamKKernelType::GridSize(kargs),
StreamKKernelType::BlockSize(),
0,
kargs);
launch_kernel(stream_config{nullptr, false}, kernel_func);
hip_check_error(hipDeviceSynchronize());
}
weight_ref_dev.FromDevice(weight_ref.data());
weight_streamk_dev.FromDevice(weight_streamk.data());
// Compute GemmK = N * product(output_spatial_lengths) for bwd weight
const index_t GemmK = N * std::accumulate(conv_param.output_spatial_lengths_.begin(),
conv_param.output_spatial_lengths_.end(),
static_cast<index_t>(1),
std::multiplies<index_t>());
// Max accumulated value calibrates atol to the output's ULP scale.
const float max_accumulated_value =
*std::max_element(weight_ref.mData.begin(), weight_ref.mData.end());
// Tolerance follows the calculate_rtol_atol pattern from conv examples:
// (1) GEMM accumulation error: fp16 compute, fp16 output, f32 accumulator
// (2) Reduction error: accounts for fp16 output quantization differences
// when two f32 results (from different accumulation orders) round to fp16
using ComputeType = PrecType;
using AccType = float;
constexpr index_t kbatch = 1;
const auto rtol_gemm =
get_relative_threshold<ComputeType, PrecType, AccType>(integer_divide_ceil(GemmK, kbatch));
const auto atol_gemm = get_absolute_threshold<ComputeType, PrecType, AccType>(
max_accumulated_value / kbatch, integer_divide_ceil(GemmK, kbatch));
const auto rtol_reduction = get_relative_threshold<PrecType, PrecType, PrecType>(kbatch);
const auto atol_reduction =
get_absolute_threshold<PrecType, PrecType, PrecType>(max_accumulated_value, kbatch);
const double rtol = std::max(rtol_gemm, rtol_reduction);
const double atol = std::max(atol_gemm, atol_reduction);
return check_err(weight_streamk, weight_ref, "StreamK vs SplitK mismatch", rtol, atol);
}
// Linear Reduction
TEST(StreamKConvBwdWeight, Linear_EndToEnd_SmallShape)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<LinearPartitioner>>(
1, 4, 128, 128, 3, 3, 16, 16, 2, 1)));
}
TEST(StreamKConvBwdWeight, Linear_EndToEnd_MediumShape)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<LinearPartitioner>>(
1, 8, 256, 128, 3, 3, 16, 16, 4, 1)));
}
TEST(StreamKConvBwdWeight, Linear_EndToEnd_MoreSKWork)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<LinearPartitioner>>(
1, 4, 128, 128, 3, 3, 16, 16, 4, 1)));
}
TEST(StreamKConvBwdWeight, Linear_EndToEnd_MultiGroup)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<LinearPartitioner>>(
2, 4, 128, 128, 3, 3, 16, 16, 4, 1)));
}
// Tree Reduction
TEST(StreamKConvBwdWeight, Tree_EndToEnd_SmallShape)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<TreePartitioner>>(
1, 4, 128, 128, 3, 3, 16, 16, 2, 1)));
}
TEST(StreamKConvBwdWeight, Tree_EndToEnd_MediumShape)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<TreePartitioner>>(
1, 8, 256, 128, 3, 3, 16, 16, 4, 1)));
}
TEST(StreamKConvBwdWeight, Tree_EndToEnd_MoreSKWork)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<TreePartitioner>>(
1, 4, 128, 128, 3, 3, 16, 16, 4, 1)));
}
TEST(StreamKConvBwdWeight, Tree_EndToEnd_MultiGroup)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<TreePartitioner>>(
2, 4, 128, 128, 3, 3, 16, 16, 4, 1)));
}
// Stride > 1 — shrinks Ho/Wo, changing the K/tile ratio and DP/SK split.
// Hi=16, Wi=16, 3x3 filter, stride=2, pad=1 → Ho=Wo=8, GemmK=N*64
TEST(StreamKConvBwdWeight, Linear_EndToEnd_Stride2)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<LinearPartitioner>>(1,
4,
128,
128,
3,
3,
16,
16,
4,
1,
/*stride=*/2,
2,
/*dil=*/1,
1,
/*pad=*/1,
1,
1,
1)));
}
TEST(StreamKConvBwdWeight, Tree_EndToEnd_Stride2)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<TreePartitioner>>(1,
4,
128,
128,
3,
3,
16,
16,
4,
1,
/*stride=*/2,
2,
/*dil=*/1,
1,
/*pad=*/1,
1,
1,
1)));
}
// Pure DP — num_tiles evenly divides grid, so sk_ctas=0.
// K=256, C=128, 3x3 → GemmM=256, GemmN=1152 → tiles=2*9=18, grid=3*1=3, 18%3=0
TEST(StreamKConvBwdWeight, Linear_EndToEnd_PureDP)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<LinearPartitioner>>(
1, 4, 256, 128, 3, 3, 16, 16, 3, 1)));
}
TEST(StreamKConvBwdWeight, Tree_EndToEnd_PureDP)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<TreePartitioner>>(
1, 4, 256, 128, 3, 3, 16, 16, 3, 1)));
}
// Single output tile — all work is SK, zero DP tiles.
// K=128, C=128, 1x1 filter, stride=1, pad=0 → GemmM=128, GemmN=128, tiles=1
TEST(StreamKConvBwdWeight, Linear_EndToEnd_SingleTile)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<LinearPartitioner>>(1,
4,
128,
128,
1,
1,
16,
16,
4,
1,
/*stride=*/1,
1,
/*dil=*/1,
1,
/*pad=*/0,
0,
0,
0)));
}
TEST(StreamKConvBwdWeight, Tree_EndToEnd_SingleTile)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<TreePartitioner>>(1,
4,
128,
128,
1,
1,
16,
16,
4,
1,
/*stride=*/1,
1,
/*dil=*/1,
1,
/*pad=*/0,
0,
0,
0)));
}
// Large N — GemmK = 32*16*16 = 8192, many K iterations per tile.
TEST(StreamKConvBwdWeight, Linear_EndToEnd_LargeN)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<LinearPartitioner>>(
1, 32, 128, 128, 3, 3, 16, 16, 4, 1)));
}
TEST(StreamKConvBwdWeight, Tree_EndToEnd_LargeN)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<TreePartitioner>>(
1, 32, 128, 128, 3, 3, 16, 16, 4, 1)));
}
// Higher occupancy — doubles the grid, more SK CTAs share tiles.
TEST(StreamKConvBwdWeight, Linear_EndToEnd_HigherOccupancy)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<LinearPartitioner>>(
1, 4, 128, 128, 3, 3, 16, 16, 4, 2)));
}
TEST(StreamKConvBwdWeight, Tree_EndToEnd_HigherOccupancy)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<TreePartitioner>>(
1, 4, 128, 128, 3, 3, 16, 16, 4, 2)));
}
// Persistent DP — workgroups loop over DP tiles, then do SK work.
TEST(StreamKConvBwdWeight, LinearPersistent_EndToEnd_SmallShape)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<LinearPersistentPartitioner>>(
1, 4, 128, 128, 3, 3, 16, 16, 2, 1)));
}
TEST(StreamKConvBwdWeight, TreePersistent_EndToEnd_SmallShape)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<TreePersistentPartitioner>>(
1, 4, 128, 128, 3, 3, 16, 16, 2, 1)));
}
TEST(StreamKConvBwdWeight, LinearPersistent_EndToEnd_MultiGroup)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<LinearPersistentPartitioner>>(
2, 4, 128, 128, 3, 3, 16, 16, 4, 1)));
}
TEST(StreamKConvBwdWeight, TreePersistent_EndToEnd_MultiGroup)
{
EXPECT_TRUE((run_streamk_vs_splitk_test<TestKernel<TreePersistentPartitioner>>(
2, 4, 128, 128, 3, 3, 16, 16, 4, 1)));
}
// ============================================================================
// Negative tests: IsSupportedArgument should reject invalid shapes
// ============================================================================
// C not divisible by VectorSizeB (=8) → rejected
TEST(StreamKConvBwdWeight, IsSupportedArgument_RejectsUnalignedC)
{
using Kernel = TestKernel<LinearPartitioner>;
auto host_args = create_host_args(1, 4, 128, 100, 3, 3, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1);
auto kargs = Kernel::MakeKernelArgs(host_args, /*num_cu=*/4, /*occupancy=*/1);
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs));
}
// K not divisible by VectorSizeA (=4) → rejected
TEST(StreamKConvBwdWeight, IsSupportedArgument_RejectsUnalignedK)
{
using Kernel = TestKernel<TreePartitioner>;
auto host_args = create_host_args(1, 4, 103, 128, 3, 3, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1);
auto kargs = Kernel::MakeKernelArgs(host_args, /*num_cu=*/4, /*occupancy=*/1);
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs));
}