mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4406 (commit 61f9f90)
[CK] CK Tile grouped convolution direct load ## Motivation CK Tile grouped convolution forward direct load support. ## Technical Details Basic pipeline for direct load and new instances for forward for v1 and v4 pipelines. ## Test Plan test_grouped_convnd_fwd_tile ## Test Result CI pending ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-130
This commit is contained in:
committed by
assistant-librarian[bot]
parent
0cafa68b6f
commit
27e0a34e0f
@@ -731,6 +731,13 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
if(get_device_name() != "gfx950")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value)
|
||||
{
|
||||
@@ -1128,17 +1135,36 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitted_k,
|
||||
i_m,
|
||||
i_n,
|
||||
i_k,
|
||||
group_id);
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitted_k,
|
||||
i_m,
|
||||
i_n,
|
||||
i_k,
|
||||
group_id);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitted_k,
|
||||
i_m,
|
||||
i_n,
|
||||
i_k,
|
||||
group_id);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -508,6 +508,13 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
if(get_device_name() != "gfx950")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if(kargs.k_batch < 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
@@ -899,7 +906,18 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm(
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr, kargs, num_loop, i_m, i_n, i_k);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -654,6 +654,14 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
if(get_device_name() != "gfx950")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
@@ -1141,19 +1149,40 @@ struct GroupedConvolutionForwardKernel
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
ds_ptr_with_offsets,
|
||||
c_ptr,
|
||||
smem_ptr,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
kargs.k_batch,
|
||||
i_m,
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
if constexpr(GemmPipeline_::Async)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
ds_ptr_with_offsets,
|
||||
c_ptr,
|
||||
smem_ptr,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
kargs.k_batch,
|
||||
i_m,
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
ds_ptr_with_offsets,
|
||||
c_ptr,
|
||||
smem_ptr,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
kargs.k_batch,
|
||||
i_m,
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user