[rocm-libraries] ROCm/rocm-libraries#5393 (commit d51b649)

[CK Tile] StreamK support for Bwd Weight grouped convolutions
 (#5393)

## Motivation

Add StreamK work distribution to the CK Tile grouped convolution
backward weight kernel. Split-K divides the K-dimension uniformly across
a fixed `k_batch`, which causes load imbalance when the number of output
tiles doesn't evenly fill the GPU. StreamK distributes total
K-iterations evenly across workgroups, improving utilization on these
shapes.

## Technical Details

StreamK is added as an `if constexpr` branch in the existing kernel,
selected by the `TilePartitioner_` template parameter. Two reduction
strategies are supported:
- **Linear**: tile-starter sequentially accumulates partials from
contributing CTAs
- **Tree**: pairwise binary tree reduction (O(log n) depth, faster for
many contributors)

Both persistent and non-persistent data-parallel (DP) sections are
supported.

Key changes:
- `grouped_convolution_backward_weight_kernel.hpp`: StreamK execution
path with `RunStreamK`/`RunStreamKLoop`, partial store/load via
workspace, flag-based cross-CTA synchronization,
`GridSize`/`MakeKernelArgs`/`GetWorkSpaceSize` extensions
- `streamk_common.hpp`: Shared `StreamKReductionOps` (reduction helpers)
and `StreamKDispatch` (persistent/non-persistent DP dispatch), used by
both GEMM and Conv StreamK kernels
- `streamk_gemm_kernel.hpp`: Refactored to use shared helpers
- Merged split-K and StreamK example invokers via `PartitionerPolicy`
template parameter
- StreamK example binary with `--streamk_reduction=linear|tree` and
`--streamk_persistent=0|1`
- CK Builder integration: `SpecifiesStreamK` concept,
`TilePartitionerType` factory helper, `InstanceTraits` with StreamK
fields
- 30 tests: host-side, GPU end-to-end (Linear + Tree + Persistent DP),
negative, builder regression

### Performance (MI355X, gfx950)

Speedup relative to best split-K (sweep over k_batch={1,2,4,8,16,32}):

| Shape | 16x64 tiles | | 128x128 tiles | |
|---|---|---|---|---|
| | Split-K | StreamK | Split-K | StreamK |
| 1x1 128x128 N=32 28x28 | 1.00x | 0.54x | 1.00x | 0.81x |
| 3x3 128x128 N=32 14x14 | 1.00x | 0.59x | 1.00x | 0.62x |
| 1x1 256x64 N=32 56x56 | 1.00x | 0.83x | 1.00x | 1.83x |
| 3x3 512x512 N=2 7x7 | 1.00x | 1.12x | 1.00x | 0.62x |
| 1x1 1024x1024 N=4 7x7 | 1.00x | 1.09x | 1.00x | 0.60x |
| 3x3 128x128 N=32 28x28 | 1.00x | 0.44x | 1.00x | 0.96x |
| 3x3 256x256 N=32 14x14 | 1.00x | 0.67x | 1.00x | 0.93x |
| 3x3 512x512 N=32 7x7 | 1.00x | 0.98x | 1.00x | 1.16x |

StreamK's value depends on tile config: with larger tiles (fewer output
tiles), StreamK delivers up to 1.83x speedup on bottleneck shapes and up
to 1.16x on typical large-channel convolutions. Tree reduction
consistently outperforms Linear when multiple CTAs contribute to the
same tile (up to 2.87x faster), due to O(log n) reduction depth vs O(n)
sequential accumulation. The table reports the best of Linear and Tree
for each shape.

## Test Plan

```bash
ninja -C build test_ck_tile_grouped_conv_bwd_weight_streamk
./build/bin/test_ck_tile_grouped_conv_bwd_weight_streamk

# Builder tests (requires CK_EXPERIMENTAL_BUILDER=ON)
ninja -C build check-builder
```

30 tests covering:
- Host-side: type traits, kernel args construction, grid size, workspace
size
- GPU end-to-end (Linear + Tree): small/medium shapes, multi-group,
stride>1, pure-DP degeneration, single-tile all-SK, large GemmK, higher
occupancy
- Persistent DP: Linear + Tree with persistent data-parallel dispatch
- Negative: `IsSupportedArgument` rejects unaligned K and C
- Builder: Create (instance string validation) + Execution (reference
comparison) + instance string regression

## Test Result

All 30 conv StreamK tests pass on MI355X (gfx950). 64/64 GEMM StreamK
tests pass. Full `check-builder` suite passes. Tolerances computed
dynamically using `calculate_rtol_atol` pattern (fp16 ULP-aware).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Johannes Graner
2026-03-27 09:18:14 +00:00
committed by assistant-librarian[bot]
parent 36f2ec23f5
commit 58475d3f45
21 changed files with 1860 additions and 348 deletions

View File

@@ -16,12 +16,25 @@
#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
#endif
namespace ck_tile {
template <typename T>
struct is_streamk_partitioner : std::false_type
{
};
template <typename Shape, StreamKReductionStrategy S, bool P>
struct is_streamk_partitioner<StreamKTilePartitioner<Shape, S, P>> : std::true_type
{
};
template <typename... Args>
CK_TILE_HOST void LogInfo(Args&&... args) noexcept
{
@@ -32,7 +45,7 @@ CK_TILE_HOST void LogInfo(Args&&... args) noexcept
}
/// @brief The Grouped Convolution kernel device arguments.
template <typename GroupedConvTraitsType_>
template <typename GroupedConvTraitsType_, typename TilePartitioner_ = void>
struct GroupedConvBwdWeightKernelArgs
{
@@ -354,6 +367,23 @@ struct GroupedConvBwdWeightKernelArgs
long_index_t group_stride_a;
long_index_t group_stride_b;
long_index_t group_stride_c;
void* workspace_ptr = nullptr;
// StreamK tile partitioner — stored directly when TilePartitioner_ is a real type,
// empty struct when void (Split-K path). Constructed with dummy values here;
// properly initialized in MakeKernelArgs before device-side use.
struct EmptyPartitioner
{
};
using PartitionerType =
std::conditional_t<std::is_void_v<TilePartitioner_>, EmptyPartitioner, TilePartitioner_>;
PartitionerType tile_partitioner = []() {
if constexpr(std::is_void_v<TilePartitioner_>)
return EmptyPartitioner{};
else
return TilePartitioner_(1, 1, 1, 1);
}();
};
/// @brief The Grouped Convolution Backward Weight kernel template.
@@ -424,10 +454,15 @@ struct GroupedConvolutionBackwardWeightKernel
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
using WeiDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using GroupedConvBwdWeightKernelArgsSpecialized =
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>;
static constexpr bool IsSplitKSupported = true;
static constexpr bool IsStreamK = is_streamk_partitioner<TilePartitioner>::value;
using GroupedConvBwdWeightKernelArgsSpecialized =
std::conditional_t<IsStreamK,
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_, TilePartitioner>,
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>>;
using AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -442,6 +477,32 @@ struct GroupedConvolutionBackwardWeightKernel
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
"Not supported!");
static_assert(!IsStreamK || NumDTensor == 0,
"D tensor per-group offsets not implemented for StreamK path");
// StreamK reduction helpers (partial store/load, flag signaling, tile accumulation).
// Shared with the StreamK GEMM kernel via StreamKReductionOps in streamk_common.hpp.
using StreamKOps = StreamKReductionOps<TilePartitioner,
GemmPipeline,
GroupedConvBwdWeightKernelArgsSpecialized>;
CK_TILE_HOST static index_t
GetWorkSpaceSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
if constexpr(IsStreamK)
return kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType)) * kargs.GemmBatch;
else
return 0;
}
// Post-construction setter: workspace is allocated by the caller after
// GetWorkSpaceSize() and must outlive the kernel launch. Can't be moved into
// the constructor because kargs is a POD value type copied to GPU constant memory.
CK_TILE_HOST static void SetWorkSpacePointer(GroupedConvBwdWeightKernelArgsSpecialized& kargs,
void* workspace_ptr)
{
kargs.workspace_ptr = workspace_ptr;
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -463,7 +524,8 @@ struct GroupedConvolutionBackwardWeightKernel
"SplitImage",
EnableSplitImage,
"ExplicitGemm",
GroupedConvTraitsType_::ExplicitGemm
GroupedConvTraitsType_::ExplicitGemm,
IsStreamK ? "StreamK" : "SplitK"
);
// clang-format on
}
@@ -483,11 +545,17 @@ struct GroupedConvolutionBackwardWeightKernel
}
#endif
CK_TILE_HOST static constexpr auto
GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
CK_TILE_HOST static auto GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
return dim3(
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
if constexpr(IsStreamK)
{
auto sk_grid = kargs.tile_partitioner.grid_size();
return dim3(sk_grid.x, kargs.GemmBatch, 1);
}
else
return dim3(TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN),
kargs.GemmBatch,
kargs.k_batch);
}
CK_TILE_HOST static constexpr auto BlockSize()
@@ -495,8 +563,10 @@ struct GroupedConvolutionBackwardWeightKernel
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
}
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
CK_TILE_HOST static GroupedConvBwdWeightKernelArgsSpecialized
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs,
[[maybe_unused]] int num_cu = 0,
[[maybe_unused]] int occupancy = 0)
{
LogInfo("MPerBlock: ",
number<TilePartitioner::MPerBlock>{},
@@ -507,18 +577,42 @@ struct GroupedConvolutionBackwardWeightKernel
auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>;
// Negative k_batch value: split-K autodeduction.
if(kernel_args.k_batch < 0)
if constexpr(IsStreamK)
{
const auto optimal_split_k =
calculate_optimal_k_batch<GemmPipeline_::BlockSize, KernelImpl, TilePartitioner_>(
kernel_args);
kernel_args.k_batch = optimal_split_k;
// StreamK: construct tile partitioner and embed it in the args.
// Use provided num_cu/occupancy, or query HW.
if(num_cu == 0)
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
ck_tile::hip_check_error(hipGetDevice(&dev));
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
}
if(occupancy == 0)
occupancy = 1; // conservative default; caller may use hipOccupancy API
const index_t grid = num_cu * occupancy;
kernel_args.tile_partitioner =
TilePartitioner(kernel_args.GemmM, kernel_args.GemmN, kernel_args.GemmK, grid);
kernel_args.k_batch = 1; // StreamK does its own K distribution
}
else
{
using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
TilePartitioner_,
GemmPipeline_,
EpiloguePipeline_>;
// Negative k_batch value: split-K autodeduction.
if(kernel_args.k_batch < 0)
{
const auto optimal_split_k =
calculate_optimal_k_batch<GemmPipeline_::BlockSize,
KernelImpl,
TilePartitioner_>(kernel_args);
kernel_args.k_batch = optimal_split_k;
}
}
return kernel_args;
@@ -539,6 +633,21 @@ struct GroupedConvolutionBackwardWeightKernel
return false;
}
}
// Runtime arch check — complements the static_assert in operator().
// Both are needed: this check runs on the host (where get_compiler_target()
// isn't available since HIP's host pass doesn't define __gfx*__ macros),
// while the static_assert in operator() catches misuse at device compile time.
if constexpr(IsStreamK)
{
const auto name = get_device_name();
if(name != "gfx90a" && name != "gfx942" && name != "gfx950")
{
LogInfo("StreamK requires cross-CU buffer coherence. "
"Supported: gfx90a, gfx942, gfx950. Got: ",
name);
return false;
}
}
if(kargs.k_batch < 1)
{
LogInfo("k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
@@ -574,17 +683,20 @@ struct GroupedConvolutionBackwardWeightKernel
}
}
if(integer_divide_ceil(kargs.GemmK,
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})) <
kargs.k_batch)
if constexpr(!IsStreamK)
{
LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ",
kargs.GemmK,
", BlockGemmShape K: ",
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}),
", k_batch: ",
kargs.k_batch);
return false;
if(integer_divide_ceil(kargs.GemmK,
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})) <
kargs.k_batch)
{
LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ",
kargs.GemmK,
", BlockGemmShape K: ",
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}),
", k_batch: ",
kargs.k_batch);
return false;
}
}
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
@@ -929,9 +1041,239 @@ struct GroupedConvolutionBackwardWeightKernel
ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
}
CK_TILE_DEVICE void RunStreamK(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
{
// Device-side compile-time arch check — complements the runtime check in
// IsSupportedArgument(). Both are needed: the runtime check runs on the host
// (where get_compiler_target() isn't available since HIP's host pass doesn't
// define __gfx*__ macros), while this catches misuse at device compile time.
static_assert(
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE !=
amd_buffer_coherence_enum::coherence_default,
"StreamK requires cross-CU buffer coherence (StreamKCoherency specialization). "
"Currently supported: gfx90a, gfx942, gfx950.");
__shared__ char smem_ptr[GetSmemSize()];
// Group offset (blockIdx.y = group batch index)
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
// Offset workspace per group so groups don't interfere.
// Safe to mutate kargs: on GPU each workgroup operates on its own
// register-local copy of the kernel arguments.
const auto per_group_ws_size =
kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType));
kargs.workspace_ptr =
static_cast<char*>(kargs.workspace_ptr) + blockIdY * per_group_ws_size;
const index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
StreamKDispatch(
kargs.tile_partitioner,
[&](index_t tile_idx) {
// Data-parallel workgroup: process one full tile
const auto tile_mn = kargs.tile_partitioner.get_output_tile_index(tile_idx);
const index_t i_m =
amd_wave_read_first_lane(tile_mn[I0] * TilePartitioner::MPerBlock);
const index_t i_n =
amd_wave_read_first_lane(tile_mn[I1] * TilePartitioner::NPerBlock);
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr,
kargs,
dp_num_loop,
i_m,
i_n,
/*block_idx_k=*/0);
},
[&](index_t sk_cta_idx) {
RunStreamKLoop(kargs, sk_cta_idx, a_ptr, b_ptr, c_ptr, smem_ptr);
});
}
/// @brief Stream-K loop: iterate over assigned K-iterations, run GEMM pipeline,
/// and perform Linear or Tree reduction to accumulate partial results.
CK_TILE_DEVICE void RunStreamKLoop(GroupedConvBwdWeightKernelArgsSpecialized& kargs,
index_t sk_cta_idx,
const OutDataType* a_ptr,
const InDataType* b_ptr,
WeiDataType* c_ptr,
char* smem_ptr) const
{
const StreamKOps sk_ops{};
index_t iter_start, iter_end;
kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, sk_cta_idx);
while(iter_start < iter_end)
{
index_t tile_idx =
amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
index_t tile_iter_start, tile_iter_end;
kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
index_t local_iter_start = amd_wave_read_first_lane(
kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
index_t local_iter_end =
amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
tile_iter_start, iter_end, tile_iter_end));
index_t num_loop_sk = local_iter_end - local_iter_start;
// Compute M/N tile indices from 1D tile index
const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
const index_t i_m =
amd_wave_read_first_lane(c_macro_tile_idx[I0] * TilePartitioner::MPerBlock);
const index_t i_n =
amd_wave_read_first_lane(c_macro_tile_idx[I1] * TilePartitioner::NPerBlock);
// K offset = local_iter_start * KPerBlock
const index_t i_k =
amd_wave_read_first_lane(local_iter_start * TilePartitioner::KPerBlock);
// Create block windows and run pipeline
const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, i_m, i_k);
const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, i_n, i_k);
const auto& d_block_window = MakeDBlockWindows(kargs.ds_ptr, kargs, i_m, i_n);
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop_sk, has_hot_loop, tail_num, smem_ptr);
auto tile_started = iter_start == tile_iter_start;
auto tile_ended = iter_end >= tile_iter_end;
if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Linear)
{
// Linear Reduction: tile-starter sequentially accumulates all
// partials from subsequent CTAs in order.
if(!tile_started)
{
sk_ops.StorePartial(kargs, sk_cta_idx, c_block_tile);
sk_ops.SignalStorePartialDone(kargs, sk_cta_idx);
}
else
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
{
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = sk_cta_idx + 1;
while(accum_iters < iter_per_tile)
{
sk_ops.WaitStorePartialDone(kargs, next_cta);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
sk_ops.AddBlockTile(
accum_block_tile,
sk_ops.template LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
}
}
auto c_block_window_out =
MakeCBlockWindow<memory_operation_enum::set>(c_ptr, kargs, i_m, i_n);
EpiloguePipeline{}(
c_block_window_out, accum_block_tile, d_block_window, smem_ptr);
}
}
else if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Tree)
{
// Tree Reduction: pairwise reduction with stride doubling.
// At each round, half the CTAs store their accumulated partial
// and exit; the other half load and accumulate from their partner.
// The tile-starter writes the final result.
auto accum_block_tile = c_block_tile;
index_t tile_local_cta_idx = amd_wave_read_first_lane(
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, sk_cta_idx));
index_t stride = amd_wave_read_first_lane(1);
for(;; stride <<= 1)
{
// Partner index is a *global* SK CTA index. This works because
// CTAs contributing to the same tile always have contiguous global
// SK CTA indices (guaranteed by the partitioner's iteration assignment).
const index_t partner_cta_idx = amd_wave_read_first_lane(sk_cta_idx + stride);
const index_t partner_start_iter = amd_wave_read_first_lane(
kargs.tile_partitioner.get_start_iter(partner_cta_idx));
bool partner_in_tile =
amd_wave_read_first_lane(partner_start_iter < tile_iter_end);
// If the partner of the tile-starter is not in this tile,
// then all partials are accumulated — write final result.
if(tile_started && !partner_in_tile)
{
auto c_block_window_out =
MakeCBlockWindow<memory_operation_enum::set>(c_ptr, kargs, i_m, i_n);
EpiloguePipeline{}(
c_block_window_out, accum_block_tile, d_block_window, smem_ptr);
break;
}
// This CTA's turn to read from its partner and accumulate.
if(tile_local_cta_idx % (stride << 1) == 0)
{
if(partner_in_tile)
{
sk_ops.WaitStorePartialDone(kargs, partner_cta_idx);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
sk_ops.AddBlockTile(
accum_block_tile,
sk_ops.template LoadPartial<typename BlockType::DataType>(
kargs, partner_cta_idx, c_block_tile.get_tile_distribution()));
}
}
// This CTA's turn to write its partial and exit.
else
{
sk_ops.StorePartial(kargs, sk_cta_idx, accum_block_tile);
sk_ops.SignalStorePartialDone(kargs, sk_cta_idx);
break;
}
}
}
else
{
static_assert(
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Linear ||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Tree,
"Unsupported StreamK reduction strategy for conv bwd weight.");
}
// Advance to next tile
iter_start = tile_iter_end;
block_sync_lds();
}
}
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
{
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
if constexpr(IsStreamK)
{
RunStreamK(kargs);
}
else if constexpr(GroupedConvTraitsType_::ExplicitGemm)
{
CallExplicitGemm(kargs);
}