Apply same optimization pattern to TensorAdaptor

TensorAdaptor has identical InitializeElementSize and
GetTransformAndItsUpperDimension patterns as TensorDescriptor.
Apply the same optimization:
- Replace nested static_for lambdas with find_in_tuple_of_sequences
- Replace generate_tuple lambda with pack expansion

Results: generate_tuple lambdas 100 -> 96 (4 events, 17ms eliminated)
This commit is contained in:
Max Podkorytov
2026-01-16 16:37:56 -06:00
parent 1159278d12
commit bc802ffe3a
2 changed files with 70 additions and 66 deletions

View File

@@ -45,28 +45,29 @@ struct TensorAdaptor
return BottomDimensionHiddenIds{};
}
// Helper to get length of a top dimension from transforms
template <index_t I>
__host__ __device__ static constexpr auto
GetTopDimLengthFromTransforms(const Transforms& transforms)
{
constexpr auto result = find_in_tuple_of_sequences<TopDimensionHiddenIds::At(Number<I>{})>(
UpperDimensionHiddenIdss{});
static_assert(result.found, "wrong! not found matching transformation and upper-dimension");
return transforms[Number<result.itran>{}].GetUpperLengths()[Number<result.idim_up>{}];
}
// Compute element size using pack expansion instead of generate_tuple with lambda
template <index_t... Is>
__host__ __device__ static constexpr auto ComputeElementSizeImpl(const Transforms& transforms,
Sequence<Is...>)
{
return (GetTopDimLengthFromTransforms<Is>(transforms) * ...);
}
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_top) {
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top);
constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
const auto length =
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
return length;
},
Number<ndim_top_>{});
// TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies{}, Number<1>{});
return ComputeElementSizeImpl(transforms,
typename arithmetic_sequence_gen<0, ndim_top_, 1>::type{});
}
template <index_t IDim>
@@ -76,24 +77,10 @@ struct TensorAdaptor
constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);
index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;
// Use compile-time search helper instead of nested static_for with lambdas
constexpr auto result = find_in_tuple_of_sequences<idim_hidden>(UpperDimensionHiddenIdss{});
static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
return make_tuple(itran_found, idim_up_found, found);
return make_tuple(result.itran, result.idim_up, result.found);
}
__host__ __device__ static constexpr index_t GetNumOfBottomDimension()

View File

@@ -64,43 +64,60 @@ struct FindTransformResult
static constexpr bool found = Found;
};
namespace detail {
// Helper to search through a tuple of sequences for a target value
// Returns FindTransformResult with (transform_index, index_within_sequence, found)
template <index_t Target, index_t TranIdx, typename FirstSeq, typename... RestSeqs>
__host__ __device__ constexpr auto find_in_tuple_of_sequences_impl()
// O(1) template depth implementation using pack expansion
// Avoids O(N) recursive template instantiations
template <index_t Target, typename... Seqs>
struct FindInTupleOfSequencesCompute
{
constexpr index_t idx = sequence_find_value<Target>(FirstSeq{});
if constexpr(idx >= 0)
private:
// Result struct for constexpr computation
struct ResultData
{
return FindTransformResult<TranIdx, idx, true>{};
}
else if constexpr(sizeof...(RestSeqs) > 0)
{
return find_in_tuple_of_sequences_impl<Target, TranIdx + 1, RestSeqs...>();
}
else
{
return FindTransformResult<0, 0, false>{};
}
}
index_t itran;
index_t idim_up;
bool found;
};
} // namespace detail
// Compute result using constexpr function with array lookup
static constexpr ResultData compute()
{
if constexpr(sizeof...(Seqs) == 0)
{
return {0, 0, false};
}
else
{
// Pack expansion creates array - O(1) template depth
constexpr index_t indices[] = {sequence_find_value<Target>(Seqs{})...};
// Find first matching sequence
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Seqs)); ++i)
{
if(indices[i] >= 0)
{
return {i, indices[i], true};
}
}
return {0, 0, false};
}
}
static constexpr ResultData result_ = compute();
public:
static constexpr index_t itran = result_.itran;
static constexpr index_t idim_up = result_.idim_up;
static constexpr bool found = result_.found;
using type = FindTransformResult<itran, idim_up, found>;
};
// Find target value in a tuple of sequences
// Returns FindTransformResult<itran, idim_up, found>
// This replaces nested static_for loops with O(1) template depth
// Uses O(1) template depth via pack expansion (no recursion)
template <index_t Target, typename... Seqs>
__host__ __device__ constexpr auto find_in_tuple_of_sequences(Tuple<Seqs...>)
{
if constexpr(sizeof...(Seqs) == 0)
{
return FindTransformResult<0, 0, false>{};
}
else
{
return detail::find_in_tuple_of_sequences_impl<Target, 0, Seqs...>();
}
return typename FindInTupleOfSequencesCompute<Target, Seqs...>::type{};
}
} // namespace ck