diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 2437132d11..f6ad26dae8 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -82,24 +82,11 @@ struct TensorDescriptor constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible); - index_t itran_found = 0; - index_t idim_up_found = 0; - bool found = false; + // Use compile-time search helper instead of nested static_for with lambdas + // This eliminates ~918 applier::operator() instantiations + constexpr auto result = find_in_tuple_of_sequences(UpperDimensionIdss{}); - static_for<0, ntransform_, 1>{}([&](auto itran) { - constexpr auto up_dim_ids = UpperDimensionIdss{}[itran]; - - static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { - if constexpr(up_dim_ids[idim_up] == idim_hidden) - { - itran_found = itran; - idim_up_found = idim_up; - found = true; - } - }); - }); - - return make_tuple(itran_found, idim_up_found, found); + return make_tuple(result.itran, result.idim_up, result.found); } constexpr static index_t ntransform_ = GetNumOfTransform(); diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index 35a6a48632..aeebf08e65 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -34,4 +34,73 @@ __host__ __device__ constexpr auto to_sequence(Tuple...>) return Sequence{}; } +// Find index of Target in Sequence, returns -1 if not found +// Uses constexpr array lookup for O(1) template depth +template +__host__ __device__ constexpr index_t sequence_find_value(Sequence) +{ + if constexpr(sizeof...(Is) == 0) + { + return -1; + } + else + { + constexpr bool matches[] = {(Is == Target)...}; + for(index_t i = 0; i < static_cast(sizeof...(Is)); ++i) + { + if(matches[i]) + return i; + } + return -1; + } +} + +// Result type for find_in_tuple_of_sequences +template +struct FindTransformResult +{ + static constexpr index_t itran = ITran; + static constexpr index_t idim_up = IDimUp; + static constexpr bool found = Found; +}; + +namespace detail { + +// Helper to search through a tuple of sequences for a target value +// Returns FindTransformResult with (transform_index, index_within_sequence, found) +template +__host__ __device__ constexpr auto find_in_tuple_of_sequences_impl() +{ + constexpr index_t idx = sequence_find_value(FirstSeq{}); + if constexpr(idx >= 0) + { + return FindTransformResult{}; + } + else if constexpr(sizeof...(RestSeqs) > 0) + { + return find_in_tuple_of_sequences_impl(); + } + else + { + return FindTransformResult<0, 0, false>{}; + } +} + +} // namespace detail + +// Find target value in a tuple of sequences +// Returns FindTransformResult +// This replaces nested static_for loops with O(1) template depth +template +__host__ __device__ constexpr auto find_in_tuple_of_sequences(Tuple) +{ + if constexpr(sizeof...(Seqs) == 0) + { + return FindTransformResult<0, 0, false>{}; + } + else + { + return detail::find_in_tuple_of_sequences_impl(); + } +} } // namespace ck