[rocm-libraries] ROCm/rocm-libraries#5334 (commit bb5a3c8)

[CK][CK Tile] Improve access for merged groups and remove
 modulo from xor (#5334)

## Motivation

[CK][CK Tile] Improve access for merged groups and remove modulo from
xor

## Technical Details

- add template parameter to xor if modulo is needed. We don't need
modulo for merged groups
- use access by m for merged groups for a tensor
-
## Test Plan

test_grouped_convnd_fwd_tile

## Test Result

passed locally

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Bartłomiej Kocot
2026-03-20 15:47:22 +00:00
committed by assistant-librarian[bot]
parent fd8714aea9
commit db40d3f517
5 changed files with 96 additions and 42 deletions

View File

@@ -513,7 +513,9 @@ struct GroupedConvolutionForwardKernel
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
"Not supported!");
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor> ||
GroupedConvTraitsType_::NumGroupsToMerge > 1,
"Not supported!");
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
@@ -885,20 +887,51 @@ struct GroupedConvolutionForwardKernel
CK_TILE_DEVICE static auto
MakeABlockWindow(const InDataType* a_ptr, const ADescType& a_desc, const index_t block_idx_m)
{
// Step 1: Create tensor view
const auto& a_tensor_view = make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
if constexpr(GroupedConvTraitsType_::NumGroupsToMerge == 1)
{
// Access by K
// Step 1: Create tensor view
const auto& a_tensor_view = make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
// Step 2: Create padded view
const auto& a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<true, true>{});
// Step 2: Create padded view
const auto& a_pad_view =
pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<true, true>{});
// Step 3: Create tile window
return make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{block_idx_m, 0});
// Step 3: Create tile window
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{block_idx_m, 0});
}
else
{
// Access by M
const auto a_desc_reversed = transform_tensor_descriptor(
a_desc,
make_tuple(make_pass_through_transform(a_desc.get_length(I0)),
make_pass_through_transform(a_desc.get_length(I1))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
// Step 1: Create tensor view
const auto& a_tensor_view =
make_tensor_view<address_space_enum::global>(a_ptr, a_desc_reversed);
// Step 2: Create padded view
const auto& a_pad_view =
pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<true, true>{});
// Step 3: Create tile window
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, block_idx_m});
}
}
template <typename BDescType>