Files
composable_kernel/experimental/grouped_convolution_tile_instances/include/instance_run.inc
Bartłomiej Kocot 945849b0f5 [rocm-libraries] ROCm/rocm-libraries#6838 (commit ff7a665)
[CK_TILE] Add depthwise conv2d forward kernel (FP16/FP32)
 (#6838)

## Motivation

CK currently has no kernel optimized for depthwise convolution
(G=C_in=C_out, C=K=1 per group) and existing generic paths perform
poorly for this workload. This PR adds a dedicated depthwise conv
forward kernel in CK Tile.

## Technical Details

Adds a dedicated depthwise conv2d forward op to CK Tile that performs
direct convolution rather than falling back to the generic GEMM path.
The kernel is templatized by filter size, stride, and data type, and
compiled into ~60 instances covering common configurations (kernel
3/5/7/9, stride 1/2, FP16/FP32). Supports both CDNA (gfx942/gfx950) and
RDNA (gfx1100/gfx1200) architectures.

## Test Plan

- [x] Correctness and performance validated on gfx942, gfx950, and
gfx1100, with ckProfiler `grouped_conv_fwd` as baseline.
- [ ] MI300A (gfx942) and gfx1200 validation.

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
AICK-1137
2026-05-15 13:48:51 +00:00

29 lines
984 B
C++

using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using ConvInstance = Builder::Instance;
auto conv = ConvInstance{};
auto result = [&]<auto Sig, auto Alg>() {
if constexpr(ConvDirectionIsBackwardWeight<Sig>)
{
if constexpr(ckb::SpecifiesTileOptimizations<decltype(Alg)> && Alg.optimizations.two_stage)
{
using ElementwiseOpBuilder = ckf::ElementwiseOpTileFactory<Sig, Alg>;
using ElementwiseOpInstance = ElementwiseOpBuilder::Instance;
auto elementwise_op = ElementwiseOpInstance{};
return ckt::run(conv, elementwise_op, args, inputs, outputs, s_conf);
}
else
{
return ckt::run(conv, args, inputs, outputs, s_conf);
}
}
else
{
return ckt::run(conv, args, inputs, outputs, s_conf);
}
}.template operator()<SIGNATURE, ALGORITHM>();
return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString());