Add make_uniform_tuple helper for repeated value patterns

Add make_uniform_tuple<N>(value) helper to replace common pattern:
  generate_tuple([&](auto) { return value; }, Number<N>{})

This avoids unique lambda instantiations when creating tuples with
repeated values. Applied to device_grouped_conv_fwd_multiple_abd.
This commit is contained in:
Max Podkorytov
2026-01-16 00:42:19 -06:00
parent 00849ac2e2
commit 22a409be00
2 changed files with 24 additions and 2 deletions

View File

@@ -699,9 +699,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
if constexpr(isMultiA || isMultiB)
{
const auto as_grid_desc_ak0_m_ak1 =
generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number<NumATensor>{});
make_uniform_tuple(a_grid_desc_m_k_, Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 =
generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number<NumBTensor>{});
make_uniform_tuple(b_grid_desc_n_k_, Number<NumBTensor>{});
if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,

View File

@@ -59,6 +59,28 @@ __host__ __device__ constexpr auto generate_identity_sequences(Number<N>)
return generate_identity_sequences<N>();
}
// Optimized helper for common pattern: generate_tuple([&](auto) { return value; }, Number<N>{})
// Creates Tuple<T, T, ..., T> (N copies) without lambda instantiation
namespace detail {
template <typename T, index_t... Is>
__host__ __device__ constexpr auto make_uniform_tuple_impl(T&& value, Sequence<Is...>)
{
return make_tuple(((void)Is, value)...);
}
} // namespace detail
template <index_t N, typename T>
__host__ __device__ constexpr auto make_uniform_tuple(T&& value)
{
return detail::make_uniform_tuple_impl(static_cast<T&&>(value), make_index_sequence<N>{});
}
template <typename T, index_t N>
__host__ __device__ constexpr auto make_uniform_tuple(T&& value, Number<N>)
{
return make_uniform_tuple<N>(static_cast<T&&>(value));
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,