mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Replace nested static_for lambdas with compile-time search helper
The GetTransformAndItsUpperDimension function used nested static_for loops with lambdas to search for a hidden dimension in UpperDimensionIdss. This caused 918 applier::operator() instantiations (81% of all applier instantiations). Replace with find_in_tuple_of_sequences helper that uses constexpr array lookup and if-constexpr recursion, eliminating the lambda instantiation overhead. Results on example_grouped_conv_fwd_xdl_fp16: - applier instantiations: 1132 -> 127 (89% reduction) - TensorDescriptor instantiations: 2503 -> 664 (73% reduction) - Template instantiation time: 23.4s -> 19.4s (17% reduction)
This commit is contained in:
@@ -82,24 +82,11 @@ struct TensorDescriptor
|
||||
|
||||
constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible);
|
||||
|
||||
index_t itran_found = 0;
|
||||
index_t idim_up_found = 0;
|
||||
bool found = false;
|
||||
// Use compile-time search helper instead of nested static_for with lambdas
|
||||
// This eliminates ~918 applier::operator() instantiations
|
||||
constexpr auto result = find_in_tuple_of_sequences<idim_hidden>(UpperDimensionIdss{});
|
||||
|
||||
static_for<0, ntransform_, 1>{}([&](auto itran) {
|
||||
constexpr auto up_dim_ids = UpperDimensionIdss{}[itran];
|
||||
|
||||
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
|
||||
if constexpr(up_dim_ids[idim_up] == idim_hidden)
|
||||
{
|
||||
itran_found = itran;
|
||||
idim_up_found = idim_up;
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return make_tuple(itran_found, idim_up_found, found);
|
||||
return make_tuple(result.itran, result.idim_up, result.found);
|
||||
}
|
||||
|
||||
constexpr static index_t ntransform_ = GetNumOfTransform();
|
||||
|
||||
@@ -34,4 +34,73 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
// Find index of Target in Sequence, returns -1 if not found
|
||||
// Uses constexpr array lookup for O(1) template depth
|
||||
template <index_t Target, index_t... Is>
|
||||
__host__ __device__ constexpr index_t sequence_find_value(Sequence<Is...>)
|
||||
{
|
||||
if constexpr(sizeof...(Is) == 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr bool matches[] = {(Is == Target)...};
|
||||
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Is)); ++i)
|
||||
{
|
||||
if(matches[i])
|
||||
return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Result type for find_in_tuple_of_sequences
|
||||
template <index_t ITran, index_t IDimUp, bool Found>
|
||||
struct FindTransformResult
|
||||
{
|
||||
static constexpr index_t itran = ITran;
|
||||
static constexpr index_t idim_up = IDimUp;
|
||||
static constexpr bool found = Found;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Helper to search through a tuple of sequences for a target value
|
||||
// Returns FindTransformResult with (transform_index, index_within_sequence, found)
|
||||
template <index_t Target, index_t TranIdx, typename FirstSeq, typename... RestSeqs>
|
||||
__host__ __device__ constexpr auto find_in_tuple_of_sequences_impl()
|
||||
{
|
||||
constexpr index_t idx = sequence_find_value<Target>(FirstSeq{});
|
||||
if constexpr(idx >= 0)
|
||||
{
|
||||
return FindTransformResult<TranIdx, idx, true>{};
|
||||
}
|
||||
else if constexpr(sizeof...(RestSeqs) > 0)
|
||||
{
|
||||
return find_in_tuple_of_sequences_impl<Target, TranIdx + 1, RestSeqs...>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return FindTransformResult<0, 0, false>{};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Find target value in a tuple of sequences
|
||||
// Returns FindTransformResult<itran, idim_up, found>
|
||||
// This replaces nested static_for loops with O(1) template depth
|
||||
template <index_t Target, typename... Seqs>
|
||||
__host__ __device__ constexpr auto find_in_tuple_of_sequences(Tuple<Seqs...>)
|
||||
{
|
||||
if constexpr(sizeof...(Seqs) == 0)
|
||||
{
|
||||
return FindTransformResult<0, 0, false>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return detail::find_in_tuple_of_sequences_impl<Target, 0, Seqs...>();
|
||||
}
|
||||
}
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user