mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-23 16:47:40 +00:00
[CK] Optimize multi-dimensional static for loop decomposition (#4447) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Recursive template implementations might initially seem attractive to minimize necessary coding. Unfortunately, this style is often affects readability and requires significant resources from the compiler to generate instantiation chains. In "high-traffic" code (e.g., used in many places + compilation units), this generally does not scale well and can bloat the overall compile times to unnecessary lengths. The aim of this PR is to take some of most high-traffic utility code and try our best to eliminate recursive templates in favor of fold expansions and constexpr function helpers. In local tests with clang build analyzer, device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp showed high hit-rates on slow template instantiations in static_for, dimensional static_for (static_ford), which are subsequently affected by implementation of the Sequence class and associated transforms. Example: **** Templates that took longest to instantiate: 70111 ms: ck::detail::applier<int, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1... (372 times, avg 188 ms) // **70 seconds!** The above is part of the implementation of static_for which uses Sequence classes.. ## Technical Details ### Summary of Optimization Techniques | Technique | Used In | Benefit | |-----------|---------|---------| | __Constexpr for-loop computation__ | sequence_reverse_inclusive_scan, sequence_map_inverse | Moves O(N) work from template instantiation to constexpr evaluation | | __Pack expansion with indexing__ | sequence_reverse, Sequence::Modify | Single template instantiation instead of recursive | | __Flat iteration + decomposition__ | ford, static_ford | O(1) template depth instead of O(N^D) | | __Pre-computed strides__ | index_decomposer | Enables O(1) linear-to-multi-index conversion | ### Impact on Compile Time These optimizations reduce template instantiation depth from O(N) or O(N^D) to O(1), which: 1. Reduces compiler memory usage 2. Reduces compile time exponentially for deep instantiation chains 3. Enables larger iteration spaces without hitting template depth limits ## Test Plan * Existing tests for Sequence are re-used to affirm correctness * Unit tests for ford and static_ford are added (dimensional looping) * 8 new regression tests specifically verify the fixes for the PR feedback: - `NonTrivialOrder3D_201` - Tests Orders<2,0,1> for static_ford - `NonTrivialOrder3D_201_Runtime` - Tests Orders<2,0,1> for ford - `ConsistencyWithNonTrivialOrder_201` - Verifies static_ford and ford consistency - `NonTrivialOrder3D_120` - Tests Orders<1,2,0> for static_ford - `NonTrivialOrder3D_120_Runtime` - Tests Orders<1,2,0> for ford - `NonTrivialOrder4D` - Tests 4D with Orders<3,1,0,2> for static_ford - `NonTrivialOrder4D_Runtime` - Tests 4D with Orders<3,1,0,2> for ford - `AsymmetricDimensionsWithOrder` - Tests asymmetric dimensions with non-trivial ordering ## Test Result ### Compile Time Comparison: `8b72bc8` (base) → `477e0686` (optimized) #### Commits in Range (8 commits) 1. `fd4ca17f48` - Optimize sequence_reverse_inclusive_scan and sequence_reverse 2. `7a7e3fdeef` - Optimize sequence_map_inverse 3. `92855c9913` - Optimize ford and static_ford calls to eliminate nested template recursion 4. `88a564032b` - Add unit tests for ford and static_ford 5. `1a0fb22217` - Fix clang-format 6. `8a0d26bddf` - Increase template recursion depth to 1024 7. `dc53bb6e20` - Address copilot feedback and add regression tests 8. `477e06861d` - Increase bracket depth to 1024 #### Build Timing Results | File | Base (8b72bc8759d9 | HEAD(a0438bd398) | Improvement | |------|------|------|-------------| | grouped_conv2d_fwd (f16) -j1 | 313.31s | 272.93s | __12.9% faster__ | | grouped_conv1d_fwd (bf16) -j1 | 79.33s | 68.61s | __13.5% faster__ | | grouped_conv1d_bwd_weight (f16) -j1| 15.77s | 14.31s | __9.2% faster__ | | device_grouped_conv2d_fwd_instance -j64 | s | s | __% faster__ | #### Key Optimizations 1. __sequence_reverse_inclusive_scan/sequence_reverse__: O(N) → O(1) template depth 2. __sequence_map_inverse__: O(N) → O(1) template depth 3. __ford/static_ford__: O(N^D) → O(1) template depth using flat iteration with index decomposition 4. __Copilot feedback fixes__: Corrected New2Old mapping for non-trivial orderings ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
106 lines
2.5 KiB
C++
106 lines
2.5 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP
|
|
#define CK_STATICALLY_INDEXED_ARRAY_HPP
|
|
|
|
#include "functional2.hpp"
|
|
#include "tuple.hpp"
|
|
|
|
namespace ck {
|
|
|
|
namespace detail {
|
|
template <typename X, typename Y>
|
|
struct tuple_concat;
|
|
|
|
template <typename... Xs, typename... Ys>
|
|
struct tuple_concat<Tuple<Xs...>, Tuple<Ys...>>
|
|
{
|
|
using type = Tuple<Xs..., Ys...>;
|
|
};
|
|
|
|
// StaticallyIndexedArrayImpl uses binary split for O(log N) depth
|
|
template <typename T, index_t N>
|
|
struct StaticallyIndexedArrayImpl
|
|
{
|
|
using type =
|
|
typename tuple_concat<typename StaticallyIndexedArrayImpl<T, N / 2>::type,
|
|
typename StaticallyIndexedArrayImpl<T, N - N / 2>::type>::type;
|
|
};
|
|
|
|
template <typename T>
|
|
struct StaticallyIndexedArrayImpl<T, 0>
|
|
{
|
|
using type = Tuple<>;
|
|
};
|
|
|
|
template <typename T>
|
|
struct StaticallyIndexedArrayImpl<T, 1>
|
|
{
|
|
using type = Tuple<T>;
|
|
};
|
|
} // namespace detail
|
|
|
|
template <typename T, index_t N>
|
|
using StaticallyIndexedArray = typename detail::StaticallyIndexedArrayImpl<T, N>::type;
|
|
|
|
template <typename X, typename... Xs>
|
|
__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
|
|
{
|
|
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
|
|
}
|
|
|
|
// make empty StaticallyIndexedArray
|
|
template <typename X>
|
|
__host__ __device__ constexpr auto make_statically_indexed_array()
|
|
{
|
|
return StaticallyIndexedArray<X, 0>();
|
|
}
|
|
|
|
template <typename T, index_t N>
|
|
struct StaticallyIndexedArray_v2
|
|
{
|
|
__host__ __device__ constexpr StaticallyIndexedArray_v2() = default;
|
|
|
|
__host__ __device__ static constexpr index_t Size() { return N; }
|
|
|
|
// read access
|
|
template <index_t I>
|
|
__host__ __device__ constexpr const auto& At(Number<I>) const
|
|
{
|
|
static_assert(I < N, "wrong! out of range");
|
|
|
|
return data_[I];
|
|
}
|
|
|
|
// write access
|
|
template <index_t I>
|
|
__host__ __device__ constexpr auto& At(Number<I>)
|
|
{
|
|
static_assert(I < N, "wrong! out of range");
|
|
|
|
return data_[I];
|
|
}
|
|
|
|
// read access
|
|
template <index_t I>
|
|
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
|
|
{
|
|
return At(i);
|
|
}
|
|
|
|
// write access
|
|
template <index_t I>
|
|
__host__ __device__ constexpr auto& operator()(Number<I> i)
|
|
{
|
|
return At(i);
|
|
}
|
|
|
|
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
|
|
|
T data_[N];
|
|
};
|
|
|
|
} // namespace ck
|
|
#endif
|