mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#4828 (commit 7de19bb)
Add generate_identity_sequences helper and replace lambdas with named functors (#4828) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Add `generate_identity_sequences<N>()` helper that returns `Tuple<Sequence<0>, Sequence<1>, ..., Sequence<N-1>>` - Replace lambdas with named functors in `transform_tensor_descriptor` - Add `unpack_and_merge_sequences` helper functor - Reduces `transform_tensor_descriptor` instantiations from 388 to 32 (92% reduction) ## Motivation Multiple call sites use `generate_tuple([](auto i) { return Sequence<i>{}; }, Number<N>{})` pattern. A named helper reduces lambda instantiations. Additionally, each lambda in `transform_tensor_descriptor` creates a unique closure type, causing the function to be instantiated separately for every call site. Named functors share a single type, so the compiler reuses the same instantiation. ## Changes ### Part 1: generate_identity_sequences helper - Replaces common lambda pattern for generating identity sequences - Each lambda expression creates a unique closure type, causing separate template instantiations at every call site - Named helper shares a single type across all uses ### Part 2: Named functors in transform_tensor_descriptor - Add `unpack_and_merge_sequences` helper to replace lambda in `GetNumOfHiddenDimension` - Use `generate_identity_sequences` in `matrix_padder.hpp` ## Test Plan - [x] Added 7 unit tests: - 4 tests for `generate_identity_sequences` - 3 tests for `unpack_and_merge_sequences` - [ ] Waiting for full CI ## Related PRs This PR merges the functionality from: - ROCm/composable_kernel#3588 (generate_identity_sequences helper) - ROCm/composable_kernel#3589 (Named functors in transform_tensor_descriptor) Part of PR stack for issue #4229 (Reduce CK/CKTile Build Times) **Note:** This PR supersedes #4283, ROCm/composable_kernel#3588 and ROCm/composable_kernel#3589, which can be closed once this is merged.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
ef82340e05
commit
1dd47118e2
@@ -38,11 +38,9 @@ struct TensorDescriptor
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
|
||||
{
|
||||
constexpr auto all_low_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
|
||||
constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
|
||||
constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
@@ -319,6 +317,41 @@ struct lambda_get_up_dim_num
|
||||
}
|
||||
};
|
||||
|
||||
// Maps a visible dimension ID to its corresponding hidden dimension ID
|
||||
template <typename OldTensorDescriptor>
|
||||
struct convert_visible_to_hidden_id
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const
|
||||
{
|
||||
return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id);
|
||||
}
|
||||
};
|
||||
|
||||
// Maps a sequence of visible IDs to their corresponding hidden IDs
|
||||
template <typename OldTensorDescriptor>
|
||||
struct convert_visible_ids_to_hidden_ids
|
||||
{
|
||||
template <typename LowDimVisibleIds>
|
||||
__host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const
|
||||
{
|
||||
return transform_sequences(convert_visible_to_hidden_id<OldTensorDescriptor>{},
|
||||
low_dim_visible_ids);
|
||||
}
|
||||
};
|
||||
|
||||
// Generates consecutive ranges of hidden dimension IDs for each transform's upper dimensions
|
||||
template <index_t OldHiddenDimNumber, typename UpDimNumbersScan>
|
||||
struct generate_arithmetic_sequence_from_scan
|
||||
{
|
||||
template <typename I>
|
||||
__host__ __device__ constexpr auto operator()(I) const
|
||||
{
|
||||
constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{});
|
||||
constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{});
|
||||
return typename arithmetic_sequence_gen<start, end, 1>::type{};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OldTensorDescriptor,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldVisibleIdss,
|
||||
@@ -335,11 +368,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
|
||||
"wrong! inconsitent number of transform");
|
||||
|
||||
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewLowerDimensionOldVisibleIdss{});
|
||||
constexpr auto all_old_top_ids =
|
||||
unpack_and_merge_sequences(NewLowerDimensionOldVisibleIdss{});
|
||||
|
||||
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
constexpr auto all_new_top_ids =
|
||||
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
|
||||
@@ -349,17 +382,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
// lower dimension's hidden idss
|
||||
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
|
||||
// sequences)
|
||||
constexpr auto low_dim_hidden_idss = transform_tuples(
|
||||
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
|
||||
[](auto low_dim_visible_ids) constexpr {
|
||||
return transform_sequences(
|
||||
// convert lower dimension visible id to hidden id
|
||||
[](auto low_dim_visible_id) constexpr {
|
||||
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
|
||||
},
|
||||
low_dim_visible_ids);
|
||||
},
|
||||
NewLowerDimensionOldVisibleIdss{});
|
||||
constexpr auto low_dim_hidden_idss =
|
||||
transform_tuples(convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{},
|
||||
NewLowerDimensionOldVisibleIdss{});
|
||||
|
||||
constexpr index_t num_new_transform = NewTransforms::Size();
|
||||
|
||||
@@ -372,22 +397,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
constexpr auto up_dim_numbers_scan = merge_sequences(
|
||||
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
|
||||
|
||||
using UpDimNumbersScanType = remove_cvref_t<decltype(up_dim_numbers_scan)>;
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
|
||||
return
|
||||
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
|
||||
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
|
||||
1>::type{};
|
||||
},
|
||||
generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{},
|
||||
Number<num_new_transform>{});
|
||||
|
||||
// new visible dimension's hidden ids
|
||||
constexpr auto unordered_new_visible_dim_hidden_ids =
|
||||
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
unpack_and_merge_sequences(up_dim_hidden_idss);
|
||||
|
||||
constexpr auto new_visible_dim_unordered2ordered =
|
||||
unpack([](auto... xs) constexpr { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
constexpr auto new_visible_dim_hidden_ids =
|
||||
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
|
||||
|
||||
Reference in New Issue
Block a user