[rocm-libraries] ROCm/rocm-libraries#5334 (commit bb5a3c8)

[CK][CK Tile] Improve access for merged groups and remove
 modulo from xor (#5334)

## Motivation

[CK][CK Tile] Improve access for merged groups and remove modulo from
xor

## Technical Details

- add template parameter to xor if modulo is needed. We don't need
modulo for merged groups
- use access by m for merged groups for a tensor
-
## Test Plan

test_grouped_convnd_fwd_tile

## Test Result

passed locally

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Bartłomiej Kocot
2026-03-20 15:47:22 +00:00
committed by assistant-librarian[bot]
parent fd8714aea9
commit db40d3f517
5 changed files with 96 additions and 42 deletions

View File

@@ -513,7 +513,9 @@ struct GroupedConvolutionForwardKernel
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
"Not supported!");
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor> ||
GroupedConvTraitsType_::NumGroupsToMerge > 1,
"Not supported!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
@@ -885,20 +887,51 @@ struct GroupedConvolutionForwardKernel
CK_TILE_DEVICE static auto
MakeABlockWindow(const InDataType* a_ptr, const ADescType& a_desc, const index_t block_idx_m)
{
// Step 1: Create tensor view
const auto& a_tensor_view = make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge == 1)
{
// Access by K
// Step 1: Create tensor view
const auto& a_tensor_view = make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
// Step 2: Create padded view
const auto& a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<true, true>{});
// Step 2: Create padded view
const auto& a_pad_view =
pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<true, true>{});
// Step 3: Create tile window
return make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{block_idx_m, 0});
// Step 3: Create tile window
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{block_idx_m, 0});
}
else
{
// Access by M
const auto a_desc_reversed = transform_tensor_descriptor(
a_desc,
make_tuple(make_pass_through_transform(a_desc.get_length(I0)),
make_pass_through_transform(a_desc.get_length(I1))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
// Step 1: Create tensor view
const auto& a_tensor_view =
make_tensor_view<address_space_enum::global>(a_ptr, a_desc_reversed);
// Step 2: Create padded view
const auto& a_pad_view =
pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<true, true>{});
// Step 3: Create tile window
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, block_idx_m});
}
}
template <typename BDescType>

View File

@@ -108,7 +108,9 @@ struct GroupedConvTraits
using OutLayout = OutLayout_;
// Forward Gemm Layouts
using AsLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
using AsLayoutFwd = std::conditional_t<NumGroupsToMerge == 1,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
using BsLayoutFwd = ck_tile::tensor_layout::gemm::ColumnMajor;
using CLayoutFwd = ck_tile::tensor_layout::gemm::RowMajor;
// Backward Data Gemm Layouts

View File

@@ -518,10 +518,12 @@ struct TransformConvBwdWeightToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(X_),
make_pass_through_transform(C_)),
make_tuple(
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(X_),
make_pass_through_transform(C_)),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
// Merge To M, N
@@ -652,10 +654,12 @@ struct TransformConvBwdWeightToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
// Merge To M, N
@@ -788,10 +792,12 @@ struct TransformConvBwdWeightToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(Z_ * Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_),
make_pass_through_transform(Z_ * Y_ * X_),
make_pass_through_transform(C_)),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0, 3>{}, sequence<1>{}, sequence<2>{}, sequence<4>{}));
// Merge To M, N

View File

@@ -1363,9 +1363,11 @@ struct TransformConvFwdToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(
make_pass_through_transform(NDoHoWo),
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}));
// Merge To M, N
@@ -1429,9 +1431,11 @@ struct TransformConvFwdToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(
make_pass_through_transform(NDoHoWo),
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}));
// Merge To M, N
@@ -1496,9 +1500,11 @@ struct TransformConvFwdToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(
make_pass_through_transform(NDoHoWo),
make_xor_transform<decltype(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
false>(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}));
// Merge To M, N