mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK_TILE] Optimize static_ford and sequence compile-time infrastructure (#5938)
## Problem Each `static_for<0, N, 1>` instantiates its lambda N times (one per `number<I>` type). When nested, intermediate lambdas capture the outer loop variable (a different type per iteration), creating unique closure types. For a 3-level nest with M=4, N=4, K=2, this produces 4 + 16 + 32 = 52 IR functions, of which 20 are intermediate closures that get inlined away but still cost frontend compile time. ck_tile's `static_ford` was supposed to eliminate these intermediates (as old CK's PR #5031 did successfully), but it used a **recursive** `static_ford_impl` that recreated the same closure pattern plus added `reorder_old_to_new`/`reorder_new_to_old` overhead. Additionally, the sequence utility layer (`sequence_sort`, `is_valid_sequence_map`) used recursive template metaprogramming that generated O(N log N) intermediate types for every permutation validation — called on every `reorder_new_to_old`/`reorder_old_to_new` invocation. ## Changes ### 1. Replace `sequence_sort` with constexpr insertion sort Replace recursive merge sort (`sequence_sort_impl` + `sorted_sequence_merge_impl`, O(N log N) intermediate type instantiations) with constexpr insertion sort using `static_array`. O(1) template depth, same `::type` and `::sorted2unsorted_map` API. ### 2. Replace `is_valid_sequence_map` with constexpr check Replace sort-based permutation validation (which instantiated the full `sequence_sort` chain) with a constexpr "seen array" loop. O(N) constexpr steps instead of O(N log N) template instantiations. ### 3. Replace recursive `static_ford` with flat-loop `index_decomposer` Replace `static_ford_impl` (recursive `static_for` nesting + `pop_front`/`push_back` + `reorder_old_to_new` per iteration) with flat `index_decomposer` using pre-computed strides. Add `decompose_reordered` alias that folds reordering into decomposition, and `inverse_perm` helper that avoids the `sequence_map_inverse` → `is_valid_sequence_map` → `sequence_sort` chain. ### 4. Eliminate internal lambda via `ford_applier` The flat-loop approach still used `static_for` with a lambda, creating M×N internal lambda instantiations per call site. Replace with `ford_applier` struct that calls `f(decompose<I>{})` directly via fold expression — zero intermediate closures: ```cpp // Before: 2×M×N function instantiations static_for<0, M*N, 1>{}([&](auto i) { f(decompose<i>{}); }); // After: M×N function instantiations (50% reduction) ford_applier<Decomposer, make_index_sequence<M*N>>{}(f); ``` Also unified identity and non-identity order paths into a single template with `constexpr if`. ### 5. Fix const-qualified sequence handling Fix `is_valid_sequence_map` to handle const-qualified sequence types via `remove_cvref_t` in callers (`tensor_adaptor.hpp`, `tile_distribution_encoding.hpp`). ## Results (this PR only, without flattening) ### Build Time (Wilcoxon signed-rank, 7 paired trials, gfx942, load ~5) | Target | Base (s) | Treat (s) | Delta | % | Wins | Significant? | |--------|----------|-----------|-------|---|------|-------------| | **flatmm** | 160.1 | 152.7 | **-7.4s** | **-4.6%** | 6/7 | **YES** (W+=1, p<0.05) | | universal_gemm | 228.4 | 224.7 | -3.7s | -1.6% | 6/7 | Trending (W+=4) | Per-trial diffs (flatmm): [-6, -20, -9, -8, -8, 4, -5] Per-trial diffs (universal_gemm): [-2, -6, 4, -3, -2, -11, -6] ### IR Function Counts (device trace, gfx942) | Target | Metric | Before | After | Delta | % | |--------|--------|--------|-------|-------|---| | **universal_gemm** | InstantiateFunction | 117,715 | 109,165 | **-8,550** | **-7.3%** | | **universal_gemm** | CodeGen Function | 47,912 | 45,044 | **-2,868** | **-6.0%** | | **flatmm** | InstantiateFunction | 100,939 | 95,127 | **-5,812** | **-5.8%** | | **flatmm** | CodeGen Function | 42,651 | 40,367 | **-2,284** | **-5.4%** | Note: The `ford_applier` (commit 3) has minimal additional effect in this PR since ck_tile code does not yet use `static_ford` extensively. Its impact compounds when the follow-up flattening PR #5939 converts 124 `static_for` nests to `static_ford`. Combined results with #5939: flatmm **-7.5%** wall time (p<0.01), CodeGen **-10.5%**. ### ASM Equivalence 7/7 PASS — 979,943 lines of device assembly verified identical (gfx942 + gfx1100). TUs: universal_gemm, flatmm_basic, fmha_bwd, reduce, bscale. ## Test plan - [x] `test_ck_tile_static_ford`: 13 behavioral tests (identity/non-identity orders, 1D-4D, unit dimensions, edge cases) - [x] `ck_tile_unit_sequence`: 88 tests (11 new for sorted2unsorted_map, is_valid_sequence_map edge cases, sequence_unique_sort map round-trip) - [x] ASM equivalence verified (980K lines) - [x] Wilcoxon timing verified (7 trials, flatmm p<0.05) - [ ] CI 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
d80fa8831f
commit
522902f29b
@@ -39,7 +39,7 @@ CK_TILE_HOST_DEVICE constexpr auto
|
||||
container_reorder_given_new2old(const array<TData, NSize>& old_array, sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
return make_array<remove_cvref_t<TData>>(old_array[IRs]...);
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple<T
|
||||
{
|
||||
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
return make_tuple(old_tuple[number<IRs>{}]...);
|
||||
}
|
||||
@@ -109,7 +109,7 @@ CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence<Is..
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
return sequence<sequence<Is...>::at(number<IRs>{})...>{};
|
||||
}
|
||||
@@ -120,7 +120,7 @@ CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence<Is..
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
static_assert(is_valid_sequence_map<sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
constexpr auto new2old = typename sequence_map_inverse<sequence<IRs...>>::type{};
|
||||
|
||||
|
||||
@@ -144,9 +144,11 @@ struct sequence
|
||||
static_assert(MapOld2New::size() == size(),
|
||||
"wrong! reorder map should have the same size as sequence to be rerodered");
|
||||
|
||||
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
|
||||
static_assert(is_valid_sequence_map<remove_cvref_t<MapOld2New>>::value,
|
||||
"wrong! invalid reorder map");
|
||||
|
||||
return reorder_new_to_old(typename sequence_map_inverse<MapOld2New>::type{});
|
||||
return reorder_new_to_old(
|
||||
typename sequence_map_inverse<remove_cvref_t<MapOld2New>>::type{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto reverse()
|
||||
@@ -548,163 +550,59 @@ struct sequence_reduce<Reduce, Seq>
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename Values, typename Ids, typename Compare>
|
||||
struct sequence_sort_impl
|
||||
// Sorts a sequence using constexpr insertion sort. O(1) template instantiation
|
||||
// depth, replacing the recursive merge sort that created O(N log N) intermediate types.
|
||||
namespace detail {
|
||||
|
||||
template <typename Values, typename Compare, typename IndexSeq>
|
||||
struct sequence_sort_helper;
|
||||
|
||||
template <index_t... Vs, typename Compare, index_t... Idx>
|
||||
struct sequence_sort_helper<sequence<Vs...>, Compare, sequence<Idx...>>
|
||||
{
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename RightValues,
|
||||
typename RightIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl
|
||||
struct sort_result
|
||||
{
|
||||
static constexpr bool choose_left = LeftValues::front() < RightValues::front();
|
||||
|
||||
static constexpr index_t chosen_value =
|
||||
choose_left ? LeftValues::front() : RightValues::front();
|
||||
static constexpr index_t chosen_id = choose_left ? LeftIds::front() : RightIds::front();
|
||||
|
||||
using new_merged_values = decltype(MergedValues::push_back(number<chosen_value>{}));
|
||||
using new_merged_ids = decltype(MergedIds::push_back(number<chosen_id>{}));
|
||||
|
||||
using new_left_values = typename std::
|
||||
conditional<choose_left, decltype(LeftValues::pop_front()), LeftValues>::type;
|
||||
using new_left_ids =
|
||||
typename std::conditional<choose_left, decltype(LeftIds::pop_front()), LeftIds>::type;
|
||||
|
||||
using new_right_values = typename std::
|
||||
conditional<choose_left, RightValues, decltype(RightValues::pop_front())>::type;
|
||||
using new_right_ids =
|
||||
typename std::conditional<choose_left, RightIds, decltype(RightIds::pop_front())>::type;
|
||||
|
||||
using merge = sorted_sequence_merge_impl<new_left_values,
|
||||
new_left_ids,
|
||||
new_right_values,
|
||||
new_right_ids,
|
||||
new_merged_values,
|
||||
new_merged_ids,
|
||||
Comp>;
|
||||
// this is output
|
||||
using merged_values = typename merge::merged_values;
|
||||
using merged_ids = typename merge::merged_ids;
|
||||
static_array<index_t, sizeof...(Vs)> values;
|
||||
static_array<index_t, sizeof...(Vs)> ids;
|
||||
};
|
||||
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl<LeftValues,
|
||||
LeftIds,
|
||||
sequence<>,
|
||||
sequence<>,
|
||||
MergedValues,
|
||||
MergedIds,
|
||||
Comp>
|
||||
static constexpr sort_result compute()
|
||||
{
|
||||
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
|
||||
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
|
||||
};
|
||||
constexpr index_t n = sizeof...(Vs);
|
||||
sort_result r{{{Vs...}}, {{Idx...}}};
|
||||
// insertion sort — O(N^2) constexpr steps, O(1) template depth
|
||||
for(index_t i = 1; i < n; ++i)
|
||||
{
|
||||
for(index_t j = i; j > 0 && Compare{}(r.values[j], r.values[j - 1]); --j)
|
||||
{
|
||||
auto tv = r.values[j];
|
||||
r.values[j] = r.values[j - 1];
|
||||
r.values[j - 1] = tv;
|
||||
auto ti = r.ids[j];
|
||||
r.ids[j] = r.ids[j - 1];
|
||||
r.ids[j - 1] = ti;
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename RightValues,
|
||||
typename RightIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl<sequence<>,
|
||||
sequence<>,
|
||||
RightValues,
|
||||
RightIds,
|
||||
MergedValues,
|
||||
MergedIds,
|
||||
Comp>
|
||||
{
|
||||
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
|
||||
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
|
||||
};
|
||||
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename RightValues,
|
||||
typename RightIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge
|
||||
{
|
||||
using merge = sorted_sequence_merge_impl<LeftValues,
|
||||
LeftIds,
|
||||
RightValues,
|
||||
RightIds,
|
||||
sequence<>,
|
||||
sequence<>,
|
||||
Comp>;
|
||||
|
||||
using merged_values = typename merge::merged_values;
|
||||
using merged_ids = typename merge::merged_ids;
|
||||
};
|
||||
|
||||
static constexpr index_t nsize = Values::size();
|
||||
|
||||
using split_unsorted_values = sequence_split<Values, nsize / 2>;
|
||||
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
|
||||
|
||||
using left_unsorted_values = typename split_unsorted_values::left_type;
|
||||
using left_unsorted_ids = typename split_unsorted_ids::left_type;
|
||||
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
|
||||
using left_sorted_values = typename left_sort::sorted_values;
|
||||
using left_sorted_ids = typename left_sort::sorted_ids;
|
||||
|
||||
using right_unsorted_values = typename split_unsorted_values::right_type;
|
||||
using right_unsorted_ids = typename split_unsorted_ids::right_type;
|
||||
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
|
||||
using right_sorted_values = typename right_sort::sorted_values;
|
||||
using right_sorted_ids = typename right_sort::sorted_ids;
|
||||
|
||||
using merged_sorted = sorted_sequence_merge<left_sorted_values,
|
||||
left_sorted_ids,
|
||||
right_sorted_values,
|
||||
right_sorted_ids,
|
||||
Compare>;
|
||||
|
||||
using sorted_values = typename merged_sorted::merged_values;
|
||||
using sorted_ids = typename merged_sorted::merged_ids;
|
||||
static constexpr sort_result sorted = compute();
|
||||
using sorted_values = sequence<sorted.values[Idx]...>;
|
||||
using sorted_ids = sequence<sorted.ids[Idx]...>;
|
||||
};
|
||||
|
||||
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
|
||||
struct sequence_sort_impl<sequence<ValueX, ValueY>, sequence<IdX, IdY>, Compare>
|
||||
{
|
||||
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
|
||||
|
||||
using sorted_values = typename std::
|
||||
conditional<choose_x, sequence<ValueX, ValueY>, sequence<ValueY, ValueX>>::type;
|
||||
using sorted_ids =
|
||||
typename std::conditional<choose_x, sequence<IdX, IdY>, sequence<IdY, IdX>>::type;
|
||||
};
|
||||
|
||||
template <index_t Value, index_t Id, typename Compare>
|
||||
struct sequence_sort_impl<sequence<Value>, sequence<Id>, Compare>
|
||||
{
|
||||
using sorted_values = sequence<Value>;
|
||||
using sorted_ids = sequence<Id>;
|
||||
};
|
||||
|
||||
template <typename Compare>
|
||||
struct sequence_sort_impl<sequence<>, sequence<>, Compare>
|
||||
{
|
||||
using sorted_values = sequence<>;
|
||||
using sorted_ids = sequence<>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename Values, typename Compare>
|
||||
struct sequence_sort
|
||||
{
|
||||
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::size(), 1>::type;
|
||||
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
|
||||
static constexpr index_t n = Values::size();
|
||||
using idx_seq = make_index_sequence<n>;
|
||||
|
||||
// this is output
|
||||
using type = typename sort::sorted_values;
|
||||
using sorted2unsorted_map = typename sort::sorted_ids;
|
||||
using helper = detail::sequence_sort_helper<remove_cvref_t<Values>, Compare, idx_seq>;
|
||||
|
||||
using type = typename helper::sorted_values;
|
||||
using sorted2unsorted_map = typename helper::sorted_ids;
|
||||
};
|
||||
|
||||
template <typename Values, typename Less, typename Equal>
|
||||
@@ -782,10 +680,42 @@ struct sequence_unique_sort
|
||||
using sorted2unsorted_map = typename uniquify::uniquified_ids;
|
||||
};
|
||||
|
||||
// Validates that a sequence is a permutation of {0, 1, ..., N-1}.
|
||||
// Uses a constexpr loop instead of instantiating sequence_sort.
|
||||
namespace detail {
|
||||
|
||||
template <index_t... Is>
|
||||
constexpr bool check_valid_sequence_map()
|
||||
{
|
||||
constexpr index_t n = sizeof...(Is);
|
||||
if constexpr(n == 0)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t vals[] = {Is...};
|
||||
static_array<bool, n> seen{};
|
||||
for(index_t i = 0; i < n; ++i)
|
||||
{
|
||||
if(vals[i] < 0 || vals[i] >= n || seen[vals[i]])
|
||||
return false;
|
||||
seen[vals[i]] = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename SeqMap>
|
||||
struct is_valid_sequence_map
|
||||
: std::is_same<typename arithmetic_sequence_gen<0, SeqMap::size(), 1>::type,
|
||||
typename sequence_sort<SeqMap, less<index_t>>::type>
|
||||
struct is_valid_sequence_map : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
struct is_valid_sequence_map<sequence<Is...>>
|
||||
: std::integral_constant<bool, detail::check_valid_sequence_map<Is...>()>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
@@ -376,9 +376,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transf
|
||||
constexpr auto all_up_dim_new_top_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
static_assert(
|
||||
is_valid_sequence_map<remove_cvref_t<decltype(all_low_dim_old_top_ids)>>::value &&
|
||||
is_valid_sequence_map<remove_cvref_t<decltype(all_up_dim_new_top_ids)>>::value,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size();
|
||||
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size();
|
||||
@@ -443,8 +444,8 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
|
||||
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
|
||||
static_assert(is_valid_sequence_map<remove_cvref_t<decltype(all_old_top_ids)>>::value &&
|
||||
is_valid_sequence_map<remove_cvref_t<decltype(all_new_top_ids)>>::value,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
|
||||
@@ -135,65 +135,147 @@ struct idx_identity
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: sequence<...>
|
||||
// Orders: sequence<...>
|
||||
template <class RemainLengths, class Orders>
|
||||
struct static_ford_impl
|
||||
// Computes the inverse of a permutation as a constexpr array.
|
||||
// Avoids the sequence_map_inverse -> is_valid_sequence_map -> sequence_sort chain.
|
||||
template <class Perm>
|
||||
struct inverse_perm;
|
||||
|
||||
template <index_t... Ps>
|
||||
struct inverse_perm<sequence<Ps...>>
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford_impl()
|
||||
static constexpr auto compute()
|
||||
{
|
||||
static_assert(RemainLengths::size() > 0, "wrong! should not get here");
|
||||
constexpr index_t n = sizeof...(Ps);
|
||||
static_array<index_t, n> result{};
|
||||
constexpr index_t input[] = {Ps...};
|
||||
for(index_t i = 0; i < n; ++i)
|
||||
{
|
||||
result[input[i]] = i;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
static constexpr auto value = compute();
|
||||
};
|
||||
|
||||
// Decomposes a linear index into multi-dimensional indices using pre-computed
|
||||
// strides. Uses a single flat static_for instead of recursive nesting, which
|
||||
// eliminates intermediate lambda closure instantiations.
|
||||
template <class OrderedLengths, class IndexSeq>
|
||||
struct index_decomposer;
|
||||
|
||||
template <index_t... Ls, index_t... Is>
|
||||
struct index_decomposer<sequence<Ls...>, sequence<Is...>>
|
||||
{
|
||||
static constexpr index_t n_dim = sizeof...(Ls);
|
||||
static constexpr static_array<index_t, n_dim> lengths = {{Ls...}};
|
||||
|
||||
static constexpr static_array<index_t, n_dim> compute_all_strides()
|
||||
{
|
||||
static_array<index_t, n_dim> result{};
|
||||
if constexpr(n_dim > 0)
|
||||
{
|
||||
result[n_dim - 1] = 1;
|
||||
for(index_t i = n_dim - 1; i > 0; --i)
|
||||
{
|
||||
result[i - 1] = result[i] * lengths[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...>)
|
||||
// CurrentOrderedId: sequence<...>
|
||||
template <class F, class CurrentOrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
|
||||
static constexpr static_array<index_t, n_dim> strides = compute_all_strides();
|
||||
|
||||
// Compile-time decomposition: linear index -> sequence of per-dimension indices
|
||||
template <index_t LinearIdx>
|
||||
using decompose = sequence<((LinearIdx / strides[Is]) % lengths[Is])...>;
|
||||
|
||||
// Decompose AND reorder in one step using a pre-computed inverse permutation.
|
||||
// Produces the unordered multi-index directly, avoiding per-iteration
|
||||
// reorder_old_to_new member function instantiations on each unique sequence type.
|
||||
template <index_t LinearIdx, class New2Old>
|
||||
using decompose_reordered = sequence<((LinearIdx / strides[inverse_perm<New2Old>::value[Is]]) %
|
||||
lengths[inverse_perm<New2Old>::value[Is]])...>;
|
||||
};
|
||||
|
||||
// Calls f(decompose<I>{}) for each linear index I in the pack, using a single
|
||||
// fold expression. Bypasses the static_for lambda entirely, eliminating M*N
|
||||
// intermediate lambda closure instantiations that the lambda-based approach creates.
|
||||
template <class Decomposer, class LinearIdxSeq>
|
||||
struct ford_applier;
|
||||
|
||||
template <class Decomposer, index_t... LinearIds>
|
||||
struct ford_applier<Decomposer, sequence<LinearIds...>>
|
||||
{
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
|
||||
f, CurrentOrderedId::push_back(I));
|
||||
});
|
||||
if constexpr(sizeof...(LinearIds) > 0)
|
||||
{
|
||||
(f(typename Decomposer::template decompose<LinearIds>{}), ...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_ford_impl<sequence<>, Orders>
|
||||
// Same as ford_applier but applies reordering during decomposition.
|
||||
template <class Decomposer, class New2Old, class LinearIdxSeq>
|
||||
struct ford_applier_reordered;
|
||||
|
||||
template <class Decomposer, class New2Old, index_t... LinearIds>
|
||||
struct ford_applier_reordered<Decomposer, New2Old, sequence<LinearIds...>>
|
||||
{
|
||||
// F signature: F(sequence<...>)
|
||||
// OrderedId: sequence<...>
|
||||
template <class F, class OrderedId>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
// retrive unordered Id
|
||||
f(OrderedId::reorder_old_to_new(Orders{}));
|
||||
if constexpr(sizeof...(LinearIds) > 0)
|
||||
{
|
||||
(f(typename Decomposer::template decompose_reordered<LinearIds, New2Old>{}), ...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Lengths is sequence<...>, it is the length of each dimension for
|
||||
// N-dimensional loop
|
||||
// Orders is sequence<...>, it is the order of dimension in which static_ford
|
||||
// will loop over each
|
||||
// dimension
|
||||
// Compile-time N-dimensional loop with static multi-indices.
|
||||
// Uses direct fold expansion with index decomposition, producing zero
|
||||
// intermediate lambda closures. Each iteration calls f with a compile-time
|
||||
// sequence<i0, i1, ...> containing the multi-dimensional index.
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
struct static_ford
|
||||
{
|
||||
static constexpr index_t n_dim = Lengths::size();
|
||||
static constexpr index_t total_size =
|
||||
reduce_on_sequence(Lengths{}, multiplies<>{}, number<1>{});
|
||||
|
||||
static constexpr bool is_identity_order = std::is_same_v<Orders, make_index_sequence<n_dim>>;
|
||||
|
||||
// For identity order, OrderedLengths == Lengths (no reorder needed).
|
||||
// For non-identity, reorder lengths according to iteration order.
|
||||
// Both branches must be valid types, but only the active one is used.
|
||||
using OrderedLengths =
|
||||
std::conditional_t<is_identity_order,
|
||||
Lengths,
|
||||
remove_cvref_t<decltype(Lengths::reorder_new_to_old(Orders{}))>>;
|
||||
using Decomposer = detail::index_decomposer<OrderedLengths, make_index_sequence<n_dim>>;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr static_ford()
|
||||
{
|
||||
static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
|
||||
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
|
||||
if constexpr(is_identity_order)
|
||||
{
|
||||
detail::ford_applier<Decomposer, make_index_sequence<total_size>>{}(f);
|
||||
}
|
||||
else
|
||||
{
|
||||
detail::ford_applier_reordered<Decomposer, Orders, make_index_sequence<total_size>>{}(
|
||||
f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -355,6 +355,102 @@ TEST(SequenceSort, SortSingleElement)
|
||||
EXPECT_TRUE((std::is_same<Result, Expected>::value));
|
||||
}
|
||||
|
||||
// Test sequence_sort sorted2unsorted_map (index tracking)
|
||||
TEST(SequenceSort, SortedMapUnsorted)
|
||||
{
|
||||
using Seq = sequence<5, 2, 8, 1, 9>;
|
||||
using Sort = sequence_sort<Seq, less<index_t>>;
|
||||
using Map = typename Sort::sorted2unsorted_map;
|
||||
// sorted = <1,2,5,8,9>, original indices = <3,1,0,2,4>
|
||||
using Expected = sequence<3, 1, 0, 2, 4>;
|
||||
EXPECT_TRUE((std::is_same<Map, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortedMapAlreadySorted)
|
||||
{
|
||||
using Seq = sequence<1, 2, 3, 4, 5>;
|
||||
using Sort = sequence_sort<Seq, less<index_t>>;
|
||||
using Map = typename Sort::sorted2unsorted_map;
|
||||
// Already sorted: map should be identity
|
||||
using Expected = sequence<0, 1, 2, 3, 4>;
|
||||
EXPECT_TRUE((std::is_same<Map, Expected>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortedMapRoundTrip)
|
||||
{
|
||||
// Verify: sorted_values[i] == original[sorted2unsorted_map[i]]
|
||||
using Seq = sequence<5, 2, 8, 1, 9>;
|
||||
using Sort = sequence_sort<Seq, less<index_t>>;
|
||||
// sorted = <1,2,5,8,9>, map = <3,1,0,2,4>
|
||||
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(0)), Sort::type::at(0));
|
||||
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(1)), Sort::type::at(1));
|
||||
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(2)), Sort::type::at(2));
|
||||
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(3)), Sort::type::at(3));
|
||||
EXPECT_EQ(Seq::at(Sort::sorted2unsorted_map::at(4)), Sort::type::at(4));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortedMapWithDuplicates)
|
||||
{
|
||||
using Seq = sequence<3, 1, 3, 1>;
|
||||
using Sort = sequence_sort<Seq, less<index_t>>;
|
||||
using Sorted = typename Sort::type;
|
||||
using Map = typename Sort::sorted2unsorted_map;
|
||||
// sorted = <1,1,3,3>
|
||||
using ExpectedSorted = sequence<1, 1, 3, 3>;
|
||||
EXPECT_TRUE((std::is_same<Sorted, ExpectedSorted>::value));
|
||||
// Verify round-trip: original[map[i]] == sorted[i] for all i
|
||||
// (don't assert specific index order for duplicates — sort stability may vary)
|
||||
EXPECT_EQ(Seq::at(Map::at(0)), Sorted::at(0));
|
||||
EXPECT_EQ(Seq::at(Map::at(1)), Sorted::at(1));
|
||||
EXPECT_EQ(Seq::at(Map::at(2)), Sorted::at(2));
|
||||
EXPECT_EQ(Seq::at(Map::at(3)), Sorted::at(3));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortedMapReverseSorted)
|
||||
{
|
||||
using Seq = sequence<5, 4, 3, 2, 1>;
|
||||
using Sort = sequence_sort<Seq, less<index_t>>;
|
||||
using Sorted = typename Sort::type;
|
||||
using Map = typename Sort::sorted2unsorted_map;
|
||||
using ExpSorted = sequence<1, 2, 3, 4, 5>;
|
||||
using ExpMap = sequence<4, 3, 2, 1, 0>;
|
||||
EXPECT_TRUE((std::is_same<Sorted, ExpSorted>::value));
|
||||
EXPECT_TRUE((std::is_same<Map, ExpMap>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortedMapEmpty)
|
||||
{
|
||||
using Sort = sequence_sort<sequence<>, less<index_t>>;
|
||||
using Map = typename Sort::sorted2unsorted_map;
|
||||
EXPECT_TRUE((std::is_same<Map, sequence<>>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceSort, SortedMapSingleElement)
|
||||
{
|
||||
using Sort = sequence_sort<sequence<42>, less<index_t>>;
|
||||
using Map = typename Sort::sorted2unsorted_map;
|
||||
EXPECT_TRUE((std::is_same<Map, sequence<0>>::value));
|
||||
}
|
||||
|
||||
// Test sequence_unique_sort sorted2unsorted_map
|
||||
TEST(SequenceUniqueSort, UniqueSortMap)
|
||||
{
|
||||
using Seq = sequence<3, 1, 4, 1, 5, 9, 2, 6, 5>;
|
||||
using Result = sequence_unique_sort<Seq, less<index_t>, equal<index_t>>;
|
||||
using Map = typename Result::sorted2unsorted_map;
|
||||
// sorted unique = <1,2,3,4,5,6,9>
|
||||
// The map should reference the first occurrence of each unique value in the original
|
||||
// Verify round-trip: for each i, original[map[i]] == sorted_unique[i]
|
||||
using Values = typename Result::type;
|
||||
EXPECT_EQ(Seq::at(Map::at(0)), Values::at(0)); // 1
|
||||
EXPECT_EQ(Seq::at(Map::at(1)), Values::at(1)); // 2
|
||||
EXPECT_EQ(Seq::at(Map::at(2)), Values::at(2)); // 3
|
||||
EXPECT_EQ(Seq::at(Map::at(3)), Values::at(3)); // 4
|
||||
EXPECT_EQ(Seq::at(Map::at(4)), Values::at(4)); // 5
|
||||
EXPECT_EQ(Seq::at(Map::at(5)), Values::at(5)); // 6
|
||||
EXPECT_EQ(Seq::at(Map::at(6)), Values::at(6)); // 9
|
||||
}
|
||||
|
||||
// Test sequence_unique_sort
|
||||
TEST(SequenceUniqueSort, UniqueSort)
|
||||
{
|
||||
@@ -405,6 +501,24 @@ TEST(SequenceMap, InvalidMapMissing)
|
||||
EXPECT_FALSE((is_valid_sequence_map<Map>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, InvalidMapNegative)
|
||||
{
|
||||
using Map = sequence<0, -1, 2>;
|
||||
EXPECT_FALSE((is_valid_sequence_map<Map>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, ValidMapSingleElement)
|
||||
{
|
||||
EXPECT_TRUE((is_valid_sequence_map<sequence<0>>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, InvalidMapSingleElement)
|
||||
{
|
||||
EXPECT_FALSE((is_valid_sequence_map<sequence<1>>::value));
|
||||
}
|
||||
|
||||
TEST(SequenceMap, ValidMapEmpty) { EXPECT_TRUE((is_valid_sequence_map<sequence<>>::value)); }
|
||||
|
||||
// Test sequence_map_inverse
|
||||
// Note: sequence_map_inverse inverts a mapping where Map[i] = j means old position i maps to new
|
||||
// position j The inverse gives us new position i came from old position inverse[i]
|
||||
|
||||
@@ -5,6 +5,7 @@ message("-- Adding: test/ck_tile/utility/")
|
||||
|
||||
add_gtest_executable(test_fill test_fill.cpp)
|
||||
add_gtest_executable(test_ck_tile_sequence test_sequence.cpp)
|
||||
add_gtest_executable(test_ck_tile_static_ford test_static_ford.cpp)
|
||||
|
||||
# Add print tests
|
||||
add_subdirectory(print)
|
||||
|
||||
293
test/ck_tile/utility/test_static_ford.cpp
Normal file
293
test/ck_tile/utility/test_static_ford.cpp
Normal file
@@ -0,0 +1,293 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
// ============================================================================
|
||||
// static_ford Tests — Identity Order (default)
|
||||
// ============================================================================
|
||||
|
||||
TEST(CkTileStaticFord, Identity2D)
|
||||
{
|
||||
std::vector<std::pair<index_t, index_t>> visited;
|
||||
|
||||
static_ford<sequence<2, 3>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
constexpr index_t j = multi_id[number<1>{}];
|
||||
visited.emplace_back(i, j);
|
||||
});
|
||||
|
||||
ASSERT_EQ(visited.size(), 6u);
|
||||
EXPECT_EQ(visited[0], std::make_pair(0, 0));
|
||||
EXPECT_EQ(visited[1], std::make_pair(0, 1));
|
||||
EXPECT_EQ(visited[2], std::make_pair(0, 2));
|
||||
EXPECT_EQ(visited[3], std::make_pair(1, 0));
|
||||
EXPECT_EQ(visited[4], std::make_pair(1, 1));
|
||||
EXPECT_EQ(visited[5], std::make_pair(1, 2));
|
||||
}
|
||||
|
||||
TEST(CkTileStaticFord, Identity3D)
|
||||
{
|
||||
std::vector<std::tuple<index_t, index_t, index_t>> visited;
|
||||
|
||||
static_ford<sequence<2, 3, 2>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
constexpr index_t j = multi_id[number<1>{}];
|
||||
constexpr index_t k = multi_id[number<2>{}];
|
||||
visited.emplace_back(i, j, k);
|
||||
});
|
||||
|
||||
ASSERT_EQ(visited.size(), 12u);
|
||||
EXPECT_EQ(visited[0], std::make_tuple(0, 0, 0));
|
||||
EXPECT_EQ(visited[1], std::make_tuple(0, 0, 1));
|
||||
EXPECT_EQ(visited[2], std::make_tuple(0, 1, 0));
|
||||
EXPECT_EQ(visited[3], std::make_tuple(0, 1, 1));
|
||||
EXPECT_EQ(visited[4], std::make_tuple(0, 2, 0));
|
||||
EXPECT_EQ(visited[5], std::make_tuple(0, 2, 1));
|
||||
EXPECT_EQ(visited[6], std::make_tuple(1, 0, 0));
|
||||
EXPECT_EQ(visited[7], std::make_tuple(1, 0, 1));
|
||||
EXPECT_EQ(visited[8], std::make_tuple(1, 1, 0));
|
||||
EXPECT_EQ(visited[9], std::make_tuple(1, 1, 1));
|
||||
EXPECT_EQ(visited[10], std::make_tuple(1, 2, 0));
|
||||
EXPECT_EQ(visited[11], std::make_tuple(1, 2, 1));
|
||||
}
|
||||
|
||||
TEST(CkTileStaticFord, Identity1D)
|
||||
{
|
||||
std::vector<index_t> visited;
|
||||
|
||||
static_ford<sequence<5>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
visited.push_back(i);
|
||||
});
|
||||
|
||||
ASSERT_EQ(visited.size(), 5u);
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
EXPECT_EQ(visited[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CkTileStaticFord, SingleElement1D)
|
||||
{
|
||||
std::vector<index_t> visited;
|
||||
|
||||
static_ford<sequence<1>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
visited.push_back(i);
|
||||
});
|
||||
|
||||
ASSERT_EQ(visited.size(), 1u);
|
||||
EXPECT_EQ(visited[0], 0);
|
||||
}
|
||||
|
||||
TEST(CkTileStaticFord, SingleElement2D)
|
||||
{
|
||||
std::vector<std::pair<index_t, index_t>> visited;
|
||||
|
||||
static_ford<sequence<1, 1>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
constexpr index_t j = multi_id[number<1>{}];
|
||||
visited.emplace_back(i, j);
|
||||
});
|
||||
|
||||
ASSERT_EQ(visited.size(), 1u);
|
||||
EXPECT_EQ(visited[0], std::make_pair(0, 0));
|
||||
}
|
||||
|
||||
TEST(CkTileStaticFord, IdentityWithUnitDim)
|
||||
{
|
||||
std::vector<std::tuple<index_t, index_t, index_t>> visited;
|
||||
|
||||
static_ford<sequence<2, 1, 3>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
constexpr index_t j = multi_id[number<1>{}];
|
||||
constexpr index_t k = multi_id[number<2>{}];
|
||||
visited.emplace_back(i, j, k);
|
||||
});
|
||||
|
||||
ASSERT_EQ(visited.size(), 6u);
|
||||
EXPECT_EQ(visited[0], std::make_tuple(0, 0, 0));
|
||||
EXPECT_EQ(visited[1], std::make_tuple(0, 0, 1));
|
||||
EXPECT_EQ(visited[2], std::make_tuple(0, 0, 2));
|
||||
EXPECT_EQ(visited[3], std::make_tuple(1, 0, 0));
|
||||
EXPECT_EQ(visited[4], std::make_tuple(1, 0, 1));
|
||||
EXPECT_EQ(visited[5], std::make_tuple(1, 0, 2));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// static_ford Tests — Non-Identity Order (primary template with decompose_reordered)
|
||||
// ============================================================================
|
||||
|
||||
TEST(CkTileStaticFord, ReversedOrder2D)
|
||||
{
|
||||
std::vector<std::pair<index_t, index_t>> visited;
|
||||
|
||||
// Order (1, 0): dim 1 is outer, dim 0 is inner (column-major)
|
||||
static_ford<sequence<2, 3>, sequence<1, 0>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
constexpr index_t j = multi_id[number<1>{}];
|
||||
visited.emplace_back(i, j);
|
||||
});
|
||||
|
||||
ASSERT_EQ(visited.size(), 6u);
|
||||
EXPECT_EQ(visited[0], std::make_pair(0, 0));
|
||||
EXPECT_EQ(visited[1], std::make_pair(1, 0));
|
||||
EXPECT_EQ(visited[2], std::make_pair(0, 1));
|
||||
EXPECT_EQ(visited[3], std::make_pair(1, 1));
|
||||
EXPECT_EQ(visited[4], std::make_pair(0, 2));
|
||||
EXPECT_EQ(visited[5], std::make_pair(1, 2));
|
||||
}
|
||||
|
||||
TEST(CkTileStaticFord, CustomOrder3D_201)
|
||||
{
|
||||
std::vector<std::tuple<index_t, index_t, index_t>> visited;
|
||||
|
||||
// Orders<2,0,1>: dim 2 outermost, dim 0 middle, dim 1 innermost
|
||||
static_ford<sequence<2, 3, 4>, sequence<2, 0, 1>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
constexpr index_t j = multi_id[number<1>{}];
|
||||
constexpr index_t k = multi_id[number<2>{}];
|
||||
visited.emplace_back(i, j, k);
|
||||
});
|
||||
|
||||
ASSERT_EQ(visited.size(), 24u);
|
||||
// With orders (2,0,1): k varies slowest, then i, then j fastest
|
||||
EXPECT_EQ(visited[0], std::make_tuple(0, 0, 0));
|
||||
EXPECT_EQ(visited[1], std::make_tuple(0, 1, 0));
|
||||
EXPECT_EQ(visited[2], std::make_tuple(0, 2, 0));
|
||||
EXPECT_EQ(visited[3], std::make_tuple(1, 0, 0));
|
||||
EXPECT_EQ(visited[4], std::make_tuple(1, 1, 0));
|
||||
EXPECT_EQ(visited[5], std::make_tuple(1, 2, 0));
|
||||
EXPECT_EQ(visited[6], std::make_tuple(0, 0, 1));
|
||||
EXPECT_EQ(visited[7], std::make_tuple(0, 1, 1));
|
||||
// Tail: last element should be (1, 2, 3)
|
||||
EXPECT_EQ(visited[23], std::make_tuple(1, 2, 3));
|
||||
}
|
||||
|
||||
TEST(CkTileStaticFord, CustomOrder3D_120)
|
||||
{
|
||||
std::vector<std::tuple<index_t, index_t, index_t>> visited;
|
||||
|
||||
// Orders<1,2,0>: dim 1 outermost, dim 2 middle, dim 0 innermost
|
||||
static_ford<sequence<2, 3, 2>, sequence<1, 2, 0>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
constexpr index_t j = multi_id[number<1>{}];
|
||||
constexpr index_t k = multi_id[number<2>{}];
|
||||
visited.emplace_back(i, j, k);
|
||||
});
|
||||
|
||||
ASSERT_EQ(visited.size(), 12u);
|
||||
// With orders (1,2,0): j varies slowest, then k, then i fastest
|
||||
EXPECT_EQ(visited[0], std::make_tuple(0, 0, 0));
|
||||
EXPECT_EQ(visited[1], std::make_tuple(1, 0, 0));
|
||||
EXPECT_EQ(visited[2], std::make_tuple(0, 0, 1));
|
||||
EXPECT_EQ(visited[3], std::make_tuple(1, 0, 1));
|
||||
EXPECT_EQ(visited[4], std::make_tuple(0, 1, 0));
|
||||
EXPECT_EQ(visited[5], std::make_tuple(1, 1, 0));
|
||||
// Tail: last element should be (1, 2, 1)
|
||||
EXPECT_EQ(visited[11], std::make_tuple(1, 2, 1));
|
||||
}
|
||||
|
||||
TEST(CkTileStaticFord, NonIdentityWithUnitDim)
|
||||
{
|
||||
std::vector<std::tuple<index_t, index_t, index_t>> visited;
|
||||
|
||||
// Unit dim at position 1 with non-trivial order
|
||||
static_ford<sequence<2, 1, 3>, sequence<2, 0, 1>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
constexpr index_t j = multi_id[number<1>{}];
|
||||
constexpr index_t k = multi_id[number<2>{}];
|
||||
visited.emplace_back(i, j, k);
|
||||
});
|
||||
|
||||
ASSERT_EQ(visited.size(), 6u);
|
||||
// All entries must have j == 0 (unit dimension)
|
||||
for(size_t idx = 0; idx < visited.size(); ++idx)
|
||||
{
|
||||
EXPECT_EQ(std::get<1>(visited[idx]), 0) << "Unit dim not zero at iteration " << idx;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CkTileStaticFord, CustomOrder4D)
|
||||
{
|
||||
std::vector<std::tuple<index_t, index_t, index_t, index_t>> visited;
|
||||
|
||||
// 4D with order <3,1,0,2>
|
||||
static_ford<sequence<2, 3, 2, 4>, sequence<3, 1, 0, 2>>{}([&](auto multi_id) {
|
||||
constexpr index_t a = multi_id[number<0>{}];
|
||||
constexpr index_t b = multi_id[number<1>{}];
|
||||
constexpr index_t c = multi_id[number<2>{}];
|
||||
constexpr index_t d = multi_id[number<3>{}];
|
||||
visited.emplace_back(a, b, c, d);
|
||||
});
|
||||
|
||||
ASSERT_EQ(visited.size(), 48u);
|
||||
// dim 3 (size 4) outermost, dim 1 (size 3) next, dim 0 (size 2) next, dim 2 (size 2) inner
|
||||
EXPECT_EQ(visited[0], std::make_tuple(0, 0, 0, 0));
|
||||
EXPECT_EQ(visited[1], std::make_tuple(0, 0, 1, 0));
|
||||
EXPECT_EQ(visited[2], std::make_tuple(1, 0, 0, 0));
|
||||
EXPECT_EQ(visited[3], std::make_tuple(1, 0, 1, 0));
|
||||
EXPECT_EQ(visited[4], std::make_tuple(0, 1, 0, 0));
|
||||
EXPECT_EQ(visited[5], std::make_tuple(0, 1, 1, 0));
|
||||
}
|
||||
|
||||
TEST(CkTileStaticFord, AsymmetricDimsWithOrder)
|
||||
{
|
||||
std::vector<std::pair<index_t, index_t>> visited;
|
||||
|
||||
// Asymmetric: 3x5 with reversed order
|
||||
static_ford<sequence<3, 5>, sequence<1, 0>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
constexpr index_t j = multi_id[number<1>{}];
|
||||
visited.emplace_back(i, j);
|
||||
});
|
||||
|
||||
ASSERT_EQ(visited.size(), 15u);
|
||||
// dim 1 (size 5) outer, dim 0 (size 3) inner
|
||||
EXPECT_EQ(visited[0], std::make_pair(0, 0));
|
||||
EXPECT_EQ(visited[1], std::make_pair(1, 0));
|
||||
EXPECT_EQ(visited[2], std::make_pair(2, 0));
|
||||
EXPECT_EQ(visited[3], std::make_pair(0, 1));
|
||||
EXPECT_EQ(visited[4], std::make_pair(1, 1));
|
||||
EXPECT_EQ(visited[5], std::make_pair(2, 1));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Consistency: identity order matches explicit identity order
|
||||
// ============================================================================
|
||||
|
||||
TEST(CkTileStaticFord, IdentityOrderMatchesExplicit)
|
||||
{
|
||||
std::vector<std::pair<index_t, index_t>> default_visited;
|
||||
std::vector<std::pair<index_t, index_t>> explicit_visited;
|
||||
|
||||
static_ford<sequence<3, 4>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
constexpr index_t j = multi_id[number<1>{}];
|
||||
default_visited.emplace_back(i, j);
|
||||
});
|
||||
|
||||
static_ford<sequence<3, 4>, sequence<0, 1>>{}([&](auto multi_id) {
|
||||
constexpr index_t i = multi_id[number<0>{}];
|
||||
constexpr index_t j = multi_id[number<1>{}];
|
||||
explicit_visited.emplace_back(i, j);
|
||||
});
|
||||
|
||||
ASSERT_EQ(default_visited.size(), explicit_visited.size());
|
||||
for(size_t i = 0; i < default_visited.size(); ++i)
|
||||
{
|
||||
EXPECT_EQ(default_visited[i], explicit_visited[i]) << "Mismatch at iteration " << i;
|
||||
}
|
||||
}
|
||||
|
||||
// index_decomposer and inverse_perm are implementation details tested
|
||||
// indirectly through the static_ford behavioral tests above.
|
||||
// The IdentityOrderMatchesExplicit test verifies both code paths
|
||||
// (identity specialization and primary template) produce identical results.
|
||||
Reference in New Issue
Block a user