mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Apply same optimization pattern to TensorAdaptor
TensorAdaptor has identical InitializeElementSize and GetTransformAndItsUpperDimension patterns as TensorDescriptor. Apply the same optimization: - Replace nested static_for lambdas with find_in_tuple_of_sequences - Replace generate_tuple lambda with pack expansion Results: generate_tuple lambdas 100 -> 96 (4 events, 17ms eliminated)
This commit is contained in:
@@ -45,28 +45,29 @@ struct TensorAdaptor
|
||||
return BottomDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
// Helper to get length of a top dimension from transforms
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr auto
|
||||
GetTopDimLengthFromTransforms(const Transforms& transforms)
|
||||
{
|
||||
constexpr auto result = find_in_tuple_of_sequences<TopDimensionHiddenIds::At(Number<I>{})>(
|
||||
UpperDimensionHiddenIdss{});
|
||||
static_assert(result.found, "wrong! not found matching transformation and upper-dimension");
|
||||
return transforms[Number<result.itran>{}].GetUpperLengths()[Number<result.idim_up>{}];
|
||||
}
|
||||
|
||||
// Compute element size using pack expansion instead of generate_tuple with lambda
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static constexpr auto ComputeElementSizeImpl(const Transforms& transforms,
|
||||
Sequence<Is...>)
|
||||
{
|
||||
return (GetTopDimLengthFromTransforms<Is>(transforms) * ...);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
|
||||
{
|
||||
const auto lengths = generate_tuple(
|
||||
[&](auto idim_top) {
|
||||
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top);
|
||||
|
||||
constexpr index_t itran = tmp[Number<0>{}];
|
||||
constexpr index_t idim_up = tmp[Number<1>{}];
|
||||
constexpr bool found = tmp[Number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
const auto length =
|
||||
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
|
||||
|
||||
return length;
|
||||
},
|
||||
Number<ndim_top_>{});
|
||||
|
||||
// TODO: make container_reduce support tuple of Number and index_t
|
||||
return container_reduce(lengths, math::multiplies{}, Number<1>{});
|
||||
return ComputeElementSizeImpl(transforms,
|
||||
typename arithmetic_sequence_gen<0, ndim_top_, 1>::type{});
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
@@ -76,24 +77,10 @@ struct TensorAdaptor
|
||||
|
||||
constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);
|
||||
|
||||
index_t itran_found = 0;
|
||||
index_t idim_up_found = 0;
|
||||
bool found = false;
|
||||
// Use compile-time search helper instead of nested static_for with lambdas
|
||||
constexpr auto result = find_in_tuple_of_sequences<idim_hidden>(UpperDimensionHiddenIdss{});
|
||||
|
||||
static_for<0, ntransform_, 1>{}([&](auto itran) {
|
||||
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
|
||||
|
||||
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
|
||||
if constexpr(up_dim_ids[idim_up] == idim_hidden)
|
||||
{
|
||||
itran_found = itran;
|
||||
idim_up_found = idim_up;
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return make_tuple(itran_found, idim_up_found, found);
|
||||
return make_tuple(result.itran, result.idim_up, result.found);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfBottomDimension()
|
||||
|
||||
@@ -64,43 +64,60 @@ struct FindTransformResult
|
||||
static constexpr bool found = Found;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Helper to search through a tuple of sequences for a target value
|
||||
// Returns FindTransformResult with (transform_index, index_within_sequence, found)
|
||||
template <index_t Target, index_t TranIdx, typename FirstSeq, typename... RestSeqs>
|
||||
__host__ __device__ constexpr auto find_in_tuple_of_sequences_impl()
|
||||
// O(1) template depth implementation using pack expansion
|
||||
// Avoids O(N) recursive template instantiations
|
||||
template <index_t Target, typename... Seqs>
|
||||
struct FindInTupleOfSequencesCompute
|
||||
{
|
||||
constexpr index_t idx = sequence_find_value<Target>(FirstSeq{});
|
||||
if constexpr(idx >= 0)
|
||||
private:
|
||||
// Result struct for constexpr computation
|
||||
struct ResultData
|
||||
{
|
||||
return FindTransformResult<TranIdx, idx, true>{};
|
||||
}
|
||||
else if constexpr(sizeof...(RestSeqs) > 0)
|
||||
{
|
||||
return find_in_tuple_of_sequences_impl<Target, TranIdx + 1, RestSeqs...>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return FindTransformResult<0, 0, false>{};
|
||||
}
|
||||
}
|
||||
index_t itran;
|
||||
index_t idim_up;
|
||||
bool found;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
// Compute result using constexpr function with array lookup
|
||||
static constexpr ResultData compute()
|
||||
{
|
||||
if constexpr(sizeof...(Seqs) == 0)
|
||||
{
|
||||
return {0, 0, false};
|
||||
}
|
||||
else
|
||||
{
|
||||
// Pack expansion creates array - O(1) template depth
|
||||
constexpr index_t indices[] = {sequence_find_value<Target>(Seqs{})...};
|
||||
|
||||
// Find first matching sequence
|
||||
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Seqs)); ++i)
|
||||
{
|
||||
if(indices[i] >= 0)
|
||||
{
|
||||
return {i, indices[i], true};
|
||||
}
|
||||
}
|
||||
return {0, 0, false};
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr ResultData result_ = compute();
|
||||
|
||||
public:
|
||||
static constexpr index_t itran = result_.itran;
|
||||
static constexpr index_t idim_up = result_.idim_up;
|
||||
static constexpr bool found = result_.found;
|
||||
|
||||
using type = FindTransformResult<itran, idim_up, found>;
|
||||
};
|
||||
|
||||
// Find target value in a tuple of sequences
|
||||
// Returns FindTransformResult<itran, idim_up, found>
|
||||
// This replaces nested static_for loops with O(1) template depth
|
||||
// Uses O(1) template depth via pack expansion (no recursion)
|
||||
template <index_t Target, typename... Seqs>
|
||||
__host__ __device__ constexpr auto find_in_tuple_of_sequences(Tuple<Seqs...>)
|
||||
{
|
||||
if constexpr(sizeof...(Seqs) == 0)
|
||||
{
|
||||
return FindTransformResult<0, 0, false>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return detail::find_in_tuple_of_sequences_impl<Target, 0, Seqs...>();
|
||||
}
|
||||
return typename FindInTupleOfSequencesCompute<Target, Seqs...>::type{};
|
||||
}
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user