Replace lambdas with named functors in transform_tensor_descriptor

Lambda expressions in transform_tensor_descriptor created unique template
instantiations for each capture combination. This change replaces lambdas
with named functor structs to reduce instantiation count:

- Add merge_sequences_functor and unpack_and_merge_sequences helper
- Add convert_visible_to_hidden_id and convert_visible_ids_to_hidden_ids
- Add generate_arithmetic_sequence_from_scan

Build analysis shows instantiation count dropped from 388 to 32 (92% reduction).
This commit is contained in:
Max Podkorytov
2026-01-16 00:10:23 -06:00
parent d7e7fbdcff
commit 00849ac2e2
3 changed files with 72 additions and 33 deletions

View File

@@ -36,11 +36,9 @@ struct TensorDescriptor
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{});
constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
@@ -311,6 +309,45 @@ struct lambda_get_up_dim_num
}
};
// Functor to convert a single visible dimension id to hidden id
// Replaces inner lambda in transform_tensor_descriptor
// Note: transform_sequences passes index_t values, not Number<> types
template <typename OldTensorDescriptor>
struct convert_visible_to_hidden_id
{
__host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const
{
return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id);
}
};
// Functor to convert a sequence of visible dimension ids to hidden ids
// Replaces outer lambda in transform_tensor_descriptor
template <typename OldTensorDescriptor>
struct convert_visible_ids_to_hidden_ids
{
template <typename LowDimVisibleIds>
__host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const
{
return transform_sequences(convert_visible_to_hidden_id<OldTensorDescriptor>{},
low_dim_visible_ids);
}
};
// Functor to generate arithmetic sequences from scan results
// Replaces lambda in transform_tensor_descriptor that generates up_dim_hidden_idss
template <index_t OldHiddenDimNumber, typename UpDimNumbersScan>
struct generate_arithmetic_sequence_from_scan
{
template <typename I>
__host__ __device__ constexpr auto operator()(I) const
{
constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{});
constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{});
return typename arithmetic_sequence_gen<start, end, 1>::type{};
}
};
template <typename OldTensorDescriptor,
typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss,
@@ -327,11 +364,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
"wrong! inconsitent number of transform");
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{});
constexpr auto all_old_top_ids =
unpack_and_merge_sequences(NewLowerDimensionOldVisibleIdss{});
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
constexpr auto all_new_top_ids =
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
@@ -341,17 +378,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
// lower dimension's hidden idss
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_visible_ids) constexpr {
return transform_sequences(
// convert lower dimension visible id to hidden id
[](auto low_dim_visible_id) constexpr {
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
},
low_dim_visible_ids);
},
NewLowerDimensionOldVisibleIdss{});
constexpr auto low_dim_hidden_idss =
transform_tuples(convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{},
NewLowerDimensionOldVisibleIdss{});
constexpr index_t num_new_transform = NewTransforms::Size();
@@ -364,22 +393,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
constexpr auto up_dim_numbers_scan = merge_sequences(
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
using UpDimNumbersScanType = remove_cvref_t<decltype(up_dim_numbers_scan)>;
constexpr auto up_dim_hidden_idss = generate_tuple(
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{},
Number<num_new_transform>{});
// new visible dimension's hidden ids
constexpr auto unordered_new_visible_dim_hidden_ids =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
unpack_and_merge_sequences(up_dim_hidden_idss);
constexpr auto new_visible_dim_unordered2ordered =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
constexpr auto new_visible_dim_hidden_ids =
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);

View File

@@ -43,11 +43,8 @@ PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoP
},
Number<num_dim>{});
// lower dimension Id
const auto lower_dimss =
generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});
// upper dimension Id
// lower/upper dimension Ids
const auto lower_dimss = generate_identity_sequences<num_dim>();
const auto upper_dimss = lower_dimss;
return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);

View File

@@ -34,4 +34,22 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
return Sequence<Is...>{};
}
// Functor for merge_sequences to avoid lambda instantiation overhead
struct merge_sequences_functor
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
return merge_sequences(seqs...);
}
};
// Helper to unpack a tuple of sequences and merge them
// Replaces: unpack([](auto... xs) { return merge_sequences(xs...); }, tuple_of_sequences)
template <typename TupleOfSequences>
__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences)
{
return unpack(merge_sequences_functor{}, TupleOfSequences{});
}
} // namespace ck