[rocm-libraries] ROCm/rocm-libraries#4447 (commit 6d08a99)

[CK] Optimize multi-dimensional static for loop decomposition
 (#4447)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation
Recursive template implementations might initially seem attractive to
minimize necessary coding.

Unfortunately, this style is often affects readability and requires
significant resources from the compiler to generate instantiation
chains. In "high-traffic" code (e.g., used in many places + compilation
units), this generally does not scale well and can bloat the overall
compile times to unnecessary lengths.

The aim of this PR is to take some of most high-traffic utility code and
try our best to eliminate recursive templates in favor of fold
expansions and constexpr function helpers.

In local tests with clang build analyzer,
device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp
showed high hit-rates on slow template instantiations in static_for,
dimensional static_for (static_ford), which are subsequently affected by
implementation of the Sequence class and associated transforms.

Example:
**** Templates that took longest to instantiate:
70111 ms: ck::detail::applier<int, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 1... (372 times, avg 188 ms) // **70 seconds!**

The above is part of the implementation of static_for which uses
Sequence classes..

## Technical Details

### Summary of Optimization Techniques

| Technique | Used In | Benefit |
 |-----------|---------|---------|
| __Constexpr for-loop computation__ | sequence_reverse_inclusive_scan,
sequence_map_inverse | Moves O(N) work from template instantiation to
constexpr evaluation |
| __Pack expansion with indexing__ | sequence_reverse, Sequence::Modify
| Single template instantiation instead of recursive |
| __Flat iteration + decomposition__ | ford, static_ford | O(1) template
depth instead of O(N^D) |
| __Pre-computed strides__ | index_decomposer | Enables O(1)
linear-to-multi-index conversion |

### Impact on Compile Time

These optimizations reduce template instantiation depth from O(N) or
O(N^D) to O(1), which:

1. Reduces compiler memory usage
2. Reduces compile time exponentially for deep instantiation chains
3. Enables larger iteration spaces without hitting template depth limits

## Test Plan

* Existing tests for Sequence are re-used to affirm correctness
* Unit tests for ford and static_ford are added (dimensional looping)
* 8 new regression tests specifically verify the fixes for the PR
feedback:

  - `NonTrivialOrder3D_201` - Tests Orders<2,0,1> for static_ford
  - `NonTrivialOrder3D_201_Runtime` - Tests Orders<2,0,1> for ford
- `ConsistencyWithNonTrivialOrder_201` - Verifies static_ford and ford
consistency
  - `NonTrivialOrder3D_120` - Tests Orders<1,2,0> for static_ford
  - `NonTrivialOrder3D_120_Runtime` - Tests Orders<1,2,0> for ford
  - `NonTrivialOrder4D` - Tests 4D with Orders<3,1,0,2> for static_ford
  - `NonTrivialOrder4D_Runtime` - Tests 4D with Orders<3,1,0,2> for ford
- `AsymmetricDimensionsWithOrder` - Tests asymmetric dimensions with
non-trivial ordering

## Test Result
### Compile Time Comparison: `8b72bc8` (base) → `477e0686` (optimized)

#### Commits in Range (8 commits)

1. `fd4ca17f48` - Optimize sequence_reverse_inclusive_scan and
sequence_reverse
2. `7a7e3fdeef` - Optimize sequence_map_inverse
3. `92855c9913` - Optimize ford and static_ford calls to eliminate
nested template recursion
4. `88a564032b` - Add unit tests for ford and static_ford
5. `1a0fb22217` - Fix clang-format
6. `8a0d26bddf` - Increase template recursion depth to 1024
7. `dc53bb6e20` - Address copilot feedback and add regression tests
8. `477e06861d` - Increase bracket depth to 1024

#### Build Timing Results

| File | Base (8b72bc8759d9 | HEAD(a0438bd398) | Improvement |
|------|------|------|-------------|
| grouped_conv2d_fwd (f16) -j1 | 313.31s | 272.93s | __12.9% faster__ |
| grouped_conv1d_fwd (bf16) -j1 | 79.33s | 68.61s | __13.5% faster__ |
| grouped_conv1d_bwd_weight (f16) -j1| 15.77s | 14.31s | __9.2% faster__
|
| device_grouped_conv2d_fwd_instance -j64 | s | s | __% faster__ |

#### Key Optimizations

1. __sequence_reverse_inclusive_scan/sequence_reverse__: O(N) → O(1)
template depth
2. __sequence_map_inverse__: O(N) → O(1) template depth
3. __ford/static_ford__: O(N^D) → O(1) template depth using flat
iteration with index decomposition
4. __Copilot feedback fixes__: Corrected New2Old mapping for non-trivial
orderings

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Christopher Millette
2026-02-11 22:13:15 +00:00
committed by assistant-librarian[bot]
parent ea4942cd02
commit e1e2f7ac2e
7 changed files with 1109 additions and 170 deletions

View File

@@ -38,6 +38,30 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);
template <typename Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq);
namespace detail {
/**
* @brief Helper to generate integer sequences with custom Sequence class
*/
template <typename T, T... Ints>
struct __integer_sequence;
template <index_t... Ints>
struct __integer_sequence<index_t, Ints...>
{
using seq_type = Sequence<Ints...>;
};
} // namespace detail
/**
* @brief Generate a Sequence class with index_t integers from 0 to N-1
* @tparam N The size of the sequence to generate
*/
template <index_t N>
using make_index_sequence =
typename __make_integer_seq<detail::__integer_sequence, index_t, N>::seq_type;
template <index_t... Is>
struct Sequence
{
@@ -157,18 +181,37 @@ struct Sequence
return Sequence<Type::At(Number<Ns>{})...>{};
}
/**
* @brief Modify the sequence at a specific index with a new value
* @tparam I The index of the element to modify
* @tparam X The new value to set at index I
* @return A new Sequence with the value at index I replaced by X
*/
template <index_t I, index_t X>
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
{
static_assert(I < Size(), "wrong!");
using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::left_type{};
constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
// Generate and forward an index sequence that covers all elements
static_assert(I >= 0 && I < mSize, "Index I is out of bounds");
return modify_impl(make_index_sequence<mSize>{}, Number<I>{}, Number<X>{});
}
private:
/**
* @brief Helper function to modify the sequence at a specific index
* @tparam Idxs Indices of the sequence elements (0, 1, ..., Size-1)
* @tparam ModifyIdx The index of the value in the sequence to modify
* @tparam NewVal The new value to set at ModifyIdx
* @return A new Sequence with the value at ModifyIdx replaced by NewVal
*/
template <index_t... Idxs, index_t ModifyIdx, index_t NewVal>
__host__ __device__ static constexpr auto
modify_impl(Sequence<Idxs...>, Number<ModifyIdx>, Number<NewVal>)
{
// For each index: if it equals ModifyIdx, use NewVal; otherwise use original value
return Sequence<(Idxs == ModifyIdx ? NewVal : At(Idxs))...>{};
}
public:
template <typename F>
__host__ __device__ static constexpr auto Transform(F f)
{
@@ -184,21 +227,6 @@ struct Sequence
}
};
namespace impl {
template <typename T, T... Ints>
struct __integer_sequence;
template <index_t... Ints>
struct __integer_sequence<index_t, Ints...>
{
using seq_type = Sequence<Ints...>;
};
} // namespace impl
template <index_t N>
using make_index_sequence =
typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type;
// merge sequence - optimized to avoid recursive instantiation
//
// Note: Unlike sequence_gen and uniform_sequence_gen which use __make_integer_seq for O(1)
@@ -332,13 +360,7 @@ struct arithmetic_sequence_gen
template <index_t IEnd>
struct arithmetic_sequence_gen<0, IEnd, 1>
{
template <typename T, T... Ints>
struct WrapSequence
{
using type = Sequence<Ints...>;
};
// https://reviews.llvm.org/D13786
using type = typename __make_integer_seq<WrapSequence, index_t, IEnd>::type;
using type = make_index_sequence<IEnd>;
};
// uniform sequence - optimized using __make_integer_seq
@@ -367,26 +389,79 @@ struct uniform_sequence_gen<0, I>
using type = Sequence<>;
};
// reverse inclusive scan (with init) sequence
template <typename, typename, index_t>
struct sequence_reverse_inclusive_scan;
namespace detail {
template <index_t I, index_t... Is, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
/**
* @brief A simple fixed-size array to hold intermediate results during constexpr computation
* @tparam N The size of the array
*/
template <index_t N>
struct index_array
{
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
index_t data[N > 0 ? N : 1];
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
__host__ __device__ constexpr index_t& operator[](index_t i) { return data[i]; }
__host__ __device__ constexpr const index_t& operator[](index_t i) const { return data[i]; }
};
template <index_t I, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
/**
* @brief Compute the reverse inclusive scan of a sequence at compile time using a constexpr
* function
* @tparam Reduce The binary reduction functor
* @tparam Init The initial value for the reduction
* @tparam Vs The input sequence values
* @return An index_array containing the reverse inclusive scan results
*/
template <typename Reduce, index_t Init, index_t... Vs>
__host__ __device__ constexpr auto compute_reverse_inclusive_scan()
{
using type = Sequence<Reduce{}(I, Init)>;
constexpr index_t N = sizeof...(Vs);
index_array<N> result{};
constexpr index_t input[N > 0 ? N : 1] = {Vs...};
if constexpr(N > 0)
{
result.data[N - 1] = Reduce{}(input[N - 1], Init);
for(index_t i = N - 2; i >= 0; --i)
{
result.data[i] = Reduce{}(input[i], result.data[i + 1]);
}
}
return result;
}
// Build result sequence with O(1) instantiation depth
template <typename Reduce, index_t Init, typename Seq, typename IndexSeq>
struct build_reverse_inclusive_scan;
template <typename Reduce, index_t Init, index_t... Vs, index_t... Is>
struct build_reverse_inclusive_scan<Reduce, Init, Sequence<Vs...>, Sequence<Is...>>
{
static constexpr auto result = compute_reverse_inclusive_scan<Reduce, Init, Vs...>();
using type = Sequence<result.data[Is]...>;
};
} // namespace detail
/**
* @brief Reverse inclusive scan of a sequence - main interface
* @tparam Seq The input sequence to scan
* @tparam Reduce The binary reduction functor
* @tparam Init The initial value for the reduction
*/
template <typename Seq, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan
{
using type = typename detail::
build_reverse_inclusive_scan<Reduce, Init, Seq, make_index_sequence<Seq::Size()>>::type;
};
/**
* @brief Specialization for empty sequence - returns empty sequence without computation
* @tparam Reduce The binary reduction functor
* @tparam Init The initial value for the reduction
*/
template <typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
{
@@ -406,28 +481,34 @@ struct sequence_split
using right_type = decltype(Seq::Extract(range1{}));
};
// reverse sequence
// reverse sequence - optimized using direct pack expansion O(1) depth
namespace detail {
template <typename Seq, typename IndexSeq>
struct sequence_reverse_impl;
template <index_t... Is, index_t... Idxs>
struct sequence_reverse_impl<Sequence<Is...>, Sequence<Idxs...>>
{
static constexpr index_t N = sizeof...(Is);
// Access elements in reverse order: index (N-1-i) for position i
using type = Sequence<Sequence<Is...>::At(Number<N - 1 - Idxs>{})...>;
};
} // namespace detail
template <typename Seq>
struct sequence_reverse
{
static constexpr index_t NSize = Seq{}.Size();
using seq_split = sequence_split<Seq, NSize / 2>;
using type = typename sequence_merge<
typename sequence_reverse<typename seq_split::right_type>::type,
typename sequence_reverse<typename seq_split::left_type>::type>::type;
using type =
typename detail::sequence_reverse_impl<Seq, make_index_sequence<Seq::Size()>>::type;
};
template <index_t I>
struct sequence_reverse<Sequence<I>>
// Empty sequence specialization
template <>
struct sequence_reverse<Sequence<>>
{
using type = Sequence<I>;
};
template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
using type = Sequence<I1, I0>;
using type = Sequence<>;
};
#if 1
@@ -597,31 +678,59 @@ struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMa
{
};
template <typename SeqMap>
struct sequence_map_inverse
/**
* @brief Invert a permutation sequence: given X2Y = {a, b, c, ...}, compute Y2X where Y2X[X2Y[i]]
* = i Example: Sequence<2,0,1> (meaning pos0->2, pos1->0, pos2->1) inverts to Sequence<1,2,0>
*
* Why this implementation is faster to compile than recursive templates:
*
* The old recursive approach created a new template type for each element:
* sequence_map_inverse<Seq<2,0,1>> -> sequence_map_inverse<Seq<0,1>> ->
* sequence_map_inverse<Seq<1>>
* Each "->" is a new type the compiler must create, track, and manage. For N elements, that's
* N template types, each with overhead (name mangling, debug info, symbol table entries).
*
* This implementation uses a constexpr for loop to build the inverse in O(N) operations:
* For input Sequence<2,0,1>, the loop sets result[input[pos]] = pos for each position:
* pos=0: result[2]=0, pos=1: result[0]=1, pos=2: result[1]=2
* This builds the inverse permutation in a single pass with O(1) template instantiation depth.
*
* @tparam Is The input permutation sequence
*/
template <index_t... Is>
struct sequence_map_inverse<Sequence<Is...>>
{
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
static_assert(is_valid_sequence_map<Sequence<Is...>>::value,
"sequence_map_inverse requires SeqMap to be a valid permutation sequence map");
private:
static constexpr auto build_inverse()
{
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
detail::index_array<sizeof...(Is)> result{};
constexpr index_t input[] = {Is...};
for(index_t pos = 0; pos < static_cast<index_t>(sizeof...(Is)); ++pos)
{
result[input[pos]] = pos;
}
return result;
}
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
type;
};
static constexpr auto inverse = build_inverse();
template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
template <index_t... Positions>
static constexpr auto compute(Sequence<Positions...>)
{
using type = WorkingY2X;
};
return Sequence<inverse[Positions]...>{};
}
using type =
typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
0,
SeqMap::Size()>::type;
public:
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
};
template <>
struct sequence_map_inverse<Sequence<>>
{
using type = Sequence<>;
};
template <index_t... Xs, index_t... Ys>
@@ -799,8 +908,8 @@ __host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<I
return Sequence<Seq::At(Number<Is>{})...>{};
}
#if 1
namespace detail {
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
struct pick_sequence_elements_by_mask_impl
{
@@ -856,7 +965,6 @@ __host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values,
return typename detail::modify_sequence_elements_by_ids_impl<Seq, Values, Ids>::type{};
}
#endif
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr index_t