From f65e9e4c9910e488d708b88c3e42d5ddaabce770 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 16 Jan 2026 15:49:59 -0600 Subject: [PATCH] Replace nested static_for lambdas with compile-time search helper The GetTransformAndItsUpperDimension function used nested static_for loops with lambdas to search for a hidden dimension in UpperDimensionIdss. This caused 918 applier::operator() instantiations (81% of all applier instantiations). Replace with find_in_tuple_of_sequences helper that uses constexpr array lookup and if-constexpr recursion, eliminating the lambda instantiation overhead. Results on example_grouped_conv_fwd_xdl_fp16: - applier instantiations: 1132 -> 127 (89% reduction) - TensorDescriptor instantiations: 2503 -> 664 (73% reduction) - Template instantiation time: 23.4s -> 19.4s (17% reduction) --- .../tensor_description/tensor_descriptor.hpp | 21 ++---- include/ck/utility/sequence_helper.hpp | 69 +++++++++++++++++++ 2 files changed, 73 insertions(+), 17 deletions(-) diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 2437132d11..f6ad26dae8 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -82,24 +82,11 @@ struct TensorDescriptor constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible); - index_t itran_found = 0; - index_t idim_up_found = 0; - bool found = false; + // Use compile-time search helper instead of nested static_for with lambdas + // This eliminates ~918 applier::operator() instantiations + constexpr auto result = find_in_tuple_of_sequences(UpperDimensionIdss{}); - static_for<0, ntransform_, 1>{}([&](auto itran) { - constexpr auto up_dim_ids = UpperDimensionIdss{}[itran]; - - static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { - if constexpr(up_dim_ids[idim_up] == idim_hidden) - { - itran_found = itran; - idim_up_found = idim_up; - found = true; - } - }); - }); - - return make_tuple(itran_found, idim_up_found, found); + return make_tuple(result.itran, result.idim_up, result.found); } constexpr static index_t ntransform_ = GetNumOfTransform(); diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index 35a6a48632..aeebf08e65 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -34,4 +34,73 @@ __host__ __device__ constexpr auto to_sequence(Tuple...>) return Sequence{}; } +// Find index of Target in Sequence, returns -1 if not found +// Uses constexpr array lookup for O(1) template depth +template +__host__ __device__ constexpr index_t sequence_find_value(Sequence) +{ + if constexpr(sizeof...(Is) == 0) + { + return -1; + } + else + { + constexpr bool matches[] = {(Is == Target)...}; + for(index_t i = 0; i < static_cast(sizeof...(Is)); ++i) + { + if(matches[i]) + return i; + } + return -1; + } +} + +// Result type for find_in_tuple_of_sequences +template +struct FindTransformResult +{ + static constexpr index_t itran = ITran; + static constexpr index_t idim_up = IDimUp; + static constexpr bool found = Found; +}; + +namespace detail { + +// Helper to search through a tuple of sequences for a target value +// Returns FindTransformResult with (transform_index, index_within_sequence, found) +template +__host__ __device__ constexpr auto find_in_tuple_of_sequences_impl() +{ + constexpr index_t idx = sequence_find_value(FirstSeq{}); + if constexpr(idx >= 0) + { + return FindTransformResult{}; + } + else if constexpr(sizeof...(RestSeqs) > 0) + { + return find_in_tuple_of_sequences_impl(); + } + else + { + return FindTransformResult<0, 0, false>{}; + } +} + +} // namespace detail + +// Find target value in a tuple of sequences +// Returns FindTransformResult +// This replaces nested static_for loops with O(1) template depth +template +__host__ __device__ constexpr auto find_in_tuple_of_sequences(Tuple) +{ + if constexpr(sizeof...(Seqs) == 0) + { + return FindTransformResult<0, 0, false>{}; + } + else + { + return detail::find_in_tuple_of_sequences_impl(); + } +} } // namespace ck