[rocm-libraries] ROCm/rocm-libraries#4406 (commit 61f9f90)

[CK] CK Tile grouped convolution direct load

## Motivation

CK Tile grouped convolution forward direct load support.

## Technical Details

Basic pipeline for direct load and new instances for forward for v1 and
v4 pipelines.

## Test Plan

test_grouped_convnd_fwd_tile

## Test Result

CI pending

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
AICK-130
This commit is contained in:
Bartłomiej Kocot
2026-02-09 21:09:42 +00:00
committed by assistant-librarian[bot]
parent 0cafa68b6f
commit 27e0a34e0f
29 changed files with 739 additions and 56 deletions

View File

@@ -38,4 +38,12 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 3
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>

View File

@@ -38,4 +38,12 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 3
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>

View File

@@ -38,4 +38,12 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 3
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>

View File

@@ -38,4 +38,12 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Default, 32, 3
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<128, 32, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 2, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 32, 64, Filter1x1Stride1Pad0, 32, 32, 2, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Default, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 128, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<64, 16, 16, 64, Filter1x1Stride1Pad0, 16, 16, 1, 1, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 64, 64, 64, Filter1x1Stride1Pad0, 16, 16, 2, 2, 2, 2, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>