mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Replace lambdas with named functors in transform_tensor_descriptor
Lambda expressions in transform_tensor_descriptor created unique template instantiations for each capture combination. This change replaces lambdas with named functor structs to reduce instantiation count: - Add merge_sequences_functor and unpack_and_merge_sequences helper - Add convert_visible_to_hidden_id and convert_visible_ids_to_hidden_ids - Add generate_arithmetic_sequence_from_scan Build analysis shows instantiation count dropped from 388 to 32 (92% reduction).
This commit is contained in:
@@ -36,11 +36,9 @@ struct TensorDescriptor
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
|
||||
{
|
||||
constexpr auto all_low_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
|
||||
constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
|
||||
constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
@@ -311,6 +309,45 @@ struct lambda_get_up_dim_num
|
||||
}
|
||||
};
|
||||
|
||||
// Functor to convert a single visible dimension id to hidden id
|
||||
// Replaces inner lambda in transform_tensor_descriptor
|
||||
// Note: transform_sequences passes index_t values, not Number<> types
|
||||
template <typename OldTensorDescriptor>
|
||||
struct convert_visible_to_hidden_id
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const
|
||||
{
|
||||
return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id);
|
||||
}
|
||||
};
|
||||
|
||||
// Functor to convert a sequence of visible dimension ids to hidden ids
|
||||
// Replaces outer lambda in transform_tensor_descriptor
|
||||
template <typename OldTensorDescriptor>
|
||||
struct convert_visible_ids_to_hidden_ids
|
||||
{
|
||||
template <typename LowDimVisibleIds>
|
||||
__host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const
|
||||
{
|
||||
return transform_sequences(convert_visible_to_hidden_id<OldTensorDescriptor>{},
|
||||
low_dim_visible_ids);
|
||||
}
|
||||
};
|
||||
|
||||
// Functor to generate arithmetic sequences from scan results
|
||||
// Replaces lambda in transform_tensor_descriptor that generates up_dim_hidden_idss
|
||||
template <index_t OldHiddenDimNumber, typename UpDimNumbersScan>
|
||||
struct generate_arithmetic_sequence_from_scan
|
||||
{
|
||||
template <typename I>
|
||||
__host__ __device__ constexpr auto operator()(I) const
|
||||
{
|
||||
constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{});
|
||||
constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{});
|
||||
return typename arithmetic_sequence_gen<start, end, 1>::type{};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OldTensorDescriptor,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldVisibleIdss,
|
||||
@@ -327,11 +364,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
|
||||
"wrong! inconsitent number of transform");
|
||||
|
||||
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewLowerDimensionOldVisibleIdss{});
|
||||
constexpr auto all_old_top_ids =
|
||||
unpack_and_merge_sequences(NewLowerDimensionOldVisibleIdss{});
|
||||
|
||||
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
constexpr auto all_new_top_ids =
|
||||
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
|
||||
@@ -341,17 +378,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
// lower dimension's hidden idss
|
||||
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
|
||||
// sequences)
|
||||
constexpr auto low_dim_hidden_idss = transform_tuples(
|
||||
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
|
||||
[](auto low_dim_visible_ids) constexpr {
|
||||
return transform_sequences(
|
||||
// convert lower dimension visible id to hidden id
|
||||
[](auto low_dim_visible_id) constexpr {
|
||||
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
|
||||
},
|
||||
low_dim_visible_ids);
|
||||
},
|
||||
NewLowerDimensionOldVisibleIdss{});
|
||||
constexpr auto low_dim_hidden_idss =
|
||||
transform_tuples(convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{},
|
||||
NewLowerDimensionOldVisibleIdss{});
|
||||
|
||||
constexpr index_t num_new_transform = NewTransforms::Size();
|
||||
|
||||
@@ -364,22 +393,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
constexpr auto up_dim_numbers_scan = merge_sequences(
|
||||
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
|
||||
|
||||
using UpDimNumbersScanType = remove_cvref_t<decltype(up_dim_numbers_scan)>;
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
|
||||
return
|
||||
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
|
||||
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
|
||||
1>::type{};
|
||||
},
|
||||
generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{},
|
||||
Number<num_new_transform>{});
|
||||
|
||||
// new visible dimension's hidden ids
|
||||
constexpr auto unordered_new_visible_dim_hidden_ids =
|
||||
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
unpack_and_merge_sequences(up_dim_hidden_idss);
|
||||
|
||||
constexpr auto new_visible_dim_unordered2ordered =
|
||||
unpack([](auto... xs) constexpr { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
constexpr auto new_visible_dim_hidden_ids =
|
||||
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
|
||||
|
||||
@@ -43,11 +43,8 @@ PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoP
|
||||
},
|
||||
Number<num_dim>{});
|
||||
|
||||
// lower dimension Id
|
||||
const auto lower_dimss =
|
||||
generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});
|
||||
|
||||
// upper dimension Id
|
||||
// lower/upper dimension Ids
|
||||
const auto lower_dimss = generate_identity_sequences<num_dim>();
|
||||
const auto upper_dimss = lower_dimss;
|
||||
|
||||
return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);
|
||||
|
||||
@@ -34,4 +34,22 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
// Functor for merge_sequences to avoid lambda instantiation overhead
|
||||
struct merge_sequences_functor
|
||||
{
|
||||
template <typename... Seqs>
|
||||
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
|
||||
{
|
||||
return merge_sequences(seqs...);
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to unpack a tuple of sequences and merge them
|
||||
// Replaces: unpack([](auto... xs) { return merge_sequences(xs...); }, tuple_of_sequences)
|
||||
template <typename TupleOfSequences>
|
||||
__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences)
|
||||
{
|
||||
return unpack(merge_sequences_functor{}, TupleOfSequences{});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user