Replace nested static_for lambdas with compile-time search helper

The GetTransformAndItsUpperDimension function used nested static_for
loops with lambdas to search for a hidden dimension in UpperDimensionIdss.
This caused 918 applier::operator() instantiations (81% of all applier
instantiations).

Replace with find_in_tuple_of_sequences helper that uses constexpr
array lookup and if-constexpr recursion, eliminating the lambda
instantiation overhead.

Results on example_grouped_conv_fwd_xdl_fp16:
- applier instantiations: 1132 -> 127 (89% reduction)
- TensorDescriptor instantiations: 2503 -> 664 (73% reduction)
- Template instantiation time: 23.4s -> 19.4s (17% reduction)
This commit is contained in:
Max Podkorytov
2026-01-16 15:49:59 -06:00
parent fcc9372c00
commit f65e9e4c99
2 changed files with 73 additions and 17 deletions

View File

@@ -82,24 +82,11 @@ struct TensorDescriptor
constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible);
index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;
// Use compile-time search helper instead of nested static_for with lambdas
// This eliminates ~918 applier::operator() instantiations
constexpr auto result = find_in_tuple_of_sequences<idim_hidden>(UpperDimensionIdss{});
static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionIdss{}[itran];
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
return make_tuple(itran_found, idim_up_found, found);
return make_tuple(result.itran, result.idim_up, result.found);
}
constexpr static index_t ntransform_ = GetNumOfTransform();

View File

@@ -34,4 +34,73 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
return Sequence<Is...>{};
}
// Find index of Target in Sequence, returns -1 if not found
// Uses constexpr array lookup for O(1) template depth
template <index_t Target, index_t... Is>
__host__ __device__ constexpr index_t sequence_find_value(Sequence<Is...>)
{
if constexpr(sizeof...(Is) == 0)
{
return -1;
}
else
{
constexpr bool matches[] = {(Is == Target)...};
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Is)); ++i)
{
if(matches[i])
return i;
}
return -1;
}
}
// Result type for find_in_tuple_of_sequences
template <index_t ITran, index_t IDimUp, bool Found>
struct FindTransformResult
{
static constexpr index_t itran = ITran;
static constexpr index_t idim_up = IDimUp;
static constexpr bool found = Found;
};
namespace detail {
// Helper to search through a tuple of sequences for a target value
// Returns FindTransformResult with (transform_index, index_within_sequence, found)
template <index_t Target, index_t TranIdx, typename FirstSeq, typename... RestSeqs>
__host__ __device__ constexpr auto find_in_tuple_of_sequences_impl()
{
constexpr index_t idx = sequence_find_value<Target>(FirstSeq{});
if constexpr(idx >= 0)
{
return FindTransformResult<TranIdx, idx, true>{};
}
else if constexpr(sizeof...(RestSeqs) > 0)
{
return find_in_tuple_of_sequences_impl<Target, TranIdx + 1, RestSeqs...>();
}
else
{
return FindTransformResult<0, 0, false>{};
}
}
} // namespace detail
// Find target value in a tuple of sequences
// Returns FindTransformResult<itran, idim_up, found>
// This replaces nested static_for loops with O(1) template depth
template <index_t Target, typename... Seqs>
__host__ __device__ constexpr auto find_in_tuple_of_sequences(Tuple<Seqs...>)
{
if constexpr(sizeof...(Seqs) == 0)
{
return FindTransformResult<0, 0, false>{};
}
else
{
return detail::find_in_tuple_of_sequences_impl<Target, 0, Seqs...>();
}
}
} // namespace ck