[rocm-libraries] ROCm/rocm-libraries#4399 (commit 331512e)

[CK] Fix grouped conv fwd transform for merged groups

## Motivation

[CK] Fix grouped conv fwd transform for merged groups for 1d and 3d.

## Technical Details

After optimizations for 2d there is a lack of implementation for 1d and
3d

## Test Plan

test_grouped_convnd_fwd

## Test Result

pending CI

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Bartłomiej Kocot
2026-02-09 15:37:36 +00:00
committed by assistant-librarian[bot]
parent e16789b609
commit ea6363ad78

View File

@@ -1361,11 +1361,11 @@ struct TransformConvFwdToGemm
{
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}),
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
make_tuple(NumGroupsToMerge, K_, FilterSizeNumType{}),
make_tuple(GStrideTensorB_, KStrideTensorB_, CStrideTensorB_));
return transform_tensor_descriptor(
wei_gemmn_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K_)),
make_pass_through_transform(FilterSizeNumType{})),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -1495,16 +1495,16 @@ struct TransformConvFwdToGemm
else
{
const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1),
make_tuple(N_, Wo_, NumGroupsToMerge, 1, K_),
make_tuple(
NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
NStrideTensorC_, WoStride_, GStrideTensorC_, GStrideTensorC_, KStrideTensorC_));
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
make_pad_transform(1, 0, NumGroupsToMerge - 1),
make_pass_through_transform(K_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// We need only matrices from diagonal. X_or returns 0 for the same
@@ -1518,13 +1518,13 @@ struct TransformConvFwdToGemm
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
@@ -1608,21 +1608,21 @@ struct TransformConvFwdToGemm
else
{
const auto nhwo_groups_k_1_desc =
make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, 1, K_),
make_tuple(NStrideTensorC_,
DoStride_,
HoStride_,
WoStride_,
GStrideTensorC_,
KStrideTensorC_,
GStrideTensorC_));
GStrideTensorC_,
KStrideTensorC_));
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
make_pad_transform(1, 0, NumGroupsToMerge - 1),
make_pass_through_transform(K_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}, Sequence<5>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// We need only matrices from diagonal. X_or returns 0 for the same
@@ -1636,13 +1636,13 @@ struct TransformConvFwdToGemm
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}