[rocm-libraries] ROCm/rocm-libraries#4406 (commit 61f9f90)

[CK] CK Tile grouped convolution direct load

## Motivation

CK Tile grouped convolution forward direct load support.

## Technical Details

Basic pipeline for direct load and new instances for forward for v1 and
v4 pipelines.

## Test Plan

test_grouped_convnd_fwd_tile

## Test Result

CI pending

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
AICK-130
This commit is contained in:
Bartłomiej Kocot
2026-02-09 21:09:42 +00:00
committed by assistant-librarian[bot]
parent 0cafa68b6f
commit 27e0a34e0f
29 changed files with 739 additions and 56 deletions

View File

@@ -731,6 +731,13 @@ struct GroupedConvolutionBackwardDataKernel
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
{
if constexpr(GemmPipeline_::Async)
{
if(get_device_name() != "gfx950")
{
return false;
}
}
if constexpr(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value)
{
@@ -1128,17 +1135,36 @@ struct GroupedConvolutionBackwardDataKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr,
kargs,
splitted_k,
i_m,
i_n,
i_k,
group_id);
if constexpr(GemmPipeline_::Async)
{
#if defined(__gfx950__)
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr,
kargs,
splitted_k,
i_m,
i_n,
i_k,
group_id);
#endif
}
else
{
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr,
kargs,
splitted_k,
i_m,
i_n,
i_k,
group_id);
}
}
};

View File

@@ -508,6 +508,13 @@ struct GroupedConvolutionBackwardWeightKernel
CK_TILE_HOST static bool
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
{
if constexpr(GemmPipeline_::Async)
{
if(get_device_name() != "gfx950")
{
return false;
}
}
if(kargs.k_batch < 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
@@ -899,7 +906,18 @@ struct GroupedConvolutionBackwardWeightKernel
__shared__ char smem_ptr[GetSmemSize()];
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
if constexpr(GemmPipeline_::Async)
{
#if defined(__gfx950__)
RunGemm(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
#endif
}
else
{
RunGemm(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
}
}
}
};

View File

@@ -654,6 +654,14 @@ struct GroupedConvolutionForwardKernel
CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized& kargs)
{
if constexpr(GemmPipeline_::Async)
{
if(get_device_name() != "gfx950")
{
return false;
}
}
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
!IsSplitKSupported)
@@ -1141,19 +1149,40 @@ struct GroupedConvolutionForwardKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
RunGemm(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
kargs.k_batch,
i_m,
i_n,
kargs.elfunc);
if constexpr(GemmPipeline_::Async)
{
#if defined(__gfx950__)
RunGemm(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
kargs.k_batch,
i_m,
i_n,
kargs.elfunc);
#endif
}
else
{
RunGemm(a_ptr,
b_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
kargs.k_batch,
i_m,
i_n,
kargs.elfunc);
}
}
}
};