diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index 79c5881d48..031082e1a0 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -45,28 +45,29 @@ struct TensorAdaptor return BottomDimensionHiddenIds{}; } + // Helper to get length of a top dimension from transforms + template + __host__ __device__ static constexpr auto + GetTopDimLengthFromTransforms(const Transforms& transforms) + { + constexpr auto result = find_in_tuple_of_sequences{})>( + UpperDimensionHiddenIdss{}); + static_assert(result.found, "wrong! not found matching transformation and upper-dimension"); + return transforms[Number{}].GetUpperLengths()[Number{}]; + } + + // Compute element size using pack expansion instead of generate_tuple with lambda + template + __host__ __device__ static constexpr auto ComputeElementSizeImpl(const Transforms& transforms, + Sequence) + { + return (GetTopDimLengthFromTransforms(transforms) * ...); + } + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) { - const auto lengths = generate_tuple( - [&](auto idim_top) { - constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top); - - constexpr index_t itran = tmp[Number<0>{}]; - constexpr index_t idim_up = tmp[Number<1>{}]; - constexpr bool found = tmp[Number<2>{}]; - - static_assert(found == true, - "wrong! not found matching transformation and upper-dimension"); - - const auto length = - transforms[Number{}].GetUpperLengths()[Number{}]; - - return length; - }, - Number{}); - - // TODO: make container_reduce support tuple of Number and index_t - return container_reduce(lengths, math::multiplies{}, Number<1>{}); + return ComputeElementSizeImpl(transforms, + typename arithmetic_sequence_gen<0, ndim_top_, 1>::type{}); } template @@ -76,24 +77,10 @@ struct TensorAdaptor constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top); - index_t itran_found = 0; - index_t idim_up_found = 0; - bool found = false; + // Use compile-time search helper instead of nested static_for with lambdas + constexpr auto result = find_in_tuple_of_sequences(UpperDimensionHiddenIdss{}); - static_for<0, ntransform_, 1>{}([&](auto itran) { - constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran]; - - static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { - if constexpr(up_dim_ids[idim_up] == idim_hidden) - { - itran_found = itran; - idim_up_found = idim_up; - found = true; - } - }); - }); - - return make_tuple(itran_found, idim_up_found, found); + return make_tuple(result.itran, result.idim_up, result.found); } __host__ __device__ static constexpr index_t GetNumOfBottomDimension() diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index aeebf08e65..f104733f6f 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -64,43 +64,60 @@ struct FindTransformResult static constexpr bool found = Found; }; -namespace detail { - -// Helper to search through a tuple of sequences for a target value -// Returns FindTransformResult with (transform_index, index_within_sequence, found) -template -__host__ __device__ constexpr auto find_in_tuple_of_sequences_impl() +// O(1) template depth implementation using pack expansion +// Avoids O(N) recursive template instantiations +template +struct FindInTupleOfSequencesCompute { - constexpr index_t idx = sequence_find_value(FirstSeq{}); - if constexpr(idx >= 0) + private: + // Result struct for constexpr computation + struct ResultData { - return FindTransformResult{}; - } - else if constexpr(sizeof...(RestSeqs) > 0) - { - return find_in_tuple_of_sequences_impl(); - } - else - { - return FindTransformResult<0, 0, false>{}; - } -} + index_t itran; + index_t idim_up; + bool found; + }; -} // namespace detail + // Compute result using constexpr function with array lookup + static constexpr ResultData compute() + { + if constexpr(sizeof...(Seqs) == 0) + { + return {0, 0, false}; + } + else + { + // Pack expansion creates array - O(1) template depth + constexpr index_t indices[] = {sequence_find_value(Seqs{})...}; + + // Find first matching sequence + for(index_t i = 0; i < static_cast(sizeof...(Seqs)); ++i) + { + if(indices[i] >= 0) + { + return {i, indices[i], true}; + } + } + return {0, 0, false}; + } + } + + static constexpr ResultData result_ = compute(); + + public: + static constexpr index_t itran = result_.itran; + static constexpr index_t idim_up = result_.idim_up; + static constexpr bool found = result_.found; + + using type = FindTransformResult; +}; // Find target value in a tuple of sequences // Returns FindTransformResult -// This replaces nested static_for loops with O(1) template depth +// Uses O(1) template depth via pack expansion (no recursion) template __host__ __device__ constexpr auto find_in_tuple_of_sequences(Tuple) { - if constexpr(sizeof...(Seqs) == 0) - { - return FindTransformResult<0, 0, false>{}; - } - else - { - return detail::find_in_tuple_of_sequences_impl(); - } + return typename FindInTupleOfSequencesCompute::type{}; } } // namespace ck