mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4399 (commit 331512e)
[CK] Fix grouped conv fwd transform for merged groups ## Motivation [CK] Fix grouped conv fwd transform for merged groups for 1d and 3d. ## Technical Details After optimizations for 2d there is a lack of implementation for 1d and 3d ## Test Plan test_grouped_convnd_fwd ## Test Result pending CI ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
e16789b609
commit
ea6363ad78
@@ -1361,11 +1361,11 @@ struct TransformConvFwdToGemm
|
||||
{
|
||||
|
||||
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}),
|
||||
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
|
||||
make_tuple(NumGroupsToMerge, K_, FilterSizeNumType{}),
|
||||
make_tuple(GStrideTensorB_, KStrideTensorB_, CStrideTensorB_));
|
||||
return transform_tensor_descriptor(
|
||||
wei_gemmn_groups_gemmk_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
|
||||
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
|
||||
make_pass_through_transform(FilterSizeNumType{})),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
@@ -1495,16 +1495,16 @@ struct TransformConvFwdToGemm
|
||||
else
|
||||
{
|
||||
const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1),
|
||||
make_tuple(N_, Wo_, NumGroupsToMerge, 1, K_),
|
||||
make_tuple(
|
||||
NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
|
||||
NStrideTensorC_, WoStride_, GStrideTensorC_, GStrideTensorC_, KStrideTensorC_));
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
// We need only matrices from diagonal. X_or returns 0 for the same
|
||||
@@ -1518,13 +1518,13 @@ struct TransformConvFwdToGemm
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
@@ -1608,21 +1608,21 @@ struct TransformConvFwdToGemm
|
||||
else
|
||||
{
|
||||
const auto nhwo_groups_k_1_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
|
||||
make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, 1, K_),
|
||||
make_tuple(NStrideTensorC_,
|
||||
DoStride_,
|
||||
HoStride_,
|
||||
WoStride_,
|
||||
GStrideTensorC_,
|
||||
KStrideTensorC_,
|
||||
GStrideTensorC_));
|
||||
GStrideTensorC_,
|
||||
KStrideTensorC_));
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
nhwo_groups_k_1_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(K_),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
|
||||
make_pad_transform(1, 0, NumGroupsToMerge - 1),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}, Sequence<5>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
// We need only matrices from diagonal. X_or returns 0 for the same
|
||||
@@ -1636,13 +1636,13 @@ struct TransformConvFwdToGemm
|
||||
make_tuple(make_pass_through_transform(NDoHoWo),
|
||||
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
|
||||
make_pass_through_transform(K_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
|
||||
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user