[rocm-libraries] ROCm/rocm-libraries#5334 (commit bb5a3c8)

[CK][CK Tile] Improve access for merged groups and remove
 modulo from xor (#5334)

## Motivation

[CK][CK Tile] Improve access for merged groups and remove modulo from
xor

## Technical Details

- add template parameter to xor if modulo is needed. We don't need
modulo for merged groups
- use access by m for merged groups for a tensor
-
## Test Plan

test_grouped_convnd_fwd_tile

## Test Result

passed locally

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Bartłomiej Kocot
2026-03-20 15:47:22 +00:00
committed by assistant-librarian[bot]
parent fd8714aea9
commit db40d3f517
5 changed files with 96 additions and 42 deletions

View File

@@ -108,7 +108,9 @@ struct GroupedConvTraits
using OutLayout = OutLayout_;
// Forward Gemm Layouts
using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
using AsLayoutFwd = std::conditional_t<NumGroupsToMerge == 1,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor;
using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
// Backward Data Gemm Layouts

View File

@@ -518,10 +518,12 @@ struct TransformConvBwdWeightToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(X_),
make_pass_through_transform(C_)),
make_tuple(
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(X_),
make_pass_through_transform(C_)),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
// Merge To M, N
@@ -652,10 +654,12 @@ struct TransformConvBwdWeightToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
// Merge To M, N
@@ -788,10 +792,12 @@ struct TransformConvBwdWeightToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(Z_ * Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(Z_ * Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
// Merge To M, N

View File

@@ -1363,9 +1363,11 @@ struct TransformConvFwdToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(
make_pass_through_transform(NDoHoWo),
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}));
// Merge To M, N
@@ -1429,9 +1431,11 @@ struct TransformConvFwdToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(
make_pass_through_transform(NDoHoWo),
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}));
// Merge To M, N
@@ -1496,9 +1500,11 @@ struct TransformConvFwdToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(
make_pass_through_transform(NDoHoWo),
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}));
// Merge To M, N