[rocm-libraries] ROCm/rocm-libraries#4828 (commit 7de19bb)

Add generate_identity_sequences helper and replace lambdas
 with named functors (#4828)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

- Add `generate_identity_sequences<N>()` helper that returns
`Tuple<Sequence<0>, Sequence<1>, ..., Sequence<N-1>>`
- Replace lambdas with named functors in `transform_tensor_descriptor`
- Add `unpack_and_merge_sequences` helper functor
- Reduces `transform_tensor_descriptor` instantiations from 388 to 32
(92% reduction)

## Motivation

Multiple call sites use `generate_tuple([](auto i) { return
Sequence<i>{}; }, Number<N>{})` pattern. A named helper reduces lambda
instantiations.

Additionally, each lambda in `transform_tensor_descriptor` creates a
unique closure type, causing the function to be instantiated separately
for every call site. Named functors share a single type, so the compiler
reuses the same instantiation.

## Changes

### Part 1: generate_identity_sequences helper
- Replaces common lambda pattern for generating identity sequences
- Each lambda expression creates a unique closure type, causing separate
template instantiations at every call site
- Named helper shares a single type across all uses

### Part 2: Named functors in transform_tensor_descriptor
- Add `unpack_and_merge_sequences` helper to replace lambda in
`GetNumOfHiddenDimension`
- Use `generate_identity_sequences` in `matrix_padder.hpp`

## Test Plan

- [x] Added 7 unit tests:
  - 4 tests for `generate_identity_sequences`
  - 3 tests for `unpack_and_merge_sequences`
- [ ] Waiting for full CI

## Related PRs

This PR merges the functionality from:
- ROCm/composable_kernel#3588 (generate_identity_sequences helper)
- ROCm/composable_kernel#3589 (Named functors in
transform_tensor_descriptor)

Part of PR stack for issue #4229 (Reduce CK/CKTile Build Times)

**Note:** This PR supersedes #4283, ROCm/composable_kernel#3588 and
ROCm/composable_kernel#3589, which can be closed once this is merged.
This commit is contained in:
Max Podkorytov
2026-02-28 20:11:11 +00:00
committed by assistant-librarian[bot]
parent ef82340e05
commit 1dd47118e2
19 changed files with 550 additions and 74 deletions

View File

@@ -38,11 +38,9 @@ struct TensorDescriptor
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{});
constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
@@ -319,6 +317,41 @@ struct lambda_get_up_dim_num
}
};
// Maps a visible dimension ID to its corresponding hidden dimension ID
template <typename OldTensorDescriptor>
struct convert_visible_to_hidden_id
{
__host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const
{
return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id);
}
};
// Maps a sequence of visible IDs to their corresponding hidden IDs
template <typename OldTensorDescriptor>
struct convert_visible_ids_to_hidden_ids
{
template <typename LowDimVisibleIds>
__host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const
{
return transform_sequences(convert_visible_to_hidden_id<OldTensorDescriptor>{},
low_dim_visible_ids);
}
};
// Generates consecutive ranges of hidden dimension IDs for each transform's upper dimensions
template <index_t OldHiddenDimNumber, typename UpDimNumbersScan>
struct generate_arithmetic_sequence_from_scan
{
template <typename I>
__host__ __device__ constexpr auto operator()(I) const
{
constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{});
constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{});
return typename arithmetic_sequence_gen<start, end, 1>::type{};
}
};
template <typename OldTensorDescriptor,
typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss,
@@ -335,11 +368,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
"wrong! inconsitent number of transform");
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{});
constexpr auto all_old_top_ids =
unpack_and_merge_sequences(NewLowerDimensionOldVisibleIdss{});
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
constexpr auto all_new_top_ids =
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
@@ -349,17 +382,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
// lower dimension's hidden idss
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_visible_ids) constexpr {
return transform_sequences(
// convert lower dimension visible id to hidden id
[](auto low_dim_visible_id) constexpr {
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
},
low_dim_visible_ids);
},
NewLowerDimensionOldVisibleIdss{});
constexpr auto low_dim_hidden_idss =
transform_tuples(convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{},
NewLowerDimensionOldVisibleIdss{});
constexpr index_t num_new_transform = NewTransforms::Size();
@@ -372,22 +397,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
constexpr auto up_dim_numbers_scan = merge_sequences(
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
using UpDimNumbersScanType = remove_cvref_t<decltype(up_dim_numbers_scan)>;
constexpr auto up_dim_hidden_idss = generate_tuple(
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{},
Number<num_new_transform>{});
// new visible dimension's hidden ids
constexpr auto unordered_new_visible_dim_hidden_ids =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
unpack_and_merge_sequences(up_dim_hidden_idss);
constexpr auto new_visible_dim_unordered2ordered =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
constexpr auto new_visible_dim_hidden_ids =
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);