mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
[CK_TILE] Optimize static_ford and sequence compile-time infrastructure (#5938) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem Each `static_for<0, N, 1>` instantiates its lambda N times (one per `number<I>` type). When nested, intermediate lambdas capture the outer loop variable (a different type per iteration), creating unique closure types. For a 3-level nest with M=4, N=4, K=2, this produces 4 + 16 + 32 = 52 IR functions, of which 20 are intermediate closures that get inlined away but still cost frontend compile time. ck_tile's `static_ford` was supposed to eliminate these intermediates (as old CK's PR #5031 did successfully), but it used a **recursive** `static_ford_impl` that recreated the same closure pattern plus added `reorder_old_to_new`/`reorder_new_to_old` overhead. Additionally, the sequence utility layer (`sequence_sort`, `is_valid_sequence_map`) used recursive template metaprogramming that generated O(N log N) intermediate types for every permutation validation — called on every `reorder_new_to_old`/`reorder_old_to_new` invocation. ## Changes ### 1. Replace `sequence_sort` with constexpr insertion sort Replace recursive merge sort (`sequence_sort_impl` + `sorted_sequence_merge_impl`, O(N log N) intermediate type instantiations) with constexpr insertion sort using `static_array`. O(1) template depth, same `::type` and `::sorted2unsorted_map` API. ### 2. Replace `is_valid_sequence_map` with constexpr check Replace sort-based permutation validation (which instantiated the full `sequence_sort` chain) with a constexpr "seen array" loop. O(N) constexpr steps instead of O(N log N) template instantiations. ### 3. Replace recursive `static_ford` with flat-loop `index_decomposer` Replace `static_ford_impl` (recursive `static_for` nesting + `pop_front`/`push_back` + `reorder_old_to_new` per iteration) with flat `index_decomposer` using pre-computed strides. Add `decompose_reordered` alias that folds reordering into decomposition, and `inverse_perm` helper that avoids the `sequence_map_inverse` → `is_valid_sequence_map` → `sequence_sort` chain. ### 4. Eliminate internal lambda via `ford_applier` The flat-loop approach still used `static_for` with a lambda, creating M×N internal lambda instantiations per call site. Replace with `ford_applier` struct that calls `f(decompose<I>{})` directly via fold expression — zero intermediate closures: ```cpp // Before: 2×M×N function instantiations static_for<0, M*N, 1>{}([&](auto i) { f(decompose<i>{}); }); // After: M×N function instantiations (50% reduction) ford_applier<Decomposer, make_index_sequence<M*N>>{}(f); ``` Also unified identity and non-identity order paths into a single template with `constexpr if`. ### 5. Fix const-qualified sequence handling Fix `is_valid_sequence_map` to handle const-qualified sequence types via `remove_cvref_t` in callers (`tensor_adaptor.hpp`, `tile_distribution_encoding.hpp`). ## Results (this PR only, without flattening) ### Build Time (Wilcoxon signed-rank, 7 paired trials, gfx942, load ~5) | Target | Base (s) | Treat (s) | Delta | % | Wins | Significant? | |--------|----------|-----------|-------|---|------|-------------| | **flatmm** | 160.1 | 152.7 | **-7.4s** | **-4.6%** | 6/7 | **YES** (W+=1, p<0.05) | | universal_gemm | 228.4 | 224.7 | -3.7s | -1.6% | 6/7 | Trending (W+=4) | Per-trial diffs (flatmm): [-6, -20, -9, -8, -8, 4, -5] Per-trial diffs (universal_gemm): [-2, -6, 4, -3, -2, -11, -6] ### IR Function Counts (device trace, gfx942) | Target | Metric | Before | After | Delta | % | |--------|--------|--------|-------|-------|---| | **universal_gemm** | InstantiateFunction | 117,715 | 109,165 | **-8,550** | **-7.3%** | | **universal_gemm** | CodeGen Function | 47,912 | 45,044 | **-2,868** | **-6.0%** | | **flatmm** | InstantiateFunction | 100,939 | 95,127 | **-5,812** | **-5.8%** | | **flatmm** | CodeGen Function | 42,651 | 40,367 | **-2,284** | **-5.4%** | Note: The `ford_applier` (commit 3) has minimal additional effect in this PR since ck_tile code does not yet use `static_ford` extensively. Its impact compounds when the follow-up flattening PR #5939 converts 124 `static_for` nests to `static_ford`. Combined results with #5939: flatmm **-7.5%** wall time (p<0.01), CodeGen **-10.5%**. ### ASM Equivalence 7/7 PASS — 979,943 lines of device assembly verified identical (gfx942 + gfx1100). TUs: universal_gemm, flatmm_basic, fmha_bwd, reduce, bscale. ## Test plan - [x] `test_ck_tile_static_ford`: 13 behavioral tests (identity/non-identity orders, 1D-4D, unit dimensions, edge cases) - [x] `ck_tile_unit_sequence`: 88 tests (11 new for sorted2unsorted_map, is_valid_sequence_map edge cases, sequence_unique_sort map round-trip) - [x] ASM equivalence verified (980K lines) - [x] Wilcoxon timing verified (7 trials, flatmm p<0.05) - [ ] CI 🤖 Generated with [Claude Code](https://claude.com/claude-code)
294 lines
10 KiB
C++
294 lines
10 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include <gtest/gtest.h>
|
|
#include <vector>
|
|
#include <tuple>
|
|
#include "ck_tile/core/container/sequence.hpp"
|
|
#include "ck_tile/core/utility/functional.hpp"
|
|
|
|
using namespace ck_tile;
|
|
|
|
// ============================================================================
|
|
// static_ford Tests — Identity Order (default)
|
|
// ============================================================================
|
|
|
|
TEST(CkTileStaticFord, Identity2D)
|
|
{
|
|
std::vector<std::pair<index_t, index_t>> visited;
|
|
|
|
static_ford<sequence<2, 3>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
constexpr index_t j = multi_id[number<1>{}];
|
|
visited.emplace_back(i, j);
|
|
});
|
|
|
|
ASSERT_EQ(visited.size(), 6u);
|
|
EXPECT_EQ(visited[0], std::make_pair(0, 0));
|
|
EXPECT_EQ(visited[1], std::make_pair(0, 1));
|
|
EXPECT_EQ(visited[2], std::make_pair(0, 2));
|
|
EXPECT_EQ(visited[3], std::make_pair(1, 0));
|
|
EXPECT_EQ(visited[4], std::make_pair(1, 1));
|
|
EXPECT_EQ(visited[5], std::make_pair(1, 2));
|
|
}
|
|
|
|
TEST(CkTileStaticFord, Identity3D)
|
|
{
|
|
std::vector<std::tuple<index_t, index_t, index_t>> visited;
|
|
|
|
static_ford<sequence<2, 3, 2>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
constexpr index_t j = multi_id[number<1>{}];
|
|
constexpr index_t k = multi_id[number<2>{}];
|
|
visited.emplace_back(i, j, k);
|
|
});
|
|
|
|
ASSERT_EQ(visited.size(), 12u);
|
|
EXPECT_EQ(visited[0], std::make_tuple(0, 0, 0));
|
|
EXPECT_EQ(visited[1], std::make_tuple(0, 0, 1));
|
|
EXPECT_EQ(visited[2], std::make_tuple(0, 1, 0));
|
|
EXPECT_EQ(visited[3], std::make_tuple(0, 1, 1));
|
|
EXPECT_EQ(visited[4], std::make_tuple(0, 2, 0));
|
|
EXPECT_EQ(visited[5], std::make_tuple(0, 2, 1));
|
|
EXPECT_EQ(visited[6], std::make_tuple(1, 0, 0));
|
|
EXPECT_EQ(visited[7], std::make_tuple(1, 0, 1));
|
|
EXPECT_EQ(visited[8], std::make_tuple(1, 1, 0));
|
|
EXPECT_EQ(visited[9], std::make_tuple(1, 1, 1));
|
|
EXPECT_EQ(visited[10], std::make_tuple(1, 2, 0));
|
|
EXPECT_EQ(visited[11], std::make_tuple(1, 2, 1));
|
|
}
|
|
|
|
TEST(CkTileStaticFord, Identity1D)
|
|
{
|
|
std::vector<index_t> visited;
|
|
|
|
static_ford<sequence<5>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
visited.push_back(i);
|
|
});
|
|
|
|
ASSERT_EQ(visited.size(), 5u);
|
|
for(index_t i = 0; i < 5; ++i)
|
|
{
|
|
EXPECT_EQ(visited[i], i);
|
|
}
|
|
}
|
|
|
|
TEST(CkTileStaticFord, SingleElement1D)
|
|
{
|
|
std::vector<index_t> visited;
|
|
|
|
static_ford<sequence<1>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
visited.push_back(i);
|
|
});
|
|
|
|
ASSERT_EQ(visited.size(), 1u);
|
|
EXPECT_EQ(visited[0], 0);
|
|
}
|
|
|
|
TEST(CkTileStaticFord, SingleElement2D)
|
|
{
|
|
std::vector<std::pair<index_t, index_t>> visited;
|
|
|
|
static_ford<sequence<1, 1>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
constexpr index_t j = multi_id[number<1>{}];
|
|
visited.emplace_back(i, j);
|
|
});
|
|
|
|
ASSERT_EQ(visited.size(), 1u);
|
|
EXPECT_EQ(visited[0], std::make_pair(0, 0));
|
|
}
|
|
|
|
TEST(CkTileStaticFord, IdentityWithUnitDim)
|
|
{
|
|
std::vector<std::tuple<index_t, index_t, index_t>> visited;
|
|
|
|
static_ford<sequence<2, 1, 3>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
constexpr index_t j = multi_id[number<1>{}];
|
|
constexpr index_t k = multi_id[number<2>{}];
|
|
visited.emplace_back(i, j, k);
|
|
});
|
|
|
|
ASSERT_EQ(visited.size(), 6u);
|
|
EXPECT_EQ(visited[0], std::make_tuple(0, 0, 0));
|
|
EXPECT_EQ(visited[1], std::make_tuple(0, 0, 1));
|
|
EXPECT_EQ(visited[2], std::make_tuple(0, 0, 2));
|
|
EXPECT_EQ(visited[3], std::make_tuple(1, 0, 0));
|
|
EXPECT_EQ(visited[4], std::make_tuple(1, 0, 1));
|
|
EXPECT_EQ(visited[5], std::make_tuple(1, 0, 2));
|
|
}
|
|
|
|
// ============================================================================
|
|
// static_ford Tests — Non-Identity Order (primary template with decompose_reordered)
|
|
// ============================================================================
|
|
|
|
TEST(CkTileStaticFord, ReversedOrder2D)
|
|
{
|
|
std::vector<std::pair<index_t, index_t>> visited;
|
|
|
|
// Order (1, 0): dim 1 is outer, dim 0 is inner (column-major)
|
|
static_ford<sequence<2, 3>, sequence<1, 0>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
constexpr index_t j = multi_id[number<1>{}];
|
|
visited.emplace_back(i, j);
|
|
});
|
|
|
|
ASSERT_EQ(visited.size(), 6u);
|
|
EXPECT_EQ(visited[0], std::make_pair(0, 0));
|
|
EXPECT_EQ(visited[1], std::make_pair(1, 0));
|
|
EXPECT_EQ(visited[2], std::make_pair(0, 1));
|
|
EXPECT_EQ(visited[3], std::make_pair(1, 1));
|
|
EXPECT_EQ(visited[4], std::make_pair(0, 2));
|
|
EXPECT_EQ(visited[5], std::make_pair(1, 2));
|
|
}
|
|
|
|
TEST(CkTileStaticFord, CustomOrder3D_201)
|
|
{
|
|
std::vector<std::tuple<index_t, index_t, index_t>> visited;
|
|
|
|
// Orders<2,0,1>: dim 2 outermost, dim 0 middle, dim 1 innermost
|
|
static_ford<sequence<2, 3, 4>, sequence<2, 0, 1>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
constexpr index_t j = multi_id[number<1>{}];
|
|
constexpr index_t k = multi_id[number<2>{}];
|
|
visited.emplace_back(i, j, k);
|
|
});
|
|
|
|
ASSERT_EQ(visited.size(), 24u);
|
|
// With orders (2,0,1): k varies slowest, then i, then j fastest
|
|
EXPECT_EQ(visited[0], std::make_tuple(0, 0, 0));
|
|
EXPECT_EQ(visited[1], std::make_tuple(0, 1, 0));
|
|
EXPECT_EQ(visited[2], std::make_tuple(0, 2, 0));
|
|
EXPECT_EQ(visited[3], std::make_tuple(1, 0, 0));
|
|
EXPECT_EQ(visited[4], std::make_tuple(1, 1, 0));
|
|
EXPECT_EQ(visited[5], std::make_tuple(1, 2, 0));
|
|
EXPECT_EQ(visited[6], std::make_tuple(0, 0, 1));
|
|
EXPECT_EQ(visited[7], std::make_tuple(0, 1, 1));
|
|
// Tail: last element should be (1, 2, 3)
|
|
EXPECT_EQ(visited[23], std::make_tuple(1, 2, 3));
|
|
}
|
|
|
|
TEST(CkTileStaticFord, CustomOrder3D_120)
|
|
{
|
|
std::vector<std::tuple<index_t, index_t, index_t>> visited;
|
|
|
|
// Orders<1,2,0>: dim 1 outermost, dim 2 middle, dim 0 innermost
|
|
static_ford<sequence<2, 3, 2>, sequence<1, 2, 0>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
constexpr index_t j = multi_id[number<1>{}];
|
|
constexpr index_t k = multi_id[number<2>{}];
|
|
visited.emplace_back(i, j, k);
|
|
});
|
|
|
|
ASSERT_EQ(visited.size(), 12u);
|
|
// With orders (1,2,0): j varies slowest, then k, then i fastest
|
|
EXPECT_EQ(visited[0], std::make_tuple(0, 0, 0));
|
|
EXPECT_EQ(visited[1], std::make_tuple(1, 0, 0));
|
|
EXPECT_EQ(visited[2], std::make_tuple(0, 0, 1));
|
|
EXPECT_EQ(visited[3], std::make_tuple(1, 0, 1));
|
|
EXPECT_EQ(visited[4], std::make_tuple(0, 1, 0));
|
|
EXPECT_EQ(visited[5], std::make_tuple(1, 1, 0));
|
|
// Tail: last element should be (1, 2, 1)
|
|
EXPECT_EQ(visited[11], std::make_tuple(1, 2, 1));
|
|
}
|
|
|
|
TEST(CkTileStaticFord, NonIdentityWithUnitDim)
|
|
{
|
|
std::vector<std::tuple<index_t, index_t, index_t>> visited;
|
|
|
|
// Unit dim at position 1 with non-trivial order
|
|
static_ford<sequence<2, 1, 3>, sequence<2, 0, 1>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
constexpr index_t j = multi_id[number<1>{}];
|
|
constexpr index_t k = multi_id[number<2>{}];
|
|
visited.emplace_back(i, j, k);
|
|
});
|
|
|
|
ASSERT_EQ(visited.size(), 6u);
|
|
// All entries must have j == 0 (unit dimension)
|
|
for(size_t idx = 0; idx < visited.size(); ++idx)
|
|
{
|
|
EXPECT_EQ(std::get<1>(visited[idx]), 0) << "Unit dim not zero at iteration " << idx;
|
|
}
|
|
}
|
|
|
|
TEST(CkTileStaticFord, CustomOrder4D)
|
|
{
|
|
std::vector<std::tuple<index_t, index_t, index_t, index_t>> visited;
|
|
|
|
// 4D with order <3,1,0,2>
|
|
static_ford<sequence<2, 3, 2, 4>, sequence<3, 1, 0, 2>>{}([&](auto multi_id) {
|
|
constexpr index_t a = multi_id[number<0>{}];
|
|
constexpr index_t b = multi_id[number<1>{}];
|
|
constexpr index_t c = multi_id[number<2>{}];
|
|
constexpr index_t d = multi_id[number<3>{}];
|
|
visited.emplace_back(a, b, c, d);
|
|
});
|
|
|
|
ASSERT_EQ(visited.size(), 48u);
|
|
// dim 3 (size 4) outermost, dim 1 (size 3) next, dim 0 (size 2) next, dim 2 (size 2) inner
|
|
EXPECT_EQ(visited[0], std::make_tuple(0, 0, 0, 0));
|
|
EXPECT_EQ(visited[1], std::make_tuple(0, 0, 1, 0));
|
|
EXPECT_EQ(visited[2], std::make_tuple(1, 0, 0, 0));
|
|
EXPECT_EQ(visited[3], std::make_tuple(1, 0, 1, 0));
|
|
EXPECT_EQ(visited[4], std::make_tuple(0, 1, 0, 0));
|
|
EXPECT_EQ(visited[5], std::make_tuple(0, 1, 1, 0));
|
|
}
|
|
|
|
TEST(CkTileStaticFord, AsymmetricDimsWithOrder)
|
|
{
|
|
std::vector<std::pair<index_t, index_t>> visited;
|
|
|
|
// Asymmetric: 3x5 with reversed order
|
|
static_ford<sequence<3, 5>, sequence<1, 0>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
constexpr index_t j = multi_id[number<1>{}];
|
|
visited.emplace_back(i, j);
|
|
});
|
|
|
|
ASSERT_EQ(visited.size(), 15u);
|
|
// dim 1 (size 5) outer, dim 0 (size 3) inner
|
|
EXPECT_EQ(visited[0], std::make_pair(0, 0));
|
|
EXPECT_EQ(visited[1], std::make_pair(1, 0));
|
|
EXPECT_EQ(visited[2], std::make_pair(2, 0));
|
|
EXPECT_EQ(visited[3], std::make_pair(0, 1));
|
|
EXPECT_EQ(visited[4], std::make_pair(1, 1));
|
|
EXPECT_EQ(visited[5], std::make_pair(2, 1));
|
|
}
|
|
|
|
// ============================================================================
|
|
// Consistency: identity order matches explicit identity order
|
|
// ============================================================================
|
|
|
|
TEST(CkTileStaticFord, IdentityOrderMatchesExplicit)
|
|
{
|
|
std::vector<std::pair<index_t, index_t>> default_visited;
|
|
std::vector<std::pair<index_t, index_t>> explicit_visited;
|
|
|
|
static_ford<sequence<3, 4>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
constexpr index_t j = multi_id[number<1>{}];
|
|
default_visited.emplace_back(i, j);
|
|
});
|
|
|
|
static_ford<sequence<3, 4>, sequence<0, 1>>{}([&](auto multi_id) {
|
|
constexpr index_t i = multi_id[number<0>{}];
|
|
constexpr index_t j = multi_id[number<1>{}];
|
|
explicit_visited.emplace_back(i, j);
|
|
});
|
|
|
|
ASSERT_EQ(default_visited.size(), explicit_visited.size());
|
|
for(size_t i = 0; i < default_visited.size(); ++i)
|
|
{
|
|
EXPECT_EQ(default_visited[i], explicit_visited[i]) << "Mismatch at iteration " << i;
|
|
}
|
|
}
|
|
|
|
// index_decomposer and inverse_perm are implementation details tested
|
|
// indirectly through the static_ford behavioral tests above.
|
|
// The IdentityOrderMatchesExplicit test verifies both code paths
|
|
// (identity specialization and primary template) produce identical results.
|