[rocm-libraries] ROCm/rocm-libraries#4828 (commit 7de19bb)

Add generate_identity_sequences helper and replace lambdas
 with named functors (#4828)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

- Add `generate_identity_sequences<N>()` helper that returns
`Tuple<Sequence<0>, Sequence<1>, ..., Sequence<N-1>>`
- Replace lambdas with named functors in `transform_tensor_descriptor`
- Add `unpack_and_merge_sequences` helper functor
- Reduces `transform_tensor_descriptor` instantiations from 388 to 32
(92% reduction)

## Motivation

Multiple call sites use `generate_tuple([](auto i) { return
Sequence<i>{}; }, Number<N>{})` pattern. A named helper reduces lambda
instantiations.

Additionally, each lambda in `transform_tensor_descriptor` creates a
unique closure type, causing the function to be instantiated separately
for every call site. Named functors share a single type, so the compiler
reuses the same instantiation.

## Changes

### Part 1: generate_identity_sequences helper
- Replaces common lambda pattern for generating identity sequences
- Each lambda expression creates a unique closure type, causing separate
template instantiations at every call site
- Named helper shares a single type across all uses

### Part 2: Named functors in transform_tensor_descriptor
- Add `unpack_and_merge_sequences` helper to replace lambda in
`GetNumOfHiddenDimension`
- Use `generate_identity_sequences` in `matrix_padder.hpp`

## Test Plan

- [x] Added 7 unit tests:
  - 4 tests for `generate_identity_sequences`
  - 3 tests for `unpack_and_merge_sequences`
- [ ] Waiting for full CI

## Related PRs

This PR merges the functionality from:
- ROCm/composable_kernel#3588 (generate_identity_sequences helper)
- ROCm/composable_kernel#3589 (Named functors in
transform_tensor_descriptor)

Part of PR stack for issue #4229 (Reduce CK/CKTile Build Times)

**Note:** This PR supersedes #4283, ROCm/composable_kernel#3588 and
ROCm/composable_kernel#3589, which can be closed once this is merged.
This commit is contained in:
Max Podkorytov
2026-02-28 20:11:11 +00:00
committed by assistant-librarian[bot]
parent ef82340e05
commit 1dd47118e2
19 changed files with 550 additions and 74 deletions

View File

@@ -38,11 +38,9 @@ struct TensorDescriptor
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{});
constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
@@ -319,6 +317,41 @@ struct lambda_get_up_dim_num
}
};
// Maps a visible dimension ID to its corresponding hidden dimension ID
template <typename OldTensorDescriptor>
struct convert_visible_to_hidden_id
{
__host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const
{
return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id);
}
};
// Maps a sequence of visible IDs to their corresponding hidden IDs
template <typename OldTensorDescriptor>
struct convert_visible_ids_to_hidden_ids
{
template <typename LowDimVisibleIds>
__host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const
{
return transform_sequences(convert_visible_to_hidden_id<OldTensorDescriptor>{},
low_dim_visible_ids);
}
};
// Generates consecutive ranges of hidden dimension IDs for each transform's upper dimensions
template <index_t OldHiddenDimNumber, typename UpDimNumbersScan>
struct generate_arithmetic_sequence_from_scan
{
template <typename I>
__host__ __device__ constexpr auto operator()(I) const
{
constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{});
constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{});
return typename arithmetic_sequence_gen<start, end, 1>::type{};
}
};
template <typename OldTensorDescriptor,
typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss,
@@ -335,11 +368,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
"wrong! inconsitent number of transform");
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{});
constexpr auto all_old_top_ids =
unpack_and_merge_sequences(NewLowerDimensionOldVisibleIdss{});
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
constexpr auto all_new_top_ids =
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
@@ -349,17 +382,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
// lower dimension's hidden idss
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_visible_ids) constexpr {
return transform_sequences(
// convert lower dimension visible id to hidden id
[](auto low_dim_visible_id) constexpr {
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
},
low_dim_visible_ids);
},
NewLowerDimensionOldVisibleIdss{});
constexpr auto low_dim_hidden_idss =
transform_tuples(convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{},
NewLowerDimensionOldVisibleIdss{});
constexpr index_t num_new_transform = NewTransforms::Size();
@@ -372,22 +397,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
constexpr auto up_dim_numbers_scan = merge_sequences(
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
using UpDimNumbersScanType = remove_cvref_t<decltype(up_dim_numbers_scan)>;
constexpr auto up_dim_hidden_idss = generate_tuple(
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{},
Number<num_new_transform>{});
// new visible dimension's hidden ids
constexpr auto unordered_new_visible_dim_hidden_ids =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
unpack_and_merge_sequences(up_dim_hidden_idss);
constexpr auto new_visible_dim_unordered2ordered =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
constexpr auto new_visible_dim_hidden_ids =
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);

View File

@@ -43,11 +43,8 @@ PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoP
},
Number<num_dim>{});
// lower dimension Id
const auto lower_dimss =
generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});
// upper dimension Id
// lower/upper dimension Ids
const auto lower_dimss = generate_identity_sequences<num_dim>();
const auto upper_dimss = lower_dimss;
return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);

View File

@@ -739,8 +739,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}

View File

@@ -894,8 +894,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
@@ -944,8 +943,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
@@ -993,8 +991,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}

View File

@@ -833,8 +833,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
@@ -892,8 +891,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}

View File

@@ -692,8 +692,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
@@ -744,8 +743,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}

View File

@@ -514,8 +514,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
@@ -563,8 +562,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}

View File

@@ -657,8 +657,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
@@ -707,8 +706,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}

View File

@@ -548,8 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
@@ -598,8 +597,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck/utility/functional4.hpp"
#include "ck/utility/tuple.hpp"
namespace ck {
@@ -34,4 +35,21 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
return Sequence<Is...>{};
}
// Functor wrapper for merge_sequences to enable reuse across call sites
struct merge_sequences_functor
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
return merge_sequences(seqs...);
}
};
// Unpacks tuple of sequences and merges them into a single sequence
template <typename TupleOfSequences>
__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences tuple_of_sequences)
{
return unpack(merge_sequences_functor{}, tuple_of_sequences);
}
} // namespace ck

View File

@@ -37,6 +37,27 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
// Creates Tuple<Sequence<0>, Sequence<1>, ..., Sequence<N-1>>
namespace detail {
template <index_t... Is>
__host__ __device__ constexpr auto make_identity_sequences_impl(Sequence<Is...>)
{
return make_tuple(Sequence<Is>{}...);
}
} // namespace detail
template <index_t N>
__host__ __device__ constexpr auto generate_identity_sequences()
{
return detail::make_identity_sequences_impl(make_index_sequence<N>{});
}
template <index_t N>
__host__ __device__ constexpr auto generate_identity_sequences(Number<N>)
{
return generate_identity_sequences<N>();
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,

View File

@@ -245,8 +245,7 @@ struct Layout
const auto lower_dims =
generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
Number<Tuple<ShapeDims...>::Size()>{});
const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
Number<Tuple<ShapeDims...>::Size()>{});
const auto upper_dims = generate_identity_sequences<Tuple<ShapeDims...>::Size()>();
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
}

View File

@@ -259,8 +259,7 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
layout(c_local_tile_tensor).GetUnrolledDescriptor());
const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<8>{});
const auto lower_upper_dims = generate_identity_sequences<8>();
auto sliced_desc = transform_tensor_descriptor(
partition_desc,

View File

@@ -190,8 +190,7 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
const auto transforms = GenerateSliceTransforms(idx, shape);
using TransformsTupleType = decltype(transforms);
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto lower_dims = generate_identity_sequences<old_shape_dims>();
const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}

View File

@@ -186,8 +186,7 @@ __host__ __device__ constexpr auto get(const Layout<Shape, UnrolledDesc>& layout
},
Number<old_shape_dims>{});
const auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
const auto lower_dims = generate_identity_sequences<old_shape_dims>();
const auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
@@ -492,8 +491,7 @@ __host__ __device__ constexpr auto unmerge(const Layout<Shape, UnrolledDesc>& la
},
Number<dims>{});
constexpr auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
constexpr auto lower_dims = generate_identity_sequences<dims>();
constexpr auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>::value)

View File

@@ -293,8 +293,7 @@ make_local_partition(TensorType& tensor,
},
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; },
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
generate_identity_sequences<remove_reference_t<decltype(tensor_shape)>::Size()>();
auto sliced_desc =
transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
// Create layout