mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-25 09:37:42 +00:00
Add generate_identity_sequences helper and replace lambdas with named functors (#4828) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Add `generate_identity_sequences<N>()` helper that returns `Tuple<Sequence<0>, Sequence<1>, ..., Sequence<N-1>>` - Replace lambdas with named functors in `transform_tensor_descriptor` - Add `unpack_and_merge_sequences` helper functor - Reduces `transform_tensor_descriptor` instantiations from 388 to 32 (92% reduction) ## Motivation Multiple call sites use `generate_tuple([](auto i) { return Sequence<i>{}; }, Number<N>{})` pattern. A named helper reduces lambda instantiations. Additionally, each lambda in `transform_tensor_descriptor` creates a unique closure type, causing the function to be instantiated separately for every call site. Named functors share a single type, so the compiler reuses the same instantiation. ## Changes ### Part 1: generate_identity_sequences helper - Replaces common lambda pattern for generating identity sequences - Each lambda expression creates a unique closure type, causing separate template instantiations at every call site - Named helper shares a single type across all uses ### Part 2: Named functors in transform_tensor_descriptor - Add `unpack_and_merge_sequences` helper to replace lambda in `GetNumOfHiddenDimension` - Use `generate_identity_sequences` in `matrix_padder.hpp` ## Test Plan - [x] Added 7 unit tests: - 4 tests for `generate_identity_sequences` - 3 tests for `unpack_and_merge_sequences` - [ ] Waiting for full CI ## Related PRs This PR merges the functionality from: - ROCm/composable_kernel#3588 (generate_identity_sequences helper) - ROCm/composable_kernel#3589 (Named functors in transform_tensor_descriptor) Part of PR stack for issue #4229 (Reduce CK/CKTile Build Times) **Note:** This PR supersedes #4283, ROCm/composable_kernel#3588 and ROCm/composable_kernel#3589, which can be closed once this is merged.
56 lines
1.4 KiB
C++
56 lines
1.4 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck/utility/functional4.hpp"
|
|
#include "ck/utility/tuple.hpp"
|
|
|
|
namespace ck {
|
|
|
|
template <index_t... Is>
|
|
__host__ __device__ constexpr auto make_sequence(Number<Is>...)
|
|
{
|
|
return Sequence<Is...>{};
|
|
}
|
|
|
|
// F returns index_t
|
|
template <typename F, index_t N>
|
|
__host__ __device__ constexpr auto generate_sequence(F, Number<N>)
|
|
{
|
|
return typename sequence_gen<N, F>::type{};
|
|
}
|
|
|
|
// F returns Number<>
|
|
template <typename F, index_t N>
|
|
__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>)
|
|
{
|
|
return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); },
|
|
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
|
}
|
|
|
|
template <index_t... Is>
|
|
__host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
|
|
{
|
|
return Sequence<Is...>{};
|
|
}
|
|
|
|
// Functor wrapper for merge_sequences to enable reuse across call sites
|
|
struct merge_sequences_functor
|
|
{
|
|
template <typename... Seqs>
|
|
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
|
|
{
|
|
return merge_sequences(seqs...);
|
|
}
|
|
};
|
|
|
|
// Unpacks tuple of sequences and merges them into a single sequence
|
|
template <typename TupleOfSequences>
|
|
__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences tuple_of_sequences)
|
|
{
|
|
return unpack(merge_sequences_functor{}, tuple_of_sequences);
|
|
}
|
|
|
|
} // namespace ck
|