mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
Add make_uniform_tuple helper for repeated value patterns
Add make_uniform_tuple<N>(value) helper to replace common pattern:
generate_tuple([&](auto) { return value; }, Number<N>{})
This avoids unique lambda instantiations when creating tuples with
repeated values. Applied to device_grouped_conv_fwd_multiple_abd.
This commit is contained in:
@@ -699,9 +699,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 =
|
||||
generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number<NumATensor>{});
|
||||
make_uniform_tuple(a_grid_desc_m_k_, Number<NumATensor>{});
|
||||
const auto bs_grid_desc_bk0_n_bk1 =
|
||||
generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number<NumBTensor>{});
|
||||
make_uniform_tuple(b_grid_desc_n_k_, Number<NumBTensor>{});
|
||||
|
||||
if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
|
||||
@@ -59,6 +59,28 @@ __host__ __device__ constexpr auto generate_identity_sequences(Number<N>)
|
||||
return generate_identity_sequences<N>();
|
||||
}
|
||||
|
||||
// Optimized helper for common pattern: generate_tuple([&](auto) { return value; }, Number<N>{})
|
||||
// Creates Tuple<T, T, ..., T> (N copies) without lambda instantiation
|
||||
namespace detail {
|
||||
template <typename T, index_t... Is>
|
||||
__host__ __device__ constexpr auto make_uniform_tuple_impl(T&& value, Sequence<Is...>)
|
||||
{
|
||||
return make_tuple(((void)Is, value)...);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template <index_t N, typename T>
|
||||
__host__ __device__ constexpr auto make_uniform_tuple(T&& value)
|
||||
{
|
||||
return detail::make_uniform_tuple_impl(static_cast<T&&>(value), make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_uniform_tuple(T&& value, Number<N>)
|
||||
{
|
||||
return make_uniform_tuple<N>(static_cast<T&&>(value));
|
||||
}
|
||||
|
||||
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
|
||||
template <typename... X, typename... Y>
|
||||
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
|
||||
|
||||
Reference in New Issue
Block a user