Files
composable_kernel/include/ck/utility/tuple_helper.hpp
Max Podkorytov 1dd47118e2 [rocm-libraries] ROCm/rocm-libraries#4828 (commit 7de19bb)
Add generate_identity_sequences helper and replace lambdas
 with named functors (#4828)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

- Add `generate_identity_sequences<N>()` helper that returns
`Tuple<Sequence<0>, Sequence<1>, ..., Sequence<N-1>>`
- Replace lambdas with named functors in `transform_tensor_descriptor`
- Add `unpack_and_merge_sequences` helper functor
- Reduces `transform_tensor_descriptor` instantiations from 388 to 32
(92% reduction)

## Motivation

Multiple call sites use `generate_tuple([](auto i) { return
Sequence<i>{}; }, Number<N>{})` pattern. A named helper reduces lambda
instantiations.

Additionally, each lambda in `transform_tensor_descriptor` creates a
unique closure type, causing the function to be instantiated separately
for every call site. Named functors share a single type, so the compiler
reuses the same instantiation.

## Changes

### Part 1: generate_identity_sequences helper
- Replaces common lambda pattern for generating identity sequences
- Each lambda expression creates a unique closure type, causing separate
template instantiations at every call site
- Named helper shares a single type across all uses

### Part 2: Named functors in transform_tensor_descriptor
- Add `unpack_and_merge_sequences` helper to replace lambda in
`GetNumOfHiddenDimension`
- Use `generate_identity_sequences` in `matrix_padder.hpp`

## Test Plan

- [x] Added 7 unit tests:
  - 4 tests for `generate_identity_sequences`
  - 3 tests for `unpack_and_merge_sequences`
- [ ] Waiting for full CI

## Related PRs

This PR merges the functionality from:
- ROCm/composable_kernel#3588 (generate_identity_sequences helper)
- ROCm/composable_kernel#3589 (Named functors in
transform_tensor_descriptor)

Part of PR stack for issue #4229 (Reduce CK/CKTile Build Times)

**Note:** This PR supersedes #4283, ROCm/composable_kernel#3588 and
ROCm/composable_kernel#3589, which can be closed once this is merged.
2026-02-28 20:11:11 +00:00

241 lines
6.8 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "functional4.hpp"
#include "tuple.hpp"
#ifndef CK_CODE_GEN_RTC
#include "is_detected.hpp"
#endif
namespace ck {
template <typename F, index_t... ids>
__host__ __device__ constexpr auto generate_tuple_for(F&& f, Sequence<ids...>)
{
return ck::make_tuple(f(Number<ids>{})...);
}
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
{
return generate_tuple_for(f, make_index_sequence<N>{});
}
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_tuple(F&& f, LongNumber<N>)
{
return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
{
return unpack([&f](auto&&... xs) { return tie(f(xs)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
// Creates Tuple<Sequence<0>, Sequence<1>, ..., Sequence<N-1>>
namespace detail {
template <index_t... Is>
__host__ __device__ constexpr auto make_identity_sequences_impl(Sequence<Is...>)
{
return make_tuple(Sequence<Is>{}...);
}
} // namespace detail
template <index_t N>
__host__ __device__ constexpr auto generate_identity_sequences()
{
return detail::make_identity_sequences_impl(make_index_sequence<N>{});
}
template <index_t N>
__host__ __device__ constexpr auto generate_identity_sequences(Number<N>)
{
return generate_identity_sequences<N>();
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
const Tuple<Y&...>& ty)
{
return unpack2(
[&](auto&&... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
template <typename... X, typename... Y>
auto concat_tuple_of_reference(ck::Tuple<X&...>& tx, ck::Tuple<Y&...>& ty)
{
return ck::unpack2(
[&](auto&&... zs) { return ck::Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{
return unpack2(
[&](auto... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
// Support any number of tuples to concat (also 1)
template <typename... X>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx)
{
return tx;
}
template <typename... X, typename... Tuples>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuples&... tuples)
{
return concat_tuple(tx, concat_tuple(tuples...));
}
namespace detail {
template <typename F, typename X, index_t... Is>
__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}))...);
}
template <typename F, typename X, typename Y, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
}
template <typename F, typename X, typename Y, typename Z, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
}
} // namespace detail
template <typename F, typename X>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
return detail::transform_tuples_impl(
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
// By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element)
{
return element;
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
__host__ __device__ constexpr auto UnrollNestedTuple(const T& element)
{
return make_tuple(element);
}
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
{
if constexpr(Depth == MaxDepth)
{
return tuple;
}
else
{
return unpack(
[&](auto&&... ts) {
return concat_tuple(UnrollNestedTuple<Depth + 1, MaxDepth>(ts)...);
},
tuple);
}
}
template <typename... Ts>
__host__ __device__ constexpr auto TupleReverse(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = Number<Tuple<Ts...>::Size() - i - 1>;
return tuple.At(Idx{});
},
Number<Tuple<Ts...>::Size()>{});
}
// Reduce tuple values in specific range using Function
template <index_t Idx, index_t End, typename F, typename... Ts>
__host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
{
static_assert(Idx < End, "Wrong parameters for TupleReduce");
if constexpr(Idx + 1 == End)
{
return tuple.At(Number<Idx>{});
}
else
{
return f(tuple.At(Number<Idx>{}), TupleReduce<Idx + 1, End>(f, tuple));
}
}
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
template <typename T>
using is_tuple = decltype(ck::declval<T&>().IsTuple());
#endif
template <typename... Ts>
__host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
{
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
return (is_detected<is_tuple, Ts>::value || ...);
#endif
}
template <index_t depth = 0, typename T>
__host__ __device__ constexpr auto TupleDepth(const T&)
{
return depth;
}
template <index_t depth = 0, typename... Ts>
__host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
{
return math::max(TupleDepth<depth + 1>(Ts{})...);
}
template <index_t from, index_t to, typename... Ts>
__host__ __device__ constexpr auto TupleSlice(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = Number<from + i>;
return tuple.At(Idx{});
},
Number<to - from>{});
}
} // namespace ck